Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

streaming model distialltion with codebook loss #1

Open
wants to merge 1 commit into
base: streaming-conformer
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 34 additions & 9 deletions egs/librispeech/ASR/pruned_transducer_stateless2/conformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __init__(
short_chunk_size: int = 25,
num_left_chunks: int = -1,
causal: bool = False,
middle_output_layer: int = None, # 0-based layer index
) -> None:
super(Conformer, self).__init__()

Expand Down Expand Up @@ -121,12 +122,27 @@ def __init__(
cnn_module_kernel,
causal,
)
self.encoder = ConformerEncoder(encoder_layer, num_encoder_layers)

output_layers = []
if middle_output_layer is not None:
assert (
middle_output_layer >= 0
and middle_output_layer < num_encoder_layers
)
output_layers.append(middle_output_layer)

# The last layer is always needed.
output_layers.append(num_encoder_layers - 1)

self.encoder = ConformerEncoder(
encoder_layer, num_encoder_layers, output_layers=output_layers
)

self._init_state: List[torch.Tensor] = [torch.empty(0)]

def forward(
self, x: torch.Tensor, x_lens: torch.Tensor, warmup: float = 1.0
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[List[torch.Tensor], torch.Tensor]:
"""
Args:
x:
Expand Down Expand Up @@ -176,24 +192,23 @@ def forward(
num_left_chunks=self.num_left_chunks,
device=x.device,
)
x = self.encoder(
layer_results = self.encoder(
x,
pos_emb,
mask=mask,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C)
else:
x = self.encoder(
layer_results = self.encoder(
x,
pos_emb,
mask=None,
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
) # (T, N, C)

x = x.permute(1, 0, 2) # (T, N, C) ->(N, T, C)
return x, lengths
return layer_results, lengths

@torch.jit.export
def get_init_state(
Expand Down Expand Up @@ -647,12 +662,18 @@ class ConformerEncoder(nn.Module):
>>> out = conformer_encoder(src, pos_emb)
"""

def __init__(self, encoder_layer: nn.Module, num_layers: int) -> None:
def __init__(
self,
encoder_layer: nn.Module,
num_layers: int,
output_layers: List[int],
) -> None:
super().__init__()
self.layers = nn.ModuleList(
[copy.deepcopy(encoder_layer) for i in range(num_layers)]
)
self.num_layers = num_layers
self.output_layers = output_layers

def forward(
self,
Expand All @@ -661,7 +682,7 @@ def forward(
mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
warmup: float = 1.0,
) -> Tensor:
) -> List[Tensor]:
r"""Pass the input through the encoder layers in turn.

Args:
Expand All @@ -682,6 +703,7 @@ def forward(
"""
output = src

layer_results = []
for layer_index, mod in enumerate(self.layers):
output = mod(
output,
Expand All @@ -690,8 +712,11 @@ def forward(
src_key_padding_mask=src_key_padding_mask,
warmup=warmup,
)
if layer_index in self.output_layers:
# (T, N, C) --> (N, T, C)
layer_results.append(output.permute(1, 0, 2))

return output
return layer_results

@torch.jit.export
def chunk_forward(
Expand Down
58 changes: 56 additions & 2 deletions egs/librispeech/ASR/pruned_transducer_stateless2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from icefall.utils import add_sos

from multi_quantization.prediction import JointCodebookLoss

class Transducer(nn.Module):
"""It implements https://arxiv.org/pdf/1211.3711.pdf
Expand All @@ -38,6 +39,7 @@ def __init__(
decoder_dim: int,
joiner_dim: int,
vocab_size: int,
num_codebooks: int = 0,
):
"""
Args:
Expand All @@ -55,6 +57,8 @@ def __init__(
(N, U, decoder_dim).
Its output shape is (N, T, U, vocab_size). Note that its output
contains unnormalized probs, i.e., not processed by log-softmax.
num_codebooks:
Used by distillation loss.
"""
super().__init__()
assert isinstance(encoder, EncoderInterface), type(encoder)
Expand All @@ -68,6 +72,10 @@ def __init__(
encoder_dim, vocab_size, initial_speed=0.5
)
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size)
if num_codebooks > 0:
self.codebook_loss_net = JointCodebookLoss(
predictor_channels=encoder_dim, num_codebooks=num_codebooks
)

def forward(
self,
Expand All @@ -78,6 +86,7 @@ def forward(
am_scale: float = 0.0,
lm_scale: float = 0.0,
warmup: float = 1.0,
codebook_indexes: torch.Tensor = None,
) -> torch.Tensor:
"""
Args:
Expand All @@ -101,6 +110,8 @@ def forward(
warmup:
A value warmup >= 0 that determines which modules are active, values
warmup > 1 "are fully warmed up" and all modules will be active.
codebook_indexes:
codebook_indexes extracted from a teacher model.
Returns:
Return the transducer loss.

Expand All @@ -116,7 +127,22 @@ def forward(

assert x.size(0) == x_lens.size(0) == y.dim0

encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup)
layer_results, x_lens = self.encoder(x, x_lens, warmup=warmup)
encoder_out = layer_results[-1]
middle_layer_output = layer_results[0]
if self.training and codebook_indexes is not None:
assert hasattr(self, "codebook_loss_net")
if codebook_indexes.shape[1] != middle_layer_output.shape[1]:
codebook_indexes = self.concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes
)
codebook_loss = self.codebook_loss_net(
middle_layer_output, codebook_indexes
)
else:
# when codebook index is not available.
codebook_loss = None

assert torch.all(x_lens > 0)

# Now for the decoder, i.e., the prediction network
Expand Down Expand Up @@ -191,4 +217,32 @@ def forward(
reduction="sum",
)

return (simple_loss, pruned_loss)
return (simple_loss, pruned_loss, codebook_loss)

@staticmethod
def concat_successive_codebook_indexes(
middle_layer_output, codebook_indexes
):
# Output rate of hubert is 50 frames per second,
# while that of current encoder is 25.
# Following code handling two issues:
# 1.
# Roughly speaking, to generate another frame output,
# hubert needes extra two frames,
# while current encoder needs extra four frames.
# Suppose there are only extra three frames provided,
# hubert will generate another frame while current encoder does nothing.
# 2.
# codebook loss is a frame-wise loss, to enalbe 25 frames studnet output
# learns from 50 frames teacher output, two successive frames of teacher model
# output is concatenated together.
t_expected = middle_layer_output.shape[1]
N, T, C = codebook_indexes.shape

# Handling issue 1.
if T >= t_expected * 2:
codebook_indexes = codebook_indexes[:, : t_expected * 2, :]
# Handling issue 2.
codebook_indexes = codebook_indexes.reshape(N, t_expected, C * 2)
assert middle_layer_output.shape[1] == codebook_indexes.shape[1]
return codebook_indexes
55 changes: 50 additions & 5 deletions egs/librispeech/ASR/pruned_transducer_stateless4/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
--world-size 4 \
--num-epochs 30 \
--start-epoch 1 \
--exp-dir pruned_transducer_stateless2/exp \
--exp-dir pruned_transducer_stateless4/exp \
--full-libri 1 \
--max-duration 300

Expand All @@ -37,7 +37,7 @@
--num-epochs 30 \
--start-epoch 1 \
--use-fp16 1 \
--exp-dir pruned_transducer_stateless2/exp \
--exp-dir pruned_transducer_stateless4/exp \
--full-libri 1 \
--max-duration 550

Expand Down Expand Up @@ -74,9 +74,10 @@
from conformer import Conformer
from decoder import Decoder
from joiner import Joiner
from lhotse.cut import Cut
from lhotse.cut import Cut, MonoCut
from lhotse.dataset.sampling.base import CutSampler
from lhotse.utils import fix_random_seed
from lhotse.dataset.collation import collate_custom_field
from model import Transducer
from optim import Eden, Eve
from torch import Tensor
Expand Down Expand Up @@ -235,6 +236,13 @@ def get_parser():
"with this parameter before adding to the final loss.",
)

parser.add_argument(
"--codebook-loss-scale",
type=float,
default=0.1,
help="The scale of codebook loss.",
)

parser.add_argument(
"--seed",
type=int,
Expand Down Expand Up @@ -398,6 +406,13 @@ def get_params() -> AttributeDict:
# parameters for Noam
"model_warm_step": 3000, # arg given to model, not for lrate
"env_info": get_env_info(),
# parameters for distillation with codebook indexes.
"enable_distiallation": True,
"distillation_layer": 5, # 0-based index
# Since output rate of hubert is 50, while that of encoder is 8,
# two successive codebook_index are concatenated together.
# Detailed in function Transducer::concat_sucessive_codebook_indexes.
"num_codebooks": 16, # used to construct distillation loss
}
)

Expand All @@ -417,6 +432,9 @@ def get_encoder_model(params: AttributeDict) -> nn.Module:
short_chunk_size=params.short_chunk_size,
num_left_chunks=params.num_left_chunks,
causal=params.causal_convolution,
middle_output_layer=params.distillation_layer
if params.enable_distiallation
else None,
)
return encoder

Expand Down Expand Up @@ -454,6 +472,9 @@ def get_transducer_model(params: AttributeDict) -> nn.Module:
decoder_dim=params.decoder_dim,
joiner_dim=params.joiner_dim,
vocab_size=params.vocab_size,
num_codebooks=params.num_codebooks
if params.enable_distiallation
else 0,
)
return model

Expand Down Expand Up @@ -577,6 +598,18 @@ def save_checkpoint(
copyfile(src=filename, dst=best_valid_filename)


def extract_codebook_indexes(batch):
cuts = batch["supervisions"]["cut"]
# -100 is identical to ignore_value in CE loss computation.
cuts_pre_mixed = [
c if isinstance(c, MonoCut) else c.tracks[0].cut for c in cuts
]
codebook_indexes, codebook_indexes_lens = collate_custom_field(
cuts_pre_mixed, "codebook_indexes", pad_value=-100
)
return codebook_indexes, codebook_indexes_lens


def compute_loss(
params: AttributeDict,
model: Union[nn.Module, DDP],
Expand Down Expand Up @@ -620,15 +653,23 @@ def compute_loss(
y = sp.encode(texts, out_type=int)
y = k2.RaggedTensor(y).to(device)

info = MetricsTracker()
if is_training and params.enable_distiallation:
codebook_indexes, _ = extract_codebook_indexes(batch)
codebook_indexes = codebook_indexes.to(device)
else:
codebook_indexes = None

with torch.set_grad_enabled(is_training):
simple_loss, pruned_loss = model(
simple_loss, pruned_loss, codebook_loss = model(
x=feature,
x_lens=feature_lens,
y=y,
prune_range=params.prune_range,
am_scale=params.am_scale,
lm_scale=params.lm_scale,
warmup=warmup,
codebook_indexes=codebook_indexes,
)
# after the main warmup step, we keep pruned_loss_scale small
# for the same amount of time (model_warm_step), to avoid
Expand All @@ -643,10 +684,12 @@ def compute_loss(
params.simple_loss_scale * simple_loss
+ pruned_loss_scale * pruned_loss
)
if is_training and params.enable_distiallation:
assert codebook_loss is not None
loss += params.codebook_loss_scale * codebook_loss

assert loss.requires_grad == is_training

info = MetricsTracker()
with warnings.catch_warnings():
warnings.simplefilter("ignore")
info["frames"] = (
Expand All @@ -657,6 +700,8 @@ def compute_loss(
info["loss"] = loss.detach().cpu().item()
info["simple_loss"] = simple_loss.detach().cpu().item()
info["pruned_loss"] = pruned_loss.detach().cpu().item()
if is_training and params.enable_distiallation:
info["codebook_loss"] = codebook_loss.detach().cpu().item()

return loss, info

Expand Down