Skip to content

Commit 4411c86

Browse files
authoredJun 1, 2021
Merge pull request #320 from maxfreu/upsample-linear
introduce linear upsampling
2 parents 1dbdcef + b5abee5 commit 4411c86

File tree

2 files changed

+130
-2
lines changed

2 files changed

+130
-2
lines changed
 

‎src/upsample.jl

+115
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
export upsample_nearest, ∇upsample_nearest,
2+
upsample_linear, ∇upsample_linear,
23
upsample_bilinear, ∇upsample_bilinear,
34
upsample_trilinear, ∇upsample_trilinear,
45
pixel_shuffle
@@ -96,6 +97,120 @@ end
9697
return input_index0, input_index1, lambda0, lambda1
9798
end
9899

100+
###########
101+
# linear
102+
###########
103+
"""
104+
upsample_linear(x::AbstractArray{T,3}, scale::Real)
105+
upsample_linear(x::AbstractArray{T,3}; size::Integer)
106+
107+
Upsamples the first dimension of the array `x` by the upsample provided `scale`,
108+
using linear interpolation. As an alternative to using `scale`, the resulting array `size`
109+
can be directly specified with a keyword argument.
110+
111+
The size of the output is equal to
112+
`(scale*S1, S2, S3)`, where `S1, S2, S3 = size(x)`.
113+
"""
114+
function upsample_linear(x::AbstractArray{<:Any,3}, scale::Real)
115+
outsize = floor(Int, scale * Base.size(x)[1])
116+
return upsample_linear(x; size=outsize)
117+
end
118+
119+
function upsample_linear(x::AbstractArray{T,3}; size::Integer) where T
120+
w,c,n = Base.size(x)
121+
if w == size
122+
return x
123+
end
124+
y = similar(x, T, size, c, n)
125+
return upsample_linear_wcn!(y, x)
126+
end
127+
128+
function upsample_linear(x::AbstractArray{T,3}; size::Integer) where T<:Integer
129+
y = float.(x)
130+
res = upsample_linear(y; size=size)
131+
return round.(T, res)
132+
end
133+
134+
function upsample_linear_wcn!(output::AbstractArray{T,3}, input::AbstractArray{T,3}) where T
135+
size(input)[2:3] == size(output)[2:3] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
136+
in_w, channels, batches = size(input)
137+
# treat batch and channel dimension as one for better parallelization granularity
138+
channels *= batches
139+
out_w, _, _ = size(output)
140+
output_slice_size = out_w
141+
142+
# T() and // so that we can handle rationals (super slow)
143+
width_scale = T((in_w - 1) // (out_w - 1))
144+
145+
@inline idx(c, w) = c * in_w + w + 1
146+
147+
@inbounds Threads.@threads for c in 0:channels-1
148+
for ow in 0:out_w-1
149+
iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w)
150+
output_offset = c * output_slice_size + ow + 1
151+
output[output_offset] = (w0lambda * input[idx(c, iw0)] + # w0 * i00
152+
w1lambda * input[idx(c, iw1)]) # w1 * i01
153+
end
154+
end
155+
return output
156+
end
157+
158+
"""
159+
∇upsample_linear(Δ::AbstractArray{T,3}; size::Integer) where T
160+
161+
# Arguments
162+
- `Δ`: Incoming gradient array, backpropagated from downstream layers
163+
- `size`: Size of the image upsampled in the first place
164+
165+
# Outputs
166+
- `dx`: Downsampled version of `Δ`
167+
"""
168+
function ∇upsample_linear::AbstractArray{T,3}; size::Integer) where T
169+
w, c, n = Base.size(Δ)
170+
out_w = size
171+
if w == out_w
172+
return Δ
173+
end
174+
dx = zero(similar(Δ, T, out_w, c, n))
175+
return ∇upsample_linear_wcn!(dx, Δ)
176+
end
177+
178+
function ∇upsample_linear_wcn!(dx::AbstractArray{T,3}, Δ::AbstractArray{T,3}) where T
179+
size(dx)[2:3] == size(Δ)[2:3] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
180+
in_w, channels, batches = size(dx)
181+
182+
# treat batch and channel dimension as one for better parallelization granularity
183+
channels *= batches
184+
out_w, _, _ = size(Δ)
185+
output_slice_size = out_w
186+
187+
width_scale = T((in_w - 1) // (out_w - 1))
188+
189+
@inline idx(c, w) = c * in_w + w + 1
190+
191+
@inbounds Threads.@threads for c in 0:channels-1
192+
for ow in 0:out_w-1
193+
iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w)
194+
output_offset = c * output_slice_size + ow + 1
195+
Δ_value = Δ[output_offset]
196+
dx[idx(c, iw0)] += w0lambda * Δ_value # i00
197+
dx[idx(c, iw1)] += w1lambda * Δ_value # i01
198+
end
199+
end
200+
return dx
201+
end
202+
203+
function rrule(::typeof(upsample_linear), x; size)
204+
Ω = upsample_linear(x; size=size)
205+
function upsample_linear_pullback(Δ)
206+
(NO_FIELDS, ∇upsample_linear(Δ; size=Base.size(x,1)))
207+
end
208+
return Ω, upsample_linear_pullback
209+
end
210+
211+
###########
212+
# bilinear
213+
###########
99214
"""
100215
upsample_bilinear(x::AbstractArray{T,4}, scale::NTuple{2,Real})
101216
upsample_bilinear(x::AbstractArray{T,4}; size::NTuple{2,Integer})

‎test/upsample.jl

+15-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,20 @@
1717
@test_throws ArgumentError upsample_nearest(x, size=(3,4))
1818
end
1919

20-
@testset "upsample_bilinear 2d" begin
20+
@testset "Linear upsampling (1D)" begin
21+
x = Float64[1,2,3,4]
22+
x = hcat(x,x,x)[:,:,:]
23+
24+
y = collect(1:1//3:4)
25+
y = hcat(y,y,y)[:,:,:]
26+
yF64 = Float64.(y)
27+
28+
@test y upsample_linear(x, 2.5)
29+
@test y upsample_linear(x; size=10)
30+
gradtest(x->upsample_linear(x, 2.5), x)
31+
end
32+
33+
@testset "Bilinear upsampling (2D)" begin
2134
x = Float32[1 2; 3 4][:,:,:,:]
2235
x = cat(x,x; dims=3)
2336
x = cat(x,x; dims=4)
@@ -65,7 +78,7 @@ end
6578
@test y == y_true_int
6679
end
6780

68-
@testset "Trilinear upsampling" begin
81+
@testset "Trilinear upsampling (3D)" begin
6982
# Layout: WHDCN, where D is depth
7083
# we generate data which is constant along W & H and differs in D
7184
# then we upsample along all dimensions

0 commit comments

Comments
 (0)