Skip to content

Commit 0a8f15d

Browse files
committed
Add multiple roles support
1 parent f9126a0 commit 0a8f15d

File tree

5 files changed

+140
-20
lines changed

5 files changed

+140
-20
lines changed

tests/test_session.py

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
import pytest
1+
import concurrent.futures
2+
from datetime import datetime, timezone
23
from unittest.mock import AsyncMock, Mock, patch
4+
35
import jwt
4-
from datetime import datetime, timezone
5-
import concurrent.futures
6+
import pytest
7+
from cryptography.hazmat.primitives import serialization
8+
from cryptography.hazmat.primitives.asymmetric import rsa
69

710
from tests.conftest import with_jwks_mock
811
from workos.session import AsyncSession, Session, _get_jwks_client
@@ -16,9 +19,6 @@
1619
RefreshWithSessionCookieSuccessResponse,
1720
)
1821

19-
from cryptography.hazmat.primitives import serialization
20-
from cryptography.hazmat.primitives.asymmetric import rsa
21-
2222

2323
class SessionFixtures:
2424
@pytest.fixture(autouse=True)
@@ -48,6 +48,7 @@ def session_constants(self):
4848
"sid": "session_123",
4949
"org_id": "organization_123",
5050
"role": "admin",
51+
"roles": ["admin"],
5152
"permissions": ["read"],
5253
"entitlements": ["feature_1"],
5354
"exp": int(current_datetime.timestamp()) + 3600,
@@ -215,6 +216,75 @@ def test_authenticate_success(self, session_constants, mock_user_management):
215216
"sid": session_constants["SESSION_ID"],
216217
"org_id": session_constants["ORGANIZATION_ID"],
217218
"role": "admin",
219+
"roles": ["admin"],
220+
"permissions": ["read"],
221+
"entitlements": ["feature_1"],
222+
"exp": int(datetime.now(timezone.utc).timestamp()) + 3600,
223+
"iat": int(datetime.now(timezone.utc).timestamp()),
224+
},
225+
session_constants["PRIVATE_KEY"],
226+
algorithm="RS256",
227+
),
228+
"user": {
229+
"object": "user",
230+
"id": session_constants["USER_ID"],
231+
"email": "[email protected]",
232+
"email_verified": True,
233+
"created_at": session_constants["CURRENT_TIMESTAMP"],
234+
"updated_at": session_constants["CURRENT_TIMESTAMP"],
235+
},
236+
"impersonator": None,
237+
}
238+
239+
# Mock the JWT payload that would be decoded
240+
mock_jwt_payload = {
241+
"sid": session_constants["SESSION_ID"],
242+
"org_id": session_constants["ORGANIZATION_ID"],
243+
"role": "admin",
244+
"roles": ["admin"],
245+
"permissions": ["read"],
246+
"entitlements": ["feature_1"],
247+
}
248+
249+
with patch.object(Session, "unseal_data", return_value=mock_session), patch(
250+
"jwt.decode", return_value=mock_jwt_payload
251+
), patch.object(
252+
session.jwks,
253+
"get_signing_key_from_jwt",
254+
return_value=Mock(key=session_constants["PUBLIC_KEY"]),
255+
):
256+
response = session.authenticate()
257+
258+
assert isinstance(response, AuthenticateWithSessionCookieSuccessResponse)
259+
assert response.authenticated is True
260+
assert response.session_id == session_constants["SESSION_ID"]
261+
assert response.organization_id == session_constants["ORGANIZATION_ID"]
262+
assert response.role == "admin"
263+
assert response.roles == ["admin"]
264+
assert response.permissions == ["read"]
265+
assert response.entitlements == ["feature_1"]
266+
assert response.user.id == session_constants["USER_ID"]
267+
assert response.impersonator is None
268+
269+
@with_jwks_mock
270+
def test_authenticate_success_with_roles(
271+
self, session_constants, mock_user_management
272+
):
273+
session = Session(
274+
user_management=mock_user_management,
275+
client_id=session_constants["CLIENT_ID"],
276+
session_data=session_constants["SESSION_DATA"],
277+
cookie_password=session_constants["COOKIE_PASSWORD"],
278+
)
279+
280+
# Mock the session data that would be unsealed
281+
mock_session = {
282+
"access_token": jwt.encode(
283+
{
284+
"sid": session_constants["SESSION_ID"],
285+
"org_id": session_constants["ORGANIZATION_ID"],
286+
"role": "admin",
287+
"roles": ["admin", "member"],
218288
"permissions": ["read"],
219289
"entitlements": ["feature_1"],
220290
"exp": int(datetime.now(timezone.utc).timestamp()) + 3600,
@@ -239,6 +309,7 @@ def test_authenticate_success(self, session_constants, mock_user_management):
239309
"sid": session_constants["SESSION_ID"],
240310
"org_id": session_constants["ORGANIZATION_ID"],
241311
"role": "admin",
312+
"roles": ["admin", "member"],
242313
"permissions": ["read"],
243314
"entitlements": ["feature_1"],
244315
}
@@ -257,6 +328,7 @@ def test_authenticate_success(self, session_constants, mock_user_management):
257328
assert response.session_id == session_constants["SESSION_ID"]
258329
assert response.organization_id == session_constants["ORGANIZATION_ID"]
259330
assert response.role == "admin"
331+
assert response.roles == ["admin", "member"]
260332
assert response.permissions == ["read"]
261333
assert response.entitlements == ["feature_1"]
262334
assert response.user.id == session_constants["USER_ID"]
@@ -335,6 +407,7 @@ def test_refresh_success(self, session_constants, mock_user_management):
335407
"sid": session_constants["SESSION_ID"],
336408
"org_id": session_constants["ORGANIZATION_ID"],
337409
"role": "admin",
410+
"roles": ["admin"],
338411
"permissions": ["read"],
339412
"entitlements": ["feature_1"],
340413
},
@@ -435,6 +508,7 @@ async def test_refresh_success(self, session_constants, mock_user_management):
435508
"sid": session_constants["SESSION_ID"],
436509
"org_id": session_constants["ORGANIZATION_ID"],
437510
"role": "admin",
511+
"roles": ["admin"],
438512
"permissions": ["read"],
439513
"entitlements": ["feature_1"],
440514
},

workos/session.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def authenticate(
102102
session_id=decoded["sid"],
103103
organization_id=decoded.get("org_id", None),
104104
role=decoded.get("role", None),
105+
roles=decoded.get("roles", None),
105106
permissions=decoded.get("permissions", None),
106107
entitlements=decoded.get("entitlements", None),
107108
user=session["user"],
@@ -229,6 +230,7 @@ def refresh(
229230
session_id=decoded["sid"],
230231
organization_id=decoded.get("org_id", None),
231232
role=decoded.get("role", None),
233+
roles=decoded.get("roles", None),
232234
permissions=decoded.get("permissions", None),
233235
entitlements=decoded.get("entitlements", None),
234236
user=auth_response.user,
@@ -319,6 +321,7 @@ async def refresh(
319321
session_id=decoded["sid"],
320322
organization_id=decoded.get("org_id", None),
321323
role=decoded.get("role", None),
324+
roles=decoded.get("roles", None),
322325
permissions=decoded.get("permissions", None),
323326
entitlements=decoded.get("entitlements", None),
324327
user=auth_response.user,

workos/types/user_management/organization_membership.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Literal
1+
from typing import Literal, Sequence, Optional
22
from typing_extensions import TypedDict
33

44
from workos.types.workos_model import WorkOSModel
@@ -19,6 +19,7 @@ class OrganizationMembership(WorkOSModel):
1919
user_id: str
2020
organization_id: str
2121
role: OrganizationMembershipRole
22+
roles: Optional[Sequence[OrganizationMembershipRole]] = None
2223
status: LiteralOrUntyped[OrganizationMembershipStatus]
2324
created_at: str
2425
updated_at: str

workos/types/user_management/session.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from typing import Optional, Sequence, TypedDict, Union
21
from enum import Enum
2+
from typing import Optional, Sequence, TypedDict, Union
3+
34
from typing_extensions import Literal
5+
46
from workos.types.user_management.impersonator import Impersonator
57
from workos.types.user_management.user import User
68
from workos.types.workos_model import WorkOSModel
@@ -17,6 +19,7 @@ class AuthenticateWithSessionCookieSuccessResponse(WorkOSModel):
1719
session_id: str
1820
organization_id: Optional[str] = None
1921
role: Optional[str] = None
22+
roles: Optional[Sequence[str]] = None
2023
permissions: Optional[Sequence[str]] = None
2124
user: User
2225
impersonator: Optional[Impersonator] = None

workos/user_management.py

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -245,30 +245,47 @@ def delete_user(self, user_id: str) -> SyncOrAsync[None]:
245245
...
246246

247247
def create_organization_membership(
248-
self, *, user_id: str, organization_id: str, role_slug: Optional[str] = None
248+
self,
249+
*,
250+
user_id: str,
251+
organization_id: str,
252+
role_slug: Optional[str] = None,
253+
role_slugs: Optional[Sequence[str]] = None,
249254
) -> SyncOrAsync[OrganizationMembership]:
250255
"""Create a new OrganizationMembership for the given Organization and User.
251256
252257
Kwargs:
253-
user_id: The Unique ID of the User.
254-
organization_id: The Unique ID of the Organization to which the user belongs to.
255-
role_slug: The Unique Slug of the Role to which to grant to this membership.
256-
If no slug is passed in, the default role will be granted.(Optional)
258+
user_id: The unique ID of the User.
259+
organization_id: The unique ID of the Organization to which the user belongs to.
260+
role_slug: The unique slug of the role to grant to this membership.(Optional)
261+
role_slugs: The unique slugs of the roles to grant to this membership.(Optional)
262+
263+
Note:
264+
role_slug and role_slugs are mutually exclusive. If neither is provided,
265+
the user will be assigned the organization's default role.
257266
258267
Returns:
259268
OrganizationMembership: Created OrganizationMembership response from WorkOS.
260269
"""
261270
...
262271

263272
def update_organization_membership(
264-
self, *, organization_membership_id: str, role_slug: Optional[str] = None
273+
self,
274+
*,
275+
organization_membership_id: str,
276+
role_slug: Optional[str] = None,
277+
role_slugs: Optional[Sequence[str]] = None,
265278
) -> SyncOrAsync[OrganizationMembership]:
266279
"""Updates an OrganizationMembership for the given id.
267280
268281
Args:
269282
organization_membership_id (str): The unique ID of the Organization Membership.
270-
role_slug: The Unique Slug of the Role to which to grant to this membership.
271-
If no slug is passed in, it will not be changed (Optional)
283+
role_slug: The unique slug of the role to grant to this membership.(Optional)
284+
role_slugs: The unique slugs of the roles to grant to this membership.(Optional)
285+
286+
Note:
287+
role_slug and role_slugs are mutually exclusive. If neither is provided,
288+
the user will be assigned the organization's default role.
272289
273290
Returns:
274291
OrganizationMembership: Updated OrganizationMembership response from WorkOS.
@@ -988,12 +1005,18 @@ def delete_user(self, user_id: str) -> None:
9881005
)
9891006

9901007
def create_organization_membership(
991-
self, *, user_id: str, organization_id: str, role_slug: Optional[str] = None
1008+
self,
1009+
*,
1010+
user_id: str,
1011+
organization_id: str,
1012+
role_slug: Optional[str] = None,
1013+
role_slugs: Optional[Sequence[str]] = None,
9921014
) -> OrganizationMembership:
9931015
json = {
9941016
"user_id": user_id,
9951017
"organization_id": organization_id,
9961018
"role_slug": role_slug,
1019+
"role_slugs": role_slugs,
9971020
}
9981021

9991022
response = self._http_client.request(
@@ -1003,10 +1026,15 @@ def create_organization_membership(
10031026
return OrganizationMembership.model_validate(response)
10041027

10051028
def update_organization_membership(
1006-
self, *, organization_membership_id: str, role_slug: Optional[str] = None
1029+
self,
1030+
*,
1031+
organization_membership_id: str,
1032+
role_slug: Optional[str] = None,
1033+
role_slugs: Optional[Sequence[str]] = None,
10071034
) -> OrganizationMembership:
10081035
json = {
10091036
"role_slug": role_slug,
1037+
"role_slugs": role_slugs,
10101038
}
10111039

10121040
response = self._http_client.request(
@@ -1614,12 +1642,18 @@ async def delete_user(self, user_id: str) -> None:
16141642
)
16151643

16161644
async def create_organization_membership(
1617-
self, *, user_id: str, organization_id: str, role_slug: Optional[str] = None
1645+
self,
1646+
*,
1647+
user_id: str,
1648+
organization_id: str,
1649+
role_slug: Optional[str] = None,
1650+
role_slugs: Optional[Sequence[str]] = None,
16181651
) -> OrganizationMembership:
16191652
json = {
16201653
"user_id": user_id,
16211654
"organization_id": organization_id,
16221655
"role_slug": role_slug,
1656+
"role_slugs": role_slugs,
16231657
}
16241658

16251659
response = await self._http_client.request(
@@ -1629,10 +1663,15 @@ async def create_organization_membership(
16291663
return OrganizationMembership.model_validate(response)
16301664

16311665
async def update_organization_membership(
1632-
self, *, organization_membership_id: str, role_slug: Optional[str] = None
1666+
self,
1667+
*,
1668+
organization_membership_id: str,
1669+
role_slug: Optional[str] = None,
1670+
role_slugs: Optional[Sequence[str]] = None,
16331671
) -> OrganizationMembership:
16341672
json = {
16351673
"role_slug": role_slug,
1674+
"role_slugs": role_slugs,
16361675
}
16371676

16381677
response = await self._http_client.request(

0 commit comments

Comments
 (0)