Skip to content

Commit 2f162e0

Browse files
committed
feat(api): implement namespace-level rate limiting with token bucket
Currently, Nginx rate limiting provides global protection but lacks isolation between namespaces. When multiple namespaces make API requests simultaneously, they compete for the same global limit. A token bucket implementation at the application level gives each namespace its own rate limit state. This means heavy traffic in one namespace won't affect others, allowing each to handle its own request patterns independently. While Nginx's leaky bucket drains requests at a fixed rate, Go's token bucket works differently: it stores tokens in a burst buffer, consuming them when new requests arrive and replenishing them at the defined rate. The idea of combining Nginx and a token bucket provides effective multi-layer protection: global traffic management at the edge, while maintaining precise per-namespace control at the application level.
1 parent e96dd7f commit 2f162e0

File tree

7 files changed

+481
-7
lines changed

7 files changed

+481
-7
lines changed

.env

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,3 +167,19 @@ SHELLHUB_API_BURST_SIZE=100
167167
# Defines the delay strategy for handling bursts of incoming requests.
168168
# VALUES: nodelay, or the number of requests to delay.
169169
SHELLHUB_API_BURST_DELAY=nodelay
170+
171+
# Namespace rate limiting configuration
172+
173+
# Enables rate limiting for each namespace.
174+
# VALUES: true, false
175+
SHELLHUB_NAMESPACE_RATE_LIMIT=false
176+
177+
# Defines the rate at which tokens are replenished into the bucket for the rate limiter.
178+
# VALUES: Positive integer
179+
SHELLHUB_NAMESPACE_RATE_LIMIT_RATE=1000
180+
181+
# Defines the maximum size of the bucket for the rate limiter.
182+
# Each API request consumes one token from the bucket, that are replenished at the [SHELLHUB_NAMESPACE_RATE_LIMIT_RATE].
183+
# If the bucket is empty, the request is rejected.
184+
# VALUES: Positive integer
185+
SHELLHUB_NAMESPACE_RATE_LIMIT_BURST=1000
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
package middleware
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"strings"
7+
"sync"
8+
"time"
9+
10+
"github.com/labstack/echo/v4"
11+
"github.com/labstack/echo/v4/middleware"
12+
"github.com/shellhub-io/shellhub/api/pkg/gateway"
13+
"github.com/shellhub-io/shellhub/api/services"
14+
"github.com/shellhub-io/shellhub/pkg/models"
15+
log "github.com/sirupsen/logrus"
16+
"golang.org/x/time/rate"
17+
)
18+
19+
type Expirable[D any] struct {
20+
data *D
21+
duration time.Duration
22+
lastSeen *time.Time
23+
}
24+
25+
// IsExpired checks if the structure is expired.
26+
func (n *Expirable[D]) IsExpired() bool {
27+
if n.lastSeen == nil {
28+
return false
29+
}
30+
31+
if n.data == nil {
32+
return false
33+
}
34+
35+
return time.Since(*n.lastSeen) > n.duration
36+
}
37+
38+
func (n *Expirable[D]) Get() *D {
39+
if n.IsExpired() {
40+
return nil
41+
}
42+
43+
return n.data
44+
}
45+
46+
const (
47+
// DefaultNamespaceCacheDuration is the default duration for which a namespace is cached.
48+
DefaultNamespaceCacheDuration = 30 * time.Minute
49+
// DefaultNamespaceRateLimit defines the rate at which tokens are replenished into the bucket for the rate limiter.
50+
DefaultNamespaceRateLimit = 1000
51+
// DefaultNamespaceBurst defines the maximun size of the bucket for the rate limiter.
52+
DefaultNamespaceBurst = 1000
53+
)
54+
55+
func NewNamespaceCached(namespace *models.Namespace, duration time.Duration) *Expirable[models.Namespace] {
56+
if duration <= 0 {
57+
duration = DefaultNamespaceCacheDuration
58+
}
59+
60+
t := time.Now()
61+
62+
return &Expirable[models.Namespace]{
63+
data: namespace,
64+
duration: duration,
65+
lastSeen: &t,
66+
}
67+
}
68+
69+
type NamespaceRateLimitOptions struct {
70+
// cacheDuration specifies how long the namespace cache should be valid.
71+
cacheDuration time.Duration
72+
// rate specify how many requests per second are allowed.
73+
rate int
74+
// burst specifies the maximum burst size for the rate limiter.
75+
burst int
76+
}
77+
78+
func DefaultNamespaceRateLimitOptions() *NamespaceRateLimitOptions {
79+
return &NamespaceRateLimitOptions{
80+
cacheDuration: DefaultNamespaceCacheDuration,
81+
}
82+
}
83+
84+
type NamespaceRateLimitOption func(*NamespaceRateLimitOptions) *NamespaceRateLimitOptions
85+
86+
// NamespaceRateLimitWithCacheDuration sets the duration for which the namespace cache is valid.
87+
func NamespaceRateLimitWithCacheDuration(duration time.Duration) NamespaceRateLimitOption {
88+
return func(options *NamespaceRateLimitOptions) *NamespaceRateLimitOptions {
89+
options.cacheDuration = duration
90+
91+
return options
92+
}
93+
}
94+
95+
// NamespaceRateLimitWithRate sets the rate limit of requests per second for the rate limiter.
96+
func NamespaceRateLimitWithRate(rate int) NamespaceRateLimitOption {
97+
return func(options *NamespaceRateLimitOptions) *NamespaceRateLimitOptions {
98+
options.rate = rate
99+
100+
return options
101+
}
102+
}
103+
104+
// NamespaceRateLimitWithBurst sets the burst size for the rate limiter.
105+
func NamespaceRateLimitWithBurst(burst int) NamespaceRateLimitOption {
106+
return func(options *NamespaceRateLimitOptions) *NamespaceRateLimitOptions {
107+
options.burst = burst
108+
109+
return options
110+
}
111+
}
112+
113+
type NamespaceRateLimit struct {
114+
config *NamespaceRateLimitOptions
115+
116+
mutex *sync.Mutex
117+
mutexts map[string]*sync.Mutex
118+
119+
services services.Service
120+
121+
cached map[string]*Expirable[models.Namespace]
122+
limiters map[string]*rate.Limiter
123+
}
124+
125+
func NewNamespaceRateLimit(svc any, options ...NamespaceRateLimitOption) *NamespaceRateLimit {
126+
s, _ := svc.(services.Service)
127+
128+
config := &NamespaceRateLimitOptions{
129+
cacheDuration: DefaultNamespaceCacheDuration,
130+
rate: DefaultNamespaceRateLimit,
131+
burst: DefaultNamespaceBurst,
132+
}
133+
134+
for _, option := range options {
135+
config = option(config)
136+
}
137+
138+
return &NamespaceRateLimit{
139+
config: config,
140+
141+
mutex: new(sync.Mutex),
142+
mutexts: make(map[string]*sync.Mutex),
143+
144+
services: s,
145+
146+
cached: make(map[string]*Expirable[models.Namespace]),
147+
limiters: make(map[string]*rate.Limiter),
148+
}
149+
}
150+
151+
// getTenantMutex gets or creates a mutex for the given tenant in a thread-safe way
152+
func (l *NamespaceRateLimit) getTenantMutex(tenant string) *sync.Mutex {
153+
l.mutex.Lock()
154+
defer l.mutex.Unlock()
155+
156+
mutex, exists := l.mutexts[tenant]
157+
if !exists {
158+
mutex = &sync.Mutex{}
159+
l.mutexts[tenant] = mutex
160+
}
161+
162+
return mutex
163+
}
164+
165+
func (l *NamespaceRateLimit) Allow(tenant string) (bool, error) {
166+
if l.services == nil {
167+
log.Warn("rate limiter service is not configured - allowing request")
168+
169+
return true, nil
170+
}
171+
172+
if strings.TrimSpace(tenant) == "" {
173+
log.Error("tenant ID cannot be empty")
174+
175+
return false, fmt.Errorf("tenant ID cannot be empty")
176+
}
177+
178+
mu := l.getTenantMutex(tenant)
179+
180+
mu.Lock()
181+
defer mu.Unlock()
182+
183+
cached, exists := l.cached[tenant]
184+
185+
needsRefresh := !exists || (cached != nil && cached.IsExpired())
186+
if needsRefresh {
187+
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
188+
defer cancel()
189+
190+
namespace, err := l.services.GetNamespace(ctx, tenant)
191+
if err != nil {
192+
log.WithFields(log.Fields{
193+
"tenant": tenant,
194+
}).WithError(err).Error("failed to fetch namespace for rate limiter")
195+
196+
return false, fmt.Errorf("failed to fetch namespace: %w", err)
197+
}
198+
199+
if namespace == nil {
200+
return false, fmt.Errorf("namespace not found for tenant: %s", tenant)
201+
}
202+
203+
// TODO: We'll increase or decrease the rate dynamically based on the namespace characteristics in the future.
204+
l.cached[tenant] = NewNamespaceCached(namespace, DefaultNamespaceCacheDuration)
205+
l.limiters[tenant] = rate.NewLimiter(rate.Limit(l.config.rate), l.config.burst)
206+
207+
log.WithFields(log.Fields{
208+
"tenant": tenant,
209+
"namespace": namespace.Name,
210+
}).Debug("namespace cache refreshed for rate limiter")
211+
}
212+
213+
limiter, exists := l.limiters[tenant]
214+
if !exists {
215+
log.WithField("tenant", tenant).Error("rate limiter visitor not found after cache refresh")
216+
217+
return false, fmt.Errorf("rate limiter not configured for tenant: %s", tenant)
218+
}
219+
220+
allowed := limiter.Allow()
221+
222+
log.WithFields(log.Fields{
223+
"tenant": tenant,
224+
"allowed": allowed,
225+
"tokens": limiter.Tokens(),
226+
}).Debug("rate limiter check completed")
227+
228+
return allowed, nil
229+
}
230+
231+
// CleanupExpiredEntries removes expired cache entries (call this periodically)
232+
func (l *NamespaceRateLimit) CleanupExpiredEntries() {
233+
l.mutex.Lock()
234+
defer l.mutex.Unlock()
235+
236+
for tenant, cached := range l.cached {
237+
if cached != nil && cached.IsExpired() {
238+
delete(l.cached, tenant)
239+
delete(l.limiters, tenant)
240+
delete(l.mutexts, tenant)
241+
242+
log.WithField("tenant", tenant).Debug("cleaned up expired rate limiter cache entry")
243+
}
244+
}
245+
}
246+
247+
// SkipperNamespace is a function that checks if the context contains a valid tenant ID.
248+
var SkipperNamespace = func(context echo.Context) bool {
249+
c, ok := context.(*gateway.Context)
250+
if !ok {
251+
log.Error("context is not of type gateway.Context for rate limiting")
252+
253+
return true
254+
}
255+
256+
tenant, ok := c.GetTennat()
257+
if !ok || tenant == "" {
258+
log.Error("tenant ID cannot be empty in request context for rate limiting")
259+
260+
return true
261+
}
262+
263+
return false
264+
}
265+
266+
// NewNamespaceRateLimitMiddleware creates a middleware that limits the rate of requests based on the tenant ID
267+
// extracted from the request context.
268+
func NewNamespaceRateLimitMiddleware(service any, options ...NamespaceRateLimitOption) echo.MiddlewareFunc {
269+
return middleware.RateLimiterWithConfig(middleware.RateLimiterConfig{
270+
Skipper: SkipperNamespace,
271+
IdentifierExtractor: func(context echo.Context) (string, error) {
272+
c, ok := context.(*gateway.Context)
273+
if !ok {
274+
return "", fmt.Errorf("context is not of type gateway.Context")
275+
}
276+
277+
tenant, ok := c.GetTennat()
278+
if !ok || tenant == "" {
279+
log.Error("tenant ID cannot be empty in request context for rate limiting")
280+
281+
return "", fmt.Errorf("tenant ID cannot be empty in request context for rate limiting")
282+
}
283+
284+
return tenant, nil
285+
},
286+
Store: NewNamespaceRateLimit(service, options...),
287+
ErrorHandler: func(c echo.Context, err error) error {
288+
return &echo.HTTPError{
289+
Code: middleware.ErrRateLimitExceeded.Code,
290+
Message: middleware.ErrRateLimitExceeded.Message,
291+
Internal: err,
292+
}
293+
},
294+
DenyHandler: func(c echo.Context, identifier string, err error) error {
295+
return &echo.HTTPError{
296+
Code: middleware.ErrRateLimitExceeded.Code,
297+
Message: middleware.ErrRateLimitExceeded.Message,
298+
Internal: err,
299+
}
300+
},
301+
})
302+
}

0 commit comments

Comments
 (0)