Skip to content

Commit

Permalink
Enable basic end-to-end scenarios of extension method invocation. (#7…
Browse files Browse the repository at this point in the history
  • Loading branch information
AlekseyTs authored Feb 25, 2025
1 parent a5f65f6 commit 74d162a
Show file tree
Hide file tree
Showing 33 changed files with 2,369 additions and 390 deletions.
26 changes: 13 additions & 13 deletions src/Compilers/CSharp/Portable/BoundTree/BoundTreeRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,55 +62,55 @@ private ImmutableArray<T> DoVisitList<T>(ImmutableArray<T> list) where T : Bound
}

[return: NotNullIfNotNull(nameof(symbol))]
protected virtual AliasSymbol? VisitAliasSymbol(AliasSymbol? symbol) => symbol;
public virtual AliasSymbol? VisitAliasSymbol(AliasSymbol? symbol) => symbol;

protected virtual DiscardSymbol VisitDiscardSymbol(DiscardSymbol symbol)
public virtual DiscardSymbol VisitDiscardSymbol(DiscardSymbol symbol)
{
Debug.Assert(symbol is not null);
return symbol;
}

protected virtual EventSymbol VisitEventSymbol(EventSymbol symbol)
public virtual EventSymbol VisitEventSymbol(EventSymbol symbol)
{
Debug.Assert(symbol is not null);
return symbol;
}

[return: NotNullIfNotNull(nameof(symbol))]
protected virtual LabelSymbol? VisitLabelSymbol(LabelSymbol? symbol) => symbol;
public virtual LabelSymbol? VisitLabelSymbol(LabelSymbol? symbol) => symbol;

protected virtual LocalSymbol VisitLocalSymbol(LocalSymbol symbol)
public virtual LocalSymbol VisitLocalSymbol(LocalSymbol symbol)
{
Debug.Assert(symbol is not null);
return symbol;
}

protected virtual NamespaceSymbol VisitNamespaceSymbol(NamespaceSymbol symbol)
public virtual NamespaceSymbol VisitNamespaceSymbol(NamespaceSymbol symbol)
{
Debug.Assert(symbol is not null);
return symbol;
}

[return: NotNullIfNotNull(nameof(symbol))]
protected virtual RangeVariableSymbol? VisitRangeVariableSymbol(RangeVariableSymbol? symbol) => symbol;
public virtual RangeVariableSymbol? VisitRangeVariableSymbol(RangeVariableSymbol? symbol) => symbol;

[return: NotNullIfNotNull(nameof(symbol))]
protected virtual FieldSymbol? VisitFieldSymbol(FieldSymbol? symbol) => symbol;
public virtual FieldSymbol? VisitFieldSymbol(FieldSymbol? symbol) => symbol;

protected virtual ParameterSymbol VisitParameterSymbol(ParameterSymbol symbol)
public virtual ParameterSymbol VisitParameterSymbol(ParameterSymbol symbol)
{
Debug.Assert(symbol is not null);
return symbol;
}

[return: NotNullIfNotNull(nameof(symbol))]
protected virtual PropertySymbol? VisitPropertySymbol(PropertySymbol? symbol) => symbol;
public virtual PropertySymbol? VisitPropertySymbol(PropertySymbol? symbol) => symbol;

[return: NotNullIfNotNull(nameof(symbol))]
protected virtual MethodSymbol? VisitMethodSymbol(MethodSymbol? symbol) => symbol;
public virtual MethodSymbol? VisitMethodSymbol(MethodSymbol? symbol) => symbol;

[return: NotNullIfNotNull(nameof(symbol))]
protected Symbol? VisitSymbol(Symbol? symbol)
public Symbol? VisitSymbol(Symbol? symbol)
{
if (symbol is null)
{
Expand Down Expand Up @@ -158,7 +158,7 @@ protected virtual ParameterSymbol VisitParameterSymbol(ParameterSymbol symbol)
return (FunctionTypeSymbol?)VisitType(symbol);
}

protected ImmutableArray<T> VisitSymbols<T>(ImmutableArray<T> symbols) where T : Symbol?
public ImmutableArray<T> VisitSymbols<T>(ImmutableArray<T> symbols) where T : Symbol?
{
if (symbols.IsDefault)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ internal override void GenerateMethodBody(TypeCompilationState compilationState,
}
}

internal sealed partial class SynthesizedSealedPropertyAccessor : SynthesizedInstanceMethodSymbol
internal sealed partial class SynthesizedSealedPropertyAccessor : SynthesizedMethodSymbol
{
internal override bool SynthesizesLoweredBoundBody
{
Expand Down
59 changes: 40 additions & 19 deletions src/Compilers/CSharp/Portable/Compiler/MethodCompiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,12 @@ private void CompileNamedType(NamedTypeSymbol containingType)

if (compilationState.Emitting)
{
if (containingType is { IsExtension: true, ExtensionParameter: not null })
{
// PROTOTYPE the extension marker should also be added in metadata-only emit scenario (see SynthesizedMetadataCompiler)
CompileSynthesizedExtensionMarker(containingType, compilationState);
}

CompileSynthesizedExplicitImplementations(sourceTypeSymbol, compilationState);
}
}
Expand Down Expand Up @@ -661,6 +667,24 @@ private void CompileNamedType(NamedTypeSymbol containingType)
compilationState.Free();
}

private void CompileSynthesizedExtensionMarker(NamedTypeSymbol sourceExtension, TypeCompilationState compilationState)
{
if (!_globalHasErrors)
{
var extensionMarker = new SynthesizedExtensionMarker(sourceExtension, _diagnostics);

#if DEBUG
var discardedDiagnostics = BindingDiagnosticBag.GetInstance(withDiagnostics: true, withDependencies: false);
#else
var discardedDiagnostics = BindingDiagnosticBag.Discarded;
#endif
extensionMarker.GenerateMethodBody(compilationState, discardedDiagnostics);
Debug.Assert(!discardedDiagnostics.HasAnyErrors());

_moduleBeingBuiltOpt.AddSynthesizedDefinition(sourceExtension, extensionMarker.GetCciAdapter());
}
}

internal static MethodSymbol GetMethodToCompile(MethodSymbol method)
{
if (method.IsPartialDefinition())
Expand Down Expand Up @@ -1436,18 +1460,6 @@ private static MethodSymbol GetSymbolForEmittedBody(MethodSymbol methodSymbol)
return methodSymbol.PartialDefinitionPart ?? methodSymbol;
}

internal static SourceExtensionImplementationMethodSymbol TryGetCorrespondingExtensionImplementationMethod(MethodSymbol methodSymbol)
{
if (methodSymbol.ContainingType.IsExtension)
{
return methodSymbol.ContainingType.ContainingType.
GetMembers((methodSymbol.IsStatic ? SourceExtensionImplementationMethodSymbol.StaticExtensionNamePrefix : SourceExtensionImplementationMethodSymbol.InstanceExtensionNamePrefix) + methodSymbol.Name).
OfType<SourceExtensionImplementationMethodSymbol>().Where(m => (object)m.UnderlyingMethod == methodSymbol).SingleOrDefault();
}

return null;
}

// internal for testing
internal static BoundStatement LowerBodyOrInitializer(
MethodSymbol method,
Expand Down Expand Up @@ -1500,6 +1512,22 @@ internal static BoundStatement LowerBodyOrInitializer(
return loweredBody;
}

if (extensionImplementationMethod is not null)
{
var extensionRewriter = new ExtensionMethodBodyRewriter(method, extensionImplementationMethod);
loweredBody = (BoundStatement)extensionRewriter.Visit(loweredBody);
method = extensionImplementationMethod;
}
else
{
loweredBody = ExtensionMethodReferenceRewriter.Rewrite(loweredBody);
}

if (loweredBody.HasErrors)
{
return loweredBody;
}

if (sawAwaitInExceptionHandler)
{
// If we have awaits in handlers, we need to
Expand All @@ -1520,13 +1548,6 @@ internal static BoundStatement LowerBodyOrInitializer(
return loweredBody;
}

if (extensionImplementationMethod is not null)
{
var extensionRewriter = new ExtensionMethodBodyRewriter(method, extensionImplementationMethod);
loweredBody = (BoundStatement)extensionRewriter.Visit(loweredBody);
method = extensionImplementationMethod;
}

lazyVariableSlotAllocator ??= compilationState.ModuleBuilderOpt.TryCreateVariableSlotAllocator(method, method, diagnostics.DiagnosticBag);

BoundStatement bodyWithoutLambdas = loweredBody;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2105,7 +2105,7 @@ public override void AddSynthesizedDefinition(NamedTypeSymbol container, Cci.IFi

public override void AddSynthesizedDefinition(NamedTypeSymbol container, Cci.IMethodDefinition method)
{
Debug.Assert(container is not NamedTypeSymbol { IsExtension: true });
Debug.Assert(container is not NamedTypeSymbol { IsExtension: true } || method.GetInternalSymbol() is SynthesizedExtensionMarker);
base.AddSynthesizedDefinition(container, method);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ protected sealed override ImmutableArray<LocalSymbol> VisitLocals(ImmutableArray
return newLocals.ToImmutableAndFree();
}

protected sealed override LocalSymbol VisitLocalSymbol(LocalSymbol local)
public sealed override LocalSymbol VisitLocalSymbol(LocalSymbol local)
{
if (!TryRewriteLocal(local, out var newLocal))
{
Expand Down Expand Up @@ -148,7 +148,7 @@ public override BoundNode VisitBinaryOperator(BoundBinaryOperator node)
}

[return: NotNullIfNotNull(nameof(method))]
protected override MethodSymbol? VisitMethodSymbol(MethodSymbol? method)
public override MethodSymbol? VisitMethodSymbol(MethodSymbol? method)
{
if (method is null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ public ExtensionMethodBodyRewriter(MethodSymbol sourceMethod, SourceExtensionImp
throw ExceptionUtilities.Unreachable();
}

protected override ParameterSymbol VisitParameterSymbol(ParameterSymbol symbol)
public override ParameterSymbol VisitParameterSymbol(ParameterSymbol symbol)
{
return (ParameterSymbol)_symbolMap[symbol];
}
Expand Down Expand Up @@ -141,7 +141,7 @@ protected override ImmutableArray<MethodSymbol> VisitDeclaredLocalFunctions(Immu
}

[return: NotNullIfNotNull(nameof(symbol))]
protected override MethodSymbol? VisitMethodSymbol(MethodSymbol? symbol)
public override MethodSymbol? VisitMethodSymbol(MethodSymbol? symbol)
{
switch (symbol?.MethodKind)
{
Expand All @@ -162,7 +162,7 @@ protected override ImmutableArray<MethodSymbol> VisitDeclaredLocalFunctions(Immu
}

[return: NotNullIfNotNull(nameof(symbol))]
protected override FieldSymbol? VisitFieldSymbol(FieldSymbol? symbol)
public override FieldSymbol? VisitFieldSymbol(FieldSymbol? symbol)
{
if (symbol is null)
{
Expand All @@ -173,6 +173,11 @@ protected override ImmutableArray<MethodSymbol> VisitDeclaredLocalFunctions(Immu
.AsMember((NamedTypeSymbol)TypeMap.SubstituteType(symbol.ContainingType).AsTypeSymbolOnly());
}

// PROTOTYPE: Handle deep recursion on long chain of binary operators, calls, etc.
public override BoundNode? VisitCall(BoundCall node)
{
return ExtensionMethodReferenceRewriter.VisitCall(this, node);
}

// PROTOTYPE: Handle deep recursion on long chain of binary operators, etc.
}
}
Loading

0 comments on commit 74d162a

Please # to comment.