Skip to content

Commit

Permalink
Create dataloader on trainer __init__() (mosaicml#92)
Browse files Browse the repository at this point in the history
mosaicml#65 made the global rank available in the process start, so it is no longer necessarry to wait until training_start() to create the dataloader. Instead, dataloaders are now initialized in init.

This change will help with dataloader profiling, as now the dataloader will be immediately bound to the state.
  • Loading branch information
ravi-mosaicml authored and coryMosaicML committed Feb 23, 2022
1 parent 106d006 commit 63828bf
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 67 deletions.
8 changes: 4 additions & 4 deletions composer/core/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@ class State(Serializable):
# stopping conditions
max_epochs: int

# dataloaders
train_dataloader: types.DataLoader
eval_dataloader: types.DataLoader

# precision
# storing precision internally so strings can be passed into the constructor and setter
# but the getter will always return a Precision enum
Expand All @@ -134,10 +138,6 @@ class State(Serializable):
# scaler
scaler: Optional[types.Scaler] = None

# dataloaders
train_dataloader: Optional[types.DataLoader] = None
eval_dataloader: Optional[types.DataLoader] = None

# algorithms
algorithms: Sequence[Algorithm] = tuple()
callbacks: Sequence[Callback] = tuple()
Expand Down
75 changes: 35 additions & 40 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,15 +203,39 @@ def __init__(
timeout=ddp_timeout,
)

self.state = State(max_epochs=max_epochs,
train_batch_size=train_batch_size,
eval_batch_size=eval_batch_size,
algorithms=algorithms,
callbacks=callbacks,
model=model,
grad_accum=grad_accum,
precision=precision,
precision_context=self.device.precision_context)
dl_hparams = DataloaderHparams(num_workers=num_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory=pin_memory,
timeout=timeout)

train_gpu_batch_size = train_batch_size // self.ddp.world_size
train_dataloader = self.device.dataloader_to_device(
self.ddp.create_dataloader(train_gpu_batch_size, dl_hparams, train_dataloader_spec),
train_dataloader_spec.prefetch_fn,
)
self.train_dl_spec = train_dataloader_spec

eval_gpu_batch_size = eval_batch_size // self.ddp.world_size
eval_dataloader = self.device.dataloader_to_device(
self.ddp.create_dataloader(eval_gpu_batch_size, dl_hparams, eval_dataloader_spec),
eval_dataloader_spec.prefetch_fn,
)
self.eval_dl_spec = eval_dataloader_spec

self.state = State(
max_epochs=max_epochs,
train_batch_size=train_batch_size,
eval_batch_size=eval_batch_size,
algorithms=algorithms,
callbacks=callbacks,
model=model,
grad_accum=grad_accum,
precision=precision,
precision_context=self.device.precision_context,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
)

if not log_destinations:
log_destinations = [TQDMLoggerBackend()]
Expand All @@ -220,15 +244,6 @@ def __init__(

self.engine = Engine(self.state, self.state.algorithms, self.logger, self.state.callbacks)

self.train_dl_spec = train_dataloader_spec
self.eval_dl_spec = eval_dataloader_spec

self.dl_hparams = DataloaderHparams(num_workers=num_workers,
prefetch_factor=prefetch_factor,
persistent_workers=persistent_workers,
pin_memory=pin_memory,
timeout=timeout)

self.validate_every_n_batches = validate_every_n_batches
self.validate_every_n_epochs = validate_every_n_epochs
self.compute_training_metrics = compute_training_metrics
Expand All @@ -243,8 +258,8 @@ def __init__(
# run INIT event before optimizers and schedulers are created
self.engine.run_event(Event.INIT)

assert isinstance(self.train_dl_spec.dataset, collections.abc.Sized)
steps_per_epoch = len(self.train_dl_spec.dataset) // train_batch_size
assert isinstance(self.state.train_dataloader.dataset, collections.abc.Sized)
steps_per_epoch = len(self.state.train_dataloader.dataset) // train_batch_size
# Need to use hparams here because optimizer and schedulers need to be created after Event.INIT
if not optimizer_hparams:
optimizer_hparams = DecoupledSGDWHparams(lr=0.1, momentum=0.9, weight_decay=1.0e-4)
Expand Down Expand Up @@ -365,21 +380,6 @@ def _create_dataloaders(self) -> None:
state = self.state

# compute per gpu batch size
train_gpu_batch_size = state.train_batch_size // state.world_size
eval_gpu_batch_size = state.eval_batch_size // state.world_size

train_dataloader = self.ddp.create_dataloader(train_gpu_batch_size, self.dl_hparams, self.train_dl_spec)
eval_dataloader = self.ddp.create_dataloader(eval_gpu_batch_size, self.dl_hparams, self.eval_dl_spec)

# move to device
state.train_dataloader = self.device.dataloader_to_device(
train_dataloader,
self.train_dl_spec.prefetch_fn,
)
state.eval_dataloader = self.device.dataloader_to_device(
eval_dataloader,
self.eval_dl_spec.prefetch_fn,
)

def _get_metrics_as_collection(self, *, is_train: bool) -> MetricCollection:
"""Get metrics relevant to the model. Metrics are all implemented as subclasses
Expand Down Expand Up @@ -477,11 +477,6 @@ def _train_loop(self) -> None:
state.model = self.device.module_to_device(state.model)
state.optimizers = map_collection(state.optimizers, self.device.optimizer_to_device)

# create dataloaders here after distributed training has started
self._create_dataloaders()
if state.train_dataloader is None or state.eval_dataloader is None:
raise ValueError('Dataloaders were not created properly, and are None.')

# wrap model with DDP
state.model = self.ddp.prepare_module(state.model)
original_model = state.model.module
Expand Down
6 changes: 4 additions & 2 deletions tests/algorithms/test_blurpool_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
from composer.algorithms import BlurPool, BlurPoolHparams
from composer.algorithms.blurpool.blurpool_layers import BlurConv2d, BlurMaxPool2d
from composer.core import Event, State
from composer.core.types import Model, Precision
from composer.core.types import DataLoader, Model, Precision
from tests.fixtures.models import SimpleConvModel


@pytest.fixture
def state(simple_conv_model: Model):
def state(simple_conv_model: Model, dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader):
state = State(
epoch=50,
step=50,
Expand All @@ -27,6 +27,8 @@ def state(simple_conv_model: Model):
max_epochs=100,
model=simple_conv_model,
precision=Precision.FP32,
train_dataloader=dummy_train_dataloader,
eval_dataloader=dummy_val_dataloader,
)
return state

Expand Down
6 changes: 4 additions & 2 deletions tests/algorithms/test_channels_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from composer.algorithms import ChannelsLastHparams
from composer.core.event import Event
from composer.core.state import State
from composer.core.types import Model, Precision, Tensor
from composer.core.types import DataLoader, Model, Precision, Tensor


def _has_singleton_dimension(tensor: Tensor) -> bool:
Expand All @@ -31,14 +31,16 @@ def _infer_memory_format(tensor: Tensor) -> str:


@pytest.fixture()
def state(simple_conv_model: Model):
def state(simple_conv_model: Model, dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader):
return State(
model=simple_conv_model,
train_batch_size=100,
eval_batch_size=100,
precision=Precision.FP32,
grad_accum=1,
max_epochs=10,
train_dataloader=dummy_train_dataloader,
eval_dataloader=dummy_val_dataloader,
)


Expand Down
25 changes: 19 additions & 6 deletions tests/algorithms/test_layer_freezing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@

from composer.algorithms import LayerFreezing, LayerFreezingHparams
from composer.core.state import State
from composer.core.types import Event, Model, Precision
from composer.core.types import DataLoader, Event, Model, Precision
from composer.loggers import Logger
from composer.trainer.trainer_hparams import TrainerHparams
from composer.utils import ensure_tuple
from tests.utils.trainer_fit import train_model


def _generate_state(epoch: int, max_epochs: int, model: Model):
def _generate_state(epoch: int, max_epochs: int, model: Model, train_dataloader: DataLoader,
val_dataloader: DataLoader):
state = State(
epoch=epoch,
step=epoch,
Expand All @@ -24,6 +25,8 @@ def _generate_state(epoch: int, max_epochs: int, model: Model):
model=model,
optimizers=(torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.99),),
precision=Precision.FP32,
train_dataloader=train_dataloader,
eval_dataloader=val_dataloader,
)
return state

Expand All @@ -39,8 +42,13 @@ def _check_param_groups(expected_groups, actual_groups):
assert (actual_groups[i]['params'][j] == expected_params).all()


def test_freeze_layers_no_freeze(simple_conv_model: Model, noop_dummy_logger: Logger):
state = _generate_state(epoch=10, max_epochs=100, model=simple_conv_model)
def test_freeze_layers_no_freeze(simple_conv_model: Model, noop_dummy_logger: Logger,
dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader):
state = _generate_state(epoch=10,
max_epochs=100,
model=simple_conv_model,
train_dataloader=dummy_train_dataloader,
val_dataloader=dummy_val_dataloader)

first_optimizer = ensure_tuple(state.optimizers)[0]
assert first_optimizer is not None
Expand All @@ -53,8 +61,13 @@ def test_freeze_layers_no_freeze(simple_conv_model: Model, noop_dummy_logger: Lo
_check_param_groups(expected_param_groups, updated_param_groups)


def test_freeze_layers_with_freeze(simple_conv_model: Model, noop_dummy_logger: Logger):
state = _generate_state(epoch=80, max_epochs=100, model=simple_conv_model)
def test_freeze_layers_with_freeze(simple_conv_model: Model, noop_dummy_logger: Logger,
dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader):
state = _generate_state(epoch=80,
max_epochs=100,
model=simple_conv_model,
train_dataloader=dummy_train_dataloader,
val_dataloader=dummy_val_dataloader)

first_optimizer = ensure_tuple(state.optimizers)[0]
assert first_optimizer is not None
Expand Down
1 change: 1 addition & 0 deletions tests/algorithms/test_stochastic_depth.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def dummy_state(dummy_dataloader_hparams: DataloaderHparams):
grad_accum=1,
max_epochs=100,
model=model,
eval_dataloader=train_dataloader,
precision=Precision.FP32)


Expand Down
15 changes: 8 additions & 7 deletions tests/fixtures/dummy_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@

import pytest
import torch
import torch.distributed as dist
import torch.utils.data
from _pytest.monkeypatch import MonkeyPatch

from composer import Logger, State
from composer.core.types import DataLoader, Model, Precision
Expand Down Expand Up @@ -71,8 +69,8 @@ def dummy_val_dataloader_spec(dummy_train_dataset_hparams: SyntheticDatasetHpara


@pytest.fixture()
def dummy_state_without_rank(dummy_model: SimpleBatchPairModel, dummy_train_batch_size: int,
dummy_val_batch_size: int) -> State:
def dummy_state_without_rank(dummy_model: SimpleBatchPairModel, dummy_train_batch_size: int, dummy_val_batch_size: int,
dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader) -> State:
state = State(
model=dummy_model,
epoch=5,
Expand All @@ -81,6 +79,8 @@ def dummy_state_without_rank(dummy_model: SimpleBatchPairModel, dummy_train_batc
grad_accum=1,
train_batch_size=dummy_train_batch_size,
eval_batch_size=dummy_val_batch_size,
train_dataloader=dummy_train_dataloader,
eval_dataloader=dummy_val_dataloader,
max_epochs=10,
)
return state
Expand Down Expand Up @@ -110,8 +110,7 @@ def dummy_val_dataloader(dummy_dataloader_hparams: DataloaderHparams, dummy_val_


@pytest.fixture()
def dummy_state(dummy_state_without_rank: State, monkeypatch: MonkeyPatch) -> State:
monkeypatch.setattr(dist, "get_rank", lambda: 0)
def dummy_state(dummy_state_without_rank: State) -> State:
return dummy_state_without_rank


Expand Down Expand Up @@ -191,7 +190,7 @@ def simple_conv_model_input():


@pytest.fixture()
def state_with_model(simple_conv_model: Model):
def state_with_model(simple_conv_model: Model, dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader):
state = State(
epoch=50,
step=50,
Expand All @@ -201,6 +200,8 @@ def state_with_model(simple_conv_model: Model):
max_epochs=100,
model=simple_conv_model,
precision=Precision.FP32,
train_dataloader=dummy_train_dataloader,
eval_dataloader=dummy_val_dataloader,
)
return state

Expand Down
11 changes: 7 additions & 4 deletions tests/test_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def random_tensor(size=(4, 10)):
return torch.rand(*size)


def get_dummy_state(model: BaseMosaicModel):
def get_dummy_state(model: BaseMosaicModel, train_dataloader: types.DataLoader, val_dataloader: types.DataLoader):
optimizers = torch.optim.Adadelta(model.parameters())

return State(model=model,
Expand All @@ -36,6 +36,8 @@ def get_dummy_state(model: BaseMosaicModel):
loss=random_tensor(),
batch=(random_tensor(), random_tensor()),
outputs=random_tensor(),
train_dataloader=train_dataloader,
eval_dataloader=val_dataloader,
optimizers=optimizers,
schedulers=torch.optim.lr_scheduler.StepLR(optimizers, step_size=3),
algorithms=[DummyHparams().initialize_object()])
Expand Down Expand Up @@ -106,12 +108,13 @@ def get_batch(model: SimpleBatchPairModel, dataloader_hparams: DataloaderHparams


def test_state_serialize(tmpdir: pathlib.Path, dummy_model: BaseMosaicModel,
dummy_dataloader_hparams: DataloaderHparams):
dummy_dataloader_hparams: DataloaderHparams, dummy_train_dataloader: types.DataLoader,
dummy_val_dataloader: types.DataLoader):

assert isinstance(dummy_model, SimpleBatchPairModel)

state1 = get_dummy_state(dummy_model)
state2 = get_dummy_state(dummy_model)
state1 = get_dummy_state(dummy_model, dummy_train_dataloader, dummy_val_dataloader)
state2 = get_dummy_state(dummy_model, dummy_train_dataloader, dummy_val_dataloader)

# train one step to set the optimizer states
batch = get_batch(dummy_model, dummy_dataloader_hparams)
Expand Down
7 changes: 5 additions & 2 deletions tests/trainer/test_ddp_sync_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn

from composer.core.state import State
from composer.core.types import Tensor
from composer.core.types import DataLoader, Tensor
from composer.trainer.ddp import DDP


Expand Down Expand Up @@ -44,7 +44,8 @@ def loss(self, output: Tensor, target: Tensor):
pytest.param('forced_sync', ([-1, None, None], [-1, -1, None], [-1.5, -1.5, None]), id='forced_sync'),
])
@pytest.mark.world_size(2)
def test_ddp_sync_strategy(ddp_sync_strategy: str, expected_grads: List[Optional[float]]):
def test_ddp_sync_strategy(ddp_sync_strategy: str, expected_grads: List[Optional[float]],
dummy_train_dataloader: DataLoader, dummy_val_dataloader: DataLoader):
original_model = MinimalConditionalModel()
ddp = DDP(backend="gloo", find_unused_parameters=True, sync_strategy=ddp_sync_strategy, timeout=5.)
optimizer = torch.optim.SGD(original_model.parameters(), 0.1)
Expand All @@ -55,6 +56,8 @@ def test_ddp_sync_strategy(ddp_sync_strategy: str, expected_grads: List[Optional
eval_batch_size=1,
grad_accum=2,
max_epochs=1,
train_dataloader=dummy_train_dataloader,
eval_dataloader=dummy_val_dataloader,
precision='fp32')

batches = [[(1, Tensor([1])), (1, Tensor([2]))], [(2, Tensor([1])), (2, Tensor([2]))]]
Expand Down

0 comments on commit 63828bf

Please sign in to comment.