Skip to content

Commit

Permalink
♻️ Add typehints.
Browse files Browse the repository at this point in the history
  • Loading branch information
futabato committed Sep 3, 2024
1 parent 7dad97e commit 9c44a8e
Showing 1 changed file with 157 additions and 0 deletions.
157 changes: 157 additions & 0 deletions src/federatedlearning/models/resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""ResNet in PyTorch.
For Pre-activation ResNet, see 'preact_resnet.py'.
Reference:
[1] Kaiming He, Xiangyu Zhang, Shaoqing Ren, Jian Sun
Deep Residual Learning for Image Recognition. arXiv:1512.03385
"""
from typing import Type, Union

import torch
import torch.nn as nn
import torch.nn.functional as F


class BasicBlock(nn.Module):
expansion = 1

def __init__(self, in_planes: int, planes: int, stride: int = 1) -> None:
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(planes)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(self.expansion * planes),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = F.relu(out)
return out


class Bottleneck(nn.Module):
expansion: int = 4

def __init__(self, in_planes: int, planes: int, stride: int = 1) -> None:
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(
planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(
planes, self.expansion * planes, kernel_size=1, bias=False
)
self.bn3 = nn.BatchNorm2d(self.expansion * planes)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(
in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False,
),
nn.BatchNorm2d(self.expansion * planes),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = F.relu(self.bn1(self.conv1(x)))
out = F.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(x)
out = F.relu(out)
return out


class ResNet(nn.Module):
def __init__(
self,
block: Union[Type[BasicBlock], Type[Bottleneck]],
num_blocks: list[int],
num_classes: int = 10,
) -> None:
super(ResNet, self).__init__()
self.in_planes = 64

self.conv1 = nn.Conv2d(
3, 64, kernel_size=3, stride=1, padding=1, bias=False
)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.fc2 = nn.Linear(512 * block.expansion, num_classes)

def _make_layer(
self,
block: Union[Type[BasicBlock], Type[Bottleneck]],
planes: int,
num_blocks: int,
stride: int,
) -> nn.Sequential:
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = F.relu(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = F.avg_pool2d(out, 4)
out = out.view(out.size(0), -1)
out = self.fc2(out)
return out


def ResNet18() -> ResNet:
return ResNet(BasicBlock, [2, 2, 2, 2])


def ResNet34() -> ResNet:
return ResNet(BasicBlock, [3, 4, 6, 3])


def ResNet50() -> ResNet:
return ResNet(Bottleneck, [3, 4, 6, 3])


def ResNet101() -> ResNet:
return ResNet(Bottleneck, [3, 4, 23, 3])


def ResNet152() -> ResNet:
return ResNet(Bottleneck, [3, 8, 36, 3])

0 comments on commit 9c44a8e

Please # to comment.