Skip to content

Commit

Permalink
Add support for looking up type constants and enums.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 680677854
  • Loading branch information
jnthntatum authored and copybara-github committed Sep 30, 2024
1 parent 16cab3d commit bee5bd9
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 3 deletions.
4 changes: 4 additions & 0 deletions checker/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,20 @@ cc_library(
srcs = ["type_check_env.cc"],
hdrs = ["type_check_env.h"],
deps = [
"//common:constant",
"//common:decl",
"//common:type",
"//internal:status_macros",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@com_google_protobuf//:protobuf",
],
)

Expand Down
55 changes: 55 additions & 0 deletions checker/internal/type_check_env.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,22 @@

#include "checker/internal/type_check_env.h"

#include <cstddef>
#include <cstdint>
#include <string>

#include "absl/base/nullability.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "common/constant.h"
#include "common/decl.h"
#include "common/type.h"
#include "common/type_factory.h"
#include "common/type_introspector.h"
#include "internal/status_macros.h"
#include "google/protobuf/arena.h"

namespace cel::checker_internal {

Expand Down Expand Up @@ -64,6 +73,52 @@ absl::StatusOr<absl::optional<Type>> TypeCheckEnv::LookupTypeName(
return absl::nullopt;
}

absl::StatusOr<absl::optional<VariableDecl>> TypeCheckEnv::LookupEnumConstant(
TypeFactory& type_factory, absl::string_view type,
absl::string_view value) const {
const TypeCheckEnv* scope = this;
while (scope != nullptr) {
for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend();
++iter) {
auto enum_constant = (*iter)->FindEnumConstant(type_factory, type, value);
if (!enum_constant.ok()) {
return enum_constant.status();
}
if (enum_constant->has_value()) {
auto decl =
MakeVariableDecl(absl::StrCat((**enum_constant).type_full_name, ".",
(**enum_constant).value_name),
(**enum_constant).type);
decl.set_value(
Constant(static_cast<int64_t>((**enum_constant).number)));
return decl;
}
}
scope = scope->parent_;
}
return absl::nullopt;
}

absl::StatusOr<absl::optional<VariableDecl>> TypeCheckEnv::LookupTypeConstant(
TypeFactory& type_factory, absl::Nonnull<google::protobuf::Arena*> arena,
absl::string_view name) const {
CEL_ASSIGN_OR_RETURN(absl::optional<Type> type,
LookupTypeName(type_factory, name));
if (type.has_value()) {
return MakeVariableDecl(std::string(type->name()), TypeType(arena, *type));
}

if (name.find('.') != name.npos) {
size_t last_dot = name.rfind('.');
absl::string_view enum_name_candidate = name.substr(0, last_dot);
absl::string_view value_name_candidate = name.substr(last_dot + 1);
return LookupEnumConstant(type_factory, enum_name_candidate,
value_name_candidate);
}

return absl::nullopt;
}

absl::StatusOr<absl::optional<StructTypeField>> TypeCheckEnv::LookupStructField(
TypeFactory& type_factory, absl::string_view type_name,
absl::string_view field_name) const {
Expand Down
11 changes: 11 additions & 0 deletions checker/internal/type_check_env.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,12 @@
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "common/constant.h"
#include "common/decl.h"
#include "common/type.h"
#include "common/type_factory.h"
#include "common/type_introspector.h"
#include "google/protobuf/arena.h"

namespace cel::checker_internal {

Expand Down Expand Up @@ -154,10 +157,18 @@ class TypeCheckEnv {
TypeFactory& type_factory, absl::string_view type_name,
absl::string_view field_name) const;

absl::StatusOr<absl::optional<VariableDecl>> LookupTypeConstant(
TypeFactory& type_factory, absl::Nonnull<google::protobuf::Arena*> arena,
absl::string_view type_name) const;

TypeCheckEnv MakeExtendedEnvironment() const { return TypeCheckEnv(this); }
VariableScope MakeVariableScope() const { return VariableScope(*this); }

private:
absl::StatusOr<absl::optional<VariableDecl>> LookupEnumConstant(
TypeFactory& type_factory, absl::string_view type,
absl::string_view value) const;

std::string container_;
absl::Nullable<const TypeCheckEnv*> parent_;

Expand Down
45 changes: 42 additions & 3 deletions checker/internal/type_checker_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,10 @@ class ResolveVisitor : public AstVisitorBase {
absl::string_view function_name,
int arg_count, bool is_receiver);

// Resolves the function call shape (i.e. the number of arguments and call
// style) for the given function call.
absl::Nullable<const VariableDecl*> LookupIdentifier(absl::string_view name);

// Resolves the applicable function overloads for the given function call.
//
// If found, assigns a new function decl with the resolved overloads.
Expand Down Expand Up @@ -900,12 +904,39 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr,
types_[&expr] = resolution->result_type;
}

absl::Nullable<const VariableDecl*> ResolveVisitor::LookupIdentifier(
absl::string_view name) {
if (const VariableDecl* decl = current_scope_->LookupVariable(name);
decl != nullptr) {
return decl;
}
absl::StatusOr<absl::optional<VariableDecl>> constant =
env_->LookupTypeConstant(*type_factory_, arena_, name);

if (!constant.ok()) {
status_.Update(constant.status());
return nullptr;
}

if (constant->has_value()) {
if (constant->value().type().kind() == TypeKind::kEnum) {
// Treat enum constant as just an int after resolving the reference.
// This preserves existing behavior in the other type checkers.
constant->value().set_type(IntType());
}
return google::protobuf::Arena::Create<VariableDecl>(
arena_, std::move(constant).value().value());
}

return nullptr;
}

void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr,
absl::string_view name) {
const VariableDecl* decl = nullptr;
namespace_generator_.GenerateCandidates(
name, [&decl, this](absl::string_view candidate) {
decl = current_scope_->LookupVariable(candidate);
decl = LookupIdentifier(candidate);
// continue searching.
return decl == nullptr;
});
Expand All @@ -931,7 +962,7 @@ void ResolveVisitor::ResolveQualifiedIdentifier(
namespace_generator_.GenerateCandidates(
qualifiers, [&decl, &segment_index_out, this](absl::string_view candidate,
int segment_index) {
decl = current_scope_->LookupVariable(candidate);
decl = LookupIdentifier(candidate);
if (decl != nullptr) {
segment_index_out = segment_index;
return false;
Expand Down Expand Up @@ -984,7 +1015,12 @@ void ResolveVisitor::ResolveSelectOperation(const Expr& expr,
ReportUndefinedField(expr.id(), field, struct_type.name());
return absl::nullopt;
}
return field_info->value().GetType();
auto type = field_info->value().GetType();
if (type.kind() == TypeKind::kEnum) {
// Treat enum as just an int.
return IntType();
}
return type;
}

if (operand_type.kind() == TypeKind::kMap) {
Expand Down Expand Up @@ -1040,6 +1076,9 @@ class ResolveRewriter : public AstRewriterBase {
const VariableDecl* decl = iter->second;
auto& ast_ref = reference_map_[expr.id()];
ast_ref.set_name(decl->name());
if (decl->has_value()) {
ast_ref.set_value(decl->value());
}
expr.mutable_ident_expr().set_name(decl->name());
rewritten = true;
} else if (auto iter = visitor_.functions().find(&expr);
Expand Down
68 changes: 68 additions & 0 deletions checker/internal/type_checker_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,13 @@ absl::Status RegisterMinimalBuiltins(absl::Nonnull<google::protobuf::Arena*> are
MakeOverloadDecl("to_dyn",
/*return_type=*/DynType(), TypeParamType("A"))));

FunctionDecl to_type;
to_type.set_name("type");
CEL_RETURN_IF_ERROR(to_type.AddOverload(
MakeOverloadDecl("to_type",
/*return_type=*/TypeType(arena, TypeParamType("A")),
TypeParamType("A"))));

env.InsertFunctionIfAbsent(std::move(not_op));
env.InsertFunctionIfAbsent(std::move(not_strictly_false));
env.InsertFunctionIfAbsent(std::move(add_op));
Expand All @@ -258,6 +265,7 @@ absl::Status RegisterMinimalBuiltins(absl::Nonnull<google::protobuf::Arena*> are
env.InsertFunctionIfAbsent(std::move(eq_op));
env.InsertFunctionIfAbsent(std::move(ternary_op));
env.InsertFunctionIfAbsent(std::move(to_dyn));
env.InsertFunctionIfAbsent(std::move(to_type));
env.InsertFunctionIfAbsent(std::move(to_duration));
env.InsertFunctionIfAbsent(std::move(to_timestamp));

Expand Down Expand Up @@ -1135,6 +1143,26 @@ TEST(TypeCheckerImplTest, ContainerLookupForMessageCreation) {
"google.protobuf.Int32Value"))));
}

TEST(TypeCheckerImplTest, EnumValueCopiedToReferenceMap) {
TypeCheckEnv env;
env.set_container("google.api.expr.test.v1.proto3");
env.AddTypeProvider(std::make_unique<cel::extensions::ProtoTypeReflector>());

TypeCheckerImpl impl(std::move(env));
ASSERT_OK_AND_ASSIGN(auto ast,
MakeTestParsedAst("TestAllTypes.NestedEnum.BAZ"));
ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast)));

ASSERT_OK_AND_ASSIGN(std::unique_ptr<Ast> checked_ast, result.ReleaseAst());

const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast);
auto ref_iter = ast_impl.reference_map().find(ast_impl.root_expr().id());
ASSERT_NE(ref_iter, ast_impl.reference_map().end());
EXPECT_EQ(ref_iter->second.name(),
"google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum.BAZ");
EXPECT_EQ(ref_iter->second.value().int_value(), 2);
}

struct CheckedExprTestCase {
std::string expr;
ast_internal::Type expected_result_type;
Expand Down Expand Up @@ -1538,6 +1566,32 @@ INSTANTIATE_TEST_SUITE_P(
.expr = "TestAllTypes{map_string_int64: {'string': 1}}",
.expected_result_type = AstType(ast_internal::MessageType(
"google.api.expr.test.v1.proto3.TestAllTypes")),
},
CheckedExprTestCase{
.expr = "TestAllTypes{single_nested_enum: 1}",
.expected_result_type = AstType(ast_internal::MessageType(
"google.api.expr.test.v1.proto3.TestAllTypes")),
},
CheckedExprTestCase{
.expr =
"TestAllTypes{single_nested_enum: TestAllTypes.NestedEnum.BAR}",
.expected_result_type = AstType(ast_internal::MessageType(
"google.api.expr.test.v1.proto3.TestAllTypes")),
},
CheckedExprTestCase{
.expr = "TestAllTypes.NestedEnum.BAR",
.expected_result_type =
AstType(ast_internal::PrimitiveType::kInt64),
},
CheckedExprTestCase{
.expr = "TestAllTypes",
.expected_result_type =
AstType(std::make_unique<AstType>(ast_internal::MessageType(
"google.api.expr.test.v1.proto3.TestAllTypes"))),
},
CheckedExprTestCase{
.expr = "TestAllTypes == type(TestAllTypes{})",
.expected_result_type = AstType(ast_internal::PrimitiveType::kBool),
}));

INSTANTIATE_TEST_SUITE_P(
Expand All @@ -1554,6 +1608,20 @@ INSTANTIATE_TEST_SUITE_P(
.expected_result_type =
AstType(ast_internal::PrimitiveType::kInt64),
},
CheckedExprTestCase{
.expr = "test_msg.single_nested_enum",
.expected_result_type =
AstType(ast_internal::PrimitiveType::kInt64),
},
CheckedExprTestCase{
.expr = "test_msg.single_nested_enum == 1",
.expected_result_type = AstType(ast_internal::PrimitiveType::kBool),
},
CheckedExprTestCase{
.expr =
"test_msg.single_nested_enum == TestAllTypes.NestedEnum.BAR",
.expected_result_type = AstType(ast_internal::PrimitiveType::kBool),
},
CheckedExprTestCase{
.expr = "has(test_msg.not_a_field)",
.expected_result_type = AstType(),
Expand Down
4 changes: 4 additions & 0 deletions checker/internal/type_inference_context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,10 @@ bool TypeInferenceContext::IsAssignableInternal(
return true;
}

if (to_subs.kind() == TypeKind::kEnum && from_subs.kind() == TypeKind::kInt) {
return true;
}

if (IsWildCardType(from_subs) || IsWildCardType(to_subs)) {
return true;
}
Expand Down

0 comments on commit bee5bd9

Please # to comment.