Skip to content

Commit 7154d04

Browse files
committed
Merge branch 'devel'
2 parents 2a16129 + 75a6119 commit 7154d04

File tree

6 files changed

+140
-58
lines changed

6 files changed

+140
-58
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@ The following types can be used as primary key:
166166
* [`uuid.UUID`](https://godoc.org/github.com/satori/go.uuid#UUID)
167167
* [`kallax.ULID`](https://godoc.org/github.com/src-d/go-kallax/#ULID): this is a type kallax provides that implements a lexically sortable UUID. You can store it as `uuid` like any other UUID, but internally it's an ULID and you will be able to sort lexically by it.
168168

169+
Due to how sql mapping works, pointers to `uuid.UUID` and `kallax.ULID` are not set to `nil` if they appear as `NULL` in the database, but to [`uuid.Nil`](https://godoc.org/github.com/satori/go.uuid#pkg-variables). Using pointers to UUIDs is discouraged for this reason.
170+
169171
If you need another type as primary key, feel free to open a pull request implementing that.
170172

171173
**Known limitations**

batcher.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ type batchQueryRunner struct {
1414
q Query
1515
oneToOneRels []Relationship
1616
oneToManyRels []Relationship
17-
db squirrel.DBProxy
17+
db squirrel.BaseRunner
1818
builder squirrel.SelectBuilder
1919
total int
2020
eof bool
@@ -24,7 +24,7 @@ type batchQueryRunner struct {
2424

2525
var errNoMoreRows = errors.New("kallax: there are no more rows in the result set")
2626

27-
func newBatchQueryRunner(schema Schema, db squirrel.DBProxy, q Query) *batchQueryRunner {
27+
func newBatchQueryRunner(schema Schema, db squirrel.BaseRunner, q Query) *batchQueryRunner {
2828
cols, builder := q.compile()
2929
var (
3030
oneToOneRels []Relationship

batcher_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func TestBatcherLimit(t *testing.T) {
5454
q.BatchSize(2)
5555
q.Limit(5)
5656
r.NoError(q.AddRelation(RelSchema, "rels", OneToMany, Eq(f("foo"), "1")))
57-
runner := newBatchQueryRunner(ModelSchema, store.proxy, q)
57+
runner := newBatchQueryRunner(ModelSchema, store.runner, q)
5858
rs := NewBatchingResultSet(runner)
5959

6060
var count int
@@ -91,7 +91,7 @@ func TestBatcherNoExtraQueryIfLessThanLimit(t *testing.T) {
9191
var queries int
9292
proxy := store.DebugWith(func(_ string, _ ...interface{}) {
9393
queries++
94-
}).proxy
94+
}).runner
9595
runner := newBatchQueryRunner(ModelSchema, proxy, q)
9696
rs := NewBatchingResultSet(runner)
9797

@@ -130,7 +130,7 @@ func TestBatcherNoExtraQueryIfLessThanBatchSize(t *testing.T) {
130130
var queries int
131131
proxy := store.DebugWith(func(_ string, _ ...interface{}) {
132132
queries++
133-
}).proxy
133+
}).runner
134134
runner := newBatchQueryRunner(ModelSchema, proxy, q)
135135
rs := NewBatchingResultSet(runner)
136136

query.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,13 @@ func (q *BaseQuery) String() string {
284284
return sql
285285
}
286286

287+
// ToSql returns the SQL generated by the query, the query arguments, and
288+
// any error returned during the compile process.
289+
func (q *BaseQuery) ToSql() (string, []interface{}, error) {
290+
_, builder := q.compile()
291+
return builder.ToSql()
292+
}
293+
287294
// ColumnOrder represents a column name with its order.
288295
type ColumnOrder interface {
289296
// ToSql returns the SQL representation of the column with its order.

query_test.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,16 @@ func (s *QuerySuite) TestString() {
8282
s.Equal("SELECT __model.foo FROM model __model", s.q.String())
8383
}
8484

85+
func (s *QuerySuite) TestToSql() {
86+
s.q.Select(f("foo"))
87+
s.q.Where(Eq(f("foo"), 5))
88+
s.q.Where(Eq(f("bar"), "baz"))
89+
sql, args, err := s.q.ToSql()
90+
s.Equal("SELECT __model.foo FROM model __model WHERE __model.foo = $1 AND __model.bar = $2", sql)
91+
s.Equal([]interface{}{5, "baz"}, args)
92+
s.Equal(err, nil)
93+
}
94+
8595
func (s *QuerySuite) TestAddRelation() {
8696
s.Nil(s.q.AddRelation(RelSchema, "rel", OneToOne, nil))
8797
s.Equal("SELECT __model.id, __model.name, __model.email, __model.age, __rel_rel.id, __rel_rel.model_id, __rel_rel.foo FROM model __model LEFT JOIN rel __rel_rel ON (__rel_rel.model_id = __model.id)", s.q.String())

store.go

Lines changed: 116 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -60,62 +60,87 @@ func StoreFrom(to, from GenericStorer) {
6060
// logs it.
6161
type LoggerFunc func(string, ...interface{})
6262

63-
// debugProxy is a database proxy that logs all SQL statements executed.
64-
type debugProxy struct {
63+
func defaultLogger(message string, args ...interface{}) {
64+
log.Printf("%s, args: %v", message, args)
65+
}
66+
67+
// basicLogger is a database runner that logs all SQL statements executed.
68+
type basicLogger struct {
6569
logger LoggerFunc
66-
proxy squirrel.DBProxy
70+
runner squirrel.BaseRunner
6771
}
6872

69-
func defaultLogger(message string, args ...interface{}) {
70-
log.Printf("%s, args: %v", message, args)
73+
// basicLogger is a database runner that logs all SQL statements executed.
74+
type proxyLogger struct {
75+
basicLogger
7176
}
7277

73-
func (p *debugProxy) Exec(query string, args ...interface{}) (sql.Result, error) {
78+
func (p *basicLogger) Exec(query string, args ...interface{}) (sql.Result, error) {
7479
p.logger(fmt.Sprintf("kallax: Exec: %s", query), args...)
75-
return p.proxy.Exec(query, args...)
80+
return p.runner.Exec(query, args...)
7681
}
7782

78-
func (p *debugProxy) Query(query string, args ...interface{}) (*sql.Rows, error) {
83+
func (p *basicLogger) Query(query string, args ...interface{}) (*sql.Rows, error) {
7984
p.logger(fmt.Sprintf("kallax: Query: %s", query), args...)
80-
return p.proxy.Query(query, args...)
85+
return p.runner.Query(query, args...)
8186
}
8287

83-
func (p *debugProxy) QueryRow(query string, args ...interface{}) squirrel.RowScanner {
84-
p.logger(fmt.Sprintf("kallax: QueryRow: %s", query), args...)
85-
return p.proxy.QueryRow(query, args...)
88+
func (p *proxyLogger) QueryRow(query string, args ...interface{}) squirrel.RowScanner {
89+
p.basicLogger.logger(fmt.Sprintf("kallax: QueryRow: %s", query), args...)
90+
if queryRower, ok := p.basicLogger.runner.(squirrel.QueryRower); ok {
91+
return queryRower.QueryRow(query, args...)
92+
} else {
93+
panic("Called proxyLogger with a runner which doesn't implement QueryRower")
94+
}
8695
}
8796

88-
func (p *debugProxy) Prepare(query string) (*sql.Stmt, error) {
89-
p.logger(fmt.Sprintf("kallax: Prepare: %s", query))
90-
return p.proxy.Prepare(query)
97+
func (p *proxyLogger) Prepare(query string) (*sql.Stmt, error) {
98+
// If chained runner is a proxy, run Prepare(). Otherwise, noop.
99+
if preparer, ok := p.basicLogger.runner.(squirrel.Preparer); ok {
100+
p.basicLogger.logger(fmt.Sprintf("kallax: Prepare: %s", query))
101+
return preparer.Prepare(query)
102+
} else {
103+
panic("Called proxyLogger with a runner which doesn't implement QueryRower")
104+
}
91105
}
92106

93107
// Store is a structure capable of retrieving records from a concrete table in
94108
// the database.
95109
type Store struct {
96-
builder squirrel.StatementBuilderType
97-
db *sql.DB
98-
proxy squirrel.DBProxy
110+
db interface {
111+
squirrel.BaseRunner
112+
squirrel.PreparerContext
113+
}
114+
runner squirrel.BaseRunner
115+
useCacher bool
116+
logger LoggerFunc
99117
}
100118

101119
// NewStore returns a new Store instance.
102120
func NewStore(db *sql.DB) *Store {
103-
proxy := squirrel.NewStmtCacher(db)
104-
builder := squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar).RunWith(proxy)
105-
return &Store{
106-
db: db,
107-
proxy: proxy,
108-
builder: builder,
109-
}
121+
return (&Store{
122+
db: db,
123+
useCacher: true,
124+
}).init()
110125
}
111126

112-
func newStoreWithTransaction(tx *sql.Tx) *Store {
113-
proxy := squirrel.NewStmtCacher(tx)
114-
builder := squirrel.StatementBuilder.PlaceholderFormat(squirrel.Dollar).RunWith(proxy)
115-
return &Store{
116-
proxy: proxy,
117-
builder: builder,
127+
// init initializes the store runner with debugging or caching, and returns itself for chainability
128+
func (s *Store) init() *Store {
129+
s.runner = s.db
130+
131+
if s.useCacher {
132+
s.runner = squirrel.NewStmtCacher(s.db)
133+
}
134+
135+
if s.logger != nil && !s.useCacher {
136+
// Use BasicLogger as wrapper
137+
s.runner = &basicLogger{s.logger, s.runner}
138+
} else if s.logger != nil && s.useCacher {
139+
// We're using a proxy (cacher), so use proxyLogger instead
140+
s.runner = &proxyLogger{basicLogger{s.logger, s.runner}}
118141
}
142+
143+
return s
119144
}
120145

121146
// Debug returns a new store that will print all SQL statements to stdout using
@@ -127,12 +152,29 @@ func (s *Store) Debug() *Store {
127152
// DebugWith returns a new store that will print all SQL statements using the
128153
// given logger function.
129154
func (s *Store) DebugWith(logger LoggerFunc) *Store {
130-
proxy := &debugProxy{logger, s.proxy}
131-
return &Store{
132-
builder: s.builder.RunWith(proxy),
133-
db: s.db,
134-
proxy: proxy,
135-
}
155+
return (&Store{
156+
db: s.db,
157+
useCacher: s.useCacher,
158+
logger: logger,
159+
}).init()
160+
}
161+
162+
// DisableCacher turns off prepared statements.
163+
func (s *Store) DisableCacher() *Store {
164+
return (&Store{
165+
db: s.db,
166+
logger: s.logger,
167+
useCacher: false,
168+
}).init()
169+
}
170+
171+
// EnableCacher turns on prepared statements. This is the default.
172+
func (s *Store) EnableCacher() *Store {
173+
return (&Store{
174+
db: s.db,
175+
logger: s.logger,
176+
useCacher: true,
177+
}).init()
136178
}
137179

138180
// Insert insert the given record in the table, returns error if no-new
@@ -192,9 +234,20 @@ func (s *Store) Insert(schema Schema, record Record) error {
192234
}
193235

194236
query.WriteString(fmt.Sprintf(" RETURNING %s", schema.ID().String()))
195-
err = s.proxy.QueryRow(query.String(), values...).Scan(pk)
237+
//err = s.runner.QueryRow(query.String(), values...).Scan(pk)
238+
rows, err := s.runner.Query(query.String(), values...)
239+
if err != nil {
240+
return err
241+
}
242+
if rows.Next() {
243+
err = rows.Scan(pk)
244+
rows.Close()
245+
if err != nil {
246+
return err
247+
}
248+
}
196249
} else {
197-
_, err = s.proxy.Exec(query.String(), values...)
250+
_, err = s.runner.Exec(query.String(), values...)
198251
}
199252

200253
if err != nil {
@@ -255,7 +308,7 @@ func (s *Store) Update(schema Schema, record Record, cols ...SchemaField) (int64
255308
query.WriteRune('=')
256309
query.WriteString(fmt.Sprintf("$%d", len(columnNames)+1))
257310

258-
result, err := s.proxy.Exec(query.String(), append(values, record.GetID())...)
311+
result, err := s.runner.Exec(query.String(), append(values, record.GetID())...)
259312
if err != nil {
260313
return 0, err
261314
}
@@ -300,7 +353,7 @@ func (s *Store) Delete(schema Schema, record Record) error {
300353
query.WriteString(schema.ID().String())
301354
query.WriteString("=$1")
302355

303-
_, err := s.proxy.Exec(query.String(), record.GetID())
356+
_, err := s.runner.Exec(query.String(), record.GetID())
304357
return err
305358
}
306359

@@ -309,7 +362,7 @@ func (s *Store) Delete(schema Schema, record Record) error {
309362
// WARNING: A result set created from a raw query can only be scanned using the
310363
// RawScan method of ResultSet, instead of Scan.
311364
func (s *Store) RawQuery(sql string, params ...interface{}) (ResultSet, error) {
312-
rows, err := s.proxy.Query(sql, params...)
365+
rows, err := s.runner.Query(sql, params...)
313366
if err != nil {
314367
return nil, err
315368
}
@@ -320,7 +373,7 @@ func (s *Store) RawQuery(sql string, params ...interface{}) (ResultSet, error) {
320373
// RawExec executes a raw SQL query with the given parameters and returns
321374
// the number of affected rows.
322375
func (s *Store) RawExec(sql string, params ...interface{}) (int64, error) {
323-
result, err := s.proxy.Exec(sql, params...)
376+
result, err := s.runner.Exec(sql, params...)
324377
if err != nil {
325378
return 0, err
326379
}
@@ -332,7 +385,7 @@ func (s *Store) RawExec(sql string, params ...interface{}) (int64, error) {
332385
func (s *Store) Find(q Query) (ResultSet, error) {
333386
rels := q.getRelationships()
334387
if containsRelationshipOfType(rels, OneToMany) {
335-
return NewBatchingResultSet(newBatchQueryRunner(q.Schema(), s.proxy, q)), nil
388+
return NewBatchingResultSet(newBatchQueryRunner(q.Schema(), s.runner, q)), nil
336389
}
337390

338391
columns, builder := q.compile()
@@ -344,7 +397,7 @@ func (s *Store) Find(q Query) (ResultSet, error) {
344397
builder = builder.Limit(limit)
345398
}
346399

347-
rows, err := builder.RunWith(s.proxy).Query()
400+
rows, err := builder.RunWith(s.runner).Query()
348401
if err != nil {
349402
return nil, err
350403
}
@@ -379,7 +432,7 @@ func (s *Store) Reload(schema Schema, record Record) error {
379432
q.Limit(1)
380433
columns, builder := q.compile()
381434

382-
rows, err := builder.RunWith(s.proxy).Query()
435+
rows, err := builder.RunWith(s.runner).Query()
383436
if err != nil {
384437
return err
385438
}
@@ -399,7 +452,7 @@ func (s *Store) Count(q Query) (count int64, err error) {
399452
_, queryBuilder := q.compile()
400453
builder := builder.Set(queryBuilder, "Columns", nil).(squirrel.SelectBuilder)
401454
err = builder.Column(fmt.Sprintf("COUNT(%s)", all.QualifiedName(q.Schema()))).
402-
RunWith(s.proxy).
455+
RunWith(s.runner).
403456
QueryRow().
404457
Scan(&count)
405458
return
@@ -423,16 +476,26 @@ func (s *Store) MustCount(q Query) int64 {
423476
// If a transaction is already opened in this store, instead of opening a new
424477
// one, the other will be reused.
425478
func (s *Store) Transaction(callback func(*Store) error) error {
426-
if s.db == nil {
479+
var tx *sql.Tx
480+
var err error
481+
if db, ok := s.db.(*sql.DB); ok {
482+
// db is *sql.DB, not *sql.Tx
483+
tx, err = db.Begin()
484+
if err != nil {
485+
return fmt.Errorf("kallax: can't open transaction: %s", err)
486+
}
487+
} else {
488+
// store is already holding a transaction
427489
return callback(s)
428490
}
429491

430-
tx, err := s.db.Begin()
431-
if err != nil {
432-
return fmt.Errorf("kallax: can't open transaction: %s", err)
433-
}
492+
txStore := (&Store{
493+
db: tx,
494+
logger: s.logger,
495+
useCacher: true,
496+
}).init()
434497

435-
if err := callback(newStoreWithTransaction(tx)); err != nil {
498+
if err := callback(txStore); err != nil {
436499
if err := tx.Rollback(); err != nil {
437500
return fmt.Errorf("kallax: unable to rollback transaction: %s", err)
438501
}

0 commit comments

Comments
 (0)