4
4
// SPDX-License-Identifier: AGPL-3.0-only
5
5
// Please see LICENSE in the repository root for full details.
6
6
7
- use std:: sync:: LazyLock ;
7
+ use std:: { collections :: BTreeSet , sync:: LazyLock } ;
8
8
9
9
use axum:: { Json , extract:: State , http:: HeaderValue , response:: IntoResponse } ;
10
10
use hyper:: { HeaderMap , StatusCode } ;
@@ -24,7 +24,7 @@ use mas_storage::{
24
24
use oauth2_types:: {
25
25
errors:: { ClientError , ClientErrorCode } ,
26
26
requests:: { IntrospectionRequest , IntrospectionResponse } ,
27
- scope:: ScopeToken ,
27
+ scope:: { Scope , ScopeToken } ,
28
28
} ;
29
29
use opentelemetry:: { Key , KeyValue , metrics:: Counter } ;
30
30
use thiserror:: Error ;
@@ -190,9 +190,33 @@ const INACTIVE: IntrospectionResponse = IntrospectionResponse {
190
190
device_id : None ,
191
191
} ;
192
192
193
- const API_SCOPE : ScopeToken = ScopeToken :: from_static ( "urn:matrix:org.matrix.msc2967.client:api:*" ) ;
193
+ const UNSTABLE_API_SCOPE : ScopeToken =
194
+ ScopeToken :: from_static ( "urn:matrix:org.matrix.msc2967.client:api:*" ) ;
195
+ const STABLE_API_SCOPE : ScopeToken = ScopeToken :: from_static ( "urn:matrix:client:api:*" ) ;
194
196
const SYNAPSE_ADMIN_SCOPE : ScopeToken = ScopeToken :: from_static ( "urn:synapse:admin:*" ) ;
195
197
198
+ /// Normalize a scope by adding the stable and unstable API scopes equivalents
199
+ /// if missing
200
+ fn normalize_scope ( mut scope : Scope ) -> Scope {
201
+ // Here we abuse the fact that the scope is a BTreeSet to not care about
202
+ // duplicates
203
+ let mut to_add = BTreeSet :: new ( ) ;
204
+ for token in & * scope {
205
+ if token == & STABLE_API_SCOPE {
206
+ to_add. insert ( UNSTABLE_API_SCOPE ) ;
207
+ } else if token == & UNSTABLE_API_SCOPE {
208
+ to_add. insert ( STABLE_API_SCOPE ) ;
209
+ } else if let Some ( device) = Device :: from_scope_token ( token) {
210
+ let tokens = device
211
+ . to_scope_token ( )
212
+ . expect ( "from/to scope token rountrip should never fail" ) ;
213
+ to_add. extend ( tokens) ;
214
+ }
215
+ }
216
+ scope. append ( & mut to_add) ;
217
+ scope
218
+ }
219
+
196
220
#[ tracing:: instrument(
197
221
name = "handlers.oauth2.introspection.post" ,
198
222
fields( client. id = client_authorization. client_id( ) ) ,
@@ -311,9 +335,11 @@ pub(crate) async fn post(
311
335
] ,
312
336
) ;
313
337
338
+ let scope = normalize_scope ( session. scope ) ;
339
+
314
340
IntrospectionResponse {
315
341
active : true ,
316
- scope : Some ( session . scope ) ,
342
+ scope : Some ( scope) ,
317
343
client_id : Some ( session. client_id . to_string ( ) ) ,
318
344
username,
319
345
token_type : Some ( OAuthTokenTypeHint :: AccessToken ) ,
@@ -382,9 +408,11 @@ pub(crate) async fn post(
382
408
] ,
383
409
) ;
384
410
411
+ let scope = normalize_scope ( session. scope ) ;
412
+
385
413
IntrospectionResponse {
386
414
active : true ,
387
- scope : Some ( session . scope ) ,
415
+ scope : Some ( scope) ,
388
416
client_id : Some ( session. client_id . to_string ( ) ) ,
389
417
username,
390
418
token_type : Some ( OAuthTokenTypeHint :: RefreshToken ) ,
@@ -446,9 +474,9 @@ pub(crate) async fn post(
446
474
. transpose ( ) ?
447
475
} ;
448
476
449
- let scope = [ API_SCOPE ]
477
+ let scope = [ STABLE_API_SCOPE , UNSTABLE_API_SCOPE ]
450
478
. into_iter ( )
451
- . chain ( device_scope_opt)
479
+ . chain ( device_scope_opt. into_iter ( ) . flatten ( ) )
452
480
. chain ( synapse_admin_scope_opt)
453
481
. collect ( ) ;
454
482
@@ -530,9 +558,9 @@ pub(crate) async fn post(
530
558
. transpose ( ) ?
531
559
} ;
532
560
533
- let scope = [ API_SCOPE ]
561
+ let scope = [ STABLE_API_SCOPE , UNSTABLE_API_SCOPE ]
534
562
. into_iter ( )
535
- . chain ( device_scope_opt)
563
+ . chain ( device_scope_opt. into_iter ( ) . flatten ( ) )
536
564
. chain ( synapse_admin_scope_opt)
537
565
. collect ( ) ;
538
566
@@ -879,7 +907,7 @@ mod tests {
879
907
let refresh_token = response[ "refresh_token" ] . as_str ( ) . unwrap ( ) ;
880
908
let device_id = response[ "device_id" ] . as_str ( ) . unwrap ( ) ;
881
909
let expected_scope: Scope =
882
- format ! ( "urn:matrix:org.matrix.msc2967.client:api:* urn:matrix:org.matrix.msc2967.client:device:{device_id}" )
910
+ format ! ( "urn:matrix:org.matrix.msc2967.client:api:* urn:matrix:org.matrix.msc2967.client:device:{device_id} urn:matrix:client:api:* urn:matrix:client:device:{device_id} " )
883
911
. parse ( )
884
912
. unwrap ( ) ;
885
913
@@ -912,7 +940,7 @@ mod tests {
912
940
assert_eq ! ( response. token_type, Some ( OAuthTokenTypeHint :: AccessToken ) ) ;
913
941
assert_eq ! (
914
942
response. scope. map( |s| s. to_string( ) ) ,
915
- Some ( "urn:matrix:org.matrix.msc2967.client:api:*" . to_owned( ) )
943
+ Some ( "urn:matrix:client:api:* urn:matrix: org.matrix.msc2967.client:api:*" . to_owned( ) )
916
944
) ;
917
945
assert_eq ! ( response. device_id. as_deref( ) , Some ( device_id) ) ;
918
946
0 commit comments