diff --git a/internal/api/v1beta1connect/billing_webhook.go b/internal/api/v1beta1connect/billing_webhook.go index ced3b1a96..0c6a33487 100644 --- a/internal/api/v1beta1connect/billing_webhook.go +++ b/internal/api/v1beta1connect/billing_webhook.go @@ -4,7 +4,9 @@ import ( "context" "connectrpc.com/connect" + "github.com/raystack/frontier/billing/customer" "github.com/raystack/frontier/core/event" + "github.com/raystack/frontier/pkg/server/consts" frontierv1beta1 "github.com/raystack/frontier/proto/v1beta1" "go.uber.org/zap" ) @@ -16,6 +18,12 @@ func (h *ConnectHandler) BillingWebhookCallback(ctx context.Context, request *co return nil, connect.NewError(connect.CodeInvalidArgument, ErrBillingProviderNotSupported) } + // Extract Stripe webhook signature from headers and add to context + // This is required for webhook signature verification in the event service + if webhookSignature := request.Header().Get(consts.StripeWebhookSignature); webhookSignature != "" { + ctx = customer.SetStripeWebhookSignatureInContext(ctx, webhookSignature) + } + if err := h.eventService.BillingWebhook(ctx, event.ProviderWebhookEvent{ Name: request.Msg.GetProvider(), Body: request.Msg.GetBody(), diff --git a/pkg/server/connect_interceptors/authentication.go b/pkg/server/connect_interceptors/authentication.go index beb5111ce..67cfc8dcf 100644 --- a/pkg/server/connect_interceptors/authentication.go +++ b/pkg/server/connect_interceptors/authentication.go @@ -91,9 +91,10 @@ func (i *AuthenticationInterceptor) WrapStreamingHandler(next connect.StreamingH // authenticationSkipList stores path to skip authentication, by default its enabled for all requests var authenticationSkipList = map[string]bool{ - "/raystack.frontier.v1beta1.FrontierService/ListAuthStrategies": true, - "/raystack.frontier.v1beta1.FrontierService/Authenticate": true, - "/raystack.frontier.v1beta1.FrontierService/AuthCallback": true, - "/raystack.frontier.v1beta1.FrontierService/ListMetaSchemas": true, - "/raystack.frontier.v1beta1.FrontierService/GetMetaSchema": true, + "/raystack.frontier.v1beta1.FrontierService/ListAuthStrategies": true, + "/raystack.frontier.v1beta1.FrontierService/Authenticate": true, + "/raystack.frontier.v1beta1.FrontierService/AuthCallback": true, + "/raystack.frontier.v1beta1.FrontierService/ListMetaSchemas": true, + "/raystack.frontier.v1beta1.FrontierService/GetMetaSchema": true, + "/raystack.frontier.v1beta1.FrontierService/BillingWebhookCallback": true, } diff --git a/pkg/server/server.go b/pkg/server/server.go index a7a5a1f39..30a5a3122 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -216,6 +216,10 @@ func ServeConnect(ctx context.Context, logger log.Logger, cfg Config, deps api.D mux := http.NewServeMux() mux.Handle(frontierPath, frontierHandler) mux.Handle(adminPath, adminHandler) + + // Register webhook bridge handler to allow Stripe to call with provider in path + // This uses frontierHandler which has all interceptors (auth, logging, audit, etc.) applied + mux.HandleFunc("/billing/webhooks/callback/", WebhookBridgeHandler(frontierHandler)) reflector := grpcreflect.NewStaticReflector( "raystack.frontier.v1beta1.FrontierService", "raystack.frontier.v1beta1.AdminService") // protoc-gen-connect-go generates package-level constants diff --git a/pkg/server/webhook_bridge.go b/pkg/server/webhook_bridge.go new file mode 100644 index 000000000..2da023efb --- /dev/null +++ b/pkg/server/webhook_bridge.go @@ -0,0 +1,97 @@ +package server + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + "strings" + + "github.com/raystack/frontier/pkg/server/consts" + frontierv1beta1 "github.com/raystack/frontier/proto/v1beta1" + frontierv1beta1connect "github.com/raystack/frontier/proto/v1beta1/frontierv1beta1connect" +) + +// WebhookBridgeHandler creates an HTTP handler that bridges raw HTTP webhook requests +// to the ConnectRPC BillingWebhookCallback handler. This is needed because Stripe +// doesn't allow modifying the request body but allows custom URL paths. +// The handler uses the frontierHandler which has all interceptors (auth, logging, audit, etc.) applied. +func WebhookBridgeHandler(frontierHandler http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + // Only accept POST requests + if r.Method != http.MethodPost { + http.Error(w, "method not allowed", http.StatusMethodNotAllowed) + return + } + + // Extract provider from URL path + // Expected path: /billing/webhooks/callback/{provider} + pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/") + if len(pathParts) < 4 || pathParts[0] != "billing" || + pathParts[1] != "webhooks" || pathParts[2] != "callback" { + http.Error(w, "invalid path", http.StatusNotFound) + return + } + provider := pathParts[3] + if provider == "" { + http.Error(w, "invalid path", http.StatusNotFound) + return + } + + // Read raw request body + body, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, fmt.Sprintf("failed to read request body: %v", err), http.StatusBadRequest) + return + } + defer r.Body.Close() + + // Create ConnectRPC request payload + requestPayload := &frontierv1beta1.BillingWebhookCallbackRequest{ + Provider: provider, + Body: body, + } + + // Encode request as JSON for ConnectRPC + requestJSON, err := json.Marshal(requestPayload) + if err != nil { + http.Error(w, fmt.Sprintf("failed to encode request: %v", err), http.StatusInternalServerError) + return + } + + // Create a new HTTP request to the ConnectRPC procedure + connectReq := httptest.NewRequest( + http.MethodPost, + frontierv1beta1connect.FrontierServiceBillingWebhookCallbackProcedure, + bytes.NewReader(requestJSON), + ) + connectReq = connectReq.WithContext(r.Context()) + + // Copy important headers from the original request + if contentType := r.Header.Get("Content-Type"); contentType != "" { + connectReq.Header.Set("X-Original-Content-Type", contentType) + } + // Set ConnectRPC content type + connectReq.Header.Set("Content-Type", "application/json") + + // Copy other important headers (auth, request ID, etc.) + headersToProxy := []string{ + "Authorization", + "Cookie", + consts.RequestIDHeader, + consts.ProjectRequestKey, + consts.StripeTestClockRequestKey, + consts.StripeWebhookSignature, + } + for _, header := range headersToProxy { + if value := r.Header.Get(header); value != "" { + connectReq.Header.Set(header, value) + } + } + + // Forward to the ConnectRPC handler (which has all interceptors) + frontierHandler.ServeHTTP(w, connectReq) + } +} diff --git a/pkg/server/webhook_bridge_test.go b/pkg/server/webhook_bridge_test.go new file mode 100644 index 000000000..717b46931 --- /dev/null +++ b/pkg/server/webhook_bridge_test.go @@ -0,0 +1,172 @@ +package server + +import ( + "bytes" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" +) + +// mockHandler is a simple mock HTTP handler for testing +type mockHandler struct { + statusCode int + response []byte +} + +func (m *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(m.statusCode) + w.Write(m.response) +} + +func TestWebhookBridgeHandler_HTTPMethods(t *testing.T) { + tests := []struct { + name string + method string + path string + body []byte + expectedStatus int + expectedBody string + }{ + { + name: "invalid method GET", + method: "GET", + path: "/billing/webhooks/callback/stripe", + body: nil, + expectedStatus: http.StatusMethodNotAllowed, + expectedBody: "method not allowed", + }, + { + name: "invalid method PUT", + method: "PUT", + path: "/billing/webhooks/callback/stripe", + body: []byte(`{"test":"data"}`), + expectedStatus: http.StatusMethodNotAllowed, + expectedBody: "method not allowed", + }, + { + name: "invalid method DELETE", + method: "DELETE", + path: "/billing/webhooks/callback/stripe", + body: nil, + expectedStatus: http.StatusMethodNotAllowed, + expectedBody: "method not allowed", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockFrontierHandler := &mockHandler{ + statusCode: http.StatusOK, + response: []byte(`{}`), + } + + var body io.Reader + if tt.body != nil { + body = bytes.NewReader(tt.body) + } + req := httptest.NewRequest(tt.method, tt.path, body) + rr := httptest.NewRecorder() + + bridgeHandler := WebhookBridgeHandler(mockFrontierHandler) + bridgeHandler.ServeHTTP(rr, req) + + assert.Equal(t, tt.expectedStatus, rr.Code) + assert.Contains(t, rr.Body.String(), tt.expectedBody) + }) + } +} + +func TestWebhookBridgeHandler_PathParsing(t *testing.T) { + tests := []struct { + name string + path string + shouldBeValid bool + }{ + { + name: "missing provider", + path: "/billing/webhooks/callback", + shouldBeValid: false, + }, + { + name: "missing provider with trailing slash", + path: "/billing/webhooks/callback/", + shouldBeValid: false, + }, + { + name: "wrong path - missing webhooks", + path: "/billing/wrong/callback/stripe", + shouldBeValid: false, + }, + { + name: "wrong path - missing callback", + path: "/billing/webhooks/wrong/stripe", + shouldBeValid: false, + }, + { + name: "wrong path - missing billing prefix", + path: "/v1beta1/billing/webhooks/callback/stripe", + shouldBeValid: false, + }, + { + name: "completely wrong path", + path: "/v2/api/webhooks", + shouldBeValid: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockFrontierHandler := &mockHandler{ + statusCode: http.StatusOK, + response: []byte(`{}`), + } + req := httptest.NewRequest("POST", tt.path, bytes.NewReader([]byte(`{}`))) + rr := httptest.NewRecorder() + + bridgeHandler := WebhookBridgeHandler(mockFrontierHandler) + bridgeHandler.ServeHTTP(rr, req) + + if tt.shouldBeValid { + // Should not get 404 for path issues (might get other errors from handler logic) + assert.NotEqual(t, http.StatusNotFound, rr.Code, "should not return 404 for valid path") + } else { + // Should get 404 for invalid paths + assert.Equal(t, http.StatusNotFound, rr.Code, "should return 404 for invalid path") + } + }) + } +} + +func TestWebhookBridgeHandler_SuccessfulRequest(t *testing.T) { + // Mock handler that verifies the request was transformed correctly + mockFrontierHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify the request path was changed to ConnectRPC procedure + assert.Contains(t, r.URL.Path, "BillingWebhookCallback") + + // Verify content type + assert.Equal(t, "application/json", r.Header.Get("Content-Type")) + + // Read and verify body was encoded as ConnectRPC request + body, err := io.ReadAll(r.Body) + assert.NoError(t, err) + assert.Contains(t, string(body), "provider") + assert.Contains(t, string(body), "body") + + // Return success + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write([]byte(`{}`)) + }) + + req := httptest.NewRequest("POST", "/billing/webhooks/callback/stripe", bytes.NewReader([]byte(`{"event":"test"}`))) + req.Header.Set("Stripe-Signature", "test-signature") + rr := httptest.NewRecorder() + + bridgeHandler := WebhookBridgeHandler(mockFrontierHandler) + bridgeHandler.ServeHTTP(rr, req) + + assert.Equal(t, http.StatusOK, rr.Code) +}