Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[transformer] add multi warmup and learning rate for different modules #2449

Merged
merged 19 commits into from
Apr 11, 2024

Conversation

Mddct
Copy link
Collaborator

@Mddct Mddct commented Mar 28, 2024

当微调的时候,模型的不同部分需要的学习率和warmup不一样,比如以下 微调whisper的手 因为没有ctc的weight

我门可以设置ctc的warmu up为12000, lr 0.01, 其他部分为1200,0.00001

optim: adam
optim_conf:
  lr: [0.01, 0.00001]
  modules: ['ctc']
scheduler: warmuplr
scheduler_conf:
  warmup_steps: [1500,5000]

TODO

  • whisper aishell result
  • tensorboard multiple lr
  • refactor code and rebase main

@Mddct Mddct changed the title [transformer] add multi warmup and learning rate for different modules [WIP][transformer] add multi warmup and learning rate for different modules Mar 28, 2024
@Mddct Mddct mentioned this pull request Mar 28, 2024
24 tasks
@Mddct
Copy link
Collaborator Author

Mddct commented Mar 28, 2024

截屏2024-03-28 20 09 28

@Mddct
Copy link
Collaborator Author

Mddct commented Mar 28, 2024

tensorboard works!
截屏2024-03-28 22 56 02

@Mddct
Copy link
Collaborator Author

Mddct commented Mar 29, 2024

b39a8b1 这个commit 重构了cv 和train 以及 lr 的打印逻辑

step 模式单个lr+warmup , cv和train

image

epoch 模式 单个lr+warmup , cv和train

image

epoch 模式 多个个lr+warmup , cv和train

image

step 模式多个lr+warmup , cv和train

image

@Mddct Mddct closed this Mar 29, 2024
@Mddct Mddct reopened this Mar 29, 2024
@Mddct
Copy link
Collaborator Author

Mddct commented Mar 30, 2024

fsdp (TBD)

8卡A100, step 模式 ,3000 save interval, fp32, zero2

  • Feature info: using log_mel_spectrogram feature, no cmvn, no speed perturb
  • Training info: conf/finetune_whisper_largev3_conv2d4.yaml
  • Decoding info: ctc_weight 0.3, average_num 3
  • Git hash: TBD
decoding mode CER
attention decoder 3.09 % N=104765 C=101638 S=3024 D=103 I=108
ctc greedy search 6.24 % N=104765 C=99161 S=5422 D=182 I=936
ctc prefix beam search 6.18 % N=104765 C=99158 S=5429 D=178 I=863
attention rescoring 4.64 % N=104765 C=100207 S=4452 D=106 I=301

8卡A100, step 模式 ,3000 save interval, bf16, zero2

decoding mode CER
attention decoder 3.22 % N=104765 C=101503 S=3165 D=97 I=108
ctc greedy search 6.57 % N=104765 C=98913 S=5667 D=185 I=1033
ctc prefix beam search 6.46 % N=104765 C=98928 S=5655 D=182 I=936
attention rescoring 4.80 % N=104765 C=99997 S=4665 D=103 I=260

@fclearner
Copy link
Contributor

fsdp (TBD)

8卡A100, step 模式 ,3000 save interval, fp32, zero2

  • Feature info: using log_mel_spectrogram feature, no cmvn, no speed perturb
  • Training info: conf/finetune_whisper_largev3_conv2d4.yaml
  • Decoding info: ctc_weight 0.3, average_num 3
  • Git hash: TBD

decoding mode CER
attention decoder 3.09 % N=104765 C=101638 S=3024 D=103 I=108
ctc greedy search 6.24 % N=104765 C=99161 S=5422 D=182 I=936
ctc prefix beam search 6.18 % N=104765 C=99158 S=5429 D=178 I=863
attention rescoring 4.64 % N=104765 C=100207 S=4452 D=106 I=301

强,ctc的收益是因为使用了不同的学习率吗

@Mddct
Copy link
Collaborator Author

Mddct commented Apr 1, 2024

fsdp (TBD)

8卡A100, step 模式 ,3000 save interval, fp32, zero2

  • Feature info: using log_mel_spectrogram feature, no cmvn, no speed perturb
  • Training info: conf/finetune_whisper_largev3_conv2d4.yaml
  • Decoding info: ctc_weight 0.3, average_num 3
  • Git hash: TBD

decoding mode CER
attention decoder 3.09 % N=104765 C=101638 S=3024 D=103 I=108
ctc greedy search 6.24 % N=104765 C=99161 S=5422 D=182 I=936
ctc prefix beam search 6.18 % N=104765 C=99158 S=5429 D=178 I=863
attention rescoring 4.64 % N=104765 C=100207 S=4452 D=106 I=301

强,ctc的收益是因为使用了不同的学习率吗

直觉上,整个模型是pretrain的 除了ctc和conv2d, 所以只让这两个较大的学习率 ‘快速的学习’ (上边这个只有ctc的lr和enc+dec的lr),其他的stable training
可以参考 bestrq paper: https://arxiv.org/pdf/2202.01855.pdf
101551711637944_ pic

未来也会支持w2vbert的fintune 所以这里就引入了

@Mddct
Copy link
Collaborator Author

Mddct commented Apr 7, 2024

#2412 修复了A100 和v100精度问题, 这里实验需要重新跑下 (预期还会更好些)

@Mddct
Copy link
Collaborator Author

Mddct commented Apr 8, 2024

#2412 使用修复后的的结果:

8卡A100, step 模式 ,3000 save interval, bf16, zero2

decoding mode CER
attention decoder 3.13 % N=104765 C=101598 S=3068 D=99 I=111 |
ctc greedy search 6.10 % N=104765 C=99224 S=5366 D=175 I=845
ctc prefix beam search 6.03 % N=104765 C=99223 S=5367 D=175 I=776
attention rescoring 4.52 % N=104765 C=100241 S=4420 D=104 I=215

@Mddct
Copy link
Collaborator Author

Mddct commented Apr 9, 2024

optim: adam
optim_conf:
  lr: [0.001, 0.00005, 0.00001]
  modules: ['ctc', 'encoder.embed']
scheduler: warmuplr
scheduler_conf:
  warmup_steps: [1500, 10000, 5000]

8卡A100, step 模式 ,3000 save interval, bf16, zero2

decoding mode CER
attention decoder 3.09 % N=104765 C=101626 S=3039 D=100 I=103|
ctc greedy search 6.04 % N=104765 C=99335 S=5266 D=164 I=899
ctc prefix beam search 5.96 % N=104765 C=99345 S=5258 D=162 I=821
attention rescoring 4.47 % N=104765 C=100289 S=4370 D=106 I=212

@Mddct Mddct changed the title [WIP][transformer] add multi warmup and learning rate for different modules [transformer] add multi warmup and learning rate for different modules Apr 10, 2024
@Mddct
Copy link
Collaborator Author

Mddct commented Apr 10, 2024

image 代表我们已经跑了0个epoch 600个step,保存模型599.pt (从0.pt开始) 保存的yaml里边对应的step 也为599 截屏2024-04-10 21 27 37

@Mddct Mddct requested a review from xingchensong April 10, 2024 13:27
Comment on lines +468 to +469
if isinstance(lr, List):
optim_conf['lr'] = lr[-1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这一步是为啥,optimizer必须传一个lr参数且不能是list?

@xingchensong xingchensong merged commit 26ba7d1 into main Apr 11, 2024
6 checks passed
@xingchensong xingchensong deleted the Mddct-multiple-warmupLR branch April 11, 2024 11:39
Zth9730 pushed a commit to Zth9730/wenet that referenced this pull request Aug 7, 2024
add casual model

fix typo

rm ckpt

add topk topp sampler

fix positoin

[train_engine] support fsdp (wenet-e2e#2412)

* [train_engine] support fsdp

* [train_engine] support fsdp

* unify scaler and amp

* fp32&&fp16 works in fsdp env

* fix fsdp in cv auto cast

* try to fix wenet.join fsdp

* implementing zero1 under fsdp is almost equivalent to deepspeed's zero1

* fix clip_and_grad_

* fix train summary

* all wenet xxxformer works (-paraformer -transducer)

* try to fix nan

* add barrier for cv

* add destroy group for end of all train

* refactor wrap methods and ckpt works

* fix ckpt

* fix cv in dtype != float32

* fix ckpt in model mode

* fix bf16 amp

* refactor scaler and autocast, fix fp32 fp16 bf16 for fsdp

* fix fp32 nullcontext to nullcontext()

* modify after review

* fix lint

* fix lint

LoRA support (wenet-e2e#2049)

* support lora for v3.0.1

* format code and update lora attention && encoder

* fix bug when lora_list is None

---------

Co-authored-by: Xingchen Song(宋星辰) <xingchensong1996@163.com>

[env] update python version and deepspeed version (wenet-e2e#2462)

* [env] update python version and deepspeed version

* [env] fix lint

fix rope pos embdining (wenet-e2e#2463)

* fix rope pos embdining

* fix dropout

* fix comment

[transformer] add multi warmup and learning rate for different modules (wenet-e2e#2449)

* [transformer] add multi warmup and learning rate for different modules

* fix typo

* it works in warmuplr

* fix lr in tensorboard in step mode

* fix cv log

* cv works

* refactor cv log

* add helper lrs_to_string

* fix lrstr

* fix ddp multiple lr

* fix initial step

* revert to -1

* fix sub params dup

* fix step

* fix step

* fix log

* add assert for scheduler

* add comment for log

---------

Co-authored-by: Xingchen Song(宋星辰) <xingchensong1996@163.com>

add generate

add toto

support sft & pretrain training forward

gemm conversion works

support init casual model

[whisper] limit language to Chinese (wenet-e2e#2470)

[train] convert tensor to scalar (wenet-e2e#2471)

[workflow] upgrad python version to 3.10 (wenet-e2e#2472)

* [workflow] upgrad python version to 3.10

* [workflow] try to pass

refactor cache behaviour in training mode (reduce compute cost and memory) (wenet-e2e#2473)

all gemma model works

fix ut

fix ut (wenet-e2e#2477)

* fix ut

* fix py version

[transformer] Make MoE runnable (wenet-e2e#2474)

[transformer] fix mqa (wenet-e2e#2478)

enable mmap in torch.load (wenet-e2e#2479)

[example] Add deespeed configs of different stages for illustration (wenet-e2e#2485)

[example] Fix prefetch and step_save (wenet-e2e#2486)

[ctl] simplified ctl (wenet-e2e#2483)

* [ctl] simplified ctl

* [ctl] unify

[branchformer] simplified branchformer (wenet-e2e#2482)

* [transformer] simplified branchformer

* fix yaml

* support mqa  gradiengt ckpt sdpa

* fix gradient checkponit

* add deepspeed comment in layer dropout

* fix comment

[e_branchformer] simplified e_branchformer (wenet-e2e#2484)

* [e_branchformer] simplified ctl

* try to fix ut

* try to fix ut

* fix activation

* fix att args

* e-branformer works

[transformer] refactor cache (wenet-e2e#2481)

* [transformer] refactor cache

* fix ut

* unify cache type in branchformer and ebranchformer

fix cache

fix gradient ckpt in branchformer/ebranformer (wenet-e2e#2488)

fix search after refactor cache (wenet-e2e#2490)

generate works!

unify chat pattern

convert llama3 works

[transformer] set use_reentrant=False for gradient ckpt (wenet-e2e#2491)

[transformer] fix warning: ignore(True) has been deprecated (wenet-e2e#2492)

* [transformer] fix warning: ignore(True) has been deprecated

* [transformer] fix warning: ignore(True) has been deprecated

[log] avoid reduntant logging (wenet-e2e#2493)

fix w1 w2 w3 in feedforward

add 70b temporarily

mv LLM to wenet

support llm dataset

unify config

add dataset yaml in script

support llm dataset

dynamic static bucket works

[transformer] refacgtor mqa repeat (wenet-e2e#2497)

[transformer] fix mqa in cross att (wenet-e2e#2498)

[deepspeed] update json config (wenet-e2e#2499)

training works

pretrain works

refactor covert

fix flash att in generate

llama works

fix llama3

fix speed

try fix ut

support stop tokens in gen and support ppl

support stop tokens in gen and support ppl
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants