Skip to content

Commit

Permalink
Make StructType a composed type
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 658047577
  • Loading branch information
jcking authored and copybara-github committed Aug 1, 2024
1 parent 761798e commit d10acf9
Show file tree
Hide file tree
Showing 43 changed files with 1,027 additions and 287 deletions.
9 changes: 9 additions & 0 deletions common/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ cc_library(
"types/*_test.cc",
],
) + [
"type.cc",
"type_factory.cc",
"type_introspector.cc",
"type_manager.cc",
Expand Down Expand Up @@ -554,6 +555,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/functional:function_ref",
"@com_google_absl//absl/functional:overload",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/meta:type_traits",
Expand All @@ -564,6 +566,7 @@ cc_library(
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
"@com_google_absl//absl/types:variant",
"@com_google_protobuf//:protobuf",
],
)

Expand All @@ -581,10 +584,15 @@ cc_test(
":memory_testing",
":native_type",
":type",
":type_kind",
"//internal:testing",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/hash:hash_testing",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:die_if_null",
"@com_google_absl//absl/types:optional",
"@com_google_protobuf//:protobuf",
],
)

Expand Down Expand Up @@ -631,6 +639,7 @@ cc_library(
":unknown",
":value_kind",
"//base:attributes",
"//base/internal:message_wrapper",
"//common/internal:arena_string",
"//common/internal:data_interface",
"//common/internal:reference_count",
Expand Down
6 changes: 2 additions & 4 deletions common/legacy_value.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1198,8 +1198,7 @@ absl::Status ModernValue(google::protobuf::Arena* arena,
result = TypeValue{TypeType{}};
return absl::OkStatus();
}
result = TypeValue{
StructType{extensions::ProtoMemoryManagerRef(arena), type_name}};
result = TypeValue{common_internal::MakeBasicStructType(type_name)};
return absl::OkStatus();
}
case CelValue::Type::kError:
Expand Down Expand Up @@ -1642,8 +1641,7 @@ TypeValue CreateTypeValueFromView(google::protobuf::Arena* arena,
return TypeValue{TypeType{}};
}
// This is bad, but technically OK since we are using an arena.
return TypeValue{
StructType{extensions::ProtoMemoryManagerRef(arena), type_name}};
return TypeValue{common_internal::MakeBasicStructType(type_name)};
}

bool TestOnly_IsLegacyListBuilder(const ListValueBuilder& builder) {
Expand Down
123 changes: 123 additions & 0 deletions common/type.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "common/type.h"

#include "absl/base/attributes.h"
#include "absl/base/nullability.h"
#include "absl/base/optimization.h"
#include "absl/log/absl_check.h"
#include "absl/types/optional.h"
#include "absl/types/variant.h"
#include "common/types/type_cache.h"
#include "common/types/types.h"
#include "google/protobuf/descriptor.h"

namespace cel {

using ::google::protobuf::Descriptor;

Type Type::Message(absl::Nonnull<const Descriptor*> descriptor) {
switch (descriptor->well_known_type()) {
case Descriptor::WELLKNOWNTYPE_BOOLVALUE:
return BoolWrapperType();
case Descriptor::WELLKNOWNTYPE_INT32VALUE:
ABSL_FALLTHROUGH_INTENDED;
case Descriptor::WELLKNOWNTYPE_INT64VALUE:
return IntWrapperType();
case Descriptor::WELLKNOWNTYPE_UINT32VALUE:
ABSL_FALLTHROUGH_INTENDED;
case Descriptor::WELLKNOWNTYPE_UINT64VALUE:
return UintWrapperType();
case Descriptor::WELLKNOWNTYPE_FLOATVALUE:
ABSL_FALLTHROUGH_INTENDED;
case Descriptor::WELLKNOWNTYPE_DOUBLEVALUE:
return DoubleWrapperType();
case Descriptor::WELLKNOWNTYPE_BYTESVALUE:
return BytesWrapperType();
case Descriptor::WELLKNOWNTYPE_STRINGVALUE:
return StringWrapperType();
case Descriptor::WELLKNOWNTYPE_ANY:
return AnyType();
case Descriptor::WELLKNOWNTYPE_DURATION:
return DurationType();
case Descriptor::WELLKNOWNTYPE_TIMESTAMP:
return TimestampType();
case Descriptor::WELLKNOWNTYPE_VALUE:
return DynType();
case Descriptor::WELLKNOWNTYPE_LISTVALUE:
return ListType();
case Descriptor::WELLKNOWNTYPE_STRUCT:
return common_internal::ProcessLocalTypeCache::Get()
->GetStringDynMapType();
default:
return MessageType(descriptor);
}
}

common_internal::StructTypeVariant Type::ToStructTypeVariant() const {
if (const auto* other = absl::get_if<MessageType>(&variant_);
other != nullptr) {
return common_internal::StructTypeVariant(*other);
}
if (const auto* other =
absl::get_if<common_internal::BasicStructType>(&variant_);
other != nullptr) {
return common_internal::StructTypeVariant(*other);
}
return common_internal::StructTypeVariant();
}

absl::optional<StructType> Type::AsStruct() const {
if (const auto* alt =
absl::get_if<common_internal::BasicStructType>(&variant_);
alt != nullptr) {
return *alt;
}
if (const auto* alt = absl::get_if<MessageType>(&variant_); alt != nullptr) {
return *alt;
}
return absl::nullopt;
}

absl::optional<MessageType> Type::AsMessage() const {
if (const auto* alt = absl::get_if<MessageType>(&variant_); alt != nullptr) {
return *alt;
}
return absl::nullopt;
}

Type::operator StructType() const {
ABSL_DCHECK(IsStruct()) << DebugString();
if (const auto* alt =
absl::get_if<common_internal::BasicStructType>(&variant_);
alt != nullptr) {
return *alt;
}
if (const auto* alt = absl::get_if<MessageType>(&variant_); alt != nullptr) {
return *alt;
}
return StructType();
}

Type::operator MessageType() const {
ABSL_DCHECK(IsMessage()) << DebugString();
if (const auto* alt = absl::get_if<MessageType>(&variant_);
ABSL_PREDICT_TRUE(alt != nullptr)) {
return *alt;
}
return MessageType();
}

} // namespace cel
107 changes: 88 additions & 19 deletions common/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@

#include "absl/algorithm/container.h"
#include "absl/base/attributes.h"
#include "absl/base/nullability.h"
#include "absl/container/fixed_array.h"
#include "absl/log/absl_check.h"
#include "absl/meta/type_traits.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "absl/types/variant.h"
#include "common/casting.h"
Expand All @@ -51,6 +53,7 @@
#include "common/types/int_wrapper_type.h" // IWYU pragma: export
#include "common/types/list_type.h" // IWYU pragma: export
#include "common/types/map_type.h" // IWYU pragma: export
#include "common/types/message_type.h" // IWYU pragma: export
#include "common/types/null_type.h" // IWYU pragma: export
#include "common/types/opaque_type.h" // IWYU pragma: export
#include "common/types/optional_type.h" // IWYU pragma: export
Expand All @@ -64,6 +67,7 @@
#include "common/types/uint_type.h" // IWYU pragma: export
#include "common/types/uint_wrapper_type.h" // IWYU pragma: export
#include "common/types/unknown_type.h" // IWYU pragma: export
#include "google/protobuf/descriptor.h"

namespace cel {

Expand All @@ -76,6 +80,12 @@ class Type;
// best to fail.
class Type final {
public:
// Returns an appropriate `Type` for the dynamic protobuf message. For well
// known message types, the appropriate `Type` is returned. All others return
// `MessageType`.
static Type Message(absl::Nonnull<const google::protobuf::Descriptor*> descriptor
ABSL_ATTRIBUTE_LIFETIME_BOUND);

Type() = default;
Type(const Type&) = default;
Type(Type&&) = default;
Expand Down Expand Up @@ -118,6 +128,15 @@ class Type final {
return *this;
}

// NOLINTNEXTLINE(google-explicit-constructor)
Type(StructType alternative) : variant_(alternative.ToTypeVariant()) {}

// NOLINTNEXTLINE(google-explicit-constructor)
Type& operator=(StructType alternative) {
variant_ = alternative.ToTypeVariant();
return *this;
}

TypeKind kind() const {
AssertIsValid();
return absl::visit(
Expand Down Expand Up @@ -266,18 +285,51 @@ class Type final {

const Type* operator->() const { return this; }

bool IsStruct() const {
return absl::holds_alternative<common_internal::BasicStructType>(
variant_) ||
absl::holds_alternative<MessageType>(variant_);
}

bool IsMessage() const {
return absl::holds_alternative<MessageType>(variant_);
}

// AsStruct performs a checked cast, returning `StructType` if this type is a
// struct or `absl::nullopt` otherwise. If you have already called
// `IsStruct()` it is more performant to perform to do
// `static_cast<StructType>(type)`.
absl::optional<StructType> AsStruct() const;

// AsMessage performs a checked cast, returning `MessageType` if this type is
// both a struct and a message or `absl::nullopt` otherwise. If you have
// already called `IsMessage()` it is more performant to perform to do
// `static_cast<MessageType>(type)`.
absl::optional<MessageType> AsMessage() const;

explicit operator bool() const {
return !absl::holds_alternative<absl::monostate>(variant_);
}

explicit operator StructType() const;

explicit operator MessageType() const;

private:
friend struct NativeTypeTraits<Type>;
friend struct CompositionTraits<Type>;
friend class StructType;
friend class MessageType;
friend class common_internal::BasicStructType;

constexpr bool IsValid() const {
return !absl::holds_alternative<absl::monostate>(variant_);
}
bool IsValid() const { return static_cast<bool>(*this); }

void AssertIsValid() const {
ABSL_DCHECK(IsValid()) << "use of invalid Type";
}

common_internal::StructTypeVariant ToStructTypeVariant() const;

common_internal::TypeVariant variant_;
};

Expand Down Expand Up @@ -340,6 +392,13 @@ struct CompositionTraits<Type> final {
}
}

template <typename U>
static std::enable_if_t<std::is_same_v<StructType, U>, bool> HasA(
const Type& type) {
type.AssertIsValid();
return type.IsStruct();
}

template <typename U>
static std::enable_if_t<common_internal::IsTypeAlternativeV<U>, const U&> Get(
const Type& type) {
Expand Down Expand Up @@ -387,6 +446,32 @@ struct CompositionTraits<Type> final {
return Cast<U>(absl::get<Base>(std::move(type.variant_)));
}
}

template <typename U>
static std::enable_if_t<std::is_same_v<StructType, U>, U> Get(
const Type& type) {
type.AssertIsValid();
return static_cast<StructType>(type);
}

template <typename U>
static std::enable_if_t<std::is_same_v<StructType, U>, U> Get(Type& type) {
type.AssertIsValid();
return static_cast<StructType>(type);
}

template <typename U>
static std::enable_if_t<std::is_same_v<StructType, U>, U> Get(
const Type&& type) {
type.AssertIsValid();
return static_cast<StructType>(type);
}

template <typename U>
static std::enable_if_t<std::is_same_v<StructType, U>, U> Get(Type&& type) {
type.AssertIsValid();
return static_cast<StructType>(type);
}
};

template <typename To, typename From>
Expand Down Expand Up @@ -437,12 +522,6 @@ struct OpaqueTypeData final {
const absl::FixedArray<Type, 1> parameters;
};

struct StructTypeData final {
explicit StructTypeData(std::string name) : name(std::move(name)) {}

const std::string name;
};

struct TypeParamTypeData final {
explicit TypeParamTypeData(std::string name) : name(std::move(name)) {}

Expand Down Expand Up @@ -540,16 +619,6 @@ inline H AbslHashValue(H state, const MapType& type) {
return H::combine(std::move(state), type.key(), type.value());
}

inline StructType::StructType(MemoryManagerRef memory_manager,
absl::string_view name)
: data_(memory_manager.MakeShared<common_internal::StructTypeData>(
std::string(name))) {}

inline absl::string_view StructType::name() const
ABSL_ATTRIBUTE_LIFETIME_BOUND {
return data_->name;
}

inline TypeParamType::TypeParamType(MemoryManagerRef memory_manager,
absl::string_view name)
: data_(memory_manager.MakeShared<common_internal::TypeParamTypeData>(
Expand Down
5 changes: 0 additions & 5 deletions common/type_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,6 @@ MapType TypeFactory::CreateMapType(const Type& key, const Type& value) {
return CreateMapTypeImpl(key, value);
}

StructType TypeFactory::CreateStructType(absl::string_view name) {
ABSL_DCHECK(internal::IsValidRelativeName(name)) << name;
return CreateStructTypeImpl(name);
}

OpaqueType TypeFactory::CreateOpaqueType(absl::string_view name,
absl::Span<const Type> parameters) {
ABSL_DCHECK(internal::IsValidRelativeName(name)) << name;
Expand Down
Loading

0 comments on commit d10acf9

Please # to comment.