Skip to content

Query: Key comparison should use object.Equals internally in query #21742

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
1 commit merged into from
Jul 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/EFCore.Cosmos/Query/Internal/EqualsTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
&& right != null)
{
return left.Type.UnwrapNullableType() == right.Type.UnwrapNullableType()
? (SqlExpression)_sqlExpressionFactory.Equal(left, right)
: _sqlExpressionFactory.Constant(false);
|| (right.Type == typeof(object) && right is SqlParameterExpression)
|| (left.Type == typeof(object) && left is SqlParameterExpression)
? _sqlExpressionFactory.Equal(left, right)
: (SqlExpression)_sqlExpressionFactory.Constant(false);
}

return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
&& binaryExpression.Left is NewArrayExpression
&& binaryExpression.NodeType == ExpressionType.Equal)
{
return Visit(ConvertObjectArrayEqualityComparison(binaryExpression));
return Visit(ConvertObjectArrayEqualityComparison(binaryExpression.Left, binaryExpression.Right));
}

var newLeft = Visit(binaryExpression.Left);
Expand Down Expand Up @@ -557,6 +557,13 @@ MethodInfo GetMethod()
&& methodCallExpression.Object == null
&& methodCallExpression.Arguments.Count == 2)
{
if (methodCallExpression.Arguments[0].Type == typeof(object[])
&& methodCallExpression.Arguments[0] is NewArrayExpression)
{
return Visit(ConvertObjectArrayEqualityComparison(
methodCallExpression.Arguments[0], methodCallExpression.Arguments[1]));
}

var left = Visit(methodCallExpression.Arguments[0]);
var right = Visit(methodCallExpression.Arguments[1]);

Expand Down Expand Up @@ -1262,10 +1269,10 @@ private static bool CanEvaluate(Expression expression)
}
}

private static Expression ConvertObjectArrayEqualityComparison(BinaryExpression binaryExpression)
private static Expression ConvertObjectArrayEqualityComparison(Expression left, Expression right)
{
var leftExpressions = ((NewArrayExpression)binaryExpression.Left).Expressions;
var rightExpressions = ((NewArrayExpression)binaryExpression.Right).Expressions;
var leftExpressions = ((NewArrayExpression)left).Expressions;
var rightExpressions = ((NewArrayExpression)right).Expressions;

return leftExpressions.Zip(
rightExpressions,
Expand Down
11 changes: 5 additions & 6 deletions src/EFCore.Relational/Query/Internal/EqualsTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,11 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
if (left != null
&& right != null)
{
if (left.Type == right.Type)
{
return _sqlExpressionFactory.Equal(left, right);
}

return _sqlExpressionFactory.Constant(false);
return left.Type == right.Type
|| (right.Type == typeof(object) && right is SqlParameterExpression)
|| (left.Type == typeof(object) && left is SqlParameterExpression)
? _sqlExpressionFactory.Equal(left, right)
: (SqlExpression)_sqlExpressionFactory.Constant(false);
}

return null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
&& binaryExpression.Left is NewArrayExpression
&& binaryExpression.NodeType == ExpressionType.Equal)
{
return Visit(ConvertObjectArrayEqualityComparison(binaryExpression));
return Visit(ConvertObjectArrayEqualityComparison(binaryExpression.Left, binaryExpression.Right));
}

var left = TryRemoveImplicitConvert(binaryExpression.Left);
Expand Down Expand Up @@ -624,6 +624,13 @@ static bool IsAggregateResultWithCustomShaper(MethodInfo method)
&& methodCallExpression.Object == null
&& methodCallExpression.Arguments.Count == 2)
{
if (methodCallExpression.Arguments[0].Type == typeof(object[])
&& methodCallExpression.Arguments[0] is NewArrayExpression)
{
return Visit(ConvertObjectArrayEqualityComparison(
methodCallExpression.Arguments[0], methodCallExpression.Arguments[1]));
}

var left = Visit(RemoveObjectConvert(methodCallExpression.Arguments[0]));
var right = Visit(RemoveObjectConvert(methodCallExpression.Arguments[1]));

Expand Down Expand Up @@ -1000,10 +1007,10 @@ private static Expression RemoveObjectConvert(Expression expression)
? unaryExpression.Operand
: expression;

private static Expression ConvertObjectArrayEqualityComparison(BinaryExpression binaryExpression)
private static Expression ConvertObjectArrayEqualityComparison(Expression left, Expression right)
{
var leftExpressions = ((NewArrayExpression)binaryExpression.Left).Expressions;
var rightExpressions = ((NewArrayExpression)binaryExpression.Right).Expressions;
var leftExpressions = ((NewArrayExpression)left).Expressions;
var rightExpressions = ((NewArrayExpression)right).Expressions;

return leftExpressions.Zip(
rightExpressions,
Expand Down
47 changes: 33 additions & 14 deletions src/EFCore/Internal/EntityFinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ namespace Microsoft.EntityFrameworkCore.Internal
public class EntityFinder<TEntity> : IEntityFinder<TEntity>
where TEntity : class
{
private static readonly MethodInfo _objectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) });

private readonly IStateManager _stateManager;
private readonly IDbSetSource _setSource;
private readonly IDbSetCache _setCache;
Expand Down Expand Up @@ -354,34 +357,50 @@ private static IQueryable<TResult> Select<TSource, TResult>(
parameter));
}

private static BinaryExpression BuildPredicate(
private static Expression BuildPredicate(
IReadOnlyList<IProperty> keyProperties,
ValueBuffer keyValues,
ParameterExpression entityParameter)
{
var keyValuesConstant = Expression.Constant(keyValues);

var predicate = GenerateEqualExpression(keyProperties[0], 0);
var predicate = GenerateEqualExpression(entityParameter, keyValuesConstant, keyProperties[0], 0);

for (var i = 1; i < keyProperties.Count; i++)
{
predicate = Expression.AndAlso(predicate, GenerateEqualExpression(keyProperties[i], i));
predicate = Expression.AndAlso(predicate, GenerateEqualExpression(entityParameter, keyValuesConstant, keyProperties[i], i));
}

return predicate;

BinaryExpression GenerateEqualExpression(IProperty property, int i) =>
Expression.Equal(
Expression.Call(
EF.PropertyMethod.MakeGenericMethod(property.ClrType),
entityParameter,
Expression.Constant(property.Name, typeof(string))),
Expression.Convert(
static Expression GenerateEqualExpression(
Expression entityParameterExpression, Expression keyValuesConstantExpression, IProperty property, int i)
=> property.ClrType.IsValueType
&& property.ClrType.UnwrapNullableType() is Type nonNullableType
&& !(nonNullableType == typeof(bool) || nonNullableType.IsNumeric() || nonNullableType.IsEnum)
? Expression.Call(
_objectEqualsMethodInfo,
Expression.Call(
keyValuesConstant,
ValueBuffer.GetValueMethod,
Expression.Constant(i)),
property.ClrType));
EF.PropertyMethod.MakeGenericMethod(typeof(object)),
entityParameterExpression,
Expression.Constant(property.Name, typeof(string))),
Expression.Convert(
Expression.Call(
keyValuesConstantExpression,
ValueBuffer.GetValueMethod,
Expression.Constant(i)),
typeof(object)))
: (Expression)Expression.Equal(
Expression.Call(
EF.PropertyMethod.MakeGenericMethod(property.ClrType),
entityParameterExpression,
Expression.Constant(property.Name, typeof(string))),
Expression.Convert(
Expression.Call(
keyValuesConstantExpression,
ValueBuffer.GetValueMethod,
Expression.Constant(i)),
property.ClrType));
}

private static Expression<Func<object, object[]>> BuildProjection(IEntityType entityType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ public partial class NavigationExpandingExpressionVisitor
/// </summary>
private class ExpandingExpressionVisitor : ExpressionVisitor
{
private static readonly MethodInfo _objectEqualsMethodInfo
= typeof(object).GetRuntimeMethod(nameof(object.Equals), new[] { typeof(object), typeof(object) });

private readonly NavigationExpandingExpressionVisitor _navigationExpandingExpressionVisitor;
private readonly NavigationExpansionExpression _source;

Expand Down Expand Up @@ -393,7 +396,7 @@ outerKey is NewArrayExpression newArrayExpression
})
.Aggregate((l, r) => Expression.AndAlso(l, r))
: Expression.NotEqual(outerKey, Expression.Constant(null, outerKey.Type)),
Expression.Equal(outerKey, innerKey));
Expression.Call(_objectEqualsMethodInfo, AddConvertToObject(outerKey), AddConvertToObject(innerKey)));

// Caller should take care of wrapping MaterializeCollectionNavigation
return Expression.Call(
Expand Down Expand Up @@ -455,6 +458,11 @@ outerKey is NewArrayExpression newArrayExpression

return innerSource.PendingSelector;
}

static Expression AddConvertToObject(Expression expression)
=> expression.Type.IsValueType
? Expression.Convert(expression, typeof(object))
: expression;
}

/// <summary>
Expand Down
Loading