Skip to content

Commit

Permalink
Fix generic parameter data flow validation in NativeAOT (#80956)
Browse files Browse the repository at this point in the history
Previously generic data flow was done from generic dictionary nodes. Problem with that approach is that there's no origin information at that point. The warnings can't point to the place where the problematic instantiation is in the code - we only know that it exists.
Aside from it being unfriendly for the users, it means any RUC or suppressions don't work on these warnings the same way they do in linker/analyzer.

This change modifies the logic to tag the method as "needs data flow" whenever we spot an instantiation of an annotated generic in it somewhere.
Then the actualy validation/marking is done from data flow using the trim analysis patterns.

The only exception to this is generic data flow for base types and interface implementations, that one is done on the EEType nodes.

Note that AOT implements a much more precise version of the generic data flow validation as compared to linker/analyzer. See the big comment at the bening of GenericParameterWarningLocation.cs for how that works.

Changes the expected warning validation to use tokens to compare message origins (same reason as with Kept validation - consistently converting things to string is hard)

Adds a new dependency node with the generic type definition to the graph and then does analysis on that node.
This is to avoid potential noise warnings which could happen due to multiple instantiations calling the checking code multiple times. With this the check is done only once on the type definition.

Tweaked some tests to try to cover the multiple instances scenario.
  • Loading branch information
vitek-karas authored Jan 25, 2023
1 parent 30cc26f commit aa5e313
Show file tree
Hide file tree
Showing 23 changed files with 2,637 additions and 178 deletions.
5 changes: 4 additions & 1 deletion src/coreclr/tools/Common/Compiler/DisplayNameHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ public static string GetDisplayName(this MethodDesc method)
if (method.Signature.Length > 0)
{
for (int i = 0; i < method.Signature.Length - 1; i++)
sb.Append(method.Signature[i].GetDisplayNameWithoutNamespace()).Append(',');
{
TypeDesc instantiatedType = method.Signature[i].InstantiateSignature(method.OwningType.Instantiation, method.Instantiation);
sb.Append(instantiatedType.GetDisplayNameWithoutNamespace()).Append(',');
}

sb.Append(method.Signature[method.Signature.Length - 1].GetDisplayNameWithoutNamespace());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ public bool RequiresDataflowAnalysis(MethodDesc method)
try
{
method = method.GetTypicalMethodDefinition();
return GetAnnotations(method.OwningType).TryGetAnnotation(method, out var methodAnnotations)
&& (methodAnnotations.ReturnParameterAnnotation != DynamicallyAccessedMemberTypes.None || methodAnnotations.ParameterAnnotations != null);
TypeAnnotations typeAnnotations = GetAnnotations(method.OwningType);
return typeAnnotations.HasGenericParameterAnnotation() || typeAnnotations.TryGetAnnotation(method, out _);
}
catch (TypeSystemException)
{
Expand All @@ -73,7 +73,8 @@ public bool RequiresDataflowAnalysis(FieldDesc field)
try
{
field = field.GetTypicalFieldDefinition();
return GetAnnotations(field.OwningType).TryGetAnnotation(field, out _);
TypeAnnotations typeAnnotations = GetAnnotations(field.OwningType);
return typeAnnotations.HasGenericParameterAnnotation() || typeAnnotations.TryGetAnnotation(field, out _);
}
catch (TypeSystemException)
{
Expand Down Expand Up @@ -105,6 +106,31 @@ public bool HasAnyAnnotations(TypeDesc type)
}
}

public bool HasGenericParameterAnnotation(TypeDesc type)
{
try
{
return GetAnnotations(type.GetTypeDefinition()).HasGenericParameterAnnotation();
}
catch (TypeSystemException)
{
return false;
}
}

public bool HasGenericParameterAnnotation(MethodDesc method)
{
try
{
method = method.GetTypicalMethodDefinition();
return GetAnnotations(method.OwningType).TryGetAnnotation(method, out var annotation) && annotation.GenericParameterAnnotations != null;
}
catch (TypeSystemException)
{
return false;
}
}

internal DynamicallyAccessedMemberTypes GetParameterAnnotation(ParameterProxy param)
{
MethodDesc method = param.Method.Method.GetTypicalMethodDefinition();
Expand Down Expand Up @@ -884,6 +910,8 @@ public bool TryGetAnnotation(GenericParameterDesc genericParameter, out Dynamica

return false;
}

public bool HasGenericParameterAnnotation() => _genericParameterAnnotations != null;
}

private readonly struct MethodAnnotations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,45 +18,66 @@

namespace ILCompiler.Dataflow
{
public readonly struct GenericArgumentDataFlow
public static class GenericArgumentDataFlow
{
private readonly Logger _logger;
private readonly NodeFactory _factory;
private readonly FlowAnnotations _annotations;
private readonly MessageOrigin _origin;
public static void ProcessGenericArgumentDataFlow(ref DependencyList dependencies, Logger logger, NodeFactory factory, FlowAnnotations annotations, in MessageOrigin origin, TypeDesc type)
{
var diagnosticContext = new DiagnosticContext(
origin,
!logger.ShouldSuppressAnalysisWarningsForRequires(origin.MemberDefinition, DiagnosticUtilities.RequiresUnreferencedCodeAttribute),
logger);
var reflectionMarker = new ReflectionMarker(logger, factory, annotations, typeHierarchyDataFlowOrigin: null, enabled: true);

ProcessGenericArgumentDataFlow(diagnosticContext, reflectionMarker, type);

if (reflectionMarker.Dependencies.Count > 0)
{
if (dependencies == null)
dependencies = reflectionMarker.Dependencies;
else
dependencies.AddRange(reflectionMarker.Dependencies);
}
}

public GenericArgumentDataFlow(Logger logger, NodeFactory factory, FlowAnnotations annotations, in MessageOrigin origin)
public static void ProcessGenericArgumentDataFlow(in DiagnosticContext diagnosticContext, ReflectionMarker reflectionMarker, TypeDesc type)
{
_logger = logger;
_factory = factory;
_annotations = annotations;
_origin = origin;
TypeDesc typeDefinition = type.GetTypeDefinition();
if (typeDefinition != type)
{
ProcessGenericInstantiation(diagnosticContext, reflectionMarker, type.Instantiation, typeDefinition.Instantiation);
}
}

public DependencyList ProcessGenericArgumentDataFlow(GenericParameterDesc genericParameter, TypeDesc genericArgument)
public static void ProcessGenericArgumentDataFlow(in DiagnosticContext diagnosticContext, ReflectionMarker reflectionMarker, MethodDesc method)
{
var genericParameterValue = _annotations.GetGenericParameterValue(genericParameter);
Debug.Assert(genericParameterValue.DynamicallyAccessedMemberTypes != DynamicallyAccessedMemberTypes.None);
MethodDesc typicalMethod = method.GetTypicalMethodDefinition();
if (typicalMethod != method)
{
ProcessGenericInstantiation(diagnosticContext, reflectionMarker, method.Instantiation, typicalMethod.Instantiation);
}

MultiValue genericArgumentValue = _annotations.GetTypeValueFromGenericArgument(genericArgument);
ProcessGenericArgumentDataFlow(diagnosticContext, reflectionMarker, method.OwningType);
}

var diagnosticContext = new DiagnosticContext(
_origin,
_logger.ShouldSuppressAnalysisWarningsForRequires(_origin.MemberDefinition, DiagnosticUtilities.RequiresUnreferencedCodeAttribute),
_logger);
return RequireDynamicallyAccessedMembers(diagnosticContext, genericArgumentValue, genericParameterValue, genericParameter.GetDisplayName());
public static void ProcessGenericArgumentDataFlow(in DiagnosticContext diagnosticContext, ReflectionMarker reflectionMarker, FieldDesc field)
{
ProcessGenericArgumentDataFlow(diagnosticContext, reflectionMarker, field.OwningType);
}

private DependencyList RequireDynamicallyAccessedMembers(
in DiagnosticContext diagnosticContext,
in MultiValue value,
ValueWithDynamicallyAccessedMembers targetValue,
string reason)
private static void ProcessGenericInstantiation(in DiagnosticContext diagnosticContext, ReflectionMarker reflectionMarker, Instantiation instantiation, Instantiation typicalInstantiation)
{
var reflectionMarker = new ReflectionMarker(_logger, _factory, _annotations, typeHierarchyDataFlowOrigin: null, enabled: true);
var requireDynamicallyAccessedMembersAction = new RequireDynamicallyAccessedMembersAction(reflectionMarker, diagnosticContext, reason);
requireDynamicallyAccessedMembersAction.Invoke(value, targetValue);
return reflectionMarker.Dependencies;
for (int i = 0; i < instantiation.Length; i++)
{
var genericParameter = (GenericParameterDesc)typicalInstantiation[i];
if (reflectionMarker.Annotations.GetGenericParameterAnnotation(genericParameter) != default)
{
var genericParameterValue = reflectionMarker.Annotations.GetGenericParameterValue(genericParameter);
Debug.Assert(genericParameterValue.DynamicallyAccessedMemberTypes != DynamicallyAccessedMemberTypes.None);
MultiValue genericArgumentValue = reflectionMarker.Annotations.GetTypeValueFromGenericArgument(instantiation[i]);
var requireDynamicallyAccessedMembersAction = new RequireDynamicallyAccessedMembersAction(reflectionMarker, diagnosticContext, genericParameter.GetDisplayName());
requireDynamicallyAccessedMembersAction.Invoke(genericArgumentValue, genericParameterValue);
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ protected virtual void Scan(MethodIL methodBody, ref InterproceduralState interp
StackSlot retValue = PopUnknown(currentStack, 1, methodBody, offset);
// If the return value is a reference, treat it as the value itself for now
// We can handle ref return values better later
ReturnValue = MultiValueLattice.Meet(ReturnValue, DereferenceValue(retValue.Value, locals, ref interproceduralState));
ReturnValue = MultiValueLattice.Meet(ReturnValue, DereferenceValue(methodBody, offset, retValue.Value, locals, ref interproceduralState));
ValidateNoReferenceToReference(locals, methodBody, offset);
}
ClearStack(ref currentStack);
Expand Down Expand Up @@ -947,23 +947,24 @@ private void ScanLdtoken(MethodIL methodBody, int offset, object operand, Stack<
var nullableDam = new RuntimeTypeHandleForNullableValueWithDynamicallyAccessedMembers(new TypeProxy(type),
new RuntimeTypeHandleForGenericParameterValue(genericParam));
currentStack.Push(new StackSlot(nullableDam));
return;
break;
case MetadataType underlyingType:
var nullableType = new RuntimeTypeHandleForNullableSystemTypeValue(new TypeProxy(type), new SystemTypeValue(underlyingType));
currentStack.Push(new StackSlot(nullableType));
return;
break;
default:
PushUnknown(currentStack);
return;
break;
}
}
else
{
var typeHandle = new RuntimeTypeHandleValue(new TypeProxy(type));
currentStack.Push(new StackSlot(typeHandle));
return;
}
}

HandleTypeReflectionAccess(methodBody, offset, type);
}
else if (operand is MethodDesc method)
{
Expand Down Expand Up @@ -1026,7 +1027,7 @@ protected void StoreInReference(MultiValue target, MultiValue source, MethodIL m
StoreMethodLocalValue(locals, source, localReference.LocalIndex, curBasicBlock);
break;
case FieldReferenceValue fieldReference
when GetFieldValue(fieldReference.FieldDefinition).AsSingleValue() is FieldValue fieldValue:
when HandleGetField(method, offset, fieldReference.FieldDefinition).AsSingleValue() is FieldValue fieldValue:
HandleStoreField(method, offset, fieldValue, source);
break;
case ParameterReferenceValue parameterReference
Expand All @@ -1038,7 +1039,7 @@ when GetMethodParameterValue(parameterReference.Parameter) is MethodParameterVal
HandleStoreMethodReturnValue(method, offset, methodReturnValue, source);
break;
case FieldValue fieldValue:
HandleStoreField(method, offset, fieldValue, DereferenceValue(source, locals, ref ipState));
HandleStoreField(method, offset, fieldValue, DereferenceValue(method, offset, source, locals, ref ipState));
break;
case IValueWithStaticType valueWithStaticType:
if (valueWithStaticType.StaticType is not null && FlowAnnotations.IsTypeInterestingForDataflow(valueWithStaticType.StaticType))
Expand All @@ -1057,7 +1058,25 @@ when GetMethodParameterValue(parameterReference.Parameter) is MethodParameterVal

}

protected abstract MultiValue GetFieldValue(FieldDesc field);
/// <summary>
/// HandleGetField is called every time the scanner needs to represent a value of the field
/// either as a source or target. It is not called when just a reference to field is created,
/// But if such reference is dereferenced then it will get called.
/// It is NOT called for hoisted locals.
/// </summary>
/// <remarks>
/// There should be no need to perform checks for hoisted locals. All of our reflection checks are based
/// on an assumption that problematic things happen because of running code. Doing things purely in the type system
/// (declaring new types which are never instantiated, declaring fields which are never assigned to, ...)
/// don't cause problems (or better way, they won't show observable behavioral differences).
/// Typically that would mean that accessing fields is also an uninteresting operation, unfortunately
/// static fields access can cause execution of static .cctor and that is running code -> possible problems.
/// So we have to track accesses in that case.
/// Hoisted locals are fields on closure classes/structs which should not have static .ctors, so we don't
/// need to track those. It makes the design a bit cleaner because hoisted locals are purely handled in here
/// and don't leak over to the reflection handling code in any way.
/// </remarks>
protected abstract MultiValue HandleGetField(MethodIL methodBody, int offset, FieldDesc field);

private void ScanLdfld(
MethodIL methodBody,
Expand All @@ -1083,7 +1102,7 @@ private void ScanLdfld(
}
else
{
value = GetFieldValue(field);
value = HandleGetField(methodBody, offset, field);
}
currentStack.Push(new StackSlot(value));
}
Expand Down Expand Up @@ -1119,15 +1138,15 @@ private void ScanStfld(
return;
}

foreach (var value in GetFieldValue(field))
foreach (var value in HandleGetField(methodBody, offset, field))
{
// GetFieldValue may return different node types, in which case they can't be stored to.
// At least not yet.
if (value is not FieldValue fieldValue)
continue;

// Incomplete handling of ref fields -- if we're storing a reference to a value, pretend it's just the value
MultiValue valueToStore = DereferenceValue(valueToStoreSlot.Value, locals, ref interproceduralState);
MultiValue valueToStore = DereferenceValue(methodBody, offset, valueToStoreSlot.Value, locals, ref interproceduralState);

HandleStoreField(methodBody, offset, fieldValue, valueToStore);
}
Expand Down Expand Up @@ -1163,7 +1182,12 @@ private ValueNodeList PopCallArguments(
return methodParams;
}

internal MultiValue DereferenceValue(MultiValue maybeReferenceValue, ValueBasicBlockPair?[] locals, ref InterproceduralState interproceduralState)
internal MultiValue DereferenceValue(
MethodIL methodBody,
int offset,
MultiValue maybeReferenceValue,
ValueBasicBlockPair?[] locals,
ref InterproceduralState interproceduralState)
{
MultiValue dereferencedValue = MultiValueLattice.Top;
foreach (var value in maybeReferenceValue)
Expand All @@ -1175,7 +1199,7 @@ internal MultiValue DereferenceValue(MultiValue maybeReferenceValue, ValueBasicB
dereferencedValue,
CompilerGeneratedState.IsHoistedLocal(fieldReferenceValue.FieldDefinition)
? interproceduralState.GetHoistedLocal(new HoistedLocalKey(fieldReferenceValue.FieldDefinition))
: GetFieldValue(fieldReferenceValue.FieldDefinition));
: HandleGetField(methodBody, offset, fieldReferenceValue.FieldDefinition));
break;
case ParameterReferenceValue parameterReferenceValue:
dereferencedValue = MultiValue.Meet(
Expand Down Expand Up @@ -1224,6 +1248,11 @@ protected void AssignRefAndOutParameters(
}
}

/// <summary>
/// Called when type is accessed directly (basically only ldtoken)
/// </summary>
protected abstract void HandleTypeReflectionAccess(MethodIL methodBody, int offset, TypeDesc accessedType);

/// <summary>
/// Called to handle reflection access to a method without any other specifics (ldtoken or ldftn for example)
/// </summary>
Expand Down Expand Up @@ -1260,7 +1289,7 @@ private void HandleCall(

var dereferencedMethodParams = new List<MultiValue>();
foreach (var argument in methodArguments)
dereferencedMethodParams.Add(DereferenceValue(argument, locals, ref interproceduralState));
dereferencedMethodParams.Add(DereferenceValue(callingMethodBody, offset, argument, locals, ref interproceduralState));
MultiValue methodReturnValue;
bool handledFunction = HandleCall(
callingMethodBody,
Expand Down
Loading

0 comments on commit aa5e313

Please # to comment.