-
Notifications
You must be signed in to change notification settings - Fork 204
/
Copy pathgraph2seq.py
98 lines (95 loc) · 2.61 KB
/
graph2seq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
from graph4nlp.pytorch.models.graph2seq import Graph2Seq
# from examples.pytorch.rgcn.rgcn import RGCN
from graph4nlp.pytorch.modules.graph_embedding_learning.rgcn import RGCN
class RGCNGraph2Seq(Graph2Seq):
def __init__(
self,
vocab_model,
emb_input_size,
emb_hidden_size,
embedding_style,
graph_construction_name,
gnn_direction_option,
gnn_input_size,
gnn_hidden_size,
gnn_output_size,
gnn,
gnn_num_layers,
dec_hidden_size,
share_vocab=False,
gnn_feat_drop=0,
gnn_attn_drop=0,
emb_fix_word_emb=False,
emb_fix_bert_emb=False,
emb_word_dropout=0,
emb_rnn_dropout=0,
dec_max_decoder_step=50,
dec_use_copy=False,
dec_use_coverage=False,
dec_graph_pooling_strategy=None,
dec_rnn_type="lstm",
dec_tgt_emb_as_output_layer=False,
dec_teacher_forcing_rate=1,
dec_attention_type="uniform",
dec_fuse_strategy="average",
dec_node_type_num=None,
dec_dropout=0,
**kwargs
):
super().__init__(
vocab_model,
emb_input_size,
emb_hidden_size,
embedding_style,
graph_construction_name,
gnn_direction_option,
gnn_input_size,
gnn_hidden_size,
gnn_output_size,
gnn,
gnn_num_layers,
dec_hidden_size,
share_vocab,
gnn_feat_drop,
gnn_attn_drop,
emb_fix_word_emb,
emb_fix_bert_emb,
emb_word_dropout,
emb_rnn_dropout,
dec_max_decoder_step,
dec_use_copy,
dec_use_coverage,
dec_graph_pooling_strategy,
dec_rnn_type,
dec_tgt_emb_as_output_layer,
dec_teacher_forcing_rate,
dec_attention_type,
dec_fuse_strategy,
dec_node_type_num,
dec_dropout,
**kwargs
)
def _build_gnn_encoder(
self,
gnn,
num_layers,
input_size,
hidden_size,
output_size,
direction_option,
feats_dropout,
gnn_num_rels=80,
gnn_num_bases=4,
**kwargs
):
self.gnn_encoder = RGCN(
num_layers,
input_size,
hidden_size,
output_size,
num_rels=gnn_num_rels,
direction_option=direction_option,
# num_bases=gnn_num_bases,
# dropout=feats_dropout,
feat_drop=feats_dropout,
)