Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions pkg/kthena-router/accesslog/logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,18 @@ func (l *accessLoggerImpl) formatText(entry *AccessLogEntry) (string, error) {
if entry.ModelServer != "" {
line += fmt.Sprintf(" model_server=%s", entry.ModelServer)
}

// Add Gateway API fields
if entry.Gateway != "" {
line += fmt.Sprintf(" gateway=%s", entry.Gateway)
}
if entry.HTTPRoute != "" {
line += fmt.Sprintf(" http_route=%s", entry.HTTPRoute)
}
if entry.InferencePool != "" {
line += fmt.Sprintf(" inference_pool=%s", entry.InferencePool)
Comment on lines +175 to +183
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you move these after modelRoute and ModelServer above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}
Comment on lines +176 to +184

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code, along with other parts of the function, repeatedly concatenates strings using +=. This can be inefficient as it may lead to multiple memory allocations for each concatenation. For better performance and readability, consider refactoring the formatText function to use a strings.Builder.

Here's an example of how the whole function could be improved:

import (
	"fmt"
	"strings"
	"time"
)

func (l *accessLoggerImpl) formatText(entry *AccessLogEntry) (string, error) {
	var sb strings.Builder
	timestamp := entry.Timestamp.Format(time.RFC3339Nano)

	fmt.Fprintf(&sb, `[%s] "%s %s %s" %d`,
		timestamp, entry.Method, entry.Path, entry.Protocol,
		entry.StatusCode)

	if entry.Error != nil {
		fmt.Fprintf(&sb, " error=%s:%s", entry.Error.Type, entry.Error.Message)
	}

	appendKV := func(k, v string) {
		if v != "" {
			fmt.Fprintf(&sb, " %s=%s", k, v)
		}
	}

	appendKV("model_name", entry.ModelName)
	appendKV("model_route", entry.ModelRoute)
	appendKV("model_server", entry.ModelServer)
	appendKV("selected_pod", entry.SelectedPod)
	appendKV("request_id", entry.RequestID)
	appendKV("gateway", entry.Gateway)
	appendKV("http_route", entry.HTTPRoute)
	appendKV("inference_pool", entry.InferencePool)

	if entry.InputTokens > 0 || entry.OutputTokens > 0 {
		fmt.Fprintf(&sb, " tokens=%d/%d", entry.InputTokens, entry.OutputTokens)
	}

	fmt.Fprintf(&sb, " timings=%dms(%d+%d+%d)",
		entry.DurationTotal,
		entry.DurationRequestProcessing,
		entry.DurationUpstreamProcessing,
		entry.DurationResponseProcessing)

	return sb.String(), nil
}

This approach builds the string more efficiently in a single buffer and is more maintainable.


if entry.SelectedPod != "" {
line += fmt.Sprintf(" selected_pod=%s", entry.SelectedPod)
}
Expand Down
29 changes: 29 additions & 0 deletions pkg/kthena-router/accesslog/logger_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ func TestAccessLogEntry_ToJSON(t *testing.T) {
ModelServer: "default/llama2-server",
SelectedPod: "llama2-deployment-5f7b8c9d-xk2p4",
RequestID: "test-request-id",
Gateway: "default/test-gateway",
HTTPRoute: "default/test-httproute",
InferencePool: "default/test-inferencepool",
InputTokens: 150,
OutputTokens: 75,
DurationTotal: 2350,
Expand Down Expand Up @@ -68,6 +71,9 @@ func TestAccessLogEntry_ToJSON(t *testing.T) {
assert.Equal(t, "default/llama2-route-v1", parsed["model_route"])
assert.Equal(t, "default/llama2-server", parsed["model_server"])
assert.Equal(t, "llama2-deployment-5f7b8c9d-xk2p4", parsed["selected_pod"])
assert.Equal(t, "default/test-gateway", parsed["gateway"])
assert.Equal(t, "default/test-httproute", parsed["http_route"])
assert.Equal(t, "default/test-inferencepool", parsed["inference_pool"])
assert.Equal(t, float64(150), parsed["input_tokens"])
assert.Equal(t, float64(75), parsed["output_tokens"])

Expand All @@ -90,6 +96,9 @@ func TestAccessLogEntry_ToText(t *testing.T) {
ModelServer: "default/llama2-server",
SelectedPod: "llama2-deployment-5f7b8c9d-xk2p4",
RequestID: "test-request-id",
Gateway: "default/test-gateway",
HTTPRoute: "default/test-httproute",
InferencePool: "default/test-inferencepool",
InputTokens: 150,
OutputTokens: 75,
DurationTotal: 2350,
Expand Down Expand Up @@ -118,6 +127,9 @@ func TestAccessLogEntry_ToText(t *testing.T) {
`model_server=default/llama2-server`,
`selected_pod=llama2-deployment-5f7b8c9d-xk2p4`,
`request_id=test-request-id`,
`gateway=default/test-gateway`,
`http_route=default/test-httproute`,
`inference_pool=default/test-inferencepool`,
`tokens=150/75`,
`timings=2350ms(45+2180+5)`,
}
Expand All @@ -139,6 +151,9 @@ func TestAccessLogEntry_WithError(t *testing.T) {
Message: "Model inference timeout after 30s",
},
ModelName: "llama2-7b",
Gateway: "default/test-gateway",
HTTPRoute: "default/test-httproute",
InferencePool: "default/test-inferencepool",
DurationTotal: 100,
DurationRequestProcessing: 50,
DurationUpstreamProcessing: 0,
Expand All @@ -163,12 +178,18 @@ func TestAccessLogEntry_WithError(t *testing.T) {
errorInfo := parsed["error"].(map[string]interface{})
assert.Equal(t, "timeout", errorInfo["type"])
assert.Equal(t, "Model inference timeout after 30s", errorInfo["message"])
assert.Equal(t, "default/test-gateway", parsed["gateway"])
assert.Equal(t, "default/test-httproute", parsed["http_route"])
assert.Equal(t, "default/test-inferencepool", parsed["inference_pool"])

// Test text format
config.Format = FormatText
output, err = logger.formatText(entry)
require.NoError(t, err)
assert.Contains(t, output, "error=timeout:Model inference timeout after 30s")
assert.Contains(t, output, "gateway=default/test-gateway")
assert.Contains(t, output, "http_route=default/test-httproute")
assert.Contains(t, output, "inference_pool=default/test-inferencepool")
}

func TestAccessLogContext_Lifecycle(t *testing.T) {
Expand Down Expand Up @@ -203,6 +224,11 @@ func TestAccessLogContext_Lifecycle(t *testing.T) {
assert.Equal(t, "rate_limit", ctx.Error.Type)
assert.Equal(t, "Too many requests", ctx.Error.Message)

// Set Gateway API info
ctx.Gateway = "default/test-gateway"
ctx.HTTPRoute = "default/test-httproute"
ctx.InferencePool = "default/test-inferencepool"

// Mark timing phases
time.Sleep(1 * time.Millisecond) // Ensure time difference
ctx.MarkRequestProcessingEnd()
Expand Down Expand Up @@ -230,6 +256,9 @@ func TestAccessLogContext_Lifecycle(t *testing.T) {
assert.Greater(t, entry.DurationTotal, int64(0))
assert.NotNil(t, entry.Error)
assert.Equal(t, "rate_limit", entry.Error.Type)
assert.Equal(t, "default/test-gateway", entry.Gateway)
assert.Equal(t, "default/test-httproute", entry.HTTPRoute)
assert.Equal(t, "default/test-inferencepool", entry.InferencePool)
}

func TestNoopAccessLogger(t *testing.T) {
Expand Down
10 changes: 10 additions & 0 deletions pkg/kthena-router/accesslog/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,13 @@ func MarkResponseProcessingEnd(c *gin.Context) {
ctx.MarkResponseProcessingEnd()
}
}

// SetGatewayAPIInfo sets Gateway API information in the access log context
// gateway, httpRoute, and inferencePool should be in namespace/name format
func SetGatewayAPIInfo(c *gin.Context, gateway, httpRoute, inferencePool string) {
if ctx := GetAccessLogContext(c); ctx != nil {
ctx.Gateway = gateway
ctx.HTTPRoute = httpRoute
ctx.InferencePool = inferencePool
}
}
13 changes: 13 additions & 0 deletions pkg/kthena-router/accesslog/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@ type AccessLogEntry struct {
SelectedPod string `json:"selected_pod,omitempty"`
RequestID string `json:"request_id,omitempty"`

// Gateway API information
Gateway string `json:"gateway,omitempty"`
HTTPRoute string `json:"http_route,omitempty"`
InferencePool string `json:"inference_pool,omitempty"`

// Token information
InputTokens int `json:"input_tokens,omitempty"`
OutputTokens int `json:"output_tokens,omitempty"`
Expand Down Expand Up @@ -69,6 +74,11 @@ type AccessLogContext struct {
ModelServer string
SelectedPod string

// Gateway API information
Gateway string
HTTPRoute string
InferencePool string

// Token counts
InputTokens int
OutputTokens int
Expand Down Expand Up @@ -189,6 +199,9 @@ func (ctx *AccessLogContext) ToAccessLogEntry(statusCode int) *AccessLogEntry {
RequestID: ctx.RequestID,
InputTokens: ctx.InputTokens,
OutputTokens: ctx.OutputTokens,
Gateway: ctx.Gateway,
HTTPRoute: ctx.HTTPRoute,
InferencePool: ctx.InferencePool,
DurationTotal: total,
DurationRequestProcessing: requestProcessing,
DurationUpstreamProcessing: upstreamProcessing,
Expand Down
38 changes: 37 additions & 1 deletion pkg/kthena-router/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ import (

const (
// Context keys for gin context
GatewayKey = "gatewayKey"
GatewayKey = "gatewayKey"
HTTPRouteKey = "httpRouteName"
InferencePoolKey = "inferencePoolName"
)

func getEnvBool(key string, fallback bool) bool {
Expand Down Expand Up @@ -394,6 +396,25 @@ func (r *Router) doLoadbalance(c *gin.Context, modelRequest ModelRequest) {
c.Set("modelRouteName", modelRouteName)
}

// Set Gateway API info from context
var gatewayKeyForLog, httpRouteKey, inferencePoolKey string
if key, exists := c.Get(GatewayKey); exists {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: In my opinion, the operations of L393-L407 are similar; can they be abstracted into a function?

if k, ok := key.(string); ok {
gatewayKeyForLog = k
}
}
if httpRouteName, exists := c.Get(HTTPRouteKey); exists {
if name, ok := httpRouteName.(types.NamespacedName); ok {
httpRouteKey = fmt.Sprintf("%s/%s", name.Namespace, name.Name)
}
}
if inferencePoolName, exists := c.Get(InferencePoolKey); exists {
if name, ok := inferencePoolName.(types.NamespacedName); ok {
inferencePoolKey = fmt.Sprintf("%s/%s", name.Namespace, name.Name)
}
}
accesslog.SetGatewayAPIInfo(c, gatewayKeyForLog, httpRouteKey, inferencePoolKey)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@YaoZengzeng We support mix using gateway httpRoute and modelServer, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First attempt to match ModelRoute → ModelServer; if that fails, fall back to HTTPRoute → InferencePool. Both share the same gatewayKey context.
FYI:

if err == nil && strings.HasPrefix(c.Request.URL.Path, "/v1/") {
// Regular ModelServer request
// step 3: Find pods and model server details
klog.V(4).Infof("modelServer is %v, is_lora: %v", modelServerName, isLora)
pods, modelServer, err = r.getPodsAndServer(modelServerName)
if err != nil || len(pods) == 0 {
klog.Errorf("failed to get pods and model server: %v, %v", modelServerName, err)
accesslog.SetError(c, "pod_discovery", fmt.Sprintf("can't find model server: %v", modelServerName))
c.AbortWithStatusJSON(http.StatusNotFound, fmt.Sprintf("can't find model server: %v", modelServerName))
return
}
model := modelServer.Spec.Model
if model != nil && !isLora {
modelRequest["model"] = *model
}
port = modelServer.Spec.WorkloadPort.Port
} else if matched, inferencePoolName := r.handleHTTPRoute(c, gatewayKey); matched {
// If ModelRoute is not matched, try to match HTTPRoute


if len(ctx.BestPods) > 0 && ctx.BestPods[0].Pod != nil {
selectedPod := ctx.BestPods[0].Pod.Name
accesslog.SetRequestRouting(c, modelRouteName, modelServerFullName, selectedPod)
Expand Down Expand Up @@ -515,6 +536,18 @@ func (r *Router) handleHTTPRoute(c *gin.Context, gatewayKey string) (bool, types
return false, types.NamespacedName{}
}

// Store Gateway key in context for access log
if gatewayKey != "" {
c.Set(GatewayKey, gatewayKey)
}

// Store HTTPRoute name in context for access log
httpRouteName := types.NamespacedName{
Namespace: matchedRoute.Namespace,
Name: matchedRoute.Name,
}
c.Set(HTTPRouteKey, httpRouteName)

// Store the matched prefix in context for URL rewriting
if matchedPrefix != "" {
c.Set("matchedPrefix", matchedPrefix)
Expand Down Expand Up @@ -548,6 +581,9 @@ func (r *Router) handleHTTPRoute(c *gin.Context, gatewayKey string) (bool, types
return false, types.NamespacedName{}
}

// Store InferencePool name in context for access log
c.Set(InferencePoolKey, inferencePoolName)

// Apply HTTPURLRewriteFilter if present
if matchedRule != nil && matchedRule.Filters != nil {
for _, filter := range matchedRule.Filters {
Expand Down
111 changes: 111 additions & 0 deletions pkg/kthena-router/router/router_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ import (
"github.com/agiledragon/gomonkey/v2"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"istio.io/istio/pkg/util/sets"
corev1 "k8s.io/api/core/v1"
v1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/apimachinery/pkg/runtime"
"k8s.io/apimachinery/pkg/types"
"k8s.io/klog/v2"
gatewayv1 "sigs.k8s.io/gateway-api/apis/v1"

aiv1alpha1 "github.com/volcano-sh/kthena/pkg/apis/networking/v1alpha1"
"github.com/volcano-sh/kthena/pkg/kthena-router/accesslog"
Expand Down Expand Up @@ -426,6 +428,115 @@ func TestRouter_HandlerFunc_ScheduleFailure(t *testing.T) {
assert.Contains(t, w.Body.String(), "can't schedule to target pod")
}

func TestRouter_HandlerFunc_AccessLogRoutingInfo(t *testing.T) {
// 1. Setup backend mock
backendHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
fmt.Fprint(w, `{"id":"response-id"}`)
})
router, store, backend := setupTestRouter(backendHandler)
defer backend.Close()

backendURL, _ := url.Parse(backend.URL)
backendIP := backendURL.Hostname()
backendPort, _ := strconv.Atoi(backendURL.Port())

// 2. Populate store
modelServer := &aiv1alpha1.ModelServer{
ObjectMeta: v1.ObjectMeta{Name: "ms-1", Namespace: "default"},
Spec: aiv1alpha1.ModelServerSpec{
Model: func(s string) *string { return &s }("test-model-base"),
WorkloadPort: aiv1alpha1.WorkloadPort{Port: int32(backendPort)},
InferenceEngine: "vLLM",
},
}
pod1 := &corev1.Pod{
ObjectMeta: v1.ObjectMeta{Name: "pod-1", Namespace: "default"},
Status: corev1.PodStatus{PodIP: backendIP, Phase: corev1.PodRunning},
}
// Create Gateway and add to store (required for ModelRoute with parentRefs to match)
gateway := &gatewayv1.Gateway{
ObjectMeta: v1.ObjectMeta{Name: "test-gateway", Namespace: "default"},
Spec: gatewayv1.GatewaySpec{
Listeners: []gatewayv1.Listener{
{
Name: "http",
Port: 80,
Protocol: gatewayv1.HTTPProtocolType,
},
},
},
}
gatewayKey := "default/test-gateway"
store.AddOrUpdateGateway(gateway)

// Create ModelRoute with parentRefs pointing to the Gateway
gatewayKind := gatewayv1.Kind("Gateway")
modelRoute := &aiv1alpha1.ModelRoute{
ObjectMeta: v1.ObjectMeta{Name: "mr-1", Namespace: "default"},
Spec: aiv1alpha1.ModelRouteSpec{
ModelName: "test-model",
ParentRefs: []gatewayv1.ParentReference{
{
Name: gatewayv1.ObjectName("test-gateway"),
Kind: &gatewayKind,
Group: func() *gatewayv1.Group { g := gatewayv1.Group("gateway.networking.k8s.io"); return &g }(),
},
},
Rules: []*aiv1alpha1.Rule{
{
TargetModels: []*aiv1alpha1.TargetModel{
{ModelServerName: "ms-1"},
},
},
},
},
}

store.AddOrUpdateModelServer(modelServer, sets.New(types.NamespacedName{Name: "pod-1", Namespace: "default"}))
store.AddOrUpdatePod(pod1, []*aiv1alpha1.ModelServer{modelServer})
store.AddOrUpdateModelRoute(modelRoute)

// 3. Create request with Gateway API info in context
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)

reqBody := `{"model": "test-model", "prompt": "hello"}`
c.Request, _ = http.NewRequest("POST", "/v1/chat/completions", bytes.NewBufferString(reqBody))
c.Request.Header.Set("Content-Type", "application/json")

// Set Gateway API info in gin.Context
// We set GatewayKey because ModelRoute has parentRefs that match the Gateway
c.Set(GatewayKey, gatewayKey)
c.Set(HTTPRouteKey, types.NamespacedName{Namespace: "default", Name: "test-httproute"})
c.Set(InferencePoolKey, types.NamespacedName{Namespace: "default", Name: "test-inferencepool"})

// 4. Execute access log middleware first to create access log context
router.AccessLog()(c)

// 5. Execute handler
router.HandlerFunc()(c)

// 6. Verify request succeeded
assert.Equal(t, http.StatusOK, w.Code, "Request should succeed to test routing info setting")

// 7. Verify access log context has all routing information
accessCtx := accesslog.GetAccessLogContext(c)
require.NotNil(t, accessCtx, "Access log context should be set")

// Verify AI-specific routing information
assert.Equal(t, "test-model", accessCtx.ModelName, "ModelName should be set from request")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually ModelRoute/ModelServer and HTTPRoute/InferencePool are mutually exclusive; they can't be set simultaneously.

This is also the key point of the test.

Example ref: https://github.com/volcano-sh/kthena/blob/main/test/e2e/router/gateway-inference-extension/e2e_test.go#L143

And it will be easy to test in E2E.

assert.Equal(t, "default/mr-1", accessCtx.ModelRoute, "ModelRoute should be set from matched ModelRoute")
assert.Equal(t, "default/ms-1", accessCtx.ModelServer, "ModelServer should be set from matched ModelServer")
assert.Equal(t, "pod-1", accessCtx.SelectedPod, "SelectedPod should be set from scheduled pod")
assert.Equal(t, accessCtx.RequestID, c.Request.Header.Get("x-request-id"), "RequestID should match request header")

// Verify Gateway API information
assert.Equal(t, gatewayKey, accessCtx.Gateway, "Gateway should be set from gin.Context GatewayKey")
assert.Equal(t, "default/test-httproute", accessCtx.HTTPRoute, "HTTPRoute should be set from gin.Context")
assert.Equal(t, "default/test-inferencepool", accessCtx.InferencePool, "InferencePool should be set from gin.Context")
}

func TestAccessLogConfigurationFromEnv(t *testing.T) {
// Save original environment variables
originalEnabled := os.Getenv("ACCESS_LOG_ENABLED")
Expand Down
Loading
Loading