import matplotlib.pyplot as plt
import numpy as np

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

# query if we have GPU
dev = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('Using device:', dev)

# transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# datasets
trainset = torchvision.datasets.MNIST('./data', download=True, train=True, transform=transform)

# dataloaders
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True, num_workers=0)

# lets verify how the loader packs the data
(data, target) = next(iter(train_loader))
# probably get [batch_size x 1 x 28 x 28]
print('Input  type:', data.type())
print('Input  size:', data.size())
# probably get [batch_size]
print('Labels size:', target.size())
# see number of trainig data:
n_train_data = len(trainset)
print('Train data size:', n_train_data)

# network, expect input images 28* 28 and 10 classes
net = nn.Sequential(nn.Linear(28 * 28, 10))
net.to(dev)
# loss function
loss = nn.CrossEntropyLoss(reduction='none')

# optimizer
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

for epoch in range(10):
    # will accumulate total loss over the dataset
    L = 0
    # loop fetching a mini-batch of data at each iteration
    for i, (data, target) in enumerate(train_loader):
        data = data.to(dev)
        target = target.to(dev)
        # flatten the data size to [batch_size x 784]
        data_vectors = data.flatten(start_dim=1)
        # apply the network
        y = net.forward(data_vectors)
        # calculate mini-batch losses
        l = loss(y, target)
        # accumulate the total loss as a regular float number (important to sop graph tracking)
        L += l.sum().item()
        # the gradient usually accumulates, need to clear explicitly
        optimizer.zero_grad()
        # compute the gradient from the mini-batch loss
        l.mean().backward()
        # make the optimization step
        optimizer.step()
    print(f'Epoch: {epoch} mean loss: {L / n_train_data}')
