-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfrom simpletransformers.ini
51 lines (39 loc) · 2 KB
/
from simpletransformers.ini
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
from simpletransformers.classification import ClassificationModel, ClassificationArgs
from transformers import BertTokenizer
import shap
import torch
import time
# 初始化模型和 tokenizer
dir_name = r'C:\114project\outputs\bert-base-Chinese-bs-64-epo-3'
model_args = ClassificationArgs()
model = ClassificationModel('bert', dir_name, use_cuda=True, cuda_device=0, num_labels=6, args=model_args)
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
# 預測函數封裝
def predict_proba(texts):
"""
處理 SHAP 輸入,將分詞結果轉換為字符串列表後進行預測。
"""
# 如果是 token 列表,將其轉回字符串
if isinstance(texts, list) and all(isinstance(t, list) and all(isinstance(token, str) for token in t) for t in texts):
texts = ["".join(tokens) for tokens in texts] # 將 token 列表拼接為字符串
elif isinstance(texts, str):
texts = [texts]
elif not isinstance(texts, list) or not all(isinstance(t, str) for t in texts):
raise ValueError("`texts` must be a string or a list of strings.")
# 預測並返回概率
predictions, raw_outputs = model.predict(texts)
probabilities = torch.nn.functional.softmax(torch.tensor(raw_outputs), dim=1)
return probabilities.numpy()
if __name__ == "__main__":
tStart = time.time()
# 初始化 SHAP Explainer
masker = shap.maskers.Text(tokenizer) # 定義文本遮蔽方式
explainer = shap.Explainer(predict_proba, masker)
# 示例文章
article = "唯一缺點就是隔音太糟了!其他倒是不錯,但隔音很不好,雖然房內會附2副耳塞,但對於睡眠品質有點講究的朋友,真的不建議選這間。"
# 使用 SHAP 解釋文章情緒分類
shap_values = explainer([article]) # 傳遞單條文本(包裝為列表)
# 可視化解釋
shap.plots.text(shap_values[0]) # 可視化第一條文本的 SHAP 值
tEnd = time.time()
print(f"SHAP 解釋執行花費 {tEnd - tStart:.2f} 秒。")