Skip to content

Commit de358b8

Browse files
authored
Extract workflow id interceptor (#8834)
## What changed? Extracting workflow id during redirection based on interface ## Why? Configurable checks if needed based on workflow id ## How did you test it? - [ ] built - [ ] run locally and tested manually - [ ] covered by existing tests - [ ] added new unit test(s) - [ ] added new functional test(s) ``` michaely520 ~/projects/temporal % TEMPORAL_TEST_LOG_LEVEL=debug go test -tags test_dep -v ./tests/xdc -run TestStreamBasedReplicationTestSuite/EnableTransitionHistory/TestCloseTransferTaskAckedReplication -timeout 5m -count=1 | grep "workflow ID extraction" 2025-12-17T08:29:04.827-0800 debug workflow ID extraction: adding workflow ID to context {"cluster-name": "active_odmrd", "host": "127.0.0.1:63201", "operation": "StartWorkflowExecution", "wf-id": "test-replication-40c67c0f-1730-4b61-9a84-fde179771032", "logging-call-at": "/Users/michaely520/projects/temporal/common/rpc/interceptor/workflow_id_interceptor.go:129"} 2025-12-17T08:29:04.837-0800 debug workflow ID extraction: adding workflow ID to context {"cluster-name": "active_odmrd", "host": "127.0.0.1:63201", "operation": "RespondWorkflowTaskCompleted", "wf-id": "test-replication-40c67c0f-1730-4b61-9a84-fde179771032", "logging-call-at": "/Users/michaely520/projects/temporal/common/rpc/interceptor/workflow_id_interceptor.go:129"} 2025-12-17T08:29:04.940-0800 debug workflow ID extraction: adding workflow ID to context {"cluster-name": "active_odmrd", "host": "127.0.0.1:63201", "operation": "DescribeWorkflowExecution", "wf-id": "test-replication-40c67c0f-1730-4b61-9a84-fde179771032", "logging-call-at": "/Users/michaely520/projects/temporal/common/rpc/interceptor/workflow_id_interceptor.go:129"} 2025-12-17T08:29:09.949-0800 debug workflow ID extraction: adding workflow ID to context {"cluster-name": "standby_odmrd", "host": "127.0.0.1:63210", "operation": "DescribeWorkflowExecution", "wf-id": "test-replication-40c67c0f-1730-4b61-9a84-fde179771032", "logging-call-at": "/Users/michaely520/projects/temporal/common/rpc/interceptor/workflow_id_interceptor.go:129"} 2025-12-17T08:29:12.061-0800 debug workflow ID extraction: adding workflow ID to context {"cluster-name": "standby_odmrd", "host": "127.0.0.1:63210", "operation": "DescribeWorkflowExecution", "wf-id": "test-replication-40c67c0f-1730-4b61-9a84-fde179771032", "logging-call-at": "/Users/michaely520/projects/temporal/common/rpc/interceptor/workflow_id_interceptor.go:129"} ```
1 parent 1c40a68 commit de358b8

File tree

8 files changed

+941
-45
lines changed

8 files changed

+941
-45
lines changed
Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
package interceptor
2+
3+
import (
4+
"context"
5+
"strings"
6+
7+
commonpb "go.temporal.io/api/common/v1"
8+
"go.temporal.io/api/workflowservice/v1"
9+
"go.temporal.io/server/common/api"
10+
"go.temporal.io/server/common/namespace"
11+
"go.temporal.io/server/common/tasktoken"
12+
)
13+
14+
type BusinessIDExtractor struct {
15+
serializer tasktoken.Serializer
16+
}
17+
18+
func NewBusinessIDExtractor() BusinessIDExtractor {
19+
return BusinessIDExtractor{
20+
serializer: *tasktoken.NewSerializer(),
21+
}
22+
}
23+
24+
// WorkflowServiceExtractor returns a BusinessIDExtractorFunc that extracts business ID
25+
// from WorkflowService API requests using the provided BusinessIDExtractor.
26+
func WorkflowServiceExtractor(extractor BusinessIDExtractor) BusinessIDExtractorFunc {
27+
return func(_ context.Context, req any, fullMethod string) string {
28+
// Only process WorkflowService APIs
29+
if !strings.HasPrefix(fullMethod, api.WorkflowServicePrefix) {
30+
return ""
31+
}
32+
33+
methodName := api.MethodName(fullMethod)
34+
pattern, hasPattern := methodToPattern[methodName]
35+
if !hasPattern {
36+
return ""
37+
}
38+
39+
return extractor.Extract(req, pattern)
40+
}
41+
}
42+
43+
// Interfaces for extracting business ID from different request types.
44+
type (
45+
workflowIDGetter interface {
46+
GetWorkflowId() string
47+
}
48+
49+
workflowExecutionGetter interface {
50+
GetWorkflowExecution() *commonpb.WorkflowExecution
51+
}
52+
53+
executionGetter interface {
54+
GetExecution() *commonpb.WorkflowExecution
55+
}
56+
57+
taskTokenGetter interface {
58+
GetTaskToken() []byte
59+
}
60+
)
61+
62+
// Extract extracts business ID from the request using the specified pattern.
63+
// Returns the business ID or namespace.EmptyBusinessID if not found.
64+
func (e BusinessIDExtractor) Extract(req any, pattern BusinessIDPattern) string {
65+
if req == nil {
66+
return namespace.EmptyBusinessID
67+
}
68+
69+
//nolint:revive // identical-switch-branches: PatternNone and default both fall through intentionally
70+
switch pattern {
71+
case PatternWorkflowID:
72+
if getter, ok := req.(workflowIDGetter); ok {
73+
return getter.GetWorkflowId()
74+
}
75+
76+
case PatternWorkflowExecution:
77+
if getter, ok := req.(workflowExecutionGetter); ok {
78+
if exec := getter.GetWorkflowExecution(); exec != nil {
79+
return exec.GetWorkflowId()
80+
}
81+
}
82+
83+
case PatternExecution:
84+
if getter, ok := req.(executionGetter); ok {
85+
if exec := getter.GetExecution(); exec != nil {
86+
return exec.GetWorkflowId()
87+
}
88+
}
89+
90+
case PatternTaskToken:
91+
if getter, ok := req.(taskTokenGetter); ok {
92+
if tokenBytes := getter.GetTaskToken(); len(tokenBytes) > 0 {
93+
if taskToken, err := e.serializer.Deserialize(tokenBytes); err == nil {
94+
return taskToken.GetWorkflowId()
95+
}
96+
}
97+
}
98+
99+
case PatternMultiOperation:
100+
return e.extractMultiOperation(req)
101+
102+
case PatternNone:
103+
// No extraction needed
104+
105+
default:
106+
// Unknown pattern
107+
}
108+
109+
return namespace.EmptyBusinessID
110+
}
111+
112+
// extractMultiOperation extracts business ID from ExecuteMultiOperationRequest.
113+
// The business ID is extracted from the first operation's StartWorkflow request.
114+
func (e BusinessIDExtractor) extractMultiOperation(req any) string {
115+
multiOpReq, ok := req.(*workflowservice.ExecuteMultiOperationRequest)
116+
if !ok {
117+
return namespace.EmptyBusinessID
118+
}
119+
120+
ops := multiOpReq.GetOperations()
121+
if len(ops) == 0 {
122+
return namespace.EmptyBusinessID
123+
}
124+
125+
firstOp := ops[0]
126+
if firstOp == nil {
127+
return namespace.EmptyBusinessID
128+
}
129+
130+
startWorkflow := firstOp.GetStartWorkflow()
131+
if startWorkflow == nil {
132+
// First operation is not StartWorkflow - try to get from UpdateWorkflow
133+
updateWorkflow := firstOp.GetUpdateWorkflow()
134+
if updateWorkflow != nil && updateWorkflow.GetWorkflowExecution() != nil {
135+
return updateWorkflow.GetWorkflowExecution().GetWorkflowId()
136+
}
137+
return namespace.EmptyBusinessID
138+
}
139+
140+
return startWorkflow.GetWorkflowId()
141+
}
Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
package interceptor
2+
3+
import (
4+
"context"
5+
6+
"go.temporal.io/server/common/log"
7+
"go.temporal.io/server/common/log/tag"
8+
"go.temporal.io/server/common/namespace"
9+
"google.golang.org/grpc"
10+
)
11+
12+
type (
13+
businessIDContextKey struct{}
14+
15+
// BusinessIDPattern defines the expected interface pattern for extracting business ID
16+
BusinessIDPattern int
17+
18+
// BusinessIDExtractorFunc extracts business ID from a request.
19+
// Returns empty string if this extractor doesn't handle the request.
20+
BusinessIDExtractorFunc func(ctx context.Context, req any, fullMethod string) string
21+
22+
// BusinessIDInterceptor extracts business ID from requests and adds it to context.
23+
// It iterates through a list of extractor functions until one returns a non-empty business ID.
24+
BusinessIDInterceptor struct {
25+
extractors []BusinessIDExtractorFunc
26+
logger log.Logger
27+
}
28+
)
29+
30+
var businessIDCtxKey = businessIDContextKey{}
31+
32+
const (
33+
// PatternNone indicates no business ID extraction is needed
34+
PatternNone BusinessIDPattern = iota
35+
// PatternWorkflowID indicates extraction via GetWorkflowId() method
36+
PatternWorkflowID
37+
// PatternWorkflowExecution indicates extraction via GetWorkflowExecution().GetWorkflowId()
38+
PatternWorkflowExecution
39+
// PatternExecution indicates extraction via GetExecution().GetWorkflowId()
40+
PatternExecution
41+
// PatternTaskToken indicates extraction via deserializing GetTaskToken()
42+
PatternTaskToken
43+
// PatternMultiOperation indicates extraction from ExecuteMultiOperationRequest
44+
PatternMultiOperation
45+
)
46+
47+
// methodToPattern maps API method names to their expected business ID extraction pattern.
48+
// Methods not in this map are treated as PatternNone (no business ID extraction needed).
49+
var methodToPattern = map[string]BusinessIDPattern{
50+
// Pattern: GetWorkflowId() - direct WorkflowId field
51+
"StartWorkflowExecution": PatternWorkflowID,
52+
"SignalWithStartWorkflowExecution": PatternWorkflowID,
53+
"PauseWorkflowExecution": PatternWorkflowID,
54+
"UnpauseWorkflowExecution": PatternWorkflowID,
55+
"RecordActivityTaskHeartbeatById": PatternWorkflowID,
56+
"RespondActivityTaskCompletedById": PatternWorkflowID,
57+
"RespondActivityTaskCanceledById": PatternWorkflowID,
58+
"RespondActivityTaskFailedById": PatternWorkflowID,
59+
60+
// Pattern: GetWorkflowExecution().GetWorkflowId()
61+
"DeleteWorkflowExecution": PatternWorkflowExecution,
62+
"RequestCancelWorkflowExecution": PatternWorkflowExecution,
63+
"ResetWorkflowExecution": PatternWorkflowExecution,
64+
"SignalWorkflowExecution": PatternWorkflowExecution,
65+
"TerminateWorkflowExecution": PatternWorkflowExecution,
66+
"UpdateWorkflowExecution": PatternWorkflowExecution,
67+
"UpdateWorkflowExecutionOptions": PatternWorkflowExecution,
68+
69+
// Pattern: GetExecution().GetWorkflowId()
70+
"DescribeWorkflowExecution": PatternExecution,
71+
"GetWorkflowExecutionHistory": PatternExecution,
72+
"GetWorkflowExecutionHistoryReverse": PatternExecution,
73+
"QueryWorkflow": PatternExecution,
74+
"ResetStickyTaskQueue": PatternExecution,
75+
"ResetActivity": PatternExecution,
76+
"PauseActivity": PatternExecution,
77+
"UnpauseActivity": PatternExecution,
78+
"UpdateActivityOptions": PatternExecution,
79+
"TriggerWorkflowRule": PatternExecution,
80+
81+
// Pattern: TaskToken deserialization
82+
"RecordActivityTaskHeartbeat": PatternTaskToken,
83+
"RespondActivityTaskCompleted": PatternTaskToken,
84+
"RespondActivityTaskCanceled": PatternTaskToken,
85+
"RespondActivityTaskFailed": PatternTaskToken,
86+
"RespondWorkflowTaskCompleted": PatternTaskToken,
87+
"RespondWorkflowTaskFailed": PatternTaskToken,
88+
89+
// Pattern: ExecuteMultiOperation special handling
90+
"ExecuteMultiOperation": PatternMultiOperation,
91+
}
92+
93+
// NewBusinessIDInterceptor creates a new BusinessIDInterceptor with the given extractor functions.
94+
// Extractors are called in order until one returns a non-empty business ID.
95+
func NewBusinessIDInterceptor(
96+
extractors []BusinessIDExtractorFunc,
97+
logger log.Logger,
98+
) *BusinessIDInterceptor {
99+
return &BusinessIDInterceptor{
100+
extractors: extractors,
101+
logger: logger,
102+
}
103+
}
104+
105+
// WithExtractors returns a new interceptor with additional extractors prepended.
106+
// The new extractors will be tried before the existing ones.
107+
func (i *BusinessIDInterceptor) WithExtractors(extractors ...BusinessIDExtractorFunc) *BusinessIDInterceptor {
108+
return &BusinessIDInterceptor{
109+
extractors: append(extractors, i.extractors...),
110+
logger: i.logger,
111+
}
112+
}
113+
114+
var _ grpc.UnaryServerInterceptor = (*BusinessIDInterceptor)(nil).Intercept
115+
116+
// Intercept extracts business ID from the request and adds it to the context.
117+
// It tries each extractor in order until one returns a non-empty business ID.
118+
func (i *BusinessIDInterceptor) Intercept(
119+
ctx context.Context,
120+
req any,
121+
info *grpc.UnaryServerInfo,
122+
handler grpc.UnaryHandler,
123+
) (any, error) {
124+
// Try each extractor until one returns a non-empty businessID
125+
for _, extractor := range i.extractors {
126+
if businessID := extractor(ctx, req, info.FullMethod); businessID != "" {
127+
i.logger.Debug("business ID extraction: adding business ID to context",
128+
tag.WorkflowID(businessID),
129+
tag.NewStringTag("grpc-method", info.FullMethod),
130+
)
131+
ctx = AddBusinessIDToContext(ctx, businessID)
132+
break
133+
}
134+
}
135+
136+
return handler(ctx, req)
137+
}
138+
139+
// AddBusinessIDToContext adds the business ID to the context
140+
func AddBusinessIDToContext(ctx context.Context, businessID string) context.Context {
141+
return context.WithValue(ctx, businessIDCtxKey, businessID)
142+
}
143+
144+
// GetBusinessIDFromContext retrieves the business ID from the context.
145+
// Returns namespace.EmptyBusinessID if not found.
146+
func GetBusinessIDFromContext(ctx context.Context) string {
147+
if businessID, ok := ctx.Value(businessIDCtxKey).(string); ok {
148+
return businessID
149+
}
150+
return namespace.EmptyBusinessID
151+
}

0 commit comments

Comments
 (0)