Skip to content

Commit f02645e

Browse files
authored
Merge pull request #1112 from stratosphereips/alya/use-lru-cache
Optimize MAC OUIs and malicious domains lookups
2 parents 9bb16a5 + 34ecab9 commit f02645e

File tree

5 files changed

+143
-62
lines changed

5 files changed

+143
-62
lines changed

modules/ip_info/ip_info.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -198,30 +198,33 @@ def get_vendor_online(self, mac_addr):
198198
):
199199
return False
200200

201+
@staticmethod
201202
@lru_cache(maxsize=700)
203+
def _get_vendor_offline_cached(oui, mac_db_content):
204+
"""
205+
Static helper to perform the actual lookup based on OUI and cached content.
206+
"""
207+
for line in mac_db_content:
208+
if oui in line:
209+
line = json.loads(line)
210+
return line["vendorName"]
211+
return False
212+
202213
def get_vendor_offline(self, mac_addr, profileid):
203214
"""
204-
Gets vendor from Slips' offline database databases/macaddr-db.json
215+
Gets vendor from Slips' offline database at databases/macaddr-db.json.
205216
"""
206-
if not hasattr(self, "mac_db"):
217+
if not hasattr(self, "mac_db") or self.mac_db is None:
207218
# when update manager is done updating the mac db, we should ask
208219
# the db for all these pending queries
209220
self.pending_mac_queries.put((mac_addr, profileid))
210221
return False
211222

212223
oui = mac_addr[:8].upper()
213-
# parse the mac db and search for this oui
214224
self.mac_db.seek(0)
215-
while True:
216-
line = self.mac_db.readline()
217-
if line == "":
218-
# reached the end of file without finding the vendor
219-
# set the vendor to unknown to avoid searching for it again
220-
return False
225+
mac_db_content = self.mac_db.readlines()
221226

222-
if oui in line:
223-
line = json.loads(line)
224-
return line["vendorName"]
227+
return self._get_vendor_offline_cached(oui, tuple(mac_db_content))
225228

226229
def get_vendor(self, mac_addr: str, profileid: str) -> dict:
227230
"""

modules/threat_intelligence/threat_intelligence.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1429,16 +1429,6 @@ def is_malicious_url(self, url, uid, timestamp, daddr, profileid, twid):
14291429
"""Determines if a URL is considered malicious by querying online threat
14301430
intelligence sources.
14311431
1432-
Parameters:
1433-
- url (str): The URL to check.
1434-
- uid (str): Unique identifier for the network flow.
1435-
- timestamp (str): Timestamp when the network flow occurred.
1436-
- daddr (str): Destination IP address in the network flow.
1437-
- profileid (str): Identifier of the profile associated
1438-
with the network flow.
1439-
- twid (str): Time window identifier for when the network
1440-
flow occurred.
1441-
14421432
Returns:
14431433
- None: The function does not return a value but triggers
14441434
evidence creation if the URL is found to be malicious.
@@ -1633,18 +1623,6 @@ def is_malicious_domain(
16331623
malicious, it records an evidence entry and marks the
16341624
domain in the database.
16351625
1636-
Parameters:
1637-
domain (str): The domain name to be evaluated for
1638-
malicious activity.
1639-
uid (str): Unique identifier of the network flow
1640-
associated with this domain query.
1641-
timestamp (str): Timestamp when the domain query
1642-
was observed.
1643-
profileid (str): Identifier of the network profile
1644-
that initiated the domain query.
1645-
twid (str): Time window identifier during which the
1646-
domain query occurred.
1647-
16481626
Returns:
16491627
bool: False if the domain is ignored or not found in the
16501628
offline threat intelligence data, indicating no further action
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from collections import defaultdict
2+
from typing import (
3+
Dict,
4+
Tuple,
5+
Optional,
6+
)
7+
8+
9+
class TrieNode:
10+
def __init__(self):
11+
self.children = defaultdict(TrieNode)
12+
self.is_end_of_word = False
13+
self.domain_info = (
14+
None # store associated domain information if needed
15+
)
16+
17+
18+
class Trie:
19+
def __init__(self):
20+
self.root = TrieNode()
21+
22+
def insert(self, domain: str, domain_info: str):
23+
"""Insert a domain into the trie (using domain parts not chars)."""
24+
node = self.root
25+
parts = domain.split(".")[::-1] # reverse to handle subdomains
26+
for part in parts:
27+
node = node.children[part]
28+
node.is_end_of_word = True
29+
node.domain_info = domain_info
30+
31+
def search(self, domain: str) -> Tuple[bool, Optional[Dict[str, str]]]:
32+
"""
33+
Check if a domain or its subdomain exists in the trie
34+
(using domain parts instead of characters).
35+
Returns a tuple (found, domain_info).
36+
"""
37+
node = self.root
38+
# reverse domain to handle subdomains
39+
parts = domain.split(".")[::-1]
40+
for part in parts:
41+
if part not in node.children:
42+
return False, None
43+
44+
node = node.children[part]
45+
if node.is_end_of_word:
46+
return True, node.domain_info
47+
return False, None

slips_files/core/database/redis_db/database.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,8 @@ def __new__(
131131
)
132132

133133
cls._instances[cls.redis_port] = super().__new__(cls)
134+
super().__init__(cls)
135+
134136
# By default the slips internal time is
135137
# 0 until we receive something
136138
cls.set_slips_internal_time(0)

slips_files/core/database/redis_db/ioc_handler.py

Lines changed: 79 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,12 @@
77
Optional,
88
)
99

10+
from slips_files.common.data_structures.trie import Trie
11+
12+
# for future developers, remember to invalidate_trie_cache() on every
13+
# change to the self.constants.IOC_DOMAINS key or slips will keep using an
14+
# invalid cache to lookup malicious domains
15+
1016

1117
class IoCHandler:
1218
"""
@@ -17,6 +23,38 @@ class IoCHandler:
1723

1824
name = "DB"
1925

26+
def __init__(self):
27+
# used for faster domain lookups
28+
self.trie = None
29+
self.is_trie_cached = False
30+
31+
def _build_trie(self):
32+
"""Retrieve domains from Redis and construct the trie."""
33+
self.trie = Trie()
34+
ioc_domains: Dict[str, str] = self.rcache.hgetall(
35+
self.constants.IOC_DOMAINS
36+
)
37+
for domain, domain_info in ioc_domains.items():
38+
domain: str
39+
domain_info: str
40+
# domain_info is something like this
41+
# {"description": "['hack''malware''phishing']",
42+
# "source": "OCD-Datalake-russia-ukraine_IOCs-ALL.csv",
43+
# "threat_level": "medium",
44+
# "tags": ["Russia-UkraineIoCs"]}
45+
46+
# store parsed domain info
47+
self.trie.insert(domain, json.loads(domain_info))
48+
self.is_trie_cached = True
49+
50+
def _invalidate_trie_cache(self):
51+
"""
52+
Invalidate the trie cache.
53+
used whenever IOC_DOMAINS key is updated.
54+
"""
55+
self.trie = None
56+
self.is_trie_cached = False
57+
2058
def set_loaded_ti_files(self, number_of_loaded_files: int):
2159
"""
2260
Stores the number of successfully loaded TI files
@@ -43,6 +81,7 @@ def delete_feed_entries(self, url: str):
4381
if feed_to_delete in domain_description["source"]:
4482
# this entry has the given feed as source, delete it
4583
self.rcache.hdel(self.constants.IOC_DOMAINS, domain)
84+
self._invalidate_trie_cache()
4685

4786
# get all IPs that are read from TI files in our db
4887
ioc_ips = self.rcache.hgetall(self.constants.IOC_IPS)
@@ -139,6 +178,7 @@ def delete_domains_from_ioc_domains(self, domains: List[str]):
139178
Delete old domains from IoC
140179
"""
141180
self.rcache.hdel(self.constants.IOC_DOMAINS, *domains)
181+
self._invalidate_trie_cache()
142182

143183
def add_ips_to_ioc(self, ips_and_description: Dict[str, str]) -> None:
144184
"""
@@ -164,6 +204,7 @@ def add_domains_to_ioc(self, domains_and_description: dict) -> None:
164204
self.rcache.hmset(
165205
self.constants.IOC_DOMAINS, domains_and_description
166206
)
207+
self._invalidate_trie_cache()
167208

168209
def add_ip_range_to_ioc(self, malicious_ip_ranges: dict) -> None:
169210
"""
@@ -239,43 +280,53 @@ def is_blacklisted_ssl(self, sha1):
239280
info = self.rcache.hmget(self.constants.IOC_SSL, sha1)[0]
240281
return False if info is None else info
241282

283+
def _match_exact_domain(self, domain: str) -> Optional[Dict[str, str]]:
284+
"""checks if the given domain is blacklisted.
285+
checks only the exact given domain, no subdomains"""
286+
domain_description = self.rcache.hget(
287+
self.constants.IOC_DOMAINS, domain
288+
)
289+
if not domain_description:
290+
return
291+
return json.loads(domain_description)
292+
293+
def _match_subdomain(self, domain: str) -> Optional[Dict[str, str]]:
294+
"""
295+
Checks if we have any blacklisted domain that is a part of the
296+
given domain
297+
Uses a cached trie for optimization.
298+
"""
299+
# the goal here is we dont retrieve that huge amount of domains
300+
# from the db on every domain lookup
301+
# so we retrieve once, put em in a trie (aka cache them in memory),
302+
# keep using them from that data structure until a new domain is
303+
# added to the db, when that happens we invalidate the cache,
304+
# rebuild the trie, and keep using it from there.
305+
if not self.is_trie_cached:
306+
self._build_trie()
307+
308+
found, domain_info = self.trie.search(domain)
309+
if found:
310+
return domain_info
311+
242312
def is_blacklisted_domain(
243313
self, domain: str
244-
) -> Tuple[Dict[str, str], bool]:
314+
) -> Union[Tuple[Dict[str, str], bool], bool]:
245315
"""
246-
Search in the dB of malicious domains and return a
247-
description if we found a match
316+
Check if the given domain or its subdomain is blacklisted.
248317
returns a tuple (description, is_subdomain)
249318
description: description of the subdomain if found
250319
bool: True if we found a match for exactly the given
251320
domain False if we matched a subdomain
252321
"""
253-
domain_description = self.rcache.hget(
254-
self.constants.IOC_DOMAINS, domain
255-
)
256-
is_subdomain = False
257-
if domain_description:
258-
return json.loads(domain_description), is_subdomain
322+
if match := self._match_exact_domain(domain):
323+
is_subdomain = False
324+
return match, is_subdomain
259325

260-
# try to match subdomain
261-
ioc_domains: Dict[str, Dict[str, str]] = self.rcache.hgetall(
262-
self.constants.IOC_DOMAINS
263-
)
264-
for malicious_domain, domain_info in ioc_domains.items():
265-
malicious_domain: str
266-
domain_info: str
267-
# something like this
268-
# {"description": "['hack''malware''phishing']",
269-
# "source": "OCD-Datalake-russia-ukraine_IOCs-ALL.csv",
270-
# "threat_level": "medium",
271-
# "tags": ["Russia-UkraineIoCs"]}
272-
domain_info: Dict[str, str] = json.loads(domain_info)
273-
# if the we contacted images.google.com and we have
274-
# google.com in our blacklists, we find a match
275-
if malicious_domain in domain:
276-
is_subdomain = True
277-
return domain_info, is_subdomain
278-
return False, is_subdomain
326+
if match := self._match_subdomain(domain):
327+
is_subdomain = True
328+
return match, is_subdomain
329+
return False, False
279330

280331
def get_all_blacklisted_ip_ranges(self) -> dict:
281332
"""

0 commit comments

Comments
 (0)