Skip to content

Commit c199df3

Browse files
committed
WIP
1 parent f324f67 commit c199df3

File tree

9 files changed

+204
-19
lines changed

9 files changed

+204
-19
lines changed

LICENSE

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
MIT License
22

3-
Copyright (c) 2021 WorkOS
3+
Copyright (c) 2024 WorkOS
44

55
Permission is hereby granted, free of charge, to any person obtaining a copy
66
of this software and associated documentation files (the "Software"), to deal

requirements-dev.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
flake8
2+
pytest==8.3.2
3+
pytest-asyncio==0.23.8
4+
pytest-cov==5.0.0
5+
six==1.16.0
6+
black==24.4.2
7+
twine==5.1.1
8+
mypy==1.12.0
9+
httpx>=0.27.0

requirements.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
httpx>=0.27.0
2+
pydantic==2.9.2
3+
PyJWT==2.9.0
4+
cryptography==43.0.3

setup.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@
1010
with open(os.path.join(base_dir, "workos", "__about__.py")) as f:
1111
exec(f.read(), about)
1212

13+
def read_requirements(filename):
14+
with open(filename) as f:
15+
return [line.strip() for line in f
16+
if line.strip() and not line.startswith('#')]
17+
1318
setup(
1419
name=about["__package_name__"],
1520
version=about["__version__"],
@@ -27,19 +32,9 @@
2732
),
2833
zip_safe=False,
2934
license=about["__license__"],
30-
install_requires=["httpx>=0.27.0", "pydantic==2.9.2"],
35+
install_requires=read_requirements("requirements.txt"),
3136
extras_require={
32-
"dev": [
33-
"flake8",
34-
"pytest==8.3.2",
35-
"pytest-asyncio==0.23.8",
36-
"pytest-cov==5.0.0",
37-
"six==1.16.0",
38-
"black==24.4.2",
39-
"twine==5.1.1",
40-
"mypy==1.12.0",
41-
"httpx>=0.27.0",
42-
],
37+
"dev": read_requirements("requirements-dev.txt"),
4338
":python_version<'3.4'": ["enum34"],
4439
},
4540
classifiers=[

workos/session.py

Lines changed: 151 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,155 @@
1+
import json
2+
from typing import Any, Dict, List, Optional, Union
3+
import jwt
4+
from jwt import PyJWKClient
5+
from cryptography.fernet import Fernet
16

2-
from typing import Protocol, Union
7+
from workos.types.user_management.session import (
8+
AuthenticateWithSessionCookieFailureReason,
9+
AuthenticateWithSessionCookieSuccessResponse,
10+
AuthenticateWithSessionCookieErrorResponse,
11+
)
312

4-
from workos.types.user_management.session import AuthenticateWithSessionCookieSuccessResponse, AuthenticateWithSessionCookieErrorResponse
13+
class SessionModule:
14+
def __init__(
15+
self,
16+
*,
17+
user_management: Any,
18+
client_id: str,
19+
session_data: str,
20+
cookie_password: str
21+
) -> None:
22+
# If the cookie password is not provided, throw an error
23+
if cookie_password is None or cookie_password == "":
24+
raise ValueError("cookie_password is required")
525

6-
class SessionModule(Protocol):
26+
self.user_management = user_management
27+
self.client_id = client_id
28+
self.session_data = session_data
29+
self.cookie_password = cookie_password
730

8-
def authenticate(self) -> Union[AuthenticateWithSessionCookieSuccessResponse, AuthenticateWithSessionCookieErrorResponse]:
9-
...
31+
self.jwks = self.create_remote_jwk_set(
32+
self.user_management.get_jwks_url()
33+
)
34+
self.jwk_algorithms = [str(key.Algorithm) for key in self.jwks]
35+
36+
for key in self.jwks:
37+
print("Key properties:", dir(key)) # This will show all available attributes
38+
print("Algorithm:", key.Algorithm)
39+
print("Key type:", key.key_type)
40+
41+
def authenticate(
42+
self,
43+
) -> Union[
44+
AuthenticateWithSessionCookieSuccessResponse,
45+
AuthenticateWithSessionCookieErrorResponse,
46+
]:
47+
if self.session_data is None:
48+
return AuthenticateWithSessionCookieErrorResponse(
49+
authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.NO_SESSION_COOKIE_PROVIDED
50+
)
51+
52+
try:
53+
session = self.unseal_data(self.session_data, self.cookie_password)
54+
except Exception:
55+
return AuthenticateWithSessionCookieErrorResponse(
56+
authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
57+
)
58+
59+
if not session["access_token"]:
60+
return AuthenticateWithSessionCookieErrorResponse(
61+
authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
62+
)
63+
64+
if not self.is_valid_jwt(session["access_token"]):
65+
return AuthenticateWithSessionCookieErrorResponse(
66+
authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_JWT
67+
)
68+
69+
decoded = jwt.decode(
70+
session["access_token"], self.jwks, algorithms=self.jwk_algorithms
71+
)
72+
73+
return AuthenticateWithSessionCookieSuccessResponse(
74+
authenticated=True,
75+
session_id=decoded["sid"],
76+
organization_id=decoded["org_id"],
77+
role=decoded["role"],
78+
permissions=decoded["permissions"],
79+
entitlements=decoded["entitlements"],
80+
user=session["user"],
81+
impersonator=session["impersonator"],
82+
reason=None,
83+
)
84+
85+
def refresh(self, options: Optional[Dict[str, Any]] = None) -> Union[
86+
AuthenticateWithSessionCookieSuccessResponse,
87+
AuthenticateWithSessionCookieErrorResponse,
88+
]:
89+
cookie_password = options.get("cookie_password", self.cookie_password)
90+
organization_id = options.get("organization_id", None)
91+
92+
try:
93+
session = self.unseal_data(self.session_data, cookie_password)
94+
except Exception:
95+
return AuthenticateWithSessionCookieErrorResponse(
96+
authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
97+
)
98+
99+
if not session["refresh_token"] or not session["user"]:
100+
return AuthenticateWithSessionCookieErrorResponse(
101+
authenticated=False, reason=AuthenticateWithSessionCookieFailureReason.INVALID_SESSION_COOKIE
102+
)
103+
104+
try:
105+
auth_response = self.user_management.authenticate_with_refresh_token(
106+
refresh_token=session["refresh_token"],
107+
organization_id=organization_id,
108+
)
109+
110+
self.session_data = auth_response.sealed_session
111+
self.cookie_password = cookie_password
112+
113+
return AuthenticateWithSessionCookieSuccessResponse(
114+
authenticated=True,
115+
sealed_session=auth_response.sealed_session,
116+
session=auth_response,
117+
reason=None,
118+
)
119+
except Exception as e:
120+
return AuthenticateWithSessionCookieErrorResponse(
121+
authenticated=False, reason=str(e)
122+
)
123+
124+
def get_logout_url(self) -> str:
125+
auth_response = self.authenticate()
126+
127+
if not auth_response["authenticated"]:
128+
raise ValueError(auth_response["reason"])
129+
130+
return self.user_management.get_logout_url(
131+
session_id=auth_response["session_id"]
132+
)
133+
134+
def create_remote_jwk_set(self, url: str) -> List[Dict[str, Any]]:
135+
jwks_client = PyJWKClient(url)
136+
return jwks_client.get_signing_keys()
137+
138+
def is_valid_jwt(self, token: str) -> bool:
139+
try:
140+
jwt.decode(token, self.jwks, algorithms=self.jwk_algorithms)
141+
return True
142+
except jwt.exceptions.InvalidTokenError as error:
143+
print("invalid token", error)
144+
return False
145+
146+
@staticmethod
147+
def seal_data(data: Dict[str, Any], key: str) -> str:
148+
fernet = Fernet(key)
149+
# take the data and encrypt it with the key using fernet
150+
return fernet.encrypt(json.dumps(data).encode())
151+
152+
@staticmethod
153+
def unseal_data(sealed_data: str, key: str) -> Dict[str, Any]:
154+
fernet = Fernet(key)
155+
return json.loads(fernet.decrypt(sealed_data).decode())

workos/types/user_management/authenticate_with_common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Literal, Union
22
from typing_extensions import TypedDict
3+
from workos.types.user_management.session import SessionConfig
34

45

56
class AuthenticateWithBaseParameters(TypedDict):
@@ -17,6 +18,7 @@ class AuthenticateWithCodeParameters(AuthenticateWithBaseParameters):
1718
code: str
1819
code_verifier: Union[str, None]
1920
grant_type: Literal["authorization_code"]
21+
session: Union[SessionConfig, None]
2022

2123

2224
class AuthenticateWithMagicAuthParameters(AuthenticateWithBaseParameters):

workos/types/user_management/authentication_response.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class AuthenticationResponse(_AuthenticationResponseBase):
2929
impersonator: Optional[Impersonator] = None
3030
organization_id: Optional[str] = None
3131
user: User
32+
sealed_session: Optional[str] = None
3233

3334

3435
class AuthKitAuthenticationResponse(AuthenticationResponse):

workos/types/user_management/session.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import List, Optional
1+
from typing import List, Optional, TypedDict
22
from enum import Enum
33

44
from workos.types.user_management.impersonator import Impersonator
@@ -23,3 +23,6 @@ class AuthenticateWithSessionCookieErrorResponse(WorkOSModel):
2323
authenticated: bool = False
2424
reason: AuthenticateWithSessionCookieFailureReason
2525

26+
class SessionConfig(TypedDict):
27+
seal_session: bool
28+
cookie_password: str

workos/user_management.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional, Protocol, Sequence, Set, Type
22
from workos._client_configuration import ClientConfiguration
3+
from workos.session import SessionModule
34
from workos.types.list_resource import (
45
ListArgs,
56
ListMetadata,
@@ -43,6 +44,7 @@
4344
UsersListFilters,
4445
)
4546
from workos.types.user_management.password_hash_type import PasswordHashType
47+
from workos.types.user_management.session import SessionConfig
4648
from workos.types.user_management.user_management_provider_type import (
4749
UserManagementProviderType,
4850
)
@@ -109,6 +111,18 @@ class UserManagementModule(Protocol):
109111

110112
_client_configuration: ClientConfiguration
111113

114+
def load_sealed_session(self, *, sealed_session: str, cookie_password: str) -> SyncOrAsync[SessionModule]:
115+
"""Load a sealed session and return the session data.
116+
117+
Args:
118+
sealed_session (str): The sealed session data to load.
119+
cookie_password (str): The cookie password to use to decrypt the session data.
120+
121+
Returns:
122+
SessionModule: The session module.
123+
"""
124+
...
125+
112126
def get_user(self, user_id: str) -> SyncOrAsync[User]:
113127
"""Get the details of an existing user.
114128
@@ -804,6 +818,9 @@ def __init__(
804818
self._client_configuration = client_configuration
805819
self._http_client = http_client
806820

821+
def load_sealed_session(self, *, session_data: str, cookie_password: str) -> SessionModule:
822+
return SessionModule(user_management=self, client_id=self._http_client.client_id, session_data=session_data, cookie_password=cookie_password)
823+
807824
def get_user(self, user_id: str) -> User:
808825
response = self._http_client.request(
809826
USER_DETAIL_PATH.format(user_id), method=REQUEST_METHOD_GET
@@ -1013,6 +1030,9 @@ def _authenticate_with(
10131030
json=json,
10141031
)
10151032

1033+
if payload["session"] is not None and payload["session"].get("seal_session") is True:
1034+
response["sealed_session"] = SessionModule.seal_data(response, payload["session"]["cookie_password"])
1035+
10161036
return response_model.model_validate(response)
10171037

10181038
def authenticate_with_password(
@@ -1037,16 +1057,21 @@ def authenticate_with_code(
10371057
self,
10381058
*,
10391059
code: str,
1060+
session: Optional[SessionConfig] = None,
10401061
code_verifier: Optional[str] = None,
10411062
ip_address: Optional[str] = None,
10421063
user_agent: Optional[str] = None,
10431064
) -> AuthKitAuthenticationResponse:
1065+
if session is not None and (session.get("seal_session") is True and session.get("cookie_password") is None or ""):
1066+
raise ValueError("cookie_password is required when sealing session")
1067+
10441068
payload: AuthenticateWithCodeParameters = {
10451069
"code": code,
10461070
"grant_type": "authorization_code",
10471071
"ip_address": ip_address,
10481072
"user_agent": user_agent,
10491073
"code_verifier": code_verifier,
1074+
"session": session,
10501075
}
10511076

10521077
return self._authenticate_with(

0 commit comments

Comments
 (0)