diff --git a/projects/implicitron_trainer/tests/experiment.yaml b/projects/implicitron_trainer/tests/experiment.yaml index f70a3c234..00343b16f 100644 --- a/projects/implicitron_trainer/tests/experiment.yaml +++ b/projects/implicitron_trainer/tests/experiment.yaml @@ -61,6 +61,7 @@ data_source_ImplicitronDataSource_args: test_on_train: false only_test_set: false load_eval_batches: true + n_known_frames_for_test: 0 dataset_class_type: JsonIndexDataset path_manager_factory_class_type: PathManagerFactory dataset_JsonIndexDataset_args: diff --git a/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py index d38c4bf1c..b05b4cfb8 100644 --- a/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py +++ b/pytorch3d/implicitron/dataset/json_index_dataset_map_provider_v2.py @@ -5,11 +5,15 @@ # LICENSE file in the root directory of this source tree. +import copy import json import logging import os import warnings -from typing import Dict, List, Optional, Tuple, Type +from collections import defaultdict +from typing import Dict, List, Optional, Tuple, Type, Union + +import numpy as np from omegaconf import DictConfig from pytorch3d.implicitron.dataset.dataset_map_provider import ( @@ -152,6 +156,9 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13] only_test_set: Load only the test set. Incompatible with `test_on_train`. load_eval_batches: Load the file containing eval batches pointing to the test dataset. + n_known_frames_for_test: Add a certain number of known frames to each + eval batch. Useful for evaluating models that require + source views as input (e.g. NeRF-WCE / PixelNeRF). dataset_args: Specifies additional arguments to the JsonIndexDataset constructor call. path_manager_factory: (Optional) An object that generates an instance of @@ -167,6 +174,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13] only_test_set: bool = False load_eval_batches: bool = True + n_known_frames_for_test: int = 0 + dataset_class_type: str = "JsonIndexDataset" dataset: JsonIndexDataset @@ -264,6 +273,18 @@ def __post_init__(self): val_dataset = dataset.subset_from_frame_index(subset_mapping["val"]) logger.info(f"Val dataset: {str(val_dataset)}") logger.debug("Extracting test dataset.") + + if (self.n_known_frames_for_test > 0) and self.load_eval_batches: + # extend the test subset mapping and the dataset with additional + # known views from the train dataset + ( + eval_batch_index, + subset_mapping["test"], + ) = self._extend_test_data_with_known_views( + subset_mapping, + eval_batch_index, + ) + test_dataset = dataset.subset_from_frame_index(subset_mapping["test"]) logger.info(f"Test dataset: {str(test_dataset)}") if self.load_eval_batches: @@ -369,6 +390,40 @@ def _get_available_subset_names(self): dataset_root = self.dataset_root return get_available_subset_names(dataset_root, self.category) + def _extend_test_data_with_known_views( + self, + subset_mapping: Dict[str, List[Union[Tuple[str, int], Tuple[str, int, str]]]], + eval_batch_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]], + ): + # convert the train subset mapping to a dict: + # sequence_to_train_frames: {sequence_name: frame_index} + sequence_to_train_frames = defaultdict(list) + for frame_entry in subset_mapping["train"]: + sequence_name = frame_entry[0] + sequence_to_train_frames[sequence_name].append(frame_entry) + sequence_to_train_frames = dict(sequence_to_train_frames) + test_subset_mapping_set = {tuple(s) for s in subset_mapping["test"]} + + # extend the eval batches / subset mapping with the additional examples + eval_batch_index_out = copy.deepcopy(eval_batch_index) + generator = np.random.default_rng(seed=0) + for batch in eval_batch_index_out: + sequence_name = batch[0][0] + sequence_known_entries = sequence_to_train_frames[sequence_name] + idx_to_add = generator.permutation(len(sequence_known_entries))[ + : self.n_known_frames_for_test + ] + entries_to_add = [sequence_known_entries[a] for a in idx_to_add] + assert all(e in subset_mapping["train"] for e in entries_to_add) + + # extend the eval batch with the known views + batch.extend(entries_to_add) + + # also add these new entries to the test subset mapping + test_subset_mapping_set.update(tuple(e) for e in entries_to_add) + + return eval_batch_index_out, list(test_subset_mapping_set) + def get_available_subset_names(dataset_root: str, category: str) -> List[str]: """ diff --git a/tests/implicitron/data/data_source.yaml b/tests/implicitron/data/data_source.yaml index 55e3971ab..bd6f4aade 100644 --- a/tests/implicitron/data/data_source.yaml +++ b/tests/implicitron/data/data_source.yaml @@ -49,6 +49,7 @@ dataset_map_provider_JsonIndexDatasetMapProviderV2_args: test_on_train: false only_test_set: false load_eval_batches: true + n_known_frames_for_test: 0 dataset_class_type: JsonIndexDataset path_manager_factory_class_type: PathManagerFactory dataset_JsonIndexDataset_args: diff --git a/tests/implicitron/test_json_index_dataset_provider_v2.py b/tests/implicitron/test_json_index_dataset_provider_v2.py index bd698692b..abec3b9c0 100644 --- a/tests/implicitron/test_json_index_dataset_provider_v2.py +++ b/tests/implicitron/test_json_index_dataset_provider_v2.py @@ -37,29 +37,47 @@ def test_random_dataset(self): expand_args_fields(JsonIndexDatasetMapProviderV2) categories = ["A", "B"] subset_name = "test" + eval_batch_size = 5 with tempfile.TemporaryDirectory() as tmpd: - _make_random_json_dataset_map_provider_v2_data(tmpd, categories) - for category in categories: - dataset_provider = JsonIndexDatasetMapProviderV2( - category=category, - subset_name="test", - dataset_root=tmpd, - ) - dataset_map = dataset_provider.get_dataset_map() - for set_ in ["train", "val", "test"]: - dataloader = torch.utils.data.DataLoader( - getattr(dataset_map, set_), - batch_size=3, - shuffle=True, - collate_fn=FrameData.collate, + _make_random_json_dataset_map_provider_v2_data( + tmpd, + categories, + eval_batch_size=eval_batch_size, + ) + for n_known_frames_for_test in [0, 2]: + for category in categories: + dataset_provider = JsonIndexDatasetMapProviderV2( + category=category, + subset_name="test", + dataset_root=tmpd, + n_known_frames_for_test=n_known_frames_for_test, ) - for _ in dataloader: - pass - category_to_subset_list = ( - dataset_provider.get_category_to_subset_name_list() - ) - category_to_subset_list_ = {c: [subset_name] for c in categories} - self.assertTrue(category_to_subset_list == category_to_subset_list_) + dataset_map = dataset_provider.get_dataset_map() + for set_ in ["train", "val", "test"]: + if set_ in ["train", "val"]: + dataloader = torch.utils.data.DataLoader( + getattr(dataset_map, set_), + batch_size=3, + shuffle=True, + collate_fn=FrameData.collate, + ) + else: + dataloader = torch.utils.data.DataLoader( + getattr(dataset_map, set_), + batch_sampler=dataset_map[set_].get_eval_batches(), + collate_fn=FrameData.collate, + ) + for batch in dataloader: + if set_ == "test": + self.assertTrue( + batch.image_rgb.shape[0] + == n_known_frames_for_test + eval_batch_size + ) + category_to_subset_list = ( + dataset_provider.get_category_to_subset_name_list() + ) + category_to_subset_list_ = {c: [subset_name] for c in categories} + self.assertTrue(category_to_subset_list == category_to_subset_list_) def _make_random_json_dataset_map_provider_v2_data( @@ -70,6 +88,7 @@ def _make_random_json_dataset_map_provider_v2_data( H: int = 50, W: int = 30, subset_name: str = "test", + eval_batch_size: int = 5, ): os.makedirs(root, exist_ok=True) category_to_subset_list = {} @@ -142,7 +161,10 @@ def _make_random_json_dataset_map_provider_v2_data( with open(set_list_file, "w") as f: json.dump(set_list, f) - eval_batches = [random.sample(test_frame_index, 5) for _ in range(10)] + eval_batches = [ + random.sample(test_frame_index, eval_batch_size) for _ in range(10) + ] + eval_b_dir = os.path.join(root, category, "eval_batches") os.makedirs(eval_b_dir, exist_ok=True) eval_b_file = os.path.join(eval_b_dir, f"eval_batches_{subset_name}.json")