Skip to content

Commit

Permalink
Move arithmetic operators to runtime/standard.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 549769613
  • Loading branch information
jnthntatum authored and copybara-github committed Jul 20, 2023
1 parent dfc8d87 commit d0131ea
Show file tree
Hide file tree
Showing 6 changed files with 400 additions and 212 deletions.
1 change: 1 addition & 0 deletions eval/public/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ cc_library(
"//internal:utf8",
"//runtime:function_registry",
"//runtime:runtime_options",
"//runtime/standard:arithmetic_functions",
"//runtime/standard:comparison_functions",
"//runtime/standard:container_functions",
"//runtime/standard:logical_functions",
Expand Down
215 changes: 3 additions & 212 deletions eval/public/builtin_func_registrar.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include <array>
#include <cstdint>
#include <functional>
#include <limits>
#include <string>

#include "absl/status/status.h"
Expand Down Expand Up @@ -49,6 +48,7 @@
#include "internal/utf8.h"
#include "runtime/function_registry.h"
#include "runtime/runtime_options.h"
#include "runtime/standard/arithmetic_functions.h"
#include "runtime/standard/comparison_functions.h"
#include "runtime/standard/container_functions.h"
#include "runtime/standard/logical_functions.h"
Expand All @@ -68,216 +68,6 @@ using ::cel::Value;
using ::cel::ValueFactory;
using ::google::protobuf::Arena;

// Template functions providing arithmetic operations
template <class Type>
Handle<Value> Add(ValueFactory&, Type v0, Type v1);

template <>
Handle<Value> Add<int64_t>(ValueFactory& value_factory, int64_t v0,
int64_t v1) {
auto sum = cel::internal::CheckedAdd(v0, v1);
if (!sum.ok()) {
return value_factory.CreateErrorValue(sum.status());
}
return value_factory.CreateIntValue(*sum);
}

template <>
Handle<Value> Add<uint64_t>(ValueFactory& value_factory, uint64_t v0,
uint64_t v1) {
auto sum = cel::internal::CheckedAdd(v0, v1);
if (!sum.ok()) {
return value_factory.CreateErrorValue(sum.status());
}
return value_factory.CreateUintValue(*sum);
}

template <>
Handle<Value> Add<double>(ValueFactory& value_factory, double v0, double v1) {
return value_factory.CreateDoubleValue(v0 + v1);
}

template <class Type>
Handle<Value> Sub(ValueFactory&, Type v0, Type v1);

template <>
Handle<Value> Sub<int64_t>(ValueFactory& value_factory, int64_t v0,
int64_t v1) {
auto diff = cel::internal::CheckedSub(v0, v1);
if (!diff.ok()) {
return value_factory.CreateErrorValue(diff.status());
}
return value_factory.CreateIntValue(*diff);
}

template <>
Handle<Value> Sub<uint64_t>(ValueFactory& value_factory, uint64_t v0,
uint64_t v1) {
auto diff = cel::internal::CheckedSub(v0, v1);
if (!diff.ok()) {
return value_factory.CreateErrorValue(diff.status());
}
return value_factory.CreateUintValue(*diff);
}

template <>
Handle<Value> Sub<double>(ValueFactory& value_factory, double v0, double v1) {
return value_factory.CreateDoubleValue(v0 - v1);
}

template <class Type>
Handle<Value> Mul(ValueFactory&, Type v0, Type v1);

template <>
Handle<Value> Mul<int64_t>(ValueFactory& value_factory, int64_t v0,
int64_t v1) {
auto prod = cel::internal::CheckedMul(v0, v1);
if (!prod.ok()) {
return value_factory.CreateErrorValue(prod.status());
}
return value_factory.CreateIntValue(*prod);
}

template <>
Handle<Value> Mul<uint64_t>(ValueFactory& value_factory, uint64_t v0,
uint64_t v1) {
auto prod = cel::internal::CheckedMul(v0, v1);
if (!prod.ok()) {
return value_factory.CreateErrorValue(prod.status());
}
return value_factory.CreateUintValue(*prod);
}

template <>
Handle<Value> Mul<double>(ValueFactory& value_factory, double v0, double v1) {
return value_factory.CreateDoubleValue(v0 * v1);
}

template <class Type>
Handle<Value> Div(ValueFactory&, Type v0, Type v1);

// Division operations for integer types should check for
// division by 0
template <>
Handle<Value> Div<int64_t>(ValueFactory& value_factory, int64_t v0,
int64_t v1) {
auto quot = cel::internal::CheckedDiv(v0, v1);
if (!quot.ok()) {
return value_factory.CreateErrorValue(quot.status());
}
return value_factory.CreateIntValue(*quot);
}

// Division operations for integer types should check for
// division by 0
template <>
Handle<Value> Div<uint64_t>(ValueFactory& value_factory, uint64_t v0,
uint64_t v1) {
auto quot = cel::internal::CheckedDiv(v0, v1);
if (!quot.ok()) {
return value_factory.CreateErrorValue(quot.status());
}
return value_factory.CreateUintValue(*quot);
}

template <>
Handle<Value> Div<double>(ValueFactory& value_factory, double v0, double v1) {
static_assert(std::numeric_limits<double>::is_iec559,
"Division by zero for doubles must be supported");

// For double, division will result in +/- inf
return value_factory.CreateDoubleValue(v0 / v1);
}

// Modulo operation
template <class Type>
Handle<Value> Modulo(ValueFactory& value_factory, Type v0, Type v1);

// Modulo operations for integer types should check for
// division by 0
template <>
Handle<Value> Modulo<int64_t>(ValueFactory& value_factory, int64_t v0,
int64_t v1) {
auto mod = cel::internal::CheckedMod(v0, v1);
if (!mod.ok()) {
return value_factory.CreateErrorValue(mod.status());
}
return value_factory.CreateIntValue(*mod);
}

template <>
Handle<Value> Modulo<uint64_t>(ValueFactory& value_factory, uint64_t v0,
uint64_t v1) {
auto mod = cel::internal::CheckedMod(v0, v1);
if (!mod.ok()) {
return value_factory.CreateErrorValue(mod.status());
}
return value_factory.CreateUintValue(*mod);
}

// Helper method
// Registers all arithmetic functions for template parameter type.
template <class Type>
absl::Status RegisterArithmeticFunctionsForType(CelFunctionRegistry* registry) {
using FunctionAdapter = cel::BinaryFunctionAdapter<Handle<Value>, Type, Type>;
CEL_RETURN_IF_ERROR(registry->Register(
FunctionAdapter::CreateDescriptor(cel::builtin::kAdd, false),
FunctionAdapter::WrapFunction(&Add<Type>)));

CEL_RETURN_IF_ERROR(registry->Register(
FunctionAdapter::CreateDescriptor(cel::builtin::kSubtract, false),
FunctionAdapter::WrapFunction(&Sub<Type>)));

CEL_RETURN_IF_ERROR(registry->Register(
FunctionAdapter::CreateDescriptor(cel::builtin::kMultiply, false),
FunctionAdapter::WrapFunction(&Mul<Type>)));

return registry->Register(
FunctionAdapter::CreateDescriptor(cel::builtin::kDivide, false),
FunctionAdapter::WrapFunction(&Div<Type>));
}

// Register basic Arithmetic functions for numeric types.
absl::Status RegisterNumericArithmeticFunctions(
CelFunctionRegistry* registry, const InterpreterOptions& options) {
CEL_RETURN_IF_ERROR(RegisterArithmeticFunctionsForType<int64_t>(registry));
CEL_RETURN_IF_ERROR(RegisterArithmeticFunctionsForType<uint64_t>(registry));
CEL_RETURN_IF_ERROR(RegisterArithmeticFunctionsForType<double>(registry));

// Modulo
CEL_RETURN_IF_ERROR(registry->Register(
BinaryFunctionAdapter<Handle<Value>, int64_t, int64_t>::CreateDescriptor(
cel::builtin::kModulo, false),
BinaryFunctionAdapter<Handle<Value>, int64_t, int64_t>::WrapFunction(
&Modulo<int64_t>)));

CEL_RETURN_IF_ERROR(registry->Register(
BinaryFunctionAdapter<Handle<Value>, uint64_t,
uint64_t>::CreateDescriptor(cel::builtin::kModulo,
false),
BinaryFunctionAdapter<Handle<Value>, uint64_t, uint64_t>::WrapFunction(
&Modulo<uint64_t>)));

// Negation group
CEL_RETURN_IF_ERROR(registry->Register(
UnaryFunctionAdapter<Handle<Value>, int64_t>::CreateDescriptor(
cel::builtin::kNeg, false),
UnaryFunctionAdapter<Handle<Value>, int64_t>::WrapFunction(
[](ValueFactory& value_factory, int64_t value) -> Handle<Value> {
auto inv = cel::internal::CheckedNegation(value);
if (!inv.ok()) {
return value_factory.CreateErrorValue(inv.status());
}
return value_factory.CreateIntValue(*inv);
})));

return registry->Register(
UnaryFunctionAdapter<double, double>::CreateDescriptor(cel::builtin::kNeg,
false),
UnaryFunctionAdapter<double, double>::WrapFunction(
[](ValueFactory&, double value) -> double { return -value; }));
}

template <class T>
bool ValueEquals(const CelValue& value, T other);

Expand Down Expand Up @@ -1197,11 +987,12 @@ absl::Status RegisterBuiltinFunctions(CelFunctionRegistry* registry,
cel::RegisterContainerFunctions(modern_registry, runtime_options));
CEL_RETURN_IF_ERROR(
cel::RegisterTypeConversionFunctions(modern_registry, runtime_options));
CEL_RETURN_IF_ERROR(
cel::RegisterArithmeticFunctions(modern_registry, runtime_options));

return registry->RegisterAll(
{
&RegisterEqualityFunctions,
&RegisterNumericArithmeticFunctions,
&RegisterTimeFunctions,
&RegisterStringFunctions,
&RegisterRegexFunctions,
Expand Down
32 changes: 32 additions & 0 deletions runtime/standard/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -167,3 +167,35 @@ cc_test(
"//internal:testing",
],
)

cc_library(
name = "arithmetic_functions",
srcs = ["arithmetic_functions.cc"],
hdrs = ["arithmetic_functions.h"],
deps = [
"//base:builtins",
"//base:data",
"//base:function_adapter",
"//base:handle",
"//internal:overflow",
"//internal:status_macros",
"//runtime:function_registry",
"//runtime:runtime_options",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
],
)

cc_test(
name = "arithmetic_functions_test",
size = "small",
srcs = [
"arithmetic_functions_test.cc",
],
deps = [
":arithmetic_functions",
"//base:builtins",
"//base:function_descriptor",
"//internal:testing",
],
)
Loading

0 comments on commit d0131ea

Please # to comment.