@@ -55,14 +55,16 @@ use url::Url;
5555use uuid:: Uuid ;
5656
5757use crate :: catalog:: { PyCatalog , RustWrappedPyCatalogProvider } ;
58+ use crate :: common:: data_type:: PyScalarValue ;
5859use crate :: dataframe:: PyDataFrame ;
5960use crate :: dataset:: Dataset ;
60- use crate :: errors:: { py_datafusion_err, PyDataFusionResult } ;
61+ use crate :: errors:: { py_datafusion_err, PyDataFusionError , PyDataFusionResult } ;
6162use crate :: expr:: sort_expr:: PySortExpr ;
6263use crate :: physical_plan:: PyExecutionPlan ;
6364use crate :: record_batch:: PyRecordBatchStream ;
6465use crate :: sql:: exceptions:: py_value_err;
6566use crate :: sql:: logical:: PyLogicalPlan ;
67+ use crate :: sql:: util:: replace_placeholders_with_table_names;
6668use crate :: store:: StorageContexts ;
6769use crate :: table:: { PyTable , TempViewTable } ;
6870use crate :: udaf:: PyAggregateUDF ;
@@ -422,27 +424,54 @@ impl PySessionContext {
422424 self . ctx . register_udtf ( & name, func) ;
423425 }
424426
425- /// Returns a PyDataFrame whose plan corresponds to the SQL statement.
426- pub fn sql ( & self , query : & str , py : Python ) -> PyDataFusionResult < PyDataFrame > {
427- let result = self . ctx . sql ( query) ;
428- let df = wait_for_future ( py, result) ??;
429- Ok ( PyDataFrame :: new ( df) )
430- }
431-
432- #[ pyo3( signature = ( query, options=None ) ) ]
427+ #[ pyo3( signature = ( query, options=None , scalar_params=vec![ ] , dataframe_params=vec![ ] ) ) ]
433428 pub fn sql_with_options (
434429 & self ,
430+ py : Python ,
435431 query : & str ,
436432 options : Option < PySQLOptions > ,
437- py : Python ,
433+ scalar_params : Vec < ( String , PyScalarValue ) > ,
434+ dataframe_params : Vec < ( String , PyDataFrame ) > ,
438435 ) -> PyDataFusionResult < PyDataFrame > {
439436 let options = if let Some ( options) = options {
440437 options. options
441438 } else {
442439 SQLOptions :: new ( )
443440 } ;
444- let result = self . ctx . sql_with_options ( query, options) ;
445- let df = wait_for_future ( py, result) ??;
441+
442+ let scalar_params = scalar_params
443+ . into_iter ( )
444+ . map ( |( name, value) | ( name, ScalarValue :: from ( value) ) )
445+ . collect :: < Vec < _ > > ( ) ;
446+
447+ let dataframe_params = dataframe_params
448+ . into_iter ( )
449+ . map ( |( name, df) | {
450+ let uuid = Uuid :: new_v4 ( ) . to_string ( ) . replace ( "-" , "" ) ;
451+ let view_name = format ! ( "view_{uuid}" ) ;
452+
453+ self . create_temporary_view ( py, view_name. as_str ( ) , df, true ) ?;
454+ Ok ( ( name, view_name) )
455+ } )
456+ . collect :: < Result < HashMap < _ , _ > , PyDataFusionError > > ( ) ?;
457+
458+ let state = self . ctx . state ( ) ;
459+ let dialect = state. config ( ) . options ( ) . sql_parser . dialect . as_str ( ) ;
460+
461+ let query = replace_placeholders_with_table_names ( query, dialect, dataframe_params) ?;
462+
463+ println ! ( "using scalar params: {scalar_params:?}" ) ;
464+ let df = wait_for_future ( py, async {
465+ self . ctx
466+ . sql_with_options ( & query, options)
467+ . await
468+ . map_err ( |err| {
469+ println ! ( "error before param replacement: {}" , err) ;
470+ err
471+ } ) ?
472+ . with_param_values ( scalar_params)
473+ } ) ??;
474+
446475 Ok ( PyDataFrame :: new ( df) )
447476 }
448477
0 commit comments