-
Notifications
You must be signed in to change notification settings - Fork 3.4k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
xmotli02
committed
Jul 27, 2020
1 parent
3f2c102
commit 54e57aa
Showing
4 changed files
with
274 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,181 @@ | ||
""" | ||
File logger | ||
----------- | ||
""" | ||
import io | ||
import os | ||
import csv | ||
import torch | ||
|
||
from argparse import Namespace | ||
from typing import Optional, Dict, Any, Union | ||
|
||
|
||
from pytorch_lightning import _logger as log | ||
from pytorch_lightning.core.saving import save_hparams_to_yaml | ||
from pytorch_lightning.loggers.base import LightningLoggerBase | ||
from pytorch_lightning.utilities.distributed import rank_zero_only | ||
|
||
|
||
class ExperimentWriter(object): | ||
NAME_HPARAMS_FILE = 'hparams.yaml' | ||
NAME_METRICS_FILE = 'metrics.csv' | ||
|
||
def __init__(self, log_dir): | ||
self.hparams = {} | ||
self.metrics = [] | ||
self.metrics_keys = ["step"] | ||
|
||
self.log_dir = log_dir | ||
if not os.path.exists(log_dir): | ||
os.makedirs(log_dir) | ||
|
||
self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE) | ||
|
||
def log_hparams(self, params): | ||
self.hparams.update(params) | ||
|
||
def log_metrics(self, metrics_dict, step=None): | ||
def _handle_value(value): | ||
if isinstance(value, torch.Tensor): | ||
return value.item() | ||
return value | ||
|
||
if step is None: | ||
step = len(self.metrics) | ||
|
||
new_row = dict.fromkeys(self.metrics_keys) | ||
new_row['step'] = step | ||
for k, v in metrics_dict.items(): | ||
if k not in self.metrics_keys: | ||
self.metrics_keys.append(k) | ||
new_row[k] = _handle_value(v) | ||
self.metrics.append(new_row) | ||
|
||
def save(self): | ||
hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE) | ||
save_hparams_to_yaml(hparams_file, self.hparams) | ||
|
||
if self.metrics: | ||
with io.open(self.metrics_file_path, 'w', newline='') as f: | ||
self.writer = csv.DictWriter(f, fieldnames=self.metrics_keys) | ||
self.writer.writeheader() | ||
self.writer.writerows(self.metrics) | ||
|
||
|
||
class FileLogger(LightningLoggerBase): | ||
r""" | ||
Log to local file system in yaml and CSV format. Logs are saved to | ||
``os.path.join(save_dir, name, version)``. | ||
Example: | ||
>>> from pytorch_lightning import Trainer | ||
>>> from pytorch_lightning.loggers import FileLogger | ||
>>> logger = FileLogger("logs", name="my_exp_name") | ||
>>> trainer = Trainer(logger=logger) | ||
Args: | ||
save_dir: Save directory | ||
name: Experiment name. Defaults to ``'default'``. | ||
version: Experiment version. If version is not specified the logger inspects the save | ||
directory for existing versions, then automatically assigns the next available version. | ||
""" | ||
|
||
def __init__(self, | ||
save_dir: str, | ||
name: Optional[str] = "default", | ||
version: Optional[Union[int, str]] = None): | ||
|
||
super().__init__() | ||
self._save_dir = save_dir | ||
self._name = name or '' | ||
self._version = version | ||
self._experiment = None | ||
|
||
@property | ||
def root_dir(self) -> str: | ||
""" | ||
Parent directory for all checkpoint subdirectories. | ||
If the experiment name parameter is ``None`` or the empty string, no experiment subdirectory is used | ||
and the checkpoint will be saved in "save_dir/version_dir" | ||
""" | ||
if self.name is None or len(self.name) == 0: | ||
return self._save_dir | ||
else: | ||
return os.path.join(self._save_dir, self.name) | ||
|
||
@property | ||
def log_dir(self) -> str: | ||
""" | ||
The log directory for this run. By default, it is named | ||
``'version_${self.version}'`` but it can be overridden by passing a string value | ||
for the constructor's version parameter instead of ``None`` or an int. | ||
""" | ||
# create a pseudo standard path ala test-tube | ||
version = self.version if isinstance(self.version, str) else f"version_{self.version}" | ||
log_dir = os.path.join(self.root_dir, version) | ||
return log_dir | ||
|
||
@property | ||
def experiment(self) -> ExperimentWriter: | ||
r""" | ||
Actual ExperimentWriter object. To use ExperimentWriter features in your | ||
:class:`~pytorch_lightning.core.lightning.LightningModule` do the following. | ||
Example:: | ||
self.logger.experiment.some_experiment_writer_function() | ||
""" | ||
if self._experiment is not None: | ||
return self._experiment | ||
|
||
os.makedirs(self.root_dir, exist_ok=True) | ||
self._experiment = ExperimentWriter(log_dir=self.log_dir) | ||
return self._experiment | ||
|
||
@rank_zero_only | ||
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: | ||
params = self._convert_params(params) | ||
self.experiment.log_hparams(params) | ||
|
||
@rank_zero_only | ||
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: | ||
self.experiment.log_metrics(metrics, step) | ||
|
||
@rank_zero_only | ||
def save(self) -> None: | ||
super().save() | ||
self.experiment.save() | ||
|
||
@rank_zero_only | ||
def finalize(self, status: str) -> None: | ||
self.save() | ||
|
||
@property | ||
def name(self) -> str: | ||
return self._name | ||
|
||
@property | ||
def version(self) -> int: | ||
if self._version is None: | ||
self._version = self._get_next_version() | ||
return self._version | ||
|
||
def _get_next_version(self): | ||
root_dir = os.path.join(self._save_dir, self.name) | ||
|
||
if not os.path.isdir(root_dir): | ||
log.warning('Missing logger folder: %s', root_dir) | ||
return 0 | ||
|
||
existing_versions = [] | ||
for d in os.listdir(root_dir): | ||
if os.path.isdir(os.path.join(root_dir, d)) and d.startswith("version_"): | ||
existing_versions.append(int(d.split("_")[1])) | ||
|
||
if len(existing_versions) == 0: | ||
return 0 | ||
|
||
return max(existing_versions) + 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
from argparse import Namespace | ||
|
||
import pytest | ||
import torch | ||
import os | ||
|
||
from pytorch_lightning.loggers import FileLogger | ||
|
||
|
||
def test_file_logger_automatic_versioning(tmpdir): | ||
"""Verify that automatic versioning works""" | ||
|
||
root_dir = tmpdir.mkdir("exp") | ||
root_dir.mkdir("version_0") | ||
root_dir.mkdir("version_1") | ||
|
||
logger = FileLogger(save_dir=tmpdir, name="exp") | ||
|
||
assert logger.version == 2 | ||
|
||
|
||
def test_file_logger_manual_versioning(tmpdir): | ||
"""Verify that manual versioning works""" | ||
|
||
root_dir = tmpdir.mkdir("exp") | ||
root_dir.mkdir("version_0") | ||
root_dir.mkdir("version_1") | ||
root_dir.mkdir("version_2") | ||
|
||
logger = FileLogger(save_dir=tmpdir, name="exp", version=1) | ||
|
||
assert logger.version == 1 | ||
|
||
|
||
def test_file_logger_named_version(tmpdir): | ||
"""Verify that manual versioning works for string versions, e.g. '2020-02-05-162402' """ | ||
|
||
exp_name = "exp" | ||
tmpdir.mkdir(exp_name) | ||
expected_version = "2020-02-05-162402" | ||
|
||
logger = FileLogger(save_dir=tmpdir, name=exp_name, version=expected_version) | ||
logger.log_hyperparams({"a": 1, "b": 2}) | ||
logger.save() | ||
assert logger.version == expected_version | ||
assert os.listdir(tmpdir / exp_name) == [expected_version] | ||
assert os.listdir(tmpdir / exp_name / expected_version) | ||
|
||
|
||
@pytest.mark.parametrize("name", ['', None]) | ||
def test_file_logger_no_name(tmpdir, name): | ||
"""Verify that None or empty name works""" | ||
logger = FileLogger(save_dir=tmpdir, name=name) | ||
logger.save() | ||
assert logger.root_dir == tmpdir | ||
assert os.listdir(tmpdir / 'version_0') | ||
|
||
|
||
@pytest.mark.parametrize("step_idx", [10, None]) | ||
def test_file_logger_log_metrics(tmpdir, step_idx): | ||
logger = FileLogger(tmpdir) | ||
metrics = { | ||
"float": 0.3, | ||
"int": 1, | ||
"FloatTensor": torch.tensor(0.1), | ||
"IntTensor": torch.tensor(1) | ||
} | ||
logger.log_metrics(metrics, step_idx) | ||
logger.save() | ||
|
||
|
||
def test_file_logger_log_hyperparams(tmpdir): | ||
logger = FileLogger(tmpdir) | ||
hparams = { | ||
"float": 0.3, | ||
"int": 1, | ||
"string": "abc", | ||
"bool": True, | ||
"dict": {'a': {'b': 'c'}}, | ||
"list": [1, 2, 3], | ||
"namespace": Namespace(foo=Namespace(bar='buzz')), | ||
"layer": torch.nn.BatchNorm1d | ||
} | ||
logger.log_hyperparams(hparams) | ||
logger.save() |