Skip to content

Commit 82f8652

Browse files
committed
add first unit tests
1 parent d4e985d commit 82f8652

File tree

1 file changed

+147
-12
lines changed

1 file changed

+147
-12
lines changed

superset/stackable/patches/4.0.2/001-opa-integration.patch

Lines changed: 147 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
diff --git a/superset/security/OpaSupersetSecurityManager.py b/superset/security/OpaSupersetSecurityManager.py
1+
diff --git a/superset/security/opa_manager.py b/superset/security/opa_manager.py
22
new file mode 100644
3-
index 0000000000..a435572d0e
3+
index 000000000..771377a2b
44
--- /dev/null
5-
+++ b/superset/security/OpaSupersetSecurityManager.py
6-
@@ -0,0 +1,73 @@
5+
+++ b/superset/security/opa_manager.py
6+
@@ -0,0 +1,75 @@
7+
+import logging
8+
+
79
+from typing import List, Optional, Tuple
810
+from http.client import HTTPException
911
+from opa_client.opa import OpaClient
@@ -15,7 +17,7 @@ index 0000000000..a435572d0e
1517
+ User,
1618
+)
1719
+
18-
+import logging
20+
+
1921
+class OpaSupersetSecurityManager(SupersetSecurityManager):
2022
+ def get_user_roles(self, user: Optional[User] = None) -> List[Role]:
2123
+ if not user:
@@ -27,10 +29,10 @@ index 0000000000..a435572d0e
2729
+ logging.info(f'OPA returned roles: {opa_role_names}')
2830
+
2931
+ opa_roles = set(map(self.resolve_role, opa_role_names))
30-
+ logging.info(f'found Roles in Database: {opa_roles}')
32+
+ logging.info(f'Resolved OPA Roles in Database: {opa_roles}')
3133
+ # Ensure that in case of a bad or no response from OPA each user will have at least one role.
32-
+ if opa_roles == {None} or opa_roles == []:
33-
+ opa_roles.add(default_role)
34+
+ if opa_roles == {None} or opa_roles == set():
35+
+ opa_roles = {default_role}
3436
+
3537
+ if set(user.roles) != opa_roles:
3638
+ logging.info(f'Found diff in {user.roles} vs. {opa_roles}')
@@ -47,7 +49,7 @@ index 0000000000..a435572d0e
4749
+ :returns: A list of role names or an empty list if an exception during the connection to OPA
4850
+ is encountered or if OPA didn't return a list.
4951
+ """
50-
+ host, port, tls = self.resolve_opa_endpoint()
52+
+ host, port, tls = self.resolve_opa_base_path()
5153
+ client = OpaClient(host = host, port = port, ssl = tls)
5254
+ try:
5355
+ response = client.query_rule(
@@ -65,9 +67,9 @@ index 0000000000..a435572d0e
6567
+ return roles
6668
+
6769
+
68-
+ def resolve_opa_endpoint(self) -> Tuple[str, int, bool]:
69-
+ opa_endpoint = current_app.config.get('STACKABLE_OPA_BASE_URL')
70-
+ [protocol, host, port] = opa_endpoint.split(":")
70+
+ def resolve_opa_base_path(self) -> Tuple[str, int, bool]:
71+
+ opa_base_path = current_app.config.get('STACKABLE_OPA_BASE_URL')
72+
+ [protocol, host, port] = opa_base_path.split(":")
7173
+ return host.lstrip('/'), int(port.rstrip('/')), protocol == 'https'
7274
+
7375
+
@@ -77,3 +79,136 @@ index 0000000000..a435572d0e
7779
+ logging.info(f'Creating role {role_name} as it doesn\'t already exist.')
7880
+ self.add_role(role_name)
7981
+ return self.find_role(role_name)
82+
diff --git a/tests/unit_tests/security/opa_manager_test.py b/tests/unit_tests/security/opa_manager_test.py
83+
new file mode 100644
84+
index 000000000..978d1ca1c
85+
--- /dev/null
86+
+++ b/tests/unit_tests/security/opa_manager_test.py
87+
@@ -0,0 +1,127 @@
88+
+# Licensed to the Apache Software Foundation (ASF) under one
89+
+# or more contributor license agreements. See the NOTICE file
90+
+# distributed with this work for additional information
91+
+# regarding copyright ownership. The ASF licenses this file
92+
+# to you under the Apache License, Version 2.0 (the
93+
+# "License"); you may not use this file except in compliance
94+
+# with the License. You may obtain a copy of the License at
95+
+#
96+
+# http://www.apache.org/licenses/LICENSE-2.0
97+
+#
98+
+# Unless required by applicable law or agreed to in writing,
99+
+# software distributed under the License is distributed on an
100+
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
101+
+# KIND, either express or implied. See the License for the
102+
+# specific language governing permissions and limitations
103+
+# under the License.
104+
+
105+
+# pylint: disable=invalid-name, unused-argument, redefined-outer-name
106+
+
107+
+import pytest
108+
+from flask_appbuilder.security.sqla.models import Role, User
109+
+from flask import current_app
110+
+from pytest_mock import MockFixture
111+
+
112+
+from superset.extensions import appbuilder
113+
+from superset.security.opa_manager import OpaSupersetSecurityManager
114+
+
115+
+
116+
+def test_opa_security_manager(app_context: None) -> None:
117+
+ """
118+
+ Test that the OPA security manager can be built.
119+
+ """
120+
+ sm = OpaSupersetSecurityManager(appbuilder)
121+
+ assert sm
122+
+
123+
+
124+
+@pytest.fixture
125+
+def user() -> User:
126+
+ """
127+
+ Return a user.
128+
+ """
129+
+ user = User()
130+
+ user.id = 1234
131+
+ user.first_name = 'mock'
132+
+ user.last_name = 'mock'
133+
+ user.username = 'mock'
134+
+ user.email = '[email protected]'
135+
+
136+
+ return user
137+
+
138+
+
139+
+def test_add_roles(
140+
+ mocker: MockFixture,
141+
+ app_context: None,
142+
+ user: User,
143+
+) -> None:
144+
+ """
145+
+ Test that roles are correctly added to a user.
146+
+ """
147+
+ sm = OpaSupersetSecurityManager(appbuilder)
148+
+ mocker.patch('flask_appbuilder.security.sqla.manager.SecurityManager.update_user', return_value = True)
149+
+
150+
+ opa_roles = ['Test1', 'Test2', 'Test3']
151+
+ mocker.patch('superset.security.opa_manager.OpaSupersetSecurityManager.get_opa_user_roles', return_value = opa_roles)
152+
+ assert set(sm.get_user_roles(user)) == set(map(sm.resolve_role, opa_roles))
153+
+
154+
+
155+
+def test_change_roles(
156+
+ mocker: MockFixture,
157+
+ app_context: None,
158+
+ user: User,
159+
+) -> None:
160+
+ """
161+
+ Test that roles are correcty changed on a user.
162+
+ """
163+
+ sm = OpaSupersetSecurityManager(appbuilder)
164+
+ mocker.patch('flask_appbuilder.security.sqla.manager.SecurityManager.update_user', return_value = True)
165+
+
166+
+ user_roles = ['Test1', 'Test2', 'Test3']
167+
+ opa_roles = ['Test4']
168+
+ user.roles = list(map(sm.resolve_role, user_roles))
169+
+ mocker.patch('superset.security.opa_manager.OpaSupersetSecurityManager.get_opa_user_roles', return_value = opa_roles)
170+
+ assert set(sm.get_user_roles(user)) == set(map(sm.resolve_role, opa_roles))
171+
+
172+
+
173+
+def test_no_roles(
174+
+ mocker: MockFixture,
175+
+ app_context: None,
176+
+ user: User,
177+
+) -> None:
178+
+ """
179+
+ Test that only the default role is assigned if OPA returns no roles.
180+
+ """
181+
+ sm = OpaSupersetSecurityManager(appbuilder)
182+
+ mocker.patch('flask_appbuilder.security.sqla.manager.SecurityManager.update_user', return_value = True)
183+
+
184+
+ opa_roles = []
185+
+ mocker.patch('superset.security.opa_manager.OpaSupersetSecurityManager.get_opa_user_roles', return_value = opa_roles)
186+
+ default_role = sm.resolve_role('Public')
187+
+ assert set(sm.get_user_roles(user)) == {default_role}
188+
+
189+
+
190+
+def test_roles_not_created(
191+
+ mocker: MockFixture,
192+
+ app_context: None,
193+
+ user: User,
194+
+) -> None:
195+
+ """
196+
+ Test that only the default role is assigned if a new role can't be created in the DB.
197+
+ """
198+
+ sm = OpaSupersetSecurityManager(appbuilder)
199+
+ mocker.patch('flask_appbuilder.security.sqla.manager.SecurityManager.update_user', return_value = True)
200+
+
201+
+ opa_roles = ['Test5', 'Test6']
202+
+ mocker.patch('superset.security.opa_manager.OpaSupersetSecurityManager.get_opa_user_roles', return_value = opa_roles)
203+
+ mocker.patch('superset.security.opa_manager.OpaSupersetSecurityManager.add_role', return_value = None)
204+
+ default_role = sm.resolve_role('Public')
205+
+ assert set(sm.get_user_roles(user)) == {default_role}
206+
+
207+
+
208+
+def test_resolve_opa_base_path(
209+
+ mocker: MockFixture,
210+
+ app_context: None,
211+
+) -> None:
212+
+ sm = OpaSupersetSecurityManager(appbuilder)
213+
+ mocker.patch('flask.current_app.config.get', return_value = 'http://opa-instance:8081')
214+
+ assert sm.resolve_opa_base_path() == ('opa-instance', 8081, False)

0 commit comments

Comments
 (0)