diff --git a/common/internal/BUILD b/common/internal/BUILD index 1065b7118..23343f1dc 100644 --- a/common/internal/BUILD +++ b/common/internal/BUILD @@ -107,3 +107,45 @@ cc_library( hdrs = ["metadata.h"], deps = ["@com_google_protobuf//:protobuf"], ) + +cc_library( + name = "byte_string", + srcs = ["byte_string.cc"], + hdrs = ["byte_string.h"], + deps = [ + ":metadata", + ":reference_count", + "//common:allocator", + "//common:memory", + "//internal:new", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/functional:overload", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/log:absl_check", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "byte_string_test", + srcs = ["byte_string_test.cc"], + deps = [ + ":byte_string", + ":reference_count", + "//common:allocator", + "//common:memory", + "//internal:testing", + "@com_google_absl//absl/base:no_destructor", + "@com_google_absl//absl/hash", + "@com_google_absl//absl/strings:cord", + "@com_google_absl//absl/strings:cord_test_helpers", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:optional", + "@com_google_protobuf//:protobuf", + ], +) diff --git a/common/internal/byte_string.cc b/common/internal/byte_string.cc new file mode 100644 index 000000000..241be87e8 --- /dev/null +++ b/common/internal/byte_string.cc @@ -0,0 +1,1312 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/internal/byte_string.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/functional/overload.h" +#include "absl/hash/hash.h" +#include "absl/log/absl_check.h" +#include "absl/strings/cord.h" +#include "absl/strings/match.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/allocator.h" +#include "common/internal/metadata.h" +#include "common/internal/reference_count.h" +#include "common/memory.h" +#include "internal/new.h" +#include "google/protobuf/arena.h" + +namespace cel::common_internal { + +namespace { + +char* CopyCordToArray(const absl::Cord& cord, char* data) { + for (auto chunk : cord.Chunks()) { + std::memcpy(data, chunk.data(), chunk.size()); + data += chunk.size(); + } + return data; +} + +class ReferenceCountedStdString final : public ReferenceCounted { + public: + template + explicit ReferenceCountedStdString(String&& string) { + ::new (static_cast(&string_[0])) + std::string(std::forward(string)); + } + + const char* data() const noexcept { + return std::launder(reinterpret_cast(&string_[0])) + ->data(); + } + + size_t size() const noexcept { + return std::launder(reinterpret_cast(&string_[0])) + ->size(); + } + + private: + void Finalize() noexcept override { + std::destroy_at(std::launder(reinterpret_cast(&string_[0]))); + } + + alignas(std::string) char string_[sizeof(std::string)]; +}; + +class ReferenceCountedString final : public ReferenceCounted { + public: + static const ReferenceCountedString* New(const char* data, size_t size) { + size_t offset = offsetof(ReferenceCountedString, data_); + return ::new (internal::New(offset + size)) + ReferenceCountedString(size, data); + } + + const char* data() const noexcept { + return reinterpret_cast(&data_); + } + + size_t size() const noexcept { return size_; } + + private: + ReferenceCountedString(size_t size, const char* data) noexcept : size_(size) { + std::memcpy(&data_, data, size); + } + + void Delete() noexcept override { + void* const that = this; + const auto size = size_; + std::destroy_at(this); + internal::SizedDelete(that, offsetof(ReferenceCountedString, data_) + size); + } + + const size_t size_; + char data_[]; +}; + +template +T ConsumeAndDestroy(T& object) { + T consumed = std::move(object); + object.~T(); // NOLINT(bugprone-use-after-move) + return consumed; +} + +} // namespace + +ByteString::ByteString(Allocator<> allocator, absl::string_view string) { + ABSL_DCHECK_LE(string.size(), max_size()); + auto* arena = allocator.arena(); + if (string.size() <= kSmallByteStringCapacity) { + SetSmall(arena, string); + } else { + SetMedium(arena, string); + } +} + +ByteString::ByteString(Allocator<> allocator, const std::string& string) { + ABSL_DCHECK_LE(string.size(), max_size()); + auto* arena = allocator.arena(); + if (string.size() <= kSmallByteStringCapacity) { + SetSmall(arena, string); + } else { + SetMedium(arena, string); + } +} + +ByteString::ByteString(Allocator<> allocator, std::string&& string) { + ABSL_DCHECK_LE(string.size(), max_size()); + auto* arena = allocator.arena(); + if (string.size() <= kSmallByteStringCapacity) { + SetSmall(arena, string); + } else { + SetMedium(arena, std::move(string)); + } +} + +ByteString::ByteString(Allocator<> allocator, const absl::Cord& cord) { + ABSL_DCHECK_LE(cord.size(), max_size()); + auto* arena = allocator.arena(); + if (cord.size() <= kSmallByteStringCapacity) { + SetSmall(arena, cord); + } else if (arena != nullptr) { + SetMedium(arena, cord); + } else { + SetLarge(cord); + } +} + +ByteString ByteString::Borrowed(Owner owner, absl::string_view string) { + ABSL_DCHECK(owner != Owner::None()) << "Borrowing from Owner::None()"; + auto* arena = owner.arena(); + if (string.size() <= kSmallByteStringCapacity || arena != nullptr) { + return ByteString(arena, string); + } + const auto* refcount = OwnerRelease(owner); + // A nullptr refcount indicates somebody called us to borrow something that + // has no owner. If this is the case, we fallback to assuming operator + // new/delete and convert it to a reference count. + if (refcount == nullptr) { + auto* refcount_string = + ReferenceCountedString::New(string.data(), string.size()); + string = absl::string_view(refcount_string->data(), string.size()); + refcount = refcount_string; + } + return ByteString(refcount, string); +} + +ByteString ByteString::Borrowed(const Owner& owner, const absl::Cord& cord) { + ABSL_DCHECK(owner != Owner::None()) << "Borrowing from Owner::None()"; + return ByteString(owner.arena(), cord); +} + +ByteString::ByteString(absl::Nonnull refcount, + absl::string_view string) { + ABSL_DCHECK_LE(string.size(), max_size()); + SetMedium(string, reinterpret_cast(refcount) | + kMetadataOwnerReferenceCountBit); +} + +absl::Nullable ByteString::GetArena() const noexcept { + switch (GetKind()) { + case ByteStringKind::kSmall: + return GetSmallArena(); + case ByteStringKind::kMedium: + return GetMediumArena(); + case ByteStringKind::kLarge: + return nullptr; + } +} + +bool ByteString::empty() const noexcept { + switch (GetKind()) { + case ByteStringKind::kSmall: + return rep_.small.size == 0; + case ByteStringKind::kMedium: + return rep_.medium.size == 0; + case ByteStringKind::kLarge: + return GetLarge().empty(); + } +} + +size_t ByteString::size() const noexcept { + switch (GetKind()) { + case ByteStringKind::kSmall: + return rep_.small.size; + case ByteStringKind::kMedium: + return rep_.medium.size; + case ByteStringKind::kLarge: + return GetLarge().size(); + } +} + +absl::string_view ByteString::Flatten() { + switch (GetKind()) { + case ByteStringKind::kSmall: + return GetSmall(); + case ByteStringKind::kMedium: + return GetMedium(); + case ByteStringKind::kLarge: + return GetLarge().Flatten(); + } +} + +absl::optional ByteString::TryFlat() const noexcept { + switch (GetKind()) { + case ByteStringKind::kSmall: + return GetSmall(); + case ByteStringKind::kMedium: + return GetMedium(); + case ByteStringKind::kLarge: + return GetLarge().TryFlat(); + } +} + +absl::string_view ByteString::GetFlat(std::string& scratch) const { + switch (GetKind()) { + case ByteStringKind::kSmall: + return GetSmall(); + case ByteStringKind::kMedium: + return GetMedium(); + case ByteStringKind::kLarge: { + const auto& large = GetLarge(); + if (auto flat = large.TryFlat(); flat) { + return *flat; + } + scratch = static_cast(large); + return scratch; + } + } +} + +void ByteString::RemovePrefix(size_t n) { + ABSL_DCHECK_LE(n, size()); + if (n == 0) { + return; + } + switch (GetKind()) { + case ByteStringKind::kSmall: + std::memmove(rep_.small.data, rep_.small.data + n, rep_.small.size - n); + rep_.small.size -= n; + break; + case ByteStringKind::kMedium: + rep_.medium.data += n; + rep_.medium.size -= n; + if (rep_.medium.size <= kSmallByteStringCapacity) { + const auto* refcount = GetMediumReferenceCount(); + SetSmall(GetMediumArena(), GetMedium()); + StrongUnref(refcount); + } + break; + case ByteStringKind::kLarge: { + auto& large = GetLarge(); + const auto large_size = large.size(); + const auto new_large_pos = n; + const auto new_large_size = large_size - n; + large = large.Subcord(new_large_pos, new_large_size); + if (new_large_size <= kSmallByteStringCapacity) { + auto large_copy = std::move(large); + DestroyLarge(); + SetSmall(nullptr, large_copy); + } + } break; + } +} + +void ByteString::RemoveSuffix(size_t n) { + ABSL_DCHECK_LE(n, size()); + if (n == 0) { + return; + } + switch (GetKind()) { + case ByteStringKind::kSmall: + rep_.small.size -= n; + break; + case ByteStringKind::kMedium: + rep_.medium.size -= n; + if (rep_.medium.size <= kSmallByteStringCapacity) { + const auto* refcount = GetMediumReferenceCount(); + SetSmall(GetMediumArena(), GetMedium()); + StrongUnref(refcount); + } + break; + case ByteStringKind::kLarge: { + auto& large = GetLarge(); + const auto large_size = large.size(); + const auto new_large_pos = 0; + const auto new_large_size = large_size - n; + large = large.Subcord(new_large_pos, new_large_size); + if (new_large_size <= kSmallByteStringCapacity) { + auto large_copy = std::move(large); + DestroyLarge(); + SetSmall(nullptr, large_copy); + } + } break; + } +} + +std::string ByteString::ToString() const { + switch (GetKind()) { + case ByteStringKind::kSmall: + return std::string(GetSmall()); + case ByteStringKind::kMedium: + return std::string(GetMedium()); + case ByteStringKind::kLarge: + return static_cast(GetLarge()); + } +} + +namespace { + +struct ReferenceCountReleaser { + absl::Nonnull refcount; + + void operator()() const noexcept { StrongUnref(*refcount); } +}; + +} // namespace + +absl::Cord ByteString::ToCord() const& { + switch (GetKind()) { + case ByteStringKind::kSmall: + return absl::Cord(GetSmall()); + case ByteStringKind::kMedium: { + const auto* refcount = GetMediumReferenceCount(); + if (refcount != nullptr) { + StrongRef(*refcount); + return absl::MakeCordFromExternal(GetMedium(), + ReferenceCountReleaser{refcount}); + } + return absl::Cord(GetMedium()); + } + case ByteStringKind::kLarge: + return GetLarge(); + } +} + +absl::Cord ByteString::ToCord() && { + switch (GetKind()) { + case ByteStringKind::kSmall: + return absl::Cord(GetSmall()); + case ByteStringKind::kMedium: { + const auto* refcount = GetMediumReferenceCount(); + if (refcount != nullptr) { + auto medium = GetMedium(); + SetSmallEmpty(nullptr); + return absl::MakeCordFromExternal(medium, + ReferenceCountReleaser{refcount}); + } + return absl::Cord(GetMedium()); + } + case ByteStringKind::kLarge: + return GetLarge(); + } +} + +absl::Nullable ByteString::GetMediumArena( + const MediumByteStringRep& rep) noexcept { + if ((rep.owner & kMetadataOwnerBits) == kMetadataOwnerArenaBit) { + return reinterpret_cast(rep.owner & + kMetadataOwnerPointerMask); + } + return nullptr; +} + +absl::Nullable ByteString::GetMediumReferenceCount( + const MediumByteStringRep& rep) noexcept { + if ((rep.owner & kMetadataOwnerBits) == kMetadataOwnerReferenceCountBit) { + return reinterpret_cast(rep.owner & + kMetadataOwnerPointerMask); + } + return nullptr; +} + +void ByteString::CopyFrom(const ByteString& other) { + const auto kind = GetKind(); + const auto other_kind = other.GetKind(); + switch (kind) { + case ByteStringKind::kSmall: + switch (other_kind) { + case ByteStringKind::kSmall: + CopyFromSmallSmall(other); + break; + case ByteStringKind::kMedium: + CopyFromSmallMedium(other); + break; + case ByteStringKind::kLarge: + CopyFromSmallLarge(other); + break; + } + break; + case ByteStringKind::kMedium: + switch (other_kind) { + case ByteStringKind::kSmall: + CopyFromMediumSmall(other); + break; + case ByteStringKind::kMedium: + CopyFromMediumMedium(other); + break; + case ByteStringKind::kLarge: + CopyFromMediumLarge(other); + break; + } + break; + case ByteStringKind::kLarge: + switch (other_kind) { + case ByteStringKind::kSmall: + CopyFromLargeSmall(other); + break; + case ByteStringKind::kMedium: + CopyFromLargeMedium(other); + break; + case ByteStringKind::kLarge: + CopyFromLargeLarge(other); + break; + } + break; + } +} + +void ByteString::CopyFromSmallSmall(const ByteString& other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kSmall); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringKind::kSmall); + rep_.small.size = other.rep_.small.size; + std::memcpy(rep_.small.data, other.rep_.small.data, rep_.small.size); +} + +void ByteString::CopyFromSmallMedium(const ByteString& other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kSmall); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringKind::kMedium); + SetMedium(GetSmallArena(), other.GetMedium()); +} + +void ByteString::CopyFromSmallLarge(const ByteString& other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kSmall); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringKind::kLarge); + SetMediumOrLarge(GetSmallArena(), other.GetLarge()); +} + +void ByteString::CopyFromMediumSmall(const ByteString& other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringKind::kSmall); + auto* arena = GetMediumArena(); + if (arena == nullptr) { + DestroyMedium(); + } + SetSmall(arena, other.GetSmall()); +} + +void ByteString::CopyFromMediumMedium(const ByteString& other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringKind::kMedium); + auto* arena = GetMediumArena(); + auto* other_arena = other.GetMediumArena(); + if (arena == other_arena) { + // No need to call `DestroyMedium`, we take care of the reference count + // management directly. + if (other_arena == nullptr) { + StrongRef(other.GetMediumReferenceCount()); + } + if (arena == nullptr) { + StrongUnref(GetMediumReferenceCount()); + } + SetMedium(other.GetMedium(), other.GetMediumOwner()); + } else { + // Different allocator. This could be interesting. + DestroyMedium(); + SetMedium(arena, other.GetMedium()); + } +} + +void ByteString::CopyFromMediumLarge(const ByteString& other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringKind::kLarge); + auto* arena = GetMediumArena(); + if (arena == nullptr) { + DestroyMedium(); + SetLarge(std::move(other.GetLarge())); + } else { + // No need to call `DestroyMedium`, it is guaranteed that we do not have a + // reference count because `arena` is not `nullptr`. + SetMedium(arena, other.GetLarge()); + } +} + +void ByteString::CopyFromLargeSmall(const ByteString& other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kLarge); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringKind::kSmall); + DestroyLarge(); + SetSmall(nullptr, other.GetSmall()); +} + +void ByteString::CopyFromLargeMedium(const ByteString& other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kLarge); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringKind::kMedium); + const auto* refcount = other.GetMediumReferenceCount(); + if (refcount != nullptr) { + StrongRef(*refcount); + DestroyLarge(); + SetMedium(other.GetMedium(), other.GetMediumOwner()); + } else { + GetLarge() = other.GetMedium(); + } +} + +void ByteString::CopyFromLargeLarge(const ByteString& other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kLarge); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringKind::kLarge); + GetLarge() = std::move(other.GetLarge()); +} + +void ByteString::CopyFrom(ByteStringView other) { + const auto kind = GetKind(); + const auto other_kind = other.GetKind(); + switch (kind) { + case ByteStringKind::kSmall: + switch (other_kind) { + case ByteStringViewKind::kString: + CopyFromSmallString(other); + break; + case ByteStringViewKind::kCord: + CopyFromSmallCord(other); + break; + } + break; + case ByteStringKind::kMedium: + switch (other_kind) { + case ByteStringViewKind::kString: + CopyFromMediumString(other); + break; + case ByteStringViewKind::kCord: + CopyFromMediumCord(other); + break; + } + break; + case ByteStringKind::kLarge: + switch (other_kind) { + case ByteStringViewKind::kString: + CopyFromLargeString(other); + break; + case ByteStringViewKind::kCord: + CopyFromLargeCord(other); + break; + } + break; + } +} + +void ByteString::CopyFromSmallString(ByteStringView other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kSmall); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringViewKind::kString); + auto* arena = GetSmallArena(); + const auto other_string = other.GetString(); + if (other_string.size() <= kSmallByteStringCapacity) { + SetSmall(arena, other_string); + } else { + SetMedium(arena, other_string); + } +} + +void ByteString::CopyFromSmallCord(ByteStringView other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kSmall); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringViewKind::kCord); + auto* arena = GetSmallArena(); + auto other_cord = other.GetSubcord(); + if (other_cord.size() <= kSmallByteStringCapacity) { + SetSmall(arena, other_cord); + } else { + SetMediumOrLarge(arena, std::move(other_cord)); + } +} + +void ByteString::CopyFromMediumString(ByteStringView other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringViewKind::kString); + auto* arena = GetMediumArena(); + const auto other_string = other.GetString(); + if (other_string.size() <= kSmallByteStringCapacity) { + DestroyMedium(); + SetSmall(arena, other_string); + return; + } + auto* other_arena = other.GetStringArena(); + if (arena == other_arena) { + if (other_arena == nullptr) { + StrongRef(other.GetStringReferenceCount()); + } + if (arena == nullptr) { + StrongUnref(GetMediumReferenceCount()); + } + SetMedium(other_string, other.GetStringOwner()); + } else { + DestroyMedium(); + SetMedium(arena, other_string); + } +} + +void ByteString::CopyFromMediumCord(ByteStringView other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringViewKind::kCord); + auto* arena = GetMediumArena(); + auto other_cord = other.GetSubcord(); + DestroyMedium(); + if (other_cord.size() <= kSmallByteStringCapacity) { + SetSmall(arena, other_cord); + } else { + SetMediumOrLarge(arena, std::move(other_cord)); + } +} + +void ByteString::CopyFromLargeString(ByteStringView other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kLarge); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringViewKind::kString); + const auto other_string = other.GetString(); + if (other_string.size() <= kSmallByteStringCapacity) { + DestroyLarge(); + SetSmall(nullptr, other_string); + return; + } + auto* other_arena = other.GetStringArena(); + if (other_arena == nullptr) { + const auto* refcount = other.GetStringReferenceCount(); + if (refcount != nullptr) { + StrongRef(*refcount); + DestroyLarge(); + SetMedium(other_string, other.GetStringOwner()); + return; + } + } + GetLarge() = other_string; +} + +void ByteString::CopyFromLargeCord(ByteStringView other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kLarge); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringViewKind::kCord); + auto cord = other.GetSubcord(); + if (cord.size() <= kSmallByteStringCapacity) { + DestroyLarge(); + SetSmall(nullptr, cord); + } else { + GetLarge() = std::move(cord); + } +} + +void ByteString::MoveFrom(ByteString& other) { + const auto kind = GetKind(); + const auto other_kind = other.GetKind(); + switch (kind) { + case ByteStringKind::kSmall: + switch (other_kind) { + case ByteStringKind::kSmall: + MoveFromSmallSmall(other); + break; + case ByteStringKind::kMedium: + MoveFromSmallMedium(other); + break; + case ByteStringKind::kLarge: + MoveFromSmallLarge(other); + break; + } + break; + case ByteStringKind::kMedium: + switch (other_kind) { + case ByteStringKind::kSmall: + MoveFromMediumSmall(other); + break; + case ByteStringKind::kMedium: + MoveFromMediumMedium(other); + break; + case ByteStringKind::kLarge: + MoveFromMediumLarge(other); + break; + } + break; + case ByteStringKind::kLarge: + switch (other_kind) { + case ByteStringKind::kSmall: + MoveFromLargeSmall(other); + break; + case ByteStringKind::kMedium: + MoveFromLargeMedium(other); + break; + case ByteStringKind::kLarge: + MoveFromLargeLarge(other); + break; + } + break; + } +} + +void ByteString::MoveFromSmallSmall(ByteString& other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kSmall); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringKind::kSmall); + rep_.small.size = other.rep_.small.size; + std::memcpy(rep_.small.data, other.rep_.small.data, rep_.small.size); + other.SetSmallEmpty(other.GetSmallArena()); +} + +void ByteString::MoveFromSmallMedium(ByteString& other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kSmall); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringKind::kMedium); + auto* arena = GetSmallArena(); + auto* other_arena = other.GetMediumArena(); + if (arena == other_arena) { + SetMedium(other.GetMedium(), other.GetMediumOwner()); + } else { + SetMedium(arena, other.GetMedium()); + other.DestroyMedium(); + } + other.SetSmallEmpty(other_arena); +} + +void ByteString::MoveFromSmallLarge(ByteString& other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kSmall); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringKind::kLarge); + auto* arena = GetSmallArena(); + if (arena == nullptr) { + SetLarge(std::move(other.GetLarge())); + } else { + SetMediumOrLarge(arena, other.GetLarge()); + } + other.DestroyLarge(); + other.SetSmallEmpty(nullptr); +} + +void ByteString::MoveFromMediumSmall(ByteString& other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringKind::kSmall); + auto* arena = GetMediumArena(); + auto* other_arena = other.GetSmallArena(); + if (arena == nullptr) { + DestroyMedium(); + } + SetSmall(arena, other.GetSmall()); + other.SetSmallEmpty(other_arena); +} + +void ByteString::MoveFromMediumMedium(ByteString& other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringKind::kMedium); + auto* arena = GetMediumArena(); + auto* other_arena = other.GetMediumArena(); + DestroyMedium(); + if (arena == other_arena) { + SetMedium(other.GetMedium(), other.GetMediumOwner()); + } else { + SetMedium(arena, other.GetMedium()); + other.DestroyMedium(); + } + other.SetSmallEmpty(other_arena); +} + +void ByteString::MoveFromMediumLarge(ByteString& other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringKind::kLarge); + auto* arena = GetMediumArena(); + DestroyMedium(); + SetMediumOrLarge(arena, std::move(other.GetLarge())); + other.DestroyLarge(); + other.SetSmallEmpty(nullptr); +} + +void ByteString::MoveFromLargeSmall(ByteString& other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kLarge); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringKind::kSmall); + auto* other_arena = other.GetSmallArena(); + DestroyLarge(); + SetSmall(nullptr, other.GetSmall()); + other.SetSmallEmpty(other_arena); +} + +void ByteString::MoveFromLargeMedium(ByteString& other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kLarge); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringKind::kMedium); + auto* other_arena = other.GetMediumArena(); + if (other_arena == nullptr) { + DestroyLarge(); + SetMedium(other.GetMedium(), other.GetMediumOwner()); + } else { + GetLarge() = other.GetMedium(); + other.DestroyMedium(); + } + other.SetSmallEmpty(other_arena); +} + +void ByteString::MoveFromLargeLarge(ByteString& other) { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kLarge); + ABSL_DCHECK_EQ(other.GetKind(), ByteStringKind::kLarge); + GetLarge() = ConsumeAndDestroy(other.GetLarge()); + other.SetSmallEmpty(nullptr); +} + +void ByteString::HashValue(absl::HashState state) const { + Visit(absl::Overload( + [&state](absl::string_view string) { + absl::HashState::combine(std::move(state), string); + }, + [&state](const absl::Cord& cord) { + absl::HashState::combine(std::move(state), cord); + })); +} + +void ByteString::Swap(ByteString& other) { + const auto kind = GetKind(); + const auto other_kind = other.GetKind(); + switch (kind) { + case ByteStringKind::kSmall: + switch (other_kind) { + case ByteStringKind::kSmall: + SwapSmallSmall(*this, other); + break; + case ByteStringKind::kMedium: + SwapSmallMedium(*this, other); + break; + case ByteStringKind::kLarge: + SwapSmallLarge(*this, other); + break; + } + break; + case ByteStringKind::kMedium: + switch (other_kind) { + case ByteStringKind::kSmall: + SwapSmallMedium(other, *this); + break; + case ByteStringKind::kMedium: + SwapMediumMedium(*this, other); + break; + case ByteStringKind::kLarge: + SwapMediumLarge(*this, other); + break; + } + break; + case ByteStringKind::kLarge: + switch (other_kind) { + case ByteStringKind::kSmall: + SwapSmallLarge(other, *this); + break; + case ByteStringKind::kMedium: + SwapMediumLarge(other, *this); + break; + case ByteStringKind::kLarge: + SwapLargeLarge(*this, other); + break; + } + break; + } +} + +void ByteString::Destroy() noexcept { + switch (GetKind()) { + case ByteStringKind::kSmall: + break; + case ByteStringKind::kMedium: + DestroyMedium(); + break; + case ByteStringKind::kLarge: + DestroyLarge(); + break; + } +} + +void ByteString::SetSmallEmpty(absl::Nullable arena) { + rep_.header.kind = ByteStringKind::kSmall; + rep_.small.size = 0; + rep_.small.arena = arena; +} + +void ByteString::SetSmall(absl::Nullable arena, + absl::string_view string) { + ABSL_DCHECK_LE(string.size(), kSmallByteStringCapacity); + rep_.header.kind = ByteStringKind::kSmall; + rep_.small.size = string.size(); + rep_.small.arena = arena; + std::memcpy(rep_.small.data, string.data(), rep_.small.size); +} + +void ByteString::SetSmall(absl::Nullable arena, + const absl::Cord& cord) { + ABSL_DCHECK_LE(cord.size(), kSmallByteStringCapacity); + rep_.header.kind = ByteStringKind::kSmall; + rep_.small.size = cord.size(); + rep_.small.arena = arena; + (CopyCordToArray)(cord, rep_.small.data); +} + +void ByteString::SetMedium(absl::Nullable arena, + absl::string_view string) { + ABSL_DCHECK_GT(string.size(), kSmallByteStringCapacity); + rep_.header.kind = ByteStringKind::kMedium; + rep_.medium.size = string.size(); + if (arena != nullptr) { + char* data = static_cast( + arena->AllocateAligned(rep_.medium.size, alignof(char))); + std::memcpy(data, string.data(), rep_.medium.size); + rep_.medium.data = data; + rep_.medium.owner = + reinterpret_cast(arena) | kMetadataOwnerArenaBit; + } else { + const auto* refcount = + ReferenceCountedString::New(string.data(), string.size()); + rep_.medium.data = refcount->data(); + rep_.medium.owner = + reinterpret_cast(refcount) | kMetadataOwnerReferenceCountBit; + } +} + +void ByteString::SetMedium(absl::Nullable arena, + std::string&& string) { + ABSL_DCHECK_GT(string.size(), kSmallByteStringCapacity); + rep_.header.kind = ByteStringKind::kMedium; + rep_.medium.size = string.size(); + if (arena != nullptr) { + auto* data = google::protobuf::Arena::Create(arena, std::move(string)); + rep_.medium.data = data->data(); + rep_.medium.owner = + reinterpret_cast(arena) | kMetadataOwnerArenaBit; + } else { + const auto* refcount = new ReferenceCountedStdString(std::move(string)); + rep_.medium.data = refcount->data(); + rep_.medium.owner = + reinterpret_cast(refcount) | kMetadataOwnerReferenceCountBit; + } +} + +void ByteString::SetMedium(absl::Nonnull arena, + const absl::Cord& cord) { + ABSL_DCHECK_GT(cord.size(), kSmallByteStringCapacity); + rep_.header.kind = ByteStringKind::kMedium; + rep_.medium.size = cord.size(); + char* data = static_cast( + arena->AllocateAligned(rep_.medium.size, alignof(char))); + (CopyCordToArray)(cord, data); + rep_.medium.data = data; + rep_.medium.owner = + reinterpret_cast(arena) | kMetadataOwnerArenaBit; +} + +void ByteString::SetMedium(absl::string_view string, uintptr_t owner) { + ABSL_DCHECK_GT(string.size(), kSmallByteStringCapacity); + ABSL_DCHECK_NE(owner, 0); + rep_.header.kind = ByteStringKind::kMedium; + rep_.medium.size = string.size(); + rep_.medium.data = string.data(); + rep_.medium.owner = owner; +} + +void ByteString::SetMediumOrLarge(absl::Nullable arena, + const absl::Cord& cord) { + if (arena != nullptr) { + SetMedium(arena, cord); + } else { + SetLarge(cord); + } +} + +void ByteString::SetMediumOrLarge(absl::Nullable arena, + absl::Cord&& cord) { + if (arena != nullptr) { + SetMedium(arena, cord); + } else { + SetLarge(std::move(cord)); + } +} + +void ByteString::SetLarge(const absl::Cord& cord) { + ABSL_DCHECK_GT(cord.size(), kSmallByteStringCapacity); + rep_.header.kind = ByteStringKind::kLarge; + ::new (static_cast(&rep_.large.data[0])) absl::Cord(cord); +} + +void ByteString::SetLarge(absl::Cord&& cord) { + ABSL_DCHECK_GT(cord.size(), kSmallByteStringCapacity); + rep_.header.kind = ByteStringKind::kLarge; + ::new (static_cast(&rep_.large.data[0])) absl::Cord(std::move(cord)); +} + +void ByteString::SwapSmallSmall(ByteString& lhs, ByteString& rhs) { + using std::swap; + ABSL_DCHECK_EQ(lhs.GetKind(), ByteStringKind::kSmall); + ABSL_DCHECK_EQ(rhs.GetKind(), ByteStringKind::kSmall); + const auto size = lhs.rep_.small.size; + lhs.rep_.small.size = rhs.rep_.small.size; + rhs.rep_.small.size = size; + swap(lhs.rep_.small.data, rhs.rep_.small.data); +} + +void ByteString::SwapSmallMedium(ByteString& lhs, ByteString& rhs) { + ABSL_DCHECK_EQ(lhs.GetKind(), ByteStringKind::kSmall); + ABSL_DCHECK_EQ(rhs.GetKind(), ByteStringKind::kMedium); + auto* lhs_arena = lhs.GetSmallArena(); + auto* rhs_arena = rhs.GetMediumArena(); + if (lhs_arena == rhs_arena) { + SmallByteStringRep lhs_rep = lhs.rep_.small; + lhs.rep_.medium = rhs.rep_.medium; + rhs.rep_.small = lhs_rep; + } else { + SmallByteStringRep small = lhs.rep_.small; + lhs.SetMedium(lhs_arena, rhs.GetMedium()); + rhs.DestroyMedium(); + rhs.SetSmall(rhs_arena, GetSmall(small)); + } +} + +void ByteString::SwapSmallLarge(ByteString& lhs, ByteString& rhs) { + ABSL_DCHECK_EQ(lhs.GetKind(), ByteStringKind::kSmall); + ABSL_DCHECK_EQ(rhs.GetKind(), ByteStringKind::kLarge); + auto* lhs_arena = lhs.GetSmallArena(); + absl::Cord large = std::move(rhs.GetLarge()); + rhs.DestroyLarge(); + rhs.rep_.small = lhs.rep_.small; + if (lhs_arena == nullptr) { + lhs.SetLarge(std::move(large)); + } else { + rhs.rep_.small.arena = nullptr; + lhs.SetMedium(lhs_arena, large); + } +} + +void ByteString::SwapMediumMedium(ByteString& lhs, ByteString& rhs) { + using std::swap; + ABSL_DCHECK_EQ(lhs.GetKind(), ByteStringKind::kMedium); + ABSL_DCHECK_EQ(rhs.GetKind(), ByteStringKind::kMedium); + auto* lhs_arena = lhs.GetMediumArena(); + auto* rhs_arena = rhs.GetMediumArena(); + if (lhs_arena == rhs_arena) { + swap(lhs.rep_.medium, rhs.rep_.medium); + } else { + MediumByteStringRep medium = lhs.rep_.medium; + lhs.SetMedium(lhs_arena, rhs.GetMedium()); + rhs.DestroyMedium(); + rhs.SetMedium(rhs_arena, GetMedium(medium)); + DestroyMedium(medium); + } +} + +void ByteString::SwapMediumLarge(ByteString& lhs, ByteString& rhs) { + ABSL_DCHECK_EQ(lhs.GetKind(), ByteStringKind::kMedium); + ABSL_DCHECK_EQ(rhs.GetKind(), ByteStringKind::kLarge); + auto* lhs_arena = lhs.GetMediumArena(); + absl::Cord large = std::move(rhs.GetLarge()); + rhs.DestroyLarge(); + if (lhs_arena == nullptr) { + rhs.rep_.medium = lhs.rep_.medium; + lhs.SetLarge(std::move(large)); + } else { + rhs.SetMedium(nullptr, lhs.GetMedium()); + lhs.SetMedium(lhs_arena, std::move(large)); + } +} + +void ByteString::SwapLargeLarge(ByteString& lhs, ByteString& rhs) { + using std::swap; + ABSL_DCHECK_EQ(lhs.GetKind(), ByteStringKind::kLarge); + ABSL_DCHECK_EQ(rhs.GetKind(), ByteStringKind::kLarge); + swap(lhs.GetLarge(), rhs.GetLarge()); +} + +ByteStringView::ByteStringView(const ByteString& other) noexcept { + switch (other.GetKind()) { + case ByteStringKind::kSmall: { + auto* other_arena = other.GetSmallArena(); + const auto string = other.GetSmall(); + rep_.header.kind = ByteStringViewKind::kString; + rep_.string.size = string.size(); + rep_.string.data = string.data(); + if (other_arena != nullptr) { + rep_.string.owner = + reinterpret_cast(other_arena) | kMetadataOwnerArenaBit; + } else { + rep_.string.owner = 0; + } + } break; + case ByteStringKind::kMedium: { + const auto string = other.GetMedium(); + rep_.header.kind = ByteStringViewKind::kString; + rep_.string.size = string.size(); + rep_.string.data = string.data(); + rep_.string.owner = other.GetMediumOwner(); + } break; + case ByteStringKind::kLarge: { + const auto& cord = other.GetLarge(); + rep_.header.kind = ByteStringViewKind::kCord; + rep_.cord.size = cord.size(); + rep_.cord.data = &cord; + rep_.cord.pos = 0; + } break; + } +} + +bool ByteStringView::empty() const noexcept { + switch (GetKind()) { + case ByteStringViewKind::kString: + return rep_.string.size == 0; + case ByteStringViewKind::kCord: + return rep_.cord.size == 0; + } +} + +size_t ByteStringView::size() const noexcept { + switch (GetKind()) { + case ByteStringViewKind::kString: + return rep_.string.size; + case ByteStringViewKind::kCord: + return rep_.cord.size; + } +} + +absl::optional ByteStringView::TryFlat() const noexcept { + switch (GetKind()) { + case ByteStringViewKind::kString: + return GetString(); + case ByteStringViewKind::kCord: + if (auto flat = GetCord().TryFlat(); flat) { + return flat->substr(rep_.cord.pos, rep_.cord.size); + } + return absl::nullopt; + } +} + +absl::string_view ByteStringView::GetFlat(std::string& scratch) const { + switch (GetKind()) { + case ByteStringViewKind::kString: + return GetString(); + case ByteStringViewKind::kCord: { + if (auto flat = GetCord().TryFlat(); flat) { + return flat->substr(rep_.cord.pos, rep_.cord.size); + } + scratch = static_cast(GetSubcord()); + return scratch; + } + } +} + +bool ByteStringView::Equals(ByteStringView rhs) const noexcept { + switch (GetKind()) { + case ByteStringViewKind::kString: + switch (rhs.GetKind()) { + case ByteStringViewKind::kString: + return GetString() == rhs.GetString(); + case ByteStringViewKind::kCord: + return GetString() == rhs.GetSubcord(); + } + case ByteStringViewKind::kCord: + switch (rhs.GetKind()) { + case ByteStringViewKind::kString: + return GetSubcord() == rhs.GetString(); + case ByteStringViewKind::kCord: + return GetSubcord() == rhs.GetSubcord(); + } + } +} + +int ByteStringView::Compare(ByteStringView rhs) const noexcept { + switch (GetKind()) { + case ByteStringViewKind::kString: + switch (rhs.GetKind()) { + case ByteStringViewKind::kString: + return GetString().compare(rhs.GetString()); + case ByteStringViewKind::kCord: + return -rhs.GetSubcord().Compare(GetString()); + } + case ByteStringViewKind::kCord: + switch (rhs.GetKind()) { + case ByteStringViewKind::kString: + return GetSubcord().Compare(rhs.GetString()); + case ByteStringViewKind::kCord: + return GetSubcord().Compare(rhs.GetSubcord()); + } + } +} + +bool ByteStringView::StartsWith(ByteStringView rhs) const noexcept { + switch (GetKind()) { + case ByteStringViewKind::kString: + switch (rhs.GetKind()) { + case ByteStringViewKind::kString: + return absl::StartsWith(GetString(), rhs.GetString()); + case ByteStringViewKind::kCord: { + const auto string = GetString(); + const auto& cord = rhs.GetSubcord(); + const auto cord_size = cord.size(); + return string.size() >= cord_size && + string.substr(0, cord_size) == cord; + } + } + case ByteStringViewKind::kCord: + switch (rhs.GetKind()) { + case ByteStringViewKind::kString: + return GetSubcord().StartsWith(rhs.GetString()); + case ByteStringViewKind::kCord: + return GetSubcord().StartsWith(rhs.GetSubcord()); + } + } +} + +bool ByteStringView::EndsWith(ByteStringView rhs) const noexcept { + switch (GetKind()) { + case ByteStringViewKind::kString: + switch (rhs.GetKind()) { + case ByteStringViewKind::kString: + return absl::EndsWith(GetString(), rhs.GetString()); + case ByteStringViewKind::kCord: { + const auto string = GetString(); + const auto& cord = rhs.GetSubcord(); + const auto string_size = string.size(); + const auto cord_size = cord.size(); + return string_size >= cord_size && + string.substr(string_size - cord_size) == cord; + } + } + case ByteStringViewKind::kCord: + switch (rhs.GetKind()) { + case ByteStringViewKind::kString: + return GetSubcord().EndsWith(rhs.GetString()); + case ByteStringViewKind::kCord: + return GetSubcord().EndsWith(rhs.GetSubcord()); + } + } +} + +void ByteStringView::RemovePrefix(size_t n) { + ABSL_DCHECK_LE(n, size()); + switch (GetKind()) { + case ByteStringViewKind::kString: + rep_.string.data += n; + break; + case ByteStringViewKind::kCord: + rep_.cord.pos += n; + break; + } + rep_.header.size -= n; +} + +void ByteStringView::RemoveSuffix(size_t n) { + ABSL_DCHECK_LE(n, size()); + rep_.header.size -= n; +} + +std::string ByteStringView::ToString() const { + switch (GetKind()) { + case ByteStringViewKind::kString: + return std::string(GetString()); + case ByteStringViewKind::kCord: + return static_cast(GetSubcord()); + } +} + +absl::Cord ByteStringView::ToCord() const { + switch (GetKind()) { + case ByteStringViewKind::kString: { + const auto* refcount = GetStringReferenceCount(); + if (refcount != nullptr) { + StrongRef(*refcount); + return absl::MakeCordFromExternal(GetString(), + ReferenceCountReleaser{refcount}); + } + return absl::Cord(GetString()); + } + case ByteStringViewKind::kCord: + return GetSubcord(); + } +} + +absl::Nullable ByteStringView::GetArena() const noexcept { + switch (GetKind()) { + case ByteStringViewKind::kString: + return GetStringArena(); + case ByteStringViewKind::kCord: + return nullptr; + } +} + +void ByteStringView::HashValue(absl::HashState state) const { + Visit(absl::Overload( + [&state](absl::string_view string) { + absl::HashState::combine(std::move(state), string); + }, + [&state](const absl::Cord& cord) { + absl::HashState::combine(std::move(state), cord); + })); +} + +} // namespace cel::common_internal diff --git a/common/internal/byte_string.h b/common/internal/byte_string.h new file mode 100644 index 000000000..66cf44c18 --- /dev/null +++ b/common/internal/byte_string.h @@ -0,0 +1,829 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_BYTE_STRING_H_ +#define THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_BYTE_STRING_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/base/nullability.h" +#include "absl/base/optimization.h" +#include "absl/hash/hash.h" +#include "absl/log/absl_check.h" +#include "absl/strings/cord.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/allocator.h" +#include "common/internal/metadata.h" +#include "common/internal/reference_count.h" +#include "common/memory.h" +#include "google/protobuf/arena.h" + +namespace cel::common_internal { + +// absl::Cord is trivially relocatable IFF we are not using ASan or MSan. When +// using ASan or MSan absl::Cord will poison/unpoison its inline storage. +#if defined(ABSL_HAVE_ADDRESS_SANITIZER) || defined(ABSL_HAVE_MEMORY_SANITIZER) +#define CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI +#else +#define CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI ABSL_ATTRIBUTE_TRIVIAL_ABI +#endif + +class CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI [[nodiscard]] ByteString; +class ByteStringView; + +struct ByteStringTestFriend; +struct ByteStringViewTestFriend; + +enum class ByteStringKind : unsigned int { + kSmall = 0, + kMedium, + kLarge, +}; + +inline std::ostream& operator<<(std::ostream& out, ByteStringKind kind) { + switch (kind) { + case ByteStringKind::kSmall: + return out << "SMALL"; + case ByteStringKind::kMedium: + return out << "MEDIUM"; + case ByteStringKind::kLarge: + return out << "LARGE"; + } +} + +// Representation of small strings in ByteString, which are stored in place. +struct CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI SmallByteStringRep final { +#ifdef _MSC_VER +#pragma push(pack, 1) +#endif + struct ABSL_ATTRIBUTE_PACKED CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI { + ByteStringKind kind : 2; + size_t size : 6; + }; +#ifdef _MSC_VER +#pragma pop(pack) +#endif + char data[23 - sizeof(google::protobuf::Arena*)]; + google::protobuf::Arena* arena; +}; + +inline constexpr size_t kSmallByteStringCapacity = + sizeof(SmallByteStringRep::data); + +inline constexpr size_t kMediumByteStringSizeBits = sizeof(size_t) * 8 - 2; +inline constexpr size_t kMediumByteStringMaxSize = + (size_t{1} << kMediumByteStringSizeBits) - 1; + +inline constexpr size_t kByteStringViewSizeBits = sizeof(size_t) * 8 - 1; +inline constexpr size_t kByteStringViewMaxSize = + (size_t{1} << kByteStringViewSizeBits) - 1; + +// Representation of medium strings in ByteString. These are either owned by an +// arena or managed by a reference count. This is encoded in `owner` following +// the same semantics as `cel::Owner`. +struct CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI MediumByteStringRep final { +#ifdef _MSC_VER +#pragma push(pack, 1) +#endif + struct ABSL_ATTRIBUTE_PACKED CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI { + ByteStringKind kind : 2; + size_t size : kMediumByteStringSizeBits; + }; +#ifdef _MSC_VER +#pragma pop(pack) +#endif + const char* data; + uintptr_t owner; +}; + +// Representation of large strings in ByteString. These are stored as +// `absl::Cord` and never owned by an arena. +struct CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI LargeByteStringRep final { +#ifdef _MSC_VER +#pragma push(pack, 1) +#endif + struct ABSL_ATTRIBUTE_PACKED CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI { + ByteStringKind kind : 2; + size_t padding : kMediumByteStringSizeBits; + }; +#ifdef _MSC_VER +#pragma pop(pack) +#endif + alignas(absl::Cord) char data[sizeof(absl::Cord)]; +}; + +// Representation of ByteString. +union CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI ByteStringRep final { +#ifdef _MSC_VER +#pragma push(pack, 1) +#endif + struct ABSL_ATTRIBUTE_PACKED CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI { + ByteStringKind kind : 2; + } header; +#ifdef _MSC_VER +#pragma pop(pack) +#endif + SmallByteStringRep small; + MediumByteStringRep medium; + LargeByteStringRep large; +}; + +// `ByteString` is an vocabulary type capable of representing copy-on-write +// strings efficiently for arenas and reference counting. The contents of the +// byte string are owned by an arena or managed by a reference count. All byte +// strings have an associated allocator specified at construction, once the byte +// string is constructed the allocator will not and cannot change. Copying and +// moving between different allocators is supported and dealt with +// transparently by copying. +class CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI + [[nodiscard]] ByteString final { + public: + static ByteString Owned(Allocator<> allocator, const char* string) { + return ByteString(allocator, string); + } + + static ByteString Owned(Allocator<> allocator, absl::string_view string) { + return ByteString(allocator, string); + } + + static ByteString Owned(Allocator<> allocator, const std::string& string) { + return ByteString(allocator, string); + } + + static ByteString Owned(Allocator<> allocator, std::string&& string) { + return ByteString(allocator, std::move(string)); + } + + static ByteString Owned(Allocator<> allocator, const absl::Cord& cord) { + return ByteString(allocator, cord); + } + + static ByteString Owned(Allocator<> allocator, ByteStringView other); + + static ByteString Borrowed( + Owner owner, absl::string_view string ABSL_ATTRIBUTE_LIFETIME_BOUND); + + static ByteString Borrowed( + const Owner& owner, const absl::Cord& cord ABSL_ATTRIBUTE_LIFETIME_BOUND); + + ByteString() noexcept : ByteString(NewDeleteAllocator()) {} + + explicit ByteString(const char* string) + : ByteString(NewDeleteAllocator(), string) {} + + explicit ByteString(absl::string_view string) + : ByteString(NewDeleteAllocator(), string) {} + + explicit ByteString(const std::string& string) + : ByteString(NewDeleteAllocator(), string) {} + + explicit ByteString(std::string&& string) + : ByteString(NewDeleteAllocator(), std::move(string)) {} + + explicit ByteString(const absl::Cord& cord) + : ByteString(NewDeleteAllocator(), cord) {} + + explicit ByteString(ByteStringView other); + + ByteString(const ByteString& other) : ByteString(other.GetArena(), other) {} + + ByteString(ByteString&& other) + : ByteString(other.GetArena(), std::move(other)) {} + + explicit ByteString(Allocator<> allocator) noexcept { + SetSmallEmpty(allocator.arena()); + } + + ByteString(Allocator<> allocator, const char* string) + : ByteString(allocator, absl::NullSafeStringView(string)) {} + + ByteString(Allocator<> allocator, absl::string_view string); + + ByteString(Allocator<> allocator, const std::string& string); + + ByteString(Allocator<> allocator, std::string&& string); + + ByteString(Allocator<> allocator, const absl::Cord& cord); + + ByteString(Allocator<> allocator, ByteStringView other); + + ByteString(Allocator<> allocator, const ByteString& other) + : ByteString(allocator) { + CopyFrom(other); + } + + ByteString(Allocator<> allocator, ByteString&& other) + : ByteString(allocator) { + MoveFrom(other); + } + + ~ByteString() { Destroy(); } + + ByteString& operator=(const ByteString& other) { + if (ABSL_PREDICT_TRUE(this != &other)) { + CopyFrom(other); + } + return *this; + } + + ByteString& operator=(ByteString&& other) { + if (ABSL_PREDICT_TRUE(this != &other)) { + MoveFrom(other); + } + return *this; + } + + ByteString& operator=(ByteStringView other); + + bool empty() const noexcept; + + size_t size() const noexcept; + + size_t max_size() const noexcept { return kByteStringViewMaxSize; } + + absl::string_view Flatten() ABSL_ATTRIBUTE_LIFETIME_BOUND; + + absl::optional TryFlat() const noexcept + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + absl::string_view GetFlat(std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) + const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + bool Equals(ByteStringView rhs) const noexcept; + + int Compare(ByteStringView rhs) const noexcept; + + bool StartsWith(ByteStringView rhs) const noexcept; + + bool EndsWith(ByteStringView rhs) const noexcept; + + void RemovePrefix(size_t n); + + void RemoveSuffix(size_t n); + + std::string ToString() const; + + absl::Cord ToCord() const&; + + absl::Cord ToCord() &&; + + absl::Nullable GetArena() const noexcept; + + void HashValue(absl::HashState state) const; + + void swap(ByteString& other) { + if (ABSL_PREDICT_TRUE(this != &other)) { + Swap(other); + } + } + + template + std::common_type_t, + std::invoke_result_t> + Visit(Visitor&& visitor) const { + switch (GetKind()) { + case ByteStringKind::kSmall: + return std::invoke(std::forward(visitor), GetSmall()); + case ByteStringKind::kMedium: + return std::invoke(std::forward(visitor), GetMedium()); + case ByteStringKind::kLarge: + return std::invoke(std::forward(visitor), GetLarge()); + } + } + + friend void swap(ByteString& lhs, ByteString& rhs) { lhs.swap(rhs); } + + private: + friend class ByteStringView; + friend struct ByteStringTestFriend; + + ByteString(absl::Nonnull refcount, + absl::string_view string); + + constexpr ByteStringKind GetKind() const noexcept { return rep_.header.kind; } + + absl::string_view GetSmall() const noexcept { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kSmall); + return GetSmall(rep_.small); + } + + static absl::string_view GetSmall(const SmallByteStringRep& rep) noexcept { + return absl::string_view(rep.data, rep.size); + } + + absl::string_view GetMedium() const noexcept { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); + return GetMedium(rep_.medium); + } + + static absl::string_view GetMedium(const MediumByteStringRep& rep) noexcept { + return absl::string_view(rep.data, rep.size); + } + + absl::Nullable GetSmallArena() const noexcept { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kSmall); + return GetSmallArena(rep_.small); + } + + static absl::Nullable GetSmallArena( + const SmallByteStringRep& rep) noexcept { + return rep.arena; + } + + absl::Nullable GetMediumArena() const noexcept { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); + return GetMediumArena(rep_.medium); + } + + static absl::Nullable GetMediumArena( + const MediumByteStringRep& rep) noexcept; + + absl::Nullable GetMediumReferenceCount() + const noexcept { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); + return GetMediumReferenceCount(rep_.medium); + } + + static absl::Nullable GetMediumReferenceCount( + const MediumByteStringRep& rep) noexcept; + + uintptr_t GetMediumOwner() const noexcept { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); + return rep_.medium.owner; + } + + absl::Cord& GetLarge() noexcept ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kLarge); + return GetLarge(rep_.large); + } + + static absl::Cord& GetLarge( + LargeByteStringRep& rep ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { + return *std::launder(reinterpret_cast(&rep.data[0])); + } + + const absl::Cord& GetLarge() const noexcept ABSL_ATTRIBUTE_LIFETIME_BOUND { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kLarge); + return GetLarge(rep_.large); + } + + static const absl::Cord& GetLarge( + const LargeByteStringRep& rep ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { + return *std::launder(reinterpret_cast(&rep.data[0])); + } + + void SetSmallEmpty(absl::Nullable arena); + + void SetSmall(absl::Nullable arena, absl::string_view string); + + void SetSmall(absl::Nullable arena, const absl::Cord& cord); + + void SetMedium(absl::Nullable arena, + absl::string_view string); + + void SetMedium(absl::Nullable arena, std::string&& string); + + void SetMedium(absl::Nonnull arena, const absl::Cord& cord); + + void SetMedium(absl::string_view string, uintptr_t owner); + + void SetMediumOrLarge(absl::Nullable arena, + const absl::Cord& cord); + + void SetMediumOrLarge(absl::Nullable arena, + absl::Cord&& cord); + + void SetLarge(const absl::Cord& cord); + + void SetLarge(absl::Cord&& cord); + + void Swap(ByteString& other); + + static void SwapSmallSmall(ByteString& lhs, ByteString& rhs); + static void SwapSmallMedium(ByteString& lhs, ByteString& rhs); + static void SwapSmallLarge(ByteString& lhs, ByteString& rhs); + static void SwapMediumMedium(ByteString& lhs, ByteString& rhs); + static void SwapMediumLarge(ByteString& lhs, ByteString& rhs); + static void SwapLargeLarge(ByteString& lhs, ByteString& rhs); + + void CopyFrom(const ByteString& other); + + void CopyFromSmallSmall(const ByteString& other); + void CopyFromSmallMedium(const ByteString& other); + void CopyFromSmallLarge(const ByteString& other); + void CopyFromMediumSmall(const ByteString& other); + void CopyFromMediumMedium(const ByteString& other); + void CopyFromMediumLarge(const ByteString& other); + void CopyFromLargeSmall(const ByteString& other); + void CopyFromLargeMedium(const ByteString& other); + void CopyFromLargeLarge(const ByteString& other); + + void CopyFrom(ByteStringView other); + + void CopyFromSmallString(ByteStringView other); + void CopyFromSmallCord(ByteStringView other); + void CopyFromMediumString(ByteStringView other); + void CopyFromMediumCord(ByteStringView other); + void CopyFromLargeString(ByteStringView other); + void CopyFromLargeCord(ByteStringView other); + + void MoveFrom(ByteString& other); + + void MoveFromSmallSmall(ByteString& other); + void MoveFromSmallMedium(ByteString& other); + void MoveFromSmallLarge(ByteString& other); + void MoveFromMediumSmall(ByteString& other); + void MoveFromMediumMedium(ByteString& other); + void MoveFromMediumLarge(ByteString& other); + void MoveFromLargeSmall(ByteString& other); + void MoveFromLargeMedium(ByteString& other); + void MoveFromLargeLarge(ByteString& other); + + void Destroy() noexcept; + + void DestroyMedium() noexcept { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kMedium); + DestroyMedium(rep_.medium); + } + + static void DestroyMedium(const MediumByteStringRep& rep) noexcept { + StrongUnref(GetMediumReferenceCount(rep)); + } + + void DestroyLarge() noexcept { + ABSL_DCHECK_EQ(GetKind(), ByteStringKind::kLarge); + DestroyLarge(rep_.large); + } + + static void DestroyLarge(LargeByteStringRep& rep) noexcept { + GetLarge(rep).~Cord(); + } + + ByteStringRep rep_; +}; + +template +H AbslHashValue(H state, const ByteString& byte_string) { + byte_string.HashValue(absl::HashState::Create(&state)); + return state; +} + +enum class ByteStringViewKind : unsigned int { + kString = 0, + kCord, +}; + +inline std::ostream& operator<<(std::ostream& out, ByteStringViewKind kind) { + switch (kind) { + case ByteStringViewKind::kString: + return out << "STRING"; + case ByteStringViewKind::kCord: + return out << "CORD"; + } +} + +struct StringByteStringViewRep final { +#ifdef _MSC_VER +#pragma push(pack, 1) +#endif + struct ABSL_ATTRIBUTE_PACKED { + ByteStringViewKind kind : 1; + size_t size : kByteStringViewSizeBits; + }; +#ifdef _MSC_VER +#pragma pop(pack) +#endif + const char* data; + uintptr_t owner; +}; + +struct CordByteStringViewRep final { +#ifdef _MSC_VER +#pragma push(pack, 1) +#endif + struct ABSL_ATTRIBUTE_PACKED { + ByteStringViewKind kind : 1; + size_t size : kByteStringViewSizeBits; + }; +#ifdef _MSC_VER +#pragma pop(pack) +#endif + const absl::Cord* data; + size_t pos; +}; + +union ByteStringViewRep final { +#ifdef _MSC_VER +#pragma push(pack, 1) +#endif + struct ABSL_ATTRIBUTE_PACKED { + ByteStringViewKind kind : 1; + size_t size : kByteStringViewSizeBits; + } header; +#ifdef _MSC_VER +#pragma pop(pack) +#endif + StringByteStringViewRep string; + CordByteStringViewRep cord; +}; + +// `ByteStringView` is to `ByteString` what `std::string_view` is to +// `std::string`. While it is capable of being a view over the underlying data +// of `ByteStringView`, it is also capable of being a view over `std::string`, +// `std::string_view`, and `absl::Cord`. +class ByteStringView final { + public: + ByteStringView() noexcept { + rep_.header.kind = ByteStringViewKind::kString; + rep_.string.size = 0; + rep_.string.data = ""; + rep_.string.owner = 0; + } + + ByteStringView(const ByteStringView&) = default; + ByteStringView(ByteStringView&&) = default; + ByteStringView& operator=(const ByteStringView&) = default; + ByteStringView& operator=(ByteStringView&&) = default; + + // NOLINTNEXTLINE(google-explicit-constructor) + ByteStringView(const char* string ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept + : ByteStringView(absl::NullSafeStringView(string)) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + ByteStringView( + absl::string_view string ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { + ABSL_DCHECK_LE(string.size(), max_size()); + rep_.header.kind = ByteStringViewKind::kString; + rep_.string.size = string.size(); + rep_.string.data = string.data(); + rep_.string.owner = 0; + } + + // NOLINTNEXTLINE(google-explicit-constructor) + ByteStringView( + const std::string& string ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept + : ByteStringView(absl::string_view(string)) {} + + // NOLINTNEXTLINE(google-explicit-constructor) + ByteStringView( + const absl::Cord& cord ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { + ABSL_DCHECK_LE(cord.size(), max_size()); + rep_.header.kind = ByteStringViewKind::kCord; + rep_.cord.size = cord.size(); + rep_.cord.data = &cord; + rep_.cord.pos = 0; + } + + // NOLINTNEXTLINE(google-explicit-constructor) + ByteStringView( + const ByteString& other ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept; + + // NOLINTNEXTLINE(google-explicit-constructor) + ByteStringView& operator=( + const char* string ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { + return *this = ByteStringView(string); + } + + // NOLINTNEXTLINE(google-explicit-constructor) + ByteStringView& operator=( + absl::string_view string ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { + return *this = ByteStringView(string); + } + + // NOLINTNEXTLINE(google-explicit-constructor) + ByteStringView& operator=( + const std::string& string ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { + return *this = ByteStringView(string); + } + + ByteStringView& operator=(std::string&&) = delete; + + // NOLINTNEXTLINE(google-explicit-constructor) + ByteStringView& operator=( + const absl::Cord& cord ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { + return *this = ByteStringView(cord); + } + + ByteStringView& operator=(absl::Cord&&) = delete; + + // NOLINTNEXTLINE(google-explicit-constructor) + ByteStringView& operator=( + const ByteString& other ABSL_ATTRIBUTE_LIFETIME_BOUND) noexcept { + return *this = ByteStringView(other); + } + + ByteStringView& operator=(ByteString&&) = delete; + + bool empty() const noexcept; + + size_t size() const noexcept; + + size_t max_size() const noexcept { return kByteStringViewMaxSize; } + + absl::optional TryFlat() const noexcept + ABSL_ATTRIBUTE_LIFETIME_BOUND; + + absl::string_view GetFlat(std::string& scratch ABSL_ATTRIBUTE_LIFETIME_BOUND) + const ABSL_ATTRIBUTE_LIFETIME_BOUND; + + bool Equals(ByteStringView rhs) const noexcept; + + int Compare(ByteStringView rhs) const noexcept; + + bool StartsWith(ByteStringView rhs) const noexcept; + + bool EndsWith(ByteStringView rhs) const noexcept; + + void RemovePrefix(size_t n); + + void RemoveSuffix(size_t n); + + std::string ToString() const; + + absl::Cord ToCord() const; + + absl::Nullable GetArena() const noexcept; + + void HashValue(absl::HashState state) const; + + template + std::common_type_t, + std::invoke_result_t> + Visit(Visitor&& visitor) const { + switch (GetKind()) { + case ByteStringViewKind::kString: + return std::invoke(std::forward(visitor), GetString()); + case ByteStringViewKind::kCord: + return std::invoke(std::forward(visitor), + static_cast(GetSubcord())); + } + } + + private: + friend class ByteString; + friend struct ByteStringViewTestFriend; + + constexpr ByteStringViewKind GetKind() const noexcept { + return rep_.header.kind; + } + + absl::string_view GetString() const noexcept { + ABSL_DCHECK_EQ(GetKind(), ByteStringViewKind::kString); + return absl::string_view(rep_.string.data, rep_.string.size); + } + + absl::Nullable GetStringArena() const noexcept { + ABSL_DCHECK_EQ(GetKind(), ByteStringViewKind::kString); + if ((rep_.string.owner & kMetadataOwnerBits) == kMetadataOwnerArenaBit) { + return reinterpret_cast(rep_.string.owner & + kMetadataOwnerPointerMask); + } + return nullptr; + } + + absl::Nullable GetStringReferenceCount() + const noexcept { + ABSL_DCHECK_EQ(GetKind(), ByteStringViewKind::kString); + return GetStringReferenceCount(rep_.string); + } + + static absl::Nullable GetStringReferenceCount( + const StringByteStringViewRep& rep) noexcept { + if ((rep.owner & kMetadataOwnerBits) == kMetadataOwnerReferenceCountBit) { + return reinterpret_cast(rep.owner & + kMetadataOwnerPointerMask); + } + return nullptr; + } + + uintptr_t GetStringOwner() const noexcept { + ABSL_DCHECK_EQ(GetKind(), ByteStringViewKind::kString); + return rep_.string.owner; + } + + const absl::Cord& GetCord() const noexcept { + ABSL_DCHECK_EQ(GetKind(), ByteStringViewKind::kCord); + return *rep_.cord.data; + } + + absl::Cord GetSubcord() const noexcept { + ABSL_DCHECK_EQ(GetKind(), ByteStringViewKind::kCord); + return GetCord().Subcord(rep_.cord.pos, rep_.cord.size); + } + + ByteStringViewRep rep_; +}; + +inline bool operator==(const ByteString& lhs, const ByteString& rhs) noexcept { + return lhs.Equals(rhs); +} + +inline bool operator!=(const ByteString& lhs, const ByteString& rhs) noexcept { + return !operator==(lhs, rhs); +} + +inline bool operator<(const ByteString& lhs, const ByteString& rhs) noexcept { + return lhs.Compare(rhs) < 0; +} + +inline bool operator<=(const ByteString& lhs, const ByteString& rhs) noexcept { + return lhs.Compare(rhs) <= 0; +} + +inline bool operator>(const ByteString& lhs, const ByteString& rhs) noexcept { + return lhs.Compare(rhs) > 0; +} + +inline bool operator>=(const ByteString& lhs, const ByteString& rhs) noexcept { + return lhs.Compare(rhs) >= 0; +} + +inline bool ByteString::Equals(ByteStringView rhs) const noexcept { + return ByteStringView(*this).Equals(rhs); +} + +inline int ByteString::Compare(ByteStringView rhs) const noexcept { + return ByteStringView(*this).Compare(rhs); +} + +inline bool ByteString::StartsWith(ByteStringView rhs) const noexcept { + return ByteStringView(*this).StartsWith(rhs); +} + +inline bool ByteString::EndsWith(ByteStringView rhs) const noexcept { + return ByteStringView(*this).EndsWith(rhs); +} + +inline bool operator==(ByteStringView lhs, ByteStringView rhs) noexcept { + return lhs.Equals(rhs); +} + +inline bool operator!=(ByteStringView lhs, ByteStringView rhs) noexcept { + return !operator==(lhs, rhs); +} + +inline bool operator<(ByteStringView lhs, ByteStringView rhs) noexcept { + return lhs.Compare(rhs) < 0; +} + +inline bool operator<=(ByteStringView lhs, ByteStringView rhs) noexcept { + return lhs.Compare(rhs) <= 0; +} + +inline bool operator>(ByteStringView lhs, ByteStringView rhs) noexcept { + return lhs.Compare(rhs) > 0; +} + +inline bool operator>=(ByteStringView lhs, ByteStringView rhs) noexcept { + return lhs.Compare(rhs) >= 0; +} + +template +H AbslHashValue(H state, ByteStringView byte_string_view) { + byte_string_view.HashValue(absl::HashState::Create(&state)); + return state; +} + +inline ByteString ByteString::Owned(Allocator<> allocator, + ByteStringView other) { + return ByteString(allocator, other); +} + +inline ByteString::ByteString(ByteStringView other) + : ByteString(NewDeleteAllocator(), other) {} + +inline ByteString::ByteString(Allocator<> allocator, ByteStringView other) + : ByteString(allocator) { + CopyFrom(other); +} + +inline ByteString& ByteString::operator=(ByteStringView other) { + CopyFrom(other); + return *this; +} + +#undef CEL_COMMON_INTERNAL_BYTE_STRING_TRIVIAL_ABI + +} // namespace cel::common_internal + +#endif // THIRD_PARTY_CEL_CPP_COMMON_INTERNAL_BYTE_STRING_H_ diff --git a/common/internal/byte_string_test.cc b/common/internal/byte_string_test.cc new file mode 100644 index 000000000..46ba4c50f --- /dev/null +++ b/common/internal/byte_string_test.cc @@ -0,0 +1,1154 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "common/internal/byte_string.h" + +#include +#include +#include +#include + +#include "absl/base/no_destructor.h" +#include "absl/hash/hash.h" +#include "absl/strings/cord.h" +#include "absl/strings/cord_test_helpers.h" +#include "absl/strings/string_view.h" +#include "absl/types/optional.h" +#include "common/allocator.h" +#include "common/internal/reference_count.h" +#include "common/memory.h" +#include "internal/testing.h" +#include "google/protobuf/arena.h" + +namespace cel::common_internal { + +struct ByteStringTestFriend { + static ByteStringKind GetKind(const ByteString& byte_string) { + return byte_string.GetKind(); + } +}; + +struct ByteStringViewTestFriend { + static ByteStringViewKind GetKind(ByteStringView byte_string_view) { + return byte_string_view.GetKind(); + } +}; + +namespace { + +using testing::Eq; +using testing::IsEmpty; +using testing::Not; +using testing::Optional; +using testing::SizeIs; +using testing::TestWithParam; + +TEST(ByteStringKind, Ostream) { + { + std::ostringstream out; + out << ByteStringKind::kSmall; + EXPECT_EQ(out.str(), "SMALL"); + } + { + std::ostringstream out; + out << ByteStringKind::kMedium; + EXPECT_EQ(out.str(), "MEDIUM"); + } + { + std::ostringstream out; + out << ByteStringKind::kLarge; + EXPECT_EQ(out.str(), "LARGE"); + } +} + +TEST(ByteStringViewKind, Ostream) { + { + std::ostringstream out; + out << ByteStringViewKind::kString; + EXPECT_EQ(out.str(), "STRING"); + } + { + std::ostringstream out; + out << ByteStringViewKind::kCord; + EXPECT_EQ(out.str(), "CORD"); + } +} + +class ByteStringTest : public TestWithParam, + public ByteStringTestFriend { + public: + Allocator<> GetAllocator() { + switch (GetParam()) { + case MemoryManagement::kPooling: + return ArenaAllocator(&arena_); + case MemoryManagement::kReferenceCounting: + return NewDeleteAllocator(); + } + } + + private: + google::protobuf::Arena arena_; +}; + +absl::string_view GetSmallStringView() { + static constexpr absl::string_view small = "A small string!"; + return small.substr(0, std::min(kSmallByteStringCapacity, small.size())); +} + +std::string GetSmallString() { return std::string(GetSmallStringView()); } + +absl::Cord GetSmallCord() { + static const absl::NoDestructor small(GetSmallStringView()); + return *small; +} + +absl::string_view GetMediumStringView() { + static constexpr absl::string_view medium = + "A string that is too large for the small string optimization!"; + return medium; +} + +std::string GetMediumString() { return std::string(GetMediumStringView()); } + +const absl::Cord& GetMediumOrLargeCord() { + static const absl::NoDestructor medium_or_large( + GetMediumStringView()); + return *medium_or_large; +} + +const absl::Cord& GetMediumOrLargeFragmentedCord() { + static const absl::NoDestructor medium_or_large( + absl::MakeFragmentedCord( + {GetMediumStringView().substr(0, kSmallByteStringCapacity), + GetMediumStringView().substr(kSmallByteStringCapacity)})); + return *medium_or_large; +} + +TEST_P(ByteStringTest, Default) { + ByteString byte_string = ByteString::Owned(GetAllocator(), ""); + EXPECT_THAT(byte_string, SizeIs(0)); + EXPECT_THAT(byte_string, IsEmpty()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); +} + +TEST_P(ByteStringTest, ConstructSmallCString) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetSmallString().c_str()); + EXPECT_THAT(byte_string, SizeIs(GetSmallStringView().size())); + EXPECT_THAT(byte_string, Not(IsEmpty())); + EXPECT_EQ(byte_string, GetSmallStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); +} + +TEST_P(ByteStringTest, ConstructMediumCString) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumString().c_str()); + EXPECT_THAT(byte_string, SizeIs(GetMediumStringView().size())); + EXPECT_THAT(byte_string, Not(IsEmpty())); + EXPECT_EQ(byte_string, GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); + EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); +} + +TEST_P(ByteStringTest, ConstructSmallRValueString) { + ByteString byte_string = ByteString::Owned(GetAllocator(), GetSmallString()); + EXPECT_THAT(byte_string, SizeIs(GetSmallStringView().size())); + EXPECT_THAT(byte_string, Not(IsEmpty())); + EXPECT_EQ(byte_string, GetSmallStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); +} + +TEST_P(ByteStringTest, ConstructSmallLValueString) { + ByteString byte_string = ByteString::Owned( + GetAllocator(), static_cast(GetSmallString())); + EXPECT_THAT(byte_string, SizeIs(GetSmallStringView().size())); + EXPECT_THAT(byte_string, Not(IsEmpty())); + EXPECT_EQ(byte_string, GetSmallStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); +} + +TEST_P(ByteStringTest, ConstructMediumRValueString) { + ByteString byte_string = ByteString::Owned(GetAllocator(), GetMediumString()); + EXPECT_THAT(byte_string, SizeIs(GetMediumStringView().size())); + EXPECT_THAT(byte_string, Not(IsEmpty())); + EXPECT_EQ(byte_string, GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); + EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); +} + +TEST_P(ByteStringTest, ConstructMediumLValueString) { + ByteString byte_string = ByteString::Owned( + GetAllocator(), static_cast(GetMediumString())); + EXPECT_THAT(byte_string, SizeIs(GetMediumStringView().size())); + EXPECT_THAT(byte_string, Not(IsEmpty())); + EXPECT_EQ(byte_string, GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); + EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); +} + +TEST_P(ByteStringTest, ConstructSmallCord) { + ByteString byte_string = ByteString::Owned(GetAllocator(), GetSmallCord()); + EXPECT_THAT(byte_string, SizeIs(GetSmallStringView().size())); + EXPECT_THAT(byte_string, Not(IsEmpty())); + EXPECT_EQ(byte_string, GetSmallStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); +} + +TEST_P(ByteStringTest, ConstructMediumOrLargeCord) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_THAT(byte_string, SizeIs(GetMediumStringView().size())); + EXPECT_THAT(byte_string, Not(IsEmpty())); + EXPECT_EQ(byte_string, GetMediumStringView()); + if (GetAllocator().arena() == nullptr) { + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kLarge); + } else { + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); + } + EXPECT_EQ(byte_string.GetArena(), GetAllocator().arena()); +} + +TEST(ByteStringTest, BorrowedUnownedString) { +#ifdef NDEBUG + ByteString byte_string = + ByteString::Borrowed(Owner::None(), GetMediumStringView()); + EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), + ByteStringKind::kMedium); + EXPECT_EQ(byte_string.GetArena(), nullptr); + EXPECT_EQ(byte_string, GetMediumStringView()); +#else + EXPECT_DEBUG_DEATH(static_cast(ByteString::Borrowed( + Owner::None(), GetMediumStringView())), + ::testing::_); +#endif +} + +TEST(ByteStringTest, BorrowedUnownedCord) { +#ifdef NDEBUG + ByteString byte_string = + ByteString::Borrowed(Owner::None(), GetMediumOrLargeCord()); + EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), ByteStringKind::kLarge); + EXPECT_EQ(byte_string.GetArena(), nullptr); + EXPECT_EQ(byte_string, GetMediumOrLargeCord()); +#else + EXPECT_DEBUG_DEATH(static_cast(ByteString::Borrowed( + Owner::None(), GetMediumOrLargeCord())), + ::testing::_); +#endif +} + +TEST(ByteStringTest, BorrowedReferenceCountSmallString) { + auto* refcount = new ReferenceCounted(); + Owner owner = Owner::ReferenceCount(refcount); + StrongUnref(refcount); + ByteString byte_string = ByteString::Borrowed(owner, GetSmallStringView()); + EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string.GetArena(), nullptr); + EXPECT_EQ(byte_string, GetSmallStringView()); +} + +TEST(ByteStringTest, BorrowedReferenceCountMediumString) { + auto* refcount = new ReferenceCounted(); + Owner owner = Owner::ReferenceCount(refcount); + StrongUnref(refcount); + ByteString byte_string = ByteString::Borrowed(owner, GetMediumStringView()); + EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), + ByteStringKind::kMedium); + EXPECT_EQ(byte_string.GetArena(), nullptr); + EXPECT_EQ(byte_string, GetMediumStringView()); +} + +TEST(ByteStringTest, BorrowedArenaSmallString) { + google::protobuf::Arena arena; + ByteString byte_string = + ByteString::Borrowed(Owner::Arena(&arena), GetSmallStringView()); + EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string.GetArena(), &arena); + EXPECT_EQ(byte_string, GetSmallStringView()); +} + +TEST(ByteStringTest, BorrowedArenaMediumString) { + google::protobuf::Arena arena; + ByteString byte_string = + ByteString::Borrowed(Owner::Arena(&arena), GetMediumStringView()); + EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), + ByteStringKind::kMedium); + EXPECT_EQ(byte_string.GetArena(), &arena); + EXPECT_EQ(byte_string, GetMediumStringView()); +} + +TEST(ByteStringTest, BorrowedReferenceCountCord) { + auto* refcount = new ReferenceCounted(); + Owner owner = Owner::ReferenceCount(refcount); + StrongUnref(refcount); + ByteString byte_string = ByteString::Borrowed(owner, GetMediumOrLargeCord()); + EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), ByteStringKind::kLarge); + EXPECT_EQ(byte_string.GetArena(), nullptr); + EXPECT_EQ(byte_string, GetMediumOrLargeCord()); +} + +TEST(ByteStringTest, BorrowedArenaCord) { + google::protobuf::Arena arena; + Owner owner = Owner::Arena(&arena); + ByteString byte_string = ByteString::Borrowed(owner, GetMediumOrLargeCord()); + EXPECT_EQ(ByteStringTestFriend::GetKind(byte_string), + ByteStringKind::kMedium); + EXPECT_EQ(byte_string.GetArena(), &arena); + EXPECT_EQ(byte_string, GetMediumOrLargeCord()); +} + +TEST_P(ByteStringTest, CopyFromByteStringView) { + ByteString small_byte_string = + ByteString::Owned(GetAllocator(), GetSmallStringView()); + ByteString medium_byte_string = + ByteString::Owned(GetAllocator(), GetMediumStringView()); + ByteString large_byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeCord()); + + ByteString new_delete_byte_string(NewDeleteAllocator()); + // Small <= Small + new_delete_byte_string = ByteStringView(small_byte_string); + EXPECT_EQ(new_delete_byte_string, ByteStringView(small_byte_string)); + // Small <= Medium + new_delete_byte_string = ByteStringView(medium_byte_string); + EXPECT_EQ(new_delete_byte_string, ByteStringView(medium_byte_string)); + // Medium <= Medium + new_delete_byte_string = ByteStringView(medium_byte_string); + EXPECT_EQ(new_delete_byte_string, ByteStringView(medium_byte_string)); + // Medium <= Large + new_delete_byte_string = ByteStringView(large_byte_string); + EXPECT_EQ(new_delete_byte_string, ByteStringView(large_byte_string)); + // Large <= Large + new_delete_byte_string = ByteStringView(large_byte_string); + EXPECT_EQ(new_delete_byte_string, ByteStringView(large_byte_string)); + // Large <= Small + new_delete_byte_string = ByteStringView(small_byte_string); + EXPECT_EQ(new_delete_byte_string, ByteStringView(small_byte_string)); + // Small <= Large + new_delete_byte_string = ByteStringView(large_byte_string); + EXPECT_EQ(new_delete_byte_string, ByteStringView(large_byte_string)); + // Large <= Medium + new_delete_byte_string = ByteStringView(medium_byte_string); + EXPECT_EQ(new_delete_byte_string, ByteStringView(medium_byte_string)); + // Medium <= Small + new_delete_byte_string = ByteStringView(small_byte_string); + EXPECT_EQ(new_delete_byte_string, ByteStringView(small_byte_string)); + + google::protobuf::Arena arena; + ByteString arena_byte_string(ArenaAllocator(&arena)); + // Small <= Small + arena_byte_string = ByteStringView(small_byte_string); + EXPECT_EQ(arena_byte_string, ByteStringView(small_byte_string)); + // Small <= Medium + arena_byte_string = ByteStringView(medium_byte_string); + EXPECT_EQ(arena_byte_string, ByteStringView(medium_byte_string)); + // Medium <= Medium + arena_byte_string = ByteStringView(medium_byte_string); + EXPECT_EQ(arena_byte_string, ByteStringView(medium_byte_string)); + // Medium <= Large + arena_byte_string = ByteStringView(large_byte_string); + EXPECT_EQ(arena_byte_string, ByteStringView(large_byte_string)); + // Large <= Large + arena_byte_string = ByteStringView(large_byte_string); + EXPECT_EQ(arena_byte_string, ByteStringView(large_byte_string)); + // Large <= Small + arena_byte_string = ByteStringView(small_byte_string); + EXPECT_EQ(arena_byte_string, ByteStringView(small_byte_string)); + // Small <= Large + arena_byte_string = ByteStringView(large_byte_string); + EXPECT_EQ(arena_byte_string, ByteStringView(large_byte_string)); + // Large <= Medium + arena_byte_string = ByteStringView(medium_byte_string); + EXPECT_EQ(arena_byte_string, ByteStringView(medium_byte_string)); + // Medium <= Small + arena_byte_string = ByteStringView(small_byte_string); + EXPECT_EQ(arena_byte_string, ByteStringView(small_byte_string)); + + ByteString allocator_byte_string(GetAllocator()); + // Small <= Small + allocator_byte_string = ByteStringView(small_byte_string); + EXPECT_EQ(allocator_byte_string, ByteStringView(small_byte_string)); + // Small <= Medium + allocator_byte_string = ByteStringView(medium_byte_string); + EXPECT_EQ(allocator_byte_string, ByteStringView(medium_byte_string)); + // Medium <= Medium + allocator_byte_string = ByteStringView(medium_byte_string); + EXPECT_EQ(allocator_byte_string, ByteStringView(medium_byte_string)); + // Medium <= Large + allocator_byte_string = ByteStringView(large_byte_string); + EXPECT_EQ(allocator_byte_string, ByteStringView(large_byte_string)); + // Large <= Large + allocator_byte_string = ByteStringView(large_byte_string); + EXPECT_EQ(allocator_byte_string, ByteStringView(large_byte_string)); + // Large <= Small + allocator_byte_string = ByteStringView(small_byte_string); + EXPECT_EQ(allocator_byte_string, ByteStringView(small_byte_string)); + // Small <= Large + allocator_byte_string = ByteStringView(large_byte_string); + EXPECT_EQ(allocator_byte_string, ByteStringView(large_byte_string)); + // Large <= Medium + allocator_byte_string = ByteStringView(medium_byte_string); + EXPECT_EQ(allocator_byte_string, ByteStringView(medium_byte_string)); + // Medium <= Small + allocator_byte_string = ByteStringView(small_byte_string); + EXPECT_EQ(allocator_byte_string, ByteStringView(small_byte_string)); + + // Miscellaneous cases not covered above. + // Small <= Small Cord + allocator_byte_string = ByteStringView(absl::Cord(GetSmallStringView())); + EXPECT_EQ(allocator_byte_string, GetSmallStringView()); + allocator_byte_string = ByteStringView(medium_byte_string); + // Medium <= Small Cord + allocator_byte_string = ByteStringView(absl::Cord(GetSmallStringView())); + EXPECT_EQ(allocator_byte_string, GetSmallStringView()); + // Large <= Small Cord + allocator_byte_string = ByteStringView(large_byte_string); + allocator_byte_string = ByteStringView(absl::Cord(GetSmallStringView())); + EXPECT_EQ(allocator_byte_string, GetSmallStringView()); + // Large <= Medium Arena String + ByteString large_new_delete_byte_string(NewDeleteAllocator(), + GetMediumOrLargeCord()); + ByteString medium_arena_byte_string(ArenaAllocator(&arena), + GetMediumStringView()); + large_new_delete_byte_string = ByteStringView(medium_arena_byte_string); + EXPECT_EQ(large_new_delete_byte_string, medium_arena_byte_string); +} + +TEST_P(ByteStringTest, CopyFromByteString) { + ByteString small_byte_string = + ByteString::Owned(GetAllocator(), GetSmallStringView()); + ByteString medium_byte_string = + ByteString::Owned(GetAllocator(), GetMediumStringView()); + ByteString large_byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeCord()); + + ByteString new_delete_byte_string(NewDeleteAllocator()); + // Small <= Small + new_delete_byte_string = small_byte_string; + EXPECT_EQ(new_delete_byte_string, small_byte_string); + // Small <= Medium + new_delete_byte_string = medium_byte_string; + EXPECT_EQ(new_delete_byte_string, medium_byte_string); + // Medium <= Medium + new_delete_byte_string = medium_byte_string; + EXPECT_EQ(new_delete_byte_string, medium_byte_string); + // Medium <= Large + new_delete_byte_string = large_byte_string; + EXPECT_EQ(new_delete_byte_string, large_byte_string); + // Large <= Large + new_delete_byte_string = large_byte_string; + EXPECT_EQ(new_delete_byte_string, large_byte_string); + // Large <= Small + new_delete_byte_string = small_byte_string; + EXPECT_EQ(new_delete_byte_string, small_byte_string); + // Small <= Large + new_delete_byte_string = large_byte_string; + EXPECT_EQ(new_delete_byte_string, large_byte_string); + // Large <= Medium + new_delete_byte_string = medium_byte_string; + EXPECT_EQ(new_delete_byte_string, medium_byte_string); + // Medium <= Small + new_delete_byte_string = small_byte_string; + EXPECT_EQ(new_delete_byte_string, small_byte_string); + + google::protobuf::Arena arena; + ByteString arena_byte_string(ArenaAllocator(&arena)); + // Small <= Small + arena_byte_string = small_byte_string; + EXPECT_EQ(arena_byte_string, small_byte_string); + // Small <= Medium + arena_byte_string = medium_byte_string; + EXPECT_EQ(arena_byte_string, medium_byte_string); + // Medium <= Medium + arena_byte_string = medium_byte_string; + EXPECT_EQ(arena_byte_string, medium_byte_string); + // Medium <= Large + arena_byte_string = large_byte_string; + EXPECT_EQ(arena_byte_string, large_byte_string); + // Large <= Large + arena_byte_string = large_byte_string; + EXPECT_EQ(arena_byte_string, large_byte_string); + // Large <= Small + arena_byte_string = small_byte_string; + EXPECT_EQ(arena_byte_string, small_byte_string); + // Small <= Large + arena_byte_string = large_byte_string; + EXPECT_EQ(arena_byte_string, large_byte_string); + // Large <= Medium + arena_byte_string = medium_byte_string; + EXPECT_EQ(arena_byte_string, medium_byte_string); + // Medium <= Small + arena_byte_string = small_byte_string; + EXPECT_EQ(arena_byte_string, small_byte_string); + + ByteString allocator_byte_string(GetAllocator()); + // Small <= Small + allocator_byte_string = small_byte_string; + EXPECT_EQ(allocator_byte_string, small_byte_string); + // Small <= Medium + allocator_byte_string = medium_byte_string; + EXPECT_EQ(allocator_byte_string, medium_byte_string); + // Medium <= Medium + allocator_byte_string = medium_byte_string; + EXPECT_EQ(allocator_byte_string, medium_byte_string); + // Medium <= Large + allocator_byte_string = large_byte_string; + EXPECT_EQ(allocator_byte_string, large_byte_string); + // Large <= Large + allocator_byte_string = large_byte_string; + EXPECT_EQ(allocator_byte_string, large_byte_string); + // Large <= Small + allocator_byte_string = small_byte_string; + EXPECT_EQ(allocator_byte_string, small_byte_string); + // Small <= Large + allocator_byte_string = large_byte_string; + EXPECT_EQ(allocator_byte_string, large_byte_string); + // Large <= Medium + allocator_byte_string = medium_byte_string; + EXPECT_EQ(allocator_byte_string, medium_byte_string); + // Medium <= Small + allocator_byte_string = small_byte_string; + EXPECT_EQ(allocator_byte_string, small_byte_string); + + // Miscellaneous cases not covered above. + // Large <= Medium Arena String + ByteString large_new_delete_byte_string(NewDeleteAllocator(), + GetMediumOrLargeCord()); + ByteString medium_arena_byte_string(ArenaAllocator(&arena), + GetMediumStringView()); + large_new_delete_byte_string = medium_arena_byte_string; + EXPECT_EQ(large_new_delete_byte_string, medium_arena_byte_string); +} + +TEST_P(ByteStringTest, MoveFrom) { + const auto& small_byte_string = [this]() { + return ByteString::Owned(GetAllocator(), GetSmallStringView()); + }; + const auto& medium_byte_string = [this]() { + return ByteString::Owned(GetAllocator(), GetMediumStringView()); + }; + const auto& large_byte_string = [this]() { + return ByteString::Owned(GetAllocator(), GetMediumOrLargeCord()); + }; + + ByteString new_delete_byte_string(NewDeleteAllocator()); + // Small <= Small + new_delete_byte_string = small_byte_string(); + EXPECT_EQ(new_delete_byte_string, small_byte_string()); + // Small <= Medium + new_delete_byte_string = medium_byte_string(); + EXPECT_EQ(new_delete_byte_string, medium_byte_string()); + // Medium <= Medium + new_delete_byte_string = medium_byte_string(); + EXPECT_EQ(new_delete_byte_string, medium_byte_string()); + // Medium <= Large + new_delete_byte_string = large_byte_string(); + EXPECT_EQ(new_delete_byte_string, large_byte_string()); + // Large <= Large + new_delete_byte_string = large_byte_string(); + EXPECT_EQ(new_delete_byte_string, large_byte_string()); + // Large <= Small + new_delete_byte_string = small_byte_string(); + EXPECT_EQ(new_delete_byte_string, small_byte_string()); + // Small <= Large + new_delete_byte_string = large_byte_string(); + EXPECT_EQ(new_delete_byte_string, large_byte_string()); + // Large <= Medium + new_delete_byte_string = medium_byte_string(); + EXPECT_EQ(new_delete_byte_string, medium_byte_string()); + // Medium <= Small + new_delete_byte_string = small_byte_string(); + EXPECT_EQ(new_delete_byte_string, small_byte_string()); + + google::protobuf::Arena arena; + ByteString arena_byte_string(ArenaAllocator(&arena)); + // Small <= Small + arena_byte_string = small_byte_string(); + EXPECT_EQ(arena_byte_string, small_byte_string()); + // Small <= Medium + arena_byte_string = medium_byte_string(); + EXPECT_EQ(arena_byte_string, medium_byte_string()); + // Medium <= Medium + arena_byte_string = medium_byte_string(); + EXPECT_EQ(arena_byte_string, medium_byte_string()); + // Medium <= Large + arena_byte_string = large_byte_string(); + EXPECT_EQ(arena_byte_string, large_byte_string()); + // Large <= Large + arena_byte_string = large_byte_string(); + EXPECT_EQ(arena_byte_string, large_byte_string()); + // Large <= Small + arena_byte_string = small_byte_string(); + EXPECT_EQ(arena_byte_string, small_byte_string()); + // Small <= Large + arena_byte_string = large_byte_string(); + EXPECT_EQ(arena_byte_string, large_byte_string()); + // Large <= Medium + arena_byte_string = medium_byte_string(); + EXPECT_EQ(arena_byte_string, medium_byte_string()); + // Medium <= Small + arena_byte_string = small_byte_string(); + EXPECT_EQ(arena_byte_string, small_byte_string()); + + ByteString allocator_byte_string(GetAllocator()); + // Small <= Small + allocator_byte_string = small_byte_string(); + EXPECT_EQ(allocator_byte_string, small_byte_string()); + // Small <= Medium + allocator_byte_string = medium_byte_string(); + EXPECT_EQ(allocator_byte_string, medium_byte_string()); + // Medium <= Medium + allocator_byte_string = medium_byte_string(); + EXPECT_EQ(allocator_byte_string, medium_byte_string()); + // Medium <= Large + allocator_byte_string = large_byte_string(); + EXPECT_EQ(allocator_byte_string, large_byte_string()); + // Large <= Large + allocator_byte_string = large_byte_string(); + EXPECT_EQ(allocator_byte_string, large_byte_string()); + // Large <= Small + allocator_byte_string = small_byte_string(); + EXPECT_EQ(allocator_byte_string, small_byte_string()); + // Small <= Large + allocator_byte_string = large_byte_string(); + EXPECT_EQ(allocator_byte_string, large_byte_string()); + // Large <= Medium + allocator_byte_string = medium_byte_string(); + EXPECT_EQ(allocator_byte_string, medium_byte_string()); + // Medium <= Small + allocator_byte_string = small_byte_string(); + EXPECT_EQ(allocator_byte_string, small_byte_string()); + + // Miscellaneous cases not covered above. + // Large <= Medium Arena String + ByteString large_new_delete_byte_string(NewDeleteAllocator(), + GetMediumOrLargeCord()); + ByteString medium_arena_byte_string(ArenaAllocator(&arena), + GetMediumStringView()); + large_new_delete_byte_string = std::move(medium_arena_byte_string); + EXPECT_EQ(large_new_delete_byte_string, GetMediumStringView()); +} + +TEST_P(ByteStringTest, Swap) { + using std::swap; + ByteString empty_byte_string(GetAllocator()); + ByteString small_byte_string = + ByteString::Owned(GetAllocator(), GetSmallStringView()); + ByteString medium_byte_string = + ByteString::Owned(GetAllocator(), GetMediumStringView()); + ByteString large_byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeCord()); + + // Small <=> Small + swap(empty_byte_string, small_byte_string); + EXPECT_EQ(empty_byte_string, GetSmallStringView()); + EXPECT_EQ(small_byte_string, ""); + swap(empty_byte_string, small_byte_string); + EXPECT_EQ(empty_byte_string, ""); + EXPECT_EQ(small_byte_string, GetSmallStringView()); + + // Small <=> Medium + swap(small_byte_string, medium_byte_string); + EXPECT_EQ(small_byte_string, GetMediumStringView()); + EXPECT_EQ(medium_byte_string, GetSmallStringView()); + swap(small_byte_string, medium_byte_string); + EXPECT_EQ(small_byte_string, GetSmallStringView()); + EXPECT_EQ(medium_byte_string, GetMediumStringView()); + + // Small <=> Large + swap(small_byte_string, large_byte_string); + EXPECT_EQ(small_byte_string, GetMediumOrLargeCord()); + EXPECT_EQ(large_byte_string, GetSmallStringView()); + swap(small_byte_string, large_byte_string); + EXPECT_EQ(small_byte_string, GetSmallStringView()); + EXPECT_EQ(large_byte_string, GetMediumOrLargeCord()); + + // Medium <=> Medium + static constexpr absl::string_view kDifferentMediumStringView = + "A different string that is too large for the small string optimization!"; + ByteString other_medium_byte_string = + ByteString::Owned(GetAllocator(), kDifferentMediumStringView); + swap(medium_byte_string, other_medium_byte_string); + EXPECT_EQ(medium_byte_string, kDifferentMediumStringView); + EXPECT_EQ(other_medium_byte_string, GetMediumStringView()); + swap(medium_byte_string, other_medium_byte_string); + EXPECT_EQ(medium_byte_string, GetMediumStringView()); + EXPECT_EQ(other_medium_byte_string, kDifferentMediumStringView); + + // Medium <=> Large + swap(medium_byte_string, large_byte_string); + EXPECT_EQ(medium_byte_string, GetMediumOrLargeCord()); + EXPECT_EQ(large_byte_string, GetMediumStringView()); + swap(medium_byte_string, large_byte_string); + EXPECT_EQ(medium_byte_string, GetMediumStringView()); + EXPECT_EQ(large_byte_string, GetMediumOrLargeCord()); + + // Large <=> Large + const absl::Cord different_medium_or_large_cord = + absl::Cord(kDifferentMediumStringView); + ByteString other_large_byte_string = + ByteString::Owned(GetAllocator(), different_medium_or_large_cord); + swap(large_byte_string, other_large_byte_string); + EXPECT_EQ(large_byte_string, different_medium_or_large_cord); + EXPECT_EQ(other_large_byte_string, GetMediumStringView()); + swap(large_byte_string, other_large_byte_string); + EXPECT_EQ(large_byte_string, GetMediumStringView()); + EXPECT_EQ(other_large_byte_string, different_medium_or_large_cord); + + // Miscellaneous cases not covered above. These do not swap a second time to + // restore state, so they are destructive. + // Small <=> Different Allocator Medium + ByteString medium_new_delete_byte_string = + ByteString::Owned(NewDeleteAllocator(), kDifferentMediumStringView); + swap(empty_byte_string, medium_new_delete_byte_string); + EXPECT_EQ(empty_byte_string, kDifferentMediumStringView); + EXPECT_EQ(medium_new_delete_byte_string, ""); + // Small <=> Different Allocator Large + ByteString large_new_delete_byte_string = + ByteString::Owned(NewDeleteAllocator(), GetMediumOrLargeCord()); + swap(small_byte_string, large_new_delete_byte_string); + EXPECT_EQ(small_byte_string, GetMediumOrLargeCord()); + EXPECT_EQ(large_new_delete_byte_string, GetSmallStringView()); + // Medium <=> Different Allocator Large + large_new_delete_byte_string = + ByteString::Owned(NewDeleteAllocator(), different_medium_or_large_cord); + swap(medium_byte_string, large_new_delete_byte_string); + EXPECT_EQ(medium_byte_string, different_medium_or_large_cord); + EXPECT_EQ(large_new_delete_byte_string, GetMediumStringView()); + // Medium <=> Different Allocator Medium + medium_byte_string = ByteString::Owned(GetAllocator(), GetMediumStringView()); + medium_new_delete_byte_string = + ByteString::Owned(NewDeleteAllocator(), kDifferentMediumStringView); + swap(medium_byte_string, medium_new_delete_byte_string); + EXPECT_EQ(medium_byte_string, kDifferentMediumStringView); + EXPECT_EQ(medium_new_delete_byte_string, GetMediumStringView()); +} + +TEST_P(ByteStringTest, FlattenSmall) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetSmallStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string.Flatten(), GetSmallStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); +} + +TEST_P(ByteStringTest, FlattenMedium) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); + EXPECT_EQ(byte_string.Flatten(), GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); +} + +TEST_P(ByteStringTest, FlattenLarge) { + if (GetAllocator().arena() != nullptr) { + GTEST_SKIP(); + } + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kLarge); + EXPECT_EQ(byte_string.Flatten(), GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kLarge); +} + +TEST_P(ByteStringTest, TryFlatSmall) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetSmallStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_THAT(byte_string.TryFlat(), Optional(GetSmallStringView())); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); +} + +TEST_P(ByteStringTest, TryFlatMedium) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); + EXPECT_THAT(byte_string.TryFlat(), Optional(GetMediumStringView())); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); +} + +TEST_P(ByteStringTest, TryFlatLarge) { + if (GetAllocator().arena() != nullptr) { + GTEST_SKIP(); + } + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeFragmentedCord()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kLarge); + EXPECT_THAT(byte_string.TryFlat(), Eq(absl::nullopt)); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kLarge); +} + +TEST_P(ByteStringTest, GetFlatSmall) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetSmallStringView()); + std::string scratch; + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string.GetFlat(scratch), GetSmallStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); +} + +TEST_P(ByteStringTest, GetFlatMedium) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumStringView()); + std::string scratch; + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); + EXPECT_EQ(byte_string.GetFlat(scratch), GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); +} + +TEST_P(ByteStringTest, GetFlatLarge) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeCord()); + std::string scratch; + EXPECT_EQ(byte_string.GetFlat(scratch), GetMediumStringView()); +} + +TEST_P(ByteStringTest, GetFlatLargeFragmented) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeFragmentedCord()); + std::string scratch; + EXPECT_EQ(byte_string.GetFlat(scratch), GetMediumStringView()); +} + +TEST_P(ByteStringTest, Equals) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_TRUE(byte_string.Equals(GetMediumStringView())); +} + +TEST_P(ByteStringTest, Compare) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_EQ(byte_string.Compare(GetMediumStringView()), 0); + EXPECT_EQ(byte_string.Compare(GetMediumOrLargeCord()), 0); +} + +TEST_P(ByteStringTest, StartsWith) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_TRUE(byte_string.StartsWith( + GetMediumStringView().substr(0, kSmallByteStringCapacity))); + EXPECT_TRUE(byte_string.StartsWith( + GetMediumOrLargeCord().Subcord(0, kSmallByteStringCapacity))); +} + +TEST_P(ByteStringTest, EndsWith) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_TRUE(byte_string.EndsWith( + GetMediumStringView().substr(kSmallByteStringCapacity))); + EXPECT_TRUE(byte_string.EndsWith(GetMediumOrLargeCord().Subcord( + kSmallByteStringCapacity, + GetMediumOrLargeCord().size() - kSmallByteStringCapacity))); +} + +TEST_P(ByteStringTest, RemovePrefixSmall) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetSmallStringView()); + byte_string.RemovePrefix(1); + EXPECT_EQ(byte_string, GetSmallStringView().substr(1)); +} + +TEST_P(ByteStringTest, RemovePrefixMedium) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); + byte_string.RemovePrefix(byte_string.size() - kSmallByteStringCapacity); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string, + GetMediumStringView().substr(GetMediumStringView().size() - + kSmallByteStringCapacity)); +} + +TEST_P(ByteStringTest, RemovePrefixMediumOrLarge) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeCord()); + byte_string.RemovePrefix(byte_string.size() - kSmallByteStringCapacity); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string, + GetMediumStringView().substr(GetMediumStringView().size() - + kSmallByteStringCapacity)); +} + +TEST_P(ByteStringTest, RemoveSuffixSmall) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetSmallStringView()); + byte_string.RemoveSuffix(1); + EXPECT_EQ(byte_string, + GetSmallStringView().substr(0, GetSmallStringView().size() - 1)); +} + +TEST_P(ByteStringTest, RemoveSuffixMedium) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kMedium); + byte_string.RemoveSuffix(byte_string.size() - kSmallByteStringCapacity); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string, + GetMediumStringView().substr(0, kSmallByteStringCapacity)); +} + +TEST_P(ByteStringTest, RemoveSuffixMediumOrLarge) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeCord()); + byte_string.RemoveSuffix(byte_string.size() - kSmallByteStringCapacity); + EXPECT_EQ(GetKind(byte_string), ByteStringKind::kSmall); + EXPECT_EQ(byte_string, + GetMediumStringView().substr(0, kSmallByteStringCapacity)); +} + +TEST_P(ByteStringTest, ToStringSmall) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetSmallStringView()); + EXPECT_EQ(byte_string.ToString(), byte_string); +} + +TEST_P(ByteStringTest, ToStringMedium) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumStringView()); + EXPECT_EQ(byte_string.ToString(), byte_string); +} + +TEST_P(ByteStringTest, ToStringLarge) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_EQ(byte_string.ToString(), byte_string); +} + +TEST_P(ByteStringTest, ToCordSmall) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetSmallStringView()); + EXPECT_EQ(byte_string.ToCord(), byte_string); + EXPECT_EQ(std::move(byte_string).ToCord(), GetSmallStringView()); +} + +TEST_P(ByteStringTest, ToCordMedium) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumStringView()); + EXPECT_EQ(byte_string.ToCord(), byte_string); + EXPECT_EQ(std::move(byte_string).ToCord(), GetMediumStringView()); +} + +TEST_P(ByteStringTest, ToCordLarge) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeCord()); + EXPECT_EQ(byte_string.ToCord(), byte_string); + EXPECT_EQ(std::move(byte_string).ToCord(), GetMediumOrLargeCord()); +} + +TEST_P(ByteStringTest, HashValue) { + EXPECT_EQ( + absl::HashOf(ByteString::Owned(GetAllocator(), GetSmallStringView())), + absl::HashOf(GetSmallStringView())); + EXPECT_EQ( + absl::HashOf(ByteString::Owned(GetAllocator(), GetMediumStringView())), + absl::HashOf(GetMediumStringView())); + EXPECT_EQ( + absl::HashOf(ByteString::Owned(GetAllocator(), GetMediumOrLargeCord())), + absl::HashOf(GetMediumOrLargeCord())); +} + +INSTANTIATE_TEST_SUITE_P( + ByteStringTest, ByteStringTest, + ::testing::Values(MemoryManagement::kPooling, + MemoryManagement::kReferenceCounting)); + +class ByteStringViewTest : public TestWithParam, + public ByteStringViewTestFriend { + public: + Allocator<> GetAllocator() { + switch (GetParam()) { + case MemoryManagement::kPooling: + return ArenaAllocator(&arena_); + case MemoryManagement::kReferenceCounting: + return NewDeleteAllocator(); + } + } + + private: + google::protobuf::Arena arena_; +}; + +TEST_P(ByteStringViewTest, Default) { + ByteStringView byte_String_view; + EXPECT_THAT(byte_String_view, SizeIs(0)); + EXPECT_THAT(byte_String_view, IsEmpty()); + EXPECT_EQ(GetKind(byte_String_view), ByteStringViewKind::kString); +} + +TEST_P(ByteStringViewTest, String) { + ByteStringView byte_string_view(GetSmallStringView()); + EXPECT_THAT(byte_string_view, SizeIs(GetSmallStringView().size())); + EXPECT_THAT(byte_string_view, Not(IsEmpty())); + EXPECT_EQ(byte_string_view, GetSmallStringView()); + EXPECT_EQ(GetKind(byte_string_view), ByteStringViewKind::kString); + EXPECT_EQ(byte_string_view.GetArena(), nullptr); +} + +TEST_P(ByteStringViewTest, Cord) { + ByteStringView byte_string_view(GetMediumOrLargeCord()); + EXPECT_THAT(byte_string_view, SizeIs(GetMediumOrLargeCord().size())); + EXPECT_THAT(byte_string_view, Not(IsEmpty())); + EXPECT_EQ(byte_string_view, GetMediumOrLargeCord()); + EXPECT_EQ(GetKind(byte_string_view), ByteStringViewKind::kCord); + EXPECT_EQ(byte_string_view.GetArena(), nullptr); +} + +TEST_P(ByteStringViewTest, ByteStringSmall) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetSmallStringView()); + ByteStringView byte_string_view(byte_string); + EXPECT_THAT(byte_string_view, SizeIs(GetSmallStringView().size())); + EXPECT_THAT(byte_string_view, Not(IsEmpty())); + EXPECT_EQ(byte_string_view, GetSmallStringView()); + EXPECT_EQ(GetKind(byte_string_view), ByteStringViewKind::kString); + EXPECT_EQ(byte_string_view.GetArena(), GetAllocator().arena()); +} + +TEST_P(ByteStringViewTest, ByteStringMedium) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumStringView()); + ByteStringView byte_string_view(byte_string); + EXPECT_THAT(byte_string_view, SizeIs(GetMediumStringView().size())); + EXPECT_THAT(byte_string_view, Not(IsEmpty())); + EXPECT_EQ(byte_string_view, GetMediumStringView()); + EXPECT_EQ(GetKind(byte_string_view), ByteStringViewKind::kString); + EXPECT_EQ(byte_string_view.GetArena(), GetAllocator().arena()); +} + +TEST_P(ByteStringViewTest, ByteStringLarge) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeCord()); + ByteStringView byte_string_view(byte_string); + EXPECT_THAT(byte_string_view, SizeIs(GetMediumOrLargeCord().size())); + EXPECT_THAT(byte_string_view, Not(IsEmpty())); + EXPECT_EQ(byte_string_view, GetMediumOrLargeCord()); + EXPECT_EQ(byte_string_view.ToCord(), byte_string_view); + if (GetAllocator().arena() == nullptr) { + EXPECT_EQ(GetKind(byte_string_view), ByteStringViewKind::kCord); + } else { + EXPECT_EQ(GetKind(byte_string_view), ByteStringViewKind::kString); + } + EXPECT_EQ(byte_string_view.GetArena(), GetAllocator().arena()); +} + +TEST_P(ByteStringViewTest, TryFlatString) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetSmallStringView()); + ByteStringView byte_string_view(byte_string); + EXPECT_THAT(byte_string_view.TryFlat(), Optional(GetSmallStringView())); +} + +TEST_P(ByteStringViewTest, TryFlatCord) { + if (GetAllocator().arena() != nullptr) { + GTEST_SKIP(); + } + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeFragmentedCord()); + ByteStringView byte_string_view(byte_string); + EXPECT_THAT(byte_string_view.TryFlat(), Eq(absl::nullopt)); +} + +TEST_P(ByteStringViewTest, GetFlatString) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetSmallStringView()); + ByteStringView byte_string_view(byte_string); + std::string scratch; + EXPECT_EQ(byte_string_view.GetFlat(scratch), GetSmallStringView()); +} + +TEST_P(ByteStringViewTest, GetFlatCord) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeCord()); + ByteStringView byte_string_view(byte_string); + std::string scratch; + EXPECT_EQ(byte_string_view.GetFlat(scratch), GetMediumStringView()); +} + +TEST_P(ByteStringViewTest, GetFlatLargeFragmented) { + ByteString byte_string = + ByteString::Owned(GetAllocator(), GetMediumOrLargeFragmentedCord()); + ByteStringView byte_string_view(byte_string); + std::string scratch; + EXPECT_EQ(byte_string_view.GetFlat(scratch), GetMediumStringView()); +} + +TEST_P(ByteStringViewTest, RemovePrefixString) { + ByteStringView byte_string_view(GetSmallStringView()); + byte_string_view.RemovePrefix(1); + EXPECT_EQ(byte_string_view, GetSmallStringView().substr(1)); +} + +TEST_P(ByteStringViewTest, RemovePrefixCord) { + ByteStringView byte_string_view(GetMediumOrLargeCord()); + byte_string_view.RemovePrefix(1); + EXPECT_EQ(byte_string_view, GetMediumOrLargeCord().Subcord( + 1, GetMediumOrLargeCord().size() - 1)); +} + +TEST_P(ByteStringViewTest, RemoveSuffixString) { + ByteStringView byte_string_view(GetSmallStringView()); + byte_string_view.RemoveSuffix(1); + EXPECT_EQ(byte_string_view, + GetSmallStringView().substr(0, GetSmallStringView().size() - 1)); +} + +TEST_P(ByteStringViewTest, RemoveSuffixCord) { + ByteStringView byte_string_view(GetMediumOrLargeCord()); + byte_string_view.RemoveSuffix(1); + EXPECT_EQ(byte_string_view, GetMediumOrLargeCord().Subcord( + 0, GetMediumOrLargeCord().size() - 1)); +} + +TEST_P(ByteStringViewTest, ToStringString) { + ByteStringView byte_string_view(GetSmallStringView()); + EXPECT_EQ(byte_string_view.ToString(), byte_string_view); +} + +TEST_P(ByteStringViewTest, ToStringCord) { + ByteStringView byte_string_view(GetMediumOrLargeCord()); + EXPECT_EQ(byte_string_view.ToString(), byte_string_view); +} + +TEST_P(ByteStringViewTest, ToCordString) { + ByteString byte_string(GetAllocator(), GetMediumStringView()); + ByteStringView byte_string_view(byte_string); + EXPECT_EQ(byte_string_view.ToCord(), byte_string_view); +} + +TEST_P(ByteStringViewTest, ToCordCord) { + ByteStringView byte_string_view(GetMediumOrLargeCord()); + EXPECT_EQ(byte_string_view.ToCord(), byte_string_view); +} + +TEST_P(ByteStringViewTest, HashValue) { + EXPECT_EQ(absl::HashOf(ByteStringView(GetSmallStringView())), + absl::HashOf(GetSmallStringView())); + EXPECT_EQ(absl::HashOf(ByteStringView(GetMediumStringView())), + absl::HashOf(GetMediumStringView())); + EXPECT_EQ(absl::HashOf(ByteStringView(GetMediumOrLargeCord())), + absl::HashOf(GetMediumOrLargeCord())); +} + +INSTANTIATE_TEST_SUITE_P( + ByteStringViewTest, ByteStringViewTest, + ::testing::Values(MemoryManagement::kPooling, + MemoryManagement::kReferenceCounting)); + +} // namespace +} // namespace cel::common_internal diff --git a/common/memory.h b/common/memory.h index 0288d9049..7fcc6f449 100644 --- a/common/memory.h +++ b/common/memory.h @@ -76,6 +76,7 @@ struct PoolingMemoryManagerVirtualTable; class PoolingMemoryManagerVirtualDispatcher; namespace common_internal { +absl::Nullable OwnerRelease(Owner& owner) noexcept; template T* GetPointer(const Shared& shared); template @@ -187,6 +188,8 @@ class ABSL_ATTRIBUTE_TRIVIAL_ABI [[nodiscard]] Owner final { private: friend class Borrower; + friend absl::Nullable + common_internal::OwnerRelease(Owner& owner) noexcept; constexpr explicit Owner(uintptr_t ptr) noexcept : ptr_(ptr) {} @@ -246,6 +249,19 @@ inline bool operator!=(const Owner& lhs, const Owner& rhs) noexcept { return !operator==(lhs, rhs); } +namespace common_internal { + +inline absl::Nullable OwnerRelease( + Owner& owner) noexcept { + uintptr_t ptr = std::exchange(owner.ptr_, uintptr_t{0}); + if (Owner::IsReferenceCount(ptr)) { + return Owner::AsReferenceCount(ptr); + } + return nullptr; +} + +} // namespace common_internal + // `Borrower` represents a reference to some borrowed data, where the data has // at least one owner. When using reference counting, `Borrower` does not // participate in incrementing/decrementing the reference count. Thus `Borrower`