Skip to content

Commit 864a029

Browse files
authored
Fix wal api ffi (#2128)
Basically, right now, embedded replica will fail if WAL reached size of more than `4120` frame (familiar number, huh?). The issue is that in the Rust API we pass `frame.len()` as a `frame_no` and 0 as a `frame.len()` 🤪 The failure happens because `wal_insert` C code will check frame for conflict in case when `if (iFrame <= mxFrame)` and then try to allocate `nBuf - 24` bytes (where `nBuf` equals to `0` due to bug) This PR fixes this bug and also added simple test to check the API.
2 parents df497ca + 253daa6 commit 864a029

File tree

2 files changed

+70
-11
lines changed

2 files changed

+70
-11
lines changed

libsql/src/local/connection.rs

Lines changed: 68 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ impl Connection {
7676
// disabled so that we can sync our changes back to a remote
7777
// server.
7878
conn.query("PRAGMA journal_mode = WAL", Params::None)?;
79-
unsafe {
80-
ffi::libsql_wal_disable_checkpoint(conn.raw);
81-
}
79+
conn.wal_disable_checkpoint()?;
8280
}
8381
Ok(conn)
8482
}
@@ -554,6 +552,16 @@ impl Connection {
554552
Ok(buf)
555553
}
556554

555+
fn wal_disable_checkpoint(&self) -> Result<()> {
556+
let rc = unsafe { libsql_sys::ffi::libsql_wal_disable_checkpoint(self.handle()) };
557+
if rc != 0 {
558+
return Err(crate::errors::Error::SqliteFailure(
559+
rc as std::ffi::c_int,
560+
format!("wal_disable_checkpoint failed"),
561+
));
562+
}
563+
Ok(())
564+
}
557565
fn wal_insert_begin(&self) -> Result<()> {
558566
let rc = unsafe { libsql_sys::ffi::libsql_wal_insert_begin(self.handle()) };
559567
if rc != 0 {
@@ -576,14 +584,14 @@ impl Connection {
576584
Ok(())
577585
}
578586

579-
fn wal_insert_frame(&self, frame: &[u8]) -> Result<()> {
587+
fn wal_insert_frame(&self, frame_no: u32, frame: &[u8]) -> Result<()> {
580588
let mut conflict = 0i32;
581589
let rc = unsafe {
582590
libsql_sys::ffi::libsql_wal_insert_frame(
583591
self.handle(),
584-
frame.len() as u32,
592+
frame_no,
585593
frame.as_ptr() as *mut std::ffi::c_void,
586-
0,
594+
frame.len() as u32,
587595
&mut conflict,
588596
)
589597
};
@@ -658,13 +666,13 @@ unsafe extern "C" fn authorizer_callback(
658666

659667
pub(crate) struct WalInsertHandle<'a> {
660668
conn: &'a Connection,
661-
in_session: RefCell<bool>
669+
in_session: RefCell<bool>,
662670
}
663671

664672
impl WalInsertHandle<'_> {
665-
pub fn insert(&self, frame: &[u8]) -> Result<()> {
673+
pub fn insert_at(&self, frame_no: u32, frame: &[u8]) -> Result<()> {
666674
assert!(*self.in_session.borrow());
667-
self.conn.wal_insert_frame(frame)
675+
self.conn.wal_insert_frame(frame_no, frame)
668676
}
669677

670678
pub fn begin(&self) -> Result<()> {
@@ -698,3 +706,54 @@ impl fmt::Debug for Connection {
698706
f.debug_struct("Connection").finish()
699707
}
700708
}
709+
710+
#[cfg(test)]
711+
mod tests {
712+
use crate::{
713+
local::{Connection, Database},
714+
params::Params,
715+
OpenFlags,
716+
};
717+
718+
#[tokio::test]
719+
pub async fn test_kek() {
720+
let temp_dir = tempfile::tempdir().unwrap();
721+
let path1 = temp_dir.path().join("local1.db");
722+
let db1 = Database::new(path1.to_str().unwrap().to_string(), OpenFlags::default());
723+
let conn1 = Connection::connect(&db1).unwrap();
724+
conn1
725+
.query("PRAGMA journal_mode = WAL", Params::None)
726+
.unwrap();
727+
conn1.wal_disable_checkpoint().unwrap();
728+
729+
let path2 = temp_dir.path().join("local2.db");
730+
let db2 = Database::new(path2.to_str().unwrap().to_string(), OpenFlags::default());
731+
let conn2 = Connection::connect(&db2).unwrap();
732+
conn2
733+
.query("PRAGMA journal_mode = WAL", Params::None)
734+
.unwrap();
735+
conn2.wal_disable_checkpoint().unwrap();
736+
737+
conn1.execute("CREATE TABLE t(x)", Params::None).unwrap();
738+
const CNT: usize = 32;
739+
for _ in 0..CNT {
740+
conn1
741+
.execute(
742+
"INSERT INTO t VALUES (randomblob(1024 * 1024))",
743+
Params::None,
744+
)
745+
.unwrap();
746+
}
747+
let handle = conn2.wal_insert_handle().unwrap();
748+
let frame_count = conn1.wal_frame_count();
749+
for frame_no in 0..frame_count {
750+
let frame = conn1.wal_get_frame(frame_no + 1, 4096).unwrap();
751+
handle.insert_at(frame_no as u32 + 1, &frame).unwrap();
752+
}
753+
let result = conn2.query("SELECT COUNT(*) FROM t", Params::None).unwrap();
754+
let row = result.unwrap().next().unwrap().unwrap();
755+
let column = row.get_value(0).unwrap();
756+
let cnt = *column.as_integer().unwrap();
757+
assert_eq!(cnt, 32 as i64);
758+
}
759+
}

libsql/src/sync.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -974,8 +974,8 @@ pub async fn try_pull(
974974
);
975975
return Err(SyncError::InvalidPullFrameBytes(frames.len()).into());
976976
}
977-
for chunk in frames.chunks(FRAME_SIZE) {
978-
let r = insert_handle.insert(&chunk);
977+
for (i, chunk) in frames.chunks(FRAME_SIZE).enumerate() {
978+
let r = insert_handle.insert_at(frame_no + i as u32, &chunk);
979979
if let Err(e) = r {
980980
tracing::error!(
981981
"insert error (frame= {}) : {:?}",

0 commit comments

Comments
 (0)