diff --git a/extensions/protobuf/internal/BUILD b/extensions/protobuf/internal/BUILD index b6cb08066..414649f54 100644 --- a/extensions/protobuf/internal/BUILD +++ b/extensions/protobuf/internal/BUILD @@ -74,6 +74,7 @@ cc_library( "//common:constant", "//internal:proto_time_encoding", "//internal:status_macros", + "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", "@com_google_absl//absl/functional:overload", "@com_google_absl//absl/status", @@ -93,6 +94,7 @@ cc_test( "//common:ast", "//internal:proto_matchers", "//internal:testing", + "@com_google_absl//absl/status", "@com_google_googleapis//google/api/expr/v1alpha1:syntax_cc_proto", "@com_google_protobuf//:protobuf", ], diff --git a/extensions/protobuf/internal/ast.cc b/extensions/protobuf/internal/ast.cc index aa9ea027d..b02385f18 100644 --- a/extensions/protobuf/internal/ast.cc +++ b/extensions/protobuf/internal/ast.cc @@ -22,6 +22,7 @@ #include "google/api/expr/v1alpha1/syntax.pb.h" #include "google/protobuf/struct.pb.h" +#include "absl/base/attributes.h" #include "absl/base/nullability.h" #include "absl/functional/overload.h" #include "absl/status/status.h" @@ -39,6 +40,7 @@ namespace { using ExprProto = google::api::expr::v1alpha1::Expr; using ConstExprProto = google::api::expr::v1alpha1::Constant; +using StructExprProto = google::api::expr::v1alpha1::Expr::CreateStruct; class ExprToProtoState final { private: @@ -463,7 +465,7 @@ class ExprFromProtoState final { } absl::Status StructExprFromProto(const ExprProto& proto, - const ExprProto::CreateStruct& struct_proto, + const StructExprProto& struct_proto, Expr& expr) { expr.Clear(); expr.set_id(proto.id()); @@ -472,6 +474,17 @@ class ExprFromProtoState final { struct_expr.mutable_fields().reserve( static_cast(struct_proto.entries().size())); for (const auto& field_proto : struct_proto.entries()) { + switch (field_proto.key_kind_case()) { + case StructExprProto::Entry::KEY_KIND_NOT_SET: + ABSL_FALLTHROUGH_INTENDED; + case StructExprProto::Entry::kFieldKey: + break; + case StructExprProto::Entry::kMapKey: + return absl::InvalidArgumentError("encountered map entry in struct"); + default: + return absl::InvalidArgumentError(absl::StrCat( + "unexpected struct field kind: ", field_proto.key_kind_case())); + } auto& field_expr = struct_expr.add_fields(); field_expr.set_id(field_proto.id()); field_expr.set_name(field_proto.field_key()); @@ -492,6 +505,17 @@ class ExprFromProtoState final { map_expr.mutable_entries().reserve( static_cast(map_proto.entries().size())); for (const auto& entry_proto : map_proto.entries()) { + switch (entry_proto.key_kind_case()) { + case StructExprProto::Entry::KEY_KIND_NOT_SET: + ABSL_FALLTHROUGH_INTENDED; + case StructExprProto::Entry::kMapKey: + break; + case StructExprProto::Entry::kFieldKey: + return absl::InvalidArgumentError("encountered struct field in map"); + default: + return absl::InvalidArgumentError(absl::StrCat( + "unexpected map entry kind: ", entry_proto.key_kind_case())); + } auto& entry_expr = map_expr.add_entries(); entry_expr.set_id(entry_proto.id()); if (entry_proto.has_map_key()) { diff --git a/extensions/protobuf/internal/ast_test.cc b/extensions/protobuf/internal/ast_test.cc index 49ad10500..c947e3ddc 100644 --- a/extensions/protobuf/internal/ast_test.cc +++ b/extensions/protobuf/internal/ast_test.cc @@ -17,6 +17,7 @@ #include #include "google/api/expr/v1alpha1/syntax.pb.h" +#include "absl/status/status.h" #include "common/ast.h" #include "internal/proto_matchers.h" #include "internal/testing.h" @@ -27,6 +28,7 @@ namespace { using ::cel::internal::test::EqualsProto; using cel::internal::IsOk; +using cel::internal::StatusIs; using ExprProto = google::api::expr::v1alpha1::Expr; @@ -220,5 +222,53 @@ INSTANTIATE_TEST_SUITE_P( )pb"}, })); +TEST(ExprFromProto, StructFieldInMap) { + ExprProto original_proto; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"pb( + id: 1 + struct_expr: { + entries: { + id: 2 + field_key: "foo" + value: { + id: 3 + ident_expr: { name: "bar" } + } + } + } + )pb", + &original_proto)); + Expr expr; + ASSERT_THAT(ExprFromProto(original_proto, expr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + +TEST(ExprFromProto, MapEntryInStruct) { + ExprProto original_proto; + ASSERT_TRUE( + google::protobuf::TextFormat::ParseFromString(R"pb( + id: 1 + struct_expr: { + message_name: "some.Message" + entries: { + id: 2 + map_key: { + id: 3 + ident_expr: { name: "foo" } + } + value: { + id: 4 + ident_expr: { name: "bar" } + } + } + } + )pb", + &original_proto)); + Expr expr; + ASSERT_THAT(ExprFromProto(original_proto, expr), + StatusIs(absl::StatusCode::kInvalidArgument)); +} + } // namespace } // namespace cel::extensions::protobuf_internal