Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fix] OSS pytorch-compliant state dict #61

Merged
merged 3 commits into from
Sep 3, 2020
Merged

Conversation

blefaudeux
Copy link
Contributor

Before submitting

  • Was this discussed/approved via a Github issue? (no need for typos, doc improvements)
  • Did you read the contributor guideline?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?

What does this PR do?

Fixes #60 .

PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Except for the realization that all my lint errors were due to isort and black being version-dependent in their behaviour (FFS), all good 🙃

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 2, 2020
Copy link
Contributor

@msbaines msbaines left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if you un-sharded in state-dict and then re-sharded in load_state_dict? That way, we work exactly like PyTorch is expecting.

return {"state": self._all_states}
return {
"state": [s["state"] for s in self._all_states],
"param_groups": [s["param_groups"] for s in self._all_states],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove param_groups from _all_states to avoid having a redundant copy?

Copy link
Contributor Author

@blefaudeux blefaudeux Sep 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we don't own this state dict object, it's returned and could have any lifetime, I don't see the issue with a redundant copy ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A redundant copy won't consume memory since they are just references but I'm wondering if it will increase disk space. Would be good to check that. I'm also wondering what will happen on load. Will the two copies refer to the same objects or will there be two copies. That could cause OOM.

OSS created the state_dicts in _collect_sharded_states. There are no references outside OSS at this point.

Copy link
Contributor Author

@blefaudeux blefaudeux Sep 2, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aaaah, I see what you mean, but it's not related to this PR, right ? In general I agree, _all_states could be changed to remove duplicates, but to me this is a relatively independent change, I was trying to address the get/load state interface only. In the same fashion, it feels to me like there must be duplicates in between self.param_groups, self._all_states and self.optim.param_groups

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

re: lifetime, I was refering to the object you commented on, return {}, I thought that was the subject of your comment

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm worried about this potentially introducing a regression to fairseq. They are now using this in master and I tested with only a single copy of the parameters. Maybe just confirm that disk space does not increase and we don't get duplicate tensors for the params on load. Might be cool to add a memory benchmark (@andersonic added a memory assertion in one of the gpipe benchmarks). If the duplicates are referencing the same object then it is not an issue.

@blefaudeux
Copy link
Contributor Author

What if you un-sharded in state-dict and then re-sharded in load_state_dict? That way, we work exactly like PyTorch is expecting.

The pytorch expectation is not super clear to me, param_groups is a list (even if described as a dict), there does not seem to be a lot more constraints in https://pytorch.org/docs/stable/_modules/torch/optim/optimizer.html#Optimizer.state_dict than that.

You mean unrolling the list so that param_groups is [dict()] ? I can do that

@blefaudeux blefaudeux marked this pull request as draft September 2, 2020 22:01
@blefaudeux blefaudeux changed the title Oss pytorch state dict [fix] OSS pytorch-compliant state dict Sep 3, 2020
@blefaudeux blefaudeux marked this pull request as ready for review September 3, 2020 17:41
@blefaudeux blefaudeux merged commit 1d1d15e into master Sep 3, 2020
@blefaudeux blefaudeux deleted the oss_pytorch_state_dict branch September 3, 2020 18:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Rewrite state_dict in a more pytorch idiomatic way
3 participants