Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support align_corners for Resize operator #418

Closed
wants to merge 4 commits into from
Closed

Conversation

Bohrhh
Copy link

@Bohrhh Bohrhh commented Mar 21, 2020

handle onnx node below

input: "358"
input: "360"
input: "360"
input: "367"
output: "368"
op_type: "Resize"
attribute {
name: "coordinate_transformation_mode"
s: "align_corners"
type: STRING
}
attribute {
name: "cubic_coeff_a"
f: -0.75
type: FLOAT
}
attribute {
name: "mode"
s: "linear"
type: STRING
}
attribute {
name: "nearest_mode"
s: "floor"
type: STRING
}

@CLAassistant
Copy link

CLAassistant commented Mar 21, 2020

CLA assistant check
All committers have signed the CLA.

@allenling
Copy link

hi @Bohrhh

is there any test script for this PR?

from #273, setting layer->setAlignCorners(true) in python

layer = network.add_resize(trt_tensor)
layer.resize_mode = trt.ResizeMode.LINEAR
layer.align_corners = align_corners

seems do not really helps

@Bohrhh
Copy link
Author

Bohrhh commented Mar 28, 2020

hi @allenling
The attribute align_corners starts in opset 11.
We can see from the information of Resize operator like below:

if coordinate_transformation_mode is "align_corners",
x_original = x_resized * (length_original - 1) / (length_resized - 1),

We should give the size which you want to resize your input to, rather than the scale.
In tensorrt, do not use layer->setScales and use layer->setInput instead. Then layer->setAlignCorners(true) will work.

You can use following code to generate resize_align_corners.onnx and use the changed onnx2trt to generate tensorrt engine, and comapre their results.

import torch
import torch.nn as nn
import torch.nn.functional as F

class Resize(nn.Module):
    def __init__(self):
        super(Resize, self).__init__()
    def forward(self, x):
        return F.interpolate(x, size=(448,448), mode='bilinear', align_corners=True)

def export_resize_onnx_model():
    model = Resize()
    model.eval()
    x = torch.rand(1,3,224,224)
    x.requires_grad = True
    y = model(x)
    torch.onnx.export(
        model,    
        x,    
        'resize_align_corners.onnx',            
        export_params=True,   
        opset_version=11,
        input_names = ['x'], 
        output_names = ['y'], 
    )

@allenling
Copy link

allenling commented Mar 30, 2020

@Bohrhh

sry, i posted the wrong issue, it should be NVIDIA/TensorRT#273

if you would check out this issue, i would really appreciate

@bing1zhi2
Copy link

bing1zhi2 commented Sep 3, 2020

@Bohrhh I have problem when convert model from pytorch to trt with Interpolate Op too, and even I use the latested code from onnx-trt ,I still get an error like "Assertion failed: ctx->tensors().count(inputName)" So I used your docker image (bohrhh/tensorrt:7.0) , It' worked,but the output is so different between onnx model and trt model (onnx result is the same as pytorch) . It is strange that when I convert my model without loadding weights to onnx then convert it to trt engine use onnx2trt command in your docker image, all the model result is same....... what's your onnx-trt version in your docker image? It seems there is no other way to convert my model expect use trt network api ?

@Bohrhh
Copy link
Author

Bohrhh commented Sep 3, 2020

@Bohrhh I have problem when convert model from pytorch to trt with Interpolate Op too, and even I use the latested code from onnx-trt ,I still get an error like "Assertion failed: ctx->tensors().count(inputName)" So I used your docker image (bohrhh/tensorrt:7.0) , It' worked,but the output is so different between onnx model and trt model (onnx result is the same as pytorch) . It is strange that when I convert my model without loadding weights to onnx then convert it to trt engine use onnx2trt command in your docker image, all the model result is same....... what's your onnx-trt version in your docker image? It seems there is no other way to convert my model expect use trt network api ?

在docker镜像里面onnx2trt是和tensorrt对应的,也就是7.0版。你的情况好诡异,onnx模型权重对转化之后的tensorrt模型有如此明显影响的情况,我是没遇到过。

@bing1zhi2
Copy link

bing1zhi2 commented Sep 4, 2020

@Bohrhh I have problem when convert model from pytorch to trt with Interpolate Op too, and even I use the latested code from onnx-trt ,I still get an error like "Assertion failed: ctx->tensors().count(inputName)" So I used your docker image (bohrhh/tensorrt:7.0) , It' worked,but the output is so different between onnx model and trt model (onnx result is the same as pytorch) . It is strange that when I convert my model without loadding weights to onnx then convert it to trt engine use onnx2trt command in your docker image, all the model result is same....... what's your onnx-trt version in your docker image? It seems there is no other way to convert my model expect use trt network api ?

在docker镜像里面onnx2trt是和tensorrt对应的,也就是7.0版。你的情况好诡异,onnx模型权重对转化之后的tensorrt模型有如此明显影响的情况,我是没遇到过。

哈哈原谅我蹩脚的英语。今天早上我测试了下,看来这问题跟你的环境没关系,你的环境是正确的:我把前向传播里,最后一个带有align_corners 的 Interpolate( mode='bilinear', align_corners=True) 计算删除了,由于大小与输入结点本来就一样,所以其实这层在这里没变化。 这样模型里的其他Interpolate 层(mode = nearest)就可以用opset=10导出onnx了(size参数指定了常量值),转成onnx后与原模型输出一样。然后可以用官方的tensorRT镜像转模型,这次转出来的trt模型与用你的环境里的onnx2trt结果是一样的,但他们都和原模型的结果不一样。加载完权重就出现不一样的情况实在是太诡异了,都不知道从哪找问题。要说用trt的api 再重新定义网络结构吧,一来原结构太复杂,二来也没太多demo,恐怕需要好多时间啊。另外不知道为啥,我自己按官网步骤打的7.1的tensorrt镜像,安装了最新的onnx2trt转换时就报"Assertion failed: ctx->tensors().count(inputName)"....

@handoku
Copy link

handoku commented Sep 16, 2020

I have tested this code on a deeplabv3 model,pytorch ==> onnx ==> tensorrt,it worked. The tensorrt model's output is same as pytorch's

@kevinch-nv
Copy link
Collaborator

Closing in favor of #538. If any people on this thread still have issues with Resize, feel free to open an issue.

@kevinch-nv kevinch-nv closed this Dec 15, 2020
@handoku
Copy link

handoku commented Jan 7, 2021

I found that after interpolate,the resulted trt engine becomes a fixed batchsize model.
although isDynamic(layer->getOutput(0)->getDimensions()) returns true;

It makes it impossible to create a trt plan file which support dynamic batching.

A very simple example could be :

class MyModel(torch.nn.Module):
    def __init__(self):
        super(MyModel,self).__init__()

    def forward(self, x): 
        out = F.interpolate(x, size=(32,32), mode = 'bilinear',align_corners=True)
        return out 

torch_input = torch.from_numpy(np.random.randn(1,3,64,64).astype(np.float32)).cuda()
model = MyModel()

torch.onnx.export( model, torch_input, onnx_model_file, verbose=False, 
        export_params=True,
        input_names=['input'], output_names=['out'], opset_version = 11, keep_initializers_as_inputs=True,
        dynamic_axes={"input" : {0:"bs" }, "out":{0:"bs"}})

then create a plan file with trtexec :

./bin/trtexec --onnx=./resize.onnx --explicitBatch --minShapes=\'input\':1x3x64x64 --optShapes=\'input\':4x3x64x64 --maxShapes=\'input\':8x3x64x64  --buildOnly --saveEngine=./resize.plan --workspace=11288

when load the trt model with trtserver, it outputs:

I0107 13:40:31.901351 48889 autofill.cc:213] TensorRT autofill: OK: 
W0107 13:40:31.901367 48889 autofill.cc:165] The TRT engine doesn't specify appropriate dimensions to support dynamic batching
I0107 13:40:31.901388 48889 model_config_utils.cc:276] autofilled config: name: "cdcn_trt"
platform: "tensorrt_plan"
input {
  name: "input"
  data_type: TYPE_FP32
  dims: -1
  dims: 3
  dims: 64
  dims: 64
}
output {
  name: "out"
  data_type: TYPE_FP32
  dims: 1
  dims: 3
  dims: 32
  dims: 32
}
instance_group {
  count: 2
  gpus: 0
  kind: KIND_GPU
}
default_model_filename: "model.plan"

Facing NVIDIA/TensorRT#996
I really don't know how to fix the problem, could anyone help me?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants