Skip to content

Commit 1a123fe

Browse files
authored
Add image variation implementation and fix #149 (#153)
* Compatible with the situation where the mask is empty in CreateEditImage. * Fix the test for the unnecessary removal of the mask.png file. * add image variation implementation * fix image variation bugs * fix ci-lint problem with max line character limit * add offitial doc link
1 parent f4a6a99 commit 1a123fe

File tree

2 files changed

+155
-5
lines changed

2 files changed

+155
-5
lines changed

image.go

Lines changed: 51 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,65 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
8686
return
8787
}
8888

89-
// mask
90-
mask, err := writer.CreateFormFile("mask", request.Mask.Name())
89+
// mask, it is optional
90+
if request.Mask != nil {
91+
mask, err2 := writer.CreateFormFile("mask", request.Mask.Name())
92+
if err2 != nil {
93+
return
94+
}
95+
_, err = io.Copy(mask, request.Mask)
96+
if err != nil {
97+
return
98+
}
99+
}
100+
101+
err = writer.WriteField("prompt", request.Prompt)
91102
if err != nil {
92103
return
93104
}
94-
_, err = io.Copy(mask, request.Mask)
105+
err = writer.WriteField("n", strconv.Itoa(request.N))
106+
if err != nil {
107+
return
108+
}
109+
err = writer.WriteField("size", request.Size)
110+
if err != nil {
111+
return
112+
}
113+
writer.Close()
114+
urlSuffix := "/images/edits"
115+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body)
95116
if err != nil {
96117
return
97118
}
98119

99-
err = writer.WriteField("prompt", request.Prompt)
120+
req.Header.Set("Content-Type", writer.FormDataContentType())
121+
err = c.sendRequest(req, &response)
122+
return
123+
}
124+
125+
// ImageVariRequest represents the request structure for the image API.
126+
type ImageVariRequest struct {
127+
Image *os.File `json:"image,omitempty"`
128+
N int `json:"n,omitempty"`
129+
Size string `json:"size,omitempty"`
130+
}
131+
132+
// CreateVariImage - API call to create an image variation. This is the main endpoint of the DALL-E API.
133+
// Use abbreviations(vari for variation) because ci-lint has a single-line length limit ...
134+
func (c *Client) CreateVariImage(ctx context.Context, request ImageVariRequest) (response ImageResponse, err error) {
135+
body := &bytes.Buffer{}
136+
writer := multipart.NewWriter(body)
137+
138+
// image
139+
image, err := writer.CreateFormFile("image", request.Image.Name())
140+
if err != nil {
141+
return
142+
}
143+
_, err = io.Copy(image, request.Image)
100144
if err != nil {
101145
return
102146
}
147+
103148
err = writer.WriteField("n", strconv.Itoa(request.N))
104149
if err != nil {
105150
return
@@ -109,7 +154,8 @@ func (c *Client) CreateEditImage(ctx context.Context, request ImageEditRequest)
109154
return
110155
}
111156
writer.Close()
112-
urlSuffix := "/images/edits"
157+
//https://platform.openai.com/docs/api-reference/images/create-variation
158+
urlSuffix := "/images/variations"
113159
req, err := http.NewRequestWithContext(ctx, http.MethodPost, c.fullURL(urlSuffix), body)
114160
if err != nil {
115161
return

image_test.go

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,43 @@ func TestImageEdit(t *testing.T) {
132132
}
133133
}
134134

135+
func TestImageEditWithoutMask(t *testing.T) {
136+
server := test.NewTestServer()
137+
server.RegisterHandler("/v1/images/edits", handleEditImageEndpoint)
138+
// create the test server
139+
var err error
140+
ts := server.OpenAITestServer()
141+
ts.Start()
142+
defer ts.Close()
143+
144+
config := DefaultConfig(test.GetTestToken())
145+
config.BaseURL = ts.URL + "/v1"
146+
client := NewClientWithConfig(config)
147+
ctx := context.Background()
148+
149+
origin, err := os.Create("image.png")
150+
if err != nil {
151+
t.Error("open origin file error")
152+
return
153+
}
154+
155+
defer func() {
156+
origin.Close()
157+
os.Remove("image.png")
158+
}()
159+
160+
req := ImageEditRequest{
161+
Image: origin,
162+
Prompt: "There is a turtle in the pool",
163+
N: 3,
164+
Size: CreateImageSize1024x1024,
165+
}
166+
_, err = client.CreateEditImage(ctx, req)
167+
if err != nil {
168+
t.Fatalf("CreateImage error: %v", err)
169+
}
170+
}
171+
135172
// handleEditImageEndpoint Handles the images endpoint by the test server.
136173
func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) {
137174
var resBytes []byte
@@ -162,3 +199,70 @@ func handleEditImageEndpoint(w http.ResponseWriter, r *http.Request) {
162199
resBytes, _ = json.Marshal(responses)
163200
fmt.Fprintln(w, string(resBytes))
164201
}
202+
203+
func TestImageVariation(t *testing.T) {
204+
server := test.NewTestServer()
205+
server.RegisterHandler("/v1/images/variations", handleVariateImageEndpoint)
206+
// create the test server
207+
var err error
208+
ts := server.OpenAITestServer()
209+
ts.Start()
210+
defer ts.Close()
211+
212+
config := DefaultConfig(test.GetTestToken())
213+
config.BaseURL = ts.URL + "/v1"
214+
client := NewClientWithConfig(config)
215+
ctx := context.Background()
216+
217+
origin, err := os.Create("image.png")
218+
if err != nil {
219+
t.Error("open origin file error")
220+
return
221+
}
222+
223+
defer func() {
224+
origin.Close()
225+
os.Remove("image.png")
226+
}()
227+
228+
req := ImageVariRequest{
229+
Image: origin,
230+
N: 3,
231+
Size: CreateImageSize1024x1024,
232+
}
233+
_, err = client.CreateVariImage(ctx, req)
234+
if err != nil {
235+
t.Fatalf("CreateImage error: %v", err)
236+
}
237+
}
238+
239+
// handleVariateImageEndpoint Handles the images endpoint by the test server.
240+
func handleVariateImageEndpoint(w http.ResponseWriter, r *http.Request) {
241+
var resBytes []byte
242+
243+
// imagess only accepts POST requests
244+
if r.Method != "POST" {
245+
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
246+
}
247+
248+
responses := ImageResponse{
249+
Created: time.Now().Unix(),
250+
Data: []ImageResponseDataInner{
251+
{
252+
URL: "test-url1",
253+
B64JSON: "",
254+
},
255+
{
256+
URL: "test-url2",
257+
B64JSON: "",
258+
},
259+
{
260+
URL: "test-url3",
261+
B64JSON: "",
262+
},
263+
},
264+
}
265+
266+
resBytes, _ = json.Marshal(responses)
267+
fmt.Fprintln(w, string(resBytes))
268+
}

0 commit comments

Comments
 (0)