Skip to content
Open
20 changes: 11 additions & 9 deletions core/database_arango.py
Original file line number Diff line number Diff line change
Expand Up @@ -1190,17 +1190,19 @@ def filter(
aql_args["offset"] = offset
aql_args["count"] = count

# TODO: Interpolate this query
graph_query_string = ""
for name, graph, direction, field in graph_queries:
field_aggregation = "||".join([f"v.{field}" for field in field.split("|")])
graph_query_string += f"\nLET {name} = (FOR v, e in 1..1 {direction} o {graph} RETURN {{ [{field_aggregation}]: e }})"

acl_query = ""
neighbor_acl_filter = ""
if user and RBAC_ENABLED and not user.admin:
acl_query = "LET acl = FIRST(FOR v, e, p in 1..2 inbound o acls FILTER v.username == @username RETURN true) or false"
neighbor_acl_filter = "LET neighbor_acl = FIRST(FOR aclv IN 1..2 INBOUND v acls FILTER aclv.username == @username RETURN true) OR false\n FILTER neighbor_acl"
aql_args["username"] = user.username

# TODO: Interpolate this query
graph_query_string = ""
for name, graph, direction, field in graph_queries:
field_aggregation = "||".join([f"v.{field}" for field in field.split("|")])
graph_query_string += f"\nLET {name} = (FOR v, e in 1..1 {direction} o {graph} \n {neighbor_acl_filter}\n RETURN {{ [{field_aggregation}]: e }})"

filter_string = ""
if filter_conditions:
filter_string = f"FILTER {' AND '.join(filter_conditions)}"
Expand All @@ -1222,11 +1224,11 @@ def filter(
aql_string = f"""
FOR o IN @@collection
{aql_search}
{links_count_query}
{graph_query_string}
{acl_query}
{filter_string}
{acl_query}
{acl_filter}
{links_count_query}
{graph_query_string}
{aql_sort}
{limit}
"""
Expand Down
16 changes: 16 additions & 0 deletions test_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
user_admin = False
RBAC_ENABLED = True
user = True

neighbor_acl_filter = ""
if user and RBAC_ENABLED and not user_admin:
neighbor_acl_filter = "FILTER FIRST(FOR aclv IN 1..2 INBOUND v acls FILTER aclv.username == @username RETURN true) OR false"

graph_queries = [("links", "links", "inbound", "name")]

graph_query_string = ""
for name, graph, direction, field in graph_queries:
field_aggregation = "||".join([f"v.{field}" for field in field.split("|")])
graph_query_string += f"\nLET {name} = (FOR v, e in 1..1 {direction} o {graph} {neighbor_acl_filter} RETURN {{ [{field_aggregation}]: e }})"

print(graph_query_string)
28 changes: 28 additions & 0 deletions tests/schemas/rbac.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,31 @@ def test_get_acls(self):
self.assertEqual(len(self.entity1._acls), 2)
self.assertIn(self.group1.name, self.entity1._acls)
self.assertIn(self.user1.username, self.entity1._acls)

def test_filter_entities_with_graph_queries_respects_acls(self):
"""Test that filter() with graph_queries takes user ACLs on linked objects into account"""
self.user1.link_to_acl(self.entity1, roles.Role.READER)

graph_queries = [("links", "links", "inbound", "name|value")]

# User 1 has access to entity1, but NO access to observable1 (which is linked to entity1)
entities, total = entity.Entity.filter(
{"name": "malware1"}, user=self.user1, graph_queries=graph_queries
)
self.assertEqual(len(entities), 1)
# Without access, the 'links' attribute should either not be present or be empty
if hasattr(entities[0], "links"):
self.assertEqual(len(entities[0].links), 0)

# Grant access to observable1
self.user1.link_to_acl(self.observable1, roles.Role.READER)

entities, total = entity.Entity.filter(
{"name": "malware1"}, user=self.user1, graph_queries=graph_queries
)
self.assertEqual(len(entities), 1)
# With access, the 'links' attribute should be present and populated
self.assertTrue(hasattr(entities[0], "links"))
self.assertIsInstance(entities[0].links, dict)
self.assertEqual(len(entities[0].links), 1)
self.assertIn("test.com", entities[0].links)
Loading