Skip to content

Use automatic downloads of pretrained weights instead of hardcoded paths #31

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@

__pycache__/
models/seg_net_bayes.py

2 changes: 2 additions & 0 deletions models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@
from .fcn8s import *
from .gcn import *
from .psp_net import *
from .psp_net_multihead import *
from .seg_net import *
from .seg_net_bayes import *
from .u_net import *
14 changes: 4 additions & 10 deletions models/config.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
import os

# here (https://github.com/pytorch/vision/tree/master/torchvision/models) to find the download link of pretrained models

root = '/media/b3-542/LIBRARY/ZijunDeng/PyTorch Pretrained'
res101_path = os.path.join(root, 'ResNet', 'resnet101-5d3b4d8f.pth')
res152_path = os.path.join(root, 'ResNet', 'resnet152-b121ed2d.pth')
inception_v3_path = os.path.join(root, 'Inception', 'inception_v3_google-1a9a5a14.pth')
vgg19_bn_path = os.path.join(root, 'VggNet', 'vgg19_bn-c79401a0.pth')
vgg16_path = os.path.join(root, 'VggNet', 'vgg16-397923af.pth')
dense201_path = os.path.join(root, 'DenseNet', 'densenet201-4c113574.pth')
# PyTorch will automatically download pretrained weights into `os.environ['TORCH_MODEL_ZOO']`
# using the mechanism described here: (http://pytorch.org/docs/master/model_zoo.html)
# Download links used are also listed here: (https://github.com/pytorch/vision/tree/master/torchvision/models)

'''
vgg16 trained using caffe
visit this (https://github.com/jcjohnson/pytorch-vgg) to download the converted vgg16
'''
vgg16_caffe_path = os.path.join(root, 'VggNet', 'vgg16-caffe.pth')
vgg16_caffe_path = os.path.join(os.environ.get('TORCH_MODEL_ZOO', '.'), 'vgg16-caffe.pth')
7 changes: 1 addition & 6 deletions models/duc_hdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@
from torch import nn
from torchvision import models

from .config import res152_path


class _DenseUpsamplingConvModule(nn.Module):
def __init__(self, down_factor, in_dim, num_classes):
super(_DenseUpsamplingConvModule, self).__init__()
Expand All @@ -26,9 +23,7 @@ class ResNetDUC(nn.Module):
# the size of image should be multiple of 8
def __init__(self, num_classes, pretrained=True):
super(ResNetDUC, self).__init__()
resnet = models.resnet152()
if pretrained:
resnet.load_state_dict(torch.load(res152_path))
resnet = models.resnet152(pretrained=pretrained)
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
self.layer1 = resnet.layer1
self.layer2 = resnet.layer2
Expand Down
4 changes: 1 addition & 3 deletions models/fcn16s.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
class FCN16VGG(nn.Module):
def __init__(self, num_classes, pretrained=True):
super(FCN16VGG, self).__init__()
vgg = models.vgg16()
if pretrained:
vgg.load_state_dict(torch.load(vgg16_caffe_path))
vgg = models.vgg16(pretrained=pretrained)
features, classifier = list(vgg.features.children()), list(vgg.classifier.children())

features[0].padding = (100, 100)
Expand Down
5 changes: 1 addition & 4 deletions models/fcn32s.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@
from torchvision import models

from ..utils import get_upsampling_weight
from .config import vgg16_caffe_path


class FCN32VGG(nn.Module):
def __init__(self, num_classes, pretrained=True):
super(FCN32VGG, self).__init__()
vgg = models.vgg16()
if pretrained:
vgg.load_state_dict(torch.load(vgg16_caffe_path))
vgg = models.vgg16(pretrained=pretrained)
features, classifier = list(vgg.features.children()), list(vgg.classifier.children())

features[0].padding = (100, 100)
Expand Down
18 changes: 10 additions & 8 deletions models/fcn8s.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,22 @@
from torchvision import models

from ..utils import get_upsampling_weight
from .config import vgg16_path, vgg16_caffe_path
from .config import vgg16_caffe_path


# This is implemented in full accordance with the original one (https://github.com/shelhamer/fcn.berkeleyvision.org)
class FCN8s(nn.Module):
def __init__(self, num_classes, pretrained=True, caffe=False):
super(FCN8s, self).__init__()
vgg = models.vgg16()
if pretrained:
if caffe:
# load the pretrained vgg16 used by the paper's author
vgg.load_state_dict(torch.load(vgg16_caffe_path))
else:
vgg.load_state_dict(torch.load(vgg16_path))

if pretrained and caffe:
vgg = models.vgg16()
# load the pretrained vgg16 used by the paper's author
vgg.load_state_dict(torch.load(vgg16_caffe_path))
else:
# if pretrained, load the weights from PyTorch model zoo
vgg = models.vgg16(pretrained=pretrained)

features, classifier = list(vgg.features.children()), list(vgg.classifier.children())

'''
Expand Down
5 changes: 1 addition & 4 deletions models/gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from torchvision import models

from ..utils import initialize_weights
from .config import res152_path


# many are borrowed from https://github.com/ycszen/pytorch-ss/blob/master/gcn.py
Expand Down Expand Up @@ -52,9 +51,7 @@ class GCN(nn.Module):
def __init__(self, num_classes, input_size, pretrained=True):
super(GCN, self).__init__()
self.input_size = input_size
resnet = models.resnet152()
if pretrained:
resnet.load_state_dict(torch.load(res152_path))
resnet = models.resnet152(pretrained=pretrained)
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu)
self.layer1 = nn.Sequential(resnet.maxpool, resnet.layer1)
self.layer2 = resnet.layer2
Expand Down
18 changes: 8 additions & 10 deletions models/psp_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@

from ..utils import initialize_weights
from ..utils.misc import Conv2dDeformable
from .config import res101_path


class _PyramidPoolingModule(nn.Module):
def __init__(self, in_dim, reduction_dim, setting):
super(_PyramidPoolingModule, self).__init__()
super().__init__()
self.features = []
for s in setting:
self.features.append(nn.Sequential(
Expand All @@ -25,18 +24,16 @@ def forward(self, x):
x_size = x.size()
out = [x]
for f in self.features:
out.append(F.upsample(f(x), x_size[2:], mode='bilinear'))
out.append(F.interpolate(f(x), x_size[2:], mode='bilinear'))
out = torch.cat(out, 1)
return out


class PSPNet(nn.Module):
def __init__(self, num_classes, pretrained=True, use_aux=True):
super(PSPNet, self).__init__()
super().__init__()
self.use_aux = use_aux
resnet = models.resnet101()
if pretrained:
resnet.load_state_dict(torch.load(res101_path))
resnet = models.resnet101(pretrained=pretrained)
self.layer0 = nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool)
self.layer1, self.layer2, self.layer3, self.layer4 = resnet.layer1, resnet.layer2, resnet.layer3, resnet.layer4

Expand Down Expand Up @@ -66,7 +63,8 @@ def __init__(self, num_classes, pretrained=True, use_aux=True):

initialize_weights(self.ppm, self.final)

def forward(self, x):
def forward(self, image):
x = image
x_size = x.size()
x = self.layer0(x)
x = self.layer1(x)
Expand All @@ -78,8 +76,8 @@ def forward(self, x):
x = self.ppm(x)
x = self.final(x)
if self.training and self.use_aux:
return F.upsample(x, x_size[2:], mode='bilinear'), F.upsample(aux, x_size[2:], mode='bilinear')
return F.upsample(x, x_size[2:], mode='bilinear')
return F.interpolate(x, x_size[2:], mode='bilinear'), F.interpolate(aux, x_size[2:], mode='bilinear')
return F.interpolate(x, x_size[2:], mode='bilinear')


# just a try, not recommend to use
Expand Down
75 changes: 75 additions & 0 deletions models/psp_net_multihead.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import torch
import torch.nn.functional as F
from torch import nn
from torchvision import models

from ..utils import initialize_weights

from .psp_net import _PyramidPoolingModule


class PSPHead(nn.Module):
def __init__(self, num_classes):
super().__init__()

self.ppm = _PyramidPoolingModule(2048, 512, (1, 2, 3, 6))
self.final = nn.Sequential(
nn.Conv2d(4096, 512, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(512, momentum=.95),
nn.ReLU(inplace=True),
nn.Dropout(0.1),
nn.Conv2d(512, num_classes, kernel_size=1)
)

initialize_weights(self.ppm, self.final)

def forward(self, features_from_backbone, img_size):
result = self.final(self.ppm(features_from_backbone))
return F.interpolate(result, img_size[2:], mode='bilinear')


class PSPNet_Multihead(nn.Module):
def __init__(self, num_heads, num_classes, pretrained=True):
super().__init__()

self.init_heads(num_heads, num_classes=num_classes)
self.init_backbone(pretrained=pretrained)

def init_backbone(self, pretrained):
resnet = models.resnet101(pretrained=pretrained)

for n, m in resnet.layer3.named_modules():
if 'conv2' in n:
m.dilation, m.padding, m.stride = (2, 2), (2, 2), (1, 1)
elif 'downsample.0' in n:
m.stride = (1, 1)
for n, m in resnet.layer4.named_modules():
if 'conv2' in n:
m.dilation, m.padding, m.stride = (4, 4), (4, 4), (1, 1)
elif 'downsample.0' in n:
m.stride = (1, 1)

self.backbone = nn.Sequential(
nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool), # layer 0
resnet.layer1,
resnet.layer2,
resnet.layer3,
resnet.layer4,
)

def init_heads(self, num_heads, num_classes):

self.heads = nn.Sequential(
*[PSPHead(num_classes=num_classes) for i in range(num_heads)]
)

def forward(self, image):
img_size = image.size()

backbone_features = self.backbone(image)

return torch.cat([
head(backbone_features, img_size=img_size)
for head in self.heads
], dim=1)

7 changes: 2 additions & 5 deletions models/seg_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
from torchvision import models

from ..utils import initialize_weights
from .config import vgg19_bn_path


class _DecoderBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_conv_layers):
super(_DecoderBlock, self).__init__()
middle_channels = in_channels / 2
middle_channels = in_channels // 2
layers = [
nn.ConvTranspose2d(in_channels, in_channels, kernel_size=2, stride=2),
nn.Conv2d(in_channels, middle_channels, kernel_size=3, padding=1),
Expand All @@ -35,9 +34,7 @@ def forward(self, x):
class SegNet(nn.Module):
def __init__(self, num_classes, pretrained=True):
super(SegNet, self).__init__()
vgg = models.vgg19_bn()
if pretrained:
vgg.load_state_dict(torch.load(vgg19_bn_path))
vgg = models.vgg19_bn(pretrained=pretrained)
features = list(vgg.features.children())
self.enc1 = nn.Sequential(*features[0:7])
self.enc2 = nn.Sequential(*features[7:14])
Expand Down
99 changes: 99 additions & 0 deletions models/seg_net_bayes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@

import torch
from torch import nn
from torchvision import models

from ..utils import initialize_weights
from .seg_net import _DecoderBlock, SegNet

class SegNetBayes(SegNet):
def __init__(self, num_classes, dropout_p=0.5, pretrained=True, num_samples=16, min_batch_size=4):
"""
@param num_samples: number of samples for the Monte-Carlo simulation,
how many times to run with random dropout
"""
super().__init__(num_classes=num_classes, pretrained=pretrained)

self.drop = nn.Dropout2d(p=dropout_p, inplace=False)
self.num_samples = num_samples
self.min_batch_size = min_batch_size

def forward(self, x):
enc1 = self.enc1(x)
#print('enc1', enc1.shape)

enc2 = self.enc2(enc1)
#print('enc2', enc2.shape)

enc3 = self.enc3(enc2)
#print('enc3', enc3.shape)
enc3 = self.drop(enc3)
#print('enc3d', enc3.shape)

enc4 = self.enc4(enc3)
#print('enc4', enc4.shape)
enc4 = self.drop(enc4)
#print('enc4d', enc4.shape)

enc5 = self.enc5(enc4)
#print('enc5', enc5.shape)
enc5 = self.drop(enc5)
#print('enc5d', enc5.shape)

dec5 = self.dec5(enc5)
#print('dec5', dec5.shape)
dec5 = self.drop(dec5)
#print('dec5d', dec5.shape)

dec4 = self.dec4(torch.cat([enc4, dec5], 1))
#print('dec4', dec4.shape)
dec4 = self.drop(dec4)
#print('dec4d', dec4.shape)

dec3 = self.dec3(torch.cat([enc3, dec4], 1))
dec3 = self.drop(dec3)

dec2 = self.dec2(torch.cat([enc2, dec3], 1))
dec1 = self.dec1(torch.cat([enc1, dec2], 1))
return dec1

def forward_multisample(self, x, num_samples=None):
# dropout must be active
backup_train_mode = self.drop.training
self.drop.train()

softmax = torch.nn.Softmax2d()

num_samples = num_samples if num_samples else self.num_samples

results = [softmax(self.forward(x)).data.cpu() for i in range(num_samples)]

preds = torch.stack(results).cuda()
avg = torch.mean(preds, 0)
var = torch.var(preds, 0)
del preds

# restore mode
self.drop.train(backup_train_mode)

return dict(
mean = avg,
var = var,
)

#def sample(self, x, num_samples, batch_size):
#infer desired batch size from input shape
#we will divide a num_samples into batches
#num_frames = x.shape[0]
#batch_size = max(num_frames, self.min_batch_size)

#for sample_idx in range(num_samples):
#pred =


#for fr_idx in range(num_frames):
#x_single = x[fr_idx:fr_idx+1, :, :, :]
#self.sample(x_single, num_samples, batch_size)



Loading