diff --git a/binary.go b/binary.go new file mode 100644 index 0000000..1a6311d --- /dev/null +++ b/binary.go @@ -0,0 +1,185 @@ +// Copyright (c) 2023 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sqlite + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "math" + "reflect" + "sync" + + "github.com/tailscale/sqlite/sqliteh" + "golang.org/x/sys/cpu" +) + +type driverConnRawCall struct { + f func(driverConn any) error + + // results + dc *conn + ok bool +} + +var driverConnRawCallPool = &sync.Pool{ + New: func() any { + c := new(driverConnRawCall) + c.f = func(driverConn any) error { + c.dc, c.ok = driverConn.(*conn) + return nil + } + return c + }, +} + +func getDriverConn(sc SQLConn) (dc *conn, ok bool) { + c := driverConnRawCallPool.Get().(*driverConnRawCall) + defer driverConnRawCallPool.Put(c) + err := sc.Raw(c.f) + if err != nil { + return nil, false + } + return c.dc, c.ok +} + +func QueryBinary(ctx context.Context, sqlconn SQLConn, optScratch []byte, query string, args ...any) (BinaryResults, error) { + c, ok := getDriverConn(sqlconn) + if !ok { + return nil, errors.New("sqlconn is not of expected type") + } + st, err := c.prepare(ctx, query, IsPersist(ctx)) + if err != nil { + return nil, err + } + buf := optScratch + if len(buf) == 0 { + buf = make([]byte, 128) + } + for { + st.stmt.ResetAndClear() + + // Bind args. + for colIdx, a := range args { + rv := reflect.ValueOf(a) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + if err := st.stmt.BindInt64(colIdx+1, rv.Int()); err != nil { + return nil, fmt.Errorf("binding col idx %d to %T (%v): %w", colIdx, a, rv.Int(), err) + } + default: + // TODO(bradfitz): more types, at least strings for stable IDs. + return nil, fmt.Errorf("unsupported arg type %T", a) + } + } + + n, err := st.stmt.StepAllBinary(buf) + if err == nil { + return BinaryResults(buf[:n]), nil + } + if e, ok := err.(sqliteh.BufferSizeTooSmallError); ok { + buf = make([]byte, e.EncodedSize) + continue + } + return nil, err + } +} + +// BinaryResults is the result of QueryBinary. +// +// You should not depend on its specific format and parse it via its methods +// instead. +type BinaryResults []byte + +type BinaryToken struct { + StartRow bool + EndRow bool + EndRows bool + IsInt bool // if so, use Int() method + IsFloat bool // if so, use Float() method + IsNull bool + IsBytes bool + Error bool + + x uint64 + Bytes []byte +} + +func (t *BinaryToken) String() string { + switch { + case t.StartRow: + return "start-row" + case t.EndRow: + return "end-row" + case t.EndRows: + return "end-rows" + case t.IsNull: + return "null" + case t.IsInt: + return fmt.Sprintf("int: %v", t.Int()) + case t.IsFloat: + return fmt.Sprintf("float: %g", t.Float()) + case t.IsBytes: + return fmt.Sprintf("bytes: %q", t.Bytes) + case t.Error: + return "error" + default: + return "unknown" + } +} + +func (t *BinaryToken) Int() int64 { return int64(t.x) } +func (t *BinaryToken) Float() float64 { return math.Float64frombits(t.x) } + +func (r *BinaryResults) Next() BinaryToken { + if len(*r) == 0 { + return BinaryToken{Error: true} + } + first := (*r)[0] + *r = (*r)[1:] + switch first { + default: + return BinaryToken{Error: true} + case '(': + return BinaryToken{StartRow: true} + case ')': + return BinaryToken{EndRow: true} + case 'E': + return BinaryToken{EndRows: true} + case 'n': + return BinaryToken{IsNull: true} + case 'i', 'f': + if len(*r) < 8 { + return BinaryToken{Error: true} + } + t := BinaryToken{IsInt: first == 'i', IsFloat: first == 'f'} + if cpu.IsBigEndian { + t.x = binary.BigEndian.Uint64((*r)[:8]) + } else { + t.x = binary.LittleEndian.Uint64((*r)[:8]) + } + *r = (*r)[8:] + return t + case 'b': + if len(*r) < 8 { + return BinaryToken{Error: true} + } + t := BinaryToken{IsBytes: true} + var n int64 + if cpu.IsBigEndian { + n = int64(binary.BigEndian.Uint64((*r)[:8])) + } else { + n = int64(binary.LittleEndian.Uint64((*r)[:8])) + } + *r = (*r)[8:] + if int64(len(*r)) < n { + return BinaryToken{Error: true} + } + t.Bytes = (*r)[:n] + *r = (*r)[n:] + return t + } +} diff --git a/binary_test.go b/binary_test.go new file mode 100644 index 0000000..fb0d61b --- /dev/null +++ b/binary_test.go @@ -0,0 +1,110 @@ +// Copyright (c) 2023 Tailscale Inc & AUTHORS All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package sqlite + +import ( + "context" + "math" + "reflect" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestQueryBinary(t *testing.T) { + ctx := WithPersist(context.Background()) + db := openTestDB(t) + exec(t, db, "CREATE TABLE t (id INTEGER PRIMARY KEY, f REAL, txt TEXT, blb BLOB)") + exec(t, db, "INSERT INTO t VALUES (?, ?, ?, ?)", math.MinInt64, 1.0, "text-a", "blob-a") + exec(t, db, "INSERT INTO t VALUES (?, ?, ?, ?)", -1, -1.0, "text-b", "blob-b") + exec(t, db, "INSERT INTO t VALUES (?, ?, ?, ?)", 0, 0, "text-c", "blob-c") + exec(t, db, "INSERT INTO t VALUES (?, ?, ?, ?)", 20, 2, "text-d", "blob-d") + exec(t, db, "INSERT INTO t VALUES (?, ?, ?, ?)", math.MaxInt64, nil, "text-e", "blob-e") + exec(t, db, "INSERT INTO t VALUES (?, ?, ?, ?)", 42, 0.25, "text-f", nil) + exec(t, db, "INSERT INTO t VALUES (?, ?, ?, ?)", 43, 1.75, "text-g", nil) + + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + buf, err := QueryBinary(ctx, conn, make([]byte, 100), "SELECT * FROM t ORDER BY id") + if err != nil { + t.Fatal(err) + } + t.Logf("Got %d bytes: %q", len(buf), buf) + + var got []string + iter := buf + for len(iter) > 0 { + t := iter.Next() + got = append(got, t.String()) + if t.Error { + break + } + } + want := []string{ + "start-row", "int: -9223372036854775808", "float: 1", "bytes: \"text-a\"", "bytes: \"blob-a\"", "end-row", + "start-row", "int: -1", "float: -1", "bytes: \"text-b\"", "bytes: \"blob-b\"", "end-row", + "start-row", "int: 0", "float: 0", "bytes: \"text-c\"", "bytes: \"blob-c\"", "end-row", + "start-row", "int: 20", "float: 2", "bytes: \"text-d\"", "bytes: \"blob-d\"", "end-row", + "start-row", "int: 42", "float: 0.25", "bytes: \"text-f\"", "null", "end-row", + "start-row", "int: 43", "float: 1.75", "bytes: \"text-g\"", "null", "end-row", + "start-row", "int: 9223372036854775807", "null", "bytes: \"text-e\"", "bytes: \"blob-e\"", "end-row", + "end-rows", + } + if !reflect.DeepEqual(got, want) { + t.Errorf("wrong results\n got: %q\nwant: %q\n\ndiff:\n%s", got, want, cmp.Diff(want, got)) + } + + allocs := int(testing.AllocsPerRun(10000, func() { + _, err := QueryBinary(ctx, conn, buf, "SELECT * FROM t") + if err != nil { + t.Fatal(err) + } + })) + const maxAllocs = 5 // as of Go 1.20 + if allocs > maxAllocs { + t.Errorf("allocs = %v; want max %v", allocs, maxAllocs) + } +} + +func BenchmarkQueryBinaryParallel(b *testing.B) { + ctx := WithPersist(context.Background()) + db := openTestDB(b) + exec(b, db, "CREATE TABLE t (id INTEGER PRIMARY KEY, f REAL, txt TEXT, blb BLOB)") + exec(b, db, "INSERT INTO t VALUES (?, ?, ?, ?)", 42, 0.25, "text-f", "some big big big big blob so big like so many bytes even") + + b.ResetTimer() + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + conn, err := db.Conn(ctx) + if err != nil { + b.Error(err) + return + } + + var buf = make([]byte, 250) + + for pb.Next() { + res, err := QueryBinary(ctx, conn, buf, "SELECT id, f, txt, blb FROM t WHERE id=?", 42) + if err != nil { + b.Error(err) + return + } + t := res.Next() + if !t.StartRow { + b.Errorf("didn't get start row; got %v", t) + return + } + t = res.Next() + if t.Int() != 42 { + b.Errorf("got %v; want 42", t) + return + } + } + }) + +} diff --git a/cgosqlite/cgosqlite.go b/cgosqlite/cgosqlite.go index 2817384..38e6fce 100644 --- a/cgosqlite/cgosqlite.go +++ b/cgosqlite/cgosqlite.go @@ -54,6 +54,7 @@ package cgosqlite // #include "cgosqlite.h" import "C" import ( + "errors" "sync" "time" "unsafe" @@ -120,6 +121,7 @@ type Stmt struct { // used as scratch space when calling into cgo rowid, changes C.sqlite3_int64 duration C.int64_t + encodedSize C.int } // Open implements sqliteh.OpenFunc. @@ -420,6 +422,23 @@ func (stmt *Stmt) ColumnDeclType(col int) string { return res } +func (stmt *Stmt) StepAllBinary(dstBuf []byte) (n int, err error) { + if len(dstBuf) == 0 { + return 0, errors.New("zero-length buffer to StepAllBinary") + } + ret := C.ts_sqlite_step_all(stmt.stmt.int(), (*C.char)(unsafe.Pointer(&dstBuf[0])), C.int(len(dstBuf)), &stmt.encodedSize) + + if int(stmt.encodedSize) > len(dstBuf) { + return 0, sqliteh.BufferSizeTooSmallError{ + EncodedSize: int(stmt.encodedSize), + } + } + if err := errCode(ret); err != nil { + return 0, err + } + return int(stmt.encodedSize), nil +} + var emptyCStr = C.CString("") func errCode(code C.int) error { return sqliteh.CodeAsError(sqliteh.Code(code)) } diff --git a/cgosqlite/cgosqlite.h b/cgosqlite/cgosqlite.h index 0f9379c..f3d4e78 100644 --- a/cgosqlite/cgosqlite.h +++ b/cgosqlite/cgosqlite.h @@ -145,3 +145,93 @@ static double ts_sqlite3_column_double(handle_sqlite3_stmt stmt, int iCol) { static sqlite3_int64 ts_sqlite3_column_int64(handle_sqlite3_stmt stmt, int iCol) { return sqlite3_column_int64((sqlite3_stmt*)(stmt), iCol); } + +static void stepall_add_byte(char* bufBase, int bufLen, int* pos, int* encodedSize, char b) { + (*encodedSize)++; + if (*encodedSize > bufLen) { + return; + } + bufBase[*pos] = b; + (*pos)++; +} + +static void stepall_add_int64(char* bufBase, int bufLen, int* pos, int* encodedSize, sqlite3_int64 v) { + (*encodedSize) += 8; + if (*encodedSize > bufLen) { + return; + } + for (int i = 0; i < 8; i++) { + bufBase[*pos] = ((char*)&v)[i]; + (*pos)++; + } +} + +static void stepall_add_bytes(char* bufBase, int bufLen, int* pos, int* encodedSize, const char* v, int vlen) { + stepall_add_int64(bufBase, bufLen, pos, encodedSize, vlen); + + (*encodedSize) += vlen; + if (*encodedSize > bufLen) { + return; + } + strncpy(bufBase + *pos, v, vlen); + (*pos) += vlen; +} + +static void ts_sqlite_step_all_encode_row(sqlite3_stmt* stmt, char* bufBase, int bufLen, int* pos, int* encodedSize) { + stepall_add_byte(bufBase, bufLen, pos, encodedSize, '('); // start row + int cols = sqlite3_column_count(stmt); + sqlite3_int64 intVal; + double doubleVal; + + for (int col = 0; col < cols; col++) { + int colType = sqlite3_column_type(stmt, col); + switch (colType) { + case SQLITE_INTEGER: + stepall_add_byte(bufBase, bufLen, pos, encodedSize, 'i'); // i for "integer" + intVal = sqlite3_column_int64(stmt, col); + stepall_add_int64(bufBase, bufLen, pos, encodedSize, intVal); + break; + case SQLITE_FLOAT: + stepall_add_byte(bufBase, bufLen, pos, encodedSize, 'f'); // f for "integer" + doubleVal = sqlite3_column_double(stmt, col); + stepall_add_int64(bufBase, bufLen, pos, encodedSize, *(sqlite3_int64*)(&doubleVal)); // ala math.Float64bits + break; + case SQLITE_NULL: + stepall_add_byte(bufBase, bufLen, pos, encodedSize, 'n'); // n for "null" + break; + case SQLITE_TEXT: + case SQLITE_BLOB: + stepall_add_byte(bufBase, bufLen, pos, encodedSize, 'b'); // b for "blob" (but also used for TEXT) + stepall_add_bytes(bufBase, bufLen, pos, encodedSize, + (char*) sqlite3_column_text(stmt, col), + sqlite3_column_bytes(stmt, col)); + break; + } + } + + stepall_add_byte(bufBase, bufLen, pos, encodedSize, ')'); // end row +} + +// encodedSize is initialized to zero and counts how much total space would be required, +// even if bufLen is too small. Only a max of bufLen bytes are written to bufBase. +static int ts_sqlite_step_all(handle_sqlite3_stmt stmth, char* bufBase, int bufLen, int* encodedSize) { + sqlite3_stmt* stmt = (sqlite3_stmt*)(stmth); + *encodedSize = 0; + if (bufLen < 1) { + return SQLITE_ERROR; + } + int pos = 0; + + while (1) { + int err = sqlite3_step(stmt); + if (err == SQLITE_DONE) { + stepall_add_byte(bufBase, bufLen, &pos, encodedSize, 'E' /* 'E' for End */); + return SQLITE_OK; + } + if (err == SQLITE_ROW) { + ts_sqlite_step_all_encode_row(stmt, bufBase, bufLen, &pos, encodedSize); + } else { + return err; + } + } +} diff --git a/go.mod b/go.mod index 2906a27..3dae4c9 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,8 @@ module github.com/tailscale/sqlite go 1.20 + +require ( + github.com/google/go-cmp v0.5.9 // indirect + golang.org/x/sys v0.6.0 // indirect +) diff --git a/sqlite.go b/sqlite.go index e049e3c..83e2e82 100644 --- a/sqlite.go +++ b/sqlite.go @@ -184,8 +184,7 @@ func (c *conn) Close() error { return reserr(c.db, "Conn.Close", "", c.db.Close()) } func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - persist := ctx.Value(persistQuery{}) != nil - return c.prepare(ctx, query, persist) + return c.prepare(ctx, query, IsPersist(ctx)) } func (c *conn) prepare(ctx context.Context, query string, persist bool) (s *stmt, err error) { @@ -897,5 +896,10 @@ func WithPersist(ctx context.Context) context.Context { return context.WithValue(ctx, persistQuery{}, persistQuery{}) } +// IsPersist reports whether the context has the Persist key. +func IsPersist(ctx context.Context) bool { + return ctx.Value(persistQuery{}) != nil +} + // persistQuery is used as a context value. type persistQuery struct{} diff --git a/sqlite_test.go b/sqlite_test.go index 2070b39..f9713ed 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -78,7 +78,7 @@ type execContexter interface { ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) } -func exec(t *testing.T, db execContexter, query string, args ...any) sql.Result { +func exec(t testing.TB, db execContexter, query string, args ...any) sql.Result { t.Helper() ctx := context.Background() res, err := db.ExecContext(ctx, query, args...) diff --git a/sqliteh/sqliteh.go b/sqliteh/sqliteh.go index 15cbd32..b83785c 100644 --- a/sqliteh/sqliteh.go +++ b/sqliteh/sqliteh.go @@ -7,6 +7,7 @@ package sqliteh import ( "context" + "fmt" "sync" "time" ) @@ -175,6 +176,26 @@ type Stmt interface { // ColumnTableName is sqlite3_column_table_name. // https://sqlite.org/c3ref/column_database_name.html ColumnTableName(col int) string + + // StepAllBinary reads all rows into dstBuf, binary packed. + // + // It returns how much of dstBuf was populated. + // + // It returns an error of type BufferSizeTooSmallError if dstBuf is too + // small. That error's EncodedSize says how large a buffer would need to be. + // + // First byte of dstBuf contains the version of the format; it is currently + // always '1'. Callers should verify (with tests) that the format hasn't + // changed since they updated their go.mod deps. + StepAllBinary(dstBuf []byte) (n int, err error) +} + +type BufferSizeTooSmallError struct { + EncodedSize int +} + +func (e BufferSizeTooSmallError) Error() string { + return fmt.Sprintf("buffer size too small; need %d bytes", e.EncodedSize) } // ColumnType are constants for each of the SQLite datatypes.