Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ZeRO-2/3 and ZeRO-offload in fairscale #10354

Merged
merged 6 commits into from
Feb 25, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/tests/trainer/test_trainer_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,13 @@ def test_run_seq2seq_ddp(self):
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_ddp_sharded_ddp(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp")
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")

# test --sharded_ddp w/ --fp16
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_ddp_sharded_ddp_fp16(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")

@require_apex
def test_run_seq2seq_apex(self):
Expand Down
57 changes: 47 additions & 10 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
EvalPrediction,
HPSearchBackend,
PredictionOutput,
ShardedDDPType,
TrainerMemoryTracker,
TrainOutput,
default_compute_objective,
Expand Down Expand Up @@ -132,10 +133,16 @@
import torch_xla.distributed.parallel_loader as pl

if is_fairscale_available():
import fairscale
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler

if version.parse(fairscale.__version__) >= version.parse("0.3"):
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this may introduce a confusion here, should we stick to DP and not DDP to match the real name? i.e. FullyShardedDP and ShardedDP?

Perhaps change the original flag to reflect that as well? --sharded_dp?

Copy link
Contributor

@stas00 stas00 Feb 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, made a request to make those names renamed to match DDP here:
facebookresearch/fairscale#413 (comment)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Stas. I personally think the distinction between DDP and DP is not going to matter anymore. Even pytorch DDP itself is moving to remove the "device_ids" argument in the future so that there isn't a support for a single process DP (as opposed to distributed/multiprocess DP). Therefore, I think sticking with FSDP is fine within fairscale.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your follow up, @min-xu-ai

else:
FullyShardedDDP = None

if is_sagemaker_distributed_available():
import smdistributed.dataparallel.torch.distributed as dist
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
Expand Down Expand Up @@ -347,16 +354,21 @@ def __init__(

# Setup Sharded DDP training
self.sharded_dpp = False
if args.sharded_ddp:
if args.sharded_ddp != ShardedDDPType.NO:
if args.deepspeed:
raise ValueError(
"Using --sharded_ddp together with --deepspeed is not possible, deactivate one of those flags."
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
)

if args.local_rank == -1:
raise ValueError("Using sharded DDP only works in distributed training.")
elif not is_fairscale_available():
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
elif args.sharded_ddp != ShardedDDPType.SIMPLE and FullyShardedDDP is None:
raise ImportError(
"Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
)
else:
self.sharded_dpp = True

Expand Down Expand Up @@ -618,7 +630,7 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):
"eps": self.args.adam_epsilon,
}
optimizer_kwargs["lr"] = self.args.learning_rate
if self.sharded_dpp:
if self.args.sharded_ddp == ShardedDDPType.SIMPLE:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
Expand Down Expand Up @@ -737,7 +749,18 @@ def _wrap_model(self, model, training=True):

# Distributed training (should be after apex fp16 initialization)
if self.sharded_dpp:
model = ShardedDDP(model, self.optimizer)
# Sharded DDP!
if self.args.sharded_ddp == ShardedDDPType.SIMPLE:
model = ShardedDDP(model, self.optimizer)
else:
mixed_precision = self.args.fp16
cpu_offload = self.args.sharded_ddp in [ShardedDDPType.ZERO_2_OFFLOAD, ShardedDDPType.ZERO_3_OFFLOAD]
zero_3 = self.args.sharded_ddp in [ShardedDDPType.ZERO_3, ShardedDDPType.ZERO_3_OFFLOAD]
# Breaking the self.model convention but I see no way around it for now.
sgugger marked this conversation as resolved.
Show resolved Hide resolved
self.model = FullyShardedDDP(
model, mixed_precision=mixed_precision, reshard_after_forward=zero_3, cpu_offload=cpu_offload
)

elif is_sagemaker_distributed_available():
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
elif self.args.local_rank != -1:
Expand Down Expand Up @@ -854,14 +877,15 @@ def train(
num_train_epochs = 1
num_update_steps_per_epoch = max_steps

delay_optimizer_creation = self.args.sharded_ddp not in [ShardedDDPType.NO, ShardedDDPType.SIMPLE]
if self.args.deepspeed:
model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps)
self.model = model.module
self.model_wrapped = model # will get further wrapped in DDP
self.deepspeed = model # DeepSpeedEngine object
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
else:
elif not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

self.state = TrainerState()
Expand All @@ -876,6 +900,9 @@ def train(
if model is not self.model:
self.model_wrapped = model

if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

# important: at this point:
# self.model is the Transformers Model
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
Expand Down Expand Up @@ -1025,6 +1052,9 @@ def train(
if hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(self.args.max_grad_norm)
elif hasattr(model, "clip_grad_norm_"):
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping
model.clip_grad_norm(self.args.max_grad_norm)
else:
# Revert to normal clipping otherwise, handling Apex or full precision
torch.nn.utils.clip_grad_norm_(
Expand Down Expand Up @@ -1151,8 +1181,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):

def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# want to save.
assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"
# want to save except FullyShardedDDP.
# assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"

# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
Expand All @@ -1176,7 +1206,7 @@ def _save_checkpoint(self, model, trial, metrics=None):
self.deepspeed.save_checkpoint(output_dir)

# Save optimizer and scheduler
if self.sharded_dpp:
if self.args.sharded_ddp == ShardedDDPType.SIMPLE:
self.optimizer.consolidate_state_dict()

if is_torch_tpu_available():
Expand Down Expand Up @@ -1537,7 +1567,11 @@ def _save_tpu(self, output_dir: Optional[str] = None):
# They can then be reloaded using `from_pretrained()`
xm.rendezvous("saving_checkpoint")
if not isinstance(self.model, PreTrainedModel):
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if isinstance(_model_unwrap(self.model), PreTrainedModel):
if xm.is_master_ordinal():
_model_unwrap(self.model).config.save_pretrained(output_dir)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
Expand All @@ -1552,7 +1586,10 @@ def _save(self, output_dir: Optional[str] = None):
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if isinstance(_model_unwrap(self.model), PreTrainedModel):
_model_unwrap(self.model).config.save_pretrained(output_dir)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,3 +421,12 @@ def stop_and_update_metrics(self, metrics=None):
# init doesn't have metrics to update so we just save that data for later stages to retrieve
if metrics is not None:
self.update_metrics(stage, metrics)


class ShardedDDPType(ExplicitEnum):
NO = "no"
SIMPLE = "simple"
ZERO_2 = "zero2"
ZERO_2_OFFLOAD = "zero2_offload"
ZERO_3 = "zero3"
ZERO_3_OFFLOAD = "zero3_offload"
29 changes: 24 additions & 5 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
is_torch_tpu_available,
torch_required,
)
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType, ShardedDDPType
from .utils import logging


Expand Down Expand Up @@ -236,9 +236,24 @@ class TrainingArguments:
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping
step can take a long time) but will not yield the same results as the interrupted training would have.
sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`):
sharded_ddp (:obj:`bool`, :obj:`str` or :class:`~transformers.trainer_utils.ShardedDDPType`, `optional`, defaults to :obj:`False`):
Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed
training only). This is an experimental feature.

Can take up to six values:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Can take up to six values:
Can be one of the following values:
  • clarifying that it's one of them
  • the total count is of no useful value to the user


- :obj:`"no"`: for no sharded DataParallelism (default behavior)
- :obj:`"simple"`: to use first instance of sharded DDP released by fairscale (:obj:`ShardedDDP`) similar
to ZeRO-2.
- :obj:`"zero_2"`: to use the second instance of sharded DPP released by fairscale (:obj:`FullyShardedDDP`)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are smashing concepts a bit here. ZeRO is a big territory with many features. the 3 stages belong to ZeRO-DP part of ZeRO, so ideally this should be zero_dp_(1|2|3) or zero_dp(1|2|3).

This is just a suggestion though, if you strongly feel having just the number is clear enough, that's OK too.

Copy link
Contributor

@stas00 stas00 Feb 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, and that's why they call it DP and not DDP, because it's ZeRO-DP.

in Zero-2 mode (with :obj:`reshard_after_forward=False`).
- :obj:`"zero_2_offload"`: to use add ZeRO-offload to ZeRO-2.
- :obj:`"zero_3"`: to use the second instance of sharded DPP released by fairscale (:obj:`FullyShardedDDP`)
in Zero-3 mode (with :obj:`reshard_after_forward=True`).
- :obj:`"zero_3_offload"`: to use add ZeRO-offload to ZeRO-3.

If a bool is passed, it will be converted to :obj:`"no"` for :obj:`False` and :obj:`"simple"` for
:obj:`True`.
deepspeed (:obj:`str`, `optional`):
Use `Deepspeed <https://github.com/microsoft/deepspeed>`__. This is an experimental feature and its API may
evolve in the future. The value is the location of its json config file (usually ``ds_config.json``).
Expand Down Expand Up @@ -443,8 +458,8 @@ class TrainingArguments:
"help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data."
},
)
sharded_ddp: bool = field(
default=False,
sharded_ddp: ShardedDDPType = field(
default="no",
metadata={"help": "Whether or not to use sharded DDP training (in distributed training only)."},
)
deepspeed: Optional[str] = field(
Expand Down Expand Up @@ -535,6 +550,10 @@ def __post_init__(self):
"Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio during training"
)

if isinstance(self.sharded_ddp, bool):
self.sharded_ddp = "simple" if self.sharded_ddp else "no"
self.sharded_ddp = ShardedDDPType(self.sharded_ddp)

def __repr__(self):
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once
# those deprecated arguments are removed form TrainingArguments. (TODO: v5)
Expand Down Expand Up @@ -662,7 +681,7 @@ def parallel_mode(self):

- :obj:`ParallelMode.NOT_PARALLEL`: no parallelism (CPU or one GPU).
- :obj:`ParallelMode.NOT_DISTRIBUTED`: several GPUs in one single process (uses :obj:`torch.nn.DataParallel`).
- :obj:`ParallelMode.DISTRIBUTED`: several GPUs, each ahving its own process (uses
- :obj:`ParallelMode.DISTRIBUTED`: several GPUs, each having its own process (uses
:obj:`torch.nn.DistributedDataParallel`).
- :obj:`ParallelMode.TPU`: several TPU cores.
"""
Expand Down