-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for ZeRO-2/3 and ZeRO-offload in fairscale #10354
Changes from all commits
2ba8601
3fc784b
442e221
1ea56d0
0f20f1d
ab10ada
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -94,6 +94,7 @@ | |
EvalPrediction, | ||
HPSearchBackend, | ||
PredictionOutput, | ||
ShardedDDPOption, | ||
TrainerMemoryTracker, | ||
TrainOutput, | ||
default_compute_objective, | ||
|
@@ -132,10 +133,16 @@ | |
import torch_xla.distributed.parallel_loader as pl | ||
|
||
if is_fairscale_available(): | ||
import fairscale | ||
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP | ||
from fairscale.optim import OSS | ||
from fairscale.optim.grad_scaler import ShardedGradScaler | ||
|
||
if version.parse(fairscale.__version__) >= version.parse("0.3"): | ||
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this may introduce a confusion here, should we stick to DP and not DDP to match the real name? i.e. FullyShardedDP and ShardedDP? Perhaps change the original flag to reflect that as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, made a request to make those names renamed to match DDP here: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks Stas. I personally think the distinction between DDP and DP is not going to matter anymore. Even pytorch DDP itself is moving to remove the "device_ids" argument in the future so that there isn't a support for a single process DP (as opposed to distributed/multiprocess DP). Therefore, I think sticking with FSDP is fine within fairscale. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for your follow up, @min-xu-ai |
||
else: | ||
FullyShardedDDP = None | ||
|
||
if is_sagemaker_distributed_available(): | ||
import smdistributed.dataparallel.torch.distributed as dist | ||
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP | ||
|
@@ -276,9 +283,38 @@ def __init__( | |
else: | ||
self.is_model_parallel = False | ||
|
||
# Setup Sharded DDP training | ||
self.sharded_ddp = None | ||
if len(args.sharded_ddp) > 0: | ||
if args.deepspeed: | ||
raise ValueError( | ||
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags." | ||
) | ||
|
||
if args.local_rank == -1: | ||
raise ValueError("Using sharded DDP only works in distributed training.") | ||
elif not is_fairscale_available(): | ||
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") | ||
elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None: | ||
raise ImportError( | ||
"Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found " | ||
f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`." | ||
) | ||
elif ShardedDDPOption.SIMPLE in args.sharded_ddp: | ||
self.sharded_ddp = ShardedDDPOption.SIMPLE | ||
elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp: | ||
self.sharded_ddp = ShardedDDPOption.ZERO_DP_2 | ||
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp: | ||
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3 | ||
|
||
# one place to sort out whether to place the model on device or not | ||
self.place_model_on_device = args.place_model_on_device | ||
if self.is_model_parallel or (args.deepspeed and args.do_train) or (args.fp16_full_eval and not args.do_train): | ||
if ( | ||
self.is_model_parallel | ||
or (args.deepspeed and args.do_train) | ||
or (args.fp16_full_eval and not args.do_train) | ||
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3]) | ||
): | ||
self.place_model_on_device = False | ||
|
||
default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer) | ||
|
@@ -345,21 +381,6 @@ def __init__( | |
if isinstance(eval_dataset, datasets.Dataset): | ||
self._remove_unused_columns(self.eval_dataset, description="evaluation") | ||
|
||
# Setup Sharded DDP training | ||
self.sharded_dpp = False | ||
if args.sharded_ddp: | ||
if args.deepspeed: | ||
raise ValueError( | ||
"Using --sharded_ddp together with --deepspeed is not possible, deactivate one of those flags." | ||
) | ||
|
||
if args.local_rank == -1: | ||
raise ValueError("Using sharded DDP only works in distributed training.") | ||
elif not is_fairscale_available(): | ||
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.") | ||
else: | ||
self.sharded_dpp = True | ||
|
||
# Mixed precision setup | ||
self.use_apex = False | ||
self.use_amp = False | ||
|
@@ -375,7 +396,7 @@ def __init__( | |
if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16 | ||
if self.fp16_backend == "amp": | ||
self.use_amp = True | ||
self.scaler = ShardedGradScaler() if self.sharded_dpp else torch.cuda.amp.GradScaler() | ||
self.scaler = ShardedGradScaler() if self.sharded_ddp is not None else torch.cuda.amp.GradScaler() | ||
else: | ||
if not is_apex_available(): | ||
raise ImportError( | ||
|
@@ -618,7 +639,7 @@ def create_optimizer_and_scheduler(self, num_training_steps: int): | |
"eps": self.args.adam_epsilon, | ||
} | ||
optimizer_kwargs["lr"] = self.args.learning_rate | ||
if self.sharded_dpp: | ||
if self.sharded_ddp == ShardedDDPOption.SIMPLE: | ||
self.optimizer = OSS( | ||
params=optimizer_grouped_parameters, | ||
optim=optimizer_cls, | ||
|
@@ -736,8 +757,19 @@ def _wrap_model(self, model, training=True): | |
return model | ||
|
||
# Distributed training (should be after apex fp16 initialization) | ||
if self.sharded_dpp: | ||
model = ShardedDDP(model, self.optimizer) | ||
if self.sharded_ddp is not None: | ||
# Sharded DDP! | ||
if self.sharded_ddp == ShardedDDPOption.SIMPLE: | ||
model = ShardedDDP(model, self.optimizer) | ||
else: | ||
mixed_precision = self.args.fp16 | ||
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp | ||
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3 | ||
# XXX: Breaking the self.model convention but I see no way around it for now. | ||
self.model = model = FullyShardedDDP( | ||
model, mixed_precision=mixed_precision, reshard_after_forward=zero_3, cpu_offload=cpu_offload | ||
).to(self.args.device) | ||
|
||
elif is_sagemaker_distributed_available(): | ||
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False) | ||
elif self.args.local_rank != -1: | ||
|
@@ -854,14 +886,15 @@ def train( | |
num_train_epochs = 1 | ||
num_update_steps_per_epoch = max_steps | ||
|
||
delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE | ||
if self.args.deepspeed: | ||
model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps) | ||
self.model = model.module | ||
self.model_wrapped = model # will get further wrapped in DDP | ||
self.deepspeed = model # DeepSpeedEngine object | ||
self.optimizer = optimizer | ||
self.lr_scheduler = lr_scheduler | ||
else: | ||
elif not delay_optimizer_creation: | ||
self.create_optimizer_and_scheduler(num_training_steps=max_steps) | ||
|
||
self.state = TrainerState() | ||
|
@@ -876,6 +909,9 @@ def train( | |
if model is not self.model: | ||
self.model_wrapped = model | ||
|
||
if delay_optimizer_creation: | ||
self.create_optimizer_and_scheduler(num_training_steps=max_steps) | ||
|
||
# important: at this point: | ||
# self.model is the Transformers Model | ||
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc. | ||
|
@@ -1025,6 +1061,9 @@ def train( | |
if hasattr(self.optimizer, "clip_grad_norm"): | ||
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping | ||
self.optimizer.clip_grad_norm(self.args.max_grad_norm) | ||
elif hasattr(model, "clip_grad_norm_"): | ||
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping | ||
model.clip_grad_norm_(self.args.max_grad_norm) | ||
else: | ||
# Revert to normal clipping otherwise, handling Apex or full precision | ||
torch.nn.utils.clip_grad_norm_( | ||
|
@@ -1151,8 +1190,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch): | |
|
||
def _save_checkpoint(self, model, trial, metrics=None): | ||
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we | ||
# want to save. | ||
assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model" | ||
# want to save except FullyShardedDDP. | ||
# assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model" | ||
|
||
# Save model checkpoint | ||
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" | ||
|
@@ -1176,7 +1215,7 @@ def _save_checkpoint(self, model, trial, metrics=None): | |
self.deepspeed.save_checkpoint(output_dir) | ||
|
||
# Save optimizer and scheduler | ||
if self.sharded_dpp: | ||
if self.sharded_ddp == ShardedDDPOption.SIMPLE: | ||
self.optimizer.consolidate_state_dict() | ||
|
||
if is_torch_tpu_available(): | ||
|
@@ -1537,7 +1576,11 @@ def _save_tpu(self, output_dir: Optional[str] = None): | |
# They can then be reloaded using `from_pretrained()` | ||
xm.rendezvous("saving_checkpoint") | ||
if not isinstance(self.model, PreTrainedModel): | ||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") | ||
if isinstance(_model_unwrap(self.model), PreTrainedModel): | ||
if xm.is_master_ordinal(): | ||
_model_unwrap(self.model).config.save_pretrained(output_dir) | ||
else: | ||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") | ||
state_dict = self.model.state_dict() | ||
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) | ||
else: | ||
|
@@ -1552,7 +1595,10 @@ def _save(self, output_dir: Optional[str] = None): | |
# Save a trained model and configuration using `save_pretrained()`. | ||
# They can then be reloaded using `from_pretrained()` | ||
if not isinstance(self.model, PreTrainedModel): | ||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") | ||
if isinstance(_model_unwrap(self.model), PreTrainedModel): | ||
_model_unwrap(self.model).config.save_pretrained(output_dir) | ||
else: | ||
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.") | ||
state_dict = self.model.state_dict() | ||
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) | ||
else: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Love the API