@@ -211,7 +211,7 @@ fn to_sql_parameter(value: &ParameterValue) -> Result<Box<dyn ToSql + Send + Syn
211211 . with_context ( || format ! ( "invalid decimal {v}" ) ) ?;
212212 Ok ( Box :: new ( dec) )
213213 }
214- ParameterValue :: Range32 ( ( lower, upper) ) => {
214+ ParameterValue :: RangeInt32 ( ( lower, upper) ) => {
215215 let lbound = lower. map ( |( value, kind) | {
216216 postgres_range:: RangeBound :: new ( value, range_bound_kind ( kind) )
217217 } ) ;
@@ -221,7 +221,7 @@ fn to_sql_parameter(value: &ParameterValue) -> Result<Box<dyn ToSql + Send + Syn
221221 let r = postgres_range:: Range :: new ( lbound, ubound) ;
222222 Ok ( Box :: new ( r) )
223223 }
224- ParameterValue :: Range64 ( ( lower, upper) ) => {
224+ ParameterValue :: RangeInt64 ( ( lower, upper) ) => {
225225 let lbound = lower. map ( |( value, kind) | {
226226 postgres_range:: RangeBound :: new ( value, range_bound_kind ( kind) )
227227 } ) ;
@@ -231,8 +231,48 @@ fn to_sql_parameter(value: &ParameterValue) -> Result<Box<dyn ToSql + Send + Syn
231231 let r = postgres_range:: Range :: new ( lbound, ubound) ;
232232 Ok ( Box :: new ( r) )
233233 }
234+ ParameterValue :: RangeDecimal ( ( lower, upper) ) => {
235+ let lbound = match lower {
236+ None => None ,
237+ Some ( ( value, kind) ) => {
238+ let dec = rust_decimal:: Decimal :: from_str_exact ( value)
239+ . with_context ( || format ! ( "invalid decimal {value}" ) ) ?;
240+ let dec = RangeableDecimal ( dec) ;
241+ Some ( postgres_range:: RangeBound :: new (
242+ dec,
243+ range_bound_kind ( * kind) ,
244+ ) )
245+ }
246+ } ;
247+ let ubound = match upper {
248+ None => None ,
249+ Some ( ( value, kind) ) => {
250+ let dec = rust_decimal:: Decimal :: from_str_exact ( value)
251+ . with_context ( || format ! ( "invalid decimal {value}" ) ) ?;
252+ let dec = RangeableDecimal ( dec) ;
253+ Some ( postgres_range:: RangeBound :: new (
254+ dec,
255+ range_bound_kind ( * kind) ,
256+ ) )
257+ }
258+ } ;
259+ let r = postgres_range:: Range :: new ( lbound, ubound) ;
260+ Ok ( Box :: new ( r) )
261+ }
234262 ParameterValue :: ArrayInt32 ( vs) => Ok ( Box :: new ( vs. to_owned ( ) ) ) ,
235263 ParameterValue :: ArrayInt64 ( vs) => Ok ( Box :: new ( vs. to_owned ( ) ) ) ,
264+ ParameterValue :: ArrayDecimal ( vs) => {
265+ let decs = vs
266+ . iter ( )
267+ . map ( |v| match v {
268+ None => Ok ( None ) ,
269+ Some ( v) => rust_decimal:: Decimal :: from_str_exact ( v)
270+ . with_context ( || format ! ( "invalid decimal {v}" ) )
271+ . map ( Some ) ,
272+ } )
273+ . collect :: < anyhow:: Result < Vec < _ > > > ( ) ?;
274+ Ok ( Box :: new ( decs) )
275+ }
236276 ParameterValue :: ArrayStr ( vs) => Ok ( Box :: new ( vs. to_owned ( ) ) ) ,
237277 ParameterValue :: Interval ( v) => Ok ( Box :: new ( Interval ( * v) ) ) ,
238278 ParameterValue :: DbNull => Ok ( Box :: new ( PgNull ) ) ,
@@ -277,11 +317,14 @@ fn convert_data_type(pg_type: &Type) -> DbDataType {
277317 Type :: UUID => DbDataType :: Uuid ,
278318 Type :: JSONB => DbDataType :: Jsonb ,
279319 Type :: NUMERIC => DbDataType :: Decimal ,
280- Type :: INT4_RANGE => DbDataType :: Range32 ,
281- Type :: INT8_RANGE => DbDataType :: Range64 ,
320+ Type :: INT4_RANGE => DbDataType :: RangeInt32 ,
321+ Type :: INT8_RANGE => DbDataType :: RangeInt64 ,
322+ Type :: NUM_RANGE => DbDataType :: RangeDecimal ,
282323 Type :: INT4_ARRAY => DbDataType :: ArrayInt32 ,
283324 Type :: INT8_ARRAY => DbDataType :: ArrayInt64 ,
325+ Type :: NUMERIC_ARRAY => DbDataType :: ArrayDecimal ,
284326 Type :: TEXT_ARRAY | Type :: VARCHAR_ARRAY | Type :: BPCHAR_ARRAY => DbDataType :: ArrayStr ,
327+ Type :: INTERVAL => DbDataType :: Interval ,
285328 _ => {
286329 tracing:: debug!( "Couldn't convert Postgres type {} to WIT" , pg_type. name( ) , ) ;
287330 DbDataType :: Other
@@ -406,7 +449,7 @@ fn convert_entry(row: &Row, index: usize) -> anyhow::Result<DbValue> {
406449 Some ( v) => {
407450 let lower = v. lower ( ) . map ( tuplify_range_bound) ;
408451 let upper = v. upper ( ) . map ( tuplify_range_bound) ;
409- DbValue :: Range32 ( ( lower, upper) )
452+ DbValue :: RangeInt32 ( ( lower, upper) )
410453 }
411454 None => DbValue :: DbNull ,
412455 }
@@ -417,7 +460,22 @@ fn convert_entry(row: &Row, index: usize) -> anyhow::Result<DbValue> {
417460 Some ( v) => {
418461 let lower = v. lower ( ) . map ( tuplify_range_bound) ;
419462 let upper = v. upper ( ) . map ( tuplify_range_bound) ;
420- DbValue :: Range64 ( ( lower, upper) )
463+ DbValue :: RangeInt64 ( ( lower, upper) )
464+ }
465+ None => DbValue :: DbNull ,
466+ }
467+ }
468+ & Type :: NUM_RANGE => {
469+ let value: Option < postgres_range:: Range < RangeableDecimal > > = row. try_get ( index) ?;
470+ match value {
471+ Some ( v) => {
472+ let lower = v
473+ . lower ( )
474+ . map ( |b| tuplify_range_bound_map ( b, |d| d. 0 . to_string ( ) ) ) ;
475+ let upper = v
476+ . upper ( )
477+ . map ( |b| tuplify_range_bound_map ( b, |d| d. 0 . to_string ( ) ) ) ;
478+ DbValue :: RangeDecimal ( ( lower, upper) )
421479 }
422480 None => DbValue :: DbNull ,
423481 }
@@ -436,6 +494,16 @@ fn convert_entry(row: &Row, index: usize) -> anyhow::Result<DbValue> {
436494 None => DbValue :: DbNull ,
437495 }
438496 }
497+ & Type :: NUMERIC_ARRAY => {
498+ let value: Option < Vec < Option < rust_decimal:: Decimal > > > = row. try_get ( index) ?;
499+ match value {
500+ Some ( v) => {
501+ let dstrs = v. iter ( ) . map ( |opt| opt. map ( |d| d. to_string ( ) ) ) . collect ( ) ;
502+ DbValue :: ArrayDecimal ( dstrs)
503+ }
504+ None => DbValue :: DbNull ,
505+ }
506+ }
439507 & Type :: TEXT_ARRAY | & Type :: VARCHAR_ARRAY | & Type :: BPCHAR_ARRAY => {
440508 let value: Option < Vec < Option < String > > > = row. try_get ( index) ?;
441509 match value {
@@ -468,6 +536,13 @@ fn tuplify_range_bound<S: postgres_range::BoundSided, T: Copy>(
468536 ( value. value , wit_bound_kind ( value. type_ ) )
469537}
470538
539+ fn tuplify_range_bound_map < S : postgres_range:: BoundSided , T , U > (
540+ value : & postgres_range:: RangeBound < S , T > ,
541+ map_fn : impl Fn ( & T ) -> U ,
542+ ) -> ( U , v4:: RangeBoundKind ) {
543+ ( map_fn ( & value. value ) , wit_bound_kind ( value. type_ ) )
544+ }
545+
471546fn wit_bound_kind ( bound_type : postgres_range:: BoundType ) -> v4:: RangeBoundKind {
472547 match bound_type {
473548 postgres_range:: BoundType :: Inclusive => v4:: RangeBoundKind :: Inclusive ,
@@ -629,6 +704,56 @@ impl std::fmt::Debug for IntervalLengthError {
629704 }
630705}
631706
707+ #[ derive( Clone , Copy , Debug , PartialEq , PartialOrd ) ]
708+ struct RangeableDecimal ( rust_decimal:: Decimal ) ;
709+
710+ impl ToSql for RangeableDecimal {
711+ tokio_postgres:: types:: to_sql_checked!( ) ;
712+
713+ fn to_sql (
714+ & self ,
715+ ty : & Type ,
716+ out : & mut tokio_postgres:: types:: private:: BytesMut ,
717+ ) -> Result < tokio_postgres:: types:: IsNull , Box < dyn std:: error:: Error + Sync + Send > >
718+ where
719+ Self : Sized ,
720+ {
721+ self . 0 . to_sql ( ty, out)
722+ }
723+
724+ fn accepts ( ty : & Type ) -> bool
725+ where
726+ Self : Sized ,
727+ {
728+ <rust_decimal:: Decimal as ToSql >:: accepts ( ty)
729+ }
730+ }
731+
732+ impl FromSql < ' _ > for RangeableDecimal {
733+ fn from_sql (
734+ ty : & Type ,
735+ raw : & ' _ [ u8 ] ,
736+ ) -> std:: result:: Result < Self , Box < dyn std:: error:: Error + Sync + Send > > {
737+ let d = <rust_decimal:: Decimal as FromSql >:: from_sql ( ty, raw) ?;
738+ Ok ( Self ( d) )
739+ }
740+
741+ fn accepts ( ty : & Type ) -> bool {
742+ <rust_decimal:: Decimal as FromSql >:: accepts ( ty)
743+ }
744+ }
745+
746+ impl postgres_range:: Normalizable for RangeableDecimal {
747+ fn normalize < S > (
748+ bound : postgres_range:: RangeBound < S , Self > ,
749+ ) -> postgres_range:: RangeBound < S , Self >
750+ where
751+ S : postgres_range:: BoundSided ,
752+ {
753+ bound
754+ }
755+ }
756+
632757/// Workaround for moka returning Arc<Error> which, although
633758/// necessary for concurrency, does not play well with others.
634759struct ArcError ( std:: sync:: Arc < anyhow:: Error > ) ;
@@ -649,4 +774,4 @@ impl std::fmt::Display for ArcError {
649774 fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
650775 std:: fmt:: Display :: fmt ( & self . 0 , f)
651776 }
652- }
777+ }
0 commit comments