diff --git a/agent/graphagent/graph_agent.go b/agent/graphagent/graph_agent.go index ee57d31cd..3789ea6e6 100644 --- a/agent/graphagent/graph_agent.go +++ b/agent/graphagent/graph_agent.go @@ -274,6 +274,10 @@ func (ga *GraphAgent) createInitialState(ctx context.Context, invocation *agent. } // Add parent agent to state so agent nodes can access sub-agents. initialState[graph.StateKeyParentAgent] = ga + // Set checkpoint namespace if not already set. + if ns, ok := initialState[graph.CfgKeyCheckpointNS].(string); !ok || ns == "" { + initialState[graph.CfgKeyCheckpointNS] = ga.name + } return initialState } diff --git a/docs/mkdocs/en/graph.md b/docs/mkdocs/en/graph.md index 21ba439ad..bd02d8151 100644 --- a/docs/mkdocs/en/graph.md +++ b/docs/mkdocs/en/graph.md @@ -2424,6 +2424,7 @@ import ( "trpc.group/trpc-go/trpc-agent-go/agent/graphagent" "trpc.group/trpc-go/trpc-agent-go/graph" "trpc.group/trpc-go/trpc-agent-go/graph/checkpoint/sqlite" + "trpc.group/trpc-go/trpc-agent-go/graph/checkpoint/redis" "trpc.group/trpc-go/trpc-agent-go/model" ) @@ -2436,6 +2437,22 @@ graphAgent, _ := graphagent.New("workflow", g, // Checkpoints are saved automatically during execution (by default every step) +// Resume from a checkpoint +eventCh, err := r.Run(ctx, userID, sessionID, + model.NewUserMessage("resume"), + agent.WithRuntimeState(map[string]any{ + graph.CfgKeyCheckpointID: "ckpt-123", + }), +) + +// Configure redis checkpoints +redisSaver, _ := redis.NewSaver(redis.WithRedisClientURL("redis://[username:password@]host:port[/database]")) + +graphAgent, _ := graphagent.New("workflow", g, + graphagent.WithCheckpointSaver(redisSaver)) + +// Checkpoints are saved automatically during execution (by default every step) + // Resume from a checkpoint eventCh, err := r.Run(ctx, userID, sessionID, model.NewUserMessage("resume"), diff --git a/docs/mkdocs/zh/graph.md b/docs/mkdocs/zh/graph.md index 90d31f3c2..49ffdf480 100644 --- a/docs/mkdocs/zh/graph.md +++ b/docs/mkdocs/zh/graph.md @@ -2393,7 +2393,7 @@ API 参考: ### 检查点与恢复 -为了支持时间旅行与可靠恢复,可以为执行器或 GraphAgent 配置检查点保存器。下面演示使用 SQLite Saver 持久化检查点并从特定检查点恢复。 +为了支持时间旅行与可靠恢复,可以为执行器或 GraphAgent 配置检查点保存器。下面演示使用 SQLite/Redis Saver 持久化检查点并从特定检查点恢复。 ```go import ( @@ -2404,6 +2404,7 @@ import ( "trpc.group/trpc-go/trpc-agent-go/agent/graphagent" "trpc.group/trpc-go/trpc-agent-go/graph" "trpc.group/trpc-go/trpc-agent-go/graph/checkpoint/sqlite" + "trpc.group/trpc-go/trpc-agent-go/graph/checkpoint/redis" "trpc.group/trpc-go/trpc-agent-go/model" ) @@ -2416,6 +2417,22 @@ graphAgent, _ := graphagent.New("workflow", g, // 执行时自动保存检查点(默认每步保存) +// 从检查点恢复 +eventCh, err := r.Run(ctx, userID, sessionID, + model.NewUserMessage("resume"), + agent.WithRuntimeState(map[string]any{ + graph.CfgKeyCheckpointID: "ckpt-123", + }), +) + +// 配置redis检查点 +redisSaver, _ := redis.NewSaver(redis.WithRedisClientURL("redis://[username:password@]host:port[/database]")) + +graphAgent, _ := graphagent.New("workflow", g, + graphagent.WithCheckpointSaver(redisSaver)) + +// 执行时自动保存检查点(默认每步保存) + // 从检查点恢复 eventCh, err := r.Run(ctx, userID, sessionID, model.NewUserMessage("resume"), diff --git a/examples/go.mod b/examples/go.mod index ad9e00305..6d1ee84ff 100644 --- a/examples/go.mod +++ b/examples/go.mod @@ -4,6 +4,8 @@ go 1.23.0 replace trpc.group/trpc-go/trpc-agent-go => ../ +replace trpc.group/trpc-go/trpc-agent-go/graph/checkpoint/redis => ../graph/checkpoint/redis + require ( github.com/google/uuid v1.6.0 github.com/mattn/go-sqlite3 v1.14.32 @@ -14,13 +16,16 @@ require ( go.uber.org/zap v1.27.0 trpc.group/trpc-go/trpc-a2a-go v0.2.5 trpc.group/trpc-go/trpc-agent-go v0.5.0 + trpc.group/trpc-go/trpc-agent-go/graph/checkpoint/redis v0.0.0-00010101000000-000000000000 trpc.group/trpc-go/trpc-mcp-go v0.0.10 ) require ( github.com/bmatcuk/doublestar/v4 v4.9.1 // indirect github.com/cenkalti/backoff/v4 v4.3.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/getkin/kin-openapi v0.124.0 // indirect github.com/go-logr/logr v1.4.3 // indirect github.com/go-logr/stdr v1.2.2 // indirect @@ -43,6 +48,7 @@ require ( github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 // indirect github.com/panjf2000/ants/v2 v2.10.0 // indirect github.com/perimeterx/marshmallow v1.1.5 // indirect + github.com/redis/go-redis/v9 v9.17.0 // indirect github.com/segmentio/asm v1.2.0 // indirect github.com/spaolacci/murmur3 v1.1.0 // indirect github.com/tidwall/gjson v1.14.4 // indirect @@ -72,4 +78,5 @@ require ( google.golang.org/grpc v1.67.0 // indirect google.golang.org/protobuf v1.34.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + trpc.group/trpc-go/trpc-agent-go/storage/redis v0.5.0 // indirect ) diff --git a/examples/go.sum b/examples/go.sum index d2474a4aa..479789868 100644 --- a/examples/go.sum +++ b/examples/go.sum @@ -1,12 +1,22 @@ +github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI= +github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= github.com/bmatcuk/doublestar/v4 v4.9.1 h1:X8jg9rRZmJd4yRy7ZeNDRnM+T3ZfHv15JiBJ/avrEXE= github.com/bmatcuk/doublestar/v4 v4.9.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0 h1:NMZiJj8QnKe1LgsbDayM4UoHwbvwDRwnI3hwNaAHRnc= github.com/decred/dcrd/dcrec/secp256k1/v4 v4.4.0/go.mod h1:ZXNYxsqcloTdSy/rNShjYzMhyjf0LaoftYK0p+A3h40= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= github.com/getkin/kin-openapi v0.124.0 h1:VSFNMB9C9rTKBnQ/fpyDU8ytMTr4dWI9QovSKj9kz/M= github.com/getkin/kin-openapi v0.124.0/go.mod h1:wb1aSZA/iWmorQP9KTAS/phLj/t17B5jT7+fS8ed9NM= github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= @@ -69,6 +79,8 @@ github.com/perimeterx/marshmallow v1.1.5 h1:a2LALqQ1BlHM8PZblsDdidgv1mWi1DgC2UmX github.com/perimeterx/marshmallow v1.1.5/go.mod h1:dsXbUu8CRzfYP5a87xpp0xq9S3u0Vchtcl8we9tYaXw= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.17.0 h1:K6E+ZlYN95KSMmZeEQPbU/c++wfmEvfFB17yEAq/VhM= +github.com/redis/go-redis/v9 v9.17.0/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= github.com/segmentio/asm v1.2.0 h1:9BQrFxC+YOHJlTlHGkTrFWf59nbL3XnCoFLTwDCI7ys= @@ -98,6 +110,8 @@ github.com/ugorji/go/codec v1.2.7 h1:YPXUKf7fYbp/y8xloBqZOw2qaVggbfwMlI8WM3wZUJ0 github.com/ugorji/go/codec v1.2.7/go.mod h1:WGN1fab3R1fzQlVQTkfxVtIBhWDRqOviHU95kRgeqEY= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= go.opentelemetry.io/otel v1.38.0 h1:RkfdswUDRimDg0m2Az18RKOsnI8UDzppJAtj01/Ymk8= @@ -160,5 +174,7 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= trpc.group/trpc-go/trpc-a2a-go v0.2.5 h1:X3pAlWD128LaS9TtXsUDZoJWPVuPZDkZKUecKRxmWn4= trpc.group/trpc-go/trpc-a2a-go v0.2.5/go.mod h1:Gtytau9Uoc3oPo/dpHvKit+tQn9Qlk5XFG1RiZTGqfk= +trpc.group/trpc-go/trpc-agent-go/storage/redis v0.5.0 h1:zuElT5t+ESMDvZzXI3rDrzg5FYOc4RPxwdWg+AQX8Ao= +trpc.group/trpc-go/trpc-agent-go/storage/redis v0.5.0/go.mod h1:aDGUkqlbGttFVGlZW2lMg5mAgurHezq5De/gAuLhV5E= trpc.group/trpc-go/trpc-mcp-go v0.0.10 h1:kKPfevmikMojfOgtUBf5SJQ/v6aDugckodgyH1uDu2Q= trpc.group/trpc-go/trpc-mcp-go v0.0.10/go.mod h1:OT6rLglkdaQ17D2T1Y87Y/ckItzdsEldDbw7dHAbGEA= diff --git a/examples/graph/checkpoint/README.md b/examples/graph/checkpoint/README.md index ff1eada40..3b2cd7b13 100644 --- a/examples/graph/checkpoint/README.md +++ b/examples/graph/checkpoint/README.md @@ -59,7 +59,7 @@ go build . ### Command-Line Flags - `-model`: LLM model to use (default: "deepseek-chat") -- `-storage`: Storage backend - "memory" or "sqlite" (default: "memory") +- `-storage`: Storage backend - "memory" or "sqlite" or "redis" (default: "memory") - `-db`: SQLite database file path (default: "checkpoints.db", only used with -storage=sqlite) - `-verbose`: Enable verbose execution output (default: false) @@ -345,4 +345,4 @@ Set these environment variables before running: 1. **SQLite Schema**: The SQLite checkpoint saver may require schema updates if you encounter "no such column: seq" errors. The schema needs to be updated to match the latest checkpoint structure. -2. **Concurrent Access**: The in-memory checkpoint saver uses maps that may have concurrent access issues in high-throughput scenarios. Use proper synchronization if needed. \ No newline at end of file +2. **Concurrent Access**: The in-memory checkpoint saver uses maps that may have concurrent access issues in high-throughput scenarios. Use proper synchronization if needed. diff --git a/examples/graph/checkpoint/main.go b/examples/graph/checkpoint/main.go index 4e12c0df6..655bf6c1d 100644 --- a/examples/graph/checkpoint/main.go +++ b/examples/graph/checkpoint/main.go @@ -34,6 +34,7 @@ import ( "trpc.group/trpc-go/trpc-agent-go/event" "trpc.group/trpc-go/trpc-agent-go/graph" checkpointinmemory "trpc.group/trpc-go/trpc-agent-go/graph/checkpoint/inmemory" + checkpointredis "trpc.group/trpc-go/trpc-agent-go/graph/checkpoint/redis" checkpointsqlite "trpc.group/trpc-go/trpc-agent-go/graph/checkpoint/sqlite" agentlog "trpc.group/trpc-go/trpc-agent-go/log" "trpc.group/trpc-go/trpc-agent-go/model" @@ -83,9 +84,11 @@ var ( modelName = flag.String("model", defaultModelName, "Name of the model to use") storage = flag.String("storage", "memory", - "Storage type: 'memory' or 'sqlite'") + "Storage type: 'memory' or 'sqlite' or 'redis'") dbPath = flag.String("db", defaultDBPath, "Path to SQLite database file (only used with -storage=sqlite)") + redisClientURL = flag.String("redis-url", "redis://localhost:6379", + "Redis client URL (only used with -storage=redis)") verbose = flag.Bool("verbose", false, "Enable verbose output") ) @@ -106,10 +109,12 @@ func main() { // Create and run the workflow. workflow := &checkpointWorkflow{ - modelName: *modelName, - storageType: *storage, - dbPath: *dbPath, - verbose: *verbose, + modelName: *modelName, + storageType: *storage, + dbPath: *dbPath, + verbose: *verbose, + redisClientURL: *redisClientURL, + currentNamespace: "checkpoint-demo", } if err := workflow.run(); err != nil { log.Fatalf("Workflow failed: %v", err) @@ -121,6 +126,7 @@ type checkpointWorkflow struct { modelName string storageType string dbPath string + redisClientURL string verbose bool logger agentlog.Logger runner runner.Runner @@ -171,6 +177,12 @@ func (w *checkpointWorkflow) setup() error { w.saver = saver case "memory": w.saver = checkpointinmemory.NewSaver() + case "redis": + saver, err := checkpointredis.NewSaver(checkpointredis.WithRedisClientURL(w.redisClientURL)) + if err != nil { + return fmt.Errorf("failed to create Redis saver: %w", err) + } + w.saver = saver default: return fmt.Errorf("unsupported storage type: %s", w.storageType) } @@ -553,7 +565,6 @@ func (w *checkpointWorkflow) startInteractiveMode(ctx context.Context) error { func (w *checkpointWorkflow) runWorkflow(ctx context.Context, lineageID string) error { startTime := time.Now() w.currentLineageID = lineageID - w.currentNamespace = "" // Use empty namespace to align with LangGraph's design w.logger.Infof("Starting workflow execution: lineage_id=%s, namespace=%s", lineageID, w.currentNamespace) @@ -897,6 +908,7 @@ func (w *checkpointWorkflow) listCheckpoints(ctx context.Context, lineageID stri // Create config for the lineage. config := graph.NewCheckpointConfig(lineageID) + config.Namespace = "checkpoint-demo" // List checkpoints with a filter. manager := w.graphAgent.Executor().CheckpointManager() diff --git a/examples/graph/interrupt/main.go b/examples/graph/interrupt/main.go index 2d742f4b7..3ac6f407c 100644 --- a/examples/graph/interrupt/main.go +++ b/examples/graph/interrupt/main.go @@ -32,6 +32,7 @@ import ( "trpc.group/trpc-go/trpc-agent-go/agent/graphagent" "trpc.group/trpc-go/trpc-agent-go/graph" checkpointinmemory "trpc.group/trpc-go/trpc-agent-go/graph/checkpoint/inmemory" + checkpointredis "trpc.group/trpc-go/trpc-agent-go/graph/checkpoint/redis" checkpointsqlite "trpc.group/trpc-go/trpc-agent-go/graph/checkpoint/sqlite" agentlog "trpc.group/trpc-go/trpc-agent-go/log" "trpc.group/trpc-go/trpc-agent-go/model" @@ -46,7 +47,6 @@ const ( defaultAppName = "interrupt-workflow" defaultDBPath = "interrupt-checkpoints.db" defaultLineagePrefix = "interrupt-demo" - defaultNamespace = "" // State keys for the workflow. stateKeyCounter = "counter" @@ -93,9 +93,11 @@ var ( modelName = flag.String("model", defaultModelName, "Name of the model to use") storage = flag.String("storage", "memory", - "Storage type: 'memory' or 'sqlite'") + "Storage type: 'memory' or 'sqlite' or 'redis'") dbPath = flag.String("db", defaultDBPath, "Path to SQLite database file (only used with -storage=sqlite)") + redisClientURL = flag.String("redis-url", "redis://localhost:6379", + "Redis client URL (only used with -storage=redis)") verbose = flag.Bool("verbose", false, "Enable verbose output") interactiveMode = flag.Bool("interactive", true, @@ -109,6 +111,7 @@ type interruptWorkflow struct { modelName string storageType string dbPath string + redisClientURL string verbose bool logger agentlog.Logger runner runner.Runner @@ -153,7 +156,7 @@ func main() { time.Now().Unix()) } workflow.sessionID = fmt.Sprintf("session-%d", time.Now().Unix()) - workflow.currentNamespace = defaultNamespace + workflow.currentNamespace = "interrupt-demo" if err := workflow.run(); err != nil { fmt.Printf("❌ Workflow failed: %v\n", err) @@ -207,6 +210,12 @@ func (w *interruptWorkflow) setup() error { w.saver = saver case "memory": w.saver = checkpointinmemory.NewSaver() + case "redis": + saver, err := checkpointredis.NewSaver(checkpointredis.WithRedisClientURL(w.redisClientURL)) + if err != nil { + return fmt.Errorf("failed to create Redis saver: %w", err) + } + w.saver = saver default: return fmt.Errorf("unsupported storage type: %s", w.storageType) } @@ -637,7 +646,6 @@ func (w *interruptWorkflow) startInteractiveMode(ctx context.Context) error { func (w *interruptWorkflow) runWorkflow(ctx context.Context, lineageID string, waitForInterrupt bool) error { startTime := time.Now() w.currentLineageID = lineageID - w.currentNamespace = defaultNamespace w.logger.Infof("Starting workflow execution: lineage_id=%s, namespace=%s, wait_for_interrupt=%v", lineageID, w.currentNamespace, waitForInterrupt) @@ -737,7 +745,6 @@ func (w *interruptWorkflow) runWorkflow(ctx context.Context, lineageID string, w // resumeWorkflow resumes execution from a checkpoint. func (w *interruptWorkflow) resumeWorkflow(ctx context.Context, lineageID, checkpointID, userInput string) error { w.currentLineageID = lineageID - w.currentNamespace = defaultNamespace // Check if the lineage exists before attempting resume. config := graph.NewCheckpointConfig(lineageID).WithNamespace(w.currentNamespace) diff --git a/graph/checkpoint/redis/go.mod b/graph/checkpoint/redis/go.mod new file mode 100644 index 000000000..4b0856293 --- /dev/null +++ b/graph/checkpoint/redis/go.mod @@ -0,0 +1,49 @@ +module trpc.group/trpc-go/trpc-agent-go/graph/checkpoint/redis + +go 1.21 + +replace ( + trpc.group/trpc-go/trpc-agent-go => ../../../ + trpc.group/trpc-go/trpc-agent-go/storage/redis => ../../../storage/redis +) + +require ( + github.com/alicebob/miniredis/v2 v2.35.0 + github.com/google/uuid v1.6.0 + github.com/redis/go-redis/v9 v9.17.0 + github.com/stretchr/testify v1.10.0 + trpc.group/trpc-go/trpc-a2a-go v0.2.5-0.20251023030722-7f02b57fd14a + trpc.group/trpc-go/trpc-agent-go v0.0.0-00010101000000-000000000000 + trpc.group/trpc-go/trpc-agent-go/storage/redis v0.5.0 +) + +require ( + github.com/cenkalti/backoff/v4 v4.3.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/go-logr/logr v1.4.2 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + github.com/spaolacci/murmur3 v1.1.0 // indirect + github.com/yuin/gopher-lua v1.1.1 // indirect + go.opentelemetry.io/otel v1.29.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.29.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.29.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.29.0 // indirect + go.opentelemetry.io/otel/metric v1.29.0 // indirect + go.opentelemetry.io/otel/sdk v1.29.0 // indirect + go.opentelemetry.io/otel/trace v1.29.0 // indirect + go.opentelemetry.io/proto/otlp v1.3.1 // indirect + go.uber.org/multierr v1.10.0 // indirect + go.uber.org/zap v1.27.0 // indirect + golang.org/x/net v0.34.0 // indirect + golang.org/x/sys v0.30.0 // indirect + golang.org/x/text v0.21.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20240822170219-fc7c04adadcd // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20240822170219-fc7c04adadcd // indirect + google.golang.org/grpc v1.65.0 // indirect + google.golang.org/protobuf v1.34.2 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/graph/checkpoint/redis/go.sum b/graph/checkpoint/redis/go.sum new file mode 100644 index 000000000..49674d839 --- /dev/null +++ b/graph/checkpoint/redis/go.sum @@ -0,0 +1,90 @@ +github.com/alicebob/miniredis/v2 v2.35.0 h1:QwLphYqCEAo1eu1TqPRN2jgVMPBweeQcR21jeqDCONI= +github.com/alicebob/miniredis/v2 v2.35.0/go.mod h1:TcL7YfarKPGDAthEtl5NBeHZfeUQj6OXMm/+iu5cLMM= +github.com/bmatcuk/doublestar/v4 v4.9.1 h1:X8jg9rRZmJd4yRy7ZeNDRnM+T3ZfHv15JiBJ/avrEXE= +github.com/bmatcuk/doublestar/v4 v4.9.1/go.mod h1:xBQ8jztBU6kakFMg+8WGxn0c6z1fTSPVIjEY1Wr7jzc= +github.com/bsm/ginkgo/v2 v2.12.0 h1:Ny8MWAHyOepLGlLKYmXG4IEkioBysk6GpaRTLC8zwWs= +github.com/bsm/ginkgo/v2 v2.12.0/go.mod h1:SwYbGRRDovPVboqFv0tPTcG1sN61LM1Z4ARdbAV9g4c= +github.com/bsm/gomega v1.27.10 h1:yeMWxP2pV2fG3FgAODIY8EiRE3dy0aeFYt4l7wh6yKA= +github.com/bsm/gomega v1.27.10/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= +github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 h1:asbCHRVmodnJTuQ3qamDwqVOIjwqUPTYmYuemVOx+Ys= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0/go.mod h1:ggCgvZ2r7uOoQjOyu2Y1NhHmEPPzzuhWgcza5M1Ji1I= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/mattn/go-sqlite3 v1.14.32 h1:JD12Ag3oLy1zQA+BNn74xRgaBbdhbNIDYvQUEuuErjs= +github.com/mattn/go-sqlite3 v1.14.32/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.17.0 h1:K6E+ZlYN95KSMmZeEQPbU/c++wfmEvfFB17yEAq/VhM= +github.com/redis/go-redis/v9 v9.17.0/go.mod h1:u410H11HMLoB+TP67dz8rL9s6QW2j76l0//kSOd3370= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= +github.com/spaolacci/murmur3 v1.1.0 h1:7c1g84S4BPRrfL5Xrdp6fOJ206sU9y293DDHaoy0bLI= +github.com/spaolacci/murmur3 v1.1.0/go.mod h1:JwIasOWyU6f++ZhiEuf87xNszmSA2myDM2Kzu9HwQUA= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= +github.com/yuin/gopher-lua v1.1.1/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= +go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw= +go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.29.0 h1:dIIDULZJpgdiHz5tXrTgKIMLkus6jEFa7x5SOKcyR7E= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.29.0/go.mod h1:jlRVBe7+Z1wyxFSUs48L6OBQZ5JwH2Hg/Vbl+t9rAgI= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.29.0 h1:nSiV3s7wiCam610XcLbYOmMfJxB9gO4uK3Xgv5gmTgg= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc v1.29.0/go.mod h1:hKn/e/Nmd19/x1gvIHwtOwVWM+VhuITSWip3JUDghj0= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.29.0 h1:JAv0Jwtl01UFiyWZEMiJZBiTlv5A50zNs8lsthXqIio= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.29.0/go.mod h1:QNKLmUEAq2QUbPQUfvw4fmv0bgbK7UlOSFCnXyfvSNc= +go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc= +go.opentelemetry.io/otel/metric v1.29.0/go.mod h1:auu/QWieFVWx+DmQOUMgj0F8LHWdgalxXqvp7BII/W8= +go.opentelemetry.io/otel/sdk v1.29.0 h1:vkqKjk7gwhS8VaWb0POZKmIEDimRCMsopNYnriHyryo= +go.opentelemetry.io/otel/sdk v1.29.0/go.mod h1:pM8Dx5WKnvxLCb+8lG1PRNIDxu9g9b9g59Qr7hfAAok= +go.opentelemetry.io/otel/sdk/metric v1.29.0 h1:K2CfmJohnRgvZ9UAj2/FhIf/okdWcNdBwe1m8xFXiSY= +go.opentelemetry.io/otel/sdk/metric v1.29.0/go.mod h1:6zZLdCl2fkauYoZIOn/soQIDSWFmNSRcICarHfuhNJQ= +go.opentelemetry.io/otel/trace v1.29.0 h1:J/8ZNK4XgR7a21DZUAsbF8pZ5Jcw1VhACmnYt39JTi4= +go.opentelemetry.io/otel/trace v1.29.0/go.mod h1:eHl3w0sp3paPkYstJOmAimxhiFXPg+MMTlEh3nsQgWQ= +go.opentelemetry.io/proto/otlp v1.3.1 h1:TrMUixzpM0yuc/znrFTP9MMRh8trP93mkCiDVeXrui0= +go.opentelemetry.io/proto/otlp v1.3.1/go.mod h1:0X1WI4de4ZsLrrJNLAQbFeLCm3T7yBkR0XqQ7niQU+8= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= +golang.org/x/sys v0.30.0 h1:QjkSwP/36a20jFYWkSue1YwXzLmsV5Gfq7Eiy72C1uc= +golang.org/x/sys v0.30.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= +google.golang.org/genproto/googleapis/api v0.0.0-20240822170219-fc7c04adadcd h1:BBOTEWLuuEGQy9n1y9MhVJ9Qt0BDu21X8qZs71/uPZo= +google.golang.org/genproto/googleapis/api v0.0.0-20240822170219-fc7c04adadcd/go.mod h1:fO8wJzT2zbQbAjbIoos1285VfEIYKDDY+Dt+WpTkh6g= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240822170219-fc7c04adadcd h1:6TEm2ZxXoQmFWFlt1vNxvVOa1Q0dXFQD1m/rYjXmS0E= +google.golang.org/genproto/googleapis/rpc v0.0.0-20240822170219-fc7c04adadcd/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU= +google.golang.org/grpc v1.65.0 h1:bs/cUb4lp1G5iImFFd3u5ixQzweKizoZJAwBNLR42lc= +google.golang.org/grpc v1.65.0/go.mod h1:WgYC2ypjlB0EiQi6wdKixMqukr6lBc0Vo+oOgjrM5ZQ= +google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= +google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +trpc.group/trpc-go/trpc-a2a-go v0.2.5-0.20251023030722-7f02b57fd14a h1:dOon6HF2sPRFnhCLEiAeKPc21JHL2eX7UBWjIR8PLaY= +trpc.group/trpc-go/trpc-a2a-go v0.2.5-0.20251023030722-7f02b57fd14a/go.mod h1:Gtytau9Uoc3oPo/dpHvKit+tQn9Qlk5XFG1RiZTGqfk= diff --git a/graph/checkpoint/redis/option.go b/graph/checkpoint/redis/option.go new file mode 100644 index 000000000..de4fd49e5 --- /dev/null +++ b/graph/checkpoint/redis/option.go @@ -0,0 +1,69 @@ +// +// Tencent is pleased to support the open source community by making trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// +// + +// Package redis provides Redis-based checkpoint storage implementation +// for graph execution state persistence and recovery. +package redis + +import "time" + +const ( + defaultTTL = time.Hour * 24 * 7 // 7 days +) + +var ( + defaultOptions = Options{ + ttl: defaultTTL, + } +) + +// Options is the options for the redis checkpoint service. +type Options struct { + url string + instanceName string + extraOptions []any + ttl time.Duration +} + +// ServiceOpt is the option for the redis checkpoint service. +type Option func(*Options) + +// WithRedisClientURL creates a redis client from URL and sets it to the service. +func WithRedisClientURL(url string) Option { + return func(opts *Options) { + opts.url = url + } +} + +// WithRedisInstance uses a redis instance from storage. +// Note: WithRedisClientURL has higher priority than WithRedisInstance. +// If both are specified, WithRedisClientURL will be used. +func WithRedisInstance(instanceName string) Option { + return func(opts *Options) { + opts.instanceName = instanceName + } +} + +// WithExtraOptions sets the extra options for the redis checkpoint service. +// this option mainly used for the customized redis client builder, it will be passed to the builder. +func WithExtraOptions(extraOptions ...any) Option { + return func(opts *Options) { + opts.extraOptions = append(opts.extraOptions, extraOptions...) + } +} + +// WithTTL sets the TTL for the checkpoint data in redis. +func WithTTL(ttl time.Duration) Option { + return func(opts *Options) { + if ttl <= 0 { + ttl = defaultTTL + } + opts.ttl = ttl + } +} diff --git a/graph/checkpoint/redis/option_test.go b/graph/checkpoint/redis/option_test.go new file mode 100644 index 000000000..b656310f4 --- /dev/null +++ b/graph/checkpoint/redis/option_test.go @@ -0,0 +1,110 @@ +// +// Tencent is pleased to support the open source community by making trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// +// + +// Package redis provides Redis-based checkpoint storage implementation +// for graph execution state persistence and recovery. +package redis + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestWithRedisInstance(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "valid instance name", + input: "test-instance", + expected: "test-instance", + }, + { + name: "empty instance name", + input: "", + expected: "", + }, + { + name: "instance name with special characters", + input: "test-instance-123", + expected: "test-instance-123", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := Options{} + WithRedisInstance(tt.input)(&opts) + assert.Equal(t, tt.expected, opts.instanceName) + }) + } +} + +func TestWithExtraOptions(t *testing.T) { + tests := []struct { + name string + input []any + expected []any + }{ + { + name: "single option", + input: []any{"option1"}, + expected: []any{"option1"}, + }, + { + name: "multiple options", + input: []any{"option1", 123, true}, + expected: []any{"option1", 123, true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := Options{} + WithExtraOptions(tt.input...)(&opts) + assert.Equal(t, tt.expected, opts.extraOptions) + }) + } +} + +func TestWithTTL(t *testing.T) { + tests := []struct { + name string + input time.Duration + expected time.Duration + }{ + { + name: "valid TTL", + input: time.Hour * 48, + expected: time.Hour * 48, + }, + { + name: "zero TTL", + input: 0, + expected: defaultTTL, + }, + { + name: "negative TTL", + input: -time.Hour, + expected: defaultTTL, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + opts := Options{} + WithTTL(tt.input)(&opts) + assert.Equal(t, tt.expected, opts.ttl) + }) + } +} diff --git a/graph/checkpoint/redis/saver.go b/graph/checkpoint/redis/saver.go new file mode 100644 index 000000000..cad27895e --- /dev/null +++ b/graph/checkpoint/redis/saver.go @@ -0,0 +1,609 @@ +// +// Tencent is pleased to support the open source community by making trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// +// + +// Package redis provides Redis-based checkpoint storage implementation +// for graph execution state persistence and recovery. +package redis + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "sort" + "strconv" + "sync" + "time" + + "github.com/redis/go-redis/v9" + "trpc.group/trpc-go/trpc-a2a-go/log" + "trpc.group/trpc-go/trpc-agent-go/graph" + storage "trpc.group/trpc-go/trpc-agent-go/storage/redis" +) + +const ( + keyPrefixCheckpoint = "ckpt:" + keyPrefixCheckpointTS = "ckpt_ts:" + keyPrefixWrites = "writes:" + keyPrefixLineageNS = "lineage_ns:" +) + +const ( + lingeageIDKey = "lineage_id" + checkpointIDKey = "checkpoint_id" + checkpointNSKey = "checkpoint_ns" + parentCheckpointIDKey = "parent_checkpoint_id" + tsKey = "ts" + checkpointJSONKey = "checkpoint_json" + metadataJSONKey = "metadata_json" +) + +func checkpointKey(lineageID, checkpointNS, checkpointID string) string { + return fmt.Sprintf("%s%s:%s:%s", keyPrefixCheckpoint, lineageID, checkpointNS, checkpointID) +} + +func checkpointTSKey(lineageID, checkpointNS string) string { + if checkpointNS == "" { + return fmt.Sprintf("%s%s", keyPrefixCheckpointTS, lineageID) + } + return fmt.Sprintf("%s%s:%s", keyPrefixCheckpointTS, lineageID, checkpointNS) +} + +func writesKey(lineageID, checkpointNS, checkpointID string) string { + return fmt.Sprintf("%s%s:%s:%s", keyPrefixWrites, lineageID, checkpointNS, checkpointID) +} + +func lineageNSKey(lineageID string) string { + return fmt.Sprintf("%s%s", keyPrefixLineageNS, lineageID) +} + +type writeData struct { + TaskID string `json:"task_id"` + Idx int `json:"idx"` + Channel string `json:"channel"` + ValueJSON []byte `json:"value_json"` + TaskPath string `json:"task_path"` + Seq int64 `json:"seq"` +} + +// Saver is the redis checkpoint service. +type Saver struct { + opts Options + client redis.UniversalClient + once sync.Once // ensure Close is called only once +} + +// NewSaver creates a new saver. +func NewSaver(options ...Option) (*Saver, error) { + opts := defaultOptions + for _, option := range options { + option(&opts) + } + + builderOpts := []storage.ClientBuilderOpt{ + storage.WithClientBuilderURL(opts.url), + storage.WithExtraOptions(opts.extraOptions...), + } + + // if instance name set, and url not set, use instance name to create redis client + if opts.url == "" && opts.instanceName != "" { + var ok bool + if builderOpts, ok = storage.GetRedisInstance(opts.instanceName); !ok { + return nil, fmt.Errorf("redis instance %s not found", opts.instanceName) + } + } + + redisClient, err := storage.GetClientBuilder()(builderOpts...) + if err != nil { + return nil, fmt.Errorf("create redis client from url failed: %w", err) + } + + s := &Saver{ + opts: opts, + client: redisClient, + } + return s, nil +} + +// Get returns the checkpoint for the given config. +func (s *Saver) Get(ctx context.Context, config map[string]any) (*graph.Checkpoint, error) { + t, err := s.GetTuple(ctx, config) + if err != nil { + return nil, err + } + if t == nil { + return nil, nil + } + return t.Checkpoint, nil +} + +// GetTuple returns the checkpoint tuple for the given config. +func (s *Saver) GetTuple(ctx context.Context, config map[string]any) (*graph.CheckpointTuple, error) { + lineageID := graph.GetLineageID(config) + checkpointNS := graph.GetNamespace(config) + checkpointID := graph.GetCheckpointID(config) + + if lineageID == "" { + return nil, errors.New("lineage_id is required") + } + + checkpointID, err := s.findCheckpointID(ctx, lineageID, checkpointNS, checkpointID) + if err != nil { + return nil, err + } + if checkpointID == "" { + return nil, nil + } + + checkpointData, err := s.client.HGetAll(ctx, checkpointKey(lineageID, checkpointNS, checkpointID)).Result() + if err != nil { + return nil, fmt.Errorf("get checkpoint data: %w", err) + } + if len(checkpointData) == 0 { + return nil, nil + } + + var ckpt graph.Checkpoint + if err := json.Unmarshal([]byte(checkpointData["checkpoint_json"]), &ckpt); err != nil { + return nil, fmt.Errorf("unmarshal checkpoint: %w", err) + } + + var meta graph.CheckpointMetadata + if err := json.Unmarshal([]byte(checkpointData["metadata_json"]), &meta); err != nil { + return nil, fmt.Errorf("unmarshal metadata: %w", err) + } + + parentID := checkpointData[parentCheckpointIDKey] + ts, err := strconv.ParseInt(checkpointData["ts"], 10, 64) + if err != nil { + return nil, fmt.Errorf("parse timestamp: %w", err) + } + + writes, err := s.loadWrites(ctx, lineageID, checkpointNS, checkpointID) + if err != nil { + return nil, err + } + + var parentCfg map[string]any + if parentID != "" { + parentNS, err := s.findCheckpointNamespace(ctx, lineageID, parentID) + if err != nil { + return nil, err + } + parentCfg = graph.CreateCheckpointConfig(lineageID, parentID, parentNS) + } + + returnCfg := graph.CreateCheckpointConfig(lineageID, checkpointID, checkpointNS) + if ts > 0 { + ckpt.Timestamp = time.Unix(0, ts) + } + + return &graph.CheckpointTuple{ + Config: returnCfg, + Checkpoint: &ckpt, + Metadata: &meta, + ParentConfig: parentCfg, + PendingWrites: writes, + }, nil +} + +func (s *Saver) findCheckpointID(ctx context.Context, lineageID, checkpointNS, checkpointID string) (string, error) { + if checkpointID != "" { + return checkpointID, nil + } + // Find a latest checkpoint in the namespace. + key := checkpointTSKey(lineageID, checkpointNS) + members, err := s.client.ZRevRange(ctx, key, 0, 0).Result() + if err != nil { + return "", err + } + if len(members) == 0 { + return "", nil + } + return members[0], nil +} + +// List returns checkpoints for the lineage/namespace, with optional filters. +func (s *Saver) List(ctx context.Context, config map[string]any, filter *graph.CheckpointFilter) ([]*graph.CheckpointTuple, error) { + lineageID := graph.GetLineageID(config) + checkpointNS := graph.GetNamespace(config) + if lineageID == "" { + return nil, errors.New("lineage_id is required") + } + + checkpointIDs, err := s.getCheckpointIDs(ctx, lineageID, checkpointNS, filter) + if err != nil { + return nil, err + } + + var tuples []*graph.CheckpointTuple + for _, checkpointID := range checkpointIDs { + cfg := graph.CreateCheckpointConfig(lineageID, checkpointID, checkpointNS) + tuple, err := s.GetTuple(ctx, cfg) + if err != nil { + return nil, err + } + if tuple == nil { + continue + } + + if filter != nil && len(filter.Metadata) > 0 { + if tuple.Metadata == nil || tuple.Metadata.Extra == nil { + continue + } + matches := true + for key, value := range filter.Metadata { + if tuple.Metadata.Extra[key] != value { + matches = false + break + } + } + if !matches { + continue + } + } + tuples = append(tuples, tuple) + if filter != nil && filter.Limit > 0 && len(tuples) >= filter.Limit { + break + } + } + + return tuples, nil +} + +func (s *Saver) getCheckpointIDs(ctx context.Context, lineageID, checkpointNS string, filter *graph.CheckpointFilter) ([]string, error) { + key := checkpointTSKey(lineageID, checkpointNS) + var members []string + var err error + + if filter != nil && filter.Before != nil { + beforeID := graph.GetCheckpointID(filter.Before) + if beforeID != "" { + beforeScore, err := s.getCheckpointScore(ctx, lineageID, checkpointNS, beforeID) + if err != nil { + return nil, err + } + if beforeScore > 0 { + members, err = s.client.ZRangeByScore(ctx, key, &redis.ZRangeBy{ + Min: "0", + Max: fmt.Sprintf("(%d", beforeScore), + }).Result() + } + } + } + + if members == nil { + members, err = s.client.ZRevRange(ctx, key, 0, -1).Result() + } + if err != nil { + return nil, err + } + + var checkpointIDs []string + for _, id := range members { + if id == "" { + log.Warnf("invalid checkpoint id format: %s", id) + continue + } + checkpointIDs = append(checkpointIDs, id) + } + + return checkpointIDs, nil +} + +func (s *Saver) getCheckpointScore(ctx context.Context, lineageID, checkpointNS, checkpointID string) (int64, error) { + key := checkpointTSKey(lineageID, checkpointNS) + score, err := s.client.ZScore(ctx, key, checkpointID).Result() + if err != nil { + return 0, err + } + return int64(score), nil +} + +// Put stores the checkpoint and returns the updated config with checkpoint ID. +func (s *Saver) Put(ctx context.Context, req graph.PutRequest) (map[string]any, error) { + if req.Checkpoint == nil { + return nil, errors.New("checkpoint cannot be nil") + } + + lineageID := graph.GetLineageID(req.Config) + checkpointNS := graph.GetNamespace(req.Config) + if lineageID == "" { + return nil, errors.New("lineage_id is required") + } + + checkpointJSON, err := json.Marshal(req.Checkpoint) + if err != nil { + return nil, fmt.Errorf("marshal checkpoint: %w", err) + } + + if req.Metadata == nil { + req.Metadata = &graph.CheckpointMetadata{Source: graph.CheckpointSourceUpdate, Step: 0} + } + metadataJSON, err := json.Marshal(req.Metadata) + if err != nil { + return nil, fmt.Errorf("marshal metadata: %w", err) + } + + pipe := s.client.TxPipeline() + + checkpointID := req.Checkpoint.ID + ts := req.Checkpoint.Timestamp.UnixNano() + if ts <= 0 { + ts = time.Now().UTC().UnixNano() + } + + checkpointKey := checkpointKey(lineageID, checkpointNS, checkpointID) + pipe.HSet(ctx, checkpointKey, + lingeageIDKey, lineageID, + checkpointNSKey, checkpointNS, + checkpointIDKey, checkpointID, + parentCheckpointIDKey, req.Checkpoint.ParentCheckpointID, + tsKey, ts, + checkpointJSONKey, checkpointJSON, + metadataJSONKey, metadataJSON, + ) + pipe.Expire(ctx, checkpointKey, s.opts.ttl) + + tsKey := checkpointTSKey(lineageID, checkpointNS) + pipe.ZAdd(ctx, tsKey, redis.Z{ + Score: float64(ts), + Member: checkpointID, + }) + pipe.Expire(ctx, tsKey, s.opts.ttl) + + nsKey := lineageNSKey(lineageID) + pipe.SAdd(ctx, nsKey, checkpointNS) + pipe.Expire(ctx, nsKey, s.opts.ttl) + + if _, err := pipe.Exec(ctx); err != nil { + return nil, fmt.Errorf("redis transaction failed: %w", err) + } + + return graph.CreateCheckpointConfig(lineageID, checkpointID, checkpointNS), nil +} + +func (s *Saver) PutWrites(ctx context.Context, req graph.PutWritesRequest) error { + lineageID := graph.GetLineageID(req.Config) + checkpointNS := graph.GetNamespace(req.Config) + checkpointID := graph.GetCheckpointID(req.Config) + if lineageID == "" || checkpointID == "" { + return errors.New("lineage_id and checkpoint_id are required") + } + + pipe := s.client.Pipeline() + + writeKey := writesKey(lineageID, checkpointNS, checkpointID) + + for idx, w := range req.Writes { + valueJSON, err := json.Marshal(w.Value) + if err != nil { + return fmt.Errorf("marshal write: %w", err) + } + + seq := w.Sequence + if seq == 0 { + seq = int64(idx) + } + + writeData := writeData{ + TaskID: req.TaskID, + Idx: idx, + Channel: w.Channel, + ValueJSON: valueJSON, + TaskPath: req.TaskPath, + Seq: seq, + } + + field := fmt.Sprintf("%s:%d", req.TaskID, idx) + writeJSON, _ := json.Marshal(writeData) + pipe.HSet(ctx, writeKey, field, writeJSON) + } + pipe.Expire(ctx, writeKey, s.opts.ttl) + + _, err := pipe.Exec(ctx) + return err +} + +// PutFull atomically stores a checkpoint with its pending writes in a single transaction. +func (s *Saver) PutFull(ctx context.Context, req graph.PutFullRequest) (map[string]any, error) { + lineageID := graph.GetLineageID(req.Config) + checkpointNS := graph.GetNamespace(req.Config) + if lineageID == "" { + return nil, errors.New("lineage_id is required") + } + if req.Checkpoint == nil { + return nil, errors.New("checkpoint cannot be nil") + } + + checkpointJSON, err := json.Marshal(req.Checkpoint) + if err != nil { + return nil, fmt.Errorf("marshal checkpoint: %w", err) + } + + metadataJSON, err := json.Marshal(req.Metadata) + if err != nil { + return nil, fmt.Errorf("marshal metadata: %w", err) + } + + pipe := s.client.TxPipeline() + + checkpointID := req.Checkpoint.ID + ts := req.Checkpoint.Timestamp.UnixNano() + if ts <= 0 { + ts = time.Now().UTC().UnixNano() + } + + checkpointKey := checkpointKey(lineageID, checkpointNS, checkpointID) + pipe.HSet(ctx, checkpointKey, + lingeageIDKey, lineageID, + checkpointNSKey, checkpointNS, + checkpointIDKey, checkpointID, + parentCheckpointIDKey, req.Checkpoint.ParentCheckpointID, + tsKey, ts, + checkpointJSONKey, checkpointJSON, + metadataJSONKey, metadataJSON, + ) + pipe.Expire(ctx, checkpointKey, s.opts.ttl) + + tsKey := checkpointTSKey(lineageID, checkpointNS) + pipe.ZAdd(ctx, tsKey, redis.Z{ + Score: float64(ts), + Member: checkpointID, + }) + pipe.Expire(ctx, tsKey, s.opts.ttl) + + nsKey := lineageNSKey(lineageID) + pipe.SAdd(ctx, nsKey, checkpointNS) + pipe.Expire(ctx, nsKey, s.opts.ttl) + + writeKey := writesKey(lineageID, checkpointNS, checkpointID) + for idx, w := range req.PendingWrites { + valueJSON, err := json.Marshal(w.Value) + if err != nil { + return nil, fmt.Errorf("marshal write value: %w", err) + } + + seq := w.Sequence + if seq == 0 { + seq = time.Now().UnixNano() + } + + writeData := writeData{ + TaskID: w.TaskID, + Idx: idx, + Channel: w.Channel, + ValueJSON: valueJSON, + TaskPath: "", + Seq: seq, + } + + field := fmt.Sprintf("%s:%d", w.TaskID, idx) + writeJSON, err := json.Marshal(writeData) + if err != nil { + return nil, fmt.Errorf("marshal write data: %w", err) + } + pipe.HSet(ctx, writeKey, field, writeJSON) + } + pipe.Expire(ctx, writeKey, s.opts.ttl) + + if _, err := pipe.Exec(ctx); err != nil { + return nil, fmt.Errorf("redis transaction failed: %w", err) + } + + return graph.CreateCheckpointConfig(lineageID, checkpointID, checkpointNS), nil +} + +// DeleteLineage deletes all checkpoints and writes for the lineage. +func (s *Saver) DeleteLineage(ctx context.Context, lineageID string) error { + if lineageID == "" { + return errors.New("lineage_id is required") + } + + nsKey := lineageNSKey(lineageID) + namespaces, err := s.client.SMembers(ctx, nsKey).Result() + if err != nil { + return err + } + pipe := s.client.Pipeline() + + for _, ns := range namespaces { + tsKey := checkpointTSKey(lineageID, ns) + members, err := s.client.ZRange(ctx, tsKey, 0, -1).Result() + if err != nil { + continue + } + + for _, member := range members { + checkpointID := member + + ckptKey := checkpointKey(lineageID, ns, checkpointID) + pipe.Del(ctx, ckptKey) + + writeKey := writesKey(lineageID, ns, checkpointID) + pipe.Del(ctx, writeKey) + } + + pipe.Del(ctx, tsKey) + } + + pipe.Del(ctx, nsKey) + + _, err = pipe.Exec(ctx) + return err +} + +func (s *Saver) loadWrites(ctx context.Context, lineageID, checkpointNS, checkpointID string) ([]graph.PendingWrite, error) { + writeKey := writesKey(lineageID, checkpointNS, checkpointID) + writeMap, err := s.client.HGetAll(ctx, writeKey).Result() + if err != nil { + return nil, fmt.Errorf("get writes: %w", err) + } + + var writes []graph.PendingWrite + for _, writeJSON := range writeMap { + var writeData writeData + if err := json.Unmarshal([]byte(writeJSON), &writeData); err != nil { + continue + } + var value any + if err := json.Unmarshal(writeData.ValueJSON, &value); err != nil { + continue + } + + writes = append(writes, graph.PendingWrite{ + TaskID: writeData.TaskID, + Channel: writeData.Channel, + Value: value, + Sequence: writeData.Seq, + }) + } + + sort.Slice(writes, func(i, j int) bool { + return writes[i].Sequence < writes[j].Sequence + }) + + return writes, nil +} + +func (s *Saver) findCheckpointNamespace(ctx context.Context, lineageID, checkpointID string) (string, error) { + if checkpointID == "" || lineageID == "" { + return "", nil + } + + nsKey := lineageNSKey(lineageID) + namespaces, err := s.client.SMembers(ctx, nsKey).Result() + if err != nil { + return "", err + } + + for _, ns := range namespaces { + exists, err := s.client.Exists(ctx, checkpointKey(lineageID, ns, checkpointID)).Result() + if err != nil { + continue + } + if exists > 0 { + return ns, nil + } + } + + return "", nil +} + +// Close closes the service. +func (s *Saver) Close() error { + s.once.Do(func() { + // Close redis connection. + if s.client != nil { + s.client.Close() + } + }) + + return nil +} diff --git a/graph/checkpoint/redis/saver_test.go b/graph/checkpoint/redis/saver_test.go new file mode 100644 index 000000000..7716c9001 --- /dev/null +++ b/graph/checkpoint/redis/saver_test.go @@ -0,0 +1,1386 @@ +// +// Tencent is pleased to support the open source community by making trpc-agent-go available. +// +// Copyright (C) 2025 Tencent. All rights reserved. +// +// trpc-agent-go is licensed under the Apache License Version 2.0. +// +// + +// Package redis provides Redis-based checkpoint storage implementation +// for graph execution state persistence and recovery. +package redis + +import ( + "context" + "encoding/json" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/google/uuid" + "github.com/redis/go-redis/v9" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "trpc.group/trpc-go/trpc-agent-go/graph" + storage "trpc.group/trpc-go/trpc-agent-go/storage/redis" +) + +func setupTestRedis(t testing.TB) (string, func()) { + mr, err := miniredis.Run() + require.NoError(t, err) + cleanup := func() { + mr.Close() + } + return "redis://" + mr.Addr(), cleanup +} + +func buildRedisClient(t *testing.T, redisURL string) *redis.Client { + opts, err := redis.ParseURL(redisURL) + require.NoError(t, err) + return redis.NewClient(opts) +} + +func TestNewSaverWithRedisInstance_buildSuccess(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + const ( + name = "test-instance" + ) + + defer cleanup() + + storage.RegisterRedisInstance(name, storage.WithClientBuilderURL(redisURL)) + opts, ok := storage.GetRedisInstance(name) + require.True(t, ok, "expected instance to exist") + require.NotEmpty(t, opts, "expected at least one option") + + saver, err := NewSaver(WithRedisInstance(name)) + require.NoError(t, err) + defer saver.Close() +} + +func TestNewSaverWithRedisInstance_buildFailed(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + const ( + name = "test-instance" + ) + + defer cleanup() + + storage.RegisterRedisInstance(name, storage.WithClientBuilderURL(redisURL)) + opts, ok := storage.GetRedisInstance(name) + require.True(t, ok, "expected instance to exist") + require.NotEmpty(t, opts, "expected at least one option") + + saver, err := NewSaver(WithRedisInstance("no-instance")) + require.Error(t, err) + require.Nil(t, saver) +} + +func TestNewSaverWithRedisOption_Error(t *testing.T) { + saver, err := NewSaver(WithRedisClientURL("")) + require.Error(t, err) + require.Nil(t, saver) +} + +func TestRedisCheckpointSaver(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "test-lineage" + config := graph.CreateCheckpointConfig(lineageID, "", "") + + // Create a checkpoint. + checkpoint := graph.NewCheckpoint( + map[string]any{"counter": 1}, + map[string]int64{"counter": 1}, + map[string]map[string]int64{}, + ) + metadata := graph.NewCheckpointMetadata(graph.CheckpointSourceInput, -1) + + // Store checkpoint. + req := graph.PutRequest{ + Config: config, + Checkpoint: checkpoint, + Metadata: metadata, + NewVersions: map[string]int64{"counter": 1}, + } + updatedConfig, err := saver.Put(ctx, req) + require.NoError(t, err) + + // Verify updated config contains checkpoint ID. + checkpointID := graph.GetCheckpointID(updatedConfig) + assert.NotEmpty(t, checkpointID) + + // Retrieve checkpoint. + retrieved, err := saver.Get(ctx, updatedConfig) + require.NoError(t, err) + require.NotNil(t, retrieved) + + assert.NotEmpty(t, retrieved.ID) + // JSON unmarshaling converts integers to float64, so compare values properly. + assert.Equal(t, len(checkpoint.ChannelValues), len(retrieved.ChannelValues)) + for key, expectedVal := range checkpoint.ChannelValues { + actualVal, exists := retrieved.ChannelValues[key] + assert.True(t, exists, "Key %s should exist", key) + // Compare as float64 since JSON unmarshaling converts numbers to float64. + assert.Equal(t, float64(expectedVal.(int)), actualVal) + } + + // Test retrieving tuple. + tuple, err := saver.GetTuple(ctx, updatedConfig) + require.NoError(t, err) + require.NotNil(t, tuple) + + assert.NotEmpty(t, tuple.Checkpoint.ID) + assert.Equal(t, metadata.Source, tuple.Metadata.Source) + assert.Equal(t, metadata.Step, tuple.Metadata.Step) +} + +func TestRedisCheckpointSaverList(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "test-lineage" + config := graph.CreateCheckpointConfig(lineageID, "", "") + + // Create multiple checkpoints. + for i := 0; i < 3; i++ { + checkpoint := graph.NewCheckpoint( + map[string]any{"step": i}, + map[string]int64{"step": int64(i + 1)}, + map[string]map[string]int64{}, + ) + metadata := graph.NewCheckpointMetadata(graph.CheckpointSourceLoop, i) + + req := graph.PutRequest{ + Config: config, + Checkpoint: checkpoint, + Metadata: metadata, + NewVersions: map[string]int64{"step": int64(i + 1)}, + } + _, err := saver.Put(ctx, req) + require.NoError(t, err) + } + + // List checkpoints. + checkpoints, err := saver.List(ctx, config, nil) + require.NoError(t, err) + assert.Len(t, checkpoints, 3) + + // Test filtering by limit. + filter := &graph.CheckpointFilter{Limit: 2} + limited, err := saver.List(ctx, config, filter) + require.NoError(t, err) + assert.Len(t, limited, 2) +} + +func TestRedisCheckpointSaverWrites(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "test-lineage" + config := graph.CreateCheckpointConfig(lineageID, "", "") + + // Create a checkpoint first. + checkpoint := graph.NewCheckpoint( + map[string]any{"counter": 0}, + map[string]int64{"counter": 1}, + map[string]map[string]int64{}, + ) + metadata := graph.NewCheckpointMetadata(graph.CheckpointSourceInput, -1) + + req := graph.PutRequest{ + Config: config, + Checkpoint: checkpoint, + Metadata: metadata, + NewVersions: map[string]int64{"counter": 1}, + } + updatedConfig, err := saver.Put(ctx, req) + require.NoError(t, err) + + // Store writes. + writes := []graph.PendingWrite{ + {Channel: "counter", Value: 42}, + {Channel: "message", Value: "hello"}, + } + + writeReq := graph.PutWritesRequest{ + Config: updatedConfig, + Writes: writes, + TaskID: "task1", + TaskPath: "", + } + err = saver.PutWrites(ctx, writeReq) + require.NoError(t, err) + + // Retrieve tuple and verify writes. + tuple, err := saver.GetTuple(ctx, updatedConfig) + require.NoError(t, err) + require.NotNil(t, tuple) + + assert.Len(t, tuple.PendingWrites, 2) + assert.Equal(t, "counter", tuple.PendingWrites[0].Channel) + assert.Equal(t, float64(42), tuple.PendingWrites[0].Value) + assert.Equal(t, "message", tuple.PendingWrites[1].Channel) + assert.Equal(t, "hello", tuple.PendingWrites[1].Value) +} + +func TestRedisCheckpointSaverDeleteLineage(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "test-lineage" + config := graph.CreateCheckpointConfig(lineageID, "", "") + + // Create a checkpoint. + checkpoint := graph.NewCheckpoint( + map[string]any{"counter": 42}, + map[string]int64{"counter": 1}, + map[string]map[string]int64{}, + ) + metadata := graph.NewCheckpointMetadata(graph.CheckpointSourceInput, -1) + + req := graph.PutRequest{ + Config: config, + Checkpoint: checkpoint, + Metadata: metadata, + NewVersions: map[string]int64{"counter": 1}, + } + updatedConfig, err := saver.Put(ctx, req) + require.NoError(t, err) + + // Verify checkpoint exists. + retrieved, err := saver.Get(ctx, updatedConfig) + require.NoError(t, err) + assert.NotNil(t, retrieved) + + // Delete lineage. + err = saver.DeleteLineage(ctx, lineageID) + require.NoError(t, err) + + // Verify checkpoint is gone. + retrieved, err = saver.Get(ctx, updatedConfig) + require.NoError(t, err) + assert.Nil(t, retrieved) +} + +func TestRedisCheckpointSaverLatestCheckpoint(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "test-lineage" + config := graph.CreateCheckpointConfig(lineageID, "", "") + + // Create multiple checkpoints. + var checkpointIDs []string + for i := 0; i < 3; i++ { + // Add small delay to ensure different timestamps. + if i > 0 { + time.Sleep(10 * time.Millisecond) + } + checkpoint := graph.NewCheckpoint( + map[string]any{"step": i}, + map[string]int64{"step": int64(i + 1)}, + map[string]map[string]int64{}, + ) + metadata := graph.NewCheckpointMetadata(graph.CheckpointSourceLoop, i) + + req := graph.PutRequest{ + Config: config, + Checkpoint: checkpoint, + Metadata: metadata, + NewVersions: map[string]int64{"step": int64(i + 1)}, + } + updatedConfig, err := saver.Put(ctx, req) + require.NoError(t, err) + + checkpointID := graph.GetCheckpointID(updatedConfig) + checkpointIDs = append(checkpointIDs, checkpointID) + } + + // Get latest checkpoint (should be the last one created). + latest, err := saver.Get(ctx, config) + require.NoError(t, err) + require.NotNil(t, latest) + + // Debug: print what we got + t.Logf("Expected ID: %s, Got ID: %s", checkpointIDs[2], latest.ID) + t.Logf("Expected step: 2, Got step: %v", latest.ChannelValues["step"]) + + // Verify it's the latest checkpoint. + assert.Equal(t, checkpointIDs[2], latest.ID) + assert.Equal(t, float64(2), latest.ChannelValues["step"]) +} + +func TestRedis_GetTuple_EmptyDB_ReturnsNil(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + // No checkpoints inserted yet + cfg := graph.CreateCheckpointConfig("ln-empty", "", "") + tup, err := saver.GetTuple(ctx, cfg) + require.NoError(t, err) + assert.Nil(t, tup) +} + +func TestRedis_Put_MetadataDefault(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "ln-meta" + ns := "ns" + ck := graph.NewCheckpoint(map[string]any{"a": 1}, map[string]int64{"a": 1}, nil) + // Put with nil metadata should not error + cfg, err := saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig(lineageID, "", ns), Checkpoint: ck, Metadata: nil, NewVersions: map[string]int64{"a": 1}}) + require.NoError(t, err) + tup, err := saver.GetTuple(ctx, cfg) + require.NoError(t, err) + require.NotNil(t, tup) + // Metadata should exist with default Source + require.NotNil(t, tup.Metadata) + assert.NotEmpty(t, tup.Metadata.Source) +} + +func TestRedis_PutWrites_SequenceUsed(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + cfg, err := saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig("ln-writes", "", "ns"), Checkpoint: graph.NewCheckpoint(map[string]any{}, map[string]int64{}, nil), Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceInput, 0), NewVersions: map[string]int64{}}) + require.NoError(t, err) + + // Provide explicit sequence numbers + writes := []graph.PendingWrite{ + {TaskID: "t", Channel: "x", Value: 1, Sequence: 101}, + {TaskID: "t", Channel: "y", Value: 2, Sequence: 102}, + } + err = saver.PutWrites(ctx, graph.PutWritesRequest{Config: cfg, Writes: writes, TaskID: "t", TaskPath: "p"}) + require.NoError(t, err) + + tup, err := saver.GetTuple(ctx, cfg) + require.NoError(t, err) + require.NotNil(t, tup) + require.Len(t, tup.PendingWrites, 2) + assert.Equal(t, int64(101), tup.PendingWrites[0].Sequence) + assert.Equal(t, int64(102), tup.PendingWrites[1].Sequence) +} + +func TestRedis_PutFull_SequenceHonored(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "ln-full-seq" + ns := "ns" + ck := graph.NewCheckpoint(map[string]any{"v": 1}, map[string]int64{"v": 1}, nil) + cfg, err := saver.PutFull(ctx, graph.PutFullRequest{Config: graph.CreateCheckpointConfig(lineageID, "", ns), Checkpoint: ck, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceInput, 0), NewVersions: map[string]int64{"v": 1}, PendingWrites: []graph.PendingWrite{{TaskID: "t1", Channel: "c1", Value: 1, Sequence: 999}}}) + require.NoError(t, err) + + tup, err := saver.GetTuple(ctx, cfg) + require.NoError(t, err) + require.NotNil(t, tup) + require.Len(t, tup.PendingWrites, 1) + assert.Equal(t, int64(999), tup.PendingWrites[0].Sequence) +} + +func TestRedis_PutFull_SequenceZero_Assigned(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + cfg, err := saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig("ln-full0", "", "ns"), Checkpoint: graph.NewCheckpoint(map[string]any{}, map[string]int64{}, nil), Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceInput, 0), NewVersions: map[string]int64{}}) + require.NoError(t, err) + + // Write with Sequence zero should be assigned a non-zero sequence + _, err = saver.PutFull(ctx, graph.PutFullRequest{Config: cfg, Checkpoint: graph.NewCheckpoint(map[string]any{"v": 1}, map[string]int64{"v": 1}, nil), Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceLoop, 1), NewVersions: map[string]int64{"v": 1}, PendingWrites: []graph.PendingWrite{{TaskID: "t", Channel: "c", Value: 1, Sequence: 0}}}) + require.NoError(t, err) + + tup, err := saver.GetTuple(ctx, graph.CreateCheckpointConfig("ln-full0", "", "ns")) + require.NoError(t, err) + require.NotNil(t, tup) + require.Len(t, tup.PendingWrites, 1) + // Should be assigned + require.Greater(t, tup.PendingWrites[0].Sequence, int64(0)) +} + +func TestRedis_GetTuple_LatestInNamespace(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "ln-latest-ns" + + ck1 := graph.NewCheckpoint(map[string]any{"x": 1}, map[string]int64{"x": 1}, nil) + _, err = saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig(lineageID, "", "ns1"), Checkpoint: ck1, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceInput, 0), NewVersions: map[string]int64{"x": 1}}) + require.NoError(t, err) + time.Sleep(2 * time.Millisecond) + ck2 := graph.NewCheckpoint(map[string]any{"x": 2}, map[string]int64{"x": 2}, nil) + _, err = saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig(lineageID, "", "ns2"), Checkpoint: ck2, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceLoop, 1), NewVersions: map[string]int64{"x": 2}}) + require.NoError(t, err) + + // Latest in ns1 should be ck1, not ns2 + tup, err := saver.GetTuple(ctx, graph.CreateCheckpointConfig(lineageID, "", "ns1")) + require.NoError(t, err) + require.NotNil(t, tup) + assert.Equal(t, ck1.ID, tup.Checkpoint.ID) + assert.Equal(t, "ns1", graph.GetNamespace(tup.Config)) +} + +func TestRedis_Put_TimestampZero_UsesNow(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "ln-ts0" + ns := "ns" + ck := graph.NewCheckpoint(map[string]any{"x": 1}, map[string]int64{"x": 1}, nil) + // Zero out timestamp to force now assignment path + ck.Timestamp = time.Time{} + cfg, err := saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig(lineageID, "", ns), Checkpoint: ck, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceUpdate, 0), NewVersions: map[string]int64{"x": 1}}) + require.NoError(t, err) + // Should be retrievable + tup, err := saver.GetTuple(ctx, cfg) + require.NoError(t, err) + require.NotNil(t, tup) +} + +func TestRedisCheckpointSaverMetadataFilter(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "test-lineage" + config := graph.CreateCheckpointConfig(lineageID, "", "") + + // Create checkpoints with different metadata. + for i := 0; i < 3; i++ { + checkpoint := graph.NewCheckpoint( + map[string]any{"step": i}, + map[string]int64{"step": int64(i + 1)}, + map[string]map[string]int64{}, + ) + metadata := graph.NewCheckpointMetadata(graph.CheckpointSourceLoop, i) + metadata.Extra["type"] = "test" + if i == 1 { + metadata.Extra["special"] = "yes" + } + + req := graph.PutRequest{ + Config: config, + Checkpoint: checkpoint, + Metadata: metadata, + NewVersions: map[string]int64{"step": int64(i + 1)}, + } + _, err := saver.Put(ctx, req) + require.NoError(t, err) + } + + // Filter by metadata. + filter := &graph.CheckpointFilter{} + filter.WithMetadata("special", "yes") + + checkpoints, err := saver.List(ctx, config, filter) + require.NoError(t, err) + assert.Len(t, checkpoints, 1) + assert.Equal(t, float64(1), checkpoints[0].Checkpoint.ChannelValues["step"]) +} + +func TestRedis_List_MetadataFilter_NoExtraInTuple(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "ln-no-extra" + ns := "ns" + + // Manually insert a checkpoint with metadata JSON missing 'extra' field + ck := graph.NewCheckpoint(map[string]any{"x": 1}, map[string]int64{"x": 1}, nil) + ckJSON, _ := json.Marshal(ck) + // metadata without Extra + rawMeta := map[string]any{"source": graph.CheckpointSourceInput, "step": 0} + metaJSON, _ := json.Marshal(rawMeta) + db := buildRedisClient(t, redisURL) + pipe := db.TxPipeline() + checkpointKey := checkpointKey(lineageID, ns, ck.ID) + pipe.HSet(ctx, checkpointKey, + lingeageIDKey, lineageID, + checkpointNSKey, ns, + checkpointIDKey, ck.ID, + tsKey, time.Now().UTC().UnixNano(), + checkpointJSONKey, ckJSON, + metadataJSONKey, metaJSON, + ) + tsKey := checkpointTSKey(lineageID, ns) + pipe.ZAdd(ctx, tsKey, redis.Z{ + Score: float64(time.Now().UTC().UnixNano()), + Member: ck.ID, + }) + nsKey := lineageNSKey(lineageID) + pipe.SAdd(ctx, nsKey, ns) + _, err = pipe.Exec(ctx) + // _, err = db.ExecContext(ctx, sqliteInsertCheckpoint, lineageID, ns, ck.ID, "", time.Now().UTC().UnixNano(), ckJSON, metaJSON) + require.NoError(t, err) + + // List with metadata filter should exclude this tuple because Extra==nil + filter := &graph.CheckpointFilter{Metadata: map[string]any{"k": "v"}} + tuples, err := saver.List(ctx, graph.CreateCheckpointConfig(lineageID, "", ns), filter) + require.NoError(t, err) + // No tuples should match the metadata filter + require.Equal(t, 0, len(tuples)) + + // Listing without metadata filter should include 1 tuple + tuples2, err := saver.List(ctx, graph.CreateCheckpointConfig(lineageID, "", ns), nil) + require.NoError(t, err) + require.Equal(t, 1, len(tuples2)) +} + +func TestRedisCheckpointSaverClose(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + // Close should not error. + err = saver.Close() + assert.NoError(t, err) + + // Close again should not error. + err = saver.Close() + assert.NoError(t, err) +} + +func TestSQLite_GetTuple_ParentNamespaceUnknown_EmptyInParentConfig(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + // Insert a child row that references a non-existent parent ID to force findCheckpointNamespace to return empty namespace. + // Use Put to create a child (without actual parent) by bypassing ParentCheckpointID validation: we insert directly into DB. + // 1) Create a fake child checkpoint JSON + child := graph.NewCheckpoint(map[string]any{"v": 10}, map[string]int64{"v": 1}, nil) + child.ParentCheckpointID = "no-such-parent" + childJSON, _ := json.Marshal(child) + metaJSON, _ := json.Marshal(graph.NewCheckpointMetadata(graph.CheckpointSourceFork, 1)) + db := buildRedisClient(t, redisURL) + pipe := db.TxPipeline() + lineageID := "ln-unknown" + ns := "nsX" + checkpointKey := checkpointKey(lineageID, ns, child.ID) + pipe.HSet(ctx, checkpointKey, + lingeageIDKey, lineageID, + checkpointNSKey, ns, + checkpointIDKey, child.ID, + parentCheckpointIDKey, child.ParentCheckpointID, + tsKey, time.Now().UTC().UnixNano(), + checkpointJSONKey, childJSON, + metadataJSONKey, metaJSON, + ) + tsKey := checkpointTSKey(lineageID, ns) + pipe.ZAdd(ctx, tsKey, redis.Z{ + Score: float64(time.Now().UTC().UnixNano()), + Member: child.ID, + }) + nsKey := lineageNSKey(lineageID) + pipe.SAdd(ctx, nsKey, ns) + _, err = pipe.Exec(ctx) + require.NoError(t, err) + + cfg := graph.CreateCheckpointConfig("ln-unknown", child.ID, "nsX") + tup, err := saver.GetTuple(ctx, cfg) + require.NoError(t, err) + require.NotNil(t, tup) + require.NotNil(t, tup.ParentConfig) + assert.Equal(t, "", graph.GetNamespace(tup.ParentConfig)) + assert.Equal(t, child.ParentCheckpointID, graph.GetCheckpointID(tup.ParentConfig)) +} + +func TestRedis_GetTuple_CrossNamespaceLatestAndByID(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "ln-cross-ns" + + // Put a checkpoint in ns1 + ck1 := graph.NewCheckpoint(map[string]any{"n": 1}, map[string]int64{"n": 1}, map[string]map[string]int64{}) + cfgNS1 := graph.CreateCheckpointConfig(lineageID, "", "") + _, err = saver.Put(ctx, graph.PutRequest{Config: cfgNS1, Checkpoint: ck1, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceInput, 0), NewVersions: map[string]int64{"n": 1}}) + require.NoError(t, err) + + // Small delay to ensure distinct timestamps + time.Sleep(5 * time.Millisecond) + + // Put a checkpoint in ns2 + ck2 := graph.NewCheckpoint(map[string]any{"n": 2}, map[string]int64{"n": 2}, map[string]map[string]int64{}) + cfgNS2 := graph.CreateCheckpointConfig(lineageID, "", "") + _, err = saver.Put(ctx, graph.PutRequest{Config: cfgNS2, Checkpoint: ck2, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceLoop, 1), NewVersions: map[string]int64{"n": 2}}) + require.NoError(t, err) + + // Latest across namespaces with empty ns, empty id + latestCfg := graph.CreateCheckpointConfig(lineageID, "", "") + tuple, err := saver.GetTuple(ctx, latestCfg) + require.NoError(t, err) + require.NotNil(t, tuple) + // Should be the second one in ns2 + assert.Equal(t, ck2.ID, tuple.Checkpoint.ID) + assert.Equal(t, "", graph.GetNamespace(tuple.Config)) + + // Cross-namespace by ID with empty ns but specific id + byIDCfg := graph.CreateCheckpointConfig(lineageID, ck1.ID, "") + tuple2, err := saver.GetTuple(ctx, byIDCfg) + require.NoError(t, err) + require.NotNil(t, tuple2) + assert.Equal(t, ck1.ID, tuple2.Checkpoint.ID) + assert.Equal(t, "", graph.GetNamespace(tuple2.Config)) +} + +func TestRedis_Put_DefaultMetadataWhenNil(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "ln-nil-meta" + cfg := graph.CreateCheckpointConfig(lineageID, "", "ns") + + ck := graph.NewCheckpoint(map[string]any{"x": 1}, map[string]int64{"x": 1}, map[string]map[string]int64{}) + // Put with nil metadata should be accepted and default to update/step 0 + updated, err := saver.Put(ctx, graph.PutRequest{Config: cfg, Checkpoint: ck, Metadata: nil, NewVersions: map[string]int64{"x": 1}}) + require.NoError(t, err) + + tup, err := saver.GetTuple(ctx, updated) + require.NoError(t, err) + require.NotNil(t, tup) + require.NotNil(t, tup.Metadata) + assert.Equal(t, graph.CheckpointSourceUpdate, tup.Metadata.Source) + assert.Equal(t, 0, tup.Metadata.Step) +} + +func TestRedis_PutWrites_SequenceOrdering(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "ln-seq" + cfg := graph.CreateCheckpointConfig(lineageID, "", "ns") + + ck := graph.NewCheckpoint(map[string]any{"a": 0}, map[string]int64{"a": 1}, map[string]map[string]int64{}) + updated, err := saver.Put(ctx, graph.PutRequest{Config: cfg, Checkpoint: ck, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceInput, -1), NewVersions: map[string]int64{"a": 1}}) + require.NoError(t, err) + + // Deliberately out-of-order sequences; query should order by seq + writes := []graph.PendingWrite{ + {TaskID: "t", Channel: "a", Value: 1, Sequence: 200}, + {TaskID: "t", Channel: "b", Value: 2, Sequence: 100}, + } + err = saver.PutWrites(ctx, graph.PutWritesRequest{Config: updated, Writes: writes, TaskID: "t"}) + require.NoError(t, err) + + tup, err := saver.GetTuple(ctx, updated) + require.NoError(t, err) + require.Len(t, tup.PendingWrites, 2) + // Ordered by seq ascending + assert.Equal(t, int64(100), tup.PendingWrites[0].Sequence) + assert.Equal(t, "b", tup.PendingWrites[0].Channel) + assert.Equal(t, int64(200), tup.PendingWrites[1].Sequence) + assert.Equal(t, "a", tup.PendingWrites[1].Channel) +} + +func TestRedis_PutFull_WithParentAndWrites(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "ln-putfull" + ns := "ns" + + // Parent checkpoint first + parent := graph.NewCheckpoint(map[string]any{"p": 1}, map[string]int64{"p": 1}, map[string]map[string]int64{}) + cfg := graph.CreateCheckpointConfig(lineageID, "", ns) + _, err = saver.Put(ctx, graph.PutRequest{Config: cfg, Checkpoint: parent, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceInput, 0), NewVersions: map[string]int64{"p": 1}}) + require.NoError(t, err) + + // Child via PutFull; ParentCheckpointID is carried from the checkpoint object + child := graph.NewCheckpoint(map[string]any{"c": 2}, map[string]int64{"c": 1}, map[string]map[string]int64{}) + child.ParentCheckpointID = parent.ID + + fullCfg, err := saver.PutFull(ctx, graph.PutFullRequest{ + Config: cfg, + Checkpoint: child, + Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceLoop, 1), + NewVersions: map[string]int64{"c": 1}, + PendingWrites: []graph.PendingWrite{{TaskID: "t1", Channel: "c", Value: 99}}, + }) + require.NoError(t, err) + + tup, err := saver.GetTuple(ctx, fullCfg) + require.NoError(t, err) + require.NotNil(t, tup) + assert.Equal(t, child.ID, tup.Checkpoint.ID) + // Parent in same namespace + require.NotNil(t, tup.ParentConfig) + assert.Equal(t, parent.ID, graph.GetCheckpointID(tup.ParentConfig)) + assert.Equal(t, ns, graph.GetNamespace(tup.ParentConfig)) + // Writes stored + require.Len(t, tup.PendingWrites, 1) + assert.Equal(t, "c", tup.PendingWrites[0].Channel) + assert.Equal(t, float64(99), tup.PendingWrites[0].Value) +} + +func TestRedis_PutFull_ParentConfig_CrossNamespace(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "ln-cross-parentcfg" + nsA := "nsA" + nsB := "nsB" + + // Parent in nsA + parent := graph.NewCheckpoint(map[string]any{"p": 1}, map[string]int64{"p": 1}, map[string]map[string]int64{}) + cfgA := graph.CreateCheckpointConfig(lineageID, "", nsA) + _, err = saver.Put(ctx, graph.PutRequest{Config: cfgA, Checkpoint: parent, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceInput, 0), NewVersions: map[string]int64{"p": 1}}) + require.NoError(t, err) + + // Child in nsB with ParentCheckpointID referencing parent in nsA + child := graph.NewCheckpoint(map[string]any{"c": 2}, map[string]int64{"c": 1}, map[string]map[string]int64{}) + child.ParentCheckpointID = parent.ID + cfgB := graph.CreateCheckpointConfig(lineageID, "", nsB) + fullCfg, err := saver.PutFull(ctx, graph.PutFullRequest{Config: cfgB, Checkpoint: child, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceFork, 1), NewVersions: map[string]int64{"c": 1}}) + require.NoError(t, err) + + // Load child tuple and verify ParentConfig points to parent's actual namespace (nsA) + tup, err := saver.GetTuple(ctx, fullCfg) + require.NoError(t, err) + require.NotNil(t, tup) + require.NotNil(t, tup.ParentConfig) + assert.Equal(t, parent.ID, graph.GetCheckpointID(tup.ParentConfig)) + assert.Equal(t, nsA, graph.GetNamespace(tup.ParentConfig)) +} + +func TestRedis_List_WithBeforeAndCrossNamespace(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "ln-before" + + // Create three checkpoints across two namespaces + ck1 := graph.NewCheckpoint(map[string]any{"i": 1}, map[string]int64{"i": 1}, map[string]map[string]int64{}) + _, err = saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig(lineageID, "", "nsA"), Checkpoint: ck1, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceInput, 0), NewVersions: map[string]int64{"i": 1}}) + require.NoError(t, err) + time.Sleep(5 * time.Millisecond) + ck2 := graph.NewCheckpoint(map[string]any{"i": 2}, map[string]int64{"i": 2}, map[string]map[string]int64{}) + _, err = saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig(lineageID, "", "nsA"), Checkpoint: ck2, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceLoop, 1), NewVersions: map[string]int64{"i": 2}}) + require.NoError(t, err) + time.Sleep(5 * time.Millisecond) + ck3 := graph.NewCheckpoint(map[string]any{"i": 3}, map[string]int64{"i": 3}, map[string]map[string]int64{}) + _, err = saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig(lineageID, "", "nsA"), Checkpoint: ck3, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceLoop, 2), NewVersions: map[string]int64{"i": 3}}) + require.NoError(t, err) + + // Cross-namespace list with Before(ck3) should exclude ck3. + // Be tolerant on size/order across platforms; just ensure ck3 is excluded and ck1/ck2 appear if any. + cfgAll := graph.CreateCheckpointConfig(lineageID, "", "nsA") + filter := graph.NewCheckpointFilter().WithBefore(graph.CreateCheckpointConfig(lineageID, ck3.ID, "")).WithLimit(10) + tuples, err := saver.List(ctx, cfgAll, filter) + require.NoError(t, err) + have3 := false + for _, tu := range tuples { + if tu.Checkpoint.ID == ck3.ID { + have3 = true + } + } + assert.False(t, have3, "ck3 should be excluded by Before filter") + // If results present, they must be among {ck1, ck2} + for _, tu := range tuples { + assert.True(t, tu.Checkpoint.ID == ck1.ID || tu.Checkpoint.ID == ck2.ID) + } + + // Namespace-specific list with Before(ck3) in nsA should return only ck1 + cfgNsA := graph.CreateCheckpointConfig(lineageID, "", "nsA") + filter2 := graph.NewCheckpointFilter().WithBefore(graph.CreateCheckpointConfig(lineageID, ck3.ID, "nsA")) + tuples2, err := saver.List(ctx, cfgNsA, filter2) + require.NoError(t, err) + // Should not include ck3 + for _, tu := range tuples2 { + assert.NotEqual(t, tu.Checkpoint.ID, ck3.ID) + } + if len(tuples2) > 0 { + assert.Equal(t, ck1.ID, tuples2[0].Checkpoint.ID) + } +} + +func TestRedis_List_CrossNamespace_Limit1(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + ctx := context.Background() + lineageID := "ln-limit" + // three checkpoints across namespaces + _, err = saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig(lineageID, "", "ns1"), Checkpoint: graph.NewCheckpoint(map[string]any{"i": 1}, map[string]int64{"i": 1}, nil), Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceInput, 0), NewVersions: map[string]int64{"i": 1}}) + require.NoError(t, err) + time.Sleep(1 * time.Millisecond) + _, err = saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig(lineageID, "", "ns2"), Checkpoint: graph.NewCheckpoint(map[string]any{"i": 2}, map[string]int64{"i": 2}, nil), Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceLoop, 1), NewVersions: map[string]int64{"i": 2}}) + require.NoError(t, err) + time.Sleep(1 * time.Millisecond) + _, err = saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig(lineageID, "", "ns1"), Checkpoint: graph.NewCheckpoint(map[string]any{"i": 3}, map[string]int64{"i": 3}, nil), Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceLoop, 2), NewVersions: map[string]int64{"i": 3}}) + require.NoError(t, err) + + cfgAll := graph.CreateCheckpointConfig(lineageID, "", "ns1") + tuples, err := saver.List(ctx, cfgAll, &graph.CheckpointFilter{Limit: 1}) + require.NoError(t, err) + require.Equal(t, 1, len(tuples)) +} + +func TestRedis_List_NamespaceNotExists_ReturnsEmpty(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + ctx := context.Background() + // List in a namespace with no data + tuples, err := saver.List(ctx, graph.CreateCheckpointConfig("ln-empty-ns", "", "nsX"), nil) + require.NoError(t, err) + require.Equal(t, 0, len(tuples)) +} + +func TestRedis_PutFull_NilCheckpoint_Error(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + defer saver.Close() + _, err = saver.PutFull(context.Background(), graph.PutFullRequest{Config: graph.CreateCheckpointConfig("ln", "", "ns"), Checkpoint: nil}) + require.Error(t, err) +} + +func TestRedis_Get_MissingLineage_Error(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + defer saver.Close() + _, err = saver.Get(context.Background(), map[string]any{}) + require.Error(t, err) +} + +func TestRedis_List_MetadataMismatch_ReturnsEmpty(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + ctx := context.Background() + lineageID := "ln-meta-mismatch" + ck := graph.NewCheckpoint(map[string]any{"x": 1}, map[string]int64{"x": 1}, nil) + meta := graph.NewCheckpointMetadata(graph.CheckpointSourceLoop, 1) + meta.Extra["type"] = "test" + _, err = saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig(lineageID, "", "ns"), Checkpoint: ck, Metadata: meta, NewVersions: map[string]int64{"x": 1}}) + require.NoError(t, err) + // Mismatched metadata filter should yield no results + tuples, err := saver.List(ctx, graph.CreateCheckpointConfig(lineageID, "", "ns"), &graph.CheckpointFilter{Metadata: map[string]any{"type": "other"}}) + require.NoError(t, err) + require.Equal(t, 0, len(tuples)) +} + +func TestRedis_List_MissingLineage_Error(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + defer saver.Close() + _, err = saver.List(context.Background(), map[string]any{}, nil) + require.Error(t, err) +} + +func TestRedis_List_NamespaceWithLimit(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + defer saver.Close() + ctx := context.Background() + lineageID := "ln-ns-limit" + _, err = saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig(lineageID, "", "ns"), Checkpoint: graph.NewCheckpoint(map[string]any{"i": 1}, map[string]int64{"i": 1}, nil), Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceInput, 0), NewVersions: map[string]int64{"i": 1}}) + require.NoError(t, err) + _, err = saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig(lineageID, "", "ns"), Checkpoint: graph.NewCheckpoint(map[string]any{"i": 2}, map[string]int64{"i": 2}, nil), Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceLoop, 1), NewVersions: map[string]int64{"i": 2}}) + require.NoError(t, err) + tuples, err := saver.List(ctx, graph.CreateCheckpointConfig(lineageID, "", "ns"), &graph.CheckpointFilter{Limit: 1}) + require.NoError(t, err) + require.Equal(t, 1, len(tuples)) +} + +func TestRedis_PutFull_NoWrites_Success_NoPendingWrites(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + ctx := context.Background() + lineageID := "ln-pf-nowrites" + ns := "ns" + ck := graph.NewCheckpoint(map[string]any{"v": 1}, map[string]int64{"v": 1}, nil) + cfg, err := saver.PutFull(ctx, graph.PutFullRequest{Config: graph.CreateCheckpointConfig(lineageID, "", ns), Checkpoint: ck, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceInput, 0), NewVersions: map[string]int64{"v": 1}}) + require.NoError(t, err) + tup, err := saver.GetTuple(ctx, cfg) + require.NoError(t, err) + require.NotNil(t, tup) + require.Equal(t, 0, len(tup.PendingWrites)) +} + +func TestRedis_PutWrites_SequenceZero_UsesIndex(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + defer saver.Close() + ctx := context.Background() + cfg, err := saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig("ln-pw-idx", "", "ns"), Checkpoint: graph.NewCheckpoint(map[string]any{"a": 1}, map[string]int64{"a": 1}, nil), Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceInput, 0), NewVersions: map[string]int64{"a": 1}}) + require.NoError(t, err) + // both Sequence=0 -> DB uses idx (0 and 1) + err = saver.PutWrites(ctx, graph.PutWritesRequest{Config: cfg, Writes: []graph.PendingWrite{{TaskID: "t", Channel: "c", Value: 1, Sequence: 0}, {TaskID: "t", Channel: "d", Value: 2, Sequence: 0}}}) + require.NoError(t, err) + tup, err := saver.GetTuple(ctx, cfg) + require.NoError(t, err) + require.Len(t, tup.PendingWrites, 2) + require.Equal(t, int64(0), tup.PendingWrites[0].Sequence) + require.Equal(t, int64(1), tup.PendingWrites[1].Sequence) +} + +func TestRedis_NoParent_ParentConfigNil(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + ctx := context.Background() + cfg, err := saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig("ln-nopar", "", "ns"), Checkpoint: graph.NewCheckpoint(map[string]any{"x": 1}, map[string]int64{"x": 1}, nil), Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceInput, 0), NewVersions: map[string]int64{"x": 1}}) + require.NoError(t, err) + tup, err := saver.GetTuple(ctx, cfg) + require.NoError(t, err) + require.NotNil(t, tup) + require.Nil(t, tup.ParentConfig) +} + +func TestRedis_findCheckpointNamespace_EmptyArgs(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + ns, err := saver.findCheckpointNamespace(context.Background(), "", "") + require.NoError(t, err) + require.Equal(t, "", ns) +} + +func TestRedis_findCheckpointNamespace_NoRows(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + ctx := context.Background() + // Insert a checkpoint in nsA + _, err = saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig("ln-fc", "", "nsA"), Checkpoint: graph.NewCheckpoint(map[string]any{"x": 1}, map[string]int64{"x": 1}, nil), Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceInput, 0), NewVersions: map[string]int64{"x": 1}}) + require.NoError(t, err) + // Lookup non-existing parent id + ns, err := saver.findCheckpointNamespace(ctx, "ln-fc", "no-such") + require.NoError(t, err) + require.Equal(t, "", ns) +} + +func TestRedis_PutFull_SequenceZero_AssignsTime(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + ctx := context.Background() + cfg, err := saver.PutFull(ctx, graph.PutFullRequest{ + Config: graph.CreateCheckpointConfig("ln-pf-seq0", "", "ns"), + Checkpoint: graph.NewCheckpoint(map[string]any{"x": 1}, map[string]int64{"x": 1}, nil), + Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceInput, 0), + NewVersions: map[string]int64{"x": 1}, + PendingWrites: []graph.PendingWrite{{ + TaskID: "t", + Channel: "c", + Value: 1, + Sequence: 0, + }}, + }) + require.NoError(t, err) + tup, err := saver.GetTuple(ctx, cfg) + require.NoError(t, err) + require.NotNil(t, tup) + require.Len(t, tup.PendingWrites, 1) + require.Greater(t, tup.PendingWrites[0].Sequence, int64(0)) +} + +func TestRedis_ErrorCases(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + + // GetTuple with missing lineage id should error + _, err = saver.GetTuple(ctx, map[string]any{}) + require.Error(t, err) + assert.Contains(t, err.Error(), "lineage_id is required") + + // Put with missing lineage id should error + _, err = saver.Put(ctx, graph.PutRequest{Config: map[string]any{"configurable": map[string]any{}}, Checkpoint: graph.NewCheckpoint(nil, nil, nil)}) + require.Error(t, err) + assert.Contains(t, err.Error(), "lineage_id is required") + + // PutWrites with missing checkpoint id should error + err = saver.PutWrites(ctx, graph.PutWritesRequest{Config: graph.CreateCheckpointConfig("ln", "", "")}) + require.Error(t, err) + assert.Contains(t, err.Error(), "lineage_id and checkpoint_id are required") + + // PutFull with missing lineage id should error + _, err = saver.PutFull(ctx, graph.PutFullRequest{Config: map[string]any{"configurable": map[string]any{}}, Checkpoint: graph.NewCheckpoint(nil, nil, nil)}) + require.Error(t, err) + assert.Contains(t, err.Error(), "lineage_id is required") + + // DeleteLineage with empty id should error + err = saver.DeleteLineage(ctx, "") + require.Error(t, err) + assert.Contains(t, err.Error(), "lineage_id is required") +} + +func TestRedis_PutFull_WriteMarshalError(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "ln-marshal" + ns := "ns" + ck := graph.NewCheckpoint(map[string]any{"v": 1}, map[string]int64{"v": 1}, nil) + // Use a non-JSON-marshalable value (channel) to force error + _, err = saver.PutFull(ctx, graph.PutFullRequest{Config: graph.CreateCheckpointConfig(lineageID, "", ns), Checkpoint: ck, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceUpdate, 0), NewVersions: map[string]int64{"v": 1}, PendingWrites: []graph.PendingWrite{{TaskID: "t", Channel: "c", Value: make(chan int)}}}) + require.Error(t, err) +} + +func TestRedis_PutFull_WriteMarshalError_checkpoint(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "ln-marshal" + ns := "ns" + ck := graph.NewCheckpoint(map[string]any{"v": 1, "ch": make(chan int)}, map[string]int64{"v": 1}, nil) + // Use a non-JSON-marshalable value (channel) to force error + _, err = saver.PutFull(ctx, graph.PutFullRequest{Config: graph.CreateCheckpointConfig(lineageID, "", ns), Checkpoint: ck, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceUpdate, 0), NewVersions: map[string]int64{"v": 1}, PendingWrites: []graph.PendingWrite{{TaskID: "t", Channel: "c", Value: 1}}}) + require.Error(t, err) +} + +func TestRedis_PutFull_checkpoint_ts_isEmpty(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "ln-marshal" + ns := "ns" + ck := &graph.Checkpoint{ + Version: 1, + ID: uuid.New().String(), + ChannelValues: map[string]any{"v": 1}, + ChannelVersions: map[string]int64{"v": 1}, + VersionsSeen: map[string]map[string]int64{}, + } + // Use a non-JSON-marshalable value (channel) to force error + cb, err := saver.PutFull(ctx, graph.PutFullRequest{Config: graph.CreateCheckpointConfig(lineageID, "", ns), Checkpoint: ck, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceUpdate, 0), NewVersions: map[string]int64{"v": 1}, PendingWrites: []graph.PendingWrite{{TaskID: "t", Channel: "c", Value: 1}}}) + require.NoError(t, err) + assert.Equal(t, ck.ID, cb[graph.CfgKeyConfigurable].(map[string]any)[graph.CfgKeyCheckpointID]) +} + +func TestRedis_Put_checkpoint_ts_isEmpty(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + lineageID := "ln-marshal" + ns := "ns" + ck := &graph.Checkpoint{ + Version: 1, + ID: uuid.New().String(), + ChannelValues: map[string]any{"v": 1}, + ChannelVersions: map[string]int64{"v": 1}, + VersionsSeen: map[string]map[string]int64{}, + } + // Use a non-JSON-marshalable value (channel) to force error + cb, err := saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig(lineageID, "", ns), Checkpoint: ck, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceUpdate, 0), NewVersions: map[string]int64{"v": 1}}) + require.NoError(t, err) + assert.Equal(t, ck.ID, cb[graph.CfgKeyConfigurable].(map[string]any)[graph.CfgKeyCheckpointID]) +} + +func TestRedis_Close_NilDB_NoPanic(t *testing.T) { + s := &Saver{client: nil} + // Close should be no-op + assert.NoError(t, s.Close()) +} + +func TestRedis_Put_NilCheckpoint_Error(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + _, err = saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig("ln", "", "ns"), Checkpoint: nil}) + require.Error(t, err) +} + +func TestRedis_PutWrites_MarshalError(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + cfg, err := saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig("ln-pw", "", "ns"), Checkpoint: graph.NewCheckpoint(nil, nil, nil), Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceInput, 0), NewVersions: map[string]int64{}}) + require.NoError(t, err) + // Non-serializable write value to force marshal error + err = saver.PutWrites(ctx, graph.PutWritesRequest{Config: cfg, Writes: []graph.PendingWrite{{TaskID: "t", Channel: "c", Value: make(chan int)}}}) + require.Error(t, err) +} + +func TestRedis_findCheckpointNamespace_Found(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + ctx := context.Background() + lineageID := "ln-find" + // Insert a parent in nsP + parent := graph.NewCheckpoint(map[string]any{"p": 1}, map[string]int64{"p": 1}, nil) + _, err = saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig(lineageID, "", "nsP"), Checkpoint: parent, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceInput, 0), NewVersions: map[string]int64{"p": 1}}) + require.NoError(t, err) + ns, err := saver.findCheckpointNamespace(ctx, lineageID, parent.ID) + require.NoError(t, err) + assert.Equal(t, "nsP", ns) +} + +func TestRedis_NewSaver_DBError(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() +} + +func TestRedis_Put_CheckpointMarshalError(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + ck := graph.NewCheckpoint(map[string]any{"bad": make(chan int)}, map[string]int64{}, nil) + _, err = saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig("ln-bad", "", "ns"), Checkpoint: ck, Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceUpdate, 0), NewVersions: map[string]int64{}}) + require.Error(t, err) +} + +func TestRedis_Put_MetadataMarshalError(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + ctx := context.Background() + ck := graph.NewCheckpoint(map[string]any{"x": 1}, map[string]int64{"x": 1}, nil) + meta := graph.NewCheckpointMetadata(graph.CheckpointSourceUpdate, 0) + meta.Extra["bad"] = make(chan int) + _, err = saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig("ln-meta-err", "", "ns"), Checkpoint: ck, Metadata: meta, NewVersions: map[string]int64{"x": 1}}) + require.Error(t, err) +} + +func TestRedis_PutFull_MetadataMarshalError(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + + ctx := context.Background() + ck := graph.NewCheckpoint(map[string]any{"x": 1}, map[string]int64{"x": 1}, nil) + meta := graph.NewCheckpointMetadata(graph.CheckpointSourceUpdate, 0) + // Force marshal error via extra with non-serializable value + meta.Extra["bad"] = make(chan int) + _, err = saver.PutFull(ctx, graph.PutFullRequest{Config: graph.CreateCheckpointConfig("ln-meta-bad", "", "ns"), Checkpoint: ck, Metadata: meta, NewVersions: map[string]int64{"x": 1}}) + require.Error(t, err) +} + +func TestRedis_DeleteLineage_NullValue(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + err = saver.DeleteLineage(context.Background(), "ln-del") + require.NoError(t, err) +} + +func TestRedis_DeleteLineage_SecondExecError(t *testing.T) { + redisURL, cleanup := setupTestRedis(t) + defer cleanup() + + saver, err := NewSaver(WithRedisClientURL(redisURL)) + require.NoError(t, err) + defer saver.Close() + ctx := context.Background() + // Put a checkpoint and a write + cfg, err := saver.Put(ctx, graph.PutRequest{Config: graph.CreateCheckpointConfig("ln-del2", "", "ns"), Checkpoint: graph.NewCheckpoint(map[string]any{"x": 1}, map[string]int64{"x": 1}, nil), Metadata: graph.NewCheckpointMetadata(graph.CheckpointSourceInput, 0), NewVersions: map[string]int64{"x": 1}}) + require.NoError(t, err) + _ = saver.PutWrites(ctx, graph.PutWritesRequest{Config: cfg, Writes: []graph.PendingWrite{{TaskID: "t", Channel: "c", Value: 1}}}) + // Drop writes table to force second delete to fail + err = saver.DeleteLineage(ctx, "ln-del2") + require.NoError(t, err) +}