-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] batch broadcast requests into a configurable buffer #43
Conversation
Codecov Report
@@ Coverage Diff @@
## master #43 +/- ##
==========================================
+ Coverage 94.18% 94.26% +0.07%
==========================================
Files 35 35
Lines 2065 2092 +27
==========================================
+ Hits 1945 1972 +27
Misses 120 120
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
benchmarks/oss.py
Outdated
@@ -119,7 +119,7 @@ def closure(): | |||
print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}") | |||
|
|||
if use_oss and check_regression and dist.get_rank() == 0: | |||
assert (mean - 3.0 * std) < reference_speed, "Speed regression detected" | |||
assert (mean + 3.0 * std) > reference_speed, "Speed regression detected" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this was actually wrong before, we want speed to be equal to or better than the baseline, not lower...
benchmarks/oss.py
Outdated
parser.add_argument("--check_regression", action="store", default=True, type=bool) | ||
parser.add_argument("--reference_speed", action="store", default=39.82, type=float) | ||
parser.add_argument("--check_regression", action="store_true", default=False) | ||
parser.add_argument("--reference_speed", action="store", default=33, type=float) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
a previous PR switched the optimizer from SGD to RMSProp, which is a tad slower. I checked that this 33fps value is not degraded by this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is not safe to assume all params in a param_groups are on the same device.
@@ -44,12 +45,25 @@ class OSS(Optimizer): | |||
optimizer to shard (default: SGD) | |||
group (group): | |||
torch.distributed group (default: group.WORLD) | |||
buffer_size (int, optional): number of elements to buffer before |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does optional mean in this context? The parameter does not look like an optional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant to write that people are free to pass it in or not, there's a default provided
fairscale/optim/oss.py
Outdated
@@ -67,6 +81,12 @@ def __init__(self, params: _params_t, optim: Type[Optimizer] = SGD, group: Any = | |||
# Current device is set by the parameters allocated to this rank | |||
self._device = split_param_groups[self.rank][0]["params"][0].device | |||
|
|||
# Broadcast buffer settings | |||
self._buffer: Optional[torch.Tensor] = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
parameters cannot change after init so you could pre-process the params and pre-create the batch buffers.
…n terms of perfs, may need some iterations
Now deduplicating in between oss_ddp and oss @min-xu-ai, slightly tighter. I think that it would be nicer to adapt the buffering parameters to the loaded model, that's something we know at construction time (very large model and few shards -> adapt the buffer size) |
Something else that I'm planning is to do is to pre-sort the parameters per size at construction time, so that the step and oss_ddp logic is simpler, as suggested by Min |
Before submitting
What does this PR do?
Improves on #42 by batching the small broadcasts into a bigger one. There's a tradeoff in between doing more copies and incurring less latency on the communication side, so the broadcast buffer and a size above which the broadcast is direct are configurable.
This is a long lived PR so some listed commits are unrelated and come from merges with upstream master over time.
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Make sure you had fun coding 🙃