@@ -155,21 +155,26 @@ fn split_eq_and_noneq_join_predicate(
155155#[ cfg( test) ]
156156mod tests {
157157 use super :: * ;
158+ use crate :: assert_optimized_plan_eq_display_indent_snapshot;
158159 use crate :: test:: * ;
159160 use arrow:: datatypes:: DataType ;
160161 use datafusion_expr:: {
161162 col, lit, logical_plan:: builder:: LogicalPlanBuilder , JoinType ,
162163 } ;
163164 use std:: sync:: Arc ;
164165
165- fn assert_plan_eq ( plan : LogicalPlan , expected : & str ) -> Result < ( ) > {
166- assert_optimized_plan_eq_display_indent (
167- Arc :: new ( ExtractEquijoinPredicate { } ) ,
168- plan,
169- expected,
170- ) ;
171-
172- Ok ( ( ) )
166+ macro_rules! assert_optimized_plan_equal {
167+ (
168+ $plan: expr,
169+ @ $expected: literal $( , ) ?
170+ ) => { {
171+ let rule: Arc <dyn crate :: OptimizerRule + Send + Sync > = Arc :: new( ExtractEquijoinPredicate { } ) ;
172+ assert_optimized_plan_eq_display_indent_snapshot!(
173+ rule,
174+ $plan,
175+ @ $expected,
176+ )
177+ } } ;
173178 }
174179
175180 #[ test]
@@ -180,11 +185,14 @@ mod tests {
180185 let plan = LogicalPlanBuilder :: from ( t1)
181186 . join_on ( t2, JoinType :: Left , Some ( col ( "t1.a" ) . eq ( col ( "t2.a" ) ) ) ) ?
182187 . build ( ) ?;
183- let expected = "Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
184- \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
185- \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
186-
187- assert_plan_eq ( plan, expected)
188+ assert_optimized_plan_equal ! (
189+ plan,
190+ @r"
191+ Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
192+ TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
193+ TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
194+ "
195+ )
188196 }
189197
190198 #[ test]
@@ -199,11 +207,14 @@ mod tests {
199207 Some ( ( col ( "t1.a" ) + lit ( 10i64 ) ) . eq ( col ( "t2.a" ) * lit ( 2u32 ) ) ) ,
200208 ) ?
201209 . build ( ) ?;
202- let expected = "Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
203- \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
204- \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
205-
206- assert_plan_eq ( plan, expected)
210+ assert_optimized_plan_equal ! (
211+ plan,
212+ @r"
213+ Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
214+ TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
215+ TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
216+ "
217+ )
207218 }
208219
209220 #[ test]
@@ -222,11 +233,14 @@ mod tests {
222233 ) ,
223234 ) ?
224235 . build ( ) ?;
225- let expected = "Left Join: Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
226- \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
227- \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
228-
229- assert_plan_eq ( plan, expected)
236+ assert_optimized_plan_equal ! (
237+ plan,
238+ @r"
239+ Left Join: Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
240+ TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
241+ TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
242+ "
243+ )
230244 }
231245
232246 #[ test]
@@ -249,11 +263,14 @@ mod tests {
249263 ) ,
250264 ) ?
251265 . build ( ) ?;
252- let expected = "Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
253- \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
254- \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
255-
256- assert_plan_eq ( plan, expected)
266+ assert_optimized_plan_equal ! (
267+ plan,
268+ @r"
269+ Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
270+ TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
271+ TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
272+ "
273+ )
257274 }
258275
259276 #[ test]
@@ -275,11 +292,14 @@ mod tests {
275292 ) ,
276293 ) ?
277294 . build ( ) ?;
278- let expected = "Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
279- \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
280- \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
281-
282- assert_plan_eq ( plan, expected)
295+ assert_optimized_plan_equal ! (
296+ plan,
297+ @r"
298+ Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
299+ TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
300+ TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
301+ "
302+ )
283303 }
284304
285305 #[ test]
@@ -310,13 +330,16 @@ mod tests {
310330 ) ,
311331 ) ?
312332 . build ( ) ?;
313- let expected = "Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
314- \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
315- \n Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
316- \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
317- \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]";
318-
319- assert_plan_eq ( plan, expected)
333+ assert_optimized_plan_equal ! (
334+ plan,
335+ @r"
336+ Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
337+ TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
338+ Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
339+ TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
340+ TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
341+ "
342+ )
320343 }
321344
322345 #[ test]
@@ -343,13 +366,16 @@ mod tests {
343366 Some ( col ( "t1.a" ) . eq ( col ( "t2.a" ) ) . and ( col ( "t2.c" ) . eq ( col ( "t3.c" ) ) ) ) ,
344367 ) ?
345368 . build ( ) ?;
346- let expected = "Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
347- \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
348- \n Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
349- \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
350- \n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]";
351-
352- assert_plan_eq ( plan, expected)
369+ assert_optimized_plan_equal ! (
370+ plan,
371+ @r"
372+ Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
373+ TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
374+ Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
375+ TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
376+ TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
377+ "
378+ )
353379 }
354380
355381 #[ test]
@@ -369,10 +395,13 @@ mod tests {
369395 let plan = LogicalPlanBuilder :: from ( t1)
370396 . join_on ( t2, JoinType :: Left , Some ( filter) ) ?
371397 . build ( ) ?;
372- let expected = "Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
373- \n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
374- \n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
375-
376- assert_plan_eq ( plan, expected)
398+ assert_optimized_plan_equal ! (
399+ plan,
400+ @r"
401+ Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
402+ TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
403+ TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
404+ "
405+ )
377406 }
378407}
0 commit comments