diff --git a/composer/core/state.py b/composer/core/state.py index b980255d6d0..98b22bb890a 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -108,6 +108,10 @@ class State(Serializable): # stopping conditions max_epochs: int + # dataloaders + train_dataloader: types.DataLoader + eval_dataloader: types.DataLoader + # precision # storing precision internally so strings can be passed into the constructor and setter # but the getter will always return a Precision enum @@ -134,10 +138,6 @@ class State(Serializable): # scaler scaler: Optional[types.Scaler] = None - # dataloaders - train_dataloader: Optional[types.DataLoader] = None - eval_dataloader: Optional[types.DataLoader] = None - # algorithms algorithms: Sequence[Algorithm] = tuple() callbacks: Sequence[Callback] = tuple() diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index e835891cc88..6f62dd3dd42 100755 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -203,15 +203,39 @@ def __init__( timeout=ddp_timeout, ) - self.state = State(max_epochs=max_epochs, - train_batch_size=train_batch_size, - eval_batch_size=eval_batch_size, - algorithms=algorithms, - callbacks=callbacks, - model=model, - grad_accum=grad_accum, - precision=precision, - precision_context=self.device.precision_context) + dl_hparams = DataloaderHparams(num_workers=num_workers, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, + pin_memory=pin_memory, + timeout=timeout) + + train_gpu_batch_size = train_batch_size // self.ddp.world_size + train_dataloader = self.device.dataloader_to_device( + self.ddp.create_dataloader(train_gpu_batch_size, dl_hparams, train_dataloader_spec), + train_dataloader_spec.prefetch_fn, + ) + self.train_dl_spec = train_dataloader_spec + + eval_gpu_batch_size = eval_batch_size // self.ddp.world_size + eval_dataloader = self.device.dataloader_to_device( + self.ddp.create_dataloader(eval_gpu_batch_size, dl_hparams, eval_dataloader_spec), + eval_dataloader_spec.prefetch_fn, + ) + self.eval_dl_spec = eval_dataloader_spec + + self.state = State( + max_epochs=max_epochs, + train_batch_size=train_batch_size, + eval_batch_size=eval_batch_size, + algorithms=algorithms, + callbacks=callbacks, + model=model, + grad_accum=grad_accum, + precision=precision, + precision_context=self.device.precision_context, + train_dataloader=train_dataloader, + eval_dataloader=eval_dataloader, + ) if not log_destinations: log_destinations = [TQDMLoggerBackend()] @@ -220,15 +244,6 @@ def __init__( self.engine = Engine(self.state, self.state.algorithms, self.logger, self.state.callbacks) - self.train_dl_spec = train_dataloader_spec - self.eval_dl_spec = eval_dataloader_spec - - self.dl_hparams = DataloaderHparams(num_workers=num_workers, - prefetch_factor=prefetch_factor, - persistent_workers=persistent_workers, - pin_memory=pin_memory, - timeout=timeout) - self.validate_every_n_batches = validate_every_n_batches self.validate_every_n_epochs = validate_every_n_epochs self.compute_training_metrics = compute_training_metrics @@ -243,8 +258,8 @@ def __init__( # run INIT event before optimizers and schedulers are created self.engine.run_event(Event.INIT) - assert isinstance(self.train_dl_spec.dataset, collections.abc.Sized) - steps_per_epoch = len(self.train_dl_spec.dataset) // train_batch_size + assert isinstance(self.state.train_dataloader.dataset, collections.abc.Sized) + steps_per_epoch = len(self.state.train_dataloader.dataset) // train_batch_size # Need to use hparams here because optimizer and schedulers need to be created after Event.INIT if not optimizer_hparams: optimizer_hparams = DecoupledSGDWHparams(lr=0.1, momentum=0.9, weight_decay=1.0e-4) @@ -365,21 +380,6 @@ def _create_dataloaders(self) -> None: state = self.state # compute per gpu batch size - train_gpu_batch_size = state.train_batch_size // state.world_size - eval_gpu_batch_size = state.eval_batch_size // state.world_size - - train_dataloader = self.ddp.create_dataloader(train_gpu_batch_size, self.dl_hparams, self.train_dl_spec) - eval_dataloader = self.ddp.create_dataloader(eval_gpu_batch_size, self.dl_hparams, self.eval_dl_spec) - - # move to device - state.train_dataloader = self.device.dataloader_to_device( - train_dataloader, - self.train_dl_spec.prefetch_fn, - ) - state.eval_dataloader = self.device.dataloader_to_device( - eval_dataloader, - self.eval_dl_spec.prefetch_fn, - ) def _get_metrics_as_collection(self, *, is_train: bool) -> MetricCollection: """Get metrics relevant to the model. Metrics are all implemented as subclasses @@ -477,11 +477,6 @@ def _train_loop(self) -> None: state.model = self.device.module_to_device(state.model) state.optimizers = map_collection(state.optimizers, self.device.optimizer_to_device) - # create dataloaders here after distributed training has started - self._create_dataloaders() - if state.train_dataloader is None or state.eval_dataloader is None: - raise ValueError('Dataloaders were not created properly, and are None.') - # wrap model with DDP state.model = self.ddp.prepare_module(state.model) original_model = state.model.module diff --git a/tests/algorithms/test_blurpool_algorithm.py b/tests/algorithms/test_blurpool_algorithm.py index d3fea2d2354..4baaed5613e 100644 --- a/tests/algorithms/test_blurpool_algorithm.py +++ b/tests/algorithms/test_blurpool_algorithm.py @@ -12,12 +12,12 @@ from composer.algorithms import BlurPool, BlurPoolHparams from composer.algorithms.blurpool.blurpool_layers import BlurConv2d, BlurMaxPool2d from composer.core import Event, State -from composer.core.types import Model, Precision +from composer.core.types import DataLoader, Model, Precision from tests.fixtures.models import SimpleConvModel @pytest.fixture -def state(simple_conv_model: Model): +def state(simple_conv_model: Model, dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader): state = State( epoch=50, step=50, @@ -27,6 +27,8 @@ def state(simple_conv_model: Model): max_epochs=100, model=simple_conv_model, precision=Precision.FP32, + train_dataloader=dummy_train_dataloader, + eval_dataloader=dummy_val_dataloader, ) return state diff --git a/tests/algorithms/test_channels_last.py b/tests/algorithms/test_channels_last.py index 4bdf82386d8..9cae2873d1a 100644 --- a/tests/algorithms/test_channels_last.py +++ b/tests/algorithms/test_channels_last.py @@ -7,7 +7,7 @@ from composer.algorithms import ChannelsLastHparams from composer.core.event import Event from composer.core.state import State -from composer.core.types import Model, Precision, Tensor +from composer.core.types import DataLoader, Model, Precision, Tensor def _has_singleton_dimension(tensor: Tensor) -> bool: @@ -31,7 +31,7 @@ def _infer_memory_format(tensor: Tensor) -> str: @pytest.fixture() -def state(simple_conv_model: Model): +def state(simple_conv_model: Model, dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader): return State( model=simple_conv_model, train_batch_size=100, @@ -39,6 +39,8 @@ def state(simple_conv_model: Model): precision=Precision.FP32, grad_accum=1, max_epochs=10, + train_dataloader=dummy_train_dataloader, + eval_dataloader=dummy_val_dataloader, ) diff --git a/tests/algorithms/test_layer_freezing.py b/tests/algorithms/test_layer_freezing.py index a77a358d652..6f70a167b67 100644 --- a/tests/algorithms/test_layer_freezing.py +++ b/tests/algorithms/test_layer_freezing.py @@ -6,14 +6,15 @@ from composer.algorithms import LayerFreezing, LayerFreezingHparams from composer.core.state import State -from composer.core.types import Event, Model, Precision +from composer.core.types import DataLoader, Event, Model, Precision from composer.loggers import Logger from composer.trainer.trainer_hparams import TrainerHparams from composer.utils import ensure_tuple from tests.utils.trainer_fit import train_model -def _generate_state(epoch: int, max_epochs: int, model: Model): +def _generate_state(epoch: int, max_epochs: int, model: Model, train_dataloader: DataLoader, + val_dataloader: DataLoader): state = State( epoch=epoch, step=epoch, @@ -24,6 +25,8 @@ def _generate_state(epoch: int, max_epochs: int, model: Model): model=model, optimizers=(torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.99),), precision=Precision.FP32, + train_dataloader=train_dataloader, + eval_dataloader=val_dataloader, ) return state @@ -39,8 +42,13 @@ def _check_param_groups(expected_groups, actual_groups): assert (actual_groups[i]['params'][j] == expected_params).all() -def test_freeze_layers_no_freeze(simple_conv_model: Model, noop_dummy_logger: Logger): - state = _generate_state(epoch=10, max_epochs=100, model=simple_conv_model) +def test_freeze_layers_no_freeze(simple_conv_model: Model, noop_dummy_logger: Logger, + dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader): + state = _generate_state(epoch=10, + max_epochs=100, + model=simple_conv_model, + train_dataloader=dummy_train_dataloader, + val_dataloader=dummy_val_dataloader) first_optimizer = ensure_tuple(state.optimizers)[0] assert first_optimizer is not None @@ -53,8 +61,13 @@ def test_freeze_layers_no_freeze(simple_conv_model: Model, noop_dummy_logger: Lo _check_param_groups(expected_param_groups, updated_param_groups) -def test_freeze_layers_with_freeze(simple_conv_model: Model, noop_dummy_logger: Logger): - state = _generate_state(epoch=80, max_epochs=100, model=simple_conv_model) +def test_freeze_layers_with_freeze(simple_conv_model: Model, noop_dummy_logger: Logger, + dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader): + state = _generate_state(epoch=80, + max_epochs=100, + model=simple_conv_model, + train_dataloader=dummy_train_dataloader, + val_dataloader=dummy_val_dataloader) first_optimizer = ensure_tuple(state.optimizers)[0] assert first_optimizer is not None diff --git a/tests/algorithms/test_stochastic_depth.py b/tests/algorithms/test_stochastic_depth.py index 697f2f4a3ea..0fa95507a4c 100644 --- a/tests/algorithms/test_stochastic_depth.py +++ b/tests/algorithms/test_stochastic_depth.py @@ -35,6 +35,7 @@ def dummy_state(dummy_dataloader_hparams: DataloaderHparams): grad_accum=1, max_epochs=100, model=model, + eval_dataloader=train_dataloader, precision=Precision.FP32) diff --git a/tests/fixtures/dummy_fixtures.py b/tests/fixtures/dummy_fixtures.py index aafc319e7dc..9934ecea98e 100755 --- a/tests/fixtures/dummy_fixtures.py +++ b/tests/fixtures/dummy_fixtures.py @@ -5,9 +5,7 @@ import pytest import torch -import torch.distributed as dist import torch.utils.data -from _pytest.monkeypatch import MonkeyPatch from composer import Logger, State from composer.core.types import DataLoader, Model, Precision @@ -71,8 +69,8 @@ def dummy_val_dataloader_spec(dummy_train_dataset_hparams: SyntheticDatasetHpara @pytest.fixture() -def dummy_state_without_rank(dummy_model: SimpleBatchPairModel, dummy_train_batch_size: int, - dummy_val_batch_size: int) -> State: +def dummy_state_without_rank(dummy_model: SimpleBatchPairModel, dummy_train_batch_size: int, dummy_val_batch_size: int, + dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader) -> State: state = State( model=dummy_model, epoch=5, @@ -81,6 +79,8 @@ def dummy_state_without_rank(dummy_model: SimpleBatchPairModel, dummy_train_batc grad_accum=1, train_batch_size=dummy_train_batch_size, eval_batch_size=dummy_val_batch_size, + train_dataloader=dummy_train_dataloader, + eval_dataloader=dummy_val_dataloader, max_epochs=10, ) return state @@ -110,8 +110,7 @@ def dummy_val_dataloader(dummy_dataloader_hparams: DataloaderHparams, dummy_val_ @pytest.fixture() -def dummy_state(dummy_state_without_rank: State, monkeypatch: MonkeyPatch) -> State: - monkeypatch.setattr(dist, "get_rank", lambda: 0) +def dummy_state(dummy_state_without_rank: State) -> State: return dummy_state_without_rank @@ -191,7 +190,7 @@ def simple_conv_model_input(): @pytest.fixture() -def state_with_model(simple_conv_model: Model): +def state_with_model(simple_conv_model: Model, dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader): state = State( epoch=50, step=50, @@ -201,6 +200,8 @@ def state_with_model(simple_conv_model: Model): max_epochs=100, model=simple_conv_model, precision=Precision.FP32, + train_dataloader=dummy_train_dataloader, + eval_dataloader=dummy_val_dataloader, ) return state diff --git a/tests/test_state.py b/tests/test_state.py index 356acbc559c..745b94663c7 100755 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -22,7 +22,7 @@ def random_tensor(size=(4, 10)): return torch.rand(*size) -def get_dummy_state(model: BaseMosaicModel): +def get_dummy_state(model: BaseMosaicModel, train_dataloader: types.DataLoader, val_dataloader: types.DataLoader): optimizers = torch.optim.Adadelta(model.parameters()) return State(model=model, @@ -36,6 +36,8 @@ def get_dummy_state(model: BaseMosaicModel): loss=random_tensor(), batch=(random_tensor(), random_tensor()), outputs=random_tensor(), + train_dataloader=train_dataloader, + eval_dataloader=val_dataloader, optimizers=optimizers, schedulers=torch.optim.lr_scheduler.StepLR(optimizers, step_size=3), algorithms=[DummyHparams().initialize_object()]) @@ -106,12 +108,13 @@ def get_batch(model: SimpleBatchPairModel, dataloader_hparams: DataloaderHparams def test_state_serialize(tmpdir: pathlib.Path, dummy_model: BaseMosaicModel, - dummy_dataloader_hparams: DataloaderHparams): + dummy_dataloader_hparams: DataloaderHparams, dummy_train_dataloader: types.DataLoader, + dummy_val_dataloader: types.DataLoader): assert isinstance(dummy_model, SimpleBatchPairModel) - state1 = get_dummy_state(dummy_model) - state2 = get_dummy_state(dummy_model) + state1 = get_dummy_state(dummy_model, dummy_train_dataloader, dummy_val_dataloader) + state2 = get_dummy_state(dummy_model, dummy_train_dataloader, dummy_val_dataloader) # train one step to set the optimizer states batch = get_batch(dummy_model, dummy_dataloader_hparams) diff --git a/tests/trainer/test_ddp_sync_strategy.py b/tests/trainer/test_ddp_sync_strategy.py index 09ef3fab720..94362407943 100755 --- a/tests/trainer/test_ddp_sync_strategy.py +++ b/tests/trainer/test_ddp_sync_strategy.py @@ -7,7 +7,7 @@ import torch.nn as nn from composer.core.state import State -from composer.core.types import Tensor +from composer.core.types import DataLoader, Tensor from composer.trainer.ddp import DDP @@ -44,7 +44,8 @@ def loss(self, output: Tensor, target: Tensor): pytest.param('forced_sync', ([-1, None, None], [-1, -1, None], [-1.5, -1.5, None]), id='forced_sync'), ]) @pytest.mark.world_size(2) -def test_ddp_sync_strategy(ddp_sync_strategy: str, expected_grads: List[Optional[float]]): +def test_ddp_sync_strategy(ddp_sync_strategy: str, expected_grads: List[Optional[float]], + dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader): original_model = MinimalConditionalModel() ddp = DDP(backend="gloo", find_unused_parameters=True, sync_strategy=ddp_sync_strategy, timeout=5.) optimizer = torch.optim.SGD(original_model.parameters(), 0.1) @@ -55,6 +56,8 @@ def test_ddp_sync_strategy(ddp_sync_strategy: str, expected_grads: List[Optional eval_batch_size=1, grad_accum=2, max_epochs=1, + train_dataloader=dummy_train_dataloader, + eval_dataloader=dummy_val_dataloader, precision='fp32') batches = [[(1, Tensor([1])), (1, Tensor([2]))], [(2, Tensor([1])), (2, Tensor([2]))]]