Skip to content

Commit fd9df76

Browse files
authored
Merge pull request #197 from haydenhoang/pr-193
add test for PR 193, fix nil pointer error when request is a GET
2 parents 93a27de + f46a7e8 commit fd9df76

File tree

2 files changed

+72
-0
lines changed

2 files changed

+72
-0
lines changed

typesense/api_call.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package typesense
22

33
import (
4+
"bytes"
45
"context"
56
"errors"
7+
"io"
68
"net/http"
79
"net/url"
810
"time"
@@ -67,12 +69,32 @@ func (a *APICall) Do(req *http.Request) (*http.Response, error) {
6769

6870
var lastResponse *http.Response
6971
var lastError error
72+
var bodyBytes []byte
73+
74+
if req.GetBody != nil {
75+
// Store body in case we need to retry
76+
reqBody, err := req.GetBody()
77+
if err != nil {
78+
return nil, err
79+
}
80+
defer reqBody.Close()
81+
82+
bodyBytes, err = io.ReadAll(reqBody)
83+
if err != nil {
84+
return nil, err
85+
}
86+
}
7087

7188
for numTries := 0; numTries < a.numRetriesPerRequest; numTries++ {
7289
node := a.getNextNode()
7390

7491
replaceRequestHostname(req, node.url)
7592

93+
if bodyBytes != nil {
94+
// Create a new io.ReadCloser for each retry
95+
req.Body = io.NopCloser(bytes.NewReader(bodyBytes))
96+
}
97+
7698
response, err := a.client.Do(req)
7799

78100
// return early if request is aborted

typesense/api_call_test.go

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
package typesense
22

33
import (
4+
"bytes"
45
"context"
6+
"io"
57
"net/http"
68
"net/http/httptest"
79
"testing"
@@ -392,3 +394,51 @@ func TestApiCallCanAbortRequest(t *testing.T) {
392394
assert.Nil(t, res)
393395
assert.Equal(t, requestURLHistory, serverURLs[:3])
394396
}
397+
398+
func TestApiCallRetryWithRequestBody(t *testing.T) {
399+
400+
requestURLHistory := make([]string, 0, 2)
401+
402+
servers, serverURLs := instantiateServers([]serverHandler{
403+
func(w http.ResponseWriter, r *http.Request) {
404+
appendHistory(&requestURLHistory, r)
405+
w.WriteHeader(501)
406+
},
407+
func(w http.ResponseWriter, r *http.Request) {
408+
appendHistory(&requestURLHistory, r)
409+
data := r.Body
410+
bodyBytes, _ := io.ReadAll(data)
411+
assert.Equal(t, string(bodyBytes), "body data")
412+
w.WriteHeader(201)
413+
},
414+
func(w http.ResponseWriter, r *http.Request) {
415+
appendHistory(&requestURLHistory, r)
416+
data := r.Body
417+
bodyBytes, _ := io.ReadAll(data)
418+
assert.Equal(t, string(bodyBytes), "body data")
419+
w.WriteHeader(203)
420+
},
421+
})
422+
for _, server := range servers {
423+
defer server.Close()
424+
}
425+
426+
apiCall := newAPICall(
427+
&ClientConfig{
428+
Nodes: serverURLs,
429+
ConnectionTimeout: 5 * time.Second,
430+
},
431+
)
432+
req, err := http.NewRequest(http.MethodPost, "http://example.com", bytes.NewBuffer([]byte("body data")))
433+
assert.NoError(t, err)
434+
435+
res, err := apiCall.Do(req)
436+
assert.NoError(t, err)
437+
assert.Equal(t, 201, res.StatusCode)
438+
439+
res2, err2 := apiCall.Do(req)
440+
assert.NoError(t, err2)
441+
assert.Equal(t, 203, res2.StatusCode)
442+
443+
assert.Equal(t, serverURLs, requestURLHistory)
444+
}

0 commit comments

Comments
 (0)