Skip to content

Commit 83b2dfb

Browse files
ekagra-ranjanfmassa
authored andcommitted
Changing to AdaptiveAvgPool2d on VGG (#747)
The update allows VGG to process images larger or smaller than prescribed imagenet size using adaptive average pooling. Will be useful while finetuning or testing on different resolution images. Similar to #643 and #672. I did not include adaptive avg pool in features or classifier block so that these predefined blocks can be used as it is.
1 parent 6434dea commit 83b2dfb

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

torchvision/models/vgg.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class VGG(nn.Module):
2525
def __init__(self, features, num_classes=1000, init_weights=True):
2626
super(VGG, self).__init__()
2727
self.features = features
28+
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
2829
self.classifier = nn.Sequential(
2930
nn.Linear(512 * 7 * 7, 4096),
3031
nn.ReLU(True),
@@ -39,6 +40,7 @@ def __init__(self, features, num_classes=1000, init_weights=True):
3940

4041
def forward(self, x):
4142
x = self.features(x)
43+
x = self.avgpool(x)
4244
x = x.view(x.size(0), -1)
4345
x = self.classifier(x)
4446
return x

0 commit comments

Comments
 (0)