11package pg
22
33import (
4+ "context"
45 "database/sql"
56 "fmt"
67 "os"
@@ -12,11 +13,13 @@ import (
1213 _ "github.com/jackc/pgx/stdlib"
1314 "github.com/jmoiron/sqlx"
1415 "github.com/stretchr/testify/assert"
16+ "github.com/stretchr/testify/mock"
1517 "github.com/stretchr/testify/require"
18+ "gopkg.in/oauth2.v3/models"
19+
1620 "github.com/vgarvardt/go-pg-adapter"
17- "github.com/vgarvardt/go-pg-adapter/pgxadapter "
21+ "github.com/vgarvardt/go-pg-adapter/pgx3adapter "
1822 "github.com/vgarvardt/go-pg-adapter/sqladapter"
19- "gopkg.in/oauth2.v3/models"
2023)
2124
2225var uri string
@@ -55,59 +58,46 @@ func (l *memoryLogger) Log(level pgx.LogLevel, msg string, data map[string]inter
5558 }{level : level , msg : msg , data : data })
5659}
5760
58- type queryCall struct {
59- query string
60- args []interface {}
61- }
62-
6361type mockAdapter struct {
64- execCalls []queryCall
65- selectOneCalls []queryCall
66-
67- execCallback func (query string , args ... interface {}) error
68- selectCallback func (dst interface {}, query string , args ... interface {}) error
62+ mock.Mock
6963}
7064
71- func (a * mockAdapter ) Exec (query string , args ... interface {}) error {
72- a .execCalls = append (a .execCalls , queryCall {query : query , args : args })
73-
74- if a .execCallback != nil {
75- return a .execCallback (query , args ... )
76- }
77-
78- return nil
65+ func (m * mockAdapter ) Exec (ctx context.Context , query string , args ... interface {}) error {
66+ mArgs := m .Called (ctx , query , args )
67+ return mArgs .Error (0 )
7968}
8069
81- func (a * mockAdapter ) SelectOne (dst interface {}, query string , args ... interface {}) error {
82- a .selectOneCalls = append (a .selectOneCalls , queryCall {query : query , args : args })
83-
84- if a .selectCallback != nil {
85- return a .selectCallback (dst , query , args ... )
86- }
87-
88- return nil
70+ func (m * mockAdapter ) SelectOne (ctx context.Context , dst interface {}, query string , args ... interface {}) error {
71+ mArgs := m .Called (ctx , dst , query , args )
72+ return mArgs .Error (0 )
8973}
9074
9175func TestTokenStore_initTable (t * testing.T ) {
9276 adapter := new (mockAdapter )
9377
78+ adapter .On ("Exec" , mock .Anything , mock .Anything , mock .Anything ).Return (nil ).Run (func (args mock.Arguments ) {
79+ query := args .Get (1 ).(string )
80+ // new line character is the character at position 0
81+ assert .Equal (t , 1 , strings .Index (query , "CREATE TABLE IF NOT EXISTS" ))
82+ })
83+
9484 store , err := NewTokenStore (adapter , WithTokenStoreGCDisabled ())
9585 require .NoError (t , err )
9686
9787 defer func () {
9888 assert .NoError (t , store .Close ())
9989 }()
100-
101- assert .Equal (t , 1 , len (adapter .execCalls ))
102- assert .Equal (t , 0 , len (adapter .selectOneCalls ))
103-
104- // new line character is the character at position 0
105- assert .Equal (t , 1 , strings .Index (adapter .execCalls [0 ].query , "CREATE TABLE IF NOT EXISTS" ))
10690}
10791
10892func TestTokenStore_gc (t * testing.T ) {
10993 adapter := new (mockAdapter )
11094
95+ adapter .On ("Exec" , mock .Anything , mock .Anything , mock .Anything ).Return (nil ).Run (func (args mock.Arguments ) {
96+ query := args .Get (1 ).(string )
97+ // new line character is the character at position 0
98+ assert .Equal (t , 0 , strings .Index (query , "DELETE FROM" ))
99+ })
100+
111101 store , err := NewTokenStore (adapter , WithTokenStoreInitTableDisabled (), WithTokenStoreGCInterval (time .Second ))
112102 require .NoError (t , err )
113103
@@ -118,13 +108,7 @@ func TestTokenStore_gc(t *testing.T) {
118108 time .Sleep (5 * time .Second )
119109
120110 // in 5 seconds we should have 4-5 gc calls
121- assert .True (t , 3 < len (adapter .execCalls ))
122- assert .True (t , 5 >= len (adapter .execCalls ))
123- assert .Equal (t , 0 , len (adapter .selectOneCalls ))
124-
125- for i := range adapter .execCalls {
126- assert .Equal (t , 0 , strings .Index (adapter .execCalls [i ].query , "DELETE FROM" ))
127- }
111+ adapter .AssertNumberOfCalls (t , "Exec" , 4 )
128112}
129113
130114func generateTokenTableName () string {
@@ -150,7 +134,7 @@ func TestPGXConn(t *testing.T) {
150134 assert .NoError (t , pgxConn .Close ())
151135 }()
152136
153- adapter := pgxadapter .NewConn (pgxConn )
137+ adapter := pgx3adapter .NewConn (pgxConn )
154138
155139 tokenStore , err := NewTokenStore (
156140 adapter ,
@@ -189,7 +173,7 @@ func TestPGXConnPool(t *testing.T) {
189173
190174 defer pgXConnPool .Close ()
191175
192- adapter := pgxadapter .NewConnPool (pgXConnPool )
176+ adapter := pgx3adapter .NewConnPool (pgXConnPool )
193177
194178 tokenStore , err := NewTokenStore (
195179 adapter ,
0 commit comments