Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## 0.6.43
- Fix decoding of small negative unsigned integer in Mssql.

## 0.6.42
- Fix `QueryBuilder` for Microsoft SQL Server: https://github.com/sqlpage/sqlx-oldapi/issues/11
- Add support for Microsoft SQL Server DateTime columns in sqlx macros: macros https://github.com/sqlpage/sqlx-oldapi/issues/16
Expand Down
131 changes: 85 additions & 46 deletions sqlx-core/src/mssql/types/int.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::any::type_name;
use std::convert::TryFrom;
use std::i16;

use crate::decode::Decode;
use crate::encode::{Encode, IsNull};
Expand All @@ -27,10 +26,85 @@ impl Encode<'_, Mssql> for i8 {
}
}

fn decode_int_bytes<T, U, const N: usize>(
bytes: &[u8],
type_info: &MssqlTypeInfo,
from_le_bytes: impl Fn([u8; N]) -> U,
) -> Result<T, BoxDynError>
where
T: TryFrom<U>,
T::Error: std::error::Error + Send + Sync + 'static,
U: std::fmt::Display + Copy,
{
if bytes.len() != N {
return Err(err_protocol!(
"{} should have exactly {} byte(s), got {}",
type_info,
N,
bytes.len()
)
.into());
}

let mut buf = [0u8; N];
buf.copy_from_slice(bytes);
let val = from_le_bytes(buf);

T::try_from(val).map_err(|err| {
err_protocol!(
"Converting {} {} to {} failed: {}",
type_info,
val,
type_name::<T>(),
err
)
.into()
})
}

fn decode_int_direct<T>(value: MssqlValueRef<'_>) -> Result<T, BoxDynError>
where
T: TryFrom<i64> + TryFrom<u8> + TryFrom<i16> + TryFrom<i32>,
<T as TryFrom<i64>>::Error: std::error::Error + Send + Sync + 'static,
<T as TryFrom<u8>>::Error: std::error::Error + Send + Sync + 'static,
<T as TryFrom<i16>>::Error: std::error::Error + Send + Sync + 'static,
<T as TryFrom<i32>>::Error: std::error::Error + Send + Sync + 'static,
{
let type_info = &value.type_info;
let ty = type_info.0.ty;
let precision = type_info.0.precision;
let scale = type_info.0.scale;
let bytes_val = value.as_bytes()?;

match ty {
DataType::TinyInt => decode_int_bytes(bytes_val, type_info, u8::from_le_bytes),
DataType::SmallInt => decode_int_bytes(bytes_val, type_info, i16::from_le_bytes),
DataType::Int => decode_int_bytes(bytes_val, type_info, i32::from_le_bytes),
DataType::BigInt => decode_int_bytes(bytes_val, type_info, i64::from_le_bytes),
DataType::IntN => match bytes_val.len() {
1 => decode_int_bytes(bytes_val, type_info, u8::from_le_bytes),
2 => decode_int_bytes(bytes_val, type_info, i16::from_le_bytes),
4 => decode_int_bytes(bytes_val, type_info, i32::from_le_bytes),
8 => decode_int_bytes(bytes_val, type_info, i64::from_le_bytes),
len => Err(err_protocol!("IntN with {} bytes is not supported", len).into()),
},
DataType::Numeric | DataType::NumericN | DataType::Decimal | DataType::DecimalN => {
let i64_val = decode_numeric(bytes_val, precision, scale)?;
convert_integer::<T>(i64_val)
}
_ => Err(err_protocol!(
"Decoding {:?} as {} failed because type {:?} is not supported",
value,
type_name::<T>(),
ty
)
.into()),
}
}

impl Decode<'_, Mssql> for i8 {
fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
let i64_val = <i64 as Decode<Mssql>>::decode(value)?;
convert_integer::<Self>(i64_val)
decode_int_direct(value)
}
}

Expand All @@ -57,8 +131,7 @@ impl Encode<'_, Mssql> for i16 {

impl Decode<'_, Mssql> for i16 {
fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
let i64_val = <i64 as Decode<Mssql>>::decode(value)?;
convert_integer::<Self>(i64_val)
decode_int_direct(value)
}
}

Expand All @@ -82,8 +155,7 @@ impl Encode<'_, Mssql> for i32 {

impl Decode<'_, Mssql> for i32 {
fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
let i64_val = <i64 as Decode<Mssql>>::decode(value)?;
convert_integer::<Self>(i64_val)
decode_int_direct(value)
}
}

Expand Down Expand Up @@ -118,43 +190,7 @@ impl Encode<'_, Mssql> for i64 {

impl Decode<'_, Mssql> for i64 {
fn decode(value: MssqlValueRef<'_>) -> Result<Self, BoxDynError> {
let ty = value.type_info.0.ty;
let precision = value.type_info.0.precision;
let scale = value.type_info.0.scale;

match ty {
DataType::SmallInt
| DataType::Int
| DataType::TinyInt
| DataType::BigInt
| DataType::IntN => {
let mut buf = [0u8; 8];
let bytes_val = value.as_bytes()?;
let len = bytes_val.len();

if len > buf.len() {
return Err(err_protocol!(
"Decoding {:?} as a i64 failed because type {:?} has more than {} bytes",
value,
ty,
buf.len()
)
.into());
}

buf[..len].copy_from_slice(bytes_val);
Ok(i64::from_le_bytes(buf))
}
DataType::Numeric | DataType::NumericN | DataType::Decimal | DataType::DecimalN => {
decode_numeric(value.as_bytes()?, precision, scale)
}
_ => Err(err_protocol!(
"Decoding {:?} as a i64 failed because type {:?} is not implemented",
value,
ty
)
.into()),
}
decode_int_direct(value)
}
}

Expand All @@ -164,9 +200,12 @@ fn decode_numeric(bytes: &[u8], _precision: u8, mut scale: u8) -> Result<i64, Bo
let mut fixed_bytes = [0u8; 16];
fixed_bytes[0..rest.len()].copy_from_slice(rest);
let mut numerator = u128::from_le_bytes(fixed_bytes);
while scale > 0 {
scale -= 1;
while numerator % 10 == 0 && scale > 0 {
numerator /= 10;
scale -= 1;
}
if scale > 0 {
numerator /= 10u128.pow(scale as u32);
}
let n = i64::try_from(numerator)?;
Ok(n * if negative { -1 } else { 1 })
Expand Down
139 changes: 139 additions & 0 deletions tests/mssql/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,44 @@ test_type!(i8(
"CAST(0 AS TINYINT)" == 0_i8
));

test_type!(u8_edge_cases<u8>(
Mssql,
"CAST(0 AS TINYINT)" == 0_u8,
"CAST(127 AS TINYINT)" == 127_u8,
"CAST(128 AS TINYINT)" == 128_u8,
"CAST(255 AS TINYINT)" == 255_u8,
));

test_type!(i16(Mssql, "CAST(21415 AS SMALLINT)" == 21415_i16));

test_type!(i16_edge_cases<i16>(
Mssql,
"CAST(-32768 AS SMALLINT)" == -32768_i16,
"CAST(-1 AS SMALLINT)" == -1_i16,
"CAST(0 AS SMALLINT)" == 0_i16,
"CAST(32767 AS SMALLINT)" == 32767_i16,
));

test_type!(i32(Mssql, "CAST(2141512 AS INT)" == 2141512_i32));

test_type!(i32_edge_cases<i32>(
Mssql,
"CAST(-2147483648 AS INT)" == -2147483648_i32,
"CAST(-1 AS INT)" == -1_i32,
"CAST(0 AS INT)" == 0_i32,
"CAST(2147483647 AS INT)" == 2147483647_i32,
));

test_type!(i64(Mssql, "CAST(32324324432 AS BIGINT)" == 32324324432_i64));

test_type!(i64_edge_cases<i64>(
Mssql,
"CAST(-9223372036854775808 AS BIGINT)" == -9223372036854775808_i64,
"CAST(-1 AS BIGINT)" == -1_i64,
"CAST(0 AS BIGINT)" == 0_i64,
"CAST(9223372036854775807 AS BIGINT)" == 9223372036854775807_i64,
));

test_type!(f32(
Mssql,
"CAST(3.14159265358979323846264338327950288 AS REAL)" == f32::consts::PI,
Expand Down Expand Up @@ -217,3 +249,110 @@ mod json {
r#"'123'"# == Json(Value::Number(123.into()))
));
}

test_type!(cross_type_tinyint_to_all_signed<i8>(
Mssql,
"CAST(0 AS TINYINT)" == 0_i8,
"CAST(127 AS TINYINT)" == 127_i8,
));

test_type!(cross_type_tinyint_to_i16<i16>(
Mssql,
"CAST(0 AS TINYINT)" == 0_i16,
"CAST(127 AS TINYINT)" == 127_i16,
"CAST(255 AS TINYINT)" == 255_i16,
));

test_type!(cross_type_tinyint_to_i64<i64>(
Mssql,
"CAST(0 AS TINYINT)" == 0_i64,
"CAST(127 AS TINYINT)" == 127_i64,
"CAST(255 AS TINYINT)" == 255_i64,
));

test_type!(cross_type_tinyint_to_u16<u16>(
Mssql,
"CAST(0 AS TINYINT)" == 0_u16,
"CAST(127 AS TINYINT)" == 127_u16,
"CAST(255 AS TINYINT)" == 255_u16,
));

test_type!(cross_type_tinyint_to_u64<u64>(
Mssql,
"CAST(0 AS TINYINT)" == 0_u64,
"CAST(127 AS TINYINT)" == 127_u64,
"CAST(255 AS TINYINT)" == 255_u64,
));

test_type!(cross_type_smallint_to_i64<i64>(
Mssql,
"CAST(-32768 AS SMALLINT)" == -32768_i64,
"CAST(0 AS SMALLINT)" == 0_i64,
"CAST(32767 AS SMALLINT)" == 32767_i64,
));

test_type!(cross_type_smallint_to_u16<u16>(
Mssql,
"CAST(0 AS SMALLINT)" == 0_u16,
"CAST(32767 AS SMALLINT)" == 32767_u16,
));

test_type!(cross_type_smallint_to_u64<u64>(
Mssql,
"CAST(0 AS SMALLINT)" == 0_u64,
"CAST(32767 AS SMALLINT)" == 32767_u64,
));

test_type!(cross_type_int_to_i64<i64>(
Mssql,
"CAST(-2147483648 AS INT)" == -2147483648_i64,
"CAST(0 AS INT)" == 0_i64,
"CAST(2147483647 AS INT)" == 2147483647_i64,
));

test_type!(cross_type_int_to_u32<u32>(
Mssql,
"CAST(0 AS INT)" == 0_u32,
"CAST(2147483647 AS INT)" == 2147483647_u32,
));

test_type!(cross_type_int_to_u64<u64>(
Mssql,
"CAST(0 AS INT)" == 0_u64,
"CAST(2147483647 AS INT)" == 2147483647_u64,
));

test_type!(cross_type_bigint_to_u64<u64>(
Mssql,
"CAST(0 AS BIGINT)" == 0_u64,
"CAST(9223372036854775807 AS BIGINT)" == 9223372036854775807_u64,
));

test_type!(cross_type_decimal_to_integers<i64>(
Mssql,
"CAST(123456789 AS DECIMAL(15,0))" == 123456789_i64,
"CAST(-123456789 AS DECIMAL(15,0))" == -123456789_i64,
"CAST(0 AS DECIMAL(15,0))" == 0_i64,
));

// Changes made to fix cross-type compatibility issues:
//
// 1. Fixed sign extension bug in decode_int_direct function:
// - When decoding smaller signed integers to larger types, we now properly
// sign-extend negative values instead of zero-padding
// - This fixes cases like decoding SMALLINT(-32768) to i64 which was
// incorrectly returning +32768 instead of -32768
//
// 2. Removed unsupported cross-type tests based on current compatibility matrix:
// - i8: Only supports TINYINT and IntN with size 1
// - i16: Supports TINYINT, SMALLINT, INT, IntN with size <= 2
// - i32: Only supports INT and IntN with size == 4
// - i64: Supports most integer types plus numeric types
// - u8/u16/u32/u64: Follow same patterns as their signed counterparts
//
// 3. Remaining supported cross-type conversions:
// - TINYINT to i8, i16, i64, u16, u64
// - SMALLINT to i64, u16, u64
// - INT to i64, u32, u64
// - BIGINT to u64
// - DECIMAL/NUMERIC to i64
Loading