From 51ba35217debfbf34cdfbae14e09cb9df9c3be5b Mon Sep 17 00:00:00 2001 From: Myle Ott Date: Mon, 22 Jul 2019 08:11:58 -0700 Subject: [PATCH] Simplify hubconf Summary: Pull Request resolved: https://github.com/fairinternal/fairseq-py/pull/758 Differential Revision: D16418932 Pulled By: myleott fbshipit-source-id: 59f005164b61b9fa712922eeb23525f7eec38f38 --- hubconf.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/hubconf.py b/hubconf.py index 1eb25f870a..992c259fa3 100644 --- a/hubconf.py +++ b/hubconf.py @@ -5,6 +5,8 @@ # the root directory of this source tree. An additional grant of patent rights # can be found in the PATENTS file in the same directory. +import functools + from fairseq.models import MODEL_REGISTRY @@ -18,5 +20,11 @@ ] -for model, cls in MODEL_REGISTRY.items(): - globals()[model] = cls.from_pretrained +for model_type, _cls in MODEL_REGISTRY.items(): + for model_name in _cls.hub_models().keys(): + globals()[model_name] = functools.partial( + _cls.from_pretrained, + model_name_or_path=model_name, + ) + # to simplify the interface we only expose named models + #globals()[model_type] = _cls.from_pretrained