训练数据集 | 权值文件名称 | 测试数据集 | 精度 |
---|---|---|---|
MNIST-train | resnet50_mnist.pth | MNIST-test | 99.64% |
torch==1.7.1
训练所需的resnet50_mnist.pth可以在百度云或google drive下载。 百度云链接: 链接:https://pan.baidu.com/s/1apl5kspxGvjg4y6hLjSktQ?pwd=4mlv 提取码:4mlv
google drive 链接: https://drive.google.com/file/d/1rFNsKgbUWKfp533Znsu0Jwz3XhQSxksM/view?usp=sharing
百度云链接: 链接: https://pan.baidu.com/s/1MYMs_axknMm2g5Ou-cWmgQ 提取码: 8ce2
- 下载好预训练的模型或按照训练步骤训练好模型;
- 在prediction.py文件里面,在如下部分修改PAHT使其对应训练好的模型路径;
PATH = './logs/resnet50-mnist.pth'
- 运行prediction.py,输入每次预测的图片个数。
- 本文使用MNIST数据集进行训练,调用pytorch接口可以直接进行下载(代码已写好);
- 如果使用pytorch接口下载速度慢,可使用百度云进行下载。将下载后的文件放入data文件夹中即可;
- 运行train.py即可开始训练。
https://github.com/bubbliiiing/faster-rcnn-pytorch.git