Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add RandomRotate transform #215

Merged
merged 13 commits into from
Nov 7, 2020
85 changes: 85 additions & 0 deletions mmseg/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,91 @@ def __repr__(self):
return self.__class__.__name__ + f'(crop_size={self.crop_size})'


@PIPELINES.register_module()
class RandomRotate(object):
"""Rotate the image & seg.

Args:
rotate_ratio (float): The rotation probability.
degree (float, tuple[float]): Range of degrees to select from. If
degree is a number instead of tuple like (min, max),
the range of degree will be (``-degree``, ``+degree``)
pad_val (float, optional): Padding value. Default: 0.
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
seg_pad_val (float, optional): Padding value of segmentation map.
Default: 255.
center (tuple[float], optional): Center point (w, h) of the rotation in
the source image. If not specified, the center of the image will be
used. Default: None.
auto_bound (bool): Whether to adjust the image size to cover the whole
rotated image. Default: False
"""

def __init__(self,
rotate_ratio,
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
degree,
pad_val=0,
seg_pad_val=255,
center=None,
auto_bound=False):
self.rotate_ratio = rotate_ratio
assert rotate_ratio >= 0 and rotate_ratio <= 1
if isinstance(degree, (float, int)):
assert degree > 0, f'degree {degree} should be positive'
self.degree = (-degree, degree)
else:
self.degree = degree
assert len(self.degree) == 2, f'degree {self.degree} should be a ' \
f'tuple of (min, max)'
self.pal_val = pad_val
self.seg_pad_val = seg_pad_val
self.center = center
self.auto_bound = auto_bound

def __call__(self, results):
"""Call function to flip bounding boxes, masks, semantic segmentation
maps.

Args:
results (dict): Result dict from loading pipeline.

Returns:
dict: Flipped results, 'flip', 'flip_direction' keys are added into
xvjiarui marked this conversation as resolved.
Show resolved Hide resolved
result dict.
"""

rotate = True if np.random.rand() < self.rotate_ratio else False
degree = np.random.uniform(min(*self.degree), max(*self.degree))
if rotate:
# rotate image
results['img'] = mmcv.imrotate(
results['img'],
angle=degree,
border_value=self.pal_val,
center=self.center,
auto_bound=self.auto_bound)

# rotate segs
for key in results.get('seg_fields', []):
results[key] = mmcv.imrotate(
results[key],
angle=degree,
border_value=self.seg_pad_val,
center=self.center,
auto_bound=self.auto_bound,
interpolation='nearest')
return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(rotate_ratio={self.rotate_ratio}, ' \
f'degree={self.degree}, ' \
f'pad_val={self.pal_val}, ' \
f'seg_pad_val={self.seg_pad_val}, ' \
f'center={self.center}, ' \
f'auto_bound={self.auto_bound})'
return repr_str


@PIPELINES.register_module()
class SegRescale(object):
"""Rescale semantic segmentation maps.
Expand Down
42 changes: 42 additions & 0 deletions tests/test_data/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,48 @@ def test_pad():
assert img_shape[1] % 32 == 0


def test_rotate():
# test assertion degree should be tuple[float] or float
with pytest.raises(AssertionError):
transform = dict(type='RandomRotate', rotate_ratio=0.5, degree=-10)
build_from_cfg(transform, PIPELINES)
# test assertion degree should be tuple[float] or float
with pytest.raises(AssertionError):
transform = dict(
type='RandomRotate', rotate_ratio=0.5, degree=(10., 20., 30.))
build_from_cfg(transform, PIPELINES)

transform = dict(type='RandomRotate', degree=10., rotate_ratio=1.)
transform = build_from_cfg(transform, PIPELINES)

assert str(transform) == f'RandomRotate(' \
f'rotate_ratio={1.}, ' \
f'degree=({-10.}, {10.}), ' \
f'pad_val={0}, ' \
f'seg_pad_val={255}, ' \
f'center={None}, ' \
f'auto_bound={False})'

results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
h, w, _ = img.shape
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0

results = transform(results)
assert results['img'].shape[:2] == (h, w)
assert results['gt_semantic_seg'].shape[:2] == (h, w)


def test_normalize():
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53],
Expand Down