Skip to content

Commit d76e2cd

Browse files
authored
feat: add webhook bridge handler for Stripe callbacks (#1334)
1 parent ec3b49f commit d76e2cd

File tree

5 files changed

+287
-5
lines changed

5 files changed

+287
-5
lines changed

internal/api/v1beta1connect/billing_webhook.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@ import (
44
"context"
55

66
"connectrpc.com/connect"
7+
"github.com/raystack/frontier/billing/customer"
78
"github.com/raystack/frontier/core/event"
9+
"github.com/raystack/frontier/pkg/server/consts"
810
frontierv1beta1 "github.com/raystack/frontier/proto/v1beta1"
911
"go.uber.org/zap"
1012
)
@@ -16,6 +18,12 @@ func (h *ConnectHandler) BillingWebhookCallback(ctx context.Context, request *co
1618
return nil, connect.NewError(connect.CodeInvalidArgument, ErrBillingProviderNotSupported)
1719
}
1820

21+
// Extract Stripe webhook signature from headers and add to context
22+
// This is required for webhook signature verification in the event service
23+
if webhookSignature := request.Header().Get(consts.StripeWebhookSignature); webhookSignature != "" {
24+
ctx = customer.SetStripeWebhookSignatureInContext(ctx, webhookSignature)
25+
}
26+
1927
if err := h.eventService.BillingWebhook(ctx, event.ProviderWebhookEvent{
2028
Name: request.Msg.GetProvider(),
2129
Body: request.Msg.GetBody(),

pkg/server/connect_interceptors/authentication.go

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,10 @@ func (i *AuthenticationInterceptor) WrapStreamingHandler(next connect.StreamingH
9191

9292
// authenticationSkipList stores path to skip authentication, by default its enabled for all requests
9393
var authenticationSkipList = map[string]bool{
94-
"/raystack.frontier.v1beta1.FrontierService/ListAuthStrategies": true,
95-
"/raystack.frontier.v1beta1.FrontierService/Authenticate": true,
96-
"/raystack.frontier.v1beta1.FrontierService/AuthCallback": true,
97-
"/raystack.frontier.v1beta1.FrontierService/ListMetaSchemas": true,
98-
"/raystack.frontier.v1beta1.FrontierService/GetMetaSchema": true,
94+
"/raystack.frontier.v1beta1.FrontierService/ListAuthStrategies": true,
95+
"/raystack.frontier.v1beta1.FrontierService/Authenticate": true,
96+
"/raystack.frontier.v1beta1.FrontierService/AuthCallback": true,
97+
"/raystack.frontier.v1beta1.FrontierService/ListMetaSchemas": true,
98+
"/raystack.frontier.v1beta1.FrontierService/GetMetaSchema": true,
99+
"/raystack.frontier.v1beta1.FrontierService/BillingWebhookCallback": true,
99100
}

pkg/server/server.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,10 @@ func ServeConnect(ctx context.Context, logger log.Logger, cfg Config, deps api.D
216216
mux := http.NewServeMux()
217217
mux.Handle(frontierPath, frontierHandler)
218218
mux.Handle(adminPath, adminHandler)
219+
220+
// Register webhook bridge handler to allow Stripe to call with provider in path
221+
// This uses frontierHandler which has all interceptors (auth, logging, audit, etc.) applied
222+
mux.HandleFunc("/billing/webhooks/callback/", WebhookBridgeHandler(frontierHandler))
219223
reflector := grpcreflect.NewStaticReflector(
220224
"raystack.frontier.v1beta1.FrontierService",
221225
"raystack.frontier.v1beta1.AdminService") // protoc-gen-connect-go generates package-level constants

pkg/server/webhook_bridge.go

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package server
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"net/http/httptest"
10+
"strings"
11+
12+
"github.com/raystack/frontier/pkg/server/consts"
13+
frontierv1beta1 "github.com/raystack/frontier/proto/v1beta1"
14+
frontierv1beta1connect "github.com/raystack/frontier/proto/v1beta1/frontierv1beta1connect"
15+
)
16+
17+
// WebhookBridgeHandler creates an HTTP handler that bridges raw HTTP webhook requests
18+
// to the ConnectRPC BillingWebhookCallback handler. This is needed because Stripe
19+
// doesn't allow modifying the request body but allows custom URL paths.
20+
// The handler uses the frontierHandler which has all interceptors (auth, logging, audit, etc.) applied.
21+
func WebhookBridgeHandler(frontierHandler http.Handler) http.HandlerFunc {
22+
return func(w http.ResponseWriter, r *http.Request) {
23+
// Only accept POST requests
24+
if r.Method != http.MethodPost {
25+
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
26+
return
27+
}
28+
29+
// Extract provider from URL path
30+
// Expected path: /billing/webhooks/callback/{provider}
31+
pathParts := strings.Split(strings.Trim(r.URL.Path, "/"), "/")
32+
if len(pathParts) < 4 || pathParts[0] != "billing" ||
33+
pathParts[1] != "webhooks" || pathParts[2] != "callback" {
34+
http.Error(w, "invalid path", http.StatusNotFound)
35+
return
36+
}
37+
provider := pathParts[3]
38+
if provider == "" {
39+
http.Error(w, "invalid path", http.StatusNotFound)
40+
return
41+
}
42+
43+
// Read raw request body
44+
body, err := io.ReadAll(r.Body)
45+
if err != nil {
46+
http.Error(w, fmt.Sprintf("failed to read request body: %v", err), http.StatusBadRequest)
47+
return
48+
}
49+
defer r.Body.Close()
50+
51+
// Create ConnectRPC request payload
52+
requestPayload := &frontierv1beta1.BillingWebhookCallbackRequest{
53+
Provider: provider,
54+
Body: body,
55+
}
56+
57+
// Encode request as JSON for ConnectRPC
58+
requestJSON, err := json.Marshal(requestPayload)
59+
if err != nil {
60+
http.Error(w, fmt.Sprintf("failed to encode request: %v", err), http.StatusInternalServerError)
61+
return
62+
}
63+
64+
// Create a new HTTP request to the ConnectRPC procedure
65+
connectReq := httptest.NewRequest(
66+
http.MethodPost,
67+
frontierv1beta1connect.FrontierServiceBillingWebhookCallbackProcedure,
68+
bytes.NewReader(requestJSON),
69+
)
70+
connectReq = connectReq.WithContext(r.Context())
71+
72+
// Copy important headers from the original request
73+
if contentType := r.Header.Get("Content-Type"); contentType != "" {
74+
connectReq.Header.Set("X-Original-Content-Type", contentType)
75+
}
76+
// Set ConnectRPC content type
77+
connectReq.Header.Set("Content-Type", "application/json")
78+
79+
// Copy other important headers (auth, request ID, etc.)
80+
headersToProxy := []string{
81+
"Authorization",
82+
"Cookie",
83+
consts.RequestIDHeader,
84+
consts.ProjectRequestKey,
85+
consts.StripeTestClockRequestKey,
86+
consts.StripeWebhookSignature,
87+
}
88+
for _, header := range headersToProxy {
89+
if value := r.Header.Get(header); value != "" {
90+
connectReq.Header.Set(header, value)
91+
}
92+
}
93+
94+
// Forward to the ConnectRPC handler (which has all interceptors)
95+
frontierHandler.ServeHTTP(w, connectReq)
96+
}
97+
}

pkg/server/webhook_bridge_test.go

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
package server
2+
3+
import (
4+
"bytes"
5+
"io"
6+
"net/http"
7+
"net/http/httptest"
8+
"testing"
9+
10+
"github.com/stretchr/testify/assert"
11+
)
12+
13+
// mockHandler is a simple mock HTTP handler for testing
14+
type mockHandler struct {
15+
statusCode int
16+
response []byte
17+
}
18+
19+
func (m *mockHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
20+
w.WriteHeader(m.statusCode)
21+
w.Write(m.response)
22+
}
23+
24+
func TestWebhookBridgeHandler_HTTPMethods(t *testing.T) {
25+
tests := []struct {
26+
name string
27+
method string
28+
path string
29+
body []byte
30+
expectedStatus int
31+
expectedBody string
32+
}{
33+
{
34+
name: "invalid method GET",
35+
method: "GET",
36+
path: "/billing/webhooks/callback/stripe",
37+
body: nil,
38+
expectedStatus: http.StatusMethodNotAllowed,
39+
expectedBody: "method not allowed",
40+
},
41+
{
42+
name: "invalid method PUT",
43+
method: "PUT",
44+
path: "/billing/webhooks/callback/stripe",
45+
body: []byte(`{"test":"data"}`),
46+
expectedStatus: http.StatusMethodNotAllowed,
47+
expectedBody: "method not allowed",
48+
},
49+
{
50+
name: "invalid method DELETE",
51+
method: "DELETE",
52+
path: "/billing/webhooks/callback/stripe",
53+
body: nil,
54+
expectedStatus: http.StatusMethodNotAllowed,
55+
expectedBody: "method not allowed",
56+
},
57+
}
58+
59+
for _, tt := range tests {
60+
t.Run(tt.name, func(t *testing.T) {
61+
mockFrontierHandler := &mockHandler{
62+
statusCode: http.StatusOK,
63+
response: []byte(`{}`),
64+
}
65+
66+
var body io.Reader
67+
if tt.body != nil {
68+
body = bytes.NewReader(tt.body)
69+
}
70+
req := httptest.NewRequest(tt.method, tt.path, body)
71+
rr := httptest.NewRecorder()
72+
73+
bridgeHandler := WebhookBridgeHandler(mockFrontierHandler)
74+
bridgeHandler.ServeHTTP(rr, req)
75+
76+
assert.Equal(t, tt.expectedStatus, rr.Code)
77+
assert.Contains(t, rr.Body.String(), tt.expectedBody)
78+
})
79+
}
80+
}
81+
82+
func TestWebhookBridgeHandler_PathParsing(t *testing.T) {
83+
tests := []struct {
84+
name string
85+
path string
86+
shouldBeValid bool
87+
}{
88+
{
89+
name: "missing provider",
90+
path: "/billing/webhooks/callback",
91+
shouldBeValid: false,
92+
},
93+
{
94+
name: "missing provider with trailing slash",
95+
path: "/billing/webhooks/callback/",
96+
shouldBeValid: false,
97+
},
98+
{
99+
name: "wrong path - missing webhooks",
100+
path: "/billing/wrong/callback/stripe",
101+
shouldBeValid: false,
102+
},
103+
{
104+
name: "wrong path - missing callback",
105+
path: "/billing/webhooks/wrong/stripe",
106+
shouldBeValid: false,
107+
},
108+
{
109+
name: "wrong path - missing billing prefix",
110+
path: "/v1beta1/billing/webhooks/callback/stripe",
111+
shouldBeValid: false,
112+
},
113+
{
114+
name: "completely wrong path",
115+
path: "/v2/api/webhooks",
116+
shouldBeValid: false,
117+
},
118+
}
119+
120+
for _, tt := range tests {
121+
t.Run(tt.name, func(t *testing.T) {
122+
mockFrontierHandler := &mockHandler{
123+
statusCode: http.StatusOK,
124+
response: []byte(`{}`),
125+
}
126+
req := httptest.NewRequest("POST", tt.path, bytes.NewReader([]byte(`{}`)))
127+
rr := httptest.NewRecorder()
128+
129+
bridgeHandler := WebhookBridgeHandler(mockFrontierHandler)
130+
bridgeHandler.ServeHTTP(rr, req)
131+
132+
if tt.shouldBeValid {
133+
// Should not get 404 for path issues (might get other errors from handler logic)
134+
assert.NotEqual(t, http.StatusNotFound, rr.Code, "should not return 404 for valid path")
135+
} else {
136+
// Should get 404 for invalid paths
137+
assert.Equal(t, http.StatusNotFound, rr.Code, "should return 404 for invalid path")
138+
}
139+
})
140+
}
141+
}
142+
143+
func TestWebhookBridgeHandler_SuccessfulRequest(t *testing.T) {
144+
// Mock handler that verifies the request was transformed correctly
145+
mockFrontierHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
146+
// Verify the request path was changed to ConnectRPC procedure
147+
assert.Contains(t, r.URL.Path, "BillingWebhookCallback")
148+
149+
// Verify content type
150+
assert.Equal(t, "application/json", r.Header.Get("Content-Type"))
151+
152+
// Read and verify body was encoded as ConnectRPC request
153+
body, err := io.ReadAll(r.Body)
154+
assert.NoError(t, err)
155+
assert.Contains(t, string(body), "provider")
156+
assert.Contains(t, string(body), "body")
157+
158+
// Return success
159+
w.Header().Set("Content-Type", "application/json")
160+
w.WriteHeader(http.StatusOK)
161+
w.Write([]byte(`{}`))
162+
})
163+
164+
req := httptest.NewRequest("POST", "/billing/webhooks/callback/stripe", bytes.NewReader([]byte(`{"event":"test"}`)))
165+
req.Header.Set("Stripe-Signature", "test-signature")
166+
rr := httptest.NewRecorder()
167+
168+
bridgeHandler := WebhookBridgeHandler(mockFrontierHandler)
169+
bridgeHandler.ServeHTTP(rr, req)
170+
171+
assert.Equal(t, http.StatusOK, rr.Code)
172+
}

0 commit comments

Comments
 (0)