Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

[transformer] Make MoE runnable #2474

Merged
merged 2 commits into from
Apr 14, 2024
Merged

[transformer] Make MoE runnable #2474

merged 2 commits into from
Apr 14, 2024

Conversation

xingchensong
Copy link
Member

How to use :

image

@xingchensong xingchensong requested review from Mddct, robin1001 and whiteshirt0429 and removed request for Mddct and robin1001 April 14, 2024 11:00
@xingchensong
Copy link
Member Author

离线测试 test_grad_ckpt.py 会卡住,所以先删除了

@Mddct
Copy link
Collaborator

Mddct commented Apr 14, 2024

离线测试 test_grad_ckpt.py 会卡住,所以先删除了

卡住的原因是这个: https://github.com/wenet-e2e/wenet/pull/2473/files 😭
revert 回去就好了

@xingchensong
Copy link
Member Author

离线测试 test_grad_ckpt.py 会卡住,所以先删除了

卡住的原因是这个: https://github.com/wenet-e2e/wenet/pull/2473/files 😭 revert 回去就好了

ok,你搞下

@Mddct
Copy link
Collaborator

Mddct commented Apr 14, 2024

离线测试 test_grad_ckpt.py 会卡住,所以先删除了

卡住的原因是这个: https://github.com/wenet-e2e/wenet/pull/2473/files 😭 revert 回去就好了

ok,你搞下

#2477

wenet/transformer/decoder.py Show resolved Hide resolved
@Mddct Mddct merged commit c906392 into main Apr 14, 2024
6 checks passed
@Mddct Mddct deleted the xcsong-moe branch April 14, 2024 15:44
@Mddct
Copy link
Collaborator

Mddct commented Apr 19, 2024

decoder_conf:
  attention_heads: 4
  dropout_rate: 0.1
  linear_units: 512
  mlp_type: moe
  n_expert: 4
  n_expert_activated: 2
  num_blocks: 6
  positional_dropout_rate: 0.1
  self_attention_dropout_rate: 0.0
  src_attention_dropout_rate: 0.0
dtype: fp32
encoder: conformer
encoder_conf:
  activation_type: swish
  attention_dropout_rate: 0.0
  attention_heads: 4
  cnn_module_kernel: 15
  dropout_rate: 0.1
  input_layer: conv2d
  linear_units: 512
  mlp_type: moe
  n_expert: 4
  n_expert_activated: 2
  normalize_before: true
  num_blocks: 12
  output_size: 256
  pos_enc_layer_type: rel_pos
  positional_dropout_rate: 0.1
  selfattention_layer_type: rel_selfattn
  use_cnn_module: true

108761713497819_ pic

srdfjy pushed a commit to srdfjy/wenet that referenced this pull request Jul 2, 2024
@srdfjy
Copy link
Contributor

srdfjy commented Jul 15, 2024

hi,请问有与MoE相关的消融数据吗?

@xingchensong
Copy link
Member Author

hi,请问有与MoE相关的消融数据吗?

https://arxiv.org/abs/2404.16407

@srdfjy
Copy link
Contributor

srdfjy commented Jul 16, 2024

hi @xingchensong,使用libtorch的runtime跑MoE的jit script(非流式),RTF反而比没有MoE的模型还高,是测试方法不对吗?

配置

n_expert: 4
n_expert_activated: 2

@xingchensong
Copy link
Member Author

至少跑几千条后求均值rtf,跑单条,误差很大

@srdfjy
Copy link
Contributor

srdfjy commented Jul 17, 2024

至少跑几千条后求均值rtf,跑单条,误差很大

音频总数:35562条
模型大小:0.12B、RTF:0.24
模型大小:0.31B MoE、RTF:0.31

效果还是很明显,MoE的CER还略有提升,但是训练时间增加一倍。

Zth9730 pushed a commit to Zth9730/wenet that referenced this pull request Aug 7, 2024
add casual model

fix typo

rm ckpt

add topk topp sampler

fix positoin

[train_engine] support fsdp (wenet-e2e#2412)

* [train_engine] support fsdp

* [train_engine] support fsdp

* unify scaler and amp

* fp32&&fp16 works in fsdp env

* fix fsdp in cv auto cast

* try to fix wenet.join fsdp

* implementing zero1 under fsdp is almost equivalent to deepspeed's zero1

* fix clip_and_grad_

* fix train summary

* all wenet xxxformer works (-paraformer -transducer)

* try to fix nan

* add barrier for cv

* add destroy group for end of all train

* refactor wrap methods and ckpt works

* fix ckpt

* fix cv in dtype != float32

* fix ckpt in model mode

* fix bf16 amp

* refactor scaler and autocast, fix fp32 fp16 bf16 for fsdp

* fix fp32 nullcontext to nullcontext()

* modify after review

* fix lint

* fix lint

LoRA support (wenet-e2e#2049)

* support lora for v3.0.1

* format code and update lora attention && encoder

* fix bug when lora_list is None

---------

Co-authored-by: Xingchen Song(宋星辰) <xingchensong1996@163.com>

[env] update python version and deepspeed version (wenet-e2e#2462)

* [env] update python version and deepspeed version

* [env] fix lint

fix rope pos embdining (wenet-e2e#2463)

* fix rope pos embdining

* fix dropout

* fix comment

[transformer] add multi warmup and learning rate for different modules (wenet-e2e#2449)

* [transformer] add multi warmup and learning rate for different modules

* fix typo

* it works in warmuplr

* fix lr in tensorboard in step mode

* fix cv log

* cv works

* refactor cv log

* add helper lrs_to_string

* fix lrstr

* fix ddp multiple lr

* fix initial step

* revert to -1

* fix sub params dup

* fix step

* fix step

* fix log

* add assert for scheduler

* add comment for log

---------

Co-authored-by: Xingchen Song(宋星辰) <xingchensong1996@163.com>

add generate

add toto

support sft & pretrain training forward

gemm conversion works

support init casual model

[whisper] limit language to Chinese (wenet-e2e#2470)

[train] convert tensor to scalar (wenet-e2e#2471)

[workflow] upgrad python version to 3.10 (wenet-e2e#2472)

* [workflow] upgrad python version to 3.10

* [workflow] try to pass

refactor cache behaviour in training mode (reduce compute cost and memory) (wenet-e2e#2473)

all gemma model works

fix ut

fix ut (wenet-e2e#2477)

* fix ut

* fix py version

[transformer] Make MoE runnable (wenet-e2e#2474)

[transformer] fix mqa (wenet-e2e#2478)

enable mmap in torch.load (wenet-e2e#2479)

[example] Add deespeed configs of different stages for illustration (wenet-e2e#2485)

[example] Fix prefetch and step_save (wenet-e2e#2486)

[ctl] simplified ctl (wenet-e2e#2483)

* [ctl] simplified ctl

* [ctl] unify

[branchformer] simplified branchformer (wenet-e2e#2482)

* [transformer] simplified branchformer

* fix yaml

* support mqa  gradiengt ckpt sdpa

* fix gradient checkponit

* add deepspeed comment in layer dropout

* fix comment

[e_branchformer] simplified e_branchformer (wenet-e2e#2484)

* [e_branchformer] simplified ctl

* try to fix ut

* try to fix ut

* fix activation

* fix att args

* e-branformer works

[transformer] refactor cache (wenet-e2e#2481)

* [transformer] refactor cache

* fix ut

* unify cache type in branchformer and ebranchformer

fix cache

fix gradient ckpt in branchformer/ebranformer (wenet-e2e#2488)

fix search after refactor cache (wenet-e2e#2490)

generate works!

unify chat pattern

convert llama3 works

[transformer] set use_reentrant=False for gradient ckpt (wenet-e2e#2491)

[transformer] fix warning: ignore(True) has been deprecated (wenet-e2e#2492)

* [transformer] fix warning: ignore(True) has been deprecated

* [transformer] fix warning: ignore(True) has been deprecated

[log] avoid reduntant logging (wenet-e2e#2493)

fix w1 w2 w3 in feedforward

add 70b temporarily

mv LLM to wenet

support llm dataset

unify config

add dataset yaml in script

support llm dataset

dynamic static bucket works

[transformer] refacgtor mqa repeat (wenet-e2e#2497)

[transformer] fix mqa in cross att (wenet-e2e#2498)

[deepspeed] update json config (wenet-e2e#2499)

training works

pretrain works

refactor covert

fix flash att in generate

llama works

fix llama3

fix speed

try fix ut

support stop tokens in gen and support ppl

support stop tokens in gen and support ppl
srdfjy pushed a commit to srdfjy/wenet that referenced this pull request Oct 8, 2024
# for free to join this conversation on GitHub. Already have an account? # to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants