forked from mbhenaff/spectral-lib
-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathBias.lua
65 lines (57 loc) · 1.53 KB
/
Bias.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
require 'nn'
local Bias, parent = torch.class('nn.Bias', 'nn.Module')
function Bias:__init(nPlanes,stdv)
parent.__init(self)
self.stdv = stdv
self.bias = torch.Tensor(nPlanes)
self.gradBias = torch.Tensor(nPlanes)
self:reset()
end
function Bias:reset(stdv)
if stdv then
stdv = stdv * math.sqrt(3)
elseif self.stdv then
stdv = self.stdv
else
stdv = 1./math.sqrt(self.bias:size(1))
end
self.bias:uniform(-stdv, stdv)
end
function Bias:updateOutput(input)
local d1 = input:size(1)
local d2 = input:size(2)
local d3 = input:size(3)
local resize = false
if input:nDimension() == 3 then
resize = true
input:resize(d1,d2,d3,1)
end
self.output:resize(input:size())
self.output:copy(input)
libspectralnet.bias_updateOutput(self.bias, self.output)
if resize then
input:resize(d1,d2,d3)
self.output:resize(input:size())
end
return self.output
end
function Bias:updateGradInput(input, gradOutput)
self.gradInput:resize(input:size())
self.gradInput:copy(gradOutput)
return self.gradInput
end
function Bias:accGradParameters(input, gradOutput, scale)
local scale = scale or 1
local d1 = gradOutput:size(1)
local d2 = gradOutput:size(2)
local d3 = gradOutput:size(3)
local resize = false
if gradOutput:nDimension() == 3 then
resize = true
gradOutput:resize(d1,d2,d3,1)
end
libspectralnet.bias_accGradParameters(self.gradBias, gradOutput, scale)
if resize then
gradOutput:resize(d1,d2,d3)
end
end