1- use anyhow:: { anyhow, Result } ;
1+ use anyhow:: { anyhow, Context , Result } ;
22use native_tls:: TlsConnector ;
33use postgres_native_tls:: MakeTlsConnector ;
44use spin_world:: async_trait;
55use spin_world:: spin:: postgres:: postgres:: {
66 self as v3, Column , DbDataType , DbValue , ParameterValue , RowSet ,
77} ;
88use tokio_postgres:: types:: Type ;
9- use tokio_postgres:: { config:: SslMode , types:: ToSql , Row } ;
10- use tokio_postgres:: { Client as TokioClient , NoTls , Socket } ;
9+ use tokio_postgres:: { config:: SslMode , types:: ToSql , NoTls , Row } ;
10+
11+ const CONNECTION_POOL_SIZE : usize = 64 ;
1112
1213#[ async_trait]
13- pub trait Client {
14- async fn build_client ( address : & str ) -> Result < Self >
15- where
16- Self : Sized ;
14+ pub trait ClientFactory : Send + Sync {
15+ type Client : Client + Send + Sync + ' static ;
16+ fn new ( ) -> Self ;
17+ async fn build_client ( & mut self , address : & str ) -> Result < Self :: Client > ;
18+ }
19+
20+ pub struct PooledTokioClientFactory {
21+ pools : std:: collections:: HashMap < String , deadpool_postgres:: Pool > ,
22+ }
23+
24+ #[ async_trait]
25+ impl ClientFactory for PooledTokioClientFactory {
26+ type Client = deadpool_postgres:: Object ;
27+
28+ fn new ( ) -> Self {
29+ Self {
30+ pools : Default :: default ( ) ,
31+ }
32+ }
33+
34+ async fn build_client ( & mut self , address : & str ) -> Result < Self :: Client > {
35+ let pool_entry = self . pools . entry ( address. to_owned ( ) ) ;
36+ let pool = match pool_entry {
37+ std:: collections:: hash_map:: Entry :: Occupied ( entry) => entry. into_mut ( ) ,
38+ std:: collections:: hash_map:: Entry :: Vacant ( entry) => {
39+ let pool = create_connection_pool ( address)
40+ . context ( "establishing PostgreSQL connection pool" ) ?;
41+ entry. insert ( pool)
42+ }
43+ } ;
44+
45+ Ok ( pool. get ( ) . await ?)
46+ }
47+ }
48+
49+ fn create_connection_pool ( address : & str ) -> Result < deadpool_postgres:: Pool > {
50+ let config = address
51+ . parse :: < tokio_postgres:: Config > ( )
52+ . context ( "parsing Postgres connection string" ) ?;
53+
54+ tracing:: debug!( "Build new connection: {}" , address) ;
1755
56+ // TODO: This is slower but safer. Is it the right tradeoff?
57+ // https://docs.rs/deadpool-postgres/latest/deadpool_postgres/enum.RecyclingMethod.html
58+ let mgr_config = deadpool_postgres:: ManagerConfig {
59+ recycling_method : deadpool_postgres:: RecyclingMethod :: Clean ,
60+ } ;
61+
62+ let mgr = if config. get_ssl_mode ( ) == SslMode :: Disable {
63+ deadpool_postgres:: Manager :: from_config ( config, NoTls , mgr_config)
64+ } else {
65+ let builder = TlsConnector :: builder ( ) ;
66+ let connector = MakeTlsConnector :: new ( builder. build ( ) ?) ;
67+ deadpool_postgres:: Manager :: from_config ( config, connector, mgr_config)
68+ } ;
69+
70+ // TODO: what is our max size heuristic? Should this be passed in soe that different
71+ // hosts can manage it according to their needs? Will a plain number suffice for
72+ // sophisticated hosts anyway?
73+ let pool = deadpool_postgres:: Pool :: builder ( mgr)
74+ . max_size ( CONNECTION_POOL_SIZE )
75+ . build ( )
76+ . context ( "building Postgres connection pool" ) ?;
77+
78+ Ok ( pool)
79+ }
80+
81+ #[ async_trait]
82+ pub trait Client {
1883 async fn execute (
1984 & self ,
2085 statement : String ,
@@ -29,28 +94,7 @@ pub trait Client {
2994}
3095
3196#[ async_trait]
32- impl Client for TokioClient {
33- async fn build_client ( address : & str ) -> Result < Self >
34- where
35- Self : Sized ,
36- {
37- let config = address. parse :: < tokio_postgres:: Config > ( ) ?;
38-
39- tracing:: debug!( "Build new connection: {}" , address) ;
40-
41- if config. get_ssl_mode ( ) == SslMode :: Disable {
42- let ( client, connection) = config. connect ( NoTls ) . await ?;
43- spawn_connection ( connection) ;
44- Ok ( client)
45- } else {
46- let builder = TlsConnector :: builder ( ) ;
47- let connector = MakeTlsConnector :: new ( builder. build ( ) ?) ;
48- let ( client, connection) = config. connect ( connector) . await ?;
49- spawn_connection ( connection) ;
50- Ok ( client)
51- }
52- }
53-
97+ impl Client for deadpool_postgres:: Object {
5498 async fn execute (
5599 & self ,
56100 statement : String ,
@@ -67,7 +111,8 @@ impl Client for TokioClient {
67111 . map ( |b| b. as_ref ( ) as & ( dyn ToSql + Sync ) )
68112 . collect ( ) ;
69113
70- self . execute ( & statement, params_refs. as_slice ( ) )
114+ self . as_ref ( )
115+ . execute ( & statement, params_refs. as_slice ( ) )
71116 . await
72117 . map_err ( |e| v3:: Error :: QueryFailed ( format ! ( "{e:?}" ) ) )
73118 }
@@ -89,6 +134,7 @@ impl Client for TokioClient {
89134 . collect ( ) ;
90135
91136 let results = self
137+ . as_ref ( )
92138 . query ( & statement, params_refs. as_slice ( ) )
93139 . await
94140 . map_err ( |e| v3:: Error :: QueryFailed ( format ! ( "{e:?}" ) ) ) ?;
@@ -111,17 +157,6 @@ impl Client for TokioClient {
111157 }
112158}
113159
114- fn spawn_connection < T > ( connection : tokio_postgres:: Connection < Socket , T > )
115- where
116- T : tokio_postgres:: tls:: TlsStream + std:: marker:: Unpin + std:: marker:: Send + ' static ,
117- {
118- tokio:: spawn ( async move {
119- if let Err ( e) = connection. await {
120- tracing:: error!( "Postgres connection error: {}" , e) ;
121- }
122- } ) ;
123- }
124-
125160fn to_sql_parameter ( value : & ParameterValue ) -> Result < Box < dyn ToSql + Send + Sync > > {
126161 match value {
127162 ParameterValue :: Boolean ( v) => Ok ( Box :: new ( * v) ) ,
0 commit comments