"""
Markov Decision Process representation of a grid world maze.

Author: Vojtěch Drahý
Copyright (c) 2025 Vojtěch Drahý. All rights reserved.
"""

import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Tuple, List, Optional


State = Tuple[int, int]   # (row, col) indices
Action = str              # 'U', 'D', 'L', 'R'



class GridWorldMDP:
    """
    Markov Decision Process representation of a grid world maze.

    The maze is represented as a binary matrix:
    - 0 -> free cell (state is available)
    - 1 -> obstacle (state is not available)

    The agent can move in four directions: up, down, left, right.
    Attempting to walk into an obstacle or outside the grid keeps the agent
    in the same state and returns a large negative reward.

    The process terminates when the agent reaches the goal state.
    """

    # Possible actions: up, down, left, right
    ACTIONS: Dict[Action, Tuple[int, int]] = {
        'U': (-1, 0),
        'D': (1, 0),
        'L': (0, -1),
        'R': (0, 1),
    }

    def __init__(
        self,
        grid: np.ndarray,
        goal: State,
        step_reward: float = -1.0,
        goal_reward: float = 0.0,
        wall_reward: float = -100.0,
    ):
        """
        Initialize the MDP.

        Parameters
        ----------
        grid : np.ndarray
            Binary matrix with 0 for free cells and 1 for obstacles.
        goal : tuple[int, int]
            Goal state (row, col).
        step_reward : float, default -1.0
            Reward for each non-terminal move.
        goal_reward : float, default 0.0
            Reward for entering the terminal goal state.
        wall_reward : float, default -100.0
            Reward (penalty) for attempting to walk into a wall or outside the grid.
        """
        self.grid = grid.astype(int)
        self.n_rows, self.n_cols = self.grid.shape
        self.goal = goal
        self.step_reward = step_reward
        self.goal_reward = goal_reward
        self.wall_reward = wall_reward

        self._validate_states()

    @classmethod
    def from_file(
        cls,
        filepath: str,
        goal: State,
        step_reward: float = -1.0,
        goal_reward: float = 0.0,
        wall_reward: float = -100.0,
    ) -> "GridWorldMDP":
        """
        Create a GridWorldMDP from a text file containing a binary matrix.

        The file should contain rows of 0 and 1 separated by whitespace, e.g.:

        0 0 0 0 0
        0 1 1 1 0
        0 0 0 1 0
        0 1 0 0 0
        0 0 0 1 0

        Parameters
        ----------
        filepath : str
            Path to the file with the binary matrix.
        goal : tuple[int, int]
            Goal state (row, col).
        step_reward : float, optional
        goal_reward : float, optional
        wall_reward : float, optional

        Returns
        -------
        GridWorldMDP
        """
        grid = np.loadtxt(filepath)
        return cls(
            grid=grid,
            goal=goal,
            step_reward=step_reward,
            goal_reward=goal_reward,
            wall_reward=wall_reward,
        )

    def _validate_states(self) -> None:
        """Validate that goal are within bounds and not obstacles."""
        if not self.in_bounds(self.goal):
            raise ValueError(f"Goal state {self.goal} is out of bounds.")
        if self.is_obstacle(self.goal):
            raise ValueError("Goal state is an obstacle.")

    def in_bounds(self, state: State) -> bool:
        """Check if a state is inside the grid boundaries."""
        r, c = state
        return 0 <= r < self.n_rows and 0 <= c < self.n_cols

    def is_obstacle(self, state: State) -> bool:
        """Return True if the state is an obstacle (1 in the grid)."""
        r, c = state
        return self.grid[r, c] == 1

    def is_terminal(self, state: State) -> bool:
        """Return True if the state is the goal state."""
        return state == self.goal

    def get_all_states(self) -> List[State]:
        """
        Return a list of all valid states (non-obstacle cells) in the grid.
        """
        states: List[State] = []
        for r in range(self.n_rows):
            for c in range(self.n_cols):
                if not self.is_obstacle((r, c)):
                    states.append((r, c))
        return states

    def get_actions(self, state: State) -> List[Action]:
        """
        Return available actions in the given state.

        For simplicity, we always allow all four directions unless the state is terminal.
        """
        if self.is_terminal(state):
            return []  # no actions from the terminal (goal) state
        return list(self.ACTIONS.keys())

    def transition(self, state: State, action: Action) -> Tuple[State, float, bool]:
        """
        Deterministic transition function.

        Parameters
        ----------
        state : tuple[int, int]
            Current state (row, col).
        action : str
            One of 'U', 'D', 'L', 'R'.

        Returns
        -------
        next_state : tuple[int, int]
        reward : float
        done : bool
            True if next_state is terminal (goal).
        """
        if self.is_terminal(state):
            # Once in terminal state, you stay there with zero reward
            return state, 0.0, True

        dr, dc = self.ACTIONS[action]
        r, c = state
        candidate = (r + dr, c + dc)

        # Check if candidate is valid
        if not self.in_bounds(candidate) or self.is_obstacle(candidate):
            # Invalid move -> stay in place, apply wall penalty
            next_state = state
            reward = self.wall_reward
            done = False
        else:
            next_state = candidate
            done = self.is_terminal(next_state)
            reward = self.goal_reward if done else self.step_reward

        return next_state, reward, done