Skip to content

Commit

Permalink
Enable additional test-time source views for json dataset provider v2
Browse files Browse the repository at this point in the history
Summary: Adds additional source views to the eval batches for evaluating many-view models on CO3D Challenge

Reviewed By: bottler

Differential Revision: D38705904

fbshipit-source-id: cf7d00dc7db926fbd1656dd97a729674e9ff5adb
  • Loading branch information
davnov134 authored and facebook-github-bot committed Aug 15, 2022
1 parent e8616cc commit 2ff2c7c
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 23 deletions.
1 change: 1 addition & 0 deletions projects/implicitron_trainer/tests/experiment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ data_source_ImplicitronDataSource_args:
test_on_train: false
only_test_set: false
load_eval_batches: true
n_known_frames_for_test: 0
dataset_class_type: JsonIndexDataset
path_manager_factory_class_type: PathManagerFactory
dataset_JsonIndexDataset_args:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,15 @@
# LICENSE file in the root directory of this source tree.


import copy
import json
import logging
import os
import warnings
from typing import Dict, List, Optional, Tuple, Type
from collections import defaultdict
from typing import Dict, List, Optional, Tuple, Type, Union

import numpy as np

from omegaconf import DictConfig
from pytorch3d.implicitron.dataset.dataset_map_provider import (
Expand Down Expand Up @@ -152,6 +156,9 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
only_test_set: Load only the test set. Incompatible with `test_on_train`.
load_eval_batches: Load the file containing eval batches pointing to the
test dataset.
n_known_frames_for_test: Add a certain number of known frames to each
eval batch. Useful for evaluating models that require
source views as input (e.g. NeRF-WCE / PixelNeRF).
dataset_args: Specifies additional arguments to the
JsonIndexDataset constructor call.
path_manager_factory: (Optional) An object that generates an instance of
Expand All @@ -167,6 +174,8 @@ class JsonIndexDatasetMapProviderV2(DatasetMapProviderBase): # pyre-ignore [13]
only_test_set: bool = False
load_eval_batches: bool = True

n_known_frames_for_test: int = 0

dataset_class_type: str = "JsonIndexDataset"
dataset: JsonIndexDataset

Expand Down Expand Up @@ -264,6 +273,18 @@ def __post_init__(self):
val_dataset = dataset.subset_from_frame_index(subset_mapping["val"])
logger.info(f"Val dataset: {str(val_dataset)}")
logger.debug("Extracting test dataset.")

if (self.n_known_frames_for_test > 0) and self.load_eval_batches:
# extend the test subset mapping and the dataset with additional
# known views from the train dataset
(
eval_batch_index,
subset_mapping["test"],
) = self._extend_test_data_with_known_views(
subset_mapping,
eval_batch_index,
)

test_dataset = dataset.subset_from_frame_index(subset_mapping["test"])
logger.info(f"Test dataset: {str(test_dataset)}")
if self.load_eval_batches:
Expand Down Expand Up @@ -369,6 +390,40 @@ def _get_available_subset_names(self):
dataset_root = self.dataset_root
return get_available_subset_names(dataset_root, self.category)

def _extend_test_data_with_known_views(
self,
subset_mapping: Dict[str, List[Union[Tuple[str, int], Tuple[str, int, str]]]],
eval_batch_index: List[List[Union[Tuple[str, int, str], Tuple[str, int]]]],
):
# convert the train subset mapping to a dict:
# sequence_to_train_frames: {sequence_name: frame_index}
sequence_to_train_frames = defaultdict(list)
for frame_entry in subset_mapping["train"]:
sequence_name = frame_entry[0]
sequence_to_train_frames[sequence_name].append(frame_entry)
sequence_to_train_frames = dict(sequence_to_train_frames)
test_subset_mapping_set = {tuple(s) for s in subset_mapping["test"]}

# extend the eval batches / subset mapping with the additional examples
eval_batch_index_out = copy.deepcopy(eval_batch_index)
generator = np.random.default_rng(seed=0)
for batch in eval_batch_index_out:
sequence_name = batch[0][0]
sequence_known_entries = sequence_to_train_frames[sequence_name]
idx_to_add = generator.permutation(len(sequence_known_entries))[
: self.n_known_frames_for_test
]
entries_to_add = [sequence_known_entries[a] for a in idx_to_add]
assert all(e in subset_mapping["train"] for e in entries_to_add)

# extend the eval batch with the known views
batch.extend(entries_to_add)

# also add these new entries to the test subset mapping
test_subset_mapping_set.update(tuple(e) for e in entries_to_add)

return eval_batch_index_out, list(test_subset_mapping_set)


def get_available_subset_names(dataset_root: str, category: str) -> List[str]:
"""
Expand Down
1 change: 1 addition & 0 deletions tests/implicitron/data/data_source.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ dataset_map_provider_JsonIndexDatasetMapProviderV2_args:
test_on_train: false
only_test_set: false
load_eval_batches: true
n_known_frames_for_test: 0
dataset_class_type: JsonIndexDataset
path_manager_factory_class_type: PathManagerFactory
dataset_JsonIndexDataset_args:
Expand Down
66 changes: 44 additions & 22 deletions tests/implicitron/test_json_index_dataset_provider_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,47 @@ def test_random_dataset(self):
expand_args_fields(JsonIndexDatasetMapProviderV2)
categories = ["A", "B"]
subset_name = "test"
eval_batch_size = 5
with tempfile.TemporaryDirectory() as tmpd:
_make_random_json_dataset_map_provider_v2_data(tmpd, categories)
for category in categories:
dataset_provider = JsonIndexDatasetMapProviderV2(
category=category,
subset_name="test",
dataset_root=tmpd,
)
dataset_map = dataset_provider.get_dataset_map()
for set_ in ["train", "val", "test"]:
dataloader = torch.utils.data.DataLoader(
getattr(dataset_map, set_),
batch_size=3,
shuffle=True,
collate_fn=FrameData.collate,
_make_random_json_dataset_map_provider_v2_data(
tmpd,
categories,
eval_batch_size=eval_batch_size,
)
for n_known_frames_for_test in [0, 2]:
for category in categories:
dataset_provider = JsonIndexDatasetMapProviderV2(
category=category,
subset_name="test",
dataset_root=tmpd,
n_known_frames_for_test=n_known_frames_for_test,
)
for _ in dataloader:
pass
category_to_subset_list = (
dataset_provider.get_category_to_subset_name_list()
)
category_to_subset_list_ = {c: [subset_name] for c in categories}
self.assertTrue(category_to_subset_list == category_to_subset_list_)
dataset_map = dataset_provider.get_dataset_map()
for set_ in ["train", "val", "test"]:
if set_ in ["train", "val"]:
dataloader = torch.utils.data.DataLoader(
getattr(dataset_map, set_),
batch_size=3,
shuffle=True,
collate_fn=FrameData.collate,
)
else:
dataloader = torch.utils.data.DataLoader(
getattr(dataset_map, set_),
batch_sampler=dataset_map[set_].get_eval_batches(),
collate_fn=FrameData.collate,
)
for batch in dataloader:
if set_ == "test":
self.assertTrue(
batch.image_rgb.shape[0]
== n_known_frames_for_test + eval_batch_size
)
category_to_subset_list = (
dataset_provider.get_category_to_subset_name_list()
)
category_to_subset_list_ = {c: [subset_name] for c in categories}
self.assertTrue(category_to_subset_list == category_to_subset_list_)


def _make_random_json_dataset_map_provider_v2_data(
Expand All @@ -70,6 +88,7 @@ def _make_random_json_dataset_map_provider_v2_data(
H: int = 50,
W: int = 30,
subset_name: str = "test",
eval_batch_size: int = 5,
):
os.makedirs(root, exist_ok=True)
category_to_subset_list = {}
Expand Down Expand Up @@ -142,7 +161,10 @@ def _make_random_json_dataset_map_provider_v2_data(
with open(set_list_file, "w") as f:
json.dump(set_list, f)

eval_batches = [random.sample(test_frame_index, 5) for _ in range(10)]
eval_batches = [
random.sample(test_frame_index, eval_batch_size) for _ in range(10)
]

eval_b_dir = os.path.join(root, category, "eval_batches")
os.makedirs(eval_b_dir, exist_ok=True)
eval_b_file = os.path.join(eval_b_dir, f"eval_batches_{subset_name}.json")
Expand Down

0 comments on commit 2ff2c7c

Please sign in to comment.