Skip to content

Commit ca57704

Browse files
committed
cgosqlite: add feature flag to always copy []bytes returned
When enabled, this ensures that sql.RawBytes values do not actually contain pointers to memory owned by SQLite. Also add some tests that verify the behaviour of ColumnBlob, both with and without this flag set. Updates tailscale/corp#35671 Signed-off-by: Andrew Dunham <[email protected]>
1 parent 8ac0a9c commit ca57704

File tree

2 files changed

+122
-1
lines changed

2 files changed

+122
-1
lines changed

cgosqlite/cgosqlite.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ package cgosqlite
5151
import "C"
5252
import (
5353
"sync"
54+
"sync/atomic"
5455
"time"
5556
"unsafe"
5657

@@ -61,6 +62,20 @@ import (
6162
// avoid the need to allocate new storage in each invocation.
6263
var emptyChar [1]C.char
6364

65+
var alwaysCopyBlob atomic.Bool
66+
67+
// SetAlwaysCopyBlob sets whether [Stmt.ColumnBlob] should copy the blob data
68+
// instead of returning a slice that aliases SQLite's internal memory. This is
69+
// safe to call at runtime; the setting will apply to subsequent calls to
70+
// [Stmt.ColumnBlob].
71+
//
72+
// This was added to help detect misuse of [sql.RawBytes] where we might be
73+
// modifying data internal to SQLite, retaining it after it's no longer valid,
74+
// and so on.
75+
func SetAlwaysCopyBlob(copy bool) {
76+
alwaysCopyBlob.Store(copy)
77+
}
78+
6479
func init() {
6580
C.sqlite3_initialize()
6681
}
@@ -368,7 +383,11 @@ func (stmt *Stmt) ColumnBlob(col int) []byte {
368383
return nil
369384
}
370385
n := int(C.sqlite3_column_bytes(stmt.stmt, C.int(col)))
371-
return unsafe.Slice((*byte)(unsafe.Pointer(res)), n)
386+
slice := unsafe.Slice((*byte)(unsafe.Pointer(res)), n)
387+
if alwaysCopyBlob.Load() {
388+
return append([]byte(nil), slice...)
389+
}
390+
return slice
372391
}
373392

374393
func (stmt *Stmt) ColumnDouble(col int) float64 {

cgosqlite/cgosqlite_test.go

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
package cgosqlite
2+
3+
import (
4+
"bytes"
5+
"path/filepath"
6+
"testing"
7+
8+
"github.com/tailscale/sqlite/sqliteh"
9+
)
10+
11+
func TestColumnBlob(t *testing.T) {
12+
// Run the test with and without the SetAlwaysCopyBlob flag enabled.
13+
cases := []struct {
14+
name string
15+
flag bool
16+
}{
17+
{"off", false},
18+
{"on", true},
19+
}
20+
for _, tt := range cases {
21+
t.Run("SetAlwaysCopyBlob="+tt.name, func(t *testing.T) {
22+
SetAlwaysCopyBlob(tt.flag)
23+
24+
// Open a test database
25+
db, err := Open(filepath.Join(t.TempDir(), "test.db"), sqliteh.OpenFlagsDefault, "")
26+
if err != nil {
27+
t.Fatal(err)
28+
}
29+
defer db.Close()
30+
31+
mustRun := func(sql string) {
32+
t.Helper()
33+
stmt, _, err := db.Prepare(sql, 0)
34+
if err != nil {
35+
t.Fatalf("Prepare %q: %v", sql, err)
36+
}
37+
if _, err := stmt.Step(nil); err != nil {
38+
t.Fatalf("Step: %v", err)
39+
}
40+
if err := stmt.Finalize(); err != nil {
41+
t.Fatalf("Finalize: %v", err)
42+
}
43+
}
44+
45+
mustRun("CREATE TABLE t (id INTEGER PRIMARY KEY, data BLOB)")
46+
mustRun(`INSERT INTO t (id, data) VALUES (1, 'HELLOHELLOHELLOHELLOHELLOHELLO99')`)
47+
mustRun(`INSERT INTO t (id, data) VALUES (2, '')`)
48+
mustRun(`INSERT INTO t (id, data) VALUES (3, NULL)`)
49+
50+
// queryRow runs the given query and returns the *Stmt for the first row.
51+
queryRow := func(t *testing.T, sql string) sqliteh.Stmt {
52+
t.Helper()
53+
stmt, _, err := db.Prepare(sql, 0)
54+
if err != nil {
55+
t.Fatal(err)
56+
}
57+
t.Cleanup(func() {
58+
stmt.Finalize()
59+
})
60+
row, err := stmt.Step(nil)
61+
if err != nil {
62+
t.Fatal(err)
63+
}
64+
if !row {
65+
t.Fatal("expected a row")
66+
}
67+
return stmt
68+
}
69+
70+
t.Run("WithData", func(t *testing.T) {
71+
stmt := queryRow(t, "SELECT data FROM t WHERE id = 1")
72+
data := stmt.ColumnBlob(0)
73+
74+
const want = "HELLOHELLOHELLOHELLOHELLOHELLO99"
75+
if !bytes.Equal(data, []byte(want)) {
76+
t.Fatalf("got %q, want %q", data, want)
77+
}
78+
})
79+
80+
t.Run("EmptyBlob", func(t *testing.T) {
81+
stmt := queryRow(t, "SELECT data FROM t WHERE id = 2")
82+
data := stmt.ColumnBlob(0)
83+
if len(data) != 0 {
84+
t.Fatalf("got %d bytes, want 0 bytes", len(data))
85+
}
86+
87+
// NOTE: it appears that this returns a nil
88+
// slice, not a non-nil empty slice; both are
89+
// valid representations of an empty blob, so
90+
// we're not going to assert on which we get.
91+
})
92+
93+
t.Run("NullBlob", func(t *testing.T) {
94+
stmt := queryRow(t, "SELECT data FROM t WHERE id = 3")
95+
data := stmt.ColumnBlob(0)
96+
if data != nil {
97+
t.Fatalf("got %q, want nil", data)
98+
}
99+
})
100+
})
101+
}
102+
}

0 commit comments

Comments
 (0)