Skip to content

Commit 5531d91

Browse files
author
Max Freudenberg
committed
introduce trilinear upsampling
1 parent b1633f5 commit 5531d91

File tree

1 file changed

+84
-1
lines changed

1 file changed

+84
-1
lines changed

src/upsample.jl

+84-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
export upsample_nearest, ∇upsample_nearest,
22
upsample_bilinear, ∇upsample_bilinear,
3+
upsample_trilinear, ∇upsample_trilinear,
34
pixel_shuffle
45

56
"""
@@ -9,7 +10,7 @@ export upsample_nearest, ∇upsample_nearest,
910
Upsamples the array `x` by integer multiples along the first `S` dimensions.
1011
Subsequent dimensions of `x` are not altered.
1112
12-
Either the `scale` factors or the final output `size` can be specified.
13+
Either the `scale` factors or the final output `size` can be specified.
1314
1415
See also [`upsample_bilinear`](@ref), for two dimensions of an `N=4` array.
1516
@@ -257,6 +258,88 @@ function rrule(::typeof(upsample_bilinear), x; size)
257258
return Ω, upsample_bilinear_pullback
258259
end
259260

261+
###########
262+
# trilinear
263+
###########
264+
"""
265+
upsample_trilinear(x::AbstractArray{T,5}, scale::NTuple{3,Real})
266+
upsample_trilinear(x::AbstractArray{T,5}; size::NTuple{3,Integer})
267+
268+
Upsamples the first 3 dimensions of the array `x` by the upsample factors stored in `scale`,
269+
using trilinear interpolation. As an alternative to using `scale`, the resulting image `size`
270+
can be directly specified with a keyword argument.
271+
272+
The size of the output is equal to
273+
`(scale[1]*S1, scale[2]*S2, scale[3]*S3, S4, S5)`, where `S1, S2, S3, S4, S5 = size(x)`.
274+
275+
# Examples
276+
277+
```julia
278+
upsample_trilinear(x, (2, 3, 4))
279+
upsample_trilinear(x; size=(4, 9, 11)) # specify ouput size instead
280+
upsample_trilinear(x, (2.5, 3.5, pi)) # non-integer scaling factors are allowed
281+
```
282+
"""
283+
function upsample_trilinear(x::AbstractArray{<:Any,5}, scale::NTuple{3,Real})
284+
outsize = ntuple(i -> floor(Int, scale[i] * Base.size(x, i)), 3)
285+
return upsample_trilinear(x; size=outsize)
286+
end
287+
288+
upsample_trilinear(x, scale::Real) = upsample_trilinear(x, (scale,scale,scale))
289+
290+
function upsample_trilinear(x::AbstractArray{T,5}; size::NTuple{3,Integer}) where T
291+
w,h,d,c,n = Base.size(x)
292+
if (w,h,d) == size
293+
return x
294+
end
295+
y = similar(x, T, size..., c, n)
296+
return upsample_trilinear_whdcn!(y, x)
297+
end
298+
299+
function upsample_trilinear(x::AbstractArray{T,5}; size::NTuple{3,Integer}) where T<:Integer
300+
y = float.(x)
301+
res = upsample_trilinear(y; size=size)
302+
return round.(T, res)
303+
end
304+
305+
# placeholder for CPU implementation
306+
# which is overloaded in NNlibCUDA
307+
# I think a CPU implementation doesn't make sense, as it will be way too slow
308+
# for any meaningful data
309+
function upsample_trilinear_whdcn! end
310+
311+
"""
312+
∇upsample_trilinear(Δ::AbstractArray{T,5}; size::NTuple{3,Integer}) where T
313+
314+
# Arguments
315+
- `Δ`: Incoming gradient array, backpropagated from downstream layers
316+
- `size`: Lateral size & depth (W,H,D) of the image upsampled in the first place
317+
318+
# Outputs
319+
- `dx`: Downsampled version of `Δ`
320+
"""
321+
function ∇upsample_trilinear::AbstractArray{T,5}; size::NTuple{3,Integer}) where T
322+
w, h, d, c, n = Base.size(Δ)
323+
out_w, out_h, out_d = size
324+
if (w,h,d) == (out_w, out_h, out_d)
325+
return Δ
326+
end
327+
dx = zero(similar(Δ, T, size..., c, n))
328+
return ∇upsample_trilinear_whdcn!(dx, Δ)
329+
end
330+
331+
# placeholder
332+
function ∇upsample_trilinear_whdcn! end
333+
334+
function rrule(::typeof(upsample_trilinear), x; size)
335+
Ω = upsample_trilinear(x; size=size)
336+
function upsample_trilinear_pullback(Δ)
337+
(NO_FIELDS, ∇upsample_trilinear(Δ; size=(Base.size(x,1), Base.size(x,2), Base.size(x,3))))
338+
end
339+
return Ω, upsample_trilinear_pullback
340+
end
341+
342+
260343
"""
261344
pixel_shuffle(x, r::Integer)
262345

0 commit comments

Comments
 (0)