Skip to content

Commit 96ec8f9

Browse files
author
Michael Abbott
committedJan 17, 2021
extend all the docstrings in this file, and fix some types & errors
unrelated to the PR, but while I'm here...
1 parent bfa9dc1 commit 96ec8f9

File tree

1 file changed

+78
-13
lines changed

1 file changed

+78
-13
lines changed
 

‎src/upsample.jl

+78-13
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,38 @@ function ChainRulesCore.rrule(::typeof(upsample_nearest), x::AbstractArray, s::T
6161
end
6262

6363
"""
64+
upsample_bilinear(x::AbstractArray{<:Number,4}, ks::NTuple{2,Int})
65+
upsample_bilinear(x::AbstractArray{<:Number,4}, k::Int)
6466
65-
Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `k`,
66-
using bilinear interpolation.
67+
Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `ks`,
68+
using bilinear interpolation. One integer is equivalent to `ks = (k,k)`.
6769
68-
The size of the output is equal to
69-
`(k[1]*S1, k[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`.
70+
The size of the output is equal to
71+
`(ks[1]*S1, ks[2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`.
7072
7173
The interpolation grid is identical to the one used by `imresize` from `Images.jl`.
7274
73-
Currently only 2d upsampling is supported.
75+
Only two-dimensional upsampling is supported, hence "bi-linear".
76+
See also [`upsample_nearest`](@ref) which allows any dimensions.
77+
78+
# Example
79+
```jldoctest
80+
julia> upsample_bilinear(reshape([1 2 3; 4 5 6], 2,3,1,1), (2,4))
81+
4×12×1×1 Array{Float64, 4}:
82+
[:, :, 1, 1] =
83+
1.0 1.0 1.125 1.375 1.625 1.875 2.125 2.375 2.625 2.875 3.0 3.0
84+
1.75 1.75 1.875 2.125 2.375 2.625 2.875 3.125 3.375 3.625 3.75 3.75
85+
3.25 3.25 3.375 3.625 3.875 4.125 4.375 4.625 4.875 5.125 5.25 5.25
86+
4.0 4.0 4.125 4.375 4.625 4.875 5.125 5.375 5.625 5.875 6.0 6.0
87+
```
7488
"""
75-
function upsample_bilinear(x::AbstractArray{T,4}, k::NTuple{2,Int}) where T
89+
upsample_bilinear(x::AbstractArray{<:Number,4}, k::Int) = upsample_bilinear(x, (k,k))
90+
91+
upsample_bilinear(x::AbstractArray{<:Integer,4}, k::NTuple{2,Int}) = upsample_bilinear(float(x), k)
92+
93+
function upsample_bilinear(x::AbstractArray{<:Number,4}, k::NTuple{2,Int})
7694
# This function is gpu friendly
77-
95+
7896
imgsize = size(x)
7997
newsize = get_newsize(imgsize, k)
8098

@@ -284,20 +302,67 @@ end
284302

285303
"""
286304
pixel_shuffle(x, r)
287-
305+
288306
Pixel shuffling operation. `r` is the upscale factor for shuffling.
289307
The operation converts an input of size [W,H,r²C,N] to size [rW,rH,C,N]
290-
Used extensively in super-resolution networks to upsample
308+
309+
Used extensively in super-resolution networks to upsample
291310
towards high resolution features.
311+
Reference : https://arxiv.org/abs/1609.05158
292312
293-
Reference : https://arxiv.org/pdf/1609.05158.pdf
313+
# Example
314+
```jldoctest
315+
julia> x = [10i + j + channel/10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1]
316+
2×3×4×1 Array{Float64, 4}:
317+
[:, :, 1, 1] =
318+
11.1 12.1 13.1
319+
21.1 22.1 23.1
320+
321+
[:, :, 2, 1] =
322+
11.2 12.2 13.2
323+
21.2 22.2 23.2
324+
325+
[:, :, 3, 1] =
326+
11.3 12.3 13.3
327+
21.3 22.3 23.3
328+
329+
[:, :, 4, 1] =
330+
11.4 12.4 13.4
331+
21.4 22.4 23.4
332+
333+
julia> pixel_shuffle(x, 2)
334+
4×6×1×1 Array{Float64, 4}:
335+
[:, :, 1, 1] =
336+
11.1 11.3 12.1 12.3 13.1 13.3
337+
11.2 11.4 12.2 12.4 13.2 13.4
338+
21.1 21.3 22.1 22.3 23.1 23.3
339+
21.2 21.4 22.2 22.4 23.2 23.4
340+
341+
julia> y = [i + channel/10 for i in 1:3, channel in 1:6, batch in 1:1]
342+
3×6×1 Array{Float64, 3}:
343+
[:, :, 1] =
344+
1.1 1.2 1.3 1.4 1.5 1.6
345+
2.1 2.2 2.3 2.4 2.5 2.6
346+
3.1 3.2 3.3 3.4 3.5 3.6
347+
348+
julia> pixel_shuffle(y, 2)
349+
6×3×1 Array{Float64, 3}:
350+
[:, :, 1] =
351+
1.1 1.3 1.5
352+
1.2 1.4 1.6
353+
2.1 2.3 2.5
354+
2.2 2.4 2.6
355+
3.1 3.3 3.5
356+
3.2 3.4 3.6
357+
358+
```
294359
"""
295360
function pixel_shuffle(x::AbstractArray, r::Integer)
296-
@assert ndims(x) > 2
361+
ndims(x) > 2 || throw(ArgumentError("expected x with at least 3 dimensions"))
297362
d = ndims(x) - 2
298363
sizein = size(x)[1:d]
299-
cin, n = size(x, d+1), size(x, d+2)
300-
@assert cin % r^d == 0
364+
cin, n = size(x, d+1), size(x, d+2)
365+
cin % r^d == 0 || throw(ArgumentError("expected channel dimension to be divisible by r^d = $(r^d), where d=$d is the number of spatial dimensions. Given r=$r, input size(x) = $(size(x))"))
301366
cout = cin ÷ r^d
302367
# x = reshape(x, sizein..., fill(r, d)..., cout, n) # bug https://github.com/FluxML/Zygote.jl/issues/866
303368
x = reshape(x, sizein..., ntuple(i->r, d)..., cout, n)

0 commit comments

Comments
 (0)