Skip to content

Commit 3c8711f

Browse files
jeromegntantaman
authored andcommitted
fix a mistake where we were only using the current DB's site_id instead of the current site_id value in the clock table for comparison
1 parent cfc2d58 commit 3c8711f

File tree

3 files changed

+94
-7
lines changed

3 files changed

+94
-7
lines changed

core/rs/core/src/changes_vtab_write.rs

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,45 @@ fn did_cid_win(
9595
reset_cached_stmt(col_val_stmt.stmt)?;
9696
if ret == 0 && unsafe { (*ext_data).mergeEqualValues == 1 } {
9797
// values are the same (ret == 0) and the option to tie break on site_id is true
98-
ret = unsafe {
99-
let my_site_id = core::slice::from_raw_parts((*ext_data).siteId, 16);
100-
insert_site_id.cmp(my_site_id) as c_int
101-
};
98+
let col_site_id_stmt_ref = tbl_info.get_col_site_id_stmt(db)?;
99+
let col_site_id_stmt = col_site_id_stmt_ref.as_ref().ok_or(ResultCode::ERROR)?;
100+
101+
let bind_result = col_site_id_stmt.bind_int64(1, key);
102+
if let Err(rc) = bind_result {
103+
reset_cached_stmt(col_site_id_stmt.stmt)?;
104+
return Err(rc);
105+
}
106+
if let Err(rc) = col_site_id_stmt.bind_text(2, col_name, sqlite::Destructor::STATIC)
107+
{
108+
reset_cached_stmt(col_site_id_stmt.stmt)?;
109+
return Err(rc);
110+
}
111+
112+
match col_site_id_stmt.step() {
113+
Ok(ResultCode::ROW) => {
114+
let local_site_id = col_site_id_stmt.column_blob(0)?;
115+
ret = insert_site_id.cmp(local_site_id) as c_int;
116+
117+
// reset the stmt after, we're accessing a slice in-memory
118+
reset_cached_stmt(col_site_id_stmt.stmt)?;
119+
}
120+
Ok(ResultCode::DONE) => {
121+
reset_cached_stmt(col_site_id_stmt.stmt)?;
122+
let err = CString::new(format!(
123+
"could not find site_id for previous change, cr-sqlite clock table might be corrupt for tbl {}",
124+
insert_tbl
125+
))?;
126+
unsafe { *errmsg = err.into_raw() };
127+
return Err(ResultCode::ERROR);
128+
}
129+
Ok(rc) | Err(rc) => {
130+
reset_cached_stmt(col_site_id_stmt.stmt)?;
131+
let err =
132+
CString::new("Bad return code when selecting local column site_id")?;
133+
unsafe { *errmsg = err.into_raw() };
134+
return Err(rc);
135+
}
136+
}
102137
}
103138
return Ok(ret > 0);
104139
}

core/rs/core/src/tableinfo.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ pub struct TableInfo {
4545
set_winner_clock_stmt: RefCell<Option<ManagedStmt>>,
4646
local_cl_stmt: RefCell<Option<ManagedStmt>>,
4747
col_version_stmt: RefCell<Option<ManagedStmt>>,
48+
col_site_id_stmt: RefCell<Option<ManagedStmt>>,
4849
merge_pk_only_insert_stmt: RefCell<Option<ManagedStmt>>,
4950
merge_delete_stmt: RefCell<Option<ManagedStmt>>,
5051
merge_delete_drop_clocks_stmt: RefCell<Option<ManagedStmt>>,
@@ -337,6 +338,21 @@ impl TableInfo {
337338
Ok(self.col_version_stmt.try_borrow()?)
338339
}
339340

341+
pub fn get_col_site_id_stmt(
342+
&self,
343+
db: *mut sqlite3,
344+
) -> Result<Ref<Option<ManagedStmt>>, ResultCode> {
345+
if self.col_site_id_stmt.try_borrow()?.is_none() {
346+
let sql = format!(
347+
"SELECT site_id FROM crsql_site_id WHERE ordinal = (SELECT site_id FROM \"{table_name}__crsql_clock\" WHERE key = ? AND col_name = ?)",
348+
table_name = crate::util::escape_ident(&self.tbl_name),
349+
);
350+
let ret = db.prepare_v3(&sql, sqlite::PREPARE_PERSISTENT)?;
351+
*self.col_site_id_stmt.try_borrow_mut()? = Some(ret);
352+
}
353+
Ok(self.col_site_id_stmt.try_borrow()?)
354+
}
355+
340356
pub fn get_merge_pk_only_insert_stmt(
341357
&self,
342358
db: *mut sqlite3,
@@ -871,6 +887,7 @@ pub fn pull_table_info(
871887
set_winner_clock_stmt: RefCell::new(None),
872888
local_cl_stmt: RefCell::new(None),
873889
col_version_stmt: RefCell::new(None),
890+
col_site_id_stmt: RefCell::new(None),
874891

875892
select_key_stmt: RefCell::new(None),
876893
insert_key_stmt: RefCell::new(None),

py/correctness/tests/test_sync.py

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,7 @@ def make_dbs():
383383
def test_merge_same_w_tie_breaker():
384384
db1 = create_basic_db()
385385
db2 = create_basic_db()
386+
db3 = create_basic_db()
386387

387388
db1.execute("INSERT INTO foo (a,b) VALUES (1,2);")
388389
db1.execute("SELECT crsql_config_set('merge-equal-values', 1);")
@@ -392,13 +393,47 @@ def test_merge_same_w_tie_breaker():
392393
db2.execute("SELECT crsql_config_set('merge-equal-values', 1);")
393394
db2.commit()
394395

396+
db3.execute("INSERT INTO foo (a,b) VALUES (1,2);")
397+
db3.execute("SELECT crsql_config_set('merge-equal-values', 1);")
398+
db3.commit()
399+
395400
sync_left_to_right(db1, db2, 0)
396-
changes12 = db2.execute("SELECT \"table\", pk, cid, val, col_version, site_id FROM crsql_changes").fetchall()
401+
changes2 = db2.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall()
397402

398403
sync_left_to_right(db2, db1, 0)
399-
changes21 = db1.execute("SELECT \"table\", pk, cid, val, col_version, site_id FROM crsql_changes").fetchall()
404+
changes1 = db1.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall()
405+
406+
sync_left_to_right(db2, db3, 0)
407+
changes3 = db3.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall()
408+
409+
# check that everything by db_version is the same
410+
assert (changes2[:-6] == changes1[:-6] == changes3[:-6])
400411

401-
assert (changes12 == changes21)
412+
# Test that we're stable / do not loop when we tie break equal values
413+
414+
sync_left_to_right(db2, db1, 0)
415+
changes1_2 = db1.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall()
416+
sync_left_to_right(db3, db2, 0)
417+
changes2_2 = db2.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall()
418+
sync_left_to_right(db1, db3, 0)
419+
changes3_2 = db3.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall()
420+
421+
# everything should stay the same, including db_version
422+
assert (changes1 == changes1_2)
423+
assert (changes2 == changes2_2)
424+
assert (changes3 == changes3_2)
425+
426+
sync_left_to_right(db3, db1, 0)
427+
changes1_2 = db1.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall()
428+
sync_left_to_right(db1, db2, 0)
429+
changes2_2 = db2.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall()
430+
sync_left_to_right(db2, db3, 0)
431+
changes3_2 = db3.execute("SELECT \"table\", pk, cid, val, col_version, site_id, db_version FROM crsql_changes").fetchall()
432+
433+
# everything should stay the same, including db_version
434+
assert (changes1 == changes1_2)
435+
assert (changes2 == changes2_2)
436+
assert (changes3 == changes3_2)
402437

403438

404439
def test_merge_matching_clocks_lesser_value():

0 commit comments

Comments
 (0)