from math import radians, degrees

import numpy as np
from drw_tools import confidence_ellipse
from drw_tools import plot_gaussian_1d
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import scipy
import scipy.optimize

np.set_printoptions(precision=3, suppress=True)


def g(state, control):
    """Motion model (state transition function)."""
    theta = state[0]
    omega = control[0]
    return np.array([theta + omega])


def h(state):
    """Measurement function (projection from state space to measurement space)."""
    theta = state[0]
    return np.array([np.cos(theta), np.sin(theta)])


def ekf(z_t, x_t_minus_1_posterior, u_t, P_t_minus_1, R_t, Q_t, ax_1d, ax_2d):
    ###################### Linearize motion model g #########################
    # Jacobian of g (motion model) with respect to the state
    G_t = np.array([[1.0]])

    ######################### Predict #############################
    # Predict the state estimate at time t based on the state estimate at time t-1 and the
    # control input applied at time t-1.

    x_t_prior = g(x_t_minus_1_posterior, u_t)
    x_t_prior_perfect_obs = h(x_t_prior)

    P_t_prior = G_t @ P_t_minus_1 @ G_t.T + R_t

    print(f'Prior State Estimate={degrees(x_t_prior)}, P={P_t_prior}')

    plot_gaussian_1d(x_t_prior, P_t_prior, ax_1d, 1, color='magenta', label='x prior')
    ax_2d.plot([0, x_t_prior_perfect_obs[0]], [0, x_t_prior_perfect_obs[1]], '-', color='magenta', label='x prior')

    ###################### Linearize measurement model #########################
    theta_t_prior = x_t_prior[0]
    # Jacobian of h with respect to state
    H_t = np.array([[-np.sin(theta_t_prior)], [np.cos(theta_t_prior)]])

    ax_2d.arrow(x_t_prior_perfect_obs[0], x_t_prior_perfect_obs[1], 1 * H_t[0, 0], 1 * H_t[1, 0], color='k',
                head_width=0.1, head_length=0.1)
    ax_2d.arrow(x_t_prior_perfect_obs[0], x_t_prior_perfect_obs[1], -1 * H_t[0, 0], -1 * H_t[1, 0], color='k',
                head_width=0.1, head_length=0.1)

    ######################### Measurement update #############################

    innovation_t = z_t - x_t_prior_perfect_obs

    # Calculate the measurement residual covariance
    S_t = H_t @ P_t_prior @ H_t.T + Q_t

    # Calculate Kalman gain
    K_t = P_t_prior @ H_t.T @ np.linalg.inv(S_t)
    # ax_2d.arrow(z_t[0], z_t[1], K_t[0, 0], K_t[0, 1], color='k')

    # Calculate posterior state estimate for time k
    x_t_posterior = x_t_prior + (K_t @ innovation_t)

    # Calculate the posterior state covariance estimate for time k
    P_t_posterior = P_t_prior - (K_t @ H_t @ P_t_prior)

    # Print estimate of the current state of the robot
    print(f'Posterior State Estimate={degrees(x_t_posterior)}, P={P_t_posterior}, K={K_t}')

    return x_t_posterior, P_t_posterior


def fg(z_t, x_t_minus_1_posterior, u_t, P_t_minus_1, R_t, Q_t, ax_1d, ax_2d):
    R = scipy.linalg.sqrtm(np.linalg.inv(R_t))
    Q = scipy.linalg.sqrtm(np.linalg.inv(Q_t))

    def res_g_only(x):
        return np.hstack((
            R @ (g(x_t_minus_1_posterior, u_t) - x),
        ))
    sol = scipy.optimize.least_squares(res_g_only, x_t_minus_1_posterior, '3-point')
    x_t_prior = sol.x
    P_t_prior = P_t_minus_1
    x_t_prior_perfect_obs = h(x_t_prior)

    print(f'Prior State Estimate={degrees(x_t_prior)}, P={P_t_prior}')

    plot_gaussian_1d(x_t_prior, P_t_prior, ax_1d, 1, color='magenta', label='x prior')
    ax_2d.plot([0, x_t_prior_perfect_obs[0]], [0, x_t_prior_perfect_obs[1]], '-', color='magenta', label='x prior')

    def res(x):
        return np.hstack((
            R @ (g(x_t_minus_1_posterior, u_t) - x),
            Q @ (h(x) - z_t),
        ))
    sol = scipy.optimize.least_squares(res, x_t_minus_1_posterior, '3-point')
    x_t_posterior = sol.x
    cost = sol.cost

    P_t_posterior = P_t_minus_1
    print(f'Posterior State Estimate={degrees(x_t_posterior)}, P={P_t_posterior}, cost={cost}')
    return x_t_posterior, P_t_posterior


def sim(x0, u, noise=0.1):
    x = [x0]  # ground truth trajectory
    z = []  # measurements
    for u_t in u:
        omega_t = u_t[0]
        theta_t_plus_1 = x[-1][0] + omega_t
        x.append(np.array([theta_t_plus_1]))
        z_t = h(x[-1])
        z_t += noise * np.random.rand(z_t.shape[0])  # Add measurement noise
        z.append(z_t)

    return x, z


def main():
    # control
    u = [  # PLAY HERE
        np.array([radians(5)]),
        np.array([radians(5)]),
        np.array([radians(5)]),
        np.array([radians(5)]),
        np.array([radians(5)]),
    ]
    measurement_noise = 0.1  # PLAY HERE

    x, z = sim(np.array([0.0]), u, measurement_noise)
    # Simulate broken motor
    x[-1] = x[-2]
    z[-1] = z[-2]

    x_est = []  # estimated trajectory

    # Measurement noise covariance matrix (could be different for different times)
    Q = measurement_noise * np.eye(2)

    # Motion model noise covariance matrix (could be different for different times)
    R = 0.1 * np.eye(1)

    # INITIALIZATION

    # Initial state vector in the global reference frame.
    x_t_minus_1 = np.array([radians(5.0)])  # PLAY HERE

    # Initial state covariance matrix
    P_t_minus_1 = 1.0 * np.eye(1)

    print(f'Initial GT State={degrees(x[0])}')
    print(f'Initial Predicted State={degrees(x_t_minus_1)}')

    fig, (ax_2d, ax_1d) = plt.subplots(2, gridspec_kw={'height_ratios': [3, 1]}, figsize=(10, 7))

    x0_perfect_obs = h(x[0])
    plot_gaussian_1d(x[0], np.array([0.001]), ax_1d, 1.0, 'b--', mew=2, label='x GT')
    ax_2d.plot([0, x0_perfect_obs[0]], [0, x0_perfect_obs[1]], 'b--', markersize=10, mew=3, label='x GT')

    ax_2d.set_aspect("equal", "box")
    ax_2d.grid()
    ax_2d.set_xlim(-2.2, 2.2)
    ax_2d.set_ylim(-2.2, 2.2)
    ax_2d.legend()
    ax_1d.legend()
    ax_1d.set_xticklabels(["%.0f" % (degrees(float(label)),) for label in ax_1d.get_xticks()])
    plt.pause(2)

    for t, (u_t, z_t) in enumerate(zip(u, z), start=1):
        ax_2d.clear()
        ax_1d.clear()

        # Print the current timestep
        print(f'Timestep t={t}, u_t={degrees(u_t)}')

        # PLAY HERE
        opt = ekf
        # opt = fg

        x_t_posterior, P_t_posterior = opt(
            z_t,  # Most recent sensor measurement
            x_t_minus_1,  # Our most recent estimate of the state
            u_t,  # Our most recent control input
            P_t_minus_1,  # Our most recent state covariance matrix
            R,  # Motion model noise
            Q,  # Measurement noise
            ax_1d, ax_2d)

        # Get ready for the next timestep by updating the variable values
        x_t_minus_1 = x_t_posterior
        P_t_minus_1 = P_t_posterior

        x_est.append(x_t_posterior)

        print(f'GT State={degrees(x[t])}')

        # VISU
        ax_2d.add_patch(plt.Circle((0, 0), 1.0, color='gray', fill=False))
        # Estimated pose PDF
        plot_gaussian_1d(x_t_posterior, P_t_posterior, ax_1d, 1.0, 'r-', mew=2, label='x post')
        # GT pose PDF
        plot_gaussian_1d(x[t], np.array([0.001]), ax_1d, 1.0, 'b--', mew=2, label='x GT')

        # 2D measurement of robot pose
        ax_2d.scatter(z_t[0], z_t[1], color='green', label='z')
        confidence_ellipse(z_t, Q, ax_2d, 1, edgecolor='green')

        # Estimated 2D robot pose
        x_t_posterior_perfect_obs = h(x_t_posterior)
        ax_2d.plot([0, x_t_posterior_perfect_obs[0]], [0, x_t_posterior_perfect_obs[1]], 'r-', markersize=10, mew=3,
                   label='x post')

        # GT 2D robot pose
        x_t_perfect_obs = h(x[t])
        ax_2d.plot([0, x_t_perfect_obs[0]], [0, x_t_perfect_obs[1]], 'b--', markersize=10, mew=3, label='x GT')

        ax_2d.set_aspect("equal", "box")
        ax_2d.grid()
        ax_1d.grid()
        ax_2d.set_xlim(-2.2, 2.2)
        ax_2d.set_ylim(-2.2, 2.2)
        ax_1d.legend()
        handles, _ = ax_2d.get_legend_handles_labels()
        handles.append(Line2D([0], [0], label='H', color='k'))
        ax_2d.legend(handles=handles)
        ax_1d.set_xticklabels(["%.0f" % (degrees(float(label)),) for label in ax_1d.get_xticks()])
        plt.pause(2)

    plt.pause(20)


# Program starts running here with the main method
main()
