diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 1583926ece..04fc0265f5 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -79,6 +79,9 @@ jobs: pip uninstall -y protobuf pip install --no-binary protobuf protobuf + pip install kaldifst + pip install onnxruntime + pip install -r requirements.txt - name: Install graphviz diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/model.py b/egs/librispeech/ASR/lstm_transducer_stateless/model.py index efbc88a55d..d71132b4ac 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/model.py @@ -81,6 +81,7 @@ def forward( lm_scale: float = 0.0, warmup: float = 1.0, reduction: str = "sum", + delay_penalty: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -108,6 +109,11 @@ def forward( "sum" to sum the losses over all utterances in the batch. "none" to return the loss in a 1-D tensor for each utterance in the batch. + delay_penalty: + A constant value used to penalize symbol delay, to encourage + streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details. Returns: Return the transducer loss. @@ -164,6 +170,7 @@ def forward( am_only_scale=am_scale, boundary=boundary, reduction=reduction, + delay_penalty=delay_penalty, return_grad=True, ) @@ -196,6 +203,7 @@ def forward( ranges=ranges, termination_symbol=blank_id, boundary=boundary, + delay_penalty=delay_penalty, reduction=reduction, ) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless/train.py b/egs/librispeech/ASR/lstm_transducer_stateless/train.py index a50686df98..fbb4e72242 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless/train.py @@ -318,6 +318,16 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + add_model_arguments(parser) return parser @@ -611,6 +621,7 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, ) simple_loss_is_finite = torch.isfinite(simple_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py index b0fb6ab89c..fadeb4ac2f 100644 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/model.py @@ -106,6 +106,7 @@ def forward( lm_scale: float = 0.0, warmup: float = 1.0, reduction: str = "sum", + delay_penalty: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -136,6 +137,11 @@ def forward( "sum" to sum the losses over all utterances in the batch. "none" to return the loss in a 1-D tensor for each utterance in the batch. + delay_penalty: + A constant value used to penalize symbol delay, to encourage + streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details. Returns: Return the transducer loss. @@ -203,6 +209,7 @@ def forward( am_only_scale=am_scale, boundary=boundary, reduction=reduction, + delay_penalty=delay_penalty, return_grad=True, ) @@ -235,6 +242,7 @@ def forward( ranges=ranges, termination_symbol=blank_id, boundary=boundary, + delay_penalty=delay_penalty, reduction=reduction, ) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py index 9eed2dfcb4..ac6bf7e048 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless2/train.py @@ -341,6 +341,16 @@ def get_parser(): help="The probability to select a batch from the GigaSpeech dataset", ) + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + add_model_arguments(parser) return parser @@ -665,6 +675,7 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, ) simple_loss_is_finite = torch.isfinite(simple_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss) diff --git a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py index fa50576d8a..f2aa846258 100755 --- a/egs/librispeech/ASR/lstm_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/lstm_transducer_stateless3/train.py @@ -328,6 +328,16 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + add_model_arguments(parser) return parser @@ -623,6 +633,7 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, ) simple_loss_is_finite = torch.isfinite(simple_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless/test_model.py b/egs/librispeech/ASR/pruned_transducer_stateless/test_model.py index 1858d6bf02..fc82d8c69f 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless/test_model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless/test_model.py @@ -23,7 +23,6 @@ python ./pruned_transducer_stateless/test_model.py """ -import torch from train import get_params, get_transducer_model @@ -43,8 +42,6 @@ def test_model(): num_param = sum([p.numel() for p in model.parameters()]) print(f"Number of model parameters: {num_param}") - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - torch.jit.script(model) def test_model_streaming(): @@ -63,8 +60,6 @@ def test_model_streaming(): num_param = sum([p.numel() for p in model.parameters()]) print(f"Number of model parameters: {num_param}") - model.__class__.forward = torch.jit.ignore(model.__class__.forward) - torch.jit.script(model) def main(): diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py index ba7616c610..417c391d9b 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/model.py @@ -81,6 +81,7 @@ def forward( lm_scale: float = 0.0, warmup: float = 1.0, reduction: str = "sum", + delay_penalty: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -108,6 +109,12 @@ def forward( "sum" to sum the losses over all utterances in the batch. "none" to return the loss in a 1-D tensor for each utterance in the batch. + delay_penalty: + A constant value used to penalize symbol delay, to encourage + streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details. + Returns: Returns: Return the transducer loss. @@ -164,6 +171,7 @@ def forward( am_only_scale=am_scale, boundary=boundary, reduction=reduction, + delay_penalty=delay_penalty, return_grad=True, ) @@ -196,6 +204,7 @@ def forward( ranges=ranges, termination_symbol=blank_id, boundary=boundary, + delay_penalty=delay_penalty, reduction=reduction, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py index 5c2f675345..7ce2ca7791 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless2/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless2/train.py @@ -317,6 +317,16 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + add_model_arguments(parser) return parser @@ -607,6 +617,7 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, ) simple_loss_is_finite = torch.isfinite(simple_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py index 0d5f7cc6d9..7852f84e97 100644 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/model.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/model.py @@ -106,6 +106,7 @@ def forward( lm_scale: float = 0.0, warmup: float = 1.0, reduction: str = "sum", + delay_penalty: float = 0.0, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Args: @@ -136,6 +137,11 @@ def forward( "sum" to sum the losses over all utterances in the batch. "none" to return the loss in a 1-D tensor for each utterance in the batch. + delay_penalty: + A constant value used to penalize symbol delay, to encourage + streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details. Returns: Return the transducer loss. @@ -203,6 +209,7 @@ def forward( am_only_scale=am_scale, boundary=boundary, reduction=reduction, + delay_penalty=delay_penalty, return_grad=True, ) @@ -235,6 +242,7 @@ def forward( ranges=ranges, termination_symbol=blank_id, boundary=boundary, + delay_penalty=delay_penalty, reduction=reduction, ) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py index a74975caff..6cc34f18a1 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless3/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless3/train.py @@ -328,6 +328,16 @@ def get_parser(): help="The probability to select a batch from the GigaSpeech dataset", ) + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + add_model_arguments(parser) return parser @@ -645,6 +655,7 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, ) simple_loss_is_finite = torch.isfinite(simple_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py index 4c55fd6096..57548270d0 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless4/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless4/train.py @@ -335,6 +335,16 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + add_model_arguments(parser) return parser @@ -638,6 +648,7 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, ) simple_loss_is_finite = torch.isfinite(simple_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss) diff --git a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py index 1fa6682935..b964cd05d2 100755 --- a/egs/librispeech/ASR/pruned_transducer_stateless5/train.py +++ b/egs/librispeech/ASR/pruned_transducer_stateless5/train.py @@ -368,6 +368,16 @@ def get_parser(): help="Whether to use half precision training.", ) + parser.add_argument( + "--delay-penalty", + type=float, + default=0.0, + help="""A constant value used to penalize symbol delay, + to encourage streaming models to emit symbols earlier. + See https://github.com/k2-fsa/k2/issues/955 and + https://arxiv.org/pdf/2211.00490.pdf for more details.""", + ) + add_model_arguments(parser) return parser @@ -662,6 +672,7 @@ def compute_loss( lm_scale=params.lm_scale, warmup=warmup, reduction="none", + delay_penalty=params.delay_penalty if warmup >= 2.0 else 0, ) simple_loss_is_finite = torch.isfinite(simple_loss) pruned_loss_is_finite = torch.isfinite(pruned_loss)