Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch mosaicml logger to use futures to enable better error handling #2702

Merged
merged 55 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
d953709
ok
j316chuck Nov 8, 2023
f235dfe
fix
siriuslee Oct 24, 2023
e45d410
log message
siriuslee Oct 24, 2023
e1c0079
bump mcli version
siriuslee Nov 7, 2023
ad92077
commit change
j316chuck Nov 8, 2023
5583ef6
commit change
j316chuck Nov 8, 2023
a2195f6
ok test
j316chuck Nov 8, 2023
e3b0320
commit change
j316chuck Nov 8, 2023
2b7e2d0
commit change
j316chuck Nov 8, 2023
54b5792
commit change
j316chuck Nov 10, 2023
0226afb
Merge branch 'dev' into chuck/protected
j316chuck Nov 10, 2023
5913d65
commit change
j316chuck Nov 10, 2023
a2136a7
commit change
j316chuck Nov 10, 2023
4d32153
commit change
j316chuck Nov 10, 2023
4fb7034
Revert "commit change"
j316chuck Nov 10, 2023
f4700a2
commit change
j316chuck Nov 10, 2023
9da1d73
commit change
j316chuck Nov 10, 2023
10c1bd3
commit change
j316chuck Nov 10, 2023
933d96c
commit change
j316chuck Nov 11, 2023
47b5a4e
commit change
j316chuck Nov 11, 2023
3891185
commit change
j316chuck Nov 11, 2023
6f46a1a
commit change
j316chuck Nov 11, 2023
95aed5f
commit change
j316chuck Nov 11, 2023
22dc85e
fix
j316chuck Nov 11, 2023
ce95c18
commit change
j316chuck Nov 11, 2023
2797909
commit change
j316chuck Nov 11, 2023
0ce1d1c
commit change
j316chuck Nov 11, 2023
4764f00
commit change
j316chuck Nov 11, 2023
089fe31
Merge branch 'dev' into chuck/protected
j316chuck Nov 11, 2023
3ed7b58
commit change
j316chuck Nov 11, 2023
822175d
commit change
j316chuck Nov 11, 2023
6e6adb4
commit change
j316chuck Nov 12, 2023
80830a4
Merge branch 'dev' into chuck/protected
mvpatel2000 Nov 13, 2023
7095089
Disable MosaicMLLogger by default
siriuslee Nov 13, 2023
70bf91f
sort import
siriuslee Nov 13, 2023
c975070
Disable MosaicMLLogger in docstest
siriuslee Nov 13, 2023
c5af8fa
remove try/except
dakinggg Nov 14, 2023
da6818f
erge branch 'dev' into chuck/protected
dakinggg Nov 14, 2023
208da63
try again
dakinggg Nov 14, 2023
6ff530e
fix?
dakinggg Nov 14, 2023
16a31c2
fix?
dakinggg Nov 14, 2023
c4ffd2d
fix?
dakinggg Nov 14, 2023
3a71eaf
fix?
dakinggg Nov 14, 2023
3e02574
add cleanup
mvpatel2000 Nov 14, 2023
9e77e8f
cache
mvpatel2000 Nov 14, 2023
dbb4ed2
commit change
j316chuck Nov 14, 2023
90866ab
fix logging
j316chuck Nov 14, 2023
723ca92
Update docs/source/doctest_fixtures.py
j316chuck Nov 14, 2023
99696de
commit change
j316chuck Nov 14, 2023
bb235b3
commit change
j316chuck Nov 14, 2023
6ec08de
commit change
j316chuck Nov 14, 2023
602ac9e
commit change
j316chuck Nov 14, 2023
cc40eca
commit change
j316chuck Nov 14, 2023
b30e413
Merge branch 'dev' into chuck/protected
j316chuck Nov 14, 2023
5b1adb3
Update docs/source/doctest_fixtures.py
j316chuck Nov 14, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions composer/algorithms/seq_length_warmup/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def training_loop(model, train_loader):
<!--pytest.mark.gpu-->
<!--
```python
import os
previous_platform_env = os.environ["MOSAICML_PLATFORM"]
os.environ["MOSAICML_PLATFORM"] = "false"
from tests.common.models import configure_tiny_bert_hf_model
from tests.common.datasets import dummy_bert_lm_dataloader

Expand All @@ -64,6 +67,12 @@ trainer = Trainer(model=model,

trainer.fit()
```
<!--pytest-codeblocks:cont-->
<!--
```python
os.environ["MOSAICML_PLATFORM"] = previous_platform_env
```
-->

### Implementation Details

Expand Down
33 changes: 21 additions & 12 deletions composer/loggers/mosaicml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import os
import time
import warnings
from concurrent.futures import wait
from functools import reduce
from typing import TYPE_CHECKING, Any, Dict, List, Optional

Expand Down Expand Up @@ -57,22 +58,25 @@ class MosaicMLLogger(LoggerDestination):
Example 2: ``ignore_keys = ["wall_clock/*"]`` would ignore all wall clock metrics.

(default: ``None``)
ignore_exceptions: Flag to disable logging exceptions. Defaults to False.
"""

def __init__(
self,
log_interval: int = 60,
ignore_keys: Optional[List[str]] = None,
ignore_exceptions: bool = False,
) -> None:
self.log_interval = log_interval
self.ignore_keys = ignore_keys
self.ignore_exceptions = ignore_exceptions
self._enabled = dist.get_global_rank() == 0
if self._enabled:
self.allowed_fails_left = 3
self.time_last_logged = 0
self.train_dataloader_len = None
self.time_failed_count_adjusted = 0
self.buffered_metadata: Dict[str, Any] = {}
self._futures = []

self.run_name = os.environ.get(RUN_NAME_ENV_VAR)
if self.run_name is not None:
log.info(f'Logging to mosaic run {self.run_name}')
Expand Down Expand Up @@ -140,20 +144,25 @@ def _flush_metadata(self, force_flush: bool = False) -> None:
"""Flush buffered metadata to MosaicML if enough time has passed since last flush."""
if self._enabled and (time.time() - self.time_last_logged > self.log_interval or force_flush):
try:
mcli.update_run_metadata(self.run_name, self.buffered_metadata)
f = mcli.update_run_metadata(self.run_name, self.buffered_metadata, future=True, protect=True)
dakinggg marked this conversation as resolved.
Show resolved Hide resolved
j316chuck marked this conversation as resolved.
Show resolved Hide resolved
self.buffered_metadata = {}
self.time_last_logged = time.time()
# If we have not failed in the last hour, increase the allowed fails. This increases
# robustness to transient network issues.
if time.time() - self.time_failed_count_adjusted > 3600 and self.allowed_fails_left < 3:
self.allowed_fails_left += 1
self.time_failed_count_adjusted = time.time()
self._futures.append(f)
done, incomplete = wait(self._futures, timeout=0.01)
log.info(f'Logged {len(done)} metadata to MosaicML, waiting on {len(incomplete)}')
# Raise any exceptions
for f in done:
if f.exception() is not None:
raise f.exception() # type: ignore
self._futures = list(incomplete)
except Exception as e:
log.error(f'Failed to log metadata to Mosaic with error: {e}')
self.allowed_fails_left -= 1
self.time_failed_count_adjusted = time.time()
if self.allowed_fails_left <= 0:
log.exception(f'Failed to log metadata to Mosaic with error: {e}')
j316chuck marked this conversation as resolved.
Show resolved Hide resolved
if self.ignore_exceptions:
log.info('Ignoring exception and disabling MosaicMLLogger.')
self._enabled = False
else:
log.info('Raising exception. To ignore exceptions, set ignore_exceptions=True.')
j316chuck marked this conversation as resolved.
Show resolved Hide resolved
raise

def _get_training_progress_metrics(self, state: State) -> Dict[str, Any]:
"""Calculates training progress metrics.
Expand Down
6 changes: 6 additions & 0 deletions docs/source/doctest_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from composer.loggers import InMemoryLogger as InMemoryLogger
from composer.loggers import Logger as Logger
from composer.loggers import RemoteUploaderDownloader
from composer.loggers.mosaicml_logger import MOSAICML_PLATFORM_ENV_VAR
j316chuck marked this conversation as resolved.
Show resolved Hide resolved
from composer.models import ComposerModel as ComposerModel
from composer.optim.scheduler import ConstantScheduler
from composer.utils import LibcloudObjectStore
Expand Down Expand Up @@ -92,6 +93,11 @@
# Disable wandb
os.environ['WANDB_MODE'] = 'disabled'

# Disable MosaicMLLogger
os.environ['MOSAICML_PLATFORM'] = 'false'
if 'MOSAICML_ACCESS_TOKEN_FILE' in os.environ:
del os.environ['MOSAICML_ACCESS_TOKEN_FILE']
j316chuck marked this conversation as resolved.
Show resolved Hide resolved

# Change the cwd to be the tempfile, so we don't pollute the documentation source folder
tmpdir = tempfile.mkdtemp()
cwd = os.path.abspath('.')
Expand Down
1 change: 1 addition & 0 deletions docs/source/trainer/logging.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ and also saves them to the file

os.environ["WANDB_MODE"] = "disabled"
os.environ["COMET_API_KEY"] = "<comet_api_key>"
os.environ["MLFLOW_TRACKING_URI"] = ""

.. testcode::
:skipif: not _WANDB_INSTALLED or not _COMETML_INSTALLED
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def package_files(prefix: str, directory: str, extension: str):
'py-cpuinfo>=8.0.0,<10',
'packaging>=21.3.0,<23',
'importlib-metadata>=5.0.0,<7',
'mosaicml-cli>=0.5.8,<0.6',
'mosaicml-cli>=0.5.25,<0.6',
]
extra_deps = {}

Expand Down
8 changes: 7 additions & 1 deletion tests/fixtures/autouse_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import os
import pathlib
from concurrent.futures import Future

import mcli
import pytest
Expand All @@ -13,6 +14,7 @@

import composer
from composer.devices import DeviceCPU, DeviceGPU
from composer.loggers.mosaicml_logger import MOSAICML_PLATFORM_ENV_VAR
from composer.utils import dist, reproducibility


Expand Down Expand Up @@ -118,8 +120,12 @@ def seed_all(rank_zero_seed: int, monkeypatch: pytest.MonkeyPatch):
@pytest.fixture(autouse=True)
def mapi_fixture(monkeypatch):
# Composer auto-adds mosaicml logger when running on platform. Disable logging for tests.
mock_update = lambda *args, **kwargs: None
future_obj = Future()
future_obj.set_result(None)
mock_update = lambda *args, **kwargs: future_obj
monkeypatch.setattr(mcli, 'update_run_metadata', mock_update)
# Disable the MosaicMLLogger by default
monkeypatch.setenv(MOSAICML_PLATFORM_ENV_VAR, 'false')
j316chuck marked this conversation as resolved.
Show resolved Hide resolved


@pytest.fixture(autouse=True)
Expand Down
49 changes: 46 additions & 3 deletions tests/loggers/test_mosaicml_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import json
from concurrent.futures import Future
from typing import Type
from unittest.mock import MagicMock

Expand All @@ -23,10 +24,26 @@

class MockMAPI:

def __init__(self):
def __init__(self, simulate_exception: bool = False):
self.run_metadata = {}

def update_run_metadata(self, run_name, new_metadata):
self.simulate_exception = simulate_exception

def update_run_metadata(self, run_name, new_metadata, future=False, protect=True):
if future:
# Simulate asynchronous behavior using Future
future_obj = Future()
try:
self._update_metadata(run_name, new_metadata)
future_obj.set_result(None) # Set a result to indicate completion
except Exception as e:
future_obj.set_exception(e) # Set an exception if something goes wrong
return future_obj
else:
self._update_metadata(run_name, new_metadata)

def _update_metadata(self, run_name, new_metadata):
if self.simulate_exception:
raise RuntimeError('Simulated exception')
if run_name not in self.run_metadata:
self.run_metadata[run_name] = {}
for k, v in new_metadata.items():
Expand Down Expand Up @@ -94,6 +111,32 @@ def test_logged_data_is_json_serializable(monkeypatch, callback_cls: Type[Callba
assert len(mock_mapi.run_metadata.keys()) == 0


@world_size(1, 2)
@pytest.mark.parametrize('ignore_exceptions', [True, False])
def test_logged_data_exception_handling(monkeypatch, world_size: int, ignore_exceptions: bool):
"""Test that exceptions in MAPI are raised properly."""
mock_mapi = MockMAPI(simulate_exception=True)
monkeypatch.setattr(mcli, 'update_run_metadata', mock_mapi.update_run_metadata)
run_name = 'small_chungus'
monkeypatch.setenv('RUN_NAME', run_name)

logger = MosaicMLLogger(ignore_exceptions=ignore_exceptions)
if dist.get_global_rank() != 0:
assert logger._enabled is False
logger._flush_metadata(force_flush=True)
assert logger._enabled is False
elif ignore_exceptions:
assert logger._enabled is True
logger._flush_metadata(force_flush=True)
assert logger._enabled is False
else:
with pytest.raises(RuntimeError, match='Simulated exception'):
assert logger._enabled is True
mock_mapi = MockMAPI(simulate_exception=True)
monkeypatch.setattr(mcli, 'update_run_metadata', mock_mapi.update_run_metadata)
logger._flush_metadata(force_flush=True)


def test_metric_partial_filtering(monkeypatch):
mock_mapi = MockMAPI()
monkeypatch.setattr(mcli, 'update_run_metadata', mock_mapi.update_run_metadata)
Expand Down
Loading