Skip to content

Commit bc34fdb

Browse files
fix: proper support for chat session cors (#1109)
The previous approach was missing support for OPTIONS requests, causing the preflight to break. It was impractical to add support for that due to the location the middleware was attached (after the actual CORS middleware and after routes were matched, thus excluding OPTIONS), so instead I went back to the original approach of integrating with the existing CORS middleware
1 parent b73b92d commit bc34fdb

File tree

7 files changed

+73
-65
lines changed

7 files changed

+73
-65
lines changed

server/cmd/gram/start.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ func newStartCommand() *cli.Command {
539539
ragService := rag.NewToolsetVectorStore(logger, tracerProvider, db, baseChatClient)
540540

541541
mux := goahttp.NewMuxer()
542-
mux.Use(middleware.CORSMiddleware(c.String("environment"), c.String("server-url")))
542+
mux.Use(middleware.CORSMiddleware(c.String("environment"), c.String("server-url"), chatSessionsManager))
543543
mux.Use(middleware.NewHTTPLoggingMiddleware(logger))
544544
mux.Use(customdomains.Middleware(logger, db, c.String("environment"), serverURL))
545545
mux.Use(middleware.SessionMiddleware)

server/internal/chat/impl.go

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,10 @@ func Attach(mux goahttp.Muxer, service *Service) {
9090
endpoints.Use(middleware.MapErrors())
9191
endpoints.Use(middleware.TraceMethods(service.tracer))
9292

93-
chatSessionMiddleware := middleware.ChatSessionMiddleware(service.chatSessions)
94-
9593
server := srv.New(endpoints, mux, goahttp.RequestDecoder, goahttp.ResponseEncoder, nil, nil)
96-
server.Use(chatSessionMiddleware)
9794
srv.Mount(mux, server)
9895

99-
wrappedHandler := chatSessionMiddleware(oops.ErrHandle(service.logger, service.HandleCompletion))
100-
o11y.AttachHandler(mux, "POST", "/chat/completions", wrappedHandler.ServeHTTP)
96+
o11y.AttachHandler(mux, "POST", "/chat/completions", oops.ErrHandle(service.logger, service.HandleCompletion).ServeHTTP)
10197
}
10298

10399
func (s *Service) APIKeyAuth(ctx context.Context, key string, schema *security.APIKeyScheme) (context.Context, error) {
@@ -194,7 +190,7 @@ func (s *Service) ListChats(ctx context.Context, payload *gen.ListChatsPayload)
194190
} else {
195191
chats, err := s.repo.ListChatsForUser(ctx, repo.ListChatsForUserParams{
196192
ProjectID: *authCtx.ProjectID,
197-
UserID: conv.ToPGText(authCtx.UserID),
193+
UserID: conv.ToPGText(authCtx.UserID), // TODO: make this work for external user ids (Elements)
198194
})
199195
if err != nil {
200196
return nil, oops.E(oops.CodeUnexpected, err, "failed to list chats").Log(ctx, s.logger)

server/internal/instances/impl.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,6 @@ func Attach(mux goahttp.Muxer, service *Service) {
124124
endpoints.Use(middleware.TraceMethods(service.tracer))
125125

126126
server := srv.New(endpoints, mux, goahttp.RequestDecoder, goahttp.ResponseEncoder, nil, nil)
127-
server.Use(middleware.ChatSessionMiddleware(service.chatSessions))
128127
srv.Mount(mux, server)
129128

130129
o11y.AttachHandler(mux, "POST", "/rpc/instances.invoke/tool", func(w http.ResponseWriter, r *http.Request) {

server/internal/mcp/impl.go

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ import (
2020
"github.com/go-chi/chi/v5"
2121
"github.com/google/uuid"
2222
"github.com/jackc/pgx/v5/pgxpool"
23-
"github.com/speakeasy-api/gram/server/internal/middleware"
2423
"github.com/speakeasy-api/gram/server/internal/rag"
2524
tm "github.com/speakeasy-api/gram/server/internal/telemetry"
2625
"go.opentelemetry.io/otel/metric"
@@ -175,21 +174,12 @@ func NewService(
175174
}
176175

177176
func Attach(mux goahttp.Muxer, service *Service, metadataService *mcpmetadata.Service) {
178-
chatSessionMiddleware := middleware.ChatSessionMiddleware(service.chatSessionsManager)
179-
180-
// Wraps handler functions with chat session middleware
181-
withMiddleware := func(handler http.Handler) http.Handler {
182-
return chatSessionMiddleware(handler)
183-
}
184-
185-
handler := withMiddleware(oops.ErrHandle(service.logger, service.ServePublic))
186-
o11y.AttachHandler(mux, "POST", "/mcp/{mcpSlug}", handler.ServeHTTP)
187-
188-
o11y.AttachHandler(mux, "GET", "/mcp/{mcpSlug}", withMiddleware(oops.ErrHandle(service.logger, func(w http.ResponseWriter, r *http.Request) error {
177+
o11y.AttachHandler(mux, "POST", "/mcp/{mcpSlug}", oops.ErrHandle(service.logger, service.ServePublic).ServeHTTP)
178+
o11y.AttachHandler(mux, "GET", "/mcp/{mcpSlug}", oops.ErrHandle(service.logger, func(w http.ResponseWriter, r *http.Request) error {
189179
return service.HandleGetServer(w, r, metadataService)
190-
})).ServeHTTP)
191-
o11y.AttachHandler(mux, "GET", "/mcp/{mcpSlug}/install", withMiddleware(oops.ErrHandle(service.logger, metadataService.ServeInstallPage)).ServeHTTP)
192-
o11y.AttachHandler(mux, "POST", "/mcp/{project}/{toolset}/{environment}", withMiddleware(oops.ErrHandle(service.logger, service.ServeAuthenticated)).ServeHTTP)
180+
}).ServeHTTP)
181+
o11y.AttachHandler(mux, "GET", "/mcp/{mcpSlug}/install", oops.ErrHandle(service.logger, metadataService.ServeInstallPage).ServeHTTP)
182+
o11y.AttachHandler(mux, "POST", "/mcp/{project}/{toolset}/{environment}", oops.ErrHandle(service.logger, service.ServeAuthenticated).ServeHTTP)
193183

194184
// OAuth 2.1 Authorization Server Metadata
195185
o11y.AttachHandler(mux, "GET", "/.well-known/oauth-authorization-server/mcp/{mcpSlug}", oops.ErrHandle(service.logger, service.HandleWellKnownOAuthServerMetadata).ServeHTTP)

server/internal/middleware/chat_session.go

Lines changed: 0 additions & 39 deletions
This file was deleted.
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
package middleware
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
"slices"
7+
"strings"
8+
9+
"github.com/speakeasy-api/gram/server/internal/auth/chatsessions"
10+
"github.com/speakeasy-api/gram/server/internal/constants"
11+
)
12+
13+
var chatSessionsAllowedRoutes = []string{
14+
"/chat/completions",
15+
"/mcp",
16+
// "/rpc/chat", // TODO: Support listing / creating chats for elements
17+
}
18+
19+
// This isn't practical to do as a proper middleware because it needs to interoperate with the CORSMiddleware which does things like returning early for OPTIONS requests.
20+
// Instead, we combine it with the CORSMiddleware so that all CORS stuff is handled in one place.
21+
func chatSessionsCORS(chatSessionsManager *chatsessions.Manager) func(next http.Handler) http.Handler {
22+
return func(next http.Handler) http.Handler {
23+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
24+
if r.Method == http.MethodOptions {
25+
// Slightly non-ideal, but later in the file we validate the origin of the request against the audience claim
26+
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) // Allow the origin of the request for OPTIONS requests because we don't know what origins to allow until we get the token on the actual request
27+
w.WriteHeader(http.StatusNoContent)
28+
return
29+
}
30+
31+
chatSession := r.Header.Get(constants.ChatSessionsTokenHeader)
32+
if chatSession == "" {
33+
next.ServeHTTP(w, r)
34+
return
35+
}
36+
37+
claims, err := chatSessionsManager.ValidateToken(r.Context(), chatSession)
38+
if err != nil {
39+
http.Error(w, "unauthorized", http.StatusUnauthorized)
40+
return
41+
}
42+
43+
// If the request origin is in the allowed origins, set the allowed origin in the context to be used in the CORS middleware
44+
if slices.Contains(claims.Audience, r.Header.Get("Origin")) {
45+
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))
46+
} else {
47+
http.Error(w, fmt.Sprintf("Origin %s does not match audience claim: %s", r.Header.Get("Origin"), strings.Join(claims.Audience, ", ")), http.StatusForbidden)
48+
return
49+
}
50+
51+
next.ServeHTTP(w, r)
52+
})
53+
}
54+
}

server/internal/middleware/cors.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@ import (
55
"net/url"
66
"slices"
77
"strings"
8+
9+
"github.com/speakeasy-api/gram/server/internal/auth/chatsessions"
810
)
911

1012
var mcpOpenAccessControlRoutes = []string{
1113
"/.well-known/oauth-authorization-server/mcp",
1214
"/.well-known/oauth-protected-resource/mcp",
1315
}
1416

15-
func CORSMiddleware(env string, serverURL string) func(next http.Handler) http.Handler {
17+
func CORSMiddleware(env string, serverURL string, chatSessionsManager *chatsessions.Manager) func(next http.Handler) http.Handler {
1618
return func(next http.Handler) http.Handler {
1719
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1820
switch env {
@@ -29,8 +31,6 @@ func CORSMiddleware(env string, serverURL string) func(next http.Handler) http.H
2931
// No CORS headers set for unspecified environments
3032
}
3133

32-
// NOTE: The chatSession middleware may also set the Access-Control-Allow-Origin header for chat sessions.
33-
3434
w.Header().Set("Access-Control-Allow-Methods", "POST, GET, OPTIONS, PUT, DELETE")
3535
w.Header().Set("Access-Control-Allow-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, Authorization, User-Agent, Gram-Session, Gram-Project, Gram-Token, idempotency-key, Gram-Admin-Override, Gram-Chat-ID, Gram-Chat-Session, MCP-Protocol-Version")
3636
w.Header().Set("Access-Control-Expose-Headers", "Accept, Content-Type, Content-Length, Accept-Encoding, x-trace-id, Gram-Session, Gram-Chat-ID, Gram-Chat-Session")
@@ -46,6 +46,14 @@ func CORSMiddleware(env string, serverURL string) func(next http.Handler) http.H
4646
w.Header().Del("Access-Control-Allow-Credentials")
4747
}
4848

49+
// Special CORS handling for chat sessions-enabled routes
50+
if slices.ContainsFunc(chatSessionsAllowedRoutes, func(route string) bool {
51+
return strings.HasPrefix(r.URL.Path, route)
52+
}) {
53+
chatSessionsCORS(chatSessionsManager)(next).ServeHTTP(w, r)
54+
return
55+
}
56+
4957
if r.Method == "OPTIONS" {
5058
w.WriteHeader(http.StatusOK)
5159
return

0 commit comments

Comments
 (0)