import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from mpl_toolkits.mplot3d import Axes3D

# --- 1. The 2D Landscape (The Ravine) ---
def f(x, y):
    return 0.5*x**2 + 20*y**2

def df(x, y):
    return np.array([x, 40*y])

# --- 2. Setup ---
start_pos = np.array([-4.0, 2.0]) # Start "high" on the steep side
steps = 100
lr = 0.1

# --- 3. Optimizers ---
# Helper to create 3D points (x, y, z)
def make_3d_point(pos):
    return np.append(pos, f(pos[0], pos[1]))

path_gd = [make_3d_point(start_pos)]
path_mom = [make_3d_point(start_pos)]
path_adam = [make_3d_point(start_pos)]

# State variables
curr_gd = start_pos.copy()
curr_mom = start_pos.copy()
curr_adam = start_pos.copy()

# Momentum specific
vel_mom = np.zeros(2)
gamma = 0.9

# Adam specific
m = np.zeros(2) # vector [mx, my]
v = np.zeros(2) # vector [vx, vy]
beta1 = 0.9
beta2 = 0.999
eps = 1e-8
t = 0

# --- 4. Optimization Loop ---
for _ in range(steps):
    t += 1
    
    # A. GD (Standard)
    grad = df(curr_gd[0], curr_gd[1])
    curr_gd -= lr * grad # Must use small LR or it explodes on the steep y-axis
    path_gd.append(make_3d_point(curr_gd))
    
    # B. Momentum
    grad = df(curr_mom[0], curr_mom[1])
    vel_mom = gamma * vel_mom + lr * grad
    curr_mom -= vel_mom
    path_mom.append(make_3d_point(curr_mom))
    
    # C. Adam
    grad = df(curr_adam[0], curr_adam[1])
    
    m = beta1 * m + (1 - beta1) * grad
    v = beta2 * v + (1 - beta2) * (grad**2) # Element-wise square!
    
    m_hat = m / (1 - beta1**t)
    v_hat = v / (1 - beta2**t)
    
    # Element-wise division adjusts x and y differently
    curr_adam -= lr * m_hat / (np.sqrt(v_hat) + eps) 
    path_adam.append(make_3d_point(curr_adam))

# --- 5. Visualization (3D Surface Plot) ---
x_range = np.linspace(-5, 5, 100)
y_range = np.linspace(-3, 3, 100)
X, Y = np.meshgrid(x_range, y_range)
Z = f(X, Y)

fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
ax.plot_surface(X, Y, Z, cmap='viridis', alpha=0.4, rstride=5, cstride=5)
ax.set_title("3D 'Ravine': Adam vs Momentum vs GD")
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("f(x, y)")
ax.view_init(elev=30, azim=-120)

# Lines
line_gd, = ax.plot([], [], [], 'r-', label='GD', alpha=0.8)
line_mom, = ax.plot([], [], [], 'm-', label='Momentum', alpha=0.8)
line_adam, = ax.plot([], [], [], 'orange', lw=2, label='Adam')

# Dots (Heads)
dot_gd, = ax.plot([], [], [], 'ro')
dot_mom, = ax.plot([], [], [], 'mo')
dot_adam, = ax.plot([], [], [], 's', color='orange')

def init():
    line_gd.set_data_3d([], [], [])
    line_mom.set_data_3d([], [], [])
    line_adam.set_data_3d([], [], [])
    dot_gd.set_data_3d([], [], [])
    dot_mom.set_data_3d([], [], [])
    dot_adam.set_data_3d([], [], [])
    return line_gd, line_mom, line_adam, dot_gd, dot_mom, dot_adam

def update(i):
    # GD
    data = np.array(path_gd[:i+1])
    line_gd.set_data_3d(data[:,0], data[:,1], data[:,2])
    dot_gd.set_data_3d([data[-1,0]], [data[-1,1]], [data[-1,2]])
    
    # Momentum
    data = np.array(path_mom[:i+1])
    line_mom.set_data_3d(data[:,0], data[:,1], data[:,2])
    dot_mom.set_data_3d([data[-1,0]], [data[-1,1]], [data[-1,2]])
    
    # Adam
    data = np.array(path_adam[:i+1])
    line_adam.set_data_3d(data[:,0], data[:,1], data[:,2])
    dot_adam.set_data_3d([data[-1,0]], [data[-1,1]], [data[-1,2]])
    
    return line_gd, line_mom, line_adam, dot_gd, dot_mom, dot_adam

ani = animation.FuncAnimation(fig, update, frames=steps, init_func=init, blit=True, interval=50)

plt.legend()
plt.show()