diff --git a/stable_whisper.py b/stable_whisper.py index c994866..5379ffc 100644 --- a/stable_whisper.py +++ b/stable_whisper.py @@ -4,13 +4,15 @@ import numpy as np import torch from torch import Tensor +from torch.nn import functional as F +from torch.distributions import Categorical from typing import List, Optional, Tuple, Union from whisper.audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram from whisper.decoding import DecodingOptions, DecodingResult from whisper.tokenizer import LANGUAGES from whisper.utils import exact_div, format_timestamp, compression_ratio from whisper.model import Whisper -from whisper.decoding import DecodingTask +from whisper.decoding import DecodingTask, BeamSearchDecoder, GreedyDecoder from whisper.tokenizer import Tokenizer, get_tokenizer from types import MethodType from itertools import chain, repeat @@ -112,13 +114,18 @@ def secs_to_hhmmss(secs: (float, int)): return srt_str -def group_word_timestamps(res: (dict, list), one_group=True, combine_compound=False, ts_key='whole_word_timestamps'): +def group_word_timestamps(res: (dict, list), one_group=True, combine_compound=False, + ts_key='whole_word_timestamps', min_dur: float = None): + + if min_dur is None: + min_dur = 0.02 + def group_ts(ts_: List[dict], start) -> List[dict]: first_group: List[dict] = [] for w_ts in ts_: if first_group: if (not combine_compound or w_ts['word'].startswith(' ')) and \ - (w_ts['timestamp'] - first_group[-1]['start']) > 0.02 and \ + (w_ts['timestamp'] - first_group[-1]['start']) >= min_dur and \ first_group[-1]['end'] < w_ts['timestamp']: first_group.append(dict(start=first_group[-1]['end'], end=w_ts['timestamp'], @@ -224,7 +231,7 @@ def results_to_sentence_srt(res: dict, srt_path, to_srt(segs, srt_path, strip=strip) -def results_to_word_srt(res: dict, srt_path, combine_compound=False, strip=False): +def results_to_word_srt(res: dict, srt_path, combine_compound=False, strip=False, min_dur: float = None): """ Parameters @@ -237,13 +244,15 @@ def results_to_word_srt(res: dict, srt_path, combine_compound=False, strip=False concatenate words without inbetween spacing strip: bool perform strip() on each word + min_dur: bool + minimum duration for each word (i.e. concat the words if it is less than specified value; Default 0.02) """ - to_srt(group_word_timestamps(res, combine_compound=combine_compound), + to_srt(group_word_timestamps(res, combine_compound=combine_compound, min_dur=min_dur), srt_path, strip=strip) -def results_to_token_srt(res: dict, srt_path, combine_compound=False, strip=False): +def results_to_token_srt(res: dict, srt_path, combine_compound=False, strip=False, min_dur: float = None): """ Parameters @@ -256,9 +265,11 @@ def results_to_token_srt(res: dict, srt_path, combine_compound=False, strip=Fals concatenate words without inbetween spacing strip: bool perform strip() on each token + min_dur: bool + minimum duration for each token (i.e. concat the tokens if it is less than specified value; Default 0.02) """ - to_srt(group_word_timestamps(res, combine_compound=combine_compound, ts_key='word_timestamps'), + to_srt(group_word_timestamps(res, combine_compound=combine_compound, ts_key='word_timestamps', min_dur=min_dur), srt_path, strip=strip) @@ -678,8 +689,8 @@ def _remove_lower_quantile(waveform: np.ndarray, lower_threshold = 0.15 waveform = deepcopy(waveform) wave_sums = waveform.sum(0) - mx = np.quantile(wave_sums, upper_quantile, 0) - mn = np.quantile(wave_sums, lower_quantile, 0) + mx = np.quantile(wave_sums, upper_quantile, -1) + mn = np.quantile(wave_sums, lower_quantile, -1) mn_threshold = (mx - mn) * lower_threshold + mn waveform[:, wave_sums < mn_threshold] = 0 return waveform @@ -724,6 +735,7 @@ def transcribe_word_level( suppress_middle: bool = True, suppress_word_ts: bool = True, remove_background: bool = True, + silence_threshold: float = 0.1, prepend_punctuations: Union[List[str], Tuple[str]] = None, append_punctuations: Union[List[str], Tuple[str]] = None, audio_for_mask: (str, bytes) = None, @@ -797,6 +809,11 @@ def transcribe_word_level( lower_threshold: float Suppressed sections of waveform where amplitude < lower_threshold*(mx-mn) + mn. (Default: 0.15) + silence_threshold: float: + Audio segments silence average >= silence_threshold + then that segment will not have background removed even if remove_background=True. + e.g. 0.5 means if less than half of the audio segment is silent then background will be removed accordingly + prepend_punctuations: Union[List[str], Tuple[str]] Punctuations to prepend to next word (Default: “¿([{) @@ -983,7 +1000,8 @@ def add_segment( if suppress_silence: wf_seek = int(seek * ts_scale) segment_wf = wf[..., wf_seek:wf_seek + 1501] - if remove_background: + if remove_background and \ + (1 - segment_wf.sum(0).clip(max=1).mean()) < silence_threshold: segment_wf = _remove_lower_quantile(segment_wf.astype(np.float32), upper_quantile=upper_quantile, lower_quantile=lower_quantile, @@ -1034,7 +1052,7 @@ def add_segment( sliced_tokens[-1].item() - tokenizer.timestamp_begin ) - word_ts = timestamp_offset + (sliced_ts_tokens - tokenizer.timestamp_begin) * time_precision + word_ts = timestamp_offset + sliced_ts_tokens * time_precision add_segment( offset=timestamp_offset, @@ -1065,7 +1083,7 @@ def add_segment( last_timestamp_position = min(timestamps[-1].item() - tokenizer.timestamp_begin, segment_max_ts) duration = last_timestamp_position * time_precision - word_ts = timestamp_offset + (finalized_ts_tokens - tokenizer.timestamp_begin) * time_precision + word_ts = timestamp_offset + finalized_ts_tokens * time_precision add_segment( offset=timestamp_offset, @@ -1107,28 +1125,206 @@ def add_segment( return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language) +def _suppress_ts(ts_logits: Tensor, suppress_ts_mask: Tensor = None): + if suppress_ts_mask is not None: + ts_logits[:, suppress_ts_mask] = -np.inf + + +def _ts_topk(ts_logits: Tensor, k: int, prev_ts: Tensor = None) -> Tensor: + temp_ts = torch.stack(torch.topk(ts_logits, k, dim=-1), 0).unsqueeze(-2) + return temp_ts if prev_ts is None else torch.cat([prev_ts, temp_ts], dim=-2) + + +# modified version of whisper.GreedyDecoderWordLevel +class GreedyDecoderWordLevel(GreedyDecoder): + def __init__(self, *args, **kwargs): + self.ts_num: int = kwargs.pop('ts_num', 10) + self.suppress_ts_mask: Tensor = kwargs.pop('suppress_ts_mask', None) + self.timestamp_begin = kwargs.pop('timestamp_begin', 50364) + super(GreedyDecoderWordLevel, self).__init__(*args, **kwargs) + self.ts = None + + def _suppress_ts(self, logits: Tensor): + _suppress_ts(logits[:, self.timestamp_begin:], + suppress_ts_mask=self.suppress_ts_mask) + + def update_with_ts(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, ts: Tensor) -> Tuple[Tensor, bool]: + self.ts = ts + + self._suppress_ts(logits) + + if self.temperature == 0: + next_tokens = logits.argmax(dim=-1) + else: + next_tokens = Categorical(logits=logits / self.temperature).sample() + + logprobs = F.log_softmax(logits.float(), dim=-1) + current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens] + sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot) + + next_tokens[tokens[:, -1] == self.eot] = self.eot + tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1) + + completed = (tokens[:, -1] == self.eot).all() + return tokens, completed + + def finalize(self, tokens: Tensor, sum_logprobs: Tensor): + # make sure each sequence has at least one EOT token at the end + tokens = F.pad(tokens, (0, 1), value=self.eot) + return tokens, sum_logprobs.tolist(), self.ts.transpose(1, 0)[None] + + +# modified version of whisper.BeamSearchDecoder +class BeamSearchDecoderWordLevel(BeamSearchDecoder): + + def __init__(self, *args, **kwargs): + self.ts_num: int = kwargs.pop('ts_num', 10) + self.suppress_ts_mask: Tensor = kwargs.pop('suppress_ts_mask', None) + self.timestamp_begin = kwargs.pop('timestamp_begin', 50364) + super(BeamSearchDecoderWordLevel, self).__init__(*args, **kwargs) + self.ts = None + self.finished_ts_ls = None + + def reset(self): + self.finished_sequences = None + self.finished_ts_ls = None + + def _suppress_ts(self, logits: Tensor): + _suppress_ts(logits[:, self.timestamp_begin:], + suppress_ts_mask=self.suppress_ts_mask) + + def update_with_ts(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, ts: Tensor) -> Tuple[Tensor, bool]: + if tokens.shape[0] % self.beam_size != 0: + raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0") + + self.ts = ts + + n_audio = tokens.shape[0] // self.beam_size + if self.finished_sequences is None: # for the first update + self.finished_sequences = [{} for _ in range(n_audio)] + self.finished_ts_ls = [{} for _ in range(n_audio)] + + logprobs = F.log_softmax(logits.float(), dim=-1) + next_tokens, source_indices, finished_sequences, finished_ts_ls = [], [], [], [] + + self._suppress_ts(logprobs) + + for i in range(n_audio): + scores, sources, finished, finished_ts = {}, {}, {}, {} + + # STEP 1: calculate the cumulative log probabilities for possible candidates + for j in range(self.beam_size): + idx = i * self.beam_size + j + prefix = tokens[idx].tolist() + for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)): + new_logprob = (sum_logprobs[idx] + logprob).item() + sequence = tuple(prefix + [token.item()]) + scores[sequence] = new_logprob + sources[sequence] = idx + + # STEP 2: rank the candidates and keep the top beam_size sequences for each audio + saved = 0 + for sequence in sorted(scores, key=scores.get, reverse=True): + if sequence[-1] == self.eot: + finished[sequence] = scores[sequence] + finished_ts[sequence] = self.ts[:, sources[sequence]] + else: + sum_logprobs[len(next_tokens)] = scores[sequence] + next_tokens.append(sequence) + source_indices.append(sources[sequence]) + + saved += 1 + if saved == self.beam_size: + break + + finished_sequences.append(finished) + finished_ts_ls.append(finished_ts) + + tokens = torch.tensor(next_tokens, device=tokens.device) + self.inference.rearrange_kv_cache(source_indices) + self.ts = self.ts[:, source_indices] + + # add newly finished sequences to self.finished_sequences + assert len(self.finished_sequences) == len(finished_sequences) + for previously_finished, newly_finished, \ + prev_ts_ls, new_ts_ls in \ + zip(self.finished_sequences, finished_sequences, + self.finished_ts_ls, finished_ts_ls): + for seq in sorted(newly_finished, key=newly_finished.get, reverse=True): + if len(previously_finished) >= self.max_candidates: + break # the candidate list is full + previously_finished[seq] = newly_finished[seq] + prev_ts_ls[seq] = new_ts_ls[seq] + + # mark as completed if all audio has enough number of samples + completed = all( + len(sequences) >= self.max_candidates for sequences in self.finished_sequences + ) + return tokens, completed + + def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor): + # collect all finished sequences, including patience, and add unfinished ones if not enough + sum_logprobs = sum_logprobs.cpu() + for i, (sequences, ts_) in \ + enumerate(zip(self.finished_sequences, self.finished_ts_ls)): + if len(sequences) < self.beam_size: # when not enough sequences are finished + for j in list(np.argsort(sum_logprobs[i]))[::-1]: + sequence = preceding_tokens[i, j].tolist() + [self.eot] + seq_tuple = tuple(sequence) + sequences[seq_tuple] = sum_logprobs[i][j].item() + ts_[i][seq_tuple] = self.ts[i, j] + if len(sequences) >= self.beam_size: + break + + tokens: List[List[Tensor]] = [ + [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences + ] + sum_logprobs: List[List[float]] = [ + list(sequences.values()) for sequences in self.finished_sequences + ] + final_ts: List[List[Tensor]] = [ + list(sequences.values()) for sequences in self.finished_ts_ls + ] + return tokens, sum_logprobs, final_ts + + class DecodingTaskWordLevel(DecodingTask): def __init__(self, *args, **kwargs): + self.ts_num: int = kwargs.pop('ts_num', 10) + self.alpha: float = kwargs.pop('alpha', None) # experimental + self.suppress_ts_mask: Tensor = kwargs.pop('suppress_ts_mask', None) + self.suppress_word_ts: bool = kwargs.pop('suppress_word_ts', True) super(DecodingTaskWordLevel, self).__init__(*args, **kwargs) + if hasattr(self.decoder, 'beam_size'): + self.decoder = BeamSearchDecoderWordLevel(self.decoder.beam_size, + self.decoder.eot, + self.inference, + self.decoder.patience, + ts_num=self.ts_num, + suppress_ts_mask=self.suppress_ts_mask, + timestamp_begin=self.tokenizer.timestamp_begin) + else: + self.decoder = GreedyDecoderWordLevel(self.decoder.temperature, + self.decoder.eot, + ts_num=self.ts_num, + suppress_ts_mask=self.suppress_ts_mask, + timestamp_begin=self.tokenizer.timestamp_begin) # modified version of whisper.DecodingTask._main_loop - def _main_loop(self, audio_features: Tensor, tokens: Tensor, ts_num: int = None, alpha: float = None, - suppress_ts_mask: Tensor = None, suppress_word_ts: bool = False): + def _main_loop(self, audio_features: Tensor, tokens: Tensor): assert audio_features.shape[0] == tokens.shape[0] n_batch = tokens.shape[0] sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device) no_speech_probs = [np.nan] * n_batch - ts_num = 5 if ts_num is None else max(ts_num, 1) - initial_tk_len = tokens.shape[-1] - ts_tokens = torch.zeros([*tokens.shape[:-1], 1], device=tokens.device, dtype=tokens.dtype) - ts_logits = torch.zeros_like(ts_tokens) + ts = None + try: for i in range(self.sample_len): - if alpha: + if self.alpha: logits = self.inference.logits(tokens, - audio_features * (torch.rand_like(audio_features) * alpha + 1)) + audio_features * (torch.rand_like(audio_features) * self.alpha + 1)) else: logits = self.inference.logits(tokens, audio_features) @@ -1139,52 +1335,33 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor, ts_num: int = None, # now we need to consider the logits at the last token only logits = logits[:, -1] - logits_clone = torch.clone(logits) - if suppress_word_ts and suppress_ts_mask is not None: - logits_clone[:, self.tokenizer.timestamp_begin:][:, suppress_ts_mask] = -np.inf - logits_clone[:, : self.tokenizer.timestamp_begin] = -np.inf - temp_ts_logits, temp_ts_token = torch.topk(logits_clone, ts_num) - ts_tokens = torch.cat([ts_tokens, temp_ts_token], -1) - ts_logits = torch.cat([ts_logits, temp_ts_logits], -1) - - del logits_clone - - # if suppress_ts_mask is not None: - # logits[:, self.tokenizer.timestamp_begin:][suppress_ts_mask] = -np.inf + ts_logits = torch.clone(logits[:, self.tokenizer.timestamp_begin:]) + if self.suppress_word_ts: + _suppress_ts(ts_logits, self.suppress_ts_mask) + ts = _ts_topk(ts_logits, k=self.ts_num, prev_ts=ts) # apply the logit filters, e.g. for suppressing or applying penalty to for logit_filter in self.logit_filters: logit_filter.apply(logits, tokens) - if suppress_ts_mask is not None: - logits[:, self.tokenizer.timestamp_begin:][:, suppress_ts_mask] = -np.inf - # expand the tokens tensor with the selected next tokens - tokens, completed = self.decoder.update(tokens, logits, sum_logprobs) + tokens, completed = self.decoder.update_with_ts(tokens, logits, sum_logprobs, ts) if completed or tokens.shape[-1] > self.n_ctx: break finally: self.inference.cleanup_caching() - new_ts_token_count = tokens.shape[-1] - initial_tk_len - ts_tokens = ts_tokens[..., 1:].reshape( - [*tokens.shape[:-1], new_ts_token_count, ts_num]) - ts_logits = ts_logits[..., 1:].reshape( - [*tokens.shape[:-1], new_ts_token_count, ts_num]) - return tokens, sum_logprobs, no_speech_probs, ts_tokens, ts_logits + return tokens, sum_logprobs, no_speech_probs # modified version of whisper.DecodingTask.run @torch.no_grad() - def run(self, mel: Tensor, ts_num: int = None, alpha: float = None, suppress_ts_mask: Tensor = None, - suppress_word_ts=False) \ - -> Union[List[DecodingResult], Tuple[List[DecodingResult], List[List[int]], List[List[int]]]]: + def run(self, mel: Tensor) \ + -> Union[List[DecodingResult], Tuple[List[DecodingResult], List[List[int]]]]: self.decoder.reset() tokenizer: Tokenizer = self.tokenizer n_audio: int = mel.shape[0] - ts_num = 10 if ts_num is None else max(ts_num, 1) - audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass tokens: Tensor = torch.tensor([self.initial_tokens]).expand(n_audio, -1) @@ -1201,10 +1378,7 @@ def run(self, mel: Tensor, ts_num: int = None, alpha: float = None, suppress_ts_ tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device) # call the main sampling loop - tokens, sum_logprobs, no_speech_probs, ts_tokens, ts_logits = self._main_loop(audio_features, tokens, - ts_num=ts_num, alpha=alpha, - suppress_ts_mask=suppress_ts_mask, - suppress_word_ts=suppress_word_ts) + tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens) # reshape the tensors to have (n_audio, n_group) as the first two dimensions audio_features = audio_features[:: self.n_group] @@ -1212,25 +1386,19 @@ def run(self, mel: Tensor, ts_num: int = None, alpha: float = None, suppress_ts_ assert audio_features.shape[0] == len(no_speech_probs) == n_audio tokens = tokens.reshape(n_audio, self.n_group, -1) - ts_tokens = ts_tokens.reshape(n_audio, self.n_group, -1, ts_num) - ts_logits = ts_logits.reshape(n_audio, self.n_group, -1, ts_num) sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group) # get the final candidates for each group, and slice between the first sampled token and EOT - tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs) + tokens, sum_logprobs, ts = self.decoder.finalize(tokens, sum_logprobs) tokens: List[List[Tensor]] = [ [t[self.sample_begin: (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens ] - ts_tokens: List[List[Tensor]] = [[t[:len(tokens[i][j])] for j, t in enumerate(s)] for i, s in - enumerate(ts_tokens)] - ts_logits: List[List[Tensor]] = [[t[:len(tokens[i][j])] for j, t in enumerate(s)] for i, s in - enumerate(ts_logits)] + ts: List[List[Tensor]] = [[t[:, :tokens[i][j].shape[-1]] for j, t in enumerate(s)] for i, s in enumerate(ts)] # select the top-ranked sample in each group selected = self.sequence_ranker.rank(tokens, sum_logprobs) tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)] - ts_tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, ts_tokens)] - ts_logits: List[List[int]] = [t[i].tolist() for i, t in zip(selected, ts_logits)] + ts: List[List[int]] = [t[i].tolist() for i, t in zip(selected, ts)] texts: List[str] = [tokenizer.decode(t).strip() for t in tokens] sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)] @@ -1253,7 +1421,7 @@ def run(self, mel: Tensor, ts_num: int = None, alpha: float = None, suppress_ts_ compression_ratio=compression_ratio(text), ) for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields) - ], ts_tokens, ts_logits + ], ts # modified version of whisper.decoding.decode @@ -1298,15 +1466,21 @@ def decode_word_level(model: "Whisper", mel: Tensor, options: DecodingOptions = if single: mel = mel.unsqueeze(0) - result, ts_tokens, ts_logits = DecodingTaskWordLevel(model, options).run(mel, ts_num=ts_num, - alpha=alpha, - suppress_ts_mask=suppress_ts_mask, - suppress_word_ts=suppress_word_ts) + result, ts = DecodingTaskWordLevel(model, options, + ts_num=ts_num, + alpha=alpha, + suppress_ts_mask=suppress_ts_mask, + suppress_word_ts=suppress_word_ts).run(mel) if single: result = result[0] + ts_token = ts[0][1] + ts_logits = ts[0][0] + else: + ts_token = [ts_[1] for ts_ in ts] + ts_logits = [ts_[0] for ts_ in ts] - return result, ts_tokens, ts_logits + return result, ts_token, ts_logits def modify_model(model: whisper.model.Whisper):