Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix image classification scripts and Improve Fp16 tutorial #11533

Merged
merged 11 commits into from
Jul 29, 2018
12 changes: 10 additions & 2 deletions docs/faq/float16.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,17 @@ python fine-tune.py --network resnet --num-layers 50 --pretrained-model imagenet
```

## Example training results
Here is a plot to compare the training curves of a Resnet50 v1 network on the Imagenet 2012 dataset. These training jobs ran for 95 epochs with a batch size of 1024 using a learning rate of 0.4 decayed by a factor of 1 at epochs 30,60,90 and used Gluon. The only changes made for the float16 job when compared to the float32 job were that the network and data were cast to float16, and the multi-precision mode was used for optimizer. The final accuracies at 95th epoch were **76.598% for float16** and **76.486% for float32**. The difference is within what's normal random variation, and there is no reason to expect float16 to have better accuracy than float32 in general. This run was approximately **65% faster** to train with float16.
Let us consider training a Resnet50 v1 model on the Imagenet 2012 dataset. For this model, the GPU memory usage is close to the capacity of V100 GPU with a batch size of 128 when using float32. Using float16 allows the use of 256 batch size. Shared below are results using 8 V100 GPUs. Let us compare the three scenarios that arise here: float32 with 1024 batch size, float16 with 1024 batch size and float16 with 2048 batch size. These jobs trained for 90 epochs using a learning rate of 0.4 for 1024 batch size and 0.8 for 2048 batch size. This learning rate was decayed by a factor of 0.1 at the 30th, 60th and 80th epochs. The only changes made for the float16 jobs when compared to the float32 job were that the network and data were cast to float16, and the multi-precision mode was used for optimizer. The final accuracy at 90th epoch and the time to train are tabulated below for these three scenarios. The top-1 validation errors at the end of each epoch are also plotted below.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to be specific on the overall hardware setup (it's not done on DGX).

Shared below are results using 8 V100 GPUs ->
Shared below are results using 8 V100 GPUs on AWS p3.16xlarge instance.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok yeah


![Training curves of Resnet50 v1 on Imagenet 2012](https://raw.githubusercontent.com/rahul003/web-data/03929a8beb8ac574f2392ed34cc6d4b2f052826a/mxnet/tutorials/mixed-precision/resnet50v1b_imagenet_fp16_fp32_training.png)
Batch size | Data type | Top 1 Validation accuracy | Time to train | Speedup |
--- | --- | --- | --- | --- |
1024 | float32 | 76.18% | 11.8 hrs | 1 |
1024 | float16 | 76.34% | 7.3 hrs | 1.62x |
2048 | float16 | 76.29% | 6.5 hrs | 1.82x |

![Training curves of Resnet50 v1 on Imagenet 2012](https://github.com/rahul003/web-data/blob/d415abf4a1c6df007483169c81807c250135f9a5/mxnet/tutorials/mixed-precision/resnet50v1b_imagenet_fp16_fp32_training.png?raw=true)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we not use personal repo for images?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately I have a hard time getting anything merged into DMLC repos. Looks like there have been more than 25 commits since I opened this PR 4 weeks back there dmlc/web-data#79 but this is still unnoticed!


The differences in accuracies above are within normal random variation, and there is no reason to expect float16 to have better accuracy than float32 in general. As the plot indicates training behaves similarly for these cases, even though we didn't have to change any other hyperparameters. We can also see from the table that using float16 helps train faster through faster computation with float16 as well as allowing the use of larger batch sizes.

## Things to keep in mind

Expand Down
17 changes: 9 additions & 8 deletions example/gluon/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os
import random
import logging
import tarfile
logging.basicConfig(level=logging.INFO)

import mxnet as mx
Expand Down Expand Up @@ -91,9 +92,10 @@ def get_imagenet_iterator(root, batch_size, num_workers, data_shape=224, dtype='
def get_caltech101_data():
url = "https://s3.us-east-2.amazonaws.com/mxnet-public/101_ObjectCategories.tar.gz"
dataset_name = "101_ObjectCategories"
if not os.path.isdir("data"):
data_folder = "data"
if not os.path.isdir(data_folder):
os.makedirs(data_folder)
tar_path = mx.gluon.utils.download(url, path='data')
tar_path = mx.gluon.utils.download(url, path=data_folder)
if (not os.path.isdir(os.path.join(data_folder, "101_ObjectCategories")) or
not os.path.isdir(os.path.join(data_folder, "101_ObjectCategories_test"))):
tar = tarfile.open(tar_path, "r:gz")
Expand All @@ -109,18 +111,17 @@ def transform(image, label):
# resize the shorter edge to 224, the longer edge will be greater or equal to 224
resized = mx.image.resize_short(image, 224)
# center and crop an area of size (224,224)
cropped, crop_info = mx.image.center_crop(resized, 224)
cropped, crop_info = mx.image.center_crop(resized, (224, 224))
# transpose the channels to be (3,224,224)
transposed = nd.transpose(cropped, (2, 0, 1))
image = mx.nd.cast(image, dtype)
return image, label
transposed = mx.nd.transpose(cropped, (2, 0, 1))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is dtype casting no longer necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in this script as it does astype() later. And that will also be more general for any dataset iterator

return transposed, label

training_path, testing_path = get_caltech101_data()
dataset_train = ImageFolderDataset(root=training_path, transform=transform)
dataset_test = ImageFolderDataset(root=testing_path, transform=transform)

train_data = gluon.data.DataLoader(dataset_train, batch_size, shuffle=True, num_workers=num_workers)
test_data = gluon.data.DataLoader(dataset_test, batch_size, shuffle=False, num_workers=num_workers)
train_data = DataLoader(dataset_train, batch_size, shuffle=True, num_workers=num_workers)
test_data = DataLoader(dataset_test, batch_size, shuffle=False, num_workers=num_workers)
return DataLoaderIter(train_data), DataLoaderIter(test_data)

class DummyIter(mx.io.DataIter):
Expand Down
9 changes: 6 additions & 3 deletions example/image-classification/benchmark_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,13 +79,16 @@ def score(network, dev, batch_size, num_batches, dtype):
logging.info('network: %s', net)
for d in devs:
logging.info('device: %s', d)
logged_fp16_warning = False
for b in batch_sizes:
for dtype in ['float32', 'float16']:
if d == mx.cpu() and dtype == 'float16':
#float16 is not supported on CPU
continue
elif net in ['inception-bn', 'alexnet'] and dt == 'float16':
logging.info('{} does not support float16'.format(net))
elif net in ['inception-bn', 'alexnet'] and dtype == 'float16':
if not logged_fp16_warning:
logging.info('Model definition for {} does not support float16'.format(net))
logged_fp16_warning = True
else:
speed = score(network=net, dev=d, batch_size=b, num_batches=10, dtype=dtype)
logging.info('batch size %2d, dtype %s image/sec: %f', b, dtype, speed)
logging.info('batch size %2d, dtype %s, images/sec: %f', b, dtype, speed)
4 changes: 2 additions & 2 deletions example/image-classification/fine-tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def get_fine_tune_model(symbol, arg_params, num_classes, layer_name, dtype='floa
parser.add_argument('--layer-before-fullc', type=str, default='flatten0',
help='the name of the layer before the last fullc layer')\

# use less augmentations for fine-tune
data.set_data_aug_level(parser, 1)
# use less augmentations for fine-tune. by default here it uses no augmentations

# use a small learning rate and less regularizations
parser.set_defaults(image_shape='3,224,224',
num_epochs=30,
Expand Down