Skip to content

Commit b536d1b

Browse files
author
Siva M
authored
feat [sc 106247]: Add a caching layer to prevent SDK from sending the same payloads too frequently upstream (#193)
* cache middleware implementation
1 parent dc63c46 commit b536d1b

File tree

4 files changed

+417
-3
lines changed

4 files changed

+417
-3
lines changed

pkg/apiserver/server.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ func Start(params APIServerParams) {
5555
authRouter := r.NewRoute().Subrouter()
5656
authRouter.Use(handlers.RequireValidLicenseIDMiddleware)
5757

58+
cacheHandler := handlers.CacheMiddleware(handlers.NewCache(), handlers.CacheMiddlewareDefaultTTL)
59+
cachedRouter := r.NewRoute().Subrouter()
60+
cachedRouter.Use(cacheHandler)
61+
5862
r.HandleFunc("/healthz", handlers.Healthz)
5963

6064
// license
@@ -66,9 +70,9 @@ func Start(params APIServerParams) {
6670
r.HandleFunc("/api/v1/app/info", handlers.GetCurrentAppInfo).Methods("GET")
6771
r.HandleFunc("/api/v1/app/updates", handlers.GetAppUpdates).Methods("GET")
6872
r.HandleFunc("/api/v1/app/history", handlers.GetAppHistory).Methods("GET")
69-
r.HandleFunc("/api/v1/app/custom-metrics", handlers.SendCustomAppMetrics).Methods("POST", "PATCH")
70-
r.HandleFunc("/api/v1/app/custom-metrics/{key}", handlers.DeleteCustomAppMetricsKey).Methods("DELETE")
71-
r.HandleFunc("/api/v1/app/instance-tags", handlers.SendAppInstanceTags).Methods("POST")
73+
cachedRouter.HandleFunc("/api/v1/app/custom-metrics", handlers.SendCustomAppMetrics).Methods("POST", "PATCH")
74+
cachedRouter.HandleFunc("/api/v1/app/custom-metrics/{key}", handlers.DeleteCustomAppMetricsKey).Methods("DELETE")
75+
cachedRouter.HandleFunc("/api/v1/app/instance-tags", handlers.SendAppInstanceTags).Methods("POST")
7276

7377
// integration
7478
r.HandleFunc("/api/v1/integration/mock-data", handlers.EnforceMockAccess(handlers.PostIntegrationMockData)).Methods("POST")

pkg/handlers/middleware.go

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
11
package handlers
22

33
import (
4+
"bytes"
5+
"crypto/sha256"
6+
"encoding/json"
7+
"fmt"
8+
"io"
49
"net/http"
10+
"reflect"
11+
"sync"
12+
"time"
513

14+
"github.com/gorilla/mux"
15+
"github.com/pkg/errors"
616
"github.com/replicatedhq/replicated-sdk/pkg/handlers/types"
17+
"github.com/replicatedhq/replicated-sdk/pkg/logger"
718
"github.com/replicatedhq/replicated-sdk/pkg/store"
819
)
920

@@ -44,3 +55,134 @@ func RequireValidLicenseIDMiddleware(next http.Handler) http.Handler {
4455
next.ServeHTTP(w, r)
4556
})
4657
}
58+
59+
// Code for the cache middleware
60+
type CacheEntry struct {
61+
RequestBody []byte
62+
ResponseBody []byte
63+
StatusCode int
64+
Expiry time.Time
65+
}
66+
67+
type cache struct {
68+
store map[string]CacheEntry
69+
mu sync.RWMutex
70+
}
71+
72+
func NewCache() *cache {
73+
return &cache{
74+
store: map[string]CacheEntry{},
75+
}
76+
}
77+
78+
func (c *cache) Get(key string) (CacheEntry, bool) {
79+
c.mu.RLock()
80+
defer c.mu.RUnlock()
81+
82+
entry, found := c.store[key]
83+
if !found || time.Now().After(entry.Expiry) {
84+
return CacheEntry{}, false
85+
}
86+
return entry, true
87+
}
88+
89+
func (c *cache) Set(key string, entry CacheEntry, duration time.Duration) {
90+
c.mu.Lock()
91+
defer c.mu.Unlock()
92+
93+
// Clean up expired entries
94+
for k, v := range c.store {
95+
if time.Now().After(v.Expiry) {
96+
delete(c.store, k)
97+
}
98+
}
99+
100+
entry.Expiry = time.Now().Add(duration)
101+
c.store[key] = entry
102+
}
103+
104+
type responseRecorder struct {
105+
http.ResponseWriter
106+
Body *bytes.Buffer
107+
StatusCode int
108+
}
109+
110+
func (r *responseRecorder) WriteHeader(code int) {
111+
r.StatusCode = code
112+
r.ResponseWriter.WriteHeader(code)
113+
}
114+
115+
func (r *responseRecorder) Write(b []byte) (int, error) {
116+
r.Body.Write(b)
117+
return r.ResponseWriter.Write(b)
118+
}
119+
120+
const CacheMiddlewareDefaultTTL = 1 * time.Minute
121+
122+
func CacheMiddleware(cache *cache, duration time.Duration) mux.MiddlewareFunc {
123+
return func(next http.Handler) http.Handler {
124+
return cacheMiddleware(next, cache, duration)
125+
}
126+
}
127+
128+
func cacheMiddleware(next http.Handler, cache *cache, duration time.Duration) http.HandlerFunc {
129+
return func(w http.ResponseWriter, r *http.Request) {
130+
body, err := io.ReadAll(r.Body)
131+
if err != nil {
132+
logger.Error(errors.Wrap(err, "cache middleware - failed to read request body"))
133+
http.Error(w, "cache middleware: unable to read request body", http.StatusInternalServerError)
134+
return
135+
}
136+
r.Body = io.NopCloser(bytes.NewBuffer(body))
137+
138+
hash := sha256.Sum256([]byte(r.Method + "::" + r.URL.Path + "::" + r.URL.Query().Encode()))
139+
key := fmt.Sprintf("%x", hash)
140+
141+
if entry, found := cache.Get(key); found && IsSamePayload(entry.RequestBody, body) {
142+
logger.Infof("cache middleware: serving cached payload for method: %s path: %s ttl: %s ", r.Method, r.URL.Path, time.Until(entry.Expiry).Round(time.Second).String())
143+
w.Header().Set("X-Replicated-Rate-Limited", "true")
144+
JSONCached(w, entry.StatusCode, json.RawMessage(entry.ResponseBody))
145+
return
146+
}
147+
148+
recorder := &responseRecorder{ResponseWriter: w, Body: &bytes.Buffer{}}
149+
next.ServeHTTP(recorder, r)
150+
151+
// Save only successful responses in the cache
152+
if recorder.StatusCode < 200 || recorder.StatusCode >= 300 {
153+
return
154+
}
155+
156+
cache.Set(key, CacheEntry{
157+
StatusCode: recorder.StatusCode,
158+
RequestBody: body,
159+
ResponseBody: recorder.Body.Bytes(),
160+
}, duration)
161+
162+
}
163+
}
164+
165+
func IsSamePayload(a, b []byte) bool {
166+
if len(a) == 0 && len(b) == 0 {
167+
return true
168+
}
169+
170+
if len(a) == 0 {
171+
a = []byte(`{}`)
172+
}
173+
174+
if len(b) == 0 {
175+
b = []byte(`{}`)
176+
}
177+
178+
var aPayload, bPayload map[string]interface{}
179+
if err := json.Unmarshal(a, &aPayload); err != nil {
180+
logger.Error(errors.Wrap(err, "failed to unmarshal payload A"))
181+
return false
182+
}
183+
if err := json.Unmarshal(b, &bPayload); err != nil {
184+
logger.Error(errors.Wrap(err, "failed to unmarshal payload B"))
185+
return false
186+
}
187+
return reflect.DeepEqual(aPayload, bPayload)
188+
}

0 commit comments

Comments
 (0)