diff --git a/docs/source/api_ref_modules.rst b/docs/source/api_ref_modules.rst index c62e338801..51b22c1bb9 100644 --- a/docs/source/api_ref_modules.rst +++ b/docs/source/api_ref_modules.rst @@ -131,3 +131,4 @@ Losses rlhf.loss.RSOLoss rlhf.loss.IPOLoss rlhf.loss.SimPOLoss + loss.CEWithChunkedOutputLoss diff --git a/recipes/configs/code_llama2/7B_full_low_memory.yaml b/recipes/configs/code_llama2/7B_full_low_memory.yaml index 76b59c0735..b8e91eae46 100644 --- a/recipes/configs/code_llama2/7B_full_low_memory.yaml +++ b/recipes/configs/code_llama2/7B_full_low_memory.yaml @@ -59,7 +59,7 @@ optimizer: lr: 2e-5 optimizer_in_bwd: True loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss compile: False # Training env diff --git a/recipes/configs/code_llama2/7B_lora_single_device.yaml b/recipes/configs/code_llama2/7B_lora_single_device.yaml index 6081643ece..28104ce3e0 100644 --- a/recipes/configs/code_llama2/7B_lora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_lora_single_device.yaml @@ -65,7 +65,7 @@ lr_scheduler: _component_: torchtune.modules.get_cosine_schedule_with_warmup num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss compile: False # Training env diff --git a/recipes/configs/code_llama2/7B_qlora_single_device.yaml b/recipes/configs/code_llama2/7B_qlora_single_device.yaml index 049bfd5db9..f2dd2e45fb 100644 --- a/recipes/configs/code_llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/code_llama2/7B_qlora_single_device.yaml @@ -65,7 +65,7 @@ lr_scheduler: _component_: torchtune.modules.get_cosine_schedule_with_warmup num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss compile: False # Training env diff --git a/recipes/configs/dev/8B_full_experimental.yaml b/recipes/configs/dev/8B_full_experimental.yaml index 09a6ee704b..182b748e17 100644 --- a/recipes/configs/dev/8B_full_experimental.yaml +++ b/recipes/configs/dev/8B_full_experimental.yaml @@ -55,7 +55,7 @@ optimizer: foreach: False loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 diff --git a/recipes/configs/dev/llama2/13B_lora_fsdp2.yaml b/recipes/configs/dev/llama2/13B_lora_fsdp2.yaml index fb715e47ba..4dd4d1c3c4 100644 --- a/recipes/configs/dev/llama2/13B_lora_fsdp2.yaml +++ b/recipes/configs/dev/llama2/13B_lora_fsdp2.yaml @@ -70,7 +70,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/dev/llama2/70B_lora_fsdp2.yaml b/recipes/configs/dev/llama2/70B_lora_fsdp2.yaml index 4822310897..3666372b86 100644 --- a/recipes/configs/dev/llama2/70B_lora_fsdp2.yaml +++ b/recipes/configs/dev/llama2/70B_lora_fsdp2.yaml @@ -70,7 +70,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml b/recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml index 719105daf5..34e96c18c2 100644 --- a/recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml +++ b/recipes/configs/dev/llama2/70B_qlora_fsdp2.yaml @@ -71,7 +71,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss fsdp: cpu_offload: False diff --git a/recipes/configs/dev/llama2/7B_lora_fsdp2.yaml b/recipes/configs/dev/llama2/7B_lora_fsdp2.yaml index af3fd48e9f..4d25bf11c4 100644 --- a/recipes/configs/dev/llama2/7B_lora_fsdp2.yaml +++ b/recipes/configs/dev/llama2/7B_lora_fsdp2.yaml @@ -67,7 +67,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml b/recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml index f88094b752..58294454b7 100644 --- a/recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml +++ b/recipes/configs/dev/llama2/7B_qlora_fsdp2.yaml @@ -67,7 +67,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss fsdp: cpu_offload: False diff --git a/recipes/configs/gemma/2B_full.yaml b/recipes/configs/gemma/2B_full.yaml index 1e9b0b5632..a65b94d8fa 100644 --- a/recipes/configs/gemma/2B_full.yaml +++ b/recipes/configs/gemma/2B_full.yaml @@ -50,7 +50,7 @@ optimizer: _component_: torch.optim.AdamW lr: 2e-5 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 diff --git a/recipes/configs/gemma/2B_lora.yaml b/recipes/configs/gemma/2B_lora.yaml index 11ce335740..25fdb703cf 100644 --- a/recipes/configs/gemma/2B_lora.yaml +++ b/recipes/configs/gemma/2B_lora.yaml @@ -57,7 +57,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Fine-tuning arguments batch_size: 4 diff --git a/recipes/configs/gemma/2B_lora_single_device.yaml b/recipes/configs/gemma/2B_lora_single_device.yaml index 67f577e424..e7f02cc6ae 100644 --- a/recipes/configs/gemma/2B_lora_single_device.yaml +++ b/recipes/configs/gemma/2B_lora_single_device.yaml @@ -56,7 +56,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Fine-tuning arguments batch_size: 4 diff --git a/recipes/configs/gemma/2B_qlora_single_device.yaml b/recipes/configs/gemma/2B_qlora_single_device.yaml index 3a5edfd2a2..641a66d6cf 100644 --- a/recipes/configs/gemma/2B_qlora_single_device.yaml +++ b/recipes/configs/gemma/2B_qlora_single_device.yaml @@ -56,7 +56,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Fine-tuning arguments batch_size: 4 diff --git a/recipes/configs/gemma/7B_full.yaml b/recipes/configs/gemma/7B_full.yaml index 9e9bd13515..4aeb9f8cc3 100644 --- a/recipes/configs/gemma/7B_full.yaml +++ b/recipes/configs/gemma/7B_full.yaml @@ -52,7 +52,7 @@ optimizer: _component_: torch.optim.AdamW lr: 2e-5 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 diff --git a/recipes/configs/gemma/7B_lora.yaml b/recipes/configs/gemma/7B_lora.yaml index fb78a3eebb..dfabadd4d3 100644 --- a/recipes/configs/gemma/7B_lora.yaml +++ b/recipes/configs/gemma/7B_lora.yaml @@ -59,7 +59,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Fine-tuning arguments batch_size: 4 diff --git a/recipes/configs/gemma/7B_lora_single_device.yaml b/recipes/configs/gemma/7B_lora_single_device.yaml index 7f9fd7ea39..dde3924e46 100644 --- a/recipes/configs/gemma/7B_lora_single_device.yaml +++ b/recipes/configs/gemma/7B_lora_single_device.yaml @@ -58,7 +58,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Fine-tuning arguments batch_size: 8 diff --git a/recipes/configs/gemma/7B_qlora_single_device.yaml b/recipes/configs/gemma/7B_qlora_single_device.yaml index 08f05bc6f5..539b008a6e 100644 --- a/recipes/configs/gemma/7B_qlora_single_device.yaml +++ b/recipes/configs/gemma/7B_qlora_single_device.yaml @@ -58,7 +58,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Fine-tuning arguments batch_size: 4 diff --git a/recipes/configs/llama2/13B_full.yaml b/recipes/configs/llama2/13B_full.yaml index 70cedb0967..bf1085da04 100644 --- a/recipes/configs/llama2/13B_full.yaml +++ b/recipes/configs/llama2/13B_full.yaml @@ -54,7 +54,7 @@ optimizer: _component_: torch.optim.AdamW lr: 2e-5 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 diff --git a/recipes/configs/llama2/13B_lora.yaml b/recipes/configs/llama2/13B_lora.yaml index 6a2eb5d87f..33a7068975 100644 --- a/recipes/configs/llama2/13B_lora.yaml +++ b/recipes/configs/llama2/13B_lora.yaml @@ -66,7 +66,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/llama2/13B_qlora_single_device.yaml b/recipes/configs/llama2/13B_qlora_single_device.yaml index 445b9ef8af..fe5f55ffb9 100644 --- a/recipes/configs/llama2/13B_qlora_single_device.yaml +++ b/recipes/configs/llama2/13B_qlora_single_device.yaml @@ -61,7 +61,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/llama2/70B_lora.yaml b/recipes/configs/llama2/70B_lora.yaml index 44f694a8df..b0acbd8aa2 100644 --- a/recipes/configs/llama2/70B_lora.yaml +++ b/recipes/configs/llama2/70B_lora.yaml @@ -66,7 +66,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/llama2/7B_full.yaml b/recipes/configs/llama2/7B_full.yaml index f0235b4514..00eac762f1 100644 --- a/recipes/configs/llama2/7B_full.yaml +++ b/recipes/configs/llama2/7B_full.yaml @@ -53,7 +53,7 @@ optimizer: _component_: torch.optim.AdamW lr: 2e-5 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 diff --git a/recipes/configs/llama2/7B_full_low_memory.yaml b/recipes/configs/llama2/7B_full_low_memory.yaml index 1fac846f40..e00bbc5b72 100644 --- a/recipes/configs/llama2/7B_full_low_memory.yaml +++ b/recipes/configs/llama2/7B_full_low_memory.yaml @@ -56,7 +56,7 @@ optimizer: lr: 2e-5 optimizer_in_bwd: True loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 compile: False diff --git a/recipes/configs/llama2/7B_lora.yaml b/recipes/configs/llama2/7B_lora.yaml index 9193b78aa8..b66f3651b8 100644 --- a/recipes/configs/llama2/7B_lora.yaml +++ b/recipes/configs/llama2/7B_lora.yaml @@ -63,7 +63,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/llama2/7B_lora_single_device.yaml b/recipes/configs/llama2/7B_lora_single_device.yaml index b147c31b45..5ce28ead82 100644 --- a/recipes/configs/llama2/7B_lora_single_device.yaml +++ b/recipes/configs/llama2/7B_lora_single_device.yaml @@ -61,7 +61,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/llama2/7B_qat_full.yaml b/recipes/configs/llama2/7B_qat_full.yaml index e8f54fafe3..be6b4c8a42 100644 --- a/recipes/configs/llama2/7B_qat_full.yaml +++ b/recipes/configs/llama2/7B_qat_full.yaml @@ -49,7 +49,7 @@ optimizer: _component_: torch.optim.AdamW lr: 2e-5 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 diff --git a/recipes/configs/llama2/7B_qlora_single_device.yaml b/recipes/configs/llama2/7B_qlora_single_device.yaml index e114993cac..91032bb408 100644 --- a/recipes/configs/llama2/7B_qlora_single_device.yaml +++ b/recipes/configs/llama2/7B_qlora_single_device.yaml @@ -60,7 +60,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/llama3/70B_full.yaml b/recipes/configs/llama3/70B_full.yaml index 238d5d2ef2..51b695d29f 100644 --- a/recipes/configs/llama3/70B_full.yaml +++ b/recipes/configs/llama3/70B_full.yaml @@ -85,7 +85,7 @@ optimizer: fused: True loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 diff --git a/recipes/configs/llama3/70B_lora.yaml b/recipes/configs/llama3/70B_lora.yaml index 73a8abea20..18cc0ed75a 100644 --- a/recipes/configs/llama3/70B_lora.yaml +++ b/recipes/configs/llama3/70B_lora.yaml @@ -81,7 +81,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/llama3/8B_full.yaml b/recipes/configs/llama3/8B_full.yaml index 6a60d58080..fb5c105dd2 100644 --- a/recipes/configs/llama3/8B_full.yaml +++ b/recipes/configs/llama3/8B_full.yaml @@ -55,7 +55,7 @@ optimizer: foreach: False loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 diff --git a/recipes/configs/llama3/8B_full_single_device.yaml b/recipes/configs/llama3/8B_full_single_device.yaml index fafbbf275c..e0cc059a27 100644 --- a/recipes/configs/llama3/8B_full_single_device.yaml +++ b/recipes/configs/llama3/8B_full_single_device.yaml @@ -54,7 +54,7 @@ optimizer: _component_: bitsandbytes.optim.PagedAdamW8bit lr: 2e-5 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 optimizer_in_bwd: True diff --git a/recipes/configs/llama3/8B_lora.yaml b/recipes/configs/llama3/8B_lora.yaml index 13ec2fcdef..165b7fed7d 100644 --- a/recipes/configs/llama3/8B_lora.yaml +++ b/recipes/configs/llama3/8B_lora.yaml @@ -61,7 +61,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/llama3/8B_lora_single_device.yaml b/recipes/configs/llama3/8B_lora_single_device.yaml index 7d3f10d122..353b0d3bd4 100644 --- a/recipes/configs/llama3/8B_lora_single_device.yaml +++ b/recipes/configs/llama3/8B_lora_single_device.yaml @@ -60,7 +60,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/llama3/8B_qat_full.yaml b/recipes/configs/llama3/8B_qat_full.yaml index 76d88e28a6..fc184e2fe1 100644 --- a/recipes/configs/llama3/8B_qat_full.yaml +++ b/recipes/configs/llama3/8B_qat_full.yaml @@ -55,7 +55,7 @@ optimizer: foreach: False loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 diff --git a/recipes/configs/llama3/8B_qlora_single_device.yaml b/recipes/configs/llama3/8B_qlora_single_device.yaml index e3b46c56e3..fc220aae60 100644 --- a/recipes/configs/llama3/8B_qlora_single_device.yaml +++ b/recipes/configs/llama3/8B_qlora_single_device.yaml @@ -59,7 +59,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/llama3_1/70B_full.yaml b/recipes/configs/llama3_1/70B_full.yaml index 3af5931f4b..6e4e5e754c 100644 --- a/recipes/configs/llama3_1/70B_full.yaml +++ b/recipes/configs/llama3_1/70B_full.yaml @@ -85,7 +85,7 @@ optimizer: fused: True loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 diff --git a/recipes/configs/llama3_1/70B_lora.yaml b/recipes/configs/llama3_1/70B_lora.yaml index 771d51aef9..16b5ca923b 100644 --- a/recipes/configs/llama3_1/70B_lora.yaml +++ b/recipes/configs/llama3_1/70B_lora.yaml @@ -80,7 +80,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/llama3_1/8B_full.yaml b/recipes/configs/llama3_1/8B_full.yaml index 45b610bd01..72df13dde1 100644 --- a/recipes/configs/llama3_1/8B_full.yaml +++ b/recipes/configs/llama3_1/8B_full.yaml @@ -58,7 +58,7 @@ optimizer: foreach: False loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 diff --git a/recipes/configs/llama3_1/8B_full_single_device.yaml b/recipes/configs/llama3_1/8B_full_single_device.yaml index 8dae646284..90cae8e77e 100644 --- a/recipes/configs/llama3_1/8B_full_single_device.yaml +++ b/recipes/configs/llama3_1/8B_full_single_device.yaml @@ -57,7 +57,7 @@ optimizer: _component_: bitsandbytes.optim.PagedAdamW8bit lr: 2e-5 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 optimizer_in_bwd: True @@ -79,3 +79,28 @@ metric_logger: output_dir: /tmp/full-llama3.1-finetune log_every_n_steps: 1 log_peak_memory_stats: False + +# Profiler (disabled) +profiler: + _component_: torchtune.utils.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: True + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 1 + warmup_steps: 2 + active_steps: 1 + num_cycles: 1 diff --git a/recipes/configs/llama3_1/8B_lora.yaml b/recipes/configs/llama3_1/8B_lora.yaml index 37df0dc957..985b77ed5b 100644 --- a/recipes/configs/llama3_1/8B_lora.yaml +++ b/recipes/configs/llama3_1/8B_lora.yaml @@ -64,7 +64,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/llama3_1/8B_lora_single_device.yaml b/recipes/configs/llama3_1/8B_lora_single_device.yaml index 394837851e..f5ed44560b 100644 --- a/recipes/configs/llama3_1/8B_lora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_lora_single_device.yaml @@ -63,7 +63,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/llama3_1/8B_qlora_single_device.yaml b/recipes/configs/llama3_1/8B_qlora_single_device.yaml index ee709c4f37..9efbc15a44 100644 --- a/recipes/configs/llama3_1/8B_qlora_single_device.yaml +++ b/recipes/configs/llama3_1/8B_qlora_single_device.yaml @@ -62,7 +62,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/mistral/7B_full.yaml b/recipes/configs/mistral/7B_full.yaml index 9f23c9fb74..5257d94e97 100644 --- a/recipes/configs/mistral/7B_full.yaml +++ b/recipes/configs/mistral/7B_full.yaml @@ -56,7 +56,7 @@ optimizer: _component_: torch.optim.AdamW lr: 5e-6 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 diff --git a/recipes/configs/mistral/7B_full_low_memory.yaml b/recipes/configs/mistral/7B_full_low_memory.yaml index eb5578e9a6..a4e30fc698 100644 --- a/recipes/configs/mistral/7B_full_low_memory.yaml +++ b/recipes/configs/mistral/7B_full_low_memory.yaml @@ -58,7 +58,7 @@ optimizer: _component_: bitsandbytes.optim.PagedAdamW lr: 5e-6 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 optimizer_in_bwd: True diff --git a/recipes/configs/mistral/7B_lora.yaml b/recipes/configs/mistral/7B_lora.yaml index 5049f87fd6..3ce8436dca 100644 --- a/recipes/configs/mistral/7B_lora.yaml +++ b/recipes/configs/mistral/7B_lora.yaml @@ -65,7 +65,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Fine-tuning arguments batch_size: 4 diff --git a/recipes/configs/mistral/7B_lora_single_device.yaml b/recipes/configs/mistral/7B_lora_single_device.yaml index 844188a43e..6f07a819a9 100644 --- a/recipes/configs/mistral/7B_lora_single_device.yaml +++ b/recipes/configs/mistral/7B_lora_single_device.yaml @@ -62,7 +62,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Fine-tuning arguments batch_size: 4 diff --git a/recipes/configs/mistral/7B_qlora_single_device.yaml b/recipes/configs/mistral/7B_qlora_single_device.yaml index b55d2b3d99..139650989e 100644 --- a/recipes/configs/mistral/7B_qlora_single_device.yaml +++ b/recipes/configs/mistral/7B_qlora_single_device.yaml @@ -63,7 +63,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Fine-tuning arguments batch_size: 4 diff --git a/recipes/configs/phi3/mini_full.yaml b/recipes/configs/phi3/mini_full.yaml index d40c67c07a..578628f695 100644 --- a/recipes/configs/phi3/mini_full.yaml +++ b/recipes/configs/phi3/mini_full.yaml @@ -55,7 +55,7 @@ optimizer: _component_: torch.optim.AdamW lr: 5e-6 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training env device: cuda diff --git a/recipes/configs/phi3/mini_full_low_memory.yaml b/recipes/configs/phi3/mini_full_low_memory.yaml index 895a9475af..2791e004a5 100644 --- a/recipes/configs/phi3/mini_full_low_memory.yaml +++ b/recipes/configs/phi3/mini_full_low_memory.yaml @@ -58,7 +58,7 @@ optimizer: lr: 5e-6 optimizer_in_bwd: True loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss compile: False # Training env diff --git a/recipes/configs/phi3/mini_lora.yaml b/recipes/configs/phi3/mini_lora.yaml index e82a8b1df3..b120d10e1d 100644 --- a/recipes/configs/phi3/mini_lora.yaml +++ b/recipes/configs/phi3/mini_lora.yaml @@ -65,7 +65,7 @@ lr_scheduler: _component_: torchtune.modules.get_cosine_schedule_with_warmup num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training env device: cuda diff --git a/recipes/configs/phi3/mini_lora_single_device.yaml b/recipes/configs/phi3/mini_lora_single_device.yaml index 478f4c8933..2fa3d59d56 100644 --- a/recipes/configs/phi3/mini_lora_single_device.yaml +++ b/recipes/configs/phi3/mini_lora_single_device.yaml @@ -63,7 +63,7 @@ lr_scheduler: _component_: torchtune.modules.get_cosine_schedule_with_warmup num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss compile: False # Training env diff --git a/recipes/configs/phi3/mini_qlora_single_device.yaml b/recipes/configs/phi3/mini_qlora_single_device.yaml index d48cc4a666..97188db4d9 100644 --- a/recipes/configs/phi3/mini_qlora_single_device.yaml +++ b/recipes/configs/phi3/mini_qlora_single_device.yaml @@ -63,7 +63,7 @@ lr_scheduler: _component_: torchtune.modules.get_cosine_schedule_with_warmup num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss compile: False # Training env diff --git a/recipes/configs/qwen2/0.5B_full.yaml b/recipes/configs/qwen2/0.5B_full.yaml index 824e183f76..f0c6a9df51 100644 --- a/recipes/configs/qwen2/0.5B_full.yaml +++ b/recipes/configs/qwen2/0.5B_full.yaml @@ -52,7 +52,7 @@ optimizer: _component_: torch.optim.AdamW lr: 2e-5 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 16 diff --git a/recipes/configs/qwen2/0.5B_full_single_device.yaml b/recipes/configs/qwen2/0.5B_full_single_device.yaml index 5039c13bad..adcc03d8aa 100644 --- a/recipes/configs/qwen2/0.5B_full_single_device.yaml +++ b/recipes/configs/qwen2/0.5B_full_single_device.yaml @@ -51,7 +51,7 @@ optimizer: lr: 2e-5 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss optimizer_in_bwd: False max_steps_per_epoch: null diff --git a/recipes/configs/qwen2/0.5B_lora.yaml b/recipes/configs/qwen2/0.5B_lora.yaml index 4c214a766f..39f5404646 100644 --- a/recipes/configs/qwen2/0.5B_lora.yaml +++ b/recipes/configs/qwen2/0.5B_lora.yaml @@ -62,7 +62,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/qwen2/0.5B_lora_single_device.yaml b/recipes/configs/qwen2/0.5B_lora_single_device.yaml index 6bbee99cc7..ad07ddddcb 100644 --- a/recipes/configs/qwen2/0.5B_lora_single_device.yaml +++ b/recipes/configs/qwen2/0.5B_lora_single_device.yaml @@ -60,7 +60,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/qwen2/1.5B_full.yaml b/recipes/configs/qwen2/1.5B_full.yaml index 067bf078e9..fb0a4c31c0 100644 --- a/recipes/configs/qwen2/1.5B_full.yaml +++ b/recipes/configs/qwen2/1.5B_full.yaml @@ -52,7 +52,7 @@ optimizer: _component_: torch.optim.AdamW lr: 2e-5 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 diff --git a/recipes/configs/qwen2/1.5B_full_single_device.yaml b/recipes/configs/qwen2/1.5B_full_single_device.yaml index 950f754b3e..3124389192 100644 --- a/recipes/configs/qwen2/1.5B_full_single_device.yaml +++ b/recipes/configs/qwen2/1.5B_full_single_device.yaml @@ -58,7 +58,7 @@ optimizer: optimizer_in_bwd: True loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 1 diff --git a/recipes/configs/qwen2/1.5B_lora.yaml b/recipes/configs/qwen2/1.5B_lora.yaml index 7cee3f6df4..065bcd28bc 100644 --- a/recipes/configs/qwen2/1.5B_lora.yaml +++ b/recipes/configs/qwen2/1.5B_lora.yaml @@ -58,7 +58,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/qwen2/1.5B_lora_single_device.yaml b/recipes/configs/qwen2/1.5B_lora_single_device.yaml index f3c3c9686a..73a6c569c5 100644 --- a/recipes/configs/qwen2/1.5B_lora_single_device.yaml +++ b/recipes/configs/qwen2/1.5B_lora_single_device.yaml @@ -58,7 +58,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/qwen2/7B_full.yaml b/recipes/configs/qwen2/7B_full.yaml index ce22c4451b..8031600802 100644 --- a/recipes/configs/qwen2/7B_full.yaml +++ b/recipes/configs/qwen2/7B_full.yaml @@ -55,7 +55,7 @@ optimizer: _component_: torch.optim.AdamW lr: 5e-6 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 16 diff --git a/recipes/configs/qwen2/7B_full_single_device.yaml b/recipes/configs/qwen2/7B_full_single_device.yaml index 51c673993e..ba96873838 100644 --- a/recipes/configs/qwen2/7B_full_single_device.yaml +++ b/recipes/configs/qwen2/7B_full_single_device.yaml @@ -58,7 +58,7 @@ optimizer: lr: 5e-6 optimizer_in_bwd: True loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss max_steps_per_epoch: null gradient_accumulation_steps: 16 compile: False diff --git a/recipes/configs/qwen2/7B_lora.yaml b/recipes/configs/qwen2/7B_lora.yaml index 0e3728dbb6..9cd99bb7cc 100644 --- a/recipes/configs/qwen2/7B_lora.yaml +++ b/recipes/configs/qwen2/7B_lora.yaml @@ -64,7 +64,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/configs/qwen2/7B_lora_single_device.yaml b/recipes/configs/qwen2/7B_lora_single_device.yaml index dbbade6503..f2b84e4d61 100644 --- a/recipes/configs/qwen2/7B_lora_single_device.yaml +++ b/recipes/configs/qwen2/7B_lora_single_device.yaml @@ -62,7 +62,7 @@ lr_scheduler: num_warmup_steps: 100 loss: - _component_: torch.nn.CrossEntropyLoss + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss # Training epochs: 1 diff --git a/recipes/dev/lora_finetune_fsdp2.py b/recipes/dev/lora_finetune_fsdp2.py index 10a897e2e9..98e0bff4a5 100644 --- a/recipes/dev/lora_finetune_fsdp2.py +++ b/recipes/dev/lora_finetune_fsdp2.py @@ -231,11 +231,25 @@ def setup(self, cfg: DictConfig) -> None: else None, ) + # initialize loss self._loss_fn = config.instantiate(cfg.loss) - if self._model_compile: - log.info("Compiling loss with torch.compile...") - backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") - self._loss_fn = torch.compile(self._loss_fn, backend=backend) + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + if self._model_compile: + log.info("Compiling loss with torch.compile...") + # For CEWithChunkedOutputLoss, if we compile the entire class + # we lose the benefits from the chunked loss. + # Therefore, we only compile the cross entropy function + upcasting + self._loss_fn.compute_cross_entropy = torch.compile( + self._loss_fn.compute_cross_entropy, backend=backend + ) + else: + if self._model_compile: + log.info("Compiling loss with torch.compile...") + self._loss_fn = torch.compile(self._loss_fn, backend=backend) + log.info("Loss is initialized.") # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after all of these are setup @@ -270,6 +284,11 @@ def setup(self, cfg: DictConfig) -> None: last_epoch=self.global_step - 1, ) + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + def _setup_model( self, cfg_model: DictConfig, @@ -597,12 +616,20 @@ def train(self) -> None: ) logits = self._model(tokens, mask=mask, input_pos=input_pos) - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - logits = logits.transpose(1, 2) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + # Compute loss loss = self._loss_fn(logits, labels) + # free logits otherwise it peaks backward memory del logits diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 59e6ed3f8a..bc74c611a5 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os import sys import time @@ -203,6 +204,7 @@ def setup(self, cfg: DictConfig) -> None: checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + self._model_compile = cfg.get("compile", False) self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, @@ -222,7 +224,25 @@ def setup(self, cfg: DictConfig) -> None: else None, ) + # initialize loss self._loss_fn = config.instantiate(cfg.loss) + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + if self._model_compile: + log.info("Compiling loss with torch.compile...") + # For CEWithChunkedOutputLoss, if we compile the entire class + # we lose the benefits from the chunked loss. + # Therefore, we only compile the cross entropy function + upcasting + self._loss_fn.compute_cross_entropy = torch.compile( + self._loss_fn.compute_cross_entropy, backend=backend + ) + else: + if self._model_compile: + log.info("Compiling loss with torch.compile...") + self._loss_fn = torch.compile(self._loss_fn, backend=backend) + log.info("Loss is initialized.") # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after both of these are initialized @@ -253,6 +273,11 @@ def setup(self, cfg: DictConfig) -> None: # if cfg is missing profiler key or if `cfg.profiler.enabled = False` self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + def _setup_profiler( self, cfg_profiler: Optional[DictConfig] = None ) -> Union[torch.profiler.profile, DummyProfiler]: @@ -601,12 +626,20 @@ def train(self) -> None: ) logits = self._model(tokens, mask=mask, input_pos=input_pos) - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - logits = logits.transpose(1, 2) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + # Compute loss loss = self._loss_fn(logits, labels) + # free logits otherwise it peaks backward memory del logits diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index 5c0b1fa6d0..433a04dc20 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -222,11 +222,24 @@ def setup(self, cfg: DictConfig) -> None: ), ) + # initialize loss self._loss_fn = config.instantiate(cfg.loss) - if self._model_compile: - log.info("Compiling loss with torch.compile...") - backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") - self._loss_fn = torch.compile(self._loss_fn, backend=backend) + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + if self._model_compile: + log.info("Compiling loss with torch.compile...") + # For CEWithChunkedOutputLoss, if we compile the entire class + # we lose the benefits from the chunked loss. + # Therefore, we only compile the cross entropy function + upcasting + self._loss_fn.compute_cross_entropy = torch.compile( + self._loss_fn.compute_cross_entropy, backend=backend + ) + else: + if self._model_compile: + log.info("Compiling loss with torch.compile...") + self._loss_fn = torch.compile(self._loss_fn, backend=backend) log.info("Loss is initialized.") # sampler and dataloader depend on the tokenizer and loss_fn and should be @@ -258,6 +271,11 @@ def setup(self, cfg: DictConfig) -> None: # if cfg is missing profiler key or if `cfg.profiler.enabled = False` self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + def _setup_profiler( self, cfg_profiler: Optional[DictConfig] = None ) -> Union[torch.profiler.profile, DummyProfiler]: @@ -485,10 +503,17 @@ def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: input_pos = batch.get("input_pos", None) # shape [b, s] logits = self._model(tokens, mask=mask, input_pos=input_pos) - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - logits = logits.transpose(1, 2) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + # Compute loss loss = self._loss_fn(logits, labels) # free logits otherwise it peaks backward memory diff --git a/recipes/lora_finetune_distributed.py b/recipes/lora_finetune_distributed.py index a1a96d87bd..d1a5f8ffac 100644 --- a/recipes/lora_finetune_distributed.py +++ b/recipes/lora_finetune_distributed.py @@ -224,6 +224,7 @@ def setup(self, cfg: DictConfig) -> None: checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + self._model_compile = cfg.get("compile", False) self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, @@ -243,7 +244,25 @@ def setup(self, cfg: DictConfig) -> None: else None, ) + # initialize loss self._loss_fn = config.instantiate(cfg.loss) + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + if self._model_compile: + log.info("Compiling loss with torch.compile...") + # For CEWithChunkedOutputLoss, if we compile the entire class + # we lose the benefits from the chunked loss. + # Therefore, we only compile the cross entropy function + upcasting + self._loss_fn.compute_cross_entropy = torch.compile( + self._loss_fn.compute_cross_entropy, backend=backend + ) + else: + if self._model_compile: + log.info("Compiling loss with torch.compile...") + self._loss_fn = torch.compile(self._loss_fn, backend=backend) + log.info("Loss is initialized.") # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after all of these are setup @@ -282,6 +301,11 @@ def setup(self, cfg: DictConfig) -> None: # if cfg is missing profiler key or if `cfg.profiler.enabled = False` self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + def _setup_profiler( self, cfg_profiler: Optional[DictConfig] = None ) -> Union[torch.profiler.profile, DummyProfiler]: @@ -686,10 +710,16 @@ def train(self) -> None: ) logits = self._model(tokens, mask=mask, input_pos=input_pos) - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - logits = logits.transpose(1, 2) + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + # Compute loss loss = self._loss_fn(logits, labels) # free logits otherwise it peaks backward memory diff --git a/recipes/lora_finetune_single_device.py b/recipes/lora_finetune_single_device.py index 3a075d00d6..462cbb69dc 100644 --- a/recipes/lora_finetune_single_device.py +++ b/recipes/lora_finetune_single_device.py @@ -208,6 +208,7 @@ def setup(self, cfg: DictConfig) -> None: self._model_compile = cfg.compile checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + # set up model self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, @@ -230,11 +231,24 @@ def setup(self, cfg: DictConfig) -> None: ), ) + # initialize loss self._loss_fn = config.instantiate(cfg.loss) - if self._model_compile: - log.info("Compiling loss with torch.compile...") - backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") - self._loss_fn = torch.compile(self._loss_fn, backend=backend) + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + if self._model_compile: + log.info("Compiling loss with torch.compile...") + # For CEWithChunkedOutputLoss, if we compile the entire class + # we lose the benefits from the chunked loss. + # Therefore, we only compile the cross entropy function + upcasting + self._loss_fn.compute_cross_entropy = torch.compile( + self._loss_fn.compute_cross_entropy, backend=backend + ) + else: + if self._model_compile: + log.info("Compiling loss with torch.compile...") + self._loss_fn = torch.compile(self._loss_fn, backend=backend) log.info("Loss is initialized.") # Dataloader depends on the tokenizer and loss_fn and should be @@ -274,6 +288,11 @@ def setup(self, cfg: DictConfig) -> None: # if cfg is missing profiler key or if `cfg.profiler.enabled = False self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + def _setup_profiler( self, cfg_profiler: Optional[DictConfig] = None ) -> Union[torch.profiler.profile, DummyProfiler]: @@ -349,6 +368,7 @@ def _setup_model( base_model_state_dict: Dict[str, Any], lora_weights_state_dict: Optional[Dict[str, Any]] = None, ) -> nn.Module: + with utils.set_default_dtype(self._dtype), self._device: model = config.instantiate(cfg_model) @@ -543,18 +563,28 @@ def save_checkpoint(self, epoch: int) -> None: def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: # Both are shape [b, s] tokens, labels = batch["tokens"], batch["labels"] + # Get the attention mask and position ids from the dataset if they # exist. Currently, only sample packing in PackedDataset returns these mask = batch.get("mask", None) # shape [b, s, s] input_pos = batch.get("input_pos", None) # shape [b, s] + # run model logits = self._model(tokens, mask=mask, input_pos=input_pos) - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - logits = logits.transpose(1, 2) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + # Compute loss loss = self._loss_fn(logits, labels) + # free logits otherwise it peaks backward memory del logits diff --git a/recipes/qat_distributed.py b/recipes/qat_distributed.py index 07d8b8924c..c4464c9542 100644 --- a/recipes/qat_distributed.py +++ b/recipes/qat_distributed.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import os import sys import time @@ -216,6 +217,7 @@ def setup(self, cfg: DictConfig) -> None: checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer) + self._model_compile = cfg.get("compile", False) self._model = self._setup_model( cfg_model=cfg.model, enable_activation_checkpointing=cfg.enable_activation_checkpointing, @@ -236,7 +238,25 @@ def setup(self, cfg: DictConfig) -> None: else None, ) + # initialize loss self._loss_fn = config.instantiate(cfg.loss) + backend = os.environ.get("TORCH_COMPILE_BACKEND", "inductor") + if self._loss_fn.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + self._model.set_num_output_chunks(self._loss_fn.num_output_chunks) + if self._model_compile: + log.info("Compiling loss with torch.compile...") + # For CEWithChunkedOutputLoss, if we compile the entire class + # we lose the benefits from the chunked loss. + # Therefore, we only compile the cross entropy function + upcasting + self._loss_fn.compute_cross_entropy = torch.compile( + self._loss_fn.compute_cross_entropy, backend=backend + ) + else: + if self._model_compile: + log.info("Compiling loss with torch.compile...") + self._loss_fn = torch.compile(self._loss_fn, backend=backend) + log.info("Loss is initialized.") # sampler and dataloader depend on the tokenizer and loss_fn and should be # setup after both of these are initialized @@ -267,6 +287,11 @@ def setup(self, cfg: DictConfig) -> None: # if cfg is missing profiler key or if `cfg.profiler.enabled = False` self._profiler = self._setup_profiler(cfg.get(PROFILER_KEY, None)) + # Used to ignore labels for loss computation + self.ignore_labels_cache = torch.full( + (cfg.batch_size, 1), self._loss_fn.ignore_index, device=self._device + ) + def _setup_profiler( self, cfg_profiler: Optional[DictConfig] = None ) -> Union[torch.profiler.profile, DummyProfiler]: @@ -650,10 +675,17 @@ def train(self) -> None: ) logits = self._model(tokens, mask=mask, input_pos=input_pos) - # Shift so that tokens < n predict n - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - logits = logits.transpose(1, 2) + + # Shift labels to compute loss + # equivalent to doing labels[..., 1:] and logits[..., :-1, :] + # But this way we dont need to slice the logits. We just add an ignore index to labels. + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) + # Compute loss loss = self._loss_fn(logits, labels) # free logits otherwise it peaks backward memory diff --git a/tests/recipes/test_qat_distributed.py b/tests/recipes/test_qat_distributed.py index 6eac534a22..5d4d7069f1 100644 --- a/tests/recipes/test_qat_distributed.py +++ b/tests/recipes/test_qat_distributed.py @@ -95,5 +95,5 @@ def test_loss(self, config, model_type, ckpt_type, tmpdir, monkeypatch): loss_values = get_loss_values_from_metric_logger(log_file) expected_loss_values = self._fetch_expected_loss_values(model_type) torch.testing.assert_close( - loss_values, expected_loss_values, rtol=1e-4, atol=1e-4 + loss_values, expected_loss_values, rtol=1e-3, atol=1e-3 ) diff --git a/tests/torchtune/modules/loss/test_ce_chunked_output_loss.py b/tests/torchtune/modules/loss/test_ce_chunked_output_loss.py new file mode 100644 index 0000000000..47b5596dd0 --- /dev/null +++ b/tests/torchtune/modules/loss/test_ce_chunked_output_loss.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest +import torch +from tests.test_utils import assert_expected +from torchtune.modules.loss import CEWithChunkedOutputLoss +from torchtune.utils.seed import set_seed + + +@pytest.fixture(autouse=True) +def random(): + set_seed(42) + + +class TestCEWithChunkedOutputLoss: + def test_chunked_cross_entropy_loss(self): + # Create a sample input and label + ignore_index = -100 + batch_size = 3 + num_tokens = 50 + vocab_size = 50 + logits = torch.randn(batch_size, num_tokens, vocab_size, dtype=torch.bfloat16) + labels = torch.randint( + 0, vocab_size, (batch_size, num_tokens), dtype=torch.long + ) + + # add random ignore index to random tokens in the label + random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens)) + labels[random_indices < num_tokens // 5] = ignore_index + + # chunked CE + ce_loss = CEWithChunkedOutputLoss( + num_output_chunks=8, ignore_index=ignore_index + ) + logits_chunks = logits.chunk(ce_loss.num_output_chunks, dim=1) + chunked_loss = ce_loss(logits_chunks, labels) + + # vanilla CE + logits = logits.reshape(-1, logits.size(-1)) + labels = labels.reshape(-1) + standard_loss = torch.nn.functional.cross_entropy( + logits.float(), labels, reduction="mean", ignore_index=ignore_index + ) + + # Assert + assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2) diff --git a/torchtune/models/gemma/transformer.py b/torchtune/models/gemma/transformer.py index 293fe2dcc9..ebae362f83 100644 --- a/torchtune/models/gemma/transformer.py +++ b/torchtune/models/gemma/transformer.py @@ -66,6 +66,12 @@ def __init__( self.head_dim = head_dim self.causal_mask = None self.norm_embeddings = norm_embeddings + self.num_output_chunks = 0 + + def set_num_output_chunks(self, num_output_chunks: int) -> None: + """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. + This should be called before the first forward pass, in the recipe.""" + self.num_output_chunks = num_output_chunks def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: """Setup key value caches for attention calculation. @@ -158,6 +164,15 @@ def forward( # shape: [b, s, d] h = self.norm(h) - # shape: [b, s, v] - output = F.linear(h, self.tok_embeddings.weight).float() + if self.num_output_chunks > 0: + # shape: [b, seq_len/num_chunks, out_dim] - out_dim is usually the vocab size + # Used with CEWithChunkedOutputLoss. Need to set num_output_chunks in the recipe, + # before calling forward. Upcasting it done inside of the loss function. + output = [ + F.linear(chunk, self.tok_embeddings.weight) + for chunk in h.chunk(self.num_output_chunks, dim=1) + ] + else: + # shape: [b, seq_len, out_dim] + output = F.linear(h, self.tok_embeddings.weight).float() return output diff --git a/torchtune/modules/loss/__init__.py b/torchtune/modules/loss/__init__.py new file mode 100644 index 0000000000..ed5ad0be04 --- /dev/null +++ b/torchtune/modules/loss/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .ce_chunked_output_loss import CEWithChunkedOutputLoss + +__all__ = ["CEWithChunkedOutputLoss"] diff --git a/torchtune/modules/loss/ce_chunked_output_loss.py b/torchtune/modules/loss/ce_chunked_output_loss.py new file mode 100644 index 0000000000..eebb255a69 --- /dev/null +++ b/torchtune/modules/loss/ce_chunked_output_loss.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +import torch +import torch.nn.functional as F + + +class CEWithChunkedOutputLoss(torch.nn.Module): + """ + CE with chunked outputs that saves memory by only upcasting one chunk at a time. + + Since the model is trained with bf16, before running CE, we have to upcast + it to fp32 for better accuracy and stability. When upcasting happens, the memory usage doubles. + Models like llama3 have large vocabulary size and, therefore, have a large output + result (bsz, num_tokens, vocab_size). If we chunk on the token level, you can still compute + the cross entropy normally, but upcasting only one chunk at a time saves considerable memory. + + The CE and upcasting have to be compiled together for better performance. + When using this class, we recommend using torch.compile only on the method `compute_cross_entropy`. + The gains from chunking won't be realized if you compile the entire class. + + For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390 + """ + + def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100): + super().__init__() + self.num_output_chunks = num_output_chunks + self.ignore_index = ignore_index + + def compute_cross_entropy( + self, logits: torch.Tensor, labels: torch.Tensor + ) -> torch.Tensor: + """ + Upcast logits to fp32 and compute cross entropy loss. + """ + return F.cross_entropy( + logits.float(), labels, ignore_index=self.ignore_index, reduction="sum" + ) + + def forward(self, logits: List[torch.Tensor], labels: torch.Tensor) -> torch.Tensor: + """ + Args: + logits (List[torch.Tensor]): List of chunked logits of length + ``self.num_output_chunks``, where each chunk has shape + (batch_size, num_tokens / num_output_chunks, vocab_size). + labels (torch.Tensor): Ground truth labels of shape (batch_size, num_tokens). + + Returns: + torch.Tensor: Cross entropy loss of shape (1,). + + Example: + >>> loss_fn = ChunkedCrossEntropyLoss() + >>> + >>> h = torch.tensor([bsz, num_tokens, dim]) + >>> output_chunks = [model.output(chunk) for chunk in h.chunk(num_chunks, dim=1)] + >>> + >>> labels = torch.tensor([bsz, num_tokens]) + >>> loss = loss_fn(output_chunks, labels) + """ + + total_elements = (labels != self.ignore_index).sum() + + # chunk and reshape labels (bsz, num_tokens, vocab) -> [(bsz*num_tokens/num_chunks, vocab)] + labels = [ + target_chunk.reshape(-1) + for target_chunk in labels.chunk(self.num_output_chunks, dim=1) + ] + # reshape logits [(bsz, num_tokens/num_chunks, vocab)] -> [(bsz*num_tokens/num_chunks, vocab)] + logits = [ + logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits + ] + + # compute one chunk at a time + total_loss = 0.0 + for logits_chunk, labels_chunk in zip(logits, labels): + total_loss += self.compute_cross_entropy(logits_chunk, labels_chunk) + + return total_loss / total_elements diff --git a/torchtune/modules/transformer.py b/torchtune/modules/transformer.py index f9d9515f61..b9e88bbd05 100644 --- a/torchtune/modules/transformer.py +++ b/torchtune/modules/transformer.py @@ -338,6 +338,12 @@ def __init__( self.num_heads = num_heads self.head_dim = head_dim self.causal_mask = None + self.num_output_chunks = 0 + + def set_num_output_chunks(self, num_output_chunks: int) -> None: + """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. + This should be called before the first forward pass, in the recipe.""" + self.num_output_chunks = num_output_chunks def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: """Setup key value caches for attention calculation. @@ -460,8 +466,16 @@ def forward( # shape: [b, s, d] h = self.norm(h) - # shape: [b, s, out_dim] - out_dim is usually the vocab size - output = self.output(h).float() + if self.num_output_chunks > 0: + # shape: [b, seq_len/num_chunks, out_dim] - out_dim is usually the vocab size + # Used with CEWithChunkedOutputLoss. Need to set num_output_chunks in the recipe, + # before calling forward. Upcasting it done inside of the loss function. + output = [ + self.output(chunk) for chunk in h.chunk(self.num_output_chunks, dim=1) + ] + else: + # shape: [b, seq_len, out_dim] + output = self.output(h).float() # Output list if hidden states are requested, otherwise just the output # TODO: always output a list to have a consistent output type @@ -533,6 +547,12 @@ def __init__( self.num_heads = num_heads self.head_dim = head_dim self.causal_mask = None + self.num_output_chunks = 0 + + def set_num_output_chunks(self, num_output_chunks: int) -> None: + """Used to save memory in combination with :class:`~torchtune.modules.loss.CEWithChunkedOutputLoss`. + This should be called before the first forward pass, in the recipe.""" + self.num_output_chunks = num_output_chunks def setup_caches(self, batch_size: int, dtype: torch.dtype) -> None: """Setup key value caches for attention calculation. @@ -654,8 +674,17 @@ def forward( # shape: [b, s, d] h = self.norm(h) - # shape: [b, s, out_dim] - out_dim is usually the vocab size - output = F.linear(h, self.tok_embeddings.weight).float() + if self.num_output_chunks > 0: + # shape: [b, seq_len/num_chunks, out_dim] - out_dim is usually the vocab size + # Used with CEWithChunkedOutputLoss. Need to set num_output_chunks in the recipe, + # before calling forward. Upcasting it done inside of the loss function. + output = [ + F.linear(chunk, self.tok_embeddings.weight) + for chunk in h.chunk(self.num_output_chunks, dim=1) + ] + else: + # shape: [b, seq_len, out_dim] + output = F.linear(h, self.tok_embeddings.weight).float() # Output list if hidden states are requested, otherwise just the output # TODO: always output a list to have a consistent output type