From 9540c29023c2b6bb53e5a26a5e7a9d34ce88e9b1 Mon Sep 17 00:00:00 2001 From: Jeremy Reizenstein Date: Fri, 27 Jan 2023 07:07:46 -0800 Subject: [PATCH] Make Module.__init__ automatic Summary: If a configurable class inherits torch.nn.Module and is instantiated, automatically call `torch.nn.Module.__init__` on it before doing anything else. Reviewed By: shapovalov Differential Revision: D42760349 fbshipit-source-id: 409894911a4252b7987e1fd218ee9ecefbec8e62 --- projects/implicitron_trainer/README.md | 2 -- pytorch3d/implicitron/eval_demo.py | 2 +- pytorch3d/implicitron/models/base_model.py | 3 -- .../feature_extractor/feature_extractor.py | 3 -- .../resnet_feature_extractor.py | 1 - pytorch3d/implicitron/models/generic_model.py | 2 -- .../models/global_encoder/autodecoder.py | 2 -- .../models/global_encoder/global_encoder.py | 5 --- .../models/implicit_function/base.py | 3 -- .../implicit_function/decoding_functions.py | 7 ---- .../implicit_function/idr_feature_field.py | 2 -- .../neural_radiance_field.py | 1 - .../scene_representation_networks.py | 5 --- .../models/implicit_function/voxel_grid.py | 2 -- .../voxel_grid_implicit_function.py | 1 - pytorch3d/implicitron/models/metrics.py | 6 ---- pytorch3d/implicitron/models/model_dbir.py | 3 -- pytorch3d/implicitron/models/renderer/base.py | 3 -- .../models/renderer/lstm_renderer.py | 1 - .../models/renderer/multipass_ea.py | 1 - .../models/renderer/ray_point_refiner.py | 3 -- .../models/renderer/ray_sampler.py | 5 --- .../models/renderer/ray_tracing.py | 3 -- .../implicitron/models/renderer/raymarcher.py | 5 --- .../models/renderer/sdf_renderer.py | 1 - .../models/view_pooler/feature_aggregator.py | 12 ------- .../models/view_pooler/view_pooler.py | 1 - .../models/view_pooler/view_sampler.py | 3 -- pytorch3d/implicitron/tools/config.py | 35 +++++++++++++++++++ 29 files changed, 36 insertions(+), 87 deletions(-) diff --git a/projects/implicitron_trainer/README.md b/projects/implicitron_trainer/README.md index 01664a68b..232d697bb 100644 --- a/projects/implicitron_trainer/README.md +++ b/projects/implicitron_trainer/README.md @@ -212,9 +212,7 @@ from pytorch3d.implicitron.tools.config import registry class XRayRenderer(BaseRenderer, torch.nn.Module): n_pts_per_ray: int = 64 - # if there are other base classes, make sure to call `super().__init__()` explicitly def __post_init__(self): - super().__init__() # custom initialization def forward( diff --git a/pytorch3d/implicitron/eval_demo.py b/pytorch3d/implicitron/eval_demo.py index bffc5da7c..91e696945 100644 --- a/pytorch3d/implicitron/eval_demo.py +++ b/pytorch3d/implicitron/eval_demo.py @@ -130,7 +130,7 @@ def evaluate_dbir_for_category( raise ValueError("Image size should be set in the dataset") # init the simple DBIR model - model = ModelDBIR( # pyre-ignore[28]: c’tor implicitly overridden + model = ModelDBIR( render_image_width=image_size, render_image_height=image_size, bg_color=bg_color, diff --git a/pytorch3d/implicitron/models/base_model.py b/pytorch3d/implicitron/models/base_model.py index ec34caed1..56efa69cd 100644 --- a/pytorch3d/implicitron/models/base_model.py +++ b/pytorch3d/implicitron/models/base_model.py @@ -49,9 +49,6 @@ class ImplicitronModelBase(ReplaceableBase, torch.nn.Module): # the training loop. log_vars: List[str] = field(default_factory=lambda: ["objective"]) - def __init__(self) -> None: - super().__init__() - def forward( self, *, # force keyword-only arguments diff --git a/pytorch3d/implicitron/models/feature_extractor/feature_extractor.py b/pytorch3d/implicitron/models/feature_extractor/feature_extractor.py index 32d3d4835..9ce7f5e56 100644 --- a/pytorch3d/implicitron/models/feature_extractor/feature_extractor.py +++ b/pytorch3d/implicitron/models/feature_extractor/feature_extractor.py @@ -15,9 +15,6 @@ class FeatureExtractorBase(ReplaceableBase, torch.nn.Module): Base class for an extractor of a set of features from images. """ - def __init__(self): - super().__init__() - def get_feat_dims(self) -> int: """ Returns: diff --git a/pytorch3d/implicitron/models/feature_extractor/resnet_feature_extractor.py b/pytorch3d/implicitron/models/feature_extractor/resnet_feature_extractor.py index baa20c208..32cd2d42d 100644 --- a/pytorch3d/implicitron/models/feature_extractor/resnet_feature_extractor.py +++ b/pytorch3d/implicitron/models/feature_extractor/resnet_feature_extractor.py @@ -78,7 +78,6 @@ class ResNetFeatureExtractor(FeatureExtractorBase): feature_rescale: float = 1.0 def __post_init__(self): - super().__init__() if self.normalize_image: # register buffers needed to normalize the image for k, v in (("_resnet_mean", _RESNET_MEAN), ("_resnet_std", _RESNET_STD)): diff --git a/pytorch3d/implicitron/models/generic_model.py b/pytorch3d/implicitron/models/generic_model.py index 56ea080bb..979a8435d 100644 --- a/pytorch3d/implicitron/models/generic_model.py +++ b/pytorch3d/implicitron/models/generic_model.py @@ -304,8 +304,6 @@ class GenericModel(ImplicitronModelBase): # pyre-ignore: 13 ) def __post_init__(self): - super().__init__() - if self.view_pooler_enabled: if self.image_feature_extractor_class_type is None: raise ValueError( diff --git a/pytorch3d/implicitron/models/global_encoder/autodecoder.py b/pytorch3d/implicitron/models/global_encoder/autodecoder.py index b03d55887..52077b292 100644 --- a/pytorch3d/implicitron/models/global_encoder/autodecoder.py +++ b/pytorch3d/implicitron/models/global_encoder/autodecoder.py @@ -29,8 +29,6 @@ class Autodecoder(Configurable, torch.nn.Module): ignore_input: bool = False def __post_init__(self): - super().__init__() - if self.n_instances <= 0: raise ValueError(f"Invalid n_instances {self.n_instances}") diff --git a/pytorch3d/implicitron/models/global_encoder/global_encoder.py b/pytorch3d/implicitron/models/global_encoder/global_encoder.py index 641433ad1..19dcb93a7 100644 --- a/pytorch3d/implicitron/models/global_encoder/global_encoder.py +++ b/pytorch3d/implicitron/models/global_encoder/global_encoder.py @@ -26,9 +26,6 @@ class GlobalEncoderBase(ReplaceableBase): (`SequenceAutodecoder`). """ - def __init__(self) -> None: - super().__init__() - def get_encoding_dim(self): """ Returns the dimensionality of the returned encoding. @@ -69,7 +66,6 @@ class SequenceAutodecoder(GlobalEncoderBase, torch.nn.Module): # pyre-ignore: 1 autodecoder: Autodecoder def __post_init__(self): - super().__init__() run_auto_creation(self) def get_encoding_dim(self): @@ -103,7 +99,6 @@ class HarmonicTimeEncoder(GlobalEncoderBase, torch.nn.Module): time_divisor: float = 1.0 def __post_init__(self): - super().__init__() self._harmonic_embedding = HarmonicEmbedding( n_harmonic_functions=self.n_harmonic_functions, append_input=self.append_input, diff --git a/pytorch3d/implicitron/models/implicit_function/base.py b/pytorch3d/implicitron/models/implicit_function/base.py index 75bd36538..7cd67edeb 100644 --- a/pytorch3d/implicitron/models/implicit_function/base.py +++ b/pytorch3d/implicitron/models/implicit_function/base.py @@ -14,9 +14,6 @@ class ImplicitFunctionBase(ABC, ReplaceableBase): - def __init__(self): - super().__init__() - @abstractmethod def forward( self, diff --git a/pytorch3d/implicitron/models/implicit_function/decoding_functions.py b/pytorch3d/implicitron/models/implicit_function/decoding_functions.py index c516df26f..5722a2ea2 100644 --- a/pytorch3d/implicitron/models/implicit_function/decoding_functions.py +++ b/pytorch3d/implicitron/models/implicit_function/decoding_functions.py @@ -45,9 +45,6 @@ class DecoderFunctionBase(ReplaceableBase, torch.nn.Module): space and transforms it into the required quantity (for example density and color). """ - def __post_init__(self): - super().__init__() - def forward( self, features: torch.Tensor, z: Optional[torch.Tensor] = None ) -> torch.Tensor: @@ -83,7 +80,6 @@ class ElementwiseDecoder(DecoderFunctionBase): operation: DecoderActivation = DecoderActivation.IDENTITY def __post_init__(self): - super().__post_init__() if self.operation not in [ DecoderActivation.RELU, DecoderActivation.SOFTPLUS, @@ -163,8 +159,6 @@ class MLPWithInputSkips(Configurable, torch.nn.Module): use_xavier_init: bool = True def __post_init__(self): - super().__init__() - try: last_activation = { DecoderActivation.RELU: torch.nn.ReLU(True), @@ -284,7 +278,6 @@ class MLPDecoder(DecoderFunctionBase): network: MLPWithInputSkips def __post_init__(self): - super().__post_init__() run_auto_creation(self) def forward( diff --git a/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py b/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py index f43a2932e..c0be4a288 100644 --- a/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py +++ b/pytorch3d/implicitron/models/implicit_function/idr_feature_field.py @@ -66,8 +66,6 @@ class IdrFeatureField(ImplicitFunctionBase, torch.nn.Module): encoding_dim: int = 0 def __post_init__(self): - super().__init__() - dims = [self.d_in] + list(self.dims) + [self.d_out + self.feature_vector_size] self.embed_fn = None diff --git a/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py b/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py index aecd91051..1b96c7eee 100644 --- a/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py +++ b/pytorch3d/implicitron/models/implicit_function/neural_radiance_field.py @@ -56,7 +56,6 @@ class NeuralRadianceFieldBase(ImplicitFunctionBase, torch.nn.Module): """ def __post_init__(self): - super().__init__() # The harmonic embedding layer converts input 3D coordinates # to a representation that is more suitable for # processing with a deep neural network. diff --git a/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py b/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py index b9e3cc1e5..4ca71ee16 100644 --- a/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py +++ b/pytorch3d/implicitron/models/implicit_function/scene_representation_networks.py @@ -44,7 +44,6 @@ class SRNRaymarchFunction(Configurable, torch.nn.Module): raymarch_function: Any = None def __post_init__(self): - super().__init__() self._harmonic_embedding = HarmonicEmbedding( self.n_harmonic_functions, append_input=True ) @@ -135,7 +134,6 @@ class SRNPixelGenerator(Configurable, torch.nn.Module): ray_dir_in_camera_coords: bool = False def __post_init__(self): - super().__init__() self._harmonic_embedding = HarmonicEmbedding( self.n_harmonic_functions, append_input=True ) @@ -249,7 +247,6 @@ class SRNRaymarchHyperNet(Configurable, torch.nn.Module): xyz_in_camera_coords: bool = False def __post_init__(self): - super().__init__() raymarch_input_embedding_dim = ( HarmonicEmbedding.get_output_dim_static( self.in_features, @@ -335,7 +332,6 @@ class SRNImplicitFunction(ImplicitFunctionBase, torch.nn.Module): pixel_generator: SRNPixelGenerator def __post_init__(self): - super().__init__() run_auto_creation(self) def create_raymarch_function(self) -> None: @@ -393,7 +389,6 @@ class SRNHyperNetImplicitFunction(ImplicitFunctionBase, torch.nn.Module): pixel_generator: SRNPixelGenerator def __post_init__(self): - super().__init__() run_auto_creation(self) def create_hypernet(self) -> None: diff --git a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py index c9d518eaf..fd14717ed 100644 --- a/pytorch3d/implicitron/models/implicit_function/voxel_grid.py +++ b/pytorch3d/implicitron/models/implicit_function/voxel_grid.py @@ -81,7 +81,6 @@ class VoxelGridBase(ReplaceableBase, torch.nn.Module): ) def __post_init__(self): - super().__init__() if 0 not in self.resolution_changes: raise ValueError("There has to be key `0` in `resolution_changes`.") @@ -857,7 +856,6 @@ class VoxelGridModule(Configurable, torch.nn.Module): param_groups: Dict[str, str] = field(default_factory=lambda: {}) def __post_init__(self): - super().__init__() run_auto_creation(self) n_grids = 1 # Voxel grid objects are batched. We need only a single grid. shapes = self.voxel_grid.get_shapes(epoch=0) diff --git a/pytorch3d/implicitron/models/implicit_function/voxel_grid_implicit_function.py b/pytorch3d/implicitron/models/implicit_function/voxel_grid_implicit_function.py index b21e253a5..f04443b9a 100644 --- a/pytorch3d/implicitron/models/implicit_function/voxel_grid_implicit_function.py +++ b/pytorch3d/implicitron/models/implicit_function/voxel_grid_implicit_function.py @@ -186,7 +186,6 @@ class VoxelGridImplicitFunction(ImplicitFunctionBase, torch.nn.Module): volume_cropping_epochs: Tuple[int, ...] = () def __post_init__(self) -> None: - super().__init__() run_auto_creation(self) # pyre-ignore[16] self.voxel_grid_scaffold = self._create_voxel_grid_scaffold() diff --git a/pytorch3d/implicitron/models/metrics.py b/pytorch3d/implicitron/models/metrics.py index 174c73e5b..edd4b9408 100644 --- a/pytorch3d/implicitron/models/metrics.py +++ b/pytorch3d/implicitron/models/metrics.py @@ -25,9 +25,6 @@ class RegularizationMetricsBase(ReplaceableBase, torch.nn.Module): depend on the model's parameters. """ - def __post_init__(self) -> None: - super().__init__() - def forward( self, model: Any, keys_prefix: str = "loss_", **kwargs ) -> Dict[str, Any]: @@ -56,9 +53,6 @@ class ViewMetricsBase(ReplaceableBase, torch.nn.Module): `forward()` method produces losses and other metrics. """ - def __post_init__(self) -> None: - super().__init__() - def forward( self, raymarched: RendererOutput, diff --git a/pytorch3d/implicitron/models/model_dbir.py b/pytorch3d/implicitron/models/model_dbir.py index c14dab9dd..826731fce 100644 --- a/pytorch3d/implicitron/models/model_dbir.py +++ b/pytorch3d/implicitron/models/model_dbir.py @@ -41,9 +41,6 @@ class ModelDBIR(ImplicitronModelBase): bg_color: Tuple[float, float, float] = (0.0, 0.0, 0.0) max_points: int = -1 - def __post_init__(self): - super().__init__() - def forward( self, *, # force keyword-only arguments diff --git a/pytorch3d/implicitron/models/renderer/base.py b/pytorch3d/implicitron/models/renderer/base.py index 9b29bdeba..a8644dde6 100644 --- a/pytorch3d/implicitron/models/renderer/base.py +++ b/pytorch3d/implicitron/models/renderer/base.py @@ -141,9 +141,6 @@ class BaseRenderer(ABC, ReplaceableBase): Base class for all Renderer implementations. """ - def __init__(self) -> None: - super().__init__() - def requires_object_mask(self) -> bool: """ Whether `forward` needs the object_mask. diff --git a/pytorch3d/implicitron/models/renderer/lstm_renderer.py b/pytorch3d/implicitron/models/renderer/lstm_renderer.py index b24c253f5..b579ff3dd 100644 --- a/pytorch3d/implicitron/models/renderer/lstm_renderer.py +++ b/pytorch3d/implicitron/models/renderer/lstm_renderer.py @@ -57,7 +57,6 @@ class LSTMRenderer(BaseRenderer, torch.nn.Module): verbose: bool = False def __post_init__(self): - super().__init__() self._lstm = torch.nn.LSTMCell( input_size=self.n_feature_channels, hidden_size=self.hidden_size, diff --git a/pytorch3d/implicitron/models/renderer/multipass_ea.py b/pytorch3d/implicitron/models/renderer/multipass_ea.py index b937afc0c..18ee8f5b6 100644 --- a/pytorch3d/implicitron/models/renderer/multipass_ea.py +++ b/pytorch3d/implicitron/models/renderer/multipass_ea.py @@ -90,7 +90,6 @@ class MultiPassEmissionAbsorptionRenderer( # pyre-ignore: 13 return_weights: bool = False def __post_init__(self): - super().__init__() self._refiners = { EvaluationMode.TRAINING: RayPointRefiner( n_pts_per_ray=self.n_pts_per_ray_fine_training, diff --git a/pytorch3d/implicitron/models/renderer/ray_point_refiner.py b/pytorch3d/implicitron/models/renderer/ray_point_refiner.py index a69398c6a..22db11f4b 100644 --- a/pytorch3d/implicitron/models/renderer/ray_point_refiner.py +++ b/pytorch3d/implicitron/models/renderer/ray_point_refiner.py @@ -38,9 +38,6 @@ class RayPointRefiner(Configurable, torch.nn.Module): random_sampling: bool add_input_samples: bool = True - def __post_init__(self) -> None: - super().__init__() - def forward( self, input_ray_bundle: ImplicitronRayBundle, diff --git a/pytorch3d/implicitron/models/renderer/ray_sampler.py b/pytorch3d/implicitron/models/renderer/ray_sampler.py index 225084fcd..555c5db8e 100644 --- a/pytorch3d/implicitron/models/renderer/ray_sampler.py +++ b/pytorch3d/implicitron/models/renderer/ray_sampler.py @@ -20,9 +20,6 @@ class RaySamplerBase(ReplaceableBase): Base class for ray samplers. """ - def __init__(self): - super().__init__() - def forward( self, cameras: CamerasBase, @@ -102,8 +99,6 @@ class AbstractMaskRaySampler(RaySamplerBase, torch.nn.Module): stratified_point_sampling_evaluation: bool = False def __post_init__(self): - super().__init__() - if (self.n_rays_per_image_sampled_from_mask is not None) and ( self.n_rays_total_training is not None ): diff --git a/pytorch3d/implicitron/models/renderer/ray_tracing.py b/pytorch3d/implicitron/models/renderer/ray_tracing.py index 890b055e9..5c0dd0a40 100644 --- a/pytorch3d/implicitron/models/renderer/ray_tracing.py +++ b/pytorch3d/implicitron/models/renderer/ray_tracing.py @@ -43,9 +43,6 @@ class RayTracing(Configurable, nn.Module): n_steps: int = 100 n_secant_steps: int = 8 - def __post_init__(self): - super().__init__() - def forward( self, sdf: Callable[[torch.Tensor], torch.Tensor], diff --git a/pytorch3d/implicitron/models/renderer/raymarcher.py b/pytorch3d/implicitron/models/renderer/raymarcher.py index 37ddd6c5c..3e42815dc 100644 --- a/pytorch3d/implicitron/models/renderer/raymarcher.py +++ b/pytorch3d/implicitron/models/renderer/raymarcher.py @@ -22,9 +22,6 @@ class RaymarcherBase(ReplaceableBase): and marching along them in order to generate a feature render. """ - def __init__(self): - super().__init__() - def forward( self, rays_densities: torch.Tensor, @@ -98,8 +95,6 @@ def __post_init__(self): surface_thickness: Denotes the overlap between the absorption function and the density function. """ - super().__init__() - bg_color = torch.tensor(self.bg_color) if bg_color.ndim != 1: raise ValueError(f"bg_color (shape {bg_color.shape}) should be a 1D tensor") diff --git a/pytorch3d/implicitron/models/renderer/sdf_renderer.py b/pytorch3d/implicitron/models/renderer/sdf_renderer.py index d8782911e..4326e09ef 100644 --- a/pytorch3d/implicitron/models/renderer/sdf_renderer.py +++ b/pytorch3d/implicitron/models/renderer/sdf_renderer.py @@ -35,7 +35,6 @@ class SignedDistanceFunctionRenderer(BaseRenderer, torch.nn.Module): # pyre-ign def __post_init__( self, ): - super().__init__() render_features_dimensions = self.render_features_dimensions if len(self.bg_color) not in [1, render_features_dimensions]: raise ValueError( diff --git a/pytorch3d/implicitron/models/view_pooler/feature_aggregator.py b/pytorch3d/implicitron/models/view_pooler/feature_aggregator.py index 798541ea6..fa11e93d8 100644 --- a/pytorch3d/implicitron/models/view_pooler/feature_aggregator.py +++ b/pytorch3d/implicitron/models/view_pooler/feature_aggregator.py @@ -118,9 +118,6 @@ class IdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorBase): the outputs. """ - def __post_init__(self): - super().__init__() - def get_aggregated_feature_dim( self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] ): @@ -181,9 +178,6 @@ class ReductionFeatureAggregator(torch.nn.Module, FeatureAggregatorBase): ReductionFunction.STD, ) - def __post_init__(self): - super().__init__() - def get_aggregated_feature_dim( self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] ): @@ -275,9 +269,6 @@ class AngleWeightedReductionFeatureAggregator(torch.nn.Module, FeatureAggregator weight_by_ray_angle_gamma: float = 1.0 min_ray_angle_weight: float = 0.1 - def __post_init__(self): - super().__init__() - def get_aggregated_feature_dim( self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] ): @@ -377,9 +368,6 @@ class AngleWeightedIdentityFeatureAggregator(torch.nn.Module, FeatureAggregatorB weight_by_ray_angle_gamma: float = 1.0 min_ray_angle_weight: float = 0.1 - def __post_init__(self): - super().__init__() - def get_aggregated_feature_dim( self, feats_or_feats_dim: Union[Dict[str, torch.Tensor], int] ): diff --git a/pytorch3d/implicitron/models/view_pooler/view_pooler.py b/pytorch3d/implicitron/models/view_pooler/view_pooler.py index eca64b306..a47ef72de 100644 --- a/pytorch3d/implicitron/models/view_pooler/view_pooler.py +++ b/pytorch3d/implicitron/models/view_pooler/view_pooler.py @@ -38,7 +38,6 @@ class ViewPooler(Configurable, torch.nn.Module): feature_aggregator: FeatureAggregatorBase def __post_init__(self): - super().__init__() run_auto_creation(self) def get_aggregated_feature_dim(self, feats: Union[Dict[str, torch.Tensor], int]): diff --git a/pytorch3d/implicitron/models/view_pooler/view_sampler.py b/pytorch3d/implicitron/models/view_pooler/view_sampler.py index 66e65e069..56f91ed2f 100644 --- a/pytorch3d/implicitron/models/view_pooler/view_sampler.py +++ b/pytorch3d/implicitron/models/view_pooler/view_sampler.py @@ -29,9 +29,6 @@ class ViewSampler(Configurable, torch.nn.Module): masked_sampling: bool = False sampling_mode: str = "bilinear" - def __post_init__(self): - super().__init__() - def forward( self, *, # force kw args diff --git a/pytorch3d/implicitron/tools/config.py b/pytorch3d/implicitron/tools/config.py index 5e5d945c5..3289bd32b 100644 --- a/pytorch3d/implicitron/tools/config.py +++ b/pytorch3d/implicitron/tools/config.py @@ -184,6 +184,7 @@ def __post_init__(self): CREATE_PREFIX: str = "create_" IMPL_SUFFIX: str = "_impl" TWEAK_SUFFIX: str = "_tweak_args" +_DATACLASS_INIT: str = "__dataclass_own_init__" class ReplaceableBase: @@ -834,6 +835,9 @@ def x_tweak_args(cls, member_type: Type, args: DictConfig) -> None then the default_factory of x_args will also have a call to x_tweak_args(X, x_args) and the default_factory of x_Y_args will also have a call to x_tweak_args(Y, x_Y_args). + In addition, if the class inherits torch.nn.Module, the generated __init__ will + call torch.nn.Module's __init__ before doing anything else. + Note that although the *_args members are intended to have type DictConfig, they are actually internally annotated as dicts. OmegaConf is happy to see a DictConfig in place of a dict, but not vice-versa. Allowing dict lets a class user specify @@ -912,9 +916,40 @@ def x_tweak_args(cls, member_type: Type, args: DictConfig) -> None some_class._known_implementations = known_implementations dataclasses.dataclass(eq=False)(some_class) + _fixup_class_init(some_class) return some_class +def _fixup_class_init(some_class) -> None: + """ + In-place modification of the some_class class which happens + after dataclass processing. + + If the dataclass some_class inherits torch.nn.Module, then + makes torch.nn.Module's __init__ be called before anything else + on instantiation of some_class. + This is a bit like attr's __pre_init__. + """ + + assert _is_actually_dataclass(some_class) + try: + import torch + except ModuleNotFoundError: + return + + if not issubclass(some_class, torch.nn.Module): + return + + def init(self, *args, **kwargs) -> None: + torch.nn.Module.__init__(self) + getattr(self, _DATACLASS_INIT)(*args, **kwargs) + + assert not hasattr(some_class, _DATACLASS_INIT) + + setattr(some_class, _DATACLASS_INIT, some_class.__init__) + some_class.__init__ = init + + def get_default_args_field( C, *,