Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Model Data Init and Training Progress to MosaicMLLogger #2633

Merged
merged 32 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
800ffaf
created run events callback
jjanezhang Oct 10, 2023
4375fd2
added callback desc
jjanezhang Oct 10, 2023
f0ec59e
moved mosaiclogger testing out
jjanezhang Oct 10, 2023
2776f9a
updated metric type to be compatible with other loggers
jjanezhang Oct 11, 2023
2bf9629
fixed merge conflict
jjanezhang Oct 11, 2023
5beca25
moved timestamp to mosaic logger
jjanezhang Oct 13, 2023
3971fd6
logging info to get [epoch x/xx][batch y/yy] or [token x/xx] for users
jjanezhang Oct 13, 2023
9057fa7
logging info to get [epoch x/xx][batch y/yy] or [token x/xx] for users
jjanezhang Oct 14, 2023
7684fd6
Merge github.com:mosaicml/composer into jane/training-events
jjanezhang Oct 14, 2023
92eea62
formatting
jjanezhang Oct 14, 2023
4db203a
updated to calculate progress instead of returning metrics to mapi
jjanezhang Oct 20, 2023
594b508
Merge branch 'dev' of github.com:mosaicml/composer into jane/training…
jjanezhang Oct 20, 2023
b8ffb0b
moved state init in mosaicmllogger
jjanezhang Oct 20, 2023
0da9b36
caclulation updates for epoch and batches
jjanezhang Oct 23, 2023
bc4ad38
added none assertion for max duration
jjanezhang Oct 24, 2023
8ae2d05
cached the dataloader len at batch start and edited test names
jjanezhang Oct 24, 2023
2fd2a61
formatting
jjanezhang Oct 24, 2023
ba1d7bd
Merge branch 'dev' into jane/training-events
dakinggg Oct 24, 2023
c2aefa7
added short ciruit for enabe
jjanezhang Oct 25, 2023
cb34732
Merge branch 'jane/training-events' of github.com:mosaicml/composer i…
jjanezhang Oct 25, 2023
16141fe
Merge branch 'dev' into jane/training-events
jjanezhang Oct 25, 2023
0e31151
added enabled to test
jjanezhang Oct 25, 2023
406de6a
merged
jjanezhang Oct 25, 2023
aa1a438
Merge branch 'dev' of github.com:mosaicml/composer into jane/training…
jjanezhang Oct 25, 2023
2d00303
Merge branch 'dev' into jane/training-events
jjanezhang Oct 25, 2023
3435e28
Update composer/loggers/mosaicml_logger.py
jjanezhang Oct 26, 2023
1758db3
Update composer/loggers/mosaicml_logger.py
jjanezhang Oct 26, 2023
5a923c3
added training log to fit end
jjanezhang Oct 26, 2023
86216c9
Merge branch 'jane/training-events' of github.com:mosaicml/composer i…
jjanezhang Oct 26, 2023
60c4f2f
moved improt
jjanezhang Oct 26, 2023
1e24149
fixed imports again
jjanezhang Oct 26, 2023
5e03da7
moved import back out
jjanezhang Oct 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion composer/loggers/mosaicml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import mcli
import torch

from composer.core.time import TimeUnit
jjanezhang marked this conversation as resolved.
Show resolved Hide resolved
from composer.loggers import Logger
from composer.loggers.logger import Logger
from composer.loggers.logger_destination import LoggerDestination
Expand Down Expand Up @@ -72,7 +73,6 @@ def __init__(
self.time_last_logged = 0
self.time_failed_count_adjusted = 0
self.buffered_metadata: Dict[str, Any] = {}

self.run_name = os.environ.get(RUN_NAME_ENV_VAR)
if self.run_name is not None:
log.info(f'Logging to mosaic run {self.run_name}')
Expand All @@ -88,14 +88,58 @@ def log_metrics(self, metrics: Dict[str, Any], step: Optional[int] = None) -> No
self._log_metadata(metrics)

def after_load(self, state: State, logger: Logger) -> None:
# Log model data downloaded and initialized for run events
self._log_metadata({'model_initialized_time': time.time()})
# Log WandB run URL if it exists. Must run on after_load as WandB is setup on event init
for callback in state.callbacks:
if isinstance(callback, WandBLogger):
run_url = callback.run_url
if run_url is not None:
self._log_metadata({'wandb/run_url': run_url})

def _get_training_progress_metrics(self, state: State) -> Dict[str, Any]:
jjanezhang marked this conversation as resolved.
Show resolved Hide resolved
"""Calculates training progress metrics.

If user submits max duration:
- in tokens -> format: [token=x/xx]
- in batches -> format: [batch=x/xx]
- in epoch -> format: [epoch=x/xx] [batch=x/xx] (where batch refers to batches completed in current epoch)
If batches per epoch cannot be calculated, return [epoch=x/xx]

If no training duration given -> format: ''
"""
assert state.max_duration is not None
jjanezhang marked this conversation as resolved.
Show resolved Hide resolved
if state.max_duration.unit == TimeUnit.TOKEN:
return {
'training_progress': f'[token={state.timestamp.token.value}/{state.max_duration.value}]',
}
if state.max_duration.unit == TimeUnit.BATCH:
return {
'training_progress': f'[batch={state.timestamp.batch.value}/{state.max_duration.value}]',
}
training_progress_metrics = {}
if state.max_duration.unit == TimeUnit.EPOCH:
cur_batch = int(state.timestamp.batch_in_epoch)
cur_epoch = int(state.timestamp.epoch)
if int(state.timestamp.epoch) >= 1:
batches_per_epoch = int(
(state.timestamp.batch - state.timestamp.batch_in_epoch).value / state.timestamp.epoch.value)
curr_progress = f'[batch={cur_batch}/{batches_per_epoch}]'
elif state.dataloader_len is None:
curr_progress = f'[batch={cur_batch}]'
jjanezhang marked this conversation as resolved.
Show resolved Hide resolved
else:
total = int(state.dataloader_len)
jjanezhang marked this conversation as resolved.
Show resolved Hide resolved
curr_progress = f'[batch={cur_batch}/{total}]'
if cur_epoch < state.max_duration.value:
cur_epoch += 1
training_progress_metrics = {
'training_progress': f'[epoch={cur_epoch}/{state.max_duration.value}]',
}
training_progress_metrics['training_sub_progress'] = curr_progress
jjanezhang marked this conversation as resolved.
Show resolved Hide resolved
return training_progress_metrics

def batch_end(self, state: State, logger: Logger) -> None:
self._log_metadata(self._get_training_progress_metrics(state))
irenedea marked this conversation as resolved.
Show resolved Hide resolved
jjanezhang marked this conversation as resolved.
Show resolved Hide resolved
self._flush_metadata()

def epoch_end(self, state: State, logger: Logger) -> None:
Expand Down
92 changes: 92 additions & 0 deletions tests/loggers/test_mosaicml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

import json
from typing import Type
from unittest.mock import MagicMock

import mcli
import pytest
import torch
from torch.utils.data import DataLoader

from composer.core import Callback
from composer.core.time import Time, TimeUnit
from composer.loggers import WandBLogger
from composer.loggers.mosaicml_logger import (MOSAICML_ACCESS_TOKEN_ENV_VAR, MOSAICML_PLATFORM_ENV_VAR, MosaicMLLogger,
format_data_to_json_serializable)
Expand Down Expand Up @@ -194,3 +196,93 @@ def test_auto_add_logger(monkeypatch, platform_env_var, access_token_env_var, lo
# Otherwise, no logger
else:
assert logger_count == 0

jjanezhang marked this conversation as resolved.
Show resolved Hide resolved

def test_model_initialized_time_logged(monkeypatch):
mock_mapi = MockMAPI()
monkeypatch.setattr(mcli, 'update_run_metadata', mock_mapi.update_run_metadata)
run_name = 'test-run-name'
monkeypatch.setenv('RUN_NAME', run_name)
trainer = Trainer(model=SimpleModel(),
train_dataloader=DataLoader(RandomClassificationDataset()),
jjanezhang marked this conversation as resolved.
Show resolved Hide resolved
train_subset_num_batches=1,
max_duration='1ep',
loggers=[MosaicMLLogger()])
trainer.fit()
assert isinstance(mock_mapi.run_metadata[run_name]['mosaicml/model_initialized_time'], float)


def test_progress_logged(monkeypatch, tiny_bert_tokenizer):
jjanezhang marked this conversation as resolved.
Show resolved Hide resolved
mock_mapi = MockMAPI()
monkeypatch.setattr(mcli, 'update_run_metadata', mock_mapi.update_run_metadata)
run_name = 'test-run-name'
monkeypatch.setenv('RUN_NAME', run_name)
trainer = Trainer(model=SimpleModel(),
train_dataloader=DataLoader(RandomClassificationDataset()),
train_subset_num_batches=1,
max_duration='4ba',
loggers=[MosaicMLLogger()])
trainer.fit()
metadata = mock_mapi.run_metadata[run_name]
assert 'mosaicml/training_progress' in metadata
assert metadata['mosaicml/training_progress'] == '[batch=4/4]'
assert 'mosaicml/training_sub_progress' not in metadata


def test_token_training_progress_logged():
jjanezhang marked this conversation as resolved.
Show resolved Hide resolved
logger = MosaicMLLogger()
state = MagicMock()
state.max_duration.unit = TimeUnit.TOKEN
state.max_duration.value = 64
state.timestamp.token.value = 50
training_progress = logger._get_training_progress_metrics(state)
assert 'training_progress' in training_progress
assert training_progress['training_progress'] == '[token=50/64]'
assert 'training_sub_progress' not in training_progress


def test_epoch_training_progress_logged():
logger = MosaicMLLogger()
state = MagicMock()
state.max_duration.unit = TimeUnit.EPOCH
state.max_duration = Time(3, TimeUnit.EPOCH)
state.timestamp.epoch = Time(2, TimeUnit.EPOCH)
state.timestamp.batch = Time(11, TimeUnit.BATCH)
state.timestamp.batch_in_epoch = Time(1, TimeUnit.BATCH)
training_progress = logger._get_training_progress_metrics(state)
assert 'training_progress' in training_progress
assert training_progress['training_progress'] == '[epoch=3/3]'
assert 'training_sub_progress' in training_progress
assert training_progress['training_sub_progress'] == '[batch=1/5]'


def test_epoch_zero_progress_logged():
logger = MosaicMLLogger()
state = MagicMock()
state.dataloader_len = 5
state.max_duration.unit = TimeUnit.EPOCH
state.max_duration = Time(3, TimeUnit.EPOCH)
state.timestamp.epoch = Time(0, TimeUnit.EPOCH)
state.timestamp.batch = Time(0, TimeUnit.BATCH)
state.timestamp.batch_in_epoch = Time(0, TimeUnit.BATCH)
training_progress = logger._get_training_progress_metrics(state)
assert 'training_progress' in training_progress
assert training_progress['training_progress'] == '[epoch=1/3]'
assert 'training_sub_progress' in training_progress
assert training_progress['training_sub_progress'] == '[batch=0/5]'


def test_epoch_zero_no_dataloader_progress_logged():
logger = MosaicMLLogger()
state = MagicMock()
state.dataloader_len = None
state.max_duration.unit = TimeUnit.EPOCH
state.max_duration = Time(3, TimeUnit.EPOCH)
state.timestamp.epoch = Time(0, TimeUnit.EPOCH)
state.timestamp.batch = Time(1, TimeUnit.BATCH)
state.timestamp.batch_in_epoch = Time(1, TimeUnit.BATCH)
training_progress = logger._get_training_progress_metrics(state)
assert 'training_progress' in training_progress
assert training_progress['training_progress'] == '[epoch=1/3]'
assert 'training_sub_progress' in training_progress
assert training_progress['training_sub_progress'] == '[batch=1]'
Loading