diff --git a/demo/cmd/mood/main.go b/demo/cmd/mood/main.go index 555de06584..96f7658c93 100644 --- a/demo/cmd/mood/main.go +++ b/demo/cmd/mood/main.go @@ -3,11 +3,12 @@ package main import ( "context" "fmt" - "github.com/wundergraph/cosmo/demo/pkg/subgraphs/mood" "log" "net/http" "os" + "github.com/wundergraph/cosmo/demo/pkg/subgraphs/mood" + "github.com/99designs/gqlgen/graphql" "github.com/99designs/gqlgen/graphql/handler/debug" "github.com/99designs/gqlgen/graphql/playground" @@ -31,7 +32,9 @@ func main() { port = defaultPort } - srv := subgraphs.NewDemoServer(mood.NewSchema(nil)) + srv := subgraphs.NewDemoServer(mood.NewSchema(nil, func(name string) string { + return name + })) srv.Use(&debug.Tracer{}) srv.Use(otelgqlgen.Middleware(otelgqlgen.WithCreateSpanFromFields(func(ctx *graphql.FieldContext) bool { diff --git a/demo/pkg/subgraphs/availability/availability.go b/demo/pkg/subgraphs/availability/availability.go index f37ce470a5..412ffd1ac3 100644 --- a/demo/pkg/subgraphs/availability/availability.go +++ b/demo/pkg/subgraphs/availability/availability.go @@ -2,13 +2,13 @@ package availability import ( "github.com/99designs/gqlgen/graphql" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/availability/subgraph" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/availability/subgraph/generated" ) -func NewSchema(pubSubBySourceName map[string]pubsub_datasource.NatsPubSub, pubSubName func(string) string) graphql.ExecutableSchema { +func NewSchema(pubSubBySourceName map[string]nats.AdapterInterface, pubSubName func(string) string) graphql.ExecutableSchema { return generated.NewExecutableSchema(generated.Config{Resolvers: &subgraph.Resolver{ NatsPubSubByProviderID: pubSubBySourceName, GetPubSubName: pubSubName, diff --git a/demo/pkg/subgraphs/availability/subgraph/resolver.go b/demo/pkg/subgraphs/availability/subgraph/resolver.go index 9ac6f82f89..a5461fac82 100644 --- a/demo/pkg/subgraphs/availability/subgraph/resolver.go +++ b/demo/pkg/subgraphs/availability/subgraph/resolver.go @@ -1,7 +1,7 @@ package subgraph import ( - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" ) // This file will not be regenerated automatically. @@ -9,6 +9,6 @@ import ( // It serves as dependency injection for your app, add any dependencies you require here. type Resolver struct { - NatsPubSubByProviderID map[string]pubsub_datasource.NatsPubSub + NatsPubSubByProviderID map[string]nats.AdapterInterface GetPubSubName func(string) string } diff --git a/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go b/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go index 43797ed3dc..7067d88bf1 100644 --- a/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go +++ b/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go @@ -10,13 +10,13 @@ import ( "github.com/wundergraph/cosmo/demo/pkg/subgraphs/availability/subgraph/generated" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/availability/subgraph/model" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" ) // UpdateAvailability is the resolver for the updateAvailability field. func (r *mutationResolver) UpdateAvailability(ctx context.Context, employeeID int, isAvailable bool) (*model.Employee, error) { storage.Set(employeeID, isAvailable) - err := r.NatsPubSubByProviderID["default"].Publish(ctx, pubsub_datasource.NatsPublishAndRequestEventConfiguration{ + err := r.NatsPubSubByProviderID["default"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ Subject: r.GetPubSubName(fmt.Sprintf("employeeUpdated.%d", employeeID)), Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID)), }) @@ -24,7 +24,7 @@ func (r *mutationResolver) UpdateAvailability(ctx context.Context, employeeID in if err != nil { return nil, err } - err = r.NatsPubSubByProviderID["my-nats"].Publish(ctx, pubsub_datasource.NatsPublishAndRequestEventConfiguration{ + err = r.NatsPubSubByProviderID["my-nats"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ Subject: r.GetPubSubName(fmt.Sprintf("employeeUpdatedMyNats.%d", employeeID)), Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID)), }) diff --git a/demo/pkg/subgraphs/countries/countries.go b/demo/pkg/subgraphs/countries/countries.go index 3a1ccb7427..726ef1834a 100644 --- a/demo/pkg/subgraphs/countries/countries.go +++ b/demo/pkg/subgraphs/countries/countries.go @@ -2,13 +2,13 @@ package countries import ( "github.com/99designs/gqlgen/graphql" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/countries/subgraph" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/countries/subgraph/generated" ) -func NewSchema(pubSubBySourceName map[string]pubsub_datasource.NatsPubSub) graphql.ExecutableSchema { +func NewSchema(pubSubBySourceName map[string]nats.AdapterInterface) graphql.ExecutableSchema { return generated.NewExecutableSchema(generated.Config{Resolvers: &subgraph.Resolver{ NatsPubSubByProviderID: pubSubBySourceName, }}) diff --git a/demo/pkg/subgraphs/countries/subgraph/resolver.go b/demo/pkg/subgraphs/countries/subgraph/resolver.go index 4b235fdec9..e2f350a70a 100644 --- a/demo/pkg/subgraphs/countries/subgraph/resolver.go +++ b/demo/pkg/subgraphs/countries/subgraph/resolver.go @@ -1,8 +1,9 @@ package subgraph import ( - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" "sync" + + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" ) // This file will not be regenerated automatically. @@ -11,5 +12,5 @@ import ( type Resolver struct { mux sync.Mutex - NatsPubSubByProviderID map[string]pubsub_datasource.NatsPubSub + NatsPubSubByProviderID map[string]nats.AdapterInterface } diff --git a/demo/pkg/subgraphs/employees/employees.go b/demo/pkg/subgraphs/employees/employees.go index 408737da15..aa8c38b134 100644 --- a/demo/pkg/subgraphs/employees/employees.go +++ b/demo/pkg/subgraphs/employees/employees.go @@ -2,13 +2,12 @@ package employees import ( "github.com/99designs/gqlgen/graphql" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" - "github.com/wundergraph/cosmo/demo/pkg/subgraphs/employees/subgraph" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/employees/subgraph/generated" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" ) -func NewSchema(natsPubSubByProviderID map[string]pubsub_datasource.NatsPubSub) graphql.ExecutableSchema { +func NewSchema(natsPubSubByProviderID map[string]nats.AdapterInterface) graphql.ExecutableSchema { return generated.NewExecutableSchema(generated.Config{Resolvers: &subgraph.Resolver{ NatsPubSubByProviderID: natsPubSubByProviderID, EmployeesData: subgraph.Employees, diff --git a/demo/pkg/subgraphs/employees/subgraph/resolver.go b/demo/pkg/subgraphs/employees/subgraph/resolver.go index f71624ab76..8f416f6390 100644 --- a/demo/pkg/subgraphs/employees/subgraph/resolver.go +++ b/demo/pkg/subgraphs/employees/subgraph/resolver.go @@ -6,7 +6,7 @@ import ( "sync" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/employees/subgraph/model" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" ) // This file will not be regenerated automatically. @@ -15,7 +15,7 @@ import ( type Resolver struct { mux sync.Mutex - NatsPubSubByProviderID map[string]pubsub_datasource.NatsPubSub + NatsPubSubByProviderID map[string]nats.AdapterInterface EmployeesData []*model.Employee } diff --git a/demo/pkg/subgraphs/family/family.go b/demo/pkg/subgraphs/family/family.go index c55eae3fe4..aa3234ab45 100644 --- a/demo/pkg/subgraphs/family/family.go +++ b/demo/pkg/subgraphs/family/family.go @@ -2,13 +2,13 @@ package family import ( "github.com/99designs/gqlgen/graphql" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/family/subgraph" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/family/subgraph/generated" ) -func NewSchema(natsPubSubByProviderID map[string]pubsub_datasource.NatsPubSub) graphql.ExecutableSchema { +func NewSchema(natsPubSubByProviderID map[string]nats.AdapterInterface) graphql.ExecutableSchema { return generated.NewExecutableSchema(generated.Config{Resolvers: &subgraph.Resolver{ NatsPubSubByProviderID: natsPubSubByProviderID, }}) diff --git a/demo/pkg/subgraphs/family/subgraph/resolver.go b/demo/pkg/subgraphs/family/subgraph/resolver.go index f4678ba12e..6cea8bd318 100644 --- a/demo/pkg/subgraphs/family/subgraph/resolver.go +++ b/demo/pkg/subgraphs/family/subgraph/resolver.go @@ -1,7 +1,7 @@ package subgraph import ( - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" ) // This file will not be regenerated automatically. @@ -9,5 +9,5 @@ import ( // It serves as dependency injection for your app, add any dependencies you require here. type Resolver struct { - NatsPubSubByProviderID map[string]pubsub_datasource.NatsPubSub + NatsPubSubByProviderID map[string]nats.AdapterInterface } diff --git a/demo/pkg/subgraphs/hobbies/hobbies.go b/demo/pkg/subgraphs/hobbies/hobbies.go index 103e8bb43a..cf43f3ddfb 100644 --- a/demo/pkg/subgraphs/hobbies/hobbies.go +++ b/demo/pkg/subgraphs/hobbies/hobbies.go @@ -2,13 +2,13 @@ package hobbies import ( "github.com/99designs/gqlgen/graphql" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/hobbies/subgraph" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/hobbies/subgraph/generated" ) -func NewSchema(natsPubSubByProviderID map[string]pubsub_datasource.NatsPubSub) graphql.ExecutableSchema { +func NewSchema(natsPubSubByProviderID map[string]nats.AdapterInterface) graphql.ExecutableSchema { return generated.NewExecutableSchema(generated.Config{Resolvers: &subgraph.Resolver{ NatsPubSubByProviderID: natsPubSubByProviderID, }}) diff --git a/demo/pkg/subgraphs/hobbies/subgraph/resolver.go b/demo/pkg/subgraphs/hobbies/subgraph/resolver.go index dc972ffd31..9206076fc1 100644 --- a/demo/pkg/subgraphs/hobbies/subgraph/resolver.go +++ b/demo/pkg/subgraphs/hobbies/subgraph/resolver.go @@ -1,10 +1,10 @@ package subgraph import ( - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" "reflect" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/hobbies/subgraph/model" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" ) // This file will not be regenerated automatically. @@ -12,7 +12,7 @@ import ( // It serves as dependency injection for your app, add any dependencies you require here. type Resolver struct { - NatsPubSubByProviderID map[string]pubsub_datasource.NatsPubSub + NatsPubSubByProviderID map[string]nats.AdapterInterface } func (r *Resolver) Employees(hobby model.Hobby) ([]*model.Employee, error) { diff --git a/demo/pkg/subgraphs/mood/mood.go b/demo/pkg/subgraphs/mood/mood.go index 7083a607e7..efbc307be6 100644 --- a/demo/pkg/subgraphs/mood/mood.go +++ b/demo/pkg/subgraphs/mood/mood.go @@ -2,14 +2,15 @@ package mood import ( "github.com/99designs/gqlgen/graphql" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/mood/subgraph" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/mood/subgraph/generated" ) -func NewSchema(natsPubSubByProviderID map[string]pubsub_datasource.NatsPubSub) graphql.ExecutableSchema { +func NewSchema(natsPubSubByProviderID map[string]nats.AdapterInterface, getPubSubName func(string) string) graphql.ExecutableSchema { return generated.NewExecutableSchema(generated.Config{Resolvers: &subgraph.Resolver{ NatsPubSubByProviderID: natsPubSubByProviderID, + GetPubSubName: getPubSubName, }}) } diff --git a/demo/pkg/subgraphs/mood/subgraph/resolver.go b/demo/pkg/subgraphs/mood/subgraph/resolver.go index 9ac6f82f89..a5461fac82 100644 --- a/demo/pkg/subgraphs/mood/subgraph/resolver.go +++ b/demo/pkg/subgraphs/mood/subgraph/resolver.go @@ -1,7 +1,7 @@ package subgraph import ( - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" ) // This file will not be regenerated automatically. @@ -9,6 +9,6 @@ import ( // It serves as dependency injection for your app, add any dependencies you require here. type Resolver struct { - NatsPubSubByProviderID map[string]pubsub_datasource.NatsPubSub + NatsPubSubByProviderID map[string]nats.AdapterInterface GetPubSubName func(string) string } diff --git a/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go b/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go index aab22e4499..8ab7c73941 100644 --- a/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go +++ b/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go @@ -6,14 +6,42 @@ package subgraph import ( "context" + "fmt" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/mood/subgraph/generated" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/mood/subgraph/model" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" ) // UpdateMood is the resolver for the updateMood field. func (r *mutationResolver) UpdateMood(ctx context.Context, employeeID int, mood model.Mood) (*model.Employee, error) { storage.Set(employeeID, mood) + myNatsTopic := r.GetPubSubName(fmt.Sprintf("employeeUpdated.%d", employeeID)) + payload := fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID) + if r.NatsPubSubByProviderID["default"] != nil { + err := r.NatsPubSubByProviderID["default"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ + Subject: myNatsTopic, + Data: []byte(payload), + }) + if err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("no nats pubsub default provider found") + } + + defaultTopic := r.GetPubSubName(fmt.Sprintf("employeeUpdatedMyNats.%d", employeeID)) + if r.NatsPubSubByProviderID["my-nats"] != nil { + err := r.NatsPubSubByProviderID["my-nats"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ + Subject: defaultTopic, + Data: []byte(payload), + }) + if err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("no nats pubsub my-nats provider found") + } return &model.Employee{ID: employeeID, CurrentMood: mood}, nil } diff --git a/demo/pkg/subgraphs/products/products.go b/demo/pkg/subgraphs/products/products.go index f14cc97813..6a2c8c5984 100644 --- a/demo/pkg/subgraphs/products/products.go +++ b/demo/pkg/subgraphs/products/products.go @@ -2,13 +2,13 @@ package products import ( "github.com/99designs/gqlgen/graphql" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/products/subgraph" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/products/subgraph/generated" ) -func NewSchema(natsPubSubByProviderID map[string]pubsub_datasource.NatsPubSub) graphql.ExecutableSchema { +func NewSchema(natsPubSubByProviderID map[string]nats.AdapterInterface) graphql.ExecutableSchema { return generated.NewExecutableSchema(generated.Config{Resolvers: &subgraph.Resolver{ NatsPubSubByProviderID: natsPubSubByProviderID, TopSecretFederationFactsData: subgraph.TopSecretFederationFacts, diff --git a/demo/pkg/subgraphs/products/subgraph/resolver.go b/demo/pkg/subgraphs/products/subgraph/resolver.go index c9d610cb04..db8de70257 100644 --- a/demo/pkg/subgraphs/products/subgraph/resolver.go +++ b/demo/pkg/subgraphs/products/subgraph/resolver.go @@ -1,9 +1,10 @@ package subgraph import ( - "github.com/wundergraph/cosmo/demo/pkg/subgraphs/products/subgraph/model" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" "sync" + + "github.com/wundergraph/cosmo/demo/pkg/subgraphs/products/subgraph/model" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" ) // This file will not be regenerated automatically. @@ -12,6 +13,6 @@ import ( type Resolver struct { mux sync.Mutex - NatsPubSubByProviderID map[string]pubsub_datasource.NatsPubSub + NatsPubSubByProviderID map[string]nats.AdapterInterface TopSecretFederationFactsData []model.TopSecretFact } diff --git a/demo/pkg/subgraphs/products_fg/products.go b/demo/pkg/subgraphs/products_fg/products.go index 155cffe417..73e2082af4 100644 --- a/demo/pkg/subgraphs/products_fg/products.go +++ b/demo/pkg/subgraphs/products_fg/products.go @@ -2,13 +2,13 @@ package products_fg import ( "github.com/99designs/gqlgen/graphql" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/products_fg/subgraph" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/products_fg/subgraph/generated" ) -func NewSchema(natsPubSubByProviderID map[string]pubsub_datasource.NatsPubSub) graphql.ExecutableSchema { +func NewSchema(natsPubSubByProviderID map[string]nats.AdapterInterface) graphql.ExecutableSchema { return generated.NewExecutableSchema(generated.Config{Resolvers: &subgraph.Resolver{ NatsPubSubByProviderID: natsPubSubByProviderID, TopSecretFederationFactsData: subgraph.TopSecretFederationFacts, diff --git a/demo/pkg/subgraphs/products_fg/subgraph/resolver.go b/demo/pkg/subgraphs/products_fg/subgraph/resolver.go index f0f8d7059a..78a07d51ce 100644 --- a/demo/pkg/subgraphs/products_fg/subgraph/resolver.go +++ b/demo/pkg/subgraphs/products_fg/subgraph/resolver.go @@ -1,9 +1,10 @@ package subgraph import ( - "github.com/wundergraph/cosmo/demo/pkg/subgraphs/products_fg/subgraph/model" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" "sync" + + "github.com/wundergraph/cosmo/demo/pkg/subgraphs/products_fg/subgraph/model" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" ) // This file will not be regenerated automatically. @@ -12,6 +13,6 @@ import ( type Resolver struct { mux sync.Mutex - NatsPubSubByProviderID map[string]pubsub_datasource.NatsPubSub + NatsPubSubByProviderID map[string]nats.AdapterInterface TopSecretFederationFactsData []model.TopSecretFact } diff --git a/demo/pkg/subgraphs/subgraphs.go b/demo/pkg/subgraphs/subgraphs.go index abcaf75973..a66199e643 100644 --- a/demo/pkg/subgraphs/subgraphs.go +++ b/demo/pkg/subgraphs/subgraphs.go @@ -22,7 +22,6 @@ import ( "github.com/nats-io/nats.go" "github.com/nats-io/nats.go/jetstream" natsPubsub "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" "golang.org/x/sync/errgroup" "github.com/wundergraph/cosmo/demo/pkg/injector" @@ -162,7 +161,7 @@ func subgraphHandler(schema graphql.ExecutableSchema) http.Handler { } type SubgraphOptions struct { - NatsPubSubByProviderID map[string]pubsub_datasource.NatsPubSub + NatsPubSubByProviderID map[string]natsPubsub.AdapterInterface GetPubSubName func(string) string } @@ -195,7 +194,7 @@ func AvailabilityHandler(opts *SubgraphOptions) http.Handler { } func MoodHandler(opts *SubgraphOptions) http.Handler { - return subgraphHandler(mood.NewSchema(opts.NatsPubSubByProviderID)) + return subgraphHandler(mood.NewSchema(opts.NatsPubSubByProviderID, opts.GetPubSubName)) } func CountriesHandler(opts *SubgraphOptions) http.Handler { @@ -207,31 +206,30 @@ func New(ctx context.Context, config *Config) (*Subgraphs, error) { if defaultSourceNameURL := os.Getenv("NATS_URL"); defaultSourceNameURL != "" { url = defaultSourceNameURL } - defaultConnection, err := nats.Connect(url) + + natsPubSubByProviderID := map[string]natsPubsub.AdapterInterface{} + + defaultAdapter, err := natsPubsub.NewAdapter(ctx, zap.NewNop(), url, []nats.Option{}, "hostname", "test") if err != nil { - log.Printf("failed to connect to nats source \"nats\": %v", err) + return nil, fmt.Errorf("failed to create default nats adapter: %w", err) } + natsPubSubByProviderID["default"] = defaultAdapter - myNatsConnection, err := nats.Connect(url) + myNatsAdapter, err := natsPubsub.NewAdapter(ctx, zap.NewNop(), url, []nats.Option{}, "hostname", "test") if err != nil { - log.Printf("failed to connect to nats source \"my-nats\": %v", err) + return nil, fmt.Errorf("failed to create my-nats adapter: %w", err) } + natsPubSubByProviderID["my-nats"] = myNatsAdapter - defaultJetStream, err := jetstream.New(defaultConnection) + defaultConnection, err := nats.Connect(url) if err != nil { - return nil, err + log.Printf("failed to connect to nats source \"nats\": %v", err) } - - myNatsJetStream, err := jetstream.New(myNatsConnection) + defaultJetStream, err := jetstream.New(defaultConnection) if err != nil { return nil, err } - natsPubSubByProviderID := map[string]pubsub_datasource.NatsPubSub{ - "default": natsPubsub.NewConnector(zap.NewNop(), defaultConnection, defaultJetStream, "hostname", "test").New(ctx), - "my-nats": natsPubsub.NewConnector(zap.NewNop(), myNatsConnection, myNatsJetStream, "hostname", "test").New(ctx), - } - _, err = defaultJetStream.CreateOrUpdateStream(ctx, jetstream.StreamConfig{ Name: "streamName", Subjects: []string{"employeeUpdated.>"}, @@ -262,7 +260,7 @@ func New(ctx context.Context, config *Config) (*Subgraphs, error) { if srv := newServer("availability", config.EnableDebug, config.Ports.Availability, availability.NewSchema(natsPubSubByProviderID, config.GetPubSubName)); srv != nil { servers = append(servers, srv) } - if srv := newServer("mood", config.EnableDebug, config.Ports.Mood, mood.NewSchema(natsPubSubByProviderID)); srv != nil { + if srv := newServer("mood", config.EnableDebug, config.Ports.Mood, mood.NewSchema(natsPubSubByProviderID, config.GetPubSubName)); srv != nil { servers = append(servers, srv) } if srv := newServer("countries", config.EnableDebug, config.Ports.Countries, countries.NewSchema(natsPubSubByProviderID)); srv != nil { diff --git a/demo/pkg/subgraphs/test1/subgraph/resolver.go b/demo/pkg/subgraphs/test1/subgraph/resolver.go index f4678ba12e..6cea8bd318 100644 --- a/demo/pkg/subgraphs/test1/subgraph/resolver.go +++ b/demo/pkg/subgraphs/test1/subgraph/resolver.go @@ -1,7 +1,7 @@ package subgraph import ( - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" ) // This file will not be regenerated automatically. @@ -9,5 +9,5 @@ import ( // It serves as dependency injection for your app, add any dependencies you require here. type Resolver struct { - NatsPubSubByProviderID map[string]pubsub_datasource.NatsPubSub + NatsPubSubByProviderID map[string]nats.AdapterInterface } diff --git a/demo/pkg/subgraphs/test1/test1.go b/demo/pkg/subgraphs/test1/test1.go index 25f00b8ec7..3234af4209 100644 --- a/demo/pkg/subgraphs/test1/test1.go +++ b/demo/pkg/subgraphs/test1/test1.go @@ -2,13 +2,13 @@ package test1 import ( "github.com/99designs/gqlgen/graphql" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/test1/subgraph" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/test1/subgraph/generated" ) -func NewSchema(natsPubSubByProviderID map[string]pubsub_datasource.NatsPubSub) graphql.ExecutableSchema { +func NewSchema(natsPubSubByProviderID map[string]nats.AdapterInterface) graphql.ExecutableSchema { return generated.NewExecutableSchema(generated.Config{Resolvers: &subgraph.Resolver{ NatsPubSubByProviderID: natsPubSubByProviderID, }}) diff --git a/router-tests/events/kafka_events_test.go b/router-tests/events/kafka_events_test.go index a2175dd3b4..40d7041fb1 100644 --- a/router-tests/events/kafka_events_test.go +++ b/router-tests/events/kafka_events_test.go @@ -7,6 +7,7 @@ import ( "encoding/json" "fmt" "net/http" + "strconv" "sync/atomic" "testing" "time" @@ -1065,6 +1066,54 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForConnectionCount(0, KafkaWaitTimeout) }) }) + + t.Run("mutate", func(t *testing.T) { + t.Parallel() + + topics := []string{"employeeUpdated"} + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + }, func(t *testing.T, xEnv *testenv.Environment) { + ensureTopicExists(t, xEnv, topics...) + + // Send a mutation to trigger the first subscription + resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`, + }) + require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":true}}}`, resOne.Body) + + records, err := readKafkaMessages(xEnv, topics[0], 1) + require.NoError(t, err) + require.Equal(t, 1, len(records)) + require.Equal(t, `{"employeeID":3,"update":{"name":"name test"}}`, string(records[0].Value)) + }) + }) + + t.Run("kafka startup and shutdown with wrong broker should not stop router from starting indefinitely", func(t *testing.T) { + t.Parallel() + + listener := testenv.NewWaitingListener(t, time.Second*10) + listener.Start() + defer listener.Close() + + // kafka client is lazy and will not connect to the broker until the first message is produced + // so the router will start even if the kafka connection fails + errRouter := testenv.RunWithError(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + ModifyEventsConfiguration: func(config *config.EventsConfiguration) { + for i := range config.Providers.Kafka { + config.Providers.Kafka[i].Brokers = []string{"localhost:" + strconv.Itoa(listener.Port())} + } + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + t.Log("should be called") + }) + + assert.NoError(t, errRouter) + }) } func TestFlakyKafkaEvents(t *testing.T) { @@ -1245,3 +1294,20 @@ func produceKafkaMessage(t *testing.T, xEnv *testenv.Environment, topicName stri fErr := xEnv.KafkaClient.Flush(ctx) require.NoError(t, fErr) } + +func readKafkaMessages(xEnv *testenv.Environment, topicName string, msgs int) ([]*kgo.Record, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + client, err := kgo.NewClient( + kgo.SeedBrokers(xEnv.GetKafkaSeeds()...), + kgo.ConsumeTopics(xEnv.GetPubSubName(topicName)), + ) + if err != nil { + return nil, err + } + + fetchs := client.PollRecords(ctx, msgs) + + return fetchs.Records(), nil +} diff --git a/router-tests/events/nats_events_test.go b/router-tests/events/nats_events_test.go index b55adcd97a..156f6e9694 100644 --- a/router-tests/events/nats_events_test.go +++ b/router-tests/events/nats_events_test.go @@ -8,6 +8,7 @@ import ( "io" "net/http" "net/url" + "strconv" "sync/atomic" "testing" "time" @@ -139,7 +140,7 @@ func TestNatsEvents(t *testing.T) { natsLogs := xEnv.Observer().FilterMessageSnippet("Nats").All() require.Len(t, natsLogs, 4) providerIDFields := xEnv.Observer().FilterField(zap.String("provider_id", "my-nats")).All() - require.Len(t, providerIDFields, 2) + require.Len(t, providerIDFields, 3) }) }) @@ -1808,6 +1809,34 @@ func TestNatsEvents(t *testing.T) { assert.Eventually(t, completed.Load, NatsWaitTimeout, time.Millisecond*100) }) }) + + t.Run("NATS startup and shutdown with wrong URLs should not stop router from starting indefinitely", func(t *testing.T) { + t.Parallel() + + listener := testenv.NewWaitingListener(t, time.Second*10) + listener.Start() + defer listener.Close() + + errRouter := testenv.RunWithError(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsNatsJSONTemplate, + EnableNats: false, + ModifyEventsConfiguration: func(cfg *config.EventsConfiguration) { + url := "nats://127.0.0.1:" + strconv.Itoa(listener.Port()) + natsEventSources := make([]config.NatsEventSource, len(testenv.DemoNatsProviders)) + for _, sourceName := range testenv.DemoNatsProviders { + natsEventSources = append(natsEventSources, config.NatsEventSource{ + ID: sourceName, + URL: url, + }) + } + cfg.Providers.Nats = natsEventSources + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + assert.Fail(t, "Should not be called") + }) + + assert.Error(t, errRouter) + }) } func TestFlakyNatsEvents(t *testing.T) { diff --git a/router-tests/mcp_test.go b/router-tests/mcp_test.go index 22f47cebe8..51efd4c987 100644 --- a/router-tests/mcp_test.go +++ b/router-tests/mcp_test.go @@ -209,6 +209,7 @@ func TestMCP(t *testing.T) { t.Run("Execute Query", func(t *testing.T) { t.Run("Execute operation of type query with valid input", func(t *testing.T) { testenv.Run(t, &testenv.Config{ + EnableNats: true, MCP: config.MCPConfiguration{ Enabled: true, }, @@ -265,6 +266,7 @@ func TestMCP(t *testing.T) { t.Run("Execute Mutation", func(t *testing.T) { t.Run("Execute operation of type mutation with valid input", func(t *testing.T) { testenv.Run(t, &testenv.Config{ + EnableNats: true, MCP: config.MCPConfiguration{ Enabled: true, }, diff --git a/router-tests/structured_logging_test.go b/router-tests/structured_logging_test.go index 1bca435aef..b71fa8ae12 100644 --- a/router-tests/structured_logging_test.go +++ b/router-tests/structured_logging_test.go @@ -4,17 +4,19 @@ import ( "bytes" "encoding/json" "fmt" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "go.opentelemetry.io/otel/sdk/metric" - "go.uber.org/zap" - "go.uber.org/zap/zapcore" "math" "net/http" "os" "path/filepath" "testing" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/sdk/metric" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" + + "github.com/stretchr/testify/assert" + "github.com/wundergraph/cosmo/router-tests/testenv" "github.com/wundergraph/cosmo/router/core" "github.com/wundergraph/cosmo/router/pkg/config" @@ -162,11 +164,13 @@ func TestRouterStartLogs(t *testing.T) { }, }, func(t *testing.T, xEnv *testenv.Environment) { logEntries := xEnv.Observer().All() - require.Len(t, logEntries, 13) + require.Len(t, logEntries, 15) natsLogs := xEnv.Observer().FilterMessageSnippet("Nats Event source enabled").All() require.Len(t, natsLogs, 4) + natsConnectedLogs := xEnv.Observer().FilterMessageSnippet("NATS connection established").All() + require.Len(t, natsConnectedLogs, 4) providerIDFields := xEnv.Observer().FilterField(zap.String("provider_id", "default")).All() - require.Len(t, providerIDFields, 2) + require.Len(t, providerIDFields, 3) kafkaLogs := xEnv.Observer().FilterMessageSnippet("Kafka Event source enabled").All() require.Len(t, kafkaLogs, 2) playgroundLog := xEnv.Observer().FilterMessage("Serving GraphQL playground") diff --git a/router-tests/testenv/pubsub.go b/router-tests/testenv/pubsub.go index 9d67d9dd73..4011c231d4 100644 --- a/router-tests/testenv/pubsub.go +++ b/router-tests/testenv/pubsub.go @@ -5,24 +5,33 @@ import ( "time" "github.com/nats-io/nats.go" - "github.com/ory/dockertest/v3" - "github.com/twmb/franz-go/pkg/kgo" nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" ) -type KafkaData struct { - Client *kgo.Client - Brokers []string - Resource *dockertest.Resource +type NatsParams struct { + Opts []nats.Option + Url string } type NatsData struct { Connections []*nats.Conn + Params []*NatsParams } func setupNatsClients(t testing.TB) (*NatsData, error) { natsData := &NatsData{} - for range demoNatsProviders { + for range DemoNatsProviders { + param := &NatsParams{ + Url: nats.DefaultURL, + Opts: []nats.Option{ + nats.MaxReconnects(10), + nats.ReconnectWait(1 * time.Second), + nats.Timeout(5 * time.Second), + nats.ErrorHandler(func(conn *nats.Conn, subscription *nats.Subscription, err error) { + t.Log(err) + }), + }, + } natsConnection, err := nats.Connect( nats.DefaultURL, nats.MaxReconnects(10), @@ -35,6 +44,8 @@ func setupNatsClients(t testing.TB) (*NatsData, error) { if err != nil { return nil, err } + + natsData.Params = append(natsData.Params, param) natsData.Connections = append(natsData.Connections, natsConnection) } return natsData, nil diff --git a/router-tests/testenv/testenv.go b/router-tests/testenv/testenv.go index 637e68156f..e61c22261d 100644 --- a/router-tests/testenv/testenv.go +++ b/router-tests/testenv/testenv.go @@ -40,7 +40,6 @@ import ( "github.com/hashicorp/go-cleanhttp" "github.com/hashicorp/go-retryablehttp" "github.com/nats-io/nats.go" - "github.com/nats-io/nats.go/jetstream" "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" "github.com/twmb/franz-go/pkg/kadm" @@ -54,8 +53,6 @@ import ( "go.uber.org/zap/zaptest/observer" "google.golang.org/protobuf/encoding/protojson" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" - "github.com/wundergraph/cosmo/demo/pkg/subgraphs" "github.com/wundergraph/cosmo/router/core" nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" @@ -85,8 +82,8 @@ var ( ConfigWithEdfsNatsJSONTemplate string //go:embed testdata/configWithPlugins.json ConfigWithPluginsJSONTemplate string - demoNatsProviders = []string{natsDefaultSourceName, myNatsProviderID} - demoKafkaProviders = []string{myKafkaProviderID} + DemoNatsProviders = []string{natsDefaultSourceName, myNatsProviderID} + DemoKafkaProviders = []string{myKafkaProviderID} ) func init() { @@ -1266,18 +1263,18 @@ func configureRouter(listenerAddr string, testConfig *Config, routerConfig *node testConfig.ModifySubgraphErrorPropagation(&cfg.SubgraphErrorPropagation) } - natsEventSources := make([]config.NatsEventSource, len(demoNatsProviders)) - kafkaEventSources := make([]config.KafkaEventSource, len(demoKafkaProviders)) + natsEventSources := make([]config.NatsEventSource, len(DemoNatsProviders)) + kafkaEventSources := make([]config.KafkaEventSource, len(DemoKafkaProviders)) if natsData != nil { - for _, sourceName := range demoNatsProviders { + for _, sourceName := range DemoNatsProviders { natsEventSources = append(natsEventSources, config.NatsEventSource{ ID: sourceName, URL: nats.DefaultURL, }) } } - for _, sourceName := range demoKafkaProviders { + for _, sourceName := range DemoKafkaProviders { kafkaEventSources = append(kafkaEventSources, config.KafkaEventSource{ ID: sourceName, Brokers: testConfig.KafkaSeeds, @@ -1584,6 +1581,18 @@ func gqlURL(srv *httptest.Server) string { return path } +func ReadAndCheckJSON(t testing.TB, conn *websocket.Conn, v interface{}) (err error) { + _, payload, err := conn.ReadMessage() + if err != nil { + return err + } + if err := json.Unmarshal(payload, &v); err != nil { + t.Logf("Failed to decode WebSocket message. Raw payload: %s", string(payload)) + return err + } + return nil +} + type Environment struct { t testing.TB cfg *Config @@ -1630,6 +1639,10 @@ func (e *Environment) GetPubSubName(name string) string { return e.getPubSubName(name) } +func (e *Environment) GetKafkaSeeds() []string { + return e.cfg.KafkaSeeds +} + func (e *Environment) RouterConfigVersionMain() string { return e.routerConfigVersionMain } @@ -2187,8 +2200,7 @@ func (e *Environment) InitGraphQLWebSocketConnection(header http.Header, query u }) require.NoError(e.t, err) var ack WebSocketMessage - err = conn.ReadJSON(&ack) - require.NoError(e.t, err) + require.NoError(e.t, ReadAndCheckJSON(e.t, conn, &ack)) require.Equal(e.t, "connection_ack", ack.Type) return conn } @@ -2549,7 +2561,7 @@ func WSReadJSON(t testing.TB, conn *websocket.Conn, v interface{}) (err error) { return err } - err = conn.ReadJSON(v) + require.NoError(t, ReadAndCheckJSON(t, conn, v)) // Reset the deadline to prevent future operations from timing out if resetErr := conn.SetReadDeadline(time.Time{}); resetErr != nil { @@ -2665,16 +2677,19 @@ func WSWriteJSON(t testing.TB, conn *websocket.Conn, v interface{}) (err error) func subgraphOptions(ctx context.Context, t testing.TB, logger *zap.Logger, natsData *NatsData, pubSubName func(string) string) *subgraphs.SubgraphOptions { if natsData == nil { return &subgraphs.SubgraphOptions{ - NatsPubSubByProviderID: map[string]pubsub_datasource.NatsPubSub{}, + NatsPubSubByProviderID: map[string]pubsubNats.AdapterInterface{}, GetPubSubName: pubSubName, } } - natsPubSubByProviderID := make(map[string]pubsub_datasource.NatsPubSub, len(demoNatsProviders)) - for _, sourceName := range demoNatsProviders { - js, err := jetstream.New(natsData.Connections[0]) + natsPubSubByProviderID := make(map[string]pubsubNats.AdapterInterface, len(DemoNatsProviders)) + for _, sourceName := range DemoNatsProviders { + adapter, err := pubsubNats.NewAdapter(ctx, logger, natsData.Params[0].Url, natsData.Params[0].Opts, "hostname", "listenaddr") require.NoError(t, err) - - natsPubSubByProviderID[sourceName] = pubsubNats.NewConnector(logger, natsData.Connections[0], js, "hostname", "listenaddr").New(ctx) + require.NoError(t, adapter.Startup(ctx)) + t.Cleanup(func() { + require.NoError(t, adapter.Shutdown(context.Background())) + }) + natsPubSubByProviderID[sourceName] = adapter } return &subgraphs.SubgraphOptions{ diff --git a/router-tests/testenv/waitinglistener.go b/router-tests/testenv/waitinglistener.go new file mode 100644 index 0000000000..25aa32c80f --- /dev/null +++ b/router-tests/testenv/waitinglistener.go @@ -0,0 +1,54 @@ +package testenv + +import ( + "context" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +type WaitingListener struct { + cancel context.CancelFunc + listener *net.Listener + waitTime time.Duration + port int +} + +func (l *WaitingListener) Close() error { + l.cancel() + return (*l.listener).Close() +} + +func (l *WaitingListener) Start() { + go func() { + for { + conn, err := (*l.listener).Accept() + if err != nil { + return + } + time.Sleep(l.waitTime) + conn.Close() + } + }() +} + +func (l *WaitingListener) Port() int { + return l.port +} + +func NewWaitingListener(t *testing.T, waitTime time.Duration) (wl *WaitingListener) { + ctx, cancel := context.WithCancel(context.Background()) + var lc net.ListenConfig + listener, err := lc.Listen(ctx, "tcp", "127.0.0.1:0") + require.NoError(t, err) + + wl = &WaitingListener{ + cancel: cancel, + listener: &listener, + waitTime: waitTime, + port: listener.Addr().(*net.TCPAddr).Port, + } + return wl +} diff --git a/router-tests/websocket_test.go b/router-tests/websocket_test.go index f73a119ffc..64ae313215 100644 --- a/router-tests/websocket_test.go +++ b/router-tests/websocket_test.go @@ -2167,6 +2167,9 @@ func expectConnectAndReadCurrentTime(t *testing.T, xEnv *testenv.Environment) { err = testenv.WSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) + if msg.Type == "error" { + t.Logf("unexpected error on read: %s", string(msg.Payload)) + } require.Equal(t, "next", msg.Type) err = json.Unmarshal(msg.Payload, &payload) require.NoError(t, err) @@ -2176,6 +2179,9 @@ func expectConnectAndReadCurrentTime(t *testing.T, xEnv *testenv.Environment) { err = testenv.WSReadJSON(t, conn, &msg) require.NoError(t, err) require.Equal(t, "1", msg.ID) + if msg.Type == "error" { + t.Logf("unexpected error on read: %s", string(msg.Payload)) + } require.Equal(t, "next", msg.Type) err = json.Unmarshal(msg.Payload, &payload) require.NoError(t, err) @@ -2196,6 +2202,9 @@ func expectConnectAndReadCurrentTime(t *testing.T, xEnv *testenv.Environment) { err = testenv.WSReadJSON(t, conn, &complete) require.NoError(t, err) require.Equal(t, "1", complete.ID) + if complete.Type == "error" { + t.Logf("unexpected error on read: %s", string(complete.Payload)) + } require.Equal(t, "complete", complete.Type) err = conn.SetReadDeadline(time.Now().Add(1 * time.Second)) diff --git a/router/core/errors.go b/router/core/errors.go index 47f248f2e7..7f8df34da2 100644 --- a/router/core/errors.go +++ b/router/core/errors.go @@ -12,7 +12,7 @@ import ( rErrors "github.com/wundergraph/cosmo/router/internal/errors" "github.com/wundergraph/cosmo/router/internal/persistedoperation" "github.com/wundergraph/cosmo/router/internal/unique" - "github.com/wundergraph/cosmo/router/pkg/pubsub" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" rtrace "github.com/wundergraph/cosmo/router/pkg/trace" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" @@ -73,7 +73,7 @@ func getErrorType(err error) errorType { return errorTypeContextTimeout } } - var edfsErr *pubsub.Error + var edfsErr *datasource.Error if errors.As(err, &edfsErr) { return errorTypeEDFS } diff --git a/router/core/executor.go b/router/core/executor.go index 1606d1a640..a911890e4c 100644 --- a/router/core/executor.go +++ b/router/core/executor.go @@ -2,15 +2,10 @@ package core import ( "context" - "crypto/tls" - "errors" "fmt" "net/http" "time" - "github.com/nats-io/nats.go" - "github.com/twmb/franz-go/pkg/kgo" - "github.com/twmb/franz-go/pkg/sasl/plain" "go.uber.org/zap" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" @@ -23,6 +18,7 @@ import ( nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" "github.com/wundergraph/cosmo/router/pkg/config" + pubsub_datasource "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/cosmo/router/pkg/routerplugin" ) @@ -38,6 +34,7 @@ type ExecutorConfigurationBuilder struct { pluginHost *routerplugin.Host subscriptionClientOptions *SubscriptionClientOptions + instanceData InstanceData } type Executor struct { @@ -56,19 +53,19 @@ type ExecutorBuildOptions struct { EngineConfig *nodev1.EngineConfiguration Subgraphs []*nodev1.Subgraph RouterEngineConfig *RouterEngineConfiguration - PubSubProviders *EnginePubSubProviders Reporter resolve.Reporter ApolloCompatibilityFlags config.ApolloCompatibilityFlags ApolloRouterCompatibilityFlags config.ApolloRouterCompatibilityFlags HeartbeatInterval time.Duration TraceClientRequired bool PluginsEnabled bool + InstanceData InstanceData } -func (b *ExecutorConfigurationBuilder) Build(ctx context.Context, opts *ExecutorBuildOptions) (*Executor, error) { - planConfig, err := b.buildPlannerConfiguration(ctx, opts.EngineConfig, opts.Subgraphs, opts.RouterEngineConfig, opts.PubSubProviders, opts.PluginsEnabled) +func (b *ExecutorConfigurationBuilder) Build(ctx context.Context, opts *ExecutorBuildOptions) (*Executor, []pubsub_datasource.PubSubProvider, error) { + planConfig, providers, err := b.buildPlannerConfiguration(ctx, opts.EngineConfig, opts.Subgraphs, opts.RouterEngineConfig, opts.PluginsEnabled) if err != nil { - return nil, fmt.Errorf("failed to build planner configuration: %w", err) + return nil, nil, fmt.Errorf("failed to build planner configuration: %w", err) } options := resolve.ResolverOptions{ @@ -132,7 +129,7 @@ func (b *ExecutorConfigurationBuilder) Build(ctx context.Context, opts *Executor routerSchemaDefinition, report = astparser.ParseGraphqlDocumentString(opts.EngineConfig.GraphqlSchema) if report.HasErrors() { - return nil, fmt.Errorf("failed to parse graphql schema from engine config: %w", report) + return nil, providers, fmt.Errorf("failed to parse graphql schema from engine config: %w", report) } // we need to merge the base schema, it contains the __schema and __type queries, // as well as built-in scalars like Int, String, etc... @@ -140,7 +137,7 @@ func (b *ExecutorConfigurationBuilder) Build(ctx context.Context, opts *Executor // the engine needs to have them defined, otherwise it cannot resolve such fields err = asttransform.MergeDefinitionWithBaseSchema(&routerSchemaDefinition) if err != nil { - return nil, fmt.Errorf("failed to merge graphql schema with base schema: %w", err) + return nil, providers, fmt.Errorf("failed to merge graphql schema with base schema: %w", err) } if clientSchemaStr := opts.EngineConfig.GetGraphqlClientSchema(); clientSchemaStr != "" { @@ -149,11 +146,11 @@ func (b *ExecutorConfigurationBuilder) Build(ctx context.Context, opts *Executor clientSchema, report := astparser.ParseGraphqlDocumentString(clientSchemaStr) if report.HasErrors() { - return nil, fmt.Errorf("failed to parse graphql client schema from engine config: %w", report) + return nil, providers, fmt.Errorf("failed to parse graphql client schema from engine config: %w", report) } err = asttransform.MergeDefinitionWithBaseSchema(&clientSchema) if err != nil { - return nil, fmt.Errorf("failed to merge graphql client schema with base schema: %w", err) + return nil, providers, fmt.Errorf("failed to merge graphql client schema with base schema: %w", err) } clientSchemaDefinition = &clientSchema } else { @@ -169,7 +166,7 @@ func (b *ExecutorConfigurationBuilder) Build(ctx context.Context, opts *Executor // datasource is attached to Query.__schema, Query.__type, __Type.fields and __Type.enumValues fields introspectionFactory, err := introspection_datasource.NewIntrospectionConfigFactory(clientSchemaDefinition) if err != nil { - return nil, fmt.Errorf("failed to create introspection config factory: %w", err) + return nil, providers, fmt.Errorf("failed to create introspection config factory: %w", err) } fieldConfigs := introspectionFactory.BuildFieldConfigurations() // we need to add these fields to the config @@ -200,87 +197,15 @@ func (b *ExecutorConfigurationBuilder) Build(ctx context.Context, opts *Executor Resolver: resolver, RenameTypeNames: renameTypeNames, TrackUsageInfo: b.trackUsageInfo, - }, nil + }, providers, nil } -func buildNatsOptions(eventSource config.NatsEventSource, logger *zap.Logger) ([]nats.Option, error) { - opts := []nats.Option{ - nats.Name(fmt.Sprintf("cosmo.router.edfs.nats.%s", eventSource.ID)), - nats.ReconnectJitter(500*time.Millisecond, 2*time.Second), - nats.ClosedHandler(func(conn *nats.Conn) { - logger.Info("NATS connection closed", zap.String("provider_id", eventSource.ID), zap.Error(conn.LastError())) - }), - nats.ConnectHandler(func(nc *nats.Conn) { - logger.Info("NATS connection established", zap.String("provider_id", eventSource.ID), zap.String("url", nc.ConnectedUrlRedacted())) - }), - nats.DisconnectErrHandler(func(nc *nats.Conn, err error) { - if err != nil { - logger.Error("NATS disconnected; will attempt to reconnect", zap.Error(err), zap.String("provider_id", eventSource.ID)) - } else { - logger.Info("NATS disconnected", zap.String("provider_id", eventSource.ID)) - } - }), - nats.ErrorHandler(func(conn *nats.Conn, subscription *nats.Subscription, err error) { - if errors.Is(err, nats.ErrSlowConsumer) { - logger.Warn( - "NATS slow consumer detected. Events are being dropped. Please consider increasing the buffer size or reducing the number of messages being sent.", - zap.Error(err), - zap.String("provider_id", eventSource.ID), - ) - } else { - logger.Error("NATS error", zap.Error(err)) - } - }), - nats.ReconnectHandler(func(conn *nats.Conn) { - logger.Info("NATS reconnected", zap.String("provider_id", eventSource.ID), zap.String("url", conn.ConnectedUrlRedacted())) - }), - } - - if eventSource.Authentication != nil { - if eventSource.Authentication.Token != nil { - opts = append(opts, nats.Token(*eventSource.Authentication.Token)) - } else if eventSource.Authentication.UserInfo.Username != nil && eventSource.Authentication.UserInfo.Password != nil { - opts = append(opts, nats.UserInfo(*eventSource.Authentication.UserInfo.Username, *eventSource.Authentication.UserInfo.Password)) - } - } - - return opts, nil -} - -// buildKafkaOptions creates a list of kgo.Opt options for the given Kafka event source configuration. -// Only general options like TLS, SASL, etc. are configured here. Specific options like topics, etc. are -// configured in the KafkaPubSub implementation. -func buildKafkaOptions(eventSource config.KafkaEventSource) ([]kgo.Opt, error) { - opts := []kgo.Opt{ - kgo.SeedBrokers(eventSource.Brokers...), - // Ensure proper timeouts are set - kgo.ProduceRequestTimeout(10 * time.Second), - kgo.ConnIdleTimeout(60 * time.Second), - } - - if eventSource.TLS != nil && eventSource.TLS.Enabled { - opts = append(opts, - // Configure TLS. Uses SystemCertPool for RootCAs by default. - kgo.DialTLSConfig(new(tls.Config)), - ) - } - - if eventSource.Authentication != nil && eventSource.Authentication.SASLPlain.Username != nil && eventSource.Authentication.SASLPlain.Password != nil { - opts = append(opts, kgo.SASL(plain.Auth{ - User: *eventSource.Authentication.SASLPlain.Username, - Pass: *eventSource.Authentication.SASLPlain.Password, - }.AsMechanism())) - } - - return opts, nil -} - -func (b *ExecutorConfigurationBuilder) buildPlannerConfiguration(ctx context.Context, engineConfig *nodev1.EngineConfiguration, subgraphs []*nodev1.Subgraph, routerEngineCfg *RouterEngineConfiguration, pubSubProviders *EnginePubSubProviders, pluginsEnabled bool) (*plan.Configuration, error) { +func (b *ExecutorConfigurationBuilder) buildPlannerConfiguration(ctx context.Context, engineConfig *nodev1.EngineConfiguration, subgraphs []*nodev1.Subgraph, routerEngineCfg *RouterEngineConfiguration, pluginsEnabled bool) (*plan.Configuration, []pubsub_datasource.PubSubProvider, error) { // this loader is used to take the engine config and create a plan config // the plan config is what the engine uses to turn a GraphQL Request into an execution plan // the plan config is stateful as it carries connection pools and other things - loader := NewLoader(b.trackUsageInfo, NewDefaultFactoryResolver( + loader := NewLoader(ctx, b.trackUsageInfo, NewDefaultFactoryResolver( ctx, b.transportOptions, b.subscriptionClientOptions, @@ -290,14 +215,13 @@ func (b *ExecutorConfigurationBuilder) buildPlannerConfiguration(ctx context.Con b.logger, routerEngineCfg.Execution.EnableSingleFlight, routerEngineCfg.Execution.EnableNetPoll, - pubSubProviders.nats, - pubSubProviders.kafka, - )) + b.instanceData, + ), b.logger) // this generates the plan config using the data source factories from the config package - planConfig, err := loader.Load(engineConfig, subgraphs, routerEngineCfg, pluginsEnabled) + planConfig, providers, err := loader.Load(engineConfig, subgraphs, routerEngineCfg, pluginsEnabled) if err != nil { - return nil, fmt.Errorf("failed to load configuration: %w", err) + return nil, nil, fmt.Errorf("failed to load configuration: %w", err) } debug := &routerEngineCfg.Execution.Debug planConfig.Debug = plan.DebugConfiguration{ @@ -313,5 +237,6 @@ func (b *ExecutorConfigurationBuilder) buildPlannerConfiguration(ctx context.Con planConfig.MinifySubgraphOperations = routerEngineCfg.Execution.MinifySubgraphOperations planConfig.EnableOperationNamePropagation = routerEngineCfg.Execution.EnableSubgraphFetchOperationName - return planConfig, nil + + return planConfig, providers, nil } diff --git a/router/core/factoryresolver.go b/router/core/factoryresolver.go index aaccb35599..3812f3a8c5 100644 --- a/router/core/factoryresolver.go +++ b/router/core/factoryresolver.go @@ -9,9 +9,10 @@ import ( "net/url" "slices" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/argument_templates" - "github.com/buger/jsonparser" + "github.com/wundergraph/cosmo/router/pkg/pubsub" + pubsub_datasource "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/argument_templates" "github.com/wundergraph/cosmo/router/pkg/config" "github.com/wundergraph/cosmo/router/pkg/routerplugin" @@ -21,7 +22,6 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource" grpcdatasource "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/grpc_datasource" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/staticdatasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" @@ -30,15 +30,22 @@ import ( ) type Loader struct { + ctx context.Context resolver FactoryResolver // includeInfo controls whether additional information like type usage and field usage is included in the plan de includeInfo bool + logger *zap.Logger +} + +type InstanceData struct { + HostName string + ListenAddress string } type FactoryResolver interface { ResolveGraphqlFactory(subgraphName string) (plan.PlannerFactory[graphql_datasource.Configuration], error) ResolveStaticFactory() (plan.PlannerFactory[staticdatasource.Configuration], error) - ResolvePubsubFactory() (plan.PlannerFactory[pubsub_datasource.Configuration], error) + InstanceData() InstanceData } type ApiTransportFactory interface { @@ -48,7 +55,6 @@ type ApiTransportFactory interface { type DefaultFactoryResolver struct { static *staticdatasource.Factory[staticdatasource.Configuration] - pubsub *pubsub_datasource.Factory[pubsub_datasource.Configuration] log *zap.Logger engineCtx context.Context @@ -61,6 +67,7 @@ type DefaultFactoryResolver struct { pluginHost *routerplugin.Host factoryLogger abstractlogger.Logger + instanceData InstanceData } func NewDefaultFactoryResolver( @@ -73,8 +80,7 @@ func NewDefaultFactoryResolver( log *zap.Logger, enableSingleFlight bool, enableNetPoll bool, - natsPubSubBySourceID map[string]pubsub_datasource.NatsPubSub, - kafkaPubSubBySourceID map[string]pubsub_datasource.KafkaPubSub, + instanceData InstanceData, ) *DefaultFactoryResolver { transportFactory := NewTransport(transportOptions) @@ -144,7 +150,6 @@ func NewDefaultFactoryResolver( return &DefaultFactoryResolver{ static: &staticdatasource.Factory[staticdatasource.Configuration]{}, - pubsub: pubsub_datasource.NewFactory(ctx, natsPubSubBySourceID, kafkaPubSubBySourceID), log: log, factoryLogger: factoryLogger, engineCtx: ctx, @@ -155,6 +160,7 @@ func NewDefaultFactoryResolver( httpClient: defaultHTTPClient, subgraphHTTPClients: subgraphHTTPClients, pluginHost: pluginHost, + instanceData: instanceData, } } @@ -179,14 +185,16 @@ func (d *DefaultFactoryResolver) ResolveStaticFactory() (factory plan.PlannerFac return d.static, nil } -func (d *DefaultFactoryResolver) ResolvePubsubFactory() (factory plan.PlannerFactory[pubsub_datasource.Configuration], err error) { - return d.pubsub, nil +func (d *DefaultFactoryResolver) InstanceData() InstanceData { + return d.instanceData } -func NewLoader(includeInfo bool, resolver FactoryResolver) *Loader { +func NewLoader(ctx context.Context, includeInfo bool, resolver FactoryResolver, logger *zap.Logger) *Loader { return &Loader{ + ctx: ctx, resolver: resolver, includeInfo: includeInfo, + logger: logger, } } @@ -259,7 +267,7 @@ func mapProtoFilterToPlanFilter(input *nodev1.SubscriptionFilterCondition, outpu return nil } -func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nodev1.Subgraph, routerEngineConfig *RouterEngineConfiguration, pluginsEnabled bool) (*plan.Configuration, error) { +func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nodev1.Subgraph, routerEngineConfig *RouterEngineConfiguration, pluginsEnabled bool) (*plan.Configuration, []pubsub_datasource.PubSubProvider, error) { var outConfig plan.Configuration // attach field usage information to the plan outConfig.DefaultFlushIntervalMillis = engineConfig.DefaultFlushInterval @@ -294,6 +302,8 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod }) } + var providers []pubsub_datasource.PubSubProvider + for _, in := range engineConfig.DatasourceConfigurations { var out plan.DataSource @@ -301,7 +311,7 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod case nodev1.DataSourceKind_STATIC: factory, err := l.resolver.ResolveStaticFactory() if err != nil { - return nil, err + return nil, providers, err } out, err = plan.NewDataSourceConfiguration[staticdatasource.Configuration]( @@ -313,7 +323,7 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod }, ) if err != nil { - return nil, fmt.Errorf("error creating data source configuration for data source %s: %w", in.Id, err) + return nil, providers, fmt.Errorf("error creating data source configuration for data source %s: %w", in.Id, err) } case nodev1.DataSourceKind_GRAPHQL: @@ -342,7 +352,7 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod graphqlSchema, err := l.LoadInternedString(engineConfig, in.CustomGraphql.GetUpstreamSchema()) if err != nil { - return nil, fmt.Errorf("could not load GraphQL schema for data source %s: %w", in.Id, err) + return nil, providers, fmt.Errorf("could not load GraphQL schema for data source %s: %w", in.Id, err) } var subscriptionUseSSE bool @@ -381,7 +391,7 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod dataSourceRules := FetchURLRules(routerEngineConfig.Headers, subgraphs, subscriptionUrl) forwardedClientHeaders, forwardedClientRegexps, err := PropagatedHeaders(dataSourceRules) if err != nil { - return nil, fmt.Errorf("error parsing header rules for data source %s: %w", in.Id, err) + return nil, providers, fmt.Errorf("error parsing header rules for data source %s: %w", in.Id, err) } schemaConfiguration, err := graphql_datasource.NewSchemaConfiguration( @@ -392,14 +402,14 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod }, ) if err != nil { - return nil, fmt.Errorf("error creating schema configuration for data source %s: %w", in.Id, err) + return nil, providers, fmt.Errorf("error creating schema configuration for data source %s: %w", in.Id, err) } grpcConfig := toGRPCConfiguration(in.CustomGraphql.Grpc, pluginsEnabled) if grpcConfig != nil { grpcConfig.Compiler, err = grpcdatasource.NewProtoCompiler(in.CustomGraphql.Grpc.ProtoSchema, grpcConfig.Mapping) if err != nil { - return nil, fmt.Errorf("error creating proto compiler for data source %s: %w", in.Id, err) + return nil, providers, fmt.Errorf("error creating proto compiler for data source %s: %w", in.Id, err) } } @@ -422,14 +432,14 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod GRPC: grpcConfig, }) if err != nil { - return nil, fmt.Errorf("error creating custom configuration for data source %s: %w", in.Id, err) + return nil, providers, fmt.Errorf("error creating custom configuration for data source %s: %w", in.Id, err) } dataSourceName := l.subgraphName(subgraphs, in.Id) factory, err := l.resolver.ResolveGraphqlFactory(dataSourceName) if err != nil { - return nil, err + return nil, providers, err } out, err = plan.NewDataSourceConfigurationWithName[graphql_datasource.Configuration]( @@ -440,83 +450,48 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod customConfiguration, ) if err != nil { - return nil, fmt.Errorf("error creating data source configuration for data source %s: %w", in.Id, err) + return nil, providers, fmt.Errorf("error creating data source configuration for data source %s: %w", in.Id, err) } case nodev1.DataSourceKind_PUBSUB: - var eventConfigurations []pubsub_datasource.EventConfiguration - - for _, eventConfiguration := range in.GetCustomEvents().GetNats() { - eventType, err := pubsub_datasource.EventTypeFromString(eventConfiguration.EngineEventConfiguration.Type.String()) + var err error + + dsMeta := l.dataSourceMetaData(in) + providersFactories := pubsub.GetProviderFactories() + for _, providerFactory := range providersFactories { + provider, err := providerFactory( + l.ctx, + in, + dsMeta, + routerEngineConfig.Events, + l.logger, + l.resolver.InstanceData().HostName, + l.resolver.InstanceData().ListenAddress, + ) if err != nil { - return nil, fmt.Errorf("invalid event type %q for data source %q: %w", eventConfiguration.EngineEventConfiguration.Type.String(), in.Id, err) - } - - var streamConfiguration *pubsub_datasource.NatsStreamConfiguration - if eventConfiguration.StreamConfiguration != nil { - streamConfiguration = &pubsub_datasource.NatsStreamConfiguration{ - Consumer: eventConfiguration.StreamConfiguration.GetConsumerName(), - StreamName: eventConfiguration.StreamConfiguration.GetStreamName(), - ConsumerInactiveThreshold: eventConfiguration.StreamConfiguration.GetConsumerInactiveThreshold(), - } + return nil, providers, err } - - eventConfigurations = append(eventConfigurations, pubsub_datasource.EventConfiguration{ - Metadata: &pubsub_datasource.EventMetadata{ - ProviderID: eventConfiguration.EngineEventConfiguration.GetProviderId(), - Type: eventType, - TypeName: eventConfiguration.EngineEventConfiguration.GetTypeName(), - FieldName: eventConfiguration.EngineEventConfiguration.GetFieldName(), - }, - Configuration: &pubsub_datasource.NatsEventConfiguration{ - StreamConfiguration: streamConfiguration, - Subjects: eventConfiguration.GetSubjects(), - }, - }) - } - - for _, eventConfiguration := range in.GetCustomEvents().GetKafka() { - eventType, err := pubsub_datasource.EventTypeFromString(eventConfiguration.EngineEventConfiguration.Type.String()) - if err != nil { - return nil, fmt.Errorf("invalid event type %q for data source %q: %w", eventConfiguration.EngineEventConfiguration.Type.String(), in.Id, err) + if provider != nil { + providers = append(providers, provider) } - - eventConfigurations = append(eventConfigurations, pubsub_datasource.EventConfiguration{ - Metadata: &pubsub_datasource.EventMetadata{ - ProviderID: eventConfiguration.EngineEventConfiguration.GetProviderId(), - Type: eventType, - TypeName: eventConfiguration.EngineEventConfiguration.GetTypeName(), - FieldName: eventConfiguration.EngineEventConfiguration.GetFieldName(), - }, - Configuration: &pubsub_datasource.KafkaEventConfiguration{ - Topics: eventConfiguration.GetTopics(), - }, - }) } - factory, err := l.resolver.ResolvePubsubFactory() - if err != nil { - return nil, err - } - - out, err = plan.NewDataSourceConfiguration[pubsub_datasource.Configuration]( + out, err = plan.NewDataSourceConfiguration( in.Id, - factory, - l.dataSourceMetaData(in), - pubsub_datasource.Configuration{ - Events: eventConfigurations, - }, + pubsub_datasource.NewFactory(l.ctx, routerEngineConfig.Events, providers), + dsMeta, + providers, ) if err != nil { - return nil, fmt.Errorf("error creating data source configuration for data source %s: %w", in.Id, err) + return nil, providers, err } default: - return nil, fmt.Errorf("unknown data source type %q", in.Kind) + return nil, providers, fmt.Errorf("unknown data source type %q", in.Kind) } outConfig.DataSources = append(outConfig.DataSources, out) } - return &outConfig, nil + return &outConfig, providers, nil } func (l *Loader) subgraphName(subgraphs []*nodev1.Subgraph, dataSourceID string) string { diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 4c2a01625f..2552542d44 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -21,8 +21,6 @@ import ( "github.com/golang-jwt/jwt/v5" "github.com/klauspost/compress/gzhttp" "github.com/klauspost/compress/gzip" - "github.com/nats-io/nats.go" - "github.com/nats-io/nats.go/jetstream" "go.opentelemetry.io/otel/attribute" otelmetric "go.opentelemetry.io/otel/metric" oteltrace "go.opentelemetry.io/otel/trace" @@ -30,6 +28,7 @@ import ( "go.uber.org/zap" "go.uber.org/zap/zapcore" "golang.org/x/exp/maps" + "golang.org/x/sync/errgroup" "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/common" nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" @@ -46,13 +45,10 @@ import ( "github.com/wundergraph/cosmo/router/pkg/logging" rmetric "github.com/wundergraph/cosmo/router/pkg/metric" "github.com/wundergraph/cosmo/router/pkg/otel" - "github.com/wundergraph/cosmo/router/pkg/pubsub" - "github.com/wundergraph/cosmo/router/pkg/pubsub/kafka" - pubsubNats "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/cosmo/router/pkg/routerplugin" "github.com/wundergraph/cosmo/router/pkg/statistics" rtrace "github.com/wundergraph/cosmo/router/pkg/trace" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" ) const ( @@ -67,11 +63,6 @@ type ( HealthChecks() health.Checker } - EnginePubSubProviders struct { - nats map[string]pubsub_datasource.NatsPubSub - kafka map[string]pubsub_datasource.KafkaPubSub - } - // graphServer is the swappable implementation of a Graph instance which is an HTTP mux with middlewares. // Everytime a schema is updated, the old graph server is shutdown and a new graph server is created. // For feature flags, a graphql server has multiple mux and is dynamically switched based on the feature flag header or cookie. @@ -80,7 +71,6 @@ type ( *Config context context.Context cancelFunc context.CancelFunc - pubSubProviders *EnginePubSubProviders storageProviders *config.StorageProviders engineStats statistics.EngineStatistics playgroundHandler func(http.Handler) http.Handler @@ -99,8 +89,8 @@ type ( otlpEngineMetrics *rmetric.EngineMetrics prometheusEngineMetrics *rmetric.EngineMetrics connectionMetrics *rmetric.ConnectionMetrics - hostName string - routerListenAddr string + instanceData InstanceData + pubSubProviders []datasource.PubSubProvider traceDialer *TraceDialer pluginHost *routerplugin.Host } @@ -147,11 +137,9 @@ func newGraphServer(ctx context.Context, r *Router, routerConfig *nodev1.RouterC baseRouterConfigVersion: routerConfig.GetVersion(), inFlightRequests: &atomic.Uint64{}, graphMuxList: make([]*graphMux, 0, 1), - routerListenAddr: r.listenAddr, - hostName: r.hostName, - pubSubProviders: &EnginePubSubProviders{ - nats: map[string]pubsub_datasource.NatsPubSub{}, - kafka: map[string]pubsub_datasource.KafkaPubSub{}, + instanceData: InstanceData{ + HostName: r.hostName, + ListenAddress: r.listenAddr, }, storageProviders: &r.storageProviders, } @@ -991,11 +979,6 @@ func (s *graphServer) buildGraphMux(ctx context.Context, SubgraphErrorPropagation: s.subgraphErrorPropagation, } - err = s.buildPubSubConfiguration(ctx, engineConfig, routerEngineConfig) - if err != nil { - return nil, fmt.Errorf("failed to build pubsub configuration: %w", err) - } - // map[string]*http.Transport cannot be coerced into map[string]http.RoundTripper, unfortunately subgraphTippers := map[string]http.RoundTripper{} for subgraph, subgraphTransport := range s.subgraphTransports { @@ -1052,24 +1035,29 @@ func (s *graphServer) buildGraphMux(ctx context.Context, }, } - executor, err := ecb.Build( + executor, providers, err := ecb.Build( ctx, &ExecutorBuildOptions{ EngineConfig: engineConfig, Subgraphs: configSubgraphs, RouterEngineConfig: routerEngineConfig, - PubSubProviders: s.pubSubProviders, Reporter: s.engineStats, ApolloCompatibilityFlags: s.apolloCompatibilityFlags, ApolloRouterCompatibilityFlags: s.apolloRouterCompatibilityFlags, HeartbeatInterval: s.multipartHeartbeatInterval, PluginsEnabled: s.plugins.Enabled, + InstanceData: s.instanceData, }, ) if err != nil { return nil, fmt.Errorf("failed to build plan configuration: %w", err) } + s.pubSubProviders = providers + if pubSubStartupErr := s.startupPubSubProviders(ctx); pubSubStartupErr != nil { + return nil, pubSubStartupErr + } + operationProcessor := NewOperationProcessor(OperationProcessorOptions{ Executor: executor, MaxOperationSizeInBytes: int64(s.routerTrafficConfig.MaxRequestBodyBytes), @@ -1383,86 +1371,6 @@ func (s *graphServer) setupPluginHost(ctx context.Context, config *nodev1.Engine return nil } -func (s *graphServer) buildPubSubConfiguration(ctx context.Context, engineConfig *nodev1.EngineConfiguration, routerEngineCfg *RouterEngineConfiguration) error { - datasourceConfigurations := engineConfig.GetDatasourceConfigurations() - for _, datasourceConfiguration := range datasourceConfigurations { - if datasourceConfiguration.CustomEvents == nil { - continue - } - - for _, eventConfiguration := range datasourceConfiguration.GetCustomEvents().GetNats() { - - providerID := eventConfiguration.EngineEventConfiguration.GetProviderId() - // if this source name's provider has already been initiated, do not try to initiate again - _, ok := s.pubSubProviders.nats[providerID] - if ok { - continue - } - - for _, eventSource := range routerEngineCfg.Events.Providers.Nats { - if eventSource.ID == eventConfiguration.EngineEventConfiguration.GetProviderId() { - options, err := buildNatsOptions(eventSource, s.logger) - if err != nil { - return fmt.Errorf("failed to build options for Nats provider with ID \"%s\": %w", providerID, err) - } - natsConnection, err := nats.Connect(eventSource.URL, options...) - if err != nil { - return fmt.Errorf("failed to create connection for Nats provider with ID \"%s\": %w", providerID, err) - } - js, err := jetstream.New(natsConnection) - if err != nil { - return err - } - - s.pubSubProviders.nats[providerID] = pubsubNats.NewConnector(s.logger, natsConnection, js, s.hostName, s.routerListenAddr).New(ctx) - - break - } - } - - _, ok = s.pubSubProviders.nats[providerID] - if !ok { - return fmt.Errorf("failed to find Nats provider with ID \"%s\". Ensure the provider definition is part of the config", providerID) - } - } - - for _, eventConfiguration := range datasourceConfiguration.GetCustomEvents().GetKafka() { - - providerID := eventConfiguration.EngineEventConfiguration.GetProviderId() - // if this source name's provider has already been initiated, do not try to initiate again - _, ok := s.pubSubProviders.kafka[providerID] - if ok { - continue - } - - for _, eventSource := range routerEngineCfg.Events.Providers.Kafka { - if eventSource.ID == providerID { - options, err := buildKafkaOptions(eventSource) - if err != nil { - return fmt.Errorf("failed to build options for Kafka provider with ID \"%s\": %w", providerID, err) - } - ps, err := kafka.NewConnector(s.logger, options) - if err != nil { - return fmt.Errorf("failed to create connection for Kafka provider with ID \"%s\": %w", providerID, err) - } - - s.pubSubProviders.kafka[providerID] = ps.New(ctx) - - break - } - } - - _, ok = s.pubSubProviders.kafka[providerID] - if !ok { - return fmt.Errorf("failed to find Kafka provider with ID \"%s\". Ensure the provider definition is part of the config", providerID) - } - } - - } - - return nil -} - // wait waits for all in-flight requests to finish. Similar to http.Server.Shutdown we wait in intervals + jitter // to make the shutdown process more efficient. func (s *graphServer) wait(ctx context.Context) error { @@ -1542,24 +1450,8 @@ func (s *graphServer) Shutdown(ctx context.Context) error { } } - if s.pubSubProviders != nil { - - s.logger.Debug("Shutting down pubsub providers") - - for _, pubSub := range s.pubSubProviders.nats { - if p, ok := pubSub.(pubsub.Lifecycle); ok { - if err := p.Shutdown(ctx); err != nil { - finalErr = errors.Join(finalErr, err) - } - } - } - for _, pubSub := range s.pubSubProviders.kafka { - if p, ok := pubSub.(pubsub.Lifecycle); ok { - if err := p.Shutdown(ctx); err != nil { - finalErr = errors.Join(finalErr, err) - } - } - } + if err := s.shutdownPubSubProviders(ctx); err != nil { + finalErr = errors.Join(finalErr, err) } // Shutdown all graphs muxes to release resources @@ -1588,6 +1480,56 @@ func (s *graphServer) Shutdown(ctx context.Context) error { return finalErr } +// startupPubSubProviders starts all pubsub providers +// It returns an error if any of the providers fail to start +// or if some providers takes to long to start +func (s *graphServer) startupPubSubProviders(ctx context.Context) error { + // Default timeout for pubsub provider startup + const defaultStartupTimeout = 5 * time.Second + + return s.providersActionWithTimeout(ctx, func(ctx context.Context, provider datasource.PubSubProvider) error { + return provider.Startup(ctx) + }, defaultStartupTimeout, "pubsub provider startup timed out") +} + +// shutdownPubSubProviders shuts down all pubsub providers +// It returns an error if any of the providers fail to shutdown +// or if some providers takes to long to shutdown +func (s *graphServer) shutdownPubSubProviders(ctx context.Context) error { + // Default timeout for pubsub provider shutdown + const defaultShutdownTimeout = 5 * time.Second + + return s.providersActionWithTimeout(ctx, func(ctx context.Context, provider datasource.PubSubProvider) error { + return provider.Shutdown(ctx) + }, defaultShutdownTimeout, "pubsub provider shutdown timed out") +} + +func (s *graphServer) providersActionWithTimeout(ctx context.Context, action func(ctx context.Context, provider datasource.PubSubProvider) error, timeout time.Duration, timeoutMessage string) error { + cancellableCtx, cancel := context.WithCancel(ctx) + defer cancel() + + timer := time.NewTimer(timeout) + defer timer.Stop() + + providersGroup := new(errgroup.Group) + for _, provider := range s.pubSubProviders { + providersGroup.Go(func() error { + actionDone := make(chan error, 1) + go func() { + actionDone <- action(cancellableCtx, provider) + }() + select { + case err := <-actionDone: + return err + case <-timer.C: + return errors.New(timeoutMessage) + } + }) + } + + return providersGroup.Wait() +} + func configureSubgraphOverwrites( engineConfig *nodev1.EngineConfiguration, configSubgraphs []*nodev1.Subgraph, diff --git a/router/core/plan_generator.go b/router/core/plan_generator.go index cb72f141b7..1cc30c2eff 100644 --- a/router/core/plan_generator.go +++ b/router/core/plan_generator.go @@ -9,6 +9,9 @@ import ( log "github.com/jensneuse/abstractlogger" nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/cosmo/router/pkg/pubsub/kafka" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" "github.com/wundergraph/graphql-go-tools/v2/pkg/astnormalization" "github.com/wundergraph/graphql-go-tools/v2/pkg/astparser" @@ -16,7 +19,6 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/astvalidation" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/introspection_datasource" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/postprocess" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" @@ -234,8 +236,9 @@ func (pg *PlanGenerator) buildRouterConfig(configFilePath string) (*nodev1.Route } func (pg *PlanGenerator) loadConfiguration(routerConfig *nodev1.RouterConfig, logger *zap.Logger, maxDataSourceCollectorsConcurrency uint) error { - natSources := map[string]pubsub_datasource.NatsPubSub{} - kafkaSources := map[string]pubsub_datasource.KafkaPubSub{} + routerEngineConfig := RouterEngineConfiguration{} + natSources := map[string]*nats.Adapter{} + kafkaSources := map[string]*kafka.Adapter{} for _, ds := range routerConfig.GetEngineConfig().GetDatasourceConfigurations() { if ds.GetKind() != nodev1.DataSourceKind_PUBSUB || ds.GetCustomEvents() == nil { continue @@ -244,16 +247,21 @@ func (pg *PlanGenerator) loadConfiguration(routerConfig *nodev1.RouterConfig, lo providerId := natConfig.GetEngineEventConfiguration().GetProviderId() if _, ok := natSources[providerId]; !ok { natSources[providerId] = nil + routerEngineConfig.Events.Providers.Nats = append(routerEngineConfig.Events.Providers.Nats, config.NatsEventSource{ + ID: providerId, + }) } } for _, kafkaConfig := range ds.GetCustomEvents().GetKafka() { providerId := kafkaConfig.GetEngineEventConfiguration().GetProviderId() if _, ok := kafkaSources[providerId]; !ok { kafkaSources[providerId] = nil + routerEngineConfig.Events.Providers.Kafka = append(routerEngineConfig.Events.Providers.Kafka, config.KafkaEventSource{ + ID: providerId, + }) } } } - pubSubFactory := pubsub_datasource.NewFactory(context.Background(), natSources, kafkaSources) var netPollConfig graphql_datasource.NetPollConfiguration netPollConfig.ApplyDefaults() @@ -266,16 +274,18 @@ func (pg *PlanGenerator) loadConfiguration(routerConfig *nodev1.RouterConfig, lo graphql_datasource.WithNetPollConfiguration(netPollConfig), ) - loader := NewLoader(false, &DefaultFactoryResolver{ - engineCtx: context.Background(), + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + loader := NewLoader(ctx, false, &DefaultFactoryResolver{ + engineCtx: ctx, httpClient: http.DefaultClient, streamingClient: http.DefaultClient, subscriptionClient: subscriptionClient, - pubsub: pubSubFactory, - }) + }, logger) // this generates the plan configuration using the data source factories from the config package - planConfig, err := loader.Load(routerConfig.GetEngineConfig(), routerConfig.GetSubgraphs(), &RouterEngineConfiguration{}, false) // TODO: configure plugins + planConfig, _, err := loader.Load(routerConfig.GetEngineConfig(), routerConfig.GetSubgraphs(), &routerEngineConfig, false) // TODO: configure plugins if err != nil { return fmt.Errorf("failed to load configuration: %w", err) } diff --git a/router/go.mod b/router/go.mod index c49c462f94..58595de615 100644 --- a/router/go.mod +++ b/router/go.mod @@ -137,6 +137,7 @@ require ( github.com/shoenig/go-m1cpu v0.1.6 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/cast v1.7.1 // indirect + github.com/stretchr/objx v0.5.2 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect diff --git a/router/go.sum b/router/go.sum index 3b0009745c..f0086ba66b 100644 --- a/router/go.sum +++ b/router/go.sum @@ -250,6 +250,7 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= diff --git a/router/pkg/plan_generator/plan_generator.go b/router/pkg/plan_generator/plan_generator.go index 6ac9861c79..8b14858431 100644 --- a/router/pkg/plan_generator/plan_generator.go +++ b/router/pkg/plan_generator/plan_generator.go @@ -116,7 +116,10 @@ func PlanGenerator(ctx context.Context, cfg QueryPlanConfig) error { defer wg.Done() planner, err := pg.GetPlanner() if err != nil { + // if we fail to get the planner, we need to cancel the context to stop the other goroutines + // and return here to stop the current goroutine cancelError(fmt.Errorf("failed to get planner: %v", err)) + return } for { select { diff --git a/router/pkg/plan_generator/plan_generator_test.go b/router/pkg/plan_generator/plan_generator_test.go index 7bae1f21ce..6de540c7a7 100644 --- a/router/pkg/plan_generator/plan_generator_test.go +++ b/router/pkg/plan_generator/plan_generator_test.go @@ -11,6 +11,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "go.uber.org/zap" ) func getTestDataDir() string { @@ -130,6 +131,7 @@ func TestPlanGenerator(t *testing.T) { ExecutionConfig: path.Join(getTestDataDir(), "execution_config", "base.json"), Timeout: "30s", OutputFiles: true, + Logger: zap.NewNop(), } err = PlanGenerator(context.Background(), cfg) diff --git a/router/pkg/pubsub/README.md b/router/pkg/pubsub/README.md new file mode 100644 index 0000000000..ebcc636424 --- /dev/null +++ b/router/pkg/pubsub/README.md @@ -0,0 +1,87 @@ + +# How to add a PubSub Provider + +## Add the data to the router proto + +You need to change the [router proto](../../../proto/wg/cosmo/node/v1/node.proto) as follows. + +Add the provider configuration like the `KafkaEventConfiguration` and then add it as repeated inside the `DataSourceCustomEvents`. + +The fields of `KafkaEventConfiguration` will depends on the provider. If the providers uses as grouping mechanisms of the messages "channel" it will be called "channels", if it is "Topic" it will be "topics", and so on. + +After this you will have to compile the proto launching from the main folder the command `make generate-go`. + + +## Build the PubSub Provider + +To build a PubSub provider you need to implement 4 things: +- `Adapter` +- `ProviderFactory` +- `PubSubProvider` +- `PubSubDataSource` + +And then add it inside the `GetProviderFactories` function. + +### Adapter + +The Adapter contains the logic that is actually calling the provider, usually it implement an interface as follows: + +```go +type AdapterInterface interface { + Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error + Publish(ctx context.Context, event PublishEventConfiguration) error + Startup(ctx context.Context) error + Shutdown(ctx context.Context) error +} +``` + +The content of `SubscriptionEventConfiguration` and `PublishEventConfiguration` depends on the provider, you can see an example of them in the [kafka implementation](./kafka/pubsub_datasource.go). + + +### ProviderFactory + +The `ProviderFactory` is the initial contact point where you receive: +- `ctx context.Context`, usually passed down to the adapter +- [`*nodev1.DataSourceConfiguration`](../../gen/proto/wg/cosmo/node/v1/node.pb.go#DataSourceConfiguration), that contains everything you need about the provider data parsed from the schema. +- `*plan.DataSourceMetadata`, usually not needed +- [`config.EventsConfiguration`](../config/config.go#EventsConfiguration) that contains the config needed to setup the provider connection +- `*zap.Logger` +- `hostName string`, useful if you need to identify the connection based on the local host name +- `routerListenAddr string`, useful if you need to identify the connection based on different router instances in the same host + +The responsability of the factory is to initialize the PubSubProvider, like in this implementation for an `ExampleProvider`: + +You can see as an example of the `GetProvider` function in the [kafka implementation](./kafka/provider.go). + +### PubSubProvider + +So, the `PubSubProvider` has already the Adapter of the provider initialized, and it will be called on a `Visitor.EnterField` call from the engine to check if the `PubSubProvider` is matching any `EngineEventConfiguration`. + +The responsability of the `PubSubProvider` is to match the `EngineEventConfiguration` and initialize a `PubSubDataSource` with the matching event and the provider `Adapter`. + +You can see as an example of the `PubSubProvider` in the [kafka implementation](./kafka/provider.go). + +### PubSubDataSource + +The `[PubSubDataSource](./datasource/datasource.go)` is the junction between the engine `resolve.DataSource` and the Provider that we are implementing. + +You can see an example in [kafka `PubSubDataSource`](./kafka/pubsub_datasource.go). + +To complete the `PubSubDataSource` implementation you should also add the engine data source. + +So you have to implement the SubscriptionDataSource, a structure that implements all the methods needed by the interface `resolve.SubscriptionDataSource`, like the [kafka implementation](./kafka/engine_datasource.go). + +And also, you have to implement the DataSource, a structure that implements all the methods needed by the interface `resolve.DataSource`, like `PublishDataSource` in the [kafka implementation](./kafka/pubsub_datasource.go). + +# How to use the new PubSub Provider + +After you have implemented all the above, you can use your PubSub Provider by adding the following to your router config: + +```yaml +pubsub: + providers: + - name: provider-name + type: new-provider +``` + +But to use it in the schema you will have to work in the [composition](../../../composition) folder. \ No newline at end of file diff --git a/router/pkg/pubsub/datasource/datasource.go b/router/pkg/pubsub/datasource/datasource.go new file mode 100644 index 0000000000..7107a9efe1 --- /dev/null +++ b/router/pkg/pubsub/datasource/datasource.go @@ -0,0 +1,30 @@ +package datasource + +import ( + nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +// PubSubDataSource is the interface that all pubsub data sources must implement. +// It serves three main purposes: +// 1. Resolving the data source and subscription data source +// 2. Generating the appropriate input for these data sources +// 3. Providing access to the engine event configuration +// +// For detailed implementation guidelines, see: +// https://github.com/wundergraph/cosmo/blob/main/router/pkg/pubsub/README.md +type PubSubDataSource interface { + // ResolveDataSource returns the engine DataSource implementation that contains + // methods which will be called by the Planner when resolving a field + ResolveDataSource() (resolve.DataSource, error) + // ResolveDataSourceInput build the input that will be passed to the engine DataSource + ResolveDataSourceInput(event []byte) (string, error) + // EngineEventConfiguration get the engine event configuration, contains the provider id, type, type name and field name + EngineEventConfiguration() *nodev1.EngineEventConfiguration + // ResolveDataSourceSubscription returns the engine SubscriptionDataSource implementation + // that contains methods to start a subscription, which will be called by the Planner + // when a subscription is initiated + ResolveDataSourceSubscription() (resolve.SubscriptionDataSource, error) + // ResolveDataSourceSubscriptionInput build the input that will be passed to the engine SubscriptionDataSource + ResolveDataSourceSubscriptionInput() (string, error) +} diff --git a/router/pkg/pubsub/error.go b/router/pkg/pubsub/datasource/error.go similarity index 93% rename from router/pkg/pubsub/error.go rename to router/pkg/pubsub/datasource/error.go index f6220fb7b1..f09b271688 100644 --- a/router/pkg/pubsub/error.go +++ b/router/pkg/pubsub/datasource/error.go @@ -1,4 +1,4 @@ -package pubsub +package datasource type Error struct { Internal error diff --git a/router/pkg/pubsub/datasource/factory.go b/router/pkg/pubsub/datasource/factory.go new file mode 100644 index 0000000000..ce22f9934b --- /dev/null +++ b/router/pkg/pubsub/datasource/factory.go @@ -0,0 +1,38 @@ +package datasource + +import ( + "context" + + "github.com/jensneuse/abstractlogger" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" +) + +func NewFactory(executionContext context.Context, config config.EventsConfiguration, providers []PubSubProvider) *Factory { + return &Factory{ + providers: providers, + executionContext: executionContext, + config: config, + } +} + +type Factory struct { + providers []PubSubProvider + executionContext context.Context + config config.EventsConfiguration +} + +func (f *Factory) Planner(_ abstractlogger.Logger) plan.DataSourcePlanner[[]PubSubProvider] { + return &Planner{ + providers: f.providers, + } +} + +func (f *Factory) Context() context.Context { + return f.executionContext +} + +func (f *Factory) UpstreamSchema(dataSourceConfig plan.DataSourceConfiguration[[]PubSubProvider]) (*ast.Document, bool) { + return nil, false +} diff --git a/router/pkg/pubsub/datasource/planner.go b/router/pkg/pubsub/datasource/planner.go new file mode 100644 index 0000000000..9b5fe1a84a --- /dev/null +++ b/router/pkg/pubsub/datasource/planner.go @@ -0,0 +1,214 @@ +package datasource + +import ( + "fmt" + "strings" + + "github.com/wundergraph/cosmo/router/pkg/pubsub/eventdata" + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/argument_templates" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +type Planner struct { + id int + providers []PubSubProvider + pubSubDataSource PubSubDataSource + rootFieldRef int + variables resolve.Variables + visitor *plan.Visitor +} + +func (p *Planner) SetID(id int) { + p.id = id +} + +func (p *Planner) ID() (id int) { + return p.id +} + +func (p *Planner) DownstreamResponseFieldAlias(downstreamFieldRef int) (alias string, exists bool) { + // skip, not required + return +} + +func (p *Planner) DataSourcePlanningBehavior() plan.DataSourcePlanningBehavior { + return plan.DataSourcePlanningBehavior{ + MergeAliasedRootNodes: false, + OverrideFieldPathFromAlias: false, + } +} + +func (p *Planner) Register(visitor *plan.Visitor, configuration plan.DataSourceConfiguration[[]PubSubProvider], _ plan.DataSourcePlannerConfiguration) error { + p.visitor = visitor + visitor.Walker.RegisterEnterFieldVisitor(p) + visitor.Walker.RegisterEnterDocumentVisitor(p) + p.providers = configuration.CustomConfiguration() + return nil +} + +func (p *Planner) ConfigureFetch() resolve.FetchConfiguration { + if p.pubSubDataSource == nil { + return resolve.FetchConfiguration{} + } + + var dataSource resolve.DataSource + + dataSource, err := p.pubSubDataSource.ResolveDataSource() + if err != nil { + p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to get data source: %w", err)) + return resolve.FetchConfiguration{} + } + + event, err := eventdata.BuildEventDataBytes(p.rootFieldRef, p.visitor.Operation, &p.variables) + if err != nil { + p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to get resolve data source input: %w", err)) + return resolve.FetchConfiguration{} + } + + input, err := p.pubSubDataSource.ResolveDataSourceInput(event) + if err != nil { + p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to get resolve data source input: %w", err)) + return resolve.FetchConfiguration{} + } + + return resolve.FetchConfiguration{ + Input: input, + Variables: p.variables, + DataSource: dataSource, + PostProcessing: resolve.PostProcessingConfiguration{ + MergePath: []string{p.pubSubDataSource.EngineEventConfiguration().GetFieldName()}, + }, + } +} + +func (p *Planner) ConfigureSubscription() plan.SubscriptionConfiguration { + if p.pubSubDataSource == nil { + // p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to configure subscription: event manager is nil")) + return plan.SubscriptionConfiguration{} + } + + dataSource, err := p.pubSubDataSource.ResolveDataSourceSubscription() + if err != nil { + p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to get resolve data source subscription: %w", err)) + return plan.SubscriptionConfiguration{} + } + + input, err := p.pubSubDataSource.ResolveDataSourceSubscriptionInput() + if err != nil { + p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to get resolve data source subscription input: %w", err)) + return plan.SubscriptionConfiguration{} + } + + return plan.SubscriptionConfiguration{ + Input: input, + Variables: p.variables, + DataSource: dataSource, + PostProcessing: resolve.PostProcessingConfiguration{ + MergePath: []string{p.pubSubDataSource.EngineEventConfiguration().GetFieldName()}, + }, + } +} + +func (p *Planner) addContextVariableByArgumentRef(argumentRef int, argumentPath []string) (string, error) { + variablePath, err := p.visitor.Operation.VariablePathByArgumentRefAndArgumentPath(argumentRef, argumentPath, p.visitor.Walker.Ancestors[0].Ref) + if err != nil { + return "", err + } + /* The definition is passed as both definition and operation below because getJSONRootType resolves the type + * from the first argument, but finalInputValueTypeRef comes from the definition + */ + contextVariable := &resolve.ContextVariable{ + Path: variablePath, + Renderer: resolve.NewPlainVariableRenderer(), + } + variablePlaceHolder, _ := p.variables.AddVariable(contextVariable) + return variablePlaceHolder, nil +} + +func StringParser(subject string) (string, error) { + matches := argument_templates.ArgumentTemplateRegex.FindAllStringSubmatch(subject, -1) + if len(matches) < 1 { + return subject, nil + } + return "", fmt.Errorf(`subject "%s" is not a valid NATS subject`, subject) +} + +func (p *Planner) extractArgumentTemplate(fieldRef int, template string) (string, error) { + matches := argument_templates.ArgumentTemplateRegex.FindAllStringSubmatch(template, -1) + // If no argument templates are defined, there are only static values + if len(matches) < 1 { + return template, nil + } + fieldNameBytes := p.visitor.Operation.FieldNameBytes(fieldRef) + // TODO: handling for interfaces and unions + fieldDefinitionRef, ok := p.visitor.Definition.ObjectTypeDefinitionFieldWithName(p.visitor.Walker.EnclosingTypeDefinition.Ref, fieldNameBytes) + if !ok { + return "", fmt.Errorf(`expected field definition to exist for field "%s"`, fieldNameBytes) + } + templateWithVariableTemplateReplacements := template + for templateNumber, groups := range matches { + // The first group is the whole template; the second is the period delimited argument path + if len(groups) != 2 { + return "", fmt.Errorf(`argument template #%d defined on field "%s" is invalid: expected 2 matching groups but received %d`, templateNumber+1, fieldNameBytes, len(groups)-1) + } + validationResult, err := argument_templates.ValidateArgumentPath(p.visitor.Definition, groups[1], fieldDefinitionRef) + if err != nil { + return "", fmt.Errorf(`argument template #%d defined on field "%s" is invalid: %w`, templateNumber+1, fieldNameBytes, err) + } + argumentNameBytes := []byte(validationResult.ArgumentPath[0]) + argumentRef, ok := p.visitor.Operation.FieldArgument(fieldRef, argumentNameBytes) + if !ok { + return "", fmt.Errorf(`operation field "%s" does not define argument "%s"`, fieldNameBytes, argumentNameBytes) + } + // variablePlaceholder has the form $$0$$, $$1$$, etc. + variablePlaceholder, err := p.addContextVariableByArgumentRef(argumentRef, validationResult.ArgumentPath) + if err != nil { + return "", fmt.Errorf(`failed to retrieve variable placeholder for argument ""%s" defined on operation field "%s": %w`, argumentNameBytes, fieldNameBytes, err) + } + // Replace the template literal with the variable placeholder (and reuse the variable if it already exists) + templateWithVariableTemplateReplacements = strings.ReplaceAll(templateWithVariableTemplateReplacements, groups[0], variablePlaceholder) + } + + return templateWithVariableTemplateReplacements, nil +} + +func (p *Planner) EnterDocument(_, _ *ast.Document) { + p.rootFieldRef = -1 +} + +func (p *Planner) EnterField(ref int) { + if p.rootFieldRef != -1 { + // This is a nested field; nothing needs to be done + return + } + p.rootFieldRef = ref + + fieldName := p.visitor.Operation.FieldNameString(ref) + typeName := p.visitor.Walker.EnclosingTypeDefinition.NameString(p.visitor.Definition) + + extractFn := func(tpl string) (string, error) { + return p.extractArgumentTemplate(ref, tpl) + } + + var pubSubDataSource PubSubDataSource + var err error + + for _, pubSub := range p.providers { + pubSubDataSource, err = pubSub.FindPubSubDataSource(typeName, fieldName, extractFn) + if err != nil { + p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to find event config for type name \"%s\" and field name \"%s\": %w", typeName, fieldName, err)) + return + } + if pubSubDataSource != nil { + break + } + } + + if pubSubDataSource == nil { + p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to find event config for type name \"%s\" and field name \"%s\"", typeName, fieldName)) + } + + p.pubSubDataSource = pubSubDataSource +} diff --git a/router/pkg/pubsub/datasource/provider.go b/router/pkg/pubsub/datasource/provider.go new file mode 100644 index 0000000000..fa58f3b30e --- /dev/null +++ b/router/pkg/pubsub/datasource/provider.go @@ -0,0 +1,20 @@ +package datasource + +import ( + "context" + + nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" + "go.uber.org/zap" +) + +type ProviderFactory func(ctx context.Context, in *nodev1.DataSourceConfiguration, dsMeta *plan.DataSourceMetadata, config config.EventsConfiguration, logger *zap.Logger, hostName string, routerListenAddr string) (PubSubProvider, error) + +type ArgumentTemplateCallback func(tpl string) (string, error) + +type PubSubProvider interface { + Startup(ctx context.Context) error + Shutdown(ctx context.Context) error + FindPubSubDataSource(typeName string, fieldName string, extractFn ArgumentTemplateCallback) (PubSubDataSource, error) +} diff --git a/router/pkg/pubsub/eventdata/build.go b/router/pkg/pubsub/eventdata/build.go new file mode 100644 index 0000000000..23d6acf933 --- /dev/null +++ b/router/pkg/pubsub/eventdata/build.go @@ -0,0 +1,38 @@ +package eventdata + +import ( + "bytes" + "encoding/json" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/ast" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +func BuildEventDataBytes(ref int, operation *ast.Document, variables *resolve.Variables) ([]byte, error) { + // Collect the field arguments for fetch based operations + fieldArgs := operation.FieldArguments(ref) + var dataBuffer bytes.Buffer + dataBuffer.WriteByte('{') + for i, arg := range fieldArgs { + if i > 0 { + dataBuffer.WriteByte(',') + } + argValue := operation.ArgumentValue(arg) + variableName := operation.VariableValueNameBytes(argValue.Ref) + contextVariable := &resolve.ContextVariable{ + Path: []string{string(variableName)}, + Renderer: resolve.NewPlainVariableRenderer(), + } + variablePlaceHolder, _ := variables.AddVariable(contextVariable) + argumentName := operation.ArgumentNameString(arg) + escapedKey, err := json.Marshal(argumentName) + if err != nil { + return nil, err + } + dataBuffer.Write(escapedKey) + dataBuffer.WriteByte(':') + dataBuffer.WriteString(variablePlaceHolder) + } + dataBuffer.WriteByte('}') + return dataBuffer.Bytes(), nil +} diff --git a/router/pkg/pubsub/kafka/kafka.go b/router/pkg/pubsub/kafka/adapter.go similarity index 72% rename from router/pkg/pubsub/kafka/kafka.go rename to router/pkg/pubsub/kafka/adapter.go index b485980ede..7e9d3486c0 100644 --- a/router/pkg/pubsub/kafka/kafka.go +++ b/router/pkg/pubsub/kafka/adapter.go @@ -10,65 +10,44 @@ import ( "github.com/twmb/franz-go/pkg/kerr" "github.com/twmb/franz-go/pkg/kgo" - "github.com/wundergraph/cosmo/router/pkg/pubsub" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" ) var ( - _ pubsub_datasource.KafkaConnector = (*connector)(nil) - _ pubsub_datasource.KafkaPubSub = (*kafkaPubSub)(nil) - _ pubsub.Lifecycle = (*kafkaPubSub)(nil) - errClientClosed = errors.New("client closed") ) -type connector struct { - writeClient *kgo.Client - opts []kgo.Opt - logger *zap.Logger -} - -func NewConnector(logger *zap.Logger, opts []kgo.Opt) (pubsub_datasource.KafkaConnector, error) { - - writeClient, err := kgo.NewClient(append(opts, - // For observability, we set the client ID to "router" - kgo.ClientID("cosmo.router.producer"))..., - ) - if err != nil { - return nil, fmt.Errorf("failed to create write client for Kafka: %w", err) +func NewAdapter(ctx context.Context, logger *zap.Logger, opts []kgo.Opt) (AdapterInterface, error) { + ctx, cancel := context.WithCancel(ctx) + if logger == nil { + logger = zap.NewNop() } - return &connector{ - writeClient: writeClient, - opts: opts, - logger: logger, + return &Adapter{ + ctx: ctx, + logger: logger.With(zap.String("pubsub", "kafka")), + opts: opts, + closeWg: sync.WaitGroup{}, + cancel: cancel, }, nil } -func (c *connector) New(ctx context.Context) pubsub_datasource.KafkaPubSub { - - ctx, cancel := context.WithCancel(ctx) - - ps := &kafkaPubSub{ - ctx: ctx, - logger: c.logger.With(zap.String("pubsub", "kafka")), - opts: c.opts, - writeClient: c.writeClient, - closeWg: sync.WaitGroup{}, - cancel: cancel, - } - - return ps +// AdapterInterface defines the interface for Kafka adapter operations +type AdapterInterface interface { + Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error + Publish(ctx context.Context, event PublishEventConfiguration) error + Startup(ctx context.Context) error + Shutdown(ctx context.Context) error } -// kafkaPubSub is a Kafka pubsub implementation. +// Adapter is a Kafka pubsub implementation. // It uses the franz-go Kafka client to consume and produce messages. // The pubsub is stateless and does not store any messages. // It uses a single write client to produce messages and a client per topic to consume messages. // Each client polls the Kafka topic for new records and updates the subscriptions with the new data. -type kafkaPubSub struct { +type Adapter struct { ctx context.Context opts []kgo.Opt logger *zap.Logger @@ -77,8 +56,11 @@ type kafkaPubSub struct { cancel context.CancelFunc } +// Ensure Adapter implements AdapterInterface +var _ AdapterInterface = (*Adapter)(nil) + // topicPoller polls the Kafka topic for new records and calls the updateTriggers function. -func (p *kafkaPubSub) topicPoller(ctx context.Context, client *kgo.Client, updater resolve.SubscriptionUpdater) error { +func (p *Adapter) topicPoller(ctx context.Context, client *kgo.Client, updater resolve.SubscriptionUpdater) error { for { select { case <-p.ctx.Done(): // Close the poller if the application context was canceled @@ -132,7 +114,7 @@ func (p *kafkaPubSub) topicPoller(ctx context.Context, client *kgo.Client, updat // Subscribe subscribes to the given topics and updates the subscription updater. // The engine already deduplicates subscriptions with the same topics, stream configuration, extensions, headers, etc. -func (p *kafkaPubSub) Subscribe(ctx context.Context, event pubsub_datasource.KafkaSubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error { +func (p *Adapter) Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error { log := p.logger.With( zap.String("provider_id", event.ProviderID), @@ -140,8 +122,6 @@ func (p *kafkaPubSub) Subscribe(ctx context.Context, event pubsub_datasource.Kaf zap.Strings("topics", event.Topics), ) - log.Debug("subscribe") - // Create a new client for the topic client, err := kgo.NewClient(append(p.opts, kgo.ConsumeTopics(event.Topics...), @@ -151,6 +131,9 @@ func (p *kafkaPubSub) Subscribe(ctx context.Context, event pubsub_datasource.Kaf kgo.ConsumeResetOffset(kgo.NewOffset().AfterMilli(time.Now().UnixMilli())), // For observability, we set the client ID to "router" kgo.ClientID(fmt.Sprintf("cosmo.router.consumer.%s", strings.Join(event.Topics, "-"))), + // FIXME: the client id should have some unique identifier, like in nats + // What if we have multiple subscriptions for the same topics? + // What if we have more router instances? )...) if err != nil { log.Error("failed to create client", zap.Error(err)) @@ -181,13 +164,17 @@ func (p *kafkaPubSub) Subscribe(ctx context.Context, event pubsub_datasource.Kaf // Publish publishes the given event to the Kafka topic in a non-blocking way. // Publish errors are logged and returned as a pubsub error. // The event is written with a dedicated write client. -func (p *kafkaPubSub) Publish(ctx context.Context, event pubsub_datasource.KafkaPublishEventConfiguration) error { +func (p *Adapter) Publish(ctx context.Context, event PublishEventConfiguration) error { log := p.logger.With( zap.String("provider_id", event.ProviderID), zap.String("method", "publish"), zap.String("topic", event.Topic), ) + if p.writeClient == nil { + return datasource.NewError("kafka write client not initialized", nil) + } + log.Debug("publish", zap.ByteString("data", event.Data)) var wg sync.WaitGroup @@ -209,13 +196,29 @@ func (p *kafkaPubSub) Publish(ctx context.Context, event pubsub_datasource.Kafka if pErr != nil { log.Error("publish error", zap.Error(pErr)) - return pubsub.NewError(fmt.Sprintf("error publishing to Kafka topic %s", event.Topic), pErr) + return datasource.NewError(fmt.Sprintf("error publishing to Kafka topic %s", event.Topic), pErr) } return nil } -func (p *kafkaPubSub) Shutdown(ctx context.Context) error { +func (p *Adapter) Startup(ctx context.Context) (err error) { + p.writeClient, err = kgo.NewClient(append(p.opts, + // For observability, we set the client ID to "router" + kgo.ClientID("cosmo.router.producer"))..., + ) + if err != nil { + return err + } + + return +} + +func (p *Adapter) Shutdown(ctx context.Context) error { + + if p.writeClient == nil { + return nil + } err := p.writeClient.Flush(ctx) if err != nil { diff --git a/router/pkg/pubsub/kafka/engine_datasource.go b/router/pkg/pubsub/kafka/engine_datasource.go new file mode 100644 index 0000000000..f31758ffa7 --- /dev/null +++ b/router/pkg/pubsub/kafka/engine_datasource.go @@ -0,0 +1,70 @@ +package kafka + +import ( + "bytes" + "context" + "encoding/json" + "io" + + "github.com/buger/jsonparser" + "github.com/cespare/xxhash/v2" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +type SubscriptionDataSource struct { + pubSub AdapterInterface +} + +func (s *SubscriptionDataSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + val, _, _, err := jsonparser.Get(input, "topics") + if err != nil { + return err + } + + _, err = xxh.Write(val) + if err != nil { + return err + } + + val, _, _, err = jsonparser.Get(input, "providerId") + if err != nil { + return err + } + + _, err = xxh.Write(val) + return err +} + +func (s *SubscriptionDataSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { + var subscriptionConfiguration SubscriptionEventConfiguration + err := json.Unmarshal(input, &subscriptionConfiguration) + if err != nil { + return err + } + + return s.pubSub.Subscribe(ctx.Context(), subscriptionConfiguration, updater) +} + +type PublishDataSource struct { + pubSub AdapterInterface +} + +func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { + var publishConfiguration PublishEventConfiguration + err := json.Unmarshal(input, &publishConfiguration) + if err != nil { + return err + } + + if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { + _, err = io.WriteString(out, `{"success": false}`) + return err + } + _, err = io.WriteString(out, `{"success": true}`) + return err +} + +func (s *PublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { + panic("not implemented") +} diff --git a/router/pkg/pubsub/kafka/engine_datasource_test.go b/router/pkg/pubsub/kafka/engine_datasource_test.go new file mode 100644 index 0000000000..429e2b4643 --- /dev/null +++ b/router/pkg/pubsub/kafka/engine_datasource_test.go @@ -0,0 +1,270 @@ +package kafka + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "testing" + + "github.com/cespare/xxhash/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +// MockSubscriptionUpdater implements resolve.SubscriptionUpdater +type MockSubscriptionUpdater struct { + mock.Mock +} + +func (m *MockSubscriptionUpdater) Update(data []byte) { + m.Called(data) +} + +func (m *MockSubscriptionUpdater) Close() { + m.Called() +} + +func (m *MockSubscriptionUpdater) Done() { + m.Called() +} + +func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { + tests := []struct { + name string + config PublishEventConfiguration + wantPattern string + }{ + { + name: "simple configuration", + config: PublishEventConfiguration{ + ProviderID: "test-provider", + Topic: "test-topic", + Data: json.RawMessage(`{"message":"hello"}`), + }, + wantPattern: `{"topic":"test-topic", "data": {"message":"hello"}, "providerId":"test-provider"}`, + }, + { + name: "with special characters", + config: PublishEventConfiguration{ + ProviderID: "test-provider-id", + Topic: "topic-with-hyphens", + Data: json.RawMessage(`{"message":"special \"quotes\" here"}`), + }, + wantPattern: `{"topic":"topic-with-hyphens", "data": {"message":"special \"quotes\" here"}, "providerId":"test-provider-id"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.config.MarshalJSONTemplate() + assert.Equal(t, tt.wantPattern, result) + }) + } +} + +func TestSubscriptionSource_UniqueRequestID(t *testing.T) { + tests := []struct { + name string + input string + expectError bool + expectedError error + }{ + { + name: "valid input", + input: `{"topics":["topic1", "topic2"], "providerId":"test-provider"}`, + expectError: false, + }, + { + name: "missing topics", + input: `{"providerId":"test-provider"}`, + expectError: true, + expectedError: errors.New("Key path not found"), + }, + { + name: "missing providerId", + input: `{"topics":["topic1", "topic2"]}`, + expectError: true, + expectedError: errors.New("Key path not found"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + source := &SubscriptionDataSource{ + pubSub: &mockAdapter{}, + } + ctx := &resolve.Context{} + input := []byte(tt.input) + xxh := xxhash.New() + + err := source.UniqueRequestID(ctx, input, xxh) + + if tt.expectError { + require.Error(t, err) + if tt.expectedError != nil { + // For jsonparser errors, just check if the error message contains the expected text + assert.Contains(t, err.Error(), tt.expectedError.Error()) + } + } else { + require.NoError(t, err) + // Check that the hash has been updated + assert.NotEqual(t, 0, xxh.Sum64()) + } + }) + } +} + +func TestSubscriptionSource_Start(t *testing.T) { + tests := []struct { + name string + input string + mockSetup func(*mockAdapter) + expectError bool + }{ + { + name: "successful subscription", + input: `{"topics":["topic1", "topic2"], "providerId":"test-provider"}`, + mockSetup: func(m *mockAdapter) { + m.On("Subscribe", mock.Anything, SubscriptionEventConfiguration{ + ProviderID: "test-provider", + Topics: []string{"topic1", "topic2"}, + }, mock.Anything).Return(nil) + }, + expectError: false, + }, + { + name: "adapter returns error", + input: `{"topics":["topic1"], "providerId":"test-provider"}`, + mockSetup: func(m *mockAdapter) { + m.On("Subscribe", mock.Anything, SubscriptionEventConfiguration{ + ProviderID: "test-provider", + Topics: []string{"topic1"}, + }, mock.Anything).Return(errors.New("subscription error")) + }, + expectError: true, + }, + { + name: "invalid input json", + input: `{"invalid json":`, + mockSetup: func(m *mockAdapter) {}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockAdapter := new(mockAdapter) + tt.mockSetup(mockAdapter) + + source := &SubscriptionDataSource{ + pubSub: mockAdapter, + } + + // Set up go context + goCtx := context.Background() + + // Create a resolve.Context with the standard context + resolveCtx := &resolve.Context{} + resolveCtx = resolveCtx.WithContext(goCtx) + + // Create a proper mock updater + updater := new(MockSubscriptionUpdater) + updater.On("Done").Return() + + input := []byte(tt.input) + err := source.Start(resolveCtx, input, updater) + + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + mockAdapter.AssertExpectations(t) + }) + } +} + +func TestKafkaPublishDataSource_Load(t *testing.T) { + tests := []struct { + name string + input string + mockSetup func(*mockAdapter) + expectError bool + expectedOutput string + expectPublished bool + }{ + { + name: "successful publish", + input: `{"topic":"test-topic", "data":{"message":"hello"}, "providerId":"test-provider"}`, + mockSetup: func(m *mockAdapter) { + m.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishEventConfiguration) bool { + return event.ProviderID == "test-provider" && + event.Topic == "test-topic" && + string(event.Data) == `{"message":"hello"}` + })).Return(nil) + }, + expectError: false, + expectedOutput: `{"success": true}`, + expectPublished: true, + }, + { + name: "publish error", + input: `{"topic":"test-topic", "data":{"message":"hello"}, "providerId":"test-provider"}`, + mockSetup: func(m *mockAdapter) { + m.On("Publish", mock.Anything, mock.Anything).Return(errors.New("publish error")) + }, + expectError: false, // The Load method doesn't return the publish error directly + expectedOutput: `{"success": false}`, + expectPublished: true, + }, + { + name: "invalid input json", + input: `{"invalid json":`, + mockSetup: func(m *mockAdapter) {}, + expectError: true, + expectPublished: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockAdapter := new(mockAdapter) + tt.mockSetup(mockAdapter) + + dataSource := &PublishDataSource{ + pubSub: mockAdapter, + } + ctx := context.Background() + input := []byte(tt.input) + out := &bytes.Buffer{} + + err := dataSource.Load(ctx, input, out) + + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + assert.Equal(t, tt.expectedOutput, out.String()) + } + + if tt.expectPublished { + mockAdapter.AssertExpectations(t) + } + }) + } +} + +func TestKafkaPublishDataSource_LoadWithFiles(t *testing.T) { + t.Run("panic on not implemented", func(t *testing.T) { + dataSource := &PublishDataSource{ + pubSub: &mockAdapter{}, + } + + assert.Panics(t, func() { + dataSource.LoadWithFiles(context.Background(), nil, nil, &bytes.Buffer{}) + }) + }) +} diff --git a/router/pkg/pubsub/kafka/provider.go b/router/pkg/pubsub/kafka/provider.go new file mode 100644 index 0000000000..3d444f7923 --- /dev/null +++ b/router/pkg/pubsub/kafka/provider.go @@ -0,0 +1,121 @@ +package kafka + +import ( + "context" + "crypto/tls" + "fmt" + "time" + + "github.com/twmb/franz-go/pkg/kgo" + "github.com/twmb/franz-go/pkg/sasl/plain" + nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + "go.uber.org/zap" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" +) + +// buildKafkaOptions creates a list of kgo.Opt options for the given Kafka event source configuration. +// Only general options like TLS, SASL, etc. are configured here. Specific options like topics, etc. are +// configured in the KafkaPubSub implementation. +func buildKafkaOptions(eventSource config.KafkaEventSource) ([]kgo.Opt, error) { + opts := []kgo.Opt{ + kgo.SeedBrokers(eventSource.Brokers...), + // Ensure proper timeouts are set + kgo.ProduceRequestTimeout(10 * time.Second), + kgo.ConnIdleTimeout(60 * time.Second), + } + + if eventSource.TLS != nil && eventSource.TLS.Enabled { + opts = append(opts, + // Configure TLS. Uses SystemCertPool for RootCAs by default. + kgo.DialTLSConfig(new(tls.Config)), + ) + } + + if eventSource.Authentication != nil && eventSource.Authentication.SASLPlain.Username != nil && eventSource.Authentication.SASLPlain.Password != nil { + opts = append(opts, kgo.SASL(plain.Auth{ + User: *eventSource.Authentication.SASLPlain.Username, + Pass: *eventSource.Authentication.SASLPlain.Password, + }.AsMechanism())) + } + + return opts, nil +} + +func GetProvider(ctx context.Context, in *nodev1.DataSourceConfiguration, dsMeta *plan.DataSourceMetadata, config config.EventsConfiguration, logger *zap.Logger, hostName string, routerListenAddr string) (datasource.PubSubProvider, error) { + providers := make(map[string]AdapterInterface) + definedProviders := make(map[string]bool) + for _, provider := range config.Providers.Kafka { + definedProviders[provider.ID] = true + } + usedProviders := make(map[string]bool) + if kafkaData := in.GetCustomEvents().GetKafka(); kafkaData != nil { + for _, event := range kafkaData { + if !definedProviders[event.EngineEventConfiguration.ProviderId] { + return nil, fmt.Errorf("failed to find Kafka provider with ID %s", event.EngineEventConfiguration.ProviderId) + } + usedProviders[event.EngineEventConfiguration.ProviderId] = true + } + + for _, provider := range config.Providers.Kafka { + if !usedProviders[provider.ID] { + continue + } + options, err := buildKafkaOptions(provider) + if err != nil { + return nil, fmt.Errorf("failed to build options for Kafka provider with ID \"%s\": %w", provider.ID, err) + } + adapter, err := NewAdapter(ctx, logger, options) + if err != nil { + return nil, fmt.Errorf("failed to create adapter for Kafka provider with ID \"%s\": %w", provider.ID, err) + } + providers[provider.ID] = adapter + } + + return &PubSubProvider{ + EventConfiguration: kafkaData, + Logger: logger, + Providers: providers, + }, nil + } + + return nil, nil +} + +type PubSubProvider struct { + EventConfiguration []*nodev1.KafkaEventConfiguration + Logger *zap.Logger + Providers map[string]AdapterInterface +} + +func (c *PubSubProvider) FindPubSubDataSource(typeName string, fieldName string, extractFn datasource.ArgumentTemplateCallback) (datasource.PubSubDataSource, error) { + for _, cfg := range c.EventConfiguration { + if cfg.GetEngineEventConfiguration().GetTypeName() == typeName && cfg.GetEngineEventConfiguration().GetFieldName() == fieldName { + return &PubSubDataSource{ + KafkaAdapter: c.Providers[cfg.GetEngineEventConfiguration().GetProviderId()], + EventConfiguration: cfg, + }, nil + } + } + return nil, nil +} + +func (c *PubSubProvider) Startup(ctx context.Context) error { + for _, provider := range c.Providers { + if err := provider.Startup(ctx); err != nil { + return err + } + } + return nil +} + +func (c *PubSubProvider) Shutdown(ctx context.Context) error { + for _, provider := range c.Providers { + if err := provider.Shutdown(ctx); err != nil { + return err + } + } + return nil +} diff --git a/router/pkg/pubsub/kafka/provider_test.go b/router/pkg/pubsub/kafka/provider_test.go new file mode 100644 index 0000000000..52621affe6 --- /dev/null +++ b/router/pkg/pubsub/kafka/provider_test.go @@ -0,0 +1,216 @@ +package kafka + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" + "go.uber.org/zap/zaptest" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" +) + +// mockAdapter is a mock of AdapterInterface +type mockAdapter struct { + mock.Mock +} + +func (m *mockAdapter) Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error { + args := m.Called(ctx, event, updater) + return args.Error(0) +} + +func (m *mockAdapter) Publish(ctx context.Context, event PublishEventConfiguration) error { + args := m.Called(ctx, event) + return args.Error(0) +} + +func (m *mockAdapter) Startup(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +func (m *mockAdapter) Shutdown(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +func TestBuildKafkaOptions(t *testing.T) { + t.Run("basic configuration", func(t *testing.T) { + cfg := config.KafkaEventSource{ + Brokers: []string{"localhost:9092"}, + } + + opts, err := buildKafkaOptions(cfg) + require.NoError(t, err) + require.NotEmpty(t, opts) + }) + + t.Run("with TLS", func(t *testing.T) { + enabled := true + cfg := config.KafkaEventSource{ + Brokers: []string{"localhost:9092"}, + TLS: &config.KafkaTLSConfiguration{ + Enabled: enabled, + }, + } + + opts, err := buildKafkaOptions(cfg) + require.NoError(t, err) + require.NotEmpty(t, opts) + // Can't directly check for TLS options, but we can verify more options are present + require.Equal(t, len(opts), 4) + }) + + t.Run("with auth", func(t *testing.T) { + username := "user" + password := "pass" + cfg := config.KafkaEventSource{ + Brokers: []string{"localhost:9092"}, + Authentication: &config.KafkaAuthentication{ + SASLPlain: config.KafkaSASLPlainAuthentication{ + Username: &username, + Password: &password, + }, + }, + } + + opts, err := buildKafkaOptions(cfg) + require.NoError(t, err) + require.NotEmpty(t, opts) + // Can't directly check for SASL options, but we can verify more options are present + require.Greater(t, len(opts), 1) + }) +} + +func TestGetProvider(t *testing.T) { + t.Run("returns nil if no Kafka configuration", func(t *testing.T) { + ctx := context.Background() + in := &nodev1.DataSourceConfiguration{ + CustomEvents: &nodev1.DataSourceCustomEvents{}, + } + + dsMeta := &plan.DataSourceMetadata{} + cfg := config.EventsConfiguration{} + logger := zaptest.NewLogger(t) + + provider, err := GetProvider(ctx, in, dsMeta, cfg, logger, "host", "addr") + require.NoError(t, err) + require.Nil(t, provider) + }) + + t.Run("errors if provider not found", func(t *testing.T) { + ctx := context.Background() + in := &nodev1.DataSourceConfiguration{ + CustomEvents: &nodev1.DataSourceCustomEvents{ + Kafka: []*nodev1.KafkaEventConfiguration{ + { + EngineEventConfiguration: &nodev1.EngineEventConfiguration{ + ProviderId: "unknown", + }, + }, + }, + }, + } + + dsMeta := &plan.DataSourceMetadata{} + cfg := config.EventsConfiguration{ + Providers: config.EventProviders{ + Kafka: []config.KafkaEventSource{ + {ID: "provider1", Brokers: []string{"localhost:9092"}}, + }, + }, + } + logger := zaptest.NewLogger(t) + + provider, err := GetProvider(ctx, in, dsMeta, cfg, logger, "host", "addr") + require.Error(t, err) + require.Nil(t, provider) + assert.Contains(t, err.Error(), "failed to find Kafka provider with ID") + }) + + t.Run("creates provider with configured adapters", func(t *testing.T) { + providerId := "test-provider" + + in := &nodev1.DataSourceConfiguration{ + CustomEvents: &nodev1.DataSourceCustomEvents{ + Kafka: []*nodev1.KafkaEventConfiguration{ + { + EngineEventConfiguration: &nodev1.EngineEventConfiguration{ + ProviderId: providerId, + }, + }, + }, + }, + } + + cfg := config.EventsConfiguration{ + Providers: config.EventProviders{ + Kafka: []config.KafkaEventSource{ + {ID: providerId, Brokers: []string{"localhost:9092"}}, + }, + }, + } + + logger := zaptest.NewLogger(t) + + // Create mock adapter for testing + provider, err := GetProvider(context.Background(), in, &plan.DataSourceMetadata{}, cfg, logger, "host", "addr") + require.NoError(t, err) + require.NotNil(t, provider) + + // Check the returned provider + kafkaProvider, ok := provider.(*PubSubProvider) + require.True(t, ok) + assert.NotNil(t, kafkaProvider.Logger) + assert.NotNil(t, kafkaProvider.Providers) + assert.Contains(t, kafkaProvider.Providers, providerId) + }) +} + +func TestPubSubProvider_FindPubSubDataSource(t *testing.T) { + mock := &mockAdapter{} + providerId := "test-provider" + typeName := "TestType" + fieldName := "testField" + + provider := &PubSubProvider{ + EventConfiguration: []*nodev1.KafkaEventConfiguration{ + { + EngineEventConfiguration: &nodev1.EngineEventConfiguration{ + TypeName: typeName, + FieldName: fieldName, + ProviderId: providerId, + }, + }, + }, + Logger: zap.NewNop(), + Providers: map[string]AdapterInterface{ + providerId: mock, + }, + } + + t.Run("find matching datasource", func(t *testing.T) { + ds, err := provider.FindPubSubDataSource(typeName, fieldName, nil) + require.NoError(t, err) + require.NotNil(t, ds) + + // Check the returned datasource + kafkaDs, ok := ds.(*PubSubDataSource) + require.True(t, ok) + assert.Equal(t, mock, kafkaDs.KafkaAdapter) + assert.Equal(t, provider.EventConfiguration[0], kafkaDs.EventConfiguration) + }) + + t.Run("return nil if no match", func(t *testing.T) { + ds, err := provider.FindPubSubDataSource("OtherType", fieldName, nil) + require.NoError(t, err) + require.Nil(t, ds) + }) +} diff --git a/router/pkg/pubsub/kafka/pubsub_datasource.go b/router/pkg/pubsub/kafka/pubsub_datasource.go new file mode 100644 index 0000000000..1f749ab474 --- /dev/null +++ b/router/pkg/pubsub/kafka/pubsub_datasource.go @@ -0,0 +1,92 @@ +package kafka + +import ( + "encoding/json" + "fmt" + + nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +type PubSubDataSource struct { + EventConfiguration *nodev1.KafkaEventConfiguration + KafkaAdapter AdapterInterface +} + +func (c *PubSubDataSource) EngineEventConfiguration() *nodev1.EngineEventConfiguration { + return c.EventConfiguration.GetEngineEventConfiguration() +} + +func (c *PubSubDataSource) ResolveDataSource() (resolve.DataSource, error) { + var dataSource resolve.DataSource + + typeName := c.EventConfiguration.GetEngineEventConfiguration().GetType() + switch typeName { + case nodev1.EventType_PUBLISH: + dataSource = &PublishDataSource{ + pubSub: c.KafkaAdapter, + } + default: + return nil, fmt.Errorf("failed to configure fetch: invalid event type \"%s\" for Kafka", typeName.String()) + } + + return dataSource, nil +} + +func (c *PubSubDataSource) ResolveDataSourceInput(event []byte) (string, error) { + topics := c.EventConfiguration.GetTopics() + + if len(topics) != 1 { + return "", fmt.Errorf("publish and request events should define one topic but received %d", len(topics)) + } + + topic := topics[0] + + providerId := c.GetProviderId() + + evtCfg := PublishEventConfiguration{ + ProviderID: providerId, + Topic: topic, + Data: event, + } + + return evtCfg.MarshalJSONTemplate(), nil +} + +func (c *PubSubDataSource) ResolveDataSourceSubscription() (resolve.SubscriptionDataSource, error) { + return &SubscriptionDataSource{ + pubSub: c.KafkaAdapter, + }, nil +} + +func (c *PubSubDataSource) ResolveDataSourceSubscriptionInput() (string, error) { + providerId := c.GetProviderId() + evtCfg := SubscriptionEventConfiguration{ + ProviderID: providerId, + Topics: c.EventConfiguration.GetTopics(), + } + object, err := json.Marshal(evtCfg) + if err != nil { + return "", fmt.Errorf("failed to marshal event subscription streamConfiguration") + } + return string(object), nil +} + +func (c *PubSubDataSource) GetProviderId() string { + return c.EventConfiguration.GetEngineEventConfiguration().GetProviderId() +} + +type SubscriptionEventConfiguration struct { + ProviderID string `json:"providerId"` + Topics []string `json:"topics"` +} + +type PublishEventConfiguration struct { + ProviderID string `json:"providerId"` + Topic string `json:"topic"` + Data json.RawMessage `json:"data"` +} + +func (s *PublishEventConfiguration) MarshalJSONTemplate() string { + return fmt.Sprintf(`{"topic":"%s", "data": %s, "providerId":"%s"}`, s.Topic, s.Data, s.ProviderID) +} diff --git a/router/pkg/pubsub/kafka/pubsub_datasource_test.go b/router/pkg/pubsub/kafka/pubsub_datasource_test.go new file mode 100644 index 0000000000..cd10a1eadb --- /dev/null +++ b/router/pkg/pubsub/kafka/pubsub_datasource_test.go @@ -0,0 +1,203 @@ +package kafka + +import ( + "bytes" + "context" + "encoding/json" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "github.com/wundergraph/cosmo/router/pkg/pubsub/pubsubtest" +) + +func TestKafkaPubSubDataSource(t *testing.T) { + // Create event configuration with required fields + engineEventConfig := &nodev1.EngineEventConfiguration{ + ProviderId: "test-provider", + Type: nodev1.EventType_PUBLISH, + TypeName: "TestType", + FieldName: "testField", + } + + kafkaCfg := &nodev1.KafkaEventConfiguration{ + EngineEventConfiguration: engineEventConfig, + Topics: []string{"test-topic"}, + } + + // Create the data source to test with a real adapter + adapter := &Adapter{} + pubsub := &PubSubDataSource{ + EventConfiguration: kafkaCfg, + KafkaAdapter: adapter, + } + + // Run the standard test suite + pubsubtest.VerifyPubSubDataSourceImplementation(t, pubsub) +} + +// TestPubSubDataSourceWithMockAdapter tests the PubSubDataSource with a mocked adapter +func TestPubSubDataSourceWithMockAdapter(t *testing.T) { + // Create event configuration with required fields + engineEventConfig := &nodev1.EngineEventConfiguration{ + ProviderId: "test-provider", + Type: nodev1.EventType_PUBLISH, + TypeName: "TestType", + FieldName: "testField", + } + + kafkaCfg := &nodev1.KafkaEventConfiguration{ + EngineEventConfiguration: engineEventConfig, + Topics: []string{"test-topic"}, + } + + // Create mock adapter + mockAdapter := new(mockAdapter) + + // Configure mock expectations for Publish + mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishEventConfiguration) bool { + return event.ProviderID == "test-provider" && event.Topic == "test-topic" + })).Return(nil) + + // Create the data source with mock adapter + pubsub := &PubSubDataSource{ + EventConfiguration: kafkaCfg, + KafkaAdapter: mockAdapter, + } + + // Get the data source + ds, err := pubsub.ResolveDataSource() + require.NoError(t, err) + + // Get the input + input, err := pubsub.ResolveDataSourceInput([]byte(`{"test":"data"}`)) + require.NoError(t, err) + + // Call Load on the data source + out := &bytes.Buffer{} + err = ds.Load(context.Background(), []byte(input), out) + require.NoError(t, err) + require.Equal(t, `{"success": true}`, out.String()) + + // Verify mock expectations + mockAdapter.AssertExpectations(t) +} + +// TestPubSubDataSource_GetResolveDataSource_WrongType tests the PubSubDataSource with a mocked adapter +func TestPubSubDataSource_GetResolveDataSource_WrongType(t *testing.T) { + // Create event configuration with required fields + engineEventConfig := &nodev1.EngineEventConfiguration{ + ProviderId: "test-provider", + Type: nodev1.EventType_SUBSCRIBE, + TypeName: "TestType", + FieldName: "testField", + } + + kafkaCfg := &nodev1.KafkaEventConfiguration{ + EngineEventConfiguration: engineEventConfig, + Topics: []string{"test-topic"}, + } + + // Create mock adapter + mockAdapter := new(mockAdapter) + + // Create the data source with mock adapter + pubsub := &PubSubDataSource{ + EventConfiguration: kafkaCfg, + KafkaAdapter: mockAdapter, + } + + // Get the data source + ds, err := pubsub.ResolveDataSource() + require.Error(t, err) + require.Nil(t, ds) +} + +// TestPubSubDataSource_GetResolveDataSourceInput_MultipleTopics tests the PubSubDataSource with a mocked adapter +func TestPubSubDataSource_GetResolveDataSourceInput_MultipleTopics(t *testing.T) { + // Create event configuration with required fields + engineEventConfig := &nodev1.EngineEventConfiguration{ + ProviderId: "test-provider", + Type: nodev1.EventType_PUBLISH, + TypeName: "TestType", + FieldName: "testField", + } + + kafkaCfg := &nodev1.KafkaEventConfiguration{ + EngineEventConfiguration: engineEventConfig, + Topics: []string{"test-topic-1", "test-topic-2"}, + } + + // Create the data source with mock adapter + pubsub := &PubSubDataSource{ + EventConfiguration: kafkaCfg, + } + + // Get the input + input, err := pubsub.ResolveDataSourceInput([]byte(`{"test":"data"}`)) + require.Error(t, err) + require.Empty(t, input) +} + +// TestPubSubDataSource_GetResolveDataSourceInput_NoTopics tests the PubSubDataSource with a mocked adapter +func TestPubSubDataSource_GetResolveDataSourceInput_NoTopics(t *testing.T) { + // Create event configuration with required fields + engineEventConfig := &nodev1.EngineEventConfiguration{ + ProviderId: "test-provider", + Type: nodev1.EventType_PUBLISH, + TypeName: "TestType", + FieldName: "testField", + } + + kafkaCfg := &nodev1.KafkaEventConfiguration{ + EngineEventConfiguration: engineEventConfig, + Topics: []string{}, + } + + // Create the data source with mock adapter + pubsub := &PubSubDataSource{ + EventConfiguration: kafkaCfg, + } + + // Get the input + input, err := pubsub.ResolveDataSourceInput([]byte(`{"test":"data"}`)) + require.Error(t, err) + require.Empty(t, input) +} + +// TestKafkaPubSubDataSourceMultiTopicSubscription tests only the subscription functionality +// for multiple topics. The publish and resolve datasource tests are skipped since they +// do not support multiple topics. +func TestKafkaPubSubDataSourceMultiTopicSubscription(t *testing.T) { + // Create event configuration with multiple topics + engineEventConfig := &nodev1.EngineEventConfiguration{ + ProviderId: "test-provider", + Type: nodev1.EventType_PUBLISH, // Must be PUBLISH as it's the only supported type + TypeName: "TestType", + FieldName: "testField", + } + + kafkaCfg := &nodev1.KafkaEventConfiguration{ + EngineEventConfiguration: engineEventConfig, + Topics: []string{"test-topic-1", "test-topic-2"}, + } + + // Create the data source to test with mock adapter + pubsub := &PubSubDataSource{ + EventConfiguration: kafkaCfg, + } + + // Test GetResolveDataSourceSubscriptionInput + subscriptionInput, err := pubsub.ResolveDataSourceSubscriptionInput() + require.NoError(t, err, "Expected no error from GetResolveDataSourceSubscriptionInput") + require.NotEmpty(t, subscriptionInput, "Expected non-empty subscription input") + + // Verify the subscription input contains both topics + var subscriptionConfig SubscriptionEventConfiguration + err = json.Unmarshal([]byte(subscriptionInput), &subscriptionConfig) + require.NoError(t, err, "Expected valid JSON from GetResolveDataSourceSubscriptionInput") + require.Equal(t, 2, len(subscriptionConfig.Topics), "Expected 2 topics in subscription configuration") + require.Equal(t, "test-topic-1", subscriptionConfig.Topics[0], "Expected first topic to be 'test-topic-1'") + require.Equal(t, "test-topic-2", subscriptionConfig.Topics[1], "Expected second topic to be 'test-topic-2'") +} diff --git a/router/pkg/pubsub/nats/nats.go b/router/pkg/pubsub/nats/adapter.go similarity index 60% rename from router/pkg/pubsub/nats/nats.go rename to router/pkg/pubsub/nats/adapter.go index 2e01432d12..a355cb676a 100644 --- a/router/pkg/pubsub/nats/nats.go +++ b/router/pkg/pubsub/nats/adapter.go @@ -4,69 +4,51 @@ import ( "context" "errors" "fmt" + "io" + "sync" + "time" + "github.com/cespare/xxhash/v2" "github.com/nats-io/nats.go" "github.com/nats-io/nats.go/jetstream" - "github.com/wundergraph/cosmo/router/pkg/pubsub" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/pubsub_datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" - "io" - "sync" - "time" ) -var ( - _ pubsub_datasource.NatsConnector = (*connector)(nil) - _ pubsub_datasource.NatsPubSub = (*natsPubSub)(nil) - _ pubsub.Lifecycle = (*natsPubSub)(nil) -) - -type connector struct { - conn *nats.Conn - logger *zap.Logger - js jetstream.JetStream - hostName string - routerListenAddr string -} - -func NewConnector(logger *zap.Logger, conn *nats.Conn, js jetstream.JetStream, hostName string, routerListenAddr string) pubsub_datasource.NatsConnector { - return &connector{ - conn: conn, - logger: logger, - js: js, - hostName: hostName, - routerListenAddr: routerListenAddr, - } -} - -func (c *connector) New(ctx context.Context) pubsub_datasource.NatsPubSub { - return &natsPubSub{ - ctx: ctx, - conn: c.conn, - js: c.js, - logger: c.logger.With(zap.String("pubsub", "nats")), - closeWg: sync.WaitGroup{}, - hostName: c.hostName, - routerListenAddr: c.routerListenAddr, - } +// AdapterInterface defines the methods that a NATS adapter should implement +type AdapterInterface interface { + // Subscribe subscribes to the given events and sends updates to the updater + Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error + // Publish publishes the given event to the specified subject + Publish(ctx context.Context, event PublishAndRequestEventConfiguration) error + // Request sends a request to the specified subject and writes the response to the given writer + Request(ctx context.Context, event PublishAndRequestEventConfiguration, w io.Writer) error + // Startup initializes the adapter + Startup(ctx context.Context) error + // Shutdown gracefully shuts down the adapter + Shutdown(ctx context.Context) error } -type natsPubSub struct { +// Adapter implements the AdapterInterface for NATS pub/sub +type Adapter struct { ctx context.Context - conn *nats.Conn - logger *zap.Logger + client *nats.Conn js jetstream.JetStream + logger *zap.Logger closeWg sync.WaitGroup hostName string routerListenAddr string + url string + opts []nats.Option + flushTimeout time.Duration } // getInstanceIdentifier returns an identifier for the current instance. // We use the hostname and the address the router is listening on, which should provide a good representation // of what a unique instance is from the perspective of the client that has started a subscription to this instance // and want to restart the subscription after a failure on the client or router side. -func (p *natsPubSub) getInstanceIdentifier() string { +func (p *Adapter) getInstanceIdentifier() string { return fmt.Sprintf("%s-%s", p.hostName, p.routerListenAddr) } @@ -74,7 +56,7 @@ func (p *natsPubSub) getInstanceIdentifier() string { // we need to make sure that the durable consumer name is unique for each instance and subjects to prevent // multiple routers from changing the same consumer, which would lead to message loss and wrong messages delivered // to the subscribers -func (p *natsPubSub) getDurableConsumerName(durableName string, subjects []string) (string, error) { +func (p *Adapter) getDurableConsumerName(durableName string, subjects []string) (string, error) { subjHash := xxhash.New() _, err := subjHash.WriteString(p.getInstanceIdentifier()) if err != nil { @@ -90,13 +72,21 @@ func (p *natsPubSub) getDurableConsumerName(durableName string, subjects []strin return fmt.Sprintf("%s-%x", durableName, subjHash.Sum64()), nil } -func (p *natsPubSub) Subscribe(ctx context.Context, event pubsub_datasource.NatsSubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error { +func (p *Adapter) Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error { log := p.logger.With( zap.String("provider_id", event.ProviderID), zap.String("method", "subscribe"), zap.Strings("subjects", event.Subjects), ) + if p.client == nil { + return datasource.NewError("nats client not initialized", nil) + } + + if p.js == nil { + return datasource.NewError("nats jetstream not initialized", nil) + } + if event.StreamConfiguration != nil { durableConsumerName, err := p.getDurableConsumerName(event.StreamConfiguration.Consumer, event.Subjects) if err != nil { @@ -110,10 +100,11 @@ func (p *natsPubSub) Subscribe(ctx context.Context, event pubsub_datasource.Nats if event.StreamConfiguration.ConsumerInactiveThreshold > 0 { consumerConfig.InactiveThreshold = time.Duration(event.StreamConfiguration.ConsumerInactiveThreshold) * time.Second } + consumer, err := p.js.CreateOrUpdateConsumer(ctx, event.StreamConfiguration.StreamName, consumerConfig) if err != nil { log.Error("creating or updating consumer", zap.Error(err)) - return pubsub.NewError(fmt.Sprintf(`failed to create or update consumer for stream "%s"`, event.StreamConfiguration.StreamName), err) + return datasource.NewError(fmt.Sprintf(`failed to create or update consumer for stream "%s"`, event.StreamConfiguration.StreamName), err) } p.closeWg.Add(1) @@ -161,10 +152,10 @@ func (p *natsPubSub) Subscribe(ctx context.Context, event pubsub_datasource.Nats msgChan := make(chan *nats.Msg) subscriptions := make([]*nats.Subscription, len(event.Subjects)) for i, subject := range event.Subjects { - subscription, err := p.conn.ChanSubscribe(subject, msgChan) + subscription, err := p.client.ChanSubscribe(subject, msgChan) if err != nil { log.Error("subscribing to NATS subject", zap.Error(err), zap.String("subscription_subject", subject)) - return pubsub.NewError(fmt.Sprintf(`failed to subscribe to NATS subject "%s"`, subject), err) + return datasource.NewError(fmt.Sprintf(`failed to subscribe to NATS subject "%s"`, subject), err) } subscriptions[i] = subscription } @@ -206,37 +197,45 @@ func (p *natsPubSub) Subscribe(ctx context.Context, event pubsub_datasource.Nats return nil } -func (p *natsPubSub) Publish(_ context.Context, event pubsub_datasource.NatsPublishAndRequestEventConfiguration) error { +func (p *Adapter) Publish(_ context.Context, event PublishAndRequestEventConfiguration) error { log := p.logger.With( zap.String("provider_id", event.ProviderID), zap.String("method", "publish"), zap.String("subject", event.Subject), ) + if p.client == nil { + return datasource.NewError("nats client not initialized", nil) + } + log.Debug("publish", zap.ByteString("data", event.Data)) - err := p.conn.Publish(event.Subject, event.Data) + err := p.client.Publish(event.Subject, event.Data) if err != nil { log.Error("publish error", zap.Error(err)) - return pubsub.NewError(fmt.Sprintf("error publishing to NATS subject %s", event.Subject), err) + return datasource.NewError(fmt.Sprintf("error publishing to NATS subject %s", event.Subject), err) } return nil } -func (p *natsPubSub) Request(ctx context.Context, event pubsub_datasource.NatsPublishAndRequestEventConfiguration, w io.Writer) error { +func (p *Adapter) Request(ctx context.Context, event PublishAndRequestEventConfiguration, w io.Writer) error { log := p.logger.With( zap.String("provider_id", event.ProviderID), zap.String("method", "request"), zap.String("subject", event.Subject), ) + if p.client == nil { + return datasource.NewError("nats client not initialized", nil) + } + log.Debug("request", zap.ByteString("data", event.Data)) - msg, err := p.conn.RequestWithContext(ctx, event.Subject, event.Data) + msg, err := p.client.RequestWithContext(ctx, event.Subject, event.Data) if err != nil { log.Error("request error", zap.Error(err)) - return pubsub.NewError(fmt.Sprintf("error requesting from NATS subject %s", event.Subject), err) + return datasource.NewError(fmt.Sprintf("error requesting from NATS subject %s", event.Subject), err) } _, err = w.Write(msg.Data) @@ -248,34 +247,75 @@ func (p *natsPubSub) Request(ctx context.Context, event pubsub_datasource.NatsPu return err } -func (p *natsPubSub) flush(ctx context.Context) error { - return p.conn.FlushWithContext(ctx) +func (p *Adapter) flush(ctx context.Context) error { + if p.client == nil { + return nil + } + _, ok := ctx.Deadline() + if !ok { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(ctx, p.flushTimeout) + defer cancel() + } + return p.client.FlushWithContext(ctx) } -func (p *natsPubSub) Shutdown(ctx context.Context) error { +func (p *Adapter) Startup(ctx context.Context) (err error) { + p.client, err = nats.Connect(p.url, p.opts...) + if err != nil { + return err + } + p.js, err = jetstream.New(p.client) + if err != nil { + return err + } + return nil +} - if p.conn.IsClosed() { +func (p *Adapter) Shutdown(ctx context.Context) error { + if p.client == nil { return nil } - var err error + if p.client.IsClosed() { + return nil // Already disconnected or failed to connect + } + + var shutdownErr error fErr := p.flush(ctx) if fErr != nil { - err = errors.Join(err, fErr) + shutdownErr = errors.Join(shutdownErr, fErr) } - drainErr := p.conn.Drain() + drainErr := p.client.Drain() if drainErr != nil { - err = errors.Join(err, drainErr) + shutdownErr = errors.Join(shutdownErr, drainErr) } // Wait for all subscriptions to be closed p.closeWg.Wait() - if err != nil { - return fmt.Errorf("nats pubsub shutdown: %w", err) + if shutdownErr != nil { + return fmt.Errorf("nats pubsub shutdown: %w", shutdownErr) } return nil } + +func NewAdapter(ctx context.Context, logger *zap.Logger, url string, opts []nats.Option, hostName string, routerListenAddr string) (AdapterInterface, error) { + if logger == nil { + logger = zap.NewNop() + } + + return &Adapter{ + ctx: ctx, + logger: logger.With(zap.String("pubsub", "nats")), + closeWg: sync.WaitGroup{}, + hostName: hostName, + routerListenAddr: routerListenAddr, + url: url, + opts: opts, + flushTimeout: 10 * time.Second, + }, nil +} diff --git a/router/pkg/pubsub/nats/engine_datasource.go b/router/pkg/pubsub/nats/engine_datasource.go new file mode 100644 index 0000000000..c63da7ad0f --- /dev/null +++ b/router/pkg/pubsub/nats/engine_datasource.go @@ -0,0 +1,89 @@ +package nats + +import ( + "bytes" + "context" + "encoding/json" + "io" + + "github.com/buger/jsonparser" + "github.com/cespare/xxhash/v2" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +type SubscriptionSource struct { + pubSub AdapterInterface +} + +func (s *SubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + + val, _, _, err := jsonparser.Get(input, "subjects") + if err != nil { + return err + } + + _, err = xxh.Write(val) + if err != nil { + return err + } + + val, _, _, err = jsonparser.Get(input, "providerId") + if err != nil { + return err + } + + _, err = xxh.Write(val) + return err +} + +func (s *SubscriptionSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { + var subscriptionConfiguration SubscriptionEventConfiguration + err := json.Unmarshal(input, &subscriptionConfiguration) + if err != nil { + return err + } + + return s.pubSub.Subscribe(ctx.Context(), subscriptionConfiguration, updater) +} + +type NatsPublishDataSource struct { + pubSub AdapterInterface +} + +func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { + var publishConfiguration PublishAndRequestEventConfiguration + err := json.Unmarshal(input, &publishConfiguration) + if err != nil { + return err + } + + if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { + _, err = io.WriteString(out, `{"success": false}`) + return err + } + _, err = io.WriteString(out, `{"success": true}`) + return err +} + +func (s *NatsPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { + panic("not implemented") +} + +type NatsRequestDataSource struct { + pubSub AdapterInterface +} + +func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { + var subscriptionConfiguration PublishAndRequestEventConfiguration + err := json.Unmarshal(input, &subscriptionConfiguration) + if err != nil { + return err + } + + return s.pubSub.Request(ctx, subscriptionConfiguration, out) +} + +func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) error { + panic("not implemented") +} diff --git a/router/pkg/pubsub/nats/engine_datasource_test.go b/router/pkg/pubsub/nats/engine_datasource_test.go new file mode 100644 index 0000000000..89c289926d --- /dev/null +++ b/router/pkg/pubsub/nats/engine_datasource_test.go @@ -0,0 +1,380 @@ +package nats + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "io" + "testing" + + "github.com/cespare/xxhash/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +// MockSubscriptionUpdater implements resolve.SubscriptionUpdater +type MockSubscriptionUpdater struct { + mock.Mock +} + +func (m *MockSubscriptionUpdater) Update(data []byte) { + m.Called(data) +} + +func (m *MockSubscriptionUpdater) Close() { + m.Called() +} + +func (m *MockSubscriptionUpdater) Done() { + m.Called() +} + +func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { + tests := []struct { + name string + config PublishEventConfiguration + wantPattern string + }{ + { + name: "simple configuration", + config: PublishEventConfiguration{ + ProviderID: "test-provider", + Subject: "test-subject", + Data: json.RawMessage(`{"message":"hello"}`), + }, + wantPattern: `{"subject":"test-subject", "data": {"message":"hello"}, "providerId":"test-provider"}`, + }, + { + name: "with special characters", + config: PublishEventConfiguration{ + ProviderID: "test-provider-id", + Subject: "subject-with-hyphens", + Data: json.RawMessage(`{"message":"special \"quotes\" here"}`), + }, + wantPattern: `{"subject":"subject-with-hyphens", "data": {"message":"special \"quotes\" here"}, "providerId":"test-provider-id"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.config.MarshalJSONTemplate() + assert.Equal(t, tt.wantPattern, result) + }) + } +} + +func TestPublishAndRequestEventConfiguration_MarshalJSONTemplate(t *testing.T) { + tests := []struct { + name string + config PublishAndRequestEventConfiguration + wantPattern string + }{ + { + name: "simple configuration", + config: PublishAndRequestEventConfiguration{ + ProviderID: "test-provider", + Subject: "test-subject", + Data: json.RawMessage(`{"message":"hello"}`), + }, + wantPattern: `{"subject":"test-subject", "data": {"message":"hello"}, "providerId":"test-provider"}`, + }, + { + name: "with special characters", + config: PublishAndRequestEventConfiguration{ + ProviderID: "test-provider-id", + Subject: "subject-with-hyphens", + Data: json.RawMessage(`{"message":"special \"quotes\" here"}`), + }, + wantPattern: `{"subject":"subject-with-hyphens", "data": {"message":"special \"quotes\" here"}, "providerId":"test-provider-id"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.config.MarshalJSONTemplate() + assert.Equal(t, tt.wantPattern, result) + }) + } +} + +func TestSubscriptionSource_UniqueRequestID(t *testing.T) { + tests := []struct { + name string + input string + expectError bool + expectedError error + }{ + { + name: "valid input", + input: `{"subjects":["subject1", "subject2"], "providerId":"test-provider"}`, + expectError: false, + }, + { + name: "missing subjects", + input: `{"providerId":"test-provider"}`, + expectError: true, + expectedError: errors.New("Key path not found"), + }, + { + name: "missing providerId", + input: `{"subjects":["subject1", "subject2"]}`, + expectError: true, + expectedError: errors.New("Key path not found"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + source := &SubscriptionSource{ + pubSub: &mockAdapter{}, + } + ctx := &resolve.Context{} + input := []byte(tt.input) + xxh := xxhash.New() + + err := source.UniqueRequestID(ctx, input, xxh) + + if tt.expectError { + require.Error(t, err) + if tt.expectedError != nil { + // For jsonparser errors, just check if the error message contains the expected text + assert.Contains(t, err.Error(), tt.expectedError.Error()) + } + } else { + require.NoError(t, err) + // Check that the hash has been updated + assert.NotEqual(t, 0, xxh.Sum64()) + } + }) + } +} + +func TestSubscriptionSource_Start(t *testing.T) { + tests := []struct { + name string + input string + mockSetup func(*mockAdapter) + expectError bool + }{ + { + name: "successful subscription", + input: `{"subjects":["subject1", "subject2"], "providerId":"test-provider"}`, + mockSetup: func(m *mockAdapter) { + m.On("Subscribe", mock.Anything, SubscriptionEventConfiguration{ + ProviderID: "test-provider", + Subjects: []string{"subject1", "subject2"}, + }, mock.Anything).Return(nil) + }, + expectError: false, + }, + { + name: "adapter returns error", + input: `{"subjects":["subject1"], "providerId":"test-provider"}`, + mockSetup: func(m *mockAdapter) { + m.On("Subscribe", mock.Anything, SubscriptionEventConfiguration{ + ProviderID: "test-provider", + Subjects: []string{"subject1"}, + }, mock.Anything).Return(errors.New("subscription error")) + }, + expectError: true, + }, + { + name: "invalid input json", + input: `{"invalid json":`, + mockSetup: func(m *mockAdapter) {}, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockAdapter := new(mockAdapter) + tt.mockSetup(mockAdapter) + + source := &SubscriptionSource{ + pubSub: mockAdapter, + } + + // Set up go context + goCtx := context.Background() + + // Create a resolve.Context with the standard context + resolveCtx := &resolve.Context{} + resolveCtx = resolveCtx.WithContext(goCtx) + + // Create a proper mock updater + updater := new(MockSubscriptionUpdater) + updater.On("Done").Return() + + input := []byte(tt.input) + err := source.Start(resolveCtx, input, updater) + + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + } + mockAdapter.AssertExpectations(t) + }) + } +} + +func TestNatsPublishDataSource_Load(t *testing.T) { + tests := []struct { + name string + input string + mockSetup func(*mockAdapter) + expectError bool + expectedOutput string + expectPublished bool + }{ + { + name: "successful publish", + input: `{"subject":"test-subject", "data":{"message":"hello"}, "providerId":"test-provider"}`, + mockSetup: func(m *mockAdapter) { + m.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { + return event.ProviderID == "test-provider" && + event.Subject == "test-subject" && + string(event.Data) == `{"message":"hello"}` + })).Return(nil) + }, + expectError: false, + expectedOutput: `{"success": true}`, + expectPublished: true, + }, + { + name: "publish error", + input: `{"subject":"test-subject", "data":{"message":"hello"}, "providerId":"test-provider"}`, + mockSetup: func(m *mockAdapter) { + m.On("Publish", mock.Anything, mock.Anything).Return(errors.New("publish error")) + }, + expectError: false, // The Load method doesn't return the publish error directly + expectedOutput: `{"success": false}`, + expectPublished: true, + }, + { + name: "invalid input json", + input: `{"invalid json":`, + mockSetup: func(m *mockAdapter) {}, + expectError: true, + expectPublished: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockAdapter := new(mockAdapter) + tt.mockSetup(mockAdapter) + + dataSource := &NatsPublishDataSource{ + pubSub: mockAdapter, + } + + ctx := context.Background() + input := []byte(tt.input) + var out bytes.Buffer + + err := dataSource.Load(ctx, input, &out) + + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + if tt.expectPublished { + mockAdapter.AssertExpectations(t) + } + if tt.expectedOutput != "" { + assert.Equal(t, tt.expectedOutput, out.String()) + } + } + }) + } +} + +func TestNatsPublishDataSource_LoadWithFiles(t *testing.T) { + dataSource := &NatsPublishDataSource{} + assert.Panics(t, func() { + dataSource.LoadWithFiles(context.Background(), []byte{}, nil, &bytes.Buffer{}) + }, "Expected LoadWithFiles to panic with 'not implemented'") +} + +func TestNatsRequestDataSource_Load(t *testing.T) { + tests := []struct { + name string + input string + mockSetup func(*mockAdapter) + expectError bool + expectedOutput string + }{ + { + name: "successful request", + input: `{"subject":"test-subject", "data":{"message":"hello"}, "providerId":"test-provider"}`, + mockSetup: func(m *mockAdapter) { + m.On("Request", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { + return event.ProviderID == "test-provider" && + event.Subject == "test-subject" && + string(event.Data) == `{"message":"hello"}` + }), mock.Anything).Run(func(args mock.Arguments) { + // Write response to the output buffer + w := args.Get(2).(io.Writer) + _, _ = w.Write([]byte(`{"response":"success"}`)) + }).Return(nil) + }, + expectError: false, + expectedOutput: `{"response":"success"}`, + }, + { + name: "request error", + input: `{"subject":"test-subject", "data":{"message":"hello"}, "providerId":"test-provider"}`, + mockSetup: func(m *mockAdapter) { + m.On("Request", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("request error")) + }, + expectError: true, + expectedOutput: "", + }, + { + name: "invalid input json", + input: `{"invalid json":`, + mockSetup: func(m *mockAdapter) {}, + expectError: true, + expectedOutput: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockAdapter := new(mockAdapter) + tt.mockSetup(mockAdapter) + + dataSource := &NatsRequestDataSource{ + pubSub: mockAdapter, + } + + ctx := context.Background() + input := []byte(tt.input) + var out bytes.Buffer + + err := dataSource.Load(ctx, input, &out) + + if tt.expectError { + require.Error(t, err) + } else { + require.NoError(t, err) + mockAdapter.AssertExpectations(t) + if tt.expectedOutput != "" { + assert.Equal(t, tt.expectedOutput, out.String()) + } + } + }) + } +} + +func TestNatsRequestDataSource_LoadWithFiles(t *testing.T) { + dataSource := &NatsRequestDataSource{} + assert.Panics(t, func() { + dataSource.LoadWithFiles(context.Background(), []byte{}, nil, &bytes.Buffer{}) + }, "Expected LoadWithFiles to panic with 'not implemented'") +} diff --git a/router/pkg/pubsub/nats/provider.go b/router/pkg/pubsub/nats/provider.go new file mode 100644 index 0000000000..a4e9af1248 --- /dev/null +++ b/router/pkg/pubsub/nats/provider.go @@ -0,0 +1,169 @@ +package nats + +import ( + "context" + "errors" + "fmt" + "slices" + "time" + + "github.com/nats-io/nats.go" + nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" + "go.uber.org/zap" +) + +func buildNatsOptions(eventSource config.NatsEventSource, logger *zap.Logger) ([]nats.Option, error) { + opts := []nats.Option{ + nats.Name(fmt.Sprintf("cosmo.router.edfs.nats.%s", eventSource.ID)), + nats.ReconnectJitter(500*time.Millisecond, 2*time.Second), + nats.ClosedHandler(func(conn *nats.Conn) { + logger.Info("NATS connection closed", zap.String("provider_id", eventSource.ID), zap.Error(conn.LastError())) + }), + nats.ConnectHandler(func(nc *nats.Conn) { + logger.Info("NATS connection established", zap.String("provider_id", eventSource.ID), zap.String("url", nc.ConnectedUrlRedacted())) + }), + nats.DisconnectErrHandler(func(nc *nats.Conn, err error) { + if err != nil { + logger.Error("NATS disconnected; will attempt to reconnect", zap.Error(err), zap.String("provider_id", eventSource.ID)) + } else { + logger.Info("NATS disconnected", zap.String("provider_id", eventSource.ID)) + } + }), + nats.ErrorHandler(func(conn *nats.Conn, subscription *nats.Subscription, err error) { + if errors.Is(err, nats.ErrSlowConsumer) { + logger.Warn( + "NATS slow consumer detected. Events are being dropped. Please consider increasing the buffer size or reducing the number of messages being sent.", + zap.Error(err), + zap.String("provider_id", eventSource.ID), + ) + } else { + logger.Error("NATS error", zap.Error(err)) + } + }), + nats.ReconnectHandler(func(conn *nats.Conn) { + logger.Info("NATS reconnected", zap.String("provider_id", eventSource.ID), zap.String("url", conn.ConnectedUrlRedacted())) + }), + } + + if eventSource.Authentication != nil { + if eventSource.Authentication.Token != nil { + opts = append(opts, nats.Token(*eventSource.Authentication.Token)) + } else if eventSource.Authentication.UserInfo.Username != nil && eventSource.Authentication.UserInfo.Password != nil { + opts = append(opts, nats.UserInfo(*eventSource.Authentication.UserInfo.Username, *eventSource.Authentication.UserInfo.Password)) + } + } + + return opts, nil +} + +func transformEventConfig(cfg *nodev1.NatsEventConfiguration, fn datasource.ArgumentTemplateCallback) (*nodev1.NatsEventConfiguration, error) { + switch v := cfg.GetEngineEventConfiguration().GetType(); v { + case nodev1.EventType_PUBLISH, nodev1.EventType_REQUEST: + extractedSubject, err := fn(cfg.GetSubjects()[0]) + if err != nil { + return cfg, fmt.Errorf("unable to parse subject with id %s", cfg.GetSubjects()[0]) + } + if !isValidNatsSubject(extractedSubject) { + return cfg, fmt.Errorf("invalid subject: %s", extractedSubject) + } + cfg.Subjects = []string{extractedSubject} + case nodev1.EventType_SUBSCRIBE: + extractedSubjects := make([]string, 0, len(cfg.Subjects)) + for _, rawSubject := range cfg.Subjects { + extractedSubject, err := fn(rawSubject) + if err != nil { + return cfg, nil + } + if !isValidNatsSubject(extractedSubject) { + return cfg, fmt.Errorf("invalid subject: %s", extractedSubject) + } + extractedSubjects = append(extractedSubjects, extractedSubject) + } + slices.Sort(extractedSubjects) + cfg.Subjects = extractedSubjects + } + return cfg, nil +} + +type PubSubProvider struct { + EventConfiguration []*nodev1.NatsEventConfiguration + Logger *zap.Logger + Providers map[string]AdapterInterface +} + +func (c *PubSubProvider) FindPubSubDataSource(typeName string, fieldName string, extractFn datasource.ArgumentTemplateCallback) (datasource.PubSubDataSource, error) { + for _, cfg := range c.EventConfiguration { + if cfg.GetEngineEventConfiguration().GetTypeName() == typeName && cfg.GetEngineEventConfiguration().GetFieldName() == fieldName { + transformedCfg, err := transformEventConfig(cfg, extractFn) + if err != nil { + return nil, err + } + return &PubSubDataSource{ + EventConfiguration: transformedCfg, + NatsAdapter: c.Providers[cfg.GetEngineEventConfiguration().GetProviderId()], + }, nil + } + } + return nil, nil +} + +func GetProvider(ctx context.Context, in *nodev1.DataSourceConfiguration, dsMeta *plan.DataSourceMetadata, config config.EventsConfiguration, logger *zap.Logger, hostName string, routerListenAddr string) (datasource.PubSubProvider, error) { + var providers map[string]AdapterInterface + if natsData := in.GetCustomEvents().GetNats(); natsData != nil { + definedProviders := make(map[string]bool) + for _, provider := range config.Providers.Nats { + definedProviders[provider.ID] = true + } + usedProviders := make(map[string]bool) + for _, event := range natsData { + if _, found := definedProviders[event.EngineEventConfiguration.ProviderId]; !found { + return nil, fmt.Errorf("failed to find Nats provider with ID %s", event.EngineEventConfiguration.ProviderId) + } + usedProviders[event.EngineEventConfiguration.ProviderId] = true + } + providers = map[string]AdapterInterface{} + for _, provider := range config.Providers.Nats { + if !usedProviders[provider.ID] { + continue + } + options, err := buildNatsOptions(provider, logger) + if err != nil { + return nil, fmt.Errorf("failed to build options for Nats provider with ID \"%s\": %w", provider.ID, err) + } + + adapter, err := NewAdapter(ctx, logger, provider.URL, options, hostName, routerListenAddr) + if err != nil { + return nil, fmt.Errorf("failed to create adapter for Nats provider with ID \"%s\": %w", provider.ID, err) + } + providers[provider.ID] = adapter + } + return &PubSubProvider{ + EventConfiguration: natsData, + Logger: logger, + Providers: providers, + }, nil + } + + return nil, nil +} + +func (c *PubSubProvider) Startup(ctx context.Context) error { + for _, provider := range c.Providers { + if err := provider.Startup(ctx); err != nil { + return err + } + } + return nil +} + +func (c *PubSubProvider) Shutdown(ctx context.Context) error { + for _, provider := range c.Providers { + if err := provider.Shutdown(ctx); err != nil { + return err + } + } + return nil +} diff --git a/router/pkg/pubsub/nats/provider_test.go b/router/pkg/pubsub/nats/provider_test.go new file mode 100644 index 0000000000..9a4ba9f03a --- /dev/null +++ b/router/pkg/pubsub/nats/provider_test.go @@ -0,0 +1,287 @@ +package nats + +import ( + "context" + "io" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/plan" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" + "go.uber.org/zap/zaptest" +) + +// mockAdapter is a mock of AdapterInterface +type mockAdapter struct { + mock.Mock +} + +func (m *mockAdapter) Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error { + args := m.Called(ctx, event, updater) + return args.Error(0) +} + +func (m *mockAdapter) Publish(ctx context.Context, event PublishAndRequestEventConfiguration) error { + args := m.Called(ctx, event) + return args.Error(0) +} + +func (m *mockAdapter) Request(ctx context.Context, event PublishAndRequestEventConfiguration, w io.Writer) error { + args := m.Called(ctx, event, w) + return args.Error(0) +} + +func (m *mockAdapter) Startup(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +func (m *mockAdapter) Shutdown(ctx context.Context) error { + args := m.Called(ctx) + return args.Error(0) +} + +func TestBuildNatsOptions(t *testing.T) { + t.Run("basic configuration", func(t *testing.T) { + cfg := config.NatsEventSource{ + ID: "test-nats", + URL: "nats://localhost:4222", + } + logger := zaptest.NewLogger(t) + + opts, err := buildNatsOptions(cfg, logger) + require.NoError(t, err) + require.NotEmpty(t, opts) + }) + + t.Run("with token authentication", func(t *testing.T) { + token := "test-token" + cfg := config.NatsEventSource{ + ID: "test-nats", + URL: "nats://localhost:4222", + Authentication: &config.NatsAuthentication{ + NatsTokenBasedAuthentication: config.NatsTokenBasedAuthentication{ + Token: &token, + }, + }, + } + logger := zaptest.NewLogger(t) + + opts, err := buildNatsOptions(cfg, logger) + require.NoError(t, err) + require.NotEmpty(t, opts) + // Can't directly check for token options, but we can verify options are present + require.Greater(t, len(opts), 7) // Basic options (7) + token option + }) + + t.Run("with user/password authentication", func(t *testing.T) { + username := "user" + password := "pass" + cfg := config.NatsEventSource{ + ID: "test-nats", + URL: "nats://localhost:4222", + Authentication: &config.NatsAuthentication{ + UserInfo: config.NatsCredentialsAuthentication{ + Username: &username, + Password: &password, + }, + }, + } + logger := zaptest.NewLogger(t) + + opts, err := buildNatsOptions(cfg, logger) + require.NoError(t, err) + require.NotEmpty(t, opts) + // Can't directly check for auth options, but we can verify options are present + require.Greater(t, len(opts), 7) // Basic options (7) + user info option + }) +} + +func TestTransformEventConfig(t *testing.T) { + t.Run("publish event", func(t *testing.T) { + cfg := &nodev1.NatsEventConfiguration{ + EngineEventConfiguration: &nodev1.EngineEventConfiguration{ + Type: nodev1.EventType_PUBLISH, + }, + Subjects: []string{"original.subject"}, + } + + // Simple transform function that adds "transformed." prefix + transformFn := func(s string) (string, error) { + return "transformed." + s, nil + } + + transformedCfg, err := transformEventConfig(cfg, transformFn) + require.NoError(t, err) + require.Equal(t, []string{"transformed.original.subject"}, transformedCfg.Subjects) + }) + + t.Run("subscribe event", func(t *testing.T) { + cfg := &nodev1.NatsEventConfiguration{ + EngineEventConfiguration: &nodev1.EngineEventConfiguration{ + Type: nodev1.EventType_SUBSCRIBE, + }, + Subjects: []string{"original.subject1", "original.subject2"}, + } + + // Simple transform function that adds "transformed." prefix + transformFn := func(s string) (string, error) { + return "transformed." + s, nil + } + + transformedCfg, err := transformEventConfig(cfg, transformFn) + require.NoError(t, err) + // Since the function sorts the subjects + require.Equal(t, []string{"transformed.original.subject1", "transformed.original.subject2"}, transformedCfg.Subjects) + }) + + t.Run("invalid subject", func(t *testing.T) { + cfg := &nodev1.NatsEventConfiguration{ + EngineEventConfiguration: &nodev1.EngineEventConfiguration{ + Type: nodev1.EventType_PUBLISH, + }, + Subjects: []string{"invalid subject with spaces"}, + } + + transformFn := func(s string) (string, error) { + return s, nil + } + + _, err := transformEventConfig(cfg, transformFn) + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid subject") + }) +} + +func TestGetProvider(t *testing.T) { + t.Run("returns nil if no NATS configuration", func(t *testing.T) { + ctx := context.Background() + in := &nodev1.DataSourceConfiguration{ + CustomEvents: &nodev1.DataSourceCustomEvents{}, + } + + dsMeta := &plan.DataSourceMetadata{} + cfg := config.EventsConfiguration{} + logger := zaptest.NewLogger(t) + + provider, err := GetProvider(ctx, in, dsMeta, cfg, logger, "host", "addr") + require.NoError(t, err) + require.Nil(t, provider) + }) + + t.Run("errors if provider not found", func(t *testing.T) { + ctx := context.Background() + in := &nodev1.DataSourceConfiguration{ + CustomEvents: &nodev1.DataSourceCustomEvents{ + Nats: []*nodev1.NatsEventConfiguration{ + { + EngineEventConfiguration: &nodev1.EngineEventConfiguration{ + ProviderId: "unknown", + }, + }, + }, + }, + } + + dsMeta := &plan.DataSourceMetadata{} + cfg := config.EventsConfiguration{ + Providers: config.EventProviders{ + Nats: []config.NatsEventSource{ + {ID: "provider1", URL: "nats://localhost:4222"}, + }, + }, + } + logger := zaptest.NewLogger(t) + + provider, err := GetProvider(ctx, in, dsMeta, cfg, logger, "host", "addr") + require.Error(t, err) + require.Nil(t, provider) + assert.Contains(t, err.Error(), "failed to find Nats provider with ID") + }) +} + +func TestPubSubProvider_FindPubSubDataSource(t *testing.T) { + mockNats := &mockAdapter{} + providerId := "test-provider" + typeName := "TestType" + fieldName := "testField" + + provider := &PubSubProvider{ + EventConfiguration: []*nodev1.NatsEventConfiguration{ + { + EngineEventConfiguration: &nodev1.EngineEventConfiguration{ + TypeName: typeName, + FieldName: fieldName, + ProviderId: providerId, + Type: nodev1.EventType_PUBLISH, + }, + Subjects: []string{"test.subject"}, + }, + }, + Logger: zap.NewNop(), + Providers: map[string]AdapterInterface{ + providerId: mockNats, + }, + } + + t.Run("find matching datasource", func(t *testing.T) { + // Identity transform function + transformFn := func(s string) (string, error) { + return s, nil + } + + ds, err := provider.FindPubSubDataSource(typeName, fieldName, transformFn) + require.NoError(t, err) + require.NotNil(t, ds) + + // Check the returned datasource + natsDs, ok := ds.(*PubSubDataSource) + require.True(t, ok) + assert.Equal(t, mockNats, natsDs.NatsAdapter) + assert.Equal(t, provider.EventConfiguration[0], natsDs.EventConfiguration) + }) + + t.Run("return nil if no match", func(t *testing.T) { + ds, err := provider.FindPubSubDataSource("OtherType", fieldName, nil) + require.NoError(t, err) + require.Nil(t, ds) + }) + + t.Run("handle error in transform function", func(t *testing.T) { + // Function that returns error + errorFn := func(s string) (string, error) { + return "", assert.AnError + } + + ds, err := provider.FindPubSubDataSource(typeName, fieldName, errorFn) + require.Error(t, err) + require.Nil(t, ds) + }) + + t.Run("handle error in transform function", func(t *testing.T) { + // Function that returns error + errorFn := func(s string) (string, error) { + return "", assert.AnError + } + + ds, err := provider.FindPubSubDataSource(typeName, fieldName, errorFn) + require.Error(t, err) + require.Nil(t, ds) + }) + + t.Run("handle error in transform function with invalid subject", func(t *testing.T) { + // Function that returns error + errorFn := func(s string) (string, error) { + return " ", nil + } + + ds, err := provider.FindPubSubDataSource(typeName, fieldName, errorFn) + require.Error(t, err) + require.Nil(t, ds) + }) +} diff --git a/router/pkg/pubsub/nats/pubsub_datasource.go b/router/pkg/pubsub/nats/pubsub_datasource.go new file mode 100644 index 0000000000..ffbac39db6 --- /dev/null +++ b/router/pkg/pubsub/nats/pubsub_datasource.go @@ -0,0 +1,122 @@ +package nats + +import ( + "encoding/json" + "fmt" + + nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +type PubSubDataSource struct { + EventConfiguration *nodev1.NatsEventConfiguration + NatsAdapter AdapterInterface +} + +func (c *PubSubDataSource) EngineEventConfiguration() *nodev1.EngineEventConfiguration { + return c.EventConfiguration.GetEngineEventConfiguration() +} + +func (c *PubSubDataSource) ResolveDataSource() (resolve.DataSource, error) { + var dataSource resolve.DataSource + + typeName := c.EventConfiguration.GetEngineEventConfiguration().GetType() + switch typeName { + case nodev1.EventType_PUBLISH: + dataSource = &NatsPublishDataSource{ + pubSub: c.NatsAdapter, + } + case nodev1.EventType_REQUEST: + dataSource = &NatsRequestDataSource{ + pubSub: c.NatsAdapter, + } + default: + return nil, fmt.Errorf("failed to configure fetch: invalid event type \"%s\" for Nats", typeName.String()) + } + + return dataSource, nil +} + +func (c *PubSubDataSource) ResolveDataSourceInput(event []byte) (string, error) { + subjects := c.EventConfiguration.GetSubjects() + + if len(subjects) != 1 { + return "", fmt.Errorf("publish and request events should define one subject but received %d", len(subjects)) + } + + subject := subjects[0] + + providerId := c.GetProviderId() + + evtCfg := PublishEventConfiguration{ + ProviderID: providerId, + Subject: subject, + Data: event, + } + + return evtCfg.MarshalJSONTemplate(), nil +} + +func (c *PubSubDataSource) ResolveDataSourceSubscription() (resolve.SubscriptionDataSource, error) { + return &SubscriptionSource{ + pubSub: c.NatsAdapter, + }, nil +} + +func (c *PubSubDataSource) ResolveDataSourceSubscriptionInput() (string, error) { + providerId := c.GetProviderId() + + evtCfg := SubscriptionEventConfiguration{ + ProviderID: providerId, + Subjects: c.EventConfiguration.GetSubjects(), + } + if c.EventConfiguration.StreamConfiguration != nil { + evtCfg.StreamConfiguration = &StreamConfiguration{ + Consumer: c.EventConfiguration.StreamConfiguration.ConsumerName, + StreamName: c.EventConfiguration.StreamConfiguration.StreamName, + ConsumerInactiveThreshold: c.EventConfiguration.StreamConfiguration.ConsumerInactiveThreshold, + } + } + object, err := json.Marshal(evtCfg) + if err != nil { + return "", fmt.Errorf("failed to marshal event subscription streamConfiguration") + } + return string(object), nil +} + +func (c *PubSubDataSource) GetProviderId() string { + return c.EventConfiguration.GetEngineEventConfiguration().GetProviderId() +} + +type StreamConfiguration struct { + Consumer string `json:"consumer"` + ConsumerInactiveThreshold int32 `json:"consumerInactiveThreshold"` + StreamName string `json:"streamName"` +} + +type SubscriptionEventConfiguration struct { + ProviderID string `json:"providerId"` + Subjects []string `json:"subjects"` + StreamConfiguration *StreamConfiguration `json:"streamConfiguration,omitempty"` +} + +type PublishAndRequestEventConfiguration struct { + ProviderID string `json:"providerId"` + Subject string `json:"subject"` + Data json.RawMessage `json:"data"` +} + +func (s *PublishAndRequestEventConfiguration) MarshalJSONTemplate() string { + return fmt.Sprintf(`{"subject":"%s", "data": %s, "providerId":"%s"}`, s.Subject, s.Data, s.ProviderID) +} + +type PublishEventConfiguration struct { + ProviderID string `json:"providerId"` + Subject string `json:"subject"` + Data json.RawMessage `json:"data"` +} + +func (s *PublishEventConfiguration) MarshalJSONTemplate() string { + return fmt.Sprintf(`{"subject":"%s", "data": %s, "providerId":"%s"}`, s.Subject, s.Data, s.ProviderID) +} diff --git a/router/pkg/pubsub/nats/pubsub_datasource_test.go b/router/pkg/pubsub/nats/pubsub_datasource_test.go new file mode 100644 index 0000000000..e48639b0c9 --- /dev/null +++ b/router/pkg/pubsub/nats/pubsub_datasource_test.go @@ -0,0 +1,286 @@ +package nats + +import ( + "bytes" + "context" + "encoding/json" + "io" + "testing" + + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + nodev1 "github.com/wundergraph/cosmo/router/gen/proto/wg/cosmo/node/v1" + "github.com/wundergraph/cosmo/router/pkg/pubsub/pubsubtest" +) + +func TestNatsPubSubDataSource(t *testing.T) { + // Create event configuration with required fields + engineEventConfig := &nodev1.EngineEventConfiguration{ + ProviderId: "test-provider", + Type: nodev1.EventType_PUBLISH, + TypeName: "TestType", + FieldName: "testField", + } + + natsCfg := &nodev1.NatsEventConfiguration{ + EngineEventConfiguration: engineEventConfig, + Subjects: []string{"test-subject"}, + } + + // Create the data source to test with a real adapter + adapter := &Adapter{} + pubsub := &PubSubDataSource{ + EventConfiguration: natsCfg, + NatsAdapter: adapter, + } + + // Run the standard test suite + pubsubtest.VerifyPubSubDataSourceImplementation(t, pubsub) +} + +func TestPubSubDataSourceWithMockAdapter(t *testing.T) { + // Create event configuration with required fields + engineEventConfig := &nodev1.EngineEventConfiguration{ + ProviderId: "test-provider", + Type: nodev1.EventType_PUBLISH, + TypeName: "TestType", + FieldName: "testField", + } + + natsCfg := &nodev1.NatsEventConfiguration{ + EngineEventConfiguration: engineEventConfig, + Subjects: []string{"test-subject"}, + } + + // Create mock adapter + mockAdapter := new(mockAdapter) + + // Configure mock expectations for Publish + mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { + return event.ProviderID == "test-provider" && event.Subject == "test-subject" + })).Return(nil) + + // Create the data source with mock adapter + pubsub := &PubSubDataSource{ + EventConfiguration: natsCfg, + NatsAdapter: mockAdapter, + } + + // Get the data source + ds, err := pubsub.ResolveDataSource() + require.NoError(t, err) + + // Get the input + input, err := pubsub.ResolveDataSourceInput([]byte(`{"test":"data"}`)) + require.NoError(t, err) + + // Call Load on the data source + out := &bytes.Buffer{} + err = ds.Load(context.Background(), []byte(input), out) + require.NoError(t, err) + require.Equal(t, `{"success": true}`, out.String()) + + // Verify mock expectations + mockAdapter.AssertExpectations(t) +} + +func TestPubSubDataSource_GetResolveDataSource_WrongType(t *testing.T) { + // Create event configuration with required fields + engineEventConfig := &nodev1.EngineEventConfiguration{ + ProviderId: "test-provider", + Type: nodev1.EventType_SUBSCRIBE, // This is not supported + TypeName: "TestType", + FieldName: "testField", + } + + natsCfg := &nodev1.NatsEventConfiguration{ + EngineEventConfiguration: engineEventConfig, + Subjects: []string{"test-subject"}, + } + + // Create mock adapter + mockAdapter := new(mockAdapter) + + // Create the data source with mock adapter + pubsub := &PubSubDataSource{ + EventConfiguration: natsCfg, + NatsAdapter: mockAdapter, + } + + // Get the data source + ds, err := pubsub.ResolveDataSource() + require.Error(t, err) + require.Nil(t, ds) +} + +func TestPubSubDataSource_GetResolveDataSourceInput_MultipleSubjects(t *testing.T) { + // Create event configuration with required fields + engineEventConfig := &nodev1.EngineEventConfiguration{ + ProviderId: "test-provider", + Type: nodev1.EventType_PUBLISH, + TypeName: "TestType", + FieldName: "testField", + } + + natsCfg := &nodev1.NatsEventConfiguration{ + EngineEventConfiguration: engineEventConfig, + Subjects: []string{"test-subject-1", "test-subject-2"}, + } + + // Create the data source with mock adapter + pubsub := &PubSubDataSource{ + EventConfiguration: natsCfg, + } + + // Get the input + input, err := pubsub.ResolveDataSourceInput([]byte(`{"test":"data"}`)) + require.Error(t, err) + require.Empty(t, input) +} + +func TestPubSubDataSource_GetResolveDataSourceInput_NoSubjects(t *testing.T) { + // Create event configuration with required fields + engineEventConfig := &nodev1.EngineEventConfiguration{ + ProviderId: "test-provider", + Type: nodev1.EventType_PUBLISH, + TypeName: "TestType", + FieldName: "testField", + } + + natsCfg := &nodev1.NatsEventConfiguration{ + EngineEventConfiguration: engineEventConfig, + Subjects: []string{}, + } + + // Create the data source with mock adapter + pubsub := &PubSubDataSource{ + EventConfiguration: natsCfg, + } + + // Get the input + input, err := pubsub.ResolveDataSourceInput([]byte(`{"test":"data"}`)) + require.Error(t, err) + require.Empty(t, input) +} + +func TestNatsPubSubDataSourceMultiSubjectSubscription(t *testing.T) { + // Create event configuration with multiple subjects + engineEventConfig := &nodev1.EngineEventConfiguration{ + ProviderId: "test-provider", + Type: nodev1.EventType_PUBLISH, // Must be PUBLISH as it's the only supported type + TypeName: "TestType", + FieldName: "testField", + } + + natsCfg := &nodev1.NatsEventConfiguration{ + EngineEventConfiguration: engineEventConfig, + Subjects: []string{"test-subject-1", "test-subject-2"}, + } + + // Create the data source to test with mock adapter + pubsub := &PubSubDataSource{ + EventConfiguration: natsCfg, + } + + // Test GetResolveDataSourceSubscriptionInput + subscriptionInput, err := pubsub.ResolveDataSourceSubscriptionInput() + require.NoError(t, err, "Expected no error from GetResolveDataSourceSubscriptionInput") + require.NotEmpty(t, subscriptionInput, "Expected non-empty subscription input") + + // Verify the subscription input contains both subjects + var subscriptionConfig SubscriptionEventConfiguration + err = json.Unmarshal([]byte(subscriptionInput), &subscriptionConfig) + require.NoError(t, err, "Expected valid JSON from GetResolveDataSourceSubscriptionInput") + require.Equal(t, 2, len(subscriptionConfig.Subjects), "Expected 2 subjects in subscription configuration") + require.Equal(t, "test-subject-1", subscriptionConfig.Subjects[0], "Expected first subject to be 'test-subject-1'") + require.Equal(t, "test-subject-2", subscriptionConfig.Subjects[1], "Expected second subject to be 'test-subject-2'") +} + +func TestNatsPubSubDataSourceWithStreamConfiguration(t *testing.T) { + // Create event configuration with stream configuration + engineEventConfig := &nodev1.EngineEventConfiguration{ + ProviderId: "test-provider", + Type: nodev1.EventType_PUBLISH, + TypeName: "TestType", + FieldName: "testField", + } + + natsCfg := &nodev1.NatsEventConfiguration{ + EngineEventConfiguration: engineEventConfig, + Subjects: []string{"test-subject"}, + StreamConfiguration: &nodev1.NatsStreamConfiguration{ + StreamName: "test-stream", + ConsumerName: "test-consumer", + ConsumerInactiveThreshold: 30, + }, + } + + // Create the data source to test + pubsub := &PubSubDataSource{ + EventConfiguration: natsCfg, + } + + // Test GetResolveDataSourceSubscriptionInput with stream configuration + subscriptionInput, err := pubsub.ResolveDataSourceSubscriptionInput() + require.NoError(t, err, "Expected no error from GetResolveDataSourceSubscriptionInput") + require.NotEmpty(t, subscriptionInput, "Expected non-empty subscription input") + + // Verify the subscription input contains stream configuration + var subscriptionConfig SubscriptionEventConfiguration + err = json.Unmarshal([]byte(subscriptionInput), &subscriptionConfig) + require.NoError(t, err, "Expected valid JSON from GetResolveDataSourceSubscriptionInput") + require.NotNil(t, subscriptionConfig.StreamConfiguration, "Expected non-nil stream configuration") + require.Equal(t, "test-consumer", subscriptionConfig.StreamConfiguration.Consumer, "Expected consumer to be 'test-consumer'") + require.Equal(t, "test-stream", subscriptionConfig.StreamConfiguration.StreamName, "Expected stream name to be 'test-stream'") + require.Equal(t, int32(30), subscriptionConfig.StreamConfiguration.ConsumerInactiveThreshold, "Expected consumer inactive threshold to be 30") +} + +func TestPubSubDataSource_RequestDataSource(t *testing.T) { + // Create event configuration with REQUEST type + engineEventConfig := &nodev1.EngineEventConfiguration{ + ProviderId: "test-provider", + Type: nodev1.EventType_REQUEST, + TypeName: "TestType", + FieldName: "testField", + } + + natsCfg := &nodev1.NatsEventConfiguration{ + EngineEventConfiguration: engineEventConfig, + Subjects: []string{"test-subject"}, + } + + // Create mock adapter + mockAdapter := new(mockAdapter) + + // Configure mock expectations for Request + mockAdapter.On("Request", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { + return event.ProviderID == "test-provider" && event.Subject == "test-subject" + }), mock.Anything).Return(nil).Run(func(args mock.Arguments) { + w := args.Get(2).(io.Writer) + w.Write([]byte(`{"response": "test"}`)) + }) + + // Create the data source with mock adapter + pubsub := &PubSubDataSource{ + EventConfiguration: natsCfg, + NatsAdapter: mockAdapter, + } + + // Get the data source + ds, err := pubsub.ResolveDataSource() + require.NoError(t, err) + require.NotNil(t, ds) + + // Get the input + input, err := pubsub.ResolveDataSourceInput([]byte(`{"test":"data"}`)) + require.NoError(t, err) + + // Call Load on the data source + out := &bytes.Buffer{} + err = ds.Load(context.Background(), []byte(input), out) + require.NoError(t, err) + require.Equal(t, `{"response": "test"}`, out.String()) + + // Verify mock expectations + mockAdapter.AssertExpectations(t) +} diff --git a/router/pkg/pubsub/nats/utils.go b/router/pkg/pubsub/nats/utils.go new file mode 100644 index 0000000000..0229f7a41c --- /dev/null +++ b/router/pkg/pubsub/nats/utils.go @@ -0,0 +1,37 @@ +package nats + +import ( + "strings" +) + +const ( + fwc = '>' + tsep = "." +) + +func isValidNatsSubject(subject string) bool { + if subject == "" { + return false + } + sfwc := false + tokens := strings.Split(subject, tsep) + for _, t := range tokens { + length := len(t) + if length == 0 || sfwc { + return false + } + if length > 1 { + if strings.ContainsAny(t, "\t\n\f\r ") { + return false + } + continue + } + switch t[0] { + case fwc: + sfwc = true + case ' ', '\t', '\n', '\r', '\f': + return false + } + } + return true +} diff --git a/router/pkg/pubsub/nats/utils_test.go b/router/pkg/pubsub/nats/utils_test.go new file mode 100644 index 0000000000..9ec92c78cc --- /dev/null +++ b/router/pkg/pubsub/nats/utils_test.go @@ -0,0 +1,83 @@ +package nats + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestIsValidNatsSubject(t *testing.T) { + tests := []struct { + name string + subject string + want bool + }{ + { + name: "empty string", + subject: "", + want: false, + }, + { + name: "simple valid subject", + subject: "test.subject", + want: true, + }, + { + name: "valid subject with wildcard", + subject: "test.>", + want: true, + }, + { + name: "invalid with space", + subject: "test subject", + want: false, + }, + { + name: "invalid with tab", + subject: "test\tsubject", + want: false, + }, + { + name: "invalid with newline", + subject: "test\nsubject", + want: false, + }, + { + name: "invalid with empty token", + subject: "test..subject", + want: false, + }, + { + name: "wildcard not at end", + subject: "test.>.subject", + want: false, + }, + { + name: "contains a space", + subject: " ", + want: false, + }, + { + name: "contains a tab", + subject: "\t", + want: false, + }, + { + name: "contains a newline", + subject: "\n", + want: false, + }, + { + name: "contains a form feed", + subject: "\f", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isValidNatsSubject(tt.subject) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/router/pkg/pubsub/pubsub.go b/router/pkg/pubsub/pubsub.go new file mode 100644 index 0000000000..e70cf66f33 --- /dev/null +++ b/router/pkg/pubsub/pubsub.go @@ -0,0 +1,23 @@ +package pubsub + +import ( + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + + "github.com/wundergraph/cosmo/router/pkg/pubsub/kafka" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" +) + +var additionalProviders []datasource.ProviderFactory + +// RegisterAdditionalProvider registers an additional PubSub provider +func RegisterAdditionalProvider(provider datasource.ProviderFactory) { + additionalProviders = append(additionalProviders, provider) +} + +// GetProviderFactories returns a list of all PubSub implementations +func GetProviderFactories() []datasource.ProviderFactory { + return append([]datasource.ProviderFactory{ + kafka.GetProvider, + nats.GetProvider, + }, additionalProviders...) +} diff --git a/router/pkg/pubsub/pubsubtest/pubsubtest.go b/router/pkg/pubsub/pubsubtest/pubsubtest.go new file mode 100644 index 0000000000..09ad3b25b0 --- /dev/null +++ b/router/pkg/pubsub/pubsubtest/pubsubtest.go @@ -0,0 +1,48 @@ +package pubsubtest + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" +) + +// VerifyPubSubDataSourceImplementation is a common test function to verify any PubSubDataSource implementation +// This function can be used by other packages to test their PubSubDataSource implementations +func VerifyPubSubDataSourceImplementation(t *testing.T, pubSub datasource.PubSubDataSource) { + // Test GetEngineEventConfiguration + engineEventConfig := pubSub.EngineEventConfiguration() + require.NotNil(t, engineEventConfig, "Expected non-nil EngineEventConfiguration") + + // Test GetResolveDataSource + dataSource, err := pubSub.ResolveDataSource() + require.NoError(t, err, "Expected no error from GetResolveDataSource") + require.NotNil(t, dataSource, "Expected non-nil DataSource") + + // Test GetResolveDataSourceInput with sample event data + testEvent := []byte(`{"test":"data"}`) + input, err := pubSub.ResolveDataSourceInput(testEvent) + require.NoError(t, err, "Expected no error from GetResolveDataSourceInput") + assert.NotEmpty(t, input, "Expected non-empty input") + + // Make sure the input is valid JSON + var result interface{} + err = json.Unmarshal([]byte(input), &result) + assert.NoError(t, err, "Expected valid JSON from GetResolveDataSourceInput") + + // Test GetResolveDataSourceSubscription + subscription, err := pubSub.ResolveDataSourceSubscription() + require.NoError(t, err, "Expected no error from GetResolveDataSourceSubscription") + require.NotNil(t, subscription, "Expected non-nil SubscriptionDataSource") + + // Test GetResolveDataSourceSubscriptionInput + subscriptionInput, err := pubSub.ResolveDataSourceSubscriptionInput() + require.NoError(t, err, "Expected no error from GetResolveDataSourceSubscriptionInput") + assert.NotEmpty(t, subscriptionInput, "Expected non-empty subscription input") + + // Make sure the subscription input is valid JSON + err = json.Unmarshal([]byte(subscriptionInput), &result) + assert.NoError(t, err, "Expected valid JSON from GetResolveDataSourceSubscriptionInput") +}