from __future__ import print_function
import torch
import matplotlib.pyplot as plt
import numpy as np
import os
from mpl_toolkits.mplot3d import Axes3D

#######################################################
# fitting 2D line into points by gradient descent     #
# zimmerk@fel.cvut.cz                                 #
# Vision for Robotics (VIR)                           #
#######################################################
CREATE_MOVIE=False #True


# select points
N = 5
print("choose 5 points:")
plt.clf()
plt.axis([0, 1., 0, 1.])
a = plt.ginput(N)
pts = torch.tensor(a, dtype=torch.double)

# pts = np.load('pts.npy')
# pts= torch.tensor(pts)
# define criterion function
w = torch.tensor([-2, 2], requires_grad=True, dtype=torch.double)

for i in range(0, 100):
    y = w[0] * pts[:, 0] + w[1]
    dy = y - pts[:, 1]
    loss = torch.mean(dy * dy)

    # gradient descent fitting
    loss.backward()
    learning_rate = 0.5
    with torch.no_grad():
        w = w - learning_rate * w.grad
        print(i, w, loss.item())
        w.requires_grad_()

    # visualize result
    PTS = pts.detach().numpy() # convert to numpy
    W = w.detach().numpy()
    T = torch.linspace(0, 1, 10).numpy()


    plt.figure(1)
    ax = plt.axes(projection='3d')
    X,Y = np.meshgrid(np.arange(0,1,0.02), np.arange(0,1,0.02))
    P = np.exp(-(W[0] * X + W[1] - Y)**2*50)
    surf = ax.plot_surface(X, Y, P, linewidth= 0, antialiased= False, alpha=0.5)
    ax.scatter(PTS[:, 0], PTS[:, 1], np.zeros_like(PTS[:, 1]), s=50, marker='x', color='r')
    ax.plot3D(T, W[0] * T + W[1], np.zeros_like(T))
    ax.set_xlim3d(0, 1)
    ax.set_ylim3d(0, 1)
    ax.set_zlim3d(0, 1)
    ax.set_xlabel('x')
    ax.set_ylabel('y')
    ax.set_zlabel('p(y|x,w)')
    plt.pause(0.01)
    plt.draw()
    if CREATE_MOVIE:
        OUTPUT_PATH = './movie/'
        plt.savefig(OUTPUT_PATH + '{:04d}_frame'.format(i) + '.png')

if CREATE_MOVIE:
    os.system('rm ' + OUTPUT_PATH + 'output.mp4')
    os.system('ffmpeg -i ' + OUTPUT_PATH + '%04d_frame.png -c:v libx264 -vf scale=1280:-2 -pix_fmt yuv420p ' + OUTPUT_PATH + 'output.mp4')
