Skip to content

Commit 38edf74

Browse files
committed
add the cloudsync_payload_apply_callback to support RLS verification while applying cloudsync changes on the centralized node
1 parent 329730e commit 38edf74

File tree

4 files changed

+142
-33
lines changed

4 files changed

+142
-33
lines changed

src/cloudsync.c

Lines changed: 99 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@ SQLITE_EXTENSION_INIT1
7171
#define CLOUDSYNC_PAYLOAD_MINBUF_SIZE 512*1024
7272
#define CLOUDSYNC_PAYLOAD_VERSION 1
7373
#define CLOUDSYNC_PAYLOAD_SIGNATURE 'CLSY'
74-
#define CLOUDSYNC_PK_INDEX_DBVERSION 5
75-
#define CLOUDSYNC_PK_INDEX_SEQ 8
7674

7775
#ifndef MAX
7876
#define MAX(a, b) (((a)>(b))?(a):(b))
@@ -93,12 +91,17 @@ typedef enum {
9391
CLOUDSYNC_STMT_VALUE_CHANGED = 1,
9492
} CLOUDSYNC_STMT_VALUE;
9593

96-
typedef struct {
97-
sqlite3_stmt *vm;
98-
int64_t dbversion;
99-
int64_t seq;
100-
int64_t tmp_dbversion;
101-
} cloudsync_pk_decode_bind_context;
94+
typedef enum {
95+
CLOUDSYNC_PK_INDEX_TBL = 0,
96+
CLOUDSYNC_PK_INDEX_PK = 1,
97+
CLOUDSYNC_PK_INDEX_COLNAME = 2,
98+
CLOUDSYNC_PK_INDEX_COLVALUE = 3,
99+
CLOUDSYNC_PK_INDEX_COLVERSION = 4,
100+
CLOUDSYNC_PK_INDEX_DBVERSION = 5,
101+
CLOUDSYNC_PK_INDEX_SITEID = 6,
102+
CLOUDSYNC_PK_INDEX_CL = 7,
103+
CLOUDSYNC_PK_INDEX_SEQ = 8
104+
} CLOUDSYNC_PK_INDEX;
102105

103106
typedef struct {
104107
sqlite3_context *context;
@@ -215,6 +218,8 @@ int db_version_rebuild_stmt (sqlite3 *db, cloudsync_context *data);
215218
int cloudsync_load_siteid (sqlite3 *db, cloudsync_context *data);
216219
int local_mark_insert_or_update_meta (sqlite3 *db, cloudsync_table_context *table, const char *pk, size_t pklen, const char *col_name, sqlite3_int64 db_version, int seq);
217220

221+
static cloudsync_payload_apply_callback_t payload_apply_callback;
222+
218223
// MARK: - STMT Utils -
219224

220225
CLOUDSYNC_STMT_VALUE stmt_execute (sqlite3_stmt *stmt, cloudsync_context *data) {
@@ -1927,32 +1932,55 @@ int cloudsync_pk_decode_bind_callback (void *xdata, int index, int type, int64_t
19271932
cloudsync_pk_decode_bind_context *decode_context = (cloudsync_pk_decode_bind_context*)xdata;
19281933
int rc = pk_decode_bind_callback(decode_context->vm, index, type, ival, dval, pval);
19291934

1930-
if (rc == SQLITE_OK && type == SQLITE_INTEGER) {
1935+
if (rc == SQLITE_OK) {
19311936
// the dbversion index is smaller than seq index, so it is processed first
19321937
// when processing the dbversion column: save the value to the tmp_dbversion field
19331938
// when processing the seq column: update the dbversion and seq fields only if the current dbversion is greater than the last max value
19341939
switch (index) {
1940+
case CLOUDSYNC_PK_INDEX_TBL:
1941+
if (type == SQLITE_TEXT) {
1942+
decode_context->tbl = pval;
1943+
decode_context->tbl_len = ival;
1944+
}
1945+
break;
1946+
case CLOUDSYNC_PK_INDEX_PK:
1947+
if (type == SQLITE_BLOB) {
1948+
decode_context->pk = pval;
1949+
decode_context->pk_len = ival;
1950+
}
1951+
break;
1952+
case CLOUDSYNC_PK_INDEX_COLNAME:
1953+
if (type == SQLITE_TEXT) {
1954+
decode_context->col_name = pval;
1955+
decode_context->col_name_len = ival;
1956+
}
1957+
break;
1958+
case CLOUDSYNC_PK_INDEX_COLVERSION:
1959+
if (type == SQLITE_INTEGER) decode_context->col_version = ival;
1960+
break;
19351961
case CLOUDSYNC_PK_INDEX_DBVERSION:
1936-
decode_context->tmp_dbversion = ival;
1962+
if (type == SQLITE_INTEGER) decode_context->db_version = ival;
19371963
break;
1938-
case CLOUDSYNC_PK_INDEX_SEQ:
1939-
// when the dbversion field is incremented the seq val must be updated too
1940-
// because the current decode_context->seq field refers to the previous dbversion
1941-
if (decode_context->tmp_dbversion > decode_context->dbversion) {
1942-
decode_context->dbversion = decode_context->tmp_dbversion;
1943-
decode_context->seq = ival;
1944-
} else if (decode_context->tmp_dbversion == decode_context->dbversion) {
1945-
decode_context->seq = MAX(decode_context->seq, ival);
1964+
case CLOUDSYNC_PK_INDEX_SITEID:
1965+
if (type == SQLITE_BLOB) {
1966+
decode_context->site_id = pval;
1967+
decode_context->site_id_len = ival;
19461968
}
1947-
// reset the tmp_dbversion value before processing the next row
1948-
decode_context->tmp_dbversion = 0;
1969+
break;
1970+
case CLOUDSYNC_PK_INDEX_CL:
1971+
if (type == SQLITE_INTEGER) decode_context->cl = ival;
1972+
break;
1973+
case CLOUDSYNC_PK_INDEX_SEQ:
1974+
if (type == SQLITE_INTEGER) decode_context->seq = ival;
19491975
break;
19501976
}
19511977
}
19521978

19531979
return rc;
19541980
}
19551981

1982+
// #ifndef CLOUDSYNC_OMIT_RLS_VALIDATION
1983+
19561984
int cloudsync_payload_apply (sqlite3_context *context, const char *payload, int blen) {
19571985
// decode header
19581986
cloudsync_network_header header;
@@ -2016,31 +2044,47 @@ int cloudsync_payload_apply (sqlite3_context *context, const char *payload, int
20162044
uint32_t nrows = header.nrows;
20172045
int dbversion = dbutils_settings_get_int_value(db, CLOUDSYNC_KEY_CHECK_DBVERSION);
20182046
int seq = dbutils_settings_get_int_value(db, CLOUDSYNC_KEY_CHECK_SEQ);
2019-
cloudsync_pk_decode_bind_context xdata = {.vm = vm, .dbversion = dbversion, .seq = seq, .tmp_dbversion = 0};
2047+
cloudsync_pk_decode_bind_context decoded_context = {.vm = vm};
2048+
void *payload_apply_xdata = NULL;
2049+
20202050
for (uint32_t i=0; i<nrows; ++i) {
20212051
size_t seek = 0;
2022-
pk_decode((char *)buffer, blen, ncols, &seek, cloudsync_pk_decode_bind_callback, &xdata);
2052+
pk_decode((char *)buffer, blen, ncols, &seek, cloudsync_pk_decode_bind_callback, &decoded_context);
20232053
// n is the pk_decode return value, I don't think I should assert here because in any case the next sqlite3_step would fail
20242054
// assert(n == ncols);
20252055

2026-
rc = sqlite3_step(vm);
2027-
if (rc != SQLITE_DONE) break;
2056+
bool approved = true;
2057+
if (payload_apply_callback) approved = payload_apply_callback(&payload_apply_xdata, &decoded_context, db, data, CLOUDSYNC_PAYLOAD_APPLY_WILL_APPLY, SQLITE_OK);
2058+
2059+
if (approved) {
2060+
rc = sqlite3_step(vm);
2061+
if (rc != SQLITE_DONE) {
2062+
// don't "break;", the error can be due to a RLS policy.
2063+
// in case of error we try to apply the following changes
2064+
printf("cloudsync_payload_apply error in step: (%d) %s\n", rc, sqlite3_errmsg(db));
2065+
}
2066+
}
2067+
2068+
if (payload_apply_callback) payload_apply_callback(&payload_apply_xdata, &decoded_context, db, data, CLOUDSYNC_PAYLOAD_APPLY_DID_APPLY, rc);
20282069

20292070
buffer += seek;
20302071
blen -= seek;
20312072
stmt_reset(vm);
20322073
}
2033-
2074+
2075+
if (payload_apply_callback) payload_apply_callback(&payload_apply_xdata, &decoded_context, db, data, CLOUDSYNC_PAYLOAD_APPLY_CLEANUP, rc);
2076+
20342077
if (rc == SQLITE_DONE) rc = SQLITE_OK;
20352078
if (rc == SQLITE_OK) {
20362079
char buf[256];
2037-
if (xdata.dbversion != dbversion) {
2038-
snprintf(buf, sizeof(buf), "%lld", xdata.dbversion);
2080+
if (decoded_context.db_version >= dbversion) {
2081+
snprintf(buf, sizeof(buf), "%lld", decoded_context.db_version);
20392082
dbutils_settings_set_key_value(db, context, CLOUDSYNC_KEY_CHECK_DBVERSION, buf);
2040-
}
2041-
if (xdata.seq != seq) {
2042-
snprintf(buf, sizeof(buf), "%lld", xdata.seq);
2043-
dbutils_settings_set_key_value(db, context, CLOUDSYNC_KEY_CHECK_SEQ, buf);
2083+
2084+
if (decoded_context.seq != seq) {
2085+
snprintf(buf, sizeof(buf), "%lld", decoded_context.seq);
2086+
dbutils_settings_set_key_value(db, context, CLOUDSYNC_KEY_CHECK_SEQ, buf);
2087+
}
20442088
}
20452089
}
20462090

@@ -2061,6 +2105,30 @@ int cloudsync_payload_apply (sqlite3_context *context, const char *payload, int
20612105
return nrows;
20622106
}
20632107

2108+
void cloudsync_payload_apply_callback(cloudsync_payload_apply_callback_t callback) {
2109+
payload_apply_callback = callback;
2110+
}
2111+
2112+
sqlite3_stmt *cloudsync_col_value_stmt (sqlite3 *db, cloudsync_context *data, const char *tbl_name, bool *persistent) {
2113+
sqlite3_stmt *vm;
2114+
2115+
cloudsync_table_context *table = table_lookup(data, tbl_name, false);
2116+
char *col_name = NULL;
2117+
if (table->ncols > 0) {
2118+
col_name = table->col_name[0];
2119+
// retrieve col_value precompiled statement
2120+
vm = table_column_lookup(table, col_name, false, NULL);
2121+
*persistent = true;
2122+
} else {
2123+
char *sql = table_build_value_sql(db, table, "*");
2124+
sqlite3_prepare_v2(db, sql, -1, &vm, NULL);
2125+
cloudsync_memory_free(sql);
2126+
*persistent = false;
2127+
}
2128+
2129+
return vm;
2130+
}
2131+
20642132
void cloudsync_network_decode (sqlite3_context *context, int argc, sqlite3_value **argv) {
20652133
DEBUG_FUNCTION("cloudsync_network_decode");
20662134
//debug_values(argc, argv);

src/cloudsync.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,30 @@
2121
#define CLOUDSYNC_RLS_RESTRICTED_VALUE "__[RLS]__"
2222
#define CLOUDSYNC_DISABLE_ROWIDONLY_TABLES 1
2323

24+
typedef enum {
25+
CLOUDSYNC_PAYLOAD_APPLY_WILL_APPLY = 1,
26+
CLOUDSYNC_PAYLOAD_APPLY_DID_APPLY = 2,
27+
CLOUDSYNC_PAYLOAD_APPLY_CLEANUP = 3
28+
} CLOUDSYNC_PAYLOAD_APPLY_STEPS;
29+
30+
typedef struct {
31+
sqlite3_stmt *vm;
32+
char *tbl;
33+
int64_t tbl_len;
34+
const void *pk;
35+
int64_t pk_len;
36+
char *col_name;
37+
int64_t col_name_len;
38+
int64_t col_version;
39+
int64_t db_version;
40+
const void *site_id;
41+
int64_t site_id_len;
42+
int64_t cl;
43+
int64_t seq;
44+
} cloudsync_pk_decode_bind_context;
45+
2446
typedef struct cloudsync_context cloudsync_context;
47+
typedef bool (*cloudsync_payload_apply_callback_t)(void **xdata, cloudsync_pk_decode_bind_context *decoded_change, sqlite3 *db, cloudsync_context *data, int step, int rc);
2548

2649
int sqlite3_cloudsync_init (sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi);
2750
bool cloudsync_config_exists (sqlite3 *db);
@@ -32,5 +55,7 @@ void cloudsync_sync_table_key (cloudsync_context *data, const char *table, const
3255
void *cloudsync_get_auxdata (sqlite3_context *context);
3356
void cloudsync_set_auxdata (sqlite3_context *context, void *xdata);
3457
int cloudsync_payload_apply (sqlite3_context *context, const char *payload, int blen);
58+
void cloudsync_payload_apply_callback(cloudsync_payload_apply_callback_t callback);
59+
sqlite3_stmt *cloudsync_col_value_stmt (sqlite3 *db, cloudsync_context *data, const char *tbl_name, bool *persistent);
3560

3661
#endif

src/utils.c

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,9 @@ void *cloudsync_memory_zeroalloc (uint64_t size) {
126126
return ptr;
127127
}
128128

129-
char *cloudsync_string_dup (const char *str, bool lowercase) {
129+
char *cloudsync_string_ndup (const char *str, size_t len, bool lowercase) {
130130
if (str == NULL) return NULL;
131131

132-
size_t len = strlen(str);
133132
char *s = (char *)cloudsync_memory_alloc((sqlite3_uint64)(len + 1));
134133
if (!s) return NULL;
135134

@@ -148,6 +147,20 @@ char *cloudsync_string_dup (const char *str, bool lowercase) {
148147
return s;
149148
}
150149

150+
char *cloudsync_string_dup (const char *str, bool lowercase) {
151+
if (str == NULL) return NULL;
152+
153+
size_t len = strlen(str);
154+
return cloudsync_string_ndup(str, len, lowercase);
155+
}
156+
157+
int cloudsync_blob_compare(const char *blob1, size_t size1, const char *blob2, size_t size2) {
158+
if (size1 != size2) {
159+
return (int)(size1 - size2); // Blobs are different if sizes are different
160+
}
161+
return memcmp(blob1, blob2, size1); // Use memcmp for byte-by-byte comparison
162+
}
163+
151164
void cloudsync_rowid_decode (sqlite3_int64 rowid, sqlite3_int64 *db_version, sqlite3_int64 *seq) {
152165
// use unsigned 64-bit integer for intermediate calculations
153166
// when db_version is large enough, it can cause overflow, leading to negative values

src/utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,10 @@ char *cloudsync_string_replace_prefix(const char *input, char *prefix, char *rep
129129
uint64_t fnv1a_hash(const char *data, size_t len);
130130

131131
void *cloudsync_memory_zeroalloc (uint64_t size);
132+
char *cloudsync_string_ndup (const char *str, size_t len, bool lowercase);
132133
char *cloudsync_string_dup (const char *str, bool lowercase);
134+
int cloudsync_blob_compare(const char *blob1, size_t size1, const char *blob2, size_t size2);
135+
133136
void cloudsync_rowid_decode (sqlite3_int64 rowid, sqlite3_int64 *db_version, sqlite3_int64 *seq);
134137

135138
#endif

0 commit comments

Comments
 (0)