forked from hustzxd/LSQuantization
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_quan_base.py
84 lines (69 loc) · 2.91 KB
/
_quan_base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
"""
Quantized modules: the base class
"""
import torch
import torch.nn as nn
from enum import Enum
from torch.nn.parameter import Parameter
__all__ = ['Qmodes', '_Conv2dQ', '_LinearQ', '_ActQ']
class Qmodes(Enum):
layer_wise = 1
kernel_wise = 2
class _Conv2dQ(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True, **kwargs_q):
super(_Conv2dQ, self).__init__(in_channels, out_channels, kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=groups, bias=bias)
self.kwargs_q = get_default_kwargs_q(kwargs_q, layer_type=self)
self.nbits = kwargs_q['nbits']
if self.nbits < 0:
self.register_parameter('alpha', None)
return
self.q_mode = kwargs_q['mode']
if self.q_mode == Qmodes.kernel_wise:
self.alpha = Parameter(torch.Tensor(out_channels))
else: # layer-wise quantization
self.alpha = Parameter(torch.Tensor(1))
self.register_buffer('init_state', torch.zeros(1))
def add_param(self, param_k, param_v):
self.kwargs_q[param_k] = param_v
def extra_repr(self):
s_prefix = super(_Conv2dQ, self).extra_repr()
if self.alpha is None:
return '{}, fake'.format(s_prefix)
return '{}, {}'.format(s_prefix, self.kwargs_q)
class _LinearQ(nn.Linear):
def __init__(self, in_features, out_features, bias=True, **kwargs_q):
super(_LinearQ, self).__init__(in_features=in_features, out_features=out_features, bias=bias)
self.kwargs_q = get_default_kwargs_q(kwargs_q, layer_type=self)
self.nbits = kwargs_q['nbits']
if self.nbits < 0:
self.register_parameter('alpha', None)
return
self.alpha = Parameter(torch.Tensor(1))
self.register_buffer('init_state', torch.zeros(1))
def add_param(self, param_k, param_v):
self.kwargs_q[param_k] = param_v
def extra_repr(self):
s_prefix = super(_LinearQ, self).extra_repr()
if self.alpha is None:
return '{}, fake'.format(s_prefix)
return '{}, {}'.format(s_prefix, self.kwargs_q)
class _ActQ(nn.Module):
def __init__(self, **kwargs_q):
super(_ActQ, self).__init__()
self.kwargs_q = get_default_kwargs_q(kwargs_q, layer_type=self)
self.nbits = kwargs_q['nbits']
if self.nbits < 0:
self.register_parameter('alpha', None)
return
self.signed = kwargs_q['signed']
self.alpha = Parameter(torch.Tensor(1))
self.register_buffer('init_state', torch.zeros(1))
def add_param(self, param_k, param_v):
self.kwargs_q[param_k] = param_v
def extra_repr(self):
# s_prefix = super(_ActQ, self).extra_repr()
if self.alpha is None:
return 'fake'
return '{}'.format(self.kwargs_q)