1+ using System . Data ;
2+ using System . Diagnostics ;
3+ using System . Diagnostics . CodeAnalysis ;
4+ using System . Globalization ;
5+ using System . Linq . Expressions ;
6+ using System . Reflection ;
7+ using System . Text ;
18using Microsoft . EntityFrameworkCore . Query ;
9+ using Microsoft . EntityFrameworkCore . Query . SqlExpressions ;
10+ using Microsoft . EntityFrameworkCore . Storage ;
11+ using ArgumentOutOfRangeException = System . ArgumentOutOfRangeException ;
212
313namespace EfCore . Ydb . Query . Internal ;
414
@@ -9,5 +19,292 @@ QueryableMethodTranslatingExpressionVisitor queryableMethodTranslatingExpression
919) : RelationalSqlTranslatingExpressionVisitor (
1020 dependencies ,
1121 queryCompilationContext ,
12- queryableMethodTranslatingExpressionVisitor
13- ) ;
22+ queryableMethodTranslatingExpressionVisitor )
23+ {
24+ private readonly QueryCompilationContext _queryCompilationContext = queryCompilationContext ;
25+
26+ private readonly YdbSqlExpressionFactory _sqlExpressionFactory =
27+ ( YdbSqlExpressionFactory ) dependencies . SqlExpressionFactory ;
28+
29+ private readonly IRelationalTypeMappingSource _typeMappingSource = dependencies . TypeMappingSource ;
30+
31+
32+ private static readonly MethodInfo StringStartsWithMethod
33+ = typeof ( string ) . GetRuntimeMethod ( nameof ( string . StartsWith ) , [ typeof ( string ) ] ) ! ;
34+
35+ private static readonly MethodInfo StringEndsWithMethod
36+ = typeof ( string ) . GetRuntimeMethod ( nameof ( string . EndsWith ) , [ typeof ( string ) ] ) ! ;
37+
38+ private static readonly MethodInfo StringContainsMethod
39+ = typeof ( string ) . GetRuntimeMethod ( nameof ( string . Contains ) , [ typeof ( string ) ] ) ! ;
40+
41+ private static readonly MethodInfo EscapeLikePatternParameterMethod =
42+ typeof ( YdbSqlTranslatingExpressionVisitor ) . GetTypeInfo ( )
43+ . GetDeclaredMethod ( nameof ( ConstructLikePatternParameter ) ) ! ;
44+
45+
46+ protected override Expression VisitMethodCall ( MethodCallExpression methodCallExpression )
47+ {
48+ var method = methodCallExpression . Method ;
49+
50+ if ( method == StringStartsWithMethod
51+ && TryTranslateStartsEndsWithContains (
52+ methodCallExpression . Object ! ,
53+ methodCallExpression . Arguments [ 0 ] ,
54+ StartsEndsWithContains . StartsWith ,
55+ out var translation1 )
56+ )
57+ {
58+ return translation1 ;
59+ }
60+
61+ if ( method == StringEndsWithMethod
62+ && TryTranslateStartsEndsWithContains (
63+ methodCallExpression . Object ! ,
64+ methodCallExpression . Arguments [ 0 ] ,
65+ StartsEndsWithContains . EndsWith ,
66+ out var translation2 )
67+ )
68+ {
69+ return translation2 ;
70+ }
71+
72+ if ( method == StringContainsMethod
73+ && TryTranslateStartsEndsWithContains (
74+ methodCallExpression . Object ! ,
75+ methodCallExpression . Arguments [ 0 ] ,
76+ StartsEndsWithContains . Contains ,
77+ out var translation3 )
78+ )
79+ {
80+ return translation3 ;
81+ }
82+
83+ return base . VisitMethodCall ( methodCallExpression ) ;
84+ }
85+
86+ private bool TryTranslateStartsEndsWithContains (
87+ Expression instance ,
88+ Expression pattern ,
89+ StartsEndsWithContains methodType ,
90+ [ NotNullWhen ( true ) ] out SqlExpression ? translation
91+ )
92+ {
93+ if ( Visit ( instance ) is not SqlExpression translatedInstance
94+ || Visit ( pattern ) is not SqlExpression translatedPattern )
95+ {
96+ translation = null ;
97+ return false ;
98+ }
99+
100+ var stringTypeMapping = ExpressionExtensions . InferTypeMapping ( translatedInstance , translatedPattern ) ;
101+
102+ // UTF8 is DbType.String whereas STRING is DbType.Binary
103+ var isUtf8 = stringTypeMapping . DbType == DbType . String ;
104+
105+ translatedInstance = _sqlExpressionFactory . ApplyTypeMapping ( translatedInstance , stringTypeMapping ) ;
106+ translatedPattern = _sqlExpressionFactory . ApplyTypeMapping ( translatedPattern , stringTypeMapping ) ;
107+
108+ switch ( translatedPattern )
109+ {
110+ case SqlConstantExpression patternConstant :
111+ {
112+ translation = patternConstant . Value switch
113+ {
114+ null => _sqlExpressionFactory . Like (
115+ translatedInstance ,
116+ _sqlExpressionFactory . Constant ( null , typeof ( string ) , stringTypeMapping )
117+ ) ,
118+ "" => _sqlExpressionFactory . Like ( translatedInstance , _sqlExpressionFactory . Constant ( "%" ) ) ,
119+ string s => _sqlExpressionFactory . Like (
120+ translatedInstance ,
121+ _sqlExpressionFactory . Constant (
122+ methodType switch
123+ {
124+ StartsEndsWithContains . StartsWith => EscapeLikePattern ( s ) + '%' ,
125+ StartsEndsWithContains . EndsWith => '%' + EscapeLikePattern ( s ) ,
126+ StartsEndsWithContains . Contains => $ "%{ EscapeLikePattern ( s ) } %",
127+
128+ _ => throw new ArgumentOutOfRangeException ( nameof ( methodType ) , methodType , null )
129+ } ) ) ,
130+
131+ _ => throw new UnreachableException ( )
132+ } ;
133+
134+ return true ;
135+ }
136+
137+ case SqlParameterExpression patternParameter :
138+ {
139+ var lambda = Expression . Lambda (
140+ Expression . Call (
141+ EscapeLikePatternParameterMethod ,
142+ QueryCompilationContext . QueryContextParameter ,
143+ Expression . Constant ( patternParameter . Name ) ,
144+ Expression . Constant ( methodType ) ) ,
145+ QueryCompilationContext . QueryContextParameter ) ;
146+
147+ var escapedPatternParameter =
148+ _queryCompilationContext . RegisterRuntimeParameter (
149+ $ "{ patternParameter . Name } _{ methodType . ToString ( ) . ToLower ( CultureInfo . InvariantCulture ) } ",
150+ lambda ) ;
151+
152+ translation = _sqlExpressionFactory . Like (
153+ translatedInstance ,
154+ new SqlParameterExpression ( escapedPatternParameter . Name ! , escapedPatternParameter . Type ,
155+ stringTypeMapping ) ) ;
156+
157+ return true ;
158+ }
159+
160+ default :
161+ switch ( methodType )
162+ {
163+ case StartsEndsWithContains . StartsWith or StartsEndsWithContains . EndsWith :
164+ var substringArguments = new SqlExpression [ 3 ] ;
165+ substringArguments [ 0 ] = translatedInstance ;
166+ substringArguments [ 2 ] = _sqlExpressionFactory . Function (
167+ "len" ,
168+ [ translatedPattern ] ,
169+ nullable : true ,
170+ argumentsPropagateNullability : [ true ] ,
171+ typeof ( int )
172+ ) ;
173+
174+ if ( methodType == StartsEndsWithContains . StartsWith )
175+ {
176+ substringArguments [ 1 ] = _sqlExpressionFactory . Constant ( 1 ) ;
177+ }
178+ else
179+ {
180+ substringArguments [ 1 ] = _sqlExpressionFactory . Subtract (
181+ _sqlExpressionFactory . Function (
182+ "len" ,
183+ [ translatedInstance ] ,
184+ nullable : true ,
185+ argumentsPropagateNullability : [ true ] ,
186+ typeof ( int )
187+ ) ,
188+ _sqlExpressionFactory . Function (
189+ "len" ,
190+ [ translatedPattern ] ,
191+ nullable : true ,
192+ argumentsPropagateNullability : [ true ] ,
193+ typeof ( int )
194+ )
195+ ) ;
196+ }
197+
198+ var substringFunction = _sqlExpressionFactory . Function (
199+ "substring" ,
200+ substringArguments ,
201+ nullable : true ,
202+ argumentsPropagateNullability : [ true , false , false ] ,
203+ typeof ( string ) ,
204+ stringTypeMapping
205+ ) ;
206+
207+ translation = _sqlExpressionFactory . AndAlso (
208+ _sqlExpressionFactory . IsNotNull ( translatedInstance ) ,
209+ _sqlExpressionFactory . AndAlso (
210+ _sqlExpressionFactory . IsNotNull ( translatedPattern ) ,
211+ _sqlExpressionFactory . OrElse (
212+ _sqlExpressionFactory . Equal (
213+ isUtf8
214+ ? _sqlExpressionFactory . Function (
215+ "unwrap" ,
216+ [
217+ _sqlExpressionFactory . Convert (
218+ substringFunction ,
219+ typeof ( string ) ,
220+ typeMapping : StringTypeMapping . Default
221+ )
222+ ] ,
223+ nullable : false ,
224+ argumentsPropagateNullability : [ true ] ,
225+ typeof ( string )
226+ )
227+ : substringFunction ,
228+ translatedPattern
229+ ) ,
230+ _sqlExpressionFactory . Equal ( translatedPattern ,
231+ _sqlExpressionFactory . Constant ( string . Empty )
232+ )
233+ )
234+ )
235+ ) ;
236+ break ;
237+ case StartsEndsWithContains . Contains :
238+ translation =
239+ _sqlExpressionFactory . AndAlso (
240+ _sqlExpressionFactory . IsNotNull ( translatedInstance ) ,
241+ _sqlExpressionFactory . AndAlso (
242+ _sqlExpressionFactory . IsNotNull ( translatedPattern ) ,
243+ _sqlExpressionFactory . GreaterThan (
244+ _sqlExpressionFactory . Function (
245+ "strpos" , [ translatedInstance , translatedPattern ] , nullable : true ,
246+ argumentsPropagateNullability : [ true , true ] , typeof ( int ) ) ,
247+ _sqlExpressionFactory . Constant ( 0 ) ) ) ) ;
248+ break ;
249+
250+ default :
251+ throw new UnreachableException ( ) ;
252+ }
253+
254+ return true ;
255+ }
256+ }
257+
258+
259+ public enum StartsEndsWithContains
260+ {
261+ StartsWith ,
262+ EndsWith ,
263+ Contains
264+ }
265+
266+ public static string ? ConstructLikePatternParameter (
267+ QueryContext queryContext ,
268+ string baseParameterName ,
269+ StartsEndsWithContains methodType
270+ )
271+ => queryContext . ParameterValues [ baseParameterName ] switch
272+ {
273+ null => null ,
274+
275+ // In .NET, all strings start/end with the empty string, but SQL LIKE return false for empty patterns.
276+ // Return % which always matches instead.
277+ "" => "%" ,
278+
279+ string s => methodType switch
280+ {
281+ StartsEndsWithContains . StartsWith => EscapeLikePattern ( s ) + '%' ,
282+ StartsEndsWithContains . EndsWith => '%' + EscapeLikePattern ( s ) ,
283+ StartsEndsWithContains . Contains => $ "%{ EscapeLikePattern ( s ) } %",
284+ _ => throw new ArgumentOutOfRangeException ( nameof ( methodType ) , methodType , null )
285+ } ,
286+
287+ _ => throw new UnreachableException ( )
288+ } ;
289+
290+ private const char LikeEscapeChar = '\\ ' ;
291+
292+ private static bool IsLikeWildChar ( char c )
293+ => c is '%' or '_' ;
294+
295+ private static string EscapeLikePattern ( string pattern )
296+ {
297+ var builder = new StringBuilder ( ) ;
298+ foreach ( var c in pattern )
299+ {
300+ if ( IsLikeWildChar ( c ) || c == LikeEscapeChar )
301+ {
302+ builder . Append ( LikeEscapeChar ) ;
303+ }
304+
305+ builder . Append ( c ) ;
306+ }
307+
308+ return builder . ToString ( ) ;
309+ }
310+ }
0 commit comments