From 70281aa70f4a3cbfa653b287896a89edc8b375a2 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 15 Nov 2023 17:05:51 -0500 Subject: [PATCH 1/4] Ignore `p` if it is NullParameters --- Project.toml | 2 ++ src/SparseDiffTools.jl | 2 +- src/differentiation/common.jl | 2 +- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 34aa591b..8e8b916e 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -49,6 +50,7 @@ LinearAlgebra = "1.6" PackageExtensionCompat = "1" Random = "1.6" Reexport = "1" +SciMLBase = "2" SciMLOperators = "0.3.7" Setfield = "1" SparseArrays = "1.6" diff --git a/src/SparseDiffTools.jl b/src/SparseDiffTools.jl index 0d971507..0044d93b 100644 --- a/src/SparseDiffTools.jl +++ b/src/SparseDiffTools.jl @@ -18,7 +18,7 @@ import ArrayInterface: matrix_colors import StaticArrays import StaticArrays: StaticArray # Others -using SciMLOperators, LinearAlgebra, Random +using SciMLBase, SciMLOperators, LinearAlgebra, Random import DataStructures: DisjointSets, find_root!, union! import SciMLOperators: update_coefficients, update_coefficients! import Setfield: @set! diff --git a/src/differentiation/common.jl b/src/differentiation/common.jl index 2117986f..35ac12eb 100644 --- a/src/differentiation/common.jl +++ b/src/differentiation/common.jl @@ -52,7 +52,7 @@ function JacFunctionWrapper(f::F, fu_, u, p, t) where {F} return JacFunctionWrapper{iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)}(f, fu, p, t) end - elseif p !== nothing + elseif p !== nothing && !(p isa SciMLBase.NullParameters) iip = static_hasmethod(f, typeof((fu, u, p))) oop = static_hasmethod(f, typeof((u, p))) if !iip && !oop From 1dff8c3ef4c92bfde7a772d5be9a1ed556cb4204 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 15 Nov 2023 17:28:33 -0500 Subject: [PATCH 2/4] Remove unused deps --- Project.toml | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 8e8b916e..8c06e926 100644 --- a/Project.toml +++ b/Project.toml @@ -64,8 +64,6 @@ Zygote = "0.6" julia = "1.6" [extras] -ArrayInterfaceBandedMatrices = "2e50d22c-5be1-4042-81b1-c572ed69783d" -ArrayInterfaceBlockBandedMatrices = "5331f1e9-51c7-46b0-a9b0-df4434785e0a" BandedMatrices = "aae01518-5342-5314-be14-df237901396f" BlockBandedMatrices = "ffab5731-97b5-5995-9138-79e8c1846df0" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" @@ -79,4 +77,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Test", "ArrayInterfaceBandedMatrices", "ArrayInterfaceBlockBandedMatrices", "BandedMatrices", "BlockBandedMatrices", "Enzyme", "IterativeSolvers", "Pkg", "Random", "SafeTestsets", "Symbolics", "Zygote", "StaticArrays"] +test = ["Test", "BandedMatrices", "BlockBandedMatrices", "Enzyme", "IterativeSolvers", "Pkg", "Random", "SafeTestsets", "Symbolics", "Zygote", "StaticArrays"] From 8254f9ec599e476ae84a6016f1b7e889f6a0f00e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 15 Nov 2023 20:41:11 -0500 Subject: [PATCH 3/4] Try to make it non-breaking as much as possible --- Project.toml | 2 -- src/SparseDiffTools.jl | 2 +- src/differentiation/common.jl | 52 ++++++++++++++++++++++++----------- 3 files changed, 37 insertions(+), 19 deletions(-) diff --git a/Project.toml b/Project.toml index 8c06e926..e92e4b00 100644 --- a/Project.toml +++ b/Project.toml @@ -16,7 +16,6 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -50,7 +49,6 @@ LinearAlgebra = "1.6" PackageExtensionCompat = "1" Random = "1.6" Reexport = "1" -SciMLBase = "2" SciMLOperators = "0.3.7" Setfield = "1" SparseArrays = "1.6" diff --git a/src/SparseDiffTools.jl b/src/SparseDiffTools.jl index 0044d93b..0d971507 100644 --- a/src/SparseDiffTools.jl +++ b/src/SparseDiffTools.jl @@ -18,7 +18,7 @@ import ArrayInterface: matrix_colors import StaticArrays import StaticArrays: StaticArray # Others -using SciMLBase, SciMLOperators, LinearAlgebra, Random +using SciMLOperators, LinearAlgebra, Random import DataStructures: DisjointSets, find_root!, union! import SciMLOperators: update_coefficients, update_coefficients! import Setfield: @set! diff --git a/src/differentiation/common.jl b/src/differentiation/common.jl index 35ac12eb..60155546 100644 --- a/src/differentiation/common.jl +++ b/src/differentiation/common.jl @@ -42,30 +42,50 @@ function JacFunctionWrapper(f::F, fu_, u, p, t) where {F} # The warning instead of error ensures a non-breaking change for users relying on an # undefined / undocumented feature fu = fu_ === nothing ? copy(u) : copy(fu_) + + # Check this first else we were breaking things + # In the next breaking release, we will fix the ordering of the checks + iip = static_hasmethod(f, typeof((fu, u))) + oop = static_hasmethod(f, typeof((u,))) + if iip || oop + if p !== nothing || t !== nothing + Base.depwarn("""`p` and/or `t` provided and are not `nothing`. But we + potentially detected `f(du, u)` or `f(u)`. This can be caused by: + + 1. `f(du, u)` or `f(u)` is defined, in-which case `p` and/or `t` should not be + supplied. + 2. `f(args...)` is defined, in which case `hasmethod` can be spurious. + + Currently, we perform the check for `f(du, u)` and `f(u)` first, but in future + breaking releases, this check will be performed last, which means that if `t` + is provided `f(du, u, p, t)`/`f(u, p, t)` will be given precedence, similarly + if `p` is provided `f(du, u, p)`/`f(u, p)` will be given precedence.""", + :JacFunctionWrapper) + end + return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f, + fu, p, t) + end + if t !== nothing iip = static_hasmethod(f, typeof((fu, u, p, t))) oop = static_hasmethod(f, typeof((u, p, t))) if !iip && !oop - @warn """`p` and `t` provided but `f(u, p, t)` or `f(fu, u, p, t)` not defined - for `f`! Will fallback to `f(u)` or `f(fu, u)`.""" maxlog=1 - else - return JacFunctionWrapper{iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)}(f, - fu, p, t) + throw(ArgumentError("""`p` and `t` provided but `f(u, p, t)` or `f(fu, u, p, t)` + not defined for `f`!""")) end - elseif p !== nothing && !(p isa SciMLBase.NullParameters) + return JacFunctionWrapper{iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)}(f, + fu, p, t) + elseif p !== nothing iip = static_hasmethod(f, typeof((fu, u, p))) oop = static_hasmethod(f, typeof((u, p))) if !iip && !oop - @warn """`p` provided but `f(u, p)` or `f(fu, u, p)` not defined for `f`! Will - fallback to `f(u)` or `f(fu, u)`.""" maxlog=1 - else - return JacFunctionWrapper{iip, oop, 2, F, typeof(fu), typeof(p), typeof(t)}(f, - fu, p, t) + throw(ArgumentError("""`p` is provided but `f(u, p)` or `f(fu, u, p)` + not defined for `f`!""")) end + return JacFunctionWrapper{iip, oop, 2, F, typeof(fu), typeof(p), typeof(t)}(f, + fu, p, t) end - iip = static_hasmethod(f, typeof((fu, u))) - oop = static_hasmethod(f, typeof((u,))) - !iip && !oop && throw(ArgumentError("`f(u)` or `f(fu, u)` not defined for `f`")) - return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f, - fu, p, t) + + throw(ArgumentError("""Couldn't determine the function signature of `f` to construct a + JacobianWrapper!""")) end From 77a337a7f0ce3d2e2f4deec6790485def0a89c51 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 16 Nov 2023 10:06:51 -0500 Subject: [PATCH 4/4] Allow a kwarg to force use the correct ordering --- src/differentiation/common.jl | 58 ++++++++++++++++---------- src/differentiation/jaches_products.jl | 5 ++- src/differentiation/vecjac_products.jl | 4 +- 3 files changed, 42 insertions(+), 25 deletions(-) diff --git a/src/differentiation/common.jl b/src/differentiation/common.jl index 60155546..c55a8716 100644 --- a/src/differentiation/common.jl +++ b/src/differentiation/common.jl @@ -38,32 +38,37 @@ __internal_oop(::JacFunctionWrapper{iip, oop}) where {iip, oop} = oop (f::JacFunctionWrapper{false, true, 2})(u) = f.f(u, f.p) (f::JacFunctionWrapper{false, true, 3})(u) = f.f(u) -function JacFunctionWrapper(f::F, fu_, u, p, t) where {F} +# NOTE: `use_deprecated_ordering` is a way for external libraries to update to the correct +# style. In the next release, we will drop the first check +function JacFunctionWrapper(f::F, fu_, u, p, t; + use_deprecated_ordering::Val{deporder} = Val(true)) where {F, deporder} # The warning instead of error ensures a non-breaking change for users relying on an # undefined / undocumented feature fu = fu_ === nothing ? copy(u) : copy(fu_) - # Check this first else we were breaking things - # In the next breaking release, we will fix the ordering of the checks - iip = static_hasmethod(f, typeof((fu, u))) - oop = static_hasmethod(f, typeof((u,))) - if iip || oop - if p !== nothing || t !== nothing - Base.depwarn("""`p` and/or `t` provided and are not `nothing`. But we - potentially detected `f(du, u)` or `f(u)`. This can be caused by: + if deporder + # Check this first else we were breaking things + # In the next breaking release, we will fix the ordering of the checks + iip = static_hasmethod(f, typeof((fu, u))) + oop = static_hasmethod(f, typeof((u,))) + if iip || oop + if p !== nothing || t !== nothing + Base.depwarn("""`p` and/or `t` provided and are not `nothing`. But we + potentially detected `f(du, u)` or `f(u)`. This can be caused by: - 1. `f(du, u)` or `f(u)` is defined, in-which case `p` and/or `t` should not be - supplied. - 2. `f(args...)` is defined, in which case `hasmethod` can be spurious. + 1. `f(du, u)` or `f(u)` is defined, in-which case `p` and/or `t` should not + be supplied. + 2. `f(args...)` is defined, in which case `hasmethod` can be spurious. - Currently, we perform the check for `f(du, u)` and `f(u)` first, but in future - breaking releases, this check will be performed last, which means that if `t` - is provided `f(du, u, p, t)`/`f(u, p, t)` will be given precedence, similarly - if `p` is provided `f(du, u, p)`/`f(u, p)` will be given precedence.""", - :JacFunctionWrapper) + Currently, we perform the check for `f(du, u)` and `f(u)` first, but in + future breaking releases, this check will be performed last, which means + that if `t` is provided `f(du, u, p, t)`/`f(u, p, t)` will be given + precedence, similarly if `p` is provided `f(du, u, p)`/`f(u, p)` will be + given precedence.""", :JacFunctionWrapper) + end + return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f, + fu, p, t) end - return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f, - fu, p, t) end if t !== nothing @@ -86,6 +91,17 @@ function JacFunctionWrapper(f::F, fu_, u, p, t) where {F} fu, p, t) end - throw(ArgumentError("""Couldn't determine the function signature of `f` to construct a - JacobianWrapper!""")) + if !deporder + iip = static_hasmethod(f, typeof((fu, u))) + oop = static_hasmethod(f, typeof((u,))) + if !iip && !oop + throw(ArgumentError("""`p` is provided but `f(u)` or `f(fu, u)` not defined for + `f`!""")) + end + return JacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f, + fu, p, t) + else + throw(ArgumentError("""Couldn't determine the function signature of `f` to + construct a JacobianWrapper!""")) + end end diff --git a/src/differentiation/jaches_products.jl b/src/differentiation/jaches_products.jl index 2c4b0d79..3c50d210 100644 --- a/src/differentiation/jaches_products.jl +++ b/src/differentiation/jaches_products.jl @@ -263,8 +263,9 @@ f(du, u) # Otherwise ``` """ function JacVec(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing, - autodiff = AutoForwardDiff(), tag = DeivVecTag(), kwargs...) - ff = JacFunctionWrapper(f, fu, u, p, t) + autodiff = AutoForwardDiff(), tag = DeivVecTag(), + use_deprecated_ordering::Val = Val(true), kwargs...) + ff = JacFunctionWrapper(f, fu, u, p, t; use_deprecated_ordering) fu === nothing && (fu = __internal_oop(ff) ? ff(u) : u) cache, vecprod, vecprod! = if autodiff isa AutoFiniteDiff diff --git a/src/differentiation/vecjac_products.jl b/src/differentiation/vecjac_products.jl index 4836d958..998c47ac 100644 --- a/src/differentiation/vecjac_products.jl +++ b/src/differentiation/vecjac_products.jl @@ -72,8 +72,8 @@ f(du, u) # Otherwise ``` """ function VecJac(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing, - autodiff = AutoFiniteDiff(), kwargs...) - ff = JacFunctionWrapper(f, fu, u, p, t) + autodiff = AutoFiniteDiff(), use_deprecated_ordering::Val = Val(true), kwargs...) + ff = JacFunctionWrapper(f, fu, u, p, t; use_deprecated_ordering) if !__internal_oop(ff) && autodiff isa AutoZygote msg = "Zygote requires an out of place method with signature f(u)."