diff --git a/Project.toml b/Project.toml index 2dea498..12a7fa5 100644 --- a/Project.toml +++ b/Project.toml @@ -5,11 +5,13 @@ version = "0.1.2" [deps] FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be" +SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [compat] FastDifferentiation = "0.3,0.4" +SciMLOperators = "0.3.12" SparseArrays = "1.10" Symbolics = "4,5,6" julia = "1.10" diff --git a/src/SymbolicTracingUtils.jl b/src/SymbolicTracingUtils.jl index a0fc6c9..605b87b 100644 --- a/src/SymbolicTracingUtils.jl +++ b/src/SymbolicTracingUtils.jl @@ -9,12 +9,15 @@ module SymbolicTracingUtils using Symbolics: Symbolics using FastDifferentiation: FastDifferentiation as FD using SparseArrays: SparseArrays +using SciMLOperators: FunctionOperator export build_function, + build_linear_operator, FastDifferentiationBackend, get_constant_entries, get_result_buffer, gradient, + infer_backend, jacobian, make_variables, sparse_jacobian, @@ -25,6 +28,8 @@ export build_function, struct SymbolicsBackend end struct FastDifferentiationBackend end const SymbolicNumber = Union{Symbolics.Num,FD.Node} +infer_backend(v::Union{Symbolics.Num,AbstractArray{<:Symbolics.Num}}) = SymbolicsBackend() +infer_backend(v::Union{FD.Node,AbstractArray{<:FD.Node}}) = FastDifferentiationBackend() include("tracing.jl") include("derivatives.jl") diff --git a/src/tracing.jl b/src/tracing.jl index bc7cd91..0d04479 100644 --- a/src/tracing.jl +++ b/src/tracing.jl @@ -66,5 +66,37 @@ function build_function( in_place, backend_options = (;), ) where {T<:FD.Node} - FD.make_function(f_symbolic, args_symbolic...; in_place, backend_options...) + f = FD.make_function(f_symbolic, args_symbolic...; in_place, backend_options...) + + if in_place + function (result, args...) + f(result, reduce(vcat, args)) + end + else + function (args...) + f(reduce(vcat, args)) + end + end +end + +""" +Build a linear SciMLOperators.FunctionOperator from a matrix-valued function `A(p)` +to represent the matrix-vector product `A(p) * u` in matrix-free form. +""" +function build_linear_operator(A_of_p::AbstractMatrix{<:SymbolicNumber}, p; in_place) + u = make_variables(infer_backend(A_of_p), gensym(), size(A_of_p)[end]) + A_of_p_times_u = build_function(A_of_p * u, p, u; in_place) + # TODO: also analyze symmetry and other matrix properties to forward to the operator + input_prototype = zeros(size(u)) + p_prototype = zeros(size(p)) + + if in_place + FunctionOperator(input_prototype; p = p_prototype, islinear = true) do result, u, p, _t + A_of_p_times_u(result, p, u) + end + else + FunctionOperator(input_prototype; p = p_prototype, islinear = true) do u, p, _t + A_of_p_times_u(p, u) + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index f060dc9..94aa189 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using SymbolicTracingUtils using Test: @test, @testset, @test_broken -using LinearAlgebra: Diagonal +using LinearAlgebra: Diagonal, mul! using SparseArrays: spzeros, findnz, nnz, rowvals function dummy_function(x) @@ -17,10 +17,13 @@ end global x = make_variables(backend, :x, 10) global fx = dummy_function(x) global x_value = [1:10;] + global y_true = dummy_function(x_value) + global g_true = dummy_function_gradient(x_value) + global J_true = Diagonal(dummy_function_gradient(x_value)) + @testset "non-ad-tracing" begin f = build_function(fx, x; in_place = false) f! = build_function(fx, x; in_place = true) - y_true = dummy_function(x_value) y_out_of_place = f(x_value) y_in_place = zeros(10) f!(y_in_place, x_value) @@ -32,19 +35,17 @@ end gx = gradient(sum(fx), x) g = build_function(gx, x; in_place = false) g! = build_function(gx, x; in_place = true) - y_true = dummy_function_gradient(x_value) - y_out_of_place = g(x_value) - y_in_place = zeros(10) - g!(y_in_place, x_value) - @test y_out_of_place ≈ y_true - @test y_in_place ≈ y_true + g_out_of_place = g(x_value) + g_in_place = zeros(10) + g!(g_in_place, x_value) + @test g_out_of_place ≈ g_true + @test g_in_place ≈ g_true end @testset "jacobian" begin Jx = jacobian(fx, x) J = build_function(Jx, x; in_place = false) J! = build_function(Jx, x; in_place = true) - J_true = Diagonal(dummy_function_gradient(x_value)) J_out_of_place = J(x_value) J_in_place = zeros(10, 10) J!(J_in_place, x_value) @@ -85,6 +86,22 @@ end @test nnz(J_sparse!) == nnz(Jx) # same structure as symbolic version @test rowvals(J_sparse!) == rows end + + @testset "build_linear_operator" begin + J_op = build_linear_operator(Jx, x; in_place = false) + J_op! = build_linear_operator(Jx, x; in_place = true) + v_value = [11.0:20.0;] + Jv_true = J_true * v_value + + J_op.p = x_value + Jv_out_of_place = J_op * v_value + @test Jv_out_of_place ≈ Jv_true + + J_op!.p = x_value + Jv_in_place = zeros(10) + mul!(Jv_in_place, J_op!, v_value) + @test Jv_in_place ≈ Jv_true + end end end end