Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Problem in position embedding #4

Closed
jmercat opened this issue Aug 31, 2023 · 8 comments · Fixed by #5
Closed

Problem in position embedding #4

jmercat opened this issue Aug 31, 2023 · 8 comments · Fixed by #5

Comments

@jmercat
Copy link
Collaborator

jmercat commented Aug 31, 2023

queries, keys, vals = self.pos_embed(queries, keys, vals)

It seems to me that the rotary position embedding is being applied on the head dimension (dim -2) of the vectors q, k instead of the sequence dimension (dim 1).
I think the head and sequence dimensions should be swapped before calling position embedding .
(see https://github.com/facebookresearch/xformers/blob/748c159096d4f9fcfe3eaf22801e5aed4777210b/xformers/components/positional_embedding/rotary.py#L85)

What I'm proposing is simply to re-write RotaryWithCast as follow:

class RotaryWithCast(RotaryEmbedding):
    def forward(self, q, k, v):
        q, k = super().forward(q.permute(0, 2, 1, 3), k.permute(0, 2, 1, 3))
        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        return q.to(v.dtype), k.to(v.dtype), v
@jmercat jmercat changed the title Problem in position embedding? Problem in position embedding Aug 31, 2023
@jmercat
Copy link
Collaborator Author

jmercat commented Aug 31, 2023

Here is the runs I made with a custom subset of starcoder data. The original 11m training is in brown. My implementation using a different positional encoding (including the proposed fix) is in orange.
Screenshot from 2023-08-31 10-22-09

@sagadre
Copy link
Collaborator

sagadre commented Aug 31, 2023

Good catch! The blow up curves your are seeing are similar to the ones we were seeing before we introduced qk norm for the smaller models. Will do some testing with this fix on my end as well. Would you like to open a PR?

@mitchellnw
Copy link
Contributor

Wow, amazing catch! We really appreciate this.

@mitchellnw
Copy link
Contributor

We've added your name to the README because this is a very substantial bug catch. It's pretty interesting that our first 1B/7B runs do pretty well even without proper posembeds, but we should fix this going forward.

@jmercat
Copy link
Collaborator Author

jmercat commented Aug 31, 2023

Great code base by the way. It's a pleasure to read.
Thanks for proposing to include me. I could open a PR but it's probably simpler for you to just include what I wrote (or a better version... I haven't tested if calling contiguous would make a difference).

@sagadre
Copy link
Collaborator

sagadre commented Aug 31, 2023

looking into a way to implement this directly with the xformers api. thanks so much @jmercat !

@jmercat
Copy link
Collaborator Author

jmercat commented Aug 31, 2023

actually moving that line before the call to view would be enough.

queries, keys, vals = self.pos_embed(queries, keys, vals)

@sagadre
Copy link
Collaborator

sagadre commented Aug 31, 2023

The problem actually seems to be upstream in xformers. Opened an issue here: facebookresearch/xformers#841

@sagadre sagadre linked a pull request Sep 2, 2023 that will close this issue
@sagadre sagadre closed this as completed in #5 Sep 3, 2023
sedrickkeh pushed a commit to sedrick-keh-tri/open_lm_fork that referenced this issue May 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants