diff --git a/cgosqlite/cgosqlite.go b/cgosqlite/cgosqlite.go index 04a1f14..58d44bb 100644 --- a/cgosqlite/cgosqlite.go +++ b/cgosqlite/cgosqlite.go @@ -50,6 +50,8 @@ package cgosqlite */ import "C" import ( + "bytes" + "runtime" "sync" "sync/atomic" "time" @@ -76,6 +78,27 @@ func SetAlwaysCopyBlob(copy bool) { alwaysCopyBlob.Store(copy) } +var columnBlobModifiedHook atomic.Pointer[func(query string)] + +// SetColumnBlobModifiedHook sets a function to be called (in a new goroutine) +// whenever a []byte returned from [Stmt.ColumBlob] is detected as modified. +// The hook receives the SQL query that created the statement. +// +// Setting a non-nil hook enables verification by attaching a cleanup function +// to each returned slice that compares the final contents against the +// original. Pass nil to disable. +// +// As a necessary side effect, this function causes [Stmt.ColumnBlob] to always +// copy the blob data, to ensure that the comparison in the cleanup function is +// valid, similar to SetAlwaysCopyBlob. +func SetColumnBlobModifiedHook(hook func(query string)) { + if hook == nil { + columnBlobModifiedHook.Store(nil) + } else { + columnBlobModifiedHook.Store(&hook) + } +} + func init() { C.sqlite3_initialize() } @@ -92,6 +115,7 @@ type Stmt struct { db *DB stmt *C.sqlite3_stmt start C.struct_timespec + query string // original query, stored for columnBlobModifiedHook // used as scratch space when calling into cgo rowid, changes C.sqlite3_int64 @@ -200,7 +224,7 @@ func (db *DB) Prepare(query string, prepFlags sqliteh.PrepareFlags) (stmt sqlite return nil, "", err } remainingQuery = query[len(query)-int(C.strlen(csqlTail)):] - return &Stmt{db: db, stmt: cstmt}, remainingQuery, nil + return &Stmt{db: db, stmt: cstmt, query: query}, remainingQuery, nil } func (db *DB) DisableFunction(name string, numArgs int) error { @@ -377,6 +401,23 @@ func (stmt *Stmt) ColumnText(col int) string { return C.GoStringN(str, n) } +// blobCheckArg is the argument passed to the cleanup function for verifying +// that the slice returned from ColumnBlob was not modified. +// +// TODO: We use uintptr instead of []byte to avoid keeping the slice alive +// (which would prevent the cleanup from running); is that right? +type blobCheckArg struct { + original []byte // copy of original data + ptr uintptr // pointer to first byte of slice + len int // length of slice + query string // SQL query that produced the blob + hook func(query string) // hook to call if modified +} + +// blobCheckHook, if non-nil, is called after each blob check Cleanup function +// executes. This allows deterministic tests. +var blobCheckHook func() + func (stmt *Stmt) ColumnBlob(col int) []byte { res := C.sqlite3_column_blob(stmt.stmt, C.int(col)) if res == nil { @@ -384,9 +425,38 @@ func (stmt *Stmt) ColumnBlob(col int) []byte { } n := int(C.sqlite3_column_bytes(stmt.stmt, C.int(col))) slice := unsafe.Slice((*byte)(unsafe.Pointer(res)), n) - if alwaysCopyBlob.Load() { - return append([]byte(nil), slice...) + + // In addition to copying if the alwaysCopyBlob flag is set, also copy + // if there is a columnBlobModifiedHook set. This is because a + // runtime.AddCleanup callback executes at some indeterminate time in + // the future, after the point which SQLite might have reused the + // underlying memory. Copying now ensures that the comparison in the + // cleanup function is valid. + hookPtr := columnBlobModifiedHook.Load() + if alwaysCopyBlob.Load() || hookPtr != nil { + slice = append([]byte(nil), slice...) } + + if hookPtr != nil && n > 0 { + arg := blobCheckArg{ + original: bytes.Clone(slice), + ptr: uintptr(unsafe.Pointer(&slice[0])), + len: n, + query: stmt.query, + hook: *hookPtr, + } + runtime.AddCleanup(&slice[0], func(a blobCheckArg) { + current := unsafe.Slice((*byte)(unsafe.Pointer(a.ptr)), a.len) + if !bytes.Equal(current, a.original) { + go a.hook(a.query) + } + + if blobCheckHook != nil { + blobCheckHook() + } + }, arg) + } + return slice } diff --git a/cgosqlite/cgosqlite_test.go b/cgosqlite/cgosqlite_test.go index 9b4f94f..520b949 100644 --- a/cgosqlite/cgosqlite_test.go +++ b/cgosqlite/cgosqlite_test.go @@ -3,7 +3,11 @@ package cgosqlite import ( "bytes" "path/filepath" + "runtime" + "sync" + "sync/atomic" "testing" + "time" "github.com/tailscale/sqlite/sqliteh" ) @@ -28,47 +32,13 @@ func TestColumnBlob(t *testing.T) { } defer db.Close() - mustRun := func(sql string) { - t.Helper() - stmt, _, err := db.Prepare(sql, 0) - if err != nil { - t.Fatalf("Prepare %q: %v", sql, err) - } - if _, err := stmt.Step(nil); err != nil { - t.Fatalf("Step: %v", err) - } - if err := stmt.Finalize(); err != nil { - t.Fatalf("Finalize: %v", err) - } - } - - mustRun("CREATE TABLE t (id INTEGER PRIMARY KEY, data BLOB)") - mustRun(`INSERT INTO t (id, data) VALUES (1, 'HELLOHELLOHELLOHELLOHELLOHELLO99')`) - mustRun(`INSERT INTO t (id, data) VALUES (2, '')`) - mustRun(`INSERT INTO t (id, data) VALUES (3, NULL)`) - - // queryRow runs the given query and returns the *Stmt for the first row. - queryRow := func(t *testing.T, sql string) sqliteh.Stmt { - t.Helper() - stmt, _, err := db.Prepare(sql, 0) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { - stmt.Finalize() - }) - row, err := stmt.Step(nil) - if err != nil { - t.Fatal(err) - } - if !row { - t.Fatal("expected a row") - } - return stmt - } + mustRun(t, db, "CREATE TABLE t (id INTEGER PRIMARY KEY, data BLOB)") + mustRun(t, db, `INSERT INTO t (id, data) VALUES (1, 'HELLOHELLOHELLOHELLOHELLOHELLO99')`) + mustRun(t, db, `INSERT INTO t (id, data) VALUES (2, '')`) + mustRun(t, db, `INSERT INTO t (id, data) VALUES (3, NULL)`) t.Run("WithData", func(t *testing.T) { - stmt := queryRow(t, "SELECT data FROM t WHERE id = 1") + stmt := queryRow(t, db, "SELECT data FROM t WHERE id = 1") data := stmt.ColumnBlob(0) const want = "HELLOHELLOHELLOHELLOHELLOHELLO99" @@ -78,7 +48,7 @@ func TestColumnBlob(t *testing.T) { }) t.Run("EmptyBlob", func(t *testing.T) { - stmt := queryRow(t, "SELECT data FROM t WHERE id = 2") + stmt := queryRow(t, db, "SELECT data FROM t WHERE id = 2") data := stmt.ColumnBlob(0) if len(data) != 0 { t.Fatalf("got %d bytes, want 0 bytes", len(data)) @@ -91,7 +61,7 @@ func TestColumnBlob(t *testing.T) { }) t.Run("NullBlob", func(t *testing.T) { - stmt := queryRow(t, "SELECT data FROM t WHERE id = 3") + stmt := queryRow(t, db, "SELECT data FROM t WHERE id = 3") data := stmt.ColumnBlob(0) if data != nil { t.Fatalf("got %q, want nil", data) @@ -100,3 +70,196 @@ func TestColumnBlob(t *testing.T) { }) } } + +func TestColumnBlobModifiedHook(t *testing.T) { + // Disable the "always copy blob" option to test just the hook behavior + SetAlwaysCopyBlob(false) + + // Write to this channel every time a cleanup function executes, so we + // can ensure they've run. + checkRun := make(chan struct{}, 10_000) // high enough to never block + blobCheckHook = func() { + checkRun <- struct{}{} + } + t.Cleanup(func() { + blobCheckHook = nil + }) + + // waitForCleanup waits for one cleanup to run. + waitForCleanup := func() { + timedOut := time.After(10 * time.Second) + for { + runtime.GC() + runtime.Gosched() + + select { + case <-checkRun: + return + case <-t.Context().Done(): + t.Fatal("test context done while waiting for cleanup") + case <-timedOut: + t.Fatal("timeout waiting for cleanup") + case <-time.After(10 * time.Millisecond): + // retry + } + } + } + + // Open a test database + db, err := Open(filepath.Join(t.TempDir(), "test.db"), sqliteh.OpenFlagsDefault, "") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + // Create a table with some blob data + mustRun(t, db, "CREATE TABLE t (id INTEGER PRIMARY KEY, data BLOB)") + + // Use a blob larger than 16 bytes to avoid tiny object optimization which + // can prevent cleanups from running (as mentioned in the documentation + // for [runtime.AddCleanup]). + mustRun(t, db, "INSERT INTO t (id, data) VALUES (1, CAST('HELLOHELLOHELLOHELLOHELLOHELLO99' AS BLOB))") + + const testQuery = "SELECT data FROM t WHERE id = 1" + + t.Run("UnmodifiedSliceDoesNotCallHook", func(t *testing.T) { + var hookCalls atomic.Int64 + SetColumnBlobModifiedHook(func(query string) { + hookCalls.Add(1) + }) + defer SetColumnBlobModifiedHook(nil) + + func() { + stmt := queryRow(t, db, testQuery) + data := stmt.ColumnBlob(0) + if len(data) != 32 { + t.Fatalf("got len %d, want 32", len(data)) + } + + // Don't modify data, just let it go out of scope + runtime.KeepAlive(data) + }() + + waitForCleanup() + if got := hookCalls.Load(); got != 0 { + t.Errorf("hook called %d times, want 0", got) + } + }) + + t.Run("ModifiedSliceCallsHook", func(t *testing.T) { + var ( + hookCalls atomic.Int64 + receivedQuery atomic.Pointer[string] + + calledOnce sync.Once + called = make(chan struct{}) + ) + SetColumnBlobModifiedHook(func(query string) { + hookCalls.Add(1) + receivedQuery.Store(&query) + calledOnce.Do(func() { close(called) }) + }) + defer SetColumnBlobModifiedHook(nil) + + func() { + stmt := queryRow(t, db, testQuery) + data := stmt.ColumnBlob(0) + if len(data) != 32 { + t.Fatalf("got len %d, want 32", len(data)) + } + + // Modify the data to trigger our hook. + data[0] = byte((int(data[0]) + 1) % 256) + + runtime.KeepAlive(data) + }() + + waitForCleanup() + <-called // need to synchronize separately since it's in another goroutine + + if got := hookCalls.Load(); got != 1 { + t.Errorf("hook called %d times, want 1", got) + } + if q := receivedQuery.Load(); q == nil || *q != testQuery { + got := "" + if q != nil { + got = *q + } + t.Errorf("hook received query %q, want %q", got, testQuery) + } + }) + + t.Run("NilHook", func(t *testing.T) { + SetColumnBlobModifiedHook(nil) + + // Ensure we start with an empty channel. + drain: + for { + select { + case <-checkRun: + default: + break drain + } + } + + func() { + stmt := queryRow(t, db, testQuery) + data := stmt.ColumnBlob(0) + if len(data) != 32 { + t.Fatalf("got len %d, want 32", len(data)) + } + + data[0] = 'Y' + + runtime.KeepAlive(data) + }() + + // Spin for a bit to try and trigger any cleanups to be executed. + for i := 0; i < 10; i++ { + runtime.GC() + runtime.Gosched() + time.Sleep(10 * time.Millisecond) + } + + // We expect nothing in the channel, as no hook is set. + select { + case <-checkRun: + t.Fatal("unexpected cleanup hook call") + default: + } + }) +} + +func mustRun(t *testing.T, db sqliteh.DB, sql string) { + t.Helper() + stmt, _, err := db.Prepare(sql, 0) + if err != nil { + t.Fatalf("Prepare %q: %v", sql, err) + } + if _, err := stmt.Step(nil); err != nil { + t.Fatalf("Step: %v", err) + } + if err := stmt.Finalize(); err != nil { + t.Fatalf("Finalize: %v", err) + } +} + +// queryRow runs the given query and returns the *Stmt for the first row. +func queryRow(t *testing.T, db sqliteh.DB, sql string) sqliteh.Stmt { + t.Helper() + stmt, _, err := db.Prepare(sql, 0) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + stmt.Finalize() + }) + row, err := stmt.Step(nil) + if err != nil { + t.Fatal(err) + } + if !row { + t.Fatal("expected a row") + } + return stmt +} diff --git a/go.mod b/go.mod index f3ecd7d..3573c78 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/tailscale/sqlite -go 1.21 +go 1.24