-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinstr_interface.py
140 lines (114 loc) · 5.13 KB
/
instr_interface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import abc
import json
import angr
# Exec Tree
class Node:
def __init__(self):
self.children = set()
self.children_prob = []
self.max_encounter_child = {}
self.addr = 0
self.is_comp = False
self.visit_count = 1
self.addr_range = None
self.led_by = ""
@staticmethod
def to_addr(node):
if node:
return node.addr
return 0
def __hash__(self):
return self.addr
def __str__(self):
return f"comp: {self.is_comp}; " \
f"vc: {self.visit_count}; " \
f"children: {self.children}; " \
f"prob: {self.children_prob};" \
f"led_by: {self.led_by}" \
f"addr_range: {hex(self.addr_range[0])} - {hex(self.addr_range[1])}"
class UnknownNode(Node):
pass
SCALE = 1
class Instrumentation(abc.ABC):
def __init__(self, executor):
self.executor = executor
self.execution_tree = {} # addr -> Node
self.corpus_traces = {}
self.dfs_visited_nodes = set()
self.unsolvable = set()
self.solved = set()
self.basic_block = {} # BB start => size
self.__get_basic_block_size()
def __get_basic_block_size(self):
p = angr.Project(self.executor.uninstrumented_path, load_options={'auto_load_libs': False})
cfg = p.analyses.CFGFast()
for key in cfg.kb.functions:
for bb in cfg.kb.functions[key].blocks:
self.basic_block[bb.addr] = bb.size
def build_execution_tree(self, new_testcase_filenames: [str]):
pass
def dump_execution_tree(self):
print(json.dumps({hex(x): str(self.execution_tree[x]) for x in self.execution_tree}, sort_keys=True, indent=4))
def assign_prob(self):
for addr, current_node in self.execution_tree.items():
should_assign_prob = current_node.is_comp
sum_of_children = 1 # prevent div by 0, todo: this causes left + right != 1
for child_node_addr in current_node.children:
child_node = self.execution_tree[child_node_addr]
sum_of_children += child_node.visit_count
for child_node_addr in current_node.children:
child_node = self.execution_tree[child_node_addr]
current_node.children_prob.append(child_node.visit_count * SCALE / sum_of_children)
while len(current_node.children_prob) < 2:
current_node.children_prob.append(3 * SCALE / sum_of_children)
if not should_assign_prob or sum_of_children < 30:
current_node.children_prob = [1.0 * SCALE for _ in range(len(current_node.children_prob))]
def __get_prob(self, parent, child):
parent_node = self.execution_tree[parent]
child_node_addr = self.execution_tree[child].addr
for k, _child_addr in enumerate(parent_node.children):
if _child_addr == child_node_addr:
return parent_node.children_prob[k]
print(f"[Exec] {parent} {child} not in execution tree")
assert False
def __is_branch_missed(self, parent_addr, child_addr, nth=0):
hit_count = nth + 1
parent_node = self.execution_tree[parent_addr]
return (
len(parent_node.children) < 2
or hit_count not in parent_node.max_encounter_child[child_addr]
) and parent_node.is_comp
def __should_i_solve(self, testcase_fn, flip_pcs, nth=0):
return ((testcase_fn, flip_pcs[0], flip_pcs[1], nth) not in self.unsolvable) and \
((testcase_fn, flip_pcs[0], flip_pcs[1], nth) not in self.solved)
def add_unsolvable_path(self, testcase_fn, flip_pcs, nth=0):
self.unsolvable.add((testcase_fn, flip_pcs[0], flip_pcs[1], nth))
def add_solved_path(self, testcase_fn, flip_pcs, nth=0):
self.solved.add((testcase_fn, flip_pcs[0], flip_pcs[1], nth))
def get_sorted_missed_path(self, num=10):
missed_paths = []
for filename in self.corpus_traces:
hit_counts = {}
trace = self.corpus_traces[filename]
prob = 1
trace_len = len(trace)
for k in range(0, trace_len - 1):
node = trace[k]
next_node = trace[k + 1]
hit_counts[node] = hit_counts[node] + 1 if node in hit_counts else 1
nth = hit_counts[node] - 1
if self.__is_branch_missed(node.addr, next_node.addr, nth=nth):
path_prob = prob * node.children_prob[-1]
flip_it = node.addr_range
if not self.__should_i_solve(filename, flip_it, nth=nth):
continue
# todo: find out why nth
if nth < 2:
missed_paths.append({
"flip": flip_it,
"prob": path_prob,
"fn": filename,
"nth": nth
})
prob *= self.__get_prob(node.addr, next_node.addr)
return sorted(missed_paths, key=lambda x: x["prob"])[:min(num, len(missed_paths))]