Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions internal/api/v1beta1connect/billing_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"context"

"connectrpc.com/connect"
"github.com/raystack/frontier/billing/customer"
"github.com/raystack/frontier/core/event"
"github.com/raystack/frontier/pkg/server/consts"
frontierv1beta1 "github.com/raystack/frontier/proto/v1beta1"
"go.uber.org/zap"
)
Expand All @@ -16,6 +18,12 @@ func (h *ConnectHandler) BillingWebhookCallback(ctx context.Context, request *co
return nil, connect.NewError(connect.CodeInvalidArgument, ErrBillingProviderNotSupported)
}

// Extract Stripe webhook signature from headers and add to context
// This is required for webhook signature verification in the event service
if webhookSignature := request.Header().Get(consts.StripeWebhookSignature); webhookSignature != "" {
ctx = customer.SetStripeWebhookSignatureInContext(ctx, webhookSignature)
}

if err := h.eventService.BillingWebhook(ctx, event.ProviderWebhookEvent{
Name: request.Msg.GetProvider(),
Body: request.Msg.GetBody(),
Expand Down
11 changes: 6 additions & 5 deletions pkg/server/connect_interceptors/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,10 @@ func (i *AuthenticationInterceptor) WrapStreamingHandler(next connect.StreamingH

// authenticationSkipList stores path to skip authentication, by default its enabled for all requests
var authenticationSkipList = map[string]bool{
"/raystack.frontier.v1beta1.FrontierService/ListAuthStrategies": true,
"/raystack.frontier.v1beta1.FrontierService/Authenticate": true,
"/raystack.frontier.v1beta1.FrontierService/AuthCallback": true,
"/raystack.frontier.v1beta1.FrontierService/ListMetaSchemas": true,
"/raystack.frontier.v1beta1.FrontierService/GetMetaSchema": true,
"/raystack.frontier.v1beta1.FrontierService/ListAuthStrategies": true,
"/raystack.frontier.v1beta1.FrontierService/Authenticate": true,
"/raystack.frontier.v1beta1.FrontierService/AuthCallback": true,
"/raystack.frontier.v1beta1.FrontierService/ListMetaSchemas": true,
"/raystack.frontier.v1beta1.FrontierService/GetMetaSchema": true,
"/raystack.frontier.v1beta1.FrontierService/BillingWebhookCallback": true,
}
4 changes: 4 additions & 0 deletions pkg/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ func ServeConnect(ctx context.Context, logger log.Logger, cfg Config, deps api.D
mux := http.NewServeMux()
mux.Handle(frontierPath, frontierHandler)
mux.Handle(adminPath, adminHandler)

// Register webhook bridge handler to allow Stripe to call with provider in path
// This uses frontierHandler which has all interceptors (auth, logging, audit, etc.) applied
mux.HandleFunc("/billing/webhooks/callback/", WebhookBridgeHandler(frontierHandler))
reflector := grpcreflect.NewStaticReflector(
"raystack.frontier.v1beta1.FrontierService",
"raystack.frontier.v1beta1.AdminService") // protoc-gen-connect-go generates package-level constants
Expand Down
97 changes: 97 additions & 0 deletions pkg/server/webhook_bridge.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
package server

import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"strings"

"github.com/raystack/frontier/pkg/server/consts"
frontierv1beta1 "github.com/raystack/frontier/proto/v1beta1"
frontierv1beta1connect "github.com/raystack/frontier/proto/v1beta1/frontierv1beta1connect"
)

// WebhookBridgeHandler creates an HTTP handler that bridges raw HTTP webhook requests
// to the ConnectRPC BillingWebhookCallback handler. This is needed because Stripe
// doesn't allow modifying the request body but allows custom URL paths.
// The handler uses the frontierHandler which has all interceptors (auth, logging, audit, etc.) applied.
func WebhookBridgeHandler(frontierHandler http.Handler) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Only accept POST requests
if r.Method != http.MethodPost {
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}

// Extract provider from URL path
// Expected path: /billing/webhooks/callback/{provider}
pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
if len(pathParts) < 4 || pathParts[0] != "billing" ||
pathParts[1] != "webhooks" || pathParts[2] != "callback" {
http.Error(w, "invalid path", http.StatusNotFound)
return
}
provider := pathParts[3]
if provider == "" {
http.Error(w, "invalid path", http.StatusNotFound)
return
}

// Read raw request body
body, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, fmt.Sprintf("failed to read request body: %v", err), http.StatusBadRequest)
return
}
defer r.Body.Close()

// Create ConnectRPC request payload
requestPayload := &frontierv1beta1.BillingWebhookCallbackRequest{
Provider: provider,
Body: body,
}

// Encode request as JSON for ConnectRPC
requestJSON, err := json.Marshal(requestPayload)
if err != nil {
http.Error(w, fmt.Sprintf("failed to encode request: %v", err), http.StatusInternalServerError)
return
}

// Create a new HTTP request to the ConnectRPC procedure
connectReq := httptest.NewRequest(
http.MethodPost,
frontierv1beta1connect.FrontierServiceBillingWebhookCallbackProcedure,
bytes.NewReader(requestJSON),
)
connectReq = connectReq.WithContext(r.Context())

// Copy important headers from the original request
if contentType := r.Header.Get("Content-Type"); contentType != "" {
connectReq.Header.Set("X-Original-Content-Type", contentType)
}
// Set ConnectRPC content type
connectReq.Header.Set("Content-Type", "application/json")

// Copy other important headers (auth, request ID, etc.)
headersToProxy := []string{
"Authorization",
"Cookie",
consts.RequestIDHeader,
consts.ProjectRequestKey,
consts.StripeTestClockRequestKey,
consts.StripeWebhookSignature,
}
for _, header := range headersToProxy {
if value := r.Header.Get(header); value != "" {
connectReq.Header.Set(header, value)
}
}

// Forward to the ConnectRPC handler (which has all interceptors)
frontierHandler.ServeHTTP(w, connectReq)
}
}
172 changes: 172 additions & 0 deletions pkg/server/webhook_bridge_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package server

import (
"bytes"
"io"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

// mockHandler is a simple mock HTTP handler for testing
type mockHandler struct {
statusCode int
response []byte
}

func (m *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(m.statusCode)
w.Write(m.response)
}

func TestWebhookBridgeHandler_HTTPMethods(t *testing.T) {
tests := []struct {
name string
method string
path string
body []byte
expectedStatus int
expectedBody string
}{
{
name: "invalid method GET",
method: "GET",
path: "/billing/webhooks/callback/stripe",
body: nil,
expectedStatus: http.StatusMethodNotAllowed,
expectedBody: "method not allowed",
},
{
name: "invalid method PUT",
method: "PUT",
path: "/billing/webhooks/callback/stripe",
body: []byte(`{"test":"data"}`),
expectedStatus: http.StatusMethodNotAllowed,
expectedBody: "method not allowed",
},
{
name: "invalid method DELETE",
method: "DELETE",
path: "/billing/webhooks/callback/stripe",
body: nil,
expectedStatus: http.StatusMethodNotAllowed,
expectedBody: "method not allowed",
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockFrontierHandler := &mockHandler{
statusCode: http.StatusOK,
response: []byte(`{}`),
}

var body io.Reader
if tt.body != nil {
body = bytes.NewReader(tt.body)
}
req := httptest.NewRequest(tt.method, tt.path, body)
rr := httptest.NewRecorder()

bridgeHandler := WebhookBridgeHandler(mockFrontierHandler)
bridgeHandler.ServeHTTP(rr, req)

assert.Equal(t, tt.expectedStatus, rr.Code)
assert.Contains(t, rr.Body.String(), tt.expectedBody)
})
}
}

func TestWebhookBridgeHandler_PathParsing(t *testing.T) {
tests := []struct {
name string
path string
shouldBeValid bool
}{
{
name: "missing provider",
path: "/billing/webhooks/callback",
shouldBeValid: false,
},
{
name: "missing provider with trailing slash",
path: "/billing/webhooks/callback/",
shouldBeValid: false,
},
{
name: "wrong path - missing webhooks",
path: "/billing/wrong/callback/stripe",
shouldBeValid: false,
},
{
name: "wrong path - missing callback",
path: "/billing/webhooks/wrong/stripe",
shouldBeValid: false,
},
{
name: "wrong path - missing billing prefix",
path: "/v1beta1/billing/webhooks/callback/stripe",
shouldBeValid: false,
},
{
name: "completely wrong path",
path: "/v2/api/webhooks",
shouldBeValid: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockFrontierHandler := &mockHandler{
statusCode: http.StatusOK,
response: []byte(`{}`),
}
req := httptest.NewRequest("POST", tt.path, bytes.NewReader([]byte(`{}`)))
rr := httptest.NewRecorder()

bridgeHandler := WebhookBridgeHandler(mockFrontierHandler)
bridgeHandler.ServeHTTP(rr, req)

if tt.shouldBeValid {
// Should not get 404 for path issues (might get other errors from handler logic)
assert.NotEqual(t, http.StatusNotFound, rr.Code, "should not return 404 for valid path")
} else {
// Should get 404 for invalid paths
assert.Equal(t, http.StatusNotFound, rr.Code, "should return 404 for invalid path")
}
})
}
}

func TestWebhookBridgeHandler_SuccessfulRequest(t *testing.T) {
// Mock handler that verifies the request was transformed correctly
mockFrontierHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Verify the request path was changed to ConnectRPC procedure
assert.Contains(t, r.URL.Path, "BillingWebhookCallback")

// Verify content type
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))

// Read and verify body was encoded as ConnectRPC request
body, err := io.ReadAll(r.Body)
assert.NoError(t, err)
assert.Contains(t, string(body), "provider")
assert.Contains(t, string(body), "body")

// Return success
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write([]byte(`{}`))
})

req := httptest.NewRequest("POST", "/billing/webhooks/callback/stripe", bytes.NewReader([]byte(`{"event":"test"}`)))
req.Header.Set("Stripe-Signature", "test-signature")
rr := httptest.NewRecorder()

bridgeHandler := WebhookBridgeHandler(mockFrontierHandler)
bridgeHandler.ServeHTTP(rr, req)

assert.Equal(t, http.StatusOK, rr.Code)
}
Loading