You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
HI,
I am using Ulysses Attention and DeepSpeed Zero3 optimizer for DPO training.
My question what is the right way of loss aggregation?
When one trains model with CrossEntropyLoss, each rank yields loss for each own sequence subset.
But when on trains model with DPO loss, we have only loss for the whole example.
What is right way to deal with it?
The text was updated successfully, but these errors were encountered:
Yes, I have seen this before. As far as I understand, in this approach on each rank all the logits are stored and on the backward pass only required parts of gradients are taken into account.
I was trying to use approach when some reduction was done before communication inside sequence parallel group (I was trying to reduce the communication load), I am going to try to create something in that way
I see gradient difference.
I have two ranks in the sequence parallel group (SPG), compute gradients of the loss and slicing it with respect to rank in SPG, therefore I have two gradients g_1 and g_2, one on each rank.
When I run the same setup without sequence parallelism, I have gradient g, and for all layers I have approximately g == g_1 + g_2.
So my question is whether DeepSpeed Zero3 optimizer handle this case correctly or it just take the average across all ranks?
I see the difference of gradient norm as well, I have around 3.24 on each rank with sequence parallelism and around 9.11 without sequence parallelism
HI,
I am using Ulysses Attention and DeepSpeed Zero3 optimizer for DPO training.
My question what is the right way of loss aggregation?
When one trains model with CrossEntropyLoss, each rank yields loss for each own sequence subset.
But when on trains model with DPO loss, we have only loss for the whole example.
What is right way to deal with it?
The text was updated successfully, but these errors were encountered: