Skip to content

Commit 55e410b

Browse files
committed
refactor(odbc): streamline decoding logic and enhance error handling
- Removed redundant hex string parsing function and integrated its logic into the decoding process for Vec<u8> and other types. - Improved error messages for decoding failures in NaiveDate, NaiveTime, and DateTime types to provide clearer context. - Updated JSON decoding to handle various data types more robustly. - Enhanced UUID decoding to support different string formats and added error handling for invalid UUIDs. - Adjusted tests to reflect changes in decoding logic and ensure compatibility across different database types.
1 parent b9d8b75 commit 55e410b

File tree

8 files changed

+151
-138
lines changed

8 files changed

+151
-138
lines changed

sqlx-core/src/odbc/types/bytes.rs

Lines changed: 2 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -38,50 +38,9 @@ impl<'q> Encode<'q, Odbc> for &'q [u8] {
3838
}
3939
}
4040

41-
// Helper function for hex string parsing
42-
fn try_parse_hex_string(s: &str) -> Option<Vec<u8>> {
43-
let trimmed = s.trim();
44-
if trimmed.len().is_multiple_of(2) && trimmed.chars().all(|c| c.is_ascii_hexdigit()) {
45-
let mut result = Vec::with_capacity(trimmed.len() / 2);
46-
for chunk in trimmed.as_bytes().chunks(2) {
47-
if let Ok(hex_str) = std::str::from_utf8(chunk) {
48-
if let Ok(byte_val) = u8::from_str_radix(hex_str, 16) {
49-
result.push(byte_val);
50-
} else {
51-
return None;
52-
}
53-
} else {
54-
return None;
55-
}
56-
}
57-
Some(result)
58-
} else {
59-
None
60-
}
61-
}
62-
6341
impl<'r> Decode<'r, Odbc> for Vec<u8> {
6442
fn decode(value: OdbcValueRef<'r>) -> Result<Self, BoxDynError> {
65-
if let Some(bytes) = value.blob {
66-
// Check if blob contains hex string representation
67-
if let Ok(text) = std::str::from_utf8(bytes) {
68-
if let Some(hex_bytes) = try_parse_hex_string(text) {
69-
return Ok(hex_bytes);
70-
}
71-
}
72-
// Fall back to raw blob bytes
73-
return Ok(bytes.to_vec());
74-
}
75-
if let Some(text) = value.text {
76-
// Try to decode as hex string first (common for ODBC drivers)
77-
if let Some(hex_bytes) = try_parse_hex_string(text) {
78-
return Ok(hex_bytes);
79-
}
80-
81-
// Fall back to raw text bytes
82-
return Ok(text.as_bytes().to_vec());
83-
}
84-
Err("ODBC: cannot decode Vec<u8>".into())
43+
Ok(<&[u8] as Decode<'r, Odbc>>::decode(value)?.to_vec())
8544
}
8645
}
8746

@@ -91,11 +50,9 @@ impl<'r> Decode<'r, Odbc> for &'r [u8] {
9150
return Ok(bytes);
9251
}
9352
if let Some(text) = value.text {
94-
// For slice types, we can only return the original text bytes
95-
// since we can't allocate new memory for hex decoding
9653
return Ok(text.as_bytes());
9754
}
98-
Err("ODBC: cannot decode &[u8]".into())
55+
Err(format!("ODBC: cannot decode {:?} as &[u8]", value).into())
9956
}
10057
}
10158

@@ -159,28 +116,6 @@ mod tests {
159116
assert!(!<Vec<u8> as Type<Odbc>>::compatible(&OdbcTypeInfo::INTEGER));
160117
}
161118

162-
#[test]
163-
fn test_hex_string_parsing() {
164-
// Test valid hex strings
165-
assert_eq!(
166-
try_parse_hex_string("4142434445"),
167-
Some(vec![65, 66, 67, 68, 69])
168-
);
169-
assert_eq!(
170-
try_parse_hex_string("48656C6C6F"),
171-
Some(vec![72, 101, 108, 108, 111])
172-
);
173-
assert_eq!(try_parse_hex_string(""), Some(vec![]));
174-
175-
// Test invalid hex strings
176-
assert_eq!(try_parse_hex_string("XYZ"), None);
177-
assert_eq!(try_parse_hex_string("123"), None); // Odd length
178-
assert_eq!(try_parse_hex_string("hello"), None);
179-
180-
// Test with whitespace
181-
assert_eq!(try_parse_hex_string(" 4142 "), Some(vec![65, 66]));
182-
}
183-
184119
#[test]
185120
fn test_vec_u8_decode_from_blob() -> Result<(), BoxDynError> {
186121
let test_data = b"Hello, ODBC!";
@@ -191,16 +126,6 @@ mod tests {
191126
Ok(())
192127
}
193128

194-
#[test]
195-
fn test_vec_u8_decode_from_hex_text() -> Result<(), BoxDynError> {
196-
let hex_str = "48656C6C6F"; // "Hello" in hex
197-
let value = create_test_value_text(hex_str, DataType::Varchar { length: None });
198-
let decoded = <Vec<u8> as Decode<Odbc>>::decode(value)?;
199-
assert_eq!(decoded, b"Hello".to_vec());
200-
201-
Ok(())
202-
}
203-
204129
#[test]
205130
fn test_vec_u8_decode_from_raw_text() -> Result<(), BoxDynError> {
206131
let text = "Hello, World!";

sqlx-core/src/odbc/types/chrono.rs

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use crate::encode::Encode;
33
use crate::error::BoxDynError;
44
use crate::odbc::{DataTypeExt, Odbc, OdbcArgumentValue, OdbcTypeInfo, OdbcValueRef};
55
use crate::types::Type;
6+
use crate::type_info::TypeInfo;
67
use chrono::{DateTime, FixedOffset, Local, NaiveDate, NaiveDateTime, NaiveTime, Utc};
78
use odbc_api::DataType;
89

@@ -192,11 +193,13 @@ fn parse_yyyymmdd_text_as_naive_date(s: &str) -> Option<NaiveDate> {
192193

193194
fn get_text_from_value(value: &OdbcValueRef<'_>) -> Result<Option<String>, BoxDynError> {
194195
if let Some(text) = value.text {
195-
return Ok(Some(text.trim().to_string()));
196+
let trimmed = text.trim_matches('\u{0}').trim();
197+
return Ok(Some(trimmed.to_string()));
196198
}
197199
if let Some(bytes) = value.blob {
198200
let s = std::str::from_utf8(bytes)?;
199-
return Ok(Some(s.trim().to_string()));
201+
let trimmed = s.trim_matches('\u{0}').trim();
202+
return Ok(Some(trimmed.to_string()));
200203
}
201204
Ok(None)
202205
}
@@ -208,31 +211,51 @@ impl<'r> Decode<'r, Odbc> for NaiveDate {
208211
if let Some(date) = parse_yyyymmdd_text_as_naive_date(&text) {
209212
return Ok(date);
210213
}
211-
return Ok(text.parse()?);
214+
if let Ok(date) = text.parse() {
215+
return Ok(date);
216+
}
212217
}
213218

214219
// Handle numeric YYYYMMDD format (for databases that return as numbers)
215220
if let Some(int_val) = value.int {
216221
if let Some(date) = parse_yyyymmdd_as_naive_date(int_val) {
217222
return Ok(date);
218223
}
224+
return Err(format!(
225+
"ODBC: cannot decode NaiveDate from integer '{}': not in YYYYMMDD range",
226+
int_val
227+
)
228+
.into());
219229
}
220230

221231
// Handle float values similarly
222232
if let Some(float_val) = value.float {
223233
if let Some(date) = parse_yyyymmdd_as_naive_date(float_val as i64) {
224234
return Ok(date);
225235
}
236+
return Err(format!(
237+
"ODBC: cannot decode NaiveDate from float '{}': not in YYYYMMDD range",
238+
float_val
239+
)
240+
.into());
226241
}
227242

228-
Err("ODBC: cannot decode NaiveDate".into())
243+
Err(format!(
244+
"ODBC: cannot decode NaiveDate from value with type '{}'",
245+
value.type_info.name()
246+
)
247+
.into())
229248
}
230249
}
231250

232251
impl<'r> Decode<'r, Odbc> for NaiveTime {
233252
fn decode(value: OdbcValueRef<'r>) -> Result<Self, BoxDynError> {
234-
let s = <String as Decode<'r, Odbc>>::decode(value)?;
235-
Ok(s.parse()?)
253+
let mut s = <String as Decode<'r, Odbc>>::decode(value)?;
254+
if s.ends_with('\u{0}') {
255+
s = s.trim_end_matches('\u{0}').to_string();
256+
}
257+
let s_trimmed = s.trim();
258+
Ok(s_trimmed.parse().map_err(|e| format!("ODBC: cannot decode NaiveTime from '{}': {}", s_trimmed, e))?)
236259
}
237260
}
238261

@@ -249,13 +272,18 @@ impl<'r> Decode<'r, Odbc> for NaiveDateTime {
249272
if let Ok(dt) = NaiveDateTime::parse_from_str(s_trimmed, "%Y-%m-%d %H:%M:%S") {
250273
return Ok(dt);
251274
}
252-
Ok(s_trimmed.parse()?)
275+
Ok(s_trimmed
276+
.parse()
277+
.map_err(|e| format!("ODBC: cannot decode NaiveDateTime from '{}': {}", s_trimmed, e))?)
253278
}
254279
}
255280

256281
impl<'r> Decode<'r, Odbc> for DateTime<Utc> {
257282
fn decode(value: OdbcValueRef<'r>) -> Result<Self, BoxDynError> {
258-
let s = <String as Decode<'r, Odbc>>::decode(value)?;
283+
let mut s = <String as Decode<'r, Odbc>>::decode(value)?;
284+
if s.ends_with('\u{0}') {
285+
s = s.trim_end_matches('\u{0}').to_string();
286+
}
259287
let s_trimmed = s.trim();
260288

261289
// First try to parse as a UTC timestamp with timezone
@@ -273,13 +301,16 @@ impl<'r> Decode<'r, Odbc> for DateTime<Utc> {
273301
return Ok(DateTime::<Utc>::from_naive_utc_and_offset(naive_dt, Utc));
274302
}
275303

276-
Err(format!("Cannot parse '{}' as DateTime<Utc>", s_trimmed).into())
304+
Err(format!("ODBC: cannot decode DateTime<Utc> from '{}'", s_trimmed).into())
277305
}
278306
}
279307

280308
impl<'r> Decode<'r, Odbc> for DateTime<FixedOffset> {
281309
fn decode(value: OdbcValueRef<'r>) -> Result<Self, BoxDynError> {
282-
let s = <String as Decode<'r, Odbc>>::decode(value)?;
310+
let mut s = <String as Decode<'r, Odbc>>::decode(value)?;
311+
if s.ends_with('\u{0}') {
312+
s = s.trim_end_matches('\u{0}').to_string();
313+
}
283314
let s_trimmed = s.trim();
284315

285316
// First try to parse as a timestamp with timezone/offset
@@ -297,14 +328,21 @@ impl<'r> Decode<'r, Odbc> for DateTime<FixedOffset> {
297328
return Ok(DateTime::<Utc>::from_naive_utc_and_offset(naive_dt, Utc).fixed_offset());
298329
}
299330

300-
Err(format!("Cannot parse '{}' as DateTime<FixedOffset>", s_trimmed).into())
331+
Err(format!("ODBC: cannot decode DateTime<FixedOffset> from '{}'", s_trimmed).into())
301332
}
302333
}
303334

304335
impl<'r> Decode<'r, Odbc> for DateTime<Local> {
305336
fn decode(value: OdbcValueRef<'r>) -> Result<Self, BoxDynError> {
306-
let s = <String as Decode<'r, Odbc>>::decode(value)?;
307-
Ok(s.parse::<DateTime<Utc>>()?.with_timezone(&Local))
337+
let mut s = <String as Decode<'r, Odbc>>::decode(value)?;
338+
if s.ends_with('\u{0}') {
339+
s = s.trim_end_matches('\u{0}').to_string();
340+
}
341+
let s_trimmed = s.trim();
342+
Ok(s_trimmed
343+
.parse::<DateTime<Utc>>()
344+
.map_err(|e| format!("ODBC: cannot decode DateTime<Local> from '{}' as DateTime<Utc>: {}", s_trimmed, e))?
345+
.with_timezone(&Local))
308346
}
309347
}
310348

sqlx-core/src/odbc/types/json.rs

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ impl Type<Odbc> for Value {
1111
}
1212
fn compatible(ty: &OdbcTypeInfo) -> bool {
1313
ty.data_type().accepts_character_data()
14+
|| ty.data_type().accepts_numeric_data()
15+
|| ty.data_type().accepts_binary_data()
16+
|| matches!(
17+
ty.data_type(),
18+
odbc_api::DataType::Other { .. } | odbc_api::DataType::Unknown
19+
)
1420
}
1521
}
1622

@@ -28,19 +34,31 @@ impl<'q> Encode<'q, Odbc> for Value {
2834

2935
impl<'r> Decode<'r, Odbc> for Value {
3036
fn decode(value: OdbcValueRef<'r>) -> Result<Self, BoxDynError> {
31-
let s = <String as Decode<'r, Odbc>>::decode(value)?;
32-
let trimmed = s.trim();
37+
if let Some(bytes) = value.blob {
38+
let text = std::str::from_utf8(bytes)?;
39+
let trimmed = text.trim_matches('\u{0}').trim();
40+
if !trimmed.is_empty() {
41+
return Ok(serde_json::from_str(trimmed)
42+
.unwrap_or_else(|_| serde_json::Value::String(trimmed.to_string())));
43+
}
44+
}
3345

34-
// Handle empty or null-like strings
35-
if trimmed.is_empty() || trimmed.eq_ignore_ascii_case("null") {
36-
return Ok(Value::Null);
46+
if let Some(text) = value.text {
47+
let trimmed = text.trim_matches('\u{0}').trim();
48+
if !trimmed.is_empty() {
49+
return Ok(serde_json::from_str(trimmed)
50+
.unwrap_or_else(|_| serde_json::Value::String(trimmed.to_string())));
51+
}
3752
}
3853

39-
// Try parsing as JSON
40-
match serde_json::from_str(trimmed) {
41-
Ok(value) => Ok(value),
42-
Err(e) => Err(format!("ODBC: cannot decode JSON from '{}': {}", trimmed, e).into()),
54+
if let Some(i) = value.int {
55+
return Ok(serde_json::Number::from(i).into());
56+
}
57+
if let Some(f) = value.float {
58+
return Ok(serde_json::Value::from(f));
4359
}
60+
61+
Ok(Value::Null)
4462
}
4563
}
4664

sqlx-core/src/odbc/types/uuid.rs

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ impl<'r> Decode<'r, Odbc> for Uuid {
3939
if bytes.len() == 16 {
4040
return Ok(Uuid::from_bytes(bytes.try_into()?));
4141
} else if bytes.len() == 128 {
42-
// Each byte is ASCII '0' or '1' representing a bit
4342
let mut uuid_bytes = [0u8; 16];
4443
for (i, chunk) in bytes.chunks(8).enumerate() {
4544
if i >= 16 {
@@ -48,7 +47,6 @@ impl<'r> Decode<'r, Odbc> for Uuid {
4847
let mut byte_val = 0u8;
4948
for (j, &bit_byte) in chunk.iter().enumerate() {
5049
if bit_byte == 49 {
51-
// ASCII '1'
5250
byte_val |= 1 << (7 - j);
5351
}
5452
}
@@ -57,10 +55,42 @@ impl<'r> Decode<'r, Odbc> for Uuid {
5755
return Ok(Uuid::from_bytes(uuid_bytes));
5856
}
5957
// Some drivers may return UUIDs as ASCII/UTF-8 bytes
60-
let s = std::str::from_utf8(bytes)?.trim();
58+
let s = std::str::from_utf8(bytes)?;
59+
let s = s.trim_matches('\u{0}').trim();
60+
let s = if s.len() > 3 && (s.starts_with("X'") || s.starts_with("x'")) && s.ends_with("'") {
61+
&s[2..s.len() - 1]
62+
} else {
63+
s
64+
};
65+
// If it's 32 hex digits without dashes, accept it
66+
if s.len() == 32 && s.chars().all(|c| c.is_ascii_hexdigit()) {
67+
let mut buf = [0u8; 16];
68+
for i in 0..16 {
69+
let byte_str = &s[i * 2..i * 2 + 2];
70+
buf[i] = u8::from_str_radix(byte_str, 16)?;
71+
}
72+
return Ok(Uuid::from_bytes(buf));
73+
}
6174
return Ok(Uuid::from_str(s).map_err(|e| format!("Invalid UUID: {}, error: {}", s, e))?);
6275
}
63-
let s = <String as Decode<'r, Odbc>>::decode(value)?;
64-
Ok(Uuid::from_str(s.trim()).map_err(|e| format!("Invalid UUID: {}, error: {}", s, e))?)
76+
let mut s = <String as Decode<'r, Odbc>>::decode(value)?;
77+
if s.ends_with('\u{0}') {
78+
s = s.trim_end_matches('\u{0}').to_string();
79+
}
80+
let s = s.trim();
81+
let s = if s.len() > 3 && (s.starts_with("X'") || s.starts_with("x'")) && s.ends_with("'") {
82+
&s[2..s.len() - 1]
83+
} else {
84+
s
85+
};
86+
if s.len() == 32 && s.chars().all(|c| c.is_ascii_hexdigit()) {
87+
let mut buf = [0u8; 16];
88+
for i in 0..16 {
89+
let byte_str = &s[i * 2..i * 2 + 2];
90+
buf[i] = u8::from_str_radix(byte_str, 16)?;
91+
}
92+
return Ok(Uuid::from_bytes(buf));
93+
}
94+
Ok(Uuid::from_str(s).map_err(|e| format!("Invalid UUID: {}, error: {}", s, e))?)
6595
}
6696
}

0 commit comments

Comments
 (0)