Skip to content

Commit

Permalink
Prepare for "Fix type-safety of torch.nn.Module instances": wave 2
Browse files Browse the repository at this point in the history
Summary: See D52890934

Reviewed By: malfet, r-barnes

Differential Revision: D66245100

fbshipit-source-id: 019058106ac7eaacf29c1c55912922ea55894d23
  • Loading branch information
ezyang authored and facebook-github-bot committed Nov 21, 2024
1 parent 754469e commit c69939a
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ def forward(
return torch.stack(losses, dim=0).mean()

def fake_value(self, densepose_predictor_outputs: Any, embedder: nn.Module):
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Module, Tensor]` is not a
# function.
losses = [embedder(mesh_name).sum() * 0 for mesh_name in embedder.mesh_names]
losses.append(densepose_predictor_outputs.embedding.sum() * 0)
return torch.mean(torch.stack(losses))
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def forward(self, embedder: nn.Module):

def fake_value(self, embedder: nn.Module):
losses = []
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Module, Tensor]` is not a
# function.
for mesh_name in embedder.mesh_names:
losses.append(embedder(mesh_name).sum() * 0)
return torch.mean(torch.stack(losses))
Expand Down
4 changes: 4 additions & 0 deletions projects/DensePose/densepose/modeling/losses/embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ def __call__(
) / (-self.embdist_gauss_sigma)
losses[mesh_name] = F.cross_entropy(scores, vertex_indices_i, ignore_index=-1)

# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Module, Tensor]` is not a
# function.
for mesh_name in embedder.mesh_names:
if mesh_name not in losses:
losses[mesh_name] = self.fake_value(
Expand All @@ -113,6 +115,8 @@ def __call__(

def fake_values(self, densepose_predictor_outputs: Any, embedder: nn.Module):
losses = {}
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Module, Tensor]` is not a
# function.
for mesh_name in embedder.mesh_names:
losses[mesh_name] = self.fake_value(densepose_predictor_outputs, embedder, mesh_name)
return losses
Expand Down
4 changes: 4 additions & 0 deletions projects/DensePose/densepose/modeling/losses/soft_embed.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def __call__(
)
losses[mesh_name] = (-geodist_softmax_values * embdist_logsoftmax_values).sum(1).mean()

# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Module, Tensor]` is not a
# function.
for mesh_name in embedder.mesh_names:
if mesh_name not in losses:
losses[mesh_name] = self.fake_value(
Expand All @@ -127,6 +129,8 @@ def __call__(

def fake_values(self, densepose_predictor_outputs: Any, embedder: nn.Module):
losses = {}
# pyre-fixme[29]: `Union[(self: Tensor) -> Any, Module, Tensor]` is not a
# function.
for mesh_name in embedder.mesh_names:
losses[mesh_name] = self.fake_value(densepose_predictor_outputs, embedder, mesh_name)
return losses
Expand Down
2 changes: 2 additions & 0 deletions projects/DensePose/densepose/modeling/roi_heads/roi_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def _forward_densepose(self, features: Dict[str, torch.Tensor], instances: List[
proposal_boxes = [x.proposal_boxes for x in proposals]

if self.use_decoder:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
features_list = [self.decoder(features_list)]

features_dp = self.densepose_pooler(features_list, proposal_boxes)
Expand All @@ -171,6 +172,7 @@ def _forward_densepose(self, features: Dict[str, torch.Tensor], instances: List[
pred_boxes = [x.pred_boxes for x in instances]

if self.use_decoder:
# pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
features_list = [self.decoder(features_list)]

features_dp = self.densepose_pooler(features_list, pred_boxes)
Expand Down

0 comments on commit c69939a

Please sign in to comment.