1+ import logging
12import os
23import ssl
3- import logging
4- import jwt
4+
55import grpc
6+ import jwt
7+ import requests
68from aiohttp import hdrs , web
7-
8- from temporalio .api .common .v1 import Payload , Payloads
9- from temporalio .api .cloud .cloudservice .v1 import request_response_pb2 , service_pb2_grpc
109from google .protobuf import json_format
10+ from jwt .algorithms import RSAAlgorithm
11+ from temporalio .api .cloud .cloudservice .v1 import request_response_pb2 , service_pb2_grpc
12+ from temporalio .api .common .v1 import Payload , Payloads
13+
1114from encryption_jwt .codec import EncryptionCodec
1215
13- AUTHORIZED_ACCOUNT_ACCESS_ROLES = ["admin" ]
16+ AUTHORIZED_ACCOUNT_ACCESS_ROLES = ["owner" , " admin" ]
1417AUTHORIZED_NAMESPACE_ACCESS_ROLES = ["read" , "write" , "admin" ]
1518
1619temporal_ops_address = "saas-api.tmprl.cloud:443"
@@ -43,51 +46,101 @@ async def cors_options(req: web.Request) -> web.Response:
4346 return resp
4447
4548 def decryption_authorized (email : str , namespace : str ) -> bool :
46- credentials = grpc .composite_channel_credentials (grpc .ssl_channel_credentials (
47- ), grpc .access_token_call_credentials (os .environ .get ("TEMPORAL_API_KEY" )))
49+ credentials = grpc .composite_channel_credentials (
50+ grpc .ssl_channel_credentials (),
51+ grpc .access_token_call_credentials (os .environ .get ("TEMPORAL_API_KEY" )),
52+ )
4853
4954 with grpc .secure_channel (temporal_ops_address , credentials ) as channel :
5055 client = service_pb2_grpc .CloudServiceStub (channel )
5156 request = request_response_pb2 .GetUsersRequest ()
5257
53- response = client .GetUsers (request , metadata = (
54- ("temporal-cloud-api-version" , os .environ .get ("TEMPORAL_OPS_API_VERSION" )),))
58+ response = client .GetUsers (
59+ request ,
60+ metadata = (
61+ (
62+ "temporal-cloud-api-version" ,
63+ os .environ .get ("TEMPORAL_OPS_API_VERSION" ),
64+ ),
65+ ),
66+ )
5567
56- authorized = False
5768 for user in response .users :
5869 if user .spec .email .lower () == email .lower ():
59- if user .spec .access .account_access .role in AUTHORIZED_ACCOUNT_ACCESS_ROLES :
60- authorized = True
70+ if (
71+ user .spec .access .account_access .role
72+ in AUTHORIZED_ACCOUNT_ACCESS_ROLES
73+ ):
74+ return True
6175 else :
6276 if namespace in user .spec .access .namespace_accesses :
63- if user .spec .access .namespace_accesses [namespace ].permission in AUTHORIZED_NAMESPACE_ACCESS_ROLES :
64- authorized = True
77+ if (
78+ user .spec .access .namespace_accesses [
79+ namespace
80+ ].permission
81+ in AUTHORIZED_NAMESPACE_ACCESS_ROLES
82+ ):
83+ return True
6584
66- return authorized
85+ return False
6786
6887 def make_handler (fn : str ):
6988 async def handler (req : web .Request ):
70- # Read payloads as JSON
71- assert req .content_type == "application/json"
72- payloads = json_format .Parse (await req .read (), Payloads ())
73-
74- # Extract the email from the JWT.
75- auth_header = req .headers .get ("Authorization" )
7689 namespace = req .headers .get ("x-namespace" )
90+ auth_header = req .headers .get ("Authorization" )
7791 _bearer , encoded = auth_header .split (" " )
78- decoded = jwt .decode (encoded , options = {"verify_signature" : False })
7992
80- # Use the email to determine if the payload should be decrypted.
81- authorized = decryption_authorized (decoded ["https://saas-api.tmprl.cloud/user/email" ], namespace )
93+ # Extract the kid from the Auth header
94+ jwt_dict = jwt .get_unverified_header (encoded )
95+ kid = jwt_dict ["kid" ]
96+ algorithm = jwt_dict ["alg" ]
97+
98+ # Fetch Temporal Cloud JWKS
99+ jwks_url = "https://login.tmprl.cloud/.well-known/jwks.json"
100+ jwks = requests .get (jwks_url ).json ()
101+
102+ # Extract Temporal Cloud's public key
103+ public_key = None
104+ for key in jwks ["keys" ]:
105+ if key ["kid" ] == kid :
106+ # Convert JWKS key to PEM format
107+ public_key = RSAAlgorithm .from_jwk (key )
108+ break
109+
110+ if public_key is None :
111+ raise ValueError ("Public key not found in JWKS" )
112+
113+ # Decode the jwt, verifying against Temporal Cloud's public key
114+ decoded = jwt .decode (
115+ encoded ,
116+ public_key ,
117+ algorithms = [algorithm ],
118+ audience = [
119+ "https://saas-api.tmprl.cloud" ,
120+ "https://prod-tmprl.us.auth0.com/userinfo" ,
121+ ],
122+ )
123+
124+ # Use the email to determine if the user is authorized to decrypt the payload
125+ authorized = decryption_authorized (
126+ decoded ["https://saas-api.tmprl.cloud/user/email" ], namespace
127+ )
128+
82129 if authorized :
130+ # Read payloads as JSON
131+ assert req .content_type == "application/json"
132+ payloads = json_format .Parse (await req .read (), Payloads ())
83133 encryptionCodec = EncryptionCodec (namespace )
84- payloads = Payloads (payloads = await getattr (encryptionCodec , fn )(payloads .payloads ))
134+ payloads = Payloads (
135+ payloads = await getattr (encryptionCodec , fn )(payloads .payloads )
136+ )
85137
86138 # Apply CORS and return JSON
87139 resp = await cors_options (req )
88140 resp .content_type = "application/json"
89141 resp .text = json_format .MessageToJson (payloads )
90142 return resp
143+
91144 return handler
92145
93146 # Build app
@@ -97,8 +150,8 @@ async def handler(req: web.Request):
97150 logger = logging .getLogger (__name__ )
98151 app .add_routes (
99152 [
100- web .post ("/encode" , make_handler (' encode' )),
101- web .post ("/decode" , make_handler (' decode' )),
153+ web .post ("/encode" , make_handler (" encode" )),
154+ web .post ("/decode" , make_handler (" decode" )),
102155 web .options ("/decode" , cors_options ),
103156 ]
104157 )
@@ -112,8 +165,10 @@ async def handler(req: web.Request):
112165 if os .environ .get ("SSL_PEM" ) and os .environ .get ("SSL_KEY" ):
113166 ssl_context = ssl .create_default_context (ssl .Purpose .CLIENT_AUTH )
114167 ssl_context .check_hostname = False
115- ssl_context .load_cert_chain (os .environ .get (
116- "SSL_PEM" ), os .environ .get ("SSL_KEY" ))
168+ ssl_context .load_cert_chain (
169+ os .environ .get ("SSL_PEM" ), os .environ .get ("SSL_KEY" )
170+ )
117171
118- web .run_app (build_codec_server (), host = "0.0.0.0" ,
119- port = 8081 , ssl_context = ssl_context )
172+ web .run_app (
173+ build_codec_server (), host = "0.0.0.0" , port = 8081 , ssl_context = ssl_context
174+ )
0 commit comments