Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 73 additions & 3 deletions cgosqlite/cgosqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ package cgosqlite
*/
import "C"
import (
"bytes"
"runtime"
"sync"
"sync/atomic"
"time"
Expand All @@ -76,6 +78,27 @@ func SetAlwaysCopyBlob(copy bool) {
alwaysCopyBlob.Store(copy)
}

var columnBlobModifiedHook atomic.Pointer[func(query string)]

// SetColumnBlobModifiedHook sets a function to be called (in a new goroutine)
// whenever a []byte returned from [Stmt.ColumBlob] is detected as modified.
// The hook receives the SQL query that created the statement.
//
// Setting a non-nil hook enables verification by attaching a cleanup function
// to each returned slice that compares the final contents against the
// original. Pass nil to disable.
//
// As a necessary side effect, this function causes [Stmt.ColumnBlob] to always
// copy the blob data, to ensure that the comparison in the cleanup function is
// valid, similar to SetAlwaysCopyBlob.
func SetColumnBlobModifiedHook(hook func(query string)) {
if hook == nil {
columnBlobModifiedHook.Store(nil)
} else {
columnBlobModifiedHook.Store(&hook)
}
}

func init() {
C.sqlite3_initialize()
}
Expand All @@ -92,6 +115,7 @@ type Stmt struct {
db *DB
stmt *C.sqlite3_stmt
start C.struct_timespec
query string // original query, stored for columnBlobModifiedHook

// used as scratch space when calling into cgo
rowid, changes C.sqlite3_int64
Expand Down Expand Up @@ -200,7 +224,7 @@ func (db *DB) Prepare(query string, prepFlags sqliteh.PrepareFlags) (stmt sqlite
return nil, "", err
}
remainingQuery = query[len(query)-int(C.strlen(csqlTail)):]
return &Stmt{db: db, stmt: cstmt}, remainingQuery, nil
return &Stmt{db: db, stmt: cstmt, query: query}, remainingQuery, nil
}

func (db *DB) DisableFunction(name string, numArgs int) error {
Expand Down Expand Up @@ -377,16 +401,62 @@ func (stmt *Stmt) ColumnText(col int) string {
return C.GoStringN(str, n)
}

// blobCheckArg is the argument passed to the cleanup function for verifying
// that the slice returned from ColumnBlob was not modified.
//
// TODO: We use uintptr instead of []byte to avoid keeping the slice alive
// (which would prevent the cleanup from running); is that right?
type blobCheckArg struct {
original []byte // copy of original data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or instead of this bytes.Clone, maybe store a https://pkg.go.dev/hash/fnv#New64a hash of the original data?

ptr uintptr // pointer to first byte of slice
len int // length of slice
query string // SQL query that produced the blob
hook func(query string) // hook to call if modified
}

// blobCheckHook, if non-nil, is called after each blob check Cleanup function
// executes. This allows deterministic tests.
var blobCheckHook func()

func (stmt *Stmt) ColumnBlob(col int) []byte {
res := C.sqlite3_column_blob(stmt.stmt, C.int(col))
if res == nil {
return nil
}
n := int(C.sqlite3_column_bytes(stmt.stmt, C.int(col)))
slice := unsafe.Slice((*byte)(unsafe.Pointer(res)), n)
if alwaysCopyBlob.Load() {
return append([]byte(nil), slice...)

// In addition to copying if the alwaysCopyBlob flag is set, also copy
// if there is a columnBlobModifiedHook set. This is because a
// runtime.AddCleanup callback executes at some indeterminate time in
// the future, after the point which SQLite might have reused the
// underlying memory. Copying now ensures that the comparison in the
// cleanup function is valid.
hookPtr := columnBlobModifiedHook.Load()
if alwaysCopyBlob.Load() || hookPtr != nil {
slice = append([]byte(nil), slice...)
}

if hookPtr != nil && n > 0 {
arg := blobCheckArg{
original: bytes.Clone(slice),
ptr: uintptr(unsafe.Pointer(&slice[0])),
len: n,
query: stmt.query,
hook: *hookPtr,
}
runtime.AddCleanup(&slice[0], func(a blobCheckArg) {
current := unsafe.Slice((*byte)(unsafe.Pointer(a.ptr)), a.len)
if !bytes.Equal(current, a.original) {
go a.hook(a.query)
}

if blobCheckHook != nil {
blobCheckHook()
}
}, arg)
}

return slice
}

Expand Down
245 changes: 204 additions & 41 deletions cgosqlite/cgosqlite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ package cgosqlite
import (
"bytes"
"path/filepath"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"

"github.com/tailscale/sqlite/sqliteh"
)
Expand All @@ -28,47 +32,13 @@ func TestColumnBlob(t *testing.T) {
}
defer db.Close()

mustRun := func(sql string) {
t.Helper()
stmt, _, err := db.Prepare(sql, 0)
if err != nil {
t.Fatalf("Prepare %q: %v", sql, err)
}
if _, err := stmt.Step(nil); err != nil {
t.Fatalf("Step: %v", err)
}
if err := stmt.Finalize(); err != nil {
t.Fatalf("Finalize: %v", err)
}
}

mustRun("CREATE TABLE t (id INTEGER PRIMARY KEY, data BLOB)")
mustRun(`INSERT INTO t (id, data) VALUES (1, 'HELLOHELLOHELLOHELLOHELLOHELLO99')`)
mustRun(`INSERT INTO t (id, data) VALUES (2, '')`)
mustRun(`INSERT INTO t (id, data) VALUES (3, NULL)`)

// queryRow runs the given query and returns the *Stmt for the first row.
queryRow := func(t *testing.T, sql string) sqliteh.Stmt {
t.Helper()
stmt, _, err := db.Prepare(sql, 0)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
stmt.Finalize()
})
row, err := stmt.Step(nil)
if err != nil {
t.Fatal(err)
}
if !row {
t.Fatal("expected a row")
}
return stmt
}
mustRun(t, db, "CREATE TABLE t (id INTEGER PRIMARY KEY, data BLOB)")
mustRun(t, db, `INSERT INTO t (id, data) VALUES (1, 'HELLOHELLOHELLOHELLOHELLOHELLO99')`)
mustRun(t, db, `INSERT INTO t (id, data) VALUES (2, '')`)
mustRun(t, db, `INSERT INTO t (id, data) VALUES (3, NULL)`)

t.Run("WithData", func(t *testing.T) {
stmt := queryRow(t, "SELECT data FROM t WHERE id = 1")
stmt := queryRow(t, db, "SELECT data FROM t WHERE id = 1")
data := stmt.ColumnBlob(0)

const want = "HELLOHELLOHELLOHELLOHELLOHELLO99"
Expand All @@ -78,7 +48,7 @@ func TestColumnBlob(t *testing.T) {
})

t.Run("EmptyBlob", func(t *testing.T) {
stmt := queryRow(t, "SELECT data FROM t WHERE id = 2")
stmt := queryRow(t, db, "SELECT data FROM t WHERE id = 2")
data := stmt.ColumnBlob(0)
if len(data) != 0 {
t.Fatalf("got %d bytes, want 0 bytes", len(data))
Expand All @@ -91,7 +61,7 @@ func TestColumnBlob(t *testing.T) {
})

t.Run("NullBlob", func(t *testing.T) {
stmt := queryRow(t, "SELECT data FROM t WHERE id = 3")
stmt := queryRow(t, db, "SELECT data FROM t WHERE id = 3")
data := stmt.ColumnBlob(0)
if data != nil {
t.Fatalf("got %q, want nil", data)
Expand All @@ -100,3 +70,196 @@ func TestColumnBlob(t *testing.T) {
})
}
}

func TestColumnBlobModifiedHook(t *testing.T) {
// Disable the "always copy blob" option to test just the hook behavior
SetAlwaysCopyBlob(false)

// Write to this channel every time a cleanup function executes, so we
// can ensure they've run.
checkRun := make(chan struct{}, 10_000) // high enough to never block
blobCheckHook = func() {
checkRun <- struct{}{}
Copy link
Member

@bradfitz bradfitz Jan 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

select {
case checkRun <- struct{}{}:
default:
   panic("checkRun unexpectedly full")
}

}
t.Cleanup(func() {
blobCheckHook = nil
})

// waitForCleanup waits for one cleanup to run.
waitForCleanup := func() {
timedOut := time.After(10 * time.Second)
for {
runtime.GC()
runtime.Gosched()

select {
case <-checkRun:
return
case <-t.Context().Done():
t.Fatal("test context done while waiting for cleanup")
case <-timedOut:
t.Fatal("timeout waiting for cleanup")
case <-time.After(10 * time.Millisecond):
// retry
}
}
}

// Open a test database
db, err := Open(filepath.Join(t.TempDir(), "test.db"), sqliteh.OpenFlagsDefault, "")
if err != nil {
t.Fatal(err)
}
defer db.Close()

// Create a table with some blob data
mustRun(t, db, "CREATE TABLE t (id INTEGER PRIMARY KEY, data BLOB)")

// Use a blob larger than 16 bytes to avoid tiny object optimization which
// can prevent cleanups from running (as mentioned in the documentation
// for [runtime.AddCleanup]).
mustRun(t, db, "INSERT INTO t (id, data) VALUES (1, CAST('HELLOHELLOHELLOHELLOHELLOHELLO99' AS BLOB))")

const testQuery = "SELECT data FROM t WHERE id = 1"

t.Run("UnmodifiedSliceDoesNotCallHook", func(t *testing.T) {
var hookCalls atomic.Int64
SetColumnBlobModifiedHook(func(query string) {
hookCalls.Add(1)
})
defer SetColumnBlobModifiedHook(nil)

func() {
stmt := queryRow(t, db, testQuery)
data := stmt.ColumnBlob(0)
if len(data) != 32 {
t.Fatalf("got len %d, want 32", len(data))
}

// Don't modify data, just let it go out of scope
runtime.KeepAlive(data)
}()

waitForCleanup()
if got := hookCalls.Load(); got != 0 {
t.Errorf("hook called %d times, want 0", got)
}
})

t.Run("ModifiedSliceCallsHook", func(t *testing.T) {
var (
hookCalls atomic.Int64
receivedQuery atomic.Pointer[string]

calledOnce sync.Once
called = make(chan struct{})
)
SetColumnBlobModifiedHook(func(query string) {
hookCalls.Add(1)
receivedQuery.Store(&query)
calledOnce.Do(func() { close(called) })
})
defer SetColumnBlobModifiedHook(nil)

func() {
stmt := queryRow(t, db, testQuery)
data := stmt.ColumnBlob(0)
if len(data) != 32 {
t.Fatalf("got len %d, want 32", len(data))
}

// Modify the data to trigger our hook.
data[0] = byte((int(data[0]) + 1) % 256)

runtime.KeepAlive(data)
}()

waitForCleanup()
<-called // need to synchronize separately since it's in another goroutine

if got := hookCalls.Load(); got != 1 {
t.Errorf("hook called %d times, want 1", got)
}
if q := receivedQuery.Load(); q == nil || *q != testQuery {
got := ""
if q != nil {
got = *q
}
t.Errorf("hook received query %q, want %q", got, testQuery)
}
})

t.Run("NilHook", func(t *testing.T) {
SetColumnBlobModifiedHook(nil)

// Ensure we start with an empty channel.
drain:
for {
select {
case <-checkRun:
default:
break drain
}
}

func() {
stmt := queryRow(t, db, testQuery)
data := stmt.ColumnBlob(0)
if len(data) != 32 {
t.Fatalf("got len %d, want 32", len(data))
}

data[0] = 'Y'

runtime.KeepAlive(data)
}()

// Spin for a bit to try and trigger any cleanups to be executed.
for i := 0; i < 10; i++ {
runtime.GC()
runtime.Gosched()
time.Sleep(10 * time.Millisecond)
}

// We expect nothing in the channel, as no hook is set.
select {
case <-checkRun:
t.Fatal("unexpected cleanup hook call")
default:
}
})
}

func mustRun(t *testing.T, db sqliteh.DB, sql string) {
t.Helper()
stmt, _, err := db.Prepare(sql, 0)
if err != nil {
t.Fatalf("Prepare %q: %v", sql, err)
}
if _, err := stmt.Step(nil); err != nil {
t.Fatalf("Step: %v", err)
}
if err := stmt.Finalize(); err != nil {
t.Fatalf("Finalize: %v", err)
}
}

// queryRow runs the given query and returns the *Stmt for the first row.
func queryRow(t *testing.T, db sqliteh.DB, sql string) sqliteh.Stmt {
t.Helper()
stmt, _, err := db.Prepare(sql, 0)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
stmt.Finalize()
})
row, err := stmt.Step(nil)
if err != nil {
t.Fatal(err)
}
if !row {
t.Fatal("expected a row")
}
return stmt
}
Loading
Loading