Skip to content

Commit

Permalink
Merge pull request #54 from alexandrainst/chore/small-changes-to-wav2…
Browse files Browse the repository at this point in the history
…vec2-finetuning

Chore/small changes to wav2vec2 finetuning
  • Loading branch information
saattrupdan authored Dec 14, 2023
2 parents bb0d073 + 6b80161 commit d9d09de
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 43 deletions.
6 changes: 3 additions & 3 deletions config/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ dirs:
seed: 4242

# Dataset parameters
characters_to_keep: 'abcdefghijklmnopqrstuvwxyzæøå0123456789éü '
characters_to_keep: 'abcdefghijklmnopqrstuvwxyzæøå0123456789éü'
max_seconds_per_example: 10
dataloader_num_workers: 8

Expand Down Expand Up @@ -48,12 +48,12 @@ ignore_data_skip: false
save_total_limit: 2

# Optimisation parameters
learning_rate: 3e-5
learning_rate: 1e-4
adam_first_momentum: 0.9
adam_second_momentum: 0.98
total_batch_size: 256
per_device_batch_size: 16
max_steps: 50_000
max_steps: 10_000
warmup_steps: 1_000
logging_steps: 10
eval_steps: 100
Expand Down
19 changes: 10 additions & 9 deletions config/model/test_wav2vec2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,17 @@ clean_dataset: true
# Model hyperparameters
sampling_rate: 16_000
activation_dropout: 0.1
attention_dropout: 0.1
hidden_dropout: 0.1
feat_proj_dropout: 0.1
final_dropout: 0.1
mask_time_prob: 0.075
attention_dropout: 0.0
hidden_dropout: 0.0
feat_proj_dropout: 0.0
feat_quantizer_dropout: 0.0
final_dropout: 0.0
mask_time_prob: 0.5
mask_time_length: 10
mask_feature_prob: 0.075
mask_feature_length: 10
layerdrop: 0.0 # NOTE: This parameter cannot be used in a multi-gpu setting!
ctc_loss_reduction: sum
mask_feature_prob: 0.5
mask_feature_length: 64
layerdrop: 0.1 # NOTE: This will automatically be set to 0 in a multi-gpu setting
ctc_loss_reduction: mean

# Decoder hyperparameters
language_model_decoder: null
4 changes: 2 additions & 2 deletions config/model/wav2vec2.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ mask_time_prob: 0.5
mask_time_length: 10
mask_feature_prob: 0.5
mask_feature_length: 64
layerdrop: 0.1 # This will automatically be set to 0 in a multi-gpu setting
ctc_loss_reduction: mean
layerdrop: 0.1 # NOTE: This will automatically be set to 0 in a multi-gpu setting
ctc_loss_reduction: sum

# Decoder hyperparameters
language_model_decoder: ngram
Expand Down
24 changes: 15 additions & 9 deletions src/coral_models/compute_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,19 +40,27 @@ def compute_wer_metrics(pred: EvalPrediction, processor: Processor) -> dict[str,
if predictions.dtype == np.int_:
vocab_size = tokenizer.get_vocab()
mismatch_dim = len(vocab_size) - predictions.shape[-1]
predictions = np.pad(predictions, ((0, 0), (0, 0), (0, mismatch_dim)))
predictions_str = tokenizer.batch_decode(
predictions, skip_special_tokens=True
predictions = np.pad(
array=predictions,
pad_width=((0, 0), (0, 0), (0, mismatch_dim)),
mode="constant",
constant_values=pad_token,
)
predictions_str = tokenizer.batch_decode(sequences=predictions)

# Otherwise, if we are not using a language model, we need to convert the
# logits to token IDs and then decode the token IDs to get the predicted string
else:
# If all the logits are -100 for a token, then we set the logit for the
# padding token for that token to 0. This is to ensure that this token gets
# decoded to a padding token, and are therefore ignored
predictions[np.all(predictions == -100, axis=-1), pad_token] = 0

pred_ids: NDArray[np.int_] = np.argmax(predictions, axis=-1)
predictions_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
predictions_str = tokenizer.batch_decode(pred_ids)

elif len(predictions.shape) == 2 and predictions.dtype == np.int_:
predictions_str = tokenizer.batch_decode(predictions, skip_special_tokens=True)
predictions_str = tokenizer.batch_decode(sequences=predictions)

else:
raise ValueError(
Expand All @@ -67,11 +75,9 @@ def compute_wer_metrics(pred: EvalPrediction, processor: Processor) -> dict[str,
labels[labels == -100] = pad_token

# Decode the ground truth labels
labels_str = tokenizer.batch_decode(
sequences=labels, skip_special_tokens=True, group_tokens=False
)
labels_str = tokenizer.batch_decode(sequences=labels, group_tokens=False)

# TEMP: Log both the predictions and the ground truth labels
# Log both the predictions and the ground truth labels
is_main_process = os.getenv("RANK", "0") == "0"
if is_main_process:
random_idx = np.random.randint(0, len(predictions_str))
Expand Down
2 changes: 1 addition & 1 deletion src/coral_models/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def clean_dataset(
# transcriptions, as they do not have an influence on the pronunciation of the
# words.
non_standard_characters_regex = re.compile(
f"[^{re.escape(cfg.characters_to_keep)}]"
f"[^{re.escape(cfg.characters_to_keep + ' |')}]"
)

mapped = dataset.map(
Expand Down
40 changes: 24 additions & 16 deletions src/coral_models/wav2vec2.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ class DataCollatorCTCWithPadding(DataCollatorMixin):
Args:
processor (Wav2Vec2Processor)
The processor used for proccessing the data.
max_seconds_per_example (float):
The maximum number of seconds per example.
padding (bool, str or PaddingStrategy, optional):
Select a strategy to pad the returned sequences (according to the model's
padding side and padding index) among:
Expand All @@ -60,6 +62,7 @@ class DataCollatorCTCWithPadding(DataCollatorMixin):
"""

processor: Wav2Vec2Processor
max_seconds_per_example: float
padding: bool | str
return_tensors: str = "pt"

Expand All @@ -86,12 +89,12 @@ def torch_call(self, features: list[dict]) -> BatchFeature:
audio_features,
padding=self.padding,
return_tensors=self.return_tensors,
max_length=16_000 * 10,
max_length=16_000 * self.max_seconds_per_example,
)

label_features = [dict(input_ids=feature["labels"]) for feature in features]
labels_batch: BatchEncoding = self.processor.tokenizer.pad(
label_features,
labels_batch: BatchEncoding = self.processor.pad(
labels=label_features,
padding=self.padding,
return_tensors=self.return_tensors,
max_length=512,
Expand Down Expand Up @@ -125,19 +128,21 @@ def load_processor(self) -> Wav2Vec2Processor:
dump_vocabulary(self.cfg)
tokenizer: Wav2Vec2CTCTokenizer = Wav2Vec2CTCTokenizer.from_pretrained(
self.cfg.model_dir,
unk_token="<unk>",
pad_token="<pad>",
unk_token="<unk>",
bos_token="<s>",
eos_token="</s>",
word_delimiter_token=" ",
word_delimiter_token="|",
replace_word_delimiter_char=" ",
)
break
except json.decoder.JSONDecodeError:
process_id = os.getenv("RANK", 0)
logger.warning(
f"JSONDecodeError while loading tokenizer on process {process_id}. "
"Retrying in a second."
)
log_message = "JSONDecodeError while loading tokenizer"
process_id = os.getenv("RANK")
if process_id is not None:
log_message += f" in process {process_id}"
log_message += ". Retrying in a second."
logger.warning(log_message)
time.sleep(1)

# Set the `model_max_length` attribute of the tokenizer, if it hasn't been set,
Expand All @@ -155,6 +160,7 @@ def load_processor(self) -> Wav2Vec2Processor:
self.processor = Wav2Vec2Processor(
feature_extractor=extractor, tokenizer=tokenizer
)

return self.processor

def load_model(self) -> Wav2Vec2ForCTC:
Expand All @@ -179,7 +185,7 @@ def load_model(self) -> Wav2Vec2ForCTC:
vocab_size=len(self.processor.tokenizer.get_vocab()),
ctc_zero_infinity=True,
)
assert isinstance(model, Wav2Vec2ForCTC)
assert isinstance(model, Wav2Vec2ForCTC)

if self.cfg.model.freeze_feature_encoder:
for param in model.wav2vec2.parameters():
Expand All @@ -189,7 +195,9 @@ def load_model(self) -> Wav2Vec2ForCTC:

def load_data_collator(self) -> DataCollatorCTCWithPadding:
return DataCollatorCTCWithPadding(
processor=self.processor, padding=self.cfg.padding
processor=self.processor,
max_seconds_per_example=self.cfg.max_seconds_per_example,
padding=self.cfg.padding,
)

def load_trainer_class(self) -> Type[Trainer]:
Expand Down Expand Up @@ -275,7 +283,9 @@ def load_saved(self) -> PreTrainedModelData:

model = Wav2Vec2ForCTC.from_pretrained(self.cfg.hub_id, token=True)
data_collator = DataCollatorCTCWithPadding(
processor=processor, padding=self.cfg.padding
processor=processor,
max_seconds_per_example=self.cfg.max_seconds_per_example,
padding=self.cfg.padding,
)
compute_metrics = partial(compute_wer_metrics, processor=processor)
return PreTrainedModelData(
Expand All @@ -296,12 +306,10 @@ def dump_vocabulary(cfg: DictConfig) -> None:
The Hydra configuration object.
"""
# Build the set of all unique characters in the dataset
unique_characters: set[str] = set(cfg.characters_to_keep)
unique_characters: set[str] = set(cfg.characters_to_keep + "|")

# Build vocabulary
vocab = {char: idx for idx, char in enumerate(unique_characters)}
for tok in ["<unk>", "<pad>", "<s>", "</s>"]:
vocab[tok] = len(vocab)

# Dump the vocabulary to a json file
model_dir = Path(cfg.model_dir)
Expand Down
19 changes: 16 additions & 3 deletions src/coral_models/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class DataCollatorSpeechSeq2SeqWithPadding(DataCollatorMixin):
Args:
processor (WhisperProcessor)
The processor used for proccessing the data.
max_seconds_per_example (float):
The maximum number of seconds per example.
padding (bool, str or PaddingStrategy, optional):
Select a strategy to pad the returned sequences (according to the model's
padding side and padding index) among:
Expand All @@ -53,6 +55,7 @@ class DataCollatorSpeechSeq2SeqWithPadding(DataCollatorMixin):
"""

processor: WhisperProcessor
max_seconds_per_example: float
padding: bool | str = True
return_tensors: str = "pt"

Expand All @@ -78,14 +81,22 @@ def torch_call(self, features: list[dict]) -> BatchFeature:
"Features must contain either 'input_features' or 'audio' key."
)
batch = self.processor.feature_extractor.pad(
audio_features, return_tensors="pt"
audio_features,
padding=self.padding,
return_tensors=self.return_tensors,
max_length=16_000 * self.max_seconds_per_example,
)

# Get the tokenized label sequences
label_features = [{"input_ids": feature["labels"]} for feature in features]

# Pad the labels to max length
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
labels_batch = self.processor.tokenizer.pad(
label_features,
padding=self.padding,
return_tensors=self.return_tensors,
max_length=512,
)

# replace padding with -100 to ignore loss correctly
labels = labels_batch["input_ids"].masked_fill(
Expand Down Expand Up @@ -162,7 +173,9 @@ def load_model(self) -> WhisperForConditionalGeneration:

def load_data_collator(self) -> DataCollatorSpeechSeq2SeqWithPadding:
return DataCollatorSpeechSeq2SeqWithPadding(
processor=self.processor, padding=self.cfg.padding
processor=self.processor,
max_seconds_per_example=self.cfg.max_seconds_per_example,
padding=self.cfg.padding,
)

def load_trainer_class(self) -> Type[Trainer]:
Expand Down

0 comments on commit d9d09de

Please sign in to comment.