Skip to content

Commit faf9a06

Browse files
committed
Eliminate races
1 parent b4cbf64 commit faf9a06

File tree

4 files changed

+80
-72
lines changed

4 files changed

+80
-72
lines changed

parrot/examples_test.go

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ func ExampleServer_Register_internal() {
3636
ResponseStatusCode: http.StatusOK,
3737
}
3838

39+
waitForParrotServerInternal(p, time.Second) // Wait for the parrot server to start
40+
3941
// Register the route with the parrot instance
4042
err = p.Register(route)
4143
if err != nil {
@@ -86,7 +88,7 @@ func ExampleServer_Register_external() {
8688
client := resty.New()
8789
client.SetBaseURL(fmt.Sprintf("http://localhost:%d", port)) // The URL of the parrot server
8890

89-
waitForParrotServer(client, time.Second) // Wait for the parrot server to start
91+
waitForParrotServerExternal(client, time.Second) // Wait for the parrot server to start
9092

9193
// Register a new route /test that will return a 200 status code with a text/plain response body of "Squawk"
9294
route := &parrot.Route{
@@ -158,6 +160,8 @@ func ExampleRecorder_internal() {
158160
panic(err)
159161
}
160162

163+
waitForParrotServerInternal(p, time.Second) // Wait for the parrot server to start
164+
161165
// Register the recorder with the parrot instance
162166
err = p.Record(recorder.URL())
163167
if err != nil {
@@ -225,7 +229,7 @@ func ExampleRecorder_external() {
225229
client := resty.New()
226230
client.SetBaseURL(fmt.Sprintf("http://localhost:%d", port)) // The URL of the parrot server
227231

228-
waitForParrotServer(client, time.Second) // Wait for the parrot server to start
232+
waitForParrotServerExternal(client, time.Second) // Wait for the parrot server to start
229233

230234
// Register a new route /test that will return a 200 status code with a text/plain response body of "Squawk"
231235
route := &parrot.Route{
@@ -290,8 +294,8 @@ func ExampleRecorder_external() {
290294
// Squawk
291295
}
292296

293-
// waitForParrotServer checks the parrot server health endpoint until it returns a 200 status code or the timeout is reached
294-
func waitForParrotServer(client *resty.Client, timeoutDur time.Duration) {
297+
// waitForParrotServerExternal checks the parrot server health endpoint until it returns a 200 status code or the timeout is reached
298+
func waitForParrotServerExternal(client *resty.Client, timeoutDur time.Duration) {
295299
ticker := time.NewTicker(50 * time.Millisecond)
296300
defer ticker.Stop()
297301
timeout := time.NewTimer(timeoutDur)
@@ -310,3 +314,19 @@ func waitForParrotServer(client *resty.Client, timeoutDur time.Duration) {
310314
}
311315
}
312316
}
317+
318+
func waitForParrotServerInternal(p *parrot.Server, timeoutDur time.Duration) {
319+
ticker := time.NewTicker(50 * time.Millisecond)
320+
defer ticker.Stop()
321+
timeout := time.NewTimer(timeoutDur)
322+
for { // Wait for the parrot server to start
323+
select {
324+
case <-ticker.C:
325+
if err := p.Healthy(); err == nil {
326+
return
327+
}
328+
case <-timeout.C:
329+
panic("timeout waiting for parrot server to start")
330+
}
331+
}
332+
}

parrot/parrot.go

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"strconv"
1616
"strings"
1717
"sync"
18+
"sync/atomic"
1819
"time"
1920

2021
"github.com/go-chi/chi"
@@ -67,13 +68,12 @@ type Server struct {
6768
recordersMu sync.RWMutex
6869

6970
// Save and shutdown
70-
shutDown bool
71+
shutDown atomic.Bool
7172
shutDownChan chan struct{}
7273
shutDownOnce sync.Once
7374
saveFileName string
7475

7576
// Logging
76-
useCustomLogger bool
7777
logFileName string
7878
logFile *os.File
7979
logLevel zerolog.Level
@@ -113,33 +113,32 @@ func Wake(options ...ServerOption) (*Server, error) {
113113
}
114114
}
115115

116+
// Setup logger
116117
var err error
117118
p.logFile, err = os.Create(p.logFileName)
118119
if err != nil {
119120
return nil, fmt.Errorf("failed to create log file: %w", err)
120121
}
121122

122-
if !p.useCustomLogger { // Build default logger
123-
var writers []io.Writer
123+
var writers []io.Writer
124124

125-
zerolog.TimeFieldFormat = "2006-01-02T15:04:05.000"
126-
if !p.disableConsoleLogs {
127-
if p.jsonLogs {
128-
writers = append(writers, os.Stderr)
129-
} else {
130-
consoleOut := zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: "2006-01-02T15:04:05.000"}
131-
writers = append(writers, consoleOut)
132-
}
133-
}
134-
135-
if p.logFile != nil {
136-
writers = append(writers, p.logFile)
125+
if !p.disableConsoleLogs {
126+
if p.jsonLogs {
127+
writers = append(writers, os.Stderr)
128+
} else {
129+
consoleOut := zerolog.ConsoleWriter{Out: os.Stderr, TimeFormat: "2006-01-02T15:04:05.000"}
130+
writers = append(writers, consoleOut)
137131
}
132+
}
138133

139-
multiWriter := zerolog.MultiLevelWriter(writers...)
140-
p.log = zerolog.New(multiWriter).Level(p.logLevel).With().Timestamp().Logger()
134+
if p.logFile != nil {
135+
writers = append(writers, p.logFile)
141136
}
142137

138+
multiWriter := zerolog.MultiLevelWriter(writers...)
139+
p.log = zerolog.New(multiWriter).Level(p.logLevel).With().Timestamp().Logger()
140+
141+
// Setup server
143142
listener, err := net.Listen("tcp", fmt.Sprintf(":%d", p.port))
144143
if err != nil {
145144
return nil, fmt.Errorf("failed to start listener: %w", err)
@@ -155,6 +154,7 @@ func Wake(options ...ServerOption) (*Server, error) {
155154
return nil, fmt.Errorf("failed to parse port: %w", err)
156155
}
157156

157+
// Initialize router
158158
p.router.Get(HealthRoute, p.healthHandlerGET)
159159

160160
p.router.Get(RoutesRoute, p.routesHandlerGET)
@@ -182,7 +182,7 @@ func Wake(options ...ServerOption) (*Server, error) {
182182
// run starts the parrot server
183183
func (p *Server) run(listener net.Listener) {
184184
defer func() {
185-
p.shutDown = true
185+
p.shutDown.Store(true)
186186
if err := p.save(); err != nil {
187187
p.log.Error().Err(err).Msg("Failed to save routes")
188188
}
@@ -241,7 +241,7 @@ func (p *Server) routeCallHandler(route *Route) http.HandlerFunc {
241241

242242
// Healthy checks if the parrot server is healthy
243243
func (p *Server) Healthy() error {
244-
if p.shutDown {
244+
if p.shutDown.Load() {
245245
return ErrServerShutdown
246246
}
247247

@@ -288,7 +288,7 @@ func (p *Server) healthHandlerGET(w http.ResponseWriter, r *http.Request) {
288288

289289
// Shutdown gracefully shuts down the parrot server
290290
func (p *Server) Shutdown(ctx context.Context) error {
291-
if p.shutDown {
291+
if p.shutDown.Load() {
292292
return ErrServerShutdown
293293
}
294294

@@ -308,7 +308,7 @@ func (p *Server) Address() string {
308308

309309
// Register adds a new route to the parrot
310310
func (p *Server) Register(route *Route) error {
311-
if p.shutDown {
311+
if p.shutDown.Load() {
312312
return ErrServerShutdown
313313
}
314314
if route == nil {
@@ -372,7 +372,7 @@ func (p *Server) routesHandlerPOST(w http.ResponseWriter, r *http.Request) {
372372

373373
// Record registers a new recorder with the parrot. All incoming requests to the parrot will be sent to the recorder.
374374
func (p *Server) Record(recorderURL string) error {
375-
if p.shutDown {
375+
if p.shutDown.Load() {
376376
return ErrServerShutdown
377377
}
378378

@@ -426,7 +426,7 @@ func (p *Server) recorderHandlerPOST(w http.ResponseWriter, r *http.Request) {
426426

427427
// Recorders returns the URLs of all registered recorders
428428
func (p *Server) Recorders() []string {
429-
if p.shutDown {
429+
if p.shutDown.Load() {
430430
return nil
431431
}
432432

@@ -493,14 +493,14 @@ func (p *Server) routesHandlerDELETE(w http.ResponseWriter, r *http.Request) {
493493

494494
// Call makes a request to the parrot server
495495
func (p *Server) Call(method, path string) (*resty.Response, error) {
496-
if p.shutDown {
496+
if p.shutDown.Load() {
497497
return nil, ErrServerShutdown
498498
}
499499
return p.client.R().Execute(method, "http://"+filepath.Join(p.Address(), path))
500500
}
501501

502502
func (p *Server) Routes() []*Route {
503-
if p.shutDown {
503+
if p.shutDown.Load() {
504504
return nil
505505
}
506506

parrot/parrot_options.go

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,6 @@ func WithLogLevel(level zerolog.Level) ServerOption {
2828
}
2929
}
3030

31-
// WithLogger sets the logger for the ParrotServer
32-
func WithLogger(l zerolog.Logger) ServerOption {
33-
return func(s *Server) error {
34-
s.log = l
35-
s.useCustomLogger = true
36-
return nil
37-
}
38-
}
39-
4031
// WithJSONLogs sets the logger to output JSON logs
4132
func WithJSONLogs() ServerOption {
4233
return func(s *Server) error {

parrot/parrot_test.go

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package parrot
22

33
import (
4-
"bytes"
54
"context"
65
"encoding/json"
76
"flag"
@@ -33,6 +32,36 @@ func TestMain(m *testing.M) {
3332
os.Exit(m.Run())
3433
}
3534

35+
func TestHealthy(t *testing.T) {
36+
t.Parallel()
37+
38+
p := newParrot(t)
39+
40+
healthCount := 0
41+
targetCount := 3
42+
43+
ticker := time.NewTicker(time.Millisecond * 10)
44+
timeout := time.NewTimer(time.Second)
45+
t.Cleanup(func() {
46+
ticker.Stop()
47+
timeout.Stop()
48+
})
49+
50+
for {
51+
select {
52+
case <-ticker.C:
53+
if err := p.Healthy(); err == nil {
54+
healthCount++
55+
}
56+
if healthCount >= targetCount {
57+
return
58+
}
59+
case <-timeout.C:
60+
require.GreaterOrEqual(t, targetCount, healthCount, "parrot never became healthy")
61+
}
62+
}
63+
}
64+
3665
func TestRegisterRoutes(t *testing.T) {
3766
t.Parallel()
3867

@@ -519,38 +548,6 @@ func TestShutDown(t *testing.T) {
519548
require.ErrorIs(t, err, ErrServerShutdown, "expected error shutting down parrot after shutdown")
520549
}
521550

522-
func TestCustomLogger(t *testing.T) {
523-
t.Parallel()
524-
525-
logBuffer := new(bytes.Buffer)
526-
testLogger := zerolog.New(logBuffer)
527-
528-
fileName := t.Name() + ".json"
529-
p, err := Wake(WithSaveFile(fileName), WithLogLevel(zerolog.DebugLevel), WithLogger(testLogger))
530-
require.NoError(t, err, "error waking parrot")
531-
t.Cleanup(func() {
532-
err := p.Shutdown(context.Background())
533-
assert.NoError(t, err, "error shutting down parrot")
534-
p.WaitShutdown() // Wait for shutdown to complete
535-
os.Remove(fileName)
536-
})
537-
538-
route := &Route{
539-
Method: http.MethodGet,
540-
Path: "/hello",
541-
RawResponseBody: "Squawk",
542-
ResponseStatusCode: http.StatusOK,
543-
}
544-
545-
err = p.Register(route)
546-
require.NoError(t, err, "error registering route")
547-
548-
_, err = p.Call(route.Method, route.Path)
549-
require.NoError(t, err, "error calling parrot")
550-
551-
require.Contains(t, logBuffer.String(), route.ID(), "expected log buffer to contain route call")
552-
}
553-
554551
func TestJSONLogger(t *testing.T) {
555552
t.Parallel()
556553

0 commit comments

Comments
 (0)