Skip to content

Turing complete expressions #123

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
2 changes: 2 additions & 0 deletions src/DynamicExpressions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ using DispatchDoctor: @stable, @unstable
include("OperatorEnumConstruction.jl")
include("Expression.jl")
include("ExpressionAlgebra.jl")
include("SpecialOperators.jl")
include("Random.jl")
include("Parse.jl")
include("ParametricExpression.jl")
Expand Down Expand Up @@ -76,6 +77,7 @@ import .StringsModule: get_op_name
@reexport import .EvaluateModule:
eval_tree_array, differentiable_eval_tree_array, EvalOptions
import .EvaluateModule: ArrayBuffer
@reexport import .SpecialOperatorsModule: AssignOperator, WhileOperator
@reexport import .EvaluateDerivativeModule: eval_diff_tree_array, eval_grad_tree_array
@reexport import .ChainRulesModule: NodeTangent, extract_gradient
@reexport import .SimplifyModule: combine_operators, simplify_tree!
Expand Down
41 changes: 34 additions & 7 deletions src/Evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ import ..NodeUtilsModule: is_constant
import ..ExtensionInterfaceModule: bumper_eval_tree_array, _is_loopvectorization_loaded
import ..ValueInterfaceModule: is_valid, is_valid_array

# Overloaded by SpecialOperators.jl:
function any_special_operators end
function special_operator end
function deg2_eval_special end
function deg1_eval_special end

const OPERATOR_LIMIT_BEFORE_SLOWDOWN = 15

macro return_on_nonfinite_val(eval_options, val, X)
Expand Down Expand Up @@ -218,6 +224,10 @@ function eval_tree_array(
"Bumper and LoopVectorization features are only compatible with numeric element types",
)
end
if any_special_operators(operators)
cX = copy(cX)
# TODO: This is dangerous if the element type is mutable
end
if _eval_options.bumper isa Val{true}
return bumper_eval_tree_array(tree, cX, operators, _eval_options)
end
Expand Down Expand Up @@ -264,7 +274,7 @@ function _eval_tree_array(
# we can just return the constant result.
if tree.degree == 0
return deg0_eval(tree, cX, eval_options)
elseif is_constant(tree)
elseif !any_special_operators(operators) && is_constant(tree)
# Speed hack for constant trees.
const_result = dispatch_constant_tree(tree, operators)::ResultOk{T}
!const_result.ok &&
Expand Down Expand Up @@ -329,30 +339,37 @@ end
long_compilation_time = nbin > OPERATOR_LIMIT_BEFORE_SLOWDOWN
if long_compilation_time
return quote
op = operators.binops[op_idx]
special_operator(op) &&
return deg2_eval_special(tree, cX, operators, op, eval_options)
result_l = _eval_tree_array(tree.l, cX, operators, eval_options)
!result_l.ok && return result_l
@return_on_nonfinite_array(eval_options, result_l.x)
result_r = _eval_tree_array(tree.r, cX, operators, eval_options)
!result_r.ok && return result_r
@return_on_nonfinite_array(eval_options, result_r.x)
# op(x, y), for any x or y
deg2_eval(result_l.x, result_r.x, operators.binops[op_idx], eval_options)
deg2_eval(result_l.x, result_r.x, op, eval_options)
end
end
return quote
return Base.Cartesian.@nif(
$nbin,
i -> i == op_idx,
i -> let op = operators.binops[i]
if tree.l.degree == 0 && tree.r.degree == 0
if special_operator(op)
deg2_eval_special(tree, cX, operators, op, eval_options)
elseif tree.l.degree == 0 && tree.r.degree == 0
deg2_l0_r0_eval(tree, cX, op, eval_options)
elseif tree.r.degree == 0
result_l = _eval_tree_array(tree.l, cX, operators, eval_options)
!result_l.ok && return result_l
@return_on_nonfinite_array(eval_options, result_l.x)
# op(x, y), where y is a constant or variable but x is not.
deg2_r0_eval(tree, result_l.x, cX, op, eval_options)
elseif tree.l.degree == 0
elseif !any_special_operators(operators) && tree.l.degree == 0
# This branch changes the execution order, so we cannot
# use this branch when special operators are present.
result_r = _eval_tree_array(tree.r, cX, operators, eval_options)
!result_r.ok && return result_r
@return_on_nonfinite_array(eval_options, result_r.x)
Expand Down Expand Up @@ -383,10 +400,13 @@ end
long_compilation_time = nuna > OPERATOR_LIMIT_BEFORE_SLOWDOWN
if long_compilation_time
return quote
op = operators.unaops[op_idx]
special_operator(op) &&
return deg1_eval_special(tree, cX, operators, op, eval_options)
result = _eval_tree_array(tree.l, cX, operators, eval_options)
!result.ok && return result
@return_on_nonfinite_array(eval_options, result.x)
deg1_eval(result.x, operators.unaops[op_idx], eval_options)
deg1_eval(result.x, op, eval_options)
end
end
# This @nif lets us generate an if statement over choice of operator,
Expand All @@ -396,13 +416,20 @@ end
$nuna,
i -> i == op_idx,
i -> let op = operators.unaops[i]
if tree.l.degree == 2 && tree.l.l.degree == 0 && tree.l.r.degree == 0
if special_operator(op)
deg1_eval_special(tree, cX, operators, op, eval_options)
elseif !any_special_operators(operators) &&
tree.l.degree == 2 &&
tree.l.l.degree == 0 &&
tree.l.r.degree == 0
# op(op2(x, y)), where x, y, z are constants or variables.
l_op_idx = tree.l.op
dispatch_deg1_l2_ll0_lr0_eval(
tree, cX, op, l_op_idx, operators.binops, eval_options
)
elseif tree.l.degree == 1 && tree.l.l.degree == 0
elseif !any_special_operators(operators) &&
tree.l.degree == 1 &&
tree.l.l.degree == 0
# op(op2(x)), where x is a constant or variable.
l_op_idx = tree.l.op
dispatch_deg1_l1_ll0_eval(
Expand Down
18 changes: 15 additions & 3 deletions src/Simplify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import ..NodeModule: AbstractExpressionNode, constructorof, Node, copy_node, set
import ..NodeUtilsModule: tree_mapreduce, is_node_constant
import ..OperatorEnumModule: AbstractOperatorEnum
import ..ValueInterfaceModule: is_valid
import ..EvaluateModule: any_special_operators

_una_op_kernel(f::F, l::T) where {F,T} = f(l)
_bin_op_kernel(f::F, l::T, r::T) where {F,T} = f(l, r)
Expand All @@ -19,6 +20,12 @@ combine_operators(tree::AbstractExpressionNode, ::AbstractOperatorEnum) = tree
# This is only defined for `Node` as it is not possible for, e.g.,
# `GraphNode`.
function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where {T}
# Skip simplification if special operators are in use
any_special_operators(operators) && return tree
return _combine_operators(tree, operators)
end

function _combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where {T}
# NOTE: (const (+*-) const) already accounted for. Call simplify_tree! before.
# ((const + var) + const) => (const + var)
# ((const * var) * const) => (const * var)
Expand All @@ -28,10 +35,10 @@ function combine_operators(tree::Node{T}, operators::AbstractOperatorEnum) where
if tree.degree == 0
return tree
elseif tree.degree == 1
tree.l = combine_operators(tree.l, operators)
tree.l = _combine_operators(tree.l, operators)
elseif tree.degree == 2
tree.l = combine_operators(tree.l, operators)
tree.r = combine_operators(tree.r, operators)
tree.l = _combine_operators(tree.l, operators)
tree.r = _combine_operators(tree.r, operators)
end

top_level_constant =
Expand Down Expand Up @@ -123,6 +130,11 @@ end

# Simplify tree
function simplify_tree!(tree::AbstractExpressionNode, operators::AbstractOperatorEnum)
# Skip simplification if special operators are in use
if any_special_operators(operators)
return tree
end

return tree_mapreduce(
identity, (p, c...) -> combine_children!(operators, p, c...), tree, typeof(tree);
)
Expand Down
84 changes: 84 additions & 0 deletions src/SpecialOperators.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
module SpecialOperatorsModule

using ..OperatorEnumModule: OperatorEnum
using ..EvaluateModule:
_eval_tree_array, @return_on_nonfinite_array, deg2_eval, ResultOk, get_filled_array
using ..ExpressionModule: AbstractExpression
using ..ExpressionAlgebraModule: @declare_expression_operator

import ..EvaluateModule:
special_operator, deg2_eval_special, deg1_eval_special, any_special_operators
import ..StringsModule: get_op_name

# Use this to customize evaluation behavior for operators:
@inline special_operator(::Type{F}) where {F} = false
@inline special_operator(::F) where {F} = special_operator(F)

@generated function any_special_operators(
::Union{O,Type{O}}
) where {B,U,O<:OperatorEnum{B,U}}
return any(special_operator, B.types) || any(special_operator, U.types)
end

Base.@kwdef struct AssignOperator <: Function
target_register::Int
end
@declare_expression_operator((op::AssignOperator), 1)
@inline special_operator(::Type{AssignOperator}) = true
get_op_name(o::AssignOperator) = "ASSIGN_OP:{FEATURE_" * string(o.target_register) * "}"

function deg1_eval_special(tree, cX, operators, op::AssignOperator, eval_options)
result = _eval_tree_array(tree.l, cX, operators, eval_options)
!result.ok && return result
@return_on_nonfinite_array(eval_options, result.x)
target_register = op.target_register
@inbounds @simd for i in eachindex(axes(cX, 2))
cX[target_register, i] = result.x[i]
end
return result
end

Base.@kwdef struct WhileOperator <: Function
max_iters::Int = 100
end

@declare_expression_operator((op::WhileOperator), 2)
@inline special_operator(::Type{WhileOperator}) = true
get_op_name(o::WhileOperator) = "while"

# TODO: Need to void any instance of buffer when using while loop.
function deg2_eval_special(tree, cX, operators, op::WhileOperator, eval_options)
cond = tree.l
body = tree.r
mask = trues(size(cX, 2))
X = @view cX[:, mask]
# Initialize the result array for all columns
result_array = get_filled_array(eval_options.buffer, zero(eltype(cX)), cX, axes(cX, 2))
body_result = ResultOk(result_array, true)

for _ in 1:(op.max_iters)
cond_result = _eval_tree_array(cond, X, operators, eval_options)
!cond_result.ok && return cond_result
@return_on_nonfinite_array(eval_options, cond_result.x)

new_mask = cond_result.x .> 0.0
any(new_mask) || return body_result

# Track which columns are still active
mask[mask] .= new_mask
X = @view cX[:, mask]

# Evaluate just for active columns
iter_result = _eval_tree_array(body, X, operators, eval_options)
!iter_result.ok && return iter_result

# Update the corresponding elements in the result array
body_result.x[mask] .= iter_result.x
@return_on_nonfinite_array(eval_options, body_result.x)
end

# We passed max_iters, so this result is invalid
return ResultOk(body_result.x, false)
end

end
49 changes: 42 additions & 7 deletions src/Strings.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,18 @@ end
end
end

const FEATURE_PLACEHOLDER_FIRST_HALF_LENGTH = length("{FEATURE_")
function replace_feature_placeholders(s::String, f_variable::Function, variable_names)
return replace(
s,
r"\{FEATURE_(\d+)\}" =>
m -> f_variable(
parse(Int, m[(begin + FEATURE_PLACEHOLDER_FIRST_HALF_LENGTH):(end - 1)]),
variable_names,
),
)
end

# Can overload these for custom behavior:
needs_brackets(val::Real) = false
needs_brackets(val::AbstractArray) = false
Expand Down Expand Up @@ -104,12 +116,33 @@ function combine_op_with_inputs(op, l, r)::Vector{Char}
end
end
function combine_op_with_inputs(op, l)
# "op(l)"
out = copy(op)
push!(out, '(')
append!(out, strip_brackets(l))
push!(out, ')')
return out
# Check if this is an assignment operator with our special prefix
op_str = String(op)
if startswith(op_str, "ASSIGN_OP:")
# Extract the variable name from the operator name
var_name = op_str[11:end]
# Format: (var ← expr)
out = ['(']
append!(out, collect(var_name))
append!(out, collect(" ← "))
# Ensure the expression is always wrapped in parentheses for clarity
if l[1] == '(' && l[end] == ')'
append!(out, l)
else
push!(out, '(')
append!(out, strip_brackets(l))
push!(out, ')')
end
push!(out, ')')
return out
else
# Regular unary operator: "op(l)"
out = copy(op)
push!(out, '(')
append!(out, strip_brackets(l))
push!(out, ')')
return out
end
end

"""
Expand Down Expand Up @@ -179,7 +212,9 @@ function string_tree(
c
end,
)
return String(strip_brackets(raw_output))
string_output = String(strip_brackets(raw_output))
string_output = replace_feature_placeholders(string_output, f_variable, variable_names)
return string_output
end

# Print an equation
Expand Down
Loading
Loading