-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Switch to fsdp training #81
Conversation
Signed-off-by: Michael Clifford <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just a small question, otherwise LGTM
@@ -191,7 +191,7 @@ def list_phase1_final_model(): | |||
export XDG_CACHE_HOME=/tmp | |||
export HF_HOME=/tmp | |||
export TRANSFORMERS_CACHE=/tmp | |||
torchrun --nnodes {nnodes} --nproc_per_node {nproc_per_node} --node_rank \$(RANK) --rdzv_endpoint \$(MASTER_ADDR):\$(MASTER_PORT) -m instructlab.training.main_ds --model_name_or_path={path_to_model} --data_path=/input_data/processed_data/data.jsonl --output_dir=/tmp/model --num_epochs=2 --effective_batch_size=3840 --learning_rate=2e-6 --num_warmup_steps=800 --save_samples=0 --log_level=INFO --max_batch_len=20000 --seed=42 --cpu_offload_optimizer --sharding_strategy=FULL_SHARD --is_granite --checkpoint_at_epoch | |||
torchrun --nnodes {nnodes} --nproc_per_node {nproc_per_node} --node_rank \$(RANK) --rdzv_endpoint \$(MASTER_ADDR):\$(MASTER_PORT) -m instructlab.training.main_ds --model_name_or_path={path_to_model} --data_path=/input_data/processed_data/data.jsonl --output_dir=/tmp/model --num_epochs=2 --effective_batch_size=3840 --learning_rate=1e-4 --num_warmup_steps=800 --save_samples=0 --log_level=INFO --max_batch_len=20000 --seed=42 --cpu_offload_optimizer --distributed_training_framework fsdp --is_granite --checkpoint_at_epoch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the --learning_rate
fix about?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was just fixing a typo where I forgot to update the learning_rate
as it should be the same value on the master and worker nodes calls to torchrun
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM 🚀
Related to #51
This is PR simply adds the
distributed_training_framework
parameter to the torchrun call of our PyTorchJob, and sets it tofsdp
. This requires that the latest RHEL AI image (1.2) is used since it has the recent updated needed for FSDP.This has been tested on the MOC and seems to work as expected.