diff --git a/CHANGELOG.md b/CHANGELOG.md index 44b421a23..a194bde7e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,11 +1,16 @@ -- 1.1.4 (Nov 29th, 2024) +1.1.4.1 (Dec 3rd, 2024) +- Fix abstract class starting with the rest of the modules. +- Fix the updating of the MAC vendors database used in slips. +- Improve MAC vendor offline lookups. + +1.1.4 (Nov 29th, 2024) - Fix changing the used database in the web interface. - Reduce false positive evidence about malicious downloaded files. - Fix datetime errors when running on interface - Improve the detection of "DNS without connection". - Add support for a light Slips docker image. -- 1.1.3 (October 30th, 2024) +1.1.3 (October 30th, 2024) - Enhanced Slips shutdown process for smoother operations. - Optimized resource management in Slips, resolving issues with lingering threads in memory. - Remove the progress bar; Slips now provides regular statistical updates. @@ -19,7 +24,7 @@ - Enhance logging of IDMEF errors. - Resolve issues with the accumulated threat level reported in alerts.json. -- 1.1.2 (September 30th, 2024) +1.1.2 (September 30th, 2024) - Add a relation between related evidence in alerts.json - Better unit tests. Thanks to @Sekhar-Kumar-Dash - Discontinued MacOS m1 docker images, P2p images, and slips dependencies image. @@ -42,7 +47,7 @@ - Update python dependencies. - Better handling of problems connecting to Redis database. -- 1.1 (July 2024) +1.1 (July 2024) - Update Python version to 3.10.12 and all python libraries used by Slips. - Update nodejs and zeek. - Improve the stopping of Slips. Modules now have more time to process flows. @@ -54,7 +59,7 @@ - Horizontal port scan detection improvements. -- 1.0.15 (June 2024) +1.0.15 (June 2024) - Add a Parameter to export strato letters to re-train the RNN model. - Better organization of flowalerts module by splitting it into many specialized files. - Better unit tests. thanks to @Sekhar-Kumar-Dash @@ -70,7 +75,7 @@ - The port of the web interface is now configurable in slips.conf -- 1.0.14 (May 2024) +1.0.14 (May 2024) - Improve whitelists. better matching of ASNs, domains, and organizations. - Whitelist Microsoft, Apple, Twitter, Facebook and Google alerts by default to reduce false positives. - Better unit tests. thanks to @Sekhar-Kumar-Dash @@ -79,7 +84,7 @@ - Add more info to metadata/info.txt for each run. -- 1.0.13 (April 2024) +1.0.13 (April 2024) - Whitelist alerts to all organizations by default to reduce false positives. - Improve and compress Slips Docker images. thanks to @verovaleros - Improve CI and add pre-commit hooks. @@ -90,7 +95,7 @@ - Better unit tests. thanks to @Sekhar-Kumar-Dash - Fix problems stopping the daemon. -- 1.0.12 (March 2024) +1.0.12 (March 2024) - Add an option to specify the current client IP in slips.conf to help avoid false positives. - Better handling of URLhaus threat intelligence. - Change how slips determines the local network of the current client IP. @@ -107,7 +112,7 @@ - Use the latest Redis and NodeJS version in all docker images. -- 1.0.11 (February 2024) +1.0.11 (February 2024) - Improve the logging of evidence in alerts.json and alerts.log. - Optimize the storing of evidence in the Redis database. - Fix problem of missing evidence, now all evidence is logged correctly. @@ -117,7 +122,7 @@ - Fix problem closing the progress bar. - Fix problem releasing the terminal when Slips is done. -- 1.0.10 (January 2024) +1.0.10 (January 2024) - Faster ensembling of evidence. - Log accumulated threat levels of each evidence in alerts.json. - Better handling of the termination of the progress bar. @@ -239,7 +244,7 @@ - Fix caching ASN ranges - Code optimizations -- 1.0.1 (Jan 2023) +1.0.1 (Jan 2023) - fix FP horizontal portscans caused by zeek flipping connections - Fix Duplicate evidence in multiple alerts - Fix FP urlhaus detetcions, now we use it to check urls only, not domains. diff --git a/VERSION b/VERSION index 65087b4f5..55bbf1581 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.1.4 +1.1.4.1 diff --git a/docker/light/Dockerfile b/docker/light/Dockerfile index 61b8a56f3..895b32fa6 100644 --- a/docker/light/Dockerfile +++ b/docker/light/Dockerfile @@ -4,7 +4,7 @@ ENV DEBIAN_FRONTEND=noninteractive # Blocking module requirement to avoid using sudo ENV IS_IN_A_DOCKER_CONTAINER=True # destionation dir for slips inside the container -ENV SLIPS_DIR=/StratosphereLinuxIPs +ENV SLIPS_DIR=/StratosphereLinuxIPS # use bash instead of sh SHELL ["/bin/bash", "-c"] diff --git a/modules/ip_info/ip_info.py b/modules/ip_info/ip_info.py index ddf628f5e..283bb28de 100644 --- a/modules/ip_info/ip_info.py +++ b/modules/ip_info/ip_info.py @@ -14,13 +14,14 @@ import time import asyncio import multiprocessing +from functools import lru_cache from modules.ip_info.jarm import JARM from slips_files.common.flow_classifier import FlowClassifier from slips_files.core.helpers.whitelist.whitelist import Whitelist from .asn_info import ASN -from slips_files.common.abstracts.module import AsyncModule +from slips_files.common.abstracts.async_module import AsyncModule from slips_files.common.slips_utils import utils from slips_files.core.structures.evidence import ( Evidence, @@ -93,14 +94,25 @@ async def open_dbs(self): self.reading_mac_db_task = asyncio.create_task(self.read_mac_db()) async def read_mac_db(self): + """ + waits 10 mins for the update manager to download the mac db and + opens it for reading. retries opening every 3s + """ + trials = 0 while True: + if trials >= 60: + # that's 10 mins of waiting for the macdb (600s) + # dont wait forever + return + try: self.mac_db = open("databases/macaddress-db.json", "r") return True except OSError: # update manager hasn't downloaded it yet try: - time.sleep(3) + time.sleep(10) + trials += 1 except KeyboardInterrupt: return False @@ -186,6 +198,7 @@ def get_vendor_online(self, mac_addr): ): return False + @lru_cache(maxsize=700) def get_vendor_offline(self, mac_addr, profileid): """ Gets vendor from Slips' offline database databases/macaddr-db.json diff --git a/modules/update_manager/update_manager.py b/modules/update_manager/update_manager.py index 5f9a59f0d..c728fbc3d 100644 --- a/modules/update_manager/update_manager.py +++ b/modules/update_manager/update_manager.py @@ -37,8 +37,8 @@ def init(self): self.update_period, self.update_ti_files ) # Timer to update the MAC db - # when update_ti_files is called, it decides what exactly to update, the mac db, - # online whitelist OT online ti files. + # when update_ti_files is called, it decides what exactly to + # update, the mac db, online whitelist Or online ti files. self.mac_db_update_manager = InfiniteTimer( self.mac_db_update_period, self.update_ti_files ) @@ -54,6 +54,7 @@ def init(self): self.whitelist = Whitelist(self.logger, self.db) self.slips_logfile = self.db.get_stdfile("stdout") self.org_info_path = "slips_files/organizations_info/" + self.path_to_mac_db = "databases/macaddress-db.json" # if any keyword of the following is present in a line # then this line should be ignored by slips # either a not supported ioc type or a header line etc. @@ -80,20 +81,20 @@ def init(self): self.responses = {} def read_configuration(self): - def read_riskiq_creds(RiskIQ_credentials_path): + def read_riskiq_creds(risk_iq_credentials_path): self.riskiq_email = None self.riskiq_key = None - if not RiskIQ_credentials_path: + if not risk_iq_credentials_path: return - RiskIQ_credentials_path = os.path.join( - os.getcwd(), RiskIQ_credentials_path + risk_iq_credentials_path = os.path.join( + os.getcwd(), risk_iq_credentials_path ) - if not os.path.exists(RiskIQ_credentials_path): + if not os.path.exists(risk_iq_credentials_path): return - with open(RiskIQ_credentials_path, "r") as f: + with open(risk_iq_credentials_path, "r") as f: self.riskiq_email = f.readline().replace("\n", "") self.riskiq_key = f.readline().replace("\n", "") @@ -113,8 +114,8 @@ def read_riskiq_creds(RiskIQ_credentials_path): self.ssl_feeds_path = conf.ssl_feeds() self.ssl_feeds = self.get_feed_details(self.ssl_feeds_path) - RiskIQ_credentials_path = conf.RiskIQ_credentials_path() - read_riskiq_creds(RiskIQ_credentials_path) + risk_iq_credentials_path = conf.RiskIQ_credentials_path() + read_riskiq_creds(risk_iq_credentials_path) self.riskiq_update_period = conf.riskiq_update_period() self.mac_db_update_period = conf.mac_db_update_period() @@ -127,7 +128,8 @@ def read_riskiq_creds(RiskIQ_credentials_path): def get_feed_details(self, feeds_path): """ - Parse links, threat level and tags from the feeds_path file and return a dict with feed info + Parse links, threat level and tags from the feeds_path file and return + a dict with feed info """ try: with open(feeds_path, "r") as feeds_file: @@ -238,7 +240,8 @@ def read_ports_info(self, ports_info_filepath) -> int: except IndexError: self.print( - f"Invalid line: {line} line number: {line_number} in {ports_info_filepath}. Skipping.", + f"Invalid line: {line} line number: " + f"{line_number} in {ports_info_filepath}. Skipping.", 0, 1, ) @@ -265,7 +268,7 @@ def update_local_file(self, file_path) -> bool: # Store the new hash of file in the database file_info = {"hash": self.new_hash} - self.db.set_ti_feed_info(file_path, file_info) + self.mark_feed_as_updated(file_path, extra_info=file_info) return True except OSError: @@ -274,7 +277,8 @@ def update_local_file(self, file_path) -> bool: def check_if_update_local_file(self, file_path: str) -> bool: """ Decides whether to update or not based on the file hash. - Used for local files that are updated if the contents of the file hash changed + Used for local files that are updated if the contents of the file + hash changed for example: files in slips_files/ports_info """ @@ -294,17 +298,14 @@ def check_if_update_local_file(self, file_path: str) -> bool: # The 2 hashes are identical. File is up to date. return False - def check_if_update_online_whitelist(self) -> bool: + def should_update_online_whitelist(self) -> bool: """ Decides whether to update or not based on the update period Used for online whitelist specified in slips.conf """ - # Get the last time this file was updated - ti_file_info = self.db.get_ti_feed_info("tranco_whitelist") - last_update = ti_file_info.get("time", float("-inf")) - - now = time.time() - if last_update + self.online_whitelist_update_period > now: + if not self.did_update_period_pass( + self.online_whitelist_update_period, "tranco_whitelist" + ): # update period hasnt passed yet return False @@ -315,8 +316,6 @@ def check_if_update_online_whitelist(self) -> bool: if not response: return False - # update the timestamp in the db - self.db.set_ti_feed_info("tranco_whitelist", {"time": time.time()}) self.responses["tranco_whitelist"] = response return True @@ -333,13 +332,19 @@ def download_file(self, file_to_download): else: return response except requests.exceptions.ReadTimeout: - error = f"Timeout reached while downloading the file {file_to_download}. Aborting." + error = ( + f"Timeout reached while downloading the file " + f"{file_to_download}. Aborting." + ) except ( requests.exceptions.ConnectionError, requests.exceptions.ChunkedEncodingError, ): - error = f"Connection error while downloading the file {file_to_download}. Aborting." + error = ( + f"Connection error while downloading the file " + f"{file_to_download}. Aborting." + ) if error: self.print(error, 0, 1) @@ -353,17 +358,49 @@ def get_last_modified(self, response) -> str: """ return response.headers.get("Last-Modified", False) - def check_if_update(self, file_to_download: str, update_period) -> bool: + def is_mac_db_file_on_disk(self) -> bool: + """checks if the mac db is present in databases/""" + return os.path.isfile(self.path_to_mac_db) + + def did_update_period_pass(self, period, file) -> bool: + """ + checks if the given period passed since the last time we + updated the given file + """ + # Get the last time this file was updated + ti_file_info: dict = self.db.get_ti_feed_info(file) + last_update = ti_file_info.get("time", float("-inf")) + return last_update + period <= time.time() + + def mark_feed_as_updated(self, feed, extra_info: dict = {}): + """ + sets the time we're done updating the feed in the db and increases + the number of loaded ti feeds + :param feed: name or link of the updated feed + :param extra_info: to store about the update of the given feen in + the db + e.g. last-modified, e-tag, hash etc + """ + now = time.time() + # update the time we last checked this file for update + self.db.set_feed_last_update_time(feed, now) + + extra_info.update({"time": now}) + self.db.set_ti_feed_info(feed, extra_info) + + self.loaded_ti_files += 1 + + def should_update(self, file_to_download: str, update_period) -> bool: """ Decides whether to update or not based on the update period and e-tag. Used for remote files that are updated periodically + the response will be stored in self.responses if the file is old + and needs to be updated :param file_to_download: url that contains the file to download + :param update_period: after how many seconds do we need to update + this file? """ - # the response will be stored in self.responses if the file is old and needs to be updated - # Get the last time this file was updated - ti_file_info: dict = self.db.get_ti_feed_info(file_to_download) - last_update = ti_file_info.get("time", float("-inf")) - if last_update + update_period > time.time(): + if not self.did_update_period_pass(update_period, file_to_download): # Update period hasn't passed yet, but the file is in our db self.loaded_ti_files += 1 return False @@ -381,12 +418,6 @@ def check_if_update(self, file_to_download: str, update_period) -> bool: if not response: return False - if "maclookup" in file_to_download: - # no need to check the e-tag - # we always need to download this file for slips to get info about MACs - self.responses["mac_db"] = response - return True - # Get the E-TAG of this file to compare with current files ti_file_info: dict = self.db.get_ti_feed_info(file_to_download) old_e_tag = ti_file_info.get("e-tag", "") @@ -400,7 +431,8 @@ def check_if_update(self, file_to_download: str, update_period) -> bool: if not new_last_modified: self.log( - f"Error updating {file_to_download}. Doesn't have an e-tag or Last-Modified field." + f"Error updating {file_to_download}." + f" Doesn't have an e-tag or Last-Modified field." ) return False @@ -409,11 +441,7 @@ def check_if_update(self, file_to_download: str, update_period) -> bool: self.responses[file_to_download] = response return True else: - # update the time we last checked this file for update - self.db.set_feed_last_update_time( - file_to_download, time.time() - ) - self.loaded_ti_files += 1 + self.mark_feed_as_updated(file_to_download) return False if old_e_tag != new_e_tag: @@ -424,13 +452,10 @@ def check_if_update(self, file_to_download: str, update_period) -> bool: else: # old_e_tag == new_e_tag - # update period passed but the file hasnt changed on the server, no need to update + # update period passed but the file hasnt changed on the + # server, no need to update # Store the update time like we downloaded it anyway - # Store the new etag and time of file in the database - self.db.set_feed_last_update_time( - file_to_download, time.time() - ) - self.loaded_ti_files += 1 + self.mark_feed_as_updated(file_to_download) return False except Exception: @@ -466,13 +491,16 @@ def parse_ssl_feed(self, url, full_path): while True: line = ssl_feed.readline() if line.startswith("# Listingdate"): - # looks like the line that contains column names, search where is the description column + # looks like the line that contains column names, + # search where is the description column for column in line.split(","): - # Listingreason is the description column in abuse.ch Suricata SSL Fingerprint Blacklist + # Listingreason is the description column in + # abuse.ch Suricata SSL Fingerprint Blacklist if "Listingreason" in column.lower(): description_column = line.split(",").index(column) if not line.startswith("#"): - # break while statement if it is not a comment (i.e. does not start with #) or a header line + # break while statement if it is not a comment (i.e. + # does not start with #) or a header line break # Find in which column is the ssl fingerprint in this file @@ -503,7 +531,8 @@ def parse_ssl_feed(self, url, full_path): if sha1_column is None: # can't find a column that contains an ioc self.print( - f"Error while reading the ssl file {full_path}. Could not find a column with sha1 info", + f"Error while reading the ssl file {full_path}. " + f"Could not find a column with sha1 info", 0, 1, ) @@ -544,7 +573,8 @@ def parse_ssl_feed(self, url, full_path): ) except IndexError: self.print( - f"IndexError Description column: {description_column}. Line: {line}" + f"IndexError Description column: " + f"{description_column}. Line: {line}" ) # self.print('\tRead Data {}: {}'.format(sha1, description)) @@ -565,14 +595,15 @@ def parse_ssl_feed(self, url, full_path): ) else: self.log( - f"The data {data} is not valid. It was found in {filename}." + f"The data {data} is not valid. It was found in " + f"{filename}." ) continue # Add all loaded malicious sha1 to the database self.db.add_ssl_sha1_to_ioc(malicious_ssl_certs) return True - async def update_TI_file(self, link_to_download: str) -> bool: + async def update_ti_file(self, link_to_download: str) -> bool: """ Update remote TI files, JA3 feeds and SSL feeds by writing them to disk and parsing them @@ -634,12 +665,10 @@ async def update_TI_file(self, link_to_download: str) -> bool: "time": time.time(), "Last-Modified": self.get_last_modified(response), } - self.db.set_ti_feed_info(link_to_download, file_info) - + self.mark_feed_as_updated(link_to_download, extra_info=file_info) self.log( - f"Successfully updated in DB the remote file {link_to_download}" + f"Successfully updated the remote file {link_to_download}" ) - self.loaded_ti_files += 1 # done parsing the file, delete it from disk try: @@ -696,20 +725,20 @@ def update_riskiq_feed(self): self.db.add_domains_to_ioc(malicious_domains_dict) except KeyError: self.print( - f'RiskIQ returned: {response["message"]}. Update Cancelled.', + f'RiskIQ returned: {response["message"]}. ' + f"Update Cancelled.", 0, 1, ) return False - # update the timestamp in the db - malicious_file_info = {"time": time.time()} - self.db.set_ti_feed_info("riskiq_domains", malicious_file_info) + self.mark_feed_as_updated("riskiq_domains") self.log("Successfully updated RiskIQ domains.") return True except Exception as e: self.log( - "An error occurred while updating RiskIQ domains. Updating was aborted." + "An error occurred while updating RiskIQ domains. " + "Updating was aborted." ) self.print("An error occurred while updating RiskIQ feed.", 0, 1) self.print(f"Error: {e}", 0, 1) @@ -731,15 +760,18 @@ def parse_ja3_feed(self, url, ja3_feed_path: str) -> bool: while True: line = ja3_feed.readline() if line.startswith("# ja3_md5"): - # looks like the line that contains column names, search where is the description column + # looks like the line that contains column names, + # search where is the description column for column in line.split(","): - # Listingreason is the description column in abuse.ch Suricata JA3 Fingerprint Blacklist + # Listingreason is the description column in + # abuse.ch Suricata JA3 Fingerprint Blacklist if "Listingreason" in column.lower(): description_column = line.split(",").index( column ) if not line.startswith("#"): - # break while statement if it is not a comment (i.e. does not startwith #) or a header line + # break while statement if it is not a comment + # (i.e. does not startwith #) or a header line break # Find in which column is the ja3 fingerprint in this file @@ -862,7 +894,8 @@ def parse_ja3_feed(self, url, ja3_feed_path: str) -> bool: def parse_json_ti_feed(self, link_to_download, ti_file_path: str) -> bool: """ - Slips has 2 json TI feeds that are parsed differently. hole.cert.pl and rstcloud + Slips has 2 json TI feeds that are parsed differently. hole.cert.pl + and rstcloud """ # to support https://hole.cert.pl/domains/domains.json tags = self.url_feeds[link_to_download]["tags"] @@ -874,7 +907,8 @@ def parse_json_ti_feed(self, link_to_download, ti_file_path: str) -> bool: malicious_ips_dict = {} with open(ti_file_path) as feed: self.print( - f"Reading next lines in the file {ti_file_path} for IoC", + f"Reading next lines in the file " + f"{ti_file_path} for IoC", 3, 0, ) @@ -902,7 +936,8 @@ def parse_json_ti_feed(self, link_to_download, ti_file_path: str) -> bool: malicious_domains_dict = {} with open(ti_file_path) as feed: self.print( - f"Reading next lines in the file {ti_file_path} for IoC", + f"Reading next lines in the file {ti_file_path}" + f" for IoC", 3, 0, ) @@ -937,7 +972,8 @@ def parse_json_ti_feed(self, link_to_download, ti_file_path: str) -> bool: def get_description_column_index(self, header): """ - Given the first line of a TI file (header line), try to get the index of the description column + Given the first line of a TI file (header line), try to get the index + of the description column """ description_keywords = ( "desc", @@ -1071,7 +1107,8 @@ def extract_ioc_from_line( def add_to_ip_ctr(self, ip, blacklist): """ keep track of how many times an ip was there in all blacklists - :param blacklist: t make sure we don't count the ip twice in the same blacklist + :param blacklist: t make sure we don't count the ip twice in the + same blacklist """ blacklist = os.path.basename(blacklist) if ip in self.ips_ctr and blacklist not in self.ips_ctr["blacklists"]: @@ -1085,7 +1122,8 @@ def is_valid_ti_file(self, ti_file_path: str) -> bool: try: filesize = os.path.getsize(ti_file_path) except FileNotFoundError: - # happens in integration tests, another instance of slips deleted the file + # happens in integration tests, another instance of slips + # deleted the file return False if filesize == 0: @@ -1441,7 +1479,7 @@ def update_org_files(self): info = { "hash": utils.get_sha256_hash(file), } - self.db.set_ti_feed_info(file, info) + self.mark_feed_as_updated(file, info) def update_ports_info(self): for file in os.listdir("slips_files/ports_info"): @@ -1459,8 +1497,10 @@ def update_ports_info(self): def print_duplicate_ip_summary(self): if not self.first_time_reading_files: - # when we parse ti files for the first time, we have the info to print the summary - # when the ti files are already updated, from a previous run, we don't + # when we parse ti files for the first time, we have the info to + # print the summary + # when the ti files are already updated, from a previous run, + # we don't return ips_in_1_bl = 0 @@ -1486,14 +1526,13 @@ def print_duplicate_ip_summary(self): def update_mac_db(self): """ - Updates the mac db using the response stored in self.response + Updates the mac db using the response stored in self.responses """ response = self.responses["mac_db"] if response.status_code != 200: return False self.log("Updating the MAC database.") - path_to_mac_db = "databases/macaddress-db.json" # write to file the info as 1 json per line mac_info = ( @@ -1501,15 +1540,16 @@ def update_mac_db(self): .replace("[", "") .replace(",{", "\n{") ) - with open(path_to_mac_db, "w") as mac_db: + with open(self.path_to_mac_db, "w") as mac_db: mac_db.write(mac_info) - self.db.set_ti_feed_info(self.mac_db_link, {"time": time.time()}) + self.mark_feed_as_updated(self.mac_db_link) return True def update_online_whitelist(self): """ - Updates online tranco whitelist defined in slips.conf online_whitelist key + Updates online tranco whitelist defined in slips.yaml + online_whitelist key """ response = self.responses["tranco_whitelist"] # write to the file so we don't store the 10k domains in memory @@ -1526,11 +1566,47 @@ def update_online_whitelist(self): self.db.store_tranco_whitelisted_domain(domain) os.remove(online_whitelist_download_path) + self.mark_feed_as_updated("tranco_whitelist") + + def download_mac_db(self): + """ + saves the mac db response to self.responses + """ + response = self.download_file(self.mac_db_link) + if not response: + return False + + self.responses["mac_db"] = response + return True + + def should_update_mac_db(self) -> bool: + """ + checks whether or not slips should download the mac db based on + its availability on disk and the update period + + the response will be stored in self.responses if the file is old + and needs to be updated + """ + if not self.is_mac_db_file_on_disk(): + # whether the period passed or not, the db needs to be + # re-downloaded + return self.download_mac_db() + + if not self.did_update_period_pass( + self.mac_db_update_period, self.mac_db_link + ): + # Update period hasn't passed yet, the file is on disk and + # up to date + self.loaded_ti_files += 1 + return False + + return self.download_mac_db() async def update(self) -> bool: """ Main function. It tries to update the TI files from a remote server - we update different types of files remote TI files, remote JA3 feeds, RiskIQ domains and local slips files + we update different types of files remote TI files, remote JA3 feeds, + RiskIQ domains and local slips files """ if self.update_period <= 0: # User does not want to update the malicious IP list. @@ -1545,12 +1621,10 @@ async def update(self) -> bool: try: self.log("Checking if we need to download TI files.") - if self.check_if_update( - self.mac_db_link, self.mac_db_update_period - ): + if self.should_update_mac_db(): self.update_mac_db() - if self.check_if_update_online_whitelist(): + if self.should_update_online_whitelist(): self.update_online_whitelist() ############### Update remote TI files ################ @@ -1562,7 +1636,7 @@ async def update(self) -> bool: files_to_download.update(self.ssl_feeds) for file_to_download in files_to_download: - if self.check_if_update(file_to_download, self.update_period): + if self.should_update(file_to_download, self.update_period): # failed to get the response, either a server problem # or the file is up to date so the response isn't needed # either way __check_if_update handles the error printing @@ -1570,18 +1644,18 @@ async def update(self) -> bool: # this run wasn't started with existing ti files in the db self.first_time_reading_files = True - # every function call to update_TI_file is now running concurrently instead of serially - # so when a server's taking a while to give us the TI feed, we proceed - # to download the next file instead of being idle + # every function call to update_TI_file is now running + # concurrently instead of serially + # so when a server's taking a while to give us the TI + # feed, we proceed to download the next file instead of + # being idle task = asyncio.create_task( - self.update_TI_file(file_to_download) + self.update_ti_file(file_to_download) ) ####################################################### # in case of riskiq files, we don't have a link for them in ti_files, We update these files using their API # check if we have a username and api key and a week has passed since we last updated - if self.check_if_update( - "riskiq_domains", self.riskiq_update_period - ): + if self.should_update("riskiq_domains", self.riskiq_update_period): self.update_riskiq_feed() # wait for all TI files to update diff --git a/slips_files/common/abstracts/async_module.py b/slips_files/common/abstracts/async_module.py index 1487b4f54..72fa81395 100644 --- a/slips_files/common/abstracts/async_module.py +++ b/slips_files/common/abstracts/async_module.py @@ -24,7 +24,8 @@ async def shutdown_gracefully(self): async def run_main(self): return await self.main() - def run_async_function(self, func: Callable): + @staticmethod + def run_async_function(func: Callable): loop = asyncio.get_event_loop() return loop.run_until_complete(func()) @@ -44,7 +45,6 @@ def run(self): self.print_traceback() return - keyboard_int_ctr = 0 while True: try: if self.should_stop(): @@ -59,10 +59,10 @@ def run(self): return except KeyboardInterrupt: - keyboard_int_ctr += 1 - if keyboard_int_ctr >= 2: + self.keyboard_int_ctr += 1 + if self.keyboard_int_ctr >= 2: # on the second ctrl+c Slips immediately stop - return + return True # on the first ctrl + C keep looping until the should_stop() # returns true continue diff --git a/slips_files/common/abstracts/module.py b/slips_files/common/abstracts/module.py index e561232d2..25eb5ca83 100644 --- a/slips_files/common/abstracts/module.py +++ b/slips_files/common/abstracts/module.py @@ -1,4 +1,3 @@ -import asyncio import sys import traceback import warnings @@ -7,7 +6,6 @@ from typing import ( Dict, Optional, - Callable, ) from slips_files.common.printer import Printer from slips_files.core.output import Output @@ -180,68 +178,3 @@ def run(self): except Exception: self.print_traceback() return - - -class AsyncModule(IModule, ABC, Process): - """ - An abstract class for asynchronous slips modules - """ - - name = "abstract class" - - def __init__(self, *args, **kwargs): - IModule.__init__(self, *args, **kwargs) - - def init(self, **kwargs): ... - - async def main(self): ... - - async def shutdown_gracefully(self): - """Implement the async shutdown logic here""" - pass - - async def run_main(self): - return await self.main() - - @staticmethod - def run_async_function(func: Callable): - loop = asyncio.get_event_loop() - return loop.run_until_complete(func()) - - def run(self): - try: - error: bool = self.pre_main() - if error or self.should_stop(): - self.run_async_function(self.shutdown_gracefully) - return - except KeyboardInterrupt: - self.run_async_function(self.shutdown_gracefully) - return - except Exception: - self.print_traceback() - return - - while True: - try: - if self.should_stop(): - self.run_async_function(self.shutdown_gracefully) - return - - # if a module's main() returns 1, it means there's an - # error and it needs to stop immediately - error: bool = self.run_async_function(self.run_main) - if error: - self.run_async_function(self.shutdown_gracefully) - return - - except KeyboardInterrupt: - self.keyboard_int_ctr += 1 - if self.keyboard_int_ctr >= 2: - # on the second ctrl+c Slips immediately stop - return True - # on the first ctrl + C keep looping until the should_stop() - # returns true - continue - except Exception: - self.print_traceback() - return diff --git a/tests/test_update_file_manager.py b/tests/test_update_file_manager.py index 139383f1f..69a9bb974 100644 --- a/tests/test_update_file_manager.py +++ b/tests/test_update_file_manager.py @@ -13,7 +13,7 @@ def test_check_if_update_based_on_update_period(): update_manager.db.get_ti_feed_info.return_value = {"time": float("inf")} url = "abc.com/x" # update period hasn't passed - assert update_manager.check_if_update(url, float("inf")) is False + assert update_manager.should_update(url, float("inf")) is False def test_check_if_update_based_on_e_tag(mocker): @@ -28,7 +28,7 @@ def test_check_if_update_based_on_e_tag(mocker): mock_requests.return_value.status_code = 200 mock_requests.return_value.headers = {"ETag": "1234"} mock_requests.return_value.text = "" - assert update_manager.check_if_update(url, float("-inf")) is False + assert update_manager.should_update(url, float("-inf")) is False # period passed, etag different etag = "1111" @@ -38,7 +38,7 @@ def test_check_if_update_based_on_e_tag(mocker): mock_requests.return_value.status_code = 200 mock_requests.return_value.headers = {"ETag": "2222"} mock_requests.return_value.text = "" - assert update_manager.check_if_update(url, float("-inf")) is True + assert update_manager.should_update(url, float("-inf")) is True def test_check_if_update_based_on_last_modified( @@ -56,7 +56,7 @@ def test_check_if_update_based_on_last_modified( mock_requests.return_value.headers = {"Last-Modified": 10.0} mock_requests.return_value.text = "" - assert update_manager.check_if_update(url, float("-inf")) is False + assert update_manager.should_update(url, float("-inf")) is False # period passed, no etag, last modified changed url = "google.com/photos" @@ -67,7 +67,7 @@ def test_check_if_update_based_on_last_modified( mock_requests.return_value.headers = {"Last-Modified": 11} mock_requests.return_value.text = "" - assert update_manager.check_if_update(url, float("-inf")) is True + assert update_manager.should_update(url, float("-inf")) is True @pytest.mark.parametrize( @@ -250,28 +250,29 @@ def test_update_local_file( update_manager = ModuleFactory().create_update_manager_obj() update_manager.new_hash = "test_hash" mocker.patch("builtins.open", mock_open(read_data=test_data)) - result = update_manager.update_local_file(str(tmp_path / file_name)) + now = 1678887000.0 + with patch("time.time", return_value=now): + result = update_manager.update_local_file(str(tmp_path / file_name)) update_manager.db.set_ti_feed_info.assert_called_once_with( - str(tmp_path / file_name), {"hash": "test_hash"} + str(tmp_path / file_name), {"hash": "test_hash", "time": now} ) assert result is True -def test_check_if_update_online_whitelist_download_updated( - mocker, -): +def test_check_if_update_online_whitelist_download_updated(): """Update period passed, download succeeds.""" update_manager = ModuleFactory().create_update_manager_obj() + update_manager.download_file = Mock() update_manager.db.get_ti_feed_info.return_value = {"time": 0} update_manager.online_whitelist = "https://example.com/whitelist.txt" update_manager.download_file = Mock(return_value=Mock(status_code=200)) - result = update_manager.check_if_update_online_whitelist() + result = update_manager.should_update_online_whitelist() assert result is True - update_manager.db.set_ti_feed_info.assert_called_once_with( - "tranco_whitelist", {"time": mocker.ANY} + update_manager.download_file.assert_called_once_with( + update_manager.online_whitelist ) assert "tranco_whitelist" in update_manager.responses @@ -281,7 +282,7 @@ def test_check_if_update_online_whitelist_not_updated(): update_manager = ModuleFactory().create_update_manager_obj() update_manager.online_whitelist = "https://example.com/whitelist.txt" update_manager.db.get_ti_feed_info.return_value = {"time": time.time()} - result = update_manager.check_if_update_online_whitelist() + result = update_manager.should_update_online_whitelist() assert result is False update_manager.db.set_ti_feed_info.assert_not_called()