diff --git a/cgosqlite/cgosqlite.go b/cgosqlite/cgosqlite.go index 386b8df..8fc1b07 100644 --- a/cgosqlite/cgosqlite.go +++ b/cgosqlite/cgosqlite.go @@ -57,6 +57,8 @@ import ( "github.com/tailscale/sqlite/sqliteh" ) +var emptyStrPtr = (*C.char)(unsafe.Pointer(unsafe.StringData(""))) + func init() { C.sqlite3_initialize() } @@ -68,31 +70,10 @@ type DB struct { declTypes map[string]string } -// cStmt is a wrapper around an sqlite3 *sqlite3_stmt. Except rather than -// storing it as a pointer, it's stored as uintptr to avoid allocations due to -// poor interactions between cgo's pointer checker and Go's escape analysis. -// -// The ptr method returns the value as a pointer, for call sites that haven't -// yet been optimized or don't need the optimization. This lets us migrate -// incrementally. -// -// See http://go/corp/9919. -type cStmt struct { - v C.handle_sqlite3_stmt -} - -// cStmtFromPtr returns a cStmt from a C pointer. -func cStmtFromPtr(p *C.sqlite3_stmt) cStmt { - return cStmt{v: C.handle_sqlite3_stmt(uintptr(unsafe.Pointer(p)))} -} - -func (h cStmt) int() C.handle_sqlite3_stmt { return h.v } -func (h cStmt) ptr() *C.sqlite3_stmt { return (*C.sqlite3_stmt)(unsafe.Pointer(uintptr(h.v))) } - // Stmt implements sqliteh.Stmt. type Stmt struct { db *DB - stmt cStmt + stmt *C.sqlite3_stmt start C.struct_timespec // used as scratch space when calling into cgo @@ -202,7 +183,7 @@ func (db *DB) Prepare(query string, prepFlags sqliteh.PrepareFlags) (stmt sqlite return nil, "", err } remainingQuery = query[len(query)-int(C.strlen(csqlTail)):] - return &Stmt{db: db, stmt: cStmtFromPtr(cstmt)}, remainingQuery, nil + return &Stmt{db: db, stmt: cstmt}, remainingQuery, nil } func (db *DB) DisableFunction(name string, numArgs int) error { @@ -212,7 +193,7 @@ func (db *DB) DisableFunction(name string, numArgs int) error { } func (stmt *Stmt) DBHandle() sqliteh.DB { - cdb := C.sqlite3_db_handle(stmt.stmt.ptr()) + cdb := C.sqlite3_db_handle(stmt.stmt) if cdb != nil { return &DB{db: cdb} } @@ -220,37 +201,37 @@ func (stmt *Stmt) DBHandle() sqliteh.DB { } func (stmt *Stmt) SQL() string { - return C.GoString(C.sqlite3_sql(stmt.stmt.ptr())) + return C.GoString(C.sqlite3_sql(stmt.stmt)) } func (stmt *Stmt) ExpandedSQL() string { // sqlite3_expanded_sql returns a string obtained by sqlite3_malloc, which // must be freed after use. - cstr := C.sqlite3_expanded_sql(stmt.stmt.ptr()) + cstr := C.sqlite3_expanded_sql(stmt.stmt) defer C.sqlite3_free(unsafe.Pointer(cstr)) return C.GoString(cstr) } func (stmt *Stmt) Reset() error { - return errCode(C.sqlite3_reset(stmt.stmt.ptr())) + return errCode(C.sqlite3_reset(stmt.stmt)) } func (stmt *Stmt) Finalize() error { - return errCode(C.sqlite3_finalize(stmt.stmt.ptr())) + return errCode(C.sqlite3_finalize(stmt.stmt)) } func (stmt *Stmt) ClearBindings() error { - return errCode(C.sqlite3_clear_bindings(stmt.stmt.ptr())) + return errCode(C.sqlite3_clear_bindings(stmt.stmt)) } func (stmt *Stmt) ResetAndClear() (time.Duration, error) { if stmt.start != (C.struct_timespec{}) { stmt.duration = 0 - err := errCode(C.reset_and_clear(stmt.stmt.int(), &stmt.start, &stmt.duration)) + err := errCode(C.reset_and_clear(stmt.stmt, &stmt.start, &stmt.duration)) return time.Duration(stmt.duration), err } - if sp := stmt.stmt.int(); sp != 0 { - return 0, errCode(C.reset_and_clear(stmt.stmt.int(), nil, nil)) + if stmt.stmt != nil { + return 0, errCode(C.reset_and_clear(stmt.stmt, nil, nil)) } // The statement was never initialized. This can happen if, for example, the // parser found only comments (so the statement was not empty, but did not @@ -263,11 +244,11 @@ func (stmt *Stmt) StartTimer() { } func (stmt *Stmt) ColumnDatabaseName(col int) string { - return C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_database_name(stmt.stmt.ptr(), C.int(col))))) + return C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_database_name(stmt.stmt, C.int(col))))) } func (stmt *Stmt) ColumnTableName(col int) string { - return C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_table_name(stmt.stmt.ptr(), C.int(col))))) + return C.GoString((*C.char)(unsafe.Pointer(C.sqlite3_column_table_name(stmt.stmt, C.int(col))))) } func (stmt *Stmt) Step(colType []sqliteh.ColumnType) (row bool, err error) { @@ -275,7 +256,7 @@ func (stmt *Stmt) Step(colType []sqliteh.ColumnType) (row bool, err error) { if len(colType) > 0 { ptr = (*C.char)(unsafe.Pointer(&colType[0])) } - res := C.ts_sqlite3_step(stmt.stmt.int(), ptr, C.int(len(colType))) + res := C.ts_sqlite3_step(stmt.stmt, ptr, C.int(len(colType))) switch res { case C.SQLITE_ROW: return true, nil @@ -288,7 +269,7 @@ func (stmt *Stmt) Step(colType []sqliteh.ColumnType) (row bool, err error) { func (stmt *Stmt) StepResult() (row bool, lastInsertRowID, changes int64, d time.Duration, err error) { stmt.rowid, stmt.changes, stmt.duration = 0, 0, 0 - res := C.step_result(stmt.stmt.int(), &stmt.rowid, &stmt.changes, &stmt.duration) + res := C.step_result(stmt.stmt, &stmt.rowid, &stmt.changes, &stmt.duration) lastInsertRowID = int64(stmt.rowid) changes = int64(stmt.changes) d = time.Duration(stmt.duration) @@ -304,27 +285,27 @@ func (stmt *Stmt) StepResult() (row bool, lastInsertRowID, changes int64, d time } func (stmt *Stmt) BindDouble(col int, val float64) error { - return errCode(C.ts_sqlite3_bind_double(stmt.stmt.int(), C.int(col), C.double(val))) + return errCode(C.sqlite3_bind_double(stmt.stmt, C.int(col), C.double(val))) } func (stmt *Stmt) BindInt64(col int, val int64) error { - return errCode(C.ts_sqlite3_bind_int64(stmt.stmt.int(), C.int(col), C.sqlite3_int64(val))) + return errCode(C.sqlite3_bind_int64(stmt.stmt, C.int(col), C.sqlite3_int64(val))) } func (stmt *Stmt) BindNull(col int) error { - return errCode(C.ts_sqlite3_bind_null(stmt.stmt.int(), C.int(col))) + return errCode(C.sqlite3_bind_null(stmt.stmt, C.int(col))) } func (stmt *Stmt) BindText64(col int, val string) error { if len(val) == 0 { - return errCode(C.bind_text64_empty(stmt.stmt.int(), C.int(col))) + return errCode(C.sqlite3_bind_text64(stmt.stmt, C.int(col), emptyStrPtr, 0, C.SQLITE_STATIC, C.SQLITE_UTF8)) } v := C.CString(val) // freed by sqlite - return errCode(C.bind_text64(stmt.stmt.int(), C.int(col), v, C.sqlite3_uint64(len(val)))) + return errCode(C.sqlite3_bind_text64(stmt.stmt, C.int(col), v, C.sqlite3_uint64(len(val)), (*[0]byte)(C.free), C.SQLITE_UTF8)) } func (stmt *Stmt) BindZeroBlob64(col int, n uint64) error { - return errCode(C.sqlite3_bind_zeroblob64(stmt.stmt.ptr(), C.int(col), C.sqlite3_uint64(n))) + return errCode(C.sqlite3_bind_zeroblob64(stmt.stmt, C.int(col), C.sqlite3_uint64(n))) } func (stmt *Stmt) BindBlob64(col int, val []byte) error { @@ -332,15 +313,15 @@ func (stmt *Stmt) BindBlob64(col int, val []byte) error { if len(val) > 0 { str = (*C.char)(unsafe.Pointer(&val[0])) } - return errCode(C.bind_blob64(stmt.stmt.int(), C.int(col), str, C.sqlite3_uint64(len(val)))) + return errCode(C.sqlite3_bind_blob64(stmt.stmt, C.int(col), unsafe.Pointer(str), C.sqlite3_uint64(len(val)), C.SQLITE_TRANSIENT)) } func (stmt *Stmt) BindParameterCount() int { - return int(C.sqlite3_bind_parameter_count(stmt.stmt.ptr())) + return int(C.sqlite3_bind_parameter_count(stmt.stmt)) } func (stmt *Stmt) BindParameterName(col int) string { - cstr := C.sqlite3_bind_parameter_name(stmt.stmt.ptr(), C.int(col)) + cstr := C.sqlite3_bind_parameter_name(stmt.stmt, C.int(col)) if cstr == nil { return "" } @@ -348,7 +329,7 @@ func (stmt *Stmt) BindParameterName(col int) string { } func (stmt *Stmt) BindParameterIndex(name string) int { - return int(C.bind_parameter_index(stmt.stmt.int(), name)) + return int(C.bind_parameter_index(stmt.stmt, name)) } func (stmt *Stmt) BindParameterIndexSearch(name string) int { @@ -363,16 +344,16 @@ func (stmt *Stmt) BindParameterIndexSearch(name string) int { } func (stmt *Stmt) ColumnCount() int { - return int(C.sqlite3_column_count(stmt.stmt.ptr())) + return int(C.sqlite3_column_count(stmt.stmt)) } func (stmt *Stmt) ColumnName(col int) string { - return C.GoString(C.sqlite3_column_name(stmt.stmt.ptr(), C.int(col))) + return C.GoString(C.sqlite3_column_name(stmt.stmt, C.int(col))) } func (stmt *Stmt) ColumnText(col int) string { - str := (*C.char)(unsafe.Pointer(C.ts_sqlite3_column_text(stmt.stmt.int(), C.int(col)))) - n := C.ts_sqlite3_column_bytes(stmt.stmt.int(), C.int(col)) + str := (*C.char)(unsafe.Pointer(C.sqlite3_column_text(stmt.stmt, C.int(col)))) + n := C.sqlite3_column_bytes(stmt.stmt, C.int(col)) if str == nil || n == 0 { return "" } @@ -380,28 +361,28 @@ func (stmt *Stmt) ColumnText(col int) string { } func (stmt *Stmt) ColumnBlob(col int) []byte { - res := C.ts_sqlite3_column_blob(stmt.stmt.int(), C.int(col)) + res := C.sqlite3_column_blob(stmt.stmt, C.int(col)) if res == nil { return nil } - n := int(C.ts_sqlite3_column_bytes(stmt.stmt.int(), C.int(col))) + n := int(C.sqlite3_column_bytes(stmt.stmt, C.int(col))) return unsafe.Slice((*byte)(unsafe.Pointer(res)), n) } func (stmt *Stmt) ColumnDouble(col int) float64 { - return float64(C.ts_sqlite3_column_double(stmt.stmt.int(), C.int(col))) + return float64(C.sqlite3_column_double(stmt.stmt, C.int(col))) } func (stmt *Stmt) ColumnInt64(col int) int64 { - return int64(C.ts_sqlite3_column_int64(stmt.stmt.int(), C.int(col))) + return int64(C.sqlite3_column_int64(stmt.stmt, C.int(col))) } func (stmt *Stmt) ColumnType(col int) sqliteh.ColumnType { - return sqliteh.ColumnType(C.ts_sqlite3_column_type(stmt.stmt.int(), C.int(col))) + return sqliteh.ColumnType(C.sqlite3_column_type(stmt.stmt, C.int(col))) } func (stmt *Stmt) ColumnDeclType(col int) string { - cstr := C.sqlite3_column_decltype(stmt.stmt.ptr(), C.int(col)) + cstr := C.sqlite3_column_decltype(stmt.stmt, C.int(col)) if cstr == nil { return "" } @@ -418,8 +399,6 @@ func (stmt *Stmt) ColumnDeclType(col int) string { return res } -var emptyCStr = C.CString("") - func errCode(code C.int) error { return sqliteh.CodeAsError(sqliteh.Code(code)) } // internCache contains interned strings. diff --git a/cgosqlite/cgosqlite.h b/cgosqlite/cgosqlite.h index a58c380..5848b2a 100644 --- a/cgosqlite/cgosqlite.h +++ b/cgosqlite/cgosqlite.h @@ -5,41 +5,10 @@ #include #include -// uintptr versions of sqlite3 pointer types, to avoid allocations -// in cgo code. (go/corp/9919) -typedef uintptr_t handle_sqlite3_stmt; // a *sqlite3_stmt -typedef uintptr_t handle_sqlite3; // a *sqlite3 (DB conn) - -// Helper methods to deal with int <-> pointer pain. - -static int bind_text64(handle_sqlite3_stmt stmt, int col, const char* str, sqlite3_uint64 len) { - return sqlite3_bind_text64((sqlite3_stmt*)(stmt), col, str, len, free, SQLITE_UTF8); -} - -static int bind_text64_empty(handle_sqlite3_stmt stmt, int col) { - return sqlite3_bind_text64((sqlite3_stmt*)(stmt), col, "", 0, SQLITE_STATIC, SQLITE_UTF8); -} - -static int bind_blob64(handle_sqlite3_stmt stmt, int col, char* str, sqlite3_uint64 n) { - return sqlite3_bind_blob64((sqlite3_stmt*)(stmt), col, str, n, SQLITE_TRANSIENT); -} - -static int ts_sqlite3_bind_double(handle_sqlite3_stmt stmt, int col, double v) { - return sqlite3_bind_double((sqlite3_stmt*)(stmt), col, v); -} - -static int ts_sqlite3_bind_int64(handle_sqlite3_stmt stmt, int col, sqlite3_int64 v) { - return sqlite3_bind_int64((sqlite3_stmt*)(stmt), col, v); -} - -static int ts_sqlite3_bind_null(handle_sqlite3_stmt stmt, int col) { - return sqlite3_bind_null((sqlite3_stmt*)(stmt), col); -} - // We only need the Go string's memory for the duration of the call, // and the GC pins it for us if we pass the gostring_t to C, so we // do the conversion here instead of with C.CString. -static int bind_parameter_index(handle_sqlite3_stmt stmt, _GoString_ s) { +static int bind_parameter_index(sqlite3_stmt* stmt, _GoString_ s) { size_t n = _GoStringLen(s); const char *p = (const char *)_GoStringPtr(s); @@ -49,7 +18,7 @@ static int bind_parameter_index(handle_sqlite3_stmt stmt, _GoString_ s) { return 0; } memmove(zName, p, n); - return sqlite3_bind_parameter_index((sqlite3_stmt*)(stmt), zName); + return sqlite3_bind_parameter_index(stmt, zName); } static void monotonic_clock_gettime(struct timespec* t) { @@ -65,8 +34,7 @@ static int64_t ns_since(const struct timespec t1) } // step_result combines several cgo calls to save overhead. -static int step_result(handle_sqlite3_stmt stmth, sqlite3_int64* rowid, sqlite3_int64* changes, int64_t* duration_ns) { - sqlite3_stmt* stmt = (sqlite3_stmt*)(stmth); +static int step_result(sqlite3_stmt* stmt, sqlite3_int64* rowid, sqlite3_int64* changes, int64_t* duration_ns) { struct timespec t1; if (duration_ns) { monotonic_clock_gettime(&t1); @@ -84,8 +52,7 @@ static int step_result(handle_sqlite3_stmt stmth, sqlite3_int64* rowid, sqlite3_ } // reset_and_clear combines two cgo calls to save overhead. -static int reset_and_clear(handle_sqlite3_stmt stmth, struct timespec* start, int64_t* duration_ns) { - sqlite3_stmt* stmt = (sqlite3_stmt*)(stmth); +static int reset_and_clear(sqlite3_stmt* stmt, struct timespec* start, int64_t* duration_ns) { int ret = sqlite3_reset(stmt); int ret2 = sqlite3_clear_bindings(stmt); if (duration_ns) { @@ -111,8 +78,7 @@ static void ts_sqlite3_wal_hook_go(sqlite3* db) { sqlite3_wal_hook(db, wal_callback_into_go, 0); } -static int ts_sqlite3_step(handle_sqlite3_stmt stmth, char* outType , int outTypeLen) { - sqlite3_stmt* stmt = (sqlite3_stmt*)(stmth); +static int ts_sqlite3_step(sqlite3_stmt* stmt, char* outType , int outTypeLen) { int res = sqlite3_step(stmt); if (res == SQLITE_ROW && outTypeLen > 0) { int cols = sqlite3_column_count(stmt); @@ -123,30 +89,6 @@ static int ts_sqlite3_step(handle_sqlite3_stmt stmth, char* outType , int outTyp return res; } -static const unsigned char *ts_sqlite3_column_text(handle_sqlite3_stmt stmt, int iCol) { - return sqlite3_column_text((sqlite3_stmt*)(stmt), iCol); -} - -static const unsigned char *ts_sqlite3_column_blob(handle_sqlite3_stmt stmt, int iCol) { - return sqlite3_column_blob((sqlite3_stmt*)(stmt), iCol); -} - -static int ts_sqlite3_column_type(handle_sqlite3_stmt stmt, int iCol) { - return sqlite3_column_type((sqlite3_stmt*)(stmt), iCol); -} - -static int ts_sqlite3_column_bytes(handle_sqlite3_stmt stmt, int iCol) { - return sqlite3_column_bytes((sqlite3_stmt*)(stmt), iCol); -} - -static double ts_sqlite3_column_double(handle_sqlite3_stmt stmt, int iCol) { - return sqlite3_column_double((sqlite3_stmt*)(stmt), iCol); -} - -static sqlite3_int64 ts_sqlite3_column_int64(handle_sqlite3_stmt stmt, int iCol) { - return sqlite3_column_int64((sqlite3_stmt*)(stmt), iCol); -} - static int ts_sqlite3_disable_function(sqlite3 *db, const char *zFunctionName, int nArg) { return sqlite3_create_function(db, zFunctionName, nArg, SQLITE_ANY, NULL, NULL, NULL, NULL); -} \ No newline at end of file +}