@@ -251,7 +251,8 @@ function test_gradients_expr(__module__, __source__, f, args...;
251
251
rtol:: Real = atol > 0 ? 0.0 : √ eps (typeof (atol)),
252
252
nans:: Bool = false ,
253
253
kwargs... )
254
- orig_exprs = map (x -> QuoteNode (Expr (:macrocall ,
254
+ orig_exprs = map (
255
+ x -> QuoteNode (Expr (:macrocall ,
255
256
GlobalRef (@__MODULE__ , Symbol (" @test_gradients{$x }" )), __source__, f, args... )),
256
257
(" Tracker" , " ReverseDiff" , " ForwardDiff" , " FiniteDifferences" ))
257
258
len = length (args)
@@ -269,8 +270,9 @@ function test_gradients_expr(__module__, __source__, f, args...;
269
270
skip= skip_reverse_diff)
270
271
reverse_diff_broken = $ reverse_diff_broken && ! skip_reverse_diff
271
272
272
- arr_len = length .(filter (Base. Fix2 (isa, AbstractArray) ∘
273
- Base. Fix1 (__correct_arguments, identity),
273
+ arr_len = length .(filter (
274
+ Base. Fix2 (isa, AbstractArray) ∘
275
+ Base. Fix1 (__correct_arguments, identity),
274
276
tuple ($ (esc .(args)... ))))
275
277
large_arrays = any (x -> x ≥ $ large_array_length, arr_len) ||
276
278
sum (arr_len) ≥ $ max_total_array_size
@@ -365,13 +367,15 @@ function __gradient(gradient_function::F, f, args...; skip::Bool) where {F}
365
367
length (args))
366
368
end
367
369
function __f (inputs... )
368
- updated_inputs = ntuple (i -> aa_inputs[i] ? inputs[__aa_input_idx[i]] : args[i],
370
+ updated_inputs = ntuple (
371
+ i -> aa_inputs[i] ? inputs[__aa_input_idx[i]] : args[i],
369
372
length (args))
370
373
return f (updated_inputs... )
371
374
end
372
375
gs = gradient_function (__f, [corrected_args... ][aa_inputs]. .. )
373
- return ntuple (i -> aa_inputs[i] ?
374
- __uncorrect_arguments (gs[__aa_input_idx[i]],
376
+ return ntuple (
377
+ i -> aa_inputs[i] ?
378
+ __uncorrect_arguments (gs[__aa_input_idx[i]],
375
379
args[__aa_input_idx[i]],
376
380
corrected_args[__aa_input_idx[i]]) : GradientComputationSkipped (),
377
381
length (args))
0 commit comments