Skip to content

Commit 908ec40

Browse files
Merge pull request #315 from maxfreu/trilinear
introduce trilinear upsampling
2 parents a72e085 + a2bfabd commit 908ec40

File tree

2 files changed

+185
-2
lines changed

2 files changed

+185
-2
lines changed

src/upsample.jl

+153-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
export upsample_nearest, ∇upsample_nearest,
22
upsample_bilinear, ∇upsample_bilinear,
3+
upsample_trilinear, ∇upsample_trilinear,
34
pixel_shuffle
45

56
"""
@@ -9,7 +10,7 @@ export upsample_nearest, ∇upsample_nearest,
910
Upsamples the array `x` by integer multiples along the first `S` dimensions.
1011
Subsequent dimensions of `x` are not altered.
1112
12-
Either the `scale` factors or the final output `size` can be specified.
13+
Either the `scale` factors or the final output `size` can be specified.
1314
1415
See also [`upsample_bilinear`](@ref), for two dimensions of an `N=4` array.
1516
@@ -257,6 +258,157 @@ function rrule(::typeof(upsample_bilinear), x; size)
257258
return Ω, upsample_bilinear_pullback
258259
end
259260

261+
###########
262+
# trilinear
263+
###########
264+
"""
265+
upsample_trilinear(x::AbstractArray{T,5}, scale::NTuple{3,Real})
266+
upsample_trilinear(x::AbstractArray{T,5}; size::NTuple{3,Integer})
267+
268+
Upsamples the first 3 dimensions of the array `x` by the upsample factors stored in `scale`,
269+
using trilinear interpolation. As an alternative to using `scale`, the resulting image `size`
270+
can be directly specified with a keyword argument.
271+
272+
The size of the output is equal to
273+
`(scale[1]*S1, scale[2]*S2, scale[3]*S3, S4, S5)`, where `S1, S2, S3, S4, S5 = size(x)`.
274+
275+
# Examples
276+
277+
```julia
278+
upsample_trilinear(x, (2, 3, 4))
279+
upsample_trilinear(x; size=(4, 9, 11)) # specify ouput size instead
280+
upsample_trilinear(x, (2.5, 3.5, pi)) # non-integer scaling factors are allowed
281+
```
282+
"""
283+
function upsample_trilinear(x::AbstractArray{<:Any,5}, scale::NTuple{3,Real})
284+
outsize = ntuple(i -> floor(Int, scale[i] * Base.size(x, i)), 3)
285+
return upsample_trilinear(x; size=outsize)
286+
end
287+
288+
upsample_trilinear(x, scale::Real) = upsample_trilinear(x, (scale,scale,scale))
289+
290+
function upsample_trilinear(x::AbstractArray{T,5}; size::NTuple{3,Integer}) where T
291+
w,h,d,c,n = Base.size(x)
292+
if (w,h,d) == size
293+
return x
294+
end
295+
y = similar(x, T, size..., c, n)
296+
return upsample_trilinear_whdcn!(y, x)
297+
end
298+
299+
function upsample_trilinear(x::AbstractArray{T,5}; size::NTuple{3,Integer}) where T<:Integer
300+
y = float.(x)
301+
res = upsample_trilinear(y; size=size)
302+
return round.(T, res)
303+
end
304+
305+
function upsample_trilinear_whdcn!(output::AbstractArray{T,5}, input::AbstractArray{T,5}) where T
306+
size(input)[4:5] == size(output)[4:5] || error("Number of input and output channels and batches must match. Got input $(size(input)) and output $(size(output))")
307+
in_w, in_h, in_d, channels, batches = size(input)
308+
# treat batch and channel dimension as one for better parallelization granularity
309+
channels *= batches
310+
out_w, out_h, out_d, _, _ = size(output)
311+
output_slice_size = out_h * out_w * out_d
312+
313+
# T() and // so that we can handle rationals (super slow)
314+
width_scale = T((in_w - 1) // (out_w - 1))
315+
height_scale = T((in_h - 1) // (out_h - 1))
316+
depth_scale = T((in_d - 1) // (out_d - 1))
317+
318+
@inline idx(c, d, h, w) = c * in_d * in_h * in_w + d * in_h * in_w + h * in_w + w + 1
319+
320+
@inbounds Threads.@threads for c in 0:channels-1
321+
for od in 0:out_d-1
322+
id0, id1, d0lambda, d1lambda = compute_source_index_and_lambda(depth_scale, od, in_d, out_d)
323+
for oh in 0:out_h-1
324+
ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h)
325+
for ow in 0:out_w-1
326+
iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w)
327+
output_offset = c * output_slice_size + od * out_w * out_h + oh * out_w + ow + 1
328+
output[output_offset] =
329+
d0lambda * h0lambda * w0lambda * input[idx(c, id0, ih0, iw0)] + # d0 * h0 * w0 * i000
330+
d0lambda * h0lambda * w1lambda * input[idx(c, id0, ih0, iw1)] + # d0 * h0 * w1 * i001
331+
d0lambda * h1lambda * w0lambda * input[idx(c, id0, ih1, iw0)] + # d0 * h1 * w0 * i010
332+
d0lambda * h1lambda * w1lambda * input[idx(c, id0, ih1, iw1)] + # d0 * h1 * w1 * i011
333+
d1lambda * h0lambda * w0lambda * input[idx(c, id1, ih0, iw0)] + # d1 * h0 * w0 * i100
334+
d1lambda * h0lambda * w1lambda * input[idx(c, id1, ih0, iw1)] + # d1 * h0 * w1 * i101
335+
d1lambda * h1lambda * w0lambda * input[idx(c, id1, ih1, iw0)] + # d1 * h1 * w0 * i110
336+
d1lambda * h1lambda * w1lambda * input[idx(c, id1, ih1, iw1)] # d1 * h1 * w1 * i111
337+
end
338+
end
339+
end
340+
end
341+
return output
342+
end
343+
344+
"""
345+
∇upsample_trilinear(Δ::AbstractArray{T,5}; size::NTuple{3,Integer}) where T
346+
347+
# Arguments
348+
- `Δ`: Incoming gradient array, backpropagated from downstream layers
349+
- `size`: Lateral size & depth (W,H,D) of the image upsampled in the first place
350+
351+
# Outputs
352+
- `dx`: Downsampled version of `Δ`
353+
"""
354+
function ∇upsample_trilinear::AbstractArray{T,5}; size::NTuple{3,Integer}) where T
355+
w, h, d, c, n = Base.size(Δ)
356+
out_w, out_h, out_d = size
357+
if (w,h,d) == (out_w, out_h, out_d)
358+
return Δ
359+
end
360+
dx = zero(similar(Δ, T, size..., c, n))
361+
return ∇upsample_trilinear_whdcn!(dx, Δ)
362+
end
363+
364+
function ∇upsample_trilinear_whdcn!(dx::AbstractArray{T,5}, Δ::AbstractArray{T,5}) where T
365+
size(dx)[4:5] == size(Δ)[4:5] || error("Number of input and output channels and batches must match. Got dx $(size(dx)) and Δ $(size(Δ))")
366+
in_w, in_h, in_d, channels, batches = size(dx)
367+
# treat batch and channel dimension as one for better parallelization granularity
368+
channels *= batches
369+
out_w, out_h, out_d, _, _ = size(Δ)
370+
output_slice_size = out_h * out_w * out_d
371+
372+
# T() and // so that we can handle rationals (super slow)
373+
width_scale = T((in_w - 1) // (out_w - 1))
374+
height_scale = T((in_h - 1) // (out_h - 1))
375+
depth_scale = T((in_d - 1) // (out_d - 1))
376+
377+
@inline idx(c, d, h, w) = c * in_d * in_h * in_w + d * in_h * in_w + h * in_w + w + 1
378+
379+
@inbounds Threads.@threads for c in 0:channels-1
380+
for od in 0:out_d-1
381+
id0, id1, d0lambda, d1lambda = compute_source_index_and_lambda(depth_scale, od, in_d, out_d)
382+
for oh in 0:out_h-1
383+
ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda(height_scale, oh, in_h, out_h)
384+
for ow in 0:out_w-1
385+
iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda(width_scale, ow, in_w, out_w)
386+
output_offset = c * output_slice_size + od * out_w * out_h + oh * out_w + ow + 1
387+
Δ_value = Δ[output_offset]
388+
dx[idx(c, id0, ih0, iw0)] += d0lambda * h0lambda * w0lambda * Δ_value # /* i000 */
389+
dx[idx(c, id0, ih0, iw1)] += d0lambda * h0lambda * w1lambda * Δ_value # /* i001 */
390+
dx[idx(c, id0, ih1, iw0)] += d0lambda * h1lambda * w0lambda * Δ_value # /* i010 */
391+
dx[idx(c, id0, ih1, iw1)] += d0lambda * h1lambda * w1lambda * Δ_value # /* i011 */
392+
dx[idx(c, id1, ih0, iw0)] += d1lambda * h0lambda * w0lambda * Δ_value # /* i100 */
393+
dx[idx(c, id1, ih0, iw1)] += d1lambda * h0lambda * w1lambda * Δ_value # /* i101 */
394+
dx[idx(c, id1, ih1, iw0)] += d1lambda * h1lambda * w0lambda * Δ_value # /* i110 */
395+
dx[idx(c, id1, ih1, iw1)] += d1lambda * h1lambda * w1lambda * Δ_value # /* i111 */
396+
end
397+
end
398+
end
399+
end
400+
return dx
401+
end
402+
403+
function rrule(::typeof(upsample_trilinear), x; size)
404+
Ω = upsample_trilinear(x; size=size)
405+
function upsample_trilinear_pullback(Δ)
406+
(NO_FIELDS, ∇upsample_trilinear(Δ; size=(Base.size(x,1), Base.size(x,2), Base.size(x,3))))
407+
end
408+
return Ω, upsample_trilinear_pullback
409+
end
410+
411+
260412
"""
261413
pixel_shuffle(x, r::Integer)
262414

test/upsample.jl

+32-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
y = upsample_nearest(x, (2,3))
66
@test size(y) == (4,6,1,1)
77
∇upsample_nearest(y, (2,3)) == [6 12; 18 24]
8-
8+
99
gradtest(x -> upsample_nearest(x, (2,3)), rand(2,2,1,1))
1010

1111
y2 = upsample_nearest(x, size=(4,6))
@@ -65,6 +65,37 @@ end
6565
@test y == y_true_int
6666
end
6767

68+
@testset "Trilinear upsampling" begin
69+
# Layout: WHDCN, where D is depth
70+
# we generate data which is constant along W & H and differs in D
71+
# then we upsample along all dimensions
72+
x = ones(Float32, 3,3,3,1,1)
73+
x[:,:,1,:,:] .= 1.
74+
x[:,:,2,:,:] .= 2.
75+
x[:,:,3,:,:] .= 3.
76+
77+
y_true = ones(Float32, 5,5,5,1,1)
78+
y_true[:,:,1,:,:] .= 1.
79+
y_true[:,:,2,:,:] .= 1.5
80+
y_true[:,:,3,:,:] .= 2.
81+
y_true[:,:,4,:,:] .= 2.5
82+
y_true[:,:,5,:,:] .= 3.
83+
84+
y = upsample_trilinear(x; size=(5,5,5))
85+
86+
@test size(y) == size(y_true)
87+
@test eltype(y) == Float32
88+
@test collect(y) collect(y_true)
89+
90+
# this test only works when align_corners=false (not present for CPU yet)
91+
# o = ones(Float32,8,8,8,1,1)
92+
# grad_true = 8*ones(Float32,4,4,4,1,1)
93+
# @test ∇upsample_trilinear(o; size=(4,4,4)) ≈ grad_true
94+
95+
x = Float64.(x)
96+
gradtest(x -> upsample_trilinear(x, (2,2,2)), x)
97+
end
98+
6899
@testset "pixel_shuffle" begin
69100
x = reshape(1:16, (2, 2, 4, 1))
70101
# [:, :, 1, 1] =

0 commit comments

Comments
 (0)