Skip to content

Commit c8c9b55

Browse files
committed
Tests [2/?]
Added several functions to work with math and date
1 parent d50f218 commit c8c9b55

35 files changed

+1653
-74
lines changed
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
using System;
2+
using System.Diagnostics;
3+
using System.Diagnostics.CodeAnalysis;
4+
using System.Reflection;
5+
using EntityFrameworkCore.Ydb.Storage.Internal;
6+
using EntityFrameworkCore.Ydb.Storage.Internal.Mapping;
7+
using EntityFrameworkCore.Ydb.Utilities;
8+
using Microsoft.EntityFrameworkCore;
9+
using Microsoft.EntityFrameworkCore.Diagnostics;
10+
using Microsoft.EntityFrameworkCore.Query;
11+
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
12+
using Microsoft.EntityFrameworkCore.Storage;
13+
14+
namespace EntityFrameworkCore.Ydb.Query.Internal.Translators;
15+
16+
public class YdbDateTimeMemberTranslator : IMemberTranslator
17+
{
18+
private readonly IRelationalTypeMappingSource _typeMappingSource;
19+
private readonly YdbSqlExpressionFactory _sqlExpressionFactory;
20+
private readonly RelationalTypeMapping _timestampMapping;
21+
22+
public YdbDateTimeMemberTranslator(
23+
IRelationalTypeMappingSource typeMappingSource,
24+
YdbSqlExpressionFactory sqlExpressionFactory
25+
)
26+
{
27+
_typeMappingSource = typeMappingSource;
28+
_timestampMapping = typeMappingSource.FindMapping(typeof(DateTime), "TimeStamp")!;
29+
_sqlExpressionFactory = sqlExpressionFactory;
30+
}
31+
32+
public virtual SqlExpression? Translate(
33+
SqlExpression? instance,
34+
MemberInfo member,
35+
Type returnType,
36+
IDiagnosticsLogger<DbLoggerCategory.Query> logger
37+
)
38+
{
39+
var declaringType = member.DeclaringType;
40+
41+
if (declaringType == typeof(TimeOnly))
42+
{
43+
throw new InvalidOperationException("Ydb doesn't support TimeOnly right now");
44+
}
45+
46+
if (declaringType != typeof(DateTime) && declaringType != typeof(DateOnly))
47+
{
48+
return null;
49+
}
50+
51+
if (member.Name == nameof(DateTime.Date))
52+
{
53+
switch (instance)
54+
{
55+
case { TypeMapping: YdbDateTimeTypeMapping }:
56+
case { Type: var type } when type == typeof(DateTime):
57+
return _sqlExpressionFactory.Convert(
58+
_sqlExpressionFactory.Convert(instance, typeof(DateOnly)),
59+
typeof(DateTime)
60+
);
61+
case { TypeMapping: YdbDateOnlyTypeMapping }:
62+
case { Type: var type } when type == typeof(DateOnly):
63+
return instance;
64+
default:
65+
return null;
66+
}
67+
}
68+
69+
return member.Name switch
70+
{
71+
// TODO: Find out how to add
72+
// nameof(DateTime.Now) => ???,
73+
// nameof(DateTime.Today) => ???
74+
75+
nameof(DateTime.UtcNow) => UtcNow(),
76+
77+
nameof(DateTime.Year) => DatePart(instance!, "GetYear"),
78+
nameof(DateTime.Month) => DatePart(instance!, "GetMonth"),
79+
nameof(DateTime.Day) => DatePart(instance!, "GetDay"),
80+
nameof(DateTime.Hour) => DatePart(instance!, "GetHour"),
81+
nameof(DateTime.Minute) => DatePart(instance!, "GetMinute"),
82+
nameof(DateTime.Second) => DatePart(instance!, "GetSecond"),
83+
nameof(DateTime.Millisecond) => DatePart(instance!, "GetMillisecondOfSecond"),
84+
85+
nameof(DateTime.DayOfYear) => DatePart(instance!, "GetDayOfYear"),
86+
nameof(DateTime.DayOfWeek) => DatePart(instance!, "GetDayOfWeek"),
87+
88+
// TODO: Research if it's possible to implement
89+
nameof(DateTime.Ticks) => null,
90+
_ => null
91+
};
92+
93+
SqlExpression UtcNow()
94+
=> _sqlExpressionFactory.Function(
95+
"CurrentUtc" + returnType.Name == "DateOnly" ? "Date" : returnType.Name,
96+
[],
97+
nullable: false,
98+
argumentsPropagateNullability: ArrayUtil.TrueArrays[0],
99+
returnType,
100+
_typeMappingSource.FindMapping(returnType)
101+
);
102+
}
103+
104+
private SqlExpression? DatePart(SqlExpression instance, string partName)
105+
{
106+
var result = _sqlExpressionFactory.Function(
107+
$"DateTime::{partName}",
108+
[instance],
109+
nullable: true,
110+
argumentsPropagateNullability: ArrayUtil.TrueArrays[1],
111+
typeof(short) // Doesn't matter because we cast it to int in next line anyway
112+
);
113+
114+
return _sqlExpressionFactory.Convert(result, typeof(int));
115+
}
116+
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq.Expressions;
4+
using System.Reflection;
5+
using EntityFrameworkCore.Ydb.Utilities;
6+
using Microsoft.EntityFrameworkCore;
7+
using Microsoft.EntityFrameworkCore.Diagnostics;
8+
using Microsoft.EntityFrameworkCore.Query;
9+
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
10+
using Microsoft.EntityFrameworkCore.Storage;
11+
12+
namespace EntityFrameworkCore.Ydb.Query.Internal.Translators;
13+
14+
public class YdbDateTimeMethodTranslator : IMethodCallTranslator
15+
{
16+
private static readonly Dictionary<MethodInfo, string> MethodInfoDatePartMapping = new()
17+
{
18+
{ typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddYears), [typeof(int)])!, " Years" },
19+
{ typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddMonths), [typeof(int)])!, " Months" },
20+
{ typeof(DateOnly).GetRuntimeMethod(nameof(DateOnly.AddDays), [typeof(int)])!, " Days" },
21+
22+
{ typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddYears), [typeof(int)])!, "Years" },
23+
{ typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddMonths), [typeof(int)])!, "Months" },
24+
{ typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddDays), [typeof(double)])!, "Days" },
25+
{ typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddHours), [typeof(double)])!, "Hours" },
26+
{ typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddMinutes), [typeof(double)])!, "Mins" },
27+
{ typeof(DateTime).GetRuntimeMethod(nameof(DateTime.AddSeconds), [typeof(double)])!, "Secs" },
28+
29+
{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddYears), [typeof(int)])!, "Years" },
30+
{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddMonths), [typeof(int)])!, "Months" },
31+
{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddDays), [typeof(double)])!, "Days" },
32+
{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddHours), [typeof(double)])!, "Hours" },
33+
{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddMinutes), [typeof(double)])!, "Mins" },
34+
{ typeof(DateTimeOffset).GetRuntimeMethod(nameof(DateTimeOffset.AddSeconds), [typeof(double)])!, "Secs" },
35+
};
36+
37+
private readonly YdbSqlExpressionFactory _sqlExpressionFactory;
38+
39+
public YdbDateTimeMethodTranslator(YdbSqlExpressionFactory sqlExpressionFactory)
40+
{
41+
_sqlExpressionFactory = sqlExpressionFactory;
42+
}
43+
44+
45+
public virtual SqlExpression? Translate(
46+
SqlExpression? instance,
47+
MethodInfo method,
48+
IReadOnlyList<SqlExpression> arguments,
49+
IDiagnosticsLogger<DbLoggerCategory.Query> logger
50+
) => TranslateDatePart(instance, method, arguments);
51+
52+
private SqlExpression? TranslateDatePart(
53+
SqlExpression? instance,
54+
MethodInfo method,
55+
IReadOnlyList<SqlExpression> arguments
56+
)
57+
{
58+
if (
59+
instance is not null
60+
&& MethodInfoDatePartMapping.TryGetValue(method, out var datePart))
61+
{
62+
var shiftDatePartFunction = _sqlExpressionFactory.Function(
63+
"DateTime::Shift" + datePart,
64+
[instance, arguments[0]],
65+
nullable: true,
66+
ArrayUtil.TrueArrays[2],
67+
returnType: typeof(DateTime)
68+
);
69+
70+
return _sqlExpressionFactory.Function(
71+
"DateTime::MakeDate",
72+
arguments: [shiftDatePartFunction],
73+
nullable: true,
74+
ArrayUtil.TrueArrays[1],
75+
returnType: typeof(DateTime)
76+
);
77+
}
78+
79+
return null;
80+
}
81+
}
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Reflection;
5+
using EntityFrameworkCore.Ydb.Utilities;
6+
using Microsoft.EntityFrameworkCore;
7+
using Microsoft.EntityFrameworkCore.Diagnostics;
8+
using Microsoft.EntityFrameworkCore.Query;
9+
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
10+
using ExpressionExtensions = Microsoft.EntityFrameworkCore.Query.ExpressionExtensions;
11+
12+
namespace EntityFrameworkCore.Ydb.Query.Internal.Translators;
13+
14+
public class YdbMathTranslator : IMethodCallTranslator
15+
{
16+
private static readonly Dictionary<MethodInfo, string> SupportedMethods = new()
17+
{
18+
{ typeof(Math).GetMethod(nameof(Math.Abs), [typeof(double)])!, "Abs" },
19+
{ typeof(Math).GetMethod(nameof(Math.Abs), [typeof(float)])!, "Abs" },
20+
{ typeof(Math).GetMethod(nameof(Math.Abs), [typeof(int)])!, "Abs" },
21+
{ typeof(Math).GetMethod(nameof(Math.Abs), [typeof(long)])!, "Abs" },
22+
{ typeof(Math).GetMethod(nameof(Math.Abs), [typeof(sbyte)])!, "Abs" },
23+
{ typeof(Math).GetMethod(nameof(Math.Abs), [typeof(short)])!, "Abs" },
24+
{ typeof(Math).GetMethod(nameof(Math.Acos), [typeof(double)])!, "Acos" },
25+
{ typeof(Math).GetMethod(nameof(Math.Acosh), [typeof(double)])!, "Acosh" },
26+
{ typeof(Math).GetMethod(nameof(Math.Asin), [typeof(double)])!, "Asin" },
27+
{ typeof(Math).GetMethod(nameof(Math.Asinh), [typeof(double)])!, "Asinh" },
28+
{ typeof(Math).GetMethod(nameof(Math.Atan), [typeof(double)])!, "Atan" },
29+
{ typeof(Math).GetMethod(nameof(Math.Atan2), [typeof(double), typeof(double)])!, "Atan2" },
30+
{ typeof(Math).GetMethod(nameof(Math.Atanh), [typeof(double)])!, "Atanh" },
31+
{ typeof(Math).GetMethod(nameof(Math.Ceiling), [typeof(double)])!, "Ceil" },
32+
{ typeof(Math).GetMethod(nameof(Math.Cos), [typeof(double)])!, "Cos" },
33+
{ typeof(Math).GetMethod(nameof(Math.Cosh), [typeof(double)])!, "Cosh" },
34+
{ typeof(Math).GetMethod(nameof(Math.Exp), [typeof(double)])!, "Exp" },
35+
{ typeof(Math).GetMethod(nameof(Math.Floor), [typeof(double)])!, "Floor" },
36+
{ typeof(Math).GetMethod(nameof(Math.Log), [typeof(double)])!, "Log" },
37+
{ typeof(Math).GetMethod(nameof(Math.Log2), [typeof(double)])!, "Log2" },
38+
{ typeof(Math).GetMethod(nameof(Math.Log10), [typeof(double)])!, "Log10" },
39+
{ typeof(Math).GetMethod(nameof(Math.Pow), [typeof(double), typeof(double)])!, "Pow" },
40+
{ typeof(Math).GetMethod(nameof(Math.Round), [typeof(double)])!, "Round" },
41+
{ typeof(Math).GetMethod(nameof(Math.Sign), [typeof(double)])!, "Sign" },
42+
{ typeof(Math).GetMethod(nameof(Math.Sign), [typeof(float)])!, "Sign" },
43+
{ typeof(Math).GetMethod(nameof(Math.Sign), [typeof(long)])!, "Sign" },
44+
{ typeof(Math).GetMethod(nameof(Math.Sign), [typeof(sbyte)])!, "Sign" },
45+
{ typeof(Math).GetMethod(nameof(Math.Sign), [typeof(short)])!, "Sign" },
46+
{ typeof(Math).GetMethod(nameof(Math.Sin), [typeof(double)])!, "Sin" },
47+
{ typeof(Math).GetMethod(nameof(Math.Sinh), [typeof(double)])!, "Sinh" },
48+
{ typeof(Math).GetMethod(nameof(Math.Sqrt), [typeof(double)])!, "Sqrt" },
49+
{ typeof(Math).GetMethod(nameof(Math.Tan), [typeof(double)])!, "Tan" },
50+
{ typeof(Math).GetMethod(nameof(Math.Tanh), [typeof(double)])!, "Tanh" },
51+
{ typeof(Math).GetMethod(nameof(Math.Truncate), [typeof(double)])!, "Trunc" },
52+
{ typeof(MathF).GetMethod(nameof(MathF.Acos), [typeof(float)])!, "Acos" },
53+
{ typeof(MathF).GetMethod(nameof(MathF.Acosh), [typeof(float)])!, "Acosh" },
54+
{ typeof(MathF).GetMethod(nameof(MathF.Asin), [typeof(float)])!, "Asin" },
55+
{ typeof(MathF).GetMethod(nameof(MathF.Asinh), [typeof(float)])!, "Asinh" },
56+
{ typeof(MathF).GetMethod(nameof(MathF.Atan), [typeof(float)])!, "Atan" },
57+
{ typeof(MathF).GetMethod(nameof(MathF.Atan2), [typeof(float), typeof(float)])!, "Atan2" },
58+
{ typeof(MathF).GetMethod(nameof(MathF.Atanh), [typeof(float)])!, "Atanh" },
59+
{ typeof(MathF).GetMethod(nameof(MathF.Ceiling), [typeof(float)])!, "Ceil" },
60+
{ typeof(MathF).GetMethod(nameof(MathF.Cos), [typeof(float)])!, "Cos" },
61+
{ typeof(MathF).GetMethod(nameof(MathF.Cosh), [typeof(float)])!, "Cosh" },
62+
{ typeof(MathF).GetMethod(nameof(MathF.Exp), [typeof(float)])!, "Exp" },
63+
{ typeof(MathF).GetMethod(nameof(MathF.Floor), [typeof(float)])!, "Floor" },
64+
{ typeof(MathF).GetMethod(nameof(MathF.Log), [typeof(float)])!, "Log" },
65+
{ typeof(MathF).GetMethod(nameof(MathF.Log10), [typeof(float)])!, "Log10" },
66+
{ typeof(MathF).GetMethod(nameof(MathF.Log2), [typeof(float)])!, "Log2" },
67+
{ typeof(MathF).GetMethod(nameof(MathF.Pow), [typeof(float), typeof(float)])!, "Pow" },
68+
{ typeof(MathF).GetMethod(nameof(MathF.Round), [typeof(float)])!, "Round" },
69+
{ typeof(MathF).GetMethod(nameof(MathF.Sin), [typeof(float)])!, "Sin" },
70+
{ typeof(MathF).GetMethod(nameof(MathF.Sinh), [typeof(float)])!, "Sinh" },
71+
{ typeof(MathF).GetMethod(nameof(MathF.Sqrt), [typeof(float)])!, "Sqrt" },
72+
{ typeof(MathF).GetMethod(nameof(MathF.Tan), [typeof(float)])!, "Tan" },
73+
{ typeof(MathF).GetMethod(nameof(MathF.Tanh), [typeof(float)])!, "Tanh" },
74+
{ typeof(MathF).GetMethod(nameof(MathF.Truncate), [typeof(float)])!, "Trunc" },
75+
};
76+
77+
private static readonly List<MethodInfo> _roundWithDecimalMethods =
78+
[
79+
typeof(Math).GetMethod(nameof(Math.Round), [typeof(double), typeof(int)])!,
80+
typeof(MathF).GetMethod(nameof(MathF.Round), [typeof(float), typeof(int)])!
81+
];
82+
83+
private static readonly List<MethodInfo> _logWithBaseMethods =
84+
[
85+
typeof(Math).GetMethod(nameof(Math.Log), [typeof(double), typeof(double)])!,
86+
typeof(MathF).GetMethod(nameof(MathF.Log), [typeof(float), typeof(float)])!
87+
];
88+
89+
private readonly ISqlExpressionFactory _sqlExpressionFactory;
90+
91+
public YdbMathTranslator(ISqlExpressionFactory sqlExpressionFactory)
92+
=> _sqlExpressionFactory = sqlExpressionFactory;
93+
94+
public virtual SqlExpression? Translate(
95+
SqlExpression? instance,
96+
MethodInfo method,
97+
IReadOnlyList<SqlExpression> arguments,
98+
IDiagnosticsLogger<DbLoggerCategory.Query> logger
99+
)
100+
{
101+
if (SupportedMethods.TryGetValue(method, out var sqlFunctionName))
102+
{
103+
var typeMapping = ExpressionExtensions.InferTypeMapping(arguments.ToArray());
104+
var newArguments = arguments
105+
.Select(a => _sqlExpressionFactory.ApplyTypeMapping(a, typeMapping))
106+
.ToList();
107+
108+
return _sqlExpressionFactory.Function(
109+
"Math::" + sqlFunctionName,
110+
newArguments,
111+
nullable: true,
112+
argumentsPropagateNullability: newArguments.Select(_ => true).ToList(),
113+
method.ReturnType,
114+
typeMapping
115+
);
116+
}
117+
118+
if (_roundWithDecimalMethods.Contains(method))
119+
{
120+
return _sqlExpressionFactory.Function(
121+
"Math::Round",
122+
arguments,
123+
nullable: true,
124+
argumentsPropagateNullability: ArrayUtil.TrueArrays[2],
125+
method.ReturnType,
126+
arguments[0].TypeMapping);
127+
}
128+
129+
return null;
130+
}
131+
}

src/EFCore.Ydb/src/Query/Internal/YdbMemberTranslatorProvider.cs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
using EntityFrameworkCore.Ydb.Query.Internal.Translators;
22
using Microsoft.EntityFrameworkCore.Query;
3+
using Microsoft.EntityFrameworkCore.Storage;
34

45
namespace EntityFrameworkCore.Ydb.Query.Internal;
56

67
public sealed class YdbMemberTranslatorProvider : RelationalMemberTranslatorProvider
78
{
8-
public YdbMemberTranslatorProvider(RelationalMemberTranslatorProviderDependencies dependencies) : base(dependencies)
9+
public YdbMemberTranslatorProvider(
10+
RelationalMemberTranslatorProviderDependencies dependencies,
11+
IRelationalTypeMappingSource typeMappingSource
12+
) : base(dependencies)
913
{
14+
var sqlExpressionFactory = (YdbSqlExpressionFactory)dependencies.SqlExpressionFactory;
15+
1016
AddTranslators(
1117
[
12-
new StubTranslator()
18+
new YdbDateTimeMemberTranslator(typeMappingSource, sqlExpressionFactory),
1319
]
1420
);
1521
}

src/EFCore.Ydb/src/Query/Internal/YdbMethodCallTranslatorProvider.cs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,21 @@
11
using EntityFrameworkCore.Ydb.Query.Internal.Translators;
2+
using EntityFrameworkCore.Ydb.Storage.Internal;
23
using Microsoft.EntityFrameworkCore.Query;
34

45
namespace EntityFrameworkCore.Ydb.Query.Internal;
56

67
public sealed class YdbMethodCallTranslatorProvider : RelationalMethodCallTranslatorProvider
78
{
8-
public YdbMethodCallTranslatorProvider(
9-
RelationalMethodCallTranslatorProviderDependencies dependencies
10-
) : base(dependencies)
9+
public YdbMethodCallTranslatorProvider(RelationalMethodCallTranslatorProviderDependencies dependencies) :
10+
base(dependencies)
1111
{
12+
var sqlExpressionFactory = (YdbSqlExpressionFactory)dependencies.SqlExpressionFactory;
13+
1214
AddTranslators(
1315
[
14-
new StubTranslator()
16+
new YdbDateTimeMethodTranslator(sqlExpressionFactory),
17+
new YdbMathTranslator(sqlExpressionFactory),
18+
new YdbMathTranslator(sqlExpressionFactory)
1519
]
1620
);
1721
}

0 commit comments

Comments
 (0)