diff --git a/internal/BUILD b/internal/BUILD index 2b6cbc740..84236fd9d 100644 --- a/internal/BUILD +++ b/internal/BUILD @@ -782,3 +782,54 @@ cc_test( "@com_google_protobuf//:protobuf", ], ) + +cc_library( + name = "message_equality", + srcs = ["message_equality.cc"], + hdrs = ["message_equality.h"], + deps = [ + ":json", + ":number", + ":status_macros", + ":well_known_types", + "//common:memory", + "//extensions/protobuf/internal:map_reflection", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/time", + "@com_google_absl//absl/types:variant", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "message_equality_test", + srcs = ["message_equality_test.cc"], + deps = [ + ":message_equality", + ":message_type_name", + ":parse_text_proto", + ":testing", + ":testing_descriptor_pool", + ":testing_message_factory", + ":well_known_types", + "//common:allocator", + "//common:memory", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status:status_matchers", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_cel_spec//proto/test/v1/proto3:test_all_types_cc_proto", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/internal/message_equality.cc b/internal/message_equality.cc new file mode 100644 index 000000000..ebcff644c --- /dev/null +++ b/internal/message_equality.cc @@ -0,0 +1,1492 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/message_equality.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/functional/overload.h" +#include "absl/log/absl_check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/time/time.h" +#include "absl/types/variant.h" +#include "common/memory.h" +#include "extensions/protobuf/internal/map_reflection.h" +#include "internal/json.h" +#include "internal/number.h" +#include "internal/status_macros.h" +#include "internal/well_known_types.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/util/message_differencer.h" + +namespace cel::internal { + +namespace { + +using ::cel::extensions::protobuf_internal::LookupMapValue; +using ::cel::extensions::protobuf_internal::MapBegin; +using ::cel::extensions::protobuf_internal::MapEnd; +using ::cel::extensions::protobuf_internal::MapSize; +using ::google::protobuf::Descriptor; +using ::google::protobuf::DescriptorPool; +using ::google::protobuf::FieldDescriptor; +using ::google::protobuf::Message; +using ::google::protobuf::MessageFactory; +using ::google::protobuf::util::MessageDifferencer; + +class EquatableListValue final + : public std::reference_wrapper { + public: + using std::reference_wrapper::reference_wrapper; +}; + +class EquatableStruct final + : public std::reference_wrapper { + public: + using std::reference_wrapper::reference_wrapper; +}; + +class EquatableAny final + : public std::reference_wrapper { + public: + using std::reference_wrapper::reference_wrapper; +}; + +class EquatableMessage final + : public std::reference_wrapper { + public: + using std::reference_wrapper::reference_wrapper; +}; + +using EquatableValue = + absl::variant; + +struct NullValueEqualer { + bool operator()(std::nullptr_t, std::nullptr_t) const { return true; } + + template + std::enable_if_t>, bool> + operator()(std::nullptr_t, const T&) const { + return false; + } +}; + +struct BoolValueEqualer { + bool operator()(bool lhs, bool rhs) const { return lhs == rhs; } + + template + std::enable_if_t>, bool> operator()( + bool, const T&) const { + return false; + } +}; + +struct BytesValueEqualer { + bool operator()(const well_known_types::BytesValue& lhs, + const well_known_types::BytesValue& rhs) const { + return lhs == rhs; + } + + template + std::enable_if_t< + std::negation_v>, bool> + operator()(const well_known_types::BytesValue&, const T&) const { + return false; + } +}; + +struct IntValueEqualer { + bool operator()(int64_t lhs, int64_t rhs) const { return lhs == rhs; } + + bool operator()(int64_t lhs, uint64_t rhs) const { + return Number::FromInt64(lhs) == Number::FromUint64(rhs); + } + + bool operator()(int64_t lhs, double rhs) const { + return Number::FromInt64(lhs) == Number::FromDouble(rhs); + } + + template + std::enable_if_t>, + std::negation>, + std::negation>>, + bool> + operator()(int64_t, const T&) const { + return false; + } +}; + +struct UintValueEqualer { + bool operator()(uint64_t lhs, int64_t rhs) const { + return Number::FromUint64(lhs) == Number::FromInt64(rhs); + } + + bool operator()(uint64_t lhs, uint64_t rhs) const { return lhs == rhs; } + + bool operator()(uint64_t lhs, double rhs) const { + return Number::FromUint64(lhs) == Number::FromDouble(rhs); + } + + template + std::enable_if_t>, + std::negation>, + std::negation>>, + bool> + operator()(uint64_t, const T&) const { + return false; + } +}; + +struct DoubleValueEqualer { + bool operator()(double lhs, int64_t rhs) const { + return Number::FromDouble(lhs) == Number::FromInt64(rhs); + } + + bool operator()(double lhs, uint64_t rhs) const { + return Number::FromDouble(lhs) == Number::FromUint64(rhs); + } + + bool operator()(double lhs, double rhs) const { return lhs == rhs; } + + template + std::enable_if_t>, + std::negation>, + std::negation>>, + bool> + operator()(double, const T&) const { + return false; + } +}; + +struct StringValueEqualer { + bool operator()(const well_known_types::StringValue& lhs, + const well_known_types::StringValue& rhs) const { + return lhs == rhs; + } + + template + std::enable_if_t< + std::negation_v>, bool> + operator()(const well_known_types::StringValue&, const T&) const { + return false; + } +}; + +struct DurationEqualer { + bool operator()(absl::Duration lhs, absl::Duration rhs) const { + return lhs == rhs; + } + + template + std::enable_if_t>, bool> + operator()(absl::Duration, const T&) const { + return false; + } +}; + +struct TimestampEqualer { + bool operator()(absl::Time lhs, absl::Time rhs) const { return lhs == rhs; } + + template + std::enable_if_t>, bool> + operator()(absl::Time, const T&) const { + return false; + } +}; + +struct ListValueEqualer { + bool operator()(EquatableListValue lhs, EquatableListValue rhs) const { + return JsonListEquals(lhs, rhs); + } + + template + std::enable_if_t>, bool> + operator()(EquatableListValue, const T&) const { + return false; + } +}; + +struct StructEqualer { + bool operator()(EquatableStruct lhs, EquatableStruct rhs) const { + return JsonMapEquals(lhs, rhs); + } + + template + std::enable_if_t>, bool> + operator()(EquatableStruct, const T&) const { + return false; + } +}; + +struct AnyEqualer { + bool operator()(EquatableAny lhs, EquatableAny rhs) const { + auto lhs_reflection = + well_known_types::GetAnyReflectionOrDie(lhs.get().GetDescriptor()); + std::string lhs_type_url_scratch; + std::string lhs_value_scratch; + auto rhs_reflection = + well_known_types::GetAnyReflectionOrDie(rhs.get().GetDescriptor()); + std::string rhs_type_url_scratch; + std::string rhs_value_scratch; + return lhs_reflection.GetTypeUrl(lhs.get(), lhs_type_url_scratch) == + rhs_reflection.GetTypeUrl(rhs.get(), rhs_type_url_scratch) && + lhs_reflection.GetValue(lhs.get(), lhs_value_scratch) == + rhs_reflection.GetValue(rhs.get(), rhs_value_scratch); + } + + template + std::enable_if_t>, bool> + operator()(EquatableAny, const T&) const { + return false; + } +}; + +struct MessageEqualer { + bool operator()(EquatableMessage lhs, EquatableMessage rhs) const { + return lhs.get().GetDescriptor() == rhs.get().GetDescriptor() && + MessageDifferencer::Equals(lhs.get(), rhs.get()); + } + + template + std::enable_if_t>, bool> + operator()(EquatableMessage, const T&) const { + return false; + } +}; + +struct EquatableValueReflection final { + well_known_types::DoubleValueReflection double_value_reflection; + well_known_types::FloatValueReflection float_value_reflection; + well_known_types::Int64ValueReflection int64_value_reflection; + well_known_types::UInt64ValueReflection uint64_value_reflection; + well_known_types::Int32ValueReflection int32_value_reflection; + well_known_types::UInt32ValueReflection uint32_value_reflection; + well_known_types::StringValueReflection string_value_reflection; + well_known_types::BytesValueReflection bytes_value_reflection; + well_known_types::BoolValueReflection bool_value_reflection; + well_known_types::AnyReflection any_reflection; + well_known_types::DurationReflection duration_reflection; + well_known_types::TimestampReflection timestamp_reflection; + well_known_types::ValueReflection value_reflection; + well_known_types::ListValueReflection list_value_reflection; + well_known_types::StructReflection struct_reflection; +}; + +absl::StatusOr AsEquatableValue( + EquatableValueReflection& reflection, + const Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::Nonnull descriptor, + Descriptor::WellKnownType well_known_type, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + switch (well_known_type) { + case Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: + CEL_RETURN_IF_ERROR( + reflection.double_value_reflection.Initialize(descriptor)); + return reflection.double_value_reflection.GetValue(message); + case Descriptor::WELLKNOWNTYPE_FLOATVALUE: + CEL_RETURN_IF_ERROR( + reflection.float_value_reflection.Initialize(descriptor)); + return static_cast( + reflection.float_value_reflection.GetValue(message)); + case Descriptor::WELLKNOWNTYPE_INT64VALUE: + CEL_RETURN_IF_ERROR( + reflection.int64_value_reflection.Initialize(descriptor)); + return reflection.int64_value_reflection.GetValue(message); + case Descriptor::WELLKNOWNTYPE_UINT64VALUE: + CEL_RETURN_IF_ERROR( + reflection.uint64_value_reflection.Initialize(descriptor)); + return reflection.uint64_value_reflection.GetValue(message); + case Descriptor::WELLKNOWNTYPE_INT32VALUE: + CEL_RETURN_IF_ERROR( + reflection.int32_value_reflection.Initialize(descriptor)); + return static_cast( + reflection.int32_value_reflection.GetValue(message)); + case Descriptor::WELLKNOWNTYPE_UINT32VALUE: + CEL_RETURN_IF_ERROR( + reflection.uint32_value_reflection.Initialize(descriptor)); + return static_cast( + reflection.uint32_value_reflection.GetValue(message)); + case Descriptor::WELLKNOWNTYPE_STRINGVALUE: + CEL_RETURN_IF_ERROR( + reflection.string_value_reflection.Initialize(descriptor)); + return reflection.string_value_reflection.GetValue(message, scratch); + case Descriptor::WELLKNOWNTYPE_BYTESVALUE: + CEL_RETURN_IF_ERROR( + reflection.bytes_value_reflection.Initialize(descriptor)); + return reflection.bytes_value_reflection.GetValue(message, scratch); + case Descriptor::WELLKNOWNTYPE_BOOLVALUE: + CEL_RETURN_IF_ERROR( + reflection.bool_value_reflection.Initialize(descriptor)); + return reflection.bool_value_reflection.GetValue(message); + case Descriptor::WELLKNOWNTYPE_VALUE: { + CEL_RETURN_IF_ERROR(reflection.value_reflection.Initialize(descriptor)); + const auto kind_case = reflection.value_reflection.GetKindCase(message); + switch (kind_case) { + case google::protobuf::Value::KIND_NOT_SET: + ABSL_FALLTHROUGH_INTENDED; + case google::protobuf::Value::kNullValue: + return nullptr; + case google::protobuf::Value::kBoolValue: + return reflection.value_reflection.GetBoolValue(message); + case google::protobuf::Value::kNumberValue: + return reflection.value_reflection.GetNumberValue(message); + case google::protobuf::Value::kStringValue: + return reflection.value_reflection.GetStringValue(message, scratch); + case google::protobuf::Value::kListValue: + return EquatableListValue( + reflection.value_reflection.GetListValue(message)); + case google::protobuf::Value::kStructValue: + return EquatableStruct( + reflection.value_reflection.GetStructValue(message)); + default: + return absl::InternalError( + absl::StrCat("unexpected value kind case: ", kind_case)); + } + } + case Descriptor::WELLKNOWNTYPE_LISTVALUE: + return EquatableListValue(message); + case Descriptor::WELLKNOWNTYPE_STRUCT: + return EquatableStruct(message); + case Descriptor::WELLKNOWNTYPE_DURATION: + CEL_RETURN_IF_ERROR( + reflection.duration_reflection.Initialize(descriptor)); + return reflection.duration_reflection.ToAbslDuration(message); + case Descriptor::WELLKNOWNTYPE_TIMESTAMP: + CEL_RETURN_IF_ERROR( + reflection.timestamp_reflection.Initialize(descriptor)); + return reflection.timestamp_reflection.ToAbslTime(message); + case Descriptor::WELLKNOWNTYPE_ANY: + return EquatableAny(message); + default: + return EquatableMessage(message); + } +} + +absl::StatusOr AsEquatableValue( + EquatableValueReflection& reflection, + const Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::Nonnull descriptor, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + return AsEquatableValue(reflection, message, descriptor, + descriptor->well_known_type(), scratch); +} + +absl::StatusOr AsEquatableValue( + EquatableValueReflection& reflection, + const Message& message ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::Nonnull field, + std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) { + ABSL_DCHECK(!field->is_repeated() && !field->is_map()); + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return static_cast( + message.GetReflection()->GetInt32(message, field)); + case FieldDescriptor::CPPTYPE_INT64: + return message.GetReflection()->GetInt64(message, field); + case FieldDescriptor::CPPTYPE_UINT32: + return static_cast( + message.GetReflection()->GetUInt32(message, field)); + case FieldDescriptor::CPPTYPE_UINT64: + return message.GetReflection()->GetUInt64(message, field); + case FieldDescriptor::CPPTYPE_DOUBLE: + return message.GetReflection()->GetDouble(message, field); + case FieldDescriptor::CPPTYPE_FLOAT: + return static_cast( + message.GetReflection()->GetFloat(message, field)); + case FieldDescriptor::CPPTYPE_BOOL: + return message.GetReflection()->GetBool(message, field); + case FieldDescriptor::CPPTYPE_ENUM: + if (field->enum_type()->full_name() == "google.protobuf.NullValue") { + return nullptr; + } + return static_cast( + message.GetReflection()->GetEnumValue(message, field)); + case FieldDescriptor::CPPTYPE_STRING: + if (field->type() == FieldDescriptor::TYPE_BYTES) { + return well_known_types::GetBytesField(message, field, scratch); + } + return well_known_types::GetStringField(message, field, scratch); + case FieldDescriptor::CPPTYPE_MESSAGE: + return AsEquatableValue( + reflection, message.GetReflection()->GetMessage(message, field), + field->message_type(), scratch); + default: + return absl::InternalError( + absl::StrCat("unexpected field type: ", field->cpp_type_name())); + } +} + +bool IsAny(const Message& message) { + return message.GetDescriptor()->well_known_type() == + Descriptor::WELLKNOWNTYPE_ANY; +} + +bool IsAnyField(absl::Nonnull field) { + return field->type() == FieldDescriptor::TYPE_MESSAGE && + field->message_type()->well_known_type() == + Descriptor::WELLKNOWNTYPE_ANY; +} + +absl::StatusOr MapValueAsEquatableValue( + absl::Nonnull arena, + absl::Nonnull pool, + absl::Nonnull factory, + EquatableValueReflection& reflection, const google::protobuf::MapValueConstRef& value, + absl::Nonnull field, std::string& scratch, + Unique& unpacked) { + if (IsAnyField(field)) { + CEL_ASSIGN_OR_RETURN(unpacked, well_known_types::UnpackAnyIfResolveable( + arena, reflection.any_reflection, + value.GetMessageValue(), pool, factory)); + if (unpacked) { + return AsEquatableValue(reflection, *unpacked, unpacked->GetDescriptor(), + scratch); + } + return AsEquatableValue(reflection, value.GetMessageValue(), + value.GetMessageValue().GetDescriptor(), scratch); + } + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return static_cast(value.GetInt32Value()); + case FieldDescriptor::CPPTYPE_INT64: + return value.GetInt64Value(); + case FieldDescriptor::CPPTYPE_UINT32: + return static_cast(value.GetUInt32Value()); + case FieldDescriptor::CPPTYPE_UINT64: + return value.GetUInt64Value(); + case FieldDescriptor::CPPTYPE_DOUBLE: + return value.GetDoubleValue(); + case FieldDescriptor::CPPTYPE_FLOAT: + return static_cast(value.GetFloatValue()); + case FieldDescriptor::CPPTYPE_BOOL: + return value.GetBoolValue(); + case FieldDescriptor::CPPTYPE_ENUM: + if (field->enum_type()->full_name() == "google.protobuf.NullValue") { + return nullptr; + } + return static_cast(value.GetEnumValue()); + case FieldDescriptor::CPPTYPE_STRING: + if (field->type() == FieldDescriptor::TYPE_BYTES) { + return well_known_types::BytesValue( + absl::string_view(value.GetStringValue())); + } + return well_known_types::StringValue( + absl::string_view(value.GetStringValue())); + case FieldDescriptor::CPPTYPE_MESSAGE: { + const auto& message = value.GetMessageValue(); + return AsEquatableValue(reflection, message, message.GetDescriptor(), + scratch); + } + default: + return absl::InternalError( + absl::StrCat("unexpected field type: ", field->cpp_type_name())); + } +} + +absl::StatusOr RepeatedFieldAsEquatableValue( + absl::Nonnull arena, + absl::Nonnull pool, + absl::Nonnull factory, + EquatableValueReflection& reflection, const Message& message, + absl::Nonnull field, int index, + std::string& scratch, Unique& unpacked) { + if (IsAnyField(field)) { + const auto& field_value = + message.GetReflection()->GetRepeatedMessage(message, field, index); + CEL_ASSIGN_OR_RETURN(unpacked, well_known_types::UnpackAnyIfResolveable( + arena, reflection.any_reflection, + field_value, pool, factory)); + if (unpacked) { + return AsEquatableValue(reflection, *unpacked, unpacked->GetDescriptor(), + scratch); + } + return AsEquatableValue(reflection, field_value, + field_value.GetDescriptor(), scratch); + } + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return static_cast( + message.GetReflection()->GetRepeatedInt32(message, field, index)); + case FieldDescriptor::CPPTYPE_INT64: + return message.GetReflection()->GetRepeatedInt64(message, field, index); + case FieldDescriptor::CPPTYPE_UINT32: + return static_cast( + message.GetReflection()->GetRepeatedUInt32(message, field, index)); + case FieldDescriptor::CPPTYPE_UINT64: + return message.GetReflection()->GetRepeatedUInt64(message, field, index); + case FieldDescriptor::CPPTYPE_DOUBLE: + return message.GetReflection()->GetRepeatedDouble(message, field, index); + case FieldDescriptor::CPPTYPE_FLOAT: + return static_cast( + message.GetReflection()->GetRepeatedFloat(message, field, index)); + case FieldDescriptor::CPPTYPE_BOOL: + return message.GetReflection()->GetRepeatedBool(message, field, index); + case FieldDescriptor::CPPTYPE_ENUM: + if (field->enum_type()->full_name() == "google.protobuf.NullValue") { + return nullptr; + } + return static_cast( + message.GetReflection()->GetRepeatedEnumValue(message, field, index)); + case FieldDescriptor::CPPTYPE_STRING: + if (field->type() == FieldDescriptor::TYPE_BYTES) { + return well_known_types::GetRepeatedBytesField(message, field, index, + scratch); + } + return well_known_types::GetRepeatedStringField(message, field, index, + scratch); + case FieldDescriptor::CPPTYPE_MESSAGE: { + const auto& submessage = + message.GetReflection()->GetRepeatedMessage(message, field, index); + return AsEquatableValue(reflection, submessage, + submessage.GetDescriptor(), scratch); + } + default: + return absl::InternalError( + absl::StrCat("unexpected field type: ", field->cpp_type_name())); + } +} + +// Compare two `EquatableValue` for equality. +bool EquatableValueEquals(const EquatableValue& lhs, + const EquatableValue& rhs) { + return absl::visit( + absl::Overload(NullValueEqualer{}, BoolValueEqualer{}, + BytesValueEqualer{}, IntValueEqualer{}, UintValueEqualer{}, + DoubleValueEqualer{}, StringValueEqualer{}, + DurationEqualer{}, TimestampEqualer{}, ListValueEqualer{}, + StructEqualer{}, AnyEqualer{}, MessageEqualer{}), + lhs, rhs); +} + +// Attempts to coalesce one map key to another. Returns true if it was possible, +// false otherwise. +bool CoalesceMapKey(const google::protobuf::MapKey& src, + FieldDescriptor::CppType dest_type, + absl::Nonnull dest) { + switch (src.type()) { + case FieldDescriptor::CPPTYPE_BOOL: + if (dest_type != FieldDescriptor::CPPTYPE_BOOL) { + return false; + } + dest->SetBoolValue(src.GetBoolValue()); + return true; + case FieldDescriptor::CPPTYPE_INT32: { + const auto src_value = src.GetInt32Value(); + switch (dest_type) { + case FieldDescriptor::CPPTYPE_INT32: + dest->SetInt32Value(src_value); + return true; + case FieldDescriptor::CPPTYPE_INT64: + dest->SetInt64Value(src_value); + return true; + case FieldDescriptor::CPPTYPE_UINT32: + if (src_value < 0) { + return false; + } + dest->SetUInt32Value(static_cast(src_value)); + return true; + case FieldDescriptor::CPPTYPE_UINT64: + if (src_value < 0) { + return false; + } + dest->SetUInt64Value(static_cast(src_value)); + return true; + default: + return false; + } + } + case FieldDescriptor::CPPTYPE_INT64: { + const auto src_value = src.GetInt64Value(); + switch (dest_type) { + case FieldDescriptor::CPPTYPE_INT32: + if (src_value < std::numeric_limits::min() || + src_value > std::numeric_limits::max()) { + return false; + } + dest->SetInt32Value(static_cast(src_value)); + return true; + case FieldDescriptor::CPPTYPE_INT64: + dest->SetInt64Value(src_value); + return true; + case FieldDescriptor::CPPTYPE_UINT32: + if (src_value < 0 || + src_value > std::numeric_limits::max()) { + return false; + } + dest->SetUInt32Value(static_cast(src_value)); + return true; + case FieldDescriptor::CPPTYPE_UINT64: + if (src_value < 0) { + return false; + } + dest->SetUInt64Value(static_cast(src_value)); + return true; + default: + return false; + } + } + case FieldDescriptor::CPPTYPE_UINT32: { + const auto src_value = src.GetUInt32Value(); + switch (dest_type) { + case FieldDescriptor::CPPTYPE_INT32: + if (src_value > std::numeric_limits::max()) { + return false; + } + dest->SetInt32Value(static_cast(src_value)); + return true; + case FieldDescriptor::CPPTYPE_INT64: + dest->SetInt64Value(static_cast(src_value)); + return true; + case FieldDescriptor::CPPTYPE_UINT32: + dest->SetUInt32Value(src_value); + return true; + case FieldDescriptor::CPPTYPE_UINT64: + dest->SetUInt64Value(static_cast(src_value)); + return true; + default: + return false; + } + } + case FieldDescriptor::CPPTYPE_UINT64: { + const auto src_value = src.GetUInt64Value(); + switch (dest_type) { + case FieldDescriptor::CPPTYPE_INT32: + if (src_value > std::numeric_limits::max()) { + return false; + } + dest->SetInt32Value(static_cast(src_value)); + return true; + case FieldDescriptor::CPPTYPE_INT64: + if (src_value > std::numeric_limits::max()) { + return false; + } + dest->SetInt64Value(static_cast(src_value)); + return true; + case FieldDescriptor::CPPTYPE_UINT32: + if (src_value > std::numeric_limits::max()) { + return false; + } + dest->SetUInt32Value(src_value); + return true; + case FieldDescriptor::CPPTYPE_UINT64: + dest->SetUInt64Value(src_value); + return true; + default: + return false; + } + } + case FieldDescriptor::CPPTYPE_STRING: + if (dest_type != FieldDescriptor::CPPTYPE_STRING) { + return false; + } + dest->SetStringValue(src.GetStringValue()); + return true; + default: + // Only bool, integrals, and string may be map keys. + ABSL_UNREACHABLE(); + } +} + +// Bits used for categorizing equality. Can be used to cheaply check whether two +// categories are comparable for equality by performing an AND and checking if +// the result against `kNone`. +enum class EquatableCategory { + kNone = 0, + + kNullLike = 1 << 0, + kBoolLike = 1 << 1, + kNumericLike = 1 << 2, + kBytesLike = 1 << 3, + kStringLike = 1 << 4, + kList = 1 << 5, + kMap = 1 << 6, + kMessage = 1 << 7, + kDuration = 1 << 8, + kTimestamp = 1 << 9, + + kAny = kNullLike | kBoolLike | kNumericLike | kBytesLike | kStringLike | + kList | kMap | kMessage | kDuration | kTimestamp, + kValue = kNullLike | kBoolLike | kNumericLike | kStringLike | kList | kMap, +}; + +constexpr EquatableCategory operator&(EquatableCategory lhs, + EquatableCategory rhs) { + return static_cast( + static_cast>(lhs) & + static_cast>(rhs)); +} + +constexpr bool operator==(EquatableCategory lhs, EquatableCategory rhs) { + return static_cast>(lhs) == + static_cast>(rhs); +} + +EquatableCategory GetEquatableCategory( + absl::Nonnull descriptor) { + switch (descriptor->well_known_type()) { + case Descriptor::WELLKNOWNTYPE_BOOLVALUE: + return EquatableCategory::kBoolLike; + case Descriptor::WELLKNOWNTYPE_FLOATVALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_DOUBLEVALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_INT32VALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_UINT32VALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_INT64VALUE: + ABSL_FALLTHROUGH_INTENDED; + case Descriptor::WELLKNOWNTYPE_UINT64VALUE: + return EquatableCategory::kNumericLike; + case Descriptor::WELLKNOWNTYPE_BYTESVALUE: + return EquatableCategory::kBytesLike; + case Descriptor::WELLKNOWNTYPE_STRINGVALUE: + return EquatableCategory::kStringLike; + case Descriptor::WELLKNOWNTYPE_VALUE: + return EquatableCategory::kValue; + case Descriptor::WELLKNOWNTYPE_LISTVALUE: + return EquatableCategory::kList; + case Descriptor::WELLKNOWNTYPE_STRUCT: + return EquatableCategory::kMap; + case Descriptor::WELLKNOWNTYPE_ANY: + return EquatableCategory::kAny; + case Descriptor::WELLKNOWNTYPE_DURATION: + return EquatableCategory::kDuration; + case Descriptor::WELLKNOWNTYPE_TIMESTAMP: + return EquatableCategory::kTimestamp; + default: + return EquatableCategory::kAny; + } +} + +EquatableCategory GetEquatableFieldCategory( + absl::Nonnull field) { + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_ENUM: + return field->enum_type()->full_name() == "google.protobuf.NullValue" + ? EquatableCategory::kNullLike + : EquatableCategory::kNumericLike; + case FieldDescriptor::CPPTYPE_BOOL: + return EquatableCategory::kBoolLike; + case FieldDescriptor::CPPTYPE_FLOAT: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::CPPTYPE_DOUBLE: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::CPPTYPE_INT32: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::CPPTYPE_UINT32: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::CPPTYPE_INT64: + ABSL_FALLTHROUGH_INTENDED; + case FieldDescriptor::CPPTYPE_UINT64: + return EquatableCategory::kNumericLike; + case FieldDescriptor::CPPTYPE_STRING: + return field->type() == FieldDescriptor::TYPE_BYTES + ? EquatableCategory::kBytesLike + : EquatableCategory::kStringLike; + case FieldDescriptor::CPPTYPE_MESSAGE: + return GetEquatableCategory(field->message_type()); + default: + // Ugh. Force any future additions to compare instead of short circuiting. + return EquatableCategory::kAny; + } +} + +class MessageEqualsState final { + public: + MessageEqualsState(absl::Nonnull pool, + absl::Nonnull factory) + : pool_(pool), factory_(factory) {} + + // Equality between messages. + absl::StatusOr Equals(const Message& lhs, const Message& rhs) { + const auto* lhs_descriptor = lhs.GetDescriptor(); + const auto* rhs_descriptor = rhs.GetDescriptor(); + // Deal with well known types, starting with any. + auto lhs_well_known_type = lhs_descriptor->well_known_type(); + auto rhs_well_known_type = rhs_descriptor->well_known_type(); + absl::Nonnull lhs_ptr = &lhs; + absl::Nonnull rhs_ptr = &rhs; + Unique lhs_unpacked; + Unique rhs_unpacked; + // Deal with any first. We could in theory check if we should bother + // unpacking, but that is more complicated. We can always implement it + // later. + if (lhs_well_known_type == Descriptor::WELLKNOWNTYPE_ANY) { + CEL_ASSIGN_OR_RETURN( + lhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, lhs_reflection_.any_reflection, lhs, pool_, factory_)); + if (lhs_unpacked) { + lhs_ptr = cel::to_address(lhs_unpacked); + lhs_descriptor = lhs_ptr->GetDescriptor(); + lhs_well_known_type = lhs_descriptor->well_known_type(); + } + } + if (rhs_well_known_type == Descriptor::WELLKNOWNTYPE_ANY) { + CEL_ASSIGN_OR_RETURN( + rhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, rhs_reflection_.any_reflection, rhs, pool_, factory_)); + if (rhs_unpacked) { + rhs_ptr = cel::to_address(rhs_unpacked); + rhs_descriptor = rhs_ptr->GetDescriptor(); + rhs_well_known_type = rhs_descriptor->well_known_type(); + } + } + CEL_ASSIGN_OR_RETURN( + auto lhs_value, + AsEquatableValue(lhs_reflection_, *lhs_ptr, lhs_descriptor, + lhs_well_known_type, lhs_scratch_)); + CEL_ASSIGN_OR_RETURN( + auto rhs_value, + AsEquatableValue(rhs_reflection_, *rhs_ptr, rhs_descriptor, + rhs_well_known_type, rhs_scratch_)); + return EquatableValueEquals(lhs_value, rhs_value); + } + + // Equality between map message fields. + absl::StatusOr MapFieldEquals( + const Message& lhs, absl::Nonnull lhs_field, + const Message& rhs, absl::Nonnull rhs_field) { + ABSL_DCHECK(lhs_field->is_map()); + ABSL_DCHECK_EQ(lhs_field->containing_type(), lhs.GetDescriptor()); + ABSL_DCHECK(rhs_field->is_map()); + ABSL_DCHECK_EQ(rhs_field->containing_type(), rhs.GetDescriptor()); + const auto* lhs_entry = lhs_field->message_type(); + const auto* lhs_entry_key_field = lhs_entry->map_key(); + const auto* lhs_entry_value_field = lhs_entry->map_value(); + const auto* rhs_entry = rhs_field->message_type(); + const auto* rhs_entry_key_field = rhs_entry->map_key(); + const auto* rhs_entry_value_field = rhs_entry->map_value(); + // Perform cheap test which checks whether the left and right can even be + // compared for equality. + if (lhs_field != rhs_field && + ((GetEquatableFieldCategory(lhs_entry_key_field) & + GetEquatableFieldCategory(rhs_entry_key_field)) == + EquatableCategory::kNone || + (GetEquatableFieldCategory(lhs_entry_value_field) & + GetEquatableFieldCategory(rhs_entry_value_field)) == + EquatableCategory::kNone)) { + // Short-circuit. + return false; + } + const auto* lhs_reflection = lhs.GetReflection(); + const auto* rhs_reflection = rhs.GetReflection(); + if (MapSize(*lhs_reflection, lhs, *lhs_field) != + MapSize(*rhs_reflection, rhs, *rhs_field)) { + return false; + } + auto lhs_begin = MapBegin(*lhs_reflection, lhs, *lhs_field); + const auto lhs_end = MapEnd(*lhs_reflection, lhs, *lhs_field); + Unique lhs_unpacked; + EquatableValue lhs_value; + Unique rhs_unpacked; + EquatableValue rhs_value; + google::protobuf::MapKey rhs_map_key; + google::protobuf::MapValueConstRef rhs_map_value; + for (; lhs_begin != lhs_end; ++lhs_begin) { + if (!CoalesceMapKey(lhs_begin.GetKey(), rhs_entry_key_field->cpp_type(), + &rhs_map_key)) { + return false; + } + if (!LookupMapValue(*rhs_reflection, rhs, *rhs_field, rhs_map_key, + &rhs_map_value)) { + return false; + } + CEL_ASSIGN_OR_RETURN(lhs_value, + MapValueAsEquatableValue( + &arena_, pool_, factory_, lhs_reflection_, + lhs_begin.GetValueRef(), lhs_entry_value_field, + lhs_scratch_, lhs_unpacked)); + CEL_ASSIGN_OR_RETURN( + rhs_value, + MapValueAsEquatableValue(&arena_, pool_, factory_, rhs_reflection_, + rhs_map_value, rhs_entry_value_field, + rhs_scratch_, rhs_unpacked)); + if (!EquatableValueEquals(lhs_value, rhs_value)) { + return false; + } + } + return true; + } + + // Equality between repeated message fields. + absl::StatusOr RepeatedFieldEquals( + const Message& lhs, absl::Nonnull lhs_field, + const Message& rhs, absl::Nonnull rhs_field) { + ABSL_DCHECK(lhs_field->is_repeated() && !lhs_field->is_map()); + ABSL_DCHECK_EQ(lhs_field->containing_type(), lhs.GetDescriptor()); + ABSL_DCHECK(rhs_field->is_repeated() && !rhs_field->is_map()); + ABSL_DCHECK_EQ(rhs_field->containing_type(), rhs.GetDescriptor()); + // Perform cheap test which checks whether the left and right can even be + // compared for equality. + if (lhs_field != rhs_field && + (GetEquatableFieldCategory(lhs_field) & + GetEquatableFieldCategory(rhs_field)) == EquatableCategory::kNone) { + // Short-circuit. + return false; + } + const auto* lhs_reflection = lhs.GetReflection(); + const auto* rhs_reflection = rhs.GetReflection(); + const auto size = lhs_reflection->FieldSize(lhs, lhs_field); + if (size != rhs_reflection->FieldSize(rhs, rhs_field)) { + return false; + } + Unique lhs_unpacked; + EquatableValue lhs_value; + Unique rhs_unpacked; + EquatableValue rhs_value; + for (int i = 0; i < size; ++i) { + CEL_ASSIGN_OR_RETURN(lhs_value, + RepeatedFieldAsEquatableValue( + &arena_, pool_, factory_, lhs_reflection_, lhs, + lhs_field, i, lhs_scratch_, lhs_unpacked)); + CEL_ASSIGN_OR_RETURN(rhs_value, + RepeatedFieldAsEquatableValue( + &arena_, pool_, factory_, rhs_reflection_, rhs, + rhs_field, i, rhs_scratch_, rhs_unpacked)); + if (!EquatableValueEquals(lhs_value, rhs_value)) { + return false; + } + } + return true; + } + + // Equality between singular message fields and/or messages. If the field is + // `nullptr`, we are performing equality on the message itself rather than the + // corresponding field. + absl::StatusOr SingularFieldEquals( + const Message& lhs, absl::Nullable lhs_field, + const Message& rhs, absl::Nullable rhs_field) { + ABSL_DCHECK(lhs_field == nullptr || + (!lhs_field->is_repeated() && !lhs_field->is_map())); + ABSL_DCHECK(lhs_field == nullptr || + lhs_field->containing_type() == lhs.GetDescriptor()); + ABSL_DCHECK(rhs_field == nullptr || + (!rhs_field->is_repeated() && !rhs_field->is_map())); + ABSL_DCHECK(rhs_field == nullptr || + rhs_field->containing_type() == rhs.GetDescriptor()); + // Perform cheap test which checks whether the left and right can even be + // compared for equality. + if (lhs_field != rhs_field && + ((lhs_field != nullptr ? GetEquatableFieldCategory(lhs_field) + : GetEquatableCategory(lhs.GetDescriptor())) & + (rhs_field != nullptr ? GetEquatableFieldCategory(rhs_field) + : GetEquatableCategory(rhs.GetDescriptor()))) == + EquatableCategory::kNone) { + // Short-circuit. + return false; + } + absl::Nonnull lhs_ptr = &lhs; + absl::Nonnull rhs_ptr = &rhs; + Unique lhs_unpacked; + Unique rhs_unpacked; + if (lhs_field != nullptr && IsAnyField(lhs_field)) { + CEL_ASSIGN_OR_RETURN(lhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, lhs_reflection_.any_reflection, + lhs.GetReflection()->GetMessage(lhs, lhs_field), + pool_, factory_)); + if (lhs_unpacked) { + lhs_ptr = cel::to_address(lhs_unpacked); + lhs_field = nullptr; + } + } else if (lhs_field == nullptr && IsAny(lhs)) { + CEL_ASSIGN_OR_RETURN( + lhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, lhs_reflection_.any_reflection, lhs, pool_, factory_)); + if (lhs_unpacked) { + lhs_ptr = cel::to_address(lhs_unpacked); + } + } + if (rhs_field != nullptr && IsAnyField(rhs_field)) { + CEL_ASSIGN_OR_RETURN(rhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, rhs_reflection_.any_reflection, + rhs.GetReflection()->GetMessage(rhs, rhs_field), + pool_, factory_)); + if (rhs_unpacked) { + rhs_ptr = cel::to_address(rhs_unpacked); + rhs_field = nullptr; + } + } else if (rhs_field == nullptr && IsAny(rhs)) { + CEL_ASSIGN_OR_RETURN( + rhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, rhs_reflection_.any_reflection, rhs, pool_, factory_)); + if (rhs_unpacked) { + rhs_ptr = cel::to_address(rhs_unpacked); + } + } + EquatableValue lhs_value; + if (lhs_field != nullptr) { + CEL_ASSIGN_OR_RETURN( + lhs_value, + AsEquatableValue(lhs_reflection_, *lhs_ptr, lhs_field, lhs_scratch_)); + } else { + CEL_ASSIGN_OR_RETURN( + lhs_value, AsEquatableValue(lhs_reflection_, *lhs_ptr, + lhs_ptr->GetDescriptor(), lhs_scratch_)); + } + EquatableValue rhs_value; + if (rhs_field != nullptr) { + CEL_ASSIGN_OR_RETURN( + rhs_value, + AsEquatableValue(rhs_reflection_, *rhs_ptr, rhs_field, rhs_scratch_)); + } else { + CEL_ASSIGN_OR_RETURN( + rhs_value, AsEquatableValue(rhs_reflection_, *rhs_ptr, + rhs_ptr->GetDescriptor(), rhs_scratch_)); + } + return EquatableValueEquals(lhs_value, rhs_value); + } + + absl::StatusOr FieldEquals( + const Message& lhs, absl::Nullable lhs_field, + const Message& rhs, absl::Nullable rhs_field) { + ABSL_DCHECK(lhs_field != nullptr || + rhs_field != nullptr); // Both cannot be null. + if (lhs_field != nullptr && lhs_field->is_map()) { + // map == map + // map == google.protobuf.Value + // map == google.protobuf.Struct + // map == google.protobuf.Any + + // Right hand side should be a map, `google.protobuf.Value`, + // `google.protobuf.Struct`, or `google.protobuf.Any`. + if (rhs_field != nullptr && rhs_field->is_map()) { + // map == map + return MapFieldEquals(lhs, lhs_field, rhs, rhs_field); + } + if (rhs_field != nullptr && + (rhs_field->is_repeated() || + rhs_field->type() != FieldDescriptor::TYPE_MESSAGE)) { + return false; + } + absl::Nullable rhs_packed = nullptr; + Unique rhs_unpacked; + if (rhs_field != nullptr && IsAnyField(rhs_field)) { + rhs_packed = &rhs.GetReflection()->GetMessage(rhs, rhs_field); + } else if (rhs_field == nullptr && IsAny(rhs)) { + rhs_packed = &rhs; + } + if (rhs_packed != nullptr) { + CEL_RETURN_IF_ERROR(rhs_reflection_.any_reflection.Initialize( + rhs_packed->GetDescriptor())); + auto rhs_type_url = rhs_reflection_.any_reflection.GetTypeUrl( + *rhs_packed, rhs_scratch_); + if (!rhs_type_url.ConsumePrefix("type.googleapis.com/") && + !rhs_type_url.ConsumePrefix("type.googleprod.com/")) { + return false; + } + if (rhs_type_url != "google.protobuf.Value" && + rhs_type_url != "google.protobuf.Struct" && + rhs_type_url != "google.protobuf.Any") { + return false; + } + CEL_ASSIGN_OR_RETURN(rhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, rhs_reflection_.any_reflection, + *rhs_packed, pool_, factory_)); + if (rhs_unpacked) { + rhs_field = nullptr; + } + } + absl::Nonnull rhs_message = + rhs_field != nullptr + ? &rhs.GetReflection()->GetMessage(rhs, rhs_field) + : rhs_unpacked != nullptr ? cel::to_address(rhs_unpacked) + : &rhs; + const auto* rhs_descriptor = rhs_message->GetDescriptor(); + const auto rhs_well_known_type = rhs_descriptor->well_known_type(); + switch (rhs_well_known_type) { + case Descriptor::WELLKNOWNTYPE_VALUE: { + // map == google.protobuf.Value + CEL_RETURN_IF_ERROR( + rhs_reflection_.value_reflection.Initialize(rhs_descriptor)); + if (rhs_reflection_.value_reflection.GetKindCase(*rhs_message) != + google::protobuf::Value::kStructValue) { + return false; + } + CEL_RETURN_IF_ERROR(rhs_reflection_.struct_reflection.Initialize( + rhs_reflection_.value_reflection.GetStructDescriptor())); + return MapFieldEquals( + lhs, lhs_field, + rhs_reflection_.value_reflection.GetStructValue(*rhs_message), + rhs_reflection_.struct_reflection.GetFieldsDescriptor()); + } + case Descriptor::WELLKNOWNTYPE_STRUCT: { + // map == google.protobuf.Struct + CEL_RETURN_IF_ERROR( + rhs_reflection_.struct_reflection.Initialize(rhs_descriptor)); + return MapFieldEquals( + lhs, lhs_field, *rhs_message, + rhs_reflection_.struct_reflection.GetFieldsDescriptor()); + } + default: + return false; + } + // Explicitly unreachable, for ease of reading. Control never leaves this + // if statement. + ABSL_UNREACHABLE(); + } + if (rhs_field != nullptr && rhs_field->is_map()) { + // google.protobuf.Value == map + // google.protobuf.Struct == map + // google.protobuf.Any == map + + // Left hand side should be singular `google.protobuf.Value` + // `google.protobuf.Struct`, or `google.protobuf.Any`. + ABSL_DCHECK(lhs_field == nullptr || + !lhs_field->is_map()); // Handled above. + if (lhs_field != nullptr && + (lhs_field->is_repeated() || + lhs_field->type() != FieldDescriptor::TYPE_MESSAGE)) { + return false; + } + absl::Nullable lhs_packed = nullptr; + Unique lhs_unpacked; + if (lhs_field != nullptr && IsAnyField(lhs_field)) { + lhs_packed = &lhs.GetReflection()->GetMessage(lhs, lhs_field); + } else if (lhs_field == nullptr && IsAny(lhs)) { + lhs_packed = &lhs; + } + if (lhs_packed != nullptr) { + CEL_RETURN_IF_ERROR(lhs_reflection_.any_reflection.Initialize( + lhs_packed->GetDescriptor())); + auto lhs_type_url = lhs_reflection_.any_reflection.GetTypeUrl( + *lhs_packed, lhs_scratch_); + if (!lhs_type_url.ConsumePrefix("type.googleapis.com/") && + !lhs_type_url.ConsumePrefix("type.googleprod.com/")) { + return false; + } + if (lhs_type_url != "google.protobuf.Value" && + lhs_type_url != "google.protobuf.Struct" && + lhs_type_url != "google.protobuf.Any") { + return false; + } + CEL_ASSIGN_OR_RETURN(lhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, lhs_reflection_.any_reflection, + *lhs_packed, pool_, factory_)); + if (lhs_unpacked) { + lhs_field = nullptr; + } + } + absl::Nonnull lhs_message = + lhs_field != nullptr + ? &lhs.GetReflection()->GetMessage(lhs, lhs_field) + : lhs_unpacked != nullptr ? cel::to_address(lhs_unpacked) + : &lhs; + const auto* lhs_descriptor = lhs_message->GetDescriptor(); + const auto lhs_well_known_type = lhs_descriptor->well_known_type(); + switch (lhs_well_known_type) { + case Descriptor::WELLKNOWNTYPE_VALUE: { + // map == google.protobuf.Value + CEL_RETURN_IF_ERROR( + lhs_reflection_.value_reflection.Initialize(lhs_descriptor)); + if (lhs_reflection_.value_reflection.GetKindCase(*lhs_message) != + google::protobuf::Value::kStructValue) { + return false; + } + CEL_RETURN_IF_ERROR(lhs_reflection_.struct_reflection.Initialize( + lhs_reflection_.value_reflection.GetStructDescriptor())); + return MapFieldEquals( + lhs_reflection_.value_reflection.GetStructValue(*lhs_message), + lhs_reflection_.struct_reflection.GetFieldsDescriptor(), rhs, + rhs_field); + } + case Descriptor::WELLKNOWNTYPE_STRUCT: { + // map == google.protobuf.Struct + CEL_RETURN_IF_ERROR( + lhs_reflection_.struct_reflection.Initialize(lhs_descriptor)); + return MapFieldEquals( + *lhs_message, + lhs_reflection_.struct_reflection.GetFieldsDescriptor(), rhs, + rhs_field); + } + default: + return false; + } + // Explicitly unreachable, for ease of reading. Control never leaves this + // if statement. + ABSL_UNREACHABLE(); + } + ABSL_DCHECK(lhs_field == nullptr || + !lhs_field->is_map()); // Handled above. + ABSL_DCHECK(rhs_field == nullptr || + !rhs_field->is_map()); // Handled above. + if (lhs_field != nullptr && lhs_field->is_repeated()) { + // repeated == repeated + // repeated == google.protobuf.Value + // repeated == google.protobuf.ListValue + // repeated == google.protobuf.Any + + // Right hand side should be a repeated, `google.protobuf.Value`, + // `google.protobuf.ListValue`, or `google.protobuf.Any`. + if (rhs_field != nullptr && rhs_field->is_repeated()) { + // map == map + return RepeatedFieldEquals(lhs, lhs_field, rhs, rhs_field); + } + if (rhs_field != nullptr && + rhs_field->type() != FieldDescriptor::TYPE_MESSAGE) { + return false; + } + absl::Nullable rhs_packed = nullptr; + Unique rhs_unpacked; + if (rhs_field != nullptr && IsAnyField(rhs_field)) { + rhs_packed = &rhs.GetReflection()->GetMessage(rhs, rhs_field); + } else if (rhs_field == nullptr && IsAny(rhs)) { + rhs_packed = &rhs; + } + if (rhs_packed != nullptr) { + CEL_RETURN_IF_ERROR(rhs_reflection_.any_reflection.Initialize( + rhs_packed->GetDescriptor())); + auto rhs_type_url = rhs_reflection_.any_reflection.GetTypeUrl( + *rhs_packed, rhs_scratch_); + if (!rhs_type_url.ConsumePrefix("type.googleapis.com/") && + !rhs_type_url.ConsumePrefix("type.googleprod.com/")) { + return false; + } + if (rhs_type_url != "google.protobuf.Value" && + rhs_type_url != "google.protobuf.ListValue" && + rhs_type_url != "google.protobuf.Any") { + return false; + } + CEL_ASSIGN_OR_RETURN(rhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, rhs_reflection_.any_reflection, + *rhs_packed, pool_, factory_)); + if (rhs_unpacked) { + rhs_field = nullptr; + } + } + absl::Nonnull rhs_message = + rhs_field != nullptr + ? &rhs.GetReflection()->GetMessage(rhs, rhs_field) + : rhs_unpacked != nullptr ? cel::to_address(rhs_unpacked) + : &rhs; + const auto* rhs_descriptor = rhs_message->GetDescriptor(); + const auto rhs_well_known_type = rhs_descriptor->well_known_type(); + switch (rhs_well_known_type) { + case Descriptor::WELLKNOWNTYPE_VALUE: { + // map == google.protobuf.Value + CEL_RETURN_IF_ERROR( + rhs_reflection_.value_reflection.Initialize(rhs_descriptor)); + if (rhs_reflection_.value_reflection.GetKindCase(*rhs_message) != + google::protobuf::Value::kListValue) { + return false; + } + CEL_RETURN_IF_ERROR(rhs_reflection_.list_value_reflection.Initialize( + rhs_reflection_.value_reflection.GetListValueDescriptor())); + return RepeatedFieldEquals( + lhs, lhs_field, + rhs_reflection_.value_reflection.GetListValue(*rhs_message), + rhs_reflection_.list_value_reflection.GetValuesDescriptor()); + } + case Descriptor::WELLKNOWNTYPE_LISTVALUE: { + // map == google.protobuf.ListValue + CEL_RETURN_IF_ERROR( + rhs_reflection_.list_value_reflection.Initialize(rhs_descriptor)); + return RepeatedFieldEquals( + lhs, lhs_field, *rhs_message, + rhs_reflection_.list_value_reflection.GetValuesDescriptor()); + } + default: + return false; + } + // Explicitly unreachable, for ease of reading. Control never leaves this + // if statement. + ABSL_UNREACHABLE(); + } + if (rhs_field != nullptr && rhs_field->is_repeated()) { + // google.protobuf.Value == repeated + // google.protobuf.ListValue == repeated + // google.protobuf.Any == repeated + + // Left hand side should be singular `google.protobuf.Value` + // `google.protobuf.ListValue`, or `google.protobuf.Any`. + ABSL_DCHECK(lhs_field == nullptr || + !lhs_field->is_repeated()); // Handled above. + if (lhs_field != nullptr && + lhs_field->type() != FieldDescriptor::TYPE_MESSAGE) { + return false; + } + absl::Nullable lhs_packed = nullptr; + Unique lhs_unpacked; + if (lhs_field != nullptr && IsAnyField(lhs_field)) { + lhs_packed = &lhs.GetReflection()->GetMessage(lhs, lhs_field); + } else if (lhs_field == nullptr && IsAny(lhs)) { + lhs_packed = &lhs; + } + if (lhs_packed != nullptr) { + CEL_RETURN_IF_ERROR(lhs_reflection_.any_reflection.Initialize( + lhs_packed->GetDescriptor())); + auto lhs_type_url = lhs_reflection_.any_reflection.GetTypeUrl( + *lhs_packed, lhs_scratch_); + if (!lhs_type_url.ConsumePrefix("type.googleapis.com/") && + !lhs_type_url.ConsumePrefix("type.googleprod.com/")) { + return false; + } + if (lhs_type_url != "google.protobuf.Value" && + lhs_type_url != "google.protobuf.ListValue" && + lhs_type_url != "google.protobuf.Any") { + return false; + } + CEL_ASSIGN_OR_RETURN(lhs_unpacked, + well_known_types::UnpackAnyIfResolveable( + &arena_, lhs_reflection_.any_reflection, + *lhs_packed, pool_, factory_)); + if (lhs_unpacked) { + lhs_field = nullptr; + } + } + absl::Nonnull lhs_message = + lhs_field != nullptr + ? &lhs.GetReflection()->GetMessage(lhs, lhs_field) + : lhs_unpacked != nullptr ? cel::to_address(lhs_unpacked) + : &lhs; + const auto* lhs_descriptor = lhs_message->GetDescriptor(); + const auto lhs_well_known_type = lhs_descriptor->well_known_type(); + switch (lhs_well_known_type) { + case Descriptor::WELLKNOWNTYPE_VALUE: { + // map == google.protobuf.Value + CEL_RETURN_IF_ERROR( + lhs_reflection_.value_reflection.Initialize(lhs_descriptor)); + if (lhs_reflection_.value_reflection.GetKindCase(*lhs_message) != + google::protobuf::Value::kListValue) { + return false; + } + CEL_RETURN_IF_ERROR(lhs_reflection_.list_value_reflection.Initialize( + lhs_reflection_.value_reflection.GetListValueDescriptor())); + return RepeatedFieldEquals( + lhs_reflection_.value_reflection.GetListValue(*lhs_message), + lhs_reflection_.list_value_reflection.GetValuesDescriptor(), rhs, + rhs_field); + } + case Descriptor::WELLKNOWNTYPE_LISTVALUE: { + // map == google.protobuf.ListValue + CEL_RETURN_IF_ERROR( + lhs_reflection_.list_value_reflection.Initialize(lhs_descriptor)); + return RepeatedFieldEquals( + *lhs_message, + lhs_reflection_.list_value_reflection.GetValuesDescriptor(), rhs, + rhs_field); + } + default: + return false; + } + // Explicitly unreachable, for ease of reading. Control never leaves this + // if statement. + ABSL_UNREACHABLE(); + } + return SingularFieldEquals(lhs, lhs_field, rhs, rhs_field); + } + + private: + const absl::Nonnull pool_; + const absl::Nonnull factory_; + google::protobuf::Arena arena_; + EquatableValueReflection lhs_reflection_; + EquatableValueReflection rhs_reflection_; + std::string lhs_scratch_; + std::string rhs_scratch_; +}; + +} // namespace + +absl::StatusOr MessageEquals(const Message& lhs, const Message& rhs, + absl::Nonnull pool, + absl::Nonnull factory) { + ABSL_DCHECK(pool != nullptr); + ABSL_DCHECK(factory != nullptr); + if (&lhs == &rhs) { + return true; + } + // MessageEqualsState has quite a large size, so we allocate it on the heap. + // Ideally we should just hold most of the state at runtime in something like + // `FlatExpressionEvaluatorState`, so we can avoid allocating this repeatedly. + return std::make_unique(pool, factory)->Equals(lhs, rhs); +} + +absl::StatusOr MessageFieldEquals( + const Message& lhs, absl::Nonnull lhs_field, + const Message& rhs, absl::Nonnull rhs_field, + absl::Nonnull pool, + absl::Nonnull factory) { + ABSL_DCHECK(lhs_field != nullptr); + ABSL_DCHECK(rhs_field != nullptr); + ABSL_DCHECK(pool != nullptr); + ABSL_DCHECK(factory != nullptr); + if (&lhs == &rhs && lhs_field == rhs_field) { + return true; + } + // MessageEqualsState has quite a large size, so we allocate it on the heap. + // Ideally we should just hold most of the state at runtime in something like + // `FlatExpressionEvaluatorState`, so we can avoid allocating this repeatedly. + return std::make_unique(pool, factory) + ->FieldEquals(lhs, lhs_field, rhs, rhs_field); +} + +absl::StatusOr MessageFieldEquals( + const google::protobuf::Message& lhs, const google::protobuf::Message& rhs, + absl::Nonnull rhs_field, + absl::Nonnull pool, + absl::Nonnull factory) { + ABSL_DCHECK(rhs_field != nullptr); + ABSL_DCHECK(pool != nullptr); + ABSL_DCHECK(factory != nullptr); + // MessageEqualsState has quite a large size, so we allocate it on the heap. + // Ideally we should just hold most of the state at runtime in something like + // `FlatExpressionEvaluatorState`, so we can avoid allocating this repeatedly. + return std::make_unique(pool, factory) + ->FieldEquals(lhs, nullptr, rhs, rhs_field); +} + +absl::StatusOr MessageFieldEquals( + const google::protobuf::Message& lhs, + absl::Nonnull lhs_field, + const google::protobuf::Message& rhs, + absl::Nonnull pool, + absl::Nonnull factory) { + ABSL_DCHECK(lhs_field != nullptr); + ABSL_DCHECK(pool != nullptr); + ABSL_DCHECK(factory != nullptr); + // MessageEqualsState has quite a large size, so we allocate it on the heap. + // Ideally we should just hold most of the state at runtime in something like + // `FlatExpressionEvaluatorState`, so we can avoid allocating this repeatedly. + return std::make_unique(pool, factory) + ->FieldEquals(lhs, lhs_field, rhs, nullptr); +} + +} // namespace cel::internal diff --git a/internal/message_equality.h b/internal/message_equality.h new file mode 100644 index 000000000..948393bed --- /dev/null +++ b/internal/message_equality.h @@ -0,0 +1,55 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_INTERNAL_MESSAGE_EQUALITY_H_ +#define THIRD_PARTY_CEL_CPP_INTERNAL_MESSAGE_EQUALITY_H_ + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::internal { + +// Tests whether one message is equal to another following CEL equality +// semantics. +absl::StatusOr MessageEquals( + const google::protobuf::Message& lhs, const google::protobuf::Message& rhs, + absl::Nonnull pool, + absl::Nonnull factory); + +// Tests whether one message field is equal to another following CEL equality +// semantics. +absl::StatusOr MessageFieldEquals( + const google::protobuf::Message& lhs, + absl::Nonnull lhs_field, + const google::protobuf::Message& rhs, + absl::Nonnull rhs_field, + absl::Nonnull pool, + absl::Nonnull factory); +absl::StatusOr MessageFieldEquals( + const google::protobuf::Message& lhs, const google::protobuf::Message& rhs, + absl::Nonnull rhs_field, + absl::Nonnull pool, + absl::Nonnull factory); +absl::StatusOr MessageFieldEquals( + const google::protobuf::Message& lhs, + absl::Nonnull lhs_field, + const google::protobuf::Message& rhs, + absl::Nonnull pool, + absl::Nonnull factory); + +} // namespace cel::internal + +#endif // THIRD_PARTY_CEL_CPP_INTERNAL_MESSAGE_EQUALITY_H_ diff --git a/internal/message_equality_test.cc b/internal/message_equality_test.cc new file mode 100644 index 000000000..6b43635ed --- /dev/null +++ b/internal/message_equality_test.cc @@ -0,0 +1,1041 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "internal/message_equality.h" + +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "google/protobuf/duration.pb.h" +#include "google/protobuf/struct.pb.h" +#include "google/protobuf/timestamp.pb.h" +#include "google/protobuf/wrappers.pb.h" +#include "absl/base/nullability.h" +#include "absl/log/absl_check.h" +#include "absl/log/die_if_null.h" +#include "absl/status/status_matchers.h" +#include "absl/strings/cord.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/allocator.h" +#include "common/memory.h" +#include "internal/message_type_name.h" +#include "internal/parse_text_proto.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "internal/testing_message_factory.h" +#include "internal/well_known_types.h" +#include "proto/test/v1/proto3/test_all_types.pb.h" +#include "google/protobuf/arena.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace cel::internal { +namespace { + +using ::absl_testing::IsOkAndHolds; +using ::testing::IsFalse; +using ::testing::IsTrue; +using ::testing::TestParamInfo; +using ::testing::TestWithParam; +using ::testing::ValuesIn; + +using TestAllTypesProto3 = ::google::api::expr::test::v1::proto3::TestAllTypes; + +template +Owned ParseTextProto(absl::string_view text) { + return DynamicParseTextProto(NewDeleteAllocator(), text, + GetTestingDescriptorPool(), + GetTestingMessageFactory()); +} + +struct UnaryMessageEqualsTestParam { + std::string name; + std::vector> ops; + bool equal; +}; + +std::string UnaryMessageEqualsTestParamName( + const TestParamInfo& param_info) { + return param_info.param.name; +} + +using UnaryMessageEqualsTest = TestWithParam; + +Owned PackMessage(const google::protobuf::Message& message) { + const auto* descriptor = + ABSL_DIE_IF_NULL(GetTestingDescriptorPool()->FindMessageTypeByName( + MessageTypeNameFor())); + const auto* prototype = + ABSL_DIE_IF_NULL(GetTestingMessageFactory()->GetPrototype(descriptor)); + auto instance = WrapShared(prototype->New(), NewDeleteAllocator()); + auto reflection = well_known_types::GetAnyReflectionOrDie(descriptor); + reflection.SetTypeUrl( + cel::to_address(instance), + absl::StrCat("type.googleapis.com/", message.GetTypeName())); + absl::Cord value; + ABSL_CHECK(message.SerializeToCord(&value)); + reflection.SetValue(cel::to_address(instance), value); + return instance; +} + +TEST_P(UnaryMessageEqualsTest, Equals) { + const auto* pool = GetTestingDescriptorPool(); + auto* factory = GetTestingMessageFactory(); + const auto& test_case = GetParam(); + for (const auto& lhs : test_case.ops) { + for (const auto& rhs : test_case.ops) { + if (!test_case.equal && &lhs == &rhs) { + continue; + } + EXPECT_THAT(MessageEquals(*lhs, *rhs, pool, factory), + IsOkAndHolds(test_case.equal)) + << *lhs << " " << *rhs; + EXPECT_THAT(MessageEquals(*rhs, *lhs, pool, factory), + IsOkAndHolds(test_case.equal)) + << *lhs << " " << *rhs; + // Test any. + auto lhs_any = PackMessage(*lhs); + auto rhs_any = PackMessage(*rhs); + EXPECT_THAT(MessageEquals(*lhs_any, *rhs, pool, factory), + IsOkAndHolds(test_case.equal)) + << *lhs_any << " " << *rhs; + EXPECT_THAT(MessageEquals(*lhs, *rhs_any, pool, factory), + IsOkAndHolds(test_case.equal)) + << *lhs << " " << *rhs_any; + EXPECT_THAT(MessageEquals(*lhs_any, *rhs_any, pool, factory), + IsOkAndHolds(test_case.equal)) + << *lhs_any << " " << *rhs_any; + } + } +} + +INSTANTIATE_TEST_SUITE_P( + UnaryMessageEqualsTest, UnaryMessageEqualsTest, + ValuesIn({ + { + .name = "NullValue_Equal", + .ops = + { + ParseTextProto(R"pb()pb"), + ParseTextProto( + R"pb(null_value: NULL_VALUE)pb"), + }, + .equal = true, + }, + { + .name = "BoolValue_False_Equal", + .ops = + { + ParseTextProto(R"pb()pb"), + ParseTextProto( + R"pb(value: false)pb"), + ParseTextProto( + R"pb(bool_value: false)pb"), + }, + .equal = true, + }, + { + .name = "BoolValue_True_Equal", + .ops = + { + ParseTextProto( + R"pb(value: true)pb"), + ParseTextProto(R"pb(bool_value: + true)pb"), + }, + .equal = true, + }, + { + .name = "StringValue_Empty_Equal", + .ops = + { + ParseTextProto(R"pb()pb"), + ParseTextProto( + R"pb(value: "")pb"), + ParseTextProto( + R"pb(string_value: "")pb"), + }, + .equal = true, + }, + { + .name = "StringValue_Equal", + .ops = + { + ParseTextProto( + R"pb(value: "foo")pb"), + ParseTextProto( + R"pb(string_value: "foo")pb"), + }, + .equal = true, + }, + { + .name = "BytesValue_Empty_Equal", + .ops = + { + ParseTextProto(R"pb()pb"), + ParseTextProto( + R"pb(value: "")pb"), + }, + .equal = true, + }, + { + .name = "BytesValue_Equal", + .ops = + { + ParseTextProto( + R"pb(value: "foo")pb"), + ParseTextProto( + R"pb(value: "foo")pb"), + }, + .equal = true, + }, + { + .name = "ListValue_Equal", + .ops = + { + ParseTextProto( + R"pb(list_value: { values { bool_value: true } })pb"), + ParseTextProto( + R"pb(values { bool_value: true })pb"), + }, + .equal = true, + }, + { + .name = "ListValue_NotEqual", + .ops = + { + ParseTextProto( + R"pb(list_value: { values { number_value: 0.0 } })pb"), + ParseTextProto( + R"pb(values { number_value: 1.0 })pb"), + ParseTextProto( + R"pb(list_value: { values { number_value: 2.0 } })pb"), + ParseTextProto( + R"pb(values { number_value: 3.0 })pb"), + }, + .equal = false, + }, + { + .name = "StructValue_Equal", + .ops = + { + ParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: { bool_value: true } + } + })pb"), + ParseTextProto( + R"pb(fields { + key: "foo" + value: { bool_value: true } + })pb"), + }, + .equal = true, + }, + { + .name = "StructValue_NotEqual", + .ops = + { + ParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: { number_value: 0.0 } + } + })pb"), + ParseTextProto( + R"pb( + fields { + key: "bar" + value: { number_value: 0.0 } + })pb"), + ParseTextProto( + R"pb(struct_value: { + fields { + key: "foo" + value: { number_value: 1.0 } + } + })pb"), + ParseTextProto( + R"pb( + fields { + key: "bar" + value: { number_value: 1.0 } + })pb"), + }, + .equal = false, + }, + { + .name = "Heterogeneous_Equal", + .ops = + { + ParseTextProto(R"pb()pb"), + ParseTextProto(R"pb()pb"), + ParseTextProto(R"pb()pb"), + ParseTextProto(R"pb()pb"), + ParseTextProto(R"pb()pb"), + ParseTextProto(R"pb()pb"), + ParseTextProto(R"pb(number_value: + 0.0)pb"), + }, + .equal = true, + }, + { + .name = "Message_Equals", + .ops = + { + ParseTextProto(R"pb()pb"), + ParseTextProto(R"pb()pb"), + }, + .equal = true, + }, + { + .name = "Heterogeneous_NotEqual", + .ops = + { + ParseTextProto( + R"pb(value: false)pb"), + ParseTextProto( + R"pb(value: 0)pb"), + ParseTextProto( + R"pb(value: 1)pb"), + ParseTextProto( + R"pb(value: 2)pb"), + ParseTextProto( + R"pb(value: 3)pb"), + ParseTextProto( + R"pb(value: 4.0)pb"), + ParseTextProto( + R"pb(value: 5.0)pb"), + ParseTextProto(R"pb()pb"), + ParseTextProto(R"pb(bool_value: + true)pb"), + ParseTextProto(R"pb(number_value: + 6.0)pb"), + ParseTextProto( + R"pb(string_value: "bar")pb"), + ParseTextProto( + R"pb(value: "foo")pb"), + ParseTextProto( + R"pb(value: "")pb"), + ParseTextProto( + R"pb(value: "foo")pb"), + ParseTextProto( + R"pb(list_value: {})pb"), + ParseTextProto( + R"pb(values { bool_value: true })pb"), + ParseTextProto(R"pb(struct_value: + {})pb"), + ParseTextProto( + R"pb(fields { + key: "foo" + value: { bool_value: false } + })pb"), + ParseTextProto(R"pb()pb"), + ParseTextProto( + R"pb(seconds: 1 nanos: 1)pb"), + ParseTextProto(R"pb()pb"), + ParseTextProto( + R"pb(seconds: 1 nanos: 1)pb"), + ParseTextProto(R"pb()pb"), + ParseTextProto( + R"pb(single_bool: true)pb"), + }, + .equal = false, + }, + }), + UnaryMessageEqualsTestParamName); + +struct UnaryMessageFieldEqualsTestParam { + std::string name; + std::string message; + std::vector fields; + bool equal; +}; + +std::string UnaryMessageFieldEqualsTestParamName( + const TestParamInfo& param_info) { + return param_info.param.name; +} + +using UnaryMessageFieldEqualsTest = + TestWithParam; + +void PackMessageTo(const google::protobuf::Message& message, google::protobuf::Message* instance) { + auto reflection = + *well_known_types::GetAnyReflection(instance->GetDescriptor()); + reflection.SetTypeUrl( + instance, absl::StrCat("type.googleapis.com/", message.GetTypeName())); + absl::Cord value; + ABSL_CHECK(message.SerializeToCord(&value)); + reflection.SetValue(instance, value); +} + +absl::optional, + absl::Nonnull>> +PackTestAllTypesProto3Field( + const google::protobuf::Message& message, + absl::Nonnull field) { + if (field->is_map()) { + return absl::nullopt; + } + if (field->is_repeated() && + field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { + const auto* descriptor = message.GetDescriptor(); + const auto* any_field = descriptor->FindFieldByName("repeated_any"); + auto packed = WrapShared(message.New(), NewDeleteAllocator()); + const int size = message.GetReflection()->FieldSize(message, field); + for (int i = 0; i < size; ++i) { + PackMessageTo( + message.GetReflection()->GetRepeatedMessage(message, field, i), + packed->GetReflection()->AddMessage(cel::to_address(packed), + any_field)); + } + return std::pair{packed, any_field}; + } + if (!field->is_repeated() && + field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { + const auto* descriptor = message.GetDescriptor(); + const auto* any_field = descriptor->FindFieldByName("single_any"); + auto packed = WrapShared(message.New(), NewDeleteAllocator()); + PackMessageTo(message.GetReflection()->GetMessage(message, field), + packed->GetReflection()->MutableMessage( + cel::to_address(packed), any_field)); + return std::pair{packed, any_field}; + } + return absl::nullopt; +} + +TEST_P(UnaryMessageFieldEqualsTest, Equals) { + // We perform exhaustive comparison by testing for equality (or inequality) + // against all combinations of fields. Additionally we convert to + // `google.protobuf.Any` where applicable. This is all done for coverage and + // to ensure different combinations, regardless of argument order, produce the + // same result. + + const auto* pool = GetTestingDescriptorPool(); + auto* factory = GetTestingMessageFactory(); + const auto& test_case = GetParam(); + auto lhs_message = ParseTextProto(test_case.message); + auto rhs_message = ParseTextProto(test_case.message); + const auto* descriptor = ABSL_DIE_IF_NULL( + pool->FindMessageTypeByName(MessageTypeNameFor())); + for (const auto& lhs : test_case.fields) { + for (const auto& rhs : test_case.fields) { + if (!test_case.equal && lhs == rhs) { + // When testing for inequality, do not compare the same field to itself. + continue; + } + const auto* lhs_field = + ABSL_DIE_IF_NULL(descriptor->FindFieldByName(lhs)); + const auto* rhs_field = + ABSL_DIE_IF_NULL(descriptor->FindFieldByName(rhs)); + EXPECT_THAT(MessageFieldEquals(*lhs_message, lhs_field, *rhs_message, + rhs_field, pool, factory), + IsOkAndHolds(test_case.equal)) + << *lhs_message << " " << lhs_field->name() << " " << *rhs_message + << " " << rhs_field->name(); + EXPECT_THAT(MessageFieldEquals(*rhs_message, rhs_field, *lhs_message, + lhs_field, pool, factory), + IsOkAndHolds(test_case.equal)) + << *lhs_message << " " << lhs_field->name() << " " << *rhs_message + << " " << rhs_field->name(); + if (!lhs_field->is_repeated() && + lhs_field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { + EXPECT_THAT(MessageFieldEquals(lhs_message->GetReflection()->GetMessage( + *lhs_message, lhs_field), + *rhs_message, rhs_field, pool, factory), + IsOkAndHolds(test_case.equal)) + << *lhs_message << " " << lhs_field->name() << " " << *rhs_message + << " " << rhs_field->name(); + EXPECT_THAT(MessageFieldEquals(*rhs_message, rhs_field, + lhs_message->GetReflection()->GetMessage( + *lhs_message, lhs_field), + pool, factory), + IsOkAndHolds(test_case.equal)) + << *lhs_message << " " << lhs_field->name() << " " << *rhs_message + << " " << rhs_field->name(); + } + if (!rhs_field->is_repeated() && + rhs_field->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE) { + EXPECT_THAT(MessageFieldEquals(*lhs_message, lhs_field, + rhs_message->GetReflection()->GetMessage( + *rhs_message, rhs_field), + pool, factory), + IsOkAndHolds(test_case.equal)) + << *lhs_message << " " << lhs_field->name() << " " << *rhs_message + << " " << rhs_field->name(); + EXPECT_THAT(MessageFieldEquals(rhs_message->GetReflection()->GetMessage( + *rhs_message, rhs_field), + *lhs_message, lhs_field, pool, factory), + IsOkAndHolds(test_case.equal)) + << *lhs_message << " " << lhs_field->name() << " " << *rhs_message + << " " << rhs_field->name(); + } + // Test `google.protobuf.Any`. + absl::optional, + absl::Nonnull>> + lhs_any = PackTestAllTypesProto3Field(*lhs_message, lhs_field); + absl::optional, + absl::Nonnull>> + rhs_any = PackTestAllTypesProto3Field(*rhs_message, rhs_field); + if (lhs_any) { + EXPECT_THAT(MessageFieldEquals(*lhs_any->first, lhs_any->second, + *rhs_message, rhs_field, pool, factory), + IsOkAndHolds(test_case.equal)) + << *lhs_any->first << " " << *rhs_message; + if (!lhs_any->second->is_repeated()) { + EXPECT_THAT( + MessageFieldEquals(lhs_any->first->GetReflection()->GetMessage( + *lhs_any->first, lhs_any->second), + *rhs_message, rhs_field, pool, factory), + IsOkAndHolds(test_case.equal)) + << *lhs_any->first << " " << *rhs_message; + } + } + if (rhs_any) { + EXPECT_THAT(MessageFieldEquals(*lhs_message, lhs_field, *rhs_any->first, + rhs_any->second, pool, factory), + IsOkAndHolds(test_case.equal)) + << *lhs_message << " " << *rhs_any->first; + if (!rhs_any->second->is_repeated()) { + EXPECT_THAT( + MessageFieldEquals(*lhs_message, lhs_field, + rhs_any->first->GetReflection()->GetMessage( + *rhs_any->first, rhs_any->second), + pool, factory), + IsOkAndHolds(test_case.equal)) + << *lhs_message << " " << *rhs_any->first; + } + } + if (lhs_any && rhs_any) { + EXPECT_THAT( + MessageFieldEquals(*lhs_any->first, lhs_any->second, + *rhs_any->first, rhs_any->second, pool, factory), + IsOkAndHolds(test_case.equal)) + << *lhs_any->first << " " << *rhs_any->second; + } + } + } +} + +INSTANTIATE_TEST_SUITE_P( + UnaryMessageFieldEqualsTest, UnaryMessageFieldEqualsTest, + ValuesIn({ + { + .name = "Heterogeneous_Single_Equal", + .message = R"pb( + single_int32: 1 + single_int64: 1 + single_uint32: 1 + single_uint64: 1 + single_float: 1 + single_double: 1 + single_value: { number_value: 1 } + single_int32_wrapper: { value: 1 } + single_int64_wrapper: { value: 1 } + single_uint32_wrapper: { value: 1 } + single_uint64_wrapper: { value: 1 } + single_float_wrapper: { value: 1 } + single_double_wrapper: { value: 1 } + standalone_enum: BAR + )pb", + .fields = + { + "single_int32", + "single_int64", + "single_uint32", + "single_uint64", + "single_float", + "single_double", + "single_value", + "single_int32_wrapper", + "single_int64_wrapper", + "single_uint32_wrapper", + "single_uint64_wrapper", + "single_float_wrapper", + "single_double_wrapper", + "standalone_enum", + }, + .equal = true, + }, + { + .name = "Heterogeneous_Single_NotEqual", + .message = R"pb( + null_value: NULL_VALUE + single_bool: false + single_int32: 2 + single_int64: 3 + single_uint32: 4 + single_uint64: 5 + single_float: NaN + single_double: NaN + single_string: "foo" + single_bytes: "foo" + single_value: { number_value: 8 } + single_int32_wrapper: { value: 9 } + single_int64_wrapper: { value: 10 } + single_uint32_wrapper: { value: 11 } + single_uint64_wrapper: { value: 12 } + single_float_wrapper: { value: 13 } + single_double_wrapper: { value: 14 } + single_string_wrapper: { value: "bar" } + single_bytes_wrapper: { value: "bar" } + standalone_enum: BAR + )pb", + .fields = + { + "null_value", + "single_bool", + "single_int32", + "single_int64", + "single_uint32", + "single_uint64", + "single_float", + "single_double", + "single_string", + "single_bytes", + "single_value", + "single_int32_wrapper", + "single_int64_wrapper", + "single_uint32_wrapper", + "single_uint64_wrapper", + "single_float_wrapper", + "single_double_wrapper", + "standalone_enum", + }, + .equal = false, + }, + { + .name = "Heterogeneous_Repeated_Equal", + .message = R"pb( + repeated_int32: 1 + repeated_int64: 1 + repeated_uint32: 1 + repeated_uint64: 1 + repeated_float: 1 + repeated_double: 1 + repeated_value: { number_value: 1 } + repeated_int32_wrapper: { value: 1 } + repeated_int64_wrapper: { value: 1 } + repeated_uint32_wrapper: { value: 1 } + repeated_uint64_wrapper: { value: 1 } + repeated_float_wrapper: { value: 1 } + repeated_double_wrapper: { value: 1 } + repeated_nested_enum: BAR + single_value: { list_value: { values { number_value: 1 } } } + list_value: { values { number_value: 1 } } + )pb", + .fields = + { + "repeated_int32", + "repeated_int64", + "repeated_uint32", + "repeated_uint64", + "repeated_float", + "repeated_double", + "repeated_value", + "repeated_int32_wrapper", + "repeated_int64_wrapper", + "repeated_uint32_wrapper", + "repeated_uint64_wrapper", + "repeated_float_wrapper", + "repeated_double_wrapper", + "repeated_nested_enum", + "single_value", + "list_value", + }, + .equal = true, + }, + { + .name = "Heterogeneous_Repeated_NotEqual", + .message = R"pb( + repeated_null_value: NULL_VALUE + repeated_bool: false + repeated_int32: 2 + repeated_int64: 3 + repeated_uint32: 4 + repeated_uint64: 5 + repeated_float: 6 + repeated_double: 7 + repeated_string: "foo" + repeated_bytes: "foo" + repeated_value: { number_value: 8 } + repeated_int32_wrapper: { value: 9 } + repeated_int64_wrapper: { value: 10 } + repeated_uint32_wrapper: { value: 11 } + repeated_uint64_wrapper: { value: 12 } + repeated_float_wrapper: { value: 13 } + repeated_double_wrapper: { value: 14 } + repeated_string_wrapper: { value: "bar" } + repeated_bytes_wrapper: { value: "bar" } + repeated_nested_enum: BAR + )pb", + .fields = + { + "repeated_null_value", + "repeated_bool", + "repeated_int32", + "repeated_int64", + "repeated_uint32", + "repeated_uint64", + "repeated_float", + "repeated_double", + "repeated_string", + "repeated_bytes", + "repeated_value", + "repeated_int32_wrapper", + "repeated_int64_wrapper", + "repeated_uint32_wrapper", + "repeated_uint64_wrapper", + "repeated_float_wrapper", + "repeated_double_wrapper", + "repeated_nested_enum", + }, + .equal = false, + }, + { + .name = "Heterogeneous_Map_Equal", + .message = R"pb( + map_int32_int32 { key: 1 value: 1 } + map_int32_uint32 { key: 1 value: 1 } + map_int32_int64 { key: 1 value: 1 } + map_int32_uint64 { key: 1 value: 1 } + map_int32_float { key: 1 value: 1 } + map_int32_double { key: 1 value: 1 } + map_int32_enum { key: 1 value: BAR } + map_int32_value { + key: 1 + value: { number_value: 1 } + } + map_int32_int32_wrapper { + key: 1 + value: { value: 1 } + } + map_int32_uint32_wrapper { + key: 1 + value: { value: 1 } + } + map_int32_int64_wrapper { + key: 1 + value: { value: 1 } + } + map_int32_uint64_wrapper { + key: 1 + value: { value: 1 } + } + map_int32_float_wrapper { + key: 1 + value: { value: 1 } + } + map_int32_double_wrapper { + key: 1 + value: { value: 1 } + } + map_int64_int32 { key: 1 value: 1 } + map_int64_uint32 { key: 1 value: 1 } + map_int64_int64 { key: 1 value: 1 } + map_int64_uint64 { key: 1 value: 1 } + map_int64_float { key: 1 value: 1 } + map_int64_double { key: 1 value: 1 } + map_int64_enum { key: 1 value: BAR } + map_int64_value { + key: 1 + value: { number_value: 1 } + } + map_int64_int32_wrapper { + key: 1 + value: { value: 1 } + } + map_int64_uint32_wrapper { + key: 1 + value: { value: 1 } + } + map_int64_int64_wrapper { + key: 1 + value: { value: 1 } + } + map_int64_uint64_wrapper { + key: 1 + value: { value: 1 } + } + map_int64_float_wrapper { + key: 1 + value: { value: 1 } + } + map_int64_double_wrapper { + key: 1 + value: { value: 1 } + } + map_uint32_int32 { key: 1 value: 1 } + map_uint32_uint32 { key: 1 value: 1 } + map_uint32_int64 { key: 1 value: 1 } + map_uint32_uint64 { key: 1 value: 1 } + map_uint32_float { key: 1 value: 1 } + map_uint32_double { key: 1 value: 1 } + map_uint32_enum { key: 1 value: BAR } + map_uint32_value { + key: 1 + value: { number_value: 1 } + } + map_uint32_int32_wrapper { + key: 1 + value: { value: 1 } + } + map_uint32_uint32_wrapper { + key: 1 + value: { value: 1 } + } + map_uint32_int64_wrapper { + key: 1 + value: { value: 1 } + } + map_uint32_uint64_wrapper { + key: 1 + value: { value: 1 } + } + map_uint32_float_wrapper { + key: 1 + value: { value: 1 } + } + map_uint32_double_wrapper { + key: 1 + value: { value: 1 } + } + map_uint64_int32 { key: 1 value: 1 } + map_uint64_uint32 { key: 1 value: 1 } + map_uint64_int64 { key: 1 value: 1 } + map_uint64_uint64 { key: 1 value: 1 } + map_uint64_float { key: 1 value: 1 } + map_uint64_double { key: 1 value: 1 } + map_uint64_enum { key: 1 value: BAR } + map_uint64_value { + key: 1 + value: { number_value: 1 } + } + map_uint64_int32_wrapper { + key: 1 + value: { value: 1 } + } + map_uint64_uint32_wrapper { + key: 1 + value: { value: 1 } + } + map_uint64_int64_wrapper { + key: 1 + value: { value: 1 } + } + map_uint64_uint64_wrapper { + key: 1 + value: { value: 1 } + } + map_uint64_float_wrapper { + key: 1 + value: { value: 1 } + } + map_uint64_double_wrapper { + key: 1 + value: { value: 1 } + } + )pb", + .fields = + { + "map_int32_int32", "map_int32_uint32", + "map_int32_int64", "map_int32_uint64", + "map_int32_float", "map_int32_double", + "map_int32_enum", "map_int32_value", + "map_int32_int32_wrapper", "map_int32_uint32_wrapper", + "map_int32_int64_wrapper", "map_int32_uint64_wrapper", + "map_int32_float_wrapper", "map_int32_double_wrapper", + "map_int64_int32", "map_int64_uint32", + "map_int64_int64", "map_int64_uint64", + "map_int64_float", "map_int64_double", + "map_int64_enum", "map_int64_value", + "map_int64_int32_wrapper", "map_int64_uint32_wrapper", + "map_int64_int64_wrapper", "map_int64_uint64_wrapper", + "map_int64_float_wrapper", "map_int64_double_wrapper", + "map_uint32_int32", "map_uint32_uint32", + "map_uint32_int64", "map_uint32_uint64", + "map_uint32_float", "map_uint32_double", + "map_uint32_enum", "map_uint32_value", + "map_uint32_int32_wrapper", "map_uint32_uint32_wrapper", + "map_uint32_int64_wrapper", "map_uint32_uint64_wrapper", + "map_uint32_float_wrapper", "map_uint32_double_wrapper", + "map_uint64_int32", "map_uint64_uint32", + "map_uint64_int64", "map_uint64_uint64", + "map_uint64_float", "map_uint64_double", + "map_uint64_enum", "map_uint64_value", + "map_uint64_int32_wrapper", "map_uint64_uint32_wrapper", + "map_uint64_int64_wrapper", "map_uint64_uint64_wrapper", + "map_uint64_float_wrapper", "map_uint64_double_wrapper", + }, + .equal = true, + }, + { + .name = "Heterogeneous_Map_NotEqual", + .message = R"pb( + map_bool_bool { key: false value: false } + map_bool_int32 { key: false value: 1 } + map_bool_uint32 { key: false value: 0 } + map_int32_int32 { key: 0x7FFFFFFF value: 1 } + map_int64_int64 { key: 0x7FFFFFFFFFFFFFFF value: 1 } + map_uint32_uint32 { key: 0xFFFFFFFF value: 1 } + map_uint64_uint64 { key: 0xFFFFFFFFFFFFFFFF value: 1 } + map_string_string { key: "foo" value: "bar" } + map_string_bytes { key: "foo" value: "bar" } + map_int32_bytes { key: -2147483648 value: "bar" } + map_int64_bytes { key: -9223372036854775808 value: "bar" } + map_int32_float { key: -2147483648 value: 1 } + map_int64_double { key: -9223372036854775808 value: 1 } + map_uint32_string { key: 0xFFFFFFFF value: "bar" } + map_uint64_string { key: 0xFFFFFFFF value: "foo" } + map_uint32_bytes { key: 0xFFFFFFFF value: "bar" } + map_uint64_bytes { key: 0xFFFFFFFF value: "foo" } + map_uint32_bool { key: 0xFFFFFFFF value: false } + map_uint64_bool { key: 0xFFFFFFFF value: true } + single_value: { + struct_value: { + fields { + key: "bar" + value: { string_value: "foo" } + } + } + } + single_struct: { + fields { + key: "baz" + value: { string_value: "foo" } + } + } + standalone_message: {} + )pb", + .fields = + { + "map_bool_bool", "map_bool_int32", + "map_bool_uint32", "map_int32_int32", + "map_int64_int64", "map_uint32_uint32", + "map_uint64_uint64", "map_string_string", + "map_string_bytes", "map_int32_bytes", + "map_int64_bytes", "map_int32_float", + "map_int64_double", "map_uint32_string", + "map_uint64_string", "map_uint32_bytes", + "map_uint64_bytes", "map_uint32_bool", + "map_uint64_bool", "single_value", + "single_struct", "standalone_message", + }, + .equal = false, + }, + }), + UnaryMessageFieldEqualsTestParamName); + +TEST(MessageEquals, AnyFallback) { + const auto* pool = GetTestingDescriptorPool(); + auto* factory = GetTestingMessageFactory(); + google::protobuf::Arena arena; + auto message1 = DynamicParseTextProto( + &arena, R"pb(single_any: { + type_url: "type.googleapis.com/message.that.does.not.Exist" + value: "foo" + })pb", + pool, factory); + auto message2 = DynamicParseTextProto( + &arena, R"pb(single_any: { + type_url: "type.googleapis.com/message.that.does.not.Exist" + value: "foo" + })pb", + pool, factory); + auto message3 = DynamicParseTextProto( + &arena, R"pb(single_any: { + type_url: "type.googleapis.com/message.that.does.not.Exist" + value: "bar" + })pb", + pool, factory); + EXPECT_THAT(MessageEquals(*message1, *message2, pool, factory), + IsOkAndHolds(IsTrue())); + EXPECT_THAT(MessageEquals(*message2, *message1, pool, factory), + IsOkAndHolds(IsTrue())); + EXPECT_THAT(MessageEquals(*message1, *message3, pool, factory), + IsOkAndHolds(IsFalse())); + EXPECT_THAT(MessageEquals(*message3, *message1, pool, factory), + IsOkAndHolds(IsFalse())); +} + +TEST(MessageFieldEquals, AnyFallback) { + const auto* pool = GetTestingDescriptorPool(); + auto* factory = GetTestingMessageFactory(); + google::protobuf::Arena arena; + auto message1 = DynamicParseTextProto( + &arena, R"pb(single_any: { + type_url: "type.googleapis.com/message.that.does.not.Exist" + value: "foo" + })pb", + pool, factory); + auto message2 = DynamicParseTextProto( + &arena, R"pb(single_any: { + type_url: "type.googleapis.com/message.that.does.not.Exist" + value: "foo" + })pb", + pool, factory); + auto message3 = DynamicParseTextProto( + &arena, R"pb(single_any: { + type_url: "type.googleapis.com/message.that.does.not.Exist" + value: "bar" + })pb", + pool, factory); + EXPECT_THAT(MessageFieldEquals( + *message1, + ABSL_DIE_IF_NULL( + message1->GetDescriptor()->FindFieldByName("single_any")), + *message2, + ABSL_DIE_IF_NULL( + message2->GetDescriptor()->FindFieldByName("single_any")), + pool, factory), + IsOkAndHolds(IsTrue())); + EXPECT_THAT(MessageFieldEquals( + *message2, + ABSL_DIE_IF_NULL( + message2->GetDescriptor()->FindFieldByName("single_any")), + *message1, + ABSL_DIE_IF_NULL( + message1->GetDescriptor()->FindFieldByName("single_any")), + pool, factory), + IsOkAndHolds(IsTrue())); + EXPECT_THAT(MessageFieldEquals( + *message1, + ABSL_DIE_IF_NULL( + message1->GetDescriptor()->FindFieldByName("single_any")), + *message3, + ABSL_DIE_IF_NULL( + message3->GetDescriptor()->FindFieldByName("single_any")), + pool, factory), + IsOkAndHolds(IsFalse())); + EXPECT_THAT(MessageFieldEquals( + *message3, + ABSL_DIE_IF_NULL( + message3->GetDescriptor()->FindFieldByName("single_any")), + *message1, + ABSL_DIE_IF_NULL( + message1->GetDescriptor()->FindFieldByName("single_any")), + pool, factory), + IsOkAndHolds(IsFalse())); +} + +} // namespace +} // namespace cel::internal diff --git a/internal/well_known_types.cc b/internal/well_known_types.cc index 8fe0fef6d..5939c451b 100644 --- a/internal/well_known_types.cc +++ b/internal/well_known_types.cc @@ -365,6 +365,21 @@ absl::Status CheckMapField(absl::Nonnull field) { } // namespace +bool StringValue::ConsumePrefix(absl::string_view prefix) { + return absl::visit(absl::Overload( + [&](absl::string_view& value) { + return absl::ConsumePrefix(&value, prefix); + }, + [&](absl::Cord& cord) { + if (cord.StartsWith(prefix)) { + cord.RemovePrefix(prefix.size()); + return true; + } + return false; + }), + AsVariant(*this)); +} + StringValue GetStringField(absl::Nonnull reflection, const google::protobuf::Message& message, absl::Nonnull field, @@ -926,6 +941,13 @@ absl::StatusOr GetAnyReflection( return reflection; } +AnyReflection GetAnyReflectionOrDie( + absl::Nonnull descriptor) { + AnyReflection reflection; + ABSL_CHECK_OK(reflection.Initialize(descriptor)); // Crash OK + return reflection; +} + absl::Status DurationReflection::Initialize( absl::Nonnull pool) { CEL_ASSIGN_OR_RETURN(const auto* descriptor, @@ -979,6 +1001,29 @@ void DurationReflection::SetNanos(absl::Nonnull mess message->GetReflection()->SetInt32(message, nanos_field_, value); } +absl::StatusOr DurationReflection::ToAbslDuration( + const google::protobuf::Message& message) const { + ABSL_DCHECK(IsInitialized()); + ABSL_DCHECK_EQ(message.GetDescriptor(), descriptor_); + int64_t seconds = GetSeconds(message); + if (ABSL_PREDICT_FALSE(seconds < TimeUtil::kDurationMinSeconds || + seconds > TimeUtil::kDurationMaxSeconds)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid duration seconds: ", seconds)); + } + int32_t nanos = GetNanos(message); + if (ABSL_PREDICT_FALSE(nanos < TimeUtil::kDurationMinNanoseconds || + nanos > TimeUtil::kDurationMaxNanoseconds)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid duration nanoseconds: ", nanos)); + } + if ((seconds < 0 && nanos > 0) || (seconds > 0 && nanos < 0)) { + return absl::InvalidArgumentError(absl::StrCat( + "duration sign mismatch: seconds=", seconds, ", nanoseconds=", nanos)); + } + return absl::Seconds(seconds) + absl::Nanoseconds(nanos); +} + absl::StatusOr GetDurationReflection( absl::Nonnull descriptor) { DurationReflection reflection; @@ -1039,6 +1084,23 @@ void TimestampReflection::SetNanos(absl::Nonnull mes message->GetReflection()->SetInt32(message, nanos_field_, value); } +absl::StatusOr TimestampReflection::ToAbslTime( + const google::protobuf::Message& message) const { + int64_t seconds = GetSeconds(message); + if (ABSL_PREDICT_FALSE(seconds < TimeUtil::kTimestampMinSeconds || + seconds > TimeUtil::kTimestampMaxSeconds)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid timestamp seconds: ", seconds)); + } + int32_t nanos = GetNanos(message); + if (ABSL_PREDICT_FALSE(nanos < TimeUtil::kTimestampMinNanoseconds || + nanos > TimeUtil::kTimestampMaxNanoseconds)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid timestamp nanoseconds: ", nanos)); + } + return absl::UnixEpoch() + absl::Seconds(seconds) + absl::Nanoseconds(nanos); +} + absl::StatusOr GetTimestampReflection( absl::Nonnull descriptor) { TimestampReflection reflection; @@ -1591,7 +1653,8 @@ absl::StatusOr> AdaptAny( absl::Nullable arena, AnyReflection& reflection, const google::protobuf::Message& message, absl::Nonnull descriptor, absl::Nonnull pool, - absl::Nonnull factory) { + absl::Nonnull factory, + bool error_if_unresolveable) { ABSL_DCHECK_EQ(descriptor->well_known_type(), Descriptor::WELLKNOWNTYPE_ANY); absl::Nonnull to_unwrap = &message; Unique unwrapped; @@ -1604,11 +1667,17 @@ absl::StatusOr> AdaptAny( FlatStringValue(type_url, type_url_scratch); if (!absl::ConsumePrefix(&type_url_view, "type.googleapis.com/") && !absl::ConsumePrefix(&type_url_view, "type.googleprod.com/")) { + if (!error_if_unresolveable) { + break; + } return absl::InvalidArgumentError(absl::StrCat( "unable to find descriptor for type URL: ", type_url_view)); } const auto* packed_descriptor = pool->FindMessageTypeByName(type_url_view); if (packed_descriptor == nullptr) { + if (!error_if_unresolveable) { + break; + } return absl::InvalidArgumentError(absl::StrCat( "unable to find descriptor for type name: ", type_url_view)); } @@ -1655,7 +1724,18 @@ absl::StatusOr> UnpackAnyFrom( ABSL_DCHECK_EQ(message.GetDescriptor()->well_known_type(), Descriptor::WELLKNOWNTYPE_ANY); return AdaptAny(arena, reflection, message, message.GetDescriptor(), pool, - factory); + factory, /*error_if_unresolveable=*/true); +} + +absl::StatusOr> UnpackAnyIfResolveable( + absl::Nullable arena, AnyReflection& reflection, + const google::protobuf::Message& message, + absl::Nonnull pool, + absl::Nonnull factory) { + ABSL_DCHECK_EQ(message.GetDescriptor()->well_known_type(), + Descriptor::WELLKNOWNTYPE_ANY); + return AdaptAny(arena, reflection, message, message.GetDescriptor(), pool, + factory, /*error_if_unresolveable=*/false); } absl::StatusOr AdaptFromMessage( @@ -1744,41 +1824,11 @@ absl::StatusOr AdaptFromMessage( ABSL_UNREACHABLE(); case Descriptor::WELLKNOWNTYPE_DURATION: { CEL_ASSIGN_OR_RETURN(auto reflection, GetDurationReflection(descriptor)); - int64_t seconds = reflection.GetSeconds(*to_adapt); - if (ABSL_PREDICT_FALSE(seconds < TimeUtil::kDurationMinSeconds || - seconds > TimeUtil::kDurationMaxSeconds)) { - return absl::InvalidArgumentError( - absl::StrCat("invalid duration seconds: ", seconds)); - } - int32_t nanos = reflection.GetNanos(*to_adapt); - if (ABSL_PREDICT_FALSE(nanos < TimeUtil::kDurationMinNanoseconds || - nanos > TimeUtil::kDurationMaxNanoseconds)) { - return absl::InvalidArgumentError( - absl::StrCat("invalid duration nanoseconds: ", nanos)); - } - if ((seconds < 0 && nanos > 0) || (seconds > 0 && nanos < 0)) { - return absl::InvalidArgumentError( - absl::StrCat("duration sign mismatch: seconds=", seconds, - ", nanoseconds=", nanos)); - } - return absl::Seconds(seconds) + absl::Nanoseconds(nanos); + return reflection.ToAbslDuration(*to_adapt); } case Descriptor::WELLKNOWNTYPE_TIMESTAMP: { CEL_ASSIGN_OR_RETURN(auto reflection, GetTimestampReflection(descriptor)); - int64_t seconds = reflection.GetSeconds(*to_adapt); - if (ABSL_PREDICT_FALSE(seconds < TimeUtil::kTimestampMinSeconds || - seconds > TimeUtil::kTimestampMaxSeconds)) { - return absl::InvalidArgumentError( - absl::StrCat("invalid timestamp seconds: ", seconds)); - } - int32_t nanos = reflection.GetNanos(*to_adapt); - if (ABSL_PREDICT_FALSE(nanos < TimeUtil::kTimestampMinNanoseconds || - nanos > TimeUtil::kTimestampMaxNanoseconds)) { - return absl::InvalidArgumentError( - absl::StrCat("invalid timestamp nanoseconds: ", nanos)); - } - return absl::UnixEpoch() + absl::Seconds(seconds) + - absl::Nanoseconds(nanos); + return reflection.ToAbslTime(*to_adapt); } case Descriptor::WELLKNOWNTYPE_VALUE: { CEL_ASSIGN_OR_RETURN(auto reflection, GetValueReflection(descriptor)); diff --git a/internal/well_known_types.h b/internal/well_known_types.h index 96dfcbd62..f35d27849 100644 --- a/internal/well_known_types.h +++ b/internal/well_known_types.h @@ -65,6 +65,8 @@ namespace cel::well_known_types { class StringValue final : public absl::variant { public: using absl::variant::variant; + + bool ConsumePrefix(absl::string_view prefix); }; // Older versions of GCC do not deal with inheriting from variant correctly when @@ -647,6 +649,10 @@ absl::StatusOr GetAnyReflection( absl::Nonnull descriptor ABSL_ATTRIBUTE_LIFETIME_BOUND); +AnyReflection GetAnyReflectionOrDie( + absl::Nonnull descriptor + ABSL_ATTRIBUTE_LIFETIME_BOUND); + class DurationReflection final { public: static constexpr google::protobuf::Descriptor::WellKnownType kWellKnownType = @@ -690,6 +696,9 @@ class DurationReflection final { void SetNanos(absl::Nonnull message, int32_t value) const; + absl::StatusOr ToAbslDuration( + const google::protobuf::Message& message) const; + private: absl::Nullable descriptor_ = nullptr; absl::Nullable seconds_field_ = nullptr; @@ -743,6 +752,8 @@ class TimestampReflection final { void SetNanos(absl::Nonnull message, int32_t value) const; + absl::StatusOr ToAbslTime(const google::protobuf::Message& message) const; + private: absl::Nullable descriptor_ = nullptr; absl::Nullable seconds_field_ = nullptr; @@ -982,6 +993,11 @@ class ListValueReflection final { return values_field_->message_type(); } + absl::Nonnull GetValuesDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return values_field_; + } + int ValuesSize(const google::protobuf::Message& message) const; google::protobuf::RepeatedFieldRef Values( @@ -1081,6 +1097,11 @@ class StructReflection final { return fields_value_field_->message_type(); } + absl::Nonnull GetFieldsDescriptor() const { + ABSL_DCHECK(IsInitialized()); + return fields_field_; + } + int FieldsSize(const google::protobuf::Message& message) const; google::protobuf::MapIterator BeginFields( @@ -1241,6 +1262,15 @@ absl::StatusOr> UnpackAnyFrom( absl::Nonnull factory ABSL_ATTRIBUTE_LIFETIME_BOUND); +// Unpacks the given instance of `google.protobuf.Any` if it is resolvable. +absl::StatusOr> UnpackAnyIfResolveable( + absl::Nullable arena ABSL_ATTRIBUTE_LIFETIME_BOUND, + AnyReflection& reflection, const google::protobuf::Message& message, + absl::Nonnull pool + ABSL_ATTRIBUTE_LIFETIME_BOUND, + absl::Nonnull factory + ABSL_ATTRIBUTE_LIFETIME_BOUND); + // Performs any necessary unwrapping of a well known message type. If no // unwrapping is necessary, the resulting `Value` holds the alternative // `absl::monostate`.