Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] batch broadcast requests into a configurable buffer #43

Closed
wants to merge 51 commits into from

Conversation

blefaudeux
Copy link
Contributor

@blefaudeux blefaudeux commented Aug 14, 2020

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Improves on #42 by batching the small broadcasts into a bigger one. There's a tradeoff in between doing more copies and incurring less latency on the communication side, so the broadcast buffer and a size above which the broadcast is direct are configurable.
This is a long lived PR so some listed commits are unrelated and come from merges with upstream master over time.

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃

Sorry, something went wrong.

@blefaudeux blefaudeux requested a review from min-xu-ai August 14, 2020 23:09
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 14, 2020
@codecov
Copy link

codecov bot commented Aug 14, 2020

Codecov Report

Merging #43 into master will increase coverage by 0.07%.
The diff coverage is 92.85%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master      #43      +/-   ##
==========================================
+ Coverage   94.18%   94.26%   +0.07%     
==========================================
  Files          35       35              
  Lines        2065     2092      +27     
==========================================
+ Hits         1945     1972      +27     
  Misses        120      120              
Flag Coverage Δ
#Python 94.26% <92.85%> (+0.07%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
fairscale/optim/oss.py 97.11% <86.36%> (-2.89%) ⬇️
fairscale/nn/data_parallel/oss_ddp.py 84.29% <100.00%> (+1.75%) ⬆️
fairscale/optim/utils.py 100.00% <100.00%> (ø)

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update c2d6f4b...56974ed. Read the comment docs.

@blefaudeux blefaudeux marked this pull request as draft August 14, 2020 23:25
@blefaudeux blefaudeux marked this pull request as ready for review August 15, 2020 00:09
@blefaudeux blefaudeux marked this pull request as draft August 15, 2020 00:41
@blefaudeux blefaudeux marked this pull request as ready for review August 15, 2020 02:15
@blefaudeux blefaudeux requested a review from msbaines August 17, 2020 16:19
@blefaudeux blefaudeux marked this pull request as draft August 17, 2020 22:38
@blefaudeux blefaudeux marked this pull request as ready for review August 17, 2020 23:04
@blefaudeux blefaudeux self-assigned this Aug 18, 2020
fairscale/optim/utils.py Outdated Show resolved Hide resolved
fairscale/optim/oss.py Outdated Show resolved Hide resolved
fairscale/optim/oss.py Outdated Show resolved Hide resolved
@blefaudeux blefaudeux marked this pull request as draft September 3, 2020 21:19
@@ -119,7 +119,7 @@ def closure():
print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

if use_oss and check_regression and dist.get_rank() == 0:
assert (mean - 3.0 * std) < reference_speed, "Speed regression detected"
assert (mean + 3.0 * std) > reference_speed, "Speed regression detected"
Copy link
Contributor Author

@blefaudeux blefaudeux Sep 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was actually wrong before, we want speed to be equal to or better than the baseline, not lower...

parser.add_argument("--check_regression", action="store", default=True, type=bool)
parser.add_argument("--reference_speed", action="store", default=39.82, type=float)
parser.add_argument("--check_regression", action="store_true", default=False)
parser.add_argument("--reference_speed", action="store", default=33, type=float)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a previous PR switched the optimizer from SGD to RMSProp, which is a tad slower. I checked that this 33fps value is not degraded by this PR

@blefaudeux blefaudeux marked this pull request as ready for review September 3, 2020 23:01
Copy link
Contributor

@msbaines msbaines left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not safe to assume all params in a param_groups are on the same device.

@@ -44,12 +45,25 @@ class OSS(Optimizer):
optimizer to shard (default: SGD)
group (group):
torch.distributed group (default: group.WORLD)
buffer_size (int, optional): number of elements to buffer before
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does optional mean in this context? The parameter does not look like an optional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant to write that people are free to pass it in or not, there's a default provided

@@ -67,6 +81,12 @@ def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Any =
# Current device is set by the parameters allocated to this rank
self._device = split_param_groups[self.rank][0]["params"][0].device

# Broadcast buffer settings
self._buffer: Optional[torch.Tensor] = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

parameters cannot change after init so you could pre-process the params and pre-create the batch buffers.

fairscale/optim/oss.py Outdated Show resolved Hide resolved
@msbaines msbaines requested a review from shruti-bh September 3, 2020 23:43
@blefaudeux blefaudeux marked this pull request as draft September 9, 2020 20:54
@blefaudeux
Copy link
Contributor Author

Now deduplicating in between oss_ddp and oss @min-xu-ai, slightly tighter. I think that it would be nicer to adapt the buffering parameters to the loaded model, that's something we know at construction time (very large model and few shards -> adapt the buffer size)

@blefaudeux
Copy link
Contributor Author

Something else that I'm planning is to do is to pre-sort the parameters per size at construction time, so that the step and oss_ddp logic is simpler, as suggested by Min

@blefaudeux blefaudeux closed this Oct 1, 2020
@blefaudeux blefaudeux deleted the oss_batch_broadcast branch October 1, 2020 18:49
myleott added a commit that referenced this pull request Feb 22, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Faster OSS
3 participants