Skip to content

Commit 6e90b48

Browse files
committed
refactor(odbc): restructure OdbcStatement and enhance metadata handling
This commit refactors the OdbcStatement structure to encapsulate metadata, including columns and parameters, within a dedicated OdbcStatementMetadata struct. It also updates the OdbcConnection to cache prepared statement metadata, improving performance and reducing redundant metadata retrieval. Additionally, the prepare method is streamlined to utilize the new metadata structure.
1 parent 650ef81 commit 6e90b48

File tree

7 files changed

+203
-148
lines changed

7 files changed

+203
-148
lines changed

sqlx-core/src/any/connection/executor.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,7 @@ impl<'c> Executor<'c> for &'c mut AnyConnection {
128128
AnyConnectionKind::Mssql(conn) => conn.prepare(sql).await.map(Into::into)?,
129129

130130
#[cfg(feature = "odbc")]
131-
AnyConnectionKind::Odbc(conn) => {
132-
let (_, columns, parameters) = conn.prepare_metadata(sql).await?;
133-
crate::odbc::OdbcStatement {
134-
sql: sql.into(),
135-
columns,
136-
parameters,
137-
}
138-
.into()
139-
}
131+
AnyConnectionKind::Odbc(conn) => conn.prepare(sql).await.map(Into::into)?,
140132
})
141133
})
142134
}

sqlx-core/src/odbc/connection/executor.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ use either::Either;
66
use futures_core::future::BoxFuture;
77
use futures_core::stream::BoxStream;
88
use futures_util::TryStreamExt;
9-
use std::borrow::Cow;
109

1110
// run method removed; fetch_many implements streaming directly
1211

@@ -59,14 +58,7 @@ impl<'c> Executor<'c> for &'c mut OdbcConnection {
5958
where
6059
'c: 'e,
6160
{
62-
Box::pin(async move {
63-
let (_, columns, parameters) = self.prepare_metadata(sql).await?;
64-
Ok(OdbcStatement {
65-
sql: Cow::Borrowed(sql),
66-
columns,
67-
parameters,
68-
})
69-
})
61+
Box::pin(async move { self.prepare(sql).await })
7062
}
7163

7264
#[doc(hidden)]
Lines changed: 147 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,48 @@
11
use crate::connection::Connection;
22
use crate::error::Error;
33
use crate::odbc::blocking::run_blocking;
4-
use crate::odbc::{Odbc, OdbcArguments, OdbcColumn, OdbcConnectOptions, OdbcQueryResult, OdbcRow};
4+
use crate::odbc::{
5+
Odbc, OdbcArguments, OdbcColumn, OdbcConnectOptions, OdbcQueryResult, OdbcRow, OdbcTypeInfo,
6+
};
57
use crate::transaction::Transaction;
68
use either::Either;
79
mod odbc_bridge;
810
use futures_core::future::BoxFuture;
911
use futures_util::future;
10-
use odbc_bridge::{do_prepare, establish_connection, execute_sql, OdbcConn};
12+
use odbc_bridge::{establish_connection, execute_sql};
1113
// no direct spawn_blocking here; use run_blocking helper
12-
use std::sync::{Arc, Mutex};
14+
use crate::odbc::{OdbcStatement, OdbcStatementMetadata};
15+
use odbc_api::ResultSetMetadata;
16+
use std::borrow::Cow;
17+
use std::collections::HashMap;
18+
use std::sync::Arc;
19+
20+
fn collect_columns(
21+
prepared: &mut odbc_api::Prepared<odbc_api::handles::StatementImpl<'_>>,
22+
) -> Vec<OdbcColumn> {
23+
let count = prepared.num_result_cols().unwrap_or(0);
24+
(1..=count)
25+
.map(|i| create_column(prepared, i as u16))
26+
.collect()
27+
}
28+
29+
fn create_column(
30+
stmt: &mut odbc_api::Prepared<odbc_api::handles::StatementImpl<'_>>,
31+
index: u16,
32+
) -> OdbcColumn {
33+
let mut cd = odbc_api::ColumnDescription::default();
34+
let _ = stmt.describe_col(index, &mut cd);
35+
36+
OdbcColumn {
37+
name: decode_column_name(cd.name, index),
38+
type_info: OdbcTypeInfo::new(cd.data_type),
39+
ordinal: usize::from(index.checked_sub(1).unwrap()),
40+
}
41+
}
42+
43+
fn decode_column_name(name_bytes: Vec<u8>, index: u16) -> String {
44+
String::from_utf8(name_bytes).unwrap_or_else(|_| format!("col{}", index - 1))
45+
}
1346

1447
mod executor;
1548

@@ -19,86 +52,88 @@ mod executor;
1952
/// thread-pool via `spawn_blocking` and synchronize access with a mutex.
2053
#[derive(Debug)]
2154
pub struct OdbcConnection {
22-
pub(crate) conn: Arc<Mutex<OdbcConn>>,
55+
pub(crate) conn: odbc_api::SharedConnection<'static>,
56+
pub(crate) stmt_cache: HashMap<u64, crate::odbc::statement::OdbcStatementMetadata>,
2357
}
2458

2559
impl OdbcConnection {
26-
#[inline]
27-
async fn with_conn<T, F>(&self, f: F) -> Result<T, Error>
28-
where
29-
T: Send + 'static,
30-
F: FnOnce(&mut OdbcConn) -> Result<T, Error> + Send + 'static,
31-
{
32-
let inner = self.conn.clone();
33-
run_blocking(move || {
34-
let mut conn = inner.lock().unwrap();
35-
f(&mut conn)
36-
})
37-
.await
38-
}
39-
40-
#[inline]
41-
async fn with_conn_map<T, E, F>(&self, ctx: &'static str, f: F) -> Result<T, Error>
42-
where
43-
T: Send + 'static,
44-
E: std::fmt::Display,
45-
F: FnOnce(&mut OdbcConn) -> Result<T, E> + Send + 'static,
46-
{
47-
let inner = self.conn.clone();
48-
run_blocking(move || {
49-
let mut conn = inner.lock().unwrap();
50-
f(&mut conn).map_err(|e| Error::Protocol(format!("{}: {}", ctx, e)))
51-
})
52-
.await
53-
}
54-
5560
pub(crate) async fn establish(options: &OdbcConnectOptions) -> Result<Self, Error> {
56-
let conn = run_blocking({
61+
let shared_conn = run_blocking({
5762
let options = options.clone();
58-
move || establish_connection(&options)
63+
move || {
64+
let conn = establish_connection(&options)?;
65+
let shared_conn = odbc_api::SharedConnection::new(std::sync::Mutex::new(conn));
66+
Ok::<_, Error>(shared_conn)
67+
}
5968
})
6069
.await?;
6170

6271
Ok(Self {
63-
conn: Arc::new(Mutex::new(conn)),
72+
conn: shared_conn,
73+
stmt_cache: HashMap::new(),
6474
})
6575
}
6676

6777
/// Returns the name of the actual Database Management System (DBMS) this
6878
/// connection is talking to as reported by the ODBC driver.
6979
pub async fn dbms_name(&mut self) -> Result<String, Error> {
70-
self.with_conn_map::<_, _, _>("Failed to get DBMS name", |conn| {
71-
conn.conn.database_management_system_name()
80+
let conn = Arc::clone(&self.conn);
81+
run_blocking(move || {
82+
let conn_guard = conn
83+
.lock()
84+
.map_err(|_| Error::Protocol("Failed to lock connection".into()))?;
85+
conn_guard
86+
.database_management_system_name()
87+
.map_err(Error::from)
7288
})
7389
.await
7490
}
7591

7692
pub(crate) async fn ping_blocking(&mut self) -> Result<(), Error> {
77-
self.with_conn_map::<_, _, _>("Ping failed", |conn| {
78-
conn.conn.execute("SELECT 1", (), None).map(|_| ())
93+
let conn = Arc::clone(&self.conn);
94+
run_blocking(move || {
95+
let conn_guard = conn
96+
.lock()
97+
.map_err(|_| Error::Protocol("Failed to lock connection".into()))?;
98+
conn_guard
99+
.execute("SELECT 1", (), None)
100+
.map_err(Error::from)
101+
.map(|_| ())
79102
})
80103
.await
81104
}
82105

83106
pub(crate) async fn begin_blocking(&mut self) -> Result<(), Error> {
84-
self.with_conn_map::<_, _, _>("Failed to begin transaction", |conn| {
85-
conn.conn.set_autocommit(false)
107+
let conn = Arc::clone(&self.conn);
108+
run_blocking(move || {
109+
let conn_guard = conn
110+
.lock()
111+
.map_err(|_| Error::Protocol("Failed to lock connection".into()))?;
112+
conn_guard.set_autocommit(false).map_err(Error::from)
86113
})
87114
.await
88115
}
89116

90117
pub(crate) async fn commit_blocking(&mut self) -> Result<(), Error> {
91-
self.with_conn_map::<_, _, _>("Failed to commit transaction", |conn| {
92-
conn.conn.commit()?;
93-
conn.conn.set_autocommit(true)
118+
let conn = Arc::clone(&self.conn);
119+
run_blocking(move || {
120+
let conn_guard = conn
121+
.lock()
122+
.map_err(|_| Error::Protocol("Failed to lock connection".into()))?;
123+
conn_guard.commit()?;
124+
conn_guard.set_autocommit(true).map_err(Error::from)
94125
})
95126
.await
96127
}
97128

98129
pub(crate) async fn rollback_blocking(&mut self) -> Result<(), Error> {
99-
self.with_conn_map::<_, _, _>("Failed to rollback transaction", |conn| {
100-
conn.conn.rollback()?;
101-
conn.conn.set_autocommit(true)
130+
let conn = Arc::clone(&self.conn);
131+
run_blocking(move || {
132+
let conn_guard = conn
133+
.lock()
134+
.map_err(|_| Error::Protocol("Failed to lock connection".into()))?;
135+
conn_guard.rollback()?;
136+
conn_guard.set_autocommit(true).map_err(Error::from)
102137
})
103138
.await
104139
}
@@ -111,23 +146,79 @@ impl OdbcConnection {
111146
let (tx, rx) = flume::bounded(64);
112147
let sql = sql.to_string();
113148
let args_move = args;
114-
self.with_conn(move |conn| {
115-
if let Err(e) = execute_sql(conn, &sql, args_move, &tx) {
149+
let conn = Arc::clone(&self.conn);
150+
151+
run_blocking(move || {
152+
let mut conn_guard = conn
153+
.lock()
154+
.map_err(|_| Error::Protocol("Failed to lock connection".into()))?;
155+
if let Err(e) = execute_sql(&mut conn_guard, &sql, args_move, &tx) {
116156
let _ = tx.send(Err(e));
117157
}
118158
Ok(())
119159
})
120160
.await?;
161+
121162
Ok(rx)
122163
}
123164

124165
pub(crate) async fn prepare_metadata(
125166
&mut self,
126167
sql: &str,
127168
) -> Result<(u64, Vec<OdbcColumn>, usize), Error> {
169+
use std::collections::hash_map::DefaultHasher;
170+
use std::hash::{Hash, Hasher};
171+
172+
let mut hasher = DefaultHasher::new();
173+
sql.hash(&mut hasher);
174+
let key = hasher.finish();
175+
176+
// Check cache first
177+
if let Some(metadata) = self.stmt_cache.get(&key) {
178+
return Ok((key, metadata.columns.clone(), metadata.parameters));
179+
}
180+
181+
// Create new prepared statement to get metadata
128182
let sql = sql.to_string();
129-
self.with_conn(move |conn| do_prepare(conn, sql.into()))
130-
.await
183+
let conn = Arc::clone(&self.conn);
184+
185+
run_blocking(move || {
186+
let conn_guard = conn
187+
.lock()
188+
.map_err(|_| Error::Protocol("Failed to lock connection".into()))?;
189+
let mut prepared = conn_guard.prepare(&sql).map_err(Error::from)?;
190+
let columns = collect_columns(&mut prepared);
191+
let params = usize::from(prepared.num_params().unwrap_or(0));
192+
Ok::<_, Error>((columns, params))
193+
})
194+
.await
195+
.map(|(columns, params)| {
196+
// Cache the metadata
197+
let metadata = crate::odbc::statement::OdbcStatementMetadata {
198+
columns: columns.clone(),
199+
parameters: params,
200+
};
201+
self.stmt_cache.insert(key, metadata);
202+
(key, columns, params)
203+
})
204+
}
205+
206+
pub(crate) async fn clear_cached_statements(&mut self) -> Result<(), Error> {
207+
// Clear the statement metadata cache
208+
self.stmt_cache.clear();
209+
Ok(())
210+
}
211+
212+
pub async fn prepare(&mut self, sql: &str) -> Result<OdbcStatement<'static>, Error> {
213+
let (_, columns, parameters) = self.prepare_metadata(sql).await?;
214+
let metadata = OdbcStatementMetadata {
215+
columns,
216+
parameters,
217+
};
218+
Ok(OdbcStatement {
219+
sql: Cow::Owned(sql.to_string()),
220+
metadata,
221+
})
131222
}
132223
}
133224

@@ -168,6 +259,10 @@ impl Connection for OdbcConnection {
168259
fn should_flush(&self) -> bool {
169260
false
170261
}
262+
263+
fn clear_cached_statements(&mut self) -> BoxFuture<'_, Result<(), Error>> {
264+
Box::pin(self.clear_cached_statements())
265+
}
171266
}
172267

173268
// moved helpers to connection/inner.rs

0 commit comments

Comments
 (0)