#!/usr/bin/env python

import sys
import traceback
from pathlib import Path

import numpy as np


def backprop(seed: int):
    from engine import Tensor

    ret = True
    try:
        np.random.seed(seed)
        a = Tensor(np.random.rand(1)[0], req_grad=True)
        np.random.seed(seed + 42)
        b = Tensor(np.random.rand(1)[0], req_grad=True)

        c = a + b
    except Exception as e:
        print(f"FAILED! {e}\n{traceback.format_exc()}")
        ret = False
    else:
        try:  # Backward function
            print("Backward:")
            c.backward()
        except Exception as e:
            print(f"FAILED! {e}\n{traceback.format_exc()}")
            ret = False
        else:
            print("PASSED!")

        try:  # Zero grad function
            print("Zero grad:")
            c.zero_grad()
        except Exception as e:
            print(f"FAILED! {e}\n{traceback.format_exc()}")
            ret = False
        else:
            if a.grad == 0 and b.grad == 0:
                print("PASSED!")
            else:
                print("FAILED! Gradient is left non-zero.")
                ret = False

        try:  # Step function
            print("Step:")
            np.random.seed(seed - 42)
            a.grad = np.random.rand(1)[0] / 10
            np.random.seed(seed + 13)
            b.grad = np.random.rand(1)[0] / 10
            c.grad = b.grad * 0.987

            c.step(learning_rate=1)
        except Exception as e:
            print(f"FAILED! {e}\n{traceback.format_exc()}")
            ret = False
        else:
            a_ref = 0.31965876845463004
            b_ref = 0.036730164964299705
            if a.data == a_ref and b.data == b_ref:
                print("PASSED!")
            else:
                print("FAILED! Results do not match reference.")
                ret = False
    return ret


def basic_operations(seed: int):
    from engine import Tensor

    np.random.seed(seed)
    a = np.random.rand(1)[0]
    a_t = Tensor(a)
    np.random.seed(seed + 1)
    b = np.random.rand(1)[0]
    b_t = Tensor(b)

    ret = True
    try:
        print("Basic operations:")
        if (a_t + b_t).data != (a + b):
            print("Addition incorrect")
        if (a_t - b_t).data != (a - b):
            print("Substraction incorrect")
        if (a_t * b_t).data != (a * b):
            print("Multiplication incorrect")
        if (a_t / b_t).data != (a / b):
            print("Division incorrect")
        if (a_t**3).data != (a**3):
            print("Power incorrect")
    except Exception as e:
        print(f"FAILED! {e}")
        ret = False
    else:
        print("PASSED!")

    return ret


def basic_functions(seed: int):
    from engine import Tensor

    np.random.seed(seed)
    a = np.random.rand(1)[0]
    a = -a if a < 0 else a
    a_t = Tensor(a)

    ret = True
    try:
        print("Basic functions:")
        if (a_t.sin()).data != np.sin(a):
            print("Sine function incorrect")
        if (a_t.cos()).data != np.cos(a):
            print("Cosine function incorrect")
        if (a_t.exp()).data != np.exp(a):
            print("Exponential function incorrect")
        if (a_t.log()).data != np.log(a):
            print("Logarithm function incorrect")
    except Exception as e:
        print(f"FAILED! {e}\n{traceback.format_exc()}")
        ret = False
    else:
        print("PASSED!")

    return ret


def activation_functions(seed: int):
    from engine import Tensor

    np.random.seed(seed)
    a = np.random.rand(1)[0]
    a = -a if a < 0 else a
    a_t = Tensor(a)

    ret = True
    try:
        print("Activation functions:")
        if (a_t.relu()).data != max(a, 0):
            print("ReLU function incorrect")
        if (a_t.sigmoid()).data != (1 / (1 + np.exp(-a))):
            print("Sigmoid function incorrect")
        if (a_t.tanh()).data != np.tanh(a):
            print("Hyperbolic tangens function incorrect")
    except Exception as e:
        print(f"FAILED! {e}\n{traceback.format_exc()}")
        ret = False
    else:
        print("PASSED!")

    return ret


def backward(seed: int):
    from engine import Tensor

    np.random.seed(seed)
    a = Tensor(np.random.rand(1)[0])
    np.random.seed(seed + 1)
    b = Tensor(np.random.rand(1)[0])

    ret = True
    try:
        print("Backward pass:")
        c = a + b
        d = a * (b - 3)
        e = c.cos() / a.sin()
        f = d.exp().log() ** -2
        g = e.relu() + f.sigmoid()

        g.backward()
        res = np.array([a.grad, b.grad, c.grad, d.grad, e.grad, f.grad, g.grad])
    except Exception as e:
        print(f"FAILED! {e}\n{traceback.format_exc()}")
        ret = False
    else:
        expected = np.array(
            [
                -8.378997995644227,
                -1.161199827420063,
                -1.2854317620892728,
                0.33169192942943027,
                1.0,
                0.20922460207654955,
                1.0,
            ]
        )
        if np.allclose(res, expected):
            print("PASSED!")
        else:
            print("FAILED! Result does not match expected results.")

    return ret


def test_all():
    module_path = Path("./engine.py")
    if not module_path.exists():
        print("Module file engine could not be found")
        sys.exit(1)

    try:
        from engine import Tensor
    except Exception as e:
        print(f"Error importing module: {e}\n{traceback.format_exc()}")
        sys.exit(1)

    seed = 42  # DO NOT CHANGE! or certain test will not work

    if not backprop(seed):
        sys.exit(1)
    basic_operations(seed)
    basic_functions(seed)
    activation_functions(seed)
    backward(seed)


if __name__ == "__main__":
    test_all()
