Skip to content

Commit 1498620

Browse files
committed
refactor(odbc): improve error handling and data fetching logic
This commit enhances the error handling in OdbcValueRef by utilizing the new MismatchedTypeError for better type mismatch reporting. It also simplifies the data fetching logic in OdbcBridge by consolidating the handling of default values and null checks, improving code clarity and maintainability. Additionally, tests are updated to reflect these changes, ensuring robust functionality across buffered and unbuffered modes.
1 parent 40f7b38 commit 1498620

File tree

3 files changed

+67
-51
lines changed

3 files changed

+67
-51
lines changed

sqlx-core/src/odbc/connection/odbc_bridge.rs

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -434,9 +434,7 @@ where
434434
vec: &mut Vec<T>,
435435
nulls: &mut Vec<bool>,
436436
) {
437-
let mut tmp = T::default();
438-
nulls.push(cursor_row.get_data(col_index, &mut tmp).is_ok());
439-
vec.push(tmp);
437+
push_get_data_with_default(cursor_row, col_index, vec, nulls, T::default());
440438
}
441439

442440
fn push_get_data_with_default<T: Copy + CElement + CDataMut>(
@@ -447,13 +445,8 @@ where
447445
default_val: T,
448446
) {
449447
let mut tmp = default_val;
450-
if cursor_row.get_data(col_index, &mut tmp).is_ok() {
451-
vec.push(tmp);
452-
nulls.push(false);
453-
} else {
454-
vec.push(default_val);
455-
nulls.push(true);
456-
}
448+
nulls.push(cursor_row.get_data(col_index, &mut tmp).is_err());
449+
vec.push(tmp);
457450
}
458451

459452
fn push_binary(
@@ -463,16 +456,8 @@ where
463456
nulls: &mut Vec<bool>,
464457
) {
465458
let mut buf = Vec::new();
466-
match cursor_row.get_text(col_index, &mut buf) {
467-
Ok(true) => {
468-
vec.push(buf);
469-
nulls.push(false);
470-
}
471-
Ok(false) | Err(_) => {
472-
vec.push(Vec::new());
473-
nulls.push(true);
474-
}
475-
}
459+
nulls.push(cursor_row.get_text(col_index, &mut buf).is_err());
460+
vec.push(buf);
476461
}
477462

478463
fn push_text(

sqlx-core/src/odbc/value.rs

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use crate::odbc::{Odbc, OdbcBatch, OdbcTypeInfo};
2+
use crate::type_info::TypeInfo;
23
use crate::value::{Value, ValueRef};
34
use odbc_api::buffers::{AnySlice, NullableSlice};
45
use odbc_api::sys::NULL_DATA;
@@ -156,32 +157,30 @@ impl<'r> OdbcValueRef<'r> {
156157
}
157158

158159
pub fn try_int<T: TryFromInt + crate::types::Type<Odbc>>(&self) -> crate::error::Result<T> {
159-
if !T::compatible(&self.batch.column_data[self.column_index].type_info) {
160-
return Err(crate::error::Error::Decode(
161-
crate::error::mismatched_types::<Odbc, T>(
162-
&self.batch.column_data[self.column_index].type_info,
163-
),
164-
));
165-
}
166160
self.int::<T>().ok_or_else(|| {
167-
crate::error::Error::Decode(crate::error::mismatched_types::<Odbc, T>(
168-
&self.batch.columns[self.column_index].type_info,
169-
))
161+
crate::error::Error::Decode(Box::new(crate::error::MismatchedTypeError {
162+
rust_type: T::type_info().name().to_string(),
163+
rust_sql_type: T::type_info().name().to_string(),
164+
sql_type: self.batch.column_data[self.column_index]
165+
.type_info
166+
.name()
167+
.to_string(),
168+
source: Some(format!("ODBC: cannot decode {:?}", self).into()),
169+
}))
170170
})
171171
}
172172

173173
pub fn try_float<T: TryFromFloat + crate::types::Type<Odbc>>(&self) -> crate::error::Result<T> {
174-
if !T::compatible(&self.batch.column_data[self.column_index].type_info) {
175-
return Err(crate::error::Error::Decode(
176-
crate::error::mismatched_types::<Odbc, T>(
177-
&self.batch.column_data[self.column_index].type_info,
178-
),
179-
));
180-
}
181174
self.float::<T>().ok_or_else(|| {
182-
crate::error::Error::Decode(crate::error::mismatched_types::<Odbc, T>(
183-
&self.batch.columns[self.column_index].type_info,
184-
))
175+
crate::error::Error::Decode(Box::new(crate::error::MismatchedTypeError {
176+
rust_type: T::type_info().name().to_string(),
177+
rust_sql_type: T::type_info().name().to_string(),
178+
sql_type: self.batch.column_data[self.column_index]
179+
.type_info
180+
.name()
181+
.to_string(),
182+
source: Some(format!("ODBC: cannot decode {:?}", self).into()),
183+
}))
185184
})
186185
}
187186

tests/odbc/odbc.rs

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,25 +1037,45 @@ async fn it_handles_prepared_statement_with_wrong_parameters() -> anyhow::Result
10371037
}
10381038

10391039
#[tokio::test]
1040-
async fn it_works_with_unbuffered_mode() -> anyhow::Result<()> {
1040+
async fn it_works_with_buffered_and_unbuffered_mode() -> anyhow::Result<()> {
10411041
use sqlx_oldapi::odbc::{OdbcBufferSettings, OdbcConnectOptions};
10421042

10431043
// Create connection with unbuffered settings
10441044
let database_url = std::env::var("DATABASE_URL").unwrap();
10451045
let mut opts = OdbcConnectOptions::from_str(&database_url)?;
10461046

1047-
for batch_size in [1, 100, 10000] {
1048-
opts.buffer_settings(OdbcBufferSettings {
1049-
batch_size,
1047+
let count = 450;
1048+
1049+
let select = (0..count)
1050+
.map(|i| format!("SELECT {i} AS n, '{}' as aas", "a".repeat(i)))
1051+
.collect::<Vec<_>>()
1052+
.join(" UNION ALL ");
1053+
1054+
for buf_settings in [
1055+
OdbcBufferSettings {
1056+
batch_size: 1,
1057+
max_column_size: None,
1058+
},
1059+
OdbcBufferSettings {
1060+
batch_size: 1,
1061+
max_column_size: Some(450),
1062+
},
1063+
OdbcBufferSettings {
1064+
batch_size: 100,
10501065
max_column_size: None,
1051-
});
1066+
},
1067+
OdbcBufferSettings {
1068+
batch_size: 10000,
1069+
max_column_size: None,
1070+
},
1071+
OdbcBufferSettings {
1072+
batch_size: 10000,
1073+
max_column_size: Some(450),
1074+
},
1075+
] {
1076+
opts.buffer_settings(buf_settings);
10521077

10531078
let mut conn = OdbcConnection::connect_with(&opts).await?;
1054-
let count = 450;
1055-
let select = (0..count)
1056-
.map(|i| format!("SELECT {i} AS n"))
1057-
.collect::<Vec<_>>()
1058-
.join(" UNION ALL ");
10591079

10601080
// Test that unbuffered mode works correctly
10611081
let s = conn
@@ -1067,8 +1087,20 @@ async fn it_works_with_unbuffered_mode() -> anyhow::Result<()> {
10671087
assert_eq!(s.len(), count);
10681088
for i in 0..count {
10691089
let row = s.get(i).expect("row expected");
1070-
let as_i64 = row.get::<'_, i64, _>(0);
1090+
let as_i64 = row
1091+
.try_get_raw(0)
1092+
.expect("1 column expected")
1093+
.to_owned()
1094+
.try_decode::<i64>()
1095+
.expect("SELECT n should be an i64");
10711096
assert_eq!(as_i64, i64::try_from(i).unwrap());
1097+
let aas = row
1098+
.try_get_raw(1)
1099+
.expect("1 column expected")
1100+
.to_owned()
1101+
.try_decode::<String>()
1102+
.expect("SELECT aas should be a string");
1103+
assert_eq!(aas, "a".repeat(i));
10721104
}
10731105
}
10741106
Ok(())

0 commit comments

Comments
 (0)