@@ -15,6 +15,7 @@ import (
1515 "strconv"
1616 "strings"
1717 "sync"
18+ "sync/atomic"
1819 "time"
1920
2021 "github.com/go-chi/chi"
@@ -67,13 +68,12 @@ type Server struct {
6768 recordersMu sync.RWMutex
6869
6970 // Save and shutdown
70- shutDown bool
71+ shutDown atomic. Bool
7172 shutDownChan chan struct {}
7273 shutDownOnce sync.Once
7374 saveFileName string
7475
7576 // Logging
76- useCustomLogger bool
7777 logFileName string
7878 logFile * os.File
7979 logLevel zerolog.Level
@@ -113,33 +113,32 @@ func Wake(options ...ServerOption) (*Server, error) {
113113 }
114114 }
115115
116+ // Setup logger
116117 var err error
117118 p .logFile , err = os .Create (p .logFileName )
118119 if err != nil {
119120 return nil , fmt .Errorf ("failed to create log file: %w" , err )
120121 }
121122
122- if ! p .useCustomLogger { // Build default logger
123- var writers []io.Writer
123+ var writers []io.Writer
124124
125- zerolog .TimeFieldFormat = "2006-01-02T15:04:05.000"
126- if ! p .disableConsoleLogs {
127- if p .jsonLogs {
128- writers = append (writers , os .Stderr )
129- } else {
130- consoleOut := zerolog.ConsoleWriter {Out : os .Stderr , TimeFormat : "2006-01-02T15:04:05.000" }
131- writers = append (writers , consoleOut )
132- }
133- }
134-
135- if p .logFile != nil {
136- writers = append (writers , p .logFile )
125+ if ! p .disableConsoleLogs {
126+ if p .jsonLogs {
127+ writers = append (writers , os .Stderr )
128+ } else {
129+ consoleOut := zerolog.ConsoleWriter {Out : os .Stderr , TimeFormat : "2006-01-02T15:04:05.000" }
130+ writers = append (writers , consoleOut )
137131 }
132+ }
138133
139- multiWriter := zerolog . MultiLevelWriter ( writers ... )
140- p . log = zerolog . New ( multiWriter ). Level ( p . logLevel ). With (). Timestamp (). Logger ( )
134+ if p . logFile != nil {
135+ writers = append ( writers , p . logFile )
141136 }
142137
138+ multiWriter := zerolog .MultiLevelWriter (writers ... )
139+ p .log = zerolog .New (multiWriter ).Level (p .logLevel ).With ().Timestamp ().Logger ()
140+
141+ // Setup server
143142 listener , err := net .Listen ("tcp" , fmt .Sprintf (":%d" , p .port ))
144143 if err != nil {
145144 return nil , fmt .Errorf ("failed to start listener: %w" , err )
@@ -155,6 +154,7 @@ func Wake(options ...ServerOption) (*Server, error) {
155154 return nil , fmt .Errorf ("failed to parse port: %w" , err )
156155 }
157156
157+ // Initialize router
158158 p .router .Get (HealthRoute , p .healthHandlerGET )
159159
160160 p .router .Get (RoutesRoute , p .routesHandlerGET )
@@ -182,7 +182,7 @@ func Wake(options ...ServerOption) (*Server, error) {
182182// run starts the parrot server
183183func (p * Server ) run (listener net.Listener ) {
184184 defer func () {
185- p .shutDown = true
185+ p .shutDown . Store ( true )
186186 if err := p .save (); err != nil {
187187 p .log .Error ().Err (err ).Msg ("Failed to save routes" )
188188 }
@@ -241,7 +241,7 @@ func (p *Server) routeCallHandler(route *Route) http.HandlerFunc {
241241
242242// Healthy checks if the parrot server is healthy
243243func (p * Server ) Healthy () error {
244- if p .shutDown {
244+ if p .shutDown . Load () {
245245 return ErrServerShutdown
246246 }
247247
@@ -288,7 +288,7 @@ func (p *Server) healthHandlerGET(w http.ResponseWriter, r *http.Request) {
288288
289289// Shutdown gracefully shuts down the parrot server
290290func (p * Server ) Shutdown (ctx context.Context ) error {
291- if p .shutDown {
291+ if p .shutDown . Load () {
292292 return ErrServerShutdown
293293 }
294294
@@ -308,7 +308,7 @@ func (p *Server) Address() string {
308308
309309// Register adds a new route to the parrot
310310func (p * Server ) Register (route * Route ) error {
311- if p .shutDown {
311+ if p .shutDown . Load () {
312312 return ErrServerShutdown
313313 }
314314 if route == nil {
@@ -372,7 +372,7 @@ func (p *Server) routesHandlerPOST(w http.ResponseWriter, r *http.Request) {
372372
373373// Record registers a new recorder with the parrot. All incoming requests to the parrot will be sent to the recorder.
374374func (p * Server ) Record (recorderURL string ) error {
375- if p .shutDown {
375+ if p .shutDown . Load () {
376376 return ErrServerShutdown
377377 }
378378
@@ -426,7 +426,7 @@ func (p *Server) recorderHandlerPOST(w http.ResponseWriter, r *http.Request) {
426426
427427// Recorders returns the URLs of all registered recorders
428428func (p * Server ) Recorders () []string {
429- if p .shutDown {
429+ if p .shutDown . Load () {
430430 return nil
431431 }
432432
@@ -493,14 +493,14 @@ func (p *Server) routesHandlerDELETE(w http.ResponseWriter, r *http.Request) {
493493
494494// Call makes a request to the parrot server
495495func (p * Server ) Call (method , path string ) (* resty.Response , error ) {
496- if p .shutDown {
496+ if p .shutDown . Load () {
497497 return nil , ErrServerShutdown
498498 }
499499 return p .client .R ().Execute (method , "http://" + filepath .Join (p .Address (), path ))
500500}
501501
502502func (p * Server ) Routes () []* Route {
503- if p .shutDown {
503+ if p .shutDown . Load () {
504504 return nil
505505 }
506506
0 commit comments