Skip to content

Commit 29f113d

Browse files
committed
refactor test cases
1 parent 43d82eb commit 29f113d

File tree

1 file changed

+117
-102
lines changed

1 file changed

+117
-102
lines changed

lib/NNlibCUDA/test/scatter.jl

+117-102
Original file line numberDiff line numberDiff line change
@@ -1,131 +1,146 @@
1-
ys = cu([3 3 4 4 5;
2-
5 5 6 6 7])
3-
us = cu(ones(Int, 2, 3, 4))
4-
xs = CuArray{Int64}([1 2 3 4;
5-
4 2 1 3;
6-
3 5 5 3])
7-
xs_tup = CuArray([(1,) (2,) (3,) (4,);
8-
(4,) (2,) (1,) (3,);
9-
(3,) (5,) (5,) (3,)])
10-
11-
12-
@testset "cuda/scatter" begin
13-
for T = [UInt32, UInt64, Int32, Int64]
1+
dsts = Dict(
2+
0 => cu([3, 4, 5, 6, 7]),
3+
1 => cu([3 3 4 4 5;
4+
5 5 6 6 7]),
5+
)
6+
srcs = Dict(
7+
(0, true) => cu(ones(Int, 3, 4)),
8+
(0, false) => cu(ones(Int, 3) * collect(1:4)'),
9+
(1, true) => cu(ones(Int, 2, 3, 4)),
10+
(1, false) => cu([1, 2] .* reshape(ones(Int, 3) * collect(1:4)', 1,3,4)),
11+
)
12+
idxs = Dict(
13+
:int => cu([1 2 3 4;
14+
4 2 1 3;
15+
3 5 5 3]),
16+
:tup => cu([(1,) (2,) (3,) (4,);
17+
(4,) (2,) (1,) (3,);
18+
(3,) (5,) (5,) (3,)]),
19+
)
20+
res = Dict(
21+
(+, 0, true) => cu([5, 6, 9, 8, 9]),
22+
(+, 1, true) => cu([5 5 8 6 7;
23+
7 7 10 8 9]),
24+
(+, 0, false) => cu([4, 4, 12, 5, 5]),
25+
(+, 1, false) => cu([4 4 12 5 5;
26+
8 8 24 10 10]),
27+
(-, 0, true) => cu([1, 2, 1, 4, 5]),
28+
(-, 1, true) => cu([1 1 0 2 3;
29+
3 3 2 4 5]),
30+
(-, 0, false) => cu([-4, -4, -12, -5, -5]),
31+
(-, 1, false) => cu([-4 -4 -12 -5 -5;
32+
-8 -8 -24 -10 -10]),
33+
(max, 0, true) => cu([3, 4, 5, 6, 7]),
34+
(max, 1, true) => cu([3 3 4 4 5;
35+
5 5 6 6 7]),
36+
(max, 0, false) => cu([3, 2, 4, 4, 3]),
37+
(max, 1, false) => cu([3 2 4 4 3;
38+
6 4 8 8 6]),
39+
(min, 0, true) => cu([1, 1, 1, 1, 1]),
40+
(min, 1, true) => cu([1 1 1 1 1;
41+
1 1 1 1 1]),
42+
(min, 0, false) => cu([1, 2, 1, 1, 2]),
43+
(min, 1, false) => cu([1 2 1 1 2;
44+
2 4 2 2 4]),
45+
(*, 0, true) => cu([3, 4, 5, 6, 7]),
46+
(*, 1, true) => cu([3 3 4 4 5;
47+
5 5 6 6 7]),
48+
(*, 0, false) => cu([3, 4, 48, 4, 6]),
49+
(*, 1, false) => cu([3 4 48 4 6;
50+
12 16 768 16 24]),
51+
(/, 0, true) => cu([0.75, 1., 0.3125, 1.5, 1.75]),
52+
(/, 1, true) => cu([0.75 0.75 0.25 1. 1.25;
53+
1.25 1.25 0.375 1.5 1.75]),
54+
(/, 0, false) => cu([1//3, 1//4, 1//48, 1//4, 1//6]),
55+
(/, 1, false) => cu([1//3 1//4 1//48 1//4 1//6;
56+
1//12 1//16 1//768 1//16 1//24]),
57+
(mean, 0, true) => cu([4., 5., 6., 7., 8.]),
58+
(mean, 1, true) => cu([4. 4. 5. 5. 6.;
59+
6. 6. 7. 7. 8.]),
60+
(mean, 0, false) => cu([2, 2, 3, 2.5, 2.5]),
61+
(mean, 1, false) => cu([2. 2. 3. 2.5 2.5;
62+
4. 4. 6. 5. 5.]),
63+
)
64+
65+
types = [UInt32, UInt64, Int32, Int64, Float32, Float64]
66+
67+
68+
@testset "scatter" begin
69+
for T = types
1470
@testset "$(T)" begin
15-
@testset "add" begin
16-
ys_ = cu([5 5 8 6 7;
17-
7 7 10 8 9])
18-
@test scatter_add!(T.(copy(ys)), T.(us), xs) == T.(ys_)
19-
@test scatter!(:add, T.(copy(ys)), T.(us), xs) == T.(ys_)
20-
21-
@test scatter_add!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
22-
@test scatter!(:add, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
71+
@testset "+" begin
72+
for idx = values(idxs), dims = [0, 1]
73+
mutated = true
74+
@test scatter!(+, T.(dsts[dims]), srcs[(dims, mutated)], idx) == T.(res[(+, dims, mutated)])
75+
76+
mutated = false
77+
# @test scatter(+, srcs[(dims, mutated)], idx) == T.(res[(+, dims, mutated)])
78+
end
2379
end
2480

25-
@testset "sub" begin
26-
ys_ = cu([1 1 0 2 3;
27-
3 3 2 4 5])
28-
@test scatter_sub!(T.(copy(ys)), T.(us), xs) == T.(ys_)
29-
@test scatter!(:sub, T.(copy(ys)), T.(us), xs) == T.(ys_)
81+
@testset "-" begin
82+
for idx = values(idxs), dims = [0, 1]
83+
mutated = true
84+
@test scatter!(-, T.(dsts[dims]), srcs[(dims, mutated)], idx) == T.(res[(-, dims, mutated)])
3085

31-
@test scatter_sub!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
32-
@test scatter!(:sub, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
86+
mutated = false
87+
# @test scatter(-, srcs[(dims, mutated)], idx) == T.(res[(-, dims, mutated)])
88+
end
3389
end
3490

3591
@testset "max" begin
36-
ys_ = cu([3 3 4 4 5;
37-
5 5 6 6 7])
38-
@test scatter_max!(T.(copy(ys)), T.(us), xs) == T.(ys_)
39-
@test scatter!(:max, T.(copy(ys)), T.(us), xs) == T.(ys_)
92+
for idx = values(idxs), dims = [0, 1]
93+
mutated = true
94+
@test scatter!(max, T.(dsts[dims]), srcs[(dims, mutated)], idx) == T.(res[(max, dims, mutated)])
4095

41-
@test scatter_max!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
42-
@test scatter!(:max, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
96+
mutated = false
97+
# @test scatter(max, srcs[(dims, mutated)], idx) == T.(res[(max, dims, mutated)])
98+
end
4399
end
44100

45101
@testset "min" begin
46-
ys_ = cu([1 1 1 1 1;
47-
1 1 1 1 1])
48-
@test scatter_min!(T.(copy(ys)), T.(us), xs) == T.(ys_)
49-
@test scatter!(:min, T.(copy(ys)), T.(us), xs) == T.(ys_)
102+
for idx = values(idxs), dims = [0, 1]
103+
mutated = true
104+
@test scatter!(min, T.(dsts[dims]), srcs[(dims, mutated)], idx) == T.(res[(min, dims, mutated)])
50105

51-
@test scatter_min!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
52-
@test scatter!(:min, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
106+
mutated = false
107+
# @test scatter(min, srcs[(dims, mutated)], idx) == T.(res[(min, dims, mutated)])
108+
end
53109
end
54110
end
55111
end
56112

57113

58114
for T = [Float32, Float64]
59115
@testset "$(T)" begin
60-
@testset "add" begin
61-
ys_ = cu([5 5 8 6 7;
62-
7 7 10 8 9])
63-
@test scatter_add!(T.(copy(ys)), T.(us), xs) == T.(ys_)
64-
@test scatter!(:add, T.(copy(ys)), T.(us), xs) == T.(ys_)
65-
66-
@test scatter_add!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
67-
@test scatter!(:add, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
116+
@testset "*" begin
117+
for idx = values(idxs), dims = [0, 1]
118+
mutated = true
119+
@test scatter!(*, T.(dsts[dims]), srcs[(dims, mutated)], idx) == T.(res[(*, dims, mutated)])
120+
121+
mutated = false
122+
# @test scatter(*, srcs[(dims, mutated)], idx) == T.(res[(*, dims, mutated)])
123+
end
68124
end
69125

70-
@testset "sub" begin
71-
ys_ = cu([1 1 0 2 3;
72-
3 3 2 4 5])
73-
@test scatter_sub!(T.(copy(ys)), T.(us), xs) == T.(ys_)
74-
@test scatter!(:sub, T.(copy(ys)), T.(us), xs) == T.(ys_)
126+
@testset "/" begin
127+
for idx = values(idxs), dims = [0, 1]
128+
mutated = true
129+
@test scatter!(/, T.(dsts[dims]), srcs[(dims, mutated)].*2, idx) == T.(res[(/, dims, mutated)])
75130

76-
@test scatter_sub!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
77-
@test scatter!(:sub, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
78-
end
79-
80-
@testset "max" begin
81-
ys_ = cu([3 3 4 4 5;
82-
5 5 6 6 7])
83-
@test scatter_max!(T.(copy(ys)), T.(us), xs) == T.(ys_)
84-
@test scatter!(:max, T.(copy(ys)), T.(us), xs) == T.(ys_)
85-
86-
@test scatter_max!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
87-
@test scatter!(:max, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
88-
end
89-
90-
@testset "min" begin
91-
ys_ = cu([1 1 1 1 1;
92-
1 1 1 1 1])
93-
@test scatter_min!(T.(copy(ys)), T.(us), xs) == T.(ys_)
94-
@test scatter!(:min, T.(copy(ys)), T.(us), xs) == T.(ys_)
95-
96-
@test scatter_min!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
97-
@test scatter!(:min, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
98-
end
99-
100-
@testset "mul" begin
101-
ys_ = cu([3 3 4 4 5;
102-
5 5 6 6 7])
103-
@test scatter_mul!(T.(copy(ys)), T.(us), xs) == T.(ys_)
104-
@test scatter!(:mul, T.(copy(ys)), T.(us), xs) == T.(ys_)
105-
106-
@test scatter_mul!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
107-
@test scatter!(:mul, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
108-
end
109-
110-
@testset "div" begin
111-
us_div = us .* 2
112-
ys_ = cu([0.75 0.75 0.25 1. 1.25;
113-
1.25 1.25 0.375 1.5 1.75])
114-
@test scatter_div!(T.(copy(ys)), T.(us_div), xs) == T.(ys_)
115-
@test scatter!(:div, T.(copy(ys)), T.(us_div), xs) == T.(ys_)
116-
117-
@test scatter_div!(T.(copy(ys)), T.(us_div), xs_tup) == T.(ys_)
118-
@test scatter!(:div, T.(copy(ys)), T.(us_div), xs_tup) == T.(ys_)
131+
mutated = false
132+
# @test scatter(/, srcs[(dims, mutated)], idx) == T.(res[(/, dims, mutated)])
133+
end
119134
end
120135

121136
@testset "mean" begin
122-
ys_ = cu([4 4 5 5 6;
123-
6 6 7 7 8])
124-
@test scatter_mean!(T.(copy(ys)), T.(us), xs) == T.(ys_)
125-
@test scatter!(:mean, T.(copy(ys)), T.(us), xs) == T.(ys_)
137+
for idx = values(idxs), dims = [0, 1]
138+
mutated = true
139+
@test scatter!(mean, T.(dsts[dims]), srcs[(dims, mutated)], idx) == T.(res[(mean, dims, mutated)])
126140

127-
@test scatter_mean!(T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
128-
@test scatter!(:mean, T.(copy(ys)), T.(us), xs_tup) == T.(ys_)
141+
mutated = false
142+
# @test scatter(mean, srcs[(dims, mutated)], idx) == T.(res[(mean, dims, mutated)])
143+
end
129144
end
130145
end
131146
end

0 commit comments

Comments
 (0)