Skip to content

Commit

Permalink
Avoid copies in function adapter code and CEL standard implementation…
Browse files Browse the repository at this point in the history
…s for string value args.

PiperOrigin-RevId: 609798481
  • Loading branch information
jnthntatum authored and copybara-github committed Feb 26, 2024
1 parent 4ac6b7c commit 37fc0f9
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 56 deletions.
34 changes: 33 additions & 1 deletion base/function_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
#include <vector>

#include "absl/functional/bind_front.h"
#include "absl/log/die_if_null.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
Expand All @@ -50,6 +49,39 @@ struct AdaptedTypeTraits {

// Specialization for cref parameters without forcing a temporary copy of the
// underlying handle argument.
template <>
struct AdaptedTypeTraits<const Value&> {
using AssignableType = const Value*;

static std::reference_wrapper<const Value> ToArg(AssignableType v) {
return *v;
}
};

template <>
struct AdaptedTypeTraits<const StringValue&> {
using AssignableType = const StringValue*;

static std::reference_wrapper<const StringValue> ToArg(AssignableType v) {
return *v;
}
};

template <>
struct AdaptedTypeTraits<const BytesValue&> {
using AssignableType = const BytesValue*;

static std::reference_wrapper<const BytesValue> ToArg(AssignableType v) {
return *v;
}
};

// Partial specialization for other cases.
//
// These types aren't referenceable since they aren't actually
// represented as alternatives in the underlying variant.
//
// This still requires an implicit copy and corresponding ref-count increase.
template <typename T>
struct AdaptedTypeTraits<const T&> {
using AssignableType = T;
Expand Down
2 changes: 1 addition & 1 deletion base/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ cc_test(
srcs = ["function_adapter_test.cc"],
deps = [
":function_adapter",
"//base:data",
"//base:kind",
"//common:casting",
"//common:memory",
"//common:type",
"//common:value",
Expand Down
20 changes: 18 additions & 2 deletions base/internal/function_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "base/kind.h"
#include "common/casting.h"
#include "common/value.h"
#include "common/value_factory.h"
#include "common/value_manager.h"
#include "internal/status_macros.h"

Expand Down Expand Up @@ -173,6 +174,19 @@ struct HandleToAdaptedVisitor {
return absl::OkStatus();
}

template <typename T>
absl::Status operator()(T** out) {
if (!InstanceOf<std::remove_const_t<T>>(input)) {
return absl::InvalidArgumentError(
absl::StrCat("expected ", ValueKindToString(T::kKind), " value"));
}
static_assert(std::is_lvalue_reference_v<
decltype(Cast<std::remove_const_t<T>>(input))>,
"expected l-value reference return type for Cast.");
*out = &Cast<std::remove_const_t<T>>(input);
return absl::OkStatus();
}

const Value& input;
};

Expand Down Expand Up @@ -210,8 +224,10 @@ struct AdaptedToHandleVisitor {
// present, otherwise return the status.
template <typename T>
absl::StatusOr<Value> operator()(absl::StatusOr<T> wrapped) {
CEL_ASSIGN_OR_RETURN(auto value, wrapped);
return this->operator()(std::move(value));
if (!wrapped.ok()) {
return std::move(wrapped).status();
}
return this->operator()(std::move(wrapped).value());
}

cel::ValueFactory& value_factory;
Expand Down
43 changes: 21 additions & 22 deletions base/internal/function_adapter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
#include "base/internal/function_adapter.h"

#include <cstdint>
#include <string>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/time/time.h"
#include "base/kind.h"
#include "base/type_provider.h"
#include "common/casting.h"
#include "common/memory.h"
#include "common/type_factory.h"
#include "common/type_manager.h"
Expand Down Expand Up @@ -244,8 +243,8 @@ TEST_F(AdaptedToHandleVisitorTest, Int) {
ASSERT_OK_AND_ASSIGN(auto result,
AdaptedToHandleVisitor{value_factory()}(value));

ASSERT_TRUE(result->Is<IntValue>());
EXPECT_EQ(result.As<IntValue>().NativeValue(), 10);
ASSERT_TRUE(InstanceOf<IntValue>(result));
EXPECT_EQ(Cast<IntValue>(result).NativeValue(), 10);
}

TEST_F(AdaptedToHandleVisitorTest, Double) {
Expand All @@ -254,8 +253,8 @@ TEST_F(AdaptedToHandleVisitorTest, Double) {
ASSERT_OK_AND_ASSIGN(auto result,
AdaptedToHandleVisitor{value_factory()}(value));

ASSERT_TRUE(result->Is<DoubleValue>());
EXPECT_EQ(result.As<DoubleValue>().NativeValue(), 10.0);
ASSERT_TRUE(InstanceOf<DoubleValue>(result));
EXPECT_EQ(Cast<DoubleValue>(result).NativeValue(), 10.0);
}

TEST_F(AdaptedToHandleVisitorTest, Uint) {
Expand All @@ -264,8 +263,8 @@ TEST_F(AdaptedToHandleVisitorTest, Uint) {
ASSERT_OK_AND_ASSIGN(auto result,
AdaptedToHandleVisitor{value_factory()}(value));

ASSERT_TRUE(result->Is<UintValue>());
EXPECT_EQ(result.As<UintValue>().NativeValue(), 10);
ASSERT_TRUE(InstanceOf<UintValue>(result));
EXPECT_EQ(Cast<UintValue>(result).NativeValue(), 10);
}

TEST_F(AdaptedToHandleVisitorTest, Bool) {
Expand All @@ -274,8 +273,8 @@ TEST_F(AdaptedToHandleVisitorTest, Bool) {
ASSERT_OK_AND_ASSIGN(auto result,
AdaptedToHandleVisitor{value_factory()}(value));

ASSERT_TRUE(result->Is<BoolValue>());
EXPECT_EQ(result.As<BoolValue>().NativeValue(), true);
ASSERT_TRUE(InstanceOf<BoolValue>(result));
EXPECT_EQ(Cast<BoolValue>(result).NativeValue(), true);
}

TEST_F(AdaptedToHandleVisitorTest, Timestamp) {
Expand All @@ -284,8 +283,8 @@ TEST_F(AdaptedToHandleVisitorTest, Timestamp) {
ASSERT_OK_AND_ASSIGN(auto result,
AdaptedToHandleVisitor{value_factory()}(value));

ASSERT_TRUE(result->Is<TimestampValue>());
EXPECT_EQ(result.As<TimestampValue>().NativeValue(),
ASSERT_TRUE(InstanceOf<TimestampValue>(result));
EXPECT_EQ(Cast<TimestampValue>(result).NativeValue(),
absl::UnixEpoch() + absl::Seconds(10));
}

Expand All @@ -295,8 +294,8 @@ TEST_F(AdaptedToHandleVisitorTest, Duration) {
ASSERT_OK_AND_ASSIGN(auto result,
AdaptedToHandleVisitor{value_factory()}(value));

ASSERT_TRUE(result->Is<DurationValue>());
EXPECT_EQ(result.As<DurationValue>().NativeValue(), absl::Seconds(5));
ASSERT_TRUE(InstanceOf<DurationValue>(result));
EXPECT_EQ(Cast<DurationValue>(result).NativeValue(), absl::Seconds(5));
}

TEST_F(AdaptedToHandleVisitorTest, String) {
Expand All @@ -306,8 +305,8 @@ TEST_F(AdaptedToHandleVisitorTest, String) {
ASSERT_OK_AND_ASSIGN(auto result,
AdaptedToHandleVisitor{value_factory()}(value));

ASSERT_TRUE(result->Is<StringValue>());
EXPECT_EQ(result.As<StringValue>().ToString(), "str");
ASSERT_TRUE(InstanceOf<StringValue>(result));
EXPECT_EQ(Cast<StringValue>(result).ToString(), "str");
}

TEST_F(AdaptedToHandleVisitorTest, Bytes) {
Expand All @@ -317,8 +316,8 @@ TEST_F(AdaptedToHandleVisitorTest, Bytes) {
ASSERT_OK_AND_ASSIGN(auto result,
AdaptedToHandleVisitor{value_factory()}(value));

ASSERT_TRUE(result->Is<BytesValue>());
EXPECT_EQ(result.As<BytesValue>().ToString(), "bytes");
ASSERT_TRUE(InstanceOf<BytesValue>(result));
EXPECT_EQ(Cast<BytesValue>(result).ToString(), "bytes");
}

TEST_F(AdaptedToHandleVisitorTest, StatusOrValue) {
Expand All @@ -327,8 +326,8 @@ TEST_F(AdaptedToHandleVisitorTest, StatusOrValue) {
ASSERT_OK_AND_ASSIGN(auto result,
AdaptedToHandleVisitor{value_factory()}(value));

ASSERT_TRUE(result->Is<IntValue>());
EXPECT_EQ(result.As<IntValue>().NativeValue(), 10);
ASSERT_TRUE(InstanceOf<IntValue>(result));
EXPECT_EQ(Cast<IntValue>(result).NativeValue(), 10);
}

TEST_F(AdaptedToHandleVisitorTest, StatusOrError) {
Expand All @@ -345,8 +344,8 @@ TEST_F(AdaptedToHandleVisitorTest, Any) {
ASSERT_OK_AND_ASSIGN(auto result,
AdaptedToHandleVisitor{value_factory()}(handle));

ASSERT_TRUE(result->Is<ErrorValue>());
EXPECT_THAT(result.As<ErrorValue>().NativeValue(),
ASSERT_TRUE(InstanceOf<ErrorValue>(result));
EXPECT_THAT(Cast<ErrorValue>(result).NativeValue(),
StatusIs(absl::StatusCode::kInternal, "test_error"));
}

Expand Down
2 changes: 2 additions & 0 deletions runtime/standard/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ cc_library(
"//base:builtins",
"//base:function_adapter",
"//base:kind",
"//common:casting",
"//common:type",
"//common:value",
"//internal:number",
Expand Down Expand Up @@ -153,6 +154,7 @@ cc_library(
deps = [
"//base:builtins",
"//base:function_adapter",
"//common:casting",
"//common:value",
"//internal:status_macros",
"//runtime:function_registry",
Expand Down
63 changes: 35 additions & 28 deletions runtime/standard/equality_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <cstdint>
#include <functional>
#include <optional>
#include <type_traits>
#include <utility>

#include "absl/base/macros.h"
Expand All @@ -29,6 +30,7 @@
#include "base/builtins.h"
#include "base/function_adapter.h"
#include "base/kind.h"
#include "common/casting.h"
#include "common/type.h"
#include "common/value.h"
#include "common/value_manager.h"
Expand All @@ -42,6 +44,8 @@
namespace cel {
namespace {

using ::cel::Cast;
using ::cel::InstanceOf;
using ::cel::builtin::kEqual;
using ::cel::builtin::kInequal;
using ::cel::internal::Number;
Expand Down Expand Up @@ -162,7 +166,7 @@ absl::StatusOr<absl::optional<Value>> CheckAlternativeNumericType(
return absl::nullopt;
}

if (!key->Is<IntValue>() && number->LosslessConvertibleToInt()) {
if (!InstanceOf<IntValue>(key) && number->LosslessConvertibleToInt()) {
Value entry;
bool ok;
CEL_ASSIGN_OR_RETURN(
Expand All @@ -173,7 +177,7 @@ absl::StatusOr<absl::optional<Value>> CheckAlternativeNumericType(
}
}

if (!key->Is<UintValue>() && number->LosslessConvertibleToUint()) {
if (!InstanceOf<UintValue>(key) && number->LosslessConvertibleToUint()) {
Value entry;
bool ok;
CEL_ASSIGN_OR_RETURN(std::tie(entry, ok),
Expand Down Expand Up @@ -405,41 +409,44 @@ absl::StatusOr<absl::optional<bool>> HomogenousValueEqual(ValueManager& factory,
return absl::nullopt;
}

static_assert(std::is_lvalue_reference_v<decltype(Cast<StringValue>(v1))>,
"unexpected value copy");

switch (v1->kind()) {
case ValueKind::kBool:
return Equal<bool>(v1->As<BoolValue>().NativeValue(),
v2->As<BoolValue>().NativeValue());
return Equal<bool>(Cast<BoolValue>(v1).NativeValue(),
Cast<BoolValue>(v2).NativeValue());
case ValueKind::kNull:
return Equal<const NullValue&>(v1->As<NullValue>(), v2->As<NullValue>());
return Equal<const NullValue&>(Cast<NullValue>(v1), Cast<NullValue>(v2));
case ValueKind::kInt:
return Equal<int64_t>(v1->As<IntValue>().NativeValue(),
v2->As<IntValue>().NativeValue());
return Equal<int64_t>(Cast<IntValue>(v1).NativeValue(),
Cast<IntValue>(v2).NativeValue());
case ValueKind::kUint:
return Equal<uint64_t>(v1->As<UintValue>().NativeValue(),
v2->As<UintValue>().NativeValue());
return Equal<uint64_t>(Cast<UintValue>(v1).NativeValue(),
Cast<UintValue>(v2).NativeValue());
case ValueKind::kDouble:
return Equal<double>(v1->As<DoubleValue>().NativeValue(),
v2->As<DoubleValue>().NativeValue());
return Equal<double>(Cast<DoubleValue>(v1).NativeValue(),
Cast<DoubleValue>(v2).NativeValue());
case ValueKind::kDuration:
return Equal<absl::Duration>(v1->As<DurationValue>().NativeValue(),
v2->As<DurationValue>().NativeValue());
return Equal<absl::Duration>(Cast<DurationValue>(v1).NativeValue(),
Cast<DurationValue>(v2).NativeValue());
case ValueKind::kTimestamp:
return Equal<absl::Time>(v1->As<TimestampValue>().NativeValue(),
v2->As<TimestampValue>().NativeValue());
return Equal<absl::Time>(Cast<TimestampValue>(v1).NativeValue(),
Cast<TimestampValue>(v2).NativeValue());
case ValueKind::kCelType:
return Equal<const TypeValue&>(v1->As<TypeValue>(), v2->As<TypeValue>());
return Equal<const TypeValue&>(Cast<TypeValue>(v1), Cast<TypeValue>(v2));
case ValueKind::kString:
return Equal<const StringValue&>(v1->As<StringValue>(),
v2->As<StringValue>());
return Equal<const StringValue&>(Cast<StringValue>(v1),
Cast<StringValue>(v2));
case ValueKind::kBytes:
return Equal<const cel::BytesValue&>(v1->As<cel::BytesValue>(),
v2->As<cel::BytesValue>());
case ValueKind::kList:
return ListEqual<EqualsProvider>(factory, v1->As<ListValue>(),
v2->As<ListValue>());
return ListEqual<EqualsProvider>(factory, Cast<ListValue>(v1),
Cast<ListValue>(v2));
case ValueKind::kMap:
return MapEqual<EqualsProvider>(factory, v1->As<MapValue>(),
v2->As<MapValue>());
return MapEqual<EqualsProvider>(factory, Cast<MapValue>(v1),
Cast<MapValue>(v2));
default:

return absl::nullopt;
Expand Down Expand Up @@ -499,11 +506,11 @@ absl::StatusOr<absl::optional<bool>> ValueEqualImpl(ValueManager& value_factory,
const Value& v1,
const Value& v2) {
if (v1->kind() == v2->kind()) {
if (v1->Is<StructValue>() && v2->Is<StructValue>()) {
if (InstanceOf<StructValue>(v1) && InstanceOf<StructValue>(v2)) {
CEL_ASSIGN_OR_RETURN(Value result,
v1->As<StructValue>().Equal(value_factory, v2));
if (result->Is<BoolValue>()) {
return result->As<BoolValue>().NativeValue();
Cast<StructValue>(v1).Equal(value_factory, v2));
if (InstanceOf<BoolValue>(result)) {
return Cast<BoolValue>(result).NativeValue();
}
return false;
}
Expand All @@ -521,8 +528,8 @@ absl::StatusOr<absl::optional<bool>> ValueEqualImpl(ValueManager& value_factory,
// TODO(uncreated-issue/6): It's currently possible for the interpreter to create a
// map containing an Error. Return no matching overload to propagate an error
// instead of a false result.
if (v1->Is<ErrorValue>() || v1->Is<UnknownValue>() || v2->Is<ErrorValue>() ||
v2->Is<UnknownValue>()) {
if (InstanceOf<ErrorValue>(v1) || InstanceOf<UnknownValue>(v1) ||
InstanceOf<ErrorValue>(v2) || InstanceOf<UnknownValue>(v2)) {
return absl::nullopt;
}

Expand Down
Loading

0 comments on commit 37fc0f9

Please # to comment.