1+ import json
2+ from typing import Any , Dict , List , Optional , Union
3+ import jwt
4+ from jwt import PyJWKClient
5+ from cryptography .fernet import Fernet
16
2- from typing import Protocol , Union
7+ from workos .types .user_management .session import (
8+ AuthenticateWithSessionCookieFailureReason ,
9+ AuthenticateWithSessionCookieSuccessResponse ,
10+ AuthenticateWithSessionCookieErrorResponse ,
11+ )
312
4- from workos .types .user_management .session import AuthenticateWithSessionCookieSuccessResponse , AuthenticateWithSessionCookieErrorResponse
13+ class SessionModule :
14+ def __init__ (
15+ self ,
16+ * ,
17+ user_management : Any ,
18+ client_id : str ,
19+ session_data : str ,
20+ cookie_password : str
21+ ) -> None :
22+ # If the cookie password is not provided, throw an error
23+ if cookie_password is None or cookie_password == "" :
24+ raise ValueError ("cookie_password is required" )
525
6- class SessionModule (Protocol ):
26+ self .user_management = user_management
27+ self .client_id = client_id
28+ self .session_data = session_data
29+ self .cookie_password = cookie_password
730
8- def authenticate (self ) -> Union [AuthenticateWithSessionCookieSuccessResponse , AuthenticateWithSessionCookieErrorResponse ]:
9- ...
31+ self .jwks = self .create_remote_jwk_set (
32+ self .user_management .get_jwks_url ()
33+ )
34+ self .jwk_algorithms = [str (key .Algorithm ) for key in self .jwks ]
35+
36+ for key in self .jwks :
37+ print ("Key properties:" , dir (key )) # This will show all available attributes
38+ print ("Algorithm:" , key .Algorithm )
39+ print ("Key type:" , key .key_type )
40+
41+ def authenticate (
42+ self ,
43+ ) -> Union [
44+ AuthenticateWithSessionCookieSuccessResponse ,
45+ AuthenticateWithSessionCookieErrorResponse ,
46+ ]:
47+ if self .session_data is None :
48+ return AuthenticateWithSessionCookieErrorResponse (
49+ authenticated = False , reason = AuthenticateWithSessionCookieFailureReason .NO_SESSION_COOKIE_PROVIDED
50+ )
51+
52+ try :
53+ session = self .unseal_data (self .session_data , self .cookie_password )
54+ except Exception :
55+ return AuthenticateWithSessionCookieErrorResponse (
56+ authenticated = False , reason = AuthenticateWithSessionCookieFailureReason .INVALID_SESSION_COOKIE
57+ )
58+
59+ if not session ["access_token" ]:
60+ return AuthenticateWithSessionCookieErrorResponse (
61+ authenticated = False , reason = AuthenticateWithSessionCookieFailureReason .INVALID_SESSION_COOKIE
62+ )
63+
64+ if not self .is_valid_jwt (session ["access_token" ]):
65+ return AuthenticateWithSessionCookieErrorResponse (
66+ authenticated = False , reason = AuthenticateWithSessionCookieFailureReason .INVALID_JWT
67+ )
68+
69+ decoded = jwt .decode (
70+ session ["access_token" ], self .jwks , algorithms = self .jwk_algorithms
71+ )
72+
73+ return AuthenticateWithSessionCookieSuccessResponse (
74+ authenticated = True ,
75+ session_id = decoded ["sid" ],
76+ organization_id = decoded ["org_id" ],
77+ role = decoded ["role" ],
78+ permissions = decoded ["permissions" ],
79+ entitlements = decoded ["entitlements" ],
80+ user = session ["user" ],
81+ impersonator = session ["impersonator" ],
82+ reason = None ,
83+ )
84+
85+ def refresh (self , options : Optional [Dict [str , Any ]] = None ) -> Union [
86+ AuthenticateWithSessionCookieSuccessResponse ,
87+ AuthenticateWithSessionCookieErrorResponse ,
88+ ]:
89+ cookie_password = options .get ("cookie_password" , self .cookie_password )
90+ organization_id = options .get ("organization_id" , None )
91+
92+ try :
93+ session = self .unseal_data (self .session_data , cookie_password )
94+ except Exception :
95+ return AuthenticateWithSessionCookieErrorResponse (
96+ authenticated = False , reason = AuthenticateWithSessionCookieFailureReason .INVALID_SESSION_COOKIE
97+ )
98+
99+ if not session ["refresh_token" ] or not session ["user" ]:
100+ return AuthenticateWithSessionCookieErrorResponse (
101+ authenticated = False , reason = AuthenticateWithSessionCookieFailureReason .INVALID_SESSION_COOKIE
102+ )
103+
104+ try :
105+ auth_response = self .user_management .authenticate_with_refresh_token (
106+ refresh_token = session ["refresh_token" ],
107+ organization_id = organization_id ,
108+ )
109+
110+ self .session_data = auth_response .sealed_session
111+ self .cookie_password = cookie_password
112+
113+ return AuthenticateWithSessionCookieSuccessResponse (
114+ authenticated = True ,
115+ sealed_session = auth_response .sealed_session ,
116+ session = auth_response ,
117+ reason = None ,
118+ )
119+ except Exception as e :
120+ return AuthenticateWithSessionCookieErrorResponse (
121+ authenticated = False , reason = str (e )
122+ )
123+
124+ def get_logout_url (self ) -> str :
125+ auth_response = self .authenticate ()
126+
127+ if not auth_response ["authenticated" ]:
128+ raise ValueError (auth_response ["reason" ])
129+
130+ return self .user_management .get_logout_url (
131+ session_id = auth_response ["session_id" ]
132+ )
133+
134+ def create_remote_jwk_set (self , url : str ) -> List [Dict [str , Any ]]:
135+ jwks_client = PyJWKClient (url )
136+ return jwks_client .get_signing_keys ()
137+
138+ def is_valid_jwt (self , token : str ) -> bool :
139+ try :
140+ jwt .decode (token , self .jwks , algorithms = self .jwk_algorithms )
141+ return True
142+ except jwt .exceptions .InvalidTokenError as error :
143+ print ("invalid token" , error )
144+ return False
145+
146+ @staticmethod
147+ def seal_data (data : Dict [str , Any ], key : str ) -> str :
148+ fernet = Fernet (key )
149+ # take the data and encrypt it with the key using fernet
150+ return fernet .encrypt (json .dumps (data ).encode ())
151+
152+ @staticmethod
153+ def unseal_data (sealed_data : str , key : str ) -> Dict [str , Any ]:
154+ fernet = Fernet (key )
155+ return json .loads (fernet .decrypt (sealed_data ).decode ())
0 commit comments