Skip to content

Commit 378e65a

Browse files
committed
Code cleanup and more tests
1 parent 207d9fe commit 378e65a

File tree

3 files changed

+75
-35
lines changed

3 files changed

+75
-35
lines changed

parrot/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,4 +32,4 @@ make bench # Benchmark
3232

3333
```sh
3434
make goreleaser # Uses goreleaser to build binaries and docker containers
35-
```
35+
```

parrot/parrot.go

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ import (
2323
"github.com/rs/zerolog/hlog"
2424
)
2525

26+
const (
27+
healthRoute = "/health"
28+
routesRoute = "/routes"
29+
recordRoute = "/record"
30+
)
31+
2632
// Route holds information about the mock route configuration
2733
type Route struct {
2834
// Method is the HTTP method to match
@@ -215,10 +221,10 @@ func Wake(options ...ServerOption) (*Server, error) {
215221

216222
mux := http.NewServeMux()
217223
// TODO: Add a route to enable registering recorders
218-
mux.HandleFunc("/routes", p.registerHandler)
219-
mux.HandleFunc("/record", p.recordHandler)
220-
mux.HandleFunc("/health", p.healthHandler)
221-
mux.HandleFunc("/", p.routesHandler)
224+
mux.HandleFunc(routesRoute, p.routeHandler)
225+
mux.HandleFunc(recordRoute, p.recordHandler)
226+
mux.HandleFunc(healthRoute, p.healthHandler)
227+
mux.HandleFunc("/", p.dynamicHandler)
222228

223229
p.server = &http.Server{
224230
ReadHeaderTimeout: 5 * time.Second,
@@ -329,7 +335,7 @@ func (p *Server) Record(recorderURL string) error {
329335
if recorderURL == "" {
330336
return ErrNoRecorderURL
331337
}
332-
_, err := url.Parse(recorderURL)
338+
_, err := url.ParseRequestURI(recorderURL)
333339
if err != nil {
334340
return ErrInvalidRecorderURL
335341
}
@@ -390,25 +396,19 @@ func (p *Server) Routes() []*Route {
390396
return routes
391397
}
392398

393-
// registerHandler handles registering, unregistering, and querying routes
394-
func (p *Server) registerHandler(w http.ResponseWriter, r *http.Request) {
399+
// routeHandler handles registering, unregistering, and querying routes
400+
func (p *Server) routeHandler(w http.ResponseWriter, r *http.Request) {
395401
routesLogger := zerolog.Ctx(r.Context())
396402
if r.Method == http.MethodDelete {
397-
var routeRequest *Route
398-
if err := json.NewDecoder(r.Body).Decode(&routeRequest); err != nil {
403+
var route *Route
404+
if err := json.NewDecoder(r.Body).Decode(&route); err != nil {
399405
http.Error(w, "Invalid request body", http.StatusBadRequest)
400406
routesLogger.Debug().Err(err).Msg("Failed to decode request body")
401407
return
402408
}
403409
defer r.Body.Close()
404410

405-
if routeRequest.ID() == "" {
406-
http.Error(w, "Route ID required", http.StatusBadRequest)
407-
routesLogger.Debug().Msg("No Route ID provided")
408-
return
409-
}
410-
411-
err := p.Delete(routeRequest.ID())
411+
err := p.Delete(route.ID())
412412
if err != nil {
413413
http.Error(w, err.Error(), http.StatusBadRequest)
414414
routesLogger.Debug().Err(err).Msg("Failed to unregister route")
@@ -417,8 +417,8 @@ func (p *Server) registerHandler(w http.ResponseWriter, r *http.Request) {
417417

418418
w.WriteHeader(http.StatusNoContent)
419419
routesLogger.Info().
420-
Str("Route ID", routeRequest.ID()).
421-
Msg("Route unregistered")
420+
Str("Route ID", route.ID()).
421+
Msg("Route deleted")
422422
return
423423
}
424424

@@ -431,13 +431,6 @@ func (p *Server) registerHandler(w http.ResponseWriter, r *http.Request) {
431431
}
432432
defer r.Body.Close()
433433

434-
if route.Method == "" || route.Path == "" {
435-
err := errors.New("Method and path are required")
436-
http.Error(w, err.Error(), http.StatusBadRequest)
437-
routesLogger.Debug().Err(err).Msg("Method and path are required")
438-
return
439-
}
440-
441434
err := p.Register(route)
442435
if err != nil {
443436
http.Error(w, err.Error(), http.StatusBadRequest)
@@ -450,7 +443,8 @@ func (p *Server) registerHandler(w http.ResponseWriter, r *http.Request) {
450443
}
451444

452445
if r.Method == http.MethodGet {
453-
jsonRoutes, err := json.Marshal(p.Routes())
446+
routes := p.Routes()
447+
jsonRoutes, err := json.Marshal(routes)
454448
if err != nil {
455449
http.Error(w, "Failed to marshal routes", http.StatusInternalServerError)
456450
routesLogger.Debug().Err(err).Msg("Failed to marshal routes")
@@ -464,16 +458,16 @@ func (p *Server) registerHandler(w http.ResponseWriter, r *http.Request) {
464458
return
465459
}
466460

467-
routesLogger.Debug().Msg("Returned routes")
461+
routesLogger.Debug().Int("Count", len(routes)).Msg("Returned routes")
468462
return
469463
}
470464

471465
http.Error(w, "Invalid method", http.StatusMethodNotAllowed)
472466
routesLogger.Debug().Msg("Invalid method")
473467
}
474468

475-
// routesHandler handles all incoming requests and responds based on the registered routes.
476-
func (p *Server) routesHandler(w http.ResponseWriter, r *http.Request) {
469+
// dynamicHandler handles all incoming requests and responds based on the registered routes.
470+
func (p *Server) dynamicHandler(w http.ResponseWriter, r *http.Request) {
477471
p.routesMu.RLock()
478472
route, exists := p.routes[r.Method+":"+r.URL.Path]
479473
p.routesMu.RUnlock()
@@ -741,16 +735,19 @@ var pathRegex = regexp.MustCompile(`^\/[a-zA-Z0-9\-._~%!$&'()*+,;=:@\/]*$`)
741735

742736
func isValidPath(path string) bool {
743737
switch path {
744-
case "", "/", "//", "/register", "/health", "/.", "/..":
738+
case "", "/", "//", healthRoute, recordRoute, routesRoute, "/.", "/..":
745739
return false
746740
}
747741
if !strings.HasPrefix(path, "/") {
748742
return false
749743
}
750-
if strings.HasPrefix(path, "/register") {
744+
if strings.HasPrefix(path, recordRoute) {
745+
return false
746+
}
747+
if strings.HasPrefix(path, healthRoute) {
751748
return false
752749
}
753-
if strings.HasPrefix(path, "/health") {
750+
if strings.HasPrefix(path, routesRoute) {
754751
return false
755752
}
756753
return pathRegex.MatchString(path)

parrot/parrot_test.go

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ func TestIsValidPath(t *testing.T) {
170170
valid: true,
171171
},
172172
{
173-
name: "no register",
174-
paths: []string{"/register", "/register/", "/register//", "/register/other_stuff"},
173+
name: "no protected paths",
174+
paths: []string{healthRoute, routesRoute, recordRoute, fmt.Sprintf("%s/%s", routesRoute, "route-id"), fmt.Sprintf("%s/%s", healthRoute, "recorder-id"), fmt.Sprintf("%s/%s", recordRoute, "recorder-id")},
175175
valid: false,
176176
},
177177
{
@@ -343,6 +343,30 @@ func TestBadRegisterRoute(t *testing.T) {
343343
ResponseStatusCode: http.StatusOK,
344344
},
345345
},
346+
{
347+
name: "too many responses",
348+
err: ErrOnlyOneResponse,
349+
route: &Route{
350+
Method: http.MethodGet,
351+
Path: "/hello",
352+
ResponseBody: map[string]any{"message": "Squawk"},
353+
Handler: func(w http.ResponseWriter, r *http.Request) {
354+
w.WriteHeader(http.StatusOK)
355+
_, _ = w.Write([]byte("Squawk"))
356+
},
357+
ResponseStatusCode: http.StatusOK,
358+
},
359+
},
360+
{
361+
name: "bad JSON",
362+
err: ErrResponseMarshal,
363+
route: &Route{
364+
Method: http.MethodGet,
365+
Path: "/json",
366+
ResponseBody: map[string]any{"message": make(chan int)},
367+
ResponseStatusCode: http.StatusOK,
368+
},
369+
},
346370
}
347371

348372
for _, tc := range testCases {
@@ -356,6 +380,18 @@ func TestBadRegisterRoute(t *testing.T) {
356380
}
357381
}
358382

383+
func TestBadRecorder(t *testing.T) {
384+
t.Parallel()
385+
386+
p := newParrot(t)
387+
388+
err := p.Record("")
389+
require.ErrorIs(t, err, ErrNoRecorderURL, "expected error recording parrot")
390+
391+
err = p.Record("invalid url")
392+
require.ErrorIs(t, err, ErrInvalidRecorderURL, "expected error recording parrot")
393+
}
394+
359395
func TestUnregisteredRoute(t *testing.T) {
360396
t.Parallel()
361397

@@ -395,6 +431,10 @@ func TestDelete(t *testing.T) {
395431
resp, err = p.Call(route.Method, route.Path)
396432
require.NoError(t, err, "error calling parrot")
397433
assert.Equal(t, http.StatusNotFound, resp.StatusCode())
434+
435+
// Try to delete the route again
436+
err = p.Delete(route.ID())
437+
require.ErrorIs(t, err, ErrRouteNotFound, "expected error deleting route")
398438
}
399439

400440
func TestSaveLoad(t *testing.T) {
@@ -477,6 +517,9 @@ func TestShutDown(t *testing.T) {
477517
})
478518
require.ErrorIs(t, err, ErrServerShutdown, "expected error registering route after shutdown")
479519

520+
err = p.Delete("route-id")
521+
require.ErrorIs(t, err, ErrServerShutdown, "expected error deleting route after shutdown")
522+
480523
err = p.Shutdown(context.Background())
481524
require.ErrorIs(t, err, ErrServerShutdown, "expected error shutting down parrot after shutdown")
482525
}

0 commit comments

Comments
 (0)