1
1
export upsample_nearest, ∇upsample_nearest,
2
2
upsample_bilinear, ∇upsample_bilinear,
3
+ upsample_trilinear, ∇upsample_trilinear,
3
4
pixel_shuffle
4
5
5
6
"""
@@ -9,7 +10,7 @@ export upsample_nearest, ∇upsample_nearest,
9
10
Upsamples the array `x` by integer multiples along the first `S` dimensions.
10
11
Subsequent dimensions of `x` are not altered.
11
12
12
- Either the `scale` factors or the final output `size` can be specified.
13
+ Either the `scale` factors or the final output `size` can be specified.
13
14
14
15
See also [`upsample_bilinear`](@ref), for two dimensions of an `N=4` array.
15
16
@@ -257,6 +258,157 @@ function rrule(::typeof(upsample_bilinear), x; size)
257
258
return Ω, upsample_bilinear_pullback
258
259
end
259
260
261
+ # ##########
262
+ # trilinear
263
+ # ##########
264
+ """
265
+ upsample_trilinear(x::AbstractArray{T,5}, scale::NTuple{3,Real})
266
+ upsample_trilinear(x::AbstractArray{T,5}; size::NTuple{3,Integer})
267
+
268
+ Upsamples the first 3 dimensions of the array `x` by the upsample factors stored in `scale`,
269
+ using trilinear interpolation. As an alternative to using `scale`, the resulting image `size`
270
+ can be directly specified with a keyword argument.
271
+
272
+ The size of the output is equal to
273
+ `(scale[1]*S1, scale[2]*S2, scale[3]*S3, S4, S5)`, where `S1, S2, S3, S4, S5 = size(x)`.
274
+
275
+ # Examples
276
+
277
+ ```julia
278
+ upsample_trilinear(x, (2, 3, 4))
279
+ upsample_trilinear(x; size=(4, 9, 11)) # specify ouput size instead
280
+ upsample_trilinear(x, (2.5, 3.5, pi)) # non-integer scaling factors are allowed
281
+ ```
282
+ """
283
+ function upsample_trilinear (x:: AbstractArray{<:Any,5} , scale:: NTuple{3,Real} )
284
+ outsize = ntuple (i -> floor (Int, scale[i] * Base. size (x, i)), 3 )
285
+ return upsample_trilinear (x; size= outsize)
286
+ end
287
+
288
+ upsample_trilinear (x, scale:: Real ) = upsample_trilinear (x, (scale,scale,scale))
289
+
290
+ function upsample_trilinear (x:: AbstractArray{T,5} ; size:: NTuple{3,Integer} ) where T
291
+ w,h,d,c,n = Base. size (x)
292
+ if (w,h,d) == size
293
+ return x
294
+ end
295
+ y = similar (x, T, size... , c, n)
296
+ return upsample_trilinear_whdcn! (y, x)
297
+ end
298
+
299
+ function upsample_trilinear (x:: AbstractArray{T,5} ; size:: NTuple{3,Integer} ) where T<: Integer
300
+ y = float .(x)
301
+ res = upsample_trilinear (y; size= size)
302
+ return round .(T, res)
303
+ end
304
+
305
+ function upsample_trilinear_whdcn! (output:: AbstractArray{T,5} , input:: AbstractArray{T,5} ) where T
306
+ size (input)[4 : 5 ] == size (output)[4 : 5 ] || error (" Number of input and output channels and batches must match. Got input $(size (input)) and output $(size (output)) " )
307
+ in_w, in_h, in_d, channels, batches = size (input)
308
+ # treat batch and channel dimension as one for better parallelization granularity
309
+ channels *= batches
310
+ out_w, out_h, out_d, _, _ = size (output)
311
+ output_slice_size = out_h * out_w * out_d
312
+
313
+ # T() and // so that we can handle rationals (super slow)
314
+ width_scale = T ((in_w - 1 ) // (out_w - 1 ))
315
+ height_scale = T ((in_h - 1 ) // (out_h - 1 ))
316
+ depth_scale = T ((in_d - 1 ) // (out_d - 1 ))
317
+
318
+ @inline idx (c, d, h, w) = c * in_d * in_h * in_w + d * in_h * in_w + h * in_w + w + 1
319
+
320
+ @inbounds Threads. @threads for c in 0 : channels- 1
321
+ for od in 0 : out_d- 1
322
+ id0, id1, d0lambda, d1lambda = compute_source_index_and_lambda (depth_scale, od, in_d, out_d)
323
+ for oh in 0 : out_h- 1
324
+ ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda (height_scale, oh, in_h, out_h)
325
+ for ow in 0 : out_w- 1
326
+ iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda (width_scale, ow, in_w, out_w)
327
+ output_offset = c * output_slice_size + od * out_w * out_h + oh * out_w + ow + 1
328
+ output[output_offset] =
329
+ d0lambda * h0lambda * w0lambda * input[idx (c, id0, ih0, iw0)] + # d0 * h0 * w0 * i000
330
+ d0lambda * h0lambda * w1lambda * input[idx (c, id0, ih0, iw1)] + # d0 * h0 * w1 * i001
331
+ d0lambda * h1lambda * w0lambda * input[idx (c, id0, ih1, iw0)] + # d0 * h1 * w0 * i010
332
+ d0lambda * h1lambda * w1lambda * input[idx (c, id0, ih1, iw1)] + # d0 * h1 * w1 * i011
333
+ d1lambda * h0lambda * w0lambda * input[idx (c, id1, ih0, iw0)] + # d1 * h0 * w0 * i100
334
+ d1lambda * h0lambda * w1lambda * input[idx (c, id1, ih0, iw1)] + # d1 * h0 * w1 * i101
335
+ d1lambda * h1lambda * w0lambda * input[idx (c, id1, ih1, iw0)] + # d1 * h1 * w0 * i110
336
+ d1lambda * h1lambda * w1lambda * input[idx (c, id1, ih1, iw1)] # d1 * h1 * w1 * i111
337
+ end
338
+ end
339
+ end
340
+ end
341
+ return output
342
+ end
343
+
344
+ """
345
+ ∇upsample_trilinear(Δ::AbstractArray{T,5}; size::NTuple{3,Integer}) where T
346
+
347
+ # Arguments
348
+ - `Δ`: Incoming gradient array, backpropagated from downstream layers
349
+ - `size`: Lateral size & depth (W,H,D) of the image upsampled in the first place
350
+
351
+ # Outputs
352
+ - `dx`: Downsampled version of `Δ`
353
+ """
354
+ function ∇upsample_trilinear (Δ:: AbstractArray{T,5} ; size:: NTuple{3,Integer} ) where T
355
+ w, h, d, c, n = Base. size (Δ)
356
+ out_w, out_h, out_d = size
357
+ if (w,h,d) == (out_w, out_h, out_d)
358
+ return Δ
359
+ end
360
+ dx = zero (similar (Δ, T, size... , c, n))
361
+ return ∇upsample_trilinear_whdcn! (dx, Δ)
362
+ end
363
+
364
+ function ∇upsample_trilinear_whdcn! (dx:: AbstractArray{T,5} , Δ:: AbstractArray{T,5} ) where T
365
+ size (dx)[4 : 5 ] == size (Δ)[4 : 5 ] || error (" Number of input and output channels and batches must match. Got dx $(size (dx)) and Δ $(size (Δ)) " )
366
+ in_w, in_h, in_d, channels, batches = size (dx)
367
+ # treat batch and channel dimension as one for better parallelization granularity
368
+ channels *= batches
369
+ out_w, out_h, out_d, _, _ = size (Δ)
370
+ output_slice_size = out_h * out_w * out_d
371
+
372
+ # T() and // so that we can handle rationals (super slow)
373
+ width_scale = T ((in_w - 1 ) // (out_w - 1 ))
374
+ height_scale = T ((in_h - 1 ) // (out_h - 1 ))
375
+ depth_scale = T ((in_d - 1 ) // (out_d - 1 ))
376
+
377
+ @inline idx (c, d, h, w) = c * in_d * in_h * in_w + d * in_h * in_w + h * in_w + w + 1
378
+
379
+ @inbounds Threads. @threads for c in 0 : channels- 1
380
+ for od in 0 : out_d- 1
381
+ id0, id1, d0lambda, d1lambda = compute_source_index_and_lambda (depth_scale, od, in_d, out_d)
382
+ for oh in 0 : out_h- 1
383
+ ih0, ih1, h0lambda, h1lambda = compute_source_index_and_lambda (height_scale, oh, in_h, out_h)
384
+ for ow in 0 : out_w- 1
385
+ iw0, iw1, w0lambda, w1lambda = compute_source_index_and_lambda (width_scale, ow, in_w, out_w)
386
+ output_offset = c * output_slice_size + od * out_w * out_h + oh * out_w + ow + 1
387
+ Δ_value = Δ[output_offset]
388
+ dx[idx (c, id0, ih0, iw0)] += d0lambda * h0lambda * w0lambda * Δ_value # /* i000 */
389
+ dx[idx (c, id0, ih0, iw1)] += d0lambda * h0lambda * w1lambda * Δ_value # /* i001 */
390
+ dx[idx (c, id0, ih1, iw0)] += d0lambda * h1lambda * w0lambda * Δ_value # /* i010 */
391
+ dx[idx (c, id0, ih1, iw1)] += d0lambda * h1lambda * w1lambda * Δ_value # /* i011 */
392
+ dx[idx (c, id1, ih0, iw0)] += d1lambda * h0lambda * w0lambda * Δ_value # /* i100 */
393
+ dx[idx (c, id1, ih0, iw1)] += d1lambda * h0lambda * w1lambda * Δ_value # /* i101 */
394
+ dx[idx (c, id1, ih1, iw0)] += d1lambda * h1lambda * w0lambda * Δ_value # /* i110 */
395
+ dx[idx (c, id1, ih1, iw1)] += d1lambda * h1lambda * w1lambda * Δ_value # /* i111 */
396
+ end
397
+ end
398
+ end
399
+ end
400
+ return dx
401
+ end
402
+
403
+ function rrule (:: typeof (upsample_trilinear), x; size)
404
+ Ω = upsample_trilinear (x; size= size)
405
+ function upsample_trilinear_pullback (Δ)
406
+ (NO_FIELDS, ∇upsample_trilinear (Δ; size= (Base. size (x,1 ), Base. size (x,2 ), Base. size (x,3 ))))
407
+ end
408
+ return Ω, upsample_trilinear_pullback
409
+ end
410
+
411
+
260
412
"""
261
413
pixel_shuffle(x, r::Integer)
262
414
0 commit comments