Skip to content

Commit cbbd657

Browse files
authored
Global search (#1257)
1 parent de98400 commit cbbd657

File tree

4 files changed

+233
-12
lines changed

4 files changed

+233
-12
lines changed

core/database_arango.py

Lines changed: 54 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -180,8 +180,7 @@ def create_analyzers(self):
180180
self.db.create_analyzer(
181181
name="norm",
182182
analyzer_type="norm",
183-
properties={"locale": "en", "accent": False, "case": "lower"},
184-
features=[],
183+
properties={"locale": "en.utf-8", "accent": False, "case": "lower"},
185184
)
186185

187186
def refresh_views(self):
@@ -306,7 +305,6 @@ def create_views(self):
306305
link_definitions[view_target] = {
307306
"analyzers": ["identity", "norm"],
308307
"includeAllFields": True,
309-
"storedValues": [{"fields": ["name", "tags", "type"]}],
310308
"trackListPositions": False,
311309
}
312310

@@ -333,6 +331,20 @@ def create_views(self):
333331
except Exception:
334332
pass
335333

334+
for target in link_definitions:
335+
del link_definitions[target]["analyzers"]
336+
link_definitions[target]["analyzers"] = []
337+
link_definitions[target]["includeAllFields"] = False
338+
link_definitions[target]["fields"] = {
339+
"tags": {"fields": {"name": {"analyzers": ["identity", "norm"]}}},
340+
"dfiq_tags": {"analyzers": ["identity", "norm"]},
341+
"type": {"analyzers": ["identity", "norm"]},
342+
"root_type": {"analyzers": ["identity", "norm"]},
343+
"value": {"analyzers": ["identity", "norm"]},
344+
"name": {"analyzers": ["identity", "norm"]},
345+
"created": {"analyzers": ["identity", "norm"]},
346+
}
347+
336348
self.db.create_arangosearch_view(
337349
name="all_objects_view",
338350
properties={
@@ -343,6 +355,7 @@ def create_views(self):
343355
{"field": "created", "direction": "desc"},
344356
{"field": "value", "direction": "asc"},
345357
{"field": "name", "direction": "asc"},
358+
{"field": "tags.name", "direction": "asc"},
346359
],
347360
},
348361
)
@@ -439,6 +452,7 @@ class ArangoYetiConnector(AbstractYetiConnector):
439452
"""Yeti connector for an ArangoDB backend."""
440453

441454
_db = db
455+
_collection_name: str | None = None
442456

443457
def __init__(self):
444458
self._arango_id = None
@@ -1016,6 +1030,7 @@ def filter(
10161030
offset: Skip this many objects when querying the DB.
10171031
count: How many objecst after `offset` to return.
10181032
sorting: A list of (order, ascending) fields to sort by.
1033+
aliases: A list of (alias, type) tuples to use for filtering.
10191034
graph_queries: A list of (name, graph, direction, field) tuples to
10201035
query the graph with.
10211036
wildcard: whether all values should be interpreted as wildcard searches.
@@ -1026,11 +1041,19 @@ def filter(
10261041
"""
10271042
cls._get_collection()
10281043
colname = cls._collection_name
1044+
if colname is None:
1045+
colname = "all_objects_view"
10291046
conditions = []
10301047
filter_conditions = [] # used for clauses that are not supported by arangosearch
10311048
sorts = []
10321049

10331050
using_view = False
1051+
generic_query = False
1052+
1053+
if colname == "all_objects_view":
1054+
generic_query = True
1055+
using_view = True
1056+
10341057
if (
10351058
query_args
10361059
and colname in ("observables", "entities", "indicators", "dfiq")
@@ -1085,9 +1108,13 @@ def filter(
10851108
aql_args[f"arg{i}_key"] = key
10861109
elif key == "tags":
10871110
if using_view:
1088-
conditions.append(f"@arg{i}_value ALL IN o.tags.name")
1111+
conditions.append(
1112+
f"(FOR t in @arg{i}_value RETURN LOWER(t)) ALL IN o.tags.name"
1113+
)
10891114
else:
1090-
conditions.append(f"@arg{i}_value ALL IN o.tags[*].name")
1115+
conditions.append(
1116+
f"(FOR t in @arg{i}_value RETURN LOWER(t)) ALL IN o.tags[*].name"
1117+
)
10911118
elif key in ("created", "modified", "tags.expires"):
10921119
# Value is a string, we're checking the first character.
10931120
operator = value[0]
@@ -1114,11 +1141,16 @@ def filter(
11141141
key_conditions = [f"REGEX_TEST(o.@arg{i}_key, @arg{i}_value, true)"]
11151142

11161143
for alias, alias_type in aliases:
1117-
if (
1118-
alias_type in {"text", "option"}
1119-
or alias_type == "list"
1120-
and using_view
1121-
):
1144+
if alias == "tags":
1145+
if using_view:
1146+
key_conditions.append(
1147+
f"ANALYZER(LIKE(o.tags.name, LOWER(@arg{i}_value)), 'norm')"
1148+
)
1149+
else:
1150+
key_conditions.append(
1151+
f"LOWER(@arg{i}_value) IN o.tags[*].name"
1152+
)
1153+
if alias_type in {"text", "option", "list"} and using_view:
11221154
if using_view and not using_regex:
11231155
key_conditions.append(
11241156
f"ANALYZER(LIKE(o.{alias}, LOWER(@arg{i}_value)), 'norm')"
@@ -1232,7 +1264,14 @@ def filter(
12321264
results = []
12331265
for doc in documents:
12341266
doc["__id"] = doc.pop("_key")
1235-
results.append(cls.load(doc))
1267+
if not generic_query:
1268+
results.append(cls.load(doc))
1269+
else:
1270+
# Generic objects are not loaded, they are returned as dicts.
1271+
doc["id"] = doc.pop("__id")
1272+
del doc["_id"]
1273+
del doc["_rev"]
1274+
results.append(doc)
12361275
total = stats.get("fullCount", len(results))
12371276
return results, total or 0
12381277

@@ -1321,7 +1360,10 @@ def _get_collection(cls):
13211360
Returns:
13221361
The ArangoDB collection corresponding to the object class.
13231362
"""
1324-
return cls._db.collection(cls._collection_name)
1363+
if cls._collection_name is not None:
1364+
return cls._db.collection(cls._collection_name)
1365+
else:
1366+
return "all_objects_view"
13251367

13261368

13271369
def tagged_observables_export(cls, args):

core/web/apiv2/search.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from fastapi import APIRouter, Request
2+
from pydantic import BaseModel, ConfigDict
3+
4+
from core.database_arango import ArangoYetiConnector
5+
6+
# API endpoints
7+
router = APIRouter()
8+
9+
10+
class SearchRequest(BaseModel):
11+
"""Search request message."""
12+
13+
model_config = ConfigDict(extra="forbid")
14+
15+
query: dict[str, str | int | list] = {}
16+
sorting: list[tuple[str, bool]] = []
17+
filter_aliases: list[tuple[str, str]] = []
18+
count: int = 50
19+
page: int = 0
20+
21+
22+
class SearchResponse(BaseModel):
23+
"""Search response message."""
24+
25+
results: list[dict]
26+
total: int = 0
27+
28+
29+
@router.post("/")
30+
def search(httpreq: Request, request: SearchRequest) -> SearchResponse:
31+
"""Gets the system config."""
32+
results, total = ArangoYetiConnector.filter(
33+
request.query,
34+
sorting=request.sorting,
35+
aliases=request.filter_aliases,
36+
count=request.count,
37+
offset=request.page * request.count,
38+
links_count=True,
39+
user=httpreq.state.user,
40+
)
41+
return SearchResponse(results=results, total=total)

core/web/webapp.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
indicators,
2020
observables,
2121
rbac,
22+
search,
2223
system,
2324
tag,
2425
tasks,
@@ -40,6 +41,13 @@
4041

4142
api_router.include_router(audit.router, prefix="/audit", tags=["audit"])
4243

44+
api_router.include_router(
45+
search.router,
46+
prefix="/search",
47+
tags=["search"],
48+
dependencies=[Depends(auth.get_current_active_user)],
49+
)
50+
4351
api_router.include_router(
4452
observables.router,
4553
prefix="/observables",

tests/apiv2/search.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
import logging
2+
import sys
3+
import time
4+
import unittest
5+
6+
from fastapi.testclient import TestClient
7+
8+
from core import database_arango
9+
from core.schemas import dfiq, entity, indicator, observable
10+
from core.schemas.user import UserSensitive
11+
from core.web import webapp
12+
13+
client = TestClient(webapp.app)
14+
15+
16+
class searchTest(unittest.TestCase):
17+
def setUp(self) -> None:
18+
logging.disable(sys.maxsize)
19+
database_arango.db.connect(database="yeti_test")
20+
database_arango.db.truncate()
21+
22+
user = UserSensitive(username="test", password="test", enabled=True).save()
23+
apikey = user.create_api_key("default")
24+
token_data = client.post(
25+
"/api/v2/auth/api-token", headers={"x-yeti-apikey": apikey}
26+
).json()
27+
client.headers = {"Authorization": "Bearer " + token_data["access_token"]}
28+
29+
entity.Malware(
30+
name="test_malware",
31+
description="Test malware entity",
32+
type="malware",
33+
).save()
34+
m2 = entity.Malware(
35+
name="tagged_malware",
36+
description="malware entity 2",
37+
type="malware",
38+
).save()
39+
m2.tag("tagged")
40+
m2.tag("global")
41+
42+
r = indicator.Regex(
43+
name="test_regex_global",
44+
description="Test regex indicator",
45+
type="regex",
46+
pattern="^test.*",
47+
diamond="victim",
48+
).save()
49+
o = observable.Hostname(
50+
description="Test hostname observable",
51+
type="hostname",
52+
value="test.tomchop.me",
53+
).save()
54+
dfiq.DFIQScenario(
55+
name="test_dfiq",
56+
description="Test DFIQ",
57+
dfiq_tags=["tagged", "global"],
58+
dfiq_version="1.0.1",
59+
dfiq_yaml="name: test_dfiq\nversion: 1.0.1\ndescription: Test DFIQ",
60+
).save()
61+
r.tag("tagged")
62+
o.tag(["tagged", "global"])
63+
time.sleep(5)
64+
65+
def test_search_name_or_value(self) -> None:
66+
"""Test global search by name or value."""
67+
params = {"query": {"name": "test"}, "filter_aliases": [["value", "text"]]}
68+
response = client.post("/api/v2/search", json=params)
69+
self.assertEqual(response.status_code, 200, response.text)
70+
data = response.json()
71+
self.assertEqual(response.status_code, 200, data)
72+
self.assertIn("results", data)
73+
self.assertEqual(len(data["results"]), 4, data)
74+
names_or_values = [
75+
r["name"] if "name" in r else r["value"] for r in data["results"]
76+
]
77+
self.assertCountEqual(
78+
names_or_values,
79+
["test_malware", "test_regex_global", "test.tomchop.me", "test_dfiq"],
80+
data,
81+
)
82+
83+
def test_search_tag(self) -> None:
84+
"""Test global search by tag."""
85+
params = {
86+
"query": {"tags": ["tagged"]},
87+
}
88+
response = client.post("/api/v2/search", json=params)
89+
self.assertEqual(response.status_code, 200, response.text)
90+
data = response.json()
91+
self.assertEqual(response.status_code, 200, data)
92+
self.assertIn("results", data)
93+
self.assertEqual(len(data["results"]), 3, data)
94+
names_or_values = [
95+
r["name"] if "name" in r else r["value"] for r in data["results"]
96+
]
97+
self.assertCountEqual(
98+
names_or_values,
99+
[
100+
"test_regex_global",
101+
"test.tomchop.me",
102+
"tagged_malware",
103+
],
104+
data,
105+
)
106+
107+
def test_search_more_fields(self) -> None:
108+
"""Test global search with more fields."""
109+
params = {
110+
"query": {"name": "global"},
111+
"filter_aliases": [
112+
["value", "text"],
113+
["tags", ""],
114+
["dfiq_tags", "list"],
115+
],
116+
}
117+
response = client.post("/api/v2/search", json=params)
118+
self.assertEqual(response.status_code, 200, response.text)
119+
data = response.json()
120+
self.assertEqual(response.status_code, 200, data)
121+
self.assertIn("results", data)
122+
self.assertEqual(len(data["results"]), 4, data)
123+
names_or_values = [
124+
r["name"] if "name" in r else r["value"] for r in data["results"]
125+
]
126+
self.assertCountEqual(
127+
names_or_values,
128+
["test_regex_global", "test.tomchop.me", "tagged_malware", "test_dfiq"],
129+
data,
130+
)

0 commit comments

Comments
 (0)