Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama3-70B LoRA multi GPU #802

Merged
merged 6 commits into from
Apr 19, 2024
Merged

Llama3-70B LoRA multi GPU #802

merged 6 commits into from
Apr 19, 2024

Conversation

rohan-varma
Copy link
Member

@rohan-varma rohan-varma commented Apr 18, 2024

Context

  • Adds LoRA support for distributed training of Llama3-70B model.
  • Note that HF currently bundles both the meta and HF checkpoints into a single repo for llama3. This makes downloading only a single model a little tricky (and we should avoid user having to download 2 copies of model). For llama3-8B, the documented download commands work out of the box since safetensors are by default ignored by tune download and we use the meta checkpointer.
  • However, using the meta checkpointer for 70B model is much more complicated and requires significant development work to get it to support sharded checkpoints and unsharding/resharding. We can get 70B support much quicker if we use the HF checkpoints and safe tensors. Hence, we're documenting (currently in configs only)how to run the download command to get the safetensors and NOT the meta checkpoint files (so users only have 1 copy of the model), and using the meta checkpointer as the default.

Changelog

  • Render recipe in recipe registry
  • Add config for 70B LoRA training
  • NOTE: need to download with --ignore-patterns "original/consolidated*" to be able to get the safetensors
  • Update README and docs as appropriate

Test plan

  • tune download meta-llama/Meta-Llama-3-70b --hf-token <> --output-dir /tmp/Meta-Llama-3-70b --ignore-patterns "original/consolidated*"
  • tune run --nproc_per_node 8 lora_finetune_distributed --config recipes/configs/llama3/70B_lora.yaml
  • Loss curve:
image - Can't easily do eval due to lack of distributed support. Generation result is pending. - Render in `tune ls` - image - Some logs -
2024-04-18:13:49:43,502 INFO     [lora_finetune_distributed.py:280] FSDP is enabled. Instantiating Model on CPU for Rank 0 ...
2024-04-18:13:51:59,203 INFO     [lora_finetune_distributed.py:286] Model instantiation took 135.70 secs
2024-04-18:13:53:10,922 INFO     [lora_finetune_distributed.py:360] Memory Stats after model init:
{'peak_memory_active': 28.329404416, 'peak_memory_alloc': 26.228041728, 'peak_memory_reserved': 32.377929728}
  • Peak memory allocated is < 30 GB on GPU 0:
image
  • Peak memory reserved on GPU 0:
image

Copy link

pytorch-bot bot commented Apr 18, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/802

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 652cd3d with merge base a9180b5 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 18, 2024
@rohan-varma rohan-varma changed the title Llama3-70B LoRA multi GPU [WIP] Llama3-70B LoRA multi GPU Apr 18, 2024
@rohan-varma rohan-varma marked this pull request as draft April 18, 2024 20:22
@rohan-varma rohan-varma marked this pull request as ready for review April 18, 2024 23:47
@rohan-varma rohan-varma requested a review from ebsmothers April 18, 2024 23:47
@rohan-varma rohan-varma changed the title [WIP] Llama3-70B LoRA multi GPU Llama3-70B LoRA multi GPU Apr 18, 2024
README.md Outdated Show resolved Hide resolved
README.md Outdated


In our initial experiments, QLoRA has a peak allocated memory of ``~9GB`` while LoRA on a single GPU has a peak allocated memory of ``~19GB``. To get started, you can use our default configs to kick off training.
In our initial experiments for Llama3 8B, QLoRA has a peak allocated memory of ``~9GB`` while LoRA on a single GPU has a peak allocated memory of ``~19GB``. To get started, you can use our default configs to kick off training.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit

Suggested change
In our initial experiments for Llama3 8B, QLoRA has a peak allocated memory of ``~9GB`` while LoRA on a single GPU has a peak allocated memory of ``~19GB``. To get started, you can use our default configs to kick off training.
In our initial experiments for Llama3-8B, QLoRA has a peak allocated memory of ``~9GB`` while LoRA on a single GPU has a peak allocated memory of ``~19GB``. To get started, you can use our default configs to kick off training.


```bash
tune run --nproc_per_node 2 full_finetune_distributed --config llama3/8B_full
```

- 70B LoRA finetune on 8 GPUs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the right place to do this is, but do we wanna mention memory requirements to run 70B somewhere?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, maybe we can add it to the table?

# Model Arguments
model:
_component_: torchtune.models.llama3.lora_llama3_70b
lora_attn_modules: ['q_proj', 'v_proj', 'k_proj']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I feel like this is clearer

Suggested change
lora_attn_modules: ['q_proj', 'v_proj', 'k_proj']
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']

Comment on lines +16 to +19
apply_lora_to_mlp: False
apply_lora_to_output: False
lora_rank: 16
lora_alpha: 32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are these defaults set? Are rank and alpha higher than our 7/8B defaults cause the embedding dim is larger?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mostly copied these from https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/70B_lora.yaml.

I'm not sure whether rank and alpha being higher is because of the embedding dim being larger - what do those have to do with each other?

Comment on lines +28 to +59
checkpoint_files: [
model-00001-of-00030.safetensors,
model-00002-of-00030.safetensors,
model-00003-of-00030.safetensors,
model-00004-of-00030.safetensors,
model-00005-of-00030.safetensors,
model-00006-of-00030.safetensors,
model-00007-of-00030.safetensors,
model-00008-of-00030.safetensors,
model-00009-of-00030.safetensors,
model-00010-of-00030.safetensors,
model-00011-of-00030.safetensors,
model-00012-of-00030.safetensors,
model-00013-of-00030.safetensors,
model-00014-of-00030.safetensors,
model-00015-of-00030.safetensors,
model-00016-of-00030.safetensors,
model-00017-of-00030.safetensors,
model-00018-of-00030.safetensors,
model-00019-of-00030.safetensors,
model-00020-of-00030.safetensors,
model-00021-of-00030.safetensors,
model-00022-of-00030.safetensors,
model-00023-of-00030.safetensors,
model-00024-of-00030.safetensors,
model-00025-of-00030.safetensors,
model-00026-of-00030.safetensors,
model-00027-of-00030.safetensors,
model-00028-of-00030.safetensors,
model-00029-of-00030.safetensors,
model-00030-of-00030.safetensors,
]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol we really need to have a way to generate these programmatically

@musabgultekin
Copy link
Contributor

musabgultekin commented Apr 19, 2024

@rohan-varma Have you tested if full-weight training works with 8x80GB ? Maybe it works if we use 8bit AdamW?

@rohan-varma
Copy link
Member Author

@musabgultekin I've dug into this a bit, and haven't gotten it into a full working state yet. This is something on our immediate list of priorities, so stay tuned!

@musab-mk
Copy link
Contributor

Thank you! Will be following you!

@rohan-varma rohan-varma merged commit c69fba1 into main Apr 19, 2024
27 checks passed
@joecummings joecummings deleted the 70b_llama3 branch April 28, 2024 22:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants