diff --git a/fairscale/nn/pipe/async_pipe.py b/fairscale/nn/pipe/async_pipe.py index 2e1868650..1508f6781 100644 --- a/fairscale/nn/pipe/async_pipe.py +++ b/fairscale/nn/pipe/async_pipe.py @@ -13,7 +13,7 @@ from .async_pipeline import AsyncPipeline from .async_schedule import Invocation, Location, ModuleWrapper -from .multiprocess_pipe import MultiProcessPipe, check_balance +from .multiprocess_pipe import MultiProcessPipe from .skip.skippable import Skippable from .types import LazyModule @@ -54,14 +54,8 @@ def create_pipeline(self) -> None: ) def instantiate_partition( - self, - module: Union[nn.Sequential, List[LazyModule]], - balance: Iterable[int], - group: torch.distributed.ProcessGroup, + self, module: Union[nn.Sequential, List[LazyModule]], balance: List[int], group: torch.distributed.ProcessGroup, ) -> List[ModuleWrapper]: - balance = list(balance) - check_balance(module, balance, True) - layers: NamedModules = OrderedDict() def maybe_realize(layer: Any) -> nn.Module: diff --git a/fairscale/nn/pipe/multiprocess_pipe.py b/fairscale/nn/pipe/multiprocess_pipe.py index 289a0c5ad..0c105a438 100644 --- a/fairscale/nn/pipe/multiprocess_pipe.py +++ b/fairscale/nn/pipe/multiprocess_pipe.py @@ -53,85 +53,22 @@ NamedModules = OrderedDict -def recommend_auto_balance(message: str) -> str: - """Expands a message with recommendation to :mod:`torchpipe.balance`.""" - return f"""{message} - -If your model is still under development, its optimal balance would change -frequently. In this case, we highly recommend 'fairscale.nn.pipe.balance' for -naive automatic balancing: - - from fairscale.nn import Pipe - from fairscale.nn.pipe.balance import balance_by_time - - partitions = torch.cuda.device_count() - sample = torch.empty(...) - balance = balance_by_time(partitions, model, sample) - - model = MultiProcessPipe(model, balance, ...) -""" - - -# FIXME(tom) make this a valid way to call -def verify_list_of_callable(module: Union[nn.Sequential, list]) -> None: - for layer in module: - if isinstance(layer, nn.Module): - pass - elif isinstance(layer, LazyModule): - pass - else: - raise TypeError(f"layer {type(layer)} must be nn.Module or LazyModule to be partitioned") - - def verify_module(module: Union[nn.Sequential, List[LazyModule]]) -> None: - if isinstance(module, Iterable) and not isinstance(module, nn.Sequential): - verify_list_of_callable(module) - else: - if not isinstance(module, nn.Sequential): - raise TypeError("module must be nn.Sequential to be partitioned") + if len(set(map(id, module))) != len(module): + raise ValueError("module with duplicate children is not supported") - named_children = list(module.named_children()) - if len(named_children) != len(module): - raise ValueError("module with duplicate children is not supported") - -def verify_splitting(module: nn.Sequential, partitions: List[nn.Sequential], balance: Iterable[int],) -> None: - num_parameters = len(list(module.parameters())) - num_child_parameters = sum(len(list(child.parameters())) for child in module.children()) - if num_parameters == num_child_parameters: - return - - for i in range(len(partitions)): - for j in range(i + 1, len(partitions)): - parti = partitions[i] - partj = partitions[j] - for p in parti.parameters(): - for q in partj.parameters(): - if p is q: - raise ValueError("module with duplicate parameters on distinct devices is not supported") - - -class BalanceError(ValueError): - pass - - -def check_balance(module: Any, balance: Iterable[int], filter_unique: bool = False) -> None: - - if filter_unique: - module_len = len(set(map(id, module))) - else: - module_len = len(module) - - if module_len != sum(balance): - raise BalanceError( +def check_balance(module: Union[nn.Sequential, List[LazyModule]], balance: List[int]) -> None: + if len(module) != sum(balance): + raise ValueError( f"module and sum of balance have different length (module: {len(module)}, sum of balance: {sum(balance)})" ) if any(x <= 0 for x in balance): - raise BalanceError(f"all balance numbers must be positive integer (balance: {balance})") + raise ValueError(f"all balance numbers must be positive integer (balance: {balance})") -def split_module(module: nn.Sequential, balance: Iterable[int],) -> List[nn.Sequential]: +def split_module(module: nn.Sequential, balance: List[int]) -> List[nn.Sequential]: """Splits a module into multiple partitions. Returns: @@ -148,10 +85,6 @@ def split_module(module: nn.Sequential, balance: Iterable[int],) -> List[nn.Sequ the number of devices is fewer than the number of partitions. """ - balance = list(balance) - - check_balance(module, balance) - j = 0 partitions = [] layers: NamedModules = OrderedDict() @@ -274,7 +207,7 @@ class MultiProcessPipe(Module): def __init__( self, module: Union[nn.Sequential, List[LazyModule]], - balance: Optional[Iterable[int]] = None, + balance: Iterable[int], *, group: Optional[torch.distributed.ProcessGroup] = None, worker_map: Optional[Dict[int, str]] = None, @@ -290,14 +223,14 @@ def __init__( chunks = int(chunks) checkpoint = str(checkpoint) - if balance is None: - raise ValueError(recommend_auto_balance("balance is required")) if chunks <= 0: raise ValueError("number of chunks must be positive integer") if checkpoint not in ["always", "except_last", "never"]: raise ValueError("checkpoint is not one of 'always', 'except_last', or 'never'") + self.balance = list(balance) verify_module(module) + check_balance(module, self.balance) # Verify if the underlying skippable modules satisfy integrity. The # integrity can be verified before forward() because it is static. @@ -320,33 +253,28 @@ def __init__( else: self.group = group - self.balance = list(balance) - if self.group.size() < len(self.balance): raise IndexError( f"too few ranks to hold given partitions (ranks: {self.group.size()}, partitions:" f" {len(self.balance)})" ) - try: - rank = self.group.rank() - if rank >= len(self.balance): - warnings.warn("More ranks than partitions, some ranks unused") - self.partitions: List[ModuleWrapper] = [] - else: - self.partitions = self.instantiate_partition(module, balance, self.group) - if deferred_batch_norm: - for part in self.partitions: - part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks) - for name, part in enumerate(self.partitions): - self.add_module(str(name), part.module) - if isinstance(module, nn.Sequential): - local_partitions = split_module(module, balance) - self._skip_layout = inspect_skip_layout(local_partitions) - else: - self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom) - except BalanceError as exc: - raise ValueError(recommend_auto_balance(str(exc))) + rank = self.group.rank() + if rank >= len(self.balance): + warnings.warn("More ranks than partitions, some ranks unused") + self.partitions: List[ModuleWrapper] = [] + else: + self.partitions = self.instantiate_partition(module, self.balance, self.group) + if deferred_batch_norm: + for part in self.partitions: + part.module = DeferredBatchNorm.convert_deferred_batch_norm(part.module, chunks) + for name, part in enumerate(self.partitions): + self.add_module(str(name), part.module) + if isinstance(module, nn.Sequential): + local_partitions = split_module(module, self.balance) + self._skip_layout = inspect_skip_layout(local_partitions) + else: + self._skip_layout = SkipLayout(len(module), {}) # FIXME(tom) rank = self.group.rank() if rank >= len(self.balance): @@ -378,14 +306,8 @@ def create_pipeline(self) -> None: ) def instantiate_partition( - self, - module: Union[nn.Sequential, List[LazyModule]], - balance: Iterable[int], - group: torch.distributed.ProcessGroup, + self, module: Union[nn.Sequential, List[LazyModule]], balance: List[int], group: torch.distributed.ProcessGroup, ) -> List[ModuleWrapper]: - balance = list(balance) - check_balance(module, balance, True) - layers: NamedModules = OrderedDict() def maybe_realize(layer: Any) -> nn.Module: diff --git a/tests/nn/pipe_process/test_pipe.py b/tests/nn/pipe_process/test_pipe.py index 13b8a636f..4952b00aa 100644 --- a/tests/nn/pipe_process/test_pipe.py +++ b/tests/nn/pipe_process/test_pipe.py @@ -32,7 +32,7 @@ initialize_model_parallel, ) from fairscale.nn.pipe import AsyncPipe, LazyModule, MultiProcessPipe -from fairscale.utils.testing import get_worker_map, set_random_seed, torch_spawn, torch_version +from fairscale.utils.testing import get_worker_map, torch_spawn, torch_version @torch_spawn([2]) @@ -706,15 +706,11 @@ def named_children(pipe_class): @torch_spawn([1]) @pytest.mark.parametrize("pipe_class", [MultiProcessPipe, AsyncPipe]) def recommend_auto_balance(pipe_class): - with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"): - # balance is required - pipe_class(nn.Sequential()) - - with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"): + with pytest.raises(ValueError): # module and sum of balance have differen length (module: 0, sum of balance: 1) pipe_class(nn.Sequential(), [1]) - with pytest.raises(ValueError, match="fairscale.nn.pipe.balance"): + with pytest.raises(ValueError): # module and sum of balance have different length (module: 2, sum of balance: 1) pipe_class(nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1)), [1]) @@ -805,174 +801,3 @@ def async_event_loop(): if pipe.final_stage: loss = output.mean() loss.backward() - - -@torch_spawn([4]) -def reuse_lazy(): - if False: # speed - reused = LazyModule(lambda: nn.Linear(10, 10)) - model = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] - # model = [reused, reused, nn.Linear(10, 10), nn.ReLU(), reused, reused, nn.ReLU(), reused, reused, nn.ReLU()] - pipe = AsyncPipe(model, [3, 1, 1], worker_map=get_worker_map()) - pipe.eval() - output = pipe(torch.rand(10)) - - print(f"output on {pipe.group.rank()}, {output}") - torch.distributed.barrier() - - set_random_seed(1234) - # test both foward - reused = nn.Linear(10, 10) - layers = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] - model = nn.Sequential(*layers) - model.eval() - - set_random_seed(1234) - # ensure identical weights but no sharing between model and pipe - reused = nn.Linear(10, 10) - layers = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] - pipe = AsyncPipe(layers, [3, 1, 1], worker_map=get_worker_map()) - pipe.eval() - model_optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - pipe_optimizer = torch.optim.SGD(pipe.parameters(), lr=0.01, momentum=0.9) if len(list(pipe.parameters())) else None - inputs = torch.rand(10) - if False: # speed - model_out = model(inputs) - pipe_out = pipe(inputs) - - torch.distributed.barrier() - - if pipe.final_stage: - assert torch.equal(model_out, pipe_out) - - model.train() - pipe.train() - model_out = model(inputs) - pipe_out = pipe(inputs) - if pipe.final_stage: - pipe_loss = pipe_out.mean() - pipe_loss.backward() - - model_loss = model_out.mean() - model_loss.backward() - - model_optimizer.step() - if pipe_optimizer: - pipe_optimizer.step() - - model.eval() - pipe.eval() - model_out = model(inputs) - pipe_out = pipe(inputs) - - print(f"before barrier on {torch.distributed.get_rank()}") - torch.distributed.barrier() - print(f"after barrier on {torch.distributed.get_rank()}") - - if pipe.final_stage: - assert torch.equal(model_out, pipe_out) - - -@torch_spawn([1]) -def instantiate_partition(): - from fairscale.nn.pipe.async_schedule import Location - - model = nn.Sequential(nn.Linear(1, 1)) - pipe = AsyncPipe(model, balance=[1], worker_map=get_worker_map(), chunks=1) - - class FakeGroup: - def __init__(self, rank, size): - self._rank = rank - self._size = size - - def rank(self): - return self._rank - - def size(self): - return self._size - - def check_partitions(model, balance, expected_order, expected_ranks): - """Check the instantiated model matches expectation of order and rank - model: a list of modules or an nn.Sequential - balance: the balance argument to MultiProcessPipe - expected_order: the index of modules in `model` in the order they will - be executed, grouped by nn.Sequential - expected_rank: the rank that each module will be executed on - """ - - invocations = [] - invocation_wrapper = dict() - - # Collect `Invocation` and `Invocation` -> `ModuleWrapper` mapping from - # instantiated model - for rank in range(len(balance)): - instantiated = pipe.instantiate_partition(model, balance, FakeGroup(rank, len(balance))) - for part in instantiated: - assert isinstance(part.module, nn.Sequential) - for inv in part.invocations: - invocations.append(inv) - invocation_wrapper[inv] = part - - modules = [] - prev = None - current = Location(0, 0) - ranks = [] - - for order, inv in enumerate(sorted(invocations, key=lambda x: x.order)): - # Check integrity of Location chain - assert inv.order == order - assert inv.source == prev - assert inv.this == current - prev = inv.this - current = inv.dest - modules.append(list(invocation_wrapper[inv].module.children())) - ranks.append(inv.this.stage) - - # assert len(modules) == len(expected_order) - for left, right in zip(modules, expected_order): - assert len(left) == len(right), f"{right}" - assert list(map(id, left)) == list(map(id, (model[e] for e in right))), f"{right}" - - assert ranks == expected_ranks - - reused = nn.Linear(20, 20) - model = [reused, nn.Linear(10, 10), nn.ReLU(), reused, nn.ReLU(), reused, nn.ReLU()] - balance = [3, 1, 1] - - check_partitions( - model, balance, expected_order=[[0], [1, 2], [0], [4], [0], [6]], expected_ranks=[0, 0, 0, 1, 0, 2] - ) - - reused2 = nn.Linear(5, 5) - model = [reused, reused2, nn.Linear(10, 10), nn.ReLU(), reused, reused2, nn.ReLU(), reused, reused2, nn.ReLU()] - balance = [4, 1, 1] - - check_partitions( - model, - balance, - expected_order=[[0], [1], [2, 3], [0], [1], [6], [0], [1], [9]], - expected_ranks=[0, 0, 0, 0, 0, 1, 0, 0, 2], - ) - - reused2 = nn.Linear(5, 5) - model = [ - nn.Linear(10, 10), - reused, - nn.Linear(10, 10), - nn.ReLU(), - reused, - reused2, - nn.ReLU(), - reused, - reused2, - nn.ReLU(), - ] - # 0 1 2 3 1 5 6 1 5 9 - balance = [4, 2, 1] - - check_partitions( - model, - balance, - expected_order=[[0], [1], [2, 3], [1], [5], [6], [1], [5], [9]], - expected_ranks=[0, 0, 0, 0, 1, 1, 0, 1, 2], - )