@@ -4,12 +4,15 @@ use futures_util::{future::BoxFuture, stream::BoxStream};
44use postgres:: { Client , NoTls } ;
55use r2d2_postgres:: PostgresConnectionManager ;
66use sqlx:: { postgres:: PgPoolOptions , Executor } ;
7- use std:: { sync:: Arc , time:: Duration } ;
7+ use std:: {
8+ ops:: { Deref , DerefMut } ,
9+ sync:: Arc ,
10+ time:: Duration ,
11+ } ;
812use tokio:: runtime:: Runtime ;
913use tracing:: debug;
1014
1115pub type PoolClient = r2d2:: PooledConnection < PostgresConnectionManager < NoTls > > ;
12- pub type AsyncPoolClient = sqlx:: pool:: PoolConnection < sqlx:: postgres:: Postgres > ;
1316
1417const DEFAULT_SCHEMA : & str = "public" ;
1518
@@ -20,14 +23,15 @@ pub struct Pool {
2023 #[ cfg( not( test) ) ]
2124 pool : r2d2:: Pool < PostgresConnectionManager < NoTls > > ,
2225 async_pool : sqlx:: PgPool ,
26+ runtime : Arc < Runtime > ,
2327 metrics : Arc < InstanceMetrics > ,
2428 max_size : u32 ,
2529}
2630
2731impl Pool {
2832 pub fn new (
2933 config : & Config ,
30- runtime : & Runtime ,
34+ runtime : Arc < Runtime > ,
3135 metrics : Arc < InstanceMetrics > ,
3236 ) -> Result < Pool , PoolError > {
3337 debug ! (
@@ -39,7 +43,7 @@ impl Pool {
3943 #[ cfg( test) ]
4044 pub ( crate ) fn new_with_schema (
4145 config : & Config ,
42- runtime : & Runtime ,
46+ runtime : Arc < Runtime > ,
4347 metrics : Arc < InstanceMetrics > ,
4448 schema : & str ,
4549 ) -> Result < Pool , PoolError > {
@@ -48,7 +52,7 @@ impl Pool {
4852
4953 fn new_inner (
5054 config : & Config ,
51- runtime : & Runtime ,
55+ runtime : Arc < Runtime > ,
5256 metrics : Arc < InstanceMetrics > ,
5357 schema : & str ,
5458 ) -> Result < Pool , PoolError > {
@@ -109,6 +113,7 @@ impl Pool {
109113 pool,
110114 async_pool,
111115 metrics,
116+ runtime,
112117 max_size : config. max_legacy_pool_size + config. max_pool_size ,
113118 } )
114119 }
@@ -139,7 +144,10 @@ impl Pool {
139144
140145 pub async fn get_async ( & self ) -> Result < AsyncPoolClient , PoolError > {
141146 match self . async_pool . acquire ( ) . await {
142- Ok ( conn) => Ok ( conn) ,
147+ Ok ( conn) => Ok ( AsyncPoolClient {
148+ inner : Some ( conn) ,
149+ runtime : self . runtime . clone ( ) ,
150+ } ) ,
143151 Err ( err) => {
144152 self . metrics . failed_db_connections . inc ( ) ;
145153 Err ( PoolError :: AsyncClientError ( err) )
@@ -222,6 +230,36 @@ where
222230 }
223231}
224232
233+ /// we wrap `sqlx::PoolConnection` so we can drop it in a sync context
234+ /// and enter the runtime.
235+ /// Otherwise dropping the PoolConnection will panic because it can't spawn a task.
236+ #[ derive( Debug ) ]
237+ pub struct AsyncPoolClient {
238+ inner : Option < sqlx:: pool:: PoolConnection < sqlx:: postgres:: Postgres > > ,
239+ runtime : Arc < Runtime > ,
240+ }
241+
242+ impl Deref for AsyncPoolClient {
243+ type Target = sqlx:: PgConnection ;
244+
245+ fn deref ( & self ) -> & Self :: Target {
246+ self . inner . as_ref ( ) . unwrap ( )
247+ }
248+ }
249+
250+ impl DerefMut for AsyncPoolClient {
251+ fn deref_mut ( & mut self ) -> & mut Self :: Target {
252+ self . inner . as_mut ( ) . unwrap ( )
253+ }
254+ }
255+
256+ impl Drop for AsyncPoolClient {
257+ fn drop ( & mut self ) {
258+ let _guard = self . runtime . enter ( ) ;
259+ drop ( self . inner . take ( ) )
260+ }
261+ }
262+
225263#[ derive( Debug ) ]
226264struct SetSchema {
227265 schema : String ,
0 commit comments