-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathBig_mart_sales_prediction.py
135 lines (86 loc) · 5.01 KB
/
Big_mart_sales_prediction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# -*- coding: utf-8 -*-
"""
Created on Wed Jun 5 19:27:35 2024
@author: prachet
"""
import numpy as np
import pickle
import streamlit as st
with open("model.pkl", 'rb') as f:
model1 = pickle.load(f)
# with open("model2.pkl", 'rb') as f:
# model2 = pickle.load(f)
with open("model3.pkl", 'rb') as f:
model3 = pickle.load(f)
with open("encoders.pkl", 'rb') as f:
encoders = pickle.load(f)
#creating a function for prediction
def big_mart_sales_prediction1(input_data):
encoded_data = list(input_data)
columns = ['Item_Identifier', 'Item_Weight', 'Item_Fat_Content', 'Item_Visibility', 'Item_Type', 'Item_MRP', 'Outlet_Identifier', 'Outlet_Establishment_Year', 'Outlet_Size', 'Outlet_Location_Type', 'Outlet_Type']
for i, col in enumerate(columns):
if col in encoders:
encoded_data[i] = encoders[col].transform([encoded_data[i]])[0]
encoded_data = np.array(encoded_data).reshape(1, -1)
prediction = model1.predict(encoded_data)
return prediction[0]
# def big_mart_sales_prediction2(input_data):
# encoded_data = list(input_data)
# columns = ['Item_Identifier', 'Item_Weight', 'Item_Fat_Content', 'Item_Visibility', 'Item_Type', 'Item_MRP', 'Outlet_Identifier', 'Outlet_Establishment_Year', 'Outlet_Size', 'Outlet_Location_Type', 'Outlet_Type']
# for i, col in enumerate(columns):
# if col in encoders:
# encoded_data[i] = encoders[col].transform([encoded_data[i]])[0]
# encoded_data = np.array(encoded_data).reshape(1, -1)
# prediction = model2.predict(encoded_data)
# return prediction[0]
def big_mart_sales_prediction3(input_data):
encoded_data = list(input_data)
columns = ['Item_Identifier', 'Item_Weight', 'Item_Fat_Content', 'Item_Visibility', 'Item_Type', 'Item_MRP', 'Outlet_Identifier', 'Outlet_Establishment_Year', 'Outlet_Size', 'Outlet_Location_Type', 'Outlet_Type']
for i, col in enumerate(columns):
if col in encoders:
encoded_data[i] = encoders[col].transform([encoded_data[i]])[0]
encoded_data = np.array(encoded_data).reshape(1, -1)
prediction = model3.predict(encoded_data)
return prediction[0]
def main():
#giving a title
st.title('Big Mart Sales Prediction Web App')
col1 , col2 = st.columns(2)
#getting input data from user
with col1:
Item_Identifier = st.text_input("Item Identifier")
with col2:
Item_Weight = st.number_input("Item Weight")
with col1:
Item_Fat_Content = st.selectbox('Item Fat Content',('Low Fat', 'Regular'))
with col2:
Item_Visibility = st.number_input("Item Visibility")
with col1:
Item_Type = st.selectbox("Item Type",('Fruits and Vegetables','Snack Foods','Household','Frozen Foods','Dairy','Canned','Baking Goods','Health and Hygiene','Soft Drinks','Meat','Breads','Hard Drinks','Starchy Foods','Breakfast','Seafood','Others'))
with col2:
Item_MRP = st.number_input("Item MRP")
with col1:
Outlet_Identifier = st.selectbox("Outlet_Identifier",('OUT010','OUT013','OUT017','OUT018','OUT019','OUT027','OUT035','OUT045','OUT046','OUT049'))
with col2:
Outlet_Establishment_Year = int(st.selectbox("Outlet Establishment Year",('1985','1987','1997','1998','1999','2002','2004','2007','2009')))
with col1:
Outlet_Size = st.selectbox("Outlet_Size",('Small', 'Medium','High'))
with col2:
Outlet_Location_Type = st.selectbox("Outlet_Location_Type",('Tier 1', 'Tier 2','Tier 3'))
with col1:
Outlet_Type = st.selectbox("Outlet_Type",('Grocery Store', 'Supermarket Type1','Supermarket Type2','Supermarket Type3'))
# code for prediction
sales1 = ''
sales3 = ''
#creating a button for Prediction
if st.button('Predict Big Mart Sales using Model1'):
sales1 = big_mart_sales_prediction1((Item_Identifier,Item_Weight,Item_Fat_Content,Item_Visibility,Item_Type,Item_MRP,Outlet_Identifier,Outlet_Establishment_Year,Outlet_Size,Outlet_Location_Type,Outlet_Type))
st.success('The Predicted Sales: '+ str(sales1)+'$'+('(XGBoost)'))
# if st.button('Predict Big Mart Sales using Model2'):
# sales2 = big_mart_sales_prediction2((Item_Identifier,Item_Weight,Item_Fat_Content,Item_Visibility,Item_Type,Item_MRP,Outlet_Identifier,Outlet_Establishment_Year,Outlet_Size,Outlet_Location_Type,Outlet_Type))
# st.success('The Predicted Sales: '+ str(sales2)+'$'+('(Random Forest)'))
if st.button('Predict Big Mart Sales using Model3'):
sales3 = big_mart_sales_prediction3((Item_Identifier,Item_Weight,Item_Fat_Content,Item_Visibility,Item_Type,Item_MRP,Outlet_Identifier,Outlet_Establishment_Year,Outlet_Size,Outlet_Location_Type,Outlet_Type))
st.success('The Predicted Sales: '+ str(sales3)+'$'+('(Decision Tree)'))
if __name__ == '__main__':
main()