Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: example fixes #3176

Merged
merged 19 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docsrc/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ Model Zoo
* :ref:`torch_compile_resnet`
* :ref:`torch_compile_transformer`
* :ref:`torch_compile_stable_diffusion`
* :ref:`torch_compile_gpt2`
* :ref:`torch_export_gpt2`
* :ref:`torch_export_llama2`
* :ref:`notebooks`
Expand All @@ -148,6 +149,7 @@ Model Zoo
tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion
tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion
tutorials/_rendered_examples/dynamo/torch_compile_gpt2
tutorials/_rendered_examples/dynamo/torch_export_gpt2
tutorials/_rendered_examples/dynamo/torch_export_llama2
tutorials/notebooks
Expand Down
1 change: 1 addition & 0 deletions examples/dynamo/README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,6 @@ Model Zoo
* :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile``
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
* :ref:`_torch_compile_gpt2`: Compiling a GPT2 model using ``torch.compile``
* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`)
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)
117 changes: 117 additions & 0 deletions examples/dynamo/torch_compile_gpt2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
"""
.. _torch_compile_gpt2:

Compiling GPT2 using the Torch-TensorRT ``torch.compile`` frontend
==========================================================

This example illustrates the state of the art model `GPT2 <https://cdn.openai.com/better-language-models/language_models_are_unsupervised_multitask_learners.pdf>`_ optimized using
``torch.compile`` frontend of Torch-TensorRT. Install the following dependencies before compilation

.. code-block:: python

pip install -r requirements.txt

GPT2 is a causal (unidirectional) transformer pretrained using language modeling on a very large corpus of text data. In this example, we use the GPT2 model available at `HuggingFace <https://huggingface.co/docs/transformers/en/model_doc/gpt2>`_ and apply torch.compile on it to
get the graph module representation of the graph. Torch-TensorRT converts this graph into an optimized TensorRT engine.
"""

# %%
# Import necessary libraries
# -----------------------------
import torch
import torch_tensorrt
from transformers import AutoModelForCausalLM, AutoTokenizer

# %%
# Define the necessary parameters
# -----------------------------
# Torch-TensorRT requires a GPU for successful compilation of the model.
# ``MAX_LENGTH`` is the maximum length the generated tokens can have. This corresponds to the length of the input prompt +
# number of new tokens generated
MAX_LENGTH = 32
DEVICE = torch.device("cuda:0")

# %%
# Model definition
# -----------------------------
# We use ``AutoModelForCausalLM`` class to load the pretrained GPT2 model from hugging face. ``kv_cache`` is not supported in Torch-TRT currently so ``use_cache=False``
with torch.no_grad():
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model = (
AutoModelForCausalLM.from_pretrained(
"gpt2",
pad_token_id=tokenizer.eos_token_id,
use_cache=False,
attn_implementation="eager",
)
.eval()
.cuda()
)

# %%
# PyTorch inference
# -----------------------------
# Tokenize a sample input prompt and get pytorch model outputs
prompt = "I enjoy walking with my cute dog"
model_inputs = tokenizer(prompt, return_tensors="pt")
input_ids = model_inputs["input_ids"].cuda()

# %%
# The ``generate()`` API of the ``AutoModelForCausalLM`` class is used for auto-regressive generation with greedy decoding.
pyt_gen_tokens = model.generate(
input_ids,
max_length=MAX_LENGTH,
use_cache=False,
pad_token_id=tokenizer.eos_token_id,
)

# %%
# Torch-TensorRT compilation and inference
# -----------------------------
# The input sequence length is dynamic, so we mark it using ``torch._dynamo.mark_dynamic`` API.
# We provide a (min, max) range of this value so that TensorRT knows in advance what values to optimize for.
# Usually, this would be the context length for the model. We start with ``min=2`` due to the `0/1 specialization <https://docs.google.com/document/d/16VPOa3d-Liikf48teAOmxLc92rgvJdfosIy-yoT38Io/edit?fbclid=IwAR3HNwmmexcitV0pbZm_x1a4ykdXZ9th_eJWK-3hBtVgKnrkmemz6Pm5jRQ&tab=t.0#heading=h.ez923tomjvyk>`_
torch._dynamo.mark_dynamic(input_ids, 1, min=2, max=1023)
model.forward = torch.compile(
model.forward,
backend="tensorrt",
dynamic=None,
options={
"enabled_precisions": {torch.float32},
"disable_tf32": True,
"min_block_size": 1,
},
)

# %%
# Auto-regressive generation loop for greedy decoding using TensorRT model
# The first token generation compiles the model using TensorRT and the second token
# encounters recompilation (which is an issue currently that would be resolved in the future)
trt_gen_tokens = model.generate(
inputs=input_ids,
max_length=MAX_LENGTH,
use_cache=False,
pad_token_id=tokenizer.eos_token_id,
)

# %%
# Decode the output sentences of PyTorch and TensorRT
# -----------------------------
print(
"Pytorch model generated text: ",
tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True),
)
print("=============================")
print(
"TensorRT model generated text: ",
tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True),
)

# %%
# The output sentences should look like

"""
Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll
=============================
TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll
"""
5 changes: 3 additions & 2 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def _pretraced_backend(
repair_input_aliasing(gm, settings)

# Remove sym_int placeholders and inputs
remove_sym_nodes(gm, settings)
remove_sym_nodes(gm, sample_inputs, settings)

torch_inputs = [
input for input in sample_inputs if isinstance(input, torch.Tensor)
]
Expand All @@ -91,7 +92,7 @@ def _pretraced_backend(
# Invoke AOTAutograd to translate operators to aten
gm = aot_export_joint_simple(
gm,
torch_inputs,
sample_inputs,
trace_joint=False,
decompositions=get_decompositions(
settings.enable_experimental_decompositions
Expand Down
14 changes: 9 additions & 5 deletions py/torch_tensorrt/dynamo/lowering/passes/remove_sym_nodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Any, Sequence

import torch
from torch_tensorrt.dynamo._settings import CompilationSettings
Expand All @@ -7,15 +8,17 @@


def remove_sym_nodes(
gm: torch.fx.GraphModule, settings: CompilationSettings
gm: torch.fx.GraphModule,
sample_inputs: Sequence[Any],
settings: CompilationSettings,
) -> torch.fx.GraphModule:
"""Remove sym_int placeholders which get inserted due to torch.compile's
dynamic=True behavior
"""
# Extract SymInt placeholder Tensors
placeholder_sym_ints = [
node
for node in gm.graph.nodes
placeholder_idx_sym_ints = [
(idx, node)
for idx, node in enumerate(gm.graph.nodes)
if (
node.op == "placeholder"
and isinstance(node.type, type)
Expand All @@ -24,8 +27,9 @@ def remove_sym_nodes(
)
]

for node in placeholder_sym_ints:
for idx, node in placeholder_idx_sym_ints:
gm.graph.erase_node(node)
sample_inputs.pop(idx)

gm.graph.lint()
gm.recompile()
Expand Down
Loading