Skip to content

Commit 0409aac

Browse files
authored
Merge pull request #10 from vgarvardt/patch/pgx-v4-support
Updated adapter to support pgx v4
2 parents 5a35396 + 8aef2a7 commit 0409aac

File tree

8 files changed

+207
-80
lines changed

8 files changed

+207
-80
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ language: go
22
sudo: false
33

44
go:
5-
- "1.11"
65
- "1.12"
6+
- "1.13"
77
- "stable"
88

99
services:

README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,23 @@ The store accepts an adapter interface that interacts with the DB. Adapter and i
1818
package main
1919

2020
import (
21+
"context"
2122
"os"
2223
"time"
2324

24-
"github.com/jackc/pgx"
25+
"github.com/jackc/pgx/v4"
2526
pg "github.com/vgarvardt/go-oauth2-pg"
26-
"github.com/vgarvardt/go-pg-adapter/pgxadapter"
27+
"github.com/vgarvardt/go-pg-adapter/pgx4adapter"
2728
"gopkg.in/oauth2.v3/manage"
2829
)
2930

3031
func main() {
31-
pgxConnConfig, _ := pgx.ParseURI(os.Getenv("DB_URI"))
32-
pgxConn, _ := pgx.Connect(pgxConnConfig)
32+
pgxConn, _ := pgx.Connect(context.TODO(), os.Getenv("DB_URI"))
3333

3434
manager := manage.NewDefaultManager()
3535

3636
// use PostgreSQL token store with pgx.Connection adapter
37-
adapter := pgxadapter.NewConn(pgxConn)
37+
adapter := pgx4adapter.NewConn(pgxConn)
3838
tokenStore, _ := pg.NewTokenStore(adapter, pg.WithTokenStoreGCInterval(time.Minute))
3939
defer tokenStore.Close()
4040

client_store.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
package pg
22

33
import (
4+
"context"
45
"fmt"
56
"log"
67
"os"
78

89
"github.com/json-iterator/go"
9-
"github.com/vgarvardt/go-pg-adapter"
1010
"gopkg.in/oauth2.v3"
1111
"gopkg.in/oauth2.v3/models"
12+
13+
"github.com/vgarvardt/go-pg-adapter"
1214
)
1315

1416
// ClientStore PostgreSQL client store
@@ -53,7 +55,7 @@ func NewClientStore(adapter pgadapter.Adapter, options ...ClientStoreOption) (*C
5355
}
5456

5557
func (s *ClientStore) initTable() error {
56-
return s.adapter.Exec(fmt.Sprintf(`
58+
return s.adapter.Exec(context.Background(), fmt.Sprintf(`
5759
CREATE TABLE IF NOT EXISTS %[1]s (
5860
id TEXT NOT NULL,
5961
secret TEXT NOT NULL,
@@ -77,7 +79,7 @@ func (s *ClientStore) GetByID(id string) (oauth2.ClientInfo, error) {
7779
}
7880

7981
var item ClientStoreItem
80-
if err := s.adapter.SelectOne(&item, fmt.Sprintf("SELECT * FROM %s WHERE id = $1", s.tableName), id); err != nil {
82+
if err := s.adapter.SelectOne(context.Background(), &item, fmt.Sprintf("SELECT * FROM %s WHERE id = $1", s.tableName), id); err != nil {
8183
return nil, err
8284
}
8385

@@ -92,6 +94,7 @@ func (s *ClientStore) Create(info oauth2.ClientInfo) error {
9294
}
9395

9496
return s.adapter.Exec(
97+
context.Background(),
9598
fmt.Sprintf("INSERT INTO %s (id, secret, domain, data) VALUES ($1, $2, $3, $4)", s.tableName),
9699
info.GetID(),
97100
info.GetSecret(),

client_store_test.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,21 @@ import (
55
"testing"
66

77
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/mock"
89
"github.com/stretchr/testify/require"
910
)
1011

1112
func TestClientStore_initTable(t *testing.T) {
1213
adapter := new(mockAdapter)
1314

15+
adapter.On("Exec", mock.Anything, mock.Anything, mock.Anything).Return(nil).Run(func(args mock.Arguments) {
16+
query := args.Get(1).(string)
17+
// new line character is the character at position 0
18+
assert.Equal(t, 1, strings.Index(query, "CREATE TABLE IF NOT EXISTS"))
19+
})
20+
1421
_, err := NewClientStore(adapter)
1522
require.NoError(t, err)
1623

17-
assert.Equal(t, 1, len(adapter.execCalls))
18-
assert.Equal(t, 0, len(adapter.selectOneCalls))
19-
20-
// new line character is the character at position 0
21-
assert.Equal(t, 1, strings.Index(adapter.execCalls[0].query, "CREATE TABLE IF NOT EXISTS"))
24+
adapter.AssertExpectations(t)
2225
}

go.mod

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ module github.com/vgarvardt/go-oauth2-pg
33
go 1.12
44

55
require (
6-
github.com/jackc/pgx v3.6.0+incompatible
6+
github.com/jackc/pgx/v4 v4.4.1
77
github.com/jmoiron/sqlx v1.2.0
88
github.com/json-iterator/go v1.1.7
99
github.com/stretchr/testify v1.4.0
10-
github.com/vgarvardt/go-pg-adapter v0.3.0
10+
github.com/vgarvardt/go-pg-adapter v0.4.1
1111
gopkg.in/oauth2.v3 v3.10.1
1212
)

go.sum

Lines changed: 129 additions & 0 deletions
Large diffs are not rendered by default.

token_store.go

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
package pg
22

33
import (
4+
"context"
45
"fmt"
56
"log"
67
"os"
78
"time"
89

910
"github.com/json-iterator/go"
10-
"github.com/vgarvardt/go-pg-adapter"
1111
"gopkg.in/oauth2.v3"
1212
"gopkg.in/oauth2.v3/models"
13+
14+
"github.com/vgarvardt/go-pg-adapter"
1315
)
1416

1517
// TokenStore PostgreSQL token store
@@ -81,7 +83,7 @@ func (s *TokenStore) gc() {
8183
}
8284

8385
func (s *TokenStore) initTable() error {
84-
return s.adapter.Exec(fmt.Sprintf(`
86+
return s.adapter.Exec(context.Background(), fmt.Sprintf(`
8587
CREATE TABLE IF NOT EXISTS %[1]s (
8688
id BIGSERIAL NOT NULL,
8789
created_at TIMESTAMPTZ NOT NULL,
@@ -102,7 +104,7 @@ CREATE INDEX IF NOT EXISTS idx_%[1]s_refresh ON %[1]s (refresh);
102104

103105
func (s *TokenStore) clean() {
104106
now := time.Now()
105-
err := s.adapter.Exec(fmt.Sprintf("DELETE FROM %s WHERE expires_at <= $1", s.tableName), now)
107+
err := s.adapter.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE expires_at <= $1", s.tableName), now)
106108
if err != nil {
107109
s.logger.Printf("Error while cleaning out outdated entities: %+v", err)
108110
}
@@ -134,6 +136,7 @@ func (s *TokenStore) Create(info oauth2.TokenInfo) error {
134136
}
135137

136138
return s.adapter.Exec(
139+
context.Background(),
137140
fmt.Sprintf("INSERT INTO %s (created_at, expires_at, code, access, refresh, data) VALUES ($1, $2, $3, $4, $5, $6)", s.tableName),
138141
item.CreatedAt,
139142
item.ExpiresAt,
@@ -146,7 +149,7 @@ func (s *TokenStore) Create(info oauth2.TokenInfo) error {
146149

147150
// RemoveByCode deletes the authorization code
148151
func (s *TokenStore) RemoveByCode(code string) error {
149-
err := s.adapter.Exec(fmt.Sprintf("DELETE FROM %s WHERE code = $1", s.tableName), code)
152+
err := s.adapter.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE code = $1", s.tableName), code)
150153
if err == pgadapter.ErrNoRows {
151154
return nil
152155
}
@@ -155,7 +158,7 @@ func (s *TokenStore) RemoveByCode(code string) error {
155158

156159
// RemoveByAccess uses the access token to delete the token information
157160
func (s *TokenStore) RemoveByAccess(access string) error {
158-
err := s.adapter.Exec(fmt.Sprintf("DELETE FROM %s WHERE access = $1", s.tableName), access)
161+
err := s.adapter.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE access = $1", s.tableName), access)
159162
if err == pgadapter.ErrNoRows {
160163
return nil
161164
}
@@ -164,7 +167,7 @@ func (s *TokenStore) RemoveByAccess(access string) error {
164167

165168
// RemoveByRefresh uses the refresh token to delete the token information
166169
func (s *TokenStore) RemoveByRefresh(refresh string) error {
167-
err := s.adapter.Exec(fmt.Sprintf("DELETE FROM %s WHERE refresh = $1", s.tableName), refresh)
170+
err := s.adapter.Exec(context.Background(), fmt.Sprintf("DELETE FROM %s WHERE refresh = $1", s.tableName), refresh)
168171
if err == pgadapter.ErrNoRows {
169172
return nil
170173
}
@@ -184,7 +187,7 @@ func (s *TokenStore) GetByCode(code string) (oauth2.TokenInfo, error) {
184187
}
185188

186189
var item TokenStoreItem
187-
if err := s.adapter.SelectOne(&item, fmt.Sprintf("SELECT * FROM %s WHERE code = $1", s.tableName), code); err != nil {
190+
if err := s.adapter.SelectOne(context.Background(), &item, fmt.Sprintf("SELECT * FROM %s WHERE code = $1", s.tableName), code); err != nil {
188191
return nil, err
189192
}
190193

@@ -198,7 +201,7 @@ func (s *TokenStore) GetByAccess(access string) (oauth2.TokenInfo, error) {
198201
}
199202

200203
var item TokenStoreItem
201-
if err := s.adapter.SelectOne(&item, fmt.Sprintf("SELECT * FROM %s WHERE access = $1", s.tableName), access); err != nil {
204+
if err := s.adapter.SelectOne(context.Background(), &item, fmt.Sprintf("SELECT * FROM %s WHERE access = $1", s.tableName), access); err != nil {
202205
return nil, err
203206
}
204207

@@ -212,7 +215,7 @@ func (s *TokenStore) GetByRefresh(refresh string) (oauth2.TokenInfo, error) {
212215
}
213216

214217
var item TokenStoreItem
215-
if err := s.adapter.SelectOne(&item, fmt.Sprintf("SELECT * FROM %s WHERE refresh = $1", s.tableName), refresh); err != nil {
218+
if err := s.adapter.SelectOne(context.Background(), &item, fmt.Sprintf("SELECT * FROM %s WHERE refresh = $1", s.tableName), refresh); err != nil {
216219
return nil, err
217220
}
218221

0 commit comments

Comments
 (0)