diff --git a/tests/api.py b/tests/api.py index 0080312..bd8835d 100644 --- a/tests/api.py +++ b/tests/api.py @@ -46,6 +46,19 @@ def test_search_indicators(self, mock_post): }, ) + @patch("yeti.api.requests.Session.post") + def test_get_multiple_indicators(self, mock_post): + mock_response = MagicMock() + mock_response.content = b'{"indicators": [{"name": "test"}]}' + mock_post.return_value = mock_response + + result = self.api.get_multiple_indicators(["test"]) + self.assertEqual(result, [{"name": "test"}]) + mock_post.assert_called_with( + "http://fake-url/api/v2/indicators/get/multiple", + json={"names": ["test"], "count": 100, "page": 0}, + ) + @patch("yeti.api.requests.Session.post") def test_search_entities(self, mock_post): mock_response = MagicMock() @@ -65,6 +78,19 @@ def test_search_entities(self, mock_post): }, ) + @patch("yeti.api.requests.Session.post") + def test_get_multiple_entities(self, mock_post): + mock_response = MagicMock() + mock_response.content = b'{"entities": [{"name": "test_entity"}]}' + mock_post.return_value = mock_response + + result = self.api.get_multiple_entities(["test_entity"]) + self.assertEqual(result, [{"name": "test_entity"}]) + mock_post.assert_called_with( + "http://fake-url/api/v2/entities/get/multiple", + json={"names": ["test_entity"], "count": 100, "page": 0}, + ) + @patch("yeti.api.requests.Session.post") def test_search_observables(self, mock_post): mock_response = MagicMock() @@ -159,6 +185,19 @@ def test_search_dfiq(self, mock_post): }, ) + @patch("yeti.api.requests.Session.post") + def test_get_multiple_dfiq(self, mock_post): + mock_response = MagicMock() + mock_response.content = b'{"dfiq": [{"name": "test_dfiq"}]}' + mock_post.return_value = mock_response + + result = self.api.get_multiple_dfiq(["test_dfiq"]) + self.assertEqual(result, [{"name": "test_dfiq"}]) + mock_post.assert_called_with( + "http://fake-url/api/v2/dfiq/get/multiple", + json={"names": ["test_dfiq"], "count": 100, "page": 0}, + ) + @patch("yeti.api.requests.Session.post") def test_new_dfiq_from_yaml(self, mock_post): mock_response = MagicMock() diff --git a/tests/e2e.py b/tests/e2e.py index 6cbfe84..6cd5eb6 100644 --- a/tests/e2e.py +++ b/tests/e2e.py @@ -44,6 +44,46 @@ def test_auth_refresh(self): self.api.search_indicators(name="test") + def test_search_entities(self): + self.api.auth_api_key(os.getenv("YETI_API_KEY")) + self.api.new_entity( + { + "name": "testSearch", + "type": "malware", + "description": "test", + }, + tags=["testtag"], + ) + time.sleep(5) + result = self.api.search_entities(name="testSear", description="tes") + self.assertEqual(len(result), 1, result) + self.assertEqual(result[0]["name"], "testSearch") + self.assertEqual(result[0]["tags"][0]["name"], "testtag") + + def test_get_multiple_entities(self): + self.api.auth_api_key(os.getenv("YETI_API_KEY")) + self.api.new_entity( + { + "name": "testGet1", + "type": "malware", + "description": "test", + }, + tags=["testtag1"], + ) + self.api.new_entity( + { + "name": "testGet2", + "type": "malware", + "description": "test", + }, + tags=["testtag2"], + ) + time.sleep(5) + entities = self.api.get_multiple_entities(["testGet1", "testGet2"]) + self.assertEqual(len(entities), 2) + names = [entity["name"] for entity in entities] + self.assertCountEqual(names, ["testGet1", "testGet2"]) + def test_search_indicators(self): self.api.auth_api_key(os.getenv("YETI_API_KEY")) self.api.new_indicator( @@ -83,6 +123,34 @@ def test_find_indicator(self): self.assertEqual(indicator["pattern"], "test[0-9]") self.assertEqual(indicator["tags"][0]["name"], "testtag") + def test_get_multiple_indicators(self): + self.api.auth_api_key(os.getenv("YETI_API_KEY")) + self.api.new_indicator( + { + "name": "testGet1", + "type": "regex", + "description": "test", + "pattern": "test[0-9]", + "diamond": "victim", + }, + tags=["testtag1"], + ) + self.api.new_indicator( + { + "name": "testGet2", + "type": "regex", + "description": "test", + "pattern": "test[0-9]", + "diamond": "victim", + }, + tags=["testtag2"], + ) + time.sleep(5) + indicators = self.api.get_multiple_indicators(["testGet1", "testGet2"]) + self.assertEqual(len(indicators), 2) + names = [indicator["name"] for indicator in indicators] + self.assertCountEqual(names, ["testGet1", "testGet2"]) + def test_link_objects(self): self.api.auth_api_key(os.getenv("YETI_API_KEY")) indicator = self.api.new_indicator( diff --git a/yeti/api.py b/yeti/api.py index e26b9e8..389a053 100644 --- a/yeti/api.py +++ b/yeti/api.py @@ -253,6 +253,25 @@ def search_indicators( ) return json.loads(response)["indicators"] + def get_multiple_indicators( + self, names: list[str], count: int = 100, page: int = 0 + ) -> list[YetiObject]: + """Gets a list of indicators by name. + + Args: + names: The list of indicator names to retrieve. + count: The number of results to return (default is 100). + page: The page of results to return (default is 0, which means the first page). + + Returns: + A list of dicts representing the indicators. + """ + params = {"names": names, "count": count, "page": page} + response = self.do_request( + "POST", f"{self._url_root}/api/v2/indicators/get/multiple", json_data=params + ) + return json.loads(response)["indicators"] + def find_entity(self, name: str, type: str) -> YetiObject | None: """Finds an entity in Yeti by name. @@ -316,6 +335,25 @@ def search_entities( ) return json.loads(response)["entities"] + def get_multiple_entities( + self, names: list[str], count: int = 100, page: int = 0 + ) -> list[YetiObject]: + """Gets a list of entities by name. + + Args: + names: The list of entity names to retrieve. + count: The number of results to return (default is 100). + page: The page of results to return (default is 0, which means the first page). + + Returns: + A list of dicts representing the entities. + """ + params = {"names": names, "count": count, "page": page} + response = self.do_request( + "POST", f"{self._url_root}/api/v2/entities/get/multiple", json_data=params + ) + return json.loads(response)["entities"] + def find_observable(self, value: str, type: str) -> YetiObject | None: """Finds an observable in Yeti by value and type. @@ -594,6 +632,25 @@ def search_dfiq( ) return json.loads(response)["dfiq"] + def get_multiple_dfiq( + self, names: list[str], count: int = 100, page: int = 0 + ) -> list[YetiObject]: + """Gets a list of DFIQ objects by name. + + Args: + names: The list of DFIQ names to retrieve. + count: The number of results to return (default is 100). + page: The page of results to return (default is 0, which means the first page). + + Returns: + A list of dicts representing the DFIQ objects. + """ + params = {"names": names, "count": count, "page": page} + response = self.do_request( + "POST", f"{self._url_root}/api/v2/dfiq/get/multiple", json_data=params + ) + return json.loads(response)["dfiq"] + def new_dfiq_from_yaml(self, dfiq_type: str, dfiq_yaml: str) -> YetiObject: """Creates a new DFIQ object in Yeti from a YAML string.""" params = { @@ -776,6 +833,25 @@ def search_tags(self, name: str, count: int = 100, page: int = 0): ) return json.loads(response)["tags"] + def get_multiple_tags( + self, names: list[str], count: int = 100, page: int = 0 + ) -> list[dict[str, Any]]: + """Gets a list of tags by name. + + Args: + names: The list of tag names to retrieve. + count: The number of results to return (default is 100). + page: The page of results to return (default is 0, which means the first page). + + Returns: + A list of dicts representing the tags. + """ + params = {"names": names, "count": count, "page": page} + response = self.do_request( + "POST", f"{self._url_root}/api/v2/tags/get/multiple", json_data=params + ) + return json.loads(response)["tags"] + def link_objects( self, source: YetiObject,