@@ -4,6 +4,7 @@ use crate::{Command, Connection, Db, Frame, Shutdown};
4
4
use bytes:: Bytes ;
5
5
use tokio:: select;
6
6
use tokio:: stream:: { StreamExt , StreamMap } ;
7
+ use tokio:: sync:: broadcast;
7
8
8
9
/// Subscribes the client to one or more channels.
9
10
///
@@ -112,21 +113,8 @@ impl Subscribe {
112
113
// `self.channels` is used to track additional channels to subscribe
113
114
// to. When new `SUBSCRIBE` commands are received during the
114
115
// execution of `apply`, the new channels are pushed onto this vec.
115
- for channel in self . channels . drain ( ..) {
116
- // Build response frame to respond to the client with.
117
- let mut response = Frame :: array ( ) ;
118
- response. push_bulk ( Bytes :: from_static ( b"subscribe" ) ) ;
119
- response. push_bulk ( Bytes :: copy_from_slice ( channel. as_bytes ( ) ) ) ;
120
- response. push_int ( subscriptions. len ( ) . saturating_add ( 1 ) as u64 ) ;
121
-
122
- // Subscribe to channel
123
- let rx = db. subscribe ( channel. clone ( ) ) ;
124
-
125
- // Track subscription in this client's subscription set.
126
- subscriptions. insert ( channel, rx) ;
127
-
128
- // Respond with the successful subscription
129
- dst. write_frame ( & response) . await ?;
116
+ for channel_name in self . channels . drain ( ..) {
117
+ subscribe_to_channel ( channel_name, & mut subscriptions, db, dst) . await ?;
130
118
}
131
119
132
120
// Wait for one of the following to happen:
@@ -136,7 +124,7 @@ impl Subscribe {
136
124
// - A server shutdown signal.
137
125
select ! {
138
126
// Receive messages from subscribed channels
139
- Some ( ( channel , msg) ) = subscriptions. next( ) => {
127
+ Some ( ( channel_name , msg) ) = subscriptions. next( ) => {
140
128
use tokio:: sync:: broadcast:: RecvError ;
141
129
142
130
let msg = match msg {
@@ -145,60 +133,22 @@ impl Subscribe {
145
133
Err ( RecvError :: Closed ) => unreachable!( ) ,
146
134
} ;
147
135
148
- let mut response = Frame :: array( ) ;
149
- response. push_bulk( Bytes :: from_static( b"message" ) ) ;
150
- response. push_bulk( Bytes :: copy_from_slice( channel. as_bytes( ) ) ) ;
151
- response. push_bulk( msg) ;
152
-
153
- dst. write_frame( & response) . await ?;
136
+ dst. write_frame( & make_message_frame( channel_name, msg) ) . await ?;
154
137
}
155
138
res = dst. read_frame( ) => {
156
139
let frame = match res? {
157
140
Some ( frame) => frame,
158
- // How to handle remote client closing write half?
141
+ // This happens if the remote client has disconnected.
159
142
None => return Ok ( ( ) )
160
143
} ;
161
144
162
- // A command has been received from the client.
163
- //
164
- // Only `SUBSCRIBE` and `UNSUBSCRIBE` commands are permitted
165
- // in this context.
166
- match Command :: from_frame( frame) ? {
167
- Command :: Subscribe ( subscribe) => {
168
- // Subscribe to the channels on next iteration
169
- self . channels. extend( subscribe. channels. into_iter( ) ) ;
170
- }
171
- Command :: Unsubscribe ( mut unsubscribe) => {
172
- // If no channels are specified, this requests
173
- // unsubscribing from **all** channels. To implement
174
- // this, the `unsubscribe.channels` vec is populated
175
- // with the list of channels currently subscribed
176
- // to.
177
- if unsubscribe. channels. is_empty( ) {
178
- unsubscribe. channels = subscriptions
179
- . keys( )
180
- . map( |channel| channel. to_string( ) )
181
- . collect( ) ;
182
- }
183
-
184
- for channel in unsubscribe. channels. drain( ..) {
185
- subscriptions. remove( & channel) ;
186
-
187
- let mut response = Frame :: array( ) ;
188
- response. push_bulk( Bytes :: from_static( b"unsubscribe" ) ) ;
189
- response. push_bulk( Bytes :: copy_from_slice( channel. as_bytes( ) ) ) ;
190
- response. push_int( subscriptions. len( ) as u64 ) ;
191
-
192
- dst. write_frame( & response) . await ?;
193
- }
194
- }
195
- command => {
196
- let cmd = Unknown :: new( command. get_name( ) ) ;
197
- cmd. apply( dst) . await ?;
198
- }
199
- }
145
+ handle_command(
146
+ frame,
147
+ & mut self . channels,
148
+ & mut subscriptions,
149
+ dst,
150
+ ) . await ?;
200
151
}
201
- // Receive additional commands from the client
202
152
_ = shutdown. recv( ) => {
203
153
return Ok ( ( ) ) ;
204
154
}
@@ -220,6 +170,106 @@ impl Subscribe {
220
170
}
221
171
}
222
172
173
+ async fn subscribe_to_channel (
174
+ channel_name : String ,
175
+ subscriptions : & mut StreamMap < String , broadcast:: Receiver < Bytes > > ,
176
+ db : & Db ,
177
+ dst : & mut Connection ,
178
+ ) -> crate :: Result < ( ) > {
179
+ // Subscribe to the channel.
180
+ let rx = db. subscribe ( channel_name. clone ( ) ) ;
181
+
182
+ // Track subscription in this client's subscription set.
183
+ subscriptions. insert ( channel_name. clone ( ) , rx) ;
184
+
185
+ // Respond with the successful subscription
186
+ let response = make_subscribe_frame ( channel_name, subscriptions. len ( ) ) ;
187
+ dst. write_frame ( & response) . await ?;
188
+
189
+ Ok ( ( ) )
190
+ }
191
+
192
+ /// Handle a command received while inside `Subscribe::apply`. Only subscribe
193
+ /// and unsubscribe commands are permitted in this context.
194
+ ///
195
+ /// Any new subscriptions are appended to `subscribe_to` instead of modifying
196
+ /// `subscriptions`.
197
+ async fn handle_command (
198
+ frame : Frame ,
199
+ subscribe_to : & mut Vec < String > ,
200
+ subscriptions : & mut StreamMap < String , broadcast:: Receiver < Bytes > > ,
201
+ dst : & mut Connection ,
202
+ ) -> crate :: Result < ( ) > {
203
+ // A command has been received from the client.
204
+ //
205
+ // Only `SUBSCRIBE` and `UNSUBSCRIBE` commands are permitted
206
+ // in this context.
207
+ match Command :: from_frame ( frame) ? {
208
+ Command :: Subscribe ( subscribe) => {
209
+ // The `apply` method will subscribe to the channels we add to this
210
+ // vector.
211
+ subscribe_to. extend ( subscribe. channels . into_iter ( ) ) ;
212
+ }
213
+ Command :: Unsubscribe ( mut unsubscribe) => {
214
+ // If no channels are specified, this requests unsubscribing from
215
+ // **all** channels. To implement this, the `unsubscribe.channels`
216
+ // vec is populated with the list of channels currently subscribed
217
+ // to.
218
+ if unsubscribe. channels . is_empty ( ) {
219
+ unsubscribe. channels = subscriptions
220
+ . keys ( )
221
+ . map ( |channel_name| channel_name. to_string ( ) )
222
+ . collect ( ) ;
223
+ }
224
+
225
+ for channel_name in unsubscribe. channels {
226
+ subscriptions. remove ( & channel_name) ;
227
+
228
+ let response = make_unsubscribe_frame ( channel_name, subscriptions. len ( ) ) ;
229
+ dst. write_frame ( & response) . await ?;
230
+ }
231
+ }
232
+ command => {
233
+ let cmd = Unknown :: new ( command. get_name ( ) ) ;
234
+ cmd. apply ( dst) . await ?;
235
+ }
236
+ }
237
+ Ok ( ( ) )
238
+ }
239
+
240
+ /// Creates the response to a subcribe request.
241
+ ///
242
+ /// All of these functions take the `channel_name` as a `String` instead of
243
+ /// a `&str` since `Bytes::from` can reuse the allocation in the `String`, and
244
+ /// taking a `&str` would require copying the data. This allows the caller to
245
+ /// decide whether to clone the channel name or not.
246
+ fn make_subscribe_frame ( channel_name : String , num_subs : usize ) -> Frame {
247
+ let mut response = Frame :: array ( ) ;
248
+ response. push_bulk ( Bytes :: from_static ( b"subscribe" ) ) ;
249
+ response. push_bulk ( Bytes :: from ( channel_name) ) ;
250
+ response. push_int ( num_subs as u64 ) ;
251
+ response
252
+ }
253
+
254
+ /// Creates the response to an unsubcribe request.
255
+ fn make_unsubscribe_frame ( channel_name : String , num_subs : usize ) -> Frame {
256
+ let mut response = Frame :: array ( ) ;
257
+ response. push_bulk ( Bytes :: from_static ( b"unsubscribe" ) ) ;
258
+ response. push_bulk ( Bytes :: from ( channel_name) ) ;
259
+ response. push_int ( num_subs as u64 ) ;
260
+ response
261
+ }
262
+
263
+ /// Creates a message informing the client about a new message on a channel that
264
+ /// the client subscribes to.
265
+ fn make_message_frame ( channel_name : String , msg : Bytes ) -> Frame {
266
+ let mut response = Frame :: array ( ) ;
267
+ response. push_bulk ( Bytes :: from_static ( b"message" ) ) ;
268
+ response. push_bulk ( Bytes :: from ( channel_name) ) ;
269
+ response. push_bulk ( msg) ;
270
+ response
271
+ }
272
+
223
273
impl Unsubscribe {
224
274
/// Create a new `Unsubscribe` command with the given `channels`.
225
275
pub ( crate ) fn new ( channels : & [ String ] ) -> Unsubscribe {
0 commit comments