diff --git a/client/base_client.go b/client/base_client.go index 6e84a19cd..85475be8a 100644 --- a/client/base_client.go +++ b/client/base_client.go @@ -1,6 +1,7 @@ package client import ( + "context" "net/http" "net/url" "time" @@ -14,3 +15,9 @@ type BaseClient interface { SetOauth(auth OAuth) OAuth() OAuth } + +type BaseClientWithContext interface { + BaseClient + SendRequestWithContext(ctx context.Context, method string, rawURL string, data url.Values, + headers map[string]interface{}, body ...byte) (*http.Response, error) +} diff --git a/client/client.go b/client/client.go index d4b58ace9..342307c2d 100644 --- a/client/client.go +++ b/client/client.go @@ -140,6 +140,12 @@ var userAgentOnce sync.Once func (c *Client) SendRequest(method string, rawURL string, data url.Values, headers map[string]interface{}, body ...byte) (*http.Response, error) { + return c.SendRequestWithContext(context.Background(), method, rawURL, data, headers, body...) +} + +func (c *Client) SendRequestWithContext(ctx context.Context, method string, rawURL string, data url.Values, + headers map[string]interface{}, body ...byte) (*http.Response, error) { + contentType := extractContentTypeHeader(headers) u, err := url.Parse(rawURL) @@ -167,7 +173,7 @@ func (c *Client) SendRequest(method string, rawURL string, data url.Values, //data is already processed and information will be added to u(the url) in the //previous step. Now body will solely contain json payload if contentType == jsonContentType { - req, err = http.NewRequest(method, u.String(), bytes.NewBuffer(body)) + req, err = http.NewRequestWithContext(ctx, method, u.String(), bytes.NewBuffer(body)) if err != nil { return nil, err } @@ -177,7 +183,7 @@ func (c *Client) SendRequest(method string, rawURL string, data url.Values, if method == http.MethodPost || method == http.MethodPut || method == http.MethodPatch { valueReader = strings.NewReader(data.Encode()) } - req, err = http.NewRequestWithContext(context.Background(), method, u.String(), valueReader) + req, err = http.NewRequestWithContext(ctx, method, u.String(), valueReader) if err != nil { return nil, err } @@ -203,7 +209,7 @@ func (c *Client) SendRequest(method string, rawURL string, data url.Values, } if c.OAuth() != nil { oauth := c.OAuth() - token, _ := c.OAuth().GetAccessToken(context.TODO()) + token, _ := c.OAuth().GetAccessToken(ctx) if token != "" { req.Header.Add("Authorization", "Bearer "+token) } diff --git a/client/page_util.go b/client/page_util.go index a9b722f71..a9f936aba 100644 --- a/client/page_util.go +++ b/client/page_util.go @@ -1,6 +1,7 @@ package client import ( + "context" "encoding/json" "fmt" "strings" @@ -25,13 +26,17 @@ func ReadLimits(pageSize *int, limit *int) int { } } -func GetNext(baseUrl string, response interface{}, getNextPage func(nextPageUri string) (interface{}, error)) (interface{}, error) { +func GetNext(baseUrl string, response interface{}, getNextPage func(ctx context.Context, nextPageUri string) (interface{}, error)) (interface{}, error) { + return GetNextWithContext(context.Background(), baseUrl, response, getNextPage) +} + +func GetNextWithContext(ctx context.Context, baseUrl string, response interface{}, getNextPage func(ctx context.Context, nextPageUri string) (interface{}, error)) (interface{}, error) { nextPageUrl, err := getNextPageUrl(baseUrl, response) if err != nil { return nil, err } - return getNextPage(nextPageUrl) + return getNextPage(ctx, nextPageUrl) } func toMap(s interface{}) (map[string]interface{}, error) { diff --git a/client/page_util_test.go b/client/page_util_test.go index ebfd85509..7da2af5cd 100644 --- a/client/page_util_test.go +++ b/client/page_util_test.go @@ -2,6 +2,7 @@ package client import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -140,7 +141,7 @@ type testMessage struct { To *string `json:"to,omitempty"` } -func getSomething(nextPageUrl string) (interface{}, error) { +func getSomething(ctx context.Context, nextPageUrl string) (interface{}, error) { return nextPageUrl, nil } @@ -151,11 +152,11 @@ func TestPageUtil_GetNext(t *testing.T) { ps := &testResponse{} _ = json.NewDecoder(response.Body).Decode(ps) - nextPageUrl, err := GetNext(baseUrl, ps, getSomething) + nextPageUrl, err := GetNextWithContext(context.Background(), baseUrl, ps, getSomething) assert.Equal(t, "https://api.twilio.com/2010-04-01/Accounts/ACXX/Messages.json?From=9999999999&PageNumber=&To=4444444444&PageSize=2&Page=1&PageToken=PASMXX", nextPageUrl) assert.Nil(t, err) - nextPageUrl, err = GetNext(baseUrl, nil, getSomething) + nextPageUrl, err = GetNextWithContext(context.Background(), baseUrl, nil, getSomething) assert.Empty(t, nextPageUrl) assert.Nil(t, err) } diff --git a/client/request_handler.go b/client/request_handler.go index 5c0fdcb8e..d57038e6b 100644 --- a/client/request_handler.go +++ b/client/request_handler.go @@ -2,6 +2,7 @@ package client import ( + "context" "net/http" "net/url" "os" @@ -9,25 +10,34 @@ import ( ) type RequestHandler struct { - Client BaseClient - Edge string - Region string + Client BaseClient + Edge string + Region string + clientWithContext BaseClientWithContext } func NewRequestHandler(client BaseClient) *RequestHandler { + // If the base client supports context, add it to the request handler. + // Otherwise we leave it nil and the base client will be used. + clientWithContext, _ := client.(BaseClientWithContext) + return &RequestHandler{ - Client: client, - Edge: os.Getenv("TWILIO_EDGE"), - Region: os.Getenv("TWILIO_REGION"), + Client: client, + Edge: os.Getenv("TWILIO_EDGE"), + Region: os.Getenv("TWILIO_REGION"), + clientWithContext: clientWithContext, } } -func (c *RequestHandler) sendRequest(method string, rawURL string, data url.Values, +func (c *RequestHandler) sendRequest(ctx context.Context, method string, rawURL string, data url.Values, headers map[string]interface{}, body ...byte) (*http.Response, error) { parsedURL, err := c.BuildUrl(rawURL) if err != nil { return nil, err } + if c.clientWithContext != nil { + return c.clientWithContext.SendRequestWithContext(ctx, method, parsedURL, data, headers, body...) + } return c.Client.SendRequest(method, parsedURL, data, headers, body...) } @@ -83,21 +93,41 @@ func (c *RequestHandler) BuildUrl(rawURL string) (string, error) { } func (c *RequestHandler) Post(path string, bodyData url.Values, headers map[string]interface{}, body ...byte) (*http.Response, error) { - return c.sendRequest(http.MethodPost, path, bodyData, headers, body...) + return c.PostWithContext(context.Background(), path, bodyData, headers, body...) +} + +func (c *RequestHandler) PostWithContext(ctx context.Context, path string, bodyData url.Values, headers map[string]interface{}, body ...byte) (*http.Response, error) { + return c.clientWithContext.SendRequestWithContext(ctx, http.MethodPost, path, bodyData, headers, body...) } func (c *RequestHandler) Put(path string, bodyData url.Values, headers map[string]interface{}, body ...byte) (*http.Response, error) { - return c.sendRequest(http.MethodPut, path, bodyData, headers, body...) + return c.PutWithContext(context.Background(), path, bodyData, headers, body...) +} + +func (c *RequestHandler) PutWithContext(ctx context.Context, path string, bodyData url.Values, headers map[string]interface{}, body ...byte) (*http.Response, error) { + return c.clientWithContext.SendRequestWithContext(ctx, http.MethodPut, path, bodyData, headers, body...) } func (c *RequestHandler) Patch(path string, bodyData url.Values, headers map[string]interface{}, body ...byte) (*http.Response, error) { - return c.sendRequest(http.MethodPatch, path, bodyData, headers, body...) + return c.PatchWithContext(context.Background(), path, bodyData, headers, body...) +} + +func (c *RequestHandler) PatchWithContext(ctx context.Context, path string, bodyData url.Values, headers map[string]interface{}, body ...byte) (*http.Response, error) { + return c.clientWithContext.SendRequestWithContext(ctx, http.MethodPatch, path, bodyData, headers, body...) } func (c *RequestHandler) Get(path string, queryData url.Values, headers map[string]interface{}) (*http.Response, error) { - return c.sendRequest(http.MethodGet, path, queryData, headers) + return c.GetWithContext(context.Background(), path, queryData, headers) +} + +func (c *RequestHandler) GetWithContext(ctx context.Context, path string, queryData url.Values, headers map[string]interface{}) (*http.Response, error) { + return c.clientWithContext.SendRequestWithContext(ctx, http.MethodGet, path, queryData, headers) } func (c *RequestHandler) Delete(path string, queryData url.Values, headers map[string]interface{}) (*http.Response, error) { - return c.sendRequest(http.MethodDelete, path, queryData, headers) + return c.DeleteWithContext(context.Background(), path, queryData, headers) +} + +func (c *RequestHandler) DeleteWithContext(ctx context.Context, path string, queryData url.Values, headers map[string]interface{}) (*http.Response, error) { + return c.clientWithContext.SendRequestWithContext(ctx, http.MethodDelete, path, queryData, headers) } diff --git a/client/request_handler_test.go b/client/request_handler_test.go index 1756ae521..de831102f 100644 --- a/client/request_handler_test.go +++ b/client/request_handler_test.go @@ -1,6 +1,7 @@ package client_test import ( + "context" "errors" "net/http" "net/http/httptest" @@ -83,7 +84,7 @@ func TestRequestHandler_SendGetRequest(t *testing.T) { defer errorServer.Close() requestHandler := NewRequestHandler("user", "pass") - resp, err := requestHandler.Get(errorServer.URL, nil, nil) //nolint:bodyclose + resp, err := requestHandler.GetWithContext(context.Background(), errorServer.URL, nil, nil) //nolint:bodyclose twilioError := err.(*client.TwilioRestError) assert.Nil(t, resp) assert.Equal(t, 400, twilioError.Status) @@ -108,7 +109,7 @@ func TestRequestHandler_SendPostRequest(t *testing.T) { defer errorServer.Close() requestHandler := NewRequestHandler("user", "pass") - resp, err := requestHandler.Post(errorServer.URL, nil, nil) //nolint:bodyclose + resp, err := requestHandler.PostWithContext(context.Background(), errorServer.URL, nil, nil) //nolint:bodyclose twilioError := err.(*client.TwilioRestError) assert.Nil(t, resp) assert.Equal(t, 400, twilioError.Status) diff --git a/cluster_test.go b/cluster_test.go index 479d4da04..5557b5a25 100644 --- a/cluster_test.go +++ b/cluster_test.go @@ -4,6 +4,7 @@ package twilio import ( + "context" "os" "testing" @@ -15,6 +16,7 @@ import ( EventsV1 "github.com/twilio/twilio-go/rest/events/v1" "github.com/stretchr/testify/assert" + IamV1 "github.com/twilio/twilio-go/rest/iam/v1" ) @@ -268,3 +270,26 @@ func TestOrgsScimUerList(t *testing.T) { assert.Nil(t, err) assert.NotNil(t, users) } + +func TestSendingATextWithContext(t *testing.T) { + params := &Api.CreateMessageParams{} + params.SetTo(to) + params.SetFrom(from) + params.SetBody("Hello there") + + resp, err := testClient.Api.CreateMessageWithContext(context.Background(), params) + assert.Nil(t, err) + assert.NotNil(t, resp) + assert.Equal(t, "Hello there", *resp.Body) + assert.Equal(t, from, *resp.From) + assert.Equal(t, to, *resp.To) +} + +func TestOrgsAccountsListWithContext(t *testing.T) { + listAccounts, err := orgsClient.PreviewIamOrganization.ListOrganizationAccountsWithContext(context.Background(), orgSid, &PreviewIam.ListOrganizationAccountsParams{}) + assert.Nil(t, err) + assert.NotNil(t, listAccounts) + accounts, err := orgsClient.PreviewIamOrganization.FetchOrganizationAccountWithContext(context.Background(), orgSid, &PreviewIam.FetchOrganizationAccountParams{PathAccountSid: &accountSidOrgs}) + assert.Nil(t, err) + assert.NotNil(t, accounts) +} diff --git a/oauth.go b/oauth.go index 83c4205c6..096f408ec 100644 --- a/oauth.go +++ b/oauth.go @@ -100,7 +100,7 @@ func (a *APIOAuth) GetAccessToken(ctx context.Context) (string, error) { SetClientId(a.creds.ClientId). SetClientSecret(a.creds.ClientSecret) a.iamService.RequestHandler().Client.SetOauth(nil) // set oauth to nil to make no-auth request - token, err := a.iamService.CreateToken(params) + token, err := a.iamService.CreateTokenWithContext(ctx, params) if err == nil { a.tokenAuth = TokenAuth{ Token: *token.AccessToken,