Skip to content

Commit 2fccb2e

Browse files
authored
Merge pull request #30 from signnow/fix/cors-expose-headers-preflight
Expose Mcp-Session-Id header in CORS preflight responses
2 parents 90e7646 + 9d7e4d1 commit 2fccb2e

File tree

8 files changed

+221
-115
lines changed

8 files changed

+221
-115
lines changed

src/signnow_client/models/folders_lite.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@ def _normalize_folder_type_value(value: Any) -> Any:
4444
return value
4545

4646

47+
def _normalize_to_unknown(_: object) -> str:
48+
"""Always return 'unknown' — used as BeforeValidator for UnknownFolderDocLite.type."""
49+
return "unknown"
50+
51+
4752
def _normalize_roles(value: Any) -> list[str] | None:
4853
"""Normalize roles to list[str] format.
4954
@@ -171,7 +176,6 @@ class DocumentGroupDocumentLite(SNBaseModel):
171176
DocTypeTemplate = Annotated[Literal["template"], BeforeValidator(_normalize_folder_type_value)]
172177
DocTypeDocGroup = Annotated[Literal["document-group"], BeforeValidator(_normalize_folder_type_value)]
173178
DocTypeDgt = Annotated[Literal["dgt"], BeforeValidator(_normalize_folder_type_value)]
174-
DocTypeUnknown = Annotated[Literal["unknown"], BeforeValidator(_normalize_folder_type_value)]
175179

176180

177181
class DocumentItemLite(SNBaseModel):
@@ -272,7 +276,7 @@ class DocumentGroupTemplateItemLite(SNBaseModel):
272276
class UnknownFolderDocLite(SNBaseModel):
273277
"""Fallback model for folder items with unknown or missing type/entity_type."""
274278

275-
type: DocTypeUnknown = Field(..., validation_alias=AliasChoices("type", "entity_type"))
279+
type: Annotated[Literal["unknown"], BeforeValidator(_normalize_to_unknown)] = Field("unknown", validation_alias=AliasChoices("type", "entity_type"))
276280

277281
id: str
278282
user_id: str | None = None

src/sn_mcp_server/app.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
from collections.abc import Sequence
2+
from typing import Any
3+
14
from fastmcp.server.http import create_sse_app
25
from starlette.applications import Starlette
36
from starlette.middleware.cors import CORSMiddleware
47
from starlette.routing import Mount, Route
8+
from starlette.types import ASGIApp
59

610
from .auth import (
711
BearerJWTASGIMiddleware,
@@ -11,6 +15,20 @@
1115
from .config import load_settings
1216

1317

18+
class _CORSMiddlewareWithExposeInPreflight(CORSMiddleware):
19+
"""
20+
Starlette's CORSMiddleware adds Access-Control-Expose-Headers only to
21+
actual (non-preflight) responses. This subclass also injects it into
22+
the preflight response headers so that clients such as Claude's MCP
23+
web client can read Mcp-Session-Id after the OPTIONS handshake.
24+
"""
25+
26+
def __init__(self, app: ASGIApp, expose_headers: Sequence[str] = (), **kwargs: Any) -> None: # noqa: ANN401
27+
super().__init__(app, expose_headers=expose_headers, **kwargs)
28+
if expose_headers:
29+
self.preflight_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers)
30+
31+
1432
def create_http_app() -> Starlette:
1533
"""Create and configure Starlette HTTP application with MCP endpoints"""
1634
# ============= CONFIG =============
@@ -41,11 +59,11 @@ def create_http_app() -> Starlette:
4159

4260
# CORS for browser clients (inspector) - BEFORE BearerJWTMiddleware
4361
app.add_middleware(
44-
CORSMiddleware,
62+
_CORSMiddlewareWithExposeInPreflight,
4563
allow_origins=["*"],
4664
allow_methods=["*"],
4765
allow_headers=["*"],
48-
expose_headers=["*"],
66+
expose_headers=["Mcp-Session-Id"],
4967
allow_credentials=True,
5068
)
5169

src/sn_mcp_server/tools/list_templates.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ async def _list_all_templates(ctx: Context, token: str, client: SignNowAPIClient
7878
roles=role_names,
7979
)
8080
)
81-
except (ValueError, KeyError, AttributeError):
81+
except Exception: # noqa: S110
8282
# Skip root folder if it can't be accessed
8383
pass
8484

@@ -120,9 +120,8 @@ async def _list_all_templates(ctx: Context, token: str, client: SignNowAPIClient
120120
roles=role_names,
121121
)
122122
)
123-
except (ValueError, KeyError, AttributeError):
124-
# Skip folders that can't be accessed
125-
# Log specific error types but continue processing other folders
123+
except Exception: # noqa: S112
124+
# Skip folders that can't be accessed; continue processing remaining folders
126125
continue
127126

128127
# Get template groups

tests/unit/signnow_client/test_invite_from_alias.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,17 @@
1212

1313

1414
class _DummyDocumentClient(DocumentClientMixin):
15-
def __init__(self):
15+
def __init__(self) -> None:
1616
self.last_post = None
1717

18-
def _post(self, url: str, headers=None, data=None, json_data=None, validate_model=None):
18+
def _post( # noqa: ANN202
19+
self,
20+
url: str,
21+
headers: dict | None = None,
22+
data: dict | None = None,
23+
json_data: dict | None = None,
24+
validate_model: type | None = None,
25+
) -> object:
1926
self.last_post = {"url": url, "headers": headers, "data": data, "json_data": json_data, "validate_model": validate_model}
2027
if validate_model is None:
2128
return None
@@ -27,7 +34,7 @@ def _post(self, url: str, headers=None, data=None, json_data=None, validate_mode
2734
return validate_model.model_validate({})
2835

2936

30-
def test_create_document_field_invite_request_accepts_from_field_name_and_serializes_alias():
37+
def test_create_document_field_invite_request_accepts_from_field_name_and_serializes_alias() -> None:
3138
req = CreateDocumentFieldInviteRequest(document_id="doc123", to=[], from_="sample-apps@signnow.com")
3239
assert req.from_ == "sample-apps@signnow.com"
3340

@@ -36,7 +43,7 @@ def test_create_document_field_invite_request_accepts_from_field_name_and_serial
3643
assert "from_" not in dumped
3744

3845

39-
def test_create_document_freeform_invite_request_accepts_from_field_name_and_serializes_alias():
46+
def test_create_document_freeform_invite_request_accepts_from_field_name_and_serializes_alias() -> None:
4047
req = CreateDocumentFreeformInviteRequest(to="signer@example.com", from_="sample-apps@signnow.com")
4148
assert req.from_ == "sample-apps@signnow.com"
4249

@@ -45,19 +52,19 @@ def test_create_document_freeform_invite_request_accepts_from_field_name_and_ser
4552
assert "from_" not in dumped
4653

4754

48-
def test_client_document_field_invite_uses_by_alias_true_when_dumping():
55+
def test_client_document_field_invite_uses_by_alias_true_when_dumping() -> None:
4956
client = _DummyDocumentClient()
5057
request_data = Mock()
5158
request_data.model_dump.return_value = {"from": "sample-apps@signnow.com"}
5259

53-
client.create_document_field_invite(token="t", document_id="doc123", request_data=request_data)
60+
client.create_document_field_invite(token="t", document_id="doc123", request_data=request_data) # noqa: S106
5461
request_data.model_dump.assert_called_once_with(exclude_none=True, by_alias=True)
5562

5663

57-
def test_client_document_freeform_invite_uses_by_alias_true_when_dumping():
64+
def test_client_document_freeform_invite_uses_by_alias_true_when_dumping() -> None:
5865
client = _DummyDocumentClient()
5966
request_data = Mock()
6067
request_data.model_dump.return_value = {"from": "sample-apps@signnow.com"}
6168

62-
client.create_document_freeform_invite(token="t", document_id="doc123", request_data=request_data)
69+
client.create_document_freeform_invite(token="t", document_id="doc123", request_data=request_data) # noqa: S106
6370
request_data.model_dump.assert_called_once_with(exclude_none=True, by_alias=True)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""Unit tests for _CORSMiddlewareWithExposeInPreflight in app module."""
2+
3+
from starlette.applications import Starlette
4+
from starlette.requests import Request
5+
from starlette.responses import PlainTextResponse
6+
from starlette.testclient import TestClient
7+
8+
from sn_mcp_server.app import _CORSMiddlewareWithExposeInPreflight
9+
10+
_ORIGIN = "https://claude.ai"
11+
_PREFLIGHT_HEADERS = {
12+
"Origin": _ORIGIN,
13+
"Access-Control-Request-Method": "POST",
14+
"Access-Control-Request-Headers": "content-type,authorization,mcp-session-id",
15+
}
16+
17+
18+
def _make_app(expose_headers: list[str]) -> TestClient:
19+
async def _handler(request: Request) -> PlainTextResponse:
20+
return PlainTextResponse("OK")
21+
22+
app = Starlette()
23+
app.add_middleware(
24+
_CORSMiddlewareWithExposeInPreflight,
25+
allow_origins=["*"],
26+
allow_methods=["*"],
27+
allow_headers=["*"],
28+
expose_headers=expose_headers,
29+
allow_credentials=True,
30+
)
31+
app.add_route("/mcp", _handler, methods=["OPTIONS", "POST", "GET"])
32+
return TestClient(app, raise_server_exceptions=True)
33+
34+
35+
class TestCORSExposeHeadersPreflight:
36+
"""Regression tests for Access-Control-Expose-Headers in OPTIONS preflight responses."""
37+
38+
def test_preflight_includes_expose_header(self) -> None:
39+
"""OPTIONS response must carry Access-Control-Expose-Headers: Mcp-Session-Id."""
40+
client = _make_app(["Mcp-Session-Id"])
41+
resp = client.options("/mcp", headers=_PREFLIGHT_HEADERS)
42+
assert resp.status_code == 200
43+
assert resp.headers.get("access-control-expose-headers") == "Mcp-Session-Id"
44+
45+
def test_actual_response_includes_expose_header(self) -> None:
46+
"""Non-preflight POST response must also carry Access-Control-Expose-Headers."""
47+
client = _make_app(["Mcp-Session-Id"])
48+
resp = client.post("/mcp", headers={"Origin": _ORIGIN})
49+
assert resp.headers.get("access-control-expose-headers") == "Mcp-Session-Id"
50+
51+
def test_expose_header_absent_when_empty_list(self) -> None:
52+
"""When expose_headers=[], Access-Control-Expose-Headers must not be present."""
53+
client = _make_app([])
54+
resp = client.options("/mcp", headers=_PREFLIGHT_HEADERS)
55+
assert "access-control-expose-headers" not in resp.headers
56+
57+
def test_preflight_echoes_request_origin_with_credentials(self) -> None:
58+
"""With allow_credentials=True the preflight must echo the exact request origin."""
59+
client = _make_app(["Mcp-Session-Id"])
60+
resp = client.options("/mcp", headers=_PREFLIGHT_HEADERS)
61+
assert resp.headers.get("access-control-allow-origin") == _ORIGIN
62+
assert resp.headers.get("access-control-allow-credentials") == "true"
63+
64+
def test_preflight_multiple_expose_headers(self) -> None:
65+
"""Multiple expose_headers values are joined with ', ' in the preflight."""
66+
client = _make_app(["Mcp-Session-Id", "X-Custom-Header"])
67+
resp = client.options("/mcp", headers=_PREFLIGHT_HEADERS)
68+
expose = resp.headers.get("access-control-expose-headers", "")
69+
assert "Mcp-Session-Id" in expose
70+
assert "X-Custom-Header" in expose
71+
72+
def test_preflight_returns_200(self) -> None:
73+
"""Preflight response status must be 200 OK."""
74+
client = _make_app(["Mcp-Session-Id"])
75+
resp = client.options("/mcp", headers=_PREFLIGHT_HEADERS)
76+
assert resp.status_code == 200
77+
78+
def test_max_age_present_in_preflight(self) -> None:
79+
"""access-control-max-age must be set in preflight response."""
80+
client = _make_app(["Mcp-Session-Id"])
81+
resp = client.options("/mcp", headers=_PREFLIGHT_HEADERS)
82+
assert resp.headers.get("access-control-max-age") == "600"

tests/unit/sn_mcp_server/tools/test_expiration.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,13 @@
55
import pytest
66

77
from signnow_client.models.document_groups import DocumentGroupV2FieldInvite
8-
98
from sn_mcp_server.tools.models import SimplifiedInviteParticipant
109

1110

1211
class TestExpirationHandling:
1312
"""Test cases for optional expiration_time and expiration_days."""
1413

15-
def test_simplified_invite_participant_from_field_invite_with_expiration(self):
14+
def test_simplified_invite_participant_from_field_invite_with_expiration(self) -> None:
1615
"""Test SimplifiedInviteParticipant.from_document_group_v2_field_invite with expiration."""
1716
now = 1500000000
1817
field_invite = DocumentGroupV2FieldInvite(
@@ -23,7 +22,7 @@ def test_simplified_invite_participant_from_field_invite_with_expiration(self):
2322
expiration_time=2000000000, # Future expiration
2423
expiration_days=30,
2524
signer_email="test@example.com",
26-
password_protected="0",
25+
password_protected="0", # noqa: S106
2726
email_statuses=[],
2827
)
2928
participant = SimplifiedInviteParticipant.from_document_group_v2_field_invite(field_invite, now)
@@ -32,7 +31,7 @@ def test_simplified_invite_participant_from_field_invite_with_expiration(self):
3231
assert participant.expired is False # Not expired yet
3332
assert participant.status == "pending"
3433

35-
def test_simplified_invite_participant_from_field_invite_without_expiration(self):
34+
def test_simplified_invite_participant_from_field_invite_without_expiration(self) -> None:
3635
"""Test SimplifiedInviteParticipant.from_document_group_v2_field_invite without expiration."""
3736
now = 1500000000
3837
field_invite = DocumentGroupV2FieldInvite(
@@ -43,7 +42,7 @@ def test_simplified_invite_participant_from_field_invite_without_expiration(self
4342
expiration_time=None, # No expiration
4443
expiration_days=None,
4544
signer_email="test@example.com",
46-
password_protected="0",
45+
password_protected="0", # noqa: S106
4746
email_statuses=[],
4847
)
4948
participant = SimplifiedInviteParticipant.from_document_group_v2_field_invite(field_invite, now)
@@ -52,7 +51,7 @@ def test_simplified_invite_participant_from_field_invite_without_expiration(self
5251
assert participant.expired is False # No expiration means not expired
5352
assert participant.status == "pending"
5453

55-
def test_simplified_invite_participant_from_field_invite_expired(self):
54+
def test_simplified_invite_participant_from_field_invite_expired(self) -> None:
5655
"""Test SimplifiedInviteParticipant.from_document_group_v2_field_invite with expired invite."""
5756
now = 2500000000 # After expiration
5857
field_invite = DocumentGroupV2FieldInvite(
@@ -63,7 +62,7 @@ def test_simplified_invite_participant_from_field_invite_expired(self):
6362
expiration_time=2000000000, # Past expiration
6463
expiration_days=30,
6564
signer_email="test@example.com",
66-
password_protected="0",
65+
password_protected="0", # noqa: S106
6766
email_statuses=[],
6867
)
6968
participant = SimplifiedInviteParticipant.from_document_group_v2_field_invite(field_invite, now)
@@ -72,7 +71,7 @@ def test_simplified_invite_participant_from_field_invite_expired(self):
7271
assert participant.expired is True # Expired
7372
assert participant.status == "expired"
7473

75-
def test_simplified_invite_participant_from_field_invite_expired_status(self):
74+
def test_simplified_invite_participant_from_field_invite_expired_status(self) -> None:
7675
"""Test SimplifiedInviteParticipant.from_document_group_v2_field_invite with expired status."""
7776
now = 1500000000
7877
field_invite = DocumentGroupV2FieldInvite(
@@ -83,7 +82,7 @@ def test_simplified_invite_participant_from_field_invite_expired_status(self):
8382
expiration_time=2000000000,
8483
expiration_days=30,
8584
signer_email="test@example.com",
86-
password_protected="0",
85+
password_protected="0", # noqa: S106
8786
email_statuses=[],
8887
)
8988
participant = SimplifiedInviteParticipant.from_document_group_v2_field_invite(field_invite, now)
@@ -99,7 +98,7 @@ def test_simplified_invite_participant_from_field_invite_expired_status(self):
9998
("signed", 2000000000, 2500000000, False), # Past expiration but signed status (not in PENDING)
10099
],
101100
)
102-
def test_check_expired(self, status: str, expires_at: int | None, now: int, expected: bool):
101+
def test_check_expired(self, status: str, expires_at: int | None, now: int, expected: bool) -> None:
103102
"""Test check_expired method with various status, expiration, and time combinations."""
104103
result = SimplifiedInviteParticipant.check_expired(status, expires_at, now)
105104
assert result == expected

0 commit comments

Comments
 (0)