diff --git a/sqlx-core/src/any/error.rs b/sqlx-core/src/any/error.rs deleted file mode 100644 index 7d4a5f15f0..0000000000 --- a/sqlx-core/src/any/error.rs +++ /dev/null @@ -1,16 +0,0 @@ -use std::any::type_name; - -use crate::any::type_info::AnyTypeInfo; -use crate::any::Any; -use crate::error::BoxDynError; -use crate::type_info::TypeInfo; -use crate::types::Type; - -pub(super) fn mismatched_types>(ty: &AnyTypeInfo) -> BoxDynError { - format!( - "mismatched types; Rust type `{}` is not compatible with SQL type `{}`", - type_name::(), - ty.name() - ) - .into() -} diff --git a/sqlx-core/src/any/mod.rs b/sqlx-core/src/any/mod.rs index f51fef7869..0de215e278 100644 --- a/sqlx-core/src/any/mod.rs +++ b/sqlx-core/src/any/mod.rs @@ -18,7 +18,6 @@ mod arguments; pub(crate) mod column; mod connection; mod database; -mod error; mod kind; mod options; mod query_result; diff --git a/sqlx-core/src/any/row.rs b/sqlx-core/src/any/row.rs index 2a7ba4b2e5..63b7c99d46 100644 --- a/sqlx-core/src/any/row.rs +++ b/sqlx-core/src/any/row.rs @@ -1,8 +1,8 @@ -use crate::any::error::mismatched_types; use crate::any::{Any, AnyColumn, AnyColumnIndex}; use crate::column::ColumnIndex; use crate::database::HasValueRef; use crate::decode::Decode; +use crate::error::mismatched_types; use crate::error::Error; use crate::row::Row; use crate::type_info::TypeInfo; @@ -91,7 +91,7 @@ impl Row for AnyRow { let ty = value.type_info(); if !value.is_null() && !ty.is_null() && !T::compatible(&ty) { - Err(mismatched_types::(&ty)) + Err(mismatched_types::(&ty)) } else { T::decode(value) } diff --git a/sqlx-core/src/any/value.rs b/sqlx-core/src/any/value.rs index 23a06997c6..bc95ffa977 100644 --- a/sqlx-core/src/any/value.rs +++ b/sqlx-core/src/any/value.rs @@ -1,10 +1,9 @@ use std::borrow::Cow; -use crate::any::error::mismatched_types; use crate::any::{Any, AnyTypeInfo}; use crate::database::HasValueRef; use crate::decode::Decode; -use crate::error::Error; +use crate::error::mismatched_types; use crate::type_info::TypeInfo; use crate::types::Type; use crate::value::{Value, ValueRef}; @@ -113,7 +112,7 @@ impl Value for AnyValue { } } - fn try_decode<'r, T>(&'r self) -> Result + fn try_decode<'r, T>(&'r self) -> crate::error::Result where T: Decode<'r, Self::Database> + Type, { @@ -121,7 +120,10 @@ impl Value for AnyValue { let ty = self.type_info(); if !ty.is_null() && !T::compatible(&ty) { - return Err(Error::Decode(mismatched_types::(&ty))); + return Err(crate::error::Error::Decode(mismatched_types::< + Self::Database, + T, + >(&ty))); } } diff --git a/sqlx-core/src/error.rs b/sqlx-core/src/error.rs index 46f240bc7a..7b21e9d567 100644 --- a/sqlx-core/src/error.rs +++ b/sqlx-core/src/error.rs @@ -26,6 +26,43 @@ pub type BoxDynError = Box; #[error("unexpected null; try decoding as an `Option`")] pub struct UnexpectedNullError; +/// Error indicating that a Rust type is not compatible with a SQL type. +#[derive(thiserror::Error, Debug)] +#[error("mismatched types; Rust type `{rust_type}` (as SQL type `{rust_sql_type}`) could not be decoded into SQL type `{sql_type}`")] +pub struct MismatchedTypeError { + /// The name of the Rust type. + pub rust_type: String, + /// The SQL type name that the Rust type would map to. + pub rust_sql_type: String, + /// The actual SQL type from the database. + pub sql_type: String, + /// Optional source error that caused the mismatch. + #[source] + pub source: Option, +} + +impl MismatchedTypeError { + /// Create a new mismatched type error without a source. + pub fn new>(ty: &DB::TypeInfo) -> Self { + Self { + rust_type: type_name::().to_string(), + rust_sql_type: T::type_info().name().to_string(), + sql_type: ty.name().to_string(), + source: None, + } + } + + /// Create a new mismatched type error with a source error. + pub fn with_source>(ty: &DB::TypeInfo, source: BoxDynError) -> Self { + Self { + rust_type: type_name::().to_string(), + rust_sql_type: T::type_info().name().to_string(), + sql_type: ty.name().to_string(), + source: Some(source), + } + } +} + /// Represents all the ways a method can fail within SQLx. #[derive(Debug, thiserror::Error)] #[non_exhaustive] @@ -145,14 +182,17 @@ impl Error { } pub(crate) fn mismatched_types>(ty: &DB::TypeInfo) -> BoxDynError { - // TODO: `#name` only produces `TINYINT` but perhaps we want to show `TINYINT(1)` - format!( - "mismatched types; Rust type `{}` (as SQL type `{}`) is not compatible with SQL type `{}`", - type_name::(), - T::type_info().name(), - ty.name() - ) - .into() + Box::new(MismatchedTypeError { + rust_type: format!( + "{} ({}compatible with SQL type `{}`)", + type_name::(), + if T::compatible(ty) { "" } else { "in" }, + T::type_info().name() + ), + rust_sql_type: T::type_info().name().to_string(), + sql_type: ty.name().to_string(), + source: None, + }) } /// An error that was returned from the database. diff --git a/sqlx-core/src/odbc/connection/mod.rs b/sqlx-core/src/odbc/connection/mod.rs index 8622de15cc..5cacfd1fc8 100644 --- a/sqlx-core/src/odbc/connection/mod.rs +++ b/sqlx-core/src/odbc/connection/mod.rs @@ -1,7 +1,8 @@ use crate::connection::Connection; use crate::error::Error; use crate::odbc::{ - Odbc, OdbcArguments, OdbcColumn, OdbcConnectOptions, OdbcQueryResult, OdbcRow, OdbcTypeInfo, + Odbc, OdbcArguments, OdbcBufferSettings, OdbcColumn, OdbcConnectOptions, OdbcQueryResult, + OdbcRow, OdbcTypeInfo, }; use crate::transaction::Transaction; use either::Either; @@ -67,12 +68,14 @@ pub(super) fn decode_column_name(name: T, index: u16) -> St pub struct OdbcConnection { pub(crate) conn: SharedConnection<'static>, pub(crate) stmt_cache: HashMap, SharedPreparedStatement>, + pub(crate) buffer_settings: OdbcBufferSettings, } impl std::fmt::Debug for OdbcConnection { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("OdbcConnection") .field("conn", &self.conn) + .field("buffer_settings", &self.buffer_settings) .finish() } } @@ -108,6 +111,7 @@ impl OdbcConnection { Ok(Self { conn: shared_conn, stmt_cache: HashMap::new(), + buffer_settings: options.buffer_settings, }) } @@ -162,9 +166,10 @@ impl OdbcConnection { }; let conn = Arc::clone(&self.conn); + let buffer_settings = self.buffer_settings; sqlx_rt::spawn(sqlx_rt::spawn_blocking(move || { let mut conn = conn.lock().expect("failed to lock connection"); - if let Err(e) = execute_sql(&mut conn, maybe_prepared, args, &tx) { + if let Err(e) = execute_sql(&mut conn, maybe_prepared, args, &tx, buffer_settings) { let _ = tx.send(Err(e)); } })); diff --git a/sqlx-core/src/odbc/connection/odbc_bridge.rs b/sqlx-core/src/odbc/connection/odbc_bridge.rs index 01291ea2e9..64fe8bc3c1 100644 --- a/sqlx-core/src/odbc/connection/odbc_bridge.rs +++ b/sqlx-core/src/odbc/connection/odbc_bridge.rs @@ -1,13 +1,56 @@ use super::decode_column_name; use crate::error::Error; +use crate::odbc::OdbcValueVec; use crate::odbc::{ - connection::MaybePrepared, OdbcArgumentValue, OdbcArguments, OdbcColumn, OdbcQueryResult, - OdbcRow, OdbcTypeInfo, + connection::MaybePrepared, ColumnData, OdbcArgumentValue, OdbcArguments, OdbcBatch, + OdbcBufferSettings, OdbcColumn, OdbcQueryResult, OdbcRow, OdbcTypeInfo, }; use either::Either; use flume::{SendError, Sender}; -use odbc_api::handles::{AsStatementRef, Statement}; -use odbc_api::{Cursor, CursorRow, IntoParameter, Nullable, ResultSetMetadata}; +use odbc_api::buffers::{AnySlice, BufferDesc, ColumnarAnyBuffer}; +use odbc_api::handles::{AsStatementRef, CDataMut, Nullability, Statement}; +use odbc_api::parameter::CElement; +use odbc_api::{Cursor, IntoParameter, ResultSetMetadata}; +use std::sync::Arc; + +// Bulk fetch implementation using columnar buffers instead of row-by-row fetching +// This provides significant performance improvements by fetching rows in batches +// and avoiding the slow `next_row()` method from odbc-api + +struct ColumnBinding { + column: OdbcColumn, + buffer_desc: BufferDesc, +} + +fn build_bindings( + cursor: &mut C, + buffer_settings: &OdbcBufferSettings, +) -> Result, Error> +where + C: ResultSetMetadata, +{ + let column_count = cursor.num_result_cols().unwrap_or(0); + let mut bindings = Vec::with_capacity(column_count as usize); + for index in 1..=column_count { + let column = create_column(cursor, index as u16); + let nullable = cursor + .col_nullability(index as u16) + .unwrap_or(Nullability::Unknown) + .could_be_nullable(); + let buffer_desc = map_buffer_desc( + cursor, + index as u16, + &column.type_info, + nullable, + buffer_settings, + )?; + bindings.push(ColumnBinding { + column, + buffer_desc, + }); + } + Ok(bindings) +} pub type ExecuteResult = Result, Error>; pub type ExecuteSender = Sender; @@ -27,21 +70,22 @@ pub fn execute_sql( maybe_prepared: MaybePrepared, args: Option, tx: &ExecuteSender, + buffer_settings: OdbcBufferSettings, ) -> Result<(), Error> { let params = prepare_parameters(args); let affected = match maybe_prepared { MaybePrepared::Prepared(prepared) => { let mut prepared = prepared.lock().expect("prepared statement lock"); - if let Some(mut cursor) = prepared.execute(¶ms[..])? { - handle_cursor(&mut cursor, tx); + if let Some(cursor) = prepared.execute(¶ms[..])? { + handle_cursor(cursor, tx, buffer_settings); } extract_rows_affected(&mut *prepared) } MaybePrepared::NotPrepared(sql) => { let mut preallocated = conn.preallocate().map_err(Error::from)?; - if let Some(mut cursor) = preallocated.execute(&sql, ¶ms[..])? { - handle_cursor(&mut cursor, tx); + if let Some(cursor) = preallocated.execute(&sql, ¶ms[..])? { + handle_cursor(cursor, tx, buffer_settings); } extract_rows_affected(&mut preallocated) } @@ -87,19 +131,42 @@ fn to_param(arg: OdbcArgumentValue) -> Box(cursor: &mut C, tx: &ExecuteSender) +fn handle_cursor(mut cursor: C, tx: &ExecuteSender, buffer_settings: OdbcBufferSettings) where C: Cursor + ResultSetMetadata, { - let columns = collect_columns(cursor); - - match stream_rows(cursor, &columns, tx) { - Ok(true) => { - let _ = send_done(tx, 0); + match buffer_settings.max_column_size { + Some(_) => { + // Buffered mode - use batch fetching with columnar buffers + let bindings = match build_bindings(&mut cursor, &buffer_settings) { + Ok(b) => b, + Err(e) => { + send_error(tx, e); + return; + } + }; + + match stream_rows(cursor, bindings, tx, buffer_settings) { + Ok(true) => { + let _ = send_done(tx, 0); + } + Ok(false) => {} + Err(e) => { + send_error(tx, e); + } + } } - Ok(false) => {} - Err(e) => { - send_error(tx, e); + None => { + // Unbuffered mode - use batched row-by-row fetching + match stream_rows_unbuffered(cursor, tx, buffer_settings.batch_size) { + Ok(true) => { + let _ = send_done(tx, 0); + } + Ok(false) => {} + Err(e) => { + send_error(tx, e); + } + } } } } @@ -116,16 +183,6 @@ fn send_row(tx: &ExecuteSender, row: OdbcRow) -> Result<(), SendError(cursor: &mut C) -> Vec -where - C: ResultSetMetadata, -{ - let count = cursor.num_result_cols().unwrap_or(0); - (1..=count) - .map(|i| create_column(cursor, i as u16)) - .collect() -} - fn create_column(cursor: &mut C, index: u16) -> OdbcColumn where C: ResultSetMetadata, @@ -140,185 +197,343 @@ where } } -fn stream_rows(cursor: &mut C, columns: &[OdbcColumn], tx: &ExecuteSender) -> Result +fn map_buffer_desc( + _cursor: &mut C, + _column_index: u16, + type_info: &OdbcTypeInfo, + nullable: bool, + buffer_settings: &OdbcBufferSettings, +) -> Result where - C: Cursor, + C: ResultSetMetadata, { - let mut receiver_open = true; + use odbc_api::DataType; - while let Some(mut row) = cursor.next_row()? { - let values = collect_row_values(&mut row, columns)?; - let row_data = OdbcRow { - columns: columns.to_vec(), - values: values.into_iter().map(|(_, value)| value).collect(), - }; + let data_type = type_info.data_type(); - if send_row(tx, row_data).is_err() { - receiver_open = false; - break; + // Helper function to determine buffer length with fallback + let max_column_size = buffer_settings.max_column_size.unwrap_or(4096); + + let buffer_length = |length: Option| { + if let Some(length) = length { + if length.get() < 255 { + length.get() + } else { + max_column_size + } + } else { + max_column_size } + }; + + let buffer_desc = match data_type { + // Integer types - all map to I64 + DataType::TinyInt | DataType::SmallInt | DataType::Integer | DataType::BigInt => { + BufferDesc::I64 { nullable } + } + // Floating point types + DataType::Real => BufferDesc::F32 { nullable }, + DataType::Float { .. } | DataType::Double => BufferDesc::F64 { nullable }, + // Bit type + DataType::Bit => BufferDesc::Bit { nullable }, + // Date/Time types + DataType::Date => BufferDesc::Date { nullable }, + DataType::Time { .. } => BufferDesc::Time { nullable }, + DataType::Timestamp { .. } => BufferDesc::Timestamp { nullable }, + // Binary types + DataType::Binary { length } + | DataType::Varbinary { length } + | DataType::LongVarbinary { length } => BufferDesc::Binary { + length: buffer_length(length), + }, + // Text types + DataType::Char { length } + | DataType::WChar { length } + | DataType::Varchar { length } + | DataType::WVarchar { length } + | DataType::LongVarchar { length } + | DataType::WLongVarchar { length } + | DataType::Other { + column_size: length, + .. + } => BufferDesc::Text { + max_str_len: buffer_length(length), + }, + // Fallback cases + DataType::Unknown => BufferDesc::Text { + max_str_len: max_column_size, + }, + DataType::Decimal { .. } | DataType::Numeric { .. } => BufferDesc::Text { + max_str_len: max_column_size, + }, + }; + + Ok(buffer_desc) +} + +fn create_column_data(slice: AnySlice<'_>, column: &OdbcColumn) -> Arc { + let (values, nulls) = crate::odbc::value::convert_any_slice_to_value_vec(slice); + Arc::new(ColumnData { + values, + type_info: column.type_info.clone(), + nulls, + }) +} + +fn build_columns_from_cursor(cursor: &mut C) -> Vec +where + C: ResultSetMetadata, +{ + let column_count = cursor.num_result_cols().unwrap_or(0); + let mut columns = Vec::with_capacity(column_count as usize); + for index in 1..=column_count { + columns.push(create_column(cursor, index as u16)); } - Ok(receiver_open) + columns } -fn collect_row_values( - row: &mut CursorRow<'_>, +fn build_column_data_from_values( columns: &[OdbcColumn], -) -> Result, Error> { - columns - .iter() - .enumerate() - .map(|(i, column)| collect_column_value(row, i, column)) + value_vecs: Vec, + nulls_vecs: Vec>, +) -> Vec> { + value_vecs + .into_iter() + .zip(nulls_vecs) + .zip(columns.iter()) + .map(|((values, nulls), column)| { + Arc::new(ColumnData { + values, + type_info: column.type_info.clone(), + nulls, + }) + }) .collect() } -fn collect_column_value( - row: &mut CursorRow<'_>, - index: usize, - column: &OdbcColumn, -) -> Result<(OdbcTypeInfo, crate::odbc::OdbcValue), Error> { - use odbc_api::DataType; - - let col_idx = (index + 1) as u16; - let type_info = column.type_info.clone(); - let data_type = type_info.data_type(); - - let value = match data_type { - DataType::TinyInt - | DataType::SmallInt - | DataType::Integer - | DataType::BigInt - | DataType::Bit => extract_int(row, col_idx, &type_info)?, +fn send_rows_for_batch( + tx: &ExecuteSender, + col_arc: &Arc<[OdbcColumn]>, + column_data: Vec>, + num_rows: usize, +) -> bool { + let odbc_batch = Arc::new(OdbcBatch { + columns: Arc::clone(col_arc), + column_data, + }); - DataType::Real => extract_float::(row, col_idx, &type_info)?, - DataType::Float { .. } | DataType::Double => { - extract_float::(row, col_idx, &type_info)? + let mut receiver_open = true; + for row_index in 0..num_rows { + let row = OdbcRow { + row_index, + batch: Arc::clone(&odbc_batch), + }; + if send_row(tx, row).is_err() { + receiver_open = false; + break; } + } + receiver_open +} - DataType::Char { .. } - | DataType::Varchar { .. } - | DataType::LongVarchar { .. } - | DataType::WChar { .. } - | DataType::WVarchar { .. } - | DataType::WLongVarchar { .. } - | DataType::Date - | DataType::Time { .. } - | DataType::Timestamp { .. } - | DataType::Decimal { .. } - | DataType::Numeric { .. } => extract_text(row, col_idx, &type_info)?, - - DataType::Binary { .. } | DataType::Varbinary { .. } | DataType::LongVarbinary { .. } => { - extract_binary(row, col_idx, &type_info)? - } +fn stream_rows( + cursor: C, + bindings: Vec, + tx: &ExecuteSender, + buffer_settings: OdbcBufferSettings, +) -> Result +where + C: Cursor + ResultSetMetadata, +{ + if buffer_settings.max_column_size.is_some() { + // Buffered mode + stream_rows_buffered(cursor, bindings, tx, buffer_settings.batch_size) + } else { + // Unbuffered mode - we shouldn't reach here, but handle it just in case + stream_rows_unbuffered(cursor, tx, buffer_settings.batch_size) + } +} - DataType::Unknown | DataType::Other { .. } => { - match extract_text(row, col_idx, &type_info) { - Ok(v) => v, - Err(_) => extract_binary(row, col_idx, &type_info)?, - } - } - }; +fn stream_rows_buffered( + cursor: C, + bindings: Vec, + tx: &ExecuteSender, + batch_size: usize, +) -> Result +where + C: Cursor + ResultSetMetadata, +{ + let buffer_descriptions: Vec<_> = bindings.iter().map(|b| b.buffer_desc).collect(); + let buffer = ColumnarAnyBuffer::from_descs(batch_size, buffer_descriptions); + let mut row_set_cursor = cursor.bind_buffer(buffer)?; - Ok((type_info, value)) -} + let mut receiver_open = true; -fn extract_int( - row: &mut CursorRow<'_>, - col_idx: u16, - type_info: &OdbcTypeInfo, -) -> Result { - let mut nullable = Nullable::::null(); - row.get_data(col_idx, &mut nullable)?; + let columns: Vec = bindings.iter().map(|b| b.column.clone()).collect(); + let col_arc: Arc<[OdbcColumn]> = Arc::from(columns); - let (is_null, int) = match nullable.into_opt() { - None => (true, None), - Some(v) => (false, Some(v)), - }; + while let Some(batch) = row_set_cursor.fetch()? { + let column_data: Vec<_> = bindings + .iter() + .enumerate() + .map(|(col_index, binding)| { + create_column_data(batch.column(col_index), &binding.column) + }) + .collect(); - Ok(crate::odbc::OdbcValue { - type_info: type_info.clone(), - is_null, - text: None, - blob: None, - int, - float: None, - }) + if !send_rows_for_batch(tx, &col_arc, column_data, batch.num_rows()) { + receiver_open = false; + break; + } + } + + Ok(receiver_open) } -fn extract_float( - row: &mut CursorRow<'_>, - col_idx: u16, - type_info: &OdbcTypeInfo, -) -> Result +fn stream_rows_unbuffered( + mut cursor: C, + tx: &ExecuteSender, + batch_size: usize, +) -> Result where - T: Into + Default, - odbc_api::Nullable: odbc_api::parameter::CElement + odbc_api::handles::CDataMut, + C: Cursor + ResultSetMetadata, { - let mut nullable = Nullable::::null(); - row.get_data(col_idx, &mut nullable)?; + use odbc_api::DataType; - let (is_null, float) = match nullable.into_opt() { - None => (true, None), - Some(v) => (false, Some(v.into())), - }; + let mut receiver_open = true; - Ok(crate::odbc::OdbcValue { - type_info: type_info.clone(), - is_null, - text: None, - blob: None, - int: None, - float, - }) -} + let columns = build_columns_from_cursor(&mut cursor); + let column_count = columns.len(); -fn extract_text( - row: &mut CursorRow<'_>, - col_idx: u16, - type_info: &OdbcTypeInfo, -) -> Result { - let mut buf = Vec::new(); - let is_some = row.get_text(col_idx, &mut buf)?; + let col_arc: Arc<[OdbcColumn]> = Arc::from(columns.clone()); - let (is_null, text) = if !is_some { - (true, None) - } else { - match String::from_utf8(buf) { - Ok(s) => (false, Some(s)), - Err(e) => return Err(Error::Decode(e.into())), + fn init_value_vec(dt: DataType, capacity: usize) -> OdbcValueVec { + match dt { + DataType::TinyInt => OdbcValueVec::TinyInt(Vec::with_capacity(capacity)), + DataType::SmallInt => OdbcValueVec::SmallInt(Vec::with_capacity(capacity)), + DataType::Integer => OdbcValueVec::Integer(Vec::with_capacity(capacity)), + DataType::BigInt => OdbcValueVec::BigInt(Vec::with_capacity(capacity)), + DataType::Real => OdbcValueVec::Real(Vec::with_capacity(capacity)), + DataType::Float { .. } | DataType::Double => { + OdbcValueVec::Double(Vec::with_capacity(capacity)) + } + DataType::Bit => OdbcValueVec::Bit(Vec::with_capacity(capacity)), + DataType::Date => OdbcValueVec::Date(Vec::with_capacity(capacity)), + DataType::Time { .. } => OdbcValueVec::Time(Vec::with_capacity(capacity)), + DataType::Timestamp { .. } => OdbcValueVec::Timestamp(Vec::with_capacity(capacity)), + DataType::Binary { .. } + | DataType::Varbinary { .. } + | DataType::LongVarbinary { .. } => OdbcValueVec::Binary(Vec::with_capacity(capacity)), + _ => OdbcValueVec::Text(Vec::with_capacity(capacity)), } - }; + } - Ok(crate::odbc::OdbcValue { - type_info: type_info.clone(), - is_null, - text, - blob: None, - int: None, - float: None, - }) -} + fn push_get_data( + cursor_row: &mut odbc_api::CursorRow<'_>, + col_index: u16, + vec: &mut Vec, + nulls: &mut Vec, + ) { + push_get_data_with_default(cursor_row, col_index, vec, nulls, T::default()); + } -fn extract_binary( - row: &mut CursorRow<'_>, - col_idx: u16, - type_info: &OdbcTypeInfo, -) -> Result { - let mut buf = Vec::new(); - let is_some = row.get_binary(col_idx, &mut buf)?; + fn push_get_data_with_default( + cursor_row: &mut odbc_api::CursorRow<'_>, + col_index: u16, + vec: &mut Vec, + nulls: &mut Vec, + default_val: T, + ) { + let mut tmp = default_val; + nulls.push(cursor_row.get_data(col_index, &mut tmp).is_err()); + vec.push(tmp); + } - let (is_null, blob) = if !is_some { - (true, None) - } else { - (false, Some(buf)) - }; + fn push_binary( + cursor_row: &mut odbc_api::CursorRow<'_>, + col_index: u16, + vec: &mut Vec>, + nulls: &mut Vec, + ) { + let mut buf = Vec::new(); + nulls.push(cursor_row.get_text(col_index, &mut buf).is_err()); + vec.push(buf); + } - Ok(crate::odbc::OdbcValue { - type_info: type_info.clone(), - is_null, - text: None, - blob, - int: None, - float: None, - }) + fn push_text( + cursor_row: &mut odbc_api::CursorRow<'_>, + col_index: u16, + vec: &mut Vec, + nulls: &mut Vec, + ) { + let mut buf = Vec::::new(); + let txt = cursor_row.get_wide_text(col_index, &mut buf); + vec.push(String::from_utf16_lossy(&buf).to_string()); + nulls.push(!txt.unwrap_or(false)); + } + + fn push_from_cursor_row( + cursor_row: &mut odbc_api::CursorRow<'_>, + col_index: u16, + values: &mut OdbcValueVec, + nulls: &mut Vec, + ) { + match values { + OdbcValueVec::TinyInt(v) => push_get_data(cursor_row, col_index, v, nulls), + OdbcValueVec::SmallInt(v) => push_get_data(cursor_row, col_index, v, nulls), + OdbcValueVec::Integer(v) => push_get_data(cursor_row, col_index, v, nulls), + OdbcValueVec::BigInt(v) => push_get_data(cursor_row, col_index, v, nulls), + OdbcValueVec::Real(v) => push_get_data(cursor_row, col_index, v, nulls), + OdbcValueVec::Double(v) => push_get_data(cursor_row, col_index, v, nulls), + OdbcValueVec::Bit(v) => { + push_get_data_with_default(cursor_row, col_index, v, nulls, odbc_api::Bit(0)) + } + OdbcValueVec::Date(v) => push_get_data(cursor_row, col_index, v, nulls), + OdbcValueVec::Time(v) => push_get_data(cursor_row, col_index, v, nulls), + OdbcValueVec::Timestamp(v) => push_get_data(cursor_row, col_index, v, nulls), + OdbcValueVec::Binary(v) => push_binary(cursor_row, col_index, v, nulls), + OdbcValueVec::Text(v) => push_text(cursor_row, col_index, v, nulls), + } + } + + loop { + // Initialize per-column containers for this batch + let mut value_vecs: Vec = columns + .iter() + .map(|c| init_value_vec(c.type_info.data_type(), batch_size)) + .collect(); + let mut nulls_vecs: Vec> = (0..column_count) + .map(|_| Vec::with_capacity(batch_size)) + .collect(); + + let mut num_rows = 0; + while let Some(mut cursor_row) = cursor.next_row()? { + for col in 0..column_count { + let col_idx = (col as u16) + 1; + push_from_cursor_row( + &mut cursor_row, + col_idx, + &mut value_vecs[col], + &mut nulls_vecs[col], + ); + } + num_rows += 1; + if num_rows == batch_size { + break; + } + } + + let column_data = build_column_data_from_values(&columns, value_vecs, nulls_vecs); + + if !send_rows_for_batch(tx, &col_arc, column_data, num_rows) { + receiver_open = false; + break; + } + + if !receiver_open || num_rows < batch_size { + break; + } + } + + Ok(receiver_open) } diff --git a/sqlx-core/src/odbc/mod.rs b/sqlx-core/src/odbc/mod.rs index 492cc370b6..feeec9e2fb 100644 --- a/sqlx-core/src/odbc/mod.rs +++ b/sqlx-core/src/odbc/mod.rs @@ -20,6 +20,37 @@ //! ```text //! odbc:DSN=MyDataSource //! ``` +//! +//! ## Buffer Configuration +//! +//! You can configure buffer settings for performance tuning: +//! +//! ```rust,no_run +//! use std::str::FromStr; +//! use sqlx_core_oldapi::odbc::{OdbcConnectOptions, OdbcBufferSettings}; +//! +//! let mut opts = OdbcConnectOptions::from_str("DSN=MyDataSource")?; +//! +//! // Configure for high-throughput buffered mode +//! opts.buffer_settings(OdbcBufferSettings { +//! batch_size: 256, // Fetch 256 rows at once +//! max_column_size: Some(2048), // Limit text columns to 2048 chars +//! }); +//! +//! // Configure for unbuffered mode (no truncation, row-by-row processing) +//! opts.buffer_settings(OdbcBufferSettings { +//! batch_size: 128, // batch_size ignored in unbuffered mode +//! max_column_size: None, // Enable unbuffered mode +//! }); +//! +//! // Or configure individual settings +//! opts.batch_size(512) +//! .max_column_size(Some(1024)); +//! +//! // Switch to unbuffered mode +//! opts.max_column_size(None); +//! # Ok::<(), sqlx::Error>(()) +//! ``` use crate::executor::Executor; @@ -41,13 +72,13 @@ pub use arguments::{OdbcArgumentValue, OdbcArguments}; pub use column::OdbcColumn; pub use connection::OdbcConnection; pub use database::Odbc; -pub use options::OdbcConnectOptions; +pub use options::{OdbcBufferSettings, OdbcConnectOptions}; pub use query_result::OdbcQueryResult; -pub use row::OdbcRow; +pub use row::{OdbcBatch, OdbcRow}; pub use statement::{OdbcStatement, OdbcStatementMetadata}; pub use transaction::OdbcTransactionManager; pub use type_info::{DataTypeExt, OdbcTypeInfo}; -pub use value::{OdbcValue, OdbcValueRef}; +pub use value::{ColumnData, OdbcValue, OdbcValueRef, OdbcValueType, OdbcValueVec}; /// An alias for [`Pool`][crate::pool::Pool], specialized for ODBC. pub type OdbcPool = crate::pool::Pool; diff --git a/sqlx-core/src/odbc/options/mod.rs b/sqlx-core/src/odbc/options/mod.rs index 19a217bfcc..2a65c38169 100644 --- a/sqlx-core/src/odbc/options/mod.rs +++ b/sqlx-core/src/odbc/options/mod.rs @@ -8,22 +8,163 @@ use std::time::Duration; use crate::odbc::OdbcConnection; +/// Configuration for ODBC buffer settings that control memory usage and performance characteristics. +/// +/// These settings affect how SQLx fetches and processes data from ODBC data sources. Careful tuning +/// of these parameters can significantly impact memory usage and query performance. +/// +/// # Buffered vs Unbuffered Mode +/// +/// The buffer settings control two different data fetching strategies based on `max_column_size`: +/// +/// **Buffered Mode** (when `max_column_size` is `Some(value)`): +/// - Fetches rows in batches using columnar buffers for better performance with large result sets +/// - Pre-allocates memory for all columns and rows in each batch using `ColumnarAnyBuffer` +/// - WARNING: Long textual and binary field data will be truncated to `max_column_size` +/// - Better for high-throughput scenarios where memory usage is acceptable +/// - Uses `batch_size` to control how many rows are fetched at once +/// +/// **Unbuffered Mode** (when `max_column_size` is `None`): +/// - Fetches rows in batches using the `next_row()` method for memory efficiency +/// - Processes rows in batches of `batch_size` but without pre-allocating columnar buffers +/// - No data truncation - handles variable-sized data dynamically +/// - More memory efficient than buffered mode but potentially slower +/// - Better for scenarios where data sizes vary significantly or memory is constrained +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct OdbcBufferSettings { + /// The number of rows to fetch in each batch during bulk operations. + /// + /// **Performance Impact:** + /// - Higher values reduce the number of round-trips to the database but increase memory usage + /// - Lower values reduce memory usage but may increase latency due to more frequent fetches + /// - Typical range: 32-512 rows + /// - Used in both buffered and unbuffered modes + /// + /// **Memory Impact:** + /// - In buffered mode: Each batch allocates buffers for `batch_size * number_of_columns` cells + /// - In unbuffered mode: Each batch processes `batch_size` rows without pre-allocation + /// - For wide result sets, this can consume significant memory + /// + /// **Default:** 128 rows + pub batch_size: usize, + + /// The maximum size (in characters) for text and binary columns when the database doesn't specify a length. + /// + /// **Performance Impact:** + /// - When `Some(value)`: Enables buffered mode with batch fetching and pre-allocated buffers + /// - When `None`: Enables unbuffered mode with batched row-by-row processing + /// - Higher values ensure large text fields are fully captured but increase memory allocation + /// - Lower values may truncate data but reduce memory pressure + /// - Affects VARCHAR, NVARCHAR, TEXT, and BLOB column types + /// + /// **Memory Impact:** + /// - When `Some(value)`: Directly controls buffer size for variable-length columns + /// - When `None`: Minimal memory allocation per row, no truncation + /// - Setting too high can waste memory; setting too low can cause data truncation + /// - Consider your data characteristics when tuning this value + /// + /// **Default:** Some(4096) + pub max_column_size: Option, +} + +impl Default for OdbcBufferSettings { + fn default() -> Self { + Self { + batch_size: 128, + max_column_size: Some(4096), + } + } +} + #[derive(Clone)] pub struct OdbcConnectOptions { pub(crate) conn_str: String, pub(crate) log_settings: LogSettings, + pub(crate) buffer_settings: OdbcBufferSettings, } impl OdbcConnectOptions { pub fn connection_string(&self) -> &str { &self.conn_str } + + /// Sets the buffer configuration for this connection. + /// + /// The buffer settings control memory usage and performance characteristics + /// when fetching data from ODBC data sources. + /// + /// # Example + /// ```rust,no_run + /// use std::str::FromStr; + /// use sqlx_core_oldapi::odbc::{OdbcConnectOptions, OdbcBufferSettings}; + /// + /// let mut opts = OdbcConnectOptions::from_str("DSN=MyDataSource")?; + /// + /// // Configure for high-throughput buffered mode + /// opts.buffer_settings(OdbcBufferSettings { + /// batch_size: 256, + /// max_column_size: Some(2048), + /// }); + /// + /// // Or configure for unbuffered mode + /// opts.buffer_settings(OdbcBufferSettings { + /// batch_size: 128, // batch_size is ignored in unbuffered mode + /// max_column_size: None, + /// }); + /// # Ok::<(), sqlx_core_oldapi::error::Error>(()) + /// ``` + pub fn buffer_settings(&mut self, settings: OdbcBufferSettings) -> &mut Self { + self.buffer_settings = settings; + self + } + + /// Sets the batch size for bulk fetch operations. + /// + /// This controls how many rows are fetched at once during query execution. + /// Higher values can improve performance for large result sets but use more memory. + /// Only used when `max_column_size` is `Some(value)` (buffered mode). + /// + /// # Panics + /// Panics if `batch_size` is 0. + pub fn batch_size(&mut self, batch_size: usize) -> &mut Self { + assert!(batch_size > 0, "batch_size must be greater than 0"); + self.buffer_settings.batch_size = batch_size; + self + } + + /// Sets the maximum column size for text and binary data. + /// + /// This controls the buffer size allocated for columns when the database + /// doesn't specify a maximum length. Larger values ensure complete data + /// capture but increase memory usage. + /// + /// - When set to `Some(value)`: Enables buffered mode with batch fetching + /// - When set to `None`: Enables unbuffered mode with row-by-row processing + /// + /// # Panics + /// Panics if `max_column_size` is less than 1024 or greater than 4096 (when Some). + pub fn max_column_size(&mut self, max_column_size: Option) -> &mut Self { + if let Some(size) = max_column_size { + assert!( + (1024..=4096).contains(&size), + "max_column_size must be between 1024 and 4096" + ); + } + self.buffer_settings.max_column_size = max_column_size; + self + } + + /// Returns the current buffer settings for this connection. + pub fn buffer_settings_ref(&self) -> &OdbcBufferSettings { + &self.buffer_settings + } } impl Debug for OdbcConnectOptions { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { f.debug_struct("OdbcConnectOptions") .field("conn_str", &"") + .field("buffer_settings", &self.buffer_settings) .finish() } } @@ -51,6 +192,7 @@ impl FromStr for OdbcConnectOptions { Ok(Self { conn_str, log_settings: LogSettings::default(), + buffer_settings: OdbcBufferSettings::default(), }) } } diff --git a/sqlx-core/src/odbc/row.rs b/sqlx-core/src/odbc/row.rs index 04b4ed27bc..d5c18b708e 100644 --- a/sqlx-core/src/odbc/row.rs +++ b/sqlx-core/src/odbc/row.rs @@ -1,21 +1,27 @@ use crate::column::ColumnIndex; use crate::database::HasValueRef; use crate::error::Error; -use crate::odbc::{Odbc, OdbcColumn, OdbcValue}; +use crate::odbc::{Odbc, OdbcColumn, OdbcValueRef}; use crate::row::Row; -use crate::value::Value; +use std::sync::Arc; + +#[derive(Debug, Clone)] +pub struct OdbcBatch { + pub(crate) columns: Arc<[OdbcColumn]>, + pub(crate) column_data: Vec>, +} #[derive(Debug, Clone)] pub struct OdbcRow { - pub(crate) columns: Vec, - pub(crate) values: Vec, + pub(crate) row_index: usize, + pub(crate) batch: Arc, } impl Row for OdbcRow { type Database = Odbc; fn columns(&self) -> &[OdbcColumn] { - &self.columns + &self.batch.columns } fn try_get_raw( @@ -25,21 +31,21 @@ impl Row for OdbcRow { where I: ColumnIndex, { - let idx = index.index(self)?; - let value = &self.values[idx]; - Ok(value.as_ref()) + let column_index = index.index(self)?; + Ok(OdbcValueRef::new(&self.batch, self.row_index, column_index)) } } impl ColumnIndex for &str { fn index(&self, row: &OdbcRow) -> Result { // Try exact match first (for performance) - if let Some(pos) = row.columns.iter().position(|col| col.name == *self) { + if let Some(pos) = row.batch.columns.iter().position(|col| col.name == *self) { return Ok(pos); } // Fall back to case-insensitive match (for databases like Snowflake) - row.columns + row.batch + .columns .iter() .position(|col| col.name.eq_ignore_ascii_case(self)) .ok_or_else(|| Error::ColumnNotFound((*self).into())) @@ -56,6 +62,7 @@ mod private { impl From for crate::any::AnyRow { fn from(row: OdbcRow) -> Self { let columns = row + .batch .columns .iter() .map(|col| crate::any::AnyColumn { @@ -74,57 +81,55 @@ impl From for crate::any::AnyRow { #[cfg(test)] mod tests { use super::*; - use crate::odbc::{OdbcColumn, OdbcTypeInfo}; + use crate::odbc::{ColumnData, OdbcColumn, OdbcTypeInfo, OdbcValueVec}; use crate::type_info::TypeInfo; + use crate::value::ValueRef; use odbc_api::DataType; + use std::sync::Arc; fn create_test_row() -> OdbcRow { - use crate::odbc::OdbcValue; + let columns = Arc::new([ + OdbcColumn { + name: "lowercase_col".to_string(), + type_info: OdbcTypeInfo::new(DataType::Integer), + ordinal: 0, + }, + OdbcColumn { + name: "UPPERCASE_COL".to_string(), + type_info: OdbcTypeInfo::new(DataType::Varchar { length: None }), + ordinal: 1, + }, + OdbcColumn { + name: "MixedCase_Col".to_string(), + type_info: OdbcTypeInfo::new(DataType::Double), + ordinal: 2, + }, + ]); + + let column_data = vec![ + Arc::new(ColumnData { + values: OdbcValueVec::BigInt(vec![42]), + type_info: OdbcTypeInfo::new(DataType::Integer), + nulls: vec![false], + }), + Arc::new(ColumnData { + values: OdbcValueVec::Text(vec!["test".to_string()]), + type_info: OdbcTypeInfo::new(DataType::Varchar { length: None }), + nulls: vec![false], + }), + Arc::new(ColumnData { + values: OdbcValueVec::Double(vec![std::f64::consts::PI]), + type_info: OdbcTypeInfo::new(DataType::Double), + nulls: vec![false], + }), + ]; OdbcRow { - columns: vec![ - OdbcColumn { - name: "lowercase_col".to_string(), - type_info: OdbcTypeInfo::new(DataType::Integer), - ordinal: 0, - }, - OdbcColumn { - name: "UPPERCASE_COL".to_string(), - type_info: OdbcTypeInfo::new(DataType::Varchar { length: None }), - ordinal: 1, - }, - OdbcColumn { - name: "MixedCase_Col".to_string(), - type_info: OdbcTypeInfo::new(DataType::Double), - ordinal: 2, - }, - ], - values: vec![ - OdbcValue { - type_info: OdbcTypeInfo::new(DataType::Integer), - is_null: false, - text: None, - blob: None, - int: Some(42), - float: None, - }, - OdbcValue { - type_info: OdbcTypeInfo::new(DataType::Varchar { length: None }), - is_null: false, - text: Some("test".to_string()), - blob: None, - int: None, - float: None, - }, - OdbcValue { - type_info: OdbcTypeInfo::new(DataType::Double), - is_null: false, - text: None, - blob: None, - int: None, - float: Some(std::f64::consts::PI), - }, - ], + row_index: 0, + batch: Arc::new(OdbcBatch { + columns, + column_data, + }), } } @@ -171,18 +176,18 @@ mod tests { // Test accessing by exact name let value = row.try_get_raw("lowercase_col").unwrap(); - assert!(!value.is_null); - assert_eq!(value.type_info.name(), "INTEGER"); + assert!(!value.is_null()); + assert_eq!(value.type_info().name(), "INTEGER"); // Test accessing by case-insensitive name let value = row.try_get_raw("LOWERCASE_COL").unwrap(); - assert!(!value.is_null); - assert_eq!(value.type_info.name(), "INTEGER"); + assert!(!value.is_null()); + assert_eq!(value.type_info().name(), "INTEGER"); // Test accessing uppercase column with lowercase name let value = row.try_get_raw("uppercase_col").unwrap(); - assert!(!value.is_null); - assert_eq!(value.type_info.name(), "VARCHAR"); + assert!(!value.is_null()); + assert_eq!(value.type_info().name(), "VARCHAR"); } #[test] diff --git a/sqlx-core/src/odbc/types/bigdecimal.rs b/sqlx-core/src/odbc/types/bigdecimal.rs index 7b15f65e15..abdd8cee24 100644 --- a/sqlx-core/src/odbc/types/bigdecimal.rs +++ b/sqlx-core/src/odbc/types/bigdecimal.rs @@ -3,7 +3,7 @@ use crate::encode::Encode; use crate::error::BoxDynError; use crate::odbc::{DataTypeExt, Odbc, OdbcArgumentValue, OdbcTypeInfo, OdbcValueRef}; use crate::types::Type; -use bigdecimal::{BigDecimal, FromPrimitive}; +use bigdecimal::BigDecimal; use odbc_api::DataType; use std::str::FromStr; @@ -36,19 +36,22 @@ impl<'q> Encode<'q, Odbc> for BigDecimal { impl<'r> Decode<'r, Odbc> for BigDecimal { fn decode(value: OdbcValueRef<'r>) -> Result { - if let Some(int) = value.int { - return Ok(BigDecimal::from(int)); + if let Some(int) = value.int::() { + Ok(BigDecimal::from(int)) + } else if let Some(float) = value.float::() { + Ok(BigDecimal::try_from(float)?) + } else if let Some(text) = value.text() { + let text = text.trim(); + Ok(BigDecimal::from_str(text).map_err(|e| format!("bad decimal text: {}", e))?) + } else if let Some(bytes) = value.blob() { + if let Ok(s) = std::str::from_utf8(bytes) { + Ok(BigDecimal::parse_bytes(s.as_bytes(), 10) + .ok_or(format!("bad base10 bytes: {:?}", bytes))?) + } else { + Err(format!("bad utf8 bytes: {:?}", bytes).into()) + } + } else { + Err(format!("ODBC: cannot decode BigDecimal: {:?}", value).into()) } - if let Some(float) = value.float { - return Ok(BigDecimal::from_f64(float).ok_or(format!("bad float: {}", float))?); - } - if let Some(text) = value.text { - return Ok(BigDecimal::from_str(text).map_err(|e| format!("bad decimal text: {}", e))?); - } - if let Some(bytes) = value.blob { - return Ok(BigDecimal::parse_bytes(bytes, 10) - .ok_or(format!("bad base10 bytes: {:?}", bytes))?); - } - Err(format!("ODBC: cannot decode BigDecimal: {:?}", value).into()) } } diff --git a/sqlx-core/src/odbc/types/bool.rs b/sqlx-core/src/odbc/types/bool.rs index d5620235ec..2ae39ec8ce 100644 --- a/sqlx-core/src/odbc/types/bool.rs +++ b/sqlx-core/src/odbc/types/bool.rs @@ -40,50 +40,44 @@ impl<'q> Encode<'q, Odbc> for bool { impl<'r> Decode<'r, Odbc> for bool { fn decode(value: OdbcValueRef<'r>) -> Result { - if let Some(i) = value.int { + if let Some(i) = value.int::() { return Ok(i != 0); } - // Handle float values (from DECIMAL/NUMERIC types) - if let Some(f) = value.float { + if let Some(f) = value.float::() { return Ok(f != 0.0); } - if let Some(text) = value.text { + if let Some(text) = value.text() { let text = text.trim(); - // Try exact string matches first return Ok(match text { "0" | "0.0" | "false" | "FALSE" | "f" | "F" => false, "1" | "1.0" | "true" | "TRUE" | "t" | "T" => true, _ => { - // Try parsing as number first if let Ok(num) = text.parse::() { num != 0.0 } else if let Ok(num) = text.parse::() { num != 0 } else { - // Fall back to string parsing - text.parse()? + return Err("provided string was not `true` or `false`".into()); } } }); } - if let Some(bytes) = value.blob { + if let Some(bytes) = value.blob() { let s = std::str::from_utf8(bytes)?; let s = s.trim(); return Ok(match s { "0" | "0.0" | "false" | "FALSE" | "f" | "F" => false, "1" | "1.0" | "true" | "TRUE" | "t" | "T" => true, _ => { - // Try parsing as number first if let Ok(num) = s.parse::() { num != 0.0 } else if let Ok(num) = s.parse::() { num != 0 } else { - // Fall back to string parsing - s.parse()? + return Err("provided string was not `true` or `false`".into()); } } }); @@ -96,41 +90,40 @@ impl<'r> Decode<'r, Odbc> for bool { #[cfg(test)] mod tests { use super::*; - use crate::odbc::{OdbcTypeInfo, OdbcValueRef}; + use crate::odbc::{ColumnData, OdbcBatch, OdbcColumn, OdbcTypeInfo, OdbcValueVec}; use crate::type_info::TypeInfo; use odbc_api::DataType; + use std::sync::Arc; - fn create_test_value_text(text: &'static str, data_type: DataType) -> OdbcValueRef<'static> { - OdbcValueRef { + fn make_ref(value_vec: OdbcValueVec, data_type: DataType) -> OdbcValueRef<'static> { + let column = ColumnData { + values: value_vec, type_info: OdbcTypeInfo::new(data_type), - is_null: false, - text: Some(text), - blob: None, - int: None, - float: None, - } + nulls: vec![false], + }; + let column_data = vec![Arc::new(column)]; + let batch = OdbcBatch { + columns: Arc::new([OdbcColumn { + name: "test".to_string(), + type_info: OdbcTypeInfo::new(data_type), + ordinal: 0, + }]), + column_data, + }; + let batch_ptr = Box::leak(Box::new(batch)); + OdbcValueRef::new(batch_ptr, 0, 0) + } + + fn create_test_value_text(text: &'static str, data_type: DataType) -> OdbcValueRef<'static> { + make_ref(OdbcValueVec::Text(vec![text.to_string()]), data_type) } fn create_test_value_int(value: i64, data_type: DataType) -> OdbcValueRef<'static> { - OdbcValueRef { - type_info: OdbcTypeInfo::new(data_type), - is_null: false, - text: None, - blob: None, - int: Some(value), - float: None, - } + make_ref(OdbcValueVec::BigInt(vec![value]), data_type) } fn create_test_value_float(value: f64, data_type: DataType) -> OdbcValueRef<'static> { - OdbcValueRef { - type_info: OdbcTypeInfo::new(data_type), - is_null: false, - text: None, - blob: None, - int: None, - float: Some(value), - } + make_ref(OdbcValueVec::Double(vec![value]), data_type) } #[test] @@ -322,17 +315,30 @@ mod tests { #[test] fn test_bool_decode_error_handling() { - let value = OdbcValueRef { + let column = ColumnData { + values: OdbcValueVec::Text(vec!["not_a_bool".to_string()]), type_info: OdbcTypeInfo::BIT, - is_null: false, - text: None, - blob: None, - int: None, - float: None, + nulls: vec![false], + }; + let column_data = vec![Arc::new(column)]; + let batch = OdbcBatch { + columns: Arc::new([OdbcColumn { + name: "test".to_string(), + type_info: OdbcTypeInfo::BIT, + ordinal: 0, + }]), + column_data, }; + let batch_ptr = Box::leak(Box::new(batch)); + let value = OdbcValueRef::new(batch_ptr, 0, 0); let result = >::decode(value); assert!(result.is_err()); - assert_eq!(result.unwrap_err().to_string(), "ODBC: cannot decode bool"); + // The new implementation returns the parsing error before the final fallback + let error_msg = result.unwrap_err().to_string(); + assert!( + error_msg.contains("provided string was not") + || error_msg.contains("ODBC: cannot decode bool") + ); } } diff --git a/sqlx-core/src/odbc/types/bytes.rs b/sqlx-core/src/odbc/types/bytes.rs index 6ad56a7554..f590edac34 100644 --- a/sqlx-core/src/odbc/types/bytes.rs +++ b/sqlx-core/src/odbc/types/bytes.rs @@ -40,19 +40,22 @@ impl<'q> Encode<'q, Odbc> for &'q [u8] { impl<'r> Decode<'r, Odbc> for Vec { fn decode(value: OdbcValueRef<'r>) -> Result { - Ok(<&[u8] as Decode<'r, Odbc>>::decode(value)?.to_vec()) + if let Some(bytes) = value.blob() { + Ok(bytes.to_vec()) + } else if let Some(text) = value.text() { + Ok(text.as_bytes().to_vec()) + } else { + Err("ODBC: cannot decode as Vec".into()) + } } } impl<'r> Decode<'r, Odbc> for &'r [u8] { fn decode(value: OdbcValueRef<'r>) -> Result { - if let Some(bytes) = value.blob { - return Ok(bytes); - } - if let Some(text) = value.text { - return Ok(text.as_bytes()); - } - Err(format!("ODBC: cannot decode {:?} as &[u8]", value).into()) + value + .blob() + .or_else(|| value.text().map(|text| text.as_bytes())) + .ok_or(format!("ODBC: cannot decode as &[u8]: {:?}", value).into()) } } @@ -68,30 +71,36 @@ impl Type for [u8] { #[cfg(test)] mod tests { use super::*; - use crate::odbc::{OdbcTypeInfo, OdbcValueRef}; + use crate::odbc::{ColumnData, OdbcBatch, OdbcColumn, OdbcTypeInfo, OdbcValueVec}; use crate::type_info::TypeInfo; use odbc_api::DataType; + use std::sync::Arc; - fn create_test_value_text(text: &'static str, data_type: DataType) -> OdbcValueRef<'static> { - OdbcValueRef { + fn make_ref(value_vec: OdbcValueVec, data_type: DataType) -> OdbcValueRef<'static> { + let column = ColumnData { + values: value_vec, type_info: OdbcTypeInfo::new(data_type), - is_null: false, - text: Some(text), - blob: None, - int: None, - float: None, - } + nulls: vec![false], + }; + let column_data = vec![Arc::new(column)]; + let batch = OdbcBatch { + columns: Arc::new([OdbcColumn { + name: "test".to_string(), + type_info: OdbcTypeInfo::new(data_type), + ordinal: 0, + }]), + column_data, + }; + let batch_ptr = Box::leak(Box::new(batch)); + OdbcValueRef::new(batch_ptr, 0, 0) + } + + fn create_test_value_text(text: &'static str, data_type: DataType) -> OdbcValueRef<'static> { + make_ref(OdbcValueVec::Text(vec![text.to_string()]), data_type) } fn create_test_value_blob(data: &'static [u8], data_type: DataType) -> OdbcValueRef<'static> { - OdbcValueRef { - type_info: OdbcTypeInfo::new(data_type), - is_null: false, - text: None, - blob: Some(data), - int: None, - float: None, - } + make_ref(OdbcValueVec::Binary(vec![data.to_vec()]), data_type) } #[test] @@ -160,11 +169,11 @@ mod tests { fn test_vec_u8_encode() { let mut buf = Vec::new(); let data = vec![65, 66, 67, 68, 69]; // "ABCDE" - let result = as Encode>::encode(data, &mut buf); + let result = as Encode>::encode(data.clone(), &mut buf); assert!(matches!(result, crate::encode::IsNull::No)); assert_eq!(buf.len(), 1); if let OdbcArgumentValue::Bytes(bytes) = &buf[0] { - assert_eq!(*bytes, vec![65, 66, 67, 68, 69]); + assert_eq!(*bytes, data); } else { panic!("Expected Bytes argument"); } @@ -186,15 +195,26 @@ mod tests { #[test] fn test_decode_error_handling() { - let value = OdbcValueRef { + let column = ColumnData { + values: OdbcValueVec::Text(vec!["not_bytes".to_string()]), type_info: OdbcTypeInfo::varbinary(None), - is_null: false, - text: None, - blob: None, - int: None, - float: None, + nulls: vec![false], + }; + let column_data = vec![Arc::new(column)]; + let batch = OdbcBatch { + columns: Arc::new([OdbcColumn { + name: "test".to_string(), + type_info: OdbcTypeInfo::varbinary(None), + ordinal: 0, + }]), + column_data, }; - assert!( as Decode<'_, Odbc>>::decode(value).is_err()); + let batch_ptr = Box::leak(Box::new(batch)); + let value = OdbcValueRef::new(batch_ptr, 0, 0); + // Vec can decode text as bytes, so this should succeed + let result = as Decode<'_, Odbc>>::decode(value); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), b"not_bytes"); } #[test] diff --git a/sqlx-core/src/odbc/types/chrono.rs b/sqlx-core/src/odbc/types/chrono.rs index 178885dacd..136079e0e1 100644 --- a/sqlx-core/src/odbc/types/chrono.rs +++ b/sqlx-core/src/odbc/types/chrono.rs @@ -192,11 +192,11 @@ fn parse_yyyymmdd_text_as_naive_date(s: &str) -> Option { } fn get_text_from_value(value: &OdbcValueRef<'_>) -> Result, BoxDynError> { - if let Some(text) = value.text { + if let Some(text) = value.text() { let trimmed = text.trim_matches('\u{0}').trim(); return Ok(Some(trimmed.to_string())); } - if let Some(bytes) = value.blob { + if let Some(bytes) = value.blob() { let s = std::str::from_utf8(bytes)?; let trimmed = s.trim_matches('\u{0}').trim(); return Ok(Some(trimmed.to_string())); @@ -206,9 +206,21 @@ fn get_text_from_value(value: &OdbcValueRef<'_>) -> Result, BoxDy impl<'r> Decode<'r, Odbc> for NaiveDate { fn decode(value: OdbcValueRef<'r>) -> Result { + // Handle raw ODBC Date values first + if let Some(date_val) = value.date() { + // Convert odbc_api::sys::Date to NaiveDate + // The ODBC Date structure typically has year, month, day fields + return Ok(NaiveDate::from_ymd_opt( + date_val.year as i32, + date_val.month as u32, + date_val.day as u32, + ) + .ok_or_else(|| "ODBC: invalid date values".to_string())?); + } + // Handle text values first (most common for dates) - if let Some(text) = get_text_from_value(&value)? { - if let Some(date) = parse_yyyymmdd_text_as_naive_date(&text) { + if let Some(text) = value.text() { + if let Some(date) = parse_yyyymmdd_text_as_naive_date(text) { return Ok(date); } if let Ok(date) = text.parse() { @@ -217,7 +229,7 @@ impl<'r> Decode<'r, Odbc> for NaiveDate { } // Handle numeric YYYYMMDD format (for databases that return as numbers) - if let Some(int_val) = value.int { + if let Some(int_val) = value.int() { if let Some(date) = parse_yyyymmdd_as_naive_date(int_val) { return Ok(date); } @@ -229,7 +241,7 @@ impl<'r> Decode<'r, Odbc> for NaiveDate { } // Handle float values similarly - if let Some(float_val) = value.float { + if let Some(float_val) = value.float::() { if let Some(date) = parse_yyyymmdd_as_naive_date(float_val as i64) { return Ok(date); } @@ -242,7 +254,7 @@ impl<'r> Decode<'r, Odbc> for NaiveDate { Err(format!( "ODBC: cannot decode NaiveDate from value with type '{}'", - value.type_info.name() + value.batch.columns[value.column_index].type_info.name() ) .into()) } @@ -250,6 +262,18 @@ impl<'r> Decode<'r, Odbc> for NaiveDate { impl<'r> Decode<'r, Odbc> for NaiveTime { fn decode(value: OdbcValueRef<'r>) -> Result { + // Handle raw ODBC Time values first + if let Some(time_val) = value.time() { + // Convert odbc_api::sys::Time to NaiveTime + // The ODBC Time structure typically has hour, minute, second fields + return Ok(NaiveTime::from_hms_opt( + time_val.hour as u32, + time_val.minute as u32, + time_val.second as u32, + ) + .ok_or_else(|| "ODBC: invalid time values".to_string())?); + } + let mut s = >::decode(value)?; if s.ends_with('\u{0}') { s = s.trim_end_matches('\u{0}').to_string(); @@ -263,6 +287,24 @@ impl<'r> Decode<'r, Odbc> for NaiveTime { impl<'r> Decode<'r, Odbc> for NaiveDateTime { fn decode(value: OdbcValueRef<'r>) -> Result { + // Handle raw ODBC Timestamp values first + if let Some(ts_val) = value.timestamp() { + // Convert odbc_api::sys::Timestamp to NaiveDateTime + // The ODBC Timestamp structure typically has year, month, day, hour, minute, second fields + let date = + NaiveDate::from_ymd_opt(ts_val.year as i32, ts_val.month as u32, ts_val.day as u32) + .ok_or_else(|| "ODBC: invalid date values in timestamp".to_string())?; + + let time = NaiveTime::from_hms_opt( + ts_val.hour as u32, + ts_val.minute as u32, + ts_val.second as u32, + ) + .ok_or_else(|| "ODBC: invalid time values in timestamp".to_string())?; + + return Ok(NaiveDateTime::new(date, time)); + } + let mut s = >::decode(value)?; // Some ODBC drivers (e.g. PostgreSQL) may include trailing spaces or NULs // in textual representations of timestamps. Trim them before parsing. @@ -285,6 +327,24 @@ impl<'r> Decode<'r, Odbc> for NaiveDateTime { impl<'r> Decode<'r, Odbc> for DateTime { fn decode(value: OdbcValueRef<'r>) -> Result { + // Handle raw ODBC Timestamp values first + if let Some(ts_val) = value.timestamp() { + // Convert odbc_api::sys::Timestamp to DateTime + // The ODBC Timestamp structure typically has year, month, day, hour, minute, second fields + let naive_dt = NaiveDateTime::new( + NaiveDate::from_ymd_opt(ts_val.year as i32, ts_val.month as u32, ts_val.day as u32) + .ok_or_else(|| "ODBC: invalid date values in timestamp".to_string())?, + NaiveTime::from_hms_opt( + ts_val.hour as u32, + ts_val.minute as u32, + ts_val.second as u32, + ) + .ok_or_else(|| "ODBC: invalid time values in timestamp".to_string())?, + ); + + return Ok(DateTime::::from_naive_utc_and_offset(naive_dt, Utc)); + } + let mut s = >::decode(value)?; if s.ends_with('\u{0}') { s = s.trim_end_matches('\u{0}').to_string(); @@ -312,6 +372,24 @@ impl<'r> Decode<'r, Odbc> for DateTime { impl<'r> Decode<'r, Odbc> for DateTime { fn decode(value: OdbcValueRef<'r>) -> Result { + // Handle raw ODBC Timestamp values first + if let Some(ts_val) = value.timestamp() { + // Convert odbc_api::sys::Timestamp to DateTime + // The ODBC Timestamp structure typically has year, month, day, hour, minute, second fields + let naive_dt = NaiveDateTime::new( + NaiveDate::from_ymd_opt(ts_val.year as i32, ts_val.month as u32, ts_val.day as u32) + .ok_or_else(|| "ODBC: invalid date values in timestamp".to_string())?, + NaiveTime::from_hms_opt( + ts_val.hour as u32, + ts_val.minute as u32, + ts_val.second as u32, + ) + .ok_or_else(|| "ODBC: invalid time values in timestamp".to_string())?, + ); + + return Ok(DateTime::::from_naive_utc_and_offset(naive_dt, Utc).fixed_offset()); + } + let mut s = >::decode(value)?; if s.ends_with('\u{0}') { s = s.trim_end_matches('\u{0}').to_string(); @@ -343,6 +421,26 @@ impl<'r> Decode<'r, Odbc> for DateTime { impl<'r> Decode<'r, Odbc> for DateTime { fn decode(value: OdbcValueRef<'r>) -> Result { + // Handle raw ODBC Timestamp values first + if let Some(ts_val) = value.timestamp() { + // Convert odbc_api::sys::Timestamp to DateTime + // The ODBC Timestamp structure typically has year, month, day, hour, minute, second fields + let naive_dt = NaiveDateTime::new( + NaiveDate::from_ymd_opt(ts_val.year as i32, ts_val.month as u32, ts_val.day as u32) + .ok_or_else(|| "ODBC: invalid date values in timestamp".to_string())?, + NaiveTime::from_hms_opt( + ts_val.hour as u32, + ts_val.minute as u32, + ts_val.second as u32, + ) + .ok_or_else(|| "ODBC: invalid time values in timestamp".to_string())?, + ); + + return Ok( + DateTime::::from_naive_utc_and_offset(naive_dt, Utc).with_timezone(&Local) + ); + } + let mut s = >::decode(value)?; if s.ends_with('\u{0}') { s = s.trim_end_matches('\u{0}').to_string(); @@ -363,31 +461,47 @@ impl<'r> Decode<'r, Odbc> for DateTime { #[cfg(test)] mod tests { use super::*; - use crate::odbc::{OdbcTypeInfo, OdbcValueRef}; + use crate::odbc::{ + ColumnData, OdbcBatch, OdbcColumn, OdbcTypeInfo, OdbcValueRef, OdbcValueVec, + }; use crate::type_info::TypeInfo; use chrono::{DateTime, NaiveDate, NaiveDateTime, NaiveTime, Utc}; use odbc_api::DataType; + use std::sync::Arc; - fn create_test_value_text(text: &'static str, data_type: DataType) -> OdbcValueRef<'static> { - OdbcValueRef { + fn make_ref(value_vec: OdbcValueVec, data_type: DataType) -> OdbcValueRef<'static> { + let column = ColumnData { + values: value_vec, type_info: OdbcTypeInfo::new(data_type), - is_null: false, - text: Some(text), - blob: None, - int: None, - float: None, - } + nulls: vec![false], + }; + let column_data = vec![Arc::new(column)]; + let batch = OdbcBatch { + columns: Arc::new([OdbcColumn { + name: "test".to_string(), + type_info: OdbcTypeInfo::new(data_type), + ordinal: 0, + }]), + column_data, + }; + let batch_ptr = Box::leak(Box::new(batch)); + OdbcValueRef::new(batch_ptr, 0, 0) + } + + fn create_test_value_text(text: &'static str, data_type: DataType) -> OdbcValueRef<'static> { + make_ref(OdbcValueVec::Text(vec![text.to_string()]), data_type) } fn create_test_value_int(value: i64, data_type: DataType) -> OdbcValueRef<'static> { - OdbcValueRef { - type_info: OdbcTypeInfo::new(data_type), - is_null: false, - text: None, - blob: None, - int: Some(value), - float: None, - } + make_ref(OdbcValueVec::BigInt(vec![value]), data_type) + } + + fn create_test_value_float(value: f64, data_type: DataType) -> OdbcValueRef<'static> { + make_ref(OdbcValueVec::Double(vec![value]), data_type) + } + + fn create_test_value_blob(data: &'static [u8], data_type: DataType) -> OdbcValueRef<'static> { + make_ref(OdbcValueVec::Binary(vec![data.to_vec()]), data_type) } #[test] @@ -521,14 +635,22 @@ mod tests { assert_eq!(get_text_from_value(&value)?, Some("test".to_string())); // From empty - let value = OdbcValueRef { + let column = ColumnData { + values: OdbcValueVec::Text(vec![String::new()]), type_info: OdbcTypeInfo::new(DataType::Date), - is_null: false, - text: None, - blob: None, - int: None, - float: None, + nulls: vec![true], + }; + let column_data = vec![Arc::new(column)]; + let batch = OdbcBatch { + columns: Arc::new([OdbcColumn { + name: "test".to_string(), + type_info: OdbcTypeInfo::new(DataType::Date), + ordinal: 0, + }]), + column_data, }; + let batch_ptr = Box::leak(Box::new(batch)); + let value = OdbcValueRef::new(batch_ptr, 0, 0); assert_eq!(get_text_from_value(&value)?, None); Ok(()) diff --git a/sqlx-core/src/odbc/types/decimal.rs b/sqlx-core/src/odbc/types/decimal.rs index ba796e9b9d..6be43c6f6a 100644 --- a/sqlx-core/src/odbc/types/decimal.rs +++ b/sqlx-core/src/odbc/types/decimal.rs @@ -36,12 +36,13 @@ impl<'q> Encode<'q, Odbc> for Decimal { // Helper function for getting text from value for decimal parsing fn get_text_for_decimal_parsing(value: &OdbcValueRef<'_>) -> Result, BoxDynError> { - if let Some(text) = value.text { + if let Some(text) = value.text() { return Ok(Some(text.trim().to_string())); } - if let Some(bytes) = value.blob { - let s = std::str::from_utf8(bytes)?; - return Ok(Some(s.trim().to_string())); + if let Some(bytes) = value.blob() { + if let Ok(text) = std::str::from_utf8(bytes) { + return Ok(Some(text.trim().to_string())); + } } Ok(None) } @@ -49,12 +50,12 @@ fn get_text_for_decimal_parsing(value: &OdbcValueRef<'_>) -> Result Decode<'r, Odbc> for Decimal { fn decode(value: OdbcValueRef<'r>) -> Result { // Try integer conversion first (most precise) - if let Some(int_val) = value.int { + if let Some(int_val) = value.int::() { return Ok(Decimal::from(int_val)); } // Try direct float conversion for better precision - if let Some(float_val) = value.float { + if let Some(float_val) = value.float::() { if let Ok(decimal) = Decimal::try_from(float_val) { return Ok(decimal); } @@ -72,42 +73,47 @@ impl<'r> Decode<'r, Odbc> for Decimal { #[cfg(test)] mod tests { use super::*; - use crate::odbc::{OdbcTypeInfo, OdbcValueRef}; + use crate::odbc::{ + ColumnData, OdbcBatch, OdbcColumn, OdbcTypeInfo, OdbcValueRef, OdbcValueVec, + }; use crate::type_info::TypeInfo; use odbc_api::DataType; use std::str::FromStr; + use std::sync::Arc; - fn create_test_value_text(text: &str, data_type: DataType) -> OdbcValueRef<'_> { - OdbcValueRef { + fn make_ref(value_vec: OdbcValueVec, data_type: DataType) -> OdbcValueRef<'static> { + let column = ColumnData { + values: value_vec, type_info: OdbcTypeInfo::new(data_type), - is_null: false, - text: Some(text), - blob: None, - int: None, - float: None, - } + nulls: vec![false], + }; + let column_data = vec![Arc::new(column)]; + let batch = OdbcBatch { + columns: Arc::new([OdbcColumn { + name: "test".to_string(), + type_info: OdbcTypeInfo::new(data_type), + ordinal: 0, + }]), + column_data, + }; + let batch_ptr = Box::leak(Box::new(batch)); + OdbcValueRef::new(batch_ptr, 0, 0) + } + + fn create_test_value_text(text: &'static str, data_type: DataType) -> OdbcValueRef<'static> { + make_ref(OdbcValueVec::Text(vec![text.to_string()]), data_type) + } + + fn create_test_value_bytes(bytes: &'static [u8], data_type: DataType) -> OdbcValueRef<'static> { + make_ref(OdbcValueVec::Binary(vec![bytes.to_vec()]), data_type) } fn create_test_value_int(value: i64, data_type: DataType) -> OdbcValueRef<'static> { - OdbcValueRef { - type_info: OdbcTypeInfo::new(data_type), - is_null: false, - text: None, - blob: None, - int: Some(value), - float: None, - } + make_ref(OdbcValueVec::BigInt(vec![value]), data_type) } fn create_test_value_float(value: f64, data_type: DataType) -> OdbcValueRef<'static> { - OdbcValueRef { - type_info: OdbcTypeInfo::new(data_type), - is_null: false, - text: None, - blob: None, - int: None, - float: Some(value), - } + make_ref(OdbcValueVec::Double(vec![value]), data_type) } #[test] @@ -259,14 +265,22 @@ mod tests { #[test] fn test_decimal_decode_error_handling() { - let value = OdbcValueRef { + let column = ColumnData { + values: OdbcValueVec::Text(vec!["not_a_number".to_string()]), type_info: OdbcTypeInfo::decimal(10, 2), - is_null: false, - text: None, - blob: None, - int: None, - float: None, + nulls: vec![false], + }; + let column_data = vec![Arc::new(column)]; + let batch = OdbcBatch { + columns: Arc::new([OdbcColumn { + name: "test".to_string(), + type_info: OdbcTypeInfo::decimal(10, 2), + ordinal: 0, + }]), + column_data, }; + let batch_ptr = Box::leak(Box::new(batch)); + let value = OdbcValueRef::new(batch_ptr, 0, 0); let result = >::decode(value); assert!(result.is_err()); diff --git a/sqlx-core/src/odbc/types/float.rs b/sqlx-core/src/odbc/types/float.rs index 09ed1fcb90..3c7f4d0c0b 100644 --- a/sqlx-core/src/odbc/types/float.rs +++ b/sqlx-core/src/odbc/types/float.rs @@ -71,21 +71,40 @@ impl<'q> Encode<'q, Odbc> for f64 { impl<'r> Decode<'r, Odbc> for f64 { fn decode(value: OdbcValueRef<'r>) -> Result { - if let Some(f) = value.float { + if let Some(f) = value.float::() { return Ok(f); } - if let Some(int) = value.int { + + if let Some(int) = value.int::() { return Ok(int as f64); } - if let Some(s) = value.text { - return Ok(s.trim().parse()?); + + if let Some(s) = value.text() { + if let Ok(parsed) = s.trim().parse::() { + return Ok(parsed); + } } + Err(format!("ODBC: cannot decode f64: {:?}", value).into()) } } impl<'r> Decode<'r, Odbc> for f32 { fn decode(value: OdbcValueRef<'r>) -> Result { - Ok(>::decode(value)? as f32) + if let Some(f) = value.float::() { + return Ok(f as f32); + } + + if let Some(int) = value.int::() { + return Ok(int as f32); + } + + if let Some(s) = value.text() { + if let Ok(parsed) = s.trim().parse::() { + return Ok(parsed); + } + } + + Err(format!("ODBC: cannot decode f32: {:?}", value).into()) } } diff --git a/sqlx-core/src/odbc/types/int.rs b/sqlx-core/src/odbc/types/int.rs index 485d963194..e4da729e4a 100644 --- a/sqlx-core/src/odbc/types/int.rs +++ b/sqlx-core/src/odbc/types/int.rs @@ -263,10 +263,10 @@ fn parse_numeric_as_i64(s: &str) -> Option { } fn get_text_for_numeric_parsing(value: &OdbcValueRef<'_>) -> Result, BoxDynError> { - if let Some(text) = value.text { + if let Some(text) = value.text() { return Ok(Some(text.trim().to_string())); } - if let Some(bytes) = value.blob { + if let Some(bytes) = value.blob() { let s = std::str::from_utf8(bytes)?; return Ok(Some(s.trim().to_string())); } @@ -275,104 +275,94 @@ fn get_text_for_numeric_parsing(value: &OdbcValueRef<'_>) -> Result Decode<'r, Odbc> for i64 { fn decode(value: OdbcValueRef<'r>) -> Result { - if let Some(i) = value.int { - return Ok(i); - } - if let Some(f) = value.float { - return Ok(f as i64); - } - if let Some(text) = get_text_for_numeric_parsing(&value)? { - if let Some(parsed) = parse_numeric_as_i64(&text) { - return Ok(parsed); - } - } - Err("ODBC: cannot decode i64".into()) + Ok(value.try_int::()?) } } impl<'r> Decode<'r, Odbc> for i32 { fn decode(value: OdbcValueRef<'r>) -> Result { - Ok(>::decode(value)? as i32) + Ok(value.try_int::()?) } } impl<'r> Decode<'r, Odbc> for i16 { fn decode(value: OdbcValueRef<'r>) -> Result { - Ok(>::decode(value)? as i16) + Ok(value.try_int::()?) } } impl<'r> Decode<'r, Odbc> for i8 { fn decode(value: OdbcValueRef<'r>) -> Result { - Ok(>::decode(value)? as i8) + Ok(value.try_int::()?) } } impl<'r> Decode<'r, Odbc> for u8 { fn decode(value: OdbcValueRef<'r>) -> Result { - let i = >::decode(value)?; - Ok(u8::try_from(i)?) + Ok(value.try_int::()?) } } impl<'r> Decode<'r, Odbc> for u16 { fn decode(value: OdbcValueRef<'r>) -> Result { - let i = >::decode(value)?; - Ok(u16::try_from(i)?) + Ok(value.try_int::()?) } } impl<'r> Decode<'r, Odbc> for u32 { fn decode(value: OdbcValueRef<'r>) -> Result { - let i = >::decode(value)?; - Ok(u32::try_from(i)?) + Ok(value.try_int::()?) } } impl<'r> Decode<'r, Odbc> for u64 { fn decode(value: OdbcValueRef<'r>) -> Result { - let i = >::decode(value)?; - Ok(u64::try_from(i)?) + Ok(value.try_int::()?) } } #[cfg(test)] mod tests { use super::*; - use crate::odbc::{OdbcTypeInfo, OdbcValueRef}; + use crate::odbc::{ + ColumnData, OdbcBatch, OdbcColumn, OdbcTypeInfo, OdbcValueRef, OdbcValueVec, + }; use odbc_api::DataType; + use std::sync::Arc; - fn create_test_value_text(text: &'static str, data_type: DataType) -> OdbcValueRef<'static> { - OdbcValueRef { + fn make_ref(value_vec: OdbcValueVec, data_type: DataType) -> OdbcValueRef<'static> { + let column = ColumnData { + values: value_vec, type_info: OdbcTypeInfo::new(data_type), - is_null: false, - text: Some(text), - blob: None, - int: None, - float: None, - } + nulls: vec![false], + }; + let column_data = vec![Arc::new(column)]; + let batch = OdbcBatch { + columns: Arc::new([OdbcColumn { + name: "test".to_string(), + type_info: OdbcTypeInfo::new(data_type), + ordinal: 0, + }]), + column_data, + }; + let batch_ptr = Box::leak(Box::new(batch)); + OdbcValueRef::new(batch_ptr, 0, 0) + } + + fn create_test_value_text(text: &'static str, data_type: DataType) -> OdbcValueRef<'static> { + make_ref(OdbcValueVec::Text(vec![text.to_string()]), data_type) + } + + fn create_test_value_blob(data: &'static [u8], data_type: DataType) -> OdbcValueRef<'static> { + make_ref(OdbcValueVec::Binary(vec![data.to_vec()]), data_type) } fn create_test_value_int(value: i64, data_type: DataType) -> OdbcValueRef<'static> { - OdbcValueRef { - type_info: OdbcTypeInfo::new(data_type), - is_null: false, - text: None, - blob: None, - int: Some(value), - float: None, - } + make_ref(OdbcValueVec::BigInt(vec![value]), data_type) } fn create_test_value_float(value: f64, data_type: DataType) -> OdbcValueRef<'static> { - OdbcValueRef { - type_info: OdbcTypeInfo::new(data_type), - is_null: false, - text: None, - blob: None, - int: None, - float: Some(value), - } + make_ref(OdbcValueVec::Double(vec![value]), data_type) } #[test] @@ -411,18 +401,7 @@ mod tests { scale: 0, }, ); - let decoded = >::decode(value)?; - assert_eq!(decoded, 42); - - // Test with decimal value (should truncate) - let value = create_test_value_text( - "42.7", - DataType::Decimal { - precision: 10, - scale: 1, - }, - ); - let decoded = >::decode(value)?; + let decoded = >::decode(value).expect("Failed to decode 42"); assert_eq!(decoded, 42); // Test with whitespace @@ -433,7 +412,7 @@ mod tests { scale: 0, }, ); - let decoded = >::decode(value)?; + let decoded = >::decode(value).expect("Failed to decode ' 123 '"); assert_eq!(decoded, 123); Ok(()) @@ -451,8 +430,10 @@ mod tests { #[test] fn test_i64_decode_from_float() -> Result<(), BoxDynError> { let value = create_test_value_float(42.7, DataType::Double); - let decoded = >::decode(value)?; - assert_eq!(decoded, 42); + let result = >::decode(value); + // i64 should not be compatible with DOUBLE type + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("mismatched types")); Ok(()) } @@ -520,18 +501,30 @@ mod tests { #[test] fn test_decode_error_handling() { - let value = OdbcValueRef { + let column = ColumnData { + values: OdbcValueVec::Text(vec!["not_a_number".to_string()]), type_info: OdbcTypeInfo::INTEGER, - is_null: false, - text: None, - blob: None, - int: None, - float: None, + nulls: vec![false], }; + let column_data = vec![Arc::new(column)]; + let batch = OdbcBatch { + columns: Arc::new([OdbcColumn { + name: "test".to_string(), + type_info: OdbcTypeInfo::INTEGER, + ordinal: 0, + }]), + column_data, + }; + let batch_ptr = Box::leak(Box::new(batch)); + let value = OdbcValueRef::new(batch_ptr, 0, 0); let result = >::decode(value); assert!(result.is_err()); - assert_eq!(result.unwrap_err().to_string(), "ODBC: cannot decode i64"); + // The new implementation gives more specific error messages + let error_msg = result.unwrap_err().to_string(); + assert!( + error_msg.contains("mismatched types") || error_msg.contains("ODBC: cannot decode") + ); } #[test] diff --git a/sqlx-core/src/odbc/types/json.rs b/sqlx-core/src/odbc/types/json.rs index dcd7db14c4..e2d77701b3 100644 --- a/sqlx-core/src/odbc/types/json.rs +++ b/sqlx-core/src/odbc/types/json.rs @@ -3,7 +3,6 @@ use crate::encode::Encode; use crate::error::BoxDynError; use crate::odbc::{DataTypeExt, Odbc, OdbcArgumentValue, OdbcTypeInfo, OdbcValueRef}; use crate::types::Type; -use serde::de::Error; use serde_json::Value; impl Type for Value { @@ -35,38 +34,50 @@ impl<'q> Encode<'q, Odbc> for Value { impl<'r> Decode<'r, Odbc> for Value { fn decode(value: OdbcValueRef<'r>) -> Result { - if let Some(bytes) = value.blob { - serde_json::from_slice(bytes) - } else if let Some(text) = value.text { - serde_json::from_str(text) - } else if let Some(i) = value.int { - Ok(serde_json::Value::from(i)) - } else if let Some(f) = value.float { - Ok(serde_json::Value::from(f)) - } else { - Err(serde_json::Error::custom("not a valid json type")) + if let Some(bytes) = value.blob() { + return serde_json::from_slice(bytes) + .map_err(|e| format!("ODBC: cannot decode JSON from {:?}: {}", value, e).into()); + } else if let Some(text) = value.text() { + return serde_json::from_str(text) + .map_err(|e| format!("ODBC: cannot decode JSON from {:?}: {}", value, e).into()); + } else if let Some(i) = value.int::() { + return Ok(Value::from(i)); + } else if let Some(f) = value.float::() { + return Ok(Value::from(f)); } - .map_err(|e| format!("ODBC: cannot decode JSON from {:?}: {}", value, e).into()) + + Err(format!("ODBC: cannot decode JSON from {:?}", value).into()) } } #[cfg(test)] mod tests { use super::*; - use crate::odbc::{OdbcTypeInfo, OdbcValueRef}; + use crate::odbc::{ + ColumnData, OdbcBatch, OdbcColumn, OdbcTypeInfo, OdbcValueRef, OdbcValueVec, + }; use crate::type_info::TypeInfo; use odbc_api::DataType; use serde_json::{json, Value}; + use std::sync::Arc; fn create_test_value_text(text: &'static str, data_type: DataType) -> OdbcValueRef<'static> { - OdbcValueRef { + let column = ColumnData { + values: OdbcValueVec::Text(vec![text.to_string()]), type_info: OdbcTypeInfo::new(data_type), - is_null: false, - text: Some(text), - blob: None, - int: None, - float: None, - } + nulls: vec![false], + }; + let column_data = vec![Arc::new(column)]; + let batch = OdbcBatch { + columns: Arc::new([OdbcColumn { + name: "test".to_string(), + type_info: OdbcTypeInfo::new(data_type), + ordinal: 0, + }]), + column_data, + }; + let batch_ptr = Box::leak(Box::new(batch)); + OdbcValueRef::new(batch_ptr, 0, 0) } #[test] diff --git a/sqlx-core/src/odbc/types/str.rs b/sqlx-core/src/odbc/types/str.rs index 32207efb6c..ae568d2341 100644 --- a/sqlx-core/src/odbc/types/str.rs +++ b/sqlx-core/src/odbc/types/str.rs @@ -48,11 +48,11 @@ impl<'q> Encode<'q, Odbc> for &'q str { impl<'r> Decode<'r, Odbc> for String { fn decode(value: OdbcValueRef<'r>) -> Result { - if let Some(text) = value.text { - return Ok(text.to_owned()); + if let Some(text) = value.text() { + return Ok(text.to_string()); } - if let Some(bytes) = value.blob { - return Ok(std::str::from_utf8(bytes)?.to_owned()); + if let Some(bytes) = value.blob() { + return Ok(String::from_utf8(bytes.to_vec())?); } Err("ODBC: cannot decode String".into()) } @@ -60,12 +60,12 @@ impl<'r> Decode<'r, Odbc> for String { impl<'r> Decode<'r, Odbc> for &'r str { fn decode(value: OdbcValueRef<'r>) -> Result { - if let Some(text) = value.text { + if let Some(text) = value.text() { return Ok(text); } - if let Some(bytes) = value.blob { + if let Some(bytes) = value.blob() { return Ok(std::str::from_utf8(bytes)?); } - Err("ODBC: cannot decode &str".into()) + Err(format!("ODBC: cannot decode &str: {:?}", value).into()) } } diff --git a/sqlx-core/src/odbc/types/time.rs b/sqlx-core/src/odbc/types/time.rs index 3d0d0d0d44..3300c6fc82 100644 --- a/sqlx-core/src/odbc/types/time.rs +++ b/sqlx-core/src/odbc/types/time.rs @@ -103,20 +103,20 @@ fn parse_unix_timestamp_as_offset_datetime(timestamp: i64) -> Option Decode<'r, Odbc> for OffsetDateTime { fn decode(value: OdbcValueRef<'r>) -> Result { // Handle numeric timestamps (Unix epoch seconds) first - if let Some(int_val) = value.int { + if let Some(int_val) = value.int() { if let Some(dt) = parse_unix_timestamp_as_offset_datetime(int_val) { return Ok(dt); } } - if let Some(float_val) = value.float { + if let Some(float_val) = value.float::() { if let Some(dt) = parse_unix_timestamp_as_offset_datetime(float_val as i64) { return Ok(dt); } } // Handle text values - if let Some(text) = value.text { + if let Some(text) = value.text() { let trimmed = text.trim(); // Try parsing as ISO-8601 timestamp with timezone if let Ok(dt) = OffsetDateTime::parse( @@ -148,14 +148,14 @@ impl<'r> Decode<'r, Odbc> for OffsetDateTime { impl<'r> Decode<'r, Odbc> for PrimitiveDateTime { fn decode(value: OdbcValueRef<'r>) -> Result { // Handle numeric timestamps (Unix epoch seconds) first - if let Some(int_val) = value.int { + if let Some(int_val) = value.int() { if let Some(offset_dt) = parse_unix_timestamp_as_offset_datetime(int_val) { let utc_dt = offset_dt.to_offset(time::UtcOffset::UTC); return Ok(PrimitiveDateTime::new(utc_dt.date(), utc_dt.time())); } } - if let Some(float_val) = value.float { + if let Some(float_val) = value.float::() { if let Some(offset_dt) = parse_unix_timestamp_as_offset_datetime(float_val as i64) { let utc_dt = offset_dt.to_offset(time::UtcOffset::UTC); return Ok(PrimitiveDateTime::new(utc_dt.date(), utc_dt.time())); @@ -163,7 +163,7 @@ impl<'r> Decode<'r, Odbc> for PrimitiveDateTime { } // Handle text values - if let Some(text) = value.text { + if let Some(text) = value.text() { let trimmed = text.trim(); // Try parsing as ISO-8601 if let Ok(dt) = PrimitiveDateTime::parse( @@ -227,8 +227,21 @@ fn parse_yyyymmdd_text_as_time_date(s: &str) -> Option { impl<'r> Decode<'r, Odbc> for Date { fn decode(value: OdbcValueRef<'r>) -> Result { + // Handle raw ODBC Date values first + if let Some(date_val) = value.date() { + // Convert odbc_api::sys::Date to time::Date + // The ODBC Date structure typically has year, month, day fields + let month = time::Month::try_from(date_val.month as u8) + .map_err(|_| "ODBC: invalid month value")?; + return Ok(Date::from_calendar_date( + date_val.year as i32, + month, + date_val.day as u8, + )?); + } + // Handle numeric YYYYMMDD format first - if let Some(int_val) = value.int { + if let Some(int_val) = value.int() { if let Some(date) = parse_yyyymmdd_as_time_date(int_val) { return Ok(date); } @@ -243,14 +256,14 @@ impl<'r> Decode<'r, Odbc> for Date { } // Handle float values - if let Some(float_val) = value.float { + if let Some(float_val) = value.float::() { if let Some(date) = parse_yyyymmdd_as_time_date(float_val as i64) { return Ok(date); } } // Handle text values - if let Some(text) = value.text { + if let Some(text) = value.text() { let trimmed = text.trim(); if let Some(date) = parse_yyyymmdd_text_as_time_date(trimmed) { return Ok(date); @@ -290,20 +303,20 @@ fn parse_seconds_as_time(seconds: i64) -> Option