diff --git a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs index a9949af0d0d..0af3351d300 100644 --- a/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs +++ b/src/EFCore.Relational/Query/SqlExpressions/SelectExpression.Helper.cs @@ -924,145 +924,136 @@ private sealed class CloningExpressionVisitor : ExpressionVisitor [return: NotNullIfNotNull("expression")] public override Expression? Visit(Expression? expression) { - if (expression is SelectExpression selectExpression) + switch (expression) { - var newProjectionMappings = new Dictionary(selectExpression._projectionMapping.Count); - foreach (var (projectionMember, value) in selectExpression._projectionMapping) + case SelectExpression selectExpression: { - newProjectionMappings[projectionMember] = Visit(value); - } - - var newProjections = selectExpression._projection.Select(Visit).ToList(); + var newProjectionMappings = new Dictionary(selectExpression._projectionMapping.Count); + foreach (var (projectionMember, value) in selectExpression._projectionMapping) + { + newProjectionMappings[projectionMember] = Visit(value); + } - var newTables = selectExpression._tables.Select(Visit).ToList(); - var tpcTablesMap = selectExpression._tables.Select(UnwrapJoinExpression).Zip(newTables.Select(UnwrapJoinExpression)) - .Where(e => e.First is TpcTablesExpression) - .ToDictionary(e => (TpcTablesExpression)e.First, e => (TpcTablesExpression)e.Second); + var newProjections = selectExpression._projection.Select(Visit).ToList(); + + var newTables = selectExpression._tables.Select(Visit).ToList(); + var tpcTablesMap = selectExpression._tables.Select(UnwrapJoinExpression).Zip(newTables.Select(UnwrapJoinExpression)) + .Where(e => e.First is TpcTablesExpression) + .ToDictionary(e => (TpcTablesExpression)e.First, e => (TpcTablesExpression)e.Second); + + // Since we are cloning we need to generate new table references + // In other cases (like VisitChildren), we just reuse the same table references and update the SelectExpression inside it. + // We initially assign old SelectExpression in table references and later update it once we construct clone + var newTableReferences = selectExpression._tableReferences + .Select(e => new TableReferenceExpression(selectExpression, e.Alias)).ToList(); + Check.DebugAssert( + newTables.Select(e => GetAliasFromTableExpressionBase(e)).SequenceEqual(newTableReferences.Select(e => e.Alias)), + "Alias of updated tables must match the old tables."); + + var predicate = (SqlExpression?)Visit(selectExpression.Predicate); + var newGroupBy = selectExpression._groupBy.Select(Visit) + .Where(e => !(e is SqlConstantExpression || e is SqlParameterExpression)) + .ToList(); + var havingExpression = (SqlExpression?)Visit(selectExpression.Having); + var newOrderings = selectExpression._orderings.Select(Visit).ToList(); + var offset = (SqlExpression?)Visit(selectExpression.Offset); + var limit = (SqlExpression?)Visit(selectExpression.Limit); + + var newSelectExpression = new SelectExpression( + selectExpression.Alias, newProjections, newTables, newTableReferences, newGroupBy, newOrderings, + selectExpression.GetAnnotations()) + { + Predicate = predicate, + Having = havingExpression, + Offset = offset, + Limit = limit, + IsDistinct = selectExpression.IsDistinct, + Tags = selectExpression.Tags, + _usedAliases = selectExpression._usedAliases.ToHashSet(), + _projectionMapping = newProjectionMappings, + }; + newSelectExpression._mutable = selectExpression._mutable; + + newSelectExpression._removableJoinTables.AddRange(selectExpression._removableJoinTables); + + foreach (var kvp in selectExpression._tpcDiscriminatorValues) + { + newSelectExpression._tpcDiscriminatorValues[tpcTablesMap[kvp.Key]] = kvp.Value; + } - // Since we are cloning we need to generate new table references - // In other cases (like VisitChildren), we just reuse the same table references and update the SelectExpression inside it. - // We initially assign old SelectExpression in table references and later update it once we construct clone - var newTableReferences = selectExpression._tableReferences - .Select(e => new TableReferenceExpression(selectExpression, e.Alias)).ToList(); - Check.DebugAssert( - newTables.Select(e => GetAliasFromTableExpressionBase(e)).SequenceEqual(newTableReferences.Select(e => e.Alias)), - "Alias of updated tables must match the old tables."); - - var predicate = (SqlExpression?)Visit(selectExpression.Predicate); - var newGroupBy = selectExpression._groupBy.Select(Visit) - .Where(e => !(e is SqlConstantExpression || e is SqlParameterExpression)) - .ToList(); - var havingExpression = (SqlExpression?)Visit(selectExpression.Having); - var newOrderings = selectExpression._orderings.Select(Visit).ToList(); - var offset = (SqlExpression?)Visit(selectExpression.Offset); - var limit = (SqlExpression?)Visit(selectExpression.Limit); - - var newSelectExpression = new SelectExpression( - selectExpression.Alias, newProjections, newTables, newTableReferences, newGroupBy, newOrderings, - selectExpression.GetAnnotations()) - { - Predicate = predicate, - Having = havingExpression, - Offset = offset, - Limit = limit, - IsDistinct = selectExpression.IsDistinct, - Tags = selectExpression.Tags, - _usedAliases = selectExpression._usedAliases.ToHashSet(), - _projectionMapping = newProjectionMappings, - }; - newSelectExpression._mutable = selectExpression._mutable; - - newSelectExpression._removableJoinTables.AddRange(selectExpression._removableJoinTables); + // Since identifiers are ColumnExpression, they are not visited since they don't contain SelectExpression inside it. + newSelectExpression._identifier.AddRange(selectExpression._identifier); + newSelectExpression._childIdentifiers.AddRange(selectExpression._childIdentifiers); - foreach (var kvp in selectExpression._tpcDiscriminatorValues) - { - newSelectExpression._tpcDiscriminatorValues[tpcTablesMap[kvp.Key]] = kvp.Value; - } + // Remap tableReferences in new select expression + foreach (var tableReference in newTableReferences) + { + tableReference.UpdateTableReference(selectExpression, newSelectExpression); + } - // Since identifiers are ColumnExpression, they are not visited since they don't contain SelectExpression inside it. - newSelectExpression._identifier.AddRange(selectExpression._identifier); - newSelectExpression._childIdentifiers.AddRange(selectExpression._childIdentifiers); + // Now that we have SelectExpression, we visit all components and update table references inside columns + newSelectExpression = (SelectExpression)new ColumnExpressionReplacingExpressionVisitor( + selectExpression, newSelectExpression._tableReferences).Visit(newSelectExpression); - // Remap tableReferences in new select expression - foreach (var tableReference in newTableReferences) - { - tableReference.UpdateTableReference(selectExpression, newSelectExpression); + return newSelectExpression; } - // Now that we have SelectExpression, we visit all components and update table references inside columns - newSelectExpression = (SelectExpression)new ColumnExpressionReplacingExpressionVisitor( - selectExpression, newSelectExpression._tableReferences).Visit(newSelectExpression); - - return newSelectExpression; - } - - if (expression is TpcTablesExpression tpcTablesExpression) - { - // Deep clone - var subSelectExpressions = tpcTablesExpression.SelectExpressions.Select(Visit).ToList(); - var newTpcTable = new TpcTablesExpression(tpcTablesExpression.Alias, tpcTablesExpression.EntityType, subSelectExpressions); - foreach (var annotation in tpcTablesExpression.GetAnnotations()) + case TpcTablesExpression tpcTablesExpression: { - newTpcTable.AddAnnotation(annotation.Name, annotation.Value); - } - - return newTpcTable; - } + // Deep clone + var subSelectExpressions = tpcTablesExpression.SelectExpressions.Select(Visit).ToList(); + var newTpcTable = new TpcTablesExpression(tpcTablesExpression.Alias, tpcTablesExpression.EntityType, subSelectExpressions); + foreach (var annotation in tpcTablesExpression.GetAnnotations()) + { + newTpcTable.AddAnnotation(annotation.Name, annotation.Value); + } - if (expression is TableValuedFunctionExpression tableValuedFunctionExpression) - { - var newArguments = new SqlExpression[tableValuedFunctionExpression.Arguments.Count]; - for (var i = 0; i < newArguments.Length; i++) - { - newArguments[i] = (SqlExpression)Visit(tableValuedFunctionExpression.Arguments[i]); + return newTpcTable; } - var newTableValuedFunctionExpression = new TableValuedFunctionExpression( - tableValuedFunctionExpression.StoreFunction, - newArguments) + case TableValuedFunctionExpression tableValuedFunctionExpression: { - Alias = tableValuedFunctionExpression.Alias - }; + var newArguments = new SqlExpression[tableValuedFunctionExpression.Arguments.Count]; + for (var i = 0; i < newArguments.Length; i++) + { + newArguments[i] = (SqlExpression)Visit(tableValuedFunctionExpression.Arguments[i]); + } - foreach (var annotation in tableValuedFunctionExpression.GetAnnotations()) - { - newTableValuedFunctionExpression.AddAnnotation(annotation.Name, annotation.Value); + var newTableValuedFunctionExpression = new TableValuedFunctionExpression( + tableValuedFunctionExpression.StoreFunction, + newArguments) + { + Alias = tableValuedFunctionExpression.Alias + }; + + foreach (var annotation in tableValuedFunctionExpression.GetAnnotations()) + { + newTableValuedFunctionExpression.AddAnnotation(annotation.Name, annotation.Value); + } + + return newTableValuedFunctionExpression; } - return newTableValuedFunctionExpression; - } + case IClonableTableExpressionBase cloneable: + return cloneable.Clone(); - if (expression is IClonableTableExpressionBase cloneable) - { - return cloneable.Clone(); - } + // join and set operations are fine, because they contain other TableExpressionBases inside, that will get cloned + // and therefore set expression's Update function will generate a new instance. + case JoinExpressionBase or SetOperationBase: + return base.Visit(expression); - // join and set operations are fine, because they contain other TableExpressionBases inside, that will get cloned - // and therefore set expression's Update function will generate a new instance. - if (expression is CrossJoinExpression - or InnerJoinExpression - or LeftJoinExpression - or CrossApplyExpression - or OuterApplyExpression - or ExceptExpression - or IntersectExpression - or UnionExpression) - { - return base.Visit(expression); - } + case TableExpressionBase: + throw new InvalidOperationException( + RelationalStrings.TableExpressionBaseWithoutCloningLogic( + expression.GetType().Name, + nameof(TableExpressionBase), + nameof(IClonableTableExpressionBase), + nameof(CloningExpressionVisitor), + nameof(SelectExpression))); - if (expression is TableExpressionBase) - { - throw new InvalidOperationException( - RelationalStrings.TableExpressionBaseWithoutCloningLogic( - expression.GetType().Name, - nameof(TableExpressionBase), - nameof(IClonableTableExpressionBase), - nameof(CloningExpressionVisitor), - nameof(SelectExpression))); ; + default: + return base.Visit(expression); } - - return base.Visit(expression); } }