diff --git a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldNumberTree.java b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldNumberTree.java index e77950a0c..5bb5dd04f 100644 --- a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldNumberTree.java +++ b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldNumberTree.java @@ -117,6 +117,16 @@ static FieldNumberTree fromMessage(Message message) { return tree; } + static FieldNumberTree fromMessages(Iterable messages) { + FieldNumberTree tree = new FieldNumberTree(); + for (Message message : messages) { + if (message != null) { + tree.merge(fromMessage(message)); + } + } + return tree; + } + private static FieldNumberTree fromUnknownFieldSet(UnknownFieldSet unknownFieldSet) { FieldNumberTree tree = new FieldNumberTree(); for (int fieldNumber : unknownFieldSet.asMap().keySet()) { diff --git a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopeImpl.java b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopeImpl.java index 76fb0d6c6..4acf9916b 100644 --- a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopeImpl.java +++ b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopeImpl.java @@ -81,22 +81,8 @@ static FieldScope createFromSetFields(Iterable messages) { "Cannot create scope from messages with different descriptors: %s", getDescriptors(messages)); - Message.Builder builder = null; - for (Message message : messages) { - if (message == null) { - continue; - } - - if (builder != null) { - builder.mergeFrom(message); - } else { - builder = message.toBuilder(); - } - } - - Message aggregateMessage = builder.build(); return create( - FieldScopeLogic.partialScope(aggregateMessage), + FieldScopeLogic.partialScope(messages, optDescriptor.get()), Functions.constant(String.format("FieldScopes.fromSetFields(%s)", formatList(messages)))); } diff --git a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopeLogic.java b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopeLogic.java index 9c4bf8f54..f112a0eab 100644 --- a/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopeLogic.java +++ b/extensions/proto/src/main/java/com/google/common/truth/extensions/proto/FieldScopeLogic.java @@ -20,6 +20,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.truth.extensions.proto.FieldScopeUtil.join; +import com.google.common.base.Joiner; import com.google.common.base.Optional; import com.google.common.base.Verify; import com.google.common.collect.ImmutableList; @@ -244,13 +245,13 @@ private static PartialScopeLogic newPartialScopeLogic(FieldNumberTree fieldNumbe } private static final class RootPartialScopeLogic extends PartialScopeLogic { - private final Message message; + private final String repr; private final Descriptor expectedDescriptor; - RootPartialScopeLogic(Message message) { - super(FieldNumberTree.fromMessage(message)); - this.message = message; - this.expectedDescriptor = message.getDescriptorForType(); + RootPartialScopeLogic(FieldNumberTree fieldNumberTree, String repr, Descriptor descriptor) { + super(fieldNumberTree); + this.repr = repr; + this.expectedDescriptor = descriptor; } @Override @@ -270,12 +271,20 @@ public void validate( @Override public String toString() { - return String.format("FieldScopes.fromSetFields(%s)", message); + return String.format("FieldScopes.fromSetFields(%s)", repr); } } static FieldScopeLogic partialScope(Message message) { - return new RootPartialScopeLogic(message); + return new RootPartialScopeLogic( + FieldNumberTree.fromMessage(message), message.toString(), message.getDescriptorForType()); + } + + static FieldScopeLogic partialScope(Iterable messages, Descriptor descriptor) { + return new RootPartialScopeLogic( + FieldNumberTree.fromMessages(messages), + Joiner.on(", ").useForNull("null").join(messages), + descriptor); } // TODO(user): Performance: Optimize FieldNumbersLogic and FieldDescriptorsLogic for diff --git a/extensions/proto/src/test/java/com/google/common/truth/extensions/proto/FieldScopesTest.java b/extensions/proto/src/test/java/com/google/common/truth/extensions/proto/FieldScopesTest.java index e5d4cd945..fb0a07dec 100644 --- a/extensions/proto/src/test/java/com/google/common/truth/extensions/proto/FieldScopesTest.java +++ b/extensions/proto/src/test/java/com/google/common/truth/extensions/proto/FieldScopesTest.java @@ -507,6 +507,36 @@ public void testFromSetFields() { .contains("ignored: o_sub_test_message.o_test_message.r_string"); } + @Test + public void testFromSetFields_comparingExpectedFieldsOnly() + throws InvalidProtocolBufferException { + + Message message1 = parse("o_int: 1 o_double: 333 oneof_message1: { o_int: 3 o_double: 333 }"); + Message message2 = + parse("o_int: 333 o_double: 1.2 oneof_message2: { o_int: 333 o_double: 3.14 }"); + Message diffMessage1 = parse("o_int: 1 oneof_message1: { o_int: 4 }"); + Message diffMessage2 = parse("o_double: 1.2 oneof_message2: { o_double: 4.14 }"); + Message eqMessage1 = parse("o_int: 1 oneof_message1: { o_int: 3 }"); + Message eqMessage2 = parse("o_double: 1.2 oneof_message2: { o_double: 3.14 }"); + + expectThat(message1).comparingExpectedFieldsOnly().isEqualTo(eqMessage1); + expectThat(message2).comparingExpectedFieldsOnly().isEqualTo(eqMessage2); + expectFailureWhenTesting().that(message1).comparingExpectedFieldsOnly().isEqualTo(diffMessage1); + expectFailureWhenTesting().that(message2).comparingExpectedFieldsOnly().isEqualTo(diffMessage2); + + expectThat(listOf(message1, message2)) + .comparingExpectedFieldsOnly() + .containsExactly(eqMessage1, eqMessage2); + expectFailureWhenTesting() + .that(listOf(message1, message2)) + .comparingExpectedFieldsOnly() + .containsExactly(diffMessage1, eqMessage2); + expectFailureWhenTesting() + .that(listOf(message1, message2)) + .comparingExpectedFieldsOnly() + .containsExactly(eqMessage1, diffMessage2); + } + @Test public void testFromSetFields_unknownFields() throws InvalidProtocolBufferException { // Make sure that merging of repeated fields, separation by tag number, and separation by @@ -514,9 +544,9 @@ public void testFromSetFields_unknownFields() throws InvalidProtocolBufferExcept Message scopeMessage = fromUnknownFields( UnknownFieldSet.newBuilder() - .addField(20, Field.newBuilder().addFixed32(1).addFixed64(1).build()) + .addField(333, Field.newBuilder().addFixed32(1).addFixed64(1).build()) .addField( - 21, + 444, Field.newBuilder() .addVarint(1) .addLengthDelimited(ByteString.copyFrom("1", UTF_8)) @@ -535,9 +565,9 @@ public void testFromSetFields_unknownFields() throws InvalidProtocolBufferExcept Message message = fromUnknownFields( UnknownFieldSet.newBuilder() - .addField(19, Field.newBuilder().addFixed32(2).addFixed64(2).build()) + .addField(222, Field.newBuilder().addFixed32(2).addFixed64(2).build()) .addField( - 20, + 333, Field.newBuilder() .addFixed32(1) .addFixed64(1) @@ -549,7 +579,7 @@ public void testFromSetFields_unknownFields() throws InvalidProtocolBufferExcept .build()) .build()) .addField( - 21, + 444, Field.newBuilder() .addFixed32(2) .addFixed64(2) @@ -566,9 +596,9 @@ public void testFromSetFields_unknownFields() throws InvalidProtocolBufferExcept Message diffMessage = fromUnknownFields( UnknownFieldSet.newBuilder() - .addField(19, Field.newBuilder().addFixed32(3).addFixed64(3).build()) + .addField(222, Field.newBuilder().addFixed32(3).addFixed64(3).build()) .addField( - 20, + 333, Field.newBuilder() .addFixed32(4) .addFixed64(4) @@ -580,7 +610,7 @@ public void testFromSetFields_unknownFields() throws InvalidProtocolBufferExcept .build()) .build()) .addField( - 21, + 444, Field.newBuilder() .addFixed32(3) .addFixed64(3) @@ -597,9 +627,9 @@ public void testFromSetFields_unknownFields() throws InvalidProtocolBufferExcept Message eqMessage = fromUnknownFields( UnknownFieldSet.newBuilder() - .addField(19, Field.newBuilder().addFixed32(3).addFixed64(3).build()) + .addField(222, Field.newBuilder().addFixed32(3).addFixed64(3).build()) .addField( - 20, + 333, Field.newBuilder() .addFixed32(1) .addFixed64(1) @@ -611,7 +641,7 @@ public void testFromSetFields_unknownFields() throws InvalidProtocolBufferExcept .build()) .build()) .addField( - 21, + 444, Field.newBuilder() .addFixed32(3) .addFixed64(3) diff --git a/extensions/proto/src/test/proto/test_message2.proto b/extensions/proto/src/test/proto/test_message2.proto index fd79a975a..547caed0d 100644 --- a/extensions/proto/src/test/proto/test_message2.proto +++ b/extensions/proto/src/test/proto/test_message2.proto @@ -40,6 +40,11 @@ message TestMessage2 { map test_message_map = 17; optional .google.protobuf.Any o_any_message = 18; repeated .google.protobuf.Any r_any_message = 19; + + oneof oneof_field { + TestMessage2 oneof_message1 = 20; + TestMessage2 oneof_message2 = 21; + } } message RequiredStringMessage2 { diff --git a/extensions/proto/src/test/proto/test_message3.proto b/extensions/proto/src/test/proto/test_message3.proto index 128ed9bc3..0a6c0d12b 100644 --- a/extensions/proto/src/test/proto/test_message3.proto +++ b/extensions/proto/src/test/proto/test_message3.proto @@ -40,6 +40,11 @@ message TestMessage3 { map test_message_map = 17; .google.protobuf.Any o_any_message = 18; repeated .google.protobuf.Any r_any_message = 19; + + oneof oneof_field { + TestMessage3 oneof_message1 = 20; + TestMessage3 oneof_message2 = 21; + } } // message RequiredStringMessage3 {