-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrain.py
55 lines (42 loc) · 1.98 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import os
import pickle
import argparse
import pandas as pd
from loguru import logger
from model.utilities.classifier import Classifier
from preprocess import Preprocessor
from utilities.utils import log_and_exit, save_attendance
from model.utilities.config import config
def train(data, subject=None):
# saving the csv file for marking attendance
labels = pd.unique(data.Labels)
labels = pd.DataFrame(labels, columns=['Labels'])
labels.set_index('Labels', inplace=True)
# Save the csv file for taking attendance
save_attendance(subject=subject, data=labels)
# Training the classifier
classifier = Classifier(data)
print(config.CLASSIFIER_FILE_SUFFIX)
# SVC classification
svc = classifier.train_svc()
pickle.dump(svc, open(os.path.join(config.TRAINED_DATA, '{}{}.sav'.format(subject, config.CLASSIFIER_FILE_SUFFIX)), 'wb'))
def main(arguments):
# check if the given path is given or not
if not os.path.isdir(arguments.input):
log_and_exit("invalid input path is given: directory does not exist", logger.error)
if not os.path.isdir(arguments.output):
log_and_exit("invalid output path is given: directory does not exist", logger.error)
# initiate data processor
preprocessor = Preprocessor(path=arguments.input)
preprocessor.preprocess()
for subject in preprocessor.data:
data = Preprocessor.to_dataframe(subject)
train(data=data, subject=subject.name)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Encoding Faces and Training classifier on Encodings')
parser.add_argument('-i', '--input', help='path to the input folder or subject folder',
metavar='', type=str, default=os.path.join(config.TRAINING_DATA))
parser.add_argument('-o', '--output', help='path to the folder where trained results are to be saved',
metavar='', type=str, default=os.path.join(config.TRAINED_DATA))
args = parser.parse_args()
main(args)