Skip to content
Draft
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 49 additions & 23 deletions pkg/controller/modelrouter/modelrouter_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@ package modelrouter
import (
"context"
"fmt"
"slices"
"strconv"
"strings"

appsv1 "k8s.io/api/apps/v1"
apierrors "k8s.io/apimachinery/pkg/api/errors"
Expand All @@ -41,9 +43,13 @@ import (

const (
// TODO (varun): cleanup model related identifiers and establish common consensus
modelHeaderIdentifier = "model"
modelIdentifier = "model.aibrix.ai/name"
modelPortIdentifier = "model.aibrix.ai/port"
modelHeaderIdentifier = "model"
modelIdentifier = "model.aibrix.ai/name"
modelPortIdentifier = "model.aibrix.ai/port"
modelSupportedRoutesIdentifier = "model.aibrix.ai/supported-routes"
modelRouteEmbeddings = "embeddings"
modelRouteChatCompletions = "chat-completions"
modelRouteDefault = modelRouteChatCompletions
// TODO (varun): parameterize it or dynamically resolve it
aibrixEnvoyGateway = "aibrix-eg"
aibrixEnvoyGatewayNamespace = "aibrix-system"
Expand Down Expand Up @@ -107,6 +113,43 @@ func Add(mgr manager.Manager, runtimeConfig config.RuntimeConfig) error {
return err
}

// getSupportedRoutesMatchFromLabelsOrDefault returns the HTTPRouteMatch based on the model route labels value
func getSupportedRoutesMatchFromLabelsOrDefault(labels map[string]string, modelHeaderMatch gatewayv1.HTTPHeaderMatch) []gatewayv1.HTTPRouteMatch {
labelValueToRoutePathPrefix := map[string][]string{
modelRouteEmbeddings: {"/v1/embeddings"},
modelRouteChatCompletions: {"/v1/completions", "/v1/chat/completions"},
}

var pathPrefixes []string
if routesLabelValue, ok := labels[modelSupportedRoutesIdentifier]; ok {
routes := strings.Split(routesLabelValue, ",")
for k, route := range labelValueToRoutePathPrefix {
if slices.Contains(routes, k) {
pathPrefixes = append(pathPrefixes, route...)
}
}
}

// Add the default pathPrefixes if no route defines via labels
if len(pathPrefixes) == 0 {
pathPrefixes = append(pathPrefixes, labelValueToRoutePathPrefix[modelRouteDefault]...)
}

var routesmatch []gatewayv1.HTTPRouteMatch
for _, path := range pathPrefixes {
routesmatch = append(routesmatch, gatewayv1.HTTPRouteMatch{
Path: &gatewayv1.HTTPPathMatch{
Type: ptr.To(gatewayv1.PathMatchPathPrefix),
Value: ptr.To(path),
},
Headers: []gatewayv1.HTTPHeaderMatch{
modelHeaderMatch,
},
})
}
return routesmatch
}

type ModelRouter struct {
client.Client
Scheme *runtime.Scheme
Expand Down Expand Up @@ -192,6 +235,8 @@ func (m *ModelRouter) createHTTPRoute(namespace string, labels map[string]string
Value: modelName,
}

httpRoutesMatch := getSupportedRoutesMatchFromLabelsOrDefault(labels, modelHeaderMatch)

httpRoute := gatewayv1.HTTPRoute{
ObjectMeta: metav1.ObjectMeta{
Name: fmt.Sprintf("%s-router", modelName),
Expand All @@ -208,26 +253,7 @@ func (m *ModelRouter) createHTTPRoute(namespace string, labels map[string]string
},
Rules: []gatewayv1.HTTPRouteRule{
{
Matches: []gatewayv1.HTTPRouteMatch{
{
Path: &gatewayv1.HTTPPathMatch{
Type: ptr.To(gatewayv1.PathMatchPathPrefix),
Value: ptr.To("/v1/completions"),
},
Headers: []gatewayv1.HTTPHeaderMatch{
modelHeaderMatch,
},
},
{
Path: &gatewayv1.HTTPPathMatch{
Type: ptr.To(gatewayv1.PathMatchPathPrefix),
Value: ptr.To("/v1/chat/completions"),
},
Headers: []gatewayv1.HTTPHeaderMatch{
modelHeaderMatch,
},
},
},
Matches: httpRoutesMatch,
BackendRefs: []gatewayv1.HTTPBackendRef{
{
BackendRef: gatewayv1.BackendRef{
Expand Down
75 changes: 75 additions & 0 deletions pkg/plugins/gateway/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,22 @@

import (
"encoding/json"
"fmt"
"strings"

configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
"github.com/openai/openai-go"
"github.com/openai/openai-go/packages/param"
"github.com/vllm-project/aibrix/pkg/utils"
"k8s.io/klog/v2"
)

// OpenAI has a 2048 size limits for embeddings array inputs
// see https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-input
var maxEmbeddingInputArraySize = 2048

// validateRequestBody validates input by unmarshaling request body into respective openai-golang struct based on requestpath.
// nolint:nakedret
func validateRequestBody(requestID, requestPath string, requestBody []byte, user utils.User) (model, message string, stream bool, errRes *extProcPb.ProcessingResponse) {
Expand Down Expand Up @@ -69,6 +75,24 @@
}
model = completionObj.Model
message = completionObj.Prompt
} else if requestPath == "/v1/embeddings" {
message = "" // prefix_cache algorithms are not relevant for embeddings
var jsonMap map[string]json.RawMessage
if err := json.Unmarshal(requestBody, &jsonMap); err != nil {
klog.ErrorS(err, "error to unmarshal request body", "requestID", requestID, "requestBody", string(requestBody))
errRes = buildErrorResponse(envoyTypePb.StatusCode_BadRequest, "error processing request body", HeaderErrorRequestBodyProcessing, "true")
return
}
embeddingObj := openai.EmbeddingNewParams{}
if err := json.Unmarshal(requestBody, &embeddingObj); err != nil {
klog.ErrorS(err, "error to unmarshal embeddings object", "requestID", requestID, "requestBody", string(requestBody))
errRes = buildErrorResponse(envoyTypePb.StatusCode_BadRequest, "error processing request body", HeaderErrorRequestBodyProcessing, "true")
return
}
model = embeddingObj.Model
if errRes = checkEmbeddingInputSequenceLen(requestID, embeddingObj); errRes != nil {
return
}
} else {
errRes = buildErrorResponse(envoyTypePb.StatusCode_NotImplemented, "unknown request path", HeaderErrorRequestBodyProcessing, "true")
return
Expand Down Expand Up @@ -142,6 +166,57 @@
return builder.String(), nil
}

// getEmbeddingsInputLen returns the len of the embeddings object
func checkEmbeddingInputSequenceLen(requestID string, embeddingObj openai.EmbeddingNewParams) *extProcPb.ProcessingResponse {
inputParam := embeddingObj.Input
var size int
isArrayType := false
switch input := embeddingNewParamsInputUnionAsAny(&inputParam).(type) {
case *string:
size = len(*input)
case *[]string:
size = len(*input)
isArrayType = true
case *[]int64:
size = len(*input)
case *[][]int64:
size = len(*input)
isArrayType = true
default:
// Should never happend, but if input is of an unexpected non-nil type, let's explicitly error log it.

Check failure on line 186 in pkg/plugins/gateway/util.go

View workflow job for this annotation

GitHub Actions / lint

`happend` is a misspelling of `happened` (misspell)
// Size will be 0 in this case, which is then handled by the check below.
if input != nil {
klog.ErrorS(nil, "unhandled embedding input type", "requestID", requestID, "inputType", fmt.Sprintf("%T", input))
}
}

if size == 0 {
klog.ErrorS(nil, "no input in the request body", "requestID", requestID)
return buildErrorResponse(envoyTypePb.StatusCode_BadRequest, "no messages in the request body", HeaderErrorRequestBodyProcessing, "true")
}

if isArrayType && size > maxEmbeddingInputArraySize {
klog.ErrorS(nil, "embeddings content is too large", "requestID", requestID, "size", size)
return buildErrorResponse(envoyTypePb.StatusCode_BadRequest, "embeddings content is too large", HeaderErrorRequestBodyProcessing, "true")
}

return nil
}

// TODO: make asAny method publicly available on OpenAI go
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This TODO suggests making the asAny method publicly available on OpenAI go. It would be helpful to provide more context on why this method needs to be public and the potential benefits it would offer.

func embeddingNewParamsInputUnionAsAny(u *openai.EmbeddingNewParamsInputUnion) any {
if !param.IsOmitted(u.OfString) {
return &u.OfString.Value
} else if !param.IsOmitted(u.OfArrayOfStrings) {
return &u.OfArrayOfStrings
} else if !param.IsOmitted(u.OfArrayOfTokens) {
return &u.OfArrayOfTokens
} else if !param.IsOmitted(u.OfArrayOfTokenArrays) {
return &u.OfArrayOfTokenArrays
}
return nil
}

// generateErrorResponse construct envoy proxy error response
// deprecated: use buildErrorResponse
func generateErrorResponse(statusCode envoyTypePb.StatusCode, headers []*configPb.HeaderValueOption, body string) *extProcPb.ProcessingResponse {
Expand Down
Loading