In this assignment, your task will be to train a neural network with multi-loss objective, namely: hierarchical classification and semantic segmentation.
The dataset consists of images of pets, where each image corresponds to a species (cat or dog
) and a breed (25 dog breeds
and 12 cat breeds
). For each image
there is also a semantic segmentation map with three classes: foreground, background & boundary
.
The task is to train a model that can determine the species p(species|image)
, the breed p(breed|image)
and segmentation mask p(mask|image)
. hw03.zip
UPDATE: make sure in your image and mask transforms you use transforms.Resize(128)
and not transforms.Resize((128,128))
as was originally in the homework template!!
Task 1 - Species classification (1 point)
dog
or cat
Task 2 - Breed classification (3 points)
'american_bulldog', 'american_pit_bull_terrier', 'basset_hound', 'beagle', 'boxer', 'chihuahua', 'english_cocker_spaniel', 'english_setter', 'german_shorthaired', 'great_pyrenees', 'havanese', 'japanese_chin', 'keeshond', 'leonberger', 'miniature_pinscher', 'newfoundland', 'pomeranian', 'pug', 'saint_bernard', 'samoyed', 'scottish_terrier', 'shiba_inu', 'staffordshire_bull_terrier', 'wheaten_terrier', 'yorkshire_terrier'
'Abyssinian', 'Bengal', 'Birman', 'Bombay', 'British_Shorthair', 'Egyptian_Mau', 'Maine_Coon', 'Persian', 'Ragdoll', 'Russian_Blue', 'Siamese', 'Sphynx'
Task 3 - Semantic segmentation (6 points)
Submit a .zip file containing all your training & inference code. There needs to be a model.py
file, containing a Net
class which has a method predict
.
There also needs to be a weights.pth
file, which will be loaded in BRUTE with:
model.load_state_dict(torch.load(weights_path, map_location=torch.device('cpu')))
You can save the model with:
torch.save(model.state_dict(), "weights.pth")
The method takes a single 3 x 128 x 128
image as input. (processed with the same transform as in the template: Resize
, CenterCrop
, ToTensor
, Normalize
)
After computing the predictions, it outputs them in the following format:
128 x 128 tensor
with values 0, 1 or 2
, representing the background, foreground and boundary classes respectively
Accompanying the assignment, there will be a tournament in which the models will be ranked based on their performance. There will be a separate ranking for each task: species accuracy, top-3 breed accuracy and mean IoU. The final ranking will be determined by the sum of the ranks in all three tasks. The scoring will be based on ranking as follows:
Dropout
and do not go around the internet trying to gather more data. The training code you upload must reach a similar result to yours on the provided data subset.
nn.BatchNorm2d
and nn.BatchNorm1d
into your conv, transposed conv and fully connected layers.
torch.flatten
output isn't too big and doesn't cause the MLP head to have too many parameters. (ie. 128*16*16 = 32 768
→ matrix in the linear layer might have dimension of 32768×256
, if the output would be 128x4x4
, then the matrix will be 2048×256
, which is a smaller jump in dimensionality)
ConvBlock
, input: Bx3x128x128
output: Bx128x4x4
) (each block is Conv2D(3×3,pad=1) → BatchNorm2d → ReLU → MaxPool
)
TransposedConvBlock
, input: Bx128x4x4
, output: BxCx128x128
(each block except last is TransposedConv2D(2×2,stride=2) → BatchNorm2d → ReLU
, last layer is only TransposedConv
, C=3
in our case, as we are segmenting into 3 classes)
Fully Connected Layers
, input: Bx(128*4*4)=Bx2048
and output: Bx2
(each layer except last is Linear → BatchNorm1d → ReLU
, last layer is only Linear
)
Fully Connected Layer
, input: Bx(128*4*4)=Bx2024
and output: Bx37
(for number of breeds, same structure as species classifier otherwise)
loss = segment_loss + species_loss + breed_loss
Conv2d → BatchNorm2d → ReLU → Conv2d → BatchNorm2d → ReLU → MaxPool
)
nn.CrossEntropyLoss
has an argument weight
, which takes in a (#classes,)
shaped tensor and weighs the loss for each example based on its ground truth class (this helps greatly with class imbalance)
nn.Dropout(prob)
to your network after some / all blocks.
Good luck, if you get stuck feel free to consult the web or various chatbots, just make sure to acquire true understanding in the process and not just copy stuff, in the case of any questions or concerns please contact siproman@fel.cvut.cz.