diff --git a/tests/api.py b/tests/api.py index 2fa8332..a44eaf6 100644 --- a/tests/api.py +++ b/tests/api.py @@ -29,11 +29,21 @@ def test_search_indicators(self, mock_post): mock_response.content = b'{"indicators": [{"name": "test"}]}' mock_post.return_value = mock_response - result = self.api.search_indicators(name="test") + result = self.api.search_indicators( + name="test", description="test_description", tags=["tag1"] + ) self.assertEqual(result, [{"name": "test"}]) mock_post.assert_called_with( "http://fake-url/api/v2/indicators/search", - json={"query": {"name": "test"}, "count": 0}, + json={ + "query": { + "name": "test", + "description": "test_description", + "tags": ["tag1"], + }, + "count": 100, + "page": 0, + }, ) @patch("yeti.api.requests.Session.post") @@ -42,11 +52,17 @@ def test_search_entities(self, mock_post): mock_response.content = b'{"entities": [{"name": "test_entity"}]}' mock_post.return_value = mock_response - result = self.api.search_entities(name="test_entity") + result = self.api.search_entities( + name="test_entity", description="test_description" + ) self.assertEqual(result, [{"name": "test_entity"}]) mock_post.assert_called_with( "http://fake-url/api/v2/entities/search", - json={"query": {"name": "test_entity"}, "count": 0}, + json={ + "query": {"name": "test_entity", "description": "test_description"}, + "count": 100, + "page": 0, + }, ) @patch("yeti.api.requests.Session.post") @@ -55,11 +71,15 @@ def test_search_observables(self, mock_post): mock_response.content = b'{"observables": [{"value": "test_value"}]}' mock_post.return_value = mock_response - result = self.api.search_observables(value="test_value") + result = self.api.search_observables(value="test_value", tags=["tag1"]) self.assertEqual(result, [{"value": "test_value"}]) mock_post.assert_called_with( "http://fake-url/api/v2/observables/search", - json={"query": {"value": "test_value"}, "count": 0}, + json={ + "query": {"value": "test_value", "tags": ["tag1"]}, + "count": 100, + "page": 0, + }, ) @patch("yeti.api.requests.Session.post") @@ -121,11 +141,22 @@ def test_search_dfiq(self, mock_post): mock_response.content = b'{"dfiq": [{"name": "test_dfiq"}]}' mock_post.return_value = mock_response - result = self.api.search_dfiq(name="test_dfiq") + result = self.api.search_dfiq( + name="test_dfiq", dfiq_yaml="yaml_content", dfiq_tags=["tag1"] + ) self.assertEqual(result, [{"name": "test_dfiq"}]) mock_post.assert_called_with( "http://fake-url/api/v2/dfiq/search", - json={"query": {"name": "test_dfiq"}, "count": 0}, + json={ + "query": { + "name": "test_dfiq", + "dfiq_yaml": "yaml_content", + "dfiq_tags": ["tag1"], + }, + "count": 100, + "filter_aliases": [["dfiq_tags", "list"], ["dfiq_id", "text"]], + "page": 0, + }, ) @patch("yeti.api.requests.Session.post") @@ -275,14 +306,15 @@ def test_search_graph(self, mock_post): mock_response.content = b'{"graph": "data"}' mock_post.return_value = mock_response - result = self.api.search_graph("source", "graph", ["type"]) + result = self.api.search_graph("source", ["type"]) self.assertEqual(result, {"graph": "data"}) mock_post.assert_called_with( "http://fake-url/api/v2/graph/search", json={ - "count": 0, + "count": 50, + "page": 0, "source": "source", - "graph": "graph", + "graph": "links", "min_hops": 1, "max_hops": 1, "direction": "outbound", diff --git a/tests/e2e.py b/tests/e2e.py index 936ec7c..bac3ac4 100644 --- a/tests/e2e.py +++ b/tests/e2e.py @@ -48,7 +48,6 @@ def test_auth_refresh(self): self.api.search_indicators(name="test") def test_search_indicators(self): - self.api.auth_api_key(os.getenv("YETI_API_KEY")) self.api.auth_api_key(os.getenv("YETI_API_KEY")) self.api.new_indicator( { @@ -61,13 +60,14 @@ def test_search_indicators(self): tags=["testtag"], ) time.sleep(5) - result = self.api.search_indicators(name="testSear") + result = self.api.search_indicators( + name="testSear", description="tes", tags=["testtag"] + ) self.assertEqual(len(result), 1, result) self.assertEqual(result[0]["name"], "testSearch") self.assertEqual(result[0]["tags"][0]["name"], "testtag") def test_find_indicator(self): - self.api.auth_api_key(os.getenv("YETI_API_KEY")) self.api.auth_api_key(os.getenv("YETI_API_KEY")) self.api.new_indicator( { @@ -85,3 +85,39 @@ def test_find_indicator(self): self.assertEqual(indicator["name"], "testGet") self.assertEqual(indicator["pattern"], "test[0-9]") self.assertEqual(indicator["tags"][0]["name"], "testtag") + + def test_link_objects(self): + self.api.auth_api_key(os.getenv("YETI_API_KEY")) + indicator = self.api.new_indicator( + { + "name": "testLink", + "type": "regex", + "description": "test", + "pattern": "test[0-9]", + "diamond": "victim", + } + ) + malware = self.api.new_entity( + { + "name": "testMalware", + "type": "malware", + "description": "test", + } + ) + self.api.link_objects( + source=indicator, + target=malware, + link_type="indicates", + description="test link", + ) + + # get neighbors + neighbors = self.api.search_graph( + f'indicator/{indicator["id"]}', + target_types=["malware"], + include_original=False, + ) + self.assertEqual(len(neighbors["vertices"]), 1) + self.assertEqual( + neighbors["vertices"][f'entities/{malware["id"]}']["name"], "testMalware" + ) \ No newline at end of file diff --git a/yeti/api.py b/yeti/api.py index ba375ed..c879645 100644 --- a/yeti/api.py +++ b/yeti/api.py @@ -21,6 +21,38 @@ API_TOKEN_ENDPOINT = "/api/v2/auth/api-token" +SUPPORTED_IOC_TYPES = [ + "generic", + "ipv6", + "ipv4", + "hostname", + "url", + "file", + "sha256", + "md5", + "sha1", + "asn", + "wallet", + "certificate", + "cidr", + "mac_address", + "command_line", + "registry_key", + "imphash", + "tlsh", + "ssdeep", + "email", + "path", + "container_image", + "docker_image", + "user_agent", + "user_account", + "iban", + "bic", + "auth_secret", +] + + # typedef for a Yeti Objects YetiObject = dict[str, Any] YetiLinkObject = dict[str, Any] @@ -177,25 +209,29 @@ def search_indicators( name: str | None = None, indicator_type: str | None = None, pattern: str | None = None, + description: str | None = None, tags: list[str] | None = None, + count: int = 100, + page: int = 0, ) -> list[YetiObject]: """Searches for an indicator in Yeti. - One of name or pattern must be provided. + One of name, indicator_type, pattern, description, or tags must be provided. Args: name: The name of the indicator to search for. indicator_type: The type of the indicator to search for. pattern: The pattern of the indicator to search for. + description: The description of the indicator to search for. (substring match) tags: The tags of the indicator to search for. Returns: The response from the API; a list of dicts representing indicators. """ - if not any([name, indicator_type, pattern, tags]): + if not any([name, indicator_type, pattern, description, tags]): raise ValueError( - "You must provide one of name, indicator_type, pattern, or tags." + "You must provide one of name, indicator_type, pattern, description, or tags." ) query = {} @@ -203,11 +239,13 @@ def search_indicators( query["name"] = name if pattern: query["pattern"] = pattern + if description: + query["description"] = description if indicator_type: query["type"] = indicator_type if tags: query["tags"] = tags - params = {"query": query, "count": 0} + params = {"query": query, "count": count, "page": page} response = self.do_request( "POST", f"{self._url_root}/api/v2/indicators/search", @@ -237,8 +275,40 @@ def find_entity(self, name: str, type: str) -> YetiObject | None: raise return json.loads(response) - def search_entities(self, name: str) -> list[YetiObject]: - params = {"query": {"name": name}, "count": 0} + def search_entities( + self, + name: str | None = None, + entity_type: str | None = None, + description: str | None = None, + count: int = 100, + page: int = 0, + ) -> list[YetiObject]: + """Searches for entities in Yeti. + + One of name, type, or description must be provided. + + Args: + name: The name of the entity to search for (substring match). + entity_type: The type of the entity to search for. + description: The description of the entity to search for. (substring match) + count: The number of results to return (default is 100, which means all). + page: The page of results to return (default is 0, which means the first page). + + Returns: + The response from the API; a list of dicts representing entities. + """ + if not any([name, entity_type, description]): + raise ValueError("You must provide one of name, type, or description.") + + query = {} + if name: + query["name"] = name + if entity_type: + query["type"] = entity_type + if description: + query["description"] = description + + params = {"query": query, "count": count, "page": page} response = self.do_request( "POST", f"{self._url_root}/api/v2/entities/search", @@ -268,8 +338,60 @@ def find_observable(self, value: str, type: str) -> YetiObject | None: raise return json.loads(response) - def search_observables(self, value: str) -> list[YetiObject]: - """Searches for an observable in Yeti. + def match_observables( + self, + observables: list[str], + add_tags: list[str] | None = None, + regex_match: bool = False, + add_type: str = "guess", + fetch_neighbors: bool = True, + add_unknown: bool = False, + ): + """Matches a list of observables against the Yeti data graph. + + This is a more complex method than `search_observables`, as it will + obtain information on entities related to the observables, matching + indicators, and bloom filter hits. + + Args: + observables: The list of observable values to match. + add_tags: Optional. The tags to add to the matched observables. + regex_match: Whether to use regex matching (default is False). + add_type: Optional. The type to add to the matched observables. + Default is "guess", which will try to guess the type based on the + observable value. + fetch_neighbors: Whether to fetch neighbors of the matched observables + (default is True). + add_unknown: Whether to add unknown observables (default is False). + + Returns: + The response from the API; a dict with 'entities' (entities related + to the observables), 'obseravbles' (with the relationship to their + entities), 'known' (list of known observables), 'matches' (for + observables that matched an indicator), and 'unknown' (set of + unknown observables). + """ + params = { + "observables": observables, + "add_tags": add_tags or [], + "regex_match": regex_match, + "add_type": add_type, + "fetch_neighbors": fetch_neighbors, + "add_unknown": add_unknown, + } + response = self.do_request( + "POST", f"{self._url_root}/api/v2/graph/match", json_data=params + ) + return json.loads(response) + + def search_observables( + self, + value: str, + count: int = 100, + page: int = 0, + tags: list[str] | None = None, + ) -> list[YetiObject]: + """Searches for observables in Yeti. Args: value: The value of the observable to search for. @@ -277,7 +399,11 @@ def search_observables(self, value: str) -> list[YetiObject]: Returns: The response from the API; a dict representing the observable. """ - params = {"query": {"value": value}, "count": 0} + query = {"value": value} + if tags: + query["tags"] = tags + params = {"query": query, "count": count, "page": page} + response = self.do_request( "POST", f"{self._url_root}/api/v2/observables/search", json_data=params ) @@ -423,21 +549,46 @@ def find_dfiq(self, name: str, dfiq_type: str) -> YetiObject | None: raise return json.loads(response) - def search_dfiq(self, name: str, dfiq_type: str | None = None) -> list[YetiObject]: + def search_dfiq( + self, + name: str, + dfiq_type: str | None = None, + dfiq_yaml: str | None = None, + dfiq_tags: list[str] | None = None, + count: int = 100, + page: int = 0, + ) -> list[YetiObject]: """Searches for a DFIQ in Yeti. Args: - name: The name of the DFIQ object to search for, e.g. "Suspicious DNS + name: The name of the DFIQ object to search for, e.g. "Suspicious DNS Query." - dfiq_type: The type of the DFIQ object to search for, e.g. "scenario". + dfiq_type: The type of the DFIQ object to search for, e.g. "scenario". + dfiq_yaml: The YAML content of the DFIQ object to search for. + dfiq_tags: The tags of the DFIQ object to search for. + count: The number of results to return (default is 100, which means all). + page: The page of results to return (default is 0, which means the first page). Returns: - The response from the API; a dict representing the DFIQ object. + The response from the API; a dict representing the DFIQ object. """ - query = {"name": name} + query = { + "name": name, + } + + if dfiq_yaml: + query["dfiq_yaml"] = dfiq_yaml + if dfiq_tags: + query["dfiq_tags"] = dfiq_tags + + params = { + "query": query, + "count": count, + "page": page, + "filter_aliases": [["dfiq_tags", "list"], ["dfiq_id", "text"]], + } if dfiq_type: - query["type"] = dfiq_type - params = {"query": query, "count": 0} + params["type"] = dfiq_type response = self.do_request( "POST", f"{self._url_root}/api/v2/dfiq/search", json_data=params ) @@ -556,7 +707,7 @@ def add_observables_bulk( Args: observables: The list of observables to add. Dictionaries should have a - 'value' (str) and a 'type' (str) key. See TACO_TYPE_MAPPING for a list + 'value' (str) and a 'type' (str) key. See SUPPORTED_IOC_TYPES for a list of supported types. tags: The tags to associate with all observables. @@ -623,12 +774,14 @@ def link_objects( def search_graph( self, source: str, - graph: str, target_types: list[str], + graph: str = "links", min_hops: int = 1, max_hops: int = 1, direction: str = "outbound", include_original: bool = True, + count: int = 50, + page: int = 0, ) -> dict[str, Any]: """Searches the graph for objects related to a given object. @@ -637,20 +790,23 @@ def search_graph( for details. Args: - source: The ID of the source object (as provided by Yeti) in the format - "/", such as 'dfiq/id'. - graph: The graph to search, such as 'links'. - target_types: The types of objects to search for. - min_hops: The minimum number of hops to search. - max_hops: The maximum number of hops to search. - direction: The direction to search. - include_original: Whether to include the source object in the results. + source: The ID of the source object (as provided by Yeti) in the format + "/", such as 'dfiq/12345'. + target_types: The types of objects to search for. + min_hops: The minimum number of hops to search. + max_hops: The maximum number of hops to search. + direction: The direction to search. "inbound" or "outbound" or "both". + include_original: Whether to include the source object in the results. + count: The number of results to return (default is 50). + page: The page of results to return (default is 0, which means the first page). Returns: - The response from the API; a dict representing the graph. + The response from the API; a dict representing the graph. If the number + of results is lower than the count, the search is complete. """ params = { - "count": 0, + "count": count, + "page": page, "source": source, "graph": graph, "min_hops": min_hops,