From d0079aed4082abfe0d2028c1e9f1e951fe5e4f4d Mon Sep 17 00:00:00 2001 From: Alessandro Pagnin Date: Fri, 10 Oct 2025 12:13:05 +0200 Subject: [PATCH 01/44] feat: add SubscriptionOnStartHandler (#2059) Co-authored-by: Ludwig Bedacht Co-authored-by: StarpTech Co-authored-by: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> --- adr/cosmo-streams-v1.md | 408 +++++++ .../availability/subgraph/schema.resolvers.go | 4 +- .../mood/subgraph/schema.resolvers.go | 4 +- rfc/cosmo-streams-v1.md | 1079 +++++++++++++++++ router-tests/go.mod | 4 +- router-tests/go.sum | 2 + .../modules/start-subscription/module.go | 61 + .../modules/start_subscription_test.go | 664 ++++++++++ router/.mockery.yml | 3 +- router/core/errors.go | 5 + router/core/executor.go | 4 +- router/core/factoryresolver.go | 28 +- router/core/graph_server.go | 1 + router/core/graphql_handler.go | 16 + router/core/plan_generator.go | 2 +- router/core/router.go | 4 + router/core/router_config.go | 5 + router/core/subscriptions_modules.go | 188 +++ router/go.mod | 4 +- router/go.sum | 4 +- router/pkg/pubsub/datasource/datasource.go | 10 +- router/pkg/pubsub/datasource/factory.go | 12 +- router/pkg/pubsub/datasource/mocks.go | 271 ++++- router/pkg/pubsub/datasource/planner.go | 1 + router/pkg/pubsub/datasource/provider.go | 46 +- .../pkg/pubsub/datasource/pubsubprovider.go | 8 +- .../pubsub/datasource/pubsubprovider_test.go | 8 +- .../datasource/subscription_datasource.go | 72 ++ .../subscription_datasource_test.go | 327 +++++ .../datasource/subscription_event_updater.go | 34 + router/pkg/pubsub/kafka/adapter.go | 57 +- router/pkg/pubsub/kafka/engine_datasource.go | 97 +- .../pubsub/kafka/engine_datasource_factory.go | 42 +- .../kafka/engine_datasource_factory_test.go | 60 +- .../pubsub/kafka/engine_datasource_test.go | 156 +-- router/pkg/pubsub/kafka/mocks.go | 22 +- router/pkg/pubsub/nats/adapter.go | 68 +- router/pkg/pubsub/nats/engine_datasource.go | 94 +- .../pubsub/nats/engine_datasource_factory.go | 42 +- .../nats/engine_datasource_factory_test.go | 61 +- .../pkg/pubsub/nats/engine_datasource_test.go | 190 +-- router/pkg/pubsub/nats/mocks.go | 22 +- router/pkg/pubsub/pubsub.go | 25 +- router/pkg/pubsub/pubsub_test.go | 17 +- router/pkg/pubsub/redis/adapter.go | 35 +- router/pkg/pubsub/redis/engine_datasource.go | 95 +- .../pubsub/redis/engine_datasource_factory.go | 40 +- .../redis/engine_datasource_factory_test.go | 60 +- .../pubsub/redis/engine_datasource_test.go | 145 +-- router/pkg/pubsub/redis/mocks.go | 22 +- 50 files changed, 3861 insertions(+), 768 deletions(-) create mode 100644 adr/cosmo-streams-v1.md create mode 100644 rfc/cosmo-streams-v1.md create mode 100644 router-tests/modules/start-subscription/module.go create mode 100644 router-tests/modules/start_subscription_test.go create mode 100644 router/core/subscriptions_modules.go create mode 100644 router/pkg/pubsub/datasource/subscription_datasource.go create mode 100644 router/pkg/pubsub/datasource/subscription_datasource_test.go create mode 100644 router/pkg/pubsub/datasource/subscription_event_updater.go diff --git a/adr/cosmo-streams-v1.md b/adr/cosmo-streams-v1.md new file mode 100644 index 0000000000..436dafe45b --- /dev/null +++ b/adr/cosmo-streams-v1.md @@ -0,0 +1,408 @@ +--- +title: "Cosmo Streams v1" +author: Alessandro Pagnin +date: 2025-07-16 +status: Accepted +--- + +# ADR - Cosmo Streams V1 + +- **Author:** Alessandro Pagnin +- **Date:** 2025-07-16 +- **Status:** Accepted +- **RFC:** ../rfcs/cosmo-streams-v1.md + +## Abstract +This ADR describes new hooks that will be added to the router to support more customizable stream behavior. +The goal is to allow developers to customize the cosmo streams behavior. + +## Decision +The following interfaces will extend the existing logic in the custom modules. +These provide additional control over subscriptions by providing hooks, which are invoked during specific events. + +- `SubscriptionOnStartHandler`: Called once at subscription start. +- `StreamBatchEventHook`: Called each time a batch of events is received from the provider. +- `StreamPublishEventHook`: Called each time a batch of events is going to be sent to the provider. + +```go +// STRUCTURES TO BE ADDED TO PUBSUB PACKAGE +type ProviderType string +const ( + ProviderTypeNats ProviderType = "nats" + ProviderTypeKafka ProviderType = "kafka" + ProviderTypeRedis ProviderType = "redis" +} + +// StreamHookError is used to customize the error messages and the behavior +type StreamHookError struct { + HttpError core.HttpError + CloseSubscription bool +} + +// OperationContext already exists, we just have to add the Variables() method +type OperationContext interface { + Name() string + // the variables are currently not available, so we need to expose them here + Variables() *astjson.Value +} + +// each provider will have its own event type with custom fields +// the StreamEvent interface is used to allow the hooks system to be provider-agnostic +// there could be common fields in future, but for now we don't need them +type StreamEvent interface {} + +// SubscriptionEventConfiguration is the common interface for the subscription event configuration +type SubscriptionEventConfiguration interface { + ProviderID() string + ProviderType() string + // the root field name of the subscription in the schema + RootFieldName() string +} + +// PublishEventConfiguration is the common interface for the publish event configuration +type PublishEventConfiguration interface { + ProviderID() string + ProviderType() string + // the root field name of the mutation in the schema + RootFieldName() string +} + +type SubscriptionOnStartHookContext interface { + // Request is the original request received by the router. + Request() *http.Request + // Logger is the logger for the request + Logger() *zap.Logger + // Operation is the GraphQL operation + Operation() OperationContext + // Authentication is the authentication for the request + Authentication() authentication.Authentication + // SubscriptionEventConfiguration is the subscription event configuration (will return nil for engine subscription) + SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration + // WriteEvent writes an event to the stream of the current subscription + // It returns true if the event was written to the stream, false if the event was dropped + WriteEvent(event datasource.StreamEvent) bool +} + +type SubscriptionOnStartHandler interface { + // OnSubscriptionOnStart is called once at subscription start + // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a StreamHookError. + SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error +} + +type StreamBatchEventHookContext interface { + // the request context + RequestContext() RequestContext + // the subscription event configuration + SubscriptionEventConfiguration() SubscriptionEventConfiguration +} + +type StreamBatchEventHook interface { + // OnStreamEvents is called each time a batch of events is received from the provider + // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a StreamHookError. + OnStreamEvents(ctx StreamBatchEventHookContext, events []StreamEvent) ([]StreamEvent, error) +} + +type StreamPublishEventHookContext interface { + // the request context + RequestContext() RequestContext + // the publish event configuration + PublishEventConfiguration() PublishEventConfiguration +} + +type StreamPublishEventHook interface { + // OnPublishEvents is called each time a batch of events is going to be sent to the provider + // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a StreamHookError. + OnPublishEvents(ctx StreamPublishEventHookContext, events []StreamEvent) ([]StreamEvent, error) +} +``` + +## Example Use Cases + +- **Authorization**: Implementing authorization checks at the start of subscriptions +- **Initial message**: Sending an initial message to clients upon subscription start +- **Data mapping**: Transforming events data from the format that could be used by the external system to/from Federation compatible Router events +- **Event filtering**: Filtering events using custom logic + +## Backwards Compatibility + +The new hooks can be integrated in the router in a fully backwards compatible way. + +When the new module system will be released, the Cosmo Streams hooks: +- will be moved to the `core/hooks.go` file +- will be added to the `hookRegistry` +- will be initialized in the `coreModuleHooks.initCoreModuleHooks` + + +# Example Modules + +__All examples are pseudocode and not tested, but they are as close as possible to the final implementation__ + +## Filter and remap events + +This example will show how to filter the events based on the client's scopes and remapping the messages as they are expected from the `Employee` type. + +### 1. Add a subscription to the cosmo streams graphql schema + +The developer will start by adding a subscription to the cosmo streams graphql schema. + +```graphql +type Subscription { + employeeUpdates: Employee! @edfs__natsSubscribe(subjects: ["employeeUpdates"], providerId: "my-nats") +} + +type Employee @key(fields: "id", resolvable: false) { + id: Int! @external +} +``` +After publishing the schema, the developer will need to add the module to the cosmo streams engine. + +### 2. Write the custom module + +The developer will need to write the custom module that will be used to subscribe to the `employeeUpdates` subject and filter the events based on the client's scopes and remapping the messages as they are expected from the `Employee` type. + +```go +package mymodule + +import ( + "encoding/json" + "slices" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" +) + +func init() { + // Register your module here and it will be loaded at router start + core.RegisterModule(&MyModule{}) +} + +type MyModule struct {} + +func (m *MyModule) OnStreamEvents(ctx StreamBatchEventHookContext, events []core.StreamEvent) ([]core.StreamEvent, error) { + // check if the provider is nats + if ctx.StreamContext().ProviderType() != pubsub.ProviderTypeNats { + return events, nil + } + + // check if the provider id is the one expected by the module + if ctx.StreamContext().ProviderID() != "my-nats" { + return events, nil + } + + // check if the subject is the one expected by the module + natsConfig := ctx.SubscriptionEventConfiguration().(*nats.SubscriptionEventConfiguration) + if natsConfig.Subjects[0] != "employeeUpdates" { + return events, nil + } + + // check if the client is authenticated + if ctx.RequestContext().Authentication() == nil { + // if the client is not authenticated, return no events + return events, nil + } + + // check if the client is allowed to subscribe to the stream + clientAllowedEntitiesIds, found := ctx.RequestContext().Authentication().Claims()["allowedEntitiesIds"] + if !found { + return events, fmt.Errorf("client is not allowed to subscribe to the stream") + } + + newEvents := make([]core.StreamEvent, 0, len(events)) + + for _, evt := range events { + natsEvent, ok := evt.(*nats.NatsEvent); + if !ok { + newEvents = append(newEvents, evt) + continue + } + + // decode the event data coming from the provider + var dataReceived struct { + EmployeeId string `json:"EmployeeId"` + OtherField string `json:"OtherField"` + } + err := json.Unmarshal(natsEvent.Data, &dataReceived) + if err != nil { + return events, fmt.Errorf("error unmarshalling data: %w", err) + } + + // filter the events based on the client's scopes + if !slices.Contains(clientAllowedEntitiesIds, dataReceived.EmployeeId) { + continue + } + + // prepare the data to send to the client + var dataToSend struct { + Id string `json:"id"` + TypeName string `json:"__typename"` + } + dataToSend.Id = dataReceived.EmployeeId + dataToSend.TypeName = "Employee" + + // marshal the data to send to the client + dataToSendMarshalled, err := json.Marshal(dataToSend) + if err != nil { + return events, fmt.Errorf("error marshalling data: %w", err) + } + + // create the new event + newEvent := &nats.NatsEvent{ + Data: dataToSendMarshalled, + Metadata: natsEvent.Metadata, + } + newEvents = append(newEvents, newEvent) + } + return newEvents, nil +} + +func (m *MyModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: myModuleID, + Priority: 1, + New: func() core.Module { + return &MyModule{} + }, + } +} + +// Interface guards +var ( + _ core.StreamBatchEventHook = (*MyModule)(nil) +) +``` + +### 3. Add the provider configuration to the cosmo router +```yaml +version: "1" + +events: + providers: + nats: + - id: my-nats + url: "nats://localhost:4222" +``` + +## Check authorization at subscription start + +This example will show how to check the authorization at subscription start. + +### 1. Add a subscription to the cosmo streams graphql schema + +The developer will start by adding a subscription to the cosmo streams graphql schema. + +```graphql +type Subscription { + employeeUpdates: Employee! @edfs__natsSubscribe(subjects: ["employeeUpdates"], providerId: "my-nats") +} + +type Employee @key(fields: "id", resolvable: false) { + id: Int! @external +} +``` +After publishing the schema, the developer will need to add the module to the cosmo streams engine. + +### 2. Write the custom module + +The developer will need to write the custom module that will be used to check the authorization at subscription start. + +```go +package mymodule + +import ( + "encoding/json" + "slices" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" +) + +func init() { + // Register your module here and it will be loaded at router start + core.RegisterModule(&MyModule{}) +} + +type MyModule struct {} + +func (m *MyModule) SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error { + // check if the provider is nats + if ctx.SubscriptionEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { + return nil + } + + // check if the provider id is the one expected by the module + if ctx.SubscriptionEventConfiguration().ProviderID() != "my-nats" { + return nil + } + + // check if the subject is the one expected by the module + natsConfig := ctx.SubscriptionEventConfiguration().(*nats.SubscriptionEventConfiguration) + if natsConfig.Subjects[0] != "employeeUpdates" { + return nil + } + + // check if the client is authenticated + if ctx.Authentication() == nil { + // if the client is not authenticated, return an error + return &StreamHookError{ + HttpError: core.HttpError{ + Code: http.StatusUnauthorized, + Message: "client is not authenticated", + }, + CloseSubscription: true, + } + } + + // check if the client is allowed to subscribe to the stream + clientAllowedEntitiesIds, found := ctx.Authentication().Claims()["readEmployee"] + if !found { + return &StreamHookError{ + HttpError: core.HttpError{ + Code: http.StatusForbidden, + Message: "client is not allowed to read employees", + }, + CloseSubscription: true, + } + } + + return nil +} + +func (m *MyModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: myModuleID, + Priority: 1, + New: func() core.Module { + return &MyModule{} + }, + } +} + +// Interface guards +var ( + _ core.SubscriptionOnStartHandler = (*MyModule)(nil) +) +``` + +### 3. Add the provider configuration to the cosmo router +```yaml +version: "1" + +events: + providers: + nats: + - id: my-nats + url: "nats://localhost:4222" +``` + +### 4. Build the cosmo router with the custom module + +Build and run the router with the custom module added. + +# Outlook + +## Using AsyncAPI for Event Data Structure + +We could use AsyncAPI specifications to define the event data structure and generate the Go structs automatically. This would make the development of custom modules easier and more maintainable. +We could also generate the AsyncAPI specification from the schema and the events data, to make it easier for external systems to use the events published by cosmo streams engine. + +## Generate hooks from AsyncAPI specifications + +Building on the AsyncAPI integration, we could allow the user to define their streams using AsyncAPI and generate fully typesafe hooks with all events structures generated from the AsyncAPI specification. \ No newline at end of file diff --git a/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go b/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go index 3473ad212d..6abb2c062e 100644 --- a/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go +++ b/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go @@ -18,7 +18,7 @@ func (r *mutationResolver) UpdateAvailability(ctx context.Context, employeeID in storage.Set(employeeID, isAvailable) err := r.NatsPubSubByProviderID["default"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ Subject: r.GetPubSubName(fmt.Sprintf("employeeUpdated.%d", employeeID)), - Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID)), + Event: nats.Event{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))}, }) if err != nil { @@ -26,7 +26,7 @@ func (r *mutationResolver) UpdateAvailability(ctx context.Context, employeeID in } err = r.NatsPubSubByProviderID["my-nats"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ Subject: r.GetPubSubName(fmt.Sprintf("employeeUpdatedMyNats.%d", employeeID)), - Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID)), + Event: nats.Event{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))}, }) if err != nil { diff --git a/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go b/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go index 2f8ea33149..82a0a7e9f2 100644 --- a/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go +++ b/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go @@ -21,7 +21,7 @@ func (r *mutationResolver) UpdateMood(ctx context.Context, employeeID int, mood if r.NatsPubSubByProviderID["default"] != nil { err := r.NatsPubSubByProviderID["default"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ Subject: myNatsTopic, - Data: []byte(payload), + Event: nats.Event{Data: []byte(payload)}, }) if err != nil { return nil, err @@ -34,7 +34,7 @@ func (r *mutationResolver) UpdateMood(ctx context.Context, employeeID int, mood if r.NatsPubSubByProviderID["my-nats"] != nil { err := r.NatsPubSubByProviderID["my-nats"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ Subject: defaultTopic, - Data: []byte(payload), + Event: nats.Event{Data: []byte(payload)}, }) if err != nil { return nil, err diff --git a/rfc/cosmo-streams-v1.md b/rfc/cosmo-streams-v1.md new file mode 100644 index 0000000000..2a7cd761f1 --- /dev/null +++ b/rfc/cosmo-streams-v1.md @@ -0,0 +1,1079 @@ +# RFC Cosmo Streams V1 + +Based on customer feedback, we've identified the need for more customizable stream behavior. The key areas for customization include: +- **Authorization**: Implementing authorization checks at the start of subscriptions +- **Initial message**: Sending an initial message to clients upon subscription start +- **Data mapping**: Transforming events data from the format that could be used by the external system to/from Federation compatible Router events +- **Event filtering**: Filtering events using custom logic + +Let's explore how we can address each of these requirements. + +## Authorization + +To support authorization, we need a hook that enables the following key decisions: +- Whether the client or user is authorized to initiate the subscription +- Which topics the client is permitted to subscribe to +- Whether the client is allowed to consume an event from the stream (covered by the Event Filtering hook) + +Additionally, a similar mechanism is required for non-stream subscriptions, allowing: +- Custom JWT validation logic (e.g., expiration checks, signature verification, secret handling) +- The ability to reject unauthenticated or unauthorized requests and close the subscription accordingly + +We already allow some customization using `RouterOnRequestHandler`, but it has no access to the stream data. To access this data, we need to add a new hook that will be called immediately before the subscription starts. + +### Example: Check if the client is allowed to subscribe to the stream + +```go +// the interfaces/structs are reported partially to make the example more readable +// the full new interfaces/structs are available in the appendix 1 + +// This is the new hook that will be called once at subscription start +type SubscriptionOnStartHandler interface { + SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error +} + +// already defined in the provider package +type NatsSubscriptionEventConfiguration struct { + ProviderID string `json:"providerId"` + Subjects []string `json:"subjects"` + StreamConfiguration *StreamConfiguration `json:"streamConfiguration,omitempty"` +} + +type StreamHookError struct { + HttpError core.HttpError + CloseSubscription bool +} + +type MyModule struct {} + +// This is a custom function that will be used to check if the client is allowed to subscribe to the stream +func customCheckIfClientIsAllowedToSubscribe(ctx SubscriptionOnStartHookContext) bool { + // check if the field name is the one expected by the module + if ctx.SubscriptionEventConfiguration().RootFieldName() != "employeeUpdates" { + return true + } + + // get the specific configuration for the provider to make more advanced checks + cfg, ok := ctx.SubscriptionEventConfiguration().(*NatsSubscriptionEventConfiguration) + if !ok { + return true + } + + providerId := cfg.ProviderID + auth := ctx.RequestContext().Authentication() + + // add checks here on client authentication scopes, provider ID, etc. + + return false +} + +// This is the new hook that will be called once at subscription start +func (m *MyModule) SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error { + // check if the client is allowed to subscribe to the stream + if !customCheckIfClientIsAllowedToSubscribe(ctx) { + // if not, return an error to prevent the subscription from starting + return StreamHookError{ + HttpError: core.NewHttpGraphqlError( + "you should be an admin to subscribe to this or only subscribe to public subscriptions!", + "UNAUTHORIZED", + http.StatusUnauthorized, + ), CloseSubscription: true, + } + } + return nil +} + +func (m *MyModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: myModuleID, + Priority: 1, + New: func() core.Module { + return &MyModule{} + }, + } +} +``` + +### Proposal + +Add a new hook to the subscription lifecycle, `SubscriptionOnStartHandler`, that will be called once at subscription start. + +The hook arguments are: +* `ctx SubscriptionOnStartHookContext`: The subscription context, which contains the request context and, optionally, the subscription event configuration, and a method to emit the event to the stream + +`RequestContext` already exists and requires no changes, but `SubscriptionEventConfiguration` is new. + +The hook should return an error if the client is not allowed to subscribe to the stream, preventing the subscription from starting. +The hook should return `nil` if the client is allowed to subscribe to the stream, allowing the subscription to proceed. + +The hook can return a `SubscriptionHookError` to customize the error messages and the behavior on the subscription. + +I evaluated the possibility of adding the `SubscriptionContext` to the request context and using it within one of the existing hooks, but it would be difficult to build the subscription context without executing the pubsub code. + +The `SubscriptionEventConfiguration()` contains the subscription configuration as used by the provider. This allows the hooks system to be provider-agnostic, so adding a new provider will not require changes to the hooks system. To use specific fields, the hook can cast the configuration to the specific type for the provider. +The `WriteEvent()` method is new and allows the hook to emit the event to the stream. + +## Initial Message + +When starting a subscription, the client sends a query to the server containing the operation name and variables. The client must then wait for the broker to send the initial message. This waiting period can lead to a poor user experience, as the client cannot display anything until the initial message is received. To address this, we can emit an initial message on subscription start. + +To emit an initial message on subscription start, we need access to the stream context (to get the provider type and ID) and the query that the client sent. The variables are particularly important, as they allow the module to use them in the initial message. For example, if someone starts a subscription with employee ID 100 as a variable, the custom module can include that ID in the initial message. + +### Example + +```go +// the interfaces/structs are reported partially to make the example more readable +// the full new interfaces/structs are available in the appendix 1 + +// This is the new hook that will be called once at stream start +type SubscriptionOnStartHandler interface { + SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error +} + +// each provider will have its own event type that implements the StreamEvent interface +type NatsEvent struct { + Data json.RawMessage + Metadata map[string]string +} + +type MyModule struct {} + +// This is the new hook that will be called once at subscription start +func (m *MyModule) SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error { + // get the operation name and variables that we need + opName := ctx.RequestContext().Operation().Name() + opVarId := ctx.RequestContext().Operation().Variables().GetInt("id") + + // check if the provider ID is the one expected by the module + if ctx.SubscriptionEventConfiguration().ProviderID() != "my-provider-id" { + return nil + } + + //check if the provider type is the one expected by the module + if ctx.SubscriptionEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { + return nil + } + + // check if the operation name is the one expected by the module + if opName == "employeeSub" { + // create the event to emit using the operation variables + evt := &NatsEvent{ + Data: []byte(fmt.Sprintf("{\"id\": \"%d\", \"__typename\": \"Employee\"}", opVarId)), + Metadata: map[string]string{ + "entity-id": fmt.Sprintf("%d", opVarId), + }, + } + // emit the event to the stream, that will be received only by the client that subscribed to the stream + ctx.WriteEvent(evt) + } + return nil +} + +func (m *MyModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: myModuleID, + Priority: 1, + New: func() core.Module { + return &MyModule{} + }, + } +} +``` + +### Proposal + +Using the new `SubscriptionOnStart` hook that we introduced for the previous requirement, we can emit the initial message on subscription start. We will also need access to operation variables, which are currently not available in the request context. + +To emit the message, I propose adding a new method to the stream context, `WriteEvent`, which will emit the event to the stream at the lowest level. The message will pass through all hooks, making it behave like any other event received from the provider. The message will be received only by the client that subscribed to the stream, and not by the other clients that subscribed to the same stream. + +The `StreamEvent` contains the data as used by the provider. This allows the hooks system to be provider-agnostic, so adding a new provider will not require changes to the hooks system. To use events, the hook has to cast the event to the specific type for the provider. + +This change will require adding a new type in each provider package to represent the event with additional fields (metadata, etc.). This is a significant change, but it is necessary to support additional data in events, anyway, even if we don't expose them to the custom modules. + +Emitting the initial message with this hook ensures that the client will receive the message before the first event from the provider is received. + +## Data Mapping + +The current approach for emitting and reading data from the stream is not flexible enough. We need to be able to map data from an external format to the internal format, and vice versa. + +Also, different providers can have different additional fields other than the message body. + +As an example: +- NATS provider can have a `Metadata` field +- Kafka provider can have a `Headers` and `Key` fields + +And this additional fields could be an important part of integrating with external systems. + +### Example 1: Rewrite the event received from the provider to a format that is usable by Cosmo streams + +```go +// the interfaces/structs are reported partially to make the example more readable +// the full new interfaces/structs are available in the appendix 1 + +// each provider will have its own event type that implements the StreamEvent interface +type NatsEvent struct { + Data json.RawMessage + Metadata map[string]string +} +type KafkaEvent struct { + Key []byte + Data json.RawMessage + Headers map[[]byte][]byte +} + +// StreamBatchEventHook processes a batch of inbound stream events +// +// Return: +// - empty slice: drop all events. +// - non-empty slice: emit those events (can grow, shrink, or reorder the batch). +// err != nil: abort the subscription with an error. +type StreamBatchEventHook interface { + OnStreamEvents(ctx StreamBatchEventHookContext, events []StreamEvent) ([]StreamEvent, error) +} + +type MyModule struct {} + +// This is the new hook that will be called each time a batch of events is received from the provider +func (m *MyModule) OnStreamEvents( + ctx StreamBatchEventHookContext, + events []StreamEvent, +) ([]StreamEvent, error) { + // check if the provider ID is the one expected by the module + if ctx.SubscriptionEventConfiguration().ProviderID() != "my-provider-id" { + return events, nil + } + + // check if the provider type is the one expected by the module + if ctx.SubscriptionEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { + return events, nil + } + + // check if the subject is the one expected by the module + natsConfig := ctx.SubscriptionEventConfiguration().(*nats.SubscriptionEventConfiguration) + if natsConfig.Subjects[0] != "topic-with-internal-data-format" { + return events, nil + } + + // create a new slice of events that we will return with the events with the new format + newEvents := make([]StreamEvent, 0, len(events)) + for _, evt := range events { + // check if the event is the one expected by the module + if natsEvent, ok := evt.(*NatsEvent); ok { + // here you can umarshal the old data and map it to the new format + // for example: + // var dataReceived struct { + // EmployeeName string `json:"EmployeeName"` + // } + // err := json.Unmarshal(natsEvent.Data, &dataReceived) + + // if we have to extract the data from the metadata fields, we can do it like this: + entityId := natsEvent.Metadata["entity-id"] + entityType := natsEvent.Metadata["entity-type"] + // and prepare the new event with the data inside + newDataFormat, _ := json.Marshal(map[string]string{ + "id": entityId, + "name": dataReceived.EmployeeName, + "__typename": entityType, + }) + + // create the new event + newEvent := &NatsEvent{ + Data: newDataFormat, + Metadata: natsEvent.Metadata, + } + + // or for Kafka we would have something like: + // newEvent := &KafkaEvent{ + // Key: kafkaEvent.Key, + // Data: newDataFormat, + // Headers: kafkaEvent.Headers, + // } + + // add the new event to the slice of events to return + newEvents = append(newEvents, newEvent) + continue + } + // add the original event to the slice of events to return + newEvents = append(newEvents, evt) + } + + return newEvents, nil +} + +func (m *MyModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: myModuleID, + Priority: 1, + New: func() core.Module { + return &MyModule{} + }, + } +} +``` + +### Example 2: Rewrite the event before emitting it to the provider to a format that is usable by external systems + +```go +// the interfaces/structs are reported partially to make the example more readable +// the full new interfaces/structs are available in the appendix 1 + +// StreamPublishEventHook processes a batch of outbound stream events +// +// Return: +// - empty slice: drop all events. +// - non-empty slice: emit those events (can grow, shrink, or reorder the batch). +// err != nil: abort the subscription with an error. +type StreamPublishEventHook interface { + OnPublishEvents(ctx StreamPublishEventHookContext, events []StreamEvent) ([]StreamEvent, error) +} + +// each provider will have its own event type that implements the StreamEvent interface +type NatsEvent struct { + Data json.RawMessage + Metadata map[string]string +} + +type MyModule struct {} + +// This is the new hook that will be called each time a batch of events is going to be sent to the provider +func (m *MyModule) OnPublishEvents( + ctx StreamPublishEventHookContext, + events []StreamEvent, +) ([]StreamEvent, error) { + // check if the provider ID is the one expected by the module + if ctx.PublishEventConfiguration().ProviderID() != "my-provider-id" { + return events, nil + } + + // check if the provider type is the one expected by the module + if ctx.PublishEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { + return events, nil + } + + // check if the subject is the one expected by the module + natsConfig := ctx.PublishEventConfiguration().(*nats.PublishAndRequestEventConfiguration) + if natsConfig.Subject != "topic-with-internal-data-format" { + return events, nil + } + + // create a new slice of events that we will return with the events with the new format + newEvents := make([]StreamEvent, 0, len(events)) + for _, evt := range events { + // check if the event is the one expected by the module + if natsEvent, ok := evt.(*NatsEvent); ok { + // here you can umarshal the old data and map it to the new format + // for example: + // var dataReceived struct { + // EmployeeId string `json:"EmployeeId"` + // } + // err := json.Unmarshal(natsEvent.Data, &dataReceived) + + // create the new event + newEvent := &NatsEvent{ + Data: dataToSendMarshalled, + Metadata: map[string]string{ + "entity-id": dataReceived.Id, + "entity-domain": "employee", + }, + } + + // add the new event to the slice of events to return + newEvents = append(newEvents, newEvent) + continue + } + newEvents = append(newEvents, evt) + } + return newEvents, nil +} + +func (m *MyModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: myModuleID, + Priority: 1, + New: func() core.Module { + return &MyModule{} + }, + } +} +``` + +### Proposal + +Add two new hooks to the stream lifecycle: `StreamBatchEventHook` and `StreamPublishEventHook`. +The `StreamBatchEventHook` will be called each time a batch of events is received from the provider, making it possible to rewrite, filter or split the event data to a format usable within Cosmo streams. +The `StreamPublishEventHook` will be called each time a batch of events is going to be sent to the provider, making it possible to rewrite, filter or split the event data to a format usable by external systems. + +The hook arguments of `StreamBatchEventHook` are: +* `ctx StreamBatchEventHookContext`: The stream context, which contains the provider ID and the subscription configuration +* `events []StreamEvent`: The events received from the provider + +The hook will return a new slice of events that will be used to emit the events to the client. +The hook will also return an error if one of the events cannot be processed, preventing the events from being processed. + +The hook arguments of `StreamPublishEventHook` are: +* `ctx StreamPublishEventHookContext`: The stream context, which contains the provider ID and the publish configuration +* `events []StreamEvent`: The events that are going to be sent to the provider + +The hook will return a new slice of events that will be used to emit the events to the provider. +The hook will also return an error if one of the events cannot be processed, preventing the events from being processed. + +#### Do we need two new hooks? + +Another possible solution for mapping outward data would be to use the existing middleware hooks `RouterOnRequestHandler` or `RouterMiddlewareHandler` to intercept the mutation, access the stream context, and emit the event to the stream. However, this would require exposing a stream context in the request lifecycle, which is difficult. It would also require coordination to ensure that an event emitted on the stream is sent only after the subscription starts. + +Additionally, this solution is not usable on the subscription side of streams: +- The middleware hook is linked to the request lifecycle, making it difficult to use them to rewrite event data +- When we use the streams feature internally, we will still need to provide a way to rewrite event data, requiring a new hook in the subscription lifecycle + +Therefore, I believe the best solution is to add a new hooks to the stream lifecycle. + +## Event Filtering + +We need to allow customers to filter events based on custom logic. We currently only provide declarative filters, which are quite limited. +The event filtering hook will also be useful to implement the authorization logic at the events level. + +### Example: Filter events based on stream configuration and client's scopes + +```go +// the interfaces/structs are reported partially to make the example more readable +// the full new interfaces/structs are available in the appendix 1 + +// StreamBatchEventHook processes a batch of inbound stream events. +// +// Return: +// - empty slice: drop all events. +// - non-empty slice: emit those events (can grow, shrink, or reorder the batch). +// err != nil: abort the subscription with an error. +type StreamBatchEventHook interface { + OnStreamEvents(ctx StreamBatchEventHookContext, events []StreamEvent) ([]StreamEvent, error) +} + +// each provider will have its own event type that implements the StreamEvent interface +type NatsEvent struct { + Data json.RawMessage + Metadata map[string]string +} + +type MyModule struct {} + +// This is the new hook that will be called each time a batch of events is received from the provider +func (m *MyModule) OnStreamEvents(ctx StreamBatchEventHookContext, events []StreamEvent) ([]StreamEvent, error) { + // check if the provider ID is the one expected by the module + if ctx.SubscriptionEventConfiguration().ProviderID() != "my-provider-id" { + return events, nil + } + + // check if the provider type is the one expected by the module + if ctx.SubscriptionEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { + return events, nil + } + + // check if the subject is the one expected by the module + natsConfig := ctx.SubscriptionEventConfiguration().(*nats.SubscriptionEventConfiguration) + if natsConfig.Subjects[0] != "topic-with-internal-data-format" { + return events, nil + } + + // create a new slice of events that we will return with the events that are allowed to be received by the client + newEvents := make([]StreamEvent, 0, len(events)) + + + if ctx.RequestContext().Authentication() == nil { + // if the client is not authenticated, return no events + return newEvents, nil + } + + // get the client's allowed entities IDs + clientAllowedEntitiesIds, found := ctx.RequestContext().Authentication().Claims()["allowedEntitiesIds"] + if !found { + // if the client doesn't have allowed entities IDs, return the original events + return newEvents, nil + } + + for _, evt := range events { + // check if the event is the one expected by the module + if natsEvent, ok := evt.(*NatsEvent); ok { + // check the entity ID in the metadata + idHeader, ok := natsEvent.Metadata["entity-id"] + if !ok { + continue + } + // check if the entity ID is in the client's allowed entities IDs + if slices.Contains(clientAllowedEntitiesIds, idHeader) { + // add the event to the slice of events to return because the client is allowed to receive it + newEvents = append(newEvents, evt) + } + } + } + return newEvents, nil +} + +func (m *MyModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: myModuleID, + Priority: 1, + New: func() core.Module { + return &MyModule{} + }, + } +} +``` + +### Proposal + +We can use the new `StreamBatchEventHook` to filter events based on the stream configuration and the client's scopes. + +The hook arguments are: +* `ctx StreamBatchEventHookContext`: The stream context, which contains the ID of the stream and the request context +* `events []StreamEvent`: The events received from the provider or the events that are going to be sent to the provider + +The hook will return a new slice of events that will be used to emit the events to the client or to the provider. +The hook will also return an error if one of the events cannot be processed, preventing the event from being processed. + +## Architecture + +With this proposal, we will add two new hooks to stream lifecycles and other hooks to the subscription lifecycle. + +### Subscription Lifecycle +``` +Start subscription + │ + └─▶ core.SubscriptionOnStartHandler (Early return, Custom Authentication Logic) + │ + └─▶ "Subscription started" +``` + +### Stream Lifecycle + +``` +One or more batched events are received from the provider + │ + └─▶ core.StreamBatchEventHook (Data mapping, Filtering) + │ + └─▶ "Deliver events to client" + +One or more batched events are published to the provider + │ + └─▶ core.StreamPublishEventHook (Data mapping, Filtering) + │ + └─▶ "Send event to provider" +``` + +### Data Flow + +We will need to change the format of the event data sent within the router. Today we use the data that will be sent to the provider directly, but we will need to add a structure where we can include additional fields (metadata, etc.) in the event. + +## Implementation Details + +The implementation of this solution will only require changes in the Cosmo repository, without any changes to the engine. This implementation will not require additional changes to the hooks structures each time a new provider is added. + +## Considerations and Risks + +- All hooks could be called in parallel, so we need to handle concurrency carefully +- All hook implementations could raise a panic, so we need to implement proper error handling +- Especially the casting of the event to the specific type for the provider could raise a panic if the event is not of the expected type and the developer is not using the type check +- We should add metrics to track how much time is spent in each hook, to help customers identify slow hooks + +## Development workflow of subscription with custom modules + +Lets build an example of how the development workflow would look like for a developer that want to add a custom module to the cosmo streams engine. The idea is to build a module that will be used to subscribe to the `employeeUpdates` subject and filter the events based on the client's scopes and remapping the messages as they are expected from the `Employee` type. + +I'll show the workflow for a developer that wants to customize the subscription, but the same workflow can be applied to the mutation. + +### Add a subscription to the cosmo streams graphql schema + +The developer will start by adding a subscription to the cosmo streams graphql schema. +```graphql +type Subscription { + employeeUpdates(): Employee! @edfs__natsSubscribe(subjects: ["employeeUpdates"], providerId: "my-nats") +} + +type Employee @key(fields: "id", resolvable: false) { + id: Int! @external +} +``` +After publishing the schema, the developer will need to add the module to the cosmo streams engine. + +### 2. Write the custom module + +The developer will need to write the custom module that will be used to subscribe to the `employeeUpdates` subject and filter the events based on the client's scopes and remapping the messages as they are expected from the `Employee` type. + +```go +package mymodule + +import ( + "encoding/json" + "slices" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" +) + +func init() { + // Register your module here and it will be loaded at router start + core.RegisterModule(&MyModule{}) +} + +type MyModule struct {} + +func (m *MyModule) OnStreamEvents(ctx StreamBatchEventHookContext, events []core.StreamEvent) ([]core.StreamEvent, error) { + // check if the provider is nats + if ctx.SubscriptionEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { + return events, nil + } + + // check if the provider id is the one expected by the module + if ctx.SubscriptionEventConfiguration().ProviderID() != "my-nats" { + return events, nil + } + + // check if the subject is the one expected by the module + natsConfig := ctx.SubscriptionEventConfiguration().(*nats.SubscriptionEventConfiguration) + if natsConfig.Subjects[0] != "employeeUpdates" { + return events, nil + } + + // check if the client is authenticated + if ctx.RequestContext().Authentication() == nil { + // if the client is not authenticated, return no events + return events, nil + } + + // check if the client is allowed to subscribe to the stream + clientAllowedEntitiesIds, found := ctx.RequestContext().Authentication().Claims()["allowedEntitiesIds"] + if !found { + return events, fmt.Errorf("client is not allowed to subscribe to the stream") + } + + newEvents := make([]core.StreamEvent, 0, len(events)) + + for _, evt := range events { + natsEvent, ok := evt.(*nats.NatsEvent); + if !ok { + newEvents = append(newEvents, evt) + continue + } + + // decode the event data coming from the provider + var dataReceived struct { + EmployeeId string `json:"EmployeeId"` + OtherField string `json:"OtherField"` + } + err := json.Unmarshal(natsEvent.Data, &dataReceived) + if err != nil { + return events, fmt.Errorf("error unmarshalling data: %w", err) + } + + // filter the events based on the client's scopes + if !slices.Contains(clientAllowedEntitiesIds, dataReceived.EmployeeId) { + continue + } + + // prepare the data to send to the client + var dataToSend struct { + Id string `json:"id"` + TypeName string `json:"__typename"` + } + dataToSend.Id = dataReceived.EmployeeId + dataToSend.TypeName = "Employee" + + // marshal the data to send to the client + dataToSendMarshalled, err := json.Marshal(dataToSend) + if err != nil { + return events, fmt.Errorf("error marshalling data: %w", err) + } + + // create the new event + newEvent := &nats.NatsEvent{ + Data: dataToSendMarshalled, + Metadata: natsEvent.Metadata, + } + newEvents = append(newEvents, newEvent) + } + return newEvents, nil +} + +func (m *MyModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: myModuleID, + Priority: 1, + New: func() core.Module { + return &MyModule{} + }, + } +} + +// Interface guards +var ( + _ core.StreamBatchEventHook = (*MyModule)(nil) +) +``` + +### 3. Add the provider configuration to the cosmo router +```yaml +version: "1" + +events: + providers: + nats: + - id: my-nats + url: "nats://localhost:4222" +``` + +### 4. Build the cosmo router with the custom module + +Build and run the router with the custom module added. + +## Appendix 1, new data structures + +```go +// NEW HOOKS + +// SubscriptionOnStartHandler is a hook that is called once at subscription start +// it is used to validate if the client is allowed to subscribe to the stream +// if returns an error, the subscription will not start +type SubscriptionOnStartHandler interface { + SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error +} + +// StreamBatchEventHook processes a batch of inbound stream events +// +// Return: +// - empty slice: drop all events. +// - non-empty slice: emit those events (can grow, shrink, or reorder the batch). +// err != nil: abort the subscription with an error. +type StreamBatchEventHook interface { + OnStreamEvents(ctx StreamBatchEventHookContext, events []StreamEvent) ([]StreamEvent, error) +} + +// StreamPublishEventHook processes a batch of outbound stream events +// +// Return: +// - empty slice: drop all events. +// - non-empty slice: emit those events (can grow, shrink, or reorder the batch). +// err != nil: abort the subscription with an error. +type StreamPublishEventHook interface { + OnPublishEvents(ctx StreamPublishEventHookContext, events []StreamEvent) ([]StreamEvent, error) +} + +// NEW INTERFACES +type SubscriptionEventConfiguration interface { + ProviderID() string + ProviderType() string + RootFieldName() string // the root field name of the subscription in the schema +} + +type PublishEventConfiguration interface { + ProviderID() string + ProviderType() string + RootFieldName() string // the root field name of the mutation in the schema +} + +type StreamEvent interface {} + +type StreamBatchEventHookContext interface { + RequestContext() RequestContext + SubscriptionEventConfiguration() SubscriptionEventConfiguration +} + +type StreamPublishEventHookContext interface { + RequestContext() RequestContext + PublishEventConfiguration() PublishEventConfiguration +} + +type SubscriptionOnStartHookContext interface { + RequestContext() RequestContext + SubscriptionEventConfiguration() SubscriptionEventConfiguration + WriteEvent(event core.StreamEvent) +} + +// ALREADY EXISTING INTERFACES THAT WILL BE UPDATED +type OperationContext interface { + Name() string + // the variables are currently not available, so we need to add them here + Variables() *astjson.Value +} + +// NEW STRUCTURES +// StreamHookError is used to customize the error messages and the behavior +type StreamHookError struct { + HttpError core.HttpError + CloseSubscription bool +} + +func (e StreamHookError) Error() string { + return e.HttpError.Message() +} + +// STRUCTURES TO BE ADDED TO PUBSUB PACKAGE +type ProviderType string +const ( + ProviderTypeNats ProviderType = "nats" + ProviderTypeKafka ProviderType = "kafka" + ProviderTypeRedis ProviderType = "redis" +} + +``` + +## Appendix 2, Using AsyncAPI for Event Data Structure + +As a side note, it is important to find ways to document the data that is arriving and going out of the cosmo streams engine. This could allow some automatic code generation starting from the schema and the events data. +As an example, we are going to explore how AsyncAPI could be used to generate the data structures for the custom modules and assure the messages format. + +### Example: AsyncAPI Integration for Custom Module Development + +We propose integrating AsyncAPI specifications with Cosmo streams to generate type-safe Go structs that can be used in custom modules. This would significantly improve the developer experience by providing: + +1. **Type Safety**: Generated structs prevent runtime errors from incorrect field access +2. **Documentation**: AsyncAPI specs serve as living documentation for event schemas +3. **Code Generation**: Automatic generation of Go structs from AsyncAPI specifications +4. **IDE Support**: Better autocomplete and error detection in development environments + +### AsyncAPI Specification Example + +So if we have as an example the following AsyncAPI specification: + +```yaml +# employee-events.asyncapi.yaml +asyncapi: 3.0.0 +info: + title: Employee Events API + version: 1.0.0 + description: Events related to employee updates in the system + +channels: + externalSystemEmployeeUpdates: + messages: + EmployeeUpdated: + $ref: '#/components/messages/EmployeeUpdated' + +components: + messages: + ExternalSystemEmployeeUpdated: + name: ExternalSystemEmployeeUpdated + title: External System Employee Updated Event + summary: Sent when an employee is updated in the external system + contentType: application/json + payload: + $ref: '#/components/schemas/ExternalSystemEmployeeFormat' + + schemas: + ExternalSystemEmployeeFormat: + type: object + description: Employee data as received from external systems + properties: + EmployeeId: + type: string + description: Unique identifier for the employee + EmployeeName: + type: string + description: Full name of the employee + EmployeeEmail: + type: string + format: email + description: Email address of the employee + OtherField: + type: string + description: Additional field from external system + required: + - EmployeeId + - EmployeeName + - EmployeeEmail +``` + +### Code Generation Workflow + +We could provide a CLI command to WGC to generate the Go structs from AsyncAPI specifications: + +```bash +# Generate Go structs from AsyncAPI spec +wgc streams generate -i employee-events.asyncapi.yaml -o ./generated/events.go -p events +``` + +Before generating the code, we could add to the data that cosmo streams is expecting to receive and send. +```yaml +# cosmo-streams-events.asyncapi.yaml +asyncapi: 3.0.0 +info: + title: Cosmo Streams Employee Events API + version: 1.0.0 + +channels: + cosmoStreamsEmployeeUpdates: + messages: + CosmoStreamsEmployeeUpdated: + $ref: '#/components/messages/CosmoStreamsEmployeeUpdated' + +components: + messages: + CosmoStreamsEmployeeUpdated: + name: CosmoStreamsEmployeeUpdated + title: Cosmo Streams Employee Updated Event + summary: Event published when updating an employee in the cosmo streams + contentType: application/json + payload: + $ref: '#/components/schemas/EmployeeInternalFormat' + + schemas: + CosmoStreamsEmployeeUpdated: + type: object + description: Employee data as used internally by Cosmo streams + properties: + id: + type: string + description: Unique identifier for the employee + name: + type: string + description: Full name of the employee + email: + type: string + format: email + description: Email address of the employee + required: + - id + - __typename +``` + +This command would be a wrapper around asyncapi modelina, and with some additional logic to extract the internal events format from the schema SDL. + +This would generate a second async api specification and Go code like: + +```go +// generated/events.go +package events + +import ( + "encoding/json" + "time" +) + +// ExternalSystemEmployeeUpdated represents employee data as received from external systems +type ExternalSystemEmployeeUpdated struct { + EmployeeId string `json:"EmployeeId"` + EmployeeName string `json:"EmployeeName"` + EmployeeEmail string `json:"EmployeeEmail"` + OtherField string `json:"OtherField"` +} + +// EmployeeInternalFormat represents employee data as used internally by Cosmo streams +type CosmoStreamsEmployeeUpdated struct { + Id string `json:"id"` + Name string `json:"name"` + Email string `json:"email"` +} +``` + +We could than encourage the developers to add conversions in a file in the same package of the generated file, like so: + +```go +// generated/events.go +package events + +import ( + "encoding/json" + "time" +) + +func ExternalSystemEmployeeUpdatedToCosmoStreamsEmployeeUpdated(e *ExternalSystemEmployeeUpdated) *CosmoStreamsEmployeeUpdated { + return &CosmoStreamsEmployeeUpdated{ + Id: e.EmployeeId, + Name: e.EmployeeName, + Email: e.EmployeeEmail, + } +} + +``` + +Also, external systems could use the generated async api specification to generate the code for the events that they are sending/receiving to/from cosmo streams. + +### Enhanced Custom Module Development + +With generated structs, the custom module code becomes more maintainable and type-safe: + +```go +package mymodule + +import ( + "encoding/json" + "fmt" + "slices" + + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" + "your-project/generated/genevents" +) + +type MyModule struct {} + +func (m *MyModule) OnStreamEvents(ctx StreamBatchEventHookContext, events []core.StreamEvent) ([]core.StreamEvent, error) { + if ctx.SubscriptionEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { + return events, nil + } + + if ctx.SubscriptionEventConfiguration().ProviderID() != "my-nats" { + return events, nil + } + + natsConfig := ctx.SubscriptionEventConfiguration().(*nats.SubscriptionEventConfiguration) + if natsConfig.Subjects[0] != "employeeUpdates" { + return events, nil + } + + clientAllowedEntitiesIds, found := ctx.RequestContext().Authentication().Claims()["allowedEntitiesIds"] + if !found { + return events, fmt.Errorf("client is not allowed to subscribe to the stream") + } + + for _, evt := range events { + natsEvent, ok := evt.(*nats.NatsEvent); + if !ok { + newEvents = append(newEvents, evt) + continue + } + + // Use generated struct for type-safe deserialization + var dataReceived genevents.ExternalSystemEmployeeUpdated + err := json.Unmarshal(natsEvent.Data, &dataReceived) + if err != nil { + return events, fmt.Errorf("error unmarshalling data: %w", err) + } + + // Convert to internal format using generated method + dataToSend := genevents.ExternalSystemEmployeeUpdatedToCosmoStreamsEmployeeUpdated(&dataReceived) + + // Marshal using the generated struct + dataToSendMarshalled, err := json.Marshal(dataToSend) + if err != nil { + return events, fmt.Errorf("error marshalling data: %w", err) + } + + // Create new event + newEvent := &nats.NatsEvent{ + Data: dataToSendMarshalled, + } + newEvents = append(newEvents, newEvent) + } + return newEvents, nil +} + +func (m *MyModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + ID: myModuleID, + Priority: 1, + New: func() core.Module { + return &MyModule{} + }, + } +} + +var _ core.StreamBatchEventHook = (*MyModule)(nil) +``` + +### Considerations + +The developers would need to regenerate the code each time the AsyncAPI specification changes or the schema SDL changes. + +### Outlook + +In a second step, we could: +- allow the user to define their streams using AsyncAPI +- generate fully typesafe hooks with all events structures generated from the AsyncAPI specification \ No newline at end of file diff --git a/router-tests/go.mod b/router-tests/go.mod index ffe6d65377..479d44590c 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -27,7 +27,7 @@ require ( github.com/wundergraph/cosmo/demo/pkg/subgraphs/projects v0.0.0-20250715110703-10f2e5f9c79e github.com/wundergraph/cosmo/router v0.0.0-20250912064154-106e871ee32e github.com/wundergraph/cosmo/router-plugin v0.0.0-20250808194725-de123ba1c65e - github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229 + github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb go.opentelemetry.io/otel v1.36.0 go.opentelemetry.io/otel/sdk v1.36.0 go.opentelemetry.io/otel/sdk/metric v1.36.0 @@ -209,7 +209,7 @@ replace ( github.com/wundergraph/cosmo/demo/pkg/subgraphs/projects => ../demo/pkg/subgraphs/projects github.com/wundergraph/cosmo/router => ../router github.com/wundergraph/cosmo/router-plugin => ../router-plugin -// github.com/wundergraph/graphql-go-tools/v2 => ../../graphql-go-tools/v2 +//github.com/wundergraph/graphql-go-tools/v2 => ../../graphql-go-tools/v2 ) replace github.com/hashicorp/consul/sdk => github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301 diff --git a/router-tests/go.sum b/router-tests/go.sum index c6ec48801f..947f5ba76c 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -354,6 +354,8 @@ github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301 h1:EzfKHQoT github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301/go.mod h1:wxI0Nak5dI5RvJuzGyiEK4nZj0O9X+Aw6U0tC1wPKq0= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229 h1:VCfCX/xmpBGQLhTHJMHLugzJrXJk/smjLRAEruCI0HY= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb h1:stBTAle5FyytsTNxYeCwNzYlyhKzlS4he6f7/y6O3qE= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= diff --git a/router-tests/modules/start-subscription/module.go b/router-tests/modules/start-subscription/module.go new file mode 100644 index 0000000000..fd5a9e0088 --- /dev/null +++ b/router-tests/modules/start-subscription/module.go @@ -0,0 +1,61 @@ +package start_subscription + +import ( + "net/http" + + "go.uber.org/zap" + + "github.com/wundergraph/cosmo/router/core" +) + +const myModuleID = "startSubscriptionModule" + +type StartSubscriptionModule struct { + Logger *zap.Logger + Callback func(ctx core.SubscriptionOnStartHookContext) error + CallbackOnOriginResponse func(response *http.Response, ctx core.RequestContext) *http.Response +} + +func (m *StartSubscriptionModule) Provision(ctx *core.ModuleContext) error { + // Assign the logger to the module for non-request related logging + m.Logger = ctx.Logger + + return nil +} + +func (m *StartSubscriptionModule) SubscriptionOnStart(ctx core.SubscriptionOnStartHookContext) error { + + m.Logger.Info("SubscriptionOnStart Hook has been run") + + if m.Callback != nil { + return m.Callback(ctx) + } + + return nil +} + +func (m *StartSubscriptionModule) OnOriginResponse(response *http.Response, ctx core.RequestContext) *http.Response { + if m.CallbackOnOriginResponse != nil { + return m.CallbackOnOriginResponse(response, ctx) + } + + return response +} + +func (m *StartSubscriptionModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + // This is the ID of your module, it must be unique + ID: myModuleID, + // The priority of your module, lower numbers are executed first + Priority: 1, + New: func() core.Module { + return &StartSubscriptionModule{} + }, + } +} + +// Interface guard +var ( + _ core.SubscriptionOnStartHandler = (*StartSubscriptionModule)(nil) + _ core.EnginePostOriginHandler = (*StartSubscriptionModule)(nil) +) diff --git a/router-tests/modules/start_subscription_test.go b/router-tests/modules/start_subscription_test.go new file mode 100644 index 0000000000..ad286d54ef --- /dev/null +++ b/router-tests/modules/start_subscription_test.go @@ -0,0 +1,664 @@ +package module_test + +import ( + "errors" + "net/http" + "testing" + "time" + + "github.com/hasura/go-graphql-client" + start_subscription "github.com/wundergraph/cosmo/router-tests/modules/start-subscription" + "go.uber.org/zap/zapcore" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/cosmo/router/pkg/pubsub/kafka" +) + +func TestStartSubscriptionHook(t *testing.T) { + t.Parallel() + + t.Run("Test StartSubscription hook is called", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": start_subscription.StartSubscriptionModule{}, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: $employeeID)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + vars := map[string]interface{}{ + "employeeID": 3, + } + subscriptionOneID, err := client.Subscribe(&subscriptionOne, vars, func(dataValue []byte, errValue error) error { + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, time.Second*10) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) + + t.Run("Test StartSubscription write event works", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": start_subscription.StartSubscriptionModule{ + Callback: func(ctx core.SubscriptionOnStartHookContext) error { + if ctx.SubscriptionEventConfiguration().RootFieldName() != "employeeUpdatedMyKafka" { + return nil + } + ctx.WriteEvent(&kafka.Event{ + Key: []byte("1"), + Data: []byte(`{"id": 1, "__typename": "Employee"}`), + }) + return nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: $employeeID)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + vars := map[string]interface{}{ + "employeeID": 3, + } + type kafkaSubscriptionArgs struct { + dataValue []byte + errValue error + } + subscriptionArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, vars, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, time.Second*10) + + testenv.AwaitChannelWithT(t, time.Second*10, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) + }) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) + + t.Run("Test StartSubscription with close to true", func(t *testing.T) { + t.Parallel() + + callbackCalled := make(chan bool) + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": start_subscription.StartSubscriptionModule{ + Callback: func(ctx core.SubscriptionOnStartHookContext) error { + callbackCalled <- true + return core.NewStreamHookError(nil, "subscription closed", http.StatusOK, "") + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: $employeeID)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + vars := map[string]interface{}{ + "employeeID": 3, + } + type kafkaSubscriptionArgs struct { + dataValue []byte + errValue error + } + subscriptionArgsCh := make(chan kafkaSubscriptionArgs, 1) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, vars, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, time.Second*10) + <-callbackCalled + xEnv.WaitForSubscriptionCount(0, time.Second*10) + + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") + assert.Len(t, requestLog.All(), 1) + + require.Len(t, subscriptionArgsCh, 1) + subscriptionArgs := <-subscriptionArgsCh + require.Error(t, subscriptionArgs.errValue) + require.Empty(t, subscriptionArgs.dataValue) + }) + }) + + t.Run("Test StartSubscription write event sends event only to the subscription", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": start_subscription.StartSubscriptionModule{ + Callback: func(ctx core.SubscriptionOnStartHookContext) error { + employeeId := ctx.Operation().Variables().GetInt64("employeeID") + if employeeId != 1 { + return nil + } + ctx.WriteEvent(&kafka.Event{ + Data: []byte(`{"id": 1, "__typename": "Employee"}`), + }) + return nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + var subscription struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: $employeeID)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + vars := map[string]interface{}{ + "employeeID": 3, + } + vars2 := map[string]interface{}{ + "employeeID": 1, + } + type kafkaSubscriptionArgs struct { + dataValue []byte + errValue error + } + subscriptionOneArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscription, vars, func(dataValue []byte, errValue error) error { + subscriptionOneArgsCh <- kafkaSubscriptionArgs{ + dataValue: []byte{}, + errValue: errors.New("should not be called"), + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + subscriptionTwoArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionTwoID, err := client.Subscribe(&subscription, vars2, func(dataValue []byte, errValue error) error { + subscriptionTwoArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionTwoID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(2, time.Second*10) + + testenv.AwaitChannelWithT(t, time.Second*10, subscriptionTwoArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) + }) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") + assert.Len(t, requestLog.All(), 2) + t.Cleanup(func() { + require.Len(t, subscriptionOneArgsCh, 0) + }) + }) + }) + + t.Run("Test StartSubscription error is propagated to the client", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": start_subscription.StartSubscriptionModule{ + Callback: func(ctx core.SubscriptionOnStartHookContext) error { + return core.NewStreamHookError(errors.New("test error"), "test error", http.StatusLoopDetected, http.StatusText(http.StatusLoopDetected)) + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + var subscription struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: $employeeID)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + vars := map[string]interface{}{ + "employeeID": 1, + } + type kafkaSubscriptionArgs struct { + dataValue []byte + errValue error + } + subscriptionOneArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscription, vars, func(dataValue []byte, errValue error) error { + subscriptionOneArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + // Wait for the subscription to be closed + xEnv.WaitForSubscriptionCount(0, time.Second*10) + + testenv.AwaitChannelWithT(t, time.Second*10, subscriptionOneArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { + var graphqlErrs graphql.Errors + require.ErrorAs(t, args.errValue, &graphqlErrs) + statusCode, ok := graphqlErrs[0].Extensions["statusCode"].(float64) + require.True(t, ok, "statusCode is not a float64") + require.Equal(t, http.StatusLoopDetected, int(statusCode)) + require.Equal(t, http.StatusText(http.StatusLoopDetected), graphqlErrs[0].Extensions["code"]) + }) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") + assert.Len(t, requestLog.All(), 1) + t.Cleanup(func() { + require.Len(t, subscriptionOneArgsCh, 0) + }) + }) + }) + + t.Run("Test StartSubscription hook is called for engine subscription", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": start_subscription.StartSubscriptionModule{}, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + + var subscriptionCountEmp struct { + CountEmp int `graphql:"countEmp(max: $max, intervalMilliseconds: $interval)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + vars := map[string]interface{}{ + "max": 1, + "interval": 200, + } + subscriptionOneID, err := client.Subscribe(&subscriptionCountEmp, vars, func(dataValue []byte, errValue error) error { + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, time.Second*10) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) + + t.Run("Test StartSubscription hook is called for engine subscription and write event works", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": start_subscription.StartSubscriptionModule{ + Callback: func(ctx core.SubscriptionOnStartHookContext) error { + ctx.WriteEvent(&core.EngineEvent{ + Data: []byte(`{"data":{"countEmp":1000}}`), + }) + return nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + + var subscriptionCountEmp struct { + CountEmp int `graphql:"countEmp(max: $max, intervalMilliseconds: $interval)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + vars := map[string]interface{}{ + "max": 0, + "interval": 0, + } + + type subscriptionArgs struct { + dataValue []byte + errValue error + } + subscriptionOneArgsCh := make(chan subscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionCountEmp, vars, func(dataValue []byte, errValue error) error { + subscriptionOneArgsCh <- subscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, time.Second*10) + + testenv.AwaitChannelWithT(t, time.Second*10, subscriptionOneArgsCh, func(t *testing.T, args subscriptionArgs) { + require.NoError(t, args.errValue) + require.JSONEq(t, `{"countEmp": 1000}`, string(args.dataValue)) + }) + + testenv.AwaitChannelWithT(t, time.Second*10, subscriptionOneArgsCh, func(t *testing.T, args subscriptionArgs) { + require.NoError(t, args.errValue) + require.JSONEq(t, `{"countEmp": 0}`, string(args.dataValue)) + }) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) + + t.Run("Test StartSubscription hook is called, return StreamHookError, response on OnOriginResponse should still be set", func(t *testing.T) { + t.Parallel() + originResponseCalled := make(chan *http.Response, 1) + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "startSubscriptionModule": start_subscription.StartSubscriptionModule{ + Callback: func(ctx core.SubscriptionOnStartHookContext) error { + return core.NewStreamHookError(errors.New("subscription closed"), "subscription closed", http.StatusOK, "NotFound") + }, + CallbackOnOriginResponse: func(response *http.Response, ctx core.RequestContext) *http.Response { + originResponseCalled <- response + return response + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&start_subscription.StartSubscriptionModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + var subscriptionCountEmp struct { + CountEmp int `graphql:"countEmp(max: $max, intervalMilliseconds: $interval)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + vars := map[string]interface{}{ + "max": 0, + "interval": 0, + } + + type subscriptionArgs struct { + dataValue []byte + errValue error + } + subscriptionOneArgsCh := make(chan subscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionCountEmp, vars, func(dataValue []byte, errValue error) error { + subscriptionOneArgsCh <- subscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + testenv.AwaitChannelWithT(t, time.Second*10, subscriptionOneArgsCh, func(t *testing.T, args subscriptionArgs) { + require.Error(t, args.errValue) + require.Empty(t, args.dataValue) + }) + + testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + require.Empty(t, originResponseCalled) + + requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) +} diff --git a/router/.mockery.yml b/router/.mockery.yml index c835d3af50..558bca2185 100644 --- a/router/.mockery.yml +++ b/router/.mockery.yml @@ -13,10 +13,11 @@ template-schema: '{{.Template}}.schema.json' packages: github.com/wundergraph/cosmo/router/pkg/pubsub/datasource: interfaces: - ProviderLifecycle: + Lifecycle: ProviderBuilder: EngineDataSourceFactory: Provider: + SubscriptionEventUpdater: github.com/wundergraph/cosmo/router/pkg/pubsub/nats: interfaces: Adapter: diff --git a/router/core/errors.go b/router/core/errors.go index 7f8df34da2..44e05f327b 100644 --- a/router/core/errors.go +++ b/router/core/errors.go @@ -35,6 +35,7 @@ const ( errorTypeInvalidWsSubprotocol errorTypeEDFSInvalidMessage errorTypeMergeResult + errorTypeStreamHookError ) type ( @@ -89,6 +90,10 @@ func getErrorType(err error) errorType { if errors.As(err, &mergeResultErr) { return errorTypeMergeResult } + var streamHookErr *StreamHookError + if errors.As(err, &streamHookErr) { + return errorTypeStreamHookError + } return errorTypeUnknown } diff --git a/router/core/executor.go b/router/core/executor.go index 437634ea8c..e29ed7682b 100644 --- a/router/core/executor.go +++ b/router/core/executor.go @@ -35,6 +35,8 @@ type ExecutorConfigurationBuilder struct { subscriptionClientOptions *SubscriptionClientOptions instanceData InstanceData + + subscriptionHooks subscriptionHooks } type Executor struct { @@ -216,7 +218,7 @@ func (b *ExecutorConfigurationBuilder) buildPlannerConfiguration(ctx context.Con routerEngineCfg.Execution.EnableSingleFlight, routerEngineCfg.Execution.EnableNetPoll, b.instanceData, - ), b.logger) + ), b.logger, b.subscriptionHooks) // this generates the plan config using the data source factories from the config package planConfig, providers, err := loader.Load(engineConfig, subgraphs, routerEngineCfg, pluginsEnabled) diff --git a/router/core/factoryresolver.go b/router/core/factoryresolver.go index b73742a91d..70f3e7917c 100644 --- a/router/core/factoryresolver.go +++ b/router/core/factoryresolver.go @@ -31,8 +31,9 @@ import ( ) type Loader struct { - ctx context.Context - resolver FactoryResolver + ctx context.Context + resolver FactoryResolver + subscriptionHooks subscriptionHooks // includeInfo controls whether additional information like type usage and field usage is included in the plan de includeInfo bool logger *zap.Logger @@ -190,12 +191,13 @@ func (d *DefaultFactoryResolver) InstanceData() InstanceData { return d.instanceData } -func NewLoader(ctx context.Context, includeInfo bool, resolver FactoryResolver, logger *zap.Logger) *Loader { +func NewLoader(ctx context.Context, includeInfo bool, resolver FactoryResolver, logger *zap.Logger, subscriptionHooks subscriptionHooks) *Loader { return &Loader{ - ctx: ctx, - resolver: resolver, - includeInfo: includeInfo, - logger: logger, + ctx: ctx, + resolver: resolver, + includeInfo: includeInfo, + logger: logger, + subscriptionHooks: subscriptionHooks, } } @@ -416,6 +418,10 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod } } + subscriptionOnStartFns := make([]graphql_datasource.SubscriptionOnStartFn, len(l.subscriptionHooks.onStart)) + for i, fn := range l.subscriptionHooks.onStart { + subscriptionOnStartFns[i] = NewEngineSubscriptionOnStartHook(fn) + } customConfiguration, err := graphql_datasource.NewConfiguration(graphql_datasource.ConfigurationInput{ Fetch: &graphql_datasource.FetchConfiguration{ URL: fetchUrl, @@ -429,6 +435,7 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod ForwardedClientHeaderNames: forwardedClientHeaders, ForwardedClientHeaderRegularExpressions: forwardedClientRegexps, WsSubProtocol: wsSubprotocol, + StartupHooks: subscriptionOnStartFns, }, SchemaConfiguration: schemaConfiguration, CustomScalarTypeFields: customScalarTypeFields, @@ -470,6 +477,10 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod } } + subscriptionOnStartFns := make([]pubsub_datasource.SubscriptionOnStartFn, len(l.subscriptionHooks.onStart)) + for i, fn := range l.subscriptionHooks.onStart { + subscriptionOnStartFns[i] = NewPubSubSubscriptionOnStartHook(fn) + } factoryProviders, factoryDataSources, err := pubsub.BuildProvidersAndDataSources( l.ctx, routerEngineConfig.Events, @@ -478,6 +489,9 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod pubSubDS, l.resolver.InstanceData().HostName, l.resolver.InstanceData().ListenAddress, + pubsub.Hooks{ + SubscriptionOnStart: subscriptionOnStartFns, + }, ) if err != nil { return nil, providers, err diff --git a/router/core/graph_server.go b/router/core/graph_server.go index 95f66c0ac5..c1330f77f5 100644 --- a/router/core/graph_server.go +++ b/router/core/graph_server.go @@ -1189,6 +1189,7 @@ func (s *graphServer) buildGraphMux( EnableTraceClient: enableTraceClient, CircuitBreaker: s.circuitBreakerManager, }, + subscriptionHooks: s.subscriptionHooks, } executor, providers, err := ecb.Build( diff --git a/router/core/graphql_handler.go b/router/core/graphql_handler.go index c494fff4ce..f387d73e6c 100644 --- a/router/core/graphql_handler.go +++ b/router/core/graphql_handler.go @@ -400,6 +400,22 @@ func (h *GraphQLHandler) WriteError(ctx *resolve.Context, err error, res *resolv if isHttpResponseWriter { httpWriter.WriteHeader(http.StatusInternalServerError) } + case errorTypeStreamHookError: + var streamHookErr *StreamHookError + if !errors.As(err, &streamHookErr) { + response.Errors[0].Message = "Internal server error" + return + } + response.Errors[0].Message = streamHookErr.Message() + if streamHookErr.Code() != "" || streamHookErr.StatusCode() != 0 { + response.Errors[0].Extensions = &Extensions{ + Code: streamHookErr.Code(), + StatusCode: streamHookErr.StatusCode(), + } + } + if isHttpResponseWriter { + httpWriter.WriteHeader(streamHookErr.StatusCode()) + } } if ctx.TracingOptions.Enable && ctx.TracingOptions.IncludeTraceOutputInResponseExtensions { diff --git a/router/core/plan_generator.go b/router/core/plan_generator.go index 1026fbe592..4c265a67d5 100644 --- a/router/core/plan_generator.go +++ b/router/core/plan_generator.go @@ -323,7 +323,7 @@ func (pg *PlanGenerator) loadConfiguration(routerConfig *nodev1.RouterConfig, lo httpClient: http.DefaultClient, streamingClient: http.DefaultClient, subscriptionClient: subscriptionClient, - }, logger) + }, logger, subscriptionHooks{}) // this generates the plan configuration using the data source factories from the config package planConfig, _, err := loader.Load(routerConfig.GetEngineConfig(), routerConfig.GetSubgraphs(), &routerEngineConfig, false) // TODO: configure plugins diff --git a/router/core/router.go b/router/core/router.go index 3432528a7a..9f07bd723a 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -666,6 +666,10 @@ func (r *Router) initModules(ctx context.Context) error { } } + if handler, ok := moduleInstance.(SubscriptionOnStartHandler); ok { + r.subscriptionHooks.onStart = append(r.subscriptionHooks.onStart, handler.SubscriptionOnStart) + } + r.modules = append(r.modules, moduleInstance) r.logger.Info("Module registered", diff --git a/router/core/router_config.go b/router/core/router_config.go index 89d99f2ce1..ac4f26d4c7 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -25,6 +25,10 @@ import ( "go.uber.org/zap" ) +type subscriptionHooks struct { + onStart []func(ctx SubscriptionOnStartHookContext) error +} + type Config struct { clusterName string instanceID string @@ -118,6 +122,7 @@ type Config struct { mcp config.MCPConfiguration plugins config.PluginsConfiguration tracingAttributes []config.CustomAttribute + subscriptionHooks subscriptionHooks } // Usage returns an anonymized version of the config for usage tracking diff --git a/router/core/subscriptions_modules.go b/router/core/subscriptions_modules.go new file mode 100644 index 0000000000..505bbfc44f --- /dev/null +++ b/router/core/subscriptions_modules.go @@ -0,0 +1,188 @@ +package core + +import ( + "net/http" + + "github.com/wundergraph/cosmo/router/pkg/authentication" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/graphql_datasource" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" +) + +// StreamHookError is used to customize the error messages and the behavior +type StreamHookError struct { + err error + message string + statusCode int + code string +} + +func (e *StreamHookError) Error() string { + if e.err != nil { + return e.err.Error() + } + return e.message +} + +func (e *StreamHookError) Message() string { + return e.message +} + +func (e *StreamHookError) StatusCode() int { + return e.statusCode +} + +func (e *StreamHookError) Code() string { + return e.code +} + +func NewStreamHookError(err error, message string, statusCode int, code string) *StreamHookError { + return &StreamHookError{ + err: err, + message: message, + statusCode: statusCode, + code: code, + } +} + +type SubscriptionOnStartHookContext interface { + // Request is the original request received by the router. + Request() *http.Request + // Logger is the logger for the request + Logger() *zap.Logger + // Operation is the GraphQL operation + Operation() OperationContext + // Authentication is the authentication for the request + Authentication() authentication.Authentication + // SubscriptionEventConfiguration is the subscription event configuration (will return nil for engine subscription) + SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration + // WriteEvent writes an event to the stream of the current subscription + // It returns true if the event was written to the stream, false if the event was dropped + WriteEvent(event datasource.StreamEvent) bool +} + +type pubSubSubscriptionOnStartHookContext struct { + request *http.Request + logger *zap.Logger + operation OperationContext + authentication authentication.Authentication + subscriptionEventConfiguration datasource.SubscriptionEventConfiguration + writeEventHook func(data []byte) +} + +func (c *pubSubSubscriptionOnStartHookContext) Request() *http.Request { + return c.request +} + +func (c *pubSubSubscriptionOnStartHookContext) Logger() *zap.Logger { + return c.logger +} + +func (c *pubSubSubscriptionOnStartHookContext) Operation() OperationContext { + return c.operation +} + +func (c *pubSubSubscriptionOnStartHookContext) Authentication() authentication.Authentication { + return c.authentication +} + +func (c *pubSubSubscriptionOnStartHookContext) SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration { + return c.subscriptionEventConfiguration +} + +func (c *pubSubSubscriptionOnStartHookContext) WriteEvent(event datasource.StreamEvent) bool { + c.writeEventHook(event.GetData()) + + return true +} + +// EngineEvent is the event used to write to the engine subscription +type EngineEvent struct { + Data []byte +} + +func (e *EngineEvent) GetData() []byte { + return e.Data +} + +type engineSubscriptionOnStartHookContext struct { + request *http.Request + logger *zap.Logger + operation OperationContext + authentication authentication.Authentication + writeEventHook func(data []byte) +} + +func (c *engineSubscriptionOnStartHookContext) Request() *http.Request { + return c.request +} + +func (c *engineSubscriptionOnStartHookContext) Logger() *zap.Logger { + return c.logger +} + +func (c *engineSubscriptionOnStartHookContext) Operation() OperationContext { + return c.operation +} + +func (c *engineSubscriptionOnStartHookContext) Authentication() authentication.Authentication { + return c.authentication +} + +func (c *engineSubscriptionOnStartHookContext) WriteEvent(event datasource.StreamEvent) bool { + c.writeEventHook(event.GetData()) + + return true +} + +func (c *engineSubscriptionOnStartHookContext) SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration { + return nil +} + +type SubscriptionOnStartHandler interface { + // SubscriptionOnStart is called once at subscription start + // The error is propagated to the client. + SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error +} + +// NewPubSubSubscriptionOnStartHook converts a SubscriptionOnStartHandler to a pubsub.SubscriptionOnStartFn +func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHookContext) error) datasource.SubscriptionOnStartFn { + if fn == nil { + return nil + } + + return func(resolveCtx resolve.StartupHookContext, subConf datasource.SubscriptionEventConfiguration) error { + requestContext := getRequestContext(resolveCtx.Context) + hookCtx := &pubSubSubscriptionOnStartHookContext{ + request: requestContext.Request(), + logger: requestContext.Logger(), + operation: requestContext.Operation(), + authentication: requestContext.Authentication(), + subscriptionEventConfiguration: subConf, + writeEventHook: resolveCtx.Updater, + } + + return fn(hookCtx) + } +} + +// NewEngineSubscriptionOnStartHook converts a SubscriptionOnStartHandler to a graphql_datasource.SubscriptionOnStartFn +func NewEngineSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHookContext) error) graphql_datasource.SubscriptionOnStartFn { + if fn == nil { + return nil + } + + return func(resolveCtx resolve.StartupHookContext, input []byte) error { + requestContext := getRequestContext(resolveCtx.Context) + hookCtx := &engineSubscriptionOnStartHookContext{ + request: requestContext.Request(), + logger: requestContext.Logger(), + operation: requestContext.Operation(), + authentication: requestContext.Authentication(), + writeEventHook: resolveCtx.Updater, + } + + return fn(hookCtx) + } +} diff --git a/router/go.mod b/router/go.mod index 180c9f51d0..82ff4f2e73 100644 --- a/router/go.mod +++ b/router/go.mod @@ -31,7 +31,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/twmb/franz-go v1.16.1 - github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229 + github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb // Do not upgrade, it renames attributes we rely on go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 go.opentelemetry.io/contrib/propagators/b3 v1.23.0 @@ -196,4 +196,4 @@ replace ( // Remember you can use Go workspaces to avoid using replace directives in multiple go.mod files // Use what is best for your personal workflow. See CONTRIBUTING.md for more information -// replace github.com/wundergraph/graphql-go-tools/v2 => ../../graphql-go-tools/v2 +//replace github.com/wundergraph/graphql-go-tools/v2 => ../../graphql-go-tools/v2 diff --git a/router/go.sum b/router/go.sum index 0263992f20..1a0bc0afe5 100644 --- a/router/go.sum +++ b/router/go.sum @@ -321,8 +321,8 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/ github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229 h1:VCfCX/xmpBGQLhTHJMHLugzJrXJk/smjLRAEruCI0HY= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb h1:stBTAle5FyytsTNxYeCwNzYlyhKzlS4he6f7/y6O3qE= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= diff --git a/router/pkg/pubsub/datasource/datasource.go b/router/pkg/pubsub/datasource/datasource.go index 2f08b97074..3a3018b745 100644 --- a/router/pkg/pubsub/datasource/datasource.go +++ b/router/pkg/pubsub/datasource/datasource.go @@ -1,9 +1,17 @@ package datasource import ( + "github.com/cespare/xxhash/v2" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) +type SubscriptionDataSource interface { + SubscriptionEventConfiguration(input []byte) (SubscriptionEventConfiguration, error) + Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error + UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) (err error) + SetSubscriptionOnStartFns(fns ...SubscriptionOnStartFn) +} + // EngineDataSourceFactory is the interface that all pubsub data sources must implement. // It serves three main purposes: // 1. Resolving the data source and subscription data source @@ -23,7 +31,7 @@ type EngineDataSourceFactory interface { // ResolveDataSourceSubscription returns the engine SubscriptionDataSource implementation // that contains methods to start a subscription, which will be called by the Planner // when a subscription is initiated - ResolveDataSourceSubscription() (resolve.SubscriptionDataSource, error) + ResolveDataSourceSubscription() (SubscriptionDataSource, error) // ResolveDataSourceSubscriptionInput build the input that will be passed to the engine SubscriptionDataSource ResolveDataSourceSubscriptionInput() (string, error) // TransformEventData allows the data source to transform the event data using the extractFn diff --git a/router/pkg/pubsub/datasource/factory.go b/router/pkg/pubsub/datasource/factory.go index cbceb1a651..5c42161776 100644 --- a/router/pkg/pubsub/datasource/factory.go +++ b/router/pkg/pubsub/datasource/factory.go @@ -9,14 +9,16 @@ import ( ) type PlannerConfig[PB ProviderBuilder[P, E], P any, E any] struct { - ProviderBuilder PB - Event E + ProviderBuilder PB + Event E + SubscriptionOnStartFns []SubscriptionOnStartFn } -func NewPlannerConfig[PB ProviderBuilder[P, E], P any, E any](providerBuilder PB, event E) *PlannerConfig[PB, P, E] { +func NewPlannerConfig[PB ProviderBuilder[P, E], P any, E any](providerBuilder PB, event E, subscriptionOnStartFns []SubscriptionOnStartFn) *PlannerConfig[PB, P, E] { return &PlannerConfig[PB, P, E]{ - ProviderBuilder: providerBuilder, - Event: event, + ProviderBuilder: providerBuilder, + Event: event, + SubscriptionOnStartFns: subscriptionOnStartFns, } } diff --git a/router/pkg/pubsub/datasource/mocks.go b/router/pkg/pubsub/datasource/mocks.go index a6bbb19e18..861beb3987 100644 --- a/router/pkg/pubsub/datasource/mocks.go +++ b/router/pkg/pubsub/datasource/mocks.go @@ -198,23 +198,23 @@ func (_c *MockEngineDataSourceFactory_ResolveDataSourceInput_Call) RunAndReturn( } // ResolveDataSourceSubscription provides a mock function for the type MockEngineDataSourceFactory -func (_mock *MockEngineDataSourceFactory) ResolveDataSourceSubscription() (resolve.SubscriptionDataSource, error) { +func (_mock *MockEngineDataSourceFactory) ResolveDataSourceSubscription() (SubscriptionDataSource, error) { ret := _mock.Called() if len(ret) == 0 { panic("no return value specified for ResolveDataSourceSubscription") } - var r0 resolve.SubscriptionDataSource + var r0 SubscriptionDataSource var r1 error - if returnFunc, ok := ret.Get(0).(func() (resolve.SubscriptionDataSource, error)); ok { + if returnFunc, ok := ret.Get(0).(func() (SubscriptionDataSource, error)); ok { return returnFunc() } - if returnFunc, ok := ret.Get(0).(func() resolve.SubscriptionDataSource); ok { + if returnFunc, ok := ret.Get(0).(func() SubscriptionDataSource); ok { r0 = returnFunc() } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(resolve.SubscriptionDataSource) + r0 = ret.Get(0).(SubscriptionDataSource) } } if returnFunc, ok := ret.Get(1).(func() error); ok { @@ -242,12 +242,12 @@ func (_c *MockEngineDataSourceFactory_ResolveDataSourceSubscription_Call) Run(ru return _c } -func (_c *MockEngineDataSourceFactory_ResolveDataSourceSubscription_Call) Return(subscriptionDataSource resolve.SubscriptionDataSource, err error) *MockEngineDataSourceFactory_ResolveDataSourceSubscription_Call { +func (_c *MockEngineDataSourceFactory_ResolveDataSourceSubscription_Call) Return(subscriptionDataSource SubscriptionDataSource, err error) *MockEngineDataSourceFactory_ResolveDataSourceSubscription_Call { _c.Call.Return(subscriptionDataSource, err) return _c } -func (_c *MockEngineDataSourceFactory_ResolveDataSourceSubscription_Call) RunAndReturn(run func() (resolve.SubscriptionDataSource, error)) *MockEngineDataSourceFactory_ResolveDataSourceSubscription_Call { +func (_c *MockEngineDataSourceFactory_ResolveDataSourceSubscription_Call) RunAndReturn(run func() (SubscriptionDataSource, error)) *MockEngineDataSourceFactory_ResolveDataSourceSubscription_Call { _c.Call.Return(run) return _c } @@ -356,13 +356,13 @@ func (_c *MockEngineDataSourceFactory_TransformEventData_Call) RunAndReturn(run return _c } -// NewMockProviderLifecycle creates a new instance of MockProviderLifecycle. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// NewMockLifecycle creates a new instance of MockLifecycle. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. // The first argument is typically a *testing.T value. -func NewMockProviderLifecycle(t interface { +func NewMockLifecycle(t interface { mock.TestingT Cleanup(func()) -}) *MockProviderLifecycle { - mock := &MockProviderLifecycle{} +}) *MockLifecycle { + mock := &MockLifecycle{} mock.Mock.Test(t) t.Cleanup(func() { mock.AssertExpectations(t) }) @@ -370,21 +370,21 @@ func NewMockProviderLifecycle(t interface { return mock } -// MockProviderLifecycle is an autogenerated mock type for the ProviderLifecycle type -type MockProviderLifecycle struct { +// MockLifecycle is an autogenerated mock type for the Lifecycle type +type MockLifecycle struct { mock.Mock } -type MockProviderLifecycle_Expecter struct { +type MockLifecycle_Expecter struct { mock *mock.Mock } -func (_m *MockProviderLifecycle) EXPECT() *MockProviderLifecycle_Expecter { - return &MockProviderLifecycle_Expecter{mock: &_m.Mock} +func (_m *MockLifecycle) EXPECT() *MockLifecycle_Expecter { + return &MockLifecycle_Expecter{mock: &_m.Mock} } -// Shutdown provides a mock function for the type MockProviderLifecycle -func (_mock *MockProviderLifecycle) Shutdown(ctx context.Context) error { +// Shutdown provides a mock function for the type MockLifecycle +func (_mock *MockLifecycle) Shutdown(ctx context.Context) error { ret := _mock.Called(ctx) if len(ret) == 0 { @@ -400,18 +400,18 @@ func (_mock *MockProviderLifecycle) Shutdown(ctx context.Context) error { return r0 } -// MockProviderLifecycle_Shutdown_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Shutdown' -type MockProviderLifecycle_Shutdown_Call struct { +// MockLifecycle_Shutdown_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Shutdown' +type MockLifecycle_Shutdown_Call struct { *mock.Call } // Shutdown is a helper method to define mock.On call // - ctx context.Context -func (_e *MockProviderLifecycle_Expecter) Shutdown(ctx interface{}) *MockProviderLifecycle_Shutdown_Call { - return &MockProviderLifecycle_Shutdown_Call{Call: _e.mock.On("Shutdown", ctx)} +func (_e *MockLifecycle_Expecter) Shutdown(ctx interface{}) *MockLifecycle_Shutdown_Call { + return &MockLifecycle_Shutdown_Call{Call: _e.mock.On("Shutdown", ctx)} } -func (_c *MockProviderLifecycle_Shutdown_Call) Run(run func(ctx context.Context)) *MockProviderLifecycle_Shutdown_Call { +func (_c *MockLifecycle_Shutdown_Call) Run(run func(ctx context.Context)) *MockLifecycle_Shutdown_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -424,18 +424,18 @@ func (_c *MockProviderLifecycle_Shutdown_Call) Run(run func(ctx context.Context) return _c } -func (_c *MockProviderLifecycle_Shutdown_Call) Return(err error) *MockProviderLifecycle_Shutdown_Call { +func (_c *MockLifecycle_Shutdown_Call) Return(err error) *MockLifecycle_Shutdown_Call { _c.Call.Return(err) return _c } -func (_c *MockProviderLifecycle_Shutdown_Call) RunAndReturn(run func(ctx context.Context) error) *MockProviderLifecycle_Shutdown_Call { +func (_c *MockLifecycle_Shutdown_Call) RunAndReturn(run func(ctx context.Context) error) *MockLifecycle_Shutdown_Call { _c.Call.Return(run) return _c } -// Startup provides a mock function for the type MockProviderLifecycle -func (_mock *MockProviderLifecycle) Startup(ctx context.Context) error { +// Startup provides a mock function for the type MockLifecycle +func (_mock *MockLifecycle) Startup(ctx context.Context) error { ret := _mock.Called(ctx) if len(ret) == 0 { @@ -451,18 +451,18 @@ func (_mock *MockProviderLifecycle) Startup(ctx context.Context) error { return r0 } -// MockProviderLifecycle_Startup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Startup' -type MockProviderLifecycle_Startup_Call struct { +// MockLifecycle_Startup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Startup' +type MockLifecycle_Startup_Call struct { *mock.Call } // Startup is a helper method to define mock.On call // - ctx context.Context -func (_e *MockProviderLifecycle_Expecter) Startup(ctx interface{}) *MockProviderLifecycle_Startup_Call { - return &MockProviderLifecycle_Startup_Call{Call: _e.mock.On("Startup", ctx)} +func (_e *MockLifecycle_Expecter) Startup(ctx interface{}) *MockLifecycle_Startup_Call { + return &MockLifecycle_Startup_Call{Call: _e.mock.On("Startup", ctx)} } -func (_c *MockProviderLifecycle_Startup_Call) Run(run func(ctx context.Context)) *MockProviderLifecycle_Startup_Call { +func (_c *MockLifecycle_Startup_Call) Run(run func(ctx context.Context)) *MockLifecycle_Startup_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -475,12 +475,12 @@ func (_c *MockProviderLifecycle_Startup_Call) Run(run func(ctx context.Context)) return _c } -func (_c *MockProviderLifecycle_Startup_Call) Return(err error) *MockProviderLifecycle_Startup_Call { +func (_c *MockLifecycle_Startup_Call) Return(err error) *MockLifecycle_Startup_Call { _c.Call.Return(err) return _c } -func (_c *MockProviderLifecycle_Startup_Call) RunAndReturn(run func(ctx context.Context) error) *MockProviderLifecycle_Startup_Call { +func (_c *MockLifecycle_Startup_Call) RunAndReturn(run func(ctx context.Context) error) *MockLifecycle_Startup_Call { _c.Call.Return(run) return _c } @@ -658,6 +658,69 @@ func (_c *MockProvider_Startup_Call) RunAndReturn(run func(ctx context.Context) return _c } +// Subscribe provides a mock function for the type MockProvider +func (_mock *MockProvider) Subscribe(ctx context.Context, cfg SubscriptionEventConfiguration, updater SubscriptionEventUpdater) error { + ret := _mock.Called(ctx, cfg, updater) + + if len(ret) == 0 { + panic("no return value specified for Subscribe") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, SubscriptionEventConfiguration, SubscriptionEventUpdater) error); ok { + r0 = returnFunc(ctx, cfg, updater) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockProvider_Subscribe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Subscribe' +type MockProvider_Subscribe_Call struct { + *mock.Call +} + +// Subscribe is a helper method to define mock.On call +// - ctx context.Context +// - cfg SubscriptionEventConfiguration +// - updater SubscriptionEventUpdater +func (_e *MockProvider_Expecter) Subscribe(ctx interface{}, cfg interface{}, updater interface{}) *MockProvider_Subscribe_Call { + return &MockProvider_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, cfg, updater)} +} + +func (_c *MockProvider_Subscribe_Call) Run(run func(ctx context.Context, cfg SubscriptionEventConfiguration, updater SubscriptionEventUpdater)) *MockProvider_Subscribe_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 SubscriptionEventConfiguration + if args[1] != nil { + arg1 = args[1].(SubscriptionEventConfiguration) + } + var arg2 SubscriptionEventUpdater + if args[2] != nil { + arg2 = args[2].(SubscriptionEventUpdater) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *MockProvider_Subscribe_Call) Return(err error) *MockProvider_Subscribe_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockProvider_Subscribe_Call) RunAndReturn(run func(ctx context.Context, cfg SubscriptionEventConfiguration, updater SubscriptionEventUpdater) error) *MockProvider_Subscribe_Call { + _c.Call.Return(run) + return _c +} + // TypeID provides a mock function for the type MockProvider func (_mock *MockProvider) TypeID() string { ret := _mock.Called() @@ -896,3 +959,143 @@ func (_c *MockProviderBuilder_TypeID_Call[P, E]) RunAndReturn(run func() string) _c.Call.Return(run) return _c } + +// NewMockSubscriptionEventUpdater creates a new instance of MockSubscriptionEventUpdater. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +// The first argument is typically a *testing.T value. +func NewMockSubscriptionEventUpdater(t interface { + mock.TestingT + Cleanup(func()) +}) *MockSubscriptionEventUpdater { + mock := &MockSubscriptionEventUpdater{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} + +// MockSubscriptionEventUpdater is an autogenerated mock type for the SubscriptionEventUpdater type +type MockSubscriptionEventUpdater struct { + mock.Mock +} + +type MockSubscriptionEventUpdater_Expecter struct { + mock *mock.Mock +} + +func (_m *MockSubscriptionEventUpdater) EXPECT() *MockSubscriptionEventUpdater_Expecter { + return &MockSubscriptionEventUpdater_Expecter{mock: &_m.Mock} +} + +// Close provides a mock function for the type MockSubscriptionEventUpdater +func (_mock *MockSubscriptionEventUpdater) Close(kind resolve.SubscriptionCloseKind) { + _mock.Called(kind) + return +} + +// MockSubscriptionEventUpdater_Close_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Close' +type MockSubscriptionEventUpdater_Close_Call struct { + *mock.Call +} + +// Close is a helper method to define mock.On call +// - kind resolve.SubscriptionCloseKind +func (_e *MockSubscriptionEventUpdater_Expecter) Close(kind interface{}) *MockSubscriptionEventUpdater_Close_Call { + return &MockSubscriptionEventUpdater_Close_Call{Call: _e.mock.On("Close", kind)} +} + +func (_c *MockSubscriptionEventUpdater_Close_Call) Run(run func(kind resolve.SubscriptionCloseKind)) *MockSubscriptionEventUpdater_Close_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 resolve.SubscriptionCloseKind + if args[0] != nil { + arg0 = args[0].(resolve.SubscriptionCloseKind) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockSubscriptionEventUpdater_Close_Call) Return() *MockSubscriptionEventUpdater_Close_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSubscriptionEventUpdater_Close_Call) RunAndReturn(run func(kind resolve.SubscriptionCloseKind)) *MockSubscriptionEventUpdater_Close_Call { + _c.Run(run) + return _c +} + +// Complete provides a mock function for the type MockSubscriptionEventUpdater +func (_mock *MockSubscriptionEventUpdater) Complete() { + _mock.Called() + return +} + +// MockSubscriptionEventUpdater_Complete_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Complete' +type MockSubscriptionEventUpdater_Complete_Call struct { + *mock.Call +} + +// Complete is a helper method to define mock.On call +func (_e *MockSubscriptionEventUpdater_Expecter) Complete() *MockSubscriptionEventUpdater_Complete_Call { + return &MockSubscriptionEventUpdater_Complete_Call{Call: _e.mock.On("Complete")} +} + +func (_c *MockSubscriptionEventUpdater_Complete_Call) Run(run func()) *MockSubscriptionEventUpdater_Complete_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSubscriptionEventUpdater_Complete_Call) Return() *MockSubscriptionEventUpdater_Complete_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSubscriptionEventUpdater_Complete_Call) RunAndReturn(run func()) *MockSubscriptionEventUpdater_Complete_Call { + _c.Run(run) + return _c +} + +// Update provides a mock function for the type MockSubscriptionEventUpdater +func (_mock *MockSubscriptionEventUpdater) Update(event StreamEvent) { + _mock.Called(event) + return +} + +// MockSubscriptionEventUpdater_Update_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Update' +type MockSubscriptionEventUpdater_Update_Call struct { + *mock.Call +} + +// Update is a helper method to define mock.On call +// - event StreamEvent +func (_e *MockSubscriptionEventUpdater_Expecter) Update(event interface{}) *MockSubscriptionEventUpdater_Update_Call { + return &MockSubscriptionEventUpdater_Update_Call{Call: _e.mock.On("Update", event)} +} + +func (_c *MockSubscriptionEventUpdater_Update_Call) Run(run func(event StreamEvent)) *MockSubscriptionEventUpdater_Update_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 StreamEvent + if args[0] != nil { + arg0 = args[0].(StreamEvent) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockSubscriptionEventUpdater_Update_Call) Return() *MockSubscriptionEventUpdater_Update_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSubscriptionEventUpdater_Update_Call) RunAndReturn(run func(event StreamEvent)) *MockSubscriptionEventUpdater_Update_Call { + _c.Run(run) + return _c +} diff --git a/router/pkg/pubsub/datasource/planner.go b/router/pkg/pubsub/datasource/planner.go index e3b54a92ec..a480f8270e 100644 --- a/router/pkg/pubsub/datasource/planner.go +++ b/router/pkg/pubsub/datasource/planner.go @@ -109,6 +109,7 @@ func (p *Planner[PB, P, E]) ConfigureSubscription() plan.SubscriptionConfigurati p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to get resolve data source subscription: %w", err)) return plan.SubscriptionConfiguration{} } + dataSource.SetSubscriptionOnStartFns(p.config.SubscriptionOnStartFns...) input, err := pubSubDataSource.ResolveDataSourceSubscriptionInput() if err != nil { diff --git a/router/pkg/pubsub/datasource/provider.go b/router/pkg/pubsub/datasource/provider.go index d9138630ca..33cac33782 100644 --- a/router/pkg/pubsub/datasource/provider.go +++ b/router/pkg/pubsub/datasource/provider.go @@ -4,20 +4,30 @@ import ( "context" "github.com/wundergraph/cosmo/router/pkg/metric" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) type ArgumentTemplateCallback func(tpl string) (string, error) -type ProviderLifecycle interface { +// Lifecycle is the interface that the provider must implement +// to allow the router to start and stop the provider +type Lifecycle interface { // Startup is the method called when the provider is started Startup(ctx context.Context) error // Shutdown is the method called when the provider is shut down Shutdown(ctx context.Context) error } +// Adapter is the interface that the provider must implement +// to implement the basic functionality +type Adapter interface { + Lifecycle + Subscribe(ctx context.Context, cfg SubscriptionEventConfiguration, updater SubscriptionEventUpdater) error +} + // Provider is the interface that the PubSub provider must implement type Provider interface { - ProviderLifecycle + Adapter // ID Get the provider ID as specified in the configuration ID() string // TypeID Get the provider type id (e.g. "kafka", "nats") @@ -34,6 +44,38 @@ type ProviderBuilder[P, E any] interface { BuildEngineDataSourceFactory(data E) (EngineDataSourceFactory, error) } +// ProviderType represents the type of pubsub provider +type ProviderType string + +const ( + ProviderTypeNats ProviderType = "nats" + ProviderTypeKafka ProviderType = "kafka" + ProviderTypeRedis ProviderType = "redis" +) + +// StreamEvent is a generic interface for all stream events +// Each provider will have its own event type that implements this interface +// there could be other common fields in the future, but for now we only have data +type StreamEvent interface { + GetData() []byte +} + +type SubscriptionOnStartFn func(ctx resolve.StartupHookContext, subConf SubscriptionEventConfiguration) error + +// SubscriptionEventConfiguration is the interface that all subscription event configurations must implement +type SubscriptionEventConfiguration interface { + ProviderID() string + ProviderType() ProviderType + RootFieldName() string // the root field name of the subscription in the schema +} + +// PublishEventConfiguration is the interface that all publish event configurations must implement +type PublishEventConfiguration interface { + ProviderID() string + ProviderType() ProviderType + RootFieldName() string // the root field name of the mutation in the schema +} + type ProviderOpts struct { StreamMetricStore metric.StreamMetricStore } diff --git a/router/pkg/pubsub/datasource/pubsubprovider.go b/router/pkg/pubsub/datasource/pubsubprovider.go index 9e1223d950..84561b06db 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider.go +++ b/router/pkg/pubsub/datasource/pubsubprovider.go @@ -9,7 +9,7 @@ import ( type PubSubProvider struct { id string typeID string - Adapter ProviderLifecycle + Adapter Adapter Logger *zap.Logger } @@ -35,7 +35,11 @@ func (p *PubSubProvider) Shutdown(ctx context.Context) error { return nil } -func NewPubSubProvider(id string, typeID string, adapter ProviderLifecycle, logger *zap.Logger) *PubSubProvider { +func (p *PubSubProvider) Subscribe(ctx context.Context, cfg SubscriptionEventConfiguration, updater SubscriptionEventUpdater) error { + return p.Adapter.Subscribe(ctx, cfg, updater) +} + +func NewPubSubProvider(id string, typeID string, adapter Adapter, logger *zap.Logger) *PubSubProvider { return &PubSubProvider{ id: id, typeID: typeID, diff --git a/router/pkg/pubsub/datasource/pubsubprovider_test.go b/router/pkg/pubsub/datasource/pubsubprovider_test.go index 6579b62072..134bfbd6bb 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider_test.go +++ b/router/pkg/pubsub/datasource/pubsubprovider_test.go @@ -10,7 +10,7 @@ import ( ) func TestProvider_Startup_Success(t *testing.T) { - mockAdapter := NewMockProviderLifecycle(t) + mockAdapter := NewMockProvider(t) mockAdapter.On("Startup", mock.Anything).Return(nil) provider := PubSubProvider{ @@ -22,7 +22,7 @@ func TestProvider_Startup_Success(t *testing.T) { } func TestProvider_Startup_Error(t *testing.T) { - mockAdapter := NewMockProviderLifecycle(t) + mockAdapter := NewMockProvider(t) mockAdapter.On("Startup", mock.Anything).Return(errors.New("connect error")) provider := PubSubProvider{ @@ -34,7 +34,7 @@ func TestProvider_Startup_Error(t *testing.T) { } func TestProvider_Shutdown_Success(t *testing.T) { - mockAdapter := NewMockProviderLifecycle(t) + mockAdapter := NewMockProvider(t) mockAdapter.On("Shutdown", mock.Anything).Return(nil) provider := PubSubProvider{ @@ -46,7 +46,7 @@ func TestProvider_Shutdown_Success(t *testing.T) { } func TestProvider_Shutdown_Error(t *testing.T) { - mockAdapter := NewMockProviderLifecycle(t) + mockAdapter := NewMockProvider(t) mockAdapter.On("Shutdown", mock.Anything).Return(errors.New("close error")) provider := PubSubProvider{ diff --git a/router/pkg/pubsub/datasource/subscription_datasource.go b/router/pkg/pubsub/datasource/subscription_datasource.go new file mode 100644 index 0000000000..e5c9c26ab6 --- /dev/null +++ b/router/pkg/pubsub/datasource/subscription_datasource.go @@ -0,0 +1,72 @@ +package datasource + +import ( + "encoding/json" + "errors" + + "github.com/cespare/xxhash/v2" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +type uniqueRequestIdFn func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error + +// PubSubSubscriptionDataSource is a data source for handling subscriptions using a Pub/Sub mechanism. +// It implements the SubscriptionDataSource interface and HookableSubscriptionDataSource +type PubSubSubscriptionDataSource[C SubscriptionEventConfiguration] struct { + pubSub Adapter + uniqueRequestID uniqueRequestIdFn + subscriptionOnStartFns []SubscriptionOnStartFn +} + +func (s *PubSubSubscriptionDataSource[C]) SubscriptionEventConfiguration(input []byte) (SubscriptionEventConfiguration, error) { + var subscriptionConfiguration C + err := json.Unmarshal(input, &subscriptionConfiguration) + return subscriptionConfiguration, err +} + +func (s *PubSubSubscriptionDataSource[C]) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return s.uniqueRequestID(ctx, input, xxh) +} + +func (s *PubSubSubscriptionDataSource[C]) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { + subConf, err := s.SubscriptionEventConfiguration(input) + if err != nil { + return err + } + + conf, ok := subConf.(C) + if !ok { + return errors.New("invalid subscription configuration") + } + + return s.pubSub.Subscribe(ctx.Context(), conf, NewSubscriptionEventUpdater(updater)) +} + +func (s *PubSubSubscriptionDataSource[C]) SubscriptionOnStart(ctx resolve.StartupHookContext, input []byte) (err error) { + for _, fn := range s.subscriptionOnStartFns { + conf, errConf := s.SubscriptionEventConfiguration(input) + if errConf != nil { + return err + } + err = fn(ctx, conf) + if err != nil { + return err + } + } + + return nil +} + +func (s *PubSubSubscriptionDataSource[C]) SetSubscriptionOnStartFns(fns ...SubscriptionOnStartFn) { + s.subscriptionOnStartFns = append(s.subscriptionOnStartFns, fns...) +} + +var _ SubscriptionDataSource = (*PubSubSubscriptionDataSource[SubscriptionEventConfiguration])(nil) +var _ resolve.HookableSubscriptionDataSource = (*PubSubSubscriptionDataSource[SubscriptionEventConfiguration])(nil) + +func NewPubSubSubscriptionDataSource[C SubscriptionEventConfiguration](pubSub Adapter, uniqueRequestIdFn uniqueRequestIdFn) *PubSubSubscriptionDataSource[C] { + return &PubSubSubscriptionDataSource[C]{ + pubSub: pubSub, + uniqueRequestID: uniqueRequestIdFn, + } +} diff --git a/router/pkg/pubsub/datasource/subscription_datasource_test.go b/router/pkg/pubsub/datasource/subscription_datasource_test.go new file mode 100644 index 0000000000..a9170d7edd --- /dev/null +++ b/router/pkg/pubsub/datasource/subscription_datasource_test.go @@ -0,0 +1,327 @@ +package datasource + +import ( + "context" + "encoding/json" + "errors" + "testing" + + "github.com/cespare/xxhash/v2" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +// testSubscriptionEventConfiguration implements SubscriptionEventConfiguration for testing +type testSubscriptionEventConfiguration struct { + Topic string `json:"topic"` + Subject string `json:"subject"` +} + +func (t testSubscriptionEventConfiguration) ProviderID() string { + return "test-provider" +} + +func (t testSubscriptionEventConfiguration) ProviderType() ProviderType { + return ProviderTypeNats +} + +func (t testSubscriptionEventConfiguration) RootFieldName() string { + return "testSubscription" +} + +func TestPubSubSubscriptionDataSource_SubscriptionEventConfiguration_Success(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + testConfig := testSubscriptionEventConfiguration{ + Topic: "test-topic", + Subject: "test-subject", + } + input, err := json.Marshal(testConfig) + assert.NoError(t, err) + + result, err := dataSource.SubscriptionEventConfiguration(input) + assert.NoError(t, err) + assert.NotNil(t, result) + + typedResult, ok := result.(testSubscriptionEventConfiguration) + assert.True(t, ok) + assert.Equal(t, "test-topic", typedResult.Topic) + assert.Equal(t, "test-subject", typedResult.Subject) +} + +func TestPubSubSubscriptionDataSource_SubscriptionEventConfiguration_InvalidJSON(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + invalidInput := []byte(`{"invalid": json}`) + result, err := dataSource.SubscriptionEventConfiguration(invalidInput) + assert.Error(t, err) + assert.Equal(t, testSubscriptionEventConfiguration{}, result) +} + +func TestPubSubSubscriptionDataSource_UniqueRequestID_Success(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + ctx := &resolve.Context{} + input := []byte(`{"test": "data"}`) + xxh := xxhash.New() + + err := dataSource.UniqueRequestID(ctx, input, xxh) + assert.NoError(t, err) +} + +func TestPubSubSubscriptionDataSource_UniqueRequestID_Error(t *testing.T) { + mockAdapter := NewMockProvider(t) + expectedError := errors.New("unique ID generation error") + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return expectedError + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + ctx := &resolve.Context{} + input := []byte(`{"test": "data"}`) + xxh := xxhash.New() + + err := dataSource.UniqueRequestID(ctx, input, xxh) + assert.Error(t, err) + assert.Equal(t, expectedError, err) +} + +func TestPubSubSubscriptionDataSource_Start_Success(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + testConfig := testSubscriptionEventConfiguration{ + Topic: "test-topic", + Subject: "test-subject", + } + input, err := json.Marshal(testConfig) + assert.NoError(t, err) + + ctx := resolve.NewContext(context.Background()) + mockUpdater := NewMockSubscriptionUpdater(t) + + mockAdapter.On("Subscribe", ctx.Context(), testConfig, mock.AnythingOfType("*datasource.subscriptionEventUpdater")).Return(nil) + + err = dataSource.Start(ctx, input, mockUpdater) + assert.NoError(t, err) + mockAdapter.AssertExpectations(t) +} + +func TestPubSubSubscriptionDataSource_Start_NoConfiguration(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + invalidInput := []byte(`{"invalid": json}`) + ctx := resolve.NewContext(context.Background()) + mockUpdater := NewMockSubscriptionUpdater(t) + + err := dataSource.Start(ctx, invalidInput, mockUpdater) + assert.Error(t, err) + assert.Contains(t, err.Error(), "invalid character 'j' looking for beginning of value") +} + +func TestPubSubSubscriptionDataSource_Start_SubscribeError(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + testConfig := testSubscriptionEventConfiguration{ + Topic: "test-topic", + Subject: "test-subject", + } + input, err := json.Marshal(testConfig) + assert.NoError(t, err) + + ctx := resolve.NewContext(context.Background()) + mockUpdater := NewMockSubscriptionUpdater(t) + expectedError := errors.New("subscription error") + + mockAdapter.On("Subscribe", ctx.Context(), testConfig, mock.AnythingOfType("*datasource.subscriptionEventUpdater")).Return(expectedError) + + err = dataSource.Start(ctx, input, mockUpdater) + assert.Error(t, err) + assert.Equal(t, expectedError, err) + mockAdapter.AssertExpectations(t) +} + +func TestPubSubSubscriptionDataSource_SubscriptionOnStart_Success(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + testConfig := testSubscriptionEventConfiguration{ + Topic: "test-topic", + Subject: "test-subject", + } + input, err := json.Marshal(testConfig) + assert.NoError(t, err) + + ctx := resolve.StartupHookContext{ + Context: context.Background(), + Updater: func(data []byte) {}, + } + + err = dataSource.SubscriptionOnStart(ctx, input) + assert.NoError(t, err) +} + +func TestPubSubSubscriptionDataSource_SubscriptionOnStart_WithHooks(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + // Add subscription start hooks + hook1Called := false + hook2Called := false + + hook1 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { + hook1Called = true + return nil + } + + hook2 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { + hook2Called = true + return nil + } + + dataSource.SetSubscriptionOnStartFns(hook1, hook2) + + testConfig := testSubscriptionEventConfiguration{ + Topic: "test-topic", + Subject: "test-subject", + } + input, err := json.Marshal(testConfig) + assert.NoError(t, err) + + ctx := resolve.StartupHookContext{ + Context: context.Background(), + Updater: func(data []byte) {}, + } + + err = dataSource.SubscriptionOnStart(ctx, input) + assert.NoError(t, err) + assert.True(t, hook1Called) + assert.True(t, hook2Called) +} + +func TestPubSubSubscriptionDataSource_SubscriptionOnStart_HookReturnsError(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + expectedError := errors.New("hook error") + // Add hook that returns an error + hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { + return expectedError + } + + dataSource.SetSubscriptionOnStartFns(hook) + + testConfig := testSubscriptionEventConfiguration{ + Topic: "test-topic", + Subject: "test-subject", + } + input, err := json.Marshal(testConfig) + assert.NoError(t, err) + + ctx := resolve.StartupHookContext{ + Context: context.Background(), + Updater: func(data []byte) {}, + } + + err = dataSource.SubscriptionOnStart(ctx, input) + assert.Error(t, err) + assert.Equal(t, expectedError, err) +} + +func TestPubSubSubscriptionDataSource_SetSubscriptionOnStartFns(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + // Initially should have no hooks + assert.Len(t, dataSource.subscriptionOnStartFns, 0) + + // Add hooks + hook1 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { + return nil + } + hook2 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { + return nil + } + + dataSource.SetSubscriptionOnStartFns(hook1) + assert.Len(t, dataSource.subscriptionOnStartFns, 1) + + dataSource.SetSubscriptionOnStartFns(hook2) + assert.Len(t, dataSource.subscriptionOnStartFns, 2) +} + +func TestNewPubSubSubscriptionDataSource(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + assert.NotNil(t, dataSource) + assert.Equal(t, mockAdapter, dataSource.pubSub) + assert.NotNil(t, dataSource.uniqueRequestID) + assert.Empty(t, dataSource.subscriptionOnStartFns) +} + +func TestPubSubSubscriptionDataSource_InterfaceCompliance(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + + // Test that it implements SubscriptionDataSource interface + var _ SubscriptionDataSource = dataSource + + // Test that it implements HookableSubscriptionDataSource interface + var _ resolve.HookableSubscriptionDataSource = dataSource +} diff --git a/router/pkg/pubsub/datasource/subscription_event_updater.go b/router/pkg/pubsub/datasource/subscription_event_updater.go new file mode 100644 index 0000000000..9332d10f7a --- /dev/null +++ b/router/pkg/pubsub/datasource/subscription_event_updater.go @@ -0,0 +1,34 @@ +package datasource + +import "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + +// SubscriptionEventUpdater is a wrapper around the SubscriptionUpdater interface +// that provides a way to send the event struct instead of the raw data +// It is used to give access to the event additional fields to the hooks. +type SubscriptionEventUpdater interface { + Update(event StreamEvent) + Complete() + Close(kind resolve.SubscriptionCloseKind) +} + +type subscriptionEventUpdater struct { + eventUpdater resolve.SubscriptionUpdater +} + +func (h *subscriptionEventUpdater) Update(event StreamEvent) { + h.eventUpdater.Update(event.GetData()) +} + +func (h *subscriptionEventUpdater) Complete() { + h.eventUpdater.Complete() +} + +func (h *subscriptionEventUpdater) Close(kind resolve.SubscriptionCloseKind) { + h.eventUpdater.Close(kind) +} + +func NewSubscriptionEventUpdater(eventUpdater resolve.SubscriptionUpdater) SubscriptionEventUpdater { + return &subscriptionEventUpdater{ + eventUpdater: eventUpdater, + } +} diff --git a/router/pkg/pubsub/kafka/adapter.go b/router/pkg/pubsub/kafka/adapter.go index e11993b668..fa906370ab 100644 --- a/router/pkg/pubsub/kafka/adapter.go +++ b/router/pkg/pubsub/kafka/adapter.go @@ -13,7 +13,6 @@ import ( "github.com/twmb/franz-go/pkg/kerr" "github.com/twmb/franz-go/pkg/kgo" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" ) @@ -28,7 +27,7 @@ const ( // Adapter defines the interface for Kafka adapter operations type Adapter interface { - Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error + Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error Publish(ctx context.Context, event PublishEventConfiguration) error Startup(ctx context.Context) error Shutdown(ctx context.Context) error @@ -54,7 +53,7 @@ type PollerOpts struct { } // topicPoller polls the Kafka topic for new records and calls the updateTriggers function. -func (p *ProviderAdapter) topicPoller(ctx context.Context, client *kgo.Client, updater resolve.SubscriptionUpdater, pollerOpts PollerOpts) error { +func (p *ProviderAdapter) topicPoller(ctx context.Context, client *kgo.Client, updater datasource.SubscriptionEventUpdater, pollerOpts PollerOpts) error { for { select { case <-p.ctx.Done(): // Close the poller if the application context was canceled @@ -100,13 +99,25 @@ func (p *ProviderAdapter) topicPoller(ctx context.Context, client *kgo.Client, u r := iter.Next() p.logger.Debug("subscription update", zap.String("topic", r.Topic), zap.ByteString("data", r.Value)) + + headers := make(map[string][]byte) + for _, header := range r.Headers { + headers[header.Key] = header.Value + } + p.streamMetricStore.Consume(p.ctx, metric.StreamsEvent{ ProviderId: pollerOpts.providerId, StreamOperationName: kafkaReceive, ProviderType: metric.ProviderTypeKafka, DestinationName: r.Topic, }) - updater.Update(r.Value) + + updater.Update(&Event{ + Data: r.Value, + Headers: headers, + Key: r.Key, + }) + } } } @@ -114,23 +125,27 @@ func (p *ProviderAdapter) topicPoller(ctx context.Context, client *kgo.Client, u // Subscribe subscribes to the given topics and updates the subscription updater. // The engine already deduplicates subscriptions with the same topics, stream configuration, extensions, headers, etc. -func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error { +func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { + subConf, ok := conf.(*SubscriptionEventConfiguration) + if !ok { + return datasource.NewError("invalid event type for Kafka adapter", nil) + } log := p.logger.With( - zap.String("provider_id", event.ProviderID), + zap.String("provider_id", subConf.ProviderID()), zap.String("method", "subscribe"), - zap.Strings("topics", event.Topics), + zap.Strings("topics", subConf.Topics), ) // Create a new client for the topic client, err := kgo.NewClient(append(p.opts, - kgo.ConsumeTopics(event.Topics...), + kgo.ConsumeTopics(subConf.Topics...), // We want to consume the events produced after the first subscription was created // Messages are shared among all subscriptions, therefore old events are not redelivered // This replicates a stateless publish-subscribe model kgo.ConsumeResetOffset(kgo.NewOffset().AfterMilli(time.Now().UnixMilli())), // For observability, we set the client ID to "router" - kgo.ClientID(fmt.Sprintf("cosmo.router.consumer.%s", strings.Join(event.Topics, "-"))), + kgo.ClientID(fmt.Sprintf("cosmo.router.consumer.%s", strings.Join(subConf.Topics, "-"))), // FIXME: the client id should have some unique identifier, like in nats // What if we have multiple subscriptions for the same topics? // What if we have more router instances? @@ -146,7 +161,7 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEvent defer p.closeWg.Done() - err := p.topicPoller(ctx, client, updater, PollerOpts{providerId: event.ProviderID}) + err := p.topicPoller(ctx, client, updater, PollerOpts{providerId: conf.ProviderID()}) if err != nil { if errors.Is(err, errClientClosed) || errors.Is(err, context.Canceled) { log.Debug("poller canceled", zap.Error(err)) @@ -166,7 +181,7 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEvent // The event is written with a dedicated write client. func (p *ProviderAdapter) Publish(ctx context.Context, event PublishEventConfiguration) error { log := p.logger.With( - zap.String("provider_id", event.ProviderID), + zap.String("provider_id", event.ProviderID()), zap.String("method", "publish"), zap.String("topic", event.Topic), ) @@ -175,16 +190,26 @@ func (p *ProviderAdapter) Publish(ctx context.Context, event PublishEventConfigu return datasource.NewError("kafka write client not initialized", nil) } - log.Debug("publish", zap.ByteString("data", event.Data)) + log.Debug("publish", zap.ByteString("data", event.Event.Data)) var wg sync.WaitGroup wg.Add(1) var pErr error + headers := make([]kgo.RecordHeader, 0, len(event.Event.Headers)) + for key, value := range event.Event.Headers { + headers = append(headers, kgo.RecordHeader{ + Key: key, + Value: value, + }) + } + p.writeClient.Produce(ctx, &kgo.Record{ - Topic: event.Topic, - Value: event.Data, + Key: event.Event.Key, + Topic: event.Topic, + Value: event.Event.Data, + Headers: headers, }, func(record *kgo.Record, err error) { defer wg.Done() if err != nil { @@ -198,7 +223,7 @@ func (p *ProviderAdapter) Publish(ctx context.Context, event PublishEventConfigu log.Error("publish error", zap.Error(pErr)) // failure emission: include error.type generic p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: event.ProviderID(), StreamOperationName: kafkaProduce, ProviderType: metric.ProviderTypeKafka, ErrorType: "publish_error", @@ -208,7 +233,7 @@ func (p *ProviderAdapter) Publish(ctx context.Context, event PublishEventConfigu } p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: event.ProviderID(), StreamOperationName: kafkaProduce, ProviderType: metric.ProviderTypeKafka, DestinationName: event.Topic, diff --git a/router/pkg/pubsub/kafka/engine_datasource.go b/router/pkg/pubsub/kafka/engine_datasource.go index 7b82a766b0..723c0d0bd0 100644 --- a/router/pkg/pubsub/kafka/engine_datasource.go +++ b/router/pkg/pubsub/kafka/engine_datasource.go @@ -7,61 +7,78 @@ import ( "fmt" "io" - "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) +// Event represents an event from Kafka +type Event struct { + Key []byte `json:"key"` + Data json.RawMessage `json:"data"` + Headers map[string][]byte `json:"headers"` +} + +func (e *Event) GetData() []byte { + return e.Data +} + type SubscriptionEventConfiguration struct { - ProviderID string `json:"providerId"` - Topics []string `json:"topics"` + Provider string `json:"providerId"` + Topics []string `json:"topics"` + FieldName string `json:"rootFieldName"` } -type PublishEventConfiguration struct { - ProviderID string `json:"providerId"` - Topic string `json:"topic"` - Data json.RawMessage `json:"data"` +// ProviderID returns the provider ID +func (s *SubscriptionEventConfiguration) ProviderID() string { + return s.Provider } -func (s *PublishEventConfiguration) MarshalJSONTemplate() string { - // The content of the data field could be not valid JSON, so we can't use json.Marshal - // e.g. {"id":$$0$$,"update":$$1$$} - return fmt.Sprintf(`{"topic":"%s", "data": %s, "providerId":"%s"}`, s.Topic, s.Data, s.ProviderID) +// ProviderType returns the provider type +func (s *SubscriptionEventConfiguration) ProviderType() datasource.ProviderType { + return datasource.ProviderTypeKafka } -type SubscriptionDataSource struct { - pubSub Adapter +// RootFieldName returns the root field name +func (s *SubscriptionEventConfiguration) RootFieldName() string { + return s.FieldName } -func (s *SubscriptionDataSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { - val, _, _, err := jsonparser.Get(input, "topics") - if err != nil { - return err - } +type PublishEventConfiguration struct { + Provider string `json:"providerId"` + Topic string `json:"topic"` + Event Event `json:"event"` + FieldName string `json:"rootFieldName"` +} - _, err = xxh.Write(val) - if err != nil { - return err - } +// ProviderID returns the provider ID +func (p *PublishEventConfiguration) ProviderID() string { + return p.Provider +} - val, _, _, err = jsonparser.Get(input, "providerId") - if err != nil { - return err - } +// ProviderType returns the provider type +func (p *PublishEventConfiguration) ProviderType() datasource.ProviderType { + return datasource.ProviderTypeKafka +} - _, err = xxh.Write(val) - return err +// RootFieldName returns the root field name +func (p *PublishEventConfiguration) RootFieldName() string { + return p.FieldName } -func (s *SubscriptionDataSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { - var subscriptionConfiguration SubscriptionEventConfiguration - err := json.Unmarshal(input, &subscriptionConfiguration) +func (s *PublishEventConfiguration) MarshalJSONTemplate() (string, error) { + // The content of the data field could be not valid JSON, so we can't use json.Marshal + // e.g. {"id":$$0$$,"update":$$1$$} + headers := s.Event.Headers + if headers == nil { + headers = make(map[string][]byte) + } + + headersBytes, err := json.Marshal(headers) if err != nil { - return err + return "", err } - return s.pubSub.Subscribe(ctx.Context(), subscriptionConfiguration, updater) + return fmt.Sprintf(`{"topic":"%s", "event": {"data": %s, "key": "%s", "headers": %s}, "providerId":"%s"}`, s.Topic, s.Event.Data, s.Event.Key, headersBytes, s.ProviderID()), nil } type PublishDataSource struct { @@ -70,8 +87,7 @@ type PublishDataSource struct { func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { var publishConfiguration PublishEventConfiguration - err := json.Unmarshal(input, &publishConfiguration) - if err != nil { + if err := json.Unmarshal(input, &publishConfiguration); err != nil { return err } @@ -79,10 +95,15 @@ func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.B _, err = io.WriteString(out, `{"success": false}`) return err } - _, err = io.WriteString(out, `{"success": true}`) + _, err := io.WriteString(out, `{"success": true}`) return err } func (s *PublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { panic("not implemented") } + +// Interface compliance checks +var _ datasource.SubscriptionEventConfiguration = (*SubscriptionEventConfiguration)(nil) +var _ datasource.PublishEventConfiguration = (*PublishEventConfiguration)(nil) +var _ datasource.StreamEvent = (*Event)(nil) diff --git a/router/pkg/pubsub/kafka/engine_datasource_factory.go b/router/pkg/pubsub/kafka/engine_datasource_factory.go index d360f02f26..30507bc13b 100644 --- a/router/pkg/pubsub/kafka/engine_datasource_factory.go +++ b/router/pkg/pubsub/kafka/engine_datasource_factory.go @@ -4,6 +4,8 @@ import ( "encoding/json" "fmt" + "github.com/buger/jsonparser" + "github.com/cespare/xxhash/v2" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) @@ -49,24 +51,44 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri } evtCfg := PublishEventConfiguration{ - ProviderID: c.providerId, - Topic: c.topics[0], - Data: eventData, + Provider: c.providerId, + Topic: c.topics[0], + Event: Event{Data: eventData}, + FieldName: c.fieldName, } - return evtCfg.MarshalJSONTemplate(), nil + return evtCfg.MarshalJSONTemplate() } -func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (resolve.SubscriptionDataSource, error) { - return &SubscriptionDataSource{ - pubSub: c.KafkaAdapter, - }, nil +func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (datasource.SubscriptionDataSource, error) { + return datasource.NewPubSubSubscriptionDataSource[*SubscriptionEventConfiguration]( + c.KafkaAdapter, + func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + val, _, _, err := jsonparser.Get(input, "topics") + if err != nil { + return err + } + + _, err = xxh.Write(val) + if err != nil { + return err + } + + val, _, _, err = jsonparser.Get(input, "providerId") + if err != nil { + return err + } + + _, err = xxh.Write(val) + return err + }), nil } func (c *EngineDataSourceFactory) ResolveDataSourceSubscriptionInput() (string, error) { evtCfg := SubscriptionEventConfiguration{ - ProviderID: c.providerId, - Topics: c.topics, + Provider: c.providerId, + Topics: c.topics, + FieldName: c.fieldName, } object, err := json.Marshal(evtCfg) if err != nil { diff --git a/router/pkg/pubsub/kafka/engine_datasource_factory_test.go b/router/pkg/pubsub/kafka/engine_datasource_factory_test.go index 254359a4bc..0b4ea9c59c 100644 --- a/router/pkg/pubsub/kafka/engine_datasource_factory_test.go +++ b/router/pkg/pubsub/kafka/engine_datasource_factory_test.go @@ -4,11 +4,15 @@ import ( "bytes" "context" "encoding/json" + "errors" "testing" + "github.com/cespare/xxhash/v2" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router/pkg/pubsub/pubsubtest" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) func TestKafkaEngineDataSourceFactory(t *testing.T) { @@ -33,7 +37,7 @@ func TestEngineDataSourceFactoryWithMockAdapter(t *testing.T) { // Configure mock expectations for Publish mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishEventConfiguration) bool { - return event.ProviderID == "test-provider" && event.Topic == "test-topic" + return event.ProviderID() == "test-provider" && event.Topic == "test-topic" })).Return(nil) // Create the data source with mock adapter @@ -137,3 +141,57 @@ func TestKafkaEngineDataSourceFactoryMultiTopicSubscription(t *testing.T) { require.Equal(t, "test-topic-1", subscriptionConfig.Topics[0], "Expected first topic to be 'test-topic-1'") require.Equal(t, "test-topic-2", subscriptionConfig.Topics[1], "Expected second topic to be 'test-topic-2'") } + +func TestKafkaEngineDataSourceFactory_UniqueRequestID(t *testing.T) { + tests := []struct { + name string + input string + expectError bool + expectedError error + }{ + { + name: "valid input", + input: `{"topics":["topic1", "topic2"], "providerId":"test-provider"}`, + expectError: false, + }, + { + name: "missing topics", + input: `{"providerId":"test-provider"}`, + expectError: true, + expectedError: errors.New("Key path not found"), + }, + { + name: "missing providerId", + input: `{"topics":["topic1", "topic2"]}`, + expectError: true, + expectedError: errors.New("Key path not found"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + factory := &EngineDataSourceFactory{ + KafkaAdapter: NewMockAdapter(t), + } + source, err := factory.ResolveDataSourceSubscription() + require.NoError(t, err) + ctx := &resolve.Context{} + input := []byte(tt.input) + xxh := xxhash.New() + + err = source.UniqueRequestID(ctx, input, xxh) + + if tt.expectError { + require.Error(t, err) + if tt.expectedError != nil { + // For jsonparser errors, just check if the error message contains the expected text + assert.Contains(t, err.Error(), tt.expectedError.Error()) + } + } else { + require.NoError(t, err) + // Check that the hash has been updated + assert.NotEqual(t, 0, xxh.Sum64()) + } + }) + } +} diff --git a/router/pkg/pubsub/kafka/engine_datasource_test.go b/router/pkg/pubsub/kafka/engine_datasource_test.go index 0ad92aeb20..eed485b246 100644 --- a/router/pkg/pubsub/kafka/engine_datasource_test.go +++ b/router/pkg/pubsub/kafka/engine_datasource_test.go @@ -7,12 +7,9 @@ import ( "errors" "testing" - "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { @@ -24,145 +21,46 @@ func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { { name: "simple configuration", config: PublishEventConfiguration{ - ProviderID: "test-provider", - Topic: "test-topic", - Data: json.RawMessage(`{"message":"hello"}`), + Provider: "test-provider", + Topic: "test-topic", + Event: Event{Data: json.RawMessage(`{"message":"hello"}`)}, }, - wantPattern: `{"topic":"test-topic", "data": {"message":"hello"}, "providerId":"test-provider"}`, + wantPattern: `{"topic":"test-topic", "event": {"data": {"message":"hello"}, "key": "", "headers": {}}, "providerId":"test-provider"}`, }, { name: "with special characters", config: PublishEventConfiguration{ - ProviderID: "test-provider-id", - Topic: "topic-with-hyphens", - Data: json.RawMessage(`{"message":"special \"quotes\" here"}`), + Provider: "test-provider-id", + Topic: "topic-with-hyphens", + Event: Event{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, }, - wantPattern: `{"topic":"topic-with-hyphens", "data": {"message":"special \"quotes\" here"}, "providerId":"test-provider-id"}`, + wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}, "key": "", "headers": {}}, "providerId":"test-provider-id"}`, }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := tt.config.MarshalJSONTemplate() - assert.Equal(t, tt.wantPattern, result) - }) - } -} - -func TestSubscriptionSource_UniqueRequestID(t *testing.T) { - tests := []struct { - name string - input string - expectError bool - expectedError error - }{ - { - name: "valid input", - input: `{"topics":["topic1", "topic2"], "providerId":"test-provider"}`, - expectError: false, - }, - { - name: "missing topics", - input: `{"providerId":"test-provider"}`, - expectError: true, - expectedError: errors.New("Key path not found"), - }, - { - name: "missing providerId", - input: `{"topics":["topic1", "topic2"]}`, - expectError: true, - expectedError: errors.New("Key path not found"), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - source := &SubscriptionDataSource{ - pubSub: NewMockAdapter(t), - } - ctx := &resolve.Context{} - input := []byte(tt.input) - xxh := xxhash.New() - - err := source.UniqueRequestID(ctx, input, xxh) - - if tt.expectError { - require.Error(t, err) - if tt.expectedError != nil { - // For jsonparser errors, just check if the error message contains the expected text - assert.Contains(t, err.Error(), tt.expectedError.Error()) - } - } else { - require.NoError(t, err) - // Check that the hash has been updated - assert.NotEqual(t, 0, xxh.Sum64()) - } - }) - } -} - -func TestSubscriptionSource_Start(t *testing.T) { - tests := []struct { - name string - input string - mockSetup func(*MockAdapter, *datasource.MockSubscriptionUpdater) - expectError bool - }{ { - name: "successful subscription", - input: `{"topics":["topic1", "topic2"], "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter, updater *datasource.MockSubscriptionUpdater) { - m.On("Subscribe", mock.Anything, SubscriptionEventConfiguration{ - ProviderID: "test-provider", - Topics: []string{"topic1", "topic2"}, - }, mock.Anything).Return(nil) + name: "with key", + config: PublishEventConfiguration{ + Provider: "test-provider-id", + Topic: "topic-with-hyphens", + Event: Event{Key: []byte("blablabla"), Data: json.RawMessage(`{}`)}, }, - expectError: false, + wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {}, "key": "blablabla", "headers": {}}, "providerId":"test-provider-id"}`, }, { - name: "adapter returns error", - input: `{"topics":["topic1"], "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter, updater *datasource.MockSubscriptionUpdater) { - m.On("Subscribe", mock.Anything, SubscriptionEventConfiguration{ - ProviderID: "test-provider", - Topics: []string{"topic1"}, - }, mock.Anything).Return(errors.New("subscription error")) + name: "with headers", + config: PublishEventConfiguration{ + Provider: "test-provider-id", + Topic: "topic-with-hyphens", + Event: Event{Headers: map[string][]byte{"key": []byte(`blablabla`)}, Data: json.RawMessage(`{}`)}, }, - expectError: true, - }, - { - name: "invalid input json", - input: `{"invalid json":`, - mockSetup: func(m *MockAdapter, updater *datasource.MockSubscriptionUpdater) {}, - expectError: true, + wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {}, "key": "", "headers": {"key":"YmxhYmxhYmxh"}}, "providerId":"test-provider-id"}`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mockAdapter := NewMockAdapter(t) - updater := datasource.NewMockSubscriptionUpdater(t) - tt.mockSetup(mockAdapter, updater) - - source := &SubscriptionDataSource{ - pubSub: mockAdapter, - } - - // Set up go context - goCtx := context.Background() - - // Create a resolve.Context with the standard context - resolveCtx := &resolve.Context{} - resolveCtx = resolveCtx.WithContext(goCtx) - - input := []byte(tt.input) - err := source.Start(resolveCtx, input, updater) - - if tt.expectError { - require.Error(t, err) - } else { - require.NoError(t, err) - } + result, err := tt.config.MarshalJSONTemplate() + assert.NoError(t, err) + assert.Equal(t, tt.wantPattern, result) }) } } @@ -178,12 +76,12 @@ func TestKafkaPublishDataSource_Load(t *testing.T) { }{ { name: "successful publish", - input: `{"topic":"test-topic", "data":{"message":"hello"}, "providerId":"test-provider"}`, + input: `{"topic":"test-topic", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { m.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishEventConfiguration) bool { - return event.ProviderID == "test-provider" && + return event.ProviderID() == "test-provider" && event.Topic == "test-topic" && - string(event.Data) == `{"message":"hello"}` + string(event.Event.Data) == `{"message":"hello"}` })).Return(nil) }, expectError: false, @@ -192,7 +90,7 @@ func TestKafkaPublishDataSource_Load(t *testing.T) { }, { name: "publish error", - input: `{"topic":"test-topic", "data":{"message":"hello"}, "providerId":"test-provider"}`, + input: `{"topic":"test-topic", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { m.On("Publish", mock.Anything, mock.Anything).Return(errors.New("publish error")) }, diff --git a/router/pkg/pubsub/kafka/mocks.go b/router/pkg/pubsub/kafka/mocks.go index f39aee8b4e..08faa08eb2 100644 --- a/router/pkg/pubsub/kafka/mocks.go +++ b/router/pkg/pubsub/kafka/mocks.go @@ -8,7 +8,7 @@ import ( "context" mock "github.com/stretchr/testify/mock" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" ) // NewMockAdapter creates a new instance of MockAdapter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. @@ -198,7 +198,7 @@ func (_c *MockAdapter_Startup_Call) RunAndReturn(run func(ctx context.Context) e } // Subscribe provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error { +func (_mock *MockAdapter) Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { ret := _mock.Called(ctx, event, updater) if len(ret) == 0 { @@ -206,7 +206,7 @@ func (_mock *MockAdapter) Subscribe(ctx context.Context, event SubscriptionEvent } var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, SubscriptionEventConfiguration, resolve.SubscriptionUpdater) error); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, datasource.SubscriptionEventConfiguration, datasource.SubscriptionEventUpdater) error); ok { r0 = returnFunc(ctx, event, updater) } else { r0 = ret.Error(0) @@ -221,25 +221,25 @@ type MockAdapter_Subscribe_Call struct { // Subscribe is a helper method to define mock.On call // - ctx context.Context -// - event SubscriptionEventConfiguration -// - updater resolve.SubscriptionUpdater +// - event datasource.SubscriptionEventConfiguration +// - updater datasource.SubscriptionEventUpdater func (_e *MockAdapter_Expecter) Subscribe(ctx interface{}, event interface{}, updater interface{}) *MockAdapter_Subscribe_Call { return &MockAdapter_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, event, updater)} } -func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater)) *MockAdapter_Subscribe_Call { +func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater)) *MockAdapter_Subscribe_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 SubscriptionEventConfiguration + var arg1 datasource.SubscriptionEventConfiguration if args[1] != nil { - arg1 = args[1].(SubscriptionEventConfiguration) + arg1 = args[1].(datasource.SubscriptionEventConfiguration) } - var arg2 resolve.SubscriptionUpdater + var arg2 datasource.SubscriptionEventUpdater if args[2] != nil { - arg2 = args[2].(resolve.SubscriptionUpdater) + arg2 = args[2].(datasource.SubscriptionEventUpdater) } run( arg0, @@ -255,7 +255,7 @@ func (_c *MockAdapter_Subscribe_Call) Return(err error) *MockAdapter_Subscribe_C return _c } -func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error) *MockAdapter_Subscribe_Call { +func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error) *MockAdapter_Subscribe_Call { _c.Call.Return(run) return _c } diff --git a/router/pkg/pubsub/nats/adapter.go b/router/pkg/pubsub/nats/adapter.go index d10f8cf93d..dcba74a03b 100644 --- a/router/pkg/pubsub/nats/adapter.go +++ b/router/pkg/pubsub/nats/adapter.go @@ -14,7 +14,6 @@ import ( "github.com/nats-io/nats.go" "github.com/nats-io/nats.go/jetstream" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" ) @@ -27,7 +26,7 @@ const ( // Adapter defines the methods that a NATS adapter should implement type Adapter interface { // Subscribe subscribes to the given events and sends updates to the updater - Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error + Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error // Publish publishes the given event to the specified subject Publish(ctx context.Context, event PublishAndRequestEventConfiguration) error // Request sends a request to the specified subject and writes the response to the given writer @@ -81,11 +80,15 @@ func (p *ProviderAdapter) getDurableConsumerName(durableName string, subjects [] return fmt.Sprintf("%s-%x", durableName, subjHash.Sum64()), nil } -func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error { +func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { + subConf, ok := conf.(*SubscriptionEventConfiguration) + if !ok { + return datasource.NewError("invalid event type for Kafka adapter", nil) + } log := p.logger.With( - zap.String("provider_id", event.ProviderID), + zap.String("provider_id", subConf.ProviderID()), zap.String("method", "subscribe"), - zap.Strings("subjects", event.Subjects), + zap.Strings("subjects", subConf.Subjects), ) if p.client == nil { @@ -96,24 +99,24 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEvent return datasource.NewError("nats jetstream not initialized", nil) } - if event.StreamConfiguration != nil { - durableConsumerName, err := p.getDurableConsumerName(event.StreamConfiguration.Consumer, event.Subjects) + if subConf.StreamConfiguration != nil { + durableConsumerName, err := p.getDurableConsumerName(subConf.StreamConfiguration.Consumer, subConf.Subjects) if err != nil { return err } consumerConfig := jetstream.ConsumerConfig{ Durable: durableConsumerName, - FilterSubjects: event.Subjects, + FilterSubjects: subConf.Subjects, } // Durable consumers are removed automatically only if the InactiveThreshold value is set - if event.StreamConfiguration.ConsumerInactiveThreshold > 0 { - consumerConfig.InactiveThreshold = time.Duration(event.StreamConfiguration.ConsumerInactiveThreshold) * time.Second + if subConf.StreamConfiguration.ConsumerInactiveThreshold > 0 { + consumerConfig.InactiveThreshold = time.Duration(subConf.StreamConfiguration.ConsumerInactiveThreshold) * time.Second } - consumer, err := p.js.CreateOrUpdateConsumer(ctx, event.StreamConfiguration.StreamName, consumerConfig) + consumer, err := p.js.CreateOrUpdateConsumer(ctx, subConf.StreamConfiguration.StreamName, consumerConfig) if err != nil { log.Error("creating or updating consumer", zap.Error(err)) - return datasource.NewError(fmt.Sprintf(`failed to create or update consumer for stream "%s"`, event.StreamConfiguration.StreamName), err) + return datasource.NewError(fmt.Sprintf(`failed to create or update consumer for stream "%s"`, subConf.StreamConfiguration.StreamName), err) } p.closeWg.Add(1) @@ -142,12 +145,16 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEvent log.Debug("subscription update", zap.String("message_subject", msg.Subject()), zap.ByteString("data", msg.Data())) p.streamMetricStore.Consume(p.ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: conf.ProviderID(), StreamOperationName: natsReceive, ProviderType: metric.ProviderTypeNats, DestinationName: msg.Subject(), }) - updater.Update(msg.Data()) + + updater.Update(&Event{ + Data: msg.Data(), + Headers: msg.Headers(), + }) // Acknowledge the message after it has been processed ackErr := msg.Ack() @@ -165,8 +172,8 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEvent } msgChan := make(chan *nats.Msg) - subscriptions := make([]*nats.Subscription, len(event.Subjects)) - for i, subject := range event.Subjects { + subscriptions := make([]*nats.Subscription, len(subConf.Subjects)) + for i, subject := range subConf.Subjects { subscription, err := p.client.ChanSubscribe(subject, msgChan) if err != nil { log.Error("subscribing to NATS subject", zap.Error(err), zap.String("subscription_subject", subject)) @@ -184,13 +191,18 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEvent select { case msg := <-msgChan: log.Debug("subscription update", zap.String("message_subject", msg.Subject), zap.ByteString("data", msg.Data)) + p.streamMetricStore.Consume(p.ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: conf.ProviderID(), StreamOperationName: natsReceive, ProviderType: metric.ProviderTypeNats, DestinationName: msg.Subject, }) - updater.Update(msg.Data) + + updater.Update(&Event{ + Data: msg.Data, + Headers: msg.Header, + }) case <-p.ctx.Done(): // When the application context is done, we stop the subscriptions for _, subscription := range subscriptions { @@ -220,7 +232,7 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEvent func (p *ProviderAdapter) Publish(ctx context.Context, event PublishAndRequestEventConfiguration) error { log := p.logger.With( - zap.String("provider_id", event.ProviderID), + zap.String("provider_id", event.ProviderID()), zap.String("method", "publish"), zap.String("subject", event.Subject), ) @@ -229,13 +241,13 @@ func (p *ProviderAdapter) Publish(ctx context.Context, event PublishAndRequestEv return datasource.NewError("nats client not initialized", nil) } - log.Debug("publish", zap.ByteString("data", event.Data)) + log.Debug("publish", zap.ByteString("data", event.Event.Data)) - err := p.client.Publish(event.Subject, event.Data) + err := p.client.Publish(event.Subject, event.Event.Data) if err != nil { log.Error("publish error", zap.Error(err)) p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: event.ProviderID(), StreamOperationName: natsPublish, ProviderType: metric.ProviderTypeNats, ErrorType: "publish_error", @@ -244,7 +256,7 @@ func (p *ProviderAdapter) Publish(ctx context.Context, event PublishAndRequestEv return datasource.NewError(fmt.Sprintf("error publishing to NATS subject %s", event.Subject), err) } else { p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: event.ProviderID(), StreamOperationName: natsPublish, ProviderType: metric.ProviderTypeNats, DestinationName: event.Subject, @@ -256,7 +268,7 @@ func (p *ProviderAdapter) Publish(ctx context.Context, event PublishAndRequestEv func (p *ProviderAdapter) Request(ctx context.Context, event PublishAndRequestEventConfiguration, w io.Writer) error { log := p.logger.With( - zap.String("provider_id", event.ProviderID), + zap.String("provider_id", event.ProviderID()), zap.String("method", "request"), zap.String("subject", event.Subject), ) @@ -265,13 +277,13 @@ func (p *ProviderAdapter) Request(ctx context.Context, event PublishAndRequestEv return datasource.NewError("nats client not initialized", nil) } - log.Debug("request", zap.ByteString("data", event.Data)) + log.Debug("request", zap.ByteString("data", event.Event.Data)) - msg, err := p.client.RequestWithContext(ctx, event.Subject, event.Data) + msg, err := p.client.RequestWithContext(ctx, event.Subject, event.Event.Data) if err != nil { log.Error("request error", zap.Error(err)) p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: event.ProviderID(), StreamOperationName: natsRequest, ProviderType: metric.ProviderTypeNats, ErrorType: "request_error", @@ -281,7 +293,7 @@ func (p *ProviderAdapter) Request(ctx context.Context, event PublishAndRequestEv } p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: event.ProviderID(), StreamOperationName: natsRequest, ProviderType: metric.ProviderTypeNats, DestinationName: event.Subject, diff --git a/router/pkg/pubsub/nats/engine_datasource.go b/router/pkg/pubsub/nats/engine_datasource.go index ffc23ca838..0fa41e5480 100644 --- a/router/pkg/pubsub/nats/engine_datasource.go +++ b/router/pkg/pubsub/nats/engine_datasource.go @@ -7,12 +7,20 @@ import ( "fmt" "io" - "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) +// Event represents an event from NATS +type Event struct { + Data json.RawMessage `json:"data"` + Headers map[string][]string `json:"headers"` +} + +func (e *Event) GetData() []byte { + return e.Data +} + type StreamConfiguration struct { Consumer string `json:"consumer"` ConsumerInactiveThreshold int32 `json:"consumerInactiveThreshold"` @@ -20,56 +28,53 @@ type StreamConfiguration struct { } type SubscriptionEventConfiguration struct { - ProviderID string `json:"providerId"` + Provider string `json:"providerId"` Subjects []string `json:"subjects"` StreamConfiguration *StreamConfiguration `json:"streamConfiguration,omitempty"` + FieldName string `json:"rootFieldName"` } -type PublishAndRequestEventConfiguration struct { - ProviderID string `json:"providerId"` - Subject string `json:"subject"` - Data json.RawMessage `json:"data"` +// ProviderID returns the provider ID +func (s *SubscriptionEventConfiguration) ProviderID() string { + return s.Provider } -func (s *PublishAndRequestEventConfiguration) MarshalJSONTemplate() string { - // The content of the data field could be not valid JSON, so we can't use json.Marshal - // e.g. {"id":$$0$$,"update":$$1$$} - return fmt.Sprintf(`{"subject":"%s", "data": %s, "providerId":"%s"}`, s.Subject, s.Data, s.ProviderID) +// ProviderType returns the provider type +func (s *SubscriptionEventConfiguration) ProviderType() datasource.ProviderType { + return datasource.ProviderTypeNats } -type SubscriptionSource struct { - pubSub Adapter +// RootFieldName returns the root field name +func (s *SubscriptionEventConfiguration) RootFieldName() string { + return s.FieldName } -func (s *SubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { - - val, _, _, err := jsonparser.Get(input, "subjects") - if err != nil { - return err - } - - _, err = xxh.Write(val) - if err != nil { - return err - } +type PublishAndRequestEventConfiguration struct { + Provider string `json:"providerId"` + Subject string `json:"subject"` + Event Event `json:"event"` + FieldName string `json:"rootFieldName"` +} - val, _, _, err = jsonparser.Get(input, "providerId") - if err != nil { - return err - } +// ProviderID returns the provider ID +func (p *PublishAndRequestEventConfiguration) ProviderID() string { + return p.Provider +} - _, err = xxh.Write(val) - return err +// ProviderType returns the provider type +func (p *PublishAndRequestEventConfiguration) ProviderType() datasource.ProviderType { + return datasource.ProviderTypeNats } -func (s *SubscriptionSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { - var subscriptionConfiguration SubscriptionEventConfiguration - err := json.Unmarshal(input, &subscriptionConfiguration) - if err != nil { - return err - } +// RootFieldName returns the root field name +func (p *PublishAndRequestEventConfiguration) RootFieldName() string { + return p.FieldName +} - return s.pubSub.Subscribe(ctx.Context(), subscriptionConfiguration, updater) +func (p *PublishAndRequestEventConfiguration) MarshalJSONTemplate() (string, error) { + // The content of the data field could be not valid JSON, so we can't use json.Marshal + // e.g. {"id":$$0$$,"update":$$1$$} + return fmt.Sprintf(`{"subject":"%s", "event": {"data": %s}, "providerId":"%s"}`, p.Subject, p.Event.Data, p.ProviderID()), nil } type NatsPublishDataSource struct { @@ -78,8 +83,7 @@ type NatsPublishDataSource struct { func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { var publishConfiguration PublishAndRequestEventConfiguration - err := json.Unmarshal(input, &publishConfiguration) - if err != nil { + if err := json.Unmarshal(input, &publishConfiguration); err != nil { return err } @@ -87,7 +91,7 @@ func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte, out *byt _, err = io.WriteString(out, `{"success": false}`) return err } - _, err = io.WriteString(out, `{"success": true}`) + _, err := io.WriteString(out, `{"success": true}`) return err } @@ -101,8 +105,7 @@ type NatsRequestDataSource struct { func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { var subscriptionConfiguration PublishAndRequestEventConfiguration - err := json.Unmarshal(input, &subscriptionConfiguration) - if err != nil { + if err := json.Unmarshal(input, &subscriptionConfiguration); err != nil { return err } @@ -112,3 +115,8 @@ func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte, out *byt func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) error { panic("not implemented") } + +// Interface compliance checks +var _ datasource.SubscriptionEventConfiguration = (*SubscriptionEventConfiguration)(nil) +var _ datasource.PublishEventConfiguration = (*PublishAndRequestEventConfiguration)(nil) +var _ datasource.StreamEvent = (*Event)(nil) diff --git a/router/pkg/pubsub/nats/engine_datasource_factory.go b/router/pkg/pubsub/nats/engine_datasource_factory.go index 48fd2849f7..36d3932e0d 100644 --- a/router/pkg/pubsub/nats/engine_datasource_factory.go +++ b/router/pkg/pubsub/nats/engine_datasource_factory.go @@ -5,6 +5,8 @@ import ( "fmt" "slices" + "github.com/buger/jsonparser" + "github.com/cespare/xxhash/v2" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" @@ -63,24 +65,44 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri subject := c.subjects[0] evtCfg := PublishAndRequestEventConfiguration{ - ProviderID: c.providerId, - Subject: subject, - Data: eventData, + Provider: c.providerId, + Subject: subject, + Event: Event{Data: eventData}, + FieldName: c.fieldName, } - return evtCfg.MarshalJSONTemplate(), nil + return evtCfg.MarshalJSONTemplate() } -func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (resolve.SubscriptionDataSource, error) { - return &SubscriptionSource{ - pubSub: c.NatsAdapter, - }, nil +func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (datasource.SubscriptionDataSource, error) { + return datasource.NewPubSubSubscriptionDataSource[*SubscriptionEventConfiguration]( + c.NatsAdapter, + func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + val, _, _, err := jsonparser.Get(input, "subjects") + if err != nil { + return err + } + + _, err = xxh.Write(val) + if err != nil { + return err + } + + val, _, _, err = jsonparser.Get(input, "providerId") + if err != nil { + return err + } + + _, err = xxh.Write(val) + return err + }), nil } func (c *EngineDataSourceFactory) ResolveDataSourceSubscriptionInput() (string, error) { evtCfg := SubscriptionEventConfiguration{ - ProviderID: c.providerId, - Subjects: c.subjects, + Provider: c.providerId, + Subjects: c.subjects, + FieldName: c.fieldName, } if c.withStreamConfiguration { evtCfg.StreamConfiguration = &StreamConfiguration{ diff --git a/router/pkg/pubsub/nats/engine_datasource_factory_test.go b/router/pkg/pubsub/nats/engine_datasource_factory_test.go index 57426ad34c..a94c8d5941 100644 --- a/router/pkg/pubsub/nats/engine_datasource_factory_test.go +++ b/router/pkg/pubsub/nats/engine_datasource_factory_test.go @@ -4,13 +4,16 @@ import ( "bytes" "context" "encoding/json" + "errors" "io" "testing" + "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router/pkg/pubsub/pubsubtest" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) func TestNatsEngineDataSourceFactory(t *testing.T) { @@ -34,7 +37,7 @@ func TestEngineDataSourceFactoryWithMockAdapter(t *testing.T) { // Configure mock expectations for Publish mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { - return event.ProviderID == "test-provider" && event.Subject == "test-subject" + return event.ProviderID() == "test-provider" && event.Subject == "test-subject" })).Return(nil) // Create the data source with mock adapter @@ -167,7 +170,7 @@ func TestEngineDataSourceFactory_RequestDataSource(t *testing.T) { // Configure mock expectations for Request mockAdapter.On("Request", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { - return event.ProviderID == "test-provider" && event.Subject == "test-subject" + return event.ProviderID() == "test-provider" && event.Subject == "test-subject" }), mock.Anything).Return(nil).Run(func(args mock.Arguments) { w := args.Get(2).(io.Writer) w.Write([]byte(`{"response": "test"}`)) @@ -253,3 +256,57 @@ func TestTransformEventConfig(t *testing.T) { assert.Contains(t, err.Error(), "invalid subject") }) } + +func TestNatsEngineDataSourceFactory_UniqueRequestID(t *testing.T) { + tests := []struct { + name string + input string + expectError bool + expectedError error + }{ + { + name: "valid input", + input: `{"subjects":["subject1", "subject2"], "providerId":"test-provider"}`, + expectError: false, + }, + { + name: "missing subjects", + input: `{"providerId":"test-provider"}`, + expectError: true, + expectedError: errors.New("Key path not found"), + }, + { + name: "missing providerId", + input: `{"subjects":["subject1", "subject2"]}`, + expectError: true, + expectedError: errors.New("Key path not found"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + factory := &EngineDataSourceFactory{ + NatsAdapter: NewMockAdapter(t), + } + source, err := factory.ResolveDataSourceSubscription() + require.NoError(t, err) + ctx := &resolve.Context{} + input := []byte(tt.input) + xxh := xxhash.New() + + err = source.UniqueRequestID(ctx, input, xxh) + + if tt.expectError { + require.Error(t, err) + if tt.expectedError != nil { + // For jsonparser errors, just check if the error message contains the expected text + assert.Contains(t, err.Error(), tt.expectedError.Error()) + } + } else { + require.NoError(t, err) + // Check that the hash has been updated + assert.NotEqual(t, 0, xxh.Sum64()) + } + }) + } +} diff --git a/router/pkg/pubsub/nats/engine_datasource_test.go b/router/pkg/pubsub/nats/engine_datasource_test.go index da21d4de88..5d060d2c0d 100644 --- a/router/pkg/pubsub/nats/engine_datasource_test.go +++ b/router/pkg/pubsub/nats/engine_datasource_test.go @@ -8,48 +8,11 @@ import ( "io" "testing" - "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) -func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { - tests := []struct { - name string - config PublishAndRequestEventConfiguration - wantPattern string - }{ - { - name: "simple configuration", - config: PublishAndRequestEventConfiguration{ - ProviderID: "test-provider", - Subject: "test-subject", - Data: json.RawMessage(`{"message":"hello"}`), - }, - wantPattern: `{"subject":"test-subject", "data": {"message":"hello"}, "providerId":"test-provider"}`, - }, - { - name: "with special characters", - config: PublishAndRequestEventConfiguration{ - ProviderID: "test-provider-id", - Subject: "subject-with-hyphens", - Data: json.RawMessage(`{"message":"special \"quotes\" here"}`), - }, - wantPattern: `{"subject":"subject-with-hyphens", "data": {"message":"special \"quotes\" here"}, "providerId":"test-provider-id"}`, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := tt.config.MarshalJSONTemplate() - assert.Equal(t, tt.wantPattern, result) - }) - } -} - func TestPublishAndRequestEventConfiguration_MarshalJSONTemplate(t *testing.T) { tests := []struct { name string @@ -59,149 +22,32 @@ func TestPublishAndRequestEventConfiguration_MarshalJSONTemplate(t *testing.T) { { name: "simple configuration", config: PublishAndRequestEventConfiguration{ - ProviderID: "test-provider", - Subject: "test-subject", - Data: json.RawMessage(`{"message":"hello"}`), + Provider: "test-provider", + Subject: "test-subject", + Event: Event{Data: json.RawMessage(`{"message":"hello"}`)}, }, - wantPattern: `{"subject":"test-subject", "data": {"message":"hello"}, "providerId":"test-provider"}`, + wantPattern: `{"subject":"test-subject", "event": {"data": {"message":"hello"}}, "providerId":"test-provider"}`, }, { name: "with special characters", config: PublishAndRequestEventConfiguration{ - ProviderID: "test-provider-id", - Subject: "subject-with-hyphens", - Data: json.RawMessage(`{"message":"special \"quotes\" here"}`), + Provider: "test-provider-id", + Subject: "subject-with-hyphens", + Event: Event{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, }, - wantPattern: `{"subject":"subject-with-hyphens", "data": {"message":"special \"quotes\" here"}, "providerId":"test-provider-id"}`, + wantPattern: `{"subject":"subject-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}}, "providerId":"test-provider-id"}`, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := tt.config.MarshalJSONTemplate() + result, err := tt.config.MarshalJSONTemplate() + assert.NoError(t, err) assert.Equal(t, tt.wantPattern, result) }) } } -func TestSubscriptionSource_UniqueRequestID(t *testing.T) { - tests := []struct { - name string - input string - expectError bool - expectedError error - }{ - { - name: "valid input", - input: `{"subjects":["subject1", "subject2"], "providerId":"test-provider"}`, - expectError: false, - }, - { - name: "missing subjects", - input: `{"providerId":"test-provider"}`, - expectError: true, - expectedError: errors.New("Key path not found"), - }, - { - name: "missing providerId", - input: `{"subjects":["subject1", "subject2"]}`, - expectError: true, - expectedError: errors.New("Key path not found"), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - source := &SubscriptionSource{ - pubSub: NewMockAdapter(t), - } - ctx := &resolve.Context{} - input := []byte(tt.input) - xxh := xxhash.New() - - err := source.UniqueRequestID(ctx, input, xxh) - - if tt.expectError { - require.Error(t, err) - if tt.expectedError != nil { - // For jsonparser errors, just check if the error message contains the expected text - assert.Contains(t, err.Error(), tt.expectedError.Error()) - } - } else { - require.NoError(t, err) - // Check that the hash has been updated - assert.NotEqual(t, 0, xxh.Sum64()) - } - }) - } -} - -func TestSubscriptionSource_Start(t *testing.T) { - tests := []struct { - name string - input string - mockSetup func(*MockAdapter, *datasource.MockSubscriptionUpdater) - expectError bool - }{ - { - name: "successful subscription", - input: `{"subjects":["subject1", "subject2"], "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter, updater *datasource.MockSubscriptionUpdater) { - m.On("Subscribe", mock.Anything, SubscriptionEventConfiguration{ - ProviderID: "test-provider", - Subjects: []string{"subject1", "subject2"}, - }, mock.Anything).Return(nil) - }, - expectError: false, - }, - { - name: "adapter returns error", - input: `{"subjects":["subject1"], "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter, updater *datasource.MockSubscriptionUpdater) { - m.On("Subscribe", mock.Anything, SubscriptionEventConfiguration{ - ProviderID: "test-provider", - Subjects: []string{"subject1"}, - }, mock.Anything).Return(errors.New("subscription error")) - }, - expectError: true, - }, - { - name: "invalid input json", - input: `{"invalid json":`, - mockSetup: func(m *MockAdapter, updater *datasource.MockSubscriptionUpdater) {}, - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockAdapter := NewMockAdapter(t) - updater := datasource.NewMockSubscriptionUpdater(t) - tt.mockSetup(mockAdapter, updater) - - source := &SubscriptionSource{ - pubSub: mockAdapter, - } - - // Set up go context - goCtx := context.Background() - - // Create a resolve.Context with the standard context - resolveCtx := &resolve.Context{} - resolveCtx = resolveCtx.WithContext(goCtx) - - input := []byte(tt.input) - err := source.Start(resolveCtx, input, updater) - - if tt.expectError { - require.Error(t, err) - } else { - require.NoError(t, err) - } - }) - } -} - func TestNatsPublishDataSource_Load(t *testing.T) { tests := []struct { name string @@ -213,12 +59,12 @@ func TestNatsPublishDataSource_Load(t *testing.T) { }{ { name: "successful publish", - input: `{"subject":"test-subject", "data":{"message":"hello"}, "providerId":"test-provider"}`, + input: `{"subject":"test-subject", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { m.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { - return event.ProviderID == "test-provider" && + return event.ProviderID() == "test-provider" && event.Subject == "test-subject" && - string(event.Data) == `{"message":"hello"}` + string(event.Event.Data) == `{"message":"hello"}` })).Return(nil) }, expectError: false, @@ -227,7 +73,7 @@ func TestNatsPublishDataSource_Load(t *testing.T) { }, { name: "publish error", - input: `{"subject":"test-subject", "data":{"message":"hello"}, "providerId":"test-provider"}`, + input: `{"subject":"test-subject", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { m.On("Publish", mock.Anything, mock.Anything).Return(errors.New("publish error")) }, @@ -288,12 +134,12 @@ func TestNatsRequestDataSource_Load(t *testing.T) { }{ { name: "successful request", - input: `{"subject":"test-subject", "data":{"message":"hello"}, "providerId":"test-provider"}`, + input: `{"subject":"test-subject", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { m.On("Request", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { - return event.ProviderID == "test-provider" && + return event.ProviderID() == "test-provider" && event.Subject == "test-subject" && - string(event.Data) == `{"message":"hello"}` + string(event.Event.Data) == `{"message":"hello"}` }), mock.Anything).Run(func(args mock.Arguments) { // Write response to the output buffer w := args.Get(2).(io.Writer) @@ -305,7 +151,7 @@ func TestNatsRequestDataSource_Load(t *testing.T) { }, { name: "request error", - input: `{"subject":"test-subject", "data":{"message":"hello"}, "providerId":"test-provider"}`, + input: `{"subject":"test-subject", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { m.On("Request", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("request error")) }, diff --git a/router/pkg/pubsub/nats/mocks.go b/router/pkg/pubsub/nats/mocks.go index de49c6ae7e..0bc3ada5f0 100644 --- a/router/pkg/pubsub/nats/mocks.go +++ b/router/pkg/pubsub/nats/mocks.go @@ -9,7 +9,7 @@ import ( "io" mock "github.com/stretchr/testify/mock" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" ) // NewMockAdapter creates a new instance of MockAdapter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. @@ -262,7 +262,7 @@ func (_c *MockAdapter_Startup_Call) RunAndReturn(run func(ctx context.Context) e } // Subscribe provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error { +func (_mock *MockAdapter) Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { ret := _mock.Called(ctx, event, updater) if len(ret) == 0 { @@ -270,7 +270,7 @@ func (_mock *MockAdapter) Subscribe(ctx context.Context, event SubscriptionEvent } var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, SubscriptionEventConfiguration, resolve.SubscriptionUpdater) error); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, datasource.SubscriptionEventConfiguration, datasource.SubscriptionEventUpdater) error); ok { r0 = returnFunc(ctx, event, updater) } else { r0 = ret.Error(0) @@ -285,25 +285,25 @@ type MockAdapter_Subscribe_Call struct { // Subscribe is a helper method to define mock.On call // - ctx context.Context -// - event SubscriptionEventConfiguration -// - updater resolve.SubscriptionUpdater +// - event datasource.SubscriptionEventConfiguration +// - updater datasource.SubscriptionEventUpdater func (_e *MockAdapter_Expecter) Subscribe(ctx interface{}, event interface{}, updater interface{}) *MockAdapter_Subscribe_Call { return &MockAdapter_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, event, updater)} } -func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater)) *MockAdapter_Subscribe_Call { +func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater)) *MockAdapter_Subscribe_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 SubscriptionEventConfiguration + var arg1 datasource.SubscriptionEventConfiguration if args[1] != nil { - arg1 = args[1].(SubscriptionEventConfiguration) + arg1 = args[1].(datasource.SubscriptionEventConfiguration) } - var arg2 resolve.SubscriptionUpdater + var arg2 datasource.SubscriptionEventUpdater if args[2] != nil { - arg2 = args[2].(resolve.SubscriptionUpdater) + arg2 = args[2].(datasource.SubscriptionEventUpdater) } run( arg0, @@ -319,7 +319,7 @@ func (_c *MockAdapter_Subscribe_Call) Return(err error) *MockAdapter_Subscribe_C return _c } -func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error) *MockAdapter_Subscribe_Call { +func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error) *MockAdapter_Subscribe_Call { _c.Call.Return(run) return _c } diff --git a/router/pkg/pubsub/pubsub.go b/router/pkg/pubsub/pubsub.go index b92aaad6f7..085de71a0e 100644 --- a/router/pkg/pubsub/pubsub.go +++ b/router/pkg/pubsub/pubsub.go @@ -51,9 +51,23 @@ func (e *ProviderNotDefinedError) Error() string { return fmt.Sprintf("%s provider with ID %s is not defined", e.ProviderTypeID, e.ProviderID) } +// Hooks contains hooks for the pubsub providers and data sources +type Hooks struct { + SubscriptionOnStart []pubsub_datasource.SubscriptionOnStartFn +} + // BuildProvidersAndDataSources is a generic function that builds providers and data sources for the given // EventsConfiguration and DataSourceConfigurationWithMetadata -func BuildProvidersAndDataSources(ctx context.Context, config config.EventsConfiguration, store metric.StreamMetricStore, logger *zap.Logger, dsConfs []DataSourceConfigurationWithMetadata, hostName string, routerListenAddr string) ([]pubsub_datasource.Provider, []plan.DataSource, error) { +func BuildProvidersAndDataSources( + ctx context.Context, + config config.EventsConfiguration, + store metric.StreamMetricStore, + logger *zap.Logger, + dsConfs []DataSourceConfigurationWithMetadata, + hostName string, + routerListenAddr string, + hooks Hooks, +) ([]pubsub_datasource.Provider, []plan.DataSource, error) { if store == nil { store = metric.NewNoopStreamMetricStore() } @@ -70,7 +84,7 @@ func BuildProvidersAndDataSources(ctx context.Context, config config.EventsConfi events: dsConf.Configuration.GetCustomEvents().GetKafka(), }) } - kafkaPubSubProviders, kafkaOuts, err := build(ctx, kafkaBuilder, config.Providers.Kafka, kafkaDsConfsWithEvents, store) + kafkaPubSubProviders, kafkaOuts, err := build(ctx, kafkaBuilder, config.Providers.Kafka, kafkaDsConfsWithEvents, store, hooks) if err != nil { return nil, nil, err } @@ -86,7 +100,7 @@ func BuildProvidersAndDataSources(ctx context.Context, config config.EventsConfi events: dsConf.Configuration.GetCustomEvents().GetNats(), }) } - natsPubSubProviders, natsOuts, err := build(ctx, natsBuilder, config.Providers.Nats, natsDsConfsWithEvents, store) + natsPubSubProviders, natsOuts, err := build(ctx, natsBuilder, config.Providers.Nats, natsDsConfsWithEvents, store, hooks) if err != nil { return nil, nil, err } @@ -102,7 +116,7 @@ func BuildProvidersAndDataSources(ctx context.Context, config config.EventsConfi events: dsConf.Configuration.GetCustomEvents().GetRedis(), }) } - redisPubSubProviders, redisOuts, err := build(ctx, redisBuilder, config.Providers.Redis, redisDsConfsWithEvents, store) + redisPubSubProviders, redisOuts, err := build(ctx, redisBuilder, config.Providers.Redis, redisDsConfsWithEvents, store, hooks) if err != nil { return nil, nil, err } @@ -118,6 +132,7 @@ func build[P GetID, E GetEngineEventConfiguration]( providersData []P, dsConfs []dsConfAndEvents[E], store metric.StreamMetricStore, + hooks Hooks, ) ([]pubsub_datasource.Provider, []plan.DataSource, error) { var pubSubProviders []pubsub_datasource.Provider var outs []plan.DataSource @@ -161,7 +176,7 @@ func build[P GetID, E GetEngineEventConfiguration]( // build data sources for each event for _, dsConf := range dsConfs { for i, event := range dsConf.events { - plannerConfig := pubsub_datasource.NewPlannerConfig(builder, event) + plannerConfig := pubsub_datasource.NewPlannerConfig(builder, event, hooks.SubscriptionOnStart) out, err := plan.NewDataSourceConfiguration( dsConf.dsConf.Configuration.Id+"-"+builder.TypeID()+"-"+strconv.Itoa(i), pubsub_datasource.NewPlannerFactory(ctx, plannerConfig), diff --git a/router/pkg/pubsub/pubsub_test.go b/router/pkg/pubsub/pubsub_test.go index 2173e46c3d..976980b4ff 100644 --- a/router/pkg/pubsub/pubsub_test.go +++ b/router/pkg/pubsub/pubsub_test.go @@ -3,9 +3,10 @@ package pubsub import ( "context" "errors" + "testing" + "github.com/stretchr/testify/mock" rmetric "github.com/wundergraph/cosmo/router/pkg/metric" - "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -67,7 +68,7 @@ func TestBuild_OK(t *testing.T) { // ctx, kafkaBuilder, config.Providers.Kafka, kafkaDsConfsWithEvents // Execute the function - providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore()) + providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), Hooks{}) // Assertions assert.NoError(t, err) @@ -123,7 +124,7 @@ func TestBuild_ProviderError(t *testing.T) { mockBuilder.On("BuildProvider", natsEventSources[0], mock.Anything).Return(nil, errors.New("provider error")) // Execute the function - providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore()) + providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), Hooks{}) // Assertions assert.Error(t, err) @@ -178,7 +179,7 @@ func TestBuild_ShouldGetAnErrorIfProviderIsNotDefined(t *testing.T) { mockBuilder.On("TypeID").Return("nats") // Execute the function - providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore()) + providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), Hooks{}) // Assertions assert.Error(t, err) @@ -242,7 +243,7 @@ func TestBuild_ShouldNotInitializeProviderIfNotUsed(t *testing.T) { Return(mockPubSubUsedProvider, nil) // Execute the function - providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore()) + providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), Hooks{}) // Assertions assert.NoError(t, err) @@ -293,7 +294,7 @@ func TestBuildProvidersAndDataSources_Nats_OK(t *testing.T) { {ID: "provider-1"}, }, }, - }, nil, zap.NewNop(), dsConfs, "host", "addr") + }, nil, zap.NewNop(), dsConfs, "host", "addr", Hooks{}) // Assertions assert.NoError(t, err) @@ -346,7 +347,7 @@ func TestBuildProvidersAndDataSources_Kafka_OK(t *testing.T) { {ID: "provider-1"}, }, }, - }, nil, zap.NewNop(), dsConfs, "host", "addr") + }, nil, zap.NewNop(), dsConfs, "host", "addr", Hooks{}) // Assertions assert.NoError(t, err) @@ -399,7 +400,7 @@ func TestBuildProvidersAndDataSources_Redis_OK(t *testing.T) { {ID: "provider-1"}, }, }, - }, nil, zap.NewNop(), dsConfs, "host", "addr") + }, nil, zap.NewNop(), dsConfs, "host", "addr", Hooks{}) // Assertions assert.NoError(t, err) diff --git a/router/pkg/pubsub/redis/adapter.go b/router/pkg/pubsub/redis/adapter.go index 8de962d2b6..5cb0055a36 100644 --- a/router/pkg/pubsub/redis/adapter.go +++ b/router/pkg/pubsub/redis/adapter.go @@ -9,7 +9,6 @@ import ( rd "github.com/wundergraph/cosmo/router/internal/rediscloser" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" ) @@ -21,7 +20,7 @@ const ( // Adapter defines the methods that a Redis adapter should implement type Adapter interface { // Subscribe subscribes to the given events and sends updates to the updater - Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error + Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error // Publish publishes the given event to the specified channel Publish(ctx context.Context, event PublishEventConfiguration) error // Startup initializes the adapter @@ -94,19 +93,23 @@ func (p *ProviderAdapter) Shutdown(ctx context.Context) error { return p.conn.Close() } -func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error { +func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { + subConf, ok := conf.(*SubscriptionEventConfiguration) + if !ok { + return datasource.NewError("invalid event type for Kafka adapter", nil) + } log := p.logger.With( - zap.String("provider_id", event.ProviderID), + zap.String("provider_id", subConf.ProviderID()), zap.String("method", "subscribe"), - zap.Strings("channels", event.Channels), + zap.Strings("channels", subConf.Channels), ) - sub := p.conn.PSubscribe(ctx, event.Channels...) + sub := p.conn.PSubscribe(ctx, subConf.Channels...) msgChan := sub.Channel() cleanup := func() { - err := sub.PUnsubscribe(ctx, event.Channels...) + err := sub.PUnsubscribe(ctx, subConf.Channels...) if err != nil { - log.Error(fmt.Sprintf("error unsubscribing from redis for topics %v", event.Channels), zap.Error(err)) + log.Error(fmt.Sprintf("error unsubscribing from redis for topics %v", subConf.Channels), zap.Error(err)) } } @@ -128,12 +131,14 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEvent } log.Debug("subscription update", zap.String("message_channel", msg.Channel), zap.String("data", msg.Payload)) p.streamMetricStore.Consume(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: conf.ProviderID(), StreamOperationName: redisReceive, ProviderType: metric.ProviderTypeRedis, DestinationName: msg.Channel, }) - updater.Update([]byte(msg.Payload)) + updater.Update(&Event{ + Data: []byte(msg.Payload), + }) case <-p.ctx.Done(): // When the application context is done, we stop the subscription if it is not already done log.Debug("application context done, stopping subscription") @@ -153,14 +158,14 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, event SubscriptionEvent func (p *ProviderAdapter) Publish(ctx context.Context, event PublishEventConfiguration) error { log := p.logger.With( - zap.String("provider_id", event.ProviderID), + zap.String("provider_id", event.ProviderID()), zap.String("method", "publish"), zap.String("channel", event.Channel), ) - log.Debug("publish", zap.ByteString("data", event.Data)) + log.Debug("publish", zap.ByteString("data", event.Event.Data)) - data, dataErr := event.Data.MarshalJSON() + data, dataErr := event.Event.Data.MarshalJSON() if dataErr != nil { log.Error("error marshalling data", zap.Error(dataErr)) return datasource.NewError("error marshalling data", dataErr) @@ -172,7 +177,7 @@ func (p *ProviderAdapter) Publish(ctx context.Context, event PublishEventConfigu if intCmd.Err() != nil { log.Error("publish error", zap.Error(intCmd.Err())) p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: event.ProviderID(), StreamOperationName: redisPublish, ProviderType: metric.ProviderTypeRedis, ErrorType: "publish_error", @@ -182,7 +187,7 @@ func (p *ProviderAdapter) Publish(ctx context.Context, event PublishEventConfigu } p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID, + ProviderId: event.ProviderID(), StreamOperationName: redisPublish, ProviderType: metric.ProviderTypeRedis, DestinationName: event.Channel, diff --git a/router/pkg/pubsub/redis/engine_datasource.go b/router/pkg/pubsub/redis/engine_datasource.go index d24a4fb959..3a685fe9b0 100644 --- a/router/pkg/pubsub/redis/engine_datasource.go +++ b/router/pkg/pubsub/redis/engine_datasource.go @@ -7,69 +7,66 @@ import ( "fmt" "io" - "github.com/buger/jsonparser" - "github.com/cespare/xxhash/v2" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) -// SubscriptionEventConfiguration contains configuration for subscription events -type SubscriptionEventConfiguration struct { - ProviderID string `json:"providerId"` - Channels []string `json:"channels"` +// Event represents an event from Redis +type Event struct { + Data json.RawMessage `json:"data"` } -// PublishEventConfiguration contains configuration for publish events -type PublishEventConfiguration struct { - ProviderID string `json:"providerId"` - Channel string `json:"channel"` - Data json.RawMessage `json:"data"` +func (e *Event) GetData() []byte { + return e.Data } -func (s *PublishEventConfiguration) MarshalJSONTemplate() (string, error) { - return fmt.Sprintf(`{"channel":"%s", "data": %s, "providerId":"%s"}`, s.Channel, s.Data, s.ProviderID), nil +// SubscriptionEventConfiguration contains configuration for subscription events +type SubscriptionEventConfiguration struct { + Provider string `json:"providerId"` + Channels []string `json:"channels"` + FieldName string `json:"rootFieldName"` } -// SubscriptionDataSource implements resolve.SubscriptionDataSource for Redis -type SubscriptionDataSource struct { - pubSub Adapter +// ProviderID returns the provider ID +func (s *SubscriptionEventConfiguration) ProviderID() string { + return s.Provider } -// UniqueRequestID computes a unique ID for the subscription request -func (s *SubscriptionDataSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { - val, _, _, err := jsonparser.Get(input, "channels") - if err != nil { - return err - } +// ProviderType returns the provider type +func (s *SubscriptionEventConfiguration) ProviderType() datasource.ProviderType { + return datasource.ProviderTypeRedis +} - _, err = xxh.Write(val) - if err != nil { - return err - } +// RootFieldName returns the root field name +func (s *SubscriptionEventConfiguration) RootFieldName() string { + return s.FieldName +} - val, _, _, err = jsonparser.Get(input, "providerId") - if err != nil { - return err - } +// PublishEventConfiguration contains configuration for publish events +type PublishEventConfiguration struct { + Provider string `json:"providerId"` + Channel string `json:"channel"` + Event Event `json:"event"` + FieldName string `json:"rootFieldName"` +} - _, err = xxh.Write(val) - return err +// ProviderID returns the provider ID +func (p *PublishEventConfiguration) ProviderID() string { + return p.Provider } -// Start starts the subscription -func (s *SubscriptionDataSource) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error { - var subscriptionConfiguration SubscriptionEventConfiguration - err := json.Unmarshal(input, &subscriptionConfiguration) - if err != nil { - return err - } +// ProviderType returns the provider type +func (p *PublishEventConfiguration) ProviderType() datasource.ProviderType { + return datasource.ProviderTypeRedis +} - return s.pubSub.Subscribe(ctx.Context(), subscriptionConfiguration, updater) +// RootFieldName returns the root field name +func (p *PublishEventConfiguration) RootFieldName() string { + return p.FieldName } -// LoadInitialData implements the interface method (not used for this subscription type) -func (s *SubscriptionDataSource) LoadInitialData(ctx context.Context) (initial []byte, err error) { - return nil, nil +func (s *PublishEventConfiguration) MarshalJSONTemplate() (string, error) { + return fmt.Sprintf(`{"channel":"%s", "event": {"data": %s}, "providerId":"%s"}`, s.Channel, s.Event.Data, s.ProviderID()), nil } // PublishDataSource implements resolve.DataSource for Redis publishing @@ -80,8 +77,7 @@ type PublishDataSource struct { // Load processes a request to publish to Redis func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { var publishConfiguration PublishEventConfiguration - err := json.Unmarshal(input, &publishConfiguration) - if err != nil { + if err := json.Unmarshal(input, &publishConfiguration); err != nil { return err } @@ -89,7 +85,7 @@ func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.B _, err = io.WriteString(out, `{"success": false}`) return err } - _, err = io.WriteString(out, `{"success": true}`) + _, err := io.WriteString(out, `{"success": true}`) return err } @@ -97,3 +93,8 @@ func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.B func (s *PublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { panic("not implemented") } + +// Interface compliance checks +var _ datasource.SubscriptionEventConfiguration = (*SubscriptionEventConfiguration)(nil) +var _ datasource.PublishEventConfiguration = (*PublishEventConfiguration)(nil) +var _ datasource.StreamEvent = (*Event)(nil) diff --git a/router/pkg/pubsub/redis/engine_datasource_factory.go b/router/pkg/pubsub/redis/engine_datasource_factory.go index c5383ff16a..bce913e54e 100644 --- a/router/pkg/pubsub/redis/engine_datasource_factory.go +++ b/router/pkg/pubsub/redis/engine_datasource_factory.go @@ -5,6 +5,8 @@ import ( "fmt" "slices" + "github.com/buger/jsonparser" + "github.com/cespare/xxhash/v2" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) @@ -59,26 +61,46 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri providerId := c.providerId evtCfg := PublishEventConfiguration{ - ProviderID: providerId, - Channel: channel, - Data: eventData, + Provider: providerId, + Channel: channel, + Event: Event{Data: eventData}, + FieldName: c.fieldName, } return evtCfg.MarshalJSONTemplate() } // ResolveDataSourceSubscription returns the subscription data source -func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (resolve.SubscriptionDataSource, error) { - return &SubscriptionDataSource{ - pubSub: c.RedisAdapter, - }, nil +func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (datasource.SubscriptionDataSource, error) { + return datasource.NewPubSubSubscriptionDataSource[*SubscriptionEventConfiguration]( + c.RedisAdapter, + func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + val, _, _, err := jsonparser.Get(input, "channels") + if err != nil { + return err + } + + _, err = xxh.Write(val) + if err != nil { + return err + } + + val, _, _, err = jsonparser.Get(input, "providerId") + if err != nil { + return err + } + + _, err = xxh.Write(val) + return err + }), nil } // ResolveDataSourceSubscriptionInput builds the input for the subscription data source func (c *EngineDataSourceFactory) ResolveDataSourceSubscriptionInput() (string, error) { evtCfg := SubscriptionEventConfiguration{ - ProviderID: c.providerId, - Channels: c.channels, + Provider: c.providerId, + Channels: c.channels, + FieldName: c.fieldName, } object, err := json.Marshal(evtCfg) if err != nil { diff --git a/router/pkg/pubsub/redis/engine_datasource_factory_test.go b/router/pkg/pubsub/redis/engine_datasource_factory_test.go index 0c1344048a..f96691583d 100644 --- a/router/pkg/pubsub/redis/engine_datasource_factory_test.go +++ b/router/pkg/pubsub/redis/engine_datasource_factory_test.go @@ -4,11 +4,15 @@ import ( "bytes" "context" "encoding/json" + "errors" "testing" + "github.com/cespare/xxhash/v2" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "github.com/wundergraph/cosmo/router/pkg/pubsub/pubsubtest" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) func TestRedisEngineDataSourceFactory(t *testing.T) { @@ -33,7 +37,7 @@ func TestEngineDataSourceFactoryWithMockAdapter(t *testing.T) { // Configure mock expectations for Publish mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishEventConfiguration) bool { - return event.ProviderID == "test-provider" && event.Channel == "test-channel" + return event.ProviderID() == "test-provider" && event.Channel == "test-channel" })).Return(nil) // Create the data source with mock adapter @@ -176,3 +180,57 @@ func TestTransformEventConfig(t *testing.T) { require.Equal(t, []string{"transformed.original.subject1", "transformed.original.subject2"}, cfg.channels) }) } + +func TestRedisEngineDataSourceFactory_UniqueRequestID(t *testing.T) { + tests := []struct { + name string + input string + expectError bool + expectedError error + }{ + { + name: "valid input", + input: `{"channels":["channel1", "channel2"], "providerId":"test-provider"}`, + expectError: false, + }, + { + name: "missing channels", + input: `{"providerId":"test-provider"}`, + expectError: true, + expectedError: errors.New("Key path not found"), + }, + { + name: "missing providerId", + input: `{"channels":["channel1", "channel2"]}`, + expectError: true, + expectedError: errors.New("Key path not found"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + factory := &EngineDataSourceFactory{ + RedisAdapter: NewMockAdapter(t), + } + source, err := factory.ResolveDataSourceSubscription() + require.NoError(t, err) + ctx := &resolve.Context{} + input := []byte(tt.input) + xxh := xxhash.New() + + err = source.UniqueRequestID(ctx, input, xxh) + + if tt.expectError { + require.Error(t, err) + if tt.expectedError != nil { + // For jsonparser errors, just check if the error message contains the expected text + assert.Contains(t, err.Error(), tt.expectedError.Error()) + } + } else { + require.NoError(t, err) + // Check that the hash has been updated + assert.NotEqual(t, 0, xxh.Sum64()) + } + }) + } +} diff --git a/router/pkg/pubsub/redis/engine_datasource_test.go b/router/pkg/pubsub/redis/engine_datasource_test.go index 7c47d47cc6..74b7d564d7 100644 --- a/router/pkg/pubsub/redis/engine_datasource_test.go +++ b/router/pkg/pubsub/redis/engine_datasource_test.go @@ -7,12 +7,9 @@ import ( "errors" "testing" - "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { @@ -24,20 +21,20 @@ func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { { name: "simple configuration", config: PublishEventConfiguration{ - ProviderID: "test-provider", - Channel: "test-channel", - Data: json.RawMessage(`{"message":"hello"}`), + Provider: "test-provider", + Channel: "test-channel", + Event: Event{Data: json.RawMessage(`{"message":"hello"}`)}, }, - wantPattern: `{"channel":"test-channel", "data": {"message":"hello"}, "providerId":"test-provider"}`, + wantPattern: `{"channel":"test-channel", "event": {"data": {"message":"hello"}}, "providerId":"test-provider"}`, }, { name: "with special characters", config: PublishEventConfiguration{ - ProviderID: "test-provider-id", - Channel: "channel-with-hyphens", - Data: json.RawMessage(`{"message":"special \"quotes\" here"}`), + Provider: "test-provider-id", + Channel: "channel-with-hyphens", + Event: Event{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, }, - wantPattern: `{"channel":"channel-with-hyphens", "data": {"message":"special \"quotes\" here"}, "providerId":"test-provider-id"}`, + wantPattern: `{"channel":"channel-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}}, "providerId":"test-provider-id"}`, }, } @@ -50,124 +47,6 @@ func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { } } -func TestSubscriptionSource_UniqueRequestID(t *testing.T) { - tests := []struct { - name string - input string - expectError bool - expectedError error - }{ - { - name: "valid input", - input: `{"channels":["channel1", "channel2"], "providerId":"test-provider"}`, - expectError: false, - }, - { - name: "missing channels", - input: `{"providerId":"test-provider"}`, - expectError: true, - expectedError: errors.New("Key path not found"), - }, - { - name: "missing providerId", - input: `{"channels":["channel1", "channel2"]}`, - expectError: true, - expectedError: errors.New("Key path not found"), - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - source := &SubscriptionDataSource{ - pubSub: NewMockAdapter(t), - } - ctx := &resolve.Context{} - input := []byte(tt.input) - xxh := xxhash.New() - - err := source.UniqueRequestID(ctx, input, xxh) - - if tt.expectError { - require.Error(t, err) - if tt.expectedError != nil { - // For jsonparser errors, just check if the error message contains the expected text - assert.Contains(t, err.Error(), tt.expectedError.Error()) - } - } else { - require.NoError(t, err) - // Check that the hash has been updated - assert.NotEqual(t, 0, xxh.Sum64()) - } - }) - } -} - -func TestSubscriptionSource_Start(t *testing.T) { - tests := []struct { - name string - input string - mockSetup func(*MockAdapter, *datasource.MockSubscriptionUpdater) - expectError bool - }{ - { - name: "successful subscription", - input: `{"channels":["channel1", "channel2"], "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter, updater *datasource.MockSubscriptionUpdater) { - m.On("Subscribe", mock.Anything, SubscriptionEventConfiguration{ - ProviderID: "test-provider", - Channels: []string{"channel1", "channel2"}, - }, mock.Anything).Return(nil) - }, - expectError: false, - }, - { - name: "adapter returns error", - input: `{"channels":["channel1"], "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter, updater *datasource.MockSubscriptionUpdater) { - m.On("Subscribe", mock.Anything, SubscriptionEventConfiguration{ - ProviderID: "test-provider", - Channels: []string{"channel1"}, - }, mock.Anything).Return(errors.New("subscription error")) - }, - expectError: true, - }, - { - name: "invalid input json", - input: `{"invalid json":`, - mockSetup: func(m *MockAdapter, updater *datasource.MockSubscriptionUpdater) {}, - expectError: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockAdapter := NewMockAdapter(t) - updater := datasource.NewMockSubscriptionUpdater(t) - tt.mockSetup(mockAdapter, updater) - - source := &SubscriptionDataSource{ - pubSub: mockAdapter, - } - - // Set up go context - goCtx := context.Background() - - // Create a resolve.Context with the standard context - resolveCtx := &resolve.Context{} - resolveCtx = resolveCtx.WithContext(goCtx) - - input := []byte(tt.input) - err := source.Start(resolveCtx, input, updater) - - if tt.expectError { - require.Error(t, err) - } else { - require.NoError(t, err) - } - }) - } -} - func TestRedisPublishDataSource_Load(t *testing.T) { tests := []struct { name string @@ -179,12 +58,12 @@ func TestRedisPublishDataSource_Load(t *testing.T) { }{ { name: "successful publish", - input: `{"channel":"test-channel", "data":{"message":"hello"}, "providerId":"test-provider"}`, + input: `{"channel":"test-channel", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { m.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishEventConfiguration) bool { - return event.ProviderID == "test-provider" && + return event.ProviderID() == "test-provider" && event.Channel == "test-channel" && - string(event.Data) == `{"message":"hello"}` + string(event.Event.Data) == `{"message":"hello"}` })).Return(nil) }, expectError: false, @@ -193,7 +72,7 @@ func TestRedisPublishDataSource_Load(t *testing.T) { }, { name: "publish error", - input: `{"channel":"test-channel", "data":{"message":"hello"}, "providerId":"test-provider"}`, + input: `{"channel":"test-channel", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { m.On("Publish", mock.Anything, mock.Anything).Return(errors.New("publish error")) }, diff --git a/router/pkg/pubsub/redis/mocks.go b/router/pkg/pubsub/redis/mocks.go index 603a5dd548..6f6938cdd0 100644 --- a/router/pkg/pubsub/redis/mocks.go +++ b/router/pkg/pubsub/redis/mocks.go @@ -8,7 +8,7 @@ import ( "context" mock "github.com/stretchr/testify/mock" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" ) // NewMockAdapter creates a new instance of MockAdapter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. @@ -198,7 +198,7 @@ func (_c *MockAdapter_Startup_Call) RunAndReturn(run func(ctx context.Context) e } // Subscribe provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Subscribe(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error { +func (_mock *MockAdapter) Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { ret := _mock.Called(ctx, event, updater) if len(ret) == 0 { @@ -206,7 +206,7 @@ func (_mock *MockAdapter) Subscribe(ctx context.Context, event SubscriptionEvent } var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, SubscriptionEventConfiguration, resolve.SubscriptionUpdater) error); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, datasource.SubscriptionEventConfiguration, datasource.SubscriptionEventUpdater) error); ok { r0 = returnFunc(ctx, event, updater) } else { r0 = ret.Error(0) @@ -221,25 +221,25 @@ type MockAdapter_Subscribe_Call struct { // Subscribe is a helper method to define mock.On call // - ctx context.Context -// - event SubscriptionEventConfiguration -// - updater resolve.SubscriptionUpdater +// - event datasource.SubscriptionEventConfiguration +// - updater datasource.SubscriptionEventUpdater func (_e *MockAdapter_Expecter) Subscribe(ctx interface{}, event interface{}, updater interface{}) *MockAdapter_Subscribe_Call { return &MockAdapter_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, event, updater)} } -func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater)) *MockAdapter_Subscribe_Call { +func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater)) *MockAdapter_Subscribe_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 SubscriptionEventConfiguration + var arg1 datasource.SubscriptionEventConfiguration if args[1] != nil { - arg1 = args[1].(SubscriptionEventConfiguration) + arg1 = args[1].(datasource.SubscriptionEventConfiguration) } - var arg2 resolve.SubscriptionUpdater + var arg2 datasource.SubscriptionEventUpdater if args[2] != nil { - arg2 = args[2].(resolve.SubscriptionUpdater) + arg2 = args[2].(datasource.SubscriptionEventUpdater) } run( arg0, @@ -255,7 +255,7 @@ func (_c *MockAdapter_Subscribe_Call) Return(err error) *MockAdapter_Subscribe_C return _c } -func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, event SubscriptionEventConfiguration, updater resolve.SubscriptionUpdater) error) *MockAdapter_Subscribe_Call { +func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error) *MockAdapter_Subscribe_Call { _c.Call.Return(run) return _c } From b594b1c8cb5b859e3a58cf59951451a05b81f128 Mon Sep 17 00:00:00 2001 From: Alessandro Pagnin Date: Fri, 10 Oct 2025 13:59:30 +0200 Subject: [PATCH 02/44] feat: batch and stream hooks (#2087) Co-authored-by: Ludwig Bedacht Co-authored-by: StarpTech Co-authored-by: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> --- adr/cosmo-streams-v1.md | 129 ++-- .../availability/subgraph/schema.resolvers.go | 16 +- .../mood/subgraph/schema.resolvers.go | 11 +- router-tests/events/events_config_test.go | 2 +- router-tests/events/kafka_events_test.go | 102 ++- router-tests/events/nats_events_test.go | 2 +- router-tests/events/redis_events_test.go | 75 +-- .../events/{event_helpers.go => utils.go} | 88 ++- router-tests/go.mod | 2 +- router-tests/go.sum | 6 +- .../modules/start-subscription/module.go | 4 +- .../modules/start_subscription_test.go | 18 +- router-tests/modules/stream-publish/module.go | 49 ++ router-tests/modules/stream-receive/module.go | 49 ++ router-tests/modules/stream_publish_test.go | 315 +++++++++ router-tests/modules/stream_receive_test.go | 521 +++++++++++++++ .../modules/streams_hooks_combined_test.go | 149 +++++ .../prometheus_stream_metrics_test.go | 6 +- router-tests/telemetry/stream_metrics_test.go | 6 +- router/.mockery.yml | 6 - router/core/errors.go | 8 +- router/core/factoryresolver.go | 15 +- router/core/graphql_handler.go | 16 +- router/core/router.go | 8 + router/core/router_config.go | 5 +- router/core/subscriptions_modules.go | 215 ++++-- router/go.mod | 2 +- router/go.sum | 4 +- router/pkg/pubsub/datasource/datasource.go | 2 +- router/pkg/pubsub/datasource/factory.go | 16 +- router/pkg/pubsub/datasource/hooks.go | 20 + router/pkg/pubsub/datasource/mocks.go | 191 +++++- router/pkg/pubsub/datasource/mocks_resolve.go | 140 ++++ router/pkg/pubsub/datasource/planner.go | 6 +- router/pkg/pubsub/datasource/provider.go | 9 +- .../pkg/pubsub/datasource/pubsubprovider.go | 43 ++ .../pubsub/datasource/pubsubprovider_test.go | 426 +++++++++++- .../datasource/subscription_datasource.go | 22 +- .../subscription_datasource_test.go | 84 ++- .../datasource/subscription_event_updater.go | 100 ++- .../subscription_event_updater_test.go | 627 ++++++++++++++++++ router/pkg/pubsub/kafka/adapter.go | 112 ++-- router/pkg/pubsub/kafka/engine_datasource.go | 121 +++- .../pubsub/kafka/engine_datasource_factory.go | 8 +- .../kafka/engine_datasource_factory_test.go | 12 +- .../pubsub/kafka/engine_datasource_test.go | 63 +- router/pkg/pubsub/kafka/mocks.go | 261 -------- router/pkg/pubsub/kafka/provider_builder.go | 21 +- router/pkg/pubsub/nats/adapter.go | 135 ++-- router/pkg/pubsub/nats/engine_datasource.go | 118 +++- .../pubsub/nats/engine_datasource_factory.go | 11 +- .../nats/engine_datasource_factory_test.go | 16 +- .../pkg/pubsub/nats/engine_datasource_test.go | 56 +- router/pkg/pubsub/nats/mocks.go | 76 ++- router/pkg/pubsub/nats/provider_builder.go | 20 +- router/pkg/pubsub/pubsub.go | 42 +- router/pkg/pubsub/pubsub_test.go | 22 +- router/pkg/pubsub/redis/adapter.go | 88 +-- router/pkg/pubsub/redis/engine_datasource.go | 100 ++- .../pubsub/redis/engine_datasource_factory.go | 10 +- .../redis/engine_datasource_factory_test.go | 12 +- .../pubsub/redis/engine_datasource_test.go | 51 +- router/pkg/pubsub/redis/mocks.go | 261 -------- router/pkg/pubsub/redis/provider_builder.go | 12 +- 64 files changed, 3886 insertions(+), 1257 deletions(-) rename router-tests/events/{event_helpers.go => utils.go} (55%) create mode 100644 router-tests/modules/stream-publish/module.go create mode 100644 router-tests/modules/stream-receive/module.go create mode 100644 router-tests/modules/stream_publish_test.go create mode 100644 router-tests/modules/stream_receive_test.go create mode 100644 router-tests/modules/streams_hooks_combined_test.go create mode 100644 router/pkg/pubsub/datasource/hooks.go create mode 100644 router/pkg/pubsub/datasource/subscription_event_updater_test.go delete mode 100644 router/pkg/pubsub/kafka/mocks.go delete mode 100644 router/pkg/pubsub/redis/mocks.go diff --git a/adr/cosmo-streams-v1.md b/adr/cosmo-streams-v1.md index 436dafe45b..21b035ff0b 100644 --- a/adr/cosmo-streams-v1.md +++ b/adr/cosmo-streams-v1.md @@ -21,24 +21,18 @@ The following interfaces will extend the existing logic in the custom modules. These provide additional control over subscriptions by providing hooks, which are invoked during specific events. - `SubscriptionOnStartHandler`: Called once at subscription start. -- `StreamBatchEventHook`: Called each time a batch of events is received from the provider. -- `StreamPublishEventHook`: Called each time a batch of events is going to be sent to the provider. +- `StreamReceiveEventHandler`: Triggered for each client/subscription when a batch of events is received from the provider, prior to delivery. +- `StreamPublishEventHandler`: Called each time a batch of events is going to be sent to the provider. ```go // STRUCTURES TO BE ADDED TO PUBSUB PACKAGE type ProviderType string const ( - ProviderTypeNats ProviderType = "nats" + ProviderTypeNats ProviderType = "nats" ProviderTypeKafka ProviderType = "kafka" ProviderTypeRedis ProviderType = "redis" } -// StreamHookError is used to customize the error messages and the behavior -type StreamHookError struct { - HttpError core.HttpError - CloseSubscription bool -} - // OperationContext already exists, we just have to add the Variables() method type OperationContext interface { Name() string @@ -48,8 +42,9 @@ type OperationContext interface { // each provider will have its own event type with custom fields // the StreamEvent interface is used to allow the hooks system to be provider-agnostic -// there could be common fields in future, but for now we don't need them -type StreamEvent interface {} +type StreamEvent interface { + GetData() []byte +} // SubscriptionEventConfiguration is the common interface for the subscription event configuration type SubscriptionEventConfiguration interface { @@ -67,7 +62,7 @@ type PublishEventConfiguration interface { RootFieldName() string } -type SubscriptionOnStartHookContext interface { +type SubscriptionOnStartHandlerContext interface { // Request is the original request received by the router. Request() *http.Request // Logger is the logger for the request @@ -85,34 +80,48 @@ type SubscriptionOnStartHookContext interface { type SubscriptionOnStartHandler interface { // OnSubscriptionOnStart is called once at subscription start - // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a StreamHookError. - SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error + // Returning an error will result in a GraphQL error being returned to the client + SubscriptionOnStart(ctx SubscriptionOnStartHandlerContext) error } -type StreamBatchEventHookContext interface { - // the request context - RequestContext() RequestContext - // the subscription event configuration +type StreamReceiveEventHandlerContext interface { + // Request is the initial client request that started the subscription + Request() *http.Request + // Logger is the logger for the request + Logger() *zap.Logger + // Operation is the GraphQL operation + Operation() OperationContext + // Authentication is the authentication for the request + Authentication() authentication.Authentication + // SubscriptionEventConfiguration is the subscription event configuration SubscriptionEventConfiguration() SubscriptionEventConfiguration } -type StreamBatchEventHook interface { - // OnStreamEvents is called each time a batch of events is received from the provider - // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a StreamHookError. - OnStreamEvents(ctx StreamBatchEventHookContext, events []StreamEvent) ([]StreamEvent, error) +type StreamReceiveEventHandler interface { + // OnReceiveEvents is called each time a batch of events is received from the provider before delivering them to the client + // So for a single batch of events received from the provider, this hook will be called one time for each active subscription. + // It is important to optimize the logic inside this hook to avoid performance issues. + // Returning an error will result in a GraphQL error being returned to the client + OnReceiveEvents(ctx StreamReceiveEventHandlerContext, events []StreamEvent) ([]StreamEvent, error) } -type StreamPublishEventHookContext interface { - // the request context - RequestContext() RequestContext - // the publish event configuration +type StreamPublishEventHandlerContext interface { + // Request is the original request received by the router. + Request() *http.Request + // Logger is the logger for the request + Logger() *zap.Logger + // Operation is the GraphQL operation + Operation() OperationContext + // Authentication is the authentication for the request + Authentication() authentication.Authentication + // PublishEventConfiguration is the publish event configuration PublishEventConfiguration() PublishEventConfiguration } -type StreamPublishEventHook interface { +type StreamPublishEventHandler interface { // OnPublishEvents is called each time a batch of events is going to be sent to the provider - // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a StreamHookError. - OnPublishEvents(ctx StreamPublishEventHookContext, events []StreamEvent) ([]StreamEvent, error) + // Returning an error will result in an error being returned and the client will see the mutation failing + OnPublishEvents(ctx StreamPublishEventHandlerContext, events []StreamEvent) ([]StreamEvent, error) } ``` @@ -154,7 +163,7 @@ type Employee @key(fields: "id", resolvable: false) { id: Int! @external } ``` -After publishing the schema, the developer will need to add the module to the cosmo streams engine. +After publishing the schema, the developer will need to add the module to the cosmo router. ### 2. Write the custom module @@ -177,39 +186,38 @@ func init() { type MyModule struct {} -func (m *MyModule) OnStreamEvents(ctx StreamBatchEventHookContext, events []core.StreamEvent) ([]core.StreamEvent, error) { +func (m *MyModule) OnReceiveEvents(ctx StreamReceiveEventHandlerContext, events []core.StreamEvent) ([]core.StreamEvent, error) { // check if the provider is nats - if ctx.StreamContext().ProviderType() != pubsub.ProviderTypeNats { + if ctx.SubscriptionEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { return events, nil } // check if the provider id is the one expected by the module - if ctx.StreamContext().ProviderID() != "my-nats" { + if ctx.SubscriptionEventConfiguration().ProviderID() != "my-nats" { return events, nil } - // check if the subject is the one expected by the module - natsConfig := ctx.SubscriptionEventConfiguration().(*nats.SubscriptionEventConfiguration) - if natsConfig.Subjects[0] != "employeeUpdates" { - return events, nil - } + // check if the subscription is the one expected by the module + if ctx.SubscriptionEventConfiguration().RootFieldName() != "employeeUpdates" { + return events, nil + } + + newEvents := make([]core.StreamEvent, 0, len(events)) // check if the client is authenticated - if ctx.RequestContext().Authentication() == nil { + if ctx.Authentication() == nil { // if the client is not authenticated, return no events - return events, nil + return newEvents, nil } // check if the client is allowed to subscribe to the stream - clientAllowedEntitiesIds, found := ctx.RequestContext().Authentication().Claims()["allowedEntitiesIds"] + clientAllowedEntitiesIds, found := ctx.Authentication().Claims()["allowedEntitiesIds"] if !found { - return events, fmt.Errorf("client is not allowed to subscribe to the stream") + return newEvents, fmt.Errorf("client is not allowed to subscribe to the stream") } - newEvents := make([]core.StreamEvent, 0, len(events)) - for _, evt := range events { - natsEvent, ok := evt.(*nats.NatsEvent); + natsEvent, ok := evt.(*nats.NatsEvent) if !ok { newEvents = append(newEvents, evt) continue @@ -266,7 +274,7 @@ func (m *MyModule) Module() core.ModuleInfo { // Interface guards var ( - _ core.StreamBatchEventHook = (*MyModule)(nil) + _ core.StreamReceiveEventHandler = (*MyModule)(nil) ) ``` @@ -321,7 +329,7 @@ func init() { type MyModule struct {} -func (m *MyModule) SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error { +func (m *MyModule) SubscriptionOnStart(ctx SubscriptionOnStartHandlerContext) error { // check if the provider is nats if ctx.SubscriptionEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { return nil @@ -332,20 +340,17 @@ func (m *MyModule) SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error return nil } - // check if the subject is the one expected by the module - natsConfig := ctx.SubscriptionEventConfiguration().(*nats.SubscriptionEventConfiguration) - if natsConfig.Subjects[0] != "employeeUpdates" { - return nil - } + // check if the subscription is the one expected by the module + if ctx.SubscriptionEventConfiguration().RootFieldName() != "employeeUpdates" { + return nil + } // check if the client is authenticated if ctx.Authentication() == nil { // if the client is not authenticated, return an error - return &StreamHookError{ - HttpError: core.HttpError{ - Code: http.StatusUnauthorized, - Message: "client is not authenticated", - }, + return &core.HttpError{ + Code: http.StatusUnauthorized, + Message: "client is not authenticated", CloseSubscription: true, } } @@ -353,11 +358,9 @@ func (m *MyModule) SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error // check if the client is allowed to subscribe to the stream clientAllowedEntitiesIds, found := ctx.Authentication().Claims()["readEmployee"] if !found { - return &StreamHookError{ - HttpError: core.HttpError{ - Code: http.StatusForbidden, - Message: "client is not allowed to read employees", - }, + return &core.HttpError{ + Code: http.StatusForbidden, + Message: "client is not allowed to read employees", CloseSubscription: true, } } @@ -405,4 +408,4 @@ We could also generate the AsyncAPI specification from the schema and the events ## Generate hooks from AsyncAPI specifications -Building on the AsyncAPI integration, we could allow the user to define their streams using AsyncAPI and generate fully typesafe hooks with all events structures generated from the AsyncAPI specification. \ No newline at end of file +Building on the AsyncAPI integration, we could allow the user to define their streams using AsyncAPI and generate fully typesafe hooks with all events structures generated from the AsyncAPI specification. diff --git a/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go b/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go index 6abb2c062e..8e52ec96c5 100644 --- a/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go +++ b/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go @@ -10,24 +10,28 @@ import ( "github.com/wundergraph/cosmo/demo/pkg/subgraphs/availability/subgraph/generated" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/availability/subgraph/model" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" ) // UpdateAvailability is the resolver for the updateAvailability field. func (r *mutationResolver) UpdateAvailability(ctx context.Context, employeeID int, isAvailable bool) (*model.Employee, error) { storage.Set(employeeID, isAvailable) - err := r.NatsPubSubByProviderID["default"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ + conf := &nats.PublishAndRequestEventConfiguration{ Subject: r.GetPubSubName(fmt.Sprintf("employeeUpdated.%d", employeeID)), - Event: nats.Event{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))}, - }) + } + evt := &nats.Event{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))} + err := r.NatsPubSubByProviderID["default"].Publish(ctx, conf, []datasource.StreamEvent{evt}) if err != nil { return nil, err } - err = r.NatsPubSubByProviderID["my-nats"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ + + conf2 := &nats.PublishAndRequestEventConfiguration{ Subject: r.GetPubSubName(fmt.Sprintf("employeeUpdatedMyNats.%d", employeeID)), - Event: nats.Event{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))}, - }) + } + evt2 := &nats.Event{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))} + err = r.NatsPubSubByProviderID["my-nats"].Publish(ctx, conf2, []datasource.StreamEvent{evt2}) if err != nil { return nil, err diff --git a/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go b/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go index 82a0a7e9f2..b9b426593c 100644 --- a/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go +++ b/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go @@ -10,6 +10,7 @@ import ( "github.com/wundergraph/cosmo/demo/pkg/subgraphs/mood/subgraph/generated" "github.com/wundergraph/cosmo/demo/pkg/subgraphs/mood/subgraph/model" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" ) @@ -19,10 +20,9 @@ func (r *mutationResolver) UpdateMood(ctx context.Context, employeeID int, mood myNatsTopic := r.GetPubSubName(fmt.Sprintf("employeeUpdated.%d", employeeID)) payload := fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID) if r.NatsPubSubByProviderID["default"] != nil { - err := r.NatsPubSubByProviderID["default"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ + err := r.NatsPubSubByProviderID["default"].Publish(ctx, &nats.PublishAndRequestEventConfiguration{ Subject: myNatsTopic, - Event: nats.Event{Data: []byte(payload)}, - }) + }, []datasource.StreamEvent{&nats.Event{Data: []byte(payload)}}) if err != nil { return nil, err } @@ -32,10 +32,9 @@ func (r *mutationResolver) UpdateMood(ctx context.Context, employeeID int, mood defaultTopic := r.GetPubSubName(fmt.Sprintf("employeeUpdatedMyNats.%d", employeeID)) if r.NatsPubSubByProviderID["my-nats"] != nil { - err := r.NatsPubSubByProviderID["my-nats"].Publish(ctx, nats.PublishAndRequestEventConfiguration{ + err := r.NatsPubSubByProviderID["my-nats"].Publish(ctx, &nats.PublishAndRequestEventConfiguration{ Subject: defaultTopic, - Event: nats.Event{Data: []byte(payload)}, - }) + }, []datasource.StreamEvent{&nats.Event{Data: []byte(payload)}}) if err != nil { return nil, err } diff --git a/router-tests/events/events_config_test.go b/router-tests/events/events_config_test.go index f7e0739e1c..50d19dbaed 100644 --- a/router-tests/events/events_config_test.go +++ b/router-tests/events/events_config_test.go @@ -1,4 +1,4 @@ -package events +package events_test import ( "testing" diff --git a/router-tests/events/kafka_events_test.go b/router-tests/events/kafka_events_test.go index 3ad51a592c..37e54109a8 100644 --- a/router-tests/events/kafka_events_test.go +++ b/router-tests/events/kafka_events_test.go @@ -1,9 +1,8 @@ -package events +package events_test import ( "bufio" "bytes" - "context" "encoding/json" "fmt" "net/http" @@ -11,13 +10,13 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/wundergraph/cosmo/router/core" - "github.com/hasura/go-graphql-client" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/twmb/franz-go/pkg/kgo" + + "github.com/wundergraph/cosmo/router-tests/events" "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" "github.com/wundergraph/cosmo/router/pkg/config" ) @@ -74,7 +73,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) var subscriptionOne struct { employeeUpdatedMyKafka struct { @@ -107,7 +106,7 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.NoError(t, args.errValue) @@ -130,7 +129,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) var subscriptionOne struct { employeeUpdatedMyKafka struct { @@ -164,23 +163,23 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) - ProduceKafkaMessage(t, xEnv, topics[0], ``) // Empty message + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], ``) // Empty message testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.ErrorContains(t, args.errValue, "Invalid message received") }) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.NoError(t, args.errValue) require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) }) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) // Missing entity = Resolver error + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) // Missing entity = Resolver error testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.ErrorContains(t, args.errValue, "Cannot return null for non-nullable field 'Subscription.employeeUpdatedMyKafka.id'.") }) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.NoError(t, args.errValue) require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) @@ -204,7 +203,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) var subscriptionOne struct { employeeUpdatedMyKafka struct { @@ -248,7 +247,7 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(2, KafkaWaitTimeout) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionOneArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.NoError(t, args.errValue) @@ -277,7 +276,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) var subscriptionOne struct { employeeUpdatedMyKafka struct { @@ -321,7 +320,7 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(2, KafkaWaitTimeout) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionOneArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.NoError(t, args.errValue) @@ -333,7 +332,7 @@ func TestKafkaEvents(t *testing.T) { require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) }) - ProduceKafkaMessage(t, xEnv, topics[1], `{"__typename":"Employee","id": 2,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[1], `{"__typename":"Employee","id": 2,"update":{"name":"foo"}}`) testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionOneArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.NoError(t, args.errValue) @@ -366,7 +365,7 @@ func TestKafkaEvents(t *testing.T) { engineExecutionConfiguration.WebSocketClientReadTimeout = time.Millisecond * 100 }, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) var subscriptionOne struct { employeeUpdatedMyKafka struct { @@ -399,7 +398,7 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionOneArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.NoError(t, args.errValue) @@ -431,7 +430,7 @@ func TestKafkaEvents(t *testing.T) { core.WithSubscriptionHeartbeatInterval(subscriptionHeartbeatInterval), }, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) subscribePayload := []byte(`{"query":"subscription { employeeUpdatedMyKafka(employeeID: 1) { id details { forename surname } }}"}`) @@ -447,10 +446,10 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) assertKafkaMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"employeeUpdatedMyKafka\":{\"id\":1,\"details\":{\"forename\":\"Jens\",\"surname\":\"Neuse\"}}}}}") - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) assertKafkaMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"employeeUpdatedMyKafka\":{\"id\":1,\"details\":{\"forename\":\"Jens\",\"surname\":\"Neuse\"}}}}}") }) }) @@ -497,7 +496,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) subscribePayload := []byte(`{"query":"subscription { employeeUpdatedMyKafka(employeeID: 1) { id details { forename surname } }}"}`) @@ -530,7 +529,7 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) testenv.AwaitChannelWithT(t, KafkaWaitTimeout, responseCh, func(t *testing.T, response struct { response *http.Response @@ -562,7 +561,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) subscribePayload := []byte(`{"query":"subscription { employeeUpdatedMyKafka(employeeID: 1) { id details { forename surname } }}"}`) @@ -595,7 +594,7 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) testenv.AwaitChannelWithT(t, KafkaWaitTimeout, responseCh, func(t *testing.T, resp struct { response *http.Response @@ -672,7 +671,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) type subscriptionPayload struct { Data struct { @@ -713,7 +712,7 @@ func TestKafkaEvents(t *testing.T) { // Events 1, 2, 11, and 12 should be included for i := uint32(1); i < 13; i++ { - ProduceKafkaMessage(t, xEnv, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) if i == 1 || i == 2 || i == 11 || i == 12 { conn.SetReadDeadline(time.Now().Add(KafkaWaitTimeout)) gErr := conn.ReadJSON(&msg) @@ -739,7 +738,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) type subscriptionPayload struct { Data struct { @@ -780,7 +779,7 @@ func TestKafkaEvents(t *testing.T) { // Events 1, 2, 11, and 12 should be included for i := uint32(1); i < 13; i++ { - ProduceKafkaMessage(t, xEnv, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) if i == 1 || i == 2 || i == 11 || i == 12 { conn.SetReadDeadline(time.Now().Add(KafkaWaitTimeout)) gErr := conn.ReadJSON(&msg) @@ -806,7 +805,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) type subscriptionPayload struct { Data struct { @@ -835,10 +834,10 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) // The message should be ignored because "1" does not equal 1 - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id":1}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id":1}`) // This message should be delivered because it matches the filter - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id":12}`) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id":12}`) conn.SetReadDeadline(time.Now().Add(KafkaWaitTimeout)) readErr := conn.ReadJSON(&msg) require.NoError(t, readErr) @@ -861,7 +860,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) var subscriptionOne struct { employeeUpdatedMyKafka struct { @@ -894,23 +893,23 @@ func TestKafkaEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, KafkaWaitTimeout) - ProduceKafkaMessage(t, xEnv, topics[0], `{asas`) // Invalid message + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{asas`) // Invalid message testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionOneArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.ErrorContains(t, args.errValue, "Invalid message received") }) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id":1}`) // Correct message + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id":1}`) // Correct message testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionOneArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.NoError(t, args.errValue) require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) }) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) // Missing entity = Resolver error + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) // Missing entity = Resolver error testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionOneArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.ErrorContains(t, args.errValue, "Cannot return null for non-nullable field 'Subscription.employeeUpdatedMyKafka.id'.") }) - ProduceKafkaMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message testenv.AwaitChannelWithT(t, KafkaWaitTimeout, subscriptionOneArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { require.NoError(t, args.errValue) require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) @@ -932,7 +931,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) // Send a mutation to trigger the first subscription resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ @@ -940,7 +939,7 @@ func TestKafkaEvents(t *testing.T) { }) require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":true}}}`, resOne.Body) - records, err := readKafkaMessages(xEnv, topics[0], 1) + records, err := events.ReadKafkaMessages(xEnv, KafkaWaitTimeout, topics[0], 1) require.NoError(t, err) require.Equal(t, 1, len(records)) require.Equal(t, `{"employeeID":3,"update":{"name":"name test"}}`, string(records[0].Value)) @@ -980,7 +979,7 @@ func TestKafkaEvents(t *testing.T) { RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, EnableKafka: true, }, func(t *testing.T, xEnv *testenv.Environment) { - EnsureTopicExists(t, xEnv, topics...) + events.KafkaEnsureTopicExists(t, xEnv, KafkaWaitTimeout, topics...) type subscriptionPayload struct { Data struct { @@ -1024,7 +1023,7 @@ func TestKafkaEvents(t *testing.T) { // Events 1, 3, 4, 7, and 11 should be included for i := int(MsgCount); i > 0; i-- { - ProduceKafkaMessage(t, xEnv, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) + events.ProduceKafkaMessage(t, xEnv, KafkaWaitTimeout, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) if i == 1 || i == 3 || i == 4 || i == 7 || i == 11 { conn.SetReadDeadline(time.Now().Add(KafkaWaitTimeout)) jsonErr := conn.ReadJSON(&msg) @@ -1041,20 +1040,3 @@ func TestKafkaEvents(t *testing.T) { }) }) } - -func readKafkaMessages(xEnv *testenv.Environment, topicName string, msgs int) ([]*kgo.Record, error) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - client, err := kgo.NewClient( - kgo.SeedBrokers(xEnv.GetKafkaSeeds()...), - kgo.ConsumeTopics(xEnv.GetPubSubName(topicName)), - ) - if err != nil { - return nil, err - } - - fetchs := client.PollRecords(ctx, msgs) - - return fetchs.Records(), nil -} diff --git a/router-tests/events/nats_events_test.go b/router-tests/events/nats_events_test.go index 9e1558db24..0add1e361e 100644 --- a/router-tests/events/nats_events_test.go +++ b/router-tests/events/nats_events_test.go @@ -1,4 +1,4 @@ -package events +package events_test import ( "bufio" diff --git a/router-tests/events/redis_events_test.go b/router-tests/events/redis_events_test.go index f6c9e54d13..2980ae61d5 100644 --- a/router-tests/events/redis_events_test.go +++ b/router-tests/events/redis_events_test.go @@ -1,22 +1,20 @@ -package events +package events_test import ( "bufio" "bytes" - "context" "encoding/json" "fmt" "net/http" - "net/url" "testing" "time" - "github.com/redis/go-redis/v9" "github.com/stretchr/testify/assert" "github.com/wundergraph/cosmo/router/core" "github.com/hasura/go-graphql-client" "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/events" "github.com/wundergraph/cosmo/router-tests/testenv" "github.com/wundergraph/cosmo/router/pkg/config" ) @@ -104,7 +102,7 @@ func TestRedisEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, RedisWaitTimeout) // produce a message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // process the message select { @@ -170,7 +168,7 @@ func TestRedisEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, RedisWaitTimeout) // produce an empty message - ProduceRedisMessage(t, xEnv, topics[0], ``) + events.ProduceRedisMessage(t, xEnv, topics[0], ``) // process the message select { case subscriptionArgs := <-subscriptionArgsCh: @@ -181,7 +179,7 @@ func TestRedisEvents(t *testing.T) { t.Fatal("timeout waiting for first message error") } - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // Correct message select { case subscriptionArgs := <-subscriptionArgsCh: require.NoError(t, subscriptionArgs.errValue) @@ -191,7 +189,7 @@ func TestRedisEvents(t *testing.T) { } // Missing entity = Resolver error - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) select { case subscriptionArgs := <-subscriptionArgsCh: var gqlErr graphql.Errors @@ -202,7 +200,7 @@ func TestRedisEvents(t *testing.T) { } // Correct message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) select { case subscriptionArgs := <-subscriptionArgsCh: require.NoError(t, subscriptionArgs.errValue) @@ -273,7 +271,7 @@ func TestRedisEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(2, RedisWaitTimeout) // produce a message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // read the message from the first subscription select { @@ -354,7 +352,7 @@ func TestRedisEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(2, RedisWaitTimeout) // produce a message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // read the message from the first subscription select { @@ -375,7 +373,7 @@ func TestRedisEvents(t *testing.T) { } // produce a message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 2,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 2,"update":{"name":"foo"}}`) // read the message from the first subscription select { @@ -451,7 +449,7 @@ func TestRedisEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, RedisWaitTimeout) // produce a message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // read the message from the subscription select { @@ -509,12 +507,12 @@ func TestRedisEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, RedisWaitTimeout) // produce a message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // read the message from the subscription assertRedisMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"employeeUpdates\":{\"id\":1,\"details\":{\"forename\":\"Jens\",\"surname\":\"Neuse\"}}}}}") // produce a message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // read the message from the subscription assertRedisMultipartValueEventually(t, reader, "{\"payload\":{\"data\":{\"employeeUpdates\":{\"id\":1,\"details\":{\"forename\":\"Jens\",\"surname\":\"Neuse\"}}}}}") }) @@ -590,7 +588,7 @@ func TestRedisEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, RedisWaitTimeout) // produce a message so that the subscription is triggered - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // get the client response var clientRet struct { @@ -663,7 +661,7 @@ func TestRedisEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, RedisWaitTimeout) // produce a message so that the subscription is triggered - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // get the client response var clientRet struct { @@ -792,7 +790,7 @@ func TestRedisEvents(t *testing.T) { // Events 1, 3, 4, 7, and 11 should be included for i := MsgCount; i > 0; i-- { - ProduceRedisMessage(t, xEnv, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) + events.ProduceRedisMessage(t, xEnv, topics[0], fmt.Sprintf(`{"__typename":"Employee","id":%d}`, i)) if i == 11 || i == 7 || i == 4 || i == 3 || i == 1 { gErr := conn.ReadJSON(&msg) @@ -853,7 +851,7 @@ func TestRedisEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, RedisWaitTimeout) // produce an invalid message - ProduceRedisMessage(t, xEnv, topics[0], `{asas`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{asas`) // get the client response select { case args := <-subscriptionOneArgsCh: @@ -865,7 +863,7 @@ func TestRedisEvents(t *testing.T) { } // produce a correct message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id":1}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id":1}`) // get the client response select { case args := <-subscriptionOneArgsCh: @@ -876,7 +874,7 @@ func TestRedisEvents(t *testing.T) { } // produce a message with a missing entity - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","update":{"name":"foo"}}`) // get the client response select { case args := <-subscriptionOneArgsCh: @@ -888,7 +886,7 @@ func TestRedisEvents(t *testing.T) { } // produce a correct message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // get the client response select { case args := <-subscriptionOneArgsCh: @@ -920,7 +918,7 @@ func TestRedisEvents(t *testing.T) { NoRetryClient: true, }, func(t *testing.T, xEnv *testenv.Environment) { // start reading the messages from the channel - msgCh, err := readRedisMessages(t, xEnv, channels[0]) + msgCh, err := events.ReadRedisMessages(t, xEnv, channels[0]) require.NoError(t, err) // send a mutation to trigger the first subscription @@ -991,7 +989,7 @@ func TestRedisClusterEvents(t *testing.T) { xEnv.WaitForSubscriptionCount(1, RedisWaitTimeout) // produce a message - ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceRedisMessage(t, xEnv, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) // read the message select { @@ -1026,7 +1024,7 @@ func TestRedisClusterEvents(t *testing.T) { NoRetryClient: true, }, func(t *testing.T, xEnv *testenv.Environment) { // start reading the messages from the channel - msgCh, err := readRedisMessages(t, xEnv, channels[0]) + msgCh, err := events.ReadRedisMessages(t, xEnv, channels[0]) require.NoError(t, err) // send a mutation to produce a message @@ -1046,30 +1044,3 @@ func TestRedisClusterEvents(t *testing.T) { }) } - -func readRedisMessages(t *testing.T, xEnv *testenv.Environment, channelName string) (<-chan *redis.Message, error) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() - - parsedURL, err := url.Parse(xEnv.RedisHosts[0]) - if err != nil { - return nil, err - } - var redisConn redis.UniversalClient - if !xEnv.RedisWithClusterMode { - redisConn = redis.NewClient(&redis.Options{ - Addr: parsedURL.Host, - }) - } else { - redisConn = redis.NewClusterClient(&redis.ClusterOptions{ - Addrs: []string{parsedURL.Host}, - }) - } - sub := redisConn.Subscribe(ctx, xEnv.GetPubSubName(channelName)) - t.Cleanup(func() { - sub.Close() - redisConn.Close() - }) - - return sub.Channel(), nil -} diff --git a/router-tests/events/event_helpers.go b/router-tests/events/utils.go similarity index 55% rename from router-tests/events/event_helpers.go rename to router-tests/events/utils.go index 48d97e90c4..b8619c3368 100644 --- a/router-tests/events/event_helpers.go +++ b/router-tests/events/utils.go @@ -2,19 +2,37 @@ package events import ( "context" + "net/url" + "testing" + "time" + "github.com/redis/go-redis/v9" "github.com/stretchr/testify/require" "github.com/twmb/franz-go/pkg/kgo" "github.com/wundergraph/cosmo/router-tests/testenv" - "net/url" - "testing" - "time" ) -const waitTimeout = time.Second * 30 +func KafkaEnsureTopicExists(t *testing.T, xEnv *testenv.Environment, timeout time.Duration, topics ...string) { + // Delete topic for idempotency + deleteCtx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + prefixedTopics := make([]string, 0, len(topics)) + for _, topic := range topics { + prefixedTopics = append(prefixedTopics, xEnv.GetPubSubName(topic)) + } -func ProduceKafkaMessage(t *testing.T, xEnv *testenv.Environment, topicName string, message string) { - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + _, err := xEnv.KafkaAdminClient.DeleteTopics(deleteCtx, prefixedTopics...) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + _, err = xEnv.KafkaAdminClient.CreateTopics(ctx, 1, 1, nil, prefixedTopics...) + require.NoError(t, err) +} + +func ProduceKafkaMessage(t *testing.T, xEnv *testenv.Environment, timeout time.Duration, topicName string, message string) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() pErrCh := make(chan error) @@ -26,7 +44,7 @@ func ProduceKafkaMessage(t *testing.T, xEnv *testenv.Environment, topicName stri pErrCh <- err }) - testenv.AwaitChannelWithT(t, waitTimeout, pErrCh, func(t *testing.T, pErr error) { + testenv.AwaitChannelWithT(t, timeout, pErrCh, func(t *testing.T, pErr error) { require.NoError(t, pErr) }) @@ -34,23 +52,22 @@ func ProduceKafkaMessage(t *testing.T, xEnv *testenv.Environment, topicName stri require.NoError(t, fErr) } -func EnsureTopicExists(t *testing.T, xEnv *testenv.Environment, topics ...string) { - // Delete topic for idempotency - deleteCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) +func ReadKafkaMessages(xEnv *testenv.Environment, timeout time.Duration, topicName string, msgs int) ([]*kgo.Record, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() - prefixedTopics := make([]string, 0, len(topics)) - for _, topic := range topics { - prefixedTopics = append(prefixedTopics, xEnv.GetPubSubName(topic)) - } - _, err := xEnv.KafkaAdminClient.DeleteTopics(deleteCtx, prefixedTopics...) - require.NoError(t, err) + client, err := kgo.NewClient( + kgo.SeedBrokers(xEnv.GetKafkaSeeds()...), + kgo.ConsumeTopics(xEnv.GetPubSubName(topicName)), + ) + if err != nil { + return nil, err + } + defer client.Close() - ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) - defer cancel() + fetchs := client.PollRecords(ctx, msgs) - _, err = xEnv.KafkaAdminClient.CreateTopics(ctx, 1, 1, nil, prefixedTopics...) - require.NoError(t, err) + return fetchs.Records(), nil } func ProduceRedisMessage(t *testing.T, xEnv *testenv.Environment, topicName string, message string) { @@ -72,10 +89,33 @@ func ProduceRedisMessage(t *testing.T, xEnv *testenv.Environment, topicName stri }) } - defer func() { - _ = redisConn.Close() - }() - intCmd := redisConn.Publish(ctx, xEnv.GetPubSubName(topicName), message) require.NoError(t, intCmd.Err()) } + +func ReadRedisMessages(t *testing.T, xEnv *testenv.Environment, channelName string) (<-chan *redis.Message, error) { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + parsedURL, err := url.Parse(xEnv.RedisHosts[0]) + if err != nil { + return nil, err + } + var redisConn redis.UniversalClient + if !xEnv.RedisWithClusterMode { + redisConn = redis.NewClient(&redis.Options{ + Addr: parsedURL.Host, + }) + } else { + redisConn = redis.NewClusterClient(&redis.ClusterOptions{ + Addrs: []string{parsedURL.Host}, + }) + } + sub := redisConn.Subscribe(ctx, xEnv.GetPubSubName(channelName)) + t.Cleanup(func() { + sub.Close() + redisConn.Close() + }) + + return sub.Channel(), nil +} diff --git a/router-tests/go.mod b/router-tests/go.mod index 479d44590c..69af587c0b 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -27,7 +27,7 @@ require ( github.com/wundergraph/cosmo/demo/pkg/subgraphs/projects v0.0.0-20250715110703-10f2e5f9c79e github.com/wundergraph/cosmo/router v0.0.0-20250912064154-106e871ee32e github.com/wundergraph/cosmo/router-plugin v0.0.0-20250808194725-de123ba1c65e - github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb + github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20251001132016-1d6b66867259 go.opentelemetry.io/otel v1.36.0 go.opentelemetry.io/otel/sdk v1.36.0 go.opentelemetry.io/otel/sdk/metric v1.36.0 diff --git a/router-tests/go.sum b/router-tests/go.sum index 947f5ba76c..07c3dbd780 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -352,10 +352,8 @@ github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTB github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301 h1:EzfKHQoTjFDDcgaECCCR2aTePqMu9QBmPbyhqIYOhV0= github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301/go.mod h1:wxI0Nak5dI5RvJuzGyiEK4nZj0O9X+Aw6U0tC1wPKq0= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229 h1:VCfCX/xmpBGQLhTHJMHLugzJrXJk/smjLRAEruCI0HY= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb h1:stBTAle5FyytsTNxYeCwNzYlyhKzlS4he6f7/y6O3qE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20251001132016-1d6b66867259 h1:PhKYGyTBFM0JIihHLQa6tD5Al6GVFIPuJxi2T+DEiB0= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20251001132016-1d6b66867259/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4= github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= diff --git a/router-tests/modules/start-subscription/module.go b/router-tests/modules/start-subscription/module.go index fd5a9e0088..ffa94ef1f0 100644 --- a/router-tests/modules/start-subscription/module.go +++ b/router-tests/modules/start-subscription/module.go @@ -12,7 +12,7 @@ const myModuleID = "startSubscriptionModule" type StartSubscriptionModule struct { Logger *zap.Logger - Callback func(ctx core.SubscriptionOnStartHookContext) error + Callback func(ctx core.SubscriptionOnStartHandlerContext) error CallbackOnOriginResponse func(response *http.Response, ctx core.RequestContext) *http.Response } @@ -23,7 +23,7 @@ func (m *StartSubscriptionModule) Provision(ctx *core.ModuleContext) error { return nil } -func (m *StartSubscriptionModule) SubscriptionOnStart(ctx core.SubscriptionOnStartHookContext) error { +func (m *StartSubscriptionModule) SubscriptionOnStart(ctx core.SubscriptionOnStartHandlerContext) error { m.Logger.Info("SubscriptionOnStart Hook has been run") diff --git a/router-tests/modules/start_subscription_test.go b/router-tests/modules/start_subscription_test.go index ad286d54ef..b9d5e2f0ac 100644 --- a/router-tests/modules/start_subscription_test.go +++ b/router-tests/modules/start_subscription_test.go @@ -90,7 +90,7 @@ func TestStartSubscriptionHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "startSubscriptionModule": start_subscription.StartSubscriptionModule{ - Callback: func(ctx core.SubscriptionOnStartHookContext) error { + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { if ctx.SubscriptionEventConfiguration().RootFieldName() != "employeeUpdatedMyKafka" { return nil } @@ -179,9 +179,9 @@ func TestStartSubscriptionHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "startSubscriptionModule": start_subscription.StartSubscriptionModule{ - Callback: func(ctx core.SubscriptionOnStartHookContext) error { + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { callbackCalled <- true - return core.NewStreamHookError(nil, "subscription closed", http.StatusOK, "") + return core.NewHttpGraphqlError("subscription closed", http.StatusText(http.StatusOK), http.StatusOK) }, }, }, @@ -261,7 +261,7 @@ func TestStartSubscriptionHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "startSubscriptionModule": start_subscription.StartSubscriptionModule{ - Callback: func(ctx core.SubscriptionOnStartHookContext) error { + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { employeeId := ctx.Operation().Variables().GetInt64("employeeID") if employeeId != 1 { return nil @@ -365,8 +365,8 @@ func TestStartSubscriptionHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "startSubscriptionModule": start_subscription.StartSubscriptionModule{ - Callback: func(ctx core.SubscriptionOnStartHookContext) error { - return core.NewStreamHookError(errors.New("test error"), "test error", http.StatusLoopDetected, http.StatusText(http.StatusLoopDetected)) + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { + return core.NewHttpGraphqlError("test error", http.StatusText(http.StatusLoopDetected), http.StatusLoopDetected) }, }, }, @@ -509,7 +509,7 @@ func TestStartSubscriptionHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "startSubscriptionModule": start_subscription.StartSubscriptionModule{ - Callback: func(ctx core.SubscriptionOnStartHookContext) error { + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { ctx.WriteEvent(&core.EngineEvent{ Data: []byte(`{"data":{"countEmp":1000}}`), }) @@ -593,8 +593,8 @@ func TestStartSubscriptionHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "startSubscriptionModule": start_subscription.StartSubscriptionModule{ - Callback: func(ctx core.SubscriptionOnStartHookContext) error { - return core.NewStreamHookError(errors.New("subscription closed"), "subscription closed", http.StatusOK, "NotFound") + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { + return core.NewHttpGraphqlError("subscription closed", http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) }, CallbackOnOriginResponse: func(response *http.Response, ctx core.RequestContext) *http.Response { originResponseCalled <- response diff --git a/router-tests/modules/stream-publish/module.go b/router-tests/modules/stream-publish/module.go new file mode 100644 index 0000000000..e5553058ea --- /dev/null +++ b/router-tests/modules/stream-publish/module.go @@ -0,0 +1,49 @@ +package publish + +import ( + "go.uber.org/zap" + + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" +) + +const myModuleID = "publishModule" + +type PublishModule struct { + Logger *zap.Logger + Callback func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) +} + +func (m *PublishModule) Provision(ctx *core.ModuleContext) error { + // Assign the logger to the module for non-request related logging + m.Logger = ctx.Logger + + return nil +} + +func (m *PublishModule) OnPublishEvents(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + m.Logger.Info("Publish Hook has been run") + + if m.Callback != nil { + return m.Callback(ctx, events) + } + + return events, nil +} + +func (m *PublishModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + // This is the ID of your module, it must be unique + ID: myModuleID, + // The priority of your module, lower numbers are executed first + Priority: 1, + New: func() core.Module { + return &PublishModule{} + }, + } +} + +// Interface guard +var ( + _ core.StreamPublishEventHandler = (*PublishModule)(nil) +) diff --git a/router-tests/modules/stream-receive/module.go b/router-tests/modules/stream-receive/module.go new file mode 100644 index 0000000000..640218ad00 --- /dev/null +++ b/router-tests/modules/stream-receive/module.go @@ -0,0 +1,49 @@ +package batch + +import ( + "go.uber.org/zap" + + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" +) + +const myModuleID = "streamReceiveModule" + +type StreamReceiveModule struct { + Logger *zap.Logger + Callback func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) +} + +func (m *StreamReceiveModule) Provision(ctx *core.ModuleContext) error { + // Assign the logger to the module for non-request related logging + m.Logger = ctx.Logger + + return nil +} + +func (m *StreamReceiveModule) OnReceiveEvents(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + m.Logger.Info("Stream Hook has been run") + + if m.Callback != nil { + return m.Callback(ctx, events) + } + + return events, nil +} + +func (m *StreamReceiveModule) Module() core.ModuleInfo { + return core.ModuleInfo{ + // This is the ID of your module, it must be unique + ID: myModuleID, + // The priority of your module, lower numbers are executed first + Priority: 1, + New: func() core.Module { + return &StreamReceiveModule{} + }, + } +} + +// Interface guard +var ( + _ core.StreamReceiveEventHandler = (*StreamReceiveModule)(nil) +) diff --git a/router-tests/modules/stream_publish_test.go b/router-tests/modules/stream_publish_test.go new file mode 100644 index 0000000000..6fb7485dc3 --- /dev/null +++ b/router-tests/modules/stream_publish_test.go @@ -0,0 +1,315 @@ +package module_test + +import ( + "encoding/json" + "net/http" + "strconv" + "testing" + "time" + + "go.uber.org/zap/zapcore" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/events" + stream_publish "github.com/wundergraph/cosmo/router-tests/modules/stream-publish" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/kafka" +) + +func TestPublishHook(t *testing.T) { + t.Parallel() + + t.Run("Test Publish hook is called", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "publishModule": stream_publish.PublishModule{}, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_publish.PublishModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`, + }) + require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":false}}}`, resOne.Body) + + requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) + + t.Run("Test Publish kafka hook allows to set headers", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "publishModule": stream_publish.PublishModule{ + Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + for _, event := range events { + evt, ok := event.(*kafka.Event) + if !ok { + continue + } + evt.Headers["x-test"] = []byte("test") + } + + return events, nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_publish.PublishModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + events.KafkaEnsureTopicExists(t, xEnv, time.Second, "employeeUpdated") + resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`, + }) + require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":true}}}`, resOne.Body) + + requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") + assert.Len(t, requestLog.All(), 1) + + records, err := events.ReadKafkaMessages(xEnv, time.Second, "employeeUpdated", 1) + require.NoError(t, err) + require.Len(t, records, 1) + header := records[0].Headers[0] + require.Equal(t, "x-test", header.Key) + require.Equal(t, []byte("test"), header.Value) + }) + }) + + t.Run("Test kafka publish error is returned and messages sent", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "publishModule": stream_publish.PublishModule{ + Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + return events, core.NewHttpGraphqlError("test", http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_publish.PublishModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + events.KafkaEnsureTopicExists(t, xEnv, time.Second, "employeeUpdated") + resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`, + }) + require.JSONEq(t, `{"data": {"updateEmployeeMyKafka": {"success": false}}}`, resOne.Body) + require.Equal(t, resOne.Response.StatusCode, 200) + + requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") + assert.Len(t, requestLog.All(), 1) + + requestLog2 := xEnv.Observer().FilterMessage("error applying publish event hooks") + assert.Len(t, requestLog2.All(), 1) + + records, err := events.ReadKafkaMessages(xEnv, time.Second, "employeeUpdated", 1) + require.NoError(t, err) + require.Len(t, records, 1) + }) + }) + + t.Run("Test nats publish error is returned and messages sent", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "publishModule": stream_publish.PublishModule{ + Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + return events, core.NewHttpGraphqlError("test", http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsNatsJSONTemplate, + EnableNats: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_publish.PublishModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + firstSub, err := xEnv.NatsConnectionDefault.SubscribeSync(xEnv.GetPubSubName("employeeUpdatedMyNats.3")) + require.NoError(t, err) + t.Cleanup(func() { + _ = firstSub.Unsubscribe() + }) + require.NoError(t, xEnv.NatsConnectionDefault.Flush()) + resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation UpdateEmployeeNats($update: UpdateEmployeeInput!) { + updateEmployeeMyNats(id: 3, update: $update) {success} + }`, + Variables: json.RawMessage(`{"update":{"name":"Stefan Avramovic","email":"avramovic@wundergraph.com"}}`), + }) + assert.JSONEq(t, `{"data": {"updateEmployeeMyNats": {"success": false}}}`, resOne.Body) + + requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") + assert.Len(t, requestLog.All(), 1) + + requestLog2 := xEnv.Observer().FilterMessage("error applying publish event hooks") + assert.Len(t, requestLog2.All(), 1) + + msgOne, err := firstSub.NextMsg(5 * time.Second) + require.NoError(t, err) + require.Equal(t, xEnv.GetPubSubName("employeeUpdatedMyNats.3"), msgOne.Subject) + require.Equal(t, `{"id":3,"update":{"name":"Stefan Avramovic","email":"avramovic@wundergraph.com"}}`, string(msgOne.Data)) + require.NoError(t, err) + }) + }) + + t.Run("Test redis publish error is returned and messages sent", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "publishModule": stream_publish.PublishModule{ + Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + return events, core.NewHttpGraphqlError("test", http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsRedisJSONTemplate, + EnableRedis: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_publish.PublishModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + records, err := events.ReadRedisMessages(t, xEnv, "employeeUpdatedMyRedis") + require.NoError(t, err) + + resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation { updateEmployeeMyRedis(id: 3, update: {name: "name test"}) { success } }`, + }) + require.JSONEq(t, `{"data": {"updateEmployeeMyRedis": {"success": false}}}`, resOne.Body) + + requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") + assert.Len(t, requestLog.All(), 1) + + requestLog2 := xEnv.Observer().FilterMessage("error applying publish event hooks") + assert.Len(t, requestLog2.All(), 1) + + require.Len(t, records, 1) + }) + }) + + t.Run("Test kafka module publish with argument in header", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "publishModule": stream_publish.PublishModule{ + Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + if ctx.PublishEventConfiguration().RootFieldName() != "updateEmployeeMyKafka" { + return events, nil + } + + employeeID := ctx.Operation().Variables().GetInt("employeeID") + + newEvents := []datasource.StreamEvent{} + for _, event := range events { + evt, ok := event.(*kafka.Event) + if !ok { + continue + } + if evt.Headers == nil { + evt.Headers = map[string][]byte{} + } + evt.Headers["x-employee-id"] = []byte(strconv.Itoa(employeeID)) + newEvents = append(newEvents, event) + } + return newEvents, nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_publish.PublishModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + events.KafkaEnsureTopicExists(t, xEnv, time.Second, "employeeUpdated") + resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation UpdateEmployeeKafka($employeeID: Int!) { updateEmployeeMyKafka(employeeID: $employeeID, update: {name: "name test"}) { success } }`, + Variables: json.RawMessage(`{"employeeID": 3}`), + }) + require.JSONEq(t, `{"data": {"updateEmployeeMyKafka": {"success": true}}}`, resOne.Body) + + requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") + assert.Len(t, requestLog.All(), 1) + + records, err := events.ReadKafkaMessages(xEnv, time.Second, "employeeUpdated", 1) + require.NoError(t, err) + require.Len(t, records, 1) + header := records[0].Headers[0] + require.Equal(t, "x-employee-id", header.Key) + require.Equal(t, []byte("3"), header.Value) + }) + }) +} diff --git a/router-tests/modules/stream_receive_test.go b/router-tests/modules/stream_receive_test.go new file mode 100644 index 0000000000..a1658dc35c --- /dev/null +++ b/router-tests/modules/stream_receive_test.go @@ -0,0 +1,521 @@ +package module_test + +import ( + "errors" + "net/http" + "testing" + "time" + + "github.com/hasura/go-graphql-client" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + integration "github.com/wundergraph/cosmo/router-tests" + "github.com/wundergraph/cosmo/router-tests/events" + "github.com/wundergraph/cosmo/router-tests/jwks" + stream_receive "github.com/wundergraph/cosmo/router-tests/modules/stream-receive" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/authentication" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/kafka" + "go.uber.org/zap" + "go.uber.org/zap/zapcore" +) + +func TestReceiveHook(t *testing.T) { + t.Parallel() + + const Timeout = time.Second * 10 + + type kafkaSubscriptionArgs struct { + dataValue []byte + errValue error + } + + t.Run("Test Receive hook is called", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{}, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + subscriptionArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, Timeout) + + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + + testenv.AwaitChannelWithT(t, Timeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) + }) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, Timeout, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("Stream Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) + + t.Run("Test Receive hook could change events", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{ + Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + for _, event := range events { + evt, ok := event.(*kafka.Event) + if !ok { + continue + } + evt.Data = []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`) + } + + return events, nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + subscriptionArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, Timeout) + + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + + testenv.AwaitChannelWithT(t, Timeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":3,"details":{"forename":"Stefan","surname":"Avram"}}}`, string(args.dataValue)) + }) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, Timeout, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("Stream Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) + + t.Run("Test Receive hook change events of one of multiple subscriptions", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{ + Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + if ctx.Authentication() == nil { + return events, nil + } + if val, ok := ctx.Authentication().Claims()["sub"]; !ok || val != "user-2" { + return events, nil + } + for _, event := range events { + evt, ok := event.(*kafka.Event) + if !ok { + continue + } + evt.Data = []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`) + } + + return events, nil + }, + }, + }, + } + + authServer, err := jwks.NewServer(t) + require.NoError(t, err) + defer authServer.Close() + + JwksName := "my-jwks-server" + + tokenDecoder, _ := authentication.NewJwksTokenDecoder(integration.NewContextWithCancel(t), zap.NewNop(), []authentication.JWKSConfig{{ + URL: authServer.JWKSURL(), + RefreshInterval: time.Second * 5, + }}) + jwksOpts := authentication.HttpHeaderAuthenticatorOptions{ + Name: JwksName, + TokenDecoder: tokenDecoder, + } + + authenticator, err := authentication.NewHttpHeaderAuthenticator(jwksOpts) + require.NoError(t, err) + authenticators := []authentication.Authenticator{authenticator} + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + core.WithAccessController(core.NewAccessController(authenticators, false)), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + token, err := authServer.Token(map[string]interface{}{ + "sub": "user-2", + }) + require.NoError(t, err) + + headers := http.Header{ + "Authorization": []string{"Bearer " + token}, + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + client2 := graphql.NewSubscriptionClient(surl) + client2.WithWebSocketOptions(graphql.WebsocketOptions{ + HTTPHeader: headers, + }) + + subscriptionArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + subscriptionArgsCh2 := make(chan kafkaSubscriptionArgs) + subscriptionTwoID, err := client2.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsCh2 <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionTwoID) + + clientRunCh2 := make(chan error) + go func() { + clientRunCh2 <- client2.Run() + }() + + xEnv.WaitForSubscriptionCount(2, Timeout) + + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + + testenv.AwaitChannelWithT(t, Timeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + assert.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) + }) + + testenv.AwaitChannelWithT(t, Timeout, subscriptionArgsCh2, func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + assert.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":3,"details":{"forename":"Stefan","surname":"Avram"}}}`, string(args.dataValue)) + }) + + unSub1Err := client.Unsubscribe(subscriptionOneID) + require.NoError(t, unSub1Err) + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, Timeout, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + unSub2Err := client2.Unsubscribe(subscriptionTwoID) + require.NoError(t, unSub2Err) + require.NoError(t, client2.Close()) + testenv.AwaitChannelWithT(t, Timeout, clientRunCh2, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("Stream Hook has been run") + assert.Len(t, requestLog.All(), 2) + }) + }) + + t.Run("Test Receive hook can access custom header", func(t *testing.T) { + t.Parallel() + + customHeader := http.CanonicalHeaderKey("X-Custom-Header") + + cfg := config.Config{ + Graph: config.Graph{}, + + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{ + Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + if val, ok := ctx.Request().Header[customHeader]; !ok || val[0] != "Test" { + return events, nil + } + for _, event := range events { + evt, ok := event.(*kafka.Event) + if !ok { + continue + } + evt.Data = []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`) + } + + return events, nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + headers := http.Header{ + customHeader: []string{"Test"}, + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + client.WithWebSocketOptions(graphql.WebsocketOptions{ + HTTPHeader: headers, + }) + + subscriptionArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, Timeout) + + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + + testenv.AwaitChannelWithT(t, Timeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + assert.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":3,"details":{"forename":"Stefan","surname":"Avram"}}}`, string(args.dataValue)) + }) + + unSub1Err := client.Unsubscribe(subscriptionOneID) + require.NoError(t, unSub1Err) + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, Timeout, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("Stream Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) + + t.Run("Test Batch hook error should close Kafka clients and subscriptions", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{ + Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + return nil, errors.New("test error from streamevents hook") + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + subscriptionArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, Timeout) + + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + + // Wait for server to close the subscription connection + xEnv.WaitForSubscriptionCount(0, Timeout) + + // Verify that client.Run() completed when server closed the connection + testenv.AwaitChannelWithT(t, Timeout, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "client should have completed when server closed connection") + + xEnv.WaitForTriggerCount(0, Timeout) + }) + }) +} diff --git a/router-tests/modules/streams_hooks_combined_test.go b/router-tests/modules/streams_hooks_combined_test.go new file mode 100644 index 0000000000..78639dd052 --- /dev/null +++ b/router-tests/modules/streams_hooks_combined_test.go @@ -0,0 +1,149 @@ +package module + +import ( + "encoding/json" + "testing" + "time" + + "github.com/hasura/go-graphql-client" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router-tests/events" + stream_publish "github.com/wundergraph/cosmo/router-tests/modules/stream-publish" + stream_receive "github.com/wundergraph/cosmo/router-tests/modules/stream-receive" + "github.com/wundergraph/cosmo/router-tests/testenv" + "github.com/wundergraph/cosmo/router/core" + "github.com/wundergraph/cosmo/router/pkg/config" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + "github.com/wundergraph/cosmo/router/pkg/pubsub/kafka" + "go.uber.org/zap/zapcore" +) + +func TestStreamsHooksCombined(t *testing.T) { + t.Parallel() + + t.Run("Test kafka modules can depend on each other", func(t *testing.T) { + t.Parallel() + + type event struct { + data []byte + err error + } + + const Timeout = time.Second * 10 + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{ + Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + for _, event := range events { + evt, ok := event.(*kafka.Event) + if !ok { + continue + } + + if string(evt.Headers["x-publishModule"]) == "i_was_here" { + evt.Data = []byte(`{"__typename":"Employee","id": 2,"update":{"name":"irrelevant"}}`) + } + } + + return events, nil + }, + }, + "publishModule": stream_publish.PublishModule{ + Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + if ctx.PublishEventConfiguration().RootFieldName() != "updateEmployeeMyKafka" { + return events, nil + } + + for _, event := range events { + evt, ok := event.(*kafka.Event) + if !ok { + continue + } + evt.Headers["x-publishModule"] = []byte("i_was_here") + } + + return events, nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_publish.PublishModule{}, &stream_receive.StreamReceiveModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + // start a subscriber + var subscriptionPayload struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + subscriptionEventsChan := make(chan event) + subscriptionID, err := client.Subscribe(&subscriptionPayload, nil, func(dataValue []byte, errValue error) error { + subscriptionEventsChan <- event{ + data: dataValue, + err: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionID) + + clientRunChan := make(chan error) + go func() { + clientRunChan <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, Timeout) + + // publish a message to broker via mutation + // and let publish hook modify the message + resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation UpdateEmployeeKafka($employeeID: Int!) { updateEmployeeMyKafka(employeeID: $employeeID, update: {name: "name test"}) { success } }`, + Variables: json.RawMessage(`{"employeeID": 3}`), + }) + require.JSONEq(t, `{"data": {"updateEmployeeMyKafka": {"success": true}}}`, resOne.Body) + + requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") + assert.Len(t, requestLog.All(), 1) + + // wait for the message to be received by the subscriber + testenv.AwaitChannelWithT(t, Timeout, subscriptionEventsChan, func(t *testing.T, args event) { + require.NoError(t, args.err) + // verify that the stream batch hook modified the message, + // which it only does if the publish hook was run before it + require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":2,"details":{"forename":"Dustin","surname":"Deus"}}}`, string(args.data)) + }) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, Timeout, clientRunChan, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + requestLog = xEnv.Observer().FilterMessage("Stream Hook has been run") + assert.Len(t, requestLog.All(), 1) + }) + }) +} diff --git a/router-tests/prometheus_stream_metrics_test.go b/router-tests/prometheus_stream_metrics_test.go index 30fa87fe16..ac6d23d767 100644 --- a/router-tests/prometheus_stream_metrics_test.go +++ b/router-tests/prometheus_stream_metrics_test.go @@ -44,7 +44,7 @@ func TestFlakyEventMetrics(t *testing.T) { EnablePrometheusStreamMetrics: true, }, }, func(t *testing.T, xEnv *testenv.Environment) { - events.EnsureTopicExists(t, xEnv, "employeeUpdated") + events.KafkaEnsureTopicExists(t, xEnv, time.Second, "employeeUpdated") xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`}) xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`}) @@ -91,7 +91,7 @@ func TestFlakyEventMetrics(t *testing.T) { EnablePrometheusStreamMetrics: true, }, }, func(t *testing.T, xEnv *testenv.Environment) { - events.EnsureTopicExists(t, xEnv, topic) + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topic) var subscriptionOne struct { employeeUpdatedMyKafka struct { @@ -115,7 +115,7 @@ func TestFlakyEventMetrics(t *testing.T) { go func() { clientRunCh <- client.Run() }() xEnv.WaitForSubscriptionCount(1, WaitTimeout) - events.ProduceKafkaMessage(t, xEnv, topic, `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, time.Second, topic, `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) testenv.AwaitChannelWithT(t, WaitTimeout, subscriptionArgsCh, func(t *testing.T, args subscriptionArgs) { require.NoError(t, args.errValue) diff --git a/router-tests/telemetry/stream_metrics_test.go b/router-tests/telemetry/stream_metrics_test.go index 72ac7c654f..bac9aee748 100644 --- a/router-tests/telemetry/stream_metrics_test.go +++ b/router-tests/telemetry/stream_metrics_test.go @@ -45,7 +45,7 @@ func TestFlakyEventMetrics(t *testing.T) { EnableOTLPStreamMetrics: true, }, }, func(t *testing.T, xEnv *testenv.Environment) { - events.EnsureTopicExists(t, xEnv, "employeeUpdated") + events.KafkaEnsureTopicExists(t, xEnv, time.Second, "employeeUpdated") xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`}) xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`}) @@ -96,7 +96,7 @@ func TestFlakyEventMetrics(t *testing.T) { EnableOTLPStreamMetrics: true, }, }, func(t *testing.T, xEnv *testenv.Environment) { - events.EnsureTopicExists(t, xEnv, topic) + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topic) var subscriptionOne struct { employeeUpdatedMyKafka struct { @@ -120,7 +120,7 @@ func TestFlakyEventMetrics(t *testing.T) { go func() { clientRunCh <- client.Run() }() xEnv.WaitForSubscriptionCount(1, WaitTimeout) - events.ProduceKafkaMessage(t, xEnv, topic, `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + events.ProduceKafkaMessage(t, xEnv, time.Second, topic, `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) testenv.AwaitChannelWithT(t, WaitTimeout, subscriptionArgsCh, func(t *testing.T, args subscriptionArgs) { require.NoError(t, args.errValue) diff --git a/router/.mockery.yml b/router/.mockery.yml index 558bca2185..436ed0eb14 100644 --- a/router/.mockery.yml +++ b/router/.mockery.yml @@ -21,12 +21,6 @@ packages: github.com/wundergraph/cosmo/router/pkg/pubsub/nats: interfaces: Adapter: - github.com/wundergraph/cosmo/router/pkg/pubsub/kafka: - interfaces: - Adapter: - github.com/wundergraph/cosmo/router/pkg/pubsub/redis: - interfaces: - Adapter: github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve: config: dir: 'pkg/pubsub/datasource' diff --git a/router/core/errors.go b/router/core/errors.go index 44e05f327b..2ce688bbef 100644 --- a/router/core/errors.go +++ b/router/core/errors.go @@ -35,7 +35,7 @@ const ( errorTypeInvalidWsSubprotocol errorTypeEDFSInvalidMessage errorTypeMergeResult - errorTypeStreamHookError + errorTypeHttpError ) type ( @@ -90,9 +90,9 @@ func getErrorType(err error) errorType { if errors.As(err, &mergeResultErr) { return errorTypeMergeResult } - var streamHookErr *StreamHookError - if errors.As(err, &streamHookErr) { - return errorTypeStreamHookError + var httpError *httpGraphqlError + if errors.As(err, &httpError) { + return errorTypeHttpError } return errorTypeUnknown } diff --git a/router/core/factoryresolver.go b/router/core/factoryresolver.go index 70f3e7917c..d7c72fe579 100644 --- a/router/core/factoryresolver.go +++ b/router/core/factoryresolver.go @@ -481,6 +481,17 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod for i, fn := range l.subscriptionHooks.onStart { subscriptionOnStartFns[i] = NewPubSubSubscriptionOnStartHook(fn) } + + onPublishEventsFns := make([]pubsub_datasource.OnPublishEventsFn, len(l.subscriptionHooks.onPublishEvents)) + for i, fn := range l.subscriptionHooks.onPublishEvents { + onPublishEventsFns[i] = NewPubSubOnPublishEventsHook(fn) + } + + onReceiveEventsFns := make([]pubsub_datasource.OnReceiveEventsFn, len(l.subscriptionHooks.onReceiveEvents)) + for i, fn := range l.subscriptionHooks.onReceiveEvents { + onReceiveEventsFns[i] = NewPubSubOnReceiveEventsHook(fn) + } + factoryProviders, factoryDataSources, err := pubsub.BuildProvidersAndDataSources( l.ctx, routerEngineConfig.Events, @@ -489,8 +500,10 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod pubSubDS, l.resolver.InstanceData().HostName, l.resolver.InstanceData().ListenAddress, - pubsub.Hooks{ + pubsub_datasource.Hooks{ SubscriptionOnStart: subscriptionOnStartFns, + OnReceiveEvents: onReceiveEventsFns, + OnPublishEvents: onPublishEventsFns, }, ) if err != nil { diff --git a/router/core/graphql_handler.go b/router/core/graphql_handler.go index f387d73e6c..845b8bdac0 100644 --- a/router/core/graphql_handler.go +++ b/router/core/graphql_handler.go @@ -400,21 +400,21 @@ func (h *GraphQLHandler) WriteError(ctx *resolve.Context, err error, res *resolv if isHttpResponseWriter { httpWriter.WriteHeader(http.StatusInternalServerError) } - case errorTypeStreamHookError: - var streamHookErr *StreamHookError - if !errors.As(err, &streamHookErr) { + case errorTypeHttpError: + var httpErr *httpGraphqlError + if !errors.As(err, &httpErr) { response.Errors[0].Message = "Internal server error" return } - response.Errors[0].Message = streamHookErr.Message() - if streamHookErr.Code() != "" || streamHookErr.StatusCode() != 0 { + response.Errors[0].Message = httpErr.Message() + if httpErr.ExtensionCode() != "" || httpErr.StatusCode() != 0 { response.Errors[0].Extensions = &Extensions{ - Code: streamHookErr.Code(), - StatusCode: streamHookErr.StatusCode(), + Code: httpErr.ExtensionCode(), + StatusCode: httpErr.StatusCode(), } } if isHttpResponseWriter { - httpWriter.WriteHeader(streamHookErr.StatusCode()) + httpWriter.WriteHeader(httpErr.StatusCode()) } } diff --git a/router/core/router.go b/router/core/router.go index 9f07bd723a..04303023f0 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -670,6 +670,14 @@ func (r *Router) initModules(ctx context.Context) error { r.subscriptionHooks.onStart = append(r.subscriptionHooks.onStart, handler.SubscriptionOnStart) } + if handler, ok := moduleInstance.(StreamPublishEventHandler); ok { + r.subscriptionHooks.onPublishEvents = append(r.subscriptionHooks.onPublishEvents, handler.OnPublishEvents) + } + + if handler, ok := moduleInstance.(StreamReceiveEventHandler); ok { + r.subscriptionHooks.onReceiveEvents = append(r.subscriptionHooks.onReceiveEvents, handler.OnReceiveEvents) + } + r.modules = append(r.modules, moduleInstance) r.logger.Info("Module registered", diff --git a/router/core/router_config.go b/router/core/router_config.go index ac4f26d4c7..3e282d3c65 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -17,6 +17,7 @@ import ( "github.com/wundergraph/cosmo/router/pkg/health" "github.com/wundergraph/cosmo/router/pkg/mcpserver" rmetric "github.com/wundergraph/cosmo/router/pkg/metric" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" rtrace "github.com/wundergraph/cosmo/router/pkg/trace" "go.opentelemetry.io/otel/propagation" sdkmetric "go.opentelemetry.io/otel/sdk/metric" @@ -26,7 +27,9 @@ import ( ) type subscriptionHooks struct { - onStart []func(ctx SubscriptionOnStartHookContext) error + onStart []func(ctx SubscriptionOnStartHandlerContext) error + onPublishEvents []func(ctx StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) + onReceiveEvents []func(ctx StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) } type Config struct { diff --git a/router/core/subscriptions_modules.go b/router/core/subscriptions_modules.go index 505bbfc44f..e3279c811d 100644 --- a/router/core/subscriptions_modules.go +++ b/router/core/subscriptions_modules.go @@ -1,7 +1,9 @@ package core import ( + "context" "net/http" + "slices" "github.com/wundergraph/cosmo/router/pkg/authentication" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" @@ -10,43 +12,7 @@ import ( "go.uber.org/zap" ) -// StreamHookError is used to customize the error messages and the behavior -type StreamHookError struct { - err error - message string - statusCode int - code string -} - -func (e *StreamHookError) Error() string { - if e.err != nil { - return e.err.Error() - } - return e.message -} - -func (e *StreamHookError) Message() string { - return e.message -} - -func (e *StreamHookError) StatusCode() int { - return e.statusCode -} - -func (e *StreamHookError) Code() string { - return e.code -} - -func NewStreamHookError(err error, message string, statusCode int, code string) *StreamHookError { - return &StreamHookError{ - err: err, - message: message, - statusCode: statusCode, - code: code, - } -} - -type SubscriptionOnStartHookContext interface { +type SubscriptionOnStartHandlerContext interface { // Request is the original request received by the router. Request() *http.Request // Logger is the logger for the request @@ -62,11 +28,39 @@ type SubscriptionOnStartHookContext interface { WriteEvent(event datasource.StreamEvent) bool } -type pubSubSubscriptionOnStartHookContext struct { - request *http.Request +type pubSubPublishEventHookContext struct { + request *http.Request logger *zap.Logger operation OperationContext authentication authentication.Authentication + publishEventConfiguration datasource.PublishEventConfiguration +} + +func (c *pubSubPublishEventHookContext) Request() *http.Request { + return c.request +} + +func (c *pubSubPublishEventHookContext) Logger() *zap.Logger { + return c.logger +} + +func (c *pubSubPublishEventHookContext) Operation() OperationContext { + return c.operation +} + +func (c *pubSubPublishEventHookContext) Authentication() authentication.Authentication { + return c.authentication +} + +func (c *pubSubPublishEventHookContext) PublishEventConfiguration() datasource.PublishEventConfiguration { + return c.publishEventConfiguration +} + +type pubSubSubscriptionOnStartHookContext struct { + request *http.Request + logger *zap.Logger + operation OperationContext + authentication authentication.Authentication subscriptionEventConfiguration datasource.SubscriptionEventConfiguration writeEventHook func(data []byte) } @@ -106,11 +100,17 @@ func (e *EngineEvent) GetData() []byte { return e.Data } +func (e *EngineEvent) Clone() datasource.StreamEvent { + return &EngineEvent{ + Data: slices.Clone(e.Data), + } +} + type engineSubscriptionOnStartHookContext struct { - request *http.Request - logger *zap.Logger - operation OperationContext - authentication authentication.Authentication + request *http.Request + logger *zap.Logger + operation OperationContext + authentication authentication.Authentication writeEventHook func(data []byte) } @@ -143,11 +143,11 @@ func (c *engineSubscriptionOnStartHookContext) SubscriptionEventConfiguration() type SubscriptionOnStartHandler interface { // SubscriptionOnStart is called once at subscription start // The error is propagated to the client. - SubscriptionOnStart(ctx SubscriptionOnStartHookContext) error + SubscriptionOnStart(ctx SubscriptionOnStartHandlerContext) error } // NewPubSubSubscriptionOnStartHook converts a SubscriptionOnStartHandler to a pubsub.SubscriptionOnStartFn -func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHookContext) error) datasource.SubscriptionOnStartFn { +func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHandlerContext) error) datasource.SubscriptionOnStartFn { if fn == nil { return nil } @@ -155,10 +155,10 @@ func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHookContext return func(resolveCtx resolve.StartupHookContext, subConf datasource.SubscriptionEventConfiguration) error { requestContext := getRequestContext(resolveCtx.Context) hookCtx := &pubSubSubscriptionOnStartHookContext{ - request: requestContext.Request(), - logger: requestContext.Logger(), - operation: requestContext.Operation(), - authentication: requestContext.Authentication(), + request: requestContext.Request(), + logger: requestContext.Logger(), + operation: requestContext.Operation(), + authentication: requestContext.Authentication(), subscriptionEventConfiguration: subConf, writeEventHook: resolveCtx.Updater, } @@ -168,7 +168,7 @@ func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHookContext } // NewEngineSubscriptionOnStartHook converts a SubscriptionOnStartHandler to a graphql_datasource.SubscriptionOnStartFn -func NewEngineSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHookContext) error) graphql_datasource.SubscriptionOnStartFn { +func NewEngineSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHandlerContext) error) graphql_datasource.SubscriptionOnStartFn { if fn == nil { return nil } @@ -176,9 +176,9 @@ func NewEngineSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHookContext return func(resolveCtx resolve.StartupHookContext, input []byte) error { requestContext := getRequestContext(resolveCtx.Context) hookCtx := &engineSubscriptionOnStartHookContext{ - request: requestContext.Request(), - logger: requestContext.Logger(), - operation: requestContext.Operation(), + request: requestContext.Request(), + logger: requestContext.Logger(), + operation: requestContext.Operation(), authentication: requestContext.Authentication(), writeEventHook: resolveCtx.Updater, } @@ -186,3 +186,112 @@ func NewEngineSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHookContext return fn(hookCtx) } } + +type StreamReceiveEventHandlerContext interface { + // Request is the initial client request that started the subscription + Request() *http.Request + // Logger is the logger for the request + Logger() *zap.Logger + // Operation is the GraphQL operation + Operation() OperationContext + // Authentication is the authentication for the request + Authentication() authentication.Authentication + // SubscriptionEventConfiguration the subscription event configuration + SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration +} + +type StreamReceiveEventHandler interface { + // OnReceiveEvents is called each time a batch of events is received from the provider before delivering them to the + // client. So for a single batch of events received from the provider, this hook will be called one time for each + // active subscription. + // It is important to optimize the logic inside this hook to avoid performance issues. + // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a + // StreamHookError. + OnReceiveEvents(ctx StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) +} + +type StreamPublishEventHandlerContext interface { + // Request is the original request received by the router. + Request() *http.Request + // Logger is the logger for the request + Logger() *zap.Logger + // Operation is the GraphQL operation + Operation() OperationContext + // Authentication is the authentication for the request + Authentication() authentication.Authentication + // PublishEventConfiguration the publish event configuration + PublishEventConfiguration() datasource.PublishEventConfiguration +} + +type StreamPublishEventHandler interface { + // OnPublishEvents is called each time a batch of events is going to be sent to the provider + // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a + // StreamHookError. + OnPublishEvents(ctx StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) +} + +func NewPubSubOnPublishEventsHook(fn func(ctx StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error)) datasource.OnPublishEventsFn { + if fn == nil { + return nil + } + + return func(ctx context.Context, pubConf datasource.PublishEventConfiguration, evts []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + requestContext := getRequestContext(ctx) + hookCtx := &pubSubPublishEventHookContext{ + request: requestContext.Request(), + logger: requestContext.Logger(), + operation: requestContext.Operation(), + authentication: requestContext.Authentication(), + publishEventConfiguration: pubConf, + } + + return fn(hookCtx, evts) + } +} + +type pubSubStreamReceiveEventHookContext struct { + request *http.Request + logger *zap.Logger + operation OperationContext + authentication authentication.Authentication + subscriptionEventConfiguration datasource.SubscriptionEventConfiguration +} + +func (c *pubSubStreamReceiveEventHookContext) Request() *http.Request { + return c.request +} + +func (c *pubSubStreamReceiveEventHookContext) Logger() *zap.Logger { + return c.logger +} + +func (c *pubSubStreamReceiveEventHookContext) Operation() OperationContext { + return c.operation +} + +func (c *pubSubStreamReceiveEventHookContext) Authentication() authentication.Authentication { + return c.authentication +} + +func (c *pubSubStreamReceiveEventHookContext) SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration { + return c.subscriptionEventConfiguration +} + +func NewPubSubOnReceiveEventsHook(fn func(ctx StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error)) datasource.OnReceiveEventsFn { + if fn == nil { + return nil + } + + return func(ctx context.Context, subConf datasource.SubscriptionEventConfiguration, evts []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + requestContext := getRequestContext(ctx) + hookCtx := &pubSubStreamReceiveEventHookContext{ + request: requestContext.Request(), + logger: requestContext.Logger(), + operation: requestContext.Operation(), + authentication: requestContext.Authentication(), + subscriptionEventConfiguration: subConf, + } + + return fn(hookCtx, evts) + } +} diff --git a/router/go.mod b/router/go.mod index 82ff4f2e73..35b04f0033 100644 --- a/router/go.mod +++ b/router/go.mod @@ -31,7 +31,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/twmb/franz-go v1.16.1 - github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb + github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20251001132016-1d6b66867259 // Do not upgrade, it renames attributes we rely on go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 go.opentelemetry.io/contrib/propagators/b3 v1.23.0 diff --git a/router/go.sum b/router/go.sum index 1a0bc0afe5..09d82bed26 100644 --- a/router/go.sum +++ b/router/go.sum @@ -321,8 +321,8 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/ github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb h1:stBTAle5FyytsTNxYeCwNzYlyhKzlS4he6f7/y6O3qE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20250930144208-ddc652f78bbb/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20251001132016-1d6b66867259 h1:PhKYGyTBFM0JIihHLQa6tD5Al6GVFIPuJxi2T+DEiB0= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.229.0.20251001132016-1d6b66867259/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M= diff --git a/router/pkg/pubsub/datasource/datasource.go b/router/pkg/pubsub/datasource/datasource.go index 3a3018b745..b186041388 100644 --- a/router/pkg/pubsub/datasource/datasource.go +++ b/router/pkg/pubsub/datasource/datasource.go @@ -9,7 +9,7 @@ type SubscriptionDataSource interface { SubscriptionEventConfiguration(input []byte) (SubscriptionEventConfiguration, error) Start(ctx *resolve.Context, input []byte, updater resolve.SubscriptionUpdater) error UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) (err error) - SetSubscriptionOnStartFns(fns ...SubscriptionOnStartFn) + SetHooks(hooks Hooks) } // EngineDataSourceFactory is the interface that all pubsub data sources must implement. diff --git a/router/pkg/pubsub/datasource/factory.go b/router/pkg/pubsub/datasource/factory.go index 5c42161776..ae25e6dbcf 100644 --- a/router/pkg/pubsub/datasource/factory.go +++ b/router/pkg/pubsub/datasource/factory.go @@ -9,16 +9,18 @@ import ( ) type PlannerConfig[PB ProviderBuilder[P, E], P any, E any] struct { - ProviderBuilder PB - Event E - SubscriptionOnStartFns []SubscriptionOnStartFn + Providers map[string]Provider + ProviderBuilder PB + Event E + Hooks Hooks } -func NewPlannerConfig[PB ProviderBuilder[P, E], P any, E any](providerBuilder PB, event E, subscriptionOnStartFns []SubscriptionOnStartFn) *PlannerConfig[PB, P, E] { +func NewPlannerConfig[PB ProviderBuilder[P, E], P any, E any](providerBuilder PB, event E, providers map[string]Provider, hooks Hooks) *PlannerConfig[PB, P, E] { return &PlannerConfig[PB, P, E]{ - ProviderBuilder: providerBuilder, - Event: event, - SubscriptionOnStartFns: subscriptionOnStartFns, + Providers: providers, + ProviderBuilder: providerBuilder, + Event: event, + Hooks: hooks, } } diff --git a/router/pkg/pubsub/datasource/hooks.go b/router/pkg/pubsub/datasource/hooks.go new file mode 100644 index 0000000000..abab8b8ef1 --- /dev/null +++ b/router/pkg/pubsub/datasource/hooks.go @@ -0,0 +1,20 @@ +package datasource + +import ( + "context" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +) + +type SubscriptionOnStartFn func(ctx resolve.StartupHookContext, subConf SubscriptionEventConfiguration) error + +type OnPublishEventsFn func(ctx context.Context, pubConf PublishEventConfiguration, evts []StreamEvent) ([]StreamEvent, error) + +type OnReceiveEventsFn func(ctx context.Context, subConf SubscriptionEventConfiguration, evts []StreamEvent) ([]StreamEvent, error) + +// Hooks contains hooks for the pubsub providers and data sources +type Hooks struct { + SubscriptionOnStart []SubscriptionOnStartFn + OnReceiveEvents []OnReceiveEventsFn + OnPublishEvents []OnPublishEventsFn +} diff --git a/router/pkg/pubsub/datasource/mocks.go b/router/pkg/pubsub/datasource/mocks.go index 861beb3987..3c56f09919 100644 --- a/router/pkg/pubsub/datasource/mocks.go +++ b/router/pkg/pubsub/datasource/mocks.go @@ -556,6 +556,109 @@ func (_c *MockProvider_ID_Call) RunAndReturn(run func() string) *MockProvider_ID return _c } +// Publish provides a mock function for the type MockProvider +func (_mock *MockProvider) Publish(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) error { + ret := _mock.Called(ctx, cfg, events) + + if len(ret) == 0 { + panic("no return value specified for Publish") + } + + var r0 error + if returnFunc, ok := ret.Get(0).(func(context.Context, PublishEventConfiguration, []StreamEvent) error); ok { + r0 = returnFunc(ctx, cfg, events) + } else { + r0 = ret.Error(0) + } + return r0 +} + +// MockProvider_Publish_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Publish' +type MockProvider_Publish_Call struct { + *mock.Call +} + +// Publish is a helper method to define mock.On call +// - ctx context.Context +// - cfg PublishEventConfiguration +// - events []StreamEvent +func (_e *MockProvider_Expecter) Publish(ctx interface{}, cfg interface{}, events interface{}) *MockProvider_Publish_Call { + return &MockProvider_Publish_Call{Call: _e.mock.On("Publish", ctx, cfg, events)} +} + +func (_c *MockProvider_Publish_Call) Run(run func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent)) *MockProvider_Publish_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 PublishEventConfiguration + if args[1] != nil { + arg1 = args[1].(PublishEventConfiguration) + } + var arg2 []StreamEvent + if args[2] != nil { + arg2 = args[2].([]StreamEvent) + } + run( + arg0, + arg1, + arg2, + ) + }) + return _c +} + +func (_c *MockProvider_Publish_Call) Return(err error) *MockProvider_Publish_Call { + _c.Call.Return(err) + return _c +} + +func (_c *MockProvider_Publish_Call) RunAndReturn(run func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) error) *MockProvider_Publish_Call { + _c.Call.Return(run) + return _c +} + +// SetHooks provides a mock function for the type MockProvider +func (_mock *MockProvider) SetHooks(hooks Hooks) { + _mock.Called(hooks) + return +} + +// MockProvider_SetHooks_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetHooks' +type MockProvider_SetHooks_Call struct { + *mock.Call +} + +// SetHooks is a helper method to define mock.On call +// - hooks Hooks +func (_e *MockProvider_Expecter) SetHooks(hooks interface{}) *MockProvider_SetHooks_Call { + return &MockProvider_SetHooks_Call{Call: _e.mock.On("SetHooks", hooks)} +} + +func (_c *MockProvider_SetHooks_Call) Run(run func(hooks Hooks)) *MockProvider_SetHooks_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 Hooks + if args[0] != nil { + arg0 = args[0].(Hooks) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockProvider_SetHooks_Call) Return() *MockProvider_SetHooks_Call { + _c.Call.Return() + return _c +} + +func (_c *MockProvider_SetHooks_Call) RunAndReturn(run func(hooks Hooks)) *MockProvider_SetHooks_Call { + _c.Run(run) + return _c +} + // Shutdown provides a mock function for the type MockProvider func (_mock *MockProvider) Shutdown(ctx context.Context) error { ret := _mock.Called(ctx) @@ -793,8 +896,8 @@ func (_m *MockProviderBuilder[P, E]) EXPECT() *MockProviderBuilder_Expecter[P, E } // BuildEngineDataSourceFactory provides a mock function for the type MockProviderBuilder -func (_mock *MockProviderBuilder[P, E]) BuildEngineDataSourceFactory(data E) (EngineDataSourceFactory, error) { - ret := _mock.Called(data) +func (_mock *MockProviderBuilder[P, E]) BuildEngineDataSourceFactory(data E, providers map[string]Provider) (EngineDataSourceFactory, error) { + ret := _mock.Called(data, providers) if len(ret) == 0 { panic("no return value specified for BuildEngineDataSourceFactory") @@ -802,18 +905,18 @@ func (_mock *MockProviderBuilder[P, E]) BuildEngineDataSourceFactory(data E) (En var r0 EngineDataSourceFactory var r1 error - if returnFunc, ok := ret.Get(0).(func(E) (EngineDataSourceFactory, error)); ok { - return returnFunc(data) + if returnFunc, ok := ret.Get(0).(func(E, map[string]Provider) (EngineDataSourceFactory, error)); ok { + return returnFunc(data, providers) } - if returnFunc, ok := ret.Get(0).(func(E) EngineDataSourceFactory); ok { - r0 = returnFunc(data) + if returnFunc, ok := ret.Get(0).(func(E, map[string]Provider) EngineDataSourceFactory); ok { + r0 = returnFunc(data, providers) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(EngineDataSourceFactory) } } - if returnFunc, ok := ret.Get(1).(func(E) error); ok { - r1 = returnFunc(data) + if returnFunc, ok := ret.Get(1).(func(E, map[string]Provider) error); ok { + r1 = returnFunc(data, providers) } else { r1 = ret.Error(1) } @@ -827,18 +930,24 @@ type MockProviderBuilder_BuildEngineDataSourceFactory_Call[P any, E any] struct // BuildEngineDataSourceFactory is a helper method to define mock.On call // - data E -func (_e *MockProviderBuilder_Expecter[P, E]) BuildEngineDataSourceFactory(data interface{}) *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E] { - return &MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E]{Call: _e.mock.On("BuildEngineDataSourceFactory", data)} +// - providers map[string]Provider +func (_e *MockProviderBuilder_Expecter[P, E]) BuildEngineDataSourceFactory(data interface{}, providers interface{}) *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E] { + return &MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E]{Call: _e.mock.On("BuildEngineDataSourceFactory", data, providers)} } -func (_c *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E]) Run(run func(data E)) *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E] { +func (_c *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E]) Run(run func(data E, providers map[string]Provider)) *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E] { _c.Call.Run(func(args mock.Arguments) { var arg0 E if args[0] != nil { arg0 = args[0].(E) } + var arg1 map[string]Provider + if args[1] != nil { + arg1 = args[1].(map[string]Provider) + } run( arg0, + arg1, ) }) return _c @@ -849,7 +958,7 @@ func (_c *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E]) Return(en return _c } -func (_c *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E]) RunAndReturn(run func(data E) (EngineDataSourceFactory, error)) *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E] { +func (_c *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E]) RunAndReturn(run func(data E, providers map[string]Provider) (EngineDataSourceFactory, error)) *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E] { _c.Call.Return(run) return _c } @@ -1060,9 +1169,49 @@ func (_c *MockSubscriptionEventUpdater_Complete_Call) RunAndReturn(run func()) * return _c } +// SetHooks provides a mock function for the type MockSubscriptionEventUpdater +func (_mock *MockSubscriptionEventUpdater) SetHooks(hooks Hooks) { + _mock.Called(hooks) + return +} + +// MockSubscriptionEventUpdater_SetHooks_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'SetHooks' +type MockSubscriptionEventUpdater_SetHooks_Call struct { + *mock.Call +} + +// SetHooks is a helper method to define mock.On call +// - hooks Hooks +func (_e *MockSubscriptionEventUpdater_Expecter) SetHooks(hooks interface{}) *MockSubscriptionEventUpdater_SetHooks_Call { + return &MockSubscriptionEventUpdater_SetHooks_Call{Call: _e.mock.On("SetHooks", hooks)} +} + +func (_c *MockSubscriptionEventUpdater_SetHooks_Call) Run(run func(hooks Hooks)) *MockSubscriptionEventUpdater_SetHooks_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 Hooks + if args[0] != nil { + arg0 = args[0].(Hooks) + } + run( + arg0, + ) + }) + return _c +} + +func (_c *MockSubscriptionEventUpdater_SetHooks_Call) Return() *MockSubscriptionEventUpdater_SetHooks_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSubscriptionEventUpdater_SetHooks_Call) RunAndReturn(run func(hooks Hooks)) *MockSubscriptionEventUpdater_SetHooks_Call { + _c.Run(run) + return _c +} + // Update provides a mock function for the type MockSubscriptionEventUpdater -func (_mock *MockSubscriptionEventUpdater) Update(event StreamEvent) { - _mock.Called(event) +func (_mock *MockSubscriptionEventUpdater) Update(events []StreamEvent) { + _mock.Called(events) return } @@ -1072,16 +1221,16 @@ type MockSubscriptionEventUpdater_Update_Call struct { } // Update is a helper method to define mock.On call -// - event StreamEvent -func (_e *MockSubscriptionEventUpdater_Expecter) Update(event interface{}) *MockSubscriptionEventUpdater_Update_Call { - return &MockSubscriptionEventUpdater_Update_Call{Call: _e.mock.On("Update", event)} +// - events []StreamEvent +func (_e *MockSubscriptionEventUpdater_Expecter) Update(events interface{}) *MockSubscriptionEventUpdater_Update_Call { + return &MockSubscriptionEventUpdater_Update_Call{Call: _e.mock.On("Update", events)} } -func (_c *MockSubscriptionEventUpdater_Update_Call) Run(run func(event StreamEvent)) *MockSubscriptionEventUpdater_Update_Call { +func (_c *MockSubscriptionEventUpdater_Update_Call) Run(run func(events []StreamEvent)) *MockSubscriptionEventUpdater_Update_Call { _c.Call.Run(func(args mock.Arguments) { - var arg0 StreamEvent + var arg0 []StreamEvent if args[0] != nil { - arg0 = args[0].(StreamEvent) + arg0 = args[0].([]StreamEvent) } run( arg0, @@ -1095,7 +1244,7 @@ func (_c *MockSubscriptionEventUpdater_Update_Call) Return() *MockSubscriptionEv return _c } -func (_c *MockSubscriptionEventUpdater_Update_Call) RunAndReturn(run func(event StreamEvent)) *MockSubscriptionEventUpdater_Update_Call { +func (_c *MockSubscriptionEventUpdater_Update_Call) RunAndReturn(run func(events []StreamEvent)) *MockSubscriptionEventUpdater_Update_Call { _c.Run(run) return _c } diff --git a/router/pkg/pubsub/datasource/mocks_resolve.go b/router/pkg/pubsub/datasource/mocks_resolve.go index 3efc24b405..19bad89c16 100644 --- a/router/pkg/pubsub/datasource/mocks_resolve.go +++ b/router/pkg/pubsub/datasource/mocks_resolve.go @@ -5,6 +5,8 @@ package datasource import ( + "context" + mock "github.com/stretchr/testify/mock" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) @@ -76,6 +78,52 @@ func (_c *MockSubscriptionUpdater_Close_Call) RunAndReturn(run func(kind resolve return _c } +// CloseSubscription provides a mock function for the type MockSubscriptionUpdater +func (_mock *MockSubscriptionUpdater) CloseSubscription(kind resolve.SubscriptionCloseKind, id resolve.SubscriptionIdentifier) { + _mock.Called(kind, id) + return +} + +// MockSubscriptionUpdater_CloseSubscription_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'CloseSubscription' +type MockSubscriptionUpdater_CloseSubscription_Call struct { + *mock.Call +} + +// CloseSubscription is a helper method to define mock.On call +// - kind resolve.SubscriptionCloseKind +// - id resolve.SubscriptionIdentifier +func (_e *MockSubscriptionUpdater_Expecter) CloseSubscription(kind interface{}, id interface{}) *MockSubscriptionUpdater_CloseSubscription_Call { + return &MockSubscriptionUpdater_CloseSubscription_Call{Call: _e.mock.On("CloseSubscription", kind, id)} +} + +func (_c *MockSubscriptionUpdater_CloseSubscription_Call) Run(run func(kind resolve.SubscriptionCloseKind, id resolve.SubscriptionIdentifier)) *MockSubscriptionUpdater_CloseSubscription_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 resolve.SubscriptionCloseKind + if args[0] != nil { + arg0 = args[0].(resolve.SubscriptionCloseKind) + } + var arg1 resolve.SubscriptionIdentifier + if args[1] != nil { + arg1 = args[1].(resolve.SubscriptionIdentifier) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockSubscriptionUpdater_CloseSubscription_Call) Return() *MockSubscriptionUpdater_CloseSubscription_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSubscriptionUpdater_CloseSubscription_Call) RunAndReturn(run func(kind resolve.SubscriptionCloseKind, id resolve.SubscriptionIdentifier)) *MockSubscriptionUpdater_CloseSubscription_Call { + _c.Run(run) + return _c +} + // Complete provides a mock function for the type MockSubscriptionUpdater func (_mock *MockSubscriptionUpdater) Complete() { _mock.Called() @@ -109,6 +157,52 @@ func (_c *MockSubscriptionUpdater_Complete_Call) RunAndReturn(run func()) *MockS return _c } +// Subscriptions provides a mock function for the type MockSubscriptionUpdater +func (_mock *MockSubscriptionUpdater) Subscriptions() map[context.Context]resolve.SubscriptionIdentifier { + ret := _mock.Called() + + if len(ret) == 0 { + panic("no return value specified for Subscriptions") + } + + var r0 map[context.Context]resolve.SubscriptionIdentifier + if returnFunc, ok := ret.Get(0).(func() map[context.Context]resolve.SubscriptionIdentifier); ok { + r0 = returnFunc() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(map[context.Context]resolve.SubscriptionIdentifier) + } + } + return r0 +} + +// MockSubscriptionUpdater_Subscriptions_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Subscriptions' +type MockSubscriptionUpdater_Subscriptions_Call struct { + *mock.Call +} + +// Subscriptions is a helper method to define mock.On call +func (_e *MockSubscriptionUpdater_Expecter) Subscriptions() *MockSubscriptionUpdater_Subscriptions_Call { + return &MockSubscriptionUpdater_Subscriptions_Call{Call: _e.mock.On("Subscriptions")} +} + +func (_c *MockSubscriptionUpdater_Subscriptions_Call) Run(run func()) *MockSubscriptionUpdater_Subscriptions_Call { + _c.Call.Run(func(args mock.Arguments) { + run() + }) + return _c +} + +func (_c *MockSubscriptionUpdater_Subscriptions_Call) Return(contextToSubscriptionIdentifier map[context.Context]resolve.SubscriptionIdentifier) *MockSubscriptionUpdater_Subscriptions_Call { + _c.Call.Return(contextToSubscriptionIdentifier) + return _c +} + +func (_c *MockSubscriptionUpdater_Subscriptions_Call) RunAndReturn(run func() map[context.Context]resolve.SubscriptionIdentifier) *MockSubscriptionUpdater_Subscriptions_Call { + _c.Call.Return(run) + return _c +} + // Update provides a mock function for the type MockSubscriptionUpdater func (_mock *MockSubscriptionUpdater) Update(data []byte) { _mock.Called(data) @@ -148,3 +242,49 @@ func (_c *MockSubscriptionUpdater_Update_Call) RunAndReturn(run func(data []byte _c.Run(run) return _c } + +// UpdateSubscription provides a mock function for the type MockSubscriptionUpdater +func (_mock *MockSubscriptionUpdater) UpdateSubscription(id resolve.SubscriptionIdentifier, data []byte) { + _mock.Called(id, data) + return +} + +// MockSubscriptionUpdater_UpdateSubscription_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'UpdateSubscription' +type MockSubscriptionUpdater_UpdateSubscription_Call struct { + *mock.Call +} + +// UpdateSubscription is a helper method to define mock.On call +// - id resolve.SubscriptionIdentifier +// - data []byte +func (_e *MockSubscriptionUpdater_Expecter) UpdateSubscription(id interface{}, data interface{}) *MockSubscriptionUpdater_UpdateSubscription_Call { + return &MockSubscriptionUpdater_UpdateSubscription_Call{Call: _e.mock.On("UpdateSubscription", id, data)} +} + +func (_c *MockSubscriptionUpdater_UpdateSubscription_Call) Run(run func(id resolve.SubscriptionIdentifier, data []byte)) *MockSubscriptionUpdater_UpdateSubscription_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 resolve.SubscriptionIdentifier + if args[0] != nil { + arg0 = args[0].(resolve.SubscriptionIdentifier) + } + var arg1 []byte + if args[1] != nil { + arg1 = args[1].([]byte) + } + run( + arg0, + arg1, + ) + }) + return _c +} + +func (_c *MockSubscriptionUpdater_UpdateSubscription_Call) Return() *MockSubscriptionUpdater_UpdateSubscription_Call { + _c.Call.Return() + return _c +} + +func (_c *MockSubscriptionUpdater_UpdateSubscription_Call) RunAndReturn(run func(id resolve.SubscriptionIdentifier, data []byte)) *MockSubscriptionUpdater_UpdateSubscription_Call { + _c.Run(run) + return _c +} diff --git a/router/pkg/pubsub/datasource/planner.go b/router/pkg/pubsub/datasource/planner.go index a480f8270e..f0378caa88 100644 --- a/router/pkg/pubsub/datasource/planner.go +++ b/router/pkg/pubsub/datasource/planner.go @@ -48,7 +48,7 @@ func (p *Planner[PB, P, E]) ConfigureFetch() resolve.FetchConfiguration { return resolve.FetchConfiguration{} } - pubSubDataSource, err := p.config.ProviderBuilder.BuildEngineDataSourceFactory(p.config.Event) + pubSubDataSource, err := p.config.ProviderBuilder.BuildEngineDataSourceFactory(p.config.Event, p.config.Providers) if err != nil { p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to build data source: %w", err)) return resolve.FetchConfiguration{} @@ -93,7 +93,7 @@ func (p *Planner[PB, P, E]) ConfigureSubscription() plan.SubscriptionConfigurati return plan.SubscriptionConfiguration{} } - pubSubDataSource, err := p.config.ProviderBuilder.BuildEngineDataSourceFactory(p.config.Event) + pubSubDataSource, err := p.config.ProviderBuilder.BuildEngineDataSourceFactory(p.config.Event, p.config.Providers) if err != nil { p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to get resolve data source subscription: %w", err)) return plan.SubscriptionConfiguration{} @@ -109,7 +109,7 @@ func (p *Planner[PB, P, E]) ConfigureSubscription() plan.SubscriptionConfigurati p.visitor.Walker.StopWithInternalErr(fmt.Errorf("failed to get resolve data source subscription: %w", err)) return plan.SubscriptionConfiguration{} } - dataSource.SetSubscriptionOnStartFns(p.config.SubscriptionOnStartFns...) + dataSource.SetHooks(p.config.Hooks) input, err := pubSubDataSource.ResolveDataSourceSubscriptionInput() if err != nil { diff --git a/router/pkg/pubsub/datasource/provider.go b/router/pkg/pubsub/datasource/provider.go index 33cac33782..57bbb70ed7 100644 --- a/router/pkg/pubsub/datasource/provider.go +++ b/router/pkg/pubsub/datasource/provider.go @@ -4,7 +4,6 @@ import ( "context" "github.com/wundergraph/cosmo/router/pkg/metric" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) type ArgumentTemplateCallback func(tpl string) (string, error) @@ -23,6 +22,7 @@ type Lifecycle interface { type Adapter interface { Lifecycle Subscribe(ctx context.Context, cfg SubscriptionEventConfiguration, updater SubscriptionEventUpdater) error + Publish(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) error } // Provider is the interface that the PubSub provider must implement @@ -32,6 +32,8 @@ type Provider interface { ID() string // TypeID Get the provider type id (e.g. "kafka", "nats") TypeID() string + // SetHooks Set the hooks + SetHooks(Hooks) } // ProviderBuilder is the interface that the provider builder must implement. @@ -41,7 +43,7 @@ type ProviderBuilder[P, E any] interface { // BuildProvider Build the provider and the adapter BuildProvider(options P, providerOpts ProviderOpts) (Provider, error) // BuildEngineDataSourceFactory Build the data source for the given provider and event configuration - BuildEngineDataSourceFactory(data E) (EngineDataSourceFactory, error) + BuildEngineDataSourceFactory(data E, providers map[string]Provider) (EngineDataSourceFactory, error) } // ProviderType represents the type of pubsub provider @@ -58,10 +60,9 @@ const ( // there could be other common fields in the future, but for now we only have data type StreamEvent interface { GetData() []byte + Clone() StreamEvent } -type SubscriptionOnStartFn func(ctx resolve.StartupHookContext, subConf SubscriptionEventConfiguration) error - // SubscriptionEventConfiguration is the interface that all subscription event configurations must implement type SubscriptionEventConfiguration interface { ProviderID() string diff --git a/router/pkg/pubsub/datasource/pubsubprovider.go b/router/pkg/pubsub/datasource/pubsubprovider.go index 84561b06db..e234ebfb73 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider.go +++ b/router/pkg/pubsub/datasource/pubsubprovider.go @@ -11,6 +11,21 @@ type PubSubProvider struct { typeID string Adapter Adapter Logger *zap.Logger + hooks Hooks +} + +// applyPublishEventHooks processes events through a chain of hook functions +// Each hook receives the result from the previous hook, creating a proper middleware pipeline +func applyPublishEventHooks(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, hooks []OnPublishEventsFn) ([]StreamEvent, error) { + currentEvents := events + for _, hook := range hooks { + var err error + currentEvents, err = hook(ctx, cfg, currentEvents) + if err != nil { + return currentEvents, err + } + } + return currentEvents, nil } func (p *PubSubProvider) ID() string { @@ -39,6 +54,34 @@ func (p *PubSubProvider) Subscribe(ctx context.Context, cfg SubscriptionEventCon return p.Adapter.Subscribe(ctx, cfg, updater) } +func (p *PubSubProvider) Publish(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) error { + if len(p.hooks.OnPublishEvents) == 0 { + return p.Adapter.Publish(ctx, cfg, events) + } + + processedEvents, hooksErr := applyPublishEventHooks(ctx, cfg, events, p.hooks.OnPublishEvents) + if hooksErr != nil { + p.Logger.Error( + "error applying publish event hooks", + zap.Error(hooksErr), + zap.String("provider_id", cfg.ProviderID()), + zap.String("provider_type_id", string(cfg.ProviderType())), + zap.String("field_name", cfg.RootFieldName()), + ) + } + + errPublish := p.Adapter.Publish(ctx, cfg, processedEvents) + if errPublish != nil { + return errPublish + } + + return hooksErr +} + +func (p *PubSubProvider) SetHooks(hooks Hooks) { + p.hooks = hooks +} + func NewPubSubProvider(id string, typeID string, adapter Adapter, logger *zap.Logger) *PubSubProvider { return &PubSubProvider{ id: id, diff --git a/router/pkg/pubsub/datasource/pubsubprovider_test.go b/router/pkg/pubsub/datasource/pubsubprovider_test.go index 134bfbd6bb..6ef41c56a5 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider_test.go +++ b/router/pkg/pubsub/datasource/pubsubprovider_test.go @@ -1,14 +1,67 @@ package datasource import ( + "bytes" "context" "errors" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "go.uber.org/zap" ) +// Test helper types +type testEvent struct { + data []byte +} + +func (e *testEvent) GetData() []byte { + return e.data +} + +func (e *testEvent) Clone() StreamEvent { + return &testEvent{ + data: bytes.Clone(e.data), + } +} + +type testSubscriptionConfig struct { + providerID string + providerType ProviderType + fieldName string +} + +func (c *testSubscriptionConfig) ProviderID() string { + return c.providerID +} + +func (c *testSubscriptionConfig) ProviderType() ProviderType { + return c.providerType +} + +func (c *testSubscriptionConfig) RootFieldName() string { + return c.fieldName +} + +type testPublishConfig struct { + providerID string + providerType ProviderType + fieldName string +} + +func (c *testPublishConfig) ProviderID() string { + return c.providerID +} + +func (c *testPublishConfig) ProviderType() ProviderType { + return c.providerType +} + +func (c *testPublishConfig) RootFieldName() string { + return c.fieldName +} + func TestProvider_Startup_Success(t *testing.T) { mockAdapter := NewMockProvider(t) mockAdapter.On("Startup", mock.Anything).Return(nil) @@ -57,18 +110,375 @@ func TestProvider_Shutdown_Error(t *testing.T) { assert.Error(t, err) } -func TestProvider_ID(t *testing.T) { - const testID = "test-id" +func TestProvider_Subscribe_Success(t *testing.T) { + mockAdapter := NewMockProvider(t) + mockUpdater := NewMockSubscriptionEventUpdater(t) + config := &testSubscriptionConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + + mockAdapter.On("Subscribe", mock.Anything, config, mockUpdater).Return(nil) + provider := PubSubProvider{ - id: testID, + Adapter: mockAdapter, } - assert.Equal(t, testID, provider.ID()) + err := provider.Subscribe(context.Background(), config, mockUpdater) + + assert.NoError(t, err) } -func TestProvider_TypeID(t *testing.T) { - const providerTypeID = "test-type-id" +func TestProvider_Subscribe_Error(t *testing.T) { + mockAdapter := NewMockProvider(t) + mockUpdater := NewMockSubscriptionEventUpdater(t) + config := &testSubscriptionConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + expectedError := errors.New("subscription error") + + mockAdapter.On("Subscribe", mock.Anything, config, mockUpdater).Return(expectedError) + provider := PubSubProvider{ - typeID: providerTypeID, + Adapter: mockAdapter, } - assert.Equal(t, providerTypeID, provider.TypeID()) + err := provider.Subscribe(context.Background(), config, mockUpdater) + + assert.Error(t, err) + assert.Equal(t, expectedError, err) +} + +func TestProvider_Publish_NoHooks_Success(t *testing.T) { + mockAdapter := NewMockProvider(t) + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + events := []StreamEvent{ + &testEvent{data: []byte("test data 1")}, + &testEvent{data: []byte("test data 2")}, + } + + mockAdapter.On("Publish", mock.Anything, config, events).Return(nil) + + provider := PubSubProvider{ + Adapter: mockAdapter, + hooks: Hooks{}, // No hooks + } + err := provider.Publish(context.Background(), config, events) + + assert.NoError(t, err) +} + +func TestProvider_Publish_NoHooks_Error(t *testing.T) { + mockAdapter := NewMockProvider(t) + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + events := []StreamEvent{ + &testEvent{data: []byte("test data")}, + } + expectedError := errors.New("publish error") + + mockAdapter.On("Publish", mock.Anything, config, events).Return(expectedError) + + provider := PubSubProvider{ + Adapter: mockAdapter, + hooks: Hooks{}, // No hooks + } + err := provider.Publish(context.Background(), config, events) + + assert.Error(t, err) + assert.Equal(t, expectedError, err) +} + +func TestProvider_Publish_WithHooks_Success(t *testing.T) { + mockAdapter := NewMockProvider(t) + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original data")}, + } + modifiedEvents := []StreamEvent{ + &testEvent{data: []byte("modified data")}, + } + + // Define hook that modifies events + testHook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return modifiedEvents, nil + } + + mockAdapter.On("Publish", mock.Anything, config, modifiedEvents).Return(nil) + + provider := PubSubProvider{ + Adapter: mockAdapter, + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{testHook}, + }, + } + err := provider.Publish(context.Background(), config, originalEvents) + + assert.NoError(t, err) +} + +func TestProvider_Publish_WithHooks_HookError(t *testing.T) { + mockAdapter := NewMockProvider(t) + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + events := []StreamEvent{ + &testEvent{data: []byte("test data")}, + } + hookError := errors.New("hook processing error") + + // Define hook that returns an error + testHook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return nil, hookError + } + + mockAdapter.On("Publish", mock.Anything, config, []StreamEvent(nil)).Return(nil) + + // Should call Publish on adapter also if hook fails + provider := PubSubProvider{ + Adapter: mockAdapter, + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{testHook}, + }, + Logger: zap.NewNop(), + } + err := provider.Publish(context.Background(), config, events) + + assert.Error(t, err) + assert.Equal(t, hookError, err) +} + +func TestProvider_Publish_WithHooks_AdapterError(t *testing.T) { + mockAdapter := NewMockProvider(t) + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original data")}, + } + processedEvents := []StreamEvent{ + &testEvent{data: []byte("processed data")}, + } + adapterError := errors.New("adapter publish error") + + // Define hook that processes events successfully + testHook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return processedEvents, nil + } + + mockAdapter.On("Publish", mock.Anything, config, processedEvents).Return(adapterError) + + provider := PubSubProvider{ + Adapter: mockAdapter, + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{testHook}, + }, + } + err := provider.Publish(context.Background(), config, originalEvents) + + assert.Error(t, err) + assert.Equal(t, adapterError, err) +} + +func TestProvider_Publish_WithMultipleHooks_Success(t *testing.T) { + mockAdapter := NewMockProvider(t) + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + + // Chain of hooks that modify the data + hook1 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return []StreamEvent{&testEvent{data: []byte("modified by hook1")}}, nil + } + hook2 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return []StreamEvent{&testEvent{data: []byte("modified by hook2")}}, nil + } + + mockAdapter.On("Publish", mock.Anything, config, mock.MatchedBy(func(events []StreamEvent) bool { + return len(events) == 1 && string(events[0].GetData()) == "modified by hook2" + })).Return(nil) + + provider := PubSubProvider{ + Adapter: mockAdapter, + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{hook1, hook2}, + }, + } + err := provider.Publish(context.Background(), config, originalEvents) + + assert.NoError(t, err) +} + +func TestProvider_SetHooks(t *testing.T) { + provider := &PubSubProvider{} + + testHook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return events, nil + } + + hooks := Hooks{ + OnPublishEvents: []OnPublishEventsFn{testHook}, + } + + provider.SetHooks(hooks) + + assert.Equal(t, hooks, provider.hooks) +} + +func TestNewPubSubProvider(t *testing.T) { + mockAdapter := NewMockProvider(t) + logger := zap.NewNop() + id := "test-provider-id" + typeID := "test-type-id" + + provider := NewPubSubProvider(id, typeID, mockAdapter, logger) + + assert.NotNil(t, provider) + assert.Equal(t, id, provider.ID()) + assert.Equal(t, typeID, provider.TypeID()) + assert.Equal(t, mockAdapter, provider.Adapter) + assert.Equal(t, logger, provider.Logger) + assert.Empty(t, provider.hooks.OnPublishEvents) +} + +func TestApplyPublishEventHooks_NoHooks(t *testing.T) { + ctx := context.Background() + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("test data")}, + } + + result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{}) + + assert.NoError(t, err) + assert.Equal(t, originalEvents, result) +} + +func TestApplyPublishEventHooks_SingleHook_Success(t *testing.T) { + ctx := context.Background() + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + modifiedEvents := []StreamEvent{ + &testEvent{data: []byte("modified")}, + } + + hook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return modifiedEvents, nil + } + + result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{hook}) + + assert.NoError(t, err) + assert.Equal(t, modifiedEvents, result) +} + +func TestApplyPublishEventHooks_SingleHook_Error(t *testing.T) { + ctx := context.Background() + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + hookError := errors.New("hook processing failed") + + hook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return nil, hookError + } + + result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{hook}) + + assert.Error(t, err) + assert.Equal(t, hookError, err) + assert.Nil(t, result) +} + +func TestApplyPublishEventHooks_MultipleHooks_Success(t *testing.T) { + ctx := context.Background() + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + + hook1 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return []StreamEvent{&testEvent{data: []byte("step1")}}, nil + } + hook2 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return []StreamEvent{&testEvent{data: []byte("step2")}}, nil + } + hook3 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return []StreamEvent{&testEvent{data: []byte("final")}}, nil + } + + result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{hook1, hook2, hook3}) + + assert.NoError(t, err) + assert.Len(t, result, 1) + assert.Equal(t, "final", string(result[0].GetData())) +} + +func TestApplyPublishEventHooks_MultipleHooks_MiddleHookError(t *testing.T) { + ctx := context.Background() + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + middleHookError := errors.New("middle hook failed") + + hook1 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return []StreamEvent{&testEvent{data: []byte("step1")}}, nil + } + hook2 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return nil, middleHookError + } + hook3 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return []StreamEvent{&testEvent{data: []byte("final")}}, nil + } + + result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{hook1, hook2, hook3}) + + assert.Error(t, err) + assert.Equal(t, middleHookError, err) + assert.Nil(t, result) } diff --git a/router/pkg/pubsub/datasource/subscription_datasource.go b/router/pkg/pubsub/datasource/subscription_datasource.go index e5c9c26ab6..16ec03171a 100644 --- a/router/pkg/pubsub/datasource/subscription_datasource.go +++ b/router/pkg/pubsub/datasource/subscription_datasource.go @@ -6,6 +6,7 @@ import ( "github.com/cespare/xxhash/v2" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" ) type uniqueRequestIdFn func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error @@ -13,9 +14,10 @@ type uniqueRequestIdFn func(ctx *resolve.Context, input []byte, xxh *xxhash.Dige // PubSubSubscriptionDataSource is a data source for handling subscriptions using a Pub/Sub mechanism. // It implements the SubscriptionDataSource interface and HookableSubscriptionDataSource type PubSubSubscriptionDataSource[C SubscriptionEventConfiguration] struct { - pubSub Adapter - uniqueRequestID uniqueRequestIdFn - subscriptionOnStartFns []SubscriptionOnStartFn + pubSub Adapter + uniqueRequestID uniqueRequestIdFn + hooks Hooks + logger *zap.Logger } func (s *PubSubSubscriptionDataSource[C]) SubscriptionEventConfiguration(input []byte) (SubscriptionEventConfiguration, error) { @@ -39,11 +41,11 @@ func (s *PubSubSubscriptionDataSource[C]) Start(ctx *resolve.Context, input []by return errors.New("invalid subscription configuration") } - return s.pubSub.Subscribe(ctx.Context(), conf, NewSubscriptionEventUpdater(updater)) + return s.pubSub.Subscribe(ctx.Context(), conf, NewSubscriptionEventUpdater(conf, s.hooks, updater, s.logger)) } func (s *PubSubSubscriptionDataSource[C]) SubscriptionOnStart(ctx resolve.StartupHookContext, input []byte) (err error) { - for _, fn := range s.subscriptionOnStartFns { + for _, fn := range s.hooks.SubscriptionOnStart { conf, errConf := s.SubscriptionEventConfiguration(input) if errConf != nil { return err @@ -57,16 +59,20 @@ func (s *PubSubSubscriptionDataSource[C]) SubscriptionOnStart(ctx resolve.Startu return nil } -func (s *PubSubSubscriptionDataSource[C]) SetSubscriptionOnStartFns(fns ...SubscriptionOnStartFn) { - s.subscriptionOnStartFns = append(s.subscriptionOnStartFns, fns...) +func (s *PubSubSubscriptionDataSource[C]) SetHooks(hooks Hooks) { + s.hooks = hooks } var _ SubscriptionDataSource = (*PubSubSubscriptionDataSource[SubscriptionEventConfiguration])(nil) var _ resolve.HookableSubscriptionDataSource = (*PubSubSubscriptionDataSource[SubscriptionEventConfiguration])(nil) -func NewPubSubSubscriptionDataSource[C SubscriptionEventConfiguration](pubSub Adapter, uniqueRequestIdFn uniqueRequestIdFn) *PubSubSubscriptionDataSource[C] { +func NewPubSubSubscriptionDataSource[C SubscriptionEventConfiguration](pubSub Adapter, uniqueRequestIdFn uniqueRequestIdFn, logger *zap.Logger) *PubSubSubscriptionDataSource[C] { + if logger == nil { + logger = zap.NewNop() + } return &PubSubSubscriptionDataSource[C]{ pubSub: pubSub, uniqueRequestID: uniqueRequestIdFn, + logger: logger, } } diff --git a/router/pkg/pubsub/datasource/subscription_datasource_test.go b/router/pkg/pubsub/datasource/subscription_datasource_test.go index a9170d7edd..c82b339faa 100644 --- a/router/pkg/pubsub/datasource/subscription_datasource_test.go +++ b/router/pkg/pubsub/datasource/subscription_datasource_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" ) // testSubscriptionEventConfiguration implements SubscriptionEventConfiguration for testing @@ -36,7 +37,7 @@ func TestPubSubSubscriptionDataSource_SubscriptionEventConfiguration_Success(t * return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) testConfig := testSubscriptionEventConfiguration{ Topic: "test-topic", @@ -61,7 +62,7 @@ func TestPubSubSubscriptionDataSource_SubscriptionEventConfiguration_InvalidJSON return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) invalidInput := []byte(`{"invalid": json}`) result, err := dataSource.SubscriptionEventConfiguration(invalidInput) @@ -75,7 +76,7 @@ func TestPubSubSubscriptionDataSource_UniqueRequestID_Success(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) ctx := &resolve.Context{} input := []byte(`{"test": "data"}`) @@ -92,7 +93,7 @@ func TestPubSubSubscriptionDataSource_UniqueRequestID_Error(t *testing.T) { return expectedError } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) ctx := &resolve.Context{} input := []byte(`{"test": "data"}`) @@ -109,7 +110,7 @@ func TestPubSubSubscriptionDataSource_Start_Success(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) testConfig := testSubscriptionEventConfiguration{ Topic: "test-topic", @@ -134,7 +135,7 @@ func TestPubSubSubscriptionDataSource_Start_NoConfiguration(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) invalidInput := []byte(`{"invalid": json}`) ctx := resolve.NewContext(context.Background()) @@ -151,7 +152,7 @@ func TestPubSubSubscriptionDataSource_Start_SubscribeError(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) testConfig := testSubscriptionEventConfiguration{ Topic: "test-topic", @@ -178,7 +179,7 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_Success(t *testing.T) return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) testConfig := testSubscriptionEventConfiguration{ Topic: "test-topic", @@ -202,7 +203,7 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_WithHooks(t *testing.T return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) // Add subscription start hooks hook1Called := false @@ -218,7 +219,9 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_WithHooks(t *testing.T return nil } - dataSource.SetSubscriptionOnStartFns(hook1, hook2) + dataSource.SetHooks(Hooks{ + SubscriptionOnStart: []SubscriptionOnStartFn{hook1, hook2}, + }) testConfig := testSubscriptionEventConfiguration{ Topic: "test-topic", @@ -238,13 +241,46 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_WithHooks(t *testing.T assert.True(t, hook2Called) } +func TestPubSubSubscriptionDataSource_SubscriptionOnStart_HookReturnsClose(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) + + // Add hook that returns close=true + hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { + return nil + } + + dataSource.SetHooks(Hooks{ + SubscriptionOnStart: []SubscriptionOnStartFn{hook}, + }) + + testConfig := testSubscriptionEventConfiguration{ + Topic: "test-topic", + Subject: "test-subject", + } + input, err := json.Marshal(testConfig) + assert.NoError(t, err) + + ctx := resolve.StartupHookContext{ + Context: context.Background(), + Updater: func(data []byte) {}, + } + + errSubStart := dataSource.SubscriptionOnStart(ctx, input) + assert.NoError(t, errSubStart) +} + func TestPubSubSubscriptionDataSource_SubscriptionOnStart_HookReturnsError(t *testing.T) { mockAdapter := NewMockProvider(t) uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) expectedError := errors.New("hook error") // Add hook that returns an error @@ -252,7 +288,9 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_HookReturnsError(t *te return expectedError } - dataSource.SetSubscriptionOnStartFns(hook) + dataSource.SetHooks(Hooks{ + SubscriptionOnStart: []SubscriptionOnStartFn{hook}, + }) testConfig := testSubscriptionEventConfiguration{ Topic: "test-topic", @@ -277,10 +315,10 @@ func TestPubSubSubscriptionDataSource_SetSubscriptionOnStartFns(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) // Initially should have no hooks - assert.Len(t, dataSource.subscriptionOnStartFns, 0) + assert.Len(t, dataSource.hooks.SubscriptionOnStart, 0) // Add hooks hook1 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { @@ -290,11 +328,15 @@ func TestPubSubSubscriptionDataSource_SetSubscriptionOnStartFns(t *testing.T) { return nil } - dataSource.SetSubscriptionOnStartFns(hook1) - assert.Len(t, dataSource.subscriptionOnStartFns, 1) + dataSource.SetHooks(Hooks{ + SubscriptionOnStart: []SubscriptionOnStartFn{hook1}, + }) + assert.Len(t, dataSource.hooks.SubscriptionOnStart, 1) - dataSource.SetSubscriptionOnStartFns(hook2) - assert.Len(t, dataSource.subscriptionOnStartFns, 2) + dataSource.SetHooks(Hooks{ + SubscriptionOnStart: []SubscriptionOnStartFn{hook2}, + }) + assert.Len(t, dataSource.hooks.SubscriptionOnStart, 1) } func TestNewPubSubSubscriptionDataSource(t *testing.T) { @@ -303,12 +345,12 @@ func TestNewPubSubSubscriptionDataSource(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) assert.NotNil(t, dataSource) assert.Equal(t, mockAdapter, dataSource.pubSub) assert.NotNil(t, dataSource.uniqueRequestID) - assert.Empty(t, dataSource.subscriptionOnStartFns) + assert.Empty(t, dataSource.hooks.SubscriptionOnStart) } func TestPubSubSubscriptionDataSource_InterfaceCompliance(t *testing.T) { @@ -317,7 +359,7 @@ func TestPubSubSubscriptionDataSource_InterfaceCompliance(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) // Test that it implements SubscriptionDataSource interface var _ SubscriptionDataSource = dataSource diff --git a/router/pkg/pubsub/datasource/subscription_event_updater.go b/router/pkg/pubsub/datasource/subscription_event_updater.go index 9332d10f7a..95289bb313 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater.go @@ -1,34 +1,112 @@ package datasource -import "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" +import ( + "context" + + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" +) // SubscriptionEventUpdater is a wrapper around the SubscriptionUpdater interface // that provides a way to send the event struct instead of the raw data // It is used to give access to the event additional fields to the hooks. type SubscriptionEventUpdater interface { - Update(event StreamEvent) + Update(events []StreamEvent) Complete() Close(kind resolve.SubscriptionCloseKind) + SetHooks(hooks Hooks) } type subscriptionEventUpdater struct { - eventUpdater resolve.SubscriptionUpdater + eventUpdater resolve.SubscriptionUpdater + subscriptionEventConfiguration SubscriptionEventConfiguration + hooks Hooks + logger *zap.Logger +} + +func (s *subscriptionEventUpdater) Update(events []StreamEvent) { + if len(s.hooks.OnReceiveEvents) == 0 { + for _, event := range events { + s.eventUpdater.Update(event.GetData()) + } + return + } + // If there are hooks, we should apply them separated for each subscription + for ctx, subId := range s.eventUpdater.Subscriptions() { + processedEvents, err := applyStreamEventHooks( + ctx, + s.subscriptionEventConfiguration, + events, + s.hooks.OnReceiveEvents, + ) + // updates the events even if the hooks fail + // if a hook doesn't want to send the events, it should return no events! + for _, event := range processedEvents { + s.eventUpdater.UpdateSubscription(subId, event.GetData()) + } + if err != nil { + // For all errors, log them + if s.logger != nil { + s.logger.Error( + "An error occurred while processing stream events hooks", + zap.Error(err), + zap.String("provider_type", string(s.subscriptionEventConfiguration.ProviderType())), + zap.String("provider_id", s.subscriptionEventConfiguration.ProviderID()), + zap.String("field_name", s.subscriptionEventConfiguration.RootFieldName()), + ) + } + // Always close the subscription when a hook reports an error to avoid inconsistent state. + s.eventUpdater.CloseSubscription(resolve.SubscriptionCloseKindNormal, subId) + } + } } -func (h *subscriptionEventUpdater) Update(event StreamEvent) { - h.eventUpdater.Update(event.GetData()) +func (s *subscriptionEventUpdater) Complete() { + s.eventUpdater.Complete() } -func (h *subscriptionEventUpdater) Complete() { - h.eventUpdater.Complete() +func (s *subscriptionEventUpdater) Close(kind resolve.SubscriptionCloseKind) { + s.eventUpdater.Close(kind) } -func (h *subscriptionEventUpdater) Close(kind resolve.SubscriptionCloseKind) { - h.eventUpdater.Close(kind) +func (s *subscriptionEventUpdater) SetHooks(hooks Hooks) { + s.hooks = hooks +} + +// applyStreamEventHooks processes events through a chain of hook functions +// Each hook receives the result from the previous hook, creating a proper middleware pipeline +func applyStreamEventHooks( + ctx context.Context, + cfg SubscriptionEventConfiguration, + events []StreamEvent, + hooks []OnReceiveEventsFn) ([]StreamEvent, error) { + // Copy the events to avoid modifying the original slice + currentEvents := make([]StreamEvent, len(events)) + for i, event := range events { + currentEvents[i] = event.Clone() + } + // Apply each hook in sequence, passing the result of one as the input to the next + // If any hook returns an error, stop processing and return the error + for _, hook := range hooks { + var err error + currentEvents, err = hook(ctx, cfg, currentEvents) + if err != nil { + return currentEvents, err + } + } + return currentEvents, nil } -func NewSubscriptionEventUpdater(eventUpdater resolve.SubscriptionUpdater) SubscriptionEventUpdater { +func NewSubscriptionEventUpdater( + cfg SubscriptionEventConfiguration, + hooks Hooks, + eventUpdater resolve.SubscriptionUpdater, + logger *zap.Logger, +) SubscriptionEventUpdater { return &subscriptionEventUpdater{ - eventUpdater: eventUpdater, + subscriptionEventConfiguration: cfg, + hooks: hooks, + eventUpdater: eventUpdater, + logger: logger, } } diff --git a/router/pkg/pubsub/datasource/subscription_event_updater_test.go b/router/pkg/pubsub/datasource/subscription_event_updater_test.go new file mode 100644 index 0000000000..79fd140a51 --- /dev/null +++ b/router/pkg/pubsub/datasource/subscription_event_updater_test.go @@ -0,0 +1,627 @@ +package datasource + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" + "go.uber.org/zap/zaptest/observer" +) + +// Test helper type for subscription event configuration +type testSubscriptionEventConfig struct { + providerID string + providerType ProviderType + fieldName string +} + +func (c *testSubscriptionEventConfig) ProviderID() string { + return c.providerID +} + +func (c *testSubscriptionEventConfig) ProviderType() ProviderType { + return c.providerType +} + +func (c *testSubscriptionEventConfig) RootFieldName() string { + return c.fieldName +} + +type receivedHooksArgs struct { + events []StreamEvent + cfg SubscriptionEventConfiguration +} + +func TestSubscriptionEventUpdater_Update_NoHooks(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + events := []StreamEvent{ + &testEvent{data: []byte("test data 1")}, + &testEvent{data: []byte("test data 2")}, + } + + // Expect calls to Update for each event + mockUpdater.On("Update", []byte("test data 1")).Return() + mockUpdater.On("Update", []byte("test data 2")).Return() + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{}, // No hooks + } + + updater.Update(events) +} + +func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Success(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original data")}, + } + modifiedEvents := []StreamEvent{ + &testEvent{data: []byte("modified data")}, + } + + // Create wrapper function for the mock + receivedArgs := make(chan receivedHooksArgs, 1) + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + receivedArgs <- receivedHooksArgs{events: events, cfg: cfg} + return modifiedEvents, nil + } + + // Expect call to UpdateSubscription with modified data + subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} + mockUpdater.On("UpdateSubscription", subId, []byte("modified data")).Return() + mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ + context.Background(): subId, + }) + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{testHook}, + }, + } + + updater.Update(originalEvents) + + select { + case receivedArgs := <-receivedArgs: + assert.Equal(t, originalEvents, receivedArgs.events) + assert.Equal(t, config, receivedArgs.cfg) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for events") + } +} + +func TestSubscriptionEventUpdater_UpdateSubscriptions_WithHooks_Error(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + events := []StreamEvent{ + &testEvent{data: []byte("test data")}, + } + hookError := errors.New("hook processing error") + + // Define hook that returns an error + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return nil, hookError + } + + // Expect call to UpdateSubscription with modified data + subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} + mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ + context.Background(): subId, + }) + mockUpdater.On("CloseSubscription", resolve.SubscriptionCloseKindNormal, subId).Return() + + // Should not call Update or UpdateSubscription on eventUpdater since hook fails + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{testHook}, + }, + } + + updater.Update(events) + + // Assert that Update and UpdateSubscription were not called on the eventUpdater + mockUpdater.AssertNotCalled(t, "Update") + mockUpdater.AssertNotCalled(t, "UpdateSubscription") + mockUpdater.AssertCalled(t, "CloseSubscription", resolve.SubscriptionCloseKindNormal, subId) +} + +func TestSubscriptionEventUpdater_Update_WithMultipleHooks_Success(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + + // Chain of hooks that modify the data + receivedArgs1 := make(chan receivedHooksArgs, 1) + hook1 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + receivedArgs1 <- receivedHooksArgs{events: events, cfg: cfg} + return []StreamEvent{&testEvent{data: []byte("modified by hook1")}}, nil + } + + receivedArgs2 := make(chan receivedHooksArgs, 1) + hook2 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + receivedArgs2 <- receivedHooksArgs{events: events, cfg: cfg} + return []StreamEvent{&testEvent{data: []byte("modified by hook2")}}, nil + } + + // Expect call to UpdateSubscription with modified data + subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} + mockUpdater.On("UpdateSubscription", subId, []byte("modified by hook2")).Return() + mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ + context.Background(): subId, + }) + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{hook1, hook2}, + }, + } + + updater.Update(originalEvents) + + select { + case receivedArgs1 := <-receivedArgs1: + assert.Equal(t, originalEvents, receivedArgs1.events) + assert.Equal(t, config, receivedArgs1.cfg) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for events") + } + + select { + case receivedArgs2 := <-receivedArgs2: + assert.Equal(t, []StreamEvent{&testEvent{data: []byte("modified by hook1")}}, receivedArgs2.events) + assert.Equal(t, config, receivedArgs2.cfg) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for events") + } +} + +func TestSubscriptionEventUpdater_Complete(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + + mockUpdater.On("Complete").Return() + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{}, + } + + updater.Complete() +} + +func TestSubscriptionEventUpdater_Close(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + closeKind := resolve.SubscriptionCloseKindNormal + + mockUpdater.On("Close", closeKind).Return() + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{}, + } + + updater.Close(closeKind) +} + +func TestSubscriptionEventUpdater_SetHooks(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return events, nil + } + + hooks := Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{testHook}, + } + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{}, + } + + updater.SetHooks(hooks) + + assert.Equal(t, hooks, updater.hooks) +} + +func TestNewSubscriptionEventUpdater(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return events, nil + } + + hooks := Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{testHook}, + } + + updater := NewSubscriptionEventUpdater(config, hooks, mockUpdater, zap.NewNop()) + + assert.NotNil(t, updater) + + // Type assertion to access private fields for testing + var concreteUpdater *subscriptionEventUpdater + assert.IsType(t, concreteUpdater, updater) + concreteUpdater = updater.(*subscriptionEventUpdater) + assert.Equal(t, config, concreteUpdater.subscriptionEventConfiguration) + assert.Equal(t, hooks, concreteUpdater.hooks) + assert.Equal(t, mockUpdater, concreteUpdater.eventUpdater) +} + +func TestApplyStreamEventHooks_NoHooks(t *testing.T) { + ctx := context.Background() + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("test data")}, + } + + result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{}) + + assert.NoError(t, err) + assert.Equal(t, originalEvents, result) +} + +func TestApplyStreamEventHooks_SingleHook_Success(t *testing.T) { + ctx := context.Background() + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + modifiedEvents := []StreamEvent{ + &testEvent{data: []byte("modified")}, + } + + hook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return modifiedEvents, nil + } + + result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook}) + + assert.NoError(t, err) + assert.Equal(t, modifiedEvents, result) +} + +func TestApplyStreamEventHooks_SingleHook_Error(t *testing.T) { + ctx := context.Background() + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + hookError := errors.New("hook processing failed") + + hook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return nil, hookError + } + + result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook}) + + assert.Error(t, err) + assert.Equal(t, hookError, err) + assert.Nil(t, result) +} + +func TestApplyStreamEventHooks_MultipleHooks_Success(t *testing.T) { + ctx := context.Background() + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + + receivedArgs1 := make(chan receivedHooksArgs, 1) + hook1 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + receivedArgs1 <- receivedHooksArgs{events: events, cfg: cfg} + return []StreamEvent{&testEvent{data: []byte("step1")}}, nil + } + receivedArgs2 := make(chan receivedHooksArgs, 1) + hook2 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + receivedArgs2 <- receivedHooksArgs{events: events, cfg: cfg} + return []StreamEvent{&testEvent{data: []byte("step2")}}, nil + } + receivedArgs3 := make(chan receivedHooksArgs, 1) + hook3 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + receivedArgs3 <- receivedHooksArgs{events: events, cfg: cfg} + return []StreamEvent{&testEvent{data: []byte("final")}}, nil + } + + result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook1, hook2, hook3}) + + select { + case receivedArgs1 := <-receivedArgs1: + assert.Equal(t, originalEvents, receivedArgs1.events) + assert.Equal(t, config, receivedArgs1.cfg) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for events") + } + + select { + case receivedArgs2 := <-receivedArgs2: + assert.Equal(t, []StreamEvent{&testEvent{data: []byte("step1")}}, receivedArgs2.events) + assert.Equal(t, config, receivedArgs2.cfg) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for events") + } + + select { + case receivedArgs3 := <-receivedArgs3: + assert.Equal(t, []StreamEvent{&testEvent{data: []byte("step2")}}, receivedArgs3.events) + assert.Equal(t, config, receivedArgs3.cfg) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for events") + } + + assert.NoError(t, err) + assert.Len(t, result, 1) + assert.Equal(t, "final", string(result[0].GetData())) +} + +func TestApplyStreamEventHooks_MultipleHooks_MiddleHookError(t *testing.T) { + ctx := context.Background() + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{data: []byte("original")}, + } + middleHookError := errors.New("middle hook failed") + + receivedArgs1 := make(chan receivedHooksArgs, 1) + hook1 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + receivedArgs1 <- receivedHooksArgs{events: events, cfg: cfg} + return []StreamEvent{&testEvent{data: []byte("step1")}}, nil + } + receivedArgs2 := make(chan receivedHooksArgs, 1) + hook2 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + receivedArgs2 <- receivedHooksArgs{events: events, cfg: cfg} + return nil, middleHookError + } + receivedArgs3 := make(chan receivedHooksArgs, 1) + hook3 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + receivedArgs3 <- receivedHooksArgs{events: events, cfg: cfg} + return []StreamEvent{&testEvent{data: []byte("final")}}, nil + } + + result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook1, hook2, hook3}) + + assert.Error(t, err) + assert.Equal(t, middleHookError, err) + assert.Nil(t, result) + + select { + case receivedArgs1 := <-receivedArgs1: + assert.Equal(t, originalEvents, receivedArgs1.events) + assert.Equal(t, config, receivedArgs1.cfg) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for events") + } + + select { + case receivedArgs2 := <-receivedArgs2: + assert.Equal(t, []StreamEvent{&testEvent{data: []byte("step1")}}, receivedArgs2.events) + assert.Equal(t, config, receivedArgs2.cfg) + case <-time.After(1 * time.Second): + t.Fatal("timeout waiting for events") + } + + assert.Empty(t, receivedArgs3) +} + +// Test the updateEvents method indirectly through Update method +func TestSubscriptionEventUpdater_UpdateEvents_EmptyEvents(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + events := []StreamEvent{} // Empty events + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{}, // No hooks + } + + updater.Update(events) + + // No calls to Update should be made for empty events + mockUpdater.AssertNotCalled(t, "Update") +} + +func TestSubscriptionEventUpdater_Close_WithDifferentCloseKinds(t *testing.T) { + testCases := []struct { + name string + closeKind resolve.SubscriptionCloseKind + }{ + {"Normal", resolve.SubscriptionCloseKindNormal}, + {"DownstreamServiceError", resolve.SubscriptionCloseKindDownstreamServiceError}, + {"GoingAway", resolve.SubscriptionCloseKindGoingAway}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + + mockUpdater.On("Close", tc.closeKind).Return() + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{}, + } + + updater.Close(tc.closeKind) + }) + } +} + +func TestSubscriptionEventUpdater_UpdateSubscription_WithHookError_ClosesSubscription(t *testing.T) { + testCases := []struct { + name string + hookError error + }{ + { + name: "generic error", + hookError: errors.New("subscription should close"), + }, + { + name: "error implementing CloseSubscription false", + hookError: errors.New("subscription should still close"), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + events := []StreamEvent{ + &testEvent{data: []byte("test data")}, + } + + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return events, tc.hookError + } + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{testHook}, + }, + } + + subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} + mockUpdater.On("UpdateSubscription", subId, []byte("test data")).Return() + mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ + context.Background(): subId, + }) + mockUpdater.On("CloseSubscription", resolve.SubscriptionCloseKindNormal, subId).Return() + + updater.Update(events) + + mockUpdater.AssertCalled(t, "CloseSubscription", resolve.SubscriptionCloseKindNormal, subId) + }) + } +} + +func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Error_LoggerWritesError(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + events := []StreamEvent{ + &testEvent{data: []byte("test data")}, + } + hookError := errors.New("hook processing error") + + // Define hook that returns an error + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + return nil, hookError + } + + zCore, logObserver := observer.New(zap.InfoLevel) + logger := zap.New(zCore) + + // Test with a real zap logger to verify error logging behavior + // The logger.Error() call should be executed when an error occurs + updater := NewSubscriptionEventUpdater(config, Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{testHook}, + }, mockUpdater, logger) + + subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} + mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ + context.Background(): subId, + }) + mockUpdater.On("CloseSubscription", resolve.SubscriptionCloseKindNormal, subId).Return() + + updater.Update(events) + + // Assert that Update was not called on the eventUpdater + mockUpdater.AssertNotCalled(t, "UpdateSubscription") + mockUpdater.AssertCalled(t, "CloseSubscription", resolve.SubscriptionCloseKindNormal, subId) + + msgs := logObserver.FilterMessageSnippet("An error occurred while processing stream events hooks").TakeAll() + assert.Equal(t, 1, len(msgs)) +} diff --git a/router/pkg/pubsub/kafka/adapter.go b/router/pkg/pubsub/kafka/adapter.go index fa906370ab..7f61a242b9 100644 --- a/router/pkg/pubsub/kafka/adapter.go +++ b/router/pkg/pubsub/kafka/adapter.go @@ -13,6 +13,7 @@ import ( "github.com/twmb/franz-go/pkg/kerr" "github.com/twmb/franz-go/pkg/kgo" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" ) @@ -20,19 +21,14 @@ var ( errClientClosed = errors.New("client closed") ) +// Ensure ProviderAdapter implements Adapter +var _ datasource.Adapter = (*ProviderAdapter)(nil) + const ( kafkaReceive = "receive" kafkaProduce = "produce" ) -// Adapter defines the interface for Kafka adapter operations -type Adapter interface { - Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error - Publish(ctx context.Context, event PublishEventConfiguration) error - Startup(ctx context.Context) error - Shutdown(ctx context.Context) error -} - // ProviderAdapter is a Kafka pubsub implementation. // It uses the franz-go Kafka client to consume and produce messages. // The pubsub is stateless and does not store any messages. @@ -112,12 +108,11 @@ func (p *ProviderAdapter) topicPoller(ctx context.Context, client *kgo.Client, u DestinationName: r.Topic, }) - updater.Update(&Event{ + updater.Update([]datasource.StreamEvent{&Event{ Data: r.Value, Headers: headers, Key: r.Key, - }) - + }}) } } } @@ -132,7 +127,7 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri } log := p.logger.With( - zap.String("provider_id", subConf.ProviderID()), + zap.String("provider_id", conf.ProviderID()), zap.String("method", "subscribe"), zap.Strings("topics", subConf.Topics), ) @@ -159,15 +154,24 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri go func() { - defer p.closeWg.Done() + defer func() { + client.Close() + updater.Close(resolve.SubscriptionCloseKindNormal) + p.closeWg.Done() + }() err := p.topicPoller(ctx, client, updater, PollerOpts{providerId: conf.ProviderID()}) if err != nil { if errors.Is(err, errClientClosed) || errors.Is(err, context.Canceled) { log.Debug("poller canceled", zap.Error(err)) } else { - log.Error("poller error", zap.Error(err)) - + log.Error( + "poller error", + zap.Error(err), + zap.String("provider_id", conf.ProviderID()), + zap.String("provider_type", string(conf.ProviderType())), + zap.String("field_name", conf.RootFieldName()), + ) } return } @@ -176,67 +180,85 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri return nil } -// Publish publishes the given event to the Kafka topic in a non-blocking way. +// Publish publishes the given events to the Kafka topic in a non-blocking way. // Publish errors are logged and returned as a pubsub error. -// The event is written with a dedicated write client. -func (p *ProviderAdapter) Publish(ctx context.Context, event PublishEventConfiguration) error { +// The events are written with a dedicated write client. +func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEventConfiguration, events []datasource.StreamEvent) error { + pubConf, ok := conf.(*PublishEventConfiguration) + if !ok { + return datasource.NewError("invalid event type for Kafka adapter", nil) + } + log := p.logger.With( - zap.String("provider_id", event.ProviderID()), + zap.String("provider_id", conf.ProviderID()), zap.String("method", "publish"), - zap.String("topic", event.Topic), + zap.String("topic", pubConf.Topic), ) if p.writeClient == nil { return datasource.NewError("kafka write client not initialized", nil) } - log.Debug("publish", zap.ByteString("data", event.Event.Data)) + if len(events) == 0 { + return nil + } + + log.Debug("publish", zap.Int("event_count", len(events))) var wg sync.WaitGroup - wg.Add(1) + wg.Add(len(events)) var pErr error + var errMutex sync.Mutex - headers := make([]kgo.RecordHeader, 0, len(event.Event.Headers)) - for key, value := range event.Event.Headers { - headers = append(headers, kgo.RecordHeader{ - Key: key, - Value: value, - }) - } + for _, streamEvent := range events { + kafkaEvent, ok := streamEvent.(*Event) + if !ok { + return datasource.NewError("invalid event type for Kafka adapter", nil) + } - p.writeClient.Produce(ctx, &kgo.Record{ - Key: event.Event.Key, - Topic: event.Topic, - Value: event.Event.Data, - Headers: headers, - }, func(record *kgo.Record, err error) { - defer wg.Done() - if err != nil { - pErr = err + headers := make([]kgo.RecordHeader, 0, len(kafkaEvent.Headers)) + for key, value := range kafkaEvent.Headers { + headers = append(headers, kgo.RecordHeader{ + Key: key, + Value: value, + }) } - }) + + p.writeClient.Produce(ctx, &kgo.Record{ + Key: kafkaEvent.Key, + Topic: pubConf.Topic, + Value: kafkaEvent.Data, + Headers: headers, + }, func(record *kgo.Record, err error) { + defer wg.Done() + if err != nil { + errMutex.Lock() + pErr = err + errMutex.Unlock() + } + }) + } wg.Wait() if pErr != nil { log.Error("publish error", zap.Error(pErr)) - // failure emission: include error.type generic p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID(), + ProviderId: pubConf.ProviderID(), StreamOperationName: kafkaProduce, ProviderType: metric.ProviderTypeKafka, ErrorType: "publish_error", - DestinationName: event.Topic, + DestinationName: pubConf.Topic, }) - return datasource.NewError(fmt.Sprintf("error publishing to Kafka topic %s", event.Topic), pErr) + return datasource.NewError(fmt.Sprintf("error publishing to Kafka topic %s", pubConf.Topic), pErr) } p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID(), + ProviderId: pubConf.ProviderID(), StreamOperationName: kafkaProduce, ProviderType: metric.ProviderTypeKafka, - DestinationName: event.Topic, + DestinationName: pubConf.Topic, }) return nil } diff --git a/router/pkg/pubsub/kafka/engine_datasource.go b/router/pkg/pubsub/kafka/engine_datasource.go index 723c0d0bd0..00a38023ea 100644 --- a/router/pkg/pubsub/kafka/engine_datasource.go +++ b/router/pkg/pubsub/kafka/engine_datasource.go @@ -6,9 +6,13 @@ import ( "encoding/json" "fmt" "io" + "slices" + "github.com/buger/jsonparser" + "github.com/cespare/xxhash/v2" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) // Event represents an event from Kafka @@ -22,6 +26,18 @@ func (e *Event) GetData() []byte { return e.Data } +func (e *Event) Clone() datasource.StreamEvent { + e2 := *e + e2.Data = slices.Clone(e.Data) + e2.Headers = make(map[string][]byte, len(e.Headers)) + for k, v := range e.Headers { + e2.Headers[k] = slices.Clone(v) + } + return &e2 +} + +// SubscriptionEventConfiguration is a public type that is used to allow access to custom fields +// of the provider type SubscriptionEventConfiguration struct { Provider string `json:"providerId"` Topics []string `json:"topics"` @@ -43,13 +59,47 @@ func (s *SubscriptionEventConfiguration) RootFieldName() string { return s.FieldName } -type PublishEventConfiguration struct { +// publishData is a private type that is used to pass data from the engine to the provider +type publishData struct { Provider string `json:"providerId"` Topic string `json:"topic"` Event Event `json:"event"` FieldName string `json:"rootFieldName"` } +// PublishEventConfiguration returns the publish event configuration from the publishData type +func (p *publishData) PublishEventConfiguration() datasource.PublishEventConfiguration { + return &PublishEventConfiguration{ + Provider: p.Provider, + Topic: p.Topic, + FieldName: p.FieldName, + } +} + +func (p *publishData) MarshalJSONTemplate() (string, error) { + // The content of the data field could be not valid JSON, so we can't use json.Marshal + // e.g. {"id":$$0$$,"update":$$1$$} + headers := p.Event.Headers + if headers == nil { + headers = make(map[string][]byte) + } + + headersBytes, err := json.Marshal(headers) + if err != nil { + return "", err + } + + return fmt.Sprintf(`{"topic":"%s", "event": {"data": %s, "key": "%s", "headers": %s}, "providerId":"%s", "rootFieldName":"%s"}`, p.Topic, p.Event.Data, p.Event.Key, headersBytes, p.Provider, p.FieldName), nil +} + +// PublishEventConfiguration is a public type that is used to allow access to custom fields +// of the provider +type PublishEventConfiguration struct { + Provider string `json:"providerId"` + Topic string `json:"topic"` + FieldName string `json:"rootFieldName"` +} + // ProviderID returns the provider ID func (p *PublishEventConfiguration) ProviderID() string { return p.Provider @@ -65,38 +115,73 @@ func (p *PublishEventConfiguration) RootFieldName() string { return p.FieldName } -func (s *PublishEventConfiguration) MarshalJSONTemplate() (string, error) { - // The content of the data field could be not valid JSON, so we can't use json.Marshal - // e.g. {"id":$$0$$,"update":$$1$$} - headers := s.Event.Headers - if headers == nil { - headers = make(map[string][]byte) +type SubscriptionDataSource struct { + pubSub datasource.Adapter +} + +func (s *SubscriptionDataSource) SubscriptionEventConfiguration(input []byte) datasource.SubscriptionEventConfiguration { + var subscriptionConfiguration SubscriptionEventConfiguration + err := json.Unmarshal(input, &subscriptionConfiguration) + if err != nil { + return nil } + return &subscriptionConfiguration +} - headersBytes, err := json.Marshal(headers) +func (s *SubscriptionDataSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + val, _, _, err := jsonparser.Get(input, "topics") if err != nil { - return "", err + return err } - return fmt.Sprintf(`{"topic":"%s", "event": {"data": %s, "key": "%s", "headers": %s}, "providerId":"%s"}`, s.Topic, s.Event.Data, s.Event.Key, headersBytes, s.ProviderID()), nil + _, err = xxh.Write(val) + if err != nil { + return err + } + + val, _, _, err = jsonparser.Get(input, "providerId") + if err != nil { + return err + } + + _, err = xxh.Write(val) + return err +} + +func (s *SubscriptionDataSource) Start(ctx *resolve.Context, input []byte, updater datasource.SubscriptionEventUpdater) error { + subConf := s.SubscriptionEventConfiguration(input) + if subConf == nil { + return fmt.Errorf("no subscription configuration found") + } + + conf, ok := subConf.(*SubscriptionEventConfiguration) + if !ok { + return fmt.Errorf("invalid subscription configuration") + } + + return s.pubSub.Subscribe(ctx.Context(), conf, updater) } type PublishDataSource struct { - pubSub Adapter + pubSub datasource.Adapter } func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { - var publishConfiguration PublishEventConfiguration - if err := json.Unmarshal(input, &publishConfiguration); err != nil { + var publishData publishData + if err := json.Unmarshal(input, &publishData); err != nil { return err } - if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { - _, err = io.WriteString(out, `{"success": false}`) - return err + if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{&publishData.Event}); err != nil { + // err will not be returned but only logged inside PubSubProvider.Publish to avoid a "unable to fetch from subgraph" error + _, errWrite := io.WriteString(out, `{"success": false}`) + return errWrite } - _, err := io.WriteString(out, `{"success": true}`) - return err + _, errWrite := io.WriteString(out, `{"success": true}`) + if errWrite != nil { + return errWrite + } + return nil } func (s *PublishDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) (err error) { diff --git a/router/pkg/pubsub/kafka/engine_datasource_factory.go b/router/pkg/pubsub/kafka/engine_datasource_factory.go index 30507bc13b..d89eb408b0 100644 --- a/router/pkg/pubsub/kafka/engine_datasource_factory.go +++ b/router/pkg/pubsub/kafka/engine_datasource_factory.go @@ -8,6 +8,7 @@ import ( "github.com/cespare/xxhash/v2" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" ) type EventType int @@ -22,8 +23,9 @@ type EngineDataSourceFactory struct { eventType EventType topics []string providerId string + logger *zap.Logger - KafkaAdapter Adapter + KafkaAdapter datasource.Adapter } func (c *EngineDataSourceFactory) GetFieldName() string { @@ -50,7 +52,7 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri return "", fmt.Errorf("publish events should define one topic but received %d", len(c.topics)) } - evtCfg := PublishEventConfiguration{ + evtCfg := publishData{ Provider: c.providerId, Topic: c.topics[0], Event: Event{Data: eventData}, @@ -81,7 +83,7 @@ func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (datasource.Su _, err = xxh.Write(val) return err - }), nil + }, c.logger), nil } func (c *EngineDataSourceFactory) ResolveDataSourceSubscriptionInput() (string, error) { diff --git a/router/pkg/pubsub/kafka/engine_datasource_factory_test.go b/router/pkg/pubsub/kafka/engine_datasource_factory_test.go index 0b4ea9c59c..5ceab4ae69 100644 --- a/router/pkg/pubsub/kafka/engine_datasource_factory_test.go +++ b/router/pkg/pubsub/kafka/engine_datasource_factory_test.go @@ -5,12 +5,14 @@ import ( "context" "encoding/json" "errors" + "strings" "testing" "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/cosmo/router/pkg/pubsub/pubsubtest" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) @@ -33,11 +35,13 @@ func TestKafkaEngineDataSourceFactory(t *testing.T) { // TestEngineDataSourceFactoryWithMockAdapter tests the EngineDataSourceFactory with a mocked adapter func TestEngineDataSourceFactoryWithMockAdapter(t *testing.T) { // Create mock adapter - mockAdapter := NewMockAdapter(t) + mockAdapter := datasource.NewMockProvider(t) // Configure mock expectations for Publish - mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishEventConfiguration) bool { + mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event *PublishEventConfiguration) bool { return event.ProviderID() == "test-provider" && event.Topic == "test-topic" + }), mock.MatchedBy(func(events []datasource.StreamEvent) bool { + return len(events) == 1 && strings.EqualFold(string(events[0].GetData()), `{"test":"data"}`) })).Return(nil) // Create the data source with mock adapter @@ -67,7 +71,7 @@ func TestEngineDataSourceFactoryWithMockAdapter(t *testing.T) { // TestEngineDataSourceFactory_GetResolveDataSource_WrongType tests the EngineDataSourceFactory with a mocked adapter func TestEngineDataSourceFactory_GetResolveDataSource_WrongType(t *testing.T) { // Create mock adapter - mockAdapter := NewMockAdapter(t) + mockAdapter := datasource.NewMockProvider(t) // Create the data source with mock adapter pubsub := &EngineDataSourceFactory{ @@ -171,7 +175,7 @@ func TestKafkaEngineDataSourceFactory_UniqueRequestID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { factory := &EngineDataSourceFactory{ - KafkaAdapter: NewMockAdapter(t), + KafkaAdapter: datasource.NewMockProvider(t), } source, err := factory.ResolveDataSourceSubscription() require.NoError(t, err) diff --git a/router/pkg/pubsub/kafka/engine_datasource_test.go b/router/pkg/pubsub/kafka/engine_datasource_test.go index eed485b246..846203d6e0 100644 --- a/router/pkg/pubsub/kafka/engine_datasource_test.go +++ b/router/pkg/pubsub/kafka/engine_datasource_test.go @@ -5,54 +5,60 @@ import ( "context" "encoding/json" "errors" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" ) -func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { +func TestPublishData_MarshalJSONTemplate(t *testing.T) { tests := []struct { name string - config PublishEventConfiguration + config publishData wantPattern string }{ { name: "simple configuration", - config: PublishEventConfiguration{ + config: publishData{ Provider: "test-provider", Topic: "test-topic", Event: Event{Data: json.RawMessage(`{"message":"hello"}`)}, + FieldName: "test-field", }, - wantPattern: `{"topic":"test-topic", "event": {"data": {"message":"hello"}, "key": "", "headers": {}}, "providerId":"test-provider"}`, + wantPattern: `{"topic":"test-topic", "event": {"data": {"message":"hello"}, "key": "", "headers": {}}, "providerId":"test-provider", "rootFieldName":"test-field"}`, }, { name: "with special characters", - config: PublishEventConfiguration{ + config: publishData{ Provider: "test-provider-id", Topic: "topic-with-hyphens", Event: Event{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, + FieldName: "test-field", }, - wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}, "key": "", "headers": {}}, "providerId":"test-provider-id"}`, + wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}, "key": "", "headers": {}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`, }, { name: "with key", - config: PublishEventConfiguration{ + config: publishData{ Provider: "test-provider-id", Topic: "topic-with-hyphens", Event: Event{Key: []byte("blablabla"), Data: json.RawMessage(`{}`)}, + FieldName: "test-field", }, - wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {}, "key": "blablabla", "headers": {}}, "providerId":"test-provider-id"}`, + wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {}, "key": "blablabla", "headers": {}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`, }, { name: "with headers", - config: PublishEventConfiguration{ + config: publishData{ Provider: "test-provider-id", Topic: "topic-with-hyphens", Event: Event{Headers: map[string][]byte{"key": []byte(`blablabla`)}, Data: json.RawMessage(`{}`)}, + FieldName: "test-field", }, - wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {}, "key": "", "headers": {"key":"YmxhYmxhYmxh"}}, "providerId":"test-provider-id"}`, + wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {}, "key": "", "headers": {"key":"YmxhYmxhYmxh"}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`, }, } @@ -65,11 +71,27 @@ func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { } } +func TestPublishData_PublishEventConfiguration(t *testing.T) { + data := publishData{ + Provider: "test-provider", + Topic: "test-topic", + FieldName: "test-field", + } + + evtCfg := &PublishEventConfiguration{ + Provider: data.Provider, + Topic: data.Topic, + FieldName: data.FieldName, + } + + assert.Equal(t, evtCfg, data.PublishEventConfiguration()) +} + func TestKafkaPublishDataSource_Load(t *testing.T) { tests := []struct { name string input string - mockSetup func(*MockAdapter) + mockSetup func(*datasource.MockProvider) expectError bool expectedOutput string expectPublished bool @@ -77,11 +99,12 @@ func TestKafkaPublishDataSource_Load(t *testing.T) { { name: "successful publish", input: `{"topic":"test-topic", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter) { - m.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishEventConfiguration) bool { + mockSetup: func(m *datasource.MockProvider) { + m.On("Publish", mock.Anything, mock.MatchedBy(func(event *PublishEventConfiguration) bool { return event.ProviderID() == "test-provider" && - event.Topic == "test-topic" && - string(event.Event.Data) == `{"message":"hello"}` + event.Topic == "test-topic" + }), mock.MatchedBy(func(events []datasource.StreamEvent) bool { + return len(events) == 1 && strings.EqualFold(string(events[0].GetData()), `{"message":"hello"}`) })).Return(nil) }, expectError: false, @@ -91,8 +114,8 @@ func TestKafkaPublishDataSource_Load(t *testing.T) { { name: "publish error", input: `{"topic":"test-topic", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter) { - m.On("Publish", mock.Anything, mock.Anything).Return(errors.New("publish error")) + mockSetup: func(m *datasource.MockProvider) { + m.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("publish error")) }, expectError: false, // The Load method doesn't return the publish error directly expectedOutput: `{"success": false}`, @@ -101,7 +124,7 @@ func TestKafkaPublishDataSource_Load(t *testing.T) { { name: "invalid input json", input: `{"invalid json":`, - mockSetup: func(m *MockAdapter) {}, + mockSetup: func(m *datasource.MockProvider) {}, expectError: true, expectPublished: false, }, @@ -109,7 +132,7 @@ func TestKafkaPublishDataSource_Load(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mockAdapter := NewMockAdapter(t) + mockAdapter := datasource.NewMockProvider(t) tt.mockSetup(mockAdapter) dataSource := &PublishDataSource{ @@ -134,7 +157,7 @@ func TestKafkaPublishDataSource_Load(t *testing.T) { func TestKafkaPublishDataSource_LoadWithFiles(t *testing.T) { t.Run("panic on not implemented", func(t *testing.T) { dataSource := &PublishDataSource{ - pubSub: NewMockAdapter(t), + pubSub: datasource.NewMockProvider(t), } assert.Panics(t, func() { diff --git a/router/pkg/pubsub/kafka/mocks.go b/router/pkg/pubsub/kafka/mocks.go deleted file mode 100644 index 08faa08eb2..0000000000 --- a/router/pkg/pubsub/kafka/mocks.go +++ /dev/null @@ -1,261 +0,0 @@ -// Code generated by mockery; DO NOT EDIT. -// github.com/vektra/mockery -// template: testify - -package kafka - -import ( - "context" - - mock "github.com/stretchr/testify/mock" - "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" -) - -// NewMockAdapter creates a new instance of MockAdapter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockAdapter(t interface { - mock.TestingT - Cleanup(func()) -}) *MockAdapter { - mock := &MockAdapter{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// MockAdapter is an autogenerated mock type for the Adapter type -type MockAdapter struct { - mock.Mock -} - -type MockAdapter_Expecter struct { - mock *mock.Mock -} - -func (_m *MockAdapter) EXPECT() *MockAdapter_Expecter { - return &MockAdapter_Expecter{mock: &_m.Mock} -} - -// Publish provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Publish(ctx context.Context, event PublishEventConfiguration) error { - ret := _mock.Called(ctx, event) - - if len(ret) == 0 { - panic("no return value specified for Publish") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, PublishEventConfiguration) error); ok { - r0 = returnFunc(ctx, event) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockAdapter_Publish_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Publish' -type MockAdapter_Publish_Call struct { - *mock.Call -} - -// Publish is a helper method to define mock.On call -// - ctx context.Context -// - event PublishEventConfiguration -func (_e *MockAdapter_Expecter) Publish(ctx interface{}, event interface{}) *MockAdapter_Publish_Call { - return &MockAdapter_Publish_Call{Call: _e.mock.On("Publish", ctx, event)} -} - -func (_c *MockAdapter_Publish_Call) Run(run func(ctx context.Context, event PublishEventConfiguration)) *MockAdapter_Publish_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 PublishEventConfiguration - if args[1] != nil { - arg1 = args[1].(PublishEventConfiguration) - } - run( - arg0, - arg1, - ) - }) - return _c -} - -func (_c *MockAdapter_Publish_Call) Return(err error) *MockAdapter_Publish_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockAdapter_Publish_Call) RunAndReturn(run func(ctx context.Context, event PublishEventConfiguration) error) *MockAdapter_Publish_Call { - _c.Call.Return(run) - return _c -} - -// Shutdown provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Shutdown(ctx context.Context) error { - ret := _mock.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Shutdown") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context) error); ok { - r0 = returnFunc(ctx) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockAdapter_Shutdown_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Shutdown' -type MockAdapter_Shutdown_Call struct { - *mock.Call -} - -// Shutdown is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockAdapter_Expecter) Shutdown(ctx interface{}) *MockAdapter_Shutdown_Call { - return &MockAdapter_Shutdown_Call{Call: _e.mock.On("Shutdown", ctx)} -} - -func (_c *MockAdapter_Shutdown_Call) Run(run func(ctx context.Context)) *MockAdapter_Shutdown_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - run( - arg0, - ) - }) - return _c -} - -func (_c *MockAdapter_Shutdown_Call) Return(err error) *MockAdapter_Shutdown_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockAdapter_Shutdown_Call) RunAndReturn(run func(ctx context.Context) error) *MockAdapter_Shutdown_Call { - _c.Call.Return(run) - return _c -} - -// Startup provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Startup(ctx context.Context) error { - ret := _mock.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Startup") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context) error); ok { - r0 = returnFunc(ctx) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockAdapter_Startup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Startup' -type MockAdapter_Startup_Call struct { - *mock.Call -} - -// Startup is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockAdapter_Expecter) Startup(ctx interface{}) *MockAdapter_Startup_Call { - return &MockAdapter_Startup_Call{Call: _e.mock.On("Startup", ctx)} -} - -func (_c *MockAdapter_Startup_Call) Run(run func(ctx context.Context)) *MockAdapter_Startup_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - run( - arg0, - ) - }) - return _c -} - -func (_c *MockAdapter_Startup_Call) Return(err error) *MockAdapter_Startup_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockAdapter_Startup_Call) RunAndReturn(run func(ctx context.Context) error) *MockAdapter_Startup_Call { - _c.Call.Return(run) - return _c -} - -// Subscribe provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { - ret := _mock.Called(ctx, event, updater) - - if len(ret) == 0 { - panic("no return value specified for Subscribe") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, datasource.SubscriptionEventConfiguration, datasource.SubscriptionEventUpdater) error); ok { - r0 = returnFunc(ctx, event, updater) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockAdapter_Subscribe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Subscribe' -type MockAdapter_Subscribe_Call struct { - *mock.Call -} - -// Subscribe is a helper method to define mock.On call -// - ctx context.Context -// - event datasource.SubscriptionEventConfiguration -// - updater datasource.SubscriptionEventUpdater -func (_e *MockAdapter_Expecter) Subscribe(ctx interface{}, event interface{}, updater interface{}) *MockAdapter_Subscribe_Call { - return &MockAdapter_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, event, updater)} -} - -func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater)) *MockAdapter_Subscribe_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 datasource.SubscriptionEventConfiguration - if args[1] != nil { - arg1 = args[1].(datasource.SubscriptionEventConfiguration) - } - var arg2 datasource.SubscriptionEventUpdater - if args[2] != nil { - arg2 = args[2].(datasource.SubscriptionEventUpdater) - } - run( - arg0, - arg1, - arg2, - ) - }) - return _c -} - -func (_c *MockAdapter_Subscribe_Call) Return(err error) *MockAdapter_Subscribe_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error) *MockAdapter_Subscribe_Call { - _c.Call.Return(run) - return _c -} diff --git a/router/pkg/pubsub/kafka/provider_builder.go b/router/pkg/pubsub/kafka/provider_builder.go index c88cf814c2..c69a458eba 100644 --- a/router/pkg/pubsub/kafka/provider_builder.go +++ b/router/pkg/pubsub/kafka/provider_builder.go @@ -23,16 +23,15 @@ type ProviderBuilder struct { logger *zap.Logger hostName string routerListenAddr string - adapters map[string]Adapter } func (p *ProviderBuilder) TypeID() string { return providerTypeID } -func (p *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.KafkaEventConfiguration) (datasource.EngineDataSourceFactory, error) { +func (p *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.KafkaEventConfiguration, providers map[string]datasource.Provider) (datasource.EngineDataSourceFactory, error) { providerId := data.GetEngineEventConfiguration().GetProviderId() - adapter, ok := p.adapters[providerId] + provider, ok := providers[providerId] if !ok { return nil, fmt.Errorf("failed to get adapter for provider %s with ID %s", p.TypeID(), providerId) } @@ -52,18 +51,17 @@ func (p *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.KafkaEventCo eventType: eventType, topics: data.GetTopics(), providerId: providerId, - KafkaAdapter: adapter, + KafkaAdapter: provider, + logger: p.logger, }, nil } func (p *ProviderBuilder) BuildProvider(provider config.KafkaEventSource, providerOpts datasource.ProviderOpts) (datasource.Provider, error) { - adapter, pubSubProvider, err := buildProvider(p.ctx, provider, p.logger, providerOpts) + pubSubProvider, err := buildProvider(p.ctx, provider, p.logger, providerOpts) if err != nil { return nil, err } - p.adapters[provider.ID] = adapter - return pubSubProvider, nil } @@ -150,18 +148,18 @@ func buildKafkaOptions(eventSource config.KafkaEventSource, logger *zap.Logger) return opts, nil } -func buildProvider(ctx context.Context, provider config.KafkaEventSource, logger *zap.Logger, providerOpts datasource.ProviderOpts) (Adapter, datasource.Provider, error) { +func buildProvider(ctx context.Context, provider config.KafkaEventSource, logger *zap.Logger, providerOpts datasource.ProviderOpts) (datasource.Provider, error) { kafkaOpts, err := buildKafkaOptions(provider, logger) if err != nil { - return nil, nil, fmt.Errorf("failed to build options for Kafka provider with ID \"%s\": %w", provider.ID, err) + return nil, fmt.Errorf("failed to build options for Kafka provider with ID \"%s\": %w", provider.ID, err) } adapter, err := NewProviderAdapter(ctx, logger, kafkaOpts, providerOpts) if err != nil { - return nil, nil, fmt.Errorf("failed to create adapter for Kafka provider with ID \"%s\": %w", provider.ID, err) + return nil, fmt.Errorf("failed to create adapter for Kafka provider with ID \"%s\": %w", provider.ID, err) } pubSubProvider := datasource.NewPubSubProvider(provider.ID, providerTypeID, adapter, logger) - return adapter, pubSubProvider, nil + return pubSubProvider, nil } func NewProviderBuilder( @@ -175,6 +173,5 @@ func NewProviderBuilder( logger: logger, hostName: hostName, routerListenAddr: routerListenAddr, - adapters: make(map[string]Adapter), } } diff --git a/router/pkg/pubsub/nats/adapter.go b/router/pkg/pubsub/nats/adapter.go index dcba74a03b..def1d19f81 100644 --- a/router/pkg/pubsub/nats/adapter.go +++ b/router/pkg/pubsub/nats/adapter.go @@ -25,18 +25,14 @@ const ( // Adapter defines the methods that a NATS adapter should implement type Adapter interface { - // Subscribe subscribes to the given events and sends updates to the updater - Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error - // Publish publishes the given event to the specified subject - Publish(ctx context.Context, event PublishAndRequestEventConfiguration) error + datasource.Adapter // Request sends a request to the specified subject and writes the response to the given writer - Request(ctx context.Context, event PublishAndRequestEventConfiguration, w io.Writer) error - // Startup initializes the adapter - Startup(ctx context.Context) error - // Shutdown gracefully shuts down the adapter - Shutdown(ctx context.Context) error + Request(ctx context.Context, cfg datasource.PublishEventConfiguration, event datasource.StreamEvent, w io.Writer) error } +// Ensure ProviderAdapter implements ProviderSubscriptionHooks +var _ datasource.Adapter = (*ProviderAdapter)(nil) + // ProviderAdapter implements the AdapterInterface for NATS pub/sub type ProviderAdapter struct { ctx context.Context @@ -80,11 +76,12 @@ func (p *ProviderAdapter) getDurableConsumerName(durableName string, subjects [] return fmt.Sprintf("%s-%x", durableName, subjHash.Sum64()), nil } -func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { - subConf, ok := conf.(*SubscriptionEventConfiguration) +func (p *ProviderAdapter) Subscribe(ctx context.Context, cfg datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { + subConf, ok := cfg.(*SubscriptionEventConfiguration) if !ok { - return datasource.NewError("invalid event type for Kafka adapter", nil) + return datasource.NewError("subscription event not support by nats provider", nil) } + log := p.logger.With( zap.String("provider_id", subConf.ProviderID()), zap.String("method", "subscribe"), @@ -145,16 +142,16 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri log.Debug("subscription update", zap.String("message_subject", msg.Subject()), zap.ByteString("data", msg.Data())) p.streamMetricStore.Consume(p.ctx, metric.StreamsEvent{ - ProviderId: conf.ProviderID(), + ProviderId: subConf.ProviderID(), StreamOperationName: natsReceive, ProviderType: metric.ProviderTypeNats, DestinationName: msg.Subject(), }) - updater.Update(&Event{ + updater.Update([]datasource.StreamEvent{&Event{ Data: msg.Data(), Headers: msg.Headers(), - }) + }}) // Acknowledge the message after it has been processed ackErr := msg.Ack() @@ -191,18 +188,16 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri select { case msg := <-msgChan: log.Debug("subscription update", zap.String("message_subject", msg.Subject), zap.ByteString("data", msg.Data)) - p.streamMetricStore.Consume(p.ctx, metric.StreamsEvent{ - ProviderId: conf.ProviderID(), + ProviderId: subConf.ProviderID(), StreamOperationName: natsReceive, ProviderType: metric.ProviderTypeNats, DestinationName: msg.Subject, }) - - updater.Update(&Event{ + updater.Update([]datasource.StreamEvent{&Event{ Data: msg.Data, Headers: msg.Header, - }) + }}) case <-p.ctx.Done(): // When the application context is done, we stop the subscriptions for _, subscription := range subscriptions { @@ -230,73 +225,107 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri return nil } -func (p *ProviderAdapter) Publish(ctx context.Context, event PublishAndRequestEventConfiguration) error { +func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEventConfiguration, events []datasource.StreamEvent) error { + pubConf, ok := conf.(*PublishAndRequestEventConfiguration) + if !ok { + return datasource.NewError("publish event not support by nats provider", nil) + } + log := p.logger.With( - zap.String("provider_id", event.ProviderID()), + zap.String("provider_id", pubConf.ProviderID()), zap.String("method", "publish"), - zap.String("subject", event.Subject), + zap.String("subject", pubConf.Subject), ) if p.client == nil { return datasource.NewError("nats client not initialized", nil) } - log.Debug("publish", zap.ByteString("data", event.Event.Data)) + log.Debug("publish", zap.Int("event_count", len(events))) - err := p.client.Publish(event.Subject, event.Event.Data) - if err != nil { - log.Error("publish error", zap.Error(err)) - p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID(), - StreamOperationName: natsPublish, - ProviderType: metric.ProviderTypeNats, - ErrorType: "publish_error", - DestinationName: event.Subject, - }) - return datasource.NewError(fmt.Sprintf("error publishing to NATS subject %s", event.Subject), err) - } else { - p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID(), - StreamOperationName: natsPublish, - ProviderType: metric.ProviderTypeNats, - DestinationName: event.Subject, - }) + for _, streamEvent := range events { + natsEvent, ok := streamEvent.(*Event) + if !ok { + return datasource.NewError("invalid event type for NATS adapter", nil) + } + + err := p.client.Publish(pubConf.Subject, natsEvent.Data) + if err != nil { + p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ + ProviderId: pubConf.ProviderID(), + StreamOperationName: natsPublish, + ProviderType: metric.ProviderTypeNats, + ErrorType: "publish_error", + DestinationName: pubConf.Subject, + }) + log.Error( + "publish error", + zap.Error(err), + zap.String("provider_id", pubConf.ProviderID()), + zap.String("provider_type", string(pubConf.ProviderType())), + zap.String("field_name", pubConf.RootFieldName()), + ) + return datasource.NewError(fmt.Sprintf("error publishing to NATS subject %s", pubConf.Subject), err) + } } + p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ + ProviderId: pubConf.ProviderID(), + StreamOperationName: natsPublish, + ProviderType: metric.ProviderTypeNats, + DestinationName: pubConf.Subject, + }) + return nil } -func (p *ProviderAdapter) Request(ctx context.Context, event PublishAndRequestEventConfiguration, w io.Writer) error { +func (p *ProviderAdapter) Request(ctx context.Context, cfg datasource.PublishEventConfiguration, event datasource.StreamEvent, w io.Writer) error { + reqConf, ok := cfg.(*PublishAndRequestEventConfiguration) + if !ok { + return datasource.NewError("publish event not support by nats provider", nil) + } + log := p.logger.With( - zap.String("provider_id", event.ProviderID()), + zap.String("provider_id", cfg.ProviderID()), zap.String("method", "request"), - zap.String("subject", event.Subject), + zap.String("subject", reqConf.Subject), ) if p.client == nil { return datasource.NewError("nats client not initialized", nil) } - log.Debug("request", zap.ByteString("data", event.Event.Data)) + natsEvent, ok := event.(*Event) + if !ok { + return datasource.NewError("invalid event type for NATS adapter", nil) + } + + log.Debug("request", zap.ByteString("data", natsEvent.Data)) - msg, err := p.client.RequestWithContext(ctx, event.Subject, event.Event.Data) + msg, err := p.client.RequestWithContext(ctx, reqConf.Subject, natsEvent.Data) if err != nil { - log.Error("request error", zap.Error(err)) + log.Error( + "request error", + zap.Error(err), + zap.String("provider_id", reqConf.ProviderID()), + zap.String("provider_type", string(reqConf.ProviderType())), + zap.String("field_name", reqConf.RootFieldName()), + ) p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID(), + ProviderId: reqConf.ProviderID(), StreamOperationName: natsRequest, ProviderType: metric.ProviderTypeNats, ErrorType: "request_error", - DestinationName: event.Subject, + DestinationName: reqConf.Subject, }) - return datasource.NewError(fmt.Sprintf("error requesting from NATS subject %s", event.Subject), err) + return datasource.NewError(fmt.Sprintf("error requesting from NATS subject %s", reqConf.Subject), err) } p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID(), + ProviderId: reqConf.ProviderID(), StreamOperationName: natsRequest, ProviderType: metric.ProviderTypeNats, - DestinationName: event.Subject, + DestinationName: reqConf.Subject, }) // We don't collect metrics on err here as it's an error related to the writer diff --git a/router/pkg/pubsub/nats/engine_datasource.go b/router/pkg/pubsub/nats/engine_datasource.go index 0fa41e5480..3b2014a71a 100644 --- a/router/pkg/pubsub/nats/engine_datasource.go +++ b/router/pkg/pubsub/nats/engine_datasource.go @@ -6,9 +6,13 @@ import ( "encoding/json" "fmt" "io" + "slices" + "github.com/buger/jsonparser" + "github.com/cespare/xxhash/v2" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) // Event represents an event from NATS @@ -21,6 +25,16 @@ func (e *Event) GetData() []byte { return e.Data } +func (e *Event) Clone() datasource.StreamEvent { + e2 := *e + e2.Data = slices.Clone(e.Data) + e2.Headers = make(map[string][]string, len(e.Headers)) + for k, v := range e.Headers { + e2.Headers[k] = slices.Clone(v) + } + return &e2 +} + type StreamConfiguration struct { Consumer string `json:"consumer"` ConsumerInactiveThreshold int32 `json:"consumerInactiveThreshold"` @@ -49,13 +63,34 @@ func (s *SubscriptionEventConfiguration) RootFieldName() string { return s.FieldName } -type PublishAndRequestEventConfiguration struct { +// publishData is a private type that is used to pass data from the engine to the provider +type publishData struct { Provider string `json:"providerId"` Subject string `json:"subject"` Event Event `json:"event"` FieldName string `json:"rootFieldName"` } +func (p *publishData) PublishEventConfiguration() datasource.PublishEventConfiguration { + return &PublishAndRequestEventConfiguration{ + Provider: p.Provider, + Subject: p.Subject, + FieldName: p.FieldName, + } +} + +func (p *publishData) MarshalJSONTemplate() (string, error) { + // The content of the data field could be not valid JSON, so we can't use json.Marshal + // e.g. {"id":$$0$$,"update":$$1$$} + return fmt.Sprintf(`{"subject":"%s", "event": {"data": %s}, "providerId":"%s", "rootFieldName":"%s"}`, p.Subject, p.Event.Data, p.Provider, p.FieldName), nil +} + +type PublishAndRequestEventConfiguration struct { + Provider string `json:"providerId"` + Subject string `json:"subject"` + FieldName string `json:"rootFieldName"` +} + // ProviderID returns the provider ID func (p *PublishAndRequestEventConfiguration) ProviderID() string { return p.Provider @@ -71,25 +106,68 @@ func (p *PublishAndRequestEventConfiguration) RootFieldName() string { return p.FieldName } -func (p *PublishAndRequestEventConfiguration) MarshalJSONTemplate() (string, error) { - // The content of the data field could be not valid JSON, so we can't use json.Marshal - // e.g. {"id":$$0$$,"update":$$1$$} - return fmt.Sprintf(`{"subject":"%s", "event": {"data": %s}, "providerId":"%s"}`, p.Subject, p.Event.Data, p.ProviderID()), nil +type SubscriptionSource struct { + pubSub datasource.Adapter +} + +func (s *SubscriptionSource) SubscriptionEventConfiguration(input []byte) datasource.SubscriptionEventConfiguration { + var subscriptionConfiguration SubscriptionEventConfiguration + err := json.Unmarshal(input, &subscriptionConfiguration) + if err != nil { + return nil + } + return &subscriptionConfiguration +} + +func (s *SubscriptionSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + + val, _, _, err := jsonparser.Get(input, "subjects") + if err != nil { + return err + } + + _, err = xxh.Write(val) + if err != nil { + return err + } + + val, _, _, err = jsonparser.Get(input, "providerId") + if err != nil { + return err + } + + _, err = xxh.Write(val) + return err +} + +func (s *SubscriptionSource) Start(ctx *resolve.Context, input []byte, updater datasource.SubscriptionEventUpdater) error { + subConf := s.SubscriptionEventConfiguration(input) + if subConf == nil { + return fmt.Errorf("no subscription configuration found") + } + + conf, ok := subConf.(*SubscriptionEventConfiguration) + if !ok { + return fmt.Errorf("invalid subscription configuration") + } + + return s.pubSub.Subscribe(ctx.Context(), conf, updater) } type NatsPublishDataSource struct { - pubSub Adapter + pubSub datasource.Adapter } func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { - var publishConfiguration PublishAndRequestEventConfiguration - if err := json.Unmarshal(input, &publishConfiguration); err != nil { + var publishData publishData + if err := json.Unmarshal(input, &publishData); err != nil { return err } - if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { - _, err = io.WriteString(out, `{"success": false}`) - return err + if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{&publishData.Event}); err != nil { + // err will not be returned but only logged inside PubSubProvider.Publish to avoid a "unable to fetch from subgraph" error + _, errWrite := io.WriteString(out, `{"success": false}`) + return errWrite } _, err := io.WriteString(out, `{"success": true}`) return err @@ -100,16 +178,26 @@ func (s *NatsPublishDataSource) LoadWithFiles(ctx context.Context, input []byte, } type NatsRequestDataSource struct { - pubSub Adapter + pubSub datasource.Adapter } func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { - var subscriptionConfiguration PublishAndRequestEventConfiguration - if err := json.Unmarshal(input, &subscriptionConfiguration); err != nil { + var publishData publishData + if err := json.Unmarshal(input, &publishData); err != nil { return err } - return s.pubSub.Request(ctx, subscriptionConfiguration, out) + providerBase, ok := s.pubSub.(*datasource.PubSubProvider) + if !ok { + return fmt.Errorf("adapter for provider %s is not of the right type", publishData.Provider) + } + + adapter, ok := providerBase.Adapter.(Adapter) + if !ok { + return fmt.Errorf("adapter for provider %s is not of the right type", publishData.Provider) + } + + return adapter.Request(ctx, publishData.PublishEventConfiguration(), &publishData.Event, out) } func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) error { diff --git a/router/pkg/pubsub/nats/engine_datasource_factory.go b/router/pkg/pubsub/nats/engine_datasource_factory.go index 36d3932e0d..d88d25b868 100644 --- a/router/pkg/pubsub/nats/engine_datasource_factory.go +++ b/router/pkg/pubsub/nats/engine_datasource_factory.go @@ -8,8 +8,8 @@ import ( "github.com/buger/jsonparser" "github.com/cespare/xxhash/v2" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" ) type EventType int @@ -21,12 +21,13 @@ const ( ) type EngineDataSourceFactory struct { - NatsAdapter Adapter + NatsAdapter datasource.Adapter fieldName string eventType EventType subjects []string providerId string + logger *zap.Logger withStreamConfiguration bool consumerName string @@ -64,11 +65,11 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri subject := c.subjects[0] - evtCfg := PublishAndRequestEventConfiguration{ + evtCfg := publishData{ Provider: c.providerId, Subject: subject, - Event: Event{Data: eventData}, FieldName: c.fieldName, + Event: Event{Data: eventData}, } return evtCfg.MarshalJSONTemplate() @@ -95,7 +96,7 @@ func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (datasource.Su _, err = xxh.Write(val) return err - }), nil + }, c.logger), nil } func (c *EngineDataSourceFactory) ResolveDataSourceSubscriptionInput() (string, error) { diff --git a/router/pkg/pubsub/nats/engine_datasource_factory_test.go b/router/pkg/pubsub/nats/engine_datasource_factory_test.go index a94c8d5941..053ff0d702 100644 --- a/router/pkg/pubsub/nats/engine_datasource_factory_test.go +++ b/router/pkg/pubsub/nats/engine_datasource_factory_test.go @@ -6,14 +6,17 @@ import ( "encoding/json" "errors" "io" + "strings" "testing" "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/cosmo/router/pkg/pubsub/pubsubtest" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" ) func TestNatsEngineDataSourceFactory(t *testing.T) { @@ -36,8 +39,10 @@ func TestEngineDataSourceFactoryWithMockAdapter(t *testing.T) { mockAdapter := NewMockAdapter(t) // Configure mock expectations for Publish - mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { + mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event *PublishAndRequestEventConfiguration) bool { return event.ProviderID() == "test-provider" && event.Subject == "test-subject" + }), mock.MatchedBy(func(events []datasource.StreamEvent) bool { + return len(events) == 1 && strings.EqualFold(string(events[0].GetData()), `{"test":"data"}`) })).Return(nil) // Create the data source with mock adapter @@ -167,12 +172,15 @@ func TestNatsEngineDataSourceFactoryWithStreamConfiguration(t *testing.T) { func TestEngineDataSourceFactory_RequestDataSource(t *testing.T) { // Create mock adapter mockAdapter := NewMockAdapter(t) + provider := datasource.NewPubSubProvider("test-provider", "nats", mockAdapter, zap.NewNop()) // Configure mock expectations for Request - mockAdapter.On("Request", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { + mockAdapter.On("Request", mock.Anything, mock.MatchedBy(func(event *PublishAndRequestEventConfiguration) bool { return event.ProviderID() == "test-provider" && event.Subject == "test-subject" + }), mock.MatchedBy(func(event datasource.StreamEvent) bool { + return event != nil && strings.EqualFold(string(event.GetData()), `{"test":"data"}`) }), mock.Anything).Return(nil).Run(func(args mock.Arguments) { - w := args.Get(2).(io.Writer) + w := args.Get(3).(io.Writer) w.Write([]byte(`{"response": "test"}`)) }) @@ -182,7 +190,7 @@ func TestEngineDataSourceFactory_RequestDataSource(t *testing.T) { eventType: EventTypeRequest, subjects: []string{"test-subject"}, fieldName: "testField", - NatsAdapter: mockAdapter, + NatsAdapter: provider, } // Get the data source diff --git a/router/pkg/pubsub/nats/engine_datasource_test.go b/router/pkg/pubsub/nats/engine_datasource_test.go index 5d060d2c0d..8665f42181 100644 --- a/router/pkg/pubsub/nats/engine_datasource_test.go +++ b/router/pkg/pubsub/nats/engine_datasource_test.go @@ -6,36 +6,41 @@ import ( "encoding/json" "errors" "io" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" + "go.uber.org/zap" ) func TestPublishAndRequestEventConfiguration_MarshalJSONTemplate(t *testing.T) { tests := []struct { name string - config PublishAndRequestEventConfiguration + config publishData wantPattern string }{ { name: "simple configuration", - config: PublishAndRequestEventConfiguration{ + config: publishData{ Provider: "test-provider", Subject: "test-subject", Event: Event{Data: json.RawMessage(`{"message":"hello"}`)}, + FieldName: "test-field", }, - wantPattern: `{"subject":"test-subject", "event": {"data": {"message":"hello"}}, "providerId":"test-provider"}`, + wantPattern: `{"subject":"test-subject", "event": {"data": {"message":"hello"}}, "providerId":"test-provider", "rootFieldName":"test-field"}`, }, { name: "with special characters", - config: PublishAndRequestEventConfiguration{ + config: publishData{ Provider: "test-provider-id", Subject: "subject-with-hyphens", Event: Event{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, + FieldName: "test-field", }, - wantPattern: `{"subject":"subject-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}}, "providerId":"test-provider-id"}`, + wantPattern: `{"subject":"subject-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`, }, } @@ -43,11 +48,27 @@ func TestPublishAndRequestEventConfiguration_MarshalJSONTemplate(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result, err := tt.config.MarshalJSONTemplate() assert.NoError(t, err) - assert.Equal(t, tt.wantPattern, result) + assert.Equal(t, tt.wantPattern, string(result)) }) } } +func TestPublishData_PublishEventConfiguration(t *testing.T) { + data := publishData{ + Provider: "test-provider", + Subject: "test-subject", + FieldName: "test-field", + } + + evtCfg := &PublishAndRequestEventConfiguration{ + Provider: data.Provider, + Subject: data.Subject, + FieldName: data.FieldName, + } + + assert.Equal(t, evtCfg, data.PublishEventConfiguration()) +} + func TestNatsPublishDataSource_Load(t *testing.T) { tests := []struct { name string @@ -61,10 +82,11 @@ func TestNatsPublishDataSource_Load(t *testing.T) { name: "successful publish", input: `{"subject":"test-subject", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { - m.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { + m.On("Publish", mock.Anything, mock.MatchedBy(func(event *PublishAndRequestEventConfiguration) bool { return event.ProviderID() == "test-provider" && - event.Subject == "test-subject" && - string(event.Event.Data) == `{"message":"hello"}` + event.Subject == "test-subject" + }), mock.MatchedBy(func(events []datasource.StreamEvent) bool { + return len(events) == 1 && strings.EqualFold(string(events[0].GetData()), `{"message":"hello"}`) })).Return(nil) }, expectError: false, @@ -75,7 +97,7 @@ func TestNatsPublishDataSource_Load(t *testing.T) { name: "publish error", input: `{"subject":"test-subject", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { - m.On("Publish", mock.Anything, mock.Anything).Return(errors.New("publish error")) + m.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("publish error")) }, expectError: false, // The Load method doesn't return the publish error directly expectedOutput: `{"success": false}`, @@ -136,13 +158,14 @@ func TestNatsRequestDataSource_Load(t *testing.T) { name: "successful request", input: `{"subject":"test-subject", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { - m.On("Request", mock.Anything, mock.MatchedBy(func(event PublishAndRequestEventConfiguration) bool { + m.On("Request", mock.Anything, mock.MatchedBy(func(event *PublishAndRequestEventConfiguration) bool { return event.ProviderID() == "test-provider" && - event.Subject == "test-subject" && - string(event.Event.Data) == `{"message":"hello"}` + event.Subject == "test-subject" + }), mock.MatchedBy(func(event datasource.StreamEvent) bool { + return event != nil && strings.EqualFold(string(event.GetData()), `{"message":"hello"}`) }), mock.Anything).Run(func(args mock.Arguments) { // Write response to the output buffer - w := args.Get(2).(io.Writer) + w := args.Get(3).(io.Writer) _, _ = w.Write([]byte(`{"response":"success"}`)) }).Return(nil) }, @@ -153,7 +176,7 @@ func TestNatsRequestDataSource_Load(t *testing.T) { name: "request error", input: `{"subject":"test-subject", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, mockSetup: func(m *MockAdapter) { - m.On("Request", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("request error")) + m.On("Request", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(errors.New("request error")) }, expectError: true, expectedOutput: "", @@ -170,10 +193,11 @@ func TestNatsRequestDataSource_Load(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockAdapter := NewMockAdapter(t) + provider := datasource.NewPubSubProvider("test-provider", "nats", mockAdapter, zap.NewNop()) tt.mockSetup(mockAdapter) dataSource := &NatsRequestDataSource{ - pubSub: mockAdapter, + pubSub: provider, } ctx := context.Background() diff --git a/router/pkg/pubsub/nats/mocks.go b/router/pkg/pubsub/nats/mocks.go index 0bc3ada5f0..cfe1a57d95 100644 --- a/router/pkg/pubsub/nats/mocks.go +++ b/router/pkg/pubsub/nats/mocks.go @@ -40,16 +40,16 @@ func (_m *MockAdapter) EXPECT() *MockAdapter_Expecter { } // Publish provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Publish(ctx context.Context, event PublishAndRequestEventConfiguration) error { - ret := _mock.Called(ctx, event) +func (_mock *MockAdapter) Publish(ctx context.Context, cfg datasource.PublishEventConfiguration, events []datasource.StreamEvent) error { + ret := _mock.Called(ctx, cfg, events) if len(ret) == 0 { panic("no return value specified for Publish") } var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, PublishAndRequestEventConfiguration) error); ok { - r0 = returnFunc(ctx, event) + if returnFunc, ok := ret.Get(0).(func(context.Context, datasource.PublishEventConfiguration, []datasource.StreamEvent) error); ok { + r0 = returnFunc(ctx, cfg, events) } else { r0 = ret.Error(0) } @@ -63,24 +63,30 @@ type MockAdapter_Publish_Call struct { // Publish is a helper method to define mock.On call // - ctx context.Context -// - event PublishAndRequestEventConfiguration -func (_e *MockAdapter_Expecter) Publish(ctx interface{}, event interface{}) *MockAdapter_Publish_Call { - return &MockAdapter_Publish_Call{Call: _e.mock.On("Publish", ctx, event)} +// - cfg datasource.PublishEventConfiguration +// - events []datasource.StreamEvent +func (_e *MockAdapter_Expecter) Publish(ctx interface{}, cfg interface{}, events interface{}) *MockAdapter_Publish_Call { + return &MockAdapter_Publish_Call{Call: _e.mock.On("Publish", ctx, cfg, events)} } -func (_c *MockAdapter_Publish_Call) Run(run func(ctx context.Context, event PublishAndRequestEventConfiguration)) *MockAdapter_Publish_Call { +func (_c *MockAdapter_Publish_Call) Run(run func(ctx context.Context, cfg datasource.PublishEventConfiguration, events []datasource.StreamEvent)) *MockAdapter_Publish_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 PublishAndRequestEventConfiguration + var arg1 datasource.PublishEventConfiguration if args[1] != nil { - arg1 = args[1].(PublishAndRequestEventConfiguration) + arg1 = args[1].(datasource.PublishEventConfiguration) + } + var arg2 []datasource.StreamEvent + if args[2] != nil { + arg2 = args[2].([]datasource.StreamEvent) } run( arg0, arg1, + arg2, ) }) return _c @@ -91,22 +97,22 @@ func (_c *MockAdapter_Publish_Call) Return(err error) *MockAdapter_Publish_Call return _c } -func (_c *MockAdapter_Publish_Call) RunAndReturn(run func(ctx context.Context, event PublishAndRequestEventConfiguration) error) *MockAdapter_Publish_Call { +func (_c *MockAdapter_Publish_Call) RunAndReturn(run func(ctx context.Context, cfg datasource.PublishEventConfiguration, events []datasource.StreamEvent) error) *MockAdapter_Publish_Call { _c.Call.Return(run) return _c } // Request provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Request(ctx context.Context, event PublishAndRequestEventConfiguration, w io.Writer) error { - ret := _mock.Called(ctx, event, w) +func (_mock *MockAdapter) Request(ctx context.Context, cfg datasource.PublishEventConfiguration, event datasource.StreamEvent, w io.Writer) error { + ret := _mock.Called(ctx, cfg, event, w) if len(ret) == 0 { panic("no return value specified for Request") } var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, PublishAndRequestEventConfiguration, io.Writer) error); ok { - r0 = returnFunc(ctx, event, w) + if returnFunc, ok := ret.Get(0).(func(context.Context, datasource.PublishEventConfiguration, datasource.StreamEvent, io.Writer) error); ok { + r0 = returnFunc(ctx, cfg, event, w) } else { r0 = ret.Error(0) } @@ -120,30 +126,36 @@ type MockAdapter_Request_Call struct { // Request is a helper method to define mock.On call // - ctx context.Context -// - event PublishAndRequestEventConfiguration +// - cfg datasource.PublishEventConfiguration +// - event datasource.StreamEvent // - w io.Writer -func (_e *MockAdapter_Expecter) Request(ctx interface{}, event interface{}, w interface{}) *MockAdapter_Request_Call { - return &MockAdapter_Request_Call{Call: _e.mock.On("Request", ctx, event, w)} +func (_e *MockAdapter_Expecter) Request(ctx interface{}, cfg interface{}, event interface{}, w interface{}) *MockAdapter_Request_Call { + return &MockAdapter_Request_Call{Call: _e.mock.On("Request", ctx, cfg, event, w)} } -func (_c *MockAdapter_Request_Call) Run(run func(ctx context.Context, event PublishAndRequestEventConfiguration, w io.Writer)) *MockAdapter_Request_Call { +func (_c *MockAdapter_Request_Call) Run(run func(ctx context.Context, cfg datasource.PublishEventConfiguration, event datasource.StreamEvent, w io.Writer)) *MockAdapter_Request_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { arg0 = args[0].(context.Context) } - var arg1 PublishAndRequestEventConfiguration + var arg1 datasource.PublishEventConfiguration if args[1] != nil { - arg1 = args[1].(PublishAndRequestEventConfiguration) + arg1 = args[1].(datasource.PublishEventConfiguration) } - var arg2 io.Writer + var arg2 datasource.StreamEvent if args[2] != nil { - arg2 = args[2].(io.Writer) + arg2 = args[2].(datasource.StreamEvent) + } + var arg3 io.Writer + if args[3] != nil { + arg3 = args[3].(io.Writer) } run( arg0, arg1, arg2, + arg3, ) }) return _c @@ -154,7 +166,7 @@ func (_c *MockAdapter_Request_Call) Return(err error) *MockAdapter_Request_Call return _c } -func (_c *MockAdapter_Request_Call) RunAndReturn(run func(ctx context.Context, event PublishAndRequestEventConfiguration, w io.Writer) error) *MockAdapter_Request_Call { +func (_c *MockAdapter_Request_Call) RunAndReturn(run func(ctx context.Context, cfg datasource.PublishEventConfiguration, event datasource.StreamEvent, w io.Writer) error) *MockAdapter_Request_Call { _c.Call.Return(run) return _c } @@ -262,8 +274,8 @@ func (_c *MockAdapter_Startup_Call) RunAndReturn(run func(ctx context.Context) e } // Subscribe provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { - ret := _mock.Called(ctx, event, updater) +func (_mock *MockAdapter) Subscribe(ctx context.Context, cfg datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { + ret := _mock.Called(ctx, cfg, updater) if len(ret) == 0 { panic("no return value specified for Subscribe") @@ -271,7 +283,7 @@ func (_mock *MockAdapter) Subscribe(ctx context.Context, event datasource.Subscr var r0 error if returnFunc, ok := ret.Get(0).(func(context.Context, datasource.SubscriptionEventConfiguration, datasource.SubscriptionEventUpdater) error); ok { - r0 = returnFunc(ctx, event, updater) + r0 = returnFunc(ctx, cfg, updater) } else { r0 = ret.Error(0) } @@ -285,13 +297,13 @@ type MockAdapter_Subscribe_Call struct { // Subscribe is a helper method to define mock.On call // - ctx context.Context -// - event datasource.SubscriptionEventConfiguration +// - cfg datasource.SubscriptionEventConfiguration // - updater datasource.SubscriptionEventUpdater -func (_e *MockAdapter_Expecter) Subscribe(ctx interface{}, event interface{}, updater interface{}) *MockAdapter_Subscribe_Call { - return &MockAdapter_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, event, updater)} +func (_e *MockAdapter_Expecter) Subscribe(ctx interface{}, cfg interface{}, updater interface{}) *MockAdapter_Subscribe_Call { + return &MockAdapter_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, cfg, updater)} } -func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater)) *MockAdapter_Subscribe_Call { +func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, cfg datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater)) *MockAdapter_Subscribe_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -319,7 +331,7 @@ func (_c *MockAdapter_Subscribe_Call) Return(err error) *MockAdapter_Subscribe_C return _c } -func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error) *MockAdapter_Subscribe_Call { +func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, cfg datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error) *MockAdapter_Subscribe_Call { _c.Call.Return(run) return _c } diff --git a/router/pkg/pubsub/nats/provider_builder.go b/router/pkg/pubsub/nats/provider_builder.go index e3ba5f7cb0..2b07c4217a 100644 --- a/router/pkg/pubsub/nats/provider_builder.go +++ b/router/pkg/pubsub/nats/provider_builder.go @@ -20,16 +20,15 @@ type ProviderBuilder struct { logger *zap.Logger hostName string routerListenAddr string - adapters map[string]Adapter } func (p *ProviderBuilder) TypeID() string { return providerTypeID } -func (p *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.NatsEventConfiguration) (datasource.EngineDataSourceFactory, error) { +func (p *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.NatsEventConfiguration, providers map[string]datasource.Provider) (datasource.EngineDataSourceFactory, error) { providerId := data.GetEngineEventConfiguration().GetProviderId() - adapter, ok := p.adapters[providerId] + provider, ok := providers[providerId] if !ok { return nil, fmt.Errorf("failed to get adapter for provider %s with ID %s", p.TypeID(), providerId) } @@ -46,12 +45,13 @@ func (p *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.NatsEventCon return nil, fmt.Errorf("unsupported event type: %s", data.GetEngineEventConfiguration().GetType()) } dataSourceFactory := &EngineDataSourceFactory{ - NatsAdapter: adapter, + NatsAdapter: provider, fieldName: data.GetEngineEventConfiguration().GetFieldName(), eventType: eventType, subjects: data.GetSubjects(), providerId: providerId, withStreamConfiguration: data.GetStreamConfiguration() != nil, + logger: p.logger, } if data.GetStreamConfiguration() != nil { @@ -65,11 +65,10 @@ func (p *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.NatsEventCon } func (p *ProviderBuilder) BuildProvider(provider config.NatsEventSource, providerOpts datasource.ProviderOpts) (datasource.Provider, error) { - adapter, pubSubProvider, err := buildProvider(p.ctx, provider, p.logger, p.hostName, p.routerListenAddr, providerOpts) + pubSubProvider, err := buildProvider(p.ctx, provider, p.logger, p.hostName, p.routerListenAddr, providerOpts) if err != nil { return nil, err } - p.adapters[provider.ID] = adapter return pubSubProvider, nil } @@ -118,18 +117,18 @@ func buildNatsOptions(eventSource config.NatsEventSource, logger *zap.Logger) ([ return opts, nil } -func buildProvider(ctx context.Context, provider config.NatsEventSource, logger *zap.Logger, hostName string, routerListenAddr string, providerOpts datasource.ProviderOpts) (Adapter, datasource.Provider, error) { +func buildProvider(ctx context.Context, provider config.NatsEventSource, logger *zap.Logger, hostName string, routerListenAddr string, providerOpts datasource.ProviderOpts) (datasource.Provider, error) { options, err := buildNatsOptions(provider, logger) if err != nil { - return nil, nil, fmt.Errorf("failed to build options for Nats provider with ID \"%s\": %w", provider.ID, err) + return nil, fmt.Errorf("failed to build options for Nats provider with ID \"%s\": %w", provider.ID, err) } adapter, err := NewAdapter(ctx, logger, provider.URL, options, hostName, routerListenAddr, providerOpts) if err != nil { - return nil, nil, fmt.Errorf("failed to create adapter for Nats provider with ID \"%s\": %w", provider.ID, err) + return nil, fmt.Errorf("failed to create adapter for Nats provider with ID \"%s\": %w", provider.ID, err) } pubSubProvider := datasource.NewPubSubProvider(provider.ID, providerTypeID, adapter, logger) - return adapter, pubSubProvider, nil + return pubSubProvider, nil } func NewProviderBuilder( @@ -143,6 +142,5 @@ func NewProviderBuilder( logger: logger, hostName: hostName, routerListenAddr: routerListenAddr, - adapters: make(map[string]Adapter), } } diff --git a/router/pkg/pubsub/pubsub.go b/router/pkg/pubsub/pubsub.go index 085de71a0e..19c908712e 100644 --- a/router/pkg/pubsub/pubsub.go +++ b/router/pkg/pubsub/pubsub.go @@ -51,11 +51,6 @@ func (e *ProviderNotDefinedError) Error() string { return fmt.Sprintf("%s provider with ID %s is not defined", e.ProviderTypeID, e.ProviderID) } -// Hooks contains hooks for the pubsub providers and data sources -type Hooks struct { - SubscriptionOnStart []pubsub_datasource.SubscriptionOnStartFn -} - // BuildProvidersAndDataSources is a generic function that builds providers and data sources for the given // EventsConfiguration and DataSourceConfigurationWithMetadata func BuildProvidersAndDataSources( @@ -66,7 +61,7 @@ func BuildProvidersAndDataSources( dsConfs []DataSourceConfigurationWithMetadata, hostName string, routerListenAddr string, - hooks Hooks, + hooks pubsub_datasource.Hooks, ) ([]pubsub_datasource.Provider, []plan.DataSource, error) { if store == nil { store = metric.NewNoopStreamMetricStore() @@ -88,7 +83,9 @@ func BuildProvidersAndDataSources( if err != nil { return nil, nil, err } - pubSubProviders = append(pubSubProviders, kafkaPubSubProviders...) + for _, provider := range kafkaPubSubProviders { + pubSubProviders = append(pubSubProviders, provider) + } outs = append(outs, kafkaOuts...) // initialize NATS providers and data sources @@ -104,7 +101,9 @@ func BuildProvidersAndDataSources( if err != nil { return nil, nil, err } - pubSubProviders = append(pubSubProviders, natsPubSubProviders...) + for _, provider := range natsPubSubProviders { + pubSubProviders = append(pubSubProviders, provider) + } outs = append(outs, natsOuts...) // initialize Redis providers and data sources @@ -120,7 +119,9 @@ func BuildProvidersAndDataSources( if err != nil { return nil, nil, err } - pubSubProviders = append(pubSubProviders, redisPubSubProviders...) + for _, provider := range redisPubSubProviders { + pubSubProviders = append(pubSubProviders, provider) + } outs = append(outs, redisOuts...) return pubSubProviders, outs, nil @@ -129,12 +130,11 @@ func BuildProvidersAndDataSources( func build[P GetID, E GetEngineEventConfiguration]( ctx context.Context, builder pubsub_datasource.ProviderBuilder[P, E], - providersData []P, - dsConfs []dsConfAndEvents[E], + providersData []P, dsConfs []dsConfAndEvents[E], store metric.StreamMetricStore, - hooks Hooks, -) ([]pubsub_datasource.Provider, []plan.DataSource, error) { - var pubSubProviders []pubsub_datasource.Provider + hooks pubsub_datasource.Hooks, +) (map[string]pubsub_datasource.Provider, []plan.DataSource, error) { + pubSubProviders := make(map[string]pubsub_datasource.Provider) var outs []plan.DataSource // check used providers @@ -148,7 +148,6 @@ func build[P GetID, E GetEngineEventConfiguration]( } // initialize providers if used - providerIds := []string{} for _, providerData := range providersData { if !slices.Contains(usedProviderIds, providerData.GetID()) { continue @@ -159,13 +158,13 @@ func build[P GetID, E GetEngineEventConfiguration]( if err != nil { return nil, nil, err } - pubSubProviders = append(pubSubProviders, provider) - providerIds = append(providerIds, provider.ID()) + provider.SetHooks(hooks) + pubSubProviders[provider.ID()] = provider } // check if all used providers are initialized for _, providerId := range usedProviderIds { - if !slices.Contains(providerIds, providerId) { + if _, ok := pubSubProviders[providerId]; !ok { return pubSubProviders, nil, &ProviderNotDefinedError{ ProviderID: providerId, ProviderTypeID: builder.TypeID(), @@ -176,7 +175,12 @@ func build[P GetID, E GetEngineEventConfiguration]( // build data sources for each event for _, dsConf := range dsConfs { for i, event := range dsConf.events { - plannerConfig := pubsub_datasource.NewPlannerConfig(builder, event, hooks.SubscriptionOnStart) + plannerConfig := pubsub_datasource.NewPlannerConfig( + builder, + event, + pubSubProviders, + hooks, + ) out, err := plan.NewDataSourceConfiguration( dsConf.dsConf.Configuration.Id+"-"+builder.TypeID()+"-"+strconv.Itoa(i), pubsub_datasource.NewPlannerFactory(ctx, plannerConfig), diff --git a/router/pkg/pubsub/pubsub_test.go b/router/pkg/pubsub/pubsub_test.go index 976980b4ff..39444689ac 100644 --- a/router/pkg/pubsub/pubsub_test.go +++ b/router/pkg/pubsub/pubsub_test.go @@ -62,13 +62,17 @@ func TestBuild_OK(t *testing.T) { } mockPubSubProvider.On("ID").Return("provider-1") + mockPubSubProvider.On("SetHooks", datasource.Hooks{ + OnReceiveEvents: []datasource.OnReceiveEventsFn(nil), + OnPublishEvents: []datasource.OnPublishEventsFn(nil), + }).Return(nil) mockBuilder.On("TypeID").Return("nats") mockBuilder.On("BuildProvider", natsEventSources[0]).Return(mockPubSubProvider, nil) // ctx, kafkaBuilder, config.Providers.Kafka, kafkaDsConfsWithEvents // Execute the function - providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), Hooks{}) + providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), datasource.Hooks{}) // Assertions assert.NoError(t, err) @@ -124,7 +128,7 @@ func TestBuild_ProviderError(t *testing.T) { mockBuilder.On("BuildProvider", natsEventSources[0], mock.Anything).Return(nil, errors.New("provider error")) // Execute the function - providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), Hooks{}) + providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), datasource.Hooks{}) // Assertions assert.Error(t, err) @@ -179,7 +183,7 @@ func TestBuild_ShouldGetAnErrorIfProviderIsNotDefined(t *testing.T) { mockBuilder.On("TypeID").Return("nats") // Execute the function - providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), Hooks{}) + providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), datasource.Hooks{}) // Assertions assert.Error(t, err) @@ -237,13 +241,17 @@ func TestBuild_ShouldNotInitializeProviderIfNotUsed(t *testing.T) { } mockPubSubUsedProvider.On("ID").Return("provider-2") + mockPubSubUsedProvider.On("SetHooks", datasource.Hooks{ + OnReceiveEvents: []datasource.OnReceiveEventsFn(nil), + OnPublishEvents: []datasource.OnPublishEventsFn(nil), + }).Return(nil) mockBuilder.On("TypeID").Return("nats") mockBuilder.On("BuildProvider", natsEventSources[1], mock.Anything). Return(mockPubSubUsedProvider, nil) // Execute the function - providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), Hooks{}) + providers, dataSources, err := build(ctx, mockBuilder, natsEventSources, dsConfs, rmetric.NewNoopStreamMetricStore(), datasource.Hooks{}) // Assertions assert.NoError(t, err) @@ -294,7 +302,7 @@ func TestBuildProvidersAndDataSources_Nats_OK(t *testing.T) { {ID: "provider-1"}, }, }, - }, nil, zap.NewNop(), dsConfs, "host", "addr", Hooks{}) + }, nil, zap.NewNop(), dsConfs, "host", "addr", datasource.Hooks{}) // Assertions assert.NoError(t, err) @@ -347,7 +355,7 @@ func TestBuildProvidersAndDataSources_Kafka_OK(t *testing.T) { {ID: "provider-1"}, }, }, - }, nil, zap.NewNop(), dsConfs, "host", "addr", Hooks{}) + }, nil, zap.NewNop(), dsConfs, "host", "addr", datasource.Hooks{}) // Assertions assert.NoError(t, err) @@ -400,7 +408,7 @@ func TestBuildProvidersAndDataSources_Redis_OK(t *testing.T) { {ID: "provider-1"}, }, }, - }, nil, zap.NewNop(), dsConfs, "host", "addr", Hooks{}) + }, nil, zap.NewNop(), dsConfs, "host", "addr", datasource.Hooks{}) // Assertions assert.NoError(t, err) diff --git a/router/pkg/pubsub/redis/adapter.go b/router/pkg/pubsub/redis/adapter.go index 5cb0055a36..8c65bc3413 100644 --- a/router/pkg/pubsub/redis/adapter.go +++ b/router/pkg/pubsub/redis/adapter.go @@ -17,19 +17,10 @@ const ( redisReceive = "receive" ) -// Adapter defines the methods that a Redis adapter should implement -type Adapter interface { - // Subscribe subscribes to the given events and sends updates to the updater - Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error - // Publish publishes the given event to the specified channel - Publish(ctx context.Context, event PublishEventConfiguration) error - // Startup initializes the adapter - Startup(ctx context.Context) error - // Shutdown gracefully shuts down the adapter - Shutdown(ctx context.Context) error -} +// Ensure ProviderAdapter implements ProviderSubscriptionHooks +var _ datasource.Adapter = (*ProviderAdapter)(nil) -func NewProviderAdapter(ctx context.Context, logger *zap.Logger, urls []string, clusterEnabled bool, opts datasource.ProviderOpts) Adapter { +func NewProviderAdapter(ctx context.Context, logger *zap.Logger, urls []string, clusterEnabled bool, opts datasource.ProviderOpts) datasource.Adapter { ctx, cancel := context.WithCancel(ctx) if logger == nil { logger = zap.NewNop() @@ -96,10 +87,11 @@ func (p *ProviderAdapter) Shutdown(ctx context.Context) error { func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { subConf, ok := conf.(*SubscriptionEventConfiguration) if !ok { - return datasource.NewError("invalid event type for Kafka adapter", nil) + return datasource.NewError("subscription event not support by redis provider", nil) } + log := p.logger.With( - zap.String("provider_id", subConf.ProviderID()), + zap.String("provider_id", conf.ProviderID()), zap.String("method", "subscribe"), zap.Strings("channels", subConf.Channels), ) @@ -136,9 +128,9 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri ProviderType: metric.ProviderTypeRedis, DestinationName: msg.Channel, }) - updater.Update(&Event{ + updater.Update([]datasource.StreamEvent{&Event{ Data: []byte(msg.Payload), - }) + }}) case <-p.ctx.Done(): // When the application context is done, we stop the subscription if it is not already done log.Debug("application context done, stopping subscription") @@ -156,41 +148,59 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri return nil } -func (p *ProviderAdapter) Publish(ctx context.Context, event PublishEventConfiguration) error { +func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEventConfiguration, events []datasource.StreamEvent) error { + pubConf, ok := conf.(*PublishEventConfiguration) + if !ok { + return datasource.NewError("publish event not support by redis provider", nil) + } + log := p.logger.With( - zap.String("provider_id", event.ProviderID()), + zap.String("provider_id", conf.ProviderID()), zap.String("method", "publish"), - zap.String("channel", event.Channel), + zap.String("channel", pubConf.Channel), ) - log.Debug("publish", zap.ByteString("data", event.Event.Data)) - - data, dataErr := event.Event.Data.MarshalJSON() - if dataErr != nil { - log.Error("error marshalling data", zap.Error(dataErr)) - return datasource.NewError("error marshalling data", dataErr) - } if p.conn == nil { return datasource.NewError("redis connection not initialized", nil) } - intCmd := p.conn.Publish(ctx, event.Channel, data) - if intCmd.Err() != nil { - log.Error("publish error", zap.Error(intCmd.Err())) - p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID(), - StreamOperationName: redisPublish, - ProviderType: metric.ProviderTypeRedis, - ErrorType: "publish_error", - DestinationName: event.Channel, - }) - return datasource.NewError(fmt.Sprintf("error publishing to Redis PubSub channel %s", event.Channel), intCmd.Err()) + + if len(events) == 0 { + return nil + } + + log.Debug("publish", zap.Int("event_count", len(events))) + + for _, streamEvent := range events { + redisEvent, ok := streamEvent.(*Event) + if !ok { + return datasource.NewError("invalid event type for Redis adapter", nil) + } + + data, dataErr := redisEvent.Data.MarshalJSON() + if dataErr != nil { + log.Error("error marshalling data", zap.Error(dataErr)) + return datasource.NewError("error marshalling data", dataErr) + } + + intCmd := p.conn.Publish(ctx, pubConf.Channel, data) + if intCmd.Err() != nil { + log.Error("publish error", zap.Error(intCmd.Err())) + p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ + ProviderId: pubConf.ProviderID(), + StreamOperationName: redisPublish, + ProviderType: metric.ProviderTypeRedis, + ErrorType: "publish_error", + DestinationName: pubConf.Channel, + }) + return datasource.NewError(fmt.Sprintf("error publishing to Redis PubSub channel %s", pubConf.Channel), intCmd.Err()) + } } p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: event.ProviderID(), + ProviderId: pubConf.ProviderID(), StreamOperationName: redisPublish, ProviderType: metric.ProviderTypeRedis, - DestinationName: event.Channel, + DestinationName: pubConf.Channel, }) return nil } diff --git a/router/pkg/pubsub/redis/engine_datasource.go b/router/pkg/pubsub/redis/engine_datasource.go index 3a685fe9b0..e796b60e66 100644 --- a/router/pkg/pubsub/redis/engine_datasource.go +++ b/router/pkg/pubsub/redis/engine_datasource.go @@ -6,9 +6,13 @@ import ( "encoding/json" "fmt" "io" + "slices" + "github.com/buger/jsonparser" + "github.com/cespare/xxhash/v2" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" + "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) // Event represents an event from Redis @@ -20,6 +24,12 @@ func (e *Event) GetData() []byte { return e.Data } +func (e *Event) Clone() datasource.StreamEvent { + return &Event{ + Data: slices.Clone(e.Data), + } +} + // SubscriptionEventConfiguration contains configuration for subscription events type SubscriptionEventConfiguration struct { Provider string `json:"providerId"` @@ -42,11 +52,31 @@ func (s *SubscriptionEventConfiguration) RootFieldName() string { return s.FieldName } +// publishData is a private type that is used to pass data from the engine to the provider + +type publishData struct { + Provider string `json:"providerId"` + Channel string `json:"channel"` + Event Event `json:"event"` + FieldName string `json:"rootFieldName"` +} + +func (p *publishData) PublishEventConfiguration() datasource.PublishEventConfiguration { + return &PublishEventConfiguration{ + Provider: p.Provider, + Channel: p.Channel, + FieldName: p.FieldName, + } +} + +func (p *publishData) MarshalJSONTemplate() (string, error) { + return fmt.Sprintf(`{"channel":"%s", "event": {"data": %s}, "providerId":"%s", "rootFieldName":"%s"}`, p.Channel, p.Event.Data, p.Provider, p.FieldName), nil +} + // PublishEventConfiguration contains configuration for publish events type PublishEventConfiguration struct { Provider string `json:"providerId"` Channel string `json:"channel"` - Event Event `json:"event"` FieldName string `json:"rootFieldName"` } @@ -65,25 +95,77 @@ func (p *PublishEventConfiguration) RootFieldName() string { return p.FieldName } -func (s *PublishEventConfiguration) MarshalJSONTemplate() (string, error) { - return fmt.Sprintf(`{"channel":"%s", "event": {"data": %s}, "providerId":"%s"}`, s.Channel, s.Event.Data, s.ProviderID()), nil +// SubscriptionDataSource implements resolve.SubscriptionDataSource for Redis +type SubscriptionDataSource struct { + pubSub datasource.Adapter +} + +func (s *SubscriptionDataSource) SubscriptionEventConfiguration(input []byte) datasource.SubscriptionEventConfiguration { + var subscriptionConfiguration SubscriptionEventConfiguration + err := json.Unmarshal(input, &subscriptionConfiguration) + if err != nil { + return nil + } + return &subscriptionConfiguration +} + +// UniqueRequestID computes a unique ID for the subscription request +func (s *SubscriptionDataSource) UniqueRequestID(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + val, _, _, err := jsonparser.Get(input, "channels") + if err != nil { + return err + } + + _, err = xxh.Write(val) + if err != nil { + return err + } + + val, _, _, err = jsonparser.Get(input, "providerId") + if err != nil { + return err + } + + _, err = xxh.Write(val) + return err +} + +// Start starts the subscription +func (s *SubscriptionDataSource) Start(ctx *resolve.Context, input []byte, updater datasource.SubscriptionEventUpdater) error { + subConf := s.SubscriptionEventConfiguration(input) + if subConf == nil { + return fmt.Errorf("no subscription configuration found") + } + + conf, ok := subConf.(*SubscriptionEventConfiguration) + if !ok { + return fmt.Errorf("invalid subscription configuration") + } + + return s.pubSub.Subscribe(ctx.Context(), conf, updater) +} + +// LoadInitialData implements the interface method (not used for this subscription type) +func (s *SubscriptionDataSource) LoadInitialData(ctx context.Context) (initial []byte, err error) { + return nil, nil } // PublishDataSource implements resolve.DataSource for Redis publishing type PublishDataSource struct { - pubSub Adapter + pubSub datasource.Adapter } // Load processes a request to publish to Redis func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.Buffer) error { - var publishConfiguration PublishEventConfiguration - if err := json.Unmarshal(input, &publishConfiguration); err != nil { + var publishData publishData + if err := json.Unmarshal(input, &publishData); err != nil { return err } - if err := s.pubSub.Publish(ctx, publishConfiguration); err != nil { - _, err = io.WriteString(out, `{"success": false}`) - return err + if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{&publishData.Event}); err != nil { + // err will not be returned but only logged inside PubSubProvider.Publish to avoid a "unable to fetch from subgraph" error + _, errWrite := io.WriteString(out, `{"success": false}`) + return errWrite } _, err := io.WriteString(out, `{"success": true}`) return err diff --git a/router/pkg/pubsub/redis/engine_datasource_factory.go b/router/pkg/pubsub/redis/engine_datasource_factory.go index bce913e54e..46f22e29b9 100644 --- a/router/pkg/pubsub/redis/engine_datasource_factory.go +++ b/router/pkg/pubsub/redis/engine_datasource_factory.go @@ -9,6 +9,7 @@ import ( "github.com/cespare/xxhash/v2" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" + "go.uber.org/zap" ) type EventType int @@ -20,12 +21,13 @@ const ( // EngineDataSourceFactory implements the datasource.EngineDataSourceFactory interface for Redis type EngineDataSourceFactory struct { - RedisAdapter Adapter + RedisAdapter datasource.Adapter fieldName string eventType EventType channels []string providerId string + logger *zap.Logger } func (c *EngineDataSourceFactory) GetFieldName() string { @@ -60,11 +62,11 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri channel := channels[0] providerId := c.providerId - evtCfg := PublishEventConfiguration{ + evtCfg := publishData{ Provider: providerId, Channel: channel, - Event: Event{Data: eventData}, FieldName: c.fieldName, + Event: Event{Data: eventData}, } return evtCfg.MarshalJSONTemplate() @@ -92,7 +94,7 @@ func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (datasource.Su _, err = xxh.Write(val) return err - }), nil + }, c.logger), nil } // ResolveDataSourceSubscriptionInput builds the input for the subscription data source diff --git a/router/pkg/pubsub/redis/engine_datasource_factory_test.go b/router/pkg/pubsub/redis/engine_datasource_factory_test.go index f96691583d..7dc4ade017 100644 --- a/router/pkg/pubsub/redis/engine_datasource_factory_test.go +++ b/router/pkg/pubsub/redis/engine_datasource_factory_test.go @@ -5,12 +5,14 @@ import ( "context" "encoding/json" "errors" + "strings" "testing" "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/cosmo/router/pkg/pubsub/pubsubtest" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) @@ -33,11 +35,13 @@ func TestRedisEngineDataSourceFactory(t *testing.T) { // TestEngineDataSourceFactoryWithMockAdapter tests the EngineDataSourceFactory with a mocked adapter func TestEngineDataSourceFactoryWithMockAdapter(t *testing.T) { // Create mock adapter - mockAdapter := NewMockAdapter(t) + mockAdapter := datasource.NewMockProvider(t) // Configure mock expectations for Publish - mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishEventConfiguration) bool { + mockAdapter.On("Publish", mock.Anything, mock.MatchedBy(func(event *PublishEventConfiguration) bool { return event.ProviderID() == "test-provider" && event.Channel == "test-channel" + }), mock.MatchedBy(func(events []datasource.StreamEvent) bool { + return len(events) == 1 && strings.EqualFold(string(events[0].GetData()), `{"test":"data"}`) })).Return(nil) // Create the data source with mock adapter @@ -67,7 +71,7 @@ func TestEngineDataSourceFactoryWithMockAdapter(t *testing.T) { // TestEngineDataSourceFactory_GetResolveDataSource_WrongType tests the EngineDataSourceFactory with a mocked adapter func TestEngineDataSourceFactory_GetResolveDataSource_WrongType(t *testing.T) { // Create mock adapter - mockAdapter := NewMockAdapter(t) + mockAdapter := datasource.NewMockProvider(t) // Create the data source with mock adapter pubsub := &EngineDataSourceFactory{ @@ -210,7 +214,7 @@ func TestRedisEngineDataSourceFactory_UniqueRequestID(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { factory := &EngineDataSourceFactory{ - RedisAdapter: NewMockAdapter(t), + RedisAdapter: datasource.NewMockProvider(t), } source, err := factory.ResolveDataSourceSubscription() require.NoError(t, err) diff --git a/router/pkg/pubsub/redis/engine_datasource_test.go b/router/pkg/pubsub/redis/engine_datasource_test.go index 74b7d564d7..b322c8a60c 100644 --- a/router/pkg/pubsub/redis/engine_datasource_test.go +++ b/router/pkg/pubsub/redis/engine_datasource_test.go @@ -5,36 +5,40 @@ import ( "context" "encoding/json" "errors" + "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" ) func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { tests := []struct { name string - config PublishEventConfiguration + config publishData wantPattern string }{ { name: "simple configuration", - config: PublishEventConfiguration{ + config: publishData{ Provider: "test-provider", Channel: "test-channel", Event: Event{Data: json.RawMessage(`{"message":"hello"}`)}, + FieldName: "test-field", }, - wantPattern: `{"channel":"test-channel", "event": {"data": {"message":"hello"}}, "providerId":"test-provider"}`, + wantPattern: `{"channel":"test-channel", "event": {"data": {"message":"hello"}}, "providerId":"test-provider", "rootFieldName":"test-field"}`, }, { name: "with special characters", - config: PublishEventConfiguration{ + config: publishData{ Provider: "test-provider-id", Channel: "channel-with-hyphens", Event: Event{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, + FieldName: "test-field", }, - wantPattern: `{"channel":"channel-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}}, "providerId":"test-provider-id"}`, + wantPattern: `{"channel":"channel-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`, }, } @@ -47,11 +51,27 @@ func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { } } +func TestPublishData_PublishEventConfiguration(t *testing.T) { + data := publishData{ + Provider: "test-provider", + Channel: "test-channel", + FieldName: "test-field", + } + + evtCfg := &PublishEventConfiguration{ + Provider: data.Provider, + Channel: data.Channel, + FieldName: data.FieldName, + } + + assert.Equal(t, evtCfg, data.PublishEventConfiguration()) +} + func TestRedisPublishDataSource_Load(t *testing.T) { tests := []struct { name string input string - mockSetup func(*MockAdapter) + mockSetup func(*datasource.MockProvider) expectError bool expectedOutput string expectPublished bool @@ -59,11 +79,12 @@ func TestRedisPublishDataSource_Load(t *testing.T) { { name: "successful publish", input: `{"channel":"test-channel", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter) { - m.On("Publish", mock.Anything, mock.MatchedBy(func(event PublishEventConfiguration) bool { + mockSetup: func(m *datasource.MockProvider) { + m.On("Publish", mock.Anything, mock.MatchedBy(func(event *PublishEventConfiguration) bool { return event.ProviderID() == "test-provider" && - event.Channel == "test-channel" && - string(event.Event.Data) == `{"message":"hello"}` + event.Channel == "test-channel" + }), mock.MatchedBy(func(events []datasource.StreamEvent) bool { + return len(events) == 1 && strings.EqualFold(string(events[0].GetData()), `{"message":"hello"}`) })).Return(nil) }, expectError: false, @@ -73,8 +94,8 @@ func TestRedisPublishDataSource_Load(t *testing.T) { { name: "publish error", input: `{"channel":"test-channel", "event": {"data":{"message":"hello"}}, "providerId":"test-provider"}`, - mockSetup: func(m *MockAdapter) { - m.On("Publish", mock.Anything, mock.Anything).Return(errors.New("publish error")) + mockSetup: func(m *datasource.MockProvider) { + m.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(errors.New("publish error")) }, expectError: false, // The Load method doesn't return the publish error directly expectedOutput: `{"success": false}`, @@ -83,7 +104,7 @@ func TestRedisPublishDataSource_Load(t *testing.T) { { name: "invalid input json", input: `{"invalid json":`, - mockSetup: func(m *MockAdapter) {}, + mockSetup: func(m *datasource.MockProvider) {}, expectError: true, expectPublished: false, }, @@ -91,7 +112,7 @@ func TestRedisPublishDataSource_Load(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - mockAdapter := NewMockAdapter(t) + mockAdapter := datasource.NewMockProvider(t) tt.mockSetup(mockAdapter) dataSource := &PublishDataSource{ @@ -116,7 +137,7 @@ func TestRedisPublishDataSource_Load(t *testing.T) { func TestRedisPublishDataSource_LoadWithFiles(t *testing.T) { t.Run("panic on not implemented", func(t *testing.T) { dataSource := &PublishDataSource{ - pubSub: NewMockAdapter(t), + pubSub: datasource.NewMockProvider(t), } assert.Panics(t, func() { diff --git a/router/pkg/pubsub/redis/mocks.go b/router/pkg/pubsub/redis/mocks.go deleted file mode 100644 index 6f6938cdd0..0000000000 --- a/router/pkg/pubsub/redis/mocks.go +++ /dev/null @@ -1,261 +0,0 @@ -// Code generated by mockery; DO NOT EDIT. -// github.com/vektra/mockery -// template: testify - -package redis - -import ( - "context" - - mock "github.com/stretchr/testify/mock" - "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" -) - -// NewMockAdapter creates a new instance of MockAdapter. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. -// The first argument is typically a *testing.T value. -func NewMockAdapter(t interface { - mock.TestingT - Cleanup(func()) -}) *MockAdapter { - mock := &MockAdapter{} - mock.Mock.Test(t) - - t.Cleanup(func() { mock.AssertExpectations(t) }) - - return mock -} - -// MockAdapter is an autogenerated mock type for the Adapter type -type MockAdapter struct { - mock.Mock -} - -type MockAdapter_Expecter struct { - mock *mock.Mock -} - -func (_m *MockAdapter) EXPECT() *MockAdapter_Expecter { - return &MockAdapter_Expecter{mock: &_m.Mock} -} - -// Publish provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Publish(ctx context.Context, event PublishEventConfiguration) error { - ret := _mock.Called(ctx, event) - - if len(ret) == 0 { - panic("no return value specified for Publish") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, PublishEventConfiguration) error); ok { - r0 = returnFunc(ctx, event) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockAdapter_Publish_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Publish' -type MockAdapter_Publish_Call struct { - *mock.Call -} - -// Publish is a helper method to define mock.On call -// - ctx context.Context -// - event PublishEventConfiguration -func (_e *MockAdapter_Expecter) Publish(ctx interface{}, event interface{}) *MockAdapter_Publish_Call { - return &MockAdapter_Publish_Call{Call: _e.mock.On("Publish", ctx, event)} -} - -func (_c *MockAdapter_Publish_Call) Run(run func(ctx context.Context, event PublishEventConfiguration)) *MockAdapter_Publish_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 PublishEventConfiguration - if args[1] != nil { - arg1 = args[1].(PublishEventConfiguration) - } - run( - arg0, - arg1, - ) - }) - return _c -} - -func (_c *MockAdapter_Publish_Call) Return(err error) *MockAdapter_Publish_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockAdapter_Publish_Call) RunAndReturn(run func(ctx context.Context, event PublishEventConfiguration) error) *MockAdapter_Publish_Call { - _c.Call.Return(run) - return _c -} - -// Shutdown provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Shutdown(ctx context.Context) error { - ret := _mock.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Shutdown") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context) error); ok { - r0 = returnFunc(ctx) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockAdapter_Shutdown_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Shutdown' -type MockAdapter_Shutdown_Call struct { - *mock.Call -} - -// Shutdown is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockAdapter_Expecter) Shutdown(ctx interface{}) *MockAdapter_Shutdown_Call { - return &MockAdapter_Shutdown_Call{Call: _e.mock.On("Shutdown", ctx)} -} - -func (_c *MockAdapter_Shutdown_Call) Run(run func(ctx context.Context)) *MockAdapter_Shutdown_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - run( - arg0, - ) - }) - return _c -} - -func (_c *MockAdapter_Shutdown_Call) Return(err error) *MockAdapter_Shutdown_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockAdapter_Shutdown_Call) RunAndReturn(run func(ctx context.Context) error) *MockAdapter_Shutdown_Call { - _c.Call.Return(run) - return _c -} - -// Startup provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Startup(ctx context.Context) error { - ret := _mock.Called(ctx) - - if len(ret) == 0 { - panic("no return value specified for Startup") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context) error); ok { - r0 = returnFunc(ctx) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockAdapter_Startup_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Startup' -type MockAdapter_Startup_Call struct { - *mock.Call -} - -// Startup is a helper method to define mock.On call -// - ctx context.Context -func (_e *MockAdapter_Expecter) Startup(ctx interface{}) *MockAdapter_Startup_Call { - return &MockAdapter_Startup_Call{Call: _e.mock.On("Startup", ctx)} -} - -func (_c *MockAdapter_Startup_Call) Run(run func(ctx context.Context)) *MockAdapter_Startup_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - run( - arg0, - ) - }) - return _c -} - -func (_c *MockAdapter_Startup_Call) Return(err error) *MockAdapter_Startup_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockAdapter_Startup_Call) RunAndReturn(run func(ctx context.Context) error) *MockAdapter_Startup_Call { - _c.Call.Return(run) - return _c -} - -// Subscribe provides a mock function for the type MockAdapter -func (_mock *MockAdapter) Subscribe(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { - ret := _mock.Called(ctx, event, updater) - - if len(ret) == 0 { - panic("no return value specified for Subscribe") - } - - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, datasource.SubscriptionEventConfiguration, datasource.SubscriptionEventUpdater) error); ok { - r0 = returnFunc(ctx, event, updater) - } else { - r0 = ret.Error(0) - } - return r0 -} - -// MockAdapter_Subscribe_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'Subscribe' -type MockAdapter_Subscribe_Call struct { - *mock.Call -} - -// Subscribe is a helper method to define mock.On call -// - ctx context.Context -// - event datasource.SubscriptionEventConfiguration -// - updater datasource.SubscriptionEventUpdater -func (_e *MockAdapter_Expecter) Subscribe(ctx interface{}, event interface{}, updater interface{}) *MockAdapter_Subscribe_Call { - return &MockAdapter_Subscribe_Call{Call: _e.mock.On("Subscribe", ctx, event, updater)} -} - -func (_c *MockAdapter_Subscribe_Call) Run(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater)) *MockAdapter_Subscribe_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 datasource.SubscriptionEventConfiguration - if args[1] != nil { - arg1 = args[1].(datasource.SubscriptionEventConfiguration) - } - var arg2 datasource.SubscriptionEventUpdater - if args[2] != nil { - arg2 = args[2].(datasource.SubscriptionEventUpdater) - } - run( - arg0, - arg1, - arg2, - ) - }) - return _c -} - -func (_c *MockAdapter_Subscribe_Call) Return(err error) *MockAdapter_Subscribe_Call { - _c.Call.Return(err) - return _c -} - -func (_c *MockAdapter_Subscribe_Call) RunAndReturn(run func(ctx context.Context, event datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error) *MockAdapter_Subscribe_Call { - _c.Call.Return(run) - return _c -} diff --git a/router/pkg/pubsub/redis/provider_builder.go b/router/pkg/pubsub/redis/provider_builder.go index 46340934bd..f8814b7d42 100644 --- a/router/pkg/pubsub/redis/provider_builder.go +++ b/router/pkg/pubsub/redis/provider_builder.go @@ -18,7 +18,6 @@ type ProviderBuilder struct { logger *zap.Logger hostName string routerListenAddr string - adapters map[string]Adapter } // NewProviderBuilder creates a new Redis PubSub provider builder @@ -33,7 +32,6 @@ func NewProviderBuilder( logger: logger, hostName: hostName, routerListenAddr: routerListenAddr, - adapters: make(map[string]Adapter), } } @@ -43,8 +41,12 @@ func (b *ProviderBuilder) TypeID() string { } // DataSource creates a Redis PubSub data source for the given event configuration -func (b *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.RedisEventConfiguration) (datasource.EngineDataSourceFactory, error) { +func (b *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.RedisEventConfiguration, providers map[string]datasource.Provider) (datasource.EngineDataSourceFactory, error) { providerId := data.GetEngineEventConfiguration().GetProviderId() + provider, ok := providers[providerId] + if !ok { + return nil, fmt.Errorf("failed to get adapter for provider %s with ID %s", b.TypeID(), providerId) + } var eventType EventType switch data.GetEngineEventConfiguration().GetType() { @@ -57,11 +59,12 @@ func (b *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.RedisEventCo } return &EngineDataSourceFactory{ - RedisAdapter: b.adapters[providerId], fieldName: data.GetEngineEventConfiguration().GetFieldName(), eventType: eventType, channels: data.GetChannels(), providerId: providerId, + RedisAdapter: provider, + logger: b.logger, }, nil } @@ -69,7 +72,6 @@ func (b *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.RedisEventCo func (b *ProviderBuilder) BuildProvider(provider config.RedisEventSource, providerOpts datasource.ProviderOpts) (datasource.Provider, error) { adapter := NewProviderAdapter(b.ctx, b.logger, provider.URLs, provider.ClusterEnabled, providerOpts) pubSubProvider := datasource.NewPubSubProvider(provider.ID, providerTypeID, adapter, b.logger) - b.adapters[provider.ID] = adapter return pubSubProvider, nil } From c498de7622fccc51f8a69b43f8da4543e3368573 Mon Sep 17 00:00:00 2001 From: Dominik <23359034+dkorittki@users.noreply.github.com> Date: Thu, 23 Oct 2025 14:52:10 +0000 Subject: [PATCH 03/44] fix: async handler execution (#2288) --- router-tests/modules/stream_receive_test.go | 376 +++++++++++++++++- router/core/factoryresolver.go | 7 +- router/core/router.go | 11 + router/core/router_config.go | 7 +- router/core/supervisor_instance.go | 8 +- router/demo.config.yaml | 2 +- router/pkg/config/config.go | 7 +- router/pkg/config/config.schema.json | 13 + router/pkg/config/fixtures/full.yaml | 2 + .../pkg/config/testdata/config_defaults.json | 3 + router/pkg/config/testdata/config_full.json | 3 + router/pkg/pubsub/datasource/hooks.go | 7 +- .../datasource/subscription_event_updater.go | 123 ++++-- .../subscription_event_updater_test.go | 26 +- 14 files changed, 538 insertions(+), 57 deletions(-) diff --git a/router-tests/modules/stream_receive_test.go b/router-tests/modules/stream_receive_test.go index a1658dc35c..21f62a9b53 100644 --- a/router-tests/modules/stream_receive_test.go +++ b/router-tests/modules/stream_receive_test.go @@ -2,7 +2,9 @@ package module_test import ( "errors" + "fmt" "net/http" + "sync/atomic" "testing" "time" @@ -196,7 +198,6 @@ func TestReceiveHook(t *testing.T) { cfg := config.Config{ Graph: config.Graph{}, - Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { @@ -353,7 +354,6 @@ func TestReceiveHook(t *testing.T) { cfg := config.Config{ Graph: config.Graph{}, - Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { @@ -518,4 +518,376 @@ func TestReceiveHook(t *testing.T) { xEnv.WaitForTriggerCount(0, Timeout) }) }) + + t.Run("Test error deduplication with multiple subscriptions", func(t *testing.T) { + t.Parallel() + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{ + Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + return nil, errors.New("deduplicated error") + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.ErrorLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + + // Create 3 subscriptions that will all receive the same error + clients := make([]*graphql.SubscriptionClient, 3) + clientRunChs := make([]chan error, 3) + + for i := range 3 { + clients[i] = graphql.NewSubscriptionClient(surl) + clientRunChs[i] = make(chan error) + + subscriptionID, err := clients[i].Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionID) + + go func() { + clientRunChs[i] <- clients[i].Run() + }() + } + + // Wait for all subscriptions to be established + xEnv.WaitForSubscriptionCount(3, Timeout) + + // Produce a message that will trigger the error in all handlers + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + + // Wait for all subscriptions to be closed due to the error + xEnv.WaitForSubscriptionCount(0, Timeout) + + // Verify all clients completed + for i := 0; i < 3; i++ { + testenv.AwaitChannelWithT(t, Timeout, clientRunChs[i], func(t *testing.T, err error) { + require.NoError(t, err) + }, "client should have completed when server closed connection") + } + + xEnv.WaitForTriggerCount(0, Timeout) + + // Verify error deduplication: should see only one error log entry + errorLogs := xEnv.Observer().FilterMessage("some handlers have thrown an error") + assert.Len(t, errorLogs.All(), 1, "should have exactly one deduplicated error log entry") + + // Verify the error log contains the correct error message and count + if len(errorLogs.All()) > 0 { + logEntry := errorLogs.All()[0] + fields := logEntry.ContextMap() + + assert.Equal(t, "deduplicated error", fields["error"], "error message should match") + assert.Equal(t, int64(3), fields["amount_handlers"], "should count all 3 handlers that threw the error") + } + }) + }) + + t.Run("Test unique error messages are all logged", func(t *testing.T) { + t.Parallel() + + var errorCounter atomic.Int32 + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{ + Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + count := errorCounter.Add(1) + return nil, fmt.Errorf("unique error %d", count) + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.ErrorLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + + // Create 3 subscriptions that will each receive a unique error + clients := make([]*graphql.SubscriptionClient, 3) + clientRunChs := make([]chan error, 3) + + for i := range 3 { + clients[i] = graphql.NewSubscriptionClient(surl) + clientRunChs[i] = make(chan error) + + subscriptionID, err := clients[i].Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionID) + + go func() { + clientRunChs[i] <- clients[i].Run() + }() + } + + // Wait for all subscriptions to be established + xEnv.WaitForSubscriptionCount(3, Timeout) + + // Produce a message that will trigger a unique error in each handler + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + + // Wait for all subscriptions to be closed due to the error + xEnv.WaitForSubscriptionCount(0, Timeout) + + // Verify all clients completed + for i := range 3 { + testenv.AwaitChannelWithT(t, Timeout, clientRunChs[i], func(t *testing.T, err error) { + require.NoError(t, err) + }, "client should have completed when server closed connection") + } + + xEnv.WaitForTriggerCount(0, Timeout) + + // Verify no deduplication: should see three error log entries (one for each unique error) + errorLogs := xEnv.Observer().FilterMessage("some handlers have thrown an error") + assert.Len(t, errorLogs.All(), 3, "should have three separate error log entries for unique errors") + + // Verify each error log contains a unique error message and count of 1 + if len(errorLogs.All()) == 3 { + var errorMessages []string + for _, logEntry := range errorLogs.All() { + fields := logEntry.ContextMap() + errorMsg, ok := fields["error"].(string) + require.True(t, ok, "error field should be a string") + + // Check that error message is unique (starts with "unique error") + assert.Contains(t, errorMsg, "unique error", "error message should contain 'unique error'") + assert.NotContains(t, errorMessages, errorMsg, "each error message should be unique") + errorMessages = append(errorMessages, errorMsg) + + // Each unique error should have been thrown by exactly 1 handler + assert.Equal(t, int64(1), fields["amount_handlers"], "each unique error should have amount_handlers = 1") + } + + // Verify we got exactly 3 unique error messages + assert.Len(t, errorMessages, 3, "should have exactly 3 unique error messages") + } + }) + }) + + t.Run("Test concurrent handler execution works", func(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + maxConcurrent int + numSubscribers int + }{ + { + name: "1 concurrent handler", + maxConcurrent: 1, + numSubscribers: 5, + }, + { + name: "2 concurrent handlers", + maxConcurrent: 2, + numSubscribers: 10, + }, + { + name: "10 concurrent handlers", + maxConcurrent: 10, + numSubscribers: 20, + }, + { + name: "20 concurrent handlers", + maxConcurrent: 20, + numSubscribers: 40, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var ( + currentHandlers atomic.Int32 + maxCurrentHandlers atomic.Int32 + finishedHandlers atomic.Int32 + ) + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{ + Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + currentHandlers.Add(1) + + // wait for other handlers in the batch + for { + current := currentHandlers.Load() + max := maxCurrentHandlers.Load() + + if current > max { + maxCurrentHandlers.CompareAndSwap(max, current) + } + + if current >= int32(tc.maxConcurrent) { + // wait to see if the updater spawns too many concurrent handlers + deadline := time.Now().Add(300 * time.Millisecond) + for time.Now().Before(deadline) { + if currentHandlers.Load() > int32(tc.maxConcurrent) { + break + } + } + break + } + + // Let handlers continue if we never reach a batch size = tc.maxConcurrent + // because there are not enough remaining subscribers to be updated. + remainingSubs := tc.numSubscribers - int(finishedHandlers.Load()) + if remainingSubs < tc.maxConcurrent { + break + } + } + + currentHandlers.Add(-1) + finishedHandlers.Add(1) + return events, nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + core.WithSubscriptionHooks(config.SubscriptionHooksConfiguration{ + MaxConcurrentEventReceiveHandlers: tc.maxConcurrent, + }), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionQuery struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + + clients := make([]*graphql.SubscriptionClient, tc.numSubscribers) + clientRunChs := make([]chan error, tc.numSubscribers) + subscriptionArgsChs := make([]chan kafkaSubscriptionArgs, tc.numSubscribers) + + for i := range tc.numSubscribers { + clients[i] = graphql.NewSubscriptionClient(surl) + clientRunChs[i] = make(chan error) + subscriptionArgsChs[i] = make(chan kafkaSubscriptionArgs, 1) + + idx := i + subscriptionID, err := clients[i].Subscribe(&subscriptionQuery, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsChs[idx] <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionID) + + go func(i int) { + clientRunChs[i] <- clients[i].Run() + }(i) + } + + xEnv.WaitForSubscriptionCount(uint64(tc.numSubscribers), Timeout) + + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + + // Collect events from all subscribers + for i := 0; i < tc.numSubscribers; i++ { + testenv.AwaitChannelWithT(t, Timeout, subscriptionArgsChs[i], func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) + }) + } + + // Close all clients + for i := 0; i < tc.numSubscribers; i++ { + require.NoError(t, clients[i].Close()) + testenv.AwaitChannelWithT(t, Timeout, clientRunChs[i], func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + } + + for i := range subscriptionArgsChs { + close(subscriptionArgsChs[i]) + } + + assert.Equal(t, int32(tc.maxConcurrent), maxCurrentHandlers.Load(), "amount of concurrent handlers not what was expected") + + requestLog := xEnv.Observer().FilterMessage("Stream Hook has been run") + assert.Len(t, requestLog.All(), tc.numSubscribers) + }) + }) + } + }) } diff --git a/router/core/factoryresolver.go b/router/core/factoryresolver.go index d7c72fe579..d155c6c5b7 100644 --- a/router/core/factoryresolver.go +++ b/router/core/factoryresolver.go @@ -501,9 +501,10 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod l.resolver.InstanceData().HostName, l.resolver.InstanceData().ListenAddress, pubsub_datasource.Hooks{ - SubscriptionOnStart: subscriptionOnStartFns, - OnReceiveEvents: onReceiveEventsFns, - OnPublishEvents: onPublishEventsFns, + SubscriptionOnStart: subscriptionOnStartFns, + OnReceiveEvents: onReceiveEventsFns, + OnPublishEvents: onPublishEventsFns, + MaxConcurrentOnReceiveHandlers: l.subscriptionHooks.maxConcurrentOnReceiveHooks, }, ) if err != nil { diff --git a/router/core/router.go b/router/core/router.go index 919ac49c29..1dfb038f20 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -244,6 +244,11 @@ func NewRouter(opts ...Option) (*Router, error) { r.metricConfig = rmetric.DefaultConfig(Version) } + // Default value for maxConcurrentOnReceiveHooks + if r.subscriptionHooks.maxConcurrentOnReceiveHooks == 0 { + r.subscriptionHooks.maxConcurrentOnReceiveHooks = 100 + } + if r.corsOptions == nil { r.corsOptions = CorsDefaultOptions() } @@ -2122,6 +2127,12 @@ func WithDemoMode(demoMode bool) Option { } } +func WithSubscriptionHooks(cfg config.SubscriptionHooksConfiguration) Option { + return func(r *Router) { + r.subscriptionHooks.maxConcurrentOnReceiveHooks = cfg.MaxConcurrentEventReceiveHandlers + } +} + type ProxyFunc func(req *http.Request) (*url.URL, error) func newHTTPTransport(opts *TransportRequestOptions, proxy ProxyFunc, traceDialer *TraceDialer, subgraph string) *http.Transport { diff --git a/router/core/router_config.go b/router/core/router_config.go index 3e282d3c65..ba24294104 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -27,9 +27,10 @@ import ( ) type subscriptionHooks struct { - onStart []func(ctx SubscriptionOnStartHandlerContext) error - onPublishEvents []func(ctx StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) - onReceiveEvents []func(ctx StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) + onStart []func(ctx SubscriptionOnStartHandlerContext) error + onPublishEvents []func(ctx StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) + onReceiveEvents []func(ctx StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) + maxConcurrentOnReceiveHooks int } type Config struct { diff --git a/router/core/supervisor_instance.go b/router/core/supervisor_instance.go index 01f1daaef1..3f5cc37d7e 100644 --- a/router/core/supervisor_instance.go +++ b/router/core/supervisor_instance.go @@ -3,6 +3,10 @@ package core import ( "context" "fmt" + "net/http" + "os" + "strings" + "github.com/KimMachineGun/automemlimit/memlimit" "github.com/dustin/go-humanize" "github.com/wundergraph/cosmo/router/pkg/authentication" @@ -13,9 +17,6 @@ import ( "go.uber.org/automaxprocs/maxprocs" "go.uber.org/zap" "go.uber.org/zap/zapcore" - "net/http" - "os" - "strings" ) // newRouter creates a new router instance. @@ -251,6 +252,7 @@ func optionsFromResources(logger *zap.Logger, config *config.Config) []Option { WithMCP(config.MCP), WithPlugins(config.Plugins), WithDemoMode(config.DemoMode), + WithSubscriptionHooks(config.Events.SubscriptionHooks), } return options diff --git a/router/demo.config.yaml b/router/demo.config.yaml index 9a72e31de2..2a081e74be 100644 --- a/router/demo.config.yaml +++ b/router/demo.config.yaml @@ -19,4 +19,4 @@ events: redis: - id: my-redis urls: - - "redis://localhost:6379/2" \ No newline at end of file + - "redis://localhost:6379/2" diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index 73e8f85e28..8d52b4c4a3 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -639,7 +639,12 @@ type EventProviders struct { } type EventsConfiguration struct { - Providers EventProviders `yaml:"providers,omitempty"` + Providers EventProviders `yaml:"providers,omitempty"` + SubscriptionHooks SubscriptionHooksConfiguration `yaml:"subscription_hooks,omitempty"` +} + +type SubscriptionHooksConfiguration struct { + MaxConcurrentEventReceiveHandlers int `yaml:"max_concurrent_event_receive_handlers" envDefault:"100"` } type Cluster struct { diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index 528e0e1ce7..4992e504a8 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -2273,6 +2273,19 @@ } } } + }, + "subscription_hooks": { + "type": "object", + "description": "Configuration for subscription custom modules that are executed when events are received from a broker.", + "additionalProperties": false, + "properties": { + "max_concurrent_event_receive_handlers": { + "type": "integer", + "description": "The maximum number of concurrent event receive handlers. This controls the concurrency of the OnReceiveEvents custom modules.", + "minimum": 1, + "default": 100 + } + } } } }, diff --git a/router/pkg/config/fixtures/full.yaml b/router/pkg/config/fixtures/full.yaml index a43691cc12..d010d457d7 100644 --- a/router/pkg/config/fixtures/full.yaml +++ b/router/pkg/config/fixtures/full.yaml @@ -325,6 +325,8 @@ events: urls: - 'redis://localhost:6379/11' cluster_enabled: true + subscription_hooks: + max_concurrent_event_receive_handlers: 100 engine: enable_single_flight: true diff --git a/router/pkg/config/testdata/config_defaults.json b/router/pkg/config/testdata/config_defaults.json index c14af6023c..714acd66a9 100644 --- a/router/pkg/config/testdata/config_defaults.json +++ b/router/pkg/config/testdata/config_defaults.json @@ -289,6 +289,9 @@ "Nats": null, "Kafka": null, "Redis": null + }, + "SubscriptionHooks": { + "MaxConcurrentEventReceiveHandlers": 100 } }, "CacheWarmup": { diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index d2a5695072..003883b338 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -635,6 +635,9 @@ "ClusterEnabled": true } ] + }, + "SubscriptionHooks": { + "MaxConcurrentEventReceiveHandlers": 100 } }, "CacheWarmup": { diff --git a/router/pkg/pubsub/datasource/hooks.go b/router/pkg/pubsub/datasource/hooks.go index abab8b8ef1..e07fc7f81a 100644 --- a/router/pkg/pubsub/datasource/hooks.go +++ b/router/pkg/pubsub/datasource/hooks.go @@ -14,7 +14,8 @@ type OnReceiveEventsFn func(ctx context.Context, subConf SubscriptionEventConfig // Hooks contains hooks for the pubsub providers and data sources type Hooks struct { - SubscriptionOnStart []SubscriptionOnStartFn - OnReceiveEvents []OnReceiveEventsFn - OnPublishEvents []OnPublishEventsFn + SubscriptionOnStart []SubscriptionOnStartFn + OnReceiveEvents []OnReceiveEventsFn + OnPublishEvents []OnPublishEventsFn + MaxConcurrentOnReceiveHandlers int } diff --git a/router/pkg/pubsub/datasource/subscription_event_updater.go b/router/pkg/pubsub/datasource/subscription_event_updater.go index 95289bb313..b0ef4dbd71 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater.go @@ -2,6 +2,7 @@ package datasource import ( "context" + "sync" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" @@ -31,34 +32,30 @@ func (s *subscriptionEventUpdater) Update(events []StreamEvent) { } return } - // If there are hooks, we should apply them separated for each subscription - for ctx, subId := range s.eventUpdater.Subscriptions() { - processedEvents, err := applyStreamEventHooks( - ctx, - s.subscriptionEventConfiguration, - events, - s.hooks.OnReceiveEvents, - ) - // updates the events even if the hooks fail - // if a hook doesn't want to send the events, it should return no events! - for _, event := range processedEvents { - s.eventUpdater.UpdateSubscription(subId, event.GetData()) - } - if err != nil { - // For all errors, log them - if s.logger != nil { - s.logger.Error( - "An error occurred while processing stream events hooks", - zap.Error(err), - zap.String("provider_type", string(s.subscriptionEventConfiguration.ProviderType())), - zap.String("provider_id", s.subscriptionEventConfiguration.ProviderID()), - zap.String("field_name", s.subscriptionEventConfiguration.RootFieldName()), - ) - } - // Always close the subscription when a hook reports an error to avoid inconsistent state. - s.eventUpdater.CloseSubscription(resolve.SubscriptionCloseKindNormal, subId) - } + + subscriptions := s.eventUpdater.Subscriptions() + limit := max(s.hooks.MaxConcurrentOnReceiveHandlers, 1) + semaphore := make(chan struct{}, limit) + wg := sync.WaitGroup{} + errCh := make(chan error, len(subscriptions)) + + for ctx, subId := range subscriptions { + semaphore <- struct{}{} // Acquire a slot + eventsCopy := copyEvents(events) + wg.Add(1) + go s.updateSubscription(ctx, &wg, errCh, semaphore, subId, eventsCopy) } + + doneLogging := make(chan struct{}) + go func() { + s.deduplicateAndLogErrors(errCh, len(subscriptions)) + doneLogging <- struct{}{} + }() + + wg.Wait() + close(semaphore) + close(errCh) + <-doneLogging } func (s *subscriptionEventUpdater) Complete() { @@ -73,9 +70,9 @@ func (s *subscriptionEventUpdater) SetHooks(hooks Hooks) { s.hooks = hooks } -// applyStreamEventHooks processes events through a chain of hook functions +// applyReceiveEventHooks processes events through a chain of hook functions // Each hook receives the result from the previous hook, creating a proper middleware pipeline -func applyStreamEventHooks( +func applyReceiveEventHooks( ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent, @@ -97,6 +94,74 @@ func applyStreamEventHooks( return currentEvents, nil } +func copyEvents(in []StreamEvent) []StreamEvent { + out := make([]StreamEvent, len(in)) + for i := range in { + out[i] = in[i].Clone() + } + return out +} + +func (s *subscriptionEventUpdater) updateSubscription(ctx context.Context, wg *sync.WaitGroup, errCh chan error, semaphore chan struct{}, subID resolve.SubscriptionIdentifier, events []StreamEvent) { + defer wg.Done() + defer func() { + <-semaphore // Release the slot when done + }() + + hooks := s.hooks.OnReceiveEvents + + // modify events with hooks + var err error + for i := range hooks { + events, err = hooks[i](ctx, s.subscriptionEventConfiguration, events) + if err != nil { + errCh <- err + } + } + + // send events to the subscription, + // regardless if there was an error during hook processing. + // If no events should be sent, hook must return no events. + for _, event := range events { + s.eventUpdater.UpdateSubscription(subID, event.GetData()) + } + + // In case there was an error we close the affected subscription. + if err != nil { + s.eventUpdater.CloseSubscription(resolve.SubscriptionCloseKindNormal, subID) + } +} + +// deduplicateAndLogErrors collects errors from errCh +// and deduplicates them based on their err.Error() value. +// Afterwards it uses s.logger to log the message. +func (s *subscriptionEventUpdater) deduplicateAndLogErrors(errCh chan error, size int) { + if s.logger == nil { + return + } + + errs := make(map[string]int, size) + for err := range errCh { + amount, found := errs[err.Error()] + if found { + errs[err.Error()] = amount + 1 + continue + } + errs[err.Error()] = 1 + } + + for err, amount := range errs { + s.logger.Error( + "some handlers have thrown an error", + zap.String("error", err), + zap.Int("amount_handlers", amount), + zap.String("provider_type", string(s.subscriptionEventConfiguration.ProviderType())), + zap.String("provider_id", s.subscriptionEventConfiguration.ProviderID()), + zap.String("field_name", s.subscriptionEventConfiguration.RootFieldName()), + ) + } +} + func NewSubscriptionEventUpdater( cfg SubscriptionEventConfiguration, hooks Hooks, diff --git a/router/pkg/pubsub/datasource/subscription_event_updater_test.go b/router/pkg/pubsub/datasource/subscription_event_updater_test.go index 79fd140a51..d5ba1fcd90 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater_test.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater_test.go @@ -302,7 +302,7 @@ func TestNewSubscriptionEventUpdater(t *testing.T) { assert.Equal(t, mockUpdater, concreteUpdater.eventUpdater) } -func TestApplyStreamEventHooks_NoHooks(t *testing.T) { +func TestApplyReceiveEventHooks_NoHooks(t *testing.T) { ctx := context.Background() config := &testSubscriptionEventConfig{ providerID: "test-provider", @@ -313,13 +313,13 @@ func TestApplyStreamEventHooks_NoHooks(t *testing.T) { &testEvent{data: []byte("test data")}, } - result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{}) + result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{}) assert.NoError(t, err) assert.Equal(t, originalEvents, result) } -func TestApplyStreamEventHooks_SingleHook_Success(t *testing.T) { +func TestApplyReceiveEventHooks_SingleHook_Success(t *testing.T) { ctx := context.Background() config := &testSubscriptionEventConfig{ providerID: "test-provider", @@ -337,13 +337,13 @@ func TestApplyStreamEventHooks_SingleHook_Success(t *testing.T) { return modifiedEvents, nil } - result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook}) + result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook}) assert.NoError(t, err) assert.Equal(t, modifiedEvents, result) } -func TestApplyStreamEventHooks_SingleHook_Error(t *testing.T) { +func TestApplyReceiveEventHooks_SingleHook_Error(t *testing.T) { ctx := context.Background() config := &testSubscriptionEventConfig{ providerID: "test-provider", @@ -359,14 +359,14 @@ func TestApplyStreamEventHooks_SingleHook_Error(t *testing.T) { return nil, hookError } - result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook}) + result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook}) assert.Error(t, err) assert.Equal(t, hookError, err) assert.Nil(t, result) } -func TestApplyStreamEventHooks_MultipleHooks_Success(t *testing.T) { +func TestApplyReceiveEventHooks_MultipleHooks_Success(t *testing.T) { ctx := context.Background() config := &testSubscriptionEventConfig{ providerID: "test-provider", @@ -393,7 +393,7 @@ func TestApplyStreamEventHooks_MultipleHooks_Success(t *testing.T) { return []StreamEvent{&testEvent{data: []byte("final")}}, nil } - result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook1, hook2, hook3}) + result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook1, hook2, hook3}) select { case receivedArgs1 := <-receivedArgs1: @@ -424,7 +424,7 @@ func TestApplyStreamEventHooks_MultipleHooks_Success(t *testing.T) { assert.Equal(t, "final", string(result[0].GetData())) } -func TestApplyStreamEventHooks_MultipleHooks_MiddleHookError(t *testing.T) { +func TestApplyReceiveEventHooks_MultipleHooks_MiddleHookError(t *testing.T) { ctx := context.Background() config := &testSubscriptionEventConfig{ providerID: "test-provider", @@ -452,7 +452,7 @@ func TestApplyStreamEventHooks_MultipleHooks_MiddleHookError(t *testing.T) { return []StreamEvent{&testEvent{data: []byte("final")}}, nil } - result, err := applyStreamEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook1, hook2, hook3}) + result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook1, hook2, hook3}) assert.Error(t, err) assert.Equal(t, middleHookError, err) @@ -622,6 +622,8 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Error_LoggerWrite mockUpdater.AssertNotCalled(t, "UpdateSubscription") mockUpdater.AssertCalled(t, "CloseSubscription", resolve.SubscriptionCloseKindNormal, subId) - msgs := logObserver.FilterMessageSnippet("An error occurred while processing stream events hooks").TakeAll() - assert.Equal(t, 1, len(msgs)) + // log error messages for hooks are written async, we need to wait for them to be written + assert.Eventually(t, func() bool { + return len(logObserver.FilterMessageSnippet("some handlers have thrown an error").TakeAll()) == 1 + }, time.Second, 10*time.Millisecond, "expected one deduplicated error log") } From 4c7687d76d18729e6105a468901998b3197ba280 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Thu, 30 Oct 2025 08:44:38 +0000 Subject: [PATCH 04/44] feat: distinguish between read / write events (#2304) Co-authored-by: Alessandro Pagnin --- .../availability/subgraph/schema.resolvers.go | 4 +- .../mood/subgraph/schema.resolvers.go | 4 +- .../modules/start_subscription_test.go | 12 +- router-tests/modules/stream-publish/module.go | 4 +- router-tests/modules/stream-receive/module.go | 4 +- router-tests/modules/stream_publish_test.go | 100 ++++-- router-tests/modules/stream_receive_test.go | 62 ++-- .../modules/streams_hooks_combined_test.go | 30 +- router/core/router_config.go | 4 +- router/core/subscriptions_modules.go | 62 ++-- router/pkg/pubsub/datasource/provider.go | 44 ++- .../pubsub/datasource/pubsubprovider_test.go | 72 +++-- .../datasource/subscription_event_updater.go | 35 +- .../subscription_event_updater_test.go | 298 +++++++++++------- router/pkg/pubsub/kafka/adapter.go | 16 +- router/pkg/pubsub/kafka/engine_datasource.go | 65 +++- .../pubsub/kafka/engine_datasource_factory.go | 2 +- .../pubsub/kafka/engine_datasource_test.go | 24 +- router/pkg/pubsub/nats/adapter.go | 24 +- router/pkg/pubsub/nats/engine_datasource.go | 77 ++++- .../pubsub/nats/engine_datasource_factory.go | 2 +- .../pkg/pubsub/nats/engine_datasource_test.go | 12 +- router/pkg/pubsub/redis/adapter.go | 10 +- router/pkg/pubsub/redis/engine_datasource.go | 47 ++- .../pubsub/redis/engine_datasource_factory.go | 2 +- .../pubsub/redis/engine_datasource_test.go | 12 +- 26 files changed, 657 insertions(+), 371 deletions(-) diff --git a/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go b/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go index 8e52ec96c5..97ef578631 100644 --- a/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go +++ b/demo/pkg/subgraphs/availability/subgraph/schema.resolvers.go @@ -20,7 +20,7 @@ func (r *mutationResolver) UpdateAvailability(ctx context.Context, employeeID in conf := &nats.PublishAndRequestEventConfiguration{ Subject: r.GetPubSubName(fmt.Sprintf("employeeUpdated.%d", employeeID)), } - evt := &nats.Event{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))} + evt := &nats.MutableEvent{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))} err := r.NatsPubSubByProviderID["default"].Publish(ctx, conf, []datasource.StreamEvent{evt}) if err != nil { @@ -30,7 +30,7 @@ func (r *mutationResolver) UpdateAvailability(ctx context.Context, employeeID in conf2 := &nats.PublishAndRequestEventConfiguration{ Subject: r.GetPubSubName(fmt.Sprintf("employeeUpdatedMyNats.%d", employeeID)), } - evt2 := &nats.Event{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))} + evt2 := &nats.MutableEvent{Data: []byte(fmt.Sprintf(`{"id":%d,"__typename": "Employee"}`, employeeID))} err = r.NatsPubSubByProviderID["my-nats"].Publish(ctx, conf2, []datasource.StreamEvent{evt2}) if err != nil { diff --git a/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go b/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go index b9b426593c..8941ac7ac1 100644 --- a/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go +++ b/demo/pkg/subgraphs/mood/subgraph/schema.resolvers.go @@ -22,7 +22,7 @@ func (r *mutationResolver) UpdateMood(ctx context.Context, employeeID int, mood if r.NatsPubSubByProviderID["default"] != nil { err := r.NatsPubSubByProviderID["default"].Publish(ctx, &nats.PublishAndRequestEventConfiguration{ Subject: myNatsTopic, - }, []datasource.StreamEvent{&nats.Event{Data: []byte(payload)}}) + }, []datasource.StreamEvent{(&nats.MutableEvent{Data: []byte(payload)})}) if err != nil { return nil, err } @@ -34,7 +34,7 @@ func (r *mutationResolver) UpdateMood(ctx context.Context, employeeID int, mood if r.NatsPubSubByProviderID["my-nats"] != nil { err := r.NatsPubSubByProviderID["my-nats"].Publish(ctx, &nats.PublishAndRequestEventConfiguration{ Subject: defaultTopic, - }, []datasource.StreamEvent{&nats.Event{Data: []byte(payload)}}) + }, []datasource.StreamEvent{(&nats.MutableEvent{Data: []byte(payload)})}) if err != nil { return nil, err } diff --git a/router-tests/modules/start_subscription_test.go b/router-tests/modules/start_subscription_test.go index b9d5e2f0ac..aaee9e27db 100644 --- a/router-tests/modules/start_subscription_test.go +++ b/router-tests/modules/start_subscription_test.go @@ -94,10 +94,10 @@ func TestStartSubscriptionHook(t *testing.T) { if ctx.SubscriptionEventConfiguration().RootFieldName() != "employeeUpdatedMyKafka" { return nil } - ctx.WriteEvent(&kafka.Event{ + ctx.WriteEvent((&kafka.MutableEvent{ Key: []byte("1"), Data: []byte(`{"id": 1, "__typename": "Employee"}`), - }) + })) return nil }, }, @@ -266,9 +266,9 @@ func TestStartSubscriptionHook(t *testing.T) { if employeeId != 1 { return nil } - ctx.WriteEvent(&kafka.Event{ + ctx.WriteEvent((&kafka.MutableEvent{ Data: []byte(`{"id": 1, "__typename": "Employee"}`), - }) + })) return nil }, }, @@ -510,9 +510,7 @@ func TestStartSubscriptionHook(t *testing.T) { Modules: map[string]interface{}{ "startSubscriptionModule": start_subscription.StartSubscriptionModule{ Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { - ctx.WriteEvent(&core.EngineEvent{ - Data: []byte(`{"data":{"countEmp":1000}}`), - }) + ctx.WriteEvent(core.MutableEngineEvent([]byte(`{"data":{"countEmp":1000}}`))) return nil }, }, diff --git a/router-tests/modules/stream-publish/module.go b/router-tests/modules/stream-publish/module.go index e5553058ea..ef5c24277b 100644 --- a/router-tests/modules/stream-publish/module.go +++ b/router-tests/modules/stream-publish/module.go @@ -11,7 +11,7 @@ const myModuleID = "publishModule" type PublishModule struct { Logger *zap.Logger - Callback func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) + Callback func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) } func (m *PublishModule) Provision(ctx *core.ModuleContext) error { @@ -21,7 +21,7 @@ func (m *PublishModule) Provision(ctx *core.ModuleContext) error { return nil } -func (m *PublishModule) OnPublishEvents(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { +func (m *PublishModule) OnPublishEvents(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { m.Logger.Info("Publish Hook has been run") if m.Callback != nil { diff --git a/router-tests/modules/stream-receive/module.go b/router-tests/modules/stream-receive/module.go index 640218ad00..51d2b22a33 100644 --- a/router-tests/modules/stream-receive/module.go +++ b/router-tests/modules/stream-receive/module.go @@ -11,7 +11,7 @@ const myModuleID = "streamReceiveModule" type StreamReceiveModule struct { Logger *zap.Logger - Callback func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) + Callback func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) } func (m *StreamReceiveModule) Provision(ctx *core.ModuleContext) error { @@ -21,7 +21,7 @@ func (m *StreamReceiveModule) Provision(ctx *core.ModuleContext) error { return nil } -func (m *StreamReceiveModule) OnReceiveEvents(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { +func (m *StreamReceiveModule) OnReceiveEvents(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { m.Logger.Info("Stream Hook has been run") if m.Callback != nil { diff --git a/router-tests/modules/stream_publish_test.go b/router-tests/modules/stream_publish_test.go index 6fb7485dc3..ddaf982029 100644 --- a/router-tests/modules/stream_publish_test.go +++ b/router-tests/modules/stream_publish_test.go @@ -4,6 +4,7 @@ import ( "encoding/json" "net/http" "strconv" + "sync/atomic" "testing" "time" @@ -23,6 +24,53 @@ import ( func TestPublishHook(t *testing.T) { t.Parallel() + t.Run("Test Publish hook can't assert to mutable types", func(t *testing.T) { + t.Parallel() + + var taPossible atomic.Bool + taPossible.Store(true) + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "publishModule": stream_publish.PublishModule{ + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + for _, evt := range events.All() { + _, ok := evt.(datasource.MutableStreamEvent) + if !ok { + taPossible.Store(false) + } + } + return events, nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_publish.PublishModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ + Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`, + }) + require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":false}}}`, resOne.Body) + + requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") + assert.Len(t, requestLog.All(), 1) + + assert.False(t, taPossible.Load(), "invalid type assertion was possible") + }) + }) + t.Run("Test Publish hook is called", func(t *testing.T) { t.Parallel() @@ -55,25 +103,13 @@ func TestPublishHook(t *testing.T) { }) }) - t.Run("Test Publish kafka hook allows to set headers", func(t *testing.T) { + t.Run("Test Publish hook is called with mutable event", func(t *testing.T) { t.Parallel() cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "publishModule": stream_publish.PublishModule{ - Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { - for _, event := range events { - evt, ok := event.(*kafka.Event) - if !ok { - continue - } - evt.Headers["x-test"] = []byte("test") - } - - return events, nil - }, - }, + "publishModule": stream_publish.PublishModule{}, }, } @@ -89,21 +125,13 @@ func TestPublishHook(t *testing.T) { LogLevel: zapcore.InfoLevel, }, }, func(t *testing.T, xEnv *testenv.Environment) { - events.KafkaEnsureTopicExists(t, xEnv, time.Second, "employeeUpdated") resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`, }) - require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":true}}}`, resOne.Body) + require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":false}}}`, resOne.Body) requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") assert.Len(t, requestLog.All(), 1) - - records, err := events.ReadKafkaMessages(xEnv, time.Second, "employeeUpdated", 1) - require.NoError(t, err) - require.Len(t, records, 1) - header := records[0].Headers[0] - require.Equal(t, "x-test", header.Key) - require.Equal(t, []byte("test"), header.Value) }) }) @@ -114,7 +142,7 @@ func TestPublishHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "publishModule": stream_publish.PublishModule{ - Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { return events, core.NewHttpGraphqlError("test", http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) }, }, @@ -159,7 +187,7 @@ func TestPublishHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "publishModule": stream_publish.PublishModule{ - Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { return events, core.NewHttpGraphqlError("test", http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) }, }, @@ -213,7 +241,7 @@ func TestPublishHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "publishModule": stream_publish.PublishModule{ - Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { return events, core.NewHttpGraphqlError("test", http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) }, }, @@ -257,26 +285,28 @@ func TestPublishHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "publishModule": stream_publish.PublishModule{ - Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { if ctx.PublishEventConfiguration().RootFieldName() != "updateEmployeeMyKafka" { return events, nil } employeeID := ctx.Operation().Variables().GetInt("employeeID") - newEvents := []datasource.StreamEvent{} - for _, event := range events { - evt, ok := event.(*kafka.Event) + newEvents := make([]datasource.StreamEvent, 0, events.Len()) + for _, event := range events.All() { + newEvt, ok := event.Clone().(*kafka.MutableEvent) if !ok { continue } - if evt.Headers == nil { - evt.Headers = map[string][]byte{} + newEvt.SetData([]byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) + if newEvt.Headers == nil { + newEvt.Headers = map[string][]byte{} } - evt.Headers["x-employee-id"] = []byte(strconv.Itoa(employeeID)) - newEvents = append(newEvents, event) + newEvt.Headers["x-employee-id"] = []byte(strconv.Itoa(employeeID)) + newEvents = append(newEvents, newEvt) } - return newEvents, nil + + return datasource.NewStreamEvents(newEvents), nil }, }, }, diff --git a/router-tests/modules/stream_receive_test.go b/router-tests/modules/stream_receive_test.go index 21f62a9b53..a30efd23f1 100644 --- a/router-tests/modules/stream_receive_test.go +++ b/router-tests/modules/stream_receive_test.go @@ -20,7 +20,6 @@ import ( "github.com/wundergraph/cosmo/router/pkg/authentication" "github.com/wundergraph/cosmo/router/pkg/config" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" - "github.com/wundergraph/cosmo/router/pkg/pubsub/kafka" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) @@ -115,16 +114,15 @@ func TestReceiveHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { - for _, event := range events { - evt, ok := event.(*kafka.Event) - if !ok { - continue - } - evt.Data = []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`) + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + newEvents := make([]datasource.StreamEvent, 0, events.Len()) + for _, event := range events.All() { + eventCopy := event.Clone() + eventCopy.SetData([]byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) + newEvents = append(newEvents, eventCopy) } - return events, nil + return datasource.NewStreamEvents(newEvents), nil }, }, }, @@ -200,22 +198,22 @@ func TestReceiveHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { if ctx.Authentication() == nil { return events, nil } if val, ok := ctx.Authentication().Claims()["sub"]; !ok || val != "user-2" { return events, nil } - for _, event := range events { - evt, ok := event.(*kafka.Event) - if !ok { - continue - } - evt.Data = []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`) + + newEvents := make([]datasource.StreamEvent, 0, events.Len()) + for _, event := range events.All() { + eventCopy := event.Clone() + eventCopy.SetData([]byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) + newEvents = append(newEvents, eventCopy) } - return events, nil + return datasource.NewStreamEvents(newEvents), nil }, }, }, @@ -356,19 +354,19 @@ func TestReceiveHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { if val, ok := ctx.Request().Header[customHeader]; !ok || val[0] != "Test" { return events, nil } - for _, event := range events { - evt, ok := event.(*kafka.Event) - if !ok { - continue - } - evt.Data = []byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`) + + newEvents := make([]datasource.StreamEvent, 0, events.Len()) + for _, event := range events.All() { + eventCopy := event.Clone() + eventCopy.SetData([]byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) + newEvents = append(newEvents, eventCopy) } - return events, nil + return datasource.NewStreamEvents(newEvents), nil }, }, }, @@ -452,8 +450,8 @@ func TestReceiveHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { - return nil, errors.New("test error from streamevents hook") + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + return datasource.NewStreamEvents(nil), errors.New("test error from streamevents hook") }, }, }, @@ -526,8 +524,8 @@ func TestReceiveHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { - return nil, errors.New("deduplicated error") + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + return datasource.NewStreamEvents(nil), errors.New("deduplicated error") }, }, }, @@ -621,9 +619,9 @@ func TestReceiveHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { count := errorCounter.Add(1) - return nil, fmt.Errorf("unique error %d", count) + return datasource.NewStreamEvents(nil), fmt.Errorf("unique error %d", count) }, }, }, @@ -764,7 +762,7 @@ func TestReceiveHook(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { currentHandlers.Add(1) // wait for other handlers in the batch diff --git a/router-tests/modules/streams_hooks_combined_test.go b/router-tests/modules/streams_hooks_combined_test.go index 78639dd052..47a25b48c6 100644 --- a/router-tests/modules/streams_hooks_combined_test.go +++ b/router-tests/modules/streams_hooks_combined_test.go @@ -36,36 +36,42 @@ func TestStreamsHooksCombined(t *testing.T) { Graph: config.Graph{}, Modules: map[string]interface{}{ "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { - for _, event := range events { - evt, ok := event.(*kafka.Event) + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + newEvents := make([]datasource.StreamEvent, 0, events.Len()) + for _, event := range events.All() { + newEvt, ok := event.Clone().(*kafka.MutableEvent) if !ok { continue } - - if string(evt.Headers["x-publishModule"]) == "i_was_here" { - evt.Data = []byte(`{"__typename":"Employee","id": 2,"update":{"name":"irrelevant"}}`) + if string(newEvt.Headers["x-publishModule"]) == "i_was_here" { + newEvt.SetData([]byte(`{"__typename":"Employee","id": 2,"update":{"name":"irrelevant"}}`)) } + newEvents = append(newEvents, newEvt) } - return events, nil + return datasource.NewStreamEvents(newEvents), nil }, }, "publishModule": stream_publish.PublishModule{ - Callback: func(ctx core.StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { if ctx.PublishEventConfiguration().RootFieldName() != "updateEmployeeMyKafka" { return events, nil } - for _, event := range events { - evt, ok := event.(*kafka.Event) + newEvents := make([]datasource.StreamEvent, 0, events.Len()) + for _, event := range events.All() { + newEvt, ok := event.Clone().(*kafka.MutableEvent) if !ok { continue } - evt.Headers["x-publishModule"] = []byte("i_was_here") + if newEvt.Headers == nil { + newEvt.Headers = make(map[string][]byte) + } + newEvt.Headers["x-publishModule"] = []byte("i_was_here") + newEvents = append(newEvents, newEvt) } - return events, nil + return datasource.NewStreamEvents(newEvents), nil }, }, }, diff --git a/router/core/router_config.go b/router/core/router_config.go index ba24294104..c921687f66 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -28,8 +28,8 @@ import ( type subscriptionHooks struct { onStart []func(ctx SubscriptionOnStartHandlerContext) error - onPublishEvents []func(ctx StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) - onReceiveEvents []func(ctx StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) + onPublishEvents []func(ctx StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) + onReceiveEvents []func(ctx StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) maxConcurrentOnReceiveHooks int } diff --git a/router/core/subscriptions_modules.go b/router/core/subscriptions_modules.go index e3279c811d..8f5ca8490e 100644 --- a/router/core/subscriptions_modules.go +++ b/router/core/subscriptions_modules.go @@ -91,19 +91,35 @@ func (c *pubSubSubscriptionOnStartHookContext) WriteEvent(event datasource.Strea return true } +type MutableEngineEvent []byte + +func (e MutableEngineEvent) GetData() []byte { + return e +} + +func (e MutableEngineEvent) SetData(data []byte) { + copy(e, data) +} + +func (e MutableEngineEvent) Clone() datasource.MutableStreamEvent { + return slices.Clone(e) +} + // EngineEvent is the event used to write to the engine subscription type EngineEvent struct { - Data []byte + data MutableEngineEvent } func (e *EngineEvent) GetData() []byte { - return e.Data + return e.data } -func (e *EngineEvent) Clone() datasource.StreamEvent { - return &EngineEvent{ - Data: slices.Clone(e.Data), - } +func (e *EngineEvent) WriteCopy() datasource.MutableStreamEvent { + return e.data.Clone() +} + +func (e *EngineEvent) Clone() datasource.MutableStreamEvent { + return slices.Clone(e.data) } type engineSubscriptionOnStartHookContext struct { @@ -201,13 +217,15 @@ type StreamReceiveEventHandlerContext interface { } type StreamReceiveEventHandler interface { - // OnReceiveEvents is called each time a batch of events is received from the provider before delivering them to the - // client. So for a single batch of events received from the provider, this hook will be called one time for each - // active subscription. - // It is important to optimize the logic inside this hook to avoid performance issues. - // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a - // StreamHookError. - OnReceiveEvents(ctx StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) + // OnReceiveEvents is called whenever a batch of events is received from a provider, + // before delivering them to clients. + // The hook will be called once for each active subscription, therefore it is adviced to + // avoid resource heavy computation or blocking tasks whenever possible. + // The events argument contains all events from a batch and is shared between + // all active subscribers of these events. + // Use events.All() to iterate through them and event.Clone() to create mutable copies, when needed. + // Returning an error will result in the subscription being closed and the error being logged. + OnReceiveEvents(ctx StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) } type StreamPublishEventHandlerContext interface { @@ -224,13 +242,15 @@ type StreamPublishEventHandlerContext interface { } type StreamPublishEventHandler interface { - // OnPublishEvents is called each time a batch of events is going to be sent to the provider + // OnPublishEvents is called each time a batch of events is going to be sent to a provider. + // The events argument contains all events from a batch. + // Use events.All() to iterate through them and event.Clone() to create mutable copies, when needed. // Returning an error will result in a GraphQL error being returned to the client, could be customized returning a // StreamHookError. - OnPublishEvents(ctx StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error) + OnPublishEvents(ctx StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) } -func NewPubSubOnPublishEventsHook(fn func(ctx StreamPublishEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error)) datasource.OnPublishEventsFn { +func NewPubSubOnPublishEventsHook(fn func(ctx StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error)) datasource.OnPublishEventsFn { if fn == nil { return nil } @@ -245,7 +265,9 @@ func NewPubSubOnPublishEventsHook(fn func(ctx StreamPublishEventHandlerContext, publishEventConfiguration: pubConf, } - return fn(hookCtx, evts) + newEvts, err := fn(hookCtx, datasource.NewStreamEvents(evts)) + + return newEvts.Unsafe(), err } } @@ -277,7 +299,7 @@ func (c *pubSubStreamReceiveEventHookContext) SubscriptionEventConfiguration() d return c.subscriptionEventConfiguration } -func NewPubSubOnReceiveEventsHook(fn func(ctx StreamReceiveEventHandlerContext, events []datasource.StreamEvent) ([]datasource.StreamEvent, error)) datasource.OnReceiveEventsFn { +func NewPubSubOnReceiveEventsHook(fn func(ctx StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error)) datasource.OnReceiveEventsFn { if fn == nil { return nil } @@ -291,7 +313,7 @@ func NewPubSubOnReceiveEventsHook(fn func(ctx StreamReceiveEventHandlerContext, authentication: requestContext.Authentication(), subscriptionEventConfiguration: subConf, } - - return fn(hookCtx, evts) + newEvts, err := fn(hookCtx, datasource.NewStreamEvents(evts)) + return newEvts.Unsafe(), err } } diff --git a/router/pkg/pubsub/datasource/provider.go b/router/pkg/pubsub/datasource/provider.go index 57bbb70ed7..fd02ffccf6 100644 --- a/router/pkg/pubsub/datasource/provider.go +++ b/router/pkg/pubsub/datasource/provider.go @@ -2,6 +2,8 @@ package datasource import ( "context" + "iter" + "slices" "github.com/wundergraph/cosmo/router/pkg/metric" ) @@ -46,7 +48,7 @@ type ProviderBuilder[P, E any] interface { BuildEngineDataSourceFactory(data E, providers map[string]Provider) (EngineDataSourceFactory, error) } -// ProviderType represents the type of pubsub provider +// ProviderType represents the type of pubsub provider. type ProviderType string const ( @@ -55,12 +57,44 @@ const ( ProviderTypeRedis ProviderType = "redis" ) -// StreamEvent is a generic interface for all stream events -// Each provider will have its own event type that implements this interface -// there could be other common fields in the future, but for now we only have data +// StreamEvents is a list of stream events coming from or going to event providers. +type StreamEvents struct { + evts []StreamEvent +} + +// All is an iterator, which can be used to iterate through all events. +func (e StreamEvents) All() iter.Seq2[int, StreamEvent] { + return slices.All(e.evts) +} + +// Len returns the number of events. +func (e StreamEvents) Len() int { + return len(e.evts) +} + +// Unsafe returns the underlying slice of stream events. +// This slice is not thread safe and should not be modified directly. +func (e StreamEvents) Unsafe() []StreamEvent { + return e.evts +} + +func NewStreamEvents(evts []StreamEvent) StreamEvents { + return StreamEvents{evts: evts} +} + +// A StreamEvent is a single event coming from or going to an event provider. type StreamEvent interface { + // GetData returns the payload data of the event. GetData() []byte - Clone() StreamEvent + // Clone returns a mutable copy of the event. + Clone() MutableStreamEvent +} + +// A MutableStreamEvent is a stream event that can be modified. +type MutableStreamEvent interface { + StreamEvent + // SetData sets the data of the event. + SetData([]byte) } // SubscriptionEventConfiguration is the interface that all subscription event configurations must implement diff --git a/router/pkg/pubsub/datasource/pubsubprovider_test.go b/router/pkg/pubsub/datasource/pubsubprovider_test.go index 6ef41c56a5..939d15ad3f 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider_test.go +++ b/router/pkg/pubsub/datasource/pubsubprovider_test.go @@ -1,9 +1,9 @@ package datasource import ( - "bytes" "context" "errors" + "slices" "testing" "github.com/stretchr/testify/assert" @@ -12,18 +12,32 @@ import ( ) // Test helper types +type mutableTestEvent []byte + +func (e mutableTestEvent) Clone() MutableStreamEvent { + var evt mutableTestEvent = make([]byte, len(e)) + copy(evt, e) + return evt +} + +func (e mutableTestEvent) GetData() []byte { + return e +} + +func (e mutableTestEvent) SetData(data []byte) { + copy(e, data) +} + type testEvent struct { - data []byte + evt mutableTestEvent } func (e *testEvent) GetData() []byte { - return e.data + return slices.Clone(e.evt.GetData()) } -func (e *testEvent) Clone() StreamEvent { - return &testEvent{ - data: bytes.Clone(e.data), - } +func (e *testEvent) Clone() MutableStreamEvent { + return e.evt.Clone() } type testSubscriptionConfig struct { @@ -158,8 +172,8 @@ func TestProvider_Publish_NoHooks_Success(t *testing.T) { fieldName: "testField", } events := []StreamEvent{ - &testEvent{data: []byte("test data 1")}, - &testEvent{data: []byte("test data 2")}, + &testEvent{mutableTestEvent("test data 1")}, + &testEvent{mutableTestEvent("test data 2")}, } mockAdapter.On("Publish", mock.Anything, config, events).Return(nil) @@ -181,7 +195,7 @@ func TestProvider_Publish_NoHooks_Error(t *testing.T) { fieldName: "testField", } events := []StreamEvent{ - &testEvent{data: []byte("test data")}, + &testEvent{mutableTestEvent("test data")}, } expectedError := errors.New("publish error") @@ -205,10 +219,10 @@ func TestProvider_Publish_WithHooks_Success(t *testing.T) { fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original data")}, + &testEvent{mutableTestEvent("original data")}, } modifiedEvents := []StreamEvent{ - &testEvent{data: []byte("modified data")}, + &testEvent{mutableTestEvent("modified data")}, } // Define hook that modifies events @@ -237,7 +251,7 @@ func TestProvider_Publish_WithHooks_HookError(t *testing.T) { fieldName: "testField", } events := []StreamEvent{ - &testEvent{data: []byte("test data")}, + &testEvent{mutableTestEvent("test data")}, } hookError := errors.New("hook processing error") @@ -270,10 +284,10 @@ func TestProvider_Publish_WithHooks_AdapterError(t *testing.T) { fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original data")}, + &testEvent{mutableTestEvent("original data")}, } processedEvents := []StreamEvent{ - &testEvent{data: []byte("processed data")}, + &testEvent{mutableTestEvent("processed data")}, } adapterError := errors.New("adapter publish error") @@ -304,15 +318,15 @@ func TestProvider_Publish_WithMultipleHooks_Success(t *testing.T) { fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, + &testEvent{mutableTestEvent("original")}, } // Chain of hooks that modify the data hook1 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - return []StreamEvent{&testEvent{data: []byte("modified by hook1")}}, nil + return []StreamEvent{&testEvent{mutableTestEvent("modified by hook1")}}, nil } hook2 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - return []StreamEvent{&testEvent{data: []byte("modified by hook2")}}, nil + return []StreamEvent{&testEvent{mutableTestEvent("modified by hook2")}}, nil } mockAdapter.On("Publish", mock.Anything, config, mock.MatchedBy(func(events []StreamEvent) bool { @@ -370,7 +384,7 @@ func TestApplyPublishEventHooks_NoHooks(t *testing.T) { fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("test data")}, + &testEvent{mutableTestEvent("test data")}, } result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{}) @@ -387,10 +401,10 @@ func TestApplyPublishEventHooks_SingleHook_Success(t *testing.T) { fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, + &testEvent{mutableTestEvent("original")}, } modifiedEvents := []StreamEvent{ - &testEvent{data: []byte("modified")}, + &testEvent{mutableTestEvent("modified")}, } hook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { @@ -411,7 +425,7 @@ func TestApplyPublishEventHooks_SingleHook_Error(t *testing.T) { fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, + &testEvent{mutableTestEvent("original")}, } hookError := errors.New("hook processing failed") @@ -434,17 +448,17 @@ func TestApplyPublishEventHooks_MultipleHooks_Success(t *testing.T) { fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, + &testEvent{mutableTestEvent("original")}, } hook1 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - return []StreamEvent{&testEvent{data: []byte("step1")}}, nil + return []StreamEvent{&testEvent{mutableTestEvent("step1")}}, nil } hook2 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - return []StreamEvent{&testEvent{data: []byte("step2")}}, nil + return []StreamEvent{&testEvent{mutableTestEvent("step2")}}, nil } hook3 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - return []StreamEvent{&testEvent{data: []byte("final")}}, nil + return []StreamEvent{&testEvent{mutableTestEvent("final")}}, nil } result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{hook1, hook2, hook3}) @@ -462,18 +476,18 @@ func TestApplyPublishEventHooks_MultipleHooks_MiddleHookError(t *testing.T) { fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, + &testEvent{mutableTestEvent("original")}, } middleHookError := errors.New("middle hook failed") hook1 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - return []StreamEvent{&testEvent{data: []byte("step1")}}, nil + return []StreamEvent{&testEvent{mutableTestEvent("step1")}}, nil } hook2 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { return nil, middleHookError } hook3 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - return []StreamEvent{&testEvent{data: []byte("final")}}, nil + return []StreamEvent{&testEvent{mutableTestEvent("final")}}, nil } result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{hook1, hook2, hook3}) diff --git a/router/pkg/pubsub/datasource/subscription_event_updater.go b/router/pkg/pubsub/datasource/subscription_event_updater.go index b0ef4dbd71..1c920f7222 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater.go @@ -41,9 +41,8 @@ func (s *subscriptionEventUpdater) Update(events []StreamEvent) { for ctx, subId := range subscriptions { semaphore <- struct{}{} // Acquire a slot - eventsCopy := copyEvents(events) wg.Add(1) - go s.updateSubscription(ctx, &wg, errCh, semaphore, subId, eventsCopy) + go s.updateSubscription(ctx, &wg, errCh, semaphore, subId, events) } doneLogging := make(chan struct{}) @@ -70,38 +69,6 @@ func (s *subscriptionEventUpdater) SetHooks(hooks Hooks) { s.hooks = hooks } -// applyReceiveEventHooks processes events through a chain of hook functions -// Each hook receives the result from the previous hook, creating a proper middleware pipeline -func applyReceiveEventHooks( - ctx context.Context, - cfg SubscriptionEventConfiguration, - events []StreamEvent, - hooks []OnReceiveEventsFn) ([]StreamEvent, error) { - // Copy the events to avoid modifying the original slice - currentEvents := make([]StreamEvent, len(events)) - for i, event := range events { - currentEvents[i] = event.Clone() - } - // Apply each hook in sequence, passing the result of one as the input to the next - // If any hook returns an error, stop processing and return the error - for _, hook := range hooks { - var err error - currentEvents, err = hook(ctx, cfg, currentEvents) - if err != nil { - return currentEvents, err - } - } - return currentEvents, nil -} - -func copyEvents(in []StreamEvent) []StreamEvent { - out := make([]StreamEvent, len(in)) - for i := range in { - out[i] = in[i].Clone() - } - return out -} - func (s *subscriptionEventUpdater) updateSubscription(ctx context.Context, wg *sync.WaitGroup, errCh chan error, semaphore chan struct{}, subID resolve.SubscriptionIdentifier, events []StreamEvent) { defer wg.Done() defer func() { diff --git a/router/pkg/pubsub/datasource/subscription_event_updater_test.go b/router/pkg/pubsub/datasource/subscription_event_updater_test.go index d5ba1fcd90..693d9d5da0 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater_test.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater_test.go @@ -3,6 +3,7 @@ package datasource import ( "context" "errors" + "sync" "testing" "time" @@ -44,8 +45,8 @@ func TestSubscriptionEventUpdater_Update_NoHooks(t *testing.T) { fieldName: "testField", } events := []StreamEvent{ - &testEvent{data: []byte("test data 1")}, - &testEvent{data: []byte("test data 2")}, + &testEvent{mutableTestEvent("test data 1")}, + &testEvent{mutableTestEvent("test data 2")}, } // Expect calls to Update for each event @@ -69,10 +70,10 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Success(t *testin fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original data")}, + &testEvent{mutableTestEvent("original data")}, } modifiedEvents := []StreamEvent{ - &testEvent{data: []byte("modified data")}, + &testEvent{mutableTestEvent("modified data")}, } // Create wrapper function for the mock @@ -116,7 +117,7 @@ func TestSubscriptionEventUpdater_UpdateSubscriptions_WithHooks_Error(t *testing fieldName: "testField", } events := []StreamEvent{ - &testEvent{data: []byte("test data")}, + &testEvent{mutableTestEvent("test data")}, } hookError := errors.New("hook processing error") @@ -157,20 +158,20 @@ func TestSubscriptionEventUpdater_Update_WithMultipleHooks_Success(t *testing.T) fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, + &testEvent{mutableTestEvent("original")}, } // Chain of hooks that modify the data receivedArgs1 := make(chan receivedHooksArgs, 1) hook1 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { receivedArgs1 <- receivedHooksArgs{events: events, cfg: cfg} - return []StreamEvent{&testEvent{data: []byte("modified by hook1")}}, nil + return []StreamEvent{&testEvent{mutableTestEvent("modified by hook1")}}, nil } receivedArgs2 := make(chan receivedHooksArgs, 1) hook2 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { receivedArgs2 <- receivedHooksArgs{events: events, cfg: cfg} - return []StreamEvent{&testEvent{data: []byte("modified by hook2")}}, nil + return []StreamEvent{&testEvent{mutableTestEvent("modified by hook2")}}, nil } // Expect call to UpdateSubscription with modified data @@ -200,7 +201,7 @@ func TestSubscriptionEventUpdater_Update_WithMultipleHooks_Success(t *testing.T) select { case receivedArgs2 := <-receivedArgs2: - assert.Equal(t, []StreamEvent{&testEvent{data: []byte("modified by hook1")}}, receivedArgs2.events) + assert.Equal(t, []StreamEvent{&testEvent{mutableTestEvent("modified by hook1")}}, receivedArgs2.events) assert.Equal(t, config, receivedArgs2.cfg) case <-time.After(1 * time.Second): t.Fatal("timeout waiting for events") @@ -302,179 +303,260 @@ func TestNewSubscriptionEventUpdater(t *testing.T) { assert.Equal(t, mockUpdater, concreteUpdater.eventUpdater) } -func TestApplyReceiveEventHooks_NoHooks(t *testing.T) { - ctx := context.Background() +func TestSubscriptionEventUpdater_Update_PassthroughWithNoHooks(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) config := &testSubscriptionEventConfig{ providerID: "test-provider", providerType: ProviderTypeNats, fieldName: "testField", } - originalEvents := []StreamEvent{ - &testEvent{data: []byte("test data")}, + events := []StreamEvent{ + &testEvent{mutableTestEvent("event data 1")}, + &testEvent{mutableTestEvent("event data 2")}, + &testEvent{mutableTestEvent("event data 3")}, } - result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{}) + // With no hooks, Update should call the underlying eventUpdater.Update for each event + mockUpdater.On("Update", []byte("event data 1")).Return() + mockUpdater.On("Update", []byte("event data 2")).Return() + mockUpdater.On("Update", []byte("event data 3")).Return() - assert.NoError(t, err) - assert.Equal(t, originalEvents, result) + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{}, // No hooks + } + + updater.Update(events) + + // Verify all events were passed through without modification + mockUpdater.AssertCalled(t, "Update", []byte("event data 1")) + mockUpdater.AssertCalled(t, "Update", []byte("event data 2")) + mockUpdater.AssertCalled(t, "Update", []byte("event data 3")) + mockUpdater.AssertNumberOfCalls(t, "Update", 3) } -func TestApplyReceiveEventHooks_SingleHook_Success(t *testing.T) { - ctx := context.Background() +func TestSubscriptionEventUpdater_Update_WithSingleHookModification(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) config := &testSubscriptionEventConfig{ providerID: "test-provider", providerType: ProviderTypeNats, fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, - } - modifiedEvents := []StreamEvent{ - &testEvent{data: []byte("modified")}, + &testEvent{mutableTestEvent("original data 1")}, + &testEvent{mutableTestEvent("original data 2")}, } + // Hook that modifies events by adding a prefix hook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + modifiedEvents := make([]StreamEvent, len(events)) + for i, event := range events { + modifiedData := "modified: " + string(event.GetData()) + modifiedEvents[i] = &testEvent{mutableTestEvent(modifiedData)} + } return modifiedEvents, nil } - result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook}) + subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} + mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ + context.Background(): subId, + }) + + // With hooks, UpdateSubscription should be called with modified data + mockUpdater.On("UpdateSubscription", subId, []byte("modified: original data 1")).Return() + mockUpdater.On("UpdateSubscription", subId, []byte("modified: original data 2")).Return() + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{hook}, + }, + } - assert.NoError(t, err) - assert.Equal(t, modifiedEvents, result) + updater.Update(originalEvents) + + // Verify modified events were sent to UpdateSubscription, not the original events + mockUpdater.AssertCalled(t, "UpdateSubscription", subId, []byte("modified: original data 1")) + mockUpdater.AssertCalled(t, "UpdateSubscription", subId, []byte("modified: original data 2")) + mockUpdater.AssertNumberOfCalls(t, "UpdateSubscription", 2) + // Update should NOT be called when hooks are present + mockUpdater.AssertNotCalled(t, "Update") } -func TestApplyReceiveEventHooks_SingleHook_Error(t *testing.T) { - ctx := context.Background() +func TestSubscriptionEventUpdater_Update_WithSingleHookError_ClosesSubscriptionAndLogsError(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) config := &testSubscriptionEventConfig{ providerID: "test-provider", providerType: ProviderTypeNats, fieldName: "testField", } - originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, + events := []StreamEvent{ + &testEvent{mutableTestEvent("test data")}, } hookError := errors.New("hook processing failed") + // Hook that returns an error hook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - return nil, hookError + // Return the events but also return an error + return events, hookError } - result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook}) + // Set up logger with observer to verify error logging + zCore, logObserver := observer.New(zap.InfoLevel) + logger := zap.New(zCore) + + subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} + mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ + context.Background(): subId, + }) + // Events are still sent even when hook returns error + mockUpdater.On("UpdateSubscription", subId, []byte("test data")).Return() + // Subscription should be closed due to the error + mockUpdater.On("CloseSubscription", resolve.SubscriptionCloseKindNormal, subId).Return() + + updater := NewSubscriptionEventUpdater(config, Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{hook}, + }, mockUpdater, logger) + + updater.Update(events) - assert.Error(t, err) - assert.Equal(t, hookError, err) - assert.Nil(t, result) + // Verify events were still sent despite the error + mockUpdater.AssertCalled(t, "UpdateSubscription", subId, []byte("test data")) + // Verify subscription was closed due to the error + mockUpdater.AssertCalled(t, "CloseSubscription", resolve.SubscriptionCloseKindNormal, subId) + // Update should NOT be called when hooks are present + mockUpdater.AssertNotCalled(t, "Update") + + // Verify error was logged (logging happens asynchronously) + assert.Eventually(t, func() bool { + logs := logObserver.FilterMessageSnippet("some handlers have thrown an error").TakeAll() + if len(logs) != 1 { + return false + } + // Verify the logged error message contains our error + return logs[0].ContextMap()["error"] == hookError.Error() + }, time.Second, 10*time.Millisecond, "expected error to be logged") } -func TestApplyReceiveEventHooks_MultipleHooks_Success(t *testing.T) { - ctx := context.Background() +func TestSubscriptionEventUpdater_Update_WithMultipleHooksChaining(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) config := &testSubscriptionEventConfig{ providerID: "test-provider", providerType: ProviderTypeNats, fieldName: "testField", } originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, + &testEvent{mutableTestEvent("original")}, } + // Track what each hook receives and when it's called + hookCallOrder := make([]int, 0, 3) + var mu sync.Mutex + + // Hook 1: Adds "step1: " prefix receivedArgs1 := make(chan receivedHooksArgs, 1) hook1 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + mu.Lock() + hookCallOrder = append(hookCallOrder, 1) + mu.Unlock() receivedArgs1 <- receivedHooksArgs{events: events, cfg: cfg} - return []StreamEvent{&testEvent{data: []byte("step1")}}, nil + modifiedEvents := make([]StreamEvent, len(events)) + for i, event := range events { + modifiedData := "step1: " + string(event.GetData()) + modifiedEvents[i] = &testEvent{mutableTestEvent(modifiedData)} + } + return modifiedEvents, nil } + + // Hook 2: Adds "step2: " prefix receivedArgs2 := make(chan receivedHooksArgs, 1) hook2 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + mu.Lock() + hookCallOrder = append(hookCallOrder, 2) + mu.Unlock() receivedArgs2 <- receivedHooksArgs{events: events, cfg: cfg} - return []StreamEvent{&testEvent{data: []byte("step2")}}, nil + modifiedEvents := make([]StreamEvent, len(events)) + for i, event := range events { + modifiedData := "step2: " + string(event.GetData()) + modifiedEvents[i] = &testEvent{mutableTestEvent(modifiedData)} + } + return modifiedEvents, nil } + + // Hook 3: Adds "step3: " prefix receivedArgs3 := make(chan receivedHooksArgs, 1) hook3 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + mu.Lock() + hookCallOrder = append(hookCallOrder, 3) + mu.Unlock() receivedArgs3 <- receivedHooksArgs{events: events, cfg: cfg} - return []StreamEvent{&testEvent{data: []byte("final")}}, nil + modifiedEvents := make([]StreamEvent, len(events)) + for i, event := range events { + modifiedData := "step3: " + string(event.GetData()) + modifiedEvents[i] = &testEvent{mutableTestEvent(modifiedData)} + } + return modifiedEvents, nil } - result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook1, hook2, hook3}) + subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} + mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ + context.Background(): subId, + }) + // Final modified data should have all three transformations applied + mockUpdater.On("UpdateSubscription", subId, []byte("step3: step2: step1: original")).Return() - select { - case receivedArgs1 := <-receivedArgs1: - assert.Equal(t, originalEvents, receivedArgs1.events) - assert.Equal(t, config, receivedArgs1.cfg) - case <-time.After(1 * time.Second): - t.Fatal("timeout waiting for events") + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{hook1, hook2, hook3}, + }, } - select { - case receivedArgs2 := <-receivedArgs2: - assert.Equal(t, []StreamEvent{&testEvent{data: []byte("step1")}}, receivedArgs2.events) - assert.Equal(t, config, receivedArgs2.cfg) - case <-time.After(1 * time.Second): - t.Fatal("timeout waiting for events") - } + updater.Update(originalEvents) + // Verify hook 1 received original events select { - case receivedArgs3 := <-receivedArgs3: - assert.Equal(t, []StreamEvent{&testEvent{data: []byte("step2")}}, receivedArgs3.events) - assert.Equal(t, config, receivedArgs3.cfg) + case args1 := <-receivedArgs1: + assert.Equal(t, originalEvents, args1.events, "Hook 1 should receive original events") + assert.Equal(t, config, args1.cfg) + assert.Len(t, args1.events, 1) + assert.Equal(t, "original", string(args1.events[0].GetData())) case <-time.After(1 * time.Second): - t.Fatal("timeout waiting for events") + t.Fatal("timeout waiting for hook 1") } - assert.NoError(t, err) - assert.Len(t, result, 1) - assert.Equal(t, "final", string(result[0].GetData())) -} - -func TestApplyReceiveEventHooks_MultipleHooks_MiddleHookError(t *testing.T) { - ctx := context.Background() - config := &testSubscriptionEventConfig{ - providerID: "test-provider", - providerType: ProviderTypeNats, - fieldName: "testField", - } - originalEvents := []StreamEvent{ - &testEvent{data: []byte("original")}, - } - middleHookError := errors.New("middle hook failed") - - receivedArgs1 := make(chan receivedHooksArgs, 1) - hook1 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - receivedArgs1 <- receivedHooksArgs{events: events, cfg: cfg} - return []StreamEvent{&testEvent{data: []byte("step1")}}, nil - } - receivedArgs2 := make(chan receivedHooksArgs, 1) - hook2 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - receivedArgs2 <- receivedHooksArgs{events: events, cfg: cfg} - return nil, middleHookError - } - receivedArgs3 := make(chan receivedHooksArgs, 1) - hook3 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { - receivedArgs3 <- receivedHooksArgs{events: events, cfg: cfg} - return []StreamEvent{&testEvent{data: []byte("final")}}, nil - } - - result, err := applyReceiveEventHooks(ctx, config, originalEvents, []OnReceiveEventsFn{hook1, hook2, hook3}) - - assert.Error(t, err) - assert.Equal(t, middleHookError, err) - assert.Nil(t, result) - + // Verify hook 2 received events modified by hook 1 select { - case receivedArgs1 := <-receivedArgs1: - assert.Equal(t, originalEvents, receivedArgs1.events) - assert.Equal(t, config, receivedArgs1.cfg) + case args2 := <-receivedArgs2: + assert.Equal(t, config, args2.cfg) + assert.Len(t, args2.events, 1) + assert.Equal(t, "step1: original", string(args2.events[0].GetData()), "Hook 2 should receive output from hook 1") case <-time.After(1 * time.Second): - t.Fatal("timeout waiting for events") + t.Fatal("timeout waiting for hook 2") } + // Verify hook 3 received events modified by hook 2 select { - case receivedArgs2 := <-receivedArgs2: - assert.Equal(t, []StreamEvent{&testEvent{data: []byte("step1")}}, receivedArgs2.events) - assert.Equal(t, config, receivedArgs2.cfg) + case args3 := <-receivedArgs3: + assert.Equal(t, config, args3.cfg) + assert.Len(t, args3.events, 1) + assert.Equal(t, "step2: step1: original", string(args3.events[0].GetData()), "Hook 3 should receive output from hook 2") case <-time.After(1 * time.Second): - t.Fatal("timeout waiting for events") + t.Fatal("timeout waiting for hook 3") } - assert.Empty(t, receivedArgs3) + // Verify hooks were called in correct order + mu.Lock() + assert.Equal(t, []int{1, 2, 3}, hookCallOrder, "Hooks should be called in order") + mu.Unlock() + + // Verify final modified events were sent to UpdateSubscription + mockUpdater.AssertCalled(t, "UpdateSubscription", subId, []byte("step3: step2: step1: original")) + mockUpdater.AssertNumberOfCalls(t, "UpdateSubscription", 1) + mockUpdater.AssertNotCalled(t, "Update") } // Test the updateEvents method indirectly through Update method @@ -555,7 +637,7 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHookError_ClosesSubscri fieldName: "testField", } events := []StreamEvent{ - &testEvent{data: []byte("test data")}, + &testEvent{mutableTestEvent("test data")}, } testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { @@ -592,7 +674,7 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Error_LoggerWrite fieldName: "testField", } events := []StreamEvent{ - &testEvent{data: []byte("test data")}, + &testEvent{mutableTestEvent("test data")}, } hookError := errors.New("hook processing error") diff --git a/router/pkg/pubsub/kafka/adapter.go b/router/pkg/pubsub/kafka/adapter.go index 7f61a242b9..fcd1cc0c70 100644 --- a/router/pkg/pubsub/kafka/adapter.go +++ b/router/pkg/pubsub/kafka/adapter.go @@ -108,11 +108,15 @@ func (p *ProviderAdapter) topicPoller(ctx context.Context, client *kgo.Client, u DestinationName: r.Topic, }) - updater.Update([]datasource.StreamEvent{&Event{ - Data: r.Value, - Headers: headers, - Key: r.Key, - }}) + updater.Update([]datasource.StreamEvent{ + &Event{ + evt: &MutableEvent{ + Data: r.Value, + Headers: headers, + Key: r.Key, + }, + }, + }) } } } @@ -212,7 +216,7 @@ func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEv var errMutex sync.Mutex for _, streamEvent := range events { - kafkaEvent, ok := streamEvent.(*Event) + kafkaEvent, ok := streamEvent.Clone().(*MutableEvent) if !ok { return datasource.NewError("invalid event type for Kafka adapter", nil) } diff --git a/router/pkg/pubsub/kafka/engine_datasource.go b/router/pkg/pubsub/kafka/engine_datasource.go index 00a38023ea..9d48fd0db0 100644 --- a/router/pkg/pubsub/kafka/engine_datasource.go +++ b/router/pkg/pubsub/kafka/engine_datasource.go @@ -15,18 +15,66 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) -// Event represents an event from Kafka +// Event implements datasource.StreamEvent type Event struct { + evt *MutableEvent +} + +func (e *Event) GetData() []byte { + if e.evt == nil { + return nil + } + return slices.Clone(e.evt.Data) +} + +func (e *Event) GetKey() []byte { + if e.evt == nil { + return nil + } + return slices.Clone(e.evt.Key) +} + +func (e *Event) GetHeaders() map[string][]byte { + if e.evt == nil { + return nil + } + return cloneHeaders(e.evt.Headers) +} + +func (e Event) Clone() datasource.MutableStreamEvent { + return e.evt.Clone() +} + +func cloneHeaders(src map[string][]byte) map[string][]byte { + if src == nil { + return nil + } + dst := make(map[string][]byte, len(src)) + for k, v := range src { + dst[k] = slices.Clone(v) + } + return dst +} + +// MutableEvent implements datasource.MutableEvent +type MutableEvent struct { Key []byte `json:"key"` Data json.RawMessage `json:"data"` Headers map[string][]byte `json:"headers"` } -func (e *Event) GetData() []byte { +func (e *MutableEvent) GetData() []byte { return e.Data } -func (e *Event) Clone() datasource.StreamEvent { +func (e *MutableEvent) SetData(data []byte) { + if e == nil { + return + } + e.Data = data +} + +func (e *MutableEvent) Clone() datasource.MutableStreamEvent { e2 := *e e2.Data = slices.Clone(e.Data) e2.Headers = make(map[string][]byte, len(e.Headers)) @@ -61,10 +109,10 @@ func (s *SubscriptionEventConfiguration) RootFieldName() string { // publishData is a private type that is used to pass data from the engine to the provider type publishData struct { - Provider string `json:"providerId"` - Topic string `json:"topic"` - Event Event `json:"event"` - FieldName string `json:"rootFieldName"` + Provider string `json:"providerId"` + Topic string `json:"topic"` + Event MutableEvent `json:"event"` + FieldName string `json:"rootFieldName"` } // PublishEventConfiguration returns the publish event configuration from the publishData type @@ -172,7 +220,7 @@ func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.B return err } - if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{&publishData.Event}); err != nil { + if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{&Event{&publishData.Event}}); err != nil { // err will not be returned but only logged inside PubSubProvider.Publish to avoid a "unable to fetch from subgraph" error _, errWrite := io.WriteString(out, `{"success": false}`) return errWrite @@ -192,3 +240,4 @@ func (s *PublishDataSource) LoadWithFiles(ctx context.Context, input []byte, fil var _ datasource.SubscriptionEventConfiguration = (*SubscriptionEventConfiguration)(nil) var _ datasource.PublishEventConfiguration = (*PublishEventConfiguration)(nil) var _ datasource.StreamEvent = (*Event)(nil) +var _ datasource.MutableStreamEvent = (*MutableEvent)(nil) diff --git a/router/pkg/pubsub/kafka/engine_datasource_factory.go b/router/pkg/pubsub/kafka/engine_datasource_factory.go index d89eb408b0..b4e1356714 100644 --- a/router/pkg/pubsub/kafka/engine_datasource_factory.go +++ b/router/pkg/pubsub/kafka/engine_datasource_factory.go @@ -55,7 +55,7 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri evtCfg := publishData{ Provider: c.providerId, Topic: c.topics[0], - Event: Event{Data: eventData}, + Event: MutableEvent{Data: eventData}, FieldName: c.fieldName, } diff --git a/router/pkg/pubsub/kafka/engine_datasource_test.go b/router/pkg/pubsub/kafka/engine_datasource_test.go index 846203d6e0..5fb5808173 100644 --- a/router/pkg/pubsub/kafka/engine_datasource_test.go +++ b/router/pkg/pubsub/kafka/engine_datasource_test.go @@ -23,9 +23,9 @@ func TestPublishData_MarshalJSONTemplate(t *testing.T) { { name: "simple configuration", config: publishData{ - Provider: "test-provider", - Topic: "test-topic", - Event: Event{Data: json.RawMessage(`{"message":"hello"}`)}, + Provider: "test-provider", + Topic: "test-topic", + Event: MutableEvent{Data: json.RawMessage(`{"message":"hello"}`)}, FieldName: "test-field", }, wantPattern: `{"topic":"test-topic", "event": {"data": {"message":"hello"}, "key": "", "headers": {}}, "providerId":"test-provider", "rootFieldName":"test-field"}`, @@ -33,9 +33,9 @@ func TestPublishData_MarshalJSONTemplate(t *testing.T) { { name: "with special characters", config: publishData{ - Provider: "test-provider-id", - Topic: "topic-with-hyphens", - Event: Event{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, + Provider: "test-provider-id", + Topic: "topic-with-hyphens", + Event: MutableEvent{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, FieldName: "test-field", }, wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}, "key": "", "headers": {}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`, @@ -43,9 +43,9 @@ func TestPublishData_MarshalJSONTemplate(t *testing.T) { { name: "with key", config: publishData{ - Provider: "test-provider-id", - Topic: "topic-with-hyphens", - Event: Event{Key: []byte("blablabla"), Data: json.RawMessage(`{}`)}, + Provider: "test-provider-id", + Topic: "topic-with-hyphens", + Event: MutableEvent{Key: []byte("blablabla"), Data: json.RawMessage(`{}`)}, FieldName: "test-field", }, wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {}, "key": "blablabla", "headers": {}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`, @@ -53,9 +53,9 @@ func TestPublishData_MarshalJSONTemplate(t *testing.T) { { name: "with headers", config: publishData{ - Provider: "test-provider-id", - Topic: "topic-with-hyphens", - Event: Event{Headers: map[string][]byte{"key": []byte(`blablabla`)}, Data: json.RawMessage(`{}`)}, + Provider: "test-provider-id", + Topic: "topic-with-hyphens", + Event: MutableEvent{Headers: map[string][]byte{"key": []byte(`blablabla`)}, Data: json.RawMessage(`{}`)}, FieldName: "test-field", }, wantPattern: `{"topic":"topic-with-hyphens", "event": {"data": {}, "key": "", "headers": {"key":"YmxhYmxhYmxh"}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`, diff --git a/router/pkg/pubsub/nats/adapter.go b/router/pkg/pubsub/nats/adapter.go index 13628db1f6..e32368c658 100644 --- a/router/pkg/pubsub/nats/adapter.go +++ b/router/pkg/pubsub/nats/adapter.go @@ -149,10 +149,12 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, cfg datasource.Subscrip DestinationName: msg.Subject(), }) - updater.Update([]datasource.StreamEvent{&Event{ - Data: msg.Data(), - Headers: msg.Headers(), - }}) + updater.Update([]datasource.StreamEvent{ + Event{evt: &MutableEvent{ + Data: msg.Data(), + Headers: map[string][]string(msg.Headers()), + }}, + }) // Acknowledge the message after it has been processed ackErr := msg.Ack() @@ -195,10 +197,12 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, cfg datasource.Subscrip ProviderType: metric.ProviderTypeNats, DestinationName: msg.Subject, }) - updater.Update([]datasource.StreamEvent{&Event{ - Data: msg.Data, - Headers: msg.Header, - }}) + updater.Update([]datasource.StreamEvent{ + Event{evt: &MutableEvent{ + Data: msg.Data, + Headers: map[string][]string(msg.Header), + }}, + }) case <-p.ctx.Done(): // When the application context is done, we stop the subscriptions for _, subscription := range subscriptions { @@ -245,7 +249,7 @@ func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEv log.Debug("publish", zap.Int("event_count", len(events))) for _, streamEvent := range events { - natsEvent, ok := streamEvent.(*Event) + natsEvent, ok := streamEvent.Clone().(*MutableEvent) if !ok { return datasource.NewError("invalid event type for NATS adapter", nil) } @@ -296,7 +300,7 @@ func (p *ProviderAdapter) Request(ctx context.Context, cfg datasource.PublishEve return datasource.NewError("nats client not initialized", nil) } - natsEvent, ok := event.(*Event) + natsEvent, ok := event.Clone().(*MutableEvent) if !ok { return datasource.NewError("invalid event type for NATS adapter", nil) } diff --git a/router/pkg/pubsub/nats/engine_datasource.go b/router/pkg/pubsub/nats/engine_datasource.go index 3b2014a71a..f0b8c7b57a 100644 --- a/router/pkg/pubsub/nats/engine_datasource.go +++ b/router/pkg/pubsub/nats/engine_datasource.go @@ -15,24 +15,70 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) -// Event represents an event from NATS type Event struct { + evt *MutableEvent +} + +func (e Event) GetData() []byte { + if e.evt == nil { + return nil + } + return slices.Clone(e.evt.Data) +} + +func (e Event) GetHeaders() map[string][]string { + if e.evt == nil || e.evt.Headers == nil { + return nil + } + return cloneHeaders(e.evt.Headers) +} + +func (e Event) Clone() datasource.MutableStreamEvent { + return e.evt.Clone() +} + +type MutableEvent struct { Data json.RawMessage `json:"data"` Headers map[string][]string `json:"headers"` } -func (e *Event) GetData() []byte { +func (e *MutableEvent) GetData() []byte { + if e == nil { + return nil + } return e.Data } -func (e *Event) Clone() datasource.StreamEvent { - e2 := *e - e2.Data = slices.Clone(e.Data) - e2.Headers = make(map[string][]string, len(e.Headers)) - for k, v := range e.Headers { - e2.Headers[k] = slices.Clone(v) +func (e *MutableEvent) SetData(data []byte) { + if e == nil { + return } - return &e2 + e.Data = slices.Clone(data) +} + +func (e *MutableEvent) Clone() datasource.MutableStreamEvent { + if e == nil { + return nil + } + return &MutableEvent{ + Data: slices.Clone(e.Data), + Headers: cloneHeaders(e.Headers), + } +} + +func (e *MutableEvent) ToStreamEvent() datasource.StreamEvent { + return &Event{evt: e} +} + +func cloneHeaders(src map[string][]string) map[string][]string { + if src == nil { + return nil + } + dst := make(map[string][]string, len(src)) + for k, v := range src { + dst[k] = slices.Clone(v) + } + return dst } type StreamConfiguration struct { @@ -65,10 +111,10 @@ func (s *SubscriptionEventConfiguration) RootFieldName() string { // publishData is a private type that is used to pass data from the engine to the provider type publishData struct { - Provider string `json:"providerId"` - Subject string `json:"subject"` - Event Event `json:"event"` - FieldName string `json:"rootFieldName"` + Provider string `json:"providerId"` + Subject string `json:"subject"` + Event MutableEvent `json:"event"` + FieldName string `json:"rootFieldName"` } func (p *publishData) PublishEventConfiguration() datasource.PublishEventConfiguration { @@ -164,7 +210,7 @@ func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte, out *byt return err } - if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{&publishData.Event}); err != nil { + if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{Event{evt: &publishData.Event}}); err != nil { // err will not be returned but only logged inside PubSubProvider.Publish to avoid a "unable to fetch from subgraph" error _, errWrite := io.WriteString(out, `{"success": false}`) return errWrite @@ -197,7 +243,7 @@ func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte, out *byt return fmt.Errorf("adapter for provider %s is not of the right type", publishData.Provider) } - return adapter.Request(ctx, publishData.PublishEventConfiguration(), &publishData.Event, out) + return adapter.Request(ctx, publishData.PublishEventConfiguration(), Event{evt: &publishData.Event}, out) } func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) error { @@ -208,3 +254,4 @@ func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, var _ datasource.SubscriptionEventConfiguration = (*SubscriptionEventConfiguration)(nil) var _ datasource.PublishEventConfiguration = (*PublishAndRequestEventConfiguration)(nil) var _ datasource.StreamEvent = (*Event)(nil) +var _ datasource.MutableStreamEvent = (*MutableEvent)(nil) diff --git a/router/pkg/pubsub/nats/engine_datasource_factory.go b/router/pkg/pubsub/nats/engine_datasource_factory.go index d88d25b868..f4006448dd 100644 --- a/router/pkg/pubsub/nats/engine_datasource_factory.go +++ b/router/pkg/pubsub/nats/engine_datasource_factory.go @@ -69,7 +69,7 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri Provider: c.providerId, Subject: subject, FieldName: c.fieldName, - Event: Event{Data: eventData}, + Event: MutableEvent{Data: eventData}, } return evtCfg.MarshalJSONTemplate() diff --git a/router/pkg/pubsub/nats/engine_datasource_test.go b/router/pkg/pubsub/nats/engine_datasource_test.go index 8665f42181..183179c083 100644 --- a/router/pkg/pubsub/nats/engine_datasource_test.go +++ b/router/pkg/pubsub/nats/engine_datasource_test.go @@ -25,9 +25,9 @@ func TestPublishAndRequestEventConfiguration_MarshalJSONTemplate(t *testing.T) { { name: "simple configuration", config: publishData{ - Provider: "test-provider", - Subject: "test-subject", - Event: Event{Data: json.RawMessage(`{"message":"hello"}`)}, + Provider: "test-provider", + Subject: "test-subject", + Event: MutableEvent{Data: json.RawMessage(`{"message":"hello"}`)}, FieldName: "test-field", }, wantPattern: `{"subject":"test-subject", "event": {"data": {"message":"hello"}}, "providerId":"test-provider", "rootFieldName":"test-field"}`, @@ -35,9 +35,9 @@ func TestPublishAndRequestEventConfiguration_MarshalJSONTemplate(t *testing.T) { { name: "with special characters", config: publishData{ - Provider: "test-provider-id", - Subject: "subject-with-hyphens", - Event: Event{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, + Provider: "test-provider-id", + Subject: "subject-with-hyphens", + Event: MutableEvent{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, FieldName: "test-field", }, wantPattern: `{"subject":"subject-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`, diff --git a/router/pkg/pubsub/redis/adapter.go b/router/pkg/pubsub/redis/adapter.go index 8c65bc3413..8c056fe6c1 100644 --- a/router/pkg/pubsub/redis/adapter.go +++ b/router/pkg/pubsub/redis/adapter.go @@ -128,9 +128,11 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri ProviderType: metric.ProviderTypeRedis, DestinationName: msg.Channel, }) - updater.Update([]datasource.StreamEvent{&Event{ - Data: []byte(msg.Payload), - }}) + updater.Update([]datasource.StreamEvent{ + Event{evt: &MutableEvent{ + Data: []byte(msg.Payload), + }}, + }) case <-p.ctx.Done(): // When the application context is done, we stop the subscription if it is not already done log.Debug("application context done, stopping subscription") @@ -171,7 +173,7 @@ func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEv log.Debug("publish", zap.Int("event_count", len(events))) for _, streamEvent := range events { - redisEvent, ok := streamEvent.(*Event) + redisEvent, ok := streamEvent.Clone().(*MutableEvent) if !ok { return datasource.NewError("invalid event type for Redis adapter", nil) } diff --git a/router/pkg/pubsub/redis/engine_datasource.go b/router/pkg/pubsub/redis/engine_datasource.go index e796b60e66..56f00a4841 100644 --- a/router/pkg/pubsub/redis/engine_datasource.go +++ b/router/pkg/pubsub/redis/engine_datasource.go @@ -15,17 +15,45 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) -// Event represents an event from Redis type Event struct { + evt *MutableEvent +} + +func (e Event) GetData() []byte { + if e.evt == nil { + return nil + } + return slices.Clone(e.evt.Data) +} + +func (e Event) Clone() datasource.MutableStreamEvent { + return e.evt.Clone() +} + +type MutableEvent struct { Data json.RawMessage `json:"data"` } -func (e *Event) GetData() []byte { +func (e *MutableEvent) GetData() []byte { + if e == nil { + return nil + } return e.Data } -func (e *Event) Clone() datasource.StreamEvent { - return &Event{ +func (e *MutableEvent) SetData(data []byte) { + if e == nil { + return + } + e.Data = data +} + +func (e *MutableEvent) Clone() datasource.MutableStreamEvent { + if e == nil { + return nil + } + + return &MutableEvent{ Data: slices.Clone(e.Data), } } @@ -55,10 +83,10 @@ func (s *SubscriptionEventConfiguration) RootFieldName() string { // publishData is a private type that is used to pass data from the engine to the provider type publishData struct { - Provider string `json:"providerId"` - Channel string `json:"channel"` - Event Event `json:"event"` - FieldName string `json:"rootFieldName"` + Provider string `json:"providerId"` + Channel string `json:"channel"` + Event MutableEvent `json:"event"` + FieldName string `json:"rootFieldName"` } func (p *publishData) PublishEventConfiguration() datasource.PublishEventConfiguration { @@ -162,7 +190,7 @@ func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.B return err } - if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{&publishData.Event}); err != nil { + if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{Event{evt: &publishData.Event}}); err != nil { // err will not be returned but only logged inside PubSubProvider.Publish to avoid a "unable to fetch from subgraph" error _, errWrite := io.WriteString(out, `{"success": false}`) return errWrite @@ -180,3 +208,4 @@ func (s *PublishDataSource) LoadWithFiles(ctx context.Context, input []byte, fil var _ datasource.SubscriptionEventConfiguration = (*SubscriptionEventConfiguration)(nil) var _ datasource.PublishEventConfiguration = (*PublishEventConfiguration)(nil) var _ datasource.StreamEvent = (*Event)(nil) +var _ datasource.MutableStreamEvent = (*MutableEvent)(nil) diff --git a/router/pkg/pubsub/redis/engine_datasource_factory.go b/router/pkg/pubsub/redis/engine_datasource_factory.go index 46f22e29b9..1e9f9866e4 100644 --- a/router/pkg/pubsub/redis/engine_datasource_factory.go +++ b/router/pkg/pubsub/redis/engine_datasource_factory.go @@ -66,7 +66,7 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri Provider: providerId, Channel: channel, FieldName: c.fieldName, - Event: Event{Data: eventData}, + Event: MutableEvent{Data: eventData}, } return evtCfg.MarshalJSONTemplate() diff --git a/router/pkg/pubsub/redis/engine_datasource_test.go b/router/pkg/pubsub/redis/engine_datasource_test.go index b322c8a60c..cc59d240f3 100644 --- a/router/pkg/pubsub/redis/engine_datasource_test.go +++ b/router/pkg/pubsub/redis/engine_datasource_test.go @@ -23,9 +23,9 @@ func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { { name: "simple configuration", config: publishData{ - Provider: "test-provider", - Channel: "test-channel", - Event: Event{Data: json.RawMessage(`{"message":"hello"}`)}, + Provider: "test-provider", + Channel: "test-channel", + Event: MutableEvent{Data: json.RawMessage(`{"message":"hello"}`)}, FieldName: "test-field", }, wantPattern: `{"channel":"test-channel", "event": {"data": {"message":"hello"}}, "providerId":"test-provider", "rootFieldName":"test-field"}`, @@ -33,9 +33,9 @@ func TestPublishEventConfiguration_MarshalJSONTemplate(t *testing.T) { { name: "with special characters", config: publishData{ - Provider: "test-provider-id", - Channel: "channel-with-hyphens", - Event: Event{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, + Provider: "test-provider-id", + Channel: "channel-with-hyphens", + Event: MutableEvent{Data: json.RawMessage(`{"message":"special \"quotes\" here"}`)}, FieldName: "test-field", }, wantPattern: `{"channel":"channel-with-hyphens", "event": {"data": {"message":"special \"quotes\" here"}}, "providerId":"test-provider-id", "rootFieldName":"test-field"}`, From d61e56bff0d6920d5bc4460e71c808bffd22b6f8 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Thu, 30 Oct 2025 17:12:57 +0000 Subject: [PATCH 05/44] feat(router): provide an agnostic way of creating events in hooks (#2306) Co-authored-by: Alessandro Pagnin --- .../modules/start_subscription_test.go | 8 +-- router/core/subscriptions_modules.go | 67 ++++++++++++++----- router/pkg/pubsub/datasource/hooks.go | 6 +- .../pkg/pubsub/datasource/pubsubprovider.go | 28 ++++---- .../pubsub/datasource/pubsubprovider_test.go | 57 ++++++++++------ .../datasource/subscription_datasource.go | 10 ++- .../subscription_datasource_test.go | 55 +++++++++------ .../datasource/subscription_event_updater.go | 5 +- .../subscription_event_updater_test.go | 42 +++++++----- .../pubsub/kafka/engine_datasource_factory.go | 46 +++++++------ router/pkg/pubsub/kafka/provider_builder.go | 8 ++- .../pubsub/nats/engine_datasource_factory.go | 42 +++++++----- .../nats/engine_datasource_factory_test.go | 7 +- .../pkg/pubsub/nats/engine_datasource_test.go | 2 +- router/pkg/pubsub/nats/provider_builder.go | 8 ++- .../pubsub/redis/engine_datasource_factory.go | 42 +++++++----- router/pkg/pubsub/redis/provider_builder.go | 6 +- 17 files changed, 280 insertions(+), 159 deletions(-) diff --git a/router-tests/modules/start_subscription_test.go b/router-tests/modules/start_subscription_test.go index aaee9e27db..de738215f4 100644 --- a/router-tests/modules/start_subscription_test.go +++ b/router-tests/modules/start_subscription_test.go @@ -266,9 +266,8 @@ func TestStartSubscriptionHook(t *testing.T) { if employeeId != 1 { return nil } - ctx.WriteEvent((&kafka.MutableEvent{ - Data: []byte(`{"id": 1, "__typename": "Employee"}`), - })) + evt := ctx.NewEvent([]byte(`{"id": 1, "__typename": "Employee"}`)) + ctx.WriteEvent(evt) return nil }, }, @@ -510,7 +509,8 @@ func TestStartSubscriptionHook(t *testing.T) { Modules: map[string]interface{}{ "startSubscriptionModule": start_subscription.StartSubscriptionModule{ Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { - ctx.WriteEvent(core.MutableEngineEvent([]byte(`{"data":{"countEmp":1000}}`))) + evt := ctx.NewEvent([]byte(`{"data":{"countEmp":1000}}`)) + ctx.WriteEvent(evt) return nil }, }, diff --git a/router/core/subscriptions_modules.go b/router/core/subscriptions_modules.go index 8f5ca8490e..bcfaaae114 100644 --- a/router/core/subscriptions_modules.go +++ b/router/core/subscriptions_modules.go @@ -26,6 +26,8 @@ type SubscriptionOnStartHandlerContext interface { // WriteEvent writes an event to the stream of the current subscription // It returns true if the event was written to the stream, false if the event was dropped WriteEvent(event datasource.StreamEvent) bool + // NewEvent creates a new event that can be used in the subscription. + NewEvent(data []byte) datasource.MutableStreamEvent } type pubSubPublishEventHookContext struct { @@ -34,6 +36,7 @@ type pubSubPublishEventHookContext struct { operation OperationContext authentication authentication.Authentication publishEventConfiguration datasource.PublishEventConfiguration + eventBuilder datasource.EventBuilderFn } func (c *pubSubPublishEventHookContext) Request() *http.Request { @@ -56,6 +59,10 @@ func (c *pubSubPublishEventHookContext) PublishEventConfiguration() datasource.P return c.publishEventConfiguration } +func (c *pubSubPublishEventHookContext) NewEvent(data []byte) datasource.MutableStreamEvent { + return c.eventBuilder(data) +} + type pubSubSubscriptionOnStartHookContext struct { request *http.Request logger *zap.Logger @@ -63,6 +70,7 @@ type pubSubSubscriptionOnStartHookContext struct { authentication authentication.Authentication subscriptionEventConfiguration datasource.SubscriptionEventConfiguration writeEventHook func(data []byte) + eventBuilder datasource.EventBuilderFn } func (c *pubSubSubscriptionOnStartHookContext) Request() *http.Request { @@ -91,35 +99,44 @@ func (c *pubSubSubscriptionOnStartHookContext) WriteEvent(event datasource.Strea return true } -type MutableEngineEvent []byte +func (c *pubSubSubscriptionOnStartHookContext) NewEvent(data []byte) datasource.MutableStreamEvent { + return c.eventBuilder(data) +} + +// MutableEngineEvent is comparable to EngineEvent, but is mutable. +type MutableEngineEvent struct { + data []byte +} -func (e MutableEngineEvent) GetData() []byte { - return e +func (e *MutableEngineEvent) GetData() []byte { + return e.data } -func (e MutableEngineEvent) SetData(data []byte) { - copy(e, data) +func (e *MutableEngineEvent) SetData(data []byte) { + e.data = data } -func (e MutableEngineEvent) Clone() datasource.MutableStreamEvent { - return slices.Clone(e) +func (e *MutableEngineEvent) Clone() datasource.MutableStreamEvent { + return &MutableEngineEvent{data: slices.Clone(e.data)} } // EngineEvent is the event used to write to the engine subscription type EngineEvent struct { - data MutableEngineEvent + evt *MutableEngineEvent } func (e *EngineEvent) GetData() []byte { - return e.data -} - -func (e *EngineEvent) WriteCopy() datasource.MutableStreamEvent { - return e.data.Clone() + if e.evt == nil { + return nil + } + return slices.Clone(e.evt.data) } func (e *EngineEvent) Clone() datasource.MutableStreamEvent { - return slices.Clone(e.data) + if e.evt == nil { + return &MutableEngineEvent{} + } + return e.evt.Clone() } type engineSubscriptionOnStartHookContext struct { @@ -152,6 +169,10 @@ func (c *engineSubscriptionOnStartHookContext) WriteEvent(event datasource.Strea return true } +func (c *engineSubscriptionOnStartHookContext) NewEvent(data []byte) datasource.MutableStreamEvent { + return &MutableEngineEvent{data: data} +} + func (c *engineSubscriptionOnStartHookContext) SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration { return nil } @@ -168,7 +189,7 @@ func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHandlerCont return nil } - return func(resolveCtx resolve.StartupHookContext, subConf datasource.SubscriptionEventConfiguration) error { + return func(resolveCtx resolve.StartupHookContext, subConf datasource.SubscriptionEventConfiguration, eventBuilder datasource.EventBuilderFn) error { requestContext := getRequestContext(resolveCtx.Context) hookCtx := &pubSubSubscriptionOnStartHookContext{ request: requestContext.Request(), @@ -177,6 +198,7 @@ func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHandlerCont authentication: requestContext.Authentication(), subscriptionEventConfiguration: subConf, writeEventHook: resolveCtx.Updater, + eventBuilder: eventBuilder, } return fn(hookCtx) @@ -214,6 +236,8 @@ type StreamReceiveEventHandlerContext interface { Authentication() authentication.Authentication // SubscriptionEventConfiguration the subscription event configuration SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration + // NewEvent creates a new event that can be used in the subscription. + NewEvent(data []byte) datasource.MutableStreamEvent } type StreamReceiveEventHandler interface { @@ -239,6 +263,8 @@ type StreamPublishEventHandlerContext interface { Authentication() authentication.Authentication // PublishEventConfiguration the publish event configuration PublishEventConfiguration() datasource.PublishEventConfiguration + // NewEvent creates a new event that can be used in the subscription. + NewEvent(data []byte) datasource.MutableStreamEvent } type StreamPublishEventHandler interface { @@ -255,7 +281,7 @@ func NewPubSubOnPublishEventsHook(fn func(ctx StreamPublishEventHandlerContext, return nil } - return func(ctx context.Context, pubConf datasource.PublishEventConfiguration, evts []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + return func(ctx context.Context, pubConf datasource.PublishEventConfiguration, evts []datasource.StreamEvent, eventBuilder datasource.EventBuilderFn) ([]datasource.StreamEvent, error) { requestContext := getRequestContext(ctx) hookCtx := &pubSubPublishEventHookContext{ request: requestContext.Request(), @@ -263,6 +289,7 @@ func NewPubSubOnPublishEventsHook(fn func(ctx StreamPublishEventHandlerContext, operation: requestContext.Operation(), authentication: requestContext.Authentication(), publishEventConfiguration: pubConf, + eventBuilder: eventBuilder, } newEvts, err := fn(hookCtx, datasource.NewStreamEvents(evts)) @@ -277,6 +304,7 @@ type pubSubStreamReceiveEventHookContext struct { operation OperationContext authentication authentication.Authentication subscriptionEventConfiguration datasource.SubscriptionEventConfiguration + eventBuilder datasource.EventBuilderFn } func (c *pubSubStreamReceiveEventHookContext) Request() *http.Request { @@ -299,12 +327,16 @@ func (c *pubSubStreamReceiveEventHookContext) SubscriptionEventConfiguration() d return c.subscriptionEventConfiguration } +func (c *pubSubStreamReceiveEventHookContext) NewEvent(data []byte) datasource.MutableStreamEvent { + return c.eventBuilder(data) +} + func NewPubSubOnReceiveEventsHook(fn func(ctx StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error)) datasource.OnReceiveEventsFn { if fn == nil { return nil } - return func(ctx context.Context, subConf datasource.SubscriptionEventConfiguration, evts []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + return func(ctx context.Context, subConf datasource.SubscriptionEventConfiguration, eventBuilder datasource.EventBuilderFn, evts []datasource.StreamEvent) ([]datasource.StreamEvent, error) { requestContext := getRequestContext(ctx) hookCtx := &pubSubStreamReceiveEventHookContext{ request: requestContext.Request(), @@ -312,6 +344,7 @@ func NewPubSubOnReceiveEventsHook(fn func(ctx StreamReceiveEventHandlerContext, operation: requestContext.Operation(), authentication: requestContext.Authentication(), subscriptionEventConfiguration: subConf, + eventBuilder: eventBuilder, } newEvts, err := fn(hookCtx, datasource.NewStreamEvents(evts)) return newEvts.Unsafe(), err diff --git a/router/pkg/pubsub/datasource/hooks.go b/router/pkg/pubsub/datasource/hooks.go index e07fc7f81a..a2e53e7183 100644 --- a/router/pkg/pubsub/datasource/hooks.go +++ b/router/pkg/pubsub/datasource/hooks.go @@ -6,11 +6,11 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) -type SubscriptionOnStartFn func(ctx resolve.StartupHookContext, subConf SubscriptionEventConfiguration) error +type SubscriptionOnStartFn func(ctx resolve.StartupHookContext, subConf SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error -type OnPublishEventsFn func(ctx context.Context, pubConf PublishEventConfiguration, evts []StreamEvent) ([]StreamEvent, error) +type OnPublishEventsFn func(ctx context.Context, pubConf PublishEventConfiguration, evts []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) -type OnReceiveEventsFn func(ctx context.Context, subConf SubscriptionEventConfiguration, evts []StreamEvent) ([]StreamEvent, error) +type OnReceiveEventsFn func(ctx context.Context, subConf SubscriptionEventConfiguration, eventBuilder EventBuilderFn, evts []StreamEvent) ([]StreamEvent, error) // Hooks contains hooks for the pubsub providers and data sources type Hooks struct { diff --git a/router/pkg/pubsub/datasource/pubsubprovider.go b/router/pkg/pubsub/datasource/pubsubprovider.go index e234ebfb73..2a898b6ce3 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider.go +++ b/router/pkg/pubsub/datasource/pubsubprovider.go @@ -7,20 +7,21 @@ import ( ) type PubSubProvider struct { - id string - typeID string - Adapter Adapter - Logger *zap.Logger - hooks Hooks + id string + typeID string + Adapter Adapter + Logger *zap.Logger + hooks Hooks + eventBuilder EventBuilderFn } // applyPublishEventHooks processes events through a chain of hook functions // Each hook receives the result from the previous hook, creating a proper middleware pipeline -func applyPublishEventHooks(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, hooks []OnPublishEventsFn) ([]StreamEvent, error) { +func applyPublishEventHooks(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn, hooks []OnPublishEventsFn) ([]StreamEvent, error) { currentEvents := events for _, hook := range hooks { var err error - currentEvents, err = hook(ctx, cfg, currentEvents) + currentEvents, err = hook(ctx, cfg, currentEvents, eventBuilder) if err != nil { return currentEvents, err } @@ -59,7 +60,7 @@ func (p *PubSubProvider) Publish(ctx context.Context, cfg PublishEventConfigurat return p.Adapter.Publish(ctx, cfg, events) } - processedEvents, hooksErr := applyPublishEventHooks(ctx, cfg, events, p.hooks.OnPublishEvents) + processedEvents, hooksErr := applyPublishEventHooks(ctx, cfg, events, p.eventBuilder, p.hooks.OnPublishEvents) if hooksErr != nil { p.Logger.Error( "error applying publish event hooks", @@ -82,11 +83,12 @@ func (p *PubSubProvider) SetHooks(hooks Hooks) { p.hooks = hooks } -func NewPubSubProvider(id string, typeID string, adapter Adapter, logger *zap.Logger) *PubSubProvider { +func NewPubSubProvider(id string, typeID string, adapter Adapter, logger *zap.Logger, eventBuilder EventBuilderFn) *PubSubProvider { return &PubSubProvider{ - id: id, - typeID: typeID, - Adapter: adapter, - Logger: logger, + id: id, + typeID: typeID, + Adapter: adapter, + Logger: logger, + eventBuilder: eventBuilder, } } diff --git a/router/pkg/pubsub/datasource/pubsubprovider_test.go b/router/pkg/pubsub/datasource/pubsubprovider_test.go index 939d15ad3f..590297f689 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider_test.go +++ b/router/pkg/pubsub/datasource/pubsubprovider_test.go @@ -76,6 +76,11 @@ func (c *testPublishConfig) RootFieldName() string { return c.fieldName } +// testPubSubEventBuilder is a reusable event builder for tests +func testPubSubEventBuilder(data []byte) MutableStreamEvent { + return mutableTestEvent(data) +} + func TestProvider_Startup_Success(t *testing.T) { mockAdapter := NewMockProvider(t) mockAdapter.On("Startup", mock.Anything).Return(nil) @@ -225,8 +230,13 @@ func TestProvider_Publish_WithHooks_Success(t *testing.T) { &testEvent{mutableTestEvent("modified data")}, } + var eventBuilderExists bool + // Define hook that modifies events - testHook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + testHook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) { + if eventBuilder != nil { + eventBuilderExists = true + } return modifiedEvents, nil } @@ -237,10 +247,12 @@ func TestProvider_Publish_WithHooks_Success(t *testing.T) { hooks: Hooks{ OnPublishEvents: []OnPublishEventsFn{testHook}, }, + eventBuilder: testPubSubEventBuilder, } err := provider.Publish(context.Background(), config, originalEvents) assert.NoError(t, err) + assert.True(t, eventBuilderExists) } func TestProvider_Publish_WithHooks_HookError(t *testing.T) { @@ -256,7 +268,7 @@ func TestProvider_Publish_WithHooks_HookError(t *testing.T) { hookError := errors.New("hook processing error") // Define hook that returns an error - testHook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + testHook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) { return nil, hookError } @@ -268,7 +280,8 @@ func TestProvider_Publish_WithHooks_HookError(t *testing.T) { hooks: Hooks{ OnPublishEvents: []OnPublishEventsFn{testHook}, }, - Logger: zap.NewNop(), + Logger: zap.NewNop(), + eventBuilder: testPubSubEventBuilder, } err := provider.Publish(context.Background(), config, events) @@ -292,7 +305,7 @@ func TestProvider_Publish_WithHooks_AdapterError(t *testing.T) { adapterError := errors.New("adapter publish error") // Define hook that processes events successfully - testHook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + testHook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) { return processedEvents, nil } @@ -303,6 +316,7 @@ func TestProvider_Publish_WithHooks_AdapterError(t *testing.T) { hooks: Hooks{ OnPublishEvents: []OnPublishEventsFn{testHook}, }, + eventBuilder: testPubSubEventBuilder, } err := provider.Publish(context.Background(), config, originalEvents) @@ -322,10 +336,10 @@ func TestProvider_Publish_WithMultipleHooks_Success(t *testing.T) { } // Chain of hooks that modify the data - hook1 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + hook1 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) { return []StreamEvent{&testEvent{mutableTestEvent("modified by hook1")}}, nil } - hook2 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + hook2 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) { return []StreamEvent{&testEvent{mutableTestEvent("modified by hook2")}}, nil } @@ -338,6 +352,7 @@ func TestProvider_Publish_WithMultipleHooks_Success(t *testing.T) { hooks: Hooks{ OnPublishEvents: []OnPublishEventsFn{hook1, hook2}, }, + eventBuilder: testPubSubEventBuilder, } err := provider.Publish(context.Background(), config, originalEvents) @@ -347,7 +362,7 @@ func TestProvider_Publish_WithMultipleHooks_Success(t *testing.T) { func TestProvider_SetHooks(t *testing.T) { provider := &PubSubProvider{} - testHook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + testHook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) { return events, nil } @@ -366,7 +381,7 @@ func TestNewPubSubProvider(t *testing.T) { id := "test-provider-id" typeID := "test-type-id" - provider := NewPubSubProvider(id, typeID, mockAdapter, logger) + provider := NewPubSubProvider(id, typeID, mockAdapter, logger, testPubSubEventBuilder) assert.NotNil(t, provider) assert.Equal(t, id, provider.ID()) @@ -387,7 +402,7 @@ func TestApplyPublishEventHooks_NoHooks(t *testing.T) { &testEvent{mutableTestEvent("test data")}, } - result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{}) + result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{}) assert.NoError(t, err) assert.Equal(t, originalEvents, result) @@ -407,11 +422,11 @@ func TestApplyPublishEventHooks_SingleHook_Success(t *testing.T) { &testEvent{mutableTestEvent("modified")}, } - hook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + hook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) { return modifiedEvents, nil } - result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{hook}) + result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{hook}) assert.NoError(t, err) assert.Equal(t, modifiedEvents, result) @@ -429,11 +444,11 @@ func TestApplyPublishEventHooks_SingleHook_Error(t *testing.T) { } hookError := errors.New("hook processing failed") - hook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + hook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) { return nil, hookError } - result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{hook}) + result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{hook}) assert.Error(t, err) assert.Equal(t, hookError, err) @@ -451,17 +466,17 @@ func TestApplyPublishEventHooks_MultipleHooks_Success(t *testing.T) { &testEvent{mutableTestEvent("original")}, } - hook1 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + hook1 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) { return []StreamEvent{&testEvent{mutableTestEvent("step1")}}, nil } - hook2 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + hook2 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) { return []StreamEvent{&testEvent{mutableTestEvent("step2")}}, nil } - hook3 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + hook3 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) { return []StreamEvent{&testEvent{mutableTestEvent("final")}}, nil } - result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{hook1, hook2, hook3}) + result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{hook1, hook2, hook3}) assert.NoError(t, err) assert.Len(t, result, 1) @@ -480,17 +495,17 @@ func TestApplyPublishEventHooks_MultipleHooks_MiddleHookError(t *testing.T) { } middleHookError := errors.New("middle hook failed") - hook1 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + hook1 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) { return []StreamEvent{&testEvent{mutableTestEvent("step1")}}, nil } - hook2 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + hook2 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) { return nil, middleHookError } - hook3 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + hook3 := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) { return []StreamEvent{&testEvent{mutableTestEvent("final")}}, nil } - result, err := applyPublishEventHooks(ctx, config, originalEvents, []OnPublishEventsFn{hook1, hook2, hook3}) + result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{hook1, hook2, hook3}) assert.Error(t, err) assert.Equal(t, middleHookError, err) diff --git a/router/pkg/pubsub/datasource/subscription_datasource.go b/router/pkg/pubsub/datasource/subscription_datasource.go index 16ec03171a..fb35054bd5 100644 --- a/router/pkg/pubsub/datasource/subscription_datasource.go +++ b/router/pkg/pubsub/datasource/subscription_datasource.go @@ -11,6 +11,8 @@ import ( type uniqueRequestIdFn func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error +type EventBuilderFn func(data []byte) MutableStreamEvent + // PubSubSubscriptionDataSource is a data source for handling subscriptions using a Pub/Sub mechanism. // It implements the SubscriptionDataSource interface and HookableSubscriptionDataSource type PubSubSubscriptionDataSource[C SubscriptionEventConfiguration] struct { @@ -18,6 +20,7 @@ type PubSubSubscriptionDataSource[C SubscriptionEventConfiguration] struct { uniqueRequestID uniqueRequestIdFn hooks Hooks logger *zap.Logger + eventBuilder EventBuilderFn } func (s *PubSubSubscriptionDataSource[C]) SubscriptionEventConfiguration(input []byte) (SubscriptionEventConfiguration, error) { @@ -41,7 +44,7 @@ func (s *PubSubSubscriptionDataSource[C]) Start(ctx *resolve.Context, input []by return errors.New("invalid subscription configuration") } - return s.pubSub.Subscribe(ctx.Context(), conf, NewSubscriptionEventUpdater(conf, s.hooks, updater, s.logger)) + return s.pubSub.Subscribe(ctx.Context(), conf, NewSubscriptionEventUpdater(conf, s.hooks, updater, s.logger, s.eventBuilder)) } func (s *PubSubSubscriptionDataSource[C]) SubscriptionOnStart(ctx resolve.StartupHookContext, input []byte) (err error) { @@ -50,7 +53,7 @@ func (s *PubSubSubscriptionDataSource[C]) SubscriptionOnStart(ctx resolve.Startu if errConf != nil { return err } - err = fn(ctx, conf) + err = fn(ctx, conf, s.eventBuilder) if err != nil { return err } @@ -66,7 +69,7 @@ func (s *PubSubSubscriptionDataSource[C]) SetHooks(hooks Hooks) { var _ SubscriptionDataSource = (*PubSubSubscriptionDataSource[SubscriptionEventConfiguration])(nil) var _ resolve.HookableSubscriptionDataSource = (*PubSubSubscriptionDataSource[SubscriptionEventConfiguration])(nil) -func NewPubSubSubscriptionDataSource[C SubscriptionEventConfiguration](pubSub Adapter, uniqueRequestIdFn uniqueRequestIdFn, logger *zap.Logger) *PubSubSubscriptionDataSource[C] { +func NewPubSubSubscriptionDataSource[C SubscriptionEventConfiguration](pubSub Adapter, uniqueRequestIdFn uniqueRequestIdFn, logger *zap.Logger, eventBuilder EventBuilderFn) *PubSubSubscriptionDataSource[C] { if logger == nil { logger = zap.NewNop() } @@ -74,5 +77,6 @@ func NewPubSubSubscriptionDataSource[C SubscriptionEventConfiguration](pubSub Ad pubSub: pubSub, uniqueRequestID: uniqueRequestIdFn, logger: logger, + eventBuilder: eventBuilder, } } diff --git a/router/pkg/pubsub/datasource/subscription_datasource_test.go b/router/pkg/pubsub/datasource/subscription_datasource_test.go index c82b339faa..8bba79b259 100644 --- a/router/pkg/pubsub/datasource/subscription_datasource_test.go +++ b/router/pkg/pubsub/datasource/subscription_datasource_test.go @@ -31,13 +31,18 @@ func (t testSubscriptionEventConfiguration) RootFieldName() string { return "testSubscription" } +// testSubscriptionDataSourceEventBuilder is a reusable event builder for tests +func testSubscriptionDataSourceEventBuilder(data []byte) MutableStreamEvent { + return mutableTestEvent(data) +} + func TestPubSubSubscriptionDataSource_SubscriptionEventConfiguration_Success(t *testing.T) { mockAdapter := NewMockProvider(t) uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) testConfig := testSubscriptionEventConfiguration{ Topic: "test-topic", @@ -62,7 +67,7 @@ func TestPubSubSubscriptionDataSource_SubscriptionEventConfiguration_InvalidJSON return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) invalidInput := []byte(`{"invalid": json}`) result, err := dataSource.SubscriptionEventConfiguration(invalidInput) @@ -76,7 +81,7 @@ func TestPubSubSubscriptionDataSource_UniqueRequestID_Success(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) ctx := &resolve.Context{} input := []byte(`{"test": "data"}`) @@ -93,7 +98,7 @@ func TestPubSubSubscriptionDataSource_UniqueRequestID_Error(t *testing.T) { return expectedError } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) ctx := &resolve.Context{} input := []byte(`{"test": "data"}`) @@ -110,7 +115,7 @@ func TestPubSubSubscriptionDataSource_Start_Success(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) testConfig := testSubscriptionEventConfiguration{ Topic: "test-topic", @@ -135,7 +140,7 @@ func TestPubSubSubscriptionDataSource_Start_NoConfiguration(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) invalidInput := []byte(`{"invalid": json}`) ctx := resolve.NewContext(context.Background()) @@ -152,7 +157,7 @@ func TestPubSubSubscriptionDataSource_Start_SubscribeError(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) testConfig := testSubscriptionEventConfiguration{ Topic: "test-topic", @@ -179,7 +184,7 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_Success(t *testing.T) return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) testConfig := testSubscriptionEventConfiguration{ Topic: "test-topic", @@ -203,19 +208,27 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_WithHooks(t *testing.T return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) // Add subscription start hooks hook1Called := false hook2Called := false + hook1EventBuilderExists := false + hook2EventBuilderExists := false - hook1 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { + hook1 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { hook1Called = true + if eventBuilder != nil { + hook1EventBuilderExists = true + } return nil } - hook2 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { + hook2 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { hook2Called = true + if eventBuilder != nil { + hook2EventBuilderExists = true + } return nil } @@ -239,6 +252,8 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_WithHooks(t *testing.T assert.NoError(t, err) assert.True(t, hook1Called) assert.True(t, hook2Called) + assert.True(t, hook1EventBuilderExists) + assert.True(t, hook2EventBuilderExists) } func TestPubSubSubscriptionDataSource_SubscriptionOnStart_HookReturnsClose(t *testing.T) { @@ -247,10 +262,10 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_HookReturnsClose(t *te return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) // Add hook that returns close=true - hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { + hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { return nil } @@ -280,11 +295,11 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_HookReturnsError(t *te return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) expectedError := errors.New("hook error") // Add hook that returns an error - hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { + hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { return expectedError } @@ -315,16 +330,16 @@ func TestPubSubSubscriptionDataSource_SetSubscriptionOnStartFns(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) // Initially should have no hooks assert.Len(t, dataSource.hooks.SubscriptionOnStart, 0) // Add hooks - hook1 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { + hook1 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { return nil } - hook2 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration) error { + hook2 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { return nil } @@ -345,7 +360,7 @@ func TestNewPubSubSubscriptionDataSource(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) assert.NotNil(t, dataSource) assert.Equal(t, mockAdapter, dataSource.pubSub) @@ -359,7 +374,7 @@ func TestPubSubSubscriptionDataSource_InterfaceCompliance(t *testing.T) { return nil } - dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop()) + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) // Test that it implements SubscriptionDataSource interface var _ SubscriptionDataSource = dataSource diff --git a/router/pkg/pubsub/datasource/subscription_event_updater.go b/router/pkg/pubsub/datasource/subscription_event_updater.go index 1c920f7222..5c1ee69ac6 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater.go @@ -23,6 +23,7 @@ type subscriptionEventUpdater struct { subscriptionEventConfiguration SubscriptionEventConfiguration hooks Hooks logger *zap.Logger + eventBuilder EventBuilderFn } func (s *subscriptionEventUpdater) Update(events []StreamEvent) { @@ -80,7 +81,7 @@ func (s *subscriptionEventUpdater) updateSubscription(ctx context.Context, wg *s // modify events with hooks var err error for i := range hooks { - events, err = hooks[i](ctx, s.subscriptionEventConfiguration, events) + events, err = hooks[i](ctx, s.subscriptionEventConfiguration, s.eventBuilder, events) if err != nil { errCh <- err } @@ -134,11 +135,13 @@ func NewSubscriptionEventUpdater( hooks Hooks, eventUpdater resolve.SubscriptionUpdater, logger *zap.Logger, + eventBuilder EventBuilderFn, ) SubscriptionEventUpdater { return &subscriptionEventUpdater{ subscriptionEventConfiguration: cfg, hooks: hooks, eventUpdater: eventUpdater, logger: logger, + eventBuilder: eventBuilder, } } diff --git a/router/pkg/pubsub/datasource/subscription_event_updater_test.go b/router/pkg/pubsub/datasource/subscription_event_updater_test.go index 693d9d5da0..283c624310 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater_test.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater_test.go @@ -37,6 +37,11 @@ type receivedHooksArgs struct { cfg SubscriptionEventConfiguration } +// testEventBuilder is a reusable event builder for tests +func testEventBuilder(data []byte) MutableStreamEvent { + return mutableTestEvent(data) +} + func TestSubscriptionEventUpdater_Update_NoHooks(t *testing.T) { mockUpdater := NewMockSubscriptionUpdater(t) config := &testSubscriptionEventConfig{ @@ -78,7 +83,7 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Success(t *testin // Create wrapper function for the mock receivedArgs := make(chan receivedHooksArgs, 1) - testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { receivedArgs <- receivedHooksArgs{events: events, cfg: cfg} return modifiedEvents, nil } @@ -96,6 +101,7 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Success(t *testin hooks: Hooks{ OnReceiveEvents: []OnReceiveEventsFn{testHook}, }, + eventBuilder: testEventBuilder, } updater.Update(originalEvents) @@ -122,7 +128,7 @@ func TestSubscriptionEventUpdater_UpdateSubscriptions_WithHooks_Error(t *testing hookError := errors.New("hook processing error") // Define hook that returns an error - testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { return nil, hookError } @@ -140,6 +146,7 @@ func TestSubscriptionEventUpdater_UpdateSubscriptions_WithHooks_Error(t *testing hooks: Hooks{ OnReceiveEvents: []OnReceiveEventsFn{testHook}, }, + eventBuilder: testEventBuilder, } updater.Update(events) @@ -163,13 +170,13 @@ func TestSubscriptionEventUpdater_Update_WithMultipleHooks_Success(t *testing.T) // Chain of hooks that modify the data receivedArgs1 := make(chan receivedHooksArgs, 1) - hook1 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + hook1 := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { receivedArgs1 <- receivedHooksArgs{events: events, cfg: cfg} return []StreamEvent{&testEvent{mutableTestEvent("modified by hook1")}}, nil } receivedArgs2 := make(chan receivedHooksArgs, 1) - hook2 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + hook2 := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { receivedArgs2 <- receivedHooksArgs{events: events, cfg: cfg} return []StreamEvent{&testEvent{mutableTestEvent("modified by hook2")}}, nil } @@ -187,6 +194,7 @@ func TestSubscriptionEventUpdater_Update_WithMultipleHooks_Success(t *testing.T) hooks: Hooks{ OnReceiveEvents: []OnReceiveEventsFn{hook1, hook2}, }, + eventBuilder: testEventBuilder, } updater.Update(originalEvents) @@ -255,7 +263,7 @@ func TestSubscriptionEventUpdater_SetHooks(t *testing.T) { fieldName: "testField", } - testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { return events, nil } @@ -267,6 +275,7 @@ func TestSubscriptionEventUpdater_SetHooks(t *testing.T) { eventUpdater: mockUpdater, subscriptionEventConfiguration: config, hooks: Hooks{}, + eventBuilder: testEventBuilder, } updater.SetHooks(hooks) @@ -282,7 +291,7 @@ func TestNewSubscriptionEventUpdater(t *testing.T) { fieldName: "testField", } - testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { return events, nil } @@ -290,7 +299,7 @@ func TestNewSubscriptionEventUpdater(t *testing.T) { OnReceiveEvents: []OnReceiveEventsFn{testHook}, } - updater := NewSubscriptionEventUpdater(config, hooks, mockUpdater, zap.NewNop()) + updater := NewSubscriptionEventUpdater(config, hooks, mockUpdater, zap.NewNop(), testEventBuilder) assert.NotNil(t, updater) @@ -349,7 +358,7 @@ func TestSubscriptionEventUpdater_Update_WithSingleHookModification(t *testing.T } // Hook that modifies events by adding a prefix - hook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + hook := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { modifiedEvents := make([]StreamEvent, len(events)) for i, event := range events { modifiedData := "modified: " + string(event.GetData()) @@ -398,7 +407,7 @@ func TestSubscriptionEventUpdater_Update_WithSingleHookError_ClosesSubscriptionA hookError := errors.New("hook processing failed") // Hook that returns an error - hook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + hook := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { // Return the events but also return an error return events, hookError } @@ -418,7 +427,7 @@ func TestSubscriptionEventUpdater_Update_WithSingleHookError_ClosesSubscriptionA updater := NewSubscriptionEventUpdater(config, Hooks{ OnReceiveEvents: []OnReceiveEventsFn{hook}, - }, mockUpdater, logger) + }, mockUpdater, logger, testEventBuilder) updater.Update(events) @@ -457,7 +466,7 @@ func TestSubscriptionEventUpdater_Update_WithMultipleHooksChaining(t *testing.T) // Hook 1: Adds "step1: " prefix receivedArgs1 := make(chan receivedHooksArgs, 1) - hook1 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + hook1 := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { mu.Lock() hookCallOrder = append(hookCallOrder, 1) mu.Unlock() @@ -472,7 +481,7 @@ func TestSubscriptionEventUpdater_Update_WithMultipleHooksChaining(t *testing.T) // Hook 2: Adds "step2: " prefix receivedArgs2 := make(chan receivedHooksArgs, 1) - hook2 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + hook2 := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { mu.Lock() hookCallOrder = append(hookCallOrder, 2) mu.Unlock() @@ -487,7 +496,7 @@ func TestSubscriptionEventUpdater_Update_WithMultipleHooksChaining(t *testing.T) // Hook 3: Adds "step3: " prefix receivedArgs3 := make(chan receivedHooksArgs, 1) - hook3 := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + hook3 := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { mu.Lock() hookCallOrder = append(hookCallOrder, 3) mu.Unlock() @@ -640,7 +649,7 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHookError_ClosesSubscri &testEvent{mutableTestEvent("test data")}, } - testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { return events, tc.hookError } @@ -650,6 +659,7 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHookError_ClosesSubscri hooks: Hooks{ OnReceiveEvents: []OnReceiveEventsFn{testHook}, }, + eventBuilder: testEventBuilder, } subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} @@ -679,7 +689,7 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Error_LoggerWrite hookError := errors.New("hook processing error") // Define hook that returns an error - testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, events []StreamEvent) ([]StreamEvent, error) { + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { return nil, hookError } @@ -690,7 +700,7 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Error_LoggerWrite // The logger.Error() call should be executed when an error occurs updater := NewSubscriptionEventUpdater(config, Hooks{ OnReceiveEvents: []OnReceiveEventsFn{testHook}, - }, mockUpdater, logger) + }, mockUpdater, logger, testEventBuilder) subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ diff --git a/router/pkg/pubsub/kafka/engine_datasource_factory.go b/router/pkg/pubsub/kafka/engine_datasource_factory.go index b4e1356714..e6bebecbfc 100644 --- a/router/pkg/pubsub/kafka/engine_datasource_factory.go +++ b/router/pkg/pubsub/kafka/engine_datasource_factory.go @@ -63,27 +63,33 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri } func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (datasource.SubscriptionDataSource, error) { - return datasource.NewPubSubSubscriptionDataSource[*SubscriptionEventConfiguration]( - c.KafkaAdapter, - func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { - val, _, _, err := jsonparser.Get(input, "topics") - if err != nil { - return err - } - - _, err = xxh.Write(val) - if err != nil { - return err - } - - val, _, _, err = jsonparser.Get(input, "providerId") - if err != nil { - return err - } - - _, err = xxh.Write(val) + uniqueRequestIdFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + val, _, _, err := jsonparser.Get(input, "topics") + if err != nil { + return err + } + + _, err = xxh.Write(val) + if err != nil { + return err + } + + val, _, _, err = jsonparser.Get(input, "providerId") + if err != nil { return err - }, c.logger), nil + } + + _, err = xxh.Write(val) + return err + } + + eventCreateFn := func(data []byte) datasource.MutableStreamEvent { + return &MutableEvent{Data: data} + } + + return datasource.NewPubSubSubscriptionDataSource[*SubscriptionEventConfiguration]( + c.KafkaAdapter, uniqueRequestIdFn, c.logger, eventCreateFn, + ), nil } func (c *EngineDataSourceFactory) ResolveDataSourceSubscriptionInput() (string, error) { diff --git a/router/pkg/pubsub/kafka/provider_builder.go b/router/pkg/pubsub/kafka/provider_builder.go index c69a458eba..29dc51d579 100644 --- a/router/pkg/pubsub/kafka/provider_builder.go +++ b/router/pkg/pubsub/kafka/provider_builder.go @@ -153,11 +153,17 @@ func buildProvider(ctx context.Context, provider config.KafkaEventSource, logger if err != nil { return nil, fmt.Errorf("failed to build options for Kafka provider with ID \"%s\": %w", provider.ID, err) } + adapter, err := NewProviderAdapter(ctx, logger, kafkaOpts, providerOpts) if err != nil { return nil, fmt.Errorf("failed to create adapter for Kafka provider with ID \"%s\": %w", provider.ID, err) } - pubSubProvider := datasource.NewPubSubProvider(provider.ID, providerTypeID, adapter, logger) + + eventBuilder := func(data []byte) datasource.MutableStreamEvent { + return &MutableEvent{Data: data} + } + + pubSubProvider := datasource.NewPubSubProvider(provider.ID, providerTypeID, adapter, logger, eventBuilder) return pubSubProvider, nil } diff --git a/router/pkg/pubsub/nats/engine_datasource_factory.go b/router/pkg/pubsub/nats/engine_datasource_factory.go index f4006448dd..8aa644814b 100644 --- a/router/pkg/pubsub/nats/engine_datasource_factory.go +++ b/router/pkg/pubsub/nats/engine_datasource_factory.go @@ -76,27 +76,33 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri } func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (datasource.SubscriptionDataSource, error) { - return datasource.NewPubSubSubscriptionDataSource[*SubscriptionEventConfiguration]( - c.NatsAdapter, - func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { - val, _, _, err := jsonparser.Get(input, "subjects") - if err != nil { - return err - } - - _, err = xxh.Write(val) - if err != nil { - return err - } + uniqueRequestIdFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + val, _, _, err := jsonparser.Get(input, "subjects") + if err != nil { + return err + } - val, _, _, err = jsonparser.Get(input, "providerId") - if err != nil { - return err - } + _, err = xxh.Write(val) + if err != nil { + return err + } - _, err = xxh.Write(val) + val, _, _, err = jsonparser.Get(input, "providerId") + if err != nil { return err - }, c.logger), nil + } + + _, err = xxh.Write(val) + return err + } + + createEventFn := func(data []byte) datasource.MutableStreamEvent { + return &MutableEvent{Data: data} + } + + return datasource.NewPubSubSubscriptionDataSource[*SubscriptionEventConfiguration]( + c.NatsAdapter, + uniqueRequestIdFn, c.logger, createEventFn), nil } func (c *EngineDataSourceFactory) ResolveDataSourceSubscriptionInput() (string, error) { diff --git a/router/pkg/pubsub/nats/engine_datasource_factory_test.go b/router/pkg/pubsub/nats/engine_datasource_factory_test.go index 053ff0d702..d6017404a7 100644 --- a/router/pkg/pubsub/nats/engine_datasource_factory_test.go +++ b/router/pkg/pubsub/nats/engine_datasource_factory_test.go @@ -19,6 +19,11 @@ import ( "go.uber.org/zap" ) +// testNatsEventBuilder is a reusable event builder for tests +func testNatsEventBuilder(data []byte) datasource.MutableStreamEvent { + return &MutableEvent{Data: data} +} + func TestNatsEngineDataSourceFactory(t *testing.T) { // Create the data source to test with a real adapter adapter := &ProviderAdapter{} @@ -172,7 +177,7 @@ func TestNatsEngineDataSourceFactoryWithStreamConfiguration(t *testing.T) { func TestEngineDataSourceFactory_RequestDataSource(t *testing.T) { // Create mock adapter mockAdapter := NewMockAdapter(t) - provider := datasource.NewPubSubProvider("test-provider", "nats", mockAdapter, zap.NewNop()) + provider := datasource.NewPubSubProvider("test-provider", "nats", mockAdapter, zap.NewNop(), testNatsEventBuilder) // Configure mock expectations for Request mockAdapter.On("Request", mock.Anything, mock.MatchedBy(func(event *PublishAndRequestEventConfiguration) bool { diff --git a/router/pkg/pubsub/nats/engine_datasource_test.go b/router/pkg/pubsub/nats/engine_datasource_test.go index 183179c083..1f0818d305 100644 --- a/router/pkg/pubsub/nats/engine_datasource_test.go +++ b/router/pkg/pubsub/nats/engine_datasource_test.go @@ -193,7 +193,7 @@ func TestNatsRequestDataSource_Load(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockAdapter := NewMockAdapter(t) - provider := datasource.NewPubSubProvider("test-provider", "nats", mockAdapter, zap.NewNop()) + provider := datasource.NewPubSubProvider("test-provider", "nats", mockAdapter, zap.NewNop(), testNatsEventBuilder) tt.mockSetup(mockAdapter) dataSource := &NatsRequestDataSource{ diff --git a/router/pkg/pubsub/nats/provider_builder.go b/router/pkg/pubsub/nats/provider_builder.go index 2b07c4217a..d8314305ae 100644 --- a/router/pkg/pubsub/nats/provider_builder.go +++ b/router/pkg/pubsub/nats/provider_builder.go @@ -122,11 +122,17 @@ func buildProvider(ctx context.Context, provider config.NatsEventSource, logger if err != nil { return nil, fmt.Errorf("failed to build options for Nats provider with ID \"%s\": %w", provider.ID, err) } + adapter, err := NewAdapter(ctx, logger, provider.URL, options, hostName, routerListenAddr, providerOpts) if err != nil { return nil, fmt.Errorf("failed to create adapter for Nats provider with ID \"%s\": %w", provider.ID, err) } - pubSubProvider := datasource.NewPubSubProvider(provider.ID, providerTypeID, adapter, logger) + + eventBuilder := func(data []byte) datasource.MutableStreamEvent { + return &MutableEvent{Data: data} + } + + pubSubProvider := datasource.NewPubSubProvider(provider.ID, providerTypeID, adapter, logger, eventBuilder) return pubSubProvider, nil } diff --git a/router/pkg/pubsub/redis/engine_datasource_factory.go b/router/pkg/pubsub/redis/engine_datasource_factory.go index 1e9f9866e4..357cc5aff1 100644 --- a/router/pkg/pubsub/redis/engine_datasource_factory.go +++ b/router/pkg/pubsub/redis/engine_datasource_factory.go @@ -74,27 +74,33 @@ func (c *EngineDataSourceFactory) ResolveDataSourceInput(eventData []byte) (stri // ResolveDataSourceSubscription returns the subscription data source func (c *EngineDataSourceFactory) ResolveDataSourceSubscription() (datasource.SubscriptionDataSource, error) { - return datasource.NewPubSubSubscriptionDataSource[*SubscriptionEventConfiguration]( - c.RedisAdapter, - func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { - val, _, _, err := jsonparser.Get(input, "channels") - if err != nil { - return err - } - - _, err = xxh.Write(val) - if err != nil { - return err - } + uniqueRequestIdFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + val, _, _, err := jsonparser.Get(input, "channels") + if err != nil { + return err + } - val, _, _, err = jsonparser.Get(input, "providerId") - if err != nil { - return err - } + _, err = xxh.Write(val) + if err != nil { + return err + } - _, err = xxh.Write(val) + val, _, _, err = jsonparser.Get(input, "providerId") + if err != nil { return err - }, c.logger), nil + } + + _, err = xxh.Write(val) + return err + } + + eventCreateFn := func(data []byte) datasource.MutableStreamEvent { + return &MutableEvent{Data: data} + } + + return datasource.NewPubSubSubscriptionDataSource[*SubscriptionEventConfiguration]( + c.RedisAdapter, uniqueRequestIdFn, c.logger, eventCreateFn, + ), nil } // ResolveDataSourceSubscriptionInput builds the input for the subscription data source diff --git a/router/pkg/pubsub/redis/provider_builder.go b/router/pkg/pubsub/redis/provider_builder.go index f8814b7d42..457dd49b22 100644 --- a/router/pkg/pubsub/redis/provider_builder.go +++ b/router/pkg/pubsub/redis/provider_builder.go @@ -71,7 +71,11 @@ func (b *ProviderBuilder) BuildEngineDataSourceFactory(data *nodev1.RedisEventCo // Providers returns the Redis PubSub providers for the given provider IDs func (b *ProviderBuilder) BuildProvider(provider config.RedisEventSource, providerOpts datasource.ProviderOpts) (datasource.Provider, error) { adapter := NewProviderAdapter(b.ctx, b.logger, provider.URLs, provider.ClusterEnabled, providerOpts) - pubSubProvider := datasource.NewPubSubProvider(provider.ID, providerTypeID, adapter, b.logger) + eventBuilder := func(data []byte) datasource.MutableStreamEvent { + return &MutableEvent{Data: data} + } + + pubSubProvider := datasource.NewPubSubProvider(provider.ID, providerTypeID, adapter, b.logger, eventBuilder) return pubSubProvider, nil } From 7020faf29d1b2a35b0e7564957c2594da21582fc Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Thu, 30 Oct 2025 18:50:13 +0000 Subject: [PATCH 06/44] fix(router): recover from panics in hooks (#2311) --- .../pkg/pubsub/datasource/pubsubprovider.go | 48 +++++--- .../pubsub/datasource/pubsubprovider_test.go | 107 +++++++++++++++++- .../datasource/subscription_datasource.go | 18 +++ .../subscription_datasource_test.go | 69 +++++++++++ .../datasource/subscription_event_updater.go | 17 ++- .../subscription_event_updater_test.go | 78 +++++++++++++ 6 files changed, 317 insertions(+), 20 deletions(-) diff --git a/router/pkg/pubsub/datasource/pubsubprovider.go b/router/pkg/pubsub/datasource/pubsubprovider.go index 2a898b6ce3..3697229182 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider.go +++ b/router/pkg/pubsub/datasource/pubsubprovider.go @@ -2,8 +2,10 @@ package datasource import ( "context" + "fmt" "go.uber.org/zap" + "go.uber.org/zap/zapcore" ) type PubSubProvider struct { @@ -17,12 +19,39 @@ type PubSubProvider struct { // applyPublishEventHooks processes events through a chain of hook functions // Each hook receives the result from the previous hook, creating a proper middleware pipeline -func applyPublishEventHooks(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn, hooks []OnPublishEventsFn) ([]StreamEvent, error) { - currentEvents := events - for _, hook := range hooks { +func (p *PubSubProvider) applyPublishEventHooks(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) (currentEvents []StreamEvent, err error) { + defer func() { + if r := recover(); r != nil { + if p.Logger != nil { + p.Logger. + WithOptions(zap.AddStacktrace(zapcore.ErrorLevel)). + Error("[Recovery from handler panic]", + zap.Any("error", r), + ) + } + + switch v := r.(type) { + case error: + err = v + default: + err = fmt.Errorf("%v", r) + } + } + }() + + currentEvents = events + for _, hook := range p.hooks.OnPublishEvents { var err error - currentEvents, err = hook(ctx, cfg, currentEvents, eventBuilder) + currentEvents, err = hook(ctx, cfg, currentEvents, p.eventBuilder) if err != nil { + p.Logger.Error( + "error applying publish event hooks", + zap.Error(err), + zap.String("provider_id", cfg.ProviderID()), + zap.String("provider_type_id", string(cfg.ProviderType())), + zap.String("field_name", cfg.RootFieldName()), + ) + return currentEvents, err } } @@ -60,16 +89,7 @@ func (p *PubSubProvider) Publish(ctx context.Context, cfg PublishEventConfigurat return p.Adapter.Publish(ctx, cfg, events) } - processedEvents, hooksErr := applyPublishEventHooks(ctx, cfg, events, p.eventBuilder, p.hooks.OnPublishEvents) - if hooksErr != nil { - p.Logger.Error( - "error applying publish event hooks", - zap.Error(hooksErr), - zap.String("provider_id", cfg.ProviderID()), - zap.String("provider_type_id", string(cfg.ProviderType())), - zap.String("field_name", cfg.RootFieldName()), - ) - } + processedEvents, hooksErr := p.applyPublishEventHooks(ctx, cfg, events) errPublish := p.Adapter.Publish(ctx, cfg, processedEvents) if errPublish != nil { diff --git a/router/pkg/pubsub/datasource/pubsubprovider_test.go b/router/pkg/pubsub/datasource/pubsubprovider_test.go index 590297f689..e623de93b0 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider_test.go +++ b/router/pkg/pubsub/datasource/pubsubprovider_test.go @@ -401,8 +401,14 @@ func TestApplyPublishEventHooks_NoHooks(t *testing.T) { originalEvents := []StreamEvent{ &testEvent{mutableTestEvent("test data")}, } + provider := &PubSubProvider{ + Logger: zap.NewNop(), + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{}, + }, + } - result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{}) + result, err := provider.applyPublishEventHooks(ctx, config, originalEvents) assert.NoError(t, err) assert.Equal(t, originalEvents, result) @@ -426,7 +432,14 @@ func TestApplyPublishEventHooks_SingleHook_Success(t *testing.T) { return modifiedEvents, nil } - result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{hook}) + provider := &PubSubProvider{ + Logger: zap.NewNop(), + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{hook}, + }, + } + + result, err := provider.applyPublishEventHooks(ctx, config, originalEvents) assert.NoError(t, err) assert.Equal(t, modifiedEvents, result) @@ -448,7 +461,14 @@ func TestApplyPublishEventHooks_SingleHook_Error(t *testing.T) { return nil, hookError } - result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{hook}) + provider := &PubSubProvider{ + Logger: zap.NewNop(), + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{hook}, + }, + } + + result, err := provider.applyPublishEventHooks(ctx, config, originalEvents) assert.Error(t, err) assert.Equal(t, hookError, err) @@ -476,7 +496,14 @@ func TestApplyPublishEventHooks_MultipleHooks_Success(t *testing.T) { return []StreamEvent{&testEvent{mutableTestEvent("final")}}, nil } - result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{hook1, hook2, hook3}) + provider := &PubSubProvider{ + Logger: zap.NewNop(), + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{hook1, hook2, hook3}, + }, + } + + result, err := provider.applyPublishEventHooks(ctx, config, originalEvents) assert.NoError(t, err) assert.Len(t, result, 1) @@ -505,9 +532,79 @@ func TestApplyPublishEventHooks_MultipleHooks_MiddleHookError(t *testing.T) { return []StreamEvent{&testEvent{mutableTestEvent("final")}}, nil } - result, err := applyPublishEventHooks(ctx, config, originalEvents, testPubSubEventBuilder, []OnPublishEventsFn{hook1, hook2, hook3}) + provider := &PubSubProvider{ + Logger: zap.NewNop(), + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{hook1, hook2, hook3}, + }, + } + + result, err := provider.applyPublishEventHooks(ctx, config, originalEvents) assert.Error(t, err) assert.Equal(t, middleHookError, err) assert.Nil(t, result) } + +func TestApplyPublishEventHooks_PanicRecovery(t *testing.T) { + panicErr := errors.New("panic error") + + tests := []struct { + name string + panicValue any + expectedErr error + expectedErrText string + }{ + { + name: "error type", + panicValue: panicErr, + expectedErr: panicErr, + }, + { + name: "string type", + panicValue: "panic string message", + expectedErrText: "panic string message", + }, + { + name: "other type", + panicValue: 42, + expectedErrText: "42", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + config := &testPublishConfig{ + providerID: "test-provider", + providerType: ProviderTypeKafka, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{mutableTestEvent("original")}, + } + + hook := func(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) { + panic(tt.panicValue) + } + + provider := &PubSubProvider{ + Logger: zap.NewNop(), + hooks: Hooks{ + OnPublishEvents: []OnPublishEventsFn{hook}, + }, + } + + result, err := provider.applyPublishEventHooks(ctx, config, originalEvents) + + assert.Error(t, err) + if tt.expectedErr != nil { + assert.Equal(t, tt.expectedErr, err) + } + if tt.expectedErrText != "" { + assert.Contains(t, err.Error(), tt.expectedErrText) + } + assert.Equal(t, originalEvents, result) + }) + } +} diff --git a/router/pkg/pubsub/datasource/subscription_datasource.go b/router/pkg/pubsub/datasource/subscription_datasource.go index fb35054bd5..c625af9c33 100644 --- a/router/pkg/pubsub/datasource/subscription_datasource.go +++ b/router/pkg/pubsub/datasource/subscription_datasource.go @@ -3,10 +3,12 @@ package datasource import ( "encoding/json" "errors" + "fmt" "github.com/cespare/xxhash/v2" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" + "go.uber.org/zap/zapcore" ) type uniqueRequestIdFn func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error @@ -48,6 +50,22 @@ func (s *PubSubSubscriptionDataSource[C]) Start(ctx *resolve.Context, input []by } func (s *PubSubSubscriptionDataSource[C]) SubscriptionOnStart(ctx resolve.StartupHookContext, input []byte) (err error) { + defer func() { + if r := recover(); r != nil { + s.logger. + WithOptions(zap.AddStacktrace(zapcore.ErrorLevel)). + Error("[Recovery from handler panic]", + zap.Any("error", r), + ) + switch v := r.(type) { + case error: + err = v + default: + err = fmt.Errorf("%v", r) + } + } + }() + for _, fn := range s.hooks.SubscriptionOnStart { conf, errConf := s.SubscriptionEventConfiguration(input) if errConf != nil { diff --git a/router/pkg/pubsub/datasource/subscription_datasource_test.go b/router/pkg/pubsub/datasource/subscription_datasource_test.go index 8bba79b259..a292f4b0f4 100644 --- a/router/pkg/pubsub/datasource/subscription_datasource_test.go +++ b/router/pkg/pubsub/datasource/subscription_datasource_test.go @@ -382,3 +382,72 @@ func TestPubSubSubscriptionDataSource_InterfaceCompliance(t *testing.T) { // Test that it implements HookableSubscriptionDataSource interface var _ resolve.HookableSubscriptionDataSource = dataSource } + +func TestPubSubSubscriptionDataSource_SubscriptionOnStart_PanicRecovery(t *testing.T) { + panicErr := errors.New("panic error") + + tests := []struct { + name string + panicValue any + expectedErr error + expectedErrText string + }{ + { + name: "error type", + panicValue: panicErr, + expectedErr: panicErr, + }, + { + name: "string type", + panicValue: "panic string message", + expectedErrText: "panic string message", + }, + { + name: "other type", + panicValue: 42, + expectedErrText: "42", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) + + // Add hook that panics + hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { + panic(tt.panicValue) + } + + dataSource.SetHooks(Hooks{ + SubscriptionOnStart: []SubscriptionOnStartFn{hook}, + }) + + testConfig := testSubscriptionEventConfiguration{ + Topic: "test-topic", + Subject: "test-subject", + } + input, err := json.Marshal(testConfig) + assert.NoError(t, err) + + hookCtx := resolve.StartupHookContext{ + Context: context.Background(), + Updater: func(data []byte) {}, + } + + err = dataSource.SubscriptionOnStart(hookCtx, input) + + assert.Error(t, err) + if tt.expectedErr != nil { + assert.Equal(t, tt.expectedErr, err) + } + if tt.expectedErrText != "" { + assert.Contains(t, err.Error(), tt.expectedErrText) + } + }) + } +} diff --git a/router/pkg/pubsub/datasource/subscription_event_updater.go b/router/pkg/pubsub/datasource/subscription_event_updater.go index 5c1ee69ac6..615354ba1a 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater.go @@ -6,6 +6,7 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" + "go.uber.org/zap/zapcore" ) // SubscriptionEventUpdater is a wrapper around the SubscriptionUpdater interface @@ -73,7 +74,10 @@ func (s *subscriptionEventUpdater) SetHooks(hooks Hooks) { func (s *subscriptionEventUpdater) updateSubscription(ctx context.Context, wg *sync.WaitGroup, errCh chan error, semaphore chan struct{}, subID resolve.SubscriptionIdentifier, events []StreamEvent) { defer wg.Done() defer func() { - <-semaphore // Release the slot when done + if r := recover(); r != nil { + s.recoverPanic(subID, r) + } + <-semaphore // release the slot when done }() hooks := s.hooks.OnReceiveEvents @@ -100,6 +104,17 @@ func (s *subscriptionEventUpdater) updateSubscription(ctx context.Context, wg *s } } +func (s *subscriptionEventUpdater) recoverPanic(subID resolve.SubscriptionIdentifier, err any) { + s.logger. + WithOptions(zap.AddStacktrace(zapcore.ErrorLevel)). + Error("[Recovery from handler panic]", + zap.Int64("subscription_id", subID.SubscriptionID), + zap.Any("error", err), + ) + + s.eventUpdater.CloseSubscription(resolve.SubscriptionCloseKindDownstreamServiceError, subID) +} + // deduplicateAndLogErrors collects errors from errCh // and deduplicates them based on their err.Error() value. // Afterwards it uses s.logger to log the message. diff --git a/router/pkg/pubsub/datasource/subscription_event_updater_test.go b/router/pkg/pubsub/datasource/subscription_event_updater_test.go index 283c624310..2c0295dd1c 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater_test.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater_test.go @@ -719,3 +719,81 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Error_LoggerWrite return len(logObserver.FilterMessageSnippet("some handlers have thrown an error").TakeAll()) == 1 }, time.Second, 10*time.Millisecond, "expected one deduplicated error log") } + +func TestSubscriptionEventUpdater_OnReceiveEvents_PanicRecovery(t *testing.T) { + panicErr := errors.New("panic error") + + tests := []struct { + name string + panicValue any + }{ + { + name: "error type", + panicValue: panicErr, + }, + { + name: "string type", + panicValue: "panic string message", + }, + { + name: "other type", + panicValue: 42, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + core, logObserver := observer.New(zap.InfoLevel) + logger := zap.New(core) + + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + events := []StreamEvent{ + &testEvent{mutableTestEvent("test data")}, + } + + // Create hook that panics + testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { + panic(tt.panicValue) + } + + subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} + mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ + context.Background(): subId, + }) + mockUpdater.On("CloseSubscription", resolve.SubscriptionCloseKindDownstreamServiceError, subId).Return() + + updater := &subscriptionEventUpdater{ + eventUpdater: mockUpdater, + subscriptionEventConfiguration: config, + hooks: Hooks{ + OnReceiveEvents: []OnReceiveEventsFn{testHook}, + }, + logger: logger, + } + + updater.Update(events) + + // Wait for async processing to complete and assert panic was logged + assert.Eventually(t, func() bool { + logs := logObserver.FilterMessage("[Recovery from handler panic]").All() + return len(logs) == 1 + }, 10*time.Millisecond, time.Millisecond, "expected panic recovery log") + + // Assert that subscription was closed due to panic + mockUpdater.AssertCalled(t, "CloseSubscription", resolve.SubscriptionCloseKindDownstreamServiceError, subId) + mockUpdater.AssertNotCalled(t, "UpdateSubscription") + + // Assert that panic was logged with correct details + logs := logObserver.FilterMessage("[Recovery from handler panic]").All() + assert.Len(t, logs, 1) + assert.Equal(t, zap.ErrorLevel, logs[0].Level) + assert.Equal(t, int64(1), logs[0].ContextMap()["subscription_id"]) + assert.NotNil(t, logs[0].ContextMap()["error"]) + }) + } +} From e6169f3516def45097e7d09ff28a616a3d7dc2e2 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Fri, 31 Oct 2025 12:08:22 +0000 Subject: [PATCH 07/44] fix(router): ignore nil events (#2315) --- router/pkg/pubsub/datasource/pubsubprovider.go | 5 +++++ router/pkg/pubsub/datasource/pubsubprovider_test.go | 10 +++++++++- .../pubsub/datasource/subscription_event_updater.go | 5 +++++ .../datasource/subscription_event_updater_test.go | 1 + 4 files changed, 20 insertions(+), 1 deletion(-) diff --git a/router/pkg/pubsub/datasource/pubsubprovider.go b/router/pkg/pubsub/datasource/pubsubprovider.go index 3697229182..e20f1ace2b 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider.go +++ b/router/pkg/pubsub/datasource/pubsubprovider.go @@ -3,6 +3,7 @@ package datasource import ( "context" "fmt" + "slices" "go.uber.org/zap" "go.uber.org/zap/zapcore" @@ -43,6 +44,10 @@ func (p *PubSubProvider) applyPublishEventHooks(ctx context.Context, cfg Publish for _, hook := range p.hooks.OnPublishEvents { var err error currentEvents, err = hook(ctx, cfg, currentEvents, p.eventBuilder) + currentEvents = slices.DeleteFunc(currentEvents, func(event StreamEvent) bool { + return event == nil + }) + if err != nil { p.Logger.Error( "error applying publish event hooks", diff --git a/router/pkg/pubsub/datasource/pubsubprovider_test.go b/router/pkg/pubsub/datasource/pubsubprovider_test.go index e623de93b0..0bf12e7f60 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider_test.go +++ b/router/pkg/pubsub/datasource/pubsubprovider_test.go @@ -225,9 +225,17 @@ func TestProvider_Publish_WithHooks_Success(t *testing.T) { } originalEvents := []StreamEvent{ &testEvent{mutableTestEvent("original data")}, + &testEvent{mutableTestEvent("original data 2")}, } modifiedEvents := []StreamEvent{ &testEvent{mutableTestEvent("modified data")}, + nil, // should be ignored by publisher + &testEvent{mutableTestEvent("modified data 2")}, + nil, // should be ignored by publisher + } + expectedEvents := []StreamEvent{ + &testEvent{mutableTestEvent("modified data")}, + &testEvent{mutableTestEvent("modified data 2")}, } var eventBuilderExists bool @@ -240,7 +248,7 @@ func TestProvider_Publish_WithHooks_Success(t *testing.T) { return modifiedEvents, nil } - mockAdapter.On("Publish", mock.Anything, config, modifiedEvents).Return(nil) + mockAdapter.On("Publish", mock.Anything, config, expectedEvents).Return(nil) provider := PubSubProvider{ Adapter: mockAdapter, diff --git a/router/pkg/pubsub/datasource/subscription_event_updater.go b/router/pkg/pubsub/datasource/subscription_event_updater.go index 615354ba1a..5ed4a6c837 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater.go @@ -2,6 +2,7 @@ package datasource import ( "context" + "slices" "sync" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" @@ -86,6 +87,10 @@ func (s *subscriptionEventUpdater) updateSubscription(ctx context.Context, wg *s var err error for i := range hooks { events, err = hooks[i](ctx, s.subscriptionEventConfiguration, s.eventBuilder, events) + events = slices.DeleteFunc(events, func(event StreamEvent) bool { + return event == nil + }) + if err != nil { errCh <- err } diff --git a/router/pkg/pubsub/datasource/subscription_event_updater_test.go b/router/pkg/pubsub/datasource/subscription_event_updater_test.go index 2c0295dd1c..1b6c1bd3a7 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater_test.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater_test.go @@ -79,6 +79,7 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Success(t *testin } modifiedEvents := []StreamEvent{ &testEvent{mutableTestEvent("modified data")}, + nil, // this should simply be ignored } // Create wrapper function for the mock From a4df54987fa1f838b10e20b270357b8dab5b8758 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Fri, 31 Oct 2025 12:20:48 +0000 Subject: [PATCH 08/44] chore: update Cosmo Streams ADR (#2314) --- adr/cosmo-streams-v1.md | 155 ++++++++++++++++++++++++++-------------- 1 file changed, 100 insertions(+), 55 deletions(-) diff --git a/adr/cosmo-streams-v1.md b/adr/cosmo-streams-v1.md index 21b035ff0b..dfacf34c40 100644 --- a/adr/cosmo-streams-v1.md +++ b/adr/cosmo-streams-v1.md @@ -7,7 +7,7 @@ status: Accepted # ADR - Cosmo Streams V1 -- **Author:** Alessandro Pagnin +- **Author:** Alessandro Pagnin, Dominik Korittki - **Date:** 2025-07-16 - **Status:** Accepted - **RFC:** ../rfcs/cosmo-streams-v1.md @@ -17,7 +17,7 @@ This ADR describes new hooks that will be added to the router to support more cu The goal is to allow developers to customize the cosmo streams behavior. ## Decision -The following interfaces will extend the existing logic in the custom modules. +The following interfaces will extend the existing logic in custom modules. These provide additional control over subscriptions by providing hooks, which are invoked during specific events. - `SubscriptionOnStartHandler`: Called once at subscription start. @@ -25,7 +25,7 @@ These provide additional control over subscriptions by providing hooks, which ar - `StreamPublishEventHandler`: Called each time a batch of events is going to be sent to the provider. ```go -// STRUCTURES TO BE ADDED TO PUBSUB PACKAGE +// STRUCTURES TO BE ADDED TO PUBSUB/DATASOURCE PACKAGE type ProviderType string const ( ProviderTypeNats ProviderType = "nats" @@ -33,23 +33,44 @@ const ( ProviderTypeRedis ProviderType = "redis" } -// OperationContext already exists, we just have to add the Variables() method +// OperationContext provides information about the GraphQL operation type OperationContext interface { Name() string - // the variables are currently not available, so we need to expose them here Variables() *astjson.Value } -// each provider will have its own event type with custom fields -// the StreamEvent interface is used to allow the hooks system to be provider-agnostic +// StreamEvents is a wrapper around a list of stream events providing safe iteration +type StreamEvents struct { + evts []StreamEvent +} + +func (e StreamEvents) All() iter.Seq2[int, StreamEvent] // iterator for all events +func (e StreamEvents) Len() int // returns the number of events +func (e StreamEvents) Unsafe() []StreamEvent // returns the underlying slice + +func NewStreamEvents(evts []StreamEvent) StreamEvents + +// StreamEvent is a generic immutable event. +// Every provider will have it's distinct implementation with additionals fields. +// Common to all providers is that their events have a payload. type StreamEvent interface { + // GetData returns a copy of payload data of the event GetData() []byte + // Clone returns a mutable copy of the event + Clone() MutableStreamEvent +} + +// MutableStreamEvent is a StreamEvent that can be modified. +type MutableStreamEvent interface { + StreamEvent + // SetData sets the payload data for this event + SetData([]byte) } // SubscriptionEventConfiguration is the common interface for the subscription event configuration type SubscriptionEventConfiguration interface { ProviderID() string - ProviderType() string + ProviderType() ProviderType // the root field name of the subscription in the schema RootFieldName() string } @@ -57,7 +78,7 @@ type SubscriptionEventConfiguration interface { // PublishEventConfiguration is the common interface for the publish event configuration type PublishEventConfiguration interface { ProviderID() string - ProviderType() string + ProviderType() ProviderType // the root field name of the mutation in the schema RootFieldName() string } @@ -76,6 +97,8 @@ type SubscriptionOnStartHandlerContext interface { // WriteEvent writes an event to the stream of the current subscription // It returns true if the event was written to the stream, false if the event was dropped WriteEvent(event datasource.StreamEvent) bool + // NewEvent creates a new event that can be used in the subscription. + NewEvent(data []byte) datasource.MutableStreamEvent } type SubscriptionOnStartHandler interface { @@ -95,14 +118,20 @@ type StreamReceiveEventHandlerContext interface { Authentication() authentication.Authentication // SubscriptionEventConfiguration is the subscription event configuration SubscriptionEventConfiguration() SubscriptionEventConfiguration + // NewEvent creates a new event that can be used in the subscription. + NewEvent(data []byte) datasource.MutableStreamEvent } type StreamReceiveEventHandler interface { - // OnReceiveEvents is called each time a batch of events is received from the provider before delivering them to the client - // So for a single batch of events received from the provider, this hook will be called one time for each active subscription. - // It is important to optimize the logic inside this hook to avoid performance issues. - // Returning an error will result in a GraphQL error being returned to the client - OnReceiveEvents(ctx StreamReceiveEventHandlerContext, events []StreamEvent) ([]StreamEvent, error) + // OnReceiveEvents is called whenever a batch of events is received from a provider, + // before delivering them to clients. + // The hook will be called once for each active subscription, therefore it is advised to + // avoid resource heavy computation or blocking tasks whenever possible. + // The events argument contains all events from a batch and is shared between + // all active subscribers of these events. + // Use events.All() to iterate through them and event.Clone() to create mutable copies, when needed. + // Returning an error will result in the subscription being closed and the error being logged. + OnReceiveEvents(ctx StreamReceiveEventHandlerContext, events StreamEvents) (StreamEvents, error) } type StreamPublishEventHandlerContext interface { @@ -116,21 +145,40 @@ type StreamPublishEventHandlerContext interface { Authentication() authentication.Authentication // PublishEventConfiguration is the publish event configuration PublishEventConfiguration() PublishEventConfiguration + // NewEvent creates a new event that can be used in the subscription. + NewEvent(data []byte) datasource.MutableStreamEvent } type StreamPublishEventHandler interface { - // OnPublishEvents is called each time a batch of events is going to be sent to the provider - // Returning an error will result in an error being returned and the client will see the mutation failing - OnPublishEvents(ctx StreamPublishEventHandlerContext, events []StreamEvent) ([]StreamEvent, error) + // OnPublishEvents is called each time a batch of events is going to be sent to a provider. + // The events argument contains all events from a batch. + // Use events.All() to iterate through them and event.Clone() to create mutable copies, when needed. + // Returning an error will result in a GraphQL error being returned to the client. + OnPublishEvents(ctx StreamPublishEventHandlerContext, events StreamEvents) (StreamEvents, error) } ``` +## Immutable vs Mutable events + +The design of `StreamEvent` and `MutableStreamEvent` interfaces addresses a critical performance and safety trade-off in the event handling system. When events are received from a provider, they are typically delivered to multiple active subscriptions simultaneously. The `OnReceiveEvents` handler is called once for each active subscription, meaning the same batch of events needs to be processed by multiple handlers concurrently. + +The primary design challenge was avoiding unnecessary memory allocations and data copying while maintaining safety guarantees. If we automatically created a deep copy of all events before each handler invocation, the performance cost would be significant, especially under high load with many active subscriptions. However, if we simply passed mutable references to all handlers, we would risk handlers inadvertently modifying shared event data, causing unexpected behavior for other subscribers processing the same events. + +The current solution leverages immutability as the default behavior with explicit opt-in mutability. The `StreamEvent` interface is designed to be immutable: the `GetData()` method returns a copy of the payload data, ensuring that read operations are safe by default. When a handler needs to modify an event, it must explicitly call the `Clone()` method to obtain a `MutableStreamEvent`. This creates a conscious decision point where developers understand they are creating a new copy that can be safely modified without affecting other subscriptions. + +The `MutableStreamEvent` interface extends `StreamEvent` and adds the `SetData()` method, allowing modifications only on explicitly cloned copies. This design pattern ensures that: +1. Handlers that only read event data incur no copying overhead +2. Multiple subscriptions can safely share the same underlying event data +3. Modifications are isolated to the specific subscription that cloned the event +4. The API makes the performance implications of cloning explicit and intentional + ## Example Use Cases - **Authorization**: Implementing authorization checks at the start of subscriptions - **Initial message**: Sending an initial message to clients upon subscription start - **Data mapping**: Transforming events data from the format that could be used by the external system to/from Federation compatible Router events - **Event filtering**: Filtering events using custom logic +- **Event creation**: Creating new events from scratch using `ctx.NewEvent(data)` method available in all handler contexts ## Backwards Compatibility @@ -144,7 +192,7 @@ When the new module system will be released, the Cosmo Streams hooks: # Example Modules -__All examples are pseudocode and not tested, but they are as close as possible to the final implementation__ +__All examples reflect the current implementation and match the actual API__ ## Filter and remap events @@ -174,9 +222,10 @@ package mymodule import ( "encoding/json" + "fmt" "slices" "github.com/wundergraph/cosmo/router/core" - "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" ) func init() { @@ -186,9 +235,9 @@ func init() { type MyModule struct {} -func (m *MyModule) OnReceiveEvents(ctx StreamReceiveEventHandlerContext, events []core.StreamEvent) ([]core.StreamEvent, error) { +func (m *MyModule) OnReceiveEvents(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { // check if the provider is nats - if ctx.SubscriptionEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { + if ctx.SubscriptionEventConfiguration().ProviderType() != datasource.ProviderTypeNats { return events, nil } @@ -202,33 +251,33 @@ func (m *MyModule) OnReceiveEvents(ctx StreamReceiveEventHandlerContext, events return events, nil } - newEvents := make([]core.StreamEvent, 0, len(events)) + newEvents := make([]datasource.StreamEvent, 0, events.Len()) // check if the client is authenticated if ctx.Authentication() == nil { // if the client is not authenticated, return no events - return newEvents, nil + return datasource.NewStreamEvents(newEvents), nil } // check if the client is allowed to subscribe to the stream - clientAllowedEntitiesIds, found := ctx.Authentication().Claims()["allowedEntitiesIds"] + allowedEntitiesIdsRaw, found := ctx.Authentication().Claims()["allowedEntitiesIds"] if !found { - return newEvents, fmt.Errorf("client is not allowed to subscribe to the stream") + return datasource.NewStreamEvents(newEvents), fmt.Errorf("client is not allowed to subscribe to the stream") + } + + // type assert to string slice + clientAllowedEntitiesIds, ok := allowedEntitiesIdsRaw.([]string) + if !ok { + return datasource.NewStreamEvents(newEvents), fmt.Errorf("allowedEntitiesIds claim is not a string slice") } - for _, evt := range events { - natsEvent, ok := evt.(*nats.NatsEvent) - if !ok { - newEvents = append(newEvents, evt) - continue - } - + for _, evt := range events.All() { // decode the event data coming from the provider var dataReceived struct { EmployeeId string `json:"EmployeeId"` OtherField string `json:"OtherField"` } - err := json.Unmarshal(natsEvent.Data, &dataReceived) + err := json.Unmarshal(evt.GetData(), &dataReceived) if err != nil { return events, fmt.Errorf("error unmarshalling data: %w", err) } @@ -252,14 +301,11 @@ func (m *MyModule) OnReceiveEvents(ctx StreamReceiveEventHandlerContext, events return events, fmt.Errorf("error marshalling data: %w", err) } - // create the new event - newEvent := &nats.NatsEvent{ - Data: dataToSendMarshalled, - Metadata: natsEvent.Metadata, - } + // create a new event using the context's NewEvent method + newEvent := ctx.NewEvent(dataToSendMarshalled) newEvents = append(newEvents, newEvent) } - return newEvents, nil + return datasource.NewStreamEvents(newEvents), nil } func (m *MyModule) Module() core.ModuleInfo { @@ -316,10 +362,9 @@ The developer will need to write the custom module that will be used to check th package mymodule import ( - "encoding/json" - "slices" + "net/http" "github.com/wundergraph/cosmo/router/core" - "github.com/wundergraph/cosmo/router/pkg/pubsub/nats" + "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" ) func init() { @@ -329,9 +374,9 @@ func init() { type MyModule struct {} -func (m *MyModule) SubscriptionOnStart(ctx SubscriptionOnStartHandlerContext) error { +func (m *MyModule) SubscriptionOnStart(ctx core.SubscriptionOnStartHandlerContext) error { // check if the provider is nats - if ctx.SubscriptionEventConfiguration().ProviderType() != pubsub.ProviderTypeNats { + if ctx.SubscriptionEventConfiguration().ProviderType() != datasource.ProviderTypeNats { return nil } @@ -348,21 +393,21 @@ func (m *MyModule) SubscriptionOnStart(ctx SubscriptionOnStartHandlerContext) er // check if the client is authenticated if ctx.Authentication() == nil { // if the client is not authenticated, return an error - return &core.HttpError{ - Code: http.StatusUnauthorized, - Message: "client is not authenticated", - CloseSubscription: true, - } + return core.NewHttpGraphqlError( + "client is not authenticated", + http.StatusText(http.StatusUnauthorized), + http.StatusUnauthorized, + ) } // check if the client is allowed to subscribe to the stream - clientAllowedEntitiesIds, found := ctx.Authentication().Claims()["readEmployee"] + _, found := ctx.Authentication().Claims()["readEmployee"] if !found { - return &core.HttpError{ - Code: http.StatusForbidden, - Message: "client is not allowed to read employees", - CloseSubscription: true, - } + return core.NewHttpGraphqlError( + "client is not allowed to read employees", + http.StatusText(http.StatusForbidden), + http.StatusForbidden, + ) } return nil From 405f64c88dbc4358ed7e0d5f2b63c8b2157b308d Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Tue, 11 Nov 2025 17:10:00 +0100 Subject: [PATCH 09/44] chore: go mod tidy --- router-tests/go.sum | 2 -- router/go.sum | 20 ++++++++++---------- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/router-tests/go.sum b/router-tests/go.sum index e829cdb26c..e2af6054f8 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -354,8 +354,6 @@ github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTB github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301 h1:EzfKHQoTjFDDcgaECCCR2aTePqMu9QBmPbyhqIYOhV0= github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301/go.mod h1:wxI0Nak5dI5RvJuzGyiEK4nZj0O9X+Aw6U0tC1wPKq0= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.230.0.20251016135804-06f55f15daa7 h1:lB6ZcFpspUAQu9myScxjFZW+iWbXy44tyEzTMXCu/uw= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.230.0.20251016135804-06f55f15daa7/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.237.0.20251110152155-423a60c6a33e h1:246mrdmTHRIsW9yVQjFKQlAgvw+sNES1FymnVjJ7r/Q= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.237.0.20251110152155-423a60c6a33e/go.mod h1:ErOQH1ki2+SZB8JjpTyGVnoBpg5picIyjvuWQJP4abg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= diff --git a/router/go.sum b/router/go.sum index 6f844fa6b0..da5335fbcb 100644 --- a/router/go.sum +++ b/router/go.sum @@ -1,7 +1,7 @@ connectrpc.com/connect v1.16.2 h1:ybd6y+ls7GOlb7Bh5C8+ghA6SvCBajHwxssO2CGFjqE= connectrpc.com/connect v1.16.2/go.mod h1:n2kgwskMHXC+lVqb18wngEpF95ldBHXjZYJussz5FRc= -github.com/99designs/gqlgen v0.17.45 h1:bH0AH67vIJo8JKNKPJP+pOPpQhZeuVRQLf53dKIpDik= -github.com/99designs/gqlgen v0.17.45/go.mod h1:Bas0XQ+Jiu/Xm5E33jC8sES3G+iC2esHBMXcq0fUPs0= +github.com/99designs/gqlgen v0.17.76 h1:YsJBcfACWmXWU2t1yCjoGdOmqcTfOFpjbLAE443fmYI= +github.com/99designs/gqlgen v0.17.76/go.mod h1:miiU+PkAnTIDKMQ1BseUOIVeQHoiwYDZGCswoxl7xec= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/KimMachineGun/automemlimit v0.6.1 h1:ILa9j1onAAMadBsyyUJv5cack8Y1WT26yLj/V+ulKp8= github.com/KimMachineGun/automemlimit v0.6.1/go.mod h1:T7xYht7B8r6AG/AqFcUdc7fzd2bIdBKmepfP2S1svPY= @@ -9,8 +9,8 @@ github.com/MicahParks/jwkset v0.11.0 h1:yc0zG+jCvZpWgFDFmvs8/8jqqVBG9oyIbmBtmjOh github.com/MicahParks/jwkset v0.11.0/go.mod h1:U2oRhRaLgDCLjtpGL2GseNKGmZtLs/3O7p+OZaL5vo0= github.com/MicahParks/keyfunc/v3 v3.6.2 h1:82rre60MKw4r117ew5/T4m1AphgkpCOYry0RPbFUY3w= github.com/MicahParks/keyfunc/v3 v3.6.2/go.mod h1:z66bkCviwqfg2YUp+Jcc/xRE9IXLcMq6DrgV/+Htru0= -github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8= -github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo= +github.com/agnivade/levenshtein v1.2.1 h1:EHBY3UOn1gwdy/VbFwgo4cxecRznFk7fKWN1KOX7eoM= +github.com/agnivade/levenshtein v1.2.1/go.mod h1:QVVI16kDrtSuwcpd0p1+xMC6Z/VfhtCyDIjcwga4/DU= github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 h1:uvdUDbHQHO85qeSydJtItA4T55Pw6BtAejd0APRJOCE= github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= github.com/alicebob/miniredis/v2 v2.34.0 h1:mBFWMaJSNL9RwdGRyEDoAAv8OQc5UlEhLDQggTglU/0= @@ -94,6 +94,8 @@ github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY= github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0= github.com/go-redis/redis_rate/v10 v10.0.1 h1:calPxi7tVlxojKunJwQ72kwfozdy25RjA0bCj1h0MUo= github.com/go-redis/redis_rate/v10 v10.0.1/go.mod h1:EMiuO9+cjRkR7UvdvwMO7vbgqJkltQHtwbdIQvaBKIU= +github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= +github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/gobwas/httphead v0.1.0 h1:exrUm0f4YX0L7EBwZHuCF4GDp8aJfVeBrlLQrs6NqWU= github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= github.com/gobwas/pool v0.2.1 h1:xfeeEhW7pwmX8nuLVlqbzVc7udMDrwetjEv+TZIz1og= @@ -270,8 +272,8 @@ github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnj github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= -github.com/sosodev/duration v1.2.0 h1:pqK/FLSjsAADWY74SyWDCjOcd5l7H8GSnnOGEB9A1Us= -github.com/sosodev/duration v1.2.0/go.mod h1:RQIBBX0+fMLc/D9+Jb/fwvVmo0eZvDDEERAikUR6SDg= +github.com/sosodev/duration v1.3.1 h1:qtHBDMQ6lvMQsL15g4aopM4HEfOaYuhWBw3NPTtlqq4= +github.com/sosodev/duration v1.3.1/go.mod h1:RQIBBX0+fMLc/D9+Jb/fwvVmo0eZvDDEERAikUR6SDg= github.com/spf13/cast v1.7.1 h1:cuNEagBQEHWN1FnbGEjCXL2szYEXqfJPbP2HNUaca9Y= github.com/spf13/cast v1.7.1/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -314,14 +316,12 @@ github.com/twmb/franz-go/pkg/kmsg v1.7.0 h1:a457IbvezYfA5UkiBvyV3zj0Is3y1i8EJgqj github.com/twmb/franz-go/pkg/kmsg v1.7.0/go.mod h1:se9Mjdt0Nwzc9lnjJ0HyDtLyBnaBDAd7pCje47OhSyw= github.com/vbatts/tar-split v0.12.1 h1:CqKoORW7BUWBe7UL/iqTVvkTBOF8UvOMKOIZykxnnbo= github.com/vbatts/tar-split v0.12.1/go.mod h1:eF6B6i6ftWQcDqEn3/iGFRFRo8cBIMSJVOpnNdfTMFA= -github.com/vektah/gqlparser/v2 v2.5.14 h1:dzLq75BJe03jjQm6n56PdH1oweB8ana42wj7E4jRy70= -github.com/vektah/gqlparser/v2 v2.5.14/go.mod h1:WQQjFc+I1YIzoPvZBhUQX7waZgg3pMLi0r8KymvAE2w= +github.com/vektah/gqlparser/v2 v2.5.30 h1:EqLwGAFLIzt1wpx1IPpY67DwUujF1OfzgEyDsLrN6kE= +github.com/vektah/gqlparser/v2 v2.5.30/go.mod h1:D1/VCZtV3LPnQrcPBeR/q5jkSQIPti0uYCP/RI0gIeo= github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/fJgbpc= github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.230.0.20251016135804-06f55f15daa7 h1:lB6ZcFpspUAQu9myScxjFZW+iWbXy44tyEzTMXCu/uw= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.230.0.20251016135804-06f55f15daa7/go.mod h1:g1IFIylu5Fd9pKjzq0mDvpaKhEB/vkwLAIbGdX2djXU= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.237.0.20251110152155-423a60c6a33e h1:246mrdmTHRIsW9yVQjFKQlAgvw+sNES1FymnVjJ7r/Q= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.237.0.20251110152155-423a60c6a33e/go.mod h1:ErOQH1ki2+SZB8JjpTyGVnoBpg5picIyjvuWQJP4abg= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= From b3bc7afca30d020fa4091601ee14ace2832fbc13 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Thu, 13 Nov 2025 17:14:28 +0100 Subject: [PATCH 10/44] feat(router): add a timeout for on_receive_event hooks (#2329) --- router-tests/modules/stream_receive_test.go | 330 +++++++----------- router/core/factoryresolver.go | 31 +- router/core/router.go | 18 +- router/core/router_config.go | 21 +- router/core/subscriptions_modules.go | 66 +++- router/demo.config.yaml | 2 +- router/pkg/config/config.go | 7 +- router/pkg/config/config.schema.json | 22 +- router/pkg/config/fixtures/full.yaml | 4 +- .../pkg/config/testdata/config_defaults.json | 5 +- router/pkg/config/testdata/config_full.json | 5 +- router/pkg/pubsub/datasource/hooks.go | 27 +- .../pkg/pubsub/datasource/pubsubprovider.go | 4 +- .../pubsub/datasource/pubsubprovider_test.go | 46 ++- .../datasource/subscription_datasource.go | 11 +- .../subscription_datasource_test.go | 32 +- .../datasource/subscription_event_updater.go | 101 +++--- .../subscription_event_updater_test.go | 308 ++++++++-------- router/pkg/pubsub/pubsub_test.go | 8 +- 19 files changed, 557 insertions(+), 491 deletions(-) diff --git a/router-tests/modules/stream_receive_test.go b/router-tests/modules/stream_receive_test.go index 23477f3925..325600915d 100644 --- a/router-tests/modules/stream_receive_test.go +++ b/router-tests/modules/stream_receive_test.go @@ -1,8 +1,8 @@ package module_test import ( + "encoding/json" "errors" - "fmt" "net/http" "sync/atomic" "testing" @@ -522,207 +522,6 @@ func TestReceiveHook(t *testing.T) { }) }) - t.Run("Test error deduplication with multiple subscriptions", func(t *testing.T) { - t.Parallel() - - cfg := config.Config{ - Graph: config.Graph{}, - Modules: map[string]interface{}{ - "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - return datasource.NewStreamEvents(nil), errors.New("deduplicated error") - }, - }, - }, - } - - testenv.Run(t, &testenv.Config{ - RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, - EnableKafka: true, - RouterOptions: []core.Option{ - core.WithModulesConfig(cfg.Modules), - core.WithCustomModules(&stream_receive.StreamReceiveModule{}), - }, - LogObservation: testenv.LogObservationConfig{ - Enabled: true, - LogLevel: zapcore.ErrorLevel, - }, - }, func(t *testing.T, xEnv *testenv.Environment) { - topics := []string{"employeeUpdated"} - events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) - - var subscriptionOne struct { - employeeUpdatedMyKafka struct { - ID float64 `graphql:"id"` - Details struct { - Forename string `graphql:"forename"` - Surname string `graphql:"surname"` - } `graphql:"details"` - } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` - } - - surl := xEnv.GraphQLWebSocketSubscriptionURL() - - // Create 3 subscriptions that will all receive the same error - clients := make([]*graphql.SubscriptionClient, 3) - clientRunChs := make([]chan error, 3) - - for i := range 3 { - clients[i] = graphql.NewSubscriptionClient(surl) - clientRunChs[i] = make(chan error) - - subscriptionID, err := clients[i].Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { - return nil - }) - require.NoError(t, err) - require.NotEmpty(t, subscriptionID) - - go func() { - clientRunChs[i] <- clients[i].Run() - }() - } - - // Wait for all subscriptions to be established - xEnv.WaitForSubscriptionCount(3, Timeout) - - // Produce a message that will trigger the error in all handlers - events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) - - // Wait for all subscriptions to be closed due to the error - xEnv.WaitForSubscriptionCount(0, Timeout) - - // Verify all clients completed - for i := 0; i < 3; i++ { - testenv.AwaitChannelWithT(t, Timeout, clientRunChs[i], func(t *testing.T, err error) { - require.NoError(t, err) - }, "client should have completed when server closed connection") - } - - xEnv.WaitForTriggerCount(0, Timeout) - - // Verify error deduplication: should see only one error log entry - errorLogs := xEnv.Observer().FilterMessage("some handlers have thrown an error") - assert.Len(t, errorLogs.All(), 1, "should have exactly one deduplicated error log entry") - - // Verify the error log contains the correct error message and count - if len(errorLogs.All()) > 0 { - logEntry := errorLogs.All()[0] - fields := logEntry.ContextMap() - - assert.Equal(t, "deduplicated error", fields["error"], "error message should match") - assert.Equal(t, int64(3), fields["amount_handlers"], "should count all 3 handlers that threw the error") - } - }) - }) - - t.Run("Test unique error messages are all logged", func(t *testing.T) { - t.Parallel() - - var errorCounter atomic.Int32 - - cfg := config.Config{ - Graph: config.Graph{}, - Modules: map[string]interface{}{ - "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - count := errorCounter.Add(1) - return datasource.NewStreamEvents(nil), fmt.Errorf("unique error %d", count) - }, - }, - }, - } - - testenv.Run(t, &testenv.Config{ - RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, - EnableKafka: true, - RouterOptions: []core.Option{ - core.WithModulesConfig(cfg.Modules), - core.WithCustomModules(&stream_receive.StreamReceiveModule{}), - }, - LogObservation: testenv.LogObservationConfig{ - Enabled: true, - LogLevel: zapcore.ErrorLevel, - }, - }, func(t *testing.T, xEnv *testenv.Environment) { - topics := []string{"employeeUpdated"} - events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) - - var subscriptionOne struct { - employeeUpdatedMyKafka struct { - ID float64 `graphql:"id"` - Details struct { - Forename string `graphql:"forename"` - Surname string `graphql:"surname"` - } `graphql:"details"` - } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` - } - - surl := xEnv.GraphQLWebSocketSubscriptionURL() - - // Create 3 subscriptions that will each receive a unique error - clients := make([]*graphql.SubscriptionClient, 3) - clientRunChs := make([]chan error, 3) - - for i := range 3 { - clients[i] = graphql.NewSubscriptionClient(surl) - clientRunChs[i] = make(chan error) - - subscriptionID, err := clients[i].Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { - return nil - }) - require.NoError(t, err) - require.NotEmpty(t, subscriptionID) - - go func() { - clientRunChs[i] <- clients[i].Run() - }() - } - - // Wait for all subscriptions to be established - xEnv.WaitForSubscriptionCount(3, Timeout) - - // Produce a message that will trigger a unique error in each handler - events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) - - // Wait for all subscriptions to be closed due to the error - xEnv.WaitForSubscriptionCount(0, Timeout) - - // Verify all clients completed - for i := range 3 { - testenv.AwaitChannelWithT(t, Timeout, clientRunChs[i], func(t *testing.T, err error) { - require.NoError(t, err) - }, "client should have completed when server closed connection") - } - - xEnv.WaitForTriggerCount(0, Timeout) - - // Verify no deduplication: should see three error log entries (one for each unique error) - errorLogs := xEnv.Observer().FilterMessage("some handlers have thrown an error") - assert.Len(t, errorLogs.All(), 3, "should have three separate error log entries for unique errors") - - // Verify each error log contains a unique error message and count of 1 - if len(errorLogs.All()) == 3 { - var errorMessages []string - for _, logEntry := range errorLogs.All() { - fields := logEntry.ContextMap() - errorMsg, ok := fields["error"].(string) - require.True(t, ok, "error field should be a string") - - // Check that error message is unique (starts with "unique error") - assert.Contains(t, errorMsg, "unique error", "error message should contain 'unique error'") - assert.NotContains(t, errorMessages, errorMsg, "each error message should be unique") - errorMessages = append(errorMessages, errorMsg) - - // Each unique error should have been thrown by exactly 1 handler - assert.Equal(t, int64(1), fields["amount_handlers"], "each unique error should have amount_handlers = 1") - } - - // Verify we got exactly 3 unique error messages - assert.Len(t, errorMessages, 3, "should have exactly 3 unique error messages") - } - }) - }) - t.Run("Test concurrent handler execution works", func(t *testing.T) { t.Parallel() @@ -813,7 +612,9 @@ func TestReceiveHook(t *testing.T) { core.WithModulesConfig(cfg.Modules), core.WithCustomModules(&stream_receive.StreamReceiveModule{}), core.WithSubscriptionHooks(config.SubscriptionHooksConfiguration{ - MaxConcurrentEventReceiveHandlers: tc.maxConcurrent, + OnReceiveEvents: config.OnReceiveEventsConfiguration{ + MaxConcurrentHandlers: tc.maxConcurrent, + }, }), }, LogObservation: testenv.LogObservationConfig{ @@ -893,4 +694,127 @@ func TestReceiveHook(t *testing.T) { }) } }) + + t.Run("Test timeout mechanism allows out-of-order event delivery", func(t *testing.T) { + t.Parallel() + + // One subscriber receives three consecutive events. + // The first event's hook is delayed, exceeding the timeout. + // The second and third events' hooks process immediately without delay. + // Because the first hook exceeds the timeout, the system abandons waiting for it + // and processes the second and third events. + // The first event will be delivered later when its hook finally completes. + // This should result in event order [2, 3, 1] at the client. + + hookDelay := 500 * time.Millisecond + hookTimeout := 100 * time.Millisecond + + var callCount atomic.Int32 + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{ + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + // Only the first call should delay + if callCount.Add(1) == 1 { + time.Sleep(hookDelay) + } + return events, nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + core.WithSubscriptionHooks(config.SubscriptionHooksConfiguration{ + OnReceiveEvents: config.OnReceiveEventsConfiguration{ + MaxConcurrentHandlers: 3, + HandlerTimeout: hookTimeout, + }, + }), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + subscriptionArgsCh := make(chan kafkaSubscriptionArgs, 3) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, Timeout) + + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"first"}}`) + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 2,"update":{"name":"second"}}`) + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 3,"update":{"name":"third"}}`) + + // Collect all 3 events + receivedIDs := make([]float64, 0, 3) + for i := 0; i < 3; i++ { + testenv.AwaitChannelWithT(t, Timeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + + var response struct { + EmployeeUpdatedMyKafka struct { + ID float64 `json:"id"` + } `json:"employeeUpdatedMyKafka"` + } + err := json.Unmarshal(args.dataValue, &response) + require.NoError(t, err) + receivedIDs = append(receivedIDs, response.EmployeeUpdatedMyKafka.ID) + }) + } + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, Timeout, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + // Verify events arrived out of order: event 1 should be the last one to arrive + assert.ElementsMatch(t, []float64{1, 2, 3}, receivedIDs, "expected to receive all events") + assert.Equal(t, float64(1), receivedIDs[len(receivedIDs)-1], "expected the delayed event to arrive last") + assert.NotEqual(t, float64(1), receivedIDs[0], "expected at least one later event to arrive before the delayed one") + + timeoutLog := xEnv.Observer().FilterMessage("Timeout exceeded during subscription updates, events may arrive out of order") + assert.Len(t, timeoutLog.All(), 1, "expected timeout warning to be logged") + + // Verify all hooks were executed + hookLog := xEnv.Observer().FilterMessage("Stream Hook has been run") + assert.Len(t, hookLog.All(), 3) + }) + }) } diff --git a/router/core/factoryresolver.go b/router/core/factoryresolver.go index 36ea68d8e4..d8cfecc283 100644 --- a/router/core/factoryresolver.go +++ b/router/core/factoryresolver.go @@ -418,8 +418,8 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod } } - subscriptionOnStartFns := make([]graphql_datasource.SubscriptionOnStartFn, len(l.subscriptionHooks.onStart)) - for i, fn := range l.subscriptionHooks.onStart { + subscriptionOnStartFns := make([]graphql_datasource.SubscriptionOnStartFn, len(l.subscriptionHooks.onStart.handlers)) + for i, fn := range l.subscriptionHooks.onStart.handlers { subscriptionOnStartFns[i] = NewEngineSubscriptionOnStartHook(fn) } customConfiguration, err := graphql_datasource.NewConfiguration(graphql_datasource.ConfigurationInput{ @@ -477,18 +477,18 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod } } - subscriptionOnStartFns := make([]pubsub_datasource.SubscriptionOnStartFn, len(l.subscriptionHooks.onStart)) - for i, fn := range l.subscriptionHooks.onStart { + subscriptionOnStartFns := make([]pubsub_datasource.SubscriptionOnStartFn, len(l.subscriptionHooks.onStart.handlers)) + for i, fn := range l.subscriptionHooks.onStart.handlers { subscriptionOnStartFns[i] = NewPubSubSubscriptionOnStartHook(fn) } - onPublishEventsFns := make([]pubsub_datasource.OnPublishEventsFn, len(l.subscriptionHooks.onPublishEvents)) - for i, fn := range l.subscriptionHooks.onPublishEvents { + onPublishEventsFns := make([]pubsub_datasource.OnPublishEventsFn, len(l.subscriptionHooks.onPublishEvents.handlers)) + for i, fn := range l.subscriptionHooks.onPublishEvents.handlers { onPublishEventsFns[i] = NewPubSubOnPublishEventsHook(fn) } - onReceiveEventsFns := make([]pubsub_datasource.OnReceiveEventsFn, len(l.subscriptionHooks.onReceiveEvents)) - for i, fn := range l.subscriptionHooks.onReceiveEvents { + onReceiveEventsFns := make([]pubsub_datasource.OnReceiveEventsFn, len(l.subscriptionHooks.onReceiveEvents.handlers)) + for i, fn := range l.subscriptionHooks.onReceiveEvents.handlers { onReceiveEventsFns[i] = NewPubSubOnReceiveEventsHook(fn) } @@ -501,10 +501,17 @@ func (l *Loader) Load(engineConfig *nodev1.EngineConfiguration, subgraphs []*nod l.resolver.InstanceData().HostName, l.resolver.InstanceData().ListenAddress, pubsub_datasource.Hooks{ - SubscriptionOnStart: subscriptionOnStartFns, - OnReceiveEvents: onReceiveEventsFns, - OnPublishEvents: onPublishEventsFns, - MaxConcurrentOnReceiveHandlers: l.subscriptionHooks.maxConcurrentOnReceiveHooks, + SubscriptionOnStart: pubsub_datasource.SubscriptionOnStartHooks{ + Handlers: subscriptionOnStartFns, + }, + OnPublishEvents: pubsub_datasource.OnPublishEventsHooks{ + Handlers: onPublishEventsFns, + }, + OnReceiveEvents: pubsub_datasource.OnReceiveEventsHooks{ + Handlers: onReceiveEventsFns, + MaxConcurrentHandlers: l.subscriptionHooks.onReceiveEvents.maxConcurrentHandlers, + Timeout: l.subscriptionHooks.onReceiveEvents.timeout, + }, }, ) if err != nil { diff --git a/router/core/router.go b/router/core/router.go index 8ebac78a5b..20b0b33ae6 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -253,9 +253,12 @@ func NewRouter(opts ...Option) (*Router, error) { r.metricConfig = rmetric.DefaultConfig(Version) } - // Default value for maxConcurrentOnReceiveHooks - if r.subscriptionHooks.maxConcurrentOnReceiveHooks == 0 { - r.subscriptionHooks.maxConcurrentOnReceiveHooks = 100 + if r.subscriptionHooks.onReceiveEvents.maxConcurrentHandlers == 0 { + r.subscriptionHooks.onReceiveEvents.maxConcurrentHandlers = 100 + } + + if r.subscriptionHooks.onReceiveEvents.timeout == 0 { + r.subscriptionHooks.onReceiveEvents.timeout = 5 * time.Second } if r.corsOptions == nil { @@ -681,15 +684,15 @@ func (r *Router) initModules(ctx context.Context) error { } if handler, ok := moduleInstance.(SubscriptionOnStartHandler); ok { - r.subscriptionHooks.onStart = append(r.subscriptionHooks.onStart, handler.SubscriptionOnStart) + r.subscriptionHooks.onStart.handlers = append(r.subscriptionHooks.onStart.handlers, handler.SubscriptionOnStart) } if handler, ok := moduleInstance.(StreamPublishEventHandler); ok { - r.subscriptionHooks.onPublishEvents = append(r.subscriptionHooks.onPublishEvents, handler.OnPublishEvents) + r.subscriptionHooks.onPublishEvents.handlers = append(r.subscriptionHooks.onPublishEvents.handlers, handler.OnPublishEvents) } if handler, ok := moduleInstance.(StreamReceiveEventHandler); ok { - r.subscriptionHooks.onReceiveEvents = append(r.subscriptionHooks.onReceiveEvents, handler.OnReceiveEvents) + r.subscriptionHooks.onReceiveEvents.handlers = append(r.subscriptionHooks.onReceiveEvents.handlers, handler.OnReceiveEvents) } r.modules = append(r.modules, moduleInstance) @@ -2139,7 +2142,8 @@ func WithDemoMode(demoMode bool) Option { func WithSubscriptionHooks(cfg config.SubscriptionHooksConfiguration) Option { return func(r *Router) { - r.subscriptionHooks.maxConcurrentOnReceiveHooks = cfg.MaxConcurrentEventReceiveHandlers + r.subscriptionHooks.onReceiveEvents.maxConcurrentHandlers = cfg.OnReceiveEvents.MaxConcurrentHandlers + r.subscriptionHooks.onReceiveEvents.timeout = cfg.OnReceiveEvents.HandlerTimeout } } diff --git a/router/core/router_config.go b/router/core/router_config.go index b0027b3e8a..8616946eac 100644 --- a/router/core/router_config.go +++ b/router/core/router_config.go @@ -27,10 +27,23 @@ import ( ) type subscriptionHooks struct { - onStart []func(ctx SubscriptionOnStartHandlerContext) error - onPublishEvents []func(ctx StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) - onReceiveEvents []func(ctx StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) - maxConcurrentOnReceiveHooks int + onStart onStartHooks + onPublishEvents onPublishEventsHooks + onReceiveEvents onReceiveEventsHooks +} + +type onStartHooks struct { + handlers []func(ctx SubscriptionOnStartHandlerContext) error +} + +type onPublishEventsHooks struct { + handlers []func(ctx StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) +} + +type onReceiveEventsHooks struct { + handlers []func(ctx StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) + maxConcurrentHandlers int + timeout time.Duration } type Config struct { diff --git a/router/core/subscriptions_modules.go b/router/core/subscriptions_modules.go index bcfaaae114..90ca6cf636 100644 --- a/router/core/subscriptions_modules.go +++ b/router/core/subscriptions_modules.go @@ -191,9 +191,22 @@ func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHandlerCont return func(resolveCtx resolve.StartupHookContext, subConf datasource.SubscriptionEventConfiguration, eventBuilder datasource.EventBuilderFn) error { requestContext := getRequestContext(resolveCtx.Context) + + logger := requestContext.Logger() + if logger != nil { + logger = logger.With(zap.String("component", "pubsub_subscription_on_start_hook")) + if subConf != nil { + logger = logger.With( + zap.String("provider_id", subConf.ProviderID()), + zap.String("provider_type", string(subConf.ProviderType())), + zap.String("field_name", subConf.RootFieldName()), + ) + } + } + hookCtx := &pubSubSubscriptionOnStartHookContext{ request: requestContext.Request(), - logger: requestContext.Logger(), + logger: logger, operation: requestContext.Operation(), authentication: requestContext.Authentication(), subscriptionEventConfiguration: subConf, @@ -213,9 +226,15 @@ func NewEngineSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHandlerCont return func(resolveCtx resolve.StartupHookContext, input []byte) error { requestContext := getRequestContext(resolveCtx.Context) + + logger := requestContext.Logger() + if logger != nil { + logger = logger.With(zap.String("component", "engine_subscription_on_start_hook")) + } + hookCtx := &engineSubscriptionOnStartHookContext{ request: requestContext.Request(), - logger: requestContext.Logger(), + logger: logger, operation: requestContext.Operation(), authentication: requestContext.Authentication(), writeEventHook: resolveCtx.Updater, @@ -226,6 +245,9 @@ func NewEngineSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHandlerCont } type StreamReceiveEventHandlerContext interface { + // Context is a context for handlers. + // If it is cancelled, the handler should stop processing. + Context() context.Context // Request is the initial client request that started the subscription Request() *http.Request // Logger is the logger for the request @@ -283,9 +305,22 @@ func NewPubSubOnPublishEventsHook(fn func(ctx StreamPublishEventHandlerContext, return func(ctx context.Context, pubConf datasource.PublishEventConfiguration, evts []datasource.StreamEvent, eventBuilder datasource.EventBuilderFn) ([]datasource.StreamEvent, error) { requestContext := getRequestContext(ctx) + + logger := requestContext.Logger() + if logger != nil { + logger = logger.With(zap.String("component", "on_publish_events_hook")) + if pubConf != nil { + logger = logger.With( + zap.String("provider_id", pubConf.ProviderID()), + zap.String("provider_type", string(pubConf.ProviderType())), + zap.String("field_name", pubConf.RootFieldName()), + ) + } + } + hookCtx := &pubSubPublishEventHookContext{ request: requestContext.Request(), - logger: requestContext.Logger(), + logger: logger, operation: requestContext.Operation(), authentication: requestContext.Authentication(), publishEventConfiguration: pubConf, @@ -305,6 +340,11 @@ type pubSubStreamReceiveEventHookContext struct { authentication authentication.Authentication subscriptionEventConfiguration datasource.SubscriptionEventConfiguration eventBuilder datasource.EventBuilderFn + context context.Context +} + +func (c *pubSubStreamReceiveEventHookContext) Context() context.Context { + return c.context } func (c *pubSubStreamReceiveEventHookContext) Request() *http.Request { @@ -336,15 +376,29 @@ func NewPubSubOnReceiveEventsHook(fn func(ctx StreamReceiveEventHandlerContext, return nil } - return func(ctx context.Context, subConf datasource.SubscriptionEventConfiguration, eventBuilder datasource.EventBuilderFn, evts []datasource.StreamEvent) ([]datasource.StreamEvent, error) { - requestContext := getRequestContext(ctx) + return func(subscriptionCtx context.Context, updaterCtx context.Context, subConf datasource.SubscriptionEventConfiguration, eventBuilder datasource.EventBuilderFn, evts []datasource.StreamEvent) ([]datasource.StreamEvent, error) { + requestContext := getRequestContext(subscriptionCtx) + + logger := requestContext.Logger() + if logger != nil { + logger = logger.With(zap.String("component", "on_receive_events_hook")) + if subConf != nil { + logger = logger.With( + zap.String("provider_id", subConf.ProviderID()), + zap.String("provider_type", string(subConf.ProviderType())), + zap.String("field_name", subConf.RootFieldName()), + ) + } + } + hookCtx := &pubSubStreamReceiveEventHookContext{ request: requestContext.Request(), - logger: requestContext.Logger(), + logger: logger, operation: requestContext.Operation(), authentication: requestContext.Authentication(), subscriptionEventConfiguration: subConf, eventBuilder: eventBuilder, + context: updaterCtx, } newEvts, err := fn(hookCtx, datasource.NewStreamEvents(evts)) return newEvts.Unsafe(), err diff --git a/router/demo.config.yaml b/router/demo.config.yaml index 2a081e74be..9a72e31de2 100644 --- a/router/demo.config.yaml +++ b/router/demo.config.yaml @@ -19,4 +19,4 @@ events: redis: - id: my-redis urls: - - "redis://localhost:6379/2" + - "redis://localhost:6379/2" \ No newline at end of file diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index f386b27156..bb8c910982 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -646,7 +646,12 @@ type EventsConfiguration struct { } type SubscriptionHooksConfiguration struct { - MaxConcurrentEventReceiveHandlers int `yaml:"max_concurrent_event_receive_handlers" envDefault:"100"` + OnReceiveEvents OnReceiveEventsConfiguration `yaml:"on_receive_events"` +} + +type OnReceiveEventsConfiguration struct { + MaxConcurrentHandlers int `yaml:"max_concurrent_handlers" envDefault:"100"` + HandlerTimeout time.Duration `yaml:"handler_timeout" envDefault:"5s"` } type Cluster struct { diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index f800e07245..a2326c86ec 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -2310,11 +2310,23 @@ "description": "Configuration for subscription custom modules that are executed when events are received from a broker.", "additionalProperties": false, "properties": { - "max_concurrent_event_receive_handlers": { - "type": "integer", - "description": "The maximum number of concurrent event receive handlers. This controls the concurrency of the OnReceiveEvents custom modules.", - "minimum": 1, - "default": 100 + "on_receive_events": { + "type": "object", + "description": "Configuration for the OnReceiveEvents hook that is called when events are received from a broker.", + "additionalProperties": false, + "properties": { + "max_concurrent_handlers": { + "type": "integer", + "description": "The maximum number of concurrent event receive handlers. This controls the concurrency of the OnReceiveEvents custom modules.", + "minimum": 1, + "default": 100 + }, + "handler_timeout": { + "type": "string", + "description": "The amount of time that OnReceiveEvents handlers can run in total for a single batch of events. Specify as a duration string (e.g., '5s', '1m', '500ms').", + "default": "5s" + } + } } } } diff --git a/router/pkg/config/fixtures/full.yaml b/router/pkg/config/fixtures/full.yaml index ee5b7d8ef8..3f79ed83df 100644 --- a/router/pkg/config/fixtures/full.yaml +++ b/router/pkg/config/fixtures/full.yaml @@ -331,7 +331,9 @@ events: - 'redis://localhost:6379/11' cluster_enabled: true subscription_hooks: - max_concurrent_event_receive_handlers: 100 + on_receive_events: + max_concurrent_handlers: 100 + handler_timeout: 5s engine: enable_single_flight: true diff --git a/router/pkg/config/testdata/config_defaults.json b/router/pkg/config/testdata/config_defaults.json index 40ce94c15a..7655cc7007 100644 --- a/router/pkg/config/testdata/config_defaults.json +++ b/router/pkg/config/testdata/config_defaults.json @@ -297,7 +297,10 @@ "Redis": null }, "SubscriptionHooks": { - "MaxConcurrentEventReceiveHandlers": 100 + "OnReceiveEvents": { + "MaxConcurrentHandlers": 100, + "HandlerTimeout": 5000000000 + } } }, "CacheWarmup": { diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index 5a09b35205..2731dee760 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -643,7 +643,10 @@ ] }, "SubscriptionHooks": { - "MaxConcurrentEventReceiveHandlers": 100 + "OnReceiveEvents": { + "MaxConcurrentHandlers": 100, + "HandlerTimeout": 5000000000 + } } }, "CacheWarmup": { diff --git a/router/pkg/pubsub/datasource/hooks.go b/router/pkg/pubsub/datasource/hooks.go index a2e53e7183..a262058463 100644 --- a/router/pkg/pubsub/datasource/hooks.go +++ b/router/pkg/pubsub/datasource/hooks.go @@ -2,6 +2,7 @@ package datasource import ( "context" + "time" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" ) @@ -10,12 +11,28 @@ type SubscriptionOnStartFn func(ctx resolve.StartupHookContext, subConf Subscrip type OnPublishEventsFn func(ctx context.Context, pubConf PublishEventConfiguration, evts []StreamEvent, eventBuilder EventBuilderFn) ([]StreamEvent, error) -type OnReceiveEventsFn func(ctx context.Context, subConf SubscriptionEventConfiguration, eventBuilder EventBuilderFn, evts []StreamEvent) ([]StreamEvent, error) +type OnReceiveEventsFn func(subscriptionCtx context.Context, updaterCtx context.Context, subConf SubscriptionEventConfiguration, eventBuilder EventBuilderFn, evts []StreamEvent) ([]StreamEvent, error) // Hooks contains hooks for the pubsub providers and data sources type Hooks struct { - SubscriptionOnStart []SubscriptionOnStartFn - OnReceiveEvents []OnReceiveEventsFn - OnPublishEvents []OnPublishEventsFn - MaxConcurrentOnReceiveHandlers int + SubscriptionOnStart SubscriptionOnStartHooks + OnPublishEvents OnPublishEventsHooks + OnReceiveEvents OnReceiveEventsHooks +} + +// SubscriptionOnStartHooks contains hooks with settings for subscription starts +type SubscriptionOnStartHooks struct { + Handlers []SubscriptionOnStartFn +} + +// OnPublishEventsHooks contains hooks with settings for event publishing +type OnPublishEventsHooks struct { + Handlers []OnPublishEventsFn +} + +// OnReceiveEventsHooks contains hooks with settings for event receiving +type OnReceiveEventsHooks struct { + Handlers []OnReceiveEventsFn + MaxConcurrentHandlers int + Timeout time.Duration } diff --git a/router/pkg/pubsub/datasource/pubsubprovider.go b/router/pkg/pubsub/datasource/pubsubprovider.go index e20f1ace2b..1920bc2b46 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider.go +++ b/router/pkg/pubsub/datasource/pubsubprovider.go @@ -41,7 +41,7 @@ func (p *PubSubProvider) applyPublishEventHooks(ctx context.Context, cfg Publish }() currentEvents = events - for _, hook := range p.hooks.OnPublishEvents { + for _, hook := range p.hooks.OnPublishEvents.Handlers { var err error currentEvents, err = hook(ctx, cfg, currentEvents, p.eventBuilder) currentEvents = slices.DeleteFunc(currentEvents, func(event StreamEvent) bool { @@ -90,7 +90,7 @@ func (p *PubSubProvider) Subscribe(ctx context.Context, cfg SubscriptionEventCon } func (p *PubSubProvider) Publish(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) error { - if len(p.hooks.OnPublishEvents) == 0 { + if len(p.hooks.OnPublishEvents.Handlers) == 0 { return p.Adapter.Publish(ctx, cfg, events) } diff --git a/router/pkg/pubsub/datasource/pubsubprovider_test.go b/router/pkg/pubsub/datasource/pubsubprovider_test.go index 0bf12e7f60..b956ab38f0 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider_test.go +++ b/router/pkg/pubsub/datasource/pubsubprovider_test.go @@ -253,7 +253,9 @@ func TestProvider_Publish_WithHooks_Success(t *testing.T) { provider := PubSubProvider{ Adapter: mockAdapter, hooks: Hooks{ - OnPublishEvents: []OnPublishEventsFn{testHook}, + OnPublishEvents: OnPublishEventsHooks{ + Handlers: []OnPublishEventsFn{testHook}, + }, }, eventBuilder: testPubSubEventBuilder, } @@ -286,7 +288,9 @@ func TestProvider_Publish_WithHooks_HookError(t *testing.T) { provider := PubSubProvider{ Adapter: mockAdapter, hooks: Hooks{ - OnPublishEvents: []OnPublishEventsFn{testHook}, + OnPublishEvents: OnPublishEventsHooks{ + Handlers: []OnPublishEventsFn{testHook}, + }, }, Logger: zap.NewNop(), eventBuilder: testPubSubEventBuilder, @@ -322,7 +326,9 @@ func TestProvider_Publish_WithHooks_AdapterError(t *testing.T) { provider := PubSubProvider{ Adapter: mockAdapter, hooks: Hooks{ - OnPublishEvents: []OnPublishEventsFn{testHook}, + OnPublishEvents: OnPublishEventsHooks{ + Handlers: []OnPublishEventsFn{testHook}, + }, }, eventBuilder: testPubSubEventBuilder, } @@ -358,7 +364,9 @@ func TestProvider_Publish_WithMultipleHooks_Success(t *testing.T) { provider := PubSubProvider{ Adapter: mockAdapter, hooks: Hooks{ - OnPublishEvents: []OnPublishEventsFn{hook1, hook2}, + OnPublishEvents: OnPublishEventsHooks{ + Handlers: []OnPublishEventsFn{hook1, hook2}, + }, }, eventBuilder: testPubSubEventBuilder, } @@ -375,7 +383,9 @@ func TestProvider_SetHooks(t *testing.T) { } hooks := Hooks{ - OnPublishEvents: []OnPublishEventsFn{testHook}, + OnPublishEvents: OnPublishEventsHooks{ + Handlers: []OnPublishEventsFn{testHook}, + }, } provider.SetHooks(hooks) @@ -396,7 +406,7 @@ func TestNewPubSubProvider(t *testing.T) { assert.Equal(t, typeID, provider.TypeID()) assert.Equal(t, mockAdapter, provider.Adapter) assert.Equal(t, logger, provider.Logger) - assert.Empty(t, provider.hooks.OnPublishEvents) + assert.Empty(t, provider.hooks.OnPublishEvents.Handlers) } func TestApplyPublishEventHooks_NoHooks(t *testing.T) { @@ -412,7 +422,9 @@ func TestApplyPublishEventHooks_NoHooks(t *testing.T) { provider := &PubSubProvider{ Logger: zap.NewNop(), hooks: Hooks{ - OnPublishEvents: []OnPublishEventsFn{}, + OnPublishEvents: OnPublishEventsHooks{ + Handlers: []OnPublishEventsFn{}, + }, }, } @@ -443,7 +455,9 @@ func TestApplyPublishEventHooks_SingleHook_Success(t *testing.T) { provider := &PubSubProvider{ Logger: zap.NewNop(), hooks: Hooks{ - OnPublishEvents: []OnPublishEventsFn{hook}, + OnPublishEvents: OnPublishEventsHooks{ + Handlers: []OnPublishEventsFn{hook}, + }, }, } @@ -472,7 +486,9 @@ func TestApplyPublishEventHooks_SingleHook_Error(t *testing.T) { provider := &PubSubProvider{ Logger: zap.NewNop(), hooks: Hooks{ - OnPublishEvents: []OnPublishEventsFn{hook}, + OnPublishEvents: OnPublishEventsHooks{ + Handlers: []OnPublishEventsFn{hook}, + }, }, } @@ -507,7 +523,9 @@ func TestApplyPublishEventHooks_MultipleHooks_Success(t *testing.T) { provider := &PubSubProvider{ Logger: zap.NewNop(), hooks: Hooks{ - OnPublishEvents: []OnPublishEventsFn{hook1, hook2, hook3}, + OnPublishEvents: OnPublishEventsHooks{ + Handlers: []OnPublishEventsFn{hook1, hook2, hook3}, + }, }, } @@ -543,7 +561,9 @@ func TestApplyPublishEventHooks_MultipleHooks_MiddleHookError(t *testing.T) { provider := &PubSubProvider{ Logger: zap.NewNop(), hooks: Hooks{ - OnPublishEvents: []OnPublishEventsFn{hook1, hook2, hook3}, + OnPublishEvents: OnPublishEventsHooks{ + Handlers: []OnPublishEventsFn{hook1, hook2, hook3}, + }, }, } @@ -599,7 +619,9 @@ func TestApplyPublishEventHooks_PanicRecovery(t *testing.T) { provider := &PubSubProvider{ Logger: zap.NewNop(), hooks: Hooks{ - OnPublishEvents: []OnPublishEventsFn{hook}, + OnPublishEvents: OnPublishEventsHooks{ + Handlers: []OnPublishEventsFn{hook}, + }, }, } diff --git a/router/pkg/pubsub/datasource/subscription_datasource.go b/router/pkg/pubsub/datasource/subscription_datasource.go index c625af9c33..9285d6cfb7 100644 --- a/router/pkg/pubsub/datasource/subscription_datasource.go +++ b/router/pkg/pubsub/datasource/subscription_datasource.go @@ -46,7 +46,14 @@ func (s *PubSubSubscriptionDataSource[C]) Start(ctx *resolve.Context, input []by return errors.New("invalid subscription configuration") } - return s.pubSub.Subscribe(ctx.Context(), conf, NewSubscriptionEventUpdater(conf, s.hooks, updater, s.logger, s.eventBuilder)) + logger := s.logger.With( + zap.String("component", "subscription_event_updater"), + zap.String("provider_id", conf.ProviderID()), + zap.String("provider_type", string(conf.ProviderType())), + zap.String("field_name", conf.RootFieldName()), + ) + + return s.pubSub.Subscribe(ctx.Context(), conf, NewSubscriptionEventUpdater(conf, s.hooks, updater, logger, s.eventBuilder)) } func (s *PubSubSubscriptionDataSource[C]) SubscriptionOnStart(ctx resolve.StartupHookContext, input []byte) (err error) { @@ -66,7 +73,7 @@ func (s *PubSubSubscriptionDataSource[C]) SubscriptionOnStart(ctx resolve.Startu } }() - for _, fn := range s.hooks.SubscriptionOnStart { + for _, fn := range s.hooks.SubscriptionOnStart.Handlers { conf, errConf := s.SubscriptionEventConfiguration(input) if errConf != nil { return err diff --git a/router/pkg/pubsub/datasource/subscription_datasource_test.go b/router/pkg/pubsub/datasource/subscription_datasource_test.go index a292f4b0f4..6e2d957a07 100644 --- a/router/pkg/pubsub/datasource/subscription_datasource_test.go +++ b/router/pkg/pubsub/datasource/subscription_datasource_test.go @@ -233,7 +233,9 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_WithHooks(t *testing.T } dataSource.SetHooks(Hooks{ - SubscriptionOnStart: []SubscriptionOnStartFn{hook1, hook2}, + SubscriptionOnStart: SubscriptionOnStartHooks{ + Handlers: []SubscriptionOnStartFn{hook1, hook2}, + }, }) testConfig := testSubscriptionEventConfiguration{ @@ -270,7 +272,9 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_HookReturnsClose(t *te } dataSource.SetHooks(Hooks{ - SubscriptionOnStart: []SubscriptionOnStartFn{hook}, + SubscriptionOnStart: SubscriptionOnStartHooks{ + Handlers: []SubscriptionOnStartFn{hook}, + }, }) testConfig := testSubscriptionEventConfiguration{ @@ -304,7 +308,9 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_HookReturnsError(t *te } dataSource.SetHooks(Hooks{ - SubscriptionOnStart: []SubscriptionOnStartFn{hook}, + SubscriptionOnStart: SubscriptionOnStartHooks{ + Handlers: []SubscriptionOnStartFn{hook}, + }, }) testConfig := testSubscriptionEventConfiguration{ @@ -333,7 +339,7 @@ func TestPubSubSubscriptionDataSource_SetSubscriptionOnStartFns(t *testing.T) { dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) // Initially should have no hooks - assert.Len(t, dataSource.hooks.SubscriptionOnStart, 0) + assert.Len(t, dataSource.hooks.SubscriptionOnStart.Handlers, 0) // Add hooks hook1 := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { @@ -344,14 +350,18 @@ func TestPubSubSubscriptionDataSource_SetSubscriptionOnStartFns(t *testing.T) { } dataSource.SetHooks(Hooks{ - SubscriptionOnStart: []SubscriptionOnStartFn{hook1}, + SubscriptionOnStart: SubscriptionOnStartHooks{ + Handlers: []SubscriptionOnStartFn{hook1}, + }, }) - assert.Len(t, dataSource.hooks.SubscriptionOnStart, 1) + assert.Len(t, dataSource.hooks.SubscriptionOnStart.Handlers, 1) dataSource.SetHooks(Hooks{ - SubscriptionOnStart: []SubscriptionOnStartFn{hook2}, + SubscriptionOnStart: SubscriptionOnStartHooks{ + Handlers: []SubscriptionOnStartFn{hook2}, + }, }) - assert.Len(t, dataSource.hooks.SubscriptionOnStart, 1) + assert.Len(t, dataSource.hooks.SubscriptionOnStart.Handlers, 1) } func TestNewPubSubSubscriptionDataSource(t *testing.T) { @@ -365,7 +375,7 @@ func TestNewPubSubSubscriptionDataSource(t *testing.T) { assert.NotNil(t, dataSource) assert.Equal(t, mockAdapter, dataSource.pubSub) assert.NotNil(t, dataSource.uniqueRequestID) - assert.Empty(t, dataSource.hooks.SubscriptionOnStart) + assert.Empty(t, dataSource.hooks.SubscriptionOnStart.Handlers) } func TestPubSubSubscriptionDataSource_InterfaceCompliance(t *testing.T) { @@ -424,7 +434,9 @@ func TestPubSubSubscriptionDataSource_SubscriptionOnStart_PanicRecovery(t *testi } dataSource.SetHooks(Hooks{ - SubscriptionOnStart: []SubscriptionOnStartFn{hook}, + SubscriptionOnStart: SubscriptionOnStartHooks{ + Handlers: []SubscriptionOnStartFn{hook}, + }, }) testConfig := testSubscriptionEventConfiguration{ diff --git a/router/pkg/pubsub/datasource/subscription_event_updater.go b/router/pkg/pubsub/datasource/subscription_event_updater.go index 5ed4a6c837..0dd1060faf 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater.go @@ -4,12 +4,18 @@ import ( "context" "slices" "sync" + "time" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" "go.uber.org/zap/zapcore" ) +const ( + timeoutGracePeriod = 50 * time.Millisecond + defaultTimeout = 5 * time.Second +) + // SubscriptionEventUpdater is a wrapper around the SubscriptionUpdater interface // that provides a way to send the event struct instead of the raw data // It is used to give access to the event additional fields to the hooks. @@ -26,10 +32,12 @@ type subscriptionEventUpdater struct { hooks Hooks logger *zap.Logger eventBuilder EventBuilderFn + semaphore chan struct{} + timeout time.Duration } func (s *subscriptionEventUpdater) Update(events []StreamEvent) { - if len(s.hooks.OnReceiveEvents) == 0 { + if len(s.hooks.OnReceiveEvents.Handlers) == 0 { for _, event := range events { s.eventUpdater.Update(event.GetData()) } @@ -37,27 +45,37 @@ func (s *subscriptionEventUpdater) Update(events []StreamEvent) { } subscriptions := s.eventUpdater.Subscriptions() - limit := max(s.hooks.MaxConcurrentOnReceiveHandlers, 1) - semaphore := make(chan struct{}, limit) wg := sync.WaitGroup{} - errCh := make(chan error, len(subscriptions)) + updaterCtx, cancel := context.WithDeadline(context.Background(), time.Now().Add(s.timeout)) + defer cancel() - for ctx, subId := range subscriptions { - semaphore <- struct{}{} // Acquire a slot - wg.Add(1) - go s.updateSubscription(ctx, &wg, errCh, semaphore, subId, events) - } + done := make(chan struct{}) - doneLogging := make(chan struct{}) go func() { - s.deduplicateAndLogErrors(errCh, len(subscriptions)) - doneLogging <- struct{}{} + for subCtx, subId := range subscriptions { + s.semaphore <- struct{}{} // Acquire slot, blocks if all slots are taken + wg.Add(1) + go s.updateSubscription(subCtx, updaterCtx, &wg, subId, events) + } + + wg.Wait() + close(done) }() - wg.Wait() - close(semaphore) - close(errCh) - <-doneLogging + select { + case <-done: + s.logger.Debug("All subscription updates completed") + // All subscriptions completed successfully + case <-time.After(s.timeout + timeoutGracePeriod): + // Timeout exceeded, some subscription updates may still be running. + // We can't stop them but we will also not wait for them, basically abandoning them. + // They will continue to hold their semaphore slots until they complete, + // which means the next Update() call will have fewer available slots. + // Also since we will process the next batch of events while having abandoned updaters, + // those updaters might eventually push their events to the subscription late, + // which means events might arrive out of order. + s.logger.Warn("Timeout exceeded during subscription updates, events may arrive out of order") + } } func (s *subscriptionEventUpdater) Complete() { @@ -66,34 +84,31 @@ func (s *subscriptionEventUpdater) Complete() { func (s *subscriptionEventUpdater) Close(kind resolve.SubscriptionCloseKind) { s.eventUpdater.Close(kind) + close(s.semaphore) } func (s *subscriptionEventUpdater) SetHooks(hooks Hooks) { s.hooks = hooks } -func (s *subscriptionEventUpdater) updateSubscription(ctx context.Context, wg *sync.WaitGroup, errCh chan error, semaphore chan struct{}, subID resolve.SubscriptionIdentifier, events []StreamEvent) { +func (s *subscriptionEventUpdater) updateSubscription(subscriptionCtx context.Context, updaterCtx context.Context, wg *sync.WaitGroup, subID resolve.SubscriptionIdentifier, events []StreamEvent) { defer wg.Done() defer func() { if r := recover(); r != nil { s.recoverPanic(subID, r) } - <-semaphore // release the slot when done + <-s.semaphore // release the slot when done }() - hooks := s.hooks.OnReceiveEvents + hooks := s.hooks.OnReceiveEvents.Handlers // modify events with hooks var err error for i := range hooks { - events, err = hooks[i](ctx, s.subscriptionEventConfiguration, s.eventBuilder, events) + events, err = hooks[i](subscriptionCtx, updaterCtx, s.subscriptionEventConfiguration, s.eventBuilder, events) events = slices.DeleteFunc(events, func(event StreamEvent) bool { return event == nil }) - - if err != nil { - errCh <- err - } } // send events to the subscription, @@ -120,36 +135,6 @@ func (s *subscriptionEventUpdater) recoverPanic(subID resolve.SubscriptionIdenti s.eventUpdater.CloseSubscription(resolve.SubscriptionCloseKindDownstreamServiceError, subID) } -// deduplicateAndLogErrors collects errors from errCh -// and deduplicates them based on their err.Error() value. -// Afterwards it uses s.logger to log the message. -func (s *subscriptionEventUpdater) deduplicateAndLogErrors(errCh chan error, size int) { - if s.logger == nil { - return - } - - errs := make(map[string]int, size) - for err := range errCh { - amount, found := errs[err.Error()] - if found { - errs[err.Error()] = amount + 1 - continue - } - errs[err.Error()] = 1 - } - - for err, amount := range errs { - s.logger.Error( - "some handlers have thrown an error", - zap.String("error", err), - zap.Int("amount_handlers", amount), - zap.String("provider_type", string(s.subscriptionEventConfiguration.ProviderType())), - zap.String("provider_id", s.subscriptionEventConfiguration.ProviderID()), - zap.String("field_name", s.subscriptionEventConfiguration.RootFieldName()), - ) - } -} - func NewSubscriptionEventUpdater( cfg SubscriptionEventConfiguration, hooks Hooks, @@ -157,11 +142,19 @@ func NewSubscriptionEventUpdater( logger *zap.Logger, eventBuilder EventBuilderFn, ) SubscriptionEventUpdater { + limit := max(hooks.OnReceiveEvents.MaxConcurrentHandlers, 1) + timeout := hooks.OnReceiveEvents.Timeout + if timeout == 0 { + timeout = defaultTimeout + } + return &subscriptionEventUpdater{ subscriptionEventConfiguration: cfg, hooks: hooks, eventUpdater: eventUpdater, logger: logger, eventBuilder: eventBuilder, + semaphore: make(chan struct{}, limit), + timeout: timeout, } } diff --git a/router/pkg/pubsub/datasource/subscription_event_updater_test.go b/router/pkg/pubsub/datasource/subscription_event_updater_test.go index 1b6c1bd3a7..e8fceabf2e 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater_test.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater_test.go @@ -8,6 +8,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" "go.uber.org/zap/zaptest/observer" @@ -58,11 +59,13 @@ func TestSubscriptionEventUpdater_Update_NoHooks(t *testing.T) { mockUpdater.On("Update", []byte("test data 1")).Return() mockUpdater.On("Update", []byte("test data 2")).Return() - updater := &subscriptionEventUpdater{ - eventUpdater: mockUpdater, - subscriptionEventConfiguration: config, - hooks: Hooks{}, // No hooks - } + updater := NewSubscriptionEventUpdater( + config, + Hooks{}, // No hooks + mockUpdater, + zap.NewNop(), + testEventBuilder, + ) updater.Update(events) } @@ -84,7 +87,7 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Success(t *testin // Create wrapper function for the mock receivedArgs := make(chan receivedHooksArgs, 1) - testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { + testHook := func(subCtx context.Context, updaterCtx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { receivedArgs <- receivedHooksArgs{events: events, cfg: cfg} return modifiedEvents, nil } @@ -96,14 +99,17 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Success(t *testin context.Background(): subId, }) - updater := &subscriptionEventUpdater{ - eventUpdater: mockUpdater, - subscriptionEventConfiguration: config, - hooks: Hooks{ - OnReceiveEvents: []OnReceiveEventsFn{testHook}, + updater := NewSubscriptionEventUpdater( + config, + Hooks{ + OnReceiveEvents: OnReceiveEventsHooks{ + Handlers: []OnReceiveEventsFn{testHook}, + }, }, - eventBuilder: testEventBuilder, - } + mockUpdater, + zap.NewNop(), + testEventBuilder, + ) updater.Update(originalEvents) @@ -129,7 +135,7 @@ func TestSubscriptionEventUpdater_UpdateSubscriptions_WithHooks_Error(t *testing hookError := errors.New("hook processing error") // Define hook that returns an error - testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { + testHook := func(subCtx context.Context, updaterCtx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { return nil, hookError } @@ -141,14 +147,17 @@ func TestSubscriptionEventUpdater_UpdateSubscriptions_WithHooks_Error(t *testing mockUpdater.On("CloseSubscription", resolve.SubscriptionCloseKindNormal, subId).Return() // Should not call Update or UpdateSubscription on eventUpdater since hook fails - updater := &subscriptionEventUpdater{ - eventUpdater: mockUpdater, - subscriptionEventConfiguration: config, - hooks: Hooks{ - OnReceiveEvents: []OnReceiveEventsFn{testHook}, + updater := NewSubscriptionEventUpdater( + config, + Hooks{ + OnReceiveEvents: OnReceiveEventsHooks{ + Handlers: []OnReceiveEventsFn{testHook}, + }, }, - eventBuilder: testEventBuilder, - } + mockUpdater, + zap.NewNop(), + testEventBuilder, + ) updater.Update(events) @@ -171,13 +180,13 @@ func TestSubscriptionEventUpdater_Update_WithMultipleHooks_Success(t *testing.T) // Chain of hooks that modify the data receivedArgs1 := make(chan receivedHooksArgs, 1) - hook1 := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { + hook1 := func(subCtx context.Context, updaterCtx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { receivedArgs1 <- receivedHooksArgs{events: events, cfg: cfg} return []StreamEvent{&testEvent{mutableTestEvent("modified by hook1")}}, nil } receivedArgs2 := make(chan receivedHooksArgs, 1) - hook2 := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { + hook2 := func(subCtx context.Context, updaterCtx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { receivedArgs2 <- receivedHooksArgs{events: events, cfg: cfg} return []StreamEvent{&testEvent{mutableTestEvent("modified by hook2")}}, nil } @@ -189,14 +198,17 @@ func TestSubscriptionEventUpdater_Update_WithMultipleHooks_Success(t *testing.T) context.Background(): subId, }) - updater := &subscriptionEventUpdater{ - eventUpdater: mockUpdater, - subscriptionEventConfiguration: config, - hooks: Hooks{ - OnReceiveEvents: []OnReceiveEventsFn{hook1, hook2}, + updater := NewSubscriptionEventUpdater( + config, + Hooks{ + OnReceiveEvents: OnReceiveEventsHooks{ + Handlers: []OnReceiveEventsFn{hook1, hook2}, + }, }, - eventBuilder: testEventBuilder, - } + mockUpdater, + zap.NewNop(), + testEventBuilder, + ) updater.Update(originalEvents) @@ -227,11 +239,13 @@ func TestSubscriptionEventUpdater_Complete(t *testing.T) { mockUpdater.On("Complete").Return() - updater := &subscriptionEventUpdater{ - eventUpdater: mockUpdater, - subscriptionEventConfiguration: config, - hooks: Hooks{}, - } + updater := NewSubscriptionEventUpdater( + config, + Hooks{}, + mockUpdater, + zap.NewNop(), + testEventBuilder, + ) updater.Complete() } @@ -247,11 +261,13 @@ func TestSubscriptionEventUpdater_Close(t *testing.T) { mockUpdater.On("Close", closeKind).Return() - updater := &subscriptionEventUpdater{ - eventUpdater: mockUpdater, - subscriptionEventConfiguration: config, - hooks: Hooks{}, - } + updater := NewSubscriptionEventUpdater( + config, + Hooks{}, + mockUpdater, + zap.NewNop(), + testEventBuilder, + ) updater.Close(closeKind) } @@ -264,24 +280,30 @@ func TestSubscriptionEventUpdater_SetHooks(t *testing.T) { fieldName: "testField", } - testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { + testHook := func(subCtx context.Context, updaterCtx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { return events, nil } hooks := Hooks{ - OnReceiveEvents: []OnReceiveEventsFn{testHook}, + OnReceiveEvents: OnReceiveEventsHooks{ + Handlers: []OnReceiveEventsFn{testHook}, + }, } - updater := &subscriptionEventUpdater{ - eventUpdater: mockUpdater, - subscriptionEventConfiguration: config, - hooks: Hooks{}, - eventBuilder: testEventBuilder, - } + updater := NewSubscriptionEventUpdater( + config, + Hooks{}, + mockUpdater, + zap.NewNop(), + testEventBuilder, + ) updater.SetHooks(hooks) - assert.Equal(t, hooks, updater.hooks) + // Type assert to access internal fields for testing + concreteUpdater, ok := updater.(*subscriptionEventUpdater) + require.True(t, ok) + assert.Equal(t, hooks, concreteUpdater.hooks) } func TestNewSubscriptionEventUpdater(t *testing.T) { @@ -292,12 +314,14 @@ func TestNewSubscriptionEventUpdater(t *testing.T) { fieldName: "testField", } - testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { + testHook := func(subCtx context.Context, updaterCtx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { return events, nil } hooks := Hooks{ - OnReceiveEvents: []OnReceiveEventsFn{testHook}, + OnReceiveEvents: OnReceiveEventsHooks{ + Handlers: []OnReceiveEventsFn{testHook}, + }, } updater := NewSubscriptionEventUpdater(config, hooks, mockUpdater, zap.NewNop(), testEventBuilder) @@ -331,11 +355,13 @@ func TestSubscriptionEventUpdater_Update_PassthroughWithNoHooks(t *testing.T) { mockUpdater.On("Update", []byte("event data 2")).Return() mockUpdater.On("Update", []byte("event data 3")).Return() - updater := &subscriptionEventUpdater{ - eventUpdater: mockUpdater, - subscriptionEventConfiguration: config, - hooks: Hooks{}, // No hooks - } + updater := NewSubscriptionEventUpdater( + config, + Hooks{}, // No hooks + mockUpdater, + zap.NewNop(), + testEventBuilder, + ) updater.Update(events) @@ -359,7 +385,7 @@ func TestSubscriptionEventUpdater_Update_WithSingleHookModification(t *testing.T } // Hook that modifies events by adding a prefix - hook := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { + hook := func(subCtx context.Context, updaterCtx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { modifiedEvents := make([]StreamEvent, len(events)) for i, event := range events { modifiedData := "modified: " + string(event.GetData()) @@ -377,13 +403,17 @@ func TestSubscriptionEventUpdater_Update_WithSingleHookModification(t *testing.T mockUpdater.On("UpdateSubscription", subId, []byte("modified: original data 1")).Return() mockUpdater.On("UpdateSubscription", subId, []byte("modified: original data 2")).Return() - updater := &subscriptionEventUpdater{ - eventUpdater: mockUpdater, - subscriptionEventConfiguration: config, - hooks: Hooks{ - OnReceiveEvents: []OnReceiveEventsFn{hook}, + updater := NewSubscriptionEventUpdater( + config, + Hooks{ + OnReceiveEvents: OnReceiveEventsHooks{ + Handlers: []OnReceiveEventsFn{hook}, + }, }, - } + mockUpdater, + zap.NewNop(), + testEventBuilder, + ) updater.Update(originalEvents) @@ -395,7 +425,7 @@ func TestSubscriptionEventUpdater_Update_WithSingleHookModification(t *testing.T mockUpdater.AssertNotCalled(t, "Update") } -func TestSubscriptionEventUpdater_Update_WithSingleHookError_ClosesSubscriptionAndLogsError(t *testing.T) { +func TestSubscriptionEventUpdater_Update_WithSingleHookError_ClosesSubscription(t *testing.T) { mockUpdater := NewMockSubscriptionUpdater(t) config := &testSubscriptionEventConfig{ providerID: "test-provider", @@ -408,15 +438,11 @@ func TestSubscriptionEventUpdater_Update_WithSingleHookError_ClosesSubscriptionA hookError := errors.New("hook processing failed") // Hook that returns an error - hook := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { + hook := func(subCtx context.Context, updaterCtx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { // Return the events but also return an error return events, hookError } - // Set up logger with observer to verify error logging - zCore, logObserver := observer.New(zap.InfoLevel) - logger := zap.New(zCore) - subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ context.Background(): subId, @@ -427,8 +453,10 @@ func TestSubscriptionEventUpdater_Update_WithSingleHookError_ClosesSubscriptionA mockUpdater.On("CloseSubscription", resolve.SubscriptionCloseKindNormal, subId).Return() updater := NewSubscriptionEventUpdater(config, Hooks{ - OnReceiveEvents: []OnReceiveEventsFn{hook}, - }, mockUpdater, logger, testEventBuilder) + OnReceiveEvents: OnReceiveEventsHooks{ + Handlers: []OnReceiveEventsFn{hook}, + }, + }, mockUpdater, zap.NewNop(), testEventBuilder) updater.Update(events) @@ -438,16 +466,6 @@ func TestSubscriptionEventUpdater_Update_WithSingleHookError_ClosesSubscriptionA mockUpdater.AssertCalled(t, "CloseSubscription", resolve.SubscriptionCloseKindNormal, subId) // Update should NOT be called when hooks are present mockUpdater.AssertNotCalled(t, "Update") - - // Verify error was logged (logging happens asynchronously) - assert.Eventually(t, func() bool { - logs := logObserver.FilterMessageSnippet("some handlers have thrown an error").TakeAll() - if len(logs) != 1 { - return false - } - // Verify the logged error message contains our error - return logs[0].ContextMap()["error"] == hookError.Error() - }, time.Second, 10*time.Millisecond, "expected error to be logged") } func TestSubscriptionEventUpdater_Update_WithMultipleHooksChaining(t *testing.T) { @@ -467,7 +485,7 @@ func TestSubscriptionEventUpdater_Update_WithMultipleHooksChaining(t *testing.T) // Hook 1: Adds "step1: " prefix receivedArgs1 := make(chan receivedHooksArgs, 1) - hook1 := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { + hook1 := func(subCtx context.Context, updaterCtx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { mu.Lock() hookCallOrder = append(hookCallOrder, 1) mu.Unlock() @@ -482,7 +500,7 @@ func TestSubscriptionEventUpdater_Update_WithMultipleHooksChaining(t *testing.T) // Hook 2: Adds "step2: " prefix receivedArgs2 := make(chan receivedHooksArgs, 1) - hook2 := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { + hook2 := func(subCtx context.Context, updaterCtx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { mu.Lock() hookCallOrder = append(hookCallOrder, 2) mu.Unlock() @@ -497,7 +515,7 @@ func TestSubscriptionEventUpdater_Update_WithMultipleHooksChaining(t *testing.T) // Hook 3: Adds "step3: " prefix receivedArgs3 := make(chan receivedHooksArgs, 1) - hook3 := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { + hook3 := func(subCtx context.Context, updaterCtx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { mu.Lock() hookCallOrder = append(hookCallOrder, 3) mu.Unlock() @@ -517,13 +535,17 @@ func TestSubscriptionEventUpdater_Update_WithMultipleHooksChaining(t *testing.T) // Final modified data should have all three transformations applied mockUpdater.On("UpdateSubscription", subId, []byte("step3: step2: step1: original")).Return() - updater := &subscriptionEventUpdater{ - eventUpdater: mockUpdater, - subscriptionEventConfiguration: config, - hooks: Hooks{ - OnReceiveEvents: []OnReceiveEventsFn{hook1, hook2, hook3}, + updater := NewSubscriptionEventUpdater( + config, + Hooks{ + OnReceiveEvents: OnReceiveEventsHooks{ + Handlers: []OnReceiveEventsFn{hook1, hook2, hook3}, + }, }, - } + mockUpdater, + zap.NewNop(), + testEventBuilder, + ) updater.Update(originalEvents) @@ -579,11 +601,13 @@ func TestSubscriptionEventUpdater_UpdateEvents_EmptyEvents(t *testing.T) { } events := []StreamEvent{} // Empty events - updater := &subscriptionEventUpdater{ - eventUpdater: mockUpdater, - subscriptionEventConfiguration: config, - hooks: Hooks{}, // No hooks - } + updater := NewSubscriptionEventUpdater( + config, + Hooks{}, // No hooks + mockUpdater, + zap.NewNop(), + testEventBuilder, + ) updater.Update(events) @@ -612,11 +636,13 @@ func TestSubscriptionEventUpdater_Close_WithDifferentCloseKinds(t *testing.T) { mockUpdater.On("Close", tc.closeKind).Return() - updater := &subscriptionEventUpdater{ - eventUpdater: mockUpdater, - subscriptionEventConfiguration: config, - hooks: Hooks{}, - } + updater := NewSubscriptionEventUpdater( + config, + Hooks{}, + mockUpdater, + zap.NewNop(), + testEventBuilder, + ) updater.Close(tc.closeKind) }) @@ -650,18 +676,21 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHookError_ClosesSubscri &testEvent{mutableTestEvent("test data")}, } - testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { + testHook := func(subCtx context.Context, updaterCtx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { return events, tc.hookError } - updater := &subscriptionEventUpdater{ - eventUpdater: mockUpdater, - subscriptionEventConfiguration: config, - hooks: Hooks{ - OnReceiveEvents: []OnReceiveEventsFn{testHook}, + updater := NewSubscriptionEventUpdater( + config, + Hooks{ + OnReceiveEvents: OnReceiveEventsHooks{ + Handlers: []OnReceiveEventsFn{testHook}, + }, }, - eventBuilder: testEventBuilder, - } + mockUpdater, + zap.NewNop(), + testEventBuilder, + ) subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} mockUpdater.On("UpdateSubscription", subId, []byte("test data")).Return() @@ -677,50 +706,6 @@ func TestSubscriptionEventUpdater_UpdateSubscription_WithHookError_ClosesSubscri } } -func TestSubscriptionEventUpdater_UpdateSubscription_WithHooks_Error_LoggerWritesError(t *testing.T) { - mockUpdater := NewMockSubscriptionUpdater(t) - config := &testSubscriptionEventConfig{ - providerID: "test-provider", - providerType: ProviderTypeNats, - fieldName: "testField", - } - events := []StreamEvent{ - &testEvent{mutableTestEvent("test data")}, - } - hookError := errors.New("hook processing error") - - // Define hook that returns an error - testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { - return nil, hookError - } - - zCore, logObserver := observer.New(zap.InfoLevel) - logger := zap.New(zCore) - - // Test with a real zap logger to verify error logging behavior - // The logger.Error() call should be executed when an error occurs - updater := NewSubscriptionEventUpdater(config, Hooks{ - OnReceiveEvents: []OnReceiveEventsFn{testHook}, - }, mockUpdater, logger, testEventBuilder) - - subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} - mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ - context.Background(): subId, - }) - mockUpdater.On("CloseSubscription", resolve.SubscriptionCloseKindNormal, subId).Return() - - updater.Update(events) - - // Assert that Update was not called on the eventUpdater - mockUpdater.AssertNotCalled(t, "UpdateSubscription") - mockUpdater.AssertCalled(t, "CloseSubscription", resolve.SubscriptionCloseKindNormal, subId) - - // log error messages for hooks are written async, we need to wait for them to be written - assert.Eventually(t, func() bool { - return len(logObserver.FilterMessageSnippet("some handlers have thrown an error").TakeAll()) == 1 - }, time.Second, 10*time.Millisecond, "expected one deduplicated error log") -} - func TestSubscriptionEventUpdater_OnReceiveEvents_PanicRecovery(t *testing.T) { panicErr := errors.New("panic error") @@ -758,7 +743,7 @@ func TestSubscriptionEventUpdater_OnReceiveEvents_PanicRecovery(t *testing.T) { } // Create hook that panics - testHook := func(ctx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { + testHook := func(subCtx context.Context, updaterCtx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { panic(tt.panicValue) } @@ -768,14 +753,17 @@ func TestSubscriptionEventUpdater_OnReceiveEvents_PanicRecovery(t *testing.T) { }) mockUpdater.On("CloseSubscription", resolve.SubscriptionCloseKindDownstreamServiceError, subId).Return() - updater := &subscriptionEventUpdater{ - eventUpdater: mockUpdater, - subscriptionEventConfiguration: config, - hooks: Hooks{ - OnReceiveEvents: []OnReceiveEventsFn{testHook}, + updater := NewSubscriptionEventUpdater( + config, + Hooks{ + OnReceiveEvents: OnReceiveEventsHooks{ + Handlers: []OnReceiveEventsFn{testHook}, + }, }, - logger: logger, - } + mockUpdater, + logger, + testEventBuilder, + ) updater.Update(events) diff --git a/router/pkg/pubsub/pubsub_test.go b/router/pkg/pubsub/pubsub_test.go index 39444689ac..f0b99ec303 100644 --- a/router/pkg/pubsub/pubsub_test.go +++ b/router/pkg/pubsub/pubsub_test.go @@ -63,8 +63,8 @@ func TestBuild_OK(t *testing.T) { mockPubSubProvider.On("ID").Return("provider-1") mockPubSubProvider.On("SetHooks", datasource.Hooks{ - OnReceiveEvents: []datasource.OnReceiveEventsFn(nil), - OnPublishEvents: []datasource.OnPublishEventsFn(nil), + OnReceiveEvents: datasource.OnReceiveEventsHooks{Handlers: []datasource.OnReceiveEventsFn(nil)}, + OnPublishEvents: datasource.OnPublishEventsHooks{Handlers: []datasource.OnPublishEventsFn(nil)}, }).Return(nil) mockBuilder.On("TypeID").Return("nats") @@ -242,8 +242,8 @@ func TestBuild_ShouldNotInitializeProviderIfNotUsed(t *testing.T) { mockPubSubUsedProvider.On("ID").Return("provider-2") mockPubSubUsedProvider.On("SetHooks", datasource.Hooks{ - OnReceiveEvents: []datasource.OnReceiveEventsFn(nil), - OnPublishEvents: []datasource.OnPublishEventsFn(nil), + OnReceiveEvents: datasource.OnReceiveEventsHooks{Handlers: []datasource.OnReceiveEventsFn(nil)}, + OnPublishEvents: datasource.OnPublishEventsHooks{Handlers: []datasource.OnPublishEventsFn(nil)}, }).Return(nil) mockBuilder.On("TypeID").Return("nats") From 2955b6278f29f266ca12a564b428de1c2db2c210 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Fri, 14 Nov 2025 12:00:55 +0100 Subject: [PATCH 11/44] fix(router): return correct error when creating event config fails (Cosmo Streams) (#2333) --- .../datasource/subscription_datasource.go | 4 +-- .../subscription_datasource_test.go | 35 +++++++++++++++++++ 2 files changed, 37 insertions(+), 2 deletions(-) diff --git a/router/pkg/pubsub/datasource/subscription_datasource.go b/router/pkg/pubsub/datasource/subscription_datasource.go index 9285d6cfb7..652644e7da 100644 --- a/router/pkg/pubsub/datasource/subscription_datasource.go +++ b/router/pkg/pubsub/datasource/subscription_datasource.go @@ -74,8 +74,8 @@ func (s *PubSubSubscriptionDataSource[C]) SubscriptionOnStart(ctx resolve.Startu }() for _, fn := range s.hooks.SubscriptionOnStart.Handlers { - conf, errConf := s.SubscriptionEventConfiguration(input) - if errConf != nil { + conf, err := s.SubscriptionEventConfiguration(input) + if err != nil { return err } err = fn(ctx, conf, s.eventBuilder) diff --git a/router/pkg/pubsub/datasource/subscription_datasource_test.go b/router/pkg/pubsub/datasource/subscription_datasource_test.go index 6e2d957a07..3e4ccf4f7b 100644 --- a/router/pkg/pubsub/datasource/subscription_datasource_test.go +++ b/router/pkg/pubsub/datasource/subscription_datasource_test.go @@ -9,6 +9,7 @@ import ( "github.com/cespare/xxhash/v2" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" ) @@ -393,6 +394,40 @@ func TestPubSubSubscriptionDataSource_InterfaceCompliance(t *testing.T) { var _ resolve.HookableSubscriptionDataSource = dataSource } +func TestPubSubSubscriptionDataSource_SubscriptionOnStart_InvalidEventConfigInput(t *testing.T) { + mockAdapter := NewMockProvider(t) + uniqueRequestIDFn := func(ctx *resolve.Context, input []byte, xxh *xxhash.Digest) error { + return nil + } + + dataSource := NewPubSubSubscriptionDataSource[testSubscriptionEventConfiguration](mockAdapter, uniqueRequestIDFn, zap.NewNop(), testSubscriptionDataSourceEventBuilder) + + hookCalled := false + hook := func(ctx resolve.StartupHookContext, config SubscriptionEventConfiguration, eventBuilder EventBuilderFn) error { + hookCalled = true + return nil + } + + dataSource.SetHooks(Hooks{ + SubscriptionOnStart: SubscriptionOnStartHooks{ + Handlers: []SubscriptionOnStartFn{hook}, + }, + }) + + invalidInput := []byte(`{"invalid": json}`) + + ctx := resolve.StartupHookContext{ + Context: context.Background(), + Updater: func(data []byte) {}, + } + + err := dataSource.SubscriptionOnStart(ctx, invalidInput) + + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid character 'j' looking for beginning of value") + assert.False(t, hookCalled) +} + func TestPubSubSubscriptionDataSource_SubscriptionOnStart_PanicRecovery(t *testing.T) { panicErr := errors.New("panic error") From 34f020b718ddcdf9b1ddac7f7de8d45d28587a9c Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Sun, 16 Nov 2025 11:33:47 +0100 Subject: [PATCH 12/44] chore: ensure nop logger is set when no logger is passed --- router/pkg/pubsub/datasource/pubsubprovider.go | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/router/pkg/pubsub/datasource/pubsubprovider.go b/router/pkg/pubsub/datasource/pubsubprovider.go index 1920bc2b46..aa21c02fc8 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider.go +++ b/router/pkg/pubsub/datasource/pubsubprovider.go @@ -23,13 +23,9 @@ type PubSubProvider struct { func (p *PubSubProvider) applyPublishEventHooks(ctx context.Context, cfg PublishEventConfiguration, events []StreamEvent) (currentEvents []StreamEvent, err error) { defer func() { if r := recover(); r != nil { - if p.Logger != nil { - p.Logger. - WithOptions(zap.AddStacktrace(zapcore.ErrorLevel)). - Error("[Recovery from handler panic]", - zap.Any("error", r), - ) - } + p.Logger. + WithOptions(zap.AddStacktrace(zapcore.ErrorLevel)). + Error("[Recovery from handler panic]", zap.Any("error", r)) switch v := r.(type) { case error: @@ -109,6 +105,10 @@ func (p *PubSubProvider) SetHooks(hooks Hooks) { } func NewPubSubProvider(id string, typeID string, adapter Adapter, logger *zap.Logger, eventBuilder EventBuilderFn) *PubSubProvider { + if logger == nil { + logger = zap.NewNop() + } + return &PubSubProvider{ id: id, typeID: typeID, From b1b9df3da65eccb5ef1d8f7840d386d57d9510a8 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Sun, 16 Nov 2025 12:16:39 +0100 Subject: [PATCH 13/44] fix: prevent potential race condition The updateSubscription method is running once per subscription client in seperate go routines. The events slice passed into this function is shared between all routines. The hook execution in line 107 can either return the original events slice or a new deep copy, if events go modified by a hook. After a hook call returned we cleaned up the events slice from nil elements, thus modifying it. If the hook call prior to that returned the original events slice instead of a deep copy, we caused a race condition, since we operate on the original slice with its underlying shared array. To prevent this I removed the cleanup of nil elements and instead filter for these elements on the only place where they could hurt: The subscription update call to the engine. This way the updateSubscription function does not write to the slice at all, only reads. The only writes could happen in the hook call and our hook interface design should prevent modifying the original slice and instead of to do a deep copy first. --- router/pkg/pubsub/datasource/subscription_event_updater.go | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/router/pkg/pubsub/datasource/subscription_event_updater.go b/router/pkg/pubsub/datasource/subscription_event_updater.go index 0dd1060faf..3b663f1b6d 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater.go @@ -2,7 +2,6 @@ package datasource import ( "context" - "slices" "sync" "time" @@ -106,15 +105,15 @@ func (s *subscriptionEventUpdater) updateSubscription(subscriptionCtx context.Co var err error for i := range hooks { events, err = hooks[i](subscriptionCtx, updaterCtx, s.subscriptionEventConfiguration, s.eventBuilder, events) - events = slices.DeleteFunc(events, func(event StreamEvent) bool { - return event == nil - }) } // send events to the subscription, // regardless if there was an error during hook processing. // If no events should be sent, hook must return no events. for _, event := range events { + if event == nil { + continue + } s.eventUpdater.UpdateSubscription(subID, event.GetData()) } From c3d33dadc2494c1ff41e44d2ea62d0b12d2367d1 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Sun, 16 Nov 2025 12:30:59 +0100 Subject: [PATCH 14/44] fix: remove unnecessary semaphore channel close It might not be garantueed that there are no more subscription updaters running when the event updater Close() function is called. There is no logic waiting for the channel to close and since the garbage collector cleans up orphaned channels, there is no actual need to close the channel here. --- router/pkg/pubsub/datasource/subscription_event_updater.go | 1 - 1 file changed, 1 deletion(-) diff --git a/router/pkg/pubsub/datasource/subscription_event_updater.go b/router/pkg/pubsub/datasource/subscription_event_updater.go index 3b663f1b6d..36b1aa5593 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater.go @@ -83,7 +83,6 @@ func (s *subscriptionEventUpdater) Complete() { func (s *subscriptionEventUpdater) Close(kind resolve.SubscriptionCloseKind) { s.eventUpdater.Close(kind) - close(s.semaphore) } func (s *subscriptionEventUpdater) SetHooks(hooks Hooks) { From b311cef76fd1649b563f75a7ed6ffaed8efacb72 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Sun, 16 Nov 2025 13:56:31 +0100 Subject: [PATCH 15/44] fix: escape input when marshalling json template The MarshalJSONTemplate method creates json-like output data, but it's not actually json conform. It contains placeholders with dollar signs as delimiters. For this reason we cannot use a json marshaller to render the output and resorted to creating the string manually. This could become a security problem because we do not really validate the input we use to create the output. Hypothetically an attacker could escape in that function and inject custom fields into the output. To circumvent this we marshall as many fields as we can. The only remaining field p.Event.Data does not contain user-input data, so we are safe not escaping it. --- router/pkg/pubsub/kafka/engine_datasource.go | 51 ++++++++++++++++++-- router/pkg/pubsub/nats/engine_datasource.go | 40 +++++++++++++-- router/pkg/pubsub/redis/engine_datasource.go | 38 ++++++++++++++- 3 files changed, 121 insertions(+), 8 deletions(-) diff --git a/router/pkg/pubsub/kafka/engine_datasource.go b/router/pkg/pubsub/kafka/engine_datasource.go index 9d48fd0db0..d3ebb1fc5c 100644 --- a/router/pkg/pubsub/kafka/engine_datasource.go +++ b/router/pkg/pubsub/kafka/engine_datasource.go @@ -7,9 +7,11 @@ import ( "fmt" "io" "slices" + "strings" "github.com/buger/jsonparser" "github.com/cespare/xxhash/v2" + goccyjson "github.com/goccy/go-json" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" @@ -125,19 +127,60 @@ func (p *publishData) PublishEventConfiguration() datasource.PublishEventConfigu } func (p *publishData) MarshalJSONTemplate() (string, error) { - // The content of the data field could be not valid JSON, so we can't use json.Marshal - // e.g. {"id":$$0$$,"update":$$1$$} + // The content of p.Event.Data containa template placeholders like $$0$$, $$1$$ + // which are not valid JSON. We can't use json.Marshal for these parts. + // Instead, we use json.Marshal for the safe parts (headers, topic, providerId, rootFieldName, key) + // and manually construct the final JSON string. + headers := p.Event.Headers if headers == nil { headers = make(map[string][]byte) } - headersBytes, err := json.Marshal(headers) + var builder strings.Builder + builder.Grow(256 + len(p.Event.Data)) + + builder.WriteString(`{"topic":`) + topicBytes, err := goccyjson.Marshal(p.Topic) + if err != nil { + return "", err + } + builder.Write(topicBytes) + + builder.WriteString(`, "event": {"data": `) + builder.Write(p.Event.Data) + + builder.WriteString(`, "key": `) + keyBytes, err := goccyjson.Marshal(string(p.Event.Key)) + if err != nil { + return "", err + } + builder.Write(keyBytes) + + builder.WriteString(`, "headers": `) + headersBytes, err := goccyjson.Marshal(headers) + if err != nil { + return "", err + } + builder.Write(headersBytes) + + builder.WriteString(`}, "providerId":`) + providerBytes, err := goccyjson.Marshal(p.Provider) if err != nil { return "", err } + builder.Write(providerBytes) + + builder.WriteString(`, "rootFieldName":`) + rootFieldNameBytes, err := goccyjson.Marshal(p.FieldName) + if err != nil { + return "", err + } + builder.Write(rootFieldNameBytes) + + builder.WriteString(`}`) - return fmt.Sprintf(`{"topic":"%s", "event": {"data": %s, "key": "%s", "headers": %s}, "providerId":"%s", "rootFieldName":"%s"}`, p.Topic, p.Event.Data, p.Event.Key, headersBytes, p.Provider, p.FieldName), nil + return builder.String(), nil } // PublishEventConfiguration is a public type that is used to allow access to custom fields diff --git a/router/pkg/pubsub/nats/engine_datasource.go b/router/pkg/pubsub/nats/engine_datasource.go index f0b8c7b57a..8d1eb6a1e1 100644 --- a/router/pkg/pubsub/nats/engine_datasource.go +++ b/router/pkg/pubsub/nats/engine_datasource.go @@ -7,9 +7,11 @@ import ( "fmt" "io" "slices" + "strings" "github.com/buger/jsonparser" "github.com/cespare/xxhash/v2" + goccyjson "github.com/goccy/go-json" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" @@ -126,9 +128,41 @@ func (p *publishData) PublishEventConfiguration() datasource.PublishEventConfigu } func (p *publishData) MarshalJSONTemplate() (string, error) { - // The content of the data field could be not valid JSON, so we can't use json.Marshal - // e.g. {"id":$$0$$,"update":$$1$$} - return fmt.Sprintf(`{"subject":"%s", "event": {"data": %s}, "providerId":"%s", "rootFieldName":"%s"}`, p.Subject, p.Event.Data, p.Provider, p.FieldName), nil + // The content of p.Event.Data containa template placeholders like $$0$$, $$1$$ + // which are not valid JSON. We can't use json.Marshal for these parts. + // Instead, we use json.Marshal for the safe parts (subject, providerId, rootFieldName) + // and manually construct the final JSON string. + + var builder strings.Builder + builder.Grow(256 + len(p.Event.Data)) + + builder.WriteString(`{"subject":`) + topicBytes, err := goccyjson.Marshal(p.Subject) + if err != nil { + return "", err + } + builder.Write(topicBytes) + + builder.WriteString(`, "event": {"data": `) + builder.Write(p.Event.Data) + + builder.WriteString(`}, "providerId":`) + providerBytes, err := goccyjson.Marshal(p.Provider) + if err != nil { + return "", err + } + builder.Write(providerBytes) + + builder.WriteString(`, "rootFieldName":`) + rootFieldNameBytes, err := goccyjson.Marshal(p.FieldName) + if err != nil { + return "", err + } + builder.Write(rootFieldNameBytes) + + builder.WriteString(`}`) + + return builder.String(), nil } type PublishAndRequestEventConfiguration struct { diff --git a/router/pkg/pubsub/redis/engine_datasource.go b/router/pkg/pubsub/redis/engine_datasource.go index 56f00a4841..929aab5739 100644 --- a/router/pkg/pubsub/redis/engine_datasource.go +++ b/router/pkg/pubsub/redis/engine_datasource.go @@ -7,9 +7,11 @@ import ( "fmt" "io" "slices" + "strings" "github.com/buger/jsonparser" "github.com/cespare/xxhash/v2" + goccyjson "github.com/goccy/go-json" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/datasource/httpclient" "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" @@ -98,7 +100,41 @@ func (p *publishData) PublishEventConfiguration() datasource.PublishEventConfigu } func (p *publishData) MarshalJSONTemplate() (string, error) { - return fmt.Sprintf(`{"channel":"%s", "event": {"data": %s}, "providerId":"%s", "rootFieldName":"%s"}`, p.Channel, p.Event.Data, p.Provider, p.FieldName), nil + // The content of p.Event.Data containa template placeholders like $$0$$, $$1$$ + // which are not valid JSON. We can't use json.Marshal for these parts. + // Instead, we use json.Marshal for the safe parts (subject, providerId, rootFieldName) + // and manually construct the final JSON string. + + var builder strings.Builder + builder.Grow(256 + len(p.Event.Data)) + + builder.WriteString(`{"channel":`) + topicBytes, err := goccyjson.Marshal(p.Channel) + if err != nil { + return "", err + } + builder.Write(topicBytes) + + builder.WriteString(`, "event": {"data": `) + builder.Write(p.Event.Data) + + builder.WriteString(`}, "providerId":`) + providerBytes, err := goccyjson.Marshal(p.Provider) + if err != nil { + return "", err + } + builder.Write(providerBytes) + + builder.WriteString(`, "rootFieldName":`) + rootFieldNameBytes, err := goccyjson.Marshal(p.FieldName) + if err != nil { + return "", err + } + builder.Write(rootFieldNameBytes) + + builder.WriteString(`}`) + + return builder.String(), nil } // PublishEventConfiguration contains configuration for publish events From 538c0ae118e7a9e5d857c158155fdae17e132950 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Sun, 16 Nov 2025 14:13:03 +0100 Subject: [PATCH 16/44] fix: stop processing hooks on error If a hook returns an error we stop processing the hook and use that error for further error handling. --- .../datasource/subscription_event_updater.go | 3 + .../subscription_event_updater_test.go | 80 +++++++++++++++++++ 2 files changed, 83 insertions(+) diff --git a/router/pkg/pubsub/datasource/subscription_event_updater.go b/router/pkg/pubsub/datasource/subscription_event_updater.go index 36b1aa5593..f165dfe922 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater.go @@ -104,6 +104,9 @@ func (s *subscriptionEventUpdater) updateSubscription(subscriptionCtx context.Co var err error for i := range hooks { events, err = hooks[i](subscriptionCtx, updaterCtx, s.subscriptionEventConfiguration, s.eventBuilder, events) + if err != nil { + break + } } // send events to the subscription, diff --git a/router/pkg/pubsub/datasource/subscription_event_updater_test.go b/router/pkg/pubsub/datasource/subscription_event_updater_test.go index e8fceabf2e..2c7f6bac3e 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater_test.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater_test.go @@ -4,6 +4,7 @@ import ( "context" "errors" "sync" + "sync/atomic" "testing" "time" @@ -229,6 +230,85 @@ func TestSubscriptionEventUpdater_Update_WithMultipleHooks_Success(t *testing.T) } } +func TestSubscriptionEventUpdater_Update_WithMultipleHooks_Error(t *testing.T) { + mockUpdater := NewMockSubscriptionUpdater(t) + config := &testSubscriptionEventConfig{ + providerID: "test-provider", + providerType: ProviderTypeNats, + fieldName: "testField", + } + originalEvents := []StreamEvent{ + &testEvent{mutableTestEvent("original data")}, + } + hookError := errors.New("first hook error") + + var hook1Called, hook2Called, hook3Called atomic.Bool + + // Hook 1: Returns an error + hook1 := func(subCtx context.Context, updaterCtx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { + hook1Called.Store(true) + // Return the original events but with an error + return events, hookError + } + + // Hook 2: Should not be called since hook1 returned an error + hook2 := func(subCtx context.Context, updaterCtx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { + hook2Called.Store(true) + return []StreamEvent{&testEvent{mutableTestEvent("modified by hook2")}}, nil + } + + // Hook 3: Should not be called since hook1 returned an error + hook3 := func(subCtx context.Context, updaterCtx context.Context, cfg SubscriptionEventConfiguration, eventBuilder EventBuilderFn, events []StreamEvent) ([]StreamEvent, error) { + hook3Called.Store(true) + return []StreamEvent{&testEvent{mutableTestEvent("modified by hook3")}}, nil + } + + subId := resolve.SubscriptionIdentifier{ConnectionID: 1, SubscriptionID: 1} + mockUpdater.On("Subscriptions").Return(map[context.Context]resolve.SubscriptionIdentifier{ + context.Background(): subId, + }) + // Events from hook1 should still be sent despite the error + mockUpdater.On("UpdateSubscription", subId, []byte("original data")).Return() + // Subscription should be closed due to the error from hook1 + mockUpdater.On("CloseSubscription", resolve.SubscriptionCloseKindNormal, subId).Return() + + updater := NewSubscriptionEventUpdater( + config, + Hooks{ + OnReceiveEvents: OnReceiveEventsHooks{ + Handlers: []OnReceiveEventsFn{hook1, hook2, hook3}, + }, + }, + mockUpdater, + zap.NewNop(), + testEventBuilder, + ) + + updater.Update(originalEvents) + + // Verify hook1 was called + assert.Eventually(t, func() bool { + return hook1Called.Load() + }, 1*time.Second, 10*time.Millisecond, "hook1 should have been called") + + // Verify hook2 was NOT called + assert.Never(t, func() bool { + return hook2Called.Load() + }, 100*time.Millisecond, 10*time.Millisecond, "hook2 should not have been called after hook1 returned an error") + + // Verify hook3 was NOT called + assert.Never(t, func() bool { + return hook3Called.Load() + }, 100*time.Millisecond, 10*time.Millisecond, "hook3 should not have been called after hook1 returned an error") + + // Verify events from hook1 were still sent + mockUpdater.AssertCalled(t, "UpdateSubscription", subId, []byte("original data")) + // Verify subscription was closed due to hook1's error + mockUpdater.AssertCalled(t, "CloseSubscription", resolve.SubscriptionCloseKindNormal, subId) + // Verify Update was not called (since hooks are present) + mockUpdater.AssertNotCalled(t, "Update") +} + func TestSubscriptionEventUpdater_Complete(t *testing.T) { mockUpdater := NewMockSubscriptionUpdater(t) config := &testSubscriptionEventConfig{ From 81b7eaeed4f3a30e9e550929bb1ed01425b91dd3 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Mon, 17 Nov 2025 09:53:00 +0100 Subject: [PATCH 17/44] chore: fix test name + add nil logger checks --- router-tests/modules/start-subscription/module.go | 5 +++-- router-tests/modules/start_subscription_test.go | 2 +- router-tests/modules/stream-publish/module.go | 4 +++- router-tests/modules/stream-receive/module.go | 4 +++- 4 files changed, 10 insertions(+), 5 deletions(-) diff --git a/router-tests/modules/start-subscription/module.go b/router-tests/modules/start-subscription/module.go index ffa94ef1f0..a7d7706a88 100644 --- a/router-tests/modules/start-subscription/module.go +++ b/router-tests/modules/start-subscription/module.go @@ -24,8 +24,9 @@ func (m *StartSubscriptionModule) Provision(ctx *core.ModuleContext) error { } func (m *StartSubscriptionModule) SubscriptionOnStart(ctx core.SubscriptionOnStartHandlerContext) error { - - m.Logger.Info("SubscriptionOnStart Hook has been run") + if m.Logger != nil { + m.Logger.Info("SubscriptionOnStart Hook has been run") + } if m.Callback != nil { return m.Callback(ctx) diff --git a/router-tests/modules/start_subscription_test.go b/router-tests/modules/start_subscription_test.go index de738215f4..0d9ee9c054 100644 --- a/router-tests/modules/start_subscription_test.go +++ b/router-tests/modules/start_subscription_test.go @@ -583,7 +583,7 @@ func TestStartSubscriptionHook(t *testing.T) { }) }) - t.Run("Test StartSubscription hook is called, return StreamHookError, response on OnOriginResponse should still be set", func(t *testing.T) { + t.Run("Test when StartSubscription hook returns an error, the OnOriginResponse hook is not called", func(t *testing.T) { t.Parallel() originResponseCalled := make(chan *http.Response, 1) diff --git a/router-tests/modules/stream-publish/module.go b/router-tests/modules/stream-publish/module.go index ef5c24277b..c59df2ffe6 100644 --- a/router-tests/modules/stream-publish/module.go +++ b/router-tests/modules/stream-publish/module.go @@ -22,7 +22,9 @@ func (m *PublishModule) Provision(ctx *core.ModuleContext) error { } func (m *PublishModule) OnPublishEvents(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - m.Logger.Info("Publish Hook has been run") + if m.Logger != nil { + m.Logger.Info("Publish Hook has been run") + } if m.Callback != nil { return m.Callback(ctx, events) diff --git a/router-tests/modules/stream-receive/module.go b/router-tests/modules/stream-receive/module.go index 51d2b22a33..cdb680d015 100644 --- a/router-tests/modules/stream-receive/module.go +++ b/router-tests/modules/stream-receive/module.go @@ -22,7 +22,9 @@ func (m *StreamReceiveModule) Provision(ctx *core.ModuleContext) error { } func (m *StreamReceiveModule) OnReceiveEvents(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - m.Logger.Info("Stream Hook has been run") + if m.Logger != nil { + m.Logger.Info("Stream Hook has been run") + } if m.Callback != nil { return m.Callback(ctx, events) From a0ed7e5e237eef85760168a27ccdcddfcf859d6a Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Mon, 17 Nov 2025 10:08:31 +0100 Subject: [PATCH 18/44] fix: wrong mock in tests --- router/pkg/pubsub/datasource/mocks.go | 28 ++++++++++++++++----------- router/pkg/pubsub/pubsub_test.go | 2 +- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/router/pkg/pubsub/datasource/mocks.go b/router/pkg/pubsub/datasource/mocks.go index 3c56f09919..69b5a7eab8 100644 --- a/router/pkg/pubsub/datasource/mocks.go +++ b/router/pkg/pubsub/datasource/mocks.go @@ -965,7 +965,7 @@ func (_c *MockProviderBuilder_BuildEngineDataSourceFactory_Call[P, E]) RunAndRet // BuildProvider provides a mock function for the type MockProviderBuilder func (_mock *MockProviderBuilder[P, E]) BuildProvider(options P, providerOpts ProviderOpts) (Provider, error) { - ret := _mock.Called(options) + ret := _mock.Called(options, providerOpts) if len(ret) == 0 { panic("no return value specified for BuildProvider") @@ -973,18 +973,18 @@ func (_mock *MockProviderBuilder[P, E]) BuildProvider(options P, providerOpts Pr var r0 Provider var r1 error - if returnFunc, ok := ret.Get(0).(func(P) (Provider, error)); ok { - return returnFunc(options) + if returnFunc, ok := ret.Get(0).(func(P, ProviderOpts) (Provider, error)); ok { + return returnFunc(options, providerOpts) } - if returnFunc, ok := ret.Get(0).(func(P) Provider); ok { - r0 = returnFunc(options) + if returnFunc, ok := ret.Get(0).(func(P, ProviderOpts) Provider); ok { + r0 = returnFunc(options, providerOpts) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(Provider) } } - if returnFunc, ok := ret.Get(1).(func(P) error); ok { - r1 = returnFunc(options) + if returnFunc, ok := ret.Get(1).(func(P, ProviderOpts) error); ok { + r1 = returnFunc(options, providerOpts) } else { r1 = ret.Error(1) } @@ -998,18 +998,24 @@ type MockProviderBuilder_BuildProvider_Call[P any, E any] struct { // BuildProvider is a helper method to define mock.On call // - options P -func (_e *MockProviderBuilder_Expecter[P, E]) BuildProvider(options interface{}) *MockProviderBuilder_BuildProvider_Call[P, E] { - return &MockProviderBuilder_BuildProvider_Call[P, E]{Call: _e.mock.On("BuildProvider", options)} +// - providerOpts ProviderOpts +func (_e *MockProviderBuilder_Expecter[P, E]) BuildProvider(options interface{}, providerOpts interface{}) *MockProviderBuilder_BuildProvider_Call[P, E] { + return &MockProviderBuilder_BuildProvider_Call[P, E]{Call: _e.mock.On("BuildProvider", options, providerOpts)} } -func (_c *MockProviderBuilder_BuildProvider_Call[P, E]) Run(run func(options P)) *MockProviderBuilder_BuildProvider_Call[P, E] { +func (_c *MockProviderBuilder_BuildProvider_Call[P, E]) Run(run func(options P, providerOpts ProviderOpts)) *MockProviderBuilder_BuildProvider_Call[P, E] { _c.Call.Run(func(args mock.Arguments) { var arg0 P if args[0] != nil { arg0 = args[0].(P) } + var arg1 ProviderOpts + if args[1] != nil { + arg1 = args[1].(ProviderOpts) + } run( arg0, + arg1, ) }) return _c @@ -1020,7 +1026,7 @@ func (_c *MockProviderBuilder_BuildProvider_Call[P, E]) Return(provider Provider return _c } -func (_c *MockProviderBuilder_BuildProvider_Call[P, E]) RunAndReturn(run func(options P) (Provider, error)) *MockProviderBuilder_BuildProvider_Call[P, E] { +func (_c *MockProviderBuilder_BuildProvider_Call[P, E]) RunAndReturn(run func(options P, providerOpts ProviderOpts) (Provider, error)) *MockProviderBuilder_BuildProvider_Call[P, E] { _c.Call.Return(run) return _c } diff --git a/router/pkg/pubsub/pubsub_test.go b/router/pkg/pubsub/pubsub_test.go index f0b99ec303..2d430440e2 100644 --- a/router/pkg/pubsub/pubsub_test.go +++ b/router/pkg/pubsub/pubsub_test.go @@ -68,7 +68,7 @@ func TestBuild_OK(t *testing.T) { }).Return(nil) mockBuilder.On("TypeID").Return("nats") - mockBuilder.On("BuildProvider", natsEventSources[0]).Return(mockPubSubProvider, nil) + mockBuilder.On("BuildProvider", natsEventSources[0], mock.Anything).Return(mockPubSubProvider, nil) // ctx, kafkaBuilder, config.Providers.Kafka, kafkaDsConfsWithEvents // Execute the function From ef85d0a740e97d26bb48e42ef53de426d0801581 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Mon, 17 Nov 2025 14:11:49 +0100 Subject: [PATCH 19/44] chore: remove obsolete return from mock + rename test package --- router-tests/modules/stream-publish/module.go | 2 +- router-tests/modules/stream-receive/module.go | 2 +- router/pkg/pubsub/pubsub_test.go | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/router-tests/modules/stream-publish/module.go b/router-tests/modules/stream-publish/module.go index c59df2ffe6..15c70b21bf 100644 --- a/router-tests/modules/stream-publish/module.go +++ b/router-tests/modules/stream-publish/module.go @@ -1,4 +1,4 @@ -package publish +package stream_publish import ( "go.uber.org/zap" diff --git a/router-tests/modules/stream-receive/module.go b/router-tests/modules/stream-receive/module.go index cdb680d015..08fdb390d9 100644 --- a/router-tests/modules/stream-receive/module.go +++ b/router-tests/modules/stream-receive/module.go @@ -1,4 +1,4 @@ -package batch +package stream_receive import ( "go.uber.org/zap" diff --git a/router/pkg/pubsub/pubsub_test.go b/router/pkg/pubsub/pubsub_test.go index 2d430440e2..e89c94e885 100644 --- a/router/pkg/pubsub/pubsub_test.go +++ b/router/pkg/pubsub/pubsub_test.go @@ -65,7 +65,7 @@ func TestBuild_OK(t *testing.T) { mockPubSubProvider.On("SetHooks", datasource.Hooks{ OnReceiveEvents: datasource.OnReceiveEventsHooks{Handlers: []datasource.OnReceiveEventsFn(nil)}, OnPublishEvents: datasource.OnPublishEventsHooks{Handlers: []datasource.OnPublishEventsFn(nil)}, - }).Return(nil) + }) mockBuilder.On("TypeID").Return("nats") mockBuilder.On("BuildProvider", natsEventSources[0], mock.Anything).Return(mockPubSubProvider, nil) @@ -244,7 +244,7 @@ func TestBuild_ShouldNotInitializeProviderIfNotUsed(t *testing.T) { mockPubSubUsedProvider.On("SetHooks", datasource.Hooks{ OnReceiveEvents: datasource.OnReceiveEventsHooks{Handlers: []datasource.OnReceiveEventsFn(nil)}, OnPublishEvents: datasource.OnPublishEventsHooks{Handlers: []datasource.OnPublishEventsFn(nil)}, - }).Return(nil) + }) mockBuilder.On("TypeID").Return("nats") mockBuilder.On("BuildProvider", natsEventSources[1], mock.Anything). From 736ef093475cb1baf457f09ffdeff996c8d712db Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Tue, 18 Nov 2025 20:37:44 +0100 Subject: [PATCH 20/44] chore: add description to router-tests + delete obsolete test --- .../modules/start_subscription_test.go | 37 ++++- router-tests/modules/stream_publish_test.go | 71 +++++---- router-tests/modules/stream_receive_test.go | 148 +++++++++++++++++- 3 files changed, 214 insertions(+), 42 deletions(-) diff --git a/router-tests/modules/start_subscription_test.go b/router-tests/modules/start_subscription_test.go index 0d9ee9c054..4f283ce070 100644 --- a/router-tests/modules/start_subscription_test.go +++ b/router-tests/modules/start_subscription_test.go @@ -24,6 +24,10 @@ func TestStartSubscriptionHook(t *testing.T) { t.Run("Test StartSubscription hook is called", func(t *testing.T) { t.Parallel() + // This test verifies that the OnStartSubscription hook is invoked when a client initiates a subscription. + // It confirms the basic integration of the start subscription module by checking for the expected log message, + // ensuring the hook is called at the right moment in the subscription lifecycle. + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ @@ -86,6 +90,10 @@ func TestStartSubscriptionHook(t *testing.T) { t.Run("Test StartSubscription write event works", func(t *testing.T) { t.Parallel() + // This test verifies that the OnStartSubscription hook can emit a custom event to the subscription + // using WriteEvent(). It tests that a synthetic event injected by the hook is properly delivered + // to the client when the subscription starts, allowing for initialization data or welcome messages. + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ @@ -170,9 +178,13 @@ func TestStartSubscriptionHook(t *testing.T) { }) }) - t.Run("Test StartSubscription with close to true", func(t *testing.T) { + t.Run("Test StartSubscription closes client connection when hook returns an error", func(t *testing.T) { t.Parallel() + // This test verifies that when the OnStartSubscription hook returns an error, the subscription + // is closed and the error is propagated to the client. It ensures that hooks can prevent + // subscriptions from starting by returning an error, which triggers proper cleanup. + callbackCalled := make(chan bool) cfg := config.Config{ @@ -254,9 +266,13 @@ func TestStartSubscriptionHook(t *testing.T) { }) }) - t.Run("Test StartSubscription write event sends event only to the subscription", func(t *testing.T) { + t.Run("Test event emitted byStartSubscription sends event only to the client that triggered the hook", func(t *testing.T) { t.Parallel() + // This test verifies that WriteEvent() in the OnStartSubscription hook sends events only to the specific + // subscription that triggered the hook, not to other subscriptions. It tests with multiple subscriptions + // to ensure event isolation and that hooks can target individual clients based on their context. + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ @@ -360,6 +376,10 @@ func TestStartSubscriptionHook(t *testing.T) { t.Run("Test StartSubscription error is propagated to the client", func(t *testing.T) { t.Parallel() + // This test verifies that errors returned by the OnStartSubscription hook are properly propagated to the client + // with correct HTTP status codes and error messages. It ensures clients receive detailed error information + // including custom status codes when a subscription is rejected by the hook. + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ @@ -448,6 +468,10 @@ func TestStartSubscriptionHook(t *testing.T) { t.Run("Test StartSubscription hook is called for engine subscription", func(t *testing.T) { t.Parallel() + // This test verifies that the OnStartSubscription hook is called for engine-based subscriptions + // (subscriptions resolved by the router's execution engine, not event-driven sources like Kafka). + // It ensures the hook works uniformly across different subscription types. + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ @@ -504,6 +528,10 @@ func TestStartSubscriptionHook(t *testing.T) { t.Run("Test StartSubscription hook is called for engine subscription and write event works", func(t *testing.T) { t.Parallel() + // This test verifies that WriteEvent() works for engine-based subscriptions, allowing hooks to inject + // custom events even for subscriptions that don't use event-driven sources. It tests that the synthetic + // event is delivered first, followed by the normal engine-generated subscription data. + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ @@ -585,6 +613,11 @@ func TestStartSubscriptionHook(t *testing.T) { t.Run("Test when StartSubscription hook returns an error, the OnOriginResponse hook is not called", func(t *testing.T) { t.Parallel() + + // This test verifies that when the OnStartSubscription hook returns an error, subsequent hooks like + // OnOriginResponse are not executed. It ensures proper hook chain short-circuiting when errors occur, + // preventing unnecessary processing after a subscription has been rejected. + originResponseCalled := make(chan *http.Response, 1) cfg := config.Config{ diff --git a/router-tests/modules/stream_publish_test.go b/router-tests/modules/stream_publish_test.go index ddaf982029..09757eae0b 100644 --- a/router-tests/modules/stream_publish_test.go +++ b/router-tests/modules/stream_publish_test.go @@ -27,6 +27,16 @@ func TestPublishHook(t *testing.T) { t.Run("Test Publish hook can't assert to mutable types", func(t *testing.T) { t.Parallel() + // This test verifies that regular StreamEvents cannot be type-asserted to MutableStreamEvent. + // By default events are immutable in Cosmo Streams hooks, because it is not garantueed they aren't + // shared with other goroutines. + // The only acceptable way to get mutable events is to do a deep copy inside the hook by invoking + // event.Clone(), which returns a mutable copy of the event. If a type assertion would be successful + // it means the hook developer would have an event of type MutableEvent, but the deep copy never happened. + // Note: It's not as important in the OnPublishEvent hook, because events are isolated between hook calls. + // It's rather important in the OnReceiveEvent hook but both hooks share the same behaviour for consistency reasons + // and thats why we test it here as well. + var taPossible atomic.Bool taPossible.Store(true) @@ -74,42 +84,18 @@ func TestPublishHook(t *testing.T) { t.Run("Test Publish hook is called", func(t *testing.T) { t.Parallel() - cfg := config.Config{ - Graph: config.Graph{}, - Modules: map[string]interface{}{ - "publishModule": stream_publish.PublishModule{}, - }, - } - - testenv.Run(t, &testenv.Config{ - RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, - EnableKafka: true, - RouterOptions: []core.Option{ - core.WithModulesConfig(cfg.Modules), - core.WithCustomModules(&stream_publish.PublishModule{}), - }, - LogObservation: testenv.LogObservationConfig{ - Enabled: true, - LogLevel: zapcore.InfoLevel, - }, - }, func(t *testing.T, xEnv *testenv.Environment) { - resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ - Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`, - }) - require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":false}}}`, resOne.Body) - - requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") - assert.Len(t, requestLog.All(), 1) - }) - }) - - t.Run("Test Publish hook is called with mutable event", func(t *testing.T) { - t.Parallel() + // This test verifies that the publish hook is invoked when a mutation with a Kafka publish is executed. + // It confirms the hook as been called by checking a log message, which is written by the custom module + // used in these tests right before the actual hook is being called. cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "publishModule": stream_publish.PublishModule{}, + "publishModule": stream_publish.PublishModule{ + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + return events, nil + }, + }, }, } @@ -138,6 +124,11 @@ func TestPublishHook(t *testing.T) { t.Run("Test kafka publish error is returned and messages sent", func(t *testing.T) { t.Parallel() + // This test verifies that when the publish hook returns events and an error, + // the error is properly logged but the messages are still sent to Kafka. + // It ensures that hook errors don't prevent message delivery if the hook developer + // wants to do so. If he does not want this he must no return events. + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ @@ -183,6 +174,11 @@ func TestPublishHook(t *testing.T) { t.Run("Test nats publish error is returned and messages sent", func(t *testing.T) { t.Parallel() + // This test verifies that when the publish hook returns an error for NATS events, + // the error is properly logged but the messages are still sent to NATS. + // It ensures that hook errors don't prevent message delivery for NATS if the hook developer wants to do so. + // If he does not want this he must no return events. + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ @@ -237,6 +233,11 @@ func TestPublishHook(t *testing.T) { t.Run("Test redis publish error is returned and messages sent", func(t *testing.T) { t.Parallel() + // This test verifies that when the publish hook returns an error for Redis events, + // the error is properly logged but the messages are still sent to Redis (non-blocking behavior). + // It ensures that hook errors don't prevent message delivery for Redis if the hook developer wants to do so. + // If he does not want this he must no return events. + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ @@ -281,6 +282,12 @@ func TestPublishHook(t *testing.T) { t.Run("Test kafka module publish with argument in header", func(t *testing.T) { t.Parallel() + // This test verifies that the publish hook can modify Kafka events by cloning them, + // changing the event data, and adding custom headers. It tests the ability to access + // operation variables and inject them as headers into Kafka messages. + // The test ensures that concrete event types can be used and their + // distinct broker features (like headers for Kafka) are accessible for hook developers. + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ diff --git a/router-tests/modules/stream_receive_test.go b/router-tests/modules/stream_receive_test.go index 325600915d..6efa850fe0 100644 --- a/router-tests/modules/stream_receive_test.go +++ b/router-tests/modules/stream_receive_test.go @@ -37,6 +37,10 @@ func TestReceiveHook(t *testing.T) { t.Run("Test Receive hook is called", func(t *testing.T) { t.Parallel() + // This test verifies that the receive hook is invoked when events are received from Kafka. + // It confirms the hook is called by checking for the expected log message + // and that subscription events are properly delivered to the client. + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ @@ -110,6 +114,11 @@ func TestReceiveHook(t *testing.T) { t.Run("Test Receive hook could change events", func(t *testing.T) { t.Parallel() + // This test verifies that the receive hook can modify events by cloning them first, so they become mutable, + // and then changing their data. This is the only way to get mutable events, because by default events are immutable. + // It tests that the modified events are properly delivered to subscribers with the updated data, + // demonstrating that hooks can transform stream events before they reach clients. + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ @@ -191,9 +200,109 @@ func TestReceiveHook(t *testing.T) { }) }) + t.Run("Test hook can't assert to mutable types", func(t *testing.T) { + t.Parallel() + + // This test verifies that regular StreamEvents cannot be type-asserted to MutableStreamEvent. + // By default events are immutable in Cosmo Streams hooks, because it is not garantueed they aren't + // shared with other goroutines. + // The only acceptable way to get mutable events is to do a deep copy inside the hook by invoking + // event.Clone(), which returns a mutable copy of the event. If a type assertion would be successful + // it means the hook developer would have an event of type MutableEvent, but the deep copy never happened. + + var taPossible atomic.Bool + taPossible.Store(true) + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "streamReceiveModule": stream_receive.StreamReceiveModule{ + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + for _, evt := range events.All() { + _, ok := evt.(datasource.MutableStreamEvent) + if !ok { + taPossible.Store(false) + } + } + return events, nil + }, + }, + }, + } + + testenv.Run(t, &testenv.Config{ + RouterConfigJSONTemplate: testenv.ConfigWithEdfsKafkaJSONTemplate, + EnableKafka: true, + RouterOptions: []core.Option{ + core.WithModulesConfig(cfg.Modules), + core.WithCustomModules(&stream_receive.StreamReceiveModule{}), + }, + LogObservation: testenv.LogObservationConfig{ + Enabled: true, + LogLevel: zapcore.InfoLevel, + }, + }, func(t *testing.T, xEnv *testenv.Environment) { + topics := []string{"employeeUpdated"} + events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) + + var subscriptionOne struct { + employeeUpdatedMyKafka struct { + ID float64 `graphql:"id"` + Details struct { + Forename string `graphql:"forename"` + Surname string `graphql:"surname"` + } `graphql:"details"` + } `graphql:"employeeUpdatedMyKafka(employeeID: 3)"` + } + + surl := xEnv.GraphQLWebSocketSubscriptionURL() + client := graphql.NewSubscriptionClient(surl) + + subscriptionArgsCh := make(chan kafkaSubscriptionArgs) + subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionArgsCh <- kafkaSubscriptionArgs{ + dataValue: dataValue, + errValue: errValue, + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, subscriptionOneID) + + clientRunCh := make(chan error) + go func() { + clientRunCh <- client.Run() + }() + + xEnv.WaitForSubscriptionCount(1, Timeout) + + events.ProduceKafkaMessage(t, xEnv, Timeout, topics[0], `{"__typename":"Employee","id": 1,"update":{"name":"foo"}}`) + + testenv.AwaitChannelWithT(t, Timeout, subscriptionArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { + require.NoError(t, args.errValue) + require.JSONEq(t, `{"employeeUpdatedMyKafka":{"id":1,"details":{"forename":"Jens","surname":"Neuse"}}}`, string(args.dataValue)) + }) + + require.NoError(t, client.Close()) + testenv.AwaitChannelWithT(t, Timeout, clientRunCh, func(t *testing.T, err error) { + require.NoError(t, err) + }, "unable to close client before timeout") + + requestLog := xEnv.Observer().FilterMessage("Stream Hook has been run") + assert.Len(t, requestLog.All(), 1) + + assert.False(t, taPossible.Load(), "invalid type assertion was possible") + }) + }) + t.Run("Test Receive hook change events of one of multiple subscriptions", func(t *testing.T) { t.Parallel() + // This test verifies that the receive hook can selectively modify events for specific subscriptions + // based on the clients authentication context. It tests that when multiple clients are subscribed, the hook can + // access JWT claims of individual clients and modify events only for authenticated users with specific claims, + // while leaving events for other clients unchanged. + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ @@ -353,6 +462,10 @@ func TestReceiveHook(t *testing.T) { t.Run("Test Receive hook can access custom header", func(t *testing.T) { t.Parallel() + // This test verifies that the receive hook can access custom HTTP headers from the WebSocket connection. + // It tests that hooks can read headers sent during subscription initialization and use them to + // conditionally modify events, enabling header-based event transformation logic. + customHeader := http.CanonicalHeaderKey("X-Custom-Header") cfg := config.Config{ @@ -451,6 +564,10 @@ func TestReceiveHook(t *testing.T) { t.Run("Test Batch hook error should close Kafka clients and subscriptions", func(t *testing.T) { t.Parallel() + // This test verifies that when the receive hook returns an error, the router properly closes + // the subscription connection and cleans up Kafka clients. It ensures that hook errors trigger + // graceful shutdown of the subscription to prevent resource leaks or stuck connections. + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ @@ -525,6 +642,11 @@ func TestReceiveHook(t *testing.T) { t.Run("Test concurrent handler execution works", func(t *testing.T) { t.Parallel() + // This test verifies that the MaxConcurrentHandlers configuration properly limits the number of + // receive hooks executing simultaneously. It tests various concurrency levels (1, 2, 10, 20 handlers) + // with multiple clients to ensure the router respects the concurrency limit and never exceeds it, + // even under load with many active clients. + testCases := []struct { name string maxConcurrent int @@ -569,7 +691,7 @@ func TestReceiveHook(t *testing.T) { Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { currentHandlers.Add(1) - // wait for other handlers in the batch + // Wait for other hooks in the same client update batch to start. for { current := currentHandlers.Load() max := maxCurrentHandlers.Load() @@ -579,7 +701,8 @@ func TestReceiveHook(t *testing.T) { } if current >= int32(tc.maxConcurrent) { - // wait to see if the updater spawns too many concurrent handlers + // wait to see if the subscription-updater spawns too many concurrent hooks, + // i.e. exceeding the number of configured max concurrent hooks. deadline := time.Now().Add(300 * time.Millisecond) for time.Now().Before(deadline) { if currentHandlers.Load() > int32(tc.maxConcurrent) { @@ -589,8 +712,14 @@ func TestReceiveHook(t *testing.T) { break } - // Let handlers continue if we never reach a batch size = tc.maxConcurrent - // because there are not enough remaining subscribers to be updated. + // Let hooks continue if we never reach a updater batch size = tc.maxConcurrent + // because there are not enough remaining clients to be updated. + // i.e. it could be the last round of updates: + // 100 clients, now in comes a new event from broker, max concurrent hooks = 30. + // First round: 30 hooks run, 70 remaining. + // Second round: 30 hooks run, 40 remaining. + // Third round: 30 hooks run, 10 remaining. + // Fourth round: 10 hooks run, then we end up here because remainingSubs < tc.maxConcurrent. remainingSubs := tc.numSubscribers - int(finishedHandlers.Load()) if remainingSubs < tc.maxConcurrent { break @@ -699,12 +828,15 @@ func TestReceiveHook(t *testing.T) { t.Parallel() // One subscriber receives three consecutive events. - // The first event's hook is delayed, exceeding the timeout. + // The first event's hook is delayed, exceeding the configurable hook timeout. // The second and third events' hooks process immediately without delay. - // Because the first hook exceeds the timeout, the system abandons waiting for it - // and processes the second and third events. + // Because the first hook exceeds the timeout, the subscription-updater gives up waiting for it + // and proceedes to process the second and third events immediately. // The first event will be delivered later when its hook finally completes. - // This should result in event order [2, 3, 1] at the client. + // This should result in the first event being delivered last. + // + // Delivering events out of order is a tradeoff to ensure that hooks do not block the subscription-updater for too long. + // We try to keep the order but once the timeout is exceeded we need to move on and it's no longer guaranteed. hookDelay := 500 * time.Millisecond hookTimeout := 100 * time.Millisecond From ef0b63834601c77c17c8c10c878a46b47dd978da Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Tue, 18 Nov 2025 20:43:30 +0100 Subject: [PATCH 21/44] chore: Use builtin error instead of inappropriate graphql-error The core.NewHttpGraphqlError() method returns an error, which does not fit the OnPublishEvent hook very well. Instead we return a generic error, which results in the same thing. The error handling does not behave differently because of it. --- router-tests/modules/stream_publish_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/router-tests/modules/stream_publish_test.go b/router-tests/modules/stream_publish_test.go index 09757eae0b..52299031bf 100644 --- a/router-tests/modules/stream_publish_test.go +++ b/router-tests/modules/stream_publish_test.go @@ -2,7 +2,7 @@ package module_test import ( "encoding/json" - "net/http" + "errors" "strconv" "sync/atomic" "testing" @@ -134,7 +134,7 @@ func TestPublishHook(t *testing.T) { Modules: map[string]interface{}{ "publishModule": stream_publish.PublishModule{ Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - return events, core.NewHttpGraphqlError("test", http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return events, errors.New("test") }, }, }, @@ -184,7 +184,7 @@ func TestPublishHook(t *testing.T) { Modules: map[string]interface{}{ "publishModule": stream_publish.PublishModule{ Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - return events, core.NewHttpGraphqlError("test", http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return events, errors.New("test") }, }, }, @@ -243,7 +243,7 @@ func TestPublishHook(t *testing.T) { Modules: map[string]interface{}{ "publishModule": stream_publish.PublishModule{ Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - return events, core.NewHttpGraphqlError("test", http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return events, errors.New("test") }, }, }, From 8dcd1f4f45dc94c81aadf6874d7245b6feafc097 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Tue, 18 Nov 2025 21:09:06 +0100 Subject: [PATCH 22/44] fix: remove unwanted logging of errors from OnPublishEvent hook If this hook returns an error it is not wanted that the error is logged. Other Cosmo Streams hooks don't do it. The reason is we want the hook developer to decide what log level to use for errors, or if he doesn't want to log an error at all. This error log handling is already in place for the other two hooks but somehow was forgotten on this one. --- router-tests/modules/stream_publish_test.go | 25 +++++++------------ .../pkg/pubsub/datasource/pubsubprovider.go | 8 ------ 2 files changed, 9 insertions(+), 24 deletions(-) diff --git a/router-tests/modules/stream_publish_test.go b/router-tests/modules/stream_publish_test.go index 52299031bf..39e26f124d 100644 --- a/router-tests/modules/stream_publish_test.go +++ b/router-tests/modules/stream_publish_test.go @@ -69,10 +69,11 @@ func TestPublishHook(t *testing.T) { LogLevel: zapcore.InfoLevel, }, }, func(t *testing.T, xEnv *testenv.Environment) { + events.KafkaEnsureTopicExists(t, xEnv, time.Second, "employeeUpdated") resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`, }) - require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":false}}}`, resOne.Body) + require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":true}}}`, resOne.Body) requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") assert.Len(t, requestLog.All(), 1) @@ -111,10 +112,11 @@ func TestPublishHook(t *testing.T) { LogLevel: zapcore.InfoLevel, }, }, func(t *testing.T, xEnv *testenv.Environment) { + events.KafkaEnsureTopicExists(t, xEnv, time.Second, "employeeUpdated") resOne := xEnv.MakeGraphQLRequestOK(testenv.GraphQLRequest{ Query: `mutation { updateEmployeeMyKafka(employeeID: 3, update: {name: "name test"}) { success } }`, }) - require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":false}}}`, resOne.Body) + require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":true}}}`, resOne.Body) requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") assert.Len(t, requestLog.All(), 1) @@ -124,8 +126,8 @@ func TestPublishHook(t *testing.T) { t.Run("Test kafka publish error is returned and messages sent", func(t *testing.T) { t.Parallel() - // This test verifies that when the publish hook returns events and an error, - // the error is properly logged but the messages are still sent to Kafka. + // This test verifies that when the publish hook returns events and an error + // but the messages are still sent to Kafka. // It ensures that hook errors don't prevent message delivery if the hook developer // wants to do so. If he does not want this he must no return events. @@ -162,9 +164,6 @@ func TestPublishHook(t *testing.T) { requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") assert.Len(t, requestLog.All(), 1) - requestLog2 := xEnv.Observer().FilterMessage("error applying publish event hooks") - assert.Len(t, requestLog2.All(), 1) - records, err := events.ReadKafkaMessages(xEnv, time.Second, "employeeUpdated", 1) require.NoError(t, err) require.Len(t, records, 1) @@ -175,7 +174,7 @@ func TestPublishHook(t *testing.T) { t.Parallel() // This test verifies that when the publish hook returns an error for NATS events, - // the error is properly logged but the messages are still sent to NATS. + // but the messages are still sent to NATS. // It ensures that hook errors don't prevent message delivery for NATS if the hook developer wants to do so. // If he does not want this he must no return events. @@ -219,9 +218,6 @@ func TestPublishHook(t *testing.T) { requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") assert.Len(t, requestLog.All(), 1) - requestLog2 := xEnv.Observer().FilterMessage("error applying publish event hooks") - assert.Len(t, requestLog2.All(), 1) - msgOne, err := firstSub.NextMsg(5 * time.Second) require.NoError(t, err) require.Equal(t, xEnv.GetPubSubName("employeeUpdatedMyNats.3"), msgOne.Subject) @@ -233,8 +229,8 @@ func TestPublishHook(t *testing.T) { t.Run("Test redis publish error is returned and messages sent", func(t *testing.T) { t.Parallel() - // This test verifies that when the publish hook returns an error for Redis events, - // the error is properly logged but the messages are still sent to Redis (non-blocking behavior). + // This test verifies that when the publish hook returns an error for Redis events + // but the messages are still sent to Redis (non-blocking behavior). // It ensures that hook errors don't prevent message delivery for Redis if the hook developer wants to do so. // If he does not want this he must no return events. @@ -272,9 +268,6 @@ func TestPublishHook(t *testing.T) { requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") assert.Len(t, requestLog.All(), 1) - requestLog2 := xEnv.Observer().FilterMessage("error applying publish event hooks") - assert.Len(t, requestLog2.All(), 1) - require.Len(t, records, 1) }) }) diff --git a/router/pkg/pubsub/datasource/pubsubprovider.go b/router/pkg/pubsub/datasource/pubsubprovider.go index aa21c02fc8..ca6eda1d3b 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider.go +++ b/router/pkg/pubsub/datasource/pubsubprovider.go @@ -45,14 +45,6 @@ func (p *PubSubProvider) applyPublishEventHooks(ctx context.Context, cfg Publish }) if err != nil { - p.Logger.Error( - "error applying publish event hooks", - zap.Error(err), - zap.String("provider_id", cfg.ProviderID()), - zap.String("provider_type_id", string(cfg.ProviderType())), - zap.String("field_name", cfg.RootFieldName()), - ) - return currentEvents, err } } From f07f0776ce82b05d1b10b78fc67d64b0af45aacf Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Tue, 18 Nov 2025 21:26:26 +0100 Subject: [PATCH 23/44] fix: filter nil events at end of hook processing This makes the OnPublishEvent hook be more consistent in nil event filtering with the OnReceiveEvent hook. --- router/pkg/pubsub/datasource/pubsubprovider.go | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/router/pkg/pubsub/datasource/pubsubprovider.go b/router/pkg/pubsub/datasource/pubsubprovider.go index ca6eda1d3b..74ff9ff9bb 100644 --- a/router/pkg/pubsub/datasource/pubsubprovider.go +++ b/router/pkg/pubsub/datasource/pubsubprovider.go @@ -37,18 +37,19 @@ func (p *PubSubProvider) applyPublishEventHooks(ctx context.Context, cfg Publish }() currentEvents = events + for _, hook := range p.hooks.OnPublishEvents.Handlers { - var err error currentEvents, err = hook(ctx, cfg, currentEvents, p.eventBuilder) - currentEvents = slices.DeleteFunc(currentEvents, func(event StreamEvent) bool { - return event == nil - }) - if err != nil { - return currentEvents, err + break } } - return currentEvents, nil + + currentEvents = slices.DeleteFunc(currentEvents, func(event StreamEvent) bool { + return event == nil + }) + + return currentEvents, err } func (p *PubSubProvider) ID() string { From f93145f34cbfa3ac59a49ab921cf26b5a3396a33 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Tue, 18 Nov 2025 18:01:43 +0100 Subject: [PATCH 24/44] chore: reintroduce StreamHandlerError core.NewHttpGraphqlError() is not really the best way to describe an error in hooks processing. We have to set http specific codes and status texts, which does not really fit well with subscriptions. I created a new error for this called StreamHookError, which lets you pass a message. It can be returned from a hook: return core.StreamHandlerError{Message: "my hook error"} and it gets sent to the subscription client {"id":"1","type":"error","payload":[{"message":"my hook error"}]} Afterwards we close the connection. This behaviour remains unchanged. --- adr/cosmo-streams-v1.md | 16 ++++++---------- router-tests/modules/start_subscription_test.go | 16 +++++++--------- router/core/errors.go | 5 +++++ router/core/graphql_handler.go | 14 ++++++++++++++ router/core/subscriptions_modules.go | 13 +++++++++++++ 5 files changed, 45 insertions(+), 19 deletions(-) diff --git a/adr/cosmo-streams-v1.md b/adr/cosmo-streams-v1.md index dfacf34c40..f764639a50 100644 --- a/adr/cosmo-streams-v1.md +++ b/adr/cosmo-streams-v1.md @@ -393,21 +393,17 @@ func (m *MyModule) SubscriptionOnStart(ctx core.SubscriptionOnStartHandlerContex // check if the client is authenticated if ctx.Authentication() == nil { // if the client is not authenticated, return an error - return core.NewHttpGraphqlError( - "client is not authenticated", - http.StatusText(http.StatusUnauthorized), - http.StatusUnauthorized, - ) + return &core.StreamHandlerError{ + Message: "client is not authenticated", + } } // check if the client is allowed to subscribe to the stream _, found := ctx.Authentication().Claims()["readEmployee"] if !found { - return core.NewHttpGraphqlError( - "client is not allowed to read employees", - http.StatusText(http.StatusForbidden), - http.StatusForbidden, - ) + return &core.StreamHandlerError{ + Message: "client is not allowed to read employees", + } } return nil diff --git a/router-tests/modules/start_subscription_test.go b/router-tests/modules/start_subscription_test.go index 4f283ce070..5b21a43317 100644 --- a/router-tests/modules/start_subscription_test.go +++ b/router-tests/modules/start_subscription_test.go @@ -193,7 +193,7 @@ func TestStartSubscriptionHook(t *testing.T) { "startSubscriptionModule": start_subscription.StartSubscriptionModule{ Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { callbackCalled <- true - return core.NewHttpGraphqlError("subscription closed", http.StatusText(http.StatusOK), http.StatusOK) + return &core.StreamHandlerError{Message: "my custom error"} }, }, }, @@ -385,7 +385,7 @@ func TestStartSubscriptionHook(t *testing.T) { Modules: map[string]interface{}{ "startSubscriptionModule": start_subscription.StartSubscriptionModule{ Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { - return core.NewHttpGraphqlError("test error", http.StatusText(http.StatusLoopDetected), http.StatusLoopDetected) + return &core.StreamHandlerError{Message: "test error"} }, }, }, @@ -441,14 +441,12 @@ func TestStartSubscriptionHook(t *testing.T) { // Wait for the subscription to be closed xEnv.WaitForSubscriptionCount(0, time.Second*10) + expectedError := graphql.Errors{graphql.Error{Message: "test error"}} testenv.AwaitChannelWithT(t, time.Second*10, subscriptionOneArgsCh, func(t *testing.T, args kafkaSubscriptionArgs) { - var graphqlErrs graphql.Errors - require.ErrorAs(t, args.errValue, &graphqlErrs) - statusCode, ok := graphqlErrs[0].Extensions["statusCode"].(float64) - require.True(t, ok, "statusCode is not a float64") - require.Equal(t, http.StatusLoopDetected, int(statusCode)) - require.Equal(t, http.StatusText(http.StatusLoopDetected), graphqlErrs[0].Extensions["code"]) + var actualError graphql.Errors + require.ErrorAs(t, args.errValue, &actualError) + assert.Equal(t, expectedError, actualError) }) require.NoError(t, client.Close()) @@ -625,7 +623,7 @@ func TestStartSubscriptionHook(t *testing.T) { Modules: map[string]interface{}{ "startSubscriptionModule": start_subscription.StartSubscriptionModule{ Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { - return core.NewHttpGraphqlError("subscription closed", http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return &core.StreamHandlerError{Message: "hook error"} }, CallbackOnOriginResponse: func(response *http.Response, ctx core.RequestContext) *http.Response { originResponseCalled <- response diff --git a/router/core/errors.go b/router/core/errors.go index 2ce688bbef..203353516d 100644 --- a/router/core/errors.go +++ b/router/core/errors.go @@ -36,6 +36,7 @@ const ( errorTypeEDFSInvalidMessage errorTypeMergeResult errorTypeHttpError + errorTypeEDFSHookError ) type ( @@ -78,6 +79,10 @@ func getErrorType(err error) errorType { if errors.As(err, &edfsErr) { return errorTypeEDFS } + var edfsHookErr *StreamHandlerError + if errors.As(err, &edfsHookErr) { + return errorTypeEDFSHookError + } var invalidWsSubprotocolErr graphql_datasource.InvalidWsSubprotocolError if errors.As(err, &invalidWsSubprotocolErr) { return errorTypeInvalidWsSubprotocol diff --git a/router/core/graphql_handler.go b/router/core/graphql_handler.go index 845b8bdac0..d078e84065 100644 --- a/router/core/graphql_handler.go +++ b/router/core/graphql_handler.go @@ -390,6 +390,20 @@ func (h *GraphQLHandler) WriteError(ctx *resolve.Context, err error, res *resolv if isHttpResponseWriter { httpWriter.WriteHeader(http.StatusInternalServerError) } + case errorTypeEDFSHookError: + var errStreamHandlerError *StreamHandlerError + if !errors.As(err, &errStreamHandlerError) { + response.Errors[0].Message = "Internal server error" + // We could set response.Errors[0].Extensions, too + if isHttpResponseWriter { + httpWriter.WriteHeader(http.StatusInternalServerError) + } + return + } + response.Errors[0].Message = errStreamHandlerError.Message + if isHttpResponseWriter { + httpWriter.WriteHeader(http.StatusOK) + } case errorTypeInvalidWsSubprotocol: response.Errors[0].Message = fmt.Sprintf("Invalid Subprotocol error: %s or configure the subprotocol to be used using `wgc subgraph update` command.", err.Error()) if isHttpResponseWriter { diff --git a/router/core/subscriptions_modules.go b/router/core/subscriptions_modules.go index 90ca6cf636..dcc5abd4dc 100644 --- a/router/core/subscriptions_modules.go +++ b/router/core/subscriptions_modules.go @@ -404,3 +404,16 @@ func NewPubSubOnReceiveEventsHook(fn func(ctx StreamReceiveEventHandlerContext, return newEvts.Unsafe(), err } } + +// StreamHandlerError writes an error event with Reason to a subscription client and closes the +// websocket connection with code 1000 (Normal closure). +// It can returned from methods of the core.SubscriptionOnStartHandler interface. +type StreamHandlerError struct { + // The message for this error. + Message string +} + +// Error returns the reason of this error. +func (e *StreamHandlerError) Error() string { + return e.Message +} From 65d5d5561fc645aa07cb0aea8836c7d5d8c03c94 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Wed, 19 Nov 2025 12:45:24 +0100 Subject: [PATCH 25/44] chore: rename WriteEvent to EmitLocalEvent --- .../modules/start_subscription_test.go | 6 +- router/core/subscriptions_modules.go | 75 ++++++++++++++----- 2 files changed, 59 insertions(+), 22 deletions(-) diff --git a/router-tests/modules/start_subscription_test.go b/router-tests/modules/start_subscription_test.go index 5b21a43317..78fdea24ba 100644 --- a/router-tests/modules/start_subscription_test.go +++ b/router-tests/modules/start_subscription_test.go @@ -102,7 +102,7 @@ func TestStartSubscriptionHook(t *testing.T) { if ctx.SubscriptionEventConfiguration().RootFieldName() != "employeeUpdatedMyKafka" { return nil } - ctx.WriteEvent((&kafka.MutableEvent{ + ctx.EmitLocalEvent((&kafka.MutableEvent{ Key: []byte("1"), Data: []byte(`{"id": 1, "__typename": "Employee"}`), })) @@ -283,7 +283,7 @@ func TestStartSubscriptionHook(t *testing.T) { return nil } evt := ctx.NewEvent([]byte(`{"id": 1, "__typename": "Employee"}`)) - ctx.WriteEvent(evt) + ctx.EmitLocalEvent(evt) return nil }, }, @@ -536,7 +536,7 @@ func TestStartSubscriptionHook(t *testing.T) { "startSubscriptionModule": start_subscription.StartSubscriptionModule{ Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { evt := ctx.NewEvent([]byte(`{"data":{"countEmp":1000}}`)) - ctx.WriteEvent(evt) + ctx.EmitLocalEvent(evt) return nil }, }, diff --git a/router/core/subscriptions_modules.go b/router/core/subscriptions_modules.go index dcc5abd4dc..da767da118 100644 --- a/router/core/subscriptions_modules.go +++ b/router/core/subscriptions_modules.go @@ -23,10 +23,32 @@ type SubscriptionOnStartHandlerContext interface { Authentication() authentication.Authentication // SubscriptionEventConfiguration is the subscription event configuration (will return nil for engine subscription) SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration - // WriteEvent writes an event to the stream of the current subscription - // It returns true if the event was written to the stream, false if the event was dropped - WriteEvent(event datasource.StreamEvent) bool + // EmitLocalEvent sends an event directly to the subscription stream of the + // currently connected client. + // + // This method triggers the router to resolve the client's operation and emit + // the resulting data as a stream event. The event exists only within the + // router; it is not forwarded to any message broker. + // + // The event is delivered exclusively to the client associated with the current + // handler execution. No other subscriptions are affected. + // + // The method returns true if the event was successfully emitted, or false if + // it was dropped. + EmitLocalEvent(event datasource.StreamEvent) bool // NewEvent creates a new event that can be used in the subscription. + // + // The data parameter must contain valid JSON bytes. The format depends on the subscription type. + // + // For event-driven subscriptions (Cosmo Streams / EDFS), the data should contain: + // __typename : The name of the schema entity, which is expected to be returned to the client. + // {keyName} : The key of the entity as configured on the schema via @key directive. + // Example usage: ctx.NewEvent([]byte(`{"__typename": "Employee", "id": 1}`)) + // + // For normal subscriptions, you need to provide the complete GraphQL response structure. + // Example usage: ctx.NewEvent([]byte(`{"data": {"fieldName": value}}`)) + // + // You can use EmitLocalEvent to emit this event to subscriptions. NewEvent(data []byte) datasource.MutableStreamEvent } @@ -69,7 +91,7 @@ type pubSubSubscriptionOnStartHookContext struct { operation OperationContext authentication authentication.Authentication subscriptionEventConfiguration datasource.SubscriptionEventConfiguration - writeEventHook func(data []byte) + emitLocalEventFn func(data []byte) eventBuilder datasource.EventBuilderFn } @@ -93,8 +115,8 @@ func (c *pubSubSubscriptionOnStartHookContext) SubscriptionEventConfiguration() return c.subscriptionEventConfiguration } -func (c *pubSubSubscriptionOnStartHookContext) WriteEvent(event datasource.StreamEvent) bool { - c.writeEventHook(event.GetData()) +func (c *pubSubSubscriptionOnStartHookContext) EmitLocalEvent(event datasource.StreamEvent) bool { + c.emitLocalEventFn(event.GetData()) return true } @@ -140,11 +162,11 @@ func (e *EngineEvent) Clone() datasource.MutableStreamEvent { } type engineSubscriptionOnStartHookContext struct { - request *http.Request - logger *zap.Logger - operation OperationContext - authentication authentication.Authentication - writeEventHook func(data []byte) + request *http.Request + logger *zap.Logger + operation OperationContext + authentication authentication.Authentication + emitLocalEventFn func(data []byte) } func (c *engineSubscriptionOnStartHookContext) Request() *http.Request { @@ -163,8 +185,8 @@ func (c *engineSubscriptionOnStartHookContext) Authentication() authentication.A return c.authentication } -func (c *engineSubscriptionOnStartHookContext) WriteEvent(event datasource.StreamEvent) bool { - c.writeEventHook(event.GetData()) +func (c *engineSubscriptionOnStartHookContext) EmitLocalEvent(event datasource.StreamEvent) bool { + c.emitLocalEventFn(event.GetData()) return true } @@ -210,7 +232,7 @@ func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHandlerCont operation: requestContext.Operation(), authentication: requestContext.Authentication(), subscriptionEventConfiguration: subConf, - writeEventHook: resolveCtx.Updater, + emitLocalEventFn: resolveCtx.Updater, eventBuilder: eventBuilder, } @@ -233,11 +255,11 @@ func NewEngineSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHandlerCont } hookCtx := &engineSubscriptionOnStartHookContext{ - request: requestContext.Request(), - logger: logger, - operation: requestContext.Operation(), - authentication: requestContext.Authentication(), - writeEventHook: resolveCtx.Updater, + request: requestContext.Request(), + logger: logger, + operation: requestContext.Operation(), + authentication: requestContext.Authentication(), + emitLocalEventFn: resolveCtx.Updater, } return fn(hookCtx) @@ -259,6 +281,13 @@ type StreamReceiveEventHandlerContext interface { // SubscriptionEventConfiguration the subscription event configuration SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration // NewEvent creates a new event that can be used in the subscription. + // + // The data parameter must contain valid JSON bytes representing the raw event payload + // from your message broker (Kafka, NATS, etc.). The JSON must have properly quoted + // property names and must include the __typename field required by GraphQL. + // For example: []byte(`{"__typename": "Employee", "id": 1, "update": {"name": "John"}}`). + // + // This method is typically used in OnReceiveEvents hooks to create new or modified events. NewEvent(data []byte) datasource.MutableStreamEvent } @@ -286,6 +315,14 @@ type StreamPublishEventHandlerContext interface { // PublishEventConfiguration the publish event configuration PublishEventConfiguration() datasource.PublishEventConfiguration // NewEvent creates a new event that can be used in the subscription. + // + // The data parameter must contain valid JSON bytes representing the event payload + // that will be sent to your message broker (Kafka, NATS, etc.). The JSON must have + // properly quoted property names and must include the __typename field required by GraphQL. + // For example: []byte(`{"__typename": "Employee", "id": 1, "update": {"name": "John"}}`). + // + // This method is typically used in OnPublishEvents hooks to create new or modified events + // before they are sent to the message broker. NewEvent(data []byte) datasource.MutableStreamEvent } From f1442f13deae3bee95b6fca060539ebc3a51f73e Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Wed, 19 Nov 2025 13:07:24 +0100 Subject: [PATCH 26/44] chore: rename config events.subscription_hooks to events.handlers --- router-tests/modules/stream_receive_test.go | 4 ++-- router/core/router.go | 2 +- router/core/supervisor_instance.go | 2 +- router/pkg/config/config.go | 6 +++--- router/pkg/config/config.schema.json | 4 ++-- router/pkg/config/fixtures/full.yaml | 2 +- router/pkg/config/testdata/config_defaults.json | 2 +- router/pkg/config/testdata/config_full.json | 2 +- 8 files changed, 12 insertions(+), 12 deletions(-) diff --git a/router-tests/modules/stream_receive_test.go b/router-tests/modules/stream_receive_test.go index 6efa850fe0..264da9a329 100644 --- a/router-tests/modules/stream_receive_test.go +++ b/router-tests/modules/stream_receive_test.go @@ -740,7 +740,7 @@ func TestReceiveHook(t *testing.T) { RouterOptions: []core.Option{ core.WithModulesConfig(cfg.Modules), core.WithCustomModules(&stream_receive.StreamReceiveModule{}), - core.WithSubscriptionHooks(config.SubscriptionHooksConfiguration{ + core.WithStreamsHandlerConfiguration(config.StreamsHandlerConfiguration{ OnReceiveEvents: config.OnReceiveEventsConfiguration{ MaxConcurrentHandlers: tc.maxConcurrent, }, @@ -864,7 +864,7 @@ func TestReceiveHook(t *testing.T) { RouterOptions: []core.Option{ core.WithModulesConfig(cfg.Modules), core.WithCustomModules(&stream_receive.StreamReceiveModule{}), - core.WithSubscriptionHooks(config.SubscriptionHooksConfiguration{ + core.WithStreamsHandlerConfiguration(config.StreamsHandlerConfiguration{ OnReceiveEvents: config.OnReceiveEventsConfiguration{ MaxConcurrentHandlers: 3, HandlerTimeout: hookTimeout, diff --git a/router/core/router.go b/router/core/router.go index 20b0b33ae6..68fd8b803a 100644 --- a/router/core/router.go +++ b/router/core/router.go @@ -2140,7 +2140,7 @@ func WithDemoMode(demoMode bool) Option { } } -func WithSubscriptionHooks(cfg config.SubscriptionHooksConfiguration) Option { +func WithStreamsHandlerConfiguration(cfg config.StreamsHandlerConfiguration) Option { return func(r *Router) { r.subscriptionHooks.onReceiveEvents.maxConcurrentHandlers = cfg.OnReceiveEvents.MaxConcurrentHandlers r.subscriptionHooks.onReceiveEvents.timeout = cfg.OnReceiveEvents.HandlerTimeout diff --git a/router/core/supervisor_instance.go b/router/core/supervisor_instance.go index 4ae022632a..5879b525ad 100644 --- a/router/core/supervisor_instance.go +++ b/router/core/supervisor_instance.go @@ -271,7 +271,7 @@ func optionsFromResources(logger *zap.Logger, config *config.Config) []Option { WithMCP(config.MCP), WithPlugins(config.Plugins), WithDemoMode(config.DemoMode), - WithSubscriptionHooks(config.Events.SubscriptionHooks), + WithStreamsHandlerConfiguration(config.Events.Handlers), } return options diff --git a/router/pkg/config/config.go b/router/pkg/config/config.go index bb8c910982..d036a96f67 100644 --- a/router/pkg/config/config.go +++ b/router/pkg/config/config.go @@ -641,11 +641,11 @@ type EventProviders struct { } type EventsConfiguration struct { - Providers EventProviders `yaml:"providers,omitempty"` - SubscriptionHooks SubscriptionHooksConfiguration `yaml:"subscription_hooks,omitempty"` + Providers EventProviders `yaml:"providers,omitempty"` + Handlers StreamsHandlerConfiguration `yaml:"handlers,omitempty"` } -type SubscriptionHooksConfiguration struct { +type StreamsHandlerConfiguration struct { OnReceiveEvents OnReceiveEventsConfiguration `yaml:"on_receive_events"` } diff --git a/router/pkg/config/config.schema.json b/router/pkg/config/config.schema.json index a2326c86ec..68b6f9c343 100644 --- a/router/pkg/config/config.schema.json +++ b/router/pkg/config/config.schema.json @@ -2305,9 +2305,9 @@ } } }, - "subscription_hooks": { + "handlers": { "type": "object", - "description": "Configuration for subscription custom modules that are executed when events are received from a broker.", + "description": "Configuration for Cosmo Streams / EDFS custom modules", "additionalProperties": false, "properties": { "on_receive_events": { diff --git a/router/pkg/config/fixtures/full.yaml b/router/pkg/config/fixtures/full.yaml index 3f79ed83df..1a09ef7972 100644 --- a/router/pkg/config/fixtures/full.yaml +++ b/router/pkg/config/fixtures/full.yaml @@ -330,7 +330,7 @@ events: urls: - 'redis://localhost:6379/11' cluster_enabled: true - subscription_hooks: + handlers: on_receive_events: max_concurrent_handlers: 100 handler_timeout: 5s diff --git a/router/pkg/config/testdata/config_defaults.json b/router/pkg/config/testdata/config_defaults.json index 7655cc7007..3b832ae149 100644 --- a/router/pkg/config/testdata/config_defaults.json +++ b/router/pkg/config/testdata/config_defaults.json @@ -296,7 +296,7 @@ "Kafka": null, "Redis": null }, - "SubscriptionHooks": { + "Handlers": { "OnReceiveEvents": { "MaxConcurrentHandlers": 100, "HandlerTimeout": 5000000000 diff --git a/router/pkg/config/testdata/config_full.json b/router/pkg/config/testdata/config_full.json index 2731dee760..46f3be6b1e 100644 --- a/router/pkg/config/testdata/config_full.json +++ b/router/pkg/config/testdata/config_full.json @@ -642,7 +642,7 @@ } ] }, - "SubscriptionHooks": { + "Handlers": { "OnReceiveEvents": { "MaxConcurrentHandlers": 100, "HandlerTimeout": 5000000000 From e1bf8cd4f66cb3d831da78dfd93271844841811e Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Wed, 19 Nov 2025 13:52:36 +0100 Subject: [PATCH 27/44] chore: use counter to count hook calls --- .../modules/start-subscription/module.go | 7 + .../modules/start_subscription_test.go | 155 +++++----- router-tests/modules/stream-publish/module.go | 11 +- router-tests/modules/stream-receive/module.go | 11 +- router-tests/modules/stream_publish_test.go | 148 +++++----- router-tests/modules/stream_receive_test.go | 271 ++++++++++-------- 6 files changed, 336 insertions(+), 267 deletions(-) diff --git a/router-tests/modules/start-subscription/module.go b/router-tests/modules/start-subscription/module.go index a7d7706a88..db55e1603d 100644 --- a/router-tests/modules/start-subscription/module.go +++ b/router-tests/modules/start-subscription/module.go @@ -2,6 +2,7 @@ package start_subscription import ( "net/http" + "sync/atomic" "go.uber.org/zap" @@ -14,6 +15,7 @@ type StartSubscriptionModule struct { Logger *zap.Logger Callback func(ctx core.SubscriptionOnStartHandlerContext) error CallbackOnOriginResponse func(response *http.Response, ctx core.RequestContext) *http.Response + HookCallCount *atomic.Int32 // Counter to track how many times the hook is called } func (m *StartSubscriptionModule) Provision(ctx *core.ModuleContext) error { @@ -28,6 +30,11 @@ func (m *StartSubscriptionModule) SubscriptionOnStart(ctx core.SubscriptionOnSta m.Logger.Info("SubscriptionOnStart Hook has been run") } + // Increment the hook call counter + if m.HookCallCount != nil { + m.HookCallCount.Add(1) + } + if m.Callback != nil { return m.Callback(ctx) } diff --git a/router-tests/modules/start_subscription_test.go b/router-tests/modules/start_subscription_test.go index 78fdea24ba..1621ac5e42 100644 --- a/router-tests/modules/start_subscription_test.go +++ b/router-tests/modules/start_subscription_test.go @@ -3,6 +3,7 @@ package module_test import ( "errors" "net/http" + "sync/atomic" "testing" "time" @@ -28,10 +29,14 @@ func TestStartSubscriptionHook(t *testing.T) { // It confirms the basic integration of the start subscription module by checking for the expected log message, // ensuring the hook is called at the right moment in the subscription lifecycle. + customModule := &start_subscription.StartSubscriptionModule{ + HookCallCount: &atomic.Int32{}, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "startSubscriptionModule": start_subscription.StartSubscriptionModule{}, + "startSubscriptionModule": customModule, }, } @@ -82,8 +87,7 @@ func TestStartSubscriptionHook(t *testing.T) { }, "unable to close client before timeout") - requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") - assert.Len(t, requestLog.All(), 1) + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) }) }) @@ -94,21 +98,24 @@ func TestStartSubscriptionHook(t *testing.T) { // using WriteEvent(). It tests that a synthetic event injected by the hook is properly delivered // to the client when the subscription starts, allowing for initialization data or welcome messages. + customModule := &start_subscription.StartSubscriptionModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { + if ctx.SubscriptionEventConfiguration().RootFieldName() != "employeeUpdatedMyKafka" { + return nil + } + ctx.EmitLocalEvent((&kafka.MutableEvent{ + Key: []byte("1"), + Data: []byte(`{"id": 1, "__typename": "Employee"}`), + })) + return nil + }, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "startSubscriptionModule": start_subscription.StartSubscriptionModule{ - Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { - if ctx.SubscriptionEventConfiguration().RootFieldName() != "employeeUpdatedMyKafka" { - return nil - } - ctx.EmitLocalEvent((&kafka.MutableEvent{ - Key: []byte("1"), - Data: []byte(`{"id": 1, "__typename": "Employee"}`), - })) - return nil - }, - }, + "startSubscriptionModule": customModule, }, } @@ -173,8 +180,7 @@ func TestStartSubscriptionHook(t *testing.T) { }, "unable to close client before timeout") - requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") - assert.Len(t, requestLog.All(), 1) + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) }) }) @@ -187,15 +193,18 @@ func TestStartSubscriptionHook(t *testing.T) { callbackCalled := make(chan bool) + customModule := &start_subscription.StartSubscriptionModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { + callbackCalled <- true + return &core.StreamHandlerError{Message: "my custom error"} + }, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "startSubscriptionModule": start_subscription.StartSubscriptionModule{ - Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { - callbackCalled <- true - return &core.StreamHandlerError{Message: "my custom error"} - }, - }, + "startSubscriptionModule": customModule, }, } @@ -256,8 +265,7 @@ func TestStartSubscriptionHook(t *testing.T) { }, "unable to close client before timeout") - requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") - assert.Len(t, requestLog.All(), 1) + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) require.Len(t, subscriptionArgsCh, 1) subscriptionArgs := <-subscriptionArgsCh @@ -273,20 +281,23 @@ func TestStartSubscriptionHook(t *testing.T) { // subscription that triggered the hook, not to other subscriptions. It tests with multiple subscriptions // to ensure event isolation and that hooks can target individual clients based on their context. + customModule := &start_subscription.StartSubscriptionModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { + employeeId := ctx.Operation().Variables().GetInt64("employeeID") + if employeeId != 1 { + return nil + } + evt := ctx.NewEvent([]byte(`{"id": 1, "__typename": "Employee"}`)) + ctx.EmitLocalEvent(evt) + return nil + }, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "startSubscriptionModule": start_subscription.StartSubscriptionModule{ - Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { - employeeId := ctx.Operation().Variables().GetInt64("employeeID") - if employeeId != 1 { - return nil - } - evt := ctx.NewEvent([]byte(`{"id": 1, "__typename": "Employee"}`)) - ctx.EmitLocalEvent(evt) - return nil - }, - }, + "startSubscriptionModule": customModule, }, } @@ -365,8 +376,7 @@ func TestStartSubscriptionHook(t *testing.T) { }, "unable to close client before timeout") - requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") - assert.Len(t, requestLog.All(), 2) + assert.Equal(t, int32(2), customModule.HookCallCount.Load()) t.Cleanup(func() { require.Len(t, subscriptionOneArgsCh, 0) }) @@ -380,14 +390,17 @@ func TestStartSubscriptionHook(t *testing.T) { // with correct HTTP status codes and error messages. It ensures clients receive detailed error information // including custom status codes when a subscription is rejected by the hook. + customModule := &start_subscription.StartSubscriptionModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { + return &core.StreamHandlerError{Message: "test error"} + }, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "startSubscriptionModule": start_subscription.StartSubscriptionModule{ - Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { - return &core.StreamHandlerError{Message: "test error"} - }, - }, + "startSubscriptionModule": customModule, }, } @@ -455,8 +468,7 @@ func TestStartSubscriptionHook(t *testing.T) { }, "unable to close client before timeout") - requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") - assert.Len(t, requestLog.All(), 1) + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) t.Cleanup(func() { require.Len(t, subscriptionOneArgsCh, 0) }) @@ -470,10 +482,14 @@ func TestStartSubscriptionHook(t *testing.T) { // (subscriptions resolved by the router's execution engine, not event-driven sources like Kafka). // It ensures the hook works uniformly across different subscription types. + customModule := &start_subscription.StartSubscriptionModule{ + HookCallCount: &atomic.Int32{}, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "startSubscriptionModule": start_subscription.StartSubscriptionModule{}, + "startSubscriptionModule": customModule, }, } @@ -518,8 +534,7 @@ func TestStartSubscriptionHook(t *testing.T) { }, "unable to close client before timeout") - requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") - assert.Len(t, requestLog.All(), 1) + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) }) }) @@ -530,16 +545,19 @@ func TestStartSubscriptionHook(t *testing.T) { // custom events even for subscriptions that don't use event-driven sources. It tests that the synthetic // event is delivered first, followed by the normal engine-generated subscription data. + customModule := &start_subscription.StartSubscriptionModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { + evt := ctx.NewEvent([]byte(`{"data":{"countEmp":1000}}`)) + ctx.EmitLocalEvent(evt) + return nil + }, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "startSubscriptionModule": start_subscription.StartSubscriptionModule{ - Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { - evt := ctx.NewEvent([]byte(`{"data":{"countEmp":1000}}`)) - ctx.EmitLocalEvent(evt) - return nil - }, - }, + "startSubscriptionModule": customModule, }, } @@ -604,8 +622,7 @@ func TestStartSubscriptionHook(t *testing.T) { }, "unable to close client before timeout") - requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") - assert.Len(t, requestLog.All(), 1) + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) }) }) @@ -618,18 +635,21 @@ func TestStartSubscriptionHook(t *testing.T) { originResponseCalled := make(chan *http.Response, 1) + customModule := &start_subscription.StartSubscriptionModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { + return &core.StreamHandlerError{Message: "hook error"} + }, + CallbackOnOriginResponse: func(response *http.Response, ctx core.RequestContext) *http.Response { + originResponseCalled <- response + return response + }, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "startSubscriptionModule": start_subscription.StartSubscriptionModule{ - Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { - return &core.StreamHandlerError{Message: "hook error"} - }, - CallbackOnOriginResponse: func(response *http.Response, ctx core.RequestContext) *http.Response { - originResponseCalled <- response - return response - }, - }, + "startSubscriptionModule": customModule, }, } @@ -686,8 +706,7 @@ func TestStartSubscriptionHook(t *testing.T) { require.Empty(t, originResponseCalled) - requestLog := xEnv.Observer().FilterMessage("SubscriptionOnStart Hook has been run") - assert.Len(t, requestLog.All(), 1) + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) }) }) } diff --git a/router-tests/modules/stream-publish/module.go b/router-tests/modules/stream-publish/module.go index 15c70b21bf..f7c73409cf 100644 --- a/router-tests/modules/stream-publish/module.go +++ b/router-tests/modules/stream-publish/module.go @@ -1,6 +1,8 @@ package stream_publish import ( + "sync/atomic" + "go.uber.org/zap" "github.com/wundergraph/cosmo/router/core" @@ -10,8 +12,9 @@ import ( const myModuleID = "publishModule" type PublishModule struct { - Logger *zap.Logger - Callback func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) + Logger *zap.Logger + HookCallCount *atomic.Int32 // Counter to track how many times the hook is called + Callback func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) } func (m *PublishModule) Provision(ctx *core.ModuleContext) error { @@ -26,6 +29,10 @@ func (m *PublishModule) OnPublishEvents(ctx core.StreamPublishEventHandlerContex m.Logger.Info("Publish Hook has been run") } + if m.HookCallCount != nil { + m.HookCallCount.Add(1) + } + if m.Callback != nil { return m.Callback(ctx, events) } diff --git a/router-tests/modules/stream-receive/module.go b/router-tests/modules/stream-receive/module.go index 08fdb390d9..926555d241 100644 --- a/router-tests/modules/stream-receive/module.go +++ b/router-tests/modules/stream-receive/module.go @@ -1,6 +1,8 @@ package stream_receive import ( + "sync/atomic" + "go.uber.org/zap" "github.com/wundergraph/cosmo/router/core" @@ -10,8 +12,9 @@ import ( const myModuleID = "streamReceiveModule" type StreamReceiveModule struct { - Logger *zap.Logger - Callback func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) + Logger *zap.Logger + Callback func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) + HookCallCount *atomic.Int32 // Counter to track how many times the hook is called } func (m *StreamReceiveModule) Provision(ctx *core.ModuleContext) error { @@ -26,6 +29,10 @@ func (m *StreamReceiveModule) OnReceiveEvents(ctx core.StreamReceiveEventHandler m.Logger.Info("Stream Hook has been run") } + if m.HookCallCount != nil { + m.HookCallCount.Add(1) + } + if m.Callback != nil { return m.Callback(ctx, events) } diff --git a/router-tests/modules/stream_publish_test.go b/router-tests/modules/stream_publish_test.go index 39e26f124d..f22cce8254 100644 --- a/router-tests/modules/stream_publish_test.go +++ b/router-tests/modules/stream_publish_test.go @@ -40,20 +40,23 @@ func TestPublishHook(t *testing.T) { var taPossible atomic.Bool taPossible.Store(true) + customModule := stream_publish.PublishModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + for _, evt := range events.All() { + _, ok := evt.(datasource.MutableStreamEvent) + if !ok { + taPossible.Store(false) + } + } + return events, nil + }, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "publishModule": stream_publish.PublishModule{ - Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - for _, evt := range events.All() { - _, ok := evt.(datasource.MutableStreamEvent) - if !ok { - taPossible.Store(false) - } - } - return events, nil - }, - }, + "publishModule": customModule, }, } @@ -75,8 +78,7 @@ func TestPublishHook(t *testing.T) { }) require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":true}}}`, resOne.Body) - requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") - assert.Len(t, requestLog.All(), 1) + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) assert.False(t, taPossible.Load(), "invalid type assertion was possible") }) @@ -89,14 +91,17 @@ func TestPublishHook(t *testing.T) { // It confirms the hook as been called by checking a log message, which is written by the custom module // used in these tests right before the actual hook is being called. + customModule := stream_publish.PublishModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + return events, nil + }, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "publishModule": stream_publish.PublishModule{ - Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - return events, nil - }, - }, + "publishModule": customModule, }, } @@ -118,8 +123,7 @@ func TestPublishHook(t *testing.T) { }) require.JSONEq(t, `{"data":{"updateEmployeeMyKafka":{"success":true}}}`, resOne.Body) - requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") - assert.Len(t, requestLog.All(), 1) + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) }) }) @@ -131,14 +135,17 @@ func TestPublishHook(t *testing.T) { // It ensures that hook errors don't prevent message delivery if the hook developer // wants to do so. If he does not want this he must no return events. + customModule := stream_publish.PublishModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + return events, errors.New("test") + }, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "publishModule": stream_publish.PublishModule{ - Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - return events, errors.New("test") - }, - }, + "publishModule": customModule, }, } @@ -161,8 +168,7 @@ func TestPublishHook(t *testing.T) { require.JSONEq(t, `{"data": {"updateEmployeeMyKafka": {"success": false}}}`, resOne.Body) require.Equal(t, resOne.Response.StatusCode, 200) - requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") - assert.Len(t, requestLog.All(), 1) + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) records, err := events.ReadKafkaMessages(xEnv, time.Second, "employeeUpdated", 1) require.NoError(t, err) @@ -178,14 +184,17 @@ func TestPublishHook(t *testing.T) { // It ensures that hook errors don't prevent message delivery for NATS if the hook developer wants to do so. // If he does not want this he must no return events. + customModule := stream_publish.PublishModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + return events, errors.New("test") + }, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "publishModule": stream_publish.PublishModule{ - Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - return events, errors.New("test") - }, - }, + "publishModule": customModule, }, } @@ -215,8 +224,7 @@ func TestPublishHook(t *testing.T) { }) assert.JSONEq(t, `{"data": {"updateEmployeeMyNats": {"success": false}}}`, resOne.Body) - requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") - assert.Len(t, requestLog.All(), 1) + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) msgOne, err := firstSub.NextMsg(5 * time.Second) require.NoError(t, err) @@ -234,14 +242,17 @@ func TestPublishHook(t *testing.T) { // It ensures that hook errors don't prevent message delivery for Redis if the hook developer wants to do so. // If he does not want this he must no return events. + customModule := stream_publish.PublishModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + return events, errors.New("test") + }, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "publishModule": stream_publish.PublishModule{ - Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - return events, errors.New("test") - }, - }, + "publishModule": customModule, }, } @@ -265,8 +276,7 @@ func TestPublishHook(t *testing.T) { }) require.JSONEq(t, `{"data": {"updateEmployeeMyRedis": {"success": false}}}`, resOne.Body) - requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") - assert.Len(t, requestLog.All(), 1) + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) require.Len(t, records, 1) }) @@ -281,34 +291,37 @@ func TestPublishHook(t *testing.T) { // The test ensures that concrete event types can be used and their // distinct broker features (like headers for Kafka) are accessible for hook developers. + customModule := stream_publish.PublishModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + if ctx.PublishEventConfiguration().RootFieldName() != "updateEmployeeMyKafka" { + return events, nil + } + + employeeID := ctx.Operation().Variables().GetInt("employeeID") + + newEvents := make([]datasource.StreamEvent, 0, events.Len()) + for _, event := range events.All() { + newEvt, ok := event.Clone().(*kafka.MutableEvent) + if !ok { + continue + } + newEvt.SetData([]byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) + if newEvt.Headers == nil { + newEvt.Headers = map[string][]byte{} + } + newEvt.Headers["x-employee-id"] = []byte(strconv.Itoa(employeeID)) + newEvents = append(newEvents, newEvt) + } + + return datasource.NewStreamEvents(newEvents), nil + }, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "publishModule": stream_publish.PublishModule{ - Callback: func(ctx core.StreamPublishEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - if ctx.PublishEventConfiguration().RootFieldName() != "updateEmployeeMyKafka" { - return events, nil - } - - employeeID := ctx.Operation().Variables().GetInt("employeeID") - - newEvents := make([]datasource.StreamEvent, 0, events.Len()) - for _, event := range events.All() { - newEvt, ok := event.Clone().(*kafka.MutableEvent) - if !ok { - continue - } - newEvt.SetData([]byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) - if newEvt.Headers == nil { - newEvt.Headers = map[string][]byte{} - } - newEvt.Headers["x-employee-id"] = []byte(strconv.Itoa(employeeID)) - newEvents = append(newEvents, newEvt) - } - - return datasource.NewStreamEvents(newEvents), nil - }, - }, + "publishModule": customModule, }, } @@ -331,8 +344,7 @@ func TestPublishHook(t *testing.T) { }) require.JSONEq(t, `{"data": {"updateEmployeeMyKafka": {"success": true}}}`, resOne.Body) - requestLog := xEnv.Observer().FilterMessage("Publish Hook has been run") - assert.Len(t, requestLog.All(), 1) + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) records, err := events.ReadKafkaMessages(xEnv, time.Second, "employeeUpdated", 1) require.NoError(t, err) diff --git a/router-tests/modules/stream_receive_test.go b/router-tests/modules/stream_receive_test.go index 264da9a329..383f3d9b9d 100644 --- a/router-tests/modules/stream_receive_test.go +++ b/router-tests/modules/stream_receive_test.go @@ -41,10 +41,14 @@ func TestReceiveHook(t *testing.T) { // It confirms the hook is called by checking for the expected log message // and that subscription events are properly delivered to the client. + customModule := stream_receive.StreamReceiveModule{ + HookCallCount: &atomic.Int32{}, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "streamReceiveModule": stream_receive.StreamReceiveModule{}, + "streamReceiveModule": customModule, }, } @@ -106,8 +110,7 @@ func TestReceiveHook(t *testing.T) { require.NoError(t, err) }, "unable to close client before timeout") - requestLog := xEnv.Observer().FilterMessage("Stream Hook has been run") - assert.Len(t, requestLog.All(), 1) + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) }) }) @@ -119,21 +122,24 @@ func TestReceiveHook(t *testing.T) { // It tests that the modified events are properly delivered to subscribers with the updated data, // demonstrating that hooks can transform stream events before they reach clients. + customModule := stream_receive.StreamReceiveModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + newEvents := make([]datasource.StreamEvent, 0, events.Len()) + for _, event := range events.All() { + eventCopy := event.Clone() + eventCopy.SetData([]byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) + newEvents = append(newEvents, eventCopy) + } + + return datasource.NewStreamEvents(newEvents), nil + }, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - newEvents := make([]datasource.StreamEvent, 0, events.Len()) - for _, event := range events.All() { - eventCopy := event.Clone() - eventCopy.SetData([]byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) - newEvents = append(newEvents, eventCopy) - } - - return datasource.NewStreamEvents(newEvents), nil - }, - }, + "streamReceiveModule": customModule, }, } @@ -195,8 +201,7 @@ func TestReceiveHook(t *testing.T) { require.NoError(t, err) }, "unable to close client before timeout") - requestLog := xEnv.Observer().FilterMessage("Stream Hook has been run") - assert.Len(t, requestLog.All(), 1) + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) }) }) @@ -213,20 +218,23 @@ func TestReceiveHook(t *testing.T) { var taPossible atomic.Bool taPossible.Store(true) + customModule := stream_receive.StreamReceiveModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + for _, evt := range events.All() { + _, ok := evt.(datasource.MutableStreamEvent) + if !ok { + taPossible.Store(false) + } + } + return events, nil + }, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - for _, evt := range events.All() { - _, ok := evt.(datasource.MutableStreamEvent) - if !ok { - taPossible.Store(false) - } - } - return events, nil - }, - }, + "streamReceiveModule": customModule, }, } @@ -288,8 +296,7 @@ func TestReceiveHook(t *testing.T) { require.NoError(t, err) }, "unable to close client before timeout") - requestLog := xEnv.Observer().FilterMessage("Stream Hook has been run") - assert.Len(t, requestLog.All(), 1) + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) assert.False(t, taPossible.Load(), "invalid type assertion was possible") }) @@ -303,28 +310,31 @@ func TestReceiveHook(t *testing.T) { // access JWT claims of individual clients and modify events only for authenticated users with specific claims, // while leaving events for other clients unchanged. + customModule := stream_receive.StreamReceiveModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + if ctx.Authentication() == nil { + return events, nil + } + if val, ok := ctx.Authentication().Claims()["sub"]; !ok || val != "user-2" { + return events, nil + } + + newEvents := make([]datasource.StreamEvent, 0, events.Len()) + for _, event := range events.All() { + eventCopy := event.Clone() + eventCopy.SetData([]byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) + newEvents = append(newEvents, eventCopy) + } + + return datasource.NewStreamEvents(newEvents), nil + }, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - if ctx.Authentication() == nil { - return events, nil - } - if val, ok := ctx.Authentication().Claims()["sub"]; !ok || val != "user-2" { - return events, nil - } - - newEvents := make([]datasource.StreamEvent, 0, events.Len()) - for _, event := range events.All() { - eventCopy := event.Clone() - eventCopy.SetData([]byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) - newEvents = append(newEvents, eventCopy) - } - - return datasource.NewStreamEvents(newEvents), nil - }, - }, + "streamReceiveModule": customModule, }, } @@ -454,8 +464,7 @@ func TestReceiveHook(t *testing.T) { require.NoError(t, err) }, "unable to close client before timeout") - requestLog := xEnv.Observer().FilterMessage("Stream Hook has been run") - assert.Len(t, requestLog.All(), 2) + assert.Equal(t, int32(2), customModule.HookCallCount.Load()) }) }) @@ -467,26 +476,28 @@ func TestReceiveHook(t *testing.T) { // conditionally modify events, enabling header-based event transformation logic. customHeader := http.CanonicalHeaderKey("X-Custom-Header") + customModule := stream_receive.StreamReceiveModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + if val, ok := ctx.Request().Header[customHeader]; !ok || val[0] != "Test" { + return events, nil + } + + newEvents := make([]datasource.StreamEvent, 0, events.Len()) + for _, event := range events.All() { + eventCopy := event.Clone() + eventCopy.SetData([]byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) + newEvents = append(newEvents, eventCopy) + } + + return datasource.NewStreamEvents(newEvents), nil + }, + } cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - if val, ok := ctx.Request().Header[customHeader]; !ok || val[0] != "Test" { - return events, nil - } - - newEvents := make([]datasource.StreamEvent, 0, events.Len()) - for _, event := range events.All() { - eventCopy := event.Clone() - eventCopy.SetData([]byte(`{"__typename":"Employee","id": 3,"update":{"name":"foo"}}`)) - newEvents = append(newEvents, eventCopy) - } - - return datasource.NewStreamEvents(newEvents), nil - }, - }, + "streamReceiveModule": customModule, }, } @@ -556,8 +567,7 @@ func TestReceiveHook(t *testing.T) { require.NoError(t, err) }, "unable to close client before timeout") - requestLog := xEnv.Observer().FilterMessage("Stream Hook has been run") - assert.Len(t, requestLog.All(), 1) + assert.Equal(t, int32(1), customModule.HookCallCount.Load()) }) }) @@ -568,14 +578,17 @@ func TestReceiveHook(t *testing.T) { // the subscription connection and cleans up Kafka clients. It ensures that hook errors trigger // graceful shutdown of the subscription to prevent resource leaks or stuck connections. + customModule := stream_receive.StreamReceiveModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + return datasource.NewStreamEvents(nil), errors.New("test error from streamevents hook") + }, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - return datasource.NewStreamEvents(nil), errors.New("test error from streamevents hook") - }, - }, + "streamReceiveModule": customModule, }, } @@ -684,53 +697,56 @@ func TestReceiveHook(t *testing.T) { finishedHandlers atomic.Int32 ) - cfg := config.Config{ - Graph: config.Graph{}, - Modules: map[string]interface{}{ - "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - currentHandlers.Add(1) - - // Wait for other hooks in the same client update batch to start. - for { - current := currentHandlers.Load() - max := maxCurrentHandlers.Load() + customModule := stream_receive.StreamReceiveModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + currentHandlers.Add(1) - if current > max { - maxCurrentHandlers.CompareAndSwap(max, current) - } + // Wait for other hooks in the same client update batch to start. + for { + current := currentHandlers.Load() + max := maxCurrentHandlers.Load() - if current >= int32(tc.maxConcurrent) { - // wait to see if the subscription-updater spawns too many concurrent hooks, - // i.e. exceeding the number of configured max concurrent hooks. - deadline := time.Now().Add(300 * time.Millisecond) - for time.Now().Before(deadline) { - if currentHandlers.Load() > int32(tc.maxConcurrent) { - break - } - } - break - } + if current > max { + maxCurrentHandlers.CompareAndSwap(max, current) + } - // Let hooks continue if we never reach a updater batch size = tc.maxConcurrent - // because there are not enough remaining clients to be updated. - // i.e. it could be the last round of updates: - // 100 clients, now in comes a new event from broker, max concurrent hooks = 30. - // First round: 30 hooks run, 70 remaining. - // Second round: 30 hooks run, 40 remaining. - // Third round: 30 hooks run, 10 remaining. - // Fourth round: 10 hooks run, then we end up here because remainingSubs < tc.maxConcurrent. - remainingSubs := tc.numSubscribers - int(finishedHandlers.Load()) - if remainingSubs < tc.maxConcurrent { + if current >= int32(tc.maxConcurrent) { + // wait to see if the subscription-updater spawns too many concurrent hooks, + // i.e. exceeding the number of configured max concurrent hooks. + deadline := time.Now().Add(300 * time.Millisecond) + for time.Now().Before(deadline) { + if currentHandlers.Load() > int32(tc.maxConcurrent) { break } } + break + } - currentHandlers.Add(-1) - finishedHandlers.Add(1) - return events, nil - }, - }, + // Let hooks continue if we never reach a updater batch size = tc.maxConcurrent + // because there are not enough remaining clients to be updated. + // i.e. it could be the last round of updates: + // 100 clients, now in comes a new event from broker, max concurrent hooks = 30. + // First round: 30 hooks run, 70 remaining. + // Second round: 30 hooks run, 40 remaining. + // Third round: 30 hooks run, 10 remaining. + // Fourth round: 10 hooks run, then we end up here because remainingSubs < tc.maxConcurrent. + remainingSubs := tc.numSubscribers - int(finishedHandlers.Load()) + if remainingSubs < tc.maxConcurrent { + break + } + } + + currentHandlers.Add(-1) + finishedHandlers.Add(1) + return events, nil + }, + } + + cfg := config.Config{ + Graph: config.Graph{}, + Modules: map[string]interface{}{ + "streamReceiveModule": customModule, }, } @@ -817,8 +833,7 @@ func TestReceiveHook(t *testing.T) { assert.Equal(t, int32(tc.maxConcurrent), maxCurrentHandlers.Load(), "amount of concurrent handlers not what was expected") - requestLog := xEnv.Observer().FilterMessage("Stream Hook has been run") - assert.Len(t, requestLog.All(), tc.numSubscribers) + assert.Equal(t, int32(tc.numSubscribers), customModule.HookCallCount.Load()) }) }) } @@ -843,18 +858,21 @@ func TestReceiveHook(t *testing.T) { var callCount atomic.Int32 + customModule := stream_receive.StreamReceiveModule{ + HookCallCount: &atomic.Int32{}, + Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { + // Only the first call should delay + if callCount.Add(1) == 1 { + time.Sleep(hookDelay) + } + return events, nil + }, + } + cfg := config.Config{ Graph: config.Graph{}, Modules: map[string]interface{}{ - "streamReceiveModule": stream_receive.StreamReceiveModule{ - Callback: func(ctx core.StreamReceiveEventHandlerContext, events datasource.StreamEvents) (datasource.StreamEvents, error) { - // Only the first call should delay - if callCount.Add(1) == 1 { - time.Sleep(hookDelay) - } - return events, nil - }, - }, + "streamReceiveModule": customModule, }, } @@ -945,8 +963,7 @@ func TestReceiveHook(t *testing.T) { assert.Len(t, timeoutLog.All(), 1, "expected timeout warning to be logged") // Verify all hooks were executed - hookLog := xEnv.Observer().FilterMessage("Stream Hook has been run") - assert.Len(t, hookLog.All(), 3) + assert.Equal(t, int32(3), customModule.HookCallCount.Load()) }) }) } From 544bf72f2fe9471d3e17b9f4ada91efe1b801b43 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Wed, 19 Nov 2025 14:53:57 +0100 Subject: [PATCH 28/44] chore: removed unused case of switch + remove early return Removed a switch case, which is not reachable anymore. Also removed an early return in case a type assertion did not work. This way we ensure the response is actually sent either way. --- router/core/graphql_handler.go | 27 +++++---------------------- 1 file changed, 5 insertions(+), 22 deletions(-) diff --git a/router/core/graphql_handler.go b/router/core/graphql_handler.go index d078e84065..22ffb45498 100644 --- a/router/core/graphql_handler.go +++ b/router/core/graphql_handler.go @@ -394,15 +394,14 @@ func (h *GraphQLHandler) WriteError(ctx *resolve.Context, err error, res *resolv var errStreamHandlerError *StreamHandlerError if !errors.As(err, &errStreamHandlerError) { response.Errors[0].Message = "Internal server error" - // We could set response.Errors[0].Extensions, too if isHttpResponseWriter { httpWriter.WriteHeader(http.StatusInternalServerError) } - return - } - response.Errors[0].Message = errStreamHandlerError.Message - if isHttpResponseWriter { - httpWriter.WriteHeader(http.StatusOK) + } else { + response.Errors[0].Message = errStreamHandlerError.Message + if isHttpResponseWriter { + httpWriter.WriteHeader(http.StatusOK) + } } case errorTypeInvalidWsSubprotocol: response.Errors[0].Message = fmt.Sprintf("Invalid Subprotocol error: %s or configure the subprotocol to be used using `wgc subgraph update` command.", err.Error()) @@ -414,22 +413,6 @@ func (h *GraphQLHandler) WriteError(ctx *resolve.Context, err error, res *resolv if isHttpResponseWriter { httpWriter.WriteHeader(http.StatusInternalServerError) } - case errorTypeHttpError: - var httpErr *httpGraphqlError - if !errors.As(err, &httpErr) { - response.Errors[0].Message = "Internal server error" - return - } - response.Errors[0].Message = httpErr.Message() - if httpErr.ExtensionCode() != "" || httpErr.StatusCode() != 0 { - response.Errors[0].Extensions = &Extensions{ - Code: httpErr.ExtensionCode(), - StatusCode: httpErr.StatusCode(), - } - } - if isHttpResponseWriter { - httpWriter.WriteHeader(httpErr.StatusCode()) - } } if ctx.TracingOptions.Enable && ctx.TracingOptions.IncludeTraceOutputInResponseExtensions { From 1cabe5d69a32b267b7660298c5b2b328a84b3421 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Wed, 19 Nov 2025 15:03:25 +0100 Subject: [PATCH 29/44] fix: add timeout waiting for channel in test --- router-tests/modules/start_subscription_test.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/router-tests/modules/start_subscription_test.go b/router-tests/modules/start_subscription_test.go index 1621ac5e42..9b9d1c2c4b 100644 --- a/router-tests/modules/start_subscription_test.go +++ b/router-tests/modules/start_subscription_test.go @@ -257,7 +257,9 @@ func TestStartSubscriptionHook(t *testing.T) { }() xEnv.WaitForSubscriptionCount(1, time.Second*10) - <-callbackCalled + testenv.AwaitChannelWithT(t, 10*time.Second, callbackCalled, func(t *testing.T, called bool) { + require.True(t, called) + }, "StartSubscription callback was not invoked") xEnv.WaitForSubscriptionCount(0, time.Second*10) testenv.AwaitChannelWithT(t, time.Second*10, clientRunCh, func(t *testing.T, err error) { From d2946a1a1f6de508977508b7a00f9b483fbdabd4 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Wed, 19 Nov 2025 15:11:51 +0100 Subject: [PATCH 30/44] fix: avoid using Len on channel in test --- router-tests/modules/stream_publish_test.go | 6 +++++- router-tests/modules/stream_receive_test.go | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/router-tests/modules/stream_publish_test.go b/router-tests/modules/stream_publish_test.go index f22cce8254..35d9e16766 100644 --- a/router-tests/modules/stream_publish_test.go +++ b/router-tests/modules/stream_publish_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/redis/go-redis/v9" "go.uber.org/zap/zapcore" "github.com/stretchr/testify/assert" @@ -278,7 +279,10 @@ func TestPublishHook(t *testing.T) { assert.Equal(t, int32(1), customModule.HookCallCount.Load()) - require.Len(t, records, 1) + testenv.AwaitChannelWithT(t, 5*time.Second, records, func(t *testing.T, msg *redis.Message) { + require.NotNil(t, msg, "expected to receive a redis message") + require.Equal(t, xEnv.GetPubSubName("employeeUpdatedMyRedis"), msg.Channel) + }) }) }) diff --git a/router-tests/modules/stream_receive_test.go b/router-tests/modules/stream_receive_test.go index 383f3d9b9d..45e59b1b71 100644 --- a/router-tests/modules/stream_receive_test.go +++ b/router-tests/modules/stream_receive_test.go @@ -378,7 +378,7 @@ func TestReceiveHook(t *testing.T) { topics := []string{"employeeUpdated"} events.KafkaEnsureTopicExists(t, xEnv, time.Second, topics...) - var subscriptionOne struct { + var subscriptionQuery struct { employeeUpdatedMyKafka struct { ID float64 `graphql:"id"` Details struct { @@ -405,7 +405,7 @@ func TestReceiveHook(t *testing.T) { }) subscriptionArgsCh := make(chan kafkaSubscriptionArgs) - subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionOneID, err := client.Subscribe(&subscriptionQuery, nil, func(dataValue []byte, errValue error) error { subscriptionArgsCh <- kafkaSubscriptionArgs{ dataValue: dataValue, errValue: errValue, @@ -421,7 +421,7 @@ func TestReceiveHook(t *testing.T) { }() subscriptionArgsCh2 := make(chan kafkaSubscriptionArgs) - subscriptionTwoID, err := client2.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { + subscriptionTwoID, err := client2.Subscribe(&subscriptionQuery, nil, func(dataValue []byte, errValue error) error { subscriptionArgsCh2 <- kafkaSubscriptionArgs{ dataValue: dataValue, errValue: errValue, From 5d0c63019a4a5b2547b54ed3c6fa8dda570a2f8c Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Wed, 19 Nov 2025 15:16:59 +0100 Subject: [PATCH 31/44] chore: remove unused channel from test --- router-tests/modules/stream_receive_test.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/router-tests/modules/stream_receive_test.go b/router-tests/modules/stream_receive_test.go index 45e59b1b71..f78fda9e0c 100644 --- a/router-tests/modules/stream_receive_test.go +++ b/router-tests/modules/stream_receive_test.go @@ -620,12 +620,7 @@ func TestReceiveHook(t *testing.T) { surl := xEnv.GraphQLWebSocketSubscriptionURL() client := graphql.NewSubscriptionClient(surl) - subscriptionArgsCh := make(chan kafkaSubscriptionArgs) subscriptionOneID, err := client.Subscribe(&subscriptionOne, nil, func(dataValue []byte, errValue error) error { - subscriptionArgsCh <- kafkaSubscriptionArgs{ - dataValue: dataValue, - errValue: errValue, - } return nil }) require.NoError(t, err) From cdcbf934692ffbecd20cd39b0a99991121b98ddd Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Wed, 19 Nov 2025 15:54:11 +0100 Subject: [PATCH 32/44] chore: remove unused constant and a corresponding check --- router/core/errors.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/router/core/errors.go b/router/core/errors.go index 203353516d..9a0a99418a 100644 --- a/router/core/errors.go +++ b/router/core/errors.go @@ -35,7 +35,6 @@ const ( errorTypeInvalidWsSubprotocol errorTypeEDFSInvalidMessage errorTypeMergeResult - errorTypeHttpError errorTypeEDFSHookError ) @@ -95,10 +94,6 @@ func getErrorType(err error) errorType { if errors.As(err, &mergeResultErr) { return errorTypeMergeResult } - var httpError *httpGraphqlError - if errors.As(err, &httpError) { - return errorTypeHttpError - } return errorTypeUnknown } From 288bf528ba07a377929a95354b5e0ce2f756582f Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Thu, 20 Nov 2025 11:20:44 +0100 Subject: [PATCH 33/44] chore: update Cosmo Streams ADR --- adr/cosmo-streams-v1.md | 118 +++++++++++++++++++++++++++------------- 1 file changed, 79 insertions(+), 39 deletions(-) diff --git a/adr/cosmo-streams-v1.md b/adr/cosmo-streams-v1.md index f764639a50..03265d9c04 100644 --- a/adr/cosmo-streams-v1.md +++ b/adr/cosmo-streams-v1.md @@ -84,21 +84,43 @@ type PublishEventConfiguration interface { } type SubscriptionOnStartHandlerContext interface { - // Request is the original request received by the router. - Request() *http.Request - // Logger is the logger for the request - Logger() *zap.Logger - // Operation is the GraphQL operation - Operation() OperationContext - // Authentication is the authentication for the request - Authentication() authentication.Authentication - // SubscriptionEventConfiguration is the subscription event configuration (will return nil for engine subscription) - SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration - // WriteEvent writes an event to the stream of the current subscription - // It returns true if the event was written to the stream, false if the event was dropped - WriteEvent(event datasource.StreamEvent) bool - // NewEvent creates a new event that can be used in the subscription. - NewEvent(data []byte) datasource.MutableStreamEvent + // Request is the original request received by the router. + Request() *http.Request + // Logger is the logger for the request + Logger() *zap.Logger + // Operation is the GraphQL operation + Operation() OperationContext + // Authentication is the authentication for the request + Authentication() authentication.Authentication + // SubscriptionEventConfiguration is the subscription event configuration (will return nil for engine subscription) + SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration + // EmitLocalEvent sends an event directly to the subscription stream of the + // currently connected client. + // + // This method triggers the router to resolve the client's operation and emit + // the resulting data as a stream event. The event exists only within the + // router; it is not forwarded to any message broker. + // + // The event is delivered exclusively to the client associated with the current + // handler execution. No other subscriptions are affected. + // + // The method returns true if the event was successfully emitted, or false if + // it was dropped. + EmitLocalEvent(event datasource.StreamEvent) bool + // NewEvent creates a new event that can be used in the subscription. + // + // The data parameter must contain valid JSON bytes. The format depends on the subscription type. + // + // For event-driven subscriptions (Cosmo Streams / EDFS), the data should contain: + // __typename : The name of the schema entity, which is expected to be returned to the client. + // {keyName} : The key of the entity as configured on the schema via @key directive. + // Example usage: ctx.NewEvent([]byte(`{"__typename": "Employee", "id": 1}`)) + // + // For normal subscriptions, you need to provide the complete GraphQL response structure. + // Example usage: ctx.NewEvent([]byte(`{"data": {"fieldName": value}}`)) + // + // You can use EmitLocalEvent to emit this event to subscriptions. + NewEvent(data []byte) datasource.MutableStreamEvent } type SubscriptionOnStartHandler interface { @@ -108,18 +130,28 @@ type SubscriptionOnStartHandler interface { } type StreamReceiveEventHandlerContext interface { - // Request is the initial client request that started the subscription - Request() *http.Request - // Logger is the logger for the request - Logger() *zap.Logger - // Operation is the GraphQL operation - Operation() OperationContext - // Authentication is the authentication for the request - Authentication() authentication.Authentication - // SubscriptionEventConfiguration is the subscription event configuration - SubscriptionEventConfiguration() SubscriptionEventConfiguration - // NewEvent creates a new event that can be used in the subscription. - NewEvent(data []byte) datasource.MutableStreamEvent + // Context is a context for handlers. + // If it is cancelled, the handler should stop processing. + Context() context.Context + // Request is the initial client request that started the subscription + Request() *http.Request + // Logger is the logger for the request + Logger() *zap.Logger + // Operation is the GraphQL operation + Operation() OperationContext + // Authentication is the authentication for the request + Authentication() authentication.Authentication + // SubscriptionEventConfiguration the subscription event configuration + SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration + // NewEvent creates a new event that can be used in the subscription. + // + // The data parameter must contain valid JSON bytes representing the raw event payload + // from your message broker (Kafka, NATS, etc.). The JSON must have properly quoted + // property names and must include the __typename field required by GraphQL. + // For example: []byte(`{"__typename": "Employee", "id": 1, "update": {"name": "John"}}`). + // + // This method is typically used in OnReceiveEvents hooks to create new or modified events. + NewEvent(data []byte) datasource.MutableStreamEvent } type StreamReceiveEventHandler interface { @@ -135,18 +167,26 @@ type StreamReceiveEventHandler interface { } type StreamPublishEventHandlerContext interface { - // Request is the original request received by the router. - Request() *http.Request - // Logger is the logger for the request - Logger() *zap.Logger - // Operation is the GraphQL operation - Operation() OperationContext - // Authentication is the authentication for the request - Authentication() authentication.Authentication - // PublishEventConfiguration is the publish event configuration - PublishEventConfiguration() PublishEventConfiguration - // NewEvent creates a new event that can be used in the subscription. - NewEvent(data []byte) datasource.MutableStreamEvent + // Request is the original request received by the router. + Request() *http.Request + // Logger is the logger for the request + Logger() *zap.Logger + // Operation is the GraphQL operation + Operation() OperationContext + // Authentication is the authentication for the request + Authentication() authentication.Authentication + // PublishEventConfiguration the publish event configuration + PublishEventConfiguration() datasource.PublishEventConfiguration + // NewEvent creates a new event that can be used in the subscription. + // + // The data parameter must contain valid JSON bytes representing the event payload + // that will be sent to your message broker (Kafka, NATS, etc.). The JSON must have + // properly quoted property names and must include the __typename field required by GraphQL. + // For example: []byte(`{"__typename": "Employee", "id": 1, "update": {"name": "John"}}`). + // + // This method is typically used in OnPublishEvents hooks to create new or modified events + // before they are sent to the message broker. + NewEvent(data []byte) datasource.MutableStreamEvent } type StreamPublishEventHandler interface { From 551ef84e253d614ef8c1d33ff3969da00c451857 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Thu, 20 Nov 2025 13:59:20 +0100 Subject: [PATCH 34/44] fix: do not close updater when poller closes This actually introduces a bug, where occasionally subscription clients will get disconnected after receiving messages when we deal with lots of connecting and disconnecting clients. It depends on two contexts being regularly checked inside the Kafka topic poller and if one of them cancels, the connections would get closed. These contexts are meant to handle the lifecycle of the Kafka adapter, not subscriptions. So we should not close the subscription updater when the poller is canceled. After this commit it behaves like the router behaves before Cosmo Streams. All of this can be reevaluated, maybe there is room for improvement. But for now its better not to change this bevaviour. --- router/pkg/pubsub/kafka/adapter.go | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/router/pkg/pubsub/kafka/adapter.go b/router/pkg/pubsub/kafka/adapter.go index fcd1cc0c70..dd0f08a492 100644 --- a/router/pkg/pubsub/kafka/adapter.go +++ b/router/pkg/pubsub/kafka/adapter.go @@ -13,7 +13,6 @@ import ( "github.com/twmb/franz-go/pkg/kerr" "github.com/twmb/franz-go/pkg/kgo" "github.com/wundergraph/cosmo/router/pkg/pubsub/datasource" - "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" ) @@ -158,11 +157,7 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri go func() { - defer func() { - client.Close() - updater.Close(resolve.SubscriptionCloseKindNormal) - p.closeWg.Done() - }() + defer p.closeWg.Done() err := p.topicPoller(ctx, client, updater, PollerOpts{providerId: conf.ProviderID()}) if err != nil { From cd086fc7cb50a07e6f53ab9dde60b55704b67716 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Fri, 21 Nov 2025 11:42:53 +0100 Subject: [PATCH 35/44] chore: go mod tidy --- router-tests/go.sum | 2 -- router/go.sum | 2 -- 2 files changed, 4 deletions(-) diff --git a/router-tests/go.sum b/router-tests/go.sum index 2be4b9515d..fe3848e28b 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -354,8 +354,6 @@ github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTB github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301 h1:EzfKHQoTjFDDcgaECCCR2aTePqMu9QBmPbyhqIYOhV0= github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301/go.mod h1:wxI0Nak5dI5RvJuzGyiEK4nZj0O9X+Aw6U0tC1wPKq0= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.237.0.20251110152155-423a60c6a33e h1:246mrdmTHRIsW9yVQjFKQlAgvw+sNES1FymnVjJ7r/Q= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.237.0.20251110152155-423a60c6a33e/go.mod h1:ErOQH1ki2+SZB8JjpTyGVnoBpg5picIyjvuWQJP4abg= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.238.0.20251120113218-a4f189176bde h1:odFMNVd6midgkv+yOfK8WufK+0yANceTJcu7KaIuhs0= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.238.0.20251120113218-a4f189176bde/go.mod h1:mX25ASEQiKamxaFSK6NZihh0oDCigIuzro30up4mFH4= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= diff --git a/router/go.sum b/router/go.sum index 80d356e604..6d78fc037a 100644 --- a/router/go.sum +++ b/router/go.sum @@ -322,8 +322,6 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/ github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.237.0.20251110152155-423a60c6a33e h1:246mrdmTHRIsW9yVQjFKQlAgvw+sNES1FymnVjJ7r/Q= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.237.0.20251110152155-423a60c6a33e/go.mod h1:ErOQH1ki2+SZB8JjpTyGVnoBpg5picIyjvuWQJP4abg= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.238.0.20251120113218-a4f189176bde h1:odFMNVd6midgkv+yOfK8WufK+0yANceTJcu7KaIuhs0= github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.238.0.20251120113218-a4f189176bde/go.mod h1:mX25ASEQiKamxaFSK6NZihh0oDCigIuzro30up4mFH4= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= From 20770e1f97cbd6deab52cbafa47083d740c15ca7 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Fri, 21 Nov 2025 18:02:25 +0100 Subject: [PATCH 36/44] chore: rename EmitLocalEvent to EmitEvent --- adr/cosmo-streams-v1.md | 6 +-- .../modules/start_subscription_test.go | 6 +-- router/core/subscriptions_modules.go | 38 +++++++++---------- 3 files changed, 25 insertions(+), 25 deletions(-) diff --git a/adr/cosmo-streams-v1.md b/adr/cosmo-streams-v1.md index 03265d9c04..163a8e760b 100644 --- a/adr/cosmo-streams-v1.md +++ b/adr/cosmo-streams-v1.md @@ -94,7 +94,7 @@ type SubscriptionOnStartHandlerContext interface { Authentication() authentication.Authentication // SubscriptionEventConfiguration is the subscription event configuration (will return nil for engine subscription) SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration - // EmitLocalEvent sends an event directly to the subscription stream of the + // EmitEvent sends an event directly to the subscription stream of the // currently connected client. // // This method triggers the router to resolve the client's operation and emit @@ -106,7 +106,7 @@ type SubscriptionOnStartHandlerContext interface { // // The method returns true if the event was successfully emitted, or false if // it was dropped. - EmitLocalEvent(event datasource.StreamEvent) bool + EmitEvent(event datasource.StreamEvent) bool // NewEvent creates a new event that can be used in the subscription. // // The data parameter must contain valid JSON bytes. The format depends on the subscription type. @@ -119,7 +119,7 @@ type SubscriptionOnStartHandlerContext interface { // For normal subscriptions, you need to provide the complete GraphQL response structure. // Example usage: ctx.NewEvent([]byte(`{"data": {"fieldName": value}}`)) // - // You can use EmitLocalEvent to emit this event to subscriptions. + // You can use EmitEvent to emit this event to subscriptions. NewEvent(data []byte) datasource.MutableStreamEvent } diff --git a/router-tests/modules/start_subscription_test.go b/router-tests/modules/start_subscription_test.go index 9b9d1c2c4b..b517c287f1 100644 --- a/router-tests/modules/start_subscription_test.go +++ b/router-tests/modules/start_subscription_test.go @@ -104,7 +104,7 @@ func TestStartSubscriptionHook(t *testing.T) { if ctx.SubscriptionEventConfiguration().RootFieldName() != "employeeUpdatedMyKafka" { return nil } - ctx.EmitLocalEvent((&kafka.MutableEvent{ + ctx.EmitEvent((&kafka.MutableEvent{ Key: []byte("1"), Data: []byte(`{"id": 1, "__typename": "Employee"}`), })) @@ -291,7 +291,7 @@ func TestStartSubscriptionHook(t *testing.T) { return nil } evt := ctx.NewEvent([]byte(`{"id": 1, "__typename": "Employee"}`)) - ctx.EmitLocalEvent(evt) + ctx.EmitEvent(evt) return nil }, } @@ -551,7 +551,7 @@ func TestStartSubscriptionHook(t *testing.T) { HookCallCount: &atomic.Int32{}, Callback: func(ctx core.SubscriptionOnStartHandlerContext) error { evt := ctx.NewEvent([]byte(`{"data":{"countEmp":1000}}`)) - ctx.EmitLocalEvent(evt) + ctx.EmitEvent(evt) return nil }, } diff --git a/router/core/subscriptions_modules.go b/router/core/subscriptions_modules.go index da767da118..c01d1fa348 100644 --- a/router/core/subscriptions_modules.go +++ b/router/core/subscriptions_modules.go @@ -23,7 +23,7 @@ type SubscriptionOnStartHandlerContext interface { Authentication() authentication.Authentication // SubscriptionEventConfiguration is the subscription event configuration (will return nil for engine subscription) SubscriptionEventConfiguration() datasource.SubscriptionEventConfiguration - // EmitLocalEvent sends an event directly to the subscription stream of the + // EmitEvent sends an event directly to the subscription stream of the // currently connected client. // // This method triggers the router to resolve the client's operation and emit @@ -35,7 +35,7 @@ type SubscriptionOnStartHandlerContext interface { // // The method returns true if the event was successfully emitted, or false if // it was dropped. - EmitLocalEvent(event datasource.StreamEvent) bool + EmitEvent(event datasource.StreamEvent) bool // NewEvent creates a new event that can be used in the subscription. // // The data parameter must contain valid JSON bytes. The format depends on the subscription type. @@ -48,7 +48,7 @@ type SubscriptionOnStartHandlerContext interface { // For normal subscriptions, you need to provide the complete GraphQL response structure. // Example usage: ctx.NewEvent([]byte(`{"data": {"fieldName": value}}`)) // - // You can use EmitLocalEvent to emit this event to subscriptions. + // You can use EmitEvent to emit this event to subscriptions. NewEvent(data []byte) datasource.MutableStreamEvent } @@ -91,7 +91,7 @@ type pubSubSubscriptionOnStartHookContext struct { operation OperationContext authentication authentication.Authentication subscriptionEventConfiguration datasource.SubscriptionEventConfiguration - emitLocalEventFn func(data []byte) + emitEventFn func(data []byte) eventBuilder datasource.EventBuilderFn } @@ -115,8 +115,8 @@ func (c *pubSubSubscriptionOnStartHookContext) SubscriptionEventConfiguration() return c.subscriptionEventConfiguration } -func (c *pubSubSubscriptionOnStartHookContext) EmitLocalEvent(event datasource.StreamEvent) bool { - c.emitLocalEventFn(event.GetData()) +func (c *pubSubSubscriptionOnStartHookContext) EmitEvent(event datasource.StreamEvent) bool { + c.emitEventFn(event.GetData()) return true } @@ -162,11 +162,11 @@ func (e *EngineEvent) Clone() datasource.MutableStreamEvent { } type engineSubscriptionOnStartHookContext struct { - request *http.Request - logger *zap.Logger - operation OperationContext - authentication authentication.Authentication - emitLocalEventFn func(data []byte) + request *http.Request + logger *zap.Logger + operation OperationContext + authentication authentication.Authentication + emitEventFn func(data []byte) } func (c *engineSubscriptionOnStartHookContext) Request() *http.Request { @@ -185,8 +185,8 @@ func (c *engineSubscriptionOnStartHookContext) Authentication() authentication.A return c.authentication } -func (c *engineSubscriptionOnStartHookContext) EmitLocalEvent(event datasource.StreamEvent) bool { - c.emitLocalEventFn(event.GetData()) +func (c *engineSubscriptionOnStartHookContext) EmitEvent(event datasource.StreamEvent) bool { + c.emitEventFn(event.GetData()) return true } @@ -232,7 +232,7 @@ func NewPubSubSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHandlerCont operation: requestContext.Operation(), authentication: requestContext.Authentication(), subscriptionEventConfiguration: subConf, - emitLocalEventFn: resolveCtx.Updater, + emitEventFn: resolveCtx.Updater, eventBuilder: eventBuilder, } @@ -255,11 +255,11 @@ func NewEngineSubscriptionOnStartHook(fn func(ctx SubscriptionOnStartHandlerCont } hookCtx := &engineSubscriptionOnStartHookContext{ - request: requestContext.Request(), - logger: logger, - operation: requestContext.Operation(), - authentication: requestContext.Authentication(), - emitLocalEventFn: resolveCtx.Updater, + request: requestContext.Request(), + logger: logger, + operation: requestContext.Operation(), + authentication: requestContext.Authentication(), + emitEventFn: resolveCtx.Updater, } return fn(hookCtx) From 435c0f276b2e1c58657786eb0f68512482c1be36 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Fri, 21 Nov 2025 18:38:46 +0100 Subject: [PATCH 37/44] chore: use type assert instead of copy --- router/pkg/pubsub/kafka/adapter.go | 25 ++++++++++++++++++------- 1 file changed, 18 insertions(+), 7 deletions(-) diff --git a/router/pkg/pubsub/kafka/adapter.go b/router/pkg/pubsub/kafka/adapter.go index dd0f08a492..9dddfe5770 100644 --- a/router/pkg/pubsub/kafka/adapter.go +++ b/router/pkg/pubsub/kafka/adapter.go @@ -211,13 +211,13 @@ func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEv var errMutex sync.Mutex for _, streamEvent := range events { - kafkaEvent, ok := streamEvent.Clone().(*MutableEvent) - if !ok { - return datasource.NewError("invalid event type for Kafka adapter", nil) + evt, err := castToMutableEvent(streamEvent) + if err != nil { + return datasource.NewError(err.Error(), nil) } - headers := make([]kgo.RecordHeader, 0, len(kafkaEvent.Headers)) - for key, value := range kafkaEvent.Headers { + headers := make([]kgo.RecordHeader, 0, len(evt.Headers)) + for key, value := range evt.Headers { headers = append(headers, kgo.RecordHeader{ Key: key, Value: value, @@ -225,9 +225,9 @@ func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEv } p.writeClient.Produce(ctx, &kgo.Record{ - Key: kafkaEvent.Key, + Key: evt.Key, Topic: pubConf.Topic, - Value: kafkaEvent.Data, + Value: evt.Data, Headers: headers, }, func(record *kgo.Record, err error) { defer wg.Done() @@ -322,3 +322,14 @@ func NewProviderAdapter(ctx context.Context, logger *zap.Logger, opts []kgo.Opt, streamMetricStore: store, }, nil } + +func castToMutableEvent(event datasource.StreamEvent) (*MutableEvent, error) { + switch evt := event.(type) { + case *Event: + return evt.evt, nil + case *MutableEvent: + return evt, nil + default: + return nil, errors.New("invalid event type for Kafka adapter") + } +} From 0f131a0392fe92a5de4f3ef04cad338c7b2656ea Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Fri, 21 Nov 2025 20:08:57 +0100 Subject: [PATCH 38/44] fix: handle partial errors when publishing events --- router/pkg/pubsub/kafka/adapter.go | 39 ++++++++++++++------- router/pkg/pubsub/nats/adapter.go | 55 ++++++++++++++++++------------ router/pkg/pubsub/redis/adapter.go | 55 ++++++++++++++++++++---------- 3 files changed, 97 insertions(+), 52 deletions(-) diff --git a/router/pkg/pubsub/kafka/adapter.go b/router/pkg/pubsub/kafka/adapter.go index 9dddfe5770..70dfc1c112 100644 --- a/router/pkg/pubsub/kafka/adapter.go +++ b/router/pkg/pubsub/kafka/adapter.go @@ -207,13 +207,17 @@ func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEv var wg sync.WaitGroup wg.Add(len(events)) - var pErr error + var errs []error var errMutex sync.Mutex for _, streamEvent := range events { evt, err := castToMutableEvent(streamEvent) if err != nil { - return datasource.NewError(err.Error(), nil) + wg.Done() + errMutex.Lock() + errs = append(errs, err) + errMutex.Unlock() + continue } headers := make([]kgo.RecordHeader, 0, len(evt.Headers)) @@ -233,7 +237,7 @@ func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEv defer wg.Done() if err != nil { errMutex.Lock() - pErr = err + errs = append(errs, err) errMutex.Unlock() } }) @@ -241,8 +245,17 @@ func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEv wg.Wait() - if pErr != nil { - log.Error("publish error", zap.Error(pErr)) + // Produce metrics for all failed and successfully published events + successCount := len(events) - len(errs) + for range successCount { + p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ + ProviderId: pubConf.ProviderID(), + StreamOperationName: kafkaProduce, + ProviderType: metric.ProviderTypeKafka, + DestinationName: pubConf.Topic, + }) + } + for range len(errs) { p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ ProviderId: pubConf.ProviderID(), StreamOperationName: kafkaProduce, @@ -250,15 +263,17 @@ func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEv ErrorType: "publish_error", DestinationName: pubConf.Topic, }) - return datasource.NewError(fmt.Sprintf("error publishing to Kafka topic %s", pubConf.Topic), pErr) } - p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: pubConf.ProviderID(), - StreamOperationName: kafkaProduce, - ProviderType: metric.ProviderTypeKafka, - DestinationName: pubConf.Topic, - }) + // Log all errors, if any, as a single entry and return error + if len(errs) > 0 { + combinedErr := errors.Join(errs...) + log.Error("publish errors", zap.Error(combinedErr), zap.Int("failed_count", len(errs)), zap.Int("total_count", len(events))) + return datasource.NewError( + fmt.Sprintf("error publishing %d/%d events to Kafka topic %s", len(errs), len(events), pubConf.Topic), combinedErr, + ) + } + return nil } diff --git a/router/pkg/pubsub/nats/adapter.go b/router/pkg/pubsub/nats/adapter.go index e32368c658..fbb9dc3fe6 100644 --- a/router/pkg/pubsub/nats/adapter.go +++ b/router/pkg/pubsub/nats/adapter.go @@ -248,38 +248,49 @@ func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEv log.Debug("publish", zap.Int("event_count", len(events))) + var errs []error + for _, streamEvent := range events { natsEvent, ok := streamEvent.Clone().(*MutableEvent) if !ok { - return datasource.NewError("invalid event type for NATS adapter", nil) + errs = append(errs, errors.New("invalid event type for NATS adapter")) + continue } err := p.client.Publish(pubConf.Subject, natsEvent.Data) if err != nil { - p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: pubConf.ProviderID(), - StreamOperationName: natsPublish, - ProviderType: metric.ProviderTypeNats, - ErrorType: "publish_error", - DestinationName: pubConf.Subject, - }) - log.Error( - "publish error", - zap.Error(err), - zap.String("provider_id", pubConf.ProviderID()), - zap.String("provider_type", string(pubConf.ProviderType())), - zap.String("field_name", pubConf.RootFieldName()), - ) - return datasource.NewError(fmt.Sprintf("error publishing to NATS subject %s", pubConf.Subject), err) + errs = append(errs, err) } } - p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: pubConf.ProviderID(), - StreamOperationName: natsPublish, - ProviderType: metric.ProviderTypeNats, - DestinationName: pubConf.Subject, - }) + // Produce metrics for all failed and successfully published events + successCount := len(events) - len(errs) + for range successCount { + p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ + ProviderId: pubConf.ProviderID(), + StreamOperationName: natsPublish, + ProviderType: metric.ProviderTypeNats, + DestinationName: pubConf.Subject, + }) + } + for range len(errs) { + p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ + ProviderId: pubConf.ProviderID(), + StreamOperationName: natsPublish, + ProviderType: metric.ProviderTypeNats, + ErrorType: "publish_error", + DestinationName: pubConf.Subject, + }) + } + + // Collect and return all errors if any occurred + if len(errs) > 0 { + combinedErr := errors.Join(errs...) + log.Error("publish errors", zap.Error(combinedErr), zap.Int("failed_count", len(errs)), zap.Int("total_count", len(events))) + return datasource.NewError( + fmt.Sprintf("error publishing %d/%d events to NATS subject %s", len(errs), len(events), pubConf.Subject), combinedErr, + ) + } return nil } diff --git a/router/pkg/pubsub/redis/adapter.go b/router/pkg/pubsub/redis/adapter.go index 8c056fe6c1..b1e90e4f33 100644 --- a/router/pkg/pubsub/redis/adapter.go +++ b/router/pkg/pubsub/redis/adapter.go @@ -2,6 +2,7 @@ package redis import ( "context" + "errors" "fmt" "sync" @@ -172,37 +173,55 @@ func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEv log.Debug("publish", zap.Int("event_count", len(events))) + var errs []error + for _, streamEvent := range events { redisEvent, ok := streamEvent.Clone().(*MutableEvent) if !ok { - return datasource.NewError("invalid event type for Redis adapter", nil) + errs = append(errs, errors.New("invalid event type for Redis adapter")) + continue } data, dataErr := redisEvent.Data.MarshalJSON() if dataErr != nil { - log.Error("error marshalling data", zap.Error(dataErr)) - return datasource.NewError("error marshalling data", dataErr) + errs = append(errs, fmt.Errorf("error marshalling data: %w", dataErr)) + continue } intCmd := p.conn.Publish(ctx, pubConf.Channel, data) if intCmd.Err() != nil { - log.Error("publish error", zap.Error(intCmd.Err())) - p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: pubConf.ProviderID(), - StreamOperationName: redisPublish, - ProviderType: metric.ProviderTypeRedis, - ErrorType: "publish_error", - DestinationName: pubConf.Channel, - }) - return datasource.NewError(fmt.Sprintf("error publishing to Redis PubSub channel %s", pubConf.Channel), intCmd.Err()) + errs = append(errs, intCmd.Err()) } } - p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ - ProviderId: pubConf.ProviderID(), - StreamOperationName: redisPublish, - ProviderType: metric.ProviderTypeRedis, - DestinationName: pubConf.Channel, - }) + // Produce metrics for all failed and successfully published events + successCount := len(events) - len(errs) + for range successCount { + p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ + ProviderId: pubConf.ProviderID(), + StreamOperationName: redisPublish, + ProviderType: metric.ProviderTypeRedis, + DestinationName: pubConf.Channel, + }) + } + for range len(errs) { + p.streamMetricStore.Produce(ctx, metric.StreamsEvent{ + ProviderId: pubConf.ProviderID(), + StreamOperationName: redisPublish, + ProviderType: metric.ProviderTypeRedis, + ErrorType: "publish_error", + DestinationName: pubConf.Channel, + }) + } + + // Collect and return all errors if any occurred + if len(errs) > 0 { + combinedErr := errors.Join(errs...) + log.Error("publish errors", zap.Error(combinedErr), zap.Int("failed_count", len(errs)), zap.Int("total_count", len(events))) + return datasource.NewError( + fmt.Sprintf("error publishing %d/%d events to Redis PubSub channel %s", len(errs), len(events), pubConf.Channel), combinedErr, + ) + } + return nil } From 5c3e8343c74ac539a504850693f927f7cf123868 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Fri, 21 Nov 2025 20:33:48 +0100 Subject: [PATCH 39/44] chore: rename edfs to streams error --- router/core/errors.go | 8 ++++---- router/core/graphql_handler.go | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/router/core/errors.go b/router/core/errors.go index 9a0a99418a..a5c27ef763 100644 --- a/router/core/errors.go +++ b/router/core/errors.go @@ -35,7 +35,7 @@ const ( errorTypeInvalidWsSubprotocol errorTypeEDFSInvalidMessage errorTypeMergeResult - errorTypeEDFSHookError + errorTypeStreamsHandlerError ) type ( @@ -78,9 +78,9 @@ func getErrorType(err error) errorType { if errors.As(err, &edfsErr) { return errorTypeEDFS } - var edfsHookErr *StreamHandlerError - if errors.As(err, &edfsHookErr) { - return errorTypeEDFSHookError + var streamsHandlerErr *StreamHandlerError + if errors.As(err, &streamsHandlerErr) { + return errorTypeStreamsHandlerError } var invalidWsSubprotocolErr graphql_datasource.InvalidWsSubprotocolError if errors.As(err, &invalidWsSubprotocolErr) { diff --git a/router/core/graphql_handler.go b/router/core/graphql_handler.go index 22ffb45498..29a8772ca8 100644 --- a/router/core/graphql_handler.go +++ b/router/core/graphql_handler.go @@ -390,7 +390,7 @@ func (h *GraphQLHandler) WriteError(ctx *resolve.Context, err error, res *resolv if isHttpResponseWriter { httpWriter.WriteHeader(http.StatusInternalServerError) } - case errorTypeEDFSHookError: + case errorTypeStreamsHandlerError: var errStreamHandlerError *StreamHandlerError if !errors.As(err, &errStreamHandlerError) { response.Errors[0].Message = "Internal server error" From 7f91fb321d4a6f24ab884d87fd5e9dab4c9ac0bc Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Fri, 21 Nov 2025 20:57:15 +0100 Subject: [PATCH 40/44] chore: use context instead of seperate timer --- router-tests/modules/stream_receive_test.go | 4 +++- .../pubsub/datasource/subscription_event_updater.go | 11 +++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/router-tests/modules/stream_receive_test.go b/router-tests/modules/stream_receive_test.go index f78fda9e0c..c71d3019bb 100644 --- a/router-tests/modules/stream_receive_test.go +++ b/router-tests/modules/stream_receive_test.go @@ -954,7 +954,9 @@ func TestReceiveHook(t *testing.T) { assert.Equal(t, float64(1), receivedIDs[len(receivedIDs)-1], "expected the delayed event to arrive last") assert.NotEqual(t, float64(1), receivedIDs[0], "expected at least one later event to arrive before the delayed one") - timeoutLog := xEnv.Observer().FilterMessage("Timeout exceeded during subscription updates, events may arrive out of order") + timeoutLog := xEnv.Observer().FilterMessage("Subscription update timeout exceeded because handler execution took too long. " + + "Consider increasing events.handler.on_receive_events.handler_timeout and/or max_concurrent_handlers or reduce handler execution time." + + "Events may arrive out of order.") assert.Len(t, timeoutLog.All(), 1, "expected timeout warning to be logged") // Verify all hooks were executed diff --git a/router/pkg/pubsub/datasource/subscription_event_updater.go b/router/pkg/pubsub/datasource/subscription_event_updater.go index f165dfe922..09d6d14fcf 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater.go @@ -10,10 +10,7 @@ import ( "go.uber.org/zap/zapcore" ) -const ( - timeoutGracePeriod = 50 * time.Millisecond - defaultTimeout = 5 * time.Second -) +const defaultTimeout = 5 * time.Second // SubscriptionEventUpdater is a wrapper around the SubscriptionUpdater interface // that provides a way to send the event struct instead of the raw data @@ -65,7 +62,7 @@ func (s *subscriptionEventUpdater) Update(events []StreamEvent) { case <-done: s.logger.Debug("All subscription updates completed") // All subscriptions completed successfully - case <-time.After(s.timeout + timeoutGracePeriod): + case <-updaterCtx.Done(): // Timeout exceeded, some subscription updates may still be running. // We can't stop them but we will also not wait for them, basically abandoning them. // They will continue to hold their semaphore slots until they complete, @@ -73,7 +70,9 @@ func (s *subscriptionEventUpdater) Update(events []StreamEvent) { // Also since we will process the next batch of events while having abandoned updaters, // those updaters might eventually push their events to the subscription late, // which means events might arrive out of order. - s.logger.Warn("Timeout exceeded during subscription updates, events may arrive out of order") + s.logger.Warn("Subscription update timeout exceeded because handler execution took too long. " + + "Consider increasing events.handler.on_receive_events.handler_timeout and/or max_concurrent_handlers or reduce handler execution time." + + "Events may arrive out of order.") } } From bb03b5f4b69839e67f71a575833deb95a190ee14 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Fri, 21 Nov 2025 21:05:29 +0100 Subject: [PATCH 41/44] chore: use semaphore package instead of channel --- .../pubsub/datasource/subscription_event_updater.go | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/router/pkg/pubsub/datasource/subscription_event_updater.go b/router/pkg/pubsub/datasource/subscription_event_updater.go index 09d6d14fcf..df2224448d 100644 --- a/router/pkg/pubsub/datasource/subscription_event_updater.go +++ b/router/pkg/pubsub/datasource/subscription_event_updater.go @@ -8,6 +8,7 @@ import ( "github.com/wundergraph/graphql-go-tools/v2/pkg/engine/resolve" "go.uber.org/zap" "go.uber.org/zap/zapcore" + "golang.org/x/sync/semaphore" ) const defaultTimeout = 5 * time.Second @@ -28,7 +29,7 @@ type subscriptionEventUpdater struct { hooks Hooks logger *zap.Logger eventBuilder EventBuilderFn - semaphore chan struct{} + semaphore *semaphore.Weighted timeout time.Duration } @@ -49,7 +50,10 @@ func (s *subscriptionEventUpdater) Update(events []StreamEvent) { go func() { for subCtx, subId := range subscriptions { - s.semaphore <- struct{}{} // Acquire slot, blocks if all slots are taken + if err := s.semaphore.Acquire(updaterCtx, 1); err != nil { + // Context cancelled or timed out, stop acquiring + break + } wg.Add(1) go s.updateSubscription(subCtx, updaterCtx, &wg, subId, events) } @@ -94,7 +98,7 @@ func (s *subscriptionEventUpdater) updateSubscription(subscriptionCtx context.Co if r := recover(); r != nil { s.recoverPanic(subID, r) } - <-s.semaphore // release the slot when done + s.semaphore.Release(1) }() hooks := s.hooks.OnReceiveEvents.Handlers @@ -154,7 +158,7 @@ func NewSubscriptionEventUpdater( eventUpdater: eventUpdater, logger: logger, eventBuilder: eventBuilder, - semaphore: make(chan struct{}, limit), + semaphore: semaphore.NewWeighted(int64(limit)), timeout: timeout, } } From ef45209dfa459062ee4a20e02160fabb5f413bd6 Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Fri, 21 Nov 2025 21:19:12 +0100 Subject: [PATCH 42/44] fix: get rid of cloning events while publishing for redis+nats as well --- router/pkg/pubsub/nats/adapter.go | 23 +++++++++++++++----- router/pkg/pubsub/nats/engine_datasource.go | 8 +++---- router/pkg/pubsub/redis/adapter.go | 19 ++++++++++++---- router/pkg/pubsub/redis/engine_datasource.go | 6 ++--- 4 files changed, 39 insertions(+), 17 deletions(-) diff --git a/router/pkg/pubsub/nats/adapter.go b/router/pkg/pubsub/nats/adapter.go index fbb9dc3fe6..e4cd4cd470 100644 --- a/router/pkg/pubsub/nats/adapter.go +++ b/router/pkg/pubsub/nats/adapter.go @@ -150,7 +150,7 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, cfg datasource.Subscrip }) updater.Update([]datasource.StreamEvent{ - Event{evt: &MutableEvent{ + &Event{evt: &MutableEvent{ Data: msg.Data(), Headers: map[string][]string(msg.Headers()), }}, @@ -198,7 +198,7 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, cfg datasource.Subscrip DestinationName: msg.Subject, }) updater.Update([]datasource.StreamEvent{ - Event{evt: &MutableEvent{ + &Event{evt: &MutableEvent{ Data: msg.Data, Headers: map[string][]string(msg.Header), }}, @@ -251,13 +251,13 @@ func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEv var errs []error for _, streamEvent := range events { - natsEvent, ok := streamEvent.Clone().(*MutableEvent) - if !ok { - errs = append(errs, errors.New("invalid event type for NATS adapter")) + natsEvent, err := castToMutableEvent(streamEvent) + if err != nil { + errs = append(errs, err) continue } - err := p.client.Publish(pubConf.Subject, natsEvent.Data) + err = p.client.Publish(pubConf.Subject, natsEvent.Data) if err != nil { errs = append(errs, err) } @@ -441,3 +441,14 @@ func NewAdapter(ctx context.Context, logger *zap.Logger, url string, opts []nats streamMetricStore: store, }, nil } + +func castToMutableEvent(event datasource.StreamEvent) (*MutableEvent, error) { + switch evt := event.(type) { + case *Event: + return evt.evt, nil + case *MutableEvent: + return evt, nil + default: + return nil, errors.New("invalid event type for NATS adapter") + } +} diff --git a/router/pkg/pubsub/nats/engine_datasource.go b/router/pkg/pubsub/nats/engine_datasource.go index 8d1eb6a1e1..809f2aad43 100644 --- a/router/pkg/pubsub/nats/engine_datasource.go +++ b/router/pkg/pubsub/nats/engine_datasource.go @@ -21,14 +21,14 @@ type Event struct { evt *MutableEvent } -func (e Event) GetData() []byte { +func (e *Event) GetData() []byte { if e.evt == nil { return nil } return slices.Clone(e.evt.Data) } -func (e Event) GetHeaders() map[string][]string { +func (e *Event) GetHeaders() map[string][]string { if e.evt == nil || e.evt.Headers == nil { return nil } @@ -244,7 +244,7 @@ func (s *NatsPublishDataSource) Load(ctx context.Context, input []byte, out *byt return err } - if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{Event{evt: &publishData.Event}}); err != nil { + if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{&Event{evt: &publishData.Event}}); err != nil { // err will not be returned but only logged inside PubSubProvider.Publish to avoid a "unable to fetch from subgraph" error _, errWrite := io.WriteString(out, `{"success": false}`) return errWrite @@ -277,7 +277,7 @@ func (s *NatsRequestDataSource) Load(ctx context.Context, input []byte, out *byt return fmt.Errorf("adapter for provider %s is not of the right type", publishData.Provider) } - return adapter.Request(ctx, publishData.PublishEventConfiguration(), Event{evt: &publishData.Event}, out) + return adapter.Request(ctx, publishData.PublishEventConfiguration(), &Event{evt: &publishData.Event}, out) } func (s *NatsRequestDataSource) LoadWithFiles(ctx context.Context, input []byte, files []*httpclient.FileUpload, out *bytes.Buffer) error { diff --git a/router/pkg/pubsub/redis/adapter.go b/router/pkg/pubsub/redis/adapter.go index b1e90e4f33..89fb39f9ca 100644 --- a/router/pkg/pubsub/redis/adapter.go +++ b/router/pkg/pubsub/redis/adapter.go @@ -130,7 +130,7 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri DestinationName: msg.Channel, }) updater.Update([]datasource.StreamEvent{ - Event{evt: &MutableEvent{ + &Event{evt: &MutableEvent{ Data: []byte(msg.Payload), }}, }) @@ -176,9 +176,9 @@ func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEv var errs []error for _, streamEvent := range events { - redisEvent, ok := streamEvent.Clone().(*MutableEvent) - if !ok { - errs = append(errs, errors.New("invalid event type for Redis adapter")) + redisEvent, err := castToMutableEvent(streamEvent) + if err != nil { + errs = append(errs, err) continue } @@ -225,3 +225,14 @@ func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEv return nil } + +func castToMutableEvent(event datasource.StreamEvent) (*MutableEvent, error) { + switch evt := event.(type) { + case *Event: + return evt.evt, nil + case *MutableEvent: + return evt, nil + default: + return nil, errors.New("invalid event type for Redis adapter") + } +} diff --git a/router/pkg/pubsub/redis/engine_datasource.go b/router/pkg/pubsub/redis/engine_datasource.go index 929aab5739..41151df9f0 100644 --- a/router/pkg/pubsub/redis/engine_datasource.go +++ b/router/pkg/pubsub/redis/engine_datasource.go @@ -21,14 +21,14 @@ type Event struct { evt *MutableEvent } -func (e Event) GetData() []byte { +func (e *Event) GetData() []byte { if e.evt == nil { return nil } return slices.Clone(e.evt.Data) } -func (e Event) Clone() datasource.MutableStreamEvent { +func (e *Event) Clone() datasource.MutableStreamEvent { return e.evt.Clone() } @@ -226,7 +226,7 @@ func (s *PublishDataSource) Load(ctx context.Context, input []byte, out *bytes.B return err } - if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{Event{evt: &publishData.Event}}); err != nil { + if err := s.pubSub.Publish(ctx, publishData.PublishEventConfiguration(), []datasource.StreamEvent{&Event{evt: &publishData.Event}}); err != nil { // err will not be returned but only logged inside PubSubProvider.Publish to avoid a "unable to fetch from subgraph" error _, errWrite := io.WriteString(out, `{"success": false}`) return errWrite From 8e79ea009fdd493a4a92fb75fd57ecf8d461d98f Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Fri, 21 Nov 2025 21:47:09 +0100 Subject: [PATCH 43/44] chore: fix typo in error message --- router/pkg/pubsub/redis/adapter.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/router/pkg/pubsub/redis/adapter.go b/router/pkg/pubsub/redis/adapter.go index 89fb39f9ca..699670f1f7 100644 --- a/router/pkg/pubsub/redis/adapter.go +++ b/router/pkg/pubsub/redis/adapter.go @@ -88,7 +88,7 @@ func (p *ProviderAdapter) Shutdown(ctx context.Context) error { func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.SubscriptionEventConfiguration, updater datasource.SubscriptionEventUpdater) error { subConf, ok := conf.(*SubscriptionEventConfiguration) if !ok { - return datasource.NewError("subscription event not support by redis provider", nil) + return datasource.NewError("subscription event not supported by redis provider", nil) } log := p.logger.With( @@ -154,7 +154,7 @@ func (p *ProviderAdapter) Subscribe(ctx context.Context, conf datasource.Subscri func (p *ProviderAdapter) Publish(ctx context.Context, conf datasource.PublishEventConfiguration, events []datasource.StreamEvent) error { pubConf, ok := conf.(*PublishEventConfiguration) if !ok { - return datasource.NewError("publish event not support by redis provider", nil) + return datasource.NewError("publish event not supported by redis provider", nil) } log := p.logger.With( From da8f4a6b77b7ac49ed86e8cb2423d9c245e77bee Mon Sep 17 00:00:00 2001 From: Dominik Korittki <23359034+dkorittki@users.noreply.github.com> Date: Fri, 21 Nov 2025 22:29:25 +0100 Subject: [PATCH 44/44] chore: update graphql-go-tools in router + router-tests --- router-tests/go.mod | 2 +- router-tests/go.sum | 4 ++-- router/go.mod | 2 +- router/go.sum | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/router-tests/go.mod b/router-tests/go.mod index 330adb6a1d..9db6270ebf 100644 --- a/router-tests/go.mod +++ b/router-tests/go.mod @@ -27,7 +27,7 @@ require ( github.com/wundergraph/cosmo/demo/pkg/subgraphs/projects v0.0.0-20250715110703-10f2e5f9c79e github.com/wundergraph/cosmo/router v0.0.0-20251030234733-8ed574a0296f github.com/wundergraph/cosmo/router-plugin v0.0.0-20250808194725-de123ba1c65e - github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.238.0.20251120113218-a4f189176bde + github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.239 go.opentelemetry.io/otel v1.36.0 go.opentelemetry.io/otel/sdk v1.36.0 go.opentelemetry.io/otel/sdk/metric v1.36.0 diff --git a/router-tests/go.sum b/router-tests/go.sum index fe3848e28b..ea870f8261 100644 --- a/router-tests/go.sum +++ b/router-tests/go.sum @@ -354,8 +354,8 @@ github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTB github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301 h1:EzfKHQoTjFDDcgaECCCR2aTePqMu9QBmPbyhqIYOhV0= github.com/wundergraph/consul/sdk v0.0.0-20250204115147-ed842a8fd301/go.mod h1:wxI0Nak5dI5RvJuzGyiEK4nZj0O9X+Aw6U0tC1wPKq0= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.238.0.20251120113218-a4f189176bde h1:odFMNVd6midgkv+yOfK8WufK+0yANceTJcu7KaIuhs0= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.238.0.20251120113218-a4f189176bde/go.mod h1:mX25ASEQiKamxaFSK6NZihh0oDCigIuzro30up4mFH4= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.239 h1:wh8qTtVS4Wr/dJ/s162hAvCPsaZ1VOdmmg82QhNcGBE= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.239/go.mod h1:mX25ASEQiKamxaFSK6NZihh0oDCigIuzro30up4mFH4= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342 h1:FnBeRrxr7OU4VvAzt5X7s6266i6cSVkkFPS0TuXWbIg= github.com/xrash/smetrics v0.0.0-20250705151800-55b8f293f342/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= diff --git a/router/go.mod b/router/go.mod index 655ebd130e..37c6641bfe 100644 --- a/router/go.mod +++ b/router/go.mod @@ -31,7 +31,7 @@ require ( github.com/tidwall/gjson v1.18.0 github.com/tidwall/sjson v1.2.5 github.com/twmb/franz-go v1.16.1 - github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.238.0.20251120113218-a4f189176bde + github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.239 // Do not upgrade, it renames attributes we rely on go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.58.0 go.opentelemetry.io/contrib/propagators/b3 v1.23.0 diff --git a/router/go.sum b/router/go.sum index 6d78fc037a..4693a892d0 100644 --- a/router/go.sum +++ b/router/go.sum @@ -322,8 +322,8 @@ github.com/wk8/go-ordered-map/v2 v2.1.8 h1:5h/BUHu93oj4gIdvHHHGsScSTMijfx5PeYkE/ github.com/wk8/go-ordered-map/v2 v2.1.8/go.mod h1:5nJHM5DyteebpVlHnWMV0rPz6Zp7+xBAnxjb1X5vnTw= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083 h1:8/D7f8gKxTBjW+SZK4mhxTTBVpxcqeBgWF1Rfmltbfk= github.com/wundergraph/astjson v0.0.0-20250106123708-be463c97e083/go.mod h1:eOTL6acwctsN4F3b7YE+eE2t8zcJ/doLm9sZzsxxxrE= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.238.0.20251120113218-a4f189176bde h1:odFMNVd6midgkv+yOfK8WufK+0yANceTJcu7KaIuhs0= -github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.238.0.20251120113218-a4f189176bde/go.mod h1:mX25ASEQiKamxaFSK6NZihh0oDCigIuzro30up4mFH4= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.239 h1:wh8qTtVS4Wr/dJ/s162hAvCPsaZ1VOdmmg82QhNcGBE= +github.com/wundergraph/graphql-go-tools/v2 v2.0.0-rc.239/go.mod h1:mX25ASEQiKamxaFSK6NZihh0oDCigIuzro30up4mFH4= github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4= github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4= github.com/yuin/gopher-lua v1.1.1 h1:kYKnWBjvbNP4XLT3+bPEwAXJx262OhaHDWDVOPjL46M=