Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Avoid duplicate null check and fix null serialization property name #6066

Merged
merged 5 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1464,50 +1464,27 @@ private MethodBodyStatement WrapInIsDefined(
bool propertyIsNullable,
MethodBodyStatement writePropertySerializationStatement)
{
// Create the first conditional statement to check if the property is defined
if (propertyIsNullable)
// Non-nullable value types can be serialized directly
if (IsNonNullableValueType(propertyType))
{
writePropertySerializationStatement = CheckPropertyIsInitialized(
propertyType,
wireInfo,
propertyIsRequired,
propertyExpression,
writePropertySerializationStatement);
return writePropertySerializationStatement;
}

// Directly return the statement if the property is required or a non-nullable value type that is not JsonElement
if (IsRequiredOrNonNullableValueType(propertyType, propertyIsRequired))
// Required properties that are not nullable can be serialized directly
if (propertyIsRequired && !propertyIsNullable)
{
return writePropertySerializationStatement;
}

// Conditionally serialize based on whether the property is a collection or a single value
return CreateConditionalSerializationStatement(propertyType, propertyExpression, propertyIsReadOnly, writePropertySerializationStatement);
}

private IfElseStatement CheckPropertyIsInitialized(
CSharpType propertyType,
PropertyWireInformation wireInfo,
bool isPropRequired,
MemberExpression propertyMemberExpression,
MethodBodyStatement writePropertySerializationStatement)
{
ScopedApi<bool> propertyIsInitialized;

if (propertyType.IsCollection && !propertyType.IsReadOnlyMemory && isPropRequired)
{
propertyIsInitialized = propertyMemberExpression.NotEqual(Null)
.And(OptionalSnippets.IsCollectionDefined(propertyMemberExpression));
}
else
{
propertyIsInitialized = propertyMemberExpression.NotEqual(Null);
}

return new IfElseStatement(
propertyIsInitialized,
writePropertySerializationStatement,
_utf8JsonWriterSnippet.WriteNull(wireInfo.SerializedName.ToVariableName()));
return CreateConditionalSerializationStatement(
propertyType,
propertyExpression,
propertyIsReadOnly,
propertyIsNullable,
propertyIsRequired,
wireInfo.SerializedName,
writePropertySerializationStatement);
}

/// <summary>
Expand Down Expand Up @@ -1764,20 +1741,31 @@ private static ScopedApi GetEnumerableExpression(ValueExpression expression, CSh
return expression.As(new CSharpType(typeof(IEnumerable<>), itemType));
}

private static bool IsRequiredOrNonNullableValueType(CSharpType propertyType, bool isRequired)
=> isRequired || (!propertyType.IsNullable && propertyType.IsValueType && !propertyType.Equals(typeof(JsonElement)));
private static bool IsNonNullableValueType(CSharpType propertyType)
=> propertyType is { IsNullable: false, IsValueType: true } && !propertyType.Equals(typeof(JsonElement));

private IfStatement CreateConditionalSerializationStatement(
private MethodBodyStatement CreateConditionalSerializationStatement(
CSharpType propertyType,
MemberExpression propertyMemberExpression,
bool isReadOnly,
bool isNullable,
bool isRequired,
string serializedName,
MethodBodyStatement writePropertySerializationStatement)
{
var isDefinedCondition = propertyType.IsCollection && !propertyType.IsReadOnlyMemory
var isDefinedCondition = propertyType is { IsCollection: true, IsReadOnlyMemory: false }
? OptionalSnippets.IsCollectionDefined(propertyMemberExpression)
: OptionalSnippets.IsDefined(propertyMemberExpression);
var condition = isReadOnly ? _isNotEqualToWireConditionSnippet.And(isDefinedCondition) : isDefinedCondition;

if (isRequired && isNullable)
{
return new IfElseStatement(
condition,
writePropertySerializationStatement,
_utf8JsonWriterSnippet.WriteNull(serializedName));
}

return new IfStatement(condition) { writePropertySerializationStatement };
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
"optionalLiteralInt": 456,
"optionalLiteralFloat": 4.56,
"optionalLiteralBool": false,
"optionalNullableList": []
"optionalNullableList": [],
"requiredNullableString": null
},
"intExtensibleEnum": 1,
"intExtensibleEnumCollection": [1, 3],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
"optionalLiteralInt": 456,
"optionalLiteralFloat": 4.56,
"optionalLiteralBool": false,
"optionalNullableList": []
"optionalNullableList": [],
"requiredNullableString": null
},
"intExtensibleEnum": 1,
"intExtensibleEnumCollection": [1, 3],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
"optionalLiteralFloat": 4.56,
"optionalLiteralBool": false,
"optionalNullableList": [],
"requiredNullableString": null,
"extra": "stuff"
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@
"optionalLiteralInt": 456,
"optionalLiteralFloat": 4.56,
"optionalLiteralBool": false,
"optionalNullableList": []
"optionalNullableList": [],
"requiredNullableString": null
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"name": "Example Thing",
"requiredUnion": "mockUnion",
"requiredBadDescription": "This is a description with potentially problematic characters like < or >.",
"requiredNullableList": null,
"requiredLiteralString": "accept",
"requiredLiteralInt": 123,
"requiredLiteralFloat": 1.23,
"requiredLiteralBool": false,
"optionalLiteralString": "hi",
"optionalLiteralInt": 456,
"optionalLiteralFloat": 4.56,
"optionalLiteralBool": false,
"requiredNullableString": null,
"extra": "stuff"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"name": "Example Thing",
"requiredUnion": "mockUnion",
"requiredBadDescription": "This is a description with potentially problematic characters like < or >.",
"requiredNullableList": null,
"requiredLiteralString": "accept",
"requiredLiteralInt": 123,
"requiredLiteralFloat": 1.23,
"requiredLiteralBool": false,
"optionalLiteralString": "hi",
"optionalLiteralInt": 456,
"optionalLiteralFloat": 4.56,
"optionalLiteralBool": false,
"requiredNullableString": null
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

using System.ClientModel;
using System.IO;
using System.Text.Json;
using Microsoft.TypeSpec.Generator.Tests.Common;
using NUnit.Framework;
using UnbrandedTypeSpec;

namespace Microsoft.TypeSpec.Generator.ClientModel.Tests.ModelReaderWriterValidation.TestProjects.Unbranded_TypeSpec
{
internal class ThingWithNullsTests : LocalModelJsonTests<Thing>
{
protected override string JsonPayload => File.ReadAllText(ModelTestHelper.GetLocation("TestData/Thing/ThingWithNulls.json"));
protected override string WirePayload => File.ReadAllText(ModelTestHelper.GetLocation("TestData/Thing/ThingWithNullsWireFormat.json"));
protected override Thing ToModel(ClientResult result) => (Thing)result;
protected override BinaryContent ToBinaryContent(Thing model) => model;

protected override void CompareModels(Thing model, Thing model2, string format)
{
Assert.AreEqual(model.Name, model2.Name);
Assert.AreEqual(model.RequiredUnion.ToString(), model2.RequiredUnion.ToString());
Assert.AreEqual(model.RequiredLiteralString, model2.RequiredLiteralString);
Assert.AreEqual(model.RequiredLiteralInt, model2.RequiredLiteralInt);
Assert.AreEqual(model.RequiredLiteralFloat, model2.RequiredLiteralFloat);
Assert.AreEqual(model.RequiredLiteralBool, model2.RequiredLiteralBool);
Assert.AreEqual(model.OptionalLiteralString, model2.OptionalLiteralString);
Assert.AreEqual(model.OptionalLiteralInt, model2.OptionalLiteralInt);
Assert.AreEqual(model.OptionalLiteralFloat, model2.OptionalLiteralFloat);
Assert.AreEqual(model.OptionalLiteralBool, model2.OptionalLiteralBool);
Assert.AreEqual(model.RequiredBadDescription, model2.RequiredBadDescription);
Assert.AreEqual(model.OptionalNullableList, model2.OptionalNullableList);
Assert.AreEqual(model.RequiredNullableList, model2.RequiredNullableList);

if (format == "J")
{
var rawData = GetRawData(model);
var rawData2 = GetRawData(model2);
Assert.IsNotNull(rawData);
Assert.IsNotNull(rawData2);
Assert.AreEqual(rawData.Count, rawData2.Count);
Assert.AreEqual(rawData["extra"].ToObjectFromJson<string>(), rawData2["extra"].ToObjectFromJson<string>());
}
}

protected override void VerifyModel(Thing model, string format)
{
var parsedWireJson = JsonDocument.Parse(WirePayload).RootElement;
Assert.IsNotNull(parsedWireJson);
Assert.AreEqual(parsedWireJson.GetProperty("name").GetString(), model.Name);
Assert.AreEqual("\"mockUnion\"", model.RequiredUnion.ToString());
Assert.AreEqual(parsedWireJson.GetProperty("requiredBadDescription").GetString(), model.RequiredBadDescription);
Assert.AreEqual(JsonValueKind.Null, parsedWireJson.GetProperty("requiredNullableList").ValueKind);
Assert.IsEmpty(model.RequiredNullableList);
Assert.AreEqual(new ThingRequiredLiteralString(parsedWireJson.GetProperty("requiredLiteralString").GetString()), model.RequiredLiteralString);
Assert.AreEqual(new ThingRequiredLiteralInt(parsedWireJson.GetProperty("requiredLiteralInt").GetInt32()), model.RequiredLiteralInt);
Assert.AreEqual(new ThingRequiredLiteralFloat(parsedWireJson.GetProperty("requiredLiteralFloat").GetSingle()), model.RequiredLiteralFloat);
Assert.AreEqual(parsedWireJson.GetProperty("requiredLiteralBool").GetBoolean(), model.RequiredLiteralBool);
Assert.AreEqual(new ThingOptionalLiteralString(parsedWireJson.GetProperty("optionalLiteralString").GetString()), model.OptionalLiteralString);
Assert.AreEqual(new ThingOptionalLiteralInt(parsedWireJson.GetProperty("optionalLiteralInt").GetInt32()), model.OptionalLiteralInt);
Assert.AreEqual(new ThingOptionalLiteralFloat(parsedWireJson.GetProperty("optionalLiteralFloat").GetSingle()), model.OptionalLiteralFloat);
Assert.AreEqual(parsedWireJson.GetProperty("optionalLiteralBool").GetBoolean(), model.OptionalLiteralBool);
Assert.IsFalse(parsedWireJson.TryGetProperty("optionalNullableList", out _));
Assert.IsEmpty(model.OptionalNullableList);


var rawData = GetRawData(model);
Assert.IsNotNull(rawData);
if (format == "J")
{
var parsedJson = JsonDocument.Parse(JsonPayload).RootElement;
Assert.AreEqual(parsedJson.GetProperty("extra").GetString(), rawData["extra"].ToObjectFromJson<string>());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,19 @@ public void TestBuildDeserializationMethod()
Assert.IsNotNull(methodBody);
}

[TestCase(true)]
[TestCase(false)]
public void SerializedNameIsUsed(bool isRequired)
{
var property = InputFactory.Property("mockProperty", new InputNullableType(InputPrimitiveType.Int32), isRequired: isRequired, wireName: "mock_wire_name");
var inputModel = InputFactory.Model("mockInputModel", properties: [property]);
var (_, serialization) = CreateModelAndSerialization(inputModel);

var serializationMethod = serialization.Methods.Single(m => m.Signature.Name == "JsonModelWriteCore");
var methodBody = serializationMethod.BodyStatements!.ToDisplayString();
Assert.AreEqual(Helpers.GetExpectedFromFile(isRequired.ToString()), methodBody);
}

[Test]
public void TestBuildDeserializationMethodNestedSARD()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,8 @@ protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWrite
{
throw new global::System.FormatException($"The model {nameof(global::Sample.Models.MockInputModel)} does not support writing '{format}' format.");
}
if ((Prop1 != null))
{
writer.WritePropertyName("prop1"u8);
writer.WriteStringValue(Prop1.ToSerialString());
}
else
{
writer.WriteNull("prop1"u8);
}
writer.WritePropertyName("prop1"u8);
writer.WriteStringValue(Prop1.ToSerialString());
if (((options.Format != "W") && (_additionalBinaryDataProperties != null)))
{
foreach (var item in _additionalBinaryDataProperties)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>)this).GetFormatFromOptions(options) : options.Format;
if ((format != "J"))
{
throw new global::System.FormatException($"The model {nameof(global::Sample.Models.MockInputModel)} does not support writing '{format}' format.");
}
if (global::Sample.Optional.IsDefined(MockProperty))
{
writer.WritePropertyName("mock_wire_name"u8);
writer.WriteNumberValue(MockProperty.Value);
}
if (((options.Format != "W") && (_additionalBinaryDataProperties != null)))
{
foreach (var item in _additionalBinaryDataProperties)
{
writer.WritePropertyName(item.Key);
#if NET6_0_OR_GREATER
writer.WriteRawValue(item.Value);
#else
using (global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.Parse(item.Value))
{
global::System.Text.Json.JsonSerializer.Serialize(writer, document.RootElement);
}
#endif
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
string format = (options.Format == "W") ? ((global::System.ClientModel.Primitives.IPersistableModel<global::Sample.Models.MockInputModel>)this).GetFormatFromOptions(options) : options.Format;
if ((format != "J"))
{
throw new global::System.FormatException($"The model {nameof(global::Sample.Models.MockInputModel)} does not support writing '{format}' format.");
}
if (global::Sample.Optional.IsDefined(MockProperty))
{
writer.WritePropertyName("mock_wire_name"u8);
writer.WriteNumberValue(MockProperty.Value);
}
else
{
writer.WriteNull("mock_wire_name"u8);
}
if (((options.Format != "W") && (_additionalBinaryDataProperties != null)))
{
foreach (var item in _additionalBinaryDataProperties)
{
writer.WritePropertyName(item.Key);
#if NET6_0_OR_GREATER
writer.WriteRawValue(item.Value);
#else
using (global::System.Text.Json.JsonDocument document = global::System.Text.Json.JsonDocument.Parse(item.Value))
{
global::System.Text.Json.JsonSerializer.Serialize(writer, document.RootElement);
}
#endif
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,8 @@ protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWrite
}
if (global::Sample.Optional.IsDefined(Prop2))
{
if ((Prop2 != null))
{
writer.WritePropertyName("prop2"u8);
this.SerializationMethod(writer, options);
}
else
{
writer.WriteNull("prop2"u8);
}
writer.WritePropertyName("prop2"u8);
this.SerializationMethod(writer, options);
}
if (((options.Format != "W") && (_additionalBinaryDataProperties != null)))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,8 @@ protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWrite
}
if (global::Sample.Optional.IsDefined(Prop3))
{
if ((Prop3 != null))
{
writer.WritePropertyName("prop2"u8);
this.SerializationMethod(writer, options);
}
else
{
writer.WriteNull("prop2"u8);
}
writer.WritePropertyName("prop2"u8);
this.SerializationMethod(writer, options);
}
if (((options.Format != "W") && (_additionalBinaryDataProperties != null)))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,8 @@ protected virtual void JsonModelWriteCore(global::System.Text.Json.Utf8JsonWrite
}
if (global::Sample.Optional.IsDefined(Prop2))
{
if ((Prop2 != null))
{
writer.WritePropertyName("prop2"u8);
writer.WriteStringValue(Prop2);
}
else
{
writer.WriteNull("prop2"u8);
}
writer.WritePropertyName("prop2"u8);
writer.WriteStringValue(Prop2);
}
if (((options.Format != "W") && (_additionalBinaryDataProperties != null)))
{
Expand Down
Loading
Loading