Skip to content

Commit de8086c

Browse files
committed
Fix JacVec for not inplace problems
1 parent 5b46c2d commit de8086c

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

src/jacobian.jl

+8-5
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
@concrete struct JacobianWrapper
1+
@concrete struct JacobianWrapper{iip}
22
f
33
p
44
end
55

6-
(uf::JacobianWrapper)(u) = uf.f(u, uf.p)
7-
(uf::JacobianWrapper)(res, u) = uf.f(res, u, uf.p)
6+
# Previous Implementation did not hold onto `iip`, but this causes problems in packages
7+
# where we check for the presence of function signatures to check which dispatch to call
8+
(uf::JacobianWrapper{false})(u) = uf.f(u, uf.p)
9+
(uf::JacobianWrapper{false})(res, u) = (vec(res) .= vec(uf.f(u, uf.p)))
10+
(uf::JacobianWrapper{true})(res, u) = uf.f(res, u, uf.p)
811

912
sparsity_detection_alg(f, ad) = NoSparsityDetection()
1013
function sparsity_detection_alg(f, ad::AbstractSparseADType)
@@ -48,7 +51,7 @@ jacobian!!(::Number, cache) = last(value_derivative(cache.uf, cache.u))
4851
# Build Jacobian Caches
4952
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u, p,
5053
::Val{iip}) where {iip}
51-
uf = JacobianWrapper(f, p)
54+
uf = JacobianWrapper{iip}(f, p)
5255

5356
haslinsolve = hasfield(typeof(alg), :linsolve)
5457

@@ -98,6 +101,6 @@ end
98101
function jacobian_caches(alg::AbstractNonlinearSolveAlgorithm, f, u::Number, p,
99102
::Val{false})
100103
# NOTE: Scalar `u` assumes scalar output from `f`
101-
uf = JacobianWrapper(f, p)
104+
uf = JacobianWrapper{false}(f, p)
102105
return uf, nothing, u, nothing, nothing, u
103106
end

0 commit comments

Comments
 (0)