Skip to content

Commit 6fff47f

Browse files
committed
Updated adapter to support pgx v4
1 parent 5a35396 commit 6fff47f

File tree

7 files changed

+171
-63
lines changed

7 files changed

+171
-63
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ language: go
22
sudo: false
33

44
go:
5-
- "1.11"
65
- "1.12"
6+
- "1.13"
77
- "stable"
88

99
services:

client_store.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
package pg
22

33
import (
4+
"context"
45
"fmt"
56
"log"
67
"os"
78

89
"github.com/json-iterator/go"
9-
"github.com/vgarvardt/go-pg-adapter"
1010
"gopkg.in/oauth2.v3"
1111
"gopkg.in/oauth2.v3/models"
12+
13+
"github.com/vgarvardt/go-pg-adapter"
1214
)
1315

1416
// ClientStore PostgreSQL client store
@@ -53,7 +55,7 @@ func NewClientStore(adapter pgadapter.Adapter, options ...ClientStoreOption) (*C
5355
}
5456

5557
func (s *ClientStore) initTable() error {
56-
return s.adapter.Exec(fmt.Sprintf(`
58+
return s.adapter.Exec(context.Background(), fmt.Sprintf(`
5759
CREATE TABLE IF NOT EXISTS %[1]s (
5860
id TEXT NOT NULL,
5961
secret TEXT NOT NULL,
@@ -77,7 +79,7 @@ func (s *ClientStore) GetByID(id string) (oauth2.ClientInfo, error) {
7779
}
7880

7981
var item ClientStoreItem
80-
if err := s.adapter.SelectOne(&item, fmt.Sprintf("SELECT * FROM %s WHERE id = $1", s.tableName), id); err != nil {
82+
if err := s.adapter.SelectOne(context.Background(), &item, fmt.Sprintf("SELECT * FROM %s WHERE id = $1", s.tableName), id); err != nil {
8183
return nil, err
8284
}
8385

@@ -92,6 +94,7 @@ func (s *ClientStore) Create(info oauth2.ClientInfo) error {
9294
}
9395

9496
return s.adapter.Exec(
97+
context.Background(),
9598
fmt.Sprintf("INSERT INTO %s (id, secret, domain, data) VALUES ($1, $2, $3, $4)", s.tableName),
9699
info.GetID(),
97100
info.GetSecret(),

client_store_test.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,21 @@ import (
55
"testing"
66

77
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/mock"
89
"github.com/stretchr/testify/require"
910
)
1011

1112
func TestClientStore_initTable(t *testing.T) {
1213
adapter := new(mockAdapter)
1314

15+
adapter.On("Exec", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
16+
query := args.Get(1).(string)
17+
// new line character is the character at position 0
18+
assert.Equal(t, 1, strings.Index(query, "CREATE TABLE IF NOT EXISTS"))
19+
})
20+
1421
_, err := NewClientStore(adapter)
1522
require.NoError(t, err)
1623

17-
assert.Equal(t, 1, len(adapter.execCalls))
18-
assert.Equal(t, 0, len(adapter.selectOneCalls))
19-
20-
// new line character is the character at position 0
21-
assert.Equal(t, 1, strings.Index(adapter.execCalls[0].query, "CREATE TABLE IF NOT EXISTS"))
24+
adapter.AssertExpectations(t)
2225
}

go.mod

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ module github.com/vgarvardt/go-oauth2-pg
33
go 1.12
44

55
require (
6-
github.com/jackc/pgx v3.6.0+incompatible
6+
github.com/jackc/pgx v3.6.2+incompatible
77
github.com/jmoiron/sqlx v1.2.0
88
github.com/json-iterator/go v1.1.7
99
github.com/stretchr/testify v1.4.0
10-
github.com/vgarvardt/go-pg-adapter v0.3.0
10+
github.com/vgarvardt/go-pg-adapter v0.4.0
1111
gopkg.in/oauth2.v3 v3.10.1
1212
)

go.sum

Lines changed: 115 additions & 0 deletions
Large diffs are not rendered by default.

token_store.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
package pg
22

33
import (
4+
"context"
45
"fmt"
56
"log"
67
"os"
78
"time"
89

910
"github.com/json-iterator/go"
10-
"github.com/vgarvardt/go-pg-adapter"
1111
"gopkg.in/oauth2.v3"
1212
"gopkg.in/oauth2.v3/models"
13+
14+
"github.com/vgarvardt/go-pg-adapter"
1315
)
1416

1517
// TokenStore PostgreSQL token store
@@ -81,7 +83,7 @@ func (s *TokenStore) gc() {
8183
}
8284

8385
func (s *TokenStore) initTable() error {
84-
return s.adapter.Exec(fmt.Sprintf(`
86+
return s.adapter.Exec(context.Background(), fmt.Sprintf(`
8587
CREATE TABLE IF NOT EXISTS %[1]s (
8688
id BIGSERIAL NOT NULL,
8789
created_at TIMESTAMPTZ NOT NULL,
@@ -102,7 +104,7 @@ CREATE INDEX IF NOT EXISTS idx_%[1]s_refresh ON %[1]s (refresh);
102104

103105
func (s *TokenStore) clean() {
104106
now := time.Now()
105-
err := s.adapter.Exec(fmt.Sprintf("DELETE FROM %s WHERE expires_at <= $1", s.tableName), now)
107+
err := s.adapter.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE expires_at <= $1", s.tableName), now)
106108
if err != nil {
107109
s.logger.Printf("Error while cleaning out outdated entities: %+v", err)
108110
}
@@ -134,6 +136,7 @@ func (s *TokenStore) Create(info oauth2.TokenInfo) error {
134136
}
135137

136138
return s.adapter.Exec(
139+
context.Background(),
137140
fmt.Sprintf("INSERT INTO %s (created_at, expires_at, code, access, refresh, data) VALUES ($1, $2, $3, $4, $5, $6)", s.tableName),
138141
item.CreatedAt,
139142
item.ExpiresAt,
@@ -146,7 +149,7 @@ func (s *TokenStore) Create(info oauth2.TokenInfo) error {
146149

147150
// RemoveByCode deletes the authorization code
148151
func (s *TokenStore) RemoveByCode(code string) error {
149-
err := s.adapter.Exec(fmt.Sprintf("DELETE FROM %s WHERE code = $1", s.tableName), code)
152+
err := s.adapter.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE code = $1", s.tableName), code)
150153
if err == pgadapter.ErrNoRows {
151154
return nil
152155
}
@@ -155,7 +158,7 @@ func (s *TokenStore) RemoveByCode(code string) error {
155158

156159
// RemoveByAccess uses the access token to delete the token information
157160
func (s *TokenStore) RemoveByAccess(access string) error {
158-
err := s.adapter.Exec(fmt.Sprintf("DELETE FROM %s WHERE access = $1", s.tableName), access)
161+
err := s.adapter.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE access = $1", s.tableName), access)
159162
if err == pgadapter.ErrNoRows {
160163
return nil
161164
}
@@ -164,7 +167,7 @@ func (s *TokenStore) RemoveByAccess(access string) error {
164167

165168
// RemoveByRefresh uses the refresh token to delete the token information
166169
func (s *TokenStore) RemoveByRefresh(refresh string) error {
167-
err := s.adapter.Exec(fmt.Sprintf("DELETE FROM %s WHERE refresh = $1", s.tableName), refresh)
170+
err := s.adapter.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE refresh = $1", s.tableName), refresh)
168171
if err == pgadapter.ErrNoRows {
169172
return nil
170173
}
@@ -184,7 +187,7 @@ func (s *TokenStore) GetByCode(code string) (oauth2.TokenInfo, error) {
184187
}
185188

186189
var item TokenStoreItem
187-
if err := s.adapter.SelectOne(&item, fmt.Sprintf("SELECT * FROM %s WHERE code = $1", s.tableName), code); err != nil {
190+
if err := s.adapter.SelectOne(context.Background(), &item, fmt.Sprintf("SELECT * FROM %s WHERE code = $1", s.tableName), code); err != nil {
188191
return nil, err
189192
}
190193

@@ -198,7 +201,7 @@ func (s *TokenStore) GetByAccess(access string) (oauth2.TokenInfo, error) {
198201
}
199202

200203
var item TokenStoreItem
201-
if err := s.adapter.SelectOne(&item, fmt.Sprintf("SELECT * FROM %s WHERE access = $1", s.tableName), access); err != nil {
204+
if err := s.adapter.SelectOne(context.Background(), &item, fmt.Sprintf("SELECT * FROM %s WHERE access = $1", s.tableName), access); err != nil {
202205
return nil, err
203206
}
204207

@@ -212,7 +215,7 @@ func (s *TokenStore) GetByRefresh(refresh string) (oauth2.TokenInfo, error) {
212215
}
213216

214217
var item TokenStoreItem
215-
if err := s.adapter.SelectOne(&item, fmt.Sprintf("SELECT * FROM %s WHERE refresh = $1", s.tableName), refresh); err != nil {
218+
if err := s.adapter.SelectOne(context.Background(), &item, fmt.Sprintf("SELECT * FROM %s WHERE refresh = $1", s.tableName), refresh); err != nil {
216219
return nil, err
217220
}
218221

token_store_test.go

Lines changed: 27 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package pg
22

33
import (
4+
"context"
45
"database/sql"
56
"fmt"
67
"os"
@@ -12,11 +13,13 @@ import (
1213
_ "github.com/jackc/pgx/stdlib"
1314
"github.com/jmoiron/sqlx"
1415
"github.com/stretchr/testify/assert"
16+
"github.com/stretchr/testify/mock"
1517
"github.com/stretchr/testify/require"
18+
"gopkg.in/oauth2.v3/models"
19+
1620
"github.com/vgarvardt/go-pg-adapter"
17-
"github.com/vgarvardt/go-pg-adapter/pgxadapter"
21+
"github.com/vgarvardt/go-pg-adapter/pgx3adapter"
1822
"github.com/vgarvardt/go-pg-adapter/sqladapter"
19-
"gopkg.in/oauth2.v3/models"
2023
)
2124

2225
var uri string
@@ -55,59 +58,46 @@ func (l *memoryLogger) Log(level pgx.LogLevel, msg string, data map[string]inter
5558
}{level: level, msg: msg, data: data})
5659
}
5760

58-
type queryCall struct {
59-
query string
60-
args []interface{}
61-
}
62-
6361
type mockAdapter struct {
64-
execCalls []queryCall
65-
selectOneCalls []queryCall
66-
67-
execCallback func(query string, args ...interface{}) error
68-
selectCallback func(dst interface{}, query string, args ...interface{}) error
62+
mock.Mock
6963
}
7064

71-
func (a *mockAdapter) Exec(query string, args ...interface{}) error {
72-
a.execCalls = append(a.execCalls, queryCall{query: query, args: args})
73-
74-
if a.execCallback != nil {
75-
return a.execCallback(query, args...)
76-
}
77-
78-
return nil
65+
func (m *mockAdapter) Exec(ctx context.Context, query string, args ...interface{}) error {
66+
mArgs := m.Called(ctx, query, args)
67+
return mArgs.Error(0)
7968
}
8069

81-
func (a *mockAdapter) SelectOne(dst interface{}, query string, args ...interface{}) error {
82-
a.selectOneCalls = append(a.selectOneCalls, queryCall{query: query, args: args})
83-
84-
if a.selectCallback != nil {
85-
return a.selectCallback(dst, query, args...)
86-
}
87-
88-
return nil
70+
func (m *mockAdapter) SelectOne(ctx context.Context, dst interface{}, query string, args ...interface{}) error {
71+
mArgs := m.Called(ctx, dst, query, args)
72+
return mArgs.Error(0)
8973
}
9074

9175
func TestTokenStore_initTable(t *testing.T) {
9276
adapter := new(mockAdapter)
9377

78+
adapter.On("Exec", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
79+
query := args.Get(1).(string)
80+
// new line character is the character at position 0
81+
assert.Equal(t, 1, strings.Index(query, "CREATE TABLE IF NOT EXISTS"))
82+
})
83+
9484
store, err := NewTokenStore(adapter, WithTokenStoreGCDisabled())
9585
require.NoError(t, err)
9686

9787
defer func() {
9888
assert.NoError(t, store.Close())
9989
}()
100-
101-
assert.Equal(t, 1, len(adapter.execCalls))
102-
assert.Equal(t, 0, len(adapter.selectOneCalls))
103-
104-
// new line character is the character at position 0
105-
assert.Equal(t, 1, strings.Index(adapter.execCalls[0].query, "CREATE TABLE IF NOT EXISTS"))
10690
}
10791

10892
func TestTokenStore_gc(t *testing.T) {
10993
adapter := new(mockAdapter)
11094

95+
adapter.On("Exec", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
96+
query := args.Get(1).(string)
97+
// new line character is the character at position 0
98+
assert.Equal(t, 0, strings.Index(query, "DELETE FROM"))
99+
})
100+
111101
store, err := NewTokenStore(adapter, WithTokenStoreInitTableDisabled(), WithTokenStoreGCInterval(time.Second))
112102
require.NoError(t, err)
113103

@@ -118,13 +108,7 @@ func TestTokenStore_gc(t *testing.T) {
118108
time.Sleep(5 * time.Second)
119109

120110
// in 5 seconds we should have 4-5 gc calls
121-
assert.True(t, 3 < len(adapter.execCalls))
122-
assert.True(t, 5 >= len(adapter.execCalls))
123-
assert.Equal(t, 0, len(adapter.selectOneCalls))
124-
125-
for i := range adapter.execCalls {
126-
assert.Equal(t, 0, strings.Index(adapter.execCalls[i].query, "DELETE FROM"))
127-
}
111+
adapter.AssertNumberOfCalls(t, "Exec", 4)
128112
}
129113

130114
func generateTokenTableName() string {
@@ -150,7 +134,7 @@ func TestPGXConn(t *testing.T) {
150134
assert.NoError(t, pgxConn.Close())
151135
}()
152136

153-
adapter := pgxadapter.NewConn(pgxConn)
137+
adapter := pgx3adapter.NewConn(pgxConn)
154138

155139
tokenStore, err := NewTokenStore(
156140
adapter,
@@ -189,7 +173,7 @@ func TestPGXConnPool(t *testing.T) {
189173

190174
defer pgXConnPool.Close()
191175

192-
adapter := pgxadapter.NewConnPool(pgXConnPool)
176+
adapter := pgx3adapter.NewConnPool(pgXConnPool)
193177

194178
tokenStore, err := NewTokenStore(
195179
adapter,

0 commit comments

Comments
 (0)