Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

zipformer BF16 training recipe #1700

Merged
merged 12 commits into from
Aug 23, 2024
17 changes: 17 additions & 0 deletions egs/librispeech/ASR/RESULTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,23 @@ done

To decode with external language models, please refer to the documentation [here](https://k2-fsa.github.io/icefall/decoding-with-langugage-models/index.html).

We also support training Zipformer with AMP+bf16 format (requires bf16 support). See [here](https://github.com/k2-fsa/icefall/pull/1700) for more details and pre-trained models. **The same command can be used for decoding and exporting the model.**

The amp+bf16 training command is:
```bash
export CUDA_VISIBLE_DEVICES="0,1,2,3"
./zipformer/train.py \
--world-size 4 \
--num-epochs 50 \
--start-epoch 1 \
--use-fp16 0 \
--use-bf16 1 \
--exp-dir zipformer/exp_amp_bf16 \
--causal 0 \
--full-libri 1 \
--max-duration 1000
```

##### small-scaled model, number of model parameters: 23285615, i.e., 23.3 M

The tensorboard log can be found at
Expand Down
12 changes: 6 additions & 6 deletions egs/librispeech/ASR/zipformer/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def forward(ctx, x: Tensor, dim: int):
# (presumably) that op does not support float16, and autocast
# is enabled.
if torch.is_autocast_enabled():
ans = ans.to(torch.float16)
ans = ans.to(torch.get_autocast_gpu_dtype())
ctx.save_for_backward(ans)
ctx.x_dtype = x.dtype
ctx.dim = dim
Expand Down Expand Up @@ -1234,7 +1234,7 @@ class DoubleSwishFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
requires_grad = x.requires_grad
if x.dtype == torch.float16:
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
x = x.to(torch.float32)

s = torch.sigmoid(x - 1.0)
Expand Down Expand Up @@ -1346,7 +1346,7 @@ class SwooshLFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x: Tensor) -> Tensor:
requires_grad = x.requires_grad
if x.dtype == torch.float16:
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
x = x.to(torch.float32)

zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
Expand Down Expand Up @@ -1379,7 +1379,7 @@ def forward(ctx, x: Tensor) -> Tensor:
d_int = d_scaled.to(torch.uint8)
ctx.save_for_backward(d_int)
if x.dtype == torch.float16 or torch.is_autocast_enabled():
y = y.to(torch.float16)
y = y.to(torch.get_autocast_gpu_dtype())
return y

@staticmethod
Expand Down Expand Up @@ -1425,7 +1425,7 @@ class SwooshRFunction(torch.autograd.Function):
def forward(ctx, x: Tensor) -> Tensor:
requires_grad = x.requires_grad

if x.dtype == torch.float16:
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
x = x.to(torch.float32)

zero = torch.tensor(0.0, dtype=x.dtype, device=x.device)
Expand Down Expand Up @@ -1455,7 +1455,7 @@ def forward(ctx, x: Tensor) -> Tensor:
d_int = d_scaled.to(torch.uint8)
ctx.save_for_backward(d_int)
if x.dtype == torch.float16 or torch.is_autocast_enabled():
y = y.to(torch.float16)
y = y.to(torch.get_autocast_gpu_dtype())
return y

@staticmethod
Expand Down
47 changes: 36 additions & 11 deletions egs/librispeech/ASR/zipformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,13 @@ def get_parser():
help="Whether to use half precision training.",
)

parser.add_argument(
"--use-bf16",
type=str2bool,
default=False,
help="Whether to use bf16 in AMP.",
)

add_model_arguments(parser)

return parser
Expand Down Expand Up @@ -1027,7 +1034,9 @@ def save_bad_model(suffix: str = ""):
batch_size = len(batch["supervisions"]["text"])

try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch.cuda.amp.autocast(
enabled=params.use_autocast, dtype=params.dtype
):
loss, loss_info = compute_loss(
params=params,
model=model,
Expand All @@ -1047,9 +1056,7 @@ def save_bad_model(suffix: str = ""):
scaler.update()
optimizer.zero_grad()
except Exception as e:
logging.info(
f"Caught exception: {e}."
)
logging.info(f"Caught exception: {e}.")
save_bad_model()
display_and_save_batch(batch, params=params, sp=sp)
raise
Expand Down Expand Up @@ -1090,7 +1097,7 @@ def save_bad_model(suffix: str = ""):
rank=rank,
)

if batch_idx % 100 == 0 and params.use_fp16:
if batch_idx % 100 == 0 and params.use_autocast:
# If the grad scale was less than 1, try increasing it. The _growth_interval
# of the grad scaler is configurable, but we can't configure it to have different
# behavior depending on the current grad scale.
Expand All @@ -1109,14 +1116,14 @@ def save_bad_model(suffix: str = ""):

if batch_idx % params.log_interval == 0:
cur_lr = max(scheduler.get_last_lr())
cur_grad_scale = scaler._scale.item() if params.use_fp16 else 1.0
cur_grad_scale = scaler._scale.item() if params.use_autocast else 1.0

logging.info(
f"Epoch {params.cur_epoch}, "
f"batch {batch_idx}, loss[{loss_info}], "
f"tot_loss[{tot_loss}], batch size: {batch_size}, "
f"lr: {cur_lr:.2e}, "
+ (f"grad_scale: {scaler._scale.item()}" if params.use_fp16 else "")
+ (f"grad_scale: {scaler._scale.item()}" if params.use_autocast else "")
)

if tb_writer is not None:
Expand All @@ -1128,7 +1135,7 @@ def save_bad_model(suffix: str = ""):
tb_writer, "train/current_", params.batch_idx_train
)
tot_loss.write_summary(tb_writer, "train/tot_", params.batch_idx_train)
if params.use_fp16:
if params.use_autocast:
tb_writer.add_scalar(
"train/grad_scale", cur_grad_scale, params.batch_idx_train
)
Expand Down Expand Up @@ -1204,9 +1211,25 @@ def run(rank, world_size, args):
params.ctc_loss_scale = 1.0
else:
assert params.ctc_loss_scale + params.attention_decoder_loss_scale == 1.0, (
params.ctc_loss_scale, params.attention_decoder_loss_scale
params.ctc_loss_scale,
params.attention_decoder_loss_scale,
)

if params.use_bf16: # amp + bf16
assert torch.cuda.is_bf16_supported(), "Your GPU does not support bf16!"
assert not params.use_fp16, "You can only use either fp16 or bf16"
params.dtype = torch.bfloat16
params.use_autocast = True
elif params.use_fp16: # amp + fp16
params.dtype = torch.float16
params.use_autocast = True
else: # fp32
params.dtype = torch.float32
params.use_autocast = False

logging.info(f"Using dtype={params.dtype}")
logging.info(f"Use AMP={params.use_autocast}")

logging.info(params)

logging.info("About to create model")
Expand Down Expand Up @@ -1339,7 +1362,7 @@ def remove_short_and_long_utt(c: Cut):
params=params,
)

scaler = GradScaler(enabled=params.use_fp16, init_scale=1.0)
scaler = GradScaler(enabled=params.use_autocast, init_scale=1.0)
if checkpoints and "grad_scaler" in checkpoints:
logging.info("Loading grad scaler state dict")
scaler.load_state_dict(checkpoints["grad_scaler"])
Expand Down Expand Up @@ -1439,7 +1462,9 @@ def scan_pessimistic_batches_for_oom(
for criterion, cuts in batches.items():
batch = train_dl.dataset[cuts]
try:
with torch.cuda.amp.autocast(enabled=params.use_fp16):
with torch.cuda.amp.autocast(
enabled=params.use_autocast, dtype=params.dtype
):
loss, _ = compute_loss(
params=params,
model=model,
Expand Down
Loading