Skip to content

Commit 23c7402

Browse files
Merge pull request #17 from LuxDL/auto-juliaformatter-pr
Automatic JuliaFormatter.jl run
2 parents 75223ea + 845c354 commit 23c7402

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

lib/LuxTestUtils/src/LuxTestUtils.jl

+10-6
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,8 @@ function test_gradients_expr(__module__, __source__, f, args...;
251251
rtol::Real=atol > 0 ? 0.0 : eps(typeof(atol)),
252252
nans::Bool=false,
253253
kwargs...)
254-
orig_exprs = map(x -> QuoteNode(Expr(:macrocall,
254+
orig_exprs = map(
255+
x -> QuoteNode(Expr(:macrocall,
255256
GlobalRef(@__MODULE__, Symbol("@test_gradients{$x}")), __source__, f, args...)),
256257
("Tracker", "ReverseDiff", "ForwardDiff", "FiniteDifferences"))
257258
len = length(args)
@@ -269,8 +270,9 @@ function test_gradients_expr(__module__, __source__, f, args...;
269270
skip=skip_reverse_diff)
270271
reverse_diff_broken = $reverse_diff_broken && !skip_reverse_diff
271272

272-
arr_len = length.(filter(Base.Fix2(isa, AbstractArray)
273-
Base.Fix1(__correct_arguments, identity),
273+
arr_len = length.(filter(
274+
Base.Fix2(isa, AbstractArray)
275+
Base.Fix1(__correct_arguments, identity),
274276
tuple($(esc.(args)...))))
275277
large_arrays = any(x -> x $large_array_length, arr_len) ||
276278
sum(arr_len) $max_total_array_size
@@ -365,13 +367,15 @@ function __gradient(gradient_function::F, f, args...; skip::Bool) where {F}
365367
length(args))
366368
end
367369
function __f(inputs...)
368-
updated_inputs = ntuple(i -> aa_inputs[i] ? inputs[__aa_input_idx[i]] : args[i],
370+
updated_inputs = ntuple(
371+
i -> aa_inputs[i] ? inputs[__aa_input_idx[i]] : args[i],
369372
length(args))
370373
return f(updated_inputs...)
371374
end
372375
gs = gradient_function(__f, [corrected_args...][aa_inputs]...)
373-
return ntuple(i -> aa_inputs[i] ?
374-
__uncorrect_arguments(gs[__aa_input_idx[i]],
376+
return ntuple(
377+
i -> aa_inputs[i] ?
378+
__uncorrect_arguments(gs[__aa_input_idx[i]],
375379
args[__aa_input_idx[i]],
376380
corrected_args[__aa_input_idx[i]]) : GradientComputationSkipped(),
377381
length(args))

0 commit comments

Comments
 (0)