Skip to content

Commit

Permalink
fix: Remove pytorch overhead while finding fusions for fully converti…
Browse files Browse the repository at this point in the history
…ble models (#3311)
  • Loading branch information
peri044 authored Dec 16, 2024
1 parent 3dded9d commit 1840e36
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
2 changes: 2 additions & 0 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,7 +765,9 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:
min_block_size=settings.min_block_size,
torch_executed_ops=settings.torch_executed_ops,
require_full_compilation=settings.require_full_compilation,
skip_fusion=(num_supported_ops == total_ops),
)

except torch.fx.passes.splitter_base.FxNetSplitterInternalError:
logger.error(
"Partitioning failed on the subgraph with fast partition. See trace above. "
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def __init__(
min_block_size: int = MIN_BLOCK_SIZE,
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
return_tuple: bool = False,
skip_fusion: bool = False,
):
"""
Preprocesses graph before splitting:
Expand All @@ -127,6 +128,7 @@ def __init__(
self.settings = _SplitterSettingBase(
min_acc_module_size=min_block_size,
allow_non_tensor=True,
skip_fusion=skip_fusion,
)
self.operator_support = operator_support

Expand Down Expand Up @@ -252,6 +254,7 @@ def partition(
min_block_size: int = MIN_BLOCK_SIZE,
torch_executed_ops: Collection[Target] = set(),
require_full_compilation: bool = REQUIRE_FULL_COMPILATION,
skip_fusion: bool = False,
) -> Tuple[torch.fx.GraphModule, OpSupportTester]:
"""Partition an FX GraphModule with aten ops into TRT engines
Partitioning is based on converter operator support
Expand All @@ -262,6 +265,7 @@ def partition(
min_block_size: Minimum number of operators per TRT-Engine Block
torch_executed_ops: Collection of operations to run in Torch, regardless of converter coverage
require_full_compilation: Require that all computational operators be run in TRT
skip_fusion: Skip fusions found by FxNetAccFusionsFinder
Returns:
torch.fx.GraphModule, OpSupportTester
"""
Expand All @@ -277,6 +281,7 @@ def partition(
supported_ops,
min_block_size=min_block_size,
require_full_compilation=require_full_compilation,
skip_fusion=skip_fusion,
)

partitioned_graph = partitioner.partition_graph()
Expand Down

0 comments on commit 1840e36

Please sign in to comment.