From 132f1328d2c02ebec181851f3db47b63ca7abd51 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Fri, 11 Aug 2023 16:48:00 -0700 Subject: [PATCH] fix: Add generic evaluator function --- .../dynamo/conversion/aten_ops_converters.py | 24 ++++++----- .../dynamo/conversion/converter_registry.py | 6 +-- .../dynamo/conversion/impl/__init__.py | 1 - .../dynamo/conversion/impl/cast.py | 20 ++++++++++ .../dynamo/conversion/impl/evaluators.py | 40 ------------------- tests/py/dynamo/converters/test_casts.py | 30 ++++++++++++++ tests/py/dynamo/converters/test_evaluators.py | 30 -------------- tests/py/dynamo/models/test_models.py | 2 - 8 files changed, 67 insertions(+), 86 deletions(-) delete mode 100644 py/torch_tensorrt/dynamo/conversion/impl/evaluators.py diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index f6b8dbcc78..96ed6ebb53 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -14,7 +14,7 @@ from torch_tensorrt.fx.converters import acc_ops_converters from torch_tensorrt.fx.types import TRTNetwork, TRTTensor -from .converter_registry import dynamo_tensorrt_converter +from .converter_registry import ConverterRegistry, dynamo_tensorrt_converter _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -421,22 +421,26 @@ def aten_ops_to_copy_dtype( ) -@dynamo_tensorrt_converter(operator.getitem) -def operator_getitem( +def getitem_validator(getitem_node: Node) -> bool: + from torch_tensorrt.dynamo.conversion.converter_registry import DYNAMO_CONVERTERS + + # Getitem nodes can only be converted if their parent node also can + return getitem_node.args[0] in DYNAMO_CONVERTERS + + +# TODO: Subsequent evaluators should be registered here with their own validators +@dynamo_tensorrt_converter(operator.getitem, capability_validator=getitem_validator) +def generic_evaluator( network: TRTNetwork, target: Target, args: Tuple[Argument, ...], kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return impl.evaluators.getitem( - network, - target, - SourceIR.ATEN, - name, - args[0], - args[1], + _LOGGER.debug( + f"Evaluating {ConverterRegistry.qualified_name_or_str(target)} on object with name: {name}" ) + return target(*args, **kwargs) @dynamo_tensorrt_converter(torch.ops.aten.clone.default) diff --git a/py/torch_tensorrt/dynamo/conversion/converter_registry.py b/py/torch_tensorrt/dynamo/conversion/converter_registry.py index db41420367..bc22043aad 100644 --- a/py/torch_tensorrt/dynamo/conversion/converter_registry.py +++ b/py/torch_tensorrt/dynamo/conversion/converter_registry.py @@ -64,7 +64,7 @@ def dynamo_tensorrt_converter( enabled: bool = True, capability_validator: Optional[Callable[[Node], bool]] = None, priority: ConverterPriority = ConverterPriority.STANDARD, -) -> Callable[[Any], Union[TRTTensor, Sequence[TRTTensor]]]: +) -> Callable[[Any], TRTTensor | Sequence[TRTTensor]]: """Decorator for Dynamo TensorRT Converter Registers the decorated function in the DYNAMO_ATEN_CONVERTERS registry @@ -347,8 +347,8 @@ def unique_targets(self) -> Set[Target]: """Returns the set of unique converter targets stored across all registries""" return set.union(*[set(registry.keys()) for registry in self.registries]) - # TODO: Make this a static method since it does not need state - def qualified_name_or_str(self, target: Target) -> str: + @staticmethod + def qualified_name_or_str(target: Target) -> str: """Returns string representation of an FX Node target""" if isinstance(target, str): return target diff --git a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py index 611dc630fa..8f7ab1badc 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/__init__.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/__init__.py @@ -6,7 +6,6 @@ condition, elementwise, embedding, - evaluators, matmul, normalization, permutation, diff --git a/py/torch_tensorrt/dynamo/conversion/impl/cast.py b/py/torch_tensorrt/dynamo/conversion/impl/cast.py index 68899de766..0c55731169 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/cast.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/cast.py @@ -1,3 +1,4 @@ +import logging from typing import Optional from torch.fx.node import Target @@ -5,6 +6,8 @@ from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor from torch_tensorrt.fx.types import TRTDataType, TRTNetwork, TRTTensor +LOGGER: logging.Logger = logging.getLogger(__name__) + def to_copy( network: TRTNetwork, @@ -21,3 +24,20 @@ def to_copy( casted_tensor = cast_trt_tensor(network, input, dtype, name, target, source_ir) return casted_tensor + + +def clone( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + if not isinstance(input, TRTTensor): + raise RuntimeError( + f"clone received input {input} that is not a TensorRT ITensor" + ) + + LOGGER.debug(f"Evaluating clone on object with name: {name}") + + return input diff --git a/py/torch_tensorrt/dynamo/conversion/impl/evaluators.py b/py/torch_tensorrt/dynamo/conversion/impl/evaluators.py deleted file mode 100644 index cb61fb6158..0000000000 --- a/py/torch_tensorrt/dynamo/conversion/impl/evaluators.py +++ /dev/null @@ -1,40 +0,0 @@ -import logging -import operator -from typing import Optional, Sequence - -from torch.fx.node import Target -from torch_tensorrt.dynamo._SourceIR import SourceIR -from torch_tensorrt.fx.types import TRTNetwork, TRTTensor - -LOGGER: logging.Logger = logging.getLogger(__name__) - - -def getitem( - network: TRTNetwork, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input: Sequence[TRTTensor], - index: int, -) -> TRTTensor: - LOGGER.debug(f"Evaluating getitem on object with name: {name}") - - # Directly index the input sequence and return the value - return operator.getitem(input, index) - - -def clone( - network: TRTNetwork, - target: Target, - source_ir: Optional[SourceIR], - name: str, - input: TRTTensor, -) -> TRTTensor: - if not isinstance(input, TRTTensor): - raise RuntimeError( - f"clone received input {input} that is not a TensorRT ITensor" - ) - - LOGGER.debug(f"Evaluating clone on object with name: {name}") - - return input diff --git a/tests/py/dynamo/converters/test_casts.py b/tests/py/dynamo/converters/test_casts.py index 4bb05ef463..3a4fd65610 100644 --- a/tests/py/dynamo/converters/test_casts.py +++ b/tests/py/dynamo/converters/test_casts.py @@ -5,6 +5,36 @@ from torch_tensorrt.dynamo.conversion import UnsupportedOperatorException +class TestCloneConverter(DispatchTestCase): + def test_clone_contiguous(self): + class Clone(nn.Module): + def forward(self, x): + y = torch.clone(x, memory_format=torch.contiguous_format) + return y + 1 + + inputs = [torch.randn((1, 3, 10))] + self.run_test( + Clone(), + inputs, + expected_ops={torch.ops.aten.clone.default}, + disable_passes=True, + ) + + def test_clone_regular(self): + class Clone(nn.Module): + def forward(self, x): + y = torch.clone(x) + return y + 1 + + inputs = [torch.randn((8, 2, 10))] + self.run_test( + Clone(), + inputs, + expected_ops={torch.ops.aten.clone.default}, + disable_passes=True, + ) + + class TestToCopyConverter(DispatchTestCase): def test_to_copy_half(self): class ToCopyHalf(nn.Module): diff --git a/tests/py/dynamo/converters/test_evaluators.py b/tests/py/dynamo/converters/test_evaluators.py index cf42009495..64dd303727 100644 --- a/tests/py/dynamo/converters/test_evaluators.py +++ b/tests/py/dynamo/converters/test_evaluators.py @@ -7,36 +7,6 @@ from torch.testing._internal.common_utils import run_tests -class TestCloneConverter(DispatchTestCase): - def test_clone_contiguous(self): - class Clone(nn.Module): - def forward(self, x): - y = torch.clone(x, memory_format=torch.contiguous_format) - return y + 1 - - inputs = [torch.randn((1, 3, 10))] - self.run_test( - Clone(), - inputs, - expected_ops={torch.ops.aten.clone.default}, - disable_passes=True, - ) - - def test_clone_regular(self): - class Clone(nn.Module): - def forward(self, x): - y = torch.clone(x) - return y + 1 - - inputs = [torch.randn((8, 2, 10))] - self.run_test( - Clone(), - inputs, - expected_ops={torch.ops.aten.clone.default}, - disable_passes=True, - ) - - # TODO: Switch this test back to self.run_test once an implementation exists # for a converter that returns a list, such as aten.split @unittest.skip("Pending aten.split converter. Currently tested by E2E") diff --git a/tests/py/dynamo/models/test_models.py b/tests/py/dynamo/models/test_models.py index c8f730e2e6..50d7fcbbd9 100644 --- a/tests/py/dynamo/models/test_models.py +++ b/tests/py/dynamo/models/test_models.py @@ -27,7 +27,6 @@ def test_resnet18(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 10, "ir": "torch_compile", } @@ -176,7 +175,6 @@ def test_resnet18_half(ir): "ir": ir, "pass_through_build_failures": True, "optimization_level": 1, - "min_block_size": 10, "ir": "torch_compile", }