Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama3-70B LoRA multi GPU #802

Merged
merged 6 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
 
 

torchtune now officially supports Meta Llama3! Check out our recipes for Llama3-8B with LoRA, QLoRA and Full fine-tune in the [Llama3](#llama3) section! 🚀 🦙
torchtune now officially supports Meta Llama3! Check out our recipes for Llama3-8B with LoRA, QLoRA and Full fine-tune in the [Llama3](#llama3) section! We also support 70B fine-tuning with LoRA! 🚀 🦙

# torchtune

Expand Down Expand Up @@ -44,7 +44,7 @@ torchtune currently supports the following models.

| Model | Sizes |
|-----------------------------------------------|-----------|
| [Llama3](https://llama.meta.com/llama3) | 8B [[models](torchtune/models/llama3/_model_builders.py), [configs](recipes/configs/llama3/)] |
| [Llama3](https://llama.meta.com/llama3) | 8B, 70B [[models](torchtune/models/llama3/_model_builders.py), [configs](recipes/configs/llama3/)] |
| [Llama2](https://llama.meta.com/llama2/) | 7B, 13B, 70B [[models](torchtune/models/llama2/_model_builders.py), [configs](recipes/configs/llama2/)] |
| [Mistral](https://huggingface.co/mistralai) | 7B [[model](torchtune/models/mistral/_model_builders.py), [configs](recipes/configs/mistral/)] |
| [Gemma](https://huggingface.co/collections/google/gemma-release-65d5efbccdbb8c4202ec078b) | 2B [[model](torchtune/models/gemma/_model_builders.py), [configs](recipes/configs/gemma/)] |
Expand Down Expand Up @@ -86,35 +86,41 @@ This table captures the minimum memory requirements for our different recipes us

## Llama3

torchtune supports fine-tuning for the Llama3 8B models with support for 70B on its way. We currently support LoRA, QLoRA and Full-finetune on a single GPU as well as LoRA and Full fine-tune on multiple devices. For all the details, take a look at our [tutorial](https://pytorch.org/torchtune/stable/tutorials/llama3.html).
torchtune supports fine-tuning for the Llama3 8B and 70B models. We currently support LoRA, QLoRA and Full-finetune on a single GPU as well as LoRA and Full fine-tune on multiple devices for the 8B model, and LoRA on multiple devices for the 70B model. For all the details, take a look at our [tutorial](https://pytorch.org/torchtune/stable/tutorials/llama3.html).


In our initial experiments, QLoRA has a peak allocated memory of ``~9GB`` while LoRA on a single GPU has a peak allocated memory of ``~19GB``. To get started, you can use our default configs to kick off training.
In our initial experiments for Llama3-8B, QLoRA has a peak allocated memory of ``~9GB`` while LoRA on a single GPU has a peak allocated memory of ``~19GB``. To get started, you can use our default configs to kick off training.

- LoRA on a single GPU.
- 8B LoRA on a single GPU.

```bash
tune run lora_finetune_single_device --config llama3/8B_lora_single_device
```

- QLoRA on a single GPU
- 8B QLoRA on a single GPU

```bash
tune run lora_finetune_single_device --config llama3/8B_qlora_single_device
```

- LoRA on 2 GPUs
- 8B LoRA on 2 GPUs

```bash
tune run --nproc_per_node 4 lora_finetune_distributed --config llama3/8B_lora
```

- Full fine-tune on 2 GPUs
- 8B Full fine-tune on 2 GPUs

```bash
tune run --nproc_per_node 2 full_finetune_distributed --config llama3/8B_full
```

- 70B LoRA finetune on 8 GPUs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the right place to do this is, but do we wanna mention memory requirements to run 70B somewhere?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, maybe we can add it to the table?


```bash
tune run --nproc_per_node 8 lora_finetune_distributed --config recipes/configs/llama3/70B_lora.yaml
```


 

Expand Down
2 changes: 2 additions & 0 deletions docs/source/api_ref_models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ All models from the `Llama3 family <https://llama.meta.com/llama3/>`_.
:nosignatures:

llama3.llama3_8b
llama3.llama3_70b
llama3.lora_llama3_8b
llama3.qlora_llama3_8b
llama3.lora_llama3_70b


llama2
Expand Down
100 changes: 100 additions & 0 deletions recipes/configs/llama3/70B_lora.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Config for multi-device LoRA in lora_finetune_distributed.py
# using a Llama3 70B model
#
# This config assumes that you've run the following command before launching
# this run:
# tune download meta-llama/Meta-Llama-3-70b --hf-token <TOKEN> --output-dir /tmp/Meta-Llama-3-70b --ignore-patterns "original/consolidated*"
#
# This config needs 8 GPUs to run
# # tune run --nproc_per_node 8 lora_finetune_distributed --config recipes/configs/llama3/70B_lora.yaml
#

# Model Arguments
model:
_component_: torchtune.models.llama3.lora_llama3_70b
lora_attn_modules: ['q_proj', 'k_proj', 'v_proj']
apply_lora_to_mlp: False
apply_lora_to_output: False
lora_rank: 16
lora_alpha: 32
Comment on lines +16 to +19
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are these defaults set? Are rank and alpha higher than our 7/8B defaults cause the embedding dim is larger?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mostly copied these from https://github.com/pytorch/torchtune/blob/main/recipes/configs/llama2/70B_lora.yaml.

I'm not sure whether rank and alpha being higher is because of the embedding dim being larger - what do those have to do with each other?


tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Meta-Llama-3-70b/original/tokenizer.model

checkpointer:
_component_: torchtune.utils.FullModelHFCheckpointer
checkpoint_dir: /tmp/Meta-Llama-3-70b
checkpoint_files: [
model-00001-of-00030.safetensors,
model-00002-of-00030.safetensors,
model-00003-of-00030.safetensors,
model-00004-of-00030.safetensors,
model-00005-of-00030.safetensors,
model-00006-of-00030.safetensors,
model-00007-of-00030.safetensors,
model-00008-of-00030.safetensors,
model-00009-of-00030.safetensors,
model-00010-of-00030.safetensors,
model-00011-of-00030.safetensors,
model-00012-of-00030.safetensors,
model-00013-of-00030.safetensors,
model-00014-of-00030.safetensors,
model-00015-of-00030.safetensors,
model-00016-of-00030.safetensors,
model-00017-of-00030.safetensors,
model-00018-of-00030.safetensors,
model-00019-of-00030.safetensors,
model-00020-of-00030.safetensors,
model-00021-of-00030.safetensors,
model-00022-of-00030.safetensors,
model-00023-of-00030.safetensors,
model-00024-of-00030.safetensors,
model-00025-of-00030.safetensors,
model-00026-of-00030.safetensors,
model-00027-of-00030.safetensors,
model-00028-of-00030.safetensors,
model-00029-of-00030.safetensors,
model-00030-of-00030.safetensors,
]
Comment on lines +28 to +59
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lol we really need to have a way to generate these programmatically

recipe_checkpoint: null
output_dir: /tmp/Meta-Llama-3-70b
model_type: LLAMA3
resume_from_checkpoint: False

# Dataset and Sampler
dataset:
_component_: torchtune.datasets.alpaca_dataset
train_on_input: True
seed: null
shuffle: True
batch_size: 2

# Optimizer and Scheduler
optimizer:
_component_: torch.optim.AdamW
weight_decay: 0.01
lr: 3e-4
lr_scheduler:
_component_: torchtune.modules.get_cosine_schedule_with_warmup
num_warmup_steps: 100

loss:
_component_: torch.nn.CrossEntropyLoss

# Training
epochs: 1
max_steps_per_epoch: null
gradient_accumulation_steps: 1

# Logging
output_dir: /tmp/lora_finetune_output
metric_logger:
_component_: torchtune.utils.metric_logging.DiskLogger
log_dir: ${output_dir}
log_every_n_steps: null

# Environment
device: cuda
dtype: bf16
enable_activation_checkpointing: True
1 change: 1 addition & 0 deletions torchtune/_recipe_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class Recipe:
Config(name="llama2/7B_lora", file_path="llama2/7B_lora.yaml"),
Config(name="llama2/13B_lora", file_path="llama2/13B_lora.yaml"),
Config(name="llama2/70B_lora", file_path="llama2/70B_lora.yaml"),
Config(name="llama3/70B_lora", file_path="llama3/70B_lora.yaml"),
Config(name="llama3/8B_lora", file_path="llama3/8B_lora.yaml"),
Config(name="mistral/7B_lora", file_path="mistral/7B_lora.yaml"),
],
Expand Down
4 changes: 4 additions & 0 deletions torchtune/models/llama3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
from ._component_builders import llama3, lora_llama3

from ._model_builders import ( # noqa
llama3_70b,
llama3_8b,
llama3_tokenizer,
lora_llama3_70b,
lora_llama3_8b,
qlora_llama3_8b,
)
Expand All @@ -17,9 +19,11 @@
__all__ = [
"llama3",
"llama3_8b",
"llama3_70b",
"llama3_tokenizer",
"lora_llama3",
"lora_llama3_8b",
"lora_llama3_70b",
"qlora_llama3_8b",
"scale_hidden_dim_for_mlp",
]
73 changes: 73 additions & 0 deletions torchtune/models/llama3/_model_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,27 @@ def llama3_8b() -> TransformerDecoder:
)


def llama3_70b() -> TransformerDecoder:
"""
Builder for creating a Llama3 model initialized w/ the default 70B parameter values.

Returns:
TransformerDecoder: Instantiation of Llama3 70 model
"""
return llama3(
vocab_size=128_256,
num_layers=80,
num_heads=64,
num_kv_heads=8,
embed_dim=8192,
max_seq_len=4096,
intermediate_dim=28672,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500000.0,
)


def llama3_tokenizer(path: str) -> TikTokenTokenizer:
tiktoken = TikTokenTokenizer(path)
tiktoken.pad_id = 0
Expand Down Expand Up @@ -100,6 +121,58 @@ def lora_llama3_8b(
quantize_base=quantize_base,
)


def lora_llama3_70b(
lora_attn_modules: List[LORA_ATTN_MODULES],
apply_lora_to_mlp: bool = False,
apply_lora_to_output: bool = False,
lora_rank: int = 8,
lora_alpha: float = 16,
quantize_base: bool = False,
) -> TransformerDecoder:
"""
Builder for creating a Llama3 70B model with LoRA enabled.

The Llama3 defaults are the same as in :func:`~torchtune.models.llama3.llama3_70b`,
while LoRA default params are based on
https://github.com/tloen/alpaca-lora/blob/8bb8579e403dc78e37fe81ffbb253c413007323f/finetune.py#L41-L43.

Args:
lora_attn_modules (List[LORA_ATTN_MODULES]): list of which linear layers
LoRA should be applied to in each self-attention block. Options are
``{"q_proj", "k_proj", "v_proj", "output_proj"}``.
apply_lora_to_mlp (bool): whether to apply LoRA to the MLP in each transformer layer.
Default: False
apply_lora_to_output (bool): whether to apply LoRA to the model's final output projection.
Default: False
lora_rank (int): rank of each low-rank approximation
lora_alpha (float): scaling factor for the low-rank approximation
quantize_base (bool): Whether to quantize base model weights

Returns:
TransformerDecoder: Instantiation of Llama3 8B model with LoRA applied
"""
return lora_llama3(
lora_attn_modules=lora_attn_modules,
apply_lora_to_mlp=apply_lora_to_mlp,
apply_lora_to_output=apply_lora_to_output,
vocab_size=128_256,
num_layers=80,
num_heads=64,
num_kv_heads=8,
embed_dim=8192,
max_seq_len=4096,
intermediate_dim=28672,
attn_dropout=0.0,
norm_eps=1e-5,
rope_base=500000.0,
lora_rank=lora_rank,
lora_alpha=lora_alpha,
lora_dropout=0.05,
quantize_base=quantize_base,
)


qlora_llama3_8b = partial(lora_llama3_8b, quantize_base=True)

qlora_llama3_8b.__doc__ = """
Expand Down
Loading