We organized most of the existing SSL pretrained models in SUPERB Benchmark's framework.
We provide an all-in-one unified interface for numerous speech pretrained models. All the upstream models take input / output of the same format:
- input: list of unpadded wavs
[wav1, wav2, ...]
, each wav is intorch.FloatTensor
- output: a dictionary where each key's corresponding value is either a padded sequence in
torch.FloatTensor
or a list of padded sequences, each intorch.FloatTensor
. Every padded sequence is in the shape of(batch_size, max_sequence_length_of_batch, hidden_size)
. At least a keyhidden_states
is available, which is a list.
For upstream models that operate on features other than wav (for example: log Mel, fbank, etc), the preprocessing of wav -> feature is done on-they-fly during model forward. Rest assured that this will not increase your runtime.
The Name
field in the upstream information below is the string we use to specify different models. In other words, different upstream are identified with the exact string of their Name
. The upstreams are loaded with pretrained weights.
To evaluate upstreams with SUPERB Benchmark, we provide a unified script for all upstreams: run_downstream.py. Please refer to downstream/README.md for detailed usage.
In this script, we can use -u
with the Name
to switch different upstreams for benchmarking. Take wav2vec 2.0 Base for example:
python3 run_downstream.py -m train -u fbank -d example -n ExpName
python3 run_downstream.py -m train -u wav2vec2 -d example -n ExpName
python3 run_downstream.py -h
After installing s3prl, you can use upstreams in your own codebase.
import s3prl.hub as hub
model_0 = getattr(hub, 'fbank')() # use classic FBANK
model_1 = getattr(hub, 'modified_cpc')() # build the CPC model with pre-trained weights
model_2 = getattr(hub, 'tera')() # build the TERA model with pre-trained weights
model_3 = getattr(hub, 'wav2vec2')() # build the Wav2Vec 2.0 model with pre-trained weights
device = 'cuda' # or cpu
model_3 = model_3.to(device)
wavs = [torch.randn(160000, dtype=torch.float).to(device) for _ in range(16)]
with torch.no_grad():
reps = model_3(wavs)["hidden_states"]
import s3prl.hub as hub
print(dir(hub))
We support most of the existing SSL pretrained models. You can refer to SUPERB paper for their pre-training loss styles.
Publication Date | Model | Name | Paper | Input | Stride | Pre-train Data | Official Ckpt | Official Repo |
---|---|---|---|---|---|---|---|---|
10 Jul 2018 | CPC | - | arxiv | wav | 10ms | - | X | Unavailable |
5 Apr 2019 | APC | apc | arxiv | Mel | 10ms | LibriSpeech-360 | O | APC |
6 Apr 2019 | PASE | pase_plus | arxiv | wav | 10ms | LibriSpeech-960 | X | PASE |
11 Apr 2019 | Wav2Vec | wav2vec | arxiv | wav | 10ms | LibriSpeech-960 | O | Fairseq |
12 Oct 2019 | VQ-Wav2Vec | vq_wav2vec | arxiv | wav | 10ms | LibriSpeech-960 | O | Fairseq |
25 Oct 2019 | Mockingjay | mockingjay | arxiv | Mel | 10ms | LibriSpeech-960 | O | S3PRL |
7 Feb 2020 | Modified-CPC | modified_cpc | arxiv | wav | 10ms | LibriLight-60k | O | FAIR |
17 May 2020 | VQ-APC | vq_apc | arxiv | Mel | 10ms | LibriSpeech-360 | O | NPC |
18 May 2020 | Audio Albert | audio_albert | arxiv | Mel | 10ms | LibriSpeech-960 | X | S3PRL |
20 Jun 2020 | Wav2Vec 2.0 | wav2vec2 / wav2vec2_large_ll60k | arxiv | wav | 20ms | LibriSpeech-960 | O | Fairseq |
12 Jul 2020 | TERA | tera | arxiv | Mel | 10ms | LibriSpeech-960 | O | S3PRL |
1 Nov 2020 | NPC | npc | arxiv | Mel | 10ms | LibriSpeech-360 | X | NPC |
Jun 14 2021 | HuBERT | hubert / hubert_large_ll60k | arxiv | wav | 20ms | LibriSpeech-960 | O | Fairseq |
Dec 3 2019 | DeCoAR | decoar | arxiv | Mel | 10ms | LibriSpeech-960 | O | speech-representations |
Dec 11 2020 | DeCoAR 2.0 | decoar2 | arxiv | Mel | 10ms | LibriSpeech-960 | O | speech-representations |
Oct 5 2021 | DistilHuBERT | distilhubert | arxiv | wav | 20ms | LibriSpeech-960 | O | S3PRL |
May 27 2022 | Robust HuBERT | hubert_base_robust_mgr | arxiv | wav | 20ms | LibriSpeech-960 with mgr distortion | X | Unavailable |
We also provide classic acoustic features as baselines. For each upstream with Name
, you can configure their options (available by their Backend
) in s3prl/upstream/baseline/Name.yaml
.
Feature | Name | Default Dim | Stride | Window | Backend |
---|---|---|---|---|---|
Spectrogram | spectrogram | 257 | 10ms | 25ms | torchaudio-kaldi |
FBANK | fbank | 80 + delta1 + delta2 | 10ms | 25ms | torchaudio-kaldi |
MFCC | mfcc | 13 + delta1 + delta2 | 10ms | 25ms | torchaudio-kaldi |
Mel | mel | 80 | 10ms | 25ms | torchaudio |
Linear | linear | 201 | 10ms | 25ms | torchaudio |
The upstreams can take two options ckpt
and model_config
, whose type are both str
. You can refer to each upstream's hubconf.py and expert.py for their supported options. Hubconf.py under each upstream folder contains the entries you can use as the Name
to initialize an upstream, which follows the protocol documented at torch.hub.load. SSL upstreams with pretrained checkpoints typically has pre-registered ckpt
at their hubconf.py specifying the location of the pre-trained checkpoint. On the other hand, acoustic feature upstreams typically accept model_config
as the configuration file for the feature extraction. Below is an example on how to pass options into an upstream entry to get different upstream instances.
import torch
import s3prl.hub as hub
device = 'cuda' # or cpu
config_path = 's3prl/upstream/baseline/mfcc.yaml'
extracter = getattr(hub, 'baseline_local', model_config=config_path).to(device)
wavs = [torch.zeros(160000, dtype=torch.float).to(device) for _ in range(16)]
with torch.no_grad():
mfcc = extracter(wavs)["hidden_states"]