import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

# --- 1. The Landscape (Quadratic) ---
def f(x):
    return x**2  # Simple Bowl

def df(x):
    return 2*x

def ddf(x):
    return 2     # Constant positive curvature

# --- 2. Parameters ---
start_x = -2.5    # Start far away to show the travel
steps = 50
lr_gd = 0.1       # Standard Learning Rate
lr_adam = 0.2     # Adam Learning Rate
gamma = 0.9       # Momentum

# --- 3. Initialize Optimizers ---
history = {
    'GD': [start_x],
    'Momentum': [start_x],
    'Newton': [start_x],
    'Adam': [start_x]
}

curr = {k: start_x for k in history}
vel_mom = 0.0

# Adam states
m = 0.0
v = 0.0
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
t = 0

# --- 4. Optimization Loop ---
for _ in range(steps):
    t += 1
    
    # A. Standard GD (Red)
    curr['GD'] -= lr_gd * df(curr['GD'])
    
    # B. GD + Momentum (Purple)
    grad_mom = df(curr['Momentum'])
    vel_mom = gamma * vel_mom + lr_gd * grad_mom
    curr['Momentum'] -= vel_mom
    
    # C. Newton's Method (Blue) - UNLEASHED
    # No damping (lr=1.0), no safety checks. 
    # It should solve x^2 in exactly 1 step.
    H = ddf(curr['Newton'])
    curr['Newton'] -= 1.0 * (df(curr['Newton']) / H)
    
    # D. Adam (Orange)
    grad_adam = df(curr['Adam'])
    m = beta1 * m + (1 - beta1) * grad_adam
    v = beta2 * v + (1 - beta2) * (grad_adam**2)
    m_hat = m / (1 - beta1**t)
    v_hat = v / (1 - beta2**t)
    curr['Adam'] -= lr_adam * m_hat / (np.sqrt(v_hat) + epsilon)

    # Record
    for k in history:
        history[k].append(curr[k])

# --- 5. Visualization ---
x_vals = np.linspace(-3, 3, 400)
y_vals = f(x_vals)

fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(x_vals, y_vals, 'k-', lw=2, alpha=0.3, label='Quadratic Function')
ax.set_ylim(-1, 9)
ax.set_title("Optimization on a Quadratic (Bowl)")

styles = {
    'GD': {'c': 'red', 'marker': 'o', 'label': 'GD (Slow & Steady)'},
    'Momentum': {'c': 'purple', 'marker': 'o', 'label': 'Momentum (Overshoots)'},
    'Newton': {'c': 'blue', 'marker': '*', 'ms': 15, 'label': 'Newton (1-Step Magic)'},
    'Adam': {'c': 'orange', 'marker': 's', 'label': 'Adam (Adaptive)'}
}

lines = {}
for k, style in styles.items():
    line, = ax.plot([], [], **style)
    lines[k] = line

def init():
    for line in lines.values():
        line.set_data([], [])
    return lines.values()

def update(i):
    for k, line in lines.items():
        x = history[k][i]
        y = f(x)
        line.set_data([x], [y])
    return lines.values()

ani = animation.FuncAnimation(fig, update, frames=steps, init_func=init, blit=True, interval=150)

plt.legend()
plt.grid(True, alpha=0.3)
plt.show()