-
Notifications
You must be signed in to change notification settings - Fork 146
/
Copy pathtranslate.py
276 lines (248 loc) · 9.22 KB
/
translate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
# Copyright (c) 2019-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
# Translate sentences from the input stream.
# The model will be faster is sentences are sorted by length.
# Input sentences must have the same tokenization and BPE codes than the ones used in the model.
#
import os
import argparse
import typing as tp
from pathlib import Path
import sys
import torch
from codegen_sources.model.src.logger import create_logger
from codegen_sources.model.src.data.dictionary import (
Dictionary,
BOS_WORD,
EOS_WORD,
PAD_WORD,
UNK_WORD,
MASK_WORD,
)
from codegen_sources.model.src.utils import bool_flag
from codegen_sources.model.src.constants import SUPPORTED_LANGUAGES_FOR_TESTS
from codegen_sources.model.src.model import build_model
from codegen_sources.model.src.utils import AttrDict
import codegen_sources.dataloaders.transforms as transf
SUPPORTED_LANGUAGES = list(SUPPORTED_LANGUAGES_FOR_TESTS) + ["ir"]
logger = create_logger(None, 0)
def get_params():
"""
Generate a parameters parser.
"""
# parse parameters
parser = argparse.ArgumentParser(description="Translate sentences")
# model
parser.add_argument("--model_path", type=str, default="", help="Model path")
parser.add_argument(
"--src_lang",
type=str,
default="",
help=f"Source language, should be either {', '.join(SUPPORTED_LANGUAGES[:-1])} or {SUPPORTED_LANGUAGES[-1]}",
)
parser.add_argument(
"--tgt_lang",
type=str,
default="",
help=f"Target language, should be either {', '.join(SUPPORTED_LANGUAGES[:-1])} or {SUPPORTED_LANGUAGES[-1]}",
)
parser.add_argument(
"--BPE_path",
type=str,
default=str(
Path(__file__).parents[2].joinpath("data/bpe/cpp-java-python/codes")
),
help="Path to BPE codes.",
)
parser.add_argument(
"--beam_size",
type=int,
default=1,
help="Beam size. The beams will be printed in order of decreasing likelihood.",
)
parser.add_argument(
"--input", type=str, default=None, help="input path",
)
parser.add_argument(
"--gpu", type=bool_flag, default=True, help="input path",
)
parser.add_argument(
"--efficient_attn",
type=str,
default=None,
choices=["None", "flash", "cutlass", "fctls_bflsh", "auto"],
help="If set, uses efficient attention from xformers.",
)
parameters = parser.parse_args()
if parameters.efficient_attn == "None":
parameters.efficient_attn = None
return parameters
class Translator:
def __init__(self, model_path, BPE_path, gpu=True, efficient_attn=None) -> None:
self.gpu = gpu
# reload model
reloaded = torch.load(model_path, map_location="cpu")
# change params of the reloaded model so that it will
# relaod its own weights and not the MLM or DOBF pretrained model
reloaded["params"]["reload_model"] = ",".join([str(model_path)] * 2)
reloaded["params"]["lgs_mapping"] = ""
reloaded["params"]["reload_encoder_for_decoder"] = False
self.reloaded_params = AttrDict(reloaded["params"])
self.reloaded_params["efficient_attn"] = efficient_attn
# build dictionary / update parameters
self.dico = Dictionary(
reloaded["dico_id2word"], reloaded["dico_word2id"], reloaded["dico_counts"]
)
assert self.reloaded_params.n_words == len(self.dico)
assert self.reloaded_params.bos_index == self.dico.index(BOS_WORD)
assert self.reloaded_params.eos_index == self.dico.index(EOS_WORD)
assert self.reloaded_params.pad_index == self.dico.index(PAD_WORD)
assert self.reloaded_params.unk_index == self.dico.index(UNK_WORD)
assert self.reloaded_params.mask_index == self.dico.index(MASK_WORD)
# build model / reload weights (in the build_model method)
encoder, decoder = build_model(self.reloaded_params, self.dico, self.gpu)
self.encoder = encoder[0]
self.decoder = decoder[0]
if gpu:
self.encoder.cuda()
self.decoder.cuda()
self.encoder.eval()
self.decoder.eval()
# reload bpe
if (
self.reloaded_params.get("roberta_mode", False)
or self.reloaded_params.get("tokenization_mode", "") == "roberta"
):
self.bpe_transf: transf.BpeBase = transf.RobertaBpe()
raise ValueError("This part has not be tested thoroughly yet")
else:
self.bpe_transf = transf.FastBpe(code_path=Path(BPE_path).absolute())
def translate(
self,
input_code,
lang1: str,
lang2: str,
suffix1: str = "_sa",
suffix2: str = "_sa",
n: int = 1,
beam_size: int = 1,
sample_temperature=None,
device=None,
tokenized=False,
detokenize: bool = True,
max_tokens: tp.Optional[int] = None,
length_penalty: float = 0.5,
max_len: tp.Optional[int] = None,
):
if device is None:
device = "cuda:0" if self.gpu else "cpu"
# Build language processors
assert lang1 in SUPPORTED_LANGUAGES, lang1
assert lang2 in SUPPORTED_LANGUAGES, lang2
bpetensorizer = transf.BpeTensorizer()
bpetensorizer.dico = self.dico # TODO: hacky
in_pipe: transf.Transform[tp.Any, torch.Tensor] = self.bpe_transf.pipe(
bpetensorizer
)
out_pipe = in_pipe
if not tokenized:
in_pipe = transf.CodeTokenizer(lang1).pipe(in_pipe)
if detokenize:
out_pipe = transf.CodeTokenizer(lang2).pipe(out_pipe)
lang1 += suffix1
lang2 += suffix2
avail_langs = list(self.reloaded_params.lang2id.keys())
for lang in [lang1, lang2]:
if lang not in avail_langs:
raise ValueError(f"{lang} should be in {avail_langs}")
with torch.no_grad():
lang1_id = self.reloaded_params.lang2id[lang1]
lang2_id = self.reloaded_params.lang2id[lang2]
# Create torch batch
x1 = in_pipe.apply(input_code).to(device)[:, None]
size = x1.shape[0]
len1 = torch.LongTensor(1).fill_(size).to(device)
if max_tokens is not None and size > max_tokens:
logger.info(f"Ignoring long input sentence of size {size}")
return [f"Error: input too long: {size}"] * max(n, beam_size)
langs1 = x1.clone().fill_(lang1_id)
# Encode
enc1 = self.encoder("fwd", x=x1, lengths=len1, langs=langs1, causal=False)
enc1 = enc1.transpose(0, 1)
if n > 1:
enc1 = enc1.repeat(n, 1, 1)
len1 = len1.expand(n)
# Decode
if max_len is None:
max_len = int(
min(self.reloaded_params.max_len, 3 * len1.max().item() + 10)
)
if beam_size == 1:
x2, len2 = self.decoder.generate(
enc1,
len1,
lang2_id,
max_len=max_len,
sample_temperature=sample_temperature,
)
else:
x2, len2, _ = self.decoder.generate_beam(
enc1,
len1,
lang2_id,
max_len=max_len,
early_stopping=False,
length_penalty=length_penalty,
beam_size=beam_size,
)
# Convert out ids to text
tok = []
for i in range(x2.shape[1]):
tok.append(out_pipe.revert(x2[:, i]))
return tok
if __name__ == "__main__":
# generate parser / parse parameters
params = get_params()
# check parameters
assert os.path.isfile(
params.model_path
), f"The path to the model checkpoint is incorrect: {params.model_path}"
assert params.input is None or os.path.isfile(
params.input
), f"The path to the input file is incorrect: {params.input}"
assert os.path.isfile(
params.BPE_path
), f"The path to the BPE tokens is incorrect: {params.BPE_path}"
assert (
params.src_lang in SUPPORTED_LANGUAGES
), f"The source language should be in {SUPPORTED_LANGUAGES}."
assert (
params.tgt_lang in SUPPORTED_LANGUAGES
), f"The target language should be in {SUPPORTED_LANGUAGES}."
# Initialize translator
translator = Translator(
params.model_path, params.BPE_path, params.gpu, params.efficient_attn
)
# read input code from stdin
input = (
open(params.input).read().strip()
if params.input is not None
else sys.stdin.read().strip()
)
print(f"Input {params.src_lang} function:")
print(input)
with torch.no_grad():
output = translator.translate(
input,
lang1=params.src_lang,
lang2=params.tgt_lang,
beam_size=params.beam_size,
)
print(f"Translated {params.tgt_lang} function:")
for out in output:
print("=" * 20)
print(out)