diff --git a/docsrc/index.rst b/docsrc/index.rst index 5d88c8ecae..c762080649 100644 --- a/docsrc/index.rst +++ b/docsrc/index.rst @@ -134,6 +134,7 @@ Model Zoo * :ref:`torch_compile_resnet` * :ref:`torch_compile_transformer` * :ref:`torch_compile_stable_diffusion` +* :ref:`torch_compile_gpt2` * :ref:`torch_export_gpt2` * :ref:`torch_export_llama2` * :ref:`notebooks` @@ -148,6 +149,7 @@ Model Zoo tutorials/_rendered_examples/dynamo/torch_compile_stable_diffusion tutorials/_rendered_examples/distributed_inference/data_parallel_gpt2 tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion + tutorials/_rendered_examples/dynamo/torch_compile_gpt2 tutorials/_rendered_examples/dynamo/torch_export_gpt2 tutorials/_rendered_examples/dynamo/torch_export_llama2 tutorials/notebooks diff --git a/examples/dynamo/README.rst b/examples/dynamo/README.rst index 60f1969be2..5d3b9d4261 100644 --- a/examples/dynamo/README.rst +++ b/examples/dynamo/README.rst @@ -17,5 +17,6 @@ Model Zoo * :ref:`torch_compile_resnet`: Compiling a ResNet model using the Torch Compile Frontend for ``torch_tensorrt.compile`` * :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile`` * :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile`` +* :ref:`_torch_compile_gpt2`: Compiling a GPT2 model using ``torch.compile`` * :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`) * :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`) \ No newline at end of file diff --git a/examples/dynamo/torch_compile_gpt2.py b/examples/dynamo/torch_compile_gpt2.py new file mode 100644 index 0000000000..5d41c3ed84 --- /dev/null +++ b/examples/dynamo/torch_compile_gpt2.py @@ -0,0 +1,117 @@ +""" +.. _torch_compile_gpt2: + +Compiling GPT2 using the Torch-TensorRT ``torch.compile`` frontend +========================================================== + +This example illustrates the state of the art model `GPT2 `_ optimized using +``torch.compile`` frontend of Torch-TensorRT. Install the following dependencies before compilation + +.. code-block:: python + + pip install -r requirements.txt + +GPT2 is a causal (unidirectional) transformer pretrained using language modeling on a very large corpus of text data. In this example, we use the GPT2 model available at `HuggingFace `_ and apply torch.compile on it to +get the graph module representation of the graph. Torch-TensorRT converts this graph into an optimized TensorRT engine. +""" + +# %% +# Import necessary libraries +# ----------------------------- +import torch +import torch_tensorrt +from transformers import AutoModelForCausalLM, AutoTokenizer + +# %% +# Define the necessary parameters +# ----------------------------- +# Torch-TensorRT requires a GPU for successful compilation of the model. +# ``MAX_LENGTH`` is the maximum length the generated tokens can have. This corresponds to the length of the input prompt + +# number of new tokens generated +MAX_LENGTH = 32 +DEVICE = torch.device("cuda:0") + +# %% +# Model definition +# ----------------------------- +# We use ``AutoModelForCausalLM`` class to load the pretrained GPT2 model from hugging face. ``kv_cache`` is not supported in Torch-TRT currently so ``use_cache=False`` +with torch.no_grad(): + tokenizer = AutoTokenizer.from_pretrained("gpt2") + model = ( + AutoModelForCausalLM.from_pretrained( + "gpt2", + pad_token_id=tokenizer.eos_token_id, + use_cache=False, + attn_implementation="eager", + ) + .eval() + .cuda() + ) + +# %% +# PyTorch inference +# ----------------------------- +# Tokenize a sample input prompt and get pytorch model outputs +prompt = "I enjoy walking with my cute dog" +model_inputs = tokenizer(prompt, return_tensors="pt") +input_ids = model_inputs["input_ids"].cuda() + +# %% +# The ``generate()`` API of the ``AutoModelForCausalLM`` class is used for auto-regressive generation with greedy decoding. +pyt_gen_tokens = model.generate( + input_ids, + max_length=MAX_LENGTH, + use_cache=False, + pad_token_id=tokenizer.eos_token_id, +) + +# %% +# Torch-TensorRT compilation and inference +# ----------------------------- +# The input sequence length is dynamic, so we mark it using ``torch._dynamo.mark_dynamic`` API. +# We provide a (min, max) range of this value so that TensorRT knows in advance what values to optimize for. +# Usually, this would be the context length for the model. We start with ``min=2`` due to the `0/1 specialization `_ +torch._dynamo.mark_dynamic(input_ids, 1, min=2, max=1023) +model.forward = torch.compile( + model.forward, + backend="tensorrt", + dynamic=None, + options={ + "enabled_precisions": {torch.float32}, + "disable_tf32": True, + "min_block_size": 1, + }, +) + +# %% +# Auto-regressive generation loop for greedy decoding using TensorRT model +# The first token generation compiles the model using TensorRT and the second token +# encounters recompilation (which is an issue currently that would be resolved in the future) +trt_gen_tokens = model.generate( + inputs=input_ids, + max_length=MAX_LENGTH, + use_cache=False, + pad_token_id=tokenizer.eos_token_id, +) + +# %% +# Decode the output sentences of PyTorch and TensorRT +# ----------------------------- +print( + "Pytorch model generated text: ", + tokenizer.decode(pyt_gen_tokens[0], skip_special_tokens=True), +) +print("=============================") +print( + "TensorRT model generated text: ", + tokenizer.decode(trt_gen_tokens[0], skip_special_tokens=True), +) + +# %% +# The output sentences should look like + +""" +Pytorch model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll +============================= +TensorRT model generated text: I enjoy walking with my cute dog, but I'm not sure if I'll ever be able to walk with my dog. I'm not sure if I'll +""" diff --git a/py/torch_tensorrt/dynamo/backend/backends.py b/py/torch_tensorrt/dynamo/backend/backends.py index e15ed0495f..f60cdf3fca 100644 --- a/py/torch_tensorrt/dynamo/backend/backends.py +++ b/py/torch_tensorrt/dynamo/backend/backends.py @@ -80,7 +80,8 @@ def _pretraced_backend( repair_input_aliasing(gm, settings) # Remove sym_int placeholders and inputs - remove_sym_nodes(gm, settings) + remove_sym_nodes(gm, sample_inputs, settings) + torch_inputs = [ input for input in sample_inputs if isinstance(input, torch.Tensor) ] @@ -91,7 +92,7 @@ def _pretraced_backend( # Invoke AOTAutograd to translate operators to aten gm = aot_export_joint_simple( gm, - torch_inputs, + sample_inputs, trace_joint=False, decompositions=get_decompositions( settings.enable_experimental_decompositions diff --git a/py/torch_tensorrt/dynamo/lowering/passes/remove_sym_nodes.py b/py/torch_tensorrt/dynamo/lowering/passes/remove_sym_nodes.py index 2605dccba6..9f69572059 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/remove_sym_nodes.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/remove_sym_nodes.py @@ -1,4 +1,5 @@ import logging +from typing import Any, Sequence import torch from torch_tensorrt.dynamo._settings import CompilationSettings @@ -7,15 +8,17 @@ def remove_sym_nodes( - gm: torch.fx.GraphModule, settings: CompilationSettings + gm: torch.fx.GraphModule, + sample_inputs: Sequence[Any], + settings: CompilationSettings, ) -> torch.fx.GraphModule: """Remove sym_int placeholders which get inserted due to torch.compile's dynamic=True behavior """ # Extract SymInt placeholder Tensors - placeholder_sym_ints = [ - node - for node in gm.graph.nodes + placeholder_idx_sym_ints = [ + (idx, node) + for idx, node in enumerate(gm.graph.nodes) if ( node.op == "placeholder" and isinstance(node.type, type) @@ -24,8 +27,9 @@ def remove_sym_nodes( ) ] - for node in placeholder_sym_ints: + for idx, node in placeholder_idx_sym_ints: gm.graph.erase_node(node) + sample_inputs.pop(idx) gm.graph.lint() gm.recompile()