From adba4ded58beee4c96ac34e5f92c1e9fd8436f7f Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Sun, 25 Apr 2021 11:19:17 +0800 Subject: [PATCH 1/7] add gradient for gather --- src/gather.jl | 11 +++++++++++ test/gather.jl | 20 ++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/src/gather.jl b/src/gather.jl index 47793e0c0..d3ead1a71 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -71,3 +71,14 @@ function gather(src::AbstractArray{Tsrc, Nsrc}, colons = ntuple(i -> Colon(), Nsrc-1) return src[colons..., idx] end + +# Gradient + +∇gather_dst!(Δ, dst, y) = (dst .!= y) .* Δ +∇gather_src!(Δ, src, idx) = scatter!(+, similar(src), Δ, idx) + +function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::AbstractArray) + y = gather!(copy(dst), src, idx) + gather!_pullback(Δ) = (NO_FIELDS, ∇gather_dst!(Δ, dst, y), ∇gather_src!(Δ, src, idx), DoesNotExist()) + y, gather!_pullback +end diff --git a/test/gather.jl b/test/gather.jl index a9ce4a246..4ce796bc1 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -124,3 +124,23 @@ end @test y isa Array{T,3} @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) end + +@testset "gather gradient" begin + T = Float64 + src = T[3, 4, 5, 6, 7] + index = [1 2 3 4; + 4 2 1 3; + 3 5 5 3] + dst = T[3 4 5 6; + 6 4 3 5; + 5 7 7 5] + + @testset "∂dst" begin + gradtest(xs -> gather!(copy(xs), src, index), dst) + end + + @testset "∂src" begin + gradtest(xs -> gather!(dst, xs, index), src) + gradtest(xs -> gather(xs, index), src) + end +end From 941e5235ba07b0f8cccaebcb9f5e2256c515c662 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Sun, 16 May 2021 13:16:40 +0800 Subject: [PATCH 2/7] drop gradient of dst --- src/gather.jl | 3 +-- test/gather.jl | 10 ++-------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/gather.jl b/src/gather.jl index d3ead1a71..b1fabe7a8 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -74,11 +74,10 @@ end # Gradient -∇gather_dst!(Δ, dst, y) = (dst .!= y) .* Δ ∇gather_src!(Δ, src, idx) = scatter!(+, similar(src), Δ, idx) function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::AbstractArray) y = gather!(copy(dst), src, idx) - gather!_pullback(Δ) = (NO_FIELDS, ∇gather_dst!(Δ, dst, y), ∇gather_src!(Δ, src, idx), DoesNotExist()) + gather!_pullback(Δ) = (NO_FIELDS, DoesNotExist(), ∇gather_src!(Δ, src, idx), DoesNotExist()) y, gather!_pullback end diff --git a/test/gather.jl b/test/gather.jl index 4ce796bc1..6d11be8fc 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -135,12 +135,6 @@ end 6 4 3 5; 5 7 7 5] - @testset "∂dst" begin - gradtest(xs -> gather!(copy(xs), src, index), dst) - end - - @testset "∂src" begin - gradtest(xs -> gather!(dst, xs, index), src) - gradtest(xs -> gather(xs, index), src) - end + gradtest(xs -> gather!(dst, xs, index), src) + gradtest(xs -> gather(xs, index), src) end From 79bf21620d76d1d8bac8df3c5f7db274108c3f8a Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Sun, 16 May 2021 13:19:28 +0800 Subject: [PATCH 3/7] fill zeros --- src/gather.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gather.jl b/src/gather.jl index b1fabe7a8..c06158b04 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -74,7 +74,7 @@ end # Gradient -∇gather_src!(Δ, src, idx) = scatter!(+, similar(src), Δ, idx) +∇gather_src!(Δ, src, idx) = scatter!(+, fill!(similar(src), 0), Δ, idx) function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::AbstractArray) y = gather!(copy(dst), src, idx) From b19177bec9d57d43a0d587ab2164f1d0c927c3f8 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Tue, 18 May 2021 16:01:25 +0800 Subject: [PATCH 4/7] =?UTF-8?q?fix=20=E2=88=87gather=5Fsrc!?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/gather.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gather.jl b/src/gather.jl index c06158b04..762f048ef 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -74,10 +74,10 @@ end # Gradient -∇gather_src!(Δ, src, idx) = scatter!(+, fill!(similar(src), 0), Δ, idx) +∇gather_src(Δ, src, idx) = scatter!(+, fill!(similar(src), 0), Δ, idx) function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::AbstractArray) y = gather!(copy(dst), src, idx) - gather!_pullback(Δ) = (NO_FIELDS, DoesNotExist(), ∇gather_src!(Δ, src, idx), DoesNotExist()) + gather!_pullback(Δ) = (NO_FIELDS, DoesNotExist(), ∇gather_src(Δ, src, idx), DoesNotExist()) y, gather!_pullback end From 2a04cabf1ee09dd1925f421a585261e4b6bc9210 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Tue, 18 May 2021 16:03:33 +0800 Subject: [PATCH 5/7] refactor --- src/gather.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gather.jl b/src/gather.jl index 762f048ef..deddca666 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -74,10 +74,10 @@ end # Gradient -∇gather_src(Δ, src, idx) = scatter!(+, fill!(similar(src), 0), Δ, idx) +∇gather_src(Δ, idx) = scatter(+, Δ, idx) function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::AbstractArray) y = gather!(copy(dst), src, idx) - gather!_pullback(Δ) = (NO_FIELDS, DoesNotExist(), ∇gather_src(Δ, src, idx), DoesNotExist()) + gather!_pullback(Δ) = (NO_FIELDS, DoesNotExist(), ∇gather_src(Δ, idx), DoesNotExist()) y, gather!_pullback end From 068e151745c7608fa3a3d8620a526da05c8ab935 Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Tue, 18 May 2021 16:21:30 +0800 Subject: [PATCH 6/7] consider gc --- src/gather.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/gather.jl b/src/gather.jl index deddca666..20681acdc 100644 --- a/src/gather.jl +++ b/src/gather.jl @@ -74,10 +74,11 @@ end # Gradient -∇gather_src(Δ, idx) = scatter(+, Δ, idx) +∇gather_src(Δ, src_size, idx) = scatter!(+, fill!(similar(Δ, eltype(Δ), src_size), 0), Δ, idx) function rrule(::typeof(gather!), dst::AbstractArray, src::AbstractArray, idx::AbstractArray) y = gather!(copy(dst), src, idx) - gather!_pullback(Δ) = (NO_FIELDS, DoesNotExist(), ∇gather_src(Δ, idx), DoesNotExist()) + src_size = size(src) + gather!_pullback(Δ) = (NO_FIELDS, DoesNotExist(), ∇gather_src(Δ, src_size, idx), DoesNotExist()) y, gather!_pullback end From b9962d89ad26ea04f2d3853f02b6a729edd3b4ee Mon Sep 17 00:00:00 2001 From: Yueh-Hua Tu Date: Tue, 18 May 2021 16:26:51 +0800 Subject: [PATCH 7/7] add gradient test for tuple index --- test/gather.jl | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/test/gather.jl b/test/gather.jl index 6d11be8fc..eb6b8f6f9 100644 --- a/test/gather.jl +++ b/test/gather.jl @@ -125,7 +125,7 @@ end @test size(y) == (size(src)[1:Nsrc-M]..., size(index)...) end -@testset "gather gradient" begin +@testset "gather gradient for scalar index" begin T = Float64 src = T[3, 4, 5, 6, 7] index = [1 2 3 4; @@ -138,3 +138,14 @@ end gradtest(xs -> gather!(dst, xs, index), src) gradtest(xs -> gather(xs, index), src) end + +@testset "gather gradient for tuple index" begin + T = Float64 + src = T[3 5 7 + 4 6 8] + index = [(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)] + dst = T[3, 5, 7, 4, 6, 8] + + gradtest(xs -> gather!(dst, xs, index), src) + gradtest(xs -> gather(xs, index), src) +end