-
Notifications
You must be signed in to change notification settings - Fork 303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Model averaging #337
Comments
I will write the function. |
I think PyTorch has sth similar here https://pytorch.org/docs/stable/optim.html#stochastic-weight-averaging |
OK that's interesting. I think we can still write our own though, because it doesn't look to me like it's easy to use that PyTorch thing with our batch_idx_train, which allows us to choose the period for averaging at decode time. |
This was referenced May 1, 2022
Merged
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
OK, we have some results locally (thanks, @yaozengwei!) showing that model averaging over finely spaced checkpoints is a bit better than averaging over epochs. It's about 0.05% on test-clean and 0.15% on test-other, at around 3%/7% WER, but probably still worth doing as at this point we are picking up pennies in WER.
... OK, here's the idea. The idea is that we always store a separate version of the model, say model_avg, in which for each floating-point parameter, it contains the average from the start of training of all the parameters. We update this every average_period batches, for, say, average_period = 10 or 100 (could be every batch but this is for speed). Each time we average, we do:
[this is not the syntax we'd use, we'd have to write a function to do this weighted-average.]
I propose that we include this averaged version inside the
checkpoints-*.pt
andepoch-*.pt
as a separate key in the dict.Then the way we would implement decoding epoch-29.pt with --avg 5 would be something like the following.
The text was updated successfully, but these errors were encountered: