From 6f89bc3163c0069a0f0deb512100d266de49bad6 Mon Sep 17 00:00:00 2001 From: Viacheslav Poturaev Date: Mon, 6 Apr 2026 11:35:17 +0200 Subject: [PATCH 01/15] Canonicalize code --- .../error_response.go | 10 +- .../form_or_json.go | 16 +- .../gzip_pass_through.go | 82 +-- .../gzip_pass_through_test.go | 182 ++--- .../html_response.go | 36 +- .../html_response_test.go | 30 +- .../json_body_manual.go | 14 +- .../json_map_body.go | 20 +- .../json_slice_body.go | 20 +- .../request_response_mapping_test.go | 34 +- .../request_text_body.go | 32 +- .../validation_test.go | 12 +- _examples/advanced/dynamic_schema.go | 70 +- _examples/advanced/error_response.go | 10 +- _examples/advanced/gzip_pass_through.go | 82 +-- _examples/advanced/gzip_pass_through_test.go | 182 ++--- _examples/advanced/json_map_body.go | 20 +- _examples/advanced/json_slice_body.go | 20 +- .../advanced/request_response_mapping_test.go | 34 +- _examples/advanced/validation_test.go | 12 +- _examples/gingonic/main.go | 112 +-- _examples/jwtauth/main.go | 22 +- _examples/mount/main.go | 50 +- _examples/multi-api/main.go | 24 +- .../task-api/internal/domain/task/entity.go | 34 +- .../task-api/internal/domain/task/service.go | 14 +- .../internal/infra/nethttp/benchmark_test.go | 64 +- .../internal/infra/repository/task.go | 140 ++-- .../internal/infra/service/provider.go | 10 +- .../task-api/internal/usecase/finish_task.go | 8 +- .../task-api/internal/usecase/update_task.go | 10 +- _examples/task-api/pkg/graceful/shutdown.go | 58 +- chirouter/wrapper.go | 220 +++--- chirouter/wrapper_test.go | 249 ++++--- dev_test.go | 2 +- error.go | 124 ++-- error_test.go | 36 +- gorillamux/collector.go | 40 +- gorillamux/collector_test.go | 44 +- gorillamux/example_openapi_collector_test.go | 140 ++-- gzip/container.go | 128 ++-- jsonschema/validator.go | 116 +-- jsonschema/validator_test.go | 62 +- nethttp/handler.go | 184 ++--- nethttp/handler_test.go | 204 +++--- nethttp/openapi.go | 200 +++--- nethttp/options.go | 62 +- nethttp/wrap.go | 50 +- nethttp/wrap_test.go | 12 +- nethttp/wrapper.go | 12 +- openapi/collector.go | 588 ++++++++-------- openapi/collector_test.go | 378 +++++----- request.go | 24 +- request/decoder.go | 175 ++--- request/decoder_test.go | 528 +++++++------- request/factory.go | 266 +++---- request/factory_test.go | 256 +++---- request/file.go | 12 +- request/file_test.go | 24 +- request/jsonbody.go | 12 +- request/jsonbody_test.go | 71 +- request/middleware.go | 28 +- response/encoder.go | 666 +++++++++--------- response/encoder_test.go | 266 +++---- response/gzip/middleware.go | 214 +++--- response/gzip/middleware_test.go | 210 +++--- response/middleware.go | 8 +- response/validator.go | 8 +- resttest/client.go | 10 +- resttest/server.go | 14 +- resttest/server_test.go | 304 ++++---- route.go | 10 +- trait.go | 88 +-- validator.go | 80 +-- web/example_test.go | 40 +- web/service.go | 132 ++-- web/service_test.go | 28 +- 77 files changed, 3895 insertions(+), 3894 deletions(-) diff --git a/_examples/advanced-generic-openapi31/error_response.go b/_examples/advanced-generic-openapi31/error_response.go index 0b3dae8..59107bd 100644 --- a/_examples/advanced-generic-openapi31/error_response.go +++ b/_examples/advanced-generic-openapi31/error_response.go @@ -11,11 +11,6 @@ import ( "github.com/swaggest/usecase/status" ) -type customErr struct { - Message string `json:"msg"` - Details map[string]interface{} `json:"details,omitempty"` -} - func errorResponse() usecase.Interactor { type errType struct { Type string `query:"type" enum:"ok,invalid_argument,conflict" required:"true"` @@ -56,3 +51,8 @@ type anotherErr struct { func (anotherErr) Error() string { return "foo happened" } + +type customErr struct { + Message string `json:"msg"` + Details map[string]interface{} `json:"details,omitempty"` +} diff --git a/_examples/advanced-generic-openapi31/form_or_json.go b/_examples/advanced-generic-openapi31/form_or_json.go index b450c00..c8ec1a2 100644 --- a/_examples/advanced-generic-openapi31/form_or_json.go +++ b/_examples/advanced-generic-openapi31/form_or_json.go @@ -6,14 +6,6 @@ import ( "github.com/swaggest/usecase" ) -type formOrJSONInput struct { - Field1 string `json:"field1" formData:"field1" required:"true"` - Field2 int `json:"field2" formData:"field2" required:"true"` - Field3 string `path:"path" required:"true"` -} - -func (formOrJSONInput) ForceJSONRequestBody() {} - func formOrJSON() usecase.Interactor { type formOrJSONOutput struct { F1 string `json:"f1"` @@ -34,3 +26,11 @@ func formOrJSON() usecase.Interactor { return u } + +type formOrJSONInput struct { + Field1 string `json:"field1" formData:"field1" required:"true"` + Field2 int `json:"field2" formData:"field2" required:"true"` + Field3 string `path:"path" required:"true"` +} + +func (formOrJSONInput) ForceJSONRequestBody() {} diff --git a/_examples/advanced-generic-openapi31/gzip_pass_through.go b/_examples/advanced-generic-openapi31/gzip_pass_through.go index 5d9aa26..e7f573a 100644 --- a/_examples/advanced-generic-openapi31/gzip_pass_through.go +++ b/_examples/advanced-generic-openapi31/gzip_pass_through.go @@ -9,47 +9,6 @@ import ( "github.com/swaggest/usecase" ) -type gzipPassThroughInput struct { - PlainStruct bool `query:"plainStruct" description:"Output plain structure instead of gzip container."` - CountItems bool `query:"countItems" description:"Invokes internal decoding of compressed data."` -} - -// gzipPassThroughOutput defers data to an accessor function instead of using struct directly. -// This is necessary to allow containers that can data in binary wire-friendly format. -type gzipPassThroughOutput interface { - // Data should be accessed though an accessor to allow container interface. - gzipPassThroughStruct() gzipPassThroughStruct -} - -// gzipPassThroughStruct represents the actual structure that is held in the container -// and implements gzipPassThroughOutput to be directly useful in output. -type gzipPassThroughStruct struct { - Header string `header:"X-Header" json:"-"` - ID int `json:"id"` - Text []string `json:"text"` -} - -func (d gzipPassThroughStruct) gzipPassThroughStruct() gzipPassThroughStruct { - return d -} - -// gzipPassThroughContainer is wrapping gzip.JSONContainer and implements gzipPassThroughOutput. -type gzipPassThroughContainer struct { - Header string `header:"X-Header" json:"-"` - gzip.JSONContainer -} - -func (dc gzipPassThroughContainer) gzipPassThroughStruct() gzipPassThroughStruct { - var p gzipPassThroughStruct - - err := dc.UnpackJSON(&p) - if err != nil { - panic(err) - } - - return p -} - func directGzip() usecase.Interactor { // Prepare moderately big JSON, resulting JSON payload is ~67KB. rawData := gzipPassThroughStruct{ @@ -91,3 +50,44 @@ func directGzip() usecase.Interactor { return u } + +// gzipPassThroughContainer is wrapping gzip.JSONContainer and implements gzipPassThroughOutput. +type gzipPassThroughContainer struct { + Header string `header:"X-Header" json:"-"` + gzip.JSONContainer +} + +func (dc gzipPassThroughContainer) gzipPassThroughStruct() gzipPassThroughStruct { + var p gzipPassThroughStruct + + err := dc.UnpackJSON(&p) + if err != nil { + panic(err) + } + + return p +} + +type gzipPassThroughInput struct { + PlainStruct bool `query:"plainStruct" description:"Output plain structure instead of gzip container."` + CountItems bool `query:"countItems" description:"Invokes internal decoding of compressed data."` +} + +// gzipPassThroughOutput defers data to an accessor function instead of using struct directly. +// This is necessary to allow containers that can data in binary wire-friendly format. +type gzipPassThroughOutput interface { + // Data should be accessed though an accessor to allow container interface. + gzipPassThroughStruct() gzipPassThroughStruct +} + +// gzipPassThroughStruct represents the actual structure that is held in the container +// and implements gzipPassThroughOutput to be directly useful in output. +type gzipPassThroughStruct struct { + Header string `header:"X-Header" json:"-"` + ID int `json:"id"` + Text []string `json:"text"` +} + +func (d gzipPassThroughStruct) gzipPassThroughStruct() gzipPassThroughStruct { + return d +} diff --git a/_examples/advanced-generic-openapi31/gzip_pass_through_test.go b/_examples/advanced-generic-openapi31/gzip_pass_through_test.go index 7bb822f..e20a730 100644 --- a/_examples/advanced-generic-openapi31/gzip_pass_through_test.go +++ b/_examples/advanced-generic-openapi31/gzip_pass_through_test.go @@ -14,85 +14,28 @@ import ( "github.com/valyala/fasthttp" ) -func Test_directGzip(t *testing.T) { +// Direct gzip enabled. +// Benchmark_directGzip-4 48037 24474 ns/op 624 B:rcvd/op 103 B:sent/op 40860 rps 3499 B/op 36 allocs/op. +// Benchmark_directGzip-4 45792 26102 ns/op 624 B:rcvd/op 103 B:sent/op 38278 rps 3063 B/op 33 allocs/op. +func Benchmark_directGzip(b *testing.B) { r := NewRouter() - req, err := http.NewRequest(http.MethodGet, "/gzip-pass-through", nil) - require.NoError(t, err) - - req.Header.Set("Accept-Encoding", "gzip") - - rw := httptest.NewRecorder() - - r.ServeHTTP(rw, req) - assert.Equal(t, http.StatusOK, rw.Code) - assert.Equal(t, "330epditz19z", rw.Header().Get("Etag")) - assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) - assert.Equal(t, "abc", rw.Header().Get("X-Header")) - assert.Less(t, len(rw.Body.Bytes()), 500) -} - -func Test_directGzip_HEAD(t *testing.T) { - srv := httptest.NewServer(NewRouter()) + srv := httptest.NewServer(r) defer srv.Close() - req, err := http.NewRequest(http.MethodHead, srv.URL+"/gzip-pass-through", nil) - require.NoError(t, err) - - req.Header.Set("Accept-Encoding", "gzip") - - resp, err := http.DefaultTransport.RoundTrip(req) - require.NoError(t, err) - - body, err := io.ReadAll(resp.Body) - assert.NoError(t, err) - assert.NoError(t, resp.Body.Close()) - - assert.Equal(t, http.StatusOK, resp.StatusCode) - assert.Equal(t, "330epditz19z", resp.Header.Get("Etag")) - assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding")) - assert.Equal(t, "abc", resp.Header.Get("X-Header")) - assert.Empty(t, body) -} - -func Test_noDirectGzip(t *testing.T) { - r := NewRouter() - - req, err := http.NewRequest(http.MethodGet, "/gzip-pass-through?plainStruct=1", nil) - require.NoError(t, err) - - req.Header.Set("Accept-Encoding", "gzip") - - rw := httptest.NewRecorder() - - r.ServeHTTP(rw, req) - assert.Equal(t, http.StatusOK, rw.Code) - assert.Equal(t, "", rw.Header().Get("Etag")) // No ETag for dynamic compression. - assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) - assert.Equal(t, "cba", rw.Header().Get("X-Header")) - assert.Less(t, len(rw.Body.Bytes()), 1000) // Worse compression for better speed. -} - -func Test_directGzip_perf(t *testing.T) { - res := testing.Benchmark(Benchmark_directGzip) - - if httptestbench.RaceDetectorEnabled { - assert.Less(t, res.Extra["B:rcvd/op"], 700.0) - assert.Less(t, res.Extra["B:sent/op"], 104.0) - assert.Less(t, res.AllocsPerOp(), int64(65)) - assert.Less(t, res.AllocedBytesPerOp(), int64(8800)) - } else { - assert.Less(t, res.Extra["B:rcvd/op"], 700.0) - assert.Less(t, res.Extra["B:sent/op"], 104.0) - assert.Less(t, res.AllocsPerOp(), int64(45)) - assert.Less(t, res.AllocedBytesPerOp(), int64(4200)) - } + httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { + req.Header.Set("Accept-Encoding", "gzip") + req.SetRequestURI(srv.URL + "/gzip-pass-through") + }, func(i int, resp *fasthttp.Response) bool { + return resp.StatusCode() == http.StatusOK + }) } -// Direct gzip enabled. -// Benchmark_directGzip-4 48037 24474 ns/op 624 B:rcvd/op 103 B:sent/op 40860 rps 3499 B/op 36 allocs/op. -// Benchmark_directGzip-4 45792 26102 ns/op 624 B:rcvd/op 103 B:sent/op 38278 rps 3063 B/op 33 allocs/op. -func Benchmark_directGzip(b *testing.B) { +// Direct gzip enabled, payload is unmarshaled and decompressed for every request in usecase body. +// Unmarshaling large JSON payloads can be much more expensive than explicitly creating them from Go values. +// Benchmark_directGzip_decode-4 2018 499755 ns/op 624 B:rcvd/op 116 B:sent/op 2001 rps 403967 B/op 496 allocs/op. +// Benchmark_directGzip_decode-4 2085 526586 ns/op 624 B:rcvd/op 116 B:sent/op 1899 rps 403600 B/op 493 allocs/op. +func Benchmark_directGzip_decode(b *testing.B) { r := NewRouter() srv := httptest.NewServer(r) @@ -100,7 +43,7 @@ func Benchmark_directGzip(b *testing.B) { httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { req.Header.Set("Accept-Encoding", "gzip") - req.SetRequestURI(srv.URL + "/gzip-pass-through") + req.SetRequestURI(srv.URL + "/gzip-pass-through?countItems=1") }, func(i int, resp *fasthttp.Response) bool { return resp.StatusCode() == http.StatusOK }) @@ -142,11 +85,10 @@ func Benchmark_noDirectGzip(b *testing.B) { }) } -// Direct gzip enabled, payload is unmarshaled and decompressed for every request in usecase body. -// Unmarshaling large JSON payloads can be much more expensive than explicitly creating them from Go values. -// Benchmark_directGzip_decode-4 2018 499755 ns/op 624 B:rcvd/op 116 B:sent/op 2001 rps 403967 B/op 496 allocs/op. -// Benchmark_directGzip_decode-4 2085 526586 ns/op 624 B:rcvd/op 116 B:sent/op 1899 rps 403600 B/op 493 allocs/op. -func Benchmark_directGzip_decode(b *testing.B) { +// Direct gzip disabled. +// Benchmark_noDirectGzip_decode-4 7603 142173 ns/op 1029 B:rcvd/op 130 B:sent/op 7034 rps 5122 B/op 43 allocs/op. +// Benchmark_noDirectGzip_decode-4 5836 198000 ns/op 1029 B:rcvd/op 130 B:sent/op 5051 rps 5371 B/op 42 allocs/op. +func Benchmark_noDirectGzip_decode(b *testing.B) { r := NewRouter() srv := httptest.NewServer(r) @@ -154,25 +96,83 @@ func Benchmark_directGzip_decode(b *testing.B) { httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { req.Header.Set("Accept-Encoding", "gzip") - req.SetRequestURI(srv.URL + "/gzip-pass-through?countItems=1") + req.SetRequestURI(srv.URL + "/gzip-pass-through?plainStruct=1&countItems=1") }, func(i int, resp *fasthttp.Response) bool { return resp.StatusCode() == http.StatusOK }) } -// Direct gzip disabled. -// Benchmark_noDirectGzip_decode-4 7603 142173 ns/op 1029 B:rcvd/op 130 B:sent/op 7034 rps 5122 B/op 43 allocs/op. -// Benchmark_noDirectGzip_decode-4 5836 198000 ns/op 1029 B:rcvd/op 130 B:sent/op 5051 rps 5371 B/op 42 allocs/op. -func Benchmark_noDirectGzip_decode(b *testing.B) { +func Test_directGzip(t *testing.T) { r := NewRouter() - srv := httptest.NewServer(r) + req, err := http.NewRequest(http.MethodGet, "/gzip-pass-through", nil) + require.NoError(t, err) + + req.Header.Set("Accept-Encoding", "gzip") + + rw := httptest.NewRecorder() + + r.ServeHTTP(rw, req) + assert.Equal(t, http.StatusOK, rw.Code) + assert.Equal(t, "330epditz19z", rw.Header().Get("Etag")) + assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) + assert.Equal(t, "abc", rw.Header().Get("X-Header")) + assert.Less(t, len(rw.Body.Bytes()), 500) +} + +func Test_directGzip_HEAD(t *testing.T) { + srv := httptest.NewServer(NewRouter()) defer srv.Close() - httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { - req.Header.Set("Accept-Encoding", "gzip") - req.SetRequestURI(srv.URL + "/gzip-pass-through?plainStruct=1&countItems=1") - }, func(i int, resp *fasthttp.Response) bool { - return resp.StatusCode() == http.StatusOK - }) + req, err := http.NewRequest(http.MethodHead, srv.URL+"/gzip-pass-through", nil) + require.NoError(t, err) + + req.Header.Set("Accept-Encoding", "gzip") + + resp, err := http.DefaultTransport.RoundTrip(req) + require.NoError(t, err) + + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.NoError(t, resp.Body.Close()) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "330epditz19z", resp.Header.Get("Etag")) + assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding")) + assert.Equal(t, "abc", resp.Header.Get("X-Header")) + assert.Empty(t, body) +} + +func Test_directGzip_perf(t *testing.T) { + res := testing.Benchmark(Benchmark_directGzip) + + if httptestbench.RaceDetectorEnabled { + assert.Less(t, res.Extra["B:rcvd/op"], 700.0) + assert.Less(t, res.Extra["B:sent/op"], 104.0) + assert.Less(t, res.AllocsPerOp(), int64(65)) + assert.Less(t, res.AllocedBytesPerOp(), int64(8800)) + } else { + assert.Less(t, res.Extra["B:rcvd/op"], 700.0) + assert.Less(t, res.Extra["B:sent/op"], 104.0) + assert.Less(t, res.AllocsPerOp(), int64(45)) + assert.Less(t, res.AllocedBytesPerOp(), int64(4200)) + } +} + +func Test_noDirectGzip(t *testing.T) { + r := NewRouter() + + req, err := http.NewRequest(http.MethodGet, "/gzip-pass-through?plainStruct=1", nil) + require.NoError(t, err) + + req.Header.Set("Accept-Encoding", "gzip") + + rw := httptest.NewRecorder() + + r.ServeHTTP(rw, req) + assert.Equal(t, http.StatusOK, rw.Code) + assert.Equal(t, "", rw.Header().Get("Etag")) // No ETag for dynamic compression. + assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) + assert.Equal(t, "cba", rw.Header().Get("X-Header")) + assert.Less(t, len(rw.Body.Bytes()), 1000) // Worse compression for better speed. } diff --git a/_examples/advanced-generic-openapi31/html_response.go b/_examples/advanced-generic-openapi31/html_response.go index bdf6242..79e5b03 100644 --- a/_examples/advanced-generic-openapi31/html_response.go +++ b/_examples/advanced-generic-openapi31/html_response.go @@ -10,24 +10,6 @@ import ( "github.com/swaggest/usecase" ) -type htmlResponseOutput struct { - ID int - Filter string - Title string - Items []string - AntiHeader bool `header:"X-Anti-Header"` - - writer io.Writer -} - -func (o *htmlResponseOutput) SetWriter(w io.Writer) { - o.writer = w -} - -func (o *htmlResponseOutput) Render(tmpl *template.Template) error { - return tmpl.Execute(o.writer, o) -} - func htmlResponse() usecase.Interactor { type htmlResponseInput struct { ID int `path:"id"` @@ -68,3 +50,21 @@ func htmlResponse() usecase.Interactor { return u } + +type htmlResponseOutput struct { + ID int + Filter string + Title string + Items []string + AntiHeader bool `header:"X-Anti-Header"` + + writer io.Writer +} + +func (o *htmlResponseOutput) Render(tmpl *template.Template) error { + return tmpl.Execute(o.writer, o) +} + +func (o *htmlResponseOutput) SetWriter(w io.Writer) { + o.writer = w +} diff --git a/_examples/advanced-generic-openapi31/html_response_test.go b/_examples/advanced-generic-openapi31/html_response_test.go index 6a15646..375d7ad 100644 --- a/_examples/advanced-generic-openapi31/html_response_test.go +++ b/_examples/advanced-generic-openapi31/html_response_test.go @@ -12,6 +12,21 @@ import ( "github.com/valyala/fasthttp" ) +// Benchmark_htmlResponse-12 89209 12348 ns/op 0.3801 50%:ms 1.119 90%:ms 2.553 99%:ms 3.877 99.9%:ms 370.0 B:rcvd/op 108.0 B:sent/op 80973 rps 8279 B/op 144 allocs/op. +func Benchmark_htmlResponse(b *testing.B) { + r := NewRouter() + + srv := httptest.NewServer(r) + defer srv.Close() + + httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { + req.SetRequestURI(srv.URL + "/html-response/123?filter=feel") + req.Header.Set("X-Header", "true") + }, func(i int, resp *fasthttp.Response) bool { + return resp.StatusCode() == http.StatusOK + }) +} + func Test_htmlResponse(t *testing.T) { r := NewRouter() @@ -61,18 +76,3 @@ func Test_htmlResponse_HEAD(t *testing.T) { assert.Equal(t, "text/html", resp.Header.Get("Content-Type")) assert.Empty(t, body) } - -// Benchmark_htmlResponse-12 89209 12348 ns/op 0.3801 50%:ms 1.119 90%:ms 2.553 99%:ms 3.877 99.9%:ms 370.0 B:rcvd/op 108.0 B:sent/op 80973 rps 8279 B/op 144 allocs/op. -func Benchmark_htmlResponse(b *testing.B) { - r := NewRouter() - - srv := httptest.NewServer(r) - defer srv.Close() - - httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { - req.SetRequestURI(srv.URL + "/html-response/123?filter=feel") - req.Header.Set("X-Header", "true") - }, func(i int, resp *fasthttp.Response) bool { - return resp.StatusCode() == http.StatusOK - }) -} diff --git a/_examples/advanced-generic-openapi31/json_body_manual.go b/_examples/advanced-generic-openapi31/json_body_manual.go index c16dd9f..4c90ea8 100644 --- a/_examples/advanced-generic-openapi31/json_body_manual.go +++ b/_examples/advanced-generic-openapi31/json_body_manual.go @@ -17,6 +17,13 @@ import ( "github.com/swaggest/usecase" ) +type JSONPayload struct { + ID int `json:"id"` + Name string `json:"name"` +} + +var _ request.Loader = &inputWithJSON{} + func jsonBodyManual() usecase.Interactor { type outputWithJSON struct { Header string `json:"inHeader"` @@ -41,11 +48,6 @@ func jsonBodyManual() usecase.Interactor { return u } -type JSONPayload struct { - ID int `json:"id"` - Name string `json:"name"` -} - type inputWithJSON struct { Header string `header:"X-Header" description:"Simple scalar value in header."` Query jsonschema.Date `query:"in_query" description:"Simple scalar value in query."` @@ -54,8 +56,6 @@ type inputWithJSON struct { JSONPayload } -var _ request.Loader = &inputWithJSON{} - func (i *inputWithJSON) LoadFromHTTPRequest(r *http.Request) (err error) { defer func() { if err := r.Body.Close(); err != nil { diff --git a/_examples/advanced-generic-openapi31/json_map_body.go b/_examples/advanced-generic-openapi31/json_map_body.go index fdac060..9d5955e 100644 --- a/_examples/advanced-generic-openapi31/json_map_body.go +++ b/_examples/advanced-generic-openapi31/json_map_body.go @@ -11,16 +11,6 @@ import ( type JSONMapPayload map[string]float64 -type jsonMapReq struct { - Header string `header:"X-Header" description:"Simple scalar value in header."` - Query int `query:"in_query" description:"Simple scalar value in query."` - JSONMapPayload -} - -func (j *jsonMapReq) UnmarshalJSON(data []byte) error { - return json.Unmarshal(data, &j.JSONMapPayload) -} - func jsonMapBody() usecase.Interactor { type jsonOutput struct { Header string `json:"inHeader"` @@ -41,3 +31,13 @@ func jsonMapBody() usecase.Interactor { return u } + +type jsonMapReq struct { + Header string `header:"X-Header" description:"Simple scalar value in header."` + Query int `query:"in_query" description:"Simple scalar value in query."` + JSONMapPayload +} + +func (j *jsonMapReq) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &j.JSONMapPayload) +} diff --git a/_examples/advanced-generic-openapi31/json_slice_body.go b/_examples/advanced-generic-openapi31/json_slice_body.go index 13be45c..f5c97c3 100644 --- a/_examples/advanced-generic-openapi31/json_slice_body.go +++ b/_examples/advanced-generic-openapi31/json_slice_body.go @@ -12,16 +12,6 @@ import ( // JSONSlicePayload is an example non-scalar type without `json` tags. type JSONSlicePayload []int -type jsonSliceReq struct { - Header string `header:"X-Header" description:"Simple scalar value in header."` - Query int `query:"in_query" description:"Simple scalar value in query."` - JSONSlicePayload -} - -func (j *jsonSliceReq) UnmarshalJSON(data []byte) error { - return json.Unmarshal(data, &j.JSONSlicePayload) -} - func jsonSliceBody() usecase.Interactor { type jsonOutput struct { Header string `json:"inHeader"` @@ -42,3 +32,13 @@ func jsonSliceBody() usecase.Interactor { return u } + +type jsonSliceReq struct { + Header string `header:"X-Header" description:"Simple scalar value in header."` + Query int `query:"in_query" description:"Simple scalar value in query."` + JSONSlicePayload +} + +func (j *jsonSliceReq) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &j.JSONSlicePayload) +} diff --git a/_examples/advanced-generic-openapi31/request_response_mapping_test.go b/_examples/advanced-generic-openapi31/request_response_mapping_test.go index 0c85192..421dd82 100644 --- a/_examples/advanced-generic-openapi31/request_response_mapping_test.go +++ b/_examples/advanced-generic-openapi31/request_response_mapping_test.go @@ -15,6 +15,23 @@ import ( "github.com/valyala/fasthttp" ) +func Benchmark_requestResponseMapping(b *testing.B) { + r := NewRouter() + + srv := httptest.NewServer(r) + defer srv.Close() + + httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { + req.Header.SetMethod(http.MethodPost) + req.SetRequestURI(srv.URL + "/req-resp-mapping") + req.Header.Set("X-Header", "abc") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBody([]byte(`val2=3`)) + }, func(i int, resp *fasthttp.Response) bool { + return resp.StatusCode() == http.StatusNoContent + }) +} + func Test_requestResponseMapping(t *testing.T) { r := NewRouter() @@ -40,20 +57,3 @@ func Test_requestResponseMapping(t *testing.T) { assert.Equal(t, "abc", resp.Header.Get("X-Value-1")) assert.Equal(t, "3", resp.Header.Get("X-Value-2")) } - -func Benchmark_requestResponseMapping(b *testing.B) { - r := NewRouter() - - srv := httptest.NewServer(r) - defer srv.Close() - - httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { - req.Header.SetMethod(http.MethodPost) - req.SetRequestURI(srv.URL + "/req-resp-mapping") - req.Header.Set("X-Header", "abc") - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBody([]byte(`val2=3`)) - }, func(i int, resp *fasthttp.Response) bool { - return resp.StatusCode() == http.StatusNoContent - }) -} diff --git a/_examples/advanced-generic-openapi31/request_text_body.go b/_examples/advanced-generic-openapi31/request_text_body.go index 38b8240..2ac2fb2 100644 --- a/_examples/advanced-generic-openapi31/request_text_body.go +++ b/_examples/advanced-generic-openapi31/request_text_body.go @@ -8,22 +8,6 @@ import ( "github.com/swaggest/usecase" ) -type textReqBodyInput struct { - Path string `path:"path"` - Query int `query:"query"` - text []byte - err error -} - -func (c *textReqBodyInput) SetRequest(r *http.Request) { - c.text, c.err = io.ReadAll(r.Body) - clErr := r.Body.Close() - - if c.err == nil { - c.err = clErr - } -} - func textReqBody() usecase.Interactor { type output struct { Path string `json:"path"` @@ -67,3 +51,19 @@ func textReqBodyPtr() usecase.Interactor { return u } + +type textReqBodyInput struct { + Path string `path:"path"` + Query int `query:"query"` + text []byte + err error +} + +func (c *textReqBodyInput) SetRequest(r *http.Request) { + c.text, c.err = io.ReadAll(r.Body) + clErr := r.Body.Close() + + if c.err == nil { + c.err = clErr + } +} diff --git a/_examples/advanced-generic-openapi31/validation_test.go b/_examples/advanced-generic-openapi31/validation_test.go index 227ea0f..442b81b 100644 --- a/_examples/advanced-generic-openapi31/validation_test.go +++ b/_examples/advanced-generic-openapi31/validation_test.go @@ -11,9 +11,7 @@ import ( "github.com/valyala/fasthttp" ) -// Benchmark_validation-4 18979 53012 ns/op 197 B:rcvd/op 170 B:sent/op 18861 rps 14817 B/op 131 allocs/op. -// Benchmark_validation-4 17665 58243 ns/op 177 B:rcvd/op 170 B:sent/op 17161 rps 16349 B/op 132 allocs/op. -func Benchmark_validation(b *testing.B) { +func Benchmark_noValidation(b *testing.B) { r := NewRouter() srv := httptest.NewServer(r) @@ -21,7 +19,7 @@ func Benchmark_validation(b *testing.B) { httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { req.Header.SetMethod(http.MethodPost) - req.SetRequestURI(srv.URL + "/validation?q=true") + req.SetRequestURI(srv.URL + "/no-validation?q=true") req.Header.Set("X-Input", "12") req.Header.Set("Content-Type", "application/json") req.SetBody([]byte(`{"data":{"value":"abc"}}`)) @@ -30,7 +28,9 @@ func Benchmark_validation(b *testing.B) { }) } -func Benchmark_noValidation(b *testing.B) { +// Benchmark_validation-4 18979 53012 ns/op 197 B:rcvd/op 170 B:sent/op 18861 rps 14817 B/op 131 allocs/op. +// Benchmark_validation-4 17665 58243 ns/op 177 B:rcvd/op 170 B:sent/op 17161 rps 16349 B/op 132 allocs/op. +func Benchmark_validation(b *testing.B) { r := NewRouter() srv := httptest.NewServer(r) @@ -38,7 +38,7 @@ func Benchmark_noValidation(b *testing.B) { httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { req.Header.SetMethod(http.MethodPost) - req.SetRequestURI(srv.URL + "/no-validation?q=true") + req.SetRequestURI(srv.URL + "/validation?q=true") req.Header.Set("X-Input", "12") req.Header.Set("Content-Type", "application/json") req.SetBody([]byte(`{"data":{"value":"abc"}}`)) diff --git a/_examples/advanced/dynamic_schema.go b/_examples/advanced/dynamic_schema.go index bcf574f..18bc97c 100644 --- a/_examples/advanced/dynamic_schema.go +++ b/_examples/advanced/dynamic_schema.go @@ -13,41 +13,6 @@ import ( "github.com/swaggest/usecase/status" ) -type dynamicInput struct { - jsonschema.Struct - request.EmbeddedSetter - - // Type is a static field example. - Type string `query:"type"` -} - -type dynamicOutput struct { - // Embedded jsonschema.Struct exposes dynamic fields for documentation. - jsonschema.Struct - - jsonFields map[string]interface{} - headerFields map[string]string - - // Status is a static field example. - Status string `json:"status"` -} - -func (o dynamicOutput) SetupResponseHeader(h http.Header) { - for k, v := range o.headerFields { - h.Set(k, v) - } -} - -func (o dynamicOutput) MarshalJSON() ([]byte, error) { - if o.jsonFields == nil { - o.jsonFields = map[string]interface{}{} - } - - o.jsonFields["status"] = o.Status - - return json.Marshal(o.jsonFields) -} - func dynamicSchema() usecase.Interactor { dynIn := dynamicInput{} dynIn.DefName = "DynIn123" @@ -94,3 +59,38 @@ func dynamicSchema() usecase.Interactor { return u } + +type dynamicInput struct { + jsonschema.Struct + request.EmbeddedSetter + + // Type is a static field example. + Type string `query:"type"` +} + +type dynamicOutput struct { + // Embedded jsonschema.Struct exposes dynamic fields for documentation. + jsonschema.Struct + + jsonFields map[string]interface{} + headerFields map[string]string + + // Status is a static field example. + Status string `json:"status"` +} + +func (o dynamicOutput) MarshalJSON() ([]byte, error) { + if o.jsonFields == nil { + o.jsonFields = map[string]interface{}{} + } + + o.jsonFields["status"] = o.Status + + return json.Marshal(o.jsonFields) +} + +func (o dynamicOutput) SetupResponseHeader(h http.Header) { + for k, v := range o.headerFields { + h.Set(k, v) + } +} diff --git a/_examples/advanced/error_response.go b/_examples/advanced/error_response.go index f872d39..cda03da 100644 --- a/_examples/advanced/error_response.go +++ b/_examples/advanced/error_response.go @@ -9,11 +9,6 @@ import ( "github.com/swaggest/usecase/status" ) -type customErr struct { - Message string `json:"msg"` - Details map[string]interface{} `json:"details,omitempty"` -} - func errorResponse() usecase.Interactor { type errType struct { Type string `query:"type" enum:"ok,invalid_argument,conflict" required:"true"` @@ -57,3 +52,8 @@ type anotherErr struct { func (anotherErr) Error() string { return "foo happened" } + +type customErr struct { + Message string `json:"msg"` + Details map[string]interface{} `json:"details,omitempty"` +} diff --git a/_examples/advanced/gzip_pass_through.go b/_examples/advanced/gzip_pass_through.go index 8611277..cb6765e 100644 --- a/_examples/advanced/gzip_pass_through.go +++ b/_examples/advanced/gzip_pass_through.go @@ -7,47 +7,6 @@ import ( "github.com/swaggest/usecase" ) -type gzipPassThroughInput struct { - PlainStruct bool `query:"plainStruct" description:"Output plain structure instead of gzip container."` - CountItems bool `query:"countItems" description:"Invokes internal decoding of compressed data."` -} - -// gzipPassThroughOutput defers data to an accessor function instead of using struct directly. -// This is necessary to allow containers that can data in binary wire-friendly format. -type gzipPassThroughOutput interface { - // Data should be accessed though an accessor to allow container interface. - gzipPassThroughStruct() gzipPassThroughStruct -} - -// gzipPassThroughStruct represents the actual structure that is held in the container -// and implements gzipPassThroughOutput to be directly useful in output. -type gzipPassThroughStruct struct { - Header string `header:"X-Header" json:"-"` - ID int `json:"id"` - Text []string `json:"text"` -} - -func (d gzipPassThroughStruct) gzipPassThroughStruct() gzipPassThroughStruct { - return d -} - -// gzipPassThroughContainer is wrapping gzip.JSONContainer and implements gzipPassThroughOutput. -type gzipPassThroughContainer struct { - Header string `header:"X-Header" json:"-"` - gzip.JSONContainer -} - -func (dc gzipPassThroughContainer) gzipPassThroughStruct() gzipPassThroughStruct { - var p gzipPassThroughStruct - - err := dc.UnpackJSON(&p) - if err != nil { - panic(err) - } - - return p -} - func directGzip() usecase.Interactor { // Prepare moderately big JSON, resulting JSON payload is ~67KB. rawData := gzipPassThroughStruct{ @@ -93,3 +52,44 @@ func directGzip() usecase.Interactor { return u } + +// gzipPassThroughContainer is wrapping gzip.JSONContainer and implements gzipPassThroughOutput. +type gzipPassThroughContainer struct { + Header string `header:"X-Header" json:"-"` + gzip.JSONContainer +} + +func (dc gzipPassThroughContainer) gzipPassThroughStruct() gzipPassThroughStruct { + var p gzipPassThroughStruct + + err := dc.UnpackJSON(&p) + if err != nil { + panic(err) + } + + return p +} + +type gzipPassThroughInput struct { + PlainStruct bool `query:"plainStruct" description:"Output plain structure instead of gzip container."` + CountItems bool `query:"countItems" description:"Invokes internal decoding of compressed data."` +} + +// gzipPassThroughOutput defers data to an accessor function instead of using struct directly. +// This is necessary to allow containers that can data in binary wire-friendly format. +type gzipPassThroughOutput interface { + // Data should be accessed though an accessor to allow container interface. + gzipPassThroughStruct() gzipPassThroughStruct +} + +// gzipPassThroughStruct represents the actual structure that is held in the container +// and implements gzipPassThroughOutput to be directly useful in output. +type gzipPassThroughStruct struct { + Header string `header:"X-Header" json:"-"` + ID int `json:"id"` + Text []string `json:"text"` +} + +func (d gzipPassThroughStruct) gzipPassThroughStruct() gzipPassThroughStruct { + return d +} diff --git a/_examples/advanced/gzip_pass_through_test.go b/_examples/advanced/gzip_pass_through_test.go index f33ad81..bdfe2e7 100644 --- a/_examples/advanced/gzip_pass_through_test.go +++ b/_examples/advanced/gzip_pass_through_test.go @@ -12,85 +12,28 @@ import ( "github.com/valyala/fasthttp" ) -func Test_directGzip(t *testing.T) { +// Direct gzip enabled. +// Benchmark_directGzip-4 48037 24474 ns/op 624 B:rcvd/op 103 B:sent/op 40860 rps 3499 B/op 36 allocs/op. +// Benchmark_directGzip-4 45792 26102 ns/op 624 B:rcvd/op 103 B:sent/op 38278 rps 3063 B/op 33 allocs/op. +func Benchmark_directGzip(b *testing.B) { r := NewRouter() - req, err := http.NewRequest(http.MethodGet, "/gzip-pass-through", nil) - require.NoError(t, err) - - req.Header.Set("Accept-Encoding", "gzip") - - rw := httptest.NewRecorder() - - r.ServeHTTP(rw, req) - assert.Equal(t, http.StatusOK, rw.Code) - assert.Equal(t, "330epditz19z", rw.Header().Get("Etag")) - assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) - assert.Equal(t, "abc", rw.Header().Get("X-Header")) - assert.Less(t, len(rw.Body.Bytes()), 500) -} - -func Test_directGzip_HEAD(t *testing.T) { - srv := httptest.NewServer(NewRouter()) + srv := httptest.NewServer(r) defer srv.Close() - req, err := http.NewRequest(http.MethodHead, srv.URL+"/gzip-pass-through", nil) - require.NoError(t, err) - - req.Header.Set("Accept-Encoding", "gzip") - - resp, err := http.DefaultTransport.RoundTrip(req) - require.NoError(t, err) - - body, err := io.ReadAll(resp.Body) - assert.NoError(t, err) - assert.NoError(t, resp.Body.Close()) - - assert.Equal(t, http.StatusOK, resp.StatusCode) - assert.Equal(t, "330epditz19z", resp.Header.Get("Etag")) - assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding")) - assert.Equal(t, "abc", resp.Header.Get("X-Header")) - assert.Empty(t, body) -} - -func Test_noDirectGzip(t *testing.T) { - r := NewRouter() - - req, err := http.NewRequest(http.MethodGet, "/gzip-pass-through?plainStruct=1", nil) - require.NoError(t, err) - - req.Header.Set("Accept-Encoding", "gzip") - - rw := httptest.NewRecorder() - - r.ServeHTTP(rw, req) - assert.Equal(t, http.StatusOK, rw.Code) - assert.Equal(t, "", rw.Header().Get("Etag")) // No ETag for dynamic compression. - assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) - assert.Equal(t, "cba", rw.Header().Get("X-Header")) - assert.Less(t, len(rw.Body.Bytes()), 1000) // Worse compression for better speed. -} - -func Test_directGzip_perf(t *testing.T) { - res := testing.Benchmark(Benchmark_directGzip) - - if httptestbench.RaceDetectorEnabled { - assert.Less(t, res.Extra["B:rcvd/op"], 640.0) - assert.Less(t, res.Extra["B:sent/op"], 104.0) - assert.Less(t, res.AllocsPerOp(), int64(60)) - assert.Less(t, res.AllocedBytesPerOp(), int64(9000)) - } else { - assert.Less(t, res.Extra["B:rcvd/op"], 640.0) - assert.Less(t, res.Extra["B:sent/op"], 104.0) - assert.Less(t, res.AllocsPerOp(), int64(45)) - assert.Less(t, res.AllocedBytesPerOp(), int64(4200)) - } + httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { + req.Header.Set("Accept-Encoding", "gzip") + req.SetRequestURI(srv.URL + "/gzip-pass-through") + }, func(i int, resp *fasthttp.Response) bool { + return resp.StatusCode() == http.StatusOK + }) } -// Direct gzip enabled. -// Benchmark_directGzip-4 48037 24474 ns/op 624 B:rcvd/op 103 B:sent/op 40860 rps 3499 B/op 36 allocs/op. -// Benchmark_directGzip-4 45792 26102 ns/op 624 B:rcvd/op 103 B:sent/op 38278 rps 3063 B/op 33 allocs/op. -func Benchmark_directGzip(b *testing.B) { +// Direct gzip enabled, payload is unmarshaled and decompressed for every request in usecase body. +// Unmarshaling large JSON payloads can be much more expensive than explicitly creating them from Go values. +// Benchmark_directGzip_decode-4 2018 499755 ns/op 624 B:rcvd/op 116 B:sent/op 2001 rps 403967 B/op 496 allocs/op. +// Benchmark_directGzip_decode-4 2085 526586 ns/op 624 B:rcvd/op 116 B:sent/op 1899 rps 403600 B/op 493 allocs/op. +func Benchmark_directGzip_decode(b *testing.B) { r := NewRouter() srv := httptest.NewServer(r) @@ -98,7 +41,7 @@ func Benchmark_directGzip(b *testing.B) { httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { req.Header.Set("Accept-Encoding", "gzip") - req.SetRequestURI(srv.URL + "/gzip-pass-through") + req.SetRequestURI(srv.URL + "/gzip-pass-through?countItems=1") }, func(i int, resp *fasthttp.Response) bool { return resp.StatusCode() == http.StatusOK }) @@ -140,11 +83,10 @@ func Benchmark_noDirectGzip(b *testing.B) { }) } -// Direct gzip enabled, payload is unmarshaled and decompressed for every request in usecase body. -// Unmarshaling large JSON payloads can be much more expensive than explicitly creating them from Go values. -// Benchmark_directGzip_decode-4 2018 499755 ns/op 624 B:rcvd/op 116 B:sent/op 2001 rps 403967 B/op 496 allocs/op. -// Benchmark_directGzip_decode-4 2085 526586 ns/op 624 B:rcvd/op 116 B:sent/op 1899 rps 403600 B/op 493 allocs/op. -func Benchmark_directGzip_decode(b *testing.B) { +// Direct gzip disabled. +// Benchmark_noDirectGzip_decode-4 7603 142173 ns/op 1029 B:rcvd/op 130 B:sent/op 7034 rps 5122 B/op 43 allocs/op. +// Benchmark_noDirectGzip_decode-4 5836 198000 ns/op 1029 B:rcvd/op 130 B:sent/op 5051 rps 5371 B/op 42 allocs/op. +func Benchmark_noDirectGzip_decode(b *testing.B) { r := NewRouter() srv := httptest.NewServer(r) @@ -152,25 +94,83 @@ func Benchmark_directGzip_decode(b *testing.B) { httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { req.Header.Set("Accept-Encoding", "gzip") - req.SetRequestURI(srv.URL + "/gzip-pass-through?countItems=1") + req.SetRequestURI(srv.URL + "/gzip-pass-through?plainStruct=1&countItems=1") }, func(i int, resp *fasthttp.Response) bool { return resp.StatusCode() == http.StatusOK }) } -// Direct gzip disabled. -// Benchmark_noDirectGzip_decode-4 7603 142173 ns/op 1029 B:rcvd/op 130 B:sent/op 7034 rps 5122 B/op 43 allocs/op. -// Benchmark_noDirectGzip_decode-4 5836 198000 ns/op 1029 B:rcvd/op 130 B:sent/op 5051 rps 5371 B/op 42 allocs/op. -func Benchmark_noDirectGzip_decode(b *testing.B) { +func Test_directGzip(t *testing.T) { r := NewRouter() - srv := httptest.NewServer(r) + req, err := http.NewRequest(http.MethodGet, "/gzip-pass-through", nil) + require.NoError(t, err) + + req.Header.Set("Accept-Encoding", "gzip") + + rw := httptest.NewRecorder() + + r.ServeHTTP(rw, req) + assert.Equal(t, http.StatusOK, rw.Code) + assert.Equal(t, "330epditz19z", rw.Header().Get("Etag")) + assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) + assert.Equal(t, "abc", rw.Header().Get("X-Header")) + assert.Less(t, len(rw.Body.Bytes()), 500) +} + +func Test_directGzip_HEAD(t *testing.T) { + srv := httptest.NewServer(NewRouter()) defer srv.Close() - httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { - req.Header.Set("Accept-Encoding", "gzip") - req.SetRequestURI(srv.URL + "/gzip-pass-through?plainStruct=1&countItems=1") - }, func(i int, resp *fasthttp.Response) bool { - return resp.StatusCode() == http.StatusOK - }) + req, err := http.NewRequest(http.MethodHead, srv.URL+"/gzip-pass-through", nil) + require.NoError(t, err) + + req.Header.Set("Accept-Encoding", "gzip") + + resp, err := http.DefaultTransport.RoundTrip(req) + require.NoError(t, err) + + body, err := io.ReadAll(resp.Body) + assert.NoError(t, err) + assert.NoError(t, resp.Body.Close()) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "330epditz19z", resp.Header.Get("Etag")) + assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding")) + assert.Equal(t, "abc", resp.Header.Get("X-Header")) + assert.Empty(t, body) +} + +func Test_directGzip_perf(t *testing.T) { + res := testing.Benchmark(Benchmark_directGzip) + + if httptestbench.RaceDetectorEnabled { + assert.Less(t, res.Extra["B:rcvd/op"], 640.0) + assert.Less(t, res.Extra["B:sent/op"], 104.0) + assert.Less(t, res.AllocsPerOp(), int64(60)) + assert.Less(t, res.AllocedBytesPerOp(), int64(9000)) + } else { + assert.Less(t, res.Extra["B:rcvd/op"], 640.0) + assert.Less(t, res.Extra["B:sent/op"], 104.0) + assert.Less(t, res.AllocsPerOp(), int64(45)) + assert.Less(t, res.AllocedBytesPerOp(), int64(4200)) + } +} + +func Test_noDirectGzip(t *testing.T) { + r := NewRouter() + + req, err := http.NewRequest(http.MethodGet, "/gzip-pass-through?plainStruct=1", nil) + require.NoError(t, err) + + req.Header.Set("Accept-Encoding", "gzip") + + rw := httptest.NewRecorder() + + r.ServeHTTP(rw, req) + assert.Equal(t, http.StatusOK, rw.Code) + assert.Equal(t, "", rw.Header().Get("Etag")) // No ETag for dynamic compression. + assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) + assert.Equal(t, "cba", rw.Header().Get("X-Header")) + assert.Less(t, len(rw.Body.Bytes()), 1000) // Worse compression for better speed. } diff --git a/_examples/advanced/json_map_body.go b/_examples/advanced/json_map_body.go index 5e80d7e..9c23c41 100644 --- a/_examples/advanced/json_map_body.go +++ b/_examples/advanced/json_map_body.go @@ -9,16 +9,6 @@ import ( type JSONMapPayload map[string]float64 -type jsonMapReq struct { - Header string `header:"X-Header" description:"Simple scalar value in header."` - Query int `query:"in_query" description:"Simple scalar value in query."` - JSONMapPayload -} - -func (j *jsonMapReq) UnmarshalJSON(data []byte) error { - return json.Unmarshal(data, &j.JSONMapPayload) -} - func jsonMapBody() usecase.Interactor { type jsonOutput struct { Header string `json:"inHeader"` @@ -43,3 +33,13 @@ func jsonMapBody() usecase.Interactor { return u } + +type jsonMapReq struct { + Header string `header:"X-Header" description:"Simple scalar value in header."` + Query int `query:"in_query" description:"Simple scalar value in query."` + JSONMapPayload +} + +func (j *jsonMapReq) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &j.JSONMapPayload) +} diff --git a/_examples/advanced/json_slice_body.go b/_examples/advanced/json_slice_body.go index d85b0b7..eb82c54 100644 --- a/_examples/advanced/json_slice_body.go +++ b/_examples/advanced/json_slice_body.go @@ -9,16 +9,6 @@ import ( type JSONSlicePayload []int -type jsonSliceReq struct { - Header string `header:"X-Header" description:"Simple scalar value in header."` - Query int `query:"in_query" description:"Simple scalar value in query."` - JSONSlicePayload -} - -func (j *jsonSliceReq) UnmarshalJSON(data []byte) error { - return json.Unmarshal(data, &j.JSONSlicePayload) -} - func jsonSliceBody() usecase.Interactor { type jsonOutput struct { Header string `json:"inHeader"` @@ -43,3 +33,13 @@ func jsonSliceBody() usecase.Interactor { return u } + +type jsonSliceReq struct { + Header string `header:"X-Header" description:"Simple scalar value in header."` + Query int `query:"in_query" description:"Simple scalar value in query."` + JSONSlicePayload +} + +func (j *jsonSliceReq) UnmarshalJSON(data []byte) error { + return json.Unmarshal(data, &j.JSONSlicePayload) +} diff --git a/_examples/advanced/request_response_mapping_test.go b/_examples/advanced/request_response_mapping_test.go index 5db99a5..b23efbd 100644 --- a/_examples/advanced/request_response_mapping_test.go +++ b/_examples/advanced/request_response_mapping_test.go @@ -13,6 +13,23 @@ import ( "github.com/valyala/fasthttp" ) +func Benchmark_requestResponseMapping(b *testing.B) { + r := NewRouter() + + srv := httptest.NewServer(r) + defer srv.Close() + + httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { + req.Header.SetMethod(http.MethodPost) + req.SetRequestURI(srv.URL + "/req-resp-mapping") + req.Header.Set("X-Header", "abc") + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.SetBody([]byte(`val2=3`)) + }, func(i int, resp *fasthttp.Response) bool { + return resp.StatusCode() == http.StatusNoContent + }) +} + func Test_requestResponseMapping(t *testing.T) { r := NewRouter() @@ -38,20 +55,3 @@ func Test_requestResponseMapping(t *testing.T) { assert.Equal(t, "abc", resp.Header.Get("X-Value-1")) assert.Equal(t, "3", resp.Header.Get("X-Value-2")) } - -func Benchmark_requestResponseMapping(b *testing.B) { - r := NewRouter() - - srv := httptest.NewServer(r) - defer srv.Close() - - httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { - req.Header.SetMethod(http.MethodPost) - req.SetRequestURI(srv.URL + "/req-resp-mapping") - req.Header.Set("X-Header", "abc") - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - req.SetBody([]byte(`val2=3`)) - }, func(i int, resp *fasthttp.Response) bool { - return resp.StatusCode() == http.StatusNoContent - }) -} diff --git a/_examples/advanced/validation_test.go b/_examples/advanced/validation_test.go index d71d7b5..1d6b34c 100644 --- a/_examples/advanced/validation_test.go +++ b/_examples/advanced/validation_test.go @@ -14,9 +14,7 @@ import ( "github.com/valyala/fasthttp" ) -// Benchmark_validation-4 18979 53012 ns/op 197 B:rcvd/op 170 B:sent/op 18861 rps 14817 B/op 131 allocs/op. -// Benchmark_validation-4 17665 58243 ns/op 177 B:rcvd/op 170 B:sent/op 17161 rps 16349 B/op 132 allocs/op. -func Benchmark_validation(b *testing.B) { +func Benchmark_noValidation(b *testing.B) { r := NewRouter() srv := httptest.NewServer(r) @@ -24,7 +22,7 @@ func Benchmark_validation(b *testing.B) { httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { req.Header.SetMethod(http.MethodPost) - req.SetRequestURI(srv.URL + "/validation?q=true") + req.SetRequestURI(srv.URL + "/no-validation?q=true") req.Header.Set("X-Input", "12") req.Header.Set("Content-Type", "application/json") req.SetBody([]byte(`{"data":{"value":"abc"}}`)) @@ -33,7 +31,9 @@ func Benchmark_validation(b *testing.B) { }) } -func Benchmark_noValidation(b *testing.B) { +// Benchmark_validation-4 18979 53012 ns/op 197 B:rcvd/op 170 B:sent/op 18861 rps 14817 B/op 131 allocs/op. +// Benchmark_validation-4 17665 58243 ns/op 177 B:rcvd/op 170 B:sent/op 17161 rps 16349 B/op 132 allocs/op. +func Benchmark_validation(b *testing.B) { r := NewRouter() srv := httptest.NewServer(r) @@ -41,7 +41,7 @@ func Benchmark_noValidation(b *testing.B) { httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { req.Header.SetMethod(http.MethodPost) - req.SetRequestURI(srv.URL + "/no-validation?q=true") + req.SetRequestURI(srv.URL + "/validation?q=true") req.Header.Set("X-Input", "12") req.Header.Set("Content-Type", "application/json") req.SetBody([]byte(`{"data":{"value":"abc"}}`)) diff --git a/_examples/gingonic/main.go b/_examples/gingonic/main.go index a5355c2..7de0bda 100644 --- a/_examples/gingonic/main.go +++ b/_examples/gingonic/main.go @@ -14,14 +14,27 @@ import ( "github.com/swaggest/openapi-go/openapi3" ) -func OpenAPICtx(c *gin.Context) openapi.OperationContext { - if oc, ok := c.Get("openapiContext"); ok { - if oc, ok := oc.(openapi.OperationContext); ok { - return oc - } +func main() { + router := gin.Default() + router.GET("/albums", getAlbums) + router.GET("/albums/:id", getAlbumByID) + router.POST("/albums", postAlbums) + + refl := openapi3.NewReflector() + refl.SpecSchema().SetTitle("Albums API") + refl.SpecSchema().SetVersion("v1.2.3") + refl.SpecSchema().SetDescription("This services keeps track of albums.") + + if err := OpenAPICollect(refl, router.Routes()); err != nil { + fmt.Println(err.Error()) } - return nil + y, _ := refl.Spec.MarshalYAML() + + os.WriteFile("openapi.yaml", y, 0o600) + fmt.Println(string(y)) + + router.Run("localhost:8080") } func OpenAPICollect(refl openapi.Reflector, routes gin.RoutesInfo) error { @@ -85,12 +98,14 @@ func OpenAPICollect(refl openapi.Reflector, routes gin.RoutesInfo) error { return nil } -// album represents data about a record album. -type album struct { - ID string `json:"id"` - Title string `json:"title"` - Artist string `json:"artist"` - Price float64 `json:"price"` +func OpenAPICtx(c *gin.Context) openapi.OperationContext { + if oc, ok := c.Get("openapiContext"); ok { + if oc, ok := oc.(openapi.OperationContext); ok { + return oc + } + } + + return nil } // albums slice to seed record album data. @@ -100,27 +115,32 @@ var albums = []album{ {ID: "3", Title: "Sarah Vaughan and Clifford Brown", Artist: "Sarah Vaughan", Price: 39.99}, } -func main() { - router := gin.Default() - router.GET("/albums", getAlbums) - router.GET("/albums/:id", getAlbumByID) - router.POST("/albums", postAlbums) - - refl := openapi3.NewReflector() - refl.SpecSchema().SetTitle("Albums API") - refl.SpecSchema().SetVersion("v1.2.3") - refl.SpecSchema().SetDescription("This services keeps track of albums.") - - if err := OpenAPICollect(refl, router.Routes()); err != nil { - fmt.Println(err.Error()) +// getAlbumByID locates the album whose ID value matches the id +// parameter sent by the client, then returns that album as a response. +func getAlbumByID(c *gin.Context) { + if oc := OpenAPICtx(c); oc != nil { + oc.SetSummary("Get album") + oc.SetTags("Albums") + oc.AddReqStructure(struct { + ID string `path:"id"` + }{}) + oc.AddRespStructure(album{}) + oc.AddRespStructure(struct { + Message string `json:"message"` + }{}, openapi.WithHTTPStatus(http.StatusNotFound)) } - y, _ := refl.Spec.MarshalYAML() - - os.WriteFile("openapi.yaml", y, 0o600) - fmt.Println(string(y)) + id := c.Param("id") - router.Run("localhost:8080") + // Loop through the list of albums, looking for + // an album whose ID value matches the parameter. + for _, a := range albums { + if a.ID == id { + c.JSON(http.StatusOK, a) + return + } + } + c.JSON(http.StatusNotFound, gin.H{"message": "album not found"}) } // getAlbums responds with the list of all albums as JSON. @@ -156,30 +176,10 @@ func postAlbums(c *gin.Context) { c.JSON(http.StatusCreated, newAlbum) } -// getAlbumByID locates the album whose ID value matches the id -// parameter sent by the client, then returns that album as a response. -func getAlbumByID(c *gin.Context) { - if oc := OpenAPICtx(c); oc != nil { - oc.SetSummary("Get album") - oc.SetTags("Albums") - oc.AddReqStructure(struct { - ID string `path:"id"` - }{}) - oc.AddRespStructure(album{}) - oc.AddRespStructure(struct { - Message string `json:"message"` - }{}, openapi.WithHTTPStatus(http.StatusNotFound)) - } - - id := c.Param("id") - - // Loop through the list of albums, looking for - // an album whose ID value matches the parameter. - for _, a := range albums { - if a.ID == id { - c.JSON(http.StatusOK, a) - return - } - } - c.JSON(http.StatusNotFound, gin.H{"message": "album not found"}) +// album represents data about a record album. +type album struct { + ID string `json:"id"` + Title string `json:"title"` + Artist string `json:"artist"` + Price float64 `json:"price"` } diff --git a/_examples/jwtauth/main.go b/_examples/jwtauth/main.go index f464394..66e96ac 100644 --- a/_examples/jwtauth/main.go +++ b/_examples/jwtauth/main.go @@ -72,17 +72,6 @@ import ( "github.com/swaggest/usecase" ) -var tokenAuth *jwtauth.JWTAuth - -func init() { - tokenAuth = jwtauth.New("HS256", []byte("secret"), nil) - - // For debugging/example purposes, we generate and print - // a sample jwt token with claims `user_id:123` here: - _, tokenString, _ := tokenAuth.Encode(map[string]interface{}{"user_id": 123}) - fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString) -} - func main() { addr := "localhost:3333" fmt.Printf("Starting server on http://%v\n", addr) @@ -104,6 +93,17 @@ func Get() usecase.Interactor { return u } +var tokenAuth *jwtauth.JWTAuth + +func init() { + tokenAuth = jwtauth.New("HS256", []byte("secret"), nil) + + // For debugging/example purposes, we generate and print + // a sample jwt token with claims `user_id:123` here: + _, tokenString, _ := tokenAuth.Encode(map[string]interface{}{"user_id": 123}) + fmt.Printf("DEBUG: a sample jwt is %s\n\n", tokenString) +} + func router() http.Handler { s := web.NewService(openapi31.NewReflector()) diff --git a/_examples/mount/main.go b/_examples/mount/main.go index 0279ac6..e5108c7 100644 --- a/_examples/mount/main.go +++ b/_examples/mount/main.go @@ -15,26 +15,11 @@ import ( "github.com/swaggest/usecase" ) -func mul() usecase.Interactor { - return usecase.NewInteractor(func(ctx context.Context, input []int, output *int) error { - *output = 1 - - for _, v := range input { - *output *= v - } - - return nil - }) -} - -func sum() usecase.Interactor { - return usecase.NewInteractor(func(ctx context.Context, input []int, output *int) error { - for _, v := range input { - *output += v - } - - return nil - }) +func main() { + fmt.Println("Swagger UI at http://localhost:8010/api/docs.") + if err := http.ListenAndServe("localhost:8010", service()); err != nil { + log.Fatal(err) + } } func service() *web.Service { @@ -78,9 +63,24 @@ func service() *web.Service { return s } -func main() { - fmt.Println("Swagger UI at http://localhost:8010/api/docs.") - if err := http.ListenAndServe("localhost:8010", service()); err != nil { - log.Fatal(err) - } +func mul() usecase.Interactor { + return usecase.NewInteractor(func(ctx context.Context, input []int, output *int) error { + *output = 1 + + for _, v := range input { + *output *= v + } + + return nil + }) +} + +func sum() usecase.Interactor { + return usecase.NewInteractor(func(ctx context.Context, input []int, output *int) error { + for _, v := range input { + *output += v + } + + return nil + }) } diff --git a/_examples/multi-api/main.go b/_examples/multi-api/main.go index f5bd763..866bd14 100644 --- a/_examples/multi-api/main.go +++ b/_examples/multi-api/main.go @@ -83,18 +83,6 @@ func service() *web.Service { return s } -func specHandler(s openapi.SpecSchema) http.Handler { - j, err := json.Marshal(s) - if err != nil { - panic(err) - } - - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - _, _ = w.Write(j) - }) -} - func mul() usecase.Interactor { return usecase.NewInteractor(func(ctx context.Context, input []int, output *int) error { *output = 1 @@ -107,6 +95,18 @@ func mul() usecase.Interactor { }) } +func specHandler(s openapi.SpecSchema) http.Handler { + j, err := json.Marshal(s) + if err != nil { + panic(err) + } + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write(j) + }) +} + func sum() usecase.Interactor { return usecase.NewInteractor(func(ctx context.Context, input []int, output *int) error { for _, v := range input { diff --git a/_examples/task-api/internal/domain/task/entity.go b/_examples/task-api/internal/domain/task/entity.go index 2c6f39e..607c41d 100644 --- a/_examples/task-api/internal/domain/task/entity.go +++ b/_examples/task-api/internal/domain/task/entity.go @@ -6,9 +6,6 @@ import ( "github.com/swaggest/jsonschema-go" ) -// Status describes task state. -type Status string - // Available task statuses. const ( Active = Status("") @@ -17,7 +14,22 @@ const ( Expired = Status("expired") ) -var _ jsonschema.Exposer = Status("") +// Entity is an identified task entity. +type Entity struct { + Identity + Value + CreatedAt time.Time `json:"createdAt"` + Status Status `json:"status,omitempty"` + ClosedAt *time.Time `json:"closedAt,omitempty"` +} + +// Identity identifies task. +type Identity struct { + ID int `json:"id"` +} + +// Status describes task state. +type Status string // JSONSchema exposes Status JSON schema, implements jsonschema.Exposer. func (Status) JSONSchema() (jsonschema.Schema, error) { @@ -31,22 +43,10 @@ func (Status) JSONSchema() (jsonschema.Schema, error) { return s, nil } -// Identity identifies task. -type Identity struct { - ID int `json:"id"` -} - // Value is a task value. type Value struct { Goal string `json:"goal" minLength:"1" required:"true"` Deadline *time.Time `json:"deadline,omitempty"` } -// Entity is an identified task entity. -type Entity struct { - Identity - Value - CreatedAt time.Time `json:"createdAt"` - Status Status `json:"status,omitempty"` - ClosedAt *time.Time `json:"closedAt,omitempty"` -} +var _ jsonschema.Exposer = Status("") diff --git a/_examples/task-api/internal/domain/task/service.go b/_examples/task-api/internal/domain/task/service.go index 59174cf..6f84e8d 100644 --- a/_examples/task-api/internal/domain/task/service.go +++ b/_examples/task-api/internal/domain/task/service.go @@ -8,9 +8,10 @@ type Creator interface { Create(context.Context, Value) (Entity, error) } -// Updater updates tasks. -type Updater interface { - Update(context.Context, Identity, Value) error +// Finder finds tasks. +type Finder interface { + Find(context.Context) []Entity + FindByID(context.Context, Identity) (Entity, error) } // Finisher closes tasks. @@ -19,8 +20,7 @@ type Finisher interface { Finish(context.Context, Identity) error } -// Finder finds tasks. -type Finder interface { - Find(context.Context) []Entity - FindByID(context.Context, Identity) (Entity, error) +// Updater updates tasks. +type Updater interface { + Update(context.Context, Identity, Value) error } diff --git a/_examples/task-api/internal/infra/nethttp/benchmark_test.go b/_examples/task-api/internal/infra/nethttp/benchmark_test.go index 79c9132..fd0a7dd 100644 --- a/_examples/task-api/internal/infra/nethttp/benchmark_test.go +++ b/_examples/task-api/internal/infra/nethttp/benchmark_test.go @@ -18,34 +18,43 @@ import ( "github.com/valyala/fasthttp" ) -// Benchmark_notFoundSrv-4 31236 37106 ns/op 26927 RPS 8408 B/op 76 allocs/op. -// Benchmark_notFoundSrv-4 33241 33620 ns/op 29745 RPS 5796 B/op 65 allocs/op. -// Benchmark_notFoundSrv-4 33656 35653 ns/op 336 B:rcvd/op 74.0 B:sent/op 28048 rps 5813 B/op 65 allocs/op. -// Benchmark_notFoundSrv-4 32262 36431 ns/op 337 B:rcvd/op 74.0 B:sent/op 27449 rps 5769 B/op 64 allocs/op. -func Benchmark_notFoundSrv(b *testing.B) { +// Benchmark_invalidBody-4 23670 46677 ns/op 21424 RPS 13111 B/op 132 allocs/op. +// Benchmark_invalidBody-4 23838 46156 ns/op 21666 RPS 9724 B/op 111 allocs/op. +// Benchmark_invalidBody-4 23589 60475 ns/op 439 B:rcvd/op 137 B:sent/op 16531 rps 9781 B/op 111 allocs/op. +// Benchmark_invalidBody-4 18458 54945 ns/op 435 B:rcvd/op 137 B:sent/op 18200 rps 9634 B/op 110 allocs/op +func Benchmark_invalidBody(b *testing.B) { log.SetOutput(ioutil.Discard) l := infra.NewServiceLocator(service.Config{}) defer l.Close() - srv := httptest.NewServer(nethttp.NewRouter(l)) - defer srv.Close() + r := nethttp.NewRouter(l) + srv := httptest.NewServer(r) + + tt, err := l.TaskCreator().Create(context.Background(), task.Value{Goal: "win"}) + require.NoError(b, err) + assert.Equal(b, 1, tt.ID) + + body := []byte(`{"goal":""}`) httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { + req.Header.SetMethod(http.MethodPut) + req.Header.SetContentType("application/json") req.SetRequestURI(srv.URL + "/dev/tasks/1") + req.SetBody(body) }, func(i int, resp *fasthttp.Response) bool { - return resp.StatusCode() == http.StatusNotFound + return resp.StatusCode() == http.StatusBadRequest }, ) } -// Benchmark_ok-4 28002 36993 ns/op 27027 RPS 8539 B/op 75 allocs/op. -// Benchmark_ok-4 35078 34293 ns/op 29156 RPS 5729 B/op 61 allocs/op. -// Benchmark_ok-4 33270 36366 ns/op 360 B:rcvd/op 74.0 B:sent/op 27498 rps 5730 B/op 61 allocs/op. -// Benchmark_ok-4 32761 37317 ns/op 362 B:rcvd/op 74.0 B:sent/op 26797 rps 5673 B/op 60 allocs/op. -func Benchmark_ok(b *testing.B) { +// Benchmark_notFoundSrv-4 31236 37106 ns/op 26927 RPS 8408 B/op 76 allocs/op. +// Benchmark_notFoundSrv-4 33241 33620 ns/op 29745 RPS 5796 B/op 65 allocs/op. +// Benchmark_notFoundSrv-4 33656 35653 ns/op 336 B:rcvd/op 74.0 B:sent/op 28048 rps 5813 B/op 65 allocs/op. +// Benchmark_notFoundSrv-4 32262 36431 ns/op 337 B:rcvd/op 74.0 B:sent/op 27449 rps 5769 B/op 64 allocs/op. +func Benchmark_notFoundSrv(b *testing.B) { log.SetOutput(ioutil.Discard) l := infra.NewServiceLocator(service.Config{}) @@ -54,47 +63,38 @@ func Benchmark_ok(b *testing.B) { srv := httptest.NewServer(nethttp.NewRouter(l)) defer srv.Close() - _, err := l.TaskCreator().Create(context.Background(), task.Value{Goal: "victory!"}) - require.NoError(b, err) - httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { req.SetRequestURI(srv.URL + "/dev/tasks/1") }, func(i int, resp *fasthttp.Response) bool { - return resp.StatusCode() == http.StatusOK + return resp.StatusCode() == http.StatusNotFound }, ) } -// Benchmark_invalidBody-4 23670 46677 ns/op 21424 RPS 13111 B/op 132 allocs/op. -// Benchmark_invalidBody-4 23838 46156 ns/op 21666 RPS 9724 B/op 111 allocs/op. -// Benchmark_invalidBody-4 23589 60475 ns/op 439 B:rcvd/op 137 B:sent/op 16531 rps 9781 B/op 111 allocs/op. -// Benchmark_invalidBody-4 18458 54945 ns/op 435 B:rcvd/op 137 B:sent/op 18200 rps 9634 B/op 110 allocs/op -func Benchmark_invalidBody(b *testing.B) { +// Benchmark_ok-4 28002 36993 ns/op 27027 RPS 8539 B/op 75 allocs/op. +// Benchmark_ok-4 35078 34293 ns/op 29156 RPS 5729 B/op 61 allocs/op. +// Benchmark_ok-4 33270 36366 ns/op 360 B:rcvd/op 74.0 B:sent/op 27498 rps 5730 B/op 61 allocs/op. +// Benchmark_ok-4 32761 37317 ns/op 362 B:rcvd/op 74.0 B:sent/op 26797 rps 5673 B/op 60 allocs/op. +func Benchmark_ok(b *testing.B) { log.SetOutput(ioutil.Discard) l := infra.NewServiceLocator(service.Config{}) defer l.Close() - r := nethttp.NewRouter(l) - srv := httptest.NewServer(r) + srv := httptest.NewServer(nethttp.NewRouter(l)) + defer srv.Close() - tt, err := l.TaskCreator().Create(context.Background(), task.Value{Goal: "win"}) + _, err := l.TaskCreator().Create(context.Background(), task.Value{Goal: "victory!"}) require.NoError(b, err) - assert.Equal(b, 1, tt.ID) - - body := []byte(`{"goal":""}`) httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { - req.Header.SetMethod(http.MethodPut) - req.Header.SetContentType("application/json") req.SetRequestURI(srv.URL + "/dev/tasks/1") - req.SetBody(body) }, func(i int, resp *fasthttp.Response) bool { - return resp.StatusCode() == http.StatusBadRequest + return resp.StatusCode() == http.StatusOK }, ) } diff --git a/_examples/task-api/internal/infra/repository/task.go b/_examples/task-api/internal/infra/repository/task.go index 250d4a7..3cb0e9c 100644 --- a/_examples/task-api/internal/infra/repository/task.go +++ b/_examples/task-api/internal/infra/repository/task.go @@ -20,34 +20,41 @@ type Task struct { list map[task.Identity]task.Entity } -// TaskUpdater is a service provider. -func (tr *Task) TaskUpdater() task.Updater { - return tr +// Cancel closes task as canceled. +func (tr *Task) Cancel(ctx context.Context, identity task.Identity) error { + return tr.close(identity, task.Canceled) } -// Update updates task value by identity. -func (tr *Task) Update(_ context.Context, identity task.Identity, value task.Value) error { +// Create creates a new task. +func (tr *Task) Create(ctx context.Context, value task.Value) (task.Entity, error) { tr.mu.Lock() defer tr.mu.Unlock() - t, found := tr.list[identity] - if !found { - return status.NotFound + for _, t := range tr.list { + if t.Value.Goal == value.Goal { + return task.Entity{}, usecase.Error{ + StatusCode: status.AlreadyExists, + Context: map[string]interface{}{ + "task": t, + }, + Value: errors.New("task with same goal already exists"), + } + } } - if t.ClosedAt != nil { - return status.Wrap(errors.New("task is already closed"), status.FailedPrecondition) + tr.lastID++ + + if tr.list == nil { + tr.list = make(map[task.Identity]task.Entity, 1) } + t := task.Entity{} t.Value = value - tr.list[identity] = t - - return nil -} + t.ID = tr.lastID + t.CreatedAt = time.Now() + tr.list[t.Identity] = t -// TaskFinder is a service provider. -func (tr *Task) TaskFinder() task.Finder { - return tr + return t, nil } // Find finds all tasks. @@ -80,39 +87,27 @@ func (tr *Task) FindByID(ctx context.Context, identity task.Identity) (task.Enti return t, nil } -// TaskFinisher is a service provider. -func (tr *Task) TaskFinisher() task.Finisher { - return tr -} - // Finish closes task as done. func (tr *Task) Finish(ctx context.Context, identity task.Identity) error { return tr.close(identity, task.Done) } -// Cancel closes task as canceled. -func (tr *Task) Cancel(ctx context.Context, identity task.Identity) error { - return tr.close(identity, task.Canceled) -} - -func (tr *Task) close(identity task.Identity, st task.Status) error { +// FinishExpired closes expired tasks. +func (tr *Task) FinishExpired(_ context.Context) error { tr.mu.Lock() defer tr.mu.Unlock() - t, found := tr.list[identity] - if !found { - return status.NotFound - } + now := time.Now() - if t.ClosedAt != nil { - return status.Wrap(errors.New("task is already closed"), status.FailedPrecondition) + for _, t := range tr.list { + if t.Deadline != nil && now.After(*t.Deadline) { + err := tr.close(t.Identity, task.Expired) + if err != nil { + return err + } + } } - now := time.Now() - t.ClosedAt = &now - t.Status = st - tr.list[t.Identity] = t - return nil } @@ -121,53 +116,58 @@ func (tr *Task) TaskCreator() task.Creator { return tr } -// Create creates a new task. -func (tr *Task) Create(ctx context.Context, value task.Value) (task.Entity, error) { +// TaskFinder is a service provider. +func (tr *Task) TaskFinder() task.Finder { + return tr +} + +// TaskFinisher is a service provider. +func (tr *Task) TaskFinisher() task.Finisher { + return tr +} + +// TaskUpdater is a service provider. +func (tr *Task) TaskUpdater() task.Updater { + return tr +} + +// Update updates task value by identity. +func (tr *Task) Update(_ context.Context, identity task.Identity, value task.Value) error { tr.mu.Lock() defer tr.mu.Unlock() - for _, t := range tr.list { - if t.Value.Goal == value.Goal { - return task.Entity{}, usecase.Error{ - StatusCode: status.AlreadyExists, - Context: map[string]interface{}{ - "task": t, - }, - Value: errors.New("task with same goal already exists"), - } - } + t, found := tr.list[identity] + if !found { + return status.NotFound } - tr.lastID++ - - if tr.list == nil { - tr.list = make(map[task.Identity]task.Entity, 1) + if t.ClosedAt != nil { + return status.Wrap(errors.New("task is already closed"), status.FailedPrecondition) } - t := task.Entity{} t.Value = value - t.ID = tr.lastID - t.CreatedAt = time.Now() - tr.list[t.Identity] = t + tr.list[identity] = t - return t, nil + return nil } -// FinishExpired closes expired tasks. -func (tr *Task) FinishExpired(_ context.Context) error { +func (tr *Task) close(identity task.Identity, st task.Status) error { tr.mu.Lock() defer tr.mu.Unlock() - now := time.Now() + t, found := tr.list[identity] + if !found { + return status.NotFound + } - for _, t := range tr.list { - if t.Deadline != nil && now.After(*t.Deadline) { - err := tr.close(t.Identity, task.Expired) - if err != nil { - return err - } - } + if t.ClosedAt != nil { + return status.Wrap(errors.New("task is already closed"), status.FailedPrecondition) } + now := time.Now() + t.ClosedAt = &now + t.Status = st + tr.list[t.Identity] = t + return nil } diff --git a/_examples/task-api/internal/infra/service/provider.go b/_examples/task-api/internal/infra/service/provider.go index cec3477..a745138 100644 --- a/_examples/task-api/internal/infra/service/provider.go +++ b/_examples/task-api/internal/infra/service/provider.go @@ -7,11 +7,6 @@ type TaskCreatorProvider interface { TaskCreator() task.Creator } -// TaskUpdaterProvider is a service locator provider. -type TaskUpdaterProvider interface { - TaskUpdater() task.Updater -} - // TaskFinderProvider is a service locator provider. type TaskFinderProvider interface { TaskFinder() task.Finder @@ -21,3 +16,8 @@ type TaskFinderProvider interface { type TaskFinisherProvider interface { TaskFinisher() task.Finisher } + +// TaskUpdaterProvider is a service locator provider. +type TaskUpdaterProvider interface { + TaskUpdater() task.Updater +} diff --git a/_examples/task-api/internal/usecase/finish_task.go b/_examples/task-api/internal/usecase/finish_task.go index 6ddb41a..04e7323 100644 --- a/_examples/task-api/internal/usecase/finish_task.go +++ b/_examples/task-api/internal/usecase/finish_task.go @@ -8,10 +8,6 @@ import ( "github.com/swaggest/usecase/status" ) -type finishTaskDeps interface { - TaskFinisher() task.Finisher -} - // FinishTask creates usecase interactor. func FinishTask(deps finishTaskDeps) usecase.IOInteractor { u := usecase.NewIOI(new(task.Identity), nil, func(ctx context.Context, input, _ interface{}) error { @@ -34,3 +30,7 @@ func FinishTask(deps finishTaskDeps) usecase.IOInteractor { return u } + +type finishTaskDeps interface { + TaskFinisher() task.Finisher +} diff --git a/_examples/task-api/internal/usecase/update_task.go b/_examples/task-api/internal/usecase/update_task.go index 30f18e4..ad9ee29 100644 --- a/_examples/task-api/internal/usecase/update_task.go +++ b/_examples/task-api/internal/usecase/update_task.go @@ -8,11 +8,6 @@ import ( "github.com/swaggest/usecase/status" ) -type updateTask struct { - task.Identity `json:"-"` - task.Value -} - // UpdateTask creates usecase interactor. func UpdateTask( deps interface { @@ -38,3 +33,8 @@ func UpdateTask( return u } + +type updateTask struct { + task.Identity `json:"-"` + task.Value +} diff --git a/_examples/task-api/pkg/graceful/shutdown.go b/_examples/task-api/pkg/graceful/shutdown.go index 7198c46..71bb148 100644 --- a/_examples/task-api/pkg/graceful/shutdown.go +++ b/_examples/task-api/pkg/graceful/shutdown.go @@ -32,15 +32,6 @@ func (s *Shutdown) Close() { } } -// Wait blocks until shutdown. -func (s *Shutdown) Wait() error { - if s.shutdownSignal != nil { - <-s.shutdownSignal - } - - return s.shutdown() -} - // EnableGracefulShutdown schedules service locator termination SIGTERM or SIGINT. func (s *Shutdown) EnableGracefulShutdown() { s.mu.Lock() @@ -66,6 +57,35 @@ func (s *Shutdown) EnableGracefulShutdown() { } } +// ShutdownSignal returns a channel that is closed when service locator is closed or os shutdownSignal is received and +// a confirmation channel that should be closed once subscriber has finished the shutdown. +func (s *Shutdown) ShutdownSignal(subscriber string) (shutdown <-chan struct{}, done chan<- struct{}) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.subscribers == nil { + s.subscribers = make(map[string]chan struct{}) + } + + if d, ok := s.subscribers[subscriber]; ok { + return s.shutdownSignal, d + } + + d := make(chan struct{}, 1) + s.subscribers[subscriber] = d + + return s.shutdownSignal, d +} + +// Wait blocks until shutdown. +func (s *Shutdown) Wait() error { + if s.shutdownSignal != nil { + <-s.shutdownSignal + } + + return s.shutdown() +} + func (s *Shutdown) shutdown() error { s.mu.Lock() defer s.mu.Unlock() @@ -88,23 +108,3 @@ func (s *Shutdown) shutdown() error { return nil } - -// ShutdownSignal returns a channel that is closed when service locator is closed or os shutdownSignal is received and -// a confirmation channel that should be closed once subscriber has finished the shutdown. -func (s *Shutdown) ShutdownSignal(subscriber string) (shutdown <-chan struct{}, done chan<- struct{}) { - s.mu.Lock() - defer s.mu.Unlock() - - if s.subscribers == nil { - s.subscribers = make(map[string]chan struct{}) - } - - if d, ok := s.subscribers[subscriber]; ok { - return s.shutdownSignal, d - } - - d := make(chan struct{}, 1) - s.subscribers[subscriber] = d - - return s.shutdownSignal, d -} diff --git a/chirouter/wrapper.go b/chirouter/wrapper.go index 5fe11ce..63820e8 100644 --- a/chirouter/wrapper.go +++ b/chirouter/wrapper.go @@ -29,60 +29,19 @@ type Wrapper struct { handlers []http.Handler } -var _ chi.Router = &Wrapper{} - -func (r *Wrapper) copy(router chi.Router, pattern string) *Wrapper { - return &Wrapper{ - Router: router, - name: r.name, - basePattern: r.basePattern + pattern, - middlewares: r.middlewares, - wraps: r.wraps, - } -} - -// Wrap appends one or more wrappers that will be applied to handler before adding to Router. -// It is different from middleware in the sense that it is handler-centric, rather than request-centric. -// Wraps are invoked once for each added handler, they are not invoked for http requests. -// Wraps can leverage nethttp.HandlerAs to inspect and access deeper layers. -// For most cases Wrap can be safely used instead of Use, Use is mandatory for middlewares -// that affect routing (such as middleware.StripSlashes for example). -func (r *Wrapper) Wrap(wraps ...func(handler http.Handler) http.Handler) { - r.wraps = append(r.wraps, wraps...) +// Connect adds the route `pattern` that matches a CONNECT http method to execute the `handlerFn` http.HandlerFunc. +func (r *Wrapper) Connect(pattern string, handlerFn http.HandlerFunc) { + r.Method(http.MethodConnect, pattern, handlerFn) } -// Use appends one of more middlewares onto the Router stack. -func (r *Wrapper) Use(middlewares ...func(http.Handler) http.Handler) { - var mws []func(http.Handler) http.Handler - - for _, mw := range middlewares { - if nethttp.MiddlewareIsWrapper(mw) { - r.wraps = append(r.wraps, mw) - } else { - mws = append(mws, mw) - } - } - - r.Router.Use(mws...) - r.middlewares = append(r.middlewares, mws...) +// Delete adds the route `pattern` that matches a DELETE http method to execute the `handlerFn` http.HandlerFunc. +func (r *Wrapper) Delete(pattern string, handlerFn http.HandlerFunc) { + r.Method(http.MethodDelete, pattern, handlerFn) } -// With adds inline middlewares for an endpoint handler. -func (r Wrapper) With(middlewares ...func(http.Handler) http.Handler) chi.Router { - var mws, ws []func(http.Handler) http.Handler - - for _, mw := range middlewares { - if nethttp.MiddlewareIsWrapper(mw) { - ws = append(ws, mw) - } else { - mws = append(mws, mw) - } - } - - c := r.copy(r.Router.With(mws...), "") - c.wraps = append(c.wraps, ws...) - - return c +// Get adds the route `pattern` that matches a GET http method to execute the `handlerFn` http.HandlerFunc. +func (r *Wrapper) Get(pattern string, handlerFn http.HandlerFunc) { + r.Method(http.MethodGet, pattern, handlerFn) } // Group adds a new inline-router along the current routing path, with a fresh middleware stack for the inline-router. @@ -96,17 +55,37 @@ func (r *Wrapper) Group(fn func(r chi.Router)) chi.Router { return im } -// Route mounts a sub-router along a `basePattern` string. -func (r *Wrapper) Route(pattern string, fn func(r chi.Router)) chi.Router { - subRouter := r.copy(chi.NewRouter(), pattern) +// Handle adds routes for `basePattern` that matches all HTTP methods. +func (r *Wrapper) Handle(pattern string, h http.Handler) { + h = r.prepareHandler("", pattern, h) + r.captureHandler(h) + r.Router.Handle(pattern, h) +} - if fn != nil { - fn(subRouter) - } +// HandlerFunc prepares handler and returns its function. +// +// Can be used as input for NotFound, MethodNotAllowed. +func (r *Wrapper) HandlerFunc(h http.Handler) http.HandlerFunc { + h = nethttp.WrapHandler(h, r.wraps...) - r.Router.Mount(pattern, subRouter) + return h.ServeHTTP +} - return subRouter +// Head adds the route `pattern` that matches a HEAD http method to execute the `handlerFn` http.HandlerFunc. +func (r *Wrapper) Head(pattern string, handlerFn http.HandlerFunc) { + r.Method(http.MethodHead, pattern, handlerFn) +} + +// Method adds routes for `basePattern` that matches the `method` HTTP method. +func (r *Wrapper) Method(method, pattern string, h http.Handler) { + h = r.prepareHandler(method, pattern, h) + r.captureHandler(h) + r.Router.Method(method, pattern, h) +} + +// MethodFunc adds the route `pattern` that matches `method` http method to execute the `handlerFn` http.HandlerFunc. +func (r *Wrapper) MethodFunc(method, pattern string, handlerFn http.HandlerFunc) { + r.Method(method, pattern, handlerFn) } // Mount attaches another http.Handler along "./basePattern/*". @@ -135,45 +114,6 @@ func (r *Wrapper) Mount(pattern string, h http.Handler) { r.Router.Mount(pattern, h) } -// Handle adds routes for `basePattern` that matches all HTTP methods. -func (r *Wrapper) Handle(pattern string, h http.Handler) { - h = r.prepareHandler("", pattern, h) - r.captureHandler(h) - r.Router.Handle(pattern, h) -} - -// Method adds routes for `basePattern` that matches the `method` HTTP method. -func (r *Wrapper) Method(method, pattern string, h http.Handler) { - h = r.prepareHandler(method, pattern, h) - r.captureHandler(h) - r.Router.Method(method, pattern, h) -} - -// MethodFunc adds the route `pattern` that matches `method` http method to execute the `handlerFn` http.HandlerFunc. -func (r *Wrapper) MethodFunc(method, pattern string, handlerFn http.HandlerFunc) { - r.Method(method, pattern, handlerFn) -} - -// Connect adds the route `pattern` that matches a CONNECT http method to execute the `handlerFn` http.HandlerFunc. -func (r *Wrapper) Connect(pattern string, handlerFn http.HandlerFunc) { - r.Method(http.MethodConnect, pattern, handlerFn) -} - -// Delete adds the route `pattern` that matches a DELETE http method to execute the `handlerFn` http.HandlerFunc. -func (r *Wrapper) Delete(pattern string, handlerFn http.HandlerFunc) { - r.Method(http.MethodDelete, pattern, handlerFn) -} - -// Get adds the route `pattern` that matches a GET http method to execute the `handlerFn` http.HandlerFunc. -func (r *Wrapper) Get(pattern string, handlerFn http.HandlerFunc) { - r.Method(http.MethodGet, pattern, handlerFn) -} - -// Head adds the route `pattern` that matches a HEAD http method to execute the `handlerFn` http.HandlerFunc. -func (r *Wrapper) Head(pattern string, handlerFn http.HandlerFunc) { - r.Method(http.MethodHead, pattern, handlerFn) -} - // Options adds the route `pattern` that matches a OPTIONS http method to execute the `handlerFn` http.HandlerFunc. func (r *Wrapper) Options(pattern string, handlerFn http.HandlerFunc) { r.Method(http.MethodOptions, pattern, handlerFn) @@ -194,28 +134,90 @@ func (r *Wrapper) Put(pattern string, handlerFn http.HandlerFunc) { r.Method(http.MethodPut, pattern, handlerFn) } +// Route mounts a sub-router along a `basePattern` string. +func (r *Wrapper) Route(pattern string, fn func(r chi.Router)) chi.Router { + subRouter := r.copy(chi.NewRouter(), pattern) + + if fn != nil { + fn(subRouter) + } + + r.Router.Mount(pattern, subRouter) + + return subRouter +} + // Trace adds the route `pattern` that matches a TRACE http method to execute the `handlerFn` http.HandlerFunc. func (r *Wrapper) Trace(pattern string, handlerFn http.HandlerFunc) { r.Method(http.MethodTrace, pattern, handlerFn) } -// HandlerFunc prepares handler and returns its function. -// -// Can be used as input for NotFound, MethodNotAllowed. -func (r *Wrapper) HandlerFunc(h http.Handler) http.HandlerFunc { - h = nethttp.WrapHandler(h, r.wraps...) +// Use appends one of more middlewares onto the Router stack. +func (r *Wrapper) Use(middlewares ...func(http.Handler) http.Handler) { + var mws []func(http.Handler) http.Handler - return h.ServeHTTP + for _, mw := range middlewares { + if nethttp.MiddlewareIsWrapper(mw) { + r.wraps = append(r.wraps, mw) + } else { + mws = append(mws, mw) + } + } + + r.Router.Use(mws...) + r.middlewares = append(r.middlewares, mws...) } -func (r *Wrapper) resolvePattern(pattern string) string { - return r.basePattern + strings.ReplaceAll(pattern, "/*/", "/") +// With adds inline middlewares for an endpoint handler. +func (r Wrapper) With(middlewares ...func(http.Handler) http.Handler) chi.Router { + var mws, ws []func(http.Handler) http.Handler + + for _, mw := range middlewares { + if nethttp.MiddlewareIsWrapper(mw) { + ws = append(ws, mw) + } else { + mws = append(mws, mw) + } + } + + c := r.copy(r.Router.With(mws...), "") + c.wraps = append(c.wraps, ws...) + + return c +} + +// Wrap appends one or more wrappers that will be applied to handler before adding to Router. +// It is different from middleware in the sense that it is handler-centric, rather than request-centric. +// Wraps are invoked once for each added handler, they are not invoked for http requests. +// Wraps can leverage nethttp.HandlerAs to inspect and access deeper layers. +// For most cases Wrap can be safely used instead of Use, Use is mandatory for middlewares +// that affect routing (such as middleware.StripSlashes for example). +func (r *Wrapper) Wrap(wraps ...func(handler http.Handler) http.Handler) { + r.wraps = append(r.wraps, wraps...) } func (r *Wrapper) captureHandler(h http.Handler) { nethttp.WrapHandler(h, r.middlewares...) } +func (r *Wrapper) copy(router chi.Router, pattern string) *Wrapper { + return &Wrapper{ + Router: router, + name: r.name, + basePattern: r.basePattern + pattern, + middlewares: r.middlewares, + wraps: r.wraps, + } +} + +func (r *Wrapper) handlersWithRoute() []http.Handler { + return r.handlers +} + +func (r *Wrapper) handlerWraps() []func(http.Handler) http.Handler { + return r.wraps +} + func (r *Wrapper) prepareHandler(method, pattern string, h http.Handler) http.Handler { mw := nethttp.HandlerWithRouteMiddleware(method, r.resolvePattern(pattern)) h = nethttp.WrapHandler(h, mw) @@ -225,10 +227,8 @@ func (r *Wrapper) prepareHandler(method, pattern string, h http.Handler) http.Ha return h } -func (r *Wrapper) handlersWithRoute() []http.Handler { - return r.handlers +func (r *Wrapper) resolvePattern(pattern string) string { + return r.basePattern + strings.ReplaceAll(pattern, "/*/", "/") } -func (r *Wrapper) handlerWraps() []func(http.Handler) http.Handler { - return r.wraps -} +var _ chi.Router = &Wrapper{} diff --git a/chirouter/wrapper_test.go b/chirouter/wrapper_test.go index 4c4e76a..b6445bc 100644 --- a/chirouter/wrapper_test.go +++ b/chirouter/wrapper_test.go @@ -22,34 +22,6 @@ import ( "github.com/swaggest/usecase" ) -type HandlerWithFoo struct { - http.Handler -} - -func (h HandlerWithFoo) Foo() {} - -type HandlerWithBar struct { - http.Handler -} - -func (h HandlerWithFoo) ServeHTTP(rw http.ResponseWriter, r *http.Request) { - if _, err := rw.Write([]byte("foo")); err != nil { - panic(err) - } - - h.Handler.ServeHTTP(rw, r) -} - -func (h HandlerWithBar) Bar() {} - -func (h HandlerWithBar) ServeHTTP(rw http.ResponseWriter, r *http.Request) { - h.Handler.ServeHTTP(rw, r) - - if _, err := rw.Write([]byte("bar")); err != nil { - panic(err) - } -} - func TestNewWrapper(t *testing.T) { w := chirouter.NewWrapper(chi.NewRouter()) r := w.With(func(handler http.Handler) http.Handler { @@ -131,6 +103,102 @@ func TestNewWrapper(t *testing.T) { assert.Equal(t, 22, totalCnt) } +func TestWrapper_Mount(t *testing.T) { + service := web.NewService(openapi3.NewReflector()) + service.OpenAPISchema().SetTitle("Security and Mount Example") + + apiV1 := web.NewService(openapi3.NewReflector()) + + apiV1.Wrap( + middleware.BasicAuth("Admin Access", map[string]string{"admin": "admin"}), + nethttp.HTTPBasicSecurityMiddleware(service.OpenAPICollector, "Admin", "Admin access"), + ) + + apiV1.Post("/sum", usecase.NewIOI(new([]int), new(int), func(_ context.Context, _, _ interface{}) error { + return errors.New("oops") + })) + + service.Mount("/api/v1", apiV1) + + // Blanket handler, for example to serve static content. + service.Mount("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, err := w.Write([]byte("blanket handler got a request: " + r.URL.String())) + assert.NoError(t, err) + })) + + req, err := http.NewRequest(http.MethodGet, "/foo", nil) + require.NoError(t, err) + + rw := httptest.NewRecorder() + service.ServeHTTP(rw, req) + + assert.Equal(t, "blanket handler got a request: /foo", rw.Body.String()) + + req, err = http.NewRequest(http.MethodPost, "/api/v1/sum", bytes.NewBufferString(`[1,2,3]`)) + require.NoError(t, err) + + rw = httptest.NewRecorder() + + service.ServeHTTP(rw, req) + assert.Equal(t, http.StatusUnauthorized, rw.Code) + + req.Header.Set("Authorization", "Basic YWRtaW46YWRtaW4=") + + rw = httptest.NewRecorder() + + service.ServeHTTP(rw, req) + assert.Equal(t, `{"error":"oops"}`+"\n", rw.Body.String()) + + assertjson.EqualMarshal(t, []byte(`{ + "openapi":"3.0.3","info":{"title":"Security and Mount Example","version":""}, + "paths":{ + "/api/v1/sum":{ + "post":{ + "summary":"Test Wrapper _ Mount", + "operationId":"rest/chirouter_test.TestWrapper_Mount", + "requestBody":{ + "content":{ + "application/json":{ + "schema":{"type":"array","items":{"type":"integer"},"nullable":true} + } + } + }, + "responses":{ + "200":{ + "description":"OK", + "content":{"application/json":{"schema":{"type":"integer"}}} + }, + "401":{ + "description":"Unauthorized", + "content":{ + "application/json":{"schema":{"$ref":"#/components/schemas/RestErrResponse"}} + } + } + }, + "security":[{"Admin":[]}] + } + } + }, + "components":{ + "schemas":{ + "RestErrResponse":{ + "type":"object", + "properties":{ + "code":{"type":"integer","description":"Application-specific error code."}, + "context":{ + "type":"object","additionalProperties":{}, + "description":"Application context." + }, + "error":{"type":"string","description":"Error message."}, + "status":{"type":"string","description":"Status text."} + } + } + }, + "securitySchemes":{"Admin":{"type":"http","scheme":"basic","description":"Admin access"}} + } + }`), service.OpenAPISchema()) +} + func TestWrapper_Use_precedence(t *testing.T) { var log []string @@ -211,7 +279,6 @@ func TestWrapper_Use_precedence(t *testing.T) { // after route match. // For the use case of StripSlashes that would result in not found, because middleware was // invoked AFTER route matching, not BEFORE. - // Solution to this problem was passing middlewares to Router as is, the problem however is // that Router does not allow unwrapping handlers (that is the purpose of Wrapper) to introspect // or augment handlers. @@ -273,102 +340,6 @@ func TestWrapper_Use_StripSlashes(t *testing.T) { }, log) } -func TestWrapper_Mount(t *testing.T) { - service := web.NewService(openapi3.NewReflector()) - service.OpenAPISchema().SetTitle("Security and Mount Example") - - apiV1 := web.NewService(openapi3.NewReflector()) - - apiV1.Wrap( - middleware.BasicAuth("Admin Access", map[string]string{"admin": "admin"}), - nethttp.HTTPBasicSecurityMiddleware(service.OpenAPICollector, "Admin", "Admin access"), - ) - - apiV1.Post("/sum", usecase.NewIOI(new([]int), new(int), func(_ context.Context, _, _ interface{}) error { - return errors.New("oops") - })) - - service.Mount("/api/v1", apiV1) - - // Blanket handler, for example to serve static content. - service.Mount("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - _, err := w.Write([]byte("blanket handler got a request: " + r.URL.String())) - assert.NoError(t, err) - })) - - req, err := http.NewRequest(http.MethodGet, "/foo", nil) - require.NoError(t, err) - - rw := httptest.NewRecorder() - service.ServeHTTP(rw, req) - - assert.Equal(t, "blanket handler got a request: /foo", rw.Body.String()) - - req, err = http.NewRequest(http.MethodPost, "/api/v1/sum", bytes.NewBufferString(`[1,2,3]`)) - require.NoError(t, err) - - rw = httptest.NewRecorder() - - service.ServeHTTP(rw, req) - assert.Equal(t, http.StatusUnauthorized, rw.Code) - - req.Header.Set("Authorization", "Basic YWRtaW46YWRtaW4=") - - rw = httptest.NewRecorder() - - service.ServeHTTP(rw, req) - assert.Equal(t, `{"error":"oops"}`+"\n", rw.Body.String()) - - assertjson.EqualMarshal(t, []byte(`{ - "openapi":"3.0.3","info":{"title":"Security and Mount Example","version":""}, - "paths":{ - "/api/v1/sum":{ - "post":{ - "summary":"Test Wrapper _ Mount", - "operationId":"rest/chirouter_test.TestWrapper_Mount", - "requestBody":{ - "content":{ - "application/json":{ - "schema":{"type":"array","items":{"type":"integer"},"nullable":true} - } - } - }, - "responses":{ - "200":{ - "description":"OK", - "content":{"application/json":{"schema":{"type":"integer"}}} - }, - "401":{ - "description":"Unauthorized", - "content":{ - "application/json":{"schema":{"$ref":"#/components/schemas/RestErrResponse"}} - } - } - }, - "security":[{"Admin":[]}] - } - } - }, - "components":{ - "schemas":{ - "RestErrResponse":{ - "type":"object", - "properties":{ - "code":{"type":"integer","description":"Application-specific error code."}, - "context":{ - "type":"object","additionalProperties":{}, - "description":"Application context." - }, - "error":{"type":"string","description":"Error message."}, - "status":{"type":"string","description":"Status text."} - } - } - }, - "securitySchemes":{"Admin":{"type":"http","scheme":"basic","description":"Admin access"}} - } - }`), service.OpenAPISchema()) -} - func TestWrapper_With(t *testing.T) { wrapperCalled := 0 wrapperFound := 0 @@ -401,3 +372,31 @@ func TestWrapper_With(t *testing.T) { assert.Equal(t, 2, wrapperFound) assert.Equal(t, 5, notWrapperCalled) // 2 wrapper checks, 2 chi chains, 1 capture handler. } + +type HandlerWithBar struct { + http.Handler +} + +func (h HandlerWithBar) Bar() {} + +func (h HandlerWithBar) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + h.Handler.ServeHTTP(rw, r) + + if _, err := rw.Write([]byte("bar")); err != nil { + panic(err) + } +} + +type HandlerWithFoo struct { + http.Handler +} + +func (h HandlerWithFoo) Foo() {} + +func (h HandlerWithFoo) ServeHTTP(rw http.ResponseWriter, r *http.Request) { + if _, err := rw.Write([]byte("foo")); err != nil { + panic(err) + } + + h.Handler.ServeHTTP(rw, r) +} diff --git a/dev_test.go b/dev_test.go index 0cf488a..5dd5047 100644 --- a/dev_test.go +++ b/dev_test.go @@ -1,3 +1,3 @@ package rest_test -import _ "github.com/bool64/dev" // Include CI/Dev scripts to project. +import _ "github.com/bool64/dev" diff --git a/error.go b/error.go index 53c4998..3bb9fd6 100644 --- a/error.go +++ b/error.go @@ -7,43 +7,6 @@ import ( "github.com/swaggest/usecase/status" ) -// HTTPCodeAsError exposes HTTP status code as use case error that can be translated to response status. -type HTTPCodeAsError int - -// Error return HTTP status text. -func (c HTTPCodeAsError) Error() string { - return http.StatusText(int(c)) -} - -// HTTPStatus returns HTTP status code. -func (c HTTPCodeAsError) HTTPStatus() int { - return int(c) -} - -// ErrWithHTTPStatus exposes HTTP status code. -type ErrWithHTTPStatus interface { - error - HTTPStatus() int -} - -// ErrWithFields exposes structured context of error. -type ErrWithFields interface { - error - Fields() map[string]interface{} -} - -// ErrWithAppCode exposes application error code. -type ErrWithAppCode interface { - error - AppErrCode() int -} - -// ErrWithCanonicalStatus exposes canonical status code. -type ErrWithCanonicalStatus interface { - error - Status() status.Code -} - // Err creates HTTP status code and ErrResponse for error. // // You can use it with use case status code: @@ -92,31 +55,6 @@ func Err(err error) (int, ErrResponse) { return er.httpStatusCode, er } -// ErrResponse is HTTP error response body. -type ErrResponse struct { - StatusText string `json:"status,omitempty" description:"Status text."` - AppCode int `json:"code,omitempty" description:"Application-specific error code."` - ErrorText string `json:"error,omitempty" description:"Error message."` - Context map[string]interface{} `json:"context,omitempty" description:"Application context."` - - err error // Original error. - httpStatusCode int // HTTP response status code. -} - -// Error implements error. -func (e ErrResponse) Error() string { - if e.ErrorText != "" { - return e.ErrorText - } - - return e.StatusText -} - -// Unwrap returns parent error. -func (e ErrResponse) Unwrap() error { - return e.err -} - // HTTPStatusFromCanonicalCode returns http status accordingly to use case status code. func HTTPStatusFromCanonicalCode(c status.Code) int { switch c { @@ -159,3 +97,65 @@ func HTTPStatusFromCanonicalCode(c status.Code) int { return http.StatusInternalServerError } + +// ErrResponse is HTTP error response body. +type ErrResponse struct { + StatusText string `json:"status,omitempty" description:"Status text."` + AppCode int `json:"code,omitempty" description:"Application-specific error code."` + ErrorText string `json:"error,omitempty" description:"Error message."` + Context map[string]interface{} `json:"context,omitempty" description:"Application context."` + + err error // Original error. + httpStatusCode int // HTTP response status code. +} + +// Error implements error. +func (e ErrResponse) Error() string { + if e.ErrorText != "" { + return e.ErrorText + } + + return e.StatusText +} + +// Unwrap returns parent error. +func (e ErrResponse) Unwrap() error { + return e.err +} + +// ErrWithAppCode exposes application error code. +type ErrWithAppCode interface { + error + AppErrCode() int +} + +// ErrWithCanonicalStatus exposes canonical status code. +type ErrWithCanonicalStatus interface { + error + Status() status.Code +} + +// ErrWithFields exposes structured context of error. +type ErrWithFields interface { + error + Fields() map[string]interface{} +} + +// ErrWithHTTPStatus exposes HTTP status code. +type ErrWithHTTPStatus interface { + error + HTTPStatus() int +} + +// HTTPCodeAsError exposes HTTP status code as use case error that can be translated to response status. +type HTTPCodeAsError int + +// Error return HTTP status text. +func (c HTTPCodeAsError) Error() string { + return http.StatusText(int(c)) +} + +// HTTPStatus returns HTTP status code. +func (c HTTPCodeAsError) HTTPStatus() int { + return int(c) +} diff --git a/error_test.go b/error_test.go index 512667d..b499018 100644 --- a/error_test.go +++ b/error_test.go @@ -12,24 +12,6 @@ import ( "github.com/swaggest/usecase/status" ) -func TestHTTPStatusFromCanonicalCode(t *testing.T) { - maxStatusCode := 17 - for i := 0; i <= maxStatusCode; i++ { - s := status.Code(i) - assert.NotEmpty(t, rest.HTTPStatusFromCanonicalCode(s)) - } -} - -type errWithHTTPStatus int - -func (e errWithHTTPStatus) Error() string { - return "failed very much" -} - -func (e errWithHTTPStatus) HTTPStatus() int { - return int(e) -} - func TestErr(t *testing.T) { err := usecase.Error{ StatusCode: status.InvalidArgument, @@ -88,3 +70,21 @@ func TestErr(t *testing.T) { assert.NoError(t, er) }) } + +func TestHTTPStatusFromCanonicalCode(t *testing.T) { + maxStatusCode := 17 + for i := 0; i <= maxStatusCode; i++ { + s := status.Code(i) + assert.NotEmpty(t, rest.HTTPStatusFromCanonicalCode(s)) + } +} + +type errWithHTTPStatus int + +func (e errWithHTTPStatus) Error() string { + return "failed very much" +} + +func (e errWithHTTPStatus) HTTPStatus() int { + return int(e) +} diff --git a/gorillamux/collector.go b/gorillamux/collector.go index d72509c..c056c91 100644 --- a/gorillamux/collector.go +++ b/gorillamux/collector.go @@ -11,6 +11,19 @@ import ( "github.com/swaggest/rest/openapi" ) +// NewOpenAPICollector creates route walker for gorilla/mux, that collects OpenAPI operations. +func NewOpenAPICollector(r oapi.Reflector) *OpenAPICollector { + c := openapi.NewCollector(r) + + return &OpenAPICollector{ + Collector: c, + DefaultMethods: []string{ + http.MethodHead, http.MethodGet, http.MethodPost, + http.MethodPut, http.MethodPatch, http.MethodDelete, + }, + } +} + // OpenAPICollector is a wrapper for openapi.Collector tailored to walk gorilla/mux router. type OpenAPICollector struct { // Collector is an actual OpenAPI collector. @@ -28,26 +41,6 @@ type OpenAPICollector struct { Host string } -// NewOpenAPICollector creates route walker for gorilla/mux, that collects OpenAPI operations. -func NewOpenAPICollector(r oapi.Reflector) *OpenAPICollector { - c := openapi.NewCollector(r) - - return &OpenAPICollector{ - Collector: c, - DefaultMethods: []string{ - http.MethodHead, http.MethodGet, http.MethodPost, - http.MethodPut, http.MethodPatch, http.MethodDelete, - }, - } -} - -// OpenAPIPreparer defines http.Handler with OpenAPI information. -type OpenAPIPreparer interface { - SetupOpenAPIOperation(oc oapi.OperationContext) error -} - -type preparerFunc func(oc oapi.OperationContext) error - // Walker walks route tree and collects OpenAPI information. func (dc *OpenAPICollector) Walker(route *mux.Route, _ *mux.Router, _ []*mux.Route) error { handler := route.GetHandler() @@ -133,3 +126,10 @@ func (dc *OpenAPICollector) collect(method, path string, preparer preparerFunc) return nil } } + +// OpenAPIPreparer defines http.Handler with OpenAPI information. +type OpenAPIPreparer interface { + SetupOpenAPIOperation(oc oapi.OperationContext) error +} + +type preparerFunc func(oc oapi.OperationContext) error diff --git a/gorillamux/collector_test.go b/gorillamux/collector_test.go index a61a48a..a20a278 100644 --- a/gorillamux/collector_test.go +++ b/gorillamux/collector_test.go @@ -13,28 +13,6 @@ import ( "github.com/swaggest/usecase" ) -type structuredHandler struct { - usecase.Info - usecase.WithInput - usecase.WithOutput -} - -func (s structuredHandler) SetupOpenAPIOperation(oc openapi.OperationContext) error { - oc.AddReqStructure(s.Input) - oc.AddRespStructure(s.Output) - - return nil -} - -func newStructuredHandler(setup func(h *structuredHandler)) structuredHandler { - h := structuredHandler{} - setup(&h) - - return h -} - -func (s structuredHandler) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {} - func TestOpenAPICollector_Walker(t *testing.T) { r := mux.NewRouter() @@ -233,3 +211,25 @@ func TestOpenAPICollector_Walker(t *testing.T) { } }`, rf.Spec) } + +func newStructuredHandler(setup func(h *structuredHandler)) structuredHandler { + h := structuredHandler{} + setup(&h) + + return h +} + +type structuredHandler struct { + usecase.Info + usecase.WithInput + usecase.WithOutput +} + +func (s structuredHandler) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {} + +func (s structuredHandler) SetupOpenAPIOperation(oc openapi.OperationContext) error { + oc.AddReqStructure(s.Input) + oc.AddRespStructure(s.Output) + + return nil +} diff --git a/gorillamux/example_openapi_collector_test.go b/gorillamux/example_openapi_collector_test.go index 1078466..1ac525e 100644 --- a/gorillamux/example_openapi_collector_test.go +++ b/gorillamux/example_openapi_collector_test.go @@ -15,76 +15,6 @@ import ( "github.com/swaggest/rest/request" ) -// Define request structure for your HTTP handler. -type myRequest struct { - Query1 int `query:"query1"` - Path1 string `path:"path1"` - Path2 int `path:"path2"` - Header1 float64 `header:"X-Header-1"` - FormData1 bool `formData:"formData1"` - FormData2 string `formData:"formData2"` -} - -type myResp struct { - Sum float64 `json:"sum"` - Concat string `json:"concat"` -} - -func newMyHandler() *myHandler { - decoderFactory := request.NewDecoderFactory() - decoderFactory.ApplyDefaults = true - decoderFactory.SetDecoderFunc(rest.ParamInPath, gorillamux.PathToURLValues) - - return &myHandler{ - dec: decoderFactory.MakeDecoder(http.MethodPost, myRequest{}, nil), - } -} - -type myHandler struct { - // Automated request decoding is not required to collect OpenAPI schema, - // but it is good to have to establish a single source of truth and to simplify request reading. - dec nethttp.RequestDecoder -} - -func (m *myHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - var in myRequest - - if err := m.dec.Decode(r, &in, nil); err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - - return - } - - // Serve request. - out := myResp{ - Sum: in.Header1 + float64(in.Path2) + float64(in.Query1), - Concat: in.Path1 + in.FormData2 + strconv.FormatBool(in.FormData1), - } - - j, err := json.Marshal(out) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - - return - } - - _, _ = w.Write(j) -} - -// SetupOpenAPIOperation declares OpenAPI schema for the handler. -func (m *myHandler) SetupOpenAPIOperation(oc openapi.OperationContext) error { - oc.SetTags("My Tag") - oc.SetSummary("My Summary") - oc.SetDescription("This endpoint aggregates request in structured way.") - - oc.AddReqStructure(myRequest{}) - oc.AddRespStructure(myResp{}) - oc.AddRespStructure(nil, openapi.WithContentType("text/html"), openapi.WithHTTPStatus(http.StatusBadRequest)) - oc.AddRespStructure(nil, openapi.WithContentType("text/html"), openapi.WithHTTPStatus(http.StatusInternalServerError)) - - return nil -} - func ExampleNewOpenAPICollector() { // Your router does not need special instrumentation. router := mux.NewRouter() @@ -193,3 +123,73 @@ func ExampleNewOpenAPICollector() { // type: number // type: object } + +func newMyHandler() *myHandler { + decoderFactory := request.NewDecoderFactory() + decoderFactory.ApplyDefaults = true + decoderFactory.SetDecoderFunc(rest.ParamInPath, gorillamux.PathToURLValues) + + return &myHandler{ + dec: decoderFactory.MakeDecoder(http.MethodPost, myRequest{}, nil), + } +} + +type myHandler struct { + // Automated request decoding is not required to collect OpenAPI schema, + // but it is good to have to establish a single source of truth and to simplify request reading. + dec nethttp.RequestDecoder +} + +func (m *myHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + var in myRequest + + if err := m.dec.Decode(r, &in, nil); err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + + return + } + + // Serve request. + out := myResp{ + Sum: in.Header1 + float64(in.Path2) + float64(in.Query1), + Concat: in.Path1 + in.FormData2 + strconv.FormatBool(in.FormData1), + } + + j, err := json.Marshal(out) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + + return + } + + _, _ = w.Write(j) +} + +// SetupOpenAPIOperation declares OpenAPI schema for the handler. +func (m *myHandler) SetupOpenAPIOperation(oc openapi.OperationContext) error { + oc.SetTags("My Tag") + oc.SetSummary("My Summary") + oc.SetDescription("This endpoint aggregates request in structured way.") + + oc.AddReqStructure(myRequest{}) + oc.AddRespStructure(myResp{}) + oc.AddRespStructure(nil, openapi.WithContentType("text/html"), openapi.WithHTTPStatus(http.StatusBadRequest)) + oc.AddRespStructure(nil, openapi.WithContentType("text/html"), openapi.WithHTTPStatus(http.StatusInternalServerError)) + + return nil +} + +// Define request structure for your HTTP handler. +type myRequest struct { + Query1 int `query:"query1"` + Path1 string `path:"path1"` + Path2 int `path:"path2"` + Header1 float64 `header:"X-Header-1"` + FormData1 bool `formData:"formData1"` + FormData2 string `formData:"formData2"` +} + +type myResp struct { + Sum float64 `json:"sum"` + Concat string `json:"concat"` +} diff --git a/gzip/container.go b/gzip/container.go index 9c05a50..dc8c6ff 100644 --- a/gzip/container.go +++ b/gzip/container.go @@ -11,15 +11,47 @@ import ( "github.com/cespare/xxhash/v2" ) -// Writer writes gzip data into suitable stream or returns 0, nil. -type Writer interface { - GzipWrite(d []byte) (int, error) +// MarshalJSON encodes Go value as JSON and compresses result with gzip. +func MarshalJSON(v interface{}) ([]byte, error) { + b := bytes.Buffer{} + w := gzip.NewWriter(&b) + + enc := json.NewEncoder(w) + + err := enc.Encode(v) + if err != nil { + return nil, err + } + + err = w.Close() + if err != nil { + return nil, err + } + + // Copying result slice to reduce dynamic capacity. + res := make([]byte, len(b.Bytes())) + copy(res, b.Bytes()) + + return res, nil } -// JSONContainer contains compressed JSON. -type JSONContainer struct { - gz []byte - hash string +// UnmarshalJSON decodes compressed JSON bytes into a Go value. +func UnmarshalJSON(data []byte, v interface{}) error { + b := bytes.NewReader(data) + + r, err := gzip.NewReader(b) + if err != nil { + return err + } + + dec := json.NewDecoder(r) + + err = dec.Decode(v) + if err != nil { + return err + } + + return r.Close() } // WriteCompressedBytes writes compressed bytes to response. @@ -44,22 +76,15 @@ func WriteCompressedBytes(compressed []byte, w io.Writer) (int, error) { return int(n), err } -// UnpackJSON unmarshals data from JSON container into a Go value. -func (jc JSONContainer) UnpackJSON(v interface{}) error { - return UnmarshalJSON(jc.gz, v) +// JSONContainer contains compressed JSON. +type JSONContainer struct { + gz []byte + hash string } -// PackJSON puts Go value in JSON container. -func (jc *JSONContainer) PackJSON(v interface{}) error { - res, err := MarshalJSON(v) - if err != nil { - return err - } - - jc.gz = res - jc.hash = strconv.FormatUint(xxhash.Sum64(res), 36) - - return nil +// ETag returns hash of compressed bytes. +func (jc JSONContainer) ETag() string { + return jc.hash } // GzipCompressedJSON returns JSON compressed with gzip. @@ -67,6 +92,11 @@ func (jc JSONContainer) GzipCompressedJSON() []byte { return jc.gz } +// JSONWriteTo writes JSON payload to writer. +func (jc JSONContainer) JSONWriteTo(w io.Writer) (int, error) { + return WriteCompressedBytes(jc.gz, w) +} + // MarshalJSON returns uncompressed JSON. func (jc JSONContainer) MarshalJSON() (j []byte, err error) { b := bytes.NewReader(jc.gz) @@ -86,55 +116,25 @@ func (jc JSONContainer) MarshalJSON() (j []byte, err error) { return ioutil.ReadAll(r) } -// ETag returns hash of compressed bytes. -func (jc JSONContainer) ETag() string { - return jc.hash -} - -// MarshalJSON encodes Go value as JSON and compresses result with gzip. -func MarshalJSON(v interface{}) ([]byte, error) { - b := bytes.Buffer{} - w := gzip.NewWriter(&b) - - enc := json.NewEncoder(w) - - err := enc.Encode(v) - if err != nil { - return nil, err - } - - err = w.Close() - if err != nil { - return nil, err - } - - // Copying result slice to reduce dynamic capacity. - res := make([]byte, len(b.Bytes())) - copy(res, b.Bytes()) - - return res, nil -} - -// UnmarshalJSON decodes compressed JSON bytes into a Go value. -func UnmarshalJSON(data []byte, v interface{}) error { - b := bytes.NewReader(data) - - r, err := gzip.NewReader(b) +// PackJSON puts Go value in JSON container. +func (jc *JSONContainer) PackJSON(v interface{}) error { + res, err := MarshalJSON(v) if err != nil { return err } - dec := json.NewDecoder(r) + jc.gz = res + jc.hash = strconv.FormatUint(xxhash.Sum64(res), 36) - err = dec.Decode(v) - if err != nil { - return err - } + return nil +} - return r.Close() +// UnpackJSON unmarshals data from JSON container into a Go value. +func (jc JSONContainer) UnpackJSON(v interface{}) error { + return UnmarshalJSON(jc.gz, v) } -// JSONWriteTo writes JSON payload to writer. -func (jc JSONContainer) JSONWriteTo(w io.Writer) (int, error) { - return WriteCompressedBytes(jc.gz, w) +// Writer writes gzip data into suitable stream or returns 0, nil. +type Writer interface { + GzipWrite(d []byte) (int, error) } diff --git a/jsonschema/validator.go b/jsonschema/validator.go index 74e5d1a..23112f1 100644 --- a/jsonschema/validator.go +++ b/jsonschema/validator.go @@ -11,18 +11,6 @@ import ( "github.com/swaggest/rest" ) -var _ rest.Validator = &Validator{} - -// Validator is a JSON Schema based validator. -type Validator struct { - // JSONMarshal controls custom marshaler, nil value enables "encoding/json". - JSONMarshal func(interface{}) ([]byte, error) - - inNamedSchemas map[rest.ParamIn]map[string]*jsonschema.Schema - inRequired map[rest.ParamIn][]string - forbidUnknown map[rest.ParamIn]bool -} - // NewFactory creates new validator factory. func NewFactory( requestSchemas rest.RequestJSONSchemaProvider, @@ -88,13 +76,14 @@ func (f Factory) MakeResponseValidator( return &v } -// ForbidUnknownParams configures if unknown parameters should be forbidden. -func (v *Validator) ForbidUnknownParams(in rest.ParamIn, forbidden bool) { - if v.forbidUnknown == nil { - v.forbidUnknown = make(map[rest.ParamIn]bool) - } +// Validator is a JSON Schema based validator. +type Validator struct { + // JSONMarshal controls custom marshaler, nil value enables "encoding/json". + JSONMarshal func(interface{}) ([]byte, error) - v.forbidUnknown[in] = forbidden + inNamedSchemas map[rest.ParamIn]map[string]*jsonschema.Schema + inRequired map[rest.ParamIn][]string + forbidUnknown map[rest.ParamIn]bool } // AddSchema registers schema for validation. @@ -144,48 +133,13 @@ func (v *Validator) AddSchema(in rest.ParamIn, name string, jsonSchema []byte, r return nil } -func (v *Validator) checkRequired(in rest.ParamIn, namedData map[string]interface{}) []string { - required := v.inRequired[in] - - if len(required) == 0 { - return nil - } - - var missing []string - - for _, name := range v.inRequired[in] { - if _, ok := namedData[name]; !ok { - missing = append(missing, name) - } - } - - return missing -} - -// ValidateJSONBody performs validation of JSON body. -func (v *Validator) ValidateJSONBody(jsonBody []byte) error { - name := "body" - - schema, found := v.inNamedSchemas[rest.ParamInBody][name] - if !found || schema == nil { - return nil - } - - err := schema.Validate(bytes.NewBuffer(jsonBody)) - if err == nil { - return nil - } - - errs := make(rest.ValidationErrors, 1) - - //nolint:errorlint // Error is not wrapped, type assertion is more performant. - if ve, ok := err.(*jsonschema.ValidationError); ok { - errs[name] = appendError(errs[name], ve) - } else { - errs[name] = append(errs[name], err.Error()) +// ForbidUnknownParams configures if unknown parameters should be forbidden. +func (v *Validator) ForbidUnknownParams(in rest.ParamIn, forbidden bool) { + if v.forbidUnknown == nil { + v.forbidUnknown = make(map[rest.ParamIn]bool) } - return errs + v.forbidUnknown[in] = forbidden } // HasConstraints indicates if there are validation rules for parameter location. @@ -255,6 +209,52 @@ func (v *Validator) ValidateData(in rest.ParamIn, namedData map[string]interface return nil } +// ValidateJSONBody performs validation of JSON body. +func (v *Validator) ValidateJSONBody(jsonBody []byte) error { + name := "body" + + schema, found := v.inNamedSchemas[rest.ParamInBody][name] + if !found || schema == nil { + return nil + } + + err := schema.Validate(bytes.NewBuffer(jsonBody)) + if err == nil { + return nil + } + + errs := make(rest.ValidationErrors, 1) + + //nolint:errorlint // Error is not wrapped, type assertion is more performant. + if ve, ok := err.(*jsonschema.ValidationError); ok { + errs[name] = appendError(errs[name], ve) + } else { + errs[name] = append(errs[name], err.Error()) + } + + return errs +} + +func (v *Validator) checkRequired(in rest.ParamIn, namedData map[string]interface{}) []string { + required := v.inRequired[in] + + if len(required) == 0 { + return nil + } + + var missing []string + + for _, name := range v.inRequired[in] { + if _, ok := namedData[name]; !ok { + missing = append(missing, name) + } + } + + return missing +} + +var _ rest.Validator = &Validator{} + func appendError(errorMessages []string, err *jsonschema.ValidationError) []string { errorMessages = append(errorMessages, err.InstancePtr+": "+err.Message) for _, ec := range err.Causes { diff --git a/jsonschema/validator_test.go b/jsonschema/validator_test.go index 1b30880..6216302 100644 --- a/jsonschema/validator_test.go +++ b/jsonschema/validator_test.go @@ -36,6 +36,37 @@ func BenchmarkRequestValidator_ValidateRequestData(b *testing.B) { } } +func TestFactory_MakeResponseValidator(t *testing.T) { + validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). + MakeResponseValidator(http.StatusOK, "application/json", new(struct { + Name string `json:"name" minLength:"1"` + Trace string `maxLength:"3"` + }), map[string]string{ + "Trace": "x-TrAcE", + }) + + assert.NoError(t, validator.ValidateJSONBody([]byte(`{"name":"John"}`))) + assert.Error(t, validator.ValidateJSONBody([]byte(`{"name":""}`))) // minLength:"1" violated. + assert.NoError(t, validator.ValidateData(rest.ParamInHeader, map[string]interface{}{ + "X-Trace": "abc", + })) + assert.Error(t, validator.ValidateData(rest.ParamInHeader, map[string]interface{}{ + "X-Trace": "abcd", // maxLength:"3" violated. + })) +} + +func TestNullableTime(t *testing.T) { + type request struct { + ExpiryDate *time.Time `json:"expiryDate"` + } + + validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). + MakeRequestValidator(http.MethodPost, new(request), nil) + err := validator.ValidateJSONBody([]byte(`{"expiryDate":null}`)) + + assert.NoError(t, err, "%+v", err) +} + func TestRequestValidator_ValidateData(t *testing.T) { validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). MakeRequestValidator(http.MethodPost, new(struct { @@ -71,37 +102,6 @@ func TestRequestValidator_ValidateData(t *testing.T) { assert.Equal(t, err, rest.ValidationErrors{"formData:inFormData": []string{"#: length must be >= 3, but got 2"}}) } -func TestFactory_MakeResponseValidator(t *testing.T) { - validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). - MakeResponseValidator(http.StatusOK, "application/json", new(struct { - Name string `json:"name" minLength:"1"` - Trace string `maxLength:"3"` - }), map[string]string{ - "Trace": "x-TrAcE", - }) - - assert.NoError(t, validator.ValidateJSONBody([]byte(`{"name":"John"}`))) - assert.Error(t, validator.ValidateJSONBody([]byte(`{"name":""}`))) // minLength:"1" violated. - assert.NoError(t, validator.ValidateData(rest.ParamInHeader, map[string]interface{}{ - "X-Trace": "abc", - })) - assert.Error(t, validator.ValidateData(rest.ParamInHeader, map[string]interface{}{ - "X-Trace": "abcd", // maxLength:"3" violated. - })) -} - -func TestNullableTime(t *testing.T) { - type request struct { - ExpiryDate *time.Time `json:"expiryDate"` - } - - validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). - MakeRequestValidator(http.MethodPost, new(request), nil) - err := validator.ValidateJSONBody([]byte(`{"expiryDate":null}`)) - - assert.NoError(t, err, "%+v", err) -} - func TestValidator_ForbidUnknownParams(t *testing.T) { req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/?foo=bar&baz=1", nil) diff --git a/nethttp/handler.go b/nethttp/handler.go index 9b081b0..954c5be 100644 --- a/nethttp/handler.go +++ b/nethttp/handler.go @@ -11,7 +11,20 @@ import ( "github.com/swaggest/usecase/status" ) -var _ http.Handler = &Handler{} +// HandlerWithRouteMiddleware wraps handler with routing information. +func HandlerWithRouteMiddleware(method, pathPattern string) func(http.Handler) http.Handler { + return func(handler http.Handler) http.Handler { + if IsWrapperChecker(handler) { + return handler + } + + return handlerWithRoute{ + Handler: handler, + pathPattern: pathPattern, + method: method, + } + } +} // NewHandler creates use case http handler. func NewHandler(useCase usecase.Interactor, options ...func(h *Handler)) *Handler { @@ -33,19 +46,6 @@ func NewHandler(useCase usecase.Interactor, options ...func(h *Handler)) *Handle return h } -// UseCase returns use case interactor. -func (h *Handler) UseCase() usecase.Interactor { - return h.useCase -} - -// SetUseCase prepares handler for a use case. -func (h *Handler) SetUseCase(useCase usecase.Interactor) { - h.useCase = useCase - - h.setupInputBuffer() - h.setupOutputBuffer() -} - // Handler is a use case http handler with documentation and inputPort validation. // // Please use NewHandler to create instance. @@ -71,33 +71,6 @@ type Handler struct { responseEncoder ResponseEncoder } -// SetResponseEncoder sets response encoder. -func (h *Handler) SetResponseEncoder(responseEncoder ResponseEncoder) { - h.responseEncoder = responseEncoder - - h.setupOutputBuffer() -} - -// SetRequestDecoder sets request decoder. -func (h *Handler) SetRequestDecoder(requestDecoder RequestDecoder) { - h.requestDecoder = requestDecoder -} - -func (h *Handler) decodeRequest(r *http.Request) (interface{}, error) { - if h.requestDecoder == nil { - panic("request decoder is not initialized, please use SetRequestDecoder") - } - - iv := reflect.New(h.inputBufferType) - err := h.requestDecoder.Decode(r, iv.Interface(), h.ReqValidator) - - if !h.inputIsPtr { - return iv.Elem().Interface(), err - } - - return iv.Interface(), err -} - // ServeHTTP serves http inputPort with use case interactor. func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { var ( @@ -135,38 +108,45 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.responseEncoder.WriteSuccessfulResponse(w, r, output, h.HandlerTrait) } -func (h *Handler) handleErrResponseDefault(w http.ResponseWriter, r *http.Request, err error) { - var ( - code int - er interface{} - ) +// SetRequestDecoder sets request decoder. +func (h *Handler) SetRequestDecoder(requestDecoder RequestDecoder) { + h.requestDecoder = requestDecoder +} - if h.MakeErrResp != nil { - code, er = h.MakeErrResp(r.Context(), err) - } else { - code, er = rest.Err(err) - } +// SetResponseEncoder sets response encoder. +func (h *Handler) SetResponseEncoder(responseEncoder ResponseEncoder) { + h.responseEncoder = responseEncoder - h.responseEncoder.WriteErrResponse(w, r, code, er) + h.setupOutputBuffer() } -func (h *Handler) handleErrResponse(w http.ResponseWriter, r *http.Request, err error) { - if h.HandleErrResponse != nil { - h.HandleErrResponse(w, r, err) +// SetUseCase prepares handler for a use case. +func (h *Handler) SetUseCase(useCase usecase.Interactor) { + h.useCase = useCase - return - } + h.setupInputBuffer() + h.setupOutputBuffer() +} - h.handleErrResponseDefault(w, r, err) +// UseCase returns use case interactor. +func (h *Handler) UseCase() usecase.Interactor { + return h.useCase } -func closeMultipartForm(r *http.Request) { - if err := r.MultipartForm.RemoveAll(); err != nil { - log.Println(err) +func (h *Handler) decodeRequest(r *http.Request) (interface{}, error) { + if h.requestDecoder == nil { + panic("request decoder is not initialized, please use SetRequestDecoder") } -} -type decodeErrCtxKey struct{} + iv := reflect.New(h.inputBufferType) + err := h.requestDecoder.Decode(r, iv.Interface(), h.ReqValidator) + + if !h.inputIsPtr { + return iv.Elem().Interface(), err + } + + return iv.Interface(), err +} func (h *Handler) handleDecodeError(w http.ResponseWriter, r *http.Request, err error, input, output interface{}) { err = status.Wrap(err, status.InvalidArgument) @@ -178,6 +158,31 @@ func (h *Handler) handleDecodeError(w http.ResponseWriter, r *http.Request, err h.handleErrResponse(w, r, err) } +func (h *Handler) handleErrResponse(w http.ResponseWriter, r *http.Request, err error) { + if h.HandleErrResponse != nil { + h.HandleErrResponse(w, r, err) + + return + } + + h.handleErrResponseDefault(w, r, err) +} + +func (h *Handler) handleErrResponseDefault(w http.ResponseWriter, r *http.Request, err error) { + var ( + code int + er interface{} + ) + + if h.MakeErrResp != nil { + code, er = h.MakeErrResp(r.Context(), err) + } else { + code, er = rest.Err(err) + } + + h.responseEncoder.WriteErrResponse(w, r, code, er) +} + func (h *Handler) setupInputBuffer() { h.inputBufferType = nil @@ -212,35 +217,6 @@ func (h *Handler) setupOutputBuffer() { } } -type handlerWithRoute struct { - http.Handler - method string - pathPattern string -} - -func (h handlerWithRoute) RouteMethod() string { - return h.method -} - -func (h handlerWithRoute) RoutePattern() string { - return h.pathPattern -} - -// HandlerWithRouteMiddleware wraps handler with routing information. -func HandlerWithRouteMiddleware(method, pathPattern string) func(http.Handler) http.Handler { - return func(handler http.Handler) http.Handler { - if IsWrapperChecker(handler) { - return handler - } - - return handlerWithRoute{ - Handler: handler, - pathPattern: pathPattern, - method: method, - } - } -} - // RequestDecoder maps data from http.Request into structured Go input value. type RequestDecoder interface { // Decode fills input with data from request, input should be a pointer. @@ -259,3 +235,27 @@ type ResponseEncoder interface { SetupOutput(output interface{}, ht *rest.HandlerTrait) MakeOutput(w http.ResponseWriter, ht rest.HandlerTrait) interface{} } + +var _ http.Handler = &Handler{} + +func closeMultipartForm(r *http.Request) { + if err := r.MultipartForm.RemoveAll(); err != nil { + log.Println(err) + } +} + +type decodeErrCtxKey struct{} + +type handlerWithRoute struct { + http.Handler + method string + pathPattern string +} + +func (h handlerWithRoute) RouteMethod() string { + return h.method +} + +func (h handlerWithRoute) RoutePattern() string { + return h.pathPattern +} diff --git a/nethttp/handler_test.go b/nethttp/handler_test.go index feab71b..311a8ab 100644 --- a/nethttp/handler_test.go +++ b/nethttp/handler_test.go @@ -17,14 +17,6 @@ import ( "github.com/swaggest/usecase" ) -type Input struct { - ID int -} - -type Output struct { - Value string `json:"value"` -} - func TestHandler_ServeHTTP(t *testing.T) { u := &struct { usecase.Interactor @@ -108,6 +100,82 @@ func TestHandler_ServeHTTP(t *testing.T) { assert.True(t, umwCalled) } +func TestHandler_ServeHTTP_customErrResp(t *testing.T) { + u := struct { + usecase.Interactor + usecase.OutputWithNoContent + }{} + + u.Interactor = usecase.Interact(func(_ context.Context, input, output interface{}) error { + assert.Nil(t, input) + assert.Nil(t, output) + + return errors.New("use case failed") + }) + + h := nethttp.NewHandler(u) + h.MakeErrResp = func(_ context.Context, err error) (int, interface{}) { + return http.StatusExpectationFailed, struct { + Custom string `json:"custom"` + }{ + Custom: err.Error(), + } + } + h.SetResponseEncoder(&response.Encoder{}) + + req, err := http.NewRequest(http.MethodGet, "/test", nil) + require.NoError(t, err) + + rw := httptest.NewRecorder() + h.ServeHTTP(rw, req) + + assert.Equal(t, http.StatusExpectationFailed, rw.Code) + assert.Equal(t, `{"custom":"use case failed"}`+"\n", rw.Body.String()) +} + +func TestHandler_ServeHTTP_customMapping(t *testing.T) { + u := &struct { + usecase.Interactor + usecase.WithInput + usecase.OutputWithNoContent + }{} + + u.Input = new(Input) + u.Interactor = usecase.Interact(func(_ context.Context, input, _ interface{}) error { + in, ok := input.(*Input) + assert.True(t, ok) + assert.Equal(t, 123, in.ID) + + return nil + }) + + uh := nethttp.NewHandler(u) + uh.ReqMapping = rest.RequestMapping{ + rest.ParamInQuery: map[string]string{"ID": "ident"}, + } + + ws := []func(handler http.Handler) http.Handler{ + request.DecoderMiddleware(request.NewDecoderFactory()), + nethttp.HandlerWithRouteMiddleware(http.MethodGet, "/test"), + response.EncoderMiddleware, + } + + h := nethttp.WrapHandler(uh, ws...) + + for i, w := range ws { + assert.True(t, nethttp.MiddlewareIsWrapper(w), i) + } + + req, err := http.NewRequest(http.MethodGet, "/test?ident=123", nil) + require.NoError(t, err) + + rw := httptest.NewRecorder() + h.ServeHTTP(rw, req) + + assert.Equal(t, http.StatusNoContent, rw.Code) + assert.Equal(t, "", rw.Body.String()) +} + func TestHandler_ServeHTTP_decodeErr(t *testing.T) { u := &struct { usecase.Interactor @@ -178,63 +246,6 @@ func TestHandler_ServeHTTP_emptyPorts(t *testing.T) { assert.Equal(t, "", rw.Body.String()) } -func TestHandler_ServeHTTP_customErrResp(t *testing.T) { - u := struct { - usecase.Interactor - usecase.OutputWithNoContent - }{} - - u.Interactor = usecase.Interact(func(_ context.Context, input, output interface{}) error { - assert.Nil(t, input) - assert.Nil(t, output) - - return errors.New("use case failed") - }) - - h := nethttp.NewHandler(u) - h.MakeErrResp = func(_ context.Context, err error) (int, interface{}) { - return http.StatusExpectationFailed, struct { - Custom string `json:"custom"` - }{ - Custom: err.Error(), - } - } - h.SetResponseEncoder(&response.Encoder{}) - - req, err := http.NewRequest(http.MethodGet, "/test", nil) - require.NoError(t, err) - - rw := httptest.NewRecorder() - h.ServeHTTP(rw, req) - - assert.Equal(t, http.StatusExpectationFailed, rw.Code) - assert.Equal(t, `{"custom":"use case failed"}`+"\n", rw.Body.String()) -} - -func TestHandlerWithRouteMiddleware(t *testing.T) { - called := false - - var h http.Handler - h = http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { - called = true - }) - - h = nethttp.HandlerWithRouteMiddleware(http.MethodPost, "/test/")(h) - hr, ok := h.(rest.HandlerWithRoute) - require.True(t, ok) - assert.Equal(t, http.MethodPost, hr.RouteMethod()) - assert.Equal(t, "/test/", hr.RoutePattern()) - - h.ServeHTTP(nil, nil) - assert.True(t, called) -} - -type reqWithBody struct { - ID int `json:"id"` -} - -func (*reqWithBody) ForceRequestBody() {} - func TestHandler_ServeHTTP_getWithBody(t *testing.T) { u := struct { usecase.Interactor @@ -267,47 +278,22 @@ func TestHandler_ServeHTTP_getWithBody(t *testing.T) { assert.Equal(t, ``, rw.Body.String()) } -func TestHandler_ServeHTTP_customMapping(t *testing.T) { - u := &struct { - usecase.Interactor - usecase.WithInput - usecase.OutputWithNoContent - }{} - - u.Input = new(Input) - u.Interactor = usecase.Interact(func(_ context.Context, input, _ interface{}) error { - in, ok := input.(*Input) - assert.True(t, ok) - assert.Equal(t, 123, in.ID) +func TestHandlerWithRouteMiddleware(t *testing.T) { + called := false - return nil + var h http.Handler + h = http.HandlerFunc(func(_ http.ResponseWriter, _ *http.Request) { + called = true }) - uh := nethttp.NewHandler(u) - uh.ReqMapping = rest.RequestMapping{ - rest.ParamInQuery: map[string]string{"ID": "ident"}, - } - - ws := []func(handler http.Handler) http.Handler{ - request.DecoderMiddleware(request.NewDecoderFactory()), - nethttp.HandlerWithRouteMiddleware(http.MethodGet, "/test"), - response.EncoderMiddleware, - } - - h := nethttp.WrapHandler(uh, ws...) - - for i, w := range ws { - assert.True(t, nethttp.MiddlewareIsWrapper(w), i) - } - - req, err := http.NewRequest(http.MethodGet, "/test?ident=123", nil) - require.NoError(t, err) - - rw := httptest.NewRecorder() - h.ServeHTTP(rw, req) + h = nethttp.HandlerWithRouteMiddleware(http.MethodPost, "/test/")(h) + hr, ok := h.(rest.HandlerWithRoute) + require.True(t, ok) + assert.Equal(t, http.MethodPost, hr.RouteMethod()) + assert.Equal(t, "/test/", hr.RoutePattern()) - assert.Equal(t, http.StatusNoContent, rw.Code) - assert.Equal(t, "", rw.Body.String()) + h.ServeHTTP(nil, nil) + assert.True(t, called) } func TestOptionsMiddleware(t *testing.T) { @@ -344,3 +330,17 @@ func TestOptionsMiddleware(t *testing.T) { assert.EqualError(t, loggedErr, "failed") assert.Equal(t, `{"foo":"failed"}`+"\n", rw.Body.String()) } + +type Input struct { + ID int +} + +type Output struct { + Value string `json:"value"` +} + +type reqWithBody struct { + ID int `json:"id"` +} + +func (*reqWithBody) ForceRequestBody() {} diff --git a/nethttp/openapi.go b/nethttp/openapi.go index 1c083ec..ce5e331 100644 --- a/nethttp/openapi.go +++ b/nethttp/openapi.go @@ -9,79 +9,49 @@ import ( "github.com/swaggest/rest/openapi" ) -// OpenAPIMiddleware reads info and adds validation to handler. -func OpenAPIMiddleware(s *openapi.Collector) func(http.Handler) http.Handler { - return func(h http.Handler) http.Handler { - if IsWrapperChecker(h) { - return h - } - - var ( - withRoute rest.HandlerWithRoute - handler *Handler - ) - - if !HandlerAs(h, &withRoute) || !HandlerAs(h, &handler) { - return h +// AnnotateOpenAPI applies OpenAPI annotation to relevant handlers. +// +// Deprecated: use OpenAPIAnnotationsMiddleware. +func AnnotateOpenAPI( + s *openapi.Collector, + setup ...func(op *openapi3.Operation) error, +) func(http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + if IsWrapperChecker(next) { + return next } - var methods []string - - method := withRoute.RouteMethod() - - if method == "" { - methods = []string{"get", "put", "post", "delete", "options", "head", "patch", "trace"} - } else { - methods = []string{method} - } + var withRoute rest.HandlerWithRoute - for _, m := range methods { - err := s.CollectUseCase( - m, + if HandlerAs(next, &withRoute) { + s.Annotate( + withRoute.RouteMethod(), withRoute.RoutePattern(), - handler.UseCase(), - handler.HandlerTrait, + setup..., ) - if err != nil { - panic(err) - } } - return h + return next } } -// AuthMiddleware creates middleware to expose security scheme. -func AuthMiddleware( +// APIKeySecurityMiddleware creates middleware to expose API Key security schema. +func APIKeySecurityMiddleware( c *openapi.Collector, - name string, + name string, fieldName string, fieldIn oapi.In, description string, options ...func(*MiddlewareConfig), ) func(http.Handler) http.Handler { - cfg := MiddlewareConfig{} - - for _, o := range options { - o(&cfg) - } + c.SpecSchema().SetAPIKeySecurity(name, fieldName, fieldIn, description) - return securityMiddleware(c, name, cfg) + return AuthMiddleware(c, name, options...) } -// SecurityMiddleware creates middleware to expose security scheme. -// -// Deprecated: use AuthMiddleware. -func SecurityMiddleware( +// AuthMiddleware creates middleware to expose security scheme. +func AuthMiddleware( c *openapi.Collector, name string, - scheme openapi3.SecurityScheme, options ...func(*MiddlewareConfig), ) func(http.Handler) http.Handler { - c.Reflector().SpecEns().ComponentsEns().SecuritySchemesEns().WithMapOfSecuritySchemeOrRefValuesItem( - name, - openapi3.SecuritySchemeOrRef{ - SecurityScheme: &scheme, - }, - ) - cfg := MiddlewareConfig{} for _, o := range options { @@ -91,17 +61,6 @@ func SecurityMiddleware( return securityMiddleware(c, name, cfg) } -// APIKeySecurityMiddleware creates middleware to expose API Key security schema. -func APIKeySecurityMiddleware( - c *openapi.Collector, - name string, fieldName string, fieldIn oapi.In, description string, - options ...func(*MiddlewareConfig), -) func(http.Handler) http.Handler { - c.SpecSchema().SetAPIKeySecurity(name, fieldName, fieldIn, description) - - return AuthMiddleware(c, name, options...) -} - // HTTPBasicSecurityMiddleware creates middleware to expose HTTP Basic security schema. func HTTPBasicSecurityMiddleware( c *openapi.Collector, @@ -124,12 +83,10 @@ func HTTPBearerSecurityMiddleware( return AuthMiddleware(c, name, options...) } -// AnnotateOpenAPI applies OpenAPI annotation to relevant handlers. -// -// Deprecated: use OpenAPIAnnotationsMiddleware. -func AnnotateOpenAPI( +// OpenAPIAnnotationsMiddleware applies OpenAPI annotations to handlers. +func OpenAPIAnnotationsMiddleware( s *openapi.Collector, - setup ...func(op *openapi3.Operation) error, + annotations ...func(oc oapi.OperationContext) error, ) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { if IsWrapperChecker(next) { @@ -139,10 +96,13 @@ func AnnotateOpenAPI( var withRoute rest.HandlerWithRoute if HandlerAs(next, &withRoute) { - s.Annotate( - withRoute.RouteMethod(), - withRoute.RoutePattern(), - setup..., + method := withRoute.RouteMethod() + pattern := withRoute.RoutePattern() + + s.AnnotateOperation( + method, + pattern, + annotations..., ) } @@ -150,6 +110,73 @@ func AnnotateOpenAPI( } } +// OpenAPIMiddleware reads info and adds validation to handler. +func OpenAPIMiddleware(s *openapi.Collector) func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + if IsWrapperChecker(h) { + return h + } + + var ( + withRoute rest.HandlerWithRoute + handler *Handler + ) + + if !HandlerAs(h, &withRoute) || !HandlerAs(h, &handler) { + return h + } + + var methods []string + + method := withRoute.RouteMethod() + + if method == "" { + methods = []string{"get", "put", "post", "delete", "options", "head", "patch", "trace"} + } else { + methods = []string{method} + } + + for _, m := range methods { + err := s.CollectUseCase( + m, + withRoute.RoutePattern(), + handler.UseCase(), + handler.HandlerTrait, + ) + if err != nil { + panic(err) + } + } + + return h + } +} + +// SecurityMiddleware creates middleware to expose security scheme. +// +// Deprecated: use AuthMiddleware. +func SecurityMiddleware( + c *openapi.Collector, + name string, + scheme openapi3.SecurityScheme, + options ...func(*MiddlewareConfig), +) func(http.Handler) http.Handler { + c.Reflector().SpecEns().ComponentsEns().SecuritySchemesEns().WithMapOfSecuritySchemeOrRefValuesItem( + name, + openapi3.SecuritySchemeOrRef{ + SecurityScheme: &scheme, + }, + ) + + cfg := MiddlewareConfig{} + + for _, o := range options { + o(&cfg) + } + + return securityMiddleware(c, name, cfg) +} + // SecurityResponse is a security middleware option to customize response structure and status. func SecurityResponse(structure interface{}, httpStatus int) func(config *MiddlewareConfig) { return func(config *MiddlewareConfig) { @@ -167,33 +194,6 @@ type MiddlewareConfig struct { ResponseStatus int } -// OpenAPIAnnotationsMiddleware applies OpenAPI annotations to handlers. -func OpenAPIAnnotationsMiddleware( - s *openapi.Collector, - annotations ...func(oc oapi.OperationContext) error, -) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - if IsWrapperChecker(next) { - return next - } - - var withRoute rest.HandlerWithRoute - - if HandlerAs(next, &withRoute) { - method := withRoute.RouteMethod() - pattern := withRoute.RoutePattern() - - s.AnnotateOperation( - method, - pattern, - annotations..., - ) - } - - return next - } -} - func securityMiddleware(s *openapi.Collector, name string, cfg MiddlewareConfig) func(http.Handler) http.Handler { return OpenAPIAnnotationsMiddleware(s, func(oc oapi.OperationContext) error { oc.AddSecurity(name) diff --git a/nethttp/options.go b/nethttp/options.go index c18e1bc..600e488 100644 --- a/nethttp/options.go +++ b/nethttp/options.go @@ -10,23 +10,6 @@ import ( "github.com/swaggest/rest" ) -// OptionsMiddleware applies options to encountered nethttp.Handler. -func OptionsMiddleware(options ...func(h *Handler)) func(h http.Handler) http.Handler { - return func(h http.Handler) http.Handler { - var rh *Handler - - if HandlerAs(h, &rh) { - rh.options = append(rh.options, options...) - - for _, option := range options { - option(rh) - } - } - - return h - } -} - // AnnotateOpenAPIOperation allows customization of OpenAPI operation, that is reflected from the Handler. func AnnotateOpenAPIOperation(annotations ...func(oc openapi.OperationContext) error) func(h *Handler) { return func(h *Handler) { @@ -53,6 +36,23 @@ func AnnotateOperation(annotations ...func(operation *openapi3.Operation) error) } } +// OptionsMiddleware applies options to encountered nethttp.Handler. +func OptionsMiddleware(options ...func(h *Handler)) func(h http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + var rh *Handler + + if HandlerAs(h, &rh) { + rh.options = append(rh.options, options...) + + for _, option := range options { + option(rh) + } + } + + return h + } +} + // RequestBodyContent enables string request body with content type (e.g. text/plain). func RequestBodyContent(contentType string) func(h *Handler) { return func(h *Handler) { @@ -66,20 +66,6 @@ func RequestBodyContent(contentType string) func(h *Handler) { } } -// SuccessfulResponseContentType sets Content-Type of successful response. -func SuccessfulResponseContentType(contentType string) func(h *Handler) { - return func(h *Handler) { - h.SuccessContentType = contentType - } -} - -// SuccessStatus sets status code of successful response. -func SuccessStatus(status int) func(h *Handler) { - return func(h *Handler) { - h.SuccessStatus = status - } -} - // RequestMapping creates rest.RequestMapping from struct tags. // // This can be used to decouple mapping from usecase input with additional struct. @@ -133,3 +119,17 @@ func ResponseHeaderMapping(v interface{}) func(h *Handler) { } } } + +// SuccessfulResponseContentType sets Content-Type of successful response. +func SuccessfulResponseContentType(contentType string) func(h *Handler) { + return func(h *Handler) { + h.SuccessContentType = contentType + } +} + +// SuccessStatus sets status code of successful response. +func SuccessStatus(status int) func(h *Handler) { + return func(h *Handler) { + h.SuccessStatus = status + } +} diff --git a/nethttp/wrap.go b/nethttp/wrap.go index c9205cb..a872a1a 100644 --- a/nethttp/wrap.go +++ b/nethttp/wrap.go @@ -6,31 +6,6 @@ import ( "runtime" ) -// WrapHandler wraps http.Handler with an unwrappable middleware. -// -// Wrapping order is reversed, e.g. if you call WrapHandler(h, mw1, mw2, mw3) middlewares will be -// invoked in order of mw1(mw2(mw3(h))), mw3 first and mw1 last. So that request processing is first -// affected by mw1. -func WrapHandler(h http.Handler, mw ...func(http.Handler) http.Handler) http.Handler { - for i := len(mw) - 1; i >= 0; i-- { - w := mw[i](h) - if w == nil { - panic("nil handler returned from middleware: " + runtime.FuncForPC(reflect.ValueOf(mw[i]).Pointer()).Name()) - } - - fp := reflect.ValueOf(mw[i]).Pointer() - mwName := runtime.FuncForPC(fp).Name() - - h = &wrappedHandler{ - Handler: w, - wrapped: h, - mwName: mwName, - } - } - - return h -} - // HandlerAs finds the first http.Handler in http.Handler's chain that matches target, and if so, sets // target to that http.Handler value and returns true. // @@ -84,6 +59,31 @@ func HandlerAs(handler http.Handler, target interface{}) bool { return false } +// WrapHandler wraps http.Handler with an unwrappable middleware. +// +// Wrapping order is reversed, e.g. if you call WrapHandler(h, mw1, mw2, mw3) middlewares will be +// invoked in order of mw1(mw2(mw3(h))), mw3 first and mw1 last. So that request processing is first +// affected by mw1. +func WrapHandler(h http.Handler, mw ...func(http.Handler) http.Handler) http.Handler { + for i := len(mw) - 1; i >= 0; i-- { + w := mw[i](h) + if w == nil { + panic("nil handler returned from middleware: " + runtime.FuncForPC(reflect.ValueOf(mw[i]).Pointer()).Name()) + } + + fp := reflect.ValueOf(mw[i]).Pointer() + mwName := runtime.FuncForPC(fp).Name() + + h = &wrappedHandler{ + Handler: w, + wrapped: h, + mwName: mwName, + } + } + + return h +} + var handlerType = reflect.TypeOf((*http.Handler)(nil)).Elem() type wrappedHandler struct { diff --git a/nethttp/wrap_test.go b/nethttp/wrap_test.go index 41b9081..0904b37 100644 --- a/nethttp/wrap_test.go +++ b/nethttp/wrap_test.go @@ -8,6 +8,12 @@ import ( "github.com/swaggest/rest/nethttp" ) +func TestHandlerAs_nil(t *testing.T) { + var uh *nethttp.Handler + + assert.False(t, nethttp.HandlerAs(nil, &uh)) +} + func TestWrapHandler(t *testing.T) { var flow []string @@ -59,9 +65,3 @@ func TestWrapHandler(t *testing.T) { "mw3 after", "mw2 after", "mw1 after", }, flow) } - -func TestHandlerAs_nil(t *testing.T) { - var uh *nethttp.Handler - - assert.False(t, nethttp.HandlerAs(nil, &uh)) -} diff --git a/nethttp/wrapper.go b/nethttp/wrapper.go index f1da04f..1db105e 100644 --- a/nethttp/wrapper.go +++ b/nethttp/wrapper.go @@ -2,12 +2,6 @@ package nethttp import "net/http" -type wrapperChecker struct { - found bool -} - -func (*wrapperChecker) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {} - // IsWrapperChecker is a hack to mark middleware as a handler wrapper. // See chirouter.Wrapper Wrap() documentation for more details on the difference. // @@ -30,3 +24,9 @@ func MiddlewareIsWrapper(mw func(h http.Handler) http.Handler) bool { return wm.found } + +type wrapperChecker struct { + found bool +} + +func (*wrapperChecker) ServeHTTP(_ http.ResponseWriter, _ *http.Request) {} diff --git a/openapi/collector.go b/openapi/collector.go index 3c16579..9a61b9a 100644 --- a/openapi/collector.go +++ b/openapi/collector.go @@ -17,6 +17,19 @@ import ( "github.com/swaggest/usecase" ) +// NewCollector creates an instance of OpenAPI Collector. +func NewCollector(r openapi.Reflector) *Collector { + c := &Collector{ + ref: r, + } + + if r3, ok := r.(*openapi3.Reflector); ok { + c.gen = r3 + } + + return c +} + // Collector extracts OpenAPI documentation from HTTP handler and underlying use case interactor. type Collector struct { mu sync.Mutex @@ -44,46 +57,6 @@ type Collector struct { operationIDs map[string]bool } -// NewCollector creates an instance of OpenAPI Collector. -func NewCollector(r openapi.Reflector) *Collector { - c := &Collector{ - ref: r, - } - - if r3, ok := r.(*openapi3.Reflector); ok { - c.gen = r3 - } - - return c -} - -// SpecSchema returns OpenAPI specification schema. -func (c *Collector) SpecSchema() openapi.SpecSchema { - return c.Refl().SpecSchema() -} - -// Refl returns OpenAPI reflector. -func (c *Collector) Refl() openapi.Reflector { - if c.ref != nil { - return c.ref - } - - return c.Reflector() -} - -// Reflector is an accessor to OpenAPI Reflector instance. -func (c *Collector) Reflector() *openapi3.Reflector { - if c.ref != nil && c.gen == nil { - panic(fmt.Sprintf("conflicting OpenAPI reflector supplied: %T", c.ref)) - } - - if c.gen == nil { - c.gen = openapi3.NewReflector() - } - - return c.gen -} - // Annotate adds OpenAPI operation configuration that is applied during collection. // // Deprecated: use AnnotateOperation. @@ -111,13 +84,61 @@ func (c *Collector) AnnotateOperation(method, pattern string, setup ...func(oc o c.ocAnnotations[method+pattern] = append(c.ocAnnotations[method+pattern], setup...) } -// HasAnnotation indicates if there is at least one annotation registered for this operation. -func (c *Collector) HasAnnotation(method, pattern string) bool { - if len(c.ocAnnotations[method+pattern]) > 0 { - return true +// Collect adds use case handler to documentation. +// +// Deprecated: use CollectUseCase. +func (c *Collector) Collect( + method, pattern string, + u usecase.Interactor, + h rest.HandlerTrait, + annotations ...func(*openapi3.Operation) error, +) (err error) { + c.mu.Lock() + defer c.mu.Unlock() + + defer func() { + if err != nil { + err = fmt.Errorf("reflect API schema for %s %s: %w", method, pattern, err) + } + }() + + reflector := c.Refl() + + oc, err := reflector.NewOperationContext(method, pattern) + if err != nil { + return err } - return len(c.ocAnnotations[pattern]) > 0 + c.setupInput(oc, u, h) + c.setupOutput(oc, u, h) + c.processUseCase(oc, u, h) + + for _, setup := range c.ocAnnotations[method+pattern] { + err = setup(oc) + if err != nil { + return err + } + } + + if o3, ok := oc.(openapi3.OperationExposer); ok { + op := o3.Operation() + + for _, setup := range c.annotations[method+pattern] { + err = setup(op) + if err != nil { + return err + } + } + + for _, setup := range annotations { + err = setup(op) + if err != nil { + return err + } + } + } + + return reflector.AddOperation(oc) } // CollectOperation prepares and adds OpenAPI operation. @@ -219,210 +240,158 @@ func (c *Collector) CollectUseCase( return reflector.AddOperation(oc) } -// Collect adds use case handler to documentation. -// -// Deprecated: use CollectUseCase. -func (c *Collector) Collect( - method, pattern string, - u usecase.Interactor, - h rest.HandlerTrait, - annotations ...func(*openapi3.Operation) error, -) (err error) { - c.mu.Lock() - defer c.mu.Unlock() - - defer func() { - if err != nil { - err = fmt.Errorf("reflect API schema for %s %s: %w", method, pattern, err) - } - }() - - reflector := c.Refl() - - oc, err := reflector.NewOperationContext(method, pattern) - if err != nil { - return err +// HasAnnotation indicates if there is at least one annotation registered for this operation. +func (c *Collector) HasAnnotation(method, pattern string) bool { + if len(c.ocAnnotations[method+pattern]) > 0 { + return true } - c.setupInput(oc, u, h) - c.setupOutput(oc, u, h) - c.processUseCase(oc, u, h) + return len(c.ocAnnotations[pattern]) > 0 +} - for _, setup := range c.ocAnnotations[method+pattern] { - err = setup(oc) - if err != nil { - return err - } - } +// ProvideRequestJSONSchemas provides JSON Schemas for request structure. +func (c *Collector) ProvideRequestJSONSchemas( + method string, + input interface{}, + mapping rest.RequestMapping, + validator rest.JSONSchemaValidator, +) error { + cu := openapi.ContentUnit{} + cu.Structure = input + setFieldMapping(&cu, mapping) - if o3, ok := oc.(openapi3.OperationExposer); ok { - op := o3.Operation() + r := c.Refl() - for _, setup := range c.annotations[method+pattern] { - err = setup(op) - if err != nil { - return err - } + err := r.WalkRequestJSONSchemas(method, cu, c.jsonSchemaCallback(validator, r), func(oc openapi.OperationContext) { + fv, ok := validator.(unknownFieldsValidator) + if !ok { + return } - for _, setup := range annotations { - err = setup(op) - if err != nil { - return err + for _, in := range []openapi.In{openapi.InQuery, openapi.InCookie, openapi.InHeader} { + if oc.UnknownParamsAreForbidden(in) { + fv.ForbidUnknownParams(rest.ParamIn(in), true) } } - } + }) - return reflector.AddOperation(oc) + return err } -func (c *Collector) setupOutput(oc openapi.OperationContext, u usecase.Interactor, h rest.HandlerTrait) { - var ( - hasOutput usecase.HasOutputPort - status = http.StatusOK - noContent bool - output interface{} - contentType = h.SuccessContentType - ) - - if usecase.As(u, &hasOutput) { - output = hasOutput.OutputPort() +// ProvideResponseJSONSchemas provides JSON schemas for response structure. +func (c *Collector) ProvideResponseJSONSchemas( + statusCode int, + contentType string, + output interface{}, + headerMapping map[string]string, + validator rest.JSONSchemaValidator, +) error { + cu := openapi.ContentUnit{} + cu.Structure = output + cu.SetFieldMapping(openapi.InHeader, headerMapping) + cu.ContentType = contentType + cu.HTTPStatus = statusCode - if rest.OutputHasNoContent(output) { - status = http.StatusNoContent - noContent = true - } - } else { - status = http.StatusNoContent - noContent = true + if cu.ContentType == "" { + cu.ContentType = c.DefaultSuccessResponseContentType } - if !noContent && contentType == "" { - contentType = c.DefaultSuccessResponseContentType - } + r := c.Refl() + err := r.WalkResponseJSONSchemas(cu, c.jsonSchemaCallback(validator, r), nil) - if oc.Method() == http.MethodHead { - output = nil - } + return err +} - setupCU := func(cu *openapi.ContentUnit) { - cu.ContentType = contentType - cu.SetFieldMapping(openapi.InHeader, h.RespHeaderMapping) +// Refl returns OpenAPI reflector. +func (c *Collector) Refl() openapi.Reflector { + if c.ref != nil { + return c.ref } - if outputWithStatus, ok := output.(rest.OutputWithHTTPStatus); ok { - for _, status := range outputWithStatus.ExpectedHTTPStatuses() { - oc.AddRespStructure(output, func(cu *openapi.ContentUnit) { - cu.HTTPStatus = status - setupCU(cu) - }) - } - } else { - if h.SuccessStatus != 0 { - status = h.SuccessStatus - } + return c.Reflector() +} - oc.AddRespStructure(output, func(cu *openapi.ContentUnit) { - cu.HTTPStatus = status - setupCU(cu) - }) +// Reflector is an accessor to OpenAPI Reflector instance. +func (c *Collector) Reflector() *openapi3.Reflector { + if c.ref != nil && c.gen == nil { + panic(fmt.Sprintf("conflicting OpenAPI reflector supplied: %T", c.ref)) } -} -func (c *Collector) setupInput(oc openapi.OperationContext, u usecase.Interactor, h rest.HandlerTrait) { - var hasInput usecase.HasInputPort - - if usecase.As(u, &hasInput) { - oc.AddReqStructure(hasInput.InputPort(), func(cu *openapi.ContentUnit) { - setFieldMapping(cu, h.ReqMapping) - }) + if c.gen == nil { + c.gen = openapi3.NewReflector() } -} -func setFieldMapping(cu *openapi.ContentUnit, mapping rest.RequestMapping) { - if mapping != nil { - cu.SetFieldMapping(openapi.InQuery, mapping[rest.ParamInQuery]) - cu.SetFieldMapping(openapi.InPath, mapping[rest.ParamInPath]) - cu.SetFieldMapping(openapi.InHeader, mapping[rest.ParamInHeader]) - cu.SetFieldMapping(openapi.InCookie, mapping[rest.ParamInCookie]) - cu.SetFieldMapping(openapi.InFormData, mapping[rest.ParamInFormData]) - } + return c.gen } -func (c *Collector) processUseCase(oc openapi.OperationContext, u usecase.Interactor, h rest.HandlerTrait) { - var ( - hasName usecase.HasName - hasTitle usecase.HasTitle - hasDescription usecase.HasDescription - hasTags usecase.HasTags - hasDeprecated usecase.HasIsDeprecated - ) +func (c *Collector) ServeHTTP(rw http.ResponseWriter, _ *http.Request) { + c.mu.Lock() + defer c.mu.Unlock() - if usecase.As(u, &hasName) { - id := hasName.Name() + document, err := json.MarshalIndent(c.SpecSchema(), "", " ") + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + } - if id != "" { - if c.operationIDs == nil { - c.operationIDs = make(map[string]bool) - } + rw.Header().Set("Content-Type", "application/json") - idSuf := id - suf := 1 + _, err = rw.Write(document) + if err != nil { + http.Error(rw, err.Error(), http.StatusInternalServerError) + } +} - for c.operationIDs[idSuf] { - suf++ - idSuf = id + strconv.Itoa(suf) - } +// SpecSchema returns OpenAPI specification schema. +func (c *Collector) SpecSchema() openapi.SpecSchema { + return c.Refl().SpecSchema() +} - c.operationIDs[idSuf] = true +func (c *Collector) combineOCErrors(oc openapi.OperationContext, statusCodes []int, errsByCode map[int][]interface{}) { + for _, statusCode := range statusCodes { + errResps := errsByCode[statusCode] - oc.SetID(idSuf) + if len(errResps) == 1 || c.CombineErrors == "" { + c.setOCJSONResponse(oc, errResps[0], statusCode) + } else { + switch c.CombineErrors { + case "oneOf": + c.setOCJSONResponse(oc, jsonschema.OneOf(errResps...), statusCode) + case "anyOf": + c.setOCJSONResponse(oc, jsonschema.AnyOf(errResps...), statusCode) + default: + panic("oneOf/anyOf expected for openapi.Collector.CombineErrors, " + + c.CombineErrors + " received") + } } } +} - if usecase.As(u, &hasTitle) { - title := hasTitle.Title() - - if title != "" { - oc.SetSummary(hasTitle.Title()) +func (c *Collector) jsonSchemaCallback(validator rest.JSONSchemaValidator, r openapi.Reflector) openapi.JSONSchemaCallback { + return func(in openapi.In, paramName string, schema *jsonschema.SchemaOrBool, required bool) error { + loc := string(in) + "." + paramName + if loc == "body.body" { + loc = "body" } - } - if usecase.As(u, &hasTags) { - tags := hasTags.Tags() + if schema == nil || schema.IsTrivial(r.ResolveJSONSchemaRef) { + if err := validator.AddSchema(rest.ParamIn(in), paramName, nil, required); err != nil { + return fmt.Errorf("add validation schema %s: %w", loc, err) + } - if len(tags) > 0 { - oc.SetTags(hasTags.Tags()...) + return nil } - } - if usecase.As(u, &hasDescription) { - desc := hasDescription.Description() - - if desc != "" { - oc.SetDescription(hasDescription.Description()) + schemaData, err := schema.JSONSchemaBytes() + if err != nil { + return fmt.Errorf("marshal schema %s: %w", loc, err) } - } - - if usecase.As(u, &hasDeprecated) && hasDeprecated.IsDeprecated() { - oc.SetIsDeprecated(true) - } - - c.processOCExpectedErrors(oc, u, h) -} -func (c *Collector) setOCJSONResponse(oc openapi.OperationContext, output interface{}, statusCode int) { - oc.AddRespStructure(output, func(cu *openapi.ContentUnit) { - cu.HTTPStatus = statusCode - - if described, ok := output.(jsonschema.Described); ok { - cu.Description = described.Description() + if err = validator.AddSchema(rest.ParamIn(in), paramName, schemaData, required); err != nil { + return fmt.Errorf("add validation schema %s: %w", loc, err) } - if output != nil { - cu.ContentType = c.DefaultErrorResponseContentType - } - }) + return nil + } } func (c *Collector) processOCExpectedErrors(oc openapi.OperationContext, u usecase.Interactor, h rest.HandlerTrait) { @@ -477,124 +446,155 @@ func (c *Collector) processOCExpectedErrors(oc openapi.OperationContext, u useca c.combineOCErrors(oc, statusCodes, errsByCode) } -func (c *Collector) combineOCErrors(oc openapi.OperationContext, statusCodes []int, errsByCode map[int][]interface{}) { - for _, statusCode := range statusCodes { - errResps := errsByCode[statusCode] +func (c *Collector) processUseCase(oc openapi.OperationContext, u usecase.Interactor, h rest.HandlerTrait) { + var ( + hasName usecase.HasName + hasTitle usecase.HasTitle + hasDescription usecase.HasDescription + hasTags usecase.HasTags + hasDeprecated usecase.HasIsDeprecated + ) - if len(errResps) == 1 || c.CombineErrors == "" { - c.setOCJSONResponse(oc, errResps[0], statusCode) - } else { - switch c.CombineErrors { - case "oneOf": - c.setOCJSONResponse(oc, jsonschema.OneOf(errResps...), statusCode) - case "anyOf": - c.setOCJSONResponse(oc, jsonschema.AnyOf(errResps...), statusCode) - default: - panic("oneOf/anyOf expected for openapi.Collector.CombineErrors, " + - c.CombineErrors + " received") + if usecase.As(u, &hasName) { + id := hasName.Name() + + if id != "" { + if c.operationIDs == nil { + c.operationIDs = make(map[string]bool) } - } - } -} -type unknownFieldsValidator interface { - ForbidUnknownParams(in rest.ParamIn, forbidden bool) -} + idSuf := id + suf := 1 -// ProvideRequestJSONSchemas provides JSON Schemas for request structure. -func (c *Collector) ProvideRequestJSONSchemas( - method string, - input interface{}, - mapping rest.RequestMapping, - validator rest.JSONSchemaValidator, -) error { - cu := openapi.ContentUnit{} - cu.Structure = input - setFieldMapping(&cu, mapping) + for c.operationIDs[idSuf] { + suf++ + idSuf = id + strconv.Itoa(suf) + } - r := c.Refl() + c.operationIDs[idSuf] = true - err := r.WalkRequestJSONSchemas(method, cu, c.jsonSchemaCallback(validator, r), func(oc openapi.OperationContext) { - fv, ok := validator.(unknownFieldsValidator) - if !ok { - return + oc.SetID(idSuf) } + } - for _, in := range []openapi.In{openapi.InQuery, openapi.InCookie, openapi.InHeader} { - if oc.UnknownParamsAreForbidden(in) { - fv.ForbidUnknownParams(rest.ParamIn(in), true) - } + if usecase.As(u, &hasTitle) { + title := hasTitle.Title() + + if title != "" { + oc.SetSummary(hasTitle.Title()) } - }) + } - return err -} + if usecase.As(u, &hasTags) { + tags := hasTags.Tags() -// ProvideResponseJSONSchemas provides JSON schemas for response structure. -func (c *Collector) ProvideResponseJSONSchemas( - statusCode int, - contentType string, - output interface{}, - headerMapping map[string]string, - validator rest.JSONSchemaValidator, -) error { - cu := openapi.ContentUnit{} - cu.Structure = output - cu.SetFieldMapping(openapi.InHeader, headerMapping) - cu.ContentType = contentType - cu.HTTPStatus = statusCode + if len(tags) > 0 { + oc.SetTags(hasTags.Tags()...) + } + } - if cu.ContentType == "" { - cu.ContentType = c.DefaultSuccessResponseContentType + if usecase.As(u, &hasDescription) { + desc := hasDescription.Description() + + if desc != "" { + oc.SetDescription(hasDescription.Description()) + } } - r := c.Refl() - err := r.WalkResponseJSONSchemas(cu, c.jsonSchemaCallback(validator, r), nil) + if usecase.As(u, &hasDeprecated) && hasDeprecated.IsDeprecated() { + oc.SetIsDeprecated(true) + } - return err + c.processOCExpectedErrors(oc, u, h) } -func (c *Collector) jsonSchemaCallback(validator rest.JSONSchemaValidator, r openapi.Reflector) openapi.JSONSchemaCallback { - return func(in openapi.In, paramName string, schema *jsonschema.SchemaOrBool, required bool) error { - loc := string(in) + "." + paramName - if loc == "body.body" { - loc = "body" - } - - if schema == nil || schema.IsTrivial(r.ResolveJSONSchemaRef) { - if err := validator.AddSchema(rest.ParamIn(in), paramName, nil, required); err != nil { - return fmt.Errorf("add validation schema %s: %w", loc, err) - } +func (c *Collector) setOCJSONResponse(oc openapi.OperationContext, output interface{}, statusCode int) { + oc.AddRespStructure(output, func(cu *openapi.ContentUnit) { + cu.HTTPStatus = statusCode - return nil + if described, ok := output.(jsonschema.Described); ok { + cu.Description = described.Description() } - schemaData, err := schema.JSONSchemaBytes() - if err != nil { - return fmt.Errorf("marshal schema %s: %w", loc, err) + if output != nil { + cu.ContentType = c.DefaultErrorResponseContentType } + }) +} - if err = validator.AddSchema(rest.ParamIn(in), paramName, schemaData, required); err != nil { - return fmt.Errorf("add validation schema %s: %w", loc, err) - } +func (c *Collector) setupInput(oc openapi.OperationContext, u usecase.Interactor, h rest.HandlerTrait) { + var hasInput usecase.HasInputPort - return nil + if usecase.As(u, &hasInput) { + oc.AddReqStructure(hasInput.InputPort(), func(cu *openapi.ContentUnit) { + setFieldMapping(cu, h.ReqMapping) + }) } } -func (c *Collector) ServeHTTP(rw http.ResponseWriter, _ *http.Request) { - c.mu.Lock() - defer c.mu.Unlock() +func (c *Collector) setupOutput(oc openapi.OperationContext, u usecase.Interactor, h rest.HandlerTrait) { + var ( + hasOutput usecase.HasOutputPort + status = http.StatusOK + noContent bool + output interface{} + contentType = h.SuccessContentType + ) - document, err := json.MarshalIndent(c.SpecSchema(), "", " ") - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) + if usecase.As(u, &hasOutput) { + output = hasOutput.OutputPort() + + if rest.OutputHasNoContent(output) { + status = http.StatusNoContent + noContent = true + } + } else { + status = http.StatusNoContent + noContent = true } - rw.Header().Set("Content-Type", "application/json") + if !noContent && contentType == "" { + contentType = c.DefaultSuccessResponseContentType + } - _, err = rw.Write(document) - if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) + if oc.Method() == http.MethodHead { + output = nil } + + setupCU := func(cu *openapi.ContentUnit) { + cu.ContentType = contentType + cu.SetFieldMapping(openapi.InHeader, h.RespHeaderMapping) + } + + if outputWithStatus, ok := output.(rest.OutputWithHTTPStatus); ok { + for _, status := range outputWithStatus.ExpectedHTTPStatuses() { + oc.AddRespStructure(output, func(cu *openapi.ContentUnit) { + cu.HTTPStatus = status + setupCU(cu) + }) + } + } else { + if h.SuccessStatus != 0 { + status = h.SuccessStatus + } + + oc.AddRespStructure(output, func(cu *openapi.ContentUnit) { + cu.HTTPStatus = status + setupCU(cu) + }) + } +} + +func setFieldMapping(cu *openapi.ContentUnit, mapping rest.RequestMapping) { + if mapping != nil { + cu.SetFieldMapping(openapi.InQuery, mapping[rest.ParamInQuery]) + cu.SetFieldMapping(openapi.InPath, mapping[rest.ParamInPath]) + cu.SetFieldMapping(openapi.InHeader, mapping[rest.ParamInHeader]) + cu.SetFieldMapping(openapi.InCookie, mapping[rest.ParamInCookie]) + cu.SetFieldMapping(openapi.InFormData, mapping[rest.ParamInFormData]) + } +} + +type unknownFieldsValidator interface { + ForbidUnknownParams(in rest.ParamIn, forbidden bool) } diff --git a/openapi/collector_test.go b/openapi/collector_test.go index cbdc566..c82fdf6 100644 --- a/openapi/collector_test.go +++ b/openapi/collector_test.go @@ -22,31 +22,6 @@ import ( "github.com/swaggest/usecase/status" ) -var _ rest.JSONSchemaValidator = validatorMock{} - -type validatorMock struct { - ValidateDataFunc func(in rest.ParamIn, namedData map[string]interface{}) error - ValidateJSONBodyFunc func(jsonBody []byte) error - HasConstraintsFunc func(in rest.ParamIn) bool - AddSchemaFunc func(in rest.ParamIn, name string, schemaData []byte, required bool) error -} - -func (v validatorMock) ValidateData(in rest.ParamIn, namedData map[string]interface{}) error { - return v.ValidateDataFunc(in, namedData) -} - -func (v validatorMock) ValidateJSONBody(jsonBody []byte) error { - return v.ValidateJSONBodyFunc(jsonBody) -} - -func (v validatorMock) HasConstraints(in rest.ParamIn) bool { - return v.HasConstraintsFunc(in) -} - -func (v validatorMock) AddSchema(in rest.ParamIn, name string, schemaData []byte, required bool) error { - return v.AddSchemaFunc(in, name, schemaData, required) -} - func TestCollector_Collect(t *testing.T) { c := openapi.Collector{ BasePath: "http://example.com/", @@ -113,131 +88,6 @@ func TestCollector_Collect(t *testing.T) { assert.NoError(t, c.ProvideResponseJSONSchemas(http.StatusOK, "application/json", new(output), nil, val)) } -func TestCollector_Collect_requestMapping(t *testing.T) { - type input struct { - InHeader string `minLength:"2"` - InQuery jschema.Date - InCookie *time.Time - InFormData time.Time - InPath bool - InFile multipart.File - } - - u := usecase.IOInteractor{} - - u.SetTitle("Title") - u.SetName("name") - u.SetIsDeprecated(true) - u.Input = new(input) - - mapping := rest.RequestMapping{ - rest.ParamInFormData: map[string]string{"InFormData": "in_form_data", "InFile": "upload"}, - rest.ParamInCookie: map[string]string{"InCookie": "in_cookie"}, - rest.ParamInQuery: map[string]string{"InQuery": "in_query"}, - rest.ParamInHeader: map[string]string{"InHeader": "X-In-Header"}, - rest.ParamInPath: map[string]string{"InPath": "in-path"}, - } - - h := rest.HandlerTrait{ - ReqMapping: mapping, - } - - collector := openapi.Collector{} - - require.NoError(t, collector.CollectUseCase(http.MethodPost, "/test/{in-path}", u, h)) - require.NoError(t, collector.CollectUseCase(http.MethodPut, "/test/{in-path}", u, h)) - - assertjson.EqMarshal(t, `{ - "openapi":"3.0.3","info":{"title":"","version":""}, - "paths":{ - "/test/{in-path}":{ - "post":{ - "summary":"Title","operationId":"name", - "parameters":[ - { - "name":"in_query","in":"query", - "schema":{"type":"string","format":"date"} - }, - { - "name":"in-path","in":"path","required":true, - "schema":{"type":"boolean"} - }, - { - "name":"in_cookie","in":"cookie", - "schema":{"type":"string","format":"date-time","nullable":true} - }, - { - "name":"X-In-Header","in":"header", - "schema":{"minLength":2,"type":"string"} - } - ], - "requestBody":{ - "content":{ - "multipart/form-data":{"schema":{"$ref":"#/components/schemas/OpenapiTestInput"}} - } - }, - "responses":{"204":{"description":"No Content"}},"deprecated":true - }, - "put":{ - "summary":"Title","operationId":"name2", - "parameters":[ - { - "name":"in_query","in":"query", - "schema":{"type":"string","format":"date"} - }, - { - "name":"in-path","in":"path","required":true, - "schema":{"type":"boolean"} - }, - { - "name":"in_cookie","in":"cookie", - "schema":{"type":"string","format":"date-time","nullable":true} - }, - { - "name":"X-In-Header","in":"header", - "schema":{"minLength":2,"type":"string"} - } - ], - "requestBody":{ - "content":{ - "multipart/form-data":{"schema":{"$ref":"#/components/schemas/OpenapiTestInput"}} - } - }, - "responses":{"204":{"description":"No Content"}},"deprecated":true - } - } - }, - "components":{ - "schemas":{ - "MultipartFile":{"type":"string","format":"binary"}, - "OpenapiTestInput":{ - "type":"object", - "properties":{ - "in_form_data":{"type":"string","format":"date-time"}, - "upload":{"$ref":"#/components/schemas/MultipartFile"} - } - } - } - } - }`, collector.SpecSchema()) - - val := validatorMock{ - AddSchemaFunc: func(_ rest.ParamIn, _ string, _ []byte, _ bool) error { - return nil - }, - } - assert.NoError(t, collector.ProvideRequestJSONSchemas(http.MethodPost, new(input), mapping, val)) -} - -// anotherErr is another custom error. -type anotherErr struct { - Foo int `json:"foo"` -} - -func (anotherErr) Error() string { - return "foo happened" -} - func TestCollector_Collect_CombineErrors(t *testing.T) { u := usecase.IOInteractor{} @@ -320,17 +170,56 @@ func TestCollector_Collect_CombineErrors(t *testing.T) { }`, collector.SpecSchema()) } -// Output that implements OutputWithHTTPStatus interface. -type outputWithHTTPStatuses struct { - Number int `json:"number"` -} +func TestCollector_Collect_head_no_response(t *testing.T) { + c := openapi.Collector{} + u := usecase.IOInteractor{} -func (outputWithHTTPStatuses) HTTPStatus() int { - return http.StatusCreated -} + type resp struct { + Foo string `json:"foo"` + Bar string `header:"X-Bar"` + } -func (outputWithHTTPStatuses) ExpectedHTTPStatuses() []int { - return []int{http.StatusCreated, http.StatusOK} + u.Output = new(resp) + + require.NoError(t, c.CollectUseCase(http.MethodHead, "/foo", u, rest.HandlerTrait{ + ReqValidator: &jsonschema.Validator{}, + })) + + require.NoError(t, c.CollectUseCase(http.MethodGet, "/foo", u, rest.HandlerTrait{ + ReqValidator: &jsonschema.Validator{}, + })) + + assertjson.EqMarshal(t, `{ + "openapi":"3.0.3","info":{"title":"","version":""}, + "paths":{ + "/foo":{ + "get":{ + "responses":{ + "200":{ + "description":"OK", + "headers":{"X-Bar":{"style":"simple","schema":{"type":"string"}}}, + "content":{ + "application/json":{"schema":{"$ref":"#/components/schemas/OpenapiTestResp"}} + } + } + } + }, + "head":{ + "responses":{ + "200":{ + "description":"OK", + "headers":{"X-Bar":{"style":"simple","schema":{"type":"string"}}} + } + } + } + } + }, + "components":{ + "schemas":{ + "OpenapiTestResp":{"type":"object","properties":{"foo":{"type":"string"}}} + } + } + }`, c.SpecSchema()) } func TestCollector_Collect_multipleHttpStatuses(t *testing.T) { @@ -462,54 +351,165 @@ func TestCollector_Collect_queryObject(t *testing.T) { }`, c.SpecSchema()) } -func TestCollector_Collect_head_no_response(t *testing.T) { - c := openapi.Collector{} +func TestCollector_Collect_requestMapping(t *testing.T) { + type input struct { + InHeader string `minLength:"2"` + InQuery jschema.Date + InCookie *time.Time + InFormData time.Time + InPath bool + InFile multipart.File + } + u := usecase.IOInteractor{} - type resp struct { - Foo string `json:"foo"` - Bar string `header:"X-Bar"` + u.SetTitle("Title") + u.SetName("name") + u.SetIsDeprecated(true) + u.Input = new(input) + + mapping := rest.RequestMapping{ + rest.ParamInFormData: map[string]string{"InFormData": "in_form_data", "InFile": "upload"}, + rest.ParamInCookie: map[string]string{"InCookie": "in_cookie"}, + rest.ParamInQuery: map[string]string{"InQuery": "in_query"}, + rest.ParamInHeader: map[string]string{"InHeader": "X-In-Header"}, + rest.ParamInPath: map[string]string{"InPath": "in-path"}, } - u.Output = new(resp) + h := rest.HandlerTrait{ + ReqMapping: mapping, + } - require.NoError(t, c.CollectUseCase(http.MethodHead, "/foo", u, rest.HandlerTrait{ - ReqValidator: &jsonschema.Validator{}, - })) + collector := openapi.Collector{} - require.NoError(t, c.CollectUseCase(http.MethodGet, "/foo", u, rest.HandlerTrait{ - ReqValidator: &jsonschema.Validator{}, - })) + require.NoError(t, collector.CollectUseCase(http.MethodPost, "/test/{in-path}", u, h)) + require.NoError(t, collector.CollectUseCase(http.MethodPut, "/test/{in-path}", u, h)) assertjson.EqMarshal(t, `{ "openapi":"3.0.3","info":{"title":"","version":""}, "paths":{ - "/foo":{ - "get":{ - "responses":{ - "200":{ - "description":"OK", - "headers":{"X-Bar":{"style":"simple","schema":{"type":"string"}}}, - "content":{ - "application/json":{"schema":{"$ref":"#/components/schemas/OpenapiTestResp"}} - } + "/test/{in-path}":{ + "post":{ + "summary":"Title","operationId":"name", + "parameters":[ + { + "name":"in_query","in":"query", + "schema":{"type":"string","format":"date"} + }, + { + "name":"in-path","in":"path","required":true, + "schema":{"type":"boolean"} + }, + { + "name":"in_cookie","in":"cookie", + "schema":{"type":"string","format":"date-time","nullable":true} + }, + { + "name":"X-In-Header","in":"header", + "schema":{"minLength":2,"type":"string"} } - } + ], + "requestBody":{ + "content":{ + "multipart/form-data":{"schema":{"$ref":"#/components/schemas/OpenapiTestInput"}} + } + }, + "responses":{"204":{"description":"No Content"}},"deprecated":true }, - "head":{ - "responses":{ - "200":{ - "description":"OK", - "headers":{"X-Bar":{"style":"simple","schema":{"type":"string"}}} + "put":{ + "summary":"Title","operationId":"name2", + "parameters":[ + { + "name":"in_query","in":"query", + "schema":{"type":"string","format":"date"} + }, + { + "name":"in-path","in":"path","required":true, + "schema":{"type":"boolean"} + }, + { + "name":"in_cookie","in":"cookie", + "schema":{"type":"string","format":"date-time","nullable":true} + }, + { + "name":"X-In-Header","in":"header", + "schema":{"minLength":2,"type":"string"} } - } + ], + "requestBody":{ + "content":{ + "multipart/form-data":{"schema":{"$ref":"#/components/schemas/OpenapiTestInput"}} + } + }, + "responses":{"204":{"description":"No Content"}},"deprecated":true } } }, "components":{ "schemas":{ - "OpenapiTestResp":{"type":"object","properties":{"foo":{"type":"string"}}} + "MultipartFile":{"type":"string","format":"binary"}, + "OpenapiTestInput":{ + "type":"object", + "properties":{ + "in_form_data":{"type":"string","format":"date-time"}, + "upload":{"$ref":"#/components/schemas/MultipartFile"} + } + } } } - }`, c.SpecSchema()) + }`, collector.SpecSchema()) + + val := validatorMock{ + AddSchemaFunc: func(_ rest.ParamIn, _ string, _ []byte, _ bool) error { + return nil + }, + } + assert.NoError(t, collector.ProvideRequestJSONSchemas(http.MethodPost, new(input), mapping, val)) +} + +var _ rest.JSONSchemaValidator = validatorMock{} + +// anotherErr is another custom error. +type anotherErr struct { + Foo int `json:"foo"` +} + +func (anotherErr) Error() string { + return "foo happened" +} + +// Output that implements OutputWithHTTPStatus interface. +type outputWithHTTPStatuses struct { + Number int `json:"number"` +} + +func (outputWithHTTPStatuses) ExpectedHTTPStatuses() []int { + return []int{http.StatusCreated, http.StatusOK} +} + +func (outputWithHTTPStatuses) HTTPStatus() int { + return http.StatusCreated +} + +type validatorMock struct { + ValidateDataFunc func(in rest.ParamIn, namedData map[string]interface{}) error + ValidateJSONBodyFunc func(jsonBody []byte) error + HasConstraintsFunc func(in rest.ParamIn) bool + AddSchemaFunc func(in rest.ParamIn, name string, schemaData []byte, required bool) error +} + +func (v validatorMock) AddSchema(in rest.ParamIn, name string, schemaData []byte, required bool) error { + return v.AddSchemaFunc(in, name, schemaData, required) +} + +func (v validatorMock) HasConstraints(in rest.ParamIn) bool { + return v.HasConstraintsFunc(in) +} + +func (v validatorMock) ValidateData(in rest.ParamIn, namedData map[string]interface{}) error { + return v.ValidateDataFunc(in, namedData) +} + +func (v validatorMock) ValidateJSONBody(jsonBody []byte) error { + return v.ValidateJSONBodyFunc(jsonBody) } diff --git a/request.go b/request.go index da6de43..7475e70 100644 --- a/request.go +++ b/request.go @@ -1,8 +1,5 @@ package rest -// ParamIn defines parameter location. -type ParamIn string - const ( // ParamInPath indicates path parameters, such as `/users/{id}`. ParamInPath = ParamIn("path") @@ -24,15 +21,8 @@ const ( ParamInHeader = ParamIn("header") ) -// RequestMapping describes how decoded request should be applied to container struct. -// -// It is defined as a map by parameter location. -// Each item is a map with struct field name as key and decoded field name as value. -// -// Example: -// -// map[rest.ParamIn]map[string]string{rest.ParamInQuery:map[string]string{"ID": "id", "FirstName": "first-name"}} -type RequestMapping map[ParamIn]map[string]string +// ParamIn defines parameter location. +type ParamIn string // RequestErrors is a list of validation or decoding errors. // @@ -54,3 +44,13 @@ func (re RequestErrors) Fields() map[string]interface{} { return res } + +// RequestMapping describes how decoded request should be applied to container struct. +// +// It is defined as a map by parameter location. +// Each item is a map with struct field name as key and decoded field name as value. +// +// Example: +// +// map[rest.ParamIn]map[string]string{rest.ParamInQuery:map[string]string{"ID": "id", "FirstName": "first-name"}} +type RequestMapping map[ParamIn]map[string]string diff --git a/request/decoder.go b/request/decoder.go index 54c6262..224fee4 100644 --- a/request/decoder.go +++ b/request/decoder.go @@ -11,6 +11,21 @@ import ( "github.com/swaggest/rest/nethttp" ) +// EmbeddedSetter can capture *http.Resuest in your input structure. +type EmbeddedSetter struct { + r *http.Request +} + +// Request is an accessor. +func (e *EmbeddedSetter) Request() *http.Request { + return e.r +} + +// SetRequest implements Setter. +func (e *EmbeddedSetter) SetRequest(r *http.Request) { + e.r = r +} + type ( // Loader loads data from http.Request. // @@ -30,19 +45,28 @@ type ( valueDecoderFunc func(r *http.Request, v interface{}, validator rest.Validator) error ) -// EmbeddedSetter can capture *http.Resuest in your input structure. -type EmbeddedSetter struct { - r *http.Request -} +const defaultMaxMemory = 32 << 20 // 32 MB -// SetRequest implements Setter. -func (e *EmbeddedSetter) SetRequest(r *http.Request) { - e.r = r -} +var _ nethttp.RequestDecoder = &decoder{} -// Request is an accessor. -func (e *EmbeddedSetter) Request() *http.Request { - return e.r +func makeDecoder(in rest.ParamIn, formDecoder *form.Decoder, decoderFunc decoderFunc) valueDecoderFunc { + return func(r *http.Request, v interface{}, validator rest.Validator) error { + ct := r.Header.Get("Content-Type") + if in == rest.ParamInFormData && ct != "" && !strings.HasPrefix(ct, "multipart/form-data") && ct != "application/x-www-form-urlencoded" { + return nil + } + + values, err := decoderFunc(r) + if err != nil { + return err + } + + if validator != nil { + return decodeValidate(formDecoder, v, values, in, validator) + } + + return formDecoder.Decode(v, values) + } } func decodeValidate(d *form.Decoder, v interface{}, p url.Values, in rest.ParamIn, val rest.Validator) error { @@ -76,24 +100,65 @@ func decodeValidate(d *form.Decoder, v interface{}, p url.Values, in rest.ParamI return val.ValidateData(in, goValues) } -func makeDecoder(in rest.ParamIn, formDecoder *form.Decoder, decoderFunc decoderFunc) valueDecoderFunc { - return func(r *http.Request, v interface{}, validator rest.Validator) error { - ct := r.Header.Get("Content-Type") - if in == rest.ParamInFormData && ct != "" && !strings.HasPrefix(ct, "multipart/form-data") && ct != "application/x-www-form-urlencoded" { - return nil - } +func contentTypeBodyToURLValues(r *http.Request) (url.Values, error) { + b, err := ioutil.ReadAll(r.Body) + if err != nil { + return nil, err + } - values, err := decoderFunc(r) + return url.Values{ + r.Header.Get("Content-Type"): []string{string(b)}, + }, nil +} + +func cookiesToURLValues(r *http.Request) (url.Values, error) { + cookies := r.Cookies() + params := make(url.Values, len(cookies)) + + for _, c := range cookies { + params[c.Name] = []string{c.Value} + } + + return params, nil +} + +// 32 MB +func formDataToURLValues(r *http.Request) (url.Values, error) { + if r.ContentLength == 0 { + return nil, nil + } + + if strings.HasPrefix(r.Header.Get("Content-Type"), "multipart/form-data") { + err := r.ParseMultipartForm(defaultMaxMemory) if err != nil { - return err + return nil, err } + } else if err := r.ParseForm(); err != nil { + return nil, err + } - if validator != nil { - return decodeValidate(formDecoder, v, values, in, validator) - } + return r.PostForm, nil +} - return formDecoder.Decode(v, values) +func formToURLValues(r *http.Request) (url.Values, error) { + if strings.HasPrefix(r.Header.Get("Content-Type"), "multipart/form-data") { + err := r.ParseMultipartForm(defaultMaxMemory) + if err != nil { + return nil, err + } + } else if err := r.ParseForm(); err != nil { + return nil, err } + + return r.Form, nil +} + +func headerToURLValues(r *http.Request) (url.Values, error) { + return url.Values(r.Header), nil +} + +func queryToURLValues(r *http.Request) (url.Values, error) { + return r.URL.Query(), nil } // decoder extracts Go value from *http.Request. @@ -104,8 +169,6 @@ type decoder struct { isReqSetter bool } -var _ nethttp.RequestDecoder = &decoder{} - // Decode populates and validates input with data from http request. func (d *decoder) Decode(r *http.Request, input interface{}, validator rest.Validator) error { if d.isReqSetter { @@ -139,65 +202,3 @@ func (d *decoder) Decode(r *http.Request, input interface{}, validator rest.Vali return nil } - -const defaultMaxMemory = 32 << 20 // 32 MB - -func formDataToURLValues(r *http.Request) (url.Values, error) { - if r.ContentLength == 0 { - return nil, nil - } - - if strings.HasPrefix(r.Header.Get("Content-Type"), "multipart/form-data") { - err := r.ParseMultipartForm(defaultMaxMemory) - if err != nil { - return nil, err - } - } else if err := r.ParseForm(); err != nil { - return nil, err - } - - return r.PostForm, nil -} - -func headerToURLValues(r *http.Request) (url.Values, error) { - return url.Values(r.Header), nil -} - -func queryToURLValues(r *http.Request) (url.Values, error) { - return r.URL.Query(), nil -} - -func formToURLValues(r *http.Request) (url.Values, error) { - if strings.HasPrefix(r.Header.Get("Content-Type"), "multipart/form-data") { - err := r.ParseMultipartForm(defaultMaxMemory) - if err != nil { - return nil, err - } - } else if err := r.ParseForm(); err != nil { - return nil, err - } - - return r.Form, nil -} - -func cookiesToURLValues(r *http.Request) (url.Values, error) { - cookies := r.Cookies() - params := make(url.Values, len(cookies)) - - for _, c := range cookies { - params[c.Name] = []string{c.Value} - } - - return params, nil -} - -func contentTypeBodyToURLValues(r *http.Request) (url.Values, error) { - b, err := ioutil.ReadAll(r.Body) - if err != nil { - return nil, err - } - - return url.Values{ - r.Header.Get("Content-Type"): []string{string(b)}, - }, nil -} diff --git a/request/decoder_test.go b/request/decoder_test.go index 4c3cc75..4bb263e 100644 --- a/request/decoder_test.go +++ b/request/decoder_test.go @@ -48,185 +48,6 @@ func BenchmarkDecoder_Decode(b *testing.B) { } } -type reqTest struct { - Header int `header:"X-In-HeAdEr" required:"true"` // Headers are mapped using canonical names. - Cookie string `cookie:"in_cookie"` - Query string `query:"in_query"` - Path string `path:"in_path"` - FormData string `formData:"inFormData"` -} - -type reqTestCustomMapping struct { - reqEmbedding - Query string - Path string - FormData string -} - -type reqEmbedding struct { - Header int `required:"true"` - Cookie string -} - -type reqJSONTest struct { - Query string `query:"in_query"` - BodyOne string `json:"bodyOne" required:"true"` - BodyTwo []int `json:"bodyTwo" minItems:"2"` -} - -func TestDecoder_Decode(t *testing.T) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/?in_query=abc", - strings.NewReader(url.Values{"inFormData": []string{"def"}}.Encode())) - assert.NoError(t, err) - - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("X-In-hEaDeR", "123") - - c := http.Cookie{ - Name: "in_cookie", - Value: "jkl", - } - - req.AddCookie(&c) - - df := request.NewDecoderFactory() - df.SetDecoderFunc(rest.ParamInPath, func(r *http.Request) (url.Values, error) { - assert.Equal(t, req, r) - - return url.Values{"in_path": []string{"mno"}}, nil - }) - - input := new(reqTest) - dec := df.MakeDecoder(http.MethodPost, input, nil) - - assert.NoError(t, dec.Decode(req, input, nil)) - assert.Equal(t, "abc", input.Query) - assert.Equal(t, "def", input.FormData) - assert.Equal(t, 123, input.Header) - assert.Equal(t, "jkl", input.Cookie) - assert.Equal(t, "mno", input.Path) - - inputCM := new(reqTestCustomMapping) - decCM := df.MakeDecoder(http.MethodPost, input, map[rest.ParamIn]map[string]string{ - rest.ParamInHeader: { - "Header": "X-In-HeAdEr", // Headers are mapped using canonical names. - }, - rest.ParamInCookie: {"Cookie": "in_cookie"}, - rest.ParamInQuery: {"Query": "in_query"}, - rest.ParamInPath: {"Path": "in_path"}, - rest.ParamInFormData: {"FormData": "inFormData"}, - }) - - assert.NoError(t, decCM.Decode(req, inputCM, nil)) - assert.Equal(t, "abc", inputCM.Query) - assert.Equal(t, "def", inputCM.FormData) - assert.Equal(t, 123, inputCM.Header) - assert.Equal(t, "jkl", inputCM.Cookie) - assert.Equal(t, "mno", inputCM.Path) -} - -// BenchmarkDecoderFunc_Decode-4 440503 2525 ns/op 1513 B/op 12 allocs/op. -func BenchmarkDecoderFunc_Decode(b *testing.B) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/?in_query=abc", - strings.NewReader(url.Values{"inFormData": []string{"def"}}.Encode())) - assert.NoError(b, err) - - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("X-In-Header", "123") - - c := http.Cookie{ - Name: "in_cookie", - Value: "jkl", - } - - req.AddCookie(&c) - - df := request.NewDecoderFactory() - df.SetDecoderFunc(rest.ParamInPath, func(_ *http.Request) (url.Values, error) { - return url.Values{"in_path": []string{"mno"}}, nil - }) - - dec := df.MakeDecoder(http.MethodPost, new(reqTest), nil) - - b.ResetTimer() - b.ReportAllocs() - - for i := 0; i < b.N; i++ { - input := new(reqTest) - - err := dec.Decode(req, input, nil) - if err != nil { - b.Fail() - } - - if input.Header != 123 { - b.Fail() - } - } -} - -func TestDecoder_Decode_required(t *testing.T) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/", nil) - assert.NoError(t, err) - - input := new(reqTest) - dec := request.NewDecoderFactory().MakeDecoder(http.MethodPost, input, nil) - validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). - MakeRequestValidator(http.MethodPost, input, nil) - - err = dec.Decode(req, input, validator) - assert.Equal(t, rest.ValidationErrors{"header:X-In-Header": []string{"missing value"}}, err) -} - -func TestDecoder_Decode_required_header_case(t *testing.T) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/", nil) - req.Header.Set("x-In-heAdEr", "123") - assert.NoError(t, err) - - input := new(reqTest) - dec := request.NewDecoderFactory().MakeDecoder(http.MethodPost, input, nil) - validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). - MakeRequestValidator(http.MethodPost, input, nil) - - err = dec.Decode(req, input, validator) - assert.NoError(t, err) -} - -func TestDecoder_Decode_json(t *testing.T) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/?in_query=cba", - strings.NewReader(`{"bodyOne":"abc", "bodyTwo": [1,2,3]}`)) - assert.NoError(t, err) - - input := new(reqJSONTest) - dec := request.NewDecoderFactory().MakeDecoder(http.MethodPost, input, nil) - validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). - MakeRequestValidator(http.MethodPost, input, nil) - - assert.NoError(t, dec.Decode(req, input, validator)) - assert.Equal(t, "cba", input.Query) - assert.Equal(t, "abc", input.BodyOne) - assert.Equal(t, []int{1, 2, 3}, input.BodyTwo) - - req, err = http.NewRequestWithContext(context.Background(), http.MethodPost, "/", - strings.NewReader(`{"bodyTwo":[1]}`)) - assert.NoError(t, err) - - err = dec.Decode(req, input, validator) - assert.Equal(t, rest.ValidationErrors{"body": []string{ - "#: validation failed", - "#: missing properties: \"bodyOne\"", - "#/bodyTwo: minimum 2 items allowed, but found 1 items", - }}, err) - - req, err = http.NewRequestWithContext(context.Background(), http.MethodPost, "/", - strings.NewReader(`{"bodyOne":"abc", "bodyTwo":[1]}`)) - assert.NoError(t, err) - - err = dec.Decode(req, input, validator) - assert.Error(t, err) - assert.Equal(t, rest.ValidationErrors{"body": []string{"#/bodyTwo: minimum 2 items allowed, but found 1 items"}}, err) -} - // BenchmarkDecoder_Decode_json-4 36660 29688 ns/op 12310 B/op 169 allocs/op. func BenchmarkDecoder_Decode_json(b *testing.B) { input := new(reqJSONTest) @@ -273,30 +94,35 @@ func BenchmarkDecoder_Decode_json(b *testing.B) { } } -func TestDecoder_Decode_queryObject(t *testing.T) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, - "/?in_query[1]=1.0&in_query[2]=2.1&in_query[3]=0", nil) - assert.NoError(t, err) +// BenchmarkDecoder_Decode_jsonParam-4 525867 2306 ns/op 752 B/op 12 allocs/op. +func BenchmarkDecoder_Decode_jsonParam(b *testing.B) { + type inp struct { + Filter struct { + A int `json:"a"` + B string `json:"b"` + } `query:"filter"` + } df := request.NewDecoderFactory() + dec := df.MakeDecoder(http.MethodGet, new(inp), nil) - input := new(struct { - InQuery map[int]float64 `query:"in_query"` - }) - dec := df.MakeDecoder(http.MethodGet, input, nil) + req, err := http.NewRequest(http.MethodGet, "/?filter=%7B%22a%22%3A123%2C%22b%22%3A%22abc%22%7D", nil) + require.NoError(b, err) - assert.NoError(t, dec.Decode(req, input, nil)) - assert.Equal(t, map[int]float64{1: 1, 2: 2.1, 3: 0}, input.InQuery) + v := new(inp) - req, err = http.NewRequestWithContext(context.Background(), http.MethodGet, - "/?in_query[1]=1.0&in_query[2]=2.1&in_query[c]=0", nil) - assert.NoError(t, err) + b.ReportAllocs() + b.ResetTimer() - err = dec.Decode(req, input, nil) - assert.Error(t, err) - assert.Equal(t, rest.RequestErrors{"query:in_query": []string{ - "#: invalid integer value 'c' type 'int' namespace 'in_query'", - }}, err) + for i := 0; i < b.N; i++ { + err := dec.Decode(req, v, nil) + if err != nil { + b.Fail() + } + } + + assert.Equal(b, 123, v.Filter.A) + assert.Equal(b, "abc", v.Filter.B) } // BenchmarkDecoder_Decode_queryObject-4 170670 6104 ns/op 2000 B/op 36 allocs/op. @@ -332,56 +158,114 @@ func BenchmarkDecoder_Decode_queryObject(b *testing.B) { } } -func TestDecoder_Decode_jsonParam(t *testing.T) { - type inp struct { - Filter struct { - A int `json:"a"` - B string `json:"b"` - } `query:"filter"` +// BenchmarkDecoderFunc_Decode-4 440503 2525 ns/op 1513 B/op 12 allocs/op. +func BenchmarkDecoderFunc_Decode(b *testing.B) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/?in_query=abc", + strings.NewReader(url.Values{"inFormData": []string{"def"}}.Encode())) + assert.NoError(b, err) + + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("X-In-Header", "123") + + c := http.Cookie{ + Name: "in_cookie", + Value: "jkl", } + req.AddCookie(&c) + df := request.NewDecoderFactory() - dec := df.MakeDecoder(http.MethodGet, new(inp), nil) + df.SetDecoderFunc(rest.ParamInPath, func(_ *http.Request) (url.Values, error) { + return url.Values{"in_path": []string{"mno"}}, nil + }) - req, err := http.NewRequest(http.MethodGet, "/?filter=%7B%22a%22%3A123%2C%22b%22%3A%22abc%22%7D", nil) - require.NoError(t, err) + dec := df.MakeDecoder(http.MethodPost, new(reqTest), nil) - v := new(inp) - require.NoError(t, dec.Decode(req, v, nil)) + b.ResetTimer() + b.ReportAllocs() - assert.Equal(t, 123, v.Filter.A) - assert.Equal(t, "abc", v.Filter.B) + for i := 0; i < b.N; i++ { + input := new(reqTest) + + err := dec.Decode(req, input, nil) + if err != nil { + b.Fail() + } + + if input.Header != 123 { + b.Fail() + } + } } -// BenchmarkDecoder_Decode_jsonParam-4 525867 2306 ns/op 752 B/op 12 allocs/op. -func BenchmarkDecoder_Decode_jsonParam(b *testing.B) { - type inp struct { - Filter struct { - A int `json:"a"` - B string `json:"b"` - } `query:"filter"` +func TestDecoder_Decode(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/?in_query=abc", + strings.NewReader(url.Values{"inFormData": []string{"def"}}.Encode())) + assert.NoError(t, err) + + req.Header.Add("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set("X-In-hEaDeR", "123") + + c := http.Cookie{ + Name: "in_cookie", + Value: "jkl", } + req.AddCookie(&c) + df := request.NewDecoderFactory() - dec := df.MakeDecoder(http.MethodGet, new(inp), nil) + df.SetDecoderFunc(rest.ParamInPath, func(r *http.Request) (url.Values, error) { + assert.Equal(t, req, r) - req, err := http.NewRequest(http.MethodGet, "/?filter=%7B%22a%22%3A123%2C%22b%22%3A%22abc%22%7D", nil) - require.NoError(b, err) + return url.Values{"in_path": []string{"mno"}}, nil + }) - v := new(inp) + input := new(reqTest) + dec := df.MakeDecoder(http.MethodPost, input, nil) - b.ReportAllocs() - b.ResetTimer() + assert.NoError(t, dec.Decode(req, input, nil)) + assert.Equal(t, "abc", input.Query) + assert.Equal(t, "def", input.FormData) + assert.Equal(t, 123, input.Header) + assert.Equal(t, "jkl", input.Cookie) + assert.Equal(t, "mno", input.Path) - for i := 0; i < b.N; i++ { - err := dec.Decode(req, v, nil) - if err != nil { - b.Fail() - } + inputCM := new(reqTestCustomMapping) + decCM := df.MakeDecoder(http.MethodPost, input, map[rest.ParamIn]map[string]string{ + rest.ParamInHeader: { + "Header": "X-In-HeAdEr", // Headers are mapped using canonical names. + }, + rest.ParamInCookie: {"Cookie": "in_cookie"}, + rest.ParamInQuery: {"Query": "in_query"}, + rest.ParamInPath: {"Path": "in_path"}, + rest.ParamInFormData: {"FormData": "inFormData"}, + }) + + assert.NoError(t, decCM.Decode(req, inputCM, nil)) + assert.Equal(t, "abc", inputCM.Query) + assert.Equal(t, "def", inputCM.FormData) + assert.Equal(t, 123, inputCM.Header) + assert.Equal(t, "jkl", inputCM.Cookie) + assert.Equal(t, "mno", inputCM.Path) +} + +func TestDecoder_Decode_dateTime(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, + "/?time=2020-04-04T00:00:00Z&date=2020-04-04", nil) + assert.NoError(t, err) + + type reqTest struct { + Time time.Time `query:"time"` + Date jschema.Date `query:"date"` } - assert.Equal(b, 123, v.Filter.A) - assert.Equal(b, "abc", v.Filter.B) + input := new(reqTest) + dec := request.NewDecoderFactory().MakeDecoder(http.MethodGet, input, nil) + validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). + MakeRequestValidator(http.MethodGet, input, nil) + + err = dec.Decode(req, input, validator) + assert.NoError(t, err) } func TestDecoder_Decode_error(t *testing.T) { @@ -404,45 +288,60 @@ func TestDecoder_Decode_error(t *testing.T) { }}, err) } -func TestDecoder_Decode_dateTime(t *testing.T) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, - "/?time=2020-04-04T00:00:00Z&date=2020-04-04", nil) +func TestDecoder_Decode_json(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/?in_query=cba", + strings.NewReader(`{"bodyOne":"abc", "bodyTwo": [1,2,3]}`)) assert.NoError(t, err) - type reqTest struct { - Time time.Time `query:"time"` - Date jschema.Date `query:"date"` - } - - input := new(reqTest) - dec := request.NewDecoderFactory().MakeDecoder(http.MethodGet, input, nil) + input := new(reqJSONTest) + dec := request.NewDecoderFactory().MakeDecoder(http.MethodPost, input, nil) validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). - MakeRequestValidator(http.MethodGet, input, nil) + MakeRequestValidator(http.MethodPost, input, nil) - err = dec.Decode(req, input, validator) + assert.NoError(t, dec.Decode(req, input, validator)) + assert.Equal(t, "cba", input.Query) + assert.Equal(t, "abc", input.BodyOne) + assert.Equal(t, []int{1, 2, 3}, input.BodyTwo) + + req, err = http.NewRequestWithContext(context.Background(), http.MethodPost, "/", + strings.NewReader(`{"bodyTwo":[1]}`)) assert.NoError(t, err) -} -type inputWithLoader struct { - Time time.Time `query:"time"` - Date jschema.Date `query:"date"` + err = dec.Decode(req, input, validator) + assert.Equal(t, rest.ValidationErrors{"body": []string{ + "#: validation failed", + "#: missing properties: \"bodyOne\"", + "#/bodyTwo: minimum 2 items allowed, but found 1 items", + }}, err) - load func(r *http.Request) error -} + req, err = http.NewRequestWithContext(context.Background(), http.MethodPost, "/", + strings.NewReader(`{"bodyOne":"abc", "bodyTwo":[1]}`)) + assert.NoError(t, err) -func (i *inputWithLoader) LoadFromHTTPRequest(r *http.Request) error { - return i.load(r) + err = dec.Decode(req, input, validator) + assert.Error(t, err) + assert.Equal(t, rest.ValidationErrors{"body": []string{"#/bodyTwo: minimum 2 items allowed, but found 1 items"}}, err) } -type inputWithSetter struct { - Time time.Time `query:"time"` - Date jschema.Date `query:"date"` +func TestDecoder_Decode_jsonParam(t *testing.T) { + type inp struct { + Filter struct { + A int `json:"a"` + B string `json:"b"` + } `query:"filter"` + } - r *http.Request -} + df := request.NewDecoderFactory() + dec := df.MakeDecoder(http.MethodGet, new(inp), nil) -func (i *inputWithSetter) SetRequest(r *http.Request) { - i.r = r + req, err := http.NewRequest(http.MethodGet, "/?filter=%7B%22a%22%3A123%2C%22b%22%3A%22abc%22%7D", nil) + require.NoError(t, err) + + v := new(inp) + require.NoError(t, dec.Decode(req, v, nil)) + + assert.Equal(t, 123, v.Filter.A) + assert.Equal(t, "abc", v.Filter.B) } func TestDecoder_Decode_manualLoader_ptr(t *testing.T) { @@ -497,6 +396,59 @@ func TestDecoder_Decode_manualLoader_val(t *testing.T) { assert.True(t, input.Time.IsZero()) } +func TestDecoder_Decode_queryObject(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, + "/?in_query[1]=1.0&in_query[2]=2.1&in_query[3]=0", nil) + assert.NoError(t, err) + + df := request.NewDecoderFactory() + + input := new(struct { + InQuery map[int]float64 `query:"in_query"` + }) + dec := df.MakeDecoder(http.MethodGet, input, nil) + + assert.NoError(t, dec.Decode(req, input, nil)) + assert.Equal(t, map[int]float64{1: 1, 2: 2.1, 3: 0}, input.InQuery) + + req, err = http.NewRequestWithContext(context.Background(), http.MethodGet, + "/?in_query[1]=1.0&in_query[2]=2.1&in_query[c]=0", nil) + assert.NoError(t, err) + + err = dec.Decode(req, input, nil) + assert.Error(t, err) + assert.Equal(t, rest.RequestErrors{"query:in_query": []string{ + "#: invalid integer value 'c' type 'int' namespace 'in_query'", + }}, err) +} + +func TestDecoder_Decode_required(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/", nil) + assert.NoError(t, err) + + input := new(reqTest) + dec := request.NewDecoderFactory().MakeDecoder(http.MethodPost, input, nil) + validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). + MakeRequestValidator(http.MethodPost, input, nil) + + err = dec.Decode(req, input, validator) + assert.Equal(t, rest.ValidationErrors{"header:X-In-Header": []string{"missing value"}}, err) +} + +func TestDecoder_Decode_required_header_case(t *testing.T) { + req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/", nil) + req.Header.Set("x-In-heAdEr", "123") + assert.NoError(t, err) + + input := new(reqTest) + dec := request.NewDecoderFactory().MakeDecoder(http.MethodPost, input, nil) + validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). + MakeRequestValidator(http.MethodPost, input, nil) + + err = dec.Decode(req, input, validator) + assert.NoError(t, err) +} + func TestDecoder_Decode_setter_ptr(t *testing.T) { req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, "/?time=2020-04-04T00:00:00Z&date=2020-04-04", nil) @@ -568,13 +520,6 @@ func TestDecoderFactory_MakeDecoder_default_unexported(t *testing.T) { assert.NotNil(t, dec) } -type formOrJSONInput struct { - Field1 string `json:"field1" formData:"field1" required:"true"` - Field2 int `json:"field2" formData:"field2" required:"true"` -} - -func (formOrJSONInput) ForceJSONRequestBody() {} - func TestDecoderFactory_MakeDecoder_formOrJSON(t *testing.T) { var in formOrJSONInput @@ -603,3 +548,58 @@ func TestDecoderFactory_MakeDecoder_formOrJSON(t *testing.T) { assert.Equal(t, "abc", in.Field1) assert.Equal(t, 123, in.Field2) } + +type formOrJSONInput struct { + Field1 string `json:"field1" formData:"field1" required:"true"` + Field2 int `json:"field2" formData:"field2" required:"true"` +} + +func (formOrJSONInput) ForceJSONRequestBody() {} + +type inputWithLoader struct { + Time time.Time `query:"time"` + Date jschema.Date `query:"date"` + + load func(r *http.Request) error +} + +func (i *inputWithLoader) LoadFromHTTPRequest(r *http.Request) error { + return i.load(r) +} + +type inputWithSetter struct { + Time time.Time `query:"time"` + Date jschema.Date `query:"date"` + + r *http.Request +} + +func (i *inputWithSetter) SetRequest(r *http.Request) { + i.r = r +} + +type reqEmbedding struct { + Header int `required:"true"` + Cookie string +} + +type reqJSONTest struct { + Query string `query:"in_query"` + BodyOne string `json:"bodyOne" required:"true"` + BodyTwo []int `json:"bodyTwo" minItems:"2"` +} + +type reqTest struct { + Header int `header:"X-In-HeAdEr" required:"true"` // Headers are mapped using canonical names. + Cookie string `cookie:"in_cookie"` + Query string `query:"in_query"` + Path string `path:"in_path"` + FormData string `formData:"inFormData"` +} + +type reqTestCustomMapping struct { + reqEmbedding + Query string + Path string + FormData string +} diff --git a/request/factory.go b/request/factory.go index 4702230..845abcb 100644 --- a/request/factory.go +++ b/request/factory.go @@ -18,41 +18,6 @@ import ( "github.com/swaggest/rest/nethttp" ) -var _ DecoderMaker = &DecoderFactory{} - -const ( - defaultTag = "default" - jsonTag = "json" - fileTag = "file" - formDataTag = "formData" -) - -// DecoderFactory decodes http requests. -// -// Please use NewDecoderFactory to create instance. -type DecoderFactory struct { - // ApplyDefaults enables default value assignment for fields missing explicit value in request. - // Default value is retrieved from `default` field tag. - ApplyDefaults bool - - // JSONReader allows custom JSON decoder for request body. - // If not set encoding/json.Decoder is used. - JSONReader func(rd io.Reader, v interface{}) error - - // JSONSchemaReflector is optional, it is called to infer "default" values. - JSONSchemaReflector *jsonschema.Reflector - - formDecoders map[rest.ParamIn]*form.Decoder - decoderFunctions map[rest.ParamIn]decoderFunc - defaultValDecoder *form.Decoder - customDecoders []customDecoder -} - -type customDecoder struct { - types []interface{} - fn form.DecodeFunc -} - // NewDecoderFactory creates request decoder factory. func NewDecoderFactory() *DecoderFactory { df := DecoderFactory{} @@ -75,23 +40,25 @@ func NewDecoderFactory() *DecoderFactory { return &df } -// SetDecoderFunc adds custom decoder function for values of particular field tag name. -func (df *DecoderFactory) SetDecoderFunc(tagName rest.ParamIn, d func(r *http.Request) (url.Values, error)) { - if df.decoderFunctions == nil { - df.decoderFunctions = make(map[rest.ParamIn]decoderFunc) - } +// DecoderFactory decodes http requests. +// +// Please use NewDecoderFactory to create instance. +type DecoderFactory struct { + // ApplyDefaults enables default value assignment for fields missing explicit value in request. + // Default value is retrieved from `default` field tag. + ApplyDefaults bool - if df.formDecoders == nil { - df.formDecoders = make(map[rest.ParamIn]*form.Decoder) - } + // JSONReader allows custom JSON decoder for request body. + // If not set encoding/json.Decoder is used. + JSONReader func(rd io.Reader, v interface{}) error - df.decoderFunctions[tagName] = d - dec := form.NewDecoder() - dec.SetNamespacePrefix("[") - dec.SetNamespaceSuffix("]") - dec.SetTagName(string(tagName)) - dec.SetMode(form.ModeExplicit) - df.formDecoders[tagName] = dec + // JSONSchemaReflector is optional, it is called to infer "default" values. + JSONSchemaReflector *jsonschema.Reflector + + formDecoders map[rest.ParamIn]*form.Decoder + decoderFunctions map[rest.ParamIn]decoderFunc + defaultValDecoder *form.Decoder + customDecoders []customDecoder } // MakeDecoder creates request.RequestDecoder for a http method and request structure. @@ -159,68 +126,37 @@ func (df *DecoderFactory) MakeDecoder( return &d } -func initDecoder(input interface{}) decoder { - d := decoder{ - decoders: make([]valueDecoderFunc, 0), - in: make([]rest.ParamIn, 0), +// RegisterFunc adds custom type handling. +func (df *DecoderFactory) RegisterFunc(fn form.DecodeFunc, types ...interface{}) { + for _, fd := range df.formDecoders { + fd.RegisterFunc(fn, types...) } - loader := reflect.TypeOf((*Loader)(nil)).Elem() - d.isReqLoader = reflect.TypeOf(input).Implements(loader) || - reflect.New(reflect.TypeOf(input)).Type().Implements(loader) - - setter := reflect.TypeOf((*Setter)(nil)).Elem() - d.isReqSetter = reflect.TypeOf(input).Implements(setter) || - reflect.New(reflect.TypeOf(input)).Type().Implements(setter) + df.defaultValDecoder.RegisterFunc(fn, types...) - return d + df.customDecoders = append(df.customDecoders, customDecoder{ + fn: fn, + types: types, + }) } -func (df *DecoderFactory) prepareCustomMapping(input interface{}, customMapping rest.RequestMapping) rest.RequestMapping { - // Copy custom mapping to avoid mutability issues on original map. - cm := make(rest.RequestMapping, len(customMapping)) - for k, v := range customMapping { - cm[k] = v - } - - // Move header names to custom mapping and/or apply canonical form to match net/http request decoder. - if hdm, exists := cm[rest.ParamInHeader]; !exists && refl.HasTaggedFields(input, string(rest.ParamInHeader)) { - hdm = make(map[string]string) - - refl.WalkTaggedFields(reflect.ValueOf(input), func(_ reflect.Value, sf reflect.StructField, tag string) { - hdm[sf.Name] = http.CanonicalHeaderKey(tag) - }, string(rest.ParamInHeader)) - - cm[rest.ParamInHeader] = hdm - } else if exists { - for k, v := range hdm { - hdm[k] = http.CanonicalHeaderKey(v) - } - } - - fields := make(map[string]bool) - - refl.WalkTaggedFields(reflect.ValueOf(input), func(_ reflect.Value, sf reflect.StructField, _ string) { - fields[sf.Name] = true - }, "") - - // Check if there are non-existent fields in mapping. - var nonExistent []string - - for _, items := range cm { - for k := range items { - if _, exists := fields[k]; !exists { - nonExistent = append(nonExistent, k) - } - } +// SetDecoderFunc adds custom decoder function for values of particular field tag name. +func (df *DecoderFactory) SetDecoderFunc(tagName rest.ParamIn, d func(r *http.Request) (url.Values, error)) { + if df.decoderFunctions == nil { + df.decoderFunctions = make(map[rest.ParamIn]decoderFunc) } - if len(nonExistent) > 0 { - sort.Strings(nonExistent) - panic("non existent fields in mapping: " + strings.Join(nonExistent, ", ")) + if df.formDecoders == nil { + df.formDecoders = make(map[rest.ParamIn]*form.Decoder) } - return cm + df.decoderFunctions[tagName] = d + dec := form.NewDecoder() + dec.SetNamespacePrefix("[") + dec.SetNamespaceSuffix("]") + dec.SetTagName(string(tagName)) + dec.SetMode(form.ModeExplicit) + df.formDecoders[tagName] = dec } // jsonParams configures custom decoding for parameters with JSON struct values. @@ -259,6 +195,37 @@ func (df *DecoderFactory) jsonParams(formDecoder *form.Decoder, in rest.ParamIn, }, string(in)) } +func (df *DecoderFactory) makeCustomMappingDecoder(customMapping rest.RequestMapping, m *decoder) { + for in, mapping := range customMapping { + dec := form.NewDecoder() + dec.SetNamespacePrefix("[") + dec.SetNamespaceSuffix("]") + dec.SetTagName(string(in)) + + // Copy mapping to avoid mutability. + mm := make(map[string]string, len(mapping)) + for k, v := range mapping { + mm[k] = v + } + + dec.RegisterTagNameFunc(func(field reflect.StructField) string { + n := mm[field.Name] + if n == "" && !field.Anonymous { + return "-" + } + + return n + }) + + for _, c := range df.customDecoders { + dec.RegisterFunc(c.fn, c.types...) + } + + m.decoders = append(m.decoders, makeDecoder(in, dec, df.decoderFunctions[in])) + m.in = append(m.in, in) + } +} + func (df *DecoderFactory) makeDefaultDecoder(input interface{}, m *decoder) { defaults := url.Values{} @@ -316,47 +283,80 @@ func (df *DecoderFactory) makeDefaultDecoder(input interface{}, m *decoder) { m.in = append(m.in, defaultTag) } -func (df *DecoderFactory) makeCustomMappingDecoder(customMapping rest.RequestMapping, m *decoder) { - for in, mapping := range customMapping { - dec := form.NewDecoder() - dec.SetNamespacePrefix("[") - dec.SetNamespaceSuffix("]") - dec.SetTagName(string(in)) +func (df *DecoderFactory) prepareCustomMapping(input interface{}, customMapping rest.RequestMapping) rest.RequestMapping { + // Copy custom mapping to avoid mutability issues on original map. + cm := make(rest.RequestMapping, len(customMapping)) + for k, v := range customMapping { + cm[k] = v + } - // Copy mapping to avoid mutability. - mm := make(map[string]string, len(mapping)) - for k, v := range mapping { - mm[k] = v + // Move header names to custom mapping and/or apply canonical form to match net/http request decoder. + if hdm, exists := cm[rest.ParamInHeader]; !exists && refl.HasTaggedFields(input, string(rest.ParamInHeader)) { + hdm = make(map[string]string) + + refl.WalkTaggedFields(reflect.ValueOf(input), func(_ reflect.Value, sf reflect.StructField, tag string) { + hdm[sf.Name] = http.CanonicalHeaderKey(tag) + }, string(rest.ParamInHeader)) + + cm[rest.ParamInHeader] = hdm + } else if exists { + for k, v := range hdm { + hdm[k] = http.CanonicalHeaderKey(v) } + } - dec.RegisterTagNameFunc(func(field reflect.StructField) string { - n := mm[field.Name] - if n == "" && !field.Anonymous { - return "-" - } + fields := make(map[string]bool) - return n - }) + refl.WalkTaggedFields(reflect.ValueOf(input), func(_ reflect.Value, sf reflect.StructField, _ string) { + fields[sf.Name] = true + }, "") - for _, c := range df.customDecoders { - dec.RegisterFunc(c.fn, c.types...) + // Check if there are non-existent fields in mapping. + var nonExistent []string + + for _, items := range cm { + for k := range items { + if _, exists := fields[k]; !exists { + nonExistent = append(nonExistent, k) + } } + } - m.decoders = append(m.decoders, makeDecoder(in, dec, df.decoderFunctions[in])) - m.in = append(m.in, in) + if len(nonExistent) > 0 { + sort.Strings(nonExistent) + panic("non existent fields in mapping: " + strings.Join(nonExistent, ", ")) } + + return cm } -// RegisterFunc adds custom type handling. -func (df *DecoderFactory) RegisterFunc(fn form.DecodeFunc, types ...interface{}) { - for _, fd := range df.formDecoders { - fd.RegisterFunc(fn, types...) +const ( + defaultTag = "default" + jsonTag = "json" + fileTag = "file" + formDataTag = "formData" +) + +var _ DecoderMaker = &DecoderFactory{} + +func initDecoder(input interface{}) decoder { + d := decoder{ + decoders: make([]valueDecoderFunc, 0), + in: make([]rest.ParamIn, 0), } - df.defaultValDecoder.RegisterFunc(fn, types...) + loader := reflect.TypeOf((*Loader)(nil)).Elem() + d.isReqLoader = reflect.TypeOf(input).Implements(loader) || + reflect.New(reflect.TypeOf(input)).Type().Implements(loader) - df.customDecoders = append(df.customDecoders, customDecoder{ - fn: fn, - types: types, - }) + setter := reflect.TypeOf((*Setter)(nil)).Elem() + d.isReqSetter = reflect.TypeOf(input).Implements(setter) || + reflect.New(reflect.TypeOf(input)).Type().Implements(setter) + + return d +} + +type customDecoder struct { + types []interface{} + fn form.DecodeFunc } diff --git a/request/factory_test.go b/request/factory_test.go index fdf6eb8..8023590 100644 --- a/request/factory_test.go +++ b/request/factory_test.go @@ -16,57 +16,6 @@ import ( "github.com/swaggest/rest/request" ) -func TestDecoderFactory_SetDecoderFunc(t *testing.T) { - df := request.NewDecoderFactory() - df.SetDecoderFunc("jwt", func(r *http.Request) (url.Values, error) { - ah := r.Header.Get("Authorization") - if ah == "" || len(ah) < 8 || strings.ToLower(ah[0:7]) != "bearer " { - return nil, nil - } - - var m map[string]json.RawMessage - - err := json.Unmarshal([]byte(ah[7:]), &m) - if err != nil { - return nil, err - } - - res := make(url.Values) - - for k, v := range m { - if len(v) > 2 && v[0] == '"' && v[len(v)-1] == '"' { - v = v[1 : len(v)-1] - } - - res[k] = []string{string(v)} - } - - return res, err - }) - - type req struct { - Q string `query:"q"` - Name string `jwt:"name"` - Iat int `jwt:"iat"` - Sub string `jwt:"sub"` - } - - r, err := http.NewRequest(http.MethodGet, "/?q=abc", nil) - require.NoError(t, err) - - r.Header.Add("Authorization", `Bearer {"sub":"1234567890","name":"John Doe","iat": 1516239022}`) - - d := df.MakeDecoder(http.MethodGet, new(req), nil) - - rr := new(req) - require.NoError(t, d.Decode(r, rr, nil)) - - assert.Equal(t, "John Doe", rr.Name) - assert.Equal(t, "1234567890", rr.Sub) - assert.Equal(t, 1516239022, rr.Iat) - assert.Equal(t, "abc", rr.Q) -} - // BenchmarkDecoderFactory_SetDecoderFunc-4 577378 1994 ns/op 1024 B/op 16 allocs/op. func BenchmarkDecoderFactory_SetDecoderFunc(b *testing.B) { df := request.NewDecoderFactory() @@ -123,32 +72,21 @@ func BenchmarkDecoderFactory_SetDecoderFunc(b *testing.B) { } } -func TestDecoderFactory_MakeDecoder_default(t *testing.T) { - type Embed struct { - Baz bool `query:"baz" default:"true"` - } - - type DeeplyEmbedded struct { - Embed - } - +func TestDecoderFactory_MakeDecoder_customMapping(t *testing.T) { type MyInput struct { - ID int `query:"id" default:"123"` - Name string `header:"X-Name" default:"foo"` - Deeper struct { - Foo string `query:"foo" default:"abc"` - EvenDeeper struct { - Bar float64 `query:"bar" default:"1.23"` - } `query:"even_deeper"` - } `query:"deeper"` - *DeeplyEmbedded - unexported bool `query:"unexported"` // This field is skipped because it is unexported. + ID int `default:"123"` + Name string `default:"foo"` } df := request.NewDecoderFactory() df.ApplyDefaults = true - dec := df.MakeDecoder(http.MethodPost, new(MyInput), nil) + customMapping := rest.RequestMapping{ + rest.ParamInQuery: map[string]string{"ID": "id"}, + rest.ParamInHeader: map[string]string{"Name": "X-Name"}, + } + + dec := df.MakeDecoder(http.MethodPost, new(MyInput), customMapping) assert.NotNil(t, dec) req, err := http.NewRequest(http.MethodPost, "/", nil) @@ -160,15 +98,8 @@ func TestDecoderFactory_MakeDecoder_default(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "foo", i.Name) assert.Equal(t, 123, i.ID) - assert.Equal(t, "abc", i.Deeper.Foo) - assert.Equal(t, 1.23, i.Deeper.EvenDeeper.Bar) - assert.Equal(t, true, i.Baz) - req, err = http.NewRequest( - http.MethodPost, - "/?id=321&deeper[foo]=def&deeper[even_deeper][bar]=3.21&baz=false", - nil, - ) + req, err = http.NewRequest(http.MethodPost, "/?id=321", nil) require.NoError(t, err) req.Header.Set("X-Name", "bar") @@ -179,44 +110,34 @@ func TestDecoderFactory_MakeDecoder_default(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "bar", i.Name) assert.Equal(t, 321, i.ID) - assert.Equal(t, "def", i.Deeper.Foo) - assert.Equal(t, 3.21, i.Deeper.EvenDeeper.Bar) - assert.Equal(t, false, i.Baz) } -func TestDecoderFactory_MakeDecoder_invalidMapping(t *testing.T) { - assert.PanicsWithValue(t, "non existent fields in mapping: ID2, WrongName", func() { - type MyInput struct { - ID int `default:"123"` - Name string `default:"foo"` - } - - df := request.NewDecoderFactory() - - customMapping := rest.RequestMapping{ - rest.ParamInQuery: map[string]string{"ID2": "id"}, - rest.ParamInHeader: map[string]string{"WrongName": "X-Name"}, - } +func TestDecoderFactory_MakeDecoder_default(t *testing.T) { + type Embed struct { + Baz bool `query:"baz" default:"true"` + } - _ = df.MakeDecoder(http.MethodPost, new(MyInput), customMapping) - }) -} + type DeeplyEmbedded struct { + Embed + } -func TestDecoderFactory_MakeDecoder_customMapping(t *testing.T) { type MyInput struct { - ID int `default:"123"` - Name string `default:"foo"` + ID int `query:"id" default:"123"` + Name string `header:"X-Name" default:"foo"` + Deeper struct { + Foo string `query:"foo" default:"abc"` + EvenDeeper struct { + Bar float64 `query:"bar" default:"1.23"` + } `query:"even_deeper"` + } `query:"deeper"` + *DeeplyEmbedded + unexported bool `query:"unexported"` // This field is skipped because it is unexported. } df := request.NewDecoderFactory() df.ApplyDefaults = true - customMapping := rest.RequestMapping{ - rest.ParamInQuery: map[string]string{"ID": "id"}, - rest.ParamInHeader: map[string]string{"Name": "X-Name"}, - } - - dec := df.MakeDecoder(http.MethodPost, new(MyInput), customMapping) + dec := df.MakeDecoder(http.MethodPost, new(MyInput), nil) assert.NotNil(t, dec) req, err := http.NewRequest(http.MethodPost, "/", nil) @@ -228,8 +149,15 @@ func TestDecoderFactory_MakeDecoder_customMapping(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "foo", i.Name) assert.Equal(t, 123, i.ID) + assert.Equal(t, "abc", i.Deeper.Foo) + assert.Equal(t, 1.23, i.Deeper.EvenDeeper.Bar) + assert.Equal(t, true, i.Baz) - req, err = http.NewRequest(http.MethodPost, "/?id=321", nil) + req, err = http.NewRequest( + http.MethodPost, + "/?id=321&deeper[foo]=def&deeper[even_deeper][bar]=3.21&baz=false", + nil, + ) require.NoError(t, err) req.Header.Set("X-Name", "bar") @@ -240,6 +168,9 @@ func TestDecoderFactory_MakeDecoder_customMapping(t *testing.T) { assert.NoError(t, err) assert.Equal(t, "bar", i.Name) assert.Equal(t, 321, i.ID) + assert.Equal(t, "def", i.Deeper.Foo) + assert.Equal(t, 3.21, i.Deeper.EvenDeeper.Bar) + assert.Equal(t, false, i.Baz) } func TestDecoderFactory_MakeDecoder_header_case_sensitivity(t *testing.T) { @@ -272,35 +203,73 @@ func TestDecoderFactory_MakeDecoder_header_case_sensitivity(t *testing.T) { assert.Equal(t, "hello!", v.D) } -type defaultFromSchema string +func TestDecoderFactory_MakeDecoder_invalidMapping(t *testing.T) { + assert.PanicsWithValue(t, "non existent fields in mapping: ID2, WrongName", func() { + type MyInput struct { + ID int `default:"123"` + Name string `default:"foo"` + } -func (d *defaultFromSchema) PrepareJSONSchema(schema *jsonschema.Schema) error { - schema.WithDefault(enum1) - schema.WithTitle("Value with default from schema") + df := request.NewDecoderFactory() - return nil + customMapping := rest.RequestMapping{ + rest.ParamInQuery: map[string]string{"ID2": "id"}, + rest.ParamInHeader: map[string]string{"WrongName": "X-Name"}, + } + + _ = df.MakeDecoder(http.MethodPost, new(MyInput), customMapping) + }) } -type defaultFromSchemaVal string +func TestDecoderFactory_SetDecoderFunc(t *testing.T) { + df := request.NewDecoderFactory() + df.SetDecoderFunc("jwt", func(r *http.Request) (url.Values, error) { + ah := r.Header.Get("Authorization") + if ah == "" || len(ah) < 8 || strings.ToLower(ah[0:7]) != "bearer " { + return nil, nil + } -func (d defaultFromSchemaVal) PrepareJSONSchema(schema *jsonschema.Schema) error { - schema.WithDefault(enum1) - schema.WithTitle("Value with default from schema") + var m map[string]json.RawMessage - return nil -} + err := json.Unmarshal([]byte(ah[7:]), &m) + if err != nil { + return nil, err + } -const ( - enum1 = "all" - enum2 = "none" -) + res := make(url.Values) -func (d *defaultFromSchema) Enum() []interface{} { - return []interface{}{enum1, enum2} -} + for k, v := range m { + if len(v) > 2 && v[0] == '"' && v[len(v)-1] == '"' { + v = v[1 : len(v)-1] + } -func (d defaultFromSchemaVal) Enum() []interface{} { - return []interface{}{enum1, enum2} + res[k] = []string{string(v)} + } + + return res, err + }) + + type req struct { + Q string `query:"q"` + Name string `jwt:"name"` + Iat int `jwt:"iat"` + Sub string `jwt:"sub"` + } + + r, err := http.NewRequest(http.MethodGet, "/?q=abc", nil) + require.NoError(t, err) + + r.Header.Add("Authorization", `Bearer {"sub":"1234567890","name":"John Doe","iat": 1516239022}`) + + d := df.MakeDecoder(http.MethodGet, new(req), nil) + + rr := new(req) + require.NoError(t, d.Decode(r, rr, nil)) + + assert.Equal(t, "John Doe", rr.Name) + assert.Equal(t, "1234567890", rr.Sub) + assert.Equal(t, 1516239022, rr.Iat) + assert.Equal(t, "abc", rr.Q) } func TestNewDecoderFactory_default(t *testing.T) { @@ -367,3 +336,34 @@ func TestNewDecoderFactory_requestBody(t *testing.T) { assert.Equal(t, "hello,world", input.CSVBody) assert.Empty(t, input.TextBody) } + +const ( + enum1 = "all" + enum2 = "none" +) + +type defaultFromSchema string + +func (d *defaultFromSchema) Enum() []interface{} { + return []interface{}{enum1, enum2} +} + +func (d *defaultFromSchema) PrepareJSONSchema(schema *jsonschema.Schema) error { + schema.WithDefault(enum1) + schema.WithTitle("Value with default from schema") + + return nil +} + +type defaultFromSchemaVal string + +func (d defaultFromSchemaVal) Enum() []interface{} { + return []interface{}{enum1, enum2} +} + +func (d defaultFromSchemaVal) PrepareJSONSchema(schema *jsonschema.Schema) error { + schema.WithDefault(enum1) + schema.WithTitle("Value with default from schema") + + return nil +} diff --git a/request/file.go b/request/file.go index 049f4b0..505f0b9 100644 --- a/request/file.go +++ b/request/file.go @@ -17,12 +17,6 @@ var ( multipartFileHeadersType = reflect.TypeOf(([]*multipart.FileHeader)(nil)) ) -func decodeFiles(r *http.Request, input interface{}, _ rest.Validator) error { - v := reflect.ValueOf(input) - - return decodeFilesInStruct(r, v) -} - func decodeFilesInStruct(r *http.Request, v reflect.Value) error { for v.Kind() == reflect.Ptr { v = v.Elem() @@ -111,3 +105,9 @@ func setFile(r *http.Request, field reflect.StructField, v reflect.Value) error return nil } + +func decodeFiles(r *http.Request, input interface{}, _ rest.Validator) error { + v := reflect.ValueOf(input) + + return decodeFilesInStruct(r, v) +} diff --git a/request/file_test.go b/request/file_test.go index 77d9d40..01a51dd 100644 --- a/request/file_test.go +++ b/request/file_test.go @@ -24,18 +24,6 @@ import ( "github.com/swaggest/usecase" ) -type ReqEmb struct { - Simple string `formData:"simple"` - UploadHeader *multipart.FileHeader `formData:"upload"` - UploadsHeaders []*multipart.FileHeader `formData:"uploads"` -} - -type fileReqTest struct { - ReqEmb - Upload multipart.File `file:"upload"` - Uploads []multipart.File `formData:"uploads"` -} - func TestDecoder_Decode_fileUploadOptional(t *testing.T) { u := usecase.NewIOI(new(ReqEmb), nil, func(_ context.Context, _, _ interface{}) error { return nil @@ -162,3 +150,15 @@ func TestDecoder_Decode_fileUploadTag(t *testing.T) { assert.NoError(t, err) assert.NoError(t, resp.Body.Close()) } + +type ReqEmb struct { + Simple string `formData:"simple"` + UploadHeader *multipart.FileHeader `formData:"upload"` + UploadsHeaders []*multipart.FileHeader `formData:"uploads"` +} + +type fileReqTest struct { + ReqEmb + Upload multipart.File `file:"upload"` + Uploads []multipart.File `formData:"uploads"` +} diff --git a/request/jsonbody.go b/request/jsonbody.go index 788cab3..d747903 100644 --- a/request/jsonbody.go +++ b/request/jsonbody.go @@ -18,12 +18,6 @@ var bufPool = sync.Pool{ }, } -func readJSON(rd io.Reader, v interface{}) error { - d := json.NewDecoder(rd) - - return d.Decode(v) -} - func decodeJSONBody(readJSON func(rd io.Reader, v interface{}) error, tolerateFormData bool) valueDecoderFunc { return func(r *http.Request, input interface{}, validator rest.Validator) error { if r.ContentLength == 0 { @@ -82,3 +76,9 @@ func checkJSONBodyContentType(contentType string, tolerateFormData bool) (ret bo return false, nil } + +func readJSON(rd io.Reader, v interface{}) error { + d := json.NewDecoder(rd) + + return d.Decode(v) +} diff --git a/request/jsonbody_test.go b/request/jsonbody_test.go index 3380b62..f5b7a45 100644 --- a/request/jsonbody_test.go +++ b/request/jsonbody_test.go @@ -1,4 +1,4 @@ -package request //nolint:testpackage +package request import ( "bytes" @@ -12,6 +12,7 @@ import ( "github.com/swaggest/rest" ) +//nolint:testpackage func Test_decodeJSONBody(t *testing.T) { createBody := bytes.NewReader( []byte(`{"amount": 123,"customerId": "248df4b7-aa70-47b8-a036-33ac447e668d","type": "withdraw"}`)) @@ -43,16 +44,6 @@ func Test_decodeJSONBody(t *testing.T) { assert.Equal(t, "withdraw", i.Type) } -func Test_decodeJSONBody_emptyBody(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, "any", nil) - require.NoError(t, err) - - var i []int - - err = decodeJSONBody(readJSON, false)(req, &i, nil) - assert.EqualError(t, err, "missing request body") -} - func Test_decodeJSONBody_badContentType(t *testing.T) { req, err := http.NewRequest(http.MethodPost, "any", bytes.NewBufferString("123")) require.NoError(t, err) @@ -64,38 +55,38 @@ func Test_decodeJSONBody_badContentType(t *testing.T) { assert.EqualError(t, err, "request with application/json content type expected, received: text/plain") } -func Test_decodeJSONBody_decodeFailed(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, "any", bytes.NewBufferString("abc")) +func Test_decodeJSONBody_charset(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "any", bytes.NewBufferString(`{"amount": 123}`)) require.NoError(t, err) + req.Header.Set("Content-Type", "application/json;charset=utf-8") - var i []int + type Input struct { + Amount int `json:"amount" formData:"amount"` + } - err = decodeJSONBody(readJSON, false)(req, &i, nil) - assert.Error(t, err) + i := Input{} + + assert.NoError(t, decodeJSONBody(readJSON, false)(req, &i, nil)) } -func Test_decodeJSONBody_unmarshalFailed(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, "any", bytes.NewBufferString("123")) +func Test_decodeJSONBody_decodeFailed(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "any", bytes.NewBufferString("abc")) require.NoError(t, err) var i []int err = decodeJSONBody(readJSON, false)(req, &i, nil) - assert.EqualError(t, err, "failed to decode json: json: cannot unmarshal number into Go value of type []int") + assert.Error(t, err) } -func Test_decodeJSONBody_validateFailed(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, "any", bytes.NewBufferString("[123]")) +func Test_decodeJSONBody_emptyBody(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "any", nil) require.NoError(t, err) var i []int - vl := rest.ValidatorFunc(func(_ rest.ParamIn, _ map[string]interface{}) error { - return errors.New("failed") - }) - - err = decodeJSONBody(readJSON, false)(req, &i, vl) - assert.EqualError(t, err, "failed") + err = decodeJSONBody(readJSON, false)(req, &i, nil) + assert.EqualError(t, err, "missing request body") } func Test_decodeJSONBody_tolerateFormData(t *testing.T) { @@ -118,16 +109,26 @@ func Test_decodeJSONBody_tolerateFormData(t *testing.T) { assert.Empty(t, i.Type) } -func Test_decodeJSONBody_charset(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, "any", bytes.NewBufferString(`{"amount": 123}`)) +func Test_decodeJSONBody_unmarshalFailed(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "any", bytes.NewBufferString("123")) require.NoError(t, err) - req.Header.Set("Content-Type", "application/json;charset=utf-8") - type Input struct { - Amount int `json:"amount" formData:"amount"` - } + var i []int - i := Input{} + err = decodeJSONBody(readJSON, false)(req, &i, nil) + assert.EqualError(t, err, "failed to decode json: json: cannot unmarshal number into Go value of type []int") +} - assert.NoError(t, decodeJSONBody(readJSON, false)(req, &i, nil)) +func Test_decodeJSONBody_validateFailed(t *testing.T) { + req, err := http.NewRequest(http.MethodPost, "any", bytes.NewBufferString("[123]")) + require.NoError(t, err) + + var i []int + + vl := rest.ValidatorFunc(func(_ rest.ParamIn, _ map[string]interface{}) error { + return errors.New("failed") + }) + + err = decodeJSONBody(readJSON, false)(req, &i, vl) + assert.EqualError(t, err, "failed") } diff --git a/request/middleware.go b/request/middleware.go index 4d53d15..dcf0da7 100644 --- a/request/middleware.go +++ b/request/middleware.go @@ -8,14 +8,6 @@ import ( "github.com/swaggest/usecase" ) -type requestDecoderSetter interface { - SetRequestDecoder(rd nethttp.RequestDecoder) -} - -type requestMapping interface { - RequestMapping() rest.RequestMapping -} - // DecoderMiddleware sets up request decoder in suitable handlers. func DecoderMiddleware(factory DecoderMaker) func(http.Handler) http.Handler { return func(handler http.Handler) http.Handler { @@ -57,10 +49,6 @@ func DecoderMiddleware(factory DecoderMaker) func(http.Handler) http.Handler { } } -type withRestHandler interface { - RestHandler() *rest.HandlerTrait -} - // ValidatorMiddleware sets up request validator in suitable handlers. func ValidatorMiddleware(factory rest.RequestValidatorFactory) func(http.Handler) http.Handler { return func(handler http.Handler) http.Handler { @@ -91,8 +79,6 @@ func ValidatorMiddleware(factory rest.RequestValidatorFactory) func(http.Handler } } -var _ nethttp.RequestDecoder = DecoderFunc(nil) - // DecoderFunc implements RequestDecoder with a func. type DecoderFunc func(r *http.Request, input interface{}, validator rest.Validator) error @@ -105,3 +91,17 @@ func (df DecoderFunc) Decode(r *http.Request, input interface{}, validator rest. type DecoderMaker interface { MakeDecoder(method string, input interface{}, customMapping rest.RequestMapping) nethttp.RequestDecoder } + +var _ nethttp.RequestDecoder = DecoderFunc(nil) + +type requestDecoderSetter interface { + SetRequestDecoder(rd nethttp.RequestDecoder) +} + +type requestMapping interface { + RequestMapping() rest.RequestMapping +} + +type withRestHandler interface { + RestHandler() *rest.HandlerTrait +} diff --git a/response/encoder.go b/response/encoder.go index 926376d..f6656e0 100644 --- a/response/encoder.go +++ b/response/encoder.go @@ -17,14 +17,28 @@ import ( "github.com/swaggest/usecase/status" ) -type ( - // Setter captures original http.ResponseWriter. - // - // Implement this interface on a pointer to your output structure to get access to http.ResponseWriter. - Setter interface { - SetResponseWriter(rw http.ResponseWriter) - } -) +// DefaultErrorResponseContentType is a package-level variable set to +// default error response content type. +var DefaultErrorResponseContentType = "application/json" + +// DefaultSuccessResponseContentType is a package-level variable set to +// default success response content type. +var DefaultSuccessResponseContentType = "application/json" + +// EmbeddedSetter can capture http.ResponseWriter in your output structure. +type EmbeddedSetter struct { + rw http.ResponseWriter +} + +// ResponseWriter is an accessor. +func (e *EmbeddedSetter) ResponseWriter() http.ResponseWriter { + return e.rw +} + +// SetResponseWriter implements Setter. +func (e *EmbeddedSetter) SetResponseWriter(rw http.ResponseWriter) { + e.rw = rw +} // Encoder prepares and writes http response. type Encoder struct { @@ -45,80 +59,183 @@ type Encoder struct { dynamicNoContent bool } -type noContent interface { - // NoContent controls whether status 204 should be used in response to current request. - NoContent() bool -} +// MakeOutput instantiates a value for use case output port. +func (h *Encoder) MakeOutput(w http.ResponseWriter, ht rest.HandlerTrait) interface{} { + if h.outputBufferType == nil { + return nil + } -type outputWithHeadersSetup interface { - // SetupResponseHeader gives access to response headers of current request. - SetupResponseHeader(h http.Header) + output := reflect.New(h.outputBufferType).Interface() + + if h.outputWithWriter { + if withWriter, ok := output.(usecase.OutputWithWriter); ok { + if h.outputHeadersEncoder != nil || ht.SuccessContentType != "" { + withWriter.SetWriter(&writerWithHeaders{ + ResponseWriter: w, + responseWriter: h, + trait: ht, + output: output, + }) + } else { + withWriter.SetWriter(w) + } + } + } + + if h.dynamicSetter { + if setter, ok := output.(Setter); ok { + setter.SetResponseWriter(w) + } + } + + return output } -// DefaultSuccessResponseContentType is a package-level variable set to -// default success response content type. -var DefaultSuccessResponseContentType = "application/json" +// SetupOutput configures encoder with and instance of use case output. +func (h *Encoder) SetupOutput(output interface{}, ht *rest.HandlerTrait) { + h.outputBufferType = reflect.TypeOf(output) + h.outputHeadersEncoder = nil + h.skipRendering = true -// DefaultErrorResponseContentType is a package-level variable set to -// default error response content type. -var DefaultErrorResponseContentType = "application/json" + if output == nil { + return + } -// addressable makes a pointer from a non-pointer values. -func addressable(output interface{}) interface{} { - if reflect.ValueOf(output).Kind() != reflect.Ptr { - o := reflect.New(reflect.TypeOf(output)) - o.Elem().Set(reflect.ValueOf(output)) + output = addressable(output) - output = o.Interface() + h.unwrapInterface = reflect.ValueOf(output).Elem().Kind() == reflect.Interface + + if _, ok := output.(outputWithHeadersSetup); ok || h.unwrapInterface { + h.dynamicWithHeadersSetup = true } - return output + if _, ok := output.(Setter); ok || h.unwrapInterface { + h.dynamicSetter = true + } + + if _, ok := output.(rest.ETagged); ok || h.unwrapInterface { + h.dynamicETagged = true + } + + if _, ok := output.(noContent); ok || h.unwrapInterface { + h.dynamicNoContent = true + } + + h.setupHeadersEncoder(output, ht) + h.setupCookiesEncoder(output, ht) + h.setupContentTypeBodyEncoder(output) + + if h.outputBufferType.Kind() == reflect.Ptr { + h.outputBufferType = h.outputBufferType.Elem() + } + + if !rest.OutputHasNoContent(output) { + h.skipRendering = false + } + + if _, ok := output.(usecase.OutputWithWriter); ok { + h.skipRendering = true + h.outputWithWriter = true + } + + if ht.SuccessStatus != 0 { + return + } + + ht.SuccessStatus = h.successStatus(output) } -func (h *Encoder) setupHeadersEncoder(output interface{}, ht *rest.HandlerTrait) { - // Enable dynamic headers check in interface mode. - if h.unwrapInterface { - enc := form.NewEncoder() - enc.SetMode(form.ModeExplicit) - enc.SetTagName(string(rest.ParamInHeader)) +// WriteErrResponse encodes and writes error to response. +func (h *Encoder) WriteErrResponse(w http.ResponseWriter, r *http.Request, statusCode int, response interface{}) { + e := jsonEncoderPool.Get().(*jsonEncoder) //nolint:errcheck - h.outputHeadersEncoder = enc + e.buf.Reset() + defer jsonEncoderPool.Put(e) + + err := e.enc.Encode(response) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) return } - respHeaderMapping := ht.RespHeaderMapping - if len(respHeaderMapping) == 0 && refl.HasTaggedFields(output, string(rest.ParamInHeader)) { - respHeaderMapping = make(map[string]string) + // Skip statuses that do not allow response body (1xx, 204, 304). + if !(statusCode < http.StatusOK || statusCode == http.StatusNoContent || statusCode == http.StatusNotModified) { + w.Header().Set("Content-Length", strconv.Itoa(e.buf.Len())) - refl.WalkTaggedFields(reflect.ValueOf(output), func(_ reflect.Value, sf reflect.StructField, _ string) { - // Converting name to canonical form, while keeping omitempty and any other options. - t := sf.Tag.Get(string(rest.ParamInHeader)) - parts := strings.Split(t, ",") - parts[0] = http.CanonicalHeaderKey(parts[0]) - t = strings.Join(parts, ",") + contentType := DefaultErrorResponseContentType + w.Header().Set("Content-Type", contentType) + } - respHeaderMapping[sf.Name] = t - }, string(rest.ParamInHeader)) + w.WriteHeader(statusCode) + + if r.Method == http.MethodHead { + return } - if len(respHeaderMapping) > 0 { - enc := form.NewEncoder() - enc.SetMode(form.ModeExplicit) - enc.RegisterTagNameFunc(func(field reflect.StructField) string { - if name, ok := respHeaderMapping[field.Name]; ok { - return name + _, err = w.Write(e.buf.Bytes()) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + + return + } +} + +// WriteSuccessfulResponse encodes and writes successful output of use case interactor to http response. +func (h *Encoder) WriteSuccessfulResponse( + w http.ResponseWriter, + r *http.Request, + output interface{}, + ht rest.HandlerTrait, +) { + if h.unwrapInterface { + output = reflect.ValueOf(output).Elem().Interface() + } + + if h.dynamicETagged { + if etagged, ok := output.(rest.ETagged); ok { + etag := etagged.ETag() + if etag != "" { + w.Header().Set("Etag", etag) } + } + } - if field.Anonymous { - return "" + if !h.writeHeader(w, r, output, ht) { + return + } + + if !h.writeCookies(w, r, output, ht) { + return + } + + if h.outputContentTypeBodyEncoder != nil && h.writeRawResponse(w, r, output, ht) { + return + } + + skipRendering := h.skipRendering + if !skipRendering && h.dynamicNoContent { + if nc, ok := output.(noContent); ok { + skipRendering = nc.NoContent() + if skipRendering && ht.SuccessStatus == 0 { + ht.SuccessStatus = http.StatusNoContent } + } + } - return "-" - }) + if ht.SuccessStatus == 0 { + ht.SuccessStatus = h.successStatus(output) + } - h.outputHeadersEncoder = enc + if skipRendering { + if !h.outputWithWriter && !h.dynamicSetter && ht.SuccessStatus != http.StatusOK { + w.WriteHeader(ht.SuccessStatus) + } + + return } + + h.writeJSONResponse(w, r, output, ht) } func (h *Encoder) setupContentTypeBodyEncoder(output interface{}) { @@ -229,88 +346,146 @@ func (h *Encoder) setupCookiesEncoder(output interface{}, ht *rest.HandlerTrait) } } -// SetupOutput configures encoder with and instance of use case output. -func (h *Encoder) SetupOutput(output interface{}, ht *rest.HandlerTrait) { - h.outputBufferType = reflect.TypeOf(output) - h.outputHeadersEncoder = nil - h.skipRendering = true +func (h *Encoder) setupHeadersEncoder(output interface{}, ht *rest.HandlerTrait) { + // Enable dynamic headers check in interface mode. + if h.unwrapInterface { + enc := form.NewEncoder() + enc.SetMode(form.ModeExplicit) + enc.SetTagName(string(rest.ParamInHeader)) + + h.outputHeadersEncoder = enc - if output == nil { return } - output = addressable(output) + respHeaderMapping := ht.RespHeaderMapping + if len(respHeaderMapping) == 0 && refl.HasTaggedFields(output, string(rest.ParamInHeader)) { + respHeaderMapping = make(map[string]string) - h.unwrapInterface = reflect.ValueOf(output).Elem().Kind() == reflect.Interface + refl.WalkTaggedFields(reflect.ValueOf(output), func(_ reflect.Value, sf reflect.StructField, _ string) { + // Converting name to canonical form, while keeping omitempty and any other options. + t := sf.Tag.Get(string(rest.ParamInHeader)) + parts := strings.Split(t, ",") + parts[0] = http.CanonicalHeaderKey(parts[0]) + t = strings.Join(parts, ",") - if _, ok := output.(outputWithHeadersSetup); ok || h.unwrapInterface { - h.dynamicWithHeadersSetup = true + respHeaderMapping[sf.Name] = t + }, string(rest.ParamInHeader)) } - if _, ok := output.(Setter); ok || h.unwrapInterface { - h.dynamicSetter = true + if len(respHeaderMapping) > 0 { + enc := form.NewEncoder() + enc.SetMode(form.ModeExplicit) + enc.RegisterTagNameFunc(func(field reflect.StructField) string { + if name, ok := respHeaderMapping[field.Name]; ok { + return name + } + + if field.Anonymous { + return "" + } + + return "-" + }) + + h.outputHeadersEncoder = enc } +} - if _, ok := output.(rest.ETagged); ok || h.unwrapInterface { - h.dynamicETagged = true +func (h *Encoder) successStatus(output interface{}) int { + if outputWithStatus, ok := output.(rest.OutputWithHTTPStatus); ok { + return outputWithStatus.HTTPStatus() } - if _, ok := output.(noContent); ok || h.unwrapInterface { - h.dynamicNoContent = true + if h.skipRendering && !h.outputWithWriter { + return http.StatusNoContent } - h.setupHeadersEncoder(output, ht) - h.setupCookiesEncoder(output, ht) - h.setupContentTypeBodyEncoder(output) + return http.StatusOK +} - if h.outputBufferType.Kind() == reflect.Ptr { - h.outputBufferType = h.outputBufferType.Elem() +func (h *Encoder) writeCookies(w http.ResponseWriter, r *http.Request, output interface{}, ht rest.HandlerTrait) bool { + if h.outputCookiesEncoder == nil { + return true } - if !rest.OutputHasNoContent(output) { - h.skipRendering = false - } + cookies, err := h.outputCookiesEncoder.Encode(output, nil) + if err != nil { + h.writeError(err, w, r, ht) - if _, ok := output.(usecase.OutputWithWriter); ok { - h.skipRendering = true - h.outputWithWriter = true + return false } - if ht.SuccessStatus != 0 { - return + if h.outputCookieBase != nil { + for _, c := range h.outputCookieBase { + if val, ok := cookies[c.Name]; ok && len(val) == 1 && val[0] != "" { + c := c + c.Value = val[0] + + http.SetCookie(w, &c) + } + } + } else { + for cookie, val := range cookies { + c := http.Cookie{} + c.Name = cookie + c.Value = val[0] + + http.SetCookie(w, &c) + } } - ht.SuccessStatus = h.successStatus(output) + return true } -func (h *Encoder) successStatus(output interface{}) int { - if outputWithStatus, ok := output.(rest.OutputWithHTTPStatus); ok { - return outputWithStatus.HTTPStatus() +func (h *Encoder) writeError(err error, w http.ResponseWriter, r *http.Request, ht rest.HandlerTrait) { + if ht.MakeErrResp != nil { + code, er := ht.MakeErrResp(r.Context(), err) + h.WriteErrResponse(w, r, code, er) + } else { + code, er := rest.Err(err) + h.WriteErrResponse(w, r, code, er) } +} - if h.skipRendering && !h.outputWithWriter { - return http.StatusNoContent +func (h *Encoder) writeHeader(w http.ResponseWriter, r *http.Request, output interface{}, ht rest.HandlerTrait) bool { + if h.dynamicWithHeadersSetup { + if sh, ok := output.(outputWithHeadersSetup); ok { + sh.SetupResponseHeader(w.Header()) + } } - return http.StatusOK -} + if h.outputHeadersEncoder == nil { + return true + } -type jsonEncoder struct { - enc *json.Encoder - buf *bytes.Buffer -} + var goValues map[string]interface{} + if ht.RespValidator != nil { + goValues = make(map[string]interface{}) + } -var jsonEncoderPool = sync.Pool{ - New: func() interface{} { - buf := bytes.NewBuffer(nil) - enc := json.NewEncoder(buf) - enc.SetEscapeHTML(false) + headers, err := h.outputHeadersEncoder.Encode(output, goValues) + if err != nil { + h.writeError(err, w, r, ht) - return &jsonEncoder{ - enc: enc, - buf: buf, + return false + } + + if ht.RespValidator != nil { + if err := ht.RespValidator.ValidateData(rest.ParamInHeader, goValues); err != nil { + h.writeError(status.Wrap(fmt.Errorf("bad response: %w", err), status.Internal), w, r, ht) + + return false } - }, + } + + for header, val := range headers { + if len(val) == 1 { + w.Header().Set(header, val[0]) + } + } + + return true } func (h *Encoder) writeJSONResponse( @@ -373,149 +548,6 @@ func (h *Encoder) writeJSONResponse( } } -// WriteErrResponse encodes and writes error to response. -func (h *Encoder) WriteErrResponse(w http.ResponseWriter, r *http.Request, statusCode int, response interface{}) { - e := jsonEncoderPool.Get().(*jsonEncoder) //nolint:errcheck - - e.buf.Reset() - defer jsonEncoderPool.Put(e) - - err := e.enc.Encode(response) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - - return - } - - // Skip statuses that do not allow response body (1xx, 204, 304). - if !(statusCode < http.StatusOK || statusCode == http.StatusNoContent || statusCode == http.StatusNotModified) { - w.Header().Set("Content-Length", strconv.Itoa(e.buf.Len())) - - contentType := DefaultErrorResponseContentType - w.Header().Set("Content-Type", contentType) - } - - w.WriteHeader(statusCode) - - if r.Method == http.MethodHead { - return - } - - _, err = w.Write(e.buf.Bytes()) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - - return - } -} - -// WriteSuccessfulResponse encodes and writes successful output of use case interactor to http response. -func (h *Encoder) WriteSuccessfulResponse( - w http.ResponseWriter, - r *http.Request, - output interface{}, - ht rest.HandlerTrait, -) { - if h.unwrapInterface { - output = reflect.ValueOf(output).Elem().Interface() - } - - if h.dynamicETagged { - if etagged, ok := output.(rest.ETagged); ok { - etag := etagged.ETag() - if etag != "" { - w.Header().Set("Etag", etag) - } - } - } - - if !h.writeHeader(w, r, output, ht) { - return - } - - if !h.writeCookies(w, r, output, ht) { - return - } - - if h.outputContentTypeBodyEncoder != nil && h.writeRawResponse(w, r, output, ht) { - return - } - - skipRendering := h.skipRendering - if !skipRendering && h.dynamicNoContent { - if nc, ok := output.(noContent); ok { - skipRendering = nc.NoContent() - if skipRendering && ht.SuccessStatus == 0 { - ht.SuccessStatus = http.StatusNoContent - } - } - } - - if ht.SuccessStatus == 0 { - ht.SuccessStatus = h.successStatus(output) - } - - if skipRendering { - if !h.outputWithWriter && !h.dynamicSetter && ht.SuccessStatus != http.StatusOK { - w.WriteHeader(ht.SuccessStatus) - } - - return - } - - h.writeJSONResponse(w, r, output, ht) -} - -func (h *Encoder) writeError(err error, w http.ResponseWriter, r *http.Request, ht rest.HandlerTrait) { - if ht.MakeErrResp != nil { - code, er := ht.MakeErrResp(r.Context(), err) - h.WriteErrResponse(w, r, code, er) - } else { - code, er := rest.Err(err) - h.WriteErrResponse(w, r, code, er) - } -} - -func (h *Encoder) writeHeader(w http.ResponseWriter, r *http.Request, output interface{}, ht rest.HandlerTrait) bool { - if h.dynamicWithHeadersSetup { - if sh, ok := output.(outputWithHeadersSetup); ok { - sh.SetupResponseHeader(w.Header()) - } - } - - if h.outputHeadersEncoder == nil { - return true - } - - var goValues map[string]interface{} - if ht.RespValidator != nil { - goValues = make(map[string]interface{}) - } - - headers, err := h.outputHeadersEncoder.Encode(output, goValues) - if err != nil { - h.writeError(err, w, r, ht) - - return false - } - - if ht.RespValidator != nil { - if err := ht.RespValidator.ValidateData(rest.ParamInHeader, goValues); err != nil { - h.writeError(status.Wrap(fmt.Errorf("bad response: %w", err), status.Internal), w, r, ht) - - return false - } - } - - for header, val := range headers { - if len(val) == 1 { - w.Header().Set(header, val[0]) - } - } - - return true -} - func (h *Encoder) writeRawResponse(w http.ResponseWriter, r *http.Request, output interface{}, ht rest.HandlerTrait) bool { values, err := h.outputContentTypeBodyEncoder.Encode(output) if err != nil { @@ -539,70 +571,53 @@ func (h *Encoder) writeRawResponse(w http.ResponseWriter, r *http.Request, outpu return false } -func (h *Encoder) writeCookies(w http.ResponseWriter, r *http.Request, output interface{}, ht rest.HandlerTrait) bool { - if h.outputCookiesEncoder == nil { - return true - } - - cookies, err := h.outputCookiesEncoder.Encode(output, nil) - if err != nil { - h.writeError(err, w, r, ht) - - return false +type ( + // Setter captures original http.ResponseWriter. + // + // Implement this interface on a pointer to your output structure to get access to http.ResponseWriter. + Setter interface { + SetResponseWriter(rw http.ResponseWriter) } +) - if h.outputCookieBase != nil { - for _, c := range h.outputCookieBase { - if val, ok := cookies[c.Name]; ok && len(val) == 1 && val[0] != "" { - c := c - c.Value = val[0] - - http.SetCookie(w, &c) - } - } - } else { - for cookie, val := range cookies { - c := http.Cookie{} - c.Name = cookie - c.Value = val[0] +var jsonEncoderPool = sync.Pool{ + New: func() interface{} { + buf := bytes.NewBuffer(nil) + enc := json.NewEncoder(buf) + enc.SetEscapeHTML(false) - http.SetCookie(w, &c) + return &jsonEncoder{ + enc: enc, + buf: buf, } - } - - return true + }, } -// MakeOutput instantiates a value for use case output port. -func (h *Encoder) MakeOutput(w http.ResponseWriter, ht rest.HandlerTrait) interface{} { - if h.outputBufferType == nil { - return nil +// addressable makes a pointer from a non-pointer values. +func addressable(output interface{}) interface{} { + if reflect.ValueOf(output).Kind() != reflect.Ptr { + o := reflect.New(reflect.TypeOf(output)) + o.Elem().Set(reflect.ValueOf(output)) + + output = o.Interface() } - output := reflect.New(h.outputBufferType).Interface() + return output +} - if h.outputWithWriter { - if withWriter, ok := output.(usecase.OutputWithWriter); ok { - if h.outputHeadersEncoder != nil || ht.SuccessContentType != "" { - withWriter.SetWriter(&writerWithHeaders{ - ResponseWriter: w, - responseWriter: h, - trait: ht, - output: output, - }) - } else { - withWriter.SetWriter(w) - } - } - } +type jsonEncoder struct { + enc *json.Encoder + buf *bytes.Buffer +} - if h.dynamicSetter { - if setter, ok := output.(Setter); ok { - setter.SetResponseWriter(w) - } - } +type noContent interface { + // NoContent controls whether status 204 should be used in response to current request. + NoContent() bool +} - return output +type outputWithHeadersSetup interface { + // SetupResponseHeader gives access to response headers of current request. + SetupResponseHeader(h http.Header) } type writerWithHeaders struct { @@ -614,25 +629,6 @@ type writerWithHeaders struct { headersSet bool } -func (w *writerWithHeaders) setHeaders() error { - if w.responseWriter.outputHeadersEncoder == nil { - return nil - } - - headers, err := w.responseWriter.outputHeadersEncoder.Encode(w.output) - if err != nil { - return err - } - - for header, val := range headers { - if len(val) == 1 { - w.Header().Set(header, val[0]) - } - } - - return err -} - func (w *writerWithHeaders) Write(data []byte) (int, error) { if !w.headersSet { if err := w.setHeaders(); err != nil { @@ -649,17 +645,21 @@ func (w *writerWithHeaders) Write(data []byte) (int, error) { return w.ResponseWriter.Write(data) } -// EmbeddedSetter can capture http.ResponseWriter in your output structure. -type EmbeddedSetter struct { - rw http.ResponseWriter -} +func (w *writerWithHeaders) setHeaders() error { + if w.responseWriter.outputHeadersEncoder == nil { + return nil + } -// SetResponseWriter implements Setter. -func (e *EmbeddedSetter) SetResponseWriter(rw http.ResponseWriter) { - e.rw = rw -} + headers, err := w.responseWriter.outputHeadersEncoder.Encode(w.output) + if err != nil { + return err + } -// ResponseWriter is an accessor. -func (e *EmbeddedSetter) ResponseWriter() http.ResponseWriter { - return e.rw + for header, val := range headers { + if len(val) == 1 { + w.Header().Set(header, val[0]) + } + } + + return err } diff --git a/response/encoder_test.go b/response/encoder_test.go index 6b1f27f..a6e6027 100644 --- a/response/encoder_test.go +++ b/response/encoder_test.go @@ -13,7 +13,7 @@ import ( "github.com/swaggest/usecase" ) -func TestEncoder_SetupOutput(t *testing.T) { +func TestEmbeddedSetter_SetResponseWriter(t *testing.T) { e := response.Encoder{} type EmbeddedHeader struct { @@ -22,6 +22,7 @@ func TestEncoder_SetupOutput(t *testing.T) { } type outputPort struct { + response.EmbeddedSetter EmbeddedHeader Name string `header:"X-Name" json:"-"` Items []string `json:"items"` @@ -73,99 +74,50 @@ func TestEncoder_SetupOutput(t *testing.T) { assert.Equal(t, "application/x-vnd-json", w.Header().Get("Content-Type")) assert.Equal(t, "32", w.Header().Get("Content-Length")) assert.Equal(t, `{"items":["one","two","three"]}`+"\n", w.Body.String()) - - w = httptest.NewRecorder() - e.WriteErrResponse(w, r, http.StatusExpectationFailed, rest.ErrResponse{ - ErrorText: "failed", - }) - assert.Equal(t, http.StatusExpectationFailed, w.Code) - assert.Equal(t, "application/json", w.Header().Get("Content-Type")) - assert.Equal(t, "19", w.Header().Get("Content-Length")) - assert.Equal(t, `{"error":"failed"}`+"\n", w.Body.String()) - - out.Name = "Ja" - w = httptest.NewRecorder() - e.WriteSuccessfulResponse(w, r, output, ht) - assert.Equal(t, http.StatusInternalServerError, w.Code) - assert.Equal(t, "", w.Header().Get("X-Name")) - assert.Equal(t, "application/json", w.Header().Get("Content-Type")) - assert.Equal(t, "140", w.Header().Get("Content-Length")) - assert.Equal(t, `{"status":"INTERNAL","error":"internal: bad response: validation failed",`+ - `"context":{"header:X-Name":["#: length must be >= 3, but got 2"]}}`+"\n", w.Body.String()) + assert.Equal(t, w, out.ResponseWriter()) } -func TestEncoder_SetupOutput_withWriter(t *testing.T) { - e := response.Encoder{} - - ht := rest.HandlerTrait{ - SuccessContentType: "application/x-vnd-foo", +func TestEncoder_contentTypeRaw(t *testing.T) { + type Resp struct { + TextBody string `contentType:"text/plain"` + CSVBody string `contentType:"text/csv"` } - type outputPort struct { - Name string `header:"X-Name" json:"-"` - usecase.OutputWithEmbeddedWriter - } + ht := rest.HandlerTrait{} - e.SetupOutput(new(outputPort), &ht) + e := response.Encoder{} + e.SetupOutput(Resp{}, &ht) w := httptest.NewRecorder() - r, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) - output := e.MakeOutput(w, ht) + v := e.MakeOutput(w, ht) - out, ok := output.(*outputPort) + re, ok := v.(*Resp) assert.True(t, ok) - out.Name = "Jane" + re.CSVBody = "hello,world" - _, err = out.Write([]byte("1,2,3")) - require.NoError(t, err) + e.WriteSuccessfulResponse(w, nil, v, ht) - e.WriteSuccessfulResponse(w, r, output, ht) assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "application/x-vnd-foo", w.Header().Get("Content-Type")) - assert.Equal(t, "1,2,3", w.Body.String()) - assert.Equal(t, "Jane", w.Header().Get("X-Name")) + assert.Equal(t, "hello,world", w.Body.String()) + assert.Equal(t, "text/csv", w.Header().Get("Content-Type")) } -func TestEncoder_SetupOutput_withWriterContentType(t *testing.T) { +func TestEncoder_SetupOutput(t *testing.T) { e := response.Encoder{} - ht := rest.HandlerTrait{ - SuccessContentType: "application/x-vnd-foo", - } - - type outputPort struct { - usecase.OutputWithEmbeddedWriter + type EmbeddedHeader struct { + Foo int `header:"X-Foo" json:"-"` + Bar string `cookie:"bar" json:"-"` } - e.SetupOutput(new(outputPort), &ht) - - w := httptest.NewRecorder() - r, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) - - output := e.MakeOutput(w, ht) - - out, ok := output.(*outputPort) - assert.True(t, ok) - - _, err = out.Write([]byte("1,2,3")) - require.NoError(t, err) - - e.WriteSuccessfulResponse(w, r, output, ht) - assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "application/x-vnd-foo", w.Header().Get("Content-Type")) - assert.Equal(t, "1,2,3", w.Body.String()) -} - -func TestEncoder_SetupOutput_nonPtr(t *testing.T) { - e := response.Encoder{} - type outputPort struct { - Name string `header:"X-Name" json:"-"` - Items []string `json:"items"` + EmbeddedHeader + Name string `header:"X-Name" json:"-"` + Items []string `json:"items"` + Cookie int `cookie:"coo,httponly,path=/foo" json:"-"` + Cookie2 bool `cookie:"coo2,httponly,secure,samesite=lax,path=/foo,max-age=86400" json:"-"` } ht := rest.HandlerTrait{ @@ -193,28 +145,44 @@ func TestEncoder_SetupOutput_nonPtr(t *testing.T) { out, ok := output.(*outputPort) assert.True(t, ok) + out.Foo = 321 + out.Bar = "baz" out.Name = "Jane" out.Items = []string{"one", "two", "three"} + out.Cookie = 123 + out.Cookie2 = true e.WriteSuccessfulResponse(w, r, output, ht) assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "Jane", w.Header().Get("X-Name")) + assert.Equal(t, "321", w.Header().Get("X-Foo")) + assert.Equal(t, []string{ + "bar=baz", + "coo=123; Path=/foo; HttpOnly", + "coo2=true; Path=/foo; Max-Age=86400; HttpOnly; Secure; SameSite=Lax", + }, w.Header()["Set-Cookie"]) assert.Equal(t, "application/x-vnd-json", w.Header().Get("Content-Type")) assert.Equal(t, "32", w.Header().Get("Content-Length")) assert.Equal(t, `{"items":["one","two","three"]}`+"\n", w.Body.String()) -} -// Output that implements OutputWithHTTPStatus interface. -type outputWithHTTPStatuses struct { - Number int `json:"number"` -} - -func (outputWithHTTPStatuses) HTTPStatus() int { - return http.StatusCreated -} + w = httptest.NewRecorder() + e.WriteErrResponse(w, r, http.StatusExpectationFailed, rest.ErrResponse{ + ErrorText: "failed", + }) + assert.Equal(t, http.StatusExpectationFailed, w.Code) + assert.Equal(t, "application/json", w.Header().Get("Content-Type")) + assert.Equal(t, "19", w.Header().Get("Content-Length")) + assert.Equal(t, `{"error":"failed"}`+"\n", w.Body.String()) -func (outputWithHTTPStatuses) ExpectedHTTPStatuses() []int { - return []int{http.StatusCreated, http.StatusOK} + out.Name = "Ja" + w = httptest.NewRecorder() + e.WriteSuccessfulResponse(w, r, output, ht) + assert.Equal(t, http.StatusInternalServerError, w.Code) + assert.Equal(t, "", w.Header().Get("X-Name")) + assert.Equal(t, "application/json", w.Header().Get("Content-Type")) + assert.Equal(t, "140", w.Header().Get("Content-Length")) + assert.Equal(t, `{"status":"INTERNAL","error":"internal: bad response: validation failed",`+ + `"context":{"header:X-Name":["#: length must be >= 3, but got 2"]}}`+"\n", w.Body.String()) } func TestEncoder_SetupOutput_httpStatus(t *testing.T) { @@ -224,34 +192,12 @@ func TestEncoder_SetupOutput_httpStatus(t *testing.T) { assert.Equal(t, http.StatusCreated, ht.SuccessStatus) } -func TestEncoder_Writer_httpStatus(t *testing.T) { - e := response.Encoder{} - e.SetupOutput(outputWithHTTPStatuses{}, &rest.HandlerTrait{}) - - r, err := http.NewRequest(http.MethodPost, "/", nil) - require.NoError(t, err) - - w := httptest.NewRecorder() - output := e.MakeOutput(w, rest.HandlerTrait{}) - e.WriteSuccessfulResponse(w, r, output, rest.HandlerTrait{}) - assert.Equal(t, http.StatusCreated, w.Code) -} - -func TestEmbeddedSetter_SetResponseWriter(t *testing.T) { +func TestEncoder_SetupOutput_nonPtr(t *testing.T) { e := response.Encoder{} - type EmbeddedHeader struct { - Foo int `header:"X-Foo" json:"-"` - Bar string `cookie:"bar" json:"-"` - } - type outputPort struct { - response.EmbeddedSetter - EmbeddedHeader - Name string `header:"X-Name" json:"-"` - Items []string `json:"items"` - Cookie int `cookie:"coo,httponly,path=/foo" json:"-"` - Cookie2 bool `cookie:"coo2,httponly,secure,samesite=lax,path=/foo,max-age=86400" json:"-"` + Name string `header:"X-Name" json:"-"` + Items []string `json:"items"` } ht := rest.HandlerTrait{ @@ -279,51 +225,105 @@ func TestEmbeddedSetter_SetResponseWriter(t *testing.T) { out, ok := output.(*outputPort) assert.True(t, ok) - out.Foo = 321 - out.Bar = "baz" out.Name = "Jane" out.Items = []string{"one", "two", "three"} - out.Cookie = 123 - out.Cookie2 = true e.WriteSuccessfulResponse(w, r, output, ht) assert.Equal(t, http.StatusOK, w.Code) assert.Equal(t, "Jane", w.Header().Get("X-Name")) - assert.Equal(t, "321", w.Header().Get("X-Foo")) - assert.Equal(t, []string{ - "bar=baz", - "coo=123; Path=/foo; HttpOnly", - "coo2=true; Path=/foo; Max-Age=86400; HttpOnly; Secure; SameSite=Lax", - }, w.Header()["Set-Cookie"]) assert.Equal(t, "application/x-vnd-json", w.Header().Get("Content-Type")) assert.Equal(t, "32", w.Header().Get("Content-Length")) assert.Equal(t, `{"items":["one","two","three"]}`+"\n", w.Body.String()) - assert.Equal(t, w, out.ResponseWriter()) } -func TestEncoder_contentTypeRaw(t *testing.T) { - type Resp struct { - TextBody string `contentType:"text/plain"` - CSVBody string `contentType:"text/csv"` +func TestEncoder_SetupOutput_withWriter(t *testing.T) { + e := response.Encoder{} + + ht := rest.HandlerTrait{ + SuccessContentType: "application/x-vnd-foo", } - ht := rest.HandlerTrait{} + type outputPort struct { + Name string `header:"X-Name" json:"-"` + usecase.OutputWithEmbeddedWriter + } - e := response.Encoder{} - e.SetupOutput(Resp{}, &ht) + e.SetupOutput(new(outputPort), &ht) w := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) - v := e.MakeOutput(w, ht) + output := e.MakeOutput(w, ht) - re, ok := v.(*Resp) + out, ok := output.(*outputPort) assert.True(t, ok) - re.CSVBody = "hello,world" + out.Name = "Jane" - e.WriteSuccessfulResponse(w, nil, v, ht) + _, err = out.Write([]byte("1,2,3")) + require.NoError(t, err) + e.WriteSuccessfulResponse(w, r, output, ht) assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "hello,world", w.Body.String()) - assert.Equal(t, "text/csv", w.Header().Get("Content-Type")) + assert.Equal(t, "application/x-vnd-foo", w.Header().Get("Content-Type")) + assert.Equal(t, "1,2,3", w.Body.String()) + assert.Equal(t, "Jane", w.Header().Get("X-Name")) +} + +func TestEncoder_SetupOutput_withWriterContentType(t *testing.T) { + e := response.Encoder{} + + ht := rest.HandlerTrait{ + SuccessContentType: "application/x-vnd-foo", + } + + type outputPort struct { + usecase.OutputWithEmbeddedWriter + } + + e.SetupOutput(new(outputPort), &ht) + + w := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, "/", nil) + require.NoError(t, err) + + output := e.MakeOutput(w, ht) + + out, ok := output.(*outputPort) + assert.True(t, ok) + + _, err = out.Write([]byte("1,2,3")) + require.NoError(t, err) + + e.WriteSuccessfulResponse(w, r, output, ht) + assert.Equal(t, http.StatusOK, w.Code) + assert.Equal(t, "application/x-vnd-foo", w.Header().Get("Content-Type")) + assert.Equal(t, "1,2,3", w.Body.String()) +} + +func TestEncoder_Writer_httpStatus(t *testing.T) { + e := response.Encoder{} + e.SetupOutput(outputWithHTTPStatuses{}, &rest.HandlerTrait{}) + + r, err := http.NewRequest(http.MethodPost, "/", nil) + require.NoError(t, err) + + w := httptest.NewRecorder() + output := e.MakeOutput(w, rest.HandlerTrait{}) + e.WriteSuccessfulResponse(w, r, output, rest.HandlerTrait{}) + assert.Equal(t, http.StatusCreated, w.Code) +} + +// Output that implements OutputWithHTTPStatus interface. +type outputWithHTTPStatuses struct { + Number int `json:"number"` +} + +func (outputWithHTTPStatuses) ExpectedHTTPStatuses() []int { + return []int{http.StatusCreated, http.StatusOK} +} + +func (outputWithHTTPStatuses) HTTPStatus() int { + return http.StatusCreated } diff --git a/response/gzip/middleware.go b/response/gzip/middleware.go index ac942fb..f877e1e 100644 --- a/response/gzip/middleware.go +++ b/response/gzip/middleware.go @@ -16,15 +16,6 @@ import ( gz "github.com/swaggest/rest/gzip" ) -const ( - contentTypeHeader = "Content-Type" - contentLengthHeader = "Content-Length" - contentEncodingHeader = "Content-Encoding" - acceptEncodingHeader = "Accept-Encoding" - - defaultBufferSize = 8 * 1024 -) - // Middleware enables gzip compression of handler response for requests that accept gzip encoding. func Middleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -42,11 +33,47 @@ func Middleware(next http.Handler) http.Handler { }) } +const ( + contentTypeHeader = "Content-Type" + contentLengthHeader = "Content-Length" + contentEncodingHeader = "Content-Encoding" + acceptEncodingHeader = "Accept-Encoding" + + defaultBufferSize = 8 * 1024 +) + +var ( + _ gz.Writer = &gzipResponseWriter{} + _ gz.Writer = &gzipResponseWriterHijacker{} + + _ http.ResponseWriter = &gzipResponseWriter{} + _ http.ResponseWriter = &gzipResponseWriterHijacker{} + + _ http.Flusher = &gzipResponseWriter{} + _ http.Flusher = &gzipResponseWriterHijacker{} + + _ http.Hijacker = &gzipResponseWriterHijacker{} +) + var ( gzipWriterPool sync.Pool bufWriterPool sync.Pool ) +func getBufWriter(w io.Writer) *bufio.Writer { + v := bufWriterPool.Get() + if v == nil { + return bufio.NewWriterSize(w, defaultBufferSize) + } + + //nolint:errcheck // OK to panic here. + bw := v.(*bufio.Writer) + + bw.Reset(w) + + return bw +} + func getGzipWriter(w io.Writer) *gzip.Writer { v := gzipWriterPool.Get() if v == nil { @@ -66,18 +93,13 @@ func getGzipWriter(w io.Writer) *gzip.Writer { return zw } -func getBufWriter(w io.Writer) *bufio.Writer { - v := bufWriterPool.Get() - if v == nil { - return bufio.NewWriterSize(w, defaultBufferSize) +func isTrivialNetworkError(err error) bool { + s := err.Error() + if strings.Contains(s, "broken pipe") || strings.Contains(s, "reset by peer") { + return true } - //nolint:errcheck // OK to panic here. - bw := v.(*bufio.Writer) - - bw.Reset(w) - - return bw + return false } func maybeGzipResponseWriter(w http.ResponseWriter, r *http.Request) http.ResponseWriter { @@ -110,6 +132,14 @@ func maybeGzipResponseWriter(w http.ResponseWriter, r *http.Request) http.Respon return zrw } +func putBufWriter(bw *bufio.Writer) { + bufWriterPool.Put(bw) +} + +func putGzipWriter(zw *gzip.Writer) { + gzipWriterPool.Put(zw) +} + type gzipResponseWriter struct { http.ResponseWriter gzipWriter *gzip.Writer @@ -120,27 +150,49 @@ type gzipResponseWriter struct { disableCompression bool } -type gzipResponseWriterHijacker struct { - gzipResponseWriter - hijacker http.Hijacker -} +// Close flushes and closes response. +func (rw *gzipResponseWriter) Close() error { + if !rw.headersWritten { + rw.disableCompression = true -func (rw *gzipResponseWriterHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return rw.hijacker.Hijack() + return nil + } + + if rw.bufWriter == nil || rw.gzipWriter == nil { + return nil + } + + rw.Flush() + + err := rw.gzipWriter.Close() + + putBufWriter(rw.bufWriter) + rw.bufWriter = nil + + putGzipWriter(rw.gzipWriter) + rw.gzipWriter = nil + + return err } -var ( - _ gz.Writer = &gzipResponseWriter{} - _ gz.Writer = &gzipResponseWriterHijacker{} +// Flush implements http.Flusher. +func (rw *gzipResponseWriter) Flush() { + if rw.bufWriter == nil || rw.gzipWriter == nil { + return + } - _ http.ResponseWriter = &gzipResponseWriter{} - _ http.ResponseWriter = &gzipResponseWriterHijacker{} + if err := rw.bufWriter.Flush(); err != nil && !isTrivialNetworkError(err) { + panic(fmt.Sprintf("BUG: cannot flush bufio.Writer: %s", err)) + } - _ http.Flusher = &gzipResponseWriter{} - _ http.Flusher = &gzipResponseWriterHijacker{} + if err := rw.gzipWriter.Flush(); err != nil && !isTrivialNetworkError(err) { + panic(fmt.Sprintf("BUG: cannot flush gzip.Writer: %s", err)) + } - _ http.Hijacker = &gzipResponseWriterHijacker{} -) + if fw, ok := rw.ResponseWriter.(http.Flusher); ok { + fw.Flush() + } +} func (rw *gzipResponseWriter) GzipWrite(data []byte) (int, error) { if rw.headersWritten { @@ -152,6 +204,22 @@ func (rw *gzipResponseWriter) GzipWrite(data []byte) (int, error) { return rw.Write(data) } +func (rw *gzipResponseWriter) Write(p []byte) (int, error) { + if !rw.headersWritten { + rw.writeHeader(http.StatusOK) + } + + if rw.disableCompression || rw.expectCompressedBytes { + return rw.ResponseWriter.Write(p) + } + + return rw.bufWriter.Write(p) +} + +func (rw *gzipResponseWriter) WriteHeader(statusCode int) { + rw.writeHeader(statusCode) +} + func (rw *gzipResponseWriter) writeHeader(statusCode int) { //nolint:funlen if rw.headersWritten { return @@ -248,79 +316,11 @@ func (rw *gzipResponseWriter) writeHeader(statusCode int) { //nolint:funlen rw.headersWritten = true } -func (rw *gzipResponseWriter) Write(p []byte) (int, error) { - if !rw.headersWritten { - rw.writeHeader(http.StatusOK) - } - - if rw.disableCompression || rw.expectCompressedBytes { - return rw.ResponseWriter.Write(p) - } - - return rw.bufWriter.Write(p) -} - -func (rw *gzipResponseWriter) WriteHeader(statusCode int) { - rw.writeHeader(statusCode) -} - -func isTrivialNetworkError(err error) bool { - s := err.Error() - if strings.Contains(s, "broken pipe") || strings.Contains(s, "reset by peer") { - return true - } - - return false -} - -// Flush implements http.Flusher. -func (rw *gzipResponseWriter) Flush() { - if rw.bufWriter == nil || rw.gzipWriter == nil { - return - } - - if err := rw.bufWriter.Flush(); err != nil && !isTrivialNetworkError(err) { - panic(fmt.Sprintf("BUG: cannot flush bufio.Writer: %s", err)) - } - - if err := rw.gzipWriter.Flush(); err != nil && !isTrivialNetworkError(err) { - panic(fmt.Sprintf("BUG: cannot flush gzip.Writer: %s", err)) - } - - if fw, ok := rw.ResponseWriter.(http.Flusher); ok { - fw.Flush() - } -} - -// Close flushes and closes response. -func (rw *gzipResponseWriter) Close() error { - if !rw.headersWritten { - rw.disableCompression = true - - return nil - } - - if rw.bufWriter == nil || rw.gzipWriter == nil { - return nil - } - - rw.Flush() - - err := rw.gzipWriter.Close() - - putBufWriter(rw.bufWriter) - rw.bufWriter = nil - - putGzipWriter(rw.gzipWriter) - rw.gzipWriter = nil - - return err -} - -func putGzipWriter(zw *gzip.Writer) { - gzipWriterPool.Put(zw) +type gzipResponseWriterHijacker struct { + gzipResponseWriter + hijacker http.Hijacker } -func putBufWriter(bw *bufio.Writer) { - bufWriterPool.Put(bw) +func (rw *gzipResponseWriterHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { + return rw.hijacker.Hijack() } diff --git a/response/gzip/middleware_test.go b/response/gzip/middleware_test.go index 69c28ab..aac89eb 100644 --- a/response/gzip/middleware_test.go +++ b/response/gzip/middleware_test.go @@ -16,51 +16,6 @@ import ( "github.com/swaggest/rest/response/gzip" ) -func TestMiddleware(t *testing.T) { - resp := []byte(strings.Repeat("A", 10000) + "!!!") - h := gzip.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { - _, err := rw.Write(resp) - assert.NoError(t, err) - })) - - rw := httptest.NewRecorder() - r, err := http.NewRequest(http.MethodGet, "/", nil) - - require.NoError(t, err) - r.Header.Set("Accept-Encoding", "gzip, deflate, br") - - h.ServeHTTP(rw, r) - - assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) - assert.Less(t, rw.Body.Len(), len(resp)) // Response is compressed. - assert.Equal(t, resp, gzipDecode(t, rw.Body.Bytes())) - - rw = httptest.NewRecorder() - h.ServeHTTP(rw, r) - - assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) - assert.Less(t, rw.Body.Len(), len(resp)) // Response is compressed. - assert.Equal(t, resp, gzipDecode(t, rw.Body.Bytes())) - - rw = httptest.NewRecorder() - - r.Header.Set("Accept-Encoding", "deflate, br") - h.ServeHTTP(rw, r) - - assert.Equal(t, "", rw.Header().Get("Content-Encoding")) - assert.Equal(t, rw.Body.Len(), len(resp)) // Response is not compressed. - assert.Equal(t, resp, rw.Body.Bytes()) - - rw = httptest.NewRecorder() - - r.Header.Del("Accept-Encoding") - h.ServeHTTP(rw, r) - - assert.Equal(t, "", rw.Header().Get("Content-Encoding")) - assert.Equal(t, rw.Body.Len(), len(resp)) // Response is not compressed. - assert.Equal(t, resp, rw.Body.Bytes()) -} - // BenchmarkMiddleware measures performance of handler with compression. // // Sample result: @@ -111,6 +66,73 @@ func BenchmarkMiddleware_control(b *testing.B) { } } +func TestGzipResponseWriter_ExpectCompressedBytes(t *testing.T) { + resp := []byte(strings.Repeat("A", 10000) + "!!!") + respGz := gzipEncode(t, resp) + + h := gzip.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + _, err := gzip2.WriteCompressedBytes(respGz, rw) + assert.NoError(t, err) + })) + + rw := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, "/", nil) + + require.NoError(t, err) + r.Header.Set("Accept-Encoding", "gzip, deflate, br") + + h.ServeHTTP(rw, r) + + assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) + assert.Less(t, rw.Body.Len(), len(resp)) // Response is compressed. + assert.Equal(t, respGz, rw.Body.Bytes()) +} + +func TestMiddleware(t *testing.T) { + resp := []byte(strings.Repeat("A", 10000) + "!!!") + h := gzip.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + _, err := rw.Write(resp) + assert.NoError(t, err) + })) + + rw := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, "/", nil) + + require.NoError(t, err) + r.Header.Set("Accept-Encoding", "gzip, deflate, br") + + h.ServeHTTP(rw, r) + + assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) + assert.Less(t, rw.Body.Len(), len(resp)) // Response is compressed. + assert.Equal(t, resp, gzipDecode(t, rw.Body.Bytes())) + + rw = httptest.NewRecorder() + h.ServeHTTP(rw, r) + + assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) + assert.Less(t, rw.Body.Len(), len(resp)) // Response is compressed. + assert.Equal(t, resp, gzipDecode(t, rw.Body.Bytes())) + + rw = httptest.NewRecorder() + + r.Header.Set("Accept-Encoding", "deflate, br") + h.ServeHTTP(rw, r) + + assert.Equal(t, "", rw.Header().Get("Content-Encoding")) + assert.Equal(t, rw.Body.Len(), len(resp)) // Response is not compressed. + assert.Equal(t, resp, rw.Body.Bytes()) + + rw = httptest.NewRecorder() + + r.Header.Del("Accept-Encoding") + h.ServeHTTP(rw, r) + + assert.Equal(t, "", rw.Header().Get("Content-Encoding")) + assert.Equal(t, rw.Body.Len(), len(resp)) // Response is not compressed. + assert.Equal(t, resp, rw.Body.Bytes()) +} + func TestMiddleware_concurrency(t *testing.T) { resp := []byte(strings.Repeat("A", 10000) + "!!!") respGz := gzipEncode(t, resp) @@ -157,47 +179,35 @@ func TestMiddleware_concurrency(t *testing.T) { wg.Wait() } -func TestGzipResponseWriter_ExpectCompressedBytes(t *testing.T) { - resp := []byte(strings.Repeat("A", 10000) + "!!!") - respGz := gzipEncode(t, resp) - +func TestMiddleware_hijacker(t *testing.T) { + rb := []byte(strings.Repeat("A", 10000) + "!!!") h := gzip.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { - _, err := gzip2.WriteCompressedBytes(respGz, rw) + _, ok := rw.(http.Hijacker) + require.True(t, ok) + + _, err := rw.Write(rb) assert.NoError(t, err) })) - rw := httptest.NewRecorder() - r, err := http.NewRequest(http.MethodGet, "/", nil) + srv := httptest.NewServer(h) + defer srv.Close() + + r, err := http.NewRequest(http.MethodGet, srv.URL+"/", nil) require.NoError(t, err) r.Header.Set("Accept-Encoding", "gzip, deflate, br") - h.ServeHTTP(rw, r) - - assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) - assert.Less(t, rw.Body.Len(), len(resp)) // Response is compressed. - assert.Equal(t, respGz, rw.Body.Bytes()) -} - -func TestMiddleware_skipContentEncoding(t *testing.T) { - resp := []byte(strings.Repeat("A", 10000) + "!!!") - h := gzip.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { - rw.Header().Set("Content-Encoding", "br") - _, err := rw.Write(resp) - assert.NoError(t, err) - })) - - rw := httptest.NewRecorder() - r, err := http.NewRequest(http.MethodGet, "/", nil) + resp, err := http.DefaultTransport.RoundTrip(r) + require.NoError(t, err) + body, err := ioutil.ReadAll(resp.Body) require.NoError(t, err) - r.Header.Set("Accept-Encoding", "gzip, deflate, br") - h.ServeHTTP(rw, r) + require.NoError(t, resp.Body.Close()) - assert.Equal(t, "br", rw.Header().Get("Content-Encoding")) - assert.Equal(t, rw.Body.Len(), len(resp)) // Response is not compressed. - assert.Equal(t, resp, rw.Body.Bytes()) + assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding")) + assert.Less(t, len(body), len(rb)) // Response is compressed. + assert.Equal(t, rb, gzipDecode(t, body)) } func TestMiddleware_noContent(t *testing.T) { @@ -220,18 +230,25 @@ func TestMiddleware_noContent(t *testing.T) { assert.Equal(t, rw.Body.Len(), 0) } -func gzipEncode(t *testing.T, data []byte) []byte { - t.Helper() +func TestMiddleware_skipContentEncoding(t *testing.T) { + resp := []byte(strings.Repeat("A", 10000) + "!!!") + h := gzip.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { + rw.Header().Set("Content-Encoding", "br") + _, err := rw.Write(resp) + assert.NoError(t, err) + })) - b := bytes.Buffer{} - w := gz.NewWriter(&b) + rw := httptest.NewRecorder() + r, err := http.NewRequest(http.MethodGet, "/", nil) - _, err := w.Write(data) require.NoError(t, err) + r.Header.Set("Accept-Encoding", "gzip, deflate, br") - require.NoError(t, w.Close()) + h.ServeHTTP(rw, r) - return b.Bytes() + assert.Equal(t, "br", rw.Header().Get("Content-Encoding")) + assert.Equal(t, rw.Body.Len(), len(resp)) // Response is not compressed. + assert.Equal(t, resp, rw.Body.Bytes()) } func gzipDecode(t *testing.T, data []byte) []byte { @@ -250,33 +267,16 @@ func gzipDecode(t *testing.T, data []byte) []byte { return j } -func TestMiddleware_hijacker(t *testing.T) { - rb := []byte(strings.Repeat("A", 10000) + "!!!") - h := gzip.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, _ *http.Request) { - _, ok := rw.(http.Hijacker) - require.True(t, ok) - - _, err := rw.Write(rb) - assert.NoError(t, err) - })) - - srv := httptest.NewServer(h) - defer srv.Close() - - r, err := http.NewRequest(http.MethodGet, srv.URL+"/", nil) - - require.NoError(t, err) - r.Header.Set("Accept-Encoding", "gzip, deflate, br") +func gzipEncode(t *testing.T, data []byte) []byte { + t.Helper() - resp, err := http.DefaultTransport.RoundTrip(r) - require.NoError(t, err) + b := bytes.Buffer{} + w := gz.NewWriter(&b) - body, err := ioutil.ReadAll(resp.Body) + _, err := w.Write(data) require.NoError(t, err) - require.NoError(t, resp.Body.Close()) + require.NoError(t, w.Close()) - assert.Equal(t, "gzip", resp.Header.Get("Content-Encoding")) - assert.Less(t, len(body), len(rb)) // Response is compressed. - assert.Equal(t, rb, gzipDecode(t, body)) + return b.Bytes() } diff --git a/response/middleware.go b/response/middleware.go index 3fb0918..901c17c 100644 --- a/response/middleware.go +++ b/response/middleware.go @@ -8,10 +8,6 @@ import ( "github.com/swaggest/usecase" ) -type responseEncoderSetter interface { - SetResponseEncoder(responseWriter nethttp.ResponseEncoder) -} - // EncoderMiddleware instruments qualifying http.Handler with Encoder. func EncoderMiddleware(handler http.Handler) http.Handler { if nethttp.IsWrapperChecker(handler) { @@ -41,3 +37,7 @@ func EncoderMiddleware(handler http.Handler) http.Handler { return handler } + +type responseEncoderSetter interface { + SetResponseEncoder(responseWriter nethttp.ResponseEncoder) +} diff --git a/response/validator.go b/response/validator.go index f109447..1f5f429 100644 --- a/response/validator.go +++ b/response/validator.go @@ -8,10 +8,6 @@ import ( "github.com/swaggest/usecase" ) -type withRestHandler interface { - RestHandler() *rest.HandlerTrait -} - // ValidatorMiddleware sets up response validator in suitable handlers. func ValidatorMiddleware(factory rest.ResponseValidatorFactory) func(http.Handler) http.Handler { return func(handler http.Handler) http.Handler { @@ -49,3 +45,7 @@ func ValidatorMiddleware(factory rest.ResponseValidatorFactory) func(http.Handle return handler } } + +type withRestHandler interface { + RestHandler() *rest.HandlerTrait +} diff --git a/resttest/client.go b/resttest/client.go index 9d52f47..4e20359 100644 --- a/resttest/client.go +++ b/resttest/client.go @@ -4,14 +4,14 @@ import ( "github.com/bool64/httpmock" ) -// Client keeps state of expectations. -// -// Deprecated: please use httpmock.Client. -type Client = httpmock.Client - // NewClient creates client instance, baseURL may be empty if Client.SetBaseURL is used later. // // Deprecated: please use httpmock.NewClient. func NewClient(baseURL string) *Client { return httpmock.NewClient(baseURL) } + +// Client keeps state of expectations. +// +// Deprecated: please use httpmock.Client. +type Client = httpmock.Client diff --git a/resttest/server.go b/resttest/server.go index b013792..a47fb0a 100644 --- a/resttest/server.go +++ b/resttest/server.go @@ -4,6 +4,13 @@ import ( "github.com/bool64/httpmock" ) +// NewServerMock creates mocked server. +// +// Deprecated: please use httpmock.NewServer. +func NewServerMock() (*ServerMock, string) { + return httpmock.NewServer() +} + // Expectation describes expected request and defines response. // // Deprecated: please use httpmock.Expectation. @@ -11,10 +18,3 @@ type Expectation = httpmock.Expectation // ServerMock serves predefined response for predefined request. type ServerMock = httpmock.Server - -// NewServerMock creates mocked server. -// -// Deprecated: please use httpmock.NewServer. -func NewServerMock() (*ServerMock, string) { - return httpmock.NewServer() -} diff --git a/resttest/server_test.go b/resttest/server_test.go index 3d4d6aa..ebe8d4b 100644 --- a/resttest/server_test.go +++ b/resttest/server_test.go @@ -15,50 +15,98 @@ import ( "github.com/stretchr/testify/require" ) -func assertRoundTrip(t *testing.T, baseURL string, expectation httpmock.Expectation) { - t.Helper() +func TestServerMock_ExpectAsync(t *testing.T) { + sm, url := httpmock.NewServer() + sm.Expect(httpmock.Expectation{ + Method: http.MethodGet, + RequestURI: "/", + ResponseBody: []byte(`{"bar":"foo"}`), + }) + sm.ExpectAsync(httpmock.Expectation{ + Method: http.MethodGet, + RequestURI: "/async1", + ResponseBody: []byte(`{"bar":"async1"}`), + }) + sm.ExpectAsync(httpmock.Expectation{ + Method: http.MethodGet, + RequestURI: "/async2", + ResponseBody: []byte(`{"bar":"async2"}`), + Unlimited: true, + }) - var bodyReader io.Reader + wg := sync.WaitGroup{} + wg.Add(2) - if expectation.RequestBody != nil { - bodyReader = bytes.NewReader(expectation.RequestBody) - } + go func() { + defer wg.Done() - req, err := http.NewRequest(expectation.Method, baseURL+expectation.RequestURI, bodyReader) - require.NoError(t, err) + req, err := http.NewRequest(http.MethodGet, url+"/async1", nil) + require.NoError(t, err) - for k, v := range expectation.RequestHeader { - req.Header.Set(k, v) - } + resp, err := http.DefaultTransport.RoundTrip(req) + require.NoError(t, err) - for n, v := range expectation.RequestCookie { - req.AddCookie(&http.Cookie{Name: n, Value: v}) - } + body, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + + require.NoError(t, resp.Body.Close()) + assert.Equal(t, `{"bar":"async1"}`, string(body)) + }() + + go func() { + defer wg.Done() + + for i := 0; i < 50; i++ { + req, err := http.NewRequest(http.MethodGet, url+"/async2", nil) + require.NoError(t, err) + + resp, err := http.DefaultTransport.RoundTrip(req) + require.NoError(t, err) + + body, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + + require.NoError(t, resp.Body.Close()) + assert.Equal(t, `{"bar":"async2"}`, string(body)) + } + }() + + req, err := http.NewRequest(http.MethodGet, url+"/", nil) + require.NoError(t, err) resp, err := http.DefaultTransport.RoundTrip(req) require.NoError(t, err) body, err := ioutil.ReadAll(resp.Body) - require.NoError(t, resp.Body.Close()) require.NoError(t, err) - if expectation.Status == 0 { - expectation.Status = http.StatusOK - } + require.NoError(t, resp.Body.Close()) + assert.Equal(t, `{"bar":"foo"}`, string(body)) - assert.Equal(t, expectation.Status, resp.StatusCode) - assert.Equal(t, string(expectation.ResponseBody), string(body)) + wg.Wait() + assert.NoError(t, sm.ExpectationsWereMet()) +} - // Asserting default for successful responses. - if resp.StatusCode != http.StatusInternalServerError { - assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) - } +func TestServerMock_ResetExpectations(t *testing.T) { + // Creating REST service mock. + mock, _ := httpmock.NewServer() + defer mock.Close() - if len(expectation.ResponseHeader) > 0 { - for k, v := range expectation.ResponseHeader { - assert.Equal(t, v, resp.Header.Get(k)) - } - } + mock.Expect(httpmock.Expectation{ + Method: http.MethodGet, + RequestURI: "/test?test=test", + ResponseBody: []byte("body"), + }) + + mock.ExpectAsync(httpmock.Expectation{ + Method: http.MethodGet, + RequestURI: "/test-async?test=test", + ResponseBody: []byte("body"), + }) + + assert.Error(t, mock.ExpectationsWereMet()) + mock.ResetExpectations() + assert.NoError(t, mock.ExpectationsWereMet()) } func TestServerMock_ServeHTTP(t *testing.T) { @@ -136,6 +184,50 @@ func TestServerMock_ServeHTTP(t *testing.T) { }) } +func TestServerMock_ServeHTTP_concurrency(t *testing.T) { + // Creating REST service mock. + mock, url := httpmock.NewServer() + defer mock.Close() + + n := 50 + + for i := 0; i < n; i++ { + // Setting expectations for first request. + mock.Expect(httpmock.Expectation{ + Method: http.MethodGet, + RequestURI: "/test?test=test", + ResponseBody: []byte("body"), + }) + } + + wg := sync.WaitGroup{} + wg.Add(n) + + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + + // Sending request with wrong header. + req, err := http.NewRequest(http.MethodGet, url+"/test?test=test", nil) + require.NoError(t, err) + req.Header.Set("X-Foo", "space") + + resp, err := http.DefaultTransport.RoundTrip(req) + require.NoError(t, err) + + respBody, err := ioutil.ReadAll(resp.Body) + require.NoError(t, resp.Body.Close()) + require.NoError(t, err) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, `body`, string(respBody)) + }() + } + + wg.Wait() + assert.NoError(t, mock.ExpectationsWereMet()) +} + func TestServerMock_ServeHTTP_error(t *testing.T) { // Creating REST service mock. mock, baseURL := httpmock.NewServer() @@ -219,72 +311,6 @@ func TestServerMock_ServeHTTP_error(t *testing.T) { `, string(respBody)) } -func TestServerMock_ServeHTTP_concurrency(t *testing.T) { - // Creating REST service mock. - mock, url := httpmock.NewServer() - defer mock.Close() - - n := 50 - - for i := 0; i < n; i++ { - // Setting expectations for first request. - mock.Expect(httpmock.Expectation{ - Method: http.MethodGet, - RequestURI: "/test?test=test", - ResponseBody: []byte("body"), - }) - } - - wg := sync.WaitGroup{} - wg.Add(n) - - for i := 0; i < n; i++ { - go func() { - defer wg.Done() - - // Sending request with wrong header. - req, err := http.NewRequest(http.MethodGet, url+"/test?test=test", nil) - require.NoError(t, err) - req.Header.Set("X-Foo", "space") - - resp, err := http.DefaultTransport.RoundTrip(req) - require.NoError(t, err) - - respBody, err := ioutil.ReadAll(resp.Body) - require.NoError(t, resp.Body.Close()) - require.NoError(t, err) - - assert.Equal(t, http.StatusOK, resp.StatusCode) - assert.Equal(t, `body`, string(respBody)) - }() - } - - wg.Wait() - assert.NoError(t, mock.ExpectationsWereMet()) -} - -func TestServerMock_ResetExpectations(t *testing.T) { - // Creating REST service mock. - mock, _ := httpmock.NewServer() - defer mock.Close() - - mock.Expect(httpmock.Expectation{ - Method: http.MethodGet, - RequestURI: "/test?test=test", - ResponseBody: []byte("body"), - }) - - mock.ExpectAsync(httpmock.Expectation{ - Method: http.MethodGet, - RequestURI: "/test-async?test=test", - ResponseBody: []byte("body"), - }) - - assert.Error(t, mock.ExpectationsWereMet()) - mock.ResetExpectations() - assert.NoError(t, mock.ExpectationsWereMet()) -} - func TestServerMock_vars(t *testing.T) { sm, url := httpmock.NewServer() sm.JSONComparer.Vars = &shared.Vars{} @@ -309,74 +335,48 @@ func TestServerMock_vars(t *testing.T) { assert.Equal(t, `{"bar":"foo","dynEcho":"abc"}`, string(body)) } -func TestServerMock_ExpectAsync(t *testing.T) { - sm, url := httpmock.NewServer() - sm.Expect(httpmock.Expectation{ - Method: http.MethodGet, - RequestURI: "/", - ResponseBody: []byte(`{"bar":"foo"}`), - }) - sm.ExpectAsync(httpmock.Expectation{ - Method: http.MethodGet, - RequestURI: "/async1", - ResponseBody: []byte(`{"bar":"async1"}`), - }) - sm.ExpectAsync(httpmock.Expectation{ - Method: http.MethodGet, - RequestURI: "/async2", - ResponseBody: []byte(`{"bar":"async2"}`), - Unlimited: true, - }) - - wg := sync.WaitGroup{} - wg.Add(2) - - go func() { - defer wg.Done() - - req, err := http.NewRequest(http.MethodGet, url+"/async1", nil) - require.NoError(t, err) - - resp, err := http.DefaultTransport.RoundTrip(req) - require.NoError(t, err) - - body, err := ioutil.ReadAll(resp.Body) - require.NoError(t, err) - - require.NoError(t, resp.Body.Close()) - assert.Equal(t, `{"bar":"async1"}`, string(body)) - }() - - go func() { - defer wg.Done() +func assertRoundTrip(t *testing.T, baseURL string, expectation httpmock.Expectation) { + t.Helper() - for i := 0; i < 50; i++ { - req, err := http.NewRequest(http.MethodGet, url+"/async2", nil) - require.NoError(t, err) + var bodyReader io.Reader - resp, err := http.DefaultTransport.RoundTrip(req) - require.NoError(t, err) + if expectation.RequestBody != nil { + bodyReader = bytes.NewReader(expectation.RequestBody) + } - body, err := ioutil.ReadAll(resp.Body) - require.NoError(t, err) + req, err := http.NewRequest(expectation.Method, baseURL+expectation.RequestURI, bodyReader) + require.NoError(t, err) - require.NoError(t, resp.Body.Close()) - assert.Equal(t, `{"bar":"async2"}`, string(body)) - } - }() + for k, v := range expectation.RequestHeader { + req.Header.Set(k, v) + } - req, err := http.NewRequest(http.MethodGet, url+"/", nil) - require.NoError(t, err) + for n, v := range expectation.RequestCookie { + req.AddCookie(&http.Cookie{Name: n, Value: v}) + } resp, err := http.DefaultTransport.RoundTrip(req) require.NoError(t, err) body, err := ioutil.ReadAll(resp.Body) + require.NoError(t, resp.Body.Close()) require.NoError(t, err) - require.NoError(t, resp.Body.Close()) - assert.Equal(t, `{"bar":"foo"}`, string(body)) + if expectation.Status == 0 { + expectation.Status = http.StatusOK + } - wg.Wait() - assert.NoError(t, sm.ExpectationsWereMet()) + assert.Equal(t, expectation.Status, resp.StatusCode) + assert.Equal(t, string(expectation.ResponseBody), string(body)) + + // Asserting default for successful responses. + if resp.StatusCode != http.StatusInternalServerError { + assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) + } + + if len(expectation.ResponseHeader) > 0 { + for k, v := range expectation.ResponseHeader { + assert.Equal(t, v, resp.Header.Get(k)) + } + } } diff --git a/route.go b/route.go index 08d19fc..6d9e821 100644 --- a/route.go +++ b/route.go @@ -4,11 +4,6 @@ import ( "github.com/swaggest/usecase" ) -// HandlerWithUseCase exposes usecase. -type HandlerWithUseCase interface { - UseCase() usecase.Interactor -} - // HandlerWithRoute is a http.Handler with routing information. type HandlerWithRoute interface { // RouteMethod returns http method of action. @@ -17,3 +12,8 @@ type HandlerWithRoute interface { // RoutePattern returns http path pattern of action. RoutePattern() string } + +// HandlerWithUseCase exposes usecase. +type HandlerWithUseCase interface { + UseCase() usecase.Interactor +} diff --git a/trait.go b/trait.go index 27f95ab..7d8e47c 100644 --- a/trait.go +++ b/trait.go @@ -12,6 +12,47 @@ import ( "github.com/swaggest/usecase" ) +// OutputHasNoContent indicates if output does not seem to have any content body to render in response. +func OutputHasNoContent(output interface{}) bool { + if output == nil { + return true + } + + _, withWriter := output.(usecase.OutputWithWriter) + _, noContent := output.(usecase.OutputWithNoContent) + + rv := reflect.ValueOf(output) + + kind := rv.Kind() + elemKind := reflect.Invalid + + if kind == reflect.Ptr { + elemKind = rv.Elem().Kind() + } + + hasJSONTaggedFields := refl.HasTaggedFields(output, "json") + hasContentTypeTaggedFields := refl.HasTaggedFields(output, "contentType") + isSliceOrMap := refl.IsSliceOrMap(output) + hasEmbeddedSliceOrMap := refl.FindEmbeddedSliceOrMap(output) != nil + isJSONMarshaler := refl.As(output, new(json.Marshaler)) + isPtrToInterface := elemKind == reflect.Interface + isScalar := refl.IsScalar(output) + + if withWriter || + noContent || + hasJSONTaggedFields || + hasContentTypeTaggedFields || + isSliceOrMap || + hasEmbeddedSliceOrMap || + isJSONMarshaler || + isPtrToInterface || + isScalar { + return false + } + + return true +} + // HandlerTrait controls basic behavior of rest handler. type HandlerTrait struct { // SuccessStatus is an HTTP status code to set on successful use case interaction. @@ -48,53 +89,12 @@ type HandlerTrait struct { OpenAPIAnnotations []func(oc openapi.OperationContext) error } -// RestHandler is an accessor. -func (h *HandlerTrait) RestHandler() *HandlerTrait { - return h -} - // RequestMapping returns custom mapping for request decoder. func (h *HandlerTrait) RequestMapping() RequestMapping { return h.ReqMapping } -// OutputHasNoContent indicates if output does not seem to have any content body to render in response. -func OutputHasNoContent(output interface{}) bool { - if output == nil { - return true - } - - _, withWriter := output.(usecase.OutputWithWriter) - _, noContent := output.(usecase.OutputWithNoContent) - - rv := reflect.ValueOf(output) - - kind := rv.Kind() - elemKind := reflect.Invalid - - if kind == reflect.Ptr { - elemKind = rv.Elem().Kind() - } - - hasJSONTaggedFields := refl.HasTaggedFields(output, "json") - hasContentTypeTaggedFields := refl.HasTaggedFields(output, "contentType") - isSliceOrMap := refl.IsSliceOrMap(output) - hasEmbeddedSliceOrMap := refl.FindEmbeddedSliceOrMap(output) != nil - isJSONMarshaler := refl.As(output, new(json.Marshaler)) - isPtrToInterface := elemKind == reflect.Interface - isScalar := refl.IsScalar(output) - - if withWriter || - noContent || - hasJSONTaggedFields || - hasContentTypeTaggedFields || - isSliceOrMap || - hasEmbeddedSliceOrMap || - isJSONMarshaler || - isPtrToInterface || - isScalar { - return false - } - - return true +// RestHandler is an accessor. +func (h *HandlerTrait) RestHandler() *HandlerTrait { + return h } diff --git a/validator.go b/validator.go index 713bfff..f4980ea 100644 --- a/validator.go +++ b/validator.go @@ -2,34 +2,12 @@ package rest import "encoding/json" -// Validator validates a map of decoded data. -type Validator interface { - // ValidateData validates decoded request/response data and returns error in case of invalid data. - ValidateData(in ParamIn, namedData map[string]interface{}) error - - // ValidateJSONBody validates JSON encoded body and returns error in case of invalid data. - ValidateJSONBody(jsonBody []byte) error - - // HasConstraints indicates if there are validation rules for parameter location. - HasConstraints(in ParamIn) bool -} - -// ValidatorFunc implements Validator with a func. -type ValidatorFunc func(in ParamIn, namedData map[string]interface{}) error - -// ValidateData implements Validator. -func (v ValidatorFunc) ValidateData(in ParamIn, namedData map[string]interface{}) error { - return v(in, namedData) -} - -// HasConstraints indicates if there are validation rules for parameter location. -func (v ValidatorFunc) HasConstraints(_ ParamIn) bool { - return true -} +// JSONSchemaValidator defines JSON schema validator. +type JSONSchemaValidator interface { + Validator -// ValidateJSONBody implements Validator. -func (v ValidatorFunc) ValidateJSONBody(body []byte) error { - return v(ParamInBody, map[string]interface{}{"body": json.RawMessage(body)}) + // AddSchema accepts JSON schema for a request parameter or response value. + AddSchema(in ParamIn, name string, schemaData []byte, required bool) error } // RequestJSONSchemaProvider provides request JSON Schemas. @@ -42,6 +20,11 @@ type RequestJSONSchemaProvider interface { ) error } +// RequestValidatorFactory creates request validator for particular structured Go input value. +type RequestValidatorFactory interface { + MakeRequestValidator(method string, input interface{}, mapping RequestMapping) Validator +} + // ResponseJSONSchemaProvider provides response JSON Schemas. type ResponseJSONSchemaProvider interface { ProvideResponseJSONSchemas( @@ -53,19 +36,6 @@ type ResponseJSONSchemaProvider interface { ) error } -// JSONSchemaValidator defines JSON schema validator. -type JSONSchemaValidator interface { - Validator - - // AddSchema accepts JSON schema for a request parameter or response value. - AddSchema(in ParamIn, name string, schemaData []byte, required bool) error -} - -// RequestValidatorFactory creates request validator for particular structured Go input value. -type RequestValidatorFactory interface { - MakeRequestValidator(method string, input interface{}, mapping RequestMapping) Validator -} - // ResponseValidatorFactory creates response validator for particular structured Go output value. type ResponseValidatorFactory interface { MakeResponseValidator( @@ -96,3 +66,33 @@ func (re ValidationErrors) Fields() map[string]interface{} { return res } + +// Validator validates a map of decoded data. +type Validator interface { + // ValidateData validates decoded request/response data and returns error in case of invalid data. + ValidateData(in ParamIn, namedData map[string]interface{}) error + + // ValidateJSONBody validates JSON encoded body and returns error in case of invalid data. + ValidateJSONBody(jsonBody []byte) error + + // HasConstraints indicates if there are validation rules for parameter location. + HasConstraints(in ParamIn) bool +} + +// ValidatorFunc implements Validator with a func. +type ValidatorFunc func(in ParamIn, namedData map[string]interface{}) error + +// HasConstraints indicates if there are validation rules for parameter location. +func (v ValidatorFunc) HasConstraints(_ ParamIn) bool { + return true +} + +// ValidateData implements Validator. +func (v ValidatorFunc) ValidateData(in ParamIn, namedData map[string]interface{}) error { + return v(in, namedData) +} + +// ValidateJSONBody implements Validator. +func (v ValidatorFunc) ValidateJSONBody(body []byte) error { + return v(ParamInBody, map[string]interface{}{"body": json.RawMessage(body)}) +} diff --git a/web/example_test.go b/web/example_test.go index 4f5f2d9..cc64510 100644 --- a/web/example_test.go +++ b/web/example_test.go @@ -12,26 +12,6 @@ import ( "github.com/swaggest/usecase" ) -// album represents data about a record album. -type album struct { - ID int `json:"id"` - Title string `json:"title"` - Artist string `json:"artist"` - Price float64 `json:"price"` - Locale string `query:"locale"` -} - -func postAlbums() usecase.Interactor { - u := usecase.NewIOI(new(album), new(album), func(ctx context.Context, input, output interface{}) error { - log.Println("Creating album") - - return nil - }) - u.SetTags("Album") - - return u -} - func ExampleDefaultService() { // Service initializes router with required middlewares. service := web.NewService(openapi3.NewReflector()) @@ -59,3 +39,23 @@ func ExampleDefaultService() { log.Fatal(err) } } + +func postAlbums() usecase.Interactor { + u := usecase.NewIOI(new(album), new(album), func(ctx context.Context, input, output interface{}) error { + log.Println("Creating album") + + return nil + }) + u.SetTags("Album") + + return u +} + +// album represents data about a record album. +type album struct { + ID int `json:"id"` + Title string `json:"title"` + Artist string `json:"artist"` + Price float64 `json:"price"` + Locale string `query:"locale"` +} diff --git a/web/service.go b/web/service.go index f8038fe..6678239 100644 --- a/web/service.go +++ b/web/service.go @@ -19,6 +19,29 @@ import ( "github.com/swaggest/usecase" ) +// DefaultService initializes router and other basic components of web service. +// +// Provided functional options are invoked twice, before and after initialization. +// +// Deprecated: use NewService. +func DefaultService(options ...func(s *Service, initialized bool)) *Service { + s := NewService(openapi3.NewReflector(), func(s *Service) { + for _, o := range options { + o(s, false) + } + }) + + if r3, ok := s.OpenAPIReflector().(*openapi3.Reflector); ok && s.OpenAPI == nil { + s.OpenAPI = r3.Spec + } + + for _, o := range options { + o(s, true) + } + + return s +} + // NewService initializes router and other basic components of web service. func NewService(refl oapi.Reflector, options ...func(s *Service)) *Service { s := Service{} @@ -69,29 +92,6 @@ func NewService(refl oapi.Reflector, options ...func(s *Service)) *Service { return &s } -// DefaultService initializes router and other basic components of web service. -// -// Provided functional options are invoked twice, before and after initialization. -// -// Deprecated: use NewService. -func DefaultService(options ...func(s *Service, initialized bool)) *Service { - s := NewService(openapi3.NewReflector(), func(s *Service) { - for _, o := range options { - o(s, false) - } - }) - - if r3, ok := s.OpenAPIReflector().(*openapi3.Reflector); ok && s.OpenAPI == nil { - s.OpenAPI = r3.Spec - } - - for _, o := range options { - o(s, true) - } - - return s -} - // Service keeps instrumented router and documentation collector. type Service struct { *chirouter.Wrapper @@ -113,23 +113,32 @@ type Service struct { AddHeadToGet bool } -// OpenAPISchema returns OpenAPI schema. -// -// Returned value can be type asserted to *openapi3.Spec, *openapi31.Spec or marshaled. -func (s *Service) OpenAPISchema() oapi.SpecSchema { - return s.OpenAPICollector.SpecSchema() -} - -// OpenAPIReflector returns OpenAPI structure reflector for customizations. -func (s *Service) OpenAPIReflector() oapi.Reflector { - return s.OpenAPICollector.Refl() -} - // Delete adds the route `pattern` that matches a DELETE http method to invoke use case interactor. func (s *Service) Delete(pattern string, uc usecase.Interactor, options ...func(h *nethttp.Handler)) { s.Method(http.MethodDelete, pattern, nethttp.NewHandler(uc, options...)) } +// Docs adds the route `pattern` that serves API documentation with Swagger UI. +// +// Swagger UI should be provided by `swgui` handler constructor, you can use one of these functions +// +// github.com/swaggest/swgui/v5emb.New +// github.com/swaggest/swgui/v5cdn.New +// github.com/swaggest/swgui/v5.New +// github.com/swaggest/swgui/v4emb.New +// github.com/swaggest/swgui/v4cdn.New +// github.com/swaggest/swgui/v4.New +// github.com/swaggest/swgui/v3emb.New +// github.com/swaggest/swgui/v3cdn.New +// github.com/swaggest/swgui/v3.New +// +// or create your own. +func (s *Service) Docs(pattern string, swgui func(title, schemaURL, basePath string) http.Handler) { + pattern = strings.TrimRight(pattern, "/") + s.Method(http.MethodGet, pattern+"/openapi.json", s.OpenAPICollector) + s.Mount(pattern, swgui(s.OpenAPISchema().Title(), pattern+"/openapi.json", pattern)) +} + // Get adds the route `pattern` that matches a GET http method to invoke use case interactor. // // If Service.AddHeadToGet is enabled, it also adds a HEAD method. @@ -158,6 +167,28 @@ func (s *Service) HeadGet(pattern string, uc usecase.Interactor, options ...func s.Method(http.MethodHead, pattern, h) } +// OnMethodNotAllowed registers usecase interactor as a handler for method not allowed conditions. +func (s *Service) OnMethodNotAllowed(uc usecase.Interactor, options ...func(h *nethttp.Handler)) { + s.MethodNotAllowed(s.HandlerFunc(nethttp.NewHandler(uc, options...))) +} + +// OnNotFound registers usecase interactor as a handler for not found conditions. +func (s *Service) OnNotFound(uc usecase.Interactor, options ...func(h *nethttp.Handler)) { + s.NotFound(s.HandlerFunc(nethttp.NewHandler(uc, options...))) +} + +// OpenAPIReflector returns OpenAPI structure reflector for customizations. +func (s *Service) OpenAPIReflector() oapi.Reflector { + return s.OpenAPICollector.Refl() +} + +// OpenAPISchema returns OpenAPI schema. +// +// Returned value can be type asserted to *openapi3.Spec, *openapi31.Spec or marshaled. +func (s *Service) OpenAPISchema() oapi.SpecSchema { + return s.OpenAPICollector.SpecSchema() +} + // Options adds the route `pattern` that matches a OPTIONS http method to invoke use case interactor. func (s *Service) Options(pattern string, uc usecase.Interactor, options ...func(h *nethttp.Handler)) { s.Method(http.MethodOptions, pattern, nethttp.NewHandler(uc, options...)) @@ -182,34 +213,3 @@ func (s *Service) Put(pattern string, uc usecase.Interactor, options ...func(h * func (s *Service) Trace(pattern string, uc usecase.Interactor, options ...func(h *nethttp.Handler)) { s.Method(http.MethodTrace, pattern, nethttp.NewHandler(uc, options...)) } - -// OnNotFound registers usecase interactor as a handler for not found conditions. -func (s *Service) OnNotFound(uc usecase.Interactor, options ...func(h *nethttp.Handler)) { - s.NotFound(s.HandlerFunc(nethttp.NewHandler(uc, options...))) -} - -// OnMethodNotAllowed registers usecase interactor as a handler for method not allowed conditions. -func (s *Service) OnMethodNotAllowed(uc usecase.Interactor, options ...func(h *nethttp.Handler)) { - s.MethodNotAllowed(s.HandlerFunc(nethttp.NewHandler(uc, options...))) -} - -// Docs adds the route `pattern` that serves API documentation with Swagger UI. -// -// Swagger UI should be provided by `swgui` handler constructor, you can use one of these functions -// -// github.com/swaggest/swgui/v5emb.New -// github.com/swaggest/swgui/v5cdn.New -// github.com/swaggest/swgui/v5.New -// github.com/swaggest/swgui/v4emb.New -// github.com/swaggest/swgui/v4cdn.New -// github.com/swaggest/swgui/v4.New -// github.com/swaggest/swgui/v3emb.New -// github.com/swaggest/swgui/v3cdn.New -// github.com/swaggest/swgui/v3.New -// -// or create your own. -func (s *Service) Docs(pattern string, swgui func(title, schemaURL, basePath string) http.Handler) { - pattern = strings.TrimRight(pattern, "/") - s.Method(http.MethodGet, pattern+"/openapi.json", s.OpenAPICollector) - s.Mount(pattern, swgui(s.OpenAPISchema().Title(), pattern+"/openapi.json", pattern)) -} diff --git a/web/service_test.go b/web/service_test.go index 22ebd2e..bf46838 100644 --- a/web/service_test.go +++ b/web/service_test.go @@ -16,20 +16,6 @@ import ( "github.com/swaggest/usecase" ) -type albumID struct { - ID int `path:"id"` - Locale string `query:"locale"` -} - -func albumByID() usecase.Interactor { - u := usecase.NewIOI(new(albumID), new(album), func(_ context.Context, _, _ interface{}) error { - return nil - }) - u.SetTags("Album") - - return u -} - func TestDefaultService(t *testing.T) { var l []string @@ -92,3 +78,17 @@ func TestDefaultService(t *testing.T) { assert.Equal(t, []string{"one", "two"}, l) } + +func albumByID() usecase.Interactor { + u := usecase.NewIOI(new(albumID), new(album), func(_ context.Context, _, _ interface{}) error { + return nil + }) + u.SetTags("Album") + + return u +} + +type albumID struct { + ID int `path:"id"` + Locale string `query:"locale"` +} From 45dfcd14c42d513cfbebffdd7efa30c91620479b Mon Sep 17 00:00:00 2001 From: Viacheslav Poturaev Date: Mon, 6 Apr 2026 18:47:08 +0200 Subject: [PATCH 02/15] Test gocan --- .github/ISSUE_TEMPLATE/feature_request.md | 2 +- .github/workflows/gocan.yml | 21 +++++++++ .github/workflows/golangci-lint.yml | 2 +- .../json_body_manual.go | 4 +- _examples/advanced/dynamic_schema.go | 4 +- chirouter/example_test.go | 4 +- gorillamux/example_openapi_collector_test.go | 4 +- gorillamux/example_test.go | 4 +- jsonschema/validator_test.go | 4 +- nethttp/handler_test.go | 10 ++--- {request => requestaaaaaa}/decoder.go | 3 +- {request => requestaaaaaa}/decoder_test.go | 44 +++++++++---------- {request => requestaaaaaa}/doc.go | 2 +- {request => requestaaaaaa}/example_test.go | 10 ++--- {request => requestaaaaaa}/factory.go | 2 +- {request => requestaaaaaa}/factory_test.go | 20 ++++----- {request => requestaaaaaa}/file.go | 5 ++- {request => requestaaaaaa}/file_test.go | 10 ++--- {request => requestaaaaaa}/jsonbody.go | 7 +-- {request => requestaaaaaa}/jsonbody_test.go | 2 +- {request => requestaaaaaa}/middleware.go | 2 +- {request => requestaaaaaa}/reflect.go | 2 +- {request => requestaaaaaa/reqerr}/error.go | 2 +- web/service.go | 16 +++---- 24 files changed, 105 insertions(+), 81 deletions(-) create mode 100644 .github/workflows/gocan.yml rename {request => requestaaaaaa}/decoder.go (99%) rename {request => requestaaaaaa}/decoder_test.go (92%) rename {request => requestaaaaaa}/doc.go (76%) rename {request => requestaaaaaa}/example_test.go (80%) rename {request => requestaaaaaa}/factory.go (99%) rename {request => requestaaaaaa}/factory_test.go (95%) rename {request => requestaaaaaa}/file.go (94%) rename {request => requestaaaaaa}/file_test.go (94%) rename {request => requestaaaaaa}/jsonbody.go (89%) rename {request => requestaaaaaa}/jsonbody_test.go (99%) rename {request => requestaaaaaa}/middleware.go (99%) rename {request => requestaaaaaa}/reflect.go (95%) rename {request => requestaaaaaa/reqerr}/error.go (94%) diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md index 5e3babc..36df94a 100644 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -1,5 +1,5 @@ --- -name: Feature request +name: Feature requestaaaaaa about: Suggest an idea for this project title: '' labels: '' diff --git a/.github/workflows/gocan.yml b/.github/workflows/gocan.yml new file mode 100644 index 0000000..109a3e8 --- /dev/null +++ b/.github/workflows/gocan.yml @@ -0,0 +1,21 @@ +name: gocan +on: + pull_request: + +# Cancel the workflow in progress in newer build is about to start. +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + cloc: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Annotate layout-only changes + run: | + curl -sLO https://github.com/vearutop/gocan/releases/download/v0.0.1/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz + gocan_hash=$(git hash-object ./gocan) + [ "$gocan_hash" == "d3484d7cae5f9314243ee932dbabd3e6ba4aed63" ] || (echo "::error::unexpected hash for gocan, possible tampering: $gocan_hash" && exit 1) + ./gocan -gh-annotate diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index f435fe8..776a573 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -35,7 +35,7 @@ jobs: # Optional: golangci-lint command line arguments. # args: --issues-exit-code=0 - # Optional: show only new issues if it's a pull request. The default value is `false`. + # Optional: show only new issues if it's a pull requestaaaaaa. The default value is `false`. # only-new-issues: true # Optional: if set to true then the action will use pre-installed Go. diff --git a/_examples/advanced-generic-openapi31/json_body_manual.go b/_examples/advanced-generic-openapi31/json_body_manual.go index 4c90ea8..fb09d1e 100644 --- a/_examples/advanced-generic-openapi31/json_body_manual.go +++ b/_examples/advanced-generic-openapi31/json_body_manual.go @@ -13,7 +13,7 @@ import ( "github.com/go-chi/chi/v5" "github.com/swaggest/jsonschema-go" - "github.com/swaggest/rest/request" + "github.com/swaggest/rest/requestaaaaaa" "github.com/swaggest/usecase" ) @@ -22,7 +22,7 @@ type JSONPayload struct { Name string `json:"name"` } -var _ request.Loader = &inputWithJSON{} +var _ requestaaaaaa.Loader = &inputWithJSON{} func jsonBodyManual() usecase.Interactor { type outputWithJSON struct { diff --git a/_examples/advanced/dynamic_schema.go b/_examples/advanced/dynamic_schema.go index 18bc97c..9aa0b76 100644 --- a/_examples/advanced/dynamic_schema.go +++ b/_examples/advanced/dynamic_schema.go @@ -8,7 +8,7 @@ import ( "github.com/bool64/ctxd" "github.com/swaggest/jsonschema-go" - "github.com/swaggest/rest/request" + "github.com/swaggest/rest/requestaaaaaa" "github.com/swaggest/usecase" "github.com/swaggest/usecase/status" ) @@ -62,7 +62,7 @@ func dynamicSchema() usecase.Interactor { type dynamicInput struct { jsonschema.Struct - request.EmbeddedSetter + requestaaaaaa.EmbeddedSetter // Type is a static field example. Type string `query:"type"` diff --git a/chirouter/example_test.go b/chirouter/example_test.go index 9ebb211..e05d984 100644 --- a/chirouter/example_test.go +++ b/chirouter/example_test.go @@ -9,13 +9,13 @@ import ( "github.com/go-chi/chi/v5" "github.com/swaggest/rest" "github.com/swaggest/rest/chirouter" - "github.com/swaggest/rest/request" + "github.com/swaggest/rest/requestaaaaaa" ) func ExamplePathToURLValues() { // Instantiate decoder factory with gorillamux.PathToURLValues. // Single factory can be used to create multiple request decoders. - decoderFactory := request.NewDecoderFactory() + decoderFactory := requestaaaaaa.NewDecoderFactory() decoderFactory.ApplyDefaults = true decoderFactory.SetDecoderFunc(rest.ParamInPath, chirouter.PathToURLValues) diff --git a/gorillamux/example_openapi_collector_test.go b/gorillamux/example_openapi_collector_test.go index 1ac525e..86e9423 100644 --- a/gorillamux/example_openapi_collector_test.go +++ b/gorillamux/example_openapi_collector_test.go @@ -12,7 +12,7 @@ import ( "github.com/swaggest/rest" "github.com/swaggest/rest/gorillamux" "github.com/swaggest/rest/nethttp" - "github.com/swaggest/rest/request" + "github.com/swaggest/rest/requestaaaaaa" ) func ExampleNewOpenAPICollector() { @@ -125,7 +125,7 @@ func ExampleNewOpenAPICollector() { } func newMyHandler() *myHandler { - decoderFactory := request.NewDecoderFactory() + decoderFactory := requestaaaaaa.NewDecoderFactory() decoderFactory.ApplyDefaults = true decoderFactory.SetDecoderFunc(rest.ParamInPath, gorillamux.PathToURLValues) diff --git a/gorillamux/example_test.go b/gorillamux/example_test.go index 068be32..e78a307 100644 --- a/gorillamux/example_test.go +++ b/gorillamux/example_test.go @@ -9,13 +9,13 @@ import ( "github.com/gorilla/mux" "github.com/swaggest/rest" "github.com/swaggest/rest/gorillamux" - "github.com/swaggest/rest/request" + "github.com/swaggest/rest/requestaaaaaa" ) func ExamplePathToURLValues() { // Instantiate decoder factory with gorillamux.PathToURLValues. // Single factory can be used to create multiple request decoders. - decoderFactory := request.NewDecoderFactory() + decoderFactory := requestaaaaaa.NewDecoderFactory() decoderFactory.ApplyDefaults = true decoderFactory.SetDecoderFunc(rest.ParamInPath, gorillamux.PathToURLValues) diff --git a/jsonschema/validator_test.go b/jsonschema/validator_test.go index 6216302..a857eb9 100644 --- a/jsonschema/validator_test.go +++ b/jsonschema/validator_test.go @@ -11,7 +11,7 @@ import ( "github.com/swaggest/rest" "github.com/swaggest/rest/jsonschema" "github.com/swaggest/rest/openapi" - "github.com/swaggest/rest/request" + "github.com/swaggest/rest/requestaaaaaa" ) // BenchmarkRequestValidator_ValidateRequestData-4 634356 1761 ns/op 2496 B/op 8 allocs/op. @@ -115,7 +115,7 @@ func TestValidator_ForbidUnknownParams(t *testing.T) { in := new(input) - dec := request.NewDecoderFactory().MakeDecoder(http.MethodGet, in, nil) + dec := requestaaaaaa.NewDecoderFactory().MakeDecoder(http.MethodGet, in, nil) validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). MakeRequestValidator(http.MethodGet, in, nil) diff --git a/nethttp/handler_test.go b/nethttp/handler_test.go index 311a8ab..296c292 100644 --- a/nethttp/handler_test.go +++ b/nethttp/handler_test.go @@ -12,7 +12,7 @@ import ( "github.com/stretchr/testify/require" "github.com/swaggest/rest" "github.com/swaggest/rest/nethttp" - "github.com/swaggest/rest/request" + "github.com/swaggest/rest/requestaaaaaa" "github.com/swaggest/rest/response" "github.com/swaggest/usecase" ) @@ -60,7 +60,7 @@ func TestHandler_ServeHTTP(t *testing.T) { ) h.SetResponseEncoder(&response.Encoder{}) - h.SetRequestDecoder(request.DecoderFunc( + h.SetRequestDecoder(requestaaaaaa.DecoderFunc( func(r *http.Request, input interface{}, validator rest.Validator) error { assert.Equal(t, req, r) @@ -155,7 +155,7 @@ func TestHandler_ServeHTTP_customMapping(t *testing.T) { } ws := []func(handler http.Handler) http.Handler{ - request.DecoderMiddleware(request.NewDecoderFactory()), + requestaaaaaa.DecoderMiddleware(requestaaaaaa.NewDecoderFactory()), nethttp.HandlerWithRouteMiddleware(http.MethodGet, "/test"), response.EncoderMiddleware, } @@ -196,7 +196,7 @@ func TestHandler_ServeHTTP_decodeErr(t *testing.T) { require.NoError(t, err) uh := nethttp.NewHandler(u) - uh.SetRequestDecoder(request.DecoderFunc( + uh.SetRequestDecoder(requestaaaaaa.DecoderFunc( func(_ *http.Request, _ interface{}, _ rest.Validator) error { return errors.New("failed to decode request") }, @@ -265,7 +265,7 @@ func TestHandler_ServeHTTP_getWithBody(t *testing.T) { }) h := nethttp.NewHandler(u) - h.SetRequestDecoder(request.NewDecoderFactory().MakeDecoder(http.MethodGet, new(reqWithBody), nil)) + h.SetRequestDecoder(requestaaaaaa.NewDecoderFactory().MakeDecoder(http.MethodGet, new(reqWithBody), nil)) h.SetResponseEncoder(&response.Encoder{}) req, err := http.NewRequest(http.MethodGet, "/test", strings.NewReader(`{"id":123}`)) diff --git a/request/decoder.go b/requestaaaaaa/decoder.go similarity index 99% rename from request/decoder.go rename to requestaaaaaa/decoder.go index 224fee4..1a7f16a 100644 --- a/request/decoder.go +++ b/requestaaaaaa/decoder.go @@ -1,4 +1,4 @@ -package request +package requestaaaaaa import ( "io/ioutil" @@ -47,6 +47,7 @@ type ( const defaultMaxMemory = 32 << 20 // 32 MB +// 32 MB var _ nethttp.RequestDecoder = &decoder{} func makeDecoder(in rest.ParamIn, formDecoder *form.Decoder, decoderFunc decoderFunc) valueDecoderFunc { diff --git a/request/decoder_test.go b/requestaaaaaa/decoder_test.go similarity index 92% rename from request/decoder_test.go rename to requestaaaaaa/decoder_test.go index 4bb263e..220c9d3 100644 --- a/request/decoder_test.go +++ b/requestaaaaaa/decoder_test.go @@ -1,4 +1,4 @@ -package request_test +package requestaaaaaa_test import ( "bytes" @@ -16,12 +16,12 @@ import ( "github.com/swaggest/rest" "github.com/swaggest/rest/jsonschema" "github.com/swaggest/rest/openapi" - "github.com/swaggest/rest/request" + "github.com/swaggest/rest/requestaaaaaa" ) // BenchmarkDecoder_Decode-4 1314788 857 ns/op 448 B/op 4 allocs/op. func BenchmarkDecoder_Decode(b *testing.B) { - df := request.NewDecoderFactory() + df := requestaaaaaa.NewDecoderFactory() type req struct { Q string `query:"q"` @@ -51,7 +51,7 @@ func BenchmarkDecoder_Decode(b *testing.B) { // BenchmarkDecoder_Decode_json-4 36660 29688 ns/op 12310 B/op 169 allocs/op. func BenchmarkDecoder_Decode_json(b *testing.B) { input := new(reqJSONTest) - dec := request.NewDecoderFactory().MakeDecoder(http.MethodPost, input, nil) + dec := requestaaaaaa.NewDecoderFactory().MakeDecoder(http.MethodPost, input, nil) validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). MakeRequestValidator(http.MethodPost, input, nil) @@ -103,7 +103,7 @@ func BenchmarkDecoder_Decode_jsonParam(b *testing.B) { } `query:"filter"` } - df := request.NewDecoderFactory() + df := requestaaaaaa.NewDecoderFactory() dec := df.MakeDecoder(http.MethodGet, new(inp), nil) req, err := http.NewRequest(http.MethodGet, "/?filter=%7B%22a%22%3A123%2C%22b%22%3A%22abc%22%7D", nil) @@ -135,7 +135,7 @@ func BenchmarkDecoder_Decode_queryObject(b *testing.B) { "/?in_query[1]=1.0&in_query[2]=2.1&in_query[c]=0", nil) assert.NoError(b, err) - df := request.NewDecoderFactory() + df := requestaaaaaa.NewDecoderFactory() input := new(struct { InQuery map[int]float64 `query:"in_query"` @@ -174,7 +174,7 @@ func BenchmarkDecoderFunc_Decode(b *testing.B) { req.AddCookie(&c) - df := request.NewDecoderFactory() + df := requestaaaaaa.NewDecoderFactory() df.SetDecoderFunc(rest.ParamInPath, func(_ *http.Request) (url.Values, error) { return url.Values{"in_path": []string{"mno"}}, nil }) @@ -213,7 +213,7 @@ func TestDecoder_Decode(t *testing.T) { req.AddCookie(&c) - df := request.NewDecoderFactory() + df := requestaaaaaa.NewDecoderFactory() df.SetDecoderFunc(rest.ParamInPath, func(r *http.Request) (url.Values, error) { assert.Equal(t, req, r) @@ -260,7 +260,7 @@ func TestDecoder_Decode_dateTime(t *testing.T) { } input := new(reqTest) - dec := request.NewDecoderFactory().MakeDecoder(http.MethodGet, input, nil) + dec := requestaaaaaa.NewDecoderFactory().MakeDecoder(http.MethodGet, input, nil) validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). MakeRequestValidator(http.MethodGet, input, nil) @@ -273,7 +273,7 @@ func TestDecoder_Decode_error(t *testing.T) { Q int `default:"100" query:"q"` } - df := request.NewDecoderFactory() + df := requestaaaaaa.NewDecoderFactory() df.ApplyDefaults = true d := df.MakeDecoder(http.MethodGet, new(req), nil) @@ -294,7 +294,7 @@ func TestDecoder_Decode_json(t *testing.T) { assert.NoError(t, err) input := new(reqJSONTest) - dec := request.NewDecoderFactory().MakeDecoder(http.MethodPost, input, nil) + dec := requestaaaaaa.NewDecoderFactory().MakeDecoder(http.MethodPost, input, nil) validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). MakeRequestValidator(http.MethodPost, input, nil) @@ -331,7 +331,7 @@ func TestDecoder_Decode_jsonParam(t *testing.T) { } `query:"filter"` } - df := request.NewDecoderFactory() + df := requestaaaaaa.NewDecoderFactory() dec := df.MakeDecoder(http.MethodGet, new(inp), nil) req, err := http.NewRequest(http.MethodGet, "/?filter=%7B%22a%22%3A123%2C%22b%22%3A%22abc%22%7D", nil) @@ -360,7 +360,7 @@ func TestDecoder_Decode_manualLoader_ptr(t *testing.T) { return nil } - dec := request.NewDecoderFactory().MakeDecoder(http.MethodGet, input, nil) + dec := requestaaaaaa.NewDecoderFactory().MakeDecoder(http.MethodGet, input, nil) validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). MakeRequestValidator(http.MethodGet, input, nil) @@ -386,7 +386,7 @@ func TestDecoder_Decode_manualLoader_val(t *testing.T) { return nil } - dec := request.NewDecoderFactory().MakeDecoder(http.MethodGet, input, nil) + dec := requestaaaaaa.NewDecoderFactory().MakeDecoder(http.MethodGet, input, nil) validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). MakeRequestValidator(http.MethodGet, input, nil) @@ -401,7 +401,7 @@ func TestDecoder_Decode_queryObject(t *testing.T) { "/?in_query[1]=1.0&in_query[2]=2.1&in_query[3]=0", nil) assert.NoError(t, err) - df := request.NewDecoderFactory() + df := requestaaaaaa.NewDecoderFactory() input := new(struct { InQuery map[int]float64 `query:"in_query"` @@ -427,7 +427,7 @@ func TestDecoder_Decode_required(t *testing.T) { assert.NoError(t, err) input := new(reqTest) - dec := request.NewDecoderFactory().MakeDecoder(http.MethodPost, input, nil) + dec := requestaaaaaa.NewDecoderFactory().MakeDecoder(http.MethodPost, input, nil) validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). MakeRequestValidator(http.MethodPost, input, nil) @@ -441,7 +441,7 @@ func TestDecoder_Decode_required_header_case(t *testing.T) { assert.NoError(t, err) input := new(reqTest) - dec := request.NewDecoderFactory().MakeDecoder(http.MethodPost, input, nil) + dec := requestaaaaaa.NewDecoderFactory().MakeDecoder(http.MethodPost, input, nil) validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). MakeRequestValidator(http.MethodPost, input, nil) @@ -456,7 +456,7 @@ func TestDecoder_Decode_setter_ptr(t *testing.T) { input := new(inputWithSetter) - dec := request.NewDecoderFactory().MakeDecoder(http.MethodGet, input, nil) + dec := requestaaaaaa.NewDecoderFactory().MakeDecoder(http.MethodGet, input, nil) validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). MakeRequestValidator(http.MethodGet, input, nil) @@ -473,7 +473,7 @@ func TestDecoder_Decode_setter_val(t *testing.T) { input := inputWithSetter{} - dec := request.NewDecoderFactory().MakeDecoder(http.MethodGet, input, nil) + dec := requestaaaaaa.NewDecoderFactory().MakeDecoder(http.MethodGet, input, nil) validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). MakeRequestValidator(http.MethodGet, input, nil) @@ -498,7 +498,7 @@ func TestDecoder_Decode_unknownParams(t *testing.T) { in := new(input) - dec := request.NewDecoderFactory().MakeDecoder(http.MethodGet, in, nil) + dec := requestaaaaaa.NewDecoderFactory().MakeDecoder(http.MethodGet, in, nil) validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). MakeRequestValidator(http.MethodGet, in, nil) @@ -513,7 +513,7 @@ func TestDecoderFactory_MakeDecoder_default_unexported(t *testing.T) { req *http.Request } - f := request.NewDecoderFactory() + f := requestaaaaaa.NewDecoderFactory() f.ApplyDefaults = true dec := f.MakeDecoder(http.MethodGet, showImageInput{}, nil) @@ -523,7 +523,7 @@ func TestDecoderFactory_MakeDecoder_default_unexported(t *testing.T) { func TestDecoderFactory_MakeDecoder_formOrJSON(t *testing.T) { var in formOrJSONInput - dec := request.NewDecoderFactory().MakeDecoder(http.MethodPost, in, nil) + dec := requestaaaaaa.NewDecoderFactory().MakeDecoder(http.MethodPost, in, nil) validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). MakeRequestValidator(http.MethodPost, in, nil) diff --git a/request/doc.go b/requestaaaaaa/doc.go similarity index 76% rename from request/doc.go rename to requestaaaaaa/doc.go index 5008a28..7d6a150 100644 --- a/request/doc.go +++ b/requestaaaaaa/doc.go @@ -1,2 +1,2 @@ // Package request implements reflection-based net/http request decoder. -package request +package requestaaaaaa diff --git a/request/example_test.go b/requestaaaaaa/example_test.go similarity index 80% rename from request/example_test.go rename to requestaaaaaa/example_test.go index eb56955..ff92583 100644 --- a/request/example_test.go +++ b/requestaaaaaa/example_test.go @@ -1,10 +1,10 @@ -package request_test +package requestaaaaaa_test import ( "fmt" "net/http" - "github.com/swaggest/rest/request" + "github.com/swaggest/rest/requestaaaaaa" ) func ExampleDecoder_Decode() { @@ -15,7 +15,7 @@ func ExampleDecoder_Decode() { } // A decoder for particular structure, can be reused for multiple HTTP requests. - myDecoder := request.NewDecoderFactory().MakeDecoder(http.MethodPost, new(MyRequest), nil) + myDecoder := requestaaaaaa.NewDecoderFactory().MakeDecoder(http.MethodPost, new(MyRequest), nil) // Request and response writer from ServeHTTP. var ( @@ -33,7 +33,7 @@ func ExampleDecoder_Decode() { func ExampleEmbeddedSetter_Request() { type MyRequest struct { - request.EmbeddedSetter + requestaaaaaa.EmbeddedSetter Foo int `header:"X-Foo"` Bar string `formData:"bar"` @@ -41,7 +41,7 @@ func ExampleEmbeddedSetter_Request() { } // A decoder for particular structure, can be reused for multiple HTTP requests. - myDecoder := request.NewDecoderFactory().MakeDecoder(http.MethodPost, new(MyRequest), nil) + myDecoder := requestaaaaaa.NewDecoderFactory().MakeDecoder(http.MethodPost, new(MyRequest), nil) // Request and response writer from ServeHTTP. var ( diff --git a/request/factory.go b/requestaaaaaa/factory.go similarity index 99% rename from request/factory.go rename to requestaaaaaa/factory.go index 845abcb..30fd0c2 100644 --- a/request/factory.go +++ b/requestaaaaaa/factory.go @@ -1,4 +1,4 @@ -package request +package requestaaaaaa import ( "bytes" diff --git a/request/factory_test.go b/requestaaaaaa/factory_test.go similarity index 95% rename from request/factory_test.go rename to requestaaaaaa/factory_test.go index 8023590..153f379 100644 --- a/request/factory_test.go +++ b/requestaaaaaa/factory_test.go @@ -1,4 +1,4 @@ -package request_test +package requestaaaaaa_test import ( "bytes" @@ -13,12 +13,12 @@ import ( "github.com/stretchr/testify/require" "github.com/swaggest/jsonschema-go" "github.com/swaggest/rest" - "github.com/swaggest/rest/request" + "github.com/swaggest/rest/requestaaaaaa" ) // BenchmarkDecoderFactory_SetDecoderFunc-4 577378 1994 ns/op 1024 B/op 16 allocs/op. func BenchmarkDecoderFactory_SetDecoderFunc(b *testing.B) { - df := request.NewDecoderFactory() + df := requestaaaaaa.NewDecoderFactory() df.SetDecoderFunc("jwt", func(r *http.Request) (url.Values, error) { ah := r.Header.Get("Authorization") if ah == "" || len(ah) < 8 || strings.ToLower(ah[0:7]) != "bearer " { @@ -78,7 +78,7 @@ func TestDecoderFactory_MakeDecoder_customMapping(t *testing.T) { Name string `default:"foo"` } - df := request.NewDecoderFactory() + df := requestaaaaaa.NewDecoderFactory() df.ApplyDefaults = true customMapping := rest.RequestMapping{ @@ -134,7 +134,7 @@ func TestDecoderFactory_MakeDecoder_default(t *testing.T) { unexported bool `query:"unexported"` // This field is skipped because it is unexported. } - df := request.NewDecoderFactory() + df := requestaaaaaa.NewDecoderFactory() df.ApplyDefaults = true dec := df.MakeDecoder(http.MethodPost, new(MyInput), nil) @@ -174,7 +174,7 @@ func TestDecoderFactory_MakeDecoder_default(t *testing.T) { } func TestDecoderFactory_MakeDecoder_header_case_sensitivity(t *testing.T) { - df := request.NewDecoderFactory() + df := requestaaaaaa.NewDecoderFactory() type input struct { A string `header:"x-one-two-three" required:"true"` @@ -210,7 +210,7 @@ func TestDecoderFactory_MakeDecoder_invalidMapping(t *testing.T) { Name string `default:"foo"` } - df := request.NewDecoderFactory() + df := requestaaaaaa.NewDecoderFactory() customMapping := rest.RequestMapping{ rest.ParamInQuery: map[string]string{"ID2": "id"}, @@ -222,7 +222,7 @@ func TestDecoderFactory_MakeDecoder_invalidMapping(t *testing.T) { } func TestDecoderFactory_SetDecoderFunc(t *testing.T) { - df := request.NewDecoderFactory() + df := requestaaaaaa.NewDecoderFactory() df.SetDecoderFunc("jwt", func(r *http.Request) (url.Values, error) { ah := r.Header.Get("Authorization") if ah == "" || len(ah) < 8 || strings.ToLower(ah[0:7]) != "bearer " { @@ -284,7 +284,7 @@ func TestNewDecoderFactory_default(t *testing.T) { DefaultedTagVal defaultFromSchemaVal `query:"dtv" default:"none"` } - df := request.NewDecoderFactory() + df := requestaaaaaa.NewDecoderFactory() df.ApplyDefaults = true df.JSONSchemaReflector = &jsonschema.Reflector{} @@ -318,7 +318,7 @@ func TestNewDecoderFactory_requestBody(t *testing.T) { req.Header.Set("Content-Type", "text/plain") var input Req - dec := request.NewDecoderFactory().MakeDecoder(http.MethodPost, input, nil) + dec := requestaaaaaa.NewDecoderFactory().MakeDecoder(http.MethodPost, input, nil) require.NoError(t, dec.Decode(req, &input, nil)) diff --git a/request/file.go b/requestaaaaaa/file.go similarity index 94% rename from request/file.go rename to requestaaaaaa/file.go index 505f0b9..b2576bb 100644 --- a/request/file.go +++ b/requestaaaaaa/file.go @@ -1,4 +1,4 @@ -package request +package requestaaaaaa import ( "errors" @@ -8,6 +8,7 @@ import ( "reflect" "github.com/swaggest/rest" + "github.com/swaggest/rest/requestaaaaaa/reqerr" ) var ( @@ -67,7 +68,7 @@ func setFile(r *http.Request, field reflect.StructField, v reflect.Value) error if err != nil { if errors.Is(err, http.ErrMissingFile) { if field.Tag.Get("required") == "true" { - return fmt.Errorf("%w: %q", ErrMissingRequiredFile, name) + return fmt.Errorf("%w: %q", reqerr.ErrMissingRequiredFile, name) } return nil diff --git a/request/file_test.go b/requestaaaaaa/file_test.go similarity index 94% rename from request/file_test.go rename to requestaaaaaa/file_test.go index 01a51dd..5d5feaa 100644 --- a/request/file_test.go +++ b/requestaaaaaa/file_test.go @@ -1,4 +1,4 @@ -package request_test +package requestaaaaaa_test import ( "bytes" @@ -18,7 +18,7 @@ import ( "github.com/swaggest/rest/jsonschema" "github.com/swaggest/rest/nethttp" "github.com/swaggest/rest/openapi" - "github.com/swaggest/rest/request" + "github.com/swaggest/rest/requestaaaaaa" "github.com/swaggest/rest/response" "github.com/swaggest/rest/web" "github.com/swaggest/usecase" @@ -52,15 +52,15 @@ func TestDecoder_Decode_fileUploadOptional(t *testing.T) { func TestDecoder_Decode_fileUploadTag(t *testing.T) { r := chirouter.NewWrapper(chi.NewRouter()) apiSchema := openapi.NewCollector(openapi3.NewReflector()) - decoderFactory := request.NewDecoderFactory() + decoderFactory := requestaaaaaa.NewDecoderFactory() validatorFactory := jsonschema.NewFactory(apiSchema, apiSchema) decoderFactory.SetDecoderFunc(rest.ParamInPath, chirouter.PathToURLValues) ws := []func(handler http.Handler) http.Handler{ nethttp.OpenAPIMiddleware(apiSchema), - request.DecoderMiddleware(decoderFactory), - request.ValidatorMiddleware(validatorFactory), + requestaaaaaa.DecoderMiddleware(decoderFactory), + requestaaaaaa.ValidatorMiddleware(validatorFactory), response.EncoderMiddleware, } diff --git a/request/jsonbody.go b/requestaaaaaa/jsonbody.go similarity index 89% rename from request/jsonbody.go rename to requestaaaaaa/jsonbody.go index d747903..eca7615 100644 --- a/request/jsonbody.go +++ b/requestaaaaaa/jsonbody.go @@ -1,4 +1,4 @@ -package request +package requestaaaaaa import ( "bytes" @@ -10,6 +10,7 @@ import ( "sync" "github.com/swaggest/rest" + "github.com/swaggest/rest/requestaaaaaa/reqerr" ) var bufPool = sync.Pool{ @@ -21,7 +22,7 @@ var bufPool = sync.Pool{ func decodeJSONBody(readJSON func(rd io.Reader, v interface{}) error, tolerateFormData bool) valueDecoderFunc { return func(r *http.Request, input interface{}, validator rest.Validator) error { if r.ContentLength == 0 { - return ErrMissingRequestBody + return reqerr.ErrMissingRequestBody } if ret, err := checkJSONBodyContentType(r.Header.Get("Content-Type"), tolerateFormData); err != nil { @@ -71,7 +72,7 @@ func checkJSONBodyContentType(contentType string, tolerateFormData bool) (ret bo return true, nil } - return true, fmt.Errorf("%w, received: %s", ErrJSONExpected, contentType) + return true, fmt.Errorf("%w, received: %s", reqerr.ErrJSONExpected, contentType) } return false, nil diff --git a/request/jsonbody_test.go b/requestaaaaaa/jsonbody_test.go similarity index 99% rename from request/jsonbody_test.go rename to requestaaaaaa/jsonbody_test.go index f5b7a45..3c65ab7 100644 --- a/request/jsonbody_test.go +++ b/requestaaaaaa/jsonbody_test.go @@ -1,4 +1,4 @@ -package request +package requestaaaaaa import ( "bytes" diff --git a/request/middleware.go b/requestaaaaaa/middleware.go similarity index 99% rename from request/middleware.go rename to requestaaaaaa/middleware.go index dcf0da7..26115ec 100644 --- a/request/middleware.go +++ b/requestaaaaaa/middleware.go @@ -1,4 +1,4 @@ -package request +package requestaaaaaa import ( "net/http" diff --git a/request/reflect.go b/requestaaaaaa/reflect.go similarity index 95% rename from request/reflect.go rename to requestaaaaaa/reflect.go index f82f5d5..7e6e6e6 100644 --- a/request/reflect.go +++ b/requestaaaaaa/reflect.go @@ -1,4 +1,4 @@ -package request +package requestaaaaaa import ( "reflect" diff --git a/request/error.go b/requestaaaaaa/reqerr/error.go similarity index 94% rename from request/error.go rename to requestaaaaaa/reqerr/error.go index 9eb873c..4da2101 100644 --- a/request/error.go +++ b/requestaaaaaa/reqerr/error.go @@ -1,4 +1,4 @@ -package request +package reqerr import "errors" diff --git a/web/service.go b/web/service.go index 6678239..23301be 100644 --- a/web/service.go +++ b/web/service.go @@ -14,7 +14,7 @@ import ( "github.com/swaggest/rest/jsonschema" "github.com/swaggest/rest/nethttp" "github.com/swaggest/rest/openapi" - "github.com/swaggest/rest/request" + "github.com/swaggest/rest/requestaaaaaa" "github.com/swaggest/rest/response" "github.com/swaggest/usecase" ) @@ -65,7 +65,7 @@ func NewService(refl oapi.Reflector, options ...func(s *Service)) *Service { } if s.DecoderFactory == nil { - decoderFactory := request.NewDecoderFactory() + decoderFactory := requestaaaaaa.NewDecoderFactory() decoderFactory.ApplyDefaults = true decoderFactory.JSONSchemaReflector = s.OpenAPICollector.Refl().JSONSchemaReflector() decoderFactory.SetDecoderFunc(rest.ParamInPath, chirouter.PathToURLValues) @@ -82,11 +82,11 @@ func NewService(refl oapi.Reflector, options ...func(s *Service)) *Service { // Setup middlewares. s.Wrap( - s.PanicRecoveryMiddleware, // Panic recovery. - nethttp.OpenAPIMiddleware(s.OpenAPICollector), // Documentation collector. - request.DecoderMiddleware(s.DecoderFactory), // Request decoder setup. - request.ValidatorMiddleware(validatorFactory), // Request validator setup. - response.EncoderMiddleware, // Response encoder setup. + s.PanicRecoveryMiddleware, // Panic recovery. + nethttp.OpenAPIMiddleware(s.OpenAPICollector), // Documentation collector. + requestaaaaaa.DecoderMiddleware(s.DecoderFactory), // Request decoder setup. + requestaaaaaa.ValidatorMiddleware(validatorFactory), // Request validator setup. + response.EncoderMiddleware, // Response encoder setup. ) return &s @@ -102,7 +102,7 @@ type Service struct { OpenAPI *openapi3.Spec OpenAPICollector *openapi.Collector - DecoderFactory *request.DecoderFactory + DecoderFactory *requestaaaaaa.DecoderFactory // Response validation is not enabled by default for its less justifiable performance impact. // This field is populated so that response.ValidatorMiddleware(s.ResponseValidatorFactory) can be From 4abe0f6270f6aa4f841ed1173a3be0230a093a6c Mon Sep 17 00:00:00 2001 From: Viacheslav Poturaev Date: Mon, 6 Apr 2026 18:49:34 +0200 Subject: [PATCH 03/15] Test gocan --- .github/workflows/gocan.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/gocan.yml b/.github/workflows/gocan.yml index 109a3e8..b3565dc 100644 --- a/.github/workflows/gocan.yml +++ b/.github/workflows/gocan.yml @@ -18,4 +18,4 @@ jobs: curl -sLO https://github.com/vearutop/gocan/releases/download/v0.0.1/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz gocan_hash=$(git hash-object ./gocan) [ "$gocan_hash" == "d3484d7cae5f9314243ee932dbabd3e6ba4aed63" ] || (echo "::error::unexpected hash for gocan, possible tampering: $gocan_hash" && exit 1) - ./gocan -gh-annotate + ./gocan -gh-annotate -gh-base "${{ github.event.pull_request.base.sha }}" From bb756134ecb36c3a99218c0c6967bfa14d52b06e Mon Sep 17 00:00:00 2001 From: Viacheslav Poturaev Date: Mon, 6 Apr 2026 19:00:51 +0200 Subject: [PATCH 04/15] Test gocan --- .github/workflows/gocan.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/gocan.yml b/.github/workflows/gocan.yml index b3565dc..19bc366 100644 --- a/.github/workflows/gocan.yml +++ b/.github/workflows/gocan.yml @@ -18,4 +18,4 @@ jobs: curl -sLO https://github.com/vearutop/gocan/releases/download/v0.0.1/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz gocan_hash=$(git hash-object ./gocan) [ "$gocan_hash" == "d3484d7cae5f9314243ee932dbabd3e6ba4aed63" ] || (echo "::error::unexpected hash for gocan, possible tampering: $gocan_hash" && exit 1) - ./gocan -gh-annotate -gh-base "${{ github.event.pull_request.base.sha }}" + ./gocan -gh-annotate -gh-base "${{ github.event.pull_request.base.sha }}" -gh-head ${{ github.event.pull_request.head.sha }} From 8fd98c94acab621becca2bc8125c49f97fc1a744 Mon Sep 17 00:00:00 2001 From: Viacheslav Poturaev Date: Mon, 6 Apr 2026 19:08:54 +0200 Subject: [PATCH 05/15] Test gocan --- .github/workflows/gocan.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/gocan.yml b/.github/workflows/gocan.yml index 19bc366..b920b1a 100644 --- a/.github/workflows/gocan.yml +++ b/.github/workflows/gocan.yml @@ -15,7 +15,7 @@ jobs: uses: actions/checkout@v4 - name: Annotate layout-only changes run: | - curl -sLO https://github.com/vearutop/gocan/releases/download/v0.0.1/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz + curl -sLO https://github.com/vearutop/gocan/releases/download/v0.0.2/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz gocan_hash=$(git hash-object ./gocan) - [ "$gocan_hash" == "d3484d7cae5f9314243ee932dbabd3e6ba4aed63" ] || (echo "::error::unexpected hash for gocan, possible tampering: $gocan_hash" && exit 1) + [ "$gocan_hash" == "0da18e8ea60c2addf699e3301aca6c103b2ae657" ] || (echo "::error::unexpected hash for gocan, possible tampering: $gocan_hash" && exit 1) ./gocan -gh-annotate -gh-base "${{ github.event.pull_request.base.sha }}" -gh-head ${{ github.event.pull_request.head.sha }} From 46d1a167faed5c21c84d13f2306dfa01cf099508 Mon Sep 17 00:00:00 2001 From: Viacheslav Poturaev Date: Mon, 6 Apr 2026 19:14:14 +0200 Subject: [PATCH 06/15] Test gocan --- .github/workflows/gocan.yml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/.github/workflows/gocan.yml b/.github/workflows/gocan.yml index b920b1a..457da33 100644 --- a/.github/workflows/gocan.yml +++ b/.github/workflows/gocan.yml @@ -8,7 +8,7 @@ concurrency: cancel-in-progress: true jobs: - cloc: + gocan: runs-on: ubuntu-latest steps: - name: Checkout code @@ -18,4 +18,6 @@ jobs: curl -sLO https://github.com/vearutop/gocan/releases/download/v0.0.2/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz gocan_hash=$(git hash-object ./gocan) [ "$gocan_hash" == "0da18e8ea60c2addf699e3301aca6c103b2ae657" ] || (echo "::error::unexpected hash for gocan, possible tampering: $gocan_hash" && exit 1) + git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.base.sha }} + git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.head.sha }} ./gocan -gh-annotate -gh-base "${{ github.event.pull_request.base.sha }}" -gh-head ${{ github.event.pull_request.head.sha }} From 9be1cec16a3b7006546043fbc47155d52fbd7165 Mon Sep 17 00:00:00 2001 From: Viacheslav Poturaev Date: Mon, 6 Apr 2026 19:52:51 +0200 Subject: [PATCH 07/15] Test gocan --- .github/workflows/gocan.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/gocan.yml b/.github/workflows/gocan.yml index 457da33..4ae4ceb 100644 --- a/.github/workflows/gocan.yml +++ b/.github/workflows/gocan.yml @@ -15,9 +15,9 @@ jobs: uses: actions/checkout@v4 - name: Annotate layout-only changes run: | - curl -sLO https://github.com/vearutop/gocan/releases/download/v0.0.2/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz + curl -sLO https://github.com/vearutop/gocan/releases/download/v0.0.3/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz gocan_hash=$(git hash-object ./gocan) - [ "$gocan_hash" == "0da18e8ea60c2addf699e3301aca6c103b2ae657" ] || (echo "::error::unexpected hash for gocan, possible tampering: $gocan_hash" && exit 1) + [ "$gocan_hash" == "cfcb9c334ed3e887a802f1fde35fbf7f5af2d8e8" ] || (echo "::error::unexpected hash for gocan, possible tampering: $gocan_hash" && exit 1) git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.base.sha }} git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.head.sha }} ./gocan -gh-annotate -gh-base "${{ github.event.pull_request.base.sha }}" -gh-head ${{ github.event.pull_request.head.sha }} From dbd6fb1c061b5444be72de990abe4e8dcd1e0791 Mon Sep 17 00:00:00 2001 From: Viacheslav Poturaev Date: Mon, 6 Apr 2026 20:02:49 +0200 Subject: [PATCH 08/15] Test gocan --- .github/workflows/gocan.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/gocan.yml b/.github/workflows/gocan.yml index 4ae4ceb..6a3b61a 100644 --- a/.github/workflows/gocan.yml +++ b/.github/workflows/gocan.yml @@ -18,6 +18,6 @@ jobs: curl -sLO https://github.com/vearutop/gocan/releases/download/v0.0.3/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz gocan_hash=$(git hash-object ./gocan) [ "$gocan_hash" == "cfcb9c334ed3e887a802f1fde35fbf7f5af2d8e8" ] || (echo "::error::unexpected hash for gocan, possible tampering: $gocan_hash" && exit 1) - git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.base.sha }} - git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.head.sha }} + #git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.base.sha }} + #git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.head.sha }} ./gocan -gh-annotate -gh-base "${{ github.event.pull_request.base.sha }}" -gh-head ${{ github.event.pull_request.head.sha }} From 9ba86f0264018d3bb5595dd235d873291b98de67 Mon Sep 17 00:00:00 2001 From: Viacheslav Poturaev Date: Tue, 7 Apr 2026 10:02:39 +0200 Subject: [PATCH 09/15] Test gocan --- .github/workflows/gocan.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/gocan.yml b/.github/workflows/gocan.yml index 6a3b61a..0874fb8 100644 --- a/.github/workflows/gocan.yml +++ b/.github/workflows/gocan.yml @@ -2,7 +2,7 @@ name: gocan on: pull_request: -# Cancel the workflow in progress in newer build is about to start. +# Cancel the workflow in progress if a newer build is about to start. concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true @@ -18,6 +18,6 @@ jobs: curl -sLO https://github.com/vearutop/gocan/releases/download/v0.0.3/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz gocan_hash=$(git hash-object ./gocan) [ "$gocan_hash" == "cfcb9c334ed3e887a802f1fde35fbf7f5af2d8e8" ] || (echo "::error::unexpected hash for gocan, possible tampering: $gocan_hash" && exit 1) - #git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.base.sha }} + git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.base.sha }} #git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.head.sha }} - ./gocan -gh-annotate -gh-base "${{ github.event.pull_request.base.sha }}" -gh-head ${{ github.event.pull_request.head.sha }} + ./gocan -gh-annotate -gh-base ${{ github.event.pull_request.base.sha }} -gh-head ${{ github.event.pull_request.head.sha }} From 5fa39adf798f6f4ca285a5ab7de4704fec39814b Mon Sep 17 00:00:00 2001 From: Viacheslav Poturaev Date: Tue, 7 Apr 2026 10:06:11 +0200 Subject: [PATCH 10/15] Test gocan --- .github/workflows/gocan.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/gocan.yml b/.github/workflows/gocan.yml index 0874fb8..af5dded 100644 --- a/.github/workflows/gocan.yml +++ b/.github/workflows/gocan.yml @@ -12,12 +12,12 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout code - uses: actions/checkout@v4 + uses: actions/checkout@v6 - name: Annotate layout-only changes run: | curl -sLO https://github.com/vearutop/gocan/releases/download/v0.0.3/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz gocan_hash=$(git hash-object ./gocan) [ "$gocan_hash" == "cfcb9c334ed3e887a802f1fde35fbf7f5af2d8e8" ] || (echo "::error::unexpected hash for gocan, possible tampering: $gocan_hash" && exit 1) git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.base.sha }} - #git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.head.sha }} + git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.head.sha }} ./gocan -gh-annotate -gh-base ${{ github.event.pull_request.base.sha }} -gh-head ${{ github.event.pull_request.head.sha }} From 34e359ec389377f80b0625308676494ce23e61ba Mon Sep 17 00:00:00 2001 From: Viacheslav Poturaev Date: Tue, 7 Apr 2026 11:24:56 +0200 Subject: [PATCH 11/15] Test gocan --- .github/workflows/gocan.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/gocan.yml b/.github/workflows/gocan.yml index af5dded..11f9269 100644 --- a/.github/workflows/gocan.yml +++ b/.github/workflows/gocan.yml @@ -15,9 +15,9 @@ jobs: uses: actions/checkout@v6 - name: Annotate layout-only changes run: | - curl -sLO https://github.com/vearutop/gocan/releases/download/v0.0.3/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz + curl -sLO https://github.com/vearutop/gocan/releases/download/v0.0.4/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz gocan_hash=$(git hash-object ./gocan) - [ "$gocan_hash" == "cfcb9c334ed3e887a802f1fde35fbf7f5af2d8e8" ] || (echo "::error::unexpected hash for gocan, possible tampering: $gocan_hash" && exit 1) + [ "$gocan_hash" == "b4e7572cd799ba302202c1b058e9259c56b4e2c7" ] || (echo "::error::unexpected hash for gocan, possible tampering: $gocan_hash" && exit 1) git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.base.sha }} git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.head.sha }} ./gocan -gh-annotate -gh-base ${{ github.event.pull_request.base.sha }} -gh-head ${{ github.event.pull_request.head.sha }} From 097eeaf90fd3b97b648d90f6dd7843369ed0cdc7 Mon Sep 17 00:00:00 2001 From: Viacheslav Poturaev Date: Tue, 7 Apr 2026 16:41:48 +0200 Subject: [PATCH 12/15] Test gocan --- .github/workflows/gocan.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/gocan.yml b/.github/workflows/gocan.yml index 11f9269..34d5d46 100644 --- a/.github/workflows/gocan.yml +++ b/.github/workflows/gocan.yml @@ -15,9 +15,9 @@ jobs: uses: actions/checkout@v6 - name: Annotate layout-only changes run: | - curl -sLO https://github.com/vearutop/gocan/releases/download/v0.0.4/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz + curl -sLO https://github.com/vearutop/gocan/releases/download/v0.0.5/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz gocan_hash=$(git hash-object ./gocan) - [ "$gocan_hash" == "b4e7572cd799ba302202c1b058e9259c56b4e2c7" ] || (echo "::error::unexpected hash for gocan, possible tampering: $gocan_hash" && exit 1) + [ "$gocan_hash" == "98c1ea372abf02d6c8583689c28c8589c442674f" ] || (echo "::error::unexpected hash for gocan, possible tampering: $gocan_hash" && exit 1) git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.base.sha }} git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.head.sha }} - ./gocan -gh-annotate -gh-base ${{ github.event.pull_request.base.sha }} -gh-head ${{ github.event.pull_request.head.sha }} + ./gocan -gh-checks -gh-base ${{ github.event.pull_request.base.sha }} -gh-head ${{ github.event.pull_request.head.sha }} From 7f81717c65f581d337ec3c18c04e675f060ef40f Mon Sep 17 00:00:00 2001 From: Viacheslav Poturaev Date: Tue, 7 Apr 2026 19:56:00 +0200 Subject: [PATCH 13/15] Test gocan --- .github/workflows/gocan.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/gocan.yml b/.github/workflows/gocan.yml index 34d5d46..74c2047 100644 --- a/.github/workflows/gocan.yml +++ b/.github/workflows/gocan.yml @@ -20,4 +20,4 @@ jobs: [ "$gocan_hash" == "98c1ea372abf02d6c8583689c28c8589c442674f" ] || (echo "::error::unexpected hash for gocan, possible tampering: $gocan_hash" && exit 1) git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.base.sha }} git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.head.sha }} - ./gocan -gh-checks -gh-base ${{ github.event.pull_request.base.sha }} -gh-head ${{ github.event.pull_request.head.sha }} + GITHUB_TOKEN=${{ secrets.GITHUB_TOKEN }} ./gocan -gh-checks -gh-base ${{ github.event.pull_request.base.sha }} -gh-head ${{ github.event.pull_request.head.sha }} From a8c6384e692b85933cadaf630c8d80123d9231d9 Mon Sep 17 00:00:00 2001 From: Viacheslav Poturaev Date: Tue, 7 Apr 2026 20:17:40 +0200 Subject: [PATCH 14/15] Test gocan --- .github/workflows/gocan.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/gocan.yml b/.github/workflows/gocan.yml index 74c2047..570d80a 100644 --- a/.github/workflows/gocan.yml +++ b/.github/workflows/gocan.yml @@ -15,9 +15,9 @@ jobs: uses: actions/checkout@v6 - name: Annotate layout-only changes run: | - curl -sLO https://github.com/vearutop/gocan/releases/download/v0.0.5/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz + curl -sLO https://github.com/vearutop/gocan/releases/download/v0.0.6/linux_amd64.tar.gz && tar xf linux_amd64.tar.gz gocan_hash=$(git hash-object ./gocan) - [ "$gocan_hash" == "98c1ea372abf02d6c8583689c28c8589c442674f" ] || (echo "::error::unexpected hash for gocan, possible tampering: $gocan_hash" && exit 1) + [ "$gocan_hash" == "3ff3d1bd62d8a1c459ed68ca9dd997afcac1cfcc" ] || (echo "::error::unexpected hash for gocan, possible tampering: $gocan_hash" && exit 1) git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.base.sha }} git fetch --no-tags --depth=1 origin ${{ github.event.pull_request.head.sha }} GITHUB_TOKEN=${{ secrets.GITHUB_TOKEN }} ./gocan -gh-checks -gh-base ${{ github.event.pull_request.base.sha }} -gh-head ${{ github.event.pull_request.head.sha }} From 3c65707b381e515b82a1b506ce9857890c297bde Mon Sep 17 00:00:00 2001 From: Viacheslav Poturaev Date: Tue, 7 Apr 2026 20:32:06 +0200 Subject: [PATCH 15/15] Test gocan --- .github/workflows/gocan.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/gocan.yml b/.github/workflows/gocan.yml index 570d80a..97086e5 100644 --- a/.github/workflows/gocan.yml +++ b/.github/workflows/gocan.yml @@ -1,4 +1,4 @@ -name: gocan +name: layout-only changes on: pull_request: