Skip to content

Commit fc5597f

Browse files
authored
server: implement key expiration (#13)
1 parent 49b848b commit fc5597f

File tree

12 files changed

+348
-220
lines changed

12 files changed

+348
-220
lines changed

Cargo.lock

Lines changed: 116 additions & 128 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

src/bin/server.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ pub async fn main() -> Result<()> {
77
// enable logging
88
// see https://docs.rs/tracing for more info
99
tracing_subscriber::fmt::try_init().map_err(|e| anyhow!("{:?}", e))?;
10-
10+
1111
let cli = Cli::parse();
1212
let port = cli.port.unwrap_or(DEFAULT_PORT.to_string());
1313
server::run(&port).await

src/cmd/get.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::{Connection, Frame, Kv, Parse, ParseError};
1+
use crate::{Connection, Frame, Db, Parse, ParseError};
22

33
use std::io;
44
use tracing::{debug, instrument};
@@ -27,8 +27,8 @@ impl Get {
2727
}
2828

2929
#[instrument]
30-
pub(crate) async fn apply(self, kv: &Kv, dst: &mut Connection) -> io::Result<()> {
31-
let response = if let Some(value) = kv.get(&self.key) {
30+
pub(crate) async fn apply(self, db: &Db, dst: &mut Connection) -> io::Result<()> {
31+
let response = if let Some(value) = db.get(&self.key) {
3232
Frame::Bulk(value)
3333
} else {
3434
Frame::Null

src/cmd/mod.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pub use set::Set;
1010
mod subscribe;
1111
pub use subscribe::{Subscribe, Unsubscribe};
1212

13-
use crate::{Connection, Frame, Kv, Parse, ParseError, Shutdown};
13+
use crate::{Connection, Frame, Db, Parse, ParseError, Shutdown};
1414

1515
use std::io;
1616

@@ -44,17 +44,17 @@ impl Command {
4444

4545
pub(crate) async fn apply(
4646
self,
47-
kv: &Kv,
47+
db: &Db,
4848
dst: &mut Connection,
4949
shutdown: &mut Shutdown,
5050
) -> io::Result<()> {
5151
use Command::*;
5252

5353
match self {
54-
Get(cmd) => cmd.apply(kv, dst).await,
55-
Publish(cmd) => cmd.apply(kv, dst).await,
56-
Set(cmd) => cmd.apply(kv, dst).await,
57-
Subscribe(cmd) => cmd.apply(kv, dst, shutdown).await,
54+
Get(cmd) => cmd.apply(db, dst).await,
55+
Publish(cmd) => cmd.apply(db, dst).await,
56+
Set(cmd) => cmd.apply(db, dst).await,
57+
Subscribe(cmd) => cmd.apply(db, dst, shutdown).await,
5858
// `Unsubscribe` cannot be applied. It may only be received from the
5959
// context of a `Subscribe` command.
6060
Unsubscribe(_) => unimplemented!(),

src/cmd/publish.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use crate::{Connection, Frame, Kv, Parse, ParseError};
1+
use crate::{Connection, Frame, Db, Parse, ParseError};
22

33
use bytes::Bytes;
44
use std::io;
@@ -17,9 +17,9 @@ impl Publish {
1717
Ok(Publish { channel, message })
1818
}
1919

20-
pub(crate) async fn apply(self, kv: &Kv, dst: &mut Connection) -> io::Result<()> {
20+
pub(crate) async fn apply(self, db: &Db, dst: &mut Connection) -> io::Result<()> {
2121
// Set the value
22-
let num_subscribers = kv.publish(&self.channel, self.message);
22+
let num_subscribers = db.publish(&self.channel, self.message);
2323

2424
let response = Frame::Integer(num_subscribers as u64);
2525
dst.write_frame(&response).await

src/cmd/set.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::cmd::{Parse, ParseError};
2-
use crate::{Connection, Frame, Kv};
2+
use crate::{Connection, Frame, Db};
33

44
use bytes::Bytes;
55
use std::io;
@@ -42,9 +42,9 @@ impl Set {
4242
}
4343

4444
#[instrument]
45-
pub(crate) async fn apply(self, kv: &Kv, dst: &mut Connection) -> io::Result<()> {
45+
pub(crate) async fn apply(self, db: &Db, dst: &mut Connection) -> io::Result<()> {
4646
// Set the value
47-
kv.set(self.key, self.value, self.expire);
47+
db.set(self.key, self.value, self.expire);
4848

4949
let response = Frame::Simple("OK".to_string());
5050
debug!(?response);

src/cmd/subscribe.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use crate::cmd::{Parse, ParseError};
2-
use crate::{Command, Connection, Frame, Kv, Shutdown};
2+
use crate::{Command, Connection, Frame, Db, Shutdown};
33

44
use bytes::Bytes;
55
use std::io;
@@ -45,7 +45,7 @@ impl Subscribe {
4545
/// [here]: https://redis.io/topics/pubsub
4646
pub(crate) async fn apply(
4747
mut self,
48-
kv: &Kv,
48+
db: &Db,
4949
dst: &mut Connection,
5050
shutdown: &mut Shutdown,
5151
) -> io::Result<()> {
@@ -71,7 +71,7 @@ impl Subscribe {
7171
response.push_bulk(Bytes::copy_from_slice(channel.as_bytes()));
7272

7373
// Subscribe to channel
74-
let rx = kv.subscribe(channel.clone());
74+
let rx = db.subscribe(channel.clone());
7575

7676
// Track subscription in this client's subscription set.
7777
subscriptions.insert(channel, rx);

src/db.rs

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
use tokio::sync::{broadcast, Notify};
2+
use tokio::time::{self, Duration, Instant};
3+
4+
use bytes::Bytes;
5+
use std::collections::{BTreeMap, HashMap};
6+
use std::sync::{Arc, Mutex};
7+
8+
#[derive(Debug, Clone)]
9+
pub(crate) struct Db {
10+
shared: Arc<Shared>,
11+
}
12+
13+
#[derive(Debug)]
14+
struct Shared {
15+
state: Mutex<State>,
16+
17+
/// Notifies the task handling entry expiration
18+
expire_task: Notify,
19+
}
20+
21+
#[derive(Debug)]
22+
struct State {
23+
/// The key-value data
24+
entries: HashMap<String, Entry>,
25+
26+
/// The pub/sub key-space
27+
pub_sub: HashMap<String, broadcast::Sender<Bytes>>,
28+
29+
/// Tracks key TTLs.
30+
expirations: BTreeMap<(Instant, u64), String>,
31+
32+
/// Identifier to use for the next expiration.
33+
next_id: u64,
34+
}
35+
36+
/// Entry in the key-value store
37+
#[derive(Debug)]
38+
struct Entry {
39+
/// Uniquely identifies this entry.
40+
id: u64,
41+
42+
/// Stored data
43+
data: Bytes,
44+
45+
/// Instant at which the entry expires and should be removed from the
46+
/// database.
47+
expires_at: Option<Instant>,
48+
}
49+
50+
impl Db {
51+
pub(crate) fn new() -> Db {
52+
let shared = Arc::new(Shared {
53+
state: Mutex::new(State {
54+
entries: HashMap::new(),
55+
pub_sub: HashMap::new(),
56+
expirations: BTreeMap::new(),
57+
next_id: 0,
58+
}),
59+
expire_task: Notify::new(),
60+
});
61+
62+
// Start the background task.
63+
tokio::spawn(purge_expired_tasks(shared.clone()));
64+
65+
Db { shared }
66+
}
67+
68+
pub(crate) fn get(&self, key: &str) -> Option<Bytes> {
69+
let state = self.shared.state.lock().unwrap();
70+
state.entries.get(key).map(|entry| entry.data.clone())
71+
}
72+
73+
pub(crate) fn set(&self, key: String, value: Bytes, expire: Option<Duration>) {
74+
let mut state = self.shared.state.lock().unwrap();
75+
76+
// Get and increment the next insertion ID.
77+
let id = state.next_id;
78+
state.next_id += 1;
79+
80+
// By default, no notification is needed
81+
let mut notify = false;
82+
83+
let expires_at = expire.map(|duration| {
84+
let when = Instant::now() + duration;
85+
86+
// Only notify the worker task if the newly inserted expiration is the
87+
// **next** key to evict. In this case, the worker needs to be woken up
88+
// to update its state.
89+
notify = state.next_expiration()
90+
.map(|expiration| expiration > when)
91+
.unwrap_or(true);
92+
93+
state.expirations.insert((when, id), key.clone());
94+
when
95+
});
96+
97+
// Insert the entry.
98+
let prev = state.entries.insert(key, Entry {
99+
id,
100+
data: value,
101+
expires_at,
102+
});
103+
104+
if let Some(prev) = prev {
105+
if let Some(when) = prev.expires_at {
106+
// clear expiration
107+
state.expirations.remove(&(when, prev.id));
108+
}
109+
}
110+
111+
drop(state);
112+
113+
if notify {
114+
self.shared.expire_task.notify();
115+
}
116+
}
117+
118+
pub(crate) fn subscribe(&self, key: String) -> broadcast::Receiver<Bytes> {
119+
use std::collections::hash_map::Entry;
120+
121+
let mut state = self.shared.state.lock().unwrap();
122+
123+
match state.pub_sub.entry(key) {
124+
Entry::Occupied(e) => e.get().subscribe(),
125+
Entry::Vacant(e) => {
126+
let (tx, rx) = broadcast::channel(1028);
127+
e.insert(tx);
128+
rx
129+
}
130+
}
131+
}
132+
133+
/// Publish a message to the channel. Returns the number of subscribers
134+
/// listening on the channel.
135+
pub(crate) fn publish(&self, key: &str, value: Bytes) -> usize {
136+
let state = self.shared.state.lock().unwrap();
137+
138+
state
139+
.pub_sub
140+
.get(key)
141+
// On a successful message send on the broadcast channel, the number
142+
// of subscribers is returned. An error indicates there are no
143+
// receivers, in which case, `0` should be returned.
144+
.map(|tx| tx.send(value).unwrap_or(0))
145+
// If there is no entry for the channel key, then there are no
146+
// subscribers. In this case, return `0`.
147+
.unwrap_or(0)
148+
}
149+
}
150+
151+
impl Shared {
152+
fn purge_expired_keys(&self) -> Option<Instant> {
153+
let mut state = self.state.lock().unwrap();
154+
155+
// This is needed to make the borrow checker happy. In short, `lock()`
156+
// returns a `MutexGuard` and not a `&mut State`. The borrow checker is
157+
// not able to see "through" the mutex guard and determine that it is
158+
// safe to access both `state.expirations` and `state.entries` mutably,
159+
// so we get a "real" mutable reference to `State` outside of the loop.
160+
let state = &mut *state;
161+
162+
// Find all keys scheduled to expire **before** now.
163+
let now = Instant::now();
164+
165+
while let Some((&(when, id), key)) = state.expirations.iter().next() {
166+
if when > now {
167+
// Done purging, `when` is the instant at which the next key
168+
// expires. The worker task will wait until this instant.
169+
return Some(when);
170+
}
171+
172+
// The key expired, remove it
173+
state.entries.remove(key);
174+
state.expirations.remove(&(when, id));
175+
}
176+
177+
None
178+
}
179+
}
180+
181+
impl State {
182+
fn next_expiration(&self) -> Option<Instant> {
183+
self.expirations.keys().next().map(|expiration| expiration.0)
184+
}
185+
}
186+
187+
async fn purge_expired_tasks(shared: Arc<Shared>) {
188+
loop {
189+
// Purge all keys that are expired. The function returns the instant at
190+
// which the **next** key will expire. The worker should wait until the
191+
// instant has passed then purge again.
192+
if let Some(when) = shared.purge_expired_keys() {
193+
tokio::select! {
194+
_ = time::delay_until(when) => {}
195+
_ = shared.expire_task.notified() => {}
196+
}
197+
} else {
198+
shared.expire_task.notified().await;
199+
}
200+
}
201+
}

src/kv.rs

Lines changed: 0 additions & 65 deletions
This file was deleted.

src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,8 @@ use conn::Connection;
1111
mod frame;
1212
use frame::Frame;
1313

14-
mod kv;
15-
use kv::Kv;
14+
mod db;
15+
use db::Db;
1616

1717
mod parse;
1818
use parse::{Parse, ParseError};

0 commit comments

Comments
 (0)