-
Notifications
You must be signed in to change notification settings - Fork 352
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🐛 [Bug] Encountered bug when using Torch-TensorRT (convert part of the model) #3127
Comments
This is the example I would expect to work once user objects are supported import torch.nn
import torch_tensorrt
class MySubmodule(torch.nn.Module):
def __init__(self):
super(MySubmodule, self).__init__()
self.layer = torch.nn.Linear(10, 10)
def forward(self, x):
return self.layer(x)
class MyMod(torch.nn.Module):
def __init__(self):
super(MyMod, self).__init__()
self.submod = MySubmodule()
def forward(self, x):
return self.submod(x)
def patch_submod(mod):
mod.submod = torch_tensorrt.compile(mod.submod, ir="dynamo",inputs=[
torch_tensorrt.Input(shape=(1, 10)),
],
min_block_size=1)
if __name__ == "__main__":
model = MyMod()
model.to("cuda")
patch_submod(model)
exported_program = torch_tensorrt.dynamo.trace(model, arg_inputs=[torch.zeros(1, 10).to("cuda")])
mod = exported_program.module()
mod(torch.zeros(1, 10).cuda())
print(exported_program.graph)
torch.save(exported_program, "test.pt") Currently fails with
|
Yes, in our actual scenario, because our code framework is quite complex and involves some conditionals, if we were to directly use the dynamo mode in the TRT conversion stage, we would also encounter these types of conditional statements.
if use the dynamo: just in torch_tensorrt.compile may raise these error:
Therefore, we use torch.jit.trace + trt_torch_script instead. |
I also encountered another bug:
raise the error
when I use allow_shape_tensors=True,
aten::arange support the nvinfer1::ITensor? |
@narendasan the same model: I use the static shape, the input of seqence_length=5, the trt model acc==the origin model; but when I use seqence_length=50, the trt model acc is not equal to the origin model (-0.720 VS -0.516) I don't know if it's caused by multi-stream or dynamic some other reason. |
Would torch.compile work in your usecase? It is able to support conditionals and you can use engine caching to short cut setup. Its going to be unlikely we add any improvements to torchscript. In torchscript if there is no dynamic inputs there should be no dynamic shapes. Multistream (at least how it is used for us, where TRT has non default execution) cannot be turned off since TRT requires this. You can file an issue for the accuracy issue with a repro and we can try to figure out what is going on |
I find a solution in my code: use symbolic_trace(model) + torch.export.export + torch_tensorrt.compile(ir="dynamo") to replace torch.jit.trace(model)+ torch_tensorrt.compile(ir="ts") , the new solution acc is correct,and I can save emb+dense_trt in one model.
|
Simplified version of actual code:
when I save the model, I must use torch.jit.trace(model) + torch.save ,but torch.jit.trace don't support torch.device, --->use *args as forward ,when I use the symbolic_trace, @narendasan
I want to know when I use args as input, the trt_gm = torch_tensorrt.dynamo.compile(exp_program, inputs) is error, how can I update the code : @narendasan
I want to use symbolic_trace to support torch.device ,can you help me to solve it? |
Hello @yjjinjie Is this still an issue ? If so, can you point me to the exact model |
@peri044
error:
|
Bug Description
But it raises exception: RuntimeError: method.qualname() == QualifiedName(selfClass->name()->qualifiedName(), methodName)INTERNAL ASSERT FAILED at "../torch/csrc/jit/serialization/python_print.cpp":1105, please report a bug to PyTorch.
if I use dynamo:
error:
the env:
The text was updated successfully, but these errors were encountered: