Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 64 additions & 53 deletions pkg/transport/session/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,32 @@
package session

import (
"context"
"fmt"
"sync"
"time"
)

// Session interface
// Session interface defines the contract for all session types
type Session interface {
ID() string
Type() SessionType
CreatedAt() time.Time
UpdatedAt() time.Time
Touch()

// Data and metadata methods
GetData() interface{}
SetData(data interface{})
GetMetadata() map[string]string
SetMetadata(key, value string)
}

// Manager holds sessions with TTL cleanup.
type Manager struct {
sessions sync.Map
ttl time.Duration
stopCh chan struct{}
factory Factory
storage Storage
ttl time.Duration
stopCh chan struct{}
factory Factory
}

// Factory defines a function type for creating new sessions.
Expand Down Expand Up @@ -56,10 +63,10 @@ func NewManager(ttl time.Duration, factory interface{}) *Manager {
}

m := &Manager{
sessions: sync.Map{},
ttl: ttl,
stopCh: make(chan struct{}),
factory: f,
storage: NewLocalStorage(),
ttl: ttl,
stopCh: make(chan struct{}),
factory: f,
}
go m.cleanupRoutine()
return m
Expand All @@ -83,24 +90,28 @@ func NewTypedManager(ttl time.Duration, sessionType SessionType) *Manager {
return NewManager(ttl, factory)
}

// NewManagerWithStorage creates a session manager with a custom storage backend.
func NewManagerWithStorage(ttl time.Duration, factory Factory, storage Storage) *Manager {
m := &Manager{
storage: storage,
ttl: ttl,
stopCh: make(chan struct{}),
factory: factory,
}
go m.cleanupRoutine()
return m
}

func (m *Manager) cleanupRoutine() {
ticker := time.NewTicker(m.ttl / 2)
defer ticker.Stop()
for {
select {
case <-ticker.C:
cutoff := time.Now().Add(-m.ttl)
m.sessions.Range(func(key, val any) bool {
sess, ok := val.(Session)
if !ok {
// Skip invalid value
return true
}
if sess.UpdatedAt().Before(cutoff) {
m.sessions.Delete(key)
}
return true
})
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
_ = m.storage.DeleteExpired(ctx, cutoff)
cancel()
case <-m.stopCh:
return
}
Expand All @@ -113,13 +124,15 @@ func (m *Manager) AddWithID(id string) error {
if id == "" {
return fmt.Errorf("session ID cannot be empty")
}
// Use LoadOrStore: returns existing if already present
session := m.factory(id)
_, loaded := m.sessions.LoadOrStore(id, session)
if loaded {
// Check if session already exists
ctx := context.Background()
if _, err := m.storage.Load(ctx, id); err == nil {
return fmt.Errorf("session ID %q already exists", id)
}
return nil

// Create and store new session
session := m.factory(id)
return m.storage.Store(ctx, session)
}

// AddSession adds an existing session to the manager.
Expand All @@ -132,62 +145,60 @@ func (m *Manager) AddSession(session Session) error {
return fmt.Errorf("session ID cannot be empty")
}

_, loaded := m.sessions.LoadOrStore(session.ID(), session)
if loaded {
// Check if session already exists
ctx := context.Background()
if _, err := m.storage.Load(ctx, session.ID()); err == nil {
return fmt.Errorf("session ID %q already exists", session.ID())
}
return nil

return m.storage.Store(ctx, session)
}

// Get retrieves a session by ID. Returns (session, true) if found,
// and also updates its UpdatedAt timestamp.
func (m *Manager) Get(id string) (Session, bool) {
v, ok := m.sessions.Load(id)
if !ok {
ctx := context.Background()
sess, err := m.storage.Load(ctx, id)
if err != nil {
return nil, false
}
sess, ok := v.(Session)
if !ok {
return nil, false // Invalid session type
}

sess.Touch()
return sess, true
}

// Delete removes a session by ID.
func (m *Manager) Delete(id string) {
m.sessions.Delete(id)
ctx := context.Background()
_ = m.storage.Delete(ctx, id)
}

// Stop stops the cleanup worker.
// Stop stops the cleanup worker and closes the storage backend.
func (m *Manager) Stop() {
close(m.stopCh)
if m.storage != nil {
_ = m.storage.Close()
}
}

// Range calls f sequentially for each key and value present in the map.
// If f returns false, range stops the iteration.
// Note: This only works with LocalStorage backend.
func (m *Manager) Range(f func(key, value interface{}) bool) {
m.sessions.Range(f)
if localStorage, ok := m.storage.(*LocalStorage); ok {
localStorage.Range(f)
}
}

// Count returns the number of active sessions.
// Note: This only works with LocalStorage backend.
func (m *Manager) Count() int {
count := 0
m.sessions.Range(func(_, _ interface{}) bool {
count++
return true
})
return count
if localStorage, ok := m.storage.(*LocalStorage); ok {
return localStorage.Count()
}
return 0
}

func (m *Manager) cleanupExpiredOnce() {
cutoff := time.Now().Add(-m.ttl)
m.sessions.Range(func(key, val any) bool {
sess := val.(Session)
if sess.UpdatedAt().Before(cutoff) {
m.sessions.Delete(key)
}
return true
})
ctx := context.Background()
_ = m.storage.DeleteExpired(ctx, cutoff)
}
21 changes: 21 additions & 0 deletions pkg/transport/session/proxy_session.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,24 @@ func (s *ProxySession) DeleteMetadata(key string) {
defer s.mu.Unlock()
delete(s.metadata, key)
}

// setTimestamps updates the created and updated timestamps.
// This is used internally for deserialization to restore session state.
func (s *ProxySession) setTimestamps(created, updated time.Time) {
s.mu.Lock()
defer s.mu.Unlock()
s.created = created
s.updated = updated
}

// setMetadataMap replaces the entire metadata map.
// This is used internally for deserialization to restore session state.
func (s *ProxySession) setMetadataMap(metadata map[string]string) {
s.mu.Lock()
defer s.mu.Unlock()
if metadata == nil {
s.metadata = make(map[string]string)
} else {
s.metadata = metadata
}
}
110 changes: 110 additions & 0 deletions pkg/transport/session/serialization.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
package session

import (
"encoding/json"
"fmt"
"time"
)

// The following serialization functions are prepared for Phase 4 (Redis/Valkey implementation)
// They are currently unused but will be needed when implementing distributed storage backends.

// sessionData is the JSON representation of a session.
// This structure is used for serializing sessions to/from storage backends.
// nolint:unused // Will be used in Phase 4 for Redis/Valkey storage
type sessionData struct {
ID string `json:"id"`
Type SessionType `json:"type"`
CreatedAt time.Time `json:"created_at"`
UpdatedAt time.Time `json:"updated_at"`
Data json.RawMessage `json:"data,omitempty"`
Metadata map[string]string `json:"metadata,omitempty"`
}

// serializeSession converts a Session to its JSON representation.
// nolint:unused // Will be used in Phase 4 for Redis/Valkey storage
func serializeSession(s Session) ([]byte, error) {
if s == nil {
return nil, fmt.Errorf("cannot serialize nil session")
}

data := sessionData{
ID: s.ID(),
Type: s.Type(),
CreatedAt: s.CreatedAt(),
UpdatedAt: s.UpdatedAt(),
Metadata: s.GetMetadata(),
}

// Handle session-specific data
if sessionData := s.GetData(); sessionData != nil {
jsonData, err := json.Marshal(sessionData)
if err != nil {
return nil, fmt.Errorf("failed to marshal session data: %w", err)
}
data.Data = jsonData
}

return json.Marshal(data)
}

// deserializeSession reconstructs a Session from its JSON representation.
// It creates the appropriate session type based on the Type field.
// nolint:unused // Will be used in Phase 4 for Redis/Valkey storage
func deserializeSession(data []byte) (Session, error) {
if len(data) == 0 {
return nil, fmt.Errorf("cannot deserialize empty data")
}

var sd sessionData
if err := json.Unmarshal(data, &sd); err != nil {
return nil, fmt.Errorf("failed to unmarshal session data: %w", err)
}

// Create appropriate session type using existing constructors
var session Session
switch sd.Type {
case SessionTypeSSE:
// Use existing NewSSESession constructor
sseSession := NewSSESession(sd.ID)
// Update timestamps to match stored values
sseSession.setTimestamps(sd.CreatedAt, sd.UpdatedAt)
// Restore metadata
sseSession.setMetadataMap(sd.Metadata)
// Note: SSE channels and client info will be recreated when reconnected
session = sseSession

case SessionTypeStreamable:
// Use existing NewStreamableSession constructor
sess := NewStreamableSession(sd.ID)
streamSession, ok := sess.(*StreamableSession)
if !ok {
return nil, fmt.Errorf("failed to create StreamableSession")
}
// Update timestamps to match stored values
streamSession.setTimestamps(sd.CreatedAt, sd.UpdatedAt)
// Restore metadata
streamSession.setMetadataMap(sd.Metadata)
session = streamSession

case SessionTypeMCP:
fallthrough
default:
// Use existing NewTypedProxySession constructor
proxySession := NewTypedProxySession(sd.ID, sd.Type)
// Update timestamps to match stored values
proxySession.setTimestamps(sd.CreatedAt, sd.UpdatedAt)
// Restore metadata
proxySession.setMetadataMap(sd.Metadata)
session = proxySession
}

// Restore session-specific data if present
if len(sd.Data) > 0 {
// For now, we store the raw JSON. Session-specific implementations
// can unmarshal this as needed.
session.SetData(sd.Data)
}

return session, nil
}
Loading
Loading