-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update skip layer norm #22719
Update skip layer norm #22719
Conversation
Does this also fix SkipSimplifiedLayerNormalization ? |
Yes, both |
in terms of bug fix, the changes should have fixed the race condition for the memory buffer. @amarin16 do you see any unit test failures? |
I don't see any failures in the existing unit tests |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The latest change is incorrect.
You need to be very careful to store a const reference from a function's parameter, unless you are quite sure about the ownership and lifecycle of the object. Specifically, in this case you stores the tensor reference as a parameter of PrePack() into a pointer of a class field. The tensor object is no longer valid in the context of OpKernel::Compute() and you are actually accessing dangling pointers.
If you want to validate the initializers, you need to do that inside PrePack() and returns a failure OrtStatus.
I can also observe this failure in the original issue, but seems it is not caught by the CI pipelines.
My suggestion:
- for this PR:
- revert the changes introduced in bc2c820
- validate issue [Web] 1.20.0 breaks SkipSimplifiedLayerNormalization backwards compatibility. Missing Input: model.layers.0.input_layernorm.weight #22704 is fixed
- fix validation: only does the validation when pointer is not null (see details below)
- merge to unblock users
- Following up:
- create a new PR to add test cases (using initializers) and initialier validation.
EDIT: I debugged the code: the input validation accesses nullptr.
Before my latest changes, I initially considered doing the validation when the I will revert bc2c820 and come up with a different way to validate |
Unfortunately we need |
I added some unit tests and manually validated that they fail if you try to access |
I've addressed all of Yulong's comments, but he is currently OOF
Update the `SkipLayerNorm` implementation to address issues.
Update the `SkipLayerNorm` implementation to address issues.
Update the `SkipLayerNorm` implementation to address issues.
Update the `SkipLayerNorm` implementation to address issues.
Update the `SkipLayerNorm` implementation to address issues.
Update the `SkipLayerNorm` implementation to address issues.
Update the
SkipLayerNorm
implementation to address issues.