Skip to content

Commit

Permalink
fixup! Added basic file logger Lightning-AI#1803
Browse files Browse the repository at this point in the history
  • Loading branch information
xmotli02 authored and Borda committed Aug 5, 2020
1 parent 75d1159 commit 29de2d8
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 16 deletions.
3 changes: 2 additions & 1 deletion pytorch_lightning/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

from pytorch_lightning.loggers.base import LightningLoggerBase, LoggerCollection
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.loggers.file_logger import FileLogger
from pytorch_lightning.loggers.csv import CSVLogger


__all__ = [
'LightningLoggerBase',
'LoggerCollection',
'TensorBoardLogger',
'CSVLogger',
]

try:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
"""
File logger
-----------
CSV logger
----------
CSV logger for basic experiment logging that does not require opening ports
"""
import io
import os
Expand All @@ -14,10 +17,25 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.core.saving import save_hparams_to_yaml
from pytorch_lightning.loggers.base import LightningLoggerBase
from pytorch_lightning.utilities.distributed import rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_warn, rank_zero_only


class ExperimentWriter(object):
r"""
Experiment writer for CSVLogger.
Currently supports to log hyperparameters and metrics in YAML and CSV
format. Creates the directory structure:
```
log_dir/
hparams.yaml
metrics.csv
```
Args:
log_dir: Directory for the experiment logs
"""

NAME_HPARAMS_FILE = 'hparams.yaml'
NAME_METRICS_FILE = 'metrics.csv'

Expand All @@ -27,15 +45,21 @@ def __init__(self, log_dir):
self.metrics_keys = ["step"]

self.log_dir = log_dir
if not os.path.exists(log_dir):
os.makedirs(log_dir)
if os.path.exists(self.log_dir):
rank_zero_warn(
f"Experiment logs directory {self.log_dir} exists and is not empty. "
"Previous log files in this directory will be deleted when the new ones are saved!"
)
os.makedirs(self.log_dir, exist_ok=True)

self.metrics_file_path = os.path.join(self.log_dir, self.NAME_METRICS_FILE)

def log_hparams(self, params):
"""Record hparams"""
self.hparams.update(params)

def log_metrics(self, metrics_dict, step=None):
"""Record metrics"""
def _handle_value(value):
if isinstance(value, torch.Tensor):
return value.item()
Expand All @@ -53,6 +77,7 @@ def _handle_value(value):
self.metrics.append(new_row)

def save(self):
"""Save recorded hparams and metrics into files"""
hparams_file = os.path.join(self.log_dir, self.NAME_HPARAMS_FILE)
save_hparams_to_yaml(hparams_file, self.hparams)

Expand All @@ -63,15 +88,15 @@ def save(self):
self.writer.writerows(self.metrics)


class FileLogger(LightningLoggerBase):
class CSVLogger(LightningLoggerBase):
r"""
Log to local file system in yaml and CSV format. Logs are saved to
``os.path.join(save_dir, name, version)``.
Example:
>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.loggers import FileLogger
>>> logger = FileLogger("logs", name="my_exp_name")
>>> from pytorch_lightning.loggers import CSVLogger
>>> logger = CSVLogger("logs", name="my_exp_name")
>>> trainer = Trainer(logger=logger)
Args:
Expand Down
14 changes: 7 additions & 7 deletions tests/loggers/test_file_logger.py → tests/loggers/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import os

from pytorch_lightning.loggers import FileLogger
from pytorch_lightning.loggers import CSVLogger


def test_file_logger_automatic_versioning(tmpdir):
Expand All @@ -14,7 +14,7 @@ def test_file_logger_automatic_versioning(tmpdir):
root_dir.mkdir("version_0")
root_dir.mkdir("version_1")

logger = FileLogger(save_dir=tmpdir, name="exp")
logger = CSVLogger(save_dir=tmpdir, name="exp")

assert logger.version == 2

Expand All @@ -27,7 +27,7 @@ def test_file_logger_manual_versioning(tmpdir):
root_dir.mkdir("version_1")
root_dir.mkdir("version_2")

logger = FileLogger(save_dir=tmpdir, name="exp", version=1)
logger = CSVLogger(save_dir=tmpdir, name="exp", version=1)

assert logger.version == 1

Expand All @@ -39,7 +39,7 @@ def test_file_logger_named_version(tmpdir):
tmpdir.mkdir(exp_name)
expected_version = "2020-02-05-162402"

logger = FileLogger(save_dir=tmpdir, name=exp_name, version=expected_version)
logger = CSVLogger(save_dir=tmpdir, name=exp_name, version=expected_version)
logger.log_hyperparams({"a": 1, "b": 2})
logger.save()
assert logger.version == expected_version
Expand All @@ -50,15 +50,15 @@ def test_file_logger_named_version(tmpdir):
@pytest.mark.parametrize("name", ['', None])
def test_file_logger_no_name(tmpdir, name):
"""Verify that None or empty name works"""
logger = FileLogger(save_dir=tmpdir, name=name)
logger = CSVLogger(save_dir=tmpdir, name=name)
logger.save()
assert logger.root_dir == tmpdir
assert os.listdir(tmpdir / 'version_0')


@pytest.mark.parametrize("step_idx", [10, None])
def test_file_logger_log_metrics(tmpdir, step_idx):
logger = FileLogger(tmpdir)
logger = CSVLogger(tmpdir)
metrics = {
"float": 0.3,
"int": 1,
Expand All @@ -70,7 +70,7 @@ def test_file_logger_log_metrics(tmpdir, step_idx):


def test_file_logger_log_hyperparams(tmpdir):
logger = FileLogger(tmpdir)
logger = CSVLogger(tmpdir)
hparams = {
"float": 0.3,
"int": 1,
Expand Down

0 comments on commit 29de2d8

Please sign in to comment.