Skip to content

Commit

Permalink
Add standard ResNet data augmentation for ImageRecordIter (apache#11027)
Browse files Browse the repository at this point in the history
* add resnet augmentation

* add test

* fix scope

* fix warning

* fix lint

* fix lint

* add color jitter and pca noise

* fix center crop

* merge

* fix lint

* Trigger CI

* fix

* fix augmentation implementation

* add checks for parameters

* modify training script

* fix compile error

* Trigger CI

* Trigger CI

* modify error message

* Trigger CI

* Trigger CI

* Trigger CI

* improve script in example

* fix script

* clear code

* Trigger CI

* set min_aspect_ratio to optional, move rotation and pad before random resized crop

* fix

* Trigger CI

* Trigger CI

* Trigger CI

* fix default values

* Trigger CI
  • Loading branch information
hetong007 authored and zheng-da committed Jun 28, 2018
1 parent bfaca2b commit 977026c
Show file tree
Hide file tree
Showing 4 changed files with 435 additions and 31 deletions.
48 changes: 37 additions & 11 deletions example/image-classification/common/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,23 @@ def add_data_args(parser):
def add_data_aug_args(parser):
aug = parser.add_argument_group(
'Image augmentations', 'implemented in src/io/image_aug_default.cc')
aug.add_argument('--random-crop', type=int, default=1,
aug.add_argument('--random-crop', type=int, default=0,
help='if or not randomly crop the image')
aug.add_argument('--random-mirror', type=int, default=1,
aug.add_argument('--random-mirror', type=int, default=0,
help='if or not randomly flip horizontally')
aug.add_argument('--max-random-h', type=int, default=0,
help='max change of hue, whose range is [0, 180]')
aug.add_argument('--max-random-s', type=int, default=0,
help='max change of saturation, whose range is [0, 255]')
aug.add_argument('--max-random-l', type=int, default=0,
help='max change of intensity, whose range is [0, 255]')
aug.add_argument('--min-random-aspect-ratio', type=float, default=None,
help='min value of aspect ratio, whose value is either None or a positive value.')
aug.add_argument('--max-random-aspect-ratio', type=float, default=0,
help='max change of aspect ratio, whose range is [0, 1]')
help='max value of aspect ratio. If min_random_aspect_ratio is None, '
'the aspect ratio range is [1-max_random_aspect_ratio, '
'1+max_random_aspect_ratio], otherwise it is '
'[min_random_aspect_ratio, max_random_aspect_ratio].')
aug.add_argument('--max-random-rotate-angle', type=int, default=0,
help='max angle to rotate, whose range is [0, 360]')
aug.add_argument('--max-random-shear-ratio', type=float, default=0,
Expand All @@ -63,16 +68,28 @@ def add_data_aug_args(parser):
help='max ratio to scale')
aug.add_argument('--min-random-scale', type=float, default=1,
help='min ratio to scale, should >= img_size/input_shape. otherwise use --pad-size')
aug.add_argument('--max-random-area', type=float, default=1,
help='max area to crop in random resized crop, whose range is [0, 1]')
aug.add_argument('--min-random-area', type=float, default=1,
help='min area to crop in random resized crop, whose range is [0, 1]')
aug.add_argument('--brightness', type=float, default=0,
help='brightness jittering, whose range is [0, 1]')
aug.add_argument('--contrast', type=float, default=0,
help='contrast jittering, whose range is [0, 1]')
aug.add_argument('--saturation', type=float, default=0,
help='saturation jittering, whose range is [0, 1]')
aug.add_argument('--pca-noise', type=float, default=0,
help='pca noise, whose range is [0, 1]')
aug.add_argument('--random-resized-crop', type=int, default=0,
help='whether to use random resized crop')
return aug

def set_data_aug_level(aug, level):
if level >= 1:
aug.set_defaults(random_crop=1, random_mirror=1)
if level >= 2:
aug.set_defaults(max_random_h=36, max_random_s=50, max_random_l=50)
if level >= 3:
aug.set_defaults(max_random_rotate_angle=10, max_random_shear_ratio=0.1, max_random_aspect_ratio=0.25)

def set_resnet_aug(aug):
# standard data augmentation setting for resnet training
aug.set_defaults(random_crop=1, random_resized_crop=1)
aug.set_defaults(min_random_area=0.08)
aug.set_defaults(max_random_aspect_ratio=4./3., min_random_aspect_ratio=3./4.)
aug.set_defaults(brightness=0.4, contrast=0.4, saturation=0.4, pca_noise=0.1)

class SyntheticDataIter(DataIter):
def __init__(self, num_classes, data_shape, max_iter, dtype):
Expand Down Expand Up @@ -135,8 +152,16 @@ def get_rec_iter(args, kv=None):
max_random_scale = args.max_random_scale,
pad = args.pad_size,
fill_value = 127,
random_resized_crop = args.random_resized_crop,
min_random_scale = args.min_random_scale,
max_aspect_ratio = args.max_random_aspect_ratio,
min_aspect_ratio = args.min_random_aspect_ratio,
max_random_area = args.max_random_area,
min_random_area = args.min_random_area,
brightness = args.brightness,
contrast = args.contrast,
saturation = args.saturation,
pca_noise = args.pca_noise,
random_h = args.max_random_h,
random_s = args.max_random_s,
random_l = args.max_random_l,
Expand All @@ -156,6 +181,7 @@ def get_rec_iter(args, kv=None):
mean_r = rgb_mean[0],
mean_g = rgb_mean[1],
mean_b = rgb_mean[2],
resize = 256,
data_name = 'data',
label_name = 'softmax_label',
batch_size = args.batch_size,
Expand Down
4 changes: 2 additions & 2 deletions example/image-classification/train_imagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
fit.add_fit_args(parser)
data.add_data_args(parser)
data.add_data_aug_args(parser)
# use a large aug level
data.set_data_aug_level(parser, 3)
# uncomment to set standard augmentation for resnet training
# data.set_resnet_aug(parser)
parser.set_defaults(
# network
network = 'resnet',
Expand Down
Loading

0 comments on commit 977026c

Please sign in to comment.