-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathMLR_test.py
64 lines (47 loc) · 1.98 KB
/
MLR_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
'''
Created on May 31, 2016
@author: qwaider
'''
from __future__ import division
import sys, getopt, os
from __main__ import config
import subprocess
import numpy as np
import errno
def mkdir_p(path):
try:
os.makedirs(path)
except OSError as exc: # Python >2.5
if exc.errno == errno.EEXIST and os.path.isdir(path):
pass
else:
raise
def main(testFile,modelsdir,outfile):
print "\nTest \"" + modelsdir + "\" on \"" + testFile + "\"..."
data_size = num_lines = sum(1 for line in open(testFile))
# load the labels
labels = []
with open(testFile, 'r') as fin:
for line in fin:
labels = np.append (labels, int(line.split(' ')[0]) )
labels_rank_mat = labels.reshape([int(labels.shape[0]/int(config['CHANNELS'])), int(config['CHANNELS'])])
# predict the ranks
args = [config['RANKLIBDIR']+"/RankLib-2.6.jar", "-load", modelsdir+"/MLR.model",
"-rank", testFile,"-norm", "zscore", "-metric2T", "NDCG@"+config['CHANNELS'],
"-score", config['BASEDIR'] + "/temp/MLR/Scores "]
subprocess.Popen('java -jar' + (''.join(list(str(" " + e) for e in args))), shell=True, stdout=subprocess.PIPE).stdout.read()
# load the predicted scores into a matrix
scores = []
with open(config['BASEDIR'] + "/temp/MLR/Scores", 'r') as fin:
for line in fin:
scores = np.append( scores, "{0:.3f}".format(float(line.split('\t')[2])) )
scores_mat = scores.reshape([int(scores.shape[0]/int(config['CHANNELS'])), int(config['CHANNELS'])]).astype(np.float)
# convert ranking scores into ranks
import rank_array
pred_rank_mat = rank_array.main(scores_mat, config['use_ties'])
np.savetxt( outfile, pred_rank_mat, fmt='%d')
# ----------------- Compute NDCG
import compute_NDCG
compute_NDCG.main ( labels_rank_mat, pred_rank_mat )
if __name__ == "__main__":
main(sys.argv[1],sys.argv[2],sys.argv[3])