Skip to content

Commit 0e78092

Browse files
committed
refactor(odbc): enhance OdbcConnection and SQL execution handling
This commit refactors the OdbcConnection structure to utilize a new type for prepared statements, improving type safety and clarity. It also modifies the execute_sql function to handle both prepared and non-prepared SQL statements through a new MaybePrepared enum, streamlining execution logic. Additionally, the prepare method is updated to cache prepared statements more effectively, enhancing performance.
1 parent 6148a21 commit 0e78092

File tree

2 files changed

+78
-83
lines changed

2 files changed

+78
-83
lines changed

sqlx-core/src/odbc/connection/mod.rs

Lines changed: 46 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -9,27 +9,27 @@ use either::Either;
99
mod odbc_bridge;
1010
use futures_core::future::BoxFuture;
1111
use futures_util::future;
12+
use odbc_api::ConnectionTransitions;
1213
use odbc_bridge::{establish_connection, execute_sql};
1314
// no direct spawn_blocking here; use run_blocking helper
1415
use crate::odbc::{OdbcStatement, OdbcStatementMetadata};
15-
use odbc_api::ResultSetMetadata;
16+
use odbc_api::{handles::StatementConnection, Prepared, ResultSetMetadata, SharedConnection};
1617
use std::borrow::Cow;
1718
use std::collections::HashMap;
1819
use std::sync::Arc;
1920

20-
fn collect_columns(
21-
prepared: &mut odbc_api::Prepared<odbc_api::handles::StatementImpl<'_>>,
22-
) -> Vec<OdbcColumn> {
21+
mod executor;
22+
23+
type PreparedStatement = Prepared<StatementConnection<SharedConnection<'static>>>;
24+
25+
fn collect_columns(prepared: &mut PreparedStatement) -> Vec<OdbcColumn> {
2326
let count = prepared.num_result_cols().unwrap_or(0);
2427
(1..=count)
2528
.map(|i| create_column(prepared, i as u16))
2629
.collect()
2730
}
2831

29-
fn create_column(
30-
stmt: &mut odbc_api::Prepared<odbc_api::handles::StatementImpl<'_>>,
31-
index: u16,
32-
) -> OdbcColumn {
32+
fn create_column(stmt: &mut PreparedStatement, index: u16) -> OdbcColumn {
3333
let mut cd = odbc_api::ColumnDescription::default();
3434
let _ = stmt.describe_col(index, &mut cd);
3535

@@ -44,16 +44,21 @@ fn decode_column_name(name_bytes: Vec<u8>, index: u16) -> String {
4444
String::from_utf8(name_bytes).unwrap_or_else(|_| format!("col{}", index - 1))
4545
}
4646

47-
mod executor;
48-
4947
/// A connection to an ODBC-accessible database.
5048
///
5149
/// ODBC uses a blocking C API, so we offload blocking calls to the runtime's blocking
5250
/// thread-pool via `spawn_blocking` and synchronize access with a mutex.
53-
#[derive(Debug)]
5451
pub struct OdbcConnection {
55-
pub(crate) conn: odbc_api::SharedConnection<'static>,
56-
pub(crate) stmt_cache: HashMap<u64, crate::odbc::statement::OdbcStatementMetadata>,
52+
pub(crate) conn: SharedConnection<'static>,
53+
pub(crate) stmt_cache: HashMap<Arc<str>, PreparedStatement>,
54+
}
55+
56+
impl std::fmt::Debug for OdbcConnection {
57+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58+
f.debug_struct("OdbcConnection")
59+
.field("conn", &self.conn)
60+
.finish()
61+
}
5762
}
5863

5964
impl OdbcConnection {
@@ -139,11 +144,16 @@ impl OdbcConnection {
139144
args: Option<OdbcArguments>,
140145
) -> Result<flume::Receiver<Result<Either<OdbcQueryResult, OdbcRow>, Error>>, Error> {
141146
let (tx, rx) = flume::bounded(64);
142-
let sql = sql.to_string();
143-
let args_move = args;
147+
148+
// !!TODO!!!: Put back the prepared statement after usage
149+
let maybe_prepared = if let Some(prepared) = self.stmt_cache.remove(sql) {
150+
MaybePrepared::Prepared(prepared)
151+
} else {
152+
MaybePrepared::NotPrepared(sql.to_string())
153+
};
144154

145155
self.with_conn("execute_stream", move |conn| {
146-
if let Err(e) = execute_sql(conn, &sql, args_move, &tx) {
156+
if let Err(e) = execute_sql(conn, maybe_prepared, args, &tx) {
147157
let _ = tx.send(Err(e));
148158
}
149159
Ok(())
@@ -153,61 +163,38 @@ impl OdbcConnection {
153163
Ok(rx)
154164
}
155165

156-
pub(crate) async fn prepare_metadata(
157-
&mut self,
158-
sql: &str,
159-
) -> Result<(u64, Vec<OdbcColumn>, usize), Error> {
160-
use std::collections::hash_map::DefaultHasher;
161-
use std::hash::{Hash, Hasher};
162-
163-
let mut hasher = DefaultHasher::new();
164-
sql.hash(&mut hasher);
165-
let key = hasher.finish();
166-
167-
// Check cache first
168-
if let Some(metadata) = self.stmt_cache.get(&key) {
169-
return Ok((key, metadata.columns.clone(), metadata.parameters));
170-
}
171-
172-
// Create new prepared statement to get metadata
173-
let sql = sql.to_string();
174-
self.with_conn("prepare_metadata", move |conn| {
175-
let mut prepared = conn.prepare(&sql)?;
176-
let columns = collect_columns(&mut prepared);
177-
let params = usize::from(prepared.num_params().unwrap_or(0));
178-
Ok((columns, params))
179-
})
180-
.await
181-
.map(|(columns, params)| {
182-
// Cache the metadata
183-
let metadata = crate::odbc::statement::OdbcStatementMetadata {
184-
columns: columns.clone(),
185-
parameters: params,
186-
};
187-
self.stmt_cache.insert(key, metadata);
188-
(key, columns, params)
189-
})
190-
}
191-
192166
pub(crate) async fn clear_cached_statements(&mut self) -> Result<(), Error> {
193167
// Clear the statement metadata cache
194168
self.stmt_cache.clear();
195169
Ok(())
196170
}
197171

198-
pub async fn prepare(&mut self, sql: &str) -> Result<OdbcStatement<'static>, Error> {
199-
let (_, columns, parameters) = self.prepare_metadata(sql).await?;
200-
let metadata = OdbcStatementMetadata {
201-
columns,
202-
parameters,
203-
};
172+
pub async fn prepare<'a>(&mut self, sql: &'a str) -> Result<OdbcStatement<'a>, Error> {
173+
let conn = Arc::clone(&self.conn);
174+
let sql_arc = Arc::from(sql.to_string());
175+
let sql_clone = Arc::clone(&sql_arc);
176+
let (prepared, metadata) = run_blocking(move || {
177+
let mut prepared = conn.into_prepared(&sql_clone)?;
178+
let metadata = OdbcStatementMetadata {
179+
columns: collect_columns(&mut prepared),
180+
parameters: usize::from(prepared.num_params().unwrap_or(0)),
181+
};
182+
Ok((prepared, metadata))
183+
})
184+
.await?;
185+
self.stmt_cache.insert(Arc::clone(&sql_arc), prepared);
204186
Ok(OdbcStatement {
205-
sql: Cow::Owned(sql.to_string()),
187+
sql: Cow::Borrowed(sql),
206188
metadata,
207189
})
208190
}
209191
}
210192

193+
pub(crate) enum MaybePrepared {
194+
Prepared(PreparedStatement),
195+
NotPrepared(String),
196+
}
197+
211198
impl Connection for OdbcConnection {
212199
type Database = Odbc;
213200

sqlx-core/src/odbc/connection/odbc_bridge.rs

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
use crate::error::Error;
22
use crate::odbc::{
3-
OdbcArgumentValue, OdbcArguments, OdbcColumn, OdbcQueryResult, OdbcRow, OdbcTypeInfo,
3+
connection::MaybePrepared, OdbcArgumentValue, OdbcArguments, OdbcColumn, OdbcQueryResult,
4+
OdbcRow, OdbcTypeInfo,
45
};
56
use either::Either;
67
use flume::{SendError, Sender};
7-
use odbc_api::{Cursor, CursorRow, IntoParameter, Nullable, Preallocated, ResultSetMetadata};
8+
use odbc_api::handles::{AsStatementRef, Statement};
9+
use odbc_api::{Cursor, CursorRow, IntoParameter, Nullable, ResultSetMetadata};
810

911
pub type ExecuteResult = Result<Either<OdbcQueryResult, OdbcRow>, Error>;
1012
pub type ExecuteSender = Sender<ExecuteResult>;
@@ -21,43 +23,49 @@ pub fn establish_connection(
2123

2224
pub fn execute_sql(
2325
conn: &mut odbc_api::Connection<'static>,
24-
sql: &str,
26+
maybe_prepared: MaybePrepared,
2527
args: Option<OdbcArguments>,
2628
tx: &ExecuteSender,
2729
) -> Result<(), Error> {
2830
let params = prepare_parameters(args);
2931

30-
let mut preallocated = conn.preallocate().map_err(Error::from)?;
31-
32-
if let Some(mut cursor) = preallocated.execute(sql, &params[..])? {
33-
handle_cursor(&mut cursor, tx);
34-
return Ok(());
35-
}
32+
let affected = match maybe_prepared {
33+
MaybePrepared::Prepared(mut prepared) => {
34+
if let Some(mut cursor) = prepared.execute(&params[..])? {
35+
handle_cursor(&mut cursor, tx);
36+
}
37+
extract_rows_affected(&mut prepared)
38+
}
39+
MaybePrepared::NotPrepared(sql) => {
40+
let mut preallocated = conn.preallocate().map_err(Error::from)?;
41+
if let Some(mut cursor) = preallocated.execute(&sql, &params[..])? {
42+
handle_cursor(&mut cursor, tx);
43+
}
44+
extract_rows_affected(&mut preallocated)
45+
}
46+
};
3647

37-
let affected = extract_rows_affected(&mut preallocated);
3848
let _ = send_done(tx, affected);
3949
Ok(())
4050
}
4151

42-
fn extract_rows_affected<S>(stmt: &mut Preallocated<S>) -> u64
43-
where
44-
S: odbc_api::handles::AsStatementRef,
45-
{
46-
let count_opt = match stmt.row_count() {
47-
Ok(count_opt) => count_opt,
48-
Err(_) => {
52+
fn extract_rows_affected<S: AsStatementRef>(stmt: &mut S) -> u64 {
53+
let mut stmt_ref = stmt.as_stmt_ref();
54+
let count = match stmt_ref.row_count().into_result(&stmt_ref) {
55+
Ok(count) => count,
56+
Err(e) => {
57+
log::warn!("Failed to get row count: {}", e);
4958
return 0;
5059
}
5160
};
5261

53-
let count = match count_opt {
54-
Some(count) => count,
55-
None => {
56-
return 0;
62+
match u64::try_from(count) {
63+
Ok(count) => count,
64+
Err(e) => {
65+
log::warn!("Failed to get row count: {}", e);
66+
0
5767
}
58-
};
59-
60-
u64::try_from(count).unwrap_or_default()
68+
}
6169
}
6270

6371
fn prepare_parameters(

0 commit comments

Comments
 (0)