Skip to content

Commit

Permalink
Fixed beam search ts update order
Browse files Browse the repository at this point in the history
-fixed beam search timestamp tokens and logits not updating order properly with the reordering of beams (i.e. should increase accuracy of beam search word-level timestamps slightly)
-fixed incorrect indexing when not enough sequences are finished for BeamSearchDecoderWordLevel.finalize
-corrected typos/incorrect naming
  • Loading branch information
jianfch authored Nov 6, 2022
1 parent 280999c commit 4cbb0b8
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions stable_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ def transcribe_word_level(
The path to the audio file to open, or the audio waveform
verbose: bool
Whether to display the text (with finalized timestamps) being decoded to the console
Whether to display the decoded text (with finalized timestamps) to the console
temperature: Union[float, Tuple[float, ...]]
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
Expand Down Expand Up @@ -789,6 +789,7 @@ def transcribe_word_level(
print_unstab: bool
Whether to display the text (without stabilize timestamps) being decoded to the console
(i.e. behaves like verbose before model was modified)
suppress_silence: bool
Suppress timestamp tokens that are marked as silent
Expand Down Expand Up @@ -1135,7 +1136,7 @@ def _ts_topk(ts_logits: Tensor, k: int, prev_ts: Tensor = None) -> Tensor:
return temp_ts if prev_ts is None else torch.cat([prev_ts, temp_ts], dim=-2)


# modified version of whisper.GreedyDecoderWordLevel
# modified version of whisper.GreedyDecoder
class GreedyDecoderWordLevel(GreedyDecoder):
def __init__(self, *args, **kwargs):
self.ts_num: int = kwargs.pop('ts_num', 10)
Expand Down Expand Up @@ -1264,6 +1265,7 @@ def update_with_ts(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, t

def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
# collect all finished sequences, including patience, and add unfinished ones if not enough
self.ts = self.ts.reshape(self.ts.shape[0], *preceding_tokens.shape[:2], *self.ts.shape[2:])
sum_logprobs = sum_logprobs.cpu()
for i, (sequences, ts_) in \
enumerate(zip(self.finished_sequences, self.finished_ts_ls)):
Expand All @@ -1272,7 +1274,7 @@ def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
sequence = preceding_tokens[i, j].tolist() + [self.eot]
seq_tuple = tuple(sequence)
sequences[seq_tuple] = sum_logprobs[i][j].item()
ts_[i][seq_tuple] = self.ts[i, j]
ts_[seq_tuple] = self.ts[:, i, j]
if len(sequences) >= self.beam_size:
break

Expand Down Expand Up @@ -1318,7 +1320,7 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
no_speech_probs = [np.nan] * n_batch

ts = None
# ts = None

try:
for i in range(self.sample_len):
Expand All @@ -1338,7 +1340,7 @@ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
ts_logits = torch.clone(logits[:, self.tokenizer.timestamp_begin:])
if self.suppress_word_ts:
_suppress_ts(ts_logits, self.suppress_ts_mask)
ts = _ts_topk(ts_logits, k=self.ts_num, prev_ts=ts)
ts = _ts_topk(ts_logits, k=self.ts_num, prev_ts=self.decoder.ts)

# apply the logit filters, e.g. for suppressing or applying penalty to
for logit_filter in self.logit_filters:
Expand Down Expand Up @@ -1474,13 +1476,13 @@ def decode_word_level(model: "Whisper", mel: Tensor, options: DecodingOptions =

if single:
result = result[0]
ts_token = ts[0][1]
ts_tokens = ts[0][1]
ts_logits = ts[0][0]
else:
ts_token = [ts_[1] for ts_ in ts]
ts_tokens = [ts_[1] for ts_ in ts]
ts_logits = [ts_[0] for ts_ in ts]

return result, ts_token, ts_logits
return result, ts_tokens, ts_logits


def modify_model(model: whisper.model.Whisper):
Expand Down

0 comments on commit 4cbb0b8

Please sign in to comment.