diff --git a/config/whitelist.conf b/config/whitelist.conf index f277693fe..be554887d 100644 --- a/config/whitelist.conf +++ b/config/whitelist.conf @@ -162,3 +162,4 @@ organization,google,both,alerts organization,apple,both,alerts organization,twitter,both,alerts domain,markmonitor.com,both,alerts +domain,whois.nic.co,both,alerts diff --git a/docs/usage.md b/docs/usage.md index 69d6c1bef..3f85a9ae5 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -468,28 +468,6 @@ The values for each column are the following: - Ignore alerts: slips reads all the flows, but it just ignores alerting if there is a match. - Ignore flows: the flow will be completely discarded. -### Removing values from the Whitelist - -Whitelisted IoCs can be updated: -1. When you re-start Slips -2. On the fly while running Slips - -If you're updating the whitelist while Slips is running, be careful to use ; to comment out the lines you want to remove from the db -for example, if you have the following line in `whitelist.conf`: - -``` -organization,google,both,alerts -``` - -To be able to remove this whitelist entry while Slips is running, simply change it to - -``` -# organization,google,both,alerts -``` - -Comments starting with `;` are not removed from the database and are treated as user comments. -Comments starting with `#` will cause Slips to attempt to remove that entry from the database. - ## Popup notifications Slips Support displaying popup notifications whenever there's an alert. @@ -622,7 +600,7 @@ Check [rotation section](https://stratospherelinuxips.readthedocs.io/en/develop/ But you can also enable storing a copy of zeek log files in the output directory after the analysis is done by setting ```store_a_copy_of_zeek_files``` to yes, -or while zeek is stil generating log files by setting ```store_zeek_files_in_the_output_dir``` to yes. +or while zeek is still generating log files by setting ```store_zeek_files_in_the_output_dir``` to yes. this option stores a copy of the zeek files present in ```zeek_files/``` the moment slips stops. so this doesn't include deleted zeek logs. diff --git a/install/requirements.txt b/install/requirements.txt index c8d03d28a..09515e21e 100644 --- a/install/requirements.txt +++ b/install/requirements.txt @@ -40,6 +40,7 @@ pre-commit==4.3.0 coverage==7.11.0 netifaces==0.11.0 scapy==2.6.1 +pybloom_live pyyaml pytest-asyncio vulture diff --git a/managers/process_manager.py b/managers/process_manager.py index 169322e62..3366c0168 100644 --- a/managers/process_manager.py +++ b/managers/process_manager.py @@ -39,6 +39,7 @@ from slips_files.common.style import green from slips_files.core.evidence_handler import EvidenceHandler +from slips_files.core.helpers.bloom_filters_manager import BFManager from slips_files.core.input import Input from slips_files.core.output import Output from slips_files.core.profiler import Profiler @@ -111,6 +112,7 @@ def start_profiler_process(self): self.main.args, self.main.conf, self.main.pid, + self.main.bloom_filters_man, is_profiler_done=self.is_profiler_done, profiler_queue=self.profiler_queue, is_profiler_done_event=self.is_profiler_done_event, @@ -134,6 +136,7 @@ def start_evidence_process(self): self.main.args, self.main.conf, self.main.pid, + self.main.bloom_filters_man, ) evidence_process.start() self.main.print( @@ -154,6 +157,7 @@ def start_input_process(self): self.main.args, self.main.conf, self.main.pid, + self.main.bloom_filters_man, is_input_done=self.is_input_done, profiler_queue=self.profiler_queue, input_type=self.main.input_type, @@ -399,6 +403,7 @@ def load_modules(self): self.main.args, self.main.conf, self.main.pid, + self.main.bloom_filters_man, ) module.start() self.main.db.store_pid(module_name, int(module.pid)) @@ -430,17 +435,30 @@ def print_stopped_module(self, module): f"\t{green(module)} \tStopped. " f"" f"{green(modules_left)} left." ) + def init_bloom_filters_manager(self): + """this instance is shared accross all slips IModule instances, + because we dont wanna re-create the filters once for each process, + this way is more memory efficient""" + return BFManager( + self.main.logger, + self.main.args.output, + self.main.redis_port, + self.main.conf, + self.main.pid, + ) + def start_update_manager(self, local_files=False, ti_feeds=False): """ starts the update manager process PS; this function is blocking, slips.py will not start the rest of the - module unless this functionis done + module unless this function's done :kwarg local_files: if true, updates the local ports and org files from disk :kwarg ti_feeds: if true, updates the remote TI feeds. PS: this takes time. """ try: + bloom_filters_man = getattr(self.main, "bloom_filters_man", None) # only one instance of slips should be able to update ports # and orgs at a time # so this function will only be allowed to run from 1 slips @@ -456,6 +474,7 @@ def start_update_manager(self, local_files=False, ti_feeds=False): self.main.args, self.main.conf, self.main.pid, + bloom_filters_man, ) if local_files: diff --git a/modules/flowalerts/flowalerts.py b/modules/flowalerts/flowalerts.py index 49cb31684..9a5951065 100644 --- a/modules/flowalerts/flowalerts.py +++ b/modules/flowalerts/flowalerts.py @@ -29,7 +29,7 @@ class FlowAlerts(AsyncModule): def init(self): self.subscribe_to_channels() - self.whitelist = Whitelist(self.logger, self.db) + self.whitelist = Whitelist(self.logger, self.db, self.bloom_filters) self.dns = DNS(self.db, flowalerts=self) self.software = Software(self.db, flowalerts=self) self.notice = Notice(self.db, flowalerts=self) diff --git a/modules/ip_info/ip_info.py b/modules/ip_info/ip_info.py index 085349a95..e16aaf51e 100644 --- a/modules/ip_info/ip_info.py +++ b/modules/ip_info/ip_info.py @@ -64,7 +64,7 @@ def init(self): "new_dns": self.c3, "check_jarm_hash": self.c4, } - self.whitelist = Whitelist(self.logger, self.db) + self.whitelist = Whitelist(self.logger, self.db, self.bloom_filters) self.is_running_non_stop: bool = self.db.is_running_non_stop() self.valid_tlds = whois.validTlds() self.is_running_in_ap_mode: bool = ( diff --git a/modules/threat_intelligence/threat_intelligence.py b/modules/threat_intelligence/threat_intelligence.py index 2c3773cf4..0c589b7f1 100644 --- a/modules/threat_intelligence/threat_intelligence.py +++ b/modules/threat_intelligence/threat_intelligence.py @@ -609,6 +609,10 @@ def is_valid_threat_level(self, threat_level): return threat_level in utils.threat_levels def parse_known_fp_hashes(self, fullpath: str): + """ + That file contains known FalsePositives of hashes to reduce the + amount of FP from TI files + """ fp_hashes = {} with open(fullpath) as fps: # skip comments diff --git a/modules/update_manager/update_manager.py b/modules/update_manager/update_manager.py index 96cb1b5ad..4fd7abc1b 100644 --- a/modules/update_manager/update_manager.py +++ b/modules/update_manager/update_manager.py @@ -56,7 +56,7 @@ def init(self): self.loaded_ti_files = 0 # don't store iocs older than 1 week self.interval = 7 - self.whitelist = Whitelist(self.logger, self.db) + self.whitelist = Whitelist(self.logger, self.db, self.bloom_filters) self.slips_logfile = self.db.get_stdfile("stdout") self.org_info_path = "slips_files/organizations_info/" self.path_to_mac_db = "databases/macaddress-db.json" @@ -1484,10 +1484,16 @@ def update_local_whitelist(self): self.whitelist.update() def update_org_files(self): + """ + This func handles organizations whitelist files. + It updates the local IoCs of every supported organization in the db + and initializes the bloom filters + """ for org in utils.supported_orgs: org_ips = os.path.join(self.org_info_path, org) org_asn = os.path.join(self.org_info_path, f"{org}_asn") org_domains = os.path.join(self.org_info_path, f"{org}_domains") + if self.check_if_update_org(org_ips): self.whitelist.parser.load_org_ips(org) @@ -1576,10 +1582,11 @@ def update_online_whitelist(self): # delete the old ones self.db.delete_tranco_whitelist() response = self.responses["tranco_whitelist"] + domains = [] for line in response.text.splitlines(): - domain = line.split(",")[1] - domain.strip() - self.db.store_tranco_whitelisted_domain(domain) + domain = line.split(",")[1].strip() + domains.append(domain) + self.db.store_tranco_whitelisted_domains(domains) self.mark_feed_as_updated("tranco_whitelist") diff --git a/slips/main.py b/slips/main.py index a091b4d22..7a80daefb 100644 --- a/slips/main.py +++ b/slips/main.py @@ -27,6 +27,7 @@ from slips_files.common.slips_utils import utils from slips_files.common.style import green, yellow from slips_files.core.database.database_manager import DBManager +from slips_files.core.helpers.bloom_filters_manager import BFManager from slips_files.core.helpers.checker import Checker @@ -586,6 +587,9 @@ def start(self): # if slips is given a .rdb file, don't load the # modules as we don't need them if not self.args.db: + self.bloom_filters_man: BFManager = ( + self.proc_man.init_bloom_filters_manager() + ) # update local files before starting modules # if wait_for_TI_to_finish is set to true in the config file, # slips will wait untill all TI files are updated before @@ -595,6 +599,10 @@ def start(self): ti_feeds=self.conf.wait_for_TI_to_finish(), ) self.print("Starting modules", 1, 0) + # initialize_filter must be called after the update manager + # is started, and before the modules start. why? because + # update manager updates the iocs that the bloom filters need + self.bloom_filters_man.initialize_filter() self.proc_man.load_modules() # give outputprocess time to print all the started modules time.sleep(0.5) diff --git a/slips_files/common/abstracts/imodule.py b/slips_files/common/abstracts/imodule.py index 73182aa61..8798fd9b9 100644 --- a/slips_files/common/abstracts/imodule.py +++ b/slips_files/common/abstracts/imodule.py @@ -11,6 +11,7 @@ Optional, ) from slips_files.common.printer import Printer +from slips_files.core.helpers.bloom_filters_manager import BFManager from slips_files.core.output import Output from slips_files.common.slips_utils import utils from slips_files.core.database.database_manager import DBManager @@ -38,6 +39,7 @@ def __init__( slips_args, conf, ppid: int, + bloom_filters_manager: BFManager, **kwargs, ): Process.__init__(self) @@ -53,6 +55,7 @@ def __init__( # used to tell all slips.py children to stop self.termination_event: Event = termination_event self.logger = logger + self.bloom_filters: BFManager = bloom_filters_manager self.printer = Printer(self.logger, self.name) self.db = DBManager( self.logger, self.output_dir, self.redis_port, self.conf, self.ppid diff --git a/slips_files/core/database/database_manager.py b/slips_files/core/database/database_manager.py index 39bde1c98..236d2085d 100644 --- a/slips_files/core/database/database_manager.py +++ b/slips_files/core/database/database_manager.py @@ -367,8 +367,8 @@ def get_modified_tw(self, *args, **kwargs): def get_field_separator(self, *args, **kwargs): return self.rdb.get_field_separator(*args, **kwargs) - def store_tranco_whitelisted_domain(self, *args, **kwargs): - return self.rdb.store_tranco_whitelisted_domain(*args, **kwargs) + def store_tranco_whitelisted_domains(self, *args, **kwargs): + return self.rdb.store_tranco_whitelisted_domains(*args, **kwargs) def is_whitelisted_tranco_domain(self, *args, **kwargs): return self.rdb.is_whitelisted_tranco_domain(*args, **kwargs) @@ -478,6 +478,9 @@ def get_pids(self, *args, **kwargs): def set_org_info(self, *args, **kwargs): return self.rdb.set_org_info(*args, **kwargs) + def set_org_cidrs(self, *args, **kwargs): + return self.rdb.set_org_cidrs(*args, **kwargs) + def get_org_info(self, *args, **kwargs): return self.rdb.get_org_info(*args, **kwargs) @@ -487,12 +490,21 @@ def get_org_ips(self, *args, **kwargs): def set_whitelist(self, *args, **kwargs): return self.rdb.set_whitelist(*args, **kwargs) - def get_all_whitelist(self, *args, **kwargs): - return self.rdb.get_all_whitelist(*args, **kwargs) - def get_whitelist(self, *args, **kwargs): return self.rdb.get_whitelist(*args, **kwargs) + def is_whitelisted(self, *args, **kwargs): + return self.rdb.is_whitelisted(*args, **kwargs) + + def is_domain_in_org_domains(self, *args, **kwargs): + return self.rdb.is_domain_in_org_domains(*args, **kwargs) + + def is_asn_in_org_asn(self, *args, **kwargs): + return self.rdb.is_asn_in_org_asn(*args, **kwargs) + + def is_ip_in_org_ips(self, *args, **kwargs): + return self.rdb.is_ip_in_org_cidrs(*args, **kwargs) + def has_cached_whitelist(self, *args, **kwargs): return self.rdb.has_cached_whitelist(*args, **kwargs) diff --git a/slips_files/core/database/redis_db/alert_handler.py b/slips_files/core/database/redis_db/alert_handler.py index a8e37f7b6..c9c747f79 100644 --- a/slips_files/core/database/redis_db/alert_handler.py +++ b/slips_files/core/database/redis_db/alert_handler.py @@ -53,7 +53,7 @@ def mark_profile_as_malicious(self, profileid: ProfileID): def get_malicious_profiles(self): """returns profiles that generated an alert""" - self.r.smembers(self.constants.MALICIOUS_PROFILES) + return self.r.smembers(self.constants.MALICIOUS_PROFILES) def set_evidence_causing_alert(self, alert: Alert): """ diff --git a/slips_files/core/database/redis_db/constants.py b/slips_files/core/database/redis_db/constants.py index 35b135fbd..0f5ff8f5e 100644 --- a/slips_files/core/database/redis_db/constants.py +++ b/slips_files/core/database/redis_db/constants.py @@ -34,7 +34,6 @@ class Constants: PIDS = "PIDs" MAC = "MAC" MODIFIED_TIMEWINDOWS = "ModifiedTW" - ORG_INFO = "OrgInfo" ACCUMULATED_THREAT_LEVELS = "accumulated_threat_levels" TRANCO_WHITELISTED_DOMAINS = "tranco_whitelisted_domains" WHITELIST = "whitelist" @@ -66,7 +65,7 @@ class Constants: BLOCKED_PROFILES_AND_TWS = "BlockedProfTW" PROFILES = "profiles" NUMBER_OF_ALERTS = "number_of_alerts" - KNOWN_FPS = "known_fps" + KNOWN_FP_MD5_HASHES = "known_fps" WILL_SLIPS_HAVE_MORE_FLOWS = "will_slips_have_more_flows" SUBS_WHO_PROCESSED_MSG = "number_of_subscribers_who_processed_this_msg" FLOWS_ANALYZED_BY_ALL_MODULES_PER_MIN = "flows_analyzed_per_minute" diff --git a/slips_files/core/database/redis_db/database.py b/slips_files/core/database/redis_db/database.py index d91e44c07..8aa0f8dea 100644 --- a/slips_files/core/database/redis_db/database.py +++ b/slips_files/core/database/redis_db/database.py @@ -62,7 +62,6 @@ class RedisDB(IoCHandler, AlertHandler, ProfileHandler, P2PHandler): "new_notice", "new_url", "new_downloaded_file", - "reload_whitelist", "new_service", "new_arp", "new_MAC", @@ -1152,11 +1151,11 @@ def _determine_gw_mac(self, ip, mac, interface: str): return True return False - def get_ip_of_mac(self, MAC): + def get_ip_of_mac(self, mac_addr: str): """ Returns the IP associated with the given MAC in our database """ - return self.r.hget(self.constants.MAC, MAC) + return self.r.hget(self.constants.MAC, mac_addr) def get_modified_tw(self): """Return all the list of modified tw""" @@ -1169,14 +1168,14 @@ def get_field_separator(self): """Return the field separator""" return self.separator - def store_tranco_whitelisted_domain(self, domain): + def store_tranco_whitelisted_domains(self, domains: List[str]): """ - store whitelisted domain from tranco whitelist in the db + store whitelisted domains from tranco whitelist in the db """ # the reason we store tranco whitelisted domains in the cache db # instead of the main db is, we don't want them cleared on every new # instance of slips - self.rcache.sadd(self.constants.TRANCO_WHITELISTED_DOMAINS, domain) + self.rcache.sadd(self.constants.TRANCO_WHITELISTED_DOMAINS, *domains) def is_whitelisted_tranco_domain(self, domain): return self.rcache.sismember( @@ -1552,81 +1551,147 @@ def get_name_of_module_at(self, given_pid): if int(given_pid) == int(pid): return name - def set_org_info(self, org, org_info, info_type): + def set_org_cidrs(self, org, org_ips: Dict[str, List[str]]): """ - store ASN, IP and domains of an org in the db + stores CIDRs of an org in the db :param org: supported orgs are ('google', 'microsoft', 'apple', 'facebook', 'twitter') - : param org_info: a json serialized list of asns or ips or domains - :param info_type: supported types are 'asn', 'domains', 'IPs' + :param org_ips: A dict with the first octet of a cidr, + and the full cidr as keys. + something like { + '2401': ['2401:fa00::/42', '2401:fa00:4::/48'] + '70': ['70.32.128.0/19','70.32.136.0/24'] + } """ - # info will be stored in OrgInfo key {'facebook_asn': .., - # 'twitter_domains': ...} - self.rcache.hset( - self.constants.ORG_INFO, f"{org}_{info_type}", org_info - ) + key = f"{org}_IPs" + if isinstance(org_ips, dict): + serializable = {str(k): json.dumps(v) for k, v in org_ips.items()} + self.rcache.hset(key, mapping=serializable) - def get_org_info(self, org, info_type) -> str: + def set_org_info(self, org, org_info: List[str], info_type: str): + """ + store ASN or domains of an org in the db + :param org: supported orgs are ('google', 'microsoft', + 'apple', 'facebook', 'twitter') + : param org_info: a list of asns or ips or domains + :param info_type: supported types are 'asn' or 'domains' + NOTE: this function doesnt store org IPs, pls use set_org_ips() + instead """ - get the ASN, IP and domains of an org from the db + # info will be stored in redis SETs like 'facebook_asn', + # 'twitter_ips', etc. + key = f"{org}_{info_type}" + if isinstance(org_info, list): + self.rcache.sadd(key, *org_info) + + def get_org_info(self, org, info_type: str) -> List[str]: + """ + Returns the ASN or domains of an org from the db + :param org: supported orgs are ('google', 'microsoft', 'apple', 'facebook', 'twitter') - :param info_type: supported types are 'asn', 'domains' - returns a json serialized dict with info + :param info_type: supported types are 'asn' or 'domains' + + returns a List[str] of the required info PS: All ASNs returned by this function are uppercase """ - return ( - self.rcache.hget(self.constants.ORG_INFO, f"{org}_{info_type}") - or "[]" - ) + key = f"{org}_{info_type}" + return self.rcache.smembers(key) - def get_org_ips(self, org): - org_info = self.rcache.hget(self.constants.ORG_INFO, f"{org}_IPs") + def is_domain_in_org_domains(self, org: str, domain: str) -> bool: + """ + checks if the given domain is in the org's domains set + :param org: supported orgs are ('google', 'microsoft', 'apple', + 'facebook', 'twitter') + :param domain: domain to check + :return: True if the domain is in the org's domains set, False otherwise + """ + key = f"{org}_domains" + return True if self.rcache.sismember(key, domain) else False - if not org_info: - org_info = {} - return org_info + def is_asn_in_org_asn(self, org: str, asn: str) -> bool: + """ + checks if the given asn is in the org's asns set + :param org: supported orgs are ('google', 'microsoft', 'apple', + 'facebook', 'twitter') + :param asn: asn to check + :return: True if the asn is in the org's asns set, False otherwise + """ + key = f"{org}_asn" + return True if self.rcache.sismember(key, asn) else False - try: - return json.loads(org_info) - except TypeError: - # it's a dict - return org_info + def is_ip_in_org_cidrs( + self, org: str, first_octet: str + ) -> List[str] | None: + """ + checks if the given first octet in the org's octets + :param org: supported orgs are ('google', 'microsoft', 'apple', + 'facebook', 'twitter') + :param ip: ip to check + :return: a list of cidrs the given ip may belong to, None otherwise + """ + key = f"{org}_IPs" + return self.r.hget(key, first_octet) - def set_whitelist(self, type_, whitelist_dict): + def get_org_ips(self, org: str) -> Dict[str, str]: """ - Store the whitelist_dict in the given key - :param type_: supported types are IPs, domains, macs and organizations - :param whitelist_dict: the dict of IPs,macs, domains or orgs to store + returns Dict[str, str] + keys are subnet first octets + values are serialized list of cidrs + e.g { + '2401': ['2401:fa00::/42', '2401:fa00:4::/48'] + '70': ['70.32.128.0/19','70.32.136.0/24'] + } """ - self.r.hset( - self.constants.WHITELIST, type_, json.dumps(whitelist_dict) - ) + key = f"{org}_IPs" + org_info = self.rcache.hgetall(key) + return org_info if org_info else {} - def get_all_whitelist(self) -> Optional[Dict[str, dict]]: + def set_whitelist(self, type_, whitelist_dict: Dict[str, Dict[str, str]]): """ - Returns a dict with the following keys from the whitelist - 'mac', 'organizations', 'IPs', 'domains' + Store the whitelist_dict in the given key + :param type_: supported types are IPs, domains, macs and organizations + :param whitelist_dict: the dict of IPs,macs, domains or orgs to store """ - whitelist: Optional[Dict[str, str]] = self.r.hgetall( - self.constants.WHITELIST - ) - if whitelist: - whitelist = {k: json.loads(v) for k, v in whitelist.items()} - return whitelist + key = f"{self.constants.WHITELIST}_{type_}" + # Pre-serialize all values + data = {ioc: json.dumps(info) for ioc, info in whitelist_dict.items()} + # Send all at once + if data: + self.r.hset(key, mapping=data) def get_whitelist(self, key: str) -> dict: """ - Whitelist supports different keys like : IPs domains - and organizations - this function is used to check if we have any of the - above keys whitelisted + Return ALL the whitelisted IoCs of key type + Whitelist supports different keys like : "IPs", "domains", + "organizations" or "macs" """ - if whitelist := self.r.hget(self.constants.WHITELIST, key): - return json.loads(whitelist) + key = f"{self.constants.WHITELIST}_{key}" + if whitelist := self.r.hgetall(key): + return whitelist else: return {} + def is_whitelisted(self, ioc: str, type_: str) -> str | None: + """ + Check if a given ioc (IP, domain, or MAC) is whitelisted. + + :param ioc: The ioc to check; IP address, domain, or MAC + :param type_: The type of ioc to check. Supported types: 'IPs', + 'domains', 'macs'. + :return: a serialized dict with the whitelist info of the given ioc + :raises ValueError: If the provided type_ is not supported. + """ + valid_types = {"IPs", "domains", "macs"} + if type_ not in valid_types: + raise ValueError( + f"Unsupported whitelist type: {type_}. " + f"Must be one of {valid_types}." + ) + + key = f"{self.constants.WHITELIST}_{type_}" + return self.r.hget(key, ioc) + def has_cached_whitelist(self) -> bool: return bool(self.r.exists(self.constants.WHITELIST)) diff --git a/slips_files/core/database/redis_db/ioc_handler.py b/slips_files/core/database/redis_db/ioc_handler.py index f17483f05..32ccac28b 100644 --- a/slips_files/core/database/redis_db/ioc_handler.py +++ b/slips_files/core/database/redis_db/ioc_handler.py @@ -168,12 +168,12 @@ def set_ti_feed_info(self, file, data): self.rcache.hset(self.constants.TI_FILES_INFO, file, data) def store_known_fp_md5_hashes(self, fps: Dict[str, List[str]]): - self.rcache.hmset(self.constants.KNOWN_FPS, fps) + self.rcache.hmset(self.constants.KNOWN_FP_MD5_HASHES, fps) def is_known_fp_md5_hash(self, hash: str) -> Optional[str]: """returns the description of the given hash if it is a FP. and - returns Fals eif the hash is not a FP""" - return self.rcache.hmget(self.constants.KNOWN_FPS, hash) + returns False if the hash is not a FP""" + return self.rcache.hmget(self.constants.KNOWN_FP_MD5_HASHES, hash) def delete_ips_from_ioc_ips(self, ips: List[str]): """ diff --git a/slips_files/core/evidence_handler.py b/slips_files/core/evidence_handler.py index 0d0d278c8..cea250214 100644 --- a/slips_files/core/evidence_handler.py +++ b/slips_files/core/evidence_handler.py @@ -21,6 +21,8 @@ # stratosphere@aic.fel.cvut.cz import json +import multiprocessing +import threading from typing import ( List, Dict, @@ -31,7 +33,6 @@ import sys import os import time -import traceback from slips_files.common.idmefv2 import IDMEFv2 from slips_files.common.style import ( @@ -39,6 +40,7 @@ ) from slips_files.common.parsers.config_parser import ConfigParser from slips_files.common.slips_utils import utils +from slips_files.core.evidence_logger import EvidenceLogger from slips_files.core.helpers.whitelist.whitelist import Whitelist from slips_files.core.helpers.notify import Notify from slips_files.common.abstracts.icore import ICore @@ -64,7 +66,7 @@ class EvidenceHandler(ICore): name = "EvidenceHandler" def init(self): - self.whitelist = Whitelist(self.logger, self.db) + self.whitelist = Whitelist(self.logger, self.db, self.bloom_filters) self.idmefv2 = IDMEFv2(self.logger, self.db) self.separator = self.db.get_separator() self.read_configuration() @@ -103,10 +105,25 @@ def init(self): self.our_ips: List[str] = utils.get_own_ips(ret="List") self.formatter = EvidenceFormatter(self.db, self.args) # thats just a tmp value, this variable will be set and used when - # the - # module is stopping. + # the module is stopping. self.last_msg_received_time = time.time() + # A thread that handing I/O to disk (writing evidence to log files) + self.logger_stop_signal = threading.Event() + self.evidence_logger_q = multiprocessing.Queue() + self.evidence_logger = EvidenceLogger( + stop_signal=self.logger_stop_signal, + evidence_logger_q=self.evidence_logger_q, + logfile=self.logfile, + jsonfile=self.jsonfile, + ) + self.logger_thread = threading.Thread( + target=self.evidence_logger.run_logger_thread, + daemon=True, + name="thread_that_handles_evidence_logging_to_disk", + ) + utils.start_thread(self.logger_thread, self.db) + def read_configuration(self): conf = ConfigParser() self.width: float = conf.get_tw_width_as_float() @@ -148,13 +165,11 @@ def add_alert_to_json_log_file(self, alert: Alert): self.handle_unable_to_log(alert, "Can't convert to IDMEF alert") return - try: - json.dump(idmef_alert, self.jsonfile) - self.jsonfile.write("\n") - except KeyboardInterrupt: - return True - except Exception as e: - self.handle_unable_to_log(alert, e) + to_log = { + "to_log": idmef_alert, + "where": "alerts.json", + } + self.evidence_logger_q.put(to_log) def add_evidence_to_json_log_file( self, @@ -186,29 +201,26 @@ def add_evidence_to_json_log_file( ) } ) - json.dump(idmef_evidence, self.jsonfile) - self.jsonfile.write("\n") + + to_log = { + "to_log": idmef_evidence, + "where": "alerts.json", + } + + self.evidence_logger_q.put(to_log) + except KeyboardInterrupt: return True except Exception as e: self.handle_unable_to_log(evidence, e) - def add_to_log_file(self, data): + def add_to_log_file(self, data: str): """ Add a new evidence line to the alerts.log and other log files if logging is enabled. """ - try: - # write to alerts.log - self.logfile.write(data) - if not data.endswith("\n"): - self.logfile.write("\n") - self.logfile.flush() - except KeyboardInterrupt: - return True - except Exception: - self.print("Error in add_to_log_file()") - self.print(traceback.format_exc(), 0, 1) + to_log = {"to_log": data, "where": "alerts.log"} + self.evidence_logger_q.put(to_log) def log_alert(self, alert: Alert, blocked=False): """ @@ -239,6 +251,11 @@ def log_alert(self, alert: Alert, blocked=False): self.add_alert_to_json_log_file(alert) def shutdown_gracefully(self): + self.logger_stop_signal.set() + try: + self.logger_thread.join(timeout=5) + except Exception: + pass self.logfile.close() self.jsonfile.close() @@ -549,6 +566,9 @@ def main(self): # reaching this point, now remove evidence from db so # it could be completely ignored self.db.delete_evidence(profileid, twid, evidence.id) + self.print( + f"{self.whitelist.get_bloom_filters_stats()}", 2, 0 + ) continue # convert time to local timezone diff --git a/slips_files/core/evidence_logger.py b/slips_files/core/evidence_logger.py new file mode 100644 index 000000000..1e8a37e52 --- /dev/null +++ b/slips_files/core/evidence_logger.py @@ -0,0 +1,69 @@ +import json +import queue +import threading +import traceback +from typing import TextIO + + +class EvidenceLogger: + def __init__( + self, + stop_signal: threading.Event, + evidence_logger_q: queue.Queue, + logfile: TextIO, + jsonfile: TextIO, + ): + self.stop_signal = stop_signal + self.evidence_logger_q = evidence_logger_q + self.logfile = logfile + self.jsonfile = jsonfile + + def print_to_alerts_logfile(self, data: str): + """ + Add a new evidence line to the alerts.log and other log files if + logging is enabled. + """ + try: + # write to alerts.log + self.logfile.write(data) + if not data.endswith("\n"): + self.logfile.write("\n") + self.logfile.flush() + except KeyboardInterrupt: + return True + except Exception: + self.print("Error in evidence_logger.print_to_alerts_logfile()") + self.print(traceback.format_exc(), 0, 1) + + def print_to_alerts_json(self, idmef_evidence: dict): + try: + json.dump(idmef_evidence, self.jsonfile) + self.jsonfile.write("\n") + except KeyboardInterrupt: + return + except Exception: + return + + def run_logger_thread(self): + """ + runs forever in a loop reveiving msgs from evidence_handler and + logging them to alert.log or alerts.json + to avoid blocking evidence handler when high traffic attacks are + happening, so slips can process evidence faster there while we log + as fast as possible here + """ + while not self.stop_signal.is_set(): + try: + msg = self.evidence_logger_q.get(timeout=1) + except queue.Empty: + continue + except Exception: + continue + + destination = msg["where"] + + if destination == "alerts.log": + self.print_to_alerts_logfile(msg["to_log"]) + + elif destination == "alerts.json": + self.print_to_alerts_json(msg["to_log"]) diff --git a/slips_files/core/helpers/bloom_filters_manager.py b/slips_files/core/helpers/bloom_filters_manager.py new file mode 100644 index 000000000..4b93e6399 --- /dev/null +++ b/slips_files/core/helpers/bloom_filters_manager.py @@ -0,0 +1,86 @@ +# SPDX-FileCopyrightText: 2021 Sebastian Garcia +# SPDX-License-Identifier: GPL-2.0-only +from typing import List, Dict + +from pybloom_live import BloomFilter + +from slips_files.common.slips_utils import utils +from slips_files.core.database.database_manager import DBManager +from slips_files.core.output import Output + + +class BFManager: + def __init__( + self, + logger: Output, + output_dir, + redis_port, + conf, + ppid: int, + ): + self.redis_port = redis_port + self.output_dir = output_dir + self.logger = logger + self.conf = conf + # the parent pid of this module, used for strating the db + self.ppid = ppid + self.db = DBManager( + self.logger, self.output_dir, self.redis_port, self.conf, self.ppid + ) + self.org_filters = {} + + def initialize_filter(self): + self._init_whitelisted_iocs_bf() + self._init_whitelisted_orgs_bf() + + def _init_whitelisted_iocs_bf(self): + self.domains = BloomFilter(capacity=10000, error_rate=0.001) + self.ips = BloomFilter(capacity=10000, error_rate=0.001) + self.mac_addrs = BloomFilter(capacity=10000, error_rate=0.001) + self.orgs = BloomFilter(capacity=100, error_rate=0.001) + + for ip in self.db.get_whitelist("IPs"): + self.ips.add(ip) + + for domain in self.db.get_whitelist("domains"): + self.domains.add(domain) + + for org in self.db.get_whitelist("organizations"): + self.orgs.add(org) + + for mac in self.db.get_whitelist("macs"): + self.mac_addrs.add(mac) + + def _init_whitelisted_orgs_bf(self): + """ + Updates the bloom filters with the whitelisted organization + domains, asns, and ips + fills the self.org_filters dict + is called from update_manager whether slips did update its local + org files or not. + this goal of calling this is to make sure slips has the bloom + filters in mem at all times. + """ + err_rate = 0.01 + for org in utils.supported_orgs: + domains_bloom = BloomFilter(capacity=10000, error_rate=err_rate) + asns_bloom = BloomFilter(capacity=10000, error_rate=err_rate) + cidrs_bloom = BloomFilter(capacity=100, error_rate=err_rate) + + domains: List[str] = self.db.get_org_info(org, "domains") + _ = [domains_bloom.add(domain) for domain in domains] + + asns: List[str] = self.db.get_org_info(org, "asn") + _ = [asns_bloom.add(asn) for asn in asns] + + org_subnets: Dict[str, str] = self.db.get_org_ips(org) + _ = [ + cidrs_bloom.add(first_octet) + for first_octet in org_subnets.keys() + ] + + self.org_filters[org] = { + "domains": domains_bloom, + "asns": asns_bloom, + "first_octets": cidrs_bloom, + } diff --git a/slips_files/core/helpers/filemonitor.py b/slips_files/core/helpers/filemonitor.py index c9726a0b9..3b4130cb6 100644 --- a/slips_files/core/helpers/filemonitor.py +++ b/slips_files/core/helpers/filemonitor.py @@ -74,5 +74,3 @@ def on_modified(self, event): # tell slips to terminate self.db.publish_stop() break - elif "whitelist" in filename: - self.db.publish("reload_whitelist", "reload") diff --git a/slips_files/core/helpers/whitelist/domain_whitelist.py b/slips_files/core/helpers/whitelist/domain_whitelist.py index 1cd086455..9a0476cf4 100644 --- a/slips_files/core/helpers/whitelist/domain_whitelist.py +++ b/slips_files/core/helpers/whitelist/domain_whitelist.py @@ -1,5 +1,6 @@ # SPDX-FileCopyrightText: 2021 Sebastian Garcia # SPDX-License-Identifier: GPL-2.0-only +import json from typing import List, Dict import tldextract @@ -20,6 +21,9 @@ def name(self): def init(self): self.ip_analyzer = IPAnalyzer(self.db) self.read_configuration() + # for debugging + self.bf_hits = 0 + self.bf_misses = 0 def read_configuration(self): conf = ConfigParser() @@ -76,7 +80,6 @@ def is_whitelisted( # the reason why this function doesnt support the Attacker or # Victim as a parameter directly is that we may call it on other # values. not just attacker and victim domains. - if not isinstance(domain, str): return False @@ -104,22 +107,29 @@ def is_whitelisted( # domain is in the local whitelist, but the local whitelist # not enabled return False - whitelisted_domains: Dict[str, Dict[str, str]] - whitelisted_domains = self.db.get_whitelist("domains") - if parent_domain in whitelisted_domains: - # did the user say slips should ignore flows or alerts in the - # config file? - whitelist_should_ignore = whitelisted_domains[parent_domain][ - "what_to_ignore" - ] - # did the user say slips should ignore flows/alerts TO or from - # that domain in the config file? - dir_from_whitelist: str = whitelisted_domains[parent_domain][ - "from" - ] - else: + + if parent_domain not in self.manager.bloom_filters.domains: + # definitely not whitelisted + self.bf_hits += 1 return False + domain_info: str | None = self.db.is_whitelisted( + parent_domain, "domains" + ) + if not domain_info: + # bloom filter FP + self.bf_misses += 1 + return False + + self.bf_hits += 1 + domain_info: Dict[str, str] = json.loads(domain_info) + # did the user say slips should ignore flows or alerts in the + # config file? + whitelist_should_ignore = domain_info["what_to_ignore"] + # did the user say slips should ignore flows/alerts TO or from + # that domain in the config file? + dir_from_whitelist: str = domain_info["from"] + # match the direction and whitelist_Type of the given domain to the # ones we have from the whitelist. if not self.match.what_to_ignore( diff --git a/slips_files/core/helpers/whitelist/ip_whitelist.py b/slips_files/core/helpers/whitelist/ip_whitelist.py index 22d2dcd46..1b066d7a3 100644 --- a/slips_files/core/helpers/whitelist/ip_whitelist.py +++ b/slips_files/core/helpers/whitelist/ip_whitelist.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2021 Sebastian Garcia # SPDX-License-Identifier: GPL-2.0-only import ipaddress +import json from typing import List, Dict from slips_files.common.abstracts.iwhitelist_analyzer import IWhitelistAnalyzer @@ -17,6 +18,9 @@ def name(self): def init(self): self.read_configuration() + # for debugging + self.bf_hits = 0 + self.bf_misses = 0 def read_configuration(self): conf = ConfigParser() @@ -52,19 +56,29 @@ def is_whitelisted( if not self.is_valid_ip(ip): return False - whitelisted_ips: Dict[str, dict] = self.db.get_whitelist("IPs") + if ip not in self.manager.bloom_filters.ips: + # defnitely not whitelisted + self.bf_hits += 1 + return False - if ip not in whitelisted_ips: + ip_info: str | None = self.db.is_whitelisted(ip, "IPs") + # reaching here means ip is in the bloom filter + if not ip_info: + # bloom filter FP + self.bf_misses += 1 return False + self.bf_hits += 1 + ip_info: Dict[str, str] = json.loads(ip_info) # Check if we should ignore src or dst alerts from this ip # from_ can be: src, dst, both # what_to_ignore can be: alerts or flows or both - whitelist_direction: str = whitelisted_ips[ip]["from"] + whitelist_direction: str = ip_info["from"] if not self.match.direction(direction, whitelist_direction): return False - ignore: str = whitelisted_ips[ip]["what_to_ignore"] + ignore: str = ip_info["what_to_ignore"] if not self.match.what_to_ignore(what_to_ignore, ignore): return False + return True diff --git a/slips_files/core/helpers/whitelist/mac_whitelist.py b/slips_files/core/helpers/whitelist/mac_whitelist.py index 126c6a2cc..442116a24 100644 --- a/slips_files/core/helpers/whitelist/mac_whitelist.py +++ b/slips_files/core/helpers/whitelist/mac_whitelist.py @@ -1,5 +1,6 @@ # SPDX-FileCopyrightText: 2021 Sebastian Garcia # SPDX-License-Identifier: GPL-2.0-only +import json from typing import Dict import validators @@ -20,6 +21,9 @@ def name(self): def init(self): self.ip_analyzer = IPAnalyzer(self.db) self.read_configuration() + # for debugging + self.bf_hits = 0 + self.bf_misses = 0 def read_configuration(self): conf = ConfigParser() @@ -65,15 +69,24 @@ def is_whitelisted( if not self.is_valid_mac(mac): return False - whitelisted_macs: Dict[str, dict] = self.db.get_whitelist("macs") - if mac not in whitelisted_macs: + if mac not in self.manager.bloom_filters.mac_addrs: + # defnitely not whitelisted + self.bf_hits += 1 return False - whitelist_direction: str = whitelisted_macs[mac]["from"] + mac_info: str | None = self.db.is_whitelisted(mac, "macs") + if not mac_info: + self.bf_misses += 1 + return False + + self.bf_hits += 1 + + mac_info: Dict[str, dict] = json.loads(mac_info) + whitelist_direction: str = mac_info["from"] if not self.match.direction(direction, whitelist_direction): return False - whitelist_what_to_ignore: str = whitelisted_macs[mac]["what_to_ignore"] + whitelist_what_to_ignore: str = mac_info["what_to_ignore"] if not self.match.what_to_ignore( what_to_ignore, whitelist_what_to_ignore ): diff --git a/slips_files/core/helpers/whitelist/organization_whitelist.py b/slips_files/core/helpers/whitelist/organization_whitelist.py index c2612238d..093b6bc2a 100644 --- a/slips_files/core/helpers/whitelist/organization_whitelist.py +++ b/slips_files/core/helpers/whitelist/organization_whitelist.py @@ -8,6 +8,8 @@ Union, ) +from pybloom_live import BloomFilter + from slips_files.common.abstracts.iwhitelist_analyzer import IWhitelistAnalyzer from slips_files.common.parsers.config_parser import ConfigParser from slips_files.common.slips_utils import utils @@ -39,19 +41,36 @@ def init(self): self.ip_analyzer = IPAnalyzer(self.db) self.domain_analyzer = DomainAnalyzer(self.db) self.org_info_path = "slips_files/organizations_info/" + self.bloom_filters: Dict[str, Dict[str, BloomFilter]] + self.bloom_filters = self.manager.bloom_filters.org_filters self.read_configuration() + self.whitelisted_orgs: Dict[str, str] = self.db.get_whitelist( + "organizations" + ) + # for debugging + self.bf_hits = 0 + self.bf_misses = 0 def read_configuration(self): conf = ConfigParser() self.enable_local_whitelist: bool = conf.enable_local_whitelist() - def is_domain_in_org(self, domain: str, org: str): + def is_domain_in_org(self, domain: str, org: str) -> bool: """ Checks if the given domains belongs to the given org using the hardcoded org domains in organizations_info/org_domains """ try: - org_domains = json.loads(self.db.get_org_info(org, "domains")) + if domain not in self.bloom_filters[org]["domains"]: + self.bf_hits += 1 + return False + + if self.db.is_domain_in_org_domains(org, domain): + self.bf_hits += 1 + return True + + # match subdomains of all org domains slips knows of + org_domains: List[str] = self.db.get_org_info(org, "domains") flow_tld = self.domain_analyzer.get_tld(domain) for org_domain in org_domains: @@ -60,44 +79,58 @@ def is_domain_in_org(self, domain: str, org: str): if flow_tld != org_domain_tld: continue - # match subdomains too # if org has org.com, and the flow_domain is xyz.org.com # whitelist it if org_domain in domain: + self.bf_hits += 1 return True # if org has xyz.org.com, and the flow_domain is org.com # whitelist it if domain in org_domain: + self.bf_hits += 1 return True + self.bf_misses += 1 + except (KeyError, TypeError): # comes here if the whitelisted org doesn't have domains in # slips/organizations_info (not a famous org) # and ip doesn't have asn info. # so we don't know how to link this ip to the whitelisted org! - return False + pass + return False def is_ip_in_org(self, ip: str, org): """ Check if the given ip belongs to the given org """ try: - org_subnets: dict = self.db.get_org_ips(org) - first_octet: str = utils.get_first_octet(ip) if not first_octet: return - ip_obj = ipaddress.ip_address(ip) - # organization IPs are sorted by first octet for faster search - for range_ in org_subnets.get(first_octet, []): - if ip_obj in ipaddress.ip_network(range_): - return True + + if first_octet not in self.bloom_filters[org]["first_octets"]: + self.bf_hits += 1 + return False + + # organization IPs are sorted in the db by first octet for faster + # search + cidrs: List[str] + if cidrs := self.db.is_ip_in_org_ips(org, first_octet): + ip_obj = ipaddress.ip_address(ip) + for cidr in cidrs: + if ip_obj in ipaddress.ip_network(cidr): + self.bf_hits += 1 + return True + except (KeyError, TypeError): # comes here if the whitelisted org doesn't have # info in slips/organizations_info (not a famous org) # and ip doesn't have asn info. pass + + self.bf_misses += 1 return False def is_ip_asn_in_org_asn(self, ip: str, org): @@ -122,13 +155,23 @@ def _is_asn_in_org(self, asn: str, org: str) -> bool: """ if not (asn and asn != "Unknown"): return False + # because all ASN stored in slips organization_info/ are uppercase asn: str = asn.upper() if org.upper() in asn: return True - org_asn: List[str] = json.loads(self.db.get_org_info(org, "asn")) - return asn in org_asn + if asn not in self.bloom_filters[org]["asns"]: + self.bf_hits += 1 + return False + + if self.db.is_asn_in_org_asn(org, asn): + self.bf_hits += 1 + return True + else: + # bloom filter FP + self.bf_misses += 1 + return False def is_whitelisted(self, flow) -> bool: """checks if the given -flow- is whitelisted. not evidence/alerts.""" @@ -190,23 +233,21 @@ def _is_part_of_a_whitelisted_org( :param direction: direction of the given ioc, src or dst? :param what_to_ignore: can be "flows" or "alerts" or "both" """ - if ioc_type == IoCType.IP: if utils.is_private_ip(ioc): return False - whitelisted_orgs: Dict[str, dict] = self.db.get_whitelist( - "organizations" - ) - if not whitelisted_orgs: + if not self.whitelisted_orgs: return False - for org in whitelisted_orgs: - dir_from_whitelist = whitelisted_orgs[org]["from"] + for org in self.whitelisted_orgs: + org_info = json.loads(self.whitelisted_orgs[org]) + + dir_from_whitelist = org_info["from"] if not self.match.direction(direction, dir_from_whitelist): continue - whitelist_what_to_ignore = whitelisted_orgs[org]["what_to_ignore"] + whitelist_what_to_ignore = org_info["what_to_ignore"] if not self.match.what_to_ignore( what_to_ignore, whitelist_what_to_ignore ): diff --git a/slips_files/core/helpers/whitelist/whitelist.py b/slips_files/core/helpers/whitelist/whitelist.py index b0e944cdd..7d80c8b6c 100644 --- a/slips_files/core/helpers/whitelist/whitelist.py +++ b/slips_files/core/helpers/whitelist/whitelist.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: 2021 Sebastian Garcia # SPDX-License-Identifier: GPL-2.0-only from typing import ( - Optional, Dict, List, Union, @@ -10,6 +9,7 @@ from slips_files.common.parsers.config_parser import ConfigParser from slips_files.common.printer import Printer +from slips_files.core.helpers.bloom_filters_manager import BFManager from slips_files.core.helpers.whitelist.domain_whitelist import DomainAnalyzer from slips_files.core.helpers.whitelist.ip_whitelist import IPAnalyzer from slips_files.core.helpers.whitelist.mac_whitelist import MACAnalyzer @@ -31,10 +31,11 @@ class Whitelist: name = "Whitelist" - def __init__(self, logger: Output, db): + def __init__(self, logger: Output, db, bloom_filter_manager: BFManager): self.printer = Printer(logger, self.name) self.name = "whitelist" self.db = db + self.bloom_filters: BFManager = bloom_filter_manager self.match = WhitelistMatcher() self.parser = WhitelistParser(self.db, self) self.ip_analyzer = IPAnalyzer(self.db, whitelist_manager=self) @@ -50,7 +51,7 @@ def read_configuration(self): def update(self): """ parses the local whitelist specified in the slips.yaml - and stores the parsed results in the db + and stores the parsed results in the db and in bloom filters """ self.parser.parse() self.db.set_whitelist("IPs", self.parser.whitelisted_ips) @@ -140,33 +141,6 @@ def is_whitelisted_flow(self, flow) -> bool: return self.org_analyzer.is_whitelisted(flow) - def get_all_whitelist(self) -> Optional[Dict[str, dict]]: - """ - returns the whitelisted ips, domains, org from the db - returns a dict with the following keys - 'mac', 'organizations', 'IPs', 'domains' - this function tries to get the whitelist from the db 10 times - """ - whitelist: Dict[str, dict] = self.db.get_all_whitelist() - max_tries = 10 - # if this module is loaded before profilerProcess or before we're - # done processing the whitelist in general - # the database won't return the whitelist - # so we need to try several times until the db returns the - # populated whitelist - # empty dicts evaluate to False - while not bool(whitelist) and max_tries != 0: - # try max 10 times to get the whitelist, if it's still empty - # hen it's not empty by mistake - max_tries -= 1 - whitelist = self.db.get_all_whitelist() - - if max_tries == 0: - # we tried 10 times to get the whitelist, it's probably empty. - return - - return whitelist - def is_whitelisted_evidence(self, evidence: Evidence) -> bool: """ Checks if an evidence is whitelisted @@ -243,3 +217,25 @@ def _is_whitelisted_entity( return True return False + + def get_bloom_filters_stats(self) -> Dict[str, float]: + """ + returns the bloom filters stats + """ + total_hits = 0 + total_misses = 0 + + for helper in ( + self.ip_analyzer, + self.domain_analyzer, + self.mac_analyzer, + self.org_analyzer, + ): + total_hits += helper.bf_hits + total_misses += helper.bf_misses + + # Bloom filters cannot produce false negatives:D + return ( + f"Number of times bloom filter was acuurate (TN + TP):" + f" {total_hits}, FPs: {total_misses}" + ) diff --git a/slips_files/core/helpers/whitelist/whitelist_parser.py b/slips_files/core/helpers/whitelist/whitelist_parser.py index bccfc53be..ea94d3891 100644 --- a/slips_files/core/helpers/whitelist/whitelist_parser.py +++ b/slips_files/core/helpers/whitelist/whitelist_parser.py @@ -1,7 +1,6 @@ # SPDX-FileCopyrightText: 2021 Sebastian Garcia # SPDX-License-Identifier: GPL-2.0-only import ipaddress -import json import os from typing import TextIO, List, Dict, Optional import validators @@ -18,27 +17,11 @@ def __init__(self, db, manager): # to have access to the print function self.manager = manager self.read_configuration() - self.init_whitelists() - self.org_info_path = "slips_files/organizations_info/" - - def init_whitelists(self): - """ - initializes the dicts we'll be using for storing the parsed - whitelists. - uses existing dicts from the db if found. - """ self.whitelisted_ips = {} self.whitelisted_domains = {} self.whitelisted_orgs = {} self.whitelisted_mac = {} - if self.db.has_cached_whitelist(): - # since this parser can run when the user modifies whitelist.conf - # and not just when the user starts slips - # we need to check if the dicts are already there in the cache db - self.whitelisted_ips = self.db.get_whitelist("IPs") - self.whitelisted_domains = self.db.get_whitelist("domains") - self.whitelisted_orgs = self.db.get_whitelist("organizations") - self.whitelisted_mac = self.db.get_whitelist("mac") + self.org_info_path = "slips_files/organizations_info/" def get_dict_for_storing_data(self, data_type: str): """ @@ -93,9 +76,6 @@ def remove_entry_from_cache_db( cache.pop(entry_to_remove["data"]) return True - def set_number_of_columns(self, line: str) -> None: - self.NUMBER_OF_WHITELIST_COLUMNS: int = len(line.split(",")) - def update_whitelisted_domains(self, domain: str, info: Dict[str, str]): if not utils.is_valid_domain(domain): return @@ -127,7 +107,7 @@ def update_whitelisted_mac_addresses(self, mac: str, info: Dict[str, str]): self.whitelisted_mac[mac] = info def update_whitelisted_ips(self, ip: str, info: Dict[str, str]): - if not (validators.ipv6(ip) or validators.ipv4): + if not (validators.ipv6(ip) or validators.ipv4(ip)): return self.whitelisted_ips[ip] = info @@ -150,7 +130,7 @@ def parse_line(self, line: str) -> Dict[str, str]: def call_handler(self, parsed_line: Dict[str, str]): """ calls the appropriate handler based on the type of data in the - parsed line + given line :param parsed_line: output dict of self.parse_line should have the following keys { type": .. @@ -196,7 +176,7 @@ def load_org_asn(self, org) -> Optional[List[str]]: line = line.replace("\n", "").strip() org_asn.append(line.upper()) org_asn_file.close() - self.db.set_org_info(org, json.dumps(org_asn), "asn") + self.db.set_org_info(org, org_asn, "asn") return org_asn def load_org_domains(self, org): @@ -220,7 +200,7 @@ def load_org_domains(self, org): domains.append(line.lower()) domain_info.close() - self.db.set_org_info(org, json.dumps(domains), "domains") + self.db.set_org_info(org, domains, "domains") return domains def is_valid_network(self, network: str) -> bool: @@ -268,7 +248,8 @@ def load_org_ips(self, org) -> Optional[Dict[str, List[str]]]: org_subnets[first_octet] = [line] org_info.close() - self.db.set_org_info(org, json.dumps(org_subnets), "IPs") + + self.db.set_org_cidrs(org, org_subnets) return org_subnets def parse(self) -> bool: @@ -282,7 +263,6 @@ def parse(self) -> bool: while line := whitelist.readline(): line_number += 1 if line.startswith('"IoCType"'): - self.set_number_of_columns(line) continue if line.startswith(";"): @@ -304,10 +284,11 @@ def parse(self) -> bool: except Exception: self.manager.print( f"Line {line_number} in whitelist.conf is invalid." - f" Skipping. " + f" Skipping." ) continue self.call_handler(parsed_line) + whitelist.close() return True diff --git a/slips_files/core/profiler.py b/slips_files/core/profiler.py index 25debd14f..599cb90f9 100644 --- a/slips_files/core/profiler.py +++ b/slips_files/core/profiler.py @@ -89,16 +89,13 @@ def init( self.input_type = False self.rec_lines = 0 self.localnet_cache = {} - self.whitelist = Whitelist(self.logger, self.db) + self.whitelist = Whitelist(self.logger, self.db, self.bloom_filters) self.read_configuration() self.symbol = SymbolHandler(self.logger, self.db) # there has to be a timeout or it will wait forever and never # receive a new line self.timeout = 0.0000001 - self.c1 = self.db.subscribe("reload_whitelist") - self.channels = { - "reload_whitelist": self.c1, - } + self.channels = {} # is set by this proc to tell input proc that we are done # processing and it can exit no issue self.is_profiler_done_event = is_profiler_done_event @@ -288,6 +285,7 @@ def add_flow_to_profile(self, flow): # Check if the flow is whitelisted and we should not process it if self.whitelist.is_whitelisted_flow(flow): + self.print(f"{self.whitelist.get_bloom_filters_stats()}", 2, 0) return True # 5th. Store the data according to the paremeters @@ -633,8 +631,7 @@ def start_profiler_threads(self): """starts 3 profiler threads for faster processing of the flows""" num_of_profiler_threads = 3 for _ in range(num_of_profiler_threads): - t = threading.Thread(target=self.process_flow) - t.daemon = True + t = threading.Thread(target=self.process_flow, daemon=True) t.start() self.profiler_threads.append(t) @@ -746,16 +743,6 @@ def main(self): # we're using self.should_stop() here instead of while True to be # able to unit test this function:D while not self.should_stop(): - # listen on this channel in case whitelist.conf is changed, - # we need to process the new changes - if self.get_msg("reload_whitelist"): - # if whitelist.conf is edited using pycharm - # a msg will be sent to this channel on every keypress, - # because pycharm saves file automatically - # otherwise this channel will get a msg only when - # whitelist.conf is modified and saved to disk - self.whitelist.update() - msg = self.get_msg_from_input_proc(self.profiler_queue) if not msg: # wait for msgs diff --git a/tests/module_factory.py b/tests/module_factory.py index a50c1d3df..4497d9ed1 100644 --- a/tests/module_factory.py +++ b/tests/module_factory.py @@ -189,45 +189,44 @@ def create_main_obj(self): @patch(MODULE_DB_MANAGER, name="mock_db") def create_http_analyzer_obj(self, mock_db): http_analyzer = HTTPAnalyzer( - self.logger, - "dummy_output_dir", - 6379, - Mock(), # termination event - Mock(), # args - Mock(), # conf - Mock(), # ppid + logger=self.logger, + output_dir="dummy_output_dir", + redis_port=6379, + termination_event=Mock(), + slips_args=Mock(), + conf=Mock(), + ppid=Mock(), + bloom_filters_manager=Mock(), ) - - # override the self.print function to avoid broken pipes http_analyzer.print = Mock() return http_analyzer @patch(MODULE_DB_MANAGER, name="mock_db") def create_fidesModule_obj(self, mock_db): fm = FidesModule( - self.logger, - "dummy_output_dir", - 6379, - Mock(), # termination event - Mock(), # args - Mock(), # conf - Mock(), # ppid + logger=self.logger, + output_dir="dummy_output_dir", + redis_port=6379, + termination_event=Mock(), + slips_args=Mock(), + conf=Mock(), + ppid=Mock(), + bloom_filters_manager=Mock(), ) - - # override the self.print function fm.print = Mock() return fm @patch(MODULE_DB_MANAGER, name="mock_db") def create_virustotal_obj(self, mock_db): virustotal = VT( - self.logger, - "dummy_output_dir", - 6379, - Mock(), # termination event - Mock(), # args - Mock(), # conf - Mock(), # ppid + logger=self.logger, + output_dir="dummy_output_dir", + redis_port=6379, + termination_event=Mock(), + slips_args=Mock(), + conf=Mock(), + ppid=Mock(), + bloom_filters_manager=Mock(), ) virustotal.print = Mock() virustotal.__read_configuration = Mock() @@ -239,18 +238,111 @@ def create_arp_obj(self, mock_db): "modules.arp.arp.ARP.wait_for_arp_scans", return_value=Mock() ): arp = ARP( - self.logger, - "dummy_output_dir", - 6379, - Mock(), # termination event - Mock(), # args - Mock(), # conf - Mock(), # ppid + logger=self.logger, + output_dir="dummy_output_dir", + redis_port=6379, + termination_event=Mock(), + slips_args=Mock(), + conf=Mock(), + ppid=Mock(), + bloom_filters_manager=Mock(), ) arp.print = Mock() arp.evidence_filter.is_slips_peer = Mock(return_value=False) return arp + def create_checker_obj(self): + mock_main = Mock() + mock_main.args = MagicMock() + mock_main.args.output = "test_output" + mock_main.args.verbose = "0" + mock_main.args.debug = "0" + mock_main.redis_man = Mock() + mock_main.terminate_slips = Mock() + mock_main.print_version = Mock() + mock_main.get_input_file_type = Mock() + mock_main.handle_flows_from_stdin = Mock() + mock_main.pid = 12345 + + checker = Checker(mock_main) + return checker + + @patch(MODULE_DB_MANAGER, name="mock_db") + def create_go_director_obj(self, mock_db): + with patch("modules.p2ptrust.utils.utils.send_evaluation_to_go"): + go_director = GoDirector( + logger=self.logger, + trustdb=Mock(spec=TrustDB), + db=mock_db, + storage_name="test_storage", + override_p2p=False, + gopy_channel="test_gopy", + pygo_channel="test_pygo", + p2p_reports_logfile="test_reports.log", + ) + go_director.print = Mock() + return go_director + + @patch(DB_MANAGER, name="mock_db") + def create_daemon_object(self, mock_db): + with ( + patch("slips.daemon.Daemon.read_pidfile", return_type=None), + patch("slips.daemon.Daemon.read_configuration"), + patch("builtins.open", mock_open(read_data=None)), + ): + daemon = Daemon(MagicMock()) + daemon.stderr = "errors.log" + daemon.stdout = "slips.log" + daemon.stdin = "/dev/null" + daemon.logsfile = "slips.log" + daemon.pidfile_dir = "/tmp" + daemon.pidfile = os.path.join(daemon.pidfile_dir, "slips_daemon.lock") + daemon.daemon_start_lock = "slips_daemon_start" + daemon.daemon_stop_lock = "slips_daemon_stop" + return daemon + + @contextmanager + def dummy_acquire_flock(self): + yield + + @patch("sqlite3.connect") + def create_trust_db_obj(self, sqlite_mock): + with ( + patch("slips_files.common.abstracts.isqlite.ISQLite._init_flock"), + patch( + "slips_files.common.abstracts.isqlite.ISQLite._acquire_flock" + ), + ): + trust_db = TrustDB( + logger=self.logger, + db_file=Mock(), + main_pid=Mock(), + drop_tables_on_startup=False, + ) + trust_db.conn = Mock() + trust_db.print = Mock() + trust_db._init_flock = Mock() + trust_db._acquire_flock = MagicMock() + return trust_db + + @patch(MODULE_DB_MANAGER, name="mock_db") + def create_base_model_obj(self, mock_db): + logger = Mock(spec=Output) + trustdb = Mock() + return BaseModel(logger, trustdb, mock_db) + + def create_notify_obj(self): + notify = Notify() + return notify + + def create_ioc_handler_obj(self): + handler = IoCHandler() + handler.r = Mock() + handler.rcache = Mock() + handler.constants = Constants() + handler.channels = Channels() + return handler + @patch(MODULE_DB_MANAGER, name="mock_db") def create_arp_filter_obj(self, mock_db): filter = ARPEvidenceFilter(Mock(), Mock(), mock_db) # conf # args @@ -259,13 +351,14 @@ def create_arp_filter_obj(self, mock_db): @patch(MODULE_DB_MANAGER, name="mock_db") def create_blocking_obj(self, mock_db): blocking = Blocking( - self.logger, - "dummy_output_dir", - 6379, - Mock(), # termination event - Mock(), # args - Mock(), # conf - Mock(), # ppid + logger=self.logger, + output_dir="dummy_output_dir", + redis_port=6379, + termination_event=Mock(), + slips_args=Mock(), + conf=Mock(), + ppid=Mock(), + bloom_filters_manager=Mock(), ) # override the print function to avoid broken pipes blocking.print = Mock() @@ -295,9 +388,8 @@ def create_flowalerts_obj(self, mock_db): slips_args=Mock(), conf=Mock(), ppid=Mock(), + bloom_filters_manager=Mock(), ) - - # override the self.print function to avoid broken pipes flowalerts.print = Mock() return flowalerts @@ -346,6 +438,21 @@ def create_software_analyzer_obj(self, mock_db): flowalerts = self.create_flowalerts_obj() return Software(flowalerts.db, flowalerts=flowalerts) + @patch(MODULE_DB_MANAGER, name="mock_db") + def create_ip_info_obj(self, mock_db): + ip_info = IPInfo( + logger=self.logger, + output_dir="dummy_output_dir", + redis_port=6379, + termination_event=Mock(), + slips_args=Mock(), + conf=Mock(), + ppid=Mock(), + bloom_filters_manager=Mock(), + ) + ip_info.print = Mock() + return ip_info + @patch(MODULE_DB_MANAGER, name="mock_db") def create_input_obj( self, input_information, input_type, mock_db, line_type=False @@ -359,6 +466,7 @@ def create_input_obj( slips_args=Mock(), conf=Mock(), ppid=Mock(), + bloom_filters_manager=Mock(), is_input_done=Mock(), profiler_queue=self.profiler_queue, input_type=input_type, @@ -379,40 +487,24 @@ def create_input_obj( return input - @patch(MODULE_DB_MANAGER, name="mock_db") - def create_ip_info_obj(self, mock_db): - ip_info = IPInfo( - self.logger, - "dummy_output_dir", - 6379, - Mock(), # termination event - Mock(), # args - Mock(), # conf - Mock(), # ppid - ) - # override the self.print function to avoid broken pipes - ip_info.print = Mock() - return ip_info - @patch(DB_MANAGER, name="mock_db") def create_asn_obj(self, mock_db): return ASN(mock_db) @patch(MODULE_DB_MANAGER, name="mock_db") def create_leak_detector_obj(self, mock_db): - # this file will be used for storing the module output - # and deleted when the tests are done test_pcap = "dataset/test7-malicious.pcap" yara_rules_path = "tests/yara_rules_for_testing/rules/" compiled_yara_rules_path = "tests/yara_rules_for_testing/compiled/" leak_detector = LeakDetector( - self.logger, - "dummy_output_dir", - 6379, - Mock(), # termination event - Mock(), # args - Mock(), # conf - Mock(), # ppid + logger=self.logger, + output_dir="dummy_output_dir", + redis_port=6379, + termination_event=Mock(), + slips_args=Mock(), + conf=Mock(), + ppid=Mock(), + bloom_filters_manager=Mock(), ) leak_detector.print = Mock() # this is the path containing 1 yara rule for testing, @@ -432,11 +524,11 @@ def create_profiler_obj(self, mock_db): slips_args=Mock(), conf=Mock(), ppid=Mock(), + bloom_filters_manager=Mock(), is_profiler_done=Mock(), profiler_queue=self.input_queue, is_profiler_done_event=Mock(), ) - # override the self.print function to avoid broken pipes profiler.print = Mock() profiler.local_whitelist_path = "tests/test_whitelist.conf" profiler.db = mock_db @@ -469,16 +561,15 @@ def create_utils_obj(self): @patch(MODULE_DB_MANAGER, name="mock_db") def create_threatintel_obj(self, mock_db): threatintel = ThreatIntel( - self.logger, - "dummy_output_dir", - 6379, - Mock(), # termination event - Mock(), # args - Mock(), # conf - Mock(), # ppid + logger=self.logger, + output_dir="dummy_output_dir", + redis_port=6379, + termination_event=Mock(), + slips_args=Mock(), + conf=Mock(), + ppid=Mock(), + bloom_filters_manager=Mock(), ) - - # override the self.print function to avoid broken pipes threatintel.print = Mock() return threatintel @@ -489,21 +580,26 @@ def create_spamhaus_obj(self, mock_db): @patch(MODULE_DB_MANAGER, name="mock_db") def create_update_manager_obj(self, mock_db): update_manager = UpdateManager( - self.logger, - "dummy_output_dir", - 6379, - Mock(), # termination event - Mock(), # args - Mock(), # conf - Mock(), # ppid + logger=self.logger, + output_dir="dummy_output_dir", + redis_port=6379, + termination_event=Mock(), + slips_args=Mock(), + conf=Mock(), + ppid=Mock(), + bloom_filters_manager=Mock(), ) - # override the self.print function to avoid broken pipes update_manager.print = Mock() return update_manager @patch(MODULE_DB_MANAGER, name="mock_db") def create_whitelist_obj(self, mock_db): - whitelist = Whitelist(self.logger, mock_db) + bloom_filter_manager_mock = Mock() + whitelist = Whitelist( + self.logger, + mock_db, + bloom_filter_manager=bloom_filter_manager_mock, + ) # override the self.print function to avoid broken pipes whitelist.print = Mock() whitelist.whitelist_path = "tests/test_whitelist.conf" @@ -607,122 +703,47 @@ def create_evidence_obj( @patch(MODULE_DB_MANAGER, name="mock_db") def create_network_discovery_obj(self, mock_db): network_discovery = NetworkDiscovery( - self.logger, - "dummy_output_dir", - 6379, - Mock(), # termination event - Mock(), # args - Mock(), # conf - Mock(), # ppid + logger=self.logger, + output_dir="dummy_output_dir", + redis_port=6379, + termination_event=Mock(), + slips_args=Mock(), + conf=Mock(), + ppid=Mock(), + bloom_filters_manager=Mock(), ) return network_discovery + def create_markov_chain_obj(self): + return Matrix() + @patch(MODULE_DB_MANAGER, name="mock_db") def create_arp_poisoner_obj(self, mock_db): poisoner = ARPPoisoner( - self.logger, - "dummy_output_dir", - 6379, - Mock(), # termination event - Mock(), # args - Mock(), # conf - Mock(), # ppid + logger=self.logger, + output_dir="dummy_output_dir", + redis_port=6379, + termination_event=Mock(), + slips_args=Mock(), + conf=Mock(), + ppid=Mock(), + bloom_filters_manager=Mock(), ) return poisoner - def create_markov_chain_obj(self): - return Matrix() - - def create_checker_obj(self): - mock_main = Mock() - mock_main.args = MagicMock() - mock_main.args.output = "test_output" - mock_main.args.verbose = "0" - mock_main.args.debug = "0" - mock_main.redis_man = Mock() - mock_main.terminate_slips = Mock() - mock_main.print_version = Mock() - mock_main.get_input_file_type = Mock() - mock_main.handle_flows_from_stdin = Mock() - mock_main.pid = 12345 - - checker = Checker(mock_main) - return checker - - @patch(MODULE_DB_MANAGER, name="mock_db") - def create_go_director_obj(self, mock_db): - with patch("modules.p2ptrust.utils.utils.send_evaluation_to_go"): - go_director = GoDirector( - logger=self.logger, - trustdb=Mock(spec=TrustDB), - db=mock_db, - storage_name="test_storage", - override_p2p=False, - gopy_channel="test_gopy", - pygo_channel="test_pygo", - p2p_reports_logfile="test_reports.log", - ) - go_director.print = Mock() - return go_director - - @patch(DB_MANAGER, name="mock_db") - def create_daemon_object(self, mock_db): - with ( - patch("slips.daemon.Daemon.read_pidfile", return_type=None), - patch("slips.daemon.Daemon.read_configuration"), - patch("builtins.open", mock_open(read_data=None)), - ): - daemon = Daemon(MagicMock()) - daemon.stderr = "errors.log" - daemon.stdout = "slips.log" - daemon.stdin = "/dev/null" - daemon.logsfile = "slips.log" - daemon.pidfile_dir = "/tmp" - daemon.pidfile = os.path.join(daemon.pidfile_dir, "slips_daemon.lock") - daemon.daemon_start_lock = "slips_daemon_start" - daemon.daemon_stop_lock = "slips_daemon_stop" - return daemon - - @contextmanager - def dummy_acquire_flock(self): - yield - - @patch("sqlite3.connect") - def create_trust_db_obj(self, sqlite_mock): - with ( - patch("slips_files.common.abstracts.isqlite.ISQLite._init_flock"), - patch( - "slips_files.common.abstracts.isqlite.ISQLite._acquire_flock" - ), - ): - trust_db = TrustDB( - logger=self.logger, - db_file=Mock(), - main_pid=Mock(), - drop_tables_on_startup=False, - ) - trust_db.conn = Mock() - trust_db.print = Mock() - trust_db._init_flock = Mock() - trust_db._acquire_flock = MagicMock() - return trust_db - @patch(MODULE_DB_MANAGER, name="mock_db") - def create_base_model_obj(self, mock_db): - logger = Mock(spec=Output) - trustdb = Mock() - return BaseModel(logger, trustdb, mock_db) - - def create_notify_obj(self): - notify = Notify() - return notify - - def create_ioc_handler_obj(self): - handler = IoCHandler() - handler.r = Mock() - handler.rcache = Mock() - handler.constants = Constants() - handler.channels = Channels() + def create_evidence_handler_obj(self, mock_db): + handler = EvidenceHandler( + logger=Mock(), + output_dir="/tmp", + redis_port=6379, + termination_event=Mock(), + slips_args=Mock(), + conf=Mock(), + ppid=Mock(), + bloom_filters_manager=Mock(), + ) + handler.db = mock_db return handler @patch(MODULE_DB_MANAGER, name="mock_db") @@ -737,6 +758,7 @@ def create_cesnet_obj(self, mock_db): Mock(), # args Mock(), # conf Mock(), # ppid + Mock(), # Bloom filter manager ) cesnet.db = mock_db cesnet.wclient = MagicMock() @@ -747,20 +769,6 @@ def create_cesnet_obj(self, mock_db): cesnet.print = MagicMock() return cesnet - @patch(MODULE_DB_MANAGER, name="mock_db") - def create_evidence_handler_obj(self, mock_db): - handler = EvidenceHandler( - logger=Mock(), - output_dir="/tmp", - redis_port=6379, - termination_event=Mock(), - slips_args=Mock(), - conf=Mock(), - ppid=Mock(), - ) - handler.db = mock_db - return handler - @patch(MODULE_DB_MANAGER, name="mock_db") def create_evidence_formatter_obj(self, mock_db): args = Mock() @@ -775,13 +783,14 @@ def create_symbol_handler_obj(self, mock_db): @patch(MODULE_DB_MANAGER, name="mock_db") def create_riskiq_obj(self, mock_db): riskiq = RiskIQ( - self.logger, - "dummy_output_dir", - 6379, - Mock(), # termination event - Mock(), # args - Mock(), # conf - Mock(), # ppid + logger=self.logger, + output_dir="dummy_output_dir", + redis_port=6379, + termination_event=Mock(), + slips_args=Mock(), + conf=Mock(), + ppid=Mock(), + bloom_filters_manager=Mock(), ) riskiq.db = mock_db return riskiq @@ -792,13 +801,14 @@ def create_timeline_object(self, mock_db): output_dir = "/tmp" redis_port = 6379 tl = Timeline( - logger, - output_dir, - redis_port, - Mock(), # termination event - Mock(), # args - Mock(), # conf - Mock(), # ppid + logger=logger, + output_dir=output_dir, + redis_port=redis_port, + termination_event=Mock(), + slips_args=Mock(), + conf=Mock(), + ppid=Mock(), + bloom_filters_manager=Mock(), ) tl.db = mock_db return tl diff --git a/tests/test_evidence_handler.py b/tests/test_evidence_handler.py index af47ea0ce..73498a841 100644 --- a/tests/test_evidence_handler.py +++ b/tests/test_evidence_handler.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: GPL-2.0-only import pytest import os -from unittest.mock import Mock, MagicMock, patch, call +from unittest.mock import Mock, MagicMock, patch from slips_files.core.structures.alerts import Alert from slips_files.core.structures.evidence import ( @@ -185,20 +185,21 @@ def test_clean_file(output_dir, file_to_clean, file_exists): @pytest.mark.parametrize( "data", [ - # testcase1: Basic log entry "Test log entry", - # testcase2: Another log entry "Another log entry", ], ) def test_add_to_log_file(data): evidence_handler = ModuleFactory().create_evidence_handler_obj() - mock_file = Mock() - evidence_handler.logfile = mock_file + evidence_handler.evidence_logger_q.put = Mock() + + # Act evidence_handler.add_to_log_file(data) - assert mock_file.write.call_count == 2 - mock_file.write.assert_has_calls([call(data), call("\n")]) - mock_file.flush.assert_called_once() + + # Assert + evidence_handler.evidence_logger_q.put.assert_called_once_with( + {"to_log": data, "where": "alerts.log"} + ) @pytest.mark.parametrize( @@ -247,11 +248,18 @@ def test_add_alert_to_json_log_file( ) evidence_handler = ModuleFactory().create_evidence_handler_obj() evidence_handler.jsonfile = mock_file - evidence_handler.idmefv2.convert_to_idmef_alert = Mock(return_value=True) - with patch("json.dump") as mock_json_dump: - evidence_handler.add_alert_to_json_log_file(alert) - mock_json_dump.assert_called_once() - mock_file.write.assert_any_call("\n") + evidence_handler.idmefv2.convert_to_idmef_alert = Mock( + return_value="alert_in_idmef_format" + ) + evidence_handler.evidence_logger_q.put = Mock() + + evidence_handler.add_alert_to_json_log_file(alert) + evidence_handler.evidence_logger_q.put.assert_called_once_with( + { + "to_log": "alert_in_idmef_format", + "where": "alerts.json", + } + ) def test_show_popup(): diff --git a/tests/test_process_manager.py b/tests/test_process_manager.py index 765aa73a4..125a198a6 100644 --- a/tests/test_process_manager.py +++ b/tests/test_process_manager.py @@ -34,6 +34,7 @@ def test_start_input_process( process_manager.main.zeek_bro = zeek_or_bro process_manager.main.zeek_dir = zeek_dir process_manager.main.line_type = line_type + process_manager.main.bloom_filters_man = Mock() with patch("managers.process_manager.Input") as mock_input: mock_input_process = Mock() @@ -51,6 +52,7 @@ def test_start_input_process( process_manager.main.args, process_manager.main.conf, process_manager.main.pid, + process_manager.main.bloom_filters_man, is_input_done=process_manager.is_input_done, profiler_queue=process_manager.profiler_queue, input_type=input_type, @@ -394,6 +396,7 @@ def test_print_stopped_module(): def test_start_profiler_process(): process_manager = ModuleFactory().create_process_manager_obj() + process_manager.main.bloom_filters_man = Mock() with patch("managers.process_manager.Profiler") as mock_profiler: mock_profiler_process = Mock() mock_profiler.return_value = mock_profiler_process @@ -410,6 +413,7 @@ def test_start_profiler_process(): process_manager.main.args, process_manager.main.conf, process_manager.main.pid, + process_manager.main.bloom_filters_man, is_profiler_done=process_manager.is_profiler_done, profiler_queue=process_manager.profiler_queue, is_profiler_done_event=process_manager.is_profiler_done_event, @@ -432,6 +436,7 @@ def test_start_profiler_process(): ) def test_start_evidence_process(output_dir, redis_port): process_manager = ModuleFactory().create_process_manager_obj() + process_manager.main.bloom_filters_man = Mock() process_manager.main.args.output = output_dir process_manager.main.redis_port = redis_port @@ -451,6 +456,7 @@ def test_start_evidence_process(output_dir, redis_port): process_manager.main.args, process_manager.main.conf, process_manager.main.pid, + process_manager.main.bloom_filters_man, ) mock_evidence_process.start.assert_called_once() process_manager.main.print.assert_called_once() diff --git a/tests/test_whitelist.py b/tests/test_whitelist.py index 35ce8768a..855c0bb47 100644 --- a/tests/test_whitelist.py +++ b/tests/test_whitelist.py @@ -3,7 +3,7 @@ from tests.module_factory import ModuleFactory import pytest import json -from unittest.mock import MagicMock, patch, Mock +from unittest.mock import MagicMock, patch, Mock, mock_open from slips_files.core.structures.evidence import ( Direction, IoCType, @@ -115,50 +115,132 @@ def test_get_dst_domains_of_flow(flow_type, expected_result): @pytest.mark.parametrize( - "ip, org, org_ips, expected_result", + "ip, org, cidrs, mock_bf_octets, expected_result", [ - ("216.58.192.1", "google", {"216": ["216.58.192.0/19"]}, True), - ("8.8.8.8", "cloudflare", {"216": ["216.58.192.0/19"]}, False), - ("8.8.8.8", "google", {}, False), # no org ip info + # Case 1: Bloom filter hit, DB hit + ("216.58.192.1", "google", ["216.58.192.0/19"], ["216"], True), + # Case 2: Bloom filter hit, DB miss + ("8.8.8.8", "cloudflare", [], ["8"], False), + # Case 3: Bloom filter MISS + # The 'ip' starts with "192", but we'll only put "10" in the filter + ("192.168.1.1", "my_org", [], ["10"], False), ], ) -def test_is_ip_in_org( +def test_is_ip_in_org_complete( ip, org, - org_ips, + cidrs, + mock_bf_octets, expected_result, ): whitelist = ModuleFactory().create_whitelist_obj() - whitelist.db.get_org_ips.return_value = org_ips - result = whitelist.org_analyzer.is_ip_in_org(ip, org) + analyzer = whitelist.org_analyzer + analyzer.bloom_filters = {org: {"first_octets": mock_bf_octets}} + + whitelist.db.is_ip_in_org_ips.return_value = cidrs + + result = analyzer.is_ip_in_org(ip, org) assert result == expected_result @pytest.mark.parametrize( - "domain, org, org_domains, expected_result", + "domain, org, mock_bf_domains, mock_db_exact, mock_db_org_list, " + "mock_tld_side_effect, expected_result", [ - ("www.google.com", "google", json.dumps(["google.com"]), True), - ("www.example.com", "google", json.dumps(["google.com"]), None), + # --- Case 1: Bloom Filter MISS --- + # The domain isn't even in the bloom filter. + ("google.com", "google", ["other.com"], None, None, None, False), + # --- Case 2: Bloom Filter HIT, DB Exact Match HIT --- + # BF hits, and db.is_domain_in_org_domains finds it. + ("google.com", "google", ["google.com"], True, None, None, True), + # --- Case 3: Subdomain Match (org_domain IN domain) --- + # 'google.com' (from db) is IN 'ads.google.com' (flow domain) + ( + "ads.google.com", + "google", + ["ads.google.com"], # 1. BF Hit + False, # 2. DB Exact Miss + ["google.com"], # 3. DB Org List + ["google.com", "google.com"], # 4. TLDs match (ads.google.com + # -> google.com, google.com -> google.com) + True, # 5. Expected: True + ), + # --- Case 4: Reverse Subdomain Match (domain IN org_domain) --- + # 'google.com' (flow domain) is IN 'ads.google.com' (from db) ( - "www.google.com", + "google.com", "google", - json.dumps([]), - None, - ), # no org domain info + ["google.com"], # 1. BF Hit + False, # 2. DB Exact Miss + ["ads.google.com"], # 3. DB Org List + ["google.com", "google.com"], # 4. TLDs match + True, # 5. Expected: True + ), + # --- Case 5: TLD Mismatch --- + # TLDs (google.net vs google.com) don't match, so 'continue' is hit. + ( + "google.net", + "google", + ["google.net"], # 1. BF Hit + False, # 2. DB Exact Miss + ["google.com"], # 3. org_domains + ["google.net", "google.com"], # 4. TLDs mismatch + False, # 5. Expected: False + ), + # --- Case 6: No Match (Falls through) --- + # TLDs match, but neither is a substring of the other. + ( + "evil-oogle.com", + "google", + ["evil-google.com"], # 1. BF should Hit + False, # 2. DB Exact Miss + ["google.com"], # 3. org_domains + ["google.com", "google.com"], # 4. TLDs match + False, # 5. Expected: False + ), ], ) def test_is_domain_in_org( domain, org, - org_domains, + mock_bf_domains, + mock_db_exact, + mock_db_org_list, + mock_tld_side_effect, expected_result, ): whitelist = ModuleFactory().create_whitelist_obj() - whitelist.db.get_org_info.return_value = org_domains - result = whitelist.org_analyzer.is_domain_in_org(domain, org) + analyzer = whitelist.org_analyzer + + analyzer.bloom_filters = {org: {"domains": mock_bf_domains}} + + whitelist.db.is_domain_in_org_domains.return_value = mock_db_exact + + whitelist.db.get_org_info.return_value = mock_db_org_list + # The first call is for 'domain', the second for 'org_domain' + if mock_tld_side_effect: + analyzer.domain_analyzer.get_tld = MagicMock( + side_effect=mock_tld_side_effect + ) + result = analyzer.is_domain_in_org(domain, org) assert result == expected_result +def test_is_domain_in_org_key_error(): + """ + Tests the 'try...except KeyError' block. + This happens if the 'org' isn't in the bloom_filters dict. + """ + whitelist = ModuleFactory().create_whitelist_obj() + analyzer = whitelist.org_analyzer + analyzer.bloom_filters = {} + # Accessing analyzer.bloom_filters["google"] will raise a KeyError, + # which should be caught and return False. + result = analyzer.is_domain_in_org("google.com", "google") + + assert not result + + @pytest.mark.parametrize( "is_whitelisted_victim, is_whitelisted_attacker, expected_result", [ @@ -188,14 +270,14 @@ def test_is_whitelisted_evidence( "b1:b1:b1:c1:c2:c3", Direction.SRC, False, - {"b1:b1:b1:c1:c2:c3": {"from": "src", "what_to_ignore": "alerts"}}, + {"from": "src", "what_to_ignore": "alerts"}, ), ( "5.6.7.8", "a1:a2:a3:a4:a5:a6", Direction.DST, True, - {"a1:a2:a3:a4:a5:a6": {"from": "dst", "what_to_ignore": "both"}}, + {"from": "dst", "what_to_ignore": "both"}, ), ("9.8.7.6", "c1:c2:c3:c4:c5:c6", Direction.SRC, False, {}), ], @@ -208,8 +290,15 @@ def test_profile_has_whitelisted_mac( whitelisted_macs, ): whitelist = ModuleFactory().create_whitelist_obj() + # act as it is present in the bloom filter + whitelist.bloom_filters.mac_addrs = mac_address + whitelist.db.get_mac_addr_from_profile.return_value = mac_address - whitelist.db.get_whitelist.return_value = whitelisted_macs + if whitelisted_macs: + whitelist.db.is_whitelisted.return_value = json.dumps(whitelisted_macs) + else: + whitelist.db.is_whitelisted.return_value = None + assert ( whitelist.mac_analyzer.profile_has_whitelisted_mac( profile_ip, direction, "both" @@ -237,14 +326,16 @@ def test_matching_direction(direction, whitelist_direction, expected_result): @pytest.mark.parametrize( "ioc_data, expected_result", [ + # Private IP should short-circuit -> False ( { "ioc_type": IoCType.IP, - "value": "1.2.3.4", + "value": "192.168.1.1", "direction": Direction.SRC, }, False, ), + # Domain belonging to whitelisted org -> True ( { "ioc_type": IoCType.DOMAIN, @@ -253,6 +344,7 @@ def test_matching_direction(direction, whitelist_direction, expected_result): }, True, ), + # Public IP not in whitelisted org -> False ( { "ioc_type": IoCType.IP, @@ -263,51 +355,62 @@ def test_matching_direction(direction, whitelist_direction, expected_result): ), ], ) -def test_is_part_of_a_whitelisted_org( - ioc_data, - expected_result, -): +def test_is_part_of_a_whitelisted_org(ioc_data, expected_result): whitelist = ModuleFactory().create_whitelist_obj() - whitelist.db.get_whitelist.return_value = { - "google": {"from": "both", "what_to_ignore": "both"} + whitelist.org_analyzer.whitelisted_orgs = { + "google": json.dumps({"from": "both", "what_to_ignore": "both"}) } - whitelist.db.get_org_info.return_value = json.dumps(["1.2.3.4/32"]) - whitelist.db.get_ip_info.return_value = {"asn": {"asnorg": "Google"}} - whitelist.db.get_org_info.return_value = json.dumps(["example.com"]) - # we're mocking either an attacker or a victim obj - mock_ioc = MagicMock() - mock_ioc.value = ioc_data["value"] - mock_ioc.direction = ioc_data["direction"] - mock_ioc.ioc_type = ioc_data["ioc_type"] - assert ( - whitelist.org_analyzer._is_part_of_a_whitelisted_org( - mock_ioc.value, mock_ioc.ioc_type, mock_ioc.direction, "both" - ) - == expected_result + # mock dependent methods + whitelist.org_analyzer.is_domain_in_org = MagicMock(return_value=True) + whitelist.org_analyzer.is_ip_part_of_a_whitelisted_org = MagicMock( + return_value=False ) + whitelist.match = MagicMock() + whitelist.match.direction.return_value = True + whitelist.match.what_to_ignore.return_value = True + + with patch( + "slips_files.core.helpers.whitelist.organization_whitelist." + "utils.is_private_ip", + return_value=False, + ): + result = whitelist.org_analyzer._is_part_of_a_whitelisted_org( + ioc=ioc_data["value"], + ioc_type=ioc_data["ioc_type"], + direction=ioc_data["direction"], + what_to_ignore="both", + ) + + assert result == expected_result + @pytest.mark.parametrize( - "dst_domains, src_domains, whitelisted_domains, expected_result", + "dst_domains, src_domains, whitelisted_domains, " + "is_whitelisted_return_vals, expected_result", [ ( ["dst_domain.net"], ["apple.com"], {"apple.com": {"from": "src", "what_to_ignore": "both"}}, + [False, True], True, ), ( - ["apple.com"], + ["apple.com"], # dst domains, shouldnt be whitelisted ["src.com"], {"apple.com": {"from": "src", "what_to_ignore": "both"}}, + [False, False], False, ), - (["apple.com"], ["src.com"], {}, False), # no whitelist found + (["apple.com"], ["src.com"], {}, [False, False], False), + # no whitelist found ( # no flow domains found [], [], {"apple.com": {"from": "src", "what_to_ignore": "both"}}, + [False, False], False, ), ], @@ -316,22 +419,19 @@ def test_check_if_whitelisted_domains_of_flow( dst_domains, src_domains, whitelisted_domains, + is_whitelisted_return_vals, expected_result, ): whitelist = ModuleFactory().create_whitelist_obj() + whitelist.bloom_filters.domains = list(whitelisted_domains.keys()) whitelist.db.get_whitelist.return_value = whitelisted_domains - whitelist.domain_analyzer.is_domain_in_tranco_list = Mock() - whitelist.domain_analyzer.is_domain_in_tranco_list.return_value = False - - whitelist.domain_analyzer.get_dst_domains_of_flow = Mock() - whitelist.domain_analyzer.get_dst_domains_of_flow.return_value = ( - dst_domains + whitelist.domain_analyzer.get_src_domains_of_flow = Mock( + return_value=src_domains ) - whitelist.domain_analyzer.get_src_domains_of_flow = Mock() - whitelist.domain_analyzer.get_src_domains_of_flow.return_value = ( - src_domains + whitelist.domain_analyzer.is_whitelisted = Mock( + side_effect=is_whitelisted_return_vals ) flow = Mock() @@ -344,6 +444,7 @@ def test_is_whitelisted_domain_not_found(): Test when the domain is not found in the whitelisted domains. """ whitelist = ModuleFactory().create_whitelist_obj() + whitelist.bloom_filters.domains = [] whitelist.db.get_whitelist.return_value = {} whitelist.db.is_whitelisted_tranco_domain.return_value = False domain = "nonwhitelisteddomain.com" @@ -379,9 +480,17 @@ def test_read_configuration( ) def test_ip_analyzer_is_whitelisted(ip, what_to_ignore, expected_result): whitelist = ModuleFactory().create_whitelist_obj() - whitelist.db.get_whitelist.return_value = { - "1.2.3.4": {"from": "both", "what_to_ignore": "both"} - } + whitelist.bloom_filters.ips = [ip] # Simulate presence in bloom + # filter, because we wanna test the rest of the logic + + # only this ip is whitelisted + if ip == "1.2.3.4": + whitelist.db.is_whitelisted.return_value = json.dumps( + {"from": "both", "what_to_ignore": "both"} + ) + else: + whitelist.db.is_whitelisted.return_value = None + assert ( whitelist.ip_analyzer.is_whitelisted(ip, Direction.SRC, what_to_ignore) == expected_result @@ -485,47 +594,67 @@ def test_is_whitelisted_entity_victim( @pytest.mark.parametrize( - "org, expected_result", + "org, file_content, expected_result", [ - ("google", ["google.com", "google.co.uk"]), - ("microsoft", ["microsoft.com", "microsoft.net"]), + ( + "google", + "google.com\ngoogle.co.uk\n", + ["google.com", "google.co.uk"], + ), + ( + "microsoft", + "microsoft.com\nmicrosoft.net\n", + ["microsoft.com", "microsoft.net"], + ), ], ) -def test_load_org_domains( - org, - expected_result, -): +def test_load_org_domains(org, file_content, expected_result): whitelist = ModuleFactory().create_whitelist_obj() whitelist.db.set_org_info = MagicMock() - actual_result = whitelist.parser.load_org_domains(org) - for domain in expected_result: - assert domain in actual_result + # Mock the file open for reading org domains + with patch("builtins.open", mock_open(read_data=file_content)): + actual_result = whitelist.parser.load_org_domains(org) - assert len(actual_result) >= len(expected_result) - whitelist.db.set_org_info.assert_called_with( - org, json.dumps(actual_result), "domains" + # Check contents + assert actual_result == expected_result + whitelist.db.set_org_info.assert_called_once_with( + org, expected_result, "domains" ) @pytest.mark.parametrize( - "domain, direction, expected_result", + "domain, direction, is_whitelisted_return, expected_result", [ - ("example.com", Direction.SRC, True), - ("test.example.com", Direction.DST, True), - ("malicious.com", Direction.SRC, False), + ( + "example.com", + Direction.SRC, + {"from": "both", "what_to_ignore": "both"}, + True, + ), + ( + "test.example.com", + Direction.DST, + {"from": "both", "what_to_ignore": "both"}, + True, + ), + ("malicious.com", Direction.SRC, {}, False), ], ) def test_is_domain_whitelisted( domain, direction, + is_whitelisted_return, expected_result, ): whitelist = ModuleFactory().create_whitelist_obj() - whitelist.db.get_whitelist.return_value = { - "example.com": {"from": "both", "what_to_ignore": "both"} - } + whitelist.db.is_whitelisted.return_value = json.dumps( + is_whitelisted_return + ) + whitelist.db.is_whitelisted_tranco_domain.return_value = False + whitelist.bloom_filters.domains = ["example.com"] + for type_ in ("alerts", "flows"): result = whitelist.domain_analyzer.is_whitelisted( domain, direction, type_ @@ -539,36 +668,36 @@ def test_is_domain_whitelisted( ( "8.8.8.8", "google", - json.dumps(["AS6432"]), + ["AS6432"], {"asn": {"number": "AS6432"}}, True, ), ( "1.1.1.1", "cloudflare", - json.dumps(["AS6432"]), + ["AS6432"], {"asn": {"number": "AS6432"}}, True, ), ( "8.8.8.8", "Google", - json.dumps(["AS15169"]), + ["AS15169"], {"asn": {"number": "AS15169", "asnorg": "Google"}}, True, ), ( "1.1.1.1", "Cloudflare", - json.dumps(["AS13335"]), + ["AS13335"], {"asn": {"number": "AS15169", "asnorg": "Google"}}, False, ), - ("9.9.9.9", "IBM", json.dumps(["AS36459"]), {}, None), + ("9.9.9.9", "IBM", ["AS36459"], {}, False), ( "9.9.9.9", "IBM", - json.dumps(["AS36459"]), + ["AS36459"], {"asn": {"number": "Unknown"}}, False, ), @@ -578,80 +707,15 @@ def test_is_ip_asn_in_org_asn( ip, org, org_asn_info, ip_asn_info, expected_result ): whitelist = ModuleFactory().create_whitelist_obj() - whitelist.db.get_org_info.return_value = org_asn_info + + whitelist.db = MagicMock() whitelist.db.get_ip_info.return_value = ip_asn_info - assert ( - whitelist.org_analyzer.is_ip_asn_in_org_asn(ip, org) == expected_result - ) + whitelist.db.get_org_info.return_value = org_asn_info + ip_asn = ip_asn_info.get("asn", {}).get("number", None) + whitelist.org_analyzer._is_asn_in_org = MagicMock( + return_value=ip_asn in org_asn_info + ) -# TODO for sekhar -# @pytest.mark.parametrize( -# "flow_data, whitelist_data, expected_result", -# [ -# ( # testing_is_whitelisted_flow_with_whitelisted_organization_ -# # but_ip_or_domain_not_whitelisted -# MagicMock(saddr="9.8.7.6", daddr="5.6.7.8", type_="http", host="org.com"), -# {"organizations": {"org": {"from": "both", "what_to_ignore": "flows"}}}, -# False, -# ), -# ( # testing_is_whitelisted_flow_with_non_whitelisted_organizatio -# # n_but_ip_or_domain_whitelisted -# MagicMock( -# saddr="1.2.3.4", -# daddr="5.6.7.8", -# type_="http", -# host="whitelisted.com", -# ), -# {"IPs": {"1.2.3.4": {"from": "src", "what_to_ignore": "flows"}}}, -# False, -# ), -# ( # testing_is_whitelisted_flow_with_whitelisted_source_ip -# MagicMock( -# saddr="1.2.3.4", -# daddr="5.6.7.8", -# type_="http", -# server_name="example.com", -# ), -# {"IPs": {"1.2.3.4": {"from": "src", "what_to_ignore": "flows"}}}, -# False, -# ), -# ( # testing_is_whitelisted_flow_with_both_source_and_destination_ips_whitelisted -# MagicMock(saddr="1.2.3.4", daddr="5.6.7.8", type_="http"), -# { -# "IPs": { -# "1.2.3.4": {"from": "src", "what_to_ignore": "flows"}, -# "5.6.7.8": {"from": "dst", "what_to_ignore": "flows"}, -# } -# }, -# False, -# ), -# ( -# # testing_is_whitelisted_flow_with_whitelisted_mac_address_but_ip_not_whitelisted -# MagicMock( -# saddr="9.8.7.6", -# daddr="1.2.3.4", -# smac="b1:b1:b1:c1:c2:c3", -# dmac="a1:a2:a3:a4:a5:a6", -# type_="http", -# server_name="example.org", -# ), -# { -# "mac": { -# "b1:b1:b1:c1:c2:c3": { -# "from": "src", -# "what_to_ignore": "flows", -# } -# } -# }, -# False, -# ), -# ], -# ) -# def test_is_whitelisted_flow( flow_data, whitelist_data, expected_result): -# """ -# Test the is_whitelisted_flow method with various combinations of flow data and whitelist data. -# """ -# whitelist.db.get_all_whitelist.return_value = whitelist_data -# whitelist = ModuleFactory().create_whitelist_obj() -# assert whitelist.is_whitelisted_flow(flow_data) == expected_result + result = whitelist.org_analyzer.is_ip_asn_in_org_asn(ip, org) + assert result == expected_result