Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove linear lowering pass and converter #3323

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

HolyWu
Copy link
Contributor

@HolyWu HolyWu commented Dec 13, 2024

1. The linear lowering pass only works with ir="torch_compile" along with nn.Linear(bias=True), and has no effect at all with ir="dynamo" no matter bias is True or False. Using codes below:

import os

import torch
import torch_tensorrt

os.environ["CI_BUILD"] = "1"


class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(20, 30, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)


with torch.inference_mode():
    model = MyModule().eval().cuda().half()
    inputs = [torch.randn(128, 20, dtype=torch.half, device="cuda")]

    trt_model = torch_tensorrt.compile(
        model, "torch_compile", inputs, enabled_precisions={torch.half}, debug=True, min_block_size=1
    )

    trt_model(*inputs)

ir="torch_compile", bias=True

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removed auxiliary clone nodes for placeholders:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%arg0_1, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%arg1_1, %arg2_1, %permute), kwargs = {})
    return (addmm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%arg0_1, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%arg1_1, %arg2_1, %permute), kwargs = {})
    return (addmm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.lower_linear:Graph after lowering linear:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %linear_default : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%arg2_1, %arg0_1, %arg1_1), kwargs = {})
    return (linear_default,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %linear_default : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%arg2_1, %arg0_1, %arg1_1), kwargs = {})
    return (linear_default,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo.backend.backends:Lowered Input graph:
 graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %arg2_1 : [num_users=1] = placeholder[target=arg2_1]
    %linear_default : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%arg2_1, %arg0_1, %arg1_1), kwargs = {})
    return (linear_default,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.linear.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

ir="torch_compile", bias=False

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_input_alias_fixing_clones:Removed auxiliary clone nodes for placeholders:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%arg0_1, [1, 0]), kwargs = {})
    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %permute), kwargs = {})
    return (mm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%arg0_1, [1, 0]), kwargs = {})
    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %permute), kwargs = {})
    return (mm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%arg0_1, [1, 0]), kwargs = {})
    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %permute), kwargs = {})
    return (mm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo.backend.backends:Lowered Input graph:
 graph():
    %arg0_1 : [num_users=1] = placeholder[target=arg0_1]
    %arg1_1 : [num_users=1] = placeholder[target=arg1_1]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%arg0_1, [1, 0]), kwargs = {})
    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%arg1_1, %permute), kwargs = {})
    return (mm,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.permute.default + Operator Count: 1
- torch.ops.aten.mm.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

ir="dynamo", bias=True

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
    %p_linear_bias : [num_users=1] = placeholder[target=p_linear_bias]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight, %p_linear_bias), kwargs = {})
    return (linear,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %x : [num_users=1] = placeholder[target=x]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %permute), kwargs = {})
    return (addmm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %x : [num_users=1] = placeholder[target=x]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %_frozen_param0), kwargs = {})
    return (addmm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %x : [num_users=1] = placeholder[target=x]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %_frozen_param0), kwargs = {})
    return (addmm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %linear_bias : [num_users=1] = get_attr[target=linear.bias]
    %x : [num_users=1] = placeholder[target=x]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %addmm : [num_users=1] = call_function[target=torch.ops.aten.addmm.default](args = (%linear_bias, %x, %_frozen_param0), kwargs = {})
    return (addmm,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.addmm.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

ir="dynamo", bias=False

DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_detach:Removed 0 detach nodes:
graph():
    %p_linear_weight : [num_users=1] = placeholder[target=p_linear_weight]
    %x : [num_users=1] = placeholder[target=x]
    %linear : [num_users=1] = call_function[target=torch.ops.aten.linear.default](args = (%x, %p_linear_weight), kwargs = {})
    return (linear,)
DEBUG:torch_tensorrt.dynamo._compiler:Input graph: graph():
    %linear_weight : [num_users=1] = get_attr[target=linear.weight]
    %x : [num_users=1] = placeholder[target=x]
    %permute : [num_users=1] = call_function[target=torch.ops.aten.permute.default](args = (%linear_weight, [1, 0]), kwargs = {})
    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%x, %permute), kwargs = {})
    return (mm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.constant_folding:Graph after constant folding:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%x, %_frozen_param0), kwargs = {})
    return (mm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.remove_assert_scalar:Removed 0 assert_scalar nodes:
graph():
    %x : [num_users=1] = placeholder[target=x]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%x, %_frozen_param0), kwargs = {})
    return (mm,)
DEBUG:torch_tensorrt.dynamo.lowering.passes.accumulate_fp32_matmul:Skipping FP32 accumulation for matmul layers as use_fp32_acc is not enabled in the compilation settings
DEBUG:torch_tensorrt.dynamo._compiler:Lowered Input graph: graph():
    %x : [num_users=1] = placeholder[target=x]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    %mm : [num_users=1] = call_function[target=torch.ops.aten.mm.default](args = (%x, %_frozen_param0), kwargs = {})
    return (mm,)
DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
Supported Nodes:
- torch.ops.aten.mm.default + Operator Count: 1

DEBUG:torch_tensorrt.dynamo.partitioning._global_partitioner:
All Nodes Supported

2. The linear converter's performace is basically the same as without using the converter. Using codes below:

from __future__ import annotations

import os

import numpy as np
import torch
import torch_tensorrt

os.environ["CI_BUILD"] = "1"

times = 100


@torch.inference_mode()
def benchmark(model: torch.nn.Module, inputs: list[torch.Tensor]) -> np.ndarray:
    # Warm up
    for i in range(3):
        model(inputs[i])

    torch.cuda.synchronize()

    start_events = [torch.cuda.Event(enable_timing=True) for _ in range(times)]
    end_events = [torch.cuda.Event(enable_timing=True) for _ in range(times)]

    for i in range(times):
        torch.cuda._sleep(1_000_000)

        start_events[i].record()
        model(inputs[i])
        end_events[i].record()

    torch.cuda.synchronize()

    timings = [s.elapsed_time(e) for s, e in zip(start_events, end_events)]
    return np.array(timings)


class MyModule(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = torch.nn.Linear(4096, 8192)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.linear(x)


torch.manual_seed(12345)

model = MyModule().eval().cuda().half()

inputs = [torch_tensorrt.Input((2048, 4096), dtype=torch.half)]

trt_model = torch_tensorrt.compile(
    model, "torch_compile", inputs, enabled_precisions={torch.half}, debug=False, min_block_size=1
)

inputs = [torch.randn(2048, 4096, dtype=torch.half, device="cuda") for _ in range(times)]

timing = benchmark(trt_model, inputs)

print("\nTiming:")
print(f"Min={timing.min()} ms, Mean={timing.mean()} ms, Max={timing.max()} ms")

torch._dynamo.reset()

ir="torch_compile", with linear pass

Timing:
Min=2.2077438831329346 ms, Mean=2.27082528591156 ms, Max=3.074399948120117 ms

ir="torch_compile", linear pass removed

Timing:
Min=2.20467209815979 ms, Mean=2.2625676822662353 ms, Max=2.8375039100646973 ms

ir="dynamo"

Timing:
Min=2.0244479179382324 ms, Mean=2.063061113357544 ms, Max=2.6060800552368164 ms

3. TestLowerLinear is flaky. Recently TestLinearConverter is also observed to have threshold failures on Windows CI.

For reasons mentioned above, I think the linear lowering pass and converter provide no benefit and should be removed.

@github-actions github-actions bot added component: tests Issues re: Tests component: lowering Issues re: The lowering / preprocessing passes component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Dec 13, 2024
@github-actions github-actions bot requested a review from peri044 December 13, 2024 17:20
Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great and thanks for the perf analysis. LGTM

@peri044
Copy link
Collaborator

peri044 commented Dec 16, 2024

@HolyWu can you rebase ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: converters Issues re: Specific op converters component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants