import os.path
from pathlib import Path
import requests
import matplotlib.pyplot as plt
import numpy as np
import torch.utils.data


def Download_MNIST(path='data', show=False):
    DATA_PATH = Path(path)
    PATH = DATA_PATH / "mnist"

    PATH.mkdir(parents=True, exist_ok=True)
    URL = "https://github.com/pytorch/tutorials/raw/master/_static/"
    FILENAME = "mnist.pkl.gz"

    if not (PATH / FILENAME).exists():
            content = requests.get(URL + FILENAME).content
            (PATH / FILENAME).open("wb").write(content)

    import pickle
    import gzip

    with gzip.open((PATH / FILENAME).as_posix(), "rb") as f:
            ((x_train, y_train), (x_valid, y_valid), (x_test, y_test)) = pickle.load(f, encoding="latin-1")

    if show:
        fig_mnist, axes = plt.subplots(2,5)

        numbers, ind = np.unique(y_train, return_index=True)

        for i in range(len(axes)):
                for j in range(len(axes[0])):
                        axes[i][j].imshow(x_train[ind][j + i * 5].reshape(28,28))
                        axes[i][j].set_xticklabels([])
                        axes[i][j].set_yticklabels([])
        plt.show()
    return ((x_train, y_train), (x_valid, y_valid), (x_test, y_test))

def Download_CIFAR10(path='data', show=False):
    import torchvision
    transform = torchvision.transforms.Compose([
        torchvision.transforms.ToTensor()
    ])
    download = False if os.path.exists(f'{path}/cifar-10-batches.py') else True
    train_dataset = torchvision.datasets.CIFAR10(f'{path}', download=download, transform=transform)
    test_dataset = torchvision.datasets.CIFAR10(f'{path}', train=False, download=download, transform=transform)

    if show:
        fig_mnist, axes = plt.subplots(2,5)
        label_list = []

        for i in train_dataset:
            if i[1] not in label_list:
                label_list.append(i)
            if len(label_list) == 10:
                break


        for i in range(len(axes)):
                for j in range(len(axes[0])):

                        axes[i][j].imshow(label_list[j + i * 5][0].permute(1,2,0))
                        axes[i][j].set_xticklabels([])
                        axes[i][j].set_yticklabels([])
        plt.show()

    return train_dataset, test_dataset


class Classification_Dataset(torch.utils.data.Dataset):
    def __init__(self, preloaded_data=None, preloaded_labels=None, paths=None):
        self.data = preloaded_data
        self.labels = preloaded_labels
        self.paths = paths

    def __getitem__(self, index):
        '''
        :param index: nbr of the data sample to be taken from the dataloader
        :return: input data to the neural network
        '''
        one_data_sample = self.data[index]#.reshape(28,28)
        one_label_sample = self.labels[index]#.reshape(28,28)

        batch = {'data' : one_data_sample,
                 'label' : one_label_sample,
                 'index' : index}

        return batch


    def __len__(self):
        '''
        :return: maximal index
        '''

        return len(self.data)

if __name__ == '__main__':
    # Get the data
    train_data, validation_data, test_data = Download_MNIST(show=True)
    train_cifar10, test_cifar10 = Download_CIFAR10(show=True)
    # Wrap it in Dataset class to have all together and iterable
    dataset = Classification_Dataset(preloaded_data=train_data[0], preloaded_labels=train_data[1])
    # Wrap it in DataLoader Class to have specified batch_size and shuffle for maximal diversity and randomness to prevent overfitting
    trn_loader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

