Skip to content

Commit 38b356b

Browse files
authored
CL64.3-2-JWT-Replay (#19544)
* make jti required and use it to prevent replay attacks. also, make exp and iat required and enforce 5 minute max token age * emit metrics and add optional field to verify method in JWT * revert changes in http_trigger_handler and change metrics name
1 parent 0a7b72f commit 38b356b

File tree

6 files changed

+361
-11
lines changed

6 files changed

+361
-11
lines changed

core/services/gateway/handlers/capabilities/v2/http_trigger_handler_test.go

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,14 @@ func TestHttpTriggerHandler_HandleUserTriggerRequest(t *testing.T) {
206206
Method: gateway_common.MethodWorkflowExecute,
207207
Params: &rawParams,
208208
}
209-
req.Auth = createTestJWTToken(t, req, privateKey)
210-
211209
// First request should succeed
210+
req.Auth = createTestJWTToken(t, req, privateKey)
212211
mockDon.EXPECT().SendToNode(mock.Anything, mock.Anything, mock.Anything).Return(nil).Times(3)
213212
err = handler.HandleUserTriggerRequest(testutils.Context(t), req, callback1, time.Now())
214213
require.NoError(t, err)
215214

216215
// Second request with same ID should fail
216+
req.Auth = createTestJWTToken(t, req, privateKey)
217217
err = handler.HandleUserTriggerRequest(testutils.Context(t), req, callback2, time.Now())
218218
require.Error(t, err)
219219
require.Contains(t, err.Error(), "in-flight request")
@@ -223,6 +223,45 @@ func TestHttpTriggerHandler_HandleUserTriggerRequest(t *testing.T) {
223223
requireUserErrorSent(t, r, jsonrpc.ErrConflict)
224224
})
225225

226+
t.Run("duplicate JWT token and request ID", func(t *testing.T) {
227+
handler, mockDon := createTestTriggerHandler(t)
228+
privateKey := createTestPrivateKey(t)
229+
registerWorkflow(t, handler, workflowID, privateKey)
230+
callback1 := hc.NewCallback()
231+
callback2 := hc.NewCallback()
232+
233+
triggerReq := gateway_common.HTTPTriggerRequest{
234+
Workflow: gateway_common.WorkflowSelector{
235+
WorkflowID: workflowID,
236+
},
237+
Input: []byte(`{"key": "value"}`),
238+
}
239+
reqBytes, err := json.Marshal(triggerReq)
240+
require.NoError(t, err)
241+
242+
rawParams := json.RawMessage(reqBytes)
243+
req := &jsonrpc.Request[json.RawMessage]{
244+
Version: "2.0",
245+
ID: requestID,
246+
Method: gateway_common.MethodWorkflowExecute,
247+
Params: &rawParams,
248+
}
249+
// First request should succeed
250+
req.Auth = createTestJWTToken(t, req, privateKey)
251+
mockDon.EXPECT().SendToNode(mock.Anything, mock.Anything, mock.Anything).Return(nil).Times(3)
252+
err = handler.HandleUserTriggerRequest(testutils.Context(t), req, callback1, time.Now())
253+
require.NoError(t, err)
254+
255+
// Second request with same ID should fail
256+
err = handler.HandleUserTriggerRequest(testutils.Context(t), req, callback2, time.Now())
257+
require.Error(t, err)
258+
require.Contains(t, err.Error(), "token has already been used")
259+
260+
r, err := callback2.Wait(t.Context())
261+
require.NoError(t, err)
262+
requireUserErrorSent(t, r, jsonrpc.ErrInvalidRequest)
263+
})
264+
226265
t.Run("invalid input JSON", func(t *testing.T) {
227266
handler, _ := createTestTriggerHandler(t)
228267
callback := hc.NewCallback()
@@ -380,7 +419,6 @@ func TestHttpTriggerHandler_ReapExpiredCallbacks(t *testing.T) {
380419
Params: &rawParams,
381420
}
382421
privateKey := createTestPrivateKey(t)
383-
req.Auth = createTestJWTToken(t, req, privateKey)
384422
cfg := ServiceConfig{
385423
CleanUpPeriodMs: 100,
386424
MaxTriggerRequestDurationMs: 50,
@@ -389,6 +427,7 @@ func TestHttpTriggerHandler_ReapExpiredCallbacks(t *testing.T) {
389427
registerWorkflow(t, handler, workflowID, privateKey)
390428

391429
t.Run("reap expired callbacks", func(t *testing.T) {
430+
req.Auth = createTestJWTToken(t, req, privateKey)
392431
callback := hc.NewCallback()
393432
mockDon.EXPECT().SendToNode(mock.Anything, mock.Anything, mock.Anything).Return(nil).Times(3)
394433
err = handler.HandleUserTriggerRequest(testutils.Context(t), req, callback, time.Now())
@@ -413,6 +452,7 @@ func TestHttpTriggerHandler_ReapExpiredCallbacks(t *testing.T) {
413452
})
414453

415454
t.Run("keep non-expired callbacks", func(t *testing.T) {
455+
req.Auth = createTestJWTToken(t, req, privateKey)
416456
callback := hc.NewCallback()
417457

418458
mockDon.EXPECT().SendToNode(mock.Anything, mock.Anything, mock.Anything).Return(nil).Times(3)

core/services/gateway/handlers/capabilities/v2/metrics/metrics.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@ type TriggerMetrics struct {
5656
metadataRequestCount metric.Int64Counter
5757
metadataObservationsCleanUpCount metric.Int64Counter
5858
metadataObservationsCount metric.Int64Gauge
59+
jwtCacheSize metric.Int64Gauge
60+
jwtCacheCleanUpCount metric.Int64Counter
5961
}
6062

6163
// Metrics combines all gateway metrics for dependency injection
@@ -328,6 +330,22 @@ func newTriggerMetrics(meter metric.Meter) (*TriggerMetrics, error) {
328330
return nil, fmt.Errorf("failed to create workflow metadata observations count metric: %w", err)
329331
}
330332

333+
m.jwtCacheSize, err = meter.Int64Gauge(
334+
"http_trigger_jwt_cache_size",
335+
metric.WithDescription("Current number of entries in JWT replay protection cache"),
336+
)
337+
if err != nil {
338+
return nil, fmt.Errorf("failed to create HTTP trigger JWT cache size metric: %w", err)
339+
}
340+
341+
m.jwtCacheCleanUpCount, err = meter.Int64Counter(
342+
"http_trigger_jwt_cache_cleanup_count",
343+
metric.WithDescription("Number of JWT replay protection cache entries cleaned up"),
344+
)
345+
if err != nil {
346+
return nil, fmt.Errorf("failed to create HTTP trigger JWT cache cleanup count metric: %w", err)
347+
}
348+
331349
return m, nil
332350
}
333351

@@ -456,3 +474,11 @@ func (m *TriggerMetrics) IncrementMetadataObservationsCleanUpCount(ctx context.C
456474
func (m *TriggerMetrics) RecordMetadataObservationsCount(ctx context.Context, count int64, lggr logger.Logger) {
457475
m.metadataObservationsCount.Record(ctx, count)
458476
}
477+
478+
func (m *TriggerMetrics) RecordJwtCacheSize(ctx context.Context, size int64, lggr logger.Logger) {
479+
m.jwtCacheSize.Record(ctx, size)
480+
}
481+
482+
func (m *TriggerMetrics) IncrementJwtCacheCleanUpCount(ctx context.Context, count int64, lggr logger.Logger) {
483+
m.jwtCacheCleanUpCount.Add(ctx, count)
484+
}

core/services/gateway/handlers/capabilities/v2/workflow_metadata_handler.go

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,13 @@ type workflowReference struct {
2828
workflowTag string
2929
}
3030

31+
// jwtReplayCache manages used JWT IDs to prevent replay attacks
32+
type jwtReplayCache struct {
33+
mu sync.RWMutex
34+
cleanupPeriod time.Duration
35+
cache map[string]time.Time // jti -> timestamp
36+
}
37+
3138
type WorkflowMetadataHandler struct {
3239
services.StateMachine
3340
lggr logger.Logger
@@ -41,6 +48,7 @@ type WorkflowMetadataHandler struct {
4148
donConfig *config.DONConfig
4249
stopCh services.StopChan
4350
metrics *metrics.Metrics
51+
jwtCache *jwtReplayCache // JWT replay protection cache
4452
}
4553

4654
// NewWorkflowMetadataHandler creates a new WorkflowMetadataHandler.
@@ -58,15 +66,22 @@ func NewWorkflowMetadataHandler(lggr logger.Logger, cfg ServiceConfig, don handl
5866
config: cfg,
5967
stopCh: make(services.StopChan),
6068
metrics: metrics,
69+
jwtCache: newJWTReplayCache(time.Duration(cfg.CleanUpPeriodMs) * time.Millisecond),
6170
}
6271
}
6372

6473
func (h *WorkflowMetadataHandler) Authorize(workflowID string, token string, req *jsonrpc.Request[json.RawMessage]) (*gateway.AuthorizedKey, error) {
65-
_, signer, err := utils.VerifyRequestJWT(token, *req)
74+
claims, signer, err := utils.VerifyRequestJWT(token, *req)
6675
if err != nil {
6776
h.lggr.Errorw("Failed to verify JWT", "error", err)
6877
return nil, err
6978
}
79+
80+
if h.jwtCache.isReplay(claims.ID) {
81+
h.lggr.Warnw("JWT token has already been used", "workflowID", workflowID, "signer", signer.Hex(), "jti", claims.ID)
82+
return nil, errors.New("JWT token has already been used. Please generate a new one with new id (jti)")
83+
}
84+
7085
keys, exists := h.authorizedKeys[workflowID]
7186
if !exists {
7287
h.lggr.Errorw("Workflow ID not found in authorized keys", "workflowID", workflowID)
@@ -80,6 +95,8 @@ func (h *WorkflowMetadataHandler) Authorize(workflowID string, token string, req
8095
h.lggr.Errorw("Signer not found in authorized keys", "signer", signer.Hex())
8196
return nil, errors.New("signer not found in authorized keys")
8297
}
98+
h.jwtCache.recordUsage(claims.ID)
99+
83100
return &key, nil
84101
}
85102

@@ -205,6 +222,14 @@ func (h *WorkflowMetadataHandler) Start(ctx context.Context) error {
205222
}
206223
})
207224
h.runTicker(time.Duration(h.config.MetadataAggregationIntervalMs)*time.Millisecond, h.syncMetadata)
225+
226+
h.runTicker(h.jwtCache.cleanupPeriod, func() {
227+
now := time.Now()
228+
expiredCount := h.jwtCache.cleanupOldEntries(now.Add(-h.jwtCache.cleanupPeriod))
229+
h.metrics.Trigger.IncrementJwtCacheCleanUpCount(context.Background(), int64(expiredCount), h.lggr)
230+
h.metrics.Trigger.RecordJwtCacheSize(context.Background(), int64(len(h.jwtCache.cache)), h.lggr)
231+
h.lggr.Debugw("Workflow execution cache cleanup completed", "expired_entries", expiredCount, "remaining_entries", len(h.jwtCache.cache))
232+
})
208233
return nil
209234
})
210235
}
@@ -277,3 +302,39 @@ func (h *WorkflowMetadataHandler) Close() error {
277302
return nil
278303
})
279304
}
305+
306+
func newJWTReplayCache(cleanupPeriod time.Duration) *jwtReplayCache {
307+
return &jwtReplayCache{
308+
cache: make(map[string]time.Time),
309+
cleanupPeriod: cleanupPeriod,
310+
}
311+
}
312+
313+
func (cache *jwtReplayCache) isReplay(jti string) bool {
314+
cache.mu.RLock()
315+
defer cache.mu.RUnlock()
316+
317+
_, exists := cache.cache[jti]
318+
return exists
319+
}
320+
321+
func (cache *jwtReplayCache) recordUsage(jti string) {
322+
cache.mu.Lock()
323+
defer cache.mu.Unlock()
324+
325+
cache.cache[jti] = time.Now()
326+
}
327+
328+
// cleanupOldEntries removes expired entries from the cache
329+
func (cache *jwtReplayCache) cleanupOldEntries(cutoff time.Time) int {
330+
cache.mu.Lock()
331+
defer cache.mu.Unlock()
332+
var expiredCount int
333+
for jti, createdAt := range cache.cache {
334+
if createdAt.Before(cutoff) {
335+
delete(cache.cache, jti)
336+
expiredCount++
337+
}
338+
}
339+
return expiredCount
340+
}

core/services/gateway/handlers/capabilities/v2/workflow_metadata_handler_test.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -941,6 +941,67 @@ func TestWorkflowMetadataHandler_Authorize(t *testing.T) {
941941
require.Contains(t, err.Error(), "JWT digest does not match request digest")
942942
require.Nil(t, key)
943943
})
944+
945+
t.Run("JWT replay protection", func(t *testing.T) {
946+
params := json.RawMessage(`{"test": "data"}`)
947+
req := &jsonrpc.Request[json.RawMessage]{
948+
Version: "2.0",
949+
ID: "test-request-id-replay",
950+
Method: gateway_common.MethodWorkflowExecute,
951+
Params: &params,
952+
}
953+
954+
token, err := utils.CreateRequestJWT(*req)
955+
require.NoError(t, err)
956+
957+
tokenString, err := token.SignedString(privateKey)
958+
require.NoError(t, err)
959+
960+
key, err := handler.Authorize(workflowID, tokenString, req)
961+
require.NoError(t, err)
962+
require.NotNil(t, key)
963+
964+
// Second authorization with same JWT should fail (replay attack)
965+
key, err = handler.Authorize(workflowID, tokenString, req)
966+
require.Error(t, err)
967+
require.Contains(t, err.Error(), "JWT token has already been used. Please generate a new one with new id (jti)")
968+
require.Nil(t, key)
969+
})
970+
971+
t.Run("different JWT IDs should work", func(t *testing.T) {
972+
params := json.RawMessage(`{"test": "data"}`)
973+
req1 := &jsonrpc.Request[json.RawMessage]{
974+
Version: "2.0",
975+
ID: "test-request-id-1",
976+
Method: gateway_common.MethodWorkflowExecute,
977+
Params: &params,
978+
}
979+
980+
req2 := &jsonrpc.Request[json.RawMessage]{
981+
Version: "2.0",
982+
ID: "test-request-id-2",
983+
Method: gateway_common.MethodWorkflowExecute,
984+
Params: &params,
985+
}
986+
987+
token1, err := utils.CreateRequestJWT(*req1)
988+
require.NoError(t, err)
989+
tokenString1, err := token1.SignedString(privateKey)
990+
require.NoError(t, err)
991+
992+
key1, err := handler.Authorize(workflowID, tokenString1, req1)
993+
require.NoError(t, err)
994+
require.NotNil(t, key1)
995+
996+
token2, err := utils.CreateRequestJWT(*req2)
997+
require.NoError(t, err)
998+
tokenString2, err := token2.SignedString(privateKey)
999+
require.NoError(t, err)
1000+
1001+
key2, err := handler.Authorize(workflowID, tokenString2, req2)
1002+
require.NoError(t, err)
1003+
require.NotNil(t, key2)
1004+
})
9441005
}
9451006

9461007
func TestWorkflowMetadataHandler_GetWorkflowID(t *testing.T) {

0 commit comments

Comments
 (0)