diff --git a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs index 34bb6b21c1c..bb4787b27c6 100644 --- a/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Cosmos/Query/Internal/CosmosSqlTranslatingExpressionVisitor.cs @@ -307,7 +307,7 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp && methodCallExpression.Arguments.Count == 1) { var left = Visit(methodCallExpression.Object); - var right = Visit(methodCallExpression.Arguments[0]); + var right = Visit(RemoveObjectConvert(methodCallExpression.Arguments[0])); if (TryRewriteEntityEquality(ExpressionType.Equal, left ?? methodCallExpression.Object, @@ -332,8 +332,8 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp && methodCallExpression.Object == null && methodCallExpression.Arguments.Count == 2) { - var left = Visit(methodCallExpression.Arguments[0]); - var right = Visit(methodCallExpression.Arguments[1]); + var left = Visit(RemoveObjectConvert(methodCallExpression.Arguments[0])); + var right = Visit(RemoveObjectConvert(methodCallExpression.Arguments[1])); if (TryRewriteEntityEquality(ExpressionType.Equal, left ?? methodCallExpression.Arguments[0], @@ -417,6 +417,13 @@ protected override Expression VisitMethodCall(MethodCallExpression methodCallExp } return _methodCallTranslatorProvider.Translate(_model, sqlObject, methodCallExpression.Method, arguments); + + static Expression RemoveObjectConvert(Expression expression) + => expression is UnaryExpression unaryExpression + && (unaryExpression.NodeType == ExpressionType.Convert || unaryExpression.NodeType == ExpressionType.ConvertChecked) + && unaryExpression.Type == typeof(object) + ? unaryExpression.Operand + : expression; } /// diff --git a/src/EFCore.Cosmos/Query/Internal/EqualsTranslator.cs b/src/EFCore.Cosmos/Query/Internal/EqualsTranslator.cs index a9c34364e6e..5b02718cd19 100644 --- a/src/EFCore.Cosmos/Query/Internal/EqualsTranslator.cs +++ b/src/EFCore.Cosmos/Query/Internal/EqualsTranslator.cs @@ -50,14 +50,14 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method && arguments.Count == 1) { left = instance; - right = RemoveObjectConvert(arguments[0]); + right = arguments[0]; } else if (instance == null && method.Name == nameof(object.Equals) && arguments.Count == 2) { - left = RemoveObjectConvert(arguments[0]); - right = RemoveObjectConvert(arguments[1]); + left = arguments[0]; + right = arguments[1]; } if (left != null @@ -70,17 +70,5 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method return null; } - - private SqlExpression RemoveObjectConvert(SqlExpression expression) - { - if (expression is SqlUnaryExpression sqlUnaryExpression - && sqlUnaryExpression.OperatorType == ExpressionType.Convert - && sqlUnaryExpression.Type == typeof(object)) - { - return sqlUnaryExpression.Operand; - } - - return expression; - } } } diff --git a/src/EFCore.Relational/Query/Internal/EqualsTranslator.cs b/src/EFCore.Relational/Query/Internal/EqualsTranslator.cs index 77fc27c2c8c..57eac0091e3 100644 --- a/src/EFCore.Relational/Query/Internal/EqualsTranslator.cs +++ b/src/EFCore.Relational/Query/Internal/EqualsTranslator.cs @@ -33,14 +33,14 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method && arguments.Count == 1) { left = instance; - right = RemoveObjectConvert(arguments[0]); + right = arguments[0]; } else if (instance == null && method.Name == nameof(object.Equals) && arguments.Count == 2) { - left = RemoveObjectConvert(arguments[0]); - right = RemoveObjectConvert(arguments[1]); + left = arguments[0]; + right = arguments[1]; } if (left != null @@ -56,17 +56,5 @@ public virtual SqlExpression Translate(SqlExpression instance, MethodInfo method return null; } - - private SqlExpression RemoveObjectConvert(SqlExpression expression) - { - if (expression is SqlUnaryExpression sqlUnaryExpression - && sqlUnaryExpression.OperatorType == ExpressionType.Convert - && sqlUnaryExpression.Type == typeof(object)) - { - return sqlUnaryExpression.Operand; - } - - return expression; - } } } diff --git a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs index 81a684d578e..90b507f6546 100644 --- a/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs +++ b/src/EFCore.Relational/Query/RelationalSqlTranslatingExpressionVisitor.cs @@ -446,7 +446,7 @@ static bool IsAggregateResultWithCustomShaper(MethodInfo method) && methodCallExpression.Arguments.Count == 1) { var left = Visit(methodCallExpression.Object); - var right = Visit(methodCallExpression.Arguments[0]); + var right = Visit(RemoveObjectConvert(methodCallExpression.Arguments[0])); if (TryRewriteEntityEquality(ExpressionType.Equal, left ?? methodCallExpression.Object, @@ -471,8 +471,8 @@ static bool IsAggregateResultWithCustomShaper(MethodInfo method) && methodCallExpression.Object == null && methodCallExpression.Arguments.Count == 2) { - var left = Visit(methodCallExpression.Arguments[0]); - var right = Visit(methodCallExpression.Arguments[1]); + var left = Visit(RemoveObjectConvert(methodCallExpression.Arguments[0])); + var right = Visit(RemoveObjectConvert(methodCallExpression.Arguments[1])); if (TryRewriteEntityEquality(ExpressionType.Equal, left ?? methodCallExpression.Arguments[0], @@ -556,6 +556,13 @@ static bool IsAggregateResultWithCustomShaper(MethodInfo method) } return Dependencies.MethodCallTranslatorProvider.Translate(_model, sqlObject, methodCallExpression.Method, arguments); + + static Expression RemoveObjectConvert(Expression expression) + => expression is UnaryExpression unaryExpression + && (unaryExpression.NodeType == ExpressionType.Convert || unaryExpression.NodeType == ExpressionType.ConvertChecked) + && unaryExpression.Type == typeof(object) + ? unaryExpression.Operand + : expression; } protected override Expression VisitNew(NewExpression newExpression) diff --git a/test/EFCore.Specification.Tests/ConvertToProviderTypesTestBase.cs b/test/EFCore.Specification.Tests/ConvertToProviderTypesTestBase.cs index ffce64d5069..c17d410873b 100644 --- a/test/EFCore.Specification.Tests/ConvertToProviderTypesTestBase.cs +++ b/test/EFCore.Specification.Tests/ConvertToProviderTypesTestBase.cs @@ -1,6 +1,9 @@ // Copyright (c) .NET Foundation. All rights reserved. // Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information. +using System.Linq; +using Xunit; + namespace Microsoft.EntityFrameworkCore { public abstract class ConvertToProviderTypesTestBase : BuiltInDataTypesTestBase @@ -11,6 +14,26 @@ protected ConvertToProviderTypesTestBase(TFixture fixture) { } + [ConditionalFact] + public virtual void Equals_method_over_enum_works() + { + using var context = CreateContext(); + + var query = context.Set().Where(t => t.Id == -1 && t.Enum8.Equals(Enum8.SomeValue)).ToList(); + + Assert.Empty(query); + } + + [ConditionalFact] + public virtual void Object_equals_method_over_enum_works() + { + using var context = CreateContext(); + + var query = context.Set().Where(t => t.Id == -1 && Equals(t.Enum8, Enum8.SomeValue)).ToList(); + + Assert.Empty(query); + } + public override void Object_to_string_conversion() {} public abstract class ConvertToProviderTypesFixtureBase : BuiltInDataTypesFixtureBase diff --git a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindFunctionsQuerySqlServerTest.cs b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindFunctionsQuerySqlServerTest.cs index 2dbc3e0fd1e..a43e48afa42 100644 --- a/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindFunctionsQuerySqlServerTest.cs +++ b/test/EFCore.SqlServer.FunctionalTests/Query/NorthwindFunctionsQuerySqlServerTest.cs @@ -1395,7 +1395,7 @@ public override async Task Static_equals_nullable_datetime_compared_to_non_nulla await base.Static_equals_nullable_datetime_compared_to_non_nullable(async); AssertSql( - @"@__arg_0='1996-07-04T00:00:00.0000000' + @"@__arg_0='1996-07-04T00:00:00.0000000' (DbType = DateTime) SELECT [o].[OrderID], [o].[CustomerID], [o].[EmployeeID], [o].[OrderDate] FROM [Orders] AS [o]