diff --git a/easynmt/models/OpusMT.py b/easynmt/models/OpusMT.py index 0f1e8eb..72086ed 100644 --- a/easynmt/models/OpusMT.py +++ b/easynmt/models/OpusMT.py @@ -1,8 +1,11 @@ +from lib2to3.pgen2 import token +from lib2to3.pgen2.tokenize import tokenize import time from transformers import MarianMTModel, MarianTokenizer import torch from typing import List import logging +import os logger = logging.getLogger(__name__) @@ -12,7 +15,10 @@ def __init__(self, easynmt_path: str = None, max_loaded_models: int = 10): self.models = {} self.max_loaded_models = max_loaded_models self.max_length = None - + self.easynmt_path = easynmt_path + self.src_lang = "" + self.trgt_lang = "" + def load_model(self, model_name): if model_name in self.models: self.models[model_name]['last_loaded'] = time.time() @@ -36,7 +42,21 @@ def load_model(self, model_name): return tokenizer, model def translate_sentences(self, sentences: List[str], source_lang: str, target_lang: str, device: str, beam_size: int = 5, **kwargs): - model_name = 'Helsinki-NLP/opus-mt-{}-{}'.format(source_lang, target_lang) + # model_name = 'Helsinki-NLP/opus-mt-{}-{}'.format(source_lang, target_lang) + + ################################################ + + # This is for the loading of the already downloaded models available in download_path + # we specify in the application's algo file as:- + # Example for loading pre downloaded opus-mt models + # model_trans = app.imports.EasyNMT(download_path) + + self.src_lang = source_lang + self.trgt_lang = target_lang + model_name = os.path.join(self.easynmt_path,self.src_lang + "-" + self.trgt_lang) + + + ################################################ tokenizer, model = self.load_model(model_name) model.to(device) @@ -52,5 +72,26 @@ def translate_sentences(self, sentences: List[str], source_lang: str, target_lan return output def save(self, output_path): - return {"max_loaded_models": self.max_loaded_models} + ######################################################################################## + + # Modified by - Aniket Sood + + + # This function will not be required for normal run + # but can be used in case we want to save the models in the desired path. + + # To download the models online and save them in any folder we want defined by output_path (eg - output_path - cache_folder/) + # or alternatively user can download from online sources also and store it in any folder and provide its . + save_path = os.path.join(output_path,self.src_lang + "-" + self.trgt_lang) + try: + os.mkdir(save_path) + except OSError as error: + print(error) + self.models['Helsinki-NLP/opus-mt-' + self.src_lang + '-' + self.trgt_lang]['model'].save_pretrained(save_path) + self.models['Helsinki-NLP/opus-mt-' + self.src_lang + '-' + self.trgt_lang]['tokenizer'].save_pretrained(save_path) + + ######################################################################################## + return { + "max_loaded_models": self.max_loaded_models + }