Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation of model average. #1

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open

Implementation of model average. #1

wants to merge 13 commits into from

Conversation

yaozengwei
Copy link
Owner

@yaozengwei yaozengwei commented May 1, 2022

This is the implementation of Dan's idea about model average. (see k2-fsa#337)

@yaozengwei
Copy link
Owner Author

yaozengwei commented May 2, 2022

The codes are based on egs/librispeech/pruned_transducer_stateless2.
During training, the averaged model model_avg is updated each average_period batches with:
model_avg = (average_period / batch_idx_train) * model + ((batch_idx_train - average_period) / batch_idx_train) * model_avg
During decoding, Let start = batch_idx_train of model-start; end = batch_idx_train of model-end. Then the averaged model avg over epoch [start+1, start+2, ..., end] is avg = (model_end * end - model_start * start) / (end - start).
When trained on train-clean-100 with 3 gpu for 30 epochs, average_period=100, I got following results with greedy search decoding:

  • decode with epoch-29, avg=5, 7.14 & 19.33 (without averaged model) -> 7.03 & 18.85 (with averaged model);
  • decode with epoch-29, avg=10, 6.99 & 18.93 (without averaged model) -> 6.91 & 18.65 (with averaged model).

When trained on full librispeech with 6 gpu for 30 epochs, average_period=100, I got following results with greedy search decoding:

  • decode with epoch-29, avg=5, 2.77 & 6.77 (without averaged model) -> 2.72 & 6.67 (with averaged model);
  • decode with epoch-29, avg=10, 2.78 & 6.68 (without averaged model) -> 2.74 & 6.67 (with averaged model).

"""
Usage:
(1) greedy search
./pruned_transducer_stateless2/decode.py \
Copy link

@csukuangfj csukuangfj May 2, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
./pruned_transducer_stateless2/decode.py \
./pruned_transducer_stateless3/decode.py \

Also, please sync with the latest k2/icefall and rename it to pruned_transducer_stateless4

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok.

model.load_state_dict(average_checkpoints(filenames, device=device))
else:
assert params.iter == 0
start = params.epoch - params.avg

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add more doc to --use-average-model.
It is not clear how it is used in the code from the current help info.

filename_start = f"{params.exp_dir}/epoch-{start}.pt"
filename_end = f"{params.exp_dir}/epoch-{params.epoch}.pt"
logging.info(
f"averaging modes over range with {filename_start} (excluded) "

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
f"averaging modes over range with {filename_start} (excluded) "
f"averaging models over range with {filename_start} (excluded) "

@@ -118,6 +126,10 @@ def load_checkpoint(

checkpoint.pop("model")

if model_avg is not None and "model_avg" in checkpoint:
model_avg.load_state_dict(checkpoint["model_avg"], strict=strict)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a log here, e.g., saying "loading averaged model".

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok.

Comment on lines 423 to 436
# Identify shared parameters. Two parameters are said to be shared
# if they have the same data_ptr
uniqued: Dict[int, str] = dict()
for k, v in avg.items():
v_data_ptr = v.data_ptr()
if v_data_ptr in uniqued:
continue
uniqued[v_data_ptr] = k

uniqued_names = list(uniqued.values())
for k in uniqued_names:
avg[k] *= weight_end
avg[k] += model_start[k] * weight_start

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This part is almost the same as the above function. Please refactor it to reduce redundant code.

parser.add_argument(
"--start-epoch",
type=int,
default=0,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please change it so that epoch is counted from 1, not 0.

def load_checkpoint_if_available(
params: AttributeDict,
model: nn.Module,
model_avg: nn.Module = None,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_avg: nn.Module = None,
model_avg: Optional[nn.Module] = None,

The return value of :func:`get_params`.
model:
The training model.
optimizer:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the doc to include model_avg.

logging.info(f"Number of model parameters: {num_param}")

assert params.save_every_n >= params.average_period
model_avg: nn.Module = None

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
model_avg: nn.Module = None
model_avg: Optional[nn.Module] = None

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants