|
1 | | -use datafusion::common::{internal_err, plan_datafusion_err, DataFusionError}; |
| 1 | +use datafusion::common::{exec_err, plan_datafusion_err, DataFusionError}; |
2 | 2 | use datafusion::logical_expr::sqlparser::dialect::dialect_from_str; |
3 | 3 | use datafusion::sql::sqlparser::dialect::Dialect; |
4 | | -use datafusion::sql::sqlparser::keywords::Keyword; |
5 | 4 | use datafusion::sql::sqlparser::parser::Parser; |
6 | | -use datafusion::sql::sqlparser::tokenizer::{Token, Tokenizer, Word}; |
| 5 | +use datafusion::sql::sqlparser::tokenizer::{Token, Tokenizer}; |
7 | 6 | use std::collections::HashMap; |
8 | 7 |
|
9 | | -fn value_from_replacements( |
| 8 | +fn tokens_from_replacements( |
10 | 9 | placeholder: &str, |
11 | | - replacements: &HashMap<String, String>, |
12 | | -) -> Option<Token> { |
| 10 | + replacements: &HashMap<String, Vec<Token>>, |
| 11 | +) -> Option<Vec<Token>> { |
13 | 12 | if let Some(pattern) = placeholder.strip_prefix("$") { |
14 | | - replacements.get(pattern).map(|replacement| { |
15 | | - Token::Word(Word { |
16 | | - value: replacement.to_owned(), |
17 | | - quote_style: None, |
18 | | - keyword: Keyword::NoKeyword, |
19 | | - }) |
20 | | - }) |
| 13 | + replacements.get(pattern).cloned() |
21 | 14 | } else { |
22 | 15 | None |
23 | 16 | } |
24 | 17 | } |
25 | 18 |
|
26 | | -fn table_names_are_valid(dialect: &dyn Dialect, replacements: &HashMap<String, String>) -> bool { |
27 | | - for name in replacements.values() { |
28 | | - let tokens = Tokenizer::new(dialect, name).tokenize().unwrap(); |
29 | | - if tokens.len() != 1 { |
30 | | - // We should get exactly one token for our temporary table name |
31 | | - return false; |
32 | | - } |
33 | | - |
34 | | - if let Token::Word(word) = &tokens[0] { |
35 | | - // Generated table names should be not quoted or have keywords |
36 | | - if word.quote_style.is_some() || word.keyword != Keyword::NoKeyword { |
37 | | - return false; |
38 | | - } |
39 | | - } else { |
40 | | - // We should always parse table names to a Word |
41 | | - return false; |
42 | | - } |
43 | | - } |
44 | | - |
45 | | - true |
| 19 | +fn get_tokens_for_string_replacement( |
| 20 | + dialect: &dyn Dialect, |
| 21 | + replacements: HashMap<String, String>, |
| 22 | +) -> Result<HashMap<String, Vec<Token>>, DataFusionError> { |
| 23 | + replacements |
| 24 | + .into_iter() |
| 25 | + .map(|(name, value)| { |
| 26 | + let tokens = Tokenizer::new(dialect, &value) |
| 27 | + .tokenize() |
| 28 | + .map_err(|err| DataFusionError::External(err.into()))?; |
| 29 | + Ok((name, tokens)) |
| 30 | + }) |
| 31 | + .collect() |
46 | 32 | } |
47 | 33 |
|
48 | | -pub(crate) fn replace_placeholders_with_table_names( |
| 34 | +pub(crate) fn replace_placeholders_with_strings( |
49 | 35 | query: &str, |
50 | 36 | dialect: &str, |
51 | 37 | replacements: HashMap<String, String>, |
52 | 38 | ) -> Result<String, DataFusionError> { |
53 | | - let dialect = dialect_from_str(dialect).ok_or_else(|| { |
54 | | - plan_datafusion_err!( |
55 | | - "Unsupported SQL dialect: {dialect}. Available dialects: \ |
56 | | - Generic, MySQL, PostgreSQL, Hive, SQLite, Snowflake, Redshift, \ |
57 | | - MsSQL, ClickHouse, BigQuery, Ansi, DuckDB, Databricks." |
58 | | - ) |
59 | | - })?; |
| 39 | + let dialect = dialect_from_str(dialect) |
| 40 | + .ok_or_else(|| plan_datafusion_err!("Unsupported SQL dialect: {dialect}."))?; |
60 | 41 |
|
61 | | - if !table_names_are_valid(dialect.as_ref(), &replacements) { |
62 | | - return internal_err!("Invalid generated table name when replacing placeholders"); |
63 | | - } |
64 | | - let tokens = Tokenizer::new(dialect.as_ref(), query).tokenize().unwrap(); |
| 42 | + let replacements = get_tokens_for_string_replacement(dialect.as_ref(), replacements)?; |
| 43 | + |
| 44 | + let tokens = Tokenizer::new(dialect.as_ref(), query) |
| 45 | + .tokenize() |
| 46 | + .map_err(|err| DataFusionError::External(err.into()))?; |
65 | 47 |
|
66 | 48 | let replaced_tokens = tokens |
67 | 49 | .into_iter() |
68 | | - .map(|token| { |
69 | | - if let Token::Word(word) = &token { |
70 | | - let Word { |
71 | | - value, |
72 | | - quote_style: _, |
73 | | - keyword: _, |
74 | | - } = word; |
75 | | - |
76 | | - value_from_replacements(value, &replacements).unwrap_or(token) |
77 | | - } else if let Token::Placeholder(placeholder) = &token { |
78 | | - value_from_replacements(placeholder, &replacements).unwrap_or(token) |
| 50 | + .flat_map(|token| { |
| 51 | + if let Token::Placeholder(placeholder) = &token { |
| 52 | + tokens_from_replacements(placeholder, &replacements).unwrap_or(vec![token]) |
79 | 53 | } else { |
80 | | - token |
| 54 | + vec![token] |
81 | 55 | } |
82 | 56 | }) |
83 | 57 | .collect::<Vec<Token>>(); |
84 | 58 |
|
85 | | - Ok(Parser::new(dialect.as_ref()) |
| 59 | + let statement = Parser::new(dialect.as_ref()) |
86 | 60 | .with_tokens(replaced_tokens) |
87 | 61 | .parse_statements() |
88 | | - .map_err(|err| DataFusionError::External(Box::new(err)))? |
89 | | - .into_iter() |
90 | | - .map(|s| s.to_string()) |
91 | | - .collect::<Vec<_>>() |
92 | | - .join(" ")) |
| 62 | + .map_err(|err| DataFusionError::External(Box::new(err)))?; |
| 63 | + |
| 64 | + if statement.len() != 1 { |
| 65 | + return exec_err!("placeholder replacement should return exactly one statement"); |
| 66 | + } |
| 67 | + |
| 68 | + Ok(statement[0].to_string()) |
93 | 69 | } |
0 commit comments