-
Notifications
You must be signed in to change notification settings - Fork 27.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for ZeRO-2/3 and ZeRO-offload in fairscale #10354
Changes from 3 commits
2ba8601
3fc784b
442e221
1ea56d0
0f20f1d
ab10ada
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,7 @@ | |
is_torch_tpu_available, | ||
torch_required, | ||
) | ||
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType | ||
from .trainer_utils import EvaluationStrategy, LoggingStrategy, SchedulerType, ShardedDDPOption | ||
from .utils import logging | ||
|
||
|
||
|
@@ -236,9 +236,22 @@ class TrainingArguments: | |
When resuming training, whether or not to skip the epochs and batches to get the data loading at the same | ||
stage as in the previous training. If set to :obj:`True`, the training will begin faster (as that skipping | ||
step can take a long time) but will not yield the same results as the interrupted training would have. | ||
sharded_ddp (:obj:`bool`, `optional`, defaults to :obj:`False`): | ||
sharded_ddp (:obj:`bool`, :obj:`str` or list of :class:`~transformers.trainer_utils.ShardedDDPOption`, `optional`, defaults to :obj:`False`): | ||
Use Sharded DDP training from `FairScale <https://github.com/facebookresearch/fairscale>`__ (in distributed | ||
training only). This is an experimental feature. | ||
|
||
A list of options along the following: | ||
|
||
- :obj:`"simple"`: to use first instance of sharded DDP released by fairscale (:obj:`ShardedDDP`) similar | ||
to ZeRO-2. | ||
- :obj:`"zero_dp_2"`: to use the second instance of sharded DPP released by fairscale | ||
(:obj:`FullyShardedDDP`) in Zero-2 mode (with :obj:`reshard_after_forward=False`). | ||
- :obj:`"zero_dp_3"`: to use the second instance of sharded DPP released by fairscale | ||
(:obj:`FullyShardedDDP`) in Zero-3 mode (with :obj:`reshard_after_forward=True`). | ||
- :obj:`"offload"`: to add ZeRO-offload (only compatible with :obj:`"zero_dp_2"` and :obj:`"zero_dp_3"`). | ||
|
||
If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty | ||
list for :obj:`False` and :obj:`["simple"]` for :obj:`True`. | ||
deepspeed (:obj:`str`, `optional`): | ||
Use `Deepspeed <https://github.com/microsoft/deepspeed>`__. This is an experimental feature and its API may | ||
evolve in the future. The value is the location of its json config file (usually ``ds_config.json``). | ||
|
@@ -443,8 +456,8 @@ class TrainingArguments: | |
"help": "When resuming training, whether or not to skip the first epochs and batches to get to the same training data." | ||
}, | ||
) | ||
sharded_ddp: bool = field( | ||
default=False, | ||
sharded_ddp: str = field( | ||
default="", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. perhaps list the choices here? and perhaps a very small example of combining 2 of them in the value, since it's not a usual pattern - a user might struggle here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree with @stas00! |
||
metadata={"help": "Whether or not to use sharded DDP training (in distributed training only)."}, | ||
) | ||
deepspeed: Optional[str] = field( | ||
|
@@ -535,6 +548,20 @@ def __post_init__(self): | |
"Both warmup_ratio and warmup_steps given, warmup_steps will override any effect of warmup_ratio during training" | ||
) | ||
|
||
if isinstance(self.sharded_ddp, bool): | ||
self.sharded_ddp = "simple" if self.sharded_ddp else "" | ||
if isinstance(self.sharded_ddp, str): | ||
self.sharded_ddp = [ShardedDDPOption(s) for s in self.sharded_ddp.split()] | ||
if self.sharded_ddp == [ShardedDDPOption.OFFLOAD]: | ||
raise ValueError( | ||
"`--sharded_ddp offload` can't work on its own. It needs to be added to `--sharded_ddp zero_dp_2` or " | ||
"`--sharded_ddp zero_dp_3`" | ||
) | ||
sgugger marked this conversation as resolved.
Show resolved
Hide resolved
|
||
elif len(self.sharded_ddp) > 1 and ShardedDDPOption.Simple in self.sharded_ddp: | ||
raise ValueError("`--sharded_ddp simple` is not compatible with any other option.") | ||
elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp: | ||
raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.") | ||
|
||
def __repr__(self): | ||
# We override the default repr to remove deprecated arguments from the repr. This method should be removed once | ||
# those deprecated arguments are removed form TrainingArguments. (TODO: v5) | ||
|
@@ -662,7 +689,7 @@ def parallel_mode(self): | |
|
||
- :obj:`ParallelMode.NOT_PARALLEL`: no parallelism (CPU or one GPU). | ||
- :obj:`ParallelMode.NOT_DISTRIBUTED`: several GPUs in one single process (uses :obj:`torch.nn.DataParallel`). | ||
- :obj:`ParallelMode.DISTRIBUTED`: several GPUs, each ahving its own process (uses | ||
- :obj:`ParallelMode.DISTRIBUTED`: several GPUs, each having its own process (uses | ||
:obj:`torch.nn.DistributedDataParallel`). | ||
- :obj:`ParallelMode.TPU`: several TPU cores. | ||
""" | ||
|
@@ -692,6 +719,8 @@ def to_dict(self): | |
for k, v in d.items(): | ||
if isinstance(v, Enum): | ||
d[k] = v.value | ||
if isinstance(v, list) and len(v) > 0 and isinstance(v[0], Enum): | ||
d[k] = [x.value for x in v] | ||
return d | ||
|
||
def to_json_string(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this may introduce a confusion here, should we stick to DP and not DDP to match the real name? i.e. FullyShardedDP and ShardedDP?
Perhaps change the original flag to reflect that as well?
--sharded_dp
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, made a request to make those names renamed to match DDP here:
facebookresearch/fairscale#413 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Stas. I personally think the distinction between DDP and DP is not going to matter anymore. Even pytorch DDP itself is moving to remove the "device_ids" argument in the future so that there isn't a support for a single process DP (as opposed to distributed/multiprocess DP). Therefore, I think sticking with FSDP is fine within fairscale.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your follow up, @min-xu-ai