This repo is folked from https://github.com/lizhengwei1992/Semantic_Human_Matting, which is an wonderful implementation of paper (Semantatic Human Matting) from Alibaba. I added codes for data preparation. Dataset got available as the company released their dataset, but you need to relocate the image locations to run the code, generate text files with image names, and last but not the least, generate binary masks for train images.
- python3.5 / 3.6
- pytorch >= 0.4
- opencv-contrib-python==4.1.2.30
- opencv-python==3.4.1.15
- tqdm==4.41.1
Directory structure of the project is as follows: It is almost identical with parent branch except few images and python file I added.
Semantic Human Matting
│ README.md
│ train.py
│ train.sh
| test_camera.py
| test_camera.sh
└───model
│ │ M_Net.py
│ │ T_Net.py
│ │ network.py
└───data
| data_prepare.py
│ dataset.py
│ gen_trimap.py
| gen_trimap.sh
| knn_matting.py
| knn_matting.sh
└───image
└───mask
└───trimap
└───alpha
First, you need to download dataset from kaggle and put it in ./data
. File size is about 16GB, so make sure you have enough disk space available for it beforehand.
Training Image | Binary Mask |
---|---|
Go to ./data
and run data_prepare.sh
.
PNG files support transparent pixels, thus these pixels should be converted to white prior to generate binary masks.
It will create two new directories and text files named as image
, mask
, train.txt
, and mask.txt
respectively at ./data
.
image
folder is for training images and mask
is for binary mask images that are corresponding to traing images.
File name of training images and mask images will be written on train.txt
and mask.txt
.
Use ./data/gen_trimap.sh
to get trimaps of the masks.
Use ./data/knn_matting.sh
to get alpha mattes(it will take long time...).
-
Trimap generation: T-Net
The T-Net plays the role of semantic segmentation. I use mobilenet_v2+unet as T-Net to predict trimap.
-
Matting network: M-Net
The M-Net aims to capture detail information and generate alpha matte. I build M-Net same as the paper, but reduce channels of the original net.
-
Fusion Module
The overall prediction loss for alpha_p at each pixel is
Read papers for more details, and my codes for two loss functions:
# -------------------------------------
# classification loss L_t
# ------------------------
criterion = nn.CrossEntropyLoss()
L_t = criterion(trimap_pre, trimap_gt[:,0,:,:].long())
# -------------------------------------
# prediction loss L_p
# ------------------------
eps = 1e-6
# l_alpha
L_alpha = torch.sqrt(torch.pow(alpha_pre - alpha_gt, 2.) + eps).mean()
# L_composition
fg = torch.cat((alpha_gt, alpha_gt, alpha_gt), 1) * img
fg_pre = torch.cat((alpha_pre, alpha_pre, alpha_pre), 1) * img
L_composition = torch.sqrt(torch.pow(fg - fg_pre, 2.) + eps).mean()
L_p = 0.5*L_alpha + 0.5*L_composition
Firstly, pre_train T-Net, use ./train.sh
as :
python3 train.py \
--dataDir='./data' \
--saveDir='./ckpt' \
--trainData='human_matting_data' \
--trainList='./data/train.txt' \
--load='human_matting' \
--nThreads=4 \
--patch_size=320 \
--train_batch=8 \
--lr=1e-3 \
--lrdecayType='keep' \
--nEpochs=1000 \
--save_epoch=1 \
--train_phase='pre_train_t_net'
Then, train end to end, use ./train.sh
as:
python3 train.py \
--dataDir='./data' \
--saveDir='./ckpt' \
--trainData='human_matting_data' \
--trainList='./data/train.txt' \
--load='human_matting' \
--nThreads=4 \
--patch_size=320 \
--train_batch=8 \
--lr=1e-4 \
--lrdecayType='keep' \
--nEpochs=2000 \
--save_epoch=1 \
--finetuning \
--train_phase='end_to_end'
run ./test_camera.sh