Skip to content

Commit e767da5

Browse files
authored
refactor: use tagged union for union types in openai schema if possible (envoyproxy#1182)
**Description** changes: 1 refector `ChatCompletionMessageParamUnion` to use same tagged union implementation as others. This should be able to make the definition more explicit. 2 add `ChatCompletionContentPartFileParam` inside `ChatCompletionContentPartUserUnionParam`, also changed the names of other fields to make the naming consistent **Related Issues/PRs (if applicable)** follow up on envoyproxy#1178 Signed-off-by: yxia216 <[email protected]>
1 parent dba58df commit e767da5

File tree

16 files changed

+508
-502
lines changed

16 files changed

+508
-502
lines changed

internal/apischema/openai/openai.go

Lines changed: 90 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,15 @@ type ChatCompletionContentPartTextType string
6666
// ChatCompletionContentPartImageType The type of the content part.
6767
type ChatCompletionContentPartImageType string
6868

69+
// ChatCompletionContentPartImageType The type of the content part.
70+
type ChatCompletionContentPartFileType string
71+
6972
const (
7073
ChatCompletionContentPartTextTypeText ChatCompletionContentPartTextType = "text"
7174
ChatCompletionContentPartRefusalTypeRefusal ChatCompletionContentPartRefusalType = "refusal"
7275
ChatCompletionContentPartInputAudioTypeInputAudio ChatCompletionContentPartInputAudioType = "input_audio"
7376
ChatCompletionContentPartImageTypeImageURL ChatCompletionContentPartImageType = "image_url"
77+
ChatCompletionContentPartFileTypeFile ChatCompletionContentPartFileType = "file"
7478
)
7579

7680
// ChatCompletionContentPartTextParam Learn about
@@ -134,12 +138,32 @@ type ChatCompletionContentPartImageParam struct {
134138
Type ChatCompletionContentPartImageType `json:"type"`
135139
}
136140

141+
type ChatCompletionContentPartFileFileParam struct {
142+
// The base64 encoded file data, used when passing the file to the model as a
143+
// string.
144+
FileData string `json:"file_data,omitzero"`
145+
// The ID of an uploaded file to use as input.
146+
FileID string `json:"file_id,omitzero"`
147+
// The name of the file, used when passing the file to the model as a string.
148+
Filename string `json:"filename,omitzero"`
149+
}
150+
151+
// ChatCompletionContentPartFileParam .
152+
type ChatCompletionContentPartFileParam struct {
153+
File ChatCompletionContentPartFileFileParam `json:"file,omitzero"`
154+
// The type of the content part. Always `file`.
155+
//
156+
// This field can be elided, and will marshal its zero value as "file".
157+
Type ChatCompletionContentPartFileType `json:"type"`
158+
}
159+
137160
// ChatCompletionContentPartUserUnionParam Learn about
138161
// [text inputs](https://platform.openai.com/docs/guides/text-generation).
139162
type ChatCompletionContentPartUserUnionParam struct {
140-
TextContent *ChatCompletionContentPartTextParam
141-
InputAudioContent *ChatCompletionContentPartInputAudioParam
142-
ImageContent *ChatCompletionContentPartImageParam
163+
OfText *ChatCompletionContentPartTextParam `json:",omitzero,inline"`
164+
OfInputAudio *ChatCompletionContentPartInputAudioParam `json:",omitzero,inline"`
165+
OfImageURL *ChatCompletionContentPartImageParam `json:",omitzero,inline"`
166+
OfFile *ChatCompletionContentPartFileParam `json:",omitzero,inline"`
143167
}
144168

145169
func (c *ChatCompletionContentPartUserUnionParam) UnmarshalJSON(data []byte) error {
@@ -157,34 +181,41 @@ func (c *ChatCompletionContentPartUserUnionParam) UnmarshalJSON(data []byte) err
157181
if err := json.Unmarshal(data, &textContent); err != nil {
158182
return err
159183
}
160-
c.TextContent = &textContent
184+
c.OfText = &textContent
161185
case string(ChatCompletionContentPartInputAudioTypeInputAudio):
162186
var audioContent ChatCompletionContentPartInputAudioParam
163187
if err := json.Unmarshal(data, &audioContent); err != nil {
164188
return err
165189
}
166-
c.InputAudioContent = &audioContent
190+
c.OfInputAudio = &audioContent
167191
case string(ChatCompletionContentPartImageTypeImageURL):
168192
var imageContent ChatCompletionContentPartImageParam
169193
if err := json.Unmarshal(data, &imageContent); err != nil {
170194
return err
171195
}
172-
c.ImageContent = &imageContent
196+
c.OfImageURL = &imageContent
197+
case string(ChatCompletionContentPartFileTypeFile):
198+
var fileContent ChatCompletionContentPartFileParam
199+
if err := json.Unmarshal(data, &fileContent); err != nil {
200+
return err
201+
}
202+
c.OfFile = &fileContent
203+
173204
default:
174205
return fmt.Errorf("unknown ChatCompletionContentPartUnionParam type: %v", contentType)
175206
}
176207
return nil
177208
}
178209

179210
func (c ChatCompletionContentPartUserUnionParam) MarshalJSON() ([]byte, error) {
180-
if c.TextContent != nil {
181-
return json.Marshal(c.TextContent)
211+
if c.OfText != nil {
212+
return json.Marshal(c.OfText)
182213
}
183-
if c.InputAudioContent != nil {
184-
return json.Marshal(c.InputAudioContent)
214+
if c.OfInputAudio != nil {
215+
return json.Marshal(c.OfInputAudio)
185216
}
186-
if c.ImageContent != nil {
187-
return json.Marshal(c.ImageContent)
217+
if c.OfImageURL != nil {
218+
return json.Marshal(c.OfImageURL)
188219
}
189220
return nil, errors.New("no content to marshal")
190221
}
@@ -284,9 +315,13 @@ func (s StringOrUserRoleContentUnion) MarshalJSON() ([]byte, error) {
284315
return json.Marshal(s.Value)
285316
}
286317

318+
// Function message is deprecated and we do not allow it.
287319
type ChatCompletionMessageParamUnion struct {
288-
Value any
289-
Type string
320+
OfDeveloper *ChatCompletionDeveloperMessageParam `json:",omitzero,inline"`
321+
OfSystem *ChatCompletionSystemMessageParam `json:",omitzero,inline"`
322+
OfUser *ChatCompletionUserMessageParam `json:",omitzero,inline"`
323+
OfAssistant *ChatCompletionAssistantMessageParam `json:",omitzero,inline"`
324+
OfTool *ChatCompletionToolMessageParam `json:",omitzero,inline"`
290325
}
291326

292327
func (c *ChatCompletionMessageParamUnion) UnmarshalJSON(data []byte) error {
@@ -304,44 +339,55 @@ func (c *ChatCompletionMessageParamUnion) UnmarshalJSON(data []byte) error {
304339
if err := json.Unmarshal(data, &userMessage); err != nil {
305340
return err
306341
}
307-
c.Value = userMessage
308-
c.Type = ChatMessageRoleUser
342+
c.OfUser = &userMessage
309343
case ChatMessageRoleAssistant:
310344
var assistantMessage ChatCompletionAssistantMessageParam
311345
if err := json.Unmarshal(data, &assistantMessage); err != nil {
312346
return err
313347
}
314-
c.Value = assistantMessage
315-
c.Type = ChatMessageRoleAssistant
348+
c.OfAssistant = &assistantMessage
316349
case ChatMessageRoleSystem:
317350
var systemMessage ChatCompletionSystemMessageParam
318351
if err := json.Unmarshal(data, &systemMessage); err != nil {
319352
return err
320353
}
321-
c.Value = systemMessage
322-
c.Type = ChatMessageRoleSystem
354+
c.OfSystem = &systemMessage
323355
case ChatMessageRoleDeveloper:
324356
var developerMessage ChatCompletionDeveloperMessageParam
325357
if err := json.Unmarshal(data, &developerMessage); err != nil {
326358
return err
327359
}
328-
c.Value = developerMessage
329-
c.Type = ChatMessageRoleDeveloper
360+
c.OfDeveloper = &developerMessage
330361
case ChatMessageRoleTool:
331362
var toolMessage ChatCompletionToolMessageParam
332363
if err := json.Unmarshal(data, &toolMessage); err != nil {
333364
return err
334365
}
335-
c.Value = toolMessage
336-
c.Type = ChatMessageRoleTool
366+
c.OfTool = &toolMessage
337367
default:
338368
return fmt.Errorf("unknown ChatCompletionMessageParam type: %v", role)
339369
}
340370
return nil
341371
}
342372

343373
func (c ChatCompletionMessageParamUnion) MarshalJSON() ([]byte, error) {
344-
return json.Marshal(c.Value)
374+
if c.OfUser != nil {
375+
return json.Marshal(c.OfUser)
376+
}
377+
if c.OfAssistant != nil {
378+
return json.Marshal(c.OfAssistant)
379+
}
380+
if c.OfSystem != nil {
381+
return json.Marshal(c.OfSystem)
382+
}
383+
if c.OfDeveloper != nil {
384+
return json.Marshal(c.OfDeveloper)
385+
}
386+
if c.OfTool != nil {
387+
return json.Marshal(c.OfTool)
388+
}
389+
390+
return nil, errors.New("no message to marshal")
345391
}
346392

347393
// ChatCompletionUserMessageParam Messages sent by an end user, containing prompts or additional context
@@ -465,6 +511,25 @@ type ChatCompletionMessageToolCallParam struct {
465511
Type ChatCompletionMessageToolCallType `json:"type,omitempty"`
466512
}
467513

514+
// extractMessageRole extracts role from OpenAI message union types.
515+
func (c ChatCompletionMessageParamUnion) ExtractMessgaeRole() string {
516+
switch {
517+
case c.OfDeveloper != nil:
518+
return c.OfDeveloper.Role
519+
case c.OfSystem != nil:
520+
return c.OfSystem.Role
521+
case c.OfAssistant != nil:
522+
return c.OfAssistant.Role
523+
case c.OfTool != nil:
524+
return c.OfTool.Role
525+
case c.OfUser != nil:
526+
return c.OfUser.Role
527+
// Add other cases here for any other message types in the union.
528+
default:
529+
return "[unknown message type]"
530+
}
531+
}
532+
468533
type ChatCompletionResponseFormatType string
469534

470535
// Constants for the different response formats.

0 commit comments

Comments
 (0)