Skip to content

Commit 7572674

Browse files
committed
sqlite: add optional ConnLogger
ConnLogger can be used to log executed statements for a connection. The interface captures events for Begin, Exec, Commit, and Rollback calls. Updates tailscale/corp#33577
1 parent 3a6395a commit 7572674

File tree

2 files changed

+159
-1
lines changed

2 files changed

+159
-1
lines changed

sqlite.go

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,19 @@ func Connector(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Trace
127127
}
128128
}
129129

130+
func ConnectorWithLogger(sqliteURI string, connInitFunc ConnInitFunc, tracer sqliteh.Tracer, makeLogger func() ConnLogger) driver.Connector {
131+
return &connector{
132+
name: sqliteURI,
133+
tracer: tracer,
134+
makeLogger: makeLogger,
135+
connInitFunc: connInitFunc,
136+
}
137+
}
138+
130139
type connector struct {
131140
name string
132141
tracer sqliteh.Tracer
142+
makeLogger func() ConnLogger
133143
connInitFunc ConnInitFunc
134144
}
135145

@@ -152,12 +162,14 @@ func (p *connector) Connect(ctx context.Context) (driver.Conn, error) {
152162
}
153163
return nil, err
154164
}
155-
156165
c := &conn{
157166
db: db,
158167
tracer: p.tracer,
159168
id: sqliteh.TraceConnID(maxConnID.Add(1)),
160169
}
170+
if p.makeLogger != nil {
171+
c.logger = p.makeLogger()
172+
}
161173
if p.connInitFunc != nil {
162174
if err := p.connInitFunc(ctx, c); err != nil {
163175
db.Close()
@@ -179,6 +191,7 @@ type conn struct {
179191
db sqliteh.DB
180192
id sqliteh.TraceConnID
181193
tracer sqliteh.Tracer
194+
logger ConnLogger
182195
stmts map[string]*stmt // persisted statements
183196
txState txState
184197
readOnly bool
@@ -341,6 +354,9 @@ func (c *conn) txInit(ctx context.Context) error {
341354
return err
342355
}
343356
} else {
357+
if c.logger != nil {
358+
c.logger.Begin()
359+
}
344360
// TODO(crawshaw): offer BEGIN DEFERRED (and BEGIN CONCURRENT?)
345361
// semantics via a context annotation function.
346362
if err := c.execInternal(ctx, "BEGIN IMMEDIATE"); err != nil {
@@ -381,6 +397,9 @@ func (tx *connTx) Commit() error {
381397
if tx.conn.tracer != nil {
382398
tx.conn.tracer.Commit(tx.conn.id, err)
383399
}
400+
if tx.conn.logger != nil && err == nil {
401+
tx.conn.logger.Commit()
402+
}
384403
return err
385404
}
386405

@@ -394,6 +413,9 @@ func (tx *connTx) Rollback() error {
394413
if tx.conn.tracer != nil {
395414
tx.conn.tracer.Rollback(tx.conn.id, err)
396415
}
416+
if tx.conn.logger != nil {
417+
tx.conn.logger.Rollback()
418+
}
397419
return err
398420
}
399421

@@ -490,6 +512,9 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
490512
if err := s.bindAll(args); err != nil {
491513
return nil, s.reserr("Stmt.Exec(Bind)", err)
492514
}
515+
if s.conn.logger != nil {
516+
s.conn.logger.Statement(s.stmt.ExpandedSQL())
517+
}
493518

494519
if ctx.Value(queryCancelKey{}) != nil {
495520
done := make(chan struct{})
@@ -1068,3 +1093,20 @@ func WithQueryCancel(ctx context.Context) context.Context {
10681093

10691094
// queryCancelKey is a context key for query context enforcement.
10701095
type queryCancelKey struct{}
1096+
1097+
// ConnLogger is implemented by the caller to support statement-level logging for
1098+
// write transactions. Only Exec calls are logged, not Query calls, as this is
1099+
// intended as a mechanism to replay failed transactions.
1100+
type ConnLogger interface {
1101+
// Begin is called when a writable transaction is opened.
1102+
Begin()
1103+
1104+
// Statement is called with evaluated SQL when a statement is executed.
1105+
Statement(sql string)
1106+
1107+
// Commit is called when a transaction successfully commits.
1108+
Commit()
1109+
1110+
// Rollback is called when a transaction is rolled back.
1111+
Rollback()
1112+
}

sqlite_test.go

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"os"
1414
"reflect"
1515
"runtime"
16+
"slices"
1617
"strconv"
1718
"strings"
1819
"sync"
@@ -1354,3 +1355,118 @@ func TestDisableFunction(t *testing.T) {
13541355
t.Fatal("Attempting to use the LOWER function after disabling should have failed")
13551356
}
13561357
}
1358+
1359+
type connLogger struct {
1360+
ch chan []string
1361+
statements []string
1362+
}
1363+
1364+
func (cl *connLogger) Begin() {
1365+
cl.statements = nil
1366+
}
1367+
1368+
func (cl *connLogger) Statement(s string) {
1369+
cl.statements = append(cl.statements, s)
1370+
}
1371+
1372+
func (cl *connLogger) Commit() {
1373+
cl.ch <- cl.statements
1374+
}
1375+
1376+
func (cl *connLogger) Rollback() {
1377+
cl.statements = nil
1378+
}
1379+
1380+
func TestConnLogger_writable(t *testing.T) {
1381+
for _, commit := range []bool{true, false} {
1382+
doneStatement := "ROLLBACK"
1383+
if commit {
1384+
doneStatement = "COMMIT"
1385+
}
1386+
t.Run(doneStatement, func(t *testing.T) {
1387+
ctx := context.Background()
1388+
ch := make(chan []string, 1)
1389+
txl := connLogger{ch: ch}
1390+
makeLogger := func() ConnLogger { return &txl }
1391+
db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger))
1392+
configDB(t, db)
1393+
1394+
tx, err := db.BeginTx(ctx, nil)
1395+
if err != nil {
1396+
t.Fatal(err)
1397+
}
1398+
if _, err := tx.Exec("CREATE TABLE T (x INTEGER)"); err != nil {
1399+
t.Fatal(err)
1400+
}
1401+
if _, err := tx.Exec("INSERT INTO T VALUES (?)", 1); err != nil {
1402+
t.Fatal(err)
1403+
}
1404+
done := tx.Rollback
1405+
if commit {
1406+
done = tx.Commit
1407+
}
1408+
if err := done(); err != nil {
1409+
t.Fatal(err)
1410+
}
1411+
if !commit {
1412+
select {
1413+
case got := <-ch:
1414+
t.Errorf("unexpectedly logged statements for rollback:\n%s", strings.Join(got, "\n"))
1415+
default:
1416+
return
1417+
}
1418+
}
1419+
1420+
want := []string{
1421+
"BEGIN IMMEDIATE",
1422+
"CREATE TABLE T (x INTEGER)",
1423+
"INSERT INTO T VALUES (1)",
1424+
doneStatement,
1425+
}
1426+
select {
1427+
case got := <-ch:
1428+
if !slices.Equal(got, want) {
1429+
t.Errorf("unexpected log statements. got:\n%s\n\nwant:\n%s", strings.Join(got, "\n"), strings.Join(want, "\n"))
1430+
}
1431+
default:
1432+
t.Fatal("no logged statements after commit")
1433+
}
1434+
})
1435+
}
1436+
}
1437+
1438+
func TestConnLogger_commit_error_retry(t *testing.T) {
1439+
ctx := context.Background()
1440+
ch := make(chan []string, 1)
1441+
txl := connLogger{ch: ch}
1442+
makeLogger := func() ConnLogger { return &txl }
1443+
db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger))
1444+
configDB(t, db)
1445+
1446+
if _, err := db.Exec("PRAGMA foreign_keys = ON"); err != nil {
1447+
t.Fatal(err)
1448+
}
1449+
if _, err := db.Exec("CREATE TABLE A (x INTEGER PRIMARY KEY)"); err != nil {
1450+
t.Fatal(err)
1451+
}
1452+
if _, err := db.Exec("CREATE TABLE B (x INTEGER REFERENCES A(X) DEFERRABLE INITIALLY DEFERRED)"); err != nil {
1453+
t.Fatal(err)
1454+
}
1455+
1456+
tx, err := db.BeginTx(ctx, nil)
1457+
if err != nil {
1458+
t.Fatal(err)
1459+
}
1460+
if _, err := tx.Exec("INSERT INTO B VALUES (?)", 1); err != nil {
1461+
t.Fatal(err)
1462+
}
1463+
if err := tx.Commit(); err == nil {
1464+
t.Fatal("expected Commit to error, but didn't")
1465+
}
1466+
select {
1467+
case got := <-ch:
1468+
t.Errorf("unexpectedly logged statements for errored commit:\n%s", strings.Join(got, "\n"))
1469+
default:
1470+
return
1471+
}
1472+
}

0 commit comments

Comments
 (0)