Skip to content

Commit

Permalink
Apply delay penalty on transducer (#654)
Browse files Browse the repository at this point in the history
* add delay penalty

* fix CI

* fix CI
  • Loading branch information
yaozengwei authored Nov 4, 2022
1 parent 65b85b7 commit 3600ce1
Show file tree
Hide file tree
Showing 13 changed files with 113 additions and 5 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ jobs:
pip uninstall -y protobuf
pip install --no-binary protobuf protobuf
pip install kaldifst
pip install onnxruntime
pip install -r requirements.txt
- name: Install graphviz
Expand Down
8 changes: 8 additions & 0 deletions egs/librispeech/ASR/lstm_transducer_stateless/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def forward(
lm_scale: float = 0.0,
warmup: float = 1.0,
reduction: str = "sum",
delay_penalty: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
Expand Down Expand Up @@ -108,6 +109,11 @@ def forward(
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
delay_penalty:
A constant value used to penalize symbol delay, to encourage
streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.
Returns:
Return the transducer loss.
Expand Down Expand Up @@ -164,6 +170,7 @@ def forward(
am_only_scale=am_scale,
boundary=boundary,
reduction=reduction,
delay_penalty=delay_penalty,
return_grad=True,
)

Expand Down Expand Up @@ -196,6 +203,7 @@ def forward(
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
delay_penalty=delay_penalty,
reduction=reduction,
)

Expand Down
11 changes: 11 additions & 0 deletions egs/librispeech/ASR/lstm_transducer_stateless/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,16 @@ def get_parser():
help="Whether to use half precision training.",
)

parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -611,6 +621,7 @@ def compute_loss(
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
Expand Down
8 changes: 8 additions & 0 deletions egs/librispeech/ASR/lstm_transducer_stateless2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def forward(
lm_scale: float = 0.0,
warmup: float = 1.0,
reduction: str = "sum",
delay_penalty: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
Expand Down Expand Up @@ -136,6 +137,11 @@ def forward(
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
delay_penalty:
A constant value used to penalize symbol delay, to encourage
streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.
Returns:
Return the transducer loss.
Expand Down Expand Up @@ -203,6 +209,7 @@ def forward(
am_only_scale=am_scale,
boundary=boundary,
reduction=reduction,
delay_penalty=delay_penalty,
return_grad=True,
)

Expand Down Expand Up @@ -235,6 +242,7 @@ def forward(
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
delay_penalty=delay_penalty,
reduction=reduction,
)

Expand Down
11 changes: 11 additions & 0 deletions egs/librispeech/ASR/lstm_transducer_stateless2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,16 @@ def get_parser():
help="The probability to select a batch from the GigaSpeech dataset",
)

parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -665,6 +675,7 @@ def compute_loss(
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
Expand Down
11 changes: 11 additions & 0 deletions egs/librispeech/ASR/lstm_transducer_stateless3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,16 @@ def get_parser():
help="Whether to use half precision training.",
)

parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -623,6 +633,7 @@ def compute_loss(
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
Expand Down
5 changes: 0 additions & 5 deletions egs/librispeech/ASR/pruned_transducer_stateless/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
python ./pruned_transducer_stateless/test_model.py
"""

import torch
from train import get_params, get_transducer_model


Expand All @@ -43,8 +42,6 @@ def test_model():

num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)


def test_model_streaming():
Expand All @@ -63,8 +60,6 @@ def test_model_streaming():

num_param = sum([p.numel() for p in model.parameters()])
print(f"Number of model parameters: {num_param}")
model.__class__.forward = torch.jit.ignore(model.__class__.forward)
torch.jit.script(model)


def main():
Expand Down
9 changes: 9 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def forward(
lm_scale: float = 0.0,
warmup: float = 1.0,
reduction: str = "sum",
delay_penalty: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
Expand Down Expand Up @@ -108,6 +109,12 @@ def forward(
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
delay_penalty:
A constant value used to penalize symbol delay, to encourage
streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.
Returns:
Returns:
Return the transducer loss.
Expand Down Expand Up @@ -164,6 +171,7 @@ def forward(
am_only_scale=am_scale,
boundary=boundary,
reduction=reduction,
delay_penalty=delay_penalty,
return_grad=True,
)

Expand Down Expand Up @@ -196,6 +204,7 @@ def forward(
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
delay_penalty=delay_penalty,
reduction=reduction,
)

Expand Down
11 changes: 11 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless2/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,16 @@ def get_parser():
help="Whether to use half precision training.",
)

parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -607,6 +617,7 @@ def compute_loss(
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
Expand Down
8 changes: 8 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def forward(
lm_scale: float = 0.0,
warmup: float = 1.0,
reduction: str = "sum",
delay_penalty: float = 0.0,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
Expand Down Expand Up @@ -136,6 +137,11 @@ def forward(
"sum" to sum the losses over all utterances in the batch.
"none" to return the loss in a 1-D tensor for each utterance
in the batch.
delay_penalty:
A constant value used to penalize symbol delay, to encourage
streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.
Returns:
Return the transducer loss.
Expand Down Expand Up @@ -203,6 +209,7 @@ def forward(
am_only_scale=am_scale,
boundary=boundary,
reduction=reduction,
delay_penalty=delay_penalty,
return_grad=True,
)

Expand Down Expand Up @@ -235,6 +242,7 @@ def forward(
ranges=ranges,
termination_symbol=blank_id,
boundary=boundary,
delay_penalty=delay_penalty,
reduction=reduction,
)

Expand Down
11 changes: 11 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless3/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,16 @@ def get_parser():
help="The probability to select a batch from the GigaSpeech dataset",
)

parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)

add_model_arguments(parser)
return parser

Expand Down Expand Up @@ -645,6 +655,7 @@ def compute_loss(
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
Expand Down
11 changes: 11 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless4/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,16 @@ def get_parser():
help="Whether to use half precision training.",
)

parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -638,6 +648,7 @@ def compute_loss(
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
Expand Down
11 changes: 11 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless5/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,16 @@ def get_parser():
help="Whether to use half precision training.",
)

parser.add_argument(
"--delay-penalty",
type=float,
default=0.0,
help="""A constant value used to penalize symbol delay,
to encourage streaming models to emit symbols earlier.
See https://github.com/k2-fsa/k2/issues/955 and
https://arxiv.org/pdf/2211.00490.pdf for more details.""",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -662,6 +672,7 @@ def compute_loss(
lm_scale=params.lm_scale,
warmup=warmup,
reduction="none",
delay_penalty=params.delay_penalty if warmup >= 2.0 else 0,
)
simple_loss_is_finite = torch.isfinite(simple_loss)
pruned_loss_is_finite = torch.isfinite(pruned_loss)
Expand Down

0 comments on commit 3600ce1

Please sign in to comment.