Skip to content

Commit

Permalink
chore: Rename to CudaGraphsTorchTensorRTModule class
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Dec 13, 2024
1 parent 5edf79a commit 11886fe
Show file tree
Hide file tree
Showing 6 changed files with 13 additions and 329 deletions.
111 changes: 0 additions & 111 deletions examples/dynamo/cudagraphs_wrapper_example.py

This file was deleted.

8 changes: 4 additions & 4 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
from torch_tensorrt._features import ENABLED_FEATURES
from torch_tensorrt._Input import Input
from torch_tensorrt.dynamo import _defaults
from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import (
WrapperTorchTensorRTModule,
from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import (
CudaGraphsTorchTensorRTModule,
)
from torch_tensorrt.fx import InputTensorSpec
from torch_tensorrt.fx.lower import compile as fx_compile
Expand Down Expand Up @@ -589,15 +589,15 @@ def save(
Save the model to disk in the specified output format.
Arguments:
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule | WrapperTorchTensorRTModule)): Compiled Torch-TensorRT module
module (Optional(torch.jit.ScriptModule | torch.export.ExportedProgram | torch.fx.GraphModule | CudaGraphsTorchTensorRTModule)): Compiled Torch-TensorRT module
inputs (torch.Tensor): Torch input tensors
arg_inputs (Tuple[Any, ...]): Same as inputs. Alias for better understanding with kwarg_inputs.
kwarg_inputs (dict[Any, ...]): Optional, kwarg inputs to the module forward function.
output_format (str): Format to save the model. Options include exported_program | torchscript.
retrace (bool): When the module type is a fx.GraphModule, this option re-exports the graph using torch.export.export(strict=False) to save it.
This flag is experimental for now.
"""
if isinstance(module, WrapperTorchTensorRTModule):
if isinstance(module, CudaGraphsTorchTensorRTModule):
module = module.compiled_module
module_type = _parse_module_type(module)
accepted_formats = {"exported_program", "torchscript"}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
logger = logging.getLogger(__name__)


class WrapperTorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
class CudaGraphsTorchTensorRTModule(torch.nn.Module): # type: ignore[misc]
"""This Wrapper runtime module is to record/replay whole cuda graph in sub modules
Args:
Expand All @@ -24,7 +24,7 @@ def __init__(
self,
compiled_module: torch.nn.Module,
):
super(WrapperTorchTensorRTModule, self).__init__()
super(CudaGraphsTorchTensorRTModule, self).__init__()
self.compiled_module = compiled_module
self.inputs = partitioning.construct_submodule_inputs(compiled_module)

Expand Down
6 changes: 3 additions & 3 deletions py/torch_tensorrt/runtime/_cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import torch
import torch_tensorrt
from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import (
WrapperTorchTensorRTModule,
from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import (
CudaGraphsTorchTensorRTModule,
)


Expand Down Expand Up @@ -90,7 +90,7 @@ def __enter__(self) -> torch.nn.Module:
logger.debug(
f"{num_torch_module} torch modules are in subgraphs. Using wrapper module for cuda graphs"
)
return WrapperTorchTensorRTModule(self.compiled_module)
return CudaGraphsTorchTensorRTModule(self.compiled_module)
else:
if num_trt_module > 0:
logger.debug(
Expand Down
8 changes: 4 additions & 4 deletions py/torch_tensorrt/runtime/_weight_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@

import torch
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule, TorchTensorRTModule
from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import (
WrapperTorchTensorRTModule,
from torch_tensorrt.dynamo.runtime._CudaGraphsTorchTensorRTModule import (
CudaGraphsTorchTensorRTModule,
)

logger = logging.getLogger(__name__)
Expand All @@ -16,12 +16,12 @@ class _WeightStreamingContextManager(object):
"""

def __init__(
self, module: torch.fx.GraphModule | WrapperTorchTensorRTModule
self, module: torch.fx.GraphModule | CudaGraphsTorchTensorRTModule
) -> None:
rt_mods = []
self.current_device_budget = 0

if isinstance(module, WrapperTorchTensorRTModule):
if isinstance(module, CudaGraphsTorchTensorRTModule):
module = module.compiled_module
for name, rt_mod in module.named_children():
if "_run_on_acc" in name and isinstance(
Expand Down
Loading

0 comments on commit 11886fe

Please sign in to comment.