Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions cmd/dev/seed/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ func seedLocal(ctx context.Context, cmd *cli.Command) error {
Region: "local",
UsageLimiter: nil,
KeyCache: nil,
QuotaCache: nil,
})
if err != nil {
return fmt.Errorf("failed to create key service: %w", err)
Expand Down Expand Up @@ -217,13 +218,17 @@ func seedLocal(ctx context.Context, cmd *cli.Command) error {
AuditLogsRetentionDays: 30,
LogsRetentionDays: 7,
Team: false,
RatelimitApiLimit: sql.NullInt32{}, //nolint:exhaustruct
RatelimitApiDuration: sql.NullInt32{}, //nolint:exhaustruct
},
{
WorkspaceID: rootWorkspaceID,
RequestsPerMonth: 150000,
AuditLogsRetentionDays: 30,
LogsRetentionDays: 7,
Team: false,
RatelimitApiLimit: sql.NullInt32{}, //nolint:exhaustruct
RatelimitApiDuration: sql.NullInt32{}, //nolint:exhaustruct
},
})
if err != nil {
Expand Down
1 change: 1 addition & 0 deletions cmd/dev/seed/verifications.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ func seedVerifications(ctx context.Context, cmd *cli.Command) error {
Region: "test",
UsageLimiter: nil,
KeyCache: nil,
QuotaCache: nil,
})
if err != nil {
return fmt.Errorf("failed to create key service: %w", err)
Expand Down
19 changes: 19 additions & 0 deletions internal/services/caches/caches.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ type Caches struct {
// Keys are string (api_id) and values are db.FindKeyAuthsByIdsRow (has both KeyAuthID and ApiID).
ApiToKeyAuthRow cache.Cache[cache.ScopedKey, db.FindKeyAuthsByIdsRow]

// WorkspaceQuota caches workspace quota lookups by workspace ID.
// Keys are string (workspace ID) and values are db.Quotas.
WorkspaceQuota cache.Cache[string, db.Quotas]

// dispatcher handles routing of invalidation events to all caches in this process.
// This is not exported as it's an internal implementation detail.
dispatcher *clustering.InvalidationDispatcher
Expand Down Expand Up @@ -259,6 +263,20 @@ func New(config Config) (Caches, error) {
return Caches{}, err
}

workspaceQuota, err := createCache(
cache.Config[string, db.Quotas]{
Fresh: time.Minute,
Stale: 24 * time.Hour,
MaxSize: 100_000,
Resource: "workspace_quota",
Clock: config.Clock,
},
stringKeyOpts,
)
if err != nil {
return Caches{}, err
}

initialized = true
return Caches{
RatelimitNamespace: middleware.WithTracing(ratelimitNamespace),
Expand All @@ -267,6 +285,7 @@ func New(config Config) (Caches, error) {
ClickhouseSetting: middleware.WithTracing(clickhouseSetting),
KeyAuthToApiRow: middleware.WithTracing(keyAuthToApiRow),
ApiToKeyAuthRow: middleware.WithTracing(apiToKeyAuthRow),
WorkspaceQuota: middleware.WithTracing(workspaceQuota),
dispatcher: dispatcher,
}, nil
}
3 changes: 2 additions & 1 deletion internal/services/keys/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ go_library(
"status.go",
"validation.go",
"verifier.go",
"workspace_ratelimit.go",
],
importpath = "github.com/unkeyed/unkey/internal/services/keys",
visibility = ["//:__subpackages__"],
Expand Down Expand Up @@ -41,7 +42,7 @@ go_library(

go_test(
name = "keys_test",
size = "small",
size = "large",
srcs = [
"create_test.go",
"get_test.go",
Expand Down
4 changes: 4 additions & 0 deletions internal/services/keys/get.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ func (s *service) GetRootKey(ctx context.Context, sess *zen.Session) (*KeyVerifi
key.AuthorizedWorkspaceID = key.Key.ForWorkspaceID.String
sess.WorkspaceID = key.AuthorizedWorkspaceID

if err := s.checkWorkspaceRateLimit(ctx, sess); err != nil {
return nil, log, err
}

logger.Set(ctx, slog.Group("auth",
slog.String("workspace_id", key.AuthorizedWorkspaceID),
slog.String("root_key_id", key.Key.ID),
Expand Down
7 changes: 6 additions & 1 deletion internal/services/keys/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ type Config struct {
Region string // Geographic region identifier
UsageLimiter usagelimiter.Service // Redis Counter for usage limiting

KeyCache cache.Cache[string, db.CachedKeyData] // Cache for key lookups with pre-parsed data
KeyCache cache.Cache[string, db.CachedKeyData] // Cache for key lookups with pre-parsed data
QuotaCache cache.Cache[string, db.Quotas] // Cache for workspace quota lookups
}

type service struct {
Expand All @@ -31,6 +32,9 @@ type service struct {

// hash -> cached key data (includes pre-parsed IP whitelist)
keyCache cache.Cache[string, db.CachedKeyData]

// workspace_id -> quota (for workspace rate limiting)
quotaCache cache.Cache[string, db.Quotas]
}

// New creates a new keys service instance with the provided configuration.
Expand All @@ -44,6 +48,7 @@ func New(config Config) (*service, error) {
clickhouse: config.Clickhouse,
region: config.Region,
keyCache: config.KeyCache,
quotaCache: config.QuotaCache,
}, nil
}

Expand Down
95 changes: 95 additions & 0 deletions internal/services/keys/workspace_ratelimit.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package keys

import (
"context"
"fmt"
"strconv"
"time"

"github.com/unkeyed/unkey/internal/services/caches"
"github.com/unkeyed/unkey/internal/services/ratelimit"
"github.com/unkeyed/unkey/pkg/codes"
"github.com/unkeyed/unkey/pkg/db"
"github.com/unkeyed/unkey/pkg/fault"
"github.com/unkeyed/unkey/pkg/logger"
"github.com/unkeyed/unkey/pkg/zen"
)

const workspaceRatelimitNamespace = "workspace.ratelimit"

// checkWorkspaceRateLimit enforces per-workspace API rate limiting based on
// the quota table.
//
// NULL limit/duration = unlimited (no rate limiting configured).
// 0 limit = zero requests allowed.
// On any internal error (cache miss, rate limiter failure) the check fails
// open to avoid blocking legitimate traffic.
func (s *service) checkWorkspaceRateLimit(ctx context.Context, sess *zen.Session) error {

quota, _, err := s.quotaCache.SWR(ctx, sess.AuthorizedWorkspaceID(), func(ctx context.Context) (db.Quotas, error) {
return db.Query.FindQuotaByWorkspaceID(ctx, s.db.RO(), sess.AuthorizedWorkspaceID())
}, caches.DefaultFindFirstOp)
if err != nil {
logger.Error("workspace rate limit: failed to load quota",
"workspace_id", sess.AuthorizedWorkspaceID(),
"error", err.Error(),
)
return nil // fail open
}

// NULL = unlimited, no rate limiting configured
if !quota.RatelimitApiLimit.Valid || !quota.RatelimitApiDuration.Valid {
return nil
}

limit := quota.RatelimitApiLimit.Int32
duration := time.Duration(quota.RatelimitApiDuration.Int32) * time.Millisecond

// 0 = explicitly blocked, no requests allowed
if limit == 0 || duration == 0 {
return fault.New("workspace rate limit exceeded",
fault.Code(codes.User.TooManyRequests.WorkspaceRateLimited.URN()),
fault.Internal("workspace rate limit is zero"),
fault.Public(
fmt.Sprintf("This workspace has exceeded its API rate limit of %d/%s. Please try again later.", limit, duration.String()),
),
)
}

resp, err := s.rateLimiter.Ratelimit(ctx, ratelimit.RatelimitRequest{
Name: workspaceRatelimitNamespace,
Identifier: sess.AuthorizedWorkspaceID(),
Limit: int64(limit),
Duration: duration,
Cost: 1,
Time: time.Time{}, //nolint:exhaustruct // use ratelimiter's clock
})
if err != nil {
logger.Error("workspace rate limit: ratelimiter error",
"workspace_id", sess.AuthorizedWorkspaceID(),
"error", err.Error(),
)
return nil // fail open
}

// Set standard rate limit headers (IETF draft-ietf-httpapi-ratelimit-headers)
resetSeconds := max(int64(time.Until(resp.Reset).Seconds()), 0)

sess.AddHeader("RateLimit-Limit", strconv.FormatInt(resp.Limit, 10))
sess.AddHeader("RateLimit-Remaining", strconv.FormatInt(resp.Remaining, 10))
sess.AddHeader("RateLimit-Reset", strconv.FormatInt(resetSeconds, 10))

if !resp.Success {
sess.AddHeader("Retry-After", strconv.FormatInt(resetSeconds, 10))

return fault.New("workspace rate limit exceeded",
fault.Code(codes.User.TooManyRequests.WorkspaceRateLimited.URN()),
fault.Internal("workspace rate limit exceeded"),
fault.Public(
fmt.Sprintf("This workspace has exceeded its API rate limit of %d/%s. Please try again later.", limit, duration.String()),
),
)
}

return nil
}
13 changes: 13 additions & 0 deletions internal/services/ratelimit/namespace/BUILD.bazel
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
load("@rules_go//go:def.bzl", "go_library")

go_library(
name = "namespace",
srcs = [
"parse.go",
],
importpath = "github.com/unkeyed/unkey/internal/services/ratelimit/namespace",
visibility = ["//:__subpackages__"],
deps = [
"//pkg/db",
],
)
66 changes: 66 additions & 0 deletions internal/services/ratelimit/namespace/parse.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package namespace

import (
"encoding/json"
"strings"

"github.com/unkeyed/unkey/pkg/db"
)

// ParseNamespaceRow converts a raw DB row into a FindRatelimitNamespace with parsed overrides.
func ParseNamespaceRow(row db.FindRatelimitNamespaceRow) db.FindRatelimitNamespace {
result := db.FindRatelimitNamespace{
ID: row.ID,
WorkspaceID: row.WorkspaceID,
Name: row.Name,
CreatedAtM: row.CreatedAtM,
UpdatedAtM: row.UpdatedAtM,
DeletedAtM: row.DeletedAtM,
DirectOverrides: make(map[string]db.FindRatelimitNamespaceLimitOverride),
WildcardOverrides: make([]db.FindRatelimitNamespaceLimitOverride, 0),
}

overrides := make([]db.FindRatelimitNamespaceLimitOverride, 0)
if overrideBytes, ok := row.Overrides.([]byte); ok && overrideBytes != nil {
if unmarshalErr := json.Unmarshal(overrideBytes, &overrides); unmarshalErr != nil {
return result
}
}

for _, override := range overrides {
result.DirectOverrides[override.Identifier] = override
if strings.Contains(override.Identifier, "*") {
result.WildcardOverrides = append(result.WildcardOverrides, override)
}
}

return result
}

// RowToNamespace converts a FindManyRatelimitNamespacesRow to FindRatelimitNamespace.
func RowToNamespace(row db.FindManyRatelimitNamespacesRow) db.FindRatelimitNamespace {
result := db.FindRatelimitNamespace{
ID: row.ID,
WorkspaceID: row.WorkspaceID,
Name: row.Name,
CreatedAtM: row.CreatedAtM,
UpdatedAtM: row.UpdatedAtM,
DeletedAtM: row.DeletedAtM,
DirectOverrides: make(map[string]db.FindRatelimitNamespaceLimitOverride),
WildcardOverrides: make([]db.FindRatelimitNamespaceLimitOverride, 0),
}

overrides, err := db.UnmarshalNullableJSONTo[[]db.FindRatelimitNamespaceLimitOverride](row.Overrides)
if err != nil {
return result
}

for _, override := range overrides {
result.DirectOverrides[override.Identifier] = override
if strings.Contains(override.Identifier, "*") {
result.WildcardOverrides = append(result.WildcardOverrides, override)
}
}

return result
}
2 changes: 2 additions & 0 deletions pkg/codes/constants_gen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion pkg/codes/user_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ type userUnprocessableEntity struct {
type userTooManyRequests struct {
// QueryQuotaExceeded indicates the workspace has exceeded their query quota for the current window.
QueryQuotaExceeded Code
// WorkspaceRateLimited indicates the workspace has exceeded its API rate limit for the current window.
WorkspaceRateLimited Code
}

// UserErrors defines all user-related errors in the Unkey system.
Expand Down Expand Up @@ -76,6 +78,7 @@ var User = UserErrors{
QueryRowsLimitExceeded: Code{SystemUser, CategoryUserUnprocessableEntity, "query_rows_limit_exceeded"},
},
TooManyRequests: userTooManyRequests{
QueryQuotaExceeded: Code{SystemUser, CategoryUserTooManyRequests, "query_quota_exceeded"},
QueryQuotaExceeded: Code{SystemUser, CategoryUserTooManyRequests, "query_quota_exceeded"},
WorkspaceRateLimited: Code{SystemUser, CategoryUserTooManyRequests, "workspace_rate_limited"},
},
}
10 changes: 7 additions & 3 deletions pkg/db/bulk_quota_upsert.sql_generated.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading