Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,26 @@ type ChatMessageImageURL struct {
Detail ImageURLDetail `json:"detail,omitempty"`
}

// ChatMessagePartFile is a placeholder for file parts in chat messages.
type ChatMessageFile struct {
FileID string `json:"file_id,omitempty"`
FileName string `json:"filename,omitempty"`
FileData string `json:"file_data,omitempty"` // Base64 encoded file data
}

type ChatMessagePartType string

const (
ChatMessagePartTypeText ChatMessagePartType = "text"
ChatMessagePartTypeImageURL ChatMessagePartType = "image_url"
ChatMessagePartTypeFile ChatMessagePartType = "file"
)

type ChatMessagePart struct {
Type ChatMessagePartType `json:"type,omitempty"`
Text string `json:"text,omitempty"`
ImageURL *ChatMessageImageURL `json:"image_url,omitempty"`
File *ChatMessageFile `json:"file,omitempty"`
}

type ChatCompletionMessage struct {
Expand Down
114 changes: 112 additions & 2 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,14 @@ func TestMultipartChatCompletions(t *testing.T) {
Detail: openai.ImageURLDetailLow,
},
},
{
Type: openai.ChatMessagePartTypeFile,
File: &openai.ChatMessageFile{
FileID: "file-123",
FileName: "test.txt",
FileData: "dGVzdCBmaWxlIGNvbnRlbnQ=", // base64 encoded "test file content"
},
},
},
},
},
Expand All @@ -687,7 +695,8 @@ func TestMultipartChatCompletions(t *testing.T) {
func TestMultipartChatMessageSerialization(t *testing.T) {
jsonText := `[{"role":"system","content":"system-message"},` +
`{"role":"user","content":[{"type":"text","text":"nice-text"},` +
`{"type":"image_url","image_url":{"url":"URL","detail":"high"}}]}]`
`{"type":"image_url","image_url":{"url":"URL","detail":"high"}},` +
`{"type":"file","file":{"file_id":"file-123","filename":"test.txt","file_data":"dGVzdA=="}}]}]`

var msgs []openai.ChatCompletionMessage
err := json.Unmarshal([]byte(jsonText), &msgs)
Expand All @@ -700,7 +709,7 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
if msgs[0].Role != "system" || msgs[0].Content != "system-message" || msgs[0].MultiContent != nil {
t.Errorf("invalid user message: %v", msgs[0])
}
if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 2 {
if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 3 {
t.Errorf("invalid user message")
}
parts := msgs[1].MultiContent
Expand All @@ -710,6 +719,10 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
if parts[1].Type != "image_url" || parts[1].ImageURL.URL != "URL" || parts[1].ImageURL.Detail != "high" {
t.Errorf("invalid image_url part")
}
if parts[2].Type != "file" || parts[2].File.FileID != "file-123" ||
parts[2].File.FileName != "test.txt" || parts[2].File.FileData != "dGVzdA==" {
t.Errorf("invalid file part: %v", parts[2])
}

s, err := json.Marshal(msgs)
if err != nil {
Expand Down Expand Up @@ -756,6 +769,103 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
}
}

func TestChatMessageFile(t *testing.T) {
// Test file part with FileID
filePart := openai.ChatMessagePart{
Type: openai.ChatMessagePartTypeFile,
File: &openai.ChatMessageFile{
FileID: "file-abc123",
},
}

// Test serialization
data, err := json.Marshal(filePart)
if err != nil {
t.Fatalf("Expected no error: %s", err)
}

expected := `{"type":"file","file":{"file_id":"file-abc123"}}`
result := strings.ReplaceAll(string(data), " ", "")
if result != expected {
t.Errorf("Expected %s, got %s", expected, result)
}

// Test deserialization
var parsedPart openai.ChatMessagePart
err = json.Unmarshal(data, &parsedPart)
if err != nil {
t.Fatalf("Expected no error: %s", err)
}

if parsedPart.Type != openai.ChatMessagePartTypeFile {
t.Errorf("Expected type %s, got %s", openai.ChatMessagePartTypeFile, parsedPart.Type)
}
if parsedPart.File == nil {
t.Fatal("Expected File to be non-nil")
}
if parsedPart.File.FileID != "file-abc123" {
t.Errorf("Expected FileID %s, got %s", "file-abc123", parsedPart.File.FileID)
}

// Test file part with all fields
filePartComplete := openai.ChatMessagePart{
Type: openai.ChatMessagePartTypeFile,
File: &openai.ChatMessageFile{
FileID: "file-xyz789",
FileName: "document.pdf",
FileData: "JVBERi0xLjQK", // base64 for "%PDF-1.4\n"
},
}

data, err = json.Marshal(filePartComplete)
if err != nil {
t.Fatalf("Expected no error: %s", err)
}

expected = `{"type":"file","file":{"file_id":"file-xyz789","filename":"document.pdf","file_data":"JVBERi0xLjQK"}}`
result = strings.ReplaceAll(string(data), " ", "")
if result != expected {
t.Errorf("Expected %s, got %s", expected, result)
}

// Test deserialization of complete file part
var parsedCompleteFile openai.ChatMessagePart
err = json.Unmarshal(data, &parsedCompleteFile)
if err != nil {
t.Fatalf("Expected no error: %s", err)
}

if parsedCompleteFile.File.FileID != "file-xyz789" {
t.Errorf("Expected FileID %s, got %s", "file-xyz789", parsedCompleteFile.File.FileID)
}
if parsedCompleteFile.File.FileName != "document.pdf" {
t.Errorf("Expected FileName %s, got %s", "document.pdf", parsedCompleteFile.File.FileName)
}
if parsedCompleteFile.File.FileData != "JVBERi0xLjQK" {
t.Errorf("Expected FileData %s, got %s", "JVBERi0xLjQK", parsedCompleteFile.File.FileData)
}
}

func TestChatMessagePartTypeConstants(t *testing.T) {
// Test that the new file constant is properly defined
if openai.ChatMessagePartTypeFile != "file" {
t.Errorf("Expected ChatMessagePartTypeFile to be 'file', got %s", openai.ChatMessagePartTypeFile)
}

// Test all part type constants
expectedTypes := map[openai.ChatMessagePartType]string{
openai.ChatMessagePartTypeText: "text",
openai.ChatMessagePartTypeImageURL: "image_url",
openai.ChatMessagePartTypeFile: "file",
}

for constant, expected := range expectedTypes {
if string(constant) != expected {
t.Errorf("Expected %s to be %s, got %s", constant, expected, string(constant))
}
}
}

// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
Expand Down
Loading