Skip to content

Commit

Permalink
Fixed initialization race-condition in generated code. (#7834)
Browse files Browse the repository at this point in the history
* Fixed initialization race-condition in generated code.

* Updated Snapshots
  • Loading branch information
michaelstaib committed Dec 16, 2024
1 parent 4d6e37d commit 63ab9f3
Show file tree
Hide file tree
Showing 10 changed files with 318 additions and 191 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public string WriteBeginClass(string typeName)
_writer.WriteIndentedLine("internal static class {0}", typeName);
_writer.WriteIndentedLine("{");
_writer.IncreaseIndent();
_writer.WriteIndentedLine("private static readonly object _sync = new object();");
_writer.WriteIndentedLine("private static bool _bindingsInitialized;");
return typeName;
}
Expand Down Expand Up @@ -94,15 +95,18 @@ public void AddParameterInitializer(IEnumerable<Resolver> resolvers, ILocalTypeL

if (first)
{
_writer.WriteIndentedLine("if (_bindingsInitialized)");
_writer.WriteIndentedLine("if (!_bindingsInitialized)");
_writer.WriteIndentedLine("{");
using (_writer.IncreaseIndent())
{
_writer.WriteIndentedLine("return;");
}
_writer.IncreaseIndent();

_writer.WriteIndentedLine("lock (_sync)");
_writer.WriteIndentedLine("{");
_writer.IncreaseIndent();

_writer.WriteIndentedLine("if (!_bindingsInitialized)");
_writer.WriteIndentedLine("{");
_writer.IncreaseIndent();

_writer.WriteIndentedLine("}");
_writer.WriteIndentedLine("_bindingsInitialized = true;");
_writer.WriteLine();
_writer.WriteIndentedLine(
"const global::{0} bindingFlags =",
Expand All @@ -120,6 +124,8 @@ public void AddParameterInitializer(IEnumerable<Resolver> resolvers, ILocalTypeL
_writer.WriteIndentedLine("var type = typeof({0});", method.ContainingType.ToFullyQualified());
_writer.WriteIndentedLine("global::System.Reflection.MethodInfo resolver = default!;");
_writer.WriteIndentedLine("global::System.Reflection.ParameterInfo[] parameters = default!;");

_writer.WriteIndentedLine("_bindingsInitialized = true;");
first = false;
}

Expand Down Expand Up @@ -182,8 +188,8 @@ public void AddParameterInitializer(IEnumerable<Resolver> resolvers, ILocalTypeL
using (_writer.WriteForEach("binding", $"_args_{resolver.TypeName}_{resolver.Member.Name}"))
{
using (_writer.WriteIfClause(
"binding.Kind == global::{0}.Argument",
WellKnownTypes.ArgumentKind))
"binding.Kind == global::{0}.Argument",
WellKnownTypes.ArgumentKind))
{
_writer.WriteIndentedLine("argumentCount++;");
}
Expand All @@ -204,8 +210,8 @@ public void AddParameterInitializer(IEnumerable<Resolver> resolvers, ILocalTypeL
using (_writer.IncreaseIndent())
{
_writer.WriteIndentedLine(
".SetMessage(\"The node resolver `{0}.{1}` mustn't have more than one " +
"argument. Node resolvers can only have a single argument called `id`.\")",
".SetMessage(\"The node resolver `{0}.{1}` mustn't have more than one "
+ "argument. Node resolvers can only have a single argument called `id`.\")",
resolver.Member.ContainingType.ToDisplayString(),
resolver.Member.Name);
_writer.WriteIndentedLine(".Build());");
Expand All @@ -214,6 +220,16 @@ public void AddParameterInitializer(IEnumerable<Resolver> resolvers, ILocalTypeL
}
}
}

if (!first)
{
_writer.DecreaseIndent();
_writer.WriteIndentedLine("}");
_writer.DecreaseIndent();
_writer.WriteIndentedLine("}");
_writer.DecreaseIndent();
_writer.WriteIndentedLine("}");
}
}

_writer.WriteIndentedLine("}");
Expand All @@ -224,8 +240,7 @@ private static string ToFullyQualifiedString(
IMethodSymbol resolverMethod,
ILocalTypeLookup typeLookup)
{
if (type.TypeKind is TypeKind.Error &&
typeLookup.TryGetTypeName(type, resolverMethod, out var typeDisplayName))
if (type.TypeKind is TypeKind.Error && typeLookup.TryGetTypeName(type, resolverMethod, out var typeDisplayName))
{
return typeDisplayName;
}
Expand Down Expand Up @@ -269,9 +284,9 @@ private void AddStaticStandardResolver(
ILocalTypeLookup typeLookup)
{
using (_writer.WriteMethod(
"public static",
returnType: WellKnownTypes.FieldResolverDelegates,
methodName: $"{resolver.TypeName}_{resolver.Member.Name}"))
"public static",
returnType: WellKnownTypes.FieldResolverDelegates,
methodName: $"{resolver.TypeName}_{resolver.Member.Name}"))
{
using (_writer.WriteIfClause(condition: "!_bindingsInitialized"))
{
Expand Down Expand Up @@ -341,9 +356,9 @@ private void AddStaticStandardResolver(
private void AddStaticPureResolver(Resolver resolver, IMethodSymbol resolverMethod, ILocalTypeLookup typeLookup)
{
using (_writer.WriteMethod(
"public static",
returnType: WellKnownTypes.FieldResolverDelegates,
methodName: $"{resolver.TypeName}_{resolver.Member.Name}"))
"public static",
returnType: WellKnownTypes.FieldResolverDelegates,
methodName: $"{resolver.TypeName}_{resolver.Member.Name}"))
{
using (_writer.WriteIfClause(condition: "!_bindingsInitialized"))
{
Expand Down Expand Up @@ -436,9 +451,9 @@ private void AddStaticPureResolver(Resolver resolver, IMethodSymbol resolverMeth
private void AddStaticPropertyResolver(Resolver resolver)
{
using (_writer.WriteMethod(
"public static",
returnType: WellKnownTypes.FieldResolverDelegates,
methodName: $"{resolver.TypeName}_{resolver.Member.Name}"))
"public static",
returnType: WellKnownTypes.FieldResolverDelegates,
methodName: $"{resolver.TypeName}_{resolver.Member.Name}"))
{
using (_writer.WriteIfClause(condition: "!_bindingsInitialized"))
{
Expand Down Expand Up @@ -489,13 +504,13 @@ private void AddResolverArguments(Resolver resolver, IMethodSymbol resolverMetho
{
var parameter = resolver.Parameters[i];

if(resolver.IsNodeResolver
&& parameter.Kind is ResolverParameterKind.Argument or ResolverParameterKind.Unknown
&& (parameter.Name == "id" || parameter.Key == "id"))
if (resolver.IsNodeResolver
&& parameter.Kind is ResolverParameterKind.Argument or ResolverParameterKind.Unknown
&& (parameter.Name == "id" || parameter.Key == "id"))
{
_writer.WriteIndentedLine(
"var args{0} = context.GetLocalState<{1}>(" +
"global::HotChocolate.WellKnownContextData.InternalId);",
"var args{0} = context.GetLocalState<{1}>("
+ "global::HotChocolate.WellKnownContextData.InternalId);",
i,
parameter.Type.ToFullyQualified());
continue;
Expand Down Expand Up @@ -524,8 +539,8 @@ private void AddResolverArguments(Resolver resolver, IMethodSymbol resolverMetho
break;
case ResolverParameterKind.EventMessage:
_writer.WriteIndentedLine(
"var args{0} = context.GetScopedState<{1}>(" +
"global::HotChocolate.WellKnownContextData.EventMessage);",
"var args{0} = context.GetScopedState<{1}>("
+ "global::HotChocolate.WellKnownContextData.EventMessage);",
i,
parameter.Type.ToFullyQualified());
break;
Expand Down Expand Up @@ -593,8 +608,8 @@ private void AddResolverArguments(Resolver resolver, IMethodSymbol resolverMetho
}
case ResolverParameterKind.SetGlobalState:
_writer.WriteIndentedLine(
"var args{0} = new HotChocolate.SetState<{1}>(" +
"value => context.SetGlobalState(\"{2}\", value));",
"var args{0} = new HotChocolate.SetState<{1}>("
+ "value => context.SetGlobalState(\"{2}\", value));",
i,
((INamedTypeSymbol)parameter.Type).TypeArguments[0].ToFullyQualified(),
parameter.Key);
Expand Down Expand Up @@ -633,8 +648,8 @@ private void AddResolverArguments(Resolver resolver, IMethodSymbol resolverMetho
}
case ResolverParameterKind.SetScopedState:
_writer.WriteIndentedLine(
"var args{0} = new HotChocolate.SetState<{1}>(" +
"value => context.SetScopedState(\"{2}\", value));",
"var args{0} = new HotChocolate.SetState<{1}>("
+ "value => context.SetScopedState(\"{2}\", value));",
i,
((INamedTypeSymbol)parameter.Type).TypeArguments[0].ToFullyQualified(),
parameter.Key);
Expand Down Expand Up @@ -673,8 +688,8 @@ private void AddResolverArguments(Resolver resolver, IMethodSymbol resolverMetho
}
case ResolverParameterKind.SetLocalState:
_writer.WriteIndentedLine(
"var args{0} = new HotChocolate.SetState<{1}>(" +
"value => context.SetLocalState(\"{2}\", value));",
"var args{0} = new HotChocolate.SetState<{1}>("
+ "value => context.SetLocalState(\"{2}\", value));",
i,
((INamedTypeSymbol)parameter.Type).TypeArguments[0].ToFullyQualified(),
parameter.Key);
Expand Down
68 changes: 68 additions & 0 deletions src/HotChocolate/Core/test/Types.Analyzers.Tests/TestMe.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// <auto-generated/>

#nullable enable
#pragma warning disable

using System;
using System.Runtime.CompilerServices;
using HotChocolate;
using HotChocolate.Types;
using HotChocolate.Execution.Configuration;
using HotChocolate.Internal;

namespace HotChocolate.Types
{
internal static class EntityInterfaceResolvers2
{
private static readonly object _sync = new object();
private static bool _bindingsInitialized;
private readonly static global::HotChocolate.Internal.IParameterBinding[] _args_EntityInterface_IdString = new global::HotChocolate.Internal.IParameterBinding[1];

public static void InitializeBindings(global::HotChocolate.Internal.IParameterBindingResolver bindingResolver)
{
if (!_bindingsInitialized)
{
lock (_sync)
{
if (!_bindingsInitialized)
{
const global::System.Reflection.BindingFlags bindingFlags =
global::System.Reflection.BindingFlags.Public
| global::System.Reflection.BindingFlags.NonPublic
| global::System.Reflection.BindingFlags.Static;

var type = typeof(global::HotChocolate.Types.EntityInterface);
global::System.Reflection.MethodInfo resolver = default!;
global::System.Reflection.ParameterInfo[] parameters = default!;

resolver = type.GetMethod(
"IdString",
bindingFlags,
new global::System.Type[] { typeof(global::HotChocolate.Types.IEntity) })!;
parameters = resolver.GetParameters();
_args_EntityInterface_IdString[0] = bindingResolver.GetBinding(parameters[0]);

_bindingsInitialized = true;
}
}
}
}

public static HotChocolate.Resolvers.FieldResolverDelegates EntityInterface_IdString()
{
if(!_bindingsInitialized)
{
throw new global::System.InvalidOperationException("The bindings must be initialized before the resolvers can be created.");
}
return new global::HotChocolate.Resolvers.FieldResolverDelegates(pureResolver: EntityInterface_IdString_Resolver);
}

private static global::System.Object? EntityInterface_IdString_Resolver(global::HotChocolate.Resolvers.IResolverContext context)
{
var args0 = context.Parent<global::HotChocolate.Types.IEntity>();
var result = global::HotChocolate.Types.EntityInterface.IdString(args0);
return result;
}
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -19,37 +19,43 @@ namespace TestNamespace
{
internal static class BookNodeResolvers
{
private static readonly object _sync = new object();
private static bool _bindingsInitialized;
private readonly static global::HotChocolate.Internal.IParameterBinding[] _args_BookNode_GetAuthorAsync = new global::HotChocolate.Internal.IParameterBinding[2];

public static void InitializeBindings(global::HotChocolate.Internal.IParameterBindingResolver bindingResolver)
{
if (_bindingsInitialized)
if (!_bindingsInitialized)
{
return;
}
_bindingsInitialized = true;

const global::System.Reflection.BindingFlags bindingFlags =
global::System.Reflection.BindingFlags.Public
| global::System.Reflection.BindingFlags.NonPublic
| global::System.Reflection.BindingFlags.Static;

var type = typeof(global::TestNamespace.BookNode);
global::System.Reflection.MethodInfo resolver = default!;
global::System.Reflection.ParameterInfo[] parameters = default!;

resolver = type.GetMethod(
"GetAuthorAsync",
bindingFlags,
new global::System.Type[]
lock (_sync)
{
typeof(global::TestNamespace.Book),
typeof(global::System.Threading.CancellationToken)
})!;
parameters = resolver.GetParameters();
_args_BookNode_GetAuthorAsync[0] = bindingResolver.GetBinding(parameters[0]);
_args_BookNode_GetAuthorAsync[1] = bindingResolver.GetBinding(parameters[1]);
if (!_bindingsInitialized)
{

const global::System.Reflection.BindingFlags bindingFlags =
global::System.Reflection.BindingFlags.Public
| global::System.Reflection.BindingFlags.NonPublic
| global::System.Reflection.BindingFlags.Static;

var type = typeof(global::TestNamespace.BookNode);
global::System.Reflection.MethodInfo resolver = default!;
global::System.Reflection.ParameterInfo[] parameters = default!;
_bindingsInitialized = true;

resolver = type.GetMethod(
"GetAuthorAsync",
bindingFlags,
new global::System.Type[]
{
typeof(global::TestNamespace.Book),
typeof(global::System.Threading.CancellationToken)
})!;
parameters = resolver.GetParameters();
_args_BookNode_GetAuthorAsync[0] = bindingResolver.GetBinding(parameters[0]);
_args_BookNode_GetAuthorAsync[1] = bindingResolver.GetBinding(parameters[1]);
}
}
}
}

public static HotChocolate.Resolvers.FieldResolverDelegates BookNode_GetAuthorAsync()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,35 +19,41 @@ namespace TestNamespace
{
internal static class TestTypeResolvers
{
private static readonly object _sync = new object();
private static bool _bindingsInitialized;
private readonly static global::HotChocolate.Internal.IParameterBinding[] _args_TestType_GetTest = new global::HotChocolate.Internal.IParameterBinding[1];

public static void InitializeBindings(global::HotChocolate.Internal.IParameterBindingResolver bindingResolver)
{
if (_bindingsInitialized)
if (!_bindingsInitialized)
{
return;
}
_bindingsInitialized = true;

const global::System.Reflection.BindingFlags bindingFlags =
global::System.Reflection.BindingFlags.Public
| global::System.Reflection.BindingFlags.NonPublic
| global::System.Reflection.BindingFlags.Static;

var type = typeof(global::TestNamespace.TestType);
global::System.Reflection.MethodInfo resolver = default!;
global::System.Reflection.ParameterInfo[] parameters = default!;

resolver = type.GetMethod(
"GetTest",
bindingFlags,
new global::System.Type[]
lock (_sync)
{
typeof(int)
})!;
parameters = resolver.GetParameters();
_args_TestType_GetTest[0] = bindingResolver.GetBinding(parameters[0]);
if (!_bindingsInitialized)
{

const global::System.Reflection.BindingFlags bindingFlags =
global::System.Reflection.BindingFlags.Public
| global::System.Reflection.BindingFlags.NonPublic
| global::System.Reflection.BindingFlags.Static;

var type = typeof(global::TestNamespace.TestType);
global::System.Reflection.MethodInfo resolver = default!;
global::System.Reflection.ParameterInfo[] parameters = default!;
_bindingsInitialized = true;

resolver = type.GetMethod(
"GetTest",
bindingFlags,
new global::System.Type[]
{
typeof(int)
})!;
parameters = resolver.GetParameters();
_args_TestType_GetTest[0] = bindingResolver.GetBinding(parameters[0]);
}
}
}
}

public static HotChocolate.Resolvers.FieldResolverDelegates TestType_GetTest()
Expand Down
Loading

0 comments on commit 63ab9f3

Please # to comment.