From 731f626b3a694f37c578f955b78f253f30984a35 Mon Sep 17 00:00:00 2001 From: Alisdair McDiarmid Date: Mon, 27 Oct 2025 06:19:22 -0700 Subject: [PATCH] sqlite: add optional ConnLogger ConnLogger can be used to log executed statements for a connection. The interface captures events for Begin, Exec, Commit, and Rollback calls. Updates tailscale/corp#33577 Co-authored-by: Anton Tolchanov --- sqlite.go | 66 ++++++++++++++++-- sqlite_test.go | 184 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 244 insertions(+), 6 deletions(-) diff --git a/sqlite.go b/sqlite.go index 7e592b2..73604d8 100644 --- a/sqlite.go +++ b/sqlite.go @@ -127,9 +127,22 @@ func Connector(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Trace } } +// ConnectorWithLogger returns a [driver.Connector] for the given connection +// parameters. makeLogger is used to create a [ConnLogger] when [Connect] is +// called. +func ConnectorWithLogger(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Tracer, makeLogger func() ConnLogger) driver.Connector { + return &connector{ + name: sqliteURI, + tracer: tracer, + makeLogger: makeLogger, + connInitFunc: connInitFunc, + } +} + type connector struct { name string tracer sqliteh.Tracer + makeLogger func() ConnLogger // or nil connInitFunc ConnInitFunc } @@ -152,12 +165,14 @@ func (p *connector) Connect(ctx context.Context) (driver.Conn, error) { } return nil, err } - c := &conn{ db: db, tracer: p.tracer, id: sqliteh.TraceConnID(maxConnID.Add(1)), } + if p.makeLogger != nil { + c.logger = p.makeLogger() + } if p.connInitFunc != nil { if err := p.connInitFunc(ctx, c); err != nil { db.Close() @@ -179,6 +194,7 @@ type conn struct { db sqliteh.DB id sqliteh.TraceConnID tracer sqliteh.Tracer + logger ConnLogger stmts map[string]*stmt // persisted statements txState txState readOnly bool @@ -202,6 +218,7 @@ func (c *conn) Close() error { err := reserr(c.db, "Conn.Close", "", c.db.Close()) return err } + func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { persist := ctx.Value(persistQuery{}) != nil return c.prepare(ctx, query, persist) @@ -341,6 +358,9 @@ func (c *conn) txInit(ctx context.Context) error { return err } } else { + if c.logger != nil { + c.logger.Begin() + } // TODO(crawshaw): offer BEGIN DEFERRED (and BEGIN CONCURRENT?) // semantics via a context annotation function. if err := c.execInternal(ctx, "BEGIN IMMEDIATE"); err != nil { @@ -351,15 +371,16 @@ func (c *conn) txInit(ctx context.Context) error { } func (c *conn) txEnd(ctx context.Context, endStmt string) error { - state, readOnly := c.txState, c.readOnly - c.txState = txStateNone - c.readOnly = false - if state != txStateBegun { + defer func() { + c.txState = txStateNone + c.readOnly = false + }() + if c.txState != txStateBegun { return nil } err := c.execInternal(context.Background(), endStmt) - if readOnly { + if c.readOnly { if err2 := c.execInternal(ctx, "PRAGMA query_only=false"); err == nil { err = err2 } @@ -377,10 +398,14 @@ func (tx *connTx) Commit() error { return ErrClosed } + readonly := tx.conn.readOnly err := tx.conn.txEnd(context.Background(), "COMMIT") if tx.conn.tracer != nil { tx.conn.tracer.Commit(tx.conn.id, err) } + if tx.conn.logger != nil && !readonly { + tx.conn.logger.Commit(err) + } return err } @@ -390,10 +415,14 @@ func (tx *connTx) Rollback() error { return ErrClosed } + readonly := tx.conn.readOnly err := tx.conn.txEnd(context.Background(), "ROLLBACK") if tx.conn.tracer != nil { tx.conn.tracer.Rollback(tx.conn.id, err) } + if tx.conn.logger != nil && !readonly { + tx.conn.logger.Rollback() + } return err } @@ -490,6 +519,9 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive if err := s.bindAll(args); err != nil { return nil, s.reserr("Stmt.Exec(Bind)", err) } + if s.conn.logger != nil && !s.conn.readOnly { + s.conn.logger.Statement(s.stmt.ExpandedSQL()) + } if ctx.Value(queryCancelKey{}) != nil { done := make(chan struct{}) @@ -1068,3 +1100,25 @@ func WithQueryCancel(ctx context.Context) context.Context { // queryCancelKey is a context key for query context enforcement. type queryCancelKey struct{} + +// ConnLogger is implemented by the caller to support statement-level logging +// for write transactions. Only Exec calls are logged, not Query calls, as this +// is intended as a mechanism to replay failed transactions. +// +// Aside from logging only executed statements, ConnLogger also differs from +// [sqliteh.Tracer] by logging the expanded SQL, instead of the query with +// placeholders. +type ConnLogger interface { + // Begin is called when a writable transaction is opened. + Begin() + + // Statement is called with evaluated SQL when a statement is executed. + Statement(sql string) + + // Commit is called after a commit statement, with the error resulting + // from the attempted commit. + Commit(error) + + // Rollback is called after a rollback statement. + Rollback() +} diff --git a/sqlite_test.go b/sqlite_test.go index 4011d9b..ef88d1f 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -13,6 +13,7 @@ import ( "os" "reflect" "runtime" + "slices" "strconv" "strings" "sync" @@ -1354,3 +1355,186 @@ func TestDisableFunction(t *testing.T) { t.Fatal("Attempting to use the LOWER function after disabling should have failed") } } + +type connLogger struct { + ch chan []string + statements []string + panicOnUse bool +} + +func (cl *connLogger) Begin() { + if cl.panicOnUse { + panic("unexpected connLogger.Begin()") + } + cl.statements = nil +} + +func (cl *connLogger) Statement(s string) { + if cl.panicOnUse { + panic("unexpected connLogger.Statement: " + s) + } + cl.statements = append(cl.statements, s) +} + +func (cl *connLogger) Commit(err error) { + if cl.panicOnUse { + panic("unexpected connLogger.Commit()") + } + if err != nil { + return + } + cl.ch <- cl.statements +} + +func (cl *connLogger) Rollback() { + if cl.panicOnUse { + panic("unexpected connLogger.Rollback()") + } + cl.statements = nil +} + +func TestConnLogger_writable(t *testing.T) { + for _, commit := range []bool{true, false} { + doneStatement := "ROLLBACK" + if commit { + doneStatement = "COMMIT" + } + t.Run(doneStatement, func(t *testing.T) { + ctx := context.Background() + ch := make(chan []string, 1) + txl := connLogger{ch: ch} + makeLogger := func() ConnLogger { return &txl } + db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger)) + configDB(t, db) + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + if _, err := tx.Exec("CREATE TABLE T (x INTEGER)"); err != nil { + t.Fatal(err) + } + if _, err := tx.Exec("INSERT INTO T VALUES (?)", 1); err != nil { + t.Fatal(err) + } + if _, err := tx.Query("SELECT x FROM T"); err != nil { + t.Fatal(err) + } + done := tx.Rollback + if commit { + done = tx.Commit + } + if err := done(); err != nil { + t.Fatal(err) + } + if !commit { + select { + case got := <-ch: + t.Errorf("unexpectedly logged statements for rollback:\n%s", strings.Join(got, "\n")) + default: + return + } + } + + want := []string{ + "BEGIN IMMEDIATE", + "CREATE TABLE T (x INTEGER)", + "INSERT INTO T VALUES (1)", + doneStatement, + } + select { + case got := <-ch: + if !slices.Equal(got, want) { + t.Errorf("unexpected log statements. got:\n%s\n\nwant:\n%s", strings.Join(got, "\n"), strings.Join(want, "\n")) + } + default: + t.Fatal("no logged statements after commit") + } + }) + } +} + +func TestConnLogger_commit_error(t *testing.T) { + ctx := context.Background() + ch := make(chan []string, 1) + txl := connLogger{ch: ch} + makeLogger := func() ConnLogger { return &txl } + db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger)) + configDB(t, db) + + if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil { + t.Fatal(err) + } + if _, err := db.Exec("CREATE TABLE A (x INTEGER PRIMARY KEY)"); err != nil { + t.Fatal(err) + } + if _, err := db.Exec("CREATE TABLE B (x INTEGER REFERENCES A(X) DEFERRABLE INITIALLY DEFERRED)"); err != nil { + t.Fatal(err) + } + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + if _, err := tx.Exec("INSERT INTO B VALUES (?)", 1); err != nil { + t.Fatal(err) + } + if err := tx.Commit(); err == nil { + t.Fatal("expected Commit to error, but didn't") + } + select { + case got := <-ch: + t.Errorf("unexpectedly logged statements for errored commit:\n%s", strings.Join(got, "\n")) + default: + return + } +} + +func TestConnLogger_read_tx(t *testing.T) { + ctx := context.Background() + ch := make(chan []string, 1) + txl := connLogger{ch: ch} + makeLogger := func() ConnLogger { return &txl } + db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger)) + configDB(t, db) + + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + if _, err := tx.Exec("CREATE TABLE T (x INTEGER)"); err != nil { + t.Fatal(err) + } + if _, err := tx.Exec("INSERT INTO T VALUES (?)", 1); err != nil { + t.Fatal(err) + } + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + select { + case got := <-ch: + if len(got) == 0 { + t.Errorf("expected logged statements for write tx") + } + default: + t.Errorf("expected logged statements for write tx") + } + + txl.panicOnUse = true + for _, commit := range []bool{true, false} { + rx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) + if err != nil { + t.Fatal(err) + } + if _, err := rx.Query("SELECT x FROM T"); err != nil { + t.Fatal(err) + } + done := rx.Rollback + if commit { + done = rx.Commit + } + if err := done(); err != nil { + t.Fatal(err) + } + } +}