-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdsl.py
448 lines (320 loc) · 9.89 KB
/
dsl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
from operator import mul, add
from numpy import zeros, ones, empty, array
from functools import partial, reduce
from copy import deepcopy
class Delta:
def __init__(self, head, type=None, tailtypes=None, tails=None, repr=None, hiddentail=None, arrow=None, ishole=False, isarg=False):
self.head = head
self.tails = tails
self.tailtypes = tailtypes
self.type = type
self.ishole = ishole
self.isarg = isarg
if arrow:
self.arrow = arrow
self.type = arrow
else:
if tailtypes:
self.arrow = (tuple(tailtypes), type)
else:
self.arrow = type
self.hiddentail = hiddentail
if repr is None:
repr = str(head)
if not self.ishole and not self.isarg and type == str:
repr = f"'{repr}'"
self.repr = repr
self.idx = 0
def __call__(self):
if self.tails is None:
return self.head
if self.hiddentail:
body = deepcopy(self.hiddentail)
for tidx, tail in enumerate(self.tails):
# arg in hiddentail should only match itself for replacement
body = replace_hidden(body, Delta(f'${tidx}', isarg=True, type=tail.type), tail)
return body()
tails = []
for a in self.tails:
if isinstance(a, Delta):
tails.append(a())
else:
tails.append(a)
return self.head(*tails)
def balance(self):
if not self.tails:
return self
if not any(map(isterminal, self.tails)):
self.tails = sorted(self.tails, key=str)
if self.hiddentail:
self.hiddentail.balance()
for tail in self.tails:
tail.balance()
return self
def __eq__(self, other):
if other is None:
return False
if not isinstance(other, Delta):
return False
return isequal(self, other)
def __hash__(self):
return hash(repr(self))
def __repr__(self):
if self.tails is None or len(self.tails) == 0:
return f'{self.repr}'
else:
tails = self.tails
return f'({self.repr} {" ".join(map(str, tails))})'
def isterminal(d: Delta) -> bool:
if d.tailtypes == None:
return True
if d.tails is None or len(d.tails) == 0:
return False
for tail in d.tails:
if not isterminal(tail):
return False
return True
def length(tree: Delta) -> int:
if not tree:
return 0
if not tree.tails:
return 1
return 1 + sum(map(length, tree.tails))
def countholes(tree: Delta) -> int:
if not tree:
return 0
if tree.ishole:
return 1
if not tree.tails:
return 0
return sum(map(countholes, tree.tails))
def getdepth(tree: Delta) -> int:
if tree.tails is None or len(tree.tails) == 0:
return 0
out = 0
for tail in tree.tails:
out = max(out, 1 + getdepth(tail))
return out
def getast(expr):
ast = []
idx = 0
while idx < len(expr):
if expr[idx] == '(':
nopen = 1
sidx = idx
while nopen != 0:
idx += 1
if expr[idx] == '(':
nopen += 1
if expr[idx] == ')':
nopen -= 1
ast.append(getast(expr[sidx+1:idx]))
elif not expr[idx] in "() ":
se_idx = idx
idx += 1
while idx < len(expr) and not expr[idx] in "() ":
idx += 1
ast.append(expr[se_idx:idx])
elif expr[idx].isdigit():
sidx = idx
out = ''
nopen = 1
while idx < len(expr) and expr[idx].isdigit():
out += expr[idx]
idx += 1
ast.append(out)
# for the next ) or something else
idx -= 1
elif not expr[idx] in [' ', ')']:
ast.append(expr[idx])
idx += 1
if isinstance(ast[0], list):
return ast[0]
return ast
def todelta(D, ast):
if not isinstance(ast, list):
if ast.startswith('$'):
return Delta(ast)
if (idx := D.index(ast)) is None:
raise ValueError(f"what's a {ast}?")
return D[idx]
newast = []
idx = 0
while idx < len(ast):
d = todelta(D, ast[idx])
args = []
idx += 1
while idx < len(ast):
args.append(todelta(D, ast[idx]))
idx += 1
if len(args) > 0:
d.tails = args
newast.append(d)
idx += 1
return newast[0]
def tr(D, expr):
return todelta(D, getast(expr))
def isequal(n1, n2):
if n1.ishole or n2.ishole:
return n1.type == n2.type
if n1.isarg and n2.isarg:
return n1.type == n2.type
if n1.head == n2.head:
# 26 no kids
if not n1.tails and not n2.tails:
return True
if not n1.tails or not n2.tails:
return False
if len(n1.tails) != len(n2.tails):
return False
for t1, t2 in zip(n1.tails, n2.tails):
if not isequal(t1, t2):
return False
return True
return False
def extract_matches(tree, treeholed):
"""
given a healthy tree, find in it part covering holes in a given treeholed
return pairs of holes and covered parts
"""
if not tree or not treeholed:
return []
if treeholed.ishole or treeholed.isarg:
return [(treeholed.head, deepcopy(tree))]
out = []
if not tree.tails:
return []
# assert len(tree.tails) == len(treeholed.tails)
for tail, holedtail in zip(tree.tails, treeholed.tails):
out += extract_matches(tail, holedtail)
return out
def replace_hidden(tree, arg, tail):
if isequal(tree, arg):
return deepcopy(tail)
if not tree.tails:
return tree
qq = [tree]
while len(qq) > 0:
n = qq.pop(0)
if not n.tails: continue
for idx, nt in enumerate(n.tails):
if isequal(nt, arg):
n.tails[idx] = tail
break
else:
qq.append(nt)
return tree
def replace(tree, matchbranch, newbranch):
if isequal(tree, matchbranch):
branch = deepcopy(newbranch)
if not tree.tails:
return branch
args = {arg: tail for arg, tail in extract_matches(tree, matchbranch)}
if len(args) > 0:
branch.tails = list(args.values())
return branch
qq = [tree]
while len(qq) > 0:
n = qq.pop(0)
if not n.tails: continue
for i in range(len(n.tails)):
if isequal(n.tails[i], matchbranch):
branch = deepcopy(newbranch)
args = {arg: tail for arg, tail in extract_matches(n.tails[i], matchbranch)}
branch.tails = list(args.values())
n.tails[i] = branch
else:
qq.append(n.tails[i])
return tree
# d.type $ has property of wildcard matching
# making it impossible to modify hiddentails
def freeze(tree: Delta):
if tree.ishole:
tree.ishole = False
tree.isarg = True
if tree.hiddentail:
freeze(tree.hiddentail)
if tree.tails:
for tail in tree.tails:
freeze(tail)
def normalize(tree):
if tree.hiddentail:
ht = normalize(deepcopy(tree.hiddentail))
if tree.tails:
for tidx, tail in enumerate(tree.tails):
replace_hidden(ht, Delta(f'${tidx}', isarg=True, type=tail.type), normalize(tail))
return ht
qq = [tree]
while len(qq) > 0:
n = qq.pop(0)
if not n.tails:
continue
for idx in range(len(n.tails)):
if n.tails[idx].hiddentail:
tails = n.tails[idx].tails
n.tails[idx] = normalize(deepcopy(n.tails[idx].hiddentail))
if tails:
for tidx, tail in enumerate(tails):
n.tails[idx] = replace_hidden(n.tails[idx], Delta(f'${tidx}', isarg=True, type=tail.type), normalize(tail))
else:
qq.append(normalize(n.tails[idx]))
return tree
# not reentrant
def typize(tree: Delta):
"replace each hole with $arg, returning all $arg's types"
qq = [tree]
tailtypes = []
z = 0
while len(qq) > 0:
n = qq.pop(0)
if not n.tails:
continue
for idx in range(len(n.tails)):
# is this a hole?
if n.tails[idx].ishole:
type = n.tails[idx].type
tailtypes.append(type)
# need to hole it for next tree replacement
n.tails[idx] = Delta(f'${z}', ishole=True, type=type)
z += 1
else:
qq.append(n.tails[idx])
return tailtypes
def showoff_types(tree: Delta):
qq = [tree]
types = set()
while len(qq) > 0:
n = qq.pop(0)
types.add(n.type)
if not n.tails:
continue
for t in n.tails:
qq.append(t)
return types
def comp(n1, n2):
if isequal(n1, n2):
return deepcopy(n1)
if not n1.tails or not n2.tails:
return False
for c1, c2 in zip(n1.tails, n2.tails):
for out in [comp(n1, c2), comp(n2, c1), comp(c1, c2)]:
if out:
return out
return False
def showoff_kids(tree):
if not tree.tails:
return
yield str(tree)
for tail in tree.tails:
yield from showoff_kids(tail)
# also know as filter
def flatten(xs):
return reduce(lambda acc, x: acc + x if x else acc, xs, [])
def alld(tree):
"enumerate all heads in tree"
if not tree.tails:
return [tree]
heads = [tree]
for t in tree.tails:
heads.extend(alld(t))
return heads