# import communacation messages
from messages import *
import os

import matplotlib.pyplot as plt
import numpy as np

# import ground map
from riskmap import *
import dubins

import datetime

"""
    Plot the given impact map
"""
def plot_impact_prob(impact_map, ground_map,title=None, lim=None):
    map_end = ground_map.GridToENU(impact_map.offset_idx + np.array(impact_map.prob_map.shape))

    ys = np.arange(ground_map.GridToENU(impact_map.offset_idx)[0], map_end[0], ground_map.resolution)
    xs = np.arange(ground_map.GridToENU(impact_map.offset_idx)[1], map_end[1], ground_map.resolution)
    
    pcol = plt.pcolormesh(xs, ys, impact_map.prob_map, shading='nearest')
    pcol.set_edgecolor('face')
    
    plt.axis('equal')
    plt.xlabel("x [m]")
    plt.ylabel("y [m]")
    if title is not None:
        plt.title(title)
    
    if lim is not None:
        plt.xlim([-lim, lim])
        plt.ylim([-lim, lim])
    plt.show()

def clean_plot(space):
    plt.clf()
    if space == 'R2' or space == 'SE(2)':
        ax = plt.gca()
        plt.axis('equal')
    else:
        ax = plt.gcf().add_subplot(111, projection='3d')
        ax.view_init(1, 1)
    plt.xlabel("x")
    plt.ylabel("y")
    return ax

"""
    Convert the given Dubins path to a standard Path Msg
"""
def dubins_to_path_msg(path, flight_alt, sampling_step):
    samples, _ = path.sample_many(sampling_step)
    poses = []
    for s in samples:
        position = Vector3(s[0], s[1], flight_alt)
        orientation = Quaternion()
        orientation.from_Euler(s[2], 0, 0)
        poses.append(Pose(position, orientation))

    msg = Path(poses=poses)
    return msg

if __name__ == "__main__":
    resolution = 10
    width = 1500
    height = 1500
    origin = Pose(Vector3(0, 0, 0), Quaternion(0, 0, 0, 1))
    gm = GroundMap(resolution, width, height, origin)
    
    layers = {'height':os.getcwd() + "/data/height_raw.png", 'density':os.getcwd() + "/data/density_raw.png", 'shelter':os.getcwd() + "/data/shelter_raw.png"}
    gm.populate_map(layers)

    a = Aircraft(0.2, 10, 0.7, 0.2, 0.5)
    rm = BallisticRiskMap(gm, a, 3, 8, 1000, 100, 15, [-400, -600, -800], 10/180 * np.pi)

    # # Plot ballistic fall maps
    # for hdg in rm.risk_map.keys():
    #     for alt in rm.risk_map[hdg].keys():
    #         plot_impact_prob(rm.risk_map[hdg][alt], gm, title="Hdg: " + str(hdg) + ", alt: " + str(alt)) 

    plt.figure(1)
    ax = clean_plot('SE(2)')
    xs = np.arange(gm.resolution / 2, gm.width * gm.resolution, gm.resolution)
    ys = np.arange(gm.resolution / 2, gm.height * gm.resolution, gm.resolution)
    pcol = plt.pcolormesh(ys, xs, gm.layers["height"], shading='nearest')
    pcol.set_edgecolor('face')

    r = 50. # m
    sampling_step = 10. # m
    flight_alt = 650. # m

    dubins = dubins.shortest_path((6100., 6460, -3.14), (5685, 5895, 1.4587), r)
    path_msg = dubins_to_path_msg(dubins, flight_alt, sampling_step)
    path_msg.plot(ax, style='point')
    print("Maneuver risk: {}".format(rm.get_risk(path_msg)))

    dubins = dubins.shortest_path((6350., 12450, 1.48), (6650, 13350, 1.4587), r)
    path_msg = dubins_to_path_msg(dubins, flight_alt, sampling_step)
    path_msg.plot(ax, style='point')
    print("Maneuver risk: {}".format(rm.get_risk(path_msg)))


    plt.show()

    
