Skip to content

Commit 2f8f4aa

Browse files
committed
feat: Add pluggable storage backend for session management
Refactors the session management system to use a pluggable storage interface, enabling future support for distributed storage backends like Redis/Valkey while maintaining backward compatibility. What Changed - Introduced a Storage interface that abstracts session persistence - Refactored Manager to use the Storage interface instead of directly using sync.Map - Created LocalStorage implementation that maintains the existing in-memory behavior - Added JSON serialization support for sessions to enable future network storage - Extended Session interface with Type() and metadata methods that were already implemented in concrete types Why The previous implementation was tightly coupled to in-memory storage, making it impossible to share sessions across multiple ToolHive instances. This refactoring enables: - Horizontal scaling with shared session state - Session persistence across restarts - Future Redis/Valkey backend support without breaking changes Testing Added comprehensive unit tests covering: - LocalStorage implementation - Session serialization/deserialization - Manager with pluggable storage - All existing session types (ProxySession, SSESession, StreamableSession) All tests pass and the implementation maintains full backward compatibility. Signed-off-by: Juan Antonio Osorio <[email protected]>
1 parent f796fda commit 2f8f4aa

File tree

6 files changed

+1071
-53
lines changed

6 files changed

+1071
-53
lines changed

pkg/transport/session/manager.go

Lines changed: 64 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -2,25 +2,32 @@
22
package session
33

44
import (
5+
"context"
56
"fmt"
6-
"sync"
77
"time"
88
)
99

10-
// Session interface
10+
// Session interface defines the contract for all session types
1111
type Session interface {
1212
ID() string
13+
Type() SessionType
1314
CreatedAt() time.Time
1415
UpdatedAt() time.Time
1516
Touch()
17+
18+
// Data and metadata methods
19+
GetData() interface{}
20+
SetData(data interface{})
21+
GetMetadata() map[string]string
22+
SetMetadata(key, value string)
1623
}
1724

1825
// Manager holds sessions with TTL cleanup.
1926
type Manager struct {
20-
sessions sync.Map
21-
ttl time.Duration
22-
stopCh chan struct{}
23-
factory Factory
27+
storage Storage
28+
ttl time.Duration
29+
stopCh chan struct{}
30+
factory Factory
2431
}
2532

2633
// Factory defines a function type for creating new sessions.
@@ -56,10 +63,10 @@ func NewManager(ttl time.Duration, factory interface{}) *Manager {
5663
}
5764

5865
m := &Manager{
59-
sessions: sync.Map{},
60-
ttl: ttl,
61-
stopCh: make(chan struct{}),
62-
factory: f,
66+
storage: NewLocalStorage(),
67+
ttl: ttl,
68+
stopCh: make(chan struct{}),
69+
factory: f,
6370
}
6471
go m.cleanupRoutine()
6572
return m
@@ -83,24 +90,28 @@ func NewTypedManager(ttl time.Duration, sessionType SessionType) *Manager {
8390
return NewManager(ttl, factory)
8491
}
8592

93+
// NewManagerWithStorage creates a session manager with a custom storage backend.
94+
func NewManagerWithStorage(ttl time.Duration, factory Factory, storage Storage) *Manager {
95+
m := &Manager{
96+
storage: storage,
97+
ttl: ttl,
98+
stopCh: make(chan struct{}),
99+
factory: factory,
100+
}
101+
go m.cleanupRoutine()
102+
return m
103+
}
104+
86105
func (m *Manager) cleanupRoutine() {
87106
ticker := time.NewTicker(m.ttl / 2)
88107
defer ticker.Stop()
89108
for {
90109
select {
91110
case <-ticker.C:
92111
cutoff := time.Now().Add(-m.ttl)
93-
m.sessions.Range(func(key, val any) bool {
94-
sess, ok := val.(Session)
95-
if !ok {
96-
// Skip invalid value
97-
return true
98-
}
99-
if sess.UpdatedAt().Before(cutoff) {
100-
m.sessions.Delete(key)
101-
}
102-
return true
103-
})
112+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
113+
_ = m.storage.DeleteExpired(ctx, cutoff)
114+
cancel()
104115
case <-m.stopCh:
105116
return
106117
}
@@ -113,13 +124,15 @@ func (m *Manager) AddWithID(id string) error {
113124
if id == "" {
114125
return fmt.Errorf("session ID cannot be empty")
115126
}
116-
// Use LoadOrStore: returns existing if already present
117-
session := m.factory(id)
118-
_, loaded := m.sessions.LoadOrStore(id, session)
119-
if loaded {
127+
// Check if session already exists
128+
ctx := context.Background()
129+
if _, err := m.storage.Load(ctx, id); err == nil {
120130
return fmt.Errorf("session ID %q already exists", id)
121131
}
122-
return nil
132+
133+
// Create and store new session
134+
session := m.factory(id)
135+
return m.storage.Store(ctx, session)
123136
}
124137

125138
// AddSession adds an existing session to the manager.
@@ -132,62 +145,60 @@ func (m *Manager) AddSession(session Session) error {
132145
return fmt.Errorf("session ID cannot be empty")
133146
}
134147

135-
_, loaded := m.sessions.LoadOrStore(session.ID(), session)
136-
if loaded {
148+
// Check if session already exists
149+
ctx := context.Background()
150+
if _, err := m.storage.Load(ctx, session.ID()); err == nil {
137151
return fmt.Errorf("session ID %q already exists", session.ID())
138152
}
139-
return nil
153+
154+
return m.storage.Store(ctx, session)
140155
}
141156

142157
// Get retrieves a session by ID. Returns (session, true) if found,
143158
// and also updates its UpdatedAt timestamp.
144159
func (m *Manager) Get(id string) (Session, bool) {
145-
v, ok := m.sessions.Load(id)
146-
if !ok {
160+
ctx := context.Background()
161+
sess, err := m.storage.Load(ctx, id)
162+
if err != nil {
147163
return nil, false
148164
}
149-
sess, ok := v.(Session)
150-
if !ok {
151-
return nil, false // Invalid session type
152-
}
153-
154-
sess.Touch()
155165
return sess, true
156166
}
157167

158168
// Delete removes a session by ID.
159169
func (m *Manager) Delete(id string) {
160-
m.sessions.Delete(id)
170+
ctx := context.Background()
171+
_ = m.storage.Delete(ctx, id)
161172
}
162173

163-
// Stop stops the cleanup worker.
174+
// Stop stops the cleanup worker and closes the storage backend.
164175
func (m *Manager) Stop() {
165176
close(m.stopCh)
177+
if m.storage != nil {
178+
_ = m.storage.Close()
179+
}
166180
}
167181

168182
// Range calls f sequentially for each key and value present in the map.
169183
// If f returns false, range stops the iteration.
184+
// Note: This only works with LocalStorage backend.
170185
func (m *Manager) Range(f func(key, value interface{}) bool) {
171-
m.sessions.Range(f)
186+
if localStorage, ok := m.storage.(*LocalStorage); ok {
187+
localStorage.Range(f)
188+
}
172189
}
173190

174191
// Count returns the number of active sessions.
192+
// Note: This only works with LocalStorage backend.
175193
func (m *Manager) Count() int {
176-
count := 0
177-
m.sessions.Range(func(_, _ interface{}) bool {
178-
count++
179-
return true
180-
})
181-
return count
194+
if localStorage, ok := m.storage.(*LocalStorage); ok {
195+
return localStorage.Count()
196+
}
197+
return 0
182198
}
183199

184200
func (m *Manager) cleanupExpiredOnce() {
185201
cutoff := time.Now().Add(-m.ttl)
186-
m.sessions.Range(func(key, val any) bool {
187-
sess := val.(Session)
188-
if sess.UpdatedAt().Before(cutoff) {
189-
m.sessions.Delete(key)
190-
}
191-
return true
192-
})
202+
ctx := context.Background()
203+
_ = m.storage.DeleteExpired(ctx, cutoff)
193204
}
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
package session
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"time"
7+
)
8+
9+
// The following serialization functions are prepared for Phase 4 (Redis/Valkey implementation)
10+
// They are currently unused but will be needed when implementing distributed storage backends.
11+
12+
// sessionData is the JSON representation of a session.
13+
// This structure is used for serializing sessions to/from storage backends.
14+
// nolint:unused // Will be used in Phase 4 for Redis/Valkey storage
15+
type sessionData struct {
16+
ID string `json:"id"`
17+
Type SessionType `json:"type"`
18+
CreatedAt time.Time `json:"created_at"`
19+
UpdatedAt time.Time `json:"updated_at"`
20+
Data json.RawMessage `json:"data,omitempty"`
21+
Metadata map[string]string `json:"metadata,omitempty"`
22+
}
23+
24+
// serializeSession converts a Session to its JSON representation.
25+
// nolint:unused // Will be used in Phase 4 for Redis/Valkey storage
26+
func serializeSession(s Session) ([]byte, error) {
27+
if s == nil {
28+
return nil, fmt.Errorf("cannot serialize nil session")
29+
}
30+
31+
data := sessionData{
32+
ID: s.ID(),
33+
Type: s.Type(),
34+
CreatedAt: s.CreatedAt(),
35+
UpdatedAt: s.UpdatedAt(),
36+
Metadata: s.GetMetadata(),
37+
}
38+
39+
// Handle session-specific data
40+
if sessionData := s.GetData(); sessionData != nil {
41+
jsonData, err := json.Marshal(sessionData)
42+
if err != nil {
43+
return nil, fmt.Errorf("failed to marshal session data: %w", err)
44+
}
45+
data.Data = jsonData
46+
}
47+
48+
return json.Marshal(data)
49+
}
50+
51+
// deserializeSession reconstructs a Session from its JSON representation.
52+
// It creates the appropriate session type based on the Type field.
53+
// nolint:unused // Will be used in Phase 4 for Redis/Valkey storage
54+
func deserializeSession(data []byte) (Session, error) {
55+
if len(data) == 0 {
56+
return nil, fmt.Errorf("cannot deserialize empty data")
57+
}
58+
59+
var sd sessionData
60+
if err := json.Unmarshal(data, &sd); err != nil {
61+
return nil, fmt.Errorf("failed to unmarshal session data: %w", err)
62+
}
63+
64+
// Create appropriate session type using existing constructors
65+
var session Session
66+
switch sd.Type {
67+
case SessionTypeSSE:
68+
// Use existing NewSSESession constructor
69+
sseSession := NewSSESession(sd.ID)
70+
// Update timestamps to match stored values
71+
sseSession.created = sd.CreatedAt
72+
sseSession.updated = sd.UpdatedAt
73+
// Restore metadata
74+
sseSession.metadata = sd.Metadata
75+
if sseSession.metadata == nil {
76+
sseSession.metadata = make(map[string]string)
77+
}
78+
// Note: SSE channels and client info will be recreated when reconnected
79+
session = sseSession
80+
81+
case SessionTypeStreamable:
82+
// Use existing NewStreamableSession constructor
83+
streamSession := NewStreamableSession(sd.ID).(*StreamableSession)
84+
// Update timestamps to match stored values
85+
streamSession.created = sd.CreatedAt
86+
streamSession.updated = sd.UpdatedAt
87+
streamSession.sessType = SessionTypeStreamable
88+
// Restore metadata
89+
streamSession.metadata = sd.Metadata
90+
if streamSession.metadata == nil {
91+
streamSession.metadata = make(map[string]string)
92+
}
93+
session = streamSession
94+
95+
case SessionTypeMCP:
96+
fallthrough
97+
default:
98+
// Use existing NewTypedProxySession constructor
99+
proxySession := NewTypedProxySession(sd.ID, sd.Type)
100+
// Update timestamps to match stored values
101+
proxySession.created = sd.CreatedAt
102+
proxySession.updated = sd.UpdatedAt
103+
// Restore metadata
104+
proxySession.metadata = sd.Metadata
105+
if proxySession.metadata == nil {
106+
proxySession.metadata = make(map[string]string)
107+
}
108+
session = proxySession
109+
}
110+
111+
// Restore session-specific data if present
112+
if len(sd.Data) > 0 {
113+
// For now, we store the raw JSON. Session-specific implementations
114+
// can unmarshal this as needed.
115+
session.SetData(sd.Data)
116+
}
117+
118+
return session, nil
119+
}

0 commit comments

Comments
 (0)