Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 60 additions & 6 deletions sqlite.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,22 @@ func Connector(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Trace
}
}

// ConnectorWithLogger returns a [driver.Connector] for the given connection
// parameters. makeLogger is used to create a [ConnLogger] when [Connect] is
// called.
func ConnectorWithLogger(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Tracer, makeLogger func() ConnLogger) driver.Connector {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document when makeLogger is called (i.e., "when Connect is called")?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return &connector{
name: sqliteURI,
tracer: tracer,
makeLogger: makeLogger,
connInitFunc: connInitFunc,
}
}

type connector struct {
name string
tracer sqliteh.Tracer
makeLogger func() ConnLogger // or nil
connInitFunc ConnInitFunc
}

Expand All @@ -152,12 +165,14 @@ func (p *connector) Connect(ctx context.Context) (driver.Conn, error) {
}
return nil, err
}

c := &conn{
db: db,
tracer: p.tracer,
id: sqliteh.TraceConnID(maxConnID.Add(1)),
}
if p.makeLogger != nil {
c.logger = p.makeLogger()
}
if p.connInitFunc != nil {
if err := p.connInitFunc(ctx, c); err != nil {
db.Close()
Expand All @@ -179,6 +194,7 @@ type conn struct {
db sqliteh.DB
id sqliteh.TraceConnID
tracer sqliteh.Tracer
logger ConnLogger
stmts map[string]*stmt // persisted statements
txState txState
readOnly bool
Expand All @@ -202,6 +218,7 @@ func (c *conn) Close() error {
err := reserr(c.db, "Conn.Close", "", c.db.Close())
return err
}

func (c *conn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
persist := ctx.Value(persistQuery{}) != nil
return c.prepare(ctx, query, persist)
Expand Down Expand Up @@ -341,6 +358,9 @@ func (c *conn) txInit(ctx context.Context) error {
return err
}
} else {
if c.logger != nil {
c.logger.Begin()
}
// TODO(crawshaw): offer BEGIN DEFERRED (and BEGIN CONCURRENT?)
// semantics via a context annotation function.
if err := c.execInternal(ctx, "BEGIN IMMEDIATE"); err != nil {
Expand All @@ -351,15 +371,16 @@ func (c *conn) txInit(ctx context.Context) error {
}

func (c *conn) txEnd(ctx context.Context, endStmt string) error {
state, readOnly := c.txState, c.readOnly
c.txState = txStateNone
c.readOnly = false
if state != txStateBegun {
defer func() {
c.txState = txStateNone
c.readOnly = false
}()
if c.txState != txStateBegun {
return nil
}

err := c.execInternal(context.Background(), endStmt)
if readOnly {
if c.readOnly {
if err2 := c.execInternal(ctx, "PRAGMA query_only=false"); err == nil {
err = err2
}
Expand All @@ -377,10 +398,14 @@ func (tx *connTx) Commit() error {
return ErrClosed
}

readonly := tx.conn.readOnly
err := tx.conn.txEnd(context.Background(), "COMMIT")
if tx.conn.tracer != nil {
tx.conn.tracer.Commit(tx.conn.id, err)
}
Comment on lines 403 to 405
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This Tracer interface is what I was remembering with its type TraceConnID it plumbs through.

It's kinda weird having both Tracer and Logger that are such overlaps. Oh well.

I guess Tracer has the SQL with placeholders and Logger has it without? Might be worth documenting on both the relationship between the two.

if tx.conn.logger != nil && !readonly {
tx.conn.logger.Commit(err)
}
return err
}

Expand All @@ -390,10 +415,14 @@ func (tx *connTx) Rollback() error {
return ErrClosed
}

readonly := tx.conn.readOnly
err := tx.conn.txEnd(context.Background(), "ROLLBACK")
if tx.conn.tracer != nil {
tx.conn.tracer.Rollback(tx.conn.id, err)
}
if tx.conn.logger != nil && !readonly {
tx.conn.logger.Rollback()
}
return err
}

Expand Down Expand Up @@ -490,6 +519,9 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
if err := s.bindAll(args); err != nil {
return nil, s.reserr("Stmt.Exec(Bind)", err)
}
if s.conn.logger != nil && !s.conn.readOnly {
s.conn.logger.Statement(s.stmt.ExpandedSQL())
}

if ctx.Value(queryCancelKey{}) != nil {
done := make(chan struct{})
Expand Down Expand Up @@ -1068,3 +1100,25 @@ func WithQueryCancel(ctx context.Context) context.Context {

// queryCancelKey is a context key for query context enforcement.
type queryCancelKey struct{}

// ConnLogger is implemented by the caller to support statement-level logging
// for write transactions. Only Exec calls are logged, not Query calls, as this
// is intended as a mechanism to replay failed transactions.
//
// Aside from logging only executed statements, ConnLogger also differs from
// [sqliteh.Tracer] by logging the expanded SQL, instead of the query with
// placeholders.
type ConnLogger interface {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably want some kind of Close() method to ensure the file is flushed to disk when we exit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we want Close here, although of course you're right that we need to flush the logs to disk on exit. In the corresponding PR, I've added a Close handler which is attached to the database shutdown process. Thanks for spotting it!

// Begin is called when a writable transaction is opened.
Begin()

// Statement is called with evaluated SQL when a statement is executed.
Statement(sql string)

// Commit is called after a commit statement, with the error resulting
// from the attempted commit.
Commit(error)

// Rollback is called after a rollback statement.
Rollback()
}
184 changes: 184 additions & 0 deletions sqlite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"os"
"reflect"
"runtime"
"slices"
"strconv"
"strings"
"sync"
Expand Down Expand Up @@ -1354,3 +1355,186 @@ func TestDisableFunction(t *testing.T) {
t.Fatal("Attempting to use the LOWER function after disabling should have failed")
}
}

type connLogger struct {
ch chan []string
statements []string
panicOnUse bool
}

func (cl *connLogger) Begin() {
if cl.panicOnUse {
panic("unexpected connLogger.Begin()")
}
cl.statements = nil
}

func (cl *connLogger) Statement(s string) {
if cl.panicOnUse {
panic("unexpected connLogger.Statement: " + s)
}
cl.statements = append(cl.statements, s)
}

func (cl *connLogger) Commit(err error) {
if cl.panicOnUse {
panic("unexpected connLogger.Commit()")
}
if err != nil {
return
}
cl.ch <- cl.statements
}

func (cl *connLogger) Rollback() {
if cl.panicOnUse {
panic("unexpected connLogger.Rollback()")
}
cl.statements = nil
}

func TestConnLogger_writable(t *testing.T) {
for _, commit := range []bool{true, false} {
doneStatement := "ROLLBACK"
if commit {
doneStatement = "COMMIT"
}
t.Run(doneStatement, func(t *testing.T) {
ctx := context.Background()
ch := make(chan []string, 1)
txl := connLogger{ch: ch}
makeLogger := func() ConnLogger { return &txl }
db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger))
configDB(t, db)

tx, err := db.BeginTx(ctx, nil)
if err != nil {
t.Fatal(err)
}
if _, err := tx.Exec("CREATE TABLE T (x INTEGER)"); err != nil {
t.Fatal(err)
}
if _, err := tx.Exec("INSERT INTO T VALUES (?)", 1); err != nil {
t.Fatal(err)
}
if _, err := tx.Query("SELECT x FROM T"); err != nil {
t.Fatal(err)
}
done := tx.Rollback
if commit {
done = tx.Commit
}
if err := done(); err != nil {
t.Fatal(err)
}
if !commit {
select {
case got := <-ch:
t.Errorf("unexpectedly logged statements for rollback:\n%s", strings.Join(got, "\n"))
default:
return
}
}

want := []string{
"BEGIN IMMEDIATE",
"CREATE TABLE T (x INTEGER)",
"INSERT INTO T VALUES (1)",
doneStatement,
}
select {
case got := <-ch:
if !slices.Equal(got, want) {
t.Errorf("unexpected log statements. got:\n%s\n\nwant:\n%s", strings.Join(got, "\n"), strings.Join(want, "\n"))
}
default:
t.Fatal("no logged statements after commit")
}
})
}
}

func TestConnLogger_commit_error(t *testing.T) {
ctx := context.Background()
ch := make(chan []string, 1)
txl := connLogger{ch: ch}
makeLogger := func() ConnLogger { return &txl }
db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger))
configDB(t, db)

if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
t.Fatal(err)
}
if _, err := db.Exec("CREATE TABLE A (x INTEGER PRIMARY KEY)"); err != nil {
t.Fatal(err)
}
if _, err := db.Exec("CREATE TABLE B (x INTEGER REFERENCES A(X) DEFERRABLE INITIALLY DEFERRED)"); err != nil {
t.Fatal(err)
}

tx, err := db.BeginTx(ctx, nil)
if err != nil {
t.Fatal(err)
}
if _, err := tx.Exec("INSERT INTO B VALUES (?)", 1); err != nil {
t.Fatal(err)
}
if err := tx.Commit(); err == nil {
t.Fatal("expected Commit to error, but didn't")
}
select {
case got := <-ch:
t.Errorf("unexpectedly logged statements for errored commit:\n%s", strings.Join(got, "\n"))
default:
return
}
}

func TestConnLogger_read_tx(t *testing.T) {
ctx := context.Background()
ch := make(chan []string, 1)
txl := connLogger{ch: ch}
makeLogger := func() ConnLogger { return &txl }
db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger))
configDB(t, db)

tx, err := db.BeginTx(ctx, nil)
if err != nil {
t.Fatal(err)
}
if _, err := tx.Exec("CREATE TABLE T (x INTEGER)"); err != nil {
t.Fatal(err)
}
if _, err := tx.Exec("INSERT INTO T VALUES (?)", 1); err != nil {
t.Fatal(err)
}
if err := tx.Commit(); err != nil {
t.Fatal(err)
}
select {
case got := <-ch:
if len(got) == 0 {
t.Errorf("expected logged statements for write tx")
}
default:
t.Errorf("expected logged statements for write tx")
}

txl.panicOnUse = true
for _, commit := range []bool{true, false} {
rx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
if err != nil {
t.Fatal(err)
}
if _, err := rx.Query("SELECT x FROM T"); err != nil {
t.Fatal(err)
}
done := rx.Rollback
if commit {
done = rx.Commit
}
if err := done(); err != nil {
t.Fatal(err)
}
}
}
Loading