Skip to content

Commit

Permalink
Support reading raw extension fields in a MessageSet. (#1755)
Browse files Browse the repository at this point in the history
If an encoder doesn't know about the bit, they this is how the fields
would be encoded, and there has been a push upstream to support this
including a new conformance test that looks for it.

- Support this in the parsing loop.
- Remove the conformance failure not that is passes.
- Add some tests (includes proto file changes and regeneration)
  • Loading branch information
thomasvl authored Feb 12, 2025
1 parent eb17584 commit 101ecdb
Show file tree
Hide file tree
Showing 6 changed files with 417 additions and 32 deletions.
21 changes: 21 additions & 0 deletions Protos/SwiftProtobufTests/unittest_mset.proto
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,21 @@ message TestMessageSetContainer {
optional swift_proto_testing.wire_format.TestMessageSet message_set = 1;
}

// A message without the message_set_wire_format option but still supports
// extensions.
message MessageEx {
extensions 4 to max;
}

message TestMessageSetExtension1 {
extend swift_proto_testing.wire_format.TestMessageSet {
optional TestMessageSetExtension1 message_set_extension = 1545008;
}
// Also extend a non-MessageSet with the same field number. This will allow us
// to test parsing a normal extension into a MessageSet.
extend MessageEx {
optional TestMessageSetExtension1 doppelganger_message_set_extension = 1545008;
}
optional int32 i = 15;
optional swift_proto_testing.wire_format.TestMessageSet recursive = 16;
optional string test_aliasing = 17 [ctype = STRING_PIECE];
Expand All @@ -61,6 +72,16 @@ message TestMessageSetExtension2 {
optional string str = 25;
}

// This isn't on swift_proto_testing.wire_format.TestMessageSet, so it will be unknown
// when parsing there.
message TestMessageSetExtension3 {
extend MessageEx {
optional TestMessageSetExtension3 doppelganger_message_set_extension = 1547770;
}
optional int32 x = 26;
}


// MessageSet wire format is equivalent to this.
message RawMessageSet {
repeated group Item = 1 {
Expand Down
153 changes: 152 additions & 1 deletion Reference/SwiftProtobufTests/unittest_mset.pb.swift
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,20 @@ struct SwiftProtoTesting_TestMessageSetContainer: Sendable {
fileprivate var _messageSet: SwiftProtoTesting_WireFormat_TestMessageSet? = nil
}

/// A message without the message_set_wire_format option but still supports
/// extensions.
struct SwiftProtoTesting_MessageEx: SwiftProtobuf.ExtensibleMessage, Sendable {
// SwiftProtobuf.Message conformance is added in an extension below. See the
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
// methods supported on all messages.

var unknownFields = SwiftProtobuf.UnknownStorage()

init() {}

var _protobuf_extensionFieldValues = SwiftProtobuf.ExtensionFieldValueSet()
}

struct SwiftProtoTesting_TestMessageSetExtension1: Sendable {
// SwiftProtobuf.Message conformance is added in an extension below. See the
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
Expand Down Expand Up @@ -141,6 +155,29 @@ struct SwiftProtoTesting_TestMessageSetExtension2: Sendable {
fileprivate var _str: String? = nil
}

/// This isn't on swift_proto_testing.wire_format.TestMessageSet, so it will be unknown
/// when parsing there.
struct SwiftProtoTesting_TestMessageSetExtension3: Sendable {
// SwiftProtobuf.Message conformance is added in an extension below. See the
// `Message` and `Message+*Additions` files in the SwiftProtobuf library for
// methods supported on all messages.

var x: Int32 {
get {return _x ?? 0}
set {_x = newValue}
}
/// Returns true if `x` has been explicitly set.
var hasX: Bool {return self._x != nil}
/// Clears the value of `x`. Subsequent reads from it will return its default value.
mutating func clearX() {self._x = nil}

var unknownFields = SwiftProtobuf.UnknownStorage()

init() {}

fileprivate var _x: Int32? = nil
}

/// MessageSet wire format is equivalent to this.
struct SwiftProtoTesting_RawMessageSet: Sendable {
// SwiftProtobuf.Message conformance is added in an extension below. See the
Expand Down Expand Up @@ -222,6 +259,39 @@ struct SwiftProtoTesting_RawBreakableMessageSet: Sendable {
// declaration. To avoid naming collisions, the names are prefixed with the name of
// the scope where the extend directive occurs.

extension SwiftProtoTesting_MessageEx {

var SwiftProtoTesting_TestMessageSetExtension1_doppelgangerMessageSetExtension: SwiftProtoTesting_TestMessageSetExtension1 {
get {return getExtensionValue(ext: SwiftProtoTesting_TestMessageSetExtension1.Extensions.doppelganger_message_set_extension) ?? SwiftProtoTesting_TestMessageSetExtension1()}
set {setExtensionValue(ext: SwiftProtoTesting_TestMessageSetExtension1.Extensions.doppelganger_message_set_extension, value: newValue)}
}
/// Returns true if extension `SwiftProtoTesting_TestMessageSetExtension1.Extensions.doppelganger_message_set_extension`
/// has been explicitly set.
var hasSwiftProtoTesting_TestMessageSetExtension1_doppelgangerMessageSetExtension: Bool {
return hasExtensionValue(ext: SwiftProtoTesting_TestMessageSetExtension1.Extensions.doppelganger_message_set_extension)
}
/// Clears the value of extension `SwiftProtoTesting_TestMessageSetExtension1.Extensions.doppelganger_message_set_extension`.
/// Subsequent reads from it will return its default value.
mutating func clearSwiftProtoTesting_TestMessageSetExtension1_doppelgangerMessageSetExtension() {
clearExtensionValue(ext: SwiftProtoTesting_TestMessageSetExtension1.Extensions.doppelganger_message_set_extension)
}

var SwiftProtoTesting_TestMessageSetExtension3_doppelgangerMessageSetExtension: SwiftProtoTesting_TestMessageSetExtension3 {
get {return getExtensionValue(ext: SwiftProtoTesting_TestMessageSetExtension3.Extensions.doppelganger_message_set_extension) ?? SwiftProtoTesting_TestMessageSetExtension3()}
set {setExtensionValue(ext: SwiftProtoTesting_TestMessageSetExtension3.Extensions.doppelganger_message_set_extension, value: newValue)}
}
/// Returns true if extension `SwiftProtoTesting_TestMessageSetExtension3.Extensions.doppelganger_message_set_extension`
/// has been explicitly set.
var hasSwiftProtoTesting_TestMessageSetExtension3_doppelgangerMessageSetExtension: Bool {
return hasExtensionValue(ext: SwiftProtoTesting_TestMessageSetExtension3.Extensions.doppelganger_message_set_extension)
}
/// Clears the value of extension `SwiftProtoTesting_TestMessageSetExtension3.Extensions.doppelganger_message_set_extension`.
/// Subsequent reads from it will return its default value.
mutating func clearSwiftProtoTesting_TestMessageSetExtension3_doppelgangerMessageSetExtension() {
clearExtensionValue(ext: SwiftProtoTesting_TestMessageSetExtension3.Extensions.doppelganger_message_set_extension)
}
}

extension SwiftProtoTesting_WireFormat_TestMessageSet {

var SwiftProtoTesting_TestMessageSetExtension1_messageSetExtension: SwiftProtoTesting_TestMessageSetExtension1 {
Expand Down Expand Up @@ -264,7 +334,9 @@ extension SwiftProtoTesting_WireFormat_TestMessageSet {
/// a larger `SwiftProtobuf.SimpleExtensionMap`.
let SwiftProtoTesting_UnittestMset_Extensions: SwiftProtobuf.SimpleExtensionMap = [
SwiftProtoTesting_TestMessageSetExtension1.Extensions.message_set_extension,
SwiftProtoTesting_TestMessageSetExtension2.Extensions.message_set_extension
SwiftProtoTesting_TestMessageSetExtension1.Extensions.doppelganger_message_set_extension,
SwiftProtoTesting_TestMessageSetExtension2.Extensions.message_set_extension,
SwiftProtoTesting_TestMessageSetExtension3.Extensions.doppelganger_message_set_extension
]

// Extension Objects - The only reason these might be needed is when manually
Expand All @@ -277,6 +349,11 @@ extension SwiftProtoTesting_TestMessageSetExtension1 {
_protobuf_fieldNumber: 1545008,
fieldName: "swift_proto_testing.TestMessageSetExtension1"
)

static let doppelganger_message_set_extension = SwiftProtobuf.MessageExtension<SwiftProtobuf.OptionalMessageExtensionField<SwiftProtoTesting_TestMessageSetExtension1>, SwiftProtoTesting_MessageEx>(
_protobuf_fieldNumber: 1545008,
fieldName: "swift_proto_testing.TestMessageSetExtension1.doppelganger_message_set_extension"
)
}
}

Expand All @@ -289,6 +366,15 @@ extension SwiftProtoTesting_TestMessageSetExtension2 {
}
}

extension SwiftProtoTesting_TestMessageSetExtension3 {
enum Extensions {
static let doppelganger_message_set_extension = SwiftProtobuf.MessageExtension<SwiftProtobuf.OptionalMessageExtensionField<SwiftProtoTesting_TestMessageSetExtension3>, SwiftProtoTesting_MessageEx>(
_protobuf_fieldNumber: 1547770,
fieldName: "swift_proto_testing.TestMessageSetExtension3.doppelganger_message_set_extension"
)
}
}

// MARK: - Code below here is support for the SwiftProtobuf runtime.

fileprivate let _protobuf_package = "swift_proto_testing"
Expand Down Expand Up @@ -334,6 +420,35 @@ extension SwiftProtoTesting_TestMessageSetContainer: SwiftProtobuf.Message, Swif
}
}

extension SwiftProtoTesting_MessageEx: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
static let protoMessageName: String = _protobuf_package + ".MessageEx"
static let _protobuf_nameMap = SwiftProtobuf._NameMap()

public var isInitialized: Bool {
if !_protobuf_extensionFieldValues.isInitialized {return false}
return true
}

mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
while let fieldNumber = try decoder.nextFieldNumber() {
if (4 <= fieldNumber && fieldNumber < 536870912) {
try decoder.decodeExtensionField(values: &_protobuf_extensionFieldValues, messageType: SwiftProtoTesting_MessageEx.self, fieldNumber: fieldNumber)
}
}
}

func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
try visitor.visitExtensionFields(fields: _protobuf_extensionFieldValues, start: 4, end: 536870912)
try unknownFields.traverse(visitor: &visitor)
}

static func ==(lhs: SwiftProtoTesting_MessageEx, rhs: SwiftProtoTesting_MessageEx) -> Bool {
if lhs.unknownFields != rhs.unknownFields {return false}
if lhs._protobuf_extensionFieldValues != rhs._protobuf_extensionFieldValues {return false}
return true
}
}

extension SwiftProtoTesting_TestMessageSetExtension1: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
static let protoMessageName: String = _protobuf_package + ".TestMessageSetExtension1"
static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
Expand Down Expand Up @@ -423,6 +538,42 @@ extension SwiftProtoTesting_TestMessageSetExtension2: SwiftProtobuf.Message, Swi
}
}

extension SwiftProtoTesting_TestMessageSetExtension3: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
static let protoMessageName: String = _protobuf_package + ".TestMessageSetExtension3"
static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
26: .same(proto: "x"),
]

mutating func decodeMessage<D: SwiftProtobuf.Decoder>(decoder: inout D) throws {
while let fieldNumber = try decoder.nextFieldNumber() {
// The use of inline closures is to circumvent an issue where the compiler
// allocates stack space for every case branch when no optimizations are
// enabled. https://github.com/apple/swift-protobuf/issues/1034
switch fieldNumber {
case 26: try { try decoder.decodeSingularInt32Field(value: &self._x) }()
default: break
}
}
}

func traverse<V: SwiftProtobuf.Visitor>(visitor: inout V) throws {
// The use of inline closures is to circumvent an issue where the compiler
// allocates stack space for every if/case branch local when no optimizations
// are enabled. https://github.com/apple/swift-protobuf/issues/1034 and
// https://github.com/apple/swift-protobuf/issues/1182
try { if let v = self._x {
try visitor.visitSingularInt32Field(value: v, fieldNumber: 26)
} }()
try unknownFields.traverse(visitor: &visitor)
}

static func ==(lhs: SwiftProtoTesting_TestMessageSetExtension3, rhs: SwiftProtoTesting_TestMessageSetExtension3) -> Bool {
if lhs._x != rhs._x {return false}
if lhs.unknownFields != rhs.unknownFields {return false}
return true
}
}

extension SwiftProtoTesting_RawMessageSet: SwiftProtobuf.Message, SwiftProtobuf._MessageImplementationBase, SwiftProtobuf._ProtoNameProviding {
static let protoMessageName: String = _protobuf_package + ".RawMessageSet"
static let _protobuf_nameMap: SwiftProtobuf._NameMap = [
Expand Down
2 changes: 1 addition & 1 deletion Sources/Conformance/failure_list_swift.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Recommended.Proto2.ProtobufInput.ValidMessageSetEncoding.SubmessageEncoding.NotUnknown.ProtobufOutput # Output was not equivalent to reference message: added: message_set_correct.(protobuf_test_messages.proto2.TestAllTypesProto2.Ext
# Nothing failing.
71 changes: 42 additions & 29 deletions Sources/SwiftProtobuf/BinaryDecoder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -1113,39 +1113,52 @@ internal struct BinaryDecoder: Decoder {
values: inout ExtensionFieldValueSet,
messageType: any Message.Type
) throws {
// Spin looking for the Item group, everything else will end up in unknown fields.
// Anything not in an acceptable form will go into unknown fields
while let fieldNumber = try self.nextFieldNumber() {
guard fieldNumber == WireFormat.MessageSet.FieldNumbers.item && fieldWireFormat == WireFormat.startGroup
else {
continue
}

// This is similar to decodeFullGroup

try incrementRecursionDepth()
var subDecoder = self
subDecoder.groupFieldNumber = fieldNumber
subDecoder.consumed = true
// Normal MessageSet wire format (nested in a group)
if fieldNumber == WireFormat.MessageSet.FieldNumbers.item && fieldWireFormat == WireFormat.startGroup {
// This is similar to decodeFullGroup

try incrementRecursionDepth()
var subDecoder = self
subDecoder.groupFieldNumber = fieldNumber
subDecoder.consumed = true

let itemResult = try subDecoder.decodeMessageSetItem(
values: &values,
messageType: messageType
)
switch itemResult {
case .success:
// Advance over what was parsed.
consume(length: available - subDecoder.available)
consumed = true
case .handleAsUnknown:
// Nothing to do.
break

let itemResult = try subDecoder.decodeMessageSetItem(
values: &values,
messageType: messageType
)
switch itemResult {
case .success:
// Advance over what was parsed.
consume(length: available - subDecoder.available)
consumed = true
case .handleAsUnknown:
// Nothing to do.
break
case .malformed:
throw BinaryDecodingError.malformedProtobuf
}

case .malformed:
throw BinaryDecodingError.malformedProtobuf
assert(recursionBudget == subDecoder.recursionBudget)
decrementRecursionDepth()
} else if fieldWireFormat == WireFormat.lengthDelimited,
let ext = extensions?[messageType, fieldNumber]
{
// This was a raw extension field, this is possible if some encoder doesn't
// know the MessageSet wire format. Since we know the extension, promote it.
// _upb_Decoder_FindField() has this same basic logic.
try decodeExtensionField(
values: &values,
messageType: messageType,
fieldNumber: fieldNumber,
messageExtension: ext
)
if !consumed {
throw BinaryDecodingError.malformedProtobuf
}
}

assert(recursionBudget == subDecoder.recursionBudget)
decrementRecursionDepth()
}
}

Expand Down
49 changes: 49 additions & 0 deletions Tests/SwiftProtobufTests/Test_MessageSet.swift
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,55 @@ final class Test_MessageSet: XCTestCase {
validator.validate(message: msg)
}

func testParse_FieldEncoding() {
let extMsg1 = SwiftProtoTesting_TestMessageSetExtension1.with { $0.i = 123 }
let extMsg3 = SwiftProtoTesting_TestMessageSetExtension3.with { $0.x = 10 }

let msgEx = SwiftProtoTesting_MessageEx.with {
$0.SwiftProtoTesting_TestMessageSetExtension1_doppelgangerMessageSetExtension = extMsg1
$0.SwiftProtoTesting_TestMessageSetExtension3_doppelgangerMessageSetExtension = extMsg3
}

let serialized: Data
do {
serialized = try msgEx.serializedBytes()
} catch let e {
XCTFail("Failed to serialize: \(e)")
return
}

let msg: SwiftProtoTesting_WireFormat_TestMessageSet
do {
msg = try SwiftProtoTesting_WireFormat_TestMessageSet(
serializedBytes: serialized,
extensions: SwiftProtoTesting_UnittestMset_Extensions
)
} catch let e {
XCTFail("Failed to parse: \(e)")
return
}

// One comes in as a known field, other comes in as unknown field (not promoted to the
// group form).
XCTAssertTrue(msg.hasSwiftProtoTesting_TestMessageSetExtension1_messageSetExtension)
XCTAssertFalse(msg.hasSwiftProtoTesting_TestMessageSetExtension2_messageSetExtension)

let expectedUnknowns = Data([
210, 223, 243, 5, // Length delimted field 1547770
3, // Length (varint)
208, 1, // Varint field 26 ("x") in message
10, // Value of X
])
XCTAssertEqual(msg.unknownFields.data, expectedUnknowns)

var validator = ExtensionValidator()
validator.expectedMessages = [
(SwiftProtoTesting_TestMessageSetExtension1.Extensions.message_set_extension.fieldNumber, false)
]
validator.expectedUnknowns = [expectedUnknowns]
validator.validate(message: msg)
}

fileprivate struct ExtensionValidator: PBTestVisitor {
// Values are field number and if we should recurse.
var expectedMessages = [(Int, Bool)]()
Expand Down
Loading

0 comments on commit 101ecdb

Please # to comment.