Skip to content

Commit 44c43a6

Browse files
committed
Cosmos: Add support for OfType/is operator
Translate OfType using is operator in all providers Add Cosmos version of InheritanceTestBase. No IncompleteMapping version as all discriminator predicates in cosmos are required. Resolves #16391
1 parent 7a6fca0 commit 44c43a6

16 files changed

+717
-778
lines changed

src/EFCore.Cosmos/Query/Internal/CosmosQueryableMethodTranslatingExpressionVisitor.cs

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -711,6 +711,53 @@ protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression s
711711
Check.NotNull(source, nameof(source));
712712
Check.NotNull(resultType, nameof(resultType));
713713

714+
if (source.ShaperExpression is EntityShaperExpression entityShaperExpression)
715+
{
716+
var entityType = entityShaperExpression.EntityType;
717+
if (entityType.ClrType == resultType)
718+
{
719+
return source;
720+
}
721+
722+
var parameterExpression = Expression.Parameter(entityShaperExpression.Type);
723+
var predicate = Expression.Lambda(Expression.TypeIs(parameterExpression, resultType), parameterExpression);
724+
var translation = TranslateLambdaExpression(source, predicate);
725+
if (translation == null)
726+
{
727+
// EntityType is not part of hierarchy
728+
return null;
729+
}
730+
731+
var selectExpression = (SelectExpression)source.QueryExpression;
732+
if (!(translation is SqlConstantExpression sqlConstantExpression
733+
&& sqlConstantExpression.Value is bool constantValue
734+
&& constantValue))
735+
{
736+
selectExpression.ApplyPredicate(translation);
737+
}
738+
739+
var baseType = entityType.GetAllBaseTypes().SingleOrDefault(et => et.ClrType == resultType);
740+
if (baseType != null)
741+
{
742+
return source.UpdateShaperExpression(entityShaperExpression.WithEntityType(baseType));
743+
}
744+
745+
var derivedType = entityType.GetDerivedTypes().Single(et => et.ClrType == resultType);
746+
var projectionBindingExpression = (ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression;
747+
748+
var projectionMember = projectionBindingExpression.ProjectionMember;
749+
Check.DebugAssert(new ProjectionMember().Equals(projectionMember), "Invalid ProjectionMember when processing OfType");
750+
751+
var entityProjectionExpression = (EntityProjectionExpression)selectExpression.GetMappedProjection(projectionMember);
752+
selectExpression.ReplaceProjectionMapping(
753+
new Dictionary<ProjectionMember, Expression>
754+
{
755+
{ projectionMember, entityProjectionExpression.UpdateEntityType(derivedType) }
756+
});
757+
758+
return source.UpdateShaperExpression(entityShaperExpression.WithEntityType(derivedType));
759+
}
760+
714761
return null;
715762
}
716763

src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -598,6 +598,43 @@ protected override Expression VisitUnary(UnaryExpression unaryExpression)
598598
return null;
599599
}
600600

601+
/// <inheritdoc />
602+
protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExpression)
603+
{
604+
Check.NotNull(typeBinaryExpression, nameof(typeBinaryExpression));
605+
606+
var innerExpression = Visit(typeBinaryExpression.Expression);
607+
608+
if (typeBinaryExpression.NodeType == ExpressionType.TypeIs
609+
&& innerExpression is EntityReferenceExpression entityReferenceExpression)
610+
{
611+
var entityType = entityReferenceExpression.EntityType;
612+
if (entityType.GetAllBaseTypesInclusive().Any(et => et.ClrType == typeBinaryExpression.TypeOperand))
613+
{
614+
return _sqlExpressionFactory.Constant(true);
615+
}
616+
617+
var derivedType = entityType.GetDerivedTypes().SingleOrDefault(et => et.ClrType == typeBinaryExpression.TypeOperand);
618+
if (derivedType != null
619+
&& TryBindMember(entityReferenceExpression,
620+
MemberIdentity.Create(entityType.GetDiscriminatorProperty().Name)) is SqlExpression discriminatorColumn)
621+
{
622+
var concreteEntityTypes = derivedType.GetConcreteDerivedTypesInclusive().ToList();
623+
624+
return concreteEntityTypes.Count == 1
625+
? _sqlExpressionFactory.Equal(
626+
discriminatorColumn,
627+
_sqlExpressionFactory.Constant(concreteEntityTypes[0].GetDiscriminatorValue()))
628+
: (SqlExpression)_sqlExpressionFactory.In(
629+
discriminatorColumn,
630+
_sqlExpressionFactory.Constant(concreteEntityTypes.Select(et => et.GetDiscriminatorValue()).ToList()),
631+
negated: false);
632+
}
633+
}
634+
635+
return null;
636+
}
637+
601638
private Expression TryBindMember(Expression source, MemberIdentity member)
602639
{
603640
if (!(source is EntityReferenceExpression entityReferenceExpression))

src/EFCore.Cosmos/Query/Internal/EntityProjectionExpression.cs

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ namespace Microsoft.EntityFrameworkCore.Cosmos.Query.Internal
2222
/// </summary>
2323
public class EntityProjectionExpression : Expression, IPrintableExpression, IAccessExpression
2424
{
25-
private readonly IDictionary<IProperty, IAccessExpression> _propertyExpressionsCache
25+
private readonly IDictionary<IProperty, IAccessExpression> _propertyExpressionsMap
2626
= new Dictionary<IProperty, IAccessExpression>();
2727

28-
private readonly IDictionary<INavigation, IAccessExpression> _navigationExpressionsCache
28+
private readonly IDictionary<INavigation, IAccessExpression> _navigationExpressionsMap
2929
= new Dictionary<INavigation, IAccessExpression>();
3030

3131
/// <summary>
@@ -121,10 +121,10 @@ public virtual Expression BindProperty([NotNull] IProperty property, bool client
121121
"GetProperty", nameof(IProperty), EntityType.DisplayName(), $"Property:{property.Name}"));
122122
}
123123

124-
if (!_propertyExpressionsCache.TryGetValue(property, out var expression))
124+
if (!_propertyExpressionsMap.TryGetValue(property, out var expression))
125125
{
126126
expression = new KeyAccessExpression(property, AccessExpression);
127-
_propertyExpressionsCache[property] = expression;
127+
_propertyExpressionsMap[property] = expression;
128128
}
129129

130130
if (!clientEval
@@ -153,20 +153,15 @@ public virtual Expression BindNavigation([NotNull] INavigation navigation, bool
153153
"GetNavigation", nameof(INavigation), EntityType.DisplayName(), $"Navigation:{navigation.Name}"));
154154
}
155155

156-
if (!_navigationExpressionsCache.TryGetValue(navigation, out var expression))
156+
if (!_navigationExpressionsMap.TryGetValue(navigation, out var expression))
157157
{
158-
if (navigation.IsCollection)
159-
{
160-
expression = new ObjectArrayProjectionExpression(navigation, AccessExpression);
161-
}
162-
else
163-
{
164-
expression = new EntityProjectionExpression(
158+
expression = navigation.IsCollection
159+
? new ObjectArrayProjectionExpression(navigation, AccessExpression)
160+
: (IAccessExpression)new EntityProjectionExpression(
165161
navigation.TargetEntityType,
166162
new ObjectAccessExpression(navigation, AccessExpression));
167-
}
168163

169-
_navigationExpressionsCache[navigation] = expression;
164+
_navigationExpressionsMap[navigation] = expression;
170165
}
171166

172167
if (!clientEval
@@ -231,6 +226,19 @@ private Expression BindMember(MemberIdentity member, Type entityClrType, bool cl
231226
return null;
232227
}
233228

229+
/// <summary>
230+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
231+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
232+
/// any release. You should only use it directly in your code with extreme caution and knowing that
233+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
234+
/// </summary>
235+
public virtual EntityProjectionExpression UpdateEntityType([NotNull] IEntityType derivedType)
236+
{
237+
Check.NotNull(derivedType, nameof(derivedType));
238+
239+
return new EntityProjectionExpression(derivedType, AccessExpression);
240+
}
241+
234242
/// <summary>
235243
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
236244
/// the same compatibility standards as public APIs. It may be changed or removed without notice in

src/EFCore.InMemory/Query/Internal/InMemoryExpressionTranslatingExpressionVisitor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,7 @@ protected override Expression VisitTypeBinary(TypeBinaryExpression typeBinaryExp
808808
}
809809
}
810810

811-
return Expression.Constant(false);
811+
return null;
812812
}
813813

814814
/// <summary>

src/EFCore.InMemory/Query/Internal/InMemoryQueryableMethodTranslatingExpressionVisitor.cs

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -800,56 +800,36 @@ protected override ShapedQueryExpression TranslateOfType(ShapedQueryExpression s
800800
return source;
801801
}
802802

803+
var parameterExpression = Expression.Parameter(entityShaperExpression.Type);
804+
var predicate = Expression.Lambda(Expression.TypeIs(parameterExpression, resultType), parameterExpression);
805+
source = TranslateWhere(source, predicate);
806+
if (source == null)
807+
{
808+
// EntityType is not part of hierarchy
809+
return null;
810+
}
811+
803812
var baseType = entityType.GetAllBaseTypes().SingleOrDefault(et => et.ClrType == resultType);
804813
if (baseType != null)
805814
{
806815
return source.UpdateShaperExpression(entityShaperExpression.WithEntityType(baseType));
807816
}
808817

809-
var derivedType = entityType.GetDerivedTypes().SingleOrDefault(et => et.ClrType == resultType);
810-
if (derivedType != null)
811-
{
812-
var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression;
813-
var discriminatorProperty = entityType.GetDiscriminatorProperty();
814-
var parameter = Expression.Parameter(entityType.ClrType);
818+
var derivedType = entityType.GetDerivedTypes().Single(et => et.ClrType == resultType);
819+
var inMemoryQueryExpression = (InMemoryQueryExpression)source.QueryExpression;
815820

816-
var equals = Expression.Equal(
817-
parameter.CreateEFPropertyExpression(discriminatorProperty),
818-
Expression.Constant(derivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType));
821+
var projectionBindingExpression = (ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression;
822+
var projectionMember = projectionBindingExpression.ProjectionMember;
823+
Check.DebugAssert(new ProjectionMember().Equals(projectionMember), "Invalid ProjectionMember when processing OfType");
819824

820-
foreach (var derivedDerivedType in derivedType.GetDerivedTypes())
825+
var entityProjectionExpression = (EntityProjectionExpression)inMemoryQueryExpression.GetMappedProjection(projectionMember);
826+
inMemoryQueryExpression.ReplaceProjectionMapping(
827+
new Dictionary<ProjectionMember, Expression>
821828
{
822-
equals = Expression.OrElse(
823-
equals,
824-
Expression.Equal(
825-
parameter.CreateEFPropertyExpression(discriminatorProperty),
826-
Expression.Constant(derivedDerivedType.GetDiscriminatorValue(), discriminatorProperty.ClrType)));
827-
}
829+
{ projectionMember, entityProjectionExpression.UpdateEntityType(derivedType) }
830+
});
828831

829-
var discriminatorPredicate = TranslateLambdaExpression(source, Expression.Lambda(equals, parameter));
830-
if (discriminatorPredicate == null)
831-
{
832-
return null;
833-
}
834-
835-
inMemoryQueryExpression.UpdateServerQueryExpression(
836-
Expression.Call(
837-
EnumerableMethods.Where.MakeGenericMethod(typeof(ValueBuffer)),
838-
inMemoryQueryExpression.ServerQueryExpression,
839-
discriminatorPredicate));
840-
841-
var projectionBindingExpression = (ProjectionBindingExpression)entityShaperExpression.ValueBufferExpression;
842-
var projectionMember = projectionBindingExpression.ProjectionMember;
843-
var entityProjection = (EntityProjectionExpression)inMemoryQueryExpression.GetMappedProjection(projectionMember);
844-
845-
inMemoryQueryExpression.ReplaceProjectionMapping(
846-
new Dictionary<ProjectionMember, Expression>
847-
{
848-
{ projectionMember, entityProjection.UpdateEntityType(derivedType) }
849-
});
850-
851-
return source.UpdateShaperExpression(entityShaperExpression.WithEntityType(derivedType));
852-
}
832+
return source.UpdateShaperExpression(entityShaperExpression.WithEntityType(derivedType));
853833
}
854834

855835
return null;

0 commit comments

Comments
 (0)