import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation

# --- 1. The Landscape ---
def f(x):
    return x**4 - 4*x**2 + x

def df(x):
    return 4*x**3 - 8*x + 1

def ddf(x):
    return 12*x**2 - 8

# --- 2. Parameters ---
start_x = -1      # Starting on the left slope (trap side)
steps = 200
lr_gd = 0.1        # Standard GD LR
lr_adam = 0.1     # Adam usually needs a different tuning
gamma = 0.9        # Momentum factor

# --- 3. Initialize Optimizers ---
history = {
    'GD': [start_x],
    'Momentum': [start_x],
    'Newton': [start_x],
    'Adam': [start_x]
}

# State trackers
curr = {k: start_x for k in history}

# Momentum state
vel_mom = 0.0

# Adam state
m = 0.0  # First moment (momentum)
v = 0.0  # Second moment (variance)
beta1 = 0.9
beta2 = 0.999
epsilon = 1e-8
t = 0

# --- 4. Optimization Loop ---
for _ in range(steps):
    t += 1
    
    # A. Standard GD (Red)
    curr['GD'] -= lr_gd * df(curr['GD'])
    
    # B. GD + Momentum (Purple)
    grad_mom = df(curr['Momentum'])
    vel_mom = gamma * vel_mom + lr_gd * grad_mom
    curr['Momentum'] -= vel_mom
    
    # C. Newton (Blue)
    H = ddf(curr['Newton'])
    safe_H = abs(H) if abs(H) > 0.5 else 0.5
    curr['Newton'] -= lr_gd * (df(curr['Newton']) / safe_H)
    
    # D. Adam (Orange)
    grad_adam = df(curr['Adam'])
    
    # Update moments
    m = beta1 * m + (1 - beta1) * grad_adam
    v = beta2 * v + (1 - beta2) * (grad_adam**2)
    
    # Bias correction
    m_hat = m / (1 - beta1**t)
    v_hat = v / (1 - beta2**t)
    
    # Update parameter
    curr['Adam'] -= lr_adam * m_hat / (np.sqrt(v_hat) + epsilon)

    # Record history
    for k in history:
        history[k].append(curr[k])

# --- 5. Visualization ---
x_vals = np.linspace(-2.5, 2.5, 400)
y_vals = f(x_vals)

fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(x_vals, y_vals, 'k-', lw=2, alpha=0.3, label='Landscape')
ax.set_ylim(-15, 10)
ax.set_title("Comparison: Adam vs. Momentum vs. Newton")

styles = {
    'GD': {'c': 'red', 'marker': 'o', 'alpha': 0.5, 'label': 'GD'},
    'Momentum': {'c': 'purple', 'marker': 'o', 'label': 'GD + Momentum'},
    'Newton': {'c': 'blue', 'marker': '*', 'ms': 12, 'label': 'Newton'},
    'Adam': {'c': 'orange', 'marker': 's', 'label': 'Adam'}
}

lines = {}
for k, style in styles.items():
    line, = ax.plot([], [], **style)
    lines[k] = line

def init():
    for line in lines.values():
        line.set_data([], [])
    return lines.values()

def update(i):
    for k, line in lines.items():
        x = history[k][i]
        y = f(x)
        line.set_data([x], [y])
    return lines.values()

# Create a custom frame sequence to slow down the initial part of the animation
slow_frames_count = 20
slow_down_factor = 15
fast_interval = 50

def frame_generator():
    # Yield the first 20 frames slowly
    for i in range(slow_frames_count):
        yield i
    # Yield the remaining frames faster
    for i in range(slow_frames_count, steps):
        yield i

def interval_generator():
    # Interval for the first 20 frames
    for _ in range(slow_frames_count):
        yield fast_interval * slow_down_factor
    # Interval for the rest of the frames
    while True:
        yield fast_interval

ani = animation.FuncAnimation(fig, update, frames=frame_generator, init_func=init, blit=True, interval=next(interval_generator()), save_count=steps)

plt.legend()
plt.grid(True, alpha=0.3)
plt.show()