Skip to content

Commit 69a9ede

Browse files
committed
migrate assert_optimized_plan_equal in extract_equijoin_predicate.rs to use snapshot assertions
1 parent 6ba4152 commit 69a9ede

File tree

1 file changed

+81
-52
lines changed

1 file changed

+81
-52
lines changed

datafusion/optimizer/src/extract_equijoin_predicate.rs

Lines changed: 81 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -155,21 +155,26 @@ fn split_eq_and_noneq_join_predicate(
155155
#[cfg(test)]
156156
mod tests {
157157
use super::*;
158+
use crate::assert_optimized_plan_eq_display_indent_snapshot;
158159
use crate::test::*;
159160
use arrow::datatypes::DataType;
160161
use datafusion_expr::{
161162
col, lit, logical_plan::builder::LogicalPlanBuilder, JoinType,
162163
};
163164
use std::sync::Arc;
164165

165-
fn assert_plan_eq(plan: LogicalPlan, expected: &str) -> Result<()> {
166-
assert_optimized_plan_eq_display_indent(
167-
Arc::new(ExtractEquijoinPredicate {}),
168-
plan,
169-
expected,
170-
);
171-
172-
Ok(())
166+
macro_rules! assert_optimized_plan_equal {
167+
(
168+
$plan:expr,
169+
@ $expected:literal $(,)?
170+
) => {{
171+
let rule: Arc<dyn crate::OptimizerRule + Send + Sync> = Arc::new(ExtractEquijoinPredicate {});
172+
assert_optimized_plan_eq_display_indent_snapshot!(
173+
rule,
174+
$plan,
175+
@ $expected,
176+
)
177+
}};
173178
}
174179

175180
#[test]
@@ -180,11 +185,14 @@ mod tests {
180185
let plan = LogicalPlanBuilder::from(t1)
181186
.join_on(t2, JoinType::Left, Some(col("t1.a").eq(col("t2.a"))))?
182187
.build()?;
183-
let expected = "Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
184-
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
185-
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
186-
187-
assert_plan_eq(plan, expected)
188+
assert_optimized_plan_equal!(
189+
plan,
190+
@r"
191+
Left Join: t1.a = t2.a [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
192+
TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
193+
TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
194+
"
195+
)
188196
}
189197

190198
#[test]
@@ -199,11 +207,14 @@ mod tests {
199207
Some((col("t1.a") + lit(10i64)).eq(col("t2.a") * lit(2u32))),
200208
)?
201209
.build()?;
202-
let expected = "Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
203-
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
204-
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
205-
206-
assert_plan_eq(plan, expected)
210+
assert_optimized_plan_equal!(
211+
plan,
212+
@r"
213+
Left Join: t1.a + Int64(10) = t2.a * UInt32(2) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
214+
TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
215+
TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
216+
"
217+
)
207218
}
208219

209220
#[test]
@@ -222,11 +233,14 @@ mod tests {
222233
),
223234
)?
224235
.build()?;
225-
let expected = "Left Join: Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
226-
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
227-
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
228-
229-
assert_plan_eq(plan, expected)
236+
assert_optimized_plan_equal!(
237+
plan,
238+
@r"
239+
Left Join: Filter: t1.a + Int64(10) >= t2.a * UInt32(2) AND t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
240+
TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
241+
TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
242+
"
243+
)
230244
}
231245

232246
#[test]
@@ -249,11 +263,14 @@ mod tests {
249263
),
250264
)?
251265
.build()?;
252-
let expected = "Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
253-
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
254-
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
255-
256-
assert_plan_eq(plan, expected)
266+
assert_optimized_plan_equal!(
267+
plan,
268+
@r"
269+
Left Join: t1.a + UInt32(11) = t2.a * UInt32(2), t1.a + Int64(10) = t2.a * UInt32(2) Filter: t1.b < Int32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
270+
TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
271+
TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
272+
"
273+
)
257274
}
258275

259276
#[test]
@@ -275,11 +292,14 @@ mod tests {
275292
),
276293
)?
277294
.build()?;
278-
let expected = "Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
279-
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
280-
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
281-
282-
assert_plan_eq(plan, expected)
295+
assert_optimized_plan_equal!(
296+
plan,
297+
@r"
298+
Left Join: t1.a = t2.a, t1.b = t2.b Filter: t1.c = t2.c OR t1.a + t1.b > t2.b + t2.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
299+
TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
300+
TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
301+
"
302+
)
283303
}
284304

285305
#[test]
@@ -310,13 +330,16 @@ mod tests {
310330
),
311331
)?
312332
.build()?;
313-
let expected = "Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
314-
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
315-
\n Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
316-
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
317-
\n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]";
318-
319-
assert_plan_eq(plan, expected)
333+
assert_optimized_plan_equal!(
334+
plan,
335+
@r"
336+
Left Join: t1.a = t2.a Filter: t1.c + t2.c + t3.c < UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
337+
TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
338+
Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
339+
TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
340+
TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
341+
"
342+
)
320343
}
321344

322345
#[test]
@@ -343,13 +366,16 @@ mod tests {
343366
Some(col("t1.a").eq(col("t2.a")).and(col("t2.c").eq(col("t3.c")))),
344367
)?
345368
.build()?;
346-
let expected = "Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
347-
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
348-
\n Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
349-
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]\
350-
\n TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]";
351-
352-
assert_plan_eq(plan, expected)
369+
assert_optimized_plan_equal!(
370+
plan,
371+
@r"
372+
Left Join: t1.a = t2.a Filter: t2.c = t3.c [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N, a:UInt32;N, b:UInt32;N, c:UInt32;N]
373+
TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
374+
Left Join: t2.a = t3.a Filter: t2.a + t3.b > UInt32(100) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
375+
TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
376+
TableScan: t3 [a:UInt32, b:UInt32, c:UInt32]
377+
"
378+
)
353379
}
354380

355381
#[test]
@@ -369,10 +395,13 @@ mod tests {
369395
let plan = LogicalPlanBuilder::from(t1)
370396
.join_on(t2, JoinType::Left, Some(filter))?
371397
.build()?;
372-
let expected = "Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]\
373-
\n TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]\
374-
\n TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]";
375-
376-
assert_plan_eq(plan, expected)
398+
assert_optimized_plan_equal!(
399+
plan,
400+
@r"
401+
Left Join: t1.a + CAST(Int64(1) AS UInt32) = t2.a + CAST(Int32(2) AS UInt32) [a:UInt32, b:UInt32, c:UInt32, a:UInt32;N, b:UInt32;N, c:UInt32;N]
402+
TableScan: t1 [a:UInt32, b:UInt32, c:UInt32]
403+
TableScan: t2 [a:UInt32, b:UInt32, c:UInt32]
404+
"
405+
)
377406
}
378407
}

0 commit comments

Comments
 (0)