diff --git a/Project.toml b/Project.toml index 34aa591b..e92e4b00 100644 --- a/Project.toml +++ b/Project.toml @@ -62,8 +62,6 @@ Zygote = "0.6" julia = "1.6" [extras] -ArrayInterfaceBandedMatrices = "2e50d22c-5be1-4042-81b1-c572ed69783d" -ArrayInterfaceBlockBandedMatrices = "5331f1e9-51c7-46b0-a9b0-df4434785e0a" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" @@ -77,4 +75,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "ArrayInterfaceBandedMatrices", "ArrayInterfaceBlockBandedMatrices", "BandedMatrices", "BlockBandedMatrices", "Enzyme", "IterativeSolvers", "Pkg", "Random", "SafeTestsets", "Symbolics", "Zygote", "StaticArrays"] +test = ["Test", "BandedMatrices", "BlockBandedMatrices", "Enzyme", "IterativeSolvers", "Pkg", "Random", "SafeTestsets", "Symbolics", "Zygote", "StaticArrays"] diff --git a/src/differentiation/common.jl b/src/differentiation/common.jl index 2117986f..c55a8716 100644 --- a/src/differentiation/common.jl +++ b/src/differentiation/common.jl @@ -38,34 +38,70 @@ __internal_oop(::JacFunctionWrapper{iip, oop}) where {iip, oop} = oop (f::JacFunctionWrapper{false, true, 2})(u) = f.f(u, f.p) (f::JacFunctionWrapper{false, true, 3})(u) = f.f(u) -function JacFunctionWrapper(f::F, fu_, u, p, t) where {F} +# NOTE: `use_deprecated_ordering` is a way for external libraries to update to the correct +# style. In the next release, we will drop the first check +function JacFunctionWrapper(f::F, fu_, u, p, t; + use_deprecated_ordering::Val{deporder} = Val(true)) where {F, deporder} # The warning instead of error ensures a non-breaking change for users relying on an # undefined / undocumented feature fu = fu_ === nothing ? copy(u) : copy(fu_) + + if deporder + # Check this first else we were breaking things + # In the next breaking release, we will fix the ordering of the checks + iip = static_hasmethod(f, typeof((fu, u))) + oop = static_hasmethod(f, typeof((u,))) + if iip || oop + if p !== nothing || t !== nothing + Base.depwarn("""`p` and/or `t` provided and are not `nothing`. But we + potentially detected `f(du, u)` or `f(u)`. This can be caused by: + + 1. `f(du, u)` or `f(u)` is defined, in-which case `p` and/or `t` should not + be supplied. + 2. `f(args...)` is defined, in which case `hasmethod` can be spurious. + + Currently, we perform the check for `f(du, u)` and `f(u)` first, but in + future breaking releases, this check will be performed last, which means + that if `t` is provided `f(du, u, p, t)`/`f(u, p, t)` will be given + precedence, similarly if `p` is provided `f(du, u, p)`/`f(u, p)` will be + given precedence.""", :JacFunctionWrapper) + end + return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f, + fu, p, t) + end + end + if t !== nothing iip = static_hasmethod(f, typeof((fu, u, p, t))) oop = static_hasmethod(f, typeof((u, p, t))) if !iip && !oop - @warn """`p` and `t` provided but `f(u, p, t)` or `f(fu, u, p, t)` not defined - for `f`! Will fallback to `f(u)` or `f(fu, u)`.""" maxlog=1 - else - return JacFunctionWrapper{iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)}(f, - fu, p, t) + throw(ArgumentError("""`p` and `t` provided but `f(u, p, t)` or `f(fu, u, p, t)` + not defined for `f`!""")) end + return JacFunctionWrapper{iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)}(f, + fu, p, t) elseif p !== nothing iip = static_hasmethod(f, typeof((fu, u, p))) oop = static_hasmethod(f, typeof((u, p))) if !iip && !oop - @warn """`p` provided but `f(u, p)` or `f(fu, u, p)` not defined for `f`! Will - fallback to `f(u)` or `f(fu, u)`.""" maxlog=1 - else - return JacFunctionWrapper{iip, oop, 2, F, typeof(fu), typeof(p), typeof(t)}(f, - fu, p, t) + throw(ArgumentError("""`p` is provided but `f(u, p)` or `f(fu, u, p)` + not defined for `f`!""")) + end + return JacFunctionWrapper{iip, oop, 2, F, typeof(fu), typeof(p), typeof(t)}(f, + fu, p, t) + end + + if !deporder + iip = static_hasmethod(f, typeof((fu, u))) + oop = static_hasmethod(f, typeof((u,))) + if !iip && !oop + throw(ArgumentError("""`p` is provided but `f(u)` or `f(fu, u)` not defined for + `f`!""")) end + return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f, + fu, p, t) + else + throw(ArgumentError("""Couldn't determine the function signature of `f` to + construct a JacobianWrapper!""")) end - iip = static_hasmethod(f, typeof((fu, u))) - oop = static_hasmethod(f, typeof((u,))) - !iip && !oop && throw(ArgumentError("`f(u)` or `f(fu, u)` not defined for `f`")) - return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f, - fu, p, t) end diff --git a/src/differentiation/jaches_products.jl b/src/differentiation/jaches_products.jl index 2c4b0d79..3c50d210 100644 --- a/src/differentiation/jaches_products.jl +++ b/src/differentiation/jaches_products.jl @@ -263,8 +263,9 @@ f(du, u) # Otherwise ``` """ function JacVec(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing, - autodiff = AutoForwardDiff(), tag = DeivVecTag(), kwargs...) - ff = JacFunctionWrapper(f, fu, u, p, t) + autodiff = AutoForwardDiff(), tag = DeivVecTag(), + use_deprecated_ordering::Val = Val(true), kwargs...) + ff = JacFunctionWrapper(f, fu, u, p, t; use_deprecated_ordering) fu === nothing && (fu = __internal_oop(ff) ? ff(u) : u) cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff diff --git a/src/differentiation/vecjac_products.jl b/src/differentiation/vecjac_products.jl index 4836d958..998c47ac 100644 --- a/src/differentiation/vecjac_products.jl +++ b/src/differentiation/vecjac_products.jl @@ -72,8 +72,8 @@ f(du, u) # Otherwise ``` """ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing, - autodiff = AutoFiniteDiff(), kwargs...) - ff = JacFunctionWrapper(f, fu, u, p, t) + autodiff = AutoFiniteDiff(), use_deprecated_ordering::Val = Val(true), kwargs...) + ff = JacFunctionWrapper(f, fu, u, p, t; use_deprecated_ordering) if !__internal_oop(ff) && autodiff isa AutoZygote msg = "Zygote requires an out of place method with signature f(u)."