11package pg
22
3- import (
4- "errors"
5- "fmt"
6- "log"
7- "os"
8- "time"
9-
10- "github.com/json-iterator/go"
11- "gopkg.in/oauth2.v3"
12- "gopkg.in/oauth2.v3/models"
13- )
3+ import "errors"
144
155// ErrNoRows is the driver-agnostic error returned when no record is found
166var ErrNoRows = errors .New ("sql: no rows in result set" )
@@ -25,207 +15,3 @@ type Adapter interface {
2515type Logger interface {
2616 Printf (format string , v ... interface {})
2717}
28-
29- // Store mysql token store
30- type Store struct {
31- adapter Adapter
32- tableName string
33- logger Logger
34-
35- gcDisabled bool
36- gcInterval time.Duration
37- ticker * time.Ticker
38-
39- initTableDisabled bool
40- }
41-
42- // StoreItem data item
43- type StoreItem struct {
44- ID int64 `db:"id"`
45- CreatedAt time.Time `db:"created_at"`
46- ExpiresAt time.Time `db:"expires_at"`
47- Code string `db:"code"`
48- Access string `db:"access"`
49- Refresh string `db:"refresh"`
50- Data []byte `db:"data"`
51- }
52-
53- // NewStore creates PostgreSQL store instance
54- func NewStore (adapter Adapter , options ... Option ) (* Store , error ) {
55- store := & Store {
56- adapter : adapter ,
57- tableName : "oauth2_token" ,
58- logger : log .New (os .Stderr , "[OAUTH2-PG-ERROR]" , log .LstdFlags ),
59- gcInterval : 10 * time .Minute ,
60- }
61-
62- for _ , o := range options {
63- o (store )
64- }
65-
66- var err error
67- if ! store .initTableDisabled {
68- err = store .initTable ()
69- }
70-
71- if err != nil {
72- return store , err
73- }
74-
75- if ! store .gcDisabled {
76- store .ticker = time .NewTicker (store .gcInterval )
77- go store .gc ()
78- }
79-
80- return store , err
81- }
82-
83- // Close close the store
84- func (s * Store ) Close () error {
85- if ! s .gcDisabled {
86- s .ticker .Stop ()
87- }
88- return nil
89- }
90-
91- func (s * Store ) gc () {
92- for range s .ticker .C {
93- s .clean ()
94- }
95- }
96-
97- func (s * Store ) initTable () error {
98- return s .adapter .Exec (fmt .Sprintf (`
99- CREATE TABLE IF NOT EXISTS %[1]s (
100- id BIGSERIAL NOT NULL,
101- created_at TIMESTAMPTZ NOT NULL,
102- expires_at TIMESTAMPTZ NOT NULL,
103- code TEXT NOT NULL,
104- access TEXT NOT NULL,
105- refresh TEXT NOT NULL,
106- data JSONB NOT NULL,
107- CONSTRAINT %[1]s_pkey PRIMARY KEY (id)
108- );
109-
110- CREATE INDEX IF NOT EXISTS idx_%[1]s_expires_at ON %[1]s (expires_at);
111- CREATE INDEX IF NOT EXISTS idx_%[1]s_code ON %[1]s (code);
112- CREATE INDEX IF NOT EXISTS idx_%[1]s_access ON %[1]s (access);
113- CREATE INDEX IF NOT EXISTS idx_%[1]s_refresh ON %[1]s (refresh);
114- ` , s .tableName ))
115- }
116-
117- func (s * Store ) clean () {
118- now := time .Now ()
119- err := s .adapter .Exec (fmt .Sprintf ("DELETE FROM %s WHERE expires_at <= $1" , s .tableName ), now )
120- if err != nil {
121- s .logger .Printf ("Error while cleaning out outdated entities: %+v" , err )
122- }
123- }
124-
125- // Create create and store the new token information
126- func (s * Store ) Create (info oauth2.TokenInfo ) error {
127- buf , err := jsoniter .Marshal (info )
128- if err != nil {
129- return err
130- }
131-
132- item := & StoreItem {
133- Data : buf ,
134- CreatedAt : time .Now (),
135- }
136-
137- if code := info .GetCode (); code != "" {
138- item .Code = code
139- item .ExpiresAt = info .GetCodeCreateAt ().Add (info .GetCodeExpiresIn ())
140- } else {
141- item .Access = info .GetAccess ()
142- item .ExpiresAt = info .GetAccessCreateAt ().Add (info .GetAccessExpiresIn ())
143-
144- if refresh := info .GetRefresh (); refresh != "" {
145- item .Refresh = info .GetRefresh ()
146- item .ExpiresAt = info .GetRefreshCreateAt ().Add (info .GetRefreshExpiresIn ())
147- }
148- }
149-
150- return s .adapter .Exec (
151- fmt .Sprintf ("INSERT INTO %s (created_at, expires_at, code, access, refresh, data) VALUES ($1, $2, $3, $4, $5, $6)" , s .tableName ),
152- item .CreatedAt ,
153- item .ExpiresAt ,
154- item .Code ,
155- item .Access ,
156- item .Refresh ,
157- item .Data ,
158- )
159- }
160-
161- // RemoveByCode delete the authorization code
162- func (s * Store ) RemoveByCode (code string ) error {
163- err := s .adapter .Exec (fmt .Sprintf ("DELETE FROM %s WHERE code = $1" , s .tableName ), code )
164- if err == ErrNoRows {
165- return nil
166- }
167- return err
168- }
169-
170- // RemoveByAccess use the access token to delete the token information
171- func (s * Store ) RemoveByAccess (access string ) error {
172- err := s .adapter .Exec (fmt .Sprintf ("DELETE FROM %s WHERE access = $1" , s .tableName ), access )
173- if err == ErrNoRows {
174- return nil
175- }
176- return err
177- }
178-
179- // RemoveByRefresh use the refresh token to delete the token information
180- func (s * Store ) RemoveByRefresh (refresh string ) error {
181- err := s .adapter .Exec (fmt .Sprintf ("DELETE FROM %s WHERE refresh = $1" , s .tableName ), refresh )
182- if err == ErrNoRows {
183- return nil
184- }
185- return err
186- }
187-
188- func (s * Store ) toTokenInfo (data []byte ) (oauth2.TokenInfo , error ) {
189- var tm models.Token
190- err := jsoniter .Unmarshal (data , & tm )
191- return & tm , err
192- }
193-
194- // GetByCode use the authorization code for token information data
195- func (s * Store ) GetByCode (code string ) (oauth2.TokenInfo , error ) {
196- if code == "" {
197- return nil , nil
198- }
199-
200- var item StoreItem
201- if err := s .adapter .SelectOne (& item , fmt .Sprintf ("SELECT * FROM %s WHERE code = $1" , s .tableName ), code ); err != nil {
202- return nil , err
203- }
204- return s .toTokenInfo (item .Data )
205- }
206-
207- // GetByAccess use the access token for token information data
208- func (s * Store ) GetByAccess (access string ) (oauth2.TokenInfo , error ) {
209- if access == "" {
210- return nil , nil
211- }
212-
213- var item StoreItem
214- if err := s .adapter .SelectOne (& item , fmt .Sprintf ("SELECT * FROM %s WHERE access = $1" , s .tableName ), access ); err != nil {
215- return nil , err
216- }
217- return s .toTokenInfo (item .Data )
218- }
219-
220- // GetByRefresh use the refresh token for token information data
221- func (s * Store ) GetByRefresh (refresh string ) (oauth2.TokenInfo , error ) {
222- if refresh == "" {
223- return nil , nil
224- }
225-
226- var item StoreItem
227- if err := s .adapter .SelectOne (& item , fmt .Sprintf ("SELECT * FROM %s WHERE refresh = $1" , s .tableName ), refresh ); err != nil {
228- return nil , err
229- }
230- return s .toTokenInfo (item .Data )
231- }
0 commit comments