diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 95f104fa4..981448d82 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -41,7 +41,7 @@ repos: rev: v1.35.1 hooks: - id: yamllint - args: ["-d", "{rules: {line-length: {max: 100}}}"] + args: ["-d", "{rules: {line-length: {max: 160}}}"] files: "slips.yaml" - repo: local diff --git a/conftest.py b/conftest.py index b0178d152..d4aeb1c03 100644 --- a/conftest.py +++ b/conftest.py @@ -71,23 +71,24 @@ def profiler_queue(): def flow(): """returns a dummy flow for testing""" return Conn( - "1601998398.945854", - "1234", - "192.168.1.1", - "8.8.8.8", - 5, - "TCP", - "dhcp", - 80, - 88, - 20, - 20, - 20, - 20, - "", - "", - "Established", - "", + starttime="1601998398.945854", + uid="1234", + saddr="192.168.1.1", + daddr="8.8.8.8", + dur=5, + proto="TCP", + appproto="dhcp", + sport=80, + dport=88, + spkts=20, + dpkts=20, + sbytes=20, + dbytes=20, + state="Established", + history="", + smac="", + dmac="", + interface="eth0", ) diff --git a/docs/immune/installing_slips_in_the_rpi.md b/docs/immune/installing_slips_in_the_rpi.md index e256ec30c..d928f8c64 100644 --- a/docs/immune/installing_slips_in_the_rpi.md +++ b/docs/immune/installing_slips_in_the_rpi.md @@ -33,11 +33,12 @@ Meaning it wil kick out the malicious device from the AP. 1. Connect your RPI to your router using an ethernet cable -2. Run your RPI as an access point using [create_ap](https://github.com/oblique/create_ap) +2. Install [linux-wifi-hotspot](https://github.com/lakinduakash/linux-wifi-hotspot/blob/master/src/scripts/README.md) +3. Start the access point (in NAT mode) -`sudo create_ap wlan0 eth0 rpi_wifi mysecurepassword -c 40` + `sudo create_ap wlan0 eth0 rpi_wifi mysecurepassword -c 40` -where `wlan0` is the wifi interface of your RPI, `eth0` is the ethernet interface and `-c 40` is the channel of the access point. +where `wlan0` is the wifi interface of your RPI, `eth0` is the ethernet interface and `-c 40` is the channel of the access point. We chose channel 40 because it is a 5GHz channel, which is faster and less crowded than the 2.4GHz channels. @@ -49,13 +50,13 @@ If all goes well you should see `wlan0: AP-ENABLED` in the output of the command Check the [Debugging common AP errors](#debugging-common-ap-errors) section if you have any issues. -3. Run Slips in the RPI using the command below to listen to the traffic from the access point. +4. Run Slips in the RPI using the command below to listen to the traffic from the access point. ```bash ./slips.py -i wlan0 ``` -4. (Optional) If you want to block malicious devices, run Slips with the `-p` parameter. Using this parameter will +5. (Optional) If you want to block malicious devices, run Slips with the `-p` parameter. Using this parameter will block all traffic to and from the malicious device when slips sets an alert. ```bash diff --git a/docs/installation.md b/docs/installation.md index e90681a94..c6cc98925 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -343,9 +343,10 @@ Meaning it wil kick out the malicious device from the AP. 1. Connect your RPI to your router using an ethernet cable -2. Run your RPI as an access point using [create_ap](https://github.com/oblique/create_ap) +2. Install [linux-wifi-hotspot](https://github.com/lakinduakash/linux-wifi-hotspot/blob/master/src/scripts/README.md) +3. Start the access point (in NAT mode) -`sudo create_ap wlan0 eth0 rpi_wifi mysecurepassword -c 40` + `sudo create_ap wlan0 eth0 rpi_wifi mysecurepassword -c 40` where `wlan0` is the wifi interface of your RPI, `eth0` is the ethernet interface and `-c 40` is the channel of the access point. @@ -360,13 +361,13 @@ Check the [Debugging common AP errors](https://stratospherelinuxips.readthedocs. -3. Run Slips in the RPI using the command below to listen to the traffic from the access point. +4. Run Slips in the RPI using the command below to listen to the traffic from the access point. ```bash ./slips.py -i wlan0 ``` -4. (Optional) If you want to block malicious devices, run Slips with the `-p` parameter. Using this parameter will +5. (Optional) If you want to block malicious devices, run Slips with the `-p` parameter. Using this parameter will block all traffic to and from the malicious device when slips sets an alert. ```bash diff --git a/install/apt_dependencies.txt b/install/apt_dependencies.txt index a3b8969ed..bcf5363bc 100644 --- a/install/apt_dependencies.txt +++ b/install/apt_dependencies.txt @@ -28,3 +28,4 @@ ca-certificates redis wget npm +iw diff --git a/managers/ap_manager.py b/managers/ap_manager.py new file mode 100644 index 000000000..4b7597796 --- /dev/null +++ b/managers/ap_manager.py @@ -0,0 +1,39 @@ +import subprocess + + +class APManager: + """ + Gets AP info when slips is running as an AP in the RPI + https://stratospherelinuxips.readthedocs.io/en/develop/immune/installing_slips_in_the_rpi.html#protect-your-local-network-with-slips-on-the-rpi + """ + + def __init__(self, main): + self.main = main + + def store_ap_interfaces(self, input_information): + """ + stores the interfaces given with -ap to slips in the db + """ + self.wifi_interface, self.eth_interface = input_information.split(",") + interfaces = { + "wifi_interface": self.wifi_interface, + "ethernet_interface": self.eth_interface, + } + self.main.db.set_ap_info(interfaces) + + def is_ap_running(self): + """returns true if a running AP is detected""" + command = ["iw", "dev"] + try: + result = subprocess.run( + command, capture_output=True, text=True, check=True + ) + lines = result.stdout.splitlines() + for line in lines: + if "type AP" in line: + return True + return False + except subprocess.CalledProcessError: + return False + except FileNotFoundError: + return False diff --git a/managers/host_ip_manager.py b/managers/host_ip_manager.py index 6e05bb8fe..20af9f8b1 100644 --- a/managers/host_ip_manager.py +++ b/managers/host_ip_manager.py @@ -4,35 +4,63 @@ import netifaces from typing import ( Set, - Optional, List, + Dict, ) +from slips_files.common.slips_utils import utils from slips_files.common.style import green class HostIPManager: def __init__(self, main): self.main = main + self.info_printed = False - def get_host_ip(self) -> Optional[str]: + def _get_default_host_ip(self, interface) -> str | None: + """ + Return the host IP of the default interface (IPv4). + usefull when slips is running using -g and the user didn't supply + an interface, so we need to infer it + """ + try: + # Get the default gateway info (usually includes interface name) + addrs = netifaces.ifaddresses(interface) + # AF_INET is for IPv4 addresses + inet_info = addrs.get(netifaces.AF_INET) + if not inet_info: + return None + + return inet_info[0]["addr"] + except Exception as e: + print(f"Error getting host IP: {e}") + return None + + def _get_host_ips(self) -> Dict[str, str]: """ tries to determine the machine's IP. - uses the intrfaces provided by the user if -i is given, or all - interfaces if not. + uses the intrfaces provided by the user with -i or -ap + returns a dict with {interface_name: host_ip, ..} """ - if not (self.main.args.interface or self.main.args.growing): - # slips is running on a file, we cant determine the host IP - return + if self.main.args.growing: + # -g is used, user didn't supply the interface + # try to get the default interface + interface = utils.infer_used_interface() + if not interface: + return {} + + if default_host_ip := self._get_default_host_ip(interface): + return {interface: default_host_ip} + return {} # we use all interfaces when -g is used, otherwise we use the given # interface interfaces: List[str] = ( [self.main.args.interface] if self.main.args.interface - else netifaces.interfaces() + else self.main.args.access_point.split(",") ) - + found_ips = {} for iface in interfaces: addrs = netifaces.ifaddresses(iface) # check for IPv4 address @@ -41,9 +69,10 @@ def get_host_ip(self) -> Optional[str]: for addr in addrs[netifaces.AF_INET]: ip = addr.get("addr") if ip and not ip.startswith("127."): - return ip + found_ips[iface] = ip + return found_ips - def store_host_ip(self) -> Optional[str]: + def store_host_ip(self) -> Dict[str, str] | None: """ stores the host ip in the db recursively retries to get the host IP online every 10s if not @@ -52,33 +81,43 @@ def store_host_ip(self) -> Optional[str]: if not self.main.db.is_running_non_stop(): return - if host_ip := self.get_host_ip(): - self.main.db.set_host_ip(host_ip) - self.main.print(f"Detected host IP: {green(host_ip)}") - return host_ip + if host_ips := self._get_host_ips(): + for iface, ip in host_ips.items(): + self.main.db.set_host_ip(ip, iface) + if not self.info_printed: + self.main.print( + f"Detected host IP: {green(ip)} for {green(iface)}" + ) + self.info_printed = True + + return host_ips self.main.print("Not Connected to the internet. Reconnecting in 10s.") time.sleep(10) self.store_host_ip() def update_host_ip( - self, host_ip: str, modified_profiles: Set[str] - ) -> Optional[str]: + self, host_ips: Dict[str, str], modified_profiles: Set[str] + ) -> Dict[str, str]: """ Is called every 5s for slips to update the host ip when running on an interface we keep track of the host IP. If there was no modified TWs in the host IP, we check if the network was changed. :param modified_profiles: modified profiles since slips start time + :param host_ips: a dict with {interface: host_ip,..} for each + interface slips is monitoring """ if not self.main.db.is_running_non_stop(): return - if host_ip in modified_profiles: - return host_ip - - if latest_host_ip := self.get_host_ip(): - self.main.db.set_host_ip(latest_host_ip) - return latest_host_ip + if host_ips: + res = {} + for iface, ip in host_ips.items(): + if ip in modified_profiles: + res[iface] = ip + if res: + return res - return latest_host_ip + # there was no modified TWs in the host IPs, check if network changed + return self.store_host_ip() diff --git a/managers/process_manager.py b/managers/process_manager.py index f1c77363e..169322e62 100644 --- a/managers/process_manager.py +++ b/managers/process_manager.py @@ -629,7 +629,7 @@ def should_run_non_stop(self) -> bool: return ( self.is_debugger_active() or self.main.input_type in ("stdin", "cyst") - or self.main.is_interface + or self.main.db.is_running_non_stop() ) def shutdown_interactive( diff --git a/modules/arp/arp.py b/modules/arp/arp.py index da556329e..b940c372c 100644 --- a/modules/arp/arp.py +++ b/modules/arp/arp.py @@ -166,7 +166,7 @@ def get_uids(): # node to announce or update its IP to MAC mapping # to the entire network. It shouldn't be marked as an arp scan # Don't detect arp scan from the GW router - if self.db.get_gateway_ip() == flow.saddr: + if self.db.get_gateway_ip(flow.interface) == flow.saddr: return False # What is this? @@ -417,8 +417,8 @@ def detect_mitm_arp_attack(self, twid: str, flow): attackers_ip = flow.saddr victims_ip = original_ip - gateway_ip = self.db.get_gateway_ip() - gateway_mac = self.db.get_gateway_mac() + gateway_ip = self.db.get_gateway_ip(flow.interface) + gateway_mac = self.db.get_gateway_mac(flow.interface) if flow.saddr == gateway_ip: saddr = f"The gateway {flow.saddr}" else: diff --git a/modules/arp/filter.py b/modules/arp/filter.py index 5e67e7912..fa3a8449f 100644 --- a/modules/arp/filter.py +++ b/modules/arp/filter.py @@ -28,7 +28,7 @@ def should_discard_evidence(self, ip: str) -> bool: def is_self_defense(self, ip: str): """ - slips uses arp poison to defend itself and th enetwork, + slips uses arp poison to defend itself and the network, check arp_poison.py for more details. goal of this function is to discard evidence about slips doing arp attacks when it's just attacking attackers diff --git a/modules/arp_poisoner/arp_poisoner.py b/modules/arp_poisoner/arp_poisoner.py index 3f313a1b6..969c6f8dc 100644 --- a/modules/arp_poisoner/arp_poisoner.py +++ b/modules/arp_poisoner/arp_poisoner.py @@ -7,7 +7,7 @@ from threading import Lock import json import ipaddress -from typing import Set, Tuple +from typing import Set, Tuple, Dict from scapy.all import ARP, Ether from scapy.sendrecv import sendp, srp import random @@ -53,6 +53,12 @@ def init(self): self._scan_delay = 30 self._last_scan_time = 0 self.last_arp_scan_output = set() + self.ap_info: None | Dict[str, str] = self.db.get_ap_info() + self.is_running_in_ap_mode = True if self.ap_info else False + # contains interfaces as keys and and their GW ip as values + self.gw_ip: Dict[str, str] = {} + # keeps track of which interface were blocked ips attacking on + self.ip_interface_map = {} def log(self, text): """Logs the given text to the blocking log file""" @@ -165,8 +171,16 @@ def _arp_scan(self, interface) -> Set[Tuple[str, str]]: # use the cached output if it's not time to rescan return self.last_arp_scan_output - # --retry=0 to avoid redundant retries. - cmd = ["arp-scan", f"--interface={interface}", "--localnet"] + # we are explicitly giving arp-scan the ip to avoid giving docker + # RAW_SOCKET permissions for arp-scan to be able to auto detect the ip + host_ip = self.db.get_host_ip(interface) + cmd = [ + "arp-scan", + f"--interface={interface}", + "--localnet", + f"--arpspa={host_ip}", + ] + try: output = subprocess.check_output(cmd, text=True) except subprocess.CalledProcessError as e: @@ -205,7 +219,9 @@ def _get_mac_using_arp(self, ip) -> str | None: return result[0][1].hwsrc return None - def _isolate_target_from_localnet(self, target_ip: str, fake_mac: str): + def _isolate_target_from_localnet( + self, target_ip: str, fake_mac: str, interface: str + ): """ Tells all the available hosts in the localnet that the target_ip is at fake_mac using unsolicited arp replies. @@ -224,11 +240,10 @@ def _isolate_target_from_localnet(self, target_ip: str, fake_mac: str): # found, FW blocking module handles blocking it through the fw, # plus we need our cache unpoisoned to be able to get the mac of # attackers to poison/reposion them. - all_hosts: Set[Tuple[str, str]] = self._arp_scan(self.args.interface) + all_hosts: Set[Tuple[str, str]] = self._arp_scan(interface) for ip, mac in all_hosts: if ip == target_ip: continue - pkt = Ether(dst=mac) / ARP( op=2, pdst=ip, # which dst ip are we sending this pkt to? @@ -239,18 +254,34 @@ def _isolate_target_from_localnet(self, target_ip: str, fake_mac: str): ) sendp(pkt, verbose=0) + def _get_gateway_ip(self, interface: str) -> str | None: + """gets the GW ip using cache, using the DB, or using netifaces""" + if interface in self.gw_ip: + return self.gw_ip[interface] + + gateway_ip: str = self.db.get_gateway_ip( + interface + ) or utils.get_gateway_for_iface(interface) + # cache it for later + self.gw_ip[interface] = gateway_ip + return gateway_ip + def _cut_targets_internet( - self, target_ip: str, target_mac: str, fake_mac: str + self, target_ip: str, target_mac: str, fake_mac: str, interface: str ): """ Cuts the target's internet by telling the target_ip that the gateway is at fake_mac using unsolicited arp reply AND telling the gw that the target is at a fake mac. """ - # in ap mode, this gw ip is the same as our own ip - gateway_ip: str = self.db.get_gateway_ip() + gateway_ip: str = self._get_gateway_ip(interface) - # we use replies, not requests, because we wanna anser ARP requests + if not gateway_ip: + self.print( + f"Unable to cut the internet of attacker at" + f" {target_ip}. Gateway IP is not found." + ) + # we use replies, not requests, because we wanna answer ARP requests # sent to the network instead of waiting for the attacker to answer # them. @@ -264,12 +295,13 @@ def _cut_targets_internet( pdst=target_ip, hwdst=target_mac, ) - sendp(pkt, iface=self.args.interface, verbose=0) + sendp(pkt, iface=interface, verbose=0) # poison the gw, tell it the victim is at a fake mac so traffic # from it wont reach the victim # attacker -> gw: im at a fake mac. - gateway_mac = self.db.get_gateway_mac() + gateway_mac = self.db.get_gateway_mac(interface) + pkt = Ether(dst=gateway_mac) / ARP( op=2, psrc=target_ip, @@ -277,7 +309,16 @@ def _cut_targets_internet( pdst=gateway_ip, hwdst=gateway_mac, ) - sendp(pkt, iface=self.args.interface, verbose=0) + sendp(pkt, iface=interface, verbose=0) + + def _get_interface_of_ip(self, ip) -> str | None: + if ip in self.ip_interface_map: + return self.ip_interface_map[ip] + + interface = utils.get_interface_of_ip(ip, self.db, self.args) + if interface: + self.ip_interface_map[ip] = interface + return interface def _attack(self, target_ip: str, first_time=False): """ @@ -300,8 +341,16 @@ def _attack(self, target_ip: str, first_time=False): if not target_mac: return - self._cut_targets_internet(target_ip, target_mac, fake_mac) - self._isolate_target_from_localnet(target_ip, fake_mac) + interface: str | None = self._get_interface_of_ip(target_ip) + if not interface: + self.print( + f"Can't get the interface of {target_ip}. " + f"Poisoning cancelled." + ) + return + + self._cut_targets_internet(target_ip, target_mac, fake_mac, interface) + self._isolate_target_from_localnet(target_ip, fake_mac, interface) # we repoison every 10s, we dont wanna log every 10s. if first_time: @@ -315,21 +364,21 @@ def is_broadcast(self, ip_str, net_str) -> bool: except ValueError: return False - def can_poison_ip(self, ip) -> bool: + def can_poison_ip(self, ip, interface: str) -> bool: """ Checks if the ip is in out localnet, isnt the router """ if utils.is_public_ip(ip): return False - localnet = self.db.get_local_network() + localnet = self.db.get_local_network(interface) if ipaddress.ip_address(ip) not in ipaddress.ip_network(localnet): return False if self.is_broadcast(ip, localnet): return False - if ip == self.db.get_gateway_ip(): + if ip == self._get_gateway_ip(interface): return False # no need to check if the ip is in our ips because all our ips are @@ -343,10 +392,12 @@ def main(self): data = json.loads(msg["data"]) ip = data.get("ip") tw: int = data.get("tw") + interface: str = data.get("interface") - if not self.can_poison_ip(ip): + if not self.can_poison_ip(ip, interface): return + self.ip_interface_map[ip] = interface self._attack(ip, first_time=True) # whether this ip is blocked now, or was already blocked, make an diff --git a/modules/blocking/blocking.py b/modules/blocking/blocking.py index 35aaa3b70..4bcefa79b 100644 --- a/modules/blocking/blocking.py +++ b/modules/blocking/blocking.py @@ -16,6 +16,9 @@ from modules.blocking.unblocker import Unblocker +OUTPUT_TO_DEV_NULL = ">/dev/null 2>&1" + + class Blocking(IModule): """Data should be passed to this module as a json encoded python dict, by default this module flushes all slipsBlocking chains before it starts""" @@ -46,6 +49,10 @@ def init(self): open(self.blocking_log_path, "w").close() except FileNotFoundError: pass + self.last_closed_tw = None + + self.ap_info: None | Dict[str, str] = self.db.get_ap_info() + self.is_running_in_ap_mode = True if self.ap_info else False def log(self, text: str): """Logs the given text to the blocking log file""" @@ -89,7 +96,9 @@ def _init_chains_in_firewall(self): # self.delete_iptables_chain() self.print('Executing "sudo iptables -N slipsBlocking"', 6, 0) # Add a new chain to iptables - os.system(f"{self.sudo} iptables -N slipsBlocking >/dev/null 2>&1") + os.system( + f"{self.sudo} iptables -N slipsBlocking {OUTPUT_TO_DEV_NULL}" + ) # Check if we're already redirecting to slipsBlocking chain input_chain_rules = self._get_cmd_output( @@ -107,18 +116,18 @@ def _init_chains_in_firewall(self): # FORWARD chains if "slipsBlocking" not in input_chain_rules: os.system( - self.sudo - + " iptables -I INPUT -j slipsBlocking >/dev/null 2>&1" + f"{self.sudo} iptables -I INPUT -j slipsBlocking " + f"{OUTPUT_TO_DEV_NULL}" ) if "slipsBlocking" not in output_chain_rules: os.system( - self.sudo - + " iptables -I OUTPUT -j slipsBlocking >/dev/null 2>&1" + f"{self.sudo} iptables -I OUTPUT -j slipsBlocking " + f"{OUTPUT_TO_DEV_NULL}" ) if "slipsBlocking" not in forward_chain_rules: os.system( - self.sudo - + " iptables -I FORWARD -j slipsBlocking >/dev/null 2>&1" + f"{self.sudo} iptables -I FORWARD -j slipsBlocking" + f" {OUTPUT_TO_DEV_NULL}" ) def _is_ip_already_blocked(self, ip) -> bool: @@ -134,7 +143,9 @@ def _block_ip(self, ip_to_block: str, flags: Dict[str, str]) -> bool: This function determines the user's platform and firewall and calls the appropriate function to add the rules to the used firewall. By default this function blocks all traffic from and to the given ip. - return strue if the ip is successfully blocked + and it Blocks private IPs on the given interface, and block public + IPs on all interfaces + returns true if the ip is successfully blocked """ if self.firewall != "iptables": @@ -152,15 +163,25 @@ def _block_ip(self, ip_to_block: str, flags: Dict[str, str]) -> bool: dport = flags.get("dport") sport = flags.get("sport") protocol = flags.get("protocol") + interface = flags.get("interface") # Set the default behaviour to block all traffic from and to an ip if from_ is None and to is None: from_, to = True, True # This dictionary will be used to construct the rule options = { "protocol": f" -p {protocol}" if protocol is not None else "", - "dport": f" --dport {str(dport)}" if dport is not None else "", - "sport": f" --sport {str(sport)}" if sport is not None else "", + "dport": f" --dport {dport}" if dport is not None else "", + "sport": f" --sport {sport}" if sport is not None else "", } + + if utils.is_private_ip(ip_to_block) and interface: + # block all ingoing AND outgoing packet on the given interface + options.update( + { + "interface": f" -i {interface} -o {interface}", + } + ) + blocked = False if from_: # Add rule to block traffic from source ip_to_block (-s) @@ -234,6 +255,7 @@ def main(self): "dport": data.get("dport"), "sport": data.get("sport"), "protocol": data.get("protocol"), + "interface": data.get("interface"), } if block: self._block_ip(ip, flags) diff --git a/modules/blocking/exec_iptables_cmd.py b/modules/blocking/exec_iptables_cmd.py index 704c34a18..ae0b859fa 100644 --- a/modules/blocking/exec_iptables_cmd.py +++ b/modules/blocking/exec_iptables_cmd.py @@ -22,7 +22,7 @@ def exec_iptables_command(sudo: str, action, ip_to_block, flag, options): for key in options.keys(): command += options[key] command += " -j DROP" - # Execute + exit_status = os.system(command) # 0 is the success value diff --git a/modules/flowalerts/conn.py b/modules/flowalerts/conn.py index 1519601c0..3edca73df 100644 --- a/modules/flowalerts/conn.py +++ b/modules/flowalerts/conn.py @@ -24,8 +24,6 @@ class Conn(IFlowalertsAnalyzer): def init(self): - # get the default gateway - self.gateway = self.db.get_gateway_ip() self.p2p_daddrs = {} # If 1 flow uploaded this amount of MBs or more, # slips will alert data upload @@ -289,14 +287,15 @@ def check_multiple_reconnection_attempts(self, profileid, twid, flow): self.db.set_reconnections(profileid, twid, current_reconnections) - def is_ignored_ip_data_upload(self, ip): + def is_ignored_ip_data_upload(self, ip, interface: str): """ Ignore the IPs that we shouldn't alert about """ ip_obj = ipaddress.ip_address(ip) + interface = interface or self.db.get_wifi_interface() if ( - ip == self.gateway + ip == self.db.get_gateway_ip(interface) or ip_obj.is_multicast or ip_obj.is_link_local or ip_obj.is_reserved @@ -323,8 +322,9 @@ def get_sent_bytes( daddr = flow["daddr"] sbytes: int = int(flow.get("sbytes", 0)) ts: str = flow.get("starttime", "") + interface: str = flow.get("interface", "") - if self.is_ignored_ip_data_upload(daddr) or not sbytes: + if self.is_ignored_ip_data_upload(daddr, interface) or not sbytes: continue if daddr in bytes_sent: @@ -430,7 +430,7 @@ def check_data_upload(self, profileid, twid, flow): """ if ( not flow.daddr - or self.is_ignored_ip_data_upload(flow.daddr) + or self.is_ignored_ip_data_upload(flow.daddr, flow.interface) or not flow.sbytes ): return False @@ -562,6 +562,11 @@ def check_conn_to_port_0(self, profileid, twid, flow): """ if flow.proto.lower() in ("igmp", "icmp", "ipv6-icmp", "arp"): return + + if ipaddress.ip_address(flow.daddr).is_multicast: + # igmp + return + try: flow.sport = int(flow.sport) flow.dport = int(flow.dport) @@ -747,7 +752,7 @@ def check_different_localnet_usage( if not (validators.ipv4(ip_to_check) and utils.is_private_ip(ip_obj)): return - own_local_network = self.db.get_local_network() + own_local_network = self.db.get_local_network(flow.interface) if not own_local_network: # the current local network wasn't set in the db yet # it's impossible to get here becaus ethe localnet is set before @@ -774,7 +779,7 @@ def is_dns_conn(flow): return ( flow.dport == 53 and flow.proto.lower() == "udp" - and flow.daddr == self.db.get_gateway_ip() + and flow.daddr == self.db.get_gateway_ip(flow.interface) ) def is_dhcp_conn(flow): @@ -783,7 +788,7 @@ def is_dhcp_conn(flow): return ( (flow.dport == 67 or flow.dport == 68) and flow.proto.lower() == "udp" - and flow.daddr == self.db.get_gateway_ip() + and flow.daddr == self.db.get_gateway_ip(flow.interface) ) with contextlib.suppress(ValueError): diff --git a/modules/flowalerts/dns.py b/modules/flowalerts/dns.py index bc8a42ce3..6d74c2a27 100644 --- a/modules/flowalerts/dns.py +++ b/modules/flowalerts/dns.py @@ -675,7 +675,7 @@ def check_different_localnet_usage( # outside of localnet return - own_local_network = self.db.get_local_network() + own_local_network = self.db.get_local_network(flow.interface) if not own_local_network: # the current local network wasn't set in the db yet # it's impossible to get here becaus ethe localnet is set before diff --git a/modules/flowalerts/set_evidence.py b/modules/flowalerts/set_evidence.py index 7be9437c4..129905b0f 100644 --- a/modules/flowalerts/set_evidence.py +++ b/modules/flowalerts/set_evidence.py @@ -233,7 +233,8 @@ def different_localnet_usage(self, twid, flow, ip_outside_localnet=""): f"A connection from a private IP ({flow.saddr}) on port " f"{flow.dport}/{flow.proto} " f"outside of the used local network " - f"{self.db.get_local_network()}. To IP: {flow.daddr} " + f"{self.db.get_local_network(flow.interface)}. To IP:" + f" {flow.daddr} " ) else: attacker = Attacker( @@ -251,7 +252,7 @@ def different_localnet_usage(self, twid, flow, ip_outside_localnet=""): f"A connection to a private IP ({flow.daddr}) on port" f" {flow.dport}/{flow.proto} " f"outside of the used local network " - f"{self.db.get_local_network()}. " + f"{self.db.get_local_network(flow.interface)}. " f"From IP: {flow.saddr} " ) proto = flow.proto.lower() diff --git a/modules/flowalerts/ssl.py b/modules/flowalerts/ssl.py index e71850e67..5b4b7036b 100644 --- a/modules/flowalerts/ssl.py +++ b/modules/flowalerts/ssl.py @@ -70,15 +70,22 @@ async def check_pastebin_download( def check_self_signed_certs(self, twid, flow): """ - checks the validation status of every a zeek ssl flow for self + checks the validation status of every zeek ssl flow for self signed certs """ - if "self signed" not in flow.validation_status: + if not hasattr(flow, "validation_status"): + # must be a suricata TLS flow return + if "self signed" not in flow.validation_status: + return self.set_evidence.self_signed_certificates(twid, flow) def detect_malicious_ja3(self, twid, flow): + if not (hasattr(flow, "ja3") and hasattr(flow, "ja3s")): + # its a suricata flow + return + if not (flow.ja3 or flow.ja3s): # we don't have info about this flow's ja3 or ja3s fingerprint return @@ -295,8 +302,11 @@ async def will_slips_have_new_incoming_flows(): pass # timeout reached def detect_doh(self, twid, flow): + if not hasattr(flow, "is_DoH"): + return False if not flow.is_DoH: return False + self.set_evidence.doh(twid, flow) self.db.set_ip_info(flow.daddr, {"is_doh_server": True}) diff --git a/modules/flowmldetection/flowmldetection.py b/modules/flowmldetection/flowmldetection.py index 98bb670ab..ffb9bba54 100644 --- a/modules/flowmldetection/flowmldetection.py +++ b/modules/flowmldetection/flowmldetection.py @@ -148,6 +148,7 @@ def process_features(self, dataset): "endtime", "bytes", "flow_source", + "interface", ] for field in to_drop: try: diff --git a/modules/ip_info/ip_info.py b/modules/ip_info/ip_info.py index 2fc48754b..085349a95 100644 --- a/modules/ip_info/ip_info.py +++ b/modules/ip_info/ip_info.py @@ -3,6 +3,8 @@ from typing import ( Union, Optional, + Dict, + List, ) from uuid import uuid4, getnode import datetime @@ -65,7 +67,9 @@ def init(self): self.whitelist = Whitelist(self.logger, self.db) self.is_running_non_stop: bool = self.db.is_running_non_stop() self.valid_tlds = whois.validTlds() - self.is_running_in_ap_mode = False + self.is_running_in_ap_mode: bool = ( + True if self.args.access_point else False + ) async def open_dbs(self): """Function to open the different offline databases used in this @@ -327,38 +331,12 @@ async def shutdown_gracefully(self): # GW @staticmethod - def get_gateway_for_iface(iface) -> Optional[str]: - gws = netifaces.gateways() - for family in (netifaces.AF_INET, netifaces.AF_INET6): - if "default" in gws and gws["default"][family]: - gw, gw_iface = gws["default"][family] - if gw_iface == iface: - return gw - return None - - def is_ap_mode_iwconfig(self) -> bool: - """ - check is slips is running as an AP - """ - interface: str = getattr(self.args, "interface", None) - try: - output = subprocess.check_output( - ["iwconfig", interface], text=True, stderr=subprocess.DEVNULL - ) - for line in output.splitlines(): - if "Mode:" in line: - mode = line.split("Mode:")[1].split()[0] - return mode.lower() == "master" - except Exception: - pass - return False - def get_default_gateway(self) -> str: gws = netifaces.gateways() default = gws.get("default", {}) return default.get(netifaces.AF_INET, (None,))[0] - def get_gateway_ip_if_interface(self) -> Optional[str]: + def get_gateway_ip_if_interface(self) -> Dict[str, str] | None: """ returns the gateway ip of the given interface if running on an interface. @@ -369,26 +347,16 @@ def get_gateway_ip_if_interface(self) -> Optional[str]: # only works if running on an interface return - interface: str = getattr(self.args, "interface", None) - if self.is_running_in_ap_mode: - # ok why?? because when slips is running on normal hosts, we want - # the ip of the default gateway which is probably going to be the - # gw of the given interface and in the localnet of that interface. - # BUT when Slips is running as an AP, we dont want the ip of the - # default gateway, we want the ip of the AP. because in this case, - # the AP is the gateway of the computers connected to it. we don't - # want the ip of the actual gateway that is probably present in - # another localnet (the eth0). - # return the own IP. because slips is "the gw" for connected clients + interfaces: List[str] = utils.get_all_interfaces(self.args) + + gw_ips = {} + for interface in interfaces: try: - return netifaces.ifaddresses(interface)[netifaces.AF_INET][0][ - "addr" - ] + gw_ip = utils.get_gateway_for_iface(interface) + gw_ips.update({interface: gw_ip}) except KeyError: - return # No IP assigned - else: - # get the gw of the given interface - return self.get_gateway_for_iface(interface) + pass + return gw_ips @staticmethod def get_own_mac() -> str: @@ -400,35 +368,16 @@ def get_own_mac() -> str: ) return mac - def get_gateway_mac(self, gw_ip: str) -> Optional[str]: - """ - Given the gw_ip, this function tries to get the MAC - from arp.log, using ip neigh or from arp tables - PS: returns own MAc address if running in AP mode - """ - # we keep a cache of the macs and their IPs - # In case of a zeek dir or a pcap, - # check if we have the mac of this ip already saved in the db. - if gw_mac := self.db.get_mac_addr_from_profile(f"profile_{gw_ip}"): - gw_mac: Union[str, None] - return gw_mac - - if not self.is_running_non_stop: - # running on pcap or a given zeek file/dir - # no MAC in arp.log (in the db) and can't use arp tables, - # so it's up to the db.is_gw_mac() function to determine the gw mac - # if it's seen associated with a public IP - return - - if self.is_running_in_ap_mode: - # when running in AP mode, we are the GW for the connected - # clients, this makes our mac the GW mac. - # in AP mode, the given gw_ip is our own ip anyway, so it wont - # be found in arp tables or ip neigh command anyway. - return self.get_own_mac() + def _get_wifi_interface_if_ap(self) -> str | None: + ap_interfaces: str = self.db.get_wifi_interface() + try: + # we're now sure that we're running in AP mode + wifi_interface = ap_interfaces["wifi_interface"] + except KeyError: + wifi_interface = None + return wifi_interface - # Obtain the MAC address by using the hosts ARP table - # First, try the ip command + def _get_mac_using_ip_neigh(self, gw_ip) -> str | None: try: ip_output = subprocess.run( ["ip", "neigh", "show", gw_ip], @@ -436,20 +385,60 @@ def get_gateway_mac(self, gw_ip: str) -> Optional[str]: check=True, text=True, ).stdout - gw_mac = ip_output.split()[-2] + mac = ip_output.split()[-2] + return mac + except (subprocess.CalledProcessError, IndexError, FileNotFoundError): + return + def _get_mac_using_arp_cache(self, gw_ip) -> str | None: + try: + gw_mac = utils.get_mac_for_ip_using_cache(gw_ip) return gw_mac - except (subprocess.CalledProcessError, IndexError, FileNotFoundError): - # If the ip command doesn't exist or has failed, try using the - # arp command - try: - gw_mac = utils.get_mac_for_ip_using_cache(gw_ip) - return gw_mac - except (subprocess.CalledProcessError, IndexError): - # Could not find the MAC address of gw_ip - return + except (subprocess.CalledProcessError, IndexError): + # Could not find the MAC address of gw_ip + return - return gw_mac + def get_gateway_mac(self, gw_ips: Dict[str, str]) -> Optional[str]: + """ + Given the gw_ips, this function tries to get the MAC + from arp.log, using ip neigh or from arp tables + """ + wifi_interface: str | None = self._get_wifi_interface_if_ap() + + gw_macs = {} + for interface, gw_ip in gw_ips.items(): + # we keep a cache of the macs and their IPs + # In case of a zeek dir or a pcap, + # check if we have the mac of this ip already saved in the db. + if gw_mac := self.db.get_mac_addr_from_profile(f"profile_{gw_ip}"): + gw_mac: Union[str, None] + gw_macs[interface] = gw_mac + continue + + if not self.is_running_non_stop: + # ok now we are running on pcap or a given zeek file/dir + # and we have no MAC in arp.log (in the db) and can't use arp + # tables, so it's up to the db.is_gw_mac() function to + # determine the gw mac if it's seen associated with a + # public IP + continue + + if interface == wifi_interface: + # this interface is the wifi interface of the AP + if own_mac := self.get_own_mac(): + gw_macs[interface] = own_mac + continue + + if gw_mac := self._get_mac_using_ip_neigh(gw_ip): + gw_macs[interface] = gw_mac + continue + + if gw_mac := self._get_mac_using_arp_cache(gw_ip): + gw_macs[interface] = gw_mac + continue + + if gw_macs: + return gw_macs def check_if_we_have_pending_offline_mac_queries(self): """ @@ -551,16 +540,17 @@ def set_evidence_malicious_jarm_hash( def pre_main(self): utils.drop_root_privs_permanently() self.wait_for_dbs() - - self.is_running_in_ap_mode: bool = self.is_ap_mode_iwconfig() # the following method only works when running on an interface - if ip := self.get_gateway_ip_if_interface(): - self.db.set_default_gateway("IP", ip) + if gw_ips := self.get_gateway_ip_if_interface(): + for interface, gw_ip in gw_ips.items(): + self.db.set_default_gateway("IP", gw_ip, interface) + # whether we found the gw ip using dhcp in profiler - # or using ip route using self.get_gateway_ip() + # or using ip route here (self.get_gateway_ip()) # now that it's found, get and store the mac addr of it - if mac := self.get_gateway_mac(ip): - self.db.set_default_gateway("MAC", mac) + if gw_macs := self.get_gateway_mac(gw_ips): + for interface, gw_mac in gw_macs.items(): + self.db.set_default_gateway("MAC", gw_mac, interface) def handle_new_ip(self, ip: str): try: diff --git a/modules/irisModule/irisModule.py b/modules/irisModule/irisModule.py index 2e766e4b2..929b2c17f 100644 --- a/modules/irisModule/irisModule.py +++ b/modules/irisModule/irisModule.py @@ -49,6 +49,7 @@ def _iris_configurator(self, config_path: str, redis_port: int): with open(config_path, "r") as file: config = yaml.safe_load(file) + wifi_interface = self.db.get_wifi_interface() # Ensure the Redis section exists and update the port if "Redis" in config: config["Redis"]["Port"] = redis_port @@ -62,12 +63,12 @@ def _iris_configurator(self, config_path: str, redis_port: int): } if "Server" in config: # config["Server"]["Port"] = 9010 - config["Server"]["Host"] = self.db.get_host_ip() + config["Server"]["Host"] = self.db.get_host_ip(wifi_interface) config["Server"]["DhtServerMode"] = "true" else: config["Redis"] = { "Port": 6644, - "Host": self.db.get_host_ip(), + "Host": self.db.get_host_ip(wifi_interface), "DhtServerMode": "true", } diff --git a/modules/threat_intelligence/threat_intelligence.py b/modules/threat_intelligence/threat_intelligence.py index 2a1a6fe94..2c3773cf4 100644 --- a/modules/threat_intelligence/threat_intelligence.py +++ b/modules/threat_intelligence/threat_intelligence.py @@ -1197,13 +1197,17 @@ def is_inbound_traffic(self, ip: str, ip_state: str) -> bool: 2. ip is public 3. ip is not our host ip """ - host_ip: str = self.db.get_host_ip() - return ( - "src" in ip_state - and ipaddress.ip_address(ip).is_global - and ip != host_ip - and not utils.is_ip_in_client_ips(ip, self.client_ips) - ) + # if slips was monitoring multiple interfaces, it'd have multiple + # host ips + for host_ip in self.db.get_all_host_ips(): + if ( + "src" in ip_state + and ipaddress.ip_address(ip).is_global + and ip != host_ip + and not utils.is_ip_in_client_ips(ip, self.client_ips) + ): + return True + return False def search_online_for_ip(self, ip: str, ip_state: str): if self.is_inbound_traffic(ip, ip_state): diff --git a/modules/timeline/timeline.py b/modules/timeline/timeline.py index da2c3db87..ace09947e 100644 --- a/modules/timeline/timeline.py +++ b/modules/timeline/timeline.py @@ -31,7 +31,7 @@ def init(self): "new_flow": self.c1, } self.classifier = FlowClassifier() - self.host_ip: str = self.db.get_host_ip() + self.host_ips: List[str] = self.db.get_all_host_ips() def read_configuration(self): conf = ConfigParser() @@ -55,7 +55,7 @@ def is_inbound_traffic(self, flow) -> bool: # slips only detects inbound traffic in the "all" direction return False - return flow.daddr == self.host_ip or utils.is_ip_in_client_ips( + return flow.daddr in self.host_ips or utils.is_ip_in_client_ips( flow.daddr, self.client_ips ) @@ -135,9 +135,7 @@ def process_ssh_altflow(self, alt_flow: dict): return {"info": ssh_activity} def process_altflow(self, profileid, twid, flow) -> dict: - alt_flow: dict = self.db.get_altflow_from_uid( - profileid, twid, flow.uid - ) + alt_flow: dict = self.db.get_altflow_from_uid(flow.uid) altflow_info = {"info": ""} if not alt_flow: diff --git a/slips/main.py b/slips/main.py index cc7f38d0e..0b3104bb4 100644 --- a/slips/main.py +++ b/slips/main.py @@ -16,6 +16,7 @@ import logging from managers.host_ip_manager import HostIPManager +from managers.ap_manager import APManager from managers.metadata_manager import MetadataManager from managers.process_manager import ProcessManager from managers.profilers_manager import ProfilersManager @@ -24,7 +25,7 @@ from slips_files.common.parsers.config_parser import ConfigParser from slips_files.common.printer import Printer from slips_files.common.slips_utils import utils -from slips_files.common.style import green +from slips_files.common.style import green, yellow from slips_files.core.database.database_manager import DBManager from slips_files.core.helpers.checker import Checker @@ -59,7 +60,7 @@ def __init__(self, testing=False): self.args = self.conf.get_args() self.profilers_manager = ProfilersManager(self) self.pid = os.getpid() - self.checker.check_given_flags() + self.checker.verify_given_flags() self.prepare_locks_dir() if not self.args.stopdaemon: # Check the type of input @@ -67,7 +68,7 @@ def __init__(self, testing=False): self.input_type, self.input_information, self.line_type, - ) = self.checker.check_input_type() + ) = self.checker.get_input_type() # If we need zeek (bro), test if we can run it. self.check_zeek_or_bro() self.prepare_output_dir() @@ -76,6 +77,7 @@ def __init__(self, testing=False): self.twid_width = self.conf.get_tw_width() # should be initialised after self.input_type is set self.host_ip_man = HostIPManager(self) + self.ap_manager = APManager(self) def check_zeek_or_bro(self): """ @@ -196,6 +198,7 @@ def prepare_output_dir(self): # self.args.output is the same as self.alerts_default_path self.input_information = os.path.normpath(self.input_information) + self.input_information = self.input_information.replace(",", "_") # now that slips can run several instances, # each created dir will be named after the instance # that created it @@ -449,11 +452,13 @@ def get_slips_error_file(self) -> str: def print_gw_info(self): if self.gw_info_printed: return - if ip := self.db.get_gateway_ip(): - self.print(f"Detected gateway IP: {green(ip)}") - if mac := self.db.get_gateway_mac(): - self.print(f"Detected gateway MAC: {green(mac)}") - self.gw_info_printed = True + + for interface in utils.get_all_interfaces(self.args): + if ip := self.db.get_gateway_ip(interface): + self.print(f"Detected gateway IP: {green(ip)}") + if mac := self.db.get_gateway_mac(interface): + self.print(f"Detected gateway MAC: {green(mac)}") + self.gw_info_printed = True def prepare_locks_dir(self): """ @@ -505,6 +510,19 @@ def start(self): self.print(str(e), 1, 1) self.terminate_slips() + if self.args.access_point: + # is -ap given but no AP running? + if not self.ap_manager.is_ap_running(): + self.print( + "Slips was started with -ap but can't detect a " + "running access point. Please start an access point " + "and restart Slips. Stopping." + ) + self.terminate_slips() + else: + self.print(yellow("Slips is running in AP mode.")) + self.ap_manager.store_ap_interfaces(self.input_information) + self.db.set_input_metadata( { "output_dir": self.args.output, @@ -519,7 +537,7 @@ def start(self): # to be able to use the host IP as analyzer IP in alerts.json # should be after setting the input metadata with "input_type" # TLDR; dont change the order of this line - host_ip = self.host_ip_man.store_host_ip() + host_ips = self.host_ip_man.store_host_ip() self.print( f"Using redis server on port: {green(self.redis_port)}", @@ -634,8 +652,6 @@ def sig_handler(sig, frame): # Don't try to stop slips if it's capturing from # an interface or a growing zeek dir - self.is_interface: bool = self.db.is_running_non_stop() - while not self.proc_man.stop_slips(): # Sleep some time to do routine checks and give time for # more traffic to come @@ -654,7 +670,7 @@ def sig_handler(sig, frame): self.metadata_man.update_slips_stats_in_the_db()[1] ) - self.host_ip_man.update_host_ip(host_ip, modified_profiles) + self.host_ip_man.update_host_ip(host_ips, modified_profiles) except KeyboardInterrupt: # the EINTR error code happens if a signal occurred while diff --git a/slips_files/common/abstracts/iasync_module.py b/slips_files/common/abstracts/iasync_module.py index 1b33c866f..349dea8d7 100644 --- a/slips_files/common/abstracts/iasync_module.py +++ b/slips_files/common/abstracts/iasync_module.py @@ -1,6 +1,7 @@ # SPDX-FileCopyrightText: 2021 Sebastian Garcia # SPDX-License-Identifier: GPL-2.0-only import asyncio +import traceback from asyncio import Task from typing import ( Callable, @@ -45,8 +46,34 @@ def handle_task_exception(self, task: asyncio.Task): except asyncio.CancelledError: return # Task was cancelled, not an error if exception: - self.print(f"Unhandled exception in task: {exception}") - self.print_traceback() + self.print(f"Unhandled exception in task: {exception!r} .. ") + self.print_traceback_from_exception(exception, task) + + def print_traceback_from_exception( + self, exception: BaseException, task: asyncio.Task + ): + # Try to get the traceback directly from the task + tb = exception.__traceback__ + if tb: + formatted_tb = "".join( + traceback.format_exception(type(exception), exception, tb) + ) + # Get the last traceback line number + last_tb = traceback.extract_tb(tb)[-1] + self.print( + f"Problem in line {last_tb.lineno} of {last_tb.filename}", 0, 1 + ) + self.print(formatted_tb, 0, 1) + else: + # fallback: print stack if no traceback + stack = task.get_stack() + if stack: + formatted_stack = "".join( + traceback.format_list(traceback.extract_stack(stack[-1])) + ) + self.print(f"Task stack:\n{formatted_stack}", 0, 1) + else: + self.print("No traceback or stack available.", 0, 1) async def main(self): ... diff --git a/slips_files/common/flow_classifier.py b/slips_files/common/flow_classifier.py index 14c91d8f8..4b4931628 100644 --- a/slips_files/common/flow_classifier.py +++ b/slips_files/common/flow_classifier.py @@ -60,6 +60,7 @@ def __init__(self): "suricata_http": SuricataHTTP, "suricata_dns": SuricataDNS, "suricata_tls": SuricataTLS, + "suricata_ssl": SuricataTLS, "suricata_files": SuricataFile, "suricata_ssh": SuricataSSH, } diff --git a/slips_files/common/idmefv2.py b/slips_files/common/idmefv2.py index cfd90a39d..aa74aa3e3 100644 --- a/slips_files/common/idmefv2.py +++ b/slips_files/common/idmefv2.py @@ -51,21 +51,24 @@ def __init__(self, logger: Output, db): self.printer = Printer(logger, self.name) self.db = db self.model: str = utils.get_slips_version() - self.analyzer = { - "IP": self.get_host_ip(), + + # the used idmef version + self.version = "2.0.3" + + def _get_analyzer(self, interface): + return { + "IP": self.get_host_ip(interface), "Name": "Slips", "Model": self.model, "Category": ["NIDS"], "Data": ["Flow", "Network"], "Method": ["Heuristic"], } - # the used idmef version - self.version = "2.0.3" - def get_host_ip(self) -> str: + def get_host_ip(self, interface: str) -> str: if not self.db.is_running_non_stop(): return DEFAULT_ADDRESS - if host_ip := self.db.get_host_ip(): + if host_ip := self.db.get_host_ip(interface): return host_ip return DEFAULT_ADDRESS @@ -140,7 +143,9 @@ def convert_to_idmef_alert(self, alert: Alert) -> Message: msg.update( { "Version": self.version, - "Analyzer": self.analyzer, + # alerts aren't tied to a specific interface, and alert + # is a combination of evidence from any interface + "Analyzer": self._get_analyzer(None), "Source": [{"IP": alert.profile.ip}], "ID": alert.id, "Status": "Incident", @@ -198,7 +203,7 @@ def convert_to_idmef_event(self, evidence: Evidence) -> Message: msg.update( { "Version": self.version, - "Analyzer": self.analyzer, + "Analyzer": self._get_analyzer(evidence.interface), "Status": IDMEFv2Status.EVIDENCE.value, # that is a uuid4() "ID": evidence.id, diff --git a/slips_files/common/parsers/arg_parser.py b/slips_files/common/parsers/arg_parser.py index 46ae5ec44..eeb903a47 100644 --- a/slips_files/common/parsers/arg_parser.py +++ b/slips_files/common/parsers/arg_parser.py @@ -164,6 +164,15 @@ def parse_arguments(self): required=False, help="Read packets from an interface.", ) + self.add_argument( + "-ap", + "--access-point", + action="store", + required=False, + help="Read packets from two interfaces when Slips is running as " + "an access point. the wifi interface should come first (" + "e.g -ap wlan0, eth0).", + ) self.add_argument( "-F", "--pcapfilter", diff --git a/slips_files/common/slips_utils.py b/slips_files/common/slips_utils.py index 8b0fd0b58..34ff026a5 100644 --- a/slips_files/common/slips_utils.py +++ b/slips_files/common/slips_utils.py @@ -189,6 +189,48 @@ def to_dict(self, obj): return obj + def get_interface_of_ip(self, ip, db, args) -> str: + """ + Gets the interface this IP is attacking on + return s None if slips isnt running on an interface + """ + if args.interface: + return args.interface + + if args.access_point: + # we have 2 interfaces, in which interface is the ip_to_block? + for _type, interface in db.get_ap_info().items(): + # _type can be 'wifi_interface' or "ethernet_interface" + local_net: str = db.get_local_network(interface) + ip_obj = ipaddress.ip_address(ip) + if ip_obj in ipaddress.IPv4Network(local_net): + return interface + + def infer_used_interface(self) -> str | None: + """for when the user is using -g and didnt give slips an interface""" + # PS: make sure you neveer run this when slips is given a file or a + # pcap + try: + gateways = netifaces.gateways() + default_gateway = gateways.get("default", {}) + if netifaces.AF_INET not in default_gateway: + return None + + interface = default_gateway[netifaces.AF_INET][1] + return interface + except KeyError: + return + + def get_gateway_for_iface(self, iface: str) -> Optional[str]: + """returns the default gateway for the given interface""" + gws = netifaces.gateways() + for family in (netifaces.AF_INET, netifaces.AF_INET6): + if "default" in gws and gws["default"][family]: + gw, gw_iface = gws["default"][family] + if gw_iface == iface: + return gw + return None + def is_valid_uuid4(self, uuid_string: str) -> bool: """Validate that the given str in UUID4""" try: @@ -437,6 +479,38 @@ def get_human_readable_datetime(self, format=None) -> str: datetime.now(), format or self.alerts_format ) + def get_cidr_of_interface(self, interface: str) -> str | None: + try: + addrs = netifaces.ifaddresses(interface) + ipv4_addrs = addrs.get(socket.AF_INET) + + if ipv4_addrs: + for addr_info in ipv4_addrs: + ip = addr_info.get("addr") + netmask = addr_info.get("netmask") + + if ip and netmask: + # Create an interface object from the IP and netmask + iface = ipaddress.ip_interface(f"{ip}/{netmask}") + network_cidr = str(iface.network) + return network_cidr + except Exception: + return + + def get_all_interfaces(self, args) -> List[str]: + """ + returns a list of all interfaces slips is now monitoring + :param args: slips args + """ + if args.interface: + return [args.interface] + if args.access_point: + return args.access_point.split(",") + if args.growing: + return [self.infer_used_interface()] + + return ["default"] + def get_mac_for_ip_using_cache(self, ip: str) -> str | None: """gets the mac of the given local ip using the local arp cache""" try: diff --git a/slips_files/core/database/database_manager.py b/slips_files/core/database/database_manager.py index 59fc1d1c3..c46671e90 100644 --- a/slips_files/core/database/database_manager.py +++ b/slips_files/core/database/database_manager.py @@ -1,5 +1,6 @@ # SPDX-FileCopyrightText: 2021 Sebastian Garcia # SPDX-License-Identifier: GPL-2.0-only +import json import os import shutil import sqlite3 @@ -218,6 +219,12 @@ def is_known_fp_md5_hash(self, *args, **kwargs): def ask_for_ip_info(self, *args, **kwargs): return self.rdb.ask_for_ip_info(*args, **kwargs) + def set_ap_info(self, *args, **kwargs): + return self.rdb.set_ap_info(*args, **kwargs) + + def get_ap_info(self, *args, **kwargs): + return self.rdb.get_ap_info(*args, **kwargs) + @classmethod def discard_obj(cls): """ @@ -439,6 +446,12 @@ def set_reconnections(self, *args, **kwargs): def get_host_ip(self, *args, **kwargs): return self.rdb.get_host_ip(*args, **kwargs) + def get_wifi_interface(self, *args, **kwargs): + return self.rdb.get_wifi_interface(*args, **kwargs) + + def get_all_host_ips(self, *args, **kwargs): + return self.rdb.get_all_host_ips(*args, **kwargs) + def set_new_incoming_flows(self, *args, **kwargs): return self.rdb.set_new_incoming_flows(*args, **kwargs) @@ -542,7 +555,28 @@ def get_flows_causing_evidence(self, *args, **kwargs): """returns the list of uids of the flows causing evidence""" return self.rdb.get_flows_causing_evidence(*args, **kwargs) + def _get_evidence_interface(self, evidence: Evidence) -> str | None: + """ + Returns the interface of the first flow of the given evidence + """ + try: + # get any flow uid of this evidence, to get the interface of it + uid = evidence.uid[0] + except KeyError: + # evidence doesnt have a uid? + return + + try: + flow: str = self.get_flow(uid)[uid] + if isinstance(flow, str): + flow: dict = json.loads(flow) + except KeyError: + flow: dict = self.get_altflow_from_uid(uid) + return flow["interface"] if flow else None + def set_evidence(self, evidence: Evidence): + interface: str | None = self._get_evidence_interface(evidence) + setattr(evidence, "interface", interface) evidence_set = self.rdb.set_evidence(evidence) if evidence_set: # an evidence is generated for this profile diff --git a/slips_files/core/database/redis_db/alert_handler.py b/slips_files/core/database/redis_db/alert_handler.py index dc4c8b0ad..c4ead0f5f 100644 --- a/slips_files/core/database/redis_db/alert_handler.py +++ b/slips_files/core/database/redis_db/alert_handler.py @@ -234,7 +234,6 @@ def _get_more_info_about_evidence(self, evidence) -> Evidence: break setattr(evidence, entity_type, entity) - return evidence def set_evidence(self, evidence: Evidence): diff --git a/slips_files/core/database/redis_db/constants.py b/slips_files/core/database/redis_db/constants.py index 249e8dbb4..35b135fbd 100644 --- a/slips_files/core/database/redis_db/constants.py +++ b/slips_files/core/database/redis_db/constants.py @@ -52,6 +52,7 @@ class Constants: SLIPS_START_TIME = "slips_start_time" USED_FTP_PORTS = "used_ftp_ports" SLIPS_INTERNAL_TIME = "slips_internal_time" + IS_RUNNING_AS_AP = "is_slips_running_as_an_ap" WARDEN_INFO = "Warden" MODE = "mode" ANALYSIS = "analysis" diff --git a/slips_files/core/database/redis_db/database.py b/slips_files/core/database/redis_db/database.py index ff938cd96..409e8d034 100644 --- a/slips_files/core/database/redis_db/database.py +++ b/slips_files/core/database/redis_db/database.py @@ -690,6 +690,25 @@ def ask_for_ip_info( def get_slips_internal_time(self): return self.r.get(self.constants.SLIPS_INTERNAL_TIME) or 0 + def set_ap_info(self, interfaces: Dict[str, str]): + """the main slips instance call this func for the modules to be + aware that slips is running as an access point""" + return self.r.set( + self.constants.IS_RUNNING_AS_AP, json.dumps(interfaces) + ) + + def get_ap_info(self) -> Dict[str, str] | None: + """returns both AP interfaces or None if slips is not + running in AP mode + returns a dict with {"wifi_interface": , + "ethernet_interface": } + or None if slips is not running as an AP + """ + ap_info = self.r.get(self.constants.IS_RUNNING_AS_AP) + if not ap_info: + return None + return json.loads(ap_info) + def get_redis_keys_len(self) -> int: """returns the length of all keys in the db""" return self.r.dbsize() @@ -707,14 +726,14 @@ def get_equivalent_tws(self, hrs: float) -> int: """ return int(hrs * 3600 / self.width) - def set_local_network(self, cidr): + def set_local_network(self, cidr, interface): """ set the local network used in the db """ - self.r.set(self.constants.LOCAL_NETWORK, cidr) + self.r.hset(self.constants.LOCAL_NETWORK, interface, cidr) - def get_local_network(self): - return self.r.get(self.constants.LOCAL_NETWORK) + def get_local_network(self, interface): + return self.r.hget(self.constants.LOCAL_NETWORK, interface) def get_used_port(self) -> int: return int(self.r.config_get(self.constants.REDIS_USED_PORT)["port"]) @@ -1085,7 +1104,7 @@ def mark_srcip_as_seen_in_connlog(self, ip): """ self.r.sadd(self.constants.SRCIPS_SEEN_IN_CONN_LOG, ip) - def _is_gw_mac(self, mac_addr: str) -> bool: + def _is_gw_mac(self, mac_addr: str, interface: str) -> bool: """ Detects the MAC of the gateway if 1 mac is seen assigned to 1 public destination IP @@ -1097,9 +1116,9 @@ def _is_gw_mac(self, mac_addr: str) -> bool: if self._gateway_MAC_found: # gateway MAC already set using this function - return self.get_gateway_mac() == mac_addr + return self.get_gateway_mac(interface) == mac_addr - def _determine_gw_mac(self, ip, mac): + def _determine_gw_mac(self, ip, mac, interface: str): """ sets the gw mac if the given ip is public and is assigned a mc """ @@ -1113,7 +1132,7 @@ def _determine_gw_mac(self, ip, mac): # belongs to it # we are sure this is the gw mac # set it if we don't already have it in the db - self.set_default_gateway(self.constants.MAC, mac) + self.set_default_gateway(self.constants.MAC, mac, interface) # mark the gw mac as found so we don't look for it again self._gateway_MAC_found = True @@ -1300,38 +1319,62 @@ def get_organization_of_port(self, portproto: str): self.constants.ORGANIZATIONS_PORTS, portproto.lower() ) - def add_zeek_file(self, filename): + def add_zeek_file(self, filename, interface): """Add an entry to the list of zeek files""" - self.r.sadd(self.constants.ZEEK_FILES, filename) + self.r.hset(self.constants.ZEEK_FILES, filename, interface) def get_all_zeek_files(self) -> set: """Return all entries from the list of zeek files""" - return self.r.smembers(self.constants.ZEEK_FILES) + return self.r.hgetall(self.constants.ZEEK_FILES) - def get_gateway_ip(self): - return self.r.hget(self.constants.DEFAULT_GATEWAY, "IP") + def _get_gw_info(self, interface: str) -> Dict[str, str] | None: + """ + gets the gw of the given interface, when slips is runnuning on a + file, it uses "default" as the interface + """ + if not interface: + interface = "default" - def get_gateway_mac(self): - return self.r.hget(self.constants.DEFAULT_GATEWAY, self.constants.MAC) + gw_info: str = self.r.hget(self.constants.DEFAULT_GATEWAY, interface) + if gw_info: + gw_info: Dict[str, str] = json.loads(gw_info) + return gw_info - def get_gateway_mac_vendor(self): - return self.r.hget(self.constants.DEFAULT_GATEWAY, "Vendor") + def get_gateway_ip(self, interface: str) -> str | None: + if gw_info := self._get_gw_info(interface): + return gw_info.get("IP") - def set_default_gateway(self, address_type: str, address: str): + def get_gateway_mac(self, interface): + if gw_info := self._get_gw_info(interface): + return gw_info.get(self.constants.MAC) + + def get_gateway_mac_vendor(self, interface): + if gw_info := self._get_gw_info(interface): + return gw_info.get("Vendor") + + def set_default_gateway( + self, address_type: str, address: str, interface: str + ): """ :param address_type: can either be 'IP' or 'MAC' :param address: can be ip or mac, but always is a str + :param interface: which interface is the given address the GW to? """ # make sure the IP or mac aren't already set before re-setting if ( - (address_type == "IP" and not self.get_gateway_ip()) + (address_type == "IP" and not self.get_gateway_ip(interface)) or ( address_type == self.constants.MAC - and not self.get_gateway_mac() + and not self.get_gateway_mac(interface) + ) + or ( + address_type == "Vendor" + and not self.get_gateway_mac_vendor(interface) ) - or (address_type == "Vendor" and not self.get_gateway_mac_vendor()) ): - self.r.hset(self.constants.DEFAULT_GATEWAY, address_type, address) + gw_info = json.dumps({address_type: address}) + + self.r.hset(self.constants.DEFAULT_GATEWAY, interface, gw_info) def get_domain_resolution(self, domain) -> List[str]: """ @@ -1381,20 +1424,47 @@ def set_reconnections(self, profileid, twid, data): data = json.dumps(data) self.r.hset(f"{profileid}_{twid}", "Reconnections", str(data)) - def get_host_ip(self) -> Optional[str]: - """returns the latest added host ip""" - host_ip: List[str] = self.r.zrevrange( - "host_ip", 0, 0, withscores=False - ) + def get_host_ip(self, interface) -> Optional[str]: + """returns the latest added host ip + :param interface: can be an actual interface or "default" + """ + key = f"host_ip_{interface}" + host_ip: List[str] = self.r.zrevrange(key, 0, 0, withscores=False) return host_ip[0] if host_ip else None - def set_host_ip(self, ip): + def get_wifi_interface(self): + """ + return sthe wifi interface if running as an AP, and the user + supplied interfcae if not. + """ + if ap_info := self.get_ap_info(): + return ap_info["wifi_interface"] + + return self.get_interface() + + def get_all_host_ips(self) -> List[str]: + """returns the latest added host ip of all interfaces""" + ip_keys = self.r.scan_iter(match="host_ip_*") + + all_ips: List[str] = [] + for key in ip_keys: + host_ip_list: List[bytes] = self.r.zrevrange( + key, 0, 0, withscores=False + ) + if host_ip_list: + # Decode the bytes to a string before appending + latest_ip: str = host_ip_list[0] + all_ips.append(latest_ip) + return all_ips + + def set_host_ip(self, ip, interface: str): """Store the IP address of the host in a db. There can be more than one""" # stored them in a sorted set to be able to retrieve the latest one # of them as the host ip - host_ips_added = self.r.zcard("host_ip") - self.r.zadd("host_ip", {ip: host_ips_added + 1}) + key = f"host_ip_{interface}" + host_ips_added = self.r.zcard(key) + self.r.zadd(key, {ip: host_ips_added + 1}) def set_asn_cache(self, org: str, asn_range: str, asn_number: str) -> None: """ diff --git a/slips_files/core/database/redis_db/profile_handler.py b/slips_files/core/database/redis_db/profile_handler.py index 9396e3ce9..af18ff2c0 100644 --- a/slips_files/core/database/redis_db/profile_handler.py +++ b/slips_files/core/database/redis_db/profile_handler.py @@ -1247,22 +1247,26 @@ def update_mac_of_profile(self, profileid: str, mac: str): """Add the MAC addr to the given profileid key""" self.r.hset(profileid, self.constants.MAC, mac) - def _should_associate_this_mac_with_this_ip(self, ip, mac) -> bool: + def _should_associate_this_mac_with_this_ip( + self, ip, mac, interface + ) -> bool: return not ( ip == "0.0.0.0" or not mac # sometimes we create profiles with the mac address. # don't save that in MAC hash or validators.mac_address(ip) - or self._is_gw_mac(mac) + or self._is_gw_mac(mac, interface) # we're trying to assign the gw mac to # an ip that isn't the gateway's # this happens bc any public IP probably has the gw MAC # in the zeek logs, so skip - or ip == self.get_gateway_ip() + or ip == self.get_gateway_ip(interface) ) - def add_mac_addr_to_profile(self, profileid: str, mac_addr: str): + def add_mac_addr_to_profile( + self, profileid: str, mac_addr: str, interface: str + ): """ Used to associate the given profile with the given MAC addr. stores this info in the 'MAC' key in the db @@ -1276,12 +1280,12 @@ def add_mac_addr_to_profile(self, profileid: str, mac_addr: str): incoming_ip: str = profileid.split("_")[1] if not self._should_associate_this_mac_with_this_ip( - incoming_ip, mac_addr + incoming_ip, mac_addr, interface ): return False # see if this is the gw mac - self._determine_gw_mac(incoming_ip, mac_addr) + self._determine_gw_mac(incoming_ip, mac_addr, interface) # get the ips that belong to this mac cached_ips: Optional[List] = self.r.hmget( diff --git a/slips_files/core/database/sqlite_db/database.py b/slips_files/core/database/sqlite_db/database.py index 77797b5ef..6a3d4a0a0 100644 --- a/slips_files/core/database/sqlite_db/database.py +++ b/slips_files/core/database/sqlite_db/database.py @@ -71,7 +71,7 @@ def get_db_path(self) -> str: """ return self._flows_db - def get_altflow_from_uid(self, profileid, twid, uid) -> dict: + def get_altflow_from_uid(self, uid) -> dict: """Given a uid, get the alternative flow associated with it""" condition = f'uid = "{uid}"' altflow = self.select("altflows", condition=condition) diff --git a/slips_files/core/evidence_handler.py b/slips_files/core/evidence_handler.py index 39d2263e6..9ec213dcf 100644 --- a/slips_files/core/evidence_handler.py +++ b/slips_files/core/evidence_handler.py @@ -101,7 +101,7 @@ def init(self): utils.change_logfiles_ownership(self.jsonfile.name, self.UID, self.GID) # this list will have our local and public ips when using -i self.our_ips: List[str] = utils.get_own_ips(ret="List") - self.formatter = EvidenceFormatter(self.db) + self.formatter = EvidenceFormatter(self.db, self.args) # thats just a tmp value, this variable will be set and used when # the # module is stopping. @@ -136,8 +136,8 @@ def clean_file(self, output_dir, file_to_clean): open(logfile_path, "w").close() return open(logfile_path, "a") - def handle_unable_to_log(self): - self.print("Error logging evidence/alert.") + def handle_unable_to_log(self, failed_log, error=None): + self.print(f"Error logging evidence/alert: {error}. {failed_log}.") def add_alert_to_json_log_file(self, alert: Alert): """ @@ -145,7 +145,7 @@ def add_alert_to_json_log_file(self, alert: Alert): """ idmef_alert: dict = self.idmefv2.convert_to_idmef_alert(alert) if not idmef_alert: - self.handle_unable_to_log() + self.handle_unable_to_log(alert, "Can't convert to IDMEF alert") return try: @@ -153,8 +153,8 @@ def add_alert_to_json_log_file(self, alert: Alert): self.jsonfile.write("\n") except KeyboardInterrupt: return True - except Exception: - self.handle_unable_to_log() + except Exception as e: + self.handle_unable_to_log(alert, e) def add_evidence_to_json_log_file( self, @@ -166,7 +166,9 @@ def add_evidence_to_json_log_file( """ idmef_evidence: dict = self.idmefv2.convert_to_idmef_event(evidence) if not idmef_evidence: - self.handle_unable_to_log() + self.handle_unable_to_log( + evidence, "Can't convert to IDMEF evidence" + ) return try: @@ -188,8 +190,8 @@ def add_evidence_to_json_log_file( self.jsonfile.write("\n") except KeyboardInterrupt: return True - except Exception: - self.handle_unable_to_log() + except Exception as e: + self.handle_unable_to_log(evidence, e) def add_to_log_file(self, data): """ @@ -378,7 +380,9 @@ def is_blocking_modules_supported(self) -> bool: ) and blocking_module_enabled def handle_new_alert( - self, alert: Alert, evidence_causing_the_alert: Dict[str, Evidence] + self, + alert: Alert, + evidence_causing_the_alert, ): """ saves alert details in the db and informs exporting modules about it @@ -386,8 +390,8 @@ def handle_new_alert( if a profile already generated an alert in this tw, we send a blocking request (to extend its blocking period), and log the alert in the db only, without printing it to cli. + :param evidence_causing_the_alert: Dict[str, Evidence] """ - self.db.set_alert(alert, evidence_causing_the_alert) is_blocked: bool = self.decide_blocking( alert.profile.ip, alert.timewindow @@ -419,7 +423,9 @@ def handle_new_alert( self.log_alert(alert, blocked=is_blocked) def decide_blocking( - self, ip_to_block: str, timewindow: TimeWindow + self, + ip_to_block: str, + timewindow: TimeWindow, ) -> bool: """ Decide whether to block or not and send to the blocking module @@ -444,6 +450,10 @@ def decide_blocking( "ip": ip_to_block, "block": True, "tw": timewindow.number, + # in which localnet is this IP? to which interface does it belong? + "interface": utils.get_interface_of_ip( + ip_to_block, self.db, self.args + ), } blocking_data = json.dumps(blocking_data) self.db.publish("new_blocking", blocking_data) @@ -517,6 +527,7 @@ def main(self): twid: str = str(evidence.timewindow) evidence_type: EvidenceType = evidence.evidence_type timestamp: str = evidence.timestamp + # the database naturally has evidence before they reach # this module. and sometime when this module queries # evidence for a specific timewindow, the db returns all @@ -646,6 +657,11 @@ def main(self): "block": True, "to": True, "from": True, + # in which localnet is this IP? + # to which interface does it belong? + "interface": utils.get_interface_of_ip( + key, self.db, self.args + ), } blocking_data = json.dumps(blocking_data) self.db.publish("new_blocking", blocking_data) diff --git a/slips_files/core/flows/argus.py b/slips_files/core/flows/argus.py index 8a16f7f11..34a7e097f 100644 --- a/slips_files/core/flows/argus.py +++ b/slips_files/core/flows/argus.py @@ -1,12 +1,12 @@ # SPDX-FileCopyrightText: 2021 Sebastian Garcia # SPDX-License-Identifier: GPL-2.0-only from dataclasses import dataclass, field - from slips_files.common.slips_utils import utils +from slips_files.core.flows.base_flow import BaseFlow -@dataclass -class ArgusConn: +@dataclass(kw_only=True) +class ArgusConn(BaseFlow): starttime: str endtime: str dur: str diff --git a/slips_files/core/flows/base_flow.py b/slips_files/core/flows/base_flow.py new file mode 100644 index 000000000..9f9bc9d08 --- /dev/null +++ b/slips_files/core/flows/base_flow.py @@ -0,0 +1,10 @@ +# SPDX-FileCopyrightText: 2021 Sebastian Garcia +# SPDX-License-Identifier: GPL-2.0-only +from dataclasses import dataclass, field + + +@dataclass(kw_only=True) +class BaseFlow: + """A base class for zeek flows, containing common fields.""" + + interface: str = field(default="default") diff --git a/slips_files/core/flows/nfdump.py b/slips_files/core/flows/nfdump.py index 0ab08bcbc..cde474a06 100644 --- a/slips_files/core/flows/nfdump.py +++ b/slips_files/core/flows/nfdump.py @@ -2,10 +2,11 @@ # SPDX-License-Identifier: GPL-2.0-only from dataclasses import dataclass, field from slips_files.common.slips_utils import utils +from slips_files.core.flows.base_flow import BaseFlow -@dataclass -class NfdumpConn: +@dataclass(kw_only=True) +class NfdumpConn(BaseFlow): starttime: str endtime: str diff --git a/slips_files/core/flows/suricata.py b/slips_files/core/flows/suricata.py index 0eaa13582..5e541efa1 100644 --- a/slips_files/core/flows/suricata.py +++ b/slips_files/core/flows/suricata.py @@ -8,6 +8,8 @@ ) from slips_files.common.slips_utils import utils +from slips_files.core.flows.base_flow import BaseFlow + # suricata available event_type values: # -flow @@ -34,8 +36,8 @@ def get_total_pkts(flow): return flow.dpkts + flow.spkts -@dataclass -class SuricataFlow: +@dataclass(kw_only=True) +class SuricataFlow(BaseFlow): # A suricata line of flow type usually has 2 components. # 1. flow information # 2. tcp information @@ -86,8 +88,8 @@ def __post_init__(self): self.uid = str(self.uid) -@dataclass -class SuricataHTTP: +@dataclass(kw_only=True) +class SuricataHTTP(BaseFlow): starttime: str uid: str @@ -124,8 +126,8 @@ def __post_init__(self): self.uid = str(self.uid) -@dataclass -class SuricataDNS: +@dataclass(kw_only=True) +class SuricataDNS(BaseFlow): starttime: str uid: str @@ -153,8 +155,8 @@ def __post_init__(self): self.uid = str(self.uid) -@dataclass -class SuricataTLS: +@dataclass(kw_only=True) +class SuricataTLS(BaseFlow): starttime: str uid: str @@ -183,8 +185,8 @@ def __post_init__(self): self.uid = str(self.uid) -@dataclass -class SuricataFile: +@dataclass(kw_only=True) +class SuricataFile(BaseFlow): starttime: str uid: str @@ -212,8 +214,8 @@ def __post_init__(self): self.uid = str(self.uid) -@dataclass -class SuricataSSH: +@dataclass(kw_only=True) +class SuricataSSH(BaseFlow): starttime: str uid: str diff --git a/slips_files/core/flows/zeek.py b/slips_files/core/flows/zeek.py index fed684181..d49c2467e 100644 --- a/slips_files/core/flows/zeek.py +++ b/slips_files/core/flows/zeek.py @@ -13,8 +13,15 @@ from slips_files.common.slips_utils import utils -@dataclass -class Conn: +@dataclass(kw_only=True) +class BaseFlow: + """A base class for zeek flows, containing common fields.""" + + interface: str = field(default="default") + + +@dataclass(kw_only=True) +class Conn(BaseFlow): starttime: str uid: str saddr: str @@ -61,8 +68,8 @@ def __post_init__(self) -> None: self.proto = "" -@dataclass -class DNS: +@dataclass(kw_only=True) +class DNS(BaseFlow): starttime: str uid: str saddr: str @@ -93,8 +100,8 @@ def __post_init__(self) -> None: ) -@dataclass -class HTTP: +@dataclass(kw_only=True) +class HTTP(BaseFlow): starttime: str uid: str saddr: str @@ -125,8 +132,8 @@ def __post_init__(self) -> None: pass -@dataclass -class SSL: +@dataclass(kw_only=True) +class SSL(BaseFlow): starttime: str uid: str saddr: str @@ -160,8 +167,8 @@ class SSL: type_: str = "ssl" -@dataclass -class SSH: +@dataclass(kw_only=True) +class SSH(BaseFlow): starttime: float uid: str saddr: str @@ -188,8 +195,8 @@ class SSH: type_: str = "ssh" -@dataclass -class DHCP: +@dataclass(kw_only=True) +class DHCP(BaseFlow): starttime: float uids: List[str] client_addr: str @@ -215,8 +222,8 @@ def __post_init__(self) -> None: self.saddr = self.smac -@dataclass -class FTP: +@dataclass(kw_only=True) +class FTP(BaseFlow): starttime: float uid: str saddr: str @@ -230,8 +237,8 @@ class FTP: type_: str = "ftp" -@dataclass -class SMTP: +@dataclass(kw_only=True) +class SMTP(BaseFlow): starttime: float uid: str saddr: str @@ -245,8 +252,8 @@ class SMTP: type_: str = "smtp" -@dataclass -class Tunnel: +@dataclass(kw_only=True) +class Tunnel(BaseFlow): starttime: str uid: str saddr: str @@ -264,8 +271,8 @@ class Tunnel: type_: str = "tunnel" -@dataclass -class Notice: +@dataclass(kw_only=True) +class Notice(BaseFlow): starttime: str saddr: str daddr: str @@ -311,8 +318,8 @@ def __post_init__(self) -> None: self.dport = self.dport -@dataclass -class Files: +@dataclass(kw_only=True) +class Files(BaseFlow): starttime: str uid: str saddr: str @@ -346,8 +353,8 @@ def __post_init__(self) -> None: self.daddr = daddr -@dataclass -class ARP: +@dataclass(kw_only=True) +class ARP(BaseFlow): starttime: str uid: str saddr: str @@ -382,8 +389,8 @@ class ARP: type_: str = "arp" -@dataclass -class Software: +@dataclass(kw_only=True) +class Software(BaseFlow): starttime: str uid: str saddr: str @@ -407,8 +414,8 @@ def __post_init__(self) -> None: self.http_browser = self.software == "HTTP::BROWSER" -@dataclass -class Weird: +@dataclass(kw_only=True) +class Weird(BaseFlow): starttime: str uid: str saddr: str diff --git a/slips_files/core/helpers/checker.py b/slips_files/core/helpers/checker.py index 470e699a0..be7792734 100644 --- a/slips_files/core/helpers/checker.py +++ b/slips_files/core/helpers/checker.py @@ -10,19 +10,22 @@ class Checker: def __init__(self, main): self.main = main - def check_input_type(self) -> tuple: + def get_input_type(self) -> tuple: """ returns line_type, input_type, input_information - supported input types are: + supported input_type values are: interface, argus, suricata, zeek, nfdump, db - supported self.input_information: - given filepath, interface or type of line given in stdin + supported input_information: + given filepath, interface or type of line given in stdin, + comma separated access point interfaces like wlan0,eth0 """ # only defined in stdin lines line_type = False - # -I - if self.main.args.interface: - input_information = self.main.args.interface + # -i or -ap + if self.main.args.interface or self.main.args.access_point: + input_information = ( + self.main.args.interface or self.main.args.access_point + ) input_type = "interface" # return input_type, self.main.input_information return input_type, input_information, line_type @@ -56,37 +59,112 @@ def check_input_type(self) -> tuple: return input_type, input_information, line_type - def check_given_flags(self): - """ - check the flags that don't require starting slips - for example: clear db, clearing the blocking chain, killing all - servers, etc. - """ + def _print_help_and_exit(self): + """prints the help msg and shutd down slips""" + self.main.print_version() + arg_parser = self.main.conf.get_parser(help=True) + arg_parser.parse_arguments() + arg_parser.print_help() + self.main.terminate_slips() - if self.main.args.help: - self.main.print_version() - arg_parser = self.main.conf.get_parser(help=True) - arg_parser.parse_arguments() - arg_parser.print_help() - self.main.terminate_slips() - if self.main.args.interface and self.main.args.filepath: - print("Only -i or -f is allowed. Stopping slips.") - self.main.terminate_slips() - return + def _check_mutually_exclusive_flags(self): + """checks if the user provided args that shouldnt be used together""" + mutually_exclusive_flags = [ + self.main.args.interface, # -i + self.main.args.access_point, # -ap + self.main.args.save, # -s + self.main.args.db, # -d + self.main.args.filepath, # -f + self.main.args.input_module, # -im + ] + + # Count how many of the flags are set (True) + mutually_exclusive_flag_count = sum( + bool(flag) for flag in mutually_exclusive_flags + ) - if ( - self.main.args.interface or self.main.args.filepath - ) and self.main.args.input_module: + if mutually_exclusive_flag_count > 1: print( - "You can't use --input-module with -f or -i. Stopping slips." + "Only one of the flags -i, -ap, -s, -d, or -f is allowed. " + "Stopping slips." ) self.main.terminate_slips() return + def _check_if_root_is_required(self): if (self.main.args.save or self.main.args.db) and os.getuid() != 0: print("Saving and loading the database requires root privileges.") self.main.terminate_slips() return + if ( + self.main.args.interface + and self.main.args.blocking + and os.geteuid() != 0 + ): + # If the user wants to blocks, we need permission to modify + # iptables + print("Run Slips with sudo to use the blocking modules.") + self.main.terminate_slips() + return + + if self.main.args.clearblocking: + if os.geteuid() != 0: + print( + "Slips needs to be run as root to clear the slipsBlocking " + "chain. Stopping." + ) + else: + self.delete_blocking_chain() + self.main.terminate_slips() + return + + def _check_interface_validity(self): + """checks if the given interface/s are valid""" + interfaces = psutil.net_if_addrs().keys() + if self.main.args.interface: + if self.main.args.interface not in interfaces: + print( + f"{self.main.args.interface} is not a valid interface. " + f"Stopping Slips" + ) + self.main.terminate_slips() + return + + if self.main.args.access_point: + for interface in self.main.args.access_point.split(","): + if interface not in interfaces: + print( + f"{interface} is not a valid interface." + f" Stopping Slips" + ) + self.main.terminate_slips() + return + + def _is_slips_running_non_stop(self) -> bool: + """determines if slips is monitoring real time traffic based oin + the giving params""" + return ( + self.main.args.interface + or self.main.args.access_point + or self.main.args.growing + or self.main.args.input_module + ) + + def verify_given_flags(self): + """ + Checks the validity of the given flags. + """ + if self.main.args.help: + self._print_help_and_exit() + + if self.main.args.version: + self.main.print_version() + self.main.terminate_slips() + return + + self._check_mutually_exclusive_flags() + self._check_if_root_is_required() + self._check_interface_validity() if (self.main.args.verbose and int(self.main.args.verbose) > 3) or ( self.main.args.debug and int(self.main.args.debug) > 3 @@ -124,16 +202,6 @@ def check_given_flags(self): ) return - if self.main.args.interface: - interfaces = psutil.net_if_addrs().keys() - if self.main.args.interface not in interfaces: - print( - f"{self.main.args.interface} is not a valid interface. " - f"Stopping Slips" - ) - self.main.terminate_slips() - return - # if we're reading flows from some module other than the input # process, make sure it exists if self.main.args.input_module and not self.input_module_exists( @@ -148,9 +216,10 @@ def check_given_flags(self): return # Clear cache if the parameter was included - if self.main.args.blocking and not self.main.args.interface: + if self.main.args.blocking and not self._is_slips_running_non_stop(): print( - "Blocking is only allowed when running slips using an interface." + "Blocking is only allowed when running slips on real time " + "traffic. (running with -i, -ap, -im, or -g)" ) self.main.terminate_slips() return @@ -161,39 +230,6 @@ def check_given_flags(self): self.main.terminate_slips() return - if self.main.args.version: - self.main.print_version() - self.main.terminate_slips() - return - - if ( - self.main.args.interface - and self.main.args.blocking - and os.geteuid() != 0 - ): - # If the user wants to blocks, we need permission to modify - # iptables - print("Run Slips with sudo to use the blocking modules.") - self.main.terminate_slips() - return - - if self.main.args.clearblocking: - if os.geteuid() != 0: - print( - "Slips needs to be run as root to clear the slipsBlocking " - "chain. Stopping." - ) - else: - self.delete_blocking_chain() - self.main.terminate_slips() - return - - # Check if user want to save and load a db at the same time - if self.main.args.save and self.main.args.db: - print("Can't use -s and -d together") - self.main.terminate_slips() - return - def delete_blocking_chain(self): from modules.blocking.slips_chain_manager import ( del_slips_blocking_chain, diff --git a/slips_files/core/helpers/filemonitor.py b/slips_files/core/helpers/filemonitor.py index de9e9694d..c9726a0b9 100644 --- a/slips_files/core/helpers/filemonitor.py +++ b/slips_files/core/helpers/filemonitor.py @@ -27,17 +27,18 @@ class FileEventHandler(RegexMatchingEventHandler): REGEX = [r".*\.log$", r".*\.conf$"] - def __init__(self, dir_to_monitor, input_type, db): + def __init__(self, dir_to_monitor, db, pcap_or_interface): super().__init__(regexes=self.REGEX) self.dir_to_monitor = dir_to_monitor + # name of the pcap or interface zeek is monitoring + self.pcap_or_interface = pcap_or_interface self.db = db - self.input_type = input_type def on_created(self, event): """this will be triggered everytime zeek creates a log file""" filename, ext = os.path.splitext(event.src_path) if "log" in ext: - self.db.add_zeek_file(filename + ext) + self.db.add_zeek_file(filename + ext, self.pcap_or_interface) def on_moved(self, event): """ @@ -60,6 +61,7 @@ def on_modified(self, event): # so if zeek receives a termination signal, # slips would know about it filename, ext = os.path.splitext(event.src_path) + if "reporter" in filename: # check if it's a termination signal # get the exact file name (a ts is appended to it) diff --git a/slips_files/core/helpers/flow_handler.py b/slips_files/core/helpers/flow_handler.py index 2ccb1a7c3..50e2af75a 100644 --- a/slips_files/core/helpers/flow_handler.py +++ b/slips_files/core/helpers/flow_handler.py @@ -131,7 +131,9 @@ def handle_conn(self): # store the original flow as benign in sqlite self.db.add_flow(self.flow, self.profileid, self.twid, "benign") - self.db.add_mac_addr_to_profile(self.profileid, self.flow.smac) + self.db.add_mac_addr_to_profile( + self.profileid, self.flow.smac, self.flow.interface + ) if self.running_non_stop: # to avoid publishing duplicate MACs, when running on @@ -170,15 +172,17 @@ def handle_notice(self): # foirst check if the gw ip and mac are set by # profiler.get_gateway_info() or ip_info module gw_ip = False - if not self.db.get_gateway_ip(): + if not self.db.get_gateway_ip(self.flow.interface): # get the gw addr from the msg gw_ip = self.flow.msg.split(": ")[-1].strip() - self.db.set_default_gateway("IP", gw_ip) + self.db.set_default_gateway("IP", gw_ip, self.flow.interface) - if not self.db.get_gateway_mac() and gw_ip: + if not self.db.get_gateway_mac(self.flow.interface) and gw_ip: gw_mac = self.db.get_mac_addr_from_profile(f"profile_{gw_ip}") if gw_mac: - self.db.set_default_gateway("MAC", gw_mac) + self.db.set_default_gateway( + "MAC", gw_mac, self.flow.interface + ) self.db.add_altflow(self.flow, self.profileid, self.twid, "benign") @@ -216,7 +220,9 @@ def handle_dhcp(self): self.flow.saddr, ) - self.db.add_mac_addr_to_profile(self.profileid, self.flow.smac) + self.db.add_mac_addr_to_profile( + self.profileid, self.flow.smac, self.flow.interface + ) if self.flow.server_addr: self.db.store_dhcp_server(self.flow.server_addr) @@ -264,7 +270,9 @@ def handle_arp(self): # send to arp module to_send = json.dumps(to_send) self.db.publish("new_arp", to_send) - self.db.add_mac_addr_to_profile(self.profileid, self.flow.smac) + self.db.add_mac_addr_to_profile( + self.profileid, self.flow.smac, self.flow.interface + ) self.publisher.new_MAC(self.flow.dmac, self.flow.daddr) self.publisher.new_MAC(self.flow.smac, self.flow.saddr) self.db.add_altflow(self.flow, self.profileid, self.twid, "benign") diff --git a/slips_files/core/input.py b/slips_files/core/input.py index 0e4563dd5..2032e90c6 100644 --- a/slips_files/core/input.py +++ b/slips_files/core/input.py @@ -27,6 +27,7 @@ # Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA. # Contact: eldraco@gmail.com, sebastian.garcia@agents.fel.cvut.cz, stratosphere@aic.fel.cvut.cz from re import split +from typing import Dict, List from watchdog.observers import Observer @@ -36,8 +37,11 @@ from slips_files.common.parsers.config_parser import ConfigParser from slips_files.common.slips_utils import utils import multiprocessing + +from slips_files.common.style import yellow from slips_files.core.helpers.filemonitor import FileEventHandler from slips_files.core.supported_logfiles import SUPPORTED_LOGFILES +from slips_files.core.zeek_cmd_builder import ZeekCommandBuilder class Input(ICore): @@ -80,6 +84,7 @@ def init( self.testing = False # number of lines read self.lines = 0 + self.zeek_pids = [] # create the remover thread self.remover_thread = threading.Thread( @@ -93,9 +98,8 @@ def init( self.timeout = None # zeek rotated files to be deleted after a period of time self.to_be_deleted = [] - self.zeek_thread = threading.Thread( - target=self.run_zeek, daemon=True, name="run_zeek_thread" - ) + self.zeek_threads = [] + # is set by the profiler to tell this proc that we it is done processing # the input process and shut down and close the profiler queue no issue self.is_profiler_done_event = is_profiler_done_event @@ -239,7 +243,7 @@ def get_ts_from_line(self, zeek_line: str): return timestamp, nline - def cache_nxt_line_in_file(self, filename: str): + def cache_nxt_line_in_file(self, filename: str, interface: str): """ reads 1 line of the given file and stores in queue for sending to the profiler :param: full path to the file. includes the .log extension @@ -283,7 +287,11 @@ def cache_nxt_line_in_file(self, filename: str): self.file_time[filename] = timestamp # Store the line in the cache - self.cache_lines[filename] = {"type": filename, "data": nline} + self.cache_lines[filename] = { + "type": filename, + "data": nline, + "interface": interface, + } return True def reached_timeout(self) -> bool: @@ -328,7 +336,7 @@ def get_earliest_line(self): # It may happen that we check all the files in the folder, # and there is still no files for us. # To cover this case, just refresh the list of files - self.zeek_files = self.db.get_all_zeek_files() + self.zeek_files: Dict[str, str] = self.db.get_all_zeek_files() return False, False # comes here if we're done with all conn.log flows and it's time to @@ -338,7 +346,7 @@ def get_earliest_line(self): def read_zeek_files(self) -> int: try: - self.zeek_files = self.db.get_all_zeek_files() + self.zeek_files: Dict[str, str] = self.db.get_all_zeek_files() self.open_file_handlers = {} # stores zeek_log_file_name: timestamp of the last flow read from # that file @@ -350,13 +358,13 @@ def read_zeek_files(self) -> int: self.check_if_time_to_del_rotated_files() # Go to all the files generated by Zeek and read 1 # line from each of them - for filename in self.zeek_files: + for filename, interface in self.zeek_files.items(): if utils.is_ignored_zeek_log_file(filename): continue # reads 1 line from the given file and cache it # from in self.cache_lines - self.cache_nxt_line_in_file(filename) + self.cache_nxt_line_in_file(filename, interface) if self.reached_timeout(): break @@ -380,7 +388,7 @@ def read_zeek_files(self) -> int: # Get the new list of files. Since new files may have been created by # Zeek while we were processing them. - self.zeek_files = self.db.get_all_zeek_files() + self.zeek_files: Dict[str, str] = self.db.get_all_zeek_files() self.close_all_handles() except KeyboardInterrupt: @@ -422,7 +430,11 @@ def get_flows_number(self, file: str) -> int: return count def read_zeek_folder(self): - # This is the case that a folder full of zeek files is passed with -f + """ + This is the case that a folder full of zeek files is passed with -f + DISCLAIMER: this func does not run when slips is running on an + interface with -i or -ap + """ # wait max 10 seconds before stopping slips if no new flows are read self.bro_timeout = 10 growing_zeek_dir: bool = self.db.is_growing_zeek_dir() @@ -433,7 +445,11 @@ def read_zeek_folder(self): self.bro_timeout = float("inf") self.zeek_dir = self.given_path - self.start_observer() + if self.args.growing: + interface = utils.infer_used_interface() + else: + interface = "default" + self.start_observer(self.zeek_dir, interface) # if 1 file is zeek tabs the rest should be the same if not hasattr(self, "is_zeek_tabs"): @@ -455,7 +471,7 @@ def read_zeek_folder(self): total_flows += self.get_flows_number(full_path) # Add log file to the database - self.db.add_zeek_file(full_path) + self.db.add_zeek_file(full_path, interface) # in testing mode, we only need to read one zeek file to know # that this function is working correctly @@ -600,7 +616,7 @@ def handle_zeek_log_file(self): self.total_flows = total_flows # Add log file to database - self.db.add_zeek_file(self.given_path) + self.db.add_zeek_file(self.given_path, "default") # this timeout is the only thing that # makes the read_zeek_files() return @@ -622,72 +638,136 @@ def handle_nfdump(self): self.mark_self_as_done_processing() return True - def start_observer(self): + def start_observer(self, zeek_dir: str, pcap_or_interface: str): + """ + :param zeek_dir: directory to monitor + """ # Now start the observer of new files. We need the observer because Zeek does not create all the files # at once, but when the traffic appears. That means that we need # some process to tell us which files to read in real time when they appear # Get the file eventhandler # We have to set event_handler and event_observer before running zeek. - event_handler = FileEventHandler( - self.zeek_dir, self.input_type, self.db - ) - # Create an observer + event_handler = FileEventHandler(zeek_dir, self.db, pcap_or_interface) + self.event_observer = Observer() # Schedule the observer with the callback on the file handler - self.event_observer.schedule( - event_handler, self.zeek_dir, recursive=True - ) + self.event_observer.schedule(event_handler, zeek_dir, recursive=True) # monitor changes to whitelist self.event_observer.schedule(event_handler, "config/", recursive=True) # Start the observer self.event_observer.start() - def handle_pcap_and_interface(self) -> int: - """Returns the number of zeek lines read""" - - # Create zeek_folder if does not exist. + def handle_pcap_and_interface(self) -> bool: + """ + runs when slips is given a pcap with -f, an interface with -i, + or 2 interfaces with -ap + """ if not os.path.exists(self.zeek_dir): os.makedirs(self.zeek_dir) self.print(f"Storing zeek log files in {self.zeek_dir}") - self.start_observer() if self.input_type == "interface": + # slips is running with -i or -ap # We don't want to stop bro if we read from an interface self.bro_timeout = float("inf") + # format is {interface: zeek_dir_path} + interfaces_to_monitor = {} + if self.args.interface: + interfaces_to_monitor.update( + { + self.args.interface: { + "dir": self.zeek_dir, + "type": "main_interface", + } + } + ) + + elif self.args.access_point: + # slips is running in AP mode, we need to monitor the 2 + # interfaces, wifi and eth. + for _type, interface in self.db.get_ap_info().items(): + # _type can be 'wifi_interface' or "ethernet_interface" + dir_to_store_interface_logs = os.path.join( + self.zeek_dir, interface + ) + interfaces_to_monitor.update( + { + interface: { + "dir": dir_to_store_interface_logs, + "type": _type, + } + } + ) + for interface, interface_info in interfaces_to_monitor.items(): + interface_dir = interface_info["dir"] + if not os.path.exists(interface_dir): + os.makedirs(interface_dir) + + if interface_info["type"] == "ethernet_interface": + cidr = utils.get_cidr_of_interface(interface) + tcpdump_filter = f"dst net {cidr}" + logline = yellow( + f"Zeek is logging incoming traffic only " + f"for interface: {interface}." + ) + self.print(logline) + else: + tcpdump_filter = None + logline = yellow( + f"Zeek is logging all traffic on interface:" + f" {interface}." + ) + self.print(logline) + + self.init_zeek( + interface_dir, interface, tcpdump_filter=tcpdump_filter + ) + elif self.input_type == "pcap": # This is for stopping the inputprocess # if bro does not receive any new line while reading a pcap self.bro_timeout = 30 + self.init_zeek(self.zeek_dir, self.given_path) - zeek_files = os.listdir(self.zeek_dir) + self.lines = self.read_zeek_files() + self.print_lines_read() + self.mark_self_as_done_processing() + self.stop_observer() + return True + + def init_zeek( + self, zeek_dir: str, pcap_or_interface: str, tcpdump_filter=None + ): + """ + :param pcap_or_interface: name of the pcap or interface zeek + is going to run on + + PS: this function contains a call to self.read_zeek_files that + keeps running until slips stops + """ + self.start_observer(zeek_dir, pcap_or_interface) + + zeek_files = os.listdir(zeek_dir) if len(zeek_files) > 0: # First clear the zeek folder of old .log files for f in zeek_files: - os.remove(os.path.join(self.zeek_dir, f)) + os.remove(os.path.join(zeek_dir, f)) - # run zeek - self.zeek_thread.start() + zeek_thread = threading.Thread( + target=self.run_zeek, + args=(zeek_dir, pcap_or_interface), + kwargs={"tcpdump_filter": tcpdump_filter}, + daemon=True, + name="run_zeek_thread", + ) + zeek_thread.start() + self.zeek_threads.append(zeek_thread) # Give Zeek some time to generate at least 1 file. time.sleep(3) - self.db.store_pid("Zeek", self.zeek_pid) + self.db.store_pid(f"Zeek_{pcap_or_interface}", self.zeek_pids[-1]) if not hasattr(self, "is_zeek_tabs"): self.is_zeek_tabs = False - self.lines = self.read_zeek_files() - self.print_lines_read() - self.mark_self_as_done_processing() - - connlog_path = os.path.join(self.zeek_dir, "conn.log") - - self.print( - f"Number of zeek generated flows in conn.log: " - f"{self.get_flows_number(connlog_path)}", - 2, - 0, - ) - - self.stop_observer() - return True def stop_observer(self): # Stop the observer @@ -756,99 +836,66 @@ def shutdown_gracefully(self): except Exception: pass try: - self.zeek_thread.join(3) + for zeek_thread in self.zeek_threads: + zeek_thread.join(3) except Exception: pass if hasattr(self, "open_file_handlers"): self.close_all_handles() - if hasattr(self, "zeek_pid"): - # kill zeek manually if it started bc it's detached from this - # process and will never recv the sigint also withoutt this, - # inputproc will never shutdown and will always remain in memory - # causing 1000 bugs in proc_man:shutdown_gracefully() + # kill zeek manually if it started bc it's detached from this + # process and will never recv the sigint. + # also without this, inputproc will never shutdown and will + # always remain in memory causing 1000 bugs in + # proc_man:shutdown_gracefully() + for pid in self.zeek_pids: try: - os.kill(self.zeek_pid, signal.SIGKILL) + os.kill(pid, signal.SIGKILL) except Exception: pass return True - def run_zeek(self): + def _construct_zeek_cmd( + self, pcap_or_interface: str, tcpdump_filter=None + ) -> List[str]: """ - This thread sets the correct zeek parameters and starts zeek + constructs the zeek command based on the user given + pcap/interface/packet filter/etc. """ - - def detach_child(): - """ - Detach zeek from the parent process group(inputprocess), the child(zeek) - will no longer receive signals - """ - # we're doing this to fix zeek rotating on sigint, not when zeek has it's own - # process group, it won't get the signals sent to slips.py - os.setpgrp() - - # rotation is disabled unless it's an interface - rotation = [] - if self.input_type == "interface": - if self.enable_rotation: - # how often to rotate zeek files? taken from slips.yaml - rotation = [ - "-e", - f"redef Log::default_rotation_interval = {self.rotation_period} ;", - ] - bro_parameter = ["-i", self.given_path] - - elif self.input_type == "pcap": - # Find if the pcap file name was absolute or relative - given_path = self.given_path - if not os.path.isabs(self.given_path): - # now the given pcap is relative to slips main dir - # slips can store the zeek logs dir either in the - # output dir (by default in Slips/output/_/zeek_files/), - # or in any dir specified with -o - # construct an abs path from the given path so slips can find the given pcap - # no matter where the zeek dir is placed - given_path = os.path.join(os.getcwd(), self.given_path) - - # using a list of params instead of a str for storing the cmd - # becaus ethe given path may contain spaces - bro_parameter = ["-r", given_path] - - # Run zeek on the pcap or interface. The redef is to have json files - zeek_scripts_dir = os.path.join(os.getcwd(), "zeek-scripts") - packet_filter = ( - ["-f ", self.packet_filter] if self.packet_filter else [] + builder = ZeekCommandBuilder( + zeek_or_bro=self.zeek_or_bro, + input_type=self.input_type, + rotation_period=self.rotation_period, + enable_rotation=self.enable_rotation, + tcp_inactivity_timeout=self.tcp_inactivity_timeout, + packet_filter=self.packet_filter, ) - # 'local' is removed from the command because it - # loads policy/protocols/ssl/expiring-certs and - # and policy/protocols/ssl/validate-certs and they have conflicts with our own - # zeek-scripts/expiring-certs and validate-certs - # we have our own copy pf local.zeek in __load__.zeek - command = [self.zeek_or_bro, "-C"] - command += bro_parameter - command += [ - f"tcp_inactivity_timeout={self.tcp_inactivity_timeout}mins", - "tcp_attempt_delay=1min", - zeek_scripts_dir, - ] - command += rotation - command += packet_filter - - self.print(f'Zeek command: {" ".join(command)}', 3, 0) + cmd = builder.build(pcap_or_interface, tcpdump_filter=tcpdump_filter) + return cmd + + def run_zeek(self, zeek_logs_dir, pcap_or_interface, tcpdump_filter=None): + """ + This thread sets the correct zeek parameters and starts zeek + :kwarg tcpdump_filter: optional tcp filter to use when + starting zeek with -f + """ + command = self._construct_zeek_cmd(pcap_or_interface, tcpdump_filter) + str_cmd = " ".join(command) + self.print(f"Zeek command: {str_cmd}", 3, 0) zeek = subprocess.Popen( command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE, - cwd=self.zeek_dir, + cwd=zeek_logs_dir, start_new_session=True, ) # you have to get the pid before communicate() - self.zeek_pid = zeek.pid + self.zeek_pids.append(zeek.pid) out, error = zeek.communicate() if out: diff --git a/slips_files/core/input_profilers/argus.py b/slips_files/core/input_profilers/argus.py index 27ea9d727..14c278efb 100644 --- a/slips_files/core/input_profilers/argus.py +++ b/slips_files/core/input_profilers/argus.py @@ -41,23 +41,24 @@ def get_value_of(field_name, default_=False): return default_ self.flow: ArgusConn = ArgusConn( - utils.convert_to_datetime(get_value_of("starttime")), - get_value_of("endtime"), - get_value_of("dur"), - get_value_of("proto"), - get_value_of("appproto"), - get_value_of("saddr"), - get_value_of("sport"), - get_value_of("dir"), - get_value_of("daddr"), - get_value_of("dport"), - get_value_of("state"), - int(get_value_of("pkts")), - int(get_value_of("spkts")), - int(get_value_of("dpkts")), - int(get_value_of("bytes")), - int(get_value_of("sbytes")), - int(get_value_of("dbytes")), + starttime=utils.convert_to_datetime(get_value_of("starttime")), + endtime=get_value_of("endtime"), + dur=get_value_of("dur"), + proto=get_value_of("proto"), + appproto=get_value_of("appproto"), + saddr=get_value_of("saddr"), + sport=get_value_of("sport"), + dir_=get_value_of("dir"), + daddr=get_value_of("daddr"), + dport=get_value_of("dport"), + state=get_value_of("state"), + pkts=int(get_value_of("pkts")), + spkts=int(get_value_of("spkts")), + dpkts=int(get_value_of("dpkts")), + bytes=int(get_value_of("bytes")), + sbytes=int(get_value_of("sbytes")), + dbytes=int(get_value_of("dbytes")), + interface="default", ) return self.flow diff --git a/slips_files/core/input_profilers/nfdump.py b/slips_files/core/input_profilers/nfdump.py index 80d9533d9..f7eca0415 100644 --- a/slips_files/core/input_profilers/nfdump.py +++ b/slips_files/core/input_profilers/nfdump.py @@ -29,19 +29,19 @@ def get_value_at(indx, default_=False): starttime = utils.convert_ts_format(get_value_at(0), "unixtimestamp") endtime = utils.convert_ts_format(get_value_at(1), "unixtimestamp") self.flow: NfdumpConn = NfdumpConn( - starttime, - endtime, - get_value_at(2), - get_value_at(7), - get_value_at(3), - get_value_at(5), - get_value_at(22), - get_value_at(4), - get_value_at(6), - get_value_at(8), - get_value_at(11), - get_value_at(13), - get_value_at(12), - get_value_at(14), + starttime=starttime, + endtime=endtime, + dur=get_value_at(2), + proto=get_value_at(7), + saddr=get_value_at(3), + sport=get_value_at(5), + dir_=get_value_at(22), + daddr=get_value_at(4), + dport=get_value_at(6), + state=get_value_at(8), + spkts=get_value_at(11), + dpkts=get_value_at(13), + sbytes=get_value_at(12), + dbytes=get_value_at(14), ) return self.flow diff --git a/slips_files/core/input_profilers/suricata.py b/slips_files/core/input_profilers/suricata.py index ca2eeb592..00998ef13 100644 --- a/slips_files/core/input_profilers/suricata.py +++ b/slips_files/core/input_profilers/suricata.py @@ -47,7 +47,7 @@ def process_line(self, line) -> None: if not line: return - # these fields are common in all suricata lines regardless of the event type + event_type = line["event_type"] flow_id = line["flow_id"] saddr = line["src_ip"] @@ -82,46 +82,49 @@ def get_value_at(field, subfield, default_=False): endtime = utils.convert_ts_format( get_value_at("flow", "end"), "unixtimestamp" ) - self.flow: SuricataFlow = SuricataFlow( - flow_id, - saddr, - sport, - daddr, - dport, - proto, - appproto, - starttime, - endtime, - int(get_value_at("flow", "pkts_toserver", 0)), - int(get_value_at("flow", "pkts_toclient", 0)), - int(get_value_at("flow", "bytes_toserver", 0)), - int(get_value_at("flow", "bytes_toclient", 0)), - get_value_at("flow", "state", ""), + + self.flow = SuricataFlow( + uid=flow_id, + saddr=saddr, + sport=sport, + daddr=daddr, + dport=dport, + proto=proto, + appproto=appproto, + starttime=starttime, + endtime=endtime, + spkts=int(get_value_at("flow", "pkts_toserver", 0)), + dpkts=int(get_value_at("flow", "pkts_toclient", 0)), + sbytes=int(get_value_at("flow", "bytes_toserver", 0)), + dbytes=int(get_value_at("flow", "bytes_toclient", 0)), + state=get_value_at("flow", "state", ""), ) elif event_type == "http": - self.flow: SuricataHTTP = SuricataHTTP( - timestamp, - flow_id, - saddr, - sport, - daddr, - dport, - proto, - appproto, - get_value_at("http", "http_method", ""), - get_value_at("http", "hostname", ""), - get_value_at("http", "url", ""), - get_value_at("http", "http_user_agent", ""), - get_value_at("http", "status", ""), - get_value_at("http", "protocol", ""), - int(get_value_at("http", "request_body_len", 0)), - int(get_value_at("http", "length", 0)), + self.flow = SuricataHTTP( + starttime=timestamp, + uid=flow_id, + saddr=saddr, + sport=sport, + daddr=daddr, + dport=dport, + proto=proto, + appproto=appproto, + method=get_value_at("http", "http_method", ""), + host=get_value_at("http", "hostname", ""), + uri=get_value_at("http", "url", ""), + user_agent=get_value_at("http", "http_user_agent", ""), + status_code=get_value_at("http", "status", ""), + version=get_value_at("http", "protocol", ""), + request_body_len=int( + get_value_at("http", "request_body_len", 0) + ), + response_body_len=int(get_value_at("http", "length", 0)), ) elif event_type == "dns": - answers: list = self.get_answers(line) - self.flow: SuricataDNS = SuricataDNS( + answers = self.get_answers(line) + self.flow = SuricataDNS( starttime=timestamp, uid=flow_id, saddr=saddr, @@ -132,55 +135,63 @@ def get_value_at(field, subfield, default_=False): appproto=appproto, query=get_value_at("dns", "rrname", ""), TTLs=get_value_at("dns", "ttl", ""), - qtype_name=get_value_at("qtype_name", "rrtype", ""), + qtype_name=get_value_at("dns", "rrtype", ""), answers=answers, ) elif event_type == "tls": - self.flow: SuricataTLS = SuricataTLS( - timestamp, - flow_id, - saddr, - sport, - daddr, - dport, - proto, - appproto, - get_value_at("tls", "version", ""), - get_value_at("tls", "subject", ""), - get_value_at("tls", "issuerdn", ""), - get_value_at("tls", "sni", ""), - get_value_at("tls", "notbefore", ""), - get_value_at("tls", "notafter", ""), - get_value_at("tls", "sni", ""), + self.flow = SuricataTLS( + starttime=timestamp, + uid=flow_id, + saddr=saddr, + sport=sport, + daddr=daddr, + dport=dport, + proto=proto, + appproto=appproto, + sslversion=get_value_at("tls", "version", ""), + subject=get_value_at("tls", "subject", ""), + issuer=get_value_at("tls", "issuerdn", ""), + server_name=get_value_at("tls", "sni", ""), + notbefore=get_value_at("tls", "notbefore", ""), + notafter=get_value_at("tls", "notafter", ""), ) elif event_type == "fileinfo": - self.flow: SuricataFile = SuricataFile( - timestamp, - flow_id, - saddr, - sport, - daddr, - dport, - proto, - appproto, - get_value_at("fileinfo", "size", ""), + self.flow = SuricataFile( + starttime=timestamp, + uid=flow_id, + saddr=saddr, + sport=sport, + daddr=daddr, + dport=dport, + proto=proto, + appproto=appproto, + size=int(get_value_at("fileinfo", "size", 0)), ) + elif event_type == "ssh": - self.flow: SuricataSSH = SuricataSSH( - timestamp, - flow_id, - saddr, - sport, - daddr, - dport, - proto, - appproto, - get_value_at("ssh", "client", {}).get("software_version", ""), - get_value_at("ssh", "client", {}).get("proto_version", ""), - get_value_at("ssh", "server", {}).get("software_version", ""), + self.flow = SuricataSSH( + starttime=timestamp, + uid=flow_id, + saddr=saddr, + sport=sport, + daddr=daddr, + dport=dport, + proto=proto, + appproto=appproto, + client=get_value_at("ssh", "client", {}).get( + "software_version", "" + ), + version=get_value_at("ssh", "client", {}).get( + "proto_version", "" + ), + server=get_value_at("ssh", "server", {}).get( + "software_version", "" + ), ) + else: return False + return self.flow diff --git a/slips_files/core/input_profilers/zeek.py b/slips_files/core/input_profilers/zeek.py index 21b661957..5c715c8ae 100644 --- a/slips_files/core/input_profilers/zeek.py +++ b/slips_files/core/input_profilers/zeek.py @@ -151,6 +151,7 @@ def __init__(self): def process_line(self, new_line: dict): line = new_line["data"] + interface = new_line["interface"] if not isinstance(line, dict): return False @@ -165,7 +166,7 @@ def process_line(self, new_line: dict): else: starttime = "" - flow_values = {"starttime": starttime} + flow_values = {"starttime": starttime, "interface": interface} for zeek_field, slips_field in line_map.items(): if not slips_field: diff --git a/slips_files/core/profiler.py b/slips_files/core/profiler.py index 4619f940b..42f1fd006 100644 --- a/slips_files/core/profiler.py +++ b/slips_files/core/profiler.py @@ -28,6 +28,7 @@ List, Union, Optional, + Dict, ) import netifaces @@ -87,7 +88,7 @@ def init( self.timeformat = None self.input_type = False self.rec_lines = 0 - self.is_localnet_set = False + self.localnet_cache = {} self.whitelist = Whitelist(self.logger, self.db) self.read_configuration() self.symbol = SymbolHandler(self.logger, self.db) @@ -101,8 +102,10 @@ def init( # is set by this proc to tell input proc that we are done # processing and it can exit no issue self.is_profiler_done_event = is_profiler_done_event - self.gw_mac = None - self.gw_ip = None + # stores the MAC addresses of the gateway of each interface + # will have interfaces as keys, and MACs as values + self.gw_macs = {} + self.gw_ips = {} self.profiler_threads = [] self.stop_profiler_threads = multiprocessing.Event() # each msg received from inputprocess will be put here, and each one @@ -162,14 +165,14 @@ def get_rev_profile(self, flow): rev_twid: str = self.db.get_timewindow(flow.starttime, rev_profileid) return rev_profileid, rev_twid - def get_gw_ip_using_gw_mac(self) -> Optional[str]: + def get_gw_ip_using_gw_mac(self, gw_mac) -> Optional[str]: """ gets the ip of the given mac from the db prioritizes returning the ipv4. if not found, the function returns the ipv6. or none if both are not found. """ # the db returns a serialized list of IPs belonging to this mac - gw_ips: str = self.db.get_ip_of_mac(self.gw_mac) + gw_ips: str = self.db.get_ip_of_mac(gw_mac) if not gw_ips: return @@ -186,14 +189,14 @@ def get_gw_ip_using_gw_mac(self) -> Optional[str]: # all of them are ipv6, return the first return gw_ips[0] - def is_gw_info_detected(self, info_type: str) -> bool: + def is_gw_info_detected(self, info_type: str, interface: str) -> bool: """ checks own attributes and the db for the gw mac/ip :param info_type: can be 'mac' or 'ip' """ info_mapping = { - "mac": ("gw_mac", self.db.get_gateway_mac), - "ip": ("gw_ip", self.db.get_gateway_ip), + "mac": ("gw_macs", self.db.get_gateway_mac), + "ip": ("gw_ips", self.db.get_gateway_ip), } if info_type not in info_mapping: @@ -201,14 +204,15 @@ def is_gw_info_detected(self, info_type: str) -> bool: attr, check_db_method = info_mapping[info_type] - if getattr(self, attr): + # did we get this interface's GW IP/MAC yet? + if interface in getattr(self, attr, {}): # the reason we don't just check the db is we don't want a db # call per each flow return True # did some other module manage to get it? - if info := check_db_method(): - setattr(self, attr, info) + if info := check_db_method(interface): + getattr(self, attr, {}).update({interface: info}) return True return False @@ -226,15 +230,19 @@ def get_gateway_info(self, flow): # some suricata flows dont have that, like SuricataFile objs return - gw_mac_found: bool = self.is_gw_info_detected("mac") + gw_mac_found: bool = self.is_gw_info_detected("mac", flow.interface) + if not gw_mac_found: + # we didnt get the MAC of the GW of this flow's interface + # ok consider the GW MAC = any dst MAC of a flow + # going from a private srcip -> a public ip if ( utils.is_private_ip(flow.saddr) and not utils.is_ignored_ip(flow.daddr) and flow.dmac ): - self.gw_mac: str = flow.dmac - self.db.set_default_gateway("MAC", self.gw_mac) + self.gw_macs.update({flow.interface: flow.dmac}) + self.db.set_default_gateway("MAC", flow.dmac, flow.interface) # self.print( # f"MAC address of the gateway detected: " # f"{green(self.gw_mac)}" @@ -242,13 +250,13 @@ def get_gateway_info(self, flow): gw_mac_found = True # we need the mac to be set to be able to find the ip using it - if not self.is_gw_info_detected("ip") and gw_mac_found: - self.gw_ip: Optional[str] = self.get_gw_ip_using_gw_mac() - if self.gw_ip: - self.db.set_default_gateway("IP", self.gw_ip) + if not self.is_gw_info_detected("ip", flow.interface) and gw_mac_found: + gw_ip: Optional[str] = self.get_gw_ip_using_gw_mac(flow.dmac) + if gw_ip: + self.gw_ips[flow.interface] = gw_ip + self.db.set_default_gateway("IP", gw_ip, flow.interface) self.print( - f"IP address of the gateway detected: " - f"{green(self.gw_ip)}" + f"IP address of the gateway detected: " f"{green(gw_ip)}" ) def add_flow_to_profile(self, flow): @@ -430,8 +438,12 @@ def should_set_localnet(self, flow) -> bool: returns true only if the saddr of the current flow is ipv4, private and we don't have the local_net set already """ - if self.is_localnet_set: - return False + if self.db.is_running_non_stop(): + if flow.interface in self.localnet_cache: + return False + else: + if "default" in self.localnet_cache: + return False if flow.saddr == "0.0.0.0": return False @@ -519,73 +531,77 @@ def get_private_client_ips( private_clients.append(ip) return private_clients - def get_localnet_of_given_interface(self) -> str | None: + def get_localnet_of_given_interface(self) -> Dict[str, str]: """ returns the local network of the given interface only if slips is running with -i """ - addrs = netifaces.ifaddresses(self.args.interface).get( - netifaces.AF_INET - ) - if not addrs: - return - for addr in addrs: - ip = addr.get("addr") - netmask = addr.get("netmask") - if ip and netmask: - network = ipaddress.IPv4Network( - f"{ip}/{netmask}", strict=False - ) - return str(network) - return None - - def get_local_net(self, flow) -> Optional[str]: - """ - gets the local network from client_ip param in the config file, + local_nets = {} + for interface in utils.get_all_interfaces(self.args): + addrs = netifaces.ifaddresses(interface).get(netifaces.AF_INET) + if not addrs: + return + for addr in addrs: + ip = addr.get("addr") + netmask = addr.get("netmask") + if ip and netmask: + network = ipaddress.IPv4Network( + f"{ip}/{netmask}", strict=False + ) + local_nets[interface] = str(network) + return local_nets + + def get_local_net_of_flow(self, flow) -> Dict[str, str]: + """ + gets the local network from client_ip + param in the config file, or by using the localnetwork of the first private srcip seen in the traffic """ - # For now the local network is only ipv4, but it - # could be ipv6 in the future. Todo. - if self.args.interface: - self.is_localnet_set = True - return self.get_localnet_of_given_interface() - - # slips is running on a file, we either have a client ip or not + local_net = {} + # Reaching this func means slips is running on a file. we either + # have a client ip or not private_client_ips: List[ Union[IPv4Network, IPv6Network, IPv4Address, IPv6Address] ] - private_client_ips = self.get_private_client_ips() - - if private_client_ips: + # get_private_client_ips from the config file + if private_client_ips := self.get_private_client_ips(): # does the client ip from the config already have the localnet? for range_ in private_client_ips: if isinstance(range_, IPv4Network) or isinstance( range_, IPv6Network ): - self.is_localnet_set = True - return str(range_) + local_net["default"] = str(range_) + return local_net - # all client ips should belong to the same local network, - # it doesn't make sense to have ips belonging to different - # networks in the config file! - ip: str = str(private_client_ips[0]) - else: - ip: str = flow.saddr + # For now the local network is only ipv4, but it + # could be ipv6 in the future. Todo. + ip: str = flow.saddr + if cidr := utils.get_cidr_of_private_ip(ip): + local_net["default"] = cidr + return local_net - self.is_localnet_set = True - return utils.get_cidr_of_private_ip(ip) + return local_net def handle_setting_local_net(self, flow): """ stores the local network if possible + sets the self.localnet_cache dict """ if not self.should_set_localnet(flow): return - local_net: str = self.get_local_net(flow) - self.print(f"Used local network: {green(local_net)}") - self.db.set_local_network(local_net) + if self.db.is_running_non_stop(): + self.localnet_cache = self.get_localnet_of_given_interface() + else: + self.localnet_cache = self.get_local_net_of_flow(flow) + + for interface, local_net in self.localnet_cache.items(): + self.db.set_local_network(local_net, interface) + to_print = f"Used local network: {green(local_net)}" + if interface != "default": + to_print += f" for interface {green(interface)}." + self.print(to_print) def get_msg_from_input_proc( self, q: multiprocessing.Queue, thread_safe=False @@ -660,7 +676,6 @@ def process_flow(self): line: dict = msg["line"] input_type: str = msg["input_type"] - # TODO who is putting this True here? if line is True: continue @@ -676,7 +691,6 @@ def process_flow(self): flow = self.input_handler_obj.process_line(line) if not flow: continue - self.add_flow_to_profile(flow) self.handle_setting_local_net(flow) self.db.increment_processed_flows() diff --git a/slips_files/core/structures/alerts.py b/slips_files/core/structures/alerts.py index c84e896ac..2fc1bd512 100644 --- a/slips_files/core/structures/alerts.py +++ b/slips_files/core/structures/alerts.py @@ -45,7 +45,7 @@ def normalize(value: float): @dataclass class Alert: profile: ProfileID - # this should have the fields start_Time and end_time set #TODO force it + # this should have the fields start_Time and end_time set timewindow: TimeWindow # the last evidence that triggered this alert last_evidence: Evidence diff --git a/slips_files/core/structures/evidence.py b/slips_files/core/structures/evidence.py index aa0192ba8..b05b45c5f 100644 --- a/slips_files/core/structures/evidence.py +++ b/slips_files/core/structures/evidence.py @@ -265,6 +265,9 @@ class Method(Enum): @dataclass class Evidence: + # IMPORTANT: remember to update dict_to_evidence() function based on the + # field you add to the evidence class, or any class used by the evidence + # class. evidence_type: EvidenceType description: str attacker: Attacker @@ -277,6 +280,7 @@ class Evidence: timestamp: str = field( metadata={"validate": lambda x: validate_timestamp(x)} ) + interface: str = field(default="default") victim: Optional[Victim] = field(default=False) proto: Optional[Proto] = field(default=False) dst_port: int = field(default=None) @@ -340,6 +344,7 @@ def dict_to_evidence(evidence: dict) -> Evidence: evidence_attributes = { "evidence_type": EvidenceType[evidence["evidence_type"]], "description": evidence["description"], + "interface": evidence["interface"], "attacker": Attacker(**evidence["attacker"]), "threat_level": ThreatLevel[evidence["threat_level"].upper()], "victim": ( diff --git a/slips_files/core/text_formatters/evidence_formatter.py b/slips_files/core/text_formatters/evidence_formatter.py index 1601e01f3..be4a28420 100644 --- a/slips_files/core/text_formatters/evidence_formatter.py +++ b/slips_files/core/text_formatters/evidence_formatter.py @@ -21,8 +21,10 @@ class EvidenceFormatter: - def __init__(self, db): + def __init__(self, db, args): self.db = db + # args given to slips on startup + self.args = args def get_evidence_to_log( self, evidence: Evidence, flow_datetime: str @@ -109,6 +111,9 @@ def format_evidence_for_printing( evidence: Evidence = self.add_threat_level_to_evidence_description( evidence ) + evidence: Evidence = self.add_interface_to_evidence_description( + evidence + ) evidence_string = self.line_wrap( f"Detected {evidence.description.strip()}" ) @@ -131,6 +136,17 @@ def add_threat_level_to_evidence_description( ) return evidence + def add_interface_to_evidence_description( + self, evidence: Evidence + ) -> Evidence: + + if not self.args.access_point: + # no eed to add the interface if slips is only monitoring one + return evidence + + evidence.description += f" Interface: {evidence.interface}." + return evidence + def get_printable_attacker_and_victim_info( self, evidence: Evidence ) -> str: diff --git a/slips_files/core/zeek_cmd_builder.py b/slips_files/core/zeek_cmd_builder.py new file mode 100644 index 000000000..b91b83c85 --- /dev/null +++ b/slips_files/core/zeek_cmd_builder.py @@ -0,0 +1,111 @@ +import os +from typing import List, Optional + + +class ZeekCommandBuilder: + """ + Builds Zeek (or Bro) command lines based on the given configuration. + """ + + def __init__( + self, + zeek_or_bro: str, + input_type: str, + rotation_period: str, + enable_rotation: bool, + tcp_inactivity_timeout: int, + packet_filter: Optional[str] = None, + ): + self.zeek_or_bro = zeek_or_bro + self.input_type = input_type + self.rotation_period = rotation_period + self.enable_rotation = enable_rotation + self.tcp_inactivity_timeout = tcp_inactivity_timeout + self.packet_filter = packet_filter + + def _get_input_parameter(self, pcap_or_interface: str) -> List[str]: + if self.input_type == "interface": + return ["-i", pcap_or_interface] + + elif self.input_type == "pcap": + pcap = self._get_relative_pcap_path(pcap_or_interface) + # using a list of params instead of a str for storing the cmd + # becaus ethe given path may contain spaces + return ["-r", pcap] + + raise ValueError(f"Unsupported input_type: {self.input_type}") + + def _get_rotation_args(self) -> List[str]: + # rotation is disabled unless it's an interface + if self.input_type == "interface" and self.enable_rotation: + # how often to rotate zeek files? taken from slips.yaml + return [ + "-e", + f'"redef Log::default_rotation_interval =' + f' {self.rotation_period} ;"', + ] + return [] + + def _build_packet_filter(self, tcpdump_filter: Optional[str]) -> List[str]: + # build packet filter + # user-defined filter in slips.yaml + packet_filter = ( + ["-f", self.packet_filter] if self.packet_filter else [] + ) + + if tcpdump_filter: + # no need to quote manually; just wrap in parentheses + tcpdump_filter = f"({tcpdump_filter.strip()})" + + if packet_filter: + # combine user-provided and tcpdump filters + combined = f"{self.packet_filter} and {tcpdump_filter}" + packet_filter = ["-f", combined] + else: + packet_filter = ["-f", tcpdump_filter] + + return packet_filter + + def _get_relative_pcap_path(self, pcap: str) -> str: + # Find if the pcap file name was absolute or relative + if not os.path.isabs(pcap): + # now the given pcap is relative to slips main dir + # slips can store the zeek logs dir either in the + # output dir (by default in Slips/output/_/zeek_files/), + # or in any dir specified with -o + # construct an abs path from the given path so slips can find the given pcap + # no matter where the zeek dir is placed + pcap = os.path.join(os.getcwd(), pcap) + return pcap + + def build( + self, pcap_or_interface: str, tcpdump_filter: Optional[str] = None + ) -> List[str]: + """ + constructs the zeek command based on the user given + pcap/interface/packet filter/etc. + """ + bro_parameter = self._get_input_parameter(pcap_or_interface) + rotation = self._get_rotation_args() + packet_filter = self._build_packet_filter(tcpdump_filter) + zeek_scripts_dir = os.path.join(os.getcwd(), "zeek-scripts") + + # 'local' is removed from the command because it + # loads policy/protocols/ssl/expiring-certs and + # and policy/protocols/ssl/validate-certs and they have conflicts + # with our own + # zeek-scripts/expiring-certs and validate-certs + # we have our own copy pf local.zeek in __load__.zeek + command = [ + self.zeek_or_bro, + "-C", + *bro_parameter, + "-e", + f"redef tcp_inactivity_timeout={self.tcp_inactivity_timeout}mins;", + *rotation, + zeek_scripts_dir, + # putting -f last is best practice + *packet_filter, + ] + + return command diff --git a/tests/module_factory.py b/tests/module_factory.py index e5660b95a..a50c1d3df 100644 --- a/tests/module_factory.py +++ b/tests/module_factory.py @@ -763,7 +763,8 @@ def create_evidence_handler_obj(self, mock_db): @patch(MODULE_DB_MANAGER, name="mock_db") def create_evidence_formatter_obj(self, mock_db): - return EvidenceFormatter(mock_db) + args = Mock() + return EvidenceFormatter(mock_db, args) @patch(MODULE_DB_MANAGER, name="mock_db") def create_symbol_handler_obj(self, mock_db): diff --git a/tests/test_arp_poisoner.py b/tests/test_arp_poisoner.py index 8e1c54ce3..ae32255e6 100644 --- a/tests/test_arp_poisoner.py +++ b/tests/test_arp_poisoner.py @@ -53,7 +53,7 @@ def test_can_poison_ip( if is_gw: poisoner.db.get_gateway_ip = MagicMock(return_value=ip) poisoner.is_broadcast = MagicMock(return_value=is_bcast) - assert poisoner.can_poison_ip(ip) == expected + assert poisoner.can_poison_ip(ip, "eth0") == expected def test__arp_scan(poisoner): @@ -122,7 +122,7 @@ def test__cut_targets_internet(poisoner, ip, mac, gw_mac): ), patch.object(poisoner.db, "get_gateway_mac", return_value=gw_mac), ): - poisoner._cut_targets_internet(ip, mac, gw_mac) + poisoner._cut_targets_internet(ip, mac, gw_mac, "eth0") assert sendp.call_count == 2 @@ -143,7 +143,7 @@ def test__isolate_target_from_localnet(poisoner): } ) poisoner._isolate_target_from_localnet( - "192.168.1.100", "aa:aa:aa:aa:aa:aa" + "192.168.1.100", "aa:aa:aa:aa:aa:aa", "eth0" ) assert sendp.call_count == 2 diff --git a/tests/test_blocking.py b/tests/test_blocking.py index 8595ec16b..45bbcdf80 100644 --- a/tests/test_blocking.py +++ b/tests/test_blocking.py @@ -127,6 +127,7 @@ def test_main_blocking_logic(block, expected_block_called): "dport": 80, "sport": 12345, "protocol": "tcp", + "interface": "eth0", } msg_block = {"data": json.dumps(blocking_data)} @@ -152,6 +153,7 @@ def test_main_blocking_logic(block, expected_block_called): "dport": 80, "sport": 12345, "protocol": "tcp", + "interface": "eth0", }, ) else: @@ -166,6 +168,7 @@ def test_main_blocking_logic(block, expected_block_called): "dport": 80, "sport": 12345, "protocol": "tcp", + "interface": "eth0", }, ) mock_update.assert_not_called() diff --git a/tests/test_checker.py b/tests/test_checker.py index 08837f2f3..7f7dedc38 100644 --- a/tests/test_checker.py +++ b/tests/test_checker.py @@ -72,7 +72,7 @@ def test_check_given_flags(args, expected_calls, monkeypatch): checker.main.redis_man.check_redis_database.return_value = False checker.input_module_exists = mock.MagicMock(return_value=False) - checker.check_given_flags() + checker.verify_given_flags() for method_name in expected_calls: method = getattr(checker.main, method_name) @@ -118,47 +118,46 @@ def test_check_given_flags_root_user(monkeypatch): ) as mock_delete, mock.patch.object( checker.main, "terminate_slips" ) as mock_term: - checker.check_given_flags() + checker.verify_given_flags() mock_delete.assert_called_once() mock_term.assert_called_once() def test_check_input_type_interface(): - checker = ModuleFactory().create_checker_obj() checker.main.args.interface = "eth0" checker.main.args.filepath = None checker.main.args.db = None checker.main.args.input_module = None - result = checker.check_input_type() + result = checker.get_input_type() assert result == ("interface", "eth0", False) def test_check_input_type_db(): - checker = ModuleFactory().create_checker_obj() checker.main.args.interface = None checker.main.args.filepath = None - checker.main.args.db = True + checker.main.args.access_point = None checker.main.args.input_module = None + checker.main.args.db = True checker.main.redis_man.load_db = mock.MagicMock() - result = checker.check_input_type() + result = checker.get_input_type() assert result is None checker.main.redis_man.load_db.assert_called_once() def test_check_input_type_input_module(): - checker = ModuleFactory().create_checker_obj() checker.main.args.interface = None checker.main.args.filepath = None checker.main.args.db = None + checker.main.args.access_point = None checker.main.args.input_module = "zeek" - result = checker.check_input_type() + result = checker.get_input_type() assert result == ("zeek", "input_module", "zeek") @@ -176,6 +175,7 @@ def test_check_input_type_filepath(filepath, is_file, is_dir, expected_result): checker.main.args.interface = None checker.main.args.filepath = filepath checker.main.args.db = None + checker.main.args.access_point = None checker.main.args.input_module = None with mock.patch("os.path.isfile", return_value=is_file), mock.patch( @@ -184,14 +184,14 @@ def test_check_input_type_filepath(filepath, is_file, is_dir, expected_result): checker.main, "get_input_file_type", return_value="mock_type" ): - result = checker.check_input_type() + result = checker.get_input_type() assert result == expected_result def test_check_input_type_stdin(): - checker = ModuleFactory().create_checker_obj() checker.main.args.interface = None + checker.main.args.access_point = None checker.main.args.filepath = "stdin-type" checker.main.args.db = None checker.main.args.input_module = None @@ -204,20 +204,20 @@ def test_check_input_type_stdin(): return_value=("mock_type", "mock_line_type"), ): - result = checker.check_input_type() + result = checker.get_input_type() assert result == ("mock_type", "stdin-type", "mock_line_type") def test_check_input_type_no_input(): - checker = ModuleFactory().create_checker_obj() checker.main.args.interface = None + checker.main.args.access_point = None checker.main.args.filepath = None checker.main.args.db = None checker.main.args.input_module = None with pytest.raises(SystemExit) as excinfo: - checker.check_input_type() + checker.get_input_type() assert excinfo.value.code == -1 diff --git a/tests/test_conn.py b/tests/test_conn.py index b2c40d572..6a67f842b 100644 --- a/tests/test_conn.py +++ b/tests/test_conn.py @@ -344,9 +344,11 @@ def test_check_multiple_reconnection_attempts( ) def test_is_ignored_ip_data_upload(ip_address, expected_result): conn = ModuleFactory().create_conn_analyzer_obj() - conn.gateway = "192.168.1.1" + conn.db.get_gateway_ip = Mock(return_value="192.168.1.1") - assert conn.is_ignored_ip_data_upload(ip_address) is expected_result + assert ( + conn.is_ignored_ip_data_upload(ip_address, "eth0") is expected_result + ) @pytest.mark.parametrize( @@ -390,23 +392,24 @@ def test_get_sent_bytes(all_flows, expected_bytes_sent): @pytest.mark.parametrize( - "sbytes, daddr, expected_result, expected_call_count", + "sbytes, ignored_ip, daddr, expected_result, expected_call_count", [ # Testcase1: Exceeds threshold - (100 * 1024 * 1024 + 1, "192.168.1.2", True, 1), + (100 * 1024 * 1024 + 1, False, "192.168.1.2", True, 1), # Testcase2: Below threshold - (10 * 1024 * 1024, "192.168.1.2", False, 0), + (10 * 1024 * 1024, False, "192.168.1.2", False, 0), # Testcase3: Ignored IP - (100 * 1024 * 1024 + 1, "192.168.1.1", False, 0), + (100 * 1024 * 1024 + 1, True, "192.168.1.1", False, 0), ], ) def test_check_data_upload( - mocker, sbytes, daddr, expected_result, expected_call_count + mocker, sbytes, daddr, ignored_ip, expected_result, expected_call_count ): """ Tests the check_data_upload function with various scenarios for data upload. """ conn = ModuleFactory().create_conn_analyzer_obj() + conn.is_ignored_ip_data_upload = Mock(return_value=ignored_ip) mock_set_evidence = mocker.patch( "modules.flowalerts.set_evidence.SetEvidenceHelper.data_exfiltration" ) @@ -429,6 +432,7 @@ def test_check_data_upload( dmac="", state="", history="", + interface="eth0", ) assert conn.check_data_upload(profileid, twid, flow) is expected_result assert mock_set_evidence.call_count == expected_call_count diff --git a/tests/test_database.py b/tests/test_database.py index e83b1caaf..3c92b8b42 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -31,23 +31,24 @@ twid = "timewindow1" test_ip = "192.168.1.1" flow = Conn( - "1601998398.945854", - "1234", - test_ip, - "8.8.8.8", - 5, - "TCP", - "dhcp", - 80, - 88, - 20, - 20, - 20, - 20, - "", - "", - "Established", - "", + starttime="1601998398.945854", + uid="1234", + saddr=test_ip, + daddr="8.8.8.8", + dur=5, + proto="TCP", + appproto="dhcp", + sport=80, + dport=88, + spkts=20, + dpkts=20, + sbytes=20, + dbytes=20, + state="", + history="", + smac="Established", + dmac="", + interface="eth0", ) @@ -112,6 +113,7 @@ def test_set_evidence(): victim: Victim = Victim( direction=Direction.DST, ioc_type=IoCType.IP, value="8.8.8.8" ) + db._get_evidence_interface = Mock(return_value="eth0") evidence: Evidence = Evidence( evidence_type=EvidenceType.SSH_SUCCESSFUL, attacker=attacker, @@ -171,12 +173,13 @@ def test_add_mac_addr_with_new_ipv4(): mac_addr = "00:00:5e:00:53:af" db.rdb.is_gw_mac = Mock(return_value=False) + db.rdb._should_associate_this_mac_with_this_ip = Mock(return_value=True) db.r.hget = Mock() db.r.hset = Mock() - db.r.hmget = Mock(return_value=[None]) # No entry initially + db.r.hmget = Mock(return_value=[None]) # simulate adding a new MAC and IPv4 address - assert db.add_mac_addr_to_profile(profileid_ipv4, mac_addr) is True + assert db.add_mac_addr_to_profile(profileid_ipv4, mac_addr, "eth0") is True # Ensure the IP is associated in the 'MAC' hash db.r.hmget.assert_called_with("MAC", mac_addr) @@ -191,6 +194,7 @@ def test_add_mac_addr_with_existing_ipv4(): ipv4 = "192.168.1.5" mac_addr = "00:00:5e:00:53:af" db.rdb.is_gw_mac = Mock(return_value=False) + db.rdb._should_associate_this_mac_with_this_ip = Mock(return_value=True) db.r.hget = Mock() db.r.hset = Mock() db.r.hmget = Mock(return_value=[json.dumps([ipv4])]) @@ -198,7 +202,7 @@ def test_add_mac_addr_with_existing_ipv4(): new_profile = "profile_192.168.1.6" # try to add a new profile with the same MAC but another IPv4 address - assert db.add_mac_addr_to_profile(new_profile, mac_addr) is False + assert db.add_mac_addr_to_profile(new_profile, mac_addr, "eth0") is False def test_add_mac_addr_with_ipv6_association(): @@ -212,6 +216,7 @@ def test_add_mac_addr_with_ipv6_association(): # mock existing entry with ipv6 db.rdb.is_gw_mac = Mock(return_value=False) + db.rdb._should_associate_this_mac_with_this_ip = Mock(return_value=True) db.rdb.update_mac_of_profile = Mock() db.r.hmget = Mock(return_value=[json.dumps([ipv4])]) db.r.hset = Mock() @@ -220,11 +225,11 @@ def test_add_mac_addr_with_ipv6_association(): ipv6 = "2001:0db8:85a3:0000:0000:8a2e:0370:7334" profile_ipv6 = f"profile_{ipv6}" # try to associate an ipv6 with the same MAC address - assert db.add_mac_addr_to_profile(profile_ipv6, mac_addr) + assert db.add_mac_addr_to_profile(profile_ipv6, mac_addr, "eth0") expected_calls = [ - call(profile_ipv4, mac_addr), # call with ipv4 profile id - call(profile_ipv6, mac_addr), # call with ipv6 profile id + call(profile_ipv4, mac_addr), # call with the ipv4 profileid + call(profile_ipv6, mac_addr), # call with the ipv6 profileid ] db.rdb.update_mac_of_profile.assert_has_calls( expected_calls, any_order=True diff --git a/tests/test_evidence_formatter.py b/tests/test_evidence_formatter.py index dc0717ea2..3ec58af49 100644 --- a/tests/test_evidence_formatter.py +++ b/tests/test_evidence_formatter.py @@ -38,10 +38,10 @@ }, ProfileID("192.168.1.1"), TimeWindow(1), - "IP 192.168.1.1 detected as malicious in timewindow 1" + "converted_time IP 192.168.1.1 detected as malicious in timewindow 1" " (start 2023/07/01 12:00:00, stop 2023/07/01 12:05:00) " "given the following evidence:\n" - "\t- Detected Port scan detected threat level: medium.\n", + "\t- Detected Port scan detected threat level: medium. Interface: default.\n", ), # testcase2: Multiple evidence ( @@ -73,11 +73,11 @@ }, ProfileID("192.168.1.1"), TimeWindow(1), - "IP 192.168.1.1 detected as malicious in timewindow 1" + "converted_time IP 192.168.1.1 detected as malicious in timewindow 1" " (start 2023/07/01 12:00:00, stop 2023/07/01 12:05:00) " "given the following evidence:\n" - "\t- Detected Port scan detected threat level: medium.\n" - "\t- Detected Malicious JA3 fingerprint threat level: high.\n", + "\t- Detected Port scan detected threat level: medium. Interface: default.\n" + "\t- Detected Malicious JA3 fingerprint threat level: high. Interface: default.\n", ), ], ) @@ -105,13 +105,12 @@ def test_format_evidence_for_printing( id="123", last_flow_datetime="", ) - formatter.line_wrap = Mock() formatter.line_wrap = lambda x: x result = formatter.format_evidence_for_printing( alert, all_evidence, ) - + # remove colors result = ( result.replace("\033[31m", "") .replace("\033[36m", "") diff --git a/tests/test_evidence_handler.py b/tests/test_evidence_handler.py index 55bd85c34..af47ea0ce 100644 --- a/tests/test_evidence_handler.py +++ b/tests/test_evidence_handler.py @@ -32,7 +32,7 @@ ], ) def test_decide_blocking( - profileid, our_ips, expected_result, expected_publish_call_count + mocker, profileid, our_ips, expected_result, expected_publish_call_count ): evidence_handler = ModuleFactory().create_evidence_handler_obj() evidence_handler.blocking_modules_supported = True @@ -41,7 +41,12 @@ def test_decide_blocking( tw = TimeWindow( 2, "2025-05-09T13:27:45.123456", "2025-05-09T13:27:45.123456" ) + mocker.patch( + "slips_files.common.slips_utils.Utils.get_interface_of_ip", + return_value="eth0", + ) result = evidence_handler.decide_blocking(profileid, tw) + assert result == expected_result assert mock_publish.call_count == expected_publish_call_count diff --git a/tests/test_flow_handler.py b/tests/test_flow_handler.py index d60e723e6..ae17d84fe 100644 --- a/tests/test_flow_handler.py +++ b/tests/test_flow_handler.py @@ -142,6 +142,7 @@ def test_handle_conn(flow): flow.daddr = "192.168.1.1" flow.dport = 80 flow.proto = "tcp" + flow.interface = "eth0" mock_symbol = Mock() mock_symbol.compute.return_value = ("A", "B", "C") @@ -182,7 +183,7 @@ def test_handle_conn(flow): flow, flow_handler.profileid, flow_handler.twid, "benign" ) flow_handler.db.add_mac_addr_to_profile.assert_called_with( - flow_handler.profileid, flow.smac + flow_handler.profileid, flow.smac, flow.interface ) if not flow_handler.running_non_stop: flow_handler.publisher.new_MAC.assert_has_calls( @@ -217,6 +218,7 @@ def test_handle_arp(flow): flow.smac = "ff:ee:dd:cc:bb:aa" flow.daddr = "192.168.1.1" flow.saddr = "192.168.1.2" + flow.interface = "eth0" flow_handler.publisher = Mock() flow_handler.handle_arp() @@ -229,7 +231,7 @@ def test_handle_arp(flow): "new_arp", json.dumps(expected_payload) ) flow_handler.db.add_mac_addr_to_profile.assert_called_with( - flow_handler.profileid, flow.smac + flow_handler.profileid, flow.smac, flow.interface ) flow_handler.publisher.new_MAC.assert_has_calls( [call(flow.dmac, flow.daddr), call(flow.smac, flow.saddr)] @@ -281,6 +283,7 @@ def test_handle_notice(flow): flow.note = "Gateway_addr_identified: 192.168.1.1" flow.msg = "Gateway_addr_identified: 192.168.1.1" + flow.interface = "eth0" flow_handler.db.get_gateway_ip.return_value = False flow_handler.db.get_gateway_mac.return_value = False @@ -291,8 +294,10 @@ def test_handle_notice(flow): flow_handler.db.add_out_notice.assert_called_with( flow_handler.profileid, flow_handler.twid, flow ) - flow_handler.db.set_default_gateway.assert_any_call("IP", "192.168.1.1") - flow_handler.db.set_default_gateway.assert_any_call("MAC", "xyz") + flow_handler.db.set_default_gateway.assert_any_call( + "IP", "192.168.1.1", "eth0" + ) + flow_handler.db.set_default_gateway.assert_any_call("MAC", "xyz", "eth0") flow_handler.db.add_altflow.assert_called_with( flow, flow_handler.profileid, flow_handler.twid, "benign" ) @@ -308,6 +313,7 @@ def test_handle_dhcp(): client_addr="192.168.1.1", host_name="test-host", requested_addr="192.168.1.4", + interface="eth0", ) flow_handler = ModuleFactory().create_flow_handler_obj(flow) flow_handler.publisher = Mock() @@ -315,7 +321,7 @@ def test_handle_dhcp(): flow_handler.publisher.new_MAC.assert_called_with(flow.smac, flow.saddr) flow_handler.db.add_mac_addr_to_profile.assert_called_with( - flow_handler.profileid, flow.smac + flow_handler.profileid, flow.smac, flow.interface ) flow_handler.db.store_dhcp_server.assert_called_with("192.168.1.2") flow_handler.db.mark_profile_as_dhcp.assert_called_with( diff --git a/tests/test_host_ip_manager.py b/tests/test_host_ip_manager.py index 55a8f3462..f53d1f987 100644 --- a/tests/test_host_ip_manager.py +++ b/tests/test_host_ip_manager.py @@ -1,25 +1,33 @@ # SPDX-FileCopyrightText: 2021 Sebastian Garcia # SPDX-License-Identifier: GPL-2.0-only from unittest.mock import MagicMock, patch, Mock + +import netifaces import pytest + from tests.module_factory import ModuleFactory import sys @pytest.mark.parametrize( - "is_interface, host_ip, modified_profiles, " + "is_interface, host_ips, modified_profiles, " "expected_calls, expected_result", - [ # Testcase1: Should update host IP - (True, "192.168.1.1", set(), 1, "192.168.1.2"), - # Testcase2: Shouldn't update host IP - (True, "192.168.1.1", {"192.168.1.1"}, 0, "192.168.1.1"), - # Testcase3: Shouldn't update host IP (not interface) - (False, "192.168.1.1", set(), 0, None), + [ + # Shouldn't update host IP + ( + True, + {"eth0": "192.168.1.1"}, + {"192.168.1.1"}, + 0, + {"eth0": "192.168.1.1"}, + ), + # Shouldn't update host IP (not interface) + (False, {"eth0": "192.168.1.1"}, set(), 0, None), ], ) -def test_update_host_ip( +def test_update_host_ip_shouldnt_update( is_interface, - host_ip, + host_ips, modified_profiles, expected_calls, expected_result, @@ -30,72 +38,138 @@ def test_update_host_ip( host_ip_man.get_host_ip = Mock() host_ip_man.get_host_ip.return_value = "192.168.1.2" host_ip_man.main.db.set_host_ip = MagicMock() - result = host_ip_man.update_host_ip(host_ip, modified_profiles) + host_ip_man.store_host_ip = MagicMock() + result = host_ip_man.update_host_ip(host_ips, modified_profiles) assert result == expected_result assert host_ip_man.get_host_ip.call_count == expected_calls @pytest.mark.parametrize( - "interfaces, ifaddresses, expected", + "is_interface, host_ips, modified_profiles, " "expected_calls", + [ + # Shouldn't update host IP + (True, {"eth0": "192.168.1.1"}, set(), 1) + ], +) +def test_update_host_ip_should_update( + is_interface, + host_ips, + modified_profiles, + expected_calls, +): + host_ip_man = ModuleFactory().create_host_ip_manager_obj() + host_ip_man.main.db.is_running_non_stop.return_value = is_interface + + host_ip_man.get_host_ip = Mock(return_value="192.168.1.2") + host_ip_man.store_host_ip = MagicMock() + + host_ip_man.update_host_ip(host_ips, modified_profiles) + assert host_ip_man.store_host_ip.call_count == expected_calls + + +@pytest.mark.parametrize( + "args_interface,args_access_point,iface_addrs,expected", [ - ( # 2 here is AF_INET - ["lo", "eth0"], - {"lo": {}, "eth0": {2: [{"addr": "192.168.1.10"}]}}, - "192.168.1.10", + # Single interface with valid IPv4 + ( + "eth0", + None, + {netifaces.AF_INET: [{"addr": "192.168.1.10"}]}, + {"eth0": "192.168.1.10"}, + ), + # Only loopback IP -> should be skipped + ( + "lo", + None, + {netifaces.AF_INET: [{"addr": "127.0.0.1"}]}, + {}, ), + # Interface without AF_INET -> skipped ( - ["lo", "eth0"], - { - "lo": {2: [{"addr": "127.0.0.1"}]}, - "eth0": {2: [{"addr": "127.0.0.2"}]}, - }, + "eth1", None, + {}, + {}, ), - (["lo"], {"lo": {2: [{"addr": "127.0.0.1"}]}}, None), ], ) -def test_get_host_ip(interfaces, ifaddresses, expected): +@patch("netifaces.ifaddresses") +def test_get_host_ips_single_interface( + mock_ifaddresses, + args_interface, + args_access_point, + iface_addrs, + expected, +): + """Test _get_host_ips for single-interface cases.""" + host_ip_man = ModuleFactory().create_host_ip_manager_obj() + host_ip_man.main.args.growing = None + host_ip_man.main.args.interface = args_interface + host_ip_man.main.args.access_point = args_access_point + + mock_ifaddresses.return_value = iface_addrs + result = host_ip_man._get_host_ips() + + assert result == expected + mock_ifaddresses.assert_called_once_with(args_interface) + + +@patch("netifaces.ifaddresses") +def test_get_host_ips_multiple_interfaces_from_access_point(mock_ifaddresses): + """Test _get_host_ips when using multiple interfaces via --access-point.""" + host_ip_man = ModuleFactory().create_host_ip_manager_obj() + host_ip_man.main.args.interface = None + host_ip_man.main.args.growing = None + host_ip_man.main.args.access_point = "wlan0,eth0" + + def mock_ifaddresses_side_effect(iface): + if iface == "wlan0": + return {netifaces.AF_INET: [{"addr": "10.0.0.5"}]} + elif iface == "eth0": + return {netifaces.AF_INET: [{"addr": "192.168.0.8"}]} + return {} + + mock_ifaddresses.side_effect = mock_ifaddresses_side_effect + + result = host_ip_man._get_host_ips() + assert result == {"wlan0": "10.0.0.5", "eth0": "192.168.0.8"} + + +def test_get_host_ips_growing_zeek_dir(mocker): + """Test _get_host_ips when using multiple interfaces via --access-point.""" host_ip_man = ModuleFactory().create_host_ip_manager_obj() - host_ip_man.main.args.interface = None # simulate not passed, to use all - host_ip_man.main.args.growing = ( - True # simulate -g used, so use all interfaces + host_ip_man.main.args.interface = None + host_ip_man.main.args.growing = True + host_ip_man.main.args.access_point = None + host_ip_man._get_default_host_ip = Mock(return_value="10.0.0.5") + + mocker.patch( + "slips_files.common.slips_utils.Utils.infer_used_interface", + return_value="eth0", ) - with patch( - "managers.host_ip_manager.netifaces.interfaces", - return_value=interfaces, - ), patch( - "managers.host_ip_manager.netifaces.ifaddresses", - side_effect=lambda iface: ifaddresses.get(iface, {}), - ), patch( - "managers.host_ip_manager.netifaces.AF_INET", 2 - ): - result = host_ip_man.get_host_ip() - assert result == expected + result = host_ip_man._get_host_ips() + assert result == {"eth0": "10.0.0.5"} @pytest.mark.parametrize( - "running_on_interface, host_ip," - "set_host_ip_side_effect, expected_result", + "running_on_interface, host_ip," "expected_result", [ # testcase1: Running on interface, valid IP - (True, "192.168.1.100", None, "192.168.1.100"), + (True, {"eth0": "192.168.1.100"}, {"eth0": "192.168.1.100"}), # testcase2: Not running on interface - (False, "192.168.1.100", None, None), + (False, {"eth0": "192.168.1.100"}, None), ], ) def test_store_host_ip( running_on_interface, host_ip, - set_host_ip_side_effect, expected_result, ): host_ip_man = ModuleFactory().create_host_ip_manager_obj() host_ip_man.main.db.is_running_non_stop.return_value = running_on_interface - host_ip_man.get_host_ip = MagicMock(return_value=host_ip) - host_ip_man.main.db.set_host_ip = MagicMock( - side_effect=set_host_ip_side_effect - ) + host_ip_man._get_host_ips = MagicMock(return_value=host_ip) + host_ip_man.main.db.set_host_ip = MagicMock() with patch.object(sys, "argv", ["-i"] if running_on_interface else []): with patch("time.sleep"): diff --git a/tests/test_input.py b/tests/test_input.py index 142633b43..0a1944565 100644 --- a/tests/test_input.py +++ b/tests/test_input.py @@ -15,29 +15,31 @@ @pytest.mark.parametrize( "input_type,input_information", - # the pcaps here must have a conn.log when read by zeek [("pcap", "dataset/test7-malicious.pcap")], ) -def test_handle_pcap_and_interface(input_type, input_information): - # no need to test interfaces because in that case read_zeek_files runs - # in a loop and never returns +def test_handle_pcap_and_interface(tmp_path, input_type, input_information): input = ModuleFactory().create_input_obj(input_information, input_type) - input.zeek_pid = "False" - input.is_zeek_tabs = False - input.start_observer = Mock() + input.zeek_dir = tmp_path + # Mock attributes and methods used inside the function input.read_zeek_files = Mock() - input.zeek_thread = Mock() - with ( - patch.object(input, "get_flows_number", return_value=500), - patch("time.sleep"), - ): + input.print_lines_read = Mock() + input.mark_self_as_done_processing = Mock() + input.stop_observer = Mock() + input.init_zeek = Mock() + + with patch("os.makedirs"), patch("os.path.exists", return_value=True): assert input.handle_pcap_and_interface() is True - input.zeek_thread.start.assert_called_once() + + # Assert that the expected methods were called + input.init_zeek.assert_called_once_with(input.zeek_dir, input.given_path) input.read_zeek_files.assert_called_once() - input.start_observer.assert_called_once() + input.print_lines_read.assert_called_once() + input.mark_self_as_done_processing.assert_called_once() + input.stop_observer.assert_called_once() - # delete the zeek logs created - shutil.rmtree(input.zeek_dir) + # Clean up any directories created (safe guard) + if os.path.exists(input.zeek_dir): + shutil.rmtree(input.zeek_dir, ignore_errors=True) @pytest.mark.parametrize( @@ -108,7 +110,7 @@ def test_cache_nxt_line_in_file(path: str, is_tabs: str, line_cached: bool): input.file_time = {} input.is_zeek_tabs = is_tabs - assert input.cache_nxt_line_in_file(path) == line_cached + assert input.cache_nxt_line_in_file(path, "eth0") == line_cached if line_cached: assert input.cache_lines[path]["type"] == path assert input.cache_lines[path]["data"] @@ -352,7 +354,8 @@ def test_get_file_handle_existing_file(): def test_shutdown_gracefully_all_components_active(): """ - Test shutdown_gracefully when all components (open files, zeek, remover thread) are active. + Test shutdown_gracefully when all components + (open files, zeek, remover thread) are active. """ input_process = ModuleFactory().create_input_obj("", "zeek_log_file") input_process.stop_observer = MagicMock(return_value=True) @@ -362,11 +365,12 @@ def test_shutdown_gracefully_all_components_active(): input_process.zeek_thread = MagicMock() input_process.zeek_thread.start() input_process.open_file_handlers = {"test_file.log": MagicMock()} - input_process.zeek_pid = 123 + input_process.zeek_pids = [123, 321] with patch("os.kill") as mock_kill: assert input_process.shutdown_gracefully() - mock_kill.assert_called_with(input_process.zeek_pid, signal.SIGKILL) + for pid in input_process.zeek_pids: + mock_kill.assert_any_call(pid, signal.SIGKILL) assert input_process.open_file_handlers["test_file.log"].close.called @@ -382,12 +386,12 @@ def test_shutdown_gracefully_no_open_files(): input_process.zeek_thread = MagicMock() input_process.zeek_thread.start() input_process.open_file_handlers = {} - input_process.zeek_pid = os.getpid() + input_process.zeek_pids = [123] with patch("os.kill") as mock_kill: assert input_process.shutdown_gracefully() is True mock_kill.assert_called_once_with( - input_process.zeek_pid, signal.SIGKILL + input_process.zeek_pids[0], signal.SIGKILL ) @@ -401,33 +405,11 @@ def test_shutdown_gracefully_zeek_not_running(): input_process.remover_thread = MagicMock() input_process.remover_thread.start() input_process.open_file_handlers = {"test_file.log": MagicMock()} - input_process.zeek_pid = os.getpid() - - with patch("os.kill") as mock_kill: - assert input_process.shutdown_gracefully() is True - mock_kill.assert_called_once_with( - input_process.zeek_pid, signal.SIGKILL - ) - assert input_process.open_file_handlers["test_file.log"].close.called - - -def test_shutdown_gracefully_remover_thread_not_running(): - """ - Test shutdown_gracefully when the remover thread is not running. - """ - input_process = ModuleFactory().create_input_obj("", "zeek_log_file") - input_process.stop_observer = MagicMock(return_value=True) - input_process.stop_queues = MagicMock(return_value=True) - input_process.zeek_thread = MagicMock() - input_process.zeek_thread.start() - input_process.open_file_handlers = {"test_file.log": MagicMock()} - input_process.zeek_pid = os.getpid() + input_process.zeek_pids = [] with patch("os.kill") as mock_kill: assert input_process.shutdown_gracefully() is True - mock_kill.assert_called_once_with( - input_process.zeek_pid, signal.SIGKILL - ) + mock_kill.assert_not_called() assert input_process.open_file_handlers["test_file.log"].close.called diff --git a/tests/test_ip_info.py b/tests/test_ip_info.py index d138f5cb0..e9e8043a0 100644 --- a/tests/test_ip_info.py +++ b/tests/test_ip_info.py @@ -4,7 +4,6 @@ import asyncio -import netifaces from tests.module_factory import ModuleFactory import maxminddb @@ -16,7 +15,6 @@ import json import requests import socket -import subprocess from slips_files.core.structures.evidence import ( ThreatLevel, Evidence, @@ -407,48 +405,58 @@ async def test_shutdown_gracefully( @pytest.mark.parametrize( - "is_running_non_stop, is_running_in_ap_mode, ifaddresses_ret, get_gw_ret, expected", + "gw_return, expected", [ - # ap mode: returns own ip - (True, True, {"addr": "10.0.0.1"}, None, "10.0.0.1"), - (True, True, KeyError, None, None), - # not ap mode: returns gw - (True, False, None, "192.168.1.1", "192.168.1.1"), - (True, False, None, None, None), - # not running - (False, True, {"addr": "10.0.0.1"}, None, None), - (False, False, None, "192.168.1.1", None), + ( + ["10.0.0.1", "192.168.1.1"], + {"wlan0": "10.0.0.1", "eth0": "192.168.1.1"}, + ), + ([None, None], {}), + ( + ["10.0.0.1", None], + { + "wlan0": "10.0.0.1", + }, + ), ], ) -def test_get_gateway_ip_if_interface( - mocker, - is_running_non_stop, - is_running_in_ap_mode, - ifaddresses_ret, - get_gw_ret, - expected, +def test_get_gateway_ip_if_interface_args_access_point( + mocker, gw_return, expected +): + ip_info = ModuleFactory().create_ip_info_obj() + ip_info.is_running_non_stop = True + ip_info.is_running_in_ap_mode = True + ip_info.args.access_point = "wlan0,eth0" + mocker.patch( + "slips_files.common.slips_utils.Utils.get_gateway_for_iface", + return_value=gw_return, + ) + + +@pytest.mark.parametrize( + "gw_return, expected", + [ + ("10.0.0.1", {"ap0": "10.0.0.1"}), + (None, {}), + ("192.168.1.1", {"wlan0": "192.168.1.1"}), + ("10.0.0.1", None), + ("192.168.1.1", None), + ], +) +def test_get_gateway_ip_if_interface_args_interface( + mocker, gw_return, expected ): ip_info = ModuleFactory().create_ip_info_obj() - ip_info.is_running_non_stop = is_running_non_stop - ip_info.is_running_in_ap_mode = is_running_in_ap_mode + ip_info.is_running_non_stop = True + ip_info.is_running_in_ap_mode = False ip_info.args.interface = "wlan0" + mocker.patch( + "slips_files.common.slips_utils.Utils.get_gateway_for_iface", + return_value=gw_return, + ) - if is_running_non_stop: - if is_running_in_ap_mode: - if ifaddresses_ret is KeyError: - mocker.patch("netifaces.ifaddresses", side_effect=KeyError) - else: - mocker.patch( - "netifaces.ifaddresses", - return_value={netifaces.AF_INET: [ifaddresses_ret]}, - ) - else: - mocker.patch.object( - ip_info, "get_gateway_for_iface", return_value=get_gw_ret - ) - - result = ip_info.get_gateway_ip_if_interface() - assert result == expected + +# def test_get_gateway_ip_if_interface_args_access_point(): @pytest.mark.parametrize( @@ -526,81 +534,100 @@ def test_check_if_we_have_pending_mac_queries_empty_queue( mock_get_vendor.assert_not_called() +@pytest.fixture +def ip_info(): + """Return a mock-ready instance of the IP info manager.""" + obj = Mock() + obj.db = Mock() + obj._get_wifi_interface_if_ap = Mock(return_value=None) + obj._get_mac_using_ip_neigh = Mock(return_value=None) + obj._get_mac_using_arp_cache = Mock(return_value=None) + obj.get_own_mac = Mock(return_value=None) + obj.is_running_non_stop = True + return obj + + @pytest.mark.parametrize( - "gw_ip, cached_mac", + "gw_ips, mac_in_db, expected", [ - ("192.168.1.1", "00:11:22:33:44:55"), + ( + {"eth0": "192.168.1.1"}, + "AA:BB:CC:DD:EE:FF", + {"eth0": "AA:BB:CC:DD:EE:FF"}, + ), + ({"wlan0": "10.0.0.1"}, None, None), # no MAC found ], ) -def test_get_gateway_mac_cached(gw_ip, cached_mac): +def test_mac_found_in_db_or_not(ip_info, gw_ips, mac_in_db, expected): + """Test when MAC exists or not in the DB.""" ip_info = ModuleFactory().create_ip_info_obj() - ip_info.db.get_mac_addr_from_profile.return_value = cached_mac + ip_info.db.get_mac_addr_from_profile.return_value = mac_in_db + ip_info.is_running_non_stop = True + result = ip_info.get_gateway_mac(gw_ips) - result = ip_info.get_gateway_mac(gw_ip) + assert result == expected - assert result == cached_mac - ip_info.db.get_mac_addr_from_profile.assert_called_once_with( - f"profile_{gw_ip}" - ) + +def test_non_stop_false_skips_mac_lookup(): + """Should skip all lookups when not running non-stop.""" + ip_info = ModuleFactory().create_ip_info_obj() + ip_info.is_running_non_stop = False + ip_info.db.get_mac_addr_from_profile = Mock(return_value=None) + ip_info._get_mac_using_ip_neigh = Mock() + ip_info._get_mac_using_arp_cache = Mock() + result = ip_info.get_gateway_mac({"eth0": "192.168.1.1"}) + + assert result is None + ip_info._get_mac_using_ip_neigh.assert_not_called() + ip_info._get_mac_using_arp_cache.assert_not_called() -@pytest.mark.parametrize("gw_ip", ["192.168.0.1"]) -def test_get_gateway_mac_not_found(mocker, gw_ip): +def test_wifi_interface_uses_own_mac(): + """If interface is WiFi AP, use get_own_mac().""" ip_info = ModuleFactory().create_ip_info_obj() + ip_info._get_wifi_interface_if_ap = Mock(return_value="wlan0") + ip_info.get_own_mac = Mock(return_value="AA:AA:AA:AA:AA:AA") + ip_info.db.get_mac_addr_from_profile = Mock(return_value=None) + result = ip_info.get_gateway_mac({"wlan0": "10.0.0.1"}) - ip_info.db.get_mac_addr_from_profile.return_value = None - ip_info.is_running_non_stop = True - ip_info.is_running_in_ap_mode = False + assert result == {"wlan0": "AA:AA:AA:AA:AA:AA"} - mocker.patch("sys.argv", ["-i", "eth0"]) - mock_subprocess_run = mocker.patch("subprocess.run") - mock_subprocess_run.side_effect = subprocess.CalledProcessError( - 1, "ip neigh" - ) - mocker.patch( - "slips_files.common.slips_utils.utils.get_mac_for_ip_using_cache", - side_effect=subprocess.CalledProcessError(1, "arp"), - ) +def test_mac_from_ip_neigh(): + """Should use ip neigh if available.""" + ip_info = ModuleFactory().create_ip_info_obj() + ip_info.db.get_mac_addr_from_profile = Mock(return_value=None) + ip_info._get_mac_using_ip_neigh = Mock(return_value="AA:BB:CC:11:22:33") - result = ip_info.get_gateway_mac(gw_ip) + result = ip_info.get_gateway_mac({"eth0": "172.16.0.1"}) - assert result is None - assert mock_subprocess_run.call_count == 1 # only ip neigh is called - ip_info.db.set_default_gateway.assert_not_called() + assert result == {"eth0": "AA:BB:CC:11:22:33"} + ip_info._get_mac_using_ip_neigh.assert_called_once_with("172.16.0.1") -@pytest.mark.parametrize("gw_ip", ["172.16.0.1"]) -def test_get_gateway_mac_ip_command_failure(mocker, gw_ip): +def test_mac_from_arp_cache_if_ip_neigh_fails(): + """Should fallback to ARP cache if ip neigh fails.""" ip_info = ModuleFactory().create_ip_info_obj() + ip_info.db.get_mac_addr_from_profile = Mock(return_value=None) + ip_info._get_mac_using_ip_neigh = Mock(return_value=None) + ip_info._get_mac_using_arp_cache = Mock(return_value="FF:EE:DD:CC:BB:AA") - ip_info.db.get_mac_addr_from_profile.return_value = None - ip_info.is_running_non_stop = True - ip_info.is_running_in_ap_mode = False - - mocker.patch("sys.argv", ["-i", "eth0"]) + result = ip_info.get_gateway_mac({"eth0": "192.168.0.1"}) - mock_subprocess_run = mocker.patch("subprocess.run") - mock_subprocess_run.side_effect = subprocess.CalledProcessError( - 1, "ip neigh" - ) + assert result == {"eth0": "FF:EE:DD:CC:BB:AA"} + ip_info._get_mac_using_arp_cache.assert_called_once_with("192.168.0.1") - mocker.patch( - "slips_files.common.slips_utils.utils.get_mac_for_ip_using_cache", - side_effect=subprocess.CalledProcessError(1, "arp"), - ) - result = ip_info.get_gateway_mac(gw_ip) +def test_all_methods_fail_returns_none(): + """If all lookups fail, return None.""" + ip_info = ModuleFactory().create_ip_info_obj() + ip_info.db.get_mac_addr_from_profile = Mock(return_value=None) + ip_info._get_mac_using_ip_neigh = Mock(return_value=None) + ip_info._get_mac_using_arp_cache = Mock(return_value=None) + ip_info.get_own_mac = Mock(return_value=None) + result = ip_info.get_gateway_mac({"eth0": "1.1.1.1"}) assert result is None - assert mock_subprocess_run.call_count == 1 - mock_subprocess_run.assert_called_once_with( - ["ip", "neigh", "show", gw_ip], - capture_output=True, - check=True, - text=True, - ) - ip_info.db.set_default_gateway.assert_not_called() @pytest.mark.parametrize( diff --git a/tests/test_process_manager.py b/tests/test_process_manager.py index 06b4d1b5b..765aa73a4 100644 --- a/tests/test_process_manager.py +++ b/tests/test_process_manager.py @@ -317,7 +317,9 @@ def test_should_run_non_stop( process_manager = ModuleFactory().create_process_manager_obj() process_manager.is_debugger_active = Mock(return_value=debugger_active) process_manager.main.input_type = input_type - process_manager.main.is_interface = is_interface + process_manager.main.db.is_running_non_stop = Mock( + return_value=is_interface + ) assert process_manager.should_run_non_stop() == expected diff --git a/tests/test_profile_handler.py b/tests/test_profile_handler.py index 222c4fdd1..9a93eabc1 100644 --- a/tests/test_profile_handler.py +++ b/tests/test_profile_handler.py @@ -1515,7 +1515,7 @@ def test_add_mac_addr_to_profile_no_existing_mac(): handler.r.hmget.return_value = [None] handler.update_mac_of_profile = MagicMock() - result = handler.add_mac_addr_to_profile(profileid, mac_addr) + result = handler.add_mac_addr_to_profile(profileid, mac_addr, "eth0") handler.r.hmget.assert_called_once_with("MAC", mac_addr) handler.r.hset.assert_called_once_with( @@ -1537,7 +1537,7 @@ def test_add_mac_addr_to_profile_existing_mac(): # this should make [incoming_ip in cached_ips] True handler.r.hmget.return_value = [json.dumps([profileid.split("_")[1]])] handler.update_mac_of_profile = MagicMock() - result = handler.add_mac_addr_to_profile(profileid, mac_addr) + result = handler.add_mac_addr_to_profile(profileid, mac_addr, "eth0") assert result is False handler.r.hmget.assert_called_once_with("MAC", mac_addr) diff --git a/tests/test_profiler.py b/tests/test_profiler.py index ec086e4b4..0d68bbeff 100644 --- a/tests/test_profiler.py +++ b/tests/test_profiler.py @@ -166,7 +166,7 @@ def get_zeek_flow(file, flow_type): sample_flow = f.readline().replace("\n", "") sample_flow = json.loads(sample_flow) - sample_flow = {"data": sample_flow, "type": flow_type} + sample_flow = {"data": sample_flow, "type": flow_type, "interface": "eth0"} return sample_flow @@ -210,9 +210,7 @@ def test_process_line( if flow_type == "conn": flow_added = profiler.db.get_flow(flow.uid, twid=twid)[flow.uid] else: - flow_added = profiler.db.get_altflow_from_uid( - profileid, twid, flow.uid - ) + flow_added = profiler.db.get_altflow_from_uid(flow.uid) assert flow_added @@ -375,28 +373,26 @@ def test_convert_starttime_to_epoch_invalid_format(monkeypatch): @pytest.mark.parametrize( - "saddr, is_localnet_set, expected_result", + "saddr, localnet_cache, running_non_stop, expected_result", [ - ("192.168.1.1", False, True), - ("192.168.1.1", True, False), - ("8.8.8.8", False, False), + ("192.168.1.1", {}, True, True), + ("192.168.1.1", {"eth0": "some_ip"}, True, False), + ("8.8.8.8", {"default": "ip"}, False, False), ], ) -def test_should_set_localnet(saddr, is_localnet_set, expected_result): +def test_should_set_localnet( + saddr, localnet_cache, running_non_stop, expected_result +): profiler = ModuleFactory().create_profiler_obj() + profiler.db.is_running_non_stop = Mock(return_value=running_non_stop) + flow = Mock() flow.saddr = saddr - profiler.is_localnet_set = is_localnet_set + flow.interface = "eth0" - assert profiler.should_set_localnet(flow) == expected_result + profiler.localnet_cache = localnet_cache - -def test_should_set_localnet_already_set(): - profiler = ModuleFactory().create_profiler_obj() - profiler.is_localnet_set = True - flow = Mock(saddr="1.1.1.1") - result = profiler.should_set_localnet(flow) - assert result is False + assert profiler.should_set_localnet(flow) == expected_result def test_check_for_stop_msg(monkeypatch): @@ -589,6 +585,9 @@ def test_get_local_net(client_ips, saddr, expected_cidr): profiler = ModuleFactory().create_profiler_obj() profiler.args.interface = None + flow = Mock() + flow.saddr = saddr + if not client_ips: with patch.object( profiler, "get_private_client_ips", return_value=client_ips @@ -596,18 +595,14 @@ def test_get_local_net(client_ips, saddr, expected_cidr): "slips_files.common.slips_utils.Utils.get_cidr_of_private_ip", return_value="10.0.0.0/8", ): - flow = Mock() - flow.saddr = saddr - local_net = profiler.get_local_net(flow) + local_net = profiler.get_local_net_of_flow(flow) else: with patch.object( profiler, "get_private_client_ips", return_value=client_ips ): - flow = Mock() - flow.saddr = saddr - local_net = profiler.get_local_net(flow) + local_net = profiler.get_local_net_of_flow(flow) - assert local_net == expected_cidr + assert local_net == {"default": expected_cidr} def test_get_local_net_from_flow(): @@ -621,34 +616,34 @@ def test_get_local_net_from_flow(): ): flow = Mock() flow.saddr = "10.0.0.1" - local_net = profiler.get_local_net(flow) + local_net = profiler.get_local_net_of_flow(flow) - assert local_net == "10.0.0.0/8" + assert local_net == {"default": "10.0.0.0/8"} def test_handle_setting_local_net_when_already_set(): profiler = ModuleFactory().create_profiler_obj() - profiler.is_localnet_set = True + local_net = "192.168.1.0/24" + profiler.should_set_localnet = Mock(return_value=False) + profiler.localnet_cache = {"default": local_net} flow = Mock() profiler.handle_setting_local_net(flow) profiler.db.set_local_network.assert_not_called() -def test_handle_setting_local_net(monkeypatch): +def test_handle_setting_local_net(): profiler = ModuleFactory().create_profiler_obj() + local_net = "192.168.1.0/24" + profiler.should_set_localnet = Mock(return_value=True) + profiler.get_local_net_of_flow = Mock(return_value={"default": local_net}) + profiler.get_local_net = Mock(return_value=local_net) + profiler.db.is_running_non_stop = Mock(return_value=False) + flow = Mock() flow.saddr = "192.168.1.1" - monkeypatch.setattr( - profiler, "should_set_localnet", Mock(return_value=True) - ) - - monkeypatch.setattr( - profiler, "get_local_net", Mock(return_value="192.168.1.0/24") - ) - profiler.handle_setting_local_net(flow) - profiler.db.set_local_network.assert_called_once_with("192.168.1.0/24") + profiler.db.set_local_network.assert_called_once_with(local_net, "default") def test_notify_observers_no_observers(): @@ -709,12 +704,13 @@ def test_get_gateway_info_sets_mac_and_ip( dmac="00:11:22:33:44:55", state="Established", history="", + interface="eth0", ) profiler.get_gateway_info(flow) - profiler.db.set_default_gateway.assert_any_call("MAC", flow.dmac) - profiler.db.set_default_gateway.assert_any_call("IP", "8.8.8.1") + profiler.db.set_default_gateway.assert_any_call("MAC", flow.dmac, "eth0") + profiler.db.set_default_gateway.assert_any_call("IP", "8.8.8.1", "eth0") @patch("slips_files.core.profiler.utils.is_private_ip") @@ -771,25 +767,28 @@ def test_get_gateway_info_mac_detected_but_no_ip(): @pytest.mark.parametrize( "info_type, attr_name, db_method, db_value", [ - ("mac", "gw_mac", "get_gateway_mac", "00:1A:2B:3C:4D:5E"), - ("ip", "gw_ip", "get_gateway_ip", "192.168.1.1"), + ("mac", "gw_macs", "get_gateway_mac", "00:1A:2B:3C:4D:5E"), + ("ip", "gw_ips", "get_gateway_ip", "192.168.1.1"), ], ) def test_is_gw_info_detected(info_type, attr_name, db_method, db_value): # create a profiler object using the ModuleFactory profiler = ModuleFactory().create_profiler_obj() - # mock the profiler's database methods and attributes - setattr(profiler, attr_name, None) - getattr(profiler.db, db_method).return_value = db_value + # ensure gw_macs / gw_ips exist as dicts + setattr(profiler, attr_name, {}) + + # mock the db method + mock_method = Mock(return_value=db_value) + setattr(profiler.db, db_method, mock_method) - # test with info_type - result = profiler.is_gw_info_detected(info_type) + # call the function + result = profiler.is_gw_info_detected(info_type, "eth0") - # assertions - assert result - assert getattr(profiler, attr_name) == db_value - getattr(profiler.db, db_method).assert_called_once() + # verify + assert result is True + assert getattr(profiler, attr_name)["eth0"] == db_value + mock_method.assert_called_once_with("eth0") def test_is_gw_info_detected_unsupported_info_type(): @@ -798,26 +797,11 @@ def test_is_gw_info_detected_unsupported_info_type(): # test with an unsupported info_type with pytest.raises(ValueError) as exc_info: - profiler.is_gw_info_detected("unsupported_type") + profiler.is_gw_info_detected("unsupported_type", "eth0") - # assertion assert str(exc_info.value) == "Unsupported info_type: unsupported_type" -def test_is_gw_info_detected_when_attribute_is_already_set(): - profiler = ModuleFactory().create_profiler_obj() - - # set gw_mac attribute to a value - profiler.gw_mac = "00:1A:2B:3C:4D:5E" - - # test with info_type "mac" - result = profiler.is_gw_info_detected("mac") - - # assertions - assert result - assert profiler.gw_mac == "00:1A:2B:3C:4D:5E" - - def test_process_flow_no_msg(): profiler = ModuleFactory().create_profiler_obj() profiler.stop_profiler_thread = Mock() diff --git a/tests/test_timeline.py b/tests/test_timeline.py index cec7472dc..92101cc6a 100644 --- a/tests/test_timeline.py +++ b/tests/test_timeline.py @@ -567,7 +567,7 @@ def test_is_inbound_traffic( host_ip, daddr, analysis_direction, expected_result ): timeline = ModuleFactory().create_timeline_object() - timeline.host_ip = host_ip + timeline.host_ips = [host_ip] timeline.analysis_direction = analysis_direction flow = Mock() flow.daddr = daddr diff --git a/zeek-scripts/slips-conf.zeek b/zeek-scripts/slips-conf.zeek index 9be82512a..fdf8178e0 100644 --- a/zeek-scripts/slips-conf.zeek +++ b/zeek-scripts/slips-conf.zeek @@ -1,4 +1,5 @@ redef LogAscii::use_json=T; +redef tcp_attempt_delay=1min; function get_mmdb_path(): string { local curdir = @DIR; @@ -23,4 +24,4 @@ redef mmdb_dir = get_mmdb_path(); # zeek only tracks software for local networks by default to conserve memory. # this setting makes it do software for all networks -redef Software::asset_tracking = ALL_HOSTS; \ No newline at end of file +redef Software::asset_tracking = ALL_HOSTS;