diff --git a/fairseq/checkpoint_utils.py b/fairseq/checkpoint_utils.py index 66e16ea16d..0284fa2aa1 100644 --- a/fairseq/checkpoint_utils.py +++ b/fairseq/checkpoint_utils.py @@ -153,11 +153,11 @@ def load_model_ensemble(filenames, arg_overrides=None, task=None): were used during model training task (fairseq.tasks.FairseqTask, optional): task to use for loading """ - ensemble, args, _task = _load_model_ensemble(filenames, arg_overrides, task) + ensemble, args, _task = load_model_ensemble_and_task(filenames, arg_overrides, task) return ensemble, args -def _load_model_ensemble(filenames, arg_overrides=None, task=None): +def load_model_ensemble_and_task(filenames, arg_overrides=None, task=None): from fairseq import tasks ensemble = [] diff --git a/fairseq/models/fairseq_model.py b/fairseq/models/fairseq_model.py index 1d534188f8..8f52adc7de 100644 --- a/fairseq/models/fairseq_model.py +++ b/fairseq/models/fairseq_model.py @@ -191,7 +191,7 @@ def from_pretrained(cls, model_name_or_path, checkpoint_file='model.pt', data_na if os.path.exists(path): kwargs[arg] = path - models, args, task = checkpoint_utils._load_model_ensemble( + models, args, task = checkpoint_utils.load_model_ensemble_and_task( [os.path.join(model_path, cpt) for cpt in checkpoint_file.split(':')], arg_overrides=kwargs, )