@@ -60,62 +60,87 @@ func StoreFrom(to, from GenericStorer) {
6060// logs it.
6161type LoggerFunc func (string , ... interface {})
6262
63- // debugProxy is a database proxy that logs all SQL statements executed.
64- type debugProxy struct {
63+ func defaultLogger (message string , args ... interface {}) {
64+ log .Printf ("%s, args: %v" , message , args )
65+ }
66+
67+ // basicLogger is a database runner that logs all SQL statements executed.
68+ type basicLogger struct {
6569 logger LoggerFunc
66- proxy squirrel.DBProxy
70+ runner squirrel.BaseRunner
6771}
6872
69- func defaultLogger (message string , args ... interface {}) {
70- log .Printf ("%s, args: %v" , message , args )
73+ // basicLogger is a database runner that logs all SQL statements executed.
74+ type proxyLogger struct {
75+ basicLogger
7176}
7277
73- func (p * debugProxy ) Exec (query string , args ... interface {}) (sql.Result , error ) {
78+ func (p * basicLogger ) Exec (query string , args ... interface {}) (sql.Result , error ) {
7479 p .logger (fmt .Sprintf ("kallax: Exec: %s" , query ), args ... )
75- return p .proxy .Exec (query , args ... )
80+ return p .runner .Exec (query , args ... )
7681}
7782
78- func (p * debugProxy ) Query (query string , args ... interface {}) (* sql.Rows , error ) {
83+ func (p * basicLogger ) Query (query string , args ... interface {}) (* sql.Rows , error ) {
7984 p .logger (fmt .Sprintf ("kallax: Query: %s" , query ), args ... )
80- return p .proxy .Query (query , args ... )
85+ return p .runner .Query (query , args ... )
8186}
8287
83- func (p * debugProxy ) QueryRow (query string , args ... interface {}) squirrel.RowScanner {
84- p .logger (fmt .Sprintf ("kallax: QueryRow: %s" , query ), args ... )
85- return p .proxy .QueryRow (query , args ... )
88+ func (p * proxyLogger ) QueryRow (query string , args ... interface {}) squirrel.RowScanner {
89+ p .basicLogger .logger (fmt .Sprintf ("kallax: QueryRow: %s" , query ), args ... )
90+ if queryRower , ok := p .basicLogger .runner .(squirrel.QueryRower ); ok {
91+ return queryRower .QueryRow (query , args ... )
92+ } else {
93+ panic ("Called proxyLogger with a runner which doesn't implement QueryRower" )
94+ }
8695}
8796
88- func (p * debugProxy ) Prepare (query string ) (* sql.Stmt , error ) {
89- p .logger (fmt .Sprintf ("kallax: Prepare: %s" , query ))
90- return p .proxy .Prepare (query )
97+ func (p * proxyLogger ) Prepare (query string ) (* sql.Stmt , error ) {
98+ // If chained runner is a proxy, run Prepare(). Otherwise, noop.
99+ if preparer , ok := p .basicLogger .runner .(squirrel.Preparer ); ok {
100+ p .basicLogger .logger (fmt .Sprintf ("kallax: Prepare: %s" , query ))
101+ return preparer .Prepare (query )
102+ } else {
103+ panic ("Called proxyLogger with a runner which doesn't implement QueryRower" )
104+ }
91105}
92106
93107// Store is a structure capable of retrieving records from a concrete table in
94108// the database.
95109type Store struct {
96- builder squirrel.StatementBuilderType
97- db * sql.DB
98- proxy squirrel.DBProxy
110+ db interface {
111+ squirrel.BaseRunner
112+ squirrel.PreparerContext
113+ }
114+ runner squirrel.BaseRunner
115+ useCacher bool
116+ logger LoggerFunc
99117}
100118
101119// NewStore returns a new Store instance.
102120func NewStore (db * sql.DB ) * Store {
103- proxy := squirrel .NewStmtCacher (db )
104- builder := squirrel .StatementBuilder .PlaceholderFormat (squirrel .Dollar ).RunWith (proxy )
105- return & Store {
106- db : db ,
107- proxy : proxy ,
108- builder : builder ,
109- }
121+ return (& Store {
122+ db : db ,
123+ useCacher : true ,
124+ }).init ()
110125}
111126
112- func newStoreWithTransaction (tx * sql.Tx ) * Store {
113- proxy := squirrel .NewStmtCacher (tx )
114- builder := squirrel .StatementBuilder .PlaceholderFormat (squirrel .Dollar ).RunWith (proxy )
115- return & Store {
116- proxy : proxy ,
117- builder : builder ,
127+ // init initializes the store runner with debugging or caching, and returns itself for chainability
128+ func (s * Store ) init () * Store {
129+ s .runner = s .db
130+
131+ if s .useCacher {
132+ s .runner = squirrel .NewStmtCacher (s .db )
133+ }
134+
135+ if s .logger != nil && ! s .useCacher {
136+ // Use BasicLogger as wrapper
137+ s .runner = & basicLogger {s .logger , s .runner }
138+ } else if s .logger != nil && s .useCacher {
139+ // We're using a proxy (cacher), so use proxyLogger instead
140+ s .runner = & proxyLogger {basicLogger {s .logger , s .runner }}
118141 }
142+
143+ return s
119144}
120145
121146// Debug returns a new store that will print all SQL statements to stdout using
@@ -127,12 +152,29 @@ func (s *Store) Debug() *Store {
127152// DebugWith returns a new store that will print all SQL statements using the
128153// given logger function.
129154func (s * Store ) DebugWith (logger LoggerFunc ) * Store {
130- proxy := & debugProxy {logger , s .proxy }
131- return & Store {
132- builder : s .builder .RunWith (proxy ),
133- db : s .db ,
134- proxy : proxy ,
135- }
155+ return (& Store {
156+ db : s .db ,
157+ useCacher : s .useCacher ,
158+ logger : logger ,
159+ }).init ()
160+ }
161+
162+ // DisableCacher turns off prepared statements.
163+ func (s * Store ) DisableCacher () * Store {
164+ return (& Store {
165+ db : s .db ,
166+ logger : s .logger ,
167+ useCacher : false ,
168+ }).init ()
169+ }
170+
171+ // EnableCacher turns on prepared statements. This is the default.
172+ func (s * Store ) EnableCacher () * Store {
173+ return (& Store {
174+ db : s .db ,
175+ logger : s .logger ,
176+ useCacher : true ,
177+ }).init ()
136178}
137179
138180// Insert insert the given record in the table, returns error if no-new
@@ -192,9 +234,20 @@ func (s *Store) Insert(schema Schema, record Record) error {
192234 }
193235
194236 query .WriteString (fmt .Sprintf (" RETURNING %s" , schema .ID ().String ()))
195- err = s .proxy .QueryRow (query .String (), values ... ).Scan (pk )
237+ //err = s.runner.QueryRow(query.String(), values...).Scan(pk)
238+ rows , err := s .runner .Query (query .String (), values ... )
239+ if err != nil {
240+ return err
241+ }
242+ if rows .Next () {
243+ err = rows .Scan (pk )
244+ rows .Close ()
245+ if err != nil {
246+ return err
247+ }
248+ }
196249 } else {
197- _ , err = s .proxy .Exec (query .String (), values ... )
250+ _ , err = s .runner .Exec (query .String (), values ... )
198251 }
199252
200253 if err != nil {
@@ -255,7 +308,7 @@ func (s *Store) Update(schema Schema, record Record, cols ...SchemaField) (int64
255308 query .WriteRune ('=' )
256309 query .WriteString (fmt .Sprintf ("$%d" , len (columnNames )+ 1 ))
257310
258- result , err := s .proxy .Exec (query .String (), append (values , record .GetID ())... )
311+ result , err := s .runner .Exec (query .String (), append (values , record .GetID ())... )
259312 if err != nil {
260313 return 0 , err
261314 }
@@ -300,7 +353,7 @@ func (s *Store) Delete(schema Schema, record Record) error {
300353 query .WriteString (schema .ID ().String ())
301354 query .WriteString ("=$1" )
302355
303- _ , err := s .proxy .Exec (query .String (), record .GetID ())
356+ _ , err := s .runner .Exec (query .String (), record .GetID ())
304357 return err
305358}
306359
@@ -309,7 +362,7 @@ func (s *Store) Delete(schema Schema, record Record) error {
309362// WARNING: A result set created from a raw query can only be scanned using the
310363// RawScan method of ResultSet, instead of Scan.
311364func (s * Store ) RawQuery (sql string , params ... interface {}) (ResultSet , error ) {
312- rows , err := s .proxy .Query (sql , params ... )
365+ rows , err := s .runner .Query (sql , params ... )
313366 if err != nil {
314367 return nil , err
315368 }
@@ -320,7 +373,7 @@ func (s *Store) RawQuery(sql string, params ...interface{}) (ResultSet, error) {
320373// RawExec executes a raw SQL query with the given parameters and returns
321374// the number of affected rows.
322375func (s * Store ) RawExec (sql string , params ... interface {}) (int64 , error ) {
323- result , err := s .proxy .Exec (sql , params ... )
376+ result , err := s .runner .Exec (sql , params ... )
324377 if err != nil {
325378 return 0 , err
326379 }
@@ -332,7 +385,7 @@ func (s *Store) RawExec(sql string, params ...interface{}) (int64, error) {
332385func (s * Store ) Find (q Query ) (ResultSet , error ) {
333386 rels := q .getRelationships ()
334387 if containsRelationshipOfType (rels , OneToMany ) {
335- return NewBatchingResultSet (newBatchQueryRunner (q .Schema (), s .proxy , q )), nil
388+ return NewBatchingResultSet (newBatchQueryRunner (q .Schema (), s .runner , q )), nil
336389 }
337390
338391 columns , builder := q .compile ()
@@ -344,7 +397,7 @@ func (s *Store) Find(q Query) (ResultSet, error) {
344397 builder = builder .Limit (limit )
345398 }
346399
347- rows , err := builder .RunWith (s .proxy ).Query ()
400+ rows , err := builder .RunWith (s .runner ).Query ()
348401 if err != nil {
349402 return nil , err
350403 }
@@ -379,7 +432,7 @@ func (s *Store) Reload(schema Schema, record Record) error {
379432 q .Limit (1 )
380433 columns , builder := q .compile ()
381434
382- rows , err := builder .RunWith (s .proxy ).Query ()
435+ rows , err := builder .RunWith (s .runner ).Query ()
383436 if err != nil {
384437 return err
385438 }
@@ -399,7 +452,7 @@ func (s *Store) Count(q Query) (count int64, err error) {
399452 _ , queryBuilder := q .compile ()
400453 builder := builder .Set (queryBuilder , "Columns" , nil ).(squirrel.SelectBuilder )
401454 err = builder .Column (fmt .Sprintf ("COUNT(%s)" , all .QualifiedName (q .Schema ()))).
402- RunWith (s .proxy ).
455+ RunWith (s .runner ).
403456 QueryRow ().
404457 Scan (& count )
405458 return
@@ -423,16 +476,26 @@ func (s *Store) MustCount(q Query) int64 {
423476// If a transaction is already opened in this store, instead of opening a new
424477// one, the other will be reused.
425478func (s * Store ) Transaction (callback func (* Store ) error ) error {
426- if s .db == nil {
479+ var tx * sql.Tx
480+ var err error
481+ if db , ok := s .db .(* sql.DB ); ok {
482+ // db is *sql.DB, not *sql.Tx
483+ tx , err = db .Begin ()
484+ if err != nil {
485+ return fmt .Errorf ("kallax: can't open transaction: %s" , err )
486+ }
487+ } else {
488+ // store is already holding a transaction
427489 return callback (s )
428490 }
429491
430- tx , err := s .db .Begin ()
431- if err != nil {
432- return fmt .Errorf ("kallax: can't open transaction: %s" , err )
433- }
492+ txStore := (& Store {
493+ db : tx ,
494+ logger : s .logger ,
495+ useCacher : true ,
496+ }).init ()
434497
435- if err := callback (newStoreWithTransaction ( tx ) ); err != nil {
498+ if err := callback (txStore ); err != nil {
436499 if err := tx .Rollback (); err != nil {
437500 return fmt .Errorf ("kallax: unable to rollback transaction: %s" , err )
438501 }
0 commit comments