Skip to content

Commit

Permalink
Implement partial support for setting singular google.protobuf.Any
Browse files Browse the repository at this point in the history
…fields in `ProtoStructValueBuilder`

PiperOrigin-RevId: 540957263
  • Loading branch information
jcking authored and copybara-github committed Jun 16, 2023
1 parent 0d23085 commit 98fbc3b
Show file tree
Hide file tree
Showing 9 changed files with 436 additions and 11 deletions.
2 changes: 2 additions & 0 deletions extensions/protobuf/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,13 @@ cc_library(
"//eval/internal:interop",
"//eval/public:message_wrapper",
"//eval/public/structs:proto_message_type_adapter",
"//extensions/protobuf/internal:any",
"//extensions/protobuf/internal:map_reflection",
"//extensions/protobuf/internal:reflection",
"//extensions/protobuf/internal:time",
"//extensions/protobuf/internal:wrappers",
"//internal:casts",
"//internal:proto_time_encoding",
"//internal:rtti",
"//internal:status_macros",
"@com_google_absl//absl/base",
Expand Down
28 changes: 28 additions & 0 deletions extensions/protobuf/internal/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,31 @@ cc_test(
"@com_google_protobuf//:protobuf",
],
)

cc_library(
name = "any",
srcs = ["any.cc"],
hdrs = ["any.h"],
deps = [
"//internal:casts",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@com_google_protobuf//:protobuf",
],
)

cc_test(
name = "any_test",
srcs = ["any_test.cc"],
deps = [
":any",
"//internal:testing",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings:cord",
"@com_google_protobuf//:protobuf",
],
)
76 changes: 76 additions & 0 deletions extensions/protobuf/internal/any.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "extensions/protobuf/internal/any.h"

#include <string>

#include "google/protobuf/any.pb.h"
#include "absl/base/optimization.h"
#include "absl/log/absl_check.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "internal/casts.h"

namespace cel::extensions::protobuf_internal {

absl::Status SetAny(google::protobuf::Message& message, absl::string_view type_url,
const absl::Cord& value) {
ABSL_DCHECK_EQ(message.GetTypeName(), "google.protobuf.Any");
const auto* desc = message.GetDescriptor();
if (ABSL_PREDICT_FALSE(desc == nullptr)) {
return absl::InternalError(
absl::StrCat(message.GetTypeName(), " missing descriptor"));
}
if (ABSL_PREDICT_TRUE(desc == google::protobuf::Any::descriptor())) {
auto& any = cel::internal::down_cast<google::protobuf::Any&>(message);
any.set_type_url(type_url);
any.set_value(std::string(value));
return absl::OkStatus();
}
const auto* reflect = message.GetReflection();
if (ABSL_PREDICT_FALSE(reflect == nullptr)) {
return absl::InternalError(
absl::StrCat(message.GetTypeName(), " missing reflection"));
}
const auto* type_url_field =
desc->FindFieldByNumber(google::protobuf::Any::kTypeUrlFieldNumber);
if (ABSL_PREDICT_FALSE(type_url_field == nullptr)) {
return absl::InternalError(absl::StrCat(
message.GetTypeName(), " missing type_url field descriptor"));
}
if (ABSL_PREDICT_FALSE(type_url_field->cpp_type() !=
google::protobuf::FieldDescriptor::CPPTYPE_STRING)) {
return absl::InternalError(absl::StrCat(
message.GetTypeName(), " has unexpected type_url field type: ",
type_url_field->cpp_type_name()));
}
const auto* value_field =
desc->FindFieldByNumber(google::protobuf::Any::kValueFieldNumber);
if (ABSL_PREDICT_FALSE(value_field == nullptr)) {
return absl::InternalError(
absl::StrCat(message.GetTypeName(), " missing value field descriptor"));
}
if (ABSL_PREDICT_FALSE(value_field->cpp_type() !=
google::protobuf::FieldDescriptor::CPPTYPE_STRING)) {
return absl::InternalError(absl::StrCat(
message.GetTypeName(),
" has unexpected value field type: ", value_field->cpp_type_name()));
}
reflect->SetString(&message, type_url_field, std::string(type_url));
reflect->SetString(&message, value_field, value);
return absl::OkStatus();
}

} // namespace cel::extensions::protobuf_internal
30 changes: 30 additions & 0 deletions extensions/protobuf/internal/any.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_ANY_H_
#define THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_ANY_H_

#include "absl/status/status.h"
#include "absl/strings/cord.h"
#include "absl/strings/string_view.h"
#include "google/protobuf/message.h"

namespace cel::extensions::protobuf_internal {

absl::Status SetAny(google::protobuf::Message& message, absl::string_view type_url,
const absl::Cord& value);

} // namespace cel::extensions::protobuf_internal

#endif // THIRD_PARTY_CEL_CPP_EXTENSIONS_PROTOBUF_INTERNAL_ANY_H_
67 changes: 67 additions & 0 deletions extensions/protobuf/internal/any_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright 2023 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "extensions/protobuf/internal/any.h"

#include <memory>

#include "google/protobuf/any.pb.h"
#include "google/protobuf/descriptor.pb.h"
#include "absl/memory/memory.h"
#include "absl/strings/cord.h"
#include "internal/testing.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/descriptor_database.h"
#include "google/protobuf/dynamic_message.h"

namespace cel::extensions::protobuf_internal {
namespace {

TEST(Any, GeneratedToProto) {
google::protobuf::Any proto;
ASSERT_OK(SetAny(proto, "type.googleapis.com/foo.Bar", absl::Cord("baz")));
EXPECT_EQ(proto.type_url(), "type.googleapis.com/foo.Bar");
EXPECT_EQ(proto.value(), "baz");
}

TEST(Any, CustomToProto) {
google::protobuf::SimpleDescriptorDatabase database;
{
google::protobuf::FileDescriptorProto fd;
google::protobuf::Any::descriptor()->file()->CopyTo(&fd);
ASSERT_TRUE(database.Add(fd));
}
google::protobuf::DescriptorPool pool(&database);
pool.AllowUnknownDependencies();
google::protobuf::DynamicMessageFactory factory(&pool);
factory.SetDelegateToGeneratedFactory(false);
std::unique_ptr<google::protobuf::Message> proto = absl::WrapUnique(
factory.GetPrototype(pool.FindMessageTypeByName("google.protobuf.Any"))
->New());
const auto* descriptor = proto->GetDescriptor();
const auto* reflection = proto->GetReflection();
const auto* type_url_field = descriptor->FindFieldByName("type_url");
ASSERT_NE(type_url_field, nullptr);
const auto* value_field = descriptor->FindFieldByName("value");
ASSERT_NE(value_field, nullptr);

ASSERT_OK(SetAny(*proto, "type.googleapis.com/foo.Bar", absl::Cord("baz")));

EXPECT_EQ(reflection->GetString(*proto, type_url_field),
"type.googleapis.com/foo.Bar");
EXPECT_EQ(reflection->GetString(*proto, value_field), "baz");
}

} // namespace
} // namespace cel::extensions::protobuf_internal
104 changes: 98 additions & 6 deletions extensions/protobuf/struct_type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
#include <string>
#include <utility>

#include "google/protobuf/duration.pb.h"
#include "google/protobuf/struct.pb.h"
#include "google/protobuf/wrappers.pb.h"
#include "absl/base/attributes.h"
#include "absl/base/optimization.h"
#include "absl/functional/any_invocable.h"
Expand All @@ -27,19 +30,22 @@
#include "absl/memory/memory.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/cord.h"
#include "absl/strings/str_cat.h"
#include "base/type_manager.h"
#include "base/value_factory.h"
#include "base/values/struct_value_builder.h"
#include "eval/internal/errors.h"
#include "extensions/protobuf/enum_type.h"
#include "extensions/protobuf/internal/any.h"
#include "extensions/protobuf/internal/map_reflection.h"
#include "extensions/protobuf/internal/reflection.h"
#include "extensions/protobuf/internal/time.h"
#include "extensions/protobuf/internal/wrappers.h"
#include "extensions/protobuf/memory_manager.h"
#include "extensions/protobuf/struct_value.h"
#include "extensions/protobuf/type.h"
#include "internal/proto_time_encoding.h"
#include "internal/status_macros.h"
#include "google/protobuf/arena.h"
#include "google/protobuf/descriptor.h"
Expand Down Expand Up @@ -282,6 +288,10 @@ struct CheckedCast<uint64_t, uint32_t> {
}
};

std::string MakeAnyTypeUrl(absl::string_view type_name) {
return absl::StrCat("type.googleapis.com/", type_name);
}

// TODO(uncreated-issue/47): handle subtle implicit conversions around mixed numeric
class ProtoStructValueBuilder final : public StructValueBuilderInterface {
public:
Expand Down Expand Up @@ -1206,16 +1216,98 @@ class ProtoStructValueBuilder final : public StructValueBuilderInterface {
}
}

absl::Status SetSingularAnyField(const StructType::Field& field,
const google::protobuf::Reflection& reflect,
const google::protobuf::FieldDescriptor& field_desc,
Handle<Value>&& value) {
std::string type_url;
absl::Cord payload;
switch (value->kind()) {
case ValueKind::kNull: {
google::protobuf::Value proto;
type_url = MakeAnyTypeUrl(proto.GetDescriptor()->full_name());
payload = proto.SerializeAsCord();
} break;
case ValueKind::kBool: {
google::protobuf::BoolValue proto;
proto.set_value(value->As<BoolValue>().value());
type_url = MakeAnyTypeUrl(proto.GetDescriptor()->full_name());
payload = proto.SerializeAsCord();
} break;
case ValueKind::kInt: {
google::protobuf::Int64Value proto;
proto.set_value(value->As<IntValue>().value());
type_url = MakeAnyTypeUrl(proto.GetDescriptor()->full_name());
payload = proto.SerializeAsCord();
} break;
case ValueKind::kUint: {
google::protobuf::UInt64Value proto;
proto.set_value(value->As<UintValue>().value());
type_url = MakeAnyTypeUrl(proto.GetDescriptor()->full_name());
payload = proto.SerializeAsCord();
} break;
case ValueKind::kDouble: {
google::protobuf::DoubleValue proto;
proto.set_value(value->As<DoubleValue>().value());
type_url = MakeAnyTypeUrl(proto.GetDescriptor()->full_name());
payload = proto.SerializeAsCord();
} break;
case ValueKind::kBytes: {
google::protobuf::BytesValue proto;
proto.set_value(value->As<BytesValue>().ToString());
type_url = MakeAnyTypeUrl(proto.GetDescriptor()->full_name());
payload = proto.SerializeAsCord();
} break;
case ValueKind::kString: {
google::protobuf::StringValue proto;
proto.set_value(value->As<StringValue>().ToString());
type_url = MakeAnyTypeUrl(proto.GetDescriptor()->full_name());
payload = proto.SerializeAsCord();
} break;
case ValueKind::kDuration: {
google::protobuf::Duration proto;
CEL_RETURN_IF_ERROR(internal::EncodeDuration(
value->As<DurationValue>().value(), &proto));
type_url = MakeAnyTypeUrl(proto.GetDescriptor()->full_name());
payload = proto.SerializeAsCord();
} break;
case ValueKind::kTimestamp: {
google::protobuf::Timestamp proto;
CEL_RETURN_IF_ERROR(
internal::EncodeTime(value->As<TimestampValue>().value(), &proto));
type_url = MakeAnyTypeUrl(proto.GetDescriptor()->full_name());
payload = proto.SerializeAsCord();
} break;
case ValueKind::kStruct: {
if (ABSL_PREDICT_FALSE(!value->Is<ProtoStructValue>())) {
return absl::InvalidArgumentError(
"StructValueBuilderInterface::SetField does not yet support "
"converting custom types to "
"google.protobuf.Any");
}
type_url = MakeAnyTypeUrl(
value->type()->As<ProtoStructType>().descriptor().full_name());
CEL_ASSIGN_OR_RETURN(payload,
value->As<ProtoStructValue>().SerializeAsCord());
} break;
default:
return absl::InvalidArgumentError(absl::StrCat(
"StructValueBuilderInterface::SetField does not yet support "
"converting ",
value->type()->DebugString(), " to google.protobuf.Any"));
}
return protobuf_internal::SetAny(
*reflect.MutableMessage(message_, &field_desc, factory_), type_url,
payload);
}

absl::Status SetSingularMessageField(
const StructType::Field& field, const google::protobuf::Reflection& reflect,
const google::protobuf::FieldDescriptor& field_desc, Handle<Value>&& value) {
switch (field.type->kind()) {
case TypeKind::kAny: {
// google.protobuf.Any
return absl::UnimplementedError(
"StructValueBuilderInterface::SetField does not yet implement "
"google.protobuf.Any support");
}
case TypeKind::kAny:
return SetSingularAnyField(field, reflect, field_desc,
std::move(value));
case TypeKind::kDyn: {
// google.protobuf.Value
return absl::UnimplementedError(
Expand Down
Loading

0 comments on commit 98fbc3b

Please # to comment.