Skip to content

Commit

Permalink
Add semantic segmentation (Mask2Former) code (#186)
Browse files Browse the repository at this point in the history
Add semantic segmentation (Mask2Former based on ViT-Adapter) code + update demo notebook for segmentation with a dedicated section.
  • Loading branch information
patricklabatut authored Aug 31, 2023
1 parent d5b0405 commit 91d8cd8
Show file tree
Hide file tree
Showing 40 changed files with 6,335 additions and 13 deletions.
8 changes: 8 additions & 0 deletions dinov2/eval/segmentation_m2f/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

from .core import * # noqa: F403
from .models import * # noqa: F403
from .ops import * # noqa: F403
11 changes: 11 additions & 0 deletions dinov2/eval/segmentation_m2f/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

from mmseg.core.evaluation import * # noqa: F403
from mmseg.core.seg import * # noqa: F403

from .anchor import * # noqa: F403
from .box import * # noqa: F403
from .utils import * # noqa: F403
6 changes: 6 additions & 0 deletions dinov2/eval/segmentation_m2f/core/anchor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

from .point_generator import MlvlPointGenerator # noqa: F403
21 changes: 21 additions & 0 deletions dinov2/eval/segmentation_m2f/core/anchor/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import warnings

from mmcv.utils import Registry, build_from_cfg

PRIOR_GENERATORS = Registry("Generator for anchors and points")

ANCHOR_GENERATORS = PRIOR_GENERATORS


def build_prior_generator(cfg, default_args=None):
return build_from_cfg(cfg, PRIOR_GENERATORS, default_args)


def build_anchor_generator(cfg, default_args=None):
warnings.warn("``build_anchor_generator`` would be deprecated soon, please use " "``build_prior_generator`` ")
return build_prior_generator(cfg, default_args=default_args)
205 changes: 205 additions & 0 deletions dinov2/eval/segmentation_m2f/core/anchor/point_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

import numpy as np
import torch
from torch.nn.modules.utils import _pair

from .builder import PRIOR_GENERATORS


@PRIOR_GENERATORS.register_module()
class MlvlPointGenerator:
"""Standard points generator for multi-level (Mlvl) feature maps in 2D
points-based detectors.
Args:
strides (list[int] | list[tuple[int, int]]): Strides of anchors
in multiple feature levels in order (w, h).
offset (float): The offset of points, the value is normalized with
corresponding stride. Defaults to 0.5.
"""

def __init__(self, strides, offset=0.5):
self.strides = [_pair(stride) for stride in strides]
self.offset = offset

@property
def num_levels(self):
"""int: number of feature levels that the generator will be applied"""
return len(self.strides)

@property
def num_base_priors(self):
"""list[int]: The number of priors (points) at a point
on the feature grid"""
return [1 for _ in range(len(self.strides))]

def _meshgrid(self, x, y, row_major=True):
yy, xx = torch.meshgrid(y, x)
if row_major:
# warning .flatten() would cause error in ONNX exporting
# have to use reshape here
return xx.reshape(-1), yy.reshape(-1)

else:
return yy.reshape(-1), xx.reshape(-1)

def grid_priors(self, featmap_sizes, dtype=torch.float32, device="cuda", with_stride=False):
"""Generate grid points of multiple feature levels.
Args:
featmap_sizes (list[tuple]): List of feature map sizes in
multiple feature levels, each size arrange as
as (h, w).
dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
device (str): The device where the anchors will be put on.
with_stride (bool): Whether to concatenate the stride to
the last dimension of points.
Return:
list[torch.Tensor]: Points of multiple feature levels.
The sizes of each tensor should be (N, 2) when with stride is
``False``, where N = width * height, width and height
are the sizes of the corresponding feature level,
and the last dimension 2 represent (coord_x, coord_y),
otherwise the shape should be (N, 4),
and the last dimension 4 represent
(coord_x, coord_y, stride_w, stride_h).
"""

assert self.num_levels == len(featmap_sizes)
multi_level_priors = []
for i in range(self.num_levels):
priors = self.single_level_grid_priors(
featmap_sizes[i], level_idx=i, dtype=dtype, device=device, with_stride=with_stride
)
multi_level_priors.append(priors)
return multi_level_priors

def single_level_grid_priors(self, featmap_size, level_idx, dtype=torch.float32, device="cuda", with_stride=False):
"""Generate grid Points of a single level.
Note:
This function is usually called by method ``self.grid_priors``.
Args:
featmap_size (tuple[int]): Size of the feature maps, arrange as
(h, w).
level_idx (int): The index of corresponding feature map level.
dtype (:obj:`dtype`): Dtype of priors. Default: torch.float32.
device (str, optional): The device the tensor will be put on.
Defaults to 'cuda'.
with_stride (bool): Concatenate the stride to the last dimension
of points.
Return:
Tensor: Points of single feature levels.
The shape of tensor should be (N, 2) when with stride is
``False``, where N = width * height, width and height
are the sizes of the corresponding feature level,
and the last dimension 2 represent (coord_x, coord_y),
otherwise the shape should be (N, 4),
and the last dimension 4 represent
(coord_x, coord_y, stride_w, stride_h).
"""
feat_h, feat_w = featmap_size
stride_w, stride_h = self.strides[level_idx]
shift_x = (torch.arange(0, feat_w, device=device) + self.offset) * stride_w
# keep featmap_size as Tensor instead of int, so that we
# can convert to ONNX correctly
shift_x = shift_x.to(dtype)

shift_y = (torch.arange(0, feat_h, device=device) + self.offset) * stride_h
# keep featmap_size as Tensor instead of int, so that we
# can convert to ONNX correctly
shift_y = shift_y.to(dtype)
shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
if not with_stride:
shifts = torch.stack([shift_xx, shift_yy], dim=-1)
else:
# use `shape[0]` instead of `len(shift_xx)` for ONNX export
stride_w = shift_xx.new_full((shift_xx.shape[0],), stride_w).to(dtype)
stride_h = shift_xx.new_full((shift_yy.shape[0],), stride_h).to(dtype)
shifts = torch.stack([shift_xx, shift_yy, stride_w, stride_h], dim=-1)
all_points = shifts.to(device)
return all_points

def valid_flags(self, featmap_sizes, pad_shape, device="cuda"):
"""Generate valid flags of points of multiple feature levels.
Args:
featmap_sizes (list(tuple)): List of feature map sizes in
multiple feature levels, each size arrange as
as (h, w).
pad_shape (tuple(int)): The padded shape of the image,
arrange as (h, w).
device (str): The device where the anchors will be put on.
Return:
list(torch.Tensor): Valid flags of points of multiple levels.
"""
assert self.num_levels == len(featmap_sizes)
multi_level_flags = []
for i in range(self.num_levels):
point_stride = self.strides[i]
feat_h, feat_w = featmap_sizes[i]
h, w = pad_shape[:2]
valid_feat_h = min(int(np.ceil(h / point_stride[1])), feat_h)
valid_feat_w = min(int(np.ceil(w / point_stride[0])), feat_w)
flags = self.single_level_valid_flags((feat_h, feat_w), (valid_feat_h, valid_feat_w), device=device)
multi_level_flags.append(flags)
return multi_level_flags

def single_level_valid_flags(self, featmap_size, valid_size, device="cuda"):
"""Generate the valid flags of points of a single feature map.
Args:
featmap_size (tuple[int]): The size of feature maps, arrange as
as (h, w).
valid_size (tuple[int]): The valid size of the feature maps.
The size arrange as as (h, w).
device (str, optional): The device where the flags will be put on.
Defaults to 'cuda'.
Returns:
torch.Tensor: The valid flags of each points in a single level \
feature map.
"""
feat_h, feat_w = featmap_size
valid_h, valid_w = valid_size
assert valid_h <= feat_h and valid_w <= feat_w
valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
valid_x[:valid_w] = 1
valid_y[:valid_h] = 1
valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
valid = valid_xx & valid_yy
return valid

def sparse_priors(self, prior_idxs, featmap_size, level_idx, dtype=torch.float32, device="cuda"):
"""Generate sparse points according to the ``prior_idxs``.
Args:
prior_idxs (Tensor): The index of corresponding anchors
in the feature map.
featmap_size (tuple[int]): feature map size arrange as (w, h).
level_idx (int): The level index of corresponding feature
map.
dtype (obj:`torch.dtype`): Date type of points. Defaults to
``torch.float32``.
device (obj:`torch.device`): The device where the points is
located.
Returns:
Tensor: Anchor with shape (N, 2), N should be equal to
the length of ``prior_idxs``. And last dimension
2 represent (coord_x, coord_y).
"""
height, width = featmap_size
x = (prior_idxs % width + self.offset) * self.strides[level_idx][0]
y = ((prior_idxs // width) % height + self.offset) * self.strides[level_idx][1]
prioris = torch.stack([x, y], 1).to(dtype)
prioris = prioris.to(device)
return prioris
7 changes: 7 additions & 0 deletions dinov2/eval/segmentation_m2f/core/box/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

from .builder import * # noqa: F403
from .samplers import MaskPseudoSampler # noqa: F403
19 changes: 19 additions & 0 deletions dinov2/eval/segmentation_m2f/core/box/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

from mmcv.utils import Registry, build_from_cfg

BBOX_SAMPLERS = Registry("bbox_sampler")
BBOX_CODERS = Registry("bbox_coder")


def build_sampler(cfg, **default_args):
"""Builder of box sampler."""
return build_from_cfg(cfg, BBOX_SAMPLERS, default_args)


def build_bbox_coder(cfg, **default_args):
"""Builder of box coder."""
return build_from_cfg(cfg, BBOX_CODERS, default_args)
6 changes: 6 additions & 0 deletions dinov2/eval/segmentation_m2f/core/box/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

from .mask_pseudo_sampler import MaskPseudoSampler # noqa: F403
92 changes: 92 additions & 0 deletions dinov2/eval/segmentation_m2f/core/box/samplers/base_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.

from abc import ABCMeta, abstractmethod

import torch

from .sampling_result import SamplingResult


class BaseSampler(metaclass=ABCMeta):
"""Base class of samplers."""

def __init__(self, num, pos_fraction, neg_pos_ub=-1, add_gt_as_proposals=True, **kwargs):
self.num = num
self.pos_fraction = pos_fraction
self.neg_pos_ub = neg_pos_ub
self.add_gt_as_proposals = add_gt_as_proposals
self.pos_sampler = self
self.neg_sampler = self

@abstractmethod
def _sample_pos(self, assign_result, num_expected, **kwargs):
"""Sample positive samples."""
pass

@abstractmethod
def _sample_neg(self, assign_result, num_expected, **kwargs):
"""Sample negative samples."""
pass

def sample(self, assign_result, bboxes, gt_bboxes, gt_labels=None, **kwargs):
"""Sample positive and negative bboxes.
This is a simple implementation of bbox sampling given candidates,
assigning results and ground truth bboxes.
Args:
assign_result (:obj:`AssignResult`): Bbox assigning results.
bboxes (Tensor): Boxes to be sampled from.
gt_bboxes (Tensor): Ground truth bboxes.
gt_labels (Tensor, optional): Class labels of ground truth bboxes.
Returns:
:obj:`SamplingResult`: Sampling result.
Example:
>>> from mmdet.core.bbox import RandomSampler
>>> from mmdet.core.bbox import AssignResult
>>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes
>>> rng = ensure_rng(None)
>>> assign_result = AssignResult.random(rng=rng)
>>> bboxes = random_boxes(assign_result.num_preds, rng=rng)
>>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
>>> gt_labels = None
>>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1,
>>> add_gt_as_proposals=False)
>>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels)
"""
if len(bboxes.shape) < 2:
bboxes = bboxes[None, :]

bboxes = bboxes[:, :4]

gt_flags = bboxes.new_zeros((bboxes.shape[0],), dtype=torch.uint8)
if self.add_gt_as_proposals and len(gt_bboxes) > 0:
if gt_labels is None:
raise ValueError("gt_labels must be given when add_gt_as_proposals is True")
bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
assign_result.add_gt_(gt_labels)
gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
gt_flags = torch.cat([gt_ones, gt_flags])

num_expected_pos = int(self.num * self.pos_fraction)
pos_inds = self.pos_sampler._sample_pos(assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
# We found that sampled indices have duplicated items occasionally.
# (may be a bug of PyTorch)
pos_inds = pos_inds.unique()
num_sampled_pos = pos_inds.numel()
num_expected_neg = self.num - num_sampled_pos
if self.neg_pos_ub >= 0:
_pos = max(1, num_sampled_pos)
neg_upper_bound = int(self.neg_pos_ub * _pos)
if num_expected_neg > neg_upper_bound:
num_expected_neg = neg_upper_bound
neg_inds = self.neg_sampler._sample_neg(assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
neg_inds = neg_inds.unique()

sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags)
return sampling_result
Loading

0 comments on commit 91d8cd8

Please sign in to comment.