Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

✨[Feature] Add pickle_protocol argument for torch_tensorrt.save #3294

Open
fortminors opened this issue Nov 14, 2024 · 4 comments
Open

✨[Feature] Add pickle_protocol argument for torch_tensorrt.save #3294

fortminors opened this issue Nov 14, 2024 · 4 comments
Assignees
Labels
feature request New feature or request

Comments

@fortminors
Copy link

fortminors commented Nov 14, 2024

Is your feature request related to a problem? Please describe.
I am trying to save an optimized CLIP model to the disk, but it's too large (around 8 Gb if I remember correctly) for pickle to handle with pickle_protocol < 4 (which is what is used by default) - it gives OverflowError: serializing a string larger than 4 GiB requires pickle protocol 4 or higher

Describe the solution you'd like
It would be great if you could add pickle_protocol argument to torch_tensorrt.save method, similar to how torch.save does it, so it would be possible to save large optimized models

Describe alternatives you've considered
I've tried torch.save, but it does not save all of the required information about the optimized model, so it yields the following error when I do model = torch.export.load(model_path).module():

    model = torch.export.load(model_path).module()
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.venv/lib/python3.11/site-packages/torch/export/__init__.py", line 421, in load
    version = zipf.read("version").decode().split(".")
              ^^^^^^^^^^^^^^^^^^^^
  File "/python/cpython-3.11.10-linux-x86_64-gnu/lib/python3.11/zipfile.py", line 1527, in read
    with self.open(name, "r", pwd) as fp:
         ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/python/cpython-3.11.10-linux-x86_64-gnu/lib/python3.11/zipfile.py", line 1564, in open
    zinfo = self.getinfo(name)
            ^^^^^^^^^^^^^^^^^^
  File "/python/cpython-3.11.10-linux-x86_64-gnu/lib/python3.11/zipfile.py", line 1493, in getinfo
    raise KeyError(
KeyError: "There is no item named 'version' in the archive"

Additional context
MRE (requires pip install torch torch_tensorrt open_clip):

import torch
import torch_tensorrt
import open_clip


torch.set_float32_matmul_precision('high')


if __name__ == "__main__":
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model, _, preprocess = open_clip.create_model_and_transforms("convnext_xxlarge", pretrained="laion2b_s34b_b82k_augreg_soup")

    model = model.visual.to(device).eval()
    image_size = model.image_size

    image = torch.randn((1, 3, *image_size)).to(device) # (1, 3, 256, 256)

    model = torch_tensorrt.compile(model, ir="dynamo", inputs=[image])
    model(image)
    torch_tensorrt.save(model, "convnext_xxlarge_compiled.ep", inputs=[image])

Running the above script gives the following result:

WARNING:py.warnings:/.venv/lib/python3.11/site-packages/torch_tensorrt/dynamo/_exporter.py:387: UserWarning: Attempted to insert a get_attr Node with no underlying reference in the owning GraphModule! Call GraphModule.add_submodule to add the necessary submodule, GraphModule.add_parameter to add the necessary Parameter, or nn.Module.register_buffer to add the necessary buffer
  engine_node = gm.graph.get_attr(engine_name)

WARNING:py.warnings:/.venv/lib/python3.11/site-packages/torch/fx/graph.py:1586: UserWarning: Node _run_on_acc_0_engine target _run_on_acc_0_engine _run_on_acc_0_engine of  does not reference an nn.Module, nn.Parameter, or buffer, which is what 'get_attr' Nodes typically target
  warnings.warn(f'Node {node} target {node.target} {atom} of {seen_qualname} does '

Traceback (most recent call last):
  File "/src/repro.py", line 20, in <module>
    torch_tensorrt.save(model, "convnext_xxlarge_compiled.ep", inputs=[image])
  File "/.venv/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 529, in save
    torch.export.save(exp_program, file_path)
  File "/.venv/lib/python3.11/site-packages/torch/export/__init__.py", line 341, in save
    artifact: SerializedArtifact = serialize(ep, opset_version)
                                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 2374, in serialize
    serialized_program = ExportedProgramSerializer(opset_version).serialize(
                         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 1410, in serialize
    serialize_torch_artifact(constants),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/.venv/lib/python3.11/site-packages/torch/_export/serde/serialize.py", line 312, in serialize_torch_artifact
    torch.save(artifact, buffer)
  File "/.venv/lib/python3.11/site-packages/torch/serialization.py", line 850, in save
    _save(
  File "/.venv/lib/python3.11/site-packages/torch/serialization.py", line 1088, in _save
    pickler.dump(obj)
OverflowError: serializing a string larger than 4 GiB requires pickle protocol 4 or higher

By default, torch uses pickle protocol 2 as per this line

@lanluo-nvidia
Copy link
Collaborator

torch.save() does has the pickle_protocol support
however we use torch.export.save() which does not have the pickle_protocol support

torch.export.save()
|--> serialize_torch_artifact()
|----> torch.save() without pickle_protocol options being exposed.

@peri044 will bring this up with meta to get the pickle_protocol options being exposed

@fortminors
Copy link
Author

Would be awesome! Looking forward to it.

@peri044
Copy link
Collaborator

peri044 commented Dec 6, 2024

Here's the issue I brought up with Meta: pytorch/pytorch#142004

@peri044
Copy link
Collaborator

peri044 commented Dec 11, 2024

This PR should fix this issue: pytorch/pytorch#142253

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request New feature or request
Projects
None yet
Development

No branches or pull requests

5 participants