diff --git a/pkg/controller/modelrouter/modelrouter_controller.go b/pkg/controller/modelrouter/modelrouter_controller.go index 1e6e10611..a57da2523 100644 --- a/pkg/controller/modelrouter/modelrouter_controller.go +++ b/pkg/controller/modelrouter/modelrouter_controller.go @@ -19,7 +19,9 @@ package modelrouter import ( "context" "fmt" + "slices" "strconv" + "strings" appsv1 "k8s.io/api/apps/v1" apierrors "k8s.io/apimachinery/pkg/api/errors" @@ -35,15 +37,17 @@ import ( modelv1alpha1 "github.com/vllm-project/aibrix/api/model/v1alpha1" orchestrationv1alpha1 "github.com/vllm-project/aibrix/api/orchestration/v1alpha1" "github.com/vllm-project/aibrix/pkg/config" + aibrixgateway "github.com/vllm-project/aibrix/pkg/plugins/gateway" gatewayv1 "sigs.k8s.io/gateway-api/apis/v1" gatewayv1beta1 "sigs.k8s.io/gateway-api/apis/v1beta1" ) const ( // TODO (varun): cleanup model related identifiers and establish common consensus - modelHeaderIdentifier = "model" - modelIdentifier = "model.aibrix.ai/name" - modelPortIdentifier = "model.aibrix.ai/port" + modelHeaderIdentifier = "model" + modelIdentifier = "model.aibrix.ai/name" + modelPortIdentifier = "model.aibrix.ai/port" + modelSupportedRequestTypeIdentifier = "model.aibrix.ai/supported-request-types" // TODO (varun): parameterize it or dynamically resolve it aibrixEnvoyGateway = "aibrix-eg" aibrixEnvoyGatewayNamespace = "aibrix-system" @@ -51,6 +55,16 @@ const ( defaultModelServingPort = 8000 ) +var ( + requestTypeIdentifierToSupportedRoutePathPrefix = map[string][]string{ + string(aibrixgateway.OpenAiRequestEmbeddingsType): {string(aibrixgateway.OpenAiRequestEmbeddingsPath)}, + string(aibrixgateway.OpenAiRequestChatCompletionsType): {string(aibrixgateway.OpenAiRequestCompletionsPath), string(aibrixgateway.OpenAiRequestChatCompletionsPath)}, + string(aibrixgateway.OpenAiRequestCompletionsType): {string(aibrixgateway.OpenAiRequestCompletionsPath), string(aibrixgateway.OpenAiRequestChatCompletionsPath)}, + } + + defaultSupportedRequestType = string(aibrixgateway.OpenAiRequestChatCompletionsType) +) + //+kubebuilder:rbac:groups=apps,resources=deployments,verbs=get;list;watch;create;update;patch;delete //+kubebuilder:rbac:groups=orchestration.aibrix.ai,resources=rayclusterfleets,verbs=get;list;watch;create;update;patch;delete //+kubebuilder:rbac:groups=gateway.networking.k8s.io,resources=httproutes,verbs=get;list;watch;create;update;patch;delete @@ -107,6 +121,38 @@ func Add(mgr manager.Manager, runtimeConfig config.RuntimeConfig) error { return err } +// getSupportedRoutesMatchFromLabelsOrDefault returns the HTTPRouteMatch based on the model route labels value +func getSupportedRoutesMatchFromLabelsOrDefault(labels map[string]string, modelHeaderMatch gatewayv1.HTTPHeaderMatch) []gatewayv1.HTTPRouteMatch { + var pathPrefixes []string + if routesLabelValue, ok := labels[modelSupportedRequestTypeIdentifier]; ok { + routesIdentifier := strings.Split(routesLabelValue, ",") + for id, paths := range requestTypeIdentifierToSupportedRoutePathPrefix { + if slices.Contains(routesIdentifier, id) { + pathPrefixes = append(pathPrefixes, paths...) + } + } + } + + // Add the default pathPrefixes if no route defines via labels + if len(pathPrefixes) == 0 { + pathPrefixes = append(pathPrefixes, requestTypeIdentifierToSupportedRoutePathPrefix[defaultSupportedRequestType]...) + } + + var routesmatch []gatewayv1.HTTPRouteMatch + for _, path := range pathPrefixes { + routesmatch = append(routesmatch, gatewayv1.HTTPRouteMatch{ + Path: &gatewayv1.HTTPPathMatch{ + Type: ptr.To(gatewayv1.PathMatchPathPrefix), + Value: ptr.To(path), + }, + Headers: []gatewayv1.HTTPHeaderMatch{ + modelHeaderMatch, + }, + }) + } + return routesmatch +} + type ModelRouter struct { client.Client Scheme *runtime.Scheme @@ -192,6 +238,8 @@ func (m *ModelRouter) createHTTPRoute(namespace string, labels map[string]string Value: modelName, } + httpRoutesMatch := getSupportedRoutesMatchFromLabelsOrDefault(labels, modelHeaderMatch) + httpRoute := gatewayv1.HTTPRoute{ ObjectMeta: metav1.ObjectMeta{ Name: fmt.Sprintf("%s-router", modelName), @@ -208,26 +256,7 @@ func (m *ModelRouter) createHTTPRoute(namespace string, labels map[string]string }, Rules: []gatewayv1.HTTPRouteRule{ { - Matches: []gatewayv1.HTTPRouteMatch{ - { - Path: &gatewayv1.HTTPPathMatch{ - Type: ptr.To(gatewayv1.PathMatchPathPrefix), - Value: ptr.To("/v1/completions"), - }, - Headers: []gatewayv1.HTTPHeaderMatch{ - modelHeaderMatch, - }, - }, - { - Path: &gatewayv1.HTTPPathMatch{ - Type: ptr.To(gatewayv1.PathMatchPathPrefix), - Value: ptr.To("/v1/chat/completions"), - }, - Headers: []gatewayv1.HTTPHeaderMatch{ - modelHeaderMatch, - }, - }, - }, + Matches: httpRoutesMatch, BackendRefs: []gatewayv1.HTTPBackendRef{ { BackendRef: gatewayv1.BackendRef{ diff --git a/pkg/plugins/gateway/gateway.go b/pkg/plugins/gateway/gateway.go index 608ebc36a..3f322a2e2 100644 --- a/pkg/plugins/gateway/gateway.go +++ b/pkg/plugins/gateway/gateway.go @@ -82,6 +82,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { var respErrorCode int var model string var requestPath string + var requestType OpenAiRequestType var routingAlgorithm types.RoutingAlgorithm var routerCtx *types.RoutingContext var stream, isRespError bool @@ -113,7 +114,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { resp, user, rpm, routingAlgorithm, requestPath = s.HandleRequestHeaders(ctx, requestID, req) case *extProcPb.ProcessingRequest_RequestBody: - resp, model, routerCtx, stream, traceTerm = s.HandleRequestBody(ctx, requestID, requestPath, req, user, routingAlgorithm) + resp, model, routerCtx, stream, requestType, traceTerm = s.HandleRequestBody(ctx, requestID, requestPath, req, user, routingAlgorithm) if routerCtx != nil { ctx = routerCtx } @@ -135,7 +136,7 @@ func (s *Server) Process(srv extProcPb.ExternalProcessor_ProcessServer) error { resp = s.responseErrorProcessing(ctx, resp, respErrorCode, model, requestID, string(req.Request.(*extProcPb.ProcessingRequest_ResponseBody).ResponseBody.GetBody())) } else { - resp, completed = s.HandleResponseBody(ctx, requestID, req, user, rpm, model, stream, traceTerm, completed) + resp, completed = s.HandleResponseBody(ctx, requestID, req, requestType, user, rpm, model, stream, traceTerm, completed) } default: klog.Infof("Unknown Request type %+v\n", v) @@ -205,7 +206,8 @@ func (s *Server) validateHTTPRouteStatus(ctx context.Context, model string) erro } func (s *Server) responseErrorProcessing(ctx context.Context, resp *extProcPb.ProcessingResponse, respErrorCode int, - model, requestID, errMsg string) *extProcPb.ProcessingResponse { + model, requestID, errMsg string, +) *extProcPb.ProcessingResponse { httprouteErr := s.validateHTTPRouteStatus(ctx, model) if errMsg != "" && httprouteErr != nil { errMsg = fmt.Sprintf("%s. %s", errMsg, httprouteErr.Error()) diff --git a/pkg/plugins/gateway/gateway_req_body.go b/pkg/plugins/gateway/gateway_req_body.go index d0bacb10c..552f59bce 100644 --- a/pkg/plugins/gateway/gateway_req_body.go +++ b/pkg/plugins/gateway/gateway_req_body.go @@ -31,14 +31,17 @@ import ( ) func (s *Server) HandleRequestBody(ctx context.Context, requestID string, requestPath string, req *extProcPb.ProcessingRequest, - user utils.User, routingAlgorithm types.RoutingAlgorithm) (*extProcPb.ProcessingResponse, string, *types.RoutingContext, bool, int64) { + user utils.User, routingAlgorithm types.RoutingAlgorithm, +) (*extProcPb.ProcessingResponse, string, *types.RoutingContext, bool, OpenAiRequestType, int64) { var routingCtx *types.RoutingContext var term int64 // Identify the trace window + requestType := NewOpenAiRequestTypeFromPath(requestPath) + body := req.Request.(*extProcPb.ProcessingRequest_RequestBody) - model, message, stream, errRes := validateRequestBody(requestID, requestPath, body.RequestBody.GetBody(), user) + model, message, stream, errRes := validateRequestBody(requestID, requestType, body.RequestBody.GetBody(), user) if errRes != nil { - return errRes, model, routingCtx, stream, term + return errRes, model, routingCtx, stream, requestType, term } // early reject the request if model doesn't exist. @@ -46,8 +49,9 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, reques klog.ErrorS(nil, "model doesn't exist in cache, probably wrong model name", "requestID", requestID, "model", model) return generateErrorResponse(envoyTypePb.StatusCode_BadRequest, []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ - Key: HeaderErrorNoModelBackends, RawValue: []byte(model)}}}, - fmt.Sprintf("model %s does not exist", model)), model, routingCtx, stream, term + Key: HeaderErrorNoModelBackends, RawValue: []byte(model), + }}}, + fmt.Sprintf("model %s does not exist", model)), model, routingCtx, stream, requestType, term } // early reject if no pods are ready to accept request for a model @@ -56,8 +60,9 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, reques klog.ErrorS(err, "no ready pod available", "requestID", requestID, "model", model) return generateErrorResponse(envoyTypePb.StatusCode_ServiceUnavailable, []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ - Key: HeaderErrorNoModelBackends, RawValue: []byte("true")}}}, - fmt.Sprintf("error on getting pods for model %s", model)), model, routingCtx, stream, term + Key: HeaderErrorNoModelBackends, RawValue: []byte("true"), + }}}, + fmt.Sprintf("error on getting pods for model %s", model)), model, routingCtx, stream, requestType, term } routingCtx = types.NewRoutingContext(ctx, routingAlgorithm, model, message, requestID, user.Name) @@ -72,8 +77,9 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, reques return generateErrorResponse( envoyTypePb.StatusCode_ServiceUnavailable, []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ - Key: HeaderErrorRouting, RawValue: []byte("true")}}}, - "error on selecting target pod"), model, routingCtx, stream, term + Key: HeaderErrorRouting, RawValue: []byte("true"), + }}}, + "error on selecting target pod"), model, routingCtx, stream, requestType, term } headers = buildEnvoyProxyHeaders(headers, HeaderRoutingStrategy, string(routingAlgorithm), @@ -93,5 +99,5 @@ func (s *Server) HandleRequestBody(ctx context.Context, requestID string, reques }, }, }, - }, model, routingCtx, stream, term + }, model, routingCtx, stream, requestType, term } diff --git a/pkg/plugins/gateway/gateway_rsp_body.go b/pkg/plugins/gateway/gateway_rsp_body.go index d8bcf4634..6eeb70105 100644 --- a/pkg/plugins/gateway/gateway_rsp_body.go +++ b/pkg/plugins/gateway/gateway_rsp_body.go @@ -35,9 +35,26 @@ import ( "github.com/vllm-project/aibrix/pkg/utils" ) -func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, user utils.User, rpm int64, model string, stream bool, traceTerm int64, hasCompleted bool) (*extProcPb.ProcessingResponse, bool) { +func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req *extProcPb.ProcessingRequest, requestType OpenAiRequestType, user utils.User, rpm int64, model string, stream bool, traceTerm int64, hasCompleted bool) (*extProcPb.ProcessingResponse, bool) { b := req.Request.(*extProcPb.ProcessingRequest_ResponseBody) + switch requestType { + case OpenAiRequestChatCompletionsType, OpenAiRequestCompletionsType: + return s.handleChatCompletionsResponseBody(ctx, requestID, b, user, rpm, model, stream, traceTerm, hasCompleted) + case OpenAiRequestEmbeddingsType: + return s.handleEmbeddingsResponseBody(ctx, requestID, b, user, rpm, model, false, traceTerm, hasCompleted) + default: + // all other openAi request types (e.g. audio, image, ..) are not supported yet + return generateErrorResponse( + envoyTypePb.StatusCode_NotImplemented, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: HeaderErrorResponseUnknown, RawValue: []byte("true"), + }}}, + "request type not supported"), true + } +} + +func (s *Server) handleChatCompletionsResponseBody(ctx context.Context, requestID string, b *extProcPb.ProcessingRequest_ResponseBody, user utils.User, rpm int64, model string, stream bool, traceTerm int64, hasCompleted bool) (*extProcPb.ProcessingResponse, bool) { var res openai.ChatCompletion var usage openai.CompletionUsage var promptTokens, completionTokens int64 @@ -203,3 +220,143 @@ func (s *Server) HandleResponseBody(ctx context.Context, requestID string, req * }, }, complete } + +func (s *Server) handleEmbeddingsResponseBody(ctx context.Context, requestID string, b *extProcPb.ProcessingRequest_ResponseBody, user utils.User, rpm int64, model string, stream bool, traceTerm int64, hasCompleted bool) (*extProcPb.ProcessingResponse, bool) { + var res openai.CreateEmbeddingResponse + var usage openai.CreateEmbeddingResponseUsage + var promptTokens, completionTokens int64 + var headers []*configPb.HeaderValueOption + complete := hasCompleted + routerCtx, _ := ctx.(*types.RoutingContext) + + defer func() { + // Wrapped in a function to delay the evaluation of parameters. Using complete to make sure DoneRequestTrace only call once for a request. + if !hasCompleted && complete { + s.cache.DoneRequestTrace(routerCtx, requestID, model, promptTokens, completionTokens, traceTerm) + if routerCtx != nil { + routerCtx.Delete() + } + } + }() + + // Use request ID as a key to store per-request buffer + // Retrieve or create buffer + buf, _ := requestBuffers.LoadOrStore(requestID, &bytes.Buffer{}) + buffer := buf.(*bytes.Buffer) + // Append data to per-request buffer + buffer.Write(b.ResponseBody.Body) + + if !b.ResponseBody.EndOfStream { + // Partial data received, wait for more chunks, we just return a common response here. + return &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{}, + }, + }, + }, complete + } + + // Last part received, process the full response + finalBody := buffer.Bytes() + // Clean up the buffer after final processing + requestBuffers.Delete(requestID) + + if err := json.Unmarshal(finalBody, &res); err != nil { + klog.ErrorS(err, "error to unmarshal response", "requestID", requestID, "responseBody", string(b.ResponseBody.GetBody())) + complete = true + return generateErrorResponse( + envoyTypePb.StatusCode_InternalServerError, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: HeaderErrorResponseUnmarshal, RawValue: []byte("true"), + }}}, + err.Error()), complete + } else if len(res.Model) == 0 { + msg := ErrorUnknownResponse.Error() + responseBodyContent := string(b.ResponseBody.GetBody()) + if len(responseBodyContent) != 0 { + msg = responseBodyContent + } + klog.ErrorS(err, "unexpected response", "requestID", requestID, "responseBody", responseBodyContent) + complete = true + return generateErrorResponse( + envoyTypePb.StatusCode_InternalServerError, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: HeaderErrorResponseUnknown, RawValue: []byte("true"), + }}}, + msg), complete + } + // Do not overwrite model, res can be empty. + usage = res.Usage + + var requestEnd string + if usage.TotalTokens != 0 { + complete = true + // Update promptTokens and completeTokens + promptTokens = usage.PromptTokens + completionTokens = 0 // no completion tokens in embeddings request + // Count token per user. + if user.Name != "" { + tpm, err := s.ratelimiter.Incr(ctx, fmt.Sprintf("%v_TPM_CURRENT", user.Name), res.Usage.TotalTokens) + if err != nil { + return generateErrorResponse( + envoyTypePb.StatusCode_InternalServerError, + []*configPb.HeaderValueOption{{Header: &configPb.HeaderValue{ + Key: HeaderErrorIncrTPM, RawValue: []byte("true"), + }}}, + err.Error()), complete + } + + headers = append(headers, + &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: HeaderUpdateRPM, + RawValue: []byte(fmt.Sprintf("%d", rpm)), + }, + }, + &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: HeaderUpdateTPM, + RawValue: []byte(fmt.Sprintf("%d", tpm)), + }, + }, + ) + requestEnd = fmt.Sprintf(requestEnd+"rpm: %d, tpm: %d, ", rpm, tpm) + } + + if routerCtx != nil && routerCtx.HasRouted() { + targetPodIP := routerCtx.TargetAddress() + headers = append(headers, + &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: HeaderTargetPod, + RawValue: []byte(targetPodIP), + }, + }, + &configPb.HeaderValueOption{ + Header: &configPb.HeaderValue{ + Key: HeaderRequestID, + RawValue: []byte(requestID), + }, + }, + ) + requestEnd = fmt.Sprintf(requestEnd+"targetPod: %s", targetPodIP) + } + + klog.Infof("request end, requestID: %s - %s", requestID, requestEnd) + } else if b.ResponseBody.EndOfStream { + complete = true + } + + return &extProcPb.ProcessingResponse{ + Response: &extProcPb.ProcessingResponse_ResponseBody{ + ResponseBody: &extProcPb.BodyResponse{ + Response: &extProcPb.CommonResponse{ + HeaderMutation: &extProcPb.HeaderMutation{ + SetHeaders: headers, + }, + }, + }, + }, + }, complete +} diff --git a/pkg/plugins/gateway/openai_utils.go b/pkg/plugins/gateway/openai_utils.go new file mode 100644 index 000000000..735e7d64d --- /dev/null +++ b/pkg/plugins/gateway/openai_utils.go @@ -0,0 +1,46 @@ +/* +Copyright 2024 The Aibrix Team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package gateway + +type ( + OpenAiRequestType string + OpenAiRequestPath string +) + +var ( + OpenAiRequestChatCompletionsType OpenAiRequestType = "chat-completions" + OpenAiRequestCompletionsType OpenAiRequestType = "completions" + OpenAiRequestEmbeddingsType OpenAiRequestType = "embeddings" + OpenAiRequestUnknownType OpenAiRequestType = "unknown" + OpenAiRequestChatCompletionsPath OpenAiRequestPath = "/v1/chat/completions" + OpenAiRequestCompletionsPath OpenAiRequestPath = "/v1/completions" + OpenAiRequestEmbeddingsPath OpenAiRequestPath = "/v1/embeddings" + OpenAiRequestUnknownPath OpenAiRequestPath = "" +) + +func NewOpenAiRequestTypeFromPath(path string) OpenAiRequestType { + requestType := OpenAiRequestUnknownType + switch path { + case string(OpenAiRequestCompletionsPath): + requestType = OpenAiRequestCompletionsType + case string(OpenAiRequestChatCompletionsPath): + requestType = OpenAiRequestChatCompletionsType + case string(OpenAiRequestEmbeddingsPath): + requestType = OpenAiRequestEmbeddingsType + } + return requestType +} diff --git a/pkg/plugins/gateway/util.go b/pkg/plugins/gateway/util.go index 0e57ea540..28770537f 100644 --- a/pkg/plugins/gateway/util.go +++ b/pkg/plugins/gateway/util.go @@ -18,21 +18,29 @@ package gateway import ( "encoding/json" + "fmt" "strings" configPb "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3" envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3" "github.com/openai/openai-go" + "github.com/openai/openai-go/packages/param" "github.com/vllm-project/aibrix/pkg/utils" "k8s.io/klog/v2" ) -// validateRequestBody validates input by unmarshaling request body into respective openai-golang struct based on requestpath. +// OpenAI has a 2048 size limits for embeddings array inputs +// see https://platform.openai.com/docs/api-reference/embeddings/create#embeddings-create-input +var maxEmbeddingInputArraySize = 2048 + +// validateRequestBody validates input by unmarshaling request body into respective openai-golang struct based on requestType. // nolint:nakedret -func validateRequestBody(requestID, requestPath string, requestBody []byte, user utils.User) (model, message string, stream bool, errRes *extProcPb.ProcessingResponse) { +func validateRequestBody(requestID string, requestType OpenAiRequestType, requestBody []byte, user utils.User) (model, message string, stream bool, errRes *extProcPb.ProcessingResponse) { var streamOptions openai.ChatCompletionStreamOptionsParam - if requestPath == "/v1/chat/completions" { + switch requestType { + + case OpenAiRequestChatCompletionsType: var jsonMap map[string]json.RawMessage if err := json.Unmarshal(requestBody, &jsonMap); err != nil { klog.ErrorS(err, "error to unmarshal request body", "requestID", requestID, "requestBody", string(requestBody)) @@ -53,7 +61,8 @@ func validateRequestBody(requestID, requestPath string, requestBody []byte, user if errRes = validateStreamOptions(requestID, user, &stream, streamOptions, jsonMap); errRes != nil { return } - } else if requestPath == "/v1/completions" { + + case OpenAiRequestCompletionsType: // openai.CompletionsNewParams does not support json unmarshal for CompletionNewParamsPromptUnion in release v0.1.0-beta.10 // once supported, input request will be directly unmarshal into openai.CompletionsNewParams type Completion struct { @@ -69,12 +78,31 @@ func validateRequestBody(requestID, requestPath string, requestBody []byte, user } model = completionObj.Model message = completionObj.Prompt - } else { + + case OpenAiRequestEmbeddingsType: + message = "" // prefix_cache algorithms are not relevant for embeddings + var jsonMap map[string]json.RawMessage + if err := json.Unmarshal(requestBody, &jsonMap); err != nil { + klog.ErrorS(err, "error to unmarshal request body", "requestID", requestID, "requestBody", string(requestBody)) + errRes = buildErrorResponse(envoyTypePb.StatusCode_BadRequest, "error processing request body", HeaderErrorRequestBodyProcessing, "true") + return + } + embeddingObj := openai.EmbeddingNewParams{} + if err := json.Unmarshal(requestBody, &embeddingObj); err != nil { + klog.ErrorS(err, "error to unmarshal embeddings object", "requestID", requestID, "requestBody", string(requestBody)) + errRes = buildErrorResponse(envoyTypePb.StatusCode_BadRequest, "error processing request body", HeaderErrorRequestBodyProcessing, "true") + return + } + model = embeddingObj.Model + if errRes = checkEmbeddingInputSequenceLen(requestID, embeddingObj); errRes != nil { + return + } + case OpenAiRequestUnknownType: errRes = buildErrorResponse(envoyTypePb.StatusCode_NotImplemented, "unknown request path", HeaderErrorRequestBodyProcessing, "true") return } - klog.V(4).InfoS("validateRequestBody", "requestID", requestID, "requestPath", requestPath, "model", model, "message", message, "stream", stream, "streamOptions", streamOptions) + klog.V(4).InfoS("validateRequestBody", "requestID", requestID, "requestType", requestType, "model", model, "message", message, "stream", stream, "streamOptions", streamOptions) return } @@ -142,6 +170,57 @@ func getChatCompletionsMessage(requestID string, chatCompletionObj openai.ChatCo return builder.String(), nil } +// getEmbeddingsInputLen returns the len of the embeddings object +func checkEmbeddingInputSequenceLen(requestID string, embeddingObj openai.EmbeddingNewParams) *extProcPb.ProcessingResponse { + inputParam := embeddingObj.Input + var size int + isArrayType := false + switch input := embeddingNewParamsInputUnionAsAny(&inputParam).(type) { + case *string: + size = len(*input) + case *[]string: + size = len(*input) + isArrayType = true + case *[]int64: + size = len(*input) + case *[][]int64: + size = len(*input) + isArrayType = true + default: + // Should never happen, but if input is of an unexpected non-nil type, let's explicitly error log it. + // Size will be 0 in this case, which is then handled by the check below. + if input != nil { + klog.ErrorS(nil, "unhandled embedding input type", "requestID", requestID, "inputType", fmt.Sprintf("%T", input)) + } + } + + if size == 0 { + klog.ErrorS(nil, "no input in the request body", "requestID", requestID) + return buildErrorResponse(envoyTypePb.StatusCode_BadRequest, "no messages in the request body", HeaderErrorRequestBodyProcessing, "true") + } + + if isArrayType && size > maxEmbeddingInputArraySize { + klog.ErrorS(nil, "embeddings content is too large", "requestID", requestID, "size", size) + return buildErrorResponse(envoyTypePb.StatusCode_BadRequest, "embeddings content is too large", HeaderErrorRequestBodyProcessing, "true") + } + + return nil +} + +// TODO: make asAny method publicly available on OpenAI go +func embeddingNewParamsInputUnionAsAny(u *openai.EmbeddingNewParamsInputUnion) any { + if !param.IsOmitted(u.OfString) { + return &u.OfString.Value + } else if !param.IsOmitted(u.OfArrayOfStrings) { + return &u.OfArrayOfStrings + } else if !param.IsOmitted(u.OfArrayOfTokens) { + return &u.OfArrayOfTokens + } else if !param.IsOmitted(u.OfArrayOfTokenArrays) { + return &u.OfArrayOfTokenArrays + } + return nil +} + // generateErrorResponse construct envoy proxy error response // deprecated: use buildErrorResponse func generateErrorResponse(statusCode envoyTypePb.StatusCode, headers []*configPb.HeaderValueOption, body string) *extProcPb.ProcessingResponse { diff --git a/pkg/plugins/gateway/util_test.go b/pkg/plugins/gateway/util_test.go index dacb89215..4570b5122 100644 --- a/pkg/plugins/gateway/util_test.go +++ b/pkg/plugins/gateway/util_test.go @@ -28,7 +28,7 @@ import ( func Test_ValidateRequestBody(t *testing.T) { testCases := []struct { message string - requestPath string + requestType OpenAiRequestType requestBody []byte model string messages string @@ -38,30 +38,30 @@ func Test_ValidateRequestBody(t *testing.T) { }{ { message: "unknown path", - requestPath: "/v1/unknown", + requestType: OpenAiRequestUnknownType, statusCode: envoyTypePb.StatusCode_NotImplemented, }, { message: "/v1/chat/completions json unmarhsal error", - requestPath: "/v1/chat/completions", + requestType: OpenAiRequestChatCompletionsType, requestBody: []byte("bad_request"), statusCode: envoyTypePb.StatusCode_BadRequest, }, { message: "/v1/chat/completions json unmarhsal ChatCompletionsNewParams", - requestPath: "/v1/chat/completions", + requestType: OpenAiRequestChatCompletionsType, requestBody: []byte(`{"model": 1}`), statusCode: envoyTypePb.StatusCode_BadRequest, }, { message: "/v1/chat/completions json unmarhsal no messages", - requestPath: "/v1/chat/completions", + requestType: OpenAiRequestChatCompletionsType, requestBody: []byte(`{"model": "llama2-7b"}`), statusCode: envoyTypePb.StatusCode_BadRequest, }, { message: "/v1/chat/completions json unmarhsal valid messages", - requestPath: "/v1/chat/completions", + requestType: OpenAiRequestChatCompletionsType, requestBody: []byte(`{"model": "llama2-7b", "messages": [{"role": "system", "content": "this is system"},{"role": "user", "content": "say this is test"}]}`), model: "llama2-7b", messages: "this is system say this is test", @@ -69,13 +69,13 @@ func Test_ValidateRequestBody(t *testing.T) { }, { message: "/v1/chat/completions json unmarhsal invalid messages with complex content", - requestPath: "/v1/chat/completions", + requestType: OpenAiRequestChatCompletionsType, requestBody: []byte(`{"model": "llama2-7b", "messages": [{"role": "system", "content": "this is system"},{"role": "user", "content": {"type": "text", "text": "say this is test", "complex": make(chan int)}}]}`), statusCode: envoyTypePb.StatusCode_BadRequest, }, { message: "/v1/chat/completions json unmarhsal valid messages with complex content", - requestPath: "/v1/chat/completions", + requestType: OpenAiRequestChatCompletionsType, requestBody: []byte(`{"model": "llama2-7b", "messages": [{"role": "system", "content": "this is system"},{"role": "user", "content": [{"type": "text", "text": "say this is test"}, {"type": "text", "text": "say this is test"}]}]}`), model: "llama2-7b", messages: "this is system [{\"text\":\"say this is test\",\"type\":\"text\"},{\"text\":\"say this is test\",\"type\":\"text\"}]", @@ -83,7 +83,7 @@ func Test_ValidateRequestBody(t *testing.T) { }, { message: "/v1/chat/completions json unmarhsal valid messages with stop string param", - requestPath: "/v1/chat/completions", + requestType: OpenAiRequestChatCompletionsType, requestBody: []byte(`{"model": "llama2-7b", "messages": [{"role": "system", "content": "this is system"},{"role": "user", "content": "say this is test"}], "stop": "stop"}`), model: "llama2-7b", messages: "this is system say this is test", @@ -91,7 +91,7 @@ func Test_ValidateRequestBody(t *testing.T) { }, { message: "/v1/chat/completions json unmarhsal valid messages with stop array param", - requestPath: "/v1/chat/completions", + requestType: OpenAiRequestChatCompletionsType, requestBody: []byte(`{"model": "llama2-7b", "messages": [{"role": "system", "content": "this is system"},{"role": "user", "content": "say this is test"}], "stop": ["stop"]}`), model: "llama2-7b", messages: "this is system say this is test", @@ -99,13 +99,13 @@ func Test_ValidateRequestBody(t *testing.T) { }, { message: "/v1/chat/completions json unmarshal invalid stream bool", - requestPath: "/v1/chat/completions", + requestType: OpenAiRequestChatCompletionsType, requestBody: []byte(`{"model": "llama2-7b", "stream": "true", "messages": [{"role": "system", "content": "this is system"}]}`), statusCode: envoyTypePb.StatusCode_BadRequest, }, { message: "/v1/chat/completions json unmarshal stream options is null", - requestPath: "/v1/chat/completions", + requestType: OpenAiRequestChatCompletionsType, user: utils.User{Tpm: 1}, requestBody: []byte(`{"model": "llama2-7b", "stream": true, "messages": [{"role": "system", "content": "this is system"}]}`), statusCode: envoyTypePb.StatusCode_BadRequest, @@ -113,20 +113,20 @@ func Test_ValidateRequestBody(t *testing.T) { { message: "/v1/chat/completions stream_options.include_usage == false with user.TPM >= 1 is NOT OK", user: utils.User{Tpm: 1}, - requestPath: "/v1/chat/completions", + requestType: OpenAiRequestChatCompletionsType, requestBody: []byte(`{"model": "llama2-7b", "stream": true, "stream_options": {"include_usage": false}, "messages": [{"role": "system", "content": "this is system"}]}`), statusCode: envoyTypePb.StatusCode_BadRequest, }, { message: "/v1/chat/completions stream_options.include_usage == false with user.TPM == 0 is OK", - requestPath: "/v1/chat/completions", + requestType: OpenAiRequestChatCompletionsType, requestBody: []byte(`{"model": "llama2-7b", "stream": true, "stream_options": {"include_usage": false}, "messages": [{"role": "system", "content": "this is system"}]}`), statusCode: envoyTypePb.StatusCode_OK, }, { message: "/v1/chat/completions valid request body", user: utils.User{Tpm: 1}, - requestPath: "/v1/chat/completions", + requestType: OpenAiRequestChatCompletionsType, requestBody: []byte(`{"model": "llama2-7b", "stream": true, "stream_options": {"include_usage": true}, "messages": [{"role": "system", "content": "this is system"},{"role": "user", "content": "say this is test"}]}`), stream: true, model: "llama2-7b", @@ -136,7 +136,7 @@ func Test_ValidateRequestBody(t *testing.T) { } for _, tt := range testCases { - model, messages, stream, errRes := validateRequestBody("1", tt.requestPath, tt.requestBody, tt.user) + model, messages, stream, errRes := validateRequestBody("1", tt.requestType, tt.requestBody, tt.user) if tt.statusCode == 200 { assert.Equal(t, (*extProcPb.ProcessingResponse)(nil), errRes, tt.message)