From 0f2f3396e6bc05beeec2521d4f6db7eee5942940 Mon Sep 17 00:00:00 2001 From: Mddct Date: Wed, 20 Mar 2024 14:42:40 +0800 Subject: [PATCH] unify shuff typo --- wenet/dataset/dataset.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/wenet/dataset/dataset.py b/wenet/dataset/dataset.py index c09250343..6bb49c4da 100644 --- a/wenet/dataset/dataset.py +++ b/wenet/dataset/dataset.py @@ -42,9 +42,13 @@ def Dataset(data_type, assert data_type in ['raw', 'shard'] # cycle dataset cycle = conf.get('cycle', 1) - list_shuffle = conf.get('list_shuffle', False) - list_shuffle_size = conf.get('list_shuffle_size', 10000) - + # stage1 shuffle: source + list_shuffle = conf.get('list_shuffle', True) + list_shuffle_size = 10000000 + if list_shuffle: + list_shuffle_conf = conf.get('list_shuffle_conf', {}) + list_shuffle_size = list_shuffle_conf.get('shuffle_size', + list_shuffle_size) if data_type == 'raw': dataset = WenetRawDatasetSource(data_list_file, partition=partition,