Skip to content

Commit ddb5e9c

Browse files
bors[bot]CarloLucibelloDhairyaLGandhi
authored
Merge #1468
1468: add Upsample and PixelShuffle layers r=DhairyaLGandhi a=CarloLucibello ### PR Checklist - [x] Tests are added - [x] Entry in NEWS.md - [x] Documentation, if applicable - [ ] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com> Co-authored-by: Carlo Lucibello <carlo.lucibello@unibocconi.it> Co-authored-by: Dhairya Gandhi <dhairya@juliacomputing.com>
2 parents cdb445c + 9acaae9 commit ddb5e9c

File tree

12 files changed

+184
-12
lines changed

12 files changed

+184
-12
lines changed

Manifest.toml

+4-4
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,9 @@ uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
9393
version = "0.3.4+0"
9494

9595
[[DataAPI]]
96-
git-tree-sha1 = "6d64b28d291cb94a0d84e6e41081fb081e7f717f"
96+
git-tree-sha1 = "8ab70b4de35bb3b8cc19654f6b893cf5164f8ee8"
9797
uuid = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
98-
version = "1.5.0"
98+
version = "1.5.1"
9999

100100
[[DataStructures]]
101101
deps = ["Compat", "InteractiveUtils", "OrderedCollections"]
@@ -236,9 +236,9 @@ uuid = "a63ad114-7e13-5084-954f-fe012c677804"
236236

237237
[[NNlib]]
238238
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Pkg", "Requires", "Statistics"]
239-
git-tree-sha1 = "573cc0d31f9697b9d2b060130a7a3c05a4f36b78"
239+
git-tree-sha1 = "df42d0816edfc24f5b82a728f46381613c4dff79"
240240
uuid = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
241-
version = "0.7.12"
241+
version = "0.7.14"
242242

243243
[[NaNMath]]
244244
git-tree-sha1 = "bfe47e760d60b82b66b61d2d44128b62e3a369fb"

NEWS.md

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
* Moved GPU CI to use buildkite instead of GitLab
1414
* New [`Parallel` layer](https://github.com/FluxML/Flux.jl/pull/1462) adds inception module-like building blocks.
1515
* Feature additions and bug fixes for BatchNorm, LayerNorm, InstanceNorm, and GroupNorm [normalization layers](https://github.com/FluxML/Flux.jl/pull/1397)
16+
* Added [Upsample and PixelShuffle layers](https://github.com/FluxML/Flux.jl/pull/1468)
1617

1718
## v0.11.2
1819

Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ Colors = "0.12"
3434
Functors = "0.1, 0.2"
3535
Juno = "0.8"
3636
MacroTools = "0.5"
37-
NNlib = "0.7.10"
37+
NNlib = "0.7.14"
3838
Reexport = "0.2, 1.0"
3939
StatsBase = "0.33"
4040
ZipFile = "0.9"

docs/src/models/layers.md

+7
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ Flux.convfilter
2929
Flux.depthwiseconvfilter
3030
```
3131

32+
## Upsampling Layers
33+
34+
```@docs
35+
Upsample
36+
PixelShuffle
37+
```
38+
3239
## Recurrent Layers
3340

3441
Much like the core layers above, but can be used to process sequence data (as well as other kinds of structured data).

docs/src/models/nnlib.md

+8
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,14 @@ NNlib.conv
5151
NNlib.depthwiseconv
5252
```
5353

54+
## Upsampling
55+
56+
```@docs
57+
NNlib.upsample_nearest
58+
NNlib.upsample_bilinear
59+
NNlib.pixel_shuffle
60+
```
61+
5462
## Batched Operations
5563

5664
```@docs

src/Flux.jl

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ export Chain, Dense, Maxout, SkipConnection, Parallel, flatten,
1616
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
1717
AdaptiveMaxPool, AdaptiveMeanPool, GlobalMaxPool, GlobalMeanPool, MaxPool, MeanPool,
1818
Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
19+
Upsample, PixelShuffle,
1920
params, fmap, cpu, gpu, f32, f64,
2021
testmode!, trainmode!
2122

@@ -42,6 +43,7 @@ include("layers/basic.jl")
4243
include("layers/conv.jl")
4344
include("layers/recurrent.jl")
4445
include("layers/normalise.jl")
46+
include("layers/upsample.jl")
4547

4648
include("outputsize.jl")
4749

src/layers/upsample.jl

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""
2+
Upsample(mode = :nearest; [scale, size])
3+
Upsample(scale, mode = :nearest)
4+
5+
An upsampling layer. One of two keywords must be given:
6+
7+
If `scale` is a number, this applies to all but the last two dimensions (channel and batch) of the input.
8+
It may also be a tuple, to control dimensions individually. Alternatively, keyword
9+
`size` accepts a tuple, to directly specify the leading dimensions of the output.
10+
11+
Currently supported upsampling `mode`s
12+
and corresponding NNlib's methods are:
13+
- `:nearest` -> [`NNlib.upsample_nearest`](@ref)
14+
- `:bilinear` -> [`NNlib.upsample_bilinear`](@ref)
15+
16+
# Examples
17+
18+
```juliarepl
19+
julia> m = Upsample(scale = (2, 3))
20+
Upsample(:nearest, scale = (2, 3))
21+
22+
julia> m(ones(2, 2, 1, 1)) |> size
23+
(4, 6, 1, 1)
24+
25+
julia> m = Upsample(:bilinear, size = (4, 5))
26+
Upsample(:bilinear, size = (4, 5))
27+
28+
julia> m(ones(2, 2, 1, 1)) |> size
29+
(4, 5, 1, 1)
30+
"""
31+
struct Upsample{mode, S, T}
32+
scale::S
33+
size::T
34+
end
35+
36+
function Upsample(mode::Symbol = :nearest; scale = nothing, size = nothing)
37+
mode in [:nearest, :bilinear] ||
38+
throw(ArgumentError("mode=:$mode is not supported."))
39+
if !(isnothing(scale) isnothing(size))
40+
throw(ArgumentError("Either scale or size should be specified (but not both)."))
41+
end
42+
return Upsample{mode,typeof(scale),typeof(size)}(scale, size)
43+
end
44+
45+
Upsample(scale, mode::Symbol = :nearest) = Upsample(mode; scale)
46+
47+
(m::Upsample{:nearest})(x::AbstractArray) =
48+
NNlib.upsample_nearest(x, m.scale)
49+
function (m::Upsample{:nearest, Int})(x::AbstractArray{T, N}) where {T, N}
50+
NNlib.upsample_nearest(x, ntuple(i -> m.scale, N-2))
51+
end
52+
(m::Upsample{:nearest, Nothing})(x::AbstractArray) =
53+
NNlib.upsample_nearest(x; size=m.size)
54+
55+
(m::Upsample{:bilinear})(x::AbstractArray) =
56+
NNlib.upsample_bilinear(x, m.scale)
57+
(m::Upsample{:bilinear, Nothing})(x::AbstractArray) =
58+
NNlib.upsample_bilinear(x; size=m.size)
59+
60+
function Base.show(io::IO, u::Upsample{mode}) where {mode}
61+
print(io, "Upsample(")
62+
print(io, ":", mode)
63+
u.scale !== nothing && print(io, ", scale = $(u.scale)")
64+
u.size !== nothing && print(io, ", size = $(u.size)")
65+
print(io, ")")
66+
end
67+
68+
"""
69+
PixelShuffle(r::Int)
70+
71+
Pixel shuffling layer with upscale factor `r`.
72+
73+
See [`NNlib.pixel_shuffle`](@ref).
74+
"""
75+
struct PixelShuffle
76+
r::Int
77+
end
78+
79+
(m::PixelShuffle)(x) = NNlib.pixel_shuffle(x, m.r)

src/outputsize.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Base.isless(::Nil, ::Number) = true
3535
Base.isless(::Number, ::Nil) = true
3636

3737
Base.isnan(::Nil) = false
38-
38+
Base.isfinite(::Nil) = true
3939
Base.typemin(::Type{Nil}) = nil
4040
Base.typemax(::Type{Nil}) = nil
4141

test/cuda/layers.jl

+10-1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,15 @@ gpu_gradtest("GroupNorm 3d", groupnorm, rand(Float32, 8, 8, 8, 12, 4), 12, 3, se
7979
gpu_gradtest("GroupNorm 2d", groupnorm, rand(Float32, 8, 8, 12, 4), 12, 3, setmode=true)
8080
gpu_gradtest("GroupNorm 1d", groupnorm, rand(Float32, 8, 3, 12, 4), 12, 3, setmode=true)
8181

82+
upsample = [x -> Upsample(scale=x)]
83+
gpu_gradtest("Upsample 2d", upsample, rand(Float32, 3, 4, 2, 3), (2,2))
84+
gpu_gradtest("Upsample 1d", upsample, rand(Float32, 3, 4, 2, 3), (2,))
85+
86+
pixelshuffle = [PixelShuffle]
87+
gpu_gradtest("PixelShuffle 2d", pixelshuffle, rand(Float32, 3, 4, 18, 3), 3)
88+
gpu_gradtest("PixelShuffle 1d", pixelshuffle, rand(Float32, 3, 18, 3), 3)
89+
90+
8291
@testset "function layers" begin
8392
x = rand(Float32, 3,3)
8493
gpu_autodiff_test(x -> sum(Flux.normalise(x; dims=1)), x)
@@ -168,4 +177,4 @@ end
168177
@test sum(l(ip)) 0.f0
169178
gs = gradient(() -> sum(l(ip)), Flux.params(l))
170179
@test l.b gs.params
171-
end
180+
end

test/layers/upsample.jl

+67
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
@testset "upsample bilinear" begin
2+
m = Upsample(:bilinear, scale=(2, 3))
3+
x = rand(Float32, 3, 4, 2, 3)
4+
y = m(x)
5+
@test y isa Array{Float32, 4}
6+
@test size(y) == (6, 12, 2, 3)
7+
8+
m = Upsample(:bilinear, scale=3)
9+
x = rand(Float32, 3, 4, 2, 3)
10+
y = m(x)
11+
@test y isa Array{Float32, 4}
12+
@test size(y) == (9, 12, 2, 3)
13+
14+
m = Upsample(:bilinear, size=(4, 6))
15+
x = rand(Float32, 3, 4, 2, 3)
16+
y = m(x)
17+
@test y isa Array{Float32, 4}
18+
@test size(y) == (4, 6, 2, 3)
19+
end
20+
21+
@testset "upsample nearest" begin
22+
x = rand(Float32, 3, 2, 3)
23+
m = Upsample(:nearest, scale=(2,))
24+
y = m(x)
25+
@test y isa Array{Float32, 3}
26+
@test size(y) == (6, 2, 3)
27+
28+
x = rand(Float32, 3, 4, 2, 3)
29+
30+
m = Upsample(:nearest, scale=(2, 3))
31+
y = m(x)
32+
@test y isa Array{Float32, 4}
33+
@test size(y) == (6, 12, 2, 3)
34+
35+
m = Upsample(:nearest, scale=(2,))
36+
y = m(x)
37+
@test y isa Array{Float32, 4}
38+
@test size(y) == (6, 4, 2, 3)
39+
40+
m = Upsample(:nearest, scale=2)
41+
y = m(x)
42+
@test y isa Array{Float32, 4}
43+
@test size(y) == (6, 8, 2, 3)
44+
45+
m = Upsample(2)
46+
y2 = m(x)
47+
@test y2 y
48+
49+
m = Upsample(:nearest, size=(6,8))
50+
y = m(x)
51+
@test y isa Array{Float32, 4}
52+
@test size(y) == (6, 8, 2, 3)
53+
end
54+
55+
@testset "PixelShuffle" begin
56+
m = PixelShuffle(2)
57+
x = rand(Float32, 3, 18, 3)
58+
y = m(x)
59+
@test y isa Array{Float32, 3}
60+
@test size(y) == (6, 9, 3)
61+
62+
m = PixelShuffle(3)
63+
x = rand(Float32, 3, 4, 18, 3)
64+
y = m(x)
65+
@test y isa Array{Float32, 4}
66+
@test size(y) == (9, 12, 2, 3)
67+
end

test/outputsize.jl

+3-5
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,7 @@ end
146146
@test outputsize(m, (32, 32, 3, 16)) == (32, 32, 3, 16)
147147
@test outputsize(m, (32, 32, 3); padbatch=true) == (32, 32, 3, 1)
148148

149-
if VERSION >= v"1.1"
150-
m = GroupNorm(16, 4)
151-
@test outputsize(m, (32, 32, 16, 16)) == (32, 32, 16, 16)
152-
@test outputsize(m, (32, 32, 16); padbatch=true) == (32, 32, 16, 1)
153-
end
149+
m = GroupNorm(16, 4)
150+
@test outputsize(m, (32, 32, 16, 16)) == (32, 32, 16, 16)
151+
@test outputsize(m, (32, 32, 16); padbatch=true) == (32, 32, 16, 1)
154152
end

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ end
3434
include("layers/stateless.jl")
3535
include("layers/recurrent.jl")
3636
include("layers/conv.jl")
37+
include("layers/upsample.jl")
3738
end
3839

3940
@testset "outputsize" begin

0 commit comments

Comments
 (0)