@@ -41,13 +41,15 @@ def __init__(
4141
4242
4343class Provider :
44- def __init__ (self , id : str ): # pylint: disable=redefined-builtin
44+ def __init__ (
45+ self , id : str , config : ProviderConfigForClient
46+ ): # pylint: disable=redefined-builtin
4547 self .id = id
46- self .config = ProviderConfigForClientType ( "temp" )
48+ self .config = config
4749
4850 async def get_config_for_client_type ( # pylint: disable=no-self-use
4951 self , client_type : Optional [str ], user_context : Dict [str , Any ]
50- ) -> ProviderConfigForClientType :
52+ ) -> ProviderConfigForClient :
5153 _ = client_type
5254 __ = user_context
5355 raise NotImplementedError ()
@@ -110,60 +112,6 @@ def to_json(self) -> Dict[str, Any]:
110112 return {k : v for k , v in res .items () if v is not None }
111113
112114
113- class ProviderConfigForClientType :
114- def __init__ (
115- self ,
116- client_id : str ,
117- client_secret : Optional [str ] = None ,
118- scope : Optional [List [str ]] = None ,
119- force_pkce : bool = False ,
120- additional_config : Optional [Dict [str , Any ]] = None ,
121- name : Optional [str ] = None ,
122- authorization_endpoint : Optional [str ] = None ,
123- authorization_endpoint_query_params : Optional [
124- Dict [str , Union [str , None ]]
125- ] = None ,
126- token_endpoint : Optional [str ] = None ,
127- token_endpoint_body_params : Optional [Dict [str , Union [str , None ]]] = None ,
128- user_info_endpoint : Optional [str ] = None ,
129- user_info_endpoint_query_params : Optional [Dict [str , Union [str , None ]]] = None ,
130- user_info_endpoint_headers : Optional [Dict [str , Union [str , None ]]] = None ,
131- jwks_uri : Optional [str ] = None ,
132- oidc_discovery_endpoint : Optional [str ] = None ,
133- user_info_map : Optional [UserInfoMap ] = None ,
134- require_email : bool = True ,
135- generate_fake_email : Optional [
136- Callable [[str , str , Dict [str , Any ]], Awaitable [str ]]
137- ] = None ,
138- validate_id_token_payload : Optional [
139- Callable [
140- [Dict [str , Any ], ProviderConfigForClientType , Dict [str , Any ]],
141- Awaitable [None ],
142- ]
143- ] = None ,
144- ):
145- self .client_id = client_id
146- self .client_secret = client_secret
147- self .scope = scope
148- self .force_pkce = force_pkce
149- self .additional_config = additional_config
150-
151- self .name = name
152- self .authorization_endpoint = authorization_endpoint
153- self .authorization_endpoint_query_params = authorization_endpoint_query_params
154- self .token_endpoint = token_endpoint
155- self .token_endpoint_body_params = token_endpoint_body_params
156- self .user_info_endpoint = user_info_endpoint
157- self .user_info_endpoint_query_params = user_info_endpoint_query_params
158- self .user_info_endpoint_headers = user_info_endpoint_headers
159- self .jwks_uri = jwks_uri
160- self .oidc_discovery_endpoint = oidc_discovery_endpoint
161- self .user_info_map = user_info_map
162- self .require_email = require_email
163- self .validate_id_token_payload = validate_id_token_payload
164- self .generate_fake_email = generate_fake_email
165-
166-
167115class UserFields :
168116 def __init__ (
169117 self ,
@@ -201,12 +149,11 @@ def to_json(self) -> Dict[str, Any]:
201149 }
202150
203151
204- class ProviderConfig :
152+ class CommonProviderConfig :
205153 def __init__ (
206154 self ,
207155 third_party_id : str ,
208156 name : Optional [str ] = None ,
209- clients : Optional [List [ProviderClientConfig ]] = None ,
210157 authorization_endpoint : Optional [str ] = None ,
211158 authorization_endpoint_query_params : Optional [
212159 Dict [str , Union [str , None ]]
@@ -222,7 +169,7 @@ def __init__(
222169 require_email : bool = True ,
223170 validate_id_token_payload : Optional [
224171 Callable [
225- [Dict [str , Any ], ProviderConfigForClientType , Dict [str , Any ]],
172+ [Dict [str , Any ], ProviderConfigForClient , Dict [str , Any ]],
226173 Awaitable [None ],
227174 ]
228175 ] = None ,
@@ -232,7 +179,6 @@ def __init__(
232179 ):
233180 self .third_party_id = third_party_id
234181 self .name = name
235- self .clients = clients
236182 self .authorization_endpoint = authorization_endpoint
237183 self .authorization_endpoint_query_params = authorization_endpoint_query_params
238184 self .token_endpoint = token_endpoint
@@ -251,9 +197,6 @@ def to_json(self) -> Dict[str, Any]:
251197 res = {
252198 "thirdPartyId" : self .third_party_id ,
253199 "name" : self .name ,
254- "clients" : [c .to_json () for c in self .clients ]
255- if self .clients is not None
256- else [],
257200 "authorizationEndpoint" : self .authorization_endpoint ,
258201 "authorizationEndpointQueryParams" : self .authorization_endpoint_query_params ,
259202 "tokenEndpoint" : self .token_endpoint ,
@@ -272,6 +215,132 @@ def to_json(self) -> Dict[str, Any]:
272215 return {k : v for k , v in res .items () if v is not None }
273216
274217
218+ class ProviderConfigForClient (ProviderClientConfig , CommonProviderConfig ):
219+ def __init__ (
220+ self ,
221+ # ProviderClientConfig:
222+ client_id : str ,
223+ client_secret : Optional [str ] = None ,
224+ client_type : Optional [str ] = None ,
225+ scope : Optional [List [str ]] = None ,
226+ force_pkce : bool = False ,
227+ additional_config : Optional [Dict [str , Any ]] = None ,
228+ # CommonProviderConfig:
229+ name : Optional [str ] = None ,
230+ authorization_endpoint : Optional [str ] = None ,
231+ authorization_endpoint_query_params : Optional [
232+ Dict [str , Union [str , None ]]
233+ ] = None ,
234+ token_endpoint : Optional [str ] = None ,
235+ token_endpoint_body_params : Optional [Dict [str , Union [str , None ]]] = None ,
236+ user_info_endpoint : Optional [str ] = None ,
237+ user_info_endpoint_query_params : Optional [Dict [str , Union [str , None ]]] = None ,
238+ user_info_endpoint_headers : Optional [Dict [str , Union [str , None ]]] = None ,
239+ jwks_uri : Optional [str ] = None ,
240+ oidc_discovery_endpoint : Optional [str ] = None ,
241+ user_info_map : Optional [UserInfoMap ] = None ,
242+ require_email : bool = True ,
243+ validate_id_token_payload : Optional [
244+ Callable [
245+ [Dict [str , Any ], ProviderConfigForClient , Dict [str , Any ]],
246+ Awaitable [None ],
247+ ]
248+ ] = None ,
249+ generate_fake_email : Optional [
250+ Callable [[str , str , Dict [str , Any ]], Awaitable [str ]]
251+ ] = None ,
252+ ):
253+ ProviderClientConfig .__init__ (
254+ self ,
255+ client_id ,
256+ client_secret ,
257+ client_type ,
258+ scope ,
259+ force_pkce ,
260+ additional_config ,
261+ )
262+ CommonProviderConfig .__init__ (
263+ self ,
264+ "temp" ,
265+ name ,
266+ authorization_endpoint ,
267+ authorization_endpoint_query_params ,
268+ token_endpoint ,
269+ token_endpoint_body_params ,
270+ user_info_endpoint ,
271+ user_info_endpoint_query_params ,
272+ user_info_endpoint_headers ,
273+ jwks_uri ,
274+ oidc_discovery_endpoint ,
275+ user_info_map ,
276+ require_email ,
277+ validate_id_token_payload ,
278+ generate_fake_email ,
279+ )
280+
281+ def to_json (self ) -> Dict [str , Any ]:
282+ d1 = ProviderClientConfig .to_json (self )
283+ d2 = CommonProviderConfig .to_json (self )
284+ return {** d1 , ** d2 }
285+
286+
287+ class ProviderConfig (CommonProviderConfig ):
288+ def __init__ (
289+ self ,
290+ third_party_id : str ,
291+ name : Optional [str ] = None ,
292+ clients : Optional [List [ProviderClientConfig ]] = None ,
293+ authorization_endpoint : Optional [str ] = None ,
294+ authorization_endpoint_query_params : Optional [
295+ Dict [str , Union [str , None ]]
296+ ] = None ,
297+ token_endpoint : Optional [str ] = None ,
298+ token_endpoint_body_params : Optional [Dict [str , Union [str , None ]]] = None ,
299+ user_info_endpoint : Optional [str ] = None ,
300+ user_info_endpoint_query_params : Optional [Dict [str , Union [str , None ]]] = None ,
301+ user_info_endpoint_headers : Optional [Dict [str , Union [str , None ]]] = None ,
302+ jwks_uri : Optional [str ] = None ,
303+ oidc_discovery_endpoint : Optional [str ] = None ,
304+ user_info_map : Optional [UserInfoMap ] = None ,
305+ require_email : bool = True ,
306+ validate_id_token_payload : Optional [
307+ Callable [
308+ [Dict [str , Any ], ProviderConfigForClient , Dict [str , Any ]],
309+ Awaitable [None ],
310+ ]
311+ ] = None ,
312+ generate_fake_email : Optional [
313+ Callable [[str , str , Dict [str , Any ]], Awaitable [str ]]
314+ ] = None ,
315+ ):
316+ super ().__init__ (
317+ third_party_id ,
318+ name ,
319+ authorization_endpoint ,
320+ authorization_endpoint_query_params ,
321+ token_endpoint ,
322+ token_endpoint_body_params ,
323+ user_info_endpoint ,
324+ user_info_endpoint_query_params ,
325+ user_info_endpoint_headers ,
326+ jwks_uri ,
327+ oidc_discovery_endpoint ,
328+ user_info_map ,
329+ require_email ,
330+ validate_id_token_payload ,
331+ generate_fake_email ,
332+ )
333+ self .clients = clients
334+
335+ def to_json (self ) -> Dict [str , Any ]:
336+ d = CommonProviderConfig .to_json (self )
337+
338+ if self .clients is not None :
339+ d ["clients" ] = [c .to_json () for c in self .clients ]
340+
341+ return d
342+
343+
275344class ProviderInput :
276345 def __init__ (
277346 self ,
0 commit comments