Skip to content

Commit

Permalink
support return heatmaps and features for bu models (open-mmlab#229)
Browse files Browse the repository at this point in the history
  • Loading branch information
jin-s13 authored Nov 4, 2020
1 parent bb029d1 commit b97df6b
Show file tree
Hide file tree
Showing 7 changed files with 85 additions and 34 deletions.
13 changes: 11 additions & 2 deletions demo/bottom_up_img_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,15 +48,24 @@ def main():

img_keys = list(coco.imgs.keys())

# optional
return_heatmap = False

# e.g. use ('backbone', ) to return backbone feature
output_layer_names = None

# process each image
for i in range(len(img_keys)):
image_id = img_keys[i]
image = coco.loadImgs(image_id)[0]
image_name = os.path.join(args.img_root, image['file_name'])

# test a single image, with a list of bboxes.
pose_results, heatmaps = inference_bottom_up_pose_model(
pose_model, image_name)
pose_results, returned_outputs = inference_bottom_up_pose_model(
pose_model,
image_name,
return_heatmap=return_heatmap,
outputs=output_layer_names)

if args.out_img_root == '':
out_file = None
Expand Down
13 changes: 11 additions & 2 deletions demo/bottom_up_video_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,22 @@ def main():
f'vis_{os.path.basename(args.video_path)}'), fourcc,
fps, size)

# optional
return_heatmap = False

# e.g. use ('backbone', ) to return backbone feature
output_layer_names = None

while (cap.isOpened()):
flag, img = cap.read()
if not flag:
break

pose_results, heatmaps = inference_bottom_up_pose_model(
pose_model, img)
pose_results, returned_outputs = inference_bottom_up_pose_model(
pose_model,
img,
return_heatmap=return_heatmap,
outputs=output_layer_names)

# show the results
vis_img = vis_pose_result(
Expand Down
2 changes: 1 addition & 1 deletion demo/top_down_img_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def main():
pose_model,
image_name,
person_bboxes,
bbox_thr=args.bbox_thr,
bbox_thr=None,
format='xywh',
dataset=dataset,
return_heatmap=return_heatmap,
Expand Down
57 changes: 38 additions & 19 deletions mmpose/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,10 @@ def _inference_single_pose_model(model,
# forward the model
with torch.no_grad():
all_preds, _, _, heatmap = model(
return_loss=False,
return_heatmap=return_heatmap,
img=data['img'],
img_metas=data['img_metas'])
img_metas=data['img_metas'],
return_loss=False,
return_heatmap=return_heatmap)
return all_preds[0], heatmap


Expand Down Expand Up @@ -290,9 +290,10 @@ def inference_top_down_pose_model(model,

if len(person_bboxes) > 0:
if bbox_thr is not None:
assert person_bboxes.shape[1] == 5
person_bboxes = person_bboxes[person_bboxes[:, 4] > bbox_thr]

with OutputHook(model, outputs=outputs, as_tensor=True) as h:
with OutputHook(model, outputs=outputs, as_tensor=False) as h:
for bbox in person_bboxes:
pose, heatmap = _inference_single_pose_model(
model,
Expand All @@ -315,7 +316,10 @@ def inference_top_down_pose_model(model,
return pose_results, returned_outputs


def inference_bottom_up_pose_model(model, img_or_path):
def inference_bottom_up_pose_model(model,
img_or_path,
return_heatmap=False,
outputs=None):
"""Inference a single image.
num_people: P
Expand All @@ -326,16 +330,22 @@ def inference_bottom_up_pose_model(model, img_or_path):
Args:
model (nn.Module): The loaded pose model.
image_name (str| np.ndarray): Image_name.
return_heatmap (bool) : Flag to return heatmap, default: False
outputs (list(str) | tuple(str)) : Names of layers whose outputs
need to be returned, default: None
Returns:
list[ndarray]: The predicted pose info.
The length of the list
is the number of people (P). Each item in the
list is a ndarray, containing each person's
pose (ndarray[Kx3]): x, y, score
The length of the list is the number of people (P).
Each item in the list is a ndarray, containing each person's
pose (ndarray[Kx3]): x, y, score.
list[dict[np.ndarray[N, K, H, W] | torch.tensor[N, K, H, W]]]:
Output feature maps from layers specified in `outputs`.
Includes 'heatmap' if `return_heatmap` is True.
"""
pose_results = []
returned_outputs = []

cfg = model.cfg
device = next(model.parameters()).device

Expand Down Expand Up @@ -366,17 +376,26 @@ def inference_bottom_up_pose_model(model, img_or_path):
# just get the actual data from DataContainer
data['img_metas'] = data['img_metas'].data[0]

# forward the model
with torch.no_grad():
all_preds, _, _, heatmap = model(
return_loss=False, img=data['img'], img_metas=data['img_metas'])
with OutputHook(model, outputs=outputs, as_tensor=False) as h:
# forward the model
with torch.no_grad():
all_preds, _, _, heatmap = model(
img=data['img'],
img_metas=data['img_metas'],
return_loss=False,
return_heatmap=return_heatmap)

if return_heatmap:
h.layer_outputs['heatmap'] = heatmap

for pred in all_preds:
pose_results.append({
'keypoints': pred[:, :3],
})
returned_outputs.append(h.layer_outputs)

return pose_results, heatmap
for pred in all_preds:
pose_results.append({
'keypoints': pred[:, :3],
})

return pose_results, returned_outputs


def vis_pose_result(model,
Expand Down
4 changes: 2 additions & 2 deletions mmpose/datasets/datasets/bottom_up/bottom_up_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def evaluate(self, outputs, res_folder, metric='mAP', **kwargs):
num_keypoints: K
Args:
outputs (list(preds, scores, image_path, output_heatmap)):
outputs (list(preds, scores, image_path, heatmap)):
* preds (list[images x np.ndarray(P, K, 3+tag_num)]):
Pose predictions for all people in images.
Expand All @@ -219,7 +219,7 @@ def evaluate(self, outputs, res_folder, metric='mAP', **kwargs):
'/',i','m','a','g','e','s','/', 'v','a', 'l',
'2', '0', '1', '7', '/', '0', '0', '0', '0', '0',
'0', '3', '9', '7', '1', '3', '3', '.', 'j', 'p', 'g']
* output_heatmap (np.ndarray[N, K, H, W]): model outputs.
* heatmap (np.ndarray[N, K, H, W]): model outputs.
res_folder (str): Path of directory to save the results.
metric (str | list[str]): Metric to be performed. Defaults: 'mAP'.
Expand Down
22 changes: 15 additions & 7 deletions mmpose/models/detectors/bottom_up.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def forward(self,
joints=None,
img_metas=None,
return_loss=True,
return_heatmap=False,
**kwargs):
"""Calls either forward_train or forward_test depending on whether
return_loss is True.
Expand All @@ -88,8 +89,6 @@ def forward(self,
heatmaps
joints(List(torch.Tensor[NxMxKx2])): Joints of multi-scale target
heatmaps for ae loss
return loss(bool): Option to 'return_loss'. 'return_loss=True' for
training, 'return_loss=False' for validation & test
img_metas(dict):Information about val&test
By default this includes:
- "image_file": image path
Expand All @@ -99,17 +98,23 @@ def forward(self,
- "center": center of image
- "scale": scale of image
- "flip_index": flip index of keypoints
return loss(bool): Option to 'return_loss'. 'return_loss=True' for
training, 'return_loss=False' for validation & test
return_heatmap (bool) : Option to return heatmap.
Returns:
dict|tuple: if 'return_loss' is true, then return losses.
Otherwise, return predicted poses, scores and image
paths.
Otherwise, return predicted poses, scores, image
paths and heatmaps.
"""

if return_loss:
return self.forward_train(img, targets, masks, joints, img_metas,
**kwargs)
else:
return self.forward_test(img, img_metas, **kwargs)
return self.forward_test(
img, img_metas, return_heatmap=return_heatmap, **kwargs)

def forward_train(self, img, targets, masks, joints, img_metas, **kwargs):
"""Forward the bottom-up model and calculate the loss.
Expand Down Expand Up @@ -170,7 +175,7 @@ def forward_train(self, img, targets, masks, joints, img_metas, **kwargs):
losses['all_loss'] = loss
return losses

def forward_test(self, img, img_metas, **kwargs):
def forward_test(self, img, img_metas, return_heatmap=False, **kwargs):
"""Inference the bottom-up model.
Note:
Expand Down Expand Up @@ -243,7 +248,10 @@ def forward_test(self, img, img_metas, **kwargs):
image_path = []
image_path.extend(img_metas['image_file'])

output_heatmap = aggregated_heatmaps.detach().cpu().numpy()
if return_heatmap:
output_heatmap = aggregated_heatmaps.detach().cpu().numpy()
else:
output_heatmap = None

return results, scores, image_path, output_heatmap

Expand Down
8 changes: 7 additions & 1 deletion mmpose/utils/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@ def hook(model, input, output):
if self.as_tensor:
self.layer_outputs[name] = output
else:
self.layer_outputs[name] = output.detach().cpu().numpy()
if isinstance(output, list):
self.layer_outputs[name] = [
out.detach().cpu().numpy() for out in output
]
else:
self.layer_outputs[name] = output.detach().cpu().numpy(
)

return hook

Expand Down

0 comments on commit b97df6b

Please sign in to comment.