-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathuser_input.py
58 lines (41 loc) · 2.09 KB
/
user_input.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
"""
TEST FROM USER INPUT
"""
import os.path
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from tensorflow import keras
# -- Load trained LSTM model --
model = keras.models.load_model(os.path.join("model", "LSTM_1_model_saved_model"))
# -- Load the MinMaxScaler used during training --
scaler = MinMaxScaler()
# -- Load the training data --
training_data = pd.read_csv(os.path.join("dataset", "pre_processed_data.csv"))
# -- Fit the scaler with the training data --
scaler.fit(training_data[['arrival_hour', 'arrival_minute', 'stop_lat', 'stop_lon', 'next_lat', 'next_lon', 'direction_id']])
# -- Function to get congestion level based on user input --
def predict_congestion(user_input):
# Extract user input
arrival_hour, arrival_minute, stop_lat, stop_lon, next_lat, next_lon, direction_id = user_input
# Create a feature vector using the user input
user_features = np.array([[arrival_hour, arrival_minute, stop_lat, stop_lon, next_lat, next_lon, direction_id]])
# Normalize the user input using the fitted scaler
user_features = scaler.transform(user_features)
# Reshape the feature vector to match the model's input shape
user_features = user_features.reshape((1, user_features.shape[0], user_features.shape[1]))
# Use the model to predict the congestion level
predicted_congestion = model.predict(user_features)
return predicted_congestion[0][0]
# -- Get user input --
arrival_hour = int(input("Enter arrival hour: "))
arrival_minute = int(input("Enter arrival minute: "))
stop_lat = float(input("Enter stop latitude: "))
stop_lon = float(input("Enter stop longitude: "))
next_lat = float(input("Enter next stop latitude: "))
next_lon = float(input("Enter next stop longitude: "))
direction_id = int(input("Enter direction id: "))
# -- Make a prediction --
user_input = [stop_lat, stop_lon, next_lat, next_lon, arrival_hour, arrival_minute, direction_id]
predicted_congestion = predict_congestion(user_input)
print(f"Predicted Congestion Level: {predicted_congestion}")