22Custom Auth manager for Airflow 
33""" 
44
5- from  typing  import  override 
65from  airflow .auth .managers .base_auth_manager  import  ResourceMethod 
76from  airflow .auth .managers .models .base_user  import  BaseUser 
87from  airflow .auth .managers .models .resource_details  import  (
1817from  airflow .providers .fab .auth_manager .fab_auth_manager  import  FabAuthManager 
1918from  airflow .utils .log .logging_mixin  import  LoggingMixin 
2019from  cachetools  import  TTLCache , cachedmethod 
20+ from  typing  import  override 
2121import  json 
2222import  requests 
2323
2424class  OpaInput :
25+     """ 
26+     Wrapper for the OPA input structure which is hashable so that it can be cached 
27+     """ 
2528
2629    def  __init__ (self , input : dict ) ->  None :
2730        self .input  =  input 
@@ -42,42 +45,71 @@ class OpaFabAuthManager(FabAuthManager, LoggingMixin):
4245    Agent 
4346    """ 
4447
48+     AUTH_OPA_CACHE_MAXSIZE_DEFAULT = 1000 
49+     AUTH_OPA_CACHE_TTL_IN_SEC_DEFAULT = 30 
50+     AUTH_OPA_REQUEST_URL_DEFAULT = 'http://opa:8081/v1/data/airflow' 
51+     AUTH_OPA_REQUEST_TIMEOUT_DEFAULT = 10 
52+ 
4553    def  init (self ) ->  None :
46-         """Run operations when Airflow is initializing.""" 
54+         """ 
55+         Run operations when Airflow is initializing. 
56+         """ 
57+ 
4758        super ().init ()
48-         self ._init_config ()
4959
5060        config  =  self .appbuilder .get_app .config 
5161        self .opa_cache  =  TTLCache (
52-             maxsize = config .get ("AUTH_OPA_CACHE_MAXSIZE" ),
53-             ttl = config .get ("AUTH_OPA_CACHE_TTL_IN_SEC" ),
62+             maxsize = config .get (
63+                 'AUTH_OPA_CACHE_MAXSIZE' ,
64+                 self .AUTH_OPA_CACHE_MAXSIZE_DEFAULT 
65+             ),
66+             ttl = config .get (
67+                 'AUTH_OPA_CACHE_TTL_IN_SEC' ,
68+                 self .AUTH_OPA_CACHE_TTL_IN_SEC_DEFAULT 
69+             ),
5470        )
5571        self .opa_session  =  requests .Session ()
5672
57-     def  _init_config (self ):
58-         config  =  self .appbuilder .get_app .config 
59-         config .setdefault ('AUTH_OPA_CACHE_MAXSIZE' , 1000 )
60-         config .setdefault ("AUTH_OPA_CACHE_TTL_IN_SEC" , 30 )
61-         config .setdefault ("AUTH_OPA_REQUEST_URL" , "http://opa:8081/v1/data/airflow" )
62-         config .setdefault ("AUTH_OPA_REQUEST_TIMEOUT" , 10 )
63- 
6473    def  call_opa (self , url : str , json : dict , timeout : int ) ->  requests .Response :
74+         """ 
75+         Send a POST request to OPA. 
76+ 
77+         This function can be overriden in tests. 
78+ 
79+         :param url: URL for the OPA rule 
80+         :param json: json to send in the body 
81+         """ 
82+ 
6583        return  self .opa_session .post (url = url , json = json , timeout = timeout )
6684
6785    @cachedmethod (lambda  self : self .opa_cache ) 
6886    def  _is_authorized_in_opa (self , endpoint : str , input : OpaInput ) ->  bool :
87+         """ 
88+         Forward an authorization request to OPA. 
89+ 
90+         :param endpoint: the OPA rule 
91+         :param input: the input structure for OPA 
92+         """ 
93+ 
94+         self .log .debug ("Forward authorization request to OPA" )
95+ 
6996        config  =  self .appbuilder .get_app .config 
70-         opa_url  =  config .get ("AUTH_OPA_REQUEST_URL" )
97+         opa_url  =  config .get (
98+             'AUTH_OPA_REQUEST_URL' ,
99+             self .AUTH_OPA_REQUEST_URL_DEFAULT 
100+         )
71101        try :
72102            response  =  self .call_opa (
73103                f'{ opa_url }  /{ endpoint }  ' ,
74104                json = input .to_dict (),
75-                 timeout = config .get ("AUTH_OPA_REQUEST_TIMEOUT" )
105+                 timeout = config .get (
106+                     'AUTH_OPA_REQUEST_TIMEOUT' ,
107+                     self .AUTH_OPA_REQUEST_TIMEOUT_DEFAULT 
108+                 )
76109            )
77-             result  =  response .json ().get ("result" )
78-             return  result  ==  True 
110+             return  response .json ().get ('result' )
79111        except  Exception  as  e :
80-             self .log .error (f "Request to OPA failed" , exc_info = e )
112+             self .log .error ("Request to OPA failed" , exc_info = e )
81113            return  False 
82114
83115    @override  
@@ -98,7 +130,7 @@ def is_authorized_configuration(
98130            current user 
99131        """ 
100132
101-         self .log .info ( "Forward  is_authorized_configuration to OPA " )
133+         self .log .debug ( "Check  is_authorized_configuration" )
102134
103135        if  not  user :
104136            user  =  self .get_user ()
@@ -141,7 +173,7 @@ def is_authorized_connection(
141173            current user 
142174        """ 
143175
144-         self .log .info ( "Forward  is_authorized_connection to OPA " )
176+         self .log .debug ( "Check  is_authorized_connection" )
145177
146178        if  not  user :
147179            user  =  self .get_user ()
@@ -187,7 +219,7 @@ def is_authorized_dag(
187219            current user 
188220        """ 
189221
190-         self .log .info ( "Forward  is_authorized_dag to OPA " )
222+         self .log .debug ( "Check  is_authorized_dag" )
191223
192224        if  not  user :
193225            user  =  self .get_user ()
@@ -236,7 +268,7 @@ def is_authorized_dataset(
236268            current user 
237269        """ 
238270
239-         self .log .info ( "Forward  is_authorized_dataset to OPA " )
271+         self .log .debug ( "Check  is_authorized_dataset" )
240272
241273        if  not  user :
242274            user  =  self .get_user ()
@@ -279,7 +311,7 @@ def is_authorized_pool(
279311            current user 
280312        """ 
281313
282-         self .log .info ( "Forward  is_authorized_pool to OPA " )
314+         self .log .debug ( "Check  is_authorized_pool" )
283315
284316        if  not  user :
285317            user  =  self .get_user ()
@@ -322,7 +354,7 @@ def is_authorized_variable(
322354            current user 
323355        """ 
324356
325-         self .log .info ( "Forward  is_authorized_variable to OPA " )
357+         self .log .debug ( "Check  is_authorized_variable" )
326358
327359        if  not  user :
328360            user  =  self .get_user ()
@@ -363,7 +395,7 @@ def is_authorized_view(
363395            current user 
364396        """ 
365397
366-         self .log .info ( "Forward  is_authorized_view to OPA " )
398+         self .log .debug ( "Check  is_authorized_view" )
367399
368400        if  not  user :
369401            user  =  self .get_user ()
@@ -405,7 +437,7 @@ def is_authorized_custom_view(
405437            current user 
406438        """ 
407439
408-         self .log .info ( "Forward  is_authorized_custom_view to OPA " )
440+         self .log .debug ( "Check  is_authorized_custom_view" )
409441
410442        if  not  user :
411443            user  =  self .get_user ()
0 commit comments