Skip to content

Commit

Permalink
chore: update in cuda graphs api
Browse files Browse the repository at this point in the history
  • Loading branch information
keehyuna committed Dec 13, 2024
1 parent 614b670 commit 5edf79a
Show file tree
Hide file tree
Showing 13 changed files with 168 additions and 142 deletions.
2 changes: 0 additions & 2 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ Tutorials
* :ref:`custom_kernel_plugins`
* :ref:`mutable_torchtrt_module_example`
* :ref:`weight_streaming_example`
* :ref:`cudagraphs_wrapper_example`

.. toctree::
:caption: Tutorials
Expand All @@ -85,7 +84,6 @@ Tutorials
tutorials/_rendered_examples/dynamo/custom_kernel_plugins
tutorials/_rendered_examples/dynamo/mutable_torchtrt_module_example
tutorials/_rendered_examples/dynamo/weight_streaming_example
tutorials/_rendered_examples/dynamo/cudagraphs_wrapper_example

Dynamo Frontend
----------------
Expand Down
2 changes: 1 addition & 1 deletion docsrc/user_guide/runtime.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ Cudagraphs can accelerate certain models by reducing kernel overheads, as docume
torch_tensorrt.runtime.set_cudagraphs_mode(False)
# Enables Cudagraphs Mode, then resets the mode to its prior setting
with torch_tensorrt.runtime.enable_cudagraphs():
with torch_tensorrt.runtime.enable_cudagraphs(trt_module):
...
In the current implementation, use of a new input shape (for instance in dynamic shape
Expand Down
31 changes: 22 additions & 9 deletions examples/dynamo/cudagraphs_wrapper_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def forward(self, x):
# Compiler options
# ----------------------------------
#
# The 'torch_executed_ops' compiler option is used to demonstrate graph breaks for this example.
# The 'torch_executed_ops' compiler option is used to demonstrate the module with graph breaks for this example.
# debug=True compiler option provides detailed insights into the compilation process and helps
# pinpoint where graph breaks occur

Expand All @@ -47,6 +47,14 @@ def forward(self, x):
inputs=[input],
min_block_size=1,
pass_through_build_failures=True,
)

trt_model_with_graph_break = torch_tensorrt.compile(
model,
ir="dynamo",
inputs=[input],
min_block_size=1,
pass_through_build_failures=True,
debug=True,
torch_executed_ops={"torch.ops.aten.mul.Tensor"},
)
Expand Down Expand Up @@ -76,23 +84,28 @@ def forward(self, x):
# trt module with cuda graphs
# ----------------------------------
#
# When CUDA Graphs are applied to a TensorRT model that contains graph breaks, each break introduces additional
# overhead. This occurs because graph breaks prevent the entire model from being executed as a single, continuous
# optimized unit. As a result, some of the performance benefits typically provided by CUDA Graphs, such as reduced
# kernel launch overhead and improved execution efficiency, may be diminished.
with torch_tensorrt.runtime.enable_cudagraphs():
trt_model(input)

with torch_tensorrt.runtime.enable_cudagraphs(trt_model) as cudagraphs_module:
cudagraphs_module(input)

# %%
# Running wrapped module with cuda graphs
# ----------------------------------
#
# When CUDA Graphs are applied to a TensorRT model that contains graph breaks, each break introduces additional
# overhead. This occurs because graph breaks prevent the entire model from being executed as a single, continuous
# optimized unit. As a result, some of the performance benefits typically provided by CUDA Graphs, such as reduced
# kernel launch overhead and improved execution efficiency, may be diminished.
# Using a wrapped runtime module with CUDA Graphs allows you to encapsulate sequences of operations into graphs
# that can be executed efficiently, even in the presence of graph breaks. When a CUDA Graph context manager is
# used with the TensorRT module as a positional argument, it returns a wrapped_module. This module captures the
# execution graph, enabling efficient replay during subsequent inferences by reducing kernel launch overheads
# and improving performance. Note that initializing with the wrapper module involves a warm-up phase where the
# module is executed several times. This warm-up ensures that memory allocations and initializations are not
# recorded in CUDA Graphs, which helps maintain consistent execution paths and optimize performance.
with torch_tensorrt.runtime.enable_cudagraphs(trt_model) as wrapped_module:
wrapped_module(input)
with torch_tensorrt.runtime.enable_cudagraphs(
trt_model_with_graph_break
) as cudagraphs_module:
cudagraphs_module(input)

# %%
54 changes: 48 additions & 6 deletions examples/dynamo/torch_export_cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

import torch
import torchvision.models as models

import torch_tensorrt
import torchvision.models as models

# %%
# Compilation with `torch_tensorrt.compile` Using Default Settings
Expand Down Expand Up @@ -47,7 +46,7 @@
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

# We can enable the cudagraphs API with a context manager
with torch_tensorrt.runtime.enable_cudagraphs():
with torch_tensorrt.runtime.enable_cudagraphs(opt) as cudagraphs_module:
out_trt = opt(inputs)

# Alternatively, we can set the cudagraphs mode for the session
Expand All @@ -64,6 +63,49 @@
inputs_2 = torch.randn((8, 3, 224, 224)).cuda()
inputs_3 = torch.randn((4, 3, 224, 224)).cuda()

with torch_tensorrt.runtime.enable_cudagraphs():
out_trt_2 = opt(inputs_2)
out_trt_3 = opt(inputs_3)
with torch_tensorrt.runtime.enable_cudagraphs(opt) as cudagraphs_module:
out_trt_2 = cudagraphs_module(inputs_2)
out_trt_3 = cudagraphs_module(inputs_3)

# %%
# Cuda graphs with module that contains graph breaks
# ----------------------------------
#
# When CUDA Graphs are applied to a TensorRT model that contains graph breaks, each break introduces additional
# overhead. This occurs because graph breaks prevent the entire model from being executed as a single, continuous
# optimized unit. As a result, some of the performance benefits typically provided by CUDA Graphs, such as reduced
# kernel launch overhead and improved execution efficiency, may be diminished.
# Using a wrapped runtime module with CUDA Graphs allows you to encapsulate sequences of operations into graphs
# that can be executed efficiently, even in the presence of graph breaks.
# If TensorRT module has graph breaks, CUDA Graph context manager returns a wrapped_module. This module captures entire
# execution graph, enabling efficient replay during subsequent inferences by reducing kernel launch overheads
# and improving performance. Note that initializing with the wrapper module involves a warm-up phase where the
# module is executed several times. This warm-up ensures that memory allocations and initializations are not
# recorded in CUDA Graphs, which helps maintain consistent execution paths and optimize performance.


class SampleModel(torch.nn.Module):
def forward(self, x):
return torch.relu((x + 2) * 0.5)


model = SampleModel().eval().cuda()
input = torch.randn((1, 3, 224, 224)).to("cuda")

# The 'torch_executed_ops' compiler option is used in this example to intentionally introduce graph breaks within the module.
# Note: The Dynamo backend is required for the CUDA Graph context manager to handle modules in an Ahead-Of-Time (AOT) manner.
opt_with_graph_break = torch_tensorrt.compile(
model,
ir="dynamo",
inputs=[input],
min_block_size=1,
pass_through_build_failures=True,
torch_executed_ops={"torch.ops.aten.mul.Tensor"},
)

# %%
# If module has graph breaks, whole submodules are recorded and replayed by cuda graphs
with torch_tensorrt.runtime.enable_cudagraphs(
opt_with_graph_break
) as cudagraphs_module:
cudagraphs_module(input)
9 changes: 0 additions & 9 deletions py/torch_tensorrt/dynamo/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,6 @@
post_lowering,
pre_export_lowering,
)
from torch_tensorrt.dynamo.runtime._WrapperTorchTensorRTModule import (
WrapperTorchTensorRTModule,
)
from torch_tensorrt.dynamo.utils import (
get_flat_args_with_check,
get_output_metadata,
Expand Down Expand Up @@ -414,7 +411,6 @@ def compile(
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
enable_wrapper_module: bool = _defaults.ENABLE_WRAPPER_MODULE,
**kwargs: Any,
) -> torch.fx.GraphModule:
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
Expand Down Expand Up @@ -662,7 +658,6 @@ def compile(
"immutable_weights": immutable_weights,
"enable_cross_compile_for_windows": False,
"enable_weight_streaming": enable_weight_streaming,
"enable_wrapper_module": enable_wrapper_module,
}

settings = CompilationSettings(**compilation_options)
Expand Down Expand Up @@ -906,10 +901,6 @@ def contains_metadata(gm: torch.fx.GraphModule) -> bool:

dryrun_stats_display(dryrun_tracker, settings.dryrun)

if settings.enable_wrapper_module:
# Capture/replay a series of CUDA operations in subgraphs in a wrapped runtime module.
partitioned_module = WrapperTorchTensorRTModule(partitioned_module)

return partitioned_module


Expand Down
1 change: 0 additions & 1 deletion py/torch_tensorrt/dynamo/_defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
IMMUTABLE_WEIGHTS = True
ENABLE_WEIGHT_STREAMING = False
ENABLE_CROSS_COMPILE_FOR_WINDOWS = False
ENABLE_WRAPPER_MODULE = False


def default_device() -> Device:
Expand Down
2 changes: 0 additions & 2 deletions py/torch_tensorrt/dynamo/_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
ENABLE_CROSS_COMPILE_FOR_WINDOWS,
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
ENABLE_WEIGHT_STREAMING,
ENABLE_WRAPPER_MODULE,
ENABLED_PRECISIONS,
ENGINE_CAPABILITY,
HARDWARE_COMPATIBLE,
Expand Down Expand Up @@ -132,7 +131,6 @@ class CompilationSettings:
immutable_weights: bool = IMMUTABLE_WEIGHTS
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
enable_wrapper_module: bool = ENABLE_WRAPPER_MODULE


_SETTINGS_TO_BE_ENGINE_INVARIANT = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,10 @@ def forward(self, *inputs: torch.Tensor) -> torch.Tensor | Tuple[torch.Tensor, .
self._check_initialized()

cudagraphs_enabled = torch_tensorrt.runtime.get_cudagraphs_mode()

shape_changed = self.cudagraphs_validate_shapes(inputs)
# Cudagraphs record is required if cudagraphs_enabled is switched to True regardless of shape change
need_cudagraphs_record = cudagraphs_enabled and (
(not self.prev_cudagraphs_enabled)
or (not self.cudagraphs_validate_shapes(inputs))
(not self.prev_cudagraphs_enabled) or (not shape_changed)
)
self.prev_cudagraphs_enabled = cudagraphs_enabled

Expand Down
12 changes: 1 addition & 11 deletions py/torch_tensorrt/dynamo/runtime/_WrapperTorchTensorRTModule.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,7 @@ def __init__(
self.prev_cudagraphs_enabled = False
self._caller_stream: Optional[torch.cuda.Stream] = None
self._engine_stream: Optional[torch.cuda.Stream] = None

num_torch_mod = 0
for name, _ in self.compiled_module.named_children():
if "_run_on_acc" not in name:
num_torch_mod += 1
if num_torch_mod > 0:
self.warm_up()
else:
logger.warning(
"Wrapper runtime module provides no benefit for a graph module that doesn't have graph breaks"
)
self.warm_up()

def warm_up(self) -> None:
"""
Expand Down
38 changes: 31 additions & 7 deletions py/torch_tensorrt/runtime/_cudagraphs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Any, Optional
from typing import Any

import torch
import torch_tensorrt
Expand All @@ -9,7 +9,9 @@


class CudaGraphsMode:
# No cuda graphs
STANDARD = 0
# Cuda graphs is applied to TRT module
SUBGRAPH_CUDAGRAPHS = 1
# Internal mode to apply cuda graphs for wrapped runtime module
WHOLE_GRAPH_CUDAGRAPHS = 2
Expand Down Expand Up @@ -62,30 +64,52 @@ class _CudagraphsContextManager(object):
Used to enable cudagraphs as a context manager
"""

def __init__(self, compiled_module: Optional[torch.nn.Module]) -> None:
def __init__(self, compiled_module: torch.nn.Module) -> None:
global _PY_RT_CUDAGRAPHS
self.old_mode = _PY_RT_CUDAGRAPHS
self.compiled_module = compiled_module

def __enter__(self) -> "_CudagraphsContextManager":
def __enter__(self) -> torch.nn.Module:
global _PY_RT_CUDAGRAPHS
if self.compiled_module:

num_torch_module = 0
num_trt_module = 0
for name, _ in self.compiled_module.named_children():
if "_run_on_acc" in name:
num_trt_module += 1
elif "_run_on_gpu" in name:
num_torch_module += 1

if num_torch_module > 0:
# Set whole cudagraphs mode and returns wrapped module
_PY_RT_CUDAGRAPHS = CudaGraphsMode.WHOLE_GRAPH_CUDAGRAPHS
# Set new mode for C++
if torch_tensorrt.ENABLED_FEATURES.torch_tensorrt_runtime:
torch.ops.tensorrt.set_cudagraphs_mode(_PY_RT_CUDAGRAPHS)

logger.debug(
f"{num_torch_module} torch modules are in subgraphs. Using wrapper module for cuda graphs"
)
return WrapperTorchTensorRTModule(self.compiled_module)
else:
# Enable cudagraphs
if num_trt_module > 0:
logger.debug(
"There is no graph breaks. Using original module for cuda graphs"
)
else:
logger.warning(
"Please consider dynamo if there is graph breaks. Using original module for cuda graphs"
)
# Enable cudagraphs for TRT submodule
set_cudagraphs_mode(True)
return self
return self.compiled_module

def __exit__(self, *args: Any) -> None:
# Set cudagraphs back to old mode
set_cudagraphs_mode(self.old_mode)


def enable_cudagraphs(
compiled_module: Optional[torch.nn.Module] = None,
compiled_module: torch.nn.Module,
) -> _CudagraphsContextManager:
return _CudagraphsContextManager(compiled_module)
32 changes: 25 additions & 7 deletions tests/py/dynamo/runtime/test_002_cudagraphs_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,19 @@ def test_cudagraphs_off(self):
self.assertFalse(torch.ops.tensorrt.get_cudagraphs_mode())

def test_cudagraphs_context(self):
with torch_tensorrt.runtime.enable_cudagraphs():
class SampleModel(torch.nn.Module):
def forward(self, input):
return torch.ops.aten.abs.default(input)

fx_graph = torch.fx.symbolic_trace(SampleModel())
inputs = [torch.randn((2, 3), dtype=torch.float).cuda()]
optimized_model = torch_tensorrt.compile(
fx_graph,
"torch_compile",
inputs,
min_block_size=1,
)
with torch_tensorrt.runtime.enable_cudagraphs(optimized_model) as _:
self.assertTrue(torch.ops.tensorrt.get_cudagraphs_mode())
self.assertFalse(torch.ops.tensorrt.get_cudagraphs_mode())

Expand All @@ -51,9 +63,11 @@ def forward(self, x):

result_samples = []
torch_results_samples = []
with torch_tensorrt.runtime.enable_cudagraphs():
with torch_tensorrt.runtime.enable_cudagraphs(
optimized_model
) as cudagraphs_module:
for i in inputs:
result_samples.append(optimized_model(i).detach().cpu())
result_samples.append(cudagraphs_module(i).detach().cpu())
torch_results_samples.append(fx_graph(i).detach().cpu())

for i, (optimized_model_results, torch_model_results) in enumerate(
Expand Down Expand Up @@ -92,9 +106,11 @@ def forward(self, x):

result_samples = []
torch_results_samples = []
with torch_tensorrt.runtime.enable_cudagraphs():
with torch_tensorrt.runtime.enable_cudagraphs(
optimized_model
) as cudagraphs_module:
for i in inputs:
result_samples.append(optimized_model(i).detach().cpu())
result_samples.append(cudagraphs_module(i).detach().cpu())
torch_results_samples.append(fx_graph(i).detach().cpu())

for i, (optimized_model_results, torch_model_results) in enumerate(
Expand Down Expand Up @@ -141,9 +157,11 @@ def forward(self, x):

result_samples = []
torch_results_samples = []
with torch_tensorrt.runtime.enable_cudagraphs():
with torch_tensorrt.runtime.enable_cudagraphs(
optimized_model
) as cudagraphs_module:
for i in inputs:
result_samples.append(optimized_model(i).detach().cpu())
result_samples.append(cudagraphs_module(i).detach().cpu())
torch_results_samples.append(fx_graph(i).detach().cpu())

for i, (optimized_model_results, torch_model_results) in enumerate(
Expand Down
Loading

0 comments on commit 5edf79a

Please sign in to comment.