Skip to content

Commit 03caea8

Browse files
authored
Add support for multi part chat messages (and gpt-4-vision-preview model) (#580)
* Add support for multi part chat messages OpenAI has recently introduced a new model called gpt-4-visual-preview, which now supports images as input. The chat completion endpoint accepts multi-part chat messages, where the content can be an array of structs in addition to the usual string format. This commit introduces new structures and constants to represent different types of content parts. It also implements the json.Marshaler and json.Unmarshaler interfaces on ChatCompletionMessage. * Add ImageURLDetail and ChatMessagePartType types * Optimize ChatCompletionMessage deserialization * Add ErrContentFieldsMisused error
1 parent 7260991 commit 03caea8

File tree

2 files changed

+192
-2
lines changed

2 files changed

+192
-2
lines changed

chat.go

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package openai
22

33
import (
44
"context"
5+
"encoding/json"
56
"errors"
67
"net/http"
78
)
@@ -20,6 +21,7 @@ const chatCompletionsSuffix = "/chat/completions"
2021
var (
2122
ErrChatCompletionInvalidModel = errors.New("this model is not supported with this method, please use CreateCompletion client method instead") //nolint:lll
2223
ErrChatCompletionStreamNotSupported = errors.New("streaming is not supported with this method, please use CreateChatCompletionStream") //nolint:lll
24+
ErrContentFieldsMisused = errors.New("can't use both Content and MultiContent properties simultaneously")
2325
)
2426

2527
type Hate struct {
@@ -51,9 +53,36 @@ type PromptAnnotation struct {
5153
ContentFilterResults ContentFilterResults `json:"content_filter_results,omitempty"`
5254
}
5355

56+
type ImageURLDetail string
57+
58+
const (
59+
ImageURLDetailHigh ImageURLDetail = "high"
60+
ImageURLDetailLow ImageURLDetail = "low"
61+
ImageURLDetailAuto ImageURLDetail = "auto"
62+
)
63+
64+
type ChatMessageImageURL struct {
65+
URL string `json:"url,omitempty"`
66+
Detail ImageURLDetail `json:"detail,omitempty"`
67+
}
68+
69+
type ChatMessagePartType string
70+
71+
const (
72+
ChatMessagePartTypeText ChatMessagePartType = "text"
73+
ChatMessagePartTypeImageURL ChatMessagePartType = "image_url"
74+
)
75+
76+
type ChatMessagePart struct {
77+
Type ChatMessagePartType `json:"type,omitempty"`
78+
Text string `json:"text,omitempty"`
79+
ImageURL *ChatMessageImageURL `json:"image_url,omitempty"`
80+
}
81+
5482
type ChatCompletionMessage struct {
55-
Role string `json:"role"`
56-
Content string `json:"content"`
83+
Role string `json:"role"`
84+
Content string `json:"content"`
85+
MultiContent []ChatMessagePart
5786

5887
// This property isn't in the official documentation, but it's in
5988
// the documentation for the official library for python:
@@ -70,6 +99,64 @@ type ChatCompletionMessage struct {
7099
ToolCallID string `json:"tool_call_id,omitempty"`
71100
}
72101

102+
func (m ChatCompletionMessage) MarshalJSON() ([]byte, error) {
103+
if m.Content != "" && m.MultiContent != nil {
104+
return nil, ErrContentFieldsMisused
105+
}
106+
if len(m.MultiContent) > 0 {
107+
msg := struct {
108+
Role string `json:"role"`
109+
Content string `json:"-"`
110+
MultiContent []ChatMessagePart `json:"content,omitempty"`
111+
Name string `json:"name,omitempty"`
112+
FunctionCall *FunctionCall `json:"function_call,omitempty"`
113+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
114+
ToolCallID string `json:"tool_call_id,omitempty"`
115+
}(m)
116+
return json.Marshal(msg)
117+
}
118+
msg := struct {
119+
Role string `json:"role"`
120+
Content string `json:"content"`
121+
MultiContent []ChatMessagePart `json:"-"`
122+
Name string `json:"name,omitempty"`
123+
FunctionCall *FunctionCall `json:"function_call,omitempty"`
124+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
125+
ToolCallID string `json:"tool_call_id,omitempty"`
126+
}(m)
127+
return json.Marshal(msg)
128+
}
129+
130+
func (m *ChatCompletionMessage) UnmarshalJSON(bs []byte) error {
131+
msg := struct {
132+
Role string `json:"role"`
133+
Content string `json:"content"`
134+
MultiContent []ChatMessagePart
135+
Name string `json:"name,omitempty"`
136+
FunctionCall *FunctionCall `json:"function_call,omitempty"`
137+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
138+
ToolCallID string `json:"tool_call_id,omitempty"`
139+
}{}
140+
if err := json.Unmarshal(bs, &msg); err == nil {
141+
*m = ChatCompletionMessage(msg)
142+
return nil
143+
}
144+
multiMsg := struct {
145+
Role string `json:"role"`
146+
Content string
147+
MultiContent []ChatMessagePart `json:"content"`
148+
Name string `json:"name,omitempty"`
149+
FunctionCall *FunctionCall `json:"function_call,omitempty"`
150+
ToolCalls []ToolCall `json:"tool_calls,omitempty"`
151+
ToolCallID string `json:"tool_call_id,omitempty"`
152+
}{}
153+
if err := json.Unmarshal(bs, &multiMsg); err != nil {
154+
return err
155+
}
156+
*m = ChatCompletionMessage(multiMsg)
157+
return nil
158+
}
159+
73160
type ToolCall struct {
74161
// Index is not nil only in chat completion chunk object
75162
Index *int `json:"index,omitempty"`

chat_test.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package openai_test
33
import (
44
"context"
55
"encoding/json"
6+
"errors"
67
"fmt"
78
"io"
89
"net/http"
@@ -296,6 +297,108 @@ func TestAzureChatCompletions(t *testing.T) {
296297
checks.NoError(t, err, "CreateAzureChatCompletion error")
297298
}
298299

300+
func TestMultipartChatCompletions(t *testing.T) {
301+
client, server, teardown := setupAzureTestServer()
302+
defer teardown()
303+
server.RegisterHandler("/openai/deployments/*", handleChatCompletionEndpoint)
304+
305+
_, err := client.CreateChatCompletion(context.Background(), openai.ChatCompletionRequest{
306+
MaxTokens: 5,
307+
Model: openai.GPT3Dot5Turbo,
308+
Messages: []openai.ChatCompletionMessage{
309+
{
310+
Role: openai.ChatMessageRoleUser,
311+
MultiContent: []openai.ChatMessagePart{
312+
{
313+
Type: openai.ChatMessagePartTypeText,
314+
Text: "Hello!",
315+
},
316+
{
317+
Type: openai.ChatMessagePartTypeImageURL,
318+
ImageURL: &openai.ChatMessageImageURL{
319+
URL: "URL",
320+
Detail: openai.ImageURLDetailLow,
321+
},
322+
},
323+
},
324+
},
325+
},
326+
})
327+
checks.NoError(t, err, "CreateAzureChatCompletion error")
328+
}
329+
330+
func TestMultipartChatMessageSerialization(t *testing.T) {
331+
jsonText := `[{"role":"system","content":"system-message"},` +
332+
`{"role":"user","content":[{"type":"text","text":"nice-text"},` +
333+
`{"type":"image_url","image_url":{"url":"URL","detail":"high"}}]}]`
334+
335+
var msgs []openai.ChatCompletionMessage
336+
err := json.Unmarshal([]byte(jsonText), &msgs)
337+
if err != nil {
338+
t.Fatalf("Expected no error: %s", err)
339+
}
340+
if len(msgs) != 2 {
341+
t.Errorf("unexpected number of messages")
342+
}
343+
if msgs[0].Role != "system" || msgs[0].Content != "system-message" || msgs[0].MultiContent != nil {
344+
t.Errorf("invalid user message: %v", msgs[0])
345+
}
346+
if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 2 {
347+
t.Errorf("invalid user message")
348+
}
349+
parts := msgs[1].MultiContent
350+
if parts[0].Type != "text" || parts[0].Text != "nice-text" {
351+
t.Errorf("invalid text part: %v", parts[0])
352+
}
353+
if parts[1].Type != "image_url" || parts[1].ImageURL.URL != "URL" || parts[1].ImageURL.Detail != "high" {
354+
t.Errorf("invalid image_url part")
355+
}
356+
357+
s, err := json.Marshal(msgs)
358+
if err != nil {
359+
t.Fatalf("Expected no error: %s", err)
360+
}
361+
res := strings.ReplaceAll(string(s), " ", "")
362+
if res != jsonText {
363+
t.Fatalf("invalid message: %s", string(s))
364+
}
365+
366+
invalidMsg := []openai.ChatCompletionMessage{
367+
{
368+
Role: "user",
369+
Content: "some-text",
370+
MultiContent: []openai.ChatMessagePart{
371+
{
372+
Type: "text",
373+
Text: "nice-text",
374+
},
375+
},
376+
},
377+
}
378+
_, err = json.Marshal(invalidMsg)
379+
if !errors.Is(err, openai.ErrContentFieldsMisused) {
380+
t.Fatalf("Expected error: %s", err)
381+
}
382+
383+
err = json.Unmarshal([]byte(`["not-a-message"]`), &msgs)
384+
if err == nil {
385+
t.Fatalf("Expected error")
386+
}
387+
388+
emptyMultiContentMsg := openai.ChatCompletionMessage{
389+
Role: "user",
390+
MultiContent: []openai.ChatMessagePart{},
391+
}
392+
s, err = json.Marshal(emptyMultiContentMsg)
393+
if err != nil {
394+
t.Fatalf("Unexpected error")
395+
}
396+
res = strings.ReplaceAll(string(s), " ", "")
397+
if res != `{"role":"user","content":""}` {
398+
t.Fatalf("invalid message: %s", string(s))
399+
}
400+
}
401+
299402
// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.
300403
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
301404
var err error

0 commit comments

Comments
 (0)