-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathprepro_util.py
executable file
·153 lines (134 loc) · 5.17 KB
/
prepro_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
class SquadExample(object):
"""A single training/test example for simple sequence classification."""
def __init__(self,
qas_id,
question_text,
doc_tokens,
orig_answer_text=None,
all_answers=None,
start_position=None,
end_position=None,
switch=None):
self.qas_id = qas_id
self.question_text = question_text
self.doc_tokens = doc_tokens
self.orig_answer_text = orig_answer_text
self.all_answers=all_answers
self.start_position = start_position
self.end_position = end_position
self.switch = switch
def __str__(self):
return self.__repr__()
def __repr__(self):
s = "question: "+self.question_text
return s
class InputFeatures(object):
def __init__(self,
unique_id,
example_index,
doc_span_index,
doc_tokens,
tokens,
token_to_orig_map,
token_is_max_context,
input_ids,
input_mask,
segment_ids,
start_position=None,
end_position=None,
switch=None,
answer_mask=None):
self.unique_id = unique_id
self.example_index = example_index
self.doc_span_index = doc_span_index
self.doc_tokens = doc_tokens
self.tokens = tokens
self.token_to_orig_map = token_to_orig_map
self.token_is_max_context = token_is_max_context
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.start_position = start_position
self.end_position = end_position
self.switch = switch
self.answer_mask = answer_mask
def find_span_from_text(context, tokens, answer):
if answer.strip() in ['yes', 'no']:
return [{'text': answer, 'answer_start': 0}]
assert answer in context
offset = 0
spans = []
scanning = None
process = []
for i, token in enumerate(tokens):
while context[offset:offset+len(token)] != token:
offset += 1
if offset >= len(context):
break
if scanning is not None:
end = offset + len(token)
if answer.startswith(context[scanning[-1]:end]):
if context[scanning[-1]:end] == answer:
spans.append(scanning[0])
elif len(context[scanning[-1]:end]) >= len(answer):
scanning = None
else:
scanning = None
if scanning is None and answer.startswith(token):
if token == answer:
spans.append(offset)
if token != answer:
scanning = [offset]
offset += len(token)
if offset >= len(context):
break
process.append((token, offset, scanning, spans))
answers = []
for span in spans:
if context[span:span+len(answer)] != answer:
print (context[span:span+len(answer)], answer)
print (context)
assert False
answers.append({'text': answer, 'answer_start': span})
#if len(answers)==0:
# print ("*"*30)
# print (context, answer)
return answers
def detect_span(_answers, context, doc_tokens, char_to_word_offset):
orig_answer_texts = []
start_positions = []
end_positions = []
switches = []
if 'answer_start' not in _answers[0]:
answers = []
for answer in _answers:
answers += find_span_from_text(context, doc_tokens, answer['text'])
else:
answers = _answers
for answer in answers:
orig_answer_text = answer["text"]
answer_offset = answer["answer_start"]
answer_length = len(orig_answer_text)
if orig_answer_text in ["yes", "no"]:
start_position, end_position = 0, 0
switch = 1 if orig_answer_text == "yes" else 2
else:
switch = 0
start_position = char_to_word_offset[answer_offset]
end_position = char_to_word_offset[answer_offset + answer_length - 1]
# Only add answers where the text can be exactly recovered from the
# document. If this CAN'T happen it's likely due to weird Unicode
# stuff so we will just skip the example.
#
# Note that this means for training mode, every example is NOT
# guaranteed to be preserved.
#actual_text = " ".join(doc_tokens[start_position:(end_position + 1)])
#cleaned_answer_text = " ".join(
# tokenization.whitespace_tokenize(orig_answer_text))
#if actual_text.replace(' ', '').find(cleaned_answer_text.replace(' ', '')) == -1:
# print ("Could not find answer: '%s' vs. '%s'" % (actual_text, cleaned_answer_text))
orig_answer_texts.append(orig_answer_text)
start_positions.append(start_position)
end_positions.append(end_position)
switches.append(switch)
return orig_answer_texts, switches, start_positions, end_positions