Skip to content

Query: Rewrite Entity Equality during translation phase #20447

New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Merged
merged 1 commit into from
Mar 31, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions src/EFCore.Cosmos/Query/Internal/ContainsTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,8 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
return _sqlExpressionFactory.In(arguments[1], arguments[0], false);
}

if (method.Name == nameof(IList.Contains)
&& arguments.Count == 1
&& method.DeclaringType.GetInterfaces().Append(method.DeclaringType).Any(
t => t == typeof(IList)
|| (t.IsGenericType
&& t.GetGenericTypeDefinition() == typeof(ICollection<>)))
if (arguments.Count == 1
&& method.IsContainsMethod()
&& ValidateValues(instance))
{
return _sqlExpressionFactory.In(arguments[0], instance, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public CosmosQueryableMethodTranslatingExpressionVisitor(
_model = queryCompilationContext.Model;
_sqlExpressionFactory = sqlExpressionFactory;
_sqlTranslator = new CosmosSqlTranslatingExpressionVisitor(
_model,
queryCompilationContext,
sqlExpressionFactory,
memberTranslatorProvider,
methodCallTranslatorProvider);
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,18 @@
using System.Linq.Expressions;
using System.Reflection;
using JetBrains.Annotations;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.InMemory.Internal;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Query;
using Microsoft.EntityFrameworkCore.Query.Internal;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.EntityFrameworkCore.Utilities;

namespace Microsoft.EntityFrameworkCore.InMemory.Query.Internal
{
public class InMemoryQueryableMethodTranslatingExpressionVisitor : QueryableMethodTranslatingExpressionVisitor
{
private static readonly MethodInfo _efPropertyMethod = typeof(EF).GetTypeInfo().GetDeclaredMethod(nameof(EF.Property));

private readonly InMemoryExpressionTranslatingExpressionVisitor _expressionTranslator;
private readonly WeakEntityExpandingExpressionVisitor _weakEntityExpandingExpressionVisitor;
private readonly InMemoryProjectionBindingExpressionVisitor _projectionBindingExpressionVisitor;
Expand All @@ -32,7 +29,7 @@ public InMemoryQueryableMethodTranslatingExpressionVisitor(
[NotNull] QueryCompilationContext queryCompilationContext)
: base(dependencies, subquery: false)
{
_expressionTranslator = new InMemoryExpressionTranslatingExpressionVisitor(this, queryCompilationContext.Model);
_expressionTranslator = new InMemoryExpressionTranslatingExpressionVisitor(queryCompilationContext, this);
_weakEntityExpandingExpressionVisitor = new WeakEntityExpandingExpressionVisitor(_expressionTranslator);
_projectionBindingExpressionVisitor = new InMemoryProjectionBindingExpressionVisitor(this, _expressionTranslator);
_model = queryCompilationContext.Model;
Expand Down Expand Up @@ -402,16 +399,14 @@ protected override ShapedQueryExpression TranslateJoin(
Check.NotNull(inner, nameof(inner));
Check.NotNull(resultSelector, nameof(resultSelector));

outerKeySelector = TranslateLambdaExpression(outer, outerKeySelector);
innerKeySelector = TranslateLambdaExpression(inner, innerKeySelector);
(outerKeySelector, innerKeySelector) = ProcessJoinKeySelector(outer, inner, outerKeySelector, innerKeySelector);

if (outerKeySelector == null
|| innerKeySelector == null)
{
return null;
}

(outerKeySelector, innerKeySelector) = AlignKeySelectorTypes(outerKeySelector, innerKeySelector);

var transparentIdentifierType = TransparentIdentifierFactory.Create(
resultSelector.Parameters[0].Type,
resultSelector.Parameters[1].Type);
Expand All @@ -429,6 +424,71 @@ protected override ShapedQueryExpression TranslateJoin(
transparentIdentifierType);
}

private (LambdaExpression OuterKeySelector, LambdaExpression InnerKeySelector) ProcessJoinKeySelector(
ShapedQueryExpression outer, ShapedQueryExpression inner, LambdaExpression outerKeySelector, LambdaExpression innerKeySelector)
{
var left = RemapLambdaBody(outer, outerKeySelector);
var right = RemapLambdaBody(inner, innerKeySelector);

var joinCondition = TranslateExpression(Expression.Equal(left, right));

var (outerKeyBody, innerKeyBody) = DecomposeJoinCondition(joinCondition);

if (outerKeyBody == null
|| innerKeyBody == null)
{
return (null, null);
}

outerKeySelector = Expression.Lambda(outerKeyBody, ((InMemoryQueryExpression)outer.QueryExpression).CurrentParameter);
innerKeySelector = Expression.Lambda(innerKeyBody, ((InMemoryQueryExpression)inner.QueryExpression).CurrentParameter);

return AlignKeySelectorTypes(outerKeySelector, innerKeySelector);
}

private static (Expression, Expression) DecomposeJoinCondition(Expression joinCondition)
{
var leftExpressions = new List<Expression>();
var rightExpressions = new List<Expression>();

return ProcessJoinCondition(joinCondition, leftExpressions, rightExpressions)
? leftExpressions.Count == 1
? (leftExpressions[0], rightExpressions[0])
: (CreateAnonymousObject(leftExpressions), CreateAnonymousObject(rightExpressions))
: (null, null);

static Expression CreateAnonymousObject(List<Expression> expressions)
=> Expression.New(
AnonymousObject.AnonymousObjectCtor,
Expression.NewArrayInit(
typeof(object),
expressions.Select(e => Expression.Convert(e, typeof(object)))));
}


private static bool ProcessJoinCondition(
Expression joinCondition, List<Expression> leftExpressions, List<Expression> rightExpressions)
{
if (joinCondition is BinaryExpression binaryExpression)
{
if (binaryExpression.NodeType == ExpressionType.Equal)
{
leftExpressions.Add(binaryExpression.Left);
rightExpressions.Add(binaryExpression.Right);

return true;
}

if (binaryExpression.NodeType == ExpressionType.AndAlso)
{
return ProcessJoinCondition(binaryExpression.Left, leftExpressions, rightExpressions)
&& ProcessJoinCondition(binaryExpression.Right, leftExpressions, rightExpressions);
}
}

return false;
}

private static (LambdaExpression OuterKeySelector, LambdaExpression InnerKeySelector)
AlignKeySelectorTypes(LambdaExpression outerKeySelector, LambdaExpression innerKeySelector)
{
Expand Down Expand Up @@ -477,15 +537,14 @@ protected override ShapedQueryExpression TranslateLeftJoin(
Check.NotNull(inner, nameof(inner));
Check.NotNull(resultSelector, nameof(resultSelector));

outerKeySelector = TranslateLambdaExpression(outer, outerKeySelector);
innerKeySelector = TranslateLambdaExpression(inner, innerKeySelector);
(outerKeySelector, innerKeySelector) = ProcessJoinKeySelector(outer, inner, outerKeySelector, innerKeySelector);

if (outerKeySelector == null
|| innerKeySelector == null)
{
return null;
}

(outerKeySelector, innerKeySelector) = AlignKeySelectorTypes(outerKeySelector, innerKeySelector);

var transparentIdentifierType = TransparentIdentifierFactory.Create(
resultSelector.Parameters[0].Type,
Expand Down Expand Up @@ -579,22 +638,16 @@ protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression s
var discriminatorProperty = entityType.GetDiscriminatorProperty();
var parameter = Expression.Parameter(entityType.ClrType);

var callEFProperty = Expression.Call(
_efPropertyMethod.MakeGenericMethod(
discriminatorProperty.ClrType),
parameter,
Expression.Constant(discriminatorProperty.Name));

var equals = Expression.Equal(
callEFProperty,
parameter.CreateEFPropertyExpression(discriminatorProperty),
Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType));

foreach (var derivedDerivedType in derivedType.GetDerivedTypes())
{
equals = Expression.OrElse(
equals,
Expression.Equal(
callEFProperty,
parameter.CreateEFPropertyExpression(discriminatorProperty),
Expression.Constant(derivedDerivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType)));
}

Expand Down
8 changes: 2 additions & 6 deletions src/EFCore.Relational/Query/Internal/ContainsTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,8 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method
return _sqlExpressionFactory.In(arguments[1], arguments[0], negated: false);
}

if (method.Name == nameof(IList.Contains)
&& arguments.Count == 1
&& method.DeclaringType.GetInterfaces().Append(method.DeclaringType).Any(
t => t == typeof(IList)
|| (t.IsGenericType
&& t.GetGenericTypeDefinition() == typeof(ICollection<>)))
if (arguments.Count == 1
&& method.IsContainsMethod()
&& ValidateValues(instance))
{
return _sqlExpressionFactory.In(arguments[0], instance, negated: false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class RelationalQueryableMethodTranslatingExpressionVisitor : QueryableMe
private readonly RelationalSqlTranslatingExpressionVisitor _sqlTranslator;
private readonly WeakEntityExpandingExpressionVisitor _weakEntityExpandingExpressionVisitor;
private readonly RelationalProjectionBindingExpressionVisitor _projectionBindingExpressionVisitor;
private readonly QueryCompilationContext _queryCompilationContext;
private readonly IModel _model;
private readonly ISqlExpressionFactory _sqlExpressionFactory;
private readonly bool _subquery;
Expand Down Expand Up @@ -54,7 +55,7 @@ protected RelationalQueryableMethodTranslatingExpressionVisitor(
: base(parentVisitor.Dependencies, subquery: true)
{
RelationalDependencies = parentVisitor.RelationalDependencies;
_model = parentVisitor._model;
_queryCompilationContext = parentVisitor._queryCompilationContext;
_sqlTranslator = parentVisitor._sqlTranslator;
_weakEntityExpandingExpressionVisitor = parentVisitor._weakEntityExpandingExpressionVisitor;
_projectionBindingExpressionVisitor = new RelationalProjectionBindingExpressionVisitor(this, _sqlTranslator);
Expand Down Expand Up @@ -116,7 +117,7 @@ protected override ShapedQueryExpression CreateShapedQueryExpression(Type elemen
{
Check.NotNull(elementType, nameof(elementType));

var entityType = _model.FindEntityType(elementType);
var entityType = _queryCompilationContext.Model.FindEntityType(elementType);
var queryExpression = _sqlExpressionFactory.Select(entityType);

return CreateShapedQueryExpression(entityType, queryExpression);
Expand Down
Loading