Skip to content

Add file part support to chat message structure #1056

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,26 @@ type ChatMessageImageURL struct {
Detail ImageURLDetail `json:"detail,omitempty"`
}

// ChatMessageFile is a placeholder for file parts in chat messages.
type ChatMessageFile struct {
FileID string `json:"file_id,omitempty"`
FileName string `json:"filename,omitempty"`
FileData string `json:"file_data,omitempty"` // Base64 encoded file data
}

type ChatMessagePartType string

const (
ChatMessagePartTypeText ChatMessagePartType = "text"
ChatMessagePartTypeImageURL ChatMessagePartType = "image_url"
ChatMessagePartTypeFile ChatMessagePartType = "file"
)

type ChatMessagePart struct {
Type ChatMessagePartType `json:"type,omitempty"`
Text string `json:"text,omitempty"`
ImageURL *ChatMessageImageURL `json:"image_url,omitempty"`
File *ChatMessageFile `json:"file,omitempty"`
}

type ChatCompletionMessage struct {
Expand Down
114 changes: 112 additions & 2 deletions chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,14 @@ func TestMultipartChatCompletions(t *testing.T) {
Detail: openai.ImageURLDetailLow,
},
},
{
Type: openai.ChatMessagePartTypeFile,
File: &openai.ChatMessageFile{
FileID: "file-123",
FileName: "test.txt",
FileData: "dGVzdCBmaWxlIGNvbnRlbnQ=", // base64 encoded "test file content"
},
},
},
},
},
Expand All @@ -687,7 +695,8 @@ func TestMultipartChatCompletions(t *testing.T) {
func TestMultipartChatMessageSerialization(t *testing.T) {
jsonText := `[{"role":"system","content":"system-message"},` +
`{"role":"user","content":[{"type":"text","text":"nice-text"},` +
`{"type":"image_url","image_url":{"url":"URL","detail":"high"}}]}]`
`{"type":"image_url","image_url":{"url":"URL","detail":"high"}},` +
`{"type":"file","file":{"file_id":"file-123","filename":"test.txt","file_data":"dGVzdA=="}}]}]`

var msgs []openai.ChatCompletionMessage
err := json.Unmarshal([]byte(jsonText), &msgs)
Expand All @@ -700,7 +709,7 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
if msgs[0].Role != "system" || msgs[0].Content != "system-message" || msgs[0].MultiContent != nil {
t.Errorf("invalid user message: %v", msgs[0])
}
if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 2 {
if msgs[1].Role != "user" || msgs[1].Content != "" || len(msgs[1].MultiContent) != 3 {
t.Errorf("invalid user message")
}
parts := msgs[1].MultiContent
Expand All @@ -710,6 +719,10 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
if parts[1].Type != "image_url" || parts[1].ImageURL.URL != "URL" || parts[1].ImageURL.Detail != "high" {
t.Errorf("invalid image_url part")
}
if parts[2].Type != "file" || parts[2].File.FileID != "file-123" ||
parts[2].File.FileName != "test.txt" || parts[2].File.FileData != "dGVzdA==" {
t.Errorf("invalid file part: %v", parts[2])
}

s, err := json.Marshal(msgs)
if err != nil {
Expand Down Expand Up @@ -756,6 +769,103 @@ func TestMultipartChatMessageSerialization(t *testing.T) {
}
}

func TestChatMessageFile(t *testing.T) {
// Test file part with FileID
filePart := openai.ChatMessagePart{
Type: openai.ChatMessagePartTypeFile,
File: &openai.ChatMessageFile{
FileID: "file-abc123",
},
}

// Test serialization
data, err := json.Marshal(filePart)
if err != nil {
t.Fatalf("Expected no error: %s", err)
}

expected := `{"type":"file","file":{"file_id":"file-abc123"}}`
result := strings.ReplaceAll(string(data), " ", "")
if result != expected {
t.Errorf("Expected %s, got %s", expected, result)
}

// Test deserialization
var parsedPart openai.ChatMessagePart
err = json.Unmarshal(data, &parsedPart)
if err != nil {
t.Fatalf("Expected no error: %s", err)
}

if parsedPart.Type != openai.ChatMessagePartTypeFile {
t.Errorf("Expected type %s, got %s", openai.ChatMessagePartTypeFile, parsedPart.Type)
}
if parsedPart.File == nil {
t.Fatal("Expected File to be non-nil")
}
if parsedPart.File.FileID != "file-abc123" {
t.Errorf("Expected FileID %s, got %s", "file-abc123", parsedPart.File.FileID)
}

// Test file part with all fields
filePartComplete := openai.ChatMessagePart{
Type: openai.ChatMessagePartTypeFile,
File: &openai.ChatMessageFile{
FileID: "file-xyz789",
FileName: "document.pdf",
FileData: "JVBERi0xLjQK", // base64 for "%PDF-1.4\n"
},
}

data, err = json.Marshal(filePartComplete)
if err != nil {
t.Fatalf("Expected no error: %s", err)
}

expected = `{"type":"file","file":{"file_id":"file-xyz789","filename":"document.pdf","file_data":"JVBERi0xLjQK"}}`
result = strings.ReplaceAll(string(data), " ", "")
if result != expected {
t.Errorf("Expected %s, got %s", expected, result)
}

// Test deserialization of complete file part
var parsedCompleteFile openai.ChatMessagePart
err = json.Unmarshal(data, &parsedCompleteFile)
if err != nil {
t.Fatalf("Expected no error: %s", err)
}

if parsedCompleteFile.File.FileID != "file-xyz789" {
t.Errorf("Expected FileID %s, got %s", "file-xyz789", parsedCompleteFile.File.FileID)
}
if parsedCompleteFile.File.FileName != "document.pdf" {
t.Errorf("Expected FileName %s, got %s", "document.pdf", parsedCompleteFile.File.FileName)
}
if parsedCompleteFile.File.FileData != "JVBERi0xLjQK" {
t.Errorf("Expected FileData %s, got %s", "JVBERi0xLjQK", parsedCompleteFile.File.FileData)
}
}

func TestChatMessagePartTypeConstants(t *testing.T) {
// Test that the new file constant is properly defined
if openai.ChatMessagePartTypeFile != "file" {
t.Errorf("Expected ChatMessagePartTypeFile to be 'file', got %s", openai.ChatMessagePartTypeFile)
}

// Test all part type constants
expectedTypes := map[openai.ChatMessagePartType]string{
openai.ChatMessagePartTypeText: "text",
openai.ChatMessagePartTypeImageURL: "image_url",
openai.ChatMessagePartTypeFile: "file",
}

for constant, expected := range expectedTypes {
if string(constant) != expected {
t.Errorf("Expected %s to be %s, got %s", constant, expected, string(constant))
}
}
}

// handleChatCompletionEndpoint Handles the ChatGPT completion endpoint by the test server.
func handleChatCompletionEndpoint(w http.ResponseWriter, r *http.Request) {
var err error
Expand Down
Loading