Skip to content

Commit 6fa10ae

Browse files
committed
refactor(odbc): remove some code duplication
This commit introduces generic helper functions for handling non-nullable and nullable slices in the conversion of AnySlice to OdbcValueVec. The refactor improves code readability and maintainability by reducing redundancy in the conversion logic. Additionally, the OdbcBridge is updated to utilize a helper function for determining buffer lengths, streamlining the buffer description process for various data types.
1 parent 0a66a1b commit 6fa10ae

File tree

2 files changed

+145
-123
lines changed

2 files changed

+145
-123
lines changed

sqlx-core/src/odbc/connection/odbc_bridge.rs

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ use either::Either;
88
use flume::{SendError, Sender};
99
use odbc_api::buffers::{AnySlice, BufferDesc, ColumnarAnyBuffer};
1010
use odbc_api::handles::{AsStatementRef, Nullability, Statement};
11-
use odbc_api::DataType;
1211
use odbc_api::{Cursor, IntoParameter, ResultSetMetadata};
1312
use std::sync::Arc;
1413

@@ -189,31 +188,44 @@ fn map_buffer_desc<C>(
189188
where
190189
C: ResultSetMetadata,
191190
{
191+
use odbc_api::DataType;
192+
192193
let data_type = type_info.data_type();
194+
195+
// Helper function to determine buffer length with fallback
196+
let buffer_length = |length: Option<std::num::NonZeroUsize>| {
197+
if let Some(length) = length {
198+
if length.get() < 255 {
199+
length.get()
200+
} else {
201+
buffer_settings.max_column_size
202+
}
203+
} else {
204+
buffer_settings.max_column_size
205+
}
206+
};
207+
193208
let buffer_desc = match data_type {
209+
// Integer types - all map to I64
194210
DataType::TinyInt | DataType::SmallInt | DataType::Integer | DataType::BigInt => {
195211
BufferDesc::I64 { nullable }
196212
}
213+
// Floating point types
197214
DataType::Real => BufferDesc::F32 { nullable },
198215
DataType::Float { .. } | DataType::Double => BufferDesc::F64 { nullable },
216+
// Bit type
199217
DataType::Bit => BufferDesc::Bit { nullable },
218+
// Date/Time types
200219
DataType::Date => BufferDesc::Date { nullable },
201220
DataType::Time { .. } => BufferDesc::Time { nullable },
202221
DataType::Timestamp { .. } => BufferDesc::Timestamp { nullable },
222+
// Binary types
203223
DataType::Binary { length }
204224
| DataType::Varbinary { length }
205225
| DataType::LongVarbinary { length } => BufferDesc::Binary {
206-
length: if let Some(length) = length {
207-
if length.get() < 255 {
208-
// Some drivers report 255 for max length
209-
length.get()
210-
} else {
211-
buffer_settings.max_column_size
212-
}
213-
} else {
214-
buffer_settings.max_column_size
215-
},
226+
length: buffer_length(length),
216227
},
228+
// Text types
217229
DataType::Char { length }
218230
| DataType::WChar { length }
219231
| DataType::Varchar { length }
@@ -224,16 +236,9 @@ where
224236
column_size: length,
225237
..
226238
} => BufferDesc::Text {
227-
max_str_len: if let Some(length) = length {
228-
if length.get() < 255 {
229-
length.get()
230-
} else {
231-
buffer_settings.max_column_size
232-
}
233-
} else {
234-
buffer_settings.max_column_size
235-
},
239+
max_str_len: buffer_length(length),
236240
},
241+
// Fallback cases
237242
DataType::Unknown => BufferDesc::Text {
238243
max_str_len: buffer_settings.max_column_size,
239244
},

sqlx-core/src/odbc/value.rs

Lines changed: 120 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use crate::odbc::{Odbc, OdbcBatch, OdbcTypeInfo};
22
use crate::value::{Value, ValueRef};
3-
use odbc_api::buffers::AnySlice;
3+
use odbc_api::buffers::{AnySlice, NullableSlice};
44
use odbc_api::sys::NULL_DATA;
55
use std::borrow::Cow;
66
use std::sync::Arc;
@@ -284,65 +284,66 @@ pub enum OdbcValueType {
284284
Timestamp(odbc_api::sys::Timestamp),
285285
}
286286

287+
/// Generic helper function to handle non-nullable slices
288+
fn handle_non_nullable_slice<T: Copy>(
289+
slice: &[T],
290+
constructor: fn(Vec<T>) -> OdbcValueVec,
291+
) -> (OdbcValueVec, Vec<bool>) {
292+
let vec = slice.to_vec();
293+
(constructor(vec), vec![false; slice.len()])
294+
}
295+
296+
/// Generic helper function to handle nullable slices with custom default values
297+
fn handle_nullable_slice<'a, T: Default + Copy>(
298+
slice: NullableSlice<'a, T>,
299+
constructor: fn(Vec<T>) -> OdbcValueVec,
300+
) -> (OdbcValueVec, Vec<bool>) {
301+
let size = slice.size_hint().1.unwrap_or(0);
302+
let mut values = Vec::with_capacity(size);
303+
let mut nulls = Vec::with_capacity(size);
304+
for opt in slice {
305+
values.push(opt.copied().unwrap_or_default());
306+
nulls.push(opt.is_none());
307+
}
308+
(constructor(values), nulls)
309+
}
310+
311+
/// Generic helper function to handle nullable slices with NULL_DATA indicators
312+
fn handle_nullable_with_indicators<T: Default + Copy>(
313+
raw_values: &[T],
314+
indicators: &[isize],
315+
constructor: fn(Vec<T>) -> OdbcValueVec,
316+
) -> (OdbcValueVec, Vec<bool>) {
317+
let nulls = indicators.iter().map(|&ind| ind == NULL_DATA).collect();
318+
(constructor(raw_values.to_vec()), nulls)
319+
}
320+
287321
/// Convert AnySlice to owned OdbcValueVec and nulls vector, preserving original types
288322
pub fn convert_any_slice_to_value_vec(slice: AnySlice<'_>) -> (OdbcValueVec, Vec<bool>) {
289323
match slice {
290-
AnySlice::I8(s) => (OdbcValueVec::TinyInt(s.to_vec()), vec![false; s.len()]),
291-
AnySlice::I16(s) => (OdbcValueVec::SmallInt(s.to_vec()), vec![false; s.len()]),
292-
AnySlice::I32(s) => (OdbcValueVec::Integer(s.to_vec()), vec![false; s.len()]),
293-
AnySlice::I64(s) => (OdbcValueVec::BigInt(s.to_vec()), vec![false; s.len()]),
294-
AnySlice::F32(s) => (OdbcValueVec::Real(s.to_vec()), vec![false; s.len()]),
295-
AnySlice::F64(s) => (OdbcValueVec::Double(s.to_vec()), vec![false; s.len()]),
296-
AnySlice::Bit(s) => (OdbcValueVec::Bit(s.to_vec()), vec![false; s.len()]),
297-
298-
AnySlice::NullableI8(s) => {
299-
let values: Vec<Option<i8>> = s.map(|opt| opt.copied()).collect();
300-
let nulls = values.iter().map(|opt| opt.is_none()).collect();
301-
(
302-
OdbcValueVec::TinyInt(values.into_iter().map(|opt| opt.unwrap_or(0)).collect()),
303-
nulls,
304-
)
305-
}
306-
AnySlice::NullableI16(s) => {
307-
let values: Vec<Option<i16>> = s.map(|opt| opt.copied()).collect();
308-
let nulls = values.iter().map(|opt| opt.is_none()).collect();
309-
(
310-
OdbcValueVec::SmallInt(values.into_iter().map(|opt| opt.unwrap_or(0)).collect()),
311-
nulls,
312-
)
313-
}
314-
AnySlice::NullableI32(s) => {
315-
let values: Vec<Option<i32>> = s.map(|opt| opt.copied()).collect();
316-
let nulls = values.iter().map(|opt| opt.is_none()).collect();
317-
(
318-
OdbcValueVec::Integer(values.into_iter().map(|opt| opt.unwrap_or(0)).collect()),
319-
nulls,
320-
)
321-
}
322-
AnySlice::NullableI64(s) => {
323-
let values: Vec<Option<i64>> = s.map(|opt| opt.copied()).collect();
324-
let nulls = values.iter().map(|opt| opt.is_none()).collect();
325-
(
326-
OdbcValueVec::BigInt(values.into_iter().map(|opt| opt.unwrap_or(0)).collect()),
327-
nulls,
328-
)
329-
}
330-
AnySlice::NullableF32(s) => {
331-
let values: Vec<Option<f32>> = s.map(|opt| opt.copied()).collect();
332-
let nulls = values.iter().map(|opt| opt.is_none()).collect();
333-
(
334-
OdbcValueVec::Real(values.into_iter().map(|opt| opt.unwrap_or(0.0)).collect()),
335-
nulls,
336-
)
337-
}
338-
AnySlice::NullableF64(s) => {
339-
let values: Vec<Option<f64>> = s.map(|opt| opt.copied()).collect();
340-
let nulls = values.iter().map(|opt| opt.is_none()).collect();
341-
(
342-
OdbcValueVec::Double(values.into_iter().map(|opt| opt.unwrap_or(0.0)).collect()),
343-
nulls,
344-
)
345-
}
324+
// Non-nullable integer types
325+
AnySlice::I8(s) => handle_non_nullable_slice(s, OdbcValueVec::TinyInt),
326+
AnySlice::I16(s) => handle_non_nullable_slice(s, OdbcValueVec::SmallInt),
327+
AnySlice::I32(s) => handle_non_nullable_slice(s, OdbcValueVec::Integer),
328+
AnySlice::I64(s) => handle_non_nullable_slice(s, OdbcValueVec::BigInt),
329+
330+
// Non-nullable floating point types
331+
AnySlice::F32(s) => handle_non_nullable_slice(s, OdbcValueVec::Real),
332+
AnySlice::F64(s) => handle_non_nullable_slice(s, OdbcValueVec::Double),
333+
334+
// Non-nullable other types
335+
AnySlice::Bit(s) => handle_non_nullable_slice(s, OdbcValueVec::Bit),
336+
AnySlice::Date(s) => handle_non_nullable_slice(s, OdbcValueVec::Date),
337+
AnySlice::Time(s) => handle_non_nullable_slice(s, OdbcValueVec::Time),
338+
AnySlice::Timestamp(s) => handle_non_nullable_slice(s, OdbcValueVec::Timestamp),
339+
340+
// Nullable integer types
341+
AnySlice::NullableI8(s) => handle_nullable_slice(s, OdbcValueVec::TinyInt),
342+
AnySlice::NullableI16(s) => handle_nullable_slice(s, OdbcValueVec::SmallInt),
343+
AnySlice::NullableI32(s) => handle_nullable_slice(s, OdbcValueVec::Integer),
344+
AnySlice::NullableI64(s) => handle_nullable_slice(s, OdbcValueVec::BigInt),
345+
AnySlice::NullableF32(s) => handle_nullable_slice(s, OdbcValueVec::Real),
346+
AnySlice::NullableF64(s) => handle_nullable_slice(s, OdbcValueVec::Double),
346347
AnySlice::NullableBit(s) => {
347348
let values: Vec<Option<odbc_api::Bit>> = s.map(|opt| opt.copied()).collect();
348349
let nulls = values.iter().map(|opt| opt.is_none()).collect();
@@ -357,6 +358,7 @@ pub fn convert_any_slice_to_value_vec(slice: AnySlice<'_>) -> (OdbcValueVec, Vec
357358
)
358359
}
359360

361+
// Text and binary types (inherently nullable)
360362
AnySlice::Text(s) => {
361363
let texts: Vec<Option<String>> = s
362364
.iter()
@@ -365,7 +367,6 @@ pub fn convert_any_slice_to_value_vec(slice: AnySlice<'_>) -> (OdbcValueVec, Vec
365367
let nulls = texts.iter().map(|opt| opt.is_none()).collect();
366368
(OdbcValueVec::Text(texts), nulls)
367369
}
368-
369370
AnySlice::Binary(s) => {
370371
let binaries: Vec<Option<Vec<u8>>> = s
371372
.iter()
@@ -375,23 +376,18 @@ pub fn convert_any_slice_to_value_vec(slice: AnySlice<'_>) -> (OdbcValueVec, Vec
375376
(OdbcValueVec::Binary(binaries), nulls)
376377
}
377378

378-
AnySlice::Date(s) => (OdbcValueVec::Date(s.to_vec()), vec![false; s.len()]),
379-
AnySlice::Time(s) => (OdbcValueVec::Time(s.to_vec()), vec![false; s.len()]),
380-
AnySlice::Timestamp(s) => (OdbcValueVec::Timestamp(s.to_vec()), vec![false; s.len()]),
379+
// Nullable date/time types with NULL_DATA indicators
381380
AnySlice::NullableDate(s) => {
382381
let (raw_values, indicators) = s.raw_values();
383-
let nulls = indicators.iter().map(|&ind| ind == NULL_DATA).collect();
384-
(OdbcValueVec::Date(raw_values.to_vec()), nulls)
382+
handle_nullable_with_indicators(raw_values, indicators, OdbcValueVec::Date)
385383
}
386384
AnySlice::NullableTime(s) => {
387385
let (raw_values, indicators) = s.raw_values();
388-
let nulls = indicators.iter().map(|&ind| ind == NULL_DATA).collect();
389-
(OdbcValueVec::Time(raw_values.to_vec()), nulls)
386+
handle_nullable_with_indicators(raw_values, indicators, OdbcValueVec::Time)
390387
}
391388
AnySlice::NullableTimestamp(s) => {
392389
let (raw_values, indicators) = s.raw_values();
393-
let nulls = indicators.iter().map(|&ind| ind == NULL_DATA).collect();
394-
(OdbcValueVec::Timestamp(raw_values.to_vec()), nulls)
390+
handle_nullable_with_indicators(raw_values, indicators, OdbcValueVec::Timestamp)
395391
}
396392

397393
_ => panic!("Unsupported AnySlice variant"),
@@ -402,31 +398,36 @@ fn value_vec_is_null(column_data: &ColumnData, row_index: usize) -> bool {
402398
column_data.nulls.get(row_index).copied().unwrap_or(false)
403399
}
404400

401+
macro_rules! impl_get_raw_arm {
402+
($vec:expr, $row_index:expr, $variant:ident, $type:ty) => {
403+
$vec.get($row_index)
404+
.map(|&val| OdbcValueType::$variant(val))
405+
};
406+
($vec:expr, $row_index:expr, $variant:ident, $type:ty, copy) => {
407+
$vec.get($row_index).copied().map(OdbcValueType::$variant)
408+
};
409+
($vec:expr, $row_index:expr, $variant:ident, $type:ty, clone) => {
410+
$vec.get($row_index)
411+
.and_then(|opt| opt.clone().map(OdbcValueType::$variant))
412+
};
413+
}
414+
405415
fn value_vec_get_raw(values: &OdbcValueVec, row_index: usize) -> Option<OdbcValueType> {
406416
match values {
407-
OdbcValueVec::TinyInt(v) => v.get(row_index).map(|&val| OdbcValueType::TinyInt(val)),
408-
OdbcValueVec::SmallInt(v) => v.get(row_index).map(|&val| OdbcValueType::SmallInt(val)),
409-
OdbcValueVec::Integer(v) => v.get(row_index).map(|&val| OdbcValueType::Integer(val)),
410-
OdbcValueVec::BigInt(v) => v.get(row_index).map(|&val| OdbcValueType::BigInt(val)),
411-
OdbcValueVec::Real(v) => v.get(row_index).map(|&val| OdbcValueType::Real(val)),
412-
OdbcValueVec::Double(v) => v.get(row_index).map(|&val| OdbcValueType::Double(val)),
413-
OdbcValueVec::Bit(v) => v.get(row_index).map(|&val| OdbcValueType::Bit(val)),
414-
OdbcValueVec::Text(v) => v
415-
.get(row_index)
416-
.and_then(|opt| opt.clone().map(OdbcValueType::Text)),
417-
OdbcValueVec::Binary(v) => v
418-
.get(row_index)
419-
.and_then(|opt| opt.clone().map(OdbcValueType::Binary)),
420-
OdbcValueVec::Date(raw_values) => {
421-
raw_values.get(row_index).copied().map(OdbcValueType::Date)
422-
}
423-
OdbcValueVec::Time(raw_values) => {
424-
raw_values.get(row_index).copied().map(OdbcValueType::Time)
417+
OdbcValueVec::TinyInt(v) => impl_get_raw_arm!(v, row_index, TinyInt, i8),
418+
OdbcValueVec::SmallInt(v) => impl_get_raw_arm!(v, row_index, SmallInt, i16),
419+
OdbcValueVec::Integer(v) => impl_get_raw_arm!(v, row_index, Integer, i32),
420+
OdbcValueVec::BigInt(v) => impl_get_raw_arm!(v, row_index, BigInt, i64),
421+
OdbcValueVec::Real(v) => impl_get_raw_arm!(v, row_index, Real, f32),
422+
OdbcValueVec::Double(v) => impl_get_raw_arm!(v, row_index, Double, f64),
423+
OdbcValueVec::Bit(v) => impl_get_raw_arm!(v, row_index, Bit, odbc_api::Bit),
424+
OdbcValueVec::Text(v) => impl_get_raw_arm!(v, row_index, Text, Option<String>, clone),
425+
OdbcValueVec::Binary(v) => impl_get_raw_arm!(v, row_index, Binary, Option<Vec<u8>>, clone),
426+
OdbcValueVec::Date(v) => impl_get_raw_arm!(v, row_index, Date, odbc_api::sys::Date, copy),
427+
OdbcValueVec::Time(v) => impl_get_raw_arm!(v, row_index, Time, odbc_api::sys::Time, copy),
428+
OdbcValueVec::Timestamp(v) => {
429+
impl_get_raw_arm!(v, row_index, Timestamp, odbc_api::sys::Timestamp, copy)
425430
}
426-
OdbcValueVec::Timestamp(raw_values) => raw_values
427-
.get(row_index)
428-
.copied()
429-
.map(OdbcValueType::Timestamp),
430431
}
431432
}
432433

@@ -457,20 +458,30 @@ impl<
457458
{
458459
}
459460

461+
macro_rules! impl_int_conversion {
462+
($vec:expr, $row_index:expr, $type:ty) => {
463+
<$type>::try_from(*$vec.get($row_index)?).ok()
464+
};
465+
($vec:expr, $row_index:expr, $type:ty, bit) => {
466+
<$type>::try_from($vec.get($row_index)?.0).ok()
467+
};
468+
($vec:expr, $row_index:expr, $type:ty, text) => {
469+
if let Some(Some(text)) = $vec.get($row_index) {
470+
text.trim().parse().ok()
471+
} else {
472+
None
473+
}
474+
};
475+
}
476+
460477
fn value_vec_int<T: TryFromInt>(values: &OdbcValueVec, row_index: usize) -> Option<T> {
461478
match values {
462-
OdbcValueVec::TinyInt(v) => T::try_from(*v.get(row_index)?).ok(),
463-
OdbcValueVec::SmallInt(v) => T::try_from(*v.get(row_index)?).ok(),
464-
OdbcValueVec::Integer(v) => T::try_from(*v.get(row_index)?).ok(),
465-
OdbcValueVec::BigInt(v) => T::try_from(*v.get(row_index)?).ok(),
466-
OdbcValueVec::Bit(v) => T::try_from(v.get(row_index)?.0).ok(),
467-
OdbcValueVec::Text(v) => {
468-
if let Some(Some(text)) = v.get(row_index) {
469-
text.trim().parse().ok()
470-
} else {
471-
None
472-
}
473-
}
479+
OdbcValueVec::TinyInt(v) => impl_int_conversion!(v, row_index, T),
480+
OdbcValueVec::SmallInt(v) => impl_int_conversion!(v, row_index, T),
481+
OdbcValueVec::Integer(v) => impl_int_conversion!(v, row_index, T),
482+
OdbcValueVec::BigInt(v) => impl_int_conversion!(v, row_index, T),
483+
OdbcValueVec::Bit(v) => impl_int_conversion!(v, row_index, T, bit),
484+
OdbcValueVec::Text(v) => impl_int_conversion!(v, row_index, T, text),
474485
_ => None,
475486
}
476487
}
@@ -479,10 +490,16 @@ pub trait TryFromFloat: TryFrom<f32> + TryFrom<f64> {}
479490

480491
impl<T: TryFrom<f32> + TryFrom<f64>> TryFromFloat for T {}
481492

493+
macro_rules! impl_float_conversion {
494+
($vec:expr, $row_index:expr, $type:ty) => {
495+
<$type>::try_from(*$vec.get($row_index)?).ok()
496+
};
497+
}
498+
482499
fn value_vec_float<T: TryFromFloat>(values: &OdbcValueVec, row_index: usize) -> Option<T> {
483500
match values {
484-
OdbcValueVec::Real(v) => T::try_from(*v.get(row_index)?).ok(),
485-
OdbcValueVec::Double(v) => T::try_from(*v.get(row_index)?).ok(),
501+
OdbcValueVec::Real(v) => impl_float_conversion!(v, row_index, T),
502+
OdbcValueVec::Double(v) => impl_float_conversion!(v, row_index, T),
486503
_ => None,
487504
}
488505
}

0 commit comments

Comments
 (0)