11use anyhow:: { Context , Result } ;
2- use redis:: { aio:: MultiplexedConnection , parse_redis_url, AsyncCommands , Client , RedisError } ;
2+ use redis:: { aio:: ConnectionManager , parse_redis_url, AsyncCommands , Client , RedisError } ;
33use spin_core:: async_trait;
44use spin_factor_key_value:: { log_error, Cas , Error , Store , StoreManager , SwapError } ;
5- use std:: ops:: DerefMut ;
65use std:: sync:: Arc ;
7- use tokio:: sync:: { Mutex , OnceCell } ;
6+ use tokio:: sync:: OnceCell ;
87use url:: Url ;
98
109pub struct KeyValueRedis {
1110 database_url : Url ,
12- connection : OnceCell < Arc < Mutex < MultiplexedConnection > > > ,
11+ connection : OnceCell < ConnectionManager > ,
1312}
1413
1514impl KeyValueRedis {
@@ -30,10 +29,8 @@ impl StoreManager for KeyValueRedis {
3029 . connection
3130 . get_or_try_init ( || async {
3231 Client :: open ( self . database_url . clone ( ) ) ?
33- . get_multiplexed_async_connection ( )
32+ . get_connection_manager ( )
3433 . await
35- . map ( Mutex :: new)
36- . map ( Arc :: new)
3734 } )
3835 . await
3936 . map_err ( log_error) ?;
@@ -55,90 +52,69 @@ impl StoreManager for KeyValueRedis {
5552}
5653
5754struct RedisStore {
58- connection : Arc < Mutex < MultiplexedConnection > > ,
55+ connection : ConnectionManager ,
5956 database_url : Url ,
6057}
6158
6259struct CompareAndSwap {
6360 key : String ,
64- connection : Arc < Mutex < MultiplexedConnection > > ,
61+ connection : ConnectionManager ,
6562 bucket_rep : u32 ,
6663}
6764
6865#[ async_trait]
6966impl Store for RedisStore {
67+ async fn after_open ( & self ) -> Result < ( ) , Error > {
68+ if let Err ( _error) = self . connection . clone ( ) . ping :: < ( ) > ( ) . await {
69+ // If an IO error happens, ConnectionManager will start reconnection in the background
70+ // so we do not take any action and just pray re-connection will be successful.
71+ }
72+ Ok ( ( ) )
73+ }
74+
7075 async fn get ( & self , key : & str ) -> Result < Option < Vec < u8 > > , Error > {
71- let mut conn = self . connection . lock ( ) . await ;
72- conn. get ( key) . await . map_err ( log_error)
76+ self . connection . clone ( ) . get ( key) . await . map_err ( log_error)
7377 }
7478
7579 async fn set ( & self , key : & str , value : & [ u8 ] ) -> Result < ( ) , Error > {
7680 self . connection
77- . lock ( )
78- . await
81+ . clone ( )
7982 . set ( key, value)
8083 . await
8184 . map_err ( log_error)
8285 }
8386
8487 async fn delete ( & self , key : & str ) -> Result < ( ) , Error > {
85- self . connection
86- . lock ( )
87- . await
88- . del ( key)
89- . await
90- . map_err ( log_error)
88+ self . connection . clone ( ) . del ( key) . await . map_err ( log_error)
9189 }
9290
9391 async fn exists ( & self , key : & str ) -> Result < bool , Error > {
94- self . connection
95- . lock ( )
96- . await
97- . exists ( key)
98- . await
99- . map_err ( log_error)
92+ self . connection . clone ( ) . exists ( key) . await . map_err ( log_error)
10093 }
10194
10295 async fn get_keys ( & self ) -> Result < Vec < String > , Error > {
103- self . connection
104- . lock ( )
105- . await
106- . keys ( "*" )
107- . await
108- . map_err ( log_error)
96+ self . connection . clone ( ) . keys ( "*" ) . await . map_err ( log_error)
10997 }
11098
11199 async fn get_many ( & self , keys : Vec < String > ) -> Result < Vec < ( String , Option < Vec < u8 > > ) > , Error > {
112- self . connection
113- . lock ( )
114- . await
115- . keys ( keys)
116- . await
117- . map_err ( log_error)
100+ self . connection . clone ( ) . keys ( keys) . await . map_err ( log_error)
118101 }
119102
120103 async fn set_many ( & self , key_values : Vec < ( String , Vec < u8 > ) > ) -> Result < ( ) , Error > {
121104 self . connection
122- . lock ( )
123- . await
105+ . clone ( )
124106 . mset ( & key_values)
125107 . await
126108 . map_err ( log_error)
127109 }
128110
129111 async fn delete_many ( & self , keys : Vec < String > ) -> Result < ( ) , Error > {
130- self . connection
131- . lock ( )
132- . await
133- . del ( keys)
134- . await
135- . map_err ( log_error)
112+ self . connection . clone ( ) . del ( keys) . await . map_err ( log_error)
136113 }
137114
138115 async fn increment ( & self , key : String , delta : i64 ) -> Result < i64 , Error > {
139116 self . connection
140- . lock ( )
141- . await
117+ . clone ( )
142118 . incr ( key, delta)
143119 . await
144120 . map_err ( log_error)
@@ -154,10 +130,8 @@ impl Store for RedisStore {
154130 ) -> Result < Arc < dyn Cas > , Error > {
155131 let cx = Client :: open ( self . database_url . clone ( ) )
156132 . map_err ( log_error) ?
157- . get_multiplexed_async_connection ( )
133+ . get_connection_manager ( )
158134 . await
159- . map ( Mutex :: new)
160- . map ( Arc :: new)
161135 . map_err ( log_error) ?;
162136
163137 Ok ( Arc :: new ( CompareAndSwap {
@@ -175,12 +149,11 @@ impl Cas for CompareAndSwap {
175149 async fn current ( & self ) -> Result < Option < Vec < u8 > > , Error > {
176150 redis:: cmd ( "WATCH" )
177151 . arg ( & self . key )
178- . exec_async ( self . connection . lock ( ) . await . deref_mut ( ) )
152+ . exec_async ( & mut self . connection . clone ( ) )
179153 . await
180154 . map_err ( log_error) ?;
181155 self . connection
182- . lock ( )
183- . await
156+ . clone ( )
184157 . get ( & self . key )
185158 . await
186159 . map_err ( log_error)
@@ -194,12 +167,12 @@ impl Cas for CompareAndSwap {
194167 let res: Result < ( ) , RedisError > = transaction
195168 . atomic ( )
196169 . set ( & self . key , value)
197- . query_async ( self . connection . lock ( ) . await . deref_mut ( ) )
170+ . query_async ( & mut self . connection . clone ( ) )
198171 . await ;
199172
200173 redis:: cmd ( "UNWATCH" )
201174 . arg ( & self . key )
202- . exec_async ( self . connection . lock ( ) . await . deref_mut ( ) )
175+ . exec_async ( & mut self . connection . clone ( ) )
203176 . await
204177 . map_err ( |err| SwapError :: CasFailed ( format ! ( "{err:?}" ) ) ) ?;
205178
0 commit comments