diff --git a/client.go b/client.go index cef375348..3d34434bc 100644 --- a/client.go +++ b/client.go @@ -250,6 +250,8 @@ var azureDeploymentsEndpoints = []string{ "/audio/translations", "/audio/speech", "/images/generations", + "/images/edits", + "/images/variations", } // fullURL returns full URL for request. diff --git a/image.go b/image.go index 72077ce41..b38f6b0cf 100644 --- a/image.go +++ b/image.go @@ -3,6 +3,7 @@ package openai import ( "bytes" "context" + "fmt" "io" "net/http" "strconv" @@ -69,6 +70,10 @@ const ( CreateImageOutputFormatWEBP = "webp" ) +const ( + minFileTypeLength = 6 // "image/" +) + // ImageRequest represents the request structure for the image API. type ImageRequest struct { Prompt string `json:"prompt,omitempty"` @@ -134,15 +139,15 @@ func (c *Client) CreateImage(ctx context.Context, request ImageRequest) (respons // ImageEditRequest represents the request structure for the image API. type ImageEditRequest struct { - Image io.Reader `json:"image,omitempty"` - Mask io.Reader `json:"mask,omitempty"` - Prompt string `json:"prompt,omitempty"` - Model string `json:"model,omitempty"` - N int `json:"n,omitempty"` - Size string `json:"size,omitempty"` - ResponseFormat string `json:"response_format,omitempty"` - Quality string `json:"quality,omitempty"` - User string `json:"user,omitempty"` + Images []io.Reader `json:"images,omitempty"` + Mask io.Reader `json:"mask,omitempty"` + Prompt string `json:"prompt,omitempty"` + Model string `json:"model,omitempty"` + N int `json:"n,omitempty"` + Size string `json:"size,omitempty"` + ResponseFormat string `json:"response_format,omitempty"` + Quality string `json:"quality,omitempty"` + User string `json:"user,omitempty"` } // CreateEditImage - API call to create an image. This is the main endpoint of the DALL-E API. @@ -150,10 +155,25 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) body := &bytes.Buffer{} builder := c.createFormBuilder(body) - // image, filename is not required - err = builder.CreateFormFileReader("image", request.Image, "") - if err != nil { - return + for i, img := range request.Images { + // judge file type + var data []byte + data, err = io.ReadAll(img) + if err != nil { + return + } + fileType := http.DetectContentType(data) + if len(fileType) < minFileTypeLength { + err = fmt.Errorf("invalid file type: %s", fileType) + return + } + // get file extension + ext := fileType[minFileTypeLength:] + filename := fmt.Sprintf("%d.%s", i, ext) + err = builder.CreateFormFileReader("image", img, filename) + if err != nil { + return + } } // mask, it is optional @@ -180,9 +200,11 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest) return } - err = builder.WriteField("response_format", request.ResponseFormat) - if err != nil { - return + if request.ResponseFormat != "" { + err = builder.WriteField("response_format", request.ResponseFormat) + if err != nil { + return + } } err = builder.Close() diff --git a/image_api_test.go b/image_api_test.go index f6057b77d..1041b9430 100644 --- a/image_api_test.go +++ b/image_api_test.go @@ -100,7 +100,7 @@ func TestImageEdit(t *testing.T) { defer mask.Close() _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ - Image: origin, + Images: []io.Reader{origin}, Mask: mask, Prompt: "There is a turtle in the pool", N: 3, @@ -122,7 +122,7 @@ func TestImageEditWithoutMask(t *testing.T) { defer origin.Close() _, err = client.CreateEditImage(context.Background(), openai.ImageEditRequest{ - Image: origin, + Images: []io.Reader{origin}, Prompt: "There is a turtle in the pool", N: 3, Size: openai.CreateImageSize1024x1024,