Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] breaks when all modules within FSDP root are wrapped #437

Closed
SeanNaren opened this issue Feb 26, 2021 · 3 comments · Fixed by #441
Closed

[FSDP] breaks when all modules within FSDP root are wrapped #437

SeanNaren opened this issue Feb 26, 2021 · 3 comments · Fixed by #441
Assignees

Comments

@SeanNaren
Copy link

🐛 Bug

When all child modules that have trainable parameters are wrapped with FSDP, this breaks. I spent some time digging to figure out the issue till I came to this

To Reproduce

To reproduce in the test file here, you have to change a line and add a new test:

Wrap all layers instead of just some:

https://github.com/facebookresearch/fairscale/blob/master/tests/nn/data_parallel/test_fsdp.py#L691

class NestedWrappedModule(nn.Module):
    def __init__(self, group, wrapper_config):
        super().__init__()
        self.rank = group.rank()
        self.world_size = group.size()
        self.wrapper_config = wrapper_config

        def _maybe_wrap(layer):
            if wrapper_config is not None:
                return FullyShardedDataParallel(layer, group, **wrapper_config)
            return layer

        torch.manual_seed(0)  # keep everything deterministic
        self.module = nn.Sequential(
            _maybe_wrap(nn.Linear(8, 4)), _maybe_wrap(nn.Linear(4, 16)), _maybe_wrap(nn.Linear(16, 4)), _maybe_wrap(nn.Linear(4, 8)),
        )

I made a test new like below to run:

class TestComparisonToPyTorchDDP(DistributedTest):
    ...
    def test_nested_all_wrapped_model(self):
        # We use a model with a long CUDA delay right before the optimizer step.
        # This tests our streams logic, and that we don't start the FP32 -> FP16
        # transfer until after the optimization step completes.
        config = {"mixed_precision": True}
        model_fn = NestedWrappedModule
        test_fn = functools.partial(self._test_identical_outputs, model_fn, config)
        spawn_and_init(test_fn)

Expected behavior

I'm unsure if this a limitation of FSDP and it should work. It isn't the highest priority since I think in most cases you wouldn't wrap everything single trainable module in FSDP I think!

@min-xu-ai
Copy link
Contributor

found the bug, will have a fix to this soon.

@SeanNaren
Copy link
Author

Ridiculous turnaround times you guys are machines! Thanks so much @min-xu-ai for fixing this :)

@min-xu-ai
Copy link
Contributor

As a follow-up, we may want to explore the idea Myle mentioned in the review where we add an extra param to the outer FSDP so that it always gets a callback on the backward pass. I see 3 options below, please let me know your thoughts:

  1. we can added a single float param with value 0.0, say p, and for the outer most FSDP, before the output is returned in the forward(), we do output += p. I think (but haven't verified it) this would hook p into the autograd graph and ensure outer FSDP gets a chance in the backward callback to queue a final callback. I think we just ignore the grad on p and it will always have the value 0.0.
  2. similar to 1, we can keep a single param with value 1.0 and we do output *= p right before returning the output. Both 1 and 2 have some added FLOPs.
  3. since we want to go toward auto-wrap, we don't need to do anything here anymore and assume auto-wrap in the future will not leave the outer most FSDP empty.

Any other options I missed?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants