Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

full_like to full decomposition moving to decomposition.py for dynami… #3289

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

apbose
Copy link
Collaborator

@apbose apbose commented Nov 11, 2024

No description provided.

@github-actions github-actions bot added component: lowering Issues re: The lowering / preprocessing passes component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Nov 11, 2024
@narendasan
Copy link
Collaborator

@apbose do you have a test case?

@peri044
Copy link
Collaborator

peri044 commented Nov 18, 2024

@apbose I see your comment : #3140 (comment). Can you provide more context on why this change is required ?

@apbose apbose requested review from chohk88 and peri044 November 19, 2024 19:31
@apbose
Copy link
Collaborator Author

apbose commented Nov 19, 2024

@apbose
Copy link
Collaborator Author

apbose commented Nov 19, 2024

@peri044 I removed the replace_full_like_to_full and instead moved it to the _decompositions.py. In the dynamic case, full op tries to get the meta data from the full_like input tensor meta data, but since the shape of the input tensor is dynamic, it gets undefined shape in the graph and the forward function complains. This is the graph node it gets lowered to from the full_like node input shape

full_default = torch.ops.aten.full.default([s0, 1, s2], 1, pin_memory = False, device = device(type='cuda', index=0), dtype = torch.float32) where s0 and s2 are undefined. Whereas now by making it a lowered op in whole, the graph has

%sym_size_int_11 : [num_users=10] = call_function[target=torch.ops.aten.sym_size.int](args = (%args0, 0), kwargs = {})
 %sym_size_int_12 : [num_users=11] = call_function[target=torch.ops.aten.sym_size.int](args = (%args1, 1), kwargs = {})

%full : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_11, 1, %sym_size_int_12], 1), kwargs = {dtype: torch.float32, device: cuda:0, pin_memory: False})

Copy link

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some changes that do not conform to Python style guidelines:

--- /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decompositions.py	2024-11-19 20:01:46.728609+00:00
+++ /home/runner/work/TensorRT/TensorRT/py/torch_tensorrt/dynamo/lowering/_decompositions.py	2024-11-19 20:02:09.200571+00:00
@@ -416,12 +416,12 @@
        return torch.nn.functional.group_norm(input, input.shape[1], weight, bias, eps)
    else:
        return torch.nn.functional.batch_norm(
            input, running_mean, running_var, weight, bias, False, momentum, eps
        )
-    
-    
+
+
@register_torch_trt_decomposition(
    torch.ops.aten.full_like, registry=TORCH_TRT_DECOMPOSITIONS
)
def full_like_decomposition(*args, **kwargs) -> torch.Tensor:
    input = args[0]

@apbose
Copy link
Collaborator Author

apbose commented Nov 20, 2024

Oh the PR now is failing since the graph post lowering is an empty one

the graph now is==== graph():
    %arg0_1 : [num_users=0] = placeholder[target=arg0_1]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    return (_frozen_param0,)

Modifying the test now to make it non empty

@github-actions github-actions bot added the component: tests Issues re: Tests label Nov 20, 2024
Copy link
Collaborator

@peri044 peri044 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changes LGTM. Can you add a test for dynamic shape case as well ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla signed component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes component: tests Issues re: Tests
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants