Skip to content

Commit ad5265b

Browse files
committed
Add Prune() for pruning TTL'd sessions in Postgres store.
1 parent c5ca812 commit ad5265b

File tree

2 files changed

+70
-4
lines changed

2 files changed

+70
-4
lines changed

stores/postgres/postgres.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ type queries struct {
5050
update *sql.Stmt
5151
delete *sql.Stmt
5252
clear *sql.Stmt
53+
prune *sql.Stmt
5354
}
5455

5556
// Store represents redis session store for simple sessions.
@@ -128,6 +129,9 @@ func (s *Store) Get(id, key string) (interface{}, error) {
128129
// preserving the types.
129130
var b []byte
130131
if err := s.q.get.QueryRow(id, s.opt.TTL.Seconds()).Scan(&b); err != nil {
132+
if err == sql.ErrNoRows {
133+
return nil, ErrInvalidSession
134+
}
131135
return nil, err
132136
}
133137

@@ -324,6 +328,13 @@ func (s *Store) Clear(id string) error {
324328
return nil
325329
}
326330

331+
// Prune deletes rows that have exceeded the TTL. This should be run externally periodically (ideally as a separate goroutine)
332+
// at desired intervals, hourly/daily etc. based on the expected volume of sessions.
333+
func (s *Store) Prune() error {
334+
_, err := s.q.prune.Exec(s.opt.TTL.Seconds())
335+
return err
336+
}
337+
327338
func (s *Store) prepareQueries() (*queries, error) {
328339
var (
329340
q = &queries{}
@@ -355,6 +366,11 @@ func (s *Store) prepareQueries() (*queries, error) {
355366
return nil, err
356367
}
357368

369+
q.prune, err = s.db.Prepare(fmt.Sprintf("DELETE FROM %s WHERE created_at <= NOW() - INTERVAL '1 second' * $1", s.opt.Table))
370+
if err != nil {
371+
return nil, err
372+
}
373+
358374
return q, err
359375
}
360376

stores/postgres/postgres_test.go

Lines changed: 54 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,17 @@ import (
88
"log"
99
"os"
1010
"testing"
11+
"time"
1112

1213
_ "github.com/lib/pq"
1314
"github.com/stretchr/testify/assert"
1415
)
1516

17+
const testTable = "sessions"
18+
1619
var (
1720
st *Store
21+
db *sql.DB
1822
randID, _ = generateID(sessionIDLen)
1923
)
2024

@@ -26,18 +30,20 @@ func init() {
2630

2731
p := fmt.Sprintf("host=%s port=%s user=%s password=%s dbname=%s sslmode=disable",
2832
os.Getenv("PG_HOST"), os.Getenv("PG_PORT"), os.Getenv("PG_USER"), os.Getenv("PG_PASSWORD"), os.Getenv("PG_DB"))
29-
db, err := sql.Open("postgres", p)
30-
if err != nil {
33+
if d, err := sql.Open("postgres", p); err != nil {
3134
log.Fatal(err)
35+
} else {
36+
db = d
3237
}
3338

3439
if err := db.Ping(); err != nil {
3540
log.Fatal(err)
3641
}
3742

38-
st, err = New(Opt{}, db)
39-
if err != nil {
43+
if s, err := New(Opt{TTL: time.Second * 2, Table: testTable}, db); err != nil {
4044
log.Fatal(err)
45+
} else {
46+
st = s
4147
}
4248
}
4349

@@ -119,3 +125,47 @@ func TestSet(t *testing.T) {
119125
v, err = st.Get(id, "str")
120126
assert.Error(t, err, ErrFieldNotFound)
121127
}
128+
129+
func TestPrune(t *testing.T) {
130+
// Create a new session.
131+
id, err := st.Create()
132+
assert.NoError(t, err)
133+
assert.NotEmpty(t, id)
134+
135+
// Set value.
136+
assert.NoError(t, st.Set(id, "str", "hello 123"))
137+
assert.NoError(t, st.Commit(id))
138+
139+
// Get value and verify.
140+
v, err := st.Get(id, "str")
141+
assert.NoError(t, err)
142+
assert.Equal(t, v, "hello 123")
143+
144+
// Wait until the 2 sec TTL expires and run prune.
145+
time.Sleep(time.Second * 3)
146+
147+
// Session shouldn't be returned.
148+
_, err = st.Get(id, "str")
149+
assert.ErrorIs(t, err, ErrInvalidSession)
150+
151+
// Create one more session and immediately run prune. Except for this,
152+
// all previous sessions should be gone.
153+
id, err = st.Create()
154+
assert.NoError(t, err)
155+
assert.NoError(t, st.Set(id, "str", "hello 123"))
156+
assert.NoError(t, st.Commit(id))
157+
158+
// Run prune. All previously created sessions should be gone.
159+
assert.NoError(t, st.Prune())
160+
161+
var num int
162+
err = db.QueryRow(fmt.Sprintf("SELECT COUNT(*) FROM %s", testTable)).Scan(&num)
163+
assert.NoError(t, err)
164+
assert.Equal(t, num, 1)
165+
166+
// The last created session shouldn't have been pruned.
167+
v, err = st.Get(id, "str")
168+
assert.NoError(t, err)
169+
assert.Equal(t, v, "hello 123")
170+
171+
}

0 commit comments

Comments
 (0)