
def f(y):
    return sum(sum(list(map(lambda a: list(map(lambda b: b**2, a)), y)), []))


def g(y):
    return list(map(lambda a: list(map(lambda b: 2*b, a)), y))


def conv(x, w):
    w_len = len(w)
    y_len = len(x) - w_len + 1

    y = []
    for i in range(y_len):
        row = []
        for j in range(y_len):
            stamp = 0
            for m in range(w_len):
                for n in range(w_len):
                    stamp += w[m][n] * x[i+m][j+n]
            row.append(stamp)
        y.append(row)
    return y


if __name__ == "__main__":
    x = [[1, 0, 2],
         [0, 0, 2],
         [1, 1, 0]]

    w = [[0, 1],
         [0, 1]]

    y = conv(x, w)
    l = f(y)
    w_grad = conv(x, g(y))
    print("y:")
    print('\n'.join([''.join(['{:4}'.format(item) for item in row])for row in y]))
    print(f'\nl: {l}\n')
    print("w gradient:")
    print('\n'.join([''.join(['{:4}'.format(item) for item in row])for row in w_grad]))
    # y:
    # 0 4
    # 1 2
    #
    # l: 21
    #
    # w gradient:
    # 0 24
    # 6 18

    x = [[6, 2, 2, 0, 4],
         [1, 3, 7, 3, 7],
         [1, 1, 0, 4, 1],
         [4, 2, 8, 1, 3],
         [0, 1, 4, 0, 9]]

    w = [[0, 1, 2],
         [1, 2, 3],
         [2, 3, 4]]

    y = conv(x, w)
    l = f(y)
    w_grad = conv(x, g(y))
    print("y:")
    print('\n'.join([''.join(['{:4}'.format(item) for item in row]) for row in y]))
    print(f'\nl: {l}\n')
    print("w gradient:")
    print('\n'.join([''.join(['{:5}'.format(item) for item in row]) for row in w_grad]))
    # y:
    # 39 46 58
    # 66 58 59
    # 52 43 69
    #
    # l: 27516
    #
    # w gradient:
    # 2380 2558 3200
    # 3106 2864 3548
    # 2512 2300 3668
