forked from maciejkula/triplet_recommendations_keras
-
Notifications
You must be signed in to change notification settings - Fork 7
/
metrics.py
81 lines (51 loc) · 1.97 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import numpy as np
from sklearn.metrics import roc_auc_score
def predict(model, uid, pids):
user_latent = model.layers[6].get_weights()[0][uid]
item_latent = model.layers[2].get_weights()[0][pids]
scores = np.dot(user_latent,item_latent.T)
return scores
def precision_at_k(model, ground_truth, k, user_features=None, item_features=None):
"""
Measure precision at k for model and ground truth.
Arguments:
- lightFM instance model
- sparse matrix ground_truth (no_users, no_items)
- int k
Returns:
- float precision@k
"""
ground_truth = ground_truth.tocsr()
no_users, no_items = ground_truth.shape
pid_array = np.arange(no_items, dtype=np.int32)
precisions = []
for user_id, row in enumerate(ground_truth):
uid_array = np.empty(no_items, dtype=np.int32)
uid_array.fill(user_id)
predictions = model.predict(uid_array, pid_array,
user_features=user_features,
item_features=item_features,
num_threads=4)
top_k = set(np.argsort(-predictions)[:k])
true_pids = set(row.indices[row.data == 1])
if true_pids:
precisions.append(len(top_k & true_pids) / float(k))
return sum(precisions) / len(precisions)
def full_auc(model, ground_truth):
"""
Measure AUC for model and ground truth on all items.
Returns:
- float AUC
"""
ground_truth = ground_truth.tocsr()
no_users, no_items = ground_truth.shape
pid_array = np.arange(no_items, dtype=np.int32)
scores = []
for user_id, row in enumerate(ground_truth):
predictions = predict(model, user_id, pid_array)
true_pids = row.indices[row.data == 1]
grnd = np.zeros(no_items, dtype=np.int32)
grnd[true_pids] = 1
if len(true_pids):
scores.append(roc_auc_score(grnd, predictions))
return sum(scores) / len(scores)