diff --git a/models/graphcnn.py b/models/graphcnn.py index f5de7df..ab17f4f 100644 --- a/models/graphcnn.py +++ b/models/graphcnn.py @@ -16,8 +16,8 @@ def __init__(self, num_layers, num_mlp_layers, input_dim, hidden_dim, output_dim output_dim: number of classes for prediction final_dropout: dropout ratio on the final linear layer learn_eps: If True, learn epsilon to distinguish center nodes from neighboring nodes. If False, aggregate neighbors and center nodes altogether. - neighbor_pooling_type: how to aggregate neighbors (mean, average, or max) - graph_pooling_type: how to aggregate entire nodes in a graph (mean, average) + neighbor_pooling_type: how to aggregate neighbors (sum, average, or max) + graph_pooling_type: how to aggregate entire nodes in a graph (sum, average) device: which device to use '''