@@ -211,7 +211,7 @@ fn to_sql_parameter(value: &ParameterValue) -> Result<Box<dyn ToSql + Send + Syn
211
211
. with_context ( || format ! ( "invalid decimal {v}" ) ) ?;
212
212
Ok ( Box :: new ( dec) )
213
213
}
214
- ParameterValue :: Range32 ( ( lower, upper) ) => {
214
+ ParameterValue :: RangeInt32 ( ( lower, upper) ) => {
215
215
let lbound = lower. map ( |( value, kind) | {
216
216
postgres_range:: RangeBound :: new ( value, range_bound_kind ( kind) )
217
217
} ) ;
@@ -221,7 +221,7 @@ fn to_sql_parameter(value: &ParameterValue) -> Result<Box<dyn ToSql + Send + Syn
221
221
let r = postgres_range:: Range :: new ( lbound, ubound) ;
222
222
Ok ( Box :: new ( r) )
223
223
}
224
- ParameterValue :: Range64 ( ( lower, upper) ) => {
224
+ ParameterValue :: RangeInt64 ( ( lower, upper) ) => {
225
225
let lbound = lower. map ( |( value, kind) | {
226
226
postgres_range:: RangeBound :: new ( value, range_bound_kind ( kind) )
227
227
} ) ;
@@ -231,8 +231,48 @@ fn to_sql_parameter(value: &ParameterValue) -> Result<Box<dyn ToSql + Send + Syn
231
231
let r = postgres_range:: Range :: new ( lbound, ubound) ;
232
232
Ok ( Box :: new ( r) )
233
233
}
234
+ ParameterValue :: RangeDecimal ( ( lower, upper) ) => {
235
+ let lbound = match lower {
236
+ None => None ,
237
+ Some ( ( value, kind) ) => {
238
+ let dec = rust_decimal:: Decimal :: from_str_exact ( value)
239
+ . with_context ( || format ! ( "invalid decimal {value}" ) ) ?;
240
+ let dec = RangeableDecimal ( dec) ;
241
+ Some ( postgres_range:: RangeBound :: new (
242
+ dec,
243
+ range_bound_kind ( * kind) ,
244
+ ) )
245
+ }
246
+ } ;
247
+ let ubound = match upper {
248
+ None => None ,
249
+ Some ( ( value, kind) ) => {
250
+ let dec = rust_decimal:: Decimal :: from_str_exact ( value)
251
+ . with_context ( || format ! ( "invalid decimal {value}" ) ) ?;
252
+ let dec = RangeableDecimal ( dec) ;
253
+ Some ( postgres_range:: RangeBound :: new (
254
+ dec,
255
+ range_bound_kind ( * kind) ,
256
+ ) )
257
+ }
258
+ } ;
259
+ let r = postgres_range:: Range :: new ( lbound, ubound) ;
260
+ Ok ( Box :: new ( r) )
261
+ }
234
262
ParameterValue :: ArrayInt32 ( vs) => Ok ( Box :: new ( vs. to_owned ( ) ) ) ,
235
263
ParameterValue :: ArrayInt64 ( vs) => Ok ( Box :: new ( vs. to_owned ( ) ) ) ,
264
+ ParameterValue :: ArrayDecimal ( vs) => {
265
+ let decs = vs
266
+ . iter ( )
267
+ . map ( |v| match v {
268
+ None => Ok ( None ) ,
269
+ Some ( v) => rust_decimal:: Decimal :: from_str_exact ( v)
270
+ . with_context ( || format ! ( "invalid decimal {v}" ) )
271
+ . map ( Some ) ,
272
+ } )
273
+ . collect :: < anyhow:: Result < Vec < _ > > > ( ) ?;
274
+ Ok ( Box :: new ( decs) )
275
+ }
236
276
ParameterValue :: ArrayStr ( vs) => Ok ( Box :: new ( vs. to_owned ( ) ) ) ,
237
277
ParameterValue :: Interval ( v) => Ok ( Box :: new ( Interval ( * v) ) ) ,
238
278
ParameterValue :: DbNull => Ok ( Box :: new ( PgNull ) ) ,
@@ -277,11 +317,14 @@ fn convert_data_type(pg_type: &Type) -> DbDataType {
277
317
Type :: UUID => DbDataType :: Uuid ,
278
318
Type :: JSONB => DbDataType :: Jsonb ,
279
319
Type :: NUMERIC => DbDataType :: Decimal ,
280
- Type :: INT4_RANGE => DbDataType :: Range32 ,
281
- Type :: INT8_RANGE => DbDataType :: Range64 ,
320
+ Type :: INT4_RANGE => DbDataType :: RangeInt32 ,
321
+ Type :: INT8_RANGE => DbDataType :: RangeInt64 ,
322
+ Type :: NUM_RANGE => DbDataType :: RangeDecimal ,
282
323
Type :: INT4_ARRAY => DbDataType :: ArrayInt32 ,
283
324
Type :: INT8_ARRAY => DbDataType :: ArrayInt64 ,
325
+ Type :: NUMERIC_ARRAY => DbDataType :: ArrayDecimal ,
284
326
Type :: TEXT_ARRAY | Type :: VARCHAR_ARRAY | Type :: BPCHAR_ARRAY => DbDataType :: ArrayStr ,
327
+ Type :: INTERVAL => DbDataType :: Interval ,
285
328
_ => {
286
329
tracing:: debug!( "Couldn't convert Postgres type {} to WIT" , pg_type. name( ) , ) ;
287
330
DbDataType :: Other
@@ -406,7 +449,7 @@ fn convert_entry(row: &Row, index: usize) -> anyhow::Result<DbValue> {
406
449
Some ( v) => {
407
450
let lower = v. lower ( ) . map ( tuplify_range_bound) ;
408
451
let upper = v. upper ( ) . map ( tuplify_range_bound) ;
409
- DbValue :: Range32 ( ( lower, upper) )
452
+ DbValue :: RangeInt32 ( ( lower, upper) )
410
453
}
411
454
None => DbValue :: DbNull ,
412
455
}
@@ -417,7 +460,22 @@ fn convert_entry(row: &Row, index: usize) -> anyhow::Result<DbValue> {
417
460
Some ( v) => {
418
461
let lower = v. lower ( ) . map ( tuplify_range_bound) ;
419
462
let upper = v. upper ( ) . map ( tuplify_range_bound) ;
420
- DbValue :: Range64 ( ( lower, upper) )
463
+ DbValue :: RangeInt64 ( ( lower, upper) )
464
+ }
465
+ None => DbValue :: DbNull ,
466
+ }
467
+ }
468
+ & Type :: NUM_RANGE => {
469
+ let value: Option < postgres_range:: Range < RangeableDecimal > > = row. try_get ( index) ?;
470
+ match value {
471
+ Some ( v) => {
472
+ let lower = v
473
+ . lower ( )
474
+ . map ( |b| tuplify_range_bound_map ( b, |d| d. 0 . to_string ( ) ) ) ;
475
+ let upper = v
476
+ . upper ( )
477
+ . map ( |b| tuplify_range_bound_map ( b, |d| d. 0 . to_string ( ) ) ) ;
478
+ DbValue :: RangeDecimal ( ( lower, upper) )
421
479
}
422
480
None => DbValue :: DbNull ,
423
481
}
@@ -436,6 +494,16 @@ fn convert_entry(row: &Row, index: usize) -> anyhow::Result<DbValue> {
436
494
None => DbValue :: DbNull ,
437
495
}
438
496
}
497
+ & Type :: NUMERIC_ARRAY => {
498
+ let value: Option < Vec < Option < rust_decimal:: Decimal > > > = row. try_get ( index) ?;
499
+ match value {
500
+ Some ( v) => {
501
+ let dstrs = v. iter ( ) . map ( |opt| opt. map ( |d| d. to_string ( ) ) ) . collect ( ) ;
502
+ DbValue :: ArrayDecimal ( dstrs)
503
+ }
504
+ None => DbValue :: DbNull ,
505
+ }
506
+ }
439
507
& Type :: TEXT_ARRAY | & Type :: VARCHAR_ARRAY | & Type :: BPCHAR_ARRAY => {
440
508
let value: Option < Vec < Option < String > > > = row. try_get ( index) ?;
441
509
match value {
@@ -468,6 +536,13 @@ fn tuplify_range_bound<S: postgres_range::BoundSided, T: Copy>(
468
536
( value. value , wit_bound_kind ( value. type_ ) )
469
537
}
470
538
539
+ fn tuplify_range_bound_map < S : postgres_range:: BoundSided , T , U > (
540
+ value : & postgres_range:: RangeBound < S , T > ,
541
+ map_fn : impl Fn ( & T ) -> U ,
542
+ ) -> ( U , v4:: RangeBoundKind ) {
543
+ ( map_fn ( & value. value ) , wit_bound_kind ( value. type_ ) )
544
+ }
545
+
471
546
fn wit_bound_kind ( bound_type : postgres_range:: BoundType ) -> v4:: RangeBoundKind {
472
547
match bound_type {
473
548
postgres_range:: BoundType :: Inclusive => v4:: RangeBoundKind :: Inclusive ,
@@ -629,6 +704,56 @@ impl std::fmt::Debug for IntervalLengthError {
629
704
}
630
705
}
631
706
707
+ #[ derive( Clone , Copy , Debug , PartialEq , PartialOrd ) ]
708
+ struct RangeableDecimal ( rust_decimal:: Decimal ) ;
709
+
710
+ impl ToSql for RangeableDecimal {
711
+ tokio_postgres:: types:: to_sql_checked!( ) ;
712
+
713
+ fn to_sql (
714
+ & self ,
715
+ ty : & Type ,
716
+ out : & mut tokio_postgres:: types:: private:: BytesMut ,
717
+ ) -> Result < tokio_postgres:: types:: IsNull , Box < dyn std:: error:: Error + Sync + Send > >
718
+ where
719
+ Self : Sized ,
720
+ {
721
+ self . 0 . to_sql ( ty, out)
722
+ }
723
+
724
+ fn accepts ( ty : & Type ) -> bool
725
+ where
726
+ Self : Sized ,
727
+ {
728
+ <rust_decimal:: Decimal as ToSql >:: accepts ( ty)
729
+ }
730
+ }
731
+
732
+ impl FromSql < ' _ > for RangeableDecimal {
733
+ fn from_sql (
734
+ ty : & Type ,
735
+ raw : & ' _ [ u8 ] ,
736
+ ) -> std:: result:: Result < Self , Box < dyn std:: error:: Error + Sync + Send > > {
737
+ let d = <rust_decimal:: Decimal as FromSql >:: from_sql ( ty, raw) ?;
738
+ Ok ( Self ( d) )
739
+ }
740
+
741
+ fn accepts ( ty : & Type ) -> bool {
742
+ <rust_decimal:: Decimal as FromSql >:: accepts ( ty)
743
+ }
744
+ }
745
+
746
+ impl postgres_range:: Normalizable for RangeableDecimal {
747
+ fn normalize < S > (
748
+ bound : postgres_range:: RangeBound < S , Self > ,
749
+ ) -> postgres_range:: RangeBound < S , Self >
750
+ where
751
+ S : postgres_range:: BoundSided ,
752
+ {
753
+ bound
754
+ }
755
+ }
756
+
632
757
/// Workaround for moka returning Arc<Error> which, although
633
758
/// necessary for concurrency, does not play well with others.
634
759
struct ArcError ( std:: sync:: Arc < anyhow:: Error > ) ;
@@ -649,4 +774,4 @@ impl std::fmt::Display for ArcError {
649
774
fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
650
775
std:: fmt:: Display :: fmt ( & self . 0 , f)
651
776
}
652
- }
777
+ }
0 commit comments