Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 37 additions & 58 deletions cgosqlite/cgosqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ import (
"github.com/tailscale/sqlite/sqliteh"
)

var emptyStrPtr = (*C.char)(unsafe.Pointer(unsafe.StringData("")))

func init() {
C.sqlite3_initialize()
}
Expand All @@ -68,31 +70,10 @@ type DB struct {
declTypes map[string]string
}

// cStmt is a wrapper around an sqlite3 *sqlite3_stmt. Except rather than
// storing it as a pointer, it's stored as uintptr to avoid allocations due to
// poor interactions between cgo's pointer checker and Go's escape analysis.
//
// The ptr method returns the value as a pointer, for call sites that haven't
// yet been optimized or don't need the optimization. This lets us migrate
// incrementally.
//
// See http://go/corp/9919.
type cStmt struct {
v C.handle_sqlite3_stmt
}

// cStmtFromPtr returns a cStmt from a C pointer.
func cStmtFromPtr(p *C.sqlite3_stmt) cStmt {
return cStmt{v: C.handle_sqlite3_stmt(uintptr(unsafe.Pointer(p)))}
}

func (h cStmt) int() C.handle_sqlite3_stmt { return h.v }
func (h cStmt) ptr() *C.sqlite3_stmt { return (*C.sqlite3_stmt)(unsafe.Pointer(uintptr(h.v))) }

// Stmt implements sqliteh.Stmt.
type Stmt struct {
db *DB
stmt cStmt
stmt *C.sqlite3_stmt
start C.struct_timespec

// used as scratch space when calling into cgo
Expand Down Expand Up @@ -202,7 +183,7 @@ func (db *DB) Prepare(query string, prepFlags sqliteh.PrepareFlags) (stmt sqlite
return nil, "", err
}
remainingQuery = query[len(query)-int(C.strlen(csqlTail)):]
return &Stmt{db: db, stmt: cStmtFromPtr(cstmt)}, remainingQuery, nil
return &Stmt{db: db, stmt: cstmt}, remainingQuery, nil
}

func (db *DB) DisableFunction(name string, numArgs int) error {
Expand All @@ -212,45 +193,45 @@ func (db *DB) DisableFunction(name string, numArgs int) error {
}

func (stmt *Stmt) DBHandle() sqliteh.DB {
cdb := C.sqlite3_db_handle(stmt.stmt.ptr())
cdb := C.sqlite3_db_handle(stmt.stmt)
if cdb != nil {
return &DB{db: cdb}
}
return nil
}

func (stmt *Stmt) SQL() string {
return C.GoString(C.sqlite3_sql(stmt.stmt.ptr()))
return C.GoString(C.sqlite3_sql(stmt.stmt))
}

func (stmt *Stmt) ExpandedSQL() string {
// sqlite3_expanded_sql returns a string obtained by sqlite3_malloc, which
// must be freed after use.
cstr := C.sqlite3_expanded_sql(stmt.stmt.ptr())
cstr := C.sqlite3_expanded_sql(stmt.stmt)
defer C.sqlite3_free(unsafe.Pointer(cstr))
return C.GoString(cstr)
}

func (stmt *Stmt) Reset() error {
return errCode(C.sqlite3_reset(stmt.stmt.ptr()))
return errCode(C.sqlite3_reset(stmt.stmt))
}

func (stmt *Stmt) Finalize() error {
return errCode(C.sqlite3_finalize(stmt.stmt.ptr()))
return errCode(C.sqlite3_finalize(stmt.stmt))
}

func (stmt *Stmt) ClearBindings() error {
return errCode(C.sqlite3_clear_bindings(stmt.stmt.ptr()))
return errCode(C.sqlite3_clear_bindings(stmt.stmt))
}

func (stmt *Stmt) ResetAndClear() (time.Duration, error) {
if stmt.start != (C.struct_timespec{}) {
stmt.duration = 0
err := errCode(C.reset_and_clear(stmt.stmt.int(), &stmt.start, &stmt.duration))
err := errCode(C.reset_and_clear(stmt.stmt, &stmt.start, &stmt.duration))
return time.Duration(stmt.duration), err
}
if sp := stmt.stmt.int(); sp != 0 {
return 0, errCode(C.reset_and_clear(stmt.stmt.int(), nil, nil))
if stmt.stmt != nil {
return 0, errCode(C.reset_and_clear(stmt.stmt, nil, nil))
}
// The statement was never initialized. This can happen if, for example, the
// parser found only comments (so the statement was not empty, but did not
Expand All @@ -263,19 +244,19 @@ func (stmt *Stmt) StartTimer() {
}

func (stmt *Stmt) ColumnDatabaseName(col int) string {
return C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_database_name(stmt.stmt.ptr(), C.int(col)))))
return C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_database_name(stmt.stmt, C.int(col)))))
}

func (stmt *Stmt) ColumnTableName(col int) string {
return C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_table_name(stmt.stmt.ptr(), C.int(col)))))
return C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_table_name(stmt.stmt, C.int(col)))))
}

func (stmt *Stmt) Step(colType []sqliteh.ColumnType) (row bool, err error) {
var ptr *C.char
if len(colType) > 0 {
ptr = (*C.char)(unsafe.Pointer(&colType[0]))
}
res := C.ts_sqlite3_step(stmt.stmt.int(), ptr, C.int(len(colType)))
res := C.ts_sqlite3_step(stmt.stmt, ptr, C.int(len(colType)))
switch res {
case C.SQLITE_ROW:
return true, nil
Expand All @@ -288,7 +269,7 @@ func (stmt *Stmt) Step(colType []sqliteh.ColumnType) (row bool, err error) {

func (stmt *Stmt) StepResult() (row bool, lastInsertRowID, changes int64, d time.Duration, err error) {
stmt.rowid, stmt.changes, stmt.duration = 0, 0, 0
res := C.step_result(stmt.stmt.int(), &stmt.rowid, &stmt.changes, &stmt.duration)
res := C.step_result(stmt.stmt, &stmt.rowid, &stmt.changes, &stmt.duration)
lastInsertRowID = int64(stmt.rowid)
changes = int64(stmt.changes)
d = time.Duration(stmt.duration)
Expand All @@ -304,51 +285,51 @@ func (stmt *Stmt) StepResult() (row bool, lastInsertRowID, changes int64, d time
}

func (stmt *Stmt) BindDouble(col int, val float64) error {
return errCode(C.ts_sqlite3_bind_double(stmt.stmt.int(), C.int(col), C.double(val)))
return errCode(C.sqlite3_bind_double(stmt.stmt, C.int(col), C.double(val)))
}

func (stmt *Stmt) BindInt64(col int, val int64) error {
return errCode(C.ts_sqlite3_bind_int64(stmt.stmt.int(), C.int(col), C.sqlite3_int64(val)))
return errCode(C.sqlite3_bind_int64(stmt.stmt, C.int(col), C.sqlite3_int64(val)))
}

func (stmt *Stmt) BindNull(col int) error {
return errCode(C.ts_sqlite3_bind_null(stmt.stmt.int(), C.int(col)))
return errCode(C.sqlite3_bind_null(stmt.stmt, C.int(col)))
}

func (stmt *Stmt) BindText64(col int, val string) error {
if len(val) == 0 {
return errCode(C.bind_text64_empty(stmt.stmt.int(), C.int(col)))
return errCode(C.sqlite3_bind_text64(stmt.stmt, C.int(col), emptyStrPtr, 0, C.SQLITE_STATIC, C.SQLITE_UTF8))
}
v := C.CString(val) // freed by sqlite
return errCode(C.bind_text64(stmt.stmt.int(), C.int(col), v, C.sqlite3_uint64(len(val))))
return errCode(C.sqlite3_bind_text64(stmt.stmt, C.int(col), v, C.sqlite3_uint64(len(val)), (*[0]byte)(C.free), C.SQLITE_UTF8))
}

func (stmt *Stmt) BindZeroBlob64(col int, n uint64) error {
return errCode(C.sqlite3_bind_zeroblob64(stmt.stmt.ptr(), C.int(col), C.sqlite3_uint64(n)))
return errCode(C.sqlite3_bind_zeroblob64(stmt.stmt, C.int(col), C.sqlite3_uint64(n)))
}

func (stmt *Stmt) BindBlob64(col int, val []byte) error {
var str *C.char
if len(val) > 0 {
str = (*C.char)(unsafe.Pointer(&val[0]))
}
return errCode(C.bind_blob64(stmt.stmt.int(), C.int(col), str, C.sqlite3_uint64(len(val))))
return errCode(C.sqlite3_bind_blob64(stmt.stmt, C.int(col), unsafe.Pointer(str), C.sqlite3_uint64(len(val)), C.SQLITE_TRANSIENT))
}

func (stmt *Stmt) BindParameterCount() int {
return int(C.sqlite3_bind_parameter_count(stmt.stmt.ptr()))
return int(C.sqlite3_bind_parameter_count(stmt.stmt))
}

func (stmt *Stmt) BindParameterName(col int) string {
cstr := C.sqlite3_bind_parameter_name(stmt.stmt.ptr(), C.int(col))
cstr := C.sqlite3_bind_parameter_name(stmt.stmt, C.int(col))
if cstr == nil {
return ""
}
return C.GoString(cstr)
}

func (stmt *Stmt) BindParameterIndex(name string) int {
return int(C.bind_parameter_index(stmt.stmt.int(), name))
return int(C.bind_parameter_index(stmt.stmt, name))
}

func (stmt *Stmt) BindParameterIndexSearch(name string) int {
Expand All @@ -363,45 +344,45 @@ func (stmt *Stmt) BindParameterIndexSearch(name string) int {
}

func (stmt *Stmt) ColumnCount() int {
return int(C.sqlite3_column_count(stmt.stmt.ptr()))
return int(C.sqlite3_column_count(stmt.stmt))
}

func (stmt *Stmt) ColumnName(col int) string {
return C.GoString(C.sqlite3_column_name(stmt.stmt.ptr(), C.int(col)))
return C.GoString(C.sqlite3_column_name(stmt.stmt, C.int(col)))
}

func (stmt *Stmt) ColumnText(col int) string {
str := (*C.char)(unsafe.Pointer(C.ts_sqlite3_column_text(stmt.stmt.int(), C.int(col))))
n := C.ts_sqlite3_column_bytes(stmt.stmt.int(), C.int(col))
str := (*C.char)(unsafe.Pointer(C.sqlite3_column_text(stmt.stmt, C.int(col))))
n := C.sqlite3_column_bytes(stmt.stmt, C.int(col))
if str == nil || n == 0 {
return ""
}
return C.GoStringN(str, n)
}

func (stmt *Stmt) ColumnBlob(col int) []byte {
res := C.ts_sqlite3_column_blob(stmt.stmt.int(), C.int(col))
res := C.sqlite3_column_blob(stmt.stmt, C.int(col))
if res == nil {
return nil
}
n := int(C.ts_sqlite3_column_bytes(stmt.stmt.int(), C.int(col)))
n := int(C.sqlite3_column_bytes(stmt.stmt, C.int(col)))
return unsafe.Slice((*byte)(unsafe.Pointer(res)), n)
}

func (stmt *Stmt) ColumnDouble(col int) float64 {
return float64(C.ts_sqlite3_column_double(stmt.stmt.int(), C.int(col)))
return float64(C.sqlite3_column_double(stmt.stmt, C.int(col)))
}

func (stmt *Stmt) ColumnInt64(col int) int64 {
return int64(C.ts_sqlite3_column_int64(stmt.stmt.int(), C.int(col)))
return int64(C.sqlite3_column_int64(stmt.stmt, C.int(col)))
}

func (stmt *Stmt) ColumnType(col int) sqliteh.ColumnType {
return sqliteh.ColumnType(C.ts_sqlite3_column_type(stmt.stmt.int(), C.int(col)))
return sqliteh.ColumnType(C.sqlite3_column_type(stmt.stmt, C.int(col)))
}

func (stmt *Stmt) ColumnDeclType(col int) string {
cstr := C.sqlite3_column_decltype(stmt.stmt.ptr(), C.int(col))
cstr := C.sqlite3_column_decltype(stmt.stmt, C.int(col))
if cstr == nil {
return ""
}
Expand All @@ -418,8 +399,6 @@ func (stmt *Stmt) ColumnDeclType(col int) string {
return res
}

var emptyCStr = C.CString("")

func errCode(code C.int) error { return sqliteh.CodeAsError(sqliteh.Code(code)) }

// internCache contains interned strings.
Expand Down
70 changes: 6 additions & 64 deletions cgosqlite/cgosqlite.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,41 +5,10 @@
#include <string.h>
#include <time.h>

// uintptr versions of sqlite3 pointer types, to avoid allocations
// in cgo code. (go/corp/9919)
typedef uintptr_t handle_sqlite3_stmt; // a *sqlite3_stmt
typedef uintptr_t handle_sqlite3; // a *sqlite3 (DB conn)

// Helper methods to deal with int <-> pointer pain.

static int bind_text64(handle_sqlite3_stmt stmt, int col, const char* str, sqlite3_uint64 len) {
return sqlite3_bind_text64((sqlite3_stmt*)(stmt), col, str, len, free, SQLITE_UTF8);
}

static int bind_text64_empty(handle_sqlite3_stmt stmt, int col) {
return sqlite3_bind_text64((sqlite3_stmt*)(stmt), col, "", 0, SQLITE_STATIC, SQLITE_UTF8);
}

static int bind_blob64(handle_sqlite3_stmt stmt, int col, char* str, sqlite3_uint64 n) {
return sqlite3_bind_blob64((sqlite3_stmt*)(stmt), col, str, n, SQLITE_TRANSIENT);
}

static int ts_sqlite3_bind_double(handle_sqlite3_stmt stmt, int col, double v) {
return sqlite3_bind_double((sqlite3_stmt*)(stmt), col, v);
}

static int ts_sqlite3_bind_int64(handle_sqlite3_stmt stmt, int col, sqlite3_int64 v) {
return sqlite3_bind_int64((sqlite3_stmt*)(stmt), col, v);
}

static int ts_sqlite3_bind_null(handle_sqlite3_stmt stmt, int col) {
return sqlite3_bind_null((sqlite3_stmt*)(stmt), col);
}

// We only need the Go string's memory for the duration of the call,
// and the GC pins it for us if we pass the gostring_t to C, so we
// do the conversion here instead of with C.CString.
static int bind_parameter_index(handle_sqlite3_stmt stmt, _GoString_ s) {
static int bind_parameter_index(sqlite3_stmt* stmt, _GoString_ s) {
size_t n = _GoStringLen(s);
const char *p = (const char *)_GoStringPtr(s);

Expand All @@ -49,7 +18,7 @@ static int bind_parameter_index(handle_sqlite3_stmt stmt, _GoString_ s) {
return 0;
}
memmove(zName, p, n);
return sqlite3_bind_parameter_index((sqlite3_stmt*)(stmt), zName);
return sqlite3_bind_parameter_index(stmt, zName);
}

static void monotonic_clock_gettime(struct timespec* t) {
Expand All @@ -65,8 +34,7 @@ static int64_t ns_since(const struct timespec t1)
}

// step_result combines several cgo calls to save overhead.
static int step_result(handle_sqlite3_stmt stmth, sqlite3_int64* rowid, sqlite3_int64* changes, int64_t* duration_ns) {
sqlite3_stmt* stmt = (sqlite3_stmt*)(stmth);
static int step_result(sqlite3_stmt* stmt, sqlite3_int64* rowid, sqlite3_int64* changes, int64_t* duration_ns) {
struct timespec t1;
if (duration_ns) {
monotonic_clock_gettime(&t1);
Expand All @@ -84,8 +52,7 @@ static int step_result(handle_sqlite3_stmt stmth, sqlite3_int64* rowid, sqlite3_
}

// reset_and_clear combines two cgo calls to save overhead.
static int reset_and_clear(handle_sqlite3_stmt stmth, struct timespec* start, int64_t* duration_ns) {
sqlite3_stmt* stmt = (sqlite3_stmt*)(stmth);
static int reset_and_clear(sqlite3_stmt* stmt, struct timespec* start, int64_t* duration_ns) {
int ret = sqlite3_reset(stmt);
int ret2 = sqlite3_clear_bindings(stmt);
if (duration_ns) {
Expand All @@ -111,8 +78,7 @@ static void ts_sqlite3_wal_hook_go(sqlite3* db) {
sqlite3_wal_hook(db, wal_callback_into_go, 0);
}

static int ts_sqlite3_step(handle_sqlite3_stmt stmth, char* outType , int outTypeLen) {
sqlite3_stmt* stmt = (sqlite3_stmt*)(stmth);
static int ts_sqlite3_step(sqlite3_stmt* stmt, char* outType , int outTypeLen) {
int res = sqlite3_step(stmt);
if (res == SQLITE_ROW && outTypeLen > 0) {
int cols = sqlite3_column_count(stmt);
Expand All @@ -123,30 +89,6 @@ static int ts_sqlite3_step(handle_sqlite3_stmt stmth, char* outType , int outTyp
return res;
}

static const unsigned char *ts_sqlite3_column_text(handle_sqlite3_stmt stmt, int iCol) {
return sqlite3_column_text((sqlite3_stmt*)(stmt), iCol);
}

static const unsigned char *ts_sqlite3_column_blob(handle_sqlite3_stmt stmt, int iCol) {
return sqlite3_column_blob((sqlite3_stmt*)(stmt), iCol);
}

static int ts_sqlite3_column_type(handle_sqlite3_stmt stmt, int iCol) {
return sqlite3_column_type((sqlite3_stmt*)(stmt), iCol);
}

static int ts_sqlite3_column_bytes(handle_sqlite3_stmt stmt, int iCol) {
return sqlite3_column_bytes((sqlite3_stmt*)(stmt), iCol);
}

static double ts_sqlite3_column_double(handle_sqlite3_stmt stmt, int iCol) {
return sqlite3_column_double((sqlite3_stmt*)(stmt), iCol);
}

static sqlite3_int64 ts_sqlite3_column_int64(handle_sqlite3_stmt stmt, int iCol) {
return sqlite3_column_int64((sqlite3_stmt*)(stmt), iCol);
}

static int ts_sqlite3_disable_function(sqlite3 *db, const char *zFunctionName, int nArg) {
return sqlite3_create_function(db, zFunctionName, nArg, SQLITE_ANY, NULL, NULL, NULL, NULL);
}
}
Loading