Skip to content

Commit

Permalink
Add support for inspecting and extracting some recursive program step…
Browse files Browse the repository at this point in the history
…s (function calls and constants).

Updated regex precompilation to support recursive programs.

PiperOrigin-RevId: 623962418
  • Loading branch information
jnthntatum authored and copybara-github committed Apr 27, 2024
1 parent 76a6894 commit 6ba6dce
Show file tree
Hide file tree
Showing 15 changed files with 497 additions and 107 deletions.
14 changes: 12 additions & 2 deletions eval/compiler/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@ cc_library(
"//base:ast",
"//base/ast_internal:ast_impl",
"//base/ast_internal:expr",
"//common:native_type",
"//common:value",
"//eval/eval:direct_expression_step",
"//eval/eval:evaluator_core",
"//eval/eval:trace_step",
"//internal:casts",
"//runtime:runtime_options",
"//runtime/internal:issue_collector",
"@com_google_absl//absl/algorithm:container",
Expand Down Expand Up @@ -108,6 +111,7 @@ cc_library(
"//eval/eval:select_step",
"//eval/eval:shadowable_value_step",
"//eval/eval:ternary_step",
"//eval/eval:trace_step",
"//eval/public:ast_traverse_native",
"//eval/public:ast_visitor_native",
"//eval/public:cel_type_registry",
Expand Down Expand Up @@ -253,6 +257,7 @@ cc_test(
deps = [
":cel_expression_builder_flat_impl",
":constant_folding",
":regex_precompilation_optimization",
"//eval/eval:cel_expression_flat_impl",
"//eval/public:activation",
"//eval/public:builtin_func_registrar",
Expand Down Expand Up @@ -489,14 +494,17 @@ cc_library(
"//common:native_type",
"//common:value",
"//eval/eval:compiler_constant_step",
"//eval/eval:direct_expression_step",
"//eval/eval:evaluator_core",
"//eval/eval:regex_match_step",
"//internal:casts",
"//internal:status_macros",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_googlesource_code_re2//:re2",
],
)

Expand All @@ -512,16 +520,18 @@ cc_test(
"//base/ast_internal:ast_impl",
"//base/ast_internal:expr",
"//common:memory",
"//common:type",
"//common:value",
"//eval/eval:cel_expression_flat_impl",
"//eval/eval:evaluator_core",
"//eval/public:activation",
"//eval/public:builtin_func_registrar",
"//eval/public:cel_expression",
"//eval/public:cel_options",
"//eval/public:cel_value",
"//internal:testing",
"//parser",
"//runtime:runtime_issue",
"//runtime/internal:issue_collector",
"@com_google_absl//absl/status",
"@com_google_googleapis//google/api/expr/v1alpha1:checked_cc_proto",
"@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto",
"@com_google_protobuf//:protobuf",
Expand Down
11 changes: 10 additions & 1 deletion eval/compiler/cel_expression_builder_flat_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
#include "eval/compiler/constant_folding.h"
#include "eval/compiler/regex_precompilation_optimization.h"
#include "eval/eval/cel_expression_flat_impl.h"
#include "eval/public/activation.h"
#include "eval/public/builtin_func_registrar.h"
Expand Down Expand Up @@ -201,6 +202,8 @@ TEST_P(RecursivePlanTest, ParsedExprRecursiveOptimizedImpl) {
builder.flat_expr_builder().AddProgramOptimizer(
cel::runtime_internal::CreateConstantFoldingOptimizer(
cel::extensions::ProtoMemoryManagerRef(&arena)));
builder.flat_expr_builder().AddProgramOptimizer(
CreateRegexPrecompilationExtension(options.regex_max_program_size));

ASSERT_OK_AND_ASSIGN(std::unique_ptr<CelExpression> plan,
builder.CreateExpression(&parsed_expr.expr(),
Expand Down Expand Up @@ -323,7 +326,13 @@ INSTANTIATE_TEST_SUITE_P(
{"shadowable_value_shadowed", R"(TestEnum.BAR == -1)",
test::IsCelBool(true)},
{"lazily_resolved_function", "LazilyBoundMult(123, 2) == 246",
test::IsCelBool(true)}}),
test::IsCelBool(true)},
{"re_matches", "matches(string_abc, '[ad][be][cf]')",
test::IsCelBool(true)},
{"re_matches_receiver",
"(string_abc + string_def).matches(r'(123)?' + r'abc' + r'def')",
test::IsCelBool(true)},
}),

[](const testing::TestParamInfo<RecursiveTestCase>& info) -> std::string {
return info.param.test_name;
Expand Down
32 changes: 9 additions & 23 deletions eval/compiler/flat_expr_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
#include "eval/eval/select_step.h"
#include "eval/eval/shadowable_value_step.h"
#include "eval/eval/ternary_step.h"
#include "eval/eval/trace_step.h"
#include "eval/public/ast_traverse_native.h"
#include "eval/public/ast_visitor_native.h"
#include "eval/public/source_position_native.h"
Expand Down Expand Up @@ -124,25 +125,6 @@ class IndexManager {
size_t max_slot_count_;
};

class TraceDecorator : public DirectExpressionStep {
public:
explicit TraceDecorator(std::unique_ptr<DirectExpressionStep> expression)
: DirectExpressionStep(-1), expression_(std::move(expression)) {}

absl::Status Evaluate(ExecutionFrameBase& frame, Value& result,
AttributeTrail& trail) const override {
CEL_RETURN_IF_ERROR(expression_->Evaluate(frame, result, trail));
if (!frame.callback()) {
return absl::OkStatus();
}
return frame.callback()(expression_->expr_id(), result,
frame.value_manager());
}

private:
std::unique_ptr<DirectExpressionStep> expression_;
};

// Helper for computing jump offsets.
//
// Jumps should be self-contained to a single expression node -- jumping
Expand Down Expand Up @@ -452,6 +434,14 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {
}
}

auto* subexpression = program_builder_.current();
if (subexpression != nullptr && options_.enable_recursive_tracing &&
subexpression->IsRecursive()) {
auto program = subexpression->ExtractRecursiveProgram();
subexpression->set_recursive_program(
std::make_unique<TraceStep>(std::move(program.step)), program.depth);
}

program_builder_.ExitSubexpression(expr);

if (!comprehension_stack_.empty() &&
Expand Down Expand Up @@ -1426,10 +1416,6 @@ class FlatExprVisitor : public cel::ast_internal::AstVisitor {
"CEL AST traversal out of order in flat_expr_builder."));
return;
}
if (options_.enable_recursive_tracing) {
auto tmp = std::make_unique<TraceDecorator>(std::move(step));
step = std::move(tmp);
}
program_builder_.current()->set_recursive_program(std::move(step), depth);
}

Expand Down
30 changes: 30 additions & 0 deletions eval/compiler/flat_expr_builder_extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,14 @@
#include "base/ast.h"
#include "base/ast_internal/ast_impl.h"
#include "base/ast_internal/expr.h"
#include "common/native_type.h"
#include "common/value.h"
#include "common/value_manager.h"
#include "eval/compiler/resolver.h"
#include "eval/eval/direct_expression_step.h"
#include "eval/eval/evaluator_core.h"
#include "eval/eval/trace_step.h"
#include "internal/casts.h"
#include "runtime/internal/issue_collector.h"
#include "runtime/runtime_options.h"

Expand Down Expand Up @@ -290,6 +293,31 @@ class ProgramBuilder {
std::shared_ptr<SubprogramMap> subprogram_map_;
};

// Attempt to downcast a specific type of recursive step.
template <typename Subclass>
const Subclass* TryDowncastDirectStep(const DirectExpressionStep* step) {
if (step == nullptr) {
return nullptr;
}

auto type_id = step->GetNativeTypeId();
if (type_id == cel::NativeTypeId::For<TraceStep>()) {
const auto* trace_step = cel::internal::down_cast<const TraceStep*>(step);
auto deps = trace_step->GetDependencies();
if (!deps.has_value() || deps->size() != 1) {
return nullptr;
}
step = deps->at(0);
type_id = step->GetNativeTypeId();
}

if (type_id == cel::NativeTypeId::For<Subclass>()) {
return cel::internal::down_cast<const Subclass*>(step);
}

return nullptr;
}

// Class representing FlatExpr internals exposed to extensions.
class PlannerContext {
public:
Expand All @@ -304,6 +332,8 @@ class PlannerContext {
issue_collector_(issue_collector),
program_builder_(program_builder) {}

ProgramBuilder& program_builder() { return program_builder_; }

// Returns true if the subplan is inspectable.
//
// If false, the node is not mapped to a subexpression in the program builder.
Expand Down
Loading

0 comments on commit 6ba6dce

Please # to comment.