Skip to content

Commit 953279e

Browse files
Merge pull request #269 from mcabbott/upsample
Add `upsample_nearest`
2 parents a3e2847 + b8f40f8 commit 953279e

File tree

2 files changed

+79
-2
lines changed

2 files changed

+79
-2
lines changed

src/upsample.jl

+65-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,68 @@
1-
export upsample_bilinear, ∇upsample_bilinear, pixel_shuffle
1+
export upsample_nearest, ∇upsample_nearest,
2+
upsample_bilinear, ∇upsample_bilinear,
3+
pixel_shuffle
4+
5+
"""
6+
upsample_nearest(x::AbstractArray, scale::NTuple{S,Int})
7+
8+
Upsamples by integer multiples along the first `S` dimensions.
9+
Subsequent dimensions of `x` are not altered.
10+
11+
See also [`upsample_bilinear`](@ref), for two dimensions of an `N=4` array.
12+
13+
# Example
14+
```jldoctest
15+
julia> upsample_nearest([1 2 3; 4 5 6], (2,3))
16+
4×9 Array{$Int,2}:
17+
1 1 1 2 2 2 3 3 3
18+
1 1 1 2 2 2 3 3 3
19+
4 4 4 5 5 5 6 6 6
20+
4 4 4 5 5 5 6 6 6
21+
22+
julia> upsample_nearest([1 2 3; 4 5 6], (2,))
23+
4×3 Array{$Int,1}:
24+
1 2 3
25+
1 2 3
26+
4 5 6
27+
4 5 6
28+
```
29+
"""
30+
function upsample_nearest(x::AbstractArray{T,N}, scales::NTuple{S, <:Integer}) where {T,N,S}
31+
S in 1:N || throw(ArgumentError("can't upsample ndims(x)=$N with scale=$scales"))
32+
outsize = ntuple(d -> d<=S ? scales[d] * size(x,d) : size(x,d), N)
33+
out = similar(x, T, outsize)
34+
writesize = ntuple(N+S) do d
35+
d > 2S && return size(x, d-S)
36+
isodd(d) ? scales[cld(d,2)] : size(x, cld(d,2))
37+
end
38+
readsize = ntuple(N+S) do d
39+
d > 2S && return size(x, d-S)
40+
isodd(d) ? 1 : size(x, cld(d,2))
41+
end
42+
reshape(out, writesize) .= reshape(x, readsize)
43+
out
44+
end
45+
46+
function ∇upsample_nearest(x::AbstractArray{T,N}, scales::NTuple{S, <:Integer}) where {T,N,S}
47+
outsize = ntuple(N) do d
48+
d > S && return size(x,d)
49+
rem(size(x,d), scales[d]) == 0 || throw(ArgumentError("expected input array evenly divisible by scale=$scales, got size(x)=$(size(x))"))
50+
div(size(x,d), scales[d])
51+
end
52+
tempsize = ntuple(N+S) do d
53+
d > 2S && return size(x, d-S)
54+
s = scales[cld(d,2)]
55+
isodd(d) ? s : div(size(x, cld(d,2)),s)
56+
end
57+
mid = sum(reshape(x, tempsize), dims=ntuple(d -> 2d-1, S))
58+
reshape(mid, outsize)
59+
end
60+
61+
function ChainRulesCore.rrule(::typeof(upsample_nearest), x::AbstractArray, s::Tuple)
62+
Ω = upsample_nearest(x, s)
63+
upsample_nearest_pullback(Δ) = (NO_FIELDS, ∇upsample_nearest(Δ, s), DoesNotExist())
64+
return Ω, upsample_nearest_pullback
65+
end
266

367
"""
468
upsample_bilinear(x::AbstractArray{<:Number,4}, k::NTuple{2,Int})

test/upsample.jl

+14-1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,17 @@
1+
@testset "upsample_nearest, integer scale via reshape" begin
2+
x = reshape(Float32[1. 2.; 3. 4.], (2,2,1,1))
3+
@test upsample_nearest(x, (3,3))[1,:] == [1,1,1, 2,2,2]
4+
5+
y = upsample_nearest(x, (2,3))
6+
@test size(y) == (4,6,1,1)
7+
∇upsample_nearest(y, (2,3)) == [6 12; 18 24]
8+
9+
gradtest(x -> upsample_nearest(x, (2,3)), rand(2,2,1,1), check_rrule=false)
10+
11+
@test_throws ArgumentError ∇upsample_nearest(y, (2,4))
12+
@test_throws ArgumentError upsample_nearest(x, (1,2,3,4,5))
13+
end
14+
115
@testset "upsample_bilinear 2d" begin
216
x = reshape(Float32[1. 2.; 3. 4.], (2,2,1,1))
317
y_true = [1//1 5//4 7//4 2//1;
@@ -90,4 +104,3 @@ end
90104
gradtest(x -> pixel_shuffle(x, r), x)
91105
end
92106
end
93-

0 commit comments

Comments
 (0)