Skip to content

Commit cbbbf37

Browse files
committed
feat: validate request body for embeddings requests
Signed-off-by: Guillaume Calmettes <[email protected]>
1 parent 2341260 commit cbbbf37

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

pkg/plugins/gateway/util.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
2525
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
2626
"github.com/openai/openai-go"
27+
"github.com/openai/openai-go/packages/param"
2728
"github.com/vllm-project/aibrix/pkg/utils"
2829
"k8s.io/klog/v2"
2930
)
@@ -69,6 +70,24 @@ func validateRequestBody(requestID, requestPath string, requestBody []byte, user
6970
}
7071
model = completionObj.Model
7172
message = completionObj.Prompt
73+
} else if requestPath == "/v1/embeddings" {
74+
message = "" // prefix_cache algorithms are not relevant for embeddings
75+
var jsonMap map[string]json.RawMessage
76+
if err := json.Unmarshal(requestBody, &jsonMap); err != nil {
77+
klog.ErrorS(err, "error to unmarshal request body", "requestID", requestID, "requestBody", string(requestBody))
78+
errRes = buildErrorResponse(envoyTypePb.StatusCode_BadRequest, "error processing request body", HeaderErrorRequestBodyProcessing, "true")
79+
return
80+
}
81+
embeddingObj := openai.EmbeddingNewParams{}
82+
if err := json.Unmarshal(requestBody, &embeddingObj); err != nil {
83+
klog.ErrorS(err, "error to unmarshal embeddings object", "requestID", requestID, "requestBody", string(requestBody))
84+
errRes = buildErrorResponse(envoyTypePb.StatusCode_BadRequest, "error processing request body", HeaderErrorRequestBodyProcessing, "true")
85+
return
86+
}
87+
model = embeddingObj.Model
88+
if errRes = checkEmbeddingInputSequenceLen(requestID, embeddingObj); errRes != nil {
89+
return
90+
}
7291
} else {
7392
errRes = buildErrorResponse(envoyTypePb.StatusCode_NotImplemented, "unknown request path", HeaderErrorRequestBodyProcessing, "true")
7493
return
@@ -142,6 +161,48 @@ func getChatCompletionsMessage(requestID string, chatCompletionObj openai.ChatCo
142161
return builder.String(), nil
143162
}
144163

164+
// getEmbeddingsInputLen returns the len of the embeddings object
165+
func checkEmbeddingInputSequenceLen(requestID string, embeddingObj openai.EmbeddingNewParams) *extProcPb.ProcessingResponse {
166+
inputParam := embeddingObj.Input
167+
var size int
168+
switch input := embeddingNewParamsInputUnionAsAny(&inputParam).(type) {
169+
case *string:
170+
size = len(*input)
171+
case *[]string:
172+
size = len(*input)
173+
case *[]int64:
174+
size = len(*input)
175+
case *[][]int64:
176+
size = len(*input)
177+
default:
178+
}
179+
180+
if size == 0 {
181+
klog.ErrorS(nil, "no input in the request body", "requestID", requestID)
182+
return buildErrorResponse(envoyTypePb.StatusCode_BadRequest, "no messages in the request body", HeaderErrorRequestBodyProcessing, "true")
183+
}
184+
if size > 1024 {
185+
klog.ErrorS(nil, "embeddings content is too large", "requestID", requestID, "size", size)
186+
return buildErrorResponse(envoyTypePb.StatusCode_BadRequest, "embeddings content is too large", HeaderErrorRequestBodyProcessing, "true")
187+
}
188+
189+
return nil
190+
}
191+
192+
// TODO: make asAny method publicly available on OpenAI go
193+
func embeddingNewParamsInputUnionAsAny(u *openai.EmbeddingNewParamsInputUnion) any {
194+
if !param.IsOmitted(u.OfString) {
195+
return &u.OfString.Value
196+
} else if !param.IsOmitted(u.OfArrayOfStrings) {
197+
return &u.OfArrayOfStrings
198+
} else if !param.IsOmitted(u.OfArrayOfTokens) {
199+
return &u.OfArrayOfTokens
200+
} else if !param.IsOmitted(u.OfArrayOfTokenArrays) {
201+
return &u.OfArrayOfTokenArrays
202+
}
203+
return nil
204+
}
205+
145206
// generateErrorResponse construct envoy proxy error response
146207
// deprecated: use buildErrorResponse
147208
func generateErrorResponse(statusCode envoyTypePb.StatusCode, headers []*configPb.HeaderValueOption, body string) *extProcPb.ProcessingResponse {

0 commit comments

Comments
 (0)