Skip to content

Commit 3834ea0

Browse files
committed
Extract non-generic code from base::Sender::send() and base::Receiver::recv()
- send_chunks_from_stream: streaming chunk send task extracted - connect_ports_and_spawn: port connection + callback spawning extracted - feed_recv_chunks: chunk feeding loop extracted from recv() - receive_and_connect_ports: port acceptance + callback spawning extracted
1 parent 103347b commit 3834ea0

File tree

2 files changed

+174
-143
lines changed

2 files changed

+174
-143
lines changed

remoc/src/rch/base/receiver.rs

Lines changed: 107 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,87 @@ impl<T, Codec> fmt::Debug for Receiver<T, Codec> {
194194
}
195195
}
196196

197+
/// Result of feeding received chunks to the deserialization thread.
198+
enum FeedChunksResult {
199+
/// All chunks have been fed successfully.
200+
Done,
201+
/// The current send operation was cancelled; restart.
202+
Cancelled,
203+
/// The chmux connection failed.
204+
ChMux,
205+
/// Maximum item size was exceeded.
206+
MaxItemSizeExceeded,
207+
}
208+
209+
/// Non-generic: feed received data chunks from chmux to the deserialization thread.
210+
///
211+
/// Separated from the generic `Receiver::recv` to avoid monomorphization.
212+
#[inline(never)]
213+
async fn feed_recv_chunks(
214+
receiver: &mut chmux::Receiver,
215+
tx: &tokio::sync::mpsc::Sender<Result<Bytes, ()>>,
216+
total: &mut usize,
217+
max_item_size: usize,
218+
) -> FeedChunksResult {
219+
loop {
220+
let tx_permit = match tx.reserve().await {
221+
Ok(tx_permit) => tx_permit,
222+
_ => return FeedChunksResult::Done,
223+
};
224+
225+
match receiver.recv_chunk().await {
226+
Ok(Some(chunk)) => {
227+
*total += chunk.remaining();
228+
if *total > max_item_size {
229+
return FeedChunksResult::MaxItemSizeExceeded;
230+
}
231+
tx_permit.send(Ok(chunk));
232+
}
233+
Ok(None) => return FeedChunksResult::Done,
234+
Err(RecvChunkError::Cancelled) => return FeedChunksResult::Cancelled,
235+
Err(RecvChunkError::ChMux) => return FeedChunksResult::ChMux,
236+
}
237+
}
238+
}
239+
240+
/// Non-generic: receive port requests and spawn callbacks from deserialized objects.
241+
///
242+
/// Separated from the generic `Receiver::recv` to avoid monomorphization.
243+
///
244+
/// Returns `Ok(None)` on success, `Ok(Some(received))` if the received message
245+
/// was not port requests (requires restart), or `Err` on failure.
246+
#[inline(never)]
247+
async fn receive_and_connect_ports(
248+
receiver: &mut chmux::Receiver,
249+
pds: &mut PortDeserializer,
250+
default_max_ports: usize,
251+
) -> Result<Option<Option<Received>>, RecvError> {
252+
if !pds.expected.is_empty() {
253+
receiver.set_max_ports(pds.expected.len() + default_max_ports);
254+
255+
let requests = match receiver.recv_any().await? {
256+
Some(chmux::Received::Requests(requests)) => requests,
257+
other => return Ok(Some(other)),
258+
};
259+
260+
for request in requests {
261+
if let Some((local_port, callback)) = pds.expected.remove(&request.id()) {
262+
exec::spawn(callback(local_port, request).in_current_span());
263+
}
264+
}
265+
266+
if !pds.expected.is_empty() {
267+
return Err(RecvError::MissingPorts(pds.expected.keys().copied().collect()));
268+
}
269+
}
270+
271+
for task in pds.tasks.drain(..) {
272+
exec::spawn(task.in_current_span());
273+
}
274+
275+
Ok(None)
276+
}
277+
197278
enum DataSource<T> {
198279
None,
199280
Buffered(Option<chmux::DataBuf>),
@@ -293,46 +374,20 @@ where
293374

294375
// Observe deserialization of streamed data.
295376
DataSource::Streamed { tx, task, total } => {
296-
enum FeedError {
297-
RecvChunkError(RecvChunkError),
298-
MaxItemSizeExceeded,
299-
}
300-
301-
// Feed received data chunks to deserialization thread.
302-
if let Some(tx) = &tx {
303-
let res = loop {
304-
let tx_permit = match tx.reserve().await {
305-
Ok(tx_permit) => tx_permit,
306-
_ => {
307-
break Ok(());
308-
}
309-
};
310-
311-
match self.receiver.recv_chunk().await {
312-
Ok(Some(chunk)) => {
313-
*total += chunk.remaining();
314-
if *total > self.max_item_size {
315-
break Err(FeedError::MaxItemSizeExceeded);
316-
}
317-
318-
tx_permit.send(Ok(chunk));
319-
}
320-
Ok(None) => break Ok(()),
321-
Err(err) => break Err(FeedError::RecvChunkError(err)),
322-
}
323-
};
324-
325-
match res {
326-
Ok(()) => (),
327-
Err(FeedError::RecvChunkError(RecvChunkError::Cancelled)) => {
377+
// Feed received data chunks to deserialization thread (non-generic).
378+
if let Some(tx_ref) = &tx {
379+
match feed_recv_chunks(&mut self.receiver, tx_ref, total, self.max_item_size).await
380+
{
381+
FeedChunksResult::Done => (),
382+
FeedChunksResult::Cancelled => {
328383
self.data = DataSource::None;
329384
continue 'restart;
330385
}
331-
Err(FeedError::RecvChunkError(RecvChunkError::ChMux)) => {
386+
FeedChunksResult::ChMux => {
332387
self.data = DataSource::None;
333388
return Err(RecvError::Receive(chmux::RecvError::ChMux));
334389
}
335-
Err(FeedError::MaxItemSizeExceeded) => {
390+
FeedChunksResult::MaxItemSizeExceeded => {
336391
self.data = DataSource::None;
337392
return Err(RecvError::MaxItemSizeExceeded);
338393
}
@@ -365,47 +420,26 @@ where
365420
}
366421
}
367422

368-
// Connect received ports.
423+
// Connect received ports (non-generic).
369424
let pds = self.port_deser.as_mut().unwrap();
370-
if !pds.expected.is_empty() {
371-
// Set port limit.
372-
//
373-
// Allow the reception of additional ports for forward compatibility,
374-
// i.e. our deserializer may use an older version of the struct
375-
// which is missing some ports that the remote endpoint sent.
376-
self.receiver.set_max_ports(pds.expected.len() + self.default_max_ports.unwrap());
377-
378-
// Receive port requests from chmux.
379-
let requests = match self.receiver.recv_any().await? {
380-
Some(chmux::Received::Requests(requests)) => requests,
381-
other => {
382-
// Current send operation has been aborted and this is data from
383-
// next send operation, so we restart.
384-
self.recved = Some(other);
385-
self.data = DataSource::None;
386-
self.item = None;
387-
self.port_deser = None;
388-
continue 'restart;
389-
}
390-
};
391-
392-
// Call port callbacks from received objects, ignoring superfluous requests for
393-
// forward compatibility.
394-
for request in requests {
395-
if let Some((local_port, callback)) = pds.expected.remove(&request.id()) {
396-
exec::spawn(callback(local_port, request).in_current_span());
397-
}
398-
}
399-
400-
// But error on ports that we expect but that are missing.
401-
if !pds.expected.is_empty() {
402-
return Err(RecvError::MissingPorts(pds.expected.keys().copied().collect()));
425+
match receive_and_connect_ports(
426+
&mut self.receiver,
427+
pds,
428+
self.default_max_ports.unwrap(),
429+
)
430+
.await
431+
{
432+
Ok(None) => (),
433+
Ok(Some(other)) => {
434+
// Current send operation has been aborted and this is data from
435+
// next send operation, so we restart.
436+
self.recved = Some(other);
437+
self.data = DataSource::None;
438+
self.item = None;
439+
self.port_deser = None;
440+
continue 'restart;
403441
}
404-
}
405-
406-
// Spawn registered tasks.
407-
for task in pds.tasks.drain(..) {
408-
exec::spawn(task.in_current_span());
442+
Err(err) => return Err(err),
409443
}
410444

411445
return Ok(Some(self.item.take().unwrap()));

remoc/src/rch/base/sender.rs

Lines changed: 67 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,61 @@ impl PortSerializer {
226226
}
227227
}
228228

229+
/// Non-generic: stream serialized chunks over chmux.
230+
///
231+
/// Separated from the generic `Sender::send` to avoid monomorphization.
232+
#[inline(never)]
233+
async fn send_chunks_from_stream(
234+
sender: &mut chmux::Sender,
235+
mut rx: tokio::sync::mpsc::Receiver<BytesMut>,
236+
max_item_size: usize,
237+
) -> Result<(), SendErrorKind> {
238+
let mut sc = sender.send_chunks();
239+
let mut total = 0;
240+
while let Some(chunk) = rx.recv().await {
241+
total += chunk.len();
242+
if total > max_item_size {
243+
return Err(SendErrorKind::MaxItemSizeExceeded);
244+
}
245+
sc = sc.send(chunk.freeze()).await.map_err(SendErrorKind::Send)?;
246+
}
247+
sc.finish().await.map_err(SendErrorKind::Send)
248+
}
249+
250+
/// Non-generic: connect ports obtained during serialization and spawn callbacks.
251+
///
252+
/// Separated from the generic `Sender::send` to avoid monomorphization.
253+
#[inline(never)]
254+
async fn connect_ports_and_spawn(
255+
sender: &mut chmux::Sender,
256+
ps: PortSerializer,
257+
) -> Result<(), chmux::SendError> {
258+
let PortSerializer { requests, tasks, .. } = ps;
259+
260+
let mut ports = Vec::new();
261+
let mut callbacks = Vec::new();
262+
for (port, callback) in requests {
263+
ports.push(PortReq::new(port));
264+
callbacks.push(callback);
265+
}
266+
267+
let connects = if ports.is_empty() {
268+
Vec::new()
269+
} else {
270+
sender.connect(ports, true).await?
271+
};
272+
273+
for (callback, connect) in callbacks.into_iter().zip(connects.into_iter()) {
274+
exec::spawn(callback(connect).in_current_span());
275+
}
276+
277+
for task in tasks {
278+
exec::spawn(task.in_current_span());
279+
}
280+
281+
Ok(())
282+
}
283+
229284
/// Sends arbitrary values to a remote endpoint.
230285
///
231286
/// Values may be or contain any channel from this crate.
@@ -365,54 +420,24 @@ where
365420

366421
None => {
367422
// Stream data while serializing.
368-
let (tx, mut rx) = tokio::sync::mpsc::channel(BIG_DATA_CHUNK_QUEUE);
369-
let ser_task = Self::serialize_streaming(
370-
self.sender.port_allocator(),
371-
self.sender.storage(),
372-
item,
373-
tx,
374-
self.sender.chunk_size(),
375-
);
376-
377-
enum SendTaskError {
378-
SendError(chmux::SendError),
379-
MaxItemSizeExceeded,
380-
}
381-
382-
let mut sc = self.sender.send_chunks();
383-
let max_item_size = self.max_item_size;
384-
let send_task = async move {
385-
let mut total = 0;
386-
while let Some(chunk) = rx.recv().await {
387-
total += chunk.len();
388-
if total > max_item_size {
389-
return Err(SendTaskError::MaxItemSizeExceeded);
390-
}
391-
392-
sc = sc.send(chunk.freeze()).await.map_err(SendTaskError::SendError)?;
393-
}
394-
Ok(sc)
395-
};
423+
let (tx, rx) = tokio::sync::mpsc::channel(BIG_DATA_CHUNK_QUEUE);
424+
let allocator = self.sender.port_allocator();
425+
let storage = self.sender.storage();
426+
let chunk_size = self.sender.chunk_size();
427+
let ser_task = Self::serialize_streaming(allocator, storage, item, tx, chunk_size);
428+
let send_task = send_chunks_from_stream(&mut self.sender, rx, self.max_item_size);
396429

397430
match tokio::join!(ser_task, send_task) {
398-
(Ok((item, ps, size)), Ok(sc)) => {
399-
if let Err(err) = sc.finish().await {
400-
return Err(SendError::new(SendErrorKind::Send(err), item));
401-
}
402-
431+
(Ok((item, ps, size)), Ok(())) => {
403432
if size <= self.sender.max_data_size() {
404433
self.big_data = (self.big_data - 1).max(-BIG_DATA_LIMIT);
405434
}
406435

407436
(item, ps)
408437
}
409-
(Ok((item, _, _)), Err(err)) | (Err((_, item)), Err(err)) => {
438+
(Ok((item, _, _)), Err(kind)) | (Err((_, item)), Err(kind)) => {
410439
// When sending fails, the serialization task will either finish
411440
// or fail due to rx being dropped.
412-
let kind = match err {
413-
SendTaskError::SendError(err) => SendErrorKind::Send(err),
414-
SendTaskError::MaxItemSizeExceeded => SendErrorKind::MaxItemSizeExceeded,
415-
};
416441
return Err(SendError::new(kind, item));
417442
}
418443
(Err((err, item)), _) => {
@@ -424,42 +449,14 @@ where
424449
}
425450
};
426451

427-
let PortSerializer { requests, tasks, .. } = ps;
428-
429-
// Extract ports and connect callbacks.
430-
let mut ports = Vec::new();
431-
let mut callbacks = Vec::new();
432-
for (port, callback) in requests {
433-
ports.push(PortReq::new(port));
434-
callbacks.push(callback);
452+
// Connect ports obtained during serialization (non-generic).
453+
if let Err(err) = connect_ports_and_spawn(&mut self.sender, ps).await {
454+
return Err(SendError::new(SendErrorKind::Send(err), item));
435455
}
436456

437-
// Request connecting chmux ports.
438-
let connects = if ports.is_empty() {
439-
Vec::new()
440-
} else {
441-
match self.sender.connect(ports, true).await {
442-
Ok(connects) => connects,
443-
Err(err) => return Err(SendError::new(SendErrorKind::Send(err), item)),
444-
}
445-
};
446-
447-
// Ensure that item is dropped before calling connection callbacks.
457+
// Ensure that item is dropped before connection callbacks run.
448458
drop(item);
449459

450-
// Call callbacks of BaseSenders and BaseReceivers with obtained
451-
// chmux connect requests.
452-
//
453-
// We have to spawn a task for this to ensure cancellation safety.
454-
for (callback, connect) in callbacks.into_iter().zip(connects.into_iter()) {
455-
exec::spawn(callback(connect).in_current_span());
456-
}
457-
458-
// Spawn registered tasks.
459-
for task in tasks {
460-
exec::spawn(task.in_current_span());
461-
}
462-
463460
Ok(())
464461
}
465462

0 commit comments

Comments
 (0)