Skip to content

jihwan1008/Semantic_Human_Matting

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

64 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Semantic Human Matting

This repo is folked from https://github.com/lizhengwei1992/Semantic_Human_Matting, which is an wonderful implementation of paper (Semantatic Human Matting) from Alibaba. I added codes for data preparation. Dataset got available as the company released their dataset, but you need to relocate the image locations to run the code, generate text files with image names, and last but not the least, generate binary masks for train images.

Requirements

  • python3.5 / 3.6
  • pytorch >= 0.4
  • opencv-contrib-python==4.1.2.30
  • opencv-python==3.4.1.15
  • tqdm==4.41.1

Usage

Directory structure of the project is as follows: It is almost identical with parent branch except few images and python file I added.

Semantic Human Matting
│   README.md
│   train.py
│   train.sh
|   test_camera.py
|   test_camera.sh
└───model
│   │   M_Net.py
│   │   T_Net.py
│   │   network.py
└───data
    |	data_prepare.py
    │   dataset.py
    │   gen_trimap.py
    |   gen_trimap.sh
    |   knn_matting.py
    |   knn_matting.sh
    └───image
    └───mask
    └───trimap
    └───alpha

Step 1: prepare dataset

First, you need to download dataset from kaggle and put it in ./data. File size is about 16GB, so make sure you have enough disk space available for it beforehand.

Training Image Binary Mask
Ref Mask

Go to ./data and run data_prepare.sh. PNG files support transparent pixels, thus these pixels should be converted to white prior to generate binary masks.

It will create two new directories and text files named as image, mask, train.txt, and mask.txt respectively at ./data. image folder is for training images and mask is for binary mask images that are corresponding to traing images. File name of training images and mask images will be written on train.txt and mask.txt.

Use ./data/gen_trimap.sh to get trimaps of the masks.

Use ./data/knn_matting.sh to get alpha mattes(it will take long time...).

Step 2: build network

SHM

  • Trimap generation: T-Net

    The T-Net plays the role of semantic segmentation. I use mobilenet_v2+unet as T-Net to predict trimap.

  • Matting network: M-Net

    The M-Net aims to capture detail information and generate alpha matte. I build M-Net same as the paper, but reduce channels of the original net.

  • Fusion Module

    Probabilistic estimation of alpha matte can be written as

Step 3: build loss

The overall prediction loss for alpha_p at each pixel is

The total loss is

Read papers for more details, and my codes for two loss functions:

    # -------------------------------------
    # classification loss L_t
    # ------------------------
    criterion = nn.CrossEntropyLoss()
    L_t = criterion(trimap_pre, trimap_gt[:,0,:,:].long())

    # -------------------------------------
    # prediction loss L_p
    # ------------------------
    eps = 1e-6
    # l_alpha
    L_alpha = torch.sqrt(torch.pow(alpha_pre - alpha_gt, 2.) + eps).mean()

    # L_composition
    fg = torch.cat((alpha_gt, alpha_gt, alpha_gt), 1) * img
    fg_pre = torch.cat((alpha_pre, alpha_pre, alpha_pre), 1) * img
    L_composition = torch.sqrt(torch.pow(fg - fg_pre, 2.) + eps).mean()
    L_p = 0.5*L_alpha + 0.5*L_composition

Step 4: train

Firstly, pre_train T-Net, use ./train.sh as :

python3 train.py \
	--dataDir='./data' \
	--saveDir='./ckpt' \
	--trainData='human_matting_data' \
	--trainList='./data/train.txt' \
	--load='human_matting' \
	--nThreads=4 \
	--patch_size=320 \
	--train_batch=8 \
	--lr=1e-3 \
	--lrdecayType='keep' \
	--nEpochs=1000 \
	--save_epoch=1 \
	--train_phase='pre_train_t_net'

Then, train end to end, use ./train.sh as:

python3 train.py \
	--dataDir='./data' \
	--saveDir='./ckpt' \
	--trainData='human_matting_data' \
	--trainList='./data/train.txt' \
	--load='human_matting' \
	--nThreads=4 \
	--patch_size=320 \
	--train_batch=8 \
	--lr=1e-4 \
	--lrdecayType='keep' \
	--nEpochs=2000 \
	--save_epoch=1 \
	--finetuning \
	--train_phase='end_to_end'

Test

run ./test_camera.sh

About

Semantic Human Matting

Resources

Stars

Watchers

Forks

Packages

No packages published

Languages

  • Python 96.3%
  • Shell 3.7%