@@ -6,25 +6,34 @@ if isdefined(Base, :get_extension)
6
6
using LuxCUDA
7
7
else
8
8
using .. Tracker
9
- import .. Tracker: @grad , data, nobacksies, track, TrackedArray, TrackedVector,
10
- TrackedReal
9
+ import .. Tracker: @grad ,
10
+ data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal
11
11
using .. LuxCUDA
12
12
end
13
13
using NNlib, LuxLib
14
- import LuxLib: AA, AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅,
15
- __is_tracked
14
+ import LuxLib: AA,
15
+ AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, __is_tracked
16
16
17
17
# api/batchnorm.jl
18
- const TR_CUDNN_BN_ARRAY_TYPE = Union{TrackedArray{<: Any , <: Any , <: CuArray{<:FP_32_64, 2} },
19
- TrackedArray{<: Any , <: Any , <: CuArray{<:FP_32_64, 4} },
20
- TrackedArray{<: Any , <: Any , <: CuArray{<:FP_32_64, 5} }}
21
- const TR_BNParamType = Union{Nothing, TrackedArray{<: Any , <: Any , <: CuVector{<:FP_32_64} },
22
- CuVector{<: FP_32_64 }}
23
-
24
- function LuxLib. batchnorm (x:: TR_CUDNN_BN_ARRAY_TYPE , scale:: TR_BNParamType ,
25
- bias:: TR_BNParamType , running_mean:: TR_BNParamType ,
26
- running_var:: TR_BNParamType ; momentum:: Real , training:: Val ,
27
- epsilon:: Real )
18
+ const TR_CUDNN_BN_ARRAY_TYPE = Union{
19
+ TrackedArray{<: Any , <: Any , <: CuArray{<:FP_32_64, 2} },
20
+ TrackedArray{<: Any , <: Any , <: CuArray{<:FP_32_64, 4} },
21
+ TrackedArray{<: Any , <: Any , <: CuArray{<:FP_32_64, 5} },
22
+ }
23
+ const TR_BNParamType = Union{
24
+ Nothing,
25
+ TrackedArray{<: Any , <: Any , <: CuVector{<:FP_32_64} },
26
+ CuVector{<: FP_32_64 },
27
+ }
28
+
29
+ function LuxLib. batchnorm (x:: TR_CUDNN_BN_ARRAY_TYPE ,
30
+ scale:: TR_BNParamType ,
31
+ bias:: TR_BNParamType ,
32
+ running_mean:: TR_BNParamType ,
33
+ running_var:: TR_BNParamType ;
34
+ momentum:: Real ,
35
+ training:: Val ,
36
+ epsilon:: Real )
28
37
rm, rv = _get_batchnorm_statistics (x, running_mean, running_var, training)
29
38
30
39
x_ = _batchnorm_cudnn! (rm, rv, scale, bias, x, momentum, epsilon, training)
@@ -39,21 +48,52 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector),
39
48
40
49
__is_tracked (RM, RV, S, B, XT) || continue
41
50
42
- @eval function _batchnorm_cudnn! (running_mean:: $RM , running_var:: $RV , scale:: $S ,
43
- bias:: $B , x:: $XT , momentum, eps, training:: Val )
44
- return track (_batchnorm_cudnn!, running_mean, running_var, scale, bias, x, momentum,
45
- eps, training)
51
+ @eval function _batchnorm_cudnn! (running_mean:: $RM ,
52
+ running_var:: $RV ,
53
+ scale:: $S ,
54
+ bias:: $B ,
55
+ x:: $XT ,
56
+ momentum,
57
+ eps,
58
+ training:: Val )
59
+ return track (_batchnorm_cudnn!,
60
+ running_mean,
61
+ running_var,
62
+ scale,
63
+ bias,
64
+ x,
65
+ momentum,
66
+ eps,
67
+ training)
46
68
end
47
69
end
48
70
49
- @grad function LuxLib. _batchnorm_cudnn! (running_mean, running_var, scale, bias, x, momentum,
50
- eps, training)
51
- y = _batchnorm_cudnn! (data (running_mean), data (running_var), data (scale), data (bias),
52
- data (x), momentum, eps, training)
71
+ @grad function LuxLib. _batchnorm_cudnn! (running_mean,
72
+ running_var,
73
+ scale,
74
+ bias,
75
+ x,
76
+ momentum,
77
+ eps,
78
+ training)
79
+ y = _batchnorm_cudnn! (data (running_mean),
80
+ data (running_var),
81
+ data (scale),
82
+ data (bias),
83
+ data (x),
84
+ momentum,
85
+ eps,
86
+ training)
53
87
function ∇_batchnorm_cudnn! (Δ)
54
- ∂g, ∂b, ∂x = NNlibCUDA.∇batchnorm (data (scale), data (bias), data (x), Δ,
55
- data (running_mean), data (running_var), momentum;
56
- eps, training)
88
+ ∂g, ∂b, ∂x = NNlibCUDA.∇batchnorm (data (scale),
89
+ data (bias),
90
+ data (x),
91
+ Δ,
92
+ data (running_mean),
93
+ data (running_var),
94
+ momentum;
95
+ eps,
96
+ training)
57
97
return (nothing , nothing , ∂g, ∂b, ∂x, nothing , nothing , nothing )
58
98
end
59
99
return y, ∇_batchnorm_cudnn!
0 commit comments