Skip to content

Commit

Permalink
[refactor] pipe: simplify balance and module checks (#346)
Browse files Browse the repository at this point in the history
  • Loading branch information
msbaines authored Feb 3, 2021
1 parent cd18644 commit f21b5ff
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 291 deletions.
10 changes: 2 additions & 8 deletions fairscale/nn/pipe/async_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from .async_pipeline import AsyncPipeline
from .async_schedule import Invocation, Location, ModuleWrapper
from .multiprocess_pipe import MultiProcessPipe, check_balance
from .multiprocess_pipe import MultiProcessPipe
from .skip.skippable import Skippable
from .types import LazyModule

Expand Down Expand Up @@ -54,14 +54,8 @@ def create_pipeline(self) -> None:
)

def instantiate_partition(
self,
module: Union[nn.Sequential, List[LazyModule]],
balance: Iterable[int],
group: torch.distributed.ProcessGroup,
self, module: Union[nn.Sequential, List[LazyModule]], balance: List[int], group: torch.distributed.ProcessGroup,
) -> List[ModuleWrapper]:
balance = list(balance)
check_balance(module, balance, True)

layers: NamedModules = OrderedDict()

def maybe_realize(layer: Any) -> nn.Module:
Expand Down
132 changes: 27 additions & 105 deletions fairscale/nn/pipe/multiprocess_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,85 +53,22 @@
NamedModules = OrderedDict


def recommend_auto_balance(message: str) -> str:
"""Expands a message with recommendation to :mod:`torchpipe.balance`."""
return f"""{message}
If your model is still under development, its optimal balance would change
frequently. In this case, we highly recommend 'fairscale.nn.pipe.balance' for
naive automatic balancing:
from fairscale.nn import Pipe
from fairscale.nn.pipe.balance import balance_by_time
partitions = torch.cuda.device_count()
sample = torch.empty(...)
balance = balance_by_time(partitions, model, sample)
model = MultiProcessPipe(model, balance, ...)
"""


# FIXME(tom) make this a valid way to call
def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None:
for layer in module:
if isinstance(layer, nn.Module):
pass
elif isinstance(layer, LazyModule):
pass
else:
raise TypeError(f"layer {type(layer)} must be nn.Module or LazyModule to be partitioned")


def verify_module(module: Union[nn.Sequential, List[LazyModule]]) -> None:
if isinstance(module, Iterable) and not isinstance(module, nn.Sequential):
verify_list_of_callable(module)
else:
if not isinstance(module, nn.Sequential):
raise TypeError("module must be nn.Sequential to be partitioned")
if len(set(map(id, module))) != len(module):
raise ValueError("module with duplicate children is not supported")

named_children = list(module.named_children())
if len(named_children) != len(module):
raise ValueError("module with duplicate children is not supported")


def verify_splitting(module: nn.Sequential, partitions: List[nn.Sequential], balance: Iterable[int],) -> None:
num_parameters = len(list(module.parameters()))
num_child_parameters = sum(len(list(child.parameters())) for child in module.children())
if num_parameters == num_child_parameters:
return

for i in range(len(partitions)):
for j in range(i + 1, len(partitions)):
parti = partitions[i]
partj = partitions[j]
for p in parti.parameters():
for q in partj.parameters():
if p is q:
raise ValueError("module with duplicate parameters on distinct devices is not supported")


class BalanceError(ValueError):
pass


def check_balance(module: Any, balance: Iterable[int], filter_unique: bool = False) -> None:

if filter_unique:
module_len = len(set(map(id, module)))
else:
module_len = len(module)

if module_len != sum(balance):
raise BalanceError(
def check_balance(module: Union[nn.Sequential, List[LazyModule]], balance: List[int]) -> None:
if len(module) != sum(balance):
raise ValueError(
f"module and sum of balance have different length (module: {len(module)}, sum of balance: {sum(balance)})"
)

if any(x <= 0 for x in balance):
raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})")
raise ValueError(f"all balance numbers must be positive integer (balance: {balance})")


def split_module(module: nn.Sequential, balance: Iterable[int],) -> List[nn.Sequential]:
def split_module(module: nn.Sequential, balance: List[int]) -> List[nn.Sequential]:
"""Splits a module into multiple partitions.
Returns:
Expand All @@ -148,10 +85,6 @@ def split_module(module: nn.Sequential, balance: Iterable[int],) -> List[nn.Sequ
the number of devices is fewer than the number of partitions.
"""
balance = list(balance)

check_balance(module, balance)

j = 0
partitions = []
layers: NamedModules = OrderedDict()
Expand Down Expand Up @@ -274,7 +207,7 @@ class MultiProcessPipe(Module):
def __init__(
self,
module: Union[nn.Sequential, List[LazyModule]],
balance: Optional[Iterable[int]] = None,
balance: Iterable[int],
*,
group: Optional[torch.distributed.ProcessGroup] = None,
worker_map: Optional[Dict[int, str]] = None,
Expand All @@ -290,14 +223,14 @@ def __init__(
chunks = int(chunks)
checkpoint = str(checkpoint)

if balance is None:
raise ValueError(recommend_auto_balance("balance is required"))
if chunks <= 0:
raise ValueError("number of chunks must be positive integer")
if checkpoint not in ["always", "except_last", "never"]:
raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'")

self.balance = list(balance)
verify_module(module)
check_balance(module, self.balance)

# Verify if the underlying skippable modules satisfy integrity. The
# integrity can be verified before forward() because it is static.
Expand All @@ -320,33 +253,28 @@ def __init__(
else:
self.group = group

self.balance = list(balance)

if self.group.size() < len(self.balance):
raise IndexError(
f"too few ranks to hold given partitions (ranks: {self.group.size()}, partitions:"
f" {len(self.balance)})"
)
try:
rank = self.group.rank()
if rank >= len(self.balance):
warnings.warn("More ranks than partitions, some ranks unused")
self.partitions: List[ModuleWrapper] = []
else:
self.partitions = self.instantiate_partition(module, balance, self.group)
if deferred_batch_norm:
for part in self.partitions:
part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks)
for name, part in enumerate(self.partitions):
self.add_module(str(name), part.module)
if isinstance(module, nn.Sequential):
local_partitions = split_module(module, balance)
self._skip_layout = inspect_skip_layout(local_partitions)
else:
self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom)

except BalanceError as exc:
raise ValueError(recommend_auto_balance(str(exc)))
rank = self.group.rank()
if rank >= len(self.balance):
warnings.warn("More ranks than partitions, some ranks unused")
self.partitions: List[ModuleWrapper] = []
else:
self.partitions = self.instantiate_partition(module, self.balance, self.group)
if deferred_batch_norm:
for part in self.partitions:
part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks)
for name, part in enumerate(self.partitions):
self.add_module(str(name), part.module)
if isinstance(module, nn.Sequential):
local_partitions = split_module(module, self.balance)
self._skip_layout = inspect_skip_layout(local_partitions)
else:
self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom)

rank = self.group.rank()
if rank >= len(self.balance):
Expand Down Expand Up @@ -378,14 +306,8 @@ def create_pipeline(self) -> None:
)

def instantiate_partition(
self,
module: Union[nn.Sequential, List[LazyModule]],
balance: Iterable[int],
group: torch.distributed.ProcessGroup,
self, module: Union[nn.Sequential, List[LazyModule]], balance: List[int], group: torch.distributed.ProcessGroup,
) -> List[ModuleWrapper]:
balance = list(balance)
check_balance(module, balance, True)

layers: NamedModules = OrderedDict()

def maybe_realize(layer: Any) -> nn.Module:
Expand Down
Loading

0 comments on commit f21b5ff

Please sign in to comment.