Skip to content

Commit 34bab88

Browse files
committed
save freshly generated requestID to ctx
1 parent 1e7821f commit 34bab88

File tree

5 files changed

+62
-11
lines changed

5 files changed

+62
-11
lines changed

internal/audit/audit_event.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,11 @@ func GRPCCallAuditEvent(
153153
err error,
154154
) *GRPCCallEvent {
155155
s, r := getStatus(inProgress, err)
156+
requestID, _ := grpcinfo.GetRequestID(ctx)
156157
return &GRPCCallEvent{
157158
GenericAuditFields: GenericAuditFields{
158159
ID: uuid.New().String(),
159-
IdempotencyKey: grpcinfo.GetRequestID(ctx),
160+
IdempotencyKey: requestID,
160161
TraceID: formatTraceID(grpcinfo.GetTraceID(ctx)),
161162
Service: "ydbcp",
162163
SpecVersion: "1.0",

internal/audit/audit_event_test.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import (
1515
"os"
1616
"testing"
1717
"time"
18+
"ydbcp/internal/server/grpcinfo"
1819
"ydbcp/internal/types"
1920
pb "ydbcp/pkg/proto/ydbcp/v1alpha1"
2021
)
@@ -382,3 +383,20 @@ func (m *mockAddr) Network() string {
382383
func (m *mockAddr) String() string {
383384
return m.address
384385
}
386+
387+
func TestWithGRPCInfo(t *testing.T) {
388+
ctx := context.Background()
389+
ctx = grpcinfo.WithGRPCInfo(ctx)
390+
SetAuditFieldsForRequest(
391+
ctx, &AuditFields{
392+
ContainerID: "container-1",
393+
Database: "db-1",
394+
},
395+
)
396+
397+
requestID, _ := grpcinfo.GetRequestID(ctx)
398+
fields := GetAuditFieldsForRequest(requestID)
399+
require.NotNil(t, fields)
400+
require.Equal(t, "container-1", fields.ContainerID)
401+
require.Equal(t, "db-1", fields.Database)
402+
}

internal/audit/audit_interceptor.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ var (
2222
)
2323

2424
func SetAuditFieldsForRequest(ctx context.Context, fields *AuditFields) {
25-
containerStore.Store(grpcinfo.GetRequestID(ctx), fields)
25+
requestID, _ := grpcinfo.GetRequestID(ctx)
26+
containerStore.Store(requestID, fields)
2627
}
2728

2829
func GetAuditFieldsForRequest(requestID string) *AuditFields {
@@ -42,8 +43,6 @@ func NewAuditGRPCInterceptor(provider auth.AuthProvider) grpc.UnaryServerInterce
4243
ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler,
4344
) (interface{}, error) {
4445
ctx = grpcinfo.WithGRPCInfo(ctx)
45-
requestID := grpcinfo.GetRequestID(ctx)
46-
ctx = grpcinfo.SetRequestID(ctx, requestID)
4746
subject, _ := authHelper.Authenticate(ctx, provider)
4847
token, _ := authHelper.GetMaskedToken(ctx, provider)
4948
pm, ok := req.(proto.Message)
@@ -55,6 +54,7 @@ func NewAuditGRPCInterceptor(provider auth.AuthProvider) grpc.UnaryServerInterce
5554
)
5655
}
5756
response, grpcErr := handler(ctx, req)
57+
requestID, _ := grpcinfo.GetRequestID(ctx)
5858
fields := GetAuditFieldsForRequest(requestID)
5959
defer ClearAuditFieldsForRequest(requestID)
6060
ReportGRPCCallEnd(ctx, info.FullMethod, subject, fields.ContainerID, fields.Database, token, grpcErr)

internal/server/grpcinfo/grpcinfo.go

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,17 +27,17 @@ func SetRequestID(ctx context.Context, id string) context.Context {
2727
return context.WithValue(ctx, ctxKeyRequestID{}, id)
2828
}
2929

30-
func GetRequestID(ctx context.Context) string {
30+
func GetRequestID(ctx context.Context) (string, bool) {
3131
if id, ok := ctx.Value(ctxKeyRequestID{}).(string); ok {
32-
return id
32+
return id, false
3333
}
3434
for _, key := range []string{"RequestID", "RequestId", "request-id", "request_id"} {
3535
val := getFromCtx(ctx, key)
3636
if val != nil {
37-
return *val
37+
return *val, false
3838
}
3939
}
40-
return uuid.New().String()
40+
return uuid.New().String(), true
4141
}
4242

4343
func GetTraceID(ctx context.Context) *string {
@@ -76,12 +76,15 @@ func WithGRPCInfo(ctx context.Context) context.Context {
7676
if method, ok := grpc.Method(ctx); ok {
7777
ctx = xlog.With(ctx, zap.String("GRPCMethod", method))
7878
}
79-
requestID := GetRequestID(ctx)
79+
requestID, newID := GetRequestID(ctx)
8080
ctx = xlog.With(ctx, zap.String("RequestID", requestID))
8181
err := grpc.SendHeader(ctx, metadata.Pairs("X-Request-ID", requestID))
8282
if err != nil {
8383
xlog.Error(ctx, "failed to set X-Request-ID header", zap.Error(err))
8484
}
8585
xlog.Debug(ctx, "New grpc request")
86+
if newID {
87+
ctx = SetRequestID(ctx, requestID)
88+
}
8689
return ctx
8790
}

internal/server/grpcinfo/grpcinfo_test.go

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,11 @@ package grpcinfo
22

33
import (
44
"context"
5-
"testing"
6-
75
"github.com/stretchr/testify/assert"
6+
"github.com/stretchr/testify/require"
87
"google.golang.org/grpc/metadata"
8+
"testing"
9+
"ydbcp/internal/util/xlog"
910
)
1011

1112
func TestGetRemoteAddressChain(t *testing.T) {
@@ -44,3 +45,31 @@ func TestGetTraceID(t *testing.T) {
4445
ctx = context.WithValue(ctx, "trace_id", traceID)
4546
assert.Equal(t, traceID, *GetTraceID(ctx))
4647
}
48+
49+
func TestRequestID(t *testing.T) {
50+
ctx := context.Background()
51+
52+
id, generated := GetRequestID(ctx)
53+
require.True(t, generated)
54+
55+
ctx = SetRequestID(ctx, id)
56+
57+
id2, generated := GetRequestID(ctx)
58+
require.False(t, generated)
59+
require.Equal(t, id, id2)
60+
}
61+
62+
func TestWithGRPCInfo(t *testing.T) {
63+
logger, err := xlog.SetupLogging("DEBUG")
64+
require.NoError(t, err)
65+
xlog.SetInternalLogger(logger)
66+
67+
ctx := context.Background()
68+
ctx = WithGRPCInfo(ctx)
69+
70+
id, generated := GetRequestID(ctx)
71+
require.False(t, generated)
72+
id2, generated := GetRequestID(ctx)
73+
require.False(t, generated)
74+
require.Equal(t, id, id2)
75+
}

0 commit comments

Comments
 (0)