Skip to content

Commit

Permalink
Merge pull request #40 from alexandrainst/fix/smaller-fixes
Browse files Browse the repository at this point in the history
Fix/smaller fixes
  • Loading branch information
saattrupdan authored Oct 9, 2023
2 parents 29aa2a5 + c7f032c commit b3dcf94
Show file tree
Hide file tree
Showing 16 changed files with 284 additions and 124 deletions.
6 changes: 0 additions & 6 deletions .flake8

This file was deleted.

12 changes: 5 additions & 7 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,12 @@ repos:
rev: 23.7.0
hooks:
- id: black
- repo: https://github.com/timothycrosley/isort
rev: 5.12.0
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.290
hooks:
- id: isort
- repo: https://github.com/pycqa/flake8
rev: 6.0.0
hooks:
- id: flake8
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
types_or: [python, pyi, jupyter]
- repo: https://github.com/kynan/nbstripout
rev: 0.6.0
hooks:
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ ______________________________________________________________________
[![Documentation](https://img.shields.io/badge/docs-passing-green)](https://alexandrainst.github.io/CoRal-models/coral_models.html)
[![License](https://img.shields.io/github/license/alexandrainst/CoRal-models)](https://github.com/alexandrainst/CoRal-models/blob/main/LICENSE)
[![LastCommit](https://img.shields.io/github/last-commit/alexandrainst/CoRal-models)](https://github.com/alexandrainst/CoRal-models/commits/main)
[![Code Coverage](https://img.shields.io/badge/Coverage-61%25-yellow.svg)](https://github.com/alexandrainst/CoRal-models/tree/main/tests)
[![Code Coverage](https://img.shields.io/badge/Coverage-60%25-yellow.svg)](https://github.com/alexandrainst/CoRal-models/tree/main/tests)


Developers:
Expand Down
4 changes: 2 additions & 2 deletions config/config.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
defaults:
- model: wav2vec2
- datasets:
- common_voice_da
- nst_da
- common_voice_9_da
- override hydra/job_logging: custom
- _self_

Expand Down Expand Up @@ -38,6 +38,6 @@ logging_steps: 10
eval_steps: 100
save_steps: 100
save_total_limit: 2
early_stopping: true
early_stopping: false
early_stopping_patience: 50
fp16: true
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
common_voice_da:
common_voice_13_da:
id: mozilla-foundation/common_voice_13_0
subset: da
train_name: train
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
common_voice_nn:
common_voice_13_nn:
id: mozilla-foundation/common_voice_13_0
subset: nn-NO
train_name: train
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
common_voice_sv:
common_voice_13_sv:
id: mozilla-foundation/common_voice_13_0
subset: sv-SE
train_name: train
Expand Down
7 changes: 7 additions & 0 deletions config/datasets/common_voice_9_da.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
common_voice_9_da:
id: mozilla-foundation/common_voice_9_0
subset: da
train_name: train
val_name: validation
test_name: test
text_column: sentence
2 changes: 2 additions & 0 deletions config/model/test_wav2vec2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,6 @@ learning_rate: 4e-5
warmup_steps: 1
early_stopping: true
early_stopping_patience: 5
adam_first_momentum: 0.9
adam_second_momentum: 0.999
fp16: false
17 changes: 10 additions & 7 deletions config/model/wav2vec2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ characters_to_keep: 'abcdefghijklmnopqrstuvwxyzæøå0123456789éü '
# Model hyperparameters
sampling_rate: 16_000
activation_dropout: 0.1
attention_dropout: 0.1
hidden_dropout: 0.1
feat_proj_dropout: 0.1
final_dropout: 0.1
attention_dropout: 0.0
hidden_dropout: 0.0
feat_proj_dropout: 0.0
feat_quantizer_dropout: 0.0
final_dropout: 0.0
mask_time_prob: 0.5
mask_time_length: 10
mask_feature_prob: 0.5
Expand All @@ -31,7 +32,9 @@ decoder:

# Training hyperparameters
batch_size: 8
gradient_accumulation: 4
max_steps: 120_000
gradient_accumulation: 32
max_steps: 13_000 # Based on the XLS-R paper, section 4.3
warmup_steps: 1_300 # Based on the XLS-R paper, section 4.3
learning_rate: 3e-5
warmup_steps: 500
adam_first_momentum: 0.9
adam_second_momentum: 0.98
40 changes: 40 additions & 0 deletions config/model/wav2vec2_no_reg.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
name: wav2vec2_no_reg
type: wav2vec2
pretrained_model_id: chcaa/xls-r-300m-danish
freeze_feature_encoder: false

# Data hyperparameters
clean_dataset: true
characters_to_keep: 'abcdefghijklmnopqrstuvwxyzæøå0123456789éü '

# Model hyperparameters
sampling_rate: 16_000
activation_dropout: 0.0
hidden_dropout: 0.0
attention_dropout: 0.0
feat_proj_dropout: 0.0
feat_quantizer_dropout: 0.0
final_dropout: 0.0
mask_time_prob: 0.0
mask_time_length: 10
mask_feature_prob: 0.0
mask_feature_length: 64
layerdrop: 0.0
ctc_loss_reduction: mean

# Decoder hyperparameters
language_model_decoder: ngram
decoder:
dataset_id: DDSC/reddit-da-asr-preprocessed
dataset_subset: null
dataset_split: train
n: 5

# Training hyperparameters
batch_size: 8
gradient_accumulation: 32
max_steps: 13_000
learning_rate: 3e-5
adam_first_momentum: 0.9
adam_second_momentum: 0.98
warmup_steps: 1_300
195 changes: 130 additions & 65 deletions poetry.lock

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ license = "MIT"
python = ">=3.10,<3.12"
hydra-core = "^1.1.1"
evaluate = ">=0.4.0,<1.0.0"
transformers = "^4.30.0"
transformers = "^4.33.0"
torch = "^2.0.0,!=2.0.1"
librosa = ">=0.10.0.post2,<1.0.0"
soundfile = ">=0.12.1,<1.0.0"
Expand Down Expand Up @@ -73,8 +73,11 @@ exclude = '''
)/
'''

[tool.isort]
profile = "black"
[tool.ruff]
target-version = "py311"

[tool.ruff.extend-per-file-ignores]
"__init__.py" = ["F401"]

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
42 changes: 35 additions & 7 deletions src/coral_models/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import logging
import os
import re
from pathlib import Path
from unicodedata import normalize

from datasets import (
Expand Down Expand Up @@ -36,15 +37,43 @@ def load_data(cfg: DictConfig) -> DatasetDict | IterableDatasetDict:
for dataset_name, dataset_cfg in cfg.datasets.items():
logger.info(f"Loading dataset {dataset_name!r}")

# Load from disk if the dataset ID is a path
if Path(dataset_cfg.id).exists():
dataset_paths = {
dataset_cfg.train_name: Path(dataset_cfg.id) / dataset_cfg.train_name
}
if dataset_cfg.val_name is not None:
dataset_paths[dataset_cfg.val_name] = (
Path(dataset_cfg.id) / dataset_cfg.val_name
)
if dataset_cfg.test_name is not None:
dataset_paths[dataset_cfg.test_name] = (
Path(dataset_cfg.id) / dataset_cfg.test_name
)
data_files = {
split: list(map(str, split_path.glob("data-*.arrow")))
for split, split_path in dataset_paths.items()
}
for split, files in data_files.items():
if len(files) == 0:
raise FileNotFoundError(
f"No data files found for split {split!r} in dataset "
f"{dataset_name!r}. Please check that the provided dataset "
f"directory {dataset_paths[split]!r} contains arrow files of "
"the form 'data-*.arrow'."
)
dataset = load_dataset("arrow", data_files=data_files, streaming=True)

# Load dataset from the Hugging Face Hub. The HUGGINGFACE_HUB_TOKEN is only used
# during CI - normally it is expected that the user is logged in to the Hugging
# Face Hub using the `huggingface-cli login` command.
dataset = load_dataset(
path=dataset_cfg.id,
name=dataset_cfg.subset,
token=os.getenv("HUGGINGFACE_HUB_TOKEN", True),
streaming=True,
)
else:
dataset = load_dataset(
path=dataset_cfg.id,
name=dataset_cfg.subset,
token=os.getenv("HUGGINGFACE_HUB_TOKEN", True),
streaming=True,
)

assert isinstance(dataset, DatasetDict) or isinstance(
dataset, IterableDatasetDict
Expand Down Expand Up @@ -212,7 +241,6 @@ def clean_dataset(
"è": "e",
"kg": " kilo ",
"μg": " mikrogram ",
"μg": " mikrogram ",
"-": " minus ",
"+": " plus ",
"μ": " mikro ",
Expand Down
8 changes: 6 additions & 2 deletions src/coral_models/wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def load_model(self) -> Wav2Vec2ForCTC:
hidden_dropout=self.cfg.model.hidden_dropout,
feat_proj_dropout=self.cfg.model.feat_proj_dropout,
final_dropout=self.cfg.model.final_dropout,
apply_spec_augment=True,
mask_time_prob=self.cfg.model.mask_time_prob,
mask_time_length=self.cfg.model.mask_time_length,
mask_feature_prob=self.cfg.model.mask_feature_prob,
Expand All @@ -154,6 +155,7 @@ def load_model(self) -> Wav2Vec2ForCTC:
bos_token_id=self.processor.tokenizer.bos_token_id,
eos_token_id=self.processor.tokenizer.eos_token_id,
vocab_size=len(self.processor.tokenizer.get_vocab()),
ctc_zero_infinity=True,
)
assert isinstance(model, Wav2Vec2ForCTC)

Expand Down Expand Up @@ -204,11 +206,13 @@ def load_training_arguments(self) -> TrainingArguments:
seed=self.cfg.seed,
remove_unused_columns=False,
optim=OptimizerNames.ADAMW_TORCH,
use_mps_device=mps_is_available(),
adam_beta1=self.cfg.model.adam_first_momentum,
adam_beta2=self.cfg.model.adam_second_momentum,
report_to=["wandb"] if self.cfg.wandb else [],
ignore_data_skip=self.cfg.ignore_data_skip,
save_safetensors=True,
no_cuda=hasattr(sys, "_called_from_test"),
use_cpu=hasattr(sys, "_called_from_test"),
auto_find_batch_size=True,
)
return args

Expand Down
Loading

0 comments on commit b3dcf94

Please sign in to comment.