Skip to content

Commit 42c9e28

Browse files
authored
Add search_multiple endpoints (#1258)
1 parent cbbd657 commit 42c9e28

File tree

8 files changed

+171
-0
lines changed

8 files changed

+171
-0
lines changed

core/web/apiv2/dfiq.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,17 @@ class DFIQSearchRequest(BaseModel):
5353
page: int = 0
5454

5555

56+
class DFIQGetMultipleRequest(BaseModel):
57+
model_config = ConfigDict(extra="forbid")
58+
59+
names: list[str] = []
60+
type: dfiq.DFIQType | None = None
61+
sorting: list[tuple[str, bool]] = []
62+
filter_aliases: list[tuple[str, str]] = []
63+
count: int = 50
64+
page: int = 0
65+
66+
5667
class DFIQSearchResponse(BaseModel):
5768
model_config = ConfigDict(extra="forbid")
5869

@@ -420,3 +431,22 @@ def search(httpreq: Request, request: DFIQSearchRequest) -> DFIQSearchResponse:
420431
user=httpreq.state.user,
421432
)
422433
return DFIQSearchResponse(dfiq=dfiq_objects, total=total)
434+
435+
436+
@router.post("/get/multiple")
437+
def get_multiple(
438+
httpreq: Request, request: DFIQGetMultipleRequest
439+
) -> DFIQSearchResponse:
440+
"""Gets multiple DFIQ objects by name."""
441+
query = {"name__in": request.names}
442+
if request.type:
443+
query["type"] = request.type
444+
dfiq_objects, total = dfiq.DFIQBase.filter(
445+
query_args=query,
446+
offset=request.page * request.count,
447+
count=request.count,
448+
sorting=request.sorting,
449+
aliases=request.filter_aliases,
450+
user=httpreq.state.user,
451+
)
452+
return DFIQSearchResponse(dfiq=dfiq_objects, total=total)

core/web/apiv2/entities.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,17 @@ class EntitySearchRequest(BaseModel):
3333
page: int = 0
3434

3535

36+
class EntityMultipleGetRequest(BaseModel):
37+
model_config = ConfigDict(extra="forbid")
38+
39+
names: list[str] = []
40+
type: EntityType | None = None
41+
sorting: list[tuple[str, bool]] = []
42+
filter_aliases: list[tuple[str, str]] = []
43+
count: int = 50
44+
page: int = 0
45+
46+
3647
class EntitySearchResponse(BaseModel):
3748
model_config = ConfigDict(extra="forbid")
3849

@@ -191,6 +202,26 @@ def search(httpreq: Request, request: EntitySearchRequest) -> EntitySearchRespon
191202
return response
192203

193204

205+
@router.post("/get/multiple")
206+
def get_multiple(
207+
httpreq: Request, request: EntityMultipleGetRequest
208+
) -> EntitySearchResponse:
209+
"""Gets multiple entities by name."""
210+
query = {"name__in": request.names}
211+
if request.type:
212+
query["type"] = request.type
213+
entities, total = Entity.filter(
214+
query_args=query,
215+
offset=request.page * request.count,
216+
count=request.count,
217+
sorting=request.sorting,
218+
aliases=request.filter_aliases,
219+
links_count=True,
220+
user=httpreq.state.user,
221+
)
222+
return EntitySearchResponse(entities=entities, total=total)
223+
224+
194225
@router.post("/tag")
195226
@rbac.permission_on_ids(roles.Permission.WRITE)
196227
def tag(httpreq: Request, request: EntityTagRequest) -> EntityTagResponse:

core/web/apiv2/indicators.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,17 @@ class IndicatorSearchRequest(BaseModel):
4343
page: int = 0
4444

4545

46+
class IndicatorMultipleGetRequest(BaseModel):
47+
model_config = ConfigDict(extra="forbid")
48+
49+
names: list[str] = []
50+
type: IndicatorType | None = None
51+
sorting: list[tuple[str, bool]] = []
52+
filter_aliases: list[tuple[str, str]] = []
53+
count: int = 50
54+
page: int = 0
55+
56+
4657
class IndicatorSearchResponse(BaseModel):
4758
model_config = ConfigDict(extra="forbid")
4859

@@ -237,6 +248,25 @@ def search(
237248
return IndicatorSearchResponse(indicators=indicators, total=total)
238249

239250

251+
@router.post("/get/multiple")
252+
def get_multiple(
253+
httpreq: Request, request: IndicatorMultipleGetRequest
254+
) -> IndicatorSearchResponse:
255+
"""Gets multiple indicators by name."""
256+
query = {"name__in": request.names}
257+
if request.type:
258+
query["type"] = request.type
259+
indicators, total = Indicator.filter(
260+
query_args=query,
261+
offset=request.page * request.count,
262+
count=request.count,
263+
sorting=request.sorting,
264+
aliases=request.filter_aliases,
265+
user=httpreq.state.user,
266+
)
267+
return IndicatorSearchResponse(indicators=indicators, total=total)
268+
269+
240270
@router.post("/tag")
241271
@rbac.permission_on_ids(roles.Permission.WRITE)
242272
def tag(httpreq: Request, request: IndicatorTagRequest) -> IndicatorTagResponse:

core/web/apiv2/tag.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@ class TagSearchRequest(BaseModel):
3333
page: int
3434

3535

36+
class TagMultipleGetRequest(BaseModel):
37+
model_config = ConfigDict(extra="forbid")
38+
39+
names: list[str] = []
40+
count: int
41+
page: int
42+
43+
3644
class TagSearchResponse(BaseModel):
3745
model_config = ConfigDict(extra="forbid")
3846

@@ -109,6 +117,16 @@ def search(request: TagSearchRequest) -> TagSearchResponse:
109117
return TagSearchResponse(tags=tags, total=total)
110118

111119

120+
@router.post("/get/multiple")
121+
def get_multiple(request: TagMultipleGetRequest) -> TagSearchResponse:
122+
"""Gets multiple Tags by name."""
123+
request_args = {"name__in": request.names}
124+
count = request.count
125+
page = request.page
126+
tags, total = Tag.filter(request_args, offset=page * count, count=count)
127+
return TagSearchResponse(tags=tags, total=total)
128+
129+
112130
@router.delete("/{tag_id}")
113131
def delete(tag_id: str) -> None:
114132
"""Deletes a Tag."""

tests/apiv2/dfiq.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -678,3 +678,27 @@ def test_to_archive(self):
678678
self.assertIn("semi_private_question", content)
679679
self.assertIn("public_approach", content)
680680
self.assertIn("internal_approach", content)
681+
682+
def test_get_multiple(self):
683+
with open("tests/dfiq_test_data/S1003.yaml", "r") as f:
684+
yaml_string = f.read()
685+
686+
response = client.post(
687+
"/api/v2/dfiq/from_yaml",
688+
json={
689+
"dfiq_yaml": yaml_string,
690+
"dfiq_type": dfiq.DFIQType.scenario,
691+
},
692+
)
693+
data = response.json()
694+
self.assertEqual(response.status_code, 200, data)
695+
696+
response = client.post(
697+
"/api/v2/dfiq/get/multiple",
698+
json={"names": ["scenario1"], "page": 0, "count": 10},
699+
)
700+
data = response.json()
701+
self.assertEqual(response.status_code, 200, data)
702+
self.assertEqual(len(data["dfiq"]), 1)
703+
self.assertEqual(data["total"], 1)
704+
self.assertEqual(data["dfiq"][0]["name"], "scenario1")

tests/apiv2/entities.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,17 @@ def test_search_entities(self):
110110
self.assertEqual(len(data["entities"][0]["tags"]), 1, data)
111111
self.assertIn("ta1", [tag["name"] for tag in data["entities"][0]["tags"]])
112112

113+
def test_get_multiple_entities(self):
114+
response = client.post(
115+
"/api/v2/entities/get/multiple",
116+
json={"names": ["ta1", "bears", "foo"], "page": 0, "count": 10},
117+
)
118+
data = response.json()
119+
self.assertEqual(response.status_code, 200, data)
120+
self.assertEqual(len(data["entities"]), 2)
121+
names = [e["name"] for e in data["entities"]]
122+
self.assertCountEqual(names, ["ta1", "bears"])
123+
113124
def test_search_entities_tagged(self):
114125
response = client.post(
115126
"/api/v2/entities/search",

tests/apiv2/indicators.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,17 @@ def test_search_indicators(self):
122122
self.assertEqual(len(data["indicators"][0]["tags"]), 1)
123123
self.assertIn("hextag", [tag["name"] for tag in data["indicators"][0]["tags"]])
124124

125+
def test_get_multiple_indicators(self):
126+
response = client.post(
127+
"/api/v2/indicators/get/multiple",
128+
json={"names": ["hex", "localhost", "foo"], "page": 0, "count": 10},
129+
)
130+
self.assertEqual(response.status_code, 200)
131+
data = response.json()
132+
self.assertEqual(len(data["indicators"]), 2)
133+
names = [indicator["name"] for indicator in data["indicators"]]
134+
self.assertCountEqual(names, ["hex", "localhost"])
135+
125136
def test_search_indicators_tagged(self):
126137
response = client.post(
127138
"/api/v2/indicators/search",

tests/apiv2/tags.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,22 @@ def test_tag_search(self):
7373
data = response.json()
7474
self.assertEqual(len(data), 2)
7575

76+
def test_tag_get_multiple(self):
77+
response = client.post("/api/v2/tags/", json={"name": "tag2-test"})
78+
response = client.post("/api/v2/tags/", json={"name": "tag3-test"})
79+
self.assertEqual(response.status_code, 200)
80+
81+
response = client.post(
82+
"/api/v2/tags/get/multiple",
83+
json={"names": ["tag1", "tag2-test"], "page": 0, "count": 10},
84+
)
85+
self.assertEqual(response.status_code, 200)
86+
data = response.json()
87+
self.assertEqual(len(data["tags"]), 2)
88+
self.assertEqual(data["total"], 2)
89+
self.assertEqual(data["tags"][0]["name"], "tag1")
90+
self.assertEqual(data["tags"][1]["name"], "tag2-test")
91+
7692
def test_tag_delete(self):
7793
response = client.delete(f"/api/v2/tags/{self.tag.id}")
7894
self.assertEqual(response.status_code, 200)

0 commit comments

Comments
 (0)