Skip to content

Commit

Permalink
fix: Fix copying metadata during lowering (#3320)
Browse files Browse the repository at this point in the history
  • Loading branch information
peri044 authored Dec 12, 2024
1 parent de39fa3 commit 184f601
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 8 deletions.
12 changes: 6 additions & 6 deletions py/torch_tensorrt/dynamo/lowering/passes/view_to_reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)
from torch_tensorrt.dynamo.utils import get_metadata, set_metadata
from torch_tensorrt.dynamo.utils import copy_metadata

logger = logging.getLogger(__name__)

Expand All @@ -26,14 +26,14 @@ def orig(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
def replacement(input: torch.Tensor, shape: List[torch.SymInt]) -> torch.Tensor:
return replacement_op(input, shape)

# Store metadata of the orig_op
metadata = get_metadata(gm, orig_op)

if torch.fx.subgraph_rewriter.replace_pattern(gm, orig, replacement):
match_and_replacements = torch.fx.subgraph_rewriter._replace_pattern(
gm, orig, replacement
)
if match_and_replacements:
gm = clean_up_graph_after_modifications(gm)
logger.debug(f"Graph after replacing view with reshape:\n{gm.graph}")

# Copy the orig_op's metadata to the replacement op
set_metadata(gm, replacement_op, metadata)
copy_metadata(match_and_replacements)

return gm
18 changes: 16 additions & 2 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import tensorrt as trt
import torch
from torch._subclasses.fake_tensor import FakeTensor

from packaging import version
from torch_tensorrt._Device import Device
from torch_tensorrt._enums import dtype
from torch_tensorrt._features import ENABLED_FEATURES
Expand All @@ -22,6 +20,8 @@
from torch_tensorrt.dynamo._engine_cache import BaseEngineCache
from torch_tensorrt.dynamo._settings import CompilationSettings

from packaging import version

from .types import TRTDataType

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -716,6 +716,20 @@ def set_metadata(
node.meta = metadata[idx]


def copy_metadata(match_and_replacements: List[Any]) -> None:
"""
Copy the metadata from anchor node to the replacement node. This should be used
if the anchor node is replaced with only a single replacement node i.e one-one replacement.
"""
for match_and_replacement in match_and_replacements:
anchor_node = match_and_replacement.nodes_map[match_and_replacement.anchor]
assert (
len(match_and_replacement.replacements) == 1
), "Found more than 1 replacements for the anchor node."
replacement_node = match_and_replacement.replacements[0]
replacement_node.meta = anchor_node.meta


def flatten_nodes(nodes: Any) -> List[torch.fx.node.Node]:
ret = []
if isinstance(nodes, torch.fx.node.Node):
Expand Down

0 comments on commit 184f601

Please sign in to comment.