Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Fix generic parameter data flow validation in NativeAOT #81532

Merged
merged 2 commits into from
Feb 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/coreclr/tools/Common/Compiler/DisplayNameHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,10 @@ public static string GetDisplayName(this MethodDesc method)
if (method.Signature.Length > 0)
{
for (int i = 0; i < method.Signature.Length - 1; i++)
sb.Append(method.Signature[i].GetDisplayNameWithoutNamespace()).Append(',');
{
TypeDesc instantiatedType = method.Signature[i].InstantiateSignature(method.OwningType.Instantiation, method.Instantiation);
sb.Append(instantiatedType.GetDisplayNameWithoutNamespace()).Append(',');
}

sb.Append(method.Signature[method.Signature.Length - 1].GetDisplayNameWithoutNamespace());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@ public bool RequiresDataflowAnalysis(MethodDesc method)
try
{
method = method.GetTypicalMethodDefinition();
return GetAnnotations(method.OwningType).TryGetAnnotation(method, out var methodAnnotations)
&& (methodAnnotations.ReturnParameterAnnotation != DynamicallyAccessedMemberTypes.None || methodAnnotations.ParameterAnnotations != null);
TypeAnnotations typeAnnotations = GetAnnotations(method.OwningType);
return typeAnnotations.HasGenericParameterAnnotation() || typeAnnotations.TryGetAnnotation(method, out _);
}
catch (TypeSystemException)
{
Expand All @@ -73,7 +73,8 @@ public bool RequiresDataflowAnalysis(FieldDesc field)
try
{
field = field.GetTypicalFieldDefinition();
return GetAnnotations(field.OwningType).TryGetAnnotation(field, out _);
TypeAnnotations typeAnnotations = GetAnnotations(field.OwningType);
return typeAnnotations.HasGenericParameterAnnotation() || typeAnnotations.TryGetAnnotation(field, out _);
}
catch (TypeSystemException)
{
Expand Down Expand Up @@ -105,6 +106,31 @@ public bool HasAnyAnnotations(TypeDesc type)
}
}

public bool HasGenericParameterAnnotation(TypeDesc type)
{
try
{
return GetAnnotations(type.GetTypeDefinition()).HasGenericParameterAnnotation();
}
catch (TypeSystemException)
{
return false;
}
}

public bool HasGenericParameterAnnotation(MethodDesc method)
{
try
{
method = method.GetTypicalMethodDefinition();
return GetAnnotations(method.OwningType).TryGetAnnotation(method, out var annotation) && annotation.GenericParameterAnnotations != null;
}
catch (TypeSystemException)
{
return false;
}
}

internal DynamicallyAccessedMemberTypes GetParameterAnnotation(ParameterProxy param)
{
MethodDesc method = param.Method.Method.GetTypicalMethodDefinition();
Expand Down Expand Up @@ -884,6 +910,8 @@ public bool TryGetAnnotation(GenericParameterDesc genericParameter, out Dynamica

return false;
}

public bool HasGenericParameterAnnotation() => _genericParameterAnnotations != null;
}

private readonly struct MethodAnnotations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,45 +18,83 @@

namespace ILCompiler.Dataflow
{
public readonly struct GenericArgumentDataFlow
public static class GenericArgumentDataFlow
{
private readonly Logger _logger;
private readonly NodeFactory _factory;
private readonly FlowAnnotations _annotations;
private readonly MessageOrigin _origin;
public static void ProcessGenericArgumentDataFlow(ref DependencyList dependencies, NodeFactory factory, in MessageOrigin origin, TypeDesc type, TypeDesc contextType)
{
ProcessGenericArgumentDataFlow(ref dependencies, factory, origin, type, contextType.Instantiation, Instantiation.Empty);
}

public GenericArgumentDataFlow(Logger logger, NodeFactory factory, FlowAnnotations annotations, in MessageOrigin origin)
public static void ProcessGenericArgumentDataFlow(ref DependencyList dependencies, NodeFactory factory, in MessageOrigin origin, TypeDesc type, MethodDesc contextMethod)
{
_logger = logger;
_factory = factory;
_annotations = annotations;
_origin = origin;
ProcessGenericArgumentDataFlow(ref dependencies, factory, origin, type, contextMethod.OwningType.Instantiation, contextMethod.Instantiation);
}

public DependencyList ProcessGenericArgumentDataFlow(GenericParameterDesc genericParameter, TypeDesc genericArgument)
private static void ProcessGenericArgumentDataFlow(ref DependencyList dependencies, NodeFactory factory, in MessageOrigin origin, TypeDesc type, Instantiation typeContext, Instantiation methodContext)
{
var genericParameterValue = _annotations.GetGenericParameterValue(genericParameter);
Debug.Assert(genericParameterValue.DynamicallyAccessedMemberTypes != DynamicallyAccessedMemberTypes.None);
if (!type.HasInstantiation)
return;

MultiValue genericArgumentValue = _annotations.GetTypeValueFromGenericArgument(genericArgument);
TypeDesc instantiatedType = type.InstantiateSignature(typeContext, methodContext);

var mdManager = (UsageBasedMetadataManager)factory.MetadataManager;

var diagnosticContext = new DiagnosticContext(
_origin,
_logger.ShouldSuppressAnalysisWarningsForRequires(_origin.MemberDefinition, DiagnosticUtilities.RequiresUnreferencedCodeAttribute),
_logger);
return RequireDynamicallyAccessedMembers(diagnosticContext, genericArgumentValue, genericParameterValue, genericParameter.GetDisplayName());
origin,
!mdManager.Logger.ShouldSuppressAnalysisWarningsForRequires(origin.MemberDefinition, DiagnosticUtilities.RequiresUnreferencedCodeAttribute),
mdManager.Logger);
var reflectionMarker = new ReflectionMarker(mdManager.Logger, factory, mdManager.FlowAnnotations, typeHierarchyDataFlowOrigin: null, enabled: true);

ProcessGenericArgumentDataFlow(diagnosticContext, reflectionMarker, instantiatedType);

if (reflectionMarker.Dependencies.Count > 0)
{
if (dependencies == null)
dependencies = reflectionMarker.Dependencies;
else
dependencies.AddRange(reflectionMarker.Dependencies);
}
}

public static void ProcessGenericArgumentDataFlow(in DiagnosticContext diagnosticContext, ReflectionMarker reflectionMarker, TypeDesc type)
{
TypeDesc typeDefinition = type.GetTypeDefinition();
if (typeDefinition != type)
{
ProcessGenericInstantiation(diagnosticContext, reflectionMarker, type.Instantiation, typeDefinition.Instantiation);
}
}

public static void ProcessGenericArgumentDataFlow(in DiagnosticContext diagnosticContext, ReflectionMarker reflectionMarker, MethodDesc method)
{
MethodDesc typicalMethod = method.GetTypicalMethodDefinition();
if (typicalMethod != method)
{
ProcessGenericInstantiation(diagnosticContext, reflectionMarker, method.Instantiation, typicalMethod.Instantiation);
}

ProcessGenericArgumentDataFlow(diagnosticContext, reflectionMarker, method.OwningType);
}

public static void ProcessGenericArgumentDataFlow(in DiagnosticContext diagnosticContext, ReflectionMarker reflectionMarker, FieldDesc field)
{
ProcessGenericArgumentDataFlow(diagnosticContext, reflectionMarker, field.OwningType);
}

private DependencyList RequireDynamicallyAccessedMembers(
in DiagnosticContext diagnosticContext,
in MultiValue value,
ValueWithDynamicallyAccessedMembers targetValue,
string reason)
private static void ProcessGenericInstantiation(in DiagnosticContext diagnosticContext, ReflectionMarker reflectionMarker, Instantiation instantiation, Instantiation typicalInstantiation)
{
var reflectionMarker = new ReflectionMarker(_logger, _factory, _annotations, typeHierarchyDataFlowOrigin: null, enabled: true);
var requireDynamicallyAccessedMembersAction = new RequireDynamicallyAccessedMembersAction(reflectionMarker, diagnosticContext, reason);
requireDynamicallyAccessedMembersAction.Invoke(value, targetValue);
return reflectionMarker.Dependencies;
for (int i = 0; i < instantiation.Length; i++)
{
var genericParameter = (GenericParameterDesc)typicalInstantiation[i];
if (reflectionMarker.Annotations.GetGenericParameterAnnotation(genericParameter) != default)
{
var genericParameterValue = reflectionMarker.Annotations.GetGenericParameterValue(genericParameter);
Debug.Assert(genericParameterValue.DynamicallyAccessedMemberTypes != DynamicallyAccessedMemberTypes.None);
MultiValue genericArgumentValue = reflectionMarker.Annotations.GetTypeValueFromGenericArgument(instantiation[i]);
var requireDynamicallyAccessedMembersAction = new RequireDynamicallyAccessedMembersAction(reflectionMarker, diagnosticContext, genericParameter.GetDisplayName());
requireDynamicallyAccessedMembersAction.Invoke(genericArgumentValue, genericParameterValue);
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -800,7 +800,7 @@ protected virtual void Scan(MethodIL methodBody, ref InterproceduralState interp
StackSlot retValue = PopUnknown(currentStack, 1, methodBody, offset);
// If the return value is a reference, treat it as the value itself for now
// We can handle ref return values better later
ReturnValue = MultiValueLattice.Meet(ReturnValue, DereferenceValue(retValue.Value, locals, ref interproceduralState));
ReturnValue = MultiValueLattice.Meet(ReturnValue, DereferenceValue(methodBody, offset, retValue.Value, locals, ref interproceduralState));
ValidateNoReferenceToReference(locals, methodBody, offset);
}
ClearStack(ref currentStack);
Expand Down Expand Up @@ -947,23 +947,24 @@ private void ScanLdtoken(MethodIL methodBody, int offset, object operand, Stack<
var nullableDam = new RuntimeTypeHandleForNullableValueWithDynamicallyAccessedMembers(new TypeProxy(type),
new RuntimeTypeHandleForGenericParameterValue(genericParam));
currentStack.Push(new StackSlot(nullableDam));
return;
break;
case MetadataType underlyingType:
var nullableType = new RuntimeTypeHandleForNullableSystemTypeValue(new TypeProxy(type), new SystemTypeValue(underlyingType));
currentStack.Push(new StackSlot(nullableType));
return;
break;
default:
PushUnknown(currentStack);
return;
break;
}
}
else
{
var typeHandle = new RuntimeTypeHandleValue(new TypeProxy(type));
currentStack.Push(new StackSlot(typeHandle));
return;
}
}

HandleTypeReflectionAccess(methodBody, offset, type);
}
else if (operand is MethodDesc method)
{
Expand Down Expand Up @@ -1026,7 +1027,7 @@ protected void StoreInReference(MultiValue target, MultiValue source, MethodIL m
StoreMethodLocalValue(locals, source, localReference.LocalIndex, curBasicBlock);
break;
case FieldReferenceValue fieldReference
when GetFieldValue(fieldReference.FieldDefinition).AsSingleValue() is FieldValue fieldValue:
when HandleGetField(method, offset, fieldReference.FieldDefinition).AsSingleValue() is FieldValue fieldValue:
HandleStoreField(method, offset, fieldValue, source);
break;
case ParameterReferenceValue parameterReference
Expand All @@ -1038,7 +1039,7 @@ when GetMethodParameterValue(parameterReference.Parameter) is MethodParameterVal
HandleStoreMethodReturnValue(method, offset, methodReturnValue, source);
break;
case FieldValue fieldValue:
HandleStoreField(method, offset, fieldValue, DereferenceValue(source, locals, ref ipState));
HandleStoreField(method, offset, fieldValue, DereferenceValue(method, offset, source, locals, ref ipState));
break;
case IValueWithStaticType valueWithStaticType:
if (valueWithStaticType.StaticType is not null && FlowAnnotations.IsTypeInterestingForDataflow(valueWithStaticType.StaticType))
Expand All @@ -1057,7 +1058,25 @@ when GetMethodParameterValue(parameterReference.Parameter) is MethodParameterVal

}

protected abstract MultiValue GetFieldValue(FieldDesc field);
/// <summary>
/// HandleGetField is called every time the scanner needs to represent a value of the field
/// either as a source or target. It is not called when just a reference to field is created,
/// But if such reference is dereferenced then it will get called.
/// It is NOT called for hoisted locals.
/// </summary>
/// <remarks>
/// There should be no need to perform checks for hoisted locals. All of our reflection checks are based
/// on an assumption that problematic things happen because of running code. Doing things purely in the type system
/// (declaring new types which are never instantiated, declaring fields which are never assigned to, ...)
/// don't cause problems (or better way, they won't show observable behavioral differences).
/// Typically that would mean that accessing fields is also an uninteresting operation, unfortunately
/// static fields access can cause execution of static .cctor and that is running code -> possible problems.
/// So we have to track accesses in that case.
/// Hoisted locals are fields on closure classes/structs which should not have static .ctors, so we don't
/// need to track those. It makes the design a bit cleaner because hoisted locals are purely handled in here
/// and don't leak over to the reflection handling code in any way.
/// </remarks>
protected abstract MultiValue HandleGetField(MethodIL methodBody, int offset, FieldDesc field);

private void ScanLdfld(
MethodIL methodBody,
Expand All @@ -1083,7 +1102,7 @@ private void ScanLdfld(
}
else
{
value = GetFieldValue(field);
value = HandleGetField(methodBody, offset, field);
}
currentStack.Push(new StackSlot(value));
}
Expand Down Expand Up @@ -1119,15 +1138,15 @@ private void ScanStfld(
return;
}

foreach (var value in GetFieldValue(field))
foreach (var value in HandleGetField(methodBody, offset, field))
{
// GetFieldValue may return different node types, in which case they can't be stored to.
// At least not yet.
if (value is not FieldValue fieldValue)
continue;

// Incomplete handling of ref fields -- if we're storing a reference to a value, pretend it's just the value
MultiValue valueToStore = DereferenceValue(valueToStoreSlot.Value, locals, ref interproceduralState);
MultiValue valueToStore = DereferenceValue(methodBody, offset, valueToStoreSlot.Value, locals, ref interproceduralState);

HandleStoreField(methodBody, offset, fieldValue, valueToStore);
}
Expand Down Expand Up @@ -1163,7 +1182,12 @@ private ValueNodeList PopCallArguments(
return methodParams;
}

internal MultiValue DereferenceValue(MultiValue maybeReferenceValue, ValueBasicBlockPair?[] locals, ref InterproceduralState interproceduralState)
internal MultiValue DereferenceValue(
MethodIL methodBody,
int offset,
MultiValue maybeReferenceValue,
ValueBasicBlockPair?[] locals,
ref InterproceduralState interproceduralState)
{
MultiValue dereferencedValue = MultiValueLattice.Top;
foreach (var value in maybeReferenceValue)
Expand All @@ -1175,7 +1199,7 @@ internal MultiValue DereferenceValue(MultiValue maybeReferenceValue, ValueBasicB
dereferencedValue,
CompilerGeneratedState.IsHoistedLocal(fieldReferenceValue.FieldDefinition)
? interproceduralState.GetHoistedLocal(new HoistedLocalKey(fieldReferenceValue.FieldDefinition))
: GetFieldValue(fieldReferenceValue.FieldDefinition));
: HandleGetField(methodBody, offset, fieldReferenceValue.FieldDefinition));
break;
case ParameterReferenceValue parameterReferenceValue:
dereferencedValue = MultiValue.Meet(
Expand Down Expand Up @@ -1224,6 +1248,11 @@ protected void AssignRefAndOutParameters(
}
}

/// <summary>
/// Called when type is accessed directly (basically only ldtoken)
/// </summary>
protected abstract void HandleTypeReflectionAccess(MethodIL methodBody, int offset, TypeDesc accessedType);

/// <summary>
/// Called to handle reflection access to a method without any other specifics (ldtoken or ldftn for example)
/// </summary>
Expand Down Expand Up @@ -1260,7 +1289,7 @@ private void HandleCall(

var dereferencedMethodParams = new List<MultiValue>();
foreach (var argument in methodArguments)
dereferencedMethodParams.Add(DereferenceValue(argument, locals, ref interproceduralState));
dereferencedMethodParams.Add(DereferenceValue(callingMethodBody, offset, argument, locals, ref interproceduralState));
MultiValue methodReturnValue;
bool handledFunction = HandleCall(
callingMethodBody,
Expand Down
Loading