-
Notifications
You must be signed in to change notification settings - Fork 6.8k
Fix image classification scripts and Improve Fp16 tutorial #11533
Changes from 3 commits
cf27c68
d02b988
086327b
6ae11b8
fa08b87
3d4eb33
a151d57
419053c
81dca54
5dd8c2a
953709e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -102,9 +102,17 @@ python fine-tune.py --network resnet --num-layers 50 --pretrained-model imagenet | |
``` | ||
|
||
## Example training results | ||
Here is a plot to compare the training curves of a Resnet50 v1 network on the Imagenet 2012 dataset. These training jobs ran for 95 epochs with a batch size of 1024 using a learning rate of 0.4 decayed by a factor of 1 at epochs 30,60,90 and used Gluon. The only changes made for the float16 job when compared to the float32 job were that the network and data were cast to float16, and the multi-precision mode was used for optimizer. The final accuracies at 95th epoch were **76.598% for float16** and **76.486% for float32**. The difference is within what's normal random variation, and there is no reason to expect float16 to have better accuracy than float32 in general. This run was approximately **65% faster** to train with float16. | ||
Let us consider training a Resnet50 v1 model on the Imagenet 2012 dataset. For this model, the GPU memory usage is close to the capacity of V100 GPU with a batch size of 128 when using float32. Using float16 allows the use of 256 batch size. Shared below are results using 8 V100 GPUs. Let us compare the three scenarios that arise here: float32 with 1024 batch size, float16 with 1024 batch size and float16 with 2048 batch size. These jobs trained for 90 epochs using a learning rate of 0.4 for 1024 batch size and 0.8 for 2048 batch size. This learning rate was decayed by a factor of 0.1 at the 30th, 60th and 80th epochs. The only changes made for the float16 jobs when compared to the float32 job were that the network and data were cast to float16, and the multi-precision mode was used for optimizer. The final accuracy at 90th epoch and the time to train are tabulated below for these three scenarios. The top-1 validation errors at the end of each epoch are also plotted below. | ||
|
||
![Training curves of Resnet50 v1 on Imagenet 2012](https://raw.githubusercontent.com/rahul003/web-data/03929a8beb8ac574f2392ed34cc6d4b2f052826a/mxnet/tutorials/mixed-precision/resnet50v1b_imagenet_fp16_fp32_training.png) | ||
Batch size | Data type | Top 1 Validation accuracy | Time to train | Speedup | | ||
--- | --- | --- | --- | --- | | ||
1024 | float32 | 76.18% | 11.8 hrs | 1 | | ||
1024 | float16 | 76.34% | 7.3 hrs | 1.62x | | ||
2048 | float16 | 76.29% | 6.5 hrs | 1.82x | | ||
|
||
![Training curves of Resnet50 v1 on Imagenet 2012](https://github.com/rahul003/web-data/blob/d415abf4a1c6df007483169c81807c250135f9a5/mxnet/tutorials/mixed-precision/resnet50v1b_imagenet_fp16_fp32_training.png?raw=true) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we not use personal repo for images? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unfortunately I have a hard time getting anything merged into DMLC repos. Looks like there have been more than 25 commits since I opened this PR 4 weeks back there dmlc/web-data#79 but this is still unnoticed! |
||
|
||
The differences in accuracies above are within normal random variation, and there is no reason to expect float16 to have better accuracy than float32 in general. As the plot indicates training behaves similarly for these cases, even though we didn't have to change any other hyperparameters. We can also see from the table that using float16 helps train faster through faster computation with float16 as well as allowing the use of larger batch sizes. | ||
|
||
## Things to keep in mind | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
import os | ||
import random | ||
import logging | ||
import tarfile | ||
logging.basicConfig(level=logging.INFO) | ||
|
||
import mxnet as mx | ||
|
@@ -91,9 +92,10 @@ def get_imagenet_iterator(root, batch_size, num_workers, data_shape=224, dtype=' | |
def get_caltech101_data(): | ||
url = "https://s3.us-east-2.amazonaws.com/mxnet-public/101_ObjectCategories.tar.gz" | ||
dataset_name = "101_ObjectCategories" | ||
if not os.path.isdir("data"): | ||
data_folder = "data" | ||
if not os.path.isdir(data_folder): | ||
os.makedirs(data_folder) | ||
tar_path = mx.gluon.utils.download(url, path='data') | ||
tar_path = mx.gluon.utils.download(url, path=data_folder) | ||
if (not os.path.isdir(os.path.join(data_folder, "101_ObjectCategories")) or | ||
not os.path.isdir(os.path.join(data_folder, "101_ObjectCategories_test"))): | ||
tar = tarfile.open(tar_path, "r:gz") | ||
|
@@ -109,18 +111,17 @@ def transform(image, label): | |
# resize the shorter edge to 224, the longer edge will be greater or equal to 224 | ||
resized = mx.image.resize_short(image, 224) | ||
# center and crop an area of size (224,224) | ||
cropped, crop_info = mx.image.center_crop(resized, 224) | ||
cropped, crop_info = mx.image.center_crop(resized, (224, 224)) | ||
# transpose the channels to be (3,224,224) | ||
transposed = nd.transpose(cropped, (2, 0, 1)) | ||
image = mx.nd.cast(image, dtype) | ||
return image, label | ||
transposed = mx.nd.transpose(cropped, (2, 0, 1)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is dtype casting no longer necessary? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not in this script as it does astype() later. And that will also be more general for any dataset iterator |
||
return transposed, label | ||
|
||
training_path, testing_path = get_caltech101_data() | ||
dataset_train = ImageFolderDataset(root=training_path, transform=transform) | ||
dataset_test = ImageFolderDataset(root=testing_path, transform=transform) | ||
|
||
train_data = gluon.data.DataLoader(dataset_train, batch_size, shuffle=True, num_workers=num_workers) | ||
test_data = gluon.data.DataLoader(dataset_test, batch_size, shuffle=False, num_workers=num_workers) | ||
train_data = DataLoader(dataset_train, batch_size, shuffle=True, num_workers=num_workers) | ||
test_data = DataLoader(dataset_test, batch_size, shuffle=False, num_workers=num_workers) | ||
return DataLoaderIter(train_data), DataLoaderIter(test_data) | ||
|
||
class DummyIter(mx.io.DataIter): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better to be specific on the overall hardware setup (it's not done on DGX).
Shared below are results using 8 V100 GPUs
->Shared below are results using 8 V100 GPUs on AWS p3.16xlarge instance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok yeah