- Harshavardhan P - 2021111003
- Kapil Rajesh Kavitha - 2021101028
python mixmatch.py --epochs 3
- MixMatch is a semi-supervised learning technique that combines labeled and unlabeled data to improve the performance of the model. This is a Python implementation of a deep learning model using the MixMatch algorithm.
- We utilise the parse_args() function from the argparse library to parse the arguments passed to the script. The arguments are used for choosing the dataset, setting hyperparameters, and specifying GPU usage. This allows us to customise the training process without having to change the code.
- The dataset specified by the
--dataset
argument is loaded using a customdataloader
module, which splits the data into labeled and unlabeled subsets. The number of labeled examples is controlled by the --labeled_n argument. The data is also split into training, validation, and test sets.
-
The model used is a WideResNet (Wide Residual Network) model. Xavier weight initialization is used for initializing the weights of the model's layers.
-
The Adam optimizer is used for training the model. Additionally, an EMA Optimizer is used to maintain an exponential moving average of the model's parameters and weights. This is used to generate predictions for the unlabeled data.
- Before passing the training images to the model, augmentation is performed on them - for each image,
k
augmented images are generated and added to the dataloader. Augmentation is performed by applying a series of transformations such as
- random cropping,
- horizontal flipping,
- applying random color jitter,
- random grayscale conversion, and
- applying random Gaussian blur
The images are then normalized using the mean and standard deviation of the dataset. The number of augmented images generated for each image is controlled by the
--K
argument.
- The mix of labeled and unlabeled data is generated using the MixUp algorithm. The data is then mixed together based on a randomly generated mixing ratio, which is controlled by the
--alpha
argument. Two sets of mixed input pairs and their corresponding target pairs are generated, which are then combined to create the finalmixed_input
andmixed_target
tensors that are used for training the model.
-
The model is trained over the data generated from the MixUp algorithm.The labeled and unlabeled data are combined and processed together using the
interleave_tensors
function - this is a critical step in the MixMatch algorithm. -
The model is evaluated on the validation set after every epoch. The model's performance is evaluated using the accuracy metric. The model's performance on the test set is also evaluated after training is complete.
-
After around 3 epochs, the model achieves 100% accuracy on the training set, and the accuracy also gradually increases for the validation set. This happens without the model overfitting on the training data, due to the augmentation techniques used on the training data.