diff --git a/cgosqlite/cgosqlite.go b/cgosqlite/cgosqlite.go index 3fc1dd8..5e391c4 100644 --- a/cgosqlite/cgosqlite.go +++ b/cgosqlite/cgosqlite.go @@ -159,6 +159,56 @@ func (db *DB) TxnState(schema string) sqliteh.TxnState { return sqliteh.TxnState(C.sqlite3_txn_state(db.db, cSchema)) } +func (db *DB) BackupInit(dstSchema string, src sqliteh.DB, srcSchema string) (sqliteh.Backup, error) { + var cDstSchema, cSrcSchema *C.char + if dstSchema != "" { + cDstSchema = C.CString(dstSchema) + defer C.free(unsafe.Pointer(cDstSchema)) + } + if srcSchema != "" { + cSrcSchema = C.CString(srcSchema) + defer C.free(unsafe.Pointer(cSrcSchema)) + } + + b := C.sqlite3_backup_init(db.db, cDstSchema, src.(*DB).db, cSrcSchema) + if b == nil { + // sqlite3_backup_init docs tell us the error is on the dst DB. + return nil, sqliteh.ErrCode(db.ExtendedErrCode()) + } + return &backup{backup: b}, nil +} + +type backup struct { + backup *C.sqlite3_backup +} + +func (b *backup) Step(numPages int) (more bool, remaining, pageCount int, err error) { + res := C.sqlite3_backup_step(b.backup, C.int(numPages)) + + // It is not safe to call remaining and pagecount concurrently with step, so + // instead just return them each time. + remaining = int(C.sqlite3_backup_remaining(b.backup)) + pageCount = int(C.sqlite3_backup_pagecount(b.backup)) + + more = true + switch res { + case C.SQLITE_OK, C.SQLITE_BUSY, C.SQLITE_LOCKED: + more = true + default: + more = false + } + + return more, remaining, pageCount, errCode(res) +} + +func (b *backup) Finish() error { + res := C.sqlite3_backup_finish(b.backup) + if res == C.SQLITE_OK { + return nil + } + return errCode(res) +} + func (db *DB) Prepare(query string, prepFlags sqliteh.PrepareFlags) (stmt sqliteh.Stmt, remainingQuery string, err error) { csql := C.CString(query) defer C.free(unsafe.Pointer(csql)) diff --git a/examples/backup/backup.go b/examples/backup/backup.go new file mode 100644 index 0000000..c60dc79 --- /dev/null +++ b/examples/backup/backup.go @@ -0,0 +1,181 @@ +package main + +import ( + "context" + "database/sql" + "database/sql/driver" + "log" + "runtime" + "sync" + "sync/atomic" + "time" + + "github.com/tailscale/sqlite" +) + +var ( + // The way that SQLite3 backups work is that they restart if the database is + // ever updated from a different context than the one backing up, however + // modifications made using the same context as the backup can be observed + // without restarting. The application must just ensure that the connection is + // either performing queries, or performing a backup step at any given time. + // The backup step size can be tuned by the application to appropriately + // share time between the writer and backup operations. + mu sync.Mutex + conn *sql.Conn + + inserted atomic.Int64 + + walMode = func(ctx context.Context, conn driver.ConnPrepareContext) error { + return sqlite.ExecScript(conn.(sqlite.SQLConn), "PRAGMA journal_mode=WAL;") + } +) + +func main() { + ctx, cancelAndWait := withCancelWait(context.Background()) + db := sql.OpenDB(sqlite.Connector("file:/tmp/example.db", walMode, nil)) + defer db.Close() + + var err error + conn, err = db.Conn(context.Background()) + must(err) + defer conn.Close() + + must(initSchema(ctx)) + + go fill(ctx) + + log.Printf("sleeping for 10 seconds to populate the table") + time.Sleep(10 * time.Second) + log.Printf("inserted: %d", inserted.Load()) + + backup(ctx) + + cancelAndWait() +} + +func backup(ctx context.Context) { + bdb := sql.OpenDB(sqlite.Connector("file:/tmp/example-backup.db", walMode, nil)) + defer bdb.Close() + bConn, err := bdb.Conn(ctx) + must(err) + defer bConn.Close() + + log.Printf("backing up") + b, err := sqlite.NewBackup(bConn, "main", conn, "main") + must(err) + + var ( + more bool = true + remaining int + pageCount int + ) + + for more { + mu.Lock() + more, remaining, pageCount, err = b.Step(1024) + mu.Unlock() + if err != nil { + // fatal errors are returned by finish too + break + } + log.Printf("remaining=%5d pageCount=%5d (inserted: %5d)", remaining, pageCount, inserted.Load()) + time.Sleep(time.Millisecond) + } + log.Printf("backup steps done") + must(b.Finish()) + log.Printf("backup finished") +} + +func fill(ctx context.Context) { + defer done(ctx) + for alive(ctx) { + mu.Lock() + _, err := conn.ExecContext(ctx, "INSERT INTO foo (data) VALUES ('never gunna back you up, never gunna take you down, never gunna alter schema and hurt you');") + inserted.Add(1) + mu.Unlock() + must(err) + } +} + +func initSchema(ctx context.Context) error { + mu.Lock() + defer mu.Unlock() + _, err := conn.ExecContext(ctx, ` + CREATE TABLE IF NOT EXISTS foo ( + id INTEGER PRIMARY KEY, + data TEXT + ); + `) + return err +} + +func must(err error) { + _, file, no, _ := runtime.Caller(1) + if err != nil { + log.Fatalf("%s:%d %#v", file, no, err) + } +} + +var wgKey = &struct{}{} + +type waitCtx struct { + context.Context + wg *sync.WaitGroup +} + +func (c *waitCtx) Done() <-chan struct{} { + return c.Context.Done() +} + +func (c *waitCtx) Err() error { + return c.Context.Err() +} + +func (c *waitCtx) Deadline() (deadline time.Time, ok bool) { + return c.Context.Deadline() +} + +func (c *waitCtx) Value(key interface{}) interface{} { + if key == wgKey { + return c.wg + } + return c.Context.Value(key) +} + +var _ context.Context = &waitCtx{} + +func withWait(ctx context.Context) context.Context { + wg, ok := ctx.Value(wgKey).(*sync.WaitGroup) + if !ok { + wg = &sync.WaitGroup{} + ctx = &waitCtx{ctx, wg} + } + wg.Add(1) + return ctx +} + +func alive(ctx context.Context) bool { + select { + case <-ctx.Done(): + return false + default: + return true + } +} + +func wait(ctx context.Context) { + ctx.Value(wgKey).(*sync.WaitGroup).Wait() +} + +func done(ctx context.Context) { + ctx.Value(wgKey).(*sync.WaitGroup).Done() +} + +func withCancelWait(ctx context.Context) (context.Context, context.CancelFunc) { + ctx, cancel := context.WithCancel(withWait(ctx)) + return ctx, func() { + cancel() + wait(ctx) + } +} diff --git a/sqlite.go b/sqlite.go index e06872e..36ca935 100644 --- a/sqlite.go +++ b/sqlite.go @@ -123,20 +123,25 @@ type connector struct { connInitFunc ConnInitFunc } +func errWithMsg(db sqliteh.DB, err error, loc string) error { + if ec, ok := err.(sqliteh.ErrCode); ok { + e := &Error{ + Code: sqliteh.Code(ec), + Loc: loc, + } + if db != nil { + e.Msg = db.ErrMsg() + } + return e + } + return err +} + func (p *connector) Driver() driver.Driver { return drv{} } func (p *connector) Connect(ctx context.Context) (driver.Conn, error) { db, err := Open(p.name, sqliteh.OpenFlagsDefault, "") if err != nil { - if ec, ok := err.(sqliteh.ErrCode); ok { - e := &Error{ - Code: sqliteh.Code(ec), - Loc: "Open", - } - if db != nil { - e.Msg = db.ErrMsg() - } - err = e - } + err = errWithMsg(db, err, "Open") if db != nil { db.Close() } @@ -838,3 +843,85 @@ func WithPersist(ctx context.Context) context.Context { // persistQuery is used as a context value. type persistQuery struct{} + +// DB executes fn with the sqliteh.DB underlying sqlconn. +func DB(sqlconn SQLConn, fn func(sqliteh.DB) error) error { + return sqlconn.Raw(func(driverConn interface{}) error { + c, ok := driverConn.(*conn) + if !ok { + return fmt.Errorf("sqlite.Checkpoint: sql.Conn is not the sqlite driver: %T", driverConn) + } + return fn(c.db) + }) +} + +// Backup holds an in-progress backup context. +type Backup struct { + backup sqliteh.Backup + src sqliteh.DB + dst sqliteh.DB +} + +// NewBackup starts a new backup operation that will read from the SQLite +// database srcConn, schema srcSchema, and write to the database dstConn, schema +// dstSchema. +// The database owned by dstConn will be locked for the duration, and must not +// be modified by other connections or processes. +// The database owned by srcConn will be read-locked during each call to Step, +// but can otherwise be used normally. Applications must arrange to ensure that +// there is mutual exclusion between queries and step calls on the source +// connection. +// If a different connection alters the source database during the backup, Step +// will restart the backup process from the beginning, however if the source +// connection alters the database, the backup can continue and will include +// pages affected by the concurrent transactions. +// Finish must be called on the returned backup object in order to free +// resources consumed by the backup operation, even if errors occur during steps +// of the backup process. Finish can also be called any time that Step is not +// running in order to abort the backup. +func NewBackup(dstConn SQLConn, dstSchema string, srcConn SQLConn, srcSchema string) (*Backup, error) { + var b Backup + err := DB(dstConn, func(dst sqliteh.DB) error { + return DB(srcConn, func(src sqliteh.DB) error { + var err error + b.src = src + b.dst = dst + b.backup, err = dst.BackupInit(dstSchema, src, srcSchema) + return errWithMsg(dst, err, "Backup") + }) + }) + return &b, err +} + +// Step makes incremental progress toward a complete online backup. It performs +// at most numPages of copies from the source database to the target database. +// +// Step may be called in between other queries on the source connection, so as +// to concurrently to service traffic, however Step must not be called in +// parallel with other queries on the source connection. +// +// Step may return more=true and non-fatal errors of either SQLITE_BUSY or +// SQLITE_LOCKED, however either of these errors being returned likely indicate +// that an external writer has modified the source database and there will be a +// side effect that the backup will restart from the beginning on the next call +// to Step. +// +// Progress is reported by the `remaining` and `pageCount` values. remaining is +// the number of pages left to copy, and pageCount is the current number of +// pages in total that must be copied. pageCount may change in size due to +// writes that occur during the backup process. +func (b *Backup) Step(numPages int) (more bool, remaining, pageCount int, err error) { + more, remaining, pageCount, err = b.backup.Step(numPages) + err = errWithMsg(b.dst, err, "Step") + return +} + +// Finish frees up the backup object and any resources it consumes. It must be +// called even if errors occured calling Step. If Step reports a fatal error, +// Finish will also return the same error. Finish can be called at any time to +// abort a backup operation early. +func (b *Backup) Finish() error { + err := b.backup.Finish() + err = errWithMsg(b.dst, err, "Finish") + return err +} diff --git a/sqlite_test.go b/sqlite_test.go index 24ed223..dbc760c 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -779,6 +779,82 @@ func TestAttachOrderingDeadlock(t *testing.T) { } } +func TestBackup(t *testing.T) { + src, err := sql.Open("sqlite3", "file:src?mode=memory") + if err != nil { + t.Fatal(err) + } + defer src.Close() + dst, err := sql.Open("sqlite3", "file:dst?mode=memory") + if err != nil { + t.Fatal(err) + } + defer dst.Close() + ctx := context.Background() + + srcConn, err := src.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer srcConn.Close() + err = ExecScript(srcConn, ` + ATTACH 'file:src2?mode=memory' AS src2; + CREATE TABLE t1 (c); + INSERT INTO t1 VALUES ('a'); + CREATE TABLE src2.t2 (c); + INSERT INTO src2.t2 VALUES ('b'); + `) + if err != nil { + t.Fatal(err) + } + + dstConn, err := src.Conn(ctx) + if err != nil { + t.Fatal(err) + } + defer dstConn.Close() + + var backup = func(dstConn *sql.Conn, dstName string, srcConn *sql.Conn, srcName string) error { + t.Helper() + b, err := NewBackup(dstConn, dstName, srcConn, srcName) + if err != nil { + return err + } + var ( + more = true + remaining int + pageCount int + ) + for more { + more, remaining, pageCount, err = b.Step(1024) + t.Logf("backup step: more=%v, remaining=%d, pageCount=%d", more, remaining, pageCount) + if err != nil { + t.Errorf("backup step: %v", err) + more = false + } + } + return b.Finish() + } + + if err := backup(dstConn, "main", srcConn, "main"); err != nil { + t.Fatal(err) + } + if _, err := dstConn.ExecContext(ctx, "ATTACH 'file:dst2?mode=memory' AS dst2;"); err != nil { + t.Fatal(err) + } + var count int + if err := dstConn.QueryRowContext(ctx, "SELECT count(*) FROM t1").Scan(&count); err != nil || count != 1 { + t.Fatalf("err=%v, count=%d", err, count) + } + if err := backup(dstConn, "dst2", srcConn, "src2"); err != nil { + t.Fatal(err) + } + count = 0 + if err := dstConn.QueryRowContext(ctx, "SELECT count(*) FROM dst2.t2").Scan(&count); err != nil || count != 1 { + t.Fatalf("err=%v, count=%d", err, count) + } +} + func BenchmarkPersist(b *testing.B) { ctx := context.Background() db := openTestDB(b) diff --git a/sqliteh/sqliteh.go b/sqliteh/sqliteh.go index 168417c..e74bb4d 100644 --- a/sqliteh/sqliteh.go +++ b/sqliteh/sqliteh.go @@ -53,6 +53,9 @@ type DB interface { AutoCheckpoint(n int) error // TxnState is sqlite3_txn_state. TxnState(schema string) TxnState + // BackupInit is sqlite3_backup_init, this DB is the destination. + // https://www.sqlite.org/c3ref/backup_finish.html#sqlite3backupinit + BackupInit(dstSchema string, src DB, srcSchema string) (Backup, error) } // Stmt is an sqlite3_stmt* database connection object. @@ -165,6 +168,15 @@ type Stmt interface { ColumnTableName(col int) string } +// Backup is an sqlite3_backup object. +// https://www.sqlite.org/c3ref/backup_finish.html +type Backup interface { + // Step is called repeatedly to transfer data between the two DBs. + Step(numPages int) (more bool, remaining, pageCount int, err error) + // Finish releases all resources associated with the Backup. + Finish() error +} + // ColumnType are constants for each of the SQLite datatypes. // https://www.sqlite.org/c3ref/c_blob.html type ColumnType int