CNN-based Cascaded Multi-task Learning of High-level Prior and Density Estimation for Crowd Counting (Single Image Crowd Counting)
This is implementation of the paper CNN-based Cascaded Multi-task Learning of High-level Prior and Density Estimation for Crowd Counting for single image crowd counting which is accepted at AVSS 2017
-
Install pytorch
-
Clone this repository
git clone https://github.com/svishwa/crowdcount-cascaded-mtl.git
We'll call the directory that you cloned crowdcount-cascaded-mtl ROOT
-
Download ShanghaiTech Dataset
Dropbox: https://www.dropbox.com/s/fipgjqxl7uj8hd5/ShanghaiTech.zip?dl=0
Baidu Disk: http://pan.baidu.com/s/1nuAYslz
-
Create Directory
mkdir ROOT/data/original/shanghaitech/
-
Save "part_A_final" under ROOT/data/original/shanghaitech/
-
Save "part_B_final" under ROOT/data/original/shanghaitech/
-
cd ROOT/data_preparation/
run create_gt_test_set_shtech.m in matlab to create ground truth files for test data
-
cd ROOT/data_preparation/
run create_training_set_shtech.m in matlab to create training and validataion set along with ground truth files
-
Follow steps 1,2,3,4 and 5 from Data Setup
-
Download pre-trained model files:
Save the model files under ROOT/final_models
-
Run test.py
a. Set save_output = True to save output density maps
b. Errors are saved in output directory
- Follow steps 1,2,3,4 and 6 from Data Setup
- Run train.py
With the aid of Crayon, we can access the visualisation power of TensorBoard for any deep learning framework.
To use the TensorBoard, install Crayon (https://github.com/torrvision/crayon)
and set use_tensorboard = True
in ROOT/train.py
.
-
During training, the best model is chosen using error on the validation set.
-
10% of the training set is set aside for validation. The validation set is chosen randomly.
-
Following are the results on Shanghai Tech A and B dataset:
| | MAE | MSE | ------------------------ | A | 101 | 148 | ------------------------ | B | 17 | 29 |
It may be noted that the results are slightly different from the paper. This is due to a few implementation differences as the earlier implementation was in torch-lua. Contact me if torch models (that were used for the paper) are required.