Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow sharded grad scaler to cpu offload with FSDP #831

Merged
merged 33 commits into from
Nov 15, 2021

Conversation

anupambhatnagar
Copy link

@anupambhatnagar anupambhatnagar commented Oct 26, 2021

What does this PR do?

Fixes Issue 421 and Issue 834

The ShardedGradScaler class implements _amp_update_scale_cpu_ and _foreach_check_finite_and_unscale_cpu_ functions. These functions are required to enable loss scaling when FSDP is used along with cpu_offload.

Additional Changes:

  • In several places we will require PyTorch version >= 1.8.0
  • Removed test_fsdp_grad_scaler.py since we have implemented the test in test_fsdp.py
  • Increased number of epochs in test_fsdp_state_dict.py. With less than 6 epochs the test_load_local_state_dict__True_True fails. We need 6 epochs for the scale * loss to be not inf/nan and then the update happens. If the loss is inf/nan then the update does not take place.

Before submitting

  • Did you have fun?
    • Make sure you had fun coding 🙃
  • Did you read the contributor guideline?
  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
    • N/A
  • Did you make sure to update the docs?
    • N/A
  • Did you write any new necessary tests?
    • N/A
  • Did you update the changelog? (if needed)
    • N/A

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 26, 2021
@anupambhatnagar anupambhatnagar requested a review from anj-s October 26, 2021 22:58
@anupambhatnagar
Copy link
Author

Added a test for ShardedGradScaler class with no cpu offload.

@anupambhatnagar anupambhatnagar changed the title [DRAFT] allow sharded grad scaler to cpu offload with FSDP Allow sharded grad scaler to cpu offload with FSDP Oct 29, 2021
@anupambhatnagar
Copy link
Author

@anj-s some of the tests are failing because isort checks are failing on many of the files that I did not touch. I will apply isort to the entire repo in a different PR so that this is easy to review.

@anupambhatnagar anupambhatnagar self-assigned this Nov 2, 2021
super().__init__(
init_scale=init_scale,
growth_factor=growth_factor,
backoff_factor=backoff_factor,
growth_interval=growth_interval,
enabled=enabled,
)
self.display_warning = True
self.group = process_group
if enabled and amp_definitely_not_available(): # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how does this work for the CPU only version of GradScaler? even though this is a CPU only version we are assuming that we have GPUs to run the model. just wondering if that is a fair assumption to make.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a clear answer to this. Would the cpu-tests be enough to ensure that this works?

Copy link
Contributor

@min-xu-ai min-xu-ai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bunch of tests are now skipped with 1.8 as well. Since 1.8 is LTS, we want those to be still tested?

tests/nn/data_parallel/test_fsdp_regnet.py Outdated Show resolved Hide resolved
@@ -305,6 +314,24 @@ def test_cpu_offload_and_cpu_grads(self):
)
spawn_and_init(test_fn)

def test_no_cpu_offload_with_sharded_grad_scaler(self):
# We don't test the False condition because that requires the optimizer to internally do
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I know you did not add it but I wonder if we need this comment duplicated in all three tests. Also if could mention what the False property is it would make it a lot clearer.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated comments to reflect what we are testing now.

)
spawn_and_init(test_fn)

def test_no_cpu_offload_with_sharded_grad_scaler_and_mixed_precision(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use parametrization to combine these 4 tests? you are testing mixed_precision=True/False and cpu_offload=True/False. we can skip the move_grads_to_cpu since that is set by default by the cpu_offload param. Also wanted to mention that we have deprecated the cpu_offload param and should use move_params_to_cpu instead.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Upon changing the config dictionary key from "cpu_offload" to "move_params_to_cpu" the test_cpu_offload_and_cpu_grads test breaks with the following error message:

E       Traceback (most recent call last):
E         File "/private/home/anupamb/fairscale/tests/nn/data_parallel/test_fsdp.py", line 133, in _test_identical_outputs
E           torch.testing.assert_allclose(ref_loss, shard_loss)
E         File "/private/home/anupamb/miniconda3/lib/python3.9/site-packages/torch/testing/__init__.py", line 222, in assert_allclose
E           result, debug_msg = _compare_tensors_internal(actual, expected,
E         File "/private/home/anupamb/miniconda3/lib/python3.9/site-packages/torch/testing/__init__.py", line 130, in _compare_tensors_internal
E           if torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan):
E       RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
E
E       During handling of the above exception, another exception occurred:
E
E       Traceback (most recent call last):
E         File "/private/home/anupamb/miniconda3/lib/python3.9/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
E           fn(i, *args)
E         File "/private/home/anupamb/fairscale/tests/nn/data_parallel/test_fsdp.py", line 814, in init_and_run
E           fn(rank, group, *args)
E         File "/private/home/anupamb/fairscale/tests/nn/data_parallel/test_fsdp.py", line 136, in _test_identical_outputs
E           raise Exception(f"FullyShardedDataParallel didn't match PyTorch DDP using config: {config}\n\n {e}")
E       Exception: FullyShardedDataParallel didn't match PyTorch DDP using config: {'mixed_precision': True, 'move_params_to_cpu': True, 'compute_dtype': torch.float32}
E
E        Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

@@ -268,6 +276,7 @@ def rename_test(testcase_func, param_num, param):
return "%s_%s" % (testcase_func.__name__, parameterized.to_safe_name(str(param.args)),)


@pytest.mark.skipif(torch_version() < (1, 9, 0), reason="pytorch version >= 1.9.0 required")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we add a class level pytest skip annotation? maybe I missed that another test class is being run that does not call the _train_for_several_steps function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't follow this comment. can you please elaborate?

@anj-s
Copy link
Contributor

anj-s commented Nov 5, 2021

Thank you for the PR! We have a lot of skip messages added to tests. I am worried that it might reduce our test coverage. Is there any way we can import selected functions (I know we discussed this offline as well).

@anj-s
Copy link
Contributor

anj-s commented Nov 5, 2021

we should also update the CHANGELOG.md file that contains the release notes.

@facebook-github-bot
Copy link

Hi @anupambhatnagar!

Thank you for your pull request.

We require contributors to sign our Contributor License Agreement, and yours needs attention.

You currently have a record in our system, but the CLA is no longer valid, and will need to be resubmitted.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@anupambhatnagar
Copy link
Author

bunch of tests are now skipped with 1.8 as well. Since 1.8 is LTS, we want those to be still tested?

changed minimum version requirement to 1.8 and made sure all tests are passing with 1.8.

@anupambhatnagar
Copy link
Author

Thank you for the PR! We have a lot of skip messages added to tests. I am worried that it might reduce our test coverage. Is there any way we can import selected functions (I know we discussed this offline as well).

I have made code changes allowing us to support PyTorch 1.8. The skip messages are there but it does not bring a decline in the coverage. It is impossible to support 1.7 as it does not have the _amp_update_scale_ function.

@anj-s
Copy link
Contributor

anj-s commented Nov 15, 2021

nit: Would it be possible to add comments highlighting where CPU scaler specific code has been added? I know we discussed this offline but would be good to add this to help future development/debugging.

@anupambhatnagar
Copy link
Author

nit: Would it be possible to add comments highlighting where CPU scaler specific code has been added? I know we discussed this offline but would be good to add this to help future development/debugging.

I have added comments above both functions which are the key pieces to this implementation. See lines 147-149 and 320-322.

@anupambhatnagar anupambhatnagar merged commit ba5785f into main Nov 15, 2021
@min-xu-ai min-xu-ai deleted the cpu_gradscaler branch September 23, 2022 21:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FSDP] Support with AMP Grad scaler
4 participants