diff --git a/extra/redisotel/config.go b/extra/redisotel/config.go index 6d90abfd0..62b3c9bc2 100644 --- a/extra/redisotel/config.go +++ b/extra/redisotel/config.go @@ -1,6 +1,9 @@ package redisotel import ( + "strings" + + "github.com/redis/go-redis/v9" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/metric" @@ -21,6 +24,7 @@ type config struct { dbStmtEnabled bool callerEnabled bool + filter func(cmd redis.Cmder) bool // Metrics options. @@ -124,6 +128,37 @@ func WithCallerEnabled(on bool) TracingOption { }) } +// WithCommandFilter allows filtering of commands when tracing to omit commands that may have sensitive details like +// passwords. +func WithCommandFilter(filter func(cmd redis.Cmder) bool) TracingOption { + return tracingOption(func(conf *config) { + conf.filter = filter + }) +} + +func BasicCommandFilter(cmd redis.Cmder) bool { + if strings.ToLower(cmd.Name()) == "auth" { + return true + } + + if strings.ToLower(cmd.Name()) == "hello" { + if len(cmd.Args()) < 3 { + return false + } + + arg, exists := cmd.Args()[2].(string) + if !exists { + return false + } + + if strings.ToLower(arg) == "auth" { + return true + } + } + + return false +} + //------------------------------------------------------------------------------ type MetricsOption interface { diff --git a/extra/redisotel/tracing.go b/extra/redisotel/tracing.go index 40df5a202..5c91710c6 100644 --- a/extra/redisotel/tracing.go +++ b/extra/redisotel/tracing.go @@ -102,6 +102,12 @@ func (th *tracingHook) DialHook(hook redis.DialHook) redis.DialHook { func (th *tracingHook) ProcessHook(hook redis.ProcessHook) redis.ProcessHook { return func(ctx context.Context, cmd redis.Cmder) error { + // Check if the command should be filtered out + if th.conf.filter != nil && th.conf.filter(cmd) { + // If so, just call the next hook + return hook(ctx, cmd) + } + attrs := make([]attribute.KeyValue, 0, 8) if th.conf.callerEnabled { fn, file, line := funcFileLine("github.com/redis/go-redis") diff --git a/extra/redisotel/tracing_test.go b/extra/redisotel/tracing_test.go index a3e3ccc62..0ae70c2d8 100644 --- a/extra/redisotel/tracing_test.go +++ b/extra/redisotel/tracing_test.go @@ -95,6 +95,138 @@ func TestWithoutCaller(t *testing.T) { } } +func TestWithCommandFilter(t *testing.T) { + + t.Run("filter out ping command", func(t *testing.T) { + provider := sdktrace.NewTracerProvider() + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithCommandFilter(func(cmd redis.Cmder) bool { + return cmd.Name() == "ping" + }), + ) + ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") + cmd := redis.NewCmd(ctx, "ping") + defer span.End() + + processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error { + innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan) + if innerSpan.Name() != "redis-test" || innerSpan.Name() == "ping" { + t.Fatalf("ping command should not be traced") + } + + return nil + }) + err := processHook(ctx, cmd) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("do not filter ping command", func(t *testing.T) { + provider := sdktrace.NewTracerProvider() + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithCommandFilter(func(cmd redis.Cmder) bool { + return false // never filter + }), + ) + ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") + cmd := redis.NewCmd(ctx, "ping") + defer span.End() + + processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error { + innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan) + if innerSpan.Name() != "ping" { + t.Fatalf("ping command should be traced") + } + + return nil + }) + err := processHook(ctx, cmd) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("auth command filtered with basic command filter", func(t *testing.T) { + provider := sdktrace.NewTracerProvider() + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithCommandFilter(BasicCommandFilter), + ) + ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") + cmd := redis.NewCmd(ctx, "auth", "test-password") + defer span.End() + + processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error { + innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan) + if innerSpan.Name() != "redis-test" || innerSpan.Name() == "auth" { + t.Fatalf("auth command should not be traced by default") + } + + return nil + }) + err := processHook(ctx, cmd) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("hello command filtered with basic command filter when sensitive", func(t *testing.T) { + provider := sdktrace.NewTracerProvider() + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithCommandFilter(BasicCommandFilter), + ) + ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") + cmd := redis.NewCmd(ctx, "hello", 3, "AUTH", "test-user", "test-password") + defer span.End() + + processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error { + innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan) + if innerSpan.Name() != "redis-test" || innerSpan.Name() == "hello" { + t.Fatalf("auth command should not be traced by default") + } + + return nil + }) + err := processHook(ctx, cmd) + if err != nil { + t.Fatal(err) + } + }) + + t.Run("hello command not filtered with basic command filter when not sensitive", func(t *testing.T) { + provider := sdktrace.NewTracerProvider() + hook := newTracingHook( + "", + WithTracerProvider(provider), + WithCommandFilter(BasicCommandFilter), + ) + ctx, span := provider.Tracer("redis-test").Start(context.TODO(), "redis-test") + cmd := redis.NewCmd(ctx, "hello", 3) + defer span.End() + + processHook := hook.ProcessHook(func(ctx context.Context, cmd redis.Cmder) error { + innerSpan := trace.SpanFromContext(ctx).(sdktrace.ReadOnlySpan) + if innerSpan.Name() != "hello" { + t.Fatalf("hello command should be traced") + } + + return nil + }) + err := processHook(ctx, cmd) + if err != nil { + t.Fatal(err) + } + }) +} + func TestTracingHook_DialHook(t *testing.T) { imsb := tracetest.NewInMemoryExporter() provider := sdktrace.NewTracerProvider(sdktrace.WithSyncer(imsb))