|
| 1 | +dsts = Dict( |
| 2 | + 0 => cu([3, 4, 5, 6, 7]), |
| 3 | + 1 => cu([3 3 4 4 5; |
| 4 | + 5 5 6 6 7]), |
| 5 | +) |
| 6 | +srcs = Dict( |
| 7 | + (0, true) => cu(ones(Int, 3, 4)), |
| 8 | + (0, false) => cu(ones(Int, 3) * collect(1:4)'), |
| 9 | + (1, true) => cu(ones(Int, 2, 3, 4)), |
| 10 | + (1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4)), |
| 11 | +) |
| 12 | +idxs = [ |
| 13 | + cu([1 2 3 4; |
| 14 | + 4 2 1 3; |
| 15 | + 3 5 5 3]), # integer index |
| 16 | + cu([(1,) (2,) (3,) (4,); |
| 17 | + (4,) (2,) (1,) (3,); |
| 18 | + (3,) (5,) (5,) (3,)]), # tuple index |
| 19 | +] |
| 20 | +res = Dict( |
| 21 | + (+, 0, true) => cu([5, 6, 9, 8, 9]), |
| 22 | + (+, 1, true) => cu([5 5 8 6 7; |
| 23 | + 7 7 10 8 9]), |
| 24 | + (+, 0, false) => cu([4, 4, 12, 5, 5]), |
| 25 | + (+, 1, false) => cu([4 4 12 5 5; |
| 26 | + 8 8 24 10 10]), |
| 27 | + (-, 0, true) => cu([1, 2, 1, 4, 5]), |
| 28 | + (-, 1, true) => cu([1 1 0 2 3; |
| 29 | + 3 3 2 4 5]), |
| 30 | + (-, 0, false) => cu([-4, -4, -12, -5, -5]), |
| 31 | + (-, 1, false) => cu([-4 -4 -12 -5 -5; |
| 32 | + -8 -8 -24 -10 -10]), |
| 33 | + (max, 0, true) => cu([3, 4, 5, 6, 7]), |
| 34 | + (max, 1, true) => cu([3 3 4 4 5; |
| 35 | + 5 5 6 6 7]), |
| 36 | + (max, 0, false) => cu([3, 2, 4, 4, 3]), |
| 37 | + (max, 1, false) => cu([3 2 4 4 3; |
| 38 | + 6 4 8 8 6]), |
| 39 | + (min, 0, true) => cu([1, 1, 1, 1, 1]), |
| 40 | + (min, 1, true) => cu([1 1 1 1 1; |
| 41 | + 1 1 1 1 1]), |
| 42 | + (min, 0, false) => cu([1, 2, 1, 1, 2]), |
| 43 | + (min, 1, false) => cu([1 2 1 1 2; |
| 44 | + 2 4 2 2 4]), |
| 45 | + (*, 0, true) => cu([3, 4, 5, 6, 7]), |
| 46 | + (*, 1, true) => cu([3 3 4 4 5; |
| 47 | + 5 5 6 6 7]), |
| 48 | + (*, 0, false) => cu([3, 4, 48, 4, 6]), |
| 49 | + (*, 1, false) => cu([3 4 48 4 6; |
| 50 | + 12 16 768 16 24]), |
| 51 | + (/, 0, true) => cu([0.75, 1., 0.3125, 1.5, 1.75]), |
| 52 | + (/, 1, true) => cu([0.75 0.75 0.25 1. 1.25; |
| 53 | + 1.25 1.25 0.375 1.5 1.75]), |
| 54 | + (/, 0, false) => cu([1//3, 1//4, 1//48, 1//4, 1//6]), |
| 55 | + (/, 1, false) => cu([1//3 1//4 1//48 1//4 1//6; |
| 56 | + 1//12 1//16 1//768 1//16 1//24]), |
| 57 | + (mean, 0, true) => cu([4., 5., 6., 7., 8.]), |
| 58 | + (mean, 1, true) => cu([4. 4. 5. 5. 6.; |
| 59 | + 6. 6. 7. 7. 8.]), |
| 60 | + (mean, 0, false) => cu([2, 2, 3, 2.5, 2.5]), |
| 61 | + (mean, 1, false) => cu([2. 2. 3. 2.5 2.5; |
| 62 | + 4. 4. 6. 5. 5.]), |
| 63 | +) |
| 64 | + |
| 65 | +types = [CuArray{Int32}, CuArray{Int64}, CuArray{Float32}, CuArray{Float64}] |
| 66 | + |
| 67 | + |
| 68 | +@testset "scatter" begin |
| 69 | + for T = types |
| 70 | + @testset "$(T)" begin |
| 71 | + @testset "+" begin |
| 72 | + for idx = idxs, dims = [0, 1] |
| 73 | + mutated = true |
| 74 | + @test NNlib.scatter!(+, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(+, dims, mutated)]) |
| 75 | + |
| 76 | + mutated = false |
| 77 | + @test NNlib.scatter(+, T(srcs[(dims, mutated)]), idx) == T(res[(+, dims, mutated)]) |
| 78 | + end |
| 79 | + end |
| 80 | + |
| 81 | + @testset "-" begin |
| 82 | + for idx = idxs, dims = [0, 1] |
| 83 | + mutated = true |
| 84 | + @test NNlib.scatter!(-, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(-, dims, mutated)]) |
| 85 | + |
| 86 | + mutated = false |
| 87 | + @test NNlib.scatter(-, T(srcs[(dims, mutated)]), idx) == T(res[(-, dims, mutated)]) |
| 88 | + end |
| 89 | + end |
| 90 | + |
| 91 | + @testset "max" begin |
| 92 | + for idx = idxs, dims = [0, 1] |
| 93 | + mutated = true |
| 94 | + @test NNlib.scatter!(max, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(max, dims, mutated)]) |
| 95 | + |
| 96 | + mutated = false |
| 97 | + @test NNlib.scatter(max, T(srcs[(dims, mutated)]), idx) == T(res[(max, dims, mutated)]) |
| 98 | + end |
| 99 | + end |
| 100 | + |
| 101 | + @testset "min" begin |
| 102 | + for idx = idxs, dims = [0, 1] |
| 103 | + mutated = true |
| 104 | + @test NNlib.scatter!(min, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(min, dims, mutated)]) |
| 105 | + |
| 106 | + mutated = false |
| 107 | + @test NNlib.scatter(min, T(srcs[(dims, mutated)]), idx) == T(res[(min, dims, mutated)]) |
| 108 | + end |
| 109 | + end |
| 110 | + end |
| 111 | + end |
| 112 | + |
| 113 | + |
| 114 | + for T = [CuArray{Float32}, CuArray{Float64}] |
| 115 | + @testset "$(T)" begin |
| 116 | + @testset "*" begin |
| 117 | + for idx = idxs, dims = [0, 1] |
| 118 | + mutated = true |
| 119 | + @test NNlib.scatter!(*, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(*, dims, mutated)]) |
| 120 | + |
| 121 | + mutated = false |
| 122 | + @test NNlib.scatter(*, T(srcs[(dims, mutated)]), idx) == T(res[(*, dims, mutated)]) |
| 123 | + end |
| 124 | + end |
| 125 | + |
| 126 | + @testset "/" begin |
| 127 | + for idx = idxs, dims = [0, 1] |
| 128 | + mutated = true |
| 129 | + @test NNlib.scatter!(/, T(copy(dsts[dims])), T(srcs[(dims, mutated)].*2), idx) == T(res[(/, dims, mutated)]) |
| 130 | + |
| 131 | + mutated = false |
| 132 | + @test NNlib.scatter(/, T(srcs[(dims, mutated)]), idx) == T(res[(/, dims, mutated)]) |
| 133 | + end |
| 134 | + end |
| 135 | + |
| 136 | + @testset "mean" begin |
| 137 | + for idx = idxs, dims = [0, 1] |
| 138 | + mutated = true |
| 139 | + @test NNlib.scatter!(mean, T(copy(dsts[dims])), T(srcs[(dims, mutated)]), idx) == T(res[(mean, dims, mutated)]) |
| 140 | + |
| 141 | + mutated = false |
| 142 | + @test NNlib.scatter(mean, T(srcs[(dims, mutated)]), idx) == T(res[(mean, dims, mutated)]) |
| 143 | + end |
| 144 | + end |
| 145 | + end |
| 146 | + end |
| 147 | +end |
0 commit comments