Skip to content

Commit

Permalink
CA2251: support for string.CompareOrdinal
Browse files Browse the repository at this point in the history
  • Loading branch information
allantargino committed Mar 1, 2025
1 parent 345816f commit 459168a
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ private static ImmutableArray<OperationReplacer> GetOperationReplacers(RequiredS
return ImmutableArray.Create<OperationReplacer>(
new StringStringCaseReplacer(symbols),
new StringStringBoolReplacer(symbols),
new StringStringStringComparisonReplacer(symbols));
new StringStringStringComparisonReplacer(symbols),
new OrdinalStringStringCaseReplacer(symbols));
}

/// <summary>
Expand All @@ -86,15 +87,15 @@ protected OperationReplacer(RequiredSymbols symbols)
/// <summary>
/// Indicates whether the current <see cref="OperationReplacer"/> applies to the specified violation.
/// </summary>
/// <param name="violation">The <see cref="IBinaryOperation"/> at the location reported by the analyzer.</param>
/// <param name="violation">The <see cref="IBinaryOperation"/> or <see cref="IInvocationOperation"/> at the location reported by the analyzer.</param>
/// <returns>True if the current <see cref="OperationReplacer"/> applies to the specified violation.</returns>
public abstract bool IsMatch(IOperation violation);

/// <summary>
/// Creates a replacement node for a violation that the current <see cref="OperationReplacer"/> applies to.
/// Asserts if the current <see cref="OperationReplacer"/> does not apply to the specified violation.
/// </summary>
/// <param name="violation">The <see cref="IBinaryOperation"/> obtained at the location reported by the analyzer.
/// <param name="violation">The <see cref="IBinaryOperation"/> or <see cref="IInvocationOperation"/> obtained at the location reported by the analyzer.
/// <see cref="IsMatch(IOperation)"/> must return <see langword="true"/> for this operation.</param>
/// <param name="generator"></param>
/// <returns></returns>
Expand Down Expand Up @@ -229,5 +230,29 @@ public override SyntaxNode CreateReplacementExpression(IOperation violation, Syn
return InvertIfNotEquals(equalsInvocationSyntax, violation, generator);
}
}

/// <summary>
/// Replaces <see cref="string.CompareOrdinal(string, string)"/> violations.
/// </summary>
private sealed class OrdinalStringStringCaseReplacer : OperationReplacer
{
public OrdinalStringStringCaseReplacer(RequiredSymbols symbols)
: base(symbols)
{ }

public override bool IsMatch(IOperation violation) => UseStringEqualsOverStringCompare.IsOrdinalStringStringCase(violation, Symbols);

public override SyntaxNode CreateReplacementExpression(IOperation violation, SyntaxGenerator generator)
{
RoslynDebug.Assert(IsMatch(violation));

var compareInvocation = GetInvocation(violation);
var equalsInvocationSyntax = generator.InvocationExpression(
CreateEqualsMemberAccess(generator),
compareInvocation.Arguments.GetArgumentsInParameterOrder().Select(x => x.Value.Syntax));

return InvertIfNotEquals(equalsInvocationSyntax, violation, generator);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ private RequiredSymbols(
IMethodSymbol? compareStringString,
IMethodSymbol? compareStringStringBool,
IMethodSymbol? compareStringStringStringComparison,
IMethodSymbol? compareOrdinalStringString,
IMethodSymbol? equalsStringString,
IMethodSymbol? equalsStringStringStringComparison,
IMethodSymbol intEquals)
Expand All @@ -89,6 +90,7 @@ private RequiredSymbols(
CompareStringString = compareStringString;
CompareStringStringBool = compareStringStringBool;
CompareStringStringStringComparison = compareStringStringStringComparison;
CompareOrdinalStringString = compareOrdinalStringString;
EqualsStringString = equalsStringString;
EqualsStringStringStringComparison = equalsStringStringStringComparison;
IntEquals = intEquals;
Expand All @@ -115,12 +117,17 @@ public static bool TryGetSymbols(Compilation compilation, [NotNullWhen(true)] ou
var compareStringStringBool = compareMethods.GetFirstOrDefaultMemberWithParameterTypes(stringType, stringType, boolType);
var compareStringStringStringComparison = compareMethods.GetFirstOrDefaultMemberWithParameterTypes(stringType, stringType, stringComparisonType);

var compareOrdinalMethods = stringType.GetMembers(nameof(string.CompareOrdinal))
.OfType<IMethodSymbol>()
.Where(x => x.IsStatic);
var compareOrdinalStringString = compareOrdinalMethods.GetFirstOrDefaultMemberWithParameterTypes(stringType, stringType);

var equalsMethods = stringType.GetMembers(nameof(string.Equals))
.OfType<IMethodSymbol>()
.Where(x => x.IsStatic);
var equalsStringString = equalsMethods.GetFirstOrDefaultMemberWithParameterTypes(stringType, stringType);
var equalsStringStringStringComparison = equalsMethods.GetFirstOrDefaultMemberWithParameterTypes(stringType, stringType, stringComparisonType);
var intType = typeProvider.GetOrCreateTypeByMetadataName(WellKnownTypeNames.SystemInt32);
var intType = compilation.GetSpecialType(SpecialType.System_Int32);
var intEquals = intType
?.GetMembers(nameof(int.Equals))
.OfType<IMethodSymbol>()
Expand All @@ -133,14 +140,15 @@ public static bool TryGetSymbols(Compilation compilation, [NotNullWhen(true)] ou
// Bail if we do not have at least one complete pair of Compare-Equals methods in the compilation.
if ((compareStringString is null || equalsStringString is null) &&
(compareStringStringBool is null || equalsStringStringStringComparison is null) &&
(compareStringStringStringComparison is null || equalsStringStringStringComparison is null))
(compareStringStringStringComparison is null || equalsStringStringStringComparison is null) &&
(compareOrdinalStringString is null || equalsStringString is null))
{
return false;
}

symbols = new RequiredSymbols(
stringType, boolType, stringComparisonType,
compareStringString, compareStringStringBool, compareStringStringStringComparison,
compareStringString, compareStringStringBool, compareStringStringStringComparison, compareOrdinalStringString,
equalsStringString, equalsStringStringStringComparison, intEquals);
return true;
}
Expand All @@ -151,6 +159,7 @@ public static bool TryGetSymbols(Compilation compilation, [NotNullWhen(true)] ou
public IMethodSymbol? CompareStringString { get; }
public IMethodSymbol? CompareStringStringBool { get; }
public IMethodSymbol? CompareStringStringStringComparison { get; }
public IMethodSymbol? CompareOrdinalStringString { get; }
public IMethodSymbol? EqualsStringString { get; }
public IMethodSymbol? EqualsStringStringStringComparison { get; }
public IMethodSymbol IntEquals { get; }
Expand Down Expand Up @@ -291,10 +300,39 @@ internal static bool IsStringStringStringComparisonCase(IOperation operation, Re
invocation.TargetMethod.Equals(symbols.CompareStringStringStringComparison, SymbolEqualityComparer.Default);
}

/// <summary>
/// Returns true if the specified <see cref="IBinaryOperation"/> or <see cref="IInvocationOperation"/>:
/// <list type="bullet">
/// <item>Is an equals or not-equals operation</item>
/// <item>One operand is a literal zero</item>
/// <item>The other operand is any invocation of <see cref="string.CompareOrdinal(string, string)"/></item>
/// </list>
/// </summary>
/// <param name="operation"></param>
/// <param name="symbols"></param>
/// <returns></returns>
internal static bool IsOrdinalStringStringCase(IOperation operation, RequiredSymbols symbols)
{
// Don't report a diagnostic if either the string.CompareOrdinal overload or the
// corresponding string.Equals overload is missing.
if (symbols.CompareOrdinalStringString is null ||
symbols.EqualsStringString is null)
{
return false;
}

var invocation = GetInvocationFromEqualityCheckWithLiteralZero(operation as IBinaryOperation)
?? GetInvocationFromEqualsCheckWithLiteralZero(operation as IInvocationOperation, symbols.IntEquals);

return invocation is not null &&
invocation.TargetMethod.Equals(symbols.CompareOrdinalStringString, SymbolEqualityComparer.Default);
}

private static readonly ImmutableArray<Func<IOperation, RequiredSymbols, bool>> CaseSelectors =
ImmutableArray.Create<Func<IOperation, RequiredSymbols, bool>>(
IsStringStringCase,
IsStringStringBoolCase,
IsStringStringStringComparisonCase);
IsStringStringStringComparisonCase,
IsOrdinalStringStringCase);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ private static readonly (string CompareCall, string EqualsCall)[] CS_ComparisonE
("string.Compare(x, y, StringComparison.CurrentCulture)", "string.Equals(x, y, StringComparison.CurrentCulture)"),
("string.Compare(x, y, StringComparison.Ordinal)", "string.Equals(x, y, StringComparison.Ordinal)"),
("string.Compare(x, y, StringComparison.OrdinalIgnoreCase)", "string.Equals(x, y, StringComparison.OrdinalIgnoreCase)"),
("string.CompareOrdinal(x, y)", "string.Equals(x, y)"),
};

private static readonly (string CompareCall, string EqualsCall)[] VB_ComparisonEqualityMethodPairs = new[]
Expand All @@ -39,6 +40,7 @@ private static readonly (string CompareCall, string EqualsCall)[] VB_ComparisonE
("String.Compare(x, y, StringComparison.CurrentCulture)", "String.Equals(x, y, StringComparison.CurrentCulture)"),
("String.Compare(x, y, StringComparison.Ordinal)", "String.Equals(x, y, StringComparison.Ordinal)"),
("String.Compare(x, y, StringComparison.OrdinalIgnoreCase)", "String.Equals(x, y, StringComparison.OrdinalIgnoreCase)"),
("String.CompareOrdinal(x, y)", "String.Equals(x, y)"),
};

public static IEnumerable<object[]> CS_ComparisonLeftOfLiteralTestData { get; } = CS_ComparisonEqualityMethodCallPairs
Expand Down Expand Up @@ -83,6 +85,7 @@ public static IEnumerable<object[]> CS_IneligibleStringCompareOverloadTestData
{
yield return new[] { "string.Compare(x, y, true, System.Globalization.CultureInfo.InvariantCulture)" };
yield return new[] { "string.Compare(x, y, System.Globalization.CultureInfo.InvariantCulture, System.Globalization.CompareOptions.None)" };
yield return new[] { "string.CompareOrdinal(x, indexA: 0, y, indexB: 0, length: 0)" };
}
}

Expand All @@ -92,6 +95,7 @@ public static IEnumerable<object[]> VB_IneligibleStringCompareOverloadTestData
{
yield return new[] { "String.Compare(x, y, true, System.Globalization.CultureInfo.InvariantCulture)" };
yield return new[] { "String.Compare(x, y, System.Globalization.CultureInfo.InvariantCulture, System.Globalization.CompareOptions.None)" };
yield return new[] { "String.CompareOrdinal(x, indexA:= 0, y, indexB:= 0, length:= 0)" };
}
}

Expand Down

0 comments on commit 459168a

Please # to comment.