-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpredict.py
157 lines (124 loc) · 4.81 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from simpletransformers.classification import ClassificationModel, ClassificationArgs
import time
from transformers import BertTokenizer, BertModel
import torch
import model_db
from datetime import datetime
import random
import sys
# 在全局範圍內加載模型
dir_name = r'C:\114project\outputs\bert-base-Chinese-bs-64-epo-3'
model_args = ClassificationArgs()
model_args.train_batch_size = 64
model_args.num_train_epochs = 3
model = ClassificationModel(
'bert',
dir_name,
use_cuda=True,
cuda_device=0,
num_labels=6,
args=model_args
)
# 在 predict 函數中進行推斷
def predict(listData):
predictions, raw_outputs = model.predict(listData)
return predictions
def split(article):
# 加载中文BERT模型和tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
model = BertModel.from_pretrained('bert-base-chinese')
# 文章 string
# 分成适当的长度
max_length = 512
article_chunks = [article[i:i+max_length] for i in range(0, len(article), max_length)]
# 断句
sentences = []
for chunk in article_chunks:
# 文本编码
inputs = tokenizer(chunk, return_tensors="pt", max_length=max_length, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
# 判断断句位置(这里简单地基于句号判断)
tokenized_text = tokenizer.tokenize(chunk)
sentence_indices = [i for i, token in enumerate(tokenized_text) if token == '。' or token == ';' or token == '.' or token == ';' or token == '?' or token == '!' or token == '~' or token == '.' or token == '?' or token == '!' or token == ' ']
# 将文章断句
start = 0
for idx in sentence_indices:
sentence = tokenizer.decode(inputs.input_ids[0][start:idx+1])
sentences.append(sentence)
start = idx + 1
# 打印结果
data = []
for sentence in sentences:
if len(sentence) > 1:
data.append(sentence)
return data
def stats(finalpredict):
total = len(finalpredict) # 總共的情緒數量
other = 0 # 0
like = 0 # 1
sadness = 0 # 2
disgust = 0 # 3
anger = 0 # 4
happiness = 0 # 5
# Initialize data list to store counts for each emotion
data = [0] * 6
# Count occurrences of each emotion
for i in finalpredict:
if i == 0:
other += 1
elif i == 1:
like += 1
elif i == 2:
sadness += 1
elif i == 3:
disgust += 1
elif i == 4:
anger += 1
elif i == 5:
happiness += 1
# Assign counts to data list in corresponding index positions
data[0] = sadness/total*100
data[1] = disgust/total*100
data[2] = like/total*100
data[3] = anger/total*100
data[4] = happiness/total*100
data[5] = other/total*100
return data
def randomToChooose(data):
pos = [data[2], data[4]]
neg = [data[0], data[1], data[3]]
maxPos = max(pos)
maxNeg = max(neg)
maxPosIndex = pos.index(maxPos)
maxNegIndex = neg.index(maxNeg)
posRandom = random.randint(0, 5) if maxPosIndex == 0 else random.randint(6, 10)
negRandom = random.randint(0, 5) if maxNegIndex == 0 else (random.randint(6, 10) if maxNegIndex == 1 else random.randint(11, 15))
return posRandom, negRandom
if __name__ =="__main__":
tStart = time.time()
article = "唯一缺點就是隔音太糟了!其他倒是不錯,但隔音很不好,雖然房內會附2副耳塞,但對於睡眠品質有點講究的朋友,真的不建議選這間。另外,我訂的是雙人房,一大床,不是拿兩小床來併成一張類雙人床!中間怎麼樣都是一個凸起線在哪裡,叫人怎麼睡!我被安排在117房,整晚時不時就會聽到加壓馬達的聲音!再加上隔壁房間的朋友打呼!真的是快瘋掉了!"
listData = split(article)
# 斷句
print("斷句結果:", listData)
finalpredict = predict(listData)
# 最終預測結果
print("預測結果:", finalpredict)
# if len(sys.argv) != 2:
# print("Usage:python predict.py <diary_id>")
# sys.exit(1)
# diary_Id = sys.argv[1]
# diary_Id = str(input("請輸入日記id: "))
listData = []
# articles = []
# articles = model_db.get_diary_content(diary_Id)
# listData = spilt(articles) # 斷句
# finalpredict = predict(listData) # 最終預測結果
# print(finalpredict)
statistics = stats(finalpredict) # 數據分析
posRandom, negRandom = randomToChooose(statistics)
date = datetime.now().strftime('%Y-%m-%d')
# model_db.save_db_analysis(date, statistics, diary_Id)
# model_db.save_db_diaries_mon_ang(diary_Id, posRandom, negRandom)
tEnd = time.time()
print(f"執行花費{tEnd-tStart}秒。")