Skip to content

Commit

Permalink
Implemented alternative experiment tracking functionality (#102)
Browse files Browse the repository at this point in the history
  • Loading branch information
Leminen authored Nov 7, 2024
1 parent 859ca42 commit 7b5acfe
Show file tree
Hide file tree
Showing 15 changed files with 1,718 additions and 795 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ data/
# Models
models/

# Weights and Biases experiment tracking
# Experiment tracking
wandb/
mlruns/

# Data files
*.xlsx
Expand Down
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@ ______________________________________________________________________

Developers:

- Anders Jess Pedersen ([email protected])
- Dan Saattrup Nielsen ([email protected])
- Simon Leminen Madsen ([email protected])



## Installation
Expand Down
9 changes: 4 additions & 5 deletions config/asr_finetuning.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
defaults:
- model: wav2vec2-small
- model: whisper-xxsmall
- datasets:
- coral
- decoder_datasets:
- wikipedia
- common_voice
- reddit
- experiment_tracking: wandb
- override hydra/job_logging: custom
- _self_

seed: 4242

experiment_tracking: null

evaluation_dataset:
id: alexandrainst/coral
subset: read_aloud
Expand Down Expand Up @@ -48,10 +51,6 @@ fp16_allowed: true
bf16_allowed: true

# Training parameters
wandb: false
wandb_project: CoRal
wandb_group: default
wandb_name: ${model_id}
resume_from_checkpoint: false
ignore_data_skip: false
save_total_limit: 0 # Will automatically be set to >=1 if `early_stopping` is enabled
Expand Down
3 changes: 3 additions & 0 deletions config/experiment_tracking/mlflow.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
type: mlflow
name_experiment: CoRal
name_run: ${model_id}
4 changes: 4 additions & 0 deletions config/experiment_tracking/wandb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
type: wandb
name_experiment: CoRal
name_run: ${model_id}
name_group: default
2,328 changes: 1,552 additions & 776 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ gradio = {version = "^5.5.0", optional=true}
samplerate = {version="^0.2.1", optional=true}
punctfix = {version="^0.11.1", optional=true}
matplotlib = {version = "^3.9.2", optional = true}
mlflow = "^2.17.2"

[tool.poetry.group.dev.dependencies]
pytest = ">=8.1.1"
Expand Down
4 changes: 4 additions & 0 deletions src/coral/experiment_tracking/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""The CoRal project.
Experiment tracking.
"""
28 changes: 28 additions & 0 deletions src/coral/experiment_tracking/extracking_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Factory for experiment tracking setup."""

from omegaconf import DictConfig

from .extracking_setup import ExTrackingSetup
from .mlflow_setup import MLFlowSetup
from .wandb_setup import WandbSetup


def load_extracking_setup(config: DictConfig) -> ExTrackingSetup:
"""Return the experiment tracking setup.
Args:
config:
The configuration object.
Returns:
The experiment tracking setup.
"""
match config.experiment_tracking.type:
case "wandb":
return WandbSetup(config=config)
case "mlflow":
return MLFlowSetup(config=config)
case _:
raise ValueError(
f"Unknown experiment tracking type: {config.experiment_tracking.type}"
)
34 changes: 34 additions & 0 deletions src/coral/experiment_tracking/extracking_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""This module contains the base class for an experiment tracking setup."""

from abc import ABC, abstractmethod

from omegaconf import DictConfig


class ExTrackingSetup(ABC):
"""Base class for an experiment tracking setup."""

@abstractmethod
def __init__(self, config: DictConfig) -> None:
"""Initialise the experiment tracking setup.
Args:
config:
The configuration object.
"""

@abstractmethod
def run_initialization(self) -> None:
"""Run the initialization of the experiment tracking setup.
Returns:
True if the initialization was successful, False otherwise.
"""

@abstractmethod
def run_finalization(self) -> None:
"""Run the finalization of the experiment tracking setup.
Returns:
True if the finalization was successful, False otherwise.
"""
33 changes: 33 additions & 0 deletions src/coral/experiment_tracking/mlflow_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""MLFlow experiment tracking setup class."""

import os

import mlflow
from omegaconf import DictConfig

from .extracking_setup import ExTrackingSetup


class MLFlowSetup(ExTrackingSetup):
"""MLFlow setup class."""

def __init__(self, config: DictConfig) -> None:
"""Initialise the MLFlow setup.
Args:
config:
The configuration object.
"""
self.config = config
self.is_main_process = os.getenv("RANK", "0") == "0"

def run_initialization(self) -> None:
"""Run the initialization of the experiment tracking setup."""
mlflow.set_experiment(self.config.experiment_tracking.name_experiment)
mlflow.start_run(run_name=self.config.experiment_tracking.name_run)
return

def run_finalization(self) -> None:
"""Run the finalization of the experiment tracking setup."""
mlflow.end_run()
return
37 changes: 37 additions & 0 deletions src/coral/experiment_tracking/wandb_setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""wandb experiment tracking setup class."""

import os

import wandb
from omegaconf import DictConfig

from .extracking_setup import ExTrackingSetup


class WandbSetup(ExTrackingSetup):
"""Wandb setup class."""

def __init__(self, config: DictConfig) -> None:
"""Initialise the Wandb setup.
Args:
config:
The configuration object.
"""
self.config = config
self.is_main_process = os.getenv("RANK", "0") == "0"

def run_initialization(self) -> None:
"""Run the initialization of the experiment tracking setup."""
wandb.init(
project=self.config.experiment_tracking.name_experiment,
name=self.config.experiment_tracking.name_run,
group=self.config.experiment_tracking.name_group,
config=dict(self.config),
)
return

def run_finalization(self) -> None:
"""Run the finalization of the experiment tracking setup."""
wandb.finish()
return
18 changes: 7 additions & 11 deletions src/coral/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@

from omegaconf import DictConfig
from transformers import EarlyStoppingCallback, TrainerCallback
from wandb import finish as wandb_finish
from wandb.sdk.wandb_init import init as wandb_init

from .data import load_data_for_finetuning
from .data_models import ModelSetup
from .experiment_tracking.extracking_factory import load_extracking_setup
from .model_setup import load_model_setup
from .ngram import train_and_store_ngram_model
from .utils import block_terminal_output, disable_tqdm, push_model_to_hub
Expand All @@ -33,13 +32,9 @@ def finetune(config: DictConfig) -> None:
model = model_setup.load_model()
dataset = load_data_for_finetuning(config=config, processor=processor)

if config.wandb and is_main_process:
wandb_init(
project=config.wandb_project,
group=config.wandb_group,
name=config.wandb_name,
config=dict(config),
)
if bool(config.experiment_tracking) and is_main_process:
extracking_setup = load_extracking_setup(config=config)
extracking_setup.run_initialization()

if "val" not in dataset and is_main_process:
logger.info("No validation set found. Disabling early stopping.")
Expand All @@ -58,8 +53,9 @@ def finetune(config: DictConfig) -> None:
block_terminal_output()
with disable_tqdm():
trainer.train(resume_from_checkpoint=config.resume_from_checkpoint)
if config.wandb and is_main_process:
wandb_finish()

if bool(config.experiment_tracking) and is_main_process:
extracking_setup.run_finalization()

model.save_pretrained(save_directory=config.model_dir)

Expand Down
4 changes: 3 additions & 1 deletion src/coral/wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ def load_training_arguments(self) -> TrainingArguments:
optim=OptimizerNames.ADAMW_TORCH,
adam_beta1=self.config.adam_first_momentum,
adam_beta2=self.config.adam_second_momentum,
report_to=["wandb"] if self.config.wandb else [],
report_to=[self.config.experiment_tracking.type]
if self.config.experiment_tracking
else [],
ignore_data_skip=self.config.ignore_data_skip,
save_safetensors=True,
use_cpu=hasattr(sys, "_called_from_test"),
Expand Down
4 changes: 3 additions & 1 deletion src/coral/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,9 @@ def load_training_arguments(self) -> TrainingArguments:
optim=OptimizerNames.ADAMW_TORCH,
adam_beta1=self.config.adam_first_momentum,
adam_beta2=self.config.adam_second_momentum,
report_to=["wandb"] if self.config.wandb else [],
report_to=[self.config.experiment_tracking.type]
if self.config.experiment_tracking
else [],
ignore_data_skip=self.config.ignore_data_skip,
save_safetensors=True,
predict_with_generate=True,
Expand Down

0 comments on commit 7b5acfe

Please sign in to comment.