# Authors: Jiri Kubik (kubikji2@fel.cvut.cz), Marek Seltenhofer

import argparse
from copy import deepcopy
import datetime
import os
import json
import shutil
import time

import numpy as np

# environment-related imports
from inchworm_environment import InchwormEnvironment
from evaluator import Evaluator

# SAC-related imports
from sac_networks import ActorNetwork, CriticNetwork

import torch
import torch.optim as optim
import torch.nn.functional as F

# mushroom-rl-related imports
from mushroom_rl.algorithms.actor_critic import SAC
from mushroom_rl.core import Core
from mushroom_rl.utils.dataset import compute_J, parse_dataset

# utils
from tqdm import trange, tqdm

DEFINED_ENVIRONMENTS_PARAMETERS = {
    # Discount factor.
    "gamma" : 0.99,
    # Number of simulation steps (environment interaction) taken in each epoch.
    "horizon" : 1000, 
    # Number of intermediate steps between observation in InchwormEnvironment.
    "n_intermediate_steps": 1,
    # Time passed between two consequentive environment steps, same as control period of robot.
    "timestep": 1e-2
}

DEFINED_SAC_PARAMETERS = {
    # Minimization step size (learning rate) used during training the Actor NN.
    "actor_lr" : 0.0001,
    # Number of samples passed through NN before updating SAC NN.
    "batch_size" : 256,
    # Minimization step size (learning rate) used durring training the Critic NN.
    "critic_lr" : 0.0003,
    # The number of samples to collect before starting the learning.
    "initial_replay_size" : 5000,
    # Learning rate for the entropy coefficient.
    "lr_alpha" : 0.0003,
    # The maximum number of samples in the replay memory.
    "max_replay_size" : 50000,
    # Number of features in both actor and critic NN
    "n_features" : 256,
    # Value of (Polyak) coefficient for soft updates.
    # NOTE: See "Polyak" in https://spinningup.openai.com/en/latest/algorithms/ddpg.html
    "tau" : 0.005,
    # Number of samples to accumulate in the replay memory to start the policy fitting.
    "warmup_transitions" : 10000
}

BRUTE_EVALUATION_LENGTH = 3000


def merge_defined_and_user_parameters(user_args):
    user_args.__dict__ |= DEFINED_ENVIRONMENTS_PARAMETERS | DEFINED_SAC_PARAMETERS
    return user_args


def save_learning_setup(args, time_based_identifier):
    
    # prepare directory for data storage
    if not os.path.exists("logs"):
        os.mkdir("logs")
    os.mkdir("logs/{}/".format(time_based_identifier))

    # copy configuration
    with open("logs/{}/config.json".format(time_based_identifier), "w") as f:
        json.dump(args.__dict__, f, indent=4)

    # copy sources
    os.mkdir("logs/{}/sources/".format(time_based_identifier))
    shutil.copy("evaluator.py", "logs/{}/sources/".format(time_based_identifier))



def setup_new_sac_agent(args, env, use_cuda) -> SAC:

    actor_input_shape = env.info.observation_space.shape

    actor_mu_params = dict(network=ActorNetwork,
                           n_features=args.n_features,
                           input_shape=actor_input_shape,
                           output_shape=env.info.action_space.shape,
                           use_cuda=use_cuda)

    actor_sigma_params = dict(network=ActorNetwork,
                              n_features=args.n_features,
                              input_shape=actor_input_shape,
                              output_shape=env.info.action_space.shape,
                              use_cuda=use_cuda)

    actor_optimizer = {'class': optim.Adam,
                       'params': {'lr': args.actor_lr}}

    critic_input_shape = (actor_input_shape[0] + env.info.action_space.shape[0],)
    critic_params = dict(network=CriticNetwork,
                         optimizer={'class': optim.Adam,
                                    'params': {'lr': args.critic_lr}},
                         loss=F.mse_loss,
                         n_features=args.n_features,
                         input_shape=critic_input_shape,
                         output_shape=(1,),
                         use_cuda=use_cuda)  
    
    agent = SAC(mdp_info = inchworm_env.info,
                actor_mu_params = actor_mu_params,
                actor_sigma_params = actor_sigma_params,
                actor_optimizer = actor_optimizer,
                critic_params = critic_params,
                batch_size = args.batch_size,
                initial_replay_size = args.initial_replay_size,
                max_replay_size = args.max_replay_size,
                warmup_transitions = args.warmup_transitions, 
                tau = args.tau,
                lr_alpha = args.lr_alpha,
                critic_fit_params=None)

    return agent


def load_sac_agent(args, use_cuda) -> SAC:

    tmp_fcn = None

    if not use_cuda:
        # override the torch load implementation to enforce cpu only
        tmp_fcn = deepcopy(torch.load)
        torch.load = lambda *args, **kw : tmp_fcn(*args, **kw, map_location=torch.device('cpu'))

    # load agent
    agent : SAC = SAC.load(args.agent_path_to_load)

    if not use_cuda:
        # fixing cuda mess in torch approximators
        # critic mess
        for el in agent._critic_approximator._impl.model._model:
            el.__dict__["_use_cuda"] = False
        # critic actor mess
        for el in agent._target_critic_approximator._impl.model._model:
            el.__dict__["_use_cuda"] = False
        # policy mess
        agent.policy._mu_approximator._impl.model.__dict__["_use_cuda"] = False
        agent.policy._sigma_approximator._impl.model.__dict__["_use_cuda"] = False

        torch.load = tmp_fcn

    return agent


def evaluate(core, args):
    dataset, env_info = core.evaluate(n_steps=args.n_steps_eval, render=args.render, get_env_info=True, quiet = args.quiet)
    s, *_ = parse_dataset(dataset)

    J = np.mean(compute_J(dataset, inchworm_env.info.gamma))
    R = np.mean(compute_J(dataset))
    E = float(agent.policy.entropy(s))
    inchworm_pos = np.mean(np.array(env_info["inchworm-pos"])[-10])
    touched_ground  = env_info["touched-ground"][-1]

    print(f'J: {J:15.3f}, R: {R:15.3f}, Entropy: {E:7.3f}, Inchworm pos [m]: {inchworm_pos:7.3f}, {"TOUCHED GROUND" if touched_ground else "without touching ground"}')
    total_points, touch_points, non_backward_points, distance_points = get_brute_points(inchworm_pos, touched_ground)

    print("BRUTE: TOTAL POINTS {}, non-backward points {}, distance points {}, touch points {}".format(total_points, non_backward_points, distance_points, touch_points))

    return J, R, E, inchworm_pos, touched_ground


def get_brute_points(distance, touched_ground):
    touch_points = 0 if touched_ground else 1
    non_backward_points = 1 if distance > 0 else 0
    distance_points = int(distance/0.05) if distance > 0 else 0 
    pts = int(np.clip(touch_points+non_backward_points+distance_points, 0, 5))
    return pts, touch_points, non_backward_points, distance_points



# arguments
parser = argparse.ArgumentParser()

# General parameters
parser.add_argument("--load_agent",
                    help="If given, previously saved agent is loaded and retrained, rather than starting from scratch.",
                    default = False,
                    type=bool)

parser.add_argument("--agent_path_to_load",
                    help="Path to agent msh file.",
                    default = "agent.msh",
                    type=str)

parser.add_argument("--epoch_save_frequency",
                    help="Frequency of model saves.",   
                    default = 1,
                    type=int)

parser.add_argument("--seed",
                    help="Random seed.",      
                    default = 42,
                    type=int)

parser.add_argument("--render",
                    help="If given renders evaluation runs during training.",                 
                    default=True,
                    type=bool)

parser.add_argument("--quiet",
                    help="Show output.",  
                    default=False,
                    type=bool)

parser.add_argument("--use_cuda",
                    help="Use CUDA, if present on PC.",   
                    default=False,
                    type=bool)

# Training parameters
parser.add_argument("--n_epochs",
                    help="Number of training epochs.",
                    default=200,
                    type=int)

parser.add_argument("--n_steps",
                    help="Number of simulation steps in each epochs used for training (1 s ~ 100 steps).",
                    default=30000,
                    type=int)

# Evaluation parameters
parser.add_argument("--n_steps_eval",
                    help="Number of simulation steps in each epochs used for evaluation (1 s ~ 100 steps).",
                    default=3000,
                    type=int)

#parser.add_argument("--n_episodes",    default=5, type=int, help="Number of episodes to evaluate.")

#args = parser.parse_args()#[] if "__file__" not in globals() else None)

if __name__ == "__main__":

    # 1. setup
    time_based_identifier = datetime.datetime.now().strftime("%y-%m-%d-%H-%M-%S")
    
    args = merge_defined_and_user_parameters(parser.parse_args())
    #print(args)

    use_cuda = torch.cuda.is_available() if args.use_cuda else False

    # 1.1. backup arguments and source files
    save_learning_setup(args, time_based_identifier)

    # 1.2. setup the evaluator
    evaluator = Evaluator()
    
    # 1.3. setup the environment
    inchworm_env = InchwormEnvironment( n_intermediate_steps=args.n_intermediate_steps,
                                        gamma=args.gamma,
                                        horizon=args.horizon,
                                        timestep=args.timestep,
                                        evaluator=evaluator,
                                        terminate=True,#args.terminate,
                                        xml_path="model/inchworm.xml")

    # 1.4. setup SAC
    agent = load_sac_agent(args, use_cuda) if args.load_agent else setup_new_sac_agent(args, inchworm_env, use_cuda)

    # 1.5. setup mushroom-rl core
    core = Core(agent, inchworm_env)

    # 2. initial training
    print("Initial training:")
    core.learn(n_steps=args.initial_replay_size, n_steps_per_fit=args.initial_replay_size, quiet=args.quiet)

    # 2.1 initial evaluation
    print("Initial evaluation:")
    evaluate(core, args)

    # 3. training epochs
    print("Starting training epochs:")

    for n in trange(args.n_epochs, leave=True):
        print("Epoch {}".format(n))
        # 3.1. train
        print("  Training epoch n. {}".format(n))
        learning_start_t = time.time()
        inchworm_env.reset_touched()
        core.learn(n_steps=args.n_steps, n_steps_per_fit=1, quiet=args.quiet)
        learning_end_t = time.time()
        learning_duration = learning_end_t-learning_start_t

        # 3.2. evaluate 
        print("  Evaluating epoch n. {}".format(n))
        inchworm_env.reset_touched()
        discounted_reward, reward, entropy, distance, touched_ground = evaluate(core, args)

        # 3.3. save agent
        checkpoint_path = "logs/"+time_based_identifier+"/checkpoint-{}.msh".format(n+1)
        agent.save(checkpoint_path, full_save=True)

        # 3.4. save metadata about epoch evalutaion
        metadata_path = checkpoint_path.replace("msh", "metadata")
        with open(metadata_path, 'w') as f:
            simulation_duration = args.n_steps*args.timestep
            eval_scaling_factor = BRUTE_EVALUATION_LENGTH/args.n_steps_eval
            pts_total, touch_points, forward_points, distance_points = get_brute_points(distance, touched_ground)
            dic = { "discounted_reward": discounted_reward,
                    "reward": reward,
                    "entropy": entropy,
                    "distance": distance,
                    "duration": learning_duration,
                    "touched_ground" : touched_ground,
                    "brute_points_for_distance" : distance_points,
                    "brute_points_for_touching" : touch_points,
                    "brute_points_for_forward"  : forward_points,
                    "brute_total_points"        : pts_total,
                    "simulation_time" : simulation_duration,
                    "realtime_factor" : learning_duration/simulation_duration}
            json.dump(dic, f, indent=4)
