
import numpy as np
from typing import Dict, Tuple, List, Optional, Set

from GridWorld import GridWorldMDP

State = Tuple[int, int]
Action = str


class ValueIterationSolver:
    """
    A solver implementing the Value Iteration algorithm for a deterministic 2D GridWorld MDP.

    Performs the three steps:

    -----------------------------------------------------------------------
    (1) VALUE ITERATION
        Computes the optimal value function V(s) for all states. Terminal
        states retain their terminal value, and unreachable regions naturally
        collapse toward low or negative values due to accumulating penalties.

    (2) GREEDY POLICY EXTRACTION
        Derives a deterministic policy π(s) that chooses, in each state,
        the action that maximizes the one-step lookahead value. States that
        are terminal or have no valid actions return π(s) = None.

    (3) POLICY-BASED UNREACHABILITY ANALYSIS  (NEW LOGIC)
        A state is classified as *unreachable* if:
            - Its value V(s) < 0   (i.e., the optimal expected return is negative),
              AND
            - Following the greedy policy π(s) repeatedly never leads to any
              state t with positive value V(t) > 0.

        In other words:
            “Negative-value states that are unable to reach any positive-valued
             region under their own optimal policy are considered unreachable.”

        This approach uses only the final value function and the greedy
        deterministic policy. It does not rely on any reverse-graph search,
        transition exploration, or external reachability analysis.

    -----------------------------------------------------------------------
    Expected MDP object (GridWorldMDP) interface:
        - mdp.goal : State
        - mdp.grid : obstacle matrix
        - mdp.get_all_states() -> List[State]
        - mdp.get_actions(state) -> List[Action]
        - mdp.transition(state, action) -> (next_state, reward, done)
        - mdp.is_terminal(state) -> bool

    -----------------------------------------------------------------------
    Outputs of run():
        V : Dict[State, float]
            Final converged value function.

        policy : Dict[State, Optional[Action]]
            Deterministic greedy policy. States with no valid action or terminal
            states return None.

        unreachable_states : Set[State]
            States that:
                - Have V(s) < 0
                - And cannot, by following π, reach any state t with V(t) > 0.

    -----------------------------------------------------------------------
    Notes:
    - This method is fully deterministic.
    - Unreachable states are defined behaviorally based on the final policy.
    """

    def __init__(
        self,
        mdp: GridWorldMDP,
        gamma: float = 0.99,
        theta: float = 1e-4,
        max_iterations: int = 1000,
    ):
        self.mdp = mdp
        self.gamma = gamma
        self.theta = theta
        self.max_iterations = max_iterations

    # ----------------------------------------------------------------------
    # Value Iteration
    # ----------------------------------------------------------------------
    def run(
        self, verbose: bool = False
    ) -> Tuple[Dict[State, float], Dict[State, Optional[Action]], Set[State]]:
        """
        Execute Value Iteration, extract a greedy policy, and use the (V, policy)
        pair to detect unreachable states.

        Parameters
        ----------
        verbose : bool
            If True, progress information (iteration number and max update)
            is printed during value iteration.

        Returns
        -------
        V : Dict[State, float]
            Converged state-value function.

        policy : Dict[State, Optional[Action]]
            Greedy deterministic policy derived from V.

        unreachable_states : Set[State]
            States with V(s) < 0 which cannot reach any V > 0 state when following π.
        """
        states = self.mdp.get_all_states()

        # Initialize all values to zero
        V: Dict[State, float] = {s: 0.0 for s in states}

        # -----------------------------
        # VALUE ITERATION LOOP
        # -----------------------------
        
        # TODO: implement

        # ------------------------------------------------------------------
        # GREEDY POLICY EXTRACTION
        # ------------------------------------------------------------------
        policy: Dict[State, Optional[Action]] = {}

        # TODO: implement

        # ------------------------------------------------------------------
        # POLICY-BASED UNREACHABILITY DETECTION
        # ------------------------------------------------------------------
        unreachable_states: Set[State] = set()

        # We consider positive-value states as "good" or "reachable targets"
        # Only negative-valued states are candidates for "unreachable"
    
        # Follow greedy policy π until:
        #  - we reach a value-positive region → reachable
        #  - or fall into a loop / dead end → unreachable
        
        # Loop detected → no positive-value state reachable
        
        # TODO: implement
    
        return V, policy, unreachable_states
