You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When all child modules that have trainable parameters are wrapped with FSDP, this breaks. I spent some time digging to figure out the issue till I came to this
To Reproduce
To reproduce in the test file here, you have to change a line and add a new test:
classTestComparisonToPyTorchDDP(DistributedTest):
...
deftest_nested_all_wrapped_model(self):
# We use a model with a long CUDA delay right before the optimizer step.# This tests our streams logic, and that we don't start the FP32 -> FP16# transfer until after the optimization step completes.config= {"mixed_precision": True}
model_fn=NestedWrappedModuletest_fn=functools.partial(self._test_identical_outputs, model_fn, config)
spawn_and_init(test_fn)
Expected behavior
I'm unsure if this a limitation of FSDP and it should work. It isn't the highest priority since I think in most cases you wouldn't wrap everything single trainable module in FSDP I think!
The text was updated successfully, but these errors were encountered:
As a follow-up, we may want to explore the idea Myle mentioned in the review where we add an extra param to the outer FSDP so that it always gets a callback on the backward pass. I see 3 options below, please let me know your thoughts:
we can added a single float param with value 0.0, say p, and for the outer most FSDP, before the output is returned in the forward(), we do output += p. I think (but haven't verified it) this would hook p into the autograd graph and ensure outer FSDP gets a chance in the backward callback to queue a final callback. I think we just ignore the grad on p and it will always have the value 0.0.
similar to 1, we can keep a single param with value 1.0 and we do output *= p right before returning the output. Both 1 and 2 have some added FLOPs.
since we want to go toward auto-wrap, we don't need to do anything here anymore and assume auto-wrap in the future will not leave the outer most FSDP empty.
🐛 Bug
When all child modules that have trainable parameters are wrapped with FSDP, this breaks. I spent some time digging to figure out the issue till I came to this
To Reproduce
To reproduce in the test file here, you have to change a line and add a new test:
Wrap all layers instead of just some:
https://github.com/facebookresearch/fairscale/blob/master/tests/nn/data_parallel/test_fsdp.py#L691
I made a test new like below to run:
Expected behavior
I'm unsure if this a limitation of FSDP and it should work. It isn't the highest priority since I think in most cases you wouldn't wrap everything single trainable module in FSDP I think!
The text was updated successfully, but these errors were encountered: