# Authors: Jiri Kubik (kubikji2@fel.cvut.cz), Marek Seltenhofer

import math
from pathlib import Path
import numpy as np

from mushroom_rl.environments.mujoco import MuJoCo, ObservationType
# observation types are located in mushroom_rl.utils.mujoco.observation_helper
# MuJoCo is located in mushroom_rl.environments.mujoco
"""
An enum indicating the type of data that should be added to the observation
of the environment, can be Joint-/Body-/Site- positions, rotations, and velocities.
The Observation have the following returns:
    BODY_POS: (3,) x, y, z position of the body
    BODY_ROT: (4,) quaternion of the body
    BODY_VEL: (6,) first angular velocity around x, y, z. Then linear velocity for x, y, z
    JOINT_POS: (1,) rotation of the joint OR (7,) position, quaternion of a free joint
    JOINT_VEL: (1,) velocity of the joint OR (6,) FIRST linear then angular velocity !different to BODY_VEL!
    SITE_POS: (3,) x, y, z position of the body
    SITE_ROT: (9,) rotation matrix of the site
"""
# MuJoCo Units: https://mujoco.readthedocs.io/en/stable/overview.html#units-are-unspecified

from inchworm_interface import InchwormInterface

# friction model
import general_friction_model

# quat to euler coversion
from scipy.spatial.transform import Rotation

class InchwormEnvironment(MuJoCo, InchwormInterface):
    
    def __init__(   self, 
                    gamma,
                    horizon,
                    n_intermediate_steps, 
                    timestep,
                    evaluator,
                    terminate = True,
                    xml_path : str = (Path(__file__).resolve() / "model" / "inchworm.xml").as_posix()):

        self.evaluator = evaluator

        self.terminate = terminate


        # ok, this is terrible, but I will not dig into that
        actuation_spec = [  "joint-0",
                            "joint-1",
                            "joint-2",
                            "joint-3"]
        
        observation_spec = [("servomotor-0", "joint-servomotor-0", ObservationType.JOINT_POS),
                            ("servomotor-1", "joint-servomotor-1", ObservationType.JOINT_POS),
                            ("servomotor-2", "joint-servomotor-2", ObservationType.JOINT_POS),
                            ("servomotor-3", "joint-servomotor-3", ObservationType.JOINT_POS)]

        __body_observations = [ ["position", ObservationType.BODY_POS], 
                                ["rotation", ObservationType.BODY_ROT],
                                ["velocity", ObservationType.BODY_VEL]]
       
        additional_data_spec =  [
                                ("joint-{}-position".format(i),"joint-servomotor-{}".format(i), ObservationType.JOINT_POS) for i in range(4)
                            ] + [
                                ("joint-{}-velocity".format(i),"joint-servomotor-{}".format(i), ObservationType.JOINT_VEL) for i in range(4)
                            ] + [
                                ("bracket-front-{}".format(obs_name), "bracket-front", obs_type) for obs_name, obs_type in __body_observations
                            ] + [
                                ("bracket-middle-{}".format(obs_name), "bracket-middle", obs_type) for obs_name, obs_type in __body_observations
                            ] + [
                                ("bracket-back-{}".format(obs_name), "bracket-back", obs_type) for obs_name, obs_type in __body_observations
                            ] + [
                                ("bumper-front-{}".format(obs_name), "bumper-front", obs_type) for obs_name, obs_type in __body_observations
                            ] + [
                                ("bumper-back-{}".format(obs_name), "bumper-back", obs_type) for obs_name, obs_type in __body_observations
                            ] + [
                                ("scales-front-{}".format(obs_name), "scales-front", obs_type) for obs_name, obs_type in __body_observations
                            ] + [
                                ("scales-back-{}".format(obs_name), "scales-back", obs_type) for obs_name, obs_type in __body_observations
                            ] + [
                                ("servo-0-{}".format(obs_name), "servo-0", obs_type) for obs_name, obs_type in __body_observations
                            ] + [
                                ("servo-1-{}".format(obs_name), "servo-1", obs_type) for obs_name, obs_type in __body_observations
                            ] + [
                                ("servo-2-{}".format(obs_name), "servo-2", obs_type) for obs_name, obs_type in __body_observations
                            ] + [
                                ("servo-3-{}".format(obs_name), "servo-3", obs_type) for obs_name, obs_type in __body_observations
                            ]

        self.valid_data = [el[0] for el in additional_data_spec]


        collision_groups = [("bumper-front",    ["bumper-front-collision"]),
                            ("scales-front",    ["scales-front-collision"]),
                            ("bumper-back",     ["bumper-back-collision"]),
                            ("scales-back",     ["scales-back-collision"]),
                            ("bracket-front",   ["bracket-front-collision"]),
                            ("bracket-middle",  ["bracket-middle-collision"]),
                            ("bracket-back",    ["bracket-back-collision"]),
                            ("servo-0",         ["servo-0-collision"]),
                            ("servo-1",         ["servo-1-collision"]),
                            ("servo-2",         ["servo-2-collision"]),
                            ("servo-3",         ["servo-3-collision"]),
                            ("ground",          ["floor"]),
                            ("no-touch",   [   "bracket-front-collision",
                                                "bracket-middle-collision",
                                                "bracket-back-collision",
                                                "servo-0-collision",
                                                "servo-1-collision",
                                                "servo-2-collision",
                                                "servo-3-collision"])]
        
        self.valid_colliders = [el[0] for el in collision_groups]


        self.general_friction_model = general_friction_model.GeneralFrictionModel(0.1, 1.5)

        self.init_dist = None
        self.touched_ground = False

        super().__init__(   xml_file=xml_path,
                            actuation_spec=actuation_spec,
                            observation_spec=observation_spec,
                            gamma=gamma,
                            horizon=horizon,
                            additional_data_spec=additional_data_spec,
                            n_intermediate_steps=n_intermediate_steps,
                            timestep=timestep,
                            collision_groups=collision_groups)

    ####################
    #                  #
    #  MUJOCO OVERIDE  #
    #                  #
    ####################

    # friction update
    def _simulation_pre_step(self):

        # 1. get speed
        speed_front = self._read_data("servo-0-velocity")[3] # angluar x,y,z, linear x,y,z
        #_, angle_front, _ = euler_from_quaternion(*self._read_data("rot_A"))
        speed_back = self._read_data("servo-3-velocity")[3]
        #_, angle_back, _ = euler_from_quaternion(*self._read_data("rot_C"))

        # 1. get friction for the current frame
        self.general_friction_model.step()
        
        # 2. expand friction
        _mu_soft_forward = self.general_friction_model.get_soft_friction_forward()
        _mu_soft_backward = self.general_friction_model.get_soft_friction_backward()
        _mu_hard = self.general_friction_model.get_stiff_friction()

        # 3. set frictions for contact segments
        self._model.geom("scales-front-collision").friction[0]  = _mu_soft_forward if speed_front > 0 else _mu_soft_backward
        self._model.geom("bumper-front-collision").friction[0] = _mu_hard

        self._model.geom("scales-back-collision").friction[0]  = _mu_soft_forward if speed_back > 0 else _mu_soft_backward
        self._model.geom("bumper-back-collision").friction[0] = _mu_hard

        # 4. ser friction for all other elements
        other_pieces = ["bracket-front-collision",
                        "bracket-middle-collision",
                        "bracket-back-collision",
                        "servo-0-collision",
                        "servo-1-collision",
                        "servo-2-collision",
                        "servo-3-collision"]
        
        for geom_name in other_pieces:
            self._model.geom(geom_name).friction[0] = _mu_hard
    

    # implementing mushroom-rl reward function
    def reward(self, obs, action, next_obs, absorbing):
        return self.evaluator.compute_reward(self)
    

    # implementing mushroom-rl absorbing states
    def is_absorbing(self, state):
        return self.evaluator.is_absorbing(self)

    ########################
    #                      #
    #  Inchworm Interface  #
    #                      #
    ########################

    def quaternion_to_euler(self, quat, degrees):
        # see: https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.transform.Rotation.as_euler.html#scipy.spatial.transform.Rotation.as_euler
        return Rotation.from_quat(quat).as_euler("XYZ", degrees)
    

    def get_part_position(self, part_name):
        key = part_name + "-position"
        if not key in self.valid_data:
            return None
        return self._read_data(key)


    def get_part_rotation(self, part_name, degrees = True):
        key = part_name + "-rotation"
        if not key in self.valid_data:
            return None
        return self.quaternion_to_euler(self._read_data(key), degrees)


    def get_part_velocity(self, part_name):
        key = part_name + "-velocity"
        if not key in self.valid_data:
            return None
        return self._read_data(key)


    def get_joint_position(self, joint_name, degrees = True):
        key = joint_name + "-position"
        if not key in self.valid_data:
            return None
        return self._read_data(key)[0]*(180/math.pi if degrees else 1) 


    def get_joint_velocity(self, joint_name):
        key = joint_name + "-velocity"
        if not key in self.valid_data:
            return None
        return self._read_data(key)[0]


    def is_touching(self, part_name_1, part_name_2):
        if not part_name_1 in self.valid_colliders or not part_name_2 in self.valid_colliders:
            return None
        return self._check_collision(part_name_1, part_name_2)

    ################
    #              #
    #  Evaluation  #
    #              #
    ################

    # distance is extected to the last contact segment
    def _get_current_distance(self):
        if self.init_dist is None:
            self.init_dist = self.get_part_position("servo-3")[0]
        
        current_distance = self.get_part_position("servo-3")[0]-self.init_dist
        return current_distance

    def _touched_ground(self):
        self.touched_ground |= self.is_touching("ground", "no-touch")
        return self.touched_ground

    def reset_touched(self):
        self.touched_ground = False

    def _create_info_dictionary(self, obs):
        return {"inchworm-pos": self._get_current_distance(),
                "touched-ground" : self._touched_ground()}




