Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 72 additions & 90 deletions src/tx/send_queue.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,6 @@ use crate::types::StreamKey;
use std::cell::RefCell;
use std::collections::HashMap;
use std::collections::VecDeque;
use std::ops::AddAssign;
use std::ops::SubAssign;
use std::rc::Rc;
use std::time::Duration;

Expand Down Expand Up @@ -115,10 +113,9 @@ impl Item {
}
}

struct ThresholdWatcher<'a> {
struct ThresholdWatcher {
value: usize,
low_threshold: usize,
low_cb: Box<dyn Fn() + 'a>,
}

fn add_lifecycle_events(events: &Rc<RefCell<dyn EventSink>>, lifecycle_id: &Option<LifecycleId>) {
Expand All @@ -128,57 +125,54 @@ fn add_lifecycle_events(events: &Rc<RefCell<dyn EventSink>>, lifecycle_id: &Opti
}
}

impl<'a> ThresholdWatcher<'a> {
pub fn new(low_threshold: usize, low_cb: impl Fn() + 'a) -> Self {
Self { value: 0, low_threshold, low_cb: Box::new(low_cb) }
impl ThresholdWatcher {
pub fn new(low_threshold: usize) -> Self {
Self { value: 0, low_threshold }
}

pub fn set_low_threshold(&mut self, low_threshold: usize) {
if self.low_threshold < self.value && low_threshold >= self.value {
(self.low_cb)();
}
self.low_threshold = low_threshold;
pub fn add(&mut self, amount: usize) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment here for consistency

self.value += amount;
}
}

impl AddAssign<usize> for ThresholdWatcher<'_> {
fn add_assign(&mut self, rhs: usize) {
self.value += rhs;
// Returns `true` if the threshold was crossed downwards.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/// To make it a doc comment (since it's pub)

#[must_use]
pub fn sub(&mut self, amount: usize) -> bool {
debug_assert!(self.value >= amount);
let old_value = self.value;
self.value -= amount;
old_value > self.low_threshold && self.value <= self.low_threshold
}
}

impl SubAssign<usize> for ThresholdWatcher<'_> {
fn sub_assign(&mut self, rhs: usize) {
debug_assert!(self.value >= rhs);

let old_value = self.value;
self.value -= rhs;
if old_value > self.low_threshold && self.value <= self.low_threshold {
(self.low_cb)();
}
// Returns `true` if the new threshold is lower than the current value, which means that the
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

// new low threshold was crossed.
#[must_use]
pub fn set_low_threshold(&mut self, low_threshold: usize) -> bool {
let triggered = self.low_threshold < self.value && low_threshold >= self.value;
self.low_threshold = low_threshold;
triggered
}
}

/// Per-stream information.
struct OutgoingStream<'a> {
struct OutgoingStream {
priority: u16,
pause_state: PauseState,
next_unordered_mid: Mid,
next_ordered_mid: Mid,
next_ssn: Ssn,
buffered_amount: ThresholdWatcher<'a>,
buffered_amount: ThresholdWatcher,
items: VecDeque<Item>,
}

impl<'a> OutgoingStream<'a> {
fn new(priority: u16, low_threshold: usize, low_cb: impl Fn() + 'a) -> Self {
impl OutgoingStream {
fn new(priority: u16, low_threshold: usize) -> Self {
Self {
priority,
pause_state: PauseState::NotPaused,
next_unordered_mid: Mid(0),
next_ordered_mid: Mid(0),
next_ssn: Ssn(0),
buffered_amount: ThresholdWatcher::new(low_threshold, low_cb),
buffered_amount: ThresholdWatcher::new(low_threshold),
items: VecDeque::new(),
}
}
Expand All @@ -188,10 +182,10 @@ pub struct SendQueue {
enable_message_interleaving: bool,
default_priority: u16,
default_low_buffered_amount_low_threshold: usize,
buffered_amount: ThresholdWatcher<'static>,
buffered_amount: ThresholdWatcher,
current_message_id: OutgoingMessageId,
scheduler: StreamScheduler,
streams: HashMap<StreamId, OutgoingStream<'static>>,
streams: HashMap<StreamId, OutgoingStream>,
events: Rc<RefCell<dyn EventSink>>,
}

Expand All @@ -201,24 +195,16 @@ impl SendQueue {
options: &Options,
events: Rc<RefCell<dyn EventSink>>,
) -> Self {
let buffered_amount_low_events = Rc::clone(&events);
Self {
enable_message_interleaving: false,
default_priority: options.default_stream_priority,
default_low_buffered_amount_low_threshold: options
.default_stream_buffered_amount_low_threshold,
buffered_amount: ThresholdWatcher::new(
options.total_buffered_amount_low_threshold,
move || {
buffered_amount_low_events
.borrow_mut()
.add(SocketEvent::OnTotalBufferedAmountLow());
},
),
buffered_amount: ThresholdWatcher::new(options.total_buffered_amount_low_threshold),
current_message_id: OutgoingMessageId(0),
streams: HashMap::new(),
scheduler: StreamScheduler::new(max_payload_bytes),
events: Rc::clone(&events),
events,
}
}

Expand Down Expand Up @@ -247,8 +233,12 @@ impl SendQueue {
let item = stream.items.front().unwrap();
if item.attributes.expires_at <= now {
// Oops, this entire message has already expired. Try the next one.
self.buffered_amount -= item.remaining_size;
stream.buffered_amount -= item.remaining_size;
if self.buffered_amount.sub(item.remaining_size) {
self.events.borrow_mut().add(SocketEvent::OnTotalBufferedAmountLow());
}
if stream.buffered_amount.sub(item.remaining_size) {
self.events.borrow_mut().add(SocketEvent::OnBufferedAmountLow(stream_id));
}
add_lifecycle_events(&self.events, &item.attributes.lifecycle_id);
stream.items.pop_front();
let priority = self.enable_message_interleaving.then_some(stream.priority);
Expand All @@ -268,15 +258,8 @@ impl SendQueue {
self.scheduler.peek(usize::MAX).is_some()
}

fn make_stream(
stream_id: StreamId,
priority: u16,
low_threshold: usize,
events: Rc<RefCell<dyn EventSink>>,
) -> OutgoingStream<'static> {
OutgoingStream::new(priority, low_threshold, move || {
events.borrow_mut().add(SocketEvent::OnBufferedAmountLow(stream_id));
})
fn make_stream(priority: u16, low_threshold: usize) -> OutgoingStream {
OutgoingStream::new(priority, low_threshold)
}

pub fn add(&mut self, now: SocketTime, message: Message, send_options: &SendOptions) {
Expand All @@ -291,16 +274,14 @@ impl SendQueue {
let stream_id = message.stream_id;
let stream = self.streams.entry(stream_id).or_insert_with(|| {
SendQueue::make_stream(
stream_id,
self.default_priority,
self.default_low_buffered_amount_low_threshold,
Rc::clone(&self.events),
)
});
let message_id = self.current_message_id;
self.current_message_id += 1;
stream.buffered_amount += message.payload.len();
self.buffered_amount += message.payload.len();
stream.buffered_amount.add(message.payload.len());
self.buffered_amount.add(message.payload.len());
stream.items.push_back(Item::new(message_id, message, attributes));
if (stream.pause_state == PauseState::NotPaused
|| stream.pause_state == PauseState::Pending)
Expand Down Expand Up @@ -338,8 +319,12 @@ impl SendQueue {
.get(item.remaining_offset..size + item.remaining_offset)
.unwrap()
.to_vec();
self.buffered_amount -= payload.len();
stream.buffered_amount -= payload.len();
if self.buffered_amount.sub(payload.len()) {
self.events.borrow_mut().add(SocketEvent::OnTotalBufferedAmountLow());
}
if stream.buffered_amount.sub(payload.len()) {
self.events.borrow_mut().add(SocketEvent::OnBufferedAmountLow(stream_id));
}

let data = Data {
stream_key: StreamKey::new(item.attributes.unordered, stream_id),
Expand Down Expand Up @@ -392,8 +377,12 @@ impl SendQueue {
if item.message_id != message_id {
return;
}
self.buffered_amount -= item.remaining_size;
stream.buffered_amount -= item.remaining_size;
if self.buffered_amount.sub(item.remaining_size) {
self.events.borrow_mut().add(SocketEvent::OnTotalBufferedAmountLow());
}
if stream.buffered_amount.sub(item.remaining_size) {
self.events.borrow_mut().add(SocketEvent::OnBufferedAmountLow(stream_id));
}
add_lifecycle_events(&self.events, &item.attributes.lifecycle_id);
stream.items.pop_front();

Expand All @@ -413,10 +402,8 @@ impl SendQueue {
pub fn prepare_reset_stream(&mut self, stream_id: StreamId) {
let stream = self.streams.entry(stream_id).or_insert_with(|| {
SendQueue::make_stream(
stream_id,
self.default_priority,
self.default_low_buffered_amount_low_threshold,
Rc::clone(&self.events),
)
});
if stream.pause_state != PauseState::NotPaused {
Expand All @@ -442,8 +429,12 @@ impl SendQueue {
// will always deliver all the fragments before actually resetting the stream.
stream.items.retain_mut(|i| {
if i.remaining_offset == 0 {
stream.buffered_amount -= i.remaining_size;
self.buffered_amount -= i.remaining_size;
if stream.buffered_amount.sub(i.remaining_size) {
self.events.borrow_mut().add(SocketEvent::OnBufferedAmountLow(stream_id));
}
if self.buffered_amount.sub(i.remaining_size) {
self.events.borrow_mut().add(SocketEvent::OnTotalBufferedAmountLow());
}
add_lifecycle_events(&self.events, &i.attributes.lifecycle_id);
return false;
}
Expand All @@ -465,14 +456,14 @@ impl SendQueue {
}

pub fn get_streams_ready_to_reset(&mut self) -> Vec<StreamId> {
let mut ready: Vec<StreamId> = Vec::new();
self.streams.iter_mut().for_each(|(stream_id, stream)| {
if stream.pause_state == PauseState::Paused {
self.streams
.iter_mut()
.filter(|(_, stream)| stream.pause_state == PauseState::Paused)
.map(|(stream_id, stream)| {
stream.pause_state = PauseState::Resetting;
ready.push(*stream_id);
}
});
ready
*stream_id
})
.collect()
}

pub fn commit_reset_streams(&mut self) {
Expand Down Expand Up @@ -510,8 +501,8 @@ impl SendQueue {
stream.next_ssn = Ssn(0);
if let Some(item) = stream.items.front_mut() {
let item_size = item.message.payload.len();
self.buffered_amount += item_size - item.remaining_size;
stream.buffered_amount += item_size - item.remaining_size;
self.buffered_amount.add(item_size - item.remaining_size);
stream.buffered_amount.add(item_size - item.remaining_size);
item.remaining_offset = 0;
item.remaining_size = item_size;
let priority = self.enable_message_interleaving.then_some(stream.priority);
Expand All @@ -521,10 +512,7 @@ impl SendQueue {
}

pub fn buffered_amount(&self, stream_id: StreamId) -> usize {
match self.streams.get(&stream_id) {
Some(stream) => stream.buffered_amount.value,
None => 0,
}
self.streams.get(&stream_id).map_or(0, |s| s.buffered_amount.value)
}

pub fn total_buffered_amount(&self) -> usize {
Expand All @@ -541,22 +529,20 @@ impl SendQueue {
pub fn set_buffered_amount_low_threshold(&mut self, stream_id: StreamId, threshold: usize) {
let stream = self.streams.entry(stream_id).or_insert_with(|| {
SendQueue::make_stream(
stream_id,
self.default_priority,
self.default_low_buffered_amount_low_threshold,
Rc::clone(&self.events),
)
});
stream.buffered_amount.set_low_threshold(threshold);
if stream.buffered_amount.set_low_threshold(threshold) {
self.events.borrow_mut().add(SocketEvent::OnBufferedAmountLow(stream_id));
}
}

pub fn set_priority(&mut self, stream_id: StreamId, priority: u16) {
let stream = self.streams.entry(stream_id).or_insert_with(|| {
SendQueue::make_stream(
stream_id,
self.default_priority,
self.default_low_buffered_amount_low_threshold,
Rc::clone(&self.events),
)
});
stream.priority = priority;
Expand Down Expand Up @@ -594,12 +580,8 @@ impl SendQueue {
pub(crate) fn restore_from_state(&mut self, state: &SocketHandoverState) {
state.tx.streams.iter().for_each(|s| {
let stream_id = StreamId(s.id);
let mut stream = SendQueue::make_stream(
stream_id,
s.priority,
self.default_low_buffered_amount_low_threshold,
Rc::clone(&self.events),
);
let mut stream =
SendQueue::make_stream(s.priority, self.default_low_buffered_amount_low_threshold);
stream.next_ssn = Ssn(s.next_ssn);
stream.next_unordered_mid = Mid(s.next_unordered_mid);
stream.next_ordered_mid = Mid(s.next_ordered_mid);
Expand Down
Loading