diff --git a/src/tx/send_queue.rs b/src/tx/send_queue.rs index f3ecb25..d3c1611 100644 --- a/src/tx/send_queue.rs +++ b/src/tx/send_queue.rs @@ -33,8 +33,6 @@ use crate::types::StreamKey; use std::cell::RefCell; use std::collections::HashMap; use std::collections::VecDeque; -use std::ops::AddAssign; -use std::ops::SubAssign; use std::rc::Rc; use std::time::Duration; @@ -115,10 +113,9 @@ impl Item { } } -struct ThresholdWatcher<'a> { +struct ThresholdWatcher { value: usize, low_threshold: usize, - low_cb: Box, } fn add_lifecycle_events(events: &Rc>, lifecycle_id: &Option) { @@ -128,57 +125,54 @@ fn add_lifecycle_events(events: &Rc>, lifecycle_id: &Opti } } -impl<'a> ThresholdWatcher<'a> { - pub fn new(low_threshold: usize, low_cb: impl Fn() + 'a) -> Self { - Self { value: 0, low_threshold, low_cb: Box::new(low_cb) } +impl ThresholdWatcher { + pub fn new(low_threshold: usize) -> Self { + Self { value: 0, low_threshold } } - pub fn set_low_threshold(&mut self, low_threshold: usize) { - if self.low_threshold < self.value && low_threshold >= self.value { - (self.low_cb)(); - } - self.low_threshold = low_threshold; + pub fn add(&mut self, amount: usize) { + self.value += amount; } -} -impl AddAssign for ThresholdWatcher<'_> { - fn add_assign(&mut self, rhs: usize) { - self.value += rhs; + // Returns `true` if the threshold was crossed downwards. + #[must_use] + pub fn sub(&mut self, amount: usize) -> bool { + debug_assert!(self.value >= amount); + let old_value = self.value; + self.value -= amount; + old_value > self.low_threshold && self.value <= self.low_threshold } -} - -impl SubAssign for ThresholdWatcher<'_> { - fn sub_assign(&mut self, rhs: usize) { - debug_assert!(self.value >= rhs); - let old_value = self.value; - self.value -= rhs; - if old_value > self.low_threshold && self.value <= self.low_threshold { - (self.low_cb)(); - } + // Returns `true` if the new threshold is lower than the current value, which means that the + // new low threshold was crossed. + #[must_use] + pub fn set_low_threshold(&mut self, low_threshold: usize) -> bool { + let triggered = self.low_threshold < self.value && low_threshold >= self.value; + self.low_threshold = low_threshold; + triggered } } /// Per-stream information. -struct OutgoingStream<'a> { +struct OutgoingStream { priority: u16, pause_state: PauseState, next_unordered_mid: Mid, next_ordered_mid: Mid, next_ssn: Ssn, - buffered_amount: ThresholdWatcher<'a>, + buffered_amount: ThresholdWatcher, items: VecDeque, } -impl<'a> OutgoingStream<'a> { - fn new(priority: u16, low_threshold: usize, low_cb: impl Fn() + 'a) -> Self { +impl OutgoingStream { + fn new(priority: u16, low_threshold: usize) -> Self { Self { priority, pause_state: PauseState::NotPaused, next_unordered_mid: Mid(0), next_ordered_mid: Mid(0), next_ssn: Ssn(0), - buffered_amount: ThresholdWatcher::new(low_threshold, low_cb), + buffered_amount: ThresholdWatcher::new(low_threshold), items: VecDeque::new(), } } @@ -188,10 +182,10 @@ pub struct SendQueue { enable_message_interleaving: bool, default_priority: u16, default_low_buffered_amount_low_threshold: usize, - buffered_amount: ThresholdWatcher<'static>, + buffered_amount: ThresholdWatcher, current_message_id: OutgoingMessageId, scheduler: StreamScheduler, - streams: HashMap>, + streams: HashMap, events: Rc>, } @@ -201,24 +195,16 @@ impl SendQueue { options: &Options, events: Rc>, ) -> Self { - let buffered_amount_low_events = Rc::clone(&events); Self { enable_message_interleaving: false, default_priority: options.default_stream_priority, default_low_buffered_amount_low_threshold: options .default_stream_buffered_amount_low_threshold, - buffered_amount: ThresholdWatcher::new( - options.total_buffered_amount_low_threshold, - move || { - buffered_amount_low_events - .borrow_mut() - .add(SocketEvent::OnTotalBufferedAmountLow()); - }, - ), + buffered_amount: ThresholdWatcher::new(options.total_buffered_amount_low_threshold), current_message_id: OutgoingMessageId(0), streams: HashMap::new(), scheduler: StreamScheduler::new(max_payload_bytes), - events: Rc::clone(&events), + events, } } @@ -247,8 +233,12 @@ impl SendQueue { let item = stream.items.front().unwrap(); if item.attributes.expires_at <= now { // Oops, this entire message has already expired. Try the next one. - self.buffered_amount -= item.remaining_size; - stream.buffered_amount -= item.remaining_size; + if self.buffered_amount.sub(item.remaining_size) { + self.events.borrow_mut().add(SocketEvent::OnTotalBufferedAmountLow()); + } + if stream.buffered_amount.sub(item.remaining_size) { + self.events.borrow_mut().add(SocketEvent::OnBufferedAmountLow(stream_id)); + } add_lifecycle_events(&self.events, &item.attributes.lifecycle_id); stream.items.pop_front(); let priority = self.enable_message_interleaving.then_some(stream.priority); @@ -268,15 +258,8 @@ impl SendQueue { self.scheduler.peek(usize::MAX).is_some() } - fn make_stream( - stream_id: StreamId, - priority: u16, - low_threshold: usize, - events: Rc>, - ) -> OutgoingStream<'static> { - OutgoingStream::new(priority, low_threshold, move || { - events.borrow_mut().add(SocketEvent::OnBufferedAmountLow(stream_id)); - }) + fn make_stream(priority: u16, low_threshold: usize) -> OutgoingStream { + OutgoingStream::new(priority, low_threshold) } pub fn add(&mut self, now: SocketTime, message: Message, send_options: &SendOptions) { @@ -291,16 +274,14 @@ impl SendQueue { let stream_id = message.stream_id; let stream = self.streams.entry(stream_id).or_insert_with(|| { SendQueue::make_stream( - stream_id, self.default_priority, self.default_low_buffered_amount_low_threshold, - Rc::clone(&self.events), ) }); let message_id = self.current_message_id; self.current_message_id += 1; - stream.buffered_amount += message.payload.len(); - self.buffered_amount += message.payload.len(); + stream.buffered_amount.add(message.payload.len()); + self.buffered_amount.add(message.payload.len()); stream.items.push_back(Item::new(message_id, message, attributes)); if (stream.pause_state == PauseState::NotPaused || stream.pause_state == PauseState::Pending) @@ -338,8 +319,12 @@ impl SendQueue { .get(item.remaining_offset..size + item.remaining_offset) .unwrap() .to_vec(); - self.buffered_amount -= payload.len(); - stream.buffered_amount -= payload.len(); + if self.buffered_amount.sub(payload.len()) { + self.events.borrow_mut().add(SocketEvent::OnTotalBufferedAmountLow()); + } + if stream.buffered_amount.sub(payload.len()) { + self.events.borrow_mut().add(SocketEvent::OnBufferedAmountLow(stream_id)); + } let data = Data { stream_key: StreamKey::new(item.attributes.unordered, stream_id), @@ -392,8 +377,12 @@ impl SendQueue { if item.message_id != message_id { return; } - self.buffered_amount -= item.remaining_size; - stream.buffered_amount -= item.remaining_size; + if self.buffered_amount.sub(item.remaining_size) { + self.events.borrow_mut().add(SocketEvent::OnTotalBufferedAmountLow()); + } + if stream.buffered_amount.sub(item.remaining_size) { + self.events.borrow_mut().add(SocketEvent::OnBufferedAmountLow(stream_id)); + } add_lifecycle_events(&self.events, &item.attributes.lifecycle_id); stream.items.pop_front(); @@ -413,10 +402,8 @@ impl SendQueue { pub fn prepare_reset_stream(&mut self, stream_id: StreamId) { let stream = self.streams.entry(stream_id).or_insert_with(|| { SendQueue::make_stream( - stream_id, self.default_priority, self.default_low_buffered_amount_low_threshold, - Rc::clone(&self.events), ) }); if stream.pause_state != PauseState::NotPaused { @@ -442,8 +429,12 @@ impl SendQueue { // will always deliver all the fragments before actually resetting the stream. stream.items.retain_mut(|i| { if i.remaining_offset == 0 { - stream.buffered_amount -= i.remaining_size; - self.buffered_amount -= i.remaining_size; + if stream.buffered_amount.sub(i.remaining_size) { + self.events.borrow_mut().add(SocketEvent::OnBufferedAmountLow(stream_id)); + } + if self.buffered_amount.sub(i.remaining_size) { + self.events.borrow_mut().add(SocketEvent::OnTotalBufferedAmountLow()); + } add_lifecycle_events(&self.events, &i.attributes.lifecycle_id); return false; } @@ -465,14 +456,14 @@ impl SendQueue { } pub fn get_streams_ready_to_reset(&mut self) -> Vec { - let mut ready: Vec = Vec::new(); - self.streams.iter_mut().for_each(|(stream_id, stream)| { - if stream.pause_state == PauseState::Paused { + self.streams + .iter_mut() + .filter(|(_, stream)| stream.pause_state == PauseState::Paused) + .map(|(stream_id, stream)| { stream.pause_state = PauseState::Resetting; - ready.push(*stream_id); - } - }); - ready + *stream_id + }) + .collect() } pub fn commit_reset_streams(&mut self) { @@ -510,8 +501,8 @@ impl SendQueue { stream.next_ssn = Ssn(0); if let Some(item) = stream.items.front_mut() { let item_size = item.message.payload.len(); - self.buffered_amount += item_size - item.remaining_size; - stream.buffered_amount += item_size - item.remaining_size; + self.buffered_amount.add(item_size - item.remaining_size); + stream.buffered_amount.add(item_size - item.remaining_size); item.remaining_offset = 0; item.remaining_size = item_size; let priority = self.enable_message_interleaving.then_some(stream.priority); @@ -521,10 +512,7 @@ impl SendQueue { } pub fn buffered_amount(&self, stream_id: StreamId) -> usize { - match self.streams.get(&stream_id) { - Some(stream) => stream.buffered_amount.value, - None => 0, - } + self.streams.get(&stream_id).map_or(0, |s| s.buffered_amount.value) } pub fn total_buffered_amount(&self) -> usize { @@ -541,22 +529,20 @@ impl SendQueue { pub fn set_buffered_amount_low_threshold(&mut self, stream_id: StreamId, threshold: usize) { let stream = self.streams.entry(stream_id).or_insert_with(|| { SendQueue::make_stream( - stream_id, self.default_priority, self.default_low_buffered_amount_low_threshold, - Rc::clone(&self.events), ) }); - stream.buffered_amount.set_low_threshold(threshold); + if stream.buffered_amount.set_low_threshold(threshold) { + self.events.borrow_mut().add(SocketEvent::OnBufferedAmountLow(stream_id)); + } } pub fn set_priority(&mut self, stream_id: StreamId, priority: u16) { let stream = self.streams.entry(stream_id).or_insert_with(|| { SendQueue::make_stream( - stream_id, self.default_priority, self.default_low_buffered_amount_low_threshold, - Rc::clone(&self.events), ) }); stream.priority = priority; @@ -594,12 +580,8 @@ impl SendQueue { pub(crate) fn restore_from_state(&mut self, state: &SocketHandoverState) { state.tx.streams.iter().for_each(|s| { let stream_id = StreamId(s.id); - let mut stream = SendQueue::make_stream( - stream_id, - s.priority, - self.default_low_buffered_amount_low_threshold, - Rc::clone(&self.events), - ); + let mut stream = + SendQueue::make_stream(s.priority, self.default_low_buffered_amount_low_threshold); stream.next_ssn = Ssn(s.next_ssn); stream.next_unordered_mid = Mid(s.next_unordered_mid); stream.next_ordered_mid = Mid(s.next_ordered_mid);