diff --git a/pkg/transport/proxy/httpsse/http_proxy.go b/pkg/transport/proxy/httpsse/http_proxy.go index 0ccaa1963..f5eef0122 100644 --- a/pkg/transport/proxy/httpsse/http_proxy.go +++ b/pkg/transport/proxy/httpsse/http_proxy.go @@ -206,7 +206,9 @@ func (p *HTTPSSEProxy) Stop(ctx context.Context) error { // Stop the session manager cleanup routine if p.sessionManager != nil { - p.sessionManager.Stop() + if err := p.sessionManager.Stop(); err != nil { + logger.Errorf("Failed to stop session manager: %v", err) + } } // Disconnect all active sessions @@ -466,7 +468,9 @@ func (p *HTTPSSEProxy) removeClient(clientID string) { } // Remove the session from the manager - p.sessionManager.Delete(clientID) + if err := p.sessionManager.Delete(clientID); err != nil { + logger.Debugf("Failed to delete session %s: %v", clientID, err) + } // Clean up closed clients map periodically (prevent memory leak) p.closedClientsMutex.Lock() diff --git a/pkg/transport/proxy/streamable/streamable_proxy.go b/pkg/transport/proxy/streamable/streamable_proxy.go index ee4eb0cb1..1daf4355f 100644 --- a/pkg/transport/proxy/streamable/streamable_proxy.go +++ b/pkg/transport/proxy/streamable/streamable_proxy.go @@ -117,7 +117,9 @@ func (p *HTTPProxy) Stop(ctx context.Context) error { // Stop session manager cleanup and disconnect sessions if p.sessionManager != nil { - p.sessionManager.Stop() + if err := p.sessionManager.Stop(); err != nil { + logger.Errorf("Failed to stop session manager: %v", err) + } p.sessionManager.Range(func(_, value interface{}) bool { if ss, ok := value.(*session.StreamableSession); ok { ss.Disconnect() @@ -202,7 +204,9 @@ func (p *HTTPProxy) handleDelete(w http.ResponseWriter, r *http.Request) { writeHTTPError(w, http.StatusNotFound, "session not found") return } - p.sessionManager.Delete(sessID) + if err := p.sessionManager.Delete(sessID); err != nil { + logger.Debugf("Failed to delete session %s: %v", sessID, err) + } w.WriteHeader(http.StatusNoContent) } diff --git a/pkg/transport/session/manager.go b/pkg/transport/session/manager.go index 66ecc8971..d2e38f5b6 100644 --- a/pkg/transport/session/manager.go +++ b/pkg/transport/session/manager.go @@ -2,25 +2,34 @@ package session import ( + "context" "fmt" - "sync" "time" + + "github.com/stacklok/toolhive/pkg/logger" ) -// Session interface +// Session interface defines the contract for all session types type Session interface { ID() string + Type() SessionType CreatedAt() time.Time UpdatedAt() time.Time Touch() + + // Data and metadata methods + GetData() interface{} + SetData(data interface{}) + GetMetadata() map[string]string + SetMetadata(key, value string) } // Manager holds sessions with TTL cleanup. type Manager struct { - sessions sync.Map - ttl time.Duration - stopCh chan struct{} - factory Factory + storage Storage + ttl time.Duration + stopCh chan struct{} + factory Factory } // Factory defines a function type for creating new sessions. @@ -56,10 +65,10 @@ func NewManager(ttl time.Duration, factory interface{}) *Manager { } m := &Manager{ - sessions: sync.Map{}, - ttl: ttl, - stopCh: make(chan struct{}), - factory: f, + storage: NewLocalStorage(), + ttl: ttl, + stopCh: make(chan struct{}), + factory: f, } go m.cleanupRoutine() return m @@ -83,6 +92,18 @@ func NewTypedManager(ttl time.Duration, sessionType SessionType) *Manager { return NewManager(ttl, factory) } +// NewManagerWithStorage creates a session manager with a custom storage backend. +func NewManagerWithStorage(ttl time.Duration, factory Factory, storage Storage) *Manager { + m := &Manager{ + storage: storage, + ttl: ttl, + stopCh: make(chan struct{}), + factory: factory, + } + go m.cleanupRoutine() + return m +} + func (m *Manager) cleanupRoutine() { ticker := time.NewTicker(m.ttl / 2) defer ticker.Stop() @@ -90,17 +111,11 @@ func (m *Manager) cleanupRoutine() { select { case <-ticker.C: cutoff := time.Now().Add(-m.ttl) - m.sessions.Range(func(key, val any) bool { - sess, ok := val.(Session) - if !ok { - // Skip invalid value - return true - } - if sess.UpdatedAt().Before(cutoff) { - m.sessions.Delete(key) - } - return true - }) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + if err := m.storage.DeleteExpired(ctx, cutoff); err != nil { + logger.Errorf("Failed to delete expired sessions: %v", err) + } + cancel() case <-m.stopCh: return } @@ -113,13 +128,17 @@ func (m *Manager) AddWithID(id string) error { if id == "" { return fmt.Errorf("session ID cannot be empty") } - // Use LoadOrStore: returns existing if already present - session := m.factory(id) - _, loaded := m.sessions.LoadOrStore(id, session) - if loaded { + // Check if session already exists + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if _, err := m.storage.Load(ctx, id); err == nil { return fmt.Errorf("session ID %q already exists", id) } - return nil + + // Create and store new session + session := m.factory(id) + return m.storage.Store(ctx, session) } // AddSession adds an existing session to the manager. @@ -132,62 +151,85 @@ func (m *Manager) AddSession(session Session) error { return fmt.Errorf("session ID cannot be empty") } - _, loaded := m.sessions.LoadOrStore(session.ID(), session) - if loaded { + // Check if session already exists + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + if _, err := m.storage.Load(ctx, session.ID()); err == nil { return fmt.Errorf("session ID %q already exists", session.ID()) } - return nil + + return m.storage.Store(ctx, session) } // Get retrieves a session by ID. Returns (session, true) if found, // and also updates its UpdatedAt timestamp. func (m *Manager) Get(id string) (Session, bool) { - v, ok := m.sessions.Load(id) - if !ok { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + sess, err := m.storage.Load(ctx, id) + if err != nil { return nil, false } - sess, ok := v.(Session) - if !ok { - return nil, false // Invalid session type - } - + // Touch the session to update its timestamp sess.Touch() return sess, true } // Delete removes a session by ID. -func (m *Manager) Delete(id string) { - m.sessions.Delete(id) +// Returns an error if the deletion fails. +func (m *Manager) Delete(id string) error { + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + return m.storage.Delete(ctx, id) } -// Stop stops the cleanup worker. -func (m *Manager) Stop() { +// Stop stops the cleanup worker and closes the storage backend. +// Returns an error if closing the storage backend fails. +func (m *Manager) Stop() error { close(m.stopCh) + if m.storage != nil { + return m.storage.Close() + } + return nil } // Range calls f sequentially for each key and value present in the map. // If f returns false, range stops the iteration. +// +// Note: This method only works with LocalStorage backend. It will silently +// do nothing with other storage backends. Range is not part of the Storage +// interface because it's not feasible for distributed storage backends like +// Redis where iterating all keys can be prohibitively expensive or impractical. +// +// For distributed storage, consider using more targeted queries or maintaining +// a separate index of session IDs. func (m *Manager) Range(f func(key, value interface{}) bool) { - m.sessions.Range(f) + if localStorage, ok := m.storage.(*LocalStorage); ok { + localStorage.Range(f) + } } // Count returns the number of active sessions. +// +// Note: This method only works with LocalStorage backend and returns 0 for +// other storage backends. Count is not part of the Storage interface because +// it's not feasible for distributed storage backends like Redis where counting +// all keys can be prohibitively expensive. +// +// For distributed storage, consider maintaining a counter or using approximate +// count mechanisms provided by the storage backend. func (m *Manager) Count() int { - count := 0 - m.sessions.Range(func(_, _ interface{}) bool { - count++ - return true - }) - return count + if localStorage, ok := m.storage.(*LocalStorage); ok { + return localStorage.Count() + } + return 0 } -func (m *Manager) cleanupExpiredOnce() { +func (m *Manager) cleanupExpiredOnce() error { cutoff := time.Now().Add(-m.ttl) - m.sessions.Range(func(key, val any) bool { - sess := val.(Session) - if sess.UpdatedAt().Before(cutoff) { - m.sessions.Delete(key) - } - return true - }) + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + return m.storage.DeleteExpired(ctx, cutoff) } diff --git a/pkg/transport/session/proxy_session.go b/pkg/transport/session/proxy_session.go index 57a86bd86..7b91d7221 100644 --- a/pkg/transport/session/proxy_session.go +++ b/pkg/transport/session/proxy_session.go @@ -134,3 +134,24 @@ func (s *ProxySession) DeleteMetadata(key string) { defer s.mu.Unlock() delete(s.metadata, key) } + +// setTimestamps updates the created and updated timestamps. +// This is used internally for deserialization to restore session state. +func (s *ProxySession) setTimestamps(created, updated time.Time) { + s.mu.Lock() + defer s.mu.Unlock() + s.created = created + s.updated = updated +} + +// setMetadataMap replaces the entire metadata map. +// This is used internally for deserialization to restore session state. +func (s *ProxySession) setMetadataMap(metadata map[string]string) { + s.mu.Lock() + defer s.mu.Unlock() + if metadata == nil { + s.metadata = make(map[string]string) + } else { + s.metadata = metadata + } +} diff --git a/pkg/transport/session/serialization.go b/pkg/transport/session/serialization.go new file mode 100644 index 000000000..dd214a570 --- /dev/null +++ b/pkg/transport/session/serialization.go @@ -0,0 +1,110 @@ +package session + +import ( + "encoding/json" + "fmt" + "time" +) + +// The following serialization functions are prepared for Phase 4 (Redis/Valkey implementation) +// They are currently unused but will be needed when implementing distributed storage backends. + +// sessionData is the JSON representation of a session. +// This structure is used for serializing sessions to/from storage backends. +// nolint:unused // Will be used in Phase 4 for Redis/Valkey storage +type sessionData struct { + ID string `json:"id"` + Type SessionType `json:"type"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` + Data json.RawMessage `json:"data,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// serializeSession converts a Session to its JSON representation. +// nolint:unused // Will be used in Phase 4 for Redis/Valkey storage +func serializeSession(s Session) ([]byte, error) { + if s == nil { + return nil, fmt.Errorf("cannot serialize nil session") + } + + data := sessionData{ + ID: s.ID(), + Type: s.Type(), + CreatedAt: s.CreatedAt(), + UpdatedAt: s.UpdatedAt(), + Metadata: s.GetMetadata(), + } + + // Handle session-specific data + if sessionData := s.GetData(); sessionData != nil { + jsonData, err := json.Marshal(sessionData) + if err != nil { + return nil, fmt.Errorf("failed to marshal session data: %w", err) + } + data.Data = jsonData + } + + return json.Marshal(data) +} + +// deserializeSession reconstructs a Session from its JSON representation. +// It creates the appropriate session type based on the Type field. +// nolint:unused // Will be used in Phase 4 for Redis/Valkey storage +func deserializeSession(data []byte) (Session, error) { + if len(data) == 0 { + return nil, fmt.Errorf("cannot deserialize empty data") + } + + var sd sessionData + if err := json.Unmarshal(data, &sd); err != nil { + return nil, fmt.Errorf("failed to unmarshal session data: %w", err) + } + + // Create appropriate session type using existing constructors + var session Session + switch sd.Type { + case SessionTypeSSE: + // Use existing NewSSESession constructor + sseSession := NewSSESession(sd.ID) + // Update timestamps to match stored values + sseSession.setTimestamps(sd.CreatedAt, sd.UpdatedAt) + // Restore metadata + sseSession.setMetadataMap(sd.Metadata) + // Note: SSE channels and client info will be recreated when reconnected + session = sseSession + + case SessionTypeStreamable: + // Use existing NewStreamableSession constructor + sess := NewStreamableSession(sd.ID) + streamSession, ok := sess.(*StreamableSession) + if !ok { + return nil, fmt.Errorf("failed to create StreamableSession") + } + // Update timestamps to match stored values + streamSession.setTimestamps(sd.CreatedAt, sd.UpdatedAt) + // Restore metadata + streamSession.setMetadataMap(sd.Metadata) + session = streamSession + + case SessionTypeMCP: + fallthrough + default: + // Use existing NewTypedProxySession constructor + proxySession := NewTypedProxySession(sd.ID, sd.Type) + // Update timestamps to match stored values + proxySession.setTimestamps(sd.CreatedAt, sd.UpdatedAt) + // Restore metadata + proxySession.setMetadataMap(sd.Metadata) + session = proxySession + } + + // Restore session-specific data if present + if len(sd.Data) > 0 { + // For now, we store the raw JSON. Session-specific implementations + // can unmarshal this as needed. + session.SetData(sd.Data) + } + + return session, nil +} diff --git a/pkg/transport/session/serialization_test.go b/pkg/transport/session/serialization_test.go new file mode 100644 index 000000000..9dc0946ca --- /dev/null +++ b/pkg/transport/session/serialization_test.go @@ -0,0 +1,232 @@ +package session + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestSerialization tests the serialization and deserialization functions +func TestSerialization(t *testing.T) { + t.Parallel() + t.Run("Serialize and Deserialize ProxySession", func(t *testing.T) { + t.Parallel() + // Create a session with metadata + original := NewProxySession("test-proxy-1") + original.SetMetadata("key1", "value1") + original.SetMetadata("key2", "value2") + original.SetData(map[string]interface{}{"custom": "data"}) + + // Serialize + data, err := serializeSession(original) + require.NoError(t, err) + assert.NotEmpty(t, data) + + // Verify JSON structure + var jsonData map[string]interface{} + err = json.Unmarshal(data, &jsonData) + require.NoError(t, err) + assert.Equal(t, "test-proxy-1", jsonData["id"]) + assert.Equal(t, string(SessionTypeMCP), jsonData["type"]) + + // Deserialize + restored, err := deserializeSession(data) + require.NoError(t, err) + assert.NotNil(t, restored) + + // Verify restored session + assert.Equal(t, original.ID(), restored.ID()) + assert.Equal(t, original.Type(), restored.Type()) + + // Check metadata + metadata := restored.GetMetadata() + assert.Equal(t, "value1", metadata["key1"]) + assert.Equal(t, "value2", metadata["key2"]) + }) + + t.Run("Serialize and Deserialize SSESession", func(t *testing.T) { + t.Parallel() + // Create an SSE session + original := NewSSESession("test-sse-1") + original.SetMetadata("client", "browser") + original.SetMetadata("version", "1.0") + + // Serialize + data, err := serializeSession(original) + require.NoError(t, err) + assert.NotEmpty(t, data) + + // Deserialize + restored, err := deserializeSession(data) + require.NoError(t, err) + assert.NotNil(t, restored) + + // Verify it's an SSE session + assert.Equal(t, SessionTypeSSE, restored.Type()) + assert.Equal(t, "test-sse-1", restored.ID()) + + // Check it's the right type + sseSession, ok := restored.(*SSESession) + assert.True(t, ok) + assert.NotNil(t, sseSession.MessageCh) + + // Check metadata + metadata := restored.GetMetadata() + assert.Equal(t, "browser", metadata["client"]) + assert.Equal(t, "1.0", metadata["version"]) + }) + + t.Run("Serialize and Deserialize StreamableSession", func(t *testing.T) { + t.Parallel() + // Create a streamable session + original := NewStreamableSession("test-stream-1") + original.SetMetadata("protocol", "http") + + // Serialize + data, err := serializeSession(original) + require.NoError(t, err) + assert.NotEmpty(t, data) + + // Deserialize + restored, err := deserializeSession(data) + require.NoError(t, err) + assert.NotNil(t, restored) + + // Verify it's a streamable session + assert.Equal(t, SessionTypeStreamable, restored.Type()) + assert.Equal(t, "test-stream-1", restored.ID()) + + // Check metadata + metadata := restored.GetMetadata() + assert.Equal(t, "http", metadata["protocol"]) + }) + + t.Run("Serialize nil session", func(t *testing.T) { + t.Parallel() + data, err := serializeSession(nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "nil session") + assert.Nil(t, data) + }) + + t.Run("Deserialize empty data", func(t *testing.T) { + t.Parallel() + session, err := deserializeSession([]byte{}) + assert.Error(t, err) + assert.Contains(t, err.Error(), "empty data") + assert.Nil(t, session) + }) + + t.Run("Deserialize invalid JSON", func(t *testing.T) { + t.Parallel() + session, err := deserializeSession([]byte("not json")) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unmarshal") + assert.Nil(t, session) + }) + + t.Run("Preserve timestamps", func(t *testing.T) { + t.Parallel() + // Create a session with specific timestamps + original := NewProxySession("test-time-1") + createdAt := original.CreatedAt() + + // Wait a bit and touch to update the timestamp + time.Sleep(10 * time.Millisecond) + original.Touch() + updatedAt := original.UpdatedAt() + + // Serialize + data, err := serializeSession(original) + require.NoError(t, err) + + // Deserialize + restored, err := deserializeSession(data) + require.NoError(t, err) + + // Timestamps should be preserved + assert.Equal(t, createdAt.Unix(), restored.CreatedAt().Unix()) + assert.Equal(t, updatedAt.Unix(), restored.UpdatedAt().Unix()) + }) + + t.Run("Handle session with no metadata", func(t *testing.T) { + t.Parallel() + // Create a session without metadata + original := NewProxySession("test-no-meta") + + // Serialize + data, err := serializeSession(original) + require.NoError(t, err) + + // Deserialize + restored, err := deserializeSession(data) + require.NoError(t, err) + + // Metadata should be empty but not nil + metadata := restored.GetMetadata() + assert.NotNil(t, metadata) + assert.Len(t, metadata, 0) + }) + + t.Run("Handle complex data in session", func(t *testing.T) { + t.Parallel() + // Create a session with complex data + original := NewProxySession("test-complex") + complexData := map[string]interface{}{ + "string": "value", + "number": 42, + "bool": true, + "nested": map[string]interface{}{ + "key": "value", + }, + } + original.SetData(complexData) + + // Serialize + data, err := serializeSession(original) + require.NoError(t, err) + + // Deserialize + restored, err := deserializeSession(data) + require.NoError(t, err) + + // Data should be preserved as JSON + restoredData := restored.GetData() + assert.NotNil(t, restoredData) + + // The data will be stored as json.RawMessage + if rawData, ok := restoredData.(json.RawMessage); ok { + var parsed map[string]interface{} + err = json.Unmarshal(rawData, &parsed) + require.NoError(t, err) + assert.Equal(t, "value", parsed["string"]) + assert.Equal(t, float64(42), parsed["number"]) // JSON numbers are float64 + assert.Equal(t, true, parsed["bool"]) + } + }) + + t.Run("Unknown session type defaults to ProxySession", func(t *testing.T) { + t.Parallel() + // Create JSON with unknown session type + jsonData := `{ + "id": "unknown-1", + "type": "unknown", + "created_at": "2024-01-01T00:00:00Z", + "updated_at": "2024-01-01T00:00:00Z" + }` + + // Deserialize + restored, err := deserializeSession([]byte(jsonData)) + require.NoError(t, err) + assert.NotNil(t, restored) + + // Should be a ProxySession with the unknown type + assert.Equal(t, SessionType("unknown"), restored.Type()) + proxySession, ok := restored.(*ProxySession) + assert.True(t, ok) + assert.NotNil(t, proxySession) + }) +} diff --git a/pkg/transport/session/storage.go b/pkg/transport/session/storage.go new file mode 100644 index 000000000..661f6697a --- /dev/null +++ b/pkg/transport/session/storage.go @@ -0,0 +1,34 @@ +// Package session provides session management with pluggable storage backends. +package session + +import ( + "context" + "time" +) + +// Storage defines the minimal interface for session storage backends. +// This interface is designed to be simple and efficient, supporting both +// local in-memory storage and distributed storage backends like Redis/Valkey. +type Storage interface { + // Store creates or updates a session in the storage backend. + // If the session already exists, it will be overwritten. + Store(ctx context.Context, session Session) error + + // Load retrieves a session by ID from the storage backend. + // Returns ErrSessionNotFound if the session doesn't exist. + // Note: This does not automatically touch the session. Callers should + // explicitly call Touch() on the returned session if they want to update its timestamp. + Load(ctx context.Context, id string) (Session, error) + + // Delete removes a session from the storage backend. + // It is not an error if the session doesn't exist. + Delete(ctx context.Context, id string) error + + // DeleteExpired removes all sessions that haven't been updated since the given time. + // This is used by the cleanup routine to remove stale sessions. + DeleteExpired(ctx context.Context, before time.Time) error + + // Close performs cleanup of the storage backend. + // For local storage, this clears all sessions. For remote storage, it closes connections. + Close() error +} diff --git a/pkg/transport/session/storage_local.go b/pkg/transport/session/storage_local.go new file mode 100644 index 000000000..601c58c6d --- /dev/null +++ b/pkg/transport/session/storage_local.go @@ -0,0 +1,125 @@ +package session + +import ( + "context" + "fmt" + "sync" + "time" +) + +// LocalStorage implements the Storage interface using an in-memory sync.Map. +// This is the default storage backend for single-instance deployments. +type LocalStorage struct { + sessions sync.Map +} + +// NewLocalStorage creates a new local in-memory storage backend. +func NewLocalStorage() *LocalStorage { + return &LocalStorage{} +} + +// Store saves a session to the local storage. +// For local storage, we store the session object directly without serialization. +func (s *LocalStorage) Store(_ context.Context, session Session) error { + if session == nil { + return fmt.Errorf("cannot store nil session") + } + if session.ID() == "" { + return fmt.Errorf("cannot store session with empty ID") + } + + s.sessions.Store(session.ID(), session) + return nil +} + +// Load retrieves a session from local storage. +func (s *LocalStorage) Load(_ context.Context, id string) (Session, error) { + if id == "" { + return nil, fmt.Errorf("cannot load session with empty ID") + } + + val, ok := s.sessions.Load(id) + if !ok { + return nil, ErrSessionNotFound + } + + session, ok := val.(Session) + if !ok { + return nil, fmt.Errorf("invalid session type in storage") + } + + return session, nil +} + +// Delete removes a session from local storage. +func (s *LocalStorage) Delete(_ context.Context, id string) error { + if id == "" { + return fmt.Errorf("cannot delete session with empty ID") + } + + s.sessions.Delete(id) + return nil +} + +// DeleteExpired removes all sessions that haven't been updated since the given time. +func (s *LocalStorage) DeleteExpired(ctx context.Context, before time.Time) error { + var toDelete []string + + // First pass: collect IDs of expired sessions + s.sessions.Range(func(key, val any) bool { + // Check for context cancellation + select { + case <-ctx.Done(): + return false + default: + } + + if session, ok := val.(Session); ok { + if session.UpdatedAt().Before(before) { + if id, ok := key.(string); ok { + toDelete = append(toDelete, id) + } + } + } + return true + }) + + // Second pass: delete expired sessions + for _, id := range toDelete { + s.sessions.Delete(id) + } + + return nil +} + +// Close clears all sessions from local storage. +func (s *LocalStorage) Close() error { + // Collect keys first to avoid modifying map during iteration + var toDelete []any + s.sessions.Range(func(key, _ any) bool { + toDelete = append(toDelete, key) + return true + }) + // Clear all sessions + for _, key := range toDelete { + s.sessions.Delete(key) + } + return nil +} + +// Count returns the number of sessions in storage. +// This is a helper method not part of the Storage interface. +func (s *LocalStorage) Count() int { + count := 0 + s.sessions.Range(func(_, _ interface{}) bool { + count++ + return true + }) + return count +} + +// Range iterates over all sessions in storage. +// This is a helper method not part of the Storage interface. +func (s *LocalStorage) Range(f func(key, value interface{}) bool) { + s.sessions.Range(f) +} diff --git a/pkg/transport/session/storage_test.go b/pkg/transport/session/storage_test.go new file mode 100644 index 000000000..ee33f1836 --- /dev/null +++ b/pkg/transport/session/storage_test.go @@ -0,0 +1,506 @@ +package session + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// TestLocalStorage tests the LocalStorage implementation +func TestLocalStorage(t *testing.T) { + t.Parallel() + t.Run("Store and Load", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + defer storage.Close() + + // Create a test session + session := NewProxySession("test-id-1") + session.SetMetadata("key1", "value1") + + // Store the session + ctx := context.Background() + err := storage.Store(ctx, session) + require.NoError(t, err) + + // Load the session + loaded, err := storage.Load(ctx, "test-id-1") + require.NoError(t, err) + assert.NotNil(t, loaded) + assert.Equal(t, "test-id-1", loaded.ID()) + assert.Equal(t, SessionTypeMCP, loaded.Type()) + + // Check metadata was preserved + metadata := loaded.GetMetadata() + assert.Equal(t, "value1", metadata["key1"]) + }) + + t.Run("Store nil session", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + defer storage.Close() + + ctx := context.Background() + err := storage.Store(ctx, nil) + assert.Error(t, err) + assert.Contains(t, err.Error(), "nil session") + }) + + t.Run("Store session with empty ID", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + defer storage.Close() + + session := &ProxySession{} // Empty ID + ctx := context.Background() + err := storage.Store(ctx, session) + assert.Error(t, err) + assert.Contains(t, err.Error(), "empty ID") + }) + + t.Run("Load non-existent session", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + defer storage.Close() + + ctx := context.Background() + loaded, err := storage.Load(ctx, "non-existent") + assert.Error(t, err) + assert.Equal(t, ErrSessionNotFound, err) + assert.Nil(t, loaded) + }) + + t.Run("Load with empty ID", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + defer storage.Close() + + ctx := context.Background() + loaded, err := storage.Load(ctx, "") + assert.Error(t, err) + assert.Contains(t, err.Error(), "empty ID") + assert.Nil(t, loaded) + }) + + t.Run("Delete session", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + defer storage.Close() + + // Store a session + session := NewProxySession("test-id-2") + ctx := context.Background() + err := storage.Store(ctx, session) + require.NoError(t, err) + + // Verify it exists + loaded, err := storage.Load(ctx, "test-id-2") + require.NoError(t, err) + assert.NotNil(t, loaded) + + // Delete it + err = storage.Delete(ctx, "test-id-2") + require.NoError(t, err) + + // Verify it's gone + loaded, err = storage.Load(ctx, "test-id-2") + assert.Error(t, err) + assert.Equal(t, ErrSessionNotFound, err) + assert.Nil(t, loaded) + }) + + t.Run("Delete non-existent session", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + defer storage.Close() + + ctx := context.Background() + // Should not error when deleting non-existent session + err := storage.Delete(ctx, "non-existent") + assert.NoError(t, err) + }) + + t.Run("Delete with empty ID", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + defer storage.Close() + + ctx := context.Background() + err := storage.Delete(ctx, "") + assert.Error(t, err) + assert.Contains(t, err.Error(), "empty ID") + }) + + t.Run("DeleteExpired", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + defer storage.Close() + + ctx := context.Background() + + // Create sessions with different update times + oldSession := NewProxySession("old-session") + newSession := NewProxySession("new-session") + + // Store both sessions + err := storage.Store(ctx, oldSession) + require.NoError(t, err) + err = storage.Store(ctx, newSession) + require.NoError(t, err) + + // Manually set the old session's updated time to the past + oldSession.updated = time.Now().Add(-2 * time.Hour) + + // Store the old session again with the old timestamp + err = storage.Store(ctx, oldSession) + require.NoError(t, err) + + // Delete sessions older than 1 hour + cutoff := time.Now().Add(-1 * time.Hour) + err = storage.DeleteExpired(ctx, cutoff) + require.NoError(t, err) + + // Old session should be gone + _, err = storage.Load(ctx, "old-session") + assert.Equal(t, ErrSessionNotFound, err) + + // New session should still exist + loaded, err := storage.Load(ctx, "new-session") + assert.NoError(t, err) + assert.NotNil(t, loaded) + }) + + t.Run("Load does not auto-touch", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + defer storage.Close() + + // Create and store a session + session := NewProxySession("test-id-3") + originalUpdated := session.UpdatedAt() + + ctx := context.Background() + err := storage.Store(ctx, session) + require.NoError(t, err) + + // Wait a bit to ensure time difference + time.Sleep(10 * time.Millisecond) + + // Load the session (should NOT auto-touch) + loaded, err := storage.Load(ctx, "test-id-3") + require.NoError(t, err) + + // Updated time should be the same (not auto-touched) + assert.Equal(t, originalUpdated, loaded.UpdatedAt()) + + // But manual Touch should update the time + loaded.Touch() + assert.True(t, loaded.UpdatedAt().After(originalUpdated)) + }) + + t.Run("Count helper method", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + defer storage.Close() + + ctx := context.Background() + + // Initially empty + assert.Equal(t, 0, storage.Count()) + + // Add sessions + for i := 0; i < 5; i++ { + session := NewProxySession(fmt.Sprintf("session-%d", i)) + err := storage.Store(ctx, session) + require.NoError(t, err) + } + + // Should have 5 sessions + assert.Equal(t, 5, storage.Count()) + + // Delete one + err := storage.Delete(ctx, "session-0") + require.NoError(t, err) + + // Should have 4 sessions + assert.Equal(t, 4, storage.Count()) + }) + + t.Run("Range helper method", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + defer storage.Close() + + ctx := context.Background() + + // Add some sessions + ids := []string{"alpha", "beta", "gamma"} + for _, id := range ids { + session := NewProxySession(id) + err := storage.Store(ctx, session) + require.NoError(t, err) + } + + // Use Range to collect all IDs + var collected []string + storage.Range(func(key, _ interface{}) bool { + if id, ok := key.(string); ok { + collected = append(collected, id) + } + return true + }) + + // Should have all IDs + assert.Len(t, collected, 3) + for _, id := range ids { + assert.Contains(t, collected, id) + } + }) + + t.Run("Close clears all sessions", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + + ctx := context.Background() + + // Add some sessions + for i := 0; i < 3; i++ { + session := NewProxySession(fmt.Sprintf("session-%d", i)) + err := storage.Store(ctx, session) + require.NoError(t, err) + } + + // Should have sessions + assert.Equal(t, 3, storage.Count()) + + // Close storage + err := storage.Close() + require.NoError(t, err) + + // Should have no sessions + assert.Equal(t, 0, storage.Count()) + }) + + t.Run("Context cancellation in DeleteExpired", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + defer storage.Close() + + // Create a cancelled context + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + // DeleteExpired should handle cancelled context gracefully + err := storage.DeleteExpired(ctx, time.Now()) + // Should not error, just stop early + assert.NoError(t, err) + }) +} + +// TestManagerWithStorage tests the Manager with the Storage interface +func TestManagerWithStorage(t *testing.T) { + t.Parallel() + t.Run("Manager with LocalStorage", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + factory := func(id string) Session { + return NewProxySession(id) + } + + manager := NewManagerWithStorage(30*time.Minute, factory, storage) + defer manager.Stop() + + // Add a session + err := manager.AddWithID("test-session-1") + require.NoError(t, err) + + // Get the session + session, found := manager.Get("test-session-1") + assert.True(t, found) + assert.NotNil(t, session) + assert.Equal(t, "test-session-1", session.ID()) + + // Delete the session + manager.Delete("test-session-1") + + // Should not be found + session, found = manager.Get("test-session-1") + assert.False(t, found) + assert.Nil(t, session) + }) + + t.Run("Manager with custom factory", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + factory := func(id string) Session { + // Create SSE sessions by default + return NewSSESession(id) + } + + manager := NewManagerWithStorage(30*time.Minute, factory, storage) + defer manager.Stop() + + // Add a session + err := manager.AddWithID("sse-session-1") + require.NoError(t, err) + + // Get the session + session, found := manager.Get("sse-session-1") + assert.True(t, found) + assert.NotNil(t, session) + assert.Equal(t, SessionTypeSSE, session.Type()) + }) + + t.Run("Manager AddSession method", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + factory := func(id string) Session { + return NewProxySession(id) + } + + manager := NewManagerWithStorage(30*time.Minute, factory, storage) + defer manager.Stop() + + // Create a custom session + customSession := NewTypedProxySession("custom-1", SessionTypeStreamable) + customSession.SetMetadata("custom", "metadata") + + // Add the custom session + err := manager.AddSession(customSession) + require.NoError(t, err) + + // Get the session + session, found := manager.Get("custom-1") + assert.True(t, found) + assert.NotNil(t, session) + assert.Equal(t, SessionTypeStreamable, session.Type()) + + metadata := session.GetMetadata() + assert.Equal(t, "metadata", metadata["custom"]) + }) + + t.Run("Manager Count with LocalStorage", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + factory := func(id string) Session { + return NewProxySession(id) + } + + manager := NewManagerWithStorage(30*time.Minute, factory, storage) + defer manager.Stop() + + // Initially empty + assert.Equal(t, 0, manager.Count()) + + // Add sessions + for i := 0; i < 3; i++ { + err := manager.AddWithID(fmt.Sprintf("session-%d", i)) + require.NoError(t, err) + } + + // Should have 3 sessions + assert.Equal(t, 3, manager.Count()) + }) + + t.Run("Manager Range with LocalStorage", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + factory := func(id string) Session { + return NewProxySession(id) + } + + manager := NewManagerWithStorage(30*time.Minute, factory, storage) + defer manager.Stop() + + // Add sessions + ids := []string{"one", "two", "three"} + for _, id := range ids { + err := manager.AddWithID(id) + require.NoError(t, err) + } + + // Use Range to collect all IDs + var collected []string + manager.Range(func(key, _ interface{}) bool { + if id, ok := key.(string); ok { + collected = append(collected, id) + } + return true + }) + + // Should have all IDs + assert.Len(t, collected, 3) + for _, id := range ids { + assert.Contains(t, collected, id) + } + }) +} + +// TestSessionTypes tests different session type implementations +func TestSessionTypes(t *testing.T) { + t.Parallel() + t.Run("ProxySession with Storage", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + defer storage.Close() + + session := NewProxySession("proxy-1") + session.SetMetadata("env", "production") + session.SetData(map[string]string{"key": "value"}) + + ctx := context.Background() + err := storage.Store(ctx, session) + require.NoError(t, err) + + loaded, err := storage.Load(ctx, "proxy-1") + require.NoError(t, err) + assert.Equal(t, SessionTypeMCP, loaded.Type()) + + metadata := loaded.GetMetadata() + assert.Equal(t, "production", metadata["env"]) + }) + + t.Run("SSESession with Storage", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + defer storage.Close() + + session := NewSSESession("sse-1") + session.SetMetadata("client", "browser") + + ctx := context.Background() + err := storage.Store(ctx, session) + require.NoError(t, err) + + loaded, err := storage.Load(ctx, "sse-1") + require.NoError(t, err) + assert.Equal(t, SessionTypeSSE, loaded.Type()) + + metadata := loaded.GetMetadata() + assert.Equal(t, "browser", metadata["client"]) + }) + + t.Run("StreamableSession with Storage", func(t *testing.T) { + t.Parallel() + storage := NewLocalStorage() + defer storage.Close() + + session := NewStreamableSession("stream-1") + session.SetMetadata("protocol", "http") + + ctx := context.Background() + err := storage.Store(ctx, session) + require.NoError(t, err) + + loaded, err := storage.Load(ctx, "stream-1") + require.NoError(t, err) + assert.Equal(t, SessionTypeStreamable, loaded.Type()) + + metadata := loaded.GetMetadata() + assert.Equal(t, "http", metadata["protocol"]) + }) +} diff --git a/pkg/transport/session/streamable_session.go b/pkg/transport/session/streamable_session.go index 5418c500a..56b611160 100644 --- a/pkg/transport/session/streamable_session.go +++ b/pkg/transport/session/streamable_session.go @@ -17,7 +17,7 @@ type StreamableSession struct { // NewStreamableSession constructs a new streamable session with buffered channels func NewStreamableSession(id string) Session { return &StreamableSession{ - ProxySession: &ProxySession{id: id}, + ProxySession: NewTypedProxySession(id, SessionTypeStreamable), MessageCh: make(chan jsonrpc2.Message, 100), ResponseCh: make(chan jsonrpc2.Message, 100), }