Skip to content

Commit 1e18d89

Browse files
authored
Merge pull request #17 from LuxDL/auto-juliaformatter-pr
Automatic JuliaFormatter.jl run
2 parents f1f4363 + eb38698 commit 1e18d89

16 files changed

+523
-185
lines changed

lib/LuxLib/docs/make.jl

+27-9
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,31 @@ using Documenter, DocumenterMarkdown, LuxLib
33
deployconfig = Documenter.auto_detect_deploy_system()
44
Documenter.post_status(deployconfig; type="pending", repo="github.com/LuxDL/LuxLib.jl.git")
55

6-
makedocs(; sitename="LuxLib", authors="Avik Pal et al.", clean=true, doctest=true,
7-
modules=[LuxLib],
8-
strict=[:doctest, :linkcheck, :parse_error, :example_block, :missing_docs],
9-
checkdocs=:all, format=Markdown(), draft=false, build=joinpath(@__DIR__, "docs"))
6+
makedocs(;
7+
sitename="LuxLib",
8+
authors="Avik Pal et al.",
9+
clean=true,
10+
doctest=true,
11+
modules=[LuxLib],
12+
strict=[:doctest, :linkcheck, :parse_error, :example_block, :missing_docs],
13+
checkdocs=:all,
14+
format=Markdown(),
15+
draft=false,
16+
build=joinpath(@__DIR__, "docs"))
1017

11-
deploydocs(; repo="github.com/LuxDL/LuxLib.jl.git", push_preview=true,
12-
deps=Deps.pip("mkdocs", "pygments", "python-markdown-math", "mkdocs-material",
13-
"pymdown-extensions", "mkdocstrings", "mknotebooks",
14-
"pytkdocs_tweaks", "mkdocs_include_exclude_files", "jinja2"),
15-
make=() -> run(`mkdocs build`), target="site", devbranch="main")
18+
deploydocs(;
19+
repo="github.com/LuxDL/LuxLib.jl.git",
20+
push_preview=true,
21+
deps=Deps.pip("mkdocs",
22+
"pygments",
23+
"python-markdown-math",
24+
"mkdocs-material",
25+
"pymdown-extensions",
26+
"mkdocstrings",
27+
"mknotebooks",
28+
"pytkdocs_tweaks",
29+
"mkdocs_include_exclude_files",
30+
"jinja2"),
31+
make=() -> run(`mkdocs build`),
32+
target="site",
33+
devbranch="main")

lib/LuxLib/ext/LuxLibLuxCUDAExt.jl

+47-13
Original file line numberDiff line numberDiff line change
@@ -10,31 +10,65 @@ LuxLib._replicate(rng::CUDA.RNG) = deepcopy(rng)
1010

1111
# api/batchnorm.jl
1212

13-
const CUDNN_BN_ARRAY_TYPE = Union{CuArray{<:FP_32_64, 2}, CuArray{<:FP_32_64, 4},
14-
CuArray{<:FP_32_64, 5}}
13+
const CUDNN_BN_ARRAY_TYPE = Union{
14+
CuArray{<:FP_32_64, 2},
15+
CuArray{<:FP_32_64, 4},
16+
CuArray{<:FP_32_64, 5},
17+
}
1518
const BNParamType = Union{Nothing, CuVector{<:FP_32_64}}
1619

17-
function batchnorm(x::CUDNN_BN_ARRAY_TYPE, scale::BNParamType, bias::BNParamType,
18-
running_mean::BNParamType, running_var::BNParamType; momentum::Real,
19-
training::Val, epsilon::Real)
20+
function batchnorm(x::CUDNN_BN_ARRAY_TYPE,
21+
scale::BNParamType,
22+
bias::BNParamType,
23+
running_mean::BNParamType,
24+
running_var::BNParamType;
25+
momentum::Real,
26+
training::Val,
27+
epsilon::Real)
2028
rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training)
2129

2230
x_ = _batchnorm_cudnn!(rm, rv, scale, bias, x, momentum, epsilon, training)
2331
return x_, (; running_mean=rm, running_var=rv)
2432
end
2533

26-
function _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, eps,
27-
::Val{training}) where {training}
28-
return NNlibCUDA.batchnorm(scale, bias, x, running_mean, running_var, momentum; eps,
29-
training)
34+
function _batchnorm_cudnn!(running_mean,
35+
running_var,
36+
scale,
37+
bias,
38+
x,
39+
momentum,
40+
eps,
41+
::Val{training}) where {training}
42+
return NNlibCUDA.batchnorm(scale,
43+
bias,
44+
x,
45+
running_mean,
46+
running_var,
47+
momentum;
48+
eps,
49+
training)
3050
end
3151

32-
function CRC.rrule(::typeof(_batchnorm_cudnn!), running_mean, running_var, scale, bias, x,
33-
momentum, epsilon, t::Val{training}) where {training}
52+
function CRC.rrule(::typeof(_batchnorm_cudnn!),
53+
running_mean,
54+
running_var,
55+
scale,
56+
bias,
57+
x,
58+
momentum,
59+
epsilon,
60+
t::Val{training}) where {training}
3461
y = _batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum, epsilon, t)
3562
function ∇_batchnorm_cudnn!(Δ)
36-
∂g, ∂b, ∂x = NNlibCUDA.∇batchnorm(scale, bias, x, CRC.unthunk(Δ), running_mean,
37-
running_var, momentum; eps=epsilon, training)
63+
∂g, ∂b, ∂x = NNlibCUDA.∇batchnorm(scale,
64+
bias,
65+
x,
66+
CRC.unthunk(Δ),
67+
running_mean,
68+
running_var,
69+
momentum;
70+
eps=epsilon,
71+
training)
3872
return (∂∅, ∂∅, ∂∅, ∂g, ∂b, ∂x, ∂∅, ∂∅, ∂∅)
3973
end
4074
return y, ∇_batchnorm_cudnn!

lib/LuxLib/ext/LuxLibLuxCUDATrackerExt.jl

+65-25
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,34 @@ if isdefined(Base, :get_extension)
66
using LuxCUDA
77
else
88
using ..Tracker
9-
import ..Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector,
10-
TrackedReal
9+
import ..Tracker: @grad,
10+
data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal
1111
using ..LuxCUDA
1212
end
1313
using NNlib, LuxLib
14-
import LuxLib: AA, AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅,
15-
__is_tracked
14+
import LuxLib: AA,
15+
AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, __is_tracked
1616

1717
# api/batchnorm.jl
18-
const TR_CUDNN_BN_ARRAY_TYPE = Union{TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 2}},
19-
TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 4}},
20-
TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 5}}}
21-
const TR_BNParamType = Union{Nothing, TrackedArray{<:Any, <:Any, <:CuVector{<:FP_32_64}},
22-
CuVector{<:FP_32_64}}
23-
24-
function LuxLib.batchnorm(x::TR_CUDNN_BN_ARRAY_TYPE, scale::TR_BNParamType,
25-
bias::TR_BNParamType, running_mean::TR_BNParamType,
26-
running_var::TR_BNParamType; momentum::Real, training::Val,
27-
epsilon::Real)
18+
const TR_CUDNN_BN_ARRAY_TYPE = Union{
19+
TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 2}},
20+
TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 4}},
21+
TrackedArray{<:Any, <:Any, <:CuArray{<:FP_32_64, 5}},
22+
}
23+
const TR_BNParamType = Union{
24+
Nothing,
25+
TrackedArray{<:Any, <:Any, <:CuVector{<:FP_32_64}},
26+
CuVector{<:FP_32_64},
27+
}
28+
29+
function LuxLib.batchnorm(x::TR_CUDNN_BN_ARRAY_TYPE,
30+
scale::TR_BNParamType,
31+
bias::TR_BNParamType,
32+
running_mean::TR_BNParamType,
33+
running_var::TR_BNParamType;
34+
momentum::Real,
35+
training::Val,
36+
epsilon::Real)
2837
rm, rv = _get_batchnorm_statistics(x, running_mean, running_var, training)
2938

3039
x_ = _batchnorm_cudnn!(rm, rv, scale, bias, x, momentum, epsilon, training)
@@ -39,21 +48,52 @@ for RM in (:TrackedVector, :Nothing, :AbstractVector),
3948

4049
__is_tracked(RM, RV, S, B, XT) || continue
4150

42-
@eval function _batchnorm_cudnn!(running_mean::$RM, running_var::$RV, scale::$S,
43-
bias::$B, x::$XT, momentum, eps, training::Val)
44-
return track(_batchnorm_cudnn!, running_mean, running_var, scale, bias, x, momentum,
45-
eps, training)
51+
@eval function _batchnorm_cudnn!(running_mean::$RM,
52+
running_var::$RV,
53+
scale::$S,
54+
bias::$B,
55+
x::$XT,
56+
momentum,
57+
eps,
58+
training::Val)
59+
return track(_batchnorm_cudnn!,
60+
running_mean,
61+
running_var,
62+
scale,
63+
bias,
64+
x,
65+
momentum,
66+
eps,
67+
training)
4668
end
4769
end
4870

49-
@grad function LuxLib._batchnorm_cudnn!(running_mean, running_var, scale, bias, x, momentum,
50-
eps, training)
51-
y = _batchnorm_cudnn!(data(running_mean), data(running_var), data(scale), data(bias),
52-
data(x), momentum, eps, training)
71+
@grad function LuxLib._batchnorm_cudnn!(running_mean,
72+
running_var,
73+
scale,
74+
bias,
75+
x,
76+
momentum,
77+
eps,
78+
training)
79+
y = _batchnorm_cudnn!(data(running_mean),
80+
data(running_var),
81+
data(scale),
82+
data(bias),
83+
data(x),
84+
momentum,
85+
eps,
86+
training)
5387
function ∇_batchnorm_cudnn!(Δ)
54-
∂g, ∂b, ∂x = NNlibCUDA.∇batchnorm(data(scale), data(bias), data(x), Δ,
55-
data(running_mean), data(running_var), momentum;
56-
eps, training)
88+
∂g, ∂b, ∂x = NNlibCUDA.∇batchnorm(data(scale),
89+
data(bias),
90+
data(x),
91+
Δ,
92+
data(running_mean),
93+
data(running_var),
94+
momentum;
95+
eps,
96+
training)
5797
return (nothing, nothing, ∂g, ∂b, ∂x, nothing, nothing, nothing)
5898
end
5999
return y, ∇_batchnorm_cudnn!

lib/LuxLib/ext/LuxLibReverseDiffExt.jl

+45-19
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,28 @@ module LuxLibReverseDiffExt
22

33
if isdefined(Base, :get_extension)
44
using ReverseDiff
5-
import ReverseDiff: SpecialInstruction, TrackedArray, TrackedReal, decrement_deriv!,
6-
increment_deriv!, track, value, special_reverse_exec!,
7-
special_forward_exec!, @grad_from_chainrules
5+
import ReverseDiff: SpecialInstruction,
6+
TrackedArray,
7+
TrackedReal,
8+
decrement_deriv!,
9+
increment_deriv!,
10+
track,
11+
value,
12+
special_reverse_exec!,
13+
special_forward_exec!,
14+
@grad_from_chainrules
815
else
916
using ..ReverseDiff
10-
import ..ReverseDiff: SpecialInstruction, TrackedArray, TrackedReal, decrement_deriv!,
11-
increment_deriv!, track, value, special_reverse_exec!,
12-
special_forward_exec!, @grad_from_chainrules
17+
import ..ReverseDiff: SpecialInstruction,
18+
TrackedArray,
19+
TrackedReal,
20+
decrement_deriv!,
21+
increment_deriv!,
22+
track,
23+
value,
24+
special_reverse_exec!,
25+
special_forward_exec!,
26+
@grad_from_chainrules
1327
end
1428
using ChainRulesCore, LuxLib, NNlib
1529
import ChainRulesCore as CRC
@@ -45,36 +59,48 @@ for func in (:conv, :depthwiseconv, :∇conv_data, :∇conv_filter),
4559
return track(NNlib.$(func), x, w, cdims; kwargs...)
4660
end
4761

48-
function ReverseDiff.track(::typeof(NNlib.$(func)), x::$(xType), w::$(wType),
49-
cdims::ConvDims; kwargs...)
62+
function ReverseDiff.track(::typeof(NNlib.$(func)),
63+
x::$(xType),
64+
w::$(wType),
65+
cdims::ConvDims;
66+
kwargs...)
5067
tape = ReverseDiff.tape(x, w, cdims)
51-
output_value, back = CRC.rrule(NNlib.$(func), value(x), value(w), cdims;
52-
kwargs...)
68+
output_value, back = CRC.rrule(NNlib.$(func),
69+
value(x),
70+
value(w),
71+
cdims;
72+
kwargs...)
5373
output = track(output_value, tape)
5474
function closure(cls_args...; cls_kwargs...)
5575
return CRC.rrule(NNlib.$(func), value(x), value(w), cdims; kwargs...)
5676
end
57-
ReverseDiff.record!(tape, SpecialInstruction, NNlib.$(func), (x, w, cdims),
58-
output, (back, closure, kwargs))
77+
ReverseDiff.record!(tape,
78+
SpecialInstruction,
79+
NNlib.$(func),
80+
(x, w, cdims),
81+
output,
82+
(back, closure, kwargs))
5983
return output
6084
end
6185

62-
function special_reverse_exec!(instr::SpecialInstruction{typeof(NNlib.$(func)),
63-
<:Tuple{$(xType), $(wType),
64-
ConvDims}})
86+
function special_reverse_exec!(instr::SpecialInstruction{
87+
typeof(NNlib.$(func)),
88+
<:Tuple{$(xType), $(wType), ConvDims},
89+
})
6590
back_output = instr.cache[1](ReverseDiff.deriv(instr.output))
6691
input_derivs = back_output[2:end]
6792
ReverseDiff._add_to_deriv!.(instr.input, input_derivs)
6893
ReverseDiff.unseed!(instr.output)
6994
return nothing
7095
end
7196

72-
function special_forward_exec!(instr::SpecialInstruction{typeof(NNlib.$(func)),
73-
<:Tuple{$(xType), $(wType),
74-
ConvDims}})
97+
function special_forward_exec!(instr::SpecialInstruction{
98+
typeof(NNlib.$(func)),
99+
<:Tuple{$(xType), $(wType), ConvDims},
100+
})
75101
ReverseDiff.pull_value!.(instr.input)
76102
out_value = instr.cache[2](ReverseDiff.value.(instr.input)...;
77-
instr.cache[3]...)
103+
instr.cache[3]...)
78104
ReverseDiff.value!(instr.output, out_value)
79105
return nothing
80106
end

lib/LuxLib/ext/LuxLibTrackerExt.jl

+22-10
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ if isdefined(Base, :get_extension)
55
import Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal
66
else
77
using ..Tracker
8-
import ..Tracker: @grad, data, nobacksies, track, TrackedArray, TrackedVector,
9-
TrackedReal
8+
import ..Tracker: @grad,
9+
data, nobacksies, track, TrackedArray, TrackedVector, TrackedReal
1010
end
1111
using NNlib, LuxLib
12-
import LuxLib: AA, AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅,
13-
__is_tracked
12+
import LuxLib: AA,
13+
AV, _batchnorm_cudnn!, _get_batchnorm_statistics, FP_32_64, ∂∅, __is_tracked
1414
import ChainRulesCore as CRC
1515

1616
# NNlib: batched_mul
@@ -86,14 +86,20 @@ for T1 in (:TrackedArray, :AbstractArray),
8686

8787
__is_tracked(T1, T2, T3) || continue
8888

89-
@eval function LuxLib.groupnorm(x::$T1{T, 4}, scale::$T2{T}, bias::$T3{T}; groups::Int,
90-
epsilon::Real) where {T <: FP_32_64}
89+
@eval function LuxLib.groupnorm(x::$T1{T, 4},
90+
scale::$T2{T},
91+
bias::$T3{T};
92+
groups::Int,
93+
epsilon::Real) where {T <: FP_32_64}
9194
return track(LuxLib.groupnorm, x, scale, bias; groups, epsilon)
9295
end
9396
end
9497

95-
@grad function LuxLib.groupnorm(x::AA{T, 4}, scale::AV{T}, bias::AV{T}; groups::Int,
96-
epsilon::Real) where {T <: FP_32_64}
98+
@grad function LuxLib.groupnorm(x::AA{T, 4},
99+
scale::AV{T},
100+
bias::AV{T};
101+
groups::Int,
102+
epsilon::Real) where {T <: FP_32_64}
97103
LuxLib._assert_same_backend(data(x), data(scale), data(bias))
98104
if length(scale) != length(bias) != size(x, 3)
99105
throw(ArgumentError("Length of `scale` and `bias` must be equal to the number of channels (N - 1 dim of the input array)."))
@@ -104,8 +110,14 @@ end
104110

105111
y, mu, rsig = LuxLib._groupnorm(data(x), groups, data(scale), data(bias), epsilon)
106112
function groupnorm_pullback(dy)
107-
dx, dscale, dbias = LuxLib._dgroupnorm(dy, y, data(x), groups, data(scale),
108-
data(bias), mu, rsig)
113+
dx, dscale, dbias = LuxLib._dgroupnorm(dy,
114+
y,
115+
data(x),
116+
groups,
117+
data(scale),
118+
data(bias),
119+
mu,
120+
rsig)
109121
return nobacksies(:groupnorm, (dx, dscale, dbias))
110122
end
111123
return y, groupnorm_pullback

0 commit comments

Comments
 (0)