import torch
import matplotlib.pyplot as plt
import torch.nn as nn
import matplotlib as mpl


#######################################################
# Show what the convolution does to a grayscale image #
#######################################################

im_gray = torch.load('austrich')
plt.figure(1)
plt.imshow(im_gray[0,0,:,:], cmap='gray', norm=mpl.colors.Normalize())
plt.colorbar()

conv1 = nn.Conv2d(1, 1, 6)  # input channels, output channels/kernels, kernel_size
conv1.weight[:,:,0:7,0:3] = 1
conv1.weight[:,:,0:7,3:7] = -1
plt.figure(2)
plt.imshow(conv1.weight.detach()[0,0,:,:], cmap='gray')
plt.colorbar()

plt.figure(3)
plt.imshow(conv1(im_gray).detach()[0,0,:,:],cmap='gray', norm=mpl.colors.Normalize())
plt.colorbar()


###### padding
conv1 = nn.Conv2d(1, 1, 6, padding=(15,15))  # input channels, output channels/kernels, kernel_size
conv1.weight[:,:,0:7,0:3] = 1
conv1.weight[:,:,0:7,3:7] = -1


plt.figure(4)
plt.imshow(conv1(im_gray).detach()[0,0,:,:],cmap='gray', norm=mpl.colors.Normalize())
plt.colorbar()


###### stride
conv1 = nn.Conv2d(1, 1, 6, stride=(3,3))  # input channels, output channels/kernels, kernel_size
conv1.weight[:,:,0:7,0:3] = 1
conv1.weight[:,:,0:7,3:7] = -1


plt.figure(5)
plt.imshow(conv1(im_gray).detach()[0,0,:,:],cmap='gray', norm=mpl.colors.Normalize())
plt.colorbar()


