-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow layer-wise recompute #18566
Allow layer-wise recompute #18566
Conversation
|
…s even smaller than the other one in forward pass (FusedMatmul which is replaced by a new node after gradient graph is built)
…pengwa/add_aggresive_recompute
Find those stashed activations that are used by backward operators. Put all those activations as candidates; For each candidate, https://github.com/microsoft/onnxruntime/blob/f3369a8bf87190552ad551a6de56df01cccf7a62/orttraining/orttraining/core/optimizer/memory_optimizer/memory_insight.cc#L272C9-L272C30 will check whether it is recomputable, and how the subgraph looks like,.
No, that's the long term goal to have all those feature ready, to help dynamically choose a good plan for users. |
…pengwa/add_aggresive_recompute
orttraining/orttraining/python/training/ortmodule/_training_manager.py
Outdated
Show resolved
Hide resolved
orttraining/orttraining/core/optimizer/memory_optimizer/recompute_analysis.cc
Outdated
Show resolved
Hide resolved
…pengwa/add_aggresive_recompute
Thank you @askhade, @zhijxu-MS !! |
Allow layer-wise recompute
Early, we need users/developers to specify the subgraphs to recompute, now we introduced a more user-friendly way to enable recompute for all detected stashed activation recomputation subgraphs. This scarifies getting the best configs while makes it easier to support user requirements when they switches from PyTorch per-layer gradient checkpoint to ORTModule.
ORTMODULE_MEMORY_OPT_LEVEL
is introduced to control the usage, by default, it is 0, e.g.USER_SPECIFIED
, all subgraphs definedinORTMODULE_MEMORY_OPT_CONFIG
will be recomputed. So this is compatible to existing recompute usage in ORTModule integrated models.Using
ORTMODULE_MEMORY_OPT_LEVEL=1
, we will enable all recompute plans detected, so those configs inORTMODULE_MEMORY_OPT_CONFIG
will not be respected any more.Add Unit Tests using 3 layer blooms.
https://github.com/microsoft/onnxruntime/blob/pengwa/add_aggresive_recompute/docs/Memory_Optimizer.md