From 4cbb0b8c7e5ba25ab5171ace5401a06e2216cf19 Mon Sep 17 00:00:00 2001 From: jian Date: Sun, 6 Nov 2022 12:52:04 -0500 Subject: [PATCH] Fixed beam search ts update order -fixed beam search timestamp tokens and logits not updating order properly with the reordering of beams (i.e. should increase accuracy of beam search word-level timestamps slightly) -fixed incorrect indexing when not enough sequences are finished for BeamSearchDecoderWordLevel.finalize -corrected typos/incorrect naming --- stable_whisper.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/stable_whisper.py b/stable_whisper.py index 5379ffc..b819cd7 100644 --- a/stable_whisper.py +++ b/stable_whisper.py @@ -752,7 +752,7 @@ def transcribe_word_level( The path to the audio file to open, or the audio waveform verbose: bool - Whether to display the text (with finalized timestamps) being decoded to the console + Whether to display the decoded text (with finalized timestamps) to the console temperature: Union[float, Tuple[float, ...]] Temperature for sampling. It can be a tuple of temperatures, which will be successfully used @@ -789,6 +789,7 @@ def transcribe_word_level( print_unstab: bool Whether to display the text (without stabilize timestamps) being decoded to the console + (i.e. behaves like verbose before model was modified) suppress_silence: bool Suppress timestamp tokens that are marked as silent @@ -1135,7 +1136,7 @@ def _ts_topk(ts_logits: Tensor, k: int, prev_ts: Tensor = None) -> Tensor: return temp_ts if prev_ts is None else torch.cat([prev_ts, temp_ts], dim=-2) -# modified version of whisper.GreedyDecoderWordLevel +# modified version of whisper.GreedyDecoder class GreedyDecoderWordLevel(GreedyDecoder): def __init__(self, *args, **kwargs): self.ts_num: int = kwargs.pop('ts_num', 10) @@ -1264,6 +1265,7 @@ def update_with_ts(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, t def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor): # collect all finished sequences, including patience, and add unfinished ones if not enough + self.ts = self.ts.reshape(self.ts.shape[0], *preceding_tokens.shape[:2], *self.ts.shape[2:]) sum_logprobs = sum_logprobs.cpu() for i, (sequences, ts_) in \ enumerate(zip(self.finished_sequences, self.finished_ts_ls)): @@ -1272,7 +1274,7 @@ def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor): sequence = preceding_tokens[i, j].tolist() + [self.eot] seq_tuple = tuple(sequence) sequences[seq_tuple] = sum_logprobs[i][j].item() - ts_[i][seq_tuple] = self.ts[i, j] + ts_[seq_tuple] = self.ts[:, i, j] if len(sequences) >= self.beam_size: break @@ -1318,7 +1320,7 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor): sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device) no_speech_probs = [np.nan] * n_batch - ts = None + # ts = None try: for i in range(self.sample_len): @@ -1338,7 +1340,7 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor): ts_logits = torch.clone(logits[:, self.tokenizer.timestamp_begin:]) if self.suppress_word_ts: _suppress_ts(ts_logits, self.suppress_ts_mask) - ts = _ts_topk(ts_logits, k=self.ts_num, prev_ts=ts) + ts = _ts_topk(ts_logits, k=self.ts_num, prev_ts=self.decoder.ts) # apply the logit filters, e.g. for suppressing or applying penalty to for logit_filter in self.logit_filters: @@ -1474,13 +1476,13 @@ def decode_word_level(model: "Whisper", mel: Tensor, options: DecodingOptions = if single: result = result[0] - ts_token = ts[0][1] + ts_tokens = ts[0][1] ts_logits = ts[0][0] else: - ts_token = [ts_[1] for ts_ in ts] + ts_tokens = [ts_[1] for ts_ in ts] ts_logits = [ts_[0] for ts_ in ts] - return result, ts_token, ts_logits + return result, ts_tokens, ts_logits def modify_model(model: whisper.model.Whisper):