Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow layer-wise recompute #18566

Merged
merged 27 commits into from
Dec 12, 2023
Merged

Allow layer-wise recompute #18566

merged 27 commits into from
Dec 12, 2023

Conversation

pengwa
Copy link
Contributor

@pengwa pengwa commented Nov 23, 2023

Allow layer-wise recompute

Early, we need users/developers to specify the subgraphs to recompute, now we introduced a more user-friendly way to enable recompute for all detected stashed activation recomputation subgraphs. This scarifies getting the best configs while makes it easier to support user requirements when they switches from PyTorch per-layer gradient checkpoint to ORTModule.

ORTMODULE_MEMORY_OPT_LEVEL is introduced to control the usage, by default, it is 0, e.g. USER_SPECIFIED, all subgraphs definedin ORTMODULE_MEMORY_OPT_CONFIG will be recomputed. So this is compatible to existing recompute usage in ORTModule integrated models.

Using ORTMODULE_MEMORY_OPT_LEVEL=1, we will enable all recompute plans detected, so those configs in ORTMODULE_MEMORY_OPT_CONFIG will not be respected any more.

Add Unit Tests using 3 layer blooms.

https://github.com/microsoft/onnxruntime/blob/pengwa/add_aggresive_recompute/docs/Memory_Optimizer.md

@pengwa pengwa added the training issues related to ONNX Runtime training; typically submitted using template label Nov 23, 2023
@pengwa pengwa changed the title Allow AGGRESSIVE_FULL_RECOMPUTE in memory optimization Allow layer-wise recompute Nov 24, 2023
@AdamLouly
Copy link
Contributor

  • How does ORT finds all the recompute opportunities in this implementation?
  • Are there numbers that shows how much memory optimization we have against how much perf we lost?

…s even smaller than the other one in forward pass (FusedMatmul which is replaced by a new node after gradient graph is built)
@pengwa
Copy link
Contributor Author

pengwa commented Nov 30, 2023

Find those stashed activations that are used by backward operators. Put all those activations as candidates; For each candidate, https://github.com/microsoft/onnxruntime/blob/f3369a8bf87190552ad551a6de56df01cccf7a62/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc#L272C9-L272C30 will check whether it is recomputable, and how the subgraph looks like,.

  • Are there numbers that shows how much memory optimization we have against how much perf we lost?

No, that's the long term goal to have all those feature ready, to help dynamically choose a good plan for users.

zhijxu-MS
zhijxu-MS previously approved these changes Dec 8, 2023
@pengwa pengwa merged commit ccf3b20 into main Dec 12, 2023
96 checks passed
@pengwa pengwa deleted the pengwa/add_aggresive_recompute branch December 12, 2023 00:44
@pengwa
Copy link
Contributor Author

pengwa commented Dec 12, 2023

Thank you @askhade, @zhijxu-MS !!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants