@@ -160,6 +160,18 @@ async def verify_id_token_from_jwks_endpoint_and_get_payload(
160160 raise err
161161
162162
163+ def merge_into_dict (src : Dict [str , Any ], dest : Dict [str , Any ]) -> Dict [str , Any ]:
164+ res = dest .copy ()
165+ for k , v in src .items ():
166+ if v is None :
167+ if k in res :
168+ del res [k ]
169+ else :
170+ res [k ] = v
171+
172+ return res
173+
174+
163175class GenericProvider (Provider ):
164176 def __init__ (self , config : ProviderConfig ):
165177 super ().__init__ (config .third_party_id )
@@ -288,12 +300,7 @@ async def exchange_auth_code_for_oauth_tokens(
288300 access_token_params ["code_verifier" ] = redirect_uri_info .pkce_code_verifier
289301
290302 if self .config .token_endpoint_body_params is not None :
291- for k , v in self .config .token_endpoint_body_params :
292- if v is None :
293- if k in access_token_params :
294- del access_token_params [k ]
295- else :
296- access_token_params [k ] = v
303+ access_token_params = merge_into_dict (self .config .token_endpoint_body_params , access_token_params )
297304
298305 # Transformation needed for dev keys BEGIN
299306 if is_using_oauth_development_client_id (self .config .client_id ):
@@ -336,20 +343,10 @@ async def get_user_info(
336343
337344 if self .config .user_info_endpoint is not None :
338345 if self .config .user_info_endpoint_headers is not None :
339- for k , v in self .config .user_info_endpoint_headers .items ():
340- if v is None :
341- if k in headers :
342- del headers [k ]
343- else :
344- headers [k ] = v
346+ headers = merge_into_dict (self .config .user_info_endpoint_headers , headers )
345347
346348 if self .config .user_info_endpoint_query_params is not None :
347- for k , v in self .config .user_info_endpoint_query_params .items ():
348- if v is None :
349- if k in query_params :
350- del query_params [k ]
351- else :
352- query_params [k ] = v
349+ query_params = merge_into_dict (self .config .user_info_endpoint_query_params , query_params )
353350
354351 raw_user_info_from_provider .from_user_info_api = await do_get_request (
355352 self .config .user_info_endpoint , query_params , headers
@@ -367,7 +364,7 @@ async def get_user_info(
367364
368365
369366def NewProvider (
370- input : ProviderInput ,
367+ input : ProviderInput , # pylint: disable=redefined-builtin
371368 base_class : Callable [[ProviderConfig ], Provider ] = GenericProvider ,
372369) -> Provider :
373370 provider_instance = base_class (input .config )
0 commit comments