diff --git a/axonn/models/transformers/__init__.py b/axonn/models/transformers/__init__.py index b28ee96..951c18e 100644 --- a/axonn/models/transformers/__init__.py +++ b/axonn/models/transformers/__init__.py @@ -40,8 +40,8 @@ @contextmanager -def parallelize(model_id): - config = AutoConfig.from_pretrained(model_id) +def parallelize(model_id, model_kwargs={}): + config = AutoConfig.from_pretrained(model_id, **model_kwargs) architecture = config.architectures[0] # config.architectures is a list, not sure what to do # if it has multiple elements