Skip to content

Commit b4955a8

Browse files
Merge pull request #602 from shutter-network/feat/read-only-endpoint
added middleware to enable/disable read only endpoints
2 parents 0a536b3 + a842c31 commit b4955a8

File tree

10 files changed

+393
-17
lines changed

10 files changed

+393
-17
lines changed

rolling-shutter/keyper/kprapi/kprapi.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ type Config interface {
3333
GetHTTPListenAddress() string
3434
GetAddress() common.Address
3535
GetInstanceID() uint64
36+
GetEnableWriteOperations() bool
3637
}
3738

3839
type Server struct {
@@ -149,6 +150,8 @@ func (srv *Server) waitShutdown(ctx context.Context) error {
149150
func (srv *Server) setupAPIRouter(swagger *openapi3.T) http.Handler {
150151
router := chi.NewRouter()
151152

153+
router.Use(kproapi.ConfigMiddleware(srv.config.GetEnableWriteOperations()))
154+
152155
router.Use(chimiddleware.OapiRequestValidator(swagger))
153156
_ = kproapi.HandlerFromMux(srv, router)
154157

rolling-shutter/keyper/kprconfig/config.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ type Config struct {
2323
DatabaseURL string
2424

2525
HTTPEnabled bool
26+
HTTPReadOnly bool
2627
HTTPListenAddress string
2728

2829
P2P *p2p.Config
@@ -59,6 +60,10 @@ func (c *Config) GetHTTPListenAddress() string {
5960
return c.HTTPListenAddress
6061
}
6162

63+
func (c *Config) GetEnableWriteOperations() bool {
64+
return c.HTTPEnabled && !c.HTTPReadOnly
65+
}
66+
6267
func (c *Config) GetMaxNumKeysPerMessage() uint64 {
6368
return c.MaxNumKeysPerMessage
6469
}
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
package kproapi
2+
3+
import (
4+
"encoding/json"
5+
"net/http"
6+
7+
"github.com/getkin/kin-openapi/openapi3"
8+
)
9+
10+
// isReadOnlyEndpoint checks if an endpoint is marked as read-only in the OpenAPI spec.
11+
func isReadOnlyEndpoint(operation *openapi3.Operation) bool {
12+
// Try to get the value directly from the map first.
13+
if val, exists := operation.Extensions["x-read-only"]; exists {
14+
// Handle json.RawMessage case.
15+
if rawMsg, ok := val.(json.RawMessage); ok {
16+
return string(rawMsg) == "true"
17+
}
18+
19+
// Handle direct boolean case.
20+
if boolVal, ok := val.(bool); ok {
21+
return boolVal
22+
}
23+
}
24+
return false
25+
}
26+
27+
// shouldEnableEndpoint determines if an endpoint should be accessible based on its type and configuration.
28+
func shouldEnableEndpoint(operation *openapi3.Operation, enableWriteOperations bool) bool {
29+
if isReadOnlyEndpoint(operation) {
30+
return true
31+
}
32+
return enableWriteOperations
33+
}
34+
35+
// findOperation looks up the OpenAPI operation for the given path and method.
36+
func findOperation(spec *openapi3.T, path string, method string) *openapi3.Operation {
37+
pathItem := spec.Paths.Find(path)
38+
if pathItem == nil {
39+
return nil
40+
}
41+
42+
switch method {
43+
case http.MethodGet:
44+
return pathItem.Get
45+
case http.MethodPost:
46+
return pathItem.Post
47+
case http.MethodPut:
48+
return pathItem.Put
49+
case http.MethodDelete:
50+
return pathItem.Delete
51+
default:
52+
return nil
53+
}
54+
}
55+
56+
// ConfigMiddleware creates a middleware that controls endpoint access based on configuration.
57+
func ConfigMiddleware(enableWriteOperations bool) MiddlewareFunc {
58+
return ConfigMiddlewareWithSpec(enableWriteOperations, GetSwagger)
59+
}
60+
61+
// ConfigMiddlewareWithSpec creates a middleware that controls endpoint access based on configuration.
62+
// This accepts a function to get the spec, making it more testable.
63+
func ConfigMiddlewareWithSpec(enableWriteOperations bool, getSpec func() (*openapi3.T, error)) MiddlewareFunc {
64+
return func(next http.Handler) http.Handler {
65+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
66+
// Load the OpenAPI specification.
67+
spec, err := getSpec()
68+
if err != nil {
69+
http.Error(w, "Internal server error", http.StatusInternalServerError)
70+
return
71+
}
72+
73+
// Find the operation for this request.
74+
operation := findOperation(spec, r.URL.Path, r.Method)
75+
if operation == nil {
76+
http.Error(w, "Endpoint not found", http.StatusNotFound)
77+
return
78+
}
79+
80+
// Check if the endpoint should be accessible.
81+
if !shouldEnableEndpoint(operation, enableWriteOperations) {
82+
http.Error(w, "Endpoint not enabled", http.StatusForbidden)
83+
return
84+
}
85+
86+
next.ServeHTTP(w, r)
87+
})
88+
}
89+
}
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
package kproapi
2+
3+
import (
4+
"encoding/json"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
"github.com/getkin/kin-openapi/openapi3"
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
func TestIsReadOnlyEndpoint(t *testing.T) {
14+
tests := []struct {
15+
name string
16+
operation *openapi3.Operation
17+
want bool
18+
}{
19+
{
20+
name: "read-only with json.RawMessage true",
21+
operation: func() *openapi3.Operation {
22+
op := &openapi3.Operation{}
23+
op.Extensions = map[string]interface{}{
24+
"x-read-only": json.RawMessage("true"),
25+
}
26+
return op
27+
}(),
28+
want: true,
29+
},
30+
{
31+
name: "read-only with json.RawMessage false",
32+
operation: func() *openapi3.Operation {
33+
op := &openapi3.Operation{}
34+
op.Extensions = map[string]interface{}{
35+
"x-read-only": json.RawMessage("false"),
36+
}
37+
return op
38+
}(),
39+
want: false,
40+
},
41+
{
42+
name: "read-only with direct boolean true",
43+
operation: func() *openapi3.Operation {
44+
op := &openapi3.Operation{}
45+
op.Extensions = map[string]interface{}{
46+
"x-read-only": true,
47+
}
48+
return op
49+
}(),
50+
want: true,
51+
},
52+
{
53+
name: "no read-only extension",
54+
operation: func() *openapi3.Operation {
55+
op := &openapi3.Operation{}
56+
op.Extensions = map[string]interface{}{}
57+
return op
58+
}(),
59+
want: false,
60+
},
61+
}
62+
63+
for _, tt := range tests {
64+
t.Run(tt.name, func(t *testing.T) {
65+
got := isReadOnlyEndpoint(tt.operation)
66+
assert.Equal(t, tt.want, got)
67+
})
68+
}
69+
}
70+
71+
func TestShouldEnableEndpoint(t *testing.T) {
72+
tests := []struct {
73+
name string
74+
operation *openapi3.Operation
75+
enableWriteOps bool
76+
want bool
77+
}{
78+
{
79+
name: "read-only endpoint with write ops disabled",
80+
operation: func() *openapi3.Operation {
81+
op := &openapi3.Operation{}
82+
op.Extensions = map[string]interface{}{
83+
"x-read-only": json.RawMessage("true"),
84+
}
85+
return op
86+
}(),
87+
enableWriteOps: false,
88+
want: true,
89+
},
90+
{
91+
name: "write endpoint with write ops enabled",
92+
operation: func() *openapi3.Operation {
93+
op := &openapi3.Operation{}
94+
op.Extensions = map[string]interface{}{}
95+
return op
96+
}(),
97+
enableWriteOps: true,
98+
want: true,
99+
},
100+
{
101+
name: "write endpoint with write ops disabled",
102+
operation: func() *openapi3.Operation {
103+
op := &openapi3.Operation{}
104+
op.Extensions = map[string]interface{}{}
105+
return op
106+
}(),
107+
enableWriteOps: false,
108+
want: false,
109+
},
110+
}
111+
112+
for _, tt := range tests {
113+
t.Run(tt.name, func(t *testing.T) {
114+
got := shouldEnableEndpoint(tt.operation, tt.enableWriteOps)
115+
assert.Equal(t, tt.want, got)
116+
})
117+
}
118+
}
119+
120+
func TestFindOperation(t *testing.T) {
121+
// Create a test spec
122+
spec := &openapi3.T{
123+
Paths: openapi3.Paths{
124+
"/test": &openapi3.PathItem{
125+
Get: &openapi3.Operation{},
126+
Post: &openapi3.Operation{},
127+
Put: &openapi3.Operation{},
128+
Delete: &openapi3.Operation{},
129+
},
130+
},
131+
}
132+
133+
tests := []struct {
134+
name string
135+
path string
136+
method string
137+
want *openapi3.Operation
138+
}{
139+
{
140+
name: "GET operation exists",
141+
path: "/test",
142+
method: http.MethodGet,
143+
want: spec.Paths.Find("/test").Get,
144+
},
145+
{
146+
name: "POST operation exists",
147+
path: "/test",
148+
method: http.MethodPost,
149+
want: spec.Paths.Find("/test").Post,
150+
},
151+
{
152+
name: "PUT operation exists",
153+
path: "/test",
154+
method: http.MethodPut,
155+
want: spec.Paths.Find("/test").Put,
156+
},
157+
{
158+
name: "DELETE operation exists",
159+
path: "/test",
160+
method: http.MethodDelete,
161+
want: spec.Paths.Find("/test").Delete,
162+
},
163+
{
164+
name: "non-existent path",
165+
path: "/nonexistent",
166+
method: http.MethodGet,
167+
want: nil,
168+
},
169+
{
170+
name: "unsupported method",
171+
path: "/test",
172+
method: "PATCH",
173+
want: nil,
174+
},
175+
}
176+
177+
for _, tt := range tests {
178+
t.Run(tt.name, func(t *testing.T) {
179+
got := findOperation(spec, tt.path, tt.method)
180+
assert.Equal(t, tt.want, got)
181+
})
182+
}
183+
}
184+
185+
func TestConfigMiddleware(t *testing.T) {
186+
// Create a test spec with both read-only and write operations
187+
spec := &openapi3.T{
188+
Paths: openapi3.Paths{
189+
"/read": &openapi3.PathItem{
190+
Get: func() *openapi3.Operation {
191+
op := &openapi3.Operation{}
192+
op.Extensions = map[string]interface{}{
193+
"x-read-only": json.RawMessage("true"),
194+
}
195+
return op
196+
}(),
197+
},
198+
"/write": &openapi3.PathItem{
199+
Post: &openapi3.Operation{},
200+
},
201+
},
202+
}
203+
204+
// Create a function that returns our test spec.
205+
//nolint:unparam
206+
getTestSpec := func() (*openapi3.T, error) {
207+
return spec, nil
208+
}
209+
210+
tests := []struct {
211+
name string
212+
enableWriteOps bool
213+
path string
214+
method string
215+
expectedStatus int
216+
}{
217+
{
218+
name: "read-only endpoint with write ops disabled",
219+
enableWriteOps: false,
220+
path: "/read",
221+
method: http.MethodGet,
222+
expectedStatus: http.StatusOK,
223+
},
224+
{
225+
name: "write endpoint with write ops enabled",
226+
enableWriteOps: true,
227+
path: "/write",
228+
method: http.MethodPost,
229+
expectedStatus: http.StatusOK,
230+
},
231+
{
232+
name: "write endpoint with write ops disabled",
233+
enableWriteOps: false,
234+
path: "/write",
235+
method: http.MethodPost,
236+
expectedStatus: http.StatusForbidden,
237+
},
238+
{
239+
name: "non-existent endpoint",
240+
enableWriteOps: true,
241+
path: "/nonexistent",
242+
method: http.MethodGet,
243+
expectedStatus: http.StatusNotFound,
244+
},
245+
}
246+
247+
for _, tt := range tests {
248+
t.Run(tt.name, func(t *testing.T) {
249+
// Create a test handler that always returns 200 OK.
250+
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) {
251+
w.WriteHeader(http.StatusOK)
252+
})
253+
254+
// Create the middleware with our test spec
255+
middleware := ConfigMiddlewareWithSpec(tt.enableWriteOps, getTestSpec)
256+
handler := middleware(nextHandler)
257+
258+
// Create a test request
259+
req := httptest.NewRequest(tt.method, tt.path, http.NoBody)
260+
w := httptest.NewRecorder()
261+
262+
// Serve the request
263+
handler.ServeHTTP(w, req)
264+
265+
// Check the response
266+
assert.Equal(t, tt.expectedStatus, w.Code)
267+
})
268+
}
269+
}

0 commit comments

Comments
 (0)