Skip to content

Commit

Permalink
SQLite: Translate more Math members
Browse files Browse the repository at this point in the history
Resolves dotnet#18843
  • Loading branch information
bricelam committed Aug 10, 2023
1 parent 0dd62a2 commit 874ec4d
Show file tree
Hide file tree
Showing 16 changed files with 680 additions and 140 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="SQLitePCLRaw.bundle_e_sqlite3" Version="2.1.5" />
<PackageReference Include="SQLitePCLRaw.bundle_e_sqlite3" Version="2.1.6-pre20230809203314" />
</ItemGroup>

<ItemGroup>
Expand Down
116 changes: 95 additions & 21 deletions src/EFCore.Sqlite.Core/Query/Internal/SqliteMathTranslator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,21 @@ public class SqliteMathTranslator : IMethodCallTranslator
{ typeof(Math).GetMethod(nameof(Math.Abs), new[] { typeof(long) })!, "abs" },
{ typeof(Math).GetMethod(nameof(Math.Abs), new[] { typeof(sbyte) })!, "abs" },
{ typeof(Math).GetMethod(nameof(Math.Abs), new[] { typeof(short) })!, "abs" },
{ typeof(Math).GetMethod(nameof(Math.Acos), new[] { typeof(double) })!, "acos" },
{ typeof(Math).GetMethod(nameof(Math.Acosh), new[] { typeof(double) })!, "acosh" },
{ typeof(Math).GetMethod(nameof(Math.Asin), new[] { typeof(double) })!, "asin" },
{ typeof(Math).GetMethod(nameof(Math.Asinh), new[] { typeof(double) })!, "asinh" },
{ typeof(Math).GetMethod(nameof(Math.Atan), new[] { typeof(double) })!, "atan" },
{ typeof(Math).GetMethod(nameof(Math.Atan2), new[] { typeof(double), typeof(double) })!, "atan2" },
{ typeof(Math).GetMethod(nameof(Math.Atanh), new[] { typeof(double) })!, "atanh" },
{ typeof(Math).GetMethod(nameof(Math.Ceiling), new[] { typeof(double) })!, "ceiling" },
{ typeof(Math).GetMethod(nameof(Math.Cos), new[] { typeof(double) })!, "cos" },
{ typeof(Math).GetMethod(nameof(Math.Cosh), new[] { typeof(double) })!, "cosh" },
{ typeof(Math).GetMethod(nameof(Math.Exp), new[] { typeof(double) })!, "exp" },
{ typeof(Math).GetMethod(nameof(Math.Floor), new[] { typeof(double) })!, "floor" },
{ typeof(Math).GetMethod(nameof(Math.Log), new[] { typeof(double) })!, "ln" },
{ typeof(Math).GetMethod(nameof(Math.Log2), new[] { typeof(double) })!, "log2" },
{ typeof(Math).GetMethod(nameof(Math.Log10), new[] { typeof(double) })!, "log10" },
{ typeof(Math).GetMethod(nameof(Math.Max), new[] { typeof(byte), typeof(byte) })!, "max" },
{ typeof(Math).GetMethod(nameof(Math.Max), new[] { typeof(double), typeof(double) })!, "max" },
{ typeof(Math).GetMethod(nameof(Math.Max), new[] { typeof(float), typeof(float) })!, "max" },
Expand All @@ -40,15 +55,58 @@ public class SqliteMathTranslator : IMethodCallTranslator
{ typeof(Math).GetMethod(nameof(Math.Min), new[] { typeof(short), typeof(short) })!, "min" },
{ typeof(Math).GetMethod(nameof(Math.Min), new[] { typeof(uint), typeof(uint) })!, "min" },
{ typeof(Math).GetMethod(nameof(Math.Min), new[] { typeof(ushort), typeof(ushort) })!, "min" },
{ typeof(Math).GetMethod(nameof(Math.Pow), new[] { typeof(double), typeof(double) })!, "pow" },
{ typeof(Math).GetMethod(nameof(Math.Round), new[] { typeof(double) })!, "round" },
{ typeof(Math).GetMethod(nameof(Math.Round), new[] { typeof(double), typeof(int) })!, "round" },
{ typeof(Math).GetMethod(nameof(Math.Sign), new[] { typeof(double) })!, "sign" },
{ typeof(Math).GetMethod(nameof(Math.Sign), new[] { typeof(float) })!, "sign" },
{ typeof(Math).GetMethod(nameof(Math.Sign), new[] { typeof(long) })!, "sign" },
{ typeof(Math).GetMethod(nameof(Math.Sign), new[] { typeof(sbyte) })!, "sign" },
{ typeof(Math).GetMethod(nameof(Math.Sign), new[] { typeof(short) })!, "sign" },
{ typeof(Math).GetMethod(nameof(Math.Sin), new[] { typeof(double) })!, "sin" },
{ typeof(Math).GetMethod(nameof(Math.Sinh), new[] { typeof(double) })!, "sinh" },
{ typeof(Math).GetMethod(nameof(Math.Sqrt), new[] { typeof(double) })!, "sqrt" },
{ typeof(Math).GetMethod(nameof(Math.Tan), new[] { typeof(double) })!, "tan" },
{ typeof(Math).GetMethod(nameof(Math.Tanh), new[] { typeof(double) })!, "tanh" },
{ typeof(Math).GetMethod(nameof(Math.Truncate), new[] { typeof(double) })!, "trunc" },
{ typeof(double).GetRuntimeMethod(nameof(double.DegreesToRadians), new[] { typeof(double) })!, "radians" },
{ typeof(double).GetRuntimeMethod(nameof(double.RadiansToDegrees), new[] { typeof(double) })!, "degrees" },
{ typeof(MathF).GetMethod(nameof(MathF.Acos), new[] { typeof(float) })!, "acos" },
{ typeof(MathF).GetMethod(nameof(MathF.Acosh), new[] { typeof(float) })!, "acosh" },
{ typeof(MathF).GetMethod(nameof(MathF.Asin), new[] { typeof(float) })!, "asin" },
{ typeof(MathF).GetMethod(nameof(MathF.Asinh), new[] { typeof(float) })!, "asinh" },
{ typeof(MathF).GetMethod(nameof(MathF.Atan), new[] { typeof(float) })!, "atan" },
{ typeof(MathF).GetMethod(nameof(MathF.Atan2), new[] { typeof(float), typeof(float) })!, "atan2" },
{ typeof(MathF).GetMethod(nameof(MathF.Atanh), new[] { typeof(float) })!, "atanh" },
{ typeof(MathF).GetMethod(nameof(MathF.Ceiling), new[] { typeof(float) })!, "ceiling" },
{ typeof(MathF).GetMethod(nameof(MathF.Cos), new[] { typeof(float) })!, "cos" },
{ typeof(MathF).GetMethod(nameof(MathF.Cosh), new[] { typeof(float) })!, "cosh" },
{ typeof(MathF).GetMethod(nameof(MathF.Exp), new[] { typeof(float) })!, "exp" },
{ typeof(MathF).GetMethod(nameof(MathF.Floor), new[] { typeof(float) })!, "floor" },
{ typeof(MathF).GetMethod(nameof(MathF.Log), new[] { typeof(float) })!, "ln" },
{ typeof(MathF).GetMethod(nameof(MathF.Log10), new[] { typeof(float) })!, "log10" },
{ typeof(MathF).GetMethod(nameof(MathF.Log2), new[] { typeof(float) })!, "log2" },
{ typeof(MathF).GetMethod(nameof(MathF.Pow), new[] { typeof(float), typeof(float) })!, "pow" },
{ typeof(MathF).GetMethod(nameof(MathF.Round), new[] { typeof(float) })!, "round" },
{ typeof(MathF).GetMethod(nameof(MathF.Round), new[] { typeof(float), typeof(int) })!, "round" },
{ typeof(MathF).GetMethod(nameof(MathF.Sin), new[] { typeof(float) })!, "sin" },
{ typeof(MathF).GetMethod(nameof(MathF.Sinh), new[] { typeof(float) })!, "sinh" },
{ typeof(MathF).GetMethod(nameof(MathF.Sqrt), new[] { typeof(float) })!, "sqrt" },
{ typeof(MathF).GetMethod(nameof(MathF.Tan), new[] { typeof(float) })!, "tan" },
{ typeof(MathF).GetMethod(nameof(MathF.Tanh), new[] { typeof(float) })!, "tanh" },
{ typeof(MathF).GetMethod(nameof(MathF.Truncate), new[] { typeof(float) })!, "trunc" },
{ typeof(float).GetRuntimeMethod(nameof(float.DegreesToRadians), new[] { typeof(float) })!, "radians" },
{ typeof(float).GetRuntimeMethod(nameof(float.RadiansToDegrees), new[] { typeof(float) })!, "degrees" }
};

private static readonly List<MethodInfo> _roundWithDecimalMethods = new()
{
typeof(Math).GetMethod(nameof(Math.Round), new[] { typeof(double), typeof(int) })!,
typeof(MathF).GetMethod(nameof(MathF.Round), new[] { typeof(float), typeof(int) })!
};

private static readonly List<MethodInfo> _logWithBaseMethods = new()
{
typeof(Math).GetMethod(nameof(Math.Log), new[] { typeof(double), typeof(double) })!,
typeof(MathF).GetMethod(nameof(MathF.Log), new[] { typeof(float), typeof(float) })!
};

private readonly ISqlExpressionFactory _sqlExpressionFactory;
Expand Down Expand Up @@ -78,29 +136,45 @@ public SqliteMathTranslator(ISqlExpressionFactory sqlExpressionFactory)
{
if (SupportedMethods.TryGetValue(method, out var sqlFunctionName))
{
RelationalTypeMapping? typeMapping;
List<SqlExpression>? newArguments = null;
if (sqlFunctionName is "max" or "min")
{
typeMapping = ExpressionExtensions.InferTypeMapping(arguments[0], arguments[1]);
newArguments = new List<SqlExpression>
{
_sqlExpressionFactory.ApplyTypeMapping(arguments[0], typeMapping),
_sqlExpressionFactory.ApplyTypeMapping(arguments[1], typeMapping)
};
}
else
{
typeMapping = arguments[0].TypeMapping;
}

var finalArguments = newArguments ?? arguments;
var typeMapping = ExpressionExtensions.InferTypeMapping(arguments.ToArray());
var newArguments = arguments
.Select(a => _sqlExpressionFactory.ApplyTypeMapping(a, typeMapping))
.ToList();

return _sqlExpressionFactory.Function(
sqlFunctionName,
finalArguments,
newArguments,
nullable: true,
argumentsPropagateNullability: newArguments.Select(_ => true).ToList(),
method.ReturnType,
typeMapping);
}
else if (_roundWithDecimalMethods.Contains(method))
{
return _sqlExpressionFactory.Function(
"round",
arguments,
nullable: true,
argumentsPropagateNullability: new[] { true, true },
method.ReturnType,
arguments[0].TypeMapping);

}
else if (_logWithBaseMethods.Contains(method))
{
var a = arguments[0];
var newBase = arguments[1];
var typeMapping = ExpressionExtensions.InferTypeMapping(a, newBase);

return _sqlExpressionFactory.Function(
"log",
new[]
{
_sqlExpressionFactory.ApplyTypeMapping(newBase, typeMapping),
_sqlExpressionFactory.ApplyTypeMapping(a, typeMapping)
},
nullable: true,
argumentsPropagateNullability: finalArguments.Select(_ => true).ToList(),
argumentsPropagateNullability: new[] { true, true },
method.ReturnType,
typeMapping);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics.CodeAnalysis;
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;

namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal;
Expand Down Expand Up @@ -71,11 +72,11 @@ private static readonly IReadOnlyDictionary<ExpressionType, IReadOnlyCollection<
}
};

private static readonly IReadOnlyCollection<Type> FunctionModuloTypes = new HashSet<Type>
private static readonly IReadOnlyDictionary<Type, string> ModuloFunctions = new Dictionary<Type, string>
{
typeof(decimal),
typeof(double),
typeof(float)
{ typeof(decimal), "ef_mod" },
{ typeof(double), "mod" },
{ typeof(float), "mod" }
};

/// <summary>
Expand Down Expand Up @@ -192,11 +193,11 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
if (visitedExpression is SqlBinaryExpression sqlBinary)
{
if (sqlBinary.OperatorType == ExpressionType.Modulo
&& (FunctionModuloTypes.Contains(GetProviderType(sqlBinary.Left))
|| FunctionModuloTypes.Contains(GetProviderType(sqlBinary.Right))))
&& (ModuloFunctions.TryGetValue(GetProviderType(sqlBinary.Left), out var function)
|| ModuloFunctions.TryGetValue(GetProviderType(sqlBinary.Right), out function)))
{
return Dependencies.SqlExpressionFactory.Function(
"ef_mod",
function,
new[] { sqlBinary.Left, sqlBinary.Right },
nullable: true,
argumentsPropagateNullability: new[] { true, true },
Expand Down Expand Up @@ -225,6 +226,7 @@ protected override Expression VisitBinary(BinaryExpression binaryExpression)
return visitedExpression;
}

[return: NotNullIfNotNull(nameof(expression))]
private static Type? GetProviderType(SqlExpression? expression)
=> expression == null
? null
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System.Globalization;
using System.Text.RegularExpressions;
using Microsoft.Data.Sqlite;
using Microsoft.EntityFrameworkCore.Sqlite.Infrastructure.Internal;
Expand Down Expand Up @@ -112,25 +111,9 @@ private void InitializeDbConnection(DbConnection connection)
},
isDeterministic: true);

sqliteConnection.CreateFunction<object, object, object?>(
sqliteConnection.CreateFunction(
"ef_mod",
(dividend, divisor) =>
{
if (dividend == null
|| divisor == null)
{
return null;
}

if (dividend is string s)
{
return decimal.Parse(s, CultureInfo.InvariantCulture)
% Convert.ToDecimal(divisor, CultureInfo.InvariantCulture);
}

return Convert.ToDouble(dividend, CultureInfo.InvariantCulture)
% Convert.ToDouble(divisor, CultureInfo.InvariantCulture);
},
(decimal? dividend, decimal? divisor) => dividend % divisor,
isDeterministic: true);

sqliteConnection.CreateFunction(
Expand Down
3 changes: 1 addition & 2 deletions src/EFCore.Sqlite/EFCore.Sqlite.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@
</ItemGroup>

<ItemGroup>
<!-- TODO: Remove NoWarn NETSDK1206 in Directory.Build.props after updating -->
<PackageReference Include="SQLitePCLRaw.bundle_e_sqlite3" Version="2.1.5" />
<PackageReference Include="SQLitePCLRaw.bundle_e_sqlite3" Version="2.1.6-pre20230809203314" />
</ItemGroup>

<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ Microsoft.Data.Sqlite.SqliteTransaction</Description>
</ItemGroup>

<ItemGroup>
<PackageReference Include="SQLitePCLRaw.core" Version="2.1.5" />
<PackageReference Include="SQLitePCLRaw.core" Version="2.1.6-pre20230809203314" />
</ItemGroup>

<ItemGroup>
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.Data.Sqlite/Microsoft.Data.Sqlite.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Microsoft.Data.Sqlite.SqliteTransaction</Description>
</PropertyGroup>

<ItemGroup>
<PackageReference Include="SQLitePCLRaw.bundle_e_sqlite3" Version="2.1.5" />
<PackageReference Include="SQLitePCLRaw.bundle_e_sqlite3" Version="2.1.6-pre20230809203314" />
</ItemGroup>

<ItemGroup>
Expand Down
2 changes: 1 addition & 1 deletion test/EFCore.Design.Tests/EFCore.Design.Tests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis.CSharp" Version="$(MicrosoftCodeAnalysisVersion)" />
<PackageReference Include="Microsoft.Extensions.DependencyModel" Version="$(MicrosoftExtensionsDependencyModelVersion)" />
<PackageReference Include="SQLitePCLRaw.bundle_e_sqlite3" Version="2.1.5" />
<PackageReference Include="SQLitePCLRaw.bundle_e_sqlite3" Version="2.1.6-pre20230809203314" />
</ItemGroup>

</Project>
2 changes: 1 addition & 1 deletion test/EFCore.NativeAotTests/EFCore.NativeAotTests.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
<ItemGroup>
<PackageReference Include="Microsoft.Extensions.Configuration.EnvironmentVariables" Version="$(MicrosoftExtensionsConfigurationEnvironmentVariablesVersion)" />
<PackageReference Include="Microsoft.Extensions.Configuration.Json" Version="$(MicrosoftExtensionsConfigurationJsonVersion)" />
<PackageReference Include="SQLitePCLRaw.bundle_e_sqlite3" Version="2.1.5" />
<PackageReference Include="SQLitePCLRaw.bundle_e_sqlite3" Version="2.1.6-pre20230809203314" />
</ItemGroup>

</Project>
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="SQLitePCLRaw.bundle_e_sqlite3" Version="2.1.5" />
<PackageReference Include="SQLitePCLRaw.bundle_e_sqlite3" Version="2.1.6-pre20230809203314" />
</ItemGroup>

</Project>
Loading

0 comments on commit 874ec4d

Please # to comment.