Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama3-8b memory efficient full finetune #990

Merged
merged 1 commit into from
May 17, 2024
Merged

Llama3-8b memory efficient full finetune #990

merged 1 commit into from
May 17, 2024

Conversation

rohan-varma
Copy link
Member

@rohan-varma rohan-varma commented May 16, 2024

Context

What is the purpose of this PR? Is it to

  • add a new feature
  • fix a bug
  • update tests and/or documentation
  • other (please add here)

TL;DR: This PR saves ~46% peak memory for llama3-8b single device full finetune while keeping performance at parity to current offering, just by switching Adamw8bit -> PagedAdamw8bit. @ebsmothers reminded me that this exists after I took a much more complicated approach.

Changelog

  • Previous experiments using PagedOptimizer from bnb for llama3 workloads resulted in prohibitvely slow QPS (> 6s/it, compared to using paged optim in llama2 workload still providing > 1 it/s QPS). After some debugging, this is primarily due to paging in and out large optimizer states associated with the embedding and output projection.
  • After a chat with @ebsmothers, we realized that we can just try PagedAdamW8bit, which reduces the size of the optimizer states for the output projection and embedding, and experiments using this proved to benefit memory usage at no cost to QPS. A previous version of this PR made output projection and embedding use Adamw8bit and the others use PagedOptimizer separately, but this approach is much simpler.

Test plan

Current 8B full single device:

Step 132 | loss:0.8253529071807861 lr:2e-05 tokens_per_second_per_gpu:335.27384473126125 peak_memory_active:33.644823552 peak_memory_alloc:33.644823552 peak_memory_reserved:36.320575488
1.27it/s

8B_full_single_device using PagedAdamW:

Step 44 | loss:0.7166110873222351 lr:2e-05 tokens_per_second_per_gpu:23.2328866347064 peak_memory_active:17.456531968 peak_memory_alloc:17.456531968 peak_memory_reserved:19.302187008
7.64s/it

8B full single device with this PR (PagedAdamW8bit):

Step 44 | loss:0.7101777195930481 lr:2e-05 tokens_per_second_per_gpu:223.38999360355982 peak_memory_active:17.486964224 peak_memory_alloc:17.486964224 peak_memory_reserved:19.333644288

For comparison, current llama2-7b 7B_full_low_memory:

Step 44 | loss:0.6592501997947693 lr:2e-05 tokens_per_second_per_gpu:326.61038570142404 peak_memory_active:13.924915712 peak_memory_alloc:13.924915712 peak_memory_reserved:14.845739008
1.41it/s

TL;DR: This PR reduces peak memory by ~46% while maintaining approximately the same perf, getting us to a < 24 GB full finetune.

Loss curves are the same (comparing today's baseline versus with these changes) -

image

Follow-ups

  • Documentation for optimizer in backward and using bnb optimizers - this documentation is sparse, AFAIk we barely mention bitsandbytes in our docs and don't explain running optimizer in backward at all. We should add comprehensive documentation around these full finetuning memory optimizations.
  • Update numbers in README table (though I think these are for llama2 at the moment).

Copy link

pytorch-bot bot commented May 16, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/990

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c99b4d7 with merge base 3883081 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 16, 2024
@rohan-varma rohan-varma requested a review from ebsmothers May 16, 2024 21:26
@RdoubleA
Copy link
Contributor

Few high level questions:

  • These memory tricks are awesome but it comes with the tradeoff of making the recipe harder to understand. But we should support more of these and have a place for them. What do you think about having separate configs/recipes for low memory optimizations? We have this for llama2. Or is that the intention for the single device recipes / configs?
  • I am not too keen on having three different optimizer fields in the config. If we have a separate low memory recipe maybe this is ok, but if we use the existing recipe what are your thoughts on hardcoding the AdamW8bit optimizers for embedding and output projection and using the same learning rate as the main optimizer? Do you think we need to expose these as configurable flags to the user? That way you don't have to do all the config gymnastics in the recipe

@rohan-varma
Copy link
Member Author

what are your thoughts on hardcoding the AdamW8bit optimizers for embedding and output projection and using the same learning rate as the main optimizer

I'd like to avoid this especially if we stick with the current recipe. Since this recipe is used for other workloads, users who change optimizer via config for those workloads would be surprised that their changes don't have effect since we hardcode here.

@rohan-varma
Copy link
Member Author

@RdoubleA Yeah the UX concerns definitely make sense. I think this config can just be renamed appropriately to match what we have for llama2.

Open to authoring a separate recipe or moving this to a helper function - feel free to let me know what you and @ebsmothers think or if any additional input is needed from me, thanks!

@rohan-varma
Copy link
Member Author

Refactored to simply use PagedAdamW8bit after @ebsmothers suggestion!

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! Really happy to see we were able to get good memory with PagedAdamW8bit. We should see if it helps on the Llama2 memory-efficient config at all too (ofc not as urgent since we already have reasonable peak memory there)

Copy link
Contributor

@RdoubleA RdoubleA left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dang how did all that become one line

@RdoubleA RdoubleA merged commit 46d7c83 into main May 17, 2024
29 checks passed
@joecummings joecummings deleted the 8b_lowmem branch May 17, 2024 13:54
weifengpy pushed a commit to weifengpy/torchtune that referenced this pull request Jun 4, 2024
maximegmd pushed a commit to maximegmd/torchtune that referenced this pull request Jul 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants