diff --git a/chat.go b/chat.go index 0bb2e98ee..f364c4b7e 100644 --- a/chat.go +++ b/chat.go @@ -81,17 +81,26 @@ type ChatMessageImageURL struct { Detail ImageURLDetail `json:"detail,omitempty"` } +// ChatMessageFile is a placeholder for file parts in chat messages. +type ChatMessageFile struct { + FileID string `json:"file_id,omitempty"` + FileName string `json:"filename,omitempty"` + FileData string `json:"file_data,omitempty"` // Base64 encoded file data +} + type ChatMessagePartType string const ( ChatMessagePartTypeText ChatMessagePartType = "text" ChatMessagePartTypeImageURL ChatMessagePartType = "image_url" + ChatMessagePartTypeFile ChatMessagePartType = "file" ) type ChatMessagePart struct { Type ChatMessagePartType `json:"type,omitempty"` Text string `json:"text,omitempty"` ImageURL *ChatMessageImageURL `json:"image_url,omitempty"` + File *ChatMessageFile `json:"file,omitempty"` } type ChatCompletionMessage struct { diff --git a/chat_test.go b/chat_test.go index 172ce0740..ad9e77315 100644 --- a/chat_test.go +++ b/chat_test.go @@ -677,6 +677,14 @@ func TestMultipartChatCompletions(t *testing.T) { Detail: openai.ImageURLDetailLow, }, }, + { + Type: openai.ChatMessagePartTypeFile, + File: &openai.ChatMessageFile{ + FileID: "file-123", + FileName: "test.txt", + FileData: "dGVzdCBmaWxlIGNvbnRlbnQ=", // base64 encoded "test file content" + }, + }, }, }, }, @@ -687,7 +695,8 @@ func TestMultipartChatCompletions(t *testing.T) { func TestMultipartChatMessageSerialization(t *testing.T) { jsonText := `[{"role":"system","content":"system-message"},` + `{"role":"user","content":[{"type":"text","text":"nice-text"},` + - `{"type":"image_url","image_url":{"url":"URL","detail":"high"}}]}]` + `{"type":"image_url","image_url":{"url":"URL","detail":"high"}},` + + `{"type":"file","file":{"file_id":"file-123","filename":"test.txt","file_data":"dGVzdA=="}}]}]` var msgs []openai.ChatCompletionMessage err := json.Unmarshal([]byte(jsonText), &msgs) @@ -700,7 +709,7 @@ func TestMultipartChatMessageSerialization(t *testing.T) { if msgs[0].Role != "system" || msgs[0].Content != "system-message" || msgs[0].MultiContent != nil { t.Errorf("invalid user message: %v", msgs[0]) } - if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 2 { + if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 3 { t.Errorf("invalid user message") } parts := msgs[1].MultiContent @@ -710,6 +719,10 @@ func TestMultipartChatMessageSerialization(t *testing.T) { if parts[1].Type != "image_url" || parts[1].ImageURL.URL != "URL" || parts[1].ImageURL.Detail != "high" { t.Errorf("invalid image_url part") } + if parts[2].Type != "file" || parts[2].File.FileID != "file-123" || + parts[2].File.FileName != "test.txt" || parts[2].File.FileData != "dGVzdA==" { + t.Errorf("invalid file part: %v", parts[2]) + } s, err := json.Marshal(msgs) if err != nil { @@ -756,6 +769,103 @@ func TestMultipartChatMessageSerialization(t *testing.T) { } } +func TestChatMessageFile(t *testing.T) { + // Test file part with FileID + filePart := openai.ChatMessagePart{ + Type: openai.ChatMessagePartTypeFile, + File: &openai.ChatMessageFile{ + FileID: "file-abc123", + }, + } + + // Test serialization + data, err := json.Marshal(filePart) + if err != nil { + t.Fatalf("Expected no error: %s", err) + } + + expected := `{"type":"file","file":{"file_id":"file-abc123"}}` + result := strings.ReplaceAll(string(data), " ", "") + if result != expected { + t.Errorf("Expected %s, got %s", expected, result) + } + + // Test deserialization + var parsedPart openai.ChatMessagePart + err = json.Unmarshal(data, &parsedPart) + if err != nil { + t.Fatalf("Expected no error: %s", err) + } + + if parsedPart.Type != openai.ChatMessagePartTypeFile { + t.Errorf("Expected type %s, got %s", openai.ChatMessagePartTypeFile, parsedPart.Type) + } + if parsedPart.File == nil { + t.Fatal("Expected File to be non-nil") + } + if parsedPart.File.FileID != "file-abc123" { + t.Errorf("Expected FileID %s, got %s", "file-abc123", parsedPart.File.FileID) + } + + // Test file part with all fields + filePartComplete := openai.ChatMessagePart{ + Type: openai.ChatMessagePartTypeFile, + File: &openai.ChatMessageFile{ + FileID: "file-xyz789", + FileName: "document.pdf", + FileData: "JVBERi0xLjQK", // base64 for "%PDF-1.4\n" + }, + } + + data, err = json.Marshal(filePartComplete) + if err != nil { + t.Fatalf("Expected no error: %s", err) + } + + expected = `{"type":"file","file":{"file_id":"file-xyz789","filename":"document.pdf","file_data":"JVBERi0xLjQK"}}` + result = strings.ReplaceAll(string(data), " ", "") + if result != expected { + t.Errorf("Expected %s, got %s", expected, result) + } + + // Test deserialization of complete file part + var parsedCompleteFile openai.ChatMessagePart + err = json.Unmarshal(data, &parsedCompleteFile) + if err != nil { + t.Fatalf("Expected no error: %s", err) + } + + if parsedCompleteFile.File.FileID != "file-xyz789" { + t.Errorf("Expected FileID %s, got %s", "file-xyz789", parsedCompleteFile.File.FileID) + } + if parsedCompleteFile.File.FileName != "document.pdf" { + t.Errorf("Expected FileName %s, got %s", "document.pdf", parsedCompleteFile.File.FileName) + } + if parsedCompleteFile.File.FileData != "JVBERi0xLjQK" { + t.Errorf("Expected FileData %s, got %s", "JVBERi0xLjQK", parsedCompleteFile.File.FileData) + } +} + +func TestChatMessagePartTypeConstants(t *testing.T) { + // Test that the new file constant is properly defined + if openai.ChatMessagePartTypeFile != "file" { + t.Errorf("Expected ChatMessagePartTypeFile to be 'file', got %s", openai.ChatMessagePartTypeFile) + } + + // Test all part type constants + expectedTypes := map[openai.ChatMessagePartType]string{ + openai.ChatMessagePartTypeText: "text", + openai.ChatMessagePartTypeImageURL: "image_url", + openai.ChatMessagePartTypeFile: "file", + } + + for constant, expected := range expectedTypes { + if string(constant) != expected { + t.Errorf("Expected %s to be %s, got %s", constant, expected, string(constant)) + } + } +} + // handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server. func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) { var err error