Skip to content

Commit d7c3316

Browse files
committed
Added client store
1 parent 0c3ca76 commit d7c3316

File tree

4 files changed

+182
-7
lines changed

4 files changed

+182
-7
lines changed

client_store.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,100 @@
11
package pg
22

3+
import (
4+
"fmt"
5+
"github.com/json-iterator/go"
6+
"gopkg.in/oauth2.v3/models"
7+
"log"
8+
"os"
9+
10+
"gopkg.in/oauth2.v3"
11+
)
12+
13+
// ClientStore PostgreSQL client store
14+
type ClientStore struct {
15+
adapter Adapter
16+
tableName string
17+
logger Logger
18+
19+
initTableDisabled bool
20+
}
21+
22+
// ClientStoreItem data item
23+
type ClientStoreItem struct {
24+
ID string `db:"id"`
25+
Secret string `db:"secret"`
26+
Domain string `db:"domain"`
27+
Data []byte `db:"data"`
28+
}
29+
30+
// NewClientStore creates PostgreSQL store instance
31+
func NewClientStore(adapter Adapter, options ...ClientStoreOption) (*ClientStore, error) {
32+
store := &ClientStore{
33+
adapter: adapter,
34+
tableName: "oauth2_clients",
35+
logger: log.New(os.Stderr, "[OAUTH2-PG-ERROR]", log.LstdFlags),
36+
}
37+
38+
for _, o := range options {
39+
o(store)
40+
}
41+
42+
var err error
43+
if !store.initTableDisabled {
44+
err = store.initTable()
45+
}
46+
47+
if err != nil {
48+
return store, err
49+
}
50+
51+
return store, err
52+
}
53+
54+
func (s *ClientStore) initTable() error {
55+
return s.adapter.Exec(fmt.Sprintf(`
56+
CREATE TABLE IF NOT EXISTS %[1]s (
57+
id TEXT NOT NULL,
58+
secret TEXT NOT NULL,
59+
domain TEXT NOT NULL,
60+
data JSONB NOT NULL,
61+
CONSTRAINT %[1]s_pkey PRIMARY KEY (id)
62+
);
63+
`, s.tableName))
64+
}
65+
66+
func (s *ClientStore) toClientInfo(data []byte) (oauth2.ClientInfo, error) {
67+
var cm models.Client
68+
err := jsoniter.Unmarshal(data, &cm)
69+
return &cm, err
70+
}
71+
72+
// GetByID retrieves and returns client information by id
73+
func (s *ClientStore) GetByID(id string) (oauth2.ClientInfo, error) {
74+
if id == "" {
75+
return nil, nil
76+
}
77+
78+
var item ClientStoreItem
79+
if err := s.adapter.SelectOne(&item, fmt.Sprintf("SELECT * FROM %s WHERE id = $1", s.tableName), id); err != nil {
80+
return nil, err
81+
}
82+
83+
return s.toClientInfo(item.Data)
84+
}
85+
86+
// Create creates and stores the new client information
87+
func (s *ClientStore) Create(info oauth2.ClientInfo) error {
88+
data, err := jsoniter.Marshal(info)
89+
if err != nil {
90+
return err
91+
}
92+
93+
return s.adapter.Exec(
94+
fmt.Sprintf("INSERT INTO %s (id, secret, domain, data) VALUES ($1, $2, $3, $4)", s.tableName),
95+
info.GetID(),
96+
info.GetSecret(),
97+
info.GetDomain(),
98+
data,
99+
)
100+
}

client_store_options.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
package pg
2+
3+
// ClientStoreOption is the configuration options type for client store
4+
type ClientStoreOption func(s *ClientStore)
5+
6+
// WithClientStoreTableName returns option that sets client store table name
7+
func WithClientStoreTableName(tableName string) ClientStoreOption {
8+
return func(s *ClientStore) {
9+
s.tableName = tableName
10+
}
11+
}
12+
13+
// WithClientStoreLogger returns option that sets client store logger implementation
14+
func WithClientStoreLogger(logger Logger) ClientStoreOption {
15+
return func(s *ClientStore) {
16+
s.logger = logger
17+
}
18+
}
19+
20+
// WithClientStoreInitTableDisabled returns option that disables table creation on client store instantiation
21+
func WithClientStoreInitTableDisabled() ClientStoreOption {
22+
return func(s *ClientStore) {
23+
s.initTableDisabled = true
24+
}
25+
}

client_store_options_test.go

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
package pg
2+
3+
import (
4+
"testing"
5+
"time"
6+
7+
"github.com/stretchr/testify/assert"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestWithClientStoreInitTableDisabled(t *testing.T) {
12+
store, err := NewClientStore(nil, WithClientStoreInitTableDisabled())
13+
require.NoError(t, err)
14+
assert.True(t, store.initTableDisabled)
15+
}
16+
17+
func TestWithClientStoreTableName(t *testing.T) {
18+
randomName := time.Now().String()
19+
20+
store, err := NewClientStore(nil, WithClientStoreTableName(randomName), WithClientStoreInitTableDisabled())
21+
require.NoError(t, err)
22+
assert.Equal(t, randomName, store.tableName)
23+
}
24+
25+
func TestWithClientStoreLogger(t *testing.T) {
26+
l := new(memoryLogger)
27+
28+
store, err := NewClientStore(nil, WithClientStoreLogger(l), WithClientStoreInitTableDisabled())
29+
require.NoError(t, err)
30+
31+
store.logger.Printf("log1", 1, "2", "333")
32+
store.logger.Printf("log2", 12, "22")
33+
34+
require.Equal(t, 2, len(l.formats))
35+
require.Equal(t, 2, len(l.args))
36+
37+
assert.Equal(t, "log1", l.formats[0])
38+
assert.Equal(t, "log2", l.formats[1])
39+
40+
require.Equal(t, 3, len(l.args[0]))
41+
require.Equal(t, 2, len(l.args[1]))
42+
43+
assert.Equal(t, 1, l.args[0][0])
44+
assert.Equal(t, "2", l.args[0][1])
45+
assert.Equal(t, "333", l.args[0][2])
46+
47+
assert.Equal(t, 12, l.args[1][0])
48+
assert.Equal(t, "22", l.args[1][1])
49+
}

token_store.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ func (s *TokenStore) clean() {
107107
}
108108
}
109109

110-
// Create create and store the new token information
110+
// Create creates and stores the new token information
111111
func (s *TokenStore) Create(info oauth2.TokenInfo) error {
112112
buf, err := jsoniter.Marshal(info)
113113
if err != nil {
@@ -143,7 +143,7 @@ func (s *TokenStore) Create(info oauth2.TokenInfo) error {
143143
)
144144
}
145145

146-
// RemoveByCode delete the authorization code
146+
// RemoveByCode deletes the authorization code
147147
func (s *TokenStore) RemoveByCode(code string) error {
148148
err := s.adapter.Exec(fmt.Sprintf("DELETE FROM %s WHERE code = $1", s.tableName), code)
149149
if err == ErrNoRows {
@@ -152,7 +152,7 @@ func (s *TokenStore) RemoveByCode(code string) error {
152152
return err
153153
}
154154

155-
// RemoveByAccess use the access token to delete the token information
155+
// RemoveByAccess uses the access token to delete the token information
156156
func (s *TokenStore) RemoveByAccess(access string) error {
157157
err := s.adapter.Exec(fmt.Sprintf("DELETE FROM %s WHERE access = $1", s.tableName), access)
158158
if err == ErrNoRows {
@@ -161,7 +161,7 @@ func (s *TokenStore) RemoveByAccess(access string) error {
161161
return err
162162
}
163163

164-
// RemoveByRefresh use the refresh token to delete the token information
164+
// RemoveByRefresh uses the refresh token to delete the token information
165165
func (s *TokenStore) RemoveByRefresh(refresh string) error {
166166
err := s.adapter.Exec(fmt.Sprintf("DELETE FROM %s WHERE refresh = $1", s.tableName), refresh)
167167
if err == ErrNoRows {
@@ -176,7 +176,7 @@ func (s *TokenStore) toTokenInfo(data []byte) (oauth2.TokenInfo, error) {
176176
return &tm, err
177177
}
178178

179-
// GetByCode use the authorization code for token information data
179+
// GetByCode uses the authorization code for token information data
180180
func (s *TokenStore) GetByCode(code string) (oauth2.TokenInfo, error) {
181181
if code == "" {
182182
return nil, nil
@@ -186,10 +186,11 @@ func (s *TokenStore) GetByCode(code string) (oauth2.TokenInfo, error) {
186186
if err := s.adapter.SelectOne(&item, fmt.Sprintf("SELECT * FROM %s WHERE code = $1", s.tableName), code); err != nil {
187187
return nil, err
188188
}
189+
189190
return s.toTokenInfo(item.Data)
190191
}
191192

192-
// GetByAccess use the access token for token information data
193+
// GetByAccess uses the access token for token information data
193194
func (s *TokenStore) GetByAccess(access string) (oauth2.TokenInfo, error) {
194195
if access == "" {
195196
return nil, nil
@@ -199,10 +200,11 @@ func (s *TokenStore) GetByAccess(access string) (oauth2.TokenInfo, error) {
199200
if err := s.adapter.SelectOne(&item, fmt.Sprintf("SELECT * FROM %s WHERE access = $1", s.tableName), access); err != nil {
200201
return nil, err
201202
}
203+
202204
return s.toTokenInfo(item.Data)
203205
}
204206

205-
// GetByRefresh use the refresh token for token information data
207+
// GetByRefresh uses the refresh token for token information data
206208
func (s *TokenStore) GetByRefresh(refresh string) (oauth2.TokenInfo, error) {
207209
if refresh == "" {
208210
return nil, nil
@@ -212,5 +214,6 @@ func (s *TokenStore) GetByRefresh(refresh string) (oauth2.TokenInfo, error) {
212214
if err := s.adapter.SelectOne(&item, fmt.Sprintf("SELECT * FROM %s WHERE refresh = $1", s.tableName), refresh); err != nil {
213215
return nil, err
214216
}
217+
215218
return s.toTokenInfo(item.Data)
216219
}

0 commit comments

Comments
 (0)