Skip to content

Commit 06dda23

Browse files
committed
Reworking to do token parsing of sql query instead of string manipulation
1 parent 79ec803 commit 06dda23

File tree

5 files changed

+186
-33
lines changed

5 files changed

+186
-33
lines changed

python/datafusion/context.py

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
except ImportError:
2828
from typing_extensions import deprecated # Python 3.12
2929

30-
import uuid
3130

3231
import pyarrow as pa
3332

@@ -612,25 +611,41 @@ def sql(
612611
Returns:
613612
DataFrame representation of the SQL query.
614613
"""
615-
if named_params:
616-
for alias, param in named_params.items():
617-
if isinstance(param, DataFrame):
618-
view_name = str(uuid.uuid4()).replace("-", "_")
619-
view_name = f"view_{view_name}"
620-
self.ctx.create_temporary_view(
621-
view_name, param.df, replace_if_exists=True
622-
)
623-
replace_str = view_name
624-
else:
625-
replace_str = str(param)
626614

627-
query = query.replace(f"{{{alias}}}", replace_str)
615+
def scalar_params(**p: Any) -> list[tuple[str, pa.Scalar]]:
616+
if p is None:
617+
return []
618+
619+
return [
620+
(name, pa.scalar(value))
621+
for (name, value) in p.items()
622+
if not isinstance(value, DataFrame)
623+
]
628624

629-
if options is None:
630-
return DataFrame(self.ctx.sql(query))
631-
return DataFrame(self.ctx.sql_with_options(query, options.options_internal))
625+
def dataframe_params(**p: Any) -> list[tuple[str, DataFrame]]:
626+
if p is None:
627+
return []
632628

633-
def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
629+
return [
630+
(name, value.df)
631+
for (name, value) in p.items()
632+
if isinstance(value, DataFrame)
633+
]
634+
635+
options_raw = options.options_internal if options is not None else None
636+
637+
return DataFrame(
638+
self.ctx.sql_with_options(
639+
query,
640+
options=options_raw,
641+
scalar_params=scalar_params(**named_params),
642+
dataframe_params=dataframe_params(**named_params),
643+
)
644+
)
645+
646+
def sql_with_options(
647+
self, query: str, options: SQLOptions, **named_params: Any
648+
) -> DataFrame:
634649
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from SQL query text.
635650
636651
This function will first validate that the query is allowed by the
@@ -639,11 +654,12 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
639654
Args:
640655
query: SQL query text.
641656
options: SQL options.
657+
named_params: Provides substitution in the query string.
642658
643659
Returns:
644660
DataFrame representation of the SQL query.
645661
"""
646-
return self.sql(query, options)
662+
return self.sql(query, options, **named_params)
647663

648664
def create_dataframe(
649665
self,

python/tests/test_sql.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pyarrow as pa
2222
import pyarrow.dataset as ds
2323
import pytest
24-
from datafusion import col, udf
24+
from datafusion import SessionContext, col, udf
2525
from datafusion.object_store import Http
2626
from pyarrow.csv import write_csv
2727

@@ -552,11 +552,25 @@ def test_register_listing_table(
552552
assert dict(zip(rd["grp"], rd["count"], strict=False)) == {"a": 3, "b": 2}
553553

554554

555-
def test_parameterized_sql(ctx, tmp_path) -> None:
555+
def test_parameterized_df_in_sql(ctx, tmp_path) -> None:
556556
path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())
557+
557558
df = ctx.read_parquet(path)
558559
result = ctx.sql(
559-
"SELECT COUNT(a) AS cnt FROM {replaced_df}", replaced_df=df
560+
"SELECT COUNT(a) AS cnt FROM $replaced_df", replaced_df=df
560561
).collect()
561562
result = pa.Table.from_batches(result)
562563
assert result.to_pydict() == {"cnt": [100]}
564+
565+
566+
def test_parameterized_pass_through_in_sql(ctx: SessionContext) -> None:
567+
# Test the parameters that should be handled by the parser rather
568+
# than our manipulation of the query string by searching for tokens
569+
batch = pa.RecordBatch.from_arrays(
570+
[pa.array([1, 2, 3, 4])],
571+
names=["a"],
572+
)
573+
574+
ctx.register_record_batches("t", [[batch]])
575+
result = ctx.sql("SELECT a FROM t WHERE a < $val", val=3)
576+
assert result.to_pydict() == {"a": [1, 2]}

src/context.rs

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -55,14 +55,16 @@ use url::Url;
5555
use uuid::Uuid;
5656

5757
use crate::catalog::{PyCatalog, RustWrappedPyCatalogProvider};
58+
use crate::common::data_type::PyScalarValue;
5859
use crate::dataframe::PyDataFrame;
5960
use crate::dataset::Dataset;
60-
use crate::errors::{py_datafusion_err, PyDataFusionResult};
61+
use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult};
6162
use crate::expr::sort_expr::PySortExpr;
6263
use crate::physical_plan::PyExecutionPlan;
6364
use crate::record_batch::PyRecordBatchStream;
6465
use crate::sql::exceptions::py_value_err;
6566
use crate::sql::logical::PyLogicalPlan;
67+
use crate::sql::util::replace_placeholders_with_table_names;
6668
use crate::store::StorageContexts;
6769
use crate::table::{PyTable, TempViewTable};
6870
use crate::udaf::PyAggregateUDF;
@@ -422,27 +424,54 @@ impl PySessionContext {
422424
self.ctx.register_udtf(&name, func);
423425
}
424426

425-
/// Returns a PyDataFrame whose plan corresponds to the SQL statement.
426-
pub fn sql(&self, query: &str, py: Python) -> PyDataFusionResult<PyDataFrame> {
427-
let result = self.ctx.sql(query);
428-
let df = wait_for_future(py, result)??;
429-
Ok(PyDataFrame::new(df))
430-
}
431-
432-
#[pyo3(signature = (query, options=None))]
427+
#[pyo3(signature = (query, options=None, scalar_params=vec![], dataframe_params=vec![]))]
433428
pub fn sql_with_options(
434429
&self,
430+
py: Python,
435431
query: &str,
436432
options: Option<PySQLOptions>,
437-
py: Python,
433+
scalar_params: Vec<(String, PyScalarValue)>,
434+
dataframe_params: Vec<(String, PyDataFrame)>,
438435
) -> PyDataFusionResult<PyDataFrame> {
439436
let options = if let Some(options) = options {
440437
options.options
441438
} else {
442439
SQLOptions::new()
443440
};
444-
let result = self.ctx.sql_with_options(query, options);
445-
let df = wait_for_future(py, result)??;
441+
442+
let scalar_params = scalar_params
443+
.into_iter()
444+
.map(|(name, value)| (name, ScalarValue::from(value)))
445+
.collect::<Vec<_>>();
446+
447+
let dataframe_params = dataframe_params
448+
.into_iter()
449+
.map(|(name, df)| {
450+
let uuid = Uuid::new_v4().to_string().replace("-", "");
451+
let view_name = format!("view_{uuid}");
452+
453+
self.create_temporary_view(py, view_name.as_str(), df, true)?;
454+
Ok((name, view_name))
455+
})
456+
.collect::<Result<HashMap<_, _>, PyDataFusionError>>()?;
457+
458+
let state = self.ctx.state();
459+
let dialect = state.config().options().sql_parser.dialect.as_str();
460+
461+
let query = replace_placeholders_with_table_names(query, dialect, dataframe_params)?;
462+
463+
println!("using scalar params: {scalar_params:?}");
464+
let df = wait_for_future(py, async {
465+
self.ctx
466+
.sql_with_options(&query, options)
467+
.await
468+
.map_err(|err| {
469+
println!("error before param replacement: {}", err);
470+
err
471+
})?
472+
.with_param_values(scalar_params)
473+
})??;
474+
446475
Ok(PyDataFrame::new(df))
447476
}
448477

src/sql.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717

1818
pub mod exceptions;
1919
pub mod logical;
20+
pub(crate) mod util;

src/sql/util.rs

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
use datafusion::common::{internal_err, plan_datafusion_err, DataFusionError};
2+
use datafusion::logical_expr::sqlparser::dialect::dialect_from_str;
3+
use datafusion::sql::sqlparser::dialect::Dialect;
4+
use datafusion::sql::sqlparser::keywords::Keyword;
5+
use datafusion::sql::sqlparser::parser::Parser;
6+
use datafusion::sql::sqlparser::tokenizer::{Token, Tokenizer, Word};
7+
use std::collections::HashMap;
8+
9+
fn value_from_replacements(
10+
placeholder: &str,
11+
replacements: &HashMap<String, String>,
12+
) -> Option<Token> {
13+
if let Some(pattern) = placeholder.strip_prefix("$") {
14+
replacements.get(pattern).map(|replacement| {
15+
Token::Word(Word {
16+
value: replacement.to_owned(),
17+
quote_style: None,
18+
keyword: Keyword::NoKeyword,
19+
})
20+
})
21+
} else {
22+
None
23+
}
24+
}
25+
26+
fn table_names_are_valid(dialect: &dyn Dialect, replacements: &HashMap<String, String>) -> bool {
27+
for name in replacements.values() {
28+
let tokens = Tokenizer::new(dialect, name).tokenize().unwrap();
29+
if tokens.len() != 1 {
30+
// We should get exactly one token for our temporary table name
31+
return false;
32+
}
33+
34+
if let Token::Word(word) = &tokens[0] {
35+
// Generated table names should be not quoted or have keywords
36+
if word.quote_style.is_some() || word.keyword != Keyword::NoKeyword {
37+
return false;
38+
}
39+
} else {
40+
// We should always parse table names to a Word
41+
return false;
42+
}
43+
}
44+
45+
true
46+
}
47+
48+
pub(crate) fn replace_placeholders_with_table_names(
49+
query: &str,
50+
dialect: &str,
51+
replacements: HashMap<String, String>,
52+
) -> Result<String, DataFusionError> {
53+
let dialect = dialect_from_str(dialect).ok_or_else(|| {
54+
plan_datafusion_err!(
55+
"Unsupported SQL dialect: {dialect}. Available dialects: \
56+
Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \
57+
MsSQL, ClickHouse, BigQuery, Ansi, DuckDB, Databricks."
58+
)
59+
})?;
60+
61+
if !table_names_are_valid(dialect.as_ref(), &replacements) {
62+
return internal_err!("Invalid generated table name when replacing placeholders");
63+
}
64+
let tokens = Tokenizer::new(dialect.as_ref(), query).tokenize().unwrap();
65+
66+
let replaced_tokens = tokens
67+
.into_iter()
68+
.map(|token| {
69+
if let Token::Word(word) = &token {
70+
let Word {
71+
value,
72+
quote_style: _,
73+
keyword: _,
74+
} = word;
75+
76+
value_from_replacements(value, &replacements).unwrap_or(token)
77+
} else if let Token::Placeholder(placeholder) = &token {
78+
value_from_replacements(placeholder, &replacements).unwrap_or(token)
79+
} else {
80+
token
81+
}
82+
})
83+
.collect::<Vec<Token>>();
84+
85+
Ok(Parser::new(dialect.as_ref())
86+
.with_tokens(replaced_tokens)
87+
.parse_statements()
88+
.map_err(|err| DataFusionError::External(Box::new(err)))?
89+
.into_iter()
90+
.map(|s| s.to_string())
91+
.collect::<Vec<_>>()
92+
.join(" "))
93+
}

0 commit comments

Comments
 (0)