Skip to content

Commit

Permalink
fix: Fix additional mem copy of the model during re-export (#3302)
Browse files Browse the repository at this point in the history
  • Loading branch information
peri044 authored Dec 12, 2024
1 parent 3e376c4 commit d3a8880
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
21 changes: 19 additions & 2 deletions py/torch_tensorrt/dynamo/runtime/register_fake_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def fake_tensorrt_execute_engine(
modes = ["opt"]

# Get the TRTEngine class and infer output shapes based on input shapes
trt_engine = fake_trt_engine.wrapped_obj.engine
trt_engine = fake_trt_engine.real_obj
outputs_mode_dict = defaultdict(list)
for mode in modes:
input_shapes = [unwrap_tensor_shape(input, mode=mode) for input in inputs]
Expand Down Expand Up @@ -79,7 +79,21 @@ def fake_tensorrt_execute_engine(
@torch._library.register_fake_class("tensorrt::Engine")
class FakeTRTEngine:
def __init__(self, engine_info: List[str]) -> None:
self.engine = torch.classes.tensorrt.Engine(engine_info)
self.version = engine_info[torch.ops.tensorrt.ABI_TARGET_IDX()]
self.name = engine_info[torch.ops.tensorrt.NAME_IDX()]
self.device_info = engine_info[torch.ops.tensorrt.DEVICE_IDX()]
self.serialized_engine = engine_info[torch.ops.tensorrt.ENGINE_IDX()]
self.in_binding_names = engine_info[
torch.ops.tensorrt.INPUT_BINDING_NAMES_IDX()
]
self.out_binding_names = engine_info[
torch.ops.tensorrt.OUTPUT_BINDING_NAMES_IDX()
]
self.hardware_compatible = engine_info[torch.ops.tensorrt.HW_COMPATIBLE_IDX()]
self.serialized_metadata = engine_info[
torch.ops.tensorrt.SERIALIZED_METADATA_IDX()
]
self.target_platform = engine_info[torch.ops.tensorrt.TARGET_PLATFORM_IDX()]

@classmethod
def __obj_unflatten__(cls, flattened_tq: Any) -> Any:
Expand Down Expand Up @@ -127,3 +141,6 @@ def infer_outputs(self, input_shapes: List[Any]) -> Any:

def __setstate__(self, serialized_state: List[str]) -> Any:
pass

def __getstate__(self) -> Any:
pass
1 change: 0 additions & 1 deletion tests/py/dynamo/models/test_reexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@ def forward(self, x):

exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path)

# Reexport
trt_exp_program = torch.export.export(trt_module, (input,), strict=False)
Expand Down

0 comments on commit d3a8880

Please sign in to comment.