Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support fine-tuning #944

Merged
merged 16 commits into from
Mar 17, 2023
Merged

Conversation

marcoyang1998
Copy link
Collaborator

@marcoyang1998 marcoyang1998 commented Mar 14, 2023

This PR adds a fine-tune script for recipe pruned_transducer_stateless7.

It fine-tunes a model trained with LibriSpeech on GigaSpeech. To do fine-tuning, you need to provide the path to the checkpoint from which the training will resume. You also need to set --do-finetune True. Below is an example of fine-tuning on the GigaSpeech subset S:

base_lr=0.005
lr_epochs=100
lr_batches=100000

./pruned_transducer_stateless7/finetune.py \
    --world-size 2 \
    --master-port 18181 \
    --num-epochs 20 \
    --start-epoch 1 \
    --exp-dir pruned_transducer_stateless7/exp_giga_finetune_S_baselr_${base_lr}_lrepochs_${lr_epochs}_lrbatches_${lr_batches} \
    --subset S \
    --use-fp16 1 \
    --base-lr $base_lr \
    --lr-epochs $lr_epochs \
    --lr-batches $lr_batches \
    --bpe-model data/lang_bpe_500/bpe.model \
    --do-finetune True \
    --finetune-ckpt /ceph-data4/yangxiaoyu/softwares/icefall_development/icefall_finetune/egs/librispeech/ASR/icefall-asr-librispeech-pruned-transducer-stateless7-2022-11-11/exp/pretrained-epoch-30-avg-9.pt \
    --max-duration 500

The WERs on the GigaSpeech dev&test set after fine-tuning are shown below. As a reference, the WERs of the same model trained on GigaSpeech subset S are also shown:

model dev test
fine-tune from Libri 13.63 13.59
train from scratch 14.88 15.04

@marcoyang1998
Copy link
Collaborator Author

To do fine-tuning on another dataset or your own dataset, you just need to replace gigaspeech.py with your own script for creating train/test dataloaders.

@marcoyang1998
Copy link
Collaborator Author

marcoyang1998 commented Mar 16, 2023

Here are a few experiments with different learning rate schedule/initializations:

Fine-tune init modules base_lr lr_epoch lr_batches Epoch-avg Dev Test
No None 0.05 3.5 5000 30-13 14.88 15.04
Yes all 0.05 3.5 5000 15-10 15.52 15.61
Yes all 0.01 3.5 5000 15-10 13.85 13.96
Yes all 0.005 3.5 5000 15-10 13.63 13.81
Yes all 0.001 3.5 5000 15-10 13.82 13.89
Yes all 0.005 100 1e5 11-10 13.59 13.49
Yes all 0.001 100 1e5 17-10 13.63 13.69
Yes encoder 0.005 100 1e5 13-10 13.62 13.62
Yes encoder+decoder 0.005 100 1e5 13-10 13.65 13.61

@marcoyang1998 marcoyang1998 merged commit 7948624 into k2-fsa:master Mar 17, 2023
@pingfengluo
Copy link
Contributor

@marcoyang1998 can you upload the tensorboard logs of the finetuning experiment

@marcoyang1998
Copy link
Collaborator Author

See https://tensorboard.dev/experiment/5xPOGv6CQHO02YEApc1jxg/. @pingfengluo

It's from the experiment with base_lr=0.005, lr_epoch=100, lr_batches=1e6.

@pingfengluo
Copy link
Contributor

See https://tensorboard.dev/experiment/5xPOGv6CQHO02YEApc1jxg/. @pingfengluo

It's from the experiment with base_lr=0.005, lr_epoch=100, lr_batches=1e6.

ok, thank you

@HsunGong
Copy link

HsunGong commented Sep 1, 2023

@marcoyang1998 hello, I found that using the same config, I can not reproduce the result, my running command is ./finetune.py --num-epochs 20 --start-epoch 1 --exp-dir exp/ft_giga_lr5e-3 --use-fp16 1 --base-lr 0.005 --lr-epochs 100 --lr-batches 100000 --do-finetune True --use-mux false --finetune-ckpt exp/pretrained-pt7/pretrained.pt --max-duration 500

But my result only gets 14.66/15.04


my tensorboard logs are as follows:

  • min valid_pruned_loss: 0.055
  • min valid_simple_loss: 0.297

Compared with https://tensorboard.dev/experiment/5xPOGv6CQHO02YEApc1jxg/#scalars, I found that my valid loss is smaller, but performance is worse?

image image

Could you please prvoide a checkpoint of epoch11 avg 10 ?

@HsunGong
Copy link

HsunGong commented Sep 1, 2023

@marcoyang1998 By the way, can you provide the result of GigaSpeech that is decoded from LibriSpeech pretrained model (without finetuning)? So I can validate my pretrained model is correct or not, THX.

From the results above in #944 and #1059, there is no result without finetuning.

@marcoyang1998
Copy link
Collaborator Author

@HsunGong How many GPUs are you using? And are you using decode_gigaspeech.py for decoding?

The model without finetuning has 20.19/19.58 with greedy_search

@HsunGong
Copy link

HsunGong commented Sep 1, 2023

I got 20.96/20.26 with modfied_beamsearch+beam=8 without ft.

Two 3090 are used with max-duration=500.


There are some differences between https://tensorboard.dev/experiment/5xPOGv6CQHO02YEApc1jxg/#scalars and my experiments:

1, LR scale: it changes from 5e-3 to 1.5e-3 during training (while mine only to 4.5e-3)
2, training step: the max steps is ~50K (while mine is ~20K)
3, valid loss: the min valid loss is 0.3 (while mine is 0.19)


The dataset is built differently, as I reuse the kaldi-fsa/kaldi:egs/gigaspeech as the data dir and use on-the-fly features (rather than rebuild it with k2/icefall, as I found that the features remain the same)

@marcoyang1998
Copy link
Collaborator Author

LR scale: it changes from 5e-3 to 1.5e-3 during training (while mine only to 4.5e-3)

This seems to be caused by this change, it skips the warmup period and sets the batch count to a very large number:

https://github.com/yfyeung/icefall/blob/9df058f05c356eb411e1d6db53d29c1690d59ac5/egs/librispeech/ASR/pruned_transducer_stateless7/finetune.py#L923-L927

If you want to reproduce my results, you might need to remove the above few lines so that warmup is not skipped. But theoretically, warmup is not needed during finetuning as long as a small enough learning rate is used.

training step: the max steps is ~50K (while mine is ~20K)

At 20k updates, my exp only reached the 8-th epoch. Since we are using using 2 GPUs with the same max-duration, this looks a bit weird to me.

But overall I think the WER discrepancy is reasonable, as your initialization is worse than mine.

@HsunGong
Copy link

HsunGong commented Sep 1, 2023

THX, I think the main problem is from the dataset scale, as the max-dur is the same, there is no reason that the total steps per epoch is different (mine is ~900, yours is ~2800), and that might be the issue. I'll recheck my dataset setup to see if it is using the total train-S(250h).


This may due to the lr-batch issue, I follow your tensorboard setting (epoch=20, per-epoch=2837), find that lr_batch=3000 is more similar to the tensorboard's lr curve. I'll redo my exprs.

image

@marcoyang1998
Copy link
Collaborator Author

We are using 3x speed perturb, which explains the total step difference.

@OswaldoBornemann
Copy link

@marcoyang1998 I want to train on a new dataset (which is larger than the previous dataset) based on a previous checkpoint (I think that it's quite similar to fine-tuning). The previous checkpoint used to take 4-5 days to train for one epoch. However, I've found that now it takes only about 2 days to train for one epoch (which is larger than the previous dataset).

Therefore, I suspect there might be a problem somewhere in the dataloader's iteration in the training in terms on the larger dataset, causing the model to not actually traverse all the data during the actual training.

@OswaldoBornemann
Copy link

I haven't used the finetune.py. I just used the previous train.py, but I tried to comment the following in the load_checkpoint_if_available:

# keys = [
    #     "best_train_epoch",
    #     "best_valid_epoch",
    #     "batch_idx_train",
    #     "best_train_loss",
    #     "best_valid_loss",
    # ]
    # for k in keys:
    #     params[k] = saved_params[k]

    # if params.start_batch > 0:
    #     if "cur_epoch" in saved_params:
    #         params["start_epoch"] = saved_params["cur_epoch"]
    #
    #     if "cur_batch_idx" in saved_params:
    #         params["cur_batch_idx"] = saved_params["cur_batch_idx"]

@OswaldoBornemann
Copy link

I think the load_checkpoint_if_available is used to continue training on the previous checkpoints also on the same dataset. But the load_model_params in finetune.py is used to continue training on the previous checkpoints but on the different dataset. Am i correct? @marcoyang1998

@marcoyang1998
Copy link
Collaborator Author

Yes, in most cases.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants