|
1 | 1 | package apiserversdk
|
2 | 2 |
|
3 | 3 | import (
|
| 4 | + "bytes" |
4 | 5 | "context"
|
5 | 6 | "errors"
|
| 7 | + "io" |
6 | 8 | "net"
|
7 | 9 | "net/http"
|
8 | 10 | "path/filepath"
|
| 11 | + "strings" |
9 | 12 | "sync/atomic"
|
10 | 13 | "testing"
|
11 | 14 | "time"
|
@@ -325,3 +328,154 @@ var _ = Describe("kuberay service", Ordered, func() {
|
325 | 328 | })
|
326 | 329 | })
|
327 | 330 | })
|
| 331 | + |
| 332 | +var _ = Describe("retryRoundTripper", func() { |
| 333 | + It("should not retry on successful status OK", func() { |
| 334 | + var attempts int32 |
| 335 | + mock := &mockRoundTripper{ |
| 336 | + fn: func(_ *http.Request) (*http.Response, error) { |
| 337 | + atomic.AddInt32(&attempts, 1) |
| 338 | + return &http.Response{ /* Always return OK status */ |
| 339 | + StatusCode: http.StatusOK, |
| 340 | + Body: io.NopCloser(strings.NewReader("OK")), |
| 341 | + }, nil |
| 342 | + }, |
| 343 | + } |
| 344 | + retrier := newRetryRoundTripper(mock) |
| 345 | + req, err := http.NewRequest(http.MethodGet, "http://test", nil) |
| 346 | + Expect(err).ToNot(HaveOccurred()) |
| 347 | + resp, err := retrier.RoundTrip(req) |
| 348 | + Expect(err).ToNot(HaveOccurred()) |
| 349 | + Expect(resp.StatusCode).To(Equal(http.StatusOK)) |
| 350 | + Expect(attempts).To(Equal(int32(1))) |
| 351 | + }) |
| 352 | + |
| 353 | + It("should retry failed requests and eventually succeed", func() { |
| 354 | + const maxFailure = 2 |
| 355 | + var attempts int32 |
| 356 | + mock := &mockRoundTripper{ |
| 357 | + fn: func(_ *http.Request) (*http.Response, error) { |
| 358 | + count := atomic.AddInt32(&attempts, 1) |
| 359 | + if count <= maxFailure { |
| 360 | + return &http.Response{ |
| 361 | + StatusCode: http.StatusInternalServerError, |
| 362 | + Body: io.NopCloser(strings.NewReader("internal error")), |
| 363 | + }, nil |
| 364 | + } |
| 365 | + return &http.Response{ |
| 366 | + StatusCode: http.StatusOK, |
| 367 | + Body: io.NopCloser(strings.NewReader("ok")), |
| 368 | + }, nil |
| 369 | + }, |
| 370 | + } |
| 371 | + retrier := newRetryRoundTripper(mock) |
| 372 | + req, err := http.NewRequest(http.MethodGet, "http://test", nil) |
| 373 | + Expect(err).ToNot(HaveOccurred()) |
| 374 | + resp, err := retrier.RoundTrip(req) |
| 375 | + Expect(err).ToNot(HaveOccurred()) |
| 376 | + Expect(resp.StatusCode).To(Equal(http.StatusOK)) |
| 377 | + Expect(attempts).To(Equal(int32(maxFailure + 1))) |
| 378 | + }) |
| 379 | + |
| 380 | + It("Retries exceed maximum retry counts", func() { |
| 381 | + var attempts int32 |
| 382 | + mock := &mockRoundTripper{ |
| 383 | + fn: func(_ *http.Request) (*http.Response, error) { |
| 384 | + atomic.AddInt32(&attempts, 1) |
| 385 | + return &http.Response{ /* Always return retriable status */ |
| 386 | + StatusCode: http.StatusInternalServerError, |
| 387 | + Body: io.NopCloser(strings.NewReader("internal error")), |
| 388 | + }, nil |
| 389 | + }, |
| 390 | + } |
| 391 | + retrier := newRetryRoundTripper(mock) |
| 392 | + req, err := http.NewRequest(http.MethodGet, "http://test", nil) |
| 393 | + Expect(err).ToNot(HaveOccurred()) |
| 394 | + resp, err := retrier.RoundTrip(req) |
| 395 | + Expect(err).ToNot(HaveOccurred()) |
| 396 | + Expect(resp.StatusCode).To(Equal(http.StatusInternalServerError)) |
| 397 | + Expect(attempts).To(Equal(int32(HTTPClientDefaultMaxRetry + 1))) |
| 398 | + }) |
| 399 | + |
| 400 | + It("Retries on request with body", func() { |
| 401 | + const testBody = "test-body" |
| 402 | + const maxFailure = 2 |
| 403 | + var attempts int32 |
| 404 | + mock := &mockRoundTripper{ |
| 405 | + fn: func(req *http.Request) (*http.Response, error) { |
| 406 | + count := atomic.AddInt32(&attempts, 1) |
| 407 | + reqBody, err := io.ReadAll(req.Body) |
| 408 | + Expect(err).ToNot(HaveOccurred()) |
| 409 | + Expect(string(reqBody)).To(Equal(testBody)) |
| 410 | + |
| 411 | + if count <= maxFailure { |
| 412 | + return &http.Response{ |
| 413 | + StatusCode: http.StatusInternalServerError, |
| 414 | + Body: io.NopCloser(strings.NewReader("internal error")), |
| 415 | + }, nil |
| 416 | + } |
| 417 | + return &http.Response{ |
| 418 | + StatusCode: http.StatusOK, |
| 419 | + Body: io.NopCloser(strings.NewReader("ok")), |
| 420 | + }, nil |
| 421 | + }, |
| 422 | + } |
| 423 | + retrier := newRetryRoundTripper(mock) |
| 424 | + body := bytes.NewBufferString(testBody) |
| 425 | + req, err := http.NewRequest(http.MethodPost, "http://test", body) |
| 426 | + Expect(err).ToNot(HaveOccurred()) |
| 427 | + resp, err := retrier.RoundTrip(req) |
| 428 | + Expect(err).ToNot(HaveOccurred()) |
| 429 | + Expect(resp.StatusCode).To(Equal(http.StatusOK)) |
| 430 | + Expect(attempts).To(Equal(int32(maxFailure + 1))) |
| 431 | + }) |
| 432 | + |
| 433 | + It("should not retry on non-retriable status", func() { |
| 434 | + var attempts int32 |
| 435 | + mock := &mockRoundTripper{ |
| 436 | + fn: func(_ *http.Request) (*http.Response, error) { |
| 437 | + atomic.AddInt32(&attempts, 1) |
| 438 | + return &http.Response{ /* Always return non-retriable status */ |
| 439 | + StatusCode: http.StatusNotFound, |
| 440 | + Body: io.NopCloser(strings.NewReader("Not Found")), |
| 441 | + }, nil |
| 442 | + }, |
| 443 | + } |
| 444 | + retrier := newRetryRoundTripper(mock) |
| 445 | + req, err := http.NewRequest(http.MethodGet, "http://test", nil) |
| 446 | + Expect(err).ToNot(HaveOccurred()) |
| 447 | + resp, err := retrier.RoundTrip(req) |
| 448 | + Expect(err).ToNot(HaveOccurred()) |
| 449 | + Expect(resp.StatusCode).To(Equal(http.StatusNotFound)) |
| 450 | + Expect(attempts).To(Equal(int32(1))) |
| 451 | + }) |
| 452 | + |
| 453 | + It("should respect context timeout and stop retrying", func() { |
| 454 | + mock := &mockRoundTripper{ |
| 455 | + fn: func(_ *http.Request) (*http.Response, error) { |
| 456 | + time.Sleep(100 * time.Millisecond) |
| 457 | + return &http.Response{ |
| 458 | + StatusCode: http.StatusInternalServerError, |
| 459 | + Body: io.NopCloser(strings.NewReader("internal error")), |
| 460 | + }, nil |
| 461 | + }, |
| 462 | + } |
| 463 | + retrier := newRetryRoundTripper(mock) |
| 464 | + ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond) |
| 465 | + defer cancel() |
| 466 | + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://test", nil) |
| 467 | + Expect(err).ToNot(HaveOccurred()) |
| 468 | + resp, err := retrier.RoundTrip(req) |
| 469 | + Expect(err).To(HaveOccurred()) |
| 470 | + Expect(err.Error()).To(ContainSubstring("retry timeout exceeded context deadline")) |
| 471 | + Expect(resp).ToNot(BeNil()) |
| 472 | + }) |
| 473 | +}) |
| 474 | + |
| 475 | +type mockRoundTripper struct { |
| 476 | + fn func(*http.Request) (*http.Response, error) |
| 477 | +} |
| 478 | + |
| 479 | +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { |
| 480 | + return m.fn(req) |
| 481 | +} |
0 commit comments