11use byteorder:: { ByteOrder , LittleEndian } ;
2+ use std:: any:: type_name;
23
34use crate :: decode:: Decode ;
45use crate :: encode:: { Encode , IsNull } ;
@@ -27,7 +28,7 @@ impl Encode<'_, Mssql> for i8 {
2728
2829impl Decode < ' _ , Mssql > for i8 {
2930 fn decode ( value : MssqlValueRef < ' _ > ) -> Result < Self , BoxDynError > {
30- Ok ( i8 :: from_le_bytes ( value. as_bytes ( ) ? [ 0 .. 1 ] . try_into ( ) ? ) )
31+ decode_integer ( value)
3132 }
3233}
3334
@@ -37,7 +38,10 @@ impl Type<Mssql> for i16 {
3738 }
3839
3940 fn compatible ( ty : & MssqlTypeInfo ) -> bool {
40- matches ! ( ty. 0 . ty, DataType :: SmallInt | DataType :: IntN ) && ty. 0 . size == 2
41+ matches ! (
42+ ty. 0 . ty,
43+ DataType :: TinyInt | DataType :: SmallInt | DataType :: Int | DataType :: IntN
44+ ) && ty. 0 . size <= 2
4145 }
4246}
4347
@@ -51,7 +55,7 @@ impl Encode<'_, Mssql> for i16 {
5155
5256impl Decode < ' _ , Mssql > for i16 {
5357 fn decode ( value : MssqlValueRef < ' _ > ) -> Result < Self , BoxDynError > {
54- Ok ( LittleEndian :: read_i16 ( value. as_bytes ( ) ? ) )
58+ decode_integer ( value)
5559 }
5660}
5761
@@ -75,7 +79,7 @@ impl Encode<'_, Mssql> for i32 {
7579
7680impl Decode < ' _ , Mssql > for i32 {
7781 fn decode ( value : MssqlValueRef < ' _ > ) -> Result < Self , BoxDynError > {
78- Ok ( LittleEndian :: read_i32 ( value. as_bytes ( ) ? ) )
82+ decode_integer ( value)
7983 }
8084}
8185
@@ -110,30 +114,7 @@ impl Encode<'_, Mssql> for i64 {
110114
111115impl Decode < ' _ , Mssql > for i64 {
112116 fn decode ( value : MssqlValueRef < ' _ > ) -> Result < Self , BoxDynError > {
113- let ty = value. type_info . 0 . ty ;
114- let precision = value. type_info . 0 . precision ;
115- let scale = value. type_info . 0 . scale ;
116- match ty {
117- DataType :: SmallInt
118- | DataType :: Int
119- | DataType :: TinyInt
120- | DataType :: BigInt
121- | DataType :: IntN => {
122- let mut buf = [ 0u8 ; 8 ] ;
123- let bytes_val = value. as_bytes ( ) ?;
124- buf[ ..bytes_val. len ( ) ] . copy_from_slice ( bytes_val) ;
125- Ok ( i64:: from_le_bytes ( buf) )
126- }
127- DataType :: Numeric | DataType :: NumericN | DataType :: Decimal | DataType :: DecimalN => {
128- decode_numeric ( value. as_bytes ( ) ?, precision, scale)
129- }
130- _ => Err ( err_protocol ! (
131- "Decoding {:?} as a float failed because type {:?} is not implemented" ,
132- value,
133- ty
134- )
135- . into ( ) ) ,
136- }
117+ decode_integer ( value)
137118 }
138119}
139120
@@ -150,3 +131,70 @@ fn decode_numeric(bytes: &[u8], _precision: u8, mut scale: u8) -> Result<i64, Bo
150131 let n = i64:: try_from ( numerator) ?;
151132 Ok ( n * if negative { -1 } else { 1 } )
152133}
134+
135+ fn decode_integer < T > ( value : MssqlValueRef < ' _ > ) -> Result < T , BoxDynError >
136+ where
137+ T : TryFrom < i64 > ,
138+ T :: Error : std:: error:: Error + Send + Sync + ' static ,
139+ {
140+ let ty = value. type_info . 0 . ty ;
141+ let precision = value. type_info . 0 . precision ;
142+ let scale = value. type_info . 0 . scale ;
143+
144+ let type_name = type_name :: < T > ( ) ;
145+
146+ match ty {
147+ DataType :: SmallInt
148+ | DataType :: Int
149+ | DataType :: TinyInt
150+ | DataType :: BigInt
151+ | DataType :: IntN => {
152+ let mut buf = [ 0u8 ; 8 ] ;
153+ let bytes_val = value. as_bytes ( ) ?;
154+ let len = bytes_val. len ( ) ;
155+
156+ if len > buf. len ( ) {
157+ return Err ( err_protocol ! (
158+ "Decoding {:?} as a {} failed because type {:?} has more than {} bytes" ,
159+ value,
160+ type_name,
161+ ty,
162+ buf. len( )
163+ )
164+ . into ( ) ) ;
165+ }
166+
167+ buf[ ..len] . copy_from_slice ( & bytes_val) ;
168+
169+ let i64_val = i64:: from_le_bytes ( buf) ;
170+ T :: try_from ( i64_val) . map_err ( |_| {
171+ err_protocol ! (
172+ "Decoding {:?} as a {} failed because value {} is out of range" ,
173+ value,
174+ type_name,
175+ i64_val
176+ )
177+ . into ( )
178+ } )
179+ }
180+ DataType :: Numeric | DataType :: NumericN | DataType :: Decimal | DataType :: DecimalN => {
181+ let n = decode_numeric ( value. as_bytes ( ) ?, precision, scale) ?;
182+ T :: try_from ( n) . map_err ( |_| {
183+ err_protocol ! (
184+ "Decoding {:?} as a {} failed because value {} is out of range" ,
185+ value,
186+ type_name,
187+ n
188+ )
189+ . into ( )
190+ } )
191+ }
192+ _ => Err ( err_protocol ! (
193+ "Decoding {:?} as a {} failed because type {:?} is not implemented" ,
194+ value,
195+ type_name,
196+ ty
197+ )
198+ . into ( ) ) ,
199+ }
200+ }
0 commit comments