Skip to content

Commit

Permalink
Add Shallow fusion in modified_beam_search (#630)
Browse files Browse the repository at this point in the history
* Add utility for shallow fusion

* test batch size == 1 without shallow fusion

* Use shallow fusion for modified-beam-search

* Modified beam search with ngram rescoring

* Fix code according to review

Co-authored-by: Fangjun Kuang <[email protected]>
  • Loading branch information
ezerhouni and csukuangfj authored Oct 21, 2022
1 parent c30b8d3 commit 9b671e1
Show file tree
Hide file tree
Showing 6 changed files with 476 additions and 0 deletions.
20 changes: 20 additions & 0 deletions egs/librispeech/ASR/generate-lm.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#!/usr/bin/env bash

lang_dir=data/lang_bpe_500

for ngram in 2 3 5; do
if [ ! -f $lang_dir/${ngram}gram.arpa ]; then
./shared/make_kn_lm.py \
-ngram-order ${ngram} \
-text $lang_dir/transcript_tokens.txt \
-lm $lang_dir/${ngram}gram.arpa
fi

if [ ! -f $lang_dir/${ngram}gram.fst.txt ]; then
python3 -m kaldilm \
--read-symbol-table="$lang_dir/tokens.txt" \
--disambig-symbol='#0' \
--max-order=${ngram} \
$lang_dir/${ngram}gram.arpa > $lang_dir/${ngram}gram.fst.txt
fi
done
49 changes: 49 additions & 0 deletions egs/librispeech/ASR/lstm_transducer_stateless2/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,12 @@
greedy_search,
greedy_search_batch,
modified_beam_search,
modified_beam_search_ngram_rescoring,
)
from librispeech import LibriSpeech
from train import add_model_arguments, get_params, get_transducer_model

from icefall import NgramLm
from icefall.checkpoint import (
average_checkpoints,
average_checkpoints_with_averaged_model,
Expand Down Expand Up @@ -214,6 +216,7 @@ def get_parser():
- fast_beam_search_nbest
- fast_beam_search_nbest_oracle
- fast_beam_search_nbest_LG
- modified_beam_search_ngram_rescoring
If you use fast_beam_search_nbest_LG, you have to specify
`--lang-dir`, which should contain `LG.pt`.
""",
Expand Down Expand Up @@ -303,6 +306,22 @@ def get_parser():
fast_beam_search_nbest_LG, and fast_beam_search_nbest_oracle""",
)

parser.add_argument(
"--tokens-ngram",
type=int,
default=3,
help="""Token Ngram used for rescoring.
Used only when the decoding method is modified_beam_search_ngram_rescoring""",
)

parser.add_argument(
"--backoff-id",
type=int,
default=500,
help="""ID of the backoff symbol.
Used only when the decoding method is modified_beam_search_ngram_rescoring""",
)

add_model_arguments(parser)

return parser
Expand All @@ -315,6 +334,8 @@ def decode_one_batch(
batch: dict,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
) -> Dict[str, List[List[str]]]:
"""Decode one batch and return the result in a dict. The dict has the
following format:
Expand Down Expand Up @@ -448,6 +469,17 @@ def decode_one_batch(
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
elif params.decoding_method == "modified_beam_search_ngram_rescoring":
hyp_tokens = modified_beam_search_ngram_rescoring(
model=model,
encoder_out=encoder_out,
encoder_out_lens=encoder_out_lens,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
beam=params.beam_size,
)
for hyp in sp.decode(hyp_tokens):
hyps.append(hyp.split())
else:
batch_size = encoder_out.size(0)

Expand Down Expand Up @@ -497,6 +529,8 @@ def decode_dataset(
sp: spm.SentencePieceProcessor,
word_table: Optional[k2.SymbolTable] = None,
decoding_graph: Optional[k2.Fsa] = None,
ngram_lm: Optional[NgramLm] = None,
ngram_lm_scale: float = 1.0,
) -> Dict[str, List[Tuple[str, List[str], List[str]]]]:
"""Decode dataset.
Expand Down Expand Up @@ -546,6 +580,8 @@ def decode_dataset(
decoding_graph=decoding_graph,
word_table=word_table,
batch=batch,
ngram_lm=ngram_lm,
ngram_lm_scale=ngram_lm_scale,
)

for name, hyps in hyps_dict.items():
Expand Down Expand Up @@ -631,6 +667,7 @@ def main():
"fast_beam_search_nbest_LG",
"fast_beam_search_nbest_oracle",
"modified_beam_search",
"modified_beam_search_ngram_rescoring",
)
params.res_dir = params.exp_dir / params.decoding_method

Expand All @@ -655,6 +692,7 @@ def main():
else:
params.suffix += f"-context-{params.context_size}"
params.suffix += f"-max-sym-per-frame-{params.max_sym_per_frame}"
params.suffix += f"-ngram-lm-scale-{params.ngram_lm_scale}"

if params.use_averaged_model:
params.suffix += "-use-averaged-model"
Expand Down Expand Up @@ -768,6 +806,15 @@ def main():
model.to(device)
model.eval()

lm_filename = f"{params.tokens_ngram}gram.fst.txt"
logging.info(f"lm filename: {lm_filename}")
ngram_lm = NgramLm(
str(params.lang_dir / lm_filename),
backoff_id=params.backoff_id,
is_binary=False,
)
logging.info(f"num states: {ngram_lm.lm.num_states}")

if "fast_beam_search" in params.decoding_method:
if params.decoding_method == "fast_beam_search_nbest_LG":
lexicon = Lexicon(params.lang_dir)
Expand Down Expand Up @@ -812,6 +859,8 @@ def main():
sp=sp,
word_table=word_table,
decoding_graph=decoding_graph,
ngram_lm=ngram_lm,
ngram_lm_scale=params.ngram_lm_scale,
)

save_results(
Expand Down
173 changes: 173 additions & 0 deletions egs/librispeech/ASR/pruned_transducer_stateless2/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch
from model import Transducer

from icefall import NgramLm, NgramLmStateCost
from icefall.decode import Nbest, one_best_decoding
from icefall.utils import add_eos, add_sos, get_texts

Expand Down Expand Up @@ -656,6 +657,8 @@ class Hypothesis:
# It contains only one entry.
log_prob: torch.Tensor

state_cost: Optional[NgramLmStateCost] = None

@property
def key(self) -> str:
"""Return a string representation of self.ys"""
Expand Down Expand Up @@ -1539,3 +1542,173 @@ def fast_beam_search_with_nbest_rnn_rescoring(
ans[key] = hyps

return ans


def modified_beam_search_ngram_rescoring(
model: Transducer,
encoder_out: torch.Tensor,
encoder_out_lens: torch.Tensor,
ngram_lm: NgramLm,
ngram_lm_scale: float,
beam: int = 4,
temperature: float = 1.0,
) -> List[List[int]]:
"""Beam search in batch mode with --max-sym-per-frame=1 being hardcoded.
Args:
model:
The transducer model.
encoder_out:
Output from the encoder. Its shape is (N, T, C).
encoder_out_lens:
A 1-D tensor of shape (N,), containing number of valid frames in
encoder_out before padding.
beam:
Number of active paths during the beam search.
temperature:
Softmax temperature.
Returns:
Return a list-of-list of token IDs. ans[i] is the decoding results
for the i-th utterance.
"""
assert encoder_out.ndim == 3, encoder_out.shape
assert encoder_out.size(0) >= 1, encoder_out.size(0)

packed_encoder_out = torch.nn.utils.rnn.pack_padded_sequence(
input=encoder_out,
lengths=encoder_out_lens.cpu(),
batch_first=True,
enforce_sorted=False,
)

blank_id = model.decoder.blank_id
unk_id = getattr(model, "unk_id", blank_id)
context_size = model.decoder.context_size
device = next(model.parameters()).device
lm_scale = ngram_lm_scale

batch_size_list = packed_encoder_out.batch_sizes.tolist()
N = encoder_out.size(0)
assert torch.all(encoder_out_lens > 0), encoder_out_lens
assert N == batch_size_list[0], (N, batch_size_list)

B = [HypothesisList() for _ in range(N)]
for i in range(N):
B[i].add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
state_cost=NgramLmStateCost(ngram_lm),
)
)

encoder_out = model.joiner.encoder_proj(packed_encoder_out.data)

offset = 0
finalized_B = []
for batch_size in batch_size_list:
start = offset
end = offset + batch_size
current_encoder_out = encoder_out.data[start:end]
current_encoder_out = current_encoder_out.unsqueeze(1).unsqueeze(1)
# current_encoder_out's shape is (batch_size, 1, 1, encoder_out_dim)
offset = end

finalized_B = B[batch_size:] + finalized_B
B = B[:batch_size]

hyps_shape = get_hyps_shape(B).to(device)

A = [list(b) for b in B]
B = [HypothesisList() for _ in range(batch_size)]

ys_log_probs = torch.cat(
[
hyp.log_prob.reshape(1, 1) + hyp.state_cost.lm_score * lm_scale
for hyps in A
for hyp in hyps
]
) # (num_hyps, 1)

decoder_input = torch.tensor(
[hyp.ys[-context_size:] for hyps in A for hyp in hyps],
device=device,
dtype=torch.int64,
) # (num_hyps, context_size)

decoder_out = model.decoder(decoder_input, need_pad=False).unsqueeze(1)
decoder_out = model.joiner.decoder_proj(decoder_out)
# decoder_out is of shape (num_hyps, 1, 1, joiner_dim)

# Note: For torch 1.7.1 and below, it requires a torch.int64 tensor
# as index, so we use `to(torch.int64)` below.
current_encoder_out = torch.index_select(
current_encoder_out,
dim=0,
index=hyps_shape.row_ids(1).to(torch.int64),
) # (num_hyps, 1, 1, encoder_out_dim)

logits = model.joiner(
current_encoder_out,
decoder_out,
project_input=False,
) # (num_hyps, 1, 1, vocab_size)

logits = logits.squeeze(1).squeeze(1) # (num_hyps, vocab_size)

log_probs = (logits / temperature).log_softmax(
dim=-1
) # (num_hyps, vocab_size)

log_probs.add_(ys_log_probs)
vocab_size = log_probs.size(-1)
log_probs = log_probs.reshape(-1)

row_splits = hyps_shape.row_splits(1) * vocab_size
log_probs_shape = k2.ragged.create_ragged_shape2(
row_splits=row_splits, cached_tot_size=log_probs.numel()
)
ragged_log_probs = k2.RaggedTensor(
shape=log_probs_shape, value=log_probs
)

for i in range(batch_size):
topk_log_probs, topk_indexes = ragged_log_probs[i].topk(beam)

with warnings.catch_warnings():
warnings.simplefilter("ignore")
topk_hyp_indexes = (topk_indexes // vocab_size).tolist()
topk_token_indexes = (topk_indexes % vocab_size).tolist()

for k in range(len(topk_hyp_indexes)):
hyp_idx = topk_hyp_indexes[k]
hyp = A[i][hyp_idx]

new_ys = hyp.ys[:]
new_token = topk_token_indexes[k]
if new_token not in (blank_id, unk_id):
new_ys.append(new_token)
state_cost = hyp.state_cost.forward_one_step(new_token)
else:
state_cost = hyp.state_cost

# We only keep AM scores in new_hyp.log_prob
new_log_prob = (
topk_log_probs[k] - hyp.state_cost.lm_score * lm_scale
)

new_hyp = Hypothesis(
ys=new_ys, log_prob=new_log_prob, state_cost=state_cost
)
B[i].add(new_hyp)

B = B + finalized_B
best_hyps = [b.get_most_probable(length_norm=True) for b in B]

sorted_ans = [h.ys[context_size:] for h in best_hyps]
ans = []
unsorted_indices = packed_encoder_out.unsorted_indices.tolist()
for i in range(N):
ans.append(sorted_ans[unsorted_indices[i]])

return ans
2 changes: 2 additions & 0 deletions icefall/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,5 @@
subsequent_chunk_mask,
write_error_stats,
)

from .ngram_lm import NgramLm, NgramLmStateCost
Loading

0 comments on commit 9b671e1

Please sign in to comment.