Skip to content

Commit 7226089

Browse files
authored
Use DuckDB parser to qualify WHERE (#11)
1 parent 6001370 commit 7226089

File tree

5 files changed

+269
-9
lines changed

5 files changed

+269
-9
lines changed

include/yardstick_ffi.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,8 @@ char* yardstick_replace_range(
252252
const char* replacement
253253
);
254254

255+
char* yardstick_qualify_expression(const char* expr, const char* qualifier);
256+
255257
/**
256258
* Free a string allocated by yardstick functions.
257259
*/

src/yardstick_extension.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ extern "C" {
2121
void yardstick_free_create_view_info(YardstickCreateViewInfo* info);
2222
char* yardstick_replace_range(const char* sql, uint32_t start, uint32_t end, const char* replacement);
2323
char* yardstick_apply_replacements(const char* sql, const YardstickReplacement* replacements, size_t count);
24+
char* yardstick_qualify_expression(const char* expr, const char* qualifier);
2425
void yardstick_free_string(char* ptr);
2526
char* yardstick_expand_aggregate_call(
2627
const char* measure_name,
@@ -59,6 +60,7 @@ extern "C" {
5960
void (*free_create_view_info)(YardstickCreateViewInfo*),
6061
char* (*replace_range)(const char*, uint32_t, uint32_t, const char*),
6162
char* (*apply_replacements)(const char*, const YardstickReplacement*, size_t),
63+
char* (*qualify_expression)(const char*, const char*),
6264
void (*free_string)(char*),
6365
char* (*expand_aggregate_call)(const char*, const char*, const YardstickAtModifier*, size_t, const char*, const char*, const char*, const char* const*, size_t)
6466
);
@@ -469,6 +471,7 @@ static void LoadInternal(ExtensionLoader &loader) {
469471
yardstick_free_create_view_info,
470472
yardstick_replace_range,
471473
yardstick_apply_replacements,
474+
yardstick_qualify_expression,
472475
yardstick_free_string,
473476
yardstick_expand_aggregate_call
474477
);

src/yardstick_parser_ffi.cpp

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ static void FindAggregateCalls(ParsedExpression* expr, std::vector<AggregateCall
8787
static void CollectTablesFromTableRef(TableRef* ref, std::vector<YardstickTableRef>& tables);
8888
static bool ExpressionContainsAggregate(ParsedExpression* expr);
8989
static bool ExpressionContainsMeasureRef(ParsedExpression* expr);
90+
static void QualifyColumnRefs(ParsedExpression* expr, const std::string& qualifier);
9091

9192
//=============================================================================
9293
// AST Walking: Find AGGREGATE() function calls
@@ -458,6 +459,95 @@ static bool ExpressionContainsMeasureRef(ParsedExpression* expr) {
458459
}
459460
}
460461

462+
static void QualifyColumnRefs(ParsedExpression* expr, const std::string& qualifier) {
463+
if (!expr) return;
464+
465+
switch (expr->expression_class) {
466+
case ExpressionClass::COLUMN_REF: {
467+
auto* col = static_cast<ColumnRefExpression*>(expr);
468+
if (col->column_names.size() == 1) {
469+
col->column_names.insert(col->column_names.begin(), qualifier);
470+
}
471+
break;
472+
}
473+
case ExpressionClass::FUNCTION: {
474+
auto* func = static_cast<FunctionExpression*>(expr);
475+
for (auto& child : func->children) {
476+
QualifyColumnRefs(child.get(), qualifier);
477+
}
478+
if (func->filter) {
479+
QualifyColumnRefs(func->filter.get(), qualifier);
480+
}
481+
break;
482+
}
483+
case ExpressionClass::COMPARISON: {
484+
auto* comp = static_cast<ComparisonExpression*>(expr);
485+
QualifyColumnRefs(comp->left.get(), qualifier);
486+
QualifyColumnRefs(comp->right.get(), qualifier);
487+
break;
488+
}
489+
case ExpressionClass::CONJUNCTION: {
490+
auto* conj = static_cast<ConjunctionExpression*>(expr);
491+
for (auto& child : conj->children) {
492+
QualifyColumnRefs(child.get(), qualifier);
493+
}
494+
break;
495+
}
496+
case ExpressionClass::OPERATOR: {
497+
auto* op = static_cast<OperatorExpression*>(expr);
498+
for (auto& child : op->children) {
499+
QualifyColumnRefs(child.get(), qualifier);
500+
}
501+
break;
502+
}
503+
case ExpressionClass::CASE: {
504+
auto* case_expr = static_cast<CaseExpression*>(expr);
505+
for (auto& check : case_expr->case_checks) {
506+
QualifyColumnRefs(check.when_expr.get(), qualifier);
507+
QualifyColumnRefs(check.then_expr.get(), qualifier);
508+
}
509+
if (case_expr->else_expr) {
510+
QualifyColumnRefs(case_expr->else_expr.get(), qualifier);
511+
}
512+
break;
513+
}
514+
case ExpressionClass::CAST: {
515+
auto* cast = static_cast<CastExpression*>(expr);
516+
QualifyColumnRefs(cast->child.get(), qualifier);
517+
break;
518+
}
519+
case ExpressionClass::SUBQUERY: {
520+
auto* subq = static_cast<SubqueryExpression*>(expr);
521+
if (subq->child) {
522+
QualifyColumnRefs(subq->child.get(), qualifier);
523+
}
524+
break;
525+
}
526+
case ExpressionClass::WINDOW: {
527+
auto* window = static_cast<WindowExpression*>(expr);
528+
for (auto& child : window->children) {
529+
QualifyColumnRefs(child.get(), qualifier);
530+
}
531+
for (auto& part : window->partitions) {
532+
QualifyColumnRefs(part.get(), qualifier);
533+
}
534+
if (window->filter_expr) {
535+
QualifyColumnRefs(window->filter_expr.get(), qualifier);
536+
}
537+
break;
538+
}
539+
case ExpressionClass::BETWEEN: {
540+
auto* between = static_cast<BetweenExpression*>(expr);
541+
QualifyColumnRefs(between->input.get(), qualifier);
542+
QualifyColumnRefs(between->lower.get(), qualifier);
543+
QualifyColumnRefs(between->upper.get(), qualifier);
544+
break;
545+
}
546+
default:
547+
break;
548+
}
549+
}
550+
461551
//=============================================================================
462552
// FFI Implementation: yardstick_find_aggregates
463553
//=============================================================================
@@ -919,6 +1009,34 @@ extern "C" char* yardstick_replace_range(
9191009
return safe_strdup(result);
9201010
}
9211011

1012+
extern "C" char* yardstick_qualify_expression(const char* expr_str, const char* qualifier) {
1013+
if (!expr_str || !qualifier) return nullptr;
1014+
1015+
try {
1016+
auto expressions = Parser::ParseExpressionList(expr_str);
1017+
if (expressions.empty()) {
1018+
return safe_strdup(expr_str);
1019+
}
1020+
1021+
for (auto& expr : expressions) {
1022+
QualifyColumnRefs(expr.get(), qualifier);
1023+
}
1024+
1025+
if (expressions.size() == 1) {
1026+
return safe_strdup(expressions[0]->ToString());
1027+
}
1028+
1029+
std::string result;
1030+
for (size_t i = 0; i < expressions.size(); i++) {
1031+
if (i > 0) result += ", ";
1032+
result += expressions[i]->ToString();
1033+
}
1034+
return safe_strdup(result);
1035+
} catch (const std::exception&) {
1036+
return nullptr;
1037+
}
1038+
}
1039+
9221040
//=============================================================================
9231041
// FFI Implementation: yardstick_free_string
9241042
//=============================================================================

yardstick-rs/src/parser_ffi.rs

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ type FnParseCreateView = unsafe extern "C" fn(*const c_char) -> *mut YardstickCr
168168
type FnFreeCreateViewInfo = unsafe extern "C" fn(*mut YardstickCreateViewInfo);
169169
type FnReplaceRange = unsafe extern "C" fn(*const c_char, u32, u32, *const c_char) -> *mut c_char;
170170
type FnApplyReplacements = unsafe extern "C" fn(*const c_char, *const YardstickReplacement, usize) -> *mut c_char;
171+
type FnQualifyExpression = unsafe extern "C" fn(*const c_char, *const c_char) -> *mut c_char;
171172
type FnFreeString = unsafe extern "C" fn(*mut c_char);
172173
type FnExpandAggregateCall = unsafe extern "C" fn(
173174
*const c_char, *const c_char, *const YardstickAtModifier, usize,
@@ -185,6 +186,7 @@ static FN_PARSE_CREATE_VIEW: AtomicPtr<()> = AtomicPtr::new(ptr::null_mut());
185186
static FN_FREE_CREATE_VIEW_INFO: AtomicPtr<()> = AtomicPtr::new(ptr::null_mut());
186187
static FN_REPLACE_RANGE: AtomicPtr<()> = AtomicPtr::new(ptr::null_mut());
187188
static FN_APPLY_REPLACEMENTS: AtomicPtr<()> = AtomicPtr::new(ptr::null_mut());
189+
static FN_QUALIFY_EXPRESSION: AtomicPtr<()> = AtomicPtr::new(ptr::null_mut());
188190
static FN_FREE_STRING: AtomicPtr<()> = AtomicPtr::new(ptr::null_mut());
189191
static FN_EXPAND_AGGREGATE_CALL: AtomicPtr<()> = AtomicPtr::new(ptr::null_mut());
190192

@@ -201,6 +203,7 @@ pub extern "C" fn yardstick_init_parser_ffi(
201203
free_create_view_info: FnFreeCreateViewInfo,
202204
replace_range: FnReplaceRange,
203205
apply_replacements: FnApplyReplacements,
206+
qualify_expression: FnQualifyExpression,
204207
free_string: FnFreeString,
205208
expand_aggregate_call: FnExpandAggregateCall,
206209
) {
@@ -214,6 +217,7 @@ pub extern "C" fn yardstick_init_parser_ffi(
214217
FN_FREE_CREATE_VIEW_INFO.store(free_create_view_info as *mut (), Ordering::SeqCst);
215218
FN_REPLACE_RANGE.store(replace_range as *mut (), Ordering::SeqCst);
216219
FN_APPLY_REPLACEMENTS.store(apply_replacements as *mut (), Ordering::SeqCst);
220+
FN_QUALIFY_EXPRESSION.store(qualify_expression as *mut (), Ordering::SeqCst);
217221
FN_FREE_STRING.store(free_string as *mut (), Ordering::SeqCst);
218222
FN_EXPAND_AGGREGATE_CALL.store(expand_aggregate_call as *mut (), Ordering::SeqCst);
219223
}
@@ -790,6 +794,27 @@ pub fn apply_replacements(sql: &str, replacements: &[Replacement]) -> Result<Str
790794
}
791795
}
792796

797+
pub fn qualify_expression(expr: &str, qualifier: &str) -> Result<String, String> {
798+
let expr_ptr = CString::new(expr).map_err(|e| format!("Invalid expression string: {e}"))?;
799+
let qualifier_ptr = CString::new(qualifier).map_err(|e| format!("Invalid qualifier: {e}"))?;
800+
801+
let fn_ptr = FN_QUALIFY_EXPRESSION.load(Ordering::SeqCst);
802+
if fn_ptr.is_null() {
803+
return Err("Parser FFI not initialized".to_string());
804+
}
805+
806+
unsafe {
807+
let f: FnQualifyExpression = std::mem::transmute(fn_ptr);
808+
let result_ptr = f(expr_ptr.as_ptr(), qualifier_ptr.as_ptr());
809+
if result_ptr.is_null() {
810+
return Err("Failed to qualify expression".to_string());
811+
}
812+
let result = c_str_to_string(result_ptr).unwrap_or_default();
813+
yardstick_free_string(result_ptr);
814+
Ok(result)
815+
}
816+
}
817+
793818
/// Expand a single AGGREGATE() call to SQL.
794819
///
795820
/// Generates a correlated subquery for the measure based on the aggregation function
@@ -987,6 +1012,17 @@ mod tests {
9871012
assert_eq!(result, "SELECT baz FROM bar");
9881013
}
9891014

1015+
#[test]
1016+
#[ignore = "requires C++ library to be linked"]
1017+
fn test_qualify_expression() {
1018+
let result = qualify_expression("year between date '2023-01-01' and date '2025-01-01'", "_inner")
1019+
.unwrap();
1020+
assert_eq!(
1021+
result,
1022+
"_inner.year BETWEEN DATE '2023-01-01' AND DATE '2025-01-01'"
1023+
);
1024+
}
1025+
9901026
#[test]
9911027
#[ignore = "requires C++ library to be linked"]
9921028
fn test_expand_aggregate_call() {

0 commit comments

Comments
 (0)