Skip to content

Commit 3cf7752

Browse files
authored
Merge pull request #3 from sqliteai/feature/payload-approval-callback
Feature/payload approval callback
2 parents 329730e + 1d459b8 commit 3cf7752

File tree

6 files changed

+370
-50
lines changed

6 files changed

+370
-50
lines changed

src/cloudsync.c

Lines changed: 110 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,14 @@ SQLITE_EXTENSION_INIT1
6464
#define APIEXPORT
6565
#endif
6666

67-
#define CLOUDSYNC_DEFAULT_ALGO "cls"
68-
#define CLOUDSYNC_INIT_NTABLES 128
69-
#define CLOUDSYNC_VALUE_NOTSET -1
70-
#define CLOUDSYNC_MIN_DB_VERSION 0
71-
#define CLOUDSYNC_PAYLOAD_MINBUF_SIZE 512*1024
72-
#define CLOUDSYNC_PAYLOAD_VERSION 1
73-
#define CLOUDSYNC_PAYLOAD_SIGNATURE 'CLSY'
74-
#define CLOUDSYNC_PK_INDEX_DBVERSION 5
75-
#define CLOUDSYNC_PK_INDEX_SEQ 8
67+
#define CLOUDSYNC_DEFAULT_ALGO "cls"
68+
#define CLOUDSYNC_INIT_NTABLES 128
69+
#define CLOUDSYNC_VALUE_NOTSET -1
70+
#define CLOUDSYNC_MIN_DB_VERSION 0
71+
#define CLOUDSYNC_PAYLOAD_MINBUF_SIZE 512*1024
72+
#define CLOUDSYNC_PAYLOAD_VERSION 1
73+
#define CLOUDSYNC_PAYLOAD_SIGNATURE 'CLSY'
74+
#define CLOUDSYNC_PAYLOAD_APPLY_CALLBACK_KEY "cloudsync_payload_apply_callback"
7675

7776
#ifndef MAX
7877
#define MAX(a, b) (((a)>(b))?(a):(b))
@@ -93,12 +92,17 @@ typedef enum {
9392
CLOUDSYNC_STMT_VALUE_CHANGED = 1,
9493
} CLOUDSYNC_STMT_VALUE;
9594

96-
typedef struct {
97-
sqlite3_stmt *vm;
98-
int64_t dbversion;
99-
int64_t seq;
100-
int64_t tmp_dbversion;
101-
} cloudsync_pk_decode_bind_context;
95+
typedef enum {
96+
CLOUDSYNC_PK_INDEX_TBL = 0,
97+
CLOUDSYNC_PK_INDEX_PK = 1,
98+
CLOUDSYNC_PK_INDEX_COLNAME = 2,
99+
CLOUDSYNC_PK_INDEX_COLVALUE = 3,
100+
CLOUDSYNC_PK_INDEX_COLVERSION = 4,
101+
CLOUDSYNC_PK_INDEX_DBVERSION = 5,
102+
CLOUDSYNC_PK_INDEX_SITEID = 6,
103+
CLOUDSYNC_PK_INDEX_CL = 7,
104+
CLOUDSYNC_PK_INDEX_SEQ = 8
105+
} CLOUDSYNC_PK_INDEX;
102106

103107
typedef struct {
104108
sqlite3_context *context;
@@ -1923,36 +1927,67 @@ void cloudsync_network_encode_final (sqlite3_context *context) {
19231927
if (!use_uncompressed_buffer) cloudsync_memory_free(buffer);
19241928
}
19251929

1930+
cloudsync_payload_apply_callback_t cloudsync_get_payload_apply_callback(sqlite3 *db) {
1931+
return sqlite3_get_clientdata(db, CLOUDSYNC_PAYLOAD_APPLY_CALLBACK_KEY);
1932+
}
1933+
1934+
void cloudsync_set_payload_apply_callback(sqlite3 *db, cloudsync_payload_apply_callback_t callback) {
1935+
sqlite3_set_clientdata(db, CLOUDSYNC_PAYLOAD_APPLY_CALLBACK_KEY, (void*)callback, NULL);
1936+
}
1937+
19261938
int cloudsync_pk_decode_bind_callback (void *xdata, int index, int type, int64_t ival, double dval, char *pval) {
19271939
cloudsync_pk_decode_bind_context *decode_context = (cloudsync_pk_decode_bind_context*)xdata;
19281940
int rc = pk_decode_bind_callback(decode_context->vm, index, type, ival, dval, pval);
19291941

1930-
if (rc == SQLITE_OK && type == SQLITE_INTEGER) {
1942+
if (rc == SQLITE_OK) {
19311943
// the dbversion index is smaller than seq index, so it is processed first
19321944
// when processing the dbversion column: save the value to the tmp_dbversion field
19331945
// when processing the seq column: update the dbversion and seq fields only if the current dbversion is greater than the last max value
19341946
switch (index) {
1947+
case CLOUDSYNC_PK_INDEX_TBL:
1948+
if (type == SQLITE_TEXT) {
1949+
decode_context->tbl = pval;
1950+
decode_context->tbl_len = ival;
1951+
}
1952+
break;
1953+
case CLOUDSYNC_PK_INDEX_PK:
1954+
if (type == SQLITE_BLOB) {
1955+
decode_context->pk = pval;
1956+
decode_context->pk_len = ival;
1957+
}
1958+
break;
1959+
case CLOUDSYNC_PK_INDEX_COLNAME:
1960+
if (type == SQLITE_TEXT) {
1961+
decode_context->col_name = pval;
1962+
decode_context->col_name_len = ival;
1963+
}
1964+
break;
1965+
case CLOUDSYNC_PK_INDEX_COLVERSION:
1966+
if (type == SQLITE_INTEGER) decode_context->col_version = ival;
1967+
break;
19351968
case CLOUDSYNC_PK_INDEX_DBVERSION:
1936-
decode_context->tmp_dbversion = ival;
1969+
if (type == SQLITE_INTEGER) decode_context->db_version = ival;
19371970
break;
1938-
case CLOUDSYNC_PK_INDEX_SEQ:
1939-
// when the dbversion field is incremented the seq val must be updated too
1940-
// because the current decode_context->seq field refers to the previous dbversion
1941-
if (decode_context->tmp_dbversion > decode_context->dbversion) {
1942-
decode_context->dbversion = decode_context->tmp_dbversion;
1943-
decode_context->seq = ival;
1944-
} else if (decode_context->tmp_dbversion == decode_context->dbversion) {
1945-
decode_context->seq = MAX(decode_context->seq, ival);
1971+
case CLOUDSYNC_PK_INDEX_SITEID:
1972+
if (type == SQLITE_BLOB) {
1973+
decode_context->site_id = pval;
1974+
decode_context->site_id_len = ival;
19461975
}
1947-
// reset the tmp_dbversion value before processing the next row
1948-
decode_context->tmp_dbversion = 0;
1976+
break;
1977+
case CLOUDSYNC_PK_INDEX_CL:
1978+
if (type == SQLITE_INTEGER) decode_context->cl = ival;
1979+
break;
1980+
case CLOUDSYNC_PK_INDEX_SEQ:
1981+
if (type == SQLITE_INTEGER) decode_context->seq = ival;
19491982
break;
19501983
}
19511984
}
19521985

19531986
return rc;
19541987
}
19551988

1989+
// #ifndef CLOUDSYNC_OMIT_RLS_VALIDATION
1990+
19561991
int cloudsync_payload_apply (sqlite3_context *context, const char *payload, int blen) {
19571992
// decode header
19581993
cloudsync_network_header header;
@@ -2016,31 +2051,48 @@ int cloudsync_payload_apply (sqlite3_context *context, const char *payload, int
20162051
uint32_t nrows = header.nrows;
20172052
int dbversion = dbutils_settings_get_int_value(db, CLOUDSYNC_KEY_CHECK_DBVERSION);
20182053
int seq = dbutils_settings_get_int_value(db, CLOUDSYNC_KEY_CHECK_SEQ);
2019-
cloudsync_pk_decode_bind_context xdata = {.vm = vm, .dbversion = dbversion, .seq = seq, .tmp_dbversion = 0};
2054+
cloudsync_pk_decode_bind_context decoded_context = {.vm = vm};
2055+
void *payload_apply_xdata = NULL;
2056+
cloudsync_payload_apply_callback_t payload_apply_callback = cloudsync_get_payload_apply_callback(db);
2057+
20202058
for (uint32_t i=0; i<nrows; ++i) {
20212059
size_t seek = 0;
2022-
pk_decode((char *)buffer, blen, ncols, &seek, cloudsync_pk_decode_bind_callback, &xdata);
2060+
pk_decode((char *)buffer, blen, ncols, &seek, cloudsync_pk_decode_bind_callback, &decoded_context);
20232061
// n is the pk_decode return value, I don't think I should assert here because in any case the next sqlite3_step would fail
20242062
// assert(n == ncols);
20252063

2026-
rc = sqlite3_step(vm);
2027-
if (rc != SQLITE_DONE) break;
2064+
bool approved = true;
2065+
if (payload_apply_callback) approved = payload_apply_callback(&payload_apply_xdata, &decoded_context, db, data, CLOUDSYNC_PAYLOAD_APPLY_WILL_APPLY, SQLITE_OK);
2066+
2067+
if (approved) {
2068+
rc = sqlite3_step(vm);
2069+
if (rc != SQLITE_DONE) {
2070+
// don't "break;", the error can be due to a RLS policy.
2071+
// in case of error we try to apply the following changes
2072+
printf("cloudsync_payload_apply error in step: (%d) %s\n", rc, sqlite3_errmsg(db));
2073+
}
2074+
}
2075+
2076+
if (payload_apply_callback) payload_apply_callback(&payload_apply_xdata, &decoded_context, db, data, CLOUDSYNC_PAYLOAD_APPLY_DID_APPLY, rc);
20282077

20292078
buffer += seek;
20302079
blen -= seek;
20312080
stmt_reset(vm);
20322081
}
2033-
2082+
2083+
if (payload_apply_callback) payload_apply_callback(&payload_apply_xdata, &decoded_context, db, data, CLOUDSYNC_PAYLOAD_APPLY_CLEANUP, rc);
2084+
20342085
if (rc == SQLITE_DONE) rc = SQLITE_OK;
20352086
if (rc == SQLITE_OK) {
20362087
char buf[256];
2037-
if (xdata.dbversion != dbversion) {
2038-
snprintf(buf, sizeof(buf), "%lld", xdata.dbversion);
2088+
if (decoded_context.db_version >= dbversion) {
2089+
snprintf(buf, sizeof(buf), "%lld", decoded_context.db_version);
20392090
dbutils_settings_set_key_value(db, context, CLOUDSYNC_KEY_CHECK_DBVERSION, buf);
2040-
}
2041-
if (xdata.seq != seq) {
2042-
snprintf(buf, sizeof(buf), "%lld", xdata.seq);
2043-
dbutils_settings_set_key_value(db, context, CLOUDSYNC_KEY_CHECK_SEQ, buf);
2091+
2092+
if (decoded_context.seq != seq) {
2093+
snprintf(buf, sizeof(buf), "%lld", decoded_context.seq);
2094+
dbutils_settings_set_key_value(db, context, CLOUDSYNC_KEY_CHECK_SEQ, buf);
2095+
}
20442096
}
20452097
}
20462098

@@ -2061,6 +2113,26 @@ int cloudsync_payload_apply (sqlite3_context *context, const char *payload, int
20612113
return nrows;
20622114
}
20632115

2116+
sqlite3_stmt *cloudsync_col_value_stmt (sqlite3 *db, cloudsync_context *data, const char *tbl_name, bool *persistent) {
2117+
sqlite3_stmt *vm;
2118+
2119+
cloudsync_table_context *table = table_lookup(data, tbl_name, false);
2120+
char *col_name = NULL;
2121+
if (table->ncols > 0) {
2122+
col_name = table->col_name[0];
2123+
// retrieve col_value precompiled statement
2124+
vm = table_column_lookup(table, col_name, false, NULL);
2125+
*persistent = true;
2126+
} else {
2127+
char *sql = table_build_value_sql(db, table, "*");
2128+
sqlite3_prepare_v2(db, sql, -1, &vm, NULL);
2129+
cloudsync_memory_free(sql);
2130+
*persistent = false;
2131+
}
2132+
2133+
return vm;
2134+
}
2135+
20642136
void cloudsync_network_decode (sqlite3_context *context, int argc, sqlite3_value **argv) {
20652137
DEBUG_FUNCTION("cloudsync_network_decode");
20662138
//debug_values(argc, argv);

src/cloudsync.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,30 @@
2121
#define CLOUDSYNC_RLS_RESTRICTED_VALUE "__[RLS]__"
2222
#define CLOUDSYNC_DISABLE_ROWIDONLY_TABLES 1
2323

24+
typedef enum {
25+
CLOUDSYNC_PAYLOAD_APPLY_WILL_APPLY = 1,
26+
CLOUDSYNC_PAYLOAD_APPLY_DID_APPLY = 2,
27+
CLOUDSYNC_PAYLOAD_APPLY_CLEANUP = 3
28+
} CLOUDSYNC_PAYLOAD_APPLY_STEPS;
29+
30+
typedef struct {
31+
sqlite3_stmt *vm;
32+
char *tbl;
33+
int64_t tbl_len;
34+
const void *pk;
35+
int64_t pk_len;
36+
char *col_name;
37+
int64_t col_name_len;
38+
int64_t col_version;
39+
int64_t db_version;
40+
const void *site_id;
41+
int64_t site_id_len;
42+
int64_t cl;
43+
int64_t seq;
44+
} cloudsync_pk_decode_bind_context;
45+
2446
typedef struct cloudsync_context cloudsync_context;
47+
typedef bool (*cloudsync_payload_apply_callback_t)(void **xdata, cloudsync_pk_decode_bind_context *decoded_change, sqlite3 *db, cloudsync_context *data, int step, int rc);
2548

2649
int sqlite3_cloudsync_init (sqlite3 *db, char **pzErrMsg, const sqlite3_api_routines *pApi);
2750
bool cloudsync_config_exists (sqlite3 *db);
@@ -32,5 +55,7 @@ void cloudsync_sync_table_key (cloudsync_context *data, const char *table, const
3255
void *cloudsync_get_auxdata (sqlite3_context *context);
3356
void cloudsync_set_auxdata (sqlite3_context *context, void *xdata);
3457
int cloudsync_payload_apply (sqlite3_context *context, const char *payload, int blen);
58+
void cloudsync_set_payload_apply_callback(sqlite3 *db, cloudsync_payload_apply_callback_t callback);
59+
sqlite3_stmt *cloudsync_col_value_stmt (sqlite3 *db, cloudsync_context *data, const char *tbl_name, bool *persistent);
3560

3661
#endif

src/dbutils.c

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,15 @@ bool dbutils_table_sanity_check (sqlite3 *db, sqlite3_context *context, const ch
416416
}
417417
}
418418

419+
// check for columns declared as NOT NULL without a DEFAULT value.
420+
// Otherwise, col_merge_stmt would fail if changes to other columns are inserted first.
421+
sql = sqlite3_snprintf((int)blen, buffer, "SELECT count(*) FROM pragma_table_info('%w') WHERE pk=0 AND \"notnull\"=1 AND \"dflt_value\" IS NULL;", name);
422+
sqlite3_int64 count3 = dbutils_int_select(db, sql);
423+
if (count3 > 0) {
424+
dbutils_context_result_error(context, "All non-primary key columns declared as NOT NULL must have a DEFAULT value. (table %s)", name);
425+
return false;
426+
}
427+
419428
return true;
420429
}
421430

src/utils.c

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,9 @@ void *cloudsync_memory_zeroalloc (uint64_t size) {
126126
return ptr;
127127
}
128128

129-
char *cloudsync_string_dup (const char *str, bool lowercase) {
129+
char *cloudsync_string_ndup (const char *str, size_t len, bool lowercase) {
130130
if (str == NULL) return NULL;
131131

132-
size_t len = strlen(str);
133132
char *s = (char *)cloudsync_memory_alloc((sqlite3_uint64)(len + 1));
134133
if (!s) return NULL;
135134

@@ -148,6 +147,20 @@ char *cloudsync_string_dup (const char *str, bool lowercase) {
148147
return s;
149148
}
150149

150+
char *cloudsync_string_dup (const char *str, bool lowercase) {
151+
if (str == NULL) return NULL;
152+
153+
size_t len = strlen(str);
154+
return cloudsync_string_ndup(str, len, lowercase);
155+
}
156+
157+
int cloudsync_blob_compare(const char *blob1, size_t size1, const char *blob2, size_t size2) {
158+
if (size1 != size2) {
159+
return (int)(size1 - size2); // Blobs are different if sizes are different
160+
}
161+
return memcmp(blob1, blob2, size1); // Use memcmp for byte-by-byte comparison
162+
}
163+
151164
void cloudsync_rowid_decode (sqlite3_int64 rowid, sqlite3_int64 *db_version, sqlite3_int64 *seq) {
152165
// use unsigned 64-bit integer for intermediate calculations
153166
// when db_version is large enough, it can cause overflow, leading to negative values

src/utils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,10 @@ char *cloudsync_string_replace_prefix(const char *input, char *prefix, char *rep
129129
uint64_t fnv1a_hash(const char *data, size_t len);
130130

131131
void *cloudsync_memory_zeroalloc (uint64_t size);
132+
char *cloudsync_string_ndup (const char *str, size_t len, bool lowercase);
132133
char *cloudsync_string_dup (const char *str, bool lowercase);
134+
int cloudsync_blob_compare(const char *blob1, size_t size1, const char *blob2, size_t size2);
135+
133136
void cloudsync_rowid_decode (sqlite3_int64 rowid, sqlite3_int64 *db_version, sqlite3_int64 *seq);
134137

135138
#endif

0 commit comments

Comments
 (0)