diff --git a/managers/metadata_manager.py b/managers/metadata_manager.py index ae5bec310..49202b123 100644 --- a/managers/metadata_manager.py +++ b/managers/metadata_manager.py @@ -128,7 +128,7 @@ def update_slips_stats_in_the_db(self) -> Tuple[int, Set[str]]: updates the number of processed ips, slips internal time, and modified tws so far in the db """ - slips_internal_time = float(self.main.db.getSlipsInternalTime()) + 1 + slips_internal_time = float(self.main.db.get_slips_internal_time()) + 1 # Get the amount of modified profiles since we last checked # this is the modification time of the last timewindow diff --git a/managers/redis_manager.py b/managers/redis_manager.py index 6d33437be..34a67ced2 100644 --- a/managers/redis_manager.py +++ b/managers/redis_manager.py @@ -13,7 +13,7 @@ class RedisManager: - open_servers_pids: Dict[int, int] + open_servers_pids: Dict[int, dict] def __init__(self, main): self.main = main @@ -240,19 +240,19 @@ def get_pid_of_redis_server(self, port: int) -> int: return False @staticmethod - def is_comment(line: str) -> True: + def is_comment(line: str) -> bool: """returns true if the given line is a comment""" return (line.startswith("#") or line.startswith("Date")) or len( line ) < 3 - def get_open_redis_servers(self) -> Dict[int, int]: + def get_open_redis_servers(self) -> Dict[int, dict]: """ fills and returns self.open_servers_PIDs with PIDs and ports of the redis servers started by slips read from running_slips.info.txt """ - self.open_servers_pids = {} + self.open_servers_pids: Dict[int, dict] = {} try: with open(self.running_logfile, "r") as f: for line in f.read().splitlines(): @@ -263,8 +263,29 @@ def get_open_redis_servers(self) -> Dict[int, int]: line = line.split(",") try: - pid, port = int(line[3]), int(line[2]) - self.open_servers_pids[pid] = port + ( + timestamp, + file_or_interface, + port, + pid, + zeek_dir, + output_dir, + slips_pid, + is_daemon, + save_the_db, + ) = line + + self.open_servers_pids[pid] = { + "timestamp": timestamp, + "file_or_interface": file_or_interface, + "port": port, + "pid": pid, + "zeek_dir": zeek_dir, + "output_dir": output_dir, + "slips_pid": slips_pid, + "is_daemon": is_daemon, + "save_the_db": save_the_db, + } except ValueError: # sometimes slips can't get the server pid and logs "False" # in the logfile instead of the PID @@ -379,7 +400,8 @@ def flush_redis_server(self, pid: int = None, port: int = None): if not hasattr(self, "open_servers_PIDs"): self.get_open_redis_servers() - port: int = self.open_servers_pids.get(pid, False) + pid_info: Dict[str, str] = self.open_servers_pids.get(pid, {}) + port: int = pid_info.get("port", False) if not port: # try to get the port using a cmd port: int = self.get_port_of_redis_server(pid) diff --git a/modules/cesnet/cesnet.py b/modules/cesnet/cesnet.py index 401bdfc2b..ec5c967a5 100644 --- a/modules/cesnet/cesnet.py +++ b/modules/cesnet/cesnet.py @@ -251,7 +251,7 @@ def import_alerts(self): src_ips.update({srcip: json.dumps(event_info)}) - self.db.add_ips_to_IoC(src_ips) + self.db.add_ips_to_ioc(src_ips) def pre_main(self): utils.drop_root_privs() diff --git a/modules/flowalerts/conn.py b/modules/flowalerts/conn.py index bd0c4e4ad..d76de3388 100644 --- a/modules/flowalerts/conn.py +++ b/modules/flowalerts/conn.py @@ -227,7 +227,7 @@ def check_multiple_reconnection_attempts(self, profileid, twid, flow): # reset the reconnection attempts of this src->dst current_reconnections[key] = (0, []) - self.db.setReconnections(profileid, twid, current_reconnections) + self.db.set_reconnections(profileid, twid, current_reconnections) def is_ignored_ip_data_upload(self, ip): """ diff --git a/modules/threat_intelligence/threat_intelligence.py b/modules/threat_intelligence/threat_intelligence.py index 14973e6dd..b88421634 100644 --- a/modules/threat_intelligence/threat_intelligence.py +++ b/modules/threat_intelligence/threat_intelligence.py @@ -693,11 +693,11 @@ def parse_local_ti_file(self, ti_file_path: str) -> bool: ) # Add all loaded malicious ips to the database - self.db.add_ips_to_IoC(malicious_ips) + self.db.add_ips_to_ioc(malicious_ips) # Add all loaded malicious domains to the database - self.db.add_domains_to_IoC(malicious_domains) - self.db.add_ip_range_to_IoC(malicious_ip_ranges) - self.db.add_asn_to_IoC(malicious_asns) + self.db.add_domains_to_ioc(malicious_domains) + self.db.add_ip_range_to_ioc(malicious_ip_ranges) + self.db.add_asn_to_ioc(malicious_asns) return True def __delete_old_source_ips(self, file): @@ -724,7 +724,7 @@ def __delete_old_source_ips(self, file): if data["source"] == file: old_data.append(ip) if old_data: - self.db.delete_ips_from_IoC_ips(old_data) + self.db.delete_ips_from_ioc_ips(old_data) def __delete_old_source_domains(self, file): """Deletes all domain indicators of compromise (IoCs) associated with a specific @@ -748,7 +748,7 @@ def __delete_old_source_domains(self, file): if data["source"] == file: old_data.append(domain) if old_data: - self.db.delete_domains_from_IoC_domains(old_data) + self.db.delete_domains_from_ioc_domains(old_data) def __delete_old_source_data_from_database(self, data_file): """Deletes old indicators of compromise (IoCs) associated with a specific source @@ -837,7 +837,7 @@ def parse_ja3_file(self, path): } ) # Add all loaded JA3 to the database - self.db.add_ja3_to_IoC(ja3_dict) + self.db.add_ja3_to_ioc(ja3_dict) return True def parse_jarm_file(self, path): @@ -901,7 +901,7 @@ def parse_jarm_file(self, path): "threat_level": threat_level, } ) - self.db.add_jarm_to_IoC(jarm_dict) + self.db.add_jarm_to_ioc(jarm_dict) return True def should_update_local_ti_file(self, path_to_local_ti_file: str) -> bool: @@ -1206,7 +1206,7 @@ def ip_has_blacklisted_asn( if not asn: return - if asn_info := self.db.is_blacklisted_ASN(asn): + if asn_info := self.db.is_blacklisted_asn(asn): asn_info = json.loads(asn_info) self.set_evidence_malicious_asn( ip, @@ -1359,7 +1359,7 @@ def is_malicious_ip( # not malicious return False - self.db.add_ips_to_IoC({ip: json.dumps(ip_info)}) + self.db.add_ips_to_ioc({ip: json.dumps(ip_info)}) if is_dns_response: self.set_evidence_malicious_ip_in_dns_response( ip, diff --git a/modules/update_manager/update_manager.py b/modules/update_manager/update_manager.py index ca1dcb810..5f9a59f0d 100644 --- a/modules/update_manager/update_manager.py +++ b/modules/update_manager/update_manager.py @@ -569,7 +569,7 @@ def parse_ssl_feed(self, url, full_path): ) continue # Add all loaded malicious sha1 to the database - self.db.add_ssl_sha1_to_IoC(malicious_ssl_certs) + self.db.add_ssl_sha1_to_ioc(malicious_ssl_certs) return True async def update_TI_file(self, link_to_download: str) -> bool: @@ -693,7 +693,7 @@ def update_riskiq_feed(self): "source": url, } ) - self.db.add_domains_to_IoC(malicious_domains_dict) + self.db.add_domains_to_ioc(malicious_domains_dict) except KeyError: self.print( f'RiskIQ returned: {response["message"]}. Update Cancelled.', @@ -852,7 +852,7 @@ def parse_ja3_feed(self, url, ja3_feed_path: str) -> bool: continue # Add all loaded malicious ja3 to the database - self.db.add_ja3_to_IoC(malicious_ja3_dict) + self.db.add_ja3_to_ioc(malicious_ja3_dict) return True except Exception: @@ -895,7 +895,7 @@ def parse_json_ti_feed(self, link_to_download, ti_file_path: str) -> bool: } ) - self.db.add_ips_to_IoC(malicious_ips_dict) + self.db.add_ips_to_ioc(malicious_ips_dict) return True if "hole.cert.pl" in link_to_download: @@ -932,7 +932,7 @@ def parse_json_ti_feed(self, link_to_download, ti_file_path: str) -> bool: "tags": tags, } ) - self.db.add_domains_to_IoC(malicious_domains_dict) + self.db.add_domains_to_ioc(malicious_domains_dict) return True def get_description_column_index(self, header): @@ -1386,9 +1386,9 @@ def parse_ti_feed(self, feed_link: str, ti_file_path: str) -> bool: ti_file_name: str = ti_file_path.split("/")[-1] handlers[data_type](ioc, ti_file_name, feed_link, description) - self.db.add_ips_to_IoC(self.malicious_ips_dict) - self.db.add_domains_to_IoC(self.malicious_domains_dict) - self.db.add_ip_range_to_IoC(self.malicious_ip_ranges) + self.db.add_ips_to_ioc(self.malicious_ips_dict) + self.db.add_domains_to_ioc(self.malicious_domains_dict) + self.db.add_ip_range_to_ioc(self.malicious_ip_ranges) feed.close() return True diff --git a/slips_files/core/database/database_manager.py b/slips_files/core/database/database_manager.py index 0860c97cc..dbb4e6c0b 100644 --- a/slips_files/core/database/database_manager.py +++ b/slips_files/core/database/database_manager.py @@ -98,7 +98,8 @@ def ask_for_ip_info(self, *args, **kwargs): @classmethod def discard_obj(cls): """ - when connecting on multiple ports, this dbmanager since it's a singelton + when connecting on multiple ports, this dbmanager since it's a + singelton returns the same instance of the already used db to fix this, we call this function every time we find a used db that slips should connect to @@ -111,12 +112,15 @@ def update_times_contacted(self, *args, **kwargs): def update_ip_info(self, *args, **kwargs): return self.rdb.update_ip_info(*args, **kwargs) - def getSlipsInternalTime(self, *args, **kwargs): - return self.rdb.getSlipsInternalTime(*args, **kwargs) + def get_slips_internal_time(self, *args, **kwargs): + return self.rdb.get_slips_internal_time(*args, **kwargs) def mark_profile_as_malicious(self, *args, **kwargs): return self.rdb.mark_profile_as_malicious(*args, **kwargs) + def get_malicious_profiles(self, *args, **kwargs): + return self.rdb.get_malicious_profiles(*args, **kwargs) + def get_asn_info(self, *args, **kwargs): return self.rdb.get_asn_info(*args, **kwargs) @@ -189,8 +193,8 @@ def set_dns_resolution(self, *args, **kwargs): def set_domain_resolution(self, *args, **kwargs): return self.rdb.set_domain_resolution(*args, **kwargs) - def get_redis_server_PID(self, *args, **kwargs): - return self.rdb.get_redis_server_PID(*args, **kwargs) + def get_redis_server_pid(self, *args, **kwargs): + return self.rdb.get_redis_server_pid(*args, **kwargs) def set_slips_mode(self, *args, **kwargs): return self.rdb.set_slips_mode(*args, **kwargs) @@ -282,8 +286,8 @@ def get_gateway_ip(self, *args, **kwargs): def get_gateway_mac(self, *args, **kwargs): return self.rdb.get_gateway_mac(*args, **kwargs) - def get_gateway_MAC_Vendor(self, *args, **kwargs): - return self.rdb.get_gateway_MAC_Vendor(*args, **kwargs) + def get_gateway_mac_vendor(self, *args, **kwargs): + return self.rdb.get_gateway_mac_vendor(*args, **kwargs) def set_default_gateway(self, *args, **kwargs): return self.rdb.set_default_gateway(*args, **kwargs) @@ -303,8 +307,8 @@ def get_passive_dns(self, *args, **kwargs): def get_reconnections_for_tw(self, *args, **kwargs): return self.rdb.get_reconnections_for_tw(*args, **kwargs) - def setReconnections(self, *args, **kwargs): - return self.rdb.setReconnections(*args, **kwargs) + def set_reconnections(self, *args, **kwargs): + return self.rdb.set_reconnections(*args, **kwargs) def get_host_ip(self, *args, **kwargs): return self.rdb.get_host_ip(*args, **kwargs) @@ -330,8 +334,8 @@ def set_org_info(self, *args, **kwargs): def get_org_info(self, *args, **kwargs): return self.rdb.get_org_info(*args, **kwargs) - def get_org_IPs(self, *args, **kwargs): - return self.rdb.get_org_IPs(*args, **kwargs) + def get_org_ips(self, *args, **kwargs): + return self.rdb.get_org_ips(*args, **kwargs) def set_whitelist(self, *args, **kwargs): return self.rdb.set_whitelist(*args, **kwargs) @@ -348,6 +352,9 @@ def has_cached_whitelist(self, *args, **kwargs): def is_doh_server(self, *args, **kwargs): return self.rdb.is_doh_server(*args, **kwargs) + def get_analysis_info(self, *args, **kwargs): + return self.rdb.get_analysis_info(*args, **kwargs) + def store_dhcp_server(self, *args, **kwargs): return self.rdb.store_dhcp_server(*args, **kwargs) @@ -387,8 +394,8 @@ def set_evidence_causing_alert(self, *args, **kwargs): def get_evidence_causing_alert(self, *args, **kwargs): return self.rdb.get_evidence_causing_alert(*args, **kwargs) - def get_evidence_by_ID(self, *args, **kwargs): - return self.rdb.get_evidence_by_ID(*args, **kwargs) + def get_evidence_by_id(self, *args, **kwargs): + return self.rdb.get_evidence_by_id(*args, **kwargs) def is_detection_disabled(self, *args, **kwargs): return self.rdb.is_detection_disabled(*args, **kwargs) @@ -460,12 +467,6 @@ def set_loaded_ti_files(self, *args, **kwargs): def get_loaded_ti_feeds(self, *args, **kwargs): return self.rdb.get_loaded_ti_feeds(*args, **kwargs) - def mark_as_analyzed_by_ti_module(self, *args, **kwargs): - return self.rdb.mark_as_analyzed_by_ti_module(*args, **kwargs) - - def get_ti_queue_size(self, *args, **kwargs): - return self.rdb.get_ti_queue_size(*args, **kwargs) - def set_cyst_enabled(self, *args, **kwargs): return self.rdb.set_cyst_enabled(*args, **kwargs) @@ -475,35 +476,35 @@ def is_cyst_enabled(self, *args, **kwargs): def give_threat_intelligence(self, *args, **kwargs): return self.rdb.give_threat_intelligence(*args, **kwargs) - def delete_ips_from_IoC_ips(self, *args, **kwargs): - return self.rdb.delete_ips_from_IoC_ips(*args, **kwargs) + def delete_ips_from_ioc_ips(self, *args, **kwargs): + return self.rdb.delete_ips_from_ioc_ips(*args, **kwargs) - def delete_domains_from_IoC_domains(self, *args, **kwargs): - return self.rdb.delete_domains_from_IoC_domains(*args, **kwargs) + def delete_domains_from_ioc_domains(self, *args, **kwargs): + return self.rdb.delete_domains_from_ioc_domains(*args, **kwargs) - def add_ips_to_IoC(self, *args, **kwargs): - return self.rdb.add_ips_to_IoC(*args, **kwargs) + def add_ips_to_ioc(self, *args, **kwargs): + return self.rdb.add_ips_to_ioc(*args, **kwargs) - def add_domains_to_IoC(self, *args, **kwargs): - return self.rdb.add_domains_to_IoC(*args, **kwargs) + def add_domains_to_ioc(self, *args, **kwargs): + return self.rdb.add_domains_to_ioc(*args, **kwargs) - def add_ip_range_to_IoC(self, *args, **kwargs): - return self.rdb.add_ip_range_to_IoC(*args, **kwargs) + def add_ip_range_to_ioc(self, *args, **kwargs): + return self.rdb.add_ip_range_to_ioc(*args, **kwargs) - def add_asn_to_IoC(self, *args, **kwargs): - return self.rdb.add_asn_to_IoC(*args, **kwargs) + def add_asn_to_ioc(self, *args, **kwargs): + return self.rdb.add_asn_to_ioc(*args, **kwargs) - def is_blacklisted_ASN(self, *args, **kwargs): - return self.rdb.is_blacklisted_ASN(*args, **kwargs) + def is_blacklisted_asn(self, *args, **kwargs): + return self.rdb.is_blacklisted_asn(*args, **kwargs) - def add_ja3_to_IoC(self, *args, **kwargs): - return self.rdb.add_ja3_to_IoC(*args, **kwargs) + def add_ja3_to_ioc(self, *args, **kwargs): + return self.rdb.add_ja3_to_ioc(*args, **kwargs) - def add_jarm_to_IoC(self, *args, **kwargs): - return self.rdb.add_jarm_to_IoC(*args, **kwargs) + def add_jarm_to_ioc(self, *args, **kwargs): + return self.rdb.add_jarm_to_ioc(*args, **kwargs) - def add_ssl_sha1_to_IoC(self, *args, **kwargs): - return self.rdb.add_ssl_sha1_to_IoC(*args, **kwargs) + def add_ssl_sha1_to_ioc(self, *args, **kwargs): + return self.rdb.add_ssl_sha1_to_ioc(*args, **kwargs) def get_all_blacklisted_ip_ranges(self, *args, **kwargs): return self.rdb.get_all_blacklisted_ip_ranges(*args, **kwargs) @@ -627,17 +628,20 @@ def get_all_contacted_ips_in_profileid_twid(self, *args, **kwargs): def mark_profile_and_timewindow_as_blocked(self, *args, **kwargs): return self.rdb.mark_profile_and_timewindow_as_blocked(*args, **kwargs) - def getBlockedProfTW(self, *args, **kwargs): - return self.rdb.getBlockedProfTW(*args, **kwargs) + def get_blocked_timewindows_of_profile(self, *args, **kwargs): + return self.rdb.get_blocked_timewindows_of_profile(*args, **kwargs) + + def get_blocked_profiles_and_timewindows(self, *args, **kwargs): + return self.rdb.get_blocked_profiles_and_timewindows(*args, **kwargs) def get_used_redis_port(self): return self.rdb.get_used_port() - def checkBlockedProfTW(self, *args, **kwargs): - return self.rdb.checkBlockedProfTW(*args, **kwargs) + def is_blocked_profile_and_tw(self, *args, **kwargs): + return self.rdb.is_blocked_profile_and_tw(*args, **kwargs) - def wasProfileTWModified(self, *args, **kwargs): - return self.rdb.wasProfileTWModified(*args, **kwargs) + def was_profile_and_tw_modified(self, *args, **kwargs): + return self.rdb.was_profile_and_tw_modified(*args, **kwargs) def add_software_to_profile(self, *args, **kwargs): return self.rdb.add_software_to_profile(*args, **kwargs) @@ -666,10 +670,13 @@ def get_profileid_from_ip(self, *args, **kwargs): def get_first_flow_time(self, *args, **kwargs): return self.rdb.get_first_flow_time(*args, **kwargs) - def getProfiles(self, *args, **kwargs): - return self.rdb.getProfiles(*args, **kwargs) + def get_profiles(self, *args, **kwargs): + return self.rdb.get_profiles(*args, **kwargs) + + def get_number_of_alerts_so_far(self, *args, **kwargs): + return self.rdb.get_number_of_alerts_so_far(*args, **kwargs) - def getTWsfromProfile(self, *args, **kwargs): + def get_tws_from_profile(self, *args, **kwargs): return self.rdb.get_tws_from_profile(*args, **kwargs) def get_number_of_tws_in_profile(self, *args, **kwargs): @@ -792,6 +799,9 @@ def add_timeline_line(self, *args, **kwargs): def get_timeline_last_lines(self, *args, **kwargs): return self.rdb.get_timeline_last_lines(*args, **kwargs) + def get_profiled_tw_timeline(self, *args, **kwargs): + return self.rdb.get_profiled_tw_timeline(*args, **kwargs) + def mark_profile_as_gateway(self, *args, **kwargs): return self.rdb.mark_profile_as_gateway(*args, **kwargs) @@ -840,9 +850,6 @@ def get_malicious_label(self): def init_tables(self, *args, **kwargs): return self.sqlite.init_tables(*args, **kwargs) - def _init_db(self, *args, **kwargs): - return self.sqlite._init_db(*args, **kwargs) - def create_table(self, *args, **kwargs): return self.sqlite.create_table(*args, **kwargs) diff --git a/slips_files/core/database/redis_db/alert_handler.py b/slips_files/core/database/redis_db/alert_handler.py index 87475745d..e23654aa4 100644 --- a/slips_files/core/database/redis_db/alert_handler.py +++ b/slips_files/core/database/redis_db/alert_handler.py @@ -44,7 +44,11 @@ def increment_attack_counter( def mark_profile_as_malicious(self, profileid: ProfileID): """keeps track of profiles that generated an alert""" - self.r.sadd("malicious_profiles", str(profileid)) + self.r.sadd(self.constants.MALICIOUS_PROFILES, str(profileid)) + + def get_malicious_profiles(self): + """returns profiles that generated an alert""" + self.r.smembers(self.constants.MALICIOUS_PROFILES) def set_evidence_causing_alert(self, alert: Alert): """ @@ -74,7 +78,10 @@ def set_evidence_causing_alert(self, alert: Alert): "alerts", profileid_twid_alerts, ) - self.r.incr("number_of_alerts", 1) + self.r.incr(self.constants.NUMBER_OF_ALERTS, 1) + + def get_number_of_alerts_so_far(self): + return self.r.get(self.constants.NUMBER_OF_ALERTS) def get_evidence_causing_alert( self, profileid, twid, alert_id: str @@ -89,7 +96,7 @@ def get_evidence_causing_alert( return alerts.get(alert_id, False) return False - def get_evidence_by_ID(self, profileid: str, twid: str, evidence_id: str): + def get_evidence_by_id(self, profileid: str, twid: str, evidence_id: str): evidence: Dict[str, dict] = self.get_twid_evidence(profileid, twid) if not evidence: return False @@ -107,11 +114,15 @@ def is_detection_disabled(self, evidence_type: EvidenceType): """ return str(evidence_type) in self.disabled_detections - def set_flow_causing_evidence(self, uids: list, evidence_ID): - self.r.hset("flows_causing_evidence", evidence_ID, json.dumps(uids)) + def set_flow_causing_evidence(self, uids: list, evidence_id): + self.r.hset( + self.constants.FLOWS_CAUSING_EVIDENCE, + evidence_id, + json.dumps(uids), + ) - def get_flows_causing_evidence(self, evidence_ID) -> list: - uids = self.r.hget("flows_causing_evidence", evidence_ID) + def get_flows_causing_evidence(self, evidence_id) -> list: + uids = self.r.hget(self.constants.FLOWS_CAUSING_EVIDENCE, evidence_id) return json.loads(uids) if uids else [] def get_victim(self, profileid, attacker): @@ -190,7 +201,7 @@ def set_evidence(self, evidence: Evidence): # to the db if not evidence_exists: self.r.hset(evidence_hash, evidence.id, evidence_to_send) - self.r.incr("number_of_evidence", 1) + self.r.incr(self.constants.NUMBER_OF_EVIDENCE, 1) self.publish("evidence_added", evidence_to_send) # an evidence is generated for this profile @@ -221,24 +232,24 @@ def set_alert(self, alert: Alert): "profileid": str(alert.profile), "twid": str(alert.timewindow), } - self.publish("new_alert", json.dumps(alert_details)) + self.publish(self.channels.NEW_ALERT, json.dumps(alert_details)) def init_evidence_number(self): """used when the db starts to initialize number of evidence generated by slips""" - self.r.set("number_of_evidence", 0) + self.r.set(self.constants.NUMBER_OF_EVIDENCE, 0) def get_evidence_number(self): - return self.r.get("number_of_evidence") + return self.r.get(self.constants.NUMBER_OF_EVIDENCE) def mark_evidence_as_processed(self, evidence_id: str): """ If an evidence was processed by the evidenceprocess, mark it in the db """ - self.r.sadd("processed_evidence", evidence_id) + self.r.sadd(self.constants.PROCESSED_EVIDENCE, evidence_id) def is_evidence_processed(self, evidence_id: str) -> bool: - return self.r.sismember("processed_evidence", evidence_id) + return self.r.sismember(self.constants.PROCESSED_EVIDENCE, evidence_id) def delete_evidence(self, profileid, twid, evidence_id: str): """ @@ -248,7 +259,7 @@ def delete_evidence(self, profileid, twid, evidence_id: str): # which means that any evidence passed to this function # can never be a part of a past alert self.r.hdel(f"{profileid}_{twid}_evidence", evidence_id) - self.r.incr("number_of_evidence", -1) + self.r.incr(self.constants.NUMBER_OF_EVIDENCE, -1) def cache_whitelisted_evidence_id(self, evidence_id: str): """ @@ -256,15 +267,18 @@ def cache_whitelisted_evidence_id(self, evidence_id: str): alerts later """ # without this function, slips gets the stored evidence id from the db, - # before deleteEvidence is called, so we need to keep track of whitelisted evidence ids - self.r.sadd("whitelisted_evidence", evidence_id) + # before deleteEvidence is called, so we need to keep track of + # whitelisted evidence ids + self.r.sadd(self.constants.WHITELISTED_EVIDENCE, evidence_id) def is_whitelisted_evidence(self, evidence_id): """ Check if we have the evidence ID as whitelisted in the db to avoid showing it in alerts """ - return self.r.sismember("whitelisted_evidence", evidence_id) + return self.r.sismember( + self.constants.WHITELISTED_EVIDENCE, evidence_id + ) def remove_whitelisted_evidence(self, all_evidence: dict) -> dict: """ @@ -314,7 +328,7 @@ def get_accumulated_threat_level(self, profileid: str, twid: str) -> float: returns the accumulated_threat_lvl or 0 if it's not there """ accumulated_threat_lvl = self.r.zscore( - "accumulated_threat_levels", f"{profileid}_{twid}" + self.constants.ACCUMULATED_THREAT_LEVELS, f"{profileid}_{twid}" ) return accumulated_threat_lvl or 0 @@ -330,7 +344,7 @@ def update_accumulated_threat_level( """ return self.r.zincrby( - "accumulated_threat_levels", + self.constants.ACCUMULATED_THREAT_LEVELS, update_val, f"{profileid}_{twid}", ) @@ -342,7 +356,7 @@ def _set_accumulated_threat_level( ): profile_twid = f"{alert.profile}_{alert.timewindow}" self.r.zadd( - "accumulated_threat_levels", + self.constants.ACCUMULATED_THREAT_LEVELS, {profile_twid: accumulated_threat_lvl}, ) diff --git a/slips_files/core/database/redis_db/constants.py b/slips_files/core/database/redis_db/constants.py index 8562a8717..45d2aad62 100644 --- a/slips_files/core/database/redis_db/constants.py +++ b/slips_files/core/database/redis_db/constants.py @@ -18,8 +18,53 @@ class Constants: DOMAINS_INFO = "DomainsInfo" IPS_INFO = "IPsInfo" PROCESSED_FLOWS = "processed_flows_so_far" + MALICIOUS_PROFILES = "malicious_profiles" + FLOWS_CAUSING_EVIDENCE = "flows_causing_evidence" + PROCESSED_EVIDENCE = "processed_evidence" + NUMBER_OF_EVIDENCE = "number_of_evidence" + WHITELISTED_EVIDENCE = "whitelisted_evidence" + SRCIPS_SEEN_IN_CONN_LOG = "srcips_seen_in_connlog" + PASSIVE_DNS = "passiveDNS" + DNS_RESOLUTION = "DNSresolution" + RESOLVED_DOMAINS = "ResolvedDomains" + DOMAINS_RESOLVED = "DomainsResolved" + CACHED_ASN = "cached_asn" + PIDS = "PIDs" + MAC = "MAC" + MODIFIED_TIMEWINDOWS = "ModifiedTW" + ORG_INFO = "OrgInfo" + ACCUMULATED_THREAT_LEVELS = "accumulated_threat_levels" + TRANCO_WHITELISTED_DOMAINS = "tranco_whitelisted_domains" + WHITELIST = "whitelist" + GROWING_ZEEK_DIR = "growing_zeek_dir" + DHCP_SERVERS = "DHCP_servers" + LABELS = "labels" + MSGS_PUBLISHED_AT_RUNTIME = "msgs_published_at_runtime" + ZEEK_FILES = "zeekfiles" + DEFAULT_GATEWAY = "default_gateway" + IS_CYST_ENABLED = "is_cyst_enabled" + LOCAL_NETWORK = "local_network" + ZEEK_PATH = "zeek_path" + P2P_REPORTS = "p2p_reports" + ORGANIZATIONS_PORTS = "organization_port" + SLIPS_START_TIME = "slips_start_time" + USED_FTP_PORTS = "used_ftp_ports" + SLIPS_INTERNAL_TIME = "slips_internal_time" + WARDEN_INFO = "Warden" + MODE = "mode" + ANALYSIS = "analysis" + LOGGED_CONNECTION_ERR = "logged_connection_error" + P2P_RECEIVED_BLAME_REPORTS = "p2p-received-blame-reports" + MULTICAST_ADDRESS = "multiAddress" + PORT_INFO = "portinfo" + DHCP_FLOWS = "DHCP_flows" + REDIS_USED_PORT = "port" + BLOCKED_PROFILES_AND_TWS = "BlockedProfTW" + PROFILES = "profiles" + NUMBER_OF_ALERTS = "number_of_alerts" KNOWN_FPS = "known_fps" class Channels: DNS_INFO_CHANGE = "dns_info_change" + NEW_ALERT = "new_alert" diff --git a/slips_files/core/database/redis_db/database.py b/slips_files/core/database/redis_db/database.py index 195f1a6de..584cffcd3 100644 --- a/slips_files/core/database/redis_db/database.py +++ b/slips_files/core/database/redis_db/database.py @@ -199,12 +199,12 @@ def set_slips_internal_time(cls, timestamp): metadata_manager.py checks for new tw modifications every 5s and updates this value accordingly """ - cls.r.set("slips_internal_time", timestamp) + cls.r.set(cls.constants.SLIPS_INTERNAL_TIME, timestamp) @classmethod def get_slips_start_time(cls) -> str: """get the time slips started in unix format""" - return cls.r.get("slips_start_time") + return cls.r.get(cls.constants.SLIPS_START_TIME) @classmethod def init_redis_server(cls) -> Tuple[bool, str]: @@ -246,7 +246,7 @@ def init_redis_server(cls) -> Tuple[bool, str]: # configure redis to stop writing to dump.rdb when an error # occurs without throwing errors in slips # Even if the DB is not deleted. We need to delete some temp data - cls.r.delete("zeekfiles") + cls.r.delete(cls.constants.ZEEK_FILES) return True, "" except RuntimeError as err: return False, str(err) @@ -333,7 +333,7 @@ def connect_to_redis_server(cls) -> Tuple[bool, str]: @classmethod def close_redis_server(cls, redis_port): - if server_pid := cls.get_redis_server_PID(redis_port): + if server_pid := cls.get_redis_server_pid(redis_port): os.kill(int(server_pid), signal.SIGKILL) @classmethod @@ -361,17 +361,17 @@ def change_redis_limits(cls, client: redis.StrictRedis): def _set_slips_start_time(cls): """store the time slips started (datetime obj)""" now = time.time() - cls.r.set("slips_start_time", now) + cls.r.set(cls.constants.SLIPS_START_TIME, now) def publish(self, channel, msg): """Publish a msg in the given channel""" # keeps track of how many msgs were published in the given channel - self.r.hincrby("msgs_published_at_runtime", channel, 1) + self.r.hincrby(self.constants.MSGS_PUBLISHED_AT_RUNTIME, channel, 1) self.r.publish(channel, msg) def get_msgs_published_in_channel(self, channel: str) -> int: """returns the number of msgs published in a channel""" - return self.r.hget("msgs_published_at_runtime", channel) + return self.r.hget(self.constants.MSGS_PUBLISHED_AT_RUNTIME, channel) def subscribe(self, channel: str, ignore_subscribe_messages=True): """Subscribe to channel""" @@ -395,8 +395,10 @@ def publish_stop(self): def get_message(self, channel, timeout=0.0000001): """ - Wrapper for redis' get_message() to be able to handle redis.exceptions.ConnectionError - notice: there has to be a timeout or the channel will wait forever and never receive a new msg + Wrapper for redis' get_message() to be able to handle + redis.exceptions.ConnectionError + notice: there has to be a timeout or the channel will wait forever + and never receive a new msg """ try: return channel.get_message(timeout=timeout) @@ -506,18 +508,18 @@ def ask_for_ip_info( data_to_send.update({"cache_age": cache_age, "ip": str(ip)}) self.publish("p2p_data_request", json.dumps(data_to_send)) - def getSlipsInternalTime(self): - return self.r.get("slips_internal_time") or 0 + def get_slips_internal_time(self): + return self.r.get(self.constants.SLIPS_INTERNAL_TIME) or 0 def get_redis_keys_len(self) -> int: """returns the length of all keys in the db""" return self.r.dbsize() def set_cyst_enabled(self): - return self.r.set("is_cyst_enabled", "yes") + return self.r.set(self.constants.IS_CYST_ENABLED, "yes") def is_cyst_enabled(self): - return self.r.get("is_cyst_enabled") + return self.r.get(self.constants.IS_CYST_ENABLED) def get_equivalent_tws(self, hrs: float) -> int: """ @@ -530,28 +532,30 @@ def set_local_network(self, cidr): """ set the local network used in the db """ - self.r.set("local_network", cidr) + self.r.set(self.constants.LOCAL_NETWORK, cidr) def get_local_network(self): - return self.r.get("local_network") + return self.r.get(self.constants.LOCAL_NETWORK) - def get_used_port(self): - return int(self.r.config_get("port")["port"]) + def get_used_port(self) -> int: + return int(self.r.config_get(self.constants.REDIS_USED_PORT)["port"]) def get_label_count(self, label): """ :param label: malicious or normal """ - return self.r.zscore("labels", label) + return self.r.zscore(self.constants.LABELS, label) def get_enabled_modules(self) -> List[str]: """ Returns a list of the loaded/enabled modules """ - return self.r.hkeys("PIDs") + return self.r.hkeys(self.constants.PIDS) def get_disabled_modules(self) -> List[str]: - if disabled_modules := self.r.hget("analysis", "disabled_modules"): + if disabled_modules := self.r.hget( + self.constants.ANALYSIS, "disabled_modules" + ): return json.loads(disabled_modules) else: return {} @@ -561,37 +565,39 @@ def set_input_metadata(self, info: dict): sets name, size, analysis dates, and zeek_dir in the db """ for info, val in info.items(): - self.r.hset("analysis", info, val) + self.r.hset(self.constants.ANALYSIS, info, val) def get_zeek_output_dir(self): """ gets zeek output dir from the db """ - return self.r.hget("analysis", "zeek_dir") + return self.r.hget(self.constants.ANALYSIS, "zeek_dir") def get_input_file(self): """ gets zeek output dir from the db """ - return self.r.hget("analysis", "name") + return self.r.hget(self.constants.ANALYSIS, "name") def get_commit(self): """ gets the currently used commit from the db """ - return self.r.hget("analysis", "commit") + return self.r.hget(self.constants.ANALYSIS, "commit") def get_branch(self): """ gets the currently used branch from the db """ - return self.r.hget("analysis", "branch") + return self.r.hget(self.constants.ANALYSIS, "branch") def get_evidence_detection_threshold(self): """ gets the currently used evidence_detection_threshold from the db """ - return self.r.hget("analysis", "evidence_detection_threshold") + return self.r.hget( + self.constants.ANALYSIS, "evidence_detection_threshold" + ) def get_input_type(self) -> str: """ @@ -600,13 +606,13 @@ def get_input_type(self) -> str: "zeek_log_file", "zeek_folder", "stdin", "nfdump", "binetflow", "suricata" """ - return self.r.hget("analysis", "input_type") + return self.r.hget(self.constants.ANALYSIS, "input_type") def get_output_dir(self): """ returns the currently used output dir """ - return self.r.hget("analysis", "output_dir") + return self.r.hget(self.constants.ANALYSIS, "output_dir") def set_ip_info(self, ip: str, to_store: dict): """ @@ -646,8 +652,9 @@ def get_p2p_reports_about_ip(self, ip) -> dict: """ returns a dict of all p2p past reports about the given ip """ - # p2p_reports key is basically { ip: { reporter1: [report1, report2, report3]} } - if reports := self.rcache.hget("p2p_reports", ip): + # p2p_reports key is basically + # { ip: { reporter1: [report1, report2, report3]} } + if reports := self.rcache.hget(self.constants.P2P_REPORTS, ip): return json.loads(reports) return {} @@ -694,7 +701,9 @@ def store_p2p_report(self, ip: str, report_data: dict): # no old reports about this ip report_data = {reporter: [report_data]} - self.rcache.hset("p2p_reports", ip, json.dumps(report_data)) + self.rcache.hset( + self.constants.P2P_REPORTS, ip, json.dumps(report_data) + ) def get_dns_resolution(self, ip): """ @@ -706,7 +715,7 @@ def get_dns_resolution(self, ip): If not resolved, returns {} this function is called for every IP in the timeline of kalipso """ - if ip_info := self.r.hget("DNSresolution", ip): + if ip_info := self.r.hget(self.constants.DNS_RESOLUTION, ip): ip_info = json.loads(ip_info) # return a dict with 'ts' 'uid' 'domains' about this IP return ip_info @@ -732,12 +741,13 @@ def is_ip_resolved(self, ip, hrs): return False def delete_dns_resolution(self, ip): - self.r.hdel("DNSresolution", ip) + self.r.hdel(self.constants.DNS_RESOLUTION, ip) def should_store_resolution( self, query: str, answers: list, qtype_name: str ): - # don't store queries ending with arpa as dns resolutions, they're reverse dns + # don't store queries ending with arpa as dns resolutions, + # they're reverse dns # only store type A and AAAA for ipv4 and ipv6 if ( qtype_name not in ["AAAA", "A"] @@ -746,7 +756,8 @@ def should_store_resolution( ): return False - # sometimes adservers are resolved to 0.0.0.0 or "127.0.0.1" to block the domain. + # sometimes adservers are resolved to 0.0.0.0 or "127.0.0.1" to + # block the domain. # don't store this as a valid dns resolution if query != "localhost": for answer in answers: @@ -831,9 +842,9 @@ def set_dns_resolution( ip_info = json.dumps(ip_info) # we store ALL dns resolutions seen since starting slips # store with the IP as the key - self.r.hset("DNSresolution", answer, ip_info) + self.r.hset(self.constants.DNS_RESOLUTION, answer, ip_info) # store with the domain as the key: - self.r.hset("ResolvedDomains", domains[0], answer) + self.r.hset(self.constants.RESOLVED_DOMAINS, domains[0], answer) # these ips will be associated with the query in our db ips_to_add.append(answer) @@ -858,10 +869,10 @@ def set_domain_resolution(self, domain, ips): """ stores all the resolved domains with their ips in the db """ - self.r.hset("DomainsResolved", domain, json.dumps(ips)) + self.r.hset(self.constants.DOMAINS_RESOLVED, domain, json.dumps(ips)) @staticmethod - def get_redis_server_PID(redis_port): + def get_redis_server_pid(redis_port): """ get the PID of the redis server started on the given redis_port retrns the pid @@ -879,14 +890,14 @@ def set_slips_mode(self, slips_mode): function to store the current mode (daemonized/interactive) in the db """ - self.r.set("mode", slips_mode) + self.r.set(self.constants.MODE, slips_mode) def get_slips_mode(self): """ function to get the current mode (daemonized/interactive) in the db """ - self.r.get("mode") + self.r.get(self.constants.MODE) def get_modified_ips_in_the_last_tw(self): """ @@ -894,14 +905,14 @@ def get_modified_ips_in_the_last_tw(self): used for printing running stats in slips.py or outputprocess """ if modified_ips := self.r.hget( - "analysis", "modified_ips_in_the_last_tw" + self.constants.ANALYSIS, "modified_ips_in_the_last_tw" ): return modified_ips else: return 0 def is_connection_error_logged(self): - return bool(self.r.get("logged_connection_error")) + return bool(self.r.get(self.constants.LOGGED_CONNECTION_ERR)) def mark_connection_error_as_logged(self): """ @@ -909,18 +920,19 @@ def mark_connection_error_as_logged(self): every module from logging it to slips.log and the console, set this variable in the db """ - self.r.set("logged_connection_error", "True") + self.r.set(self.constants.LOGGED_CONNECTION_ERR, "True") def was_ip_seen_in_connlog_before(self, ip) -> bool: """ returns true if this is not the first flow slip sees of the given ip """ # we store every source address seen in a conn.log flow in this key - # if the source address is not stored in this key, it means we may have seen it - # but not in conn.log yet + # if the source address is not stored in this key, it means we may + # have seen it but not in conn.log yet - # if the ip's not in the following key, then its the first flow seen of this ip - return self.r.sismember("srcips_seen_in_connlog", ip) + # if the ip's not in the following key, then its the first flow + # seen of this ip + return self.r.sismember(self.constants.SRCIPS_SEEN_IN_CONN_LOG, ip) def mark_srcip_as_seen_in_connlog(self, ip): """ @@ -929,7 +941,7 @@ def mark_srcip_as_seen_in_connlog(self, ip): if an ip is not present in this set, it means we may have seen it but not in conn.log """ - self.r.sadd("srcips_seen_in_connlog", ip) + self.r.sadd(self.constants.SRCIPS_SEEN_IN_CONN_LOG, ip) def is_gw_mac(self, mac_addr: str, ip: str) -> bool: """ @@ -952,7 +964,7 @@ def is_gw_mac(self, mac_addr: str, ip: str) -> bool: # now we're given a public ip and a MAC that's supposedly belongs to it # we are sure this is the gw mac # set it if we don't already have it in the db - self.set_default_gateway("MAC", mac_addr) + self.set_default_gateway(self.constants.MAC, mac_addr) # mark the gw mac as found so we don't look for it again self._gateway_MAC_found = True @@ -962,11 +974,13 @@ def get_ip_of_mac(self, MAC): """ Returns the IP associated with the given MAC in our database """ - return self.r.hget("MAC", MAC) + return self.r.hget(self.constants.MAC, MAC) def get_modified_tw(self): """Return all the list of modified tw""" - data = self.r.zrange("ModifiedTW", 0, -1, withscores=True) + data = self.r.zrange( + self.constants.MODIFIED_TIMEWINDOWS, 0, -1, withscores=True + ) return data or [] def get_field_separator(self): @@ -978,22 +992,25 @@ def store_tranco_whitelisted_domain(self, domain): store whitelisted domain from tranco whitelist in the db """ # the reason we store tranco whitelisted domains in the cache db - # instead of the main db is, we don't want them cleared on every new instance of slips - self.rcache.sadd("tranco_whitelisted_domains", domain) + # instead of the main db is, we don't want them cleared on every new + # instance of slips + self.rcache.sadd(self.constants.TRANCO_WHITELISTED_DOMAINS, domain) def is_whitelisted_tranco_domain(self, domain): - return self.rcache.sismember("tranco_whitelisted_domains", domain) + return self.rcache.sismember( + self.constants.TRANCO_WHITELISTED_DOMAINS, domain + ) def set_growing_zeek_dir(self): """ Mark a dir as growing so it can be treated like the zeek logs generated by an interface """ - self.r.set("growing_zeek_dir", "yes") + self.r.set(self.constants.GROWING_ZEEK_DIR, "yes") def is_growing_zeek_dir(self): """Did slips mark the given dir as growing?""" - return "yes" in str(self.r.get("growing_zeek_dir")) + return "yes" in str(self.r.get(self.constants.GROWING_ZEEK_DIR)) def get_asn_info(self, ip: str) -> Optional[Dict[str, str]]: """ @@ -1057,38 +1074,38 @@ def get_multiaddr(self): """ this is can only be called when p2p is enabled, this value is set by p2p pigeon """ - return self.r.get("multiAddress") + return self.r.get(self.constants.MULTICAST_ADDRESS) def get_labels(self): """ Return the amount of each label so far in the DB Used to know how many labels are available during training """ - return self.r.zrange("labels", 0, -1, withscores=True) + return self.r.zrange(self.constants.LABELS, 0, -1, withscores=True) def set_port_info(self, portproto: str, name): """ Save in the DB a port with its description :param portproto: portnumber + / + protocol """ - self.rcache.hset("portinfo", portproto, name) + self.rcache.hset(self.constants.PORT_INFO, portproto, name) def get_port_info(self, portproto: str): """ Retrieve the name of a port :param portproto: portnumber + / + protocol """ - return self.rcache.hget("portinfo", portproto) + return self.rcache.hget(self.constants.PORT_INFO, portproto) def set_ftp_port(self, port): """ Stores the used ftp port in our main db (not the cache like set_port_info) """ - self.r.lpush("used_ftp_ports", str(port)) + self.r.lpush(self.constants.USED_FTP_PORTS, str(port)) def is_ftp_port(self, port): # get all used ftp ports - used_ftp_ports = self.r.lrange("used_ftp_ports", 0, -1) + used_ftp_ports = self.r.lrange(self.constants.USED_FTP_PORTS, 0, -1) # check if the given port is used as ftp port return str(port) in used_ftp_ports @@ -1109,7 +1126,9 @@ def set_organization_of_port(self, organization, ip: str, portproto: str): org_info = {"org_name": [organization], "ip": [ip]} org_info = json.dumps(org_info) - self.rcache.hset("organization_port", portproto, org_info) + self.rcache.hset( + self.constants.ORGANIZATIONS_PORTS, portproto, org_info + ) def get_organization_of_port(self, portproto: str): """ @@ -1118,24 +1137,26 @@ def get_organization_of_port(self, portproto: str): """ # this key is used to store the ports the are known to be used # by certain organizations - return self.rcache.hget("organization_port", portproto.lower()) + return self.rcache.hget( + self.constants.ORGANIZATIONS_PORTS, portproto.lower() + ) def add_zeek_file(self, filename): """Add an entry to the list of zeek files""" - self.r.sadd("zeekfiles", filename) + self.r.sadd(self.constants.ZEEK_FILES, filename) def get_all_zeek_files(self) -> set: """Return all entries from the list of zeek files""" - return self.r.smembers("zeekfiles") + return self.r.smembers(self.constants.ZEEK_FILES) def get_gateway_ip(self): - return self.r.hget("default_gateway", "IP") + return self.r.hget(self.constants.DEFAULT_GATEWAY, "IP") def get_gateway_mac(self): - return self.r.hget("default_gateway", "MAC") + return self.r.hget(self.constants.DEFAULT_GATEWAY, self.constants.MAC) - def get_gateway_MAC_Vendor(self): - return self.r.hget("default_gateway", "Vendor") + def get_gateway_mac_vendor(self): + return self.r.hget(self.constants.DEFAULT_GATEWAY, "Vendor") def set_default_gateway(self, address_type: str, address: str): """ @@ -1145,20 +1166,23 @@ def set_default_gateway(self, address_type: str, address: str): # make sure the IP or mac aren't already set before re-setting if ( (address_type == "IP" and not self.get_gateway_ip()) - or (address_type == "MAC" and not self.get_gateway_mac()) - or (address_type == "Vendor" and not self.get_gateway_MAC_Vendor()) + or ( + address_type == self.constants.MAC + and not self.get_gateway_mac() + ) + or (address_type == "Vendor" and not self.get_gateway_mac_vendor()) ): - self.r.hset("default_gateway", address_type, address) + self.r.hset(self.constants.DEFAULT_GATEWAY, address_type, address) def get_domain_resolution(self, domain) -> List[str]: """ Returns the IPs resolved by this domain """ - ips = self.r.hget("DomainsResolved", domain) + ips = self.r.hget(self.constants.DOMAINS_RESOLVED, domain) return json.loads(ips) if ips else [] def get_all_dns_resolutions(self): - dns_resolutions = self.r.hgetall("DNSresolution") + dns_resolutions = self.r.hgetall(self.constants.DNS_RESOLUTION) return dns_resolutions or [] def is_running_non_stop(self) -> bool: @@ -1176,13 +1200,13 @@ def set_passive_dns(self, ip, data): """ if data: data = json.dumps(data) - self.rcache.hset("passiveDNS", ip, data) + self.rcache.hset(self.constants.PASSIVE_DNS, ip, data) def get_passive_dns(self, ip): """ Gets passive DNS from the db """ - if data := self.rcache.hget("passiveDNS", ip): + if data := self.rcache.hget(self.constants.PASSIVE_DNS, ip): return json.loads(data) else: return False @@ -1193,7 +1217,7 @@ def get_reconnections_for_tw(self, profileid, twid): data = json.loads(data) if data else {} return data - def setReconnections(self, profileid, twid, data): + def set_reconnections(self, profileid, twid, data): """Set the reconnections for this TW for this Profile""" data = json.dumps(data) self.r.hset(f"{profileid}_{twid}", "Reconnections", str(data)) @@ -1244,10 +1268,14 @@ def set_asn_cache(self, org: str, asn_range: str, asn_number: str) -> None: # starts with the same first octet cached_asn: dict = json.loads(cached_asn) cached_asn.update(range_info) - self.rcache.hset("cached_asn", first_octet, json.dumps(cached_asn)) + self.rcache.hset( + self.constants.CACHED_ASN, first_octet, json.dumps(cached_asn) + ) else: # first time storing a range starting with the same first octet - self.rcache.hset("cached_asn", first_octet, json.dumps(range_info)) + self.rcache.hset( + self.constants.CACHED_ASN, first_octet, json.dumps(range_info) + ) def get_asn_cache(self, first_octet=False): """ @@ -1255,9 +1283,9 @@ def get_asn_cache(self, first_octet=False): Returns cached asn of ip if present, or False. """ if first_octet: - return self.rcache.hget("cached_asn", first_octet) - else: - return self.rcache.hgetall("cached_asn") + return self.rcache.hget(self.constants.CACHED_ASN, first_octet) + + return self.rcache.hgetall(self.constants.CACHED_ASN) def store_pid(self, process: str, pid: int): """ @@ -1265,14 +1293,14 @@ def store_pid(self, process: str, pid: int): :param pid: int :param process: module name, str """ - self.r.hset("PIDs", process, pid) + self.r.hset(self.constants.PIDS, process, pid) def get_pids(self) -> dict: """returns a dict with module names as keys and PIDs as values""" - return self.r.hgetall("PIDs") + return self.r.hgetall(self.constants.PIDS) def get_pid_of(self, module_name: str): - pid = self.r.hget("PIDs", module_name) + pid = self.r.hget(self.constants.PIDS, module_name) return int(pid) if pid else None def get_name_of_module_at(self, given_pid): @@ -1291,7 +1319,9 @@ def set_org_info(self, org, org_info, info_type): """ # info will be stored in OrgInfo key {'facebook_asn': .., # 'twitter_domains': ...} - self.rcache.hset("OrgInfo", f"{org}_{info_type}", org_info) + self.rcache.hset( + self.constants.ORG_INFO, f"{org}_{info_type}", org_info + ) def get_org_info(self, org, info_type) -> str: """ @@ -1302,10 +1332,13 @@ def get_org_info(self, org, info_type) -> str: returns a json serialized dict with info PS: All ASNs returned by this function are uppercase """ - return self.rcache.hget("OrgInfo", f"{org}_{info_type}") or "[]" + return ( + self.rcache.hget(self.constants.ORG_INFO, f"{org}_{info_type}") + or "[]" + ) - def get_org_IPs(self, org): - org_info = self.rcache.hget("OrgInfo", f"{org}_IPs") + def get_org_ips(self, org): + org_info = self.rcache.hget(self.constants.ORG_INFO, f"{org}_IPs") if not org_info: org_info = {} @@ -1323,14 +1356,18 @@ def set_whitelist(self, type_, whitelist_dict): :param type_: supported types are IPs, domains, macs and organizations :param whitelist_dict: the dict of IPs,macs, domains or orgs to store """ - self.r.hset("whitelist", type_, json.dumps(whitelist_dict)) + self.r.hset( + self.constants.WHITELIST, type_, json.dumps(whitelist_dict) + ) def get_all_whitelist(self) -> Optional[Dict[str, dict]]: """ Returns a dict with the following keys from the whitelist 'mac', 'organizations', 'IPs', 'domains' """ - whitelist: Optional[Dict[str, str]] = self.r.hgetall("whitelist") + whitelist: Optional[Dict[str, str]] = self.r.hgetall( + self.constants.WHITELIST + ) if whitelist: whitelist = {k: json.loads(v) for k, v in whitelist.items()} return whitelist @@ -1342,13 +1379,13 @@ def get_whitelist(self, key: str) -> dict: this function is used to check if we have any of the above keys whitelisted """ - if whitelist := self.r.hget("whitelist", key): + if whitelist := self.r.hget(self.constants.WHITELIST, key): return json.loads(whitelist) else: return {} def has_cached_whitelist(self) -> bool: - return bool(self.r.exists("whitelist")) + return bool(self.r.exists(self.constants.WHITELIST)) def store_dhcp_server(self, server_addr): """ @@ -1361,9 +1398,9 @@ def store_dhcp_server(self, server_addr): # not a valid ip skip return False # make sure the server isn't there before adding - dhcp_servers = self.r.lrange("DHCP_servers", 0, -1) + dhcp_servers = self.r.lrange(self.constants.DHCP_SERVERS, 0, -1) if server_addr not in dhcp_servers: - self.r.lpush("DHCP_servers", server_addr) + self.r.lpush(self.constants.DHCP_SERVERS, server_addr) def save(self, backup_file): """ @@ -1446,13 +1483,13 @@ def set_last_warden_poll_time(self, time): """ :param time: epoch """ - self.r.hset("Warden", "poll", time) + self.r.hset(self.constants.WARDEN_INFO, "poll", time) def get_last_warden_poll_time(self): """ returns epoch time of last poll """ - time = self.r.hget("Warden", "poll") + time = self.r.hget(self.constants.WARDEN_INFO, "poll") time = float(time) if time else float("-inf") return time @@ -1484,16 +1521,18 @@ def store_blame_report(self, ip, network_evaluation): 'confidence': .., 'ts': ..} taken from a blame report """ - self.rcache.hset("p2p-received-blame-reports", ip, network_evaluation) + self.rcache.hset( + self.constants.P2P_RECEIVED_BLAME_REPORTS, ip, network_evaluation + ) def store_zeek_path(self, path): """used to store the path of zeek log files slips is currently using""" - self.r.set("zeek_path", path) + self.r.set(self.constants.ZEEK_PATH, path) def get_zeek_path(self) -> str: """return the path of zeek log files slips is currently using""" - return self.r.get("zeek_path") + return self.r.get(self.constants.ZEEK_PATH) def increment_processed_flows(self): return self.r.incr(self.constants.PROCESSED_FLOWS, 1) diff --git a/slips_files/core/database/redis_db/ioc_handler.py b/slips_files/core/database/redis_db/ioc_handler.py index d6aeafb4c..35a19ce69 100644 --- a/slips_files/core/database/redis_db/ioc_handler.py +++ b/slips_files/core/database/redis_db/ioc_handler.py @@ -128,19 +128,19 @@ def is_known_fp_md5_hash(self, hash: str) -> Optional[str]: returns Fals eif the hash is not a FP""" return self.rcache.hmget(self.constants.KNOWN_FPS, hash) - def delete_ips_from_IoC_ips(self, ips: List[str]): + def delete_ips_from_ioc_ips(self, ips: List[str]): """ Delete the given IPs from IoC """ self.rcache.hdel(self.constants.IOC_IPS, *ips) - def delete_domains_from_IoC_domains(self, domains: List[str]): + def delete_domains_from_ioc_domains(self, domains: List[str]): """ Delete old domains from IoC """ self.rcache.hdel(self.constants.IOC_DOMAINS, *domains) - def add_ips_to_IoC(self, ips_and_description: Dict[str, str]) -> None: + def add_ips_to_ioc(self, ips_and_description: Dict[str, str]) -> None: """ Store a group of IPs in the db as they were obtained from an IoC source :param ips_and_description: is {ip: json.dumps{'source':.., @@ -152,7 +152,7 @@ def add_ips_to_IoC(self, ips_and_description: Dict[str, str]) -> None: if ips_and_description: self.rcache.hmset(self.constants.IOC_IPS, ips_and_description) - def add_domains_to_IoC(self, domains_and_description: dict) -> None: + def add_domains_to_ioc(self, domains_and_description: dict) -> None: """ Store a group of domains in the db as they were obtained from an IoC source @@ -165,7 +165,7 @@ def add_domains_to_IoC(self, domains_and_description: dict) -> None: self.constants.IOC_DOMAINS, domains_and_description ) - def add_ip_range_to_IoC(self, malicious_ip_ranges: dict) -> None: + def add_ip_range_to_ioc(self, malicious_ip_ranges: dict) -> None: """ Store a group of IP ranges in the db as they were obtained from an IoC source :param malicious_ip_ranges: is @@ -177,7 +177,7 @@ def add_ip_range_to_IoC(self, malicious_ip_ranges: dict) -> None: self.constants.IOC_IP_RANGES, malicious_ip_ranges ) - def add_asn_to_IoC(self, blacklisted_ASNs: dict): + def add_asn_to_ioc(self, blacklisted_ASNs: dict): """ Store a group of ASN in the db as they were obtained from an IoC source :param blacklisted_ASNs: is @@ -187,7 +187,7 @@ def add_asn_to_IoC(self, blacklisted_ASNs: dict): if blacklisted_ASNs: self.rcache.hmset(self.constants.IOC_ASN, blacklisted_ASNs) - def add_ja3_to_IoC(self, ja3: dict) -> None: + def add_ja3_to_ioc(self, ja3: dict) -> None: """ Store the malicious ja3 iocs in the db :param ja3: {ja3: {'source':..,'tags':.., @@ -196,7 +196,7 @@ def add_ja3_to_IoC(self, ja3: dict) -> None: """ self.rcache.hmset(self.constants.IOC_JA3, ja3) - def add_jarm_to_IoC(self, jarm: dict) -> None: + def add_jarm_to_ioc(self, jarm: dict) -> None: """ Store the malicious jarm iocs in the db :param jarm: {jarm: {'source':..,'tags':.., @@ -204,7 +204,7 @@ def add_jarm_to_IoC(self, jarm: dict) -> None: """ self.rcache.hmset(self.constants.IOC_JARM, jarm) - def add_ssl_sha1_to_IoC(self, malicious_ssl_certs): + def add_ssl_sha1_to_ioc(self, malicious_ssl_certs): """ Store a group of ssl fingerprints in the db :param malicious_ssl_certs: {sha1: {'source':..,'tags':.., @@ -212,7 +212,7 @@ def add_ssl_sha1_to_IoC(self, malicious_ssl_certs): """ self.rcache.hmset(self.constants.IOC_SSL, malicious_ssl_certs) - def is_blacklisted_ASN(self, asn) -> bool: + def is_blacklisted_asn(self, asn) -> bool: return self.rcache.hget(self.constants.IOC_ASN, asn) def is_blacklisted_jarm(self, jarm_hash: str): diff --git a/slips_files/core/database/redis_db/profile_handler.py b/slips_files/core/database/redis_db/profile_handler.py index d454c33f0..5d19f9174 100644 --- a/slips_files/core/database/redis_db/profile_handler.py +++ b/slips_files/core/database/redis_db/profile_handler.py @@ -41,7 +41,9 @@ def get_dhcp_flows(self, profileid, twid) -> list: """ returns a dict of dhcp flows that happened in this profileid and twid """ - if flows := self.r.hget("DHCP_flows", f"{profileid}_{twid}"): + if flows := self.r.hget( + self.constants.DHCP_FLOWS, f"{profileid}_{twid}" + ): return json.loads(flows) def set_dhcp_flow(self, profileid, twid, requested_addr, uid): @@ -53,10 +55,16 @@ def set_dhcp_flow(self, profileid, twid, requested_addr, uid): # we already have flows in this twid, update them cached_flows.update(flow) self.r.hset( - "DHCP_flows", f"{profileid}_{twid}", json.dumps(cached_flows) + self.constants.DHCP_FLOWS, + f"{profileid}_{twid}", + json.dumps(cached_flows), ) else: - self.r.hset("DHCP_flows", f"{profileid}_{twid}", json.dumps(flow)) + self.r.hset( + self.constants.DHCP_FLOWS, + f"{profileid}_{twid}", + json.dumps(flow), + ) def get_timewindow(self, flowtime, profileid): """ @@ -88,7 +96,9 @@ def get_timewindow(self, flowtime, profileid): tw_start = float(flowtime - (31536000 * 100)) tw_number: int = 1 else: - starttime_of_first_tw: str = self.r.hget("analysis", "file_start") + starttime_of_first_tw: str = self.r.hget( + self.constants.ANALYSIS, "file_start" + ) if starttime_of_first_tw: starttime_of_first_tw = float(starttime_of_first_tw) tw_number: int = ( @@ -116,8 +126,10 @@ def add_out_http( ): """ Store in the DB a http request - All the type of flows that are not netflows are stored in a separate hash ordered by uid. - The idea is that from the uid of a netflow, you can access which other type of info is related to that uid + All the type of flows that are not netflows are stored in a separate + hash ordered by uid. + The idea is that from the uid of a netflow, you can access which other + type of info is related to that uid """ # Convert to json string http_flow = { @@ -531,7 +543,8 @@ def get_data_from_profile_tw( except Exception: exception_line = sys.exc_info()[2].tb_lineno self.print( - f"Error in getDataFromProfileTW database.py line {exception_line}", + f"Error in getDataFromProfileTW database.py line " + f"{exception_line}", 0, 1, ) @@ -725,26 +738,36 @@ def mark_profile_and_timewindow_as_blocked(self, profileid, twid): a profile is only blocked if it was blocked using the user's firewall, not if it just generated an alert """ - tws = self.getBlockedProfTW(profileid) + tws = self.get_blocked_timewindows_of_profile(profileid) tws.append(twid) - self.r.hset("BlockedProfTW", profileid, json.dumps(tws)) + self.r.hset( + self.constants.BLOCKED_PROFILES_AND_TWS, profileid, json.dumps(tws) + ) - def getBlockedProfTW(self, profileid): + def get_blocked_timewindows_of_profile(self, profileid): """Return all the list of blocked tws""" - if tws := self.r.hget("BlockedProfTW", profileid): + if tws := self.r.hget( + self.constants.BLOCKED_PROFILES_AND_TWS, profileid + ): return json.loads(tws) return [] - def checkBlockedProfTW(self, profileid, twid): + def get_blocked_profiles_and_timewindows(self): + return self.r.hgetall(self.constants.BLOCKED_PROFILES_AND_TWS) + + def is_blocked_profile_and_tw(self, profileid, twid): """ Check if profile and timewindow is blocked """ - profile_tws = self.getBlockedProfTW(profileid) + profile_tws = self.get_blocked_timewindows_of_profile(profileid) return twid in profile_tws - def wasProfileTWModified(self, profileid, twid): + def was_profile_and_tw_modified(self, profileid, twid): """Retrieve from the db if this TW of this profile was modified""" - data = self.r.zrank("ModifiedTW", profileid + self.separator + twid) + data = self.r.zrank( + self.constants.MODIFIED_TIMEWINDOWS, + profileid + self.separator + twid, + ) return bool(data) def add_flow( @@ -760,7 +783,7 @@ def add_flow( The profileid is the main profile that this flow is related too. """ if label: - self.r.zincrby("labels", 1, label) + self.r.zincrby(self.constants.LABELS, 1, label) to_send = { "profileid": profileid, @@ -815,7 +838,10 @@ def get_total_flows(self): """ gets total flows to process from the db """ - return self.r.hget("analysis", "total_flows") + return self.r.hget(self.constants.ANALYSIS, "total_flows") + + def get_analysis_info(self): + return self.r.hgetall(self.constants.ANALYSIS) def add_out_ssh( self, @@ -854,7 +880,8 @@ def add_out_notice( twid, flow, ): - """ " Send notice.log data to new_notice channel to look for self-signed certificates""" + """Send notice.log data to new_notice channel to look for + self-signed certificates""" to_send = { "profileid": profileid, "twid": twid, @@ -909,9 +936,9 @@ def add_out_ssl(self, profileid, twid, flow): else: sni_ipdata = [] - SNI_port = {"server_name": flow.server_name, "dport": flow.dport} + sni_port = {"server_name": flow.server_name, "dport": flow.dport} # We do not want any duplicates. - if SNI_port not in sni_ipdata: + if sni_port not in sni_ipdata: # Verify that the SNI is equal to any of the domains in the DNS # resolution # only add this SNI to our db if it has a DNS resolution @@ -920,9 +947,9 @@ def add_out_ssl(self, profileid, twid, flow): # 'uid':..}} for ip, resolution in dns_resolutions.items(): resolution = json.loads(resolution) - if SNI_port["server_name"] in resolution["domains"]: + if sni_port["server_name"] in resolution["domains"]: # add SNI to our db as it has a DNS resolution - sni_ipdata.append(SNI_port) + sni_ipdata.append(sni_port) self.set_ip_info(flow.daddr, {"SNI": sni_ipdata}) break @@ -933,7 +960,7 @@ def get_profileid_from_ip(self, ip: str) -> Optional[str]: """ try: profileid = f"profile_{ip}" - if self.r.sismember("profiles", profileid): + if self.r.sismember(self.constants.PROFILES, profileid): return profileid return False except redis.exceptions.ResponseError as inst: @@ -941,9 +968,9 @@ def get_profileid_from_ip(self, ip: str) -> Optional[str]: self.print(type(inst), 0, 1) self.print(inst, 0, 1) - def getProfiles(self): + def get_profiles(self): """Get a list of all the profiles""" - profiles = self.r.smembers("profiles") + profiles = self.r.smembers(self.constants.PROFILES) return profiles if profiles != set() else {} def get_tws_from_profile(self, profileid): @@ -1003,12 +1030,16 @@ def get_t2_for_profile_tw(self, profileid, twid, tupleid, tuple_key: str): def has_profile(self, profileid): """Check if we have the given profile""" - return self.r.sismember("profiles", profileid) if profileid else False + return ( + self.r.sismember(self.constants.PROFILES, profileid) + if profileid + else False + ) def get_profiles_len(self) -> int: """Return the amount of profiles. Redis should be faster than python to do this count""" - profiles_n = self.r.scard("profiles") + profiles_n = self.r.scard(self.constants.PROFILES) return 0 if not profiles_n else int(profiles_n) def get_last_twid_of_profile(self, profileid: str) -> Tuple[str, float]: @@ -1121,7 +1152,10 @@ def get_modified_tw_since_time( # the score of each tw is the ts it was last updated # this ts is not network time, it is local time data = self.r.zrangebyscore( - "ModifiedTW", time, float("+inf"), withscores=True + self.constants.MODIFIED_TIMEWINDOWS, + time, + float("+inf"), + withscores=True, ) return data or [] @@ -1195,7 +1229,7 @@ def set_mac_vendor_to_profile( def update_mac_of_profile(self, profileid: str, mac: str): """Add the MAC addr to the given profileid key""" - self.r.hset(profileid, "MAC", mac) + self.r.hset(profileid, self.constants.MAC, mac) def add_mac_addr_to_profile(self, profileid: str, mac_addr: str): """ @@ -1229,11 +1263,13 @@ def add_mac_addr_to_profile(self, profileid: str, mac_addr: str): return False # get the ips that belong to this mac - cached_ips: Optional[List] = self.r.hmget("MAC", mac_addr)[0] + cached_ips: Optional[List] = self.r.hmget( + self.constants.MAC, mac_addr + )[0] if not cached_ips: # no mac info stored for profileid ip = json.dumps([incoming_ip]) - self.r.hset("MAC", mac_addr, ip) + self.r.hset(self.constants.MAC, mac_addr, ip) # now that it's decided that this mac belongs to this profileid # stoe the mac in the profileid's key in the db @@ -1293,7 +1329,7 @@ def add_mac_addr_to_profile(self, profileid: str, mac_addr: str): # add the incoming ip to the list of ips that belong to this mac cached_ips.add(incoming_ip) cached_ips = json.dumps(list(cached_ips)) - self.r.hset("MAC", mac_addr, cached_ips) + self.r.hset(self.constants.MAC, mac_addr, cached_ips) self.update_mac_of_profile(profileid, mac_addr) self.update_mac_of_profile(f"profile_{found_ip}", mac_addr) @@ -1306,7 +1342,7 @@ def get_mac_addr_from_profile(self, profileid: dict) -> Union[str, None]: returns the info from the profileid key. """ - return self.r.hget(profileid, "MAC") + return self.r.hget(profileid, self.constants.MAC) def add_user_agent_to_profile(self, profileid, user_agent: dict): """ @@ -1393,7 +1429,7 @@ def mark_profile_as_dhcp(self, profileid): self.r.hset(profileid, "dhcp", "true") def get_first_flow_time(self) -> Optional[str]: - return self.r.hget("analysis", "file_start") + return self.r.hget(self.constants.ANALYSIS, "file_start") def add_profile(self, profileid, starttime): """ @@ -1403,12 +1439,12 @@ def add_profile(self, profileid, starttime): and individual hashmaps for each profile (like a table) """ try: - if self.r.sismember("profiles", profileid): + if self.r.sismember(self.constants.PROFILES, profileid): # we already have this profile return False # Add the profile to the index. The index is called 'profiles' - self.r.sadd("profiles", str(profileid)) + self.r.sadd(self.constants.PROFILES, str(profileid)) # Create the hashmap with the profileid. # The hasmap of each profile is named with the profileid # Add the start time of profile @@ -1451,7 +1487,7 @@ def check_tw_to_close(self, close_all=False): were modified with the slips internal time """ - sit = self.getSlipsInternalTime() + sit = self.get_slips_internal_time() # for each modified profile modification_time = float(sit) - self.width @@ -1460,7 +1496,10 @@ def check_tw_to_close(self, close_all=False): modification_time = float("inf") profiles_tws_to_close = self.r.zrangebyscore( - "ModifiedTW", 0, modification_time, withscores=True + self.constants.MODIFIED_TIMEWINDOWS, + 0, + modification_time, + withscores=True, ) for profile_tw_to_close in profiles_tws_to_close: @@ -1483,7 +1522,7 @@ def mark_profile_tw_as_closed(self, profileid_tw): Mark the TW as closed so tools can work on its data """ self.r.sadd("ClosedTW", profileid_tw) - self.r.zrem("ModifiedTW", profileid_tw) + self.r.zrem(self.constants.MODIFIED_TIMEWINDOWS, profileid_tw) self.publish("tw_closed", profileid_tw) def mark_profile_tw_as_modified(self, profileid, twid, timestamp): @@ -1498,7 +1537,7 @@ def mark_profile_tw_as_modified(self, profileid, twid, timestamp): """ timestamp = time.time() data = {f"{profileid}{self.separator}{twid}": float(timestamp)} - self.r.zadd("ModifiedTW", data) + self.r.zadd(self.constants.MODIFIED_TIMEWINDOWS, data) self.publish("tw_modified", f"{profileid}:{twid}") # Check if we should close some TW self.check_tw_to_close() @@ -1674,6 +1713,9 @@ def get_timeline_last_lines( data = self.r.zrange(key, first_index, last_index - 1) return data, last_index + def get_profiled_tw_timeline(self, profileid, timewindow): + return self.r.zrange(f"{profileid}_{timewindow}_timeline", 0, -1) + def mark_profile_as_gateway(self, profileid): """ Used to mark this profile as dhcp server diff --git a/slips_files/core/evidencehandler.py b/slips_files/core/evidencehandler.py index d7b5b487a..d351f997d 100644 --- a/slips_files/core/evidencehandler.py +++ b/slips_files/core/evidencehandler.py @@ -670,7 +670,7 @@ def main(self): # if the profile was already blocked in # this twid, we shouldn't alert - profile_already_blocked = self.db.checkBlockedProfTW( + profile_already_blocked = self.db.is_blocked_profile_and_tw( profileid, twid ) # This is the part to detect if the accumulated diff --git a/slips_files/core/helpers/checker.py b/slips_files/core/helpers/checker.py index 1f662d7a1..3dced31c7 100644 --- a/slips_files/core/helpers/checker.py +++ b/slips_files/core/helpers/checker.py @@ -108,7 +108,8 @@ def check_given_flags(self): ) self.main.terminate_slips() - # if we're reading flows from some module other than the input process, make sure it exists + # if we're reading flows from some module other than the input + # process, make sure it exists if self.main.args.input_module and not self.input_module_exists( self.main.args.input_module ): @@ -138,14 +139,16 @@ def check_given_flags(self): and self.main.args.blocking and os.geteuid() != 0 ): - # If the user wants to blocks, we need permission to modify iptables + # If the user wants to blocks, we need permission to modify + # iptables print("Run Slips with sudo to enable the blocking module.") self.main.terminate_slips() if self.main.args.clearblocking: if os.geteuid() != 0: print( - "Slips needs to be run as root to clear the slipsBlocking chain. Stopping." + "Slips needs to be run as root to clear the slipsBlocking " + "chain. Stopping." ) else: self.delete_blocking_chain() diff --git a/slips_files/core/helpers/whitelist/organization_whitelist.py b/slips_files/core/helpers/whitelist/organization_whitelist.py index 8e46a0a1d..f274fd471 100644 --- a/slips_files/core/helpers/whitelist/organization_whitelist.py +++ b/slips_files/core/helpers/whitelist/organization_whitelist.py @@ -68,7 +68,7 @@ def is_ip_in_org(self, ip: str, org): Check if the given ip belongs to the given org """ try: - org_subnets: dict = self.db.get_org_IPs(org) + org_subnets: dict = self.db.get_org_ips(org) first_octet: str = utils.get_first_octet(ip) if not first_octet: diff --git a/tests/module_factory.py b/tests/module_factory.py index d60d54de6..830f4eb37 100644 --- a/tests/module_factory.py +++ b/tests/module_factory.py @@ -120,7 +120,8 @@ def create_db_manager_obj( """ # to prevent config/redis.conf from being overwritten with patch( - "slips_files.core.database.redis_db.database.RedisDB._set_redis_options", + "slips_files.core.database.redis_db.database." + "RedisDB._set_redis_options", return_value=Mock(), ): db = DBManager( @@ -619,7 +620,9 @@ def create_riskiq_obj(self, mock_db): return riskiq def create_alert_handler_obj(self): - return AlertHandler() + alert_handler = AlertHandler() + alert_handler.constants = Constants() + return alert_handler @patch(MODULE_DB_MANAGER, name="mock_db") def create_timeline_object(self, mock_db): diff --git a/tests/test_cesnet.py b/tests/test_cesnet.py index b6e960804..a9ce49bb1 100644 --- a/tests/test_cesnet.py +++ b/tests/test_cesnet.py @@ -113,13 +113,13 @@ def test_import_alerts(events, expected_output): cesnet.wclient = MagicMock() cesnet.wclient.getEvents = MagicMock(return_value=events) cesnet.db = MagicMock() - cesnet.db.add_ips_to_IoC = MagicMock() + cesnet.db.add_ips_to_ioc = MagicMock() cesnet.print = MagicMock() cesnet.import_alerts() - assert cesnet.db.add_ips_to_IoC.call_count == 1 + assert cesnet.db.add_ips_to_ioc.call_count == 1 - src_ips = cesnet.db.add_ips_to_IoC.call_args[0][0] + src_ips = cesnet.db.add_ips_to_ioc.call_args[0][0] assert len(src_ips) == expected_output diff --git a/tests/test_database.py b/tests/test_database.py index 91d737b37..47efbc20a 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -48,20 +48,10 @@ "", ) -random_port = 6379 - - -def get_random_port(): - global random_port - random_port += 1 - return random_port - def test_getProfileIdFromIP(): """unit test for add_profile and getProfileIdFromIP""" - db = ModuleFactory().create_db_manager_obj( - get_random_port(), flush_db=True - ) + db = ModuleFactory().create_db_manager_obj(6380, flush_db=True) # add a profile db.add_profile("profile_192.168.1.1", "00:00") @@ -72,9 +62,7 @@ def test_getProfileIdFromIP(): def test_timewindows(): """unit tests for addNewTW , getLastTWforProfile and getFirstTWforProfile""" - db = ModuleFactory().create_db_manager_obj( - get_random_port(), flush_db=True - ) + db = ModuleFactory().create_db_manager_obj(6381, flush_db=True) profileid = "profile_192.168.1.1" # add a profile db.add_profile(profileid, "00:00") @@ -87,9 +75,7 @@ def test_timewindows(): def test_add_ips(): - db = ModuleFactory().create_db_manager_obj( - get_random_port(), flush_db=True - ) + db = ModuleFactory().create_db_manager_obj(6382, flush_db=True) # add a profile db.add_profile(profileid, "00:00") # add a tw to that profile @@ -101,9 +87,7 @@ def test_add_ips(): def test_add_port(): - db = ModuleFactory().create_db_manager_obj( - get_random_port(), flush_db=True - ) + db = ModuleFactory().create_db_manager_obj(6383, flush_db=True) new_flow = flow new_flow.state = "Not Established" db.add_port(profileid, twid, flow, "Server", "Dst") @@ -114,9 +98,7 @@ def test_add_port(): def test_set_evidence(): - db = ModuleFactory().create_db_manager_obj( - get_random_port(), flush_db=True - ) + db = ModuleFactory().create_db_manager_obj(6384, flush_db=True) attacker: Attacker = Attacker( direction=Direction.SRC, attacker_type=IoCType.IP, value=test_ip ) @@ -148,9 +130,7 @@ def test_set_evidence(): def test_setInfoForDomains(): """tests setInfoForDomains, setNewDomain and getDomainData""" - db = ModuleFactory().create_db_manager_obj( - get_random_port(), flush_db=True - ) + db = ModuleFactory().create_db_manager_obj(6385, flush_db=True) domain = "www.google.com" domain_data = {"threatintelligence": "sample data"} db.set_info_for_domains(domain, domain_data) @@ -161,9 +141,7 @@ def test_setInfoForDomains(): def test_subscribe(): - db = ModuleFactory().create_db_manager_obj( - get_random_port(), flush_db=True - ) + db = ModuleFactory().create_db_manager_obj(6386, flush_db=True) # invalid channel assert db.subscribe("invalid_channel") is False # valid channel, shoud return a pubsub object @@ -172,9 +150,7 @@ def test_subscribe(): def test_profile_moddule_labels(): """tests set and get_profile_module_label""" - db = ModuleFactory().create_db_manager_obj( - get_random_port(), flush_db=True - ) + db = ModuleFactory().create_db_manager_obj(6387, flush_db=True) module_label = "malicious" module_name = "test" db.set_profile_module_label(profileid, module_name, module_label) @@ -187,9 +163,7 @@ def test_add_mac_addr_with_new_ipv4(): """ adding an ipv4 to no cached ip """ - db = ModuleFactory().create_db_manager_obj( - get_random_port(), flush_db=True - ) + db = ModuleFactory().create_db_manager_obj(6388, flush_db=True) ipv4 = "192.168.1.5" profileid_ipv4 = f"profile_{ipv4}" mac_addr = "00:00:5e:00:53:af" @@ -211,9 +185,7 @@ def test_add_mac_addr_with_existing_ipv4(): """ adding an ipv4 to a cached ipv4 """ - db = ModuleFactory().create_db_manager_obj( - get_random_port(), flush_db=True - ) + db = ModuleFactory().create_db_manager_obj(6389, flush_db=True) ipv4 = "192.168.1.5" mac_addr = "00:00:5e:00:53:af" db.rdb.is_gw_mac = Mock(return_value=False) @@ -231,9 +203,7 @@ def test_add_mac_addr_with_ipv6_association(): """ adding an ipv6 to a cached ipv4 """ - db = ModuleFactory().create_db_manager_obj( - get_random_port(), flush_db=True - ) + db = ModuleFactory().create_db_manager_obj(6390, flush_db=True) ipv4 = "192.168.1.5" profile_ipv4 = "profile_192.168.1.5" mac_addr = "00:00:5e:00:53:af" @@ -260,9 +230,7 @@ def test_add_mac_addr_with_ipv6_association(): def test_get_the_other_ip_version(): - db = ModuleFactory().create_db_manager_obj( - get_random_port(), flush_db=True - ) + db = ModuleFactory().create_db_manager_obj(6391, flush_db=True) # profileid is ipv4 ipv6 = "2001:0db8:85a3:0000:0000:8a2e:0370:7334" db.set_ipv6_of_profile(profileid, ipv6) @@ -290,9 +258,7 @@ def test_get_the_other_ip_version(): ], ) def test_add_tuple(tupleid: str, symbol, expected_direction, role, flow): - db = ModuleFactory().create_db_manager_obj( - get_random_port(), flush_db=True - ) + db = ModuleFactory().create_db_manager_obj(6392, flush_db=True) db.add_tuple(profileid, twid, tupleid, symbol, role, flow) assert symbol[0] in db.r.hget( f"profile_{flow.saddr}_{twid}", expected_direction @@ -310,9 +276,7 @@ def test_add_tuple(tupleid: str, symbol, expected_direction, role, flow): def test_update_max_threat_level( max_threat_level, cur_threat_level, expected_max ): - db = ModuleFactory().create_db_manager_obj( - get_random_port(), flush_db=True - ) + db = ModuleFactory().create_db_manager_obj(6393, flush_db=True) db.set_max_threat_level(profileid, max_threat_level) assert ( db.update_max_threat_level(profileid, cur_threat_level) == expected_max diff --git a/tests/test_redis_manager.py b/tests/test_redis_manager.py index f3c2581c2..e4d410d23 100644 --- a/tests/test_redis_manager.py +++ b/tests/test_redis_manager.py @@ -410,9 +410,23 @@ def test_remove_server_from_log( [ # Testcase 1: Normal case with multiple servers ( - "# Comment\nDate,File,Port,PID\n2024-01-01,file1," - "32768,1000\n2024-01-02,file2,32769,2000\n", - {1000: 32768, 2000: 32769}, + "Date, File or interface, Used port, Server PID, Output Zeek Dir, " + "Logs Dir, Slips PID, Is Daemon, Save the DB" + "\n2024/11/25 15:11:50.571184,dataset/test6-malicious.suricata.json," + "32768,16408,dir/zeek_files,dir,16398,False,False", + { + "16408": { + "file_or_interface": "dataset/test6-malicious.suricata.json", + "is_daemon": "False", + "output_dir": "dir", + "pid": "16408", + "port": "32768", + "save_the_db": "False", + "slips_pid": "16398", + "timestamp": "2024/11/25 15:11:50.571184", + "zeek_dir": "dir/zeek_files", + }, + }, ), # Testcase 2: Empty file ("", {}), diff --git a/tests/test_threat_intelligence.py b/tests/test_threat_intelligence.py index c15882aa3..a16287ac7 100644 --- a/tests/test_threat_intelligence.py +++ b/tests/test_threat_intelligence.py @@ -414,7 +414,7 @@ def test_delete_old_source_ips_with_deletions( threatintel = ModuleFactory().create_threatintel_obj() threatintel.db.get_all_blacklisted_ips.return_value = mock_ioc_data threatintel._ThreatIntel__delete_old_source_ips(file_to_delete) - threatintel.db.delete_ips_from_IoC_ips.assert_called_once_with( + threatintel.db.delete_ips_from_ioc_ips.assert_called_once_with( expected_deleted_ips ) @@ -440,7 +440,7 @@ def test_delete_old_source_ips_no_deletions(mock_ioc_data, file_to_delete): threatintel = ModuleFactory().create_threatintel_obj() threatintel.db.get_all_blacklisted_ips.return_value = mock_ioc_data threatintel._ThreatIntel__delete_old_source_ips(file_to_delete) - threatintel.db.delete_ips_from_IoC_ips.assert_not_called() + threatintel.db.delete_ips_from_ioc_ips.assert_not_called() @pytest.mark.parametrize( @@ -487,7 +487,7 @@ def test_delete_old_source_domains( threatintel.db.get_all_blacklisted_domains.return_value = domains_in_ioc threatintel._ThreatIntel__delete_old_source_domains(file_to_delete) assert ( - threatintel.db.delete_domains_from_IoC_domains.call_count + threatintel.db.delete_domains_from_ioc_domains.call_count == expected_calls ) @@ -567,11 +567,11 @@ def test_delete_old_source_data_from_database( threatintel._ThreatIntel__delete_old_source_data_from_database(data_file) assert ( - threatintel.db.delete_ips_from_IoC_ips.call_count + threatintel.db.delete_ips_from_ioc_ips.call_count == expected_delete_ips_calls ) assert ( - threatintel.db.delete_domains_from_IoC_domains.call_count + threatintel.db.delete_domains_from_ioc_domains.call_count == expected_delete_domains_calls ) @@ -1491,7 +1491,7 @@ def test_ip_has_blacklisted_asn( profileid = "profile_127.0.0.1" twid = "timewindow1" threatintel.db.get_ip_info.return_value = {"asn": {"number": asn}} - threatintel.db.is_blacklisted_ASN.return_value = asn_info + threatintel.db.is_blacklisted_asn.return_value = asn_info threatintel.ip_has_blacklisted_asn( ip_address, uid, timestamp, profileid, twid ) diff --git a/tests/test_update_file_manager.py b/tests/test_update_file_manager.py index 4820ce9a5..139383f1f 100644 --- a/tests/test_update_file_manager.py +++ b/tests/test_update_file_manager.py @@ -366,7 +366,7 @@ def test_update_riskiq_feed( } mocker.patch("requests.get", return_value=mock_response) result = update_manager.update_riskiq_feed() - update_manager.db.add_domains_to_IoC.assert_called_once_with( + update_manager.db.add_domains_to_ioc.assert_called_once_with( { "malicious.com": json.dumps( { @@ -397,7 +397,7 @@ def test_update_riskiq_feed_invalid_api_key( result = update_manager.update_riskiq_feed() assert result is False - update_manager.db.add_domains_to_IoC.assert_not_called() + update_manager.db.add_domains_to_ioc.assert_not_called() update_manager.db.set_ti_feed_info.assert_not_called() @@ -415,7 +415,7 @@ def test_update_riskiq_feed_request_exception( result = update_manager.update_riskiq_feed() assert result is False - update_manager.db.add_domains_to_IoC.assert_not_called() + update_manager.db.add_domains_to_ioc.assert_not_called() update_manager.db.set_ti_feed_info.assert_not_called() @@ -612,7 +612,7 @@ def test_parse_ti_feed_valid_data( result = update_manager.parse_ti_feed( "https://example.com/test.txt", "test.txt" ) - update_manager.db.add_ips_to_IoC.assert_any_call( + update_manager.db.add_ips_to_ioc.assert_any_call( { "1.2.3.4": '{"description": "Test description", ' '"source": "test.txt", ' @@ -620,7 +620,7 @@ def test_parse_ti_feed_valid_data( '"tags": ["tag3"]}' } ) - update_manager.db.add_domains_to_IoC.assert_any_call( + update_manager.db.add_domains_to_ioc.assert_any_call( { "example.com": '{"description": "Another description",' ' "source": "test.txt",' @@ -647,8 +647,8 @@ def test_parse_ti_feed_invalid_data(mocker, tmp_path): result = update_manager.parse_ti_feed( "https://example.com/invalid.txt", str(tmp_path / "invalid.txt") ) - update_manager.db.add_ips_to_IoC.assert_not_called() - update_manager.db.add_domains_to_IoC.assert_not_called() + update_manager.db.add_ips_to_ioc.assert_not_called() + update_manager.db.add_domains_to_ioc.assert_not_called() assert result is False @@ -783,7 +783,7 @@ def test_parse_ssl_feed_valid_data(mocker, tmp_path): str(tmp_path / "test_ssl_feed.csv"), ) - update_manager.db.add_ssl_sha1_to_IoC.assert_called_once_with( + update_manager.db.add_ssl_sha1_to_ioc.assert_called_once_with( { "aaabbbcccdddeeeeffff00001111222233334444": json.dumps( { @@ -818,5 +818,5 @@ def test_parse_ssl_feed_no_valid_fingerprints(mocker, tmp_path): str(tmp_path / "test_ssl_feed.csv"), ) - update_manager.db.add_ssl_sha1_to_IoC.assert_not_called() + update_manager.db.add_ssl_sha1_to_ioc.assert_not_called() assert result is False diff --git a/tests/test_whitelist.py b/tests/test_whitelist.py index c90164a52..2883e88c1 100644 --- a/tests/test_whitelist.py +++ b/tests/test_whitelist.py @@ -127,7 +127,7 @@ def test_is_ip_in_org( expected_result, ): whitelist = ModuleFactory().create_whitelist_obj() - whitelist.db.get_org_IPs.return_value = org_ips + whitelist.db.get_org_ips.return_value = org_ips result = whitelist.org_analyzer.is_ip_in_org(ip, org) assert result == expected_result diff --git a/webinterface/analysis/analysis.py b/webinterface/analysis/analysis.py index 61e9f0cf2..9317b9f49 100644 --- a/webinterface/analysis/analysis.py +++ b/webinterface/analysis/analysis.py @@ -3,7 +3,7 @@ import json from collections import defaultdict from typing import Dict, List -from ..database.database import __database__ +from ..database.database import db from slips_files.common.slips_utils import utils analysis = Blueprint( @@ -25,7 +25,7 @@ def ts_to_date(ts, seconds=False): def get_all_tw_with_ts(profileid): - tws = __database__.db.zrange(f"tws{profileid}", 0, -1, withscores=True) + tws = db.get_tws_from_profile(profileid) dict_tws = defaultdict(dict) for tw_tuple in tws: @@ -56,10 +56,9 @@ def get_ip_info(ip): "ref_file": "-", "com_file": "-", } - if ip_info := __database__.cachedb.hget("IPsInfo", ip): - ip_info = json.loads(ip_info) - # Hardcoded decapsulation due to the complexity of data in side. Ex: {"asn":{"asnorg": "CESNET", "timestamp": 0.001}} - + if ip_info := db.get_ip_info(ip): + # Hardcoded decapsulation due to the complexity of data inside. + # Ex: {"asn":{"asnorg": "CESNET", "timestamp": 0.001}} # set geocountry geocountry = ip_info.get("geocountry", "-") @@ -125,23 +124,15 @@ def set_profile_tws(): Blocked are highligted in red. :return: (profile, [tw, blocked], blocked) """ + data = {} - profiles_dict = {} - # Fetch profiles - profiles = __database__.db.smembers("profiles") + profiles = db.get_profiles() + blocked_profiles = db.get_malicious_profiles() for profileid in profiles: - profile_word, profile_ip = profileid.split("_") - profiles_dict[profile_ip] = False - - if blocked_profiles := __database__.db.smembers("malicious_profiles"): - for profile in blocked_profiles: - blocked_ip = profile.split("_")[-1] - profiles_dict[blocked_ip] = True + blocked: bool = profileid in blocked_profiles + profile_ip = profileid.split("_")[-1] + data.update({"profile": profile_ip, "blocked": blocked}) - data = [ - {"profile": profile_ip, "blocked": blocked_state} - for profile_ip, blocked_state in profiles_dict.items() - ] return {"data": data} @@ -159,22 +150,21 @@ def set_ip_info(ip): return {"data": data} -@analysis.route("/tws/") -def set_tws(profileid): +@analysis.route("/tws/") +def set_tws(ip): """ Set timewindows for selected profile - :param profileid: ip of the profile + :param ip: ip of the profile :return: """ # Fetch all profile TWs - tws: Dict[str, dict] = get_all_tw_with_ts(f"profile_{profileid}") + profileid = f"profile_{ip}" + tws: Dict[str, dict] = get_all_tw_with_ts(profileid) blocked_tws: List[str] = [] for tw_id, twid_details in tws.items(): - is_blocked: bool = __database__.db.hget( - f"profile_{profileid}_{tw_id}", "alerts" - ) + is_blocked: bool = db.get_profileid_twid_alerts(profileid, tw_id) if is_blocked: blocked_tws.append(tw_id) @@ -192,18 +182,17 @@ def set_tws(profileid): return {"data": data} -@analysis.route("/intuples//") -def set_intuples(profile, timewindow): +@analysis.route("/intuples//") +def set_intuples(ip, timewindow): """ Set intuples of a chosen profile and timewindow. - :param profile: active profile + :param ip: ip of active profile :param timewindow: active timewindow :return: (tuple, string, ip_info) """ data = [] - if intuples := __database__.db.hget( - f"profile_{profile}_{timewindow}", "InTuples" - ): + profileid = f"profile_{ip}" + if intuples := db.get_intuples_from_profile_tw(profileid, timewindow): intuples = json.loads(intuples) for key, value in intuples.items(): ip, port, protocol = key.split("-") @@ -216,21 +205,19 @@ def set_intuples(profile, timewindow): return {"data": data} -@analysis.route("/outtuples//") -def set_outtuples(profile, timewindow): +@analysis.route("/outtuples//") +def set_outtuples(ip, timewindow): """ Set outtuples of a chosen profile and timewindow. - :param profile: active profile + :param ip: ip of active profile :param timewindow: active timewindow :return: (tuple, key, ip_info) """ data = [] - if outtuples := __database__.db.hget( - f"profile_{profile}_{timewindow}", "OutTuples" - ): + profileid = f"profile_{ip}" + if outtuples := db.get_outtuples_from_profile_tw(profileid, timewindow): outtuples = json.loads(outtuples) - for key, value in outtuples.items(): ip, port, protocol = key.split("-") ip_info = get_ip_info(ip) @@ -241,15 +228,16 @@ def set_outtuples(profile, timewindow): return {"data": data} -@analysis.route("/timeline_flows//") -def set_timeline_flows(profile, timewindow): +@analysis.route("/timeline_flows//") +def set_timeline_flows(ip, timewindow): """ Set timeline flows of a chosen profile and timewindow. :return: list of timeline flows as set initially in database """ data = [] - if timeline_flows := __database__.db.hgetall( - f"profile_{profile}_{timewindow}_flows" + profileid = f"profile_{ip}" + if timeline_flows := db.get_all_flows_in_profileid_twid( + profileid, timewindow ): for key, value in timeline_flows.items(): value = json.loads(value) @@ -268,9 +256,9 @@ def set_timeline_flows(profile, timewindow): return {"data": data} -@analysis.route("/timeline//") +@analysis.route("/timeline//") def set_timeline( - profile, + ip, timewindow, ): """ @@ -278,10 +266,8 @@ def set_timeline( :return: list of timeline as set initially in database """ data = [] - - if timeline := __database__.db.zrange( - f"profile_{profile}_{timewindow}_timeline", 0, -1 - ): + profileid = f"profile_{ip}" + if timeline := db.get_profiled_tw_timeline(profileid, timewindow): for flow in timeline: flow = json.loads(flow) @@ -310,21 +296,18 @@ def set_timeline( return {"data": data} -@analysis.route("/alerts//") -def set_alerts(profile, timewindow): +@analysis.route("/alerts//") +def set_alerts(ip, timewindow): """ Set alerts for chosen profile and timewindow """ data = [] - profile = f"profile_{profile}" - if alerts := __database__.db.hget("alerts", profile): - alerts = json.loads(alerts) + profile = f"profile_{ip}" + if alerts := db.get_profileid_twid_alerts(profile, timewindow): alerts_tw = alerts.get(timewindow, {}) tws = get_all_tw_with_ts(profile) - evidence: Dict[str, str] = __database__.db.hgetall( - f"{profile}_{timewindow}_evidence" - ) + evidence: Dict[str, str] = db.get_twid_evidence(profile, timewindow) for alert_id, evidence_id_list in alerts_tw.items(): evidence_count = len(evidence_id_list) @@ -349,46 +332,43 @@ def set_alerts(profile, timewindow): return {"data": data} -@analysis.route("/evidence///") -def set_evidence(profile, timewindow, alert_id): +@analysis.route("/evidence///") +def set_evidence(ip, timewindow, alert_id: str): """ - Set evidence table for the pressed alert in chosem profile and timewindow + Set evidence table for the pressed alert in chosen profile and timewindow """ data = [] - if alerts := __database__.db.hget("alerts", f"profile_{profile}"): - alerts = json.loads(alerts) - alerts_tw = alerts[timewindow] - # get the list of evidence that were part of this alert - evidence_ids: List[str] = alerts_tw[alert_id] - - profileid = f"profile_{profile}" - evidence: Dict[str, str] = __database__.db.hgetall( - f"{profileid}_{timewindow}_evidence" - ) + profileid = f"profile_{ip}" + # get the list of evidence that were part of this alert + evidence_ids: List[str] = db.get_evidence_causing_alert( + profileid, timewindow, alert_id + ) + if evidence_ids: for evidence_id in evidence_ids: - temp_evidence = json.loads(evidence[evidence_id]) - data.append(temp_evidence) + # get the actual evidence represented by the id + evidence: Dict[str, str] = db.get_evidence_by_id( + profileid, timewindow, evidence_id + ) + data.append(evidence) return {"data": data} -@analysis.route("/evidence///") -def set_evidence_general(profile: str, timewindow: str): +@analysis.route("/evidence///") +def set_evidence_general(ip: str, timewindow: str): """ Set an analysis tag with general evidence - :param profile: the ip + :param ip: the ip of the profile :param timewindow: timewindowx :return: {"data": data} where data is a list of evidences """ data = [] - profile = f"profile_{profile}" - - evidence: Dict[str, str] = __database__.db.hgetall( - f"{profile}_{timewindow}_evidence" - ) + profile = f"profile_{ip}" + evidence: Dict[str, str] = db.get_twid_evidence(profile, timewindow) if evidence: for evidence_details in evidence.values(): + evidence_details: str evidence_details: dict = json.loads(evidence_details) data.append(evidence_details) return {"data": data} diff --git a/webinterface/app.py b/webinterface/app.py index 6f96c177c..f744f9d99 100644 --- a/webinterface/app.py +++ b/webinterface/app.py @@ -1,12 +1,12 @@ -from flask import Flask, render_template, redirect, url_for, current_app +from flask import Flask, render_template, redirect, url_for from slips_files.common.parsers.config_parser import ConfigParser -from .database.database import __database__ +from .database.database import db from .database.signals import message_sent from .analysis.analysis import analysis from .general.general import general from .documentation.documentation import documentation -from .utils import read_db_file +from .utils import get_open_redis_ports_in_order def create_app(): @@ -20,7 +20,11 @@ def create_app(): @app.route("/redis") def read_redis_port(): - res = read_db_file() + """ + is called when changing the db from the button at the top right + prints the available redis dbs and ports for the user to choose ffrom + """ + res = get_open_redis_ports_in_order() return {"data": res} @@ -31,9 +35,12 @@ def index(): @app.route("/db/") def get_post_javascript_data(new_port): - message_sent.send( - current_app._get_current_object(), port=int(new_port), dbnumber=0 - ) + """ + is called when the user chooses another db to connect to from the + button at the top right (from /redis) + should send a msg to update_db() in database.py + """ + message_sent.send(int(new_port)) return redirect(url_for("index")) @@ -42,12 +49,12 @@ def set_pcap_info(): """ Set information about the pcap. """ - info = __database__.db.hgetall("analysis") + info = db.get_analysis_info() - profiles = __database__.db.smembers("profiles") + profiles = db.get_profiles() info["num_profiles"] = len(profiles) if profiles else 0 - alerts_number = __database__.db.get("number_of_alerts") + alerts_number = db.get_number_of_alerts_so_far() info["num_alerts"] = int(alerts_number) if alerts_number else 0 return info @@ -55,9 +62,6 @@ def set_pcap_info(): if __name__ == "__main__": app.register_blueprint(analysis, url_prefix="/analysis") - app.register_blueprint(general, url_prefix="/general") - app.register_blueprint(documentation, url_prefix="/documentation") - app.run(host="0.0.0.0", port=ConfigParser().web_interface_port) diff --git a/webinterface/database/database.py b/webinterface/database/database.py index 96c9a03c4..5ba7ffc03 100644 --- a/webinterface/database/database.py +++ b/webinterface/database/database.py @@ -1,46 +1,68 @@ -import redis +from typing import ( + Dict, + Optional, +) +import os + +from slips_files.core.database.database_manager import DBManager +from slips_files.core.output import Output from .signals import message_sent -from webinterface.utils import * +from webinterface.utils import ( + get_open_redis_ports_in_order, + get_open_redis_servers, +) class Database(object): + """ + connects to the latest opened redis server on init + """ + def __init__(self): - self.db = self.init_db() - self.cachedb = self.connect_to_database( - port=6379, db_number=1 - ) # default cache - - def set_db(self, port, db_number): - self.db = self.connect_to_database(port, db_number) - - def set_cachedb(self, port, db_number): - self.cachedb = self.connect_to_database(port, db_number) - - def init_db(self): - available_dbs = read_db_file() - port, db_number = 6379, 0 - - if len(available_dbs) >= 1: - port = available_dbs[-1]["redis_port"] - - return self.connect_to_database(port, db_number) - - def connect_to_database(self, port=6379, db_number=0): - return redis.StrictRedis( - host="localhost", - port=port, - db=db_number, - charset="utf-8", - socket_keepalive=True, - retry_on_timeout=True, - decode_responses=True, - health_check_interval=30, + # connect to the db manager + self.db: DBManager = self.get_db_manager_obj() + + def set_db(self, port): + """changes the redis db we're connected to""" + self.db = self.get_db_manager_obj(port) + + def get_db_manager_obj(self, port: int = False) -> Optional[DBManager]: + """ + Connects to redis db through the DBManager + connects to the latest opened redis server if no port is given + """ + if not port: + # connect to the last opened port if no port is chosen by the + # user + last_opened_port = get_open_redis_ports_in_order()[-1][ + "redis_port" + ] + port = last_opened_port + + dbs: Dict[int, dict] = get_open_redis_servers() + output_dir = dbs[str(port)]["output_dir"] + logger = Output( + stdout=os.path.join(output_dir, "slips.log"), + stderr=os.path.join(output_dir, "errors.log"), + slips_logfile=os.path.join(output_dir, "slips.log"), ) + try: + return DBManager( + logger, + output_dir, + port, + start_redis_server=False, + ) + except RuntimeError: + return -__database__ = Database() +db_obj = Database() +db: DBManager = db_obj.db @message_sent.connect -def update_db(app, port, dbnumber): - __database__.set_db(port, dbnumber) +def update_db(port): + """is called when the user changes the used redis server from the web + interface""" + db_obj.set_db(port) diff --git a/webinterface/general/general.py b/webinterface/general/general.py index 5c612bd5e..53ca9efd5 100644 --- a/webinterface/general/general.py +++ b/webinterface/general/general.py @@ -1,6 +1,8 @@ from flask import Blueprint from flask import render_template -from ..database.database import __database__ + + +from ..database.database import db general = Blueprint( "general", @@ -17,15 +19,15 @@ def index(): @general.route("/blockedProfileTWs") -def setBlockedProfileTWs(): +def set_blocked_profiles_and_tws(): """ Function to set blocked profiles and tws """ - blockedProfileTWs = __database__.db.hgetall("BlockedProfTW") + blocked_profiles_and_tws = db.get_blocked_profiles_and_timewindows() data = [] - if blockedProfileTWs: - for profile, tws in blockedProfileTWs.items(): + if blocked_profiles_and_tws: + for profile, tws in blocked_profiles_and_tws.items(): data.append({"blocked": profile + str(tws)}) return { diff --git a/webinterface/utils.py b/webinterface/utils.py index 94616b3f6..1b4747772 100644 --- a/webinterface/utils.py +++ b/webinterface/utils.py @@ -1,9 +1,13 @@ import os +from typing import ( + Dict, + List, +) -def read_db_file(): +def get_open_redis_ports_in_order() -> List[Dict[str, str]]: available_db = [] - file_path = "../running_slips_info.txt" + file_path = "running_slips_info.txt" if os.path.exists(file_path): with open(file_path) as file: @@ -20,3 +24,69 @@ def read_db_file(): ) return available_db + + +def is_comment(line: str) -> bool: + """returns true if the given line is a comment""" + return (line.startswith("#") or line.startswith("Date")) or len(line) < 3 + + +def get_open_redis_servers() -> Dict[int, dict]: + """ + returns the opened redis servers read from running_slips.info.txt + returns the following dict: {port: { + "timestamp": ..., + "file_or_interface": ..., + "port": ..., + "pid": ..., + "zeek_dir": ..., + "output_dir": ..., + "slips_pid": ..., + "is_daemon": ..., + "save_the_db": ..., + }} + """ + running_logfile = "running_slips_info.txt" + open_servers: Dict[int, dict] = {} + try: + with open(running_logfile) as f: + for line in f.read().splitlines(): + if is_comment(line): + continue + + line = line.split(",") + + try: + ( + timestamp, + file_or_interface, + port, + pid, + zeek_dir, + output_dir, + slips_pid, + is_daemon, + save_the_db, + ) = line + + open_servers[port] = { + "timestamp": timestamp, + "file_or_interface": file_or_interface, + "port": port, + "pid": pid, + "zeek_dir": zeek_dir, + "output_dir": output_dir, + "slips_pid": slips_pid, + "is_daemon": is_daemon, + "save_the_db": save_the_db, + } + except ValueError: + # sometimes slips can't get the server pid and logs + # "False" in the logfile instead of the PID + # there's nothing we can do about it + pass + + return open_servers + + except FileNotFoundError: + return {}