主代码请浏览 vis_nbdt.py 改代码主要生成不带指标测试的可视化树
如果想生成带有指标的可视化树可以运行代码 vnbdt_quant.py
大部分函数在 vnbdt.py
可直接执行:
python vis_nbdt.py
或者:
python vnbdt_quant.py
以上两个代码可支持输入一张图像or一个文件夹,具体需要填写的参数见代码:
以下两个代码是计算iou与afc两个指标,用以评价模型对不同树结构or不同类别的鲁棒性:
python vnbdt_quant_for_all_class.py
python vnbdt_quant_for_different_tree.py
可以增加你自己的模型: 于vnbdt.py文件的_call-xxx-model()_与_get-layer()_函数中增加对应的模型结构及所需映射的网络层,于nbdt/graph.py中的MODEL_FC_KEYS增加对应的FC层名称
ResNet50, ResNet18
vgg16
wrn28_10_cifar10
细粒度模型DFLCNN
如果需要扩展自己想要的数据集,首先需要在nbdt/hierarchies中创建空文件夹(数据集名如Emo),在nbdt/wnids中创建 对应名称的txt文件(如Emo.txt)并输入对应分类数量的ID,最后在nbdt/utils.py中增加对应数据集的各类信息,才能保证 代码能在新数据集上运行
CIFAR10, 时装数据集10子类 Fashion10
PARA美学数据的情感分类数据集 Emo
'FGVC', 'FGVC12','FGVC10': FGVC飞行器数据集的子类
'Imagenet10': imagenet10子类
"gradcam": GradCAM,
"scorecam": ScoreCAM,
"gradcam++": GradCAMPlusPlus,
"ablationcam": AblationCAM,
"xgradcam": XGradCAM,
"eigencam": EigenCAM,
"eigengradcam": EigenGradCAM,
"layercam": LayerCAM,
"fullgrad": FullGrad,
'efccam': EFC_CAM
Decision-probability-oriented fusion of multi-CAM
Similarity-attention-oriented fusion of multi-CAM
Without any weight
专家树需要重新在nbdt/hierarcies/dataset 文件夹下重新定义一个固定的专家树结构json文件,详情邮件咨询:li_zhili0105@163.com
induced: 诱导树
random: 随机树
pro: 专家树(结构固定)
效果展示
由于visualization依托于HTML页面,因此在生成树结构和热力图以后对其保存路径有比较严格的文件夹设置要求:
1.samples/img_xxx: 保存需要解释的图像集(文件夹)img_xxx在sample中,此处代码假设img_xxx下没有子文件夹,全是图像
如果测试图像不在项目文件夹下,就会出现html页面根节点不显示原图
img_1.jpg/png ---> 为了更好地展示结果,这里的图片名称最好与原本的类别对应,比如 mastiff1.jpg
img_2.jpg/png
......
2.xx_output: 运行代码后保存CAM解释生成的图像,每一个图像都会生成一个文件夹,文件夹下保存链路所有节点的图像,这个生成的文件夹名称直接体现CAM方法和融合方法
img_1_[CAM_METHODS]_[MERGE_METHODS]:
img_1_[CAM_METHODS]_[MERGE_METHODS]_1.jpg
img_1_[CAM_METHODS]_[MERGE_METHODS]_2.jpg
......
img_2_[CAM_METHODS]_[MERGE_METHODS]:
img_2_[CAM_METHODS]_[MERGE_METHODS]_1.jpg
img_2_[CAM_METHODS]_[MERGE_METHODS]_2.jpg
......
3.xx_html: 运行代码以后生成html文件,打开即可监视解释效果并互动
xxx1.html
xxx2.html
......
依赖环境:在requirements.txt 中
python 3.6.2 torch 1.10.1 cuda 10.2
detail==0.2.2
matplotlib==3.3.4
networkx==2.5.1
nltk==3.6.7
numpy==1.19.5
opencv_python==4.4.0.42
Pillow==9.2.0
pytorchcv==0.0.67
scikit_learn==1.1.2
torch==1.10.1
torchvision==0.11.2
tqdm==4.64.0
ttach==0.0.3