-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fix] OSS pytorch-compliant state dict #61
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if you un-sharded in state-dict and then re-sharded in load_state_dict? That way, we work exactly like PyTorch is expecting.
return {"state": self._all_states} | ||
return { | ||
"state": [s["state"] for s in self._all_states], | ||
"param_groups": [s["param_groups"] for s in self._all_states], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove param_groups from _all_states to avoid having a redundant copy?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we don't own this state dict object, it's returned and could have any lifetime, I don't see the issue with a redundant copy ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A redundant copy won't consume memory since they are just references but I'm wondering if it will increase disk space. Would be good to check that. I'm also wondering what will happen on load. Will the two copies refer to the same objects or will there be two copies. That could cause OOM.
OSS created the state_dicts in _collect_sharded_states. There are no references outside OSS at this point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aaaah, I see what you mean, but it's not related to this PR, right ? In general I agree, _all_states could be changed to remove duplicates, but to me this is a relatively independent change, I was trying to address the get/load state interface only. In the same fashion, it feels to me like there must be duplicates in between self.param_groups
, self._all_states
and self.optim.param_groups
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
re: lifetime, I was refering to the object you commented on, return {}
, I thought that was the subject of your comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm worried about this potentially introducing a regression to fairseq. They are now using this in master and I tested with only a single copy of the parameters. Maybe just confirm that disk space does not increase and we don't get duplicate tensors for the params on load. Might be cool to add a memory benchmark (@andersonic added a memory assertion in one of the gpipe benchmarks). If the duplicates are referencing the same object then it is not an issue.
The pytorch expectation is not super clear to me, param_groups is a list (even if described as a dict), there does not seem to be a lot more constraints in https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.state_dict than that. You mean unrolling the list so that param_groups is |
Before submitting
What does this PR do?
Fixes #60 .
PR review
Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.
Did you have fun?
Except for the realization that all my lint errors were due to isort and black being version-dependent in their behaviour (FFS), all good 🙃