-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathq2_algorithm.py
146 lines (113 loc) · 4.97 KB
/
q2_algorithm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
WeChat: cstutorcs
QQ: 749389476
Email: tutorcs@163.com
import typing as T
from math import inf
import torch
from torch.nn.functional import pad
from torch import Tensor
import einops
def is_projective(heads: T.Iterable[int]) -> bool:
"""
Determines whether the dependency tree for a sentence is projective.
Args:
heads: The indices of the heads of the words in sentence. Since ROOT
has no head, it is not expected to be part of the input, but the
index values in heads are such that ROOT is assumed in the
starting (zeroth) position. See the examples below.
Returns:
True if and only if the tree represented by the input is
projective.
Examples:
The projective tree from the assignment handout:
>>> is_projective([2, 5, 4, 2, 0, 7, 5, 7])
True
The non-projective tree from the assignment handout:
>>> is_projective([2, 0, 2, 2, 6, 3, 6])
False
"""
projective = True
# *** ENTER YOUR CODE BELOW *** #
return projective
def is_single_root(heads: Tensor, lengths: Tensor) -> Tensor:
"""
Determines whether the selected arcs for a sentence constitute a tree with
a single root word.
Remember that index 0 indicates the ROOT node. A tree with "a single root
word" has exactly one outgoing edge from ROOT.
If you like, you may add helper functions to this file for this function.
This file already imports the function `pad` for you. You may find that
function handy. Here's the documentation of the function:
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
Args:
heads (Tensor): a Tensor of dimensions (batch_sz, sent_len) and dtype
int where the entry at index (b, i) indicates the index of the
predicted head for vertex i for input b in the batch
lengths (Tensor): a Tensor of dimensions (batch_sz,) and dtype int
where each element indicates the number of words (this doesn't
include ROOT) in the corresponding sentence.
Returns:
A Tensor of dtype bool and dimensions (batch_sz,) where the value
for each element is True if and only if the corresponding arcs
constitute a single-root-word tree as defined above
Examples:
Valid trees from the assignment handout:
>>> is_single_root(torch.tensor([[2, 5, 4, 2, 0, 7, 5, 7],\
[2, 0, 2, 2, 6, 3, 6, 0]]),\
torch.tensor([8, 7]))
tensor([True, True])
Invalid trees (the first has a cycle; the second has multiple roots):
>>> is_single_root(torch.tensor([[2, 5, 4, 2, 0, 8, 6, 7],\
[2, 0, 2, 2, 6, 3, 6, 0]]),\
torch.tensor([8, 8]))
tensor([False, False])
"""
# *** ENTER YOUR CODE BELOW *** #
tree_single_root = torch.ones_like(heads[:, 0], dtype=torch.bool)
return tree_single_root
def mst_single_root(arc_tensor: Tensor, lengths: Tensor) -> Tensor:
"""
Finds the maximum spanning tree (more technically, arborescence) for the
given sentences such that each tree has a single root word.
Remember that index 0 indicates the ROOT node. A tree with "a single root
word" has exactly one outgoing edge from ROOT.
If you like, you may add helper functions to this file for this function.
This file already imports the function `pad` for you. You may find that
function handy. Here's the documentation of the function:
https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
Args:
arc_tensor (Tensor): a Tensor of dimensions (batch_sz, x, y) and dtype
float where x=y and the entry at index (b, i, j) indicates the
score for a candidate arc from vertex j to vertex i.
lengths (Tensor): a Tensor of dimensions (batch_sz,) and dtype int
where each element indicates the number of words (this doesn't
include ROOT) in the corresponding sentence.
Returns:
A Tensor of dtype int and dimensions (batch_sz, x) where the value at
index (b, i) indicates the head for vertex i according to the
maximum spanning tree for the input graph.
Examples:
>>> mst_single_root(torch.tensor(\
[[[0, 0, 0, 0],\
[12, 0, 6, 5],\
[4, 5, 0, 7],\
[4, 7, 8, 0]],\
[[0, 0, 0, 0],\
[1.5, 0, 4, 0],\
[2, 0.1, 0, 0],\
[0, 0, 0, 0]],\
[[0, 0, 0, 0],\
[4, 0, 3, 1],\
[6, 2, 0, 1],\
[1, 1, 8, 0]]]),\
torch.tensor([3, 2, 3]))
tensor([[0, 0, 3, 1],
[0, 2, 0, 0],
[0, 2, 0, 2]])
"""
# *** ENTER YOUR CODE BELOW *** #
best_arcs = arc_tensor.argmax(-1)
return best_arcs
if __name__ == '__main__':
import doctest
doctest.testmod()