diff --git a/_examples/advanced-generic/gzip_pass_through_test.go b/_examples/advanced-generic/gzip_pass_through_test.go index 7d37fb9..ea7afd1 100644 --- a/_examples/advanced-generic/gzip_pass_through_test.go +++ b/_examples/advanced-generic/gzip_pass_through_test.go @@ -4,47 +4,42 @@ package main import ( "net/http" - "net/http/httptest" "testing" "github.com/bool64/httptestbench" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/swaggest/fchi" "github.com/valyala/fasthttp" ) func Test_directGzip(t *testing.T) { r := NewRouter() - req, err := http.NewRequest(http.MethodGet, "/gzip-pass-through", nil) - require.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/gzip-pass-through") + rc.Request.Header.Set("Accept-Encoding", "gzip") - req.Header.Set("Accept-Encoding", "gzip") - rw := httptest.NewRecorder() - - r.ServeHTTP(rw, req) - assert.Equal(t, http.StatusOK, rw.Code) - assert.Equal(t, "330epditz19z", rw.Header().Get("Etag")) - assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) - assert.Equal(t, "abc", rw.Header().Get("X-Header")) - assert.Less(t, len(rw.Body.Bytes()), 500) + r.ServeHTTP(rc, rc) + assert.Equal(t, http.StatusOK, rc.Response.StatusCode()) + assert.Equal(t, "330epditz19z", string(rc.Response.Header.Peek("Etag"))) + assert.Equal(t, "gzip", string(rc.Response.Header.Peek("Content-Encoding"))) + assert.Equal(t, "abc", string(rc.Response.Header.Peek("X-Header"))) + assert.Less(t, len(rc.Response.Body()), 500) } func Test_noDirectGzip(t *testing.T) { r := NewRouter() - req, err := http.NewRequest(http.MethodGet, "/gzip-pass-through?plainStruct=1", nil) - require.NoError(t, err) - - req.Header.Set("Accept-Encoding", "gzip") - rw := httptest.NewRecorder() + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/gzip-pass-through?plainStruct=1") + rc.Request.Header.Set("Accept-Encoding", "gzip") - r.ServeHTTP(rw, req) - assert.Equal(t, http.StatusOK, rw.Code) - assert.Equal(t, "", rw.Header().Get("Etag")) // No ETag for dynamic compression. - assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) - assert.Equal(t, "cba", rw.Header().Get("X-Header")) - assert.Less(t, len(rw.Body.Bytes()), 1000) // Worse compression for better speed. + r.ServeHTTP(rc, rc) + assert.Equal(t, http.StatusOK, rc.Response.StatusCode()) + assert.Equal(t, "", string(rc.Response.Header.Peek("Etag"))) // No ETag for dynamic compression. + assert.Equal(t, "gzip", string(rc.Response.Header.Peek("Content-Encoding"))) + assert.Equal(t, "cba", string(rc.Response.Header.Peek("X-Header"))) + assert.Less(t, len(rc.Response.Body()), 1000) // Worse compression for better speed. } func Test_directGzip_perf(t *testing.T) { @@ -52,14 +47,14 @@ func Test_directGzip_perf(t *testing.T) { if httptestbench.RaceDetectorEnabled { assert.Less(t, res.Extra["B:rcvd/op"], 700.0) - assert.Less(t, res.Extra["B:sent/op"], 104.0) - assert.Less(t, res.AllocsPerOp(), int64(60)) - assert.Less(t, res.AllocedBytesPerOp(), int64(8500)) + assert.Less(t, res.Extra["B:sent/op"], 110.0) + assert.Less(t, res.AllocsPerOp(), int64(30)) + assert.Less(t, res.AllocedBytesPerOp(), int64(4500)) } else { assert.Less(t, res.Extra["B:rcvd/op"], 700.0) - assert.Less(t, res.Extra["B:sent/op"], 104.0) - assert.Less(t, res.AllocsPerOp(), int64(45)) - assert.Less(t, res.AllocedBytesPerOp(), int64(4100)) + assert.Less(t, res.Extra["B:sent/op"], 105.0) + assert.Less(t, res.AllocsPerOp(), int64(17)) + assert.Less(t, res.AllocedBytesPerOp(), int64(1100)) } } @@ -69,7 +64,7 @@ func Test_directGzip_perf(t *testing.T) { func Benchmark_directGzip(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { @@ -86,7 +81,7 @@ func Benchmark_directGzip(b *testing.B) { func Benchmark_directGzipHead(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { @@ -105,7 +100,7 @@ func Benchmark_directGzipHead(b *testing.B) { func Benchmark_noDirectGzip(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { @@ -123,7 +118,7 @@ func Benchmark_noDirectGzip(b *testing.B) { func Benchmark_directGzip_decode(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { @@ -140,7 +135,7 @@ func Benchmark_directGzip_decode(b *testing.B) { func Benchmark_noDirectGzip_decode(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { diff --git a/_examples/advanced-generic/json_body_manual.go b/_examples/advanced-generic/json_body_manual.go index 037e091..707dac9 100644 --- a/_examples/advanced-generic/json_body_manual.go +++ b/_examples/advanced-generic/json_body_manual.go @@ -7,14 +7,12 @@ import ( "encoding/json" "errors" "fmt" - "io" - "log" - "net/http" - "github.com/go-chi/chi/v5" + "github.com/swaggest/fchi" "github.com/swaggest/jsonschema-go" "github.com/swaggest/rest/request" "github.com/swaggest/usecase" + "github.com/valyala/fasthttp" ) func jsonBodyManual() usecase.Interactor { @@ -55,28 +53,25 @@ type inputWithJSON struct { var _ request.Loader = &inputWithJSON{} -func (i *inputWithJSON) LoadFromHTTPRequest(r *http.Request) (err error) { - defer func() { - if err := r.Body.Close(); err != nil { - log.Printf("failed to close request body: %s", err.Error()) - } - }() - - b, err := io.ReadAll(r.Body) - if err != nil { - return fmt.Errorf("failed to read request body: %w", err) +func (i *inputWithJSON) LoadFromFastHTTPRequest(rc *fasthttp.RequestCtx) (err error) { + if err = json.Unmarshal(rc.Request.Body(), i); err != nil { + return fmt.Errorf("failed to unmarshal request body: %w", err) } - if err = json.Unmarshal(b, i); err != nil { - return fmt.Errorf("failsed to unmarshal request body: %w", err) - } + i.Header = string(rc.Request.Header.Peek("X-Header")) - i.Header = r.Header.Get("X-Header") - if err := i.Query.UnmarshalText([]byte(r.URL.Query().Get("in_query"))); err != nil { - return fmt.Errorf("failed to decode in_query %q: %w", r.URL.Query().Get("in_query"), err) + rc.Request.URI().QueryArgs().VisitAll(func(key, value []byte) { + if string(key) == "in_query" { + if err = i.Query.UnmarshalText(value); err != nil { + err = fmt.Errorf("failed to decode in_query %q: %w", string(value), err) + } + } + }) + if err != nil { + return err } - if routeCtx := chi.RouteContext(r.Context()); routeCtx != nil { + if routeCtx := fchi.RouteContext(rc); routeCtx != nil { i.Path = routeCtx.URLParam("in-path") } else { return errors.New("missing path params in context") diff --git a/_examples/advanced-generic/json_body_manual_test.go b/_examples/advanced-generic/json_body_manual_test.go index 96f952f..fc5cb90 100644 --- a/_examples/advanced-generic/json_body_manual_test.go +++ b/_examples/advanced-generic/json_body_manual_test.go @@ -4,18 +4,18 @@ package main import ( "net/http" - "net/http/httptest" "testing" "github.com/bool64/httptestbench" + "github.com/swaggest/fchi" "github.com/valyala/fasthttp" ) -// Benchmark_jsonBodyManual-12 125672 8542 ns/op 208.0 B:rcvd/op 195.0 B:sent/op 117048 rps 4523 B/op 49 allocs/op. +// Benchmark_jsonBodyManual-12 147058 8812 ns/op 226.0 B:rcvd/op 195.0 B:sent/op 113469 rps 728 B/op 18 allocs/op. func Benchmark_jsonBodyManual(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { diff --git a/_examples/advanced-generic/json_body_test.go b/_examples/advanced-generic/json_body_test.go index 55e9f0b..f39172e 100644 --- a/_examples/advanced-generic/json_body_test.go +++ b/_examples/advanced-generic/json_body_test.go @@ -4,18 +4,18 @@ package main import ( "net/http" - "net/http/httptest" "testing" "github.com/bool64/httptestbench" + "github.com/swaggest/fchi" "github.com/valyala/fasthttp" ) -// Benchmark_jsonBody-12 96762 12042 ns/op 208.0 B:rcvd/op 188.0 B:sent/op 83033 rps 10312 B/op 100 allocs/op. +// Benchmark_jsonBody-12 68124 17828 ns/op 226.0 B:rcvd/op 188.0 B:sent/op 56083 rps 6864 B/op 85 allocs/op. func Benchmark_jsonBody(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { diff --git a/_examples/advanced-generic/json_body_validation_test.go b/_examples/advanced-generic/json_body_validation_test.go index 69e7548..e37c343 100644 --- a/_examples/advanced-generic/json_body_validation_test.go +++ b/_examples/advanced-generic/json_body_validation_test.go @@ -4,10 +4,10 @@ package main import ( "net/http" - "net/http/httptest" "testing" "github.com/bool64/httptestbench" + "github.com/swaggest/fchi" "github.com/valyala/fasthttp" ) @@ -15,7 +15,7 @@ import ( func Benchmark_jsonBodyValidation(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { diff --git a/_examples/advanced-generic/main.go b/_examples/advanced-generic/main.go index 4677001..39aa32f 100644 --- a/_examples/advanced-generic/main.go +++ b/_examples/advanced-generic/main.go @@ -4,12 +4,14 @@ package main import ( "log" - "net/http" + + "github.com/swaggest/fchi" + "github.com/valyala/fasthttp" ) func main() { log.Println("http://localhost:8011/docs") - if err := http.ListenAndServe(":8011", NewRouter()); err != nil { + if err := fasthttp.ListenAndServe(":8011", fchi.RequestHandler(NewRouter())); err != nil { log.Fatal(err) } } diff --git a/_examples/advanced-generic/output_headers_test.go b/_examples/advanced-generic/output_headers_test.go index 1736619..d4edeed 100644 --- a/_examples/advanced-generic/output_headers_test.go +++ b/_examples/advanced-generic/output_headers_test.go @@ -5,13 +5,13 @@ package main import ( "io/ioutil" "net/http" - "net/http/httptest" "testing" "github.com/bool64/httptestbench" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/swaggest/assertjson" + "github.com/swaggest/fchi" "github.com/valyala/fasthttp" ) @@ -19,7 +19,7 @@ import ( func Benchmark_outputHeaders(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { @@ -33,7 +33,7 @@ func Benchmark_outputHeaders(b *testing.B) { func Test_outputHeaders(t *testing.T) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() resp, err := http.Get(srv.URL + "/output-headers") diff --git a/_examples/advanced-generic/request_response_mapping_test.go b/_examples/advanced-generic/request_response_mapping_test.go index 0c85192..4378d8a 100644 --- a/_examples/advanced-generic/request_response_mapping_test.go +++ b/_examples/advanced-generic/request_response_mapping_test.go @@ -6,19 +6,19 @@ import ( "bytes" "io/ioutil" "net/http" - "net/http/httptest" "testing" "github.com/bool64/httptestbench" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/swaggest/fchi" "github.com/valyala/fasthttp" ) func Test_requestResponseMapping(t *testing.T) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() req, err := http.NewRequest(http.MethodPost, srv.URL+"/req-resp-mapping", @@ -44,7 +44,7 @@ func Test_requestResponseMapping(t *testing.T) { func Benchmark_requestResponseMapping(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { diff --git a/_examples/advanced-generic/router.go b/_examples/advanced-generic/router.go index 1cfe0c9..a54b016 100644 --- a/_examples/advanced-generic/router.go +++ b/_examples/advanced-generic/router.go @@ -9,7 +9,7 @@ import ( "reflect" "strings" - "github.com/rs/cors" + "github.com/swaggest/fchi" "github.com/swaggest/jsonschema-go" "github.com/swaggest/openapi-go/openapi3" "github.com/swaggest/rest" @@ -20,7 +20,7 @@ import ( swgui "github.com/swaggest/swgui/v4emb" ) -func NewRouter() http.Handler { +func NewRouter() fchi.Handler { s := web.DefaultService() s.OpenAPI.Info.Title = "Advanced Example" @@ -56,8 +56,8 @@ func NewRouter() http.Handler { s.OpenAPICollector.CombineErrors = "anyOf" s.Wrap( - // Example middleware to set up custom error responses and disable response validation for particular handlers. - func(handler http.Handler) http.Handler { + // Example middleware to set up custom error responses. + func(handler fchi.Handler) fchi.Handler { var h *nethttp.Handler if nethttp.HandlerAs(handler, &h) { h.MakeErrResp = func(ctx context.Context, err error) (int, interface{}) { @@ -87,10 +87,6 @@ func NewRouter() http.Handler { return handler }, - // Example middleware to set up CORS headers. - // See https://pkg.go.dev/github.com/rs/cors for more details. - cors.AllowAll().Handler, - // Response validator setup. // // It might be a good idea to disable this middleware in production to save performance, diff --git a/_examples/advanced-generic/router_test.go b/_examples/advanced-generic/router_test.go index 55ae603..73ff85b 100644 --- a/_examples/advanced-generic/router_test.go +++ b/_examples/advanced-generic/router_test.go @@ -6,32 +6,30 @@ import ( "encoding/json" "io/ioutil" "net/http" - "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/swaggest/assertjson" + "github.com/valyala/fasthttp" ) func TestNewRouter(t *testing.T) { r := NewRouter() - req, err := http.NewRequest(http.MethodGet, "/docs/openapi.json", nil) - require.NoError(t, err) - - rw := httptest.NewRecorder() + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/docs/openapi.json") - r.ServeHTTP(rw, req) - assert.Equal(t, http.StatusOK, rw.Code) + r.ServeHTTP(rc, rc) + assert.Equal(t, http.StatusOK, rc.Response.StatusCode()) - actualSchema, err := assertjson.MarshalIndentCompact(json.RawMessage(rw.Body.Bytes()), "", " ", 120) + actualSchema, err := assertjson.MarshalIndentCompact(json.RawMessage(rc.Response.Body()), "", " ", 120) require.NoError(t, err) expectedSchema, err := ioutil.ReadFile("_testdata/openapi.json") require.NoError(t, err) - if !assertjson.Equal(t, expectedSchema, rw.Body.Bytes(), string(actualSchema)) { + if !assertjson.Equal(t, expectedSchema, rc.Response.Body(), string(actualSchema)) { require.NoError(t, ioutil.WriteFile("_testdata/openapi_last_run.json", actualSchema, 0o600)) } } diff --git a/_examples/advanced-generic/validation_test.go b/_examples/advanced-generic/validation_test.go index 227ea0f..767522c 100644 --- a/_examples/advanced-generic/validation_test.go +++ b/_examples/advanced-generic/validation_test.go @@ -4,10 +4,10 @@ package main import ( "net/http" - "net/http/httptest" "testing" "github.com/bool64/httptestbench" + "github.com/swaggest/fchi" "github.com/valyala/fasthttp" ) @@ -16,7 +16,7 @@ import ( func Benchmark_validation(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { @@ -33,7 +33,7 @@ func Benchmark_validation(b *testing.B) { func Benchmark_noValidation(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { diff --git a/_examples/advanced/gzip_pass_through_test.go b/_examples/advanced/gzip_pass_through_test.go index f7b4c9a..d73009c 100644 --- a/_examples/advanced/gzip_pass_through_test.go +++ b/_examples/advanced/gzip_pass_through_test.go @@ -2,62 +2,57 @@ package main import ( "net/http" - "net/http/httptest" "testing" "github.com/bool64/httptestbench" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/swaggest/fchi" "github.com/valyala/fasthttp" ) func Test_directGzip(t *testing.T) { r := NewRouter() - req, err := http.NewRequest(http.MethodGet, "/gzip-pass-through", nil) - require.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/gzip-pass-through") + rc.Request.Header.Set("Accept-Encoding", "gzip") - req.Header.Set("Accept-Encoding", "gzip") - rw := httptest.NewRecorder() - - r.ServeHTTP(rw, req) - assert.Equal(t, http.StatusOK, rw.Code) - assert.Equal(t, "330epditz19z", rw.Header().Get("Etag")) - assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) - assert.Equal(t, "abc", rw.Header().Get("X-Header")) - assert.Less(t, len(rw.Body.Bytes()), 500) + r.ServeHTTP(rc, rc) + assert.Equal(t, http.StatusOK, rc.Response.StatusCode()) + assert.Equal(t, "330epditz19z", string(rc.Response.Header.Peek("Etag"))) + assert.Equal(t, "gzip", string(rc.Response.Header.Peek("Content-Encoding"))) + assert.Equal(t, "abc", string(rc.Response.Header.Peek("X-Header"))) + assert.Less(t, len(rc.Response.Body()), 500) } func Test_noDirectGzip(t *testing.T) { r := NewRouter() - req, err := http.NewRequest(http.MethodGet, "/gzip-pass-through?plainStruct=1", nil) - require.NoError(t, err) - - req.Header.Set("Accept-Encoding", "gzip") - rw := httptest.NewRecorder() + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/gzip-pass-through?plainStruct=1") + rc.Request.Header.Set("Accept-Encoding", "gzip") - r.ServeHTTP(rw, req) - assert.Equal(t, http.StatusOK, rw.Code) - assert.Equal(t, "", rw.Header().Get("Etag")) // No ETag for dynamic compression. - assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) - assert.Equal(t, "cba", rw.Header().Get("X-Header")) - assert.Less(t, len(rw.Body.Bytes()), 1000) // Worse compression for better speed. + r.ServeHTTP(rc, rc) + assert.Equal(t, http.StatusOK, rc.Response.StatusCode()) + assert.Equal(t, "", string(rc.Response.Header.Peek("Etag"))) // No ETag for dynamic compression. + assert.Equal(t, "gzip", string(rc.Response.Header.Peek("Content-Encoding"))) + assert.Equal(t, "cba", string(rc.Response.Header.Peek("X-Header"))) + assert.Less(t, len(rc.Response.Body()), 1000) // Worse compression for better speed. } func Test_directGzip_perf(t *testing.T) { res := testing.Benchmark(Benchmark_directGzip) if httptestbench.RaceDetectorEnabled { - assert.Less(t, res.Extra["B:rcvd/op"], 640.0) - assert.Less(t, res.Extra["B:sent/op"], 104.0) - assert.Less(t, res.AllocsPerOp(), int64(60)) - assert.Less(t, res.AllocedBytesPerOp(), int64(9000)) + assert.Less(t, res.Extra["B:rcvd/op"], 660.0) + assert.Less(t, res.Extra["B:sent/op"], 105.0) + assert.Less(t, res.AllocsPerOp(), int64(30)) + assert.Less(t, res.AllocedBytesPerOp(), int64(4500)) } else { - assert.Less(t, res.Extra["B:rcvd/op"], 640.0) - assert.Less(t, res.Extra["B:sent/op"], 104.0) - assert.Less(t, res.AllocsPerOp(), int64(45)) - assert.Less(t, res.AllocedBytesPerOp(), int64(4000)) + assert.Less(t, res.Extra["B:rcvd/op"], 660.0) + assert.Less(t, res.Extra["B:sent/op"], 105.0) + assert.Less(t, res.AllocsPerOp(), int64(17)) + assert.Less(t, res.AllocedBytesPerOp(), int64(1100)) } } @@ -67,7 +62,7 @@ func Test_directGzip_perf(t *testing.T) { func Benchmark_directGzip(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { @@ -84,7 +79,7 @@ func Benchmark_directGzip(b *testing.B) { func Benchmark_directGzipHead(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { @@ -103,7 +98,7 @@ func Benchmark_directGzipHead(b *testing.B) { func Benchmark_noDirectGzip(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { @@ -121,7 +116,7 @@ func Benchmark_noDirectGzip(b *testing.B) { func Benchmark_directGzip_decode(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { @@ -138,7 +133,7 @@ func Benchmark_directGzip_decode(b *testing.B) { func Benchmark_noDirectGzip_decode(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { diff --git a/_examples/advanced/json_body_test.go b/_examples/advanced/json_body_test.go index 5f0c1df..311a9fa 100644 --- a/_examples/advanced/json_body_test.go +++ b/_examples/advanced/json_body_test.go @@ -2,10 +2,10 @@ package main import ( "net/http" - "net/http/httptest" "testing" "github.com/bool64/httptestbench" + "github.com/swaggest/fchi" "github.com/valyala/fasthttp" ) @@ -14,7 +14,7 @@ import ( func Benchmark_jsonBody(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { diff --git a/_examples/advanced/json_body_validation_test.go b/_examples/advanced/json_body_validation_test.go index c7a07e4..c36d3f5 100644 --- a/_examples/advanced/json_body_validation_test.go +++ b/_examples/advanced/json_body_validation_test.go @@ -2,10 +2,10 @@ package main import ( "net/http" - "net/http/httptest" "testing" "github.com/bool64/httptestbench" + "github.com/swaggest/fchi" "github.com/valyala/fasthttp" ) @@ -13,7 +13,7 @@ import ( func Benchmark_jsonBodyValidation(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { diff --git a/_examples/advanced/main.go b/_examples/advanced/main.go index 325e033..7bed6cd 100644 --- a/_examples/advanced/main.go +++ b/_examples/advanced/main.go @@ -2,12 +2,14 @@ package main import ( "log" - "net/http" + + "github.com/swaggest/fchi" + "github.com/valyala/fasthttp" ) func main() { log.Println("http://localhost:8011/docs") - if err := http.ListenAndServe(":8011", NewRouter()); err != nil { + if err := fasthttp.ListenAndServe(":8011", fchi.RequestHandler(NewRouter())); err != nil { log.Fatal(err) } } diff --git a/_examples/advanced/output_headers_test.go b/_examples/advanced/output_headers_test.go index ed3fbc4..38b7e0f 100644 --- a/_examples/advanced/output_headers_test.go +++ b/_examples/advanced/output_headers_test.go @@ -2,10 +2,10 @@ package main import ( "net/http" - "net/http/httptest" "testing" "github.com/bool64/httptestbench" + "github.com/swaggest/fchi" "github.com/valyala/fasthttp" ) @@ -13,7 +13,7 @@ import ( func Benchmark_outputHeaders(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { diff --git a/_examples/advanced/request_response_mapping_test.go b/_examples/advanced/request_response_mapping_test.go index 5db99a5..3e45dc9 100644 --- a/_examples/advanced/request_response_mapping_test.go +++ b/_examples/advanced/request_response_mapping_test.go @@ -4,19 +4,19 @@ import ( "bytes" "io/ioutil" "net/http" - "net/http/httptest" "testing" "github.com/bool64/httptestbench" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/swaggest/fchi" "github.com/valyala/fasthttp" ) func Test_requestResponseMapping(t *testing.T) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() req, err := http.NewRequest(http.MethodPost, srv.URL+"/req-resp-mapping", @@ -42,7 +42,7 @@ func Test_requestResponseMapping(t *testing.T) { func Benchmark_requestResponseMapping(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { diff --git a/_examples/advanced/router.go b/_examples/advanced/router.go index 3622d98..a77688e 100644 --- a/_examples/advanced/router.go +++ b/_examples/advanced/router.go @@ -6,6 +6,7 @@ import ( "net/http" "reflect" + "github.com/swaggest/fchi" "github.com/swaggest/jsonschema-go" "github.com/swaggest/openapi-go/openapi3" "github.com/swaggest/rest" @@ -16,7 +17,7 @@ import ( swgui "github.com/swaggest/swgui/v4emb" ) -func NewRouter() http.Handler { +func NewRouter() fchi.Handler { s := web.DefaultService() s.OpenAPI.Info.Title = "Advanced Example" @@ -57,7 +58,7 @@ func NewRouter() http.Handler { gzip.Middleware, // Response compression with support for direct gzip pass through. // Example middleware to setup custom error responses. - func(handler http.Handler) http.Handler { + func(handler fchi.Handler) fchi.Handler { var h *nethttp.Handler if nethttp.HandlerAs(handler, &h) { h.MakeErrResp = func(ctx context.Context, err error) (int, interface{}) { diff --git a/_examples/advanced/router_test.go b/_examples/advanced/router_test.go index 94c334f..59c05fb 100644 --- a/_examples/advanced/router_test.go +++ b/_examples/advanced/router_test.go @@ -4,32 +4,30 @@ import ( "encoding/json" "io/ioutil" "net/http" - "net/http/httptest" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/swaggest/assertjson" + "github.com/valyala/fasthttp" ) func TestNewRouter(t *testing.T) { r := NewRouter() - req, err := http.NewRequest(http.MethodGet, "/docs/openapi.json", nil) - require.NoError(t, err) - - rw := httptest.NewRecorder() + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/docs/openapi.json") - r.ServeHTTP(rw, req) - assert.Equal(t, http.StatusOK, rw.Code) + r.ServeHTTP(rc, rc) + assert.Equal(t, http.StatusOK, rc.Response.StatusCode()) - actualSchema, err := assertjson.MarshalIndentCompact(json.RawMessage(rw.Body.Bytes()), "", " ", 120) + actualSchema, err := assertjson.MarshalIndentCompact(json.RawMessage(rc.Response.Body()), "", " ", 120) require.NoError(t, err) expectedSchema, err := ioutil.ReadFile("_testdata/openapi.json") require.NoError(t, err) - if !assertjson.Equal(t, expectedSchema, rw.Body.Bytes(), string(actualSchema)) { + if !assertjson.Equal(t, expectedSchema, rc.Response.Body(), string(actualSchema)) { require.NoError(t, ioutil.WriteFile("_testdata/openapi_last_run.json", actualSchema, 0o600)) } } diff --git a/_examples/advanced/validation_test.go b/_examples/advanced/validation_test.go index 5a181e3..a3871e1 100644 --- a/_examples/advanced/validation_test.go +++ b/_examples/advanced/validation_test.go @@ -2,10 +2,10 @@ package main import ( "net/http" - "net/http/httptest" "testing" "github.com/bool64/httptestbench" + "github.com/swaggest/fchi" "github.com/valyala/fasthttp" ) @@ -14,7 +14,7 @@ import ( func Benchmark_validation(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { @@ -31,7 +31,7 @@ func Benchmark_validation(b *testing.B) { func Benchmark_noValidation(b *testing.B) { r := NewRouter() - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() httptestbench.RoundTrip(b, 50, func(i int, req *fasthttp.Request) { diff --git a/_examples/basic/main.go b/_examples/basic/main.go index 95683ba..f377fdf 100644 --- a/_examples/basic/main.go +++ b/_examples/basic/main.go @@ -5,14 +5,15 @@ import ( "errors" "fmt" "log" - "net/http" "time" + "github.com/swaggest/fchi" "github.com/swaggest/rest/response/gzip" "github.com/swaggest/rest/web" swgui "github.com/swaggest/swgui/v4emb" "github.com/swaggest/usecase" "github.com/swaggest/usecase/status" + "github.com/valyala/fasthttp" ) func main() { @@ -77,7 +78,7 @@ func main() { // Start server. log.Println("http://localhost:8011/docs") - if err := http.ListenAndServe(":8011", s); err != nil { + if err := fasthttp.ListenAndServe(":8011", fchi.RequestHandler(s)); err != nil { log.Fatal(err) } } diff --git a/_examples/generic/main.go b/_examples/generic/main.go index e7c0b56..8f9fc9d 100644 --- a/_examples/generic/main.go +++ b/_examples/generic/main.go @@ -8,14 +8,15 @@ import ( "errors" "fmt" "log" - "net/http" "time" + "github.com/swaggest/fchi" "github.com/swaggest/rest/response/gzip" "github.com/swaggest/rest/web" swgui "github.com/swaggest/swgui/v4emb" "github.com/swaggest/usecase" "github.com/swaggest/usecase/status" + "github.com/valyala/fasthttp" ) func main() { @@ -82,7 +83,7 @@ func main() { // Start server. log.Println("http://localhost:8011/docs") - if err := http.ListenAndServe(":8011", s); err != nil { + if err := fasthttp.ListenAndServe(":8011", fchi.RequestHandler(s)); err != nil { log.Fatal(err) } } diff --git a/_examples/go.mod b/_examples/go.mod index 1255339..dc41593 100644 --- a/_examples/go.mod +++ b/_examples/go.mod @@ -9,20 +9,18 @@ require ( github.com/bool64/dev v0.2.16 github.com/bool64/httpmock v0.1.1 github.com/bool64/httptestbench v0.1.3 - github.com/go-chi/chi/v5 v5.0.7 github.com/kelseyhightower/envconfig v1.4.0 - github.com/rs/cors v1.8.2 github.com/stretchr/testify v1.7.2 github.com/swaggest/assertjson v1.7.0 + github.com/swaggest/fchi v1.1.0 github.com/swaggest/jsonschema-go v0.3.35 github.com/swaggest/openapi-go v0.2.18 + github.com/swaggest/rest v0.0.0-00010101000000-000000000000 github.com/swaggest/swgui v1.4.5 github.com/swaggest/usecase v1.1.3 - github.com/valyala/fasthttp v1.35.0 + github.com/valyala/fasthttp v1.37.0 ) -require github.com/swaggest/rest v0.0.0-00010101000000-000000000000 - require ( github.com/andybalholm/brotli v1.0.4 // indirect github.com/bool64/shared v0.1.4 // indirect @@ -30,7 +28,7 @@ require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.5.4 // indirect github.com/iancoleman/orderedmap v0.2.0 // indirect - github.com/klauspost/compress v1.15.1 // indirect + github.com/klauspost/compress v1.15.6 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/santhosh-tekuri/jsonschema/v3 v3.1.0 // indirect github.com/sergi/go-diff v1.2.0 // indirect diff --git a/_examples/go.sum b/_examples/go.sum index 9549507..4e2349c 100644 --- a/_examples/go.sum +++ b/_examples/go.sum @@ -21,15 +21,13 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= -github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8= -github.com/go-chi/chi/v5 v5.0.7/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/iancoleman/orderedmap v0.2.0 h1:sq1N/TFpYH++aViPcaKjys3bDClUEU7s5B+z6jq8pNA= github.com/iancoleman/orderedmap v0.2.0/go.mod h1:N0Wam8K1arqPXNWjMo21EXnBPOPp36vB07FNRdD2geA= github.com/kelseyhightower/envconfig v1.4.0 h1:Im6hONhd3pLkfDFsbRgu68RDNkGF1r3dvMUtDTo2cv8= github.com/kelseyhightower/envconfig v1.4.0/go.mod h1:cccZRl6mQpaq41TPp5QxidR+Sa3axMbJDNb//FQX6Gg= github.com/klauspost/compress v1.15.0/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= -github.com/klauspost/compress v1.15.1 h1:y9FcTHGyrebwfP0ZZqFiaxTaiDnUrGkJkI+f583BL1A= -github.com/klauspost/compress v1.15.1/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/compress v1.15.6 h1:6D9PcO8QWu0JyaQ2zUMmu16T1T+zjjEpP91guRsvDfY= +github.com/klauspost/compress v1.15.6/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -42,8 +40,6 @@ github.com/onsi/ginkgo v1.15.2 h1:l77YT15o814C2qVL47NOyjV/6RbaP7kKdrvZnxQ3Org= github.com/onsi/gomega v1.11.0 h1:+CqWgvj0OZycCaqclBD1pxKHAU+tOkHmQIWvDHq2aug= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rs/cors v1.8.2 h1:KCooALfAYGs415Cwu5ABvv9n9509fSiG5SQJn/AQo4U= -github.com/rs/cors v1.8.2/go.mod h1:XyqrcTp5zjWr1wsJ8PIRZssZ8b/WMcMf71DJnit4EMU= github.com/santhosh-tekuri/jsonschema/v3 v3.1.0 h1:levPcBfnazlA1CyCMC3asL/QLZkq9pa8tQZOH513zQw= github.com/santhosh-tekuri/jsonschema/v3 v3.1.0/go.mod h1:8kzK2TC0k0YjOForaAHdNEa7ik0fokNa2k30BKJ/W7Y= github.com/sergi/go-diff v1.2.0 h1:XU+rvMAioB0UC3q1MFrIQy4Vo5/4VsRDQQXHsEya6xQ= @@ -54,6 +50,8 @@ github.com/stretchr/testify v1.7.2 h1:4jaiDzPyXQvSd7D0EjG45355tLlV3VOECpq10pLC+8 github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1FQKckRals= github.com/swaggest/assertjson v1.7.0 h1:SKw5Rn0LQs6UvmGrIdaKQbMR1R3ncXm5KNon+QJ7jtw= github.com/swaggest/assertjson v1.7.0/go.mod h1:vxMJMehbSVJd+dDWFCKv3QRZKNTpy/ktZKTz9LOEDng= +github.com/swaggest/fchi v1.1.0 h1:FbtaW8lJ2EhazxrP4zZZRSCdRn033jG2zlrCj7ERzEg= +github.com/swaggest/fchi v1.1.0/go.mod h1:2lAwaNkyHw0OSUOZTRP51Fs2P27cAA4VFujyAPIhbOA= github.com/swaggest/form/v5 v5.0.1 h1:YQH0REX7iMKhtoVPWXREZgbt50VYXNCKK61psnD8Fgo= github.com/swaggest/form/v5 v5.0.1/go.mod h1:vdnaSTze7cxVKhWiCabrfm1YeLwWLpb9P941Gxv4FnA= github.com/swaggest/jsonschema-go v0.3.35 h1:LW5DC0WgR5YdQXyTRc5e8gLdKT0wkACg4aVJyaseU+4= @@ -68,8 +66,8 @@ github.com/swaggest/usecase v1.1.3 h1:SGnmV07jyDhdg+gqEAv/NNc8R18JQqJUs4wVq+LWa5 github.com/swaggest/usecase v1.1.3/go.mod h1:gLSjsqHiDmHOIf081asqys7UUndFrTWrfNa2opxEt7k= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= -github.com/valyala/fasthttp v1.35.0 h1:wwkR8mZn2NbigFsaw2Zj5r+xkmzjbrA/lyTmiSlal/Y= -github.com/valyala/fasthttp v1.35.0/go.mod h1:t/G+3rLek+CyY9bnIE+YlMRddxVAAGjhxndDB4i4C0I= +github.com/valyala/fasthttp v1.37.0 h1:7WHCyI7EAkQMVmrfBhWTCOaeROb1aCBiTopx63LkMbE= +github.com/valyala/fasthttp v1.37.0/go.mod h1:t/G+3rLek+CyY9bnIE+YlMRddxVAAGjhxndDB4i4C0I= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= github.com/vearutop/statigz v1.1.5 h1:qWvRgXFsseWVTFCkIvwHQPpaLNf9WI0+dDJE7I9432o= github.com/vearutop/statigz v1.1.5/go.mod h1:czAv7iXgPv/s+xsgXpVEhhD0NSOQ4wZPgmM/n7LANDI= diff --git a/_examples/task-api/cmd/task-api/main.go b/_examples/task-api/cmd/task-api/main.go index b1fcf89..566fdb6 100644 --- a/_examples/task-api/cmd/task-api/main.go +++ b/_examples/task-api/cmd/task-api/main.go @@ -3,12 +3,14 @@ package main import ( "fmt" "log" - "net/http" + "time" "github.com/kelseyhightower/envconfig" + "github.com/swaggest/fchi" "github.com/swaggest/rest/_examples/task-api/internal/infra" "github.com/swaggest/rest/_examples/task-api/internal/infra/nethttp" "github.com/swaggest/rest/_examples/task-api/internal/infra/service" + "github.com/valyala/fasthttp" ) func main() { @@ -26,20 +28,24 @@ func main() { l.EnableGracefulShutdown() // Initialize HTTP server. - srv := http.Server{Addr: fmt.Sprintf(":%d", cfg.HTTPPort), Handler: nethttp.NewRouter(l)} + srv := fasthttp.Server{ + ReadTimeout: 9 * time.Second, + IdleTimeout: 9 * time.Second, + Handler: fchi.RequestHandler(nethttp.NewRouter(l)), + } // Start HTTP server. log.Printf("starting HTTP server at http://localhost:%d/docs\n", cfg.HTTPPort) go func() { - err := srv.ListenAndServe() + err := srv.ListenAndServe(fmt.Sprintf(":%d", cfg.HTTPPort)) if err != nil { log.Fatal(err) } }() // Wait for termination signal and HTTP shutdown finished. - err := l.WaitToShutdownHTTP(&srv, "http") + err := l.WaitToShutdownFastHTTP(&srv, "http") if err != nil { log.Fatal(err) } diff --git a/_examples/task-api/internal/infra/nethttp/benchmark_test.go b/_examples/task-api/internal/infra/nethttp/benchmark_test.go index 79c9132..2ecde95 100644 --- a/_examples/task-api/internal/infra/nethttp/benchmark_test.go +++ b/_examples/task-api/internal/infra/nethttp/benchmark_test.go @@ -5,12 +5,12 @@ import ( "io/ioutil" "log" "net/http" - "net/http/httptest" "testing" "github.com/bool64/httptestbench" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/swaggest/fchi" "github.com/swaggest/rest/_examples/task-api/internal/domain/task" "github.com/swaggest/rest/_examples/task-api/internal/infra" "github.com/swaggest/rest/_examples/task-api/internal/infra/nethttp" @@ -28,7 +28,7 @@ func Benchmark_notFoundSrv(b *testing.B) { l := infra.NewServiceLocator(service.Config{}) defer l.Close() - srv := httptest.NewServer(nethttp.NewRouter(l)) + srv := fchi.NewTestServer(nethttp.NewRouter(l)) defer srv.Close() httptestbench.RoundTrip(b, 50, @@ -51,7 +51,7 @@ func Benchmark_ok(b *testing.B) { l := infra.NewServiceLocator(service.Config{}) defer l.Close() - srv := httptest.NewServer(nethttp.NewRouter(l)) + srv := fchi.NewTestServer(nethttp.NewRouter(l)) defer srv.Close() _, err := l.TaskCreator().Create(context.Background(), task.Value{Goal: "victory!"}) @@ -78,7 +78,7 @@ func Benchmark_invalidBody(b *testing.B) { defer l.Close() r := nethttp.NewRouter(l) - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) tt, err := l.TaskCreator().Create(context.Background(), task.Value{Goal: "win"}) require.NoError(b, err) diff --git a/_examples/task-api/internal/infra/nethttp/integration_test.go b/_examples/task-api/internal/infra/nethttp/integration_test.go index 9aa7a59..3bd9523 100644 --- a/_examples/task-api/internal/infra/nethttp/integration_test.go +++ b/_examples/task-api/internal/infra/nethttp/integration_test.go @@ -2,11 +2,11 @@ package nethttp_test import ( "net/http" - "net/http/httptest" "testing" "github.com/bool64/httpmock" "github.com/stretchr/testify/assert" + "github.com/swaggest/fchi" "github.com/swaggest/rest/_examples/task-api/internal/infra" "github.com/swaggest/rest/_examples/task-api/internal/infra/nethttp" "github.com/swaggest/rest/_examples/task-api/internal/infra/service" @@ -16,7 +16,7 @@ func Test_taskLifeSpan(t *testing.T) { l := infra.NewServiceLocator(service.Config{}) defer l.Close() - srv := httptest.NewServer(nethttp.NewRouter(l)) + srv := fchi.NewTestServer(nethttp.NewRouter(l)) defer srv.Close() rc := httpmock.NewClient(srv.URL) diff --git a/_examples/task-api/internal/infra/nethttp/router.go b/_examples/task-api/internal/infra/nethttp/router.go index 8994e7d..1a89c8c 100644 --- a/_examples/task-api/internal/infra/nethttp/router.go +++ b/_examples/task-api/internal/infra/nethttp/router.go @@ -4,8 +4,8 @@ import ( "net/http" "time" - "github.com/go-chi/chi/v5" - "github.com/go-chi/chi/v5/middleware" + "github.com/swaggest/fchi" + "github.com/swaggest/fchi/middleware" "github.com/swaggest/openapi-go/openapi3" "github.com/swaggest/rest" "github.com/swaggest/rest/_examples/task-api/internal/infra/log" @@ -21,13 +21,13 @@ import ( ) // NewRouter creates HTTP router. -func NewRouter(locator *service.Locator) http.Handler { +func NewRouter(locator *service.Locator) fchi.Handler { apiSchema := schema.NewOpenAPICollector() validatorFactory := jsonschema.NewFactory(apiSchema, apiSchema) decoderFactory := request.NewDecoderFactory() decoderFactory.SetDecoderFunc(rest.ParamInPath, chirouter.PathToURLValues) - r := chirouter.NewWrapper(chi.NewRouter()) + r := chirouter.NewWrapper(fchi.NewRouter()) r.Wrap( middleware.Recoverer, // Panic recovery. @@ -51,13 +51,13 @@ func NewRouter(locator *service.Locator) http.Handler { } // Unrestricted access. - r.Route("/dev", func(r chi.Router) { + r.Route("/dev", func(r fchi.Router) { r.Use(nethttp.AnnotateOpenAPI(apiSchema, func(op *openapi3.Operation) error { op.Tags = []string{"Dev Mode"} return nil })) - r.Group(func(r chi.Router) { + r.Group(func(r fchi.Router) { r.Method(http.MethodPost, "/tasks", nethttp.NewHandler(usecase.CreateTask(locator), nethttp.SuccessStatus(http.StatusCreated))) r.Method(http.MethodPut, "/tasks/{id}", nethttp.NewHandler(usecase.UpdateTask(locator), ff)) @@ -68,8 +68,8 @@ func NewRouter(locator *service.Locator) http.Handler { }) // Endpoints with admin access. - r.Route("/admin", func(r chi.Router) { - r.Group(func(r chi.Router) { + r.Route("/admin", func(r fchi.Router) { + r.Group(func(r fchi.Router) { r.Use(nethttp.AnnotateOpenAPI(apiSchema, func(op *openapi3.Operation) error { op.Tags = []string{"Admin Mode"} @@ -81,8 +81,8 @@ func NewRouter(locator *service.Locator) http.Handler { }) // Endpoints with user access. - r.Route("/user", func(r chi.Router) { - r.Group(func(r chi.Router) { + r.Route("/user", func(r fchi.Router) { + r.Group(func(r fchi.Router) { r.Use(userAuth, nethttp.HTTPBasicSecurityMiddleware(apiSchema, "User", "User access")) r.Method(http.MethodPost, "/tasks", nethttp.NewHandler(usecase.CreateTask(locator), nethttp.SuccessStatus(http.StatusCreated))) @@ -91,8 +91,8 @@ func NewRouter(locator *service.Locator) http.Handler { // Swagger UI endpoint at /docs. r.Method(http.MethodGet, "/docs/openapi.json", apiSchema) - r.Mount("/docs", swgui.NewHandler(apiSchema.Reflector().Spec.Info.Title, - "/docs/openapi.json", "/docs")) + r.Mount("/docs", fchi.Adapt(swgui.NewHandler(apiSchema.Reflector().Spec.Info.Title, + "/docs/openapi.json", "/docs"))) r.Mount("/debug", middleware.Profiler()) diff --git a/_examples/task-api/pkg/graceful/http.go b/_examples/task-api/pkg/graceful/http.go index f82a38f..645ebad 100644 --- a/_examples/task-api/pkg/graceful/http.go +++ b/_examples/task-api/pkg/graceful/http.go @@ -1,12 +1,14 @@ package graceful import ( - "context" - "net/http" + "fmt" + "time" + + "github.com/valyala/fasthttp" ) -// WaitToShutdownHTTP synchronously waits for shutdown signal and shutdowns http server. -func (s *Shutdown) WaitToShutdownHTTP(server *http.Server, subscriber string) error { +// WaitToShutdownFastHTTP synchronously waits for shutdown signal and shutdowns fasthttp server. +func (s *Shutdown) WaitToShutdownFastHTTP(server *fasthttp.Server, subscriber string) error { shutdown, done := s.ShutdownSignal(subscriber) <-shutdown @@ -19,12 +21,18 @@ func (s *Shutdown) WaitToShutdownHTTP(server *http.Server, subscriber string) er timeout = DefaultTimeout } - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - err := server.Shutdown(ctx) + fs := make(chan error) - close(done) + go func() { + fs <- server.Shutdown() + }() - return err + select { + case <-time.After(timeout): + close(done) + return fmt.Errorf("failed to gracefully shutdown fasthttp server in %s", timeout.String()) + case err := <-fs: + close(done) + return err + } } diff --git a/chirouter/path_decoder.go b/chirouter/path_decoder.go index 258213c..fc46d06 100644 --- a/chirouter/path_decoder.go +++ b/chirouter/path_decoder.go @@ -1,24 +1,23 @@ package chirouter import ( - "net/http" "net/url" - "github.com/go-chi/chi/v5" + "github.com/swaggest/fchi" + "github.com/valyala/fasthttp" ) // PathToURLValues is a decoder function for parameters in path. -func PathToURLValues(r *http.Request) (url.Values, error) { // nolint:unparam // Matches request.DecoderFactory.SetDecoderFunc. - if routeCtx := chi.RouteContext(r.Context()); routeCtx != nil { - params := make(url.Values, len(routeCtx.URLParams.Keys)) - +func PathToURLValues(rc *fasthttp.RequestCtx, params url.Values) error { // nolint:unparam // Matches request.DecoderFactory.SetDecoderFunc. + if routeCtx := fchi.RouteContext(rc); routeCtx != nil { for i, key := range routeCtx.URLParams.Keys { value := routeCtx.URLParams.Values[i] + params[key] = []string{value} } - return params, nil + return nil } - return nil, nil + return nil } diff --git a/chirouter/wrapper.go b/chirouter/wrapper.go index b101192..a245199 100644 --- a/chirouter/wrapper.go +++ b/chirouter/wrapper.go @@ -4,12 +4,12 @@ import ( "net/http" "strings" - "github.com/go-chi/chi/v5" + "github.com/swaggest/fchi" "github.com/swaggest/rest/nethttp" ) // NewWrapper creates router wrapper to upgrade middlewares processing. -func NewWrapper(r chi.Router) *Wrapper { +func NewWrapper(r fchi.Router) *Wrapper { return &Wrapper{ Router: r, } @@ -19,17 +19,17 @@ func NewWrapper(r chi.Router) *Wrapper { // // Middlewares can call nethttp.HandlerAs to inspect wrapped handlers. type Wrapper struct { - chi.Router + fchi.Router name string basePattern string - middlewares []func(http.Handler) http.Handler - wraps []func(http.Handler) http.Handler + middlewares []func(fchi.Handler) fchi.Handler + wraps []func(fchi.Handler) fchi.Handler } -var _ chi.Router = &Wrapper{} +var _ fchi.Router = &Wrapper{} -func (r *Wrapper) copy(router chi.Router, pattern string) *Wrapper { +func (r *Wrapper) copy(router fchi.Router, pattern string) *Wrapper { return &Wrapper{ Router: router, name: r.name, @@ -45,18 +45,18 @@ func (r *Wrapper) copy(router chi.Router, pattern string) *Wrapper { // Wraps can leverage nethttp.HandlerAs to inspect and access deeper layers. // For most cases Wrap can be safely used instead of Use, Use is mandatory for middlewares // that affect routing (such as middleware.StripSlashes for example). -func (r *Wrapper) Wrap(wraps ...func(handler http.Handler) http.Handler) { +func (r *Wrapper) Wrap(wraps ...func(handler fchi.Handler) fchi.Handler) { r.wraps = append(r.wraps, wraps...) } // Use appends one of more middlewares onto the Router stack. -func (r *Wrapper) Use(middlewares ...func(http.Handler) http.Handler) { +func (r *Wrapper) Use(middlewares ...func(fchi.Handler) fchi.Handler) { r.Router.Use(middlewares...) r.middlewares = append(r.middlewares, middlewares...) } // With adds inline middlewares for an endpoint handler. -func (r Wrapper) With(middlewares ...func(http.Handler) http.Handler) chi.Router { +func (r Wrapper) With(middlewares ...func(fchi.Handler) fchi.Handler) fchi.Router { c := r.copy(r.Router.With(middlewares...), "") c.middlewares = append(c.middlewares, middlewares...) @@ -64,7 +64,7 @@ func (r Wrapper) With(middlewares ...func(http.Handler) http.Handler) chi.Router } // Group adds a new inline-router along the current routing path, with a fresh middleware stack for the inline-router. -func (r *Wrapper) Group(fn func(r chi.Router)) chi.Router { +func (r *Wrapper) Group(fn func(r fchi.Router)) fchi.Router { im := r.With() if fn != nil { @@ -75,8 +75,8 @@ func (r *Wrapper) Group(fn func(r chi.Router)) chi.Router { } // Route mounts a sub-router along a `basePattern` string. -func (r *Wrapper) Route(pattern string, fn func(r chi.Router)) chi.Router { - subRouter := r.copy(chi.NewRouter(), pattern) +func (r *Wrapper) Route(pattern string, fn func(r fchi.Router)) fchi.Router { + subRouter := r.copy(fchi.NewRouter(), pattern) if fn != nil { fn(subRouter) @@ -87,86 +87,81 @@ func (r *Wrapper) Route(pattern string, fn func(r chi.Router)) chi.Router { return subRouter } -// Mount attaches another http.Handler along "./basePattern/*". -func (r *Wrapper) Mount(pattern string, h http.Handler) { +// Mount attaches another Handler along "./basePattern/*". +func (r *Wrapper) Mount(pattern string, h fchi.Handler) { h = r.prepareHandler("", pattern, h) r.captureHandler(h) r.Router.Mount(pattern, h) } // Handle adds routes for `basePattern` that matches all HTTP methods. -func (r *Wrapper) Handle(pattern string, h http.Handler) { +func (r *Wrapper) Handle(pattern string, h fchi.Handler) { h = r.prepareHandler("", pattern, h) r.captureHandler(h) r.Router.Handle(pattern, h) } // Method adds routes for `basePattern` that matches the `method` HTTP method. -func (r *Wrapper) Method(method, pattern string, h http.Handler) { +func (r *Wrapper) Method(method, pattern string, h fchi.Handler) { h = r.prepareHandler(method, pattern, h) r.captureHandler(h) r.Router.Method(method, pattern, h) } -// MethodFunc adds the route `pattern` that matches `method` http method to execute the `handlerFn` http.HandlerFunc. -func (r *Wrapper) MethodFunc(method, pattern string, handlerFn http.HandlerFunc) { - r.Method(method, pattern, handlerFn) +// Connect adds the route `pattern` that matches a CONNECT http method to execute the `h` fchi.Handler. +func (r *Wrapper) Connect(pattern string, h fchi.Handler) { + r.Method(http.MethodConnect, pattern, h) } -// Connect adds the route `pattern` that matches a CONNECT http method to execute the `handlerFn` http.HandlerFunc. -func (r *Wrapper) Connect(pattern string, handlerFn http.HandlerFunc) { - r.Method(http.MethodConnect, pattern, handlerFn) +// Delete adds the route `pattern` that matches a DELETE http method to execute the `h` fchi.Handler. +func (r *Wrapper) Delete(pattern string, h fchi.Handler) { + r.Method(http.MethodDelete, pattern, h) } -// Delete adds the route `pattern` that matches a DELETE http method to execute the `handlerFn` http.HandlerFunc. -func (r *Wrapper) Delete(pattern string, handlerFn http.HandlerFunc) { - r.Method(http.MethodDelete, pattern, handlerFn) +// Get adds the route `pattern` that matches a GET http method to execute the `h` fchi.Handler. +func (r *Wrapper) Get(pattern string, h fchi.Handler) { + r.Method(http.MethodGet, pattern, h) } -// Get adds the route `pattern` that matches a GET http method to execute the `handlerFn` http.HandlerFunc. -func (r *Wrapper) Get(pattern string, handlerFn http.HandlerFunc) { - r.Method(http.MethodGet, pattern, handlerFn) +// Head adds the route `pattern` that matches a HEAD http method to execute the `h` fchi.Handler. +func (r *Wrapper) Head(pattern string, h fchi.Handler) { + r.Method(http.MethodHead, pattern, h) } -// Head adds the route `pattern` that matches a HEAD http method to execute the `handlerFn` http.HandlerFunc. -func (r *Wrapper) Head(pattern string, handlerFn http.HandlerFunc) { - r.Method(http.MethodHead, pattern, handlerFn) +// Options adds the route `pattern` that matches a OPTIONS http method to execute the `h` fchi.Handler. +func (r *Wrapper) Options(pattern string, h fchi.Handler) { + r.Method(http.MethodOptions, pattern, h) } -// Options adds the route `pattern` that matches a OPTIONS http method to execute the `handlerFn` http.HandlerFunc. -func (r *Wrapper) Options(pattern string, handlerFn http.HandlerFunc) { - r.Method(http.MethodOptions, pattern, handlerFn) +// Patch adds the route `pattern` that matches a PATCH http method to execute the `h` fchi.Handler. +func (r *Wrapper) Patch(pattern string, h fchi.Handler) { + r.Method(http.MethodPatch, pattern, h) } -// Patch adds the route `pattern` that matches a PATCH http method to execute the `handlerFn` http.HandlerFunc. -func (r *Wrapper) Patch(pattern string, handlerFn http.HandlerFunc) { - r.Method(http.MethodPatch, pattern, handlerFn) +// Post adds the route `pattern` that matches a POST http method to execute the `h` fchi.Handler. +func (r *Wrapper) Post(pattern string, h fchi.Handler) { + r.Method(http.MethodPost, pattern, h) } -// Post adds the route `pattern` that matches a POST http method to execute the `handlerFn` http.HandlerFunc. -func (r *Wrapper) Post(pattern string, handlerFn http.HandlerFunc) { - r.Method(http.MethodPost, pattern, handlerFn) +// Put adds the route `pattern` that matches a PUT http method to execute the `h` fchi.Handler. +func (r *Wrapper) Put(pattern string, h fchi.Handler) { + r.Method(http.MethodPut, pattern, h) } -// Put adds the route `pattern` that matches a PUT http method to execute the `handlerFn` http.HandlerFunc. -func (r *Wrapper) Put(pattern string, handlerFn http.HandlerFunc) { - r.Method(http.MethodPut, pattern, handlerFn) -} - -// Trace adds the route `pattern` that matches a TRACE http method to execute the `handlerFn` http.HandlerFunc. -func (r *Wrapper) Trace(pattern string, handlerFn http.HandlerFunc) { - r.Method(http.MethodTrace, pattern, handlerFn) +// Trace adds the route `pattern` that matches a TRACE http method to execute the `h` fchi.Handler. +func (r *Wrapper) Trace(pattern string, h fchi.Handler) { + r.Method(http.MethodTrace, pattern, h) } func (r *Wrapper) resolvePattern(pattern string) string { return r.basePattern + strings.ReplaceAll(pattern, "/*/", "/") } -func (r *Wrapper) captureHandler(h http.Handler) { +func (r *Wrapper) captureHandler(h fchi.Handler) { nethttp.WrapHandler(h, r.middlewares...) } -func (r *Wrapper) prepareHandler(method, pattern string, h http.Handler) http.Handler { +func (r *Wrapper) prepareHandler(method, pattern string, h fchi.Handler) fchi.Handler { mw := r.wraps mw = append(mw, nethttp.HandlerWithRouteMiddleware(method, r.resolvePattern(pattern))) h = nethttp.WrapHandler(h, mw...) diff --git a/chirouter/wrapper_test.go b/chirouter/wrapper_test.go index 01bf740..65c0dba 100644 --- a/chirouter/wrapper_test.go +++ b/chirouter/wrapper_test.go @@ -1,57 +1,57 @@ package chirouter_test import ( + "context" "net/http" - "net/http/httptest" "net/url" "testing" - "github.com/go-chi/chi/v5" - "github.com/go-chi/chi/v5/middleware" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/swaggest/fchi" + "github.com/swaggest/fchi/middleware" "github.com/swaggest/rest" "github.com/swaggest/rest/chirouter" "github.com/swaggest/rest/nethttp" + "github.com/valyala/fasthttp" ) type HandlerWithFoo struct { - http.Handler + fchi.Handler } func (h HandlerWithFoo) Foo() {} type HandlerWithBar struct { - http.Handler + fchi.Handler } -func (h HandlerWithFoo) ServeHTTP(rw http.ResponseWriter, r *http.Request) { - if _, err := rw.Write([]byte("foo")); err != nil { +func (h HandlerWithFoo) ServeHTTP(ctx context.Context, rc *fasthttp.RequestCtx) { + if _, err := rc.Write([]byte("foo")); err != nil { panic(err) } - h.Handler.ServeHTTP(rw, r) + h.Handler.ServeHTTP(ctx, rc) } func (h HandlerWithBar) Bar() {} -func (h HandlerWithBar) ServeHTTP(rw http.ResponseWriter, r *http.Request) { - h.Handler.ServeHTTP(rw, r) +func (h HandlerWithBar) ServeHTTP(ctx context.Context, rc *fasthttp.RequestCtx) { + h.Handler.ServeHTTP(ctx, rc) - if _, err := rw.Write([]byte("bar")); err != nil { + if _, err := rc.Write([]byte("bar")); err != nil { panic(err) } } func TestNewWrapper(t *testing.T) { - r := chirouter.NewWrapper(chi.NewRouter()).With(func(handler http.Handler) http.Handler { - return http.HandlerFunc(handler.ServeHTTP) + r := chirouter.NewWrapper(fchi.NewRouter()).With(func(handler fchi.Handler) fchi.Handler { + return fchi.HandlerFunc(handler.ServeHTTP) }) handlersCnt := 0 totalCnt := 0 - mw := func(handler http.Handler) http.Handler { + mw := func(handler fchi.Handler) fchi.Handler { var ( withRoute rest.HandlerWithRoute bar interface{ Bar() } @@ -72,11 +72,12 @@ func TestNewWrapper(t *testing.T) { r.Use(mw) - r.Group(func(r chi.Router) { + r.Group(func(r fchi.Router) { r.Method(http.MethodPost, "/baz/{id}/", - HandlerWithFoo{Handler: http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - val, err := chirouter.PathToURLValues(request) + HandlerWithFoo{Handler: fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { + val := make(url.Values) + err := chirouter.PathToURLValues(rc, val) assert.NoError(t, err) assert.Equal(t, url.Values{"id": []string{"123"}}, val) })}, @@ -84,37 +85,37 @@ func TestNewWrapper(t *testing.T) { }) r.Mount("/mount", - HandlerWithFoo{Handler: http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {})}, + HandlerWithFoo{Handler: fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) {})}, ) - r.Route("/deeper", func(r chi.Router) { - r.Use(func(handler http.Handler) http.Handler { + r.Route("/deeper", func(r fchi.Router) { + r.Use(func(handler fchi.Handler) fchi.Handler { return HandlerWithFoo{Handler: handler} }) - r.Get("/foo", func(writer http.ResponseWriter, request *http.Request) {}) - r.Head("/foo", func(writer http.ResponseWriter, request *http.Request) {}) - r.Post("/foo", func(writer http.ResponseWriter, request *http.Request) {}) - r.Put("/foo", func(writer http.ResponseWriter, request *http.Request) {}) - r.Trace("/foo", func(writer http.ResponseWriter, request *http.Request) {}) - r.Connect("/foo", func(writer http.ResponseWriter, request *http.Request) {}) - r.Options("/foo", func(writer http.ResponseWriter, request *http.Request) {}) - r.Patch("/foo", func(writer http.ResponseWriter, request *http.Request) {}) - r.Delete("/foo", func(writer http.ResponseWriter, request *http.Request) {}) + r.Get("/foo", fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) {})) + r.Head("/foo", fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) {})) + r.Post("/foo", fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) {})) + r.Put("/foo", fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) {})) + r.Trace("/foo", fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) {})) + r.Connect("/foo", fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) {})) + r.Options("/foo", fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) {})) + r.Patch("/foo", fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) {})) + r.Delete("/foo", fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) {})) - r.MethodFunc(http.MethodGet, "/cuux", func(writer http.ResponseWriter, request *http.Request) {}) + r.Method(http.MethodGet, "/cuux", fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) {})) - r.Handle("/bar", http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {})) + r.Handle("/bar", fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) {})) }) for _, u := range []string{"/baz/123/", "/deeper/foo", "/mount/abc"} { - req, err := http.NewRequest(http.MethodPost, u, nil) - require.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI(u) + rc.Request.Header.SetMethod(http.MethodPost) - rw := httptest.NewRecorder() - r.ServeHTTP(rw, req) + r.ServeHTTP(rc, rc) - assert.Equal(t, "foobar", rw.Body.String(), u) + assert.Equal(t, "foobar", string(rc.Response.Body()), u) } assert.Equal(t, 14, handlersCnt) @@ -124,50 +125,50 @@ func TestNewWrapper(t *testing.T) { func TestWrapper_Use_precedence(t *testing.T) { var log []string - // Vanilla chi router. - cr := chi.NewRouter() + // Vanilla fchi router. + cr := fchi.NewRouter() cr.Use( - func(handler http.Handler) http.Handler { - return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + func(handler fchi.Handler) fchi.Handler { + return fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { log = append(log, "cmw1 before") - handler.ServeHTTP(writer, request) + handler.ServeHTTP(ctx, rc) log = append(log, "cmw1 after") }) }, - func(handler http.Handler) http.Handler { - return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + func(handler fchi.Handler) fchi.Handler { + return fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { log = append(log, "cmw2 before") - handler.ServeHTTP(writer, request) + handler.ServeHTTP(ctx, rc) log = append(log, "cmw2 after") }) }, ) // Wrapped chi router. - wr := chirouter.NewWrapper(chi.NewRouter()) + wr := chirouter.NewWrapper(fchi.NewRouter()) wr.Use( - func(handler http.Handler) http.Handler { - return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + func(handler fchi.Handler) fchi.Handler { + return fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { log = append(log, "wmw1 before") - handler.ServeHTTP(writer, request) + handler.ServeHTTP(ctx, rc) log = append(log, "wmw1 after") }) }, - func(handler http.Handler) http.Handler { - return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + func(handler fchi.Handler) fchi.Handler { + return fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { log = append(log, "wmw2 before") - handler.ServeHTTP(writer, request) + handler.ServeHTTP(ctx, rc) log = append(log, "wmw2 after") }) }, ) - req, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/") - h := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + h := fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { log = append(log, "h") }) @@ -175,8 +176,8 @@ func TestWrapper_Use_precedence(t *testing.T) { cr.Method(http.MethodGet, "/", h) wr.Method(http.MethodGet, "/", h) - cr.ServeHTTP(nil, req) - wr.ServeHTTP(nil, req) + cr.ServeHTTP(rc, rc) + wr.ServeHTTP(rc, rc) assert.Equal(t, []string{ "cmw1 before", "cmw2 before", "h", "cmw2 after", "cmw1 after", "wmw1 before", "wmw2 before", "h", "wmw2 after", "wmw1 after", @@ -200,36 +201,34 @@ func TestWrapper_Use_precedence(t *testing.T) { func TestWrapper_Use_StripSlashes(t *testing.T) { var log []string - r := chi.NewRouter() + r := fchi.NewRouter() r.Use( middleware.StripSlashes, - func(handler http.Handler) http.Handler { - return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - handler.ServeHTTP(writer, request) + func(handler fchi.Handler) fchi.Handler { + return fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { + handler.ServeHTTP(ctx, rc) }) }, ) // Wrapped chi router. - wr := chirouter.NewWrapper(chi.NewRouter()) + wr := chirouter.NewWrapper(fchi.NewRouter()) wr.Use( middleware.StripSlashes, - func(handler http.Handler) http.Handler { - return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - handler.ServeHTTP(writer, request) + func(handler fchi.Handler) fchi.Handler { + return fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { + handler.ServeHTTP(ctx, rc) }) }, ) - req, err := http.NewRequest(http.MethodGet, "/foo/", nil) - require.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/foo/") - rw := httptest.NewRecorder() - - h := http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - if _, err := writer.Write([]byte("OK")); err != nil { + h := fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { + if _, err := rc.Write([]byte("OK")); err != nil { log = append(log, err.Error()) } @@ -237,18 +236,18 @@ func TestWrapper_Use_StripSlashes(t *testing.T) { }) r.Method(http.MethodGet, "/foo", h) - r.ServeHTTP(rw, req) + r.ServeHTTP(rc, rc) - assert.Equal(t, http.StatusOK, rw.Code) - assert.Equal(t, "OK", rw.Body.String()) + assert.Equal(t, http.StatusOK, rc.Response.StatusCode()) + assert.Equal(t, "OK", string(rc.Response.Body())) - rw = httptest.NewRecorder() + rc.Response = fasthttp.Response{} wr.Method(http.MethodGet, "/foo", h) - wr.ServeHTTP(rw, req) + wr.ServeHTTP(rc, rc) - assert.Equal(t, http.StatusOK, rw.Code) - assert.Equal(t, "OK", rw.Body.String()) + assert.Equal(t, http.StatusOK, rc.Response.StatusCode()) + assert.Equal(t, "OK", string(rc.Response.Body())) assert.Equal(t, []string{ "h", "h", diff --git a/go.mod b/go.mod index e85617c..23d9ec4 100644 --- a/go.mod +++ b/go.mod @@ -4,32 +4,32 @@ go 1.17 require ( github.com/bool64/dev v0.2.16 - github.com/bool64/httpmock v0.1.1 - github.com/bool64/shared v0.1.4 github.com/cespare/xxhash/v2 v2.1.2 - github.com/go-chi/chi/v5 v5.0.7 github.com/santhosh-tekuri/jsonschema/v3 v3.1.0 github.com/stretchr/testify v1.7.2 github.com/swaggest/assertjson v1.7.0 + github.com/swaggest/fchi v1.1.0 github.com/swaggest/form/v5 v5.0.1 github.com/swaggest/jsonschema-go v0.3.35 github.com/swaggest/openapi-go v0.2.18 github.com/swaggest/refl v1.0.2 github.com/swaggest/usecase v1.1.3 + github.com/valyala/fasthttp v1.37.0 ) require ( + github.com/andybalholm/brotli v1.0.4 // indirect + github.com/bool64/shared v0.1.4 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/iancoleman/orderedmap v0.2.0 // indirect + github.com/klauspost/compress v1.15.6 // indirect github.com/mattn/go-isatty v0.0.14 // indirect github.com/nxadm/tail v1.4.8 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/sergi/go-diff v1.2.0 // indirect - github.com/yosuke-furukawa/json5 v0.1.2-0.20201207051438-cf7bb3f354ff // indirect + github.com/valyala/bytebufferpool v1.0.0 // indirect github.com/yudai/gojsondiff v1.0.0 // indirect github.com/yudai/golcs v0.0.0-20170316035057-ecda9a501e82 // indirect - golang.org/x/net v0.0.0-20211105192438-b53810dc28af // indirect - golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e // indirect gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 69c8f3f..4dfebe3 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY= +github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig= github.com/bool64/dev v0.1.25/go.mod h1:cTHiTDNc8EewrQPy3p1obNilpMpdmlUesDkFTF2zRWU= github.com/bool64/dev v0.1.35/go.mod h1:cTHiTDNc8EewrQPy3p1obNilpMpdmlUesDkFTF2zRWU= github.com/bool64/dev v0.1.41/go.mod h1:cTHiTDNc8EewrQPy3p1obNilpMpdmlUesDkFTF2zRWU= @@ -6,8 +8,6 @@ github.com/bool64/dev v0.2.10/go.mod h1:/csLrm+4oDSsKJRIVS0mrywAonLnYKFG8RvGT7Jh github.com/bool64/dev v0.2.12/go.mod h1:/csLrm+4oDSsKJRIVS0mrywAonLnYKFG8RvGT7Jh9b8= github.com/bool64/dev v0.2.16 h1:ZlybgWWXmHGMojqIjDrtl5QF6jmE4hNeojE00nioVk0= github.com/bool64/dev v0.2.16/go.mod h1:/csLrm+4oDSsKJRIVS0mrywAonLnYKFG8RvGT7Jh9b8= -github.com/bool64/httpmock v0.1.1 h1:jpqM0S8efvJfN7Uy5fBUJKu2C640/ZS0yboxpeyVwm0= -github.com/bool64/httpmock v0.1.1/go.mod h1:Ju7xrs8gVyxANbgIxoxX4Pkj1uHygzPEpGEnfqct+gA= github.com/bool64/shared v0.1.3/go.mod h1:RF1p1Oi29ofgOvinBpetbF5mceOUP3kpMkvLbWOmtm0= github.com/bool64/shared v0.1.4 h1:zwtb1dl2QzDa9TJOq2jzDTdb5IPf9XlxTGKN8cySWT0= github.com/bool64/shared v0.1.4/go.mod h1:ryGjsnQFh6BnEXClfVlEJrzjwzat7CmA8PNS5E+jPp0= @@ -19,8 +19,6 @@ github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= -github.com/go-chi/chi/v5 v5.0.7 h1:rDTPXLDHGATaeHvVlLcR4Qe0zftYethFucbjVQ1PxU8= -github.com/go-chi/chi/v5 v5.0.7/go.mod h1:DslCQbL2OYiznFReuXYUmQ2hGd1aDpCnlMNITLSKoi8= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.4.0-rc.1/go.mod h1:ceaxUfeHdC40wWswd/P6IGgMaK3YpKi5j83Wpe3EHw8= github.com/golang/protobuf v1.4.0-rc.1.0.20200221234624-67d41d38c208/go.mod h1:xKAWHe0F5eneWXFV3EuXVDTCmh+JuBKY0li0aMyXATA= @@ -36,6 +34,10 @@ github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpO github.com/iancoleman/orderedmap v0.2.0 h1:sq1N/TFpYH++aViPcaKjys3bDClUEU7s5B+z6jq8pNA= github.com/iancoleman/orderedmap v0.2.0/go.mod h1:N0Wam8K1arqPXNWjMo21EXnBPOPp36vB07FNRdD2geA= github.com/k0kubun/colorstring v0.0.0-20150214042306-9440f1994b88/go.mod h1:3w7q1U84EfirKl04SVQ/s7nPm1ZPhiXd34z40TNz36k= +github.com/klauspost/compress v1.15.0/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/compress v1.15.5/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= +github.com/klauspost/compress v1.15.6 h1:6D9PcO8QWu0JyaQ2zUMmu16T1T+zjjEpP91guRsvDfY= +github.com/klauspost/compress v1.15.6/go.mod h1:PhcZ0MbTNciWF3rruxRgKxI5NkcHHrHUDtV4Yw2GlzU= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -72,6 +74,8 @@ github.com/stretchr/testify v1.7.2/go.mod h1:R6va5+xMeoiuVRoj+gSkQ7d3FALtqAAGI1F github.com/swaggest/assertjson v1.6.8/go.mod h1:Euf0upn9Vlaf1/llYHTs+Kx5K3vVbpMbsZhth7zlN7M= github.com/swaggest/assertjson v1.7.0 h1:SKw5Rn0LQs6UvmGrIdaKQbMR1R3ncXm5KNon+QJ7jtw= github.com/swaggest/assertjson v1.7.0/go.mod h1:vxMJMehbSVJd+dDWFCKv3QRZKNTpy/ktZKTz9LOEDng= +github.com/swaggest/fchi v1.1.0 h1:FbtaW8lJ2EhazxrP4zZZRSCdRn033jG2zlrCj7ERzEg= +github.com/swaggest/fchi v1.1.0/go.mod h1:2lAwaNkyHw0OSUOZTRP51Fs2P27cAA4VFujyAPIhbOA= github.com/swaggest/form/v5 v5.0.1 h1:YQH0REX7iMKhtoVPWXREZgbt50VYXNCKK61psnD8Fgo= github.com/swaggest/form/v5 v5.0.1/go.mod h1:vdnaSTze7cxVKhWiCabrfm1YeLwWLpb9P941Gxv4FnA= github.com/swaggest/jsonschema-go v0.3.34/go.mod h1:JAF1nm+uIaMOXktuQepmkiRcgQ5yJk4Ccwx9HVt2cXw= @@ -83,7 +87,11 @@ github.com/swaggest/refl v1.0.2 h1:VmP8smuDS1EzUPn31++TzMi13CAaVJdlWpIxzj0up88= github.com/swaggest/refl v1.0.2/go.mod h1:DoiPoBJPYHU6Z9fIA6zXQ9uI6VRL6M8BFX5YFT+ym9g= github.com/swaggest/usecase v1.1.3 h1:SGnmV07jyDhdg+gqEAv/NNc8R18JQqJUs4wVq+LWa5g= github.com/swaggest/usecase v1.1.3/go.mod h1:gLSjsqHiDmHOIf081asqys7UUndFrTWrfNa2opxEt7k= -github.com/yosuke-furukawa/json5 v0.1.2-0.20201207051438-cf7bb3f354ff h1:7YqG491bE4vstXRz1lD38rbSgbXnirvROz1lZiOnPO8= +github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +github.com/valyala/fasthttp v1.37.0 h1:7WHCyI7EAkQMVmrfBhWTCOaeROb1aCBiTopx63LkMbE= +github.com/valyala/fasthttp v1.37.0/go.mod h1:t/G+3rLek+CyY9bnIE+YlMRddxVAAGjhxndDB4i4C0I= +github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= github.com/yosuke-furukawa/json5 v0.1.2-0.20201207051438-cf7bb3f354ff/go.mod h1:sw49aWDqNdRJ6DYUtIQiaA3xyj2IL9tjeNYmX2ixwcU= github.com/yudai/gojsondiff v1.0.0 h1:27cbfqXLVEJ1o8I6v3y9lg8Ydm53EKqHXAOMxEGlCOA= github.com/yudai/gojsondiff v1.0.0/go.mod h1:AY32+k2cwILAkW1fbgxQ5mUmMiZFgLIV+FBNExI05xg= @@ -95,6 +103,7 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= @@ -102,8 +111,9 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200520004742-59133d7f0dd7/go.mod h1:qpuaurCH72eLCgpAm/N6yyVIVM9cpaDIP3A8BGJEC5A= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20201202161906-c7110b5ffcbb/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= -golang.org/x/net v0.0.0-20211105192438-b53810dc28af h1:SMeNJG/vclJ5wyBBd4xupMsSJIHTd1coW9g7q6KOjmY= -golang.org/x/net v0.0.0-20211105192438-b53810dc28af/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220225172249-27dd8689420f h1:oA4XRj0qtSt8Yo1Zms0CUlsT3KG69V2UGQWPBxujDmc= +golang.org/x/net v0.0.0-20220225172249-27dd8689420f/go.mod h1:CfG3xpIq0wQ8r1q4Su4UZFWDARRcnwPjda9FqA0JpMk= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -120,14 +130,18 @@ golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210112080510-489259a85091/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e h1:WUoyKPm6nCo1BnNUvPGnFG3T5DUVem42yDJZZ4CNxMA= -golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9 h1:nhht2DYV/Sn3qOayu8lM+cU1ii9sTLUeBQwQQfUHtrs= +golang.org/x/sys v0.0.0-20220227234510-4e6760a101f9/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6 h1:aRYxNxv6iGQlyVaZmk6ZgYEDa+Jg18DxebPSrd6bg1M= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= diff --git a/gzip/container.go b/gzip/container.go index 36186c7..ba34b9a 100644 --- a/gzip/container.go +++ b/gzip/container.go @@ -9,6 +9,7 @@ import ( "strconv" "github.com/cespare/xxhash/v2" + "github.com/valyala/fasthttp" ) // Writer writes gzip data into suitable stream or returns 0, nil. @@ -26,10 +27,12 @@ type JSONContainer struct { // // Bytes are unpacked if response writer does not support direct gzip writing. func WriteCompressedBytes(compressed []byte, w io.Writer) (int, error) { - if gw, ok := w.(Writer); ok { - n, err := gw.GzipWrite(compressed) - if n != 0 { - return n, err + if rc, ok := w.(*fasthttp.RequestCtx); ok { + if rc.Request.Header.HasAcceptEncoding("gzip") { + rc.Request.Header.Del("Accept-Encoding") + rc.Response.Header.Set("Content-Encoding", "gzip") + + return rc.Write(compressed) } } diff --git a/gzip/container_test.go b/gzip/container_test.go index 89fa113..6fd86d5 100644 --- a/gzip/container_test.go +++ b/gzip/container_test.go @@ -4,7 +4,6 @@ import ( "context" "encoding/json" "net/http" - "net/http/httptest" "testing" "github.com/stretchr/testify/assert" @@ -15,6 +14,7 @@ import ( "github.com/swaggest/rest/response" gzip2 "github.com/swaggest/rest/response/gzip" "github.com/swaggest/usecase" + "github.com/valyala/fasthttp" ) func TestWriteJSON(t *testing.T) { @@ -58,36 +58,34 @@ func TestWriteJSON(t *testing.T) { h := nethttp.NewHandler(u) h.SetResponseEncoder(&response.Encoder{}) - r, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) - - w := httptest.NewRecorder() + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/") - r.Header.Set("Accept-Encoding", "deflate, gzip") - gzip2.Middleware(h).ServeHTTP(w, r) + rc.Request.Header.Set("Accept-Encoding", "deflate, gzip") + gzip2.Middleware(h).ServeHTTP(rc, rc) - assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "gzip", w.Header().Get("Content-Encoding")) - assert.Equal(t, "1ofolk6sr5j4r", w.Header().Get("Etag")) - assert.Equal(t, cont.GzipCompressedJSON(), w.Body.Bytes()) + assert.Equal(t, http.StatusOK, rc.Response.StatusCode()) + assert.Equal(t, "gzip", string(rc.Response.Header.Peek("Content-Encoding"))) + assert.Equal(t, "1ofolk6sr5j4r", string(rc.Response.Header.Peek("Etag"))) + assert.Equal(t, cont.GzipCompressedJSON(), rc.Response.Body()) - w = httptest.NewRecorder() + rc.Request.Header.Del("Accept-Encoding") + rc.Response = fasthttp.Response{} + gzip2.Middleware(h).ServeHTTP(rc, rc) - r.Header.Del("Accept-Encoding") - gzip2.Middleware(h).ServeHTTP(w, r) + assert.Equal(t, http.StatusOK, rc.Response.StatusCode()) + assert.Equal(t, "", string(rc.Response.Header.Peek("Content-Encoding"))) + assert.Equal(t, "1ofolk6sr5j4r", string(rc.Response.Header.Peek("Etag"))) + assert.Equal(t, append(vj, '\n'), rc.Response.Body()) - assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "", w.Header().Get("Content-Encoding")) - assert.Equal(t, "1ofolk6sr5j4r", w.Header().Get("Etag")) - assert.Equal(t, append(vj, '\n'), w.Body.Bytes()) - - w = httptest.NewRecorder() ur = v - r.Header.Set("Accept-Encoding", "deflate, gzip") - gzip2.Middleware(h).ServeHTTP(w, r) + rc.Request.Header.Set("Accept-Encoding", "deflate, gzip") + rc.Response = fasthttp.Response{} + + gzip2.Middleware(h).ServeHTTP(rc, rc) - assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "gzip", w.Header().Get("Content-Encoding")) - assert.Equal(t, "", w.Header().Get("Etag")) + assert.Equal(t, http.StatusOK, rc.Response.StatusCode()) + assert.Equal(t, "gzip", string(rc.Response.Header.Peek("Content-Encoding"))) + assert.Equal(t, "", string(rc.Response.Header.Peek("Etag"))) } diff --git a/jsonschema/validator_test.go b/jsonschema/validator_test.go index 992a057..8425a66 100644 --- a/jsonschema/validator_test.go +++ b/jsonschema/validator_test.go @@ -1,7 +1,6 @@ package jsonschema_test import ( - "context" "fmt" "net/http" "testing" @@ -12,6 +11,7 @@ import ( "github.com/swaggest/rest/jsonschema" "github.com/swaggest/rest/openapi" "github.com/swaggest/rest/request" + "github.com/valyala/fasthttp" ) // BenchmarkRequestValidator_ValidateRequestData-4 634356 1761 ns/op 2496 B/op 8 allocs/op. @@ -103,9 +103,8 @@ func TestNullableTime(t *testing.T) { } func TestValidator_ForbidUnknownParams(t *testing.T) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, - "/?foo=bar&baz=1", nil) - assert.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/?foo=bar&baz=1") type input struct { Foo string `query:"foo"` @@ -119,7 +118,7 @@ func TestValidator_ForbidUnknownParams(t *testing.T) { validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). MakeRequestValidator(http.MethodGet, in, nil) - err = dec.Decode(req, in, validator) + err := dec.Decode(rc, in, validator) assert.Equal(t, rest.ValidationErrors{"query:baz": []string{"unknown parameter with value 1"}}, err, fmt.Sprintf("%#v", err)) } diff --git a/nethttp/example_test.go b/nethttp/example_test.go index 531eb0f..86d36dd 100644 --- a/nethttp/example_test.go +++ b/nethttp/example_test.go @@ -1,20 +1,22 @@ package nethttp_test import ( + "bytes" "context" "net/http" - "github.com/go-chi/chi/v5" + "github.com/swaggest/fchi" "github.com/swaggest/openapi-go/openapi3" "github.com/swaggest/rest/chirouter" "github.com/swaggest/rest/nethttp" "github.com/swaggest/rest/openapi" "github.com/swaggest/usecase" + "github.com/valyala/fasthttp" ) func ExampleSecurityMiddleware() { // Create router. - r := chirouter.NewWrapper(chi.NewRouter()) + r := chirouter.NewWrapper(fchi.NewRouter()) // Init API documentation schema. apiSchema := &openapi.Collector{} @@ -25,15 +27,15 @@ func ExampleSecurityMiddleware() { ) // Configure an actual security middleware. - serviceTokenAuth := func(h http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) { - if req.Header.Get("Authorization") != "" { - http.Error(w, "Authentication failed.", http.StatusUnauthorized) + serviceTokenAuth := func(h fchi.Handler) fchi.Handler { + return fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { + if !bytes.Equal(rc.Request.Header.Peek("Authorization"), []byte("")) { + fchi.Error(rc, "Authentication failed.", http.StatusUnauthorized) return } - h.ServeHTTP(w, req) + h.ServeHTTP(ctx, rc) }) } diff --git a/nethttp/handler.go b/nethttp/handler.go index ae85745..972100e 100644 --- a/nethttp/handler.go +++ b/nethttp/handler.go @@ -2,16 +2,17 @@ package nethttp import ( "context" - "log" "net/http" "reflect" + "github.com/swaggest/fchi" "github.com/swaggest/rest" "github.com/swaggest/usecase" "github.com/swaggest/usecase/status" + "github.com/valyala/fasthttp" ) -var _ http.Handler = &Handler{} +var _ fchi.Handler = &Handler{} // NewHandler creates use case http handler. func NewHandler(useCase usecase.Interactor, options ...func(h *Handler)) *Handler { @@ -52,7 +53,7 @@ type Handler struct { rest.HandlerTrait // HandleErrResponse allows control of error response processing. - HandleErrResponse func(w http.ResponseWriter, r *http.Request, err error) + HandleErrResponse func(ctx context.Context, rc *fasthttp.RequestCtx, err error) // requestDecoder maps data from http.Request into structured Go input value. requestDecoder RequestDecoder @@ -82,13 +83,13 @@ func (h *Handler) SetRequestDecoder(requestDecoder RequestDecoder) { h.requestDecoder = requestDecoder } -func (h *Handler) decodeRequest(r *http.Request) (interface{}, error) { +func (h *Handler) decodeRequest(rc *fasthttp.RequestCtx) (interface{}, error) { if h.requestDecoder == nil { panic("request decoder is not initialized, please use SetRequestDecoder") } iv := reflect.New(h.inputBufferType) - err := h.requestDecoder.Decode(r, iv.Interface(), h.ReqValidator) + err := h.requestDecoder.Decode(rc, iv.Interface(), h.ReqValidator) if !h.inputIsPtr { return iv.Elem().Interface(), err @@ -98,7 +99,7 @@ func (h *Handler) decodeRequest(r *http.Request) (interface{}, error) { } // ServeHTTP serves http inputPort with use case interactor. -func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { +func (h *Handler) ServeHTTP(ctx context.Context, rc *fasthttp.RequestCtx) { var ( input, output interface{} err error @@ -108,74 +109,64 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { panic("response encoder is not initialized, please use SetResponseEncoder") } - output = h.responseEncoder.MakeOutput(w, h.HandlerTrait) + output = h.responseEncoder.MakeOutput(rc, h.HandlerTrait) if h.inputBufferType != nil { - input, err = h.decodeRequest(r) - - if r.MultipartForm != nil { - defer closeMultipartForm(r) - } + input, err = h.decodeRequest(rc) if err != nil { - h.handleDecodeError(w, r, err, input, output) + h.handleDecodeError(ctx, rc, err, input, output) return } } - err = h.useCase.Interact(r.Context(), input, output) + err = h.useCase.Interact(ctx, input, output) if err != nil { - h.handleErrResponse(w, r, err) + h.handleErrResponse(ctx, rc, err) return } - h.responseEncoder.WriteSuccessfulResponse(w, r, output, h.HandlerTrait) + h.responseEncoder.WriteSuccessfulResponse(rc, output, h.HandlerTrait) } -func (h *Handler) handleErrResponseDefault(w http.ResponseWriter, r *http.Request, err error) { +func (h *Handler) handleErrResponseDefault(ctx context.Context, rc *fasthttp.RequestCtx, err error) { var ( code int er interface{} ) if h.MakeErrResp != nil { - code, er = h.MakeErrResp(r.Context(), err) + code, er = h.MakeErrResp(ctx, err) } else { code, er = rest.Err(err) } - h.responseEncoder.WriteErrResponse(w, r, code, er) + h.responseEncoder.WriteErrResponse(rc, code, er) } -func (h *Handler) handleErrResponse(w http.ResponseWriter, r *http.Request, err error) { +func (h *Handler) handleErrResponse(ctx context.Context, rc *fasthttp.RequestCtx, err error) { if h.HandleErrResponse != nil { - h.HandleErrResponse(w, r, err) + h.HandleErrResponse(ctx, rc, err) return } - h.handleErrResponseDefault(w, r, err) -} - -func closeMultipartForm(r *http.Request) { - if err := r.MultipartForm.RemoveAll(); err != nil { - log.Println(err) - } + h.handleErrResponseDefault(ctx, rc, err) } type decodeErrCtxKey struct{} -func (h *Handler) handleDecodeError(w http.ResponseWriter, r *http.Request, err error, input, output interface{}) { +func (h *Handler) handleDecodeError(ctx context.Context, rc *fasthttp.RequestCtx, err error, input, output interface{}) { err = status.Wrap(err, status.InvalidArgument) if h.failingUseCase != nil { - err = h.failingUseCase.Interact(context.WithValue(r.Context(), decodeErrCtxKey{}, err), input, output) + err = h.failingUseCase.Interact(context.WithValue(ctx, decodeErrCtxKey{}, err), input, output) } - h.handleErrResponse(w, r, err) + h.handleErrResponse(ctx, rc, err) } func (h *Handler) setupInputBuffer() { @@ -213,7 +204,7 @@ func (h *Handler) setupOutputBuffer() { } type handlerWithRoute struct { - http.Handler + fchi.Handler method string pathPattern string } @@ -227,8 +218,8 @@ func (h handlerWithRoute) RoutePattern() string { } // HandlerWithRouteMiddleware wraps handler with routing information. -func HandlerWithRouteMiddleware(method, pathPattern string) func(http.Handler) http.Handler { - return func(handler http.Handler) http.Handler { +func HandlerWithRouteMiddleware(method, pathPattern string) func(fchi.Handler) fchi.Handler { + return func(handler fchi.Handler) fchi.Handler { return handlerWithRoute{ Handler: handler, pathPattern: pathPattern, @@ -239,18 +230,17 @@ func HandlerWithRouteMiddleware(method, pathPattern string) func(http.Handler) h // RequestDecoder maps data from http.Request into structured Go input value. type RequestDecoder interface { - Decode(r *http.Request, input interface{}, validator rest.Validator) error + Decode(rc *fasthttp.RequestCtx, input interface{}, validator rest.Validator) error } // ResponseEncoder writes data from use case output/error into http.ResponseWriter. type ResponseEncoder interface { - WriteErrResponse(w http.ResponseWriter, r *http.Request, statusCode int, response interface{}) + WriteErrResponse(rc *fasthttp.RequestCtx, statusCode int, response interface{}) WriteSuccessfulResponse( - w http.ResponseWriter, - r *http.Request, + rc *fasthttp.RequestCtx, output interface{}, ht rest.HandlerTrait, ) SetupOutput(output interface{}, ht *rest.HandlerTrait) - MakeOutput(w http.ResponseWriter, ht rest.HandlerTrait) interface{} + MakeOutput(rc *fasthttp.RequestCtx, ht rest.HandlerTrait) interface{} } diff --git a/nethttp/handler_test.go b/nethttp/handler_test.go index 415f167..b0bbed4 100644 --- a/nethttp/handler_test.go +++ b/nethttp/handler_test.go @@ -4,17 +4,17 @@ import ( "context" "errors" "net/http" - "net/http/httptest" - "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/swaggest/fchi" "github.com/swaggest/rest" "github.com/swaggest/rest/nethttp" "github.com/swaggest/rest/request" "github.com/swaggest/rest/response" "github.com/swaggest/usecase" + "github.com/valyala/fasthttp" ) type Input struct { @@ -50,8 +50,8 @@ func TestHandler_ServeHTTP(t *testing.T) { return nil }) - req, err := http.NewRequest(http.MethodGet, "/test", nil) - require.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/test") validatorCalled := false h := nethttp.NewHandler(u, @@ -69,8 +69,8 @@ func TestHandler_ServeHTTP(t *testing.T) { h.SetResponseEncoder(&response.Encoder{}) h.SetRequestDecoder(request.DecoderFunc( - func(r *http.Request, input interface{}, validator rest.Validator) error { - assert.Equal(t, req, r) + func(r *fasthttp.RequestCtx, input interface{}, validator rest.Validator) error { + assert.Equal(t, rc, r) in, ok := input.(*Input) require.True(t, ok) require.NotNil(t, in) @@ -94,11 +94,10 @@ func TestHandler_ServeHTTP(t *testing.T) { }) }))(h) - rw := httptest.NewRecorder() - hh.ServeHTTP(rw, req) + hh.ServeHTTP(rc, rc) - assert.Equal(t, http.StatusAccepted, rw.Code) - assert.Equal(t, `{"value":"abc"}`+"\n", rw.Body.String()) + assert.Equal(t, http.StatusAccepted, rc.Response.StatusCode()) + assert.Equal(t, `{"value":"abc"}`+"\n", string(rc.Response.Body())) assert.True(t, validatorCalled) assert.True(t, umwCalled) } @@ -119,12 +118,12 @@ func TestHandler_ServeHTTP_decodeErr(t *testing.T) { return nil }) - req, err := http.NewRequest(http.MethodGet, "/test", nil) - require.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/test") uh := nethttp.NewHandler(u) uh.SetRequestDecoder(request.DecoderFunc( - func(r *http.Request, input interface{}, validator rest.Validator) error { + func(r *fasthttp.RequestCtx, input interface{}, validator rest.Validator) error { return errors.New("failed to decode request") }, )) @@ -139,12 +138,11 @@ func TestHandler_ServeHTTP_decodeErr(t *testing.T) { }) }))(uh) - rw := httptest.NewRecorder() - h.ServeHTTP(rw, req) + h.ServeHTTP(rc, rc) - assert.Equal(t, http.StatusBadRequest, rw.Code) + assert.Equal(t, http.StatusBadRequest, rc.Response.StatusCode()) assert.Equal(t, `{"status":"INVALID_ARGUMENT","error":"invalid argument: failed to decode request"}`+"\n", - rw.Body.String()) + string(rc.Response.Body())) assert.True(t, umwCalled) } @@ -163,14 +161,13 @@ func TestHandler_ServeHTTP_emptyPorts(t *testing.T) { h := nethttp.NewHandler(u) h.SetResponseEncoder(&response.Encoder{}) - req, err := http.NewRequest(http.MethodGet, "/test", nil) - require.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/test") - rw := httptest.NewRecorder() - h.ServeHTTP(rw, req) + h.ServeHTTP(rc, rc) - assert.Equal(t, http.StatusNoContent, rw.Code) - assert.Equal(t, "", rw.Body.String()) + assert.Equal(t, http.StatusNoContent, rc.Response.StatusCode()) + assert.Equal(t, "", string(rc.Response.Body())) } func TestHandler_ServeHTTP_customErrResp(t *testing.T) { @@ -196,21 +193,20 @@ func TestHandler_ServeHTTP_customErrResp(t *testing.T) { } h.SetResponseEncoder(&response.Encoder{}) - req, err := http.NewRequest(http.MethodGet, "/test", nil) - require.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/test") - rw := httptest.NewRecorder() - h.ServeHTTP(rw, req) + h.ServeHTTP(rc, rc) - assert.Equal(t, http.StatusExpectationFailed, rw.Code) - assert.Equal(t, `{"custom":"use case failed"}`+"\n", rw.Body.String()) + assert.Equal(t, http.StatusExpectationFailed, rc.Response.StatusCode()) + assert.Equal(t, `{"custom":"use case failed"}`+"\n", string(rc.Response.Body())) } func TestHandlerWithRouteMiddleware(t *testing.T) { called := false - var h http.Handler - h = http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + var h fchi.Handler + h = fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { called = true }) @@ -220,7 +216,7 @@ func TestHandlerWithRouteMiddleware(t *testing.T) { assert.Equal(t, http.MethodPost, hr.RouteMethod()) assert.Equal(t, "/test/", hr.RoutePattern()) - h.ServeHTTP(nil, nil) + h.ServeHTTP(context.Background(), nil) assert.True(t, called) } @@ -252,14 +248,14 @@ func TestHandler_ServeHTTP_getWithBody(t *testing.T) { h.SetRequestDecoder(request.NewDecoderFactory().MakeDecoder(http.MethodGet, new(reqWithBody), nil)) h.SetResponseEncoder(&response.Encoder{}) - req, err := http.NewRequest(http.MethodGet, "/test", strings.NewReader(`{"id":123}`)) - require.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/test") + rc.Request.SetBody([]byte(`{"id":123}`)) - rw := httptest.NewRecorder() - h.ServeHTTP(rw, req) + h.ServeHTTP(rc, rc) - assert.Equal(t, http.StatusNoContent, rw.Code) - assert.Equal(t, ``, rw.Body.String()) + assert.Equal(t, http.StatusNoContent, rc.Response.StatusCode()) + assert.Equal(t, ``, string(rc.Response.Body())) } func TestHandler_ServeHTTP_customMapping(t *testing.T) { @@ -289,14 +285,13 @@ func TestHandler_ServeHTTP_customMapping(t *testing.T) { response.EncoderMiddleware, ) - req, err := http.NewRequest(http.MethodGet, "/test?ident=123", nil) - require.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/test?ident=123") - rw := httptest.NewRecorder() - h.ServeHTTP(rw, req) + h.ServeHTTP(rc, rc) - assert.Equal(t, http.StatusNoContent, rw.Code) - assert.Equal(t, "", rw.Body.String()) + assert.Equal(t, http.StatusNoContent, rc.Response.StatusCode()) + assert.Equal(t, "", string(rc.Response.Body())) } func TestOptionsMiddleware(t *testing.T) { @@ -314,21 +309,20 @@ func TestOptionsMiddleware(t *testing.T) { var loggedErr error - rw := httptest.NewRecorder() - req, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/") oh := nethttp.OptionsMiddleware(func(h *nethttp.Handler) { handleErrResponse := h.HandleErrResponse - h.HandleErrResponse = func(w http.ResponseWriter, r *http.Request, err error) { - assert.Equal(t, req, r) + h.HandleErrResponse = func(ctx context.Context, r *fasthttp.RequestCtx, err error) { + assert.Equal(t, rc, r) loggedErr = err - handleErrResponse(w, r, err) + handleErrResponse(ctx, r, err) } })(h) - oh.ServeHTTP(rw, req) + oh.ServeHTTP(rc, rc) assert.EqualError(t, loggedErr, "failed") - assert.Equal(t, `{"foo":"failed"}`+"\n", rw.Body.String()) + assert.Equal(t, `{"foo":"failed"}`+"\n", string(rc.Response.Body())) } diff --git a/nethttp/openapi.go b/nethttp/openapi.go index 9d946e2..4bfef3a 100644 --- a/nethttp/openapi.go +++ b/nethttp/openapi.go @@ -3,14 +3,15 @@ package nethttp import ( "net/http" + "github.com/swaggest/fchi" "github.com/swaggest/openapi-go/openapi3" "github.com/swaggest/rest" "github.com/swaggest/rest/openapi" ) // OpenAPIMiddleware reads info and adds validation to handler. -func OpenAPIMiddleware(s *openapi.Collector) func(http.Handler) http.Handler { - return func(h http.Handler) http.Handler { +func OpenAPIMiddleware(s *openapi.Collector) func(fchi.Handler) fchi.Handler { + return func(h fchi.Handler) fchi.Handler { var ( withRoute rest.HandlerWithRoute handler *Handler @@ -41,7 +42,7 @@ func SecurityMiddleware( name string, scheme openapi3.SecurityScheme, options ...func(*MiddlewareConfig), -) func(http.Handler) http.Handler { +) func(fchi.Handler) fchi.Handler { c.Reflector().SpecEns().ComponentsEns().SecuritySchemesEns().WithMapOfSecuritySchemeOrRefValuesItem( name, openapi3.SecuritySchemeOrRef{ @@ -63,7 +64,7 @@ func HTTPBasicSecurityMiddleware( c *openapi.Collector, name, description string, options ...func(*MiddlewareConfig), -) func(http.Handler) http.Handler { +) func(fchi.Handler) fchi.Handler { hss := openapi3.HTTPSecurityScheme{} hss.WithScheme("basic") @@ -82,7 +83,7 @@ func HTTPBearerSecurityMiddleware( c *openapi.Collector, name, description, bearerFormat string, options ...func(*MiddlewareConfig), -) func(http.Handler) http.Handler { +) func(fchi.Handler) fchi.Handler { hss := openapi3.HTTPSecurityScheme{} hss.WithScheme("bearer") @@ -104,8 +105,8 @@ func HTTPBearerSecurityMiddleware( func AnnotateOpenAPI( s *openapi.Collector, setup ...func(op *openapi3.Operation) error, -) func(http.Handler) http.Handler { - return func(next http.Handler) http.Handler { +) func(fchi.Handler) fchi.Handler { + return func(next fchi.Handler) fchi.Handler { var withRoute rest.HandlerWithRoute if HandlerAs(next, &withRoute) { @@ -137,7 +138,7 @@ type MiddlewareConfig struct { ResponseStatus int } -func securityMiddleware(s *openapi.Collector, name string, cfg MiddlewareConfig) func(http.Handler) http.Handler { +func securityMiddleware(s *openapi.Collector, name string, cfg MiddlewareConfig) func(fchi.Handler) fchi.Handler { return AnnotateOpenAPI(s, func(op *openapi3.Operation) error { op.Security = append(op.Security, map[string][]string{name: {}}) diff --git a/nethttp/options.go b/nethttp/options.go index f2fcf2c..d74cb95 100644 --- a/nethttp/options.go +++ b/nethttp/options.go @@ -1,17 +1,17 @@ package nethttp import ( - "net/http" "reflect" + "github.com/swaggest/fchi" "github.com/swaggest/openapi-go/openapi3" "github.com/swaggest/refl" "github.com/swaggest/rest" ) // OptionsMiddleware applies options to encountered nethttp.Handler. -func OptionsMiddleware(options ...func(h *Handler)) func(h http.Handler) http.Handler { - return func(h http.Handler) http.Handler { +func OptionsMiddleware(options ...func(h *Handler)) func(h fchi.Handler) fchi.Handler { + return func(h fchi.Handler) fchi.Handler { var rh *Handler if HandlerAs(h, &rh) { diff --git a/nethttp/usecase.go b/nethttp/usecase.go index 2999cb0..389bbe3 100644 --- a/nethttp/usecase.go +++ b/nethttp/usecase.go @@ -2,14 +2,14 @@ package nethttp import ( "context" - "net/http" + "github.com/swaggest/fchi" "github.com/swaggest/usecase" ) // UseCaseMiddlewares applies use case middlewares to Handler. -func UseCaseMiddlewares(mw ...usecase.Middleware) func(http.Handler) http.Handler { - return func(handler http.Handler) http.Handler { +func UseCaseMiddlewares(mw ...usecase.Middleware) func(fchi.Handler) fchi.Handler { + return func(handler fchi.Handler) fchi.Handler { var uh *Handler if !HandlerAs(handler, &uh) { return handler diff --git a/nethttp/wrap.go b/nethttp/wrap.go index 9772260..8cc5f50 100644 --- a/nethttp/wrap.go +++ b/nethttp/wrap.go @@ -1,17 +1,18 @@ package nethttp import ( - "net/http" "reflect" "runtime" + + "github.com/swaggest/fchi" ) -// WrapHandler wraps http.Handler with an unwrappable middleware. +// WrapHandler wraps fchi.Handler with an unwrappable middleware. // // Wrapping order is reversed, e.g. if you call WrapHandler(h, mw1, mw2, mw3) middlewares will be // invoked in order of mw1(mw2(mw3(h))), mw3 first and mw1 last. So that request processing is first // affected by mw1. -func WrapHandler(h http.Handler, mw ...func(http.Handler) http.Handler) http.Handler { +func WrapHandler(h fchi.Handler, mw ...func(fchi.Handler) fchi.Handler) fchi.Handler { for i := len(mw) - 1; i >= 0; i-- { w := mw[i](h) if w == nil { @@ -28,15 +29,15 @@ func WrapHandler(h http.Handler, mw ...func(http.Handler) http.Handler) http.Han return h } -// HandlerAs finds the first http.Handler in http.Handler's chain that matches target, and if so, sets -// target to that http.Handler value and returns true. +// HandlerAs finds the first fchi.Handler in fchi.Handler's chain that matches target, and if so, sets +// target to that fchi.Handler value and returns true. // -// An http.Handler matches target if the http.Handler's concrete value is assignable to the value +// An fchi.Handler matches target if the fchi.Handler's concrete value is assignable to the value // pointed to by target. // // HandlerAs will panic if target is not a non-nil pointer to either a type that implements -// http.Handler, or to any interface type. -func HandlerAs(handler http.Handler, target interface{}) bool { +// fchi.Handler, or to any interface type. +func HandlerAs(handler fchi.Handler, target interface{}) bool { if target == nil { panic("target cannot be nil") } @@ -49,7 +50,7 @@ func HandlerAs(handler http.Handler, target interface{}) bool { } if e := typ.Elem(); e.Kind() != reflect.Interface && !e.Implements(handlerType) { - panic("*target must be interface or implement http.Handler") + panic("*target must be interface or implement fchi.Handler") } targetType := typ.Elem() @@ -81,11 +82,11 @@ func HandlerAs(handler http.Handler, target interface{}) bool { return false } -var handlerType = reflect.TypeOf((*http.Handler)(nil)).Elem() +var handlerType = reflect.TypeOf((*fchi.Handler)(nil)).Elem() type wrappedHandler struct { - http.Handler - wrapped http.Handler + fchi.Handler + wrapped fchi.Handler mwName string } diff --git a/nethttp/wrap_test.go b/nethttp/wrap_test.go index 0a4e6b6..c34ca8e 100644 --- a/nethttp/wrap_test.go +++ b/nethttp/wrap_test.go @@ -1,50 +1,52 @@ package nethttp_test import ( - "net/http" + "context" "testing" "github.com/stretchr/testify/assert" + "github.com/swaggest/fchi" "github.com/swaggest/rest/nethttp" + "github.com/valyala/fasthttp" ) func TestWrapHandler(t *testing.T) { var flow []string h := nethttp.WrapHandler( - http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { flow = append(flow, "handler") }), - func(handler http.Handler) http.Handler { + func(handler fchi.Handler) fchi.Handler { flow = append(flow, "mw1 registered") - return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + return fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { flow = append(flow, "mw1 before") - handler.ServeHTTP(writer, request) + handler.ServeHTTP(ctx, rc) flow = append(flow, "mw1 after") }) }, - func(handler http.Handler) http.Handler { + func(handler fchi.Handler) fchi.Handler { flow = append(flow, "mw2 registered") - return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + return fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { flow = append(flow, "mw2 before") - handler.ServeHTTP(writer, request) + handler.ServeHTTP(ctx, rc) flow = append(flow, "mw2 after") }) }, - func(handler http.Handler) http.Handler { + func(handler fchi.Handler) fchi.Handler { flow = append(flow, "mw3 registered") - return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + return fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { flow = append(flow, "mw3 before") - handler.ServeHTTP(writer, request) + handler.ServeHTTP(ctx, rc) flow = append(flow, "mw3 after") }) }, ) - h.ServeHTTP(nil, nil) + h.ServeHTTP(context.Background(), nil) assert.Equal(t, []string{ "mw3 registered", "mw2 registered", "mw1 registered", diff --git a/openapi/collector.go b/openapi/collector.go index 41f6a40..388548a 100644 --- a/openapi/collector.go +++ b/openapi/collector.go @@ -10,10 +10,12 @@ import ( "strconv" "sync" + "github.com/swaggest/fchi" "github.com/swaggest/jsonschema-go" "github.com/swaggest/openapi-go/openapi3" "github.com/swaggest/rest" "github.com/swaggest/usecase" + "github.com/valyala/fasthttp" ) // Collector extracts OpenAPI documentation from HTTP handler and underlying use case interactor. @@ -506,19 +508,19 @@ func (c *Collector) ProvideResponseJSONSchemas( return nil } -func (c *Collector) ServeHTTP(rw http.ResponseWriter, _ *http.Request) { +func (c *Collector) ServeHTTP(_ context.Context, rc *fasthttp.RequestCtx) { c.mu.Lock() defer c.mu.Unlock() document, err := json.MarshalIndent(c.Reflector().Spec, "", " ") if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) + fchi.Error(rc, err.Error(), http.StatusInternalServerError) } - rw.Header().Set("Content-Type", "application/json; charset=utf8") + rc.Response.Header.Set("Content-Type", "application/json; charset=utf8") - _, err = rw.Write(document) + _, err = rc.Write(document) if err != nil { - http.Error(rw, err.Error(), http.StatusInternalServerError) + fchi.Error(rc, err.Error(), http.StatusInternalServerError) } } diff --git a/openapi/collector_test.go b/openapi/collector_test.go index db1dcf6..8a470e0 100644 --- a/openapi/collector_test.go +++ b/openapi/collector_test.go @@ -6,7 +6,6 @@ import ( "errors" "mime/multipart" "net/http" - "net/http/httptest" "testing" "time" @@ -20,6 +19,7 @@ import ( "github.com/swaggest/rest/openapi" "github.com/swaggest/usecase" "github.com/swaggest/usecase/status" + "github.com/valyala/fasthttp" ) var _ rest.JSONSchemaValidator = validatorMock{} @@ -100,10 +100,10 @@ func TestCollector_Collect(t *testing.T) { j, err := json.MarshalIndent(c.Reflector().Spec, "", " ") require.NoError(t, err) - rw := httptest.NewRecorder() - c.ServeHTTP(rw, nil) + rc := &fasthttp.RequestCtx{} + c.ServeHTTP(rc, rc) - assertjson.Equal(t, j, rw.Body.Bytes()) + assertjson.Equal(t, j, rc.Response.Body()) val := validatorMock{ AddSchemaFunc: func(in rest.ParamIn, name string, schemaData []byte, required bool) error { diff --git a/request/b2s_safe.go b/request/b2s_safe.go new file mode 100644 index 0000000..76f6ce0 --- /dev/null +++ b/request/b2s_safe.go @@ -0,0 +1,8 @@ +//go:build appengine +// +build appengine + +package request + +func b2s(b []byte) string { + return string(b) +} diff --git a/request/b2s_unsafe.go b/request/b2s_unsafe.go new file mode 100644 index 0000000..3caa2ee --- /dev/null +++ b/request/b2s_unsafe.go @@ -0,0 +1,66 @@ +//go:build !appengine +// +build !appengine + +package request + +import ( + "reflect" + "unsafe" +) + +/* +Copyright (c) 2009 The Go Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +// b2s returns a string that refers to the data backing the slice s. +// +// The caller must ensure that the contents of the slice are never again +// mutated, and that its memory either is managed by the Go garbage collector or +// remains valid for the remainder of this process's lifetime. +// +// Programs that use b2s should be tested under the race detector to flag +// erroneous mutations. +// +// Programs that have been adequately tested and shown to be safe may be +// recompiled with the "unsafe" tag to significantly reduce the overhead of this +// function, at the cost of reduced safety checks. Programs built under the race +// detector always have safety checks enabled, even when the "unsafe" tag is +// set. +// +// Copied from https://github.com/bcmills/unsafeslice/blob/v0.2.0/unsafeslice.go#L143. +func b2s(b []byte) string { + p := unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&b)).Data) //nolint:gosec + + var s string + hdr := (*reflect.StringHeader)(unsafe.Pointer(&s)) // nolint:gosec + hdr.Data = uintptr(p) + hdr.Len = len(b) + + return s +} diff --git a/request/decoder.go b/request/decoder.go index 0494aed..b051039 100644 --- a/request/decoder.go +++ b/request/decoder.go @@ -1,13 +1,13 @@ package request import ( - "net/http" "net/url" - "strings" + "sync" "github.com/swaggest/form/v5" "github.com/swaggest/rest" "github.com/swaggest/rest/nethttp" + "github.com/valyala/fasthttp" ) type ( @@ -15,11 +15,11 @@ type ( // // Implement this interface on a pointer to your input structure to disable automatic request mapping. Loader interface { - LoadFromHTTPRequest(r *http.Request) error + LoadFromFastHTTPRequest(rc *fasthttp.RequestCtx) error } - decoderFunc func(r *http.Request) (url.Values, error) - valueDecoderFunc func(r *http.Request, v interface{}, validator rest.Validator) error + decoderFunc func(rc *fasthttp.RequestCtx, v url.Values) error + valueDecoderFunc func(rc *fasthttp.RequestCtx, v interface{}, validator rest.Validator) error ) func decodeValidate(d *form.Decoder, v interface{}, p url.Values, in rest.ParamIn, val rest.Validator) error { @@ -49,9 +49,24 @@ func decodeValidate(d *form.Decoder, v interface{}, p url.Values, in rest.ParamI return val.ValidateData(in, goValues) } +var valuesPool = &sync.Pool{ + New: func() interface{} { + return make(url.Values) + }, +} + func makeDecoder(in rest.ParamIn, formDecoder *form.Decoder, decoderFunc decoderFunc) valueDecoderFunc { - return func(r *http.Request, v interface{}, validator rest.Validator) error { - values, err := decoderFunc(r) + return func(rc *fasthttp.RequestCtx, v interface{}, validator rest.Validator) error { + values := valuesPool.Get().(url.Values) // nolint:errcheck + for k := range values { + delete(values, k) + } + + defer func() { + valuesPool.Put(values) + }() + + err := decoderFunc(rc, values) if err != nil { return err } @@ -73,13 +88,13 @@ type decoder struct { var _ nethttp.RequestDecoder = &decoder{} // Decode populates and validates input with data from http request. -func (d *decoder) Decode(r *http.Request, input interface{}, validator rest.Validator) error { +func (d *decoder) Decode(rc *fasthttp.RequestCtx, input interface{}, validator rest.Validator) error { if i, ok := input.(Loader); ok { - return i.LoadFromHTTPRequest(r) + return i.LoadFromFastHTTPRequest(rc) } for i, decode := range d.decoders { - err := decode(r, input, validator) + err := decode(rc, input, validator) if err != nil { // nolint:errorlint // Error is not wrapped, type assertion is more performant. if de, ok := err.(form.DecodeErrors); ok { @@ -98,40 +113,48 @@ func (d *decoder) Decode(r *http.Request, input interface{}, validator rest.Vali return nil } -const defaultMaxMemory = 32 << 20 // 32 MB +func formDataToURLValues(rc *fasthttp.RequestCtx, params url.Values) error { + args := rc.Request.PostArgs() -func formDataToURLValues(r *http.Request) (url.Values, error) { - if r.ContentLength == 0 { - return nil, nil + if args.Len() == 0 { + return nil } - if strings.HasPrefix(r.Header.Get("Content-Type"), "multipart/form-data") { - err := r.ParseMultipartForm(defaultMaxMemory) - if err != nil { - return nil, err - } - } else if err := r.ParseForm(); err != nil { - return nil, err - } + args.VisitAll(func(key, value []byte) { + k := b2s(key) + v := params[k] + params[k] = append(v, b2s(value)) + }) - return r.PostForm, nil + return nil } -func headerToURLValues(r *http.Request) (url.Values, error) { - return url.Values(r.Header), nil -} +func headerToURLValues(rc *fasthttp.RequestCtx, params url.Values) error { + rc.Request.Header.VisitAll(func(key, value []byte) { + k := b2s(key) + v := params[k] + params[k] = append(v, b2s(value)) + }) -func queryToURLValues(r *http.Request) (url.Values, error) { - return r.URL.Query(), nil + return nil } -func cookiesToURLValues(r *http.Request) (url.Values, error) { - cookies := r.Cookies() - params := make(url.Values, len(cookies)) +func queryToURLValues(rc *fasthttp.RequestCtx, params url.Values) error { + rc.Request.URI().QueryArgs().VisitAll(func(key, value []byte) { + k := b2s(key) + v := params[k] + params[k] = append(v, b2s(value)) + }) - for _, c := range cookies { - params[c.Name] = []string{c.Value} - } + return nil +} - return params, nil +func cookiesToURLValues(rc *fasthttp.RequestCtx, params url.Values) error { + rc.Request.Header.VisitAllCookie(func(key, value []byte) { + k := b2s(key) + v := params[k] + params[k] = append(v, b2s(value)) + }) + + return nil } diff --git a/request/decoder_test.go b/request/decoder_test.go index 9070d6e..1020772 100644 --- a/request/decoder_test.go +++ b/request/decoder_test.go @@ -1,11 +1,9 @@ package request_test import ( - "context" "fmt" "net/http" "net/url" - "strings" "testing" "time" @@ -16,9 +14,22 @@ import ( "github.com/swaggest/rest/jsonschema" "github.com/swaggest/rest/openapi" "github.com/swaggest/rest/request" + "github.com/valyala/fasthttp" ) +// BenchmarkDecoder_Decode-12 1410783 797.9 ns/op 866 B/op 10 allocs/op. +// BenchmarkDecoder_Decode-12 2104834 599.5 ns/op 65 B/op 6 allocs/op +// BenchmarkDecoder_Decode-12 1999123 568.2 ns/op 65 B/op 6 allocs/op + // BenchmarkDecoder_Decode-4 1314788 857 ns/op 448 B/op 4 allocs/op. +// --- net/http +// BenchmarkDecoder_Decode-16 2276893 453.3 ns/op 440 B/op 4 allocs/op. +// --- fasthttp +// BenchmarkDecoder_Decode-16 2615832 455.2 ns/op 65 B/op 6 allocs/op. +// unsafe b2s +// BenchmarkDecoder_Decode-16 2797910 403.6 ns/op 56 B/op 3 allocs/op. +// append v +// BenchmarkDecoder_Decode-16 2776867 429.0 ns/op 56 B/op 3 allocs/op. func BenchmarkDecoder_Decode(b *testing.B) { df := request.NewDecoderFactory() @@ -27,10 +38,10 @@ func BenchmarkDecoder_Decode(b *testing.B) { H int `header:"X-H"` } - r, err := http.NewRequest(http.MethodGet, "/?q=abc", nil) - require.NoError(b, err) + rc := fasthttp.RequestCtx{} - r.Header.Set("X-H", "123") + rc.Request.SetRequestURI("/?q=abc") + rc.Request.Header.Set("X-H", "123") d := df.MakeDecoder(http.MethodGet, new(req), nil) @@ -40,7 +51,7 @@ func BenchmarkDecoder_Decode(b *testing.B) { for i := 0; i < b.N; i++ { rr := new(req) - err = d.Decode(r, rr, nil) + err := d.Decode(&rc, rr, nil) if err != nil { b.Fail() } @@ -74,31 +85,26 @@ type reqJSONTest struct { } func TestDecoder_Decode(t *testing.T) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/?in_query=abc", - strings.NewReader(url.Values{"inFormData": []string{"def"}}.Encode())) - assert.NoError(t, err) - - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("X-In-hEaDeR", "123") - - c := http.Cookie{ - Name: "in_cookie", - Value: "jkl", - } - - req.AddCookie(&c) + rc := &fasthttp.RequestCtx{} + rc.Request.Header.SetMethod(http.MethodPost) + rc.Request.SetRequestURI("/?in_query=abc") + rc.Request.SetBody([]byte(url.Values{"inFormData": []string{"def"}}.Encode())) + rc.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rc.Request.Header.Set("X-In-hEaDeR", "123") + rc.Request.Header.SetCookie("in_cookie", "jkl") df := request.NewDecoderFactory() - df.SetDecoderFunc(rest.ParamInPath, func(r *http.Request) (url.Values, error) { - assert.Equal(t, req, r) + df.SetDecoderFunc(rest.ParamInPath, func(r *fasthttp.RequestCtx, params url.Values) error { + assert.Equal(t, rc, r) + params["in_path"] = []string{"mno"} - return url.Values{"in_path": []string{"mno"}}, nil + return nil }) input := new(reqTest) dec := df.MakeDecoder(http.MethodPost, input, nil) - assert.NoError(t, dec.Decode(req, input, nil)) + assert.NoError(t, dec.Decode(rc, input, nil)) assert.Equal(t, "abc", input.Query) assert.Equal(t, "def", input.FormData) assert.Equal(t, 123, input.Header) @@ -116,7 +122,7 @@ func TestDecoder_Decode(t *testing.T) { rest.ParamInFormData: {"FormData": "inFormData"}, }) - assert.NoError(t, decCM.Decode(req, inputCM, nil)) + assert.NoError(t, decCM.Decode(rc, inputCM, nil)) assert.Equal(t, "abc", inputCM.Query) assert.Equal(t, "def", inputCM.FormData) assert.Equal(t, 123, inputCM.Header) @@ -126,23 +132,19 @@ func TestDecoder_Decode(t *testing.T) { // BenchmarkDecoderFunc_Decode-4 440503 2525 ns/op 1513 B/op 12 allocs/op. func BenchmarkDecoderFunc_Decode(b *testing.B) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/?in_query=abc", - strings.NewReader(url.Values{"inFormData": []string{"def"}}.Encode())) - assert.NoError(b, err) - - req.Header.Add("Content-Type", "application/x-www-form-urlencoded") - req.Header.Set("X-In-Header", "123") - - c := http.Cookie{ - Name: "in_cookie", - Value: "jkl", - } - - req.AddCookie(&c) + rc := &fasthttp.RequestCtx{} + rc.Request.Header.SetMethod(http.MethodPost) + rc.Request.SetRequestURI("/?in_query=abc") + rc.Request.SetBody([]byte(url.Values{"inFormData": []string{"def"}}.Encode())) + rc.Request.Header.Set("Content-Type", "application/x-www-form-urlencoded") + rc.Request.Header.Set("X-In-hEaDeR", "123") + rc.Request.Header.SetCookie("in_cookie", "jkl") df := request.NewDecoderFactory() - df.SetDecoderFunc(rest.ParamInPath, func(r *http.Request) (url.Values, error) { - return url.Values{"in_path": []string{"mno"}}, nil + df.SetDecoderFunc(rest.ParamInPath, func(r *fasthttp.RequestCtx, params url.Values) error { + params["in_path"] = []string{"mno"} + + return nil }) dec := df.MakeDecoder(http.MethodPost, new(reqTest), nil) @@ -153,7 +155,7 @@ func BenchmarkDecoderFunc_Decode(b *testing.B) { for i := 0; i < b.N; i++ { input := new(reqTest) - err := dec.Decode(req, input, nil) + err := dec.Decode(rc, input, nil) if err != nil { b.Fail() } @@ -165,49 +167,53 @@ func BenchmarkDecoderFunc_Decode(b *testing.B) { } func TestDecoder_Decode_required(t *testing.T) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/", nil) - assert.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.Header.SetMethod(http.MethodPost) + rc.Request.SetRequestURI("/") input := new(reqTest) dec := request.NewDecoderFactory().MakeDecoder(http.MethodPost, input, nil) validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). MakeRequestValidator(http.MethodPost, input, nil) - err = dec.Decode(req, input, validator) + err := dec.Decode(rc, input, validator) assert.Equal(t, rest.ValidationErrors{"header:X-In-HeAdEr": []string{"missing value"}}, err) } func TestDecoder_Decode_json(t *testing.T) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/?in_query=cba", - strings.NewReader(`{"bodyOne":"abc", "bodyTwo": [1,2,3]}`)) - assert.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.Header.SetMethod(http.MethodPost) + rc.Request.SetRequestURI("/?in_query=cba") + rc.Request.SetBody([]byte(`{"bodyOne":"abc", "bodyTwo": [1,2,3]}`)) input := new(reqJSONTest) dec := request.NewDecoderFactory().MakeDecoder(http.MethodPost, input, nil) validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). MakeRequestValidator(http.MethodPost, input, nil) - assert.NoError(t, dec.Decode(req, input, validator)) + assert.NoError(t, dec.Decode(rc, input, validator)) assert.Equal(t, "cba", input.Query) assert.Equal(t, "abc", input.BodyOne) assert.Equal(t, []int{1, 2, 3}, input.BodyTwo) - req, err = http.NewRequestWithContext(context.Background(), http.MethodPost, "/", - strings.NewReader(`{"bodyTwo":[1]}`)) - assert.NoError(t, err) + rc = &fasthttp.RequestCtx{} + rc.Request.Header.SetMethod(http.MethodPost) + rc.Request.SetRequestURI("/") + rc.Request.SetBody([]byte(`{"bodyTwo":[1]}`)) - err = dec.Decode(req, input, validator) + err := dec.Decode(rc, input, validator) assert.Equal(t, rest.ValidationErrors{"body": []string{ "#: validation failed", "#: missing properties: \"bodyOne\"", "#/bodyTwo: minimum 2 items allowed, but found 1 items", }}, err) - req, err = http.NewRequestWithContext(context.Background(), http.MethodPost, "/", - strings.NewReader(`{"bodyOne":"abc", "bodyTwo":[1]}`)) - assert.NoError(t, err) + rc = &fasthttp.RequestCtx{} + rc.Request.Header.SetMethod(http.MethodPost) + rc.Request.SetRequestURI("/") + rc.Request.SetBody([]byte(`{"bodyOne":"abc", "bodyTwo":[1]}`)) - err = dec.Decode(req, input, validator) + err = dec.Decode(rc, input, validator) assert.Error(t, err) assert.Equal(t, rest.ValidationErrors{"body": []string{"#/bodyTwo: minimum 2 items allowed, but found 1 items"}}, err) } @@ -223,35 +229,32 @@ func BenchmarkDecoder_Decode_json(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - req, err := http.NewRequestWithContext(context.Background(), http.MethodPost, "/?in_query=cba", - strings.NewReader(`{"bodyOne":"abc", "bodyTwo": [1,2,3]}`)) - if err != nil { - b.Fail() - } + rc := fasthttp.RequestCtx{} + rc.Request.Header.SetMethod(http.MethodPost) + rc.Request.SetRequestURI("/?in_query=cba") + rc.Request.SetBody([]byte(`{"bodyOne":"abc", "bodyTwo": [1,2,3]}`)) - err = dec.Decode(req, input, validator) + err := dec.Decode(&rc, input, validator) if err != nil { b.Fail() } - req, err = http.NewRequestWithContext(context.Background(), http.MethodPost, "/", - strings.NewReader(`{"bodyTwo":[1]}`)) - if err != nil { - b.Fail() - } + rc = fasthttp.RequestCtx{} + rc.Request.Header.SetMethod(http.MethodPost) + rc.Request.SetRequestURI("/") + rc.Request.SetBody([]byte(`{"bodyTwo":[1]}`)) - err = dec.Decode(req, input, validator) + err = dec.Decode(&rc, input, validator) if err == nil { b.Fail() } - req, err = http.NewRequestWithContext(context.Background(), http.MethodPost, "/", - strings.NewReader(`{"bodyOne":"abc", "bodyTwo":[1]}`)) - if err != nil { - b.Fail() - } + rc = fasthttp.RequestCtx{} + rc.Request.Header.SetMethod(http.MethodPost) + rc.Request.SetRequestURI("/") + rc.Request.SetBody([]byte(`{"bodyOne":"abc", "bodyTwo":[1]}`)) - err = dec.Decode(req, input, validator) + err = dec.Decode(&rc, input, validator) if err == nil { b.Fail() } @@ -259,9 +262,7 @@ func BenchmarkDecoder_Decode_json(b *testing.B) { } func TestDecoder_Decode_queryObject(t *testing.T) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, - "/?in_query[1]=1.0&in_query[2]=2.1&in_query[3]=0", nil) - assert.NoError(t, err) + rc := req("/?in_query[1]=1.0&in_query[2]=2.1&in_query[3]=0") df := request.NewDecoderFactory() @@ -270,14 +271,12 @@ func TestDecoder_Decode_queryObject(t *testing.T) { }) dec := df.MakeDecoder(http.MethodGet, input, nil) - assert.NoError(t, dec.Decode(req, input, nil)) + assert.NoError(t, dec.Decode(rc, input, nil)) assert.Equal(t, map[int]float64{1: 1, 2: 2.1, 3: 0}, input.InQuery) - req, err = http.NewRequestWithContext(context.Background(), http.MethodGet, - "/?in_query[1]=1.0&in_query[2]=2.1&in_query[c]=0", nil) - assert.NoError(t, err) + rc = req("/?in_query[1]=1.0&in_query[2]=2.1&in_query[c]=0") - err = dec.Decode(req, input, nil) + err := dec.Decode(rc, input, nil) assert.Error(t, err) assert.Equal(t, rest.RequestErrors{"query:in_query": []string{ "#: invalid integer value 'c' type 'int' namespace 'in_query'", @@ -286,13 +285,9 @@ func TestDecoder_Decode_queryObject(t *testing.T) { // BenchmarkDecoder_Decode_queryObject-4 170670 6104 ns/op 2000 B/op 36 allocs/op. func BenchmarkDecoder_Decode_queryObject(b *testing.B) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, - "/?in_query[1]=1.0&in_query[2]=2.1&in_query[3]=0", nil) - assert.NoError(b, err) + rc := req("/?in_query[1]=1.0&in_query[2]=2.1&in_query[3]=0") - req2, err := http.NewRequestWithContext(context.Background(), http.MethodGet, - "/?in_query[1]=1.0&in_query[2]=2.1&in_query[c]=0", nil) - assert.NoError(b, err) + rc2 := req("/?in_query[1]=1.0&in_query[2]=2.1&in_query[c]=0") df := request.NewDecoderFactory() @@ -305,12 +300,12 @@ func BenchmarkDecoder_Decode_queryObject(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { - err = dec.Decode(req, input, nil) + err := dec.Decode(rc, input, nil) if err != nil { b.Fail() } - err = dec.Decode(req2, input, nil) + err = dec.Decode(rc2, input, nil) if err == nil { b.Fail() } @@ -328,11 +323,10 @@ func TestDecoder_Decode_jsonParam(t *testing.T) { df := request.NewDecoderFactory() dec := df.MakeDecoder(http.MethodGet, new(inp), nil) - req, err := http.NewRequest(http.MethodGet, "/?filter=%7B%22a%22%3A123%2C%22b%22%3A%22abc%22%7D", nil) - require.NoError(t, err) + rc := req("/?filter=%7B%22a%22%3A123%2C%22b%22%3A%22abc%22%7D") v := new(inp) - require.NoError(t, dec.Decode(req, v, nil)) + require.NoError(t, dec.Decode(rc, v, nil)) assert.Equal(t, 123, v.Filter.A) assert.Equal(t, "abc", v.Filter.B) @@ -350,8 +344,7 @@ func BenchmarkDecoder_Decode_jsonParam(b *testing.B) { df := request.NewDecoderFactory() dec := df.MakeDecoder(http.MethodGet, new(inp), nil) - req, err := http.NewRequest(http.MethodGet, "/?filter=%7B%22a%22%3A123%2C%22b%22%3A%22abc%22%7D", nil) - require.NoError(b, err) + rc := req("/?filter=%7B%22a%22%3A123%2C%22b%22%3A%22abc%22%7D") v := new(inp) @@ -359,7 +352,7 @@ func BenchmarkDecoder_Decode_jsonParam(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - err := dec.Decode(req, v, nil) + err := dec.Decode(rc, v, nil) if err != nil { b.Fail() } @@ -369,19 +362,18 @@ func BenchmarkDecoder_Decode_jsonParam(b *testing.B) { } func TestDecoder_Decode_error(t *testing.T) { - type req struct { + type reqs struct { Q int `default:"100" query:"q"` } df := request.NewDecoderFactory() df.ApplyDefaults = true - d := df.MakeDecoder(http.MethodGet, new(req), nil) - r, err := http.NewRequest(http.MethodGet, "?q=undefined", nil) - require.NoError(t, err) + d := df.MakeDecoder(http.MethodGet, new(reqs), nil) + rc := req("?q=undefined") - in := new(req) - err = d.Decode(r, in, nil) + in := new(reqs) + err := d.Decode(rc, in, nil) assert.EqualError(t, err, "bad request") assert.Equal(t, rest.RequestErrors{"query:q": []string{ "#: invalid integer value 'undefined' type 'int' namespace 'q'", @@ -389,9 +381,7 @@ func TestDecoder_Decode_error(t *testing.T) { } func TestDecoder_Decode_dateTime(t *testing.T) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, - "/?time=2020-04-04T00:00:00Z&date=2020-04-04", nil) - assert.NoError(t, err) + rc := req("/?time=2020-04-04T00:00:00Z&date=2020-04-04") type reqTest struct { Time time.Time `query:"time"` @@ -403,7 +393,7 @@ func TestDecoder_Decode_dateTime(t *testing.T) { validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). MakeRequestValidator(http.MethodGet, input, nil) - err = dec.Decode(req, input, validator) + err := dec.Decode(rc, input, validator) assert.NoError(t, err, fmt.Sprintf("%v", err)) } @@ -411,23 +401,21 @@ type inputWithLoader struct { Time time.Time `query:"time"` Date jschema.Date `query:"date"` - load func(r *http.Request) error + load func(rc *fasthttp.RequestCtx) error } -func (i *inputWithLoader) LoadFromHTTPRequest(r *http.Request) error { +func (i *inputWithLoader) LoadFromFastHTTPRequest(r *fasthttp.RequestCtx) error { return i.load(r) } func TestDecoder_Decode_manualLoader(t *testing.T) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, - "/?time=2020-04-04T00:00:00Z&date=2020-04-04", nil) - assert.NoError(t, err) + rc := req("/?time=2020-04-04T00:00:00Z&date=2020-04-04") input := new(inputWithLoader) loadTriggered := false - input.load = func(r *http.Request) error { - assert.Equal(t, "/?time=2020-04-04T00:00:00Z&date=2020-04-04", r.URL.String()) + input.load = func(rc *fasthttp.RequestCtx) error { + assert.Equal(t, "/?time=2020-04-04T00:00:00Z&date=2020-04-04", string(rc.Request.RequestURI())) loadTriggered = true @@ -438,16 +426,14 @@ func TestDecoder_Decode_manualLoader(t *testing.T) { validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). MakeRequestValidator(http.MethodGet, input, nil) - err = dec.Decode(req, input, validator) + err := dec.Decode(rc, input, validator) assert.NoError(t, err, fmt.Sprintf("%v", err)) assert.True(t, loadTriggered) assert.True(t, input.Time.IsZero()) } func TestDecoder_Decode_unknownParams(t *testing.T) { - req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, - "/?foo=1&bar=1&bar=2&baz&quux=123", nil) - assert.NoError(t, err) + rc := req("/?foo=1&bar=1&bar=2&baz&quux=123") type input struct { Foo string `query:"foo"` @@ -463,7 +449,32 @@ func TestDecoder_Decode_unknownParams(t *testing.T) { validator := jsonschema.NewFactory(&openapi.Collector{}, &openapi.Collector{}). MakeRequestValidator(http.MethodGet, in, nil) - err = dec.Decode(req, in, validator) + err := dec.Decode(rc, in, validator) assert.Equal(t, rest.ValidationErrors{"query:quux": []string{"unknown parameter with value 123"}}, err, fmt.Sprintf("%#v", err)) } + +func req(url string) *fasthttp.RequestCtx { + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI(url) + + return rc +} + +func TestDecoder_Decode_multi(t *testing.T) { + rc := req("/?foo=1&foo=2&foo=3") + + type input struct { + Foo1 int `query:"foo"` + Foo2 []int `query:"foo"` + } + + in := new(input) + dec := request.NewDecoderFactory().MakeDecoder(http.MethodGet, in, nil) + + err := dec.Decode(rc, in, nil) + assert.NoError(t, err) + + assert.Equal(t, 1, in.Foo1) + assert.Equal(t, []int{1, 2, 3}, in.Foo2) +} diff --git a/request/factory.go b/request/factory.go index 7205882..70654a4 100644 --- a/request/factory.go +++ b/request/factory.go @@ -15,6 +15,7 @@ import ( "github.com/swaggest/refl" "github.com/swaggest/rest" "github.com/swaggest/rest/nethttp" + "github.com/valyala/fasthttp" ) var _ DecoderMaker = &DecoderFactory{} @@ -68,7 +69,7 @@ func NewDecoderFactory() *DecoderFactory { } // SetDecoderFunc adds custom decoder function for values of particular field tag name. -func (df *DecoderFactory) SetDecoderFunc(tagName rest.ParamIn, d func(r *http.Request) (url.Values, error)) { +func (df *DecoderFactory) SetDecoderFunc(tagName rest.ParamIn, d func(rc *fasthttp.RequestCtx, v url.Values) error) { if df.decoderFunctions == nil { df.decoderFunctions = make(map[rest.ParamIn]decoderFunc) } @@ -238,7 +239,7 @@ func (df *DecoderFactory) makeDefaultDecoder(input interface{}, m *decoder) { dec := df.defaultValDecoder - m.decoders = append(m.decoders, func(r *http.Request, v interface{}, validator rest.Validator) error { + m.decoders = append(m.decoders, func(rc *fasthttp.RequestCtx, v interface{}, validator rest.Validator) error { return dec.Decode(v, defaults) }) m.in = append(m.in, defaultTag) diff --git a/request/factory_test.go b/request/factory_test.go index d7e05db..eb48c74 100644 --- a/request/factory_test.go +++ b/request/factory_test.go @@ -11,49 +11,48 @@ import ( "github.com/stretchr/testify/require" "github.com/swaggest/rest" "github.com/swaggest/rest/request" + "github.com/valyala/fasthttp" ) func TestDecoderFactory_SetDecoderFunc(t *testing.T) { df := request.NewDecoderFactory() - df.SetDecoderFunc("jwt", func(r *http.Request) (url.Values, error) { - ah := r.Header.Get("Authorization") + df.SetDecoderFunc("jwt", func(rc *fasthttp.RequestCtx, params url.Values) error { + ah := string(rc.Request.Header.Peek("Authorization")) if ah == "" || len(ah) < 8 || strings.ToLower(ah[0:7]) != "bearer " { - return nil, nil + return nil } var m map[string]json.RawMessage err := json.Unmarshal([]byte(ah[7:]), &m) if err != nil { - return nil, err + return err } - res := make(url.Values) for k, v := range m { if len(v) > 2 && v[0] == '"' && v[len(v)-1] == '"' { v = v[1 : len(v)-1] } - res[k] = []string{string(v)} + params[k] = []string{string(v)} } - return res, err + return err }) - type req struct { + type reqs struct { Q string `query:"q"` Name string `jwt:"name"` Iat int `jwt:"iat"` Sub string `jwt:"sub"` } - r, err := http.NewRequest(http.MethodGet, "/?q=abc", nil) - require.NoError(t, err) + rc := req("/?q=abc") - r.Header.Add("Authorization", `Bearer {"sub":"1234567890","name":"John Doe","iat": 1516239022}`) + rc.Request.Header.Add("Authorization", `Bearer {"sub":"1234567890","name":"John Doe","iat": 1516239022}`) - d := df.MakeDecoder(http.MethodGet, new(req), nil) + d := df.MakeDecoder(http.MethodGet, new(reqs), nil) - rr := new(req) - require.NoError(t, d.Decode(r, rr, nil)) + rr := new(reqs) + require.NoError(t, d.Decode(rc, rr, nil)) assert.Equal(t, "John Doe", rr.Name) assert.Equal(t, "1234567890", rr.Sub) @@ -64,10 +63,10 @@ func TestDecoderFactory_SetDecoderFunc(t *testing.T) { // BenchmarkDecoderFactory_SetDecoderFunc-4 577378 1994 ns/op 1024 B/op 16 allocs/op. func BenchmarkDecoderFactory_SetDecoderFunc(b *testing.B) { df := request.NewDecoderFactory() - df.SetDecoderFunc("jwt", func(r *http.Request) (url.Values, error) { - ah := r.Header.Get("Authorization") + df.SetDecoderFunc("jwt", func(r *fasthttp.RequestCtx, params url.Values) error { + ah := string(r.Request.Header.Peek("Authorization")) if ah == "" || len(ah) < 8 || strings.ToLower(ah[0:7]) != "bearer " { - return nil, nil + return nil } // Pretending json.Unmarshal has passed to improve benchmark relevancy. @@ -77,38 +76,36 @@ func BenchmarkDecoderFactory_SetDecoderFunc(b *testing.B) { "iat": []byte(`1516239022`), } - res := make(url.Values) for k, v := range m { if len(v) > 2 && v[0] == '"' && v[len(v)-1] == '"' { v = v[1 : len(v)-1] } - res[k] = []string{string(v)} + params[k] = []string{string(v)} } - return res, nil + return nil }) - type req struct { + type reqs struct { Q string `query:"q"` Name string `jwt:"name"` Iat int `jwt:"iat"` Sub string `jwt:"sub"` } - r, err := http.NewRequest(http.MethodGet, "/?q=abc", nil) - require.NoError(b, err) + rc := req("/?q=abc") - r.Header.Add("Authorization", `Bearer {"sub":"1234567890","name":"John Doe","iat": 1516239022}`) + rc.Request.Header.Add("Authorization", `Bearer {"sub":"1234567890","name":"John Doe","iat": 1516239022}`) - d := df.MakeDecoder(http.MethodGet, new(req), nil) + d := df.MakeDecoder(http.MethodGet, new(reqs), nil) b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - rr := new(req) + rr := new(reqs) - err = d.Decode(r, rr, nil) + err := d.Decode(rc, rr, nil) if err != nil { b.Fail() } @@ -128,24 +125,23 @@ func TestDecoderFactory_MakeDecoder_default(t *testing.T) { dec := df.MakeDecoder(http.MethodPost, new(MyInput), nil) assert.NotNil(t, dec) - req, err := http.NewRequest(http.MethodPost, "/", nil) - require.NoError(t, err) + rc := req("/") + rc.Request.Header.SetMethod(http.MethodPost) i := new(MyInput) - err = dec.Decode(req, i, nil) + err := dec.Decode(rc, i, nil) assert.NoError(t, err) assert.Equal(t, "foo", i.Name) assert.Equal(t, 123, i.ID) - req, err = http.NewRequest(http.MethodPost, "/?id=321", nil) - require.NoError(t, err) - - req.Header.Set("X-Name", "bar") + rc = req("/?id=321") + rc.Request.Header.SetMethod(http.MethodPost) + rc.Request.Header.Set("X-Name", "bar") i = new(MyInput) - err = dec.Decode(req, i, nil) + err = dec.Decode(rc, i, nil) assert.NoError(t, err) assert.Equal(t, "bar", i.Name) assert.Equal(t, 321, i.ID) @@ -186,24 +182,23 @@ func TestDecoderFactory_MakeDecoder_customMapping(t *testing.T) { dec := df.MakeDecoder(http.MethodPost, new(MyInput), customMapping) assert.NotNil(t, dec) - req, err := http.NewRequest(http.MethodPost, "/", nil) - require.NoError(t, err) + rc := req("/") + rc.Request.Header.SetMethod(http.MethodPost) i := new(MyInput) - err = dec.Decode(req, i, nil) + err := dec.Decode(rc, i, nil) assert.NoError(t, err) assert.Equal(t, "foo", i.Name) assert.Equal(t, 123, i.ID) - req, err = http.NewRequest(http.MethodPost, "/?id=321", nil) - require.NoError(t, err) - - req.Header.Set("X-Name", "bar") + rc = req("/?id=321") + rc.Request.Header.SetMethod(http.MethodPost) + rc.Request.Header.Set("X-Name", "bar") i = new(MyInput) - err = dec.Decode(req, i, nil) + err = dec.Decode(rc, i, nil) assert.NoError(t, err) assert.Equal(t, "bar", i.Name) assert.Equal(t, 321, i.ID) diff --git a/request/file.go b/request/file.go index 0bd2e38..6ba2423 100644 --- a/request/file.go +++ b/request/file.go @@ -8,6 +8,7 @@ import ( "reflect" "github.com/swaggest/rest" + "github.com/valyala/fasthttp" ) var ( @@ -17,13 +18,13 @@ var ( multipartFileHeadersType = reflect.TypeOf(([]*multipart.FileHeader)(nil)) ) -func decodeFiles(r *http.Request, input interface{}, _ rest.Validator) error { +func decodeFiles(rc *fasthttp.RequestCtx, input interface{}, _ rest.Validator) error { v := reflect.ValueOf(input) - return decodeFilesInStruct(r, v) + return decodeFilesInStruct(rc, v) } -func decodeFilesInStruct(r *http.Request, v reflect.Value) error { +func decodeFilesInStruct(rc *fasthttp.RequestCtx, v reflect.Value) error { for v.Kind() == reflect.Ptr { v = v.Elem() } @@ -39,7 +40,7 @@ func decodeFilesInStruct(r *http.Request, v reflect.Value) error { if field.Type == multipartFileType || field.Type == multipartFileHeaderType || field.Type == multipartFilesType || field.Type == multipartFileHeadersType { - err := setFile(r, field, v.Field(i)) + err := setFile(rc, field, v.Field(i)) if err != nil { return err } @@ -48,7 +49,7 @@ func decodeFilesInStruct(r *http.Request, v reflect.Value) error { } if field.Anonymous { - if err := decodeFilesInStruct(r, v.Field(i)); err != nil { + if err := decodeFilesInStruct(rc, v.Field(i)); err != nil { return err } } @@ -57,7 +58,8 @@ func decodeFilesInStruct(r *http.Request, v reflect.Value) error { return nil } -func setFile(r *http.Request, field reflect.StructField, v reflect.Value) error { +// nolint:funlen // Maybe later. +func setFile(rc *fasthttp.RequestCtx, field reflect.StructField, v reflect.Value) error { name := "" if tag := field.Tag.Get(fileTag); tag != "" && tag != "-" { name = tag @@ -69,7 +71,7 @@ func setFile(r *http.Request, field reflect.StructField, v reflect.Value) error return nil } - file, header, err := r.FormFile(name) + header, err := rc.FormFile(name) if err != nil { if errors.Is(err, http.ErrMissingFile) { if field.Tag.Get("required") == "true" { @@ -81,6 +83,11 @@ func setFile(r *http.Request, field reflect.StructField, v reflect.Value) error } if field.Type == multipartFileType { + file, err := header.Open() + if err != nil { + return fmt.Errorf("failed to open file %q from request: %w", name, err) + } + v.Set(reflect.ValueOf(file)) } @@ -89,9 +96,14 @@ func setFile(r *http.Request, field reflect.StructField, v reflect.Value) error } if field.Type == multipartFilesType { - res := make([]multipart.File, 0, len(r.MultipartForm.File[name])) + mf, err := rc.MultipartForm() + if err != nil { + return fmt.Errorf("failed to get multipart form from request: %w", err) + } - for _, h := range r.MultipartForm.File[name] { + res := make([]multipart.File, 0, len(mf.File[name])) + + for _, h := range mf.File[name] { f, err := h.Open() if err != nil { return fmt.Errorf("failed to open uploaded file %s (%s): %w", name, h.Filename, err) @@ -104,7 +116,12 @@ func setFile(r *http.Request, field reflect.StructField, v reflect.Value) error } if field.Type == multipartFileHeadersType { - v.Set(reflect.ValueOf(r.MultipartForm.File[name])) + mf, err := rc.MultipartForm() + if err != nil { + return fmt.Errorf("failed to get multipart form from request: %w", err) + } + + v.Set(reflect.ValueOf(mf.File[name])) } return nil diff --git a/request/file_test.go b/request/file_test.go index 035101e..00aaa48 100644 --- a/request/file_test.go +++ b/request/file_test.go @@ -3,15 +3,15 @@ package request_test import ( "bytes" "context" + "errors" "io/ioutil" "mime/multipart" "net/http" "net/http/httptest" "testing" - "github.com/go-chi/chi/v5" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" + "github.com/swaggest/fchi" "github.com/swaggest/rest" "github.com/swaggest/rest/chirouter" "github.com/swaggest/rest/jsonschema" @@ -34,7 +34,7 @@ type fileReqTest struct { } func TestMapper_Decode_fileUploadTag(t *testing.T) { - r := chirouter.NewWrapper(chi.NewRouter()) + r := chirouter.NewWrapper(fchi.NewRouter()) apiSchema := openapi.Collector{} decoderFactory := request.NewDecoderFactory() validatorFactory := jsonschema.NewFactory(&apiSchema, &apiSchema) @@ -66,8 +66,13 @@ func TestMapper_Decode_fileUploadTag(t *testing.T) { assert.NoError(t, in.Upload.Close()) assert.Equal(t, "Hello!", string(content)) - require.Len(t, in.Uploads, 2) - require.Len(t, in.UploadsHeaders, 2) + assert.Len(t, in.Uploads, 2) + assert.Len(t, in.UploadsHeaders, 2) + + if !assert.Len(t, in.Uploads, 2) || assert.Len(t, in.UploadsHeaders, 2) { + return errors.New("missing uploads") + } + assert.Equal(t, "my1.csv", in.UploadsHeaders[0].Filename) assert.Equal(t, int64(7), in.UploadsHeaders[0].Size) assert.Equal(t, "my2.csv", in.UploadsHeaders[1].Filename) @@ -89,7 +94,7 @@ func TestMapper_Decode_fileUploadTag(t *testing.T) { h := nethttp.NewHandler(u) r.Method(http.MethodPost, "/receive", h) - srv := httptest.NewServer(r) + srv := fchi.NewTestServer(r) defer srv.Close() var b bytes.Buffer @@ -119,7 +124,7 @@ func TestMapper_Decode_fileUploadTag(t *testing.T) { hreq.RequestURI = "" hreq.Header.Set("Content-Type", w.FormDataContentType()) - resp, err := srv.Client().Do(hreq) + resp, err := http.DefaultTransport.RoundTrip(hreq) assert.NoError(t, err) assert.NoError(t, resp.Body.Close()) } diff --git a/request/jsonbody.go b/request/jsonbody.go index d17d5b4..ff89bff 100644 --- a/request/jsonbody.go +++ b/request/jsonbody.go @@ -3,20 +3,14 @@ package request import ( "bytes" "encoding/json" + "errors" "fmt" "io" - "net/http" - "sync" "github.com/swaggest/rest" + "github.com/valyala/fasthttp" ) -var bufPool = sync.Pool{ - New: func() interface{} { - return bytes.NewBuffer(nil) - }, -} - func readJSON(rd io.Reader, v interface{}) error { d := json.NewDecoder(rd) @@ -24,32 +18,23 @@ func readJSON(rd io.Reader, v interface{}) error { } func decodeJSONBody(readJSON func(rd io.Reader, v interface{}) error) valueDecoderFunc { - return func(r *http.Request, input interface{}, validator rest.Validator) error { - if r.ContentLength == 0 { - return ErrMissingRequestBody + return func(rc *fasthttp.RequestCtx, input interface{}, validator rest.Validator) error { + if len(rc.Request.Body()) == 0 { + return errors.New("missing request body to decode json") } - contentType := r.Header.Get("Content-Type") - if contentType != "" { - if len(contentType) < 16 || contentType[0:16] != "application/json" { // allow 'application/json;charset=UTF-8' + contentType := rc.Request.Header.ContentType() + if len(contentType) > 0 { + if len(contentType) < 16 || !bytes.Equal(contentType[0:16], []byte("application/json")) { // allow 'application/json;charset=UTF-8' return fmt.Errorf("%w, received: %s", ErrJSONExpected, contentType) } } - var ( - rd io.Reader = r.Body - b *bytes.Buffer - ) + b := rc.Request.Body() validate := validator != nil && validator.HasConstraints(rest.ParamInBody) - if validate { - b = bufPool.Get().(*bytes.Buffer) // nolint:errcheck // bufPool is configured to provide *bytes.Buffer. - defer bufPool.Put(b) - - b.Reset() - rd = io.TeeReader(r.Body, b) - } + rd := bytes.NewReader(b) err := readJSON(rd, &input) if err != nil { @@ -57,7 +42,7 @@ func decodeJSONBody(readJSON func(rd io.Reader, v interface{}) error) valueDecod } if validator != nil && validate { - err = validator.ValidateJSONBody(b.Bytes()) + err = validator.ValidateJSONBody(b) if err != nil { return err } diff --git a/request/jsonbody_test.go b/request/jsonbody_test.go index 14db399..46bca5b 100644 --- a/request/jsonbody_test.go +++ b/request/jsonbody_test.go @@ -1,22 +1,21 @@ package request // nolint:testpackage import ( - "bytes" "errors" - "io" "net/http" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/swaggest/rest" + "github.com/valyala/fasthttp" ) func Test_decodeJSONBody(t *testing.T) { - createBody := bytes.NewReader( - []byte(`{"amount": 123,"customerId": "248df4b7-aa70-47b8-a036-33ac447e668d","type": "withdraw"}`)) - createReq, err := http.NewRequest(http.MethodPost, "/US/order/348df4b7-aa70-47b8-a036-33ac447e668d", createBody) - assert.NoError(t, err) + createBody := []byte(`{"amount": 123,"customerId": "248df4b7-aa70-47b8-a036-33ac447e668d","type": "withdraw"}`) + rc := fasthttp.RequestCtx{} + rc.Request.Header.SetMethod(http.MethodPost) + rc.Request.SetRequestURI("/US/order/348df4b7-aa70-47b8-a036-33ac447e668d") + rc.Request.SetBody(createBody) type Input struct { Amount int `json:"amount"` @@ -25,7 +24,7 @@ func Test_decodeJSONBody(t *testing.T) { } i := Input{} - assert.NoError(t, decodeJSONBody(readJSON)(createReq, &i, nil)) + assert.NoError(t, decodeJSONBody(readJSON)(&rc, &i, nil)) assert.Equal(t, 123, i.Amount) assert.Equal(t, "248df4b7-aa70-47b8-a036-33ac447e668d", i.CustomerID) assert.Equal(t, "withdraw", i.Type) @@ -35,58 +34,65 @@ func Test_decodeJSONBody(t *testing.T) { }) i = Input{} - _, err = createBody.Seek(0, io.SeekStart) - assert.NoError(t, err) - assert.NoError(t, decodeJSONBody(readJSON)(createReq, &i, vl)) + assert.NoError(t, decodeJSONBody(readJSON)(&rc, &i, vl)) assert.Equal(t, 123, i.Amount) assert.Equal(t, "248df4b7-aa70-47b8-a036-33ac447e668d", i.CustomerID) assert.Equal(t, "withdraw", i.Type) } func Test_decodeJSONBody_emptyBody(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, "any", nil) - require.NoError(t, err) + rc := fasthttp.RequestCtx{} + rc.Request.Header.SetMethod(http.MethodPost) + rc.Request.SetRequestURI("any") var i []int - err = decodeJSONBody(readJSON)(req, &i, nil) - assert.EqualError(t, err, "missing request body") + err := decodeJSONBody(readJSON)(&rc, &i, nil) + assert.EqualError(t, err, "missing request body to decode json") } func Test_decodeJSONBody_badContentType(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, "any", bytes.NewBufferString("123")) - require.NoError(t, err) - req.Header.Set("Content-Type", "text/plain") + rc := fasthttp.RequestCtx{} + rc.Request.Header.SetMethod(http.MethodPost) + rc.Request.SetRequestURI("any") + rc.Request.SetBody([]byte("123")) + rc.Request.Header.Set("Content-Type", "text/plain") var i []int - err = decodeJSONBody(readJSON)(req, &i, nil) + err := decodeJSONBody(readJSON)(&rc, &i, nil) assert.EqualError(t, err, "request with application/json content type expected, received: text/plain") } func Test_decodeJSONBody_decodeFailed(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, "any", bytes.NewBufferString("abc")) - require.NoError(t, err) + rc := fasthttp.RequestCtx{} + rc.Request.Header.SetMethod(http.MethodPost) + rc.Request.SetRequestURI("any") + rc.Request.SetBody([]byte("abc")) var i []int - err = decodeJSONBody(readJSON)(req, &i, nil) + err := decodeJSONBody(readJSON)(&rc, &i, nil) assert.Error(t, err) } func Test_decodeJSONBody_unmarshalFailed(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, "any", bytes.NewBufferString("123")) - require.NoError(t, err) + rc := fasthttp.RequestCtx{} + rc.Request.Header.SetMethod(http.MethodPost) + rc.Request.SetRequestURI("any") + rc.Request.SetBody([]byte("123")) var i []int - err = decodeJSONBody(readJSON)(req, &i, nil) + err := decodeJSONBody(readJSON)(&rc, &i, nil) assert.EqualError(t, err, "failed to decode json: json: cannot unmarshal number into Go value of type []int") } func Test_decodeJSONBody_validateFailed(t *testing.T) { - req, err := http.NewRequest(http.MethodPost, "any", bytes.NewBufferString("[123]")) - require.NoError(t, err) + rc := fasthttp.RequestCtx{} + rc.Request.Header.SetMethod(http.MethodPost) + rc.Request.SetRequestURI("any") + rc.Request.SetBody([]byte("[123]")) var i []int @@ -94,6 +100,6 @@ func Test_decodeJSONBody_validateFailed(t *testing.T) { return errors.New("failed") }) - err = decodeJSONBody(readJSON)(req, &i, vl) + err := decodeJSONBody(readJSON)(&rc, &i, vl) assert.EqualError(t, err, "failed") } diff --git a/request/middleware.go b/request/middleware.go index 1df668a..c463875 100644 --- a/request/middleware.go +++ b/request/middleware.go @@ -1,11 +1,11 @@ package request import ( - "net/http" - + "github.com/swaggest/fchi" "github.com/swaggest/rest" "github.com/swaggest/rest/nethttp" "github.com/swaggest/usecase" + "github.com/valyala/fasthttp" ) type requestDecoderSetter interface { @@ -17,8 +17,8 @@ type requestMapping interface { } // DecoderMiddleware sets up request decoder in suitable handlers. -func DecoderMiddleware(factory DecoderMaker) func(http.Handler) http.Handler { - return func(handler http.Handler) http.Handler { +func DecoderMiddleware(factory DecoderMaker) func(fchi.Handler) fchi.Handler { + return func(handler fchi.Handler) fchi.Handler { var ( withRoute rest.HandlerWithRoute withUseCase rest.HandlerWithUseCase @@ -55,8 +55,8 @@ type withRestHandler interface { } // ValidatorMiddleware sets up request validator in suitable handlers. -func ValidatorMiddleware(factory rest.RequestValidatorFactory) func(http.Handler) http.Handler { - return func(handler http.Handler) http.Handler { +func ValidatorMiddleware(factory rest.RequestValidatorFactory) func(fchi.Handler) fchi.Handler { + return func(handler fchi.Handler) fchi.Handler { var ( withRoute rest.HandlerWithRoute withUseCase rest.HandlerWithUseCase @@ -83,11 +83,11 @@ func ValidatorMiddleware(factory rest.RequestValidatorFactory) func(http.Handler var _ nethttp.RequestDecoder = DecoderFunc(nil) // DecoderFunc implements RequestDecoder with a func. -type DecoderFunc func(r *http.Request, input interface{}, validator rest.Validator) error +type DecoderFunc func(rc *fasthttp.RequestCtx, input interface{}, validator rest.Validator) error // Decode implements RequestDecoder. -func (df DecoderFunc) Decode(r *http.Request, input interface{}, validator rest.Validator) error { - return df(r, input, validator) +func (df DecoderFunc) Decode(rc *fasthttp.RequestCtx, input interface{}, validator rest.Validator) error { + return df(rc, input, validator) } // DecoderMaker creates request decoder for particular structured Go input value. diff --git a/response/encoder.go b/response/encoder.go index 0b16e4f..3727880 100644 --- a/response/encoder.go +++ b/response/encoder.go @@ -4,16 +4,19 @@ import ( "bytes" "encoding/json" "fmt" + "io" "net/http" "reflect" "strconv" "sync" + "github.com/swaggest/fchi" "github.com/swaggest/form/v5" "github.com/swaggest/refl" "github.com/swaggest/rest" "github.com/swaggest/usecase" "github.com/swaggest/usecase/status" + "github.com/valyala/fasthttp" ) // Encoder prepares and writes http response. @@ -126,8 +129,7 @@ var jsonEncoderPool = sync.Pool{ } func (h *Encoder) writeJSONResponse( - w http.ResponseWriter, - r *http.Request, + rc *fasthttp.RequestCtx, v interface{}, ht rest.HandlerTrait, ) { @@ -135,12 +137,14 @@ func (h *Encoder) writeJSONResponse( ht.SuccessContentType = "application/json; charset=utf-8" } + hd := &rc.Response.Header + if jw, ok := v.(rest.JSONWriterTo); ok { - w.Header().Set("Content-Type", ht.SuccessContentType) + hd.Set("Content-Type", ht.SuccessContentType) - _, err := jw.JSONWriteTo(w) + _, err := jw.JSONWriteTo(rc) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + fchi.Error(rc, err.Error(), http.StatusInternalServerError) return } @@ -155,7 +159,7 @@ func (h *Encoder) writeJSONResponse( err := e.enc.Encode(v) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + fchi.Error(rc, err.Error(), http.StatusInternalServerError) return } @@ -164,30 +168,30 @@ func (h *Encoder) writeJSONResponse( err = ht.RespValidator.ValidateJSONBody(e.buf.Bytes()) if err != nil { code, er := rest.Err(status.Wrap(fmt.Errorf("bad response: %w", err), status.Internal)) - h.WriteErrResponse(w, r, code, er) + h.WriteErrResponse(rc, code, er) return } } - w.Header().Set("Content-Length", strconv.Itoa(e.buf.Len())) - w.Header().Set("Content-Type", ht.SuccessContentType) - w.WriteHeader(ht.SuccessStatus) + hd.Set("Content-Length", strconv.Itoa(e.buf.Len())) + hd.Set("Content-Type", ht.SuccessContentType) + rc.Response.SetStatusCode(ht.SuccessStatus) - if r.Method == http.MethodHead { + if bytes.Equal(rc.Method(), []byte(http.MethodHead)) { return } - _, err = w.Write(e.buf.Bytes()) + _, err = rc.Write(e.buf.Bytes()) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + fchi.Error(rc, err.Error(), http.StatusInternalServerError) return } } // WriteErrResponse encodes and writes error to response. -func (h *Encoder) WriteErrResponse(w http.ResponseWriter, r *http.Request, statusCode int, response interface{}) { +func (h *Encoder) WriteErrResponse(rc *fasthttp.RequestCtx, statusCode int, response interface{}) { contentType := "application/json; charset=utf-8" e := jsonEncoderPool.Get().(*jsonEncoder) // nolint:errcheck @@ -197,22 +201,23 @@ func (h *Encoder) WriteErrResponse(w http.ResponseWriter, r *http.Request, statu err := e.enc.Encode(response) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + fchi.Error(rc, err.Error(), http.StatusInternalServerError) return } - w.Header().Set("Content-Length", strconv.Itoa(e.buf.Len())) - w.Header().Set("Content-Type", contentType) - w.WriteHeader(statusCode) + hd := &rc.Response.Header + hd.Set("Content-Length", strconv.Itoa(e.buf.Len())) + hd.Set("Content-Type", contentType) + rc.Response.SetStatusCode(statusCode) - if r.Method == http.MethodHead { + if bytes.Equal(rc.Method(), []byte(fasthttp.MethodHead)) { return } - _, err = w.Write(e.buf.Bytes()) + _, err = rc.Write(e.buf.Bytes()) if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + fchi.Error(rc, err.Error(), fasthttp.StatusInternalServerError) return } @@ -220,8 +225,7 @@ func (h *Encoder) WriteErrResponse(w http.ResponseWriter, r *http.Request, statu // WriteSuccessfulResponse encodes and writes successful output of use case interactor to http response. func (h *Encoder) WriteSuccessfulResponse( - w http.ResponseWriter, - r *http.Request, + rc *fasthttp.RequestCtx, output interface{}, ht rest.HandlerTrait, ) { @@ -232,11 +236,11 @@ func (h *Encoder) WriteSuccessfulResponse( if etagged, ok := output.(rest.ETagged); ok { etag := etagged.ETag() if etag != "" { - w.Header().Set("Etag", etag) + rc.Response.Header.Set("Etag", etag) } } - if h.outputHeadersEncoder != nil && !h.whiteHeader(w, r, output, ht) { + if h.outputHeadersEncoder != nil && !h.whiteHeader(rc, output, ht) { return } @@ -256,16 +260,16 @@ func (h *Encoder) WriteSuccessfulResponse( if skipRendering { if ht.SuccessStatus != http.StatusOK { - w.WriteHeader(ht.SuccessStatus) + rc.Response.SetStatusCode(ht.SuccessStatus) } return } - h.writeJSONResponse(w, r, output, ht) + h.writeJSONResponse(rc, output, ht) } -func (h *Encoder) whiteHeader(w http.ResponseWriter, r *http.Request, output interface{}, ht rest.HandlerTrait) bool { +func (h *Encoder) whiteHeader(rc *fasthttp.RequestCtx, output interface{}, ht rest.HandlerTrait) bool { var headerValues map[string]interface{} if ht.RespValidator != nil { headerValues = make(map[string]interface{}) @@ -274,7 +278,7 @@ func (h *Encoder) whiteHeader(w http.ResponseWriter, r *http.Request, output int headers, err := h.outputHeadersEncoder.Encode(output, headerValues) if err != nil { code, er := rest.Err(err) - h.WriteErrResponse(w, r, code, er) + h.WriteErrResponse(rc, code, er) return false } @@ -283,15 +287,17 @@ func (h *Encoder) whiteHeader(w http.ResponseWriter, r *http.Request, output int err = ht.RespValidator.ValidateData(rest.ParamInHeader, headerValues) if err != nil { code, er := rest.Err(status.Wrap(fmt.Errorf("bad response: %w", err), status.Internal)) - h.WriteErrResponse(w, r, code, er) + h.WriteErrResponse(rc, code, er) return false } } + hd := &rc.Response.Header + for header, val := range headers { if len(val) == 1 { - w.Header().Set(header, val[0]) + hd.Set(header, val[0]) } } @@ -299,7 +305,7 @@ func (h *Encoder) whiteHeader(w http.ResponseWriter, r *http.Request, output int } // MakeOutput instantiates a value for use case output port. -func (h *Encoder) MakeOutput(w http.ResponseWriter, ht rest.HandlerTrait) interface{} { +func (h *Encoder) MakeOutput(rc *fasthttp.RequestCtx, ht rest.HandlerTrait) interface{} { if h.outputBufferType == nil { return nil } @@ -310,13 +316,14 @@ func (h *Encoder) MakeOutput(w http.ResponseWriter, ht rest.HandlerTrait) interf if withWriter, ok := output.(usecase.OutputWithWriter); ok { if h.outputHeadersEncoder != nil || ht.SuccessContentType != "" { withWriter.SetWriter(&writerWithHeaders{ - ResponseWriter: w, + Writer: rc, + rc: rc, responseWriter: h, trait: ht, output: output, }) } else { - withWriter.SetWriter(w) + withWriter.SetWriter(rc) } } } @@ -325,7 +332,8 @@ func (h *Encoder) MakeOutput(w http.ResponseWriter, ht rest.HandlerTrait) interf } type writerWithHeaders struct { - http.ResponseWriter + io.Writer + rc *fasthttp.RequestCtx responseWriter *Encoder trait rest.HandlerTrait @@ -345,7 +353,7 @@ func (w *writerWithHeaders) setHeaders() error { for header, val := range headers { if len(val) == 1 { - w.Header().Set(header, val[0]) + w.rc.Response.Header.Set(header, val[0]) } } @@ -359,11 +367,11 @@ func (w *writerWithHeaders) Write(data []byte) (int, error) { } if w.trait.SuccessContentType != "" { - w.Header().Set("Content-Type", w.trait.SuccessContentType) + w.rc.Response.Header.Set("Content-Type", w.trait.SuccessContentType) } w.headersSet = true } - return w.ResponseWriter.Write(data) + return w.rc.Write(data) } diff --git a/response/encoder_test.go b/response/encoder_test.go index 704d356..5cb4b43 100644 --- a/response/encoder_test.go +++ b/response/encoder_test.go @@ -2,7 +2,6 @@ package response_test import ( "net/http" - "net/http/httptest" "testing" "github.com/stretchr/testify/assert" @@ -11,6 +10,7 @@ import ( "github.com/swaggest/rest/jsonschema" "github.com/swaggest/rest/response" "github.com/swaggest/usecase" + "github.com/valyala/fasthttp" ) func TestEncoder_SetupOutput(t *testing.T) { @@ -37,11 +37,10 @@ func TestEncoder_SetupOutput(t *testing.T) { e.SetupOutput(new(outputPort), &ht) - r, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/") - w := httptest.NewRecorder() - output := e.MakeOutput(w, ht) + output := e.MakeOutput(rc, ht) out, ok := output.(*outputPort) assert.True(t, ok) @@ -49,31 +48,31 @@ func TestEncoder_SetupOutput(t *testing.T) { out.Name = "Jane" out.Items = []string{"one", "two", "three"} - e.WriteSuccessfulResponse(w, r, output, ht) - assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "Jane", w.Header().Get("X-Name")) - assert.Equal(t, "application/x-vnd-json", w.Header().Get("Content-Type")) - assert.Equal(t, "32", w.Header().Get("Content-Length")) - assert.Equal(t, `{"items":["one","two","three"]}`+"\n", w.Body.String()) + e.WriteSuccessfulResponse(rc, output, ht) + assert.Equal(t, http.StatusOK, rc.Response.StatusCode()) + assert.Equal(t, "Jane", string(rc.Response.Header.Peek("X-Name"))) + assert.Equal(t, "application/x-vnd-json", string(rc.Response.Header.Peek("Content-Type"))) + assert.Equal(t, "32", string(rc.Response.Header.Peek("Content-Length"))) + assert.Equal(t, `{"items":["one","two","three"]}`+"\n", string(rc.Response.Body())) - w = httptest.NewRecorder() - e.WriteErrResponse(w, r, http.StatusExpectationFailed, rest.ErrResponse{ + rc.Response = fasthttp.Response{} + e.WriteErrResponse(rc, http.StatusExpectationFailed, rest.ErrResponse{ ErrorText: "failed", }) - assert.Equal(t, http.StatusExpectationFailed, w.Code) - assert.Equal(t, "application/json; charset=utf-8", w.Header().Get("Content-Type")) - assert.Equal(t, "19", w.Header().Get("Content-Length")) - assert.Equal(t, `{"error":"failed"}`+"\n", w.Body.String()) + assert.Equal(t, http.StatusExpectationFailed, rc.Response.StatusCode()) + assert.Equal(t, "application/json; charset=utf-8", string(rc.Response.Header.Peek("Content-Type"))) + assert.Equal(t, "19", string(rc.Response.Header.Peek("Content-Length"))) + assert.Equal(t, `{"error":"failed"}`+"\n", string(rc.Response.Body())) out.Name = "Ja" - w = httptest.NewRecorder() - e.WriteSuccessfulResponse(w, r, output, ht) - assert.Equal(t, http.StatusInternalServerError, w.Code) - assert.Equal(t, "", w.Header().Get("X-Name")) - assert.Equal(t, "application/json; charset=utf-8", w.Header().Get("Content-Type")) - assert.Equal(t, "140", w.Header().Get("Content-Length")) + rc.Response = fasthttp.Response{} + e.WriteSuccessfulResponse(rc, output, ht) + assert.Equal(t, http.StatusInternalServerError, rc.Response.StatusCode()) + assert.Equal(t, "", string(rc.Response.Header.Peek("X-Name"))) + assert.Equal(t, "application/json; charset=utf-8", string(rc.Response.Header.Peek("Content-Type"))) + assert.Equal(t, "140", string(rc.Response.Header.Peek("Content-Length"))) assert.Equal(t, `{"status":"INTERNAL","error":"internal: bad response: validation failed",`+ - `"context":{"header:X-Name":["#: length must be >= 3, but got 2"]}}`+"\n", w.Body.String()) + `"context":{"header:X-Name":["#: length must be >= 3, but got 2"]}}`+"\n", string(rc.Response.Body())) } func TestEncoder_SetupOutput_withWriter(t *testing.T) { @@ -90,25 +89,24 @@ func TestEncoder_SetupOutput_withWriter(t *testing.T) { e.SetupOutput(new(outputPort), &ht) - w := httptest.NewRecorder() - r, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/") - output := e.MakeOutput(w, ht) + output := e.MakeOutput(rc, ht) out, ok := output.(*outputPort) assert.True(t, ok) out.Name = "Jane" - _, err = out.Write([]byte("1,2,3")) + _, err := out.Write([]byte("1,2,3")) require.NoError(t, err) - e.WriteSuccessfulResponse(w, r, output, ht) - assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "application/x-vnd-foo", w.Header().Get("Content-Type")) - assert.Equal(t, "1,2,3", w.Body.String()) - assert.Equal(t, "Jane", w.Header().Get("X-Name")) + e.WriteSuccessfulResponse(rc, output, ht) + assert.Equal(t, http.StatusOK, rc.Response.StatusCode()) + assert.Equal(t, "application/x-vnd-foo", string(rc.Response.Header.Peek("Content-Type"))) + assert.Equal(t, "1,2,3", string(rc.Response.Body())) + assert.Equal(t, "Jane", string(rc.Response.Header.Peek("X-Name"))) } func TestEncoder_SetupOutput_withWriterContentType(t *testing.T) { @@ -124,22 +122,21 @@ func TestEncoder_SetupOutput_withWriterContentType(t *testing.T) { e.SetupOutput(new(outputPort), &ht) - w := httptest.NewRecorder() - r, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/") - output := e.MakeOutput(w, ht) + output := e.MakeOutput(rc, ht) out, ok := output.(*outputPort) assert.True(t, ok) - _, err = out.Write([]byte("1,2,3")) + _, err := out.Write([]byte("1,2,3")) require.NoError(t, err) - e.WriteSuccessfulResponse(w, r, output, ht) - assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "application/x-vnd-foo", w.Header().Get("Content-Type")) - assert.Equal(t, "1,2,3", w.Body.String()) + e.WriteSuccessfulResponse(rc, output, ht) + assert.Equal(t, http.StatusOK, rc.Response.StatusCode()) + assert.Equal(t, "application/x-vnd-foo", string(rc.Response.Header.Peek("Content-Type"))) + assert.Equal(t, "1,2,3", string(rc.Response.Body())) } func TestEncoder_SetupOutput_nonPtr(t *testing.T) { @@ -166,11 +163,10 @@ func TestEncoder_SetupOutput_nonPtr(t *testing.T) { e.SetupOutput(outputPort{}, &ht) - r, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/") - w := httptest.NewRecorder() - output := e.MakeOutput(w, ht) + output := e.MakeOutput(rc, ht) out, ok := output.(*outputPort) assert.True(t, ok) @@ -178,10 +174,10 @@ func TestEncoder_SetupOutput_nonPtr(t *testing.T) { out.Name = "Jane" out.Items = []string{"one", "two", "three"} - e.WriteSuccessfulResponse(w, r, output, ht) - assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, "Jane", w.Header().Get("X-Name")) - assert.Equal(t, "application/x-vnd-json", w.Header().Get("Content-Type")) - assert.Equal(t, "32", w.Header().Get("Content-Length")) - assert.Equal(t, `{"items":["one","two","three"]}`+"\n", w.Body.String()) + e.WriteSuccessfulResponse(rc, output, ht) + assert.Equal(t, http.StatusOK, rc.Response.StatusCode()) + assert.Equal(t, "Jane", string(rc.Response.Header.Peek("X-Name"))) + assert.Equal(t, "application/x-vnd-json", string(rc.Response.Header.Peek("Content-Type"))) + assert.Equal(t, "32", string(rc.Response.Header.Peek("Content-Length"))) + assert.Equal(t, `{"items":["one","two","three"]}`+"\n", string(rc.Response.Body())) } diff --git a/response/gzip/middleware.go b/response/gzip/middleware.go index 5ddd621..6846eaf 100644 --- a/response/gzip/middleware.go +++ b/response/gzip/middleware.go @@ -1,234 +1,19 @@ package gzip import ( - "bufio" - "compress/flate" - "compress/gzip" - "fmt" - "io" - "net/http" - "strings" - "sync" + "context" - gz "github.com/swaggest/rest/gzip" -) - -const ( - contentTypeHeader = "Content-Type" - contentLengthHeader = "Content-Length" - contentEncodingHeader = "Content-Encoding" - acceptEncodingHeader = "Accept-Encoding" - - defaultBufferSize = 8 * 1024 + "github.com/swaggest/fchi" + "github.com/valyala/fasthttp" ) // Middleware enables gzip compression of handler response for requests that accept gzip encoding. -func Middleware(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w = maybeGzipResponseWriter(w, r) - if grw, ok := w.(*gzipResponseWriter); ok { - defer func() { - err := grw.Close() - if err != nil { - panic(fmt.Sprintf("BUG: cannot close gzip writer: %s", err)) - } - }() - } +func Middleware(next fchi.Handler) fchi.Handler { + f := fasthttp.CompressHandlerLevel(func(rc *fasthttp.RequestCtx) { + next.ServeHTTP(rc, rc) + }, fasthttp.CompressBestSpeed) - next.ServeHTTP(w, r) + return fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { + f(rc) }) } - -var ( - gzipWriterPool sync.Pool - bufWriterPool sync.Pool -) - -func getGzipWriter(w io.Writer) *gzip.Writer { - v := gzipWriterPool.Get() - if v == nil { - zw, err := gzip.NewWriterLevel(w, flate.BestSpeed) - if err != nil { - panic(fmt.Sprintf("BUG: cannot create gzip writer: %s", err)) - } - - return zw - } - - // nolint:errcheck // OK to panic here. - zw := v.(*gzip.Writer) - - zw.Reset(w) - - return zw -} - -func getBufWriter(w io.Writer) *bufio.Writer { - v := bufWriterPool.Get() - if v == nil { - return bufio.NewWriterSize(w, defaultBufferSize) - } - - // nolint:errcheck // OK to panic here. - bw := v.(*bufio.Writer) - - bw.Reset(w) - - return bw -} - -func maybeGzipResponseWriter(w http.ResponseWriter, r *http.Request) http.ResponseWriter { - ae := r.Header.Get(acceptEncodingHeader) - if ae == "" { - return w - } - - ae = strings.ToLower(ae) - - if n := strings.Index(ae, "gzip"); n < 0 { - return w - } - - zrw := &gzipResponseWriter{ - ResponseWriter: w, - } - - return zrw -} - -type gzipResponseWriter struct { - http.ResponseWriter - gzipWriter *gzip.Writer - bufWriter *bufio.Writer - - expectCompressedBytes bool - headersWritten bool - disableCompression bool -} - -var _ gz.Writer = &gzipResponseWriter{} - -func (rw *gzipResponseWriter) GzipWrite(data []byte) (int, error) { - if rw.headersWritten { - return 0, nil - } - - rw.expectCompressedBytes = true - - return rw.Write(data) -} - -func (rw *gzipResponseWriter) writeHeader(statusCode int) { - if rw.headersWritten { - return - } - - if statusCode == http.StatusNoContent || - statusCode == http.StatusNotModified || - (statusCode >= http.StatusContinue && statusCode < http.StatusOK) { - rw.disableCompression = true - } - - h := rw.Header() - - if h.Get(contentEncodingHeader) != "" || rw.disableCompression { - // The request handler disabled gzip encoding. - // Send uncompressed response body. - rw.disableCompression = true - } else { - h.Set(contentEncodingHeader, "gzip") - - if !rw.expectCompressedBytes { - rw.gzipWriter = getGzipWriter(rw.ResponseWriter) - rw.bufWriter = getBufWriter(rw.gzipWriter) - } - - h.Del(contentLengthHeader) - - if h.Get(contentTypeHeader) == "" { - // Disable auto-detection of content-type, since it - // is incorrectly detected after the compression. - h.Set(contentTypeHeader, "text/html") - } - } - - rw.ResponseWriter.WriteHeader(statusCode) - rw.headersWritten = true -} - -func (rw *gzipResponseWriter) Write(p []byte) (int, error) { - if !rw.headersWritten { - rw.writeHeader(http.StatusOK) - } - - if rw.disableCompression || rw.expectCompressedBytes { - return rw.ResponseWriter.Write(p) - } - - return rw.bufWriter.Write(p) -} - -func (rw *gzipResponseWriter) WriteHeader(statusCode int) { - rw.writeHeader(statusCode) -} - -func isTrivialNetworkError(err error) bool { - s := err.Error() - if strings.Contains(s, "broken pipe") || strings.Contains(s, "reset by peer") { - return true - } - - return false -} - -// Flush implements http.Flusher. -func (rw *gzipResponseWriter) Flush() { - if rw.bufWriter == nil || rw.gzipWriter == nil { - return - } - - if err := rw.bufWriter.Flush(); err != nil && !isTrivialNetworkError(err) { - panic(fmt.Sprintf("BUG: cannot flush bufio.Writer: %s", err)) - } - - if err := rw.gzipWriter.Flush(); err != nil && !isTrivialNetworkError(err) { - panic(fmt.Sprintf("BUG: cannot flush gzip.Writer: %s", err)) - } - - if fw, ok := rw.ResponseWriter.(http.Flusher); ok { - fw.Flush() - } -} - -// Close flushes and closes response. -func (rw *gzipResponseWriter) Close() error { - if !rw.headersWritten { - rw.disableCompression = true - - return nil - } - - if rw.bufWriter == nil || rw.gzipWriter == nil { - return nil - } - - rw.Flush() - - err := rw.gzipWriter.Close() - - putBufWriter(rw.bufWriter) - rw.bufWriter = nil - - putGzipWriter(rw.gzipWriter) - rw.gzipWriter = nil - - return err -} - -func putGzipWriter(zw *gzip.Writer) { - gzipWriterPool.Put(zw) -} - -func putBufWriter(bw *bufio.Writer) { - bufWriterPool.Put(bw) -} diff --git a/response/gzip/middleware_test.go b/response/gzip/middleware_test.go index ffd99d2..1d53631 100644 --- a/response/gzip/middleware_test.go +++ b/response/gzip/middleware_test.go @@ -3,62 +3,64 @@ package gzip_test import ( "bytes" gz "compress/gzip" + "context" "io/ioutil" "net/http" - "net/http/httptest" "strings" "sync" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/swaggest/fchi" gzip2 "github.com/swaggest/rest/gzip" "github.com/swaggest/rest/response/gzip" + "github.com/valyala/fasthttp" ) func TestMiddleware(t *testing.T) { resp := []byte(strings.Repeat("A", 10000) + "!!!") - h := gzip.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - _, err := rw.Write(resp) + h := gzip.Middleware(fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { + _, err := rc.Write(resp) assert.NoError(t, err) })) - rw := httptest.NewRecorder() - r, err := http.NewRequest(http.MethodGet, "/", nil) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/") + rc.Request.Header.Set("Accept-Encoding", "gzip, deflate, br") - require.NoError(t, err) - r.Header.Set("Accept-Encoding", "gzip, deflate, br") - - h.ServeHTTP(rw, r) + h.ServeHTTP(rc, rc) - assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) - assert.Less(t, rw.Body.Len(), len(resp)) // Response is compressed. - assert.Equal(t, resp, gzipDecode(t, rw.Body.Bytes())) + assert.Equal(t, "gzip", string(rc.Response.Header.Peek("Content-Encoding"))) + assert.Less(t, len(rc.Response.Body()), len(resp)) // Response is compressed. + assert.Equal(t, resp, gzipDecode(t, rc.Response.Body())) - rw = httptest.NewRecorder() - h.ServeHTTP(rw, r) + rc = &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/") + rc.Request.Header.Set("Accept-Encoding", "gzip, deflate, br") - assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) - assert.Less(t, rw.Body.Len(), len(resp)) // Response is compressed. - assert.Equal(t, resp, gzipDecode(t, rw.Body.Bytes())) + h.ServeHTTP(rc, rc) - rw = httptest.NewRecorder() + assert.Equal(t, "gzip", string(rc.Response.Header.Peek("Content-Encoding"))) + assert.Less(t, len(rc.Response.Body()), len(resp)) // Response is compressed. + assert.Equal(t, resp, gzipDecode(t, rc.Response.Body())) - r.Header.Set("Accept-Encoding", "deflate, br") - h.ServeHTTP(rw, r) + rc = &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/") + rc.Request.Header.Set("Accept-Encoding", "br") + h.ServeHTTP(rc, rc) - assert.Equal(t, "", rw.Header().Get("Content-Encoding")) - assert.Equal(t, rw.Body.Len(), len(resp)) // Response is not compressed. - assert.Equal(t, resp, rw.Body.Bytes()) + assert.Equal(t, "", string(rc.Response.Header.Peek("Content-Encoding"))) + require.Equal(t, len(rc.Response.Body()), len(resp)) // Response is not compressed. + assert.Equal(t, resp, rc.Response.Body()) - rw = httptest.NewRecorder() + rc = &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/") + h.ServeHTTP(rc, rc) - r.Header.Del("Accept-Encoding") - h.ServeHTTP(rw, r) - - assert.Equal(t, "", rw.Header().Get("Content-Encoding")) - assert.Equal(t, rw.Body.Len(), len(resp)) // Response is not compressed. - assert.Equal(t, resp, rw.Body.Bytes()) + assert.Equal(t, "", string(rc.Response.Header.Peek("Content-Encoding"))) + require.Equal(t, len(rc.Response.Body()), len(resp)) // Response is not compressed. + assert.Equal(t, resp, rc.Response.Body()) } // BenchmarkMiddleware measures performance of handler with compression. @@ -67,22 +69,22 @@ func TestMiddleware(t *testing.T) { // BenchmarkMiddleware-12 108810 9619 ns/op 1223 B/op 11 allocs/op. func BenchmarkMiddleware(b *testing.B) { resp := []byte(strings.Repeat("A", 10000) + "!!!") - h := gzip.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - _, err := rw.Write(resp) + h := gzip.Middleware(fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { + _, err := rc.Write(resp) assert.NoError(b, err) })) - r, err := http.NewRequest(http.MethodGet, "/", nil) - - require.NoError(b, err) - r.Header.Set("Accept-Encoding", "gzip, deflate, br") + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/") + rc.Request.Header.Set("Accept-Encoding", "gzip, deflate, br") b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - rw := httptest.NewRecorder() - h.ServeHTTP(rw, r) + rc.Response = fasthttp.Response{} + + h.ServeHTTP(rc, rc) } } @@ -92,35 +94,34 @@ func BenchmarkMiddleware(b *testing.B) { // BenchmarkMiddleware_control-4 214824 5945 ns/op 11184 B/op 9 allocs/op. func BenchmarkMiddleware_control(b *testing.B) { resp := []byte(strings.Repeat("A", 10000) + "!!!") - h := http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - _, err := rw.Write(resp) + h := fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { + _, err := rc.Write(resp) assert.NoError(b, err) }) - r, err := http.NewRequest(http.MethodGet, "/", nil) - - require.NoError(b, err) - r.Header.Set("Accept-Encoding", "gzip, deflate, br") + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/") + rc.Request.Header.Set("Accept-Encoding", "gzip, deflate, br") b.ReportAllocs() b.ResetTimer() for i := 0; i < b.N; i++ { - rw := httptest.NewRecorder() - h.ServeHTTP(rw, r) + rc.Response = fasthttp.Response{} + h.ServeHTTP(rc, rc) } } func TestMiddleware_concurrency(t *testing.T) { resp := []byte(strings.Repeat("A", 10000) + "!!!") respGz := gzipEncode(t, resp) - h := gzip.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - _, err := rw.Write(resp) + h := gzip.Middleware(fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { + _, err := rc.Write(resp) assert.NoError(t, err) })) - hg := gzip.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - _, err := gzip2.WriteCompressedBytes(respGz, rw) + hg := gzip.Middleware(fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { + _, err := gzip2.WriteCompressedBytes(respGz, rc) assert.NoError(t, err) })) @@ -132,25 +133,23 @@ func TestMiddleware_concurrency(t *testing.T) { go func() { defer wg.Done() - rw := httptest.NewRecorder() - r, err := http.NewRequest(http.MethodGet, "/", nil) - - require.NoError(t, err) - r.Header.Set("Accept-Encoding", "gzip, deflate, br") + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/") + rc.Request.Header.Set("Accept-Encoding", "gzip, deflate, br") - h.ServeHTTP(rw, r) + h.ServeHTTP(rc, rc) - assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) - assert.Less(t, rw.Body.Len(), len(resp)) // Response is compressed. - assert.Equal(t, resp, gzipDecode(t, rw.Body.Bytes())) + assert.Equal(t, "gzip", string(rc.Response.Header.Peek("Content-Encoding"))) + assert.Less(t, len(rc.Response.Body()), len(resp)) // Response is compressed. + assert.Equal(t, resp, gzipDecode(t, rc.Response.Body())) - rw = httptest.NewRecorder() + rc.Response = fasthttp.Response{} - hg.ServeHTTP(rw, r) + hg.ServeHTTP(rc, rc) - assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) - assert.Less(t, rw.Body.Len(), len(resp)) // Response is compressed. - assert.Equal(t, respGz, rw.Body.Bytes()) + assert.Equal(t, "gzip", string(rc.Response.Header.Peek("Content-Encoding"))) + assert.Less(t, len(rc.Response.Body()), len(resp)) // Response is compressed. + assert.True(t, bytes.Equal(resp, gzipDecode(t, rc.Response.Body()))) }() } @@ -161,63 +160,57 @@ func TestGzipResponseWriter_ExpectCompressedBytes(t *testing.T) { resp := []byte(strings.Repeat("A", 10000) + "!!!") respGz := gzipEncode(t, resp) - h := gzip.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - _, err := gzip2.WriteCompressedBytes(respGz, rw) + h := gzip.Middleware(fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { + _, err := gzip2.WriteCompressedBytes(respGz, rc) assert.NoError(t, err) })) - rw := httptest.NewRecorder() - r, err := http.NewRequest(http.MethodGet, "/", nil) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/") + rc.Request.Header.Set("Accept-Encoding", "gzip, deflate, br") - require.NoError(t, err) - r.Header.Set("Accept-Encoding", "gzip, deflate, br") - - h.ServeHTTP(rw, r) + h.ServeHTTP(rc, rc) - assert.Equal(t, "gzip", rw.Header().Get("Content-Encoding")) - assert.Less(t, rw.Body.Len(), len(resp)) // Response is compressed. - assert.Equal(t, respGz, rw.Body.Bytes()) + assert.Equal(t, "gzip", string(rc.Response.Header.Peek("Content-Encoding"))) + assert.Less(t, len(rc.Response.Body()), len(resp)) // Response is compressed. + assert.Equal(t, respGz, rc.Response.Body()) } func TestMiddleware_skipContentEncoding(t *testing.T) { resp := []byte(strings.Repeat("A", 10000) + "!!!") - h := gzip.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - rw.Header().Set("Content-Encoding", "br") - _, err := rw.Write(resp) + h := gzip.Middleware(fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { + rc.Response.Header.Set("Content-Encoding", "br") + _, err := rc.Write(resp) assert.NoError(t, err) })) - rw := httptest.NewRecorder() - r, err := http.NewRequest(http.MethodGet, "/", nil) - - require.NoError(t, err) - r.Header.Set("Accept-Encoding", "gzip, deflate, br") + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/") + rc.Request.Header.Set("Accept-Encoding", "gzip, deflate, br") - h.ServeHTTP(rw, r) + h.ServeHTTP(rc, rc) - assert.Equal(t, "br", rw.Header().Get("Content-Encoding")) - assert.Equal(t, rw.Body.Len(), len(resp)) // Response is not compressed. - assert.Equal(t, resp, rw.Body.Bytes()) + assert.Equal(t, "br", string(rc.Response.Header.Peek("Content-Encoding"))) + assert.Equal(t, len(rc.Response.Body()), len(resp)) // Response is not compressed. + assert.Equal(t, resp, rc.Response.Body()) } func TestMiddleware_noContent(t *testing.T) { - h := gzip.Middleware(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - rw.WriteHeader(http.StatusNoContent) + h := gzip.Middleware(fchi.HandlerFunc(func(ctx context.Context, rc *fasthttp.RequestCtx) { + rc.Response.SetStatusCode(http.StatusNoContent) // Second call does not hurt. - rw.WriteHeader(http.StatusNoContent) + rc.Response.SetStatusCode(http.StatusNoContent) })) - rw := httptest.NewRecorder() - r, err := http.NewRequest(http.MethodGet, "/", nil) - - require.NoError(t, err) - r.Header.Set("Accept-Encoding", "gzip, deflate, br") + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/") + rc.Request.Header.Set("Accept-Encoding", "gzip, deflate, br") - h.ServeHTTP(rw, r) + h.ServeHTTP(rc, rc) - assert.Equal(t, "", rw.Header().Get("Content-Encoding")) - assert.Equal(t, rw.Body.Len(), 0) + assert.Equal(t, "", string(rc.Response.Header.Peek("Content-Encoding"))) + assert.Equal(t, len(rc.Response.Body()), 0) } func gzipEncode(t *testing.T, data []byte) []byte { diff --git a/response/middleware.go b/response/middleware.go index 792da1d..559c55a 100644 --- a/response/middleware.go +++ b/response/middleware.go @@ -1,8 +1,7 @@ package response import ( - "net/http" - + "github.com/swaggest/fchi" "github.com/swaggest/rest" "github.com/swaggest/rest/nethttp" "github.com/swaggest/usecase" @@ -12,8 +11,8 @@ type responseEncoderSetter interface { SetResponseEncoder(responseWriter nethttp.ResponseEncoder) } -// EncoderMiddleware instruments qualifying http.Handler with Encoder. -func EncoderMiddleware(handler http.Handler) http.Handler { +// EncoderMiddleware instruments qualifying fchi.Handler with Encoder. +func EncoderMiddleware(handler fchi.Handler) fchi.Handler { var ( withUseCase rest.HandlerWithUseCase setResponseEncoder responseEncoderSetter diff --git a/response/middleware_test.go b/response/middleware_test.go index 7da317a..0d4dc96 100644 --- a/response/middleware_test.go +++ b/response/middleware_test.go @@ -3,14 +3,13 @@ package response_test import ( "context" "net/http" - "net/http/httptest" "testing" "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" "github.com/swaggest/rest/nethttp" "github.com/swaggest/rest/response" "github.com/swaggest/usecase" + "github.com/valyala/fasthttp" ) func TestEncoderMiddleware(t *testing.T) { @@ -34,12 +33,11 @@ func TestEncoderMiddleware(t *testing.T) { h := nethttp.NewHandler(u) - w := httptest.NewRecorder() - r, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/") - response.EncoderMiddleware(h).ServeHTTP(w, r) - assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, `{"items":["one","two","three"]}`+"\n", w.Body.String()) - assert.Equal(t, "Jane", w.Header().Get("X-Name")) + response.EncoderMiddleware(h).ServeHTTP(rc, rc) + assert.Equal(t, http.StatusOK, rc.Response.StatusCode()) + assert.Equal(t, `{"items":["one","two","three"]}`+"\n", string(rc.Response.Body())) + assert.Equal(t, "Jane", string(rc.Response.Header.Peek("X-Name"))) } diff --git a/response/validator.go b/response/validator.go index 3e5a274..6155c3e 100644 --- a/response/validator.go +++ b/response/validator.go @@ -3,6 +3,7 @@ package response import ( "net/http" + "github.com/swaggest/fchi" "github.com/swaggest/rest" "github.com/swaggest/rest/nethttp" "github.com/swaggest/usecase" @@ -13,8 +14,8 @@ type withRestHandler interface { } // ValidatorMiddleware sets up response validator in suitable handlers. -func ValidatorMiddleware(factory rest.ResponseValidatorFactory) func(http.Handler) http.Handler { - return func(handler http.Handler) http.Handler { +func ValidatorMiddleware(factory rest.ResponseValidatorFactory) func(fchi.Handler) fchi.Handler { + return func(handler fchi.Handler) fchi.Handler { var ( withUseCase rest.HandlerWithUseCase handlerTrait withRestHandler diff --git a/response/validator_test.go b/response/validator_test.go index 8a56f3b..b576107 100644 --- a/response/validator_test.go +++ b/response/validator_test.go @@ -3,7 +3,6 @@ package response_test import ( "context" "net/http" - "net/http/httptest" "testing" "github.com/stretchr/testify/assert" @@ -13,6 +12,7 @@ import ( "github.com/swaggest/rest/openapi" "github.com/swaggest/rest/response" "github.com/swaggest/usecase" + "github.com/valyala/fasthttp" ) func TestValidatorMiddleware(t *testing.T) { @@ -47,20 +47,19 @@ func TestValidatorMiddleware(t *testing.T) { validatorFactory := jsonschema.NewFactory(apiSchema, apiSchema) wh := nethttp.WrapHandler(h, response.EncoderMiddleware, response.ValidatorMiddleware(validatorFactory)) - w := httptest.NewRecorder() - r, err := http.NewRequest(http.MethodGet, "/", nil) - require.NoError(t, err) + rc := &fasthttp.RequestCtx{} + rc.Request.SetRequestURI("/") - wh.ServeHTTP(w, r) - assert.Equal(t, http.StatusInternalServerError, w.Code) + wh.ServeHTTP(rc, rc) + assert.Equal(t, http.StatusInternalServerError, rc.Response.StatusCode()) assert.Equal(t, `{"status":"INTERNAL","error":"internal: bad response: validation failed",`+ - `"context":{"header:X-Name":["#: length must be >= 3, but got 2"]}}`+"\n", w.Body.String()) + `"context":{"header:X-Name":["#: length must be >= 3, but got 2"]}}`+"\n", string(rc.Response.Body())) invalidOut.Name = "Jane" - w = httptest.NewRecorder() + rc.Response = fasthttp.Response{} - wh.ServeHTTP(w, r) - assert.Equal(t, http.StatusInternalServerError, w.Code) + wh.ServeHTTP(rc, rc) + assert.Equal(t, http.StatusInternalServerError, rc.Response.StatusCode()) assert.Equal(t, `{"status":"INTERNAL","error":"internal: bad response: validation failed",`+ - `"context":{"body":["#/items: minimum 3 items allowed, but found 1 items"]}}`+"\n", w.Body.String()) + `"context":{"body":["#/items: minimum 3 items allowed, but found 1 items"]}}`+"\n", string(rc.Response.Body())) } diff --git a/resttest/client.go b/resttest/client.go deleted file mode 100644 index 9d52f47..0000000 --- a/resttest/client.go +++ /dev/null @@ -1,17 +0,0 @@ -package resttest - -import ( - "github.com/bool64/httpmock" -) - -// Client keeps state of expectations. -// -// Deprecated: please use httpmock.Client. -type Client = httpmock.Client - -// NewClient creates client instance, baseURL may be empty if Client.SetBaseURL is used later. -// -// Deprecated: please use httpmock.NewClient. -func NewClient(baseURL string) *Client { - return httpmock.NewClient(baseURL) -} diff --git a/resttest/client_test.go b/resttest/client_test.go deleted file mode 100644 index 702b162..0000000 --- a/resttest/client_test.go +++ /dev/null @@ -1,106 +0,0 @@ -package resttest_test - -import ( - "io/ioutil" - "net/http" - "net/http/httptest" - "sync/atomic" - "testing" - - "github.com/bool64/httpmock" - "github.com/bool64/shared" - "github.com/stretchr/testify/assert" -) - -func TestNewClient(t *testing.T) { - cnt := int64(0) - srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/foo?q=1", r.URL.String()) - b, err := ioutil.ReadAll(r.Body) - assert.NoError(t, err) - assert.Equal(t, `{"foo":"bar"}`, string(b)) - assert.Equal(t, "application/json", r.Header.Get("Content-Type")) - assert.Equal(t, "abc", r.Header.Get("X-Header")) - assert.Equal(t, "def", r.Header.Get("X-Custom")) - - c, err := r.Cookie("c1") - assert.NoError(t, err) - assert.Equal(t, "1", c.Value) - - c, err = r.Cookie("c2") - assert.NoError(t, err) - assert.Equal(t, "2", c.Value) - - c, err = r.Cookie("foo") - assert.NoError(t, err) - assert.Equal(t, "bar", c.Value) - - ncnt := atomic.AddInt64(&cnt, 1) - rw.Header().Set("Content-Type", "application/json") - if ncnt > 1 { - rw.WriteHeader(http.StatusConflict) - _, err := rw.Write([]byte(`{"error":"conflict"}`)) - assert.NoError(t, err) - } else { - rw.WriteHeader(http.StatusAccepted) - _, err := rw.Write([]byte(`{"bar":"foo", "dyn": "abc"}`)) - assert.NoError(t, err) - } - })) - - defer srv.Close() - - vars := &shared.Vars{} - - c := httpmock.NewClient(srv.URL) - c.JSONComparer.Vars = vars - c.ConcurrencyLevel = 50 - c.Headers = map[string]string{ - "X-Header": "abc", - } - c.Cookies = map[string]string{ - "foo": "bar", - "c1": "to-be-overridden", - } - - c.Reset(). - WithMethod(http.MethodPost). - WithHeader("X-Custom", "def"). - WithContentType("application/json"). - WithBody([]byte(`{"foo":"bar"}`)). - WithCookie("c1", "1"). - WithCookie("c2", "2"). - WithURI("/foo?q=1"). - Concurrently() - - assert.NoError(t, c.ExpectResponseStatus(http.StatusAccepted)) - assert.NoError(t, c.ExpectResponseBody([]byte(`{"bar":"foo","dyn":"$var1"}`))) - assert.NoError(t, c.ExpectResponseHeader("Content-Type", "application/json")) - assert.NoError(t, c.ExpectOtherResponsesStatus(http.StatusConflict)) - assert.NoError(t, c.ExpectOtherResponsesBody([]byte(`{"error":"conflict"}`))) - assert.NoError(t, c.ExpectOtherResponsesHeader("Content-Type", "application/json")) - assert.NoError(t, c.CheckUnexpectedOtherResponses()) - assert.EqualError(t, c.ExpectNoOtherResponses(), "unexpected response status, expected: 202 (Accepted), received: 409 (Conflict)") - - val, found := vars.Get("$var1") - assert.True(t, found) - assert.Equal(t, "abc", val) -} - -func TestNewClient_failedExpectation(t *testing.T) { - srv := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { - _, err := writer.Write([]byte(`{"bar":"foo"}`)) - assert.NoError(t, err) - })) - defer srv.Close() - c := httpmock.NewClient(srv.URL) - - c.OnBodyMismatch = func(received []byte) { - assert.Equal(t, `{"bar":"foo"}`, string(received)) - println(received) - } - - c.WithURI("/") - assert.EqualError(t, c.ExpectResponseBody([]byte(`{"foo":"bar}"`)), - "unexpected body, expected: {\"foo\":\"bar}\", received: {\"bar\":\"foo\"}") -} diff --git a/resttest/doc.go b/resttest/doc.go deleted file mode 100644 index f2ca23f..0000000 --- a/resttest/doc.go +++ /dev/null @@ -1,4 +0,0 @@ -// Package resttest provides utilities to test REST API. -// -// Deprecated: please use github.com/bool64/httpmock. -package resttest diff --git a/resttest/example_test.go b/resttest/example_test.go deleted file mode 100644 index 466aa4f..0000000 --- a/resttest/example_test.go +++ /dev/null @@ -1,64 +0,0 @@ -package resttest_test - -import ( - "fmt" - "net/http" - - "github.com/bool64/httpmock" -) - -func ExampleNewClient() { - // Prepare server mock. - sm, url := httpmock.NewServer() - defer sm.Close() - - // This example shows Client and ServerMock working together for sake of portability. - // In real-world scenarios Client would complement real server or ServerMock would complement real HTTP client. - - // Set successful expectation for first request out of concurrent batch. - exp := httpmock.Expectation{ - Method: http.MethodPost, - RequestURI: "/foo?q=1", - RequestHeader: map[string]string{ - "X-Custom": "def", - "X-Header": "abc", - "Content-Type": "application/json", - }, - RequestBody: []byte(`{"foo":"bar"}`), - Status: http.StatusAccepted, - ResponseBody: []byte(`{"bar":"foo"}`), - } - sm.Expect(exp) - - // Set failing expectation for other requests of concurrent batch. - exp.Status = http.StatusConflict - exp.ResponseBody = []byte(`{"error":"conflict"}`) - exp.Unlimited = true - sm.Expect(exp) - - // Prepare client request. - c := httpmock.NewClient(url) - c.ConcurrencyLevel = 50 - c.Headers = map[string]string{ - "X-Header": "abc", - } - - c.Reset(). - WithMethod(http.MethodPost). - WithHeader("X-Custom", "def"). - WithContentType("application/json"). - WithBody([]byte(`{"foo":"bar"}`)). - WithURI("/foo?q=1"). - Concurrently() - - // Check expectations errors. - fmt.Println( - c.ExpectResponseStatus(http.StatusAccepted), - c.ExpectResponseBody([]byte(`{"bar":"foo"}`)), - c.ExpectOtherResponsesStatus(http.StatusConflict), - c.ExpectOtherResponsesBody([]byte(`{"error":"conflict"}`)), - ) - - // Output: - // -} diff --git a/resttest/server.go b/resttest/server.go deleted file mode 100644 index b013792..0000000 --- a/resttest/server.go +++ /dev/null @@ -1,20 +0,0 @@ -package resttest - -import ( - "github.com/bool64/httpmock" -) - -// Expectation describes expected request and defines response. -// -// Deprecated: please use httpmock.Expectation. -type Expectation = httpmock.Expectation - -// ServerMock serves predefined response for predefined request. -type ServerMock = httpmock.Server - -// NewServerMock creates mocked server. -// -// Deprecated: please use httpmock.NewServer. -func NewServerMock() (*ServerMock, string) { - return httpmock.NewServer() -} diff --git a/resttest/server_test.go b/resttest/server_test.go deleted file mode 100644 index 3d4d6aa..0000000 --- a/resttest/server_test.go +++ /dev/null @@ -1,382 +0,0 @@ -package resttest_test - -import ( - "bytes" - "io" - "io/ioutil" - "net/http" - "strings" - "sync" - "testing" - - "github.com/bool64/httpmock" - "github.com/bool64/shared" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func assertRoundTrip(t *testing.T, baseURL string, expectation httpmock.Expectation) { - t.Helper() - - var bodyReader io.Reader - - if expectation.RequestBody != nil { - bodyReader = bytes.NewReader(expectation.RequestBody) - } - - req, err := http.NewRequest(expectation.Method, baseURL+expectation.RequestURI, bodyReader) - require.NoError(t, err) - - for k, v := range expectation.RequestHeader { - req.Header.Set(k, v) - } - - for n, v := range expectation.RequestCookie { - req.AddCookie(&http.Cookie{Name: n, Value: v}) - } - - resp, err := http.DefaultTransport.RoundTrip(req) - require.NoError(t, err) - - body, err := ioutil.ReadAll(resp.Body) - require.NoError(t, resp.Body.Close()) - require.NoError(t, err) - - if expectation.Status == 0 { - expectation.Status = http.StatusOK - } - - assert.Equal(t, expectation.Status, resp.StatusCode) - assert.Equal(t, string(expectation.ResponseBody), string(body)) - - // Asserting default for successful responses. - if resp.StatusCode != http.StatusInternalServerError { - assert.Equal(t, "application/json", resp.Header.Get("Content-Type")) - } - - if len(expectation.ResponseHeader) > 0 { - for k, v := range expectation.ResponseHeader { - assert.Equal(t, v, resp.Header.Get(k)) - } - } -} - -func TestServerMock_ServeHTTP(t *testing.T) { - // Creating REST service mock. - mock, baseURL := httpmock.NewServer() - defer mock.Close() - - mock.OnBodyMismatch = func(received []byte) { - assert.Equal(t, `{"foo":"bar"}`, string(received)) - } - - mock.DefaultResponseHeaders = map[string]string{ - "Content-Type": "application/json", - } - - // Requesting mock without expectations fails. - assertRoundTrip(t, baseURL, httpmock.Expectation{ - RequestURI: "/test?test=test", - Status: http.StatusInternalServerError, - ResponseBody: []byte("unexpected request received: GET /test?test=test"), - }) - - // Requesting mock without expectations fails. - assertRoundTrip(t, baseURL, httpmock.Expectation{ - RequestURI: "/test?test=test", - Status: http.StatusInternalServerError, - RequestBody: []byte(`{"foo":"bar"}`), - ResponseBody: []byte("unexpected request received: GET /test?test=test, body:\n{\"foo\":\"bar\"}"), - }) - - // Setting expectations for first request. - exp1 := httpmock.Expectation{ - Method: http.MethodPost, - RequestURI: "/test?test=test", - RequestHeader: map[string]string{"Authorization": "Bearer token"}, - RequestCookie: map[string]string{"c1": "v1", "c2": "v2"}, - RequestBody: []byte(`{"request":"body"}`), - - Status: http.StatusCreated, - ResponseBody: []byte(`{"response":"body"}`), - } - mock.Expect(exp1) - - // Setting expectations for second request. - exp2 := httpmock.Expectation{ - Method: http.MethodPost, - RequestURI: "/test?test=test", - RequestBody: []byte(`not a JSON`), - - ResponseHeader: map[string]string{ - "X-Foo": "bar", - }, - ResponseBody: []byte(`{"response":"body2"}`), - } - mock.Expect(exp2) - - // Sending first request. - assertRoundTrip(t, baseURL, exp1) - - // Expectations were not met yet. - assert.EqualError(t, mock.ExpectationsWereMet(), - "there are remaining expectations that were not met: POST /test?test=test") - - // Sending second request. - assertRoundTrip(t, baseURL, exp2) - - // Expectations were met. - assert.NoError(t, mock.ExpectationsWereMet()) - - // Requesting mock without expectations fails. - assertRoundTrip(t, baseURL, httpmock.Expectation{ - RequestURI: "/test?test=test", - Status: http.StatusInternalServerError, - ResponseBody: []byte("unexpected request received: GET /test?test=test"), - }) -} - -func TestServerMock_ServeHTTP_error(t *testing.T) { - // Creating REST service mock. - mock, baseURL := httpmock.NewServer() - defer mock.Close() - - mock.OnBodyMismatch = func(received []byte) { - assert.Equal(t, `{"request":"body"}`, string(received)) - } - - // Setting expectations for first request. - mock.Expect(httpmock.Expectation{ - Method: http.MethodPost, - RequestURI: "/test?test=test", - RequestHeader: map[string]string{"X-Foo": "bar"}, - RequestBody: []byte(`{"foo":"bar"}`), - }) - - // Sending request with wrong uri. - req, err := http.NewRequest(http.MethodPost, baseURL+"/wrong-uri", bytes.NewReader([]byte(`{"request":"body"}`))) - require.NoError(t, err) - req.Header.Set("X-Foo", "bar") - - resp, err := http.DefaultTransport.RoundTrip(req) - require.NoError(t, err) - - respBody, err := ioutil.ReadAll(resp.Body) - require.NoError(t, resp.Body.Close()) - require.NoError(t, err) - - assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) - assert.Equal(t, `request uri "/test?test=test" expected, "/wrong-uri" received`, string(respBody)) - - // Sending request with wrong method. - req, err = http.NewRequest(http.MethodGet, baseURL+"/test?test=test", bytes.NewReader([]byte(`{"request":"body"}`))) - require.NoError(t, err) - req.Header.Set("X-Foo", "bar") - - resp, err = http.DefaultTransport.RoundTrip(req) - require.NoError(t, err) - - respBody, err = ioutil.ReadAll(resp.Body) - require.NoError(t, resp.Body.Close()) - require.NoError(t, err) - - assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) - assert.Equal(t, `method "POST" expected, "GET" received`, string(respBody)) - - // Sending request with wrong header. - req, err = http.NewRequest(http.MethodPost, baseURL+"/test?test=test", bytes.NewReader([]byte(`{"request":"body"}`))) - require.NoError(t, err) - req.Header.Set("X-Foo", "space") - - resp, err = http.DefaultTransport.RoundTrip(req) - require.NoError(t, err) - - respBody, err = ioutil.ReadAll(resp.Body) - require.NoError(t, resp.Body.Close()) - require.NoError(t, err) - - assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) - assert.Equal(t, `header "X-Foo" with value "bar" expected, "space" received`, string(respBody)) - - // Sending request with wrong body. - req, err = http.NewRequest(http.MethodPost, baseURL+"/test?test=test", bytes.NewReader([]byte(`{"request":"body"}`))) - require.NoError(t, err) - req.Header.Set("X-Foo", "bar") - - resp, err = http.DefaultTransport.RoundTrip(req) - require.NoError(t, err) - - respBody, err = ioutil.ReadAll(resp.Body) - require.NoError(t, resp.Body.Close()) - require.NoError(t, err) - - assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) - assert.Equal(t, `unexpected request body: not equal: - { -- "foo": "bar" -+ "request": "body" - } -`, string(respBody)) -} - -func TestServerMock_ServeHTTP_concurrency(t *testing.T) { - // Creating REST service mock. - mock, url := httpmock.NewServer() - defer mock.Close() - - n := 50 - - for i := 0; i < n; i++ { - // Setting expectations for first request. - mock.Expect(httpmock.Expectation{ - Method: http.MethodGet, - RequestURI: "/test?test=test", - ResponseBody: []byte("body"), - }) - } - - wg := sync.WaitGroup{} - wg.Add(n) - - for i := 0; i < n; i++ { - go func() { - defer wg.Done() - - // Sending request with wrong header. - req, err := http.NewRequest(http.MethodGet, url+"/test?test=test", nil) - require.NoError(t, err) - req.Header.Set("X-Foo", "space") - - resp, err := http.DefaultTransport.RoundTrip(req) - require.NoError(t, err) - - respBody, err := ioutil.ReadAll(resp.Body) - require.NoError(t, resp.Body.Close()) - require.NoError(t, err) - - assert.Equal(t, http.StatusOK, resp.StatusCode) - assert.Equal(t, `body`, string(respBody)) - }() - } - - wg.Wait() - assert.NoError(t, mock.ExpectationsWereMet()) -} - -func TestServerMock_ResetExpectations(t *testing.T) { - // Creating REST service mock. - mock, _ := httpmock.NewServer() - defer mock.Close() - - mock.Expect(httpmock.Expectation{ - Method: http.MethodGet, - RequestURI: "/test?test=test", - ResponseBody: []byte("body"), - }) - - mock.ExpectAsync(httpmock.Expectation{ - Method: http.MethodGet, - RequestURI: "/test-async?test=test", - ResponseBody: []byte("body"), - }) - - assert.Error(t, mock.ExpectationsWereMet()) - mock.ResetExpectations() - assert.NoError(t, mock.ExpectationsWereMet()) -} - -func TestServerMock_vars(t *testing.T) { - sm, url := httpmock.NewServer() - sm.JSONComparer.Vars = &shared.Vars{} - sm.Expect(httpmock.Expectation{ - Method: http.MethodGet, - RequestURI: "/", - RequestBody: []byte(`{"foo":"bar","dyn":"$var1"}`), - ResponseBody: []byte(`{"bar":"foo","dynEcho":"$var1"}`), - }) - - req, err := http.NewRequest(http.MethodGet, url+"/", strings.NewReader(`{"foo":"bar","dyn":"abc"}`)) - require.NoError(t, err) - - resp, err := http.DefaultTransport.RoundTrip(req) - require.NoError(t, err) - - body, err := ioutil.ReadAll(resp.Body) - require.NoError(t, err) - - require.NoError(t, resp.Body.Close()) - - assert.Equal(t, `{"bar":"foo","dynEcho":"abc"}`, string(body)) -} - -func TestServerMock_ExpectAsync(t *testing.T) { - sm, url := httpmock.NewServer() - sm.Expect(httpmock.Expectation{ - Method: http.MethodGet, - RequestURI: "/", - ResponseBody: []byte(`{"bar":"foo"}`), - }) - sm.ExpectAsync(httpmock.Expectation{ - Method: http.MethodGet, - RequestURI: "/async1", - ResponseBody: []byte(`{"bar":"async1"}`), - }) - sm.ExpectAsync(httpmock.Expectation{ - Method: http.MethodGet, - RequestURI: "/async2", - ResponseBody: []byte(`{"bar":"async2"}`), - Unlimited: true, - }) - - wg := sync.WaitGroup{} - wg.Add(2) - - go func() { - defer wg.Done() - - req, err := http.NewRequest(http.MethodGet, url+"/async1", nil) - require.NoError(t, err) - - resp, err := http.DefaultTransport.RoundTrip(req) - require.NoError(t, err) - - body, err := ioutil.ReadAll(resp.Body) - require.NoError(t, err) - - require.NoError(t, resp.Body.Close()) - assert.Equal(t, `{"bar":"async1"}`, string(body)) - }() - - go func() { - defer wg.Done() - - for i := 0; i < 50; i++ { - req, err := http.NewRequest(http.MethodGet, url+"/async2", nil) - require.NoError(t, err) - - resp, err := http.DefaultTransport.RoundTrip(req) - require.NoError(t, err) - - body, err := ioutil.ReadAll(resp.Body) - require.NoError(t, err) - - require.NoError(t, resp.Body.Close()) - assert.Equal(t, `{"bar":"async2"}`, string(body)) - } - }() - - req, err := http.NewRequest(http.MethodGet, url+"/", nil) - require.NoError(t, err) - - resp, err := http.DefaultTransport.RoundTrip(req) - require.NoError(t, err) - - body, err := ioutil.ReadAll(resp.Body) - require.NoError(t, err) - - require.NoError(t, resp.Body.Close()) - assert.Equal(t, `{"bar":"foo"}`, string(body)) - - wg.Wait() - assert.NoError(t, sm.ExpectationsWereMet()) -} diff --git a/route.go b/route.go index 08d19fc..6fdc1fe 100644 --- a/route.go +++ b/route.go @@ -9,7 +9,7 @@ type HandlerWithUseCase interface { UseCase() usecase.Interactor } -// HandlerWithRoute is a http.Handler with routing information. +// HandlerWithRoute is a fchi.Handler with routing information. type HandlerWithRoute interface { // RouteMethod returns http method of action. RouteMethod() string diff --git a/web/example_test.go b/web/example_test.go index 9fc65ec..abce485 100644 --- a/web/example_test.go +++ b/web/example_test.go @@ -5,9 +5,11 @@ import ( "log" "net/http" + "github.com/swaggest/fchi" "github.com/swaggest/rest/nethttp" "github.com/swaggest/rest/web" "github.com/swaggest/usecase" + "github.com/valyala/fasthttp" ) // album represents data about a record album. @@ -41,7 +43,7 @@ func ExampleDefaultService() { log.Println("Starting service at http://localhost:8080") - if err := http.ListenAndServe("localhost:8080", service); err != nil { + if err := fasthttp.ListenAndServe("localhost:8080", fchi.RequestHandler(service)); err != nil { log.Fatal(err) } } diff --git a/web/service.go b/web/service.go index 45c0406..ca69d41 100644 --- a/web/service.go +++ b/web/service.go @@ -5,8 +5,8 @@ import ( "net/http" "strings" - "github.com/go-chi/chi/v5" - "github.com/go-chi/chi/v5/middleware" + "github.com/swaggest/fchi" + "github.com/swaggest/fchi/middleware" "github.com/swaggest/openapi-go/openapi3" "github.com/swaggest/rest" "github.com/swaggest/rest/chirouter" @@ -39,7 +39,7 @@ func DefaultService(options ...func(s *Service, initialized bool)) *Service { } if s.Wrapper == nil { - s.Wrapper = chirouter.NewWrapper(chi.NewRouter()) + s.Wrapper = chirouter.NewWrapper(fchi.NewRouter()) } if s.DecoderFactory == nil { @@ -77,7 +77,7 @@ func DefaultService(options ...func(s *Service, initialized bool)) *Service { type Service struct { *chirouter.Wrapper - PanicRecoveryMiddleware func(handler http.Handler) http.Handler // Default is middleware.Recoverer. + PanicRecoveryMiddleware func(handler fchi.Handler) fchi.Handler // Default is middleware.Recoverer. OpenAPI *openapi3.Spec OpenAPICollector *openapi.Collector DecoderFactory *request.DecoderFactory @@ -141,5 +141,5 @@ func (s *Service) Trace(pattern string, uc usecase.Interactor, options ...func(h func (s *Service) Docs(pattern string, swgui func(title, schemaURL, basePath string) http.Handler) { pattern = strings.TrimRight(pattern, "/") s.Method(http.MethodGet, pattern+"/openapi.json", s.OpenAPICollector) - s.Mount(pattern, swgui(s.OpenAPI.Info.Title, pattern+"/openapi.json", pattern)) + s.Mount(pattern, fchi.Adapt(swgui(s.OpenAPI.Info.Title, pattern+"/openapi.json", pattern))) } diff --git a/web/service_test.go b/web/service_test.go index d8ad725..39e0023 100644 --- a/web/service_test.go +++ b/web/service_test.go @@ -5,7 +5,6 @@ import ( "fmt" "io/ioutil" "net/http" - "net/http/httptest" "testing" "github.com/stretchr/testify/assert" @@ -14,6 +13,7 @@ import ( "github.com/swaggest/rest/nethttp" "github.com/swaggest/rest/web" "github.com/swaggest/usecase" + "github.com/valyala/fasthttp" ) type albumID struct { @@ -59,13 +59,15 @@ func TestDefaultService(t *testing.T) { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {}) }) - rw := httptest.NewRecorder() - r, err := http.NewRequest(http.MethodGet, "http://localhost/docs/openapi.json", nil) - require.NoError(t, err) - service.ServeHTTP(rw, r) + rc := &fasthttp.RequestCtx{ + Request: fasthttp.Request{}, + Response: fasthttp.Response{}, + } + rc.Request.SetRequestURI("http://localhost/docs/openapi.json") + service.ServeHTTP(rc, rc) - assert.Equal(t, http.StatusOK, rw.Code) - assertjson.EqualMarshal(t, rw.Body.Bytes(), service.OpenAPI) + assert.Equal(t, http.StatusOK, rc.Response.StatusCode()) + assertjson.EqualMarshal(t, rc.Response.Body(), service.OpenAPI) expected, err := ioutil.ReadFile("_testdata/openapi.json") require.NoError(t, err)