-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathMethodGraphBatching.py
34 lines (24 loc) · 949 Bytes
/
MethodGraphBatching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
'''
Concrete MethodModule class for a specific learning MethodModule
'''
# Copyright (c) 2017 Jiawei Zhang <jwzhanggy@gmail.com>
# License: TBD
from method import method
import networkx as nx
class MethodGraphBatching(method):
data = None
k = 5
def run(self):
S = self.data['S']
index_id_dict = self.data['index_id_map']
user_top_k_neighbor_intimacy_dict = {}
for node_index in index_id_dict:
node_id = index_id_dict[node_index]
s = S[node_index]
s[node_index] = -1000.0
top_k_neighbor_index = s.argsort()[-self.k:][::-1]
user_top_k_neighbor_intimacy_dict[node_id] = []
for neighbor_index in top_k_neighbor_index:
neighbor_id = index_id_dict[neighbor_index]
user_top_k_neighbor_intimacy_dict[node_id].append((neighbor_id, s[neighbor_index]))
return user_top_k_neighbor_intimacy_dict