@@ -42,6 +42,81 @@ def exchange_auth_code(self, authorization_code, request):
4242 adfs_response = response .json ()
4343 return adfs_response
4444
45+ def get_obo_access_token (self , access_token ):
46+ """
47+ Gets an On Behalf Of (OBO) access token, which is required to make queries against MS Graph
48+
49+ Args:
50+ access_token (str): Original authorization access token from the user
51+
52+ Returns:
53+ obo_access_token (str): OBO access token that can be used with the MS Graph API
54+ """
55+ logger .debug ("Getting OBO access token: %s" , provider_config .token_endpoint )
56+ data = {
57+ "grant_type" : "urn:ietf:params:oauth:grant-type:jwt-bearer" ,
58+ "client_id" : settings .CLIENT_ID ,
59+ "client_secret" : settings .CLIENT_SECRET ,
60+ "assertion" : access_token ,
61+ "requested_token_use" : "on_behalf_of" ,
62+ }
63+ if provider_config .token_endpoint .endswith ("/v2.0/token" ):
64+ data ["scope" ] = 'GroupMember.Read.All'
65+ else :
66+ data ["resource" ] = 'https://graph.microsoft.com'
67+
68+ response = provider_config .session .get (provider_config .token_endpoint , data = data , timeout = settings .TIMEOUT )
69+ # 200 = valid token received
70+ # 400 = 'something' is wrong in our request
71+ if response .status_code == 400 :
72+ logger .error ("ADFS server returned an error: %s" , response .json ()["error_description" ])
73+ raise PermissionDenied
74+
75+ if response .status_code != 200 :
76+ logger .error ("Unexpected ADFS response: %s" , response .content .decode ())
77+ raise PermissionDenied
78+
79+ obo_access_token = response .json ()["access_token" ]
80+ logger .debug ("Received OBO access token: %s" , obo_access_token )
81+ return obo_access_token
82+
83+ def get_group_memberships_from_ms_graph (self , obo_access_token ):
84+ """
85+ Looks up a users group membership from the MS Graph API
86+
87+ Args:
88+ obo_access_token (str): Access token obtained from the OBO authorization endpoint
89+
90+ Returns:
91+ claim_groups (list): List of the users group memberships
92+ """
93+ graph_url = "https://{}/v1.0/me/transitiveMemberOf/microsoft.graph.group" .format (
94+ provider_config .msgraph_endpoint
95+ )
96+ headers = {"Authorization" : "Bearer {}" .format (obo_access_token )}
97+ response = provider_config .session .get (graph_url , headers = headers , timeout = settings .TIMEOUT )
98+ # 200 = valid token received
99+ # 400 = 'something' is wrong in our request
100+ if response .status_code in [400 , 401 ]:
101+ logger .error ("MS Graph server returned an error: %s" , response .json ()["message" ])
102+ raise PermissionDenied
103+
104+ if response .status_code != 200 :
105+ logger .error ("Unexpected MS Graph response: %s" , response .content .decode ())
106+ raise PermissionDenied
107+
108+ claim_groups = []
109+ for group_data in response .json ()["value" ]:
110+ if group_data ["displayName" ] is None :
111+ logger .error (
112+ "The application does not have the required permission to read user groups from "
113+ "MS Graph (GroupMember.Read.All)"
114+ )
115+ raise PermissionDenied
116+
117+ claim_groups .append (group_data ["displayName" ])
118+ return claim_groups
119+
45120 def validate_access_token (self , access_token ):
46121 for idx , key in enumerate (provider_config .signing_keys ):
47122 try :
@@ -100,10 +175,11 @@ def process_access_token(self, access_token, adfs_response=None):
100175 if not claims :
101176 raise PermissionDenied
102177
178+ groups = self .process_user_groups (claims , access_token )
103179 user = self .create_user (claims )
104180 self .update_user_attributes (user , claims )
105- self .update_user_groups (user , claims )
106- self .update_user_flags (user , claims )
181+ self .update_user_groups (user , groups )
182+ self .update_user_flags (user , claims , groups )
107183
108184 signals .post_authenticate .send (
109185 sender = self ,
@@ -116,6 +192,41 @@ def process_access_token(self, access_token, adfs_response=None):
116192 user .save ()
117193 return user
118194
195+ def process_user_groups (self , claims , access_token ):
196+ """
197+ Checks the user groups are in the claim or pulls them from MS Graph if
198+ applicable
199+
200+ Args:
201+ claims (dict): claims from the access token
202+ access_token (str): Used to make an OBO authentication request if
203+ groups must be obtained from Microsoft Graph
204+
205+ Returns:
206+ groups (list): Groups the user is a member of, taken from the access token or MS Graph
207+ """
208+ groups = []
209+ if settings .GROUPS_CLAIM is None :
210+ logger .debug ("No group claim has been configured" )
211+ return groups
212+
213+ if settings .GROUPS_CLAIM in claims :
214+ groups = claims [settings .GROUPS_CLAIM ]
215+ if not isinstance (groups , list ):
216+ groups = [groups , ]
217+ elif (
218+ settings .TENANT_ID != "adfs"
219+ and "_claim_names" in claims
220+ and settings .GROUPS_CLAIM in claims ["_claim_names" ]
221+ ):
222+ obo_access_token = self .get_obo_access_token (access_token )
223+ groups = self .get_group_memberships_from_ms_graph (obo_access_token )
224+ else :
225+ logger .debug ("The configured groups claim %s was not found in the access token" ,
226+ settings .GROUPS_CLAIM )
227+
228+ return groups
229+
119230 def create_user (self , claims ):
120231 """
121232 Create the user if it doesn't exist yet
@@ -201,26 +312,18 @@ def update_user_attributes(self, user, claims, claim_mapping=None):
201312 msg = "Model '{}' has no field named '{}'. Check ADFS claims mapping."
202313 raise ImproperlyConfigured (msg .format (user ._meta .model_name , field ))
203314
204- def update_user_groups (self , user , claims ):
315+ def update_user_groups (self , user , claim_groups ):
205316 """
206317 Updates user group memberships based on the GROUPS_CLAIM setting.
207318
208319 Args:
209320 user (django.contrib.auth.models.User): User model instance
210- claims (dict ): Claims from the access token
321+ claim_groups (list ): User groups from the access token / MS Graph
211322 """
212323 if settings .GROUPS_CLAIM is not None :
213324 # Update the user's group memberships
214325 django_groups = [group .name for group in user .groups .all ()]
215326
216- if settings .GROUPS_CLAIM in claims :
217- claim_groups = claims [settings .GROUPS_CLAIM ]
218- if not isinstance (claim_groups , list ):
219- claim_groups = [claim_groups , ]
220- else :
221- logger .debug ("The configured groups claim '%s' was not found in the access token" ,
222- settings .GROUPS_CLAIM )
223- claim_groups = []
224327 if sorted (claim_groups ) != sorted (django_groups ):
225328 existing_groups = list (Group .objects .filter (name__in = claim_groups ).iterator ())
226329 existing_group_names = frozenset (group .name for group in existing_groups )
@@ -241,29 +344,22 @@ def update_user_groups(self, user, claims):
241344 pass
242345 user .groups .set (existing_groups + new_groups )
243346
244- def update_user_flags (self , user , claims ):
347+ def update_user_flags (self , user , claims , claim_groups ):
245348 """
246349 Updates user boolean attributes based on the BOOLEAN_CLAIM_MAPPING setting.
247350
248351 Args:
249352 user (django.contrib.auth.models.User): User model instance
250353 claims (dict): Claims from the access token
354+ claim_groups (list): User groups from the access token / MS Graph
251355 """
252356 if settings .GROUPS_CLAIM is not None :
253- if settings .GROUPS_CLAIM in claims :
254- access_token_groups = claims [settings .GROUPS_CLAIM ]
255- if not isinstance (access_token_groups , list ):
256- access_token_groups = [access_token_groups , ]
257- else :
258- logger .debug ("The configured group claim was not found in the access token" )
259- access_token_groups = []
260-
261357 for flag , group in settings .GROUP_TO_FLAG_MAPPING .items ():
262358 if hasattr (user , flag ):
263359 if not isinstance (group , list ):
264360 group = [group ]
265361
266- if any (group_list_item in access_token_groups for group_list_item in group ):
362+ if any (group_list_item in claim_groups for group_list_item in group ):
267363 value = True
268364 else :
269365 value = False
0 commit comments