diff --git a/.env b/.env index c4df9f92d14..3d23ee1d9a2 100644 --- a/.env +++ b/.env @@ -167,3 +167,19 @@ SHELLHUB_API_BURST_SIZE=100 # Defines the delay strategy for handling bursts of incoming requests. # VALUES: nodelay, or the number of requests to delay. SHELLHUB_API_BURST_DELAY=nodelay + +# Namespace rate limiting configuration + +# Enables rate limiting for each namespace. +# VALUES: true, false +SHELLHUB_NAMESPACE_RATE_LIMIT=false + +# Defines the rate at which tokens are replenished into the bucket for the rate limiter. +# VALUES: Positive integer +SHELLHUB_NAMESPACE_RATE_LIMIT_RATE=1000 + +# Defines the maximum size of the bucket for the rate limiter. +# Each API request consumes one token from the bucket, that are replenished at the [SHELLHUB_NAMESPACE_RATE_LIMIT_RATE]. +# If the bucket is empty, the request is rejected. +# VALUES: Positive integer +SHELLHUB_NAMESPACE_RATE_LIMIT_BURST=1000 diff --git a/api/routes/middleware/rate_limit.go b/api/routes/middleware/rate_limit.go new file mode 100644 index 00000000000..2ebe74db2ce --- /dev/null +++ b/api/routes/middleware/rate_limit.go @@ -0,0 +1,302 @@ +package middleware + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" + "github.com/shellhub-io/shellhub/api/pkg/gateway" + "github.com/shellhub-io/shellhub/api/services" + "github.com/shellhub-io/shellhub/pkg/models" + log "github.com/sirupsen/logrus" + "golang.org/x/time/rate" +) + +type Expirable[D any] struct { + data *D + duration time.Duration + lastSeen *time.Time +} + +// IsExpired checks if the structure is expired. +func (n *Expirable[D]) IsExpired() bool { + if n.lastSeen == nil { + return false + } + + if n.data == nil { + return false + } + + return time.Since(*n.lastSeen) > n.duration +} + +func (n *Expirable[D]) Get() *D { + if n.IsExpired() { + return nil + } + + return n.data +} + +const ( + // DefaultNamespaceCacheDuration is the default duration for which a namespace is cached. + DefaultNamespaceCacheDuration = 30 * time.Minute + // DefaultNamespaceRateLimit defines the rate at which tokens are replenished into the bucket for the rate limiter. + DefaultNamespaceRateLimit = 1000 + // DefaultNamespaceBurst defines the maximun size of the bucket for the rate limiter. + DefaultNamespaceBurst = 1000 +) + +func NewNamespaceCached(namespace *models.Namespace, duration time.Duration) *Expirable[models.Namespace] { + if duration <= 0 { + duration = DefaultNamespaceCacheDuration + } + + t := time.Now() + + return &Expirable[models.Namespace]{ + data: namespace, + duration: duration, + lastSeen: &t, + } +} + +type NamespaceRateLimitOptions struct { + // cacheDuration specifies how long the namespace cache should be valid. + cacheDuration time.Duration + // rate specify how many requests per second are allowed. + rate int + // burst specifies the maximum burst size for the rate limiter. + burst int +} + +func DefaultNamespaceRateLimitOptions() *NamespaceRateLimitOptions { + return &NamespaceRateLimitOptions{ + cacheDuration: DefaultNamespaceCacheDuration, + } +} + +type NamespaceRateLimitOption func(*NamespaceRateLimitOptions) *NamespaceRateLimitOptions + +// NamespaceRateLimitWithCacheDuration sets the duration for which the namespace cache is valid. +func NamespaceRateLimitWithCacheDuration(duration time.Duration) NamespaceRateLimitOption { + return func(options *NamespaceRateLimitOptions) *NamespaceRateLimitOptions { + options.cacheDuration = duration + + return options + } +} + +// NamespaceRateLimitWithRate sets the rate limit of requests per second for the rate limiter. +func NamespaceRateLimitWithRate(rate int) NamespaceRateLimitOption { + return func(options *NamespaceRateLimitOptions) *NamespaceRateLimitOptions { + options.rate = rate + + return options + } +} + +// NamespaceRateLimitWithBurst sets the burst size for the rate limiter. +func NamespaceRateLimitWithBurst(burst int) NamespaceRateLimitOption { + return func(options *NamespaceRateLimitOptions) *NamespaceRateLimitOptions { + options.burst = burst + + return options + } +} + +type NamespaceRateLimit struct { + config *NamespaceRateLimitOptions + + mutex *sync.Mutex + mutexts map[string]*sync.Mutex + + services services.Service + + cached map[string]*Expirable[models.Namespace] + limiters map[string]*rate.Limiter +} + +func NewNamespaceRateLimit(svc any, options ...NamespaceRateLimitOption) *NamespaceRateLimit { + s, _ := svc.(services.Service) + + config := &NamespaceRateLimitOptions{ + cacheDuration: DefaultNamespaceCacheDuration, + rate: DefaultNamespaceRateLimit, + burst: DefaultNamespaceBurst, + } + + for _, option := range options { + config = option(config) + } + + return &NamespaceRateLimit{ + config: config, + + mutex: new(sync.Mutex), + mutexts: make(map[string]*sync.Mutex), + + services: s, + + cached: make(map[string]*Expirable[models.Namespace]), + limiters: make(map[string]*rate.Limiter), + } +} + +// getTenantMutex gets or creates a mutex for the given tenant in a thread-safe way +func (l *NamespaceRateLimit) getTenantMutex(tenant string) *sync.Mutex { + l.mutex.Lock() + defer l.mutex.Unlock() + + mutex, exists := l.mutexts[tenant] + if !exists { + mutex = &sync.Mutex{} + l.mutexts[tenant] = mutex + } + + return mutex +} + +func (l *NamespaceRateLimit) Allow(tenant string) (bool, error) { + if l.services == nil { + log.Warn("rate limiter service is not configured - allowing request") + + return true, nil + } + + if strings.TrimSpace(tenant) == "" { + log.Error("tenant ID cannot be empty") + + return false, fmt.Errorf("tenant ID cannot be empty") + } + + mu := l.getTenantMutex(tenant) + + mu.Lock() + defer mu.Unlock() + + cached, exists := l.cached[tenant] + + needsRefresh := !exists || (cached != nil && cached.IsExpired()) + if needsRefresh { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + namespace, err := l.services.GetNamespace(ctx, tenant) + if err != nil { + log.WithFields(log.Fields{ + "tenant": tenant, + }).WithError(err).Error("failed to fetch namespace for rate limiter") + + return false, fmt.Errorf("failed to fetch namespace: %w", err) + } + + if namespace == nil { + return false, fmt.Errorf("namespace not found for tenant: %s", tenant) + } + + // TODO: We'll increase or decrease the rate dynamically based on the namespace characteristics in the future. + l.cached[tenant] = NewNamespaceCached(namespace, DefaultNamespaceCacheDuration) + l.limiters[tenant] = rate.NewLimiter(rate.Limit(l.config.rate), l.config.burst) + + log.WithFields(log.Fields{ + "tenant": tenant, + "namespace": namespace.Name, + }).Debug("namespace cache refreshed for rate limiter") + } + + limiter, exists := l.limiters[tenant] + if !exists { + log.WithField("tenant", tenant).Error("rate limiter visitor not found after cache refresh") + + return false, fmt.Errorf("rate limiter not configured for tenant: %s", tenant) + } + + allowed := limiter.Allow() + + log.WithFields(log.Fields{ + "tenant": tenant, + "allowed": allowed, + "tokens": limiter.Tokens(), + }).Debug("rate limiter check completed") + + return allowed, nil +} + +// CleanupExpiredEntries removes expired cache entries (call this periodically) +func (l *NamespaceRateLimit) CleanupExpiredEntries() { + l.mutex.Lock() + defer l.mutex.Unlock() + + for tenant, cached := range l.cached { + if cached != nil && cached.IsExpired() { + delete(l.cached, tenant) + delete(l.limiters, tenant) + delete(l.mutexts, tenant) + + log.WithField("tenant", tenant).Debug("cleaned up expired rate limiter cache entry") + } + } +} + +// SkipperNamespace is a function that checks if the context contains a valid tenant ID. +var SkipperNamespace = func(context echo.Context) bool { + c, ok := context.(*gateway.Context) + if !ok { + log.Error("context is not of type gateway.Context for rate limiting") + + return true + } + + tenant, ok := c.GetTennat() + if !ok || tenant == "" { + log.Error("tenant ID cannot be empty in request context for rate limiting") + + return true + } + + return false +} + +// NewNamespaceRateLimitMiddleware creates a middleware that limits the rate of requests based on the tenant ID +// extracted from the request context. +func NewNamespaceRateLimitMiddleware(service any, options ...NamespaceRateLimitOption) echo.MiddlewareFunc { + return middleware.RateLimiterWithConfig(middleware.RateLimiterConfig{ + Skipper: SkipperNamespace, + IdentifierExtractor: func(context echo.Context) (string, error) { + c, ok := context.(*gateway.Context) + if !ok { + return "", fmt.Errorf("context is not of type gateway.Context") + } + + tenant, ok := c.GetTennat() + if !ok || tenant == "" { + log.Error("tenant ID cannot be empty in request context for rate limiting") + + return "", fmt.Errorf("tenant ID cannot be empty in request context for rate limiting") + } + + return tenant, nil + }, + Store: NewNamespaceRateLimit(service, options...), + ErrorHandler: func(c echo.Context, err error) error { + return &echo.HTTPError{ + Code: middleware.ErrRateLimitExceeded.Code, + Message: middleware.ErrRateLimitExceeded.Message, + Internal: err, + } + }, + DenyHandler: func(c echo.Context, identifier string, err error) error { + return &echo.HTTPError{ + Code: middleware.ErrRateLimitExceeded.Code, + Message: middleware.ErrRateLimitExceeded.Message, + Internal: err, + } + }, + }) +} diff --git a/api/routes/nsadm_test.go b/api/routes/nsadm_test.go index e5884426d3a..4d05e4282d2 100644 --- a/api/routes/nsadm_test.go +++ b/api/routes/nsadm_test.go @@ -9,6 +9,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" svc "github.com/shellhub-io/shellhub/api/services" "github.com/shellhub-io/shellhub/api/services/mocks" @@ -458,3 +459,94 @@ func TestHandler_LeaveNamespace(t *testing.T) { svcMock.AssertExpectations(t) } + +func TestNamespaceRateLimit(t *testing.T) { + // NOTE: The "delays" in each test case are designed to test the rate limiting behavior, but is it very + // timing-dependent and fragile. If the test starts failing randomly, take into consideration that the timing may be + // off due to CPU load or other factors. + cases := []struct { + description string + rate int + burst int + delays []time.Duration + expectedStatusCodes []int + }{ + { + description: "Exceed rate limit then restore", + rate: 2, + burst: 2, + delays: []time.Duration{0, 0, 0, 600 * time.Millisecond, 0}, + expectedStatusCodes: []int{ + http.StatusOK, + http.StatusOK, + http.StatusTooManyRequests, + http.StatusOK, + http.StatusTooManyRequests, + }, + }, + { + description: "All requests within limit", + rate: 10, + burst: 5, + delays: []time.Duration{0, 200 * time.Millisecond, 200 * time.Millisecond, 200 * time.Millisecond, 200 * time.Millisecond}, + expectedStatusCodes: []int{ + http.StatusOK, + http.StatusOK, + http.StatusOK, + http.StatusOK, + http.StatusOK, + }, + }, + } + + for _, c := range cases { + t.Run(c.description, func(t *testing.T) { + svcMock := new(mocks.Service) + + namespace := &models.Namespace{ + Name: "rate-limit-test", + TenantID: "tenant123", + Settings: &models.NamespaceSettings{ + SessionRecord: true, + }, + } + + svcMock.On("GetNamespace", gomock.Anything, "tenant123").Return(namespace, nil).Maybe() + + opts := []Option{ + WithNamespaceRateLimit(c.rate, c.burst, time.Minute), + } + e := NewRouter(svcMock, opts...) + + okResponses := 0 + for _, status := range c.expectedStatusCodes { + if status == http.StatusOK { + okResponses++ + } + } + + namespaces := []models.Namespace{*namespace} + svcMock.On("ListNamespaces", gomock.Anything, gomock.AnythingOfType("*requests.NamespaceList")).Return(namespaces, 1, nil).Times(okResponses) + + for i, delay := range c.delays { + if delay > 0 { + time.Sleep(delay) + } + + req := httptest.NewRequest(http.MethodGet, "/api/namespaces", nil) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Role", authorizer.RoleOwner.String()) + req.Header.Set("X-Tenant-ID", "tenant123") + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, c.expectedStatusCodes[i], rec.Result().StatusCode, + "Request %d should have status %d (rate limit: %d req/s, burst: %d)", + i+1, c.expectedStatusCodes[i], c.rate, c.burst) + } + + svcMock.AssertExpectations(t) + }) + } +} diff --git a/api/routes/routes.go b/api/routes/routes.go index b35c75de349..dba1fb8027a 100644 --- a/api/routes/routes.go +++ b/api/routes/routes.go @@ -2,6 +2,7 @@ package routes import ( "net/http" + "time" "github.com/getsentry/sentry-go" "github.com/labstack/echo/v4" @@ -16,16 +17,24 @@ import ( "github.com/shellhub-io/shellhub/pkg/websocket" ) -type DefaultHTTPHandlerConfig struct { +type HandlerConfig struct { // Reporter represents an instance of [*sentry.Client] that should be proper configured to send error messages // from the error handler. If it's nil, the error handler will ignore the Sentry client. Reporter *sentry.Client + + NamespaceRateLimitCacheDuration time.Duration + // NamespaceRateLimit defines if the rate limiter is enabled for namespaces. + NamespaceRateLimit bool + // NamespaceRateLimitRate is the rate limit of requests per second for a namespace. + NamespaceRateLimitRate int + // NamespaceRateLimitBurst is the burst size for the rate limiter. + NamespaceRateLimitBurst int } // DefaultHTTPHandler creates an HTTP handler, using [github.com/labstack/echo/v4] package, with the default // configuration required by ShellHub's services, loading the [github.com/shellhub-io/shellhub/api/pkg/gateway] into // the context, and the service layer. The configuration received controls the error reporter and more. -func DefaultHTTPHandler[S any](service S, cfg *DefaultHTTPHandlerConfig) http.Handler { +func DefaultHTTPHandler[S any](service S, cfg *HandlerConfig) http.Handler { server := echo.New() // Sets the default binder. @@ -54,22 +63,41 @@ func DefaultHTTPHandler[S any](service S, cfg *DefaultHTTPHandlerConfig) http.Ha return server } -type Option func(e *echo.Echo, handler *Handler) error +type Option func(e *echo.Echo, handler *Handler, cfg *HandlerConfig) error func WithReporter(reporter *sentry.Client) Option { - return func(e *echo.Echo, _ *Handler) error { + return func(e *echo.Echo, _ *Handler, _ *HandlerConfig) error { e.HTTPErrorHandler = handlers.NewErrors(reporter) return nil } } +func WithNamespaceRateLimit(rate int, burst int, cacheDuration time.Duration) Option { + return func(e *echo.Echo, _ *Handler, cfg *HandlerConfig) error { + cfg.NamespaceRateLimit = true + cfg.NamespaceRateLimitRate = rate + cfg.NamespaceRateLimitBurst = burst + cfg.NamespaceRateLimitCacheDuration = cacheDuration + + return nil + } +} + func NewRouter(service services.Service, opts ...Option) *echo.Echo { - router := DefaultHTTPHandler(service, new(DefaultHTTPHandlerConfig)).(*echo.Echo) + config := &HandlerConfig{ + Reporter: nil, + NamespaceRateLimit: false, + NamespaceRateLimitRate: 1000, + NamespaceRateLimitBurst: 1000, + NamespaceRateLimitCacheDuration: 30 * time.Minute, + } + + router := DefaultHTTPHandler(service, config).(*echo.Echo) handler := NewHandler(service, websocket.NewGorillaWebSocketUpgrader()) for _, opt := range opts { - if err := opt(router, handler); err != nil { + if err := opt(router, handler, config); err != nil { return nil } } @@ -97,6 +125,15 @@ func NewRouter(service services.Service, opts ...Option) *echo.Echo { publicAPI := router.Group("/api") publicAPI.GET(HealthCheckURL, gateway.Handler(handler.EvaluateHealth)) + if config.NamespaceRateLimit { + publicAPI.Use(routesmiddleware.NewNamespaceRateLimitMiddleware( + service, + routesmiddleware.NamespaceRateLimitWithCacheDuration(config.NamespaceRateLimitCacheDuration), + routesmiddleware.NamespaceRateLimitWithRate(config.NamespaceRateLimitRate), + routesmiddleware.NamespaceRateLimitWithBurst(config.NamespaceRateLimitBurst), + )) + } + publicAPI.GET(AuthLocalUserURLV2, gateway.Handler(handler.CreateUserToken)) // TODO: method POST publicAPI.GET(AuthUserTokenPublicURL, gateway.Handler(handler.CreateUserToken), routesmiddleware.BlockAPIKey) // TODO: method POST publicAPI.POST(AuthDeviceURL, gateway.Handler(handler.AuthDevice)) diff --git a/api/routes/session_test.go b/api/routes/session_test.go index 7b8dd0a9419..43445396693 100644 --- a/api/routes/session_test.go +++ b/api/routes/session_test.go @@ -446,7 +446,7 @@ func TestEventSession(t *testing.T) { req.Header.Set("Sec-WebSocket-Version", "13") req.Header.Set("Sec-WebSocket-Key", "test") - e := NewRouter(mock, func(_ *echo.Echo, handler *Handler) error { + e := NewRouter(mock, func(_ *echo.Echo, handler *Handler, _ *HandlerConfig) error { handler.WebSocketUpgrader = webSocketUpgraderMock return nil diff --git a/api/server.go b/api/server.go index aa6a2804236..89791f34248 100644 --- a/api/server.go +++ b/api/server.go @@ -3,6 +3,7 @@ package main import ( "context" "os" + "time" "github.com/getsentry/sentry-go" "github.com/labstack/echo/v4" @@ -52,6 +53,15 @@ type env struct { // GeoipMaxmindLicense is the MaxMind license key for downloading GeoIP databases directly. // This is used as a fallback when GeoipMirror is not configured. GeoipMaxmindLicense string `env:"MAXMIND_LICENSE,default="` + + // NamespaceRateLimit enables or disables rate limiting for namespace-related operations. + NamespaceRateLimit bool `env:"NAMESPACE_RATE_LIMIT,default=false"` + + // NamespaceRateLimitRate defines the rate at which tokens are replenished into the bucket for the rate limiter. + NamespaceRateLimitRate int `env:"NAMESPACE_RATE_LIMIT_RATE,default=1000"` + + // NamespaceRateLimitBurst defines the maximum size of the bucket for the rate limiter. + NamespaceRateLimitBurst int `env:"NAMESPACE_RATE_LIMIT_BURST,default=1000"` } type Server struct { @@ -191,5 +201,19 @@ func (s *Server) routerOptions() ([]routes.Option, error) { opts = append(opts, routes.WithReporter(reporter)) } + if s.env.NamespaceRateLimit { + log.WithFields(log.Fields{ + "rate": s.env.NamespaceRateLimitRate, + "burst": s.env.NamespaceRateLimitBurst, + }).Info("Configuring namespace rate limiting") + + opts = append(opts, routes.WithNamespaceRateLimit( + s.env.NamespaceRateLimitRate, + s.env.NamespaceRateLimitBurst, + // TODO: Receive this value from the environment variable. + 30*time.Minute, + )) + } + return opts, nil } diff --git a/docker-compose.yml b/docker-compose.yml index 1fed9b20622..4487da82d33 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -51,6 +51,9 @@ services: - ASYNQ_UNIQUENESS_TIMEOUT=${SHELLHUB_ASYNQ_UNIQUENESS_TIMEOUT} - REDIS_CACHE_POOL_SIZE=${SHELLHUB_REDIS_CACHE_POOL_SIZE} - MAXIMUM_ACCOUNT_LOCKOUT=${SHELLHUB_MAXIMUM_ACCOUNT_LOCKOUT} + - NAMESPACE_RATE_LIMIT=${SHELLHUB_NAMESPACE_RATE_LIMIT} + - NAMESPACE_RATE_LIMIT_RATE=${SHELLHUB_NAMESPACE_RATE_LIMIT_RATE} + - NAMESPACE_RATE_LIMIT_BURST=${SHELLHUB_NAMESPACE_RATE_LIMIT_BURST} depends_on: - mongo - redis