diff --git a/checker/BUILD b/checker/BUILD index 4802bd01f..c52438b7d 100644 --- a/checker/BUILD +++ b/checker/BUILD @@ -84,6 +84,7 @@ cc_library( "//checker/internal:type_check_env", "//checker/internal:type_checker_impl", "//common:decl", + "//common:type", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/functional:any_invocable", "@com_google_absl//absl/status", diff --git a/checker/internal/BUILD b/checker/internal/BUILD index 1eacebbe8..b8a297ae6 100644 --- a/checker/internal/BUILD +++ b/checker/internal/BUILD @@ -61,12 +61,20 @@ cc_library( srcs = ["type_check_env.cc"], hdrs = ["type_check_env.h"], deps = [ + "//common:constant", "//common:decl", + "//common:type", + "//internal:status_macros", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/memory", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", ], ) @@ -118,15 +126,16 @@ cc_library( "//common:constant", "//common:decl", "//common:expr", + "//common:memory", "//common:source", "//common:type", "//common:type_kind", + "//extensions/protobuf:memory_manager", "//internal:status_macros", "@com_google_absl//absl/base:no_destructor", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/container:flat_hash_map", - "@com_google_absl//absl/log:absl_check", - "@com_google_absl//absl/log:absl_log", + "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -152,6 +161,7 @@ cc_test( "//common:ast", "//common:decl", "//common:type", + "//extensions/protobuf:value", "//internal:status_macros", "//internal:testing", "@com_google_absl//absl/base:no_destructor", @@ -161,6 +171,7 @@ cc_test( "@com_google_absl//absl/status", "@com_google_absl//absl/status:status_matchers", "@com_google_absl//absl/strings", + "@com_google_cel_spec//proto/test/v1/proto2:test_all_types_cc_proto", "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", "@com_google_protobuf//:protobuf", ], diff --git a/checker/internal/type_check_env.cc b/checker/internal/type_check_env.cc index e4232dae2..6b428990c 100644 --- a/checker/internal/type_check_env.cc +++ b/checker/internal/type_check_env.cc @@ -14,9 +14,22 @@ #include "checker/internal/type_check_env.h" +#include +#include +#include + #include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/constant.h" #include "common/decl.h" +#include "common/type.h" +#include "common/type_factory.h" +#include "common/type_introspector.h" +#include "internal/status_macros.h" +#include "google/protobuf/arena.h" namespace cel::checker_internal { @@ -44,6 +57,90 @@ absl::Nullable TypeCheckEnv::LookupFunction( return nullptr; } +absl::StatusOr> TypeCheckEnv::LookupTypeName( + TypeFactory& type_factory, absl::string_view name) const { + const TypeCheckEnv* scope = this; + while (scope != nullptr) { + for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); + ++iter) { + auto type = (*iter)->FindType(type_factory, name); + if (!type.ok() || type->has_value()) { + return type; + } + } + scope = scope->parent_; + } + return absl::nullopt; +} + +absl::StatusOr> TypeCheckEnv::LookupEnumConstant( + TypeFactory& type_factory, absl::string_view type, + absl::string_view value) const { + const TypeCheckEnv* scope = this; + while (scope != nullptr) { + for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); + ++iter) { + auto enum_constant = (*iter)->FindEnumConstant(type_factory, type, value); + if (!enum_constant.ok()) { + return enum_constant.status(); + } + if (enum_constant->has_value()) { + auto decl = + MakeVariableDecl(absl::StrCat((**enum_constant).type_full_name, ".", + (**enum_constant).value_name), + (**enum_constant).type); + decl.set_value( + Constant(static_cast((**enum_constant).number))); + return decl; + } + } + scope = scope->parent_; + } + return absl::nullopt; +} + +absl::StatusOr> TypeCheckEnv::LookupTypeConstant( + TypeFactory& type_factory, absl::Nonnull arena, + absl::string_view name) const { + CEL_ASSIGN_OR_RETURN(absl::optional type, + LookupTypeName(type_factory, name)); + if (type.has_value()) { + return MakeVariableDecl(std::string(type->name()), TypeType(arena, *type)); + } + + if (name.find('.') != name.npos) { + size_t last_dot = name.rfind('.'); + absl::string_view enum_name_candidate = name.substr(0, last_dot); + absl::string_view value_name_candidate = name.substr(last_dot + 1); + return LookupEnumConstant(type_factory, enum_name_candidate, + value_name_candidate); + } + + return absl::nullopt; +} + +absl::StatusOr> TypeCheckEnv::LookupStructField( + TypeFactory& type_factory, absl::string_view type_name, + absl::string_view field_name) const { + const TypeCheckEnv* scope = this; + while (scope != nullptr) { + // Check the type providers in reverse registration order. + // Note: this doesn't allow for shadowing a type with a subset type of the + // same name -- the parent type provider will still be considered when + // checking field accesses. + for (auto iter = type_providers_.rbegin(); iter != type_providers_.rend(); + ++iter) { + auto field_info = (*iter)->FindStructTypeFieldByName( + type_factory, type_name, field_name); + if (!field_info.ok() || field_info->has_value()) { + return field_info; + } + } + scope = scope->parent_; + } + return absl::nullopt; +} + absl::Nullable VariableScope::LookupVariable( absl::string_view name) const { const VariableScope* scope = this; diff --git a/checker/internal/type_check_env.h b/checker/internal/type_check_env.h index 1d33a1c03..91ca09f72 100644 --- a/checker/internal/type_check_env.h +++ b/checker/internal/type_check_env.h @@ -18,13 +18,22 @@ #include #include #include +#include #include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" #include "absl/memory/memory.h" +#include "absl/status/statusor.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "absl/types/span.h" +#include "common/constant.h" #include "common/decl.h" +#include "common/type.h" +#include "common/type_factory.h" +#include "common/type_introspector.h" +#include "google/protobuf/arena.h" namespace cel::checker_internal { @@ -68,9 +77,10 @@ class VariableScope { absl::flat_hash_map variables_; }; -// Class managing the type check environment. +// Class managing the state of the type check environment. // -// Maintains lookup maps for variables and functions. +// Maintains lookup maps for variables and functions and the set of type +// providers. // // This class is thread-compatible. class TypeCheckEnv { @@ -95,6 +105,14 @@ class TypeCheckEnv { container_ = std::move(container); } + absl::Span> type_providers() const { + return type_providers_; + } + + void AddTypeProvider(std::unique_ptr provider) { + type_providers_.push_back(std::move(provider)); + } + const absl::flat_hash_map& variables() const { return variables_; } @@ -132,16 +150,34 @@ class TypeCheckEnv { absl::Nullable LookupFunction( absl::string_view name) const; + absl::StatusOr> LookupTypeName( + TypeFactory& type_factory, absl::string_view name) const; + + absl::StatusOr> LookupStructField( + TypeFactory& type_factory, absl::string_view type_name, + absl::string_view field_name) const; + + absl::StatusOr> LookupTypeConstant( + TypeFactory& type_factory, absl::Nonnull arena, + absl::string_view type_name) const; + TypeCheckEnv MakeExtendedEnvironment() const { return TypeCheckEnv(this); } VariableScope MakeVariableScope() const { return VariableScope(*this); } private: + absl::StatusOr> LookupEnumConstant( + TypeFactory& type_factory, absl::string_view type, + absl::string_view value) const; + std::string container_; absl::Nullable parent_; // Maps fully qualified names to declarations. absl::flat_hash_map variables_; absl::flat_hash_map functions_; + + // Type providers for custom types. + std::vector> type_providers_; }; } // namespace cel::checker_internal diff --git a/checker/internal/type_checker_impl.cc b/checker/internal/type_checker_impl.cc index 0493f24db..196192857 100644 --- a/checker/internal/type_checker_impl.cc +++ b/checker/internal/type_checker_impl.cc @@ -24,6 +24,7 @@ #include "absl/base/no_destructor.h" #include "absl/base/nullability.h" #include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" @@ -48,15 +49,31 @@ #include "common/constant.h" #include "common/decl.h" #include "common/expr.h" +#include "common/memory.h" #include "common/source.h" #include "common/type.h" +#include "common/type_factory.h" #include "common/type_kind.h" +#include "extensions/protobuf/memory_manager.h" #include "internal/status_macros.h" #include "google/protobuf/arena.h" namespace cel::checker_internal { namespace { +class TrivialTypeFactory : public TypeFactory { + public: + explicit TrivialTypeFactory(absl::Nonnull arena) + : arena_(arena) {} + + MemoryManagerRef GetMemoryManager() const override { + return extensions::ProtoMemoryManagerRef(arena_); + } + + private: + absl::Nonnull arena_; +}; + using cel::ast_internal::AstImpl; using AstType = cel::ast_internal::Type; @@ -211,6 +228,8 @@ absl::StatusOr FlattenType(const Type& type) { return AstType(ast_internal::DynamicType()); case TypeKind::kType: return FlattenTypeType(type.GetType()); + case TypeKind::kAny: + return AstType(ast_internal::WellKnownType::kAny); default: return absl::InternalError( absl::StrCat("Unsupported type: ", type.DebugString())); @@ -229,7 +248,7 @@ class ResolveVisitor : public AstVisitorBase { const TypeCheckEnv& env, const AstImpl& ast, TypeInferenceContext& inference_context, std::vector& issues, - absl::Nonnull arena) + absl::Nonnull arena, TypeFactory& type_factory) : container_(container), namespace_generator_(std::move(namespace_generator)), env_(&env), @@ -238,6 +257,7 @@ class ResolveVisitor : public AstVisitorBase { ast_(&ast), root_scope_(env.MakeVariableScope()), arena_(arena), + type_factory_(&type_factory), current_scope_(&root_scope_) {} void PreVisitExpr(const Expr& expr) override { expr_stack_.push_back(&expr); } @@ -276,11 +296,7 @@ class ResolveVisitor : public AstVisitorBase { void PostVisitCall(const Expr& expr, const CallExpr& call) override; void PostVisitStruct(const Expr& expr, - const StructExpr& create_struct) override { - // TODO: For now, skip resolving create struct type. To allow - // checking other behaviors. The C++ runtime should still resolve the type - // based on the runtime configuration. - } + const StructExpr& create_struct) override; // Accessors for resolved values. const absl::flat_hash_map& functions() @@ -293,6 +309,10 @@ class ResolveVisitor : public AstVisitorBase { return attributes_; } + const absl::flat_hash_map& struct_types() const { + return struct_types_; + } + const absl::flat_hash_map& types() const { return types_; } const absl::Status& status() const { return status_; } @@ -325,6 +345,10 @@ class ResolveVisitor : public AstVisitorBase { absl::string_view function_name, int arg_count, bool is_receiver); + // Resolves the function call shape (i.e. the number of arguments and call + // style) for the given function call. + absl::Nullable LookupIdentifier(absl::string_view name); + // Resolves the applicable function overloads for the given function call. // // If found, assigns a new function decl with the resolved overloads. @@ -332,6 +356,9 @@ class ResolveVisitor : public AstVisitorBase { int arg_count, bool is_receiver, bool is_namespaced); + void ResolveSelectOperation(const Expr& expr, absl::string_view field, + const Expr& operand); + void ReportMissingReference(const Expr& expr, absl::string_view name) { issues_->push_back(TypeCheckIssue::CreateError( ComputeSourceLocation(*ast_, expr.id()), @@ -339,6 +366,48 @@ class ResolveVisitor : public AstVisitorBase { container_, "')"))); } + void ReportUndefinedField(int64_t expr_id, absl::string_view field_name, + absl::string_view struct_name) { + issues_->push_back(TypeCheckIssue::CreateError( + ComputeSourceLocation(*ast_, expr_id), + absl::StrCat("undefined field '", field_name, "' not found in struct '", + struct_name, "'"))); + } + + absl::Status CheckFieldAssignments(const Expr& expr, + const StructExpr& create_struct, + Type struct_type, + absl::string_view resolved_name) { + for (const auto& field : create_struct.fields()) { + const Expr* value = &field.value(); + Type value_type = GetTypeOrDyn(value); + + // Lookup message type by name to support WellKnownType creation. + CEL_ASSIGN_OR_RETURN( + absl::optional field_info, + env_->LookupStructField(*type_factory_, resolved_name, field.name())); + if (!field_info.has_value()) { + ReportUndefinedField(field.id(), field.name(), resolved_name); + continue; + } + Type field_type = field_info->GetType(); + if (field.optional()) { + field_type = OptionalType(arena_, field_type); + } + if (!inference_context_->IsAssignable(value_type, field_type)) { + issues_->push_back(TypeCheckIssue::CreateError( + ComputeSourceLocation(*ast_, field.id()), + absl::StrCat("expected type of field '", field_info->name(), + "' is '", field_type.DebugString(), + "' but provided type is '", value_type.DebugString(), + "'"))); + continue; + } + } + + return absl::OkStatus(); + } + // TODO: This should switch to a failing check once all core // features are supported. For now, we allow dyn for implementing the // typechecker behaviors in isolation. @@ -355,12 +424,17 @@ class ResolveVisitor : public AstVisitorBase { absl::Nonnull ast_; VariableScope root_scope_; absl::Nonnull arena_; + absl::Nonnull type_factory_; // state tracking for the traversal. const VariableScope* current_scope_; std::vector expr_stack_; absl::flat_hash_map> maybe_namespaced_functions_; + // Select operations that need to be resolved outside of the traversal. + // These are handled separately to disambiguate between namespaces and field + // accesses + absl::flat_hash_set deferred_select_operations_; absl::Status status_; std::vector> comprehension_vars_; std::vector comprehension_scopes_; @@ -368,6 +442,7 @@ class ResolveVisitor : public AstVisitorBase { // References that were resolved and may require AST rewrites. absl::flat_hash_map functions_; absl::flat_hash_map attributes_; + absl::flat_hash_map struct_types_; absl::flat_hash_map types_; }; @@ -403,7 +478,11 @@ void ResolveVisitor::PostVisitIdent(const Expr& expr, const IdentExpr& ident) { } qualifiers.push_back(parent->select_expr().field()); + deferred_select_operations_.insert(parent); root_candidate = parent; + if (parent->select_expr().test_only()) { + break; + } } if (receiver_call == nullptr) { @@ -562,6 +641,51 @@ void ResolveVisitor::PostVisitList(const Expr& expr, const ListExpr& list) { types_[&expr] = inference_context_->InstantiateTypeParams(FreeListType()); } +void ResolveVisitor::PostVisitStruct(const Expr& expr, + const StructExpr& create_struct) { + absl::Status status; + std::string resolved_name; + Type resolved_type; + namespace_generator_.GenerateCandidates( + create_struct.name(), [&](const absl::string_view name) { + auto type = env_->LookupTypeName(*type_factory_, name); + if (!type.ok()) { + status.Update(type.status()); + return false; + } else if (type->has_value()) { + resolved_name = name; + resolved_type = **type; + return false; + } + return true; + }); + + if (!status.ok()) { + status_.Update(status); + return; + } + + if (resolved_name.empty()) { + ReportMissingReference(expr, create_struct.name()); + return; + } + + if (resolved_type.kind() != TypeKind::kStruct && + !IsWellKnownMessageType(resolved_name)) { + issues_->push_back(TypeCheckIssue::CreateError( + ComputeSourceLocation(*ast_, expr.id()), + absl::StrCat("type '", resolved_name, + "' does not support message creation"))); + return; + } + + types_[&expr] = resolved_type; + struct_types_[&expr] = resolved_name; + + status_.Update( + CheckFieldAssignments(expr, create_struct, resolved_type, resolved_name)); +} + void ResolveVisitor::PostVisitCall(const Expr& expr, const CallExpr& call) { // Handle disambiguation of namespaced functions. if (auto iter = maybe_namespaced_functions_.find(&expr); @@ -706,6 +830,13 @@ void ResolveVisitor::PostVisitComprehensionSubexpression( } } +void ResolveVisitor::PostVisitSelect(const Expr& expr, + const SelectExpr& select) { + if (!deferred_select_operations_.contains(&expr)) { + ResolveSelectOperation(expr, select.field(), select.operand()); + } +} + const FunctionDecl* ResolveVisitor::ResolveFunctionCallShape( const Expr& expr, absl::string_view function_name, int arg_count, bool is_receiver) { @@ -773,12 +904,39 @@ void ResolveVisitor::ResolveFunctionOverloads(const Expr& expr, types_[&expr] = resolution->result_type; } +absl::Nullable ResolveVisitor::LookupIdentifier( + absl::string_view name) { + if (const VariableDecl* decl = current_scope_->LookupVariable(name); + decl != nullptr) { + return decl; + } + absl::StatusOr> constant = + env_->LookupTypeConstant(*type_factory_, arena_, name); + + if (!constant.ok()) { + status_.Update(constant.status()); + return nullptr; + } + + if (constant->has_value()) { + if (constant->value().type().kind() == TypeKind::kEnum) { + // Treat enum constant as just an int after resolving the reference. + // This preserves existing behavior in the other type checkers. + constant->value().set_type(IntType()); + } + return google::protobuf::Arena::Create( + arena_, std::move(constant).value().value()); + } + + return nullptr; +} + void ResolveVisitor::ResolveSimpleIdentifier(const Expr& expr, absl::string_view name) { const VariableDecl* decl = nullptr; namespace_generator_.GenerateCandidates( name, [&decl, this](absl::string_view candidate) { - decl = current_scope_->LookupVariable(candidate); + decl = LookupIdentifier(candidate); // continue searching. return decl == nullptr; }); @@ -804,7 +962,7 @@ void ResolveVisitor::ResolveQualifiedIdentifier( namespace_generator_.GenerateCandidates( qualifiers, [&decl, &segment_index_out, this](absl::string_view candidate, int segment_index) { - decl = current_scope_->LookupVariable(candidate); + decl = LookupIdentifier(candidate); if (decl != nullptr) { segment_index_out = segment_index; return false; @@ -819,16 +977,87 @@ void ResolveVisitor::ResolveQualifiedIdentifier( const int num_select_opts = qualifiers.size() - segment_index_out - 1; const Expr* root = &expr; + std::vector select_opts; + select_opts.reserve(num_select_opts); for (int i = 0; i < num_select_opts; ++i) { + select_opts.push_back(root); root = &root->select_expr().operand(); } attributes_[root] = decl; types_[root] = inference_context_->InstantiateTypeParams(decl->type()); + + // fix-up select operations that were deferred. + for (auto iter = select_opts.rbegin(); iter != select_opts.rend(); ++iter) { + ResolveSelectOperation(**iter, (*iter)->select_expr().field(), + (*iter)->select_expr().operand()); + } } -void ResolveVisitor::PostVisitSelect(const Expr& expr, - const SelectExpr& select) {} +void ResolveVisitor::ResolveSelectOperation(const Expr& expr, + absl::string_view field, + const Expr& operand) { + auto impl = [&](const Type& operand_type) -> absl::optional { + if (operand_type.kind() == TypeKind::kDyn || + operand_type.kind() == TypeKind::kAny) { + return DynType(); + } + + if (operand_type.kind() == TypeKind::kStruct) { + StructType struct_type = operand_type.GetStruct(); + auto field_info = + env_->LookupStructField(*type_factory_, struct_type.name(), field); + if (!field_info.ok()) { + status_.Update(field_info.status()); + return absl::nullopt; + } + if (!field_info->has_value()) { + ReportUndefinedField(expr.id(), field, struct_type.name()); + return absl::nullopt; + } + auto type = field_info->value().GetType(); + if (type.kind() == TypeKind::kEnum) { + // Treat enum as just an int. + return IntType(); + } + return type; + } + + if (operand_type.kind() == TypeKind::kMap) { + MapType map_type = operand_type.GetMap(); + if (inference_context_->IsAssignable(StringType(), map_type.GetKey())) { + return map_type.GetValue(); + } + // else fall though. + } + + issues_->push_back(TypeCheckIssue::CreateError( + ComputeSourceLocation(*ast_, expr.id()), + absl::StrCat("expression of type '", operand_type.DebugString(), + "' cannot be the operand of a select operation"))); + return absl::nullopt; + }; + + const Type& operand_type = GetTypeOrDyn(&operand); + + absl::optional result_type; + // Support short-hand optional chaining. + if (operand_type.IsOptional()) { + auto optional_type = operand_type.GetOptional(); + Type held_type = optional_type.GetParameter(); + result_type = impl(held_type); + } else { + result_type = impl(operand_type); + } + + if (result_type.has_value()) { + if (expr.select_expr().test_only()) { + types_[&expr] = BoolType(); + } else { + types_[&expr] = *result_type; + } + } +} class ResolveRewriter : public AstRewriterBase { public: @@ -847,6 +1076,9 @@ class ResolveRewriter : public AstRewriterBase { const VariableDecl* decl = iter->second; auto& ast_ref = reference_map_[expr.id()]; ast_ref.set_name(decl->name()); + if (decl->has_value()) { + ast_ref.set_value(decl->value()); + } expr.mutable_ident_expr().set_name(decl->name()); rewritten = true; } else if (auto iter = visitor_.functions().find(&expr); @@ -864,6 +1096,14 @@ class ResolveRewriter : public AstRewriterBase { expr.mutable_call_expr().set_target(nullptr); } rewritten = true; + } else if (auto iter = visitor_.struct_types().find(&expr); + iter != visitor_.struct_types().end()) { + auto& ast_ref = reference_map_[expr.id()]; + ast_ref.set_name(iter->second); + if (expr.has_struct_expr()) { + expr.mutable_struct_expr().set_name(iter->second); + } + rewritten = true; } if (auto iter = visitor_.types().find(&expr); @@ -904,8 +1144,10 @@ absl::StatusOr TypeCheckerImpl::Check( NamespaceGenerator::Create(env_.container())); TypeInferenceContext type_inference_context(&type_arena); + TrivialTypeFactory type_factory(&type_arena); ResolveVisitor visitor(env_.container(), std::move(generator), env_, ast_impl, - type_inference_context, issues, &type_arena); + type_inference_context, issues, &type_arena, + type_factory); TraversalOptions opts; opts.use_comprehension_callbacks = true; diff --git a/checker/internal/type_checker_impl_test.cc b/checker/internal/type_checker_impl_test.cc index fd40efe15..71f018db1 100644 --- a/checker/internal/type_checker_impl_test.cc +++ b/checker/internal/type_checker_impl_test.cc @@ -38,10 +38,14 @@ #include "common/ast.h" #include "common/decl.h" #include "common/type.h" +#include "common/type_introspector.h" +#include "extensions/protobuf/type_reflector.h" #include "internal/status_macros.h" #include "internal/testing.h" +#include "proto/test/v1/proto2/test_all_types.pb.h" #include "proto/test/v1/proto3/test_all_types.pb.h" #include "google/protobuf/arena.h" +#include "google/protobuf/message.h" namespace cel { namespace checker_internal { @@ -55,14 +59,16 @@ using ::google::api::expr::test::v1::proto3::TestAllTypes; using ::testing::_; using ::testing::Contains; using ::testing::ElementsAre; +using ::testing::Eq; using ::testing::IsEmpty; using ::testing::Pair; - -using AstType = cel::ast_internal::Type; +using ::testing::Property; using AstType = ast_internal::Type; using Severity = TypeCheckIssue::Severity; +namespace testpb3 = ::google::api::expr::test::v1::proto3; + std::string SevString(Severity severity) { switch (severity) { case Severity::kDeprecated: @@ -147,115 +153,130 @@ MATCHER_P2(IsFunctionReference, fn_name, overloads, "") { return reference.name() == fn_name && got_overload_set == want_overload_set; } -class TypeCheckerImplTest : public ::testing::Test { - public: - TypeCheckerImplTest() = default; - - absl::Status RegisterMinimalBuiltins(TypeCheckEnv& env) { - Type list_of_a = ListType(&arena_, TypeParamType("A")); - - FunctionDecl add_op; - - add_op.set_name("_+_"); - CEL_RETURN_IF_ERROR(add_op.AddOverload( - MakeOverloadDecl("add_int_int", IntType(), IntType(), IntType()))); - CEL_RETURN_IF_ERROR(add_op.AddOverload( - MakeOverloadDecl("add_uint_uint", UintType(), UintType(), UintType()))); - CEL_RETURN_IF_ERROR(add_op.AddOverload(MakeOverloadDecl( - "add_double_double", DoubleType(), DoubleType(), DoubleType()))); - - CEL_RETURN_IF_ERROR(add_op.AddOverload( - MakeOverloadDecl("add_list", list_of_a, list_of_a, list_of_a))); - - FunctionDecl not_op; - not_op.set_name("!_"); - CEL_RETURN_IF_ERROR(not_op.AddOverload( - MakeOverloadDecl("logical_not", - /*return_type=*/BoolType{}, BoolType{}))); - FunctionDecl not_strictly_false; - not_strictly_false.set_name("@not_strictly_false"); - CEL_RETURN_IF_ERROR(not_strictly_false.AddOverload( - MakeOverloadDecl("not_strictly_false", - /*return_type=*/BoolType{}, DynType{}))); - FunctionDecl mult_op; - mult_op.set_name("_*_"); - CEL_RETURN_IF_ERROR(mult_op.AddOverload( - MakeOverloadDecl("mult_int_int", - /*return_type=*/IntType(), IntType(), IntType()))); - FunctionDecl or_op; - or_op.set_name("_||_"); - CEL_RETURN_IF_ERROR(or_op.AddOverload( - MakeOverloadDecl("logical_or", - /*return_type=*/BoolType{}, BoolType{}, BoolType{}))); - - FunctionDecl and_op; - and_op.set_name("_&&_"); - CEL_RETURN_IF_ERROR(and_op.AddOverload( - MakeOverloadDecl("logical_and", - /*return_type=*/BoolType{}, BoolType{}, BoolType{}))); - - FunctionDecl lt_op; - lt_op.set_name("_<_"); - CEL_RETURN_IF_ERROR(lt_op.AddOverload( - MakeOverloadDecl("lt_int_int", - /*return_type=*/BoolType{}, IntType(), IntType()))); - - FunctionDecl gt_op; - gt_op.set_name("_>_"); - CEL_RETURN_IF_ERROR(gt_op.AddOverload( - MakeOverloadDecl("gt_int_int", - /*return_type=*/BoolType{}, IntType(), IntType()))); - - FunctionDecl eq_op; - eq_op.set_name("_==_"); - CEL_RETURN_IF_ERROR(eq_op.AddOverload(MakeOverloadDecl( - "equals", - /*return_type=*/BoolType{}, TypeParamType("A"), TypeParamType("A")))); - - FunctionDecl ternary_op; - ternary_op.set_name("_?_:_"); - CEL_RETURN_IF_ERROR(eq_op.AddOverload( - MakeOverloadDecl("conditional", - /*return_type=*/ - TypeParamType("A"), BoolType{}, TypeParamType("A"), - TypeParamType("A")))); - - FunctionDecl to_int; - to_int.set_name("int"); - CEL_RETURN_IF_ERROR(to_int.AddOverload( - MakeOverloadDecl("to_int", - /*return_type=*/IntType(), DynType()))); - - FunctionDecl to_dyn; - to_dyn.set_name("dyn"); - CEL_RETURN_IF_ERROR(to_dyn.AddOverload( - MakeOverloadDecl("to_dyn", - /*return_type=*/DynType(), TypeParamType("A")))); - - env.InsertFunctionIfAbsent(std::move(not_op)); - env.InsertFunctionIfAbsent(std::move(not_strictly_false)); - env.InsertFunctionIfAbsent(std::move(add_op)); - env.InsertFunctionIfAbsent(std::move(mult_op)); - env.InsertFunctionIfAbsent(std::move(or_op)); - env.InsertFunctionIfAbsent(std::move(and_op)); - env.InsertFunctionIfAbsent(std::move(lt_op)); - env.InsertFunctionIfAbsent(std::move(gt_op)); - env.InsertFunctionIfAbsent(std::move(to_int)); - env.InsertFunctionIfAbsent(std::move(eq_op)); - env.InsertFunctionIfAbsent(std::move(ternary_op)); - env.InsertFunctionIfAbsent(std::move(to_dyn)); - - return absl::OkStatus(); - } - - private: - google::protobuf::Arena arena_; -}; +absl::Status RegisterMinimalBuiltins(absl::Nonnull arena, + TypeCheckEnv& env) { + Type list_of_a = ListType(arena, TypeParamType("A")); + + FunctionDecl add_op; + + add_op.set_name("_+_"); + CEL_RETURN_IF_ERROR(add_op.AddOverload( + MakeOverloadDecl("add_int_int", IntType(), IntType(), IntType()))); + CEL_RETURN_IF_ERROR(add_op.AddOverload( + MakeOverloadDecl("add_uint_uint", UintType(), UintType(), UintType()))); + CEL_RETURN_IF_ERROR(add_op.AddOverload(MakeOverloadDecl( + "add_double_double", DoubleType(), DoubleType(), DoubleType()))); + + CEL_RETURN_IF_ERROR(add_op.AddOverload( + MakeOverloadDecl("add_list", list_of_a, list_of_a, list_of_a))); + + FunctionDecl not_op; + not_op.set_name("!_"); + CEL_RETURN_IF_ERROR(not_op.AddOverload( + MakeOverloadDecl("logical_not", + /*return_type=*/BoolType{}, BoolType{}))); + FunctionDecl not_strictly_false; + not_strictly_false.set_name("@not_strictly_false"); + CEL_RETURN_IF_ERROR(not_strictly_false.AddOverload( + MakeOverloadDecl("not_strictly_false", + /*return_type=*/BoolType{}, DynType{}))); + FunctionDecl mult_op; + mult_op.set_name("_*_"); + CEL_RETURN_IF_ERROR(mult_op.AddOverload( + MakeOverloadDecl("mult_int_int", + /*return_type=*/IntType(), IntType(), IntType()))); + FunctionDecl or_op; + or_op.set_name("_||_"); + CEL_RETURN_IF_ERROR(or_op.AddOverload( + MakeOverloadDecl("logical_or", + /*return_type=*/BoolType{}, BoolType{}, BoolType{}))); + + FunctionDecl and_op; + and_op.set_name("_&&_"); + CEL_RETURN_IF_ERROR(and_op.AddOverload( + MakeOverloadDecl("logical_and", + /*return_type=*/BoolType{}, BoolType{}, BoolType{}))); + + FunctionDecl lt_op; + lt_op.set_name("_<_"); + CEL_RETURN_IF_ERROR(lt_op.AddOverload( + MakeOverloadDecl("lt_int_int", + /*return_type=*/BoolType{}, IntType(), IntType()))); + + FunctionDecl gt_op; + gt_op.set_name("_>_"); + CEL_RETURN_IF_ERROR(gt_op.AddOverload( + MakeOverloadDecl("gt_int_int", + /*return_type=*/BoolType{}, IntType(), IntType()))); + + FunctionDecl eq_op; + eq_op.set_name("_==_"); + CEL_RETURN_IF_ERROR(eq_op.AddOverload(MakeOverloadDecl( + "equals", + /*return_type=*/BoolType{}, TypeParamType("A"), TypeParamType("A")))); + + FunctionDecl ternary_op; + ternary_op.set_name("_?_:_"); + CEL_RETURN_IF_ERROR(eq_op.AddOverload(MakeOverloadDecl( + "conditional", + /*return_type=*/ + TypeParamType("A"), BoolType{}, TypeParamType("A"), TypeParamType("A")))); + + FunctionDecl to_int; + to_int.set_name("int"); + CEL_RETURN_IF_ERROR(to_int.AddOverload( + MakeOverloadDecl("to_int", + /*return_type=*/IntType(), DynType()))); + + FunctionDecl to_duration; + to_duration.set_name("duration"); + CEL_RETURN_IF_ERROR(to_duration.AddOverload( + MakeOverloadDecl("to_duration", + /*return_type=*/DurationType(), StringType()))); + + FunctionDecl to_timestamp; + to_timestamp.set_name("timestamp"); + CEL_RETURN_IF_ERROR(to_timestamp.AddOverload( + MakeOverloadDecl("to_timestamp", + /*return_type=*/TimestampType(), IntType()))); + + FunctionDecl to_dyn; + to_dyn.set_name("dyn"); + CEL_RETURN_IF_ERROR(to_dyn.AddOverload( + MakeOverloadDecl("to_dyn", + /*return_type=*/DynType(), TypeParamType("A")))); + + FunctionDecl to_type; + to_type.set_name("type"); + CEL_RETURN_IF_ERROR(to_type.AddOverload( + MakeOverloadDecl("to_type", + /*return_type=*/TypeType(arena, TypeParamType("A")), + TypeParamType("A")))); + + env.InsertFunctionIfAbsent(std::move(not_op)); + env.InsertFunctionIfAbsent(std::move(not_strictly_false)); + env.InsertFunctionIfAbsent(std::move(add_op)); + env.InsertFunctionIfAbsent(std::move(mult_op)); + env.InsertFunctionIfAbsent(std::move(or_op)); + env.InsertFunctionIfAbsent(std::move(and_op)); + env.InsertFunctionIfAbsent(std::move(lt_op)); + env.InsertFunctionIfAbsent(std::move(gt_op)); + env.InsertFunctionIfAbsent(std::move(to_int)); + env.InsertFunctionIfAbsent(std::move(eq_op)); + env.InsertFunctionIfAbsent(std::move(ternary_op)); + env.InsertFunctionIfAbsent(std::move(to_dyn)); + env.InsertFunctionIfAbsent(std::move(to_type)); + env.InsertFunctionIfAbsent(std::move(to_duration)); + env.InsertFunctionIfAbsent(std::move(to_timestamp)); + + return absl::OkStatus(); +} -TEST_F(TypeCheckerImplTest, SmokeTest) { +TEST(TypeCheckerImplTest, SmokeTest) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("1 + 2")); @@ -266,10 +287,11 @@ TEST_F(TypeCheckerImplTest, SmokeTest) { EXPECT_THAT(result.GetIssues(), IsEmpty()); } -TEST_F(TypeCheckerImplTest, SimpleIdentsResolved) { +TEST(TypeCheckerImplTest, SimpleIdentsResolved) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); env.InsertVariableIfAbsent(MakeVariableDecl("y", IntType())); @@ -283,10 +305,11 @@ TEST_F(TypeCheckerImplTest, SimpleIdentsResolved) { EXPECT_THAT(result.GetIssues(), IsEmpty()); } -TEST_F(TypeCheckerImplTest, ReportMissingIdentDecl) { +TEST(TypeCheckerImplTest, ReportMissingIdentDecl) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); @@ -301,10 +324,11 @@ TEST_F(TypeCheckerImplTest, ReportMissingIdentDecl) { "undeclared reference to 'y'"))); } -TEST_F(TypeCheckerImplTest, QualifiedIdentsResolved) { +TEST(TypeCheckerImplTest, QualifiedIdentsResolved) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x.y", IntType())); env.InsertVariableIfAbsent(MakeVariableDecl("x.z", IntType())); @@ -318,10 +342,11 @@ TEST_F(TypeCheckerImplTest, QualifiedIdentsResolved) { EXPECT_THAT(result.GetIssues(), IsEmpty()); } -TEST_F(TypeCheckerImplTest, ReportMissingQualifiedIdentDecl) { +TEST(TypeCheckerImplTest, ReportMissingQualifiedIdentDecl) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); @@ -336,10 +361,11 @@ TEST_F(TypeCheckerImplTest, ReportMissingQualifiedIdentDecl) { Severity::kError, "undeclared reference to 'y.x'"))); } -TEST_F(TypeCheckerImplTest, ResolveMostQualfiedIdent) { +TEST(TypeCheckerImplTest, ResolveMostQualfiedIdent) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); env.InsertVariableIfAbsent(MakeVariableDecl("x.y", MapType())); @@ -354,7 +380,7 @@ TEST_F(TypeCheckerImplTest, ResolveMostQualfiedIdent) { Contains(Pair(_, IsVariableReference("x.y")))); } -TEST_F(TypeCheckerImplTest, MemberFunctionCallResolved) { +TEST(TypeCheckerImplTest, MemberFunctionCallResolved) { TypeCheckEnv env; env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); @@ -377,7 +403,7 @@ TEST_F(TypeCheckerImplTest, MemberFunctionCallResolved) { EXPECT_THAT(result.GetIssues(), IsEmpty()); } -TEST_F(TypeCheckerImplTest, MemberFunctionCallNotDeclared) { +TEST(TypeCheckerImplTest, MemberFunctionCallNotDeclared) { TypeCheckEnv env; env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); @@ -394,7 +420,7 @@ TEST_F(TypeCheckerImplTest, MemberFunctionCallNotDeclared) { Severity::kError, "undeclared reference to 'foo'"))); } -TEST_F(TypeCheckerImplTest, FunctionShapeMismatch) { +TEST(TypeCheckerImplTest, FunctionShapeMismatch) { TypeCheckEnv env; // foo(int, int) -> int ASSERT_OK_AND_ASSIGN( @@ -413,7 +439,7 @@ TEST_F(TypeCheckerImplTest, FunctionShapeMismatch) { Severity::kError, "undeclared reference to 'foo'"))); } -TEST_F(TypeCheckerImplTest, NamespaceFunctionCallResolved) { +TEST(TypeCheckerImplTest, NamespaceFunctionCallResolved) { TypeCheckEnv env; // Variables env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); @@ -443,10 +469,40 @@ TEST_F(TypeCheckerImplTest, NamespaceFunctionCallResolved) { EXPECT_FALSE(ast_impl.root_expr().call_expr().has_target()); } -TEST_F(TypeCheckerImplTest, MixedListTypeToDyn) { +TEST(TypeCheckerImplTest, NamespacedFunctionSkipsFieldCheck) { + TypeCheckEnv env; + // Variables + env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); + + // add x.foo as a namespaced function. + FunctionDecl foo; + foo.set_name("x.y.foo"); + ASSERT_THAT( + foo.AddOverload(MakeOverloadDecl("x_y_foo_int", + /*return_type=*/IntType(), IntType())), + IsOk()); + env.InsertFunctionIfAbsent(std::move(foo)); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("x.y.foo(x)")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + EXPECT_TRUE(result.IsValid()); + EXPECT_THAT(result.GetIssues(), IsEmpty()); + + ASSERT_OK_AND_ASSIGN(auto checked_ast, result.ReleaseAst()); + auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); + EXPECT_TRUE(ast_impl.root_expr().has_call_expr()) + << absl::StrCat("kind: ", ast_impl.root_expr().kind().index()); + EXPECT_EQ(ast_impl.root_expr().call_expr().function(), "x.y.foo"); + EXPECT_FALSE(ast_impl.root_expr().call_expr().has_target()); +} + +TEST(TypeCheckerImplTest, MixedListTypeToDyn) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("[1, 'a']")); @@ -459,10 +515,11 @@ TEST_F(TypeCheckerImplTest, MixedListTypeToDyn) { EXPECT_TRUE(ast_impl.type_map().at(1).list_type().elem_type().has_dyn()); } -TEST_F(TypeCheckerImplTest, FreeListTypeToDyn) { +TEST(TypeCheckerImplTest, FreeListTypeToDyn) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("[]")); @@ -475,10 +532,11 @@ TEST_F(TypeCheckerImplTest, FreeListTypeToDyn) { EXPECT_TRUE(ast_impl.type_map().at(1).list_type().elem_type().has_dyn()); } -TEST_F(TypeCheckerImplTest, FreeMapTypeToDyn) { +TEST(TypeCheckerImplTest, FreeMapTypeToDyn) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{}")); @@ -492,10 +550,11 @@ TEST_F(TypeCheckerImplTest, FreeMapTypeToDyn) { EXPECT_TRUE(ast_impl.type_map().at(1).map_type().value_type().has_dyn()); } -TEST_F(TypeCheckerImplTest, MapTypeWithMixedKeys) { +TEST(TypeCheckerImplTest, MapTypeWithMixedKeys) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{'a': 1, 2: 3}")); @@ -510,10 +569,11 @@ TEST_F(TypeCheckerImplTest, MapTypeWithMixedKeys) { ast_internal::PrimitiveType::kInt64); } -TEST_F(TypeCheckerImplTest, MapTypeUnsupportedKeyWarns) { +TEST(TypeCheckerImplTest, MapTypeUnsupportedKeyWarns) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{{}: 'a'}")); @@ -526,10 +586,11 @@ TEST_F(TypeCheckerImplTest, MapTypeUnsupportedKeyWarns) { "unsupported map key type:"))); } -TEST_F(TypeCheckerImplTest, MapTypeWithMixedValues) { +TEST(TypeCheckerImplTest, MapTypeWithMixedValues) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("{'a': 1, 'b': '2'}")); @@ -544,10 +605,11 @@ TEST_F(TypeCheckerImplTest, MapTypeWithMixedValues) { EXPECT_TRUE(ast_impl.type_map().at(1).map_type().value_type().has_dyn()); } -TEST_F(TypeCheckerImplTest, ComprehensionVariablesResolved) { +TEST(TypeCheckerImplTest, ComprehensionVariablesResolved) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, @@ -559,10 +621,11 @@ TEST_F(TypeCheckerImplTest, ComprehensionVariablesResolved) { EXPECT_THAT(result.GetIssues(), IsEmpty()); } -TEST_F(TypeCheckerImplTest, MapComprehensionVariablesResolved) { +TEST(TypeCheckerImplTest, MapComprehensionVariablesResolved) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, @@ -574,10 +637,11 @@ TEST_F(TypeCheckerImplTest, MapComprehensionVariablesResolved) { EXPECT_THAT(result.GetIssues(), IsEmpty()); } -TEST_F(TypeCheckerImplTest, NestedComprehensions) { +TEST(TypeCheckerImplTest, NestedComprehensions) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN( @@ -590,10 +654,11 @@ TEST_F(TypeCheckerImplTest, NestedComprehensions) { EXPECT_THAT(result.GetIssues(), IsEmpty()); } -TEST_F(TypeCheckerImplTest, ComprehensionVarsFollowNamespacePriorityRules) { +TEST(TypeCheckerImplTest, ComprehensionVarsFollowNamespacePriorityRules) { TypeCheckEnv env; env.set_container("com"); - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); // Namespace resolution still applies, compre var doesn't shadow com.x env.InsertVariableIfAbsent(MakeVariableDecl("com.x", IntType())); @@ -612,9 +677,10 @@ TEST_F(TypeCheckerImplTest, ComprehensionVarsFollowNamespacePriorityRules) { Contains(Pair(_, IsVariableReference("com.x")))); } -TEST_F(TypeCheckerImplTest, ComprehensionVarsFollowQualifiedIdentPriority) { +TEST(TypeCheckerImplTest, ComprehensionVarsFollowQualifiedIdentPriority) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); // Namespace resolution still applies, compre var doesn't shadow x.y env.InsertVariableIfAbsent(MakeVariableDecl("x.y", IntType())); @@ -816,7 +882,7 @@ INSTANTIATE_TEST_SUITE_P( .expected_type = AstType(ast_internal::MessageType( "google.api.expr.test.v1.proto3.TestAllTypes"))})); -TEST_F(TypeCheckerImplTest, NullLiteral) { +TEST(TypeCheckerImplTest, NullLiteral) { TypeCheckEnv env; TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("null")); @@ -828,9 +894,10 @@ TEST_F(TypeCheckerImplTest, NullLiteral) { EXPECT_TRUE(ast_impl.type_map()[1].has_null()); } -TEST_F(TypeCheckerImplTest, ComprehensionUnsupportedRange) { +TEST(TypeCheckerImplTest, ComprehensionUnsupportedRange) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("y", IntType())); @@ -846,9 +913,10 @@ TEST_F(TypeCheckerImplTest, ComprehensionUnsupportedRange) { "the range of a comprehension"))); } -TEST_F(TypeCheckerImplTest, ComprehensionDynRange) { +TEST(TypeCheckerImplTest, ComprehensionDynRange) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("range", DynType())); @@ -861,9 +929,10 @@ TEST_F(TypeCheckerImplTest, ComprehensionDynRange) { EXPECT_THAT(result.GetIssues(), IsEmpty()); } -TEST_F(TypeCheckerImplTest, BasicOvlResolution) { +TEST(TypeCheckerImplTest, BasicOvlResolution) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", DoubleType())); env.InsertVariableIfAbsent(MakeVariableDecl("y", DoubleType())); @@ -884,9 +953,10 @@ TEST_F(TypeCheckerImplTest, BasicOvlResolution) { "_+_", std::vector{"add_double_double"})); } -TEST_F(TypeCheckerImplTest, OvlResolutionMultipleOverloads) { +TEST(TypeCheckerImplTest, OvlResolutionMultipleOverloads) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", DoubleType())); env.InsertVariableIfAbsent(MakeVariableDecl("y", DoubleType())); @@ -908,9 +978,10 @@ TEST_F(TypeCheckerImplTest, OvlResolutionMultipleOverloads) { "add_list", "add_uint_uint"})); } -TEST_F(TypeCheckerImplTest, BasicFunctionResultTypeResolution) { +TEST(TypeCheckerImplTest, BasicFunctionResultTypeResolution) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", DoubleType())); env.InsertVariableIfAbsent(MakeVariableDecl("y", DoubleType())); @@ -938,9 +1009,10 @@ TEST_F(TypeCheckerImplTest, BasicFunctionResultTypeResolution) { ast_internal::PrimitiveType::kDouble); } -TEST_F(TypeCheckerImplTest, BasicOvlResolutionNoMatch) { +TEST(TypeCheckerImplTest, BasicOvlResolutionNoMatch) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); env.InsertVariableIfAbsent(MakeVariableDecl("y", StringType())); @@ -957,9 +1029,10 @@ TEST_F(TypeCheckerImplTest, BasicOvlResolutionNoMatch) { " applied to (int, string)"))); } -TEST_F(TypeCheckerImplTest, ParmeterizedOvlResolutionMatch) { +TEST(TypeCheckerImplTest, ParmeterizedOvlResolutionMatch) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertVariableIfAbsent(MakeVariableDecl("x", IntType())); env.InsertVariableIfAbsent(MakeVariableDecl("y", StringType())); @@ -971,9 +1044,10 @@ TEST_F(TypeCheckerImplTest, ParmeterizedOvlResolutionMatch) { EXPECT_TRUE(result.IsValid()); } -TEST_F(TypeCheckerImplTest, AliasedTypeVarSameType) { +TEST(TypeCheckerImplTest, AliasedTypeVarSameType) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, @@ -987,9 +1061,11 @@ TEST_F(TypeCheckerImplTest, AliasedTypeVarSameType) { Severity::kError, "no matching overload for '_==_' applied to"))); } -TEST_F(TypeCheckerImplTest, TypeVarRange) { +TEST(TypeCheckerImplTest, TypeVarRange) { TypeCheckEnv env; - ASSERT_THAT(RegisterMinimalBuiltins(env), IsOk()); + google::protobuf::Arena arena; + + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); env.InsertFunctionIfAbsent(MakeIdentFunction()); TypeCheckerImpl impl(std::move(env)); ASSERT_OK_AND_ASSIGN(auto ast, @@ -999,6 +1075,716 @@ TEST_F(TypeCheckerImplTest, TypeVarRange) { EXPECT_TRUE(result.IsValid()) << absl::StrJoin(result.GetIssues(), "\n"); } +TEST(TypeCheckerImplTest, WellKnownTypeCreation) { + TypeCheckEnv env; + env.AddTypeProvider(std::make_unique()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN( + auto ast, MakeTestParsedAst("google.protobuf.Int32Value{value: 10}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); + EXPECT_THAT(ast_impl.type_map(), + Contains(Pair(ast_impl.root_expr().id(), + Eq(AstType(ast_internal::PrimitiveTypeWrapper( + ast_internal::PrimitiveType::kInt64)))))); + EXPECT_THAT(ast_impl.reference_map(), + Contains(Pair(ast_impl.root_expr().id(), + Property(&ast_internal::Reference::name, + "google.protobuf.Int32Value")))); +} + +TEST(TypeCheckerImplTest, TypeInferredFromStructCreation) { + TypeCheckEnv env; + env.AddTypeProvider(std::make_unique()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("google.protobuf.Struct{fields: {}}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); + int64_t map_expr_id = + ast_impl.root_expr().struct_expr().fields().at(0).value().id(); + ASSERT_NE(map_expr_id, 0); + EXPECT_THAT( + ast_impl.type_map(), + Contains(Pair( + map_expr_id, + Eq(AstType(ast_internal::MapType( + std::make_unique(ast_internal::PrimitiveType::kString), + std::make_unique(ast_internal::DynamicType()))))))); +} + +TEST(TypeCheckerImplTest, ContainerLookupForMessageCreation) { + TypeCheckEnv env; + env.set_container("google.protobuf"); + env.AddTypeProvider(std::make_unique()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst("Int32Value{value: 10}")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); + EXPECT_THAT(ast_impl.type_map(), + Contains(Pair(ast_impl.root_expr().id(), + Eq(AstType(ast_internal::PrimitiveTypeWrapper( + ast_internal::PrimitiveType::kInt64)))))); + EXPECT_THAT(ast_impl.reference_map(), + Contains(Pair(ast_impl.root_expr().id(), + Property(&ast_internal::Reference::name, + "google.protobuf.Int32Value")))); +} + +TEST(TypeCheckerImplTest, EnumValueCopiedToReferenceMap) { + TypeCheckEnv env; + env.set_container("google.api.expr.test.v1.proto3"); + env.AddTypeProvider(std::make_unique()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, + MakeTestParsedAst("TestAllTypes.NestedEnum.BAZ")); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); + auto ref_iter = ast_impl.reference_map().find(ast_impl.root_expr().id()); + ASSERT_NE(ref_iter, ast_impl.reference_map().end()); + EXPECT_EQ(ref_iter->second.name(), + "google.api.expr.test.v1.proto3.TestAllTypes.NestedEnum.BAZ"); + EXPECT_EQ(ref_iter->second.value().int_value(), 2); +} + +struct CheckedExprTestCase { + std::string expr; + ast_internal::Type expected_result_type; + std::string error_substring; +}; + +class WktCreationTest : public testing::TestWithParam {}; + +TEST_P(WktCreationTest, MessageCreation) { + const CheckedExprTestCase& test_case = GetParam(); + TypeCheckEnv env; + env.AddTypeProvider(std::make_unique()); + env.set_container("google.protobuf"); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + if (!test_case.error_substring.empty()) { + EXPECT_THAT(result.GetIssues(), + Contains(IsIssueWithSubstring(Severity::kError, + test_case.error_substring))); + return; + } + + ASSERT_TRUE(result.IsValid()) + << absl::StrJoin(result.GetIssues(), "\n", + [](std::string* out, const TypeCheckIssue& issue) { + absl::StrAppend(out, issue.message()); + }); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); + EXPECT_THAT(ast_impl.type_map(), + Contains(Pair(ast_impl.root_expr().id(), + Eq(test_case.expected_result_type)))); +} + +INSTANTIATE_TEST_SUITE_P( + WellKnownTypes, WktCreationTest, + ::testing::Values( + CheckedExprTestCase{ + .expr = "google.protobuf.Int32Value{value: 10}", + .expected_result_type = AstType(ast_internal::PrimitiveTypeWrapper( + ast_internal::PrimitiveType::kInt64)), + }, + CheckedExprTestCase{ + .expr = ".google.protobuf.Int32Value{value: 10}", + .expected_result_type = AstType(ast_internal::PrimitiveTypeWrapper( + ast_internal::PrimitiveType::kInt64)), + }, + CheckedExprTestCase{ + .expr = "Int32Value{value: 10}", + .expected_result_type = AstType(ast_internal::PrimitiveTypeWrapper( + ast_internal::PrimitiveType::kInt64)), + }, + CheckedExprTestCase{ + .expr = "google.protobuf.Int32Value{value: '10'}", + .expected_result_type = AstType(), + .error_substring = "expected type of field 'value' is 'int' but " + "provided type is 'string'"}, + CheckedExprTestCase{ + .expr = "google.protobuf.Int32Value{not_a_field: '10'}", + .expected_result_type = AstType(), + .error_substring = "undefined field 'not_a_field' not found in " + "struct 'google.protobuf.Int32Value'"}, + CheckedExprTestCase{ + .expr = "NotAType{not_a_field: '10'}", + .expected_result_type = AstType(), + .error_substring = + "undeclared reference to 'NotAType' (in container " + "'google.protobuf')"}, + CheckedExprTestCase{ + .expr = ".protobuf.Int32Value{value: 10}", + .expected_result_type = AstType(), + .error_substring = + "undeclared reference to '.protobuf.Int32Value' (in container " + "'google.protobuf')"}, + CheckedExprTestCase{ + .expr = "Int32Value{value: 10}.value", + .expected_result_type = AstType(), + .error_substring = + "expression of type 'google.protobuf.Int64Value' cannot be the " + "operand of a select operation"}, + CheckedExprTestCase{ + .expr = "Int64Value{value: 10}", + .expected_result_type = AstType(ast_internal::PrimitiveTypeWrapper( + ast_internal::PrimitiveType::kInt64)), + }, + CheckedExprTestCase{ + .expr = "BoolValue{value: true}", + .expected_result_type = AstType(ast_internal::PrimitiveTypeWrapper( + ast_internal::PrimitiveType::kBool)), + }, + CheckedExprTestCase{ + .expr = "UInt64Value{value: 10u}", + .expected_result_type = AstType(ast_internal::PrimitiveTypeWrapper( + ast_internal::PrimitiveType::kUint64)), + }, + CheckedExprTestCase{ + .expr = "UInt32Value{value: 10u}", + .expected_result_type = AstType(ast_internal::PrimitiveTypeWrapper( + ast_internal::PrimitiveType::kUint64)), + }, + CheckedExprTestCase{ + .expr = "FloatValue{value: 1.25}", + .expected_result_type = AstType(ast_internal::PrimitiveTypeWrapper( + ast_internal::PrimitiveType::kDouble)), + }, + CheckedExprTestCase{ + .expr = "DoubleValue{value: 1.25}", + .expected_result_type = AstType(ast_internal::PrimitiveTypeWrapper( + ast_internal::PrimitiveType::kDouble)), + }, + CheckedExprTestCase{ + .expr = "StringValue{value: 'test'}", + .expected_result_type = AstType(ast_internal::PrimitiveTypeWrapper( + ast_internal::PrimitiveType::kString)), + }, + CheckedExprTestCase{ + .expr = "BytesValue{value: b'test'}", + .expected_result_type = AstType(ast_internal::PrimitiveTypeWrapper( + ast_internal::PrimitiveType::kBytes)), + }, + CheckedExprTestCase{ + .expr = "Duration{seconds: 10, nanos: 11}", + .expected_result_type = + AstType(ast_internal::WellKnownType::kDuration), + }, + CheckedExprTestCase{ + .expr = "Timestamp{seconds: 10, nanos: 11}", + .expected_result_type = + AstType(ast_internal::WellKnownType::kTimestamp), + }, + CheckedExprTestCase{ + .expr = "Struct{fields: {'key': 'value'}}", + .expected_result_type = AstType(ast_internal::MapType( + std::make_unique(ast_internal::PrimitiveType::kString), + std::make_unique(ast_internal::DynamicType()))), + }, + CheckedExprTestCase{ + .expr = "ListValue{values: [1, 2, 3]}", + .expected_result_type = AstType(ast_internal::ListType( + std::make_unique(ast_internal::DynamicType()))), + }, + CheckedExprTestCase{ + .expr = R"cel( + Any{ + type_url:'type.googleapis.com/google.protobuf.Int32Value', + value: b'' + })cel", + .expected_result_type = AstType(ast_internal::WellKnownType::kAny), + })); + +class GenericMessagesTest : public testing::TestWithParam { +}; + +TEST_P(GenericMessagesTest, TypeChecksProto3) { + const CheckedExprTestCase& test_case = GetParam(); + google::protobuf::Arena arena; + + TypeCheckEnv env; + env.AddTypeProvider(std::make_unique()); + env.set_container("google.api.expr.test.v1.proto3"); + google::protobuf::LinkMessageReflection(); + + ASSERT_TRUE(env.InsertVariableIfAbsent(MakeVariableDecl( + "test_msg", MessageType(testpb3::TestAllTypes::descriptor())))); + ASSERT_THAT(RegisterMinimalBuiltins(&arena, env), IsOk()); + + TypeCheckerImpl impl(std::move(env)); + ASSERT_OK_AND_ASSIGN(auto ast, MakeTestParsedAst(test_case.expr)); + ASSERT_OK_AND_ASSIGN(ValidationResult result, impl.Check(std::move(ast))); + + if (!test_case.error_substring.empty()) { + EXPECT_THAT(result.GetIssues(), + Contains(IsIssueWithSubstring(Severity::kError, + test_case.error_substring))); + return; + } + + ASSERT_TRUE(result.IsValid()) + << absl::StrJoin(result.GetIssues(), "\n", + [](std::string* out, const TypeCheckIssue& issue) { + absl::StrAppend(out, issue.message()); + }); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr checked_ast, result.ReleaseAst()); + + const auto& ast_impl = AstImpl::CastFromPublicAst(*checked_ast); + EXPECT_THAT(ast_impl.type_map(), + Contains(Pair(ast_impl.root_expr().id(), + Eq(test_case.expected_result_type)))); +} + +INSTANTIATE_TEST_SUITE_P( + TestAllTypesCreation, GenericMessagesTest, + ::testing::Values( + CheckedExprTestCase{ + .expr = "TestAllTypes{not_a_field: 10}", + .expected_result_type = AstType(), + .error_substring = + "undefined field 'not_a_field' not found in " + "struct 'google.api.expr.test.v1.proto3.TestAllTypes'"}, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_int64: 10}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_int64: 'string'}", + .expected_result_type = AstType(), + .error_substring = + "expected type of field 'single_int64' is 'int' but " + "provided type is 'string'"}, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_int32: 10}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_uint64: 10u}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_uint32: 10u}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_sint64: 10}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_sint32: 10}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_fixed64: 10u}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_fixed32: 10u}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_sfixed64: 10}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_sfixed32: 10}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_double: 1.25}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_float: 1.25}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_string: 'string'}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_bool: true}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_bytes: b'string'}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + // Well-known + CheckedExprTestCase{ + .expr = "TestAllTypes{single_any: TestAllTypes{single_int64: 10}}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_any: 1}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_any: 'string'}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_any: ['string']}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_duration: duration('1s')}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_timestamp: timestamp(0)}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_struct: {}}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_struct: {'key': 'value'}}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_struct: {1: 2}}", + .expected_result_type = AstType(), + .error_substring = "expected type of field 'single_struct' is " + "'map' but " + "provided type is 'map'"}, + CheckedExprTestCase{ + .expr = "TestAllTypes{list_value: [1, 2, 3]}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{list_value: []}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{list_value: 1}", + .expected_result_type = AstType(), + .error_substring = + "expected type of field 'list_value' is 'list' but " + "provided type is 'int'"}, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_int64_wrapper: 1}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_int64_wrapper: null}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_value: null}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_value: 1.0}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_value: 'string'}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_value: {'string': 'string'}}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_value: ['string']}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{repeated_int64: [1, 2, 3]}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{repeated_int64: ['string']}", + .expected_result_type = AstType(), + .error_substring = + "expected type of field 'repeated_int64' is 'list'"}, + CheckedExprTestCase{ + .expr = "TestAllTypes{map_string_int64: ['string']}", + .expected_result_type = AstType(), + .error_substring = "expected type of field 'map_string_int64' is " + "'map'"}, + CheckedExprTestCase{ + .expr = "TestAllTypes{map_string_int64: {'string': 1}}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes{single_nested_enum: 1}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = + "TestAllTypes{single_nested_enum: TestAllTypes.NestedEnum.BAR}", + .expected_result_type = AstType(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes")), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes.NestedEnum.BAR", + .expected_result_type = + AstType(ast_internal::PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes", + .expected_result_type = + AstType(std::make_unique(ast_internal::MessageType( + "google.api.expr.test.v1.proto3.TestAllTypes"))), + }, + CheckedExprTestCase{ + .expr = "TestAllTypes == type(TestAllTypes{})", + .expected_result_type = AstType(ast_internal::PrimitiveType::kBool), + })); + +INSTANTIATE_TEST_SUITE_P( + TestAllTypesFieldSelection, GenericMessagesTest, + ::testing::Values( + CheckedExprTestCase{ + .expr = "test_msg.not_a_field", + .expected_result_type = AstType(), + .error_substring = + "undefined field 'not_a_field' not found in " + "struct 'google.api.expr.test.v1.proto3.TestAllTypes'"}, + CheckedExprTestCase{ + .expr = "test_msg.single_int64", + .expected_result_type = + AstType(ast_internal::PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_nested_enum", + .expected_result_type = + AstType(ast_internal::PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_nested_enum == 1", + .expected_result_type = AstType(ast_internal::PrimitiveType::kBool), + }, + CheckedExprTestCase{ + .expr = + "test_msg.single_nested_enum == TestAllTypes.NestedEnum.BAR", + .expected_result_type = AstType(ast_internal::PrimitiveType::kBool), + }, + CheckedExprTestCase{ + .expr = "has(test_msg.not_a_field)", + .expected_result_type = AstType(), + .error_substring = + "undefined field 'not_a_field' not found in " + "struct 'google.api.expr.test.v1.proto3.TestAllTypes'"}, + CheckedExprTestCase{ + .expr = "has(test_msg.single_int64)", + .expected_result_type = AstType(ast_internal::PrimitiveType::kBool), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_int32", + .expected_result_type = + AstType(ast_internal::PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_uint64", + .expected_result_type = + AstType(ast_internal::PrimitiveType::kUint64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_uint32", + .expected_result_type = + AstType(ast_internal::PrimitiveType::kUint64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_sint64", + .expected_result_type = + AstType(ast_internal::PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_sint32", + .expected_result_type = + AstType(ast_internal::PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_fixed64", + .expected_result_type = + AstType(ast_internal::PrimitiveType::kUint64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_fixed32", + .expected_result_type = + AstType(ast_internal::PrimitiveType::kUint64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_sfixed64", + .expected_result_type = + AstType(ast_internal::PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_sfixed32", + .expected_result_type = + AstType(ast_internal::PrimitiveType::kInt64), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_float", + .expected_result_type = + AstType(ast_internal::PrimitiveType::kDouble), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_double", + .expected_result_type = + AstType(ast_internal::PrimitiveType::kDouble), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_string", + .expected_result_type = + AstType(ast_internal::PrimitiveType::kString), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_bool", + .expected_result_type = AstType(ast_internal::PrimitiveType::kBool), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_bytes", + .expected_result_type = + AstType(ast_internal::PrimitiveType::kBytes), + }, + // Basic tests for containers. This is covered in more detail in + // conformance tests and the type provider implementation. + CheckedExprTestCase{ + .expr = "test_msg.repeated_int32", + .expected_result_type = + AstType(ast_internal::ListType(std::make_unique( + ast_internal::PrimitiveType::kInt64))), + }, + CheckedExprTestCase{ + .expr = "test_msg.repeated_string", + .expected_result_type = + AstType(ast_internal::ListType(std::make_unique( + ast_internal::PrimitiveType::kString))), + }, + CheckedExprTestCase{ + .expr = "test_msg.map_bool_bool", + .expected_result_type = AstType(ast_internal::MapType( + std::make_unique(ast_internal::PrimitiveType::kBool), + std::make_unique(ast_internal::PrimitiveType::kBool))), + }, + CheckedExprTestCase{ + .expr = "test_msg.map_bool_bool.field_like_key", + .expected_result_type = AstType(), + .error_substring = + "expression of type 'map' cannot be the operand" + " of a select operation", + }, + CheckedExprTestCase{ + .expr = "test_msg.map_string_int64", + .expected_result_type = AstType(ast_internal::MapType( + std::make_unique(ast_internal::PrimitiveType::kString), + std::make_unique( + ast_internal::PrimitiveType::kInt64))), + }, + CheckedExprTestCase{ + .expr = "test_msg.map_string_int64.field_like_key", + .expected_result_type = + AstType(ast_internal::PrimitiveType::kInt64), + }, + // Well-known + CheckedExprTestCase{ + .expr = "test_msg.single_duration", + .expected_result_type = + AstType(ast_internal::WellKnownType::kDuration), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_timestamp", + .expected_result_type = + AstType(ast_internal::WellKnownType::kTimestamp), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_any", + .expected_result_type = AstType(ast_internal::WellKnownType::kAny), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_int64_wrapper", + .expected_result_type = AstType(ast_internal::PrimitiveTypeWrapper( + ast_internal::PrimitiveType::kInt64)), + }, + CheckedExprTestCase{ + .expr = "test_msg.single_struct", + .expected_result_type = AstType(ast_internal::MapType( + std::make_unique(ast_internal::PrimitiveType::kString), + std::make_unique(ast_internal::DynamicType()))), + }, + CheckedExprTestCase{ + .expr = "test_msg.list_value", + .expected_result_type = AstType(ast_internal::ListType( + std::make_unique(ast_internal::DynamicType()))), + }, + CheckedExprTestCase{ + .expr = "test_msg.list_value", + .expected_result_type = AstType(ast_internal::ListType( + std::make_unique(ast_internal::DynamicType()))), + }, + // Basic tests for nested messages. + CheckedExprTestCase{ + .expr = "NestedTestAllTypes{}.child.child.payload.single_int64", + .expected_result_type = + AstType(ast_internal::PrimitiveType::kInt64), + } + + )); + } // namespace } // namespace checker_internal } // namespace cel diff --git a/checker/internal/type_inference_context.cc b/checker/internal/type_inference_context.cc index 4c7b1fa73..8f384d289 100644 --- a/checker/internal/type_inference_context.cc +++ b/checker/internal/type_inference_context.cc @@ -229,12 +229,12 @@ Type TypeInferenceContext::InstantiateTypeParams( } bool TypeInferenceContext::IsAssignable(const Type& from, const Type& to) { - // Simple assignablility check assuming parameters are correctly bound. - // TODO: handle resolving type parameter substitution. - if (IsWildCardType(from) || IsWildCardType(to)) { - return true; + SubstitutionMap prospective_substitutions; + bool result = IsAssignableInternal(from, to, prospective_substitutions); + if (result) { + UpdateTypeParameterBindings(prospective_substitutions); } - return common_internal::TypeIsAssignable(from, to); + return result; } bool TypeInferenceContext::IsAssignableInternal( @@ -289,6 +289,10 @@ bool TypeInferenceContext::IsAssignableInternal( return true; } + if (to_subs.kind() == TypeKind::kEnum && from_subs.kind() == TypeKind::kInt) { + return true; + } + if (IsWildCardType(from_subs) || IsWildCardType(to_subs)) { return true; } diff --git a/checker/internal/type_inference_context_test.cc b/checker/internal/type_inference_context_test.cc index f1a970de7..36d21b65a 100644 --- a/checker/internal/type_inference_context_test.cc +++ b/checker/internal/type_inference_context_test.cc @@ -226,15 +226,15 @@ TEST(TypeInferenceContextTest, WrapperTypeAssignable) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); - EXPECT_TRUE(context.IsAssignable(StringWrapperType(), StringType())); - EXPECT_TRUE(context.IsAssignable(StringWrapperType(), NullType())); + EXPECT_TRUE(context.IsAssignable(StringType(), StringWrapperType())); + EXPECT_TRUE(context.IsAssignable(NullType(), StringWrapperType())); } TEST(TypeInferenceContextTest, MismatchedTypeNotAssignable) { google::protobuf::Arena arena; TypeInferenceContext context(&arena); - EXPECT_FALSE(context.IsAssignable(StringWrapperType(), IntType())); + EXPECT_FALSE(context.IsAssignable(IntType(), StringWrapperType())); } TEST(TypeInferenceContextTest, OverloadResolution) { diff --git a/checker/type_checker_builder.cc b/checker/type_checker_builder.cc index aa7a7e735..b0a84cb1d 100644 --- a/checker/type_checker_builder.cc +++ b/checker/type_checker_builder.cc @@ -25,10 +25,16 @@ #include "checker/internal/type_checker_impl.h" #include "checker/type_checker.h" #include "common/decl.h" +#include "common/type_introspector.h" namespace cel { absl::StatusOr> TypeCheckerBuilder::Build() && { + if (env_.type_providers().empty() && env_.parent() == nullptr) { + // Add a default type provider if none have been added to cover + // WellKnownTypes. + env_.AddTypeProvider(std::make_unique()); + } return std::make_unique(std::move(env_)); } @@ -61,6 +67,11 @@ absl::Status TypeCheckerBuilder::AddFunction(const FunctionDecl& decl) { return absl::OkStatus(); } +void TypeCheckerBuilder::AddTypeProvider( + std::unique_ptr provider) { + env_.AddTypeProvider(std::move(provider)); +} + void TypeCheckerBuilder::set_container(absl::string_view container) { env_.set_container(std::string(container)); } diff --git a/checker/type_checker_builder.h b/checker/type_checker_builder.h index b95b9e0f4..97807166f 100644 --- a/checker/type_checker_builder.h +++ b/checker/type_checker_builder.h @@ -28,6 +28,7 @@ #include "checker/internal/type_check_env.h" #include "checker/type_checker.h" #include "common/decl.h" +#include "common/type_introspector.h" namespace cel { @@ -72,6 +73,8 @@ class TypeCheckerBuilder { absl::Status AddVariable(const VariableDecl& decl); absl::Status AddFunction(const FunctionDecl& decl); + void AddTypeProvider(std::unique_ptr provider); + void set_container(absl::string_view container); const CheckerOptions& options() const { return options_; } diff --git a/common/type_introspector.cc b/common/type_introspector.cc index 9970cd5d4..23151654a 100644 --- a/common/type_introspector.cc +++ b/common/type_introspector.cc @@ -222,6 +222,17 @@ absl::StatusOr> TypeIntrospector::FindType( return FindTypeImpl(type_factory, name); } +absl::StatusOr> +TypeIntrospector::FindEnumConstant(TypeFactory& type_factory, + absl::string_view type, + absl::string_view value) const { + if (type == "google.protobuf.NullValue" && value == "NULL_VALUE") { + return EnumConstant{NullType{}, "google.protobuf.NullValue", "NULL_VALUE", + 0}; + } + return FindEnumConstantImpl(type_factory, type, value); +} + absl::StatusOr> TypeIntrospector::FindStructTypeFieldByName(TypeFactory& type_factory, absl::string_view type, @@ -238,6 +249,12 @@ absl::StatusOr> TypeIntrospector::FindTypeImpl( return absl::nullopt; } +absl::StatusOr> +TypeIntrospector::FindEnumConstantImpl(TypeFactory&, absl::string_view, + absl::string_view) const { + return absl::nullopt; +} + absl::StatusOr> TypeIntrospector::FindStructTypeFieldByNameImpl(TypeFactory&, absl::string_view, absl::string_view) const { diff --git a/common/type_introspector.h b/common/type_introspector.h index ec7ba21e3..2e504465b 100644 --- a/common/type_introspector.h +++ b/common/type_introspector.h @@ -15,6 +15,8 @@ #ifndef THIRD_PARTY_CEL_CPP_COMMON_TYPE_INTROSPECTOR_H_ #define THIRD_PARTY_CEL_CPP_COMMON_TYPE_INTROSPECTOR_H_ +#include + #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/optional.h" @@ -32,12 +34,27 @@ class TypeFactory; // is used by the runtime. class TypeIntrospector { public: + struct EnumConstant { + // The type of the enum. For JSON null, this may be a specific type rather + // than an enum type. + Type type; + absl::string_view type_full_name; + absl::string_view value_name; + int32_t number; + }; + virtual ~TypeIntrospector() = default; // `FindType` find the type corresponding to name `name`. absl::StatusOr> FindType(TypeFactory& type_factory, absl::string_view name) const; + // `FindEnumConstant` find a fully qualified enumerator name `name` in enum + // type `type`. + absl::StatusOr> FindEnumConstant( + TypeFactory& type_factory, absl::string_view type, + absl::string_view value) const; + // `FindStructTypeFieldByName` find the name, number, and type of the field // `name` in type `type`. absl::StatusOr> FindStructTypeFieldByName( @@ -56,6 +73,10 @@ class TypeIntrospector { virtual absl::StatusOr> FindTypeImpl( TypeFactory& type_factory, absl::string_view name) const; + virtual absl::StatusOr> FindEnumConstantImpl( + TypeFactory& type_factory, absl::string_view type, + absl::string_view value) const; + virtual absl::StatusOr> FindStructTypeFieldByNameImpl(TypeFactory& type_factory, absl::string_view type, diff --git a/common/types/basic_struct_type.cc b/common/types/basic_struct_type.cc index 5d5cc25c1..a3b31544c 100644 --- a/common/types/basic_struct_type.cc +++ b/common/types/basic_struct_type.cc @@ -24,9 +24,25 @@ namespace cel { bool IsWellKnownMessageType(absl::string_view name) { static constexpr absl::string_view kPrefix = "google.protobuf."; static constexpr std::array kNames = { - "Any", "BoolValue", "BytesValue", "DoubleValue", "Duration", - "FloatValue", "Int32Value", "Int64Value", "ListValue", "UInt32Value", - "UInt64Value", "StringValue", "Struct", "Timestamp", "Value", + // clang-format off + // keep-sorted start + "Any", + "BoolValue", + "BytesValue", + "DoubleValue", + "Duration", + "FloatValue", + "Int32Value", + "Int64Value", + "ListValue", + "StringValue", + "Struct", + "Timestamp", + "UInt32Value", + "UInt64Value", + "Value", + // keep-sorted end + // clang-format on }; if (!absl::ConsumePrefix(&name, kPrefix)) { return false; diff --git a/common/types/basic_struct_type.h b/common/types/basic_struct_type.h index 2cf9b9411..8852c9d71 100644 --- a/common/types/basic_struct_type.h +++ b/common/types/basic_struct_type.h @@ -32,6 +32,11 @@ namespace cel { class Type; class TypeParameters; +// Returns true if the given type name is one of the well known message types +// that CEL treats specially. +// +// For familiarity with textproto, these types may be created using the struct +// creation syntax, even though they are not considered a struct type in CEL. bool IsWellKnownMessageType(absl::string_view name); namespace common_internal { diff --git a/conformance/service.cc b/conformance/service.cc index b5fe9131c..6adaa33ff 100644 --- a/conformance/service.cc +++ b/conformance/service.cc @@ -610,8 +610,9 @@ class ModernConformanceServiceImpl : public ConformanceServiceInterface { cel::extensions::CreateAstFromParsedExpr(parsed_expr)); cel::TypeCheckerBuilder builder; - // TODO: apply the type env to the checker builder for custom - // variables and functions. + + builder.AddTypeProvider( + std::make_unique()); for (const auto& decl : request.type_env()) { const auto& name = decl.name(); diff --git a/extensions/protobuf/BUILD b/extensions/protobuf/BUILD index c3984bae6..ec0a53a6d 100644 --- a/extensions/protobuf/BUILD +++ b/extensions/protobuf/BUILD @@ -170,6 +170,7 @@ cc_test( deps = [ ":type", "//common:type", + "//common:type_kind", "//common:type_testing", "//internal:testing", "@com_google_absl//absl/types:optional", diff --git a/extensions/protobuf/type_introspector.cc b/extensions/protobuf/type_introspector.cc index e0fb00698..f681d41fc 100644 --- a/extensions/protobuf/type_introspector.cc +++ b/extensions/protobuf/type_introspector.cc @@ -20,6 +20,7 @@ #include "common/type.h" #include "common/type_factory.h" #include "common/type_introspector.h" +#include "google/protobuf/descriptor.h" namespace cel::extensions { @@ -34,6 +35,31 @@ absl::StatusOr> ProtoTypeIntrospector::FindTypeImpl( return MessageType(desc); } +absl::StatusOr> +ProtoTypeIntrospector::FindEnumConstantImpl(TypeFactory&, + absl::string_view type, + absl::string_view value) const { + const google::protobuf::EnumDescriptor* enum_desc = + descriptor_pool()->FindEnumTypeByName(type); + // google.protobuf.NullValue is special cased in the base class. + if (enum_desc == nullptr) { + return absl::nullopt; + } + + // Note: we don't support strong enum typing at this time so only the fully + // qualified enum values are meaningful, so we don't provide any signal if the + // enum type is found but can't match the value name. + const google::protobuf::EnumValueDescriptor* value_desc = + enum_desc->FindValueByName(value); + if (value_desc == nullptr) { + return absl::nullopt; + } + + return TypeIntrospector::EnumConstant{ + EnumType(enum_desc), enum_desc->full_name(), value_desc->name(), + value_desc->number()}; +} + absl::StatusOr> ProtoTypeIntrospector::FindStructTypeFieldByNameImpl( TypeFactory& type_factory, absl::string_view type, diff --git a/extensions/protobuf/type_introspector.h b/extensions/protobuf/type_introspector.h index 430418345..eae18aa06 100644 --- a/extensions/protobuf/type_introspector.h +++ b/extensions/protobuf/type_introspector.h @@ -43,6 +43,10 @@ class ProtoTypeIntrospector : public virtual TypeIntrospector { absl::StatusOr> FindTypeImpl( TypeFactory& type_factory, absl::string_view name) const final; + absl::StatusOr> + FindEnumConstantImpl(TypeFactory&, absl::string_view type, + absl::string_view value) const final; + absl::StatusOr> FindStructTypeFieldByNameImpl( TypeFactory& type_factory, absl::string_view type, absl::string_view name) const final; diff --git a/extensions/protobuf/type_introspector_test.cc b/extensions/protobuf/type_introspector_test.cc index f44d13496..35cb0a5e3 100644 --- a/extensions/protobuf/type_introspector_test.cc +++ b/extensions/protobuf/type_introspector_test.cc @@ -16,6 +16,7 @@ #include "absl/types/optional.h" #include "common/type.h" +#include "common/type_kind.h" #include "common/type_testing.h" #include "internal/testing.h" #include "proto/test/v1/proto2/test_all_types.pb.h" @@ -62,6 +63,54 @@ TEST_P(ProtoTypeIntrospectorTest, FindStructTypeFieldByName) { IsOkAndHolds(Eq(absl::nullopt))); } +TEST_P(ProtoTypeIntrospectorTest, FindEnumConstant) { + ProtoTypeIntrospector introspector; + const auto* enum_desc = TestAllTypes::NestedEnum_descriptor(); + ASSERT_OK_AND_ASSIGN( + auto enum_constant, + introspector.FindEnumConstant( + type_manager(), + "google.api.expr.test.v1.proto2.TestAllTypes.NestedEnum", "BAZ")); + ASSERT_TRUE(enum_constant.has_value()); + EXPECT_EQ(enum_constant->type.kind(), TypeKind::kEnum); + EXPECT_EQ(enum_constant->type_full_name, enum_desc->full_name()); + EXPECT_EQ(enum_constant->value_name, "BAZ"); + EXPECT_EQ(enum_constant->number, 2); +} + +TEST_P(ProtoTypeIntrospectorTest, FindEnumConstantNull) { + ProtoTypeIntrospector introspector; + ASSERT_OK_AND_ASSIGN( + auto enum_constant, + introspector.FindEnumConstant(type_manager(), "google.protobuf.NullValue", + "NULL_VALUE")); + ASSERT_TRUE(enum_constant.has_value()); + EXPECT_EQ(enum_constant->type.kind(), TypeKind::kNull); + EXPECT_EQ(enum_constant->type_full_name, "google.protobuf.NullValue"); + EXPECT_EQ(enum_constant->value_name, "NULL_VALUE"); + EXPECT_EQ(enum_constant->number, 0); +} + +TEST_P(ProtoTypeIntrospectorTest, FindEnumConstantUnknownEnum) { + ProtoTypeIntrospector introspector; + + ASSERT_OK_AND_ASSIGN( + auto enum_constant, + introspector.FindEnumConstant(type_manager(), "NotARealEnum", "BAZ")); + EXPECT_FALSE(enum_constant.has_value()); +} + +TEST_P(ProtoTypeIntrospectorTest, FindEnumConstantUnknownValue) { + ProtoTypeIntrospector introspector; + + ASSERT_OK_AND_ASSIGN( + auto enum_constant, + introspector.FindEnumConstant( + type_manager(), + "google.api.expr.test.v1.proto2.TestAllTypes.NestedEnum", "QUX")); + ASSERT_FALSE(enum_constant.has_value()); +} + INSTANTIATE_TEST_SUITE_P( ProtoTypeIntrospectorTest, ProtoTypeIntrospectorTest, ::testing::Values(MemoryManagement::kPooling,