Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for ZeRO-2/3 and ZeRO-offload in fairscale #10354

Merged
merged 6 commits into from
Feb 25, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 46 additions & 5 deletions docs/source/main_classes/trainer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,8 @@ provides support for the following features from `the ZeRO paper <https://arxiv.

1. Optimizer State Sharding
2. Gradient Sharding
3. Model Parameters Sharding (new and very experimental)
4. CPU offload (new and very experimental)

You will need at least two GPUs to use this feature.

Expand All @@ -255,8 +257,9 @@ To deploy this feature:
or find more details on `the FairScale's GitHub page
<https://github.com/facebookresearch/fairscale/#installation>`__.

2. Add ``--sharded_ddp`` to the command line arguments, and make sure you have added the distributed launcher ``-m
torch.distributed.launch --nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
2. To use the first version of Sharded data-parallelism, add ``--sharded_ddp simple`` to the command line arguments,
and make sure you have added the distributed launcher ``-m torch.distributed.launch
--nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.

For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs:

Expand All @@ -268,17 +271,55 @@ For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs:
--do_train --max_train_samples 500 --num_train_epochs 1 \
--dataset_name wmt16 --dataset_config "ro-en" \
--task translation_en_to_ro --source_prefix "translate English to Romanian: " \
--fp16 --sharded_ddp
--fp16 --sharded_ddp simple

Notes:

- This feature requires distributed training (so multiple GPUs).
- It is not implemented for TPUs.
- It works with ``--fp16`` too, to make things even faster.
- One of the main benefits of enabling ``--sharded_ddp`` is that it uses a lot less GPU memory, so you should be able
to use significantly larger batch sizes using the same hardware (e.g. 3x and even bigger) which should lead to
- One of the main benefits of enabling ``--sharded_ddp simple`` is that it uses a lot less GPU memory, so you should be
able to use significantly larger batch sizes using the same hardware (e.g. 3x and even bigger) which should lead to
significantly shorter training time.

3. To use the second version of Sharded data-parallelism, add ``--sharded_ddp zero_dp_2`` or ``--sharded_ddp zero_dp_3`
to the command line arguments, and make sure you have added the distributed launcher ``-m torch.distributed.launch
--nproc_per_node=NUMBER_OF_GPUS_YOU_HAVE`` if you haven't been using it already.
Comment on lines +285 to +287
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love the API


For example here is how you could use it for ``run_seq2seq.py`` with 2 GPUs:

.. code-block:: bash

python -m torch.distributed.launch --nproc_per_node=2 examples/seq2seq/run_seq2seq.py \
--model_name_or_path t5-small --per_device_train_batch_size 1 \
--output_dir output_dir --overwrite_output_dir \
LysandreJik marked this conversation as resolved.
Show resolved Hide resolved
--do_train --max_train_samples 500 --num_train_epochs 1 \
--dataset_name wmt16 --dataset_config "ro-en" \
--task translation_en_to_ro --source_prefix "translate English to Romanian: " \
--fp16 --sharded_ddp zero_dp_2

:obj:`zero_dp_2` is an optimized version of the simple wrapper, while :obj:`zero_dp_3` fully shards model weights,
gradients and optimizer states.

Both are compatible with adding :obj:`cpu_offload` to enable ZeRO-offload (activat it like this: :obj:`--sharded_ddp
sgugger marked this conversation as resolved.
Show resolved Hide resolved
"zero_dp_2 cpu_offload"`).

Notes:

- This feature requires distributed training (so multiple GPUs).
- It is not implemented for TPUs.
- It works with ``--fp16`` too, to make things even faster.
- The ``cpu_offload`` additional option require ``--fp16``.
sgugger marked this conversation as resolved.
Show resolved Hide resolved
- This is an area of active development, so make sure you have a source install of fairscale to use this feature as
some bugs you encounter may have been fixed there already.

Known caveats:

- This feature is incompatible with :obj:`--predict_with_generate` in the run_seq2seq script.
sgugger marked this conversation as resolved.
Show resolved Hide resolved
- Using :obj:`--sharded_ddp zero_dp_3` requires wrapping each layer of the model in the special container
:obj:`FullyShardedDataParallelism` of fairscale. This is not done automatically by any of the example script of the
sgugger marked this conversation as resolved.
Show resolved Hide resolved
:class:`~transformers.Trainer`.


DeepSpeed
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down
37 changes: 28 additions & 9 deletions examples/tests/trainer/test_trainer_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,13 @@ def require_apex(test_case):


class TestTrainerExt(TestCasePlus):
def run_seq2seq_quick(self, distributed=False, extra_args_str=None):
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str)
def run_seq2seq_quick(self, distributed=False, extra_args_str=None, eval=True, predict_with_generate=True):
output_dir = self.run_trainer(1, "12", MBART_TINY, 1, distributed, extra_args_str, predict_with_generate)
logs = TrainerState.load_from_json(os.path.join(output_dir, "trainer_state.json")).log_history
eval_metrics = [log for log in logs if "eval_loss" in log.keys()]
first_step_stats = eval_metrics[0]
assert "eval_bleu" in first_step_stats
if predict_with_generate:
assert "eval_bleu" in first_step_stats

@require_torch_non_multi_gpu
def test_run_seq2seq_no_dist(self):
Expand All @@ -88,14 +89,28 @@ def test_run_seq2seq_ddp(self):
# test --sharded_ddp w/o --fp16
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_ddp_sharded_ddp(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp")
def test_run_seq2seq_sharded_ddp(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple")

# test --sharded_ddp w/ --fp16
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_ddp_sharded_ddp_fp16(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp --fp16")
def test_run_seq2seq_sharded_ddp_fp16(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp simple --fp16")

# test --sharded_ddp zero2 w/o --fp16
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_fully_sharded_ddp(self):
self.run_seq2seq_quick(distributed=True, extra_args_str="--sharded_ddp zero2", predict_with_generate=False)

# test --sharded_ddp zero2 w/ --fp16
@require_torch_multi_gpu
@require_fairscale
def test_run_seq2seq_fully_sharded_ddp_fp16(self):
self.run_seq2seq_quick(
distributed=True, extra_args_str="--sharded_ddp zero2 --fp16", predict_with_generate=False
)

@require_apex
def test_run_seq2seq_apex(self):
Expand Down Expand Up @@ -131,6 +146,7 @@ def run_trainer(
num_train_epochs: int,
distributed: bool = False,
extra_args_str: str = None,
predict_with_generate: bool = True,
):
data_dir = self.examples_dir / "test_data/wmt_en_ro"
output_dir = self.get_auto_remove_tmp_dir()
Expand All @@ -155,7 +171,6 @@ def run_trainer(
--learning_rate 3e-3
--warmup_steps 8
--evaluation_strategy steps
--predict_with_generate
--logging_steps 0
--save_steps {str(eval_steps)}
--eval_steps {str(eval_steps)}
Expand All @@ -165,7 +180,11 @@ def run_trainer(
--task translation
--target_lang ro_RO
--source_lang en_XX
""".split()
"""
if predict_with_generate:
args += "--predict_with_generate"

args = args.split()

if extra_args_str is not None:
args.extend(extra_args_str.split())
Expand Down
98 changes: 72 additions & 26 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
EvalPrediction,
HPSearchBackend,
PredictionOutput,
ShardedDDPOption,
TrainerMemoryTracker,
TrainOutput,
default_compute_objective,
Expand Down Expand Up @@ -132,10 +133,16 @@
import torch_xla.distributed.parallel_loader as pl

if is_fairscale_available():
import fairscale
from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
from fairscale.optim import OSS
from fairscale.optim.grad_scaler import ShardedGradScaler

if version.parse(fairscale.__version__) >= version.parse("0.3"):
from fairscale.nn.data_parallel import FullyShardedDataParallel as FullyShardedDDP
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this may introduce a confusion here, should we stick to DP and not DDP to match the real name? i.e. FullyShardedDP and ShardedDP?

Perhaps change the original flag to reflect that as well? --sharded_dp?

Copy link
Contributor

@stas00 stas00 Feb 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, made a request to make those names renamed to match DDP here:
facebookresearch/fairscale#413 (comment)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Stas. I personally think the distinction between DDP and DP is not going to matter anymore. Even pytorch DDP itself is moving to remove the "device_ids" argument in the future so that there isn't a support for a single process DP (as opposed to distributed/multiprocess DP). Therefore, I think sticking with FSDP is fine within fairscale.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your follow up, @min-xu-ai

else:
FullyShardedDDP = None

if is_sagemaker_distributed_available():
import smdistributed.dataparallel.torch.distributed as dist
from smdistributed.dataparallel.torch.parallel.distributed import DistributedDataParallel as DDP
Expand Down Expand Up @@ -276,9 +283,38 @@ def __init__(
else:
self.is_model_parallel = False

# Setup Sharded DDP training
self.sharded_ddp = None
if len(args.sharded_ddp) > 0:
if args.deepspeed:
raise ValueError(
"Using --sharded_ddp xxx together with --deepspeed is not possible, deactivate one of those flags."
)

if args.local_rank == -1:
raise ValueError("Using sharded DDP only works in distributed training.")
elif not is_fairscale_available():
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
elif ShardedDDPOption.SIMPLE not in args.sharded_ddp and FullyShardedDDP is None:
raise ImportError(
"Sharded DDP in a mode other than simple training requires fairscale version >= 0.3, found "
f"{fairscale.__version__}. Upgrade your fairscale library: `pip install --upgrade fairscale`."
)
elif ShardedDDPOption.SIMPLE in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.SIMPLE
elif ShardedDDPOption.ZERO_DP_2 in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.ZERO_DP_2
elif ShardedDDPOption.ZERO_DP_3 in args.sharded_ddp:
self.sharded_ddp = ShardedDDPOption.ZERO_DP_3

# one place to sort out whether to place the model on device or not
self.place_model_on_device = args.place_model_on_device
if self.is_model_parallel or (args.deepspeed and args.do_train) or (args.fp16_full_eval and not args.do_train):
if (
self.is_model_parallel
or (args.deepspeed and args.do_train)
or (args.fp16_full_eval and not args.do_train)
or (self.sharded_ddp in [ShardedDDPOption.ZERO_DP_2, ShardedDDPOption.ZERO_DP_3])
):
self.place_model_on_device = False

default_collator = default_data_collator if tokenizer is None else DataCollatorWithPadding(tokenizer)
Expand Down Expand Up @@ -345,21 +381,6 @@ def __init__(
if isinstance(eval_dataset, datasets.Dataset):
self._remove_unused_columns(self.eval_dataset, description="evaluation")

# Setup Sharded DDP training
self.sharded_dpp = False
if args.sharded_ddp:
if args.deepspeed:
raise ValueError(
"Using --sharded_ddp together with --deepspeed is not possible, deactivate one of those flags."
)

if args.local_rank == -1:
raise ValueError("Using sharded DDP only works in distributed training.")
elif not is_fairscale_available():
raise ImportError("Sharded DDP training requires fairscale: `pip install fairscale`.")
else:
self.sharded_dpp = True

# Mixed precision setup
self.use_apex = False
self.use_amp = False
Expand All @@ -375,7 +396,7 @@ def __init__(
if args.fp16 and not args.deepspeed: # deepspeed manages its own fp16
if self.fp16_backend == "amp":
self.use_amp = True
self.scaler = ShardedGradScaler() if self.sharded_dpp else torch.cuda.amp.GradScaler()
self.scaler = ShardedGradScaler() if self.sharded_ddp is not None else torch.cuda.amp.GradScaler()
else:
if not is_apex_available():
raise ImportError(
Expand Down Expand Up @@ -618,7 +639,7 @@ def create_optimizer_and_scheduler(self, num_training_steps: int):
"eps": self.args.adam_epsilon,
}
optimizer_kwargs["lr"] = self.args.learning_rate
if self.sharded_dpp:
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer = OSS(
params=optimizer_grouped_parameters,
optim=optimizer_cls,
Expand Down Expand Up @@ -736,8 +757,19 @@ def _wrap_model(self, model, training=True):
return model

# Distributed training (should be after apex fp16 initialization)
if self.sharded_dpp:
model = ShardedDDP(model, self.optimizer)
if self.sharded_ddp is not None:
# Sharded DDP!
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
model = ShardedDDP(model, self.optimizer)
else:
mixed_precision = self.args.fp16
cpu_offload = ShardedDDPOption.OFFLOAD in self.args.sharded_ddp
zero_3 = self.sharded_ddp == ShardedDDPOption.ZERO_DP_3
# Breaking the self.model convention but I see no way around it for now.
sgugger marked this conversation as resolved.
Show resolved Hide resolved
self.model = model = FullyShardedDDP(
model, mixed_precision=mixed_precision, reshard_after_forward=zero_3, cpu_offload=cpu_offload
).to(self.args.device)

elif is_sagemaker_distributed_available():
model = DDP(model, device_ids=[dist.get_local_rank()], broadcast_buffers=False)
elif self.args.local_rank != -1:
Expand Down Expand Up @@ -854,14 +886,15 @@ def train(
num_train_epochs = 1
num_update_steps_per_epoch = max_steps

delay_optimizer_creation = self.sharded_ddp is not None and self.sharded_ddp != ShardedDDPOption.SIMPLE
if self.args.deepspeed:
model, optimizer, lr_scheduler = init_deepspeed(self, num_training_steps=max_steps)
self.model = model.module
self.model_wrapped = model # will get further wrapped in DDP
self.deepspeed = model # DeepSpeedEngine object
self.optimizer = optimizer
self.lr_scheduler = lr_scheduler
else:
elif not delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

self.state = TrainerState()
Expand All @@ -876,6 +909,9 @@ def train(
if model is not self.model:
self.model_wrapped = model

if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

# important: at this point:
# self.model is the Transformers Model
# self.model_wrapped is DDP(Transformers Model), Deepspeed(Transformers Model), etc.
Expand Down Expand Up @@ -1025,6 +1061,9 @@ def train(
if hasattr(self.optimizer, "clip_grad_norm"):
# Some optimizers (like the sharded optimizer) have a specific way to do gradient clipping
self.optimizer.clip_grad_norm(self.args.max_grad_norm)
elif hasattr(model, "clip_grad_norm_"):
# Some models (like FullyShardedDDP) have a specific way to do gradient clipping
model.clip_grad_norm_(self.args.max_grad_norm)
else:
# Revert to normal clipping otherwise, handling Apex or full precision
torch.nn.utils.clip_grad_norm_(
Expand Down Expand Up @@ -1151,8 +1190,8 @@ def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):

def _save_checkpoint(self, model, trial, metrics=None):
# In all cases, including ddp/dp/deepspeed, self.model is always a reference to the model we
# want to save.
assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"
# want to save except FullyShardedDDP.
# assert _model_unwrap(model) is self.model, "internal model should be a reference to self.model"

# Save model checkpoint
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
Expand All @@ -1176,7 +1215,7 @@ def _save_checkpoint(self, model, trial, metrics=None):
self.deepspeed.save_checkpoint(output_dir)

# Save optimizer and scheduler
if self.sharded_dpp:
if self.sharded_ddp == ShardedDDPOption.SIMPLE:
self.optimizer.consolidate_state_dict()

if is_torch_tpu_available():
Expand Down Expand Up @@ -1537,7 +1576,11 @@ def _save_tpu(self, output_dir: Optional[str] = None):
# They can then be reloaded using `from_pretrained()`
xm.rendezvous("saving_checkpoint")
if not isinstance(self.model, PreTrainedModel):
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if isinstance(_model_unwrap(self.model), PreTrainedModel):
if xm.is_master_ordinal():
_model_unwrap(self.model).config.save_pretrained(output_dir)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
Expand All @@ -1552,7 +1595,10 @@ def _save(self, output_dir: Optional[str] = None):
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`
if not isinstance(self.model, PreTrainedModel):
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
if isinstance(_model_unwrap(self.model), PreTrainedModel):
_model_unwrap(self.model).config.save_pretrained(output_dir)
else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
state_dict = self.model.state_dict()
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
Expand Down
7 changes: 7 additions & 0 deletions src/transformers/trainer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,3 +421,10 @@ def stop_and_update_metrics(self, metrics=None):
# init doesn't have metrics to update so we just save that data for later stages to retrieve
if metrics is not None:
self.update_metrics(stage, metrics)


class ShardedDDPOption(ExplicitEnum):
SIMPLE = "simple"
ZERO_DP_2 = "zero2"
ZERO_DP_3 = "zero3"
OFFLOAD = "offload"
Loading