Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion cgosqlite/cgosqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ package cgosqlite
import "C"
import (
"sync"
"sync/atomic"
"time"
"unsafe"

Expand All @@ -61,6 +62,20 @@ import (
// avoid the need to allocate new storage in each invocation.
var emptyChar [1]C.char

var alwaysCopyBlob atomic.Bool

// SetAlwaysCopyBlob sets whether [Stmt.ColumnBlob] should copy the blob data
// instead of returning a slice that aliases SQLite's internal memory. This is
// safe to call at runtime; the setting will apply to subsequent calls to
// [Stmt.ColumnBlob].
//
// This was added to help detect misuse of [sql.RawBytes] where we might be
// modifying data internal to SQLite, retaining it after it's no longer valid,
// and so on.
func SetAlwaysCopyBlob(copy bool) {
alwaysCopyBlob.Store(copy)
}

func init() {
C.sqlite3_initialize()
}
Expand Down Expand Up @@ -368,7 +383,11 @@ func (stmt *Stmt) ColumnBlob(col int) []byte {
return nil
}
n := int(C.sqlite3_column_bytes(stmt.stmt, C.int(col)))
return unsafe.Slice((*byte)(unsafe.Pointer(res)), n)
slice := unsafe.Slice((*byte)(unsafe.Pointer(res)), n)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to remember/confirm what unsafe.Slice's behavior is when n == 0.

Maybe we don't want that pointer in the slice header even a 0 len and 0 cap?

Worth a if n == 0 { return nil } above this?

Oh, no, that'd change it to map to a NULL column, huh?

So I guess this is correct.

(sorry, thinking aloud)

if alwaysCopyBlob.Load() {
return append([]byte(nil), slice...)
}
return slice
}

func (stmt *Stmt) ColumnDouble(col int) float64 {
Expand Down
102 changes: 102 additions & 0 deletions cgosqlite/cgosqlite_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
package cgosqlite

import (
"bytes"
"path/filepath"
"testing"

"github.com/tailscale/sqlite/sqliteh"
)

func TestColumnBlob(t *testing.T) {
// Run the test with and without the SetAlwaysCopyBlob flag enabled.
cases := []struct {
name string
flag bool
}{
{"off", false},
{"on", true},
}
for _, tt := range cases {
t.Run("SetAlwaysCopyBlob="+tt.name, func(t *testing.T) {
SetAlwaysCopyBlob(tt.flag)

// Open a test database
db, err := Open(filepath.Join(t.TempDir(), "test.db"), sqliteh.OpenFlagsDefault, "")
if err != nil {
t.Fatal(err)
}
defer db.Close()

mustRun := func(sql string) {
t.Helper()
stmt, _, err := db.Prepare(sql, 0)
if err != nil {
t.Fatalf("Prepare %q: %v", sql, err)
}
if _, err := stmt.Step(nil); err != nil {
t.Fatalf("Step: %v", err)
}
if err := stmt.Finalize(); err != nil {
t.Fatalf("Finalize: %v", err)
}
}

mustRun("CREATE TABLE t (id INTEGER PRIMARY KEY, data BLOB)")
mustRun(`INSERT INTO t (id, data) VALUES (1, 'HELLOHELLOHELLOHELLOHELLOHELLO99')`)
mustRun(`INSERT INTO t (id, data) VALUES (2, '')`)
mustRun(`INSERT INTO t (id, data) VALUES (3, NULL)`)

// queryRow runs the given query and returns the *Stmt for the first row.
queryRow := func(t *testing.T, sql string) sqliteh.Stmt {
t.Helper()
stmt, _, err := db.Prepare(sql, 0)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
stmt.Finalize()
})
row, err := stmt.Step(nil)
if err != nil {
t.Fatal(err)
}
if !row {
t.Fatal("expected a row")
}
return stmt
}

t.Run("WithData", func(t *testing.T) {
stmt := queryRow(t, "SELECT data FROM t WHERE id = 1")
data := stmt.ColumnBlob(0)

const want = "HELLOHELLOHELLOHELLOHELLOHELLO99"
if !bytes.Equal(data, []byte(want)) {
t.Fatalf("got %q, want %q", data, want)
}
})

t.Run("EmptyBlob", func(t *testing.T) {
stmt := queryRow(t, "SELECT data FROM t WHERE id = 2")
data := stmt.ColumnBlob(0)
if len(data) != 0 {
t.Fatalf("got %d bytes, want 0 bytes", len(data))
}

// NOTE: it appears that this returns a nil
// slice, not a non-nil empty slice; both are
// valid representations of an empty blob, so
// we're not going to assert on which we get.
})

t.Run("NullBlob", func(t *testing.T) {
stmt := queryRow(t, "SELECT data FROM t WHERE id = 3")
data := stmt.ColumnBlob(0)
if data != nil {
t.Fatalf("got %q, want nil", data)
}
})
})
}
}
Loading