-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to use Optimizer State Sharding with Sharpness-Aware Minimization? #989
Comments
it's a super good question, and SAM is pretty great I think ! I would wrap the other way, I would suggest you give the upstreamed Pytorch implementation a go (here), the interface should be the same and I'm not sure that the version in Fairscale is still up to date (cc @anj-s @min-xu-ai ). Context is that I'm one of the original authors of OSS (among others) |
@blefaudeux Thank you for your suggestion. I am able to get my code running following your idea However, when I use AMP with ShardedGradScaler, I am hitting the following assertion when calling assert len(found_infs) > 0, "No inf checks were recorded prior to update." The assertion is raised for only AMP but not full precision. Any insights on what I am missing? |
np, and great to hear that the basics work !
this kind of makes sense, basically what happens with AMP is that the optimizer.step() is sometimes a no-op, because the algorithm is to try to find the right scaling in between fp32 and fp16 (which has a super short dynamic, above 65k is inf), and the idea is to (1) look for infs prior to the step (so that the inf gradients, if present, don't kill the model) (2) if there are, adapt the scale and skip this step. From a distance, it looks like one of the SAM optimizer steps did not call the grad scaler (yes, this is yet another .step() overload). I think that the doc is not complete there, but from a distance I would say that the correct wrap order is probably It's really possible that it does not work (I just have a high level understanding of how SAM works), in that case it would require some plumbing in some of these pieces, it would not be turnkey I believe. |
@blefaudeux Thank you for explaining the mechanisms of grad scaler for mixed-precision training. As per your suggestion Could you kindly elaborate further on how the wrap |
hey, sorry that it was not too clear indeed, I mixed up two things here (the actual steps would be wrapped this way, but not the constructors, you're right). The .step() call looks like this, it takes the optimizer as a parameter and will do what I tentatively explained above (also explained here). There are examples in the pytorch doc, the sharded version should behave in the same way. To explain why it's there, it just needs to consolidate the gradient check over all agents so that if one of them has to skip a step(), they all do. Update: |
Thank you for your clarifications and explaining the differences between GradScaler and its sharded version. I noticed that the step function calls only @torch.no_grad()
def step(self, closure=None):
assert closure is not None, "Sharpness Aware Minimization requires closure, but it was not provided"
closure = torch.enable_grad()(closure) # the closure should do a full forward-backward pass
self.ascent_step()
closure()
self.descent_step() for input, output in data:
def closure():
with autocast():
loss = loss_function(output, model(input))
scaler.scale(loss).backward()
return loss
with autocast():
loss = loss_function(output, model(input))
scaler.scale(loss).backward()
scaler.step(optimizer, closure) # <-- optimizer here is SAM(OSS(Adam))
scaler.update()
... Update: if "closure" in kwargs:
raise RuntimeError("Closure use is not currently supported if GradScaler is enabled.") |
hmm ok, thanks for the pointer. So it's the "it will require more plumbing" part that I mentioned, I don't think that it can work out of the box, one of the issues being that there's a scaling to be done before the backward pass, which a closure would mask. Something is probably doable through backward hooks (make sure that the scaling is done automatically by attaching an appropriate hook for instance), but it's not something that I have time to investigate unfortunately :/ A low hanging fruit that you may try is to train using bfloat16, if your hardware supports that ? This does not require any scaler, and you get the same memory savings as float16. A tradeoff is that bfloat16 is not very precise, but depending on the workloads the influence can be rather small (and it could well be that the round minimum that SAM finds is perfect for that) |
Thank you for the directions. I tried bfloat16 on both 2080 Ti and V100 using
I also came across some attempts in using SAM with AMP but they don't seem to handle |
yes, you need Ampere cards for that.. Or TPUs. I had no idea what you had available
Yep, this option is to skip the gradscaler altogether, but then you don't have overflow/underflow protection, and float16 is not a gentle type to work with.. |
@blefaudeux I have been looking at different variations of SAM these days and apparently none supports AMP at the moment. Guess I will put AMP on hold as it is not trivial. Currently, I am more concerned with how different implementations handle gradient synchronization differently. I am confused about which one is the correct usage when paired with OSS and ShardedDataParallel. The unofficial SAM implementation moves the parameters to the same device before computing the gradient norm. The official Adaptive SAM (ASAM) implementation computes the gradient norm on individual workers and does not do any gradient synchronization. The Gap-guided SAM (GSAM) implementation also computes the gradient norm on individual workers but has an explicit gradient synchronization at the end. So far my experience is that ASAM converges quite slowly. While GSAM converges way quicker the gradient explodes after a few epochs. I have been considering different factors including neighborhood size in SAM, loss functions, gradient synchronization but so far I do not have a good answer. Unfortunately, I am not able to do a parameter sweep due to limited resources and the scale of my experiment. I would appreciate it if you can kindly shed light on the issue.
|
Hey @kenmbkr sorry for the delay, I've very little time these days (and not at facebook anymore). I would need to dive into this to say something half-way smart, but I think you're right in assuming that coupling that with AMP is probably for another day. Purely with SAM I think that it should work with OSS, as long as SAM is the outer wrap (but this is on a conceptual basis, I'm not sure that the API will help you there).
The short answer there is that I don't think that you would need a gradient sync step, unless the SAM method looks at the gradient values across the model (I need to re-read the paper), in which case with DDP you would still need no sync, but you would need one with ShardedDDP (sync across the agents <> sync across the model). If SAM only considers the optim tensor per tensor, this is not required with either options. edit: ----> below is the real TL; DR |
Thank you very much for getting back and I really appreciate your time in looking at the issues. From my understanding of your words whether individual implementations sync or do not sync the gradients, it should not affect the results (theoretically). Please correct me if I am wrong. The unofficial SAM implementation actually recommends "To compute the SAM update when parallelizing across multiple accelerators, we divide each data batch evenly among the accelerators, independently compute the SAM gradient on each accelerator, and average the resulting sub-batch SAM gradients to obtain the final SAM update."
I have also come across this paper while experimenting on SAM and I think it is a promising direction to get a good balance between training efficiency and loss landscape. However, their implementation is not publicly available and it is not trivial for me to implement them. I hope they will open their implementation some days. |
Oh interesting, it makes sense if the SAM gradients are smaller than these of the model (and the SAM computation is commutative with the reduction), else maybe that it's an approximation but good enough. It could explain maybe (I've not read the papers, disclaimer) why the different xSAM variants suggest something different in that particular case ? S In any case:
|
Thank you for your valuable insights into the gradient synchronization problem :) As for FSDP, it is actually the first thing I tried when I came across fairscale. |
I am trying to setup OSS with sharpness-aware minimization (SAM).
https://github.com/davda54/sam
https://github.com/SamsungLabs/ASAM
I followed this guide to setup OSS, but I have difficulty wrapping SAM with OSS.
Due to both OSS and SAM being optimizer wrappers and having their own step functions, I am not sure how to combine both and call those functions. My initial clue is to wrap like this: OSS(SAM(Adam)).
https://fairscale.readthedocs.io/en/stable/tutorials/oss.html
Is there a minimal working example showing how to do that?
The text was updated successfully, but these errors were encountered: