Skip to content

Commit

Permalink
[Feature] Add Rgb2Gray transform (#227)
Browse files Browse the repository at this point in the history
* add transformer Rgb2Gray

* restore

* fix self.weights

* restore

* fix code

* restore

* fix syntax error

* restore
  • Loading branch information
yamengxi authored Nov 9, 2020
1 parent 3d18775 commit 7c68bca
Show file tree
Hide file tree
Showing 2 changed files with 122 additions and 0 deletions.
55 changes: 55 additions & 0 deletions mmseg/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,61 @@ def __repr__(self):
return repr_str


@PIPELINES.register_module()
class RGB2Gray(object):
"""Convert RGB image to grayscale image.
This transform calculate the weighted mean of input image channels with
``weights`` and then expand the channels to ``out_channels``. When
``out_channels`` is None, the number of output channels is the same as
input channels.
Args:
out_channels (int): Expected number of output channels after
transforming. Default: None.
weights (tuple[float]): The weights to calculate the weighted mean.
Default: (0.299, 0.587, 0.114).
"""

def __init__(self, out_channels=None, weights=(0.299, 0.587, 0.114)):
assert out_channels is None or out_channels > 0
self.out_channels = out_channels
assert isinstance(weights, tuple)
for item in weights:
assert isinstance(item, (float, int))
self.weights = weights

def __call__(self, results):
"""Call function to convert RGB image to grayscale image.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with grayscale image.
"""
img = results['img']
assert len(img.shape) == 3
assert img.shape[2] == len(self.weights)
weights = np.array(self.weights).reshape((1, 1, -1))
img = (img * weights).sum(2, keepdims=True)
if self.out_channels is None:
img = img.repeat(weights.shape[2], axis=2)
else:
img = img.repeat(self.out_channels, axis=2)

results['img'] = img
results['img_shape'] = img.shape

return results

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(out_channels={self.out_channels}, ' \
f'weights={self.weights})'
return repr_str


@PIPELINES.register_module()
class SegRescale(object):
"""Rescale semantic segmentation maps.
Expand Down
67 changes: 67 additions & 0 deletions tests/test_data/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,73 @@ def test_normalize():
assert np.allclose(results['img'], converted_img)


def test_rgb2gray():
# test assertion out_channels should be greater than 0
with pytest.raises(AssertionError):
transform = dict(type='RGB2Gray', out_channels=-1)
build_from_cfg(transform, PIPELINES)
# test assertion weights should be tuple[float]
with pytest.raises(AssertionError):
transform = dict(type='RGB2Gray', out_channels=1, weights=1.1)
build_from_cfg(transform, PIPELINES)

# test out_channels is None
transform = dict(type='RGB2Gray')
transform = build_from_cfg(transform, PIPELINES)

assert str(transform) == f'RGB2Gray(' \
f'out_channels={None}, ' \
f'weights={(0.299, 0.587, 0.114)})'

results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
h, w, c = img.shape
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0

results = transform(results)
assert results['img'].shape == (h, w, c)
assert results['img_shape'] == (h, w, c)
assert results['ori_shape'] == (h, w, c)

# test out_channels = 2
transform = dict(type='RGB2Gray', out_channels=2)
transform = build_from_cfg(transform, PIPELINES)

assert str(transform) == f'RGB2Gray(' \
f'out_channels={2}, ' \
f'weights={(0.299, 0.587, 0.114)})'

results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
h, w, c = img.shape
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))
results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0

results = transform(results)
assert results['img'].shape == (h, w, 2)
assert results['img_shape'] == (h, w, 2)
assert results['ori_shape'] == (h, w, c)


def test_seg_rescale():
results = dict()
seg = np.array(
Expand Down

0 comments on commit 7c68bca

Please sign in to comment.