Skip to content

Commit 0ecbe78

Browse files
committed
chore(pegboard): consolidate to single subscriber per gateway
1 parent 911fd90 commit 0ecbe78

File tree

33 files changed

+1783
-1663
lines changed

33 files changed

+1783
-1663
lines changed

Cargo.lock

Lines changed: 5 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/common/universalpubsub/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ async-trait.workspace = true
1212
base64.workspace = true
1313
deadpool-postgres.workspace = true
1414
futures-util.workspace = true
15-
moka.workspace = true
1615
rivet-error.workspace = true
1716
rivet-ups-protocol.workspace = true
1817
serde_json.workspace = true
@@ -21,6 +20,7 @@ serde.workspace = true
2120
sha2.workspace = true
2221
tokio-postgres.workspace = true
2322
tokio.workspace = true
23+
tokio-util.workspace = true
2424
tracing.workspace = true
2525
uuid.workspace = true
2626

packages/common/universalpubsub/src/driver/postgres/mod.rs

Lines changed: 73 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1+
use std::collections::HashMap;
12
use std::hash::{DefaultHasher, Hash, Hasher};
2-
use std::sync::Arc;
3+
use std::sync::{Arc, Mutex};
34

45
use anyhow::*;
56
use async_trait::async_trait;
67
use base64::Engine;
78
use base64::engine::general_purpose::STANDARD_NO_PAD as BASE64;
89
use deadpool_postgres::{Config, ManagerConfig, Pool, PoolConfig, RecyclingMethod, Runtime};
910
use futures_util::future::poll_fn;
10-
use moka::future::Cache;
1111
use tokio_postgres::{AsyncMessage, NoTls};
1212
use tracing::Instrument;
1313

@@ -18,6 +18,15 @@ use crate::pubsub::DriverOutput;
1818
struct Subscription {
1919
// Channel to send messages to this subscription
2020
tx: tokio::sync::broadcast::Sender<Vec<u8>>,
21+
// Cancellation token shared by all subscribers of this subject
22+
token: tokio_util::sync::CancellationToken,
23+
}
24+
25+
impl Subscription {
26+
fn new(tx: tokio::sync::broadcast::Sender<Vec<u8>>) -> Self {
27+
let token = tokio_util::sync::CancellationToken::new();
28+
Self { tx, token }
29+
}
2130
}
2231

2332
/// > In the default configuration it must be shorter than 8000 bytes
@@ -40,7 +49,7 @@ pub const POSTGRES_MAX_MESSAGE_SIZE: usize =
4049
pub struct PostgresDriver {
4150
pool: Arc<Pool>,
4251
client: Arc<tokio_postgres::Client>,
43-
subscriptions: Cache<String, Subscription>,
52+
subscriptions: Arc<Mutex<HashMap<String, Subscription>>>,
4453
}
4554

4655
impl PostgresDriver {
@@ -65,8 +74,8 @@ impl PostgresDriver {
6574
.context("failed to create postgres pool")?;
6675
tracing::debug!("postgres pool created successfully");
6776

68-
let subscriptions: Cache<String, Subscription> =
69-
Cache::builder().initial_capacity(5).build();
77+
let subscriptions: Arc<Mutex<HashMap<String, Subscription>>> =
78+
Arc::new(Mutex::new(HashMap::new()));
7079
let subscriptions2 = subscriptions.clone();
7180

7281
let (client, mut conn) = tokio_postgres::connect(&conn_str, tokio_postgres::NoTls).await?;
@@ -75,7 +84,9 @@ impl PostgresDriver {
7584
loop {
7685
match poll_fn(|cx| conn.poll_message(cx)).await {
7786
Some(std::result::Result::Ok(AsyncMessage::Notification(note))) => {
78-
if let Some(sub) = subscriptions2.get(note.channel()).await {
87+
if let Some(sub) =
88+
subscriptions2.lock().unwrap().get(note.channel()).cloned()
89+
{
7990
let bytes = match BASE64.decode(note.payload()) {
8091
std::result::Result::Ok(b) => b,
8192
std::result::Result::Err(err) => {
@@ -121,7 +132,7 @@ impl PostgresDriver {
121132
#[async_trait]
122133
impl PubSubDriver for PostgresDriver {
123134
async fn subscribe(&self, subject: &str) -> Result<SubscriberDriverHandle> {
124-
// TODO: To match NATS implementation, LIST must be pipelined (i.e. wait for the command
135+
// TODO: To match NATS implementation, LISTEN must be pipelined (i.e. wait for the command
125136
// to reach the server, but not wait for it to respond). However, this has to ensure that
126137
// NOTIFY & LISTEN are called on the same connection (not diff connections in a pool) or
127138
// else there will be race conditions where messages might be published before
@@ -135,33 +146,57 @@ impl PubSubDriver for PostgresDriver {
135146
let hashed = self.hash_subject(subject);
136147

137148
// Check if we already have a subscription for this channel
138-
let rx = if let Some(existing_sub) = self.subscriptions.get(&hashed).await {
139-
// Reuse the existing broadcast channel
140-
existing_sub.tx.subscribe()
141-
} else {
142-
// Create a new broadcast channel for this subject
143-
let (tx, rx) = tokio::sync::broadcast::channel(1024);
144-
let subscription = Subscription { tx: tx.clone() };
145-
146-
// Register subscription
147-
self.subscriptions
148-
.insert(hashed.clone(), subscription)
149-
.await;
150-
151-
// Execute LISTEN command on the async client (for receiving notifications)
152-
// This only needs to be done once per channel
153-
let span = tracing::trace_span!("pg_listen");
154-
self.client
155-
.execute(&format!("LISTEN \"{hashed}\""), &[])
156-
.instrument(span)
157-
.await?;
158-
159-
rx
160-
};
149+
let (rx, drop_guard) =
150+
if let Some(existing_sub) = self.subscriptions.lock().unwrap().get(&hashed).cloned() {
151+
// Reuse the existing broadcast channel
152+
let rx = existing_sub.tx.subscribe();
153+
let drop_guard = existing_sub.token.clone().drop_guard();
154+
(rx, drop_guard)
155+
} else {
156+
// Create a new broadcast channel for this subject
157+
let (tx, rx) = tokio::sync::broadcast::channel(1024);
158+
let subscription = Subscription::new(tx.clone());
159+
160+
// Register subscription
161+
self.subscriptions
162+
.lock()
163+
.unwrap()
164+
.insert(hashed.clone(), subscription.clone());
165+
166+
// Execute LISTEN command on the async client (for receiving notifications)
167+
// This only needs to be done once per channel
168+
let span = tracing::trace_span!("pg_listen");
169+
self.client
170+
.execute(&format!("LISTEN \"{hashed}\""), &[])
171+
.instrument(span)
172+
.await?;
173+
174+
// Spawn a single cleanup task for this subscription waiting on its token
175+
let driver = self.clone();
176+
let hashed_clone = hashed.clone();
177+
let tx_clone = tx.clone();
178+
let token_clone = subscription.token.clone();
179+
tokio::spawn(async move {
180+
token_clone.cancelled().await;
181+
if tx_clone.receiver_count() == 0 {
182+
let sql = format!("UNLISTEN \"{}\"", hashed_clone);
183+
if let Err(err) = driver.client.execute(sql.as_str(), &[]).await {
184+
tracing::warn!(?err, %hashed_clone, "failed to UNLISTEN channel");
185+
} else {
186+
tracing::trace!(%hashed_clone, "unlistened channel");
187+
}
188+
driver.subscriptions.lock().unwrap().remove(&hashed_clone);
189+
}
190+
});
191+
192+
let drop_guard = subscription.token.clone().drop_guard();
193+
(rx, drop_guard)
194+
};
161195

162196
Ok(Box::new(PostgresSubscriber {
163197
subject: subject.to_string(),
164-
rx,
198+
rx: Some(rx),
199+
_drop_guard: drop_guard,
165200
}))
166201
}
167202

@@ -191,13 +226,18 @@ impl PubSubDriver for PostgresDriver {
191226

192227
pub struct PostgresSubscriber {
193228
subject: String,
194-
rx: tokio::sync::broadcast::Receiver<Vec<u8>>,
229+
rx: Option<tokio::sync::broadcast::Receiver<Vec<u8>>>,
230+
_drop_guard: tokio_util::sync::DropGuard,
195231
}
196232

197233
#[async_trait]
198234
impl SubscriberDriver for PostgresSubscriber {
199235
async fn next(&mut self) -> Result<DriverOutput> {
200-
match self.rx.recv().await {
236+
let rx = match self.rx.as_mut() {
237+
Some(rx) => rx,
238+
None => return Ok(DriverOutput::Unsubscribed),
239+
};
240+
match rx.recv().await {
201241
std::result::Result::Ok(payload) => Ok(DriverOutput::Message {
202242
subject: self.subject.clone(),
203243
payload,

packages/core/guard/server/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ hyper-tungstenite.workspace = true
2020
tower.workspace = true
2121
udb-util.workspace = true
2222
universaldb.workspace = true
23+
universalpubsub.workspace = true
2324
futures.workspace = true
2425
# TODO: Make this use workspace version
2526
hyper = "1.6.0"

packages/core/guard/server/src/lib.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ pub mod cache;
55
pub mod errors;
66
pub mod middleware;
77
pub mod routing;
8+
pub mod shared_state;
89
pub mod tls;
910

1011
#[tracing::instrument(skip_all)]
@@ -26,8 +27,12 @@ pub async fn start(config: rivet_config::Config, pools: rivet_pools::Pools) -> R
2627
tracing::warn!("crypto provider already installed in this process");
2728
}
2829

30+
// Share shared context
31+
let shared_state = shared_state::SharedState::new(ctx.ups()?);
32+
shared_state.start().await?;
33+
2934
// Create handlers
30-
let routing_fn = routing::create_routing_function(ctx.clone());
35+
let routing_fn = routing::create_routing_function(ctx.clone(), shared_state.clone());
3136
let cache_key_fn = cache::create_cache_key_function(ctx.clone());
3237
let middleware_fn = middleware::create_middleware_function(ctx.clone());
3338
let cert_resolver = tls::create_cert_resolver(&ctx).await?;

packages/core/guard/server/src/routing/mod.rs

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use gas::prelude::*;
55
use hyper::header::HeaderName;
66
use rivet_guard_core::RoutingFn;
77

8-
use crate::errors;
8+
use crate::{errors, shared_state::SharedState};
99

1010
mod api_peer;
1111
mod api_public;
@@ -17,13 +17,14 @@ pub(crate) const X_RIVET_TARGET: HeaderName = HeaderName::from_static("x-rivet-t
1717

1818
/// Creates the main routing function that handles all incoming requests
1919
#[tracing::instrument(skip_all)]
20-
pub fn create_routing_function(ctx: StandaloneCtx) -> RoutingFn {
20+
pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) -> RoutingFn {
2121
Arc::new(
2222
move |hostname: &str,
2323
path: &str,
2424
port_type: rivet_guard_core::proxy_service::PortType,
2525
headers: &hyper::HeaderMap| {
2626
let ctx = ctx.clone();
27+
let shared_state = shared_state.clone();
2728

2829
Box::pin(
2930
async move {
@@ -41,9 +42,15 @@ pub fn create_routing_function(ctx: StandaloneCtx) -> RoutingFn {
4142
return Ok(routing_output);
4243
}
4344

44-
if let Some(routing_output) =
45-
pegboard_gateway::route_request(&ctx, target, host, path, headers)
46-
.await?
45+
if let Some(routing_output) = pegboard_gateway::route_request(
46+
&ctx,
47+
&shared_state,
48+
target,
49+
host,
50+
path,
51+
headers,
52+
)
53+
.await?
4754
{
4855
return Ok(routing_output);
4956
}

packages/core/guard/server/src/routing/pegboard_gateway.rs

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ use hyper::header::HeaderName;
66
use rivet_guard_core::proxy_service::{RouteConfig, RouteTarget, RoutingOutput, RoutingTimeout};
77
use udb_util::{SERIALIZABLE, TxnExt};
88

9-
use crate::errors;
9+
use crate::{errors, shared_state::SharedState};
1010

1111
const ACTOR_READY_TIMEOUT: Duration = Duration::from_secs(10);
1212
pub const X_RIVET_ACTOR: HeaderName = HeaderName::from_static("x-rivet-actor");
@@ -16,6 +16,7 @@ pub const X_RIVET_PORT: HeaderName = HeaderName::from_static("x-rivet-port");
1616
#[tracing::instrument(skip_all)]
1717
pub async fn route_request(
1818
ctx: &StandaloneCtx,
19+
shared_state: &SharedState,
1920
target: &str,
2021
_host: &str,
2122
path: &str,
@@ -73,7 +74,7 @@ pub async fn route_request(
7374
let port_name = port_name.to_str()?;
7475

7576
// Lookup actor
76-
find_actor(ctx, actor_id, port_name, path).await
77+
find_actor(ctx, shared_state, actor_id, port_name, path).await
7778
}
7879

7980
struct FoundActor {
@@ -86,6 +87,7 @@ struct FoundActor {
8687
#[tracing::instrument(skip_all, fields(%actor_id, %port_name, %path))]
8788
async fn find_actor(
8889
ctx: &StandaloneCtx,
90+
shared_state: &SharedState,
8991
actor_id: Id,
9092
port_name: &str,
9193
path: &str,
@@ -158,10 +160,10 @@ async fn find_actor(
158160
actor_ids: vec![actor_id],
159161
});
160162
let res = tokio::time::timeout(Duration::from_secs(5), get_runner_fut).await??;
161-
let runner_info = res.actors.into_iter().next().filter(|x| x.is_connectable);
163+
let actor = res.actors.into_iter().next().filter(|x| x.is_connectable);
162164

163-
let runner_id = if let Some(runner_info) = runner_info {
164-
runner_info.runner_id
165+
let runner_id = if let Some(actor) = actor {
166+
actor.runner_id
165167
} else {
166168
tracing::info!(?actor_id, "waiting for actor to become ready");
167169

@@ -185,11 +187,23 @@ async fn find_actor(
185187

186188
tracing::debug!(?actor_id, ?runner_id, "actor ready");
187189

190+
// Get runner key from runner_id
191+
let runner_key = ctx
192+
.udb()?
193+
.run(|tx, _mc| async move {
194+
let txs = tx.subspace(pegboard::keys::subspace());
195+
let key_key = pegboard::keys::runner::KeyKey::new(runner_id);
196+
txs.read_opt(&key_key, SERIALIZABLE).await
197+
})
198+
.await?
199+
.context("runner key not found")?;
200+
188201
// Return pegboard-gateway instance
189202
let gateway = pegboard_gateway::PegboardGateway::new(
190203
ctx.clone(),
204+
shared_state.pegboard_gateway.clone(),
191205
actor_id,
192-
runner_id,
206+
runner_key,
193207
port_name.to_string(),
194208
);
195209
Ok(Some(RoutingOutput::CustomServe(std::sync::Arc::new(

packages/core/guard/server/src/routing/pegboard_tunnel.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,10 @@ pub async fn route_request(
1212
_host: &str,
1313
_path: &str,
1414
) -> Result<Option<RoutingOutput>> {
15-
// Check target
1615
if target != "tunnel" {
1716
return Ok(None);
1817
}
1918

20-
// Create pegboard-tunnel service instance
21-
let tunnel = pegboard_tunnel::PegboardTunnelCustomServe::new(ctx.clone()).await?;
22-
19+
let tunnel = pegboard_tunnel::PegboardTunnelCustomServe::new(ctx.clone());
2320
Ok(Some(RoutingOutput::CustomServe(Arc::new(tunnel))))
2421
}

0 commit comments

Comments
 (0)