Skip to content

Commit 5303da4

Browse files
committed
libsql: add offline sync unit tests
1 parent 6934c65 commit 5303da4

File tree

2 files changed

+328
-58
lines changed

2 files changed

+328
-58
lines changed

libsql/src/sync.rs

Lines changed: 3 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ use hyper::Body;
99
use tokio::io::AsyncWriteExt as _;
1010
use uuid::Uuid;
1111

12+
#[cfg(test)]
13+
mod test;
14+
1215
const METADATA_VERSION: u32 = 0;
1316

1417
const DEFAULT_MAX_RETRIES: usize = 5;
@@ -310,61 +313,3 @@ async fn atomic_write<P: AsRef<Path>>(path: P, data: &[u8]) -> Result<()> {
310313

311314
Ok(())
312315
}
313-
314-
// TODO(lucio): for the tests to work we need proper error handling which
315-
// will be done in follow up.
316-
#[cfg(test)]
317-
mod tests {
318-
use super::*;
319-
320-
#[test]
321-
#[ignore]
322-
fn test_hash_verification() {
323-
let mut metadata = MetadataJson {
324-
hash: 0,
325-
version: 1,
326-
durable_frame_num: 100,
327-
generation: 5,
328-
};
329-
330-
assert!(metadata.verify_hash().is_err());
331-
332-
metadata.set_hash();
333-
334-
assert!(metadata.verify_hash().is_ok());
335-
}
336-
337-
#[test]
338-
#[ignore]
339-
fn test_hash_tampering() {
340-
let mut metadata = MetadataJson {
341-
hash: 0,
342-
version: 1,
343-
durable_frame_num: 100,
344-
generation: 5,
345-
};
346-
347-
// Create metadata with hash
348-
metadata.set_hash();
349-
350-
// Tamper with a field
351-
metadata.version = 2;
352-
353-
// Verify should fail
354-
assert!(metadata.verify_hash().is_err());
355-
356-
metadata.version = 1;
357-
metadata.generation = 42;
358-
359-
assert!(metadata.verify_hash().is_err());
360-
361-
metadata.generation = 5;
362-
metadata.durable_frame_num = 42;
363-
364-
assert!(metadata.verify_hash().is_err());
365-
366-
metadata.durable_frame_num = 100;
367-
368-
assert!(metadata.verify_hash().is_ok());
369-
}
370-
}

libsql/src/sync/test.rs

Lines changed: 325 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,325 @@
1+
use super::*;
2+
use crate::util::Socket;
3+
use std::pin::Pin;
4+
use std::sync::atomic::{AtomicU32, Ordering};
5+
use std::sync::Arc;
6+
use std::task::{Context, Poll};
7+
use tempfile::tempdir;
8+
use tokio::io::{duplex, AsyncRead, AsyncWrite, DuplexStream};
9+
use tower::Service;
10+
11+
#[tokio::test]
12+
async fn test_sync_context_push_frame() {
13+
let server = MockServer::start();
14+
let temp_dir = tempdir().unwrap();
15+
let db_path = temp_dir.path().join("test.db");
16+
17+
let sync_ctx = SyncContext::new(
18+
server.connector(),
19+
db_path.to_str().unwrap().to_string(),
20+
server.url(),
21+
None,
22+
)
23+
.await
24+
.unwrap();
25+
26+
let frame = Bytes::from("test frame data");
27+
let mut sync_ctx = sync_ctx;
28+
29+
// Push a frame and verify the response
30+
let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
31+
assert_eq!(durable_frame, 1); // First frame should return max_frame_no = 1
32+
33+
// Verify internal state was updated
34+
assert_eq!(sync_ctx.durable_frame_num(), 1);
35+
assert_eq!(sync_ctx.generation(), 1);
36+
assert_eq!(server.frame_count(), 1);
37+
}
38+
39+
#[tokio::test]
40+
async fn test_sync_context_with_auth() {
41+
let server = MockServer::start();
42+
let temp_dir = tempdir().unwrap();
43+
let db_path = temp_dir.path().join("test.db");
44+
45+
let sync_ctx = SyncContext::new(
46+
server.connector(),
47+
db_path.to_str().unwrap().to_string(),
48+
server.url(),
49+
Some("test_token".to_string()),
50+
)
51+
.await
52+
.unwrap();
53+
54+
let frame = Bytes::from("test frame with auth");
55+
let mut sync_ctx = sync_ctx;
56+
57+
let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
58+
assert_eq!(durable_frame, 1);
59+
assert_eq!(server.frame_count(), 1);
60+
}
61+
62+
#[tokio::test]
63+
async fn test_sync_context_multiple_frames() {
64+
let server = MockServer::start();
65+
let temp_dir = tempdir().unwrap();
66+
let db_path = temp_dir.path().join("test.db");
67+
68+
let sync_ctx = SyncContext::new(
69+
server.connector(),
70+
db_path.to_str().unwrap().to_string(),
71+
server.url(),
72+
None,
73+
)
74+
.await
75+
.unwrap();
76+
77+
let mut sync_ctx = sync_ctx;
78+
79+
// Push multiple frames and verify incrementing frame numbers
80+
for i in 0..3 {
81+
let frame = Bytes::from(format!("frame data {}", i));
82+
let durable_frame = sync_ctx.push_one_frame(frame, 1, i).await.unwrap();
83+
assert_eq!(durable_frame, i + 1);
84+
assert_eq!(sync_ctx.durable_frame_num(), i + 1);
85+
assert_eq!(server.frame_count(), i + 1);
86+
}
87+
}
88+
89+
#[tokio::test]
90+
async fn test_sync_context_corrupted_metadata() {
91+
let server = MockServer::start();
92+
let temp_dir = tempdir().unwrap();
93+
let db_path = temp_dir.path().join("test.db");
94+
95+
// Create initial sync context and push a frame
96+
let sync_ctx = SyncContext::new(
97+
server.connector(),
98+
db_path.to_str().unwrap().to_string(),
99+
server.url(),
100+
None,
101+
)
102+
.await
103+
.unwrap();
104+
105+
let mut sync_ctx = sync_ctx;
106+
let frame = Bytes::from("test frame data");
107+
let durable_frame = sync_ctx.push_one_frame(frame, 1, 0).await.unwrap();
108+
assert_eq!(durable_frame, 1);
109+
assert_eq!(server.frame_count(), 1);
110+
111+
// Update metadata path to use -info instead of .meta
112+
let metadata_path = format!("{}-info", db_path.to_str().unwrap());
113+
std::fs::write(&metadata_path, b"invalid json data").unwrap();
114+
115+
// Create new sync context with corrupted metadata
116+
let sync_ctx = SyncContext::new(
117+
server.connector(),
118+
db_path.to_str().unwrap().to_string(),
119+
server.url(),
120+
None,
121+
)
122+
.await
123+
.unwrap();
124+
125+
// Verify that the context was reset to default values
126+
assert_eq!(sync_ctx.durable_frame_num(), 0);
127+
assert_eq!(sync_ctx.generation(), 1);
128+
}
129+
130+
#[test]
131+
fn test_hash_verification() {
132+
let mut metadata = MetadataJson {
133+
hash: 0,
134+
version: 1,
135+
durable_frame_num: 100,
136+
generation: 5,
137+
};
138+
139+
assert!(metadata.verify_hash().is_err());
140+
141+
metadata.set_hash();
142+
143+
assert!(metadata.verify_hash().is_ok());
144+
}
145+
146+
#[test]
147+
fn test_hash_tampering() {
148+
let mut metadata = MetadataJson {
149+
hash: 0,
150+
version: 1,
151+
durable_frame_num: 100,
152+
generation: 5,
153+
};
154+
155+
// Create metadata with hash
156+
metadata.set_hash();
157+
158+
// Tamper with a field
159+
metadata.version = 2;
160+
161+
// Verify should fail
162+
assert!(metadata.verify_hash().is_err());
163+
164+
metadata.version = 1;
165+
metadata.generation = 42;
166+
167+
assert!(metadata.verify_hash().is_err());
168+
169+
metadata.generation = 5;
170+
metadata.durable_frame_num = 42;
171+
172+
assert!(metadata.verify_hash().is_err());
173+
174+
metadata.durable_frame_num = 100;
175+
176+
assert!(metadata.verify_hash().is_ok());
177+
}
178+
179+
// Mock connector service that implements tower::Service
180+
#[derive(Clone)]
181+
struct MockConnector {
182+
tx: tokio::sync::mpsc::Sender<DuplexStream>,
183+
}
184+
185+
impl Service<http::Uri> for MockConnector {
186+
type Response = Box<dyn Socket>;
187+
type Error = Box<dyn std::error::Error + Send + Sync>;
188+
type Future = Pin<
189+
Box<
190+
dyn std::future::Future<Output = std::result::Result<Self::Response, Self::Error>>
191+
+ Send,
192+
>,
193+
>;
194+
195+
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<std::result::Result<(), Self::Error>> {
196+
Poll::Ready(Ok(()))
197+
}
198+
199+
fn call(&mut self, _: http::Uri) -> Self::Future {
200+
let (stream, server_stream) = duplex(1024);
201+
let _ = self.tx.try_send(server_stream);
202+
let conn = MockConnection { stream };
203+
Box::pin(std::future::ready(Ok(Box::new(conn) as Box<dyn Socket>)))
204+
}
205+
}
206+
207+
struct MockServer {
208+
url: String,
209+
frame_count: Arc<AtomicU32>,
210+
connector: ConnectorService,
211+
}
212+
213+
impl MockServer {
214+
fn start() -> Self {
215+
let frame_count = Arc::new(AtomicU32::new(0));
216+
217+
// Create the mock connector with Some(client_stream)
218+
let (tx, mut rx) = tokio::sync::mpsc::channel(1);
219+
let mock_connector = MockConnector { tx };
220+
let connector = ConnectorService::new(mock_connector);
221+
222+
let server = Self {
223+
url: "http://mock.server".to_string(),
224+
frame_count: frame_count.clone(),
225+
connector,
226+
};
227+
228+
// Spawn the server handler
229+
let frame_count_clone = frame_count.clone();
230+
231+
tokio::spawn(async move {
232+
while let Some(server_stream) = rx.recv().await {
233+
let frame_count_clone = frame_count_clone.clone();
234+
235+
tokio::spawn(async move {
236+
use hyper::server::conn::Http;
237+
use hyper::service::service_fn;
238+
239+
let frame_count_clone = frame_count_clone.clone();
240+
let service = service_fn(move |req: http::Request<Body>| {
241+
let frame_count = frame_count_clone.clone();
242+
async move {
243+
let current_count = frame_count.fetch_add(1, Ordering::SeqCst);
244+
245+
if req.uri().path().contains("/sync/") {
246+
let response = serde_json::json!({
247+
"max_frame_no": current_count + 1
248+
});
249+
250+
Ok::<_, hyper::Error>(
251+
http::Response::builder()
252+
.status(200)
253+
.body(Body::from(response.to_string()))
254+
.unwrap(),
255+
)
256+
} else {
257+
Ok(http::Response::builder()
258+
.status(404)
259+
.body(Body::empty())
260+
.unwrap())
261+
}
262+
}
263+
});
264+
265+
if let Err(e) = Http::new().serve_connection(server_stream, service).await {
266+
eprintln!("Error serving connection: {}", e);
267+
}
268+
});
269+
}
270+
});
271+
272+
server
273+
}
274+
275+
fn connector(&self) -> ConnectorService {
276+
self.connector.clone()
277+
}
278+
279+
fn url(&self) -> String {
280+
self.url.clone()
281+
}
282+
283+
fn frame_count(&self) -> u32 {
284+
self.frame_count.load(Ordering::SeqCst)
285+
}
286+
}
287+
288+
// Mock connection that implements the Socket trait
289+
struct MockConnection {
290+
stream: DuplexStream,
291+
}
292+
293+
impl AsyncRead for MockConnection {
294+
fn poll_read(
295+
mut self: Pin<&mut Self>,
296+
cx: &mut Context<'_>,
297+
buf: &mut tokio::io::ReadBuf<'_>,
298+
) -> Poll<std::io::Result<()>> {
299+
Pin::new(&mut self.stream).poll_read(cx, buf)
300+
}
301+
}
302+
303+
impl AsyncWrite for MockConnection {
304+
fn poll_write(
305+
mut self: Pin<&mut Self>,
306+
cx: &mut Context<'_>,
307+
buf: &[u8],
308+
) -> Poll<std::io::Result<usize>> {
309+
Pin::new(&mut self.stream).poll_write(cx, buf)
310+
}
311+
312+
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
313+
Pin::new(&mut self.stream).poll_flush(cx)
314+
}
315+
316+
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
317+
Pin::new(&mut self.stream).poll_shutdown(cx)
318+
}
319+
}
320+
321+
impl hyper::client::connect::Connection for MockConnection {
322+
fn connected(&self) -> hyper::client::connect::Connected {
323+
hyper::client::connect::Connected::new()
324+
}
325+
}

0 commit comments

Comments
 (0)