Skip to content

Commit

Permalink
[Research] Speed up evaluation for XTREME-S (#16785)
Browse files Browse the repository at this point in the history
* Avoid repeated per-lang filtering

* Language groups and logits preprocessing

* Style
  • Loading branch information
anton-l authored Apr 27, 2022
1 parent 2d91e3c commit a4a88fa
Showing 1 changed file with 46 additions and 8 deletions.
54 changes: 46 additions & 8 deletions examples/research_projects/xtreme-s/run_xtreme_s.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ class ModelArguments:
metadata={"help": "Length of vector span to mask along the feature axis."},
)
layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
ctc_zero_infinity: bool = field(
default=False,
metadata={"help": "Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`."},
)
ctc_loss_reduction: Optional[str] = field(
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
)
Expand Down Expand Up @@ -166,6 +170,15 @@ class DataTrainingArguments:
default="all",
metadata={"help": "The language id as defined in the datasets config name or `all` for all languages."},
)
language_group: str = field(
default=None,
metadata={
"help": "The language group to select a subset of languages to train on. "
"This option is only used the 'fleurs-asr' task. Should be one of: "
"'western_european_we', 'eastern_european_ee', 'central_asia_middle_north_african_cmn', "
"'sub_saharan_african_ssa', 'south_asian_sa', 'south_east_asian_sea', 'chinese_japanase_korean_cjk'."
},
)
train_split_name: str = field(
default="train",
metadata={
Expand Down Expand Up @@ -441,6 +454,11 @@ def main():
"config to be used (e.g. 'pl', 'en.tr', 'fr-FR') or 'all'"
" for multi-lingual fine-tuning."
)
if data_args.language_group is not None:
if data_args.task != "fleurs-asr":
raise ValueError("--language_group should only be used with --task=fleurs-asr")
if data_args.language != "all":
raise ValueError("--language_group should only be used with --language=all")

if data_args.target_column_name is None:
target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name]
Expand Down Expand Up @@ -502,11 +520,23 @@ def main():
if data_args.max_predict_samples is not None:
raw_datasets["predict"] = raw_datasets["predict"].select(range(data_args.max_predict_samples))

lang_list = next(iter(raw_datasets.values())).features["lang_id"].names
if not is_text_target:
label_list = next(iter(raw_datasets.values())).features[target_column_name].names
lang_list = next(iter(raw_datasets.values())).features["lang_id"].names
num_labels = len(label_list)

num_workers = data_args.preprocessing_num_workers

lang_group = data_args.language_group
if lang_group is not None:
with training_args.main_process_first(desc="language group filter"):
lang_group_id = next(iter(raw_datasets.values())).features["lang_group_id"].str2int(lang_group)
raw_datasets = raw_datasets.filter(
lambda lang_group: lang_group == lang_group_id,
num_proc=num_workers,
input_columns=["lang_group_id"],
)

# 2. We remove some special characters from the datasets
# that make training complicated and do not help in transcribing the speech
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
Expand Down Expand Up @@ -616,6 +646,7 @@ def remove_special_characters(batch):
"mask_feature_length": model_args.mask_feature_length,
"gradient_checkpointing": training_args.gradient_checkpointing,
"layerdrop": model_args.layerdrop,
"ctc_zero_infinity": model_args.ctc_zero_infinity,
"ctc_loss_reduction": model_args.ctc_loss_reduction,
"activation_dropout": model_args.activation_dropout,
}
Expand Down Expand Up @@ -675,7 +706,6 @@ def remove_special_characters(batch):
max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
audio_column_name = data_args.audio_column_name
num_workers = data_args.preprocessing_num_workers

# `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
phoneme_language = data_args.phoneme_language
Expand Down Expand Up @@ -740,13 +770,13 @@ def is_audio_in_length_range(length):
logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
return

def compute_asr_metric(pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis=-1)
def asr_logits_argmax(logits, labels):
return logits.argmax(dim=-1)

def compute_asr_metric(pred):
pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id

pred_str = tokenizer.batch_decode(pred_ids)
pred_str = tokenizer.batch_decode(pred.predictions)
# we do not want to group tokens when computing the metrics
label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)

Expand Down Expand Up @@ -783,6 +813,7 @@ def compute_classification_metric(pred):
model=model,
data_collator=data_collator,
args=training_args,
preprocess_logits_for_metrics=asr_logits_argmax if training_args.predict_with_generate else None,
compute_metrics=compute_asr_metric if training_args.predict_with_generate else None,
train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
Expand All @@ -793,6 +824,7 @@ def compute_classification_metric(pred):
model=model,
data_collator=data_collator,
args=training_args,
preprocess_logits_for_metrics=asr_logits_argmax if is_text_target else None,
compute_metrics=compute_asr_metric if is_text_target else compute_classification_metric,
train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
Expand Down Expand Up @@ -837,11 +869,17 @@ def compute_classification_metric(pred):
average_metrics = defaultdict(list)
for lang_id in range(len(lang_list)):
lang_name = lang_list[lang_id]
lang_dataset = vectorized_datasets["predict"].filter(lambda example: example["lang"] == lang_id)
with training_args.main_process_first(desc="per-language dataset filter"):
lang_dataset = vectorized_datasets["predict"].filter(
lambda lang: lang == lang_id,
num_proc=num_workers,
input_columns=["lang"],
)
lang_metrics = trainer.evaluate(lang_dataset)
redundant_metrics = ["eval_runtime", "eval_samples_per_second", "eval_steps_per_second", "eval_epoch"]
for metric_name, value in lang_metrics.items():
average_metrics[metric_name].append(value)
if metric_name not in ["eval_runtime", "eval_samples_per_second", "eval_steps_per_second"]:
if metric_name not in redundant_metrics:
metrics[f"{metric_name}_{lang_name}"] = value
for metric_name, value in average_metrics.items():
metrics[metric_name] = np.mean(value)
Expand Down

0 comments on commit a4a88fa

Please # to comment.