Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

local offline usage of opusmt #91

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions easynmt/models/OpusMT.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from lib2to3.pgen2 import token
from lib2to3.pgen2.tokenize import tokenize
import time
from transformers import MarianMTModel, MarianTokenizer
import torch
from typing import List
import logging
import os

logger = logging.getLogger(__name__)

Expand All @@ -12,7 +15,10 @@ def __init__(self, easynmt_path: str = None, max_loaded_models: int = 10):
self.models = {}
self.max_loaded_models = max_loaded_models
self.max_length = None

self.easynmt_path = easynmt_path
self.src_lang = ""
self.trgt_lang = ""

def load_model(self, model_name):
if model_name in self.models:
self.models[model_name]['last_loaded'] = time.time()
Expand All @@ -36,7 +42,21 @@ def load_model(self, model_name):
return tokenizer, model

def translate_sentences(self, sentences: List[str], source_lang: str, target_lang: str, device: str, beam_size: int = 5, **kwargs):
model_name = 'Helsinki-NLP/opus-mt-{}-{}'.format(source_lang, target_lang)
# model_name = 'Helsinki-NLP/opus-mt-{}-{}'.format(source_lang, target_lang)

################################################

# This is for the loading of the already downloaded models available in download_path
# we specify in the application's algo file as:-
# Example for loading pre downloaded opus-mt models
# model_trans = app.imports.EasyNMT(download_path)

self.src_lang = source_lang
self.trgt_lang = target_lang
model_name = os.path.join(self.easynmt_path,self.src_lang + "-" + self.trgt_lang)


################################################
tokenizer, model = self.load_model(model_name)
model.to(device)

Expand All @@ -52,5 +72,26 @@ def translate_sentences(self, sentences: List[str], source_lang: str, target_lan
return output

def save(self, output_path):
return {"max_loaded_models": self.max_loaded_models}
########################################################################################

# Modified by - Aniket Sood

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this isn't needed, as it is already included on the git blame :)



# This function will not be required for normal run
# but can be used in case we want to save the models in the desired path.

# To download the models online and save them in any folder we want defined by output_path (eg - output_path - cache_folder/)
# or alternatively user can download from online sources also and store it in any folder and provide its .
save_path = os.path.join(output_path,self.src_lang + "-" + self.trgt_lang)
try:
os.mkdir(save_path)
except OSError as error:
print(error)
self.models['Helsinki-NLP/opus-mt-' + self.src_lang + '-' + self.trgt_lang]['model'].save_pretrained(save_path)
self.models['Helsinki-NLP/opus-mt-' + self.src_lang + '-' + self.trgt_lang]['tokenizer'].save_pretrained(save_path)

########################################################################################
return {
"max_loaded_models": self.max_loaded_models
}