-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathinference.py
146 lines (131 loc) · 4.78 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
import joblib
import pandas as pd
import os
import time
import json
import numpy as np
from aro.model.zone_utils import zone_based_tsp
def get_route_with_station(df, route_id):
dfr = df[(df.route_id == route_id)].sort_values(
["is_station", "stop"], ascending=(False, True)
) # make sure depot is always the first
return dfr
if __name__ == "__main__":
"""
check example_inference.sh for usage
"""
parser = argparse.ArgumentParser(description="inference")
parser.add_argument(
"--model_dir", required=True, type=str,
help="Directory where the PPM model file is"
)
parser.add_argument(
"--ppm_model_fn", required=True, type=str,
help="File name of the PPM model"
)
parser.add_argument(
"--data_dir", required=True, type=str,
help="Directory where the parquet file lmc_route_full_16xxxxx.parquet is"
)
parser.add_argument(
"--dist_matrix_dir", required=True, type=str,
help="Directory of distance_matrix, where files like RouteID_*_raw_w_st.npy are"
)
parser.add_argument(
"--zone_list_dir", required=True, type=str,
help="Directory of zone_list, where files like RouteID_*_zone_w_st.joblib are"
)
parser.add_argument(
"--output_dir", required=True, type=str, help="Directory of output"
)
parser.add_argument(
"--output_score_dir", required=True, type=str, help="Directory of output"
)
parser.add_argument(
"--data_fn", required=True, type=str,
help="data file name, e.g. lmc_route_full_16xxxxx.parquet"
)
parser.add_argument(
"--model_score_input_dir", required=True, type=str, help="Directory of model score input"
)
parser.add_argument(
"--actual_seq_fn", required=True, type=str, help="data file name for actual sequences"
)
parser.add_argument(
"--invalid_seq_fn", required=True, type=str, help="data file name for invalid sequences"
)
parser.add_argument(
"--model_apply_input_dir", required=True, type=str, help="Directory of model apply input"
)
parser.add_argument(
"--travel_time_fn", required=True, type=str, help="data file name for travel time"
)
parser.add_argument(
"--debug",
dest="debug",
action="store_true",
default=False,
help="debug or not",
)
args = parser.parse_args()
df_val = pd.read_parquet(os.path.join(args.data_dir, args.data_fn))
json_dict = dict()
stt = time.time()
mfn = f"{args.model_dir}/{args.ppm_model_fn}"
if os.path.exists(mfn):
zone_prob_model = joblib.load(mfn)
else:
raise Exception(f'Cannot find {mfn}')
for idx, route_id in enumerate(df_val.route_id.unique()):
df_route = get_route_with_station(
df_val,
route_id,
)
nb_nodes = df_route.shape[0]
fn = f"{args.dist_matrix_dir}/{route_id}_raw_w_st.npy"
if os.path.exists(fn):
pre_dist_matrix = np.load(fn)
zfn = f"{args.zone_list_dir}/{route_id}_zone_w_st.joblib"
if (os.path.exists(zfn)):
zone_list = joblib.load(zfn)
cw = [0.25, 0.25, 0.25, 0.25]
zs_algo = 'ppm'
sol_list = zone_based_tsp(
pre_dist_matrix,
zone_list,
zone_prob_model,
route_id,
cluster_weights=cw,
zone_sort_algo=zs_algo
)
else:
print(f'zone files missing: {zfn}')
rank_list = [-1] * len(sol_list) # df_route.shape[0]
for i, rank in enumerate(sol_list):
try:
rank_list[rank] = i
except:
print(sol_list)
raise Exception(f"{i}, {rank}")
df_route["myrank"] = rank_list
df_route = df_route.sort_values(["stop"])
sequence_dict = {}
for _, stop_item in df_route.iterrows():
rank_id = stop_item.myrank
stop_id = stop_item.stop
sequence_dict[stop_id] = rank_id
proposal_dict = {"proposed": sequence_dict}
json_dict[route_id] = proposal_dict
if (args.debug):
break
if (idx + 1) % 10 == 0:
avg_speed = (time.time() - stt) / (idx + 1)
print(f'Number of routes done: {idx + 1}, avg speed = {avg_speed:.3f} per route')
ttt = int(time.time())
args_input_perf_file = f"eval-ppm-rollout-{ttt}.json"
output_json_file = os.path.join(args.output_dir, args_input_perf_file)
with open(output_json_file, "w") as fout:
json.dump(json_dict, fout)