-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy path05_SHAP_importance.py
71 lines (60 loc) · 2.24 KB
/
05_SHAP_importance.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""SHAP features importance."""
import os
import shap
import numpy as np
import pandas as pd
from lightgbm import LGBMClassifier
# Define paths
# parent_dir = os.getcwd()
parent_dir = "/home/walker/rvallat/yasa_classifier" # Neurocluster
wdir = parent_dir + '/output/features/'
outdir = parent_dir + "/output/classifiers/"
assert os.path.isdir(wdir)
assert os.path.isdir(outdir)
# Load the full dataframe
df = pd.read_parquet(wdir + "features_all.parquet")
# Define predictors and target
X = df[df.columns.difference(['stage', 'dataset'])].sort_index(axis=1)
y = df['stage']
print(df.shape)
print(X.columns)
# Define hyper-parameters
params = dict(
boosting_type='gbdt',
n_estimators=400,
max_depth=5,
num_leaves=90,
colsample_bytree=0.5,
importance_type='gain',
n_jobs=20
)
# Manually define class weight
# See output/classifiers/gridsearch_class_weights.xlsx
params['class_weight'] = {'N1': 2.2, 'N2': 1, 'N3': 1.2, 'R': 1.4, 'W': 1}
fname = outdir + 'clf_eeg+eog+emg+demo_lgb_gbdt_custom_shap'
# Fit classifier
clf = LGBMClassifier(**params)
clf.fit(X, y)
print("Fitting done!")
# Calculate SHAP feature importance - we limit the number of trees for speed
explainer = shap.TreeExplainer(clf)
shap_values = explainer.shap_values(X, tree_limit=50)
# Sum absolute values across all stages and then average across all samples
# Gives similar results as summing the absolute values across all samples and then
# taking average or sum across sleep stage to get one value per feature.
shap_sum = np.abs(shap_values).sum(axis=0).mean(axis=0)
df_shap = pd.Series(shap_sum, index=X.columns.tolist(), name="Importance")
df_shap.sort_values(ascending=False, inplace=True)
df_shap.index.name = 'Features'
# Export
np.savez_compressed(fname + ".npz", shap_values=shap_values)
df_shap.to_csv(fname + ".csv")
# Disabled: plot
# from matplotlib import colors
# cmap_stages = ['#99d7f1', '#009DDC', 'xkcd:twilight blue',
# 'xkcd:rich purple', 'xkcd:sunflower']
# cmap = colors.ListedColormap(np.array(cmap_stages)[class_inds])
# class_inds = np.argsort(
# [-np.abs(shap_values[i]).mean() for i in range(len(shap_values))])
# shap.summary_plot(shap_values, X, plot_type='bar', max_display=15,
# color=cmap, class_names=clf.classes_)