-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathconv.py
41 lines (33 loc) · 1.09 KB
/
conv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from accelerate import init_empty_weights
import torch
@classmethod
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
weight = weights.get_tensor(f"{prefix}.weight")
bias = weights.get_tensor(f"{prefix}.bias")
with init_empty_weights():
conv2d = cls(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
)
conv2d.weight = torch.nn.Parameter(weight)
conv2d.bias = torch.nn.Parameter(bias)
return conv2d
@classmethod
def load_conv2d_no_bias(
cls, prefix, weights, in_channels, out_channels, kernel_size, stride
):
weight = weights.get_tensor(f"{prefix}.weight")
with init_empty_weights():
conv2d = cls(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
)
conv2d.weight = torch.nn.Parameter(weight)
conv2d.bias = None
return conv2d
torch.nn.Conv2d.load = load_conv2d
torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias