From 94f321fa3dd776da5edd1efa80e8a094ee5e6b02 Mon Sep 17 00:00:00 2001 From: David Novotny Date: Mon, 28 Nov 2022 04:36:41 -0800 Subject: [PATCH] render_flyaround bugfix Summary: Fixes a bug which would crash render_flyaround anytime visualize_preds_keys is adjusted Reviewed By: shapovalov Differential Revision: D41124462 fbshipit-source-id: 127045a91a055909f8bd56c8af81afac02c00f60 --- .../models/visualization/render_flyaround.py | 32 +++++++++++++++---- 1 file changed, 25 insertions(+), 7 deletions(-) diff --git a/pytorch3d/implicitron/models/visualization/render_flyaround.py b/pytorch3d/implicitron/models/visualization/render_flyaround.py index f3d868d5a..2a3afadbb 100644 --- a/pytorch3d/implicitron/models/visualization/render_flyaround.py +++ b/pytorch3d/implicitron/models/visualization/render_flyaround.py @@ -10,7 +10,17 @@ import math import os import random -from typing import Any, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union +from typing import ( + Any, + Dict, + Iterable, + List, + Optional, + Sequence, + Tuple, + TYPE_CHECKING, + Union, +) import numpy as np import torch @@ -180,7 +190,7 @@ def render_flyaround( preds.update(net_input) # merge everything into one big dict # Render the predictions to images - rendered_pred = _images_from_preds(preds) + rendered_pred = _images_from_preds(preds, extract_keys=visualize_preds_keys) preds_total.append(rendered_pred) # show the preds every 5% of the export iterations @@ -223,9 +233,9 @@ def _load_whole_dataset( return next(iter(load_all_dataloader)) -def _images_from_preds(preds: Dict[str, Any]) -> Dict[str, torch.Tensor]: - imout = {} - for k in ( +def _images_from_preds( + preds: Dict[str, Any], + extract_keys: Iterable[str] = ( "image_rgb", "images_render", "fg_probability", @@ -233,7 +243,10 @@ def _images_from_preds(preds: Dict[str, Any]) -> Dict[str, torch.Tensor]: "depths_render", "depth_map", "_all_source_images", - ): + ), +) -> Dict[str, torch.Tensor]: + imout = {} + for k in extract_keys: if k == "_all_source_images" and "image_rgb" in preds: src_ims = preds["image_rgb"][1:].cpu().detach().clone() v = _stack_images(src_ims, None)[None] @@ -343,6 +356,9 @@ def _generate_prediction_videos( # init a video writer for each predicted key vws = {} for k in predicted_keys: + if k not in preds[0]: + logger.warn(f"Cannot generate video for prediction key '{k}'") + continue cache_dir = ( None if video_frames_dir is None @@ -355,13 +371,15 @@ def _generate_prediction_videos( ) for rendered_pred in tqdm(preds): - for k in predicted_keys: + for k in vws: vws[k].write_frame( rendered_pred[k][0].clip(0.0, 1.0).detach().cpu().numpy(), resize=resize, ) for k in predicted_keys: + if k not in vws: + continue vws[k].get_video() logger.info(f"Generated {vws[k].out_path}.") if viz is not None: