Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update skip layer norm #22719

Merged
merged 21 commits into from
Nov 12, 2024
Merged

Conversation

amarin16
Copy link
Collaborator

@amarin16 amarin16 commented Nov 4, 2024

Update the SkipLayerNorm implementation to address issues.

@amarin16 amarin16 requested review from fs-eire and jywu-msft November 4, 2024 22:09
@amarin16 amarin16 marked this pull request as ready for review November 4, 2024 23:07
@fs-eire
Copy link
Contributor

fs-eire commented Nov 5, 2024

Does this also fix SkipSimplifiedLayerNormalization ?

@amarin16
Copy link
Collaborator Author

amarin16 commented Nov 5, 2024

Does this also fix SkipSimplifiedLayerNormalization ?

Yes, both SkipLayerNormalization and SkipSimplifiedLayerNormalization are implemented using the SkipLayerNorm class

@amarin16 amarin16 requested a review from fajin-corp November 5, 2024 11:52
@fajin-corp
Copy link
Contributor

in terms of bug fix, the changes should have fixed the race condition for the memory buffer. @amarin16 do you see any unit test failures?

@amarin16
Copy link
Collaborator Author

amarin16 commented Nov 5, 2024

in terms of bug fix, the changes should have fixed the race condition for the memory buffer. @amarin16 do you see any unit test failures?

I don't see any failures in the existing unit tests

fajin-corp
fajin-corp previously approved these changes Nov 5, 2024
fs-eire
fs-eire previously approved these changes Nov 6, 2024
Copy link
Contributor

@fs-eire fs-eire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@amarin16 amarin16 dismissed stale reviews from fs-eire and fajin-corp via d5850a4 November 6, 2024 15:34
fs-eire
fs-eire previously requested changes Nov 7, 2024
Copy link
Contributor

@fs-eire fs-eire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The latest change is incorrect.

You need to be very careful to store a const reference from a function's parameter, unless you are quite sure about the ownership and lifecycle of the object. Specifically, in this case you stores the tensor reference as a parameter of PrePack() into a pointer of a class field. The tensor object is no longer valid in the context of OpKernel::Compute() and you are actually accessing dangling pointers.

If you want to validate the initializers, you need to do that inside PrePack() and returns a failure OrtStatus.

I can also observe this failure in the original issue, but seems it is not caught by the CI pipelines.

My suggestion:

EDIT: I debugged the code: the input validation accesses nullptr.
image

@amarin16
Copy link
Collaborator Author

amarin16 commented Nov 7, 2024

Before my latest changes, I initially considered doing the validation when the gamma pointer is not null like you suggested. The problem with that is that you no longer enforce the fact that gamma is not optional. And this is not the only code that uses that helper.
An alternative is to do the validation in prepack, however that's tricky because the current logic validates prepacked and non prepacked tensors in the same if statement.

I will revert bc2c820 and come up with a different way to validate

@amarin16
Copy link
Collaborator Author

amarin16 commented Nov 7, 2024

Unfortunately we need input->Shape().GetDims() in PrePack in order to do any of the validations there and we don't have access to it.

@amarin16
Copy link
Collaborator Author

amarin16 commented Nov 8, 2024

I added some unit tests and manually validated that they fail if you try to access gamma when prepacked

@sophies927 sophies927 added the triage:approved Approved for cherrypicks for release label Nov 11, 2024
@amarin16 amarin16 dismissed fs-eire’s stale review November 12, 2024 15:01

I've addressed all of Yulong's comments, but he is currently OOF

@amarin16 amarin16 merged commit f0ac5e0 into microsoft:main Nov 12, 2024
92 checks passed
yf711 pushed a commit that referenced this pull request Nov 12, 2024
Update the `SkipLayerNorm` implementation to address issues.
@sophies927 sophies927 added the cherry-picked Cherry-picked for a cherrypicks branch label Nov 18, 2024
ishwar-raut1 pushed a commit to ishwar-raut1/onnxruntime that referenced this pull request Nov 19, 2024
Update the `SkipLayerNorm` implementation to address issues.
guschmue pushed a commit that referenced this pull request Dec 2, 2024
Update the `SkipLayerNorm` implementation to address issues.
ankitm3k pushed a commit to intel/onnxruntime that referenced this pull request Dec 11, 2024
Update the `SkipLayerNorm` implementation to address issues.
ankitm3k pushed a commit to intel/onnxruntime that referenced this pull request Dec 11, 2024
Update the `SkipLayerNorm` implementation to address issues.
ankitm3k pushed a commit to intel/onnxruntime that referenced this pull request Dec 11, 2024
Update the `SkipLayerNorm` implementation to address issues.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cherry-picked Cherry-picked for a cherrypicks branch release:1.20.1 triage:approved Approved for cherrypicks for release
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants