-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathcodegen.py
83 lines (69 loc) · 3.05 KB
/
codegen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from torch.utils.data import Dataset
import torch.utils.checkpoint
class VoltronDataset(Dataset):
def __init__(self, data_root):
self.samples = []
for i, code_file_path in enumerate(os.listdir(data_root)):
if i > 1:
break
with open(os.path.join(data_root, code_file_path), 'r') as code_file:
for code_block in code_file.read().splitlines():
encoded = [int(x) for x in code_block.split('\t')]
self.samples.append(encoded)
def __len__(self):
return len(self.samples)
def __getitem__(self, idx):
return self.samples[idx]
class CodeGenPass():
def collate_batch(self, batch):
# Padds batch of variable length
tensor_batch = [torch.tensor(x) for x in batch]
max_len = max([x.squeeze().numel() for x in tensor_batch])
padded_batch = [torch.nn.functional.pad(x, pad=(
0, max_len - x.numel()), mode='constant', value=0) for x in tensor_batch]
padded_batch = torch.stack(padded_batch)
return padded_batch
def setup_model(self, type):
print('Loading codegen model ...')
starcoder = "bigcode/starcoder"
codegen = f"Salesforce/codegen-{type}-multi"
codegen_token = "Salesforce/codegen-350M-mono"
model = AutoModelForCausalLM.from_pretrained(
codegen, output_hidden_states=True, torch_dtype=torch.bfloat16, device_map="balanced")
model.eval()
tokenizer = AutoTokenizer.from_pretrained(
codegen_token, fp16=True)
print('Finished loading')
return model, tokenizer
def get_hidden_state(self, decoded_program=None, model=None, tokenizer=None, device=None):
nl_replacement = '\n'
if not isinstance(decoded_program, str):
decoded_program = " ".join(decoded_program)
if len(decoded_program) > 2048:
decoded_program = decoded_program[:2048]
decoded_program = decoded_program.replace(
'#TAB#', '\t').replace('#NL#', nl_replacement)
input_ids = tokenizer(
decoded_program, return_tensors='pt').input_ids.to(device)
# nl_ids = tokenizer(
# '\n', return_tensors='pt').input_ids.to(device)
# print('nl id: ', nl_ids)
nl_indices = torch.where(input_ids == 198)
try:
outputs = model(input_ids=input_ids)
except:
return
hidden_states = outputs[2]
attention_hidden_states = hidden_states[1:]
final_attention_states = attention_hidden_states[-1]
nl_final_attention_states = final_attention_states[torch.arange(
final_attention_states.size(0)), nl_indices[1]]
# project fix number (dense to 1024)
return nl_final_attention_states, len(nl_indices[1])
if __name__ == '__main__':
codegen_trainer = CodeGenPass()
nl_final_attention_states = codegen_trainer.get_hidden_state_local()
print('\n\n\n'+'done\n\n\n')