From 2f162e0223b8718d669e3600af4cde9a333e2efd Mon Sep 17 00:00:00 2001 From: Henry Barreto Date: Fri, 22 Aug 2025 16:06:06 -0300 Subject: [PATCH] feat(api): implement namespace-level rate limiting with token bucket Currently, Nginx rate limiting provides global protection but lacks isolation between namespaces. When multiple namespaces make API requests simultaneously, they compete for the same global limit. A token bucket implementation at the application level gives each namespace its own rate limit state. This means heavy traffic in one namespace won't affect others, allowing each to handle its own request patterns independently. While Nginx's leaky bucket drains requests at a fixed rate, Go's token bucket works differently: it stores tokens in a burst buffer, consuming them when new requests arrive and replenishing them at the defined rate. The idea of combining Nginx and a token bucket provides effective multi-layer protection: global traffic management at the edge, while maintaining precise per-namespace control at the application level. --- .env | 16 ++ api/routes/middleware/rate_limit.go | 302 ++++++++++++++++++++++++++++ api/routes/nsadm_test.go | 92 +++++++++ api/routes/routes.go | 49 ++++- api/routes/session_test.go | 2 +- api/server.go | 24 +++ docker-compose.yml | 3 + 7 files changed, 481 insertions(+), 7 deletions(-) create mode 100644 api/routes/middleware/rate_limit.go diff --git a/.env b/.env index c4df9f92d14..3d23ee1d9a2 100644 --- a/.env +++ b/.env @@ -167,3 +167,19 @@ SHELLHUB_API_BURST_SIZE=100 # Defines the delay strategy for handling bursts of incoming requests. # VALUES: nodelay, or the number of requests to delay. SHELLHUB_API_BURST_DELAY=nodelay + +# Namespace rate limiting configuration + +# Enables rate limiting for each namespace. +# VALUES: true, false +SHELLHUB_NAMESPACE_RATE_LIMIT=false + +# Defines the rate at which tokens are replenished into the bucket for the rate limiter. +# VALUES: Positive integer +SHELLHUB_NAMESPACE_RATE_LIMIT_RATE=1000 + +# Defines the maximum size of the bucket for the rate limiter. +# Each API request consumes one token from the bucket, that are replenished at the [SHELLHUB_NAMESPACE_RATE_LIMIT_RATE]. +# If the bucket is empty, the request is rejected. +# VALUES: Positive integer +SHELLHUB_NAMESPACE_RATE_LIMIT_BURST=1000 diff --git a/api/routes/middleware/rate_limit.go b/api/routes/middleware/rate_limit.go new file mode 100644 index 00000000000..2ebe74db2ce --- /dev/null +++ b/api/routes/middleware/rate_limit.go @@ -0,0 +1,302 @@ +package middleware + +import ( + "context" + "fmt" + "strings" + "sync" + "time" + + "github.com/labstack/echo/v4" + "github.com/labstack/echo/v4/middleware" + "github.com/shellhub-io/shellhub/api/pkg/gateway" + "github.com/shellhub-io/shellhub/api/services" + "github.com/shellhub-io/shellhub/pkg/models" + log "github.com/sirupsen/logrus" + "golang.org/x/time/rate" +) + +type Expirable[D any] struct { + data *D + duration time.Duration + lastSeen *time.Time +} + +// IsExpired checks if the structure is expired. +func (n *Expirable[D]) IsExpired() bool { + if n.lastSeen == nil { + return false + } + + if n.data == nil { + return false + } + + return time.Since(*n.lastSeen) > n.duration +} + +func (n *Expirable[D]) Get() *D { + if n.IsExpired() { + return nil + } + + return n.data +} + +const ( + // DefaultNamespaceCacheDuration is the default duration for which a namespace is cached. + DefaultNamespaceCacheDuration = 30 * time.Minute + // DefaultNamespaceRateLimit defines the rate at which tokens are replenished into the bucket for the rate limiter. + DefaultNamespaceRateLimit = 1000 + // DefaultNamespaceBurst defines the maximun size of the bucket for the rate limiter. + DefaultNamespaceBurst = 1000 +) + +func NewNamespaceCached(namespace *models.Namespace, duration time.Duration) *Expirable[models.Namespace] { + if duration <= 0 { + duration = DefaultNamespaceCacheDuration + } + + t := time.Now() + + return &Expirable[models.Namespace]{ + data: namespace, + duration: duration, + lastSeen: &t, + } +} + +type NamespaceRateLimitOptions struct { + // cacheDuration specifies how long the namespace cache should be valid. + cacheDuration time.Duration + // rate specify how many requests per second are allowed. + rate int + // burst specifies the maximum burst size for the rate limiter. + burst int +} + +func DefaultNamespaceRateLimitOptions() *NamespaceRateLimitOptions { + return &NamespaceRateLimitOptions{ + cacheDuration: DefaultNamespaceCacheDuration, + } +} + +type NamespaceRateLimitOption func(*NamespaceRateLimitOptions) *NamespaceRateLimitOptions + +// NamespaceRateLimitWithCacheDuration sets the duration for which the namespace cache is valid. +func NamespaceRateLimitWithCacheDuration(duration time.Duration) NamespaceRateLimitOption { + return func(options *NamespaceRateLimitOptions) *NamespaceRateLimitOptions { + options.cacheDuration = duration + + return options + } +} + +// NamespaceRateLimitWithRate sets the rate limit of requests per second for the rate limiter. +func NamespaceRateLimitWithRate(rate int) NamespaceRateLimitOption { + return func(options *NamespaceRateLimitOptions) *NamespaceRateLimitOptions { + options.rate = rate + + return options + } +} + +// NamespaceRateLimitWithBurst sets the burst size for the rate limiter. +func NamespaceRateLimitWithBurst(burst int) NamespaceRateLimitOption { + return func(options *NamespaceRateLimitOptions) *NamespaceRateLimitOptions { + options.burst = burst + + return options + } +} + +type NamespaceRateLimit struct { + config *NamespaceRateLimitOptions + + mutex *sync.Mutex + mutexts map[string]*sync.Mutex + + services services.Service + + cached map[string]*Expirable[models.Namespace] + limiters map[string]*rate.Limiter +} + +func NewNamespaceRateLimit(svc any, options ...NamespaceRateLimitOption) *NamespaceRateLimit { + s, _ := svc.(services.Service) + + config := &NamespaceRateLimitOptions{ + cacheDuration: DefaultNamespaceCacheDuration, + rate: DefaultNamespaceRateLimit, + burst: DefaultNamespaceBurst, + } + + for _, option := range options { + config = option(config) + } + + return &NamespaceRateLimit{ + config: config, + + mutex: new(sync.Mutex), + mutexts: make(map[string]*sync.Mutex), + + services: s, + + cached: make(map[string]*Expirable[models.Namespace]), + limiters: make(map[string]*rate.Limiter), + } +} + +// getTenantMutex gets or creates a mutex for the given tenant in a thread-safe way +func (l *NamespaceRateLimit) getTenantMutex(tenant string) *sync.Mutex { + l.mutex.Lock() + defer l.mutex.Unlock() + + mutex, exists := l.mutexts[tenant] + if !exists { + mutex = &sync.Mutex{} + l.mutexts[tenant] = mutex + } + + return mutex +} + +func (l *NamespaceRateLimit) Allow(tenant string) (bool, error) { + if l.services == nil { + log.Warn("rate limiter service is not configured - allowing request") + + return true, nil + } + + if strings.TrimSpace(tenant) == "" { + log.Error("tenant ID cannot be empty") + + return false, fmt.Errorf("tenant ID cannot be empty") + } + + mu := l.getTenantMutex(tenant) + + mu.Lock() + defer mu.Unlock() + + cached, exists := l.cached[tenant] + + needsRefresh := !exists || (cached != nil && cached.IsExpired()) + if needsRefresh { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + namespace, err := l.services.GetNamespace(ctx, tenant) + if err != nil { + log.WithFields(log.Fields{ + "tenant": tenant, + }).WithError(err).Error("failed to fetch namespace for rate limiter") + + return false, fmt.Errorf("failed to fetch namespace: %w", err) + } + + if namespace == nil { + return false, fmt.Errorf("namespace not found for tenant: %s", tenant) + } + + // TODO: We'll increase or decrease the rate dynamically based on the namespace characteristics in the future. + l.cached[tenant] = NewNamespaceCached(namespace, DefaultNamespaceCacheDuration) + l.limiters[tenant] = rate.NewLimiter(rate.Limit(l.config.rate), l.config.burst) + + log.WithFields(log.Fields{ + "tenant": tenant, + "namespace": namespace.Name, + }).Debug("namespace cache refreshed for rate limiter") + } + + limiter, exists := l.limiters[tenant] + if !exists { + log.WithField("tenant", tenant).Error("rate limiter visitor not found after cache refresh") + + return false, fmt.Errorf("rate limiter not configured for tenant: %s", tenant) + } + + allowed := limiter.Allow() + + log.WithFields(log.Fields{ + "tenant": tenant, + "allowed": allowed, + "tokens": limiter.Tokens(), + }).Debug("rate limiter check completed") + + return allowed, nil +} + +// CleanupExpiredEntries removes expired cache entries (call this periodically) +func (l *NamespaceRateLimit) CleanupExpiredEntries() { + l.mutex.Lock() + defer l.mutex.Unlock() + + for tenant, cached := range l.cached { + if cached != nil && cached.IsExpired() { + delete(l.cached, tenant) + delete(l.limiters, tenant) + delete(l.mutexts, tenant) + + log.WithField("tenant", tenant).Debug("cleaned up expired rate limiter cache entry") + } + } +} + +// SkipperNamespace is a function that checks if the context contains a valid tenant ID. +var SkipperNamespace = func(context echo.Context) bool { + c, ok := context.(*gateway.Context) + if !ok { + log.Error("context is not of type gateway.Context for rate limiting") + + return true + } + + tenant, ok := c.GetTennat() + if !ok || tenant == "" { + log.Error("tenant ID cannot be empty in request context for rate limiting") + + return true + } + + return false +} + +// NewNamespaceRateLimitMiddleware creates a middleware that limits the rate of requests based on the tenant ID +// extracted from the request context. +func NewNamespaceRateLimitMiddleware(service any, options ...NamespaceRateLimitOption) echo.MiddlewareFunc { + return middleware.RateLimiterWithConfig(middleware.RateLimiterConfig{ + Skipper: SkipperNamespace, + IdentifierExtractor: func(context echo.Context) (string, error) { + c, ok := context.(*gateway.Context) + if !ok { + return "", fmt.Errorf("context is not of type gateway.Context") + } + + tenant, ok := c.GetTennat() + if !ok || tenant == "" { + log.Error("tenant ID cannot be empty in request context for rate limiting") + + return "", fmt.Errorf("tenant ID cannot be empty in request context for rate limiting") + } + + return tenant, nil + }, + Store: NewNamespaceRateLimit(service, options...), + ErrorHandler: func(c echo.Context, err error) error { + return &echo.HTTPError{ + Code: middleware.ErrRateLimitExceeded.Code, + Message: middleware.ErrRateLimitExceeded.Message, + Internal: err, + } + }, + DenyHandler: func(c echo.Context, identifier string, err error) error { + return &echo.HTTPError{ + Code: middleware.ErrRateLimitExceeded.Code, + Message: middleware.ErrRateLimitExceeded.Message, + Internal: err, + } + }, + }) +} diff --git a/api/routes/nsadm_test.go b/api/routes/nsadm_test.go index e5884426d3a..4d05e4282d2 100644 --- a/api/routes/nsadm_test.go +++ b/api/routes/nsadm_test.go @@ -9,6 +9,7 @@ import ( "net/http/httptest" "strings" "testing" + "time" svc "github.com/shellhub-io/shellhub/api/services" "github.com/shellhub-io/shellhub/api/services/mocks" @@ -458,3 +459,94 @@ func TestHandler_LeaveNamespace(t *testing.T) { svcMock.AssertExpectations(t) } + +func TestNamespaceRateLimit(t *testing.T) { + // NOTE: The "delays" in each test case are designed to test the rate limiting behavior, but is it very + // timing-dependent and fragile. If the test starts failing randomly, take into consideration that the timing may be + // off due to CPU load or other factors. + cases := []struct { + description string + rate int + burst int + delays []time.Duration + expectedStatusCodes []int + }{ + { + description: "Exceed rate limit then restore", + rate: 2, + burst: 2, + delays: []time.Duration{0, 0, 0, 600 * time.Millisecond, 0}, + expectedStatusCodes: []int{ + http.StatusOK, + http.StatusOK, + http.StatusTooManyRequests, + http.StatusOK, + http.StatusTooManyRequests, + }, + }, + { + description: "All requests within limit", + rate: 10, + burst: 5, + delays: []time.Duration{0, 200 * time.Millisecond, 200 * time.Millisecond, 200 * time.Millisecond, 200 * time.Millisecond}, + expectedStatusCodes: []int{ + http.StatusOK, + http.StatusOK, + http.StatusOK, + http.StatusOK, + http.StatusOK, + }, + }, + } + + for _, c := range cases { + t.Run(c.description, func(t *testing.T) { + svcMock := new(mocks.Service) + + namespace := &models.Namespace{ + Name: "rate-limit-test", + TenantID: "tenant123", + Settings: &models.NamespaceSettings{ + SessionRecord: true, + }, + } + + svcMock.On("GetNamespace", gomock.Anything, "tenant123").Return(namespace, nil).Maybe() + + opts := []Option{ + WithNamespaceRateLimit(c.rate, c.burst, time.Minute), + } + e := NewRouter(svcMock, opts...) + + okResponses := 0 + for _, status := range c.expectedStatusCodes { + if status == http.StatusOK { + okResponses++ + } + } + + namespaces := []models.Namespace{*namespace} + svcMock.On("ListNamespaces", gomock.Anything, gomock.AnythingOfType("*requests.NamespaceList")).Return(namespaces, 1, nil).Times(okResponses) + + for i, delay := range c.delays { + if delay > 0 { + time.Sleep(delay) + } + + req := httptest.NewRequest(http.MethodGet, "/api/namespaces", nil) + req.Header.Set("Content-Type", "application/json") + req.Header.Set("X-Role", authorizer.RoleOwner.String()) + req.Header.Set("X-Tenant-ID", "tenant123") + rec := httptest.NewRecorder() + + e.ServeHTTP(rec, req) + + assert.Equal(t, c.expectedStatusCodes[i], rec.Result().StatusCode, + "Request %d should have status %d (rate limit: %d req/s, burst: %d)", + i+1, c.expectedStatusCodes[i], c.rate, c.burst) + } + + svcMock.AssertExpectations(t) + }) + } +} diff --git a/api/routes/routes.go b/api/routes/routes.go index b35c75de349..dba1fb8027a 100644 --- a/api/routes/routes.go +++ b/api/routes/routes.go @@ -2,6 +2,7 @@ package routes import ( "net/http" + "time" "github.com/getsentry/sentry-go" "github.com/labstack/echo/v4" @@ -16,16 +17,24 @@ import ( "github.com/shellhub-io/shellhub/pkg/websocket" ) -type DefaultHTTPHandlerConfig struct { +type HandlerConfig struct { // Reporter represents an instance of [*sentry.Client] that should be proper configured to send error messages // from the error handler. If it's nil, the error handler will ignore the Sentry client. Reporter *sentry.Client + + NamespaceRateLimitCacheDuration time.Duration + // NamespaceRateLimit defines if the rate limiter is enabled for namespaces. + NamespaceRateLimit bool + // NamespaceRateLimitRate is the rate limit of requests per second for a namespace. + NamespaceRateLimitRate int + // NamespaceRateLimitBurst is the burst size for the rate limiter. + NamespaceRateLimitBurst int } // DefaultHTTPHandler creates an HTTP handler, using [github.com/labstack/echo/v4] package, with the default // configuration required by ShellHub's services, loading the [github.com/shellhub-io/shellhub/api/pkg/gateway] into // the context, and the service layer. The configuration received controls the error reporter and more. -func DefaultHTTPHandler[S any](service S, cfg *DefaultHTTPHandlerConfig) http.Handler { +func DefaultHTTPHandler[S any](service S, cfg *HandlerConfig) http.Handler { server := echo.New() // Sets the default binder. @@ -54,22 +63,41 @@ func DefaultHTTPHandler[S any](service S, cfg *DefaultHTTPHandlerConfig) http.Ha return server } -type Option func(e *echo.Echo, handler *Handler) error +type Option func(e *echo.Echo, handler *Handler, cfg *HandlerConfig) error func WithReporter(reporter *sentry.Client) Option { - return func(e *echo.Echo, _ *Handler) error { + return func(e *echo.Echo, _ *Handler, _ *HandlerConfig) error { e.HTTPErrorHandler = handlers.NewErrors(reporter) return nil } } +func WithNamespaceRateLimit(rate int, burst int, cacheDuration time.Duration) Option { + return func(e *echo.Echo, _ *Handler, cfg *HandlerConfig) error { + cfg.NamespaceRateLimit = true + cfg.NamespaceRateLimitRate = rate + cfg.NamespaceRateLimitBurst = burst + cfg.NamespaceRateLimitCacheDuration = cacheDuration + + return nil + } +} + func NewRouter(service services.Service, opts ...Option) *echo.Echo { - router := DefaultHTTPHandler(service, new(DefaultHTTPHandlerConfig)).(*echo.Echo) + config := &HandlerConfig{ + Reporter: nil, + NamespaceRateLimit: false, + NamespaceRateLimitRate: 1000, + NamespaceRateLimitBurst: 1000, + NamespaceRateLimitCacheDuration: 30 * time.Minute, + } + + router := DefaultHTTPHandler(service, config).(*echo.Echo) handler := NewHandler(service, websocket.NewGorillaWebSocketUpgrader()) for _, opt := range opts { - if err := opt(router, handler); err != nil { + if err := opt(router, handler, config); err != nil { return nil } } @@ -97,6 +125,15 @@ func NewRouter(service services.Service, opts ...Option) *echo.Echo { publicAPI := router.Group("/api") publicAPI.GET(HealthCheckURL, gateway.Handler(handler.EvaluateHealth)) + if config.NamespaceRateLimit { + publicAPI.Use(routesmiddleware.NewNamespaceRateLimitMiddleware( + service, + routesmiddleware.NamespaceRateLimitWithCacheDuration(config.NamespaceRateLimitCacheDuration), + routesmiddleware.NamespaceRateLimitWithRate(config.NamespaceRateLimitRate), + routesmiddleware.NamespaceRateLimitWithBurst(config.NamespaceRateLimitBurst), + )) + } + publicAPI.GET(AuthLocalUserURLV2, gateway.Handler(handler.CreateUserToken)) // TODO: method POST publicAPI.GET(AuthUserTokenPublicURL, gateway.Handler(handler.CreateUserToken), routesmiddleware.BlockAPIKey) // TODO: method POST publicAPI.POST(AuthDeviceURL, gateway.Handler(handler.AuthDevice)) diff --git a/api/routes/session_test.go b/api/routes/session_test.go index 7b8dd0a9419..43445396693 100644 --- a/api/routes/session_test.go +++ b/api/routes/session_test.go @@ -446,7 +446,7 @@ func TestEventSession(t *testing.T) { req.Header.Set("Sec-WebSocket-Version", "13") req.Header.Set("Sec-WebSocket-Key", "test") - e := NewRouter(mock, func(_ *echo.Echo, handler *Handler) error { + e := NewRouter(mock, func(_ *echo.Echo, handler *Handler, _ *HandlerConfig) error { handler.WebSocketUpgrader = webSocketUpgraderMock return nil diff --git a/api/server.go b/api/server.go index aa6a2804236..89791f34248 100644 --- a/api/server.go +++ b/api/server.go @@ -3,6 +3,7 @@ package main import ( "context" "os" + "time" "github.com/getsentry/sentry-go" "github.com/labstack/echo/v4" @@ -52,6 +53,15 @@ type env struct { // GeoipMaxmindLicense is the MaxMind license key for downloading GeoIP databases directly. // This is used as a fallback when GeoipMirror is not configured. GeoipMaxmindLicense string `env:"MAXMIND_LICENSE,default="` + + // NamespaceRateLimit enables or disables rate limiting for namespace-related operations. + NamespaceRateLimit bool `env:"NAMESPACE_RATE_LIMIT,default=false"` + + // NamespaceRateLimitRate defines the rate at which tokens are replenished into the bucket for the rate limiter. + NamespaceRateLimitRate int `env:"NAMESPACE_RATE_LIMIT_RATE,default=1000"` + + // NamespaceRateLimitBurst defines the maximum size of the bucket for the rate limiter. + NamespaceRateLimitBurst int `env:"NAMESPACE_RATE_LIMIT_BURST,default=1000"` } type Server struct { @@ -191,5 +201,19 @@ func (s *Server) routerOptions() ([]routes.Option, error) { opts = append(opts, routes.WithReporter(reporter)) } + if s.env.NamespaceRateLimit { + log.WithFields(log.Fields{ + "rate": s.env.NamespaceRateLimitRate, + "burst": s.env.NamespaceRateLimitBurst, + }).Info("Configuring namespace rate limiting") + + opts = append(opts, routes.WithNamespaceRateLimit( + s.env.NamespaceRateLimitRate, + s.env.NamespaceRateLimitBurst, + // TODO: Receive this value from the environment variable. + 30*time.Minute, + )) + } + return opts, nil } diff --git a/docker-compose.yml b/docker-compose.yml index 1fed9b20622..4487da82d33 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -51,6 +51,9 @@ services: - ASYNQ_UNIQUENESS_TIMEOUT=${SHELLHUB_ASYNQ_UNIQUENESS_TIMEOUT} - REDIS_CACHE_POOL_SIZE=${SHELLHUB_REDIS_CACHE_POOL_SIZE} - MAXIMUM_ACCOUNT_LOCKOUT=${SHELLHUB_MAXIMUM_ACCOUNT_LOCKOUT} + - NAMESPACE_RATE_LIMIT=${SHELLHUB_NAMESPACE_RATE_LIMIT} + - NAMESPACE_RATE_LIMIT_RATE=${SHELLHUB_NAMESPACE_RATE_LIMIT_RATE} + - NAMESPACE_RATE_LIMIT_BURST=${SHELLHUB_NAMESPACE_RATE_LIMIT_BURST} depends_on: - mongo - redis