Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Whisper] Add word timestamps and confidence scores #201

Merged
merged 8 commits into from
Jan 7, 2024

Conversation

bofenghuang
Copy link
Contributor

@bofenghuang bofenghuang commented Dec 28, 2023

Hi @awni 👋

I've tried to add several new features to the Whisper implementation through this PR, following the implementation of the original repository:

This is still a draft version that may require some optimizations:

  • Move certain numpy operations to MLX. I encountered issues running some operations in MLX, so I left them in np. However, you may have better solutions :)
  • More efficient implementation of median_filter and dtw. I used directly the median_filter from scipy, since I didn't find the unfold function in mlx. As for dtw, I kept the original numba version
  • Better handling of qk attention scores in the model forward

Below are the benchmark times from tests run on my M1 Pro.

Feature time 0.038

Model: TINY
Model forward time 0.038
Decode time 0.211
Everything time 0.266
Everything (w/ word_timestamps) time 0.320

--------------------------------------------------


Model: SMALL
Model forward time 0.233
Decode time 0.644
Everything time 0.859
Everything (w/ word_timestamps) time 1.113

--------------------------------------------------


Model: MEDIUM
Model forward time 0.684
Decode time 1.700
Everything time 2.356
Everything (w/ word_timestamps) time 2.914

--------------------------------------------------


Model: LARGE
Model forward time 2.782
Decode time 2.701
Everything time 3.597
Everything (w/ word_timestamps) time 4.823

--------------------------------------------------

@bofenghuang bofenghuang marked this pull request as draft December 28, 2023 20:09
@awni
Copy link
Member

awni commented Dec 29, 2023

Super cool, thanks for adding that!

@awni awni self-requested a review December 29, 2023 00:40
@awni
Copy link
Member

awni commented Dec 29, 2023

Addresses #146

@bofenghuang
Copy link
Contributor Author

After measuring the time taken for operations to add word-level timestamps/scores, I've found that most are consumed by the extra model forward pass. There also appears to be overhead in the first run of DTW, likely due to Numba JIT compilation

Below are the measured times from tests with the large model.

extra forward time: 1.2198s
median_filter time: 0.0046s
dtw time: 0.7094s
extra forward time: 1.2341s
median_filter time: 0.0044s
dtw time: 0.0012s
extra forward time: 1.2124s
median_filter time: 0.0045s
dtw time: 0.0010s
extra forward time: 1.3064s
median_filter time: 0.0081s
dtw time: 0.0005s
extra forward time: 1.2168s
median_filter time: 0.0045s
dtw time: 0.0003s

@bofenghuang bofenghuang marked this pull request as ready for review January 1, 2024 14:02
whisper/test.py Outdated Show resolved Hide resolved
whisper/whisper/whisper.py Outdated Show resolved Hide resolved
whisper/whisper/whisper.py Outdated Show resolved Hide resolved
Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @bofenghuang this looks really nice to me and I think we can merge it!

One thing I'm wondering is if we can test the alignment code and/or the word timestamp code at all? It is a bit involved so it would be good to have a test or two to cover it.

whisper/whisper/timing.py Outdated Show resolved Hide resolved
@bofenghuang
Copy link
Contributor Author

Hi @awni, thanks for the review!

I've just done a rebase and added a test for word-level timestamps & confidence, comparing the results with those from openai-whisper.

@bofenghuang
Copy link
Contributor Author

Below are the new measured times from tests run on my mac m1 pro:

Selected models: ['tiny', 'small', 'medium', 'large-v3']

Feature time 0.035

Model: TINY
Model forward time 0.034
Decode time 0.186
Everything time 0.251
Everything (w/ word_timestamps) time 0.291

--------------------------------------------------


Model: SMALL
Model forward time 0.221
Decode time 0.650
Everything time 0.855
Everything (w/ word_timestamps) time 1.137

--------------------------------------------------


Model: MEDIUM
Model forward time 0.646
Decode time 1.559
Everything time 2.176
Everything (w/ word_timestamps) time 2.832

--------------------------------------------------


Model: LARGE-V3
Model forward time 1.209
Decode time 2.753
Everything time 3.609
Everything (w/ word_timestamps) time 4.953

--------------------------------------------------

Copy link
Member

@awni awni left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for adding this!! I updated the README to reflect the addition.

@awni awni merged commit bf99264 into ml-explore:main Jan 7, 2024
@bofenghuang bofenghuang deleted the word-timestamps branch January 7, 2024 20:19
Blaizzy pushed a commit to Blaizzy/mlx-examples that referenced this pull request Mar 13, 2024
* Add word timestamps and confidence scores

* Create a separate forward_with_cross_qk function

* Move multiple ops from np to mlx, clean comments

* Save alignment_heads

* Cast qk to fp32

* Add test for word-level timestamps and confidence scores

* format + readme

* nit

---------

Co-authored-by: Awni Hannun <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants