@@ -24,6 +24,7 @@ import (
24
24
extProcPb "github.com/envoyproxy/go-control-plane/envoy/service/ext_proc/v3"
25
25
envoyTypePb "github.com/envoyproxy/go-control-plane/envoy/type/v3"
26
26
"github.com/openai/openai-go"
27
+ "github.com/openai/openai-go/packages/param"
27
28
"github.com/vllm-project/aibrix/pkg/utils"
28
29
"k8s.io/klog/v2"
29
30
)
@@ -69,6 +70,24 @@ func validateRequestBody(requestID, requestPath string, requestBody []byte, user
69
70
}
70
71
model = completionObj .Model
71
72
message = completionObj .Prompt
73
+ } else if requestPath == "/v1/embeddings" {
74
+ message = "" // prefix_cache algorithms are not relevant for embeddings
75
+ var jsonMap map [string ]json.RawMessage
76
+ if err := json .Unmarshal (requestBody , & jsonMap ); err != nil {
77
+ klog .ErrorS (err , "error to unmarshal request body" , "requestID" , requestID , "requestBody" , string (requestBody ))
78
+ errRes = buildErrorResponse (envoyTypePb .StatusCode_BadRequest , "error processing request body" , HeaderErrorRequestBodyProcessing , "true" )
79
+ return
80
+ }
81
+ embeddingObj := openai.EmbeddingNewParams {}
82
+ if err := json .Unmarshal (requestBody , & embeddingObj ); err != nil {
83
+ klog .ErrorS (err , "error to unmarshal embeddings object" , "requestID" , requestID , "requestBody" , string (requestBody ))
84
+ errRes = buildErrorResponse (envoyTypePb .StatusCode_BadRequest , "error processing request body" , HeaderErrorRequestBodyProcessing , "true" )
85
+ return
86
+ }
87
+ model = embeddingObj .Model
88
+ if errRes = checkEmbeddingInputSequenceLen (requestID , embeddingObj ); errRes != nil {
89
+ return
90
+ }
72
91
} else {
73
92
errRes = buildErrorResponse (envoyTypePb .StatusCode_NotImplemented , "unknown request path" , HeaderErrorRequestBodyProcessing , "true" )
74
93
return
@@ -142,6 +161,48 @@ func getChatCompletionsMessage(requestID string, chatCompletionObj openai.ChatCo
142
161
return builder .String (), nil
143
162
}
144
163
164
+ // getEmbeddingsInputLen returns the len of the embeddings object
165
+ func checkEmbeddingInputSequenceLen (requestID string , embeddingObj openai.EmbeddingNewParams ) * extProcPb.ProcessingResponse {
166
+ inputParam := embeddingObj .Input
167
+ var size int
168
+ switch input := embeddingNewParamsInputUnionAsAny (& inputParam ).(type ) {
169
+ case * string :
170
+ size = len (* input )
171
+ case * []string :
172
+ size = len (* input )
173
+ case * []int64 :
174
+ size = len (* input )
175
+ case * [][]int64 :
176
+ size = len (* input )
177
+ default :
178
+ }
179
+
180
+ if size == 0 {
181
+ klog .ErrorS (nil , "no input in the request body" , "requestID" , requestID )
182
+ return buildErrorResponse (envoyTypePb .StatusCode_BadRequest , "no messages in the request body" , HeaderErrorRequestBodyProcessing , "true" )
183
+ }
184
+ if size > 1024 {
185
+ klog .ErrorS (nil , "embeddings content is too large" , "requestID" , requestID , "size" , size )
186
+ return buildErrorResponse (envoyTypePb .StatusCode_BadRequest , "embeddings content is too large" , HeaderErrorRequestBodyProcessing , "true" )
187
+ }
188
+
189
+ return nil
190
+ }
191
+
192
+ // TODO: make asAny method publicly available on OpenAI go
193
+ func embeddingNewParamsInputUnionAsAny (u * openai.EmbeddingNewParamsInputUnion ) any {
194
+ if ! param .IsOmitted (u .OfString ) {
195
+ return & u .OfString .Value
196
+ } else if ! param .IsOmitted (u .OfArrayOfStrings ) {
197
+ return & u .OfArrayOfStrings
198
+ } else if ! param .IsOmitted (u .OfArrayOfTokens ) {
199
+ return & u .OfArrayOfTokens
200
+ } else if ! param .IsOmitted (u .OfArrayOfTokenArrays ) {
201
+ return & u .OfArrayOfTokenArrays
202
+ }
203
+ return nil
204
+ }
205
+
145
206
// generateErrorResponse construct envoy proxy error response
146
207
// deprecated: use buildErrorResponse
147
208
func generateErrorResponse (statusCode envoyTypePb.StatusCode , headers []* configPb.HeaderValueOption , body string ) * extProcPb.ProcessingResponse {
0 commit comments