@@ -5,8 +5,8 @@ use spin_world::async_trait;
5
5
use spin_world:: spin:: postgres4_0_0:: postgres:: {
6
6
self as v4, Column , DbDataType , DbValue , ParameterValue , RowSet ,
7
7
} ;
8
- use tokio_postgres:: types:: Type ;
9
- use tokio_postgres:: { config:: SslMode , types :: ToSql , NoTls , Row } ;
8
+ use tokio_postgres:: types:: { FromSql , ToSql , Type } ;
9
+ use tokio_postgres:: { config:: SslMode , NoTls , Row } ;
10
10
11
11
/// Max connections in a given address' connection pool
12
12
const CONNECTION_POOL_SIZE : usize = 64 ;
@@ -234,6 +234,7 @@ fn to_sql_parameter(value: &ParameterValue) -> Result<Box<dyn ToSql + Send + Syn
234
234
ParameterValue :: ArrayInt32 ( vs) => Ok ( Box :: new ( vs. to_owned ( ) ) ) ,
235
235
ParameterValue :: ArrayInt64 ( vs) => Ok ( Box :: new ( vs. to_owned ( ) ) ) ,
236
236
ParameterValue :: ArrayStr ( vs) => Ok ( Box :: new ( vs. to_owned ( ) ) ) ,
237
+ ParameterValue :: Interval ( v) => Ok ( Box :: new ( Interval ( * v) ) ) ,
237
238
ParameterValue :: DbNull => Ok ( Box :: new ( PgNull ) ) ,
238
239
}
239
240
}
@@ -442,6 +443,13 @@ fn convert_entry(row: &Row, index: usize) -> anyhow::Result<DbValue> {
442
443
None => DbValue :: DbNull ,
443
444
}
444
445
}
446
+ & Type :: INTERVAL => {
447
+ let value: Option < Interval > = row. try_get ( index) ?;
448
+ match value {
449
+ Some ( v) => DbValue :: Interval ( v. 0 ) ,
450
+ None => DbValue :: DbNull ,
451
+ }
452
+ }
445
453
t => {
446
454
tracing:: debug!(
447
455
"Couldn't convert Postgres type {} in column {}" ,
@@ -544,6 +552,83 @@ impl std::fmt::Debug for PgNull {
544
552
}
545
553
}
546
554
555
+ #[ derive( Debug ) ]
556
+ struct Interval ( v4:: Interval ) ;
557
+
558
+ impl ToSql for Interval {
559
+ tokio_postgres:: types:: to_sql_checked!( ) ;
560
+
561
+ fn to_sql (
562
+ & self ,
563
+ _ty : & Type ,
564
+ out : & mut tokio_postgres:: types:: private:: BytesMut ,
565
+ ) -> Result < tokio_postgres:: types:: IsNull , Box < dyn std:: error:: Error + Sync + Send > >
566
+ where
567
+ Self : Sized ,
568
+ {
569
+ use bytes:: BufMut ;
570
+
571
+ out. put_i64 ( self . 0 . micros ) ;
572
+ out. put_i32 ( self . 0 . days ) ;
573
+ out. put_i32 ( self . 0 . months ) ;
574
+
575
+ Ok ( tokio_postgres:: types:: IsNull :: No )
576
+ }
577
+
578
+ fn accepts ( ty : & Type ) -> bool
579
+ where
580
+ Self : Sized ,
581
+ {
582
+ matches ! ( ty, & Type :: INTERVAL )
583
+ }
584
+ }
585
+
586
+ impl FromSql < ' _ > for Interval {
587
+ fn from_sql (
588
+ _ty : & Type ,
589
+ raw : & ' _ [ u8 ] ,
590
+ ) -> std:: result:: Result < Self , Box < dyn std:: error:: Error + Sync + Send > > {
591
+ const EXPECTED_LEN : usize = size_of :: < i64 > ( ) + size_of :: < i32 > ( ) + size_of :: < i32 > ( ) ;
592
+
593
+ if raw. len ( ) != EXPECTED_LEN {
594
+ return Err ( Box :: new ( IntervalLengthError ) ) ;
595
+ }
596
+
597
+ let ( micro_bytes, rest) = raw. split_at ( size_of :: < i64 > ( ) ) ;
598
+ let ( day_bytes, rest) = rest. split_at ( size_of :: < i32 > ( ) ) ;
599
+ let month_bytes = rest;
600
+ let months = i32:: from_be_bytes ( month_bytes. try_into ( ) . unwrap ( ) ) ;
601
+ let days = i32:: from_be_bytes ( day_bytes. try_into ( ) . unwrap ( ) ) ;
602
+ let micros = i64:: from_be_bytes ( micro_bytes. try_into ( ) . unwrap ( ) ) ;
603
+
604
+ Ok ( Self ( v4:: Interval {
605
+ micros,
606
+ days,
607
+ months,
608
+ } ) )
609
+ }
610
+
611
+ fn accepts ( ty : & Type ) -> bool {
612
+ matches ! ( ty, & Type :: INTERVAL )
613
+ }
614
+ }
615
+
616
+ struct IntervalLengthError ;
617
+
618
+ impl std:: error:: Error for IntervalLengthError { }
619
+
620
+ impl std:: fmt:: Display for IntervalLengthError {
621
+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
622
+ f. write_str ( "unexpected binary format for Postgres INTERVAL" )
623
+ }
624
+ }
625
+
626
+ impl std:: fmt:: Debug for IntervalLengthError {
627
+ fn fmt ( & self , f : & mut std:: fmt:: Formatter < ' _ > ) -> std:: fmt:: Result {
628
+ std:: fmt:: Display :: fmt ( self , f)
629
+ }
630
+ }
631
+
547
632
/// Workaround for moka returning Arc<Error> which, although
548
633
/// necessary for concurrency, does not play well with others.
549
634
struct ArcError ( std:: sync:: Arc < anyhow:: Error > ) ;
0 commit comments