"""
Visualization utilities for policies and value functions on GridWorldMDP environments.

Author: Vojtěch Drahý
Copyright (c) 2025 Vojtěch Drahý. All rights reserved.
"""

import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Tuple, List, Optional, Set


from GridWorld import GridWorldMDP


State = Tuple[int, int]
Action = str


class PolicyVisualizer:
    """
    Visualization utilities for policies and value functions
    on GridWorldMDP environments.
    """

    def __init__(self, mdp: GridWorldMDP):
        self.mdp = mdp

    def plot_policy(
        self,
        policy: Dict[State, Optional[Action]],
        V: Optional[Dict[State, float]] = None,
        unreachable: Optional[Set[State]] = None,
        title: str = "Policy (Value Iteration)",
    ) -> None:

        n_rows, n_cols = self.mdp.n_rows, self.mdp.n_cols

        # Base grid: 1 = obstacle, 0 = free
        base = np.zeros((n_rows, n_cols))
        base[self.mdp.grid == 1] = 0.5  # obstacles as gray

        # Unreachable states overlay: mark as 0.8
        if unreachable is not None:
            for r, c in unreachable:
                base[r, c] = 0.8  # light red zone (with custom cmap)

        # Create custom colormap:
        # 0.0 free, 0.5 obstacle, 0.8 unreachable
        from matplotlib.colors import ListedColormap

        cmap = ListedColormap([
            "#FFFFFF",  # 0.0 free = white
            "#888888",  # 0.5 obstacle = gray
            "#FFB3B3",  # 0.8 unreachable = light red
        ])

        norm_values = [0.0, 0.5, 0.8, 1.0]

        fig, ax = plt.subplots(figsize=(n_cols, n_rows))
        ax.imshow(base, cmap=cmap, origin="upper", vmin=0, vmax=1)

        # Draw grid lines
        ax.set_xticks(np.arange(-0.5, n_cols, 1), minor=True)
        ax.set_yticks(np.arange(-0.5, n_rows, 1), minor=True)
        ax.grid(which="minor", color="black", linestyle="-", linewidth=0.5)

        # Remove tick labels
        ax.set_xticks(np.arange(0, n_cols, 1))
        ax.set_yticks(np.arange(0, n_rows, 1))
        ax.set_xticklabels([])
        ax.set_yticklabels([])

        # Draw goal
        gr, gc = self.mdp.goal
        ax.add_patch(
            plt.Rectangle((gc - 0.5, gr - 0.5), 1, 1, fill=False, edgecolor="green", linewidth=2)
        )
        ax.text(gc, gr, "G", color="green", fontsize=12, ha="center", va="center", fontweight="bold")

        # Draw policy arrows + values
        for (r, c), a in policy.items():

            # Skip obstacles
            if self.mdp.grid[r, c] == 1:
                continue

            # Skip unreachable states
            if unreachable is not None and (r, c) in unreachable:
                continue

            # Skip goal
            if self.mdp.is_terminal((r, c)):
                continue

            # Draw action arrow
            if a is not None:
                dr, dc = self.mdp.ACTIONS[a]
                ax.arrow(
                    c, r,
                    0.3 * dc, 0.3 * dr,
                    head_width=0.15,
                    head_length=0.15,
                    length_includes_head=True,
                    color="red",
                )

            # Draw value
            if V is not None and (r, c) in V:
                ax.text(c, r + 0.2, f"{V[(r, c)]:.1f}", color="blue", fontsize=8, ha="center", va="center")

        ax.set_title(title)
        plt.gca().invert_yaxis()
        plt.tight_layout()
        plt.show()
