diff --git a/ci/run_tests/run_tests.py b/ci/run_tests/run_tests.py index df4f3fe373..b606708202 100644 --- a/ci/run_tests/run_tests.py +++ b/ci/run_tests/run_tests.py @@ -3,22 +3,9 @@ import os import subprocess import argparse -from contextlib import contextmanager from scapy.all import get_if_addr -PCAP_FILE_PATH = os.path.join("Tests", "Pcap++Test", "PcapExamples", "example.pcap") - - -@contextmanager -def tcp_replay_worker(interface: str, tcpreplay_dir: str): - tcpreplay_proc = subprocess.Popen( - ["tcpreplay", "-i", interface, "--mbps=10", "-l", "0", PCAP_FILE_PATH], - cwd=tcpreplay_dir, - ) - try: - yield tcpreplay_proc - finally: - tcpreplay_proc.kill() +from tcp_replay_utils import TcpReplay, PCAP_FILE_PATH def run_packet_tests(args: list[str], use_sudo: bool): @@ -31,11 +18,13 @@ def run_packet_tests(args: list[str], use_sudo: bool): raise RuntimeError(f"Error while executing Packet++ tests: {completed_process}") -def run_pcap_tests(interface: str, tcpreplay_dir: str, args: list[str], use_sudo: bool): +def run_pcap_tests( + interface: str, tcp_replay: TcpReplay, args: list[str], use_sudo: bool +): ip_address = get_if_addr(interface) print(f"IP address is: {ip_address}") - with tcp_replay_worker(interface, tcpreplay_dir): + with tcp_replay.replay(interface, PCAP_FILE_PATH): cmd_line = ["sudo"] if use_sudo else [] cmd_line += [os.path.join("Bin", "Pcap++Test"), "-i", ip_address, *args] @@ -84,9 +73,11 @@ def main(): run_packet_tests(args.packet_test_args.split(), args.use_sudo) if "pcap" in args.test_suites: + tcp_replay = TcpReplay(args.tcpreplay_dir) + run_pcap_tests( args.interface, - args.tcpreplay_dir, + tcp_replay, args.pcap_test_args.split(), args.use_sudo, ) diff --git a/ci/run_tests/run_tests_windows.py b/ci/run_tests/run_tests_windows.py index 532241bfeb..894f1f8d4a 100644 --- a/ci/run_tests/run_tests_windows.py +++ b/ci/run_tests/run_tests_windows.py @@ -4,10 +4,9 @@ import scapy.arch.windows from ipaddress import IPv4Address +from tcp_replay_utils import TcpReplay, PCAP_FILE_PATH + TCPREPLAY_PATH = "tcpreplay-4.4.1-win" -PCAP_FILE_PATH = os.path.abspath( - os.path.join("Tests", "Pcap++Test", "PcapExamples", "example.pcap") -) def validate_ipv4_address(address): @@ -28,38 +27,99 @@ def get_ip_by_guid(guid): return None -def find_interface(): - completed_process = subprocess.run( - ["tcpreplay.exe", "--listnics"], +def find_interface(tcp_replay: TcpReplay): + nic_devices = tcp_replay.get_nic_list() + + for device in nic_devices: + nic_guid = device.lstrip("\\Device\\NPF_") + ip_address = get_ip_by_guid(nic_guid) + + if ip_address and not ip_address.startswith("169.254"): + completed_process = subprocess.run( + ["curl", "--interface", ip_address, "www.google.com"], + capture_output=True, + shell=True, + ) + if completed_process.returncode == 0: + return device, ip_address + + return None, None + + +def run_packet_tests(): + return subprocess.run( + os.path.join("Bin", "Packet++Test"), + cwd=os.path.join("Tests", "Packet++Test"), shell=True, - capture_output=True, - cwd=TCPREPLAY_PATH, + check=True, # Raise exception if the worker returns in non-zero status code ) - if completed_process.returncode != 0: - print('Error executing "tcpreplay.exe --listnics"!') - exit(1) - raw_nics_output = completed_process.stdout.decode("utf-8") - for row in raw_nics_output.split("\n")[2:]: - columns = row.split("\t") - if len(columns) > 1 and columns[1].startswith("\\Device\\NPF_"): - interface = columns[1] - try: - nic_guid = interface.lstrip("\\Device\\NPF_") - ip_address = get_ip_by_guid(nic_guid) - if ip_address.startswith("169.254"): - continue - completed_process = subprocess.run( - ["curl", "--interface", ip_address, "www.google.com"], - capture_output=True, - shell=True, - ) - if completed_process.returncode != 0: - continue - return interface, ip_address - except Exception: - pass - return None, None + +def run_packet_coverage(): + return subprocess.run( + [ + "OpenCppCoverage.exe", + "--verbose", + "--sources", + "Packet++", + "--sources", + "Pcap++", + "--sources", + "Common++", + "--excluded_sources", + "Tests", + "--export_type", + "cobertura:Packet++Coverage.xml", + "--", + os.path.join("Bin", "Packet++Test"), + ], + cwd=os.path.join("Tests", "Packet++Test"), + shell=True, + check=True, # Raise exception if the worker returns in non-zero status code + ) + + +def run_pcap_tests(ip_address: str, skip_tests: list[str]): + return subprocess.run( + [ + os.path.join("Bin", "Pcap++Test"), + "-i", + ip_address, + "-x", + ";".join(skip_tests), + ], + cwd=os.path.join("Tests", "Pcap++Test"), + shell=True, + check=True, # Raise exception if the worker returns in non-zero status code + ) + + +def run_pcap_coverage(ip_address: str, skip_tests: list[str]): + return subprocess.run( + [ + "OpenCppCoverage.exe", + "--verbose", + "--sources", + "Packet++", + "--sources", + "Pcap++", + "--sources", + "Common++", + "--excluded_sources", + "Tests", + "--export_type", + "cobertura:Pcap++Coverage.xml", + "--", + os.path.join("Bin", "Pcap++Test"), + "-i", + ip_address, + "-x", + ";".join(skip_tests), + ], + cwd=os.path.join("Tests", "Pcap++Test"), + shell=True, + check=True, # Raise exception if the worker returns in non-zero status code + ) def main(): @@ -81,93 +141,26 @@ def main(): ) args = parser.parse_args() - tcpreplay_interface, ip_address = find_interface() + if args.coverage: + run_packet_coverage() + else: + run_packet_tests() + + tcp_replay = TcpReplay(TCPREPLAY_PATH) + + tcpreplay_interface, ip_address = find_interface(tcp_replay) if not tcpreplay_interface or not ip_address: print("Cannot find an interface to run tests on!") exit(1) - print(f"Interface is {tcpreplay_interface} and IP address is {ip_address}") - - try: - tcpreplay_cmd = ( - f'tcpreplay.exe -i "{tcpreplay_interface}" --mbps=10 -l 0 {PCAP_FILE_PATH}' - ) - tcpreplay_proc = subprocess.Popen(tcpreplay_cmd, shell=True, cwd=TCPREPLAY_PATH) - if args.coverage: - completed_process = subprocess.run( - [ - "OpenCppCoverage.exe", - "--verbose", - "--sources", - "Packet++", - "--sources", - "Pcap++", - "--sources", - "Common++", - "--excluded_sources", - "Tests", - "--export_type", - "cobertura:Packet++Coverage.xml", - "--", - os.path.join("Bin", "Packet++Test"), - ], - cwd=os.path.join("Tests", "Packet++Test"), - shell=True, - ) - else: - completed_process = subprocess.run( - os.path.join("Bin", "Packet++Test"), - cwd=os.path.join("Tests", "Packet++Test"), - shell=True, - ) - if completed_process.returncode != 0: - print("Error while executing Packet++ tests: " + str(completed_process)) - exit(completed_process.returncode) + print(f"Interface is {tcpreplay_interface} and IP address is {ip_address}") - skip_tests = ["TestRemoteCapture"] + args.skip_tests + skip_tests = ["TestRemoteCapture"] + args.skip_tests + with tcp_replay.replay(tcpreplay_interface, PCAP_FILE_PATH): if args.coverage: - completed_process = subprocess.run( - [ - "OpenCppCoverage.exe", - "--verbose", - "--sources", - "Packet++", - "--sources", - "Pcap++", - "--sources", - "Common++", - "--excluded_sources", - "Tests", - "--export_type", - "cobertura:Pcap++Coverage.xml", - "--", - os.path.join("Bin", "Pcap++Test"), - "-i", - ip_address, - "-x", - ";".join(skip_tests), - ], - cwd=os.path.join("Tests", "Pcap++Test"), - shell=True, - ) + run_pcap_coverage(ip_address, skip_tests) else: - completed_process = subprocess.run( - [ - os.path.join("Bin", "Pcap++Test"), - "-i", - ip_address, - "-x", - ";".join(skip_tests), - ], - cwd=os.path.join("Tests", "Pcap++Test"), - shell=True, - ) - if completed_process.returncode != 0: - print("Error while executing Pcap++ tests: " + str(completed_process)) - exit(completed_process.returncode) - - finally: - subprocess.call(["taskkill", "/F", "/T", "/PID", str(tcpreplay_proc.pid)]) + run_pcap_tests(ip_address, skip_tests) if __name__ == "__main__": diff --git a/ci/run_tests/tcp_replay_utils.py b/ci/run_tests/tcp_replay_utils.py new file mode 100644 index 0000000000..370fcf4c56 --- /dev/null +++ b/ci/run_tests/tcp_replay_utils.py @@ -0,0 +1,87 @@ +from __future__ import annotations + +import subprocess +import sys +from collections.abc import Generator +from dataclasses import dataclass +from contextlib import contextmanager +from pathlib import Path + +PCAP_FILE_PATH = Path("Tests", "Pcap++Test", "PcapExamples", "example.pcap").absolute() + + +@dataclass +class TcpReplayTask: + """A replay task that holds the tcpreplay instance and the subprocess procedure.""" + + replay: TcpReplay + procedure: subprocess.Popen + + +class TcpReplay: + def __init__(self, tcpreplay_dir: str | None = None): + """ + A wrapper class for managing tcpreplay operations. + + :param tcpreplay_dir: Directory where tcpreplay is located. If None, assumes tcpreplay is in the system PATH. + """ + if tcpreplay_dir is None: + self.executable = Path("tcpreplay") + else: + self.executable = Path(tcpreplay_dir) / "tcpreplay" + + if sys.platform == "win32": + self.executable = self.executable.with_suffix(".exe") + + # Checking for executable existence does not work if it's in PATH + subprocess.run([self.executable, "--version"], capture_output=True, check=True) + + @contextmanager + def replay( + self, interface: str, pcap_file: Path + ) -> Generator[TcpReplayTask, None, None]: + """ + Context manager that starts tcpreplay and yields a TcpReplayTask. + + :param interface: Network interface to use for replaying packets. + :param pcap_file: Path to the pcap file to replay. + """ + cmd = [self.executable, "-i", interface, "--mbps=10", "-l", "0", str(pcap_file)] + proc = subprocess.Popen(cmd) + try: + yield TcpReplayTask(replay=self, procedure=proc) + finally: + self._kill_process(proc) + + def get_nic_list(self): + """ + Get the list of network interfaces using tcpreplay. Only works on Windows. + + :return: List of network interface names. + """ + if sys.platform != "win32": + # We don't use it on non-Windows platforms yet. + raise RuntimeError("This method is only supported on Windows!") + + completed_process = subprocess.run( + [self.executable, "--listnics"], + capture_output=True, + ) + if completed_process.returncode != 0: + raise RuntimeError('Error executing "tcpreplay --listnics"!') + + raw_nics_output = completed_process.stdout.decode("utf-8") + nics = [] + for row in raw_nics_output.split("\n")[2:]: + columns = row.split("\t") + if len(columns) > 1 and columns[1].startswith("\\Device\\NPF_"): + nics.append(columns[1]) + return nics + + @staticmethod + def _kill_process(proc: subprocess.Popen) -> None: + if sys.platform == "win32": + # Use taskkill to kill the process and its children + subprocess.call(["taskkill", "/F", "/T", "/PID", str(proc.pid)]) + else: + proc.kill()