Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
fe74623
[feat] Add Basic Session Cache Structure (#633)
SleepyLGod Jul 30, 2025
b4377fe
Merge branch 'main' into feature/add-autellix-logic
SleepyLGod Jul 30, 2025
ead1afa
[Fix] Enhance logic of sharded and mutex-based session cache with tests
SleepyLGod Jul 30, 2025
8ac640b
[feat] initialize AIBrix scheduler plugin logic
SleepyLGod Aug 1, 2025
7d22b0c
Merge branch 'main' into feature/add-autellix-logic
SleepyLGod Aug 1, 2025
5445671
Merge branch 'main' into feature/add-autellix-logic
SleepyLGod Aug 6, 2025
b6fc67f
Merge branch 'vllm-project:main' into feature/add-autellix-logic
SleepyLGod Aug 12, 2025
b295e85
feat: implement high-throughput lock-free scheduler and add benchmark…
SleepyLGod Aug 12, 2025
c2404dc
feat: integrate scheduler and state machine for request processing
SleepyLGod Aug 28, 2025
25f4cd3
feat: enhance scheduler with load awareness and batch size smoothing
SleepyLGod Aug 28, 2025
24b3a32
feat: integrate scheduler with cache and enhance state machine tests
SleepyLGod Aug 28, 2025
7af94f2
Merge branch 'main' into feature/add-autellix-logic
SleepyLGod Aug 28, 2025
10eb267
feat: enhance pod watcher and scheduling logic with detailed filtering
SleepyLGod Sep 4, 2025
abc31d0
feat: implement headers-only session ID extraction and disable body p…
SleepyLGod Sep 4, 2025
3b89b01
feat: Add session interface and modify the shard session cache version.
SleepyLGod Oct 7, 2025
c6d8906
feat: Add legacy processing mode support and update routing logic in …
SleepyLGod Oct 8, 2025
6640936
feat: Enhance session ID validation by sending 400 Bad Request respon…
SleepyLGod Oct 8, 2025
2f61b1a
feat: Adjust logging levels for message processing and request header…
SleepyLGod Oct 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 96 additions & 0 deletions pkg/plugins/gateway/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,48 @@ import (
"github.com/vllm-project/aibrix/pkg/metrics"
routing "github.com/vllm-project/aibrix/pkg/plugins/gateway/algorithms"
"github.com/vllm-project/aibrix/pkg/plugins/gateway/ratelimiter"
"github.com/vllm-project/aibrix/pkg/plugins/gateway/scheduler"
"github.com/vllm-project/aibrix/pkg/plugins/gateway/scheduler/sessioninfo"
"github.com/vllm-project/aibrix/pkg/types"
"github.com/vllm-project/aibrix/pkg/utils"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
gatewayv1 "sigs.k8s.io/gateway-api/apis/v1"
gatewayapi "sigs.k8s.io/gateway-api/pkg/client/clientset/versioned"
)

// requestState represents the state of a single request processing flow
type requestState int

const (
stateAwaitingHeaders requestState = iota
stateAwaitingBody
stateAwaitingDecision
stateForwarding
stateDone
)

// perRequestState holds all the state for a single Process() invocation
type perRequestState struct {
currentState requestState
sessionID string
requestID string
user utils.User
rpm int64
model string
routerCtx *types.RoutingContext
stream bool
traceTerm int64
completed bool
isRespError bool
respErrorCode int

// For timing and scheduling
requestStartTime time.Time
submissionTime time.Time
dispatchTime time.Time // When scheduler granted permission
schedulingDecision *scheduler.Decision
}

const (
defaultAIBrixNamespace = "aibrix-system"
)
Expand All @@ -56,6 +91,17 @@ type Server struct {
requestCountTracker map[string]int
cache cache.Cache
metricsServer *metrics.Server

// Scheduler and session management
scheduler scheduler.Scheduler
sessionCache *sessioninfo.MutexSessionCache

// Cleanup function for session cache
sessionCleanupStop func()

// useLegacyMode controls whether to use legacy processing mode (default: false)
// When true, scheduler-based routing is disabled and legacy routing is used
useLegacyMode bool
}

func NewServer(redisClient *redis.Client, client kubernetes.Interface, gatewayClient gatewayapi.Interface) *Server {
Expand All @@ -68,6 +114,13 @@ func NewServer(redisClient *redis.Client, client kubernetes.Interface, gatewayCl
// Initialize the routers
routing.Init()

// Initialize session cache and scheduler
sessionCache := sessioninfo.NewMutexSessionCache()
sched := scheduler.NewScheduler(client, sessionCache, c)

// Start session cleanup routine (cleanup every 5 minutes, timeout after 30 minutes)
sessionCleanupStop := sessionCache.StartCleanupRoutine(5*time.Minute, 30*time.Minute)

return &Server{
redisClient: redisClient,
ratelimiter: r,
Expand All @@ -76,10 +129,35 @@ func NewServer(redisClient *redis.Client, client kubernetes.Interface, gatewayCl
requestCountTracker: map[string]int{},
cache: c,
metricsServer: nil,
scheduler: sched,
sessionCache: sessionCache,
sessionCleanupStop: sessionCleanupStop,
useLegacyMode: false, // Default to state machine mode
}
}

// SetLegacyMode enables or disables legacy processing mode
// This should be called before starting to process requests
func (s *Server) SetLegacyMode(enabled bool) {
s.useLegacyMode = enabled
if enabled {
klog.InfoS("legacy mode enabled - scheduler-based routing will be disabled")
} else {
klog.InfoS("state machine mode enabled - using scheduler-based routing")
}
}

// Process delegates to the appropriate implementation based on useLegacyMode flag
func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error {
if s.useLegacyMode {
klog.InfoS("using legacy processing mode")
return s.ProcessLegacy(srv)
}
return s.ProcessStateMachine(srv)
}

// ProcessLegacy is the original implementation kept for reference
func (s *Server) ProcessLegacy(srv extProcPb.ExternalProcessor_ProcessServer) error {
var user utils.User
var rpm, traceTerm int64
var respErrorCode int
Expand Down Expand Up @@ -227,11 +305,29 @@ func (s *Server) StartMetricsServer(addr string) error {
}

func (s *Server) Shutdown() {
klog.InfoS("Starting graceful shutdown of Gateway Server")

// Stop scheduler first to prevent new jobs
if s.scheduler != nil {
klog.InfoS("Stopping scheduler")
s.scheduler.Stop()
}

// Stop session cache cleanup routine
if s.sessionCleanupStop != nil {
klog.InfoS("Stopping session cache cleanup routine")
s.sessionCleanupStop()
}

// Stop metrics server
if s.metricsServer != nil {
klog.InfoS("Stopping metrics server")
if err := s.metricsServer.Stop(); err != nil {
klog.ErrorS(err, "Error stopping metrics server")
}
}

klog.InfoS("Gateway Server shutdown complete")
}

func (s *Server) responseErrorProcessing(ctx context.Context, resp *extProcPb.ProcessingResponse, respErrorCode int,
Expand Down
18 changes: 17 additions & 1 deletion pkg/plugins/gateway/gateway_req_body.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,14 @@ import (
func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, user utils.User) (*extProcPb.ProcessingResponse, string, *types.RoutingContext, bool, int64) {
var term int64 // Identify the trace window

routingCtx, _ := ctx.(*types.RoutingContext)
routingCtx, ok := ctx.(*types.RoutingContext)
if !ok || routingCtx == nil {
klog.ErrorS(nil, "CRITICAL: context is not RoutingContext or is nil", "requestID", requestID, "contextType", fmt.Sprintf("%T", ctx))
return generateErrorResponse(envoyTypePb.StatusCode_InternalServerError,
[]*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{
Key: HeaderErrorRouting, RawValue: []byte("true")}}},
"internal routing context error"), "", nil, false, term
}
requestPath := routingCtx.ReqPath
routingAlgorithm := routingCtx.Algorithm

Expand Down Expand Up @@ -66,6 +73,15 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, req *e
fmt.Sprintf("error on getting pods for model %s", model)), model, routingCtx, stream, term
}

// Check if we're in legacy mode or state machine mode
if !s.useLegacyMode {
// State machine mode: defer routing to the scheduler
// Just validate the model exists and return nil to let Process handle scheduling
klog.InfoS("request body processed, deferring to scheduler", "requestID", requestID, "requestPath", requestPath, "model", model, "stream", stream)
return nil, model, routingCtx, stream, term
}

// Legacy routing logic (when useLegacyMode is true)
headers := []*configPb.HeaderValueOption{}
if routingAlgorithm == routing.RouterNotSet {
if err := s.validateHTTPRouteStatus(ctx, model); err != nil {
Expand Down
169 changes: 169 additions & 0 deletions pkg/plugins/gateway/gateway_req_body_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -628,9 +628,11 @@ func Test_handleRequestBody(t *testing.T) {
mockHTTP.On("Get", mock.Anything, "test-model-router", mock.Anything).Return(route, nil)

// Create server with mock cache
// Enable legacy mode for these tests since they test HandleRequestBody's direct routing behavior
server := &Server{
cache: mockCache,
gatewayClient: mockGW,
useLegacyMode: true, // These tests expect direct routing behavior (legacy mode)
}

// Create request for the test case
Expand Down Expand Up @@ -661,3 +663,170 @@ func Test_handleRequestBody(t *testing.T) {
})
}
}

// Test_handleRequestBody_StateMachine tests HandleRequestBody in state machine mode (default mode)
// In state machine mode, HandleRequestBody should return nil to defer routing to the scheduler
func Test_handleRequestBody_StateMachine(t *testing.T) {
// Initialize routing algorithms
routingalgorithms.Init()

tests := []struct {
name string
requestBody string
user utils.User
routingAlgo types.RoutingAlgorithm
mockSetup func(*MockCache)
expectNil bool // Should response be nil?
expectModel string
expectTerm int64
}{
{
name: "state machine mode - valid model should return nil",
requestBody: `{"model": "test-model", "messages": [{"role": "user", "content": "test"}]}`,
user: utils.User{
Name: "test-user",
},
routingAlgo: "random",
mockSetup: func(mockCache *MockCache) {
mockCache.On("HasModel", "test-model").Return(true)
podList := &utils.PodArray{
Pods: []*v1.Pod{
{
Status: v1.PodStatus{
PodIP: "1.2.3.4",
Conditions: []v1.PodCondition{{Type: v1.PodReady, Status: v1.ConditionTrue}},
},
},
},
}
mockCache.On("ListPodsByModel", "test-model").Return(podList, nil)
// Note: AddRequestCount should NOT be called in state machine mode
},
expectNil: true,
expectModel: "test-model",
expectTerm: 0, // No AddRequestCount call
},
{
name: "state machine mode - model not found should return error",
requestBody: `{"model": "unknown-model", "messages": [{"role": "user", "content": "test"}]}`,
user: utils.User{
Name: "test-user",
},
routingAlgo: "",
mockSetup: func(mockCache *MockCache) {
mockCache.On("HasModel", "unknown-model").Return(false)
},
expectNil: false, // Error response should not be nil
expectModel: "unknown-model",
expectTerm: 0,
},
{
name: "state machine mode - no ready pods should return error",
requestBody: `{"model": "test-model", "messages": [{"role": "user", "content": "test"}]}`,
user: utils.User{
Name: "test-user",
},
routingAlgo: "",
mockSetup: func(mockCache *MockCache) {
mockCache.On("HasModel", "test-model").Return(true)
podList := &utils.PodArray{
Pods: []*v1.Pod{
{
Status: v1.PodStatus{
PodIP: "1.2.3.4",
Conditions: []v1.PodCondition{
{
Type: v1.PodReady,
Status: v1.ConditionFalse, // Not ready
},
},
},
},
},
}
mockCache.On("ListPodsByModel", "test-model").Return(podList, nil)
},
expectNil: false, // Error response should not be nil
expectModel: "test-model",
expectTerm: 0,
},
}

for _, tt := range tests {
t.Run(tt.name, func(subtest *testing.T) {
// Initialize mock cache
mockCache := &MockCache{Cache: cache.NewForTest()}
if tt.mockSetup != nil {
tt.mockSetup(mockCache)
}

mockGW := &MockGatewayClient{}
mockGWv1 := &MockGatewayV1Client{}
mockHTTP := &MockHTTPRouteClient{}

mockGW.On("GatewayV1").Return(mockGWv1)
mockGWv1.On("HTTPRoutes", "aibrix-system").Return(mockHTTP)

route := &gatewayv1.HTTPRoute{
Status: gatewayv1.HTTPRouteStatus{
RouteStatus: gatewayv1.RouteStatus{
Parents: []gatewayv1.RouteParentStatus{{
Conditions: []metav1.Condition{{
Type: string(gatewayv1.RouteConditionAccepted),
Reason: string(gatewayv1.RouteReasonAccepted),
Status: metav1.ConditionTrue,
}, {
Type: string(gatewayv1.RouteConditionResolvedRefs),
Reason: string(gatewayv1.RouteReasonResolvedRefs),
Status: metav1.ConditionTrue,
}},
}},
},
},
}
mockHTTP.On("Get", mock.Anything, "test-model-router", mock.Anything).Return(route, nil)

// Create server in STATE MACHINE mode (useLegacyMode = false, which is the default)
server := &Server{
cache: mockCache,
gatewayClient: mockGW,
useLegacyMode: false, // State machine mode
}

// Create request
req := &extProcPb.ProcessingRequest{
Request: &extProcPb.ProcessingRequest_RequestBody{
RequestBody: &extProcPb.HttpBody{
Body: []byte(tt.requestBody),
},
},
}

// Call HandleRequestBody
routingCtx := types.NewRoutingContext(context.Background(), tt.routingAlgo, tt.expectModel, "", "test-request-id", tt.user.Name)
routingCtx.ReqPath = "/v1/chat/completions"
resp, model, returnedCtx, stream, term := server.HandleRequestBody(
routingCtx,
"test-request-id",
req,
tt.user,
)

// Validate response
if tt.expectNil {
assert.Nil(subtest, resp, "response should be nil in state machine mode for valid requests")
} else {
assert.NotNil(subtest, resp, "response should not be nil for error cases")
}

assert.Equal(subtest, tt.expectModel, model)
assert.Equal(subtest, tt.expectTerm, term)
assert.NotNil(subtest, returnedCtx)
assert.Equal(subtest, tt.expectModel, returnedCtx.Model)
assert.False(subtest, stream) // These test cases don't use streaming

// Verify all mock expectations were met
mockCache.AssertExpectations(subtest)
})
}
}
Loading