-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implementation of model average. #1
base: master
Are you sure you want to change the base?
Conversation
The codes are based on egs/librispeech/pruned_transducer_stateless2.
When trained on full librispeech with 6 gpu for 30 epochs, average_period=100, I got following results with greedy search decoding:
|
""" | ||
Usage: | ||
(1) greedy search | ||
./pruned_transducer_stateless2/decode.py \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
./pruned_transducer_stateless2/decode.py \ | |
./pruned_transducer_stateless3/decode.py \ |
Also, please sync with the latest k2/icefall and rename it to pruned_transducer_stateless4
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok.
model.load_state_dict(average_checkpoints(filenames, device=device)) | ||
else: | ||
assert params.iter == 0 | ||
start = params.epoch - params.avg |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add more doc to --use-average-model.
It is not clear how it is used in the code from the current help info.
filename_start = f"{params.exp_dir}/epoch-{start}.pt" | ||
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt" | ||
logging.info( | ||
f"averaging modes over range with {filename_start} (excluded) " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f"averaging modes over range with {filename_start} (excluded) " | |
f"averaging models over range with {filename_start} (excluded) " |
@@ -118,6 +126,10 @@ def load_checkpoint( | |||
|
|||
checkpoint.pop("model") | |||
|
|||
if model_avg is not None and "model_avg" in checkpoint: | |||
model_avg.load_state_dict(checkpoint["model_avg"], strict=strict) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a log here, e.g., saying "loading averaged model".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok.
icefall/checkpoint.py
Outdated
# Identify shared parameters. Two parameters are said to be shared | ||
# if they have the same data_ptr | ||
uniqued: Dict[int, str] = dict() | ||
for k, v in avg.items(): | ||
v_data_ptr = v.data_ptr() | ||
if v_data_ptr in uniqued: | ||
continue | ||
uniqued[v_data_ptr] = k | ||
|
||
uniqued_names = list(uniqued.values()) | ||
for k in uniqued_names: | ||
avg[k] *= weight_end | ||
avg[k] += model_start[k] * weight_start | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This part is almost the same as the above function. Please refactor it to reduce redundant code.
parser.add_argument( | ||
"--start-epoch", | ||
type=int, | ||
default=0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please change it so that epoch is counted from 1, not 0.
def load_checkpoint_if_available( | ||
params: AttributeDict, | ||
model: nn.Module, | ||
model_avg: nn.Module = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_avg: nn.Module = None, | |
model_avg: Optional[nn.Module] = None, |
The return value of :func:`get_params`. | ||
model: | ||
The training model. | ||
optimizer: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update the doc to include model_avg.
logging.info(f"Number of model parameters: {num_param}") | ||
|
||
assert params.save_every_n >= params.average_period | ||
model_avg: nn.Module = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model_avg: nn.Module = None | |
model_avg: Optional[nn.Module] = None |
…n the pruned_transducer_stateless4/decode.py
This is the implementation of Dan's idea about model average. (see k2-fsa#337)