From ca57704edbd2ff51339570e9f8fb3387d558d9f7 Mon Sep 17 00:00:00 2001 From: Andrew Dunham Date: Wed, 7 Jan 2026 11:45:50 -0500 Subject: [PATCH] cgosqlite: add feature flag to always copy []bytes returned When enabled, this ensures that sql.RawBytes values do not actually contain pointers to memory owned by SQLite. Also add some tests that verify the behaviour of ColumnBlob, both with and without this flag set. Updates tailscale/corp#35671 Signed-off-by: Andrew Dunham --- cgosqlite/cgosqlite.go | 21 +++++++- cgosqlite/cgosqlite_test.go | 102 ++++++++++++++++++++++++++++++++++++ 2 files changed, 122 insertions(+), 1 deletion(-) create mode 100644 cgosqlite/cgosqlite_test.go diff --git a/cgosqlite/cgosqlite.go b/cgosqlite/cgosqlite.go index eff9a12..04a1f14 100644 --- a/cgosqlite/cgosqlite.go +++ b/cgosqlite/cgosqlite.go @@ -51,6 +51,7 @@ package cgosqlite import "C" import ( "sync" + "sync/atomic" "time" "unsafe" @@ -61,6 +62,20 @@ import ( // avoid the need to allocate new storage in each invocation. var emptyChar [1]C.char +var alwaysCopyBlob atomic.Bool + +// SetAlwaysCopyBlob sets whether [Stmt.ColumnBlob] should copy the blob data +// instead of returning a slice that aliases SQLite's internal memory. This is +// safe to call at runtime; the setting will apply to subsequent calls to +// [Stmt.ColumnBlob]. +// +// This was added to help detect misuse of [sql.RawBytes] where we might be +// modifying data internal to SQLite, retaining it after it's no longer valid, +// and so on. +func SetAlwaysCopyBlob(copy bool) { + alwaysCopyBlob.Store(copy) +} + func init() { C.sqlite3_initialize() } @@ -368,7 +383,11 @@ func (stmt *Stmt) ColumnBlob(col int) []byte { return nil } n := int(C.sqlite3_column_bytes(stmt.stmt, C.int(col))) - return unsafe.Slice((*byte)(unsafe.Pointer(res)), n) + slice := unsafe.Slice((*byte)(unsafe.Pointer(res)), n) + if alwaysCopyBlob.Load() { + return append([]byte(nil), slice...) + } + return slice } func (stmt *Stmt) ColumnDouble(col int) float64 { diff --git a/cgosqlite/cgosqlite_test.go b/cgosqlite/cgosqlite_test.go new file mode 100644 index 0000000..9b4f94f --- /dev/null +++ b/cgosqlite/cgosqlite_test.go @@ -0,0 +1,102 @@ +package cgosqlite + +import ( + "bytes" + "path/filepath" + "testing" + + "github.com/tailscale/sqlite/sqliteh" +) + +func TestColumnBlob(t *testing.T) { + // Run the test with and without the SetAlwaysCopyBlob flag enabled. + cases := []struct { + name string + flag bool + }{ + {"off", false}, + {"on", true}, + } + for _, tt := range cases { + t.Run("SetAlwaysCopyBlob="+tt.name, func(t *testing.T) { + SetAlwaysCopyBlob(tt.flag) + + // Open a test database + db, err := Open(filepath.Join(t.TempDir(), "test.db"), sqliteh.OpenFlagsDefault, "") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + mustRun := func(sql string) { + t.Helper() + stmt, _, err := db.Prepare(sql, 0) + if err != nil { + t.Fatalf("Prepare %q: %v", sql, err) + } + if _, err := stmt.Step(nil); err != nil { + t.Fatalf("Step: %v", err) + } + if err := stmt.Finalize(); err != nil { + t.Fatalf("Finalize: %v", err) + } + } + + mustRun("CREATE TABLE t (id INTEGER PRIMARY KEY, data BLOB)") + mustRun(`INSERT INTO t (id, data) VALUES (1, 'HELLOHELLOHELLOHELLOHELLOHELLO99')`) + mustRun(`INSERT INTO t (id, data) VALUES (2, '')`) + mustRun(`INSERT INTO t (id, data) VALUES (3, NULL)`) + + // queryRow runs the given query and returns the *Stmt for the first row. + queryRow := func(t *testing.T, sql string) sqliteh.Stmt { + t.Helper() + stmt, _, err := db.Prepare(sql, 0) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { + stmt.Finalize() + }) + row, err := stmt.Step(nil) + if err != nil { + t.Fatal(err) + } + if !row { + t.Fatal("expected a row") + } + return stmt + } + + t.Run("WithData", func(t *testing.T) { + stmt := queryRow(t, "SELECT data FROM t WHERE id = 1") + data := stmt.ColumnBlob(0) + + const want = "HELLOHELLOHELLOHELLOHELLOHELLO99" + if !bytes.Equal(data, []byte(want)) { + t.Fatalf("got %q, want %q", data, want) + } + }) + + t.Run("EmptyBlob", func(t *testing.T) { + stmt := queryRow(t, "SELECT data FROM t WHERE id = 2") + data := stmt.ColumnBlob(0) + if len(data) != 0 { + t.Fatalf("got %d bytes, want 0 bytes", len(data)) + } + + // NOTE: it appears that this returns a nil + // slice, not a non-nil empty slice; both are + // valid representations of an empty blob, so + // we're not going to assert on which we get. + }) + + t.Run("NullBlob", func(t *testing.T) { + stmt := queryRow(t, "SELECT data FROM t WHERE id = 3") + data := stmt.ColumnBlob(0) + if data != nil { + t.Fatalf("got %q, want nil", data) + } + }) + }) + } +}