
import argparse
import numpy as np
import matplotlib.pyplot as plt
from typing import Dict, Tuple, List, Optional

from GridWorld import GridWorldMDP
from Visualizer import PolicyVisualizer
from Solver import ValueIterationSolver


def parse_args():
    parser = argparse.ArgumentParser(description="Gridworld Value Iteration Solver and Visualizer")

    parser.add_argument("maze_file", type=str, help="Path to maze file")
    parser.add_argument("goal_r", type=int, help="Goal row index")
    parser.add_argument("goal_c", type=int, help="Goal column index")

    return parser.parse_args()



if __name__ == "__main__":

    # run as "python3 main.py maze/small1.txt 4 4"
    # run as "python3 main.py maze/large2.txt 23 23"

    args = parse_args()
    maze_file = args.maze_file
    goal_state = (args.goal_r, args.goal_c)

    mdp = GridWorldMDP.from_file(
        filepath=maze_file,
        goal=goal_state,
        step_reward=-1.0,
        goal_reward=100.0,
        wall_reward=-10.0,
    )

    solver = ValueIterationSolver(mdp, gamma=0.99, theta=1e-4, max_iterations=1000)
    V, policy, unreachable_states = solver.run(verbose=True)

    visualizer = PolicyVisualizer(mdp)
    visualizer.plot_policy(policy, V=V, title="Optimal Policy (Value Iteration)", unreachable=unreachable_states)
