@@ -61,20 +61,38 @@ function ChainRulesCore.rrule(::typeof(upsample_nearest), x::AbstractArray, s::T
61
61
end
62
62
63
63
"""
64
+ upsample_bilinear(x::AbstractArray{<:Number,4}, ks::NTuple{2,Int})
65
+ upsample_bilinear(x::AbstractArray{<:Number,4}, k::Int)
64
66
65
- Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `k `,
66
- using bilinear interpolation.
67
+ Upsamples the first 2 dimensions of the array `x` by the upsample factors stored in `ks `,
68
+ using bilinear interpolation. One integer is equivalent to `ks = (k,k)`.
67
69
68
- The size of the output is equal to
69
- `(k [1]*S1, k [2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`.
70
+ The size of the output is equal to
71
+ `(ks [1]*S1, ks [2]*S2, S3, S4)`, where `S1, S2, S3, S4 = size(x)`.
70
72
71
73
The interpolation grid is identical to the one used by `imresize` from `Images.jl`.
72
74
73
- Currently only 2d upsampling is supported.
75
+ Only two-dimensional upsampling is supported, hence "bi-linear".
76
+ See also [`upsample_nearest`](@ref) which allows any dimensions.
77
+
78
+ # Example
79
+ ```jldoctest
80
+ julia> upsample_bilinear(reshape([1 2 3; 4 5 6], 2,3,1,1), (2,4))
81
+ 4×12×1×1 Array{Float64, 4}:
82
+ [:, :, 1, 1] =
83
+ 1.0 1.0 1.125 1.375 1.625 1.875 2.125 2.375 2.625 2.875 3.0 3.0
84
+ 1.75 1.75 1.875 2.125 2.375 2.625 2.875 3.125 3.375 3.625 3.75 3.75
85
+ 3.25 3.25 3.375 3.625 3.875 4.125 4.375 4.625 4.875 5.125 5.25 5.25
86
+ 4.0 4.0 4.125 4.375 4.625 4.875 5.125 5.375 5.625 5.875 6.0 6.0
87
+ ```
74
88
"""
75
- function upsample_bilinear (x:: AbstractArray{T,4} , k:: NTuple{2,Int} ) where T
89
+ upsample_bilinear (x:: AbstractArray{<:Number,4} , k:: Int ) = upsample_bilinear (x, (k,k))
90
+
91
+ upsample_bilinear (x:: AbstractArray{<:Integer,4} , k:: NTuple{2,Int} ) = upsample_bilinear (float (x), k)
92
+
93
+ function upsample_bilinear (x:: AbstractArray{<:Number,4} , k:: NTuple{2,Int} )
76
94
# This function is gpu friendly
77
-
95
+
78
96
imgsize = size (x)
79
97
newsize = get_newsize (imgsize, k)
80
98
@@ -284,20 +302,67 @@ end
284
302
285
303
"""
286
304
pixel_shuffle(x, r)
287
-
305
+
288
306
Pixel shuffling operation. `r` is the upscale factor for shuffling.
289
307
The operation converts an input of size [W,H,r²C,N] to size [rW,rH,C,N]
290
- Used extensively in super-resolution networks to upsample
308
+
309
+ Used extensively in super-resolution networks to upsample
291
310
towards high resolution features.
311
+ Reference : https://arxiv.org/abs/1609.05158
292
312
293
- Reference : https://arxiv.org/pdf/1609.05158.pdf
313
+ # Example
314
+ ```jldoctest
315
+ julia> x = [10i + j + channel/10 for i in 1:2, j in 1:3, channel in 1:4, batch in 1:1]
316
+ 2×3×4×1 Array{Float64, 4}:
317
+ [:, :, 1, 1] =
318
+ 11.1 12.1 13.1
319
+ 21.1 22.1 23.1
320
+
321
+ [:, :, 2, 1] =
322
+ 11.2 12.2 13.2
323
+ 21.2 22.2 23.2
324
+
325
+ [:, :, 3, 1] =
326
+ 11.3 12.3 13.3
327
+ 21.3 22.3 23.3
328
+
329
+ [:, :, 4, 1] =
330
+ 11.4 12.4 13.4
331
+ 21.4 22.4 23.4
332
+
333
+ julia> pixel_shuffle(x, 2)
334
+ 4×6×1×1 Array{Float64, 4}:
335
+ [:, :, 1, 1] =
336
+ 11.1 11.3 12.1 12.3 13.1 13.3
337
+ 11.2 11.4 12.2 12.4 13.2 13.4
338
+ 21.1 21.3 22.1 22.3 23.1 23.3
339
+ 21.2 21.4 22.2 22.4 23.2 23.4
340
+
341
+ julia> y = [i + channel/10 for i in 1:3, channel in 1:6, batch in 1:1]
342
+ 3×6×1 Array{Float64, 3}:
343
+ [:, :, 1] =
344
+ 1.1 1.2 1.3 1.4 1.5 1.6
345
+ 2.1 2.2 2.3 2.4 2.5 2.6
346
+ 3.1 3.2 3.3 3.4 3.5 3.6
347
+
348
+ julia> pixel_shuffle(y, 2)
349
+ 6×3×1 Array{Float64, 3}:
350
+ [:, :, 1] =
351
+ 1.1 1.3 1.5
352
+ 1.2 1.4 1.6
353
+ 2.1 2.3 2.5
354
+ 2.2 2.4 2.6
355
+ 3.1 3.3 3.5
356
+ 3.2 3.4 3.6
357
+
358
+ ```
294
359
"""
295
360
function pixel_shuffle (x:: AbstractArray , r:: Integer )
296
- @assert ndims (x) > 2
361
+ ndims (x) > 2 || throw ( ArgumentError ( " expected x with at least 3 dimensions " ))
297
362
d = ndims (x) - 2
298
363
sizein = size (x)[1 : d]
299
- cin, n = size (x, d+ 1 ), size (x, d+ 2 )
300
- @assert cin % r^ d == 0
364
+ cin, n = size (x, d+ 1 ), size (x, d+ 2 )
365
+ cin % r^ d == 0 || throw ( ArgumentError ( " expected channel dimension to be divisible by r^d = $(r ^ d) , where d= $d is the number of spatial dimensions. Given r= $r , input size(x) = $( size (x)) " ))
301
366
cout = cin ÷ r^ d
302
367
# x = reshape(x, sizein..., fill(r, d)..., cout, n) # bug https://github.com/FluxML/Zygote.jl/issues/866
303
368
x = reshape (x, sizein... , ntuple (i-> r, d)... , cout, n)
0 commit comments