From ad7d535ef88613fdfcc8ffa864e85eb4ad752c3c Mon Sep 17 00:00:00 2001 From: Vasundhara Volam Date: Mon, 27 Oct 2025 21:23:25 +0000 Subject: [PATCH 1/8] Add module state trasition definitions --- sonic_platform_base/module_base.py | 260 ++++++++++- tests/module_base_test.py | 695 +++++++++++++++++++---------- 2 files changed, 702 insertions(+), 253 deletions(-) diff --git a/sonic_platform_base/module_base.py b/sonic_platform_base/module_base.py index bd333571a..187323ae7 100644 --- a/sonic_platform_base/module_base.py +++ b/sonic_platform_base/module_base.py @@ -8,6 +8,7 @@ import sys import os import fcntl +import time from . import device_base import json import threading @@ -28,6 +29,7 @@ class ModuleBase(device_base.DeviceBase): DEVICE_TYPE = "module" PCI_OPERATION_LOCK_FILE_PATH = "/var/lock/{}_pci.lock" SENSORD_OPERATION_LOCK_FILE_PATH = "/var/lock/sensord.lock" + TRANSITION_OPERATION_LOCK_FILE_PATH = "/var/lock/{}_transition.lock" # Possible card types for modular chassis MODULE_TYPE_SUPERVISOR = "SUPERVISOR" @@ -85,7 +87,7 @@ def __init__(self): self._thermal_list = [] self._voltage_sensor_list = [] self._current_sensor_list = [] - self.state_db_connector = None + self.state_db = self.initialize_state_db() self.pci_bus_info = None # List of SfpBase-derived objects representing all sfps @@ -96,6 +98,18 @@ def __init__(self): # visibile in PCI domain on the module self._asic_list = [] + def initialize_state_db(self): + """ + Initializes and returns the state database connector. + + Returns: + A database connector object for the state database. + """ + if self.state_db is None: + import sonic_py_common.daemon_base as daemon_base + self.state_db = daemon_base.db_connect("STATE_DB") + return self.state_db + @contextlib.contextmanager def _file_operation_lock(self, lock_file_path): """Common file-based lock for operations using flock""" @@ -120,6 +134,13 @@ def _sensord_operation_lock(self): with self._file_operation_lock(lock_file_path): yield + @contextlib.contextmanager + def _transition_operation_lock(self): + """File-based lock for module state transition operations using flock""" + lock_file_path = self.TRANSITION_OPERATION_LOCK_FILE_PATH.format(self.get_name()) + with self._file_operation_lock(lock_file_path): + yield + def get_base_mac(self): """ Retrieves the base MAC address for the module @@ -342,16 +363,12 @@ def pci_entry_state_db(self, pcie_string, operation): RuntimeError: If state database connection fails """ try: - # Do not use import if swsscommon is not needed - import swsscommon PCIE_DETACH_INFO_TABLE_KEY = PCIE_DETACH_INFO_TABLE+"|"+pcie_string - if not self.state_db_connector: - self.state_db_connector = swsscommon.swsscommon.DBConnector("STATE_DB", 0) if operation == PCIE_OPERATION_ATTACHING: - self.state_db_connector.delete(PCIE_DETACH_INFO_TABLE_KEY) + self.state_db.delete(PCIE_DETACH_INFO_TABLE_KEY) return - self.state_db_connector.hset(PCIE_DETACH_INFO_TABLE_KEY, "bus_info", pcie_string) - self.state_db_connector.hset(PCIE_DETACH_INFO_TABLE_KEY, "dpu_state", operation) + self.state_db.hset(PCIE_DETACH_INFO_TABLE_KEY, "bus_info", pcie_string) + self.state_db.hset(PCIE_DETACH_INFO_TABLE_KEY, "dpu_state", operation) except Exception as e: sys.stderr.write("Failed to write pcie bus info to state database: {}\n".format(str(e))) @@ -391,6 +408,233 @@ def pci_reattach(self): """ raise NotImplementedError + def set_admin_state_gracefully(self, up): + """ + Request to keep the module in administratively up/down state with graceful shutdown. + + This function is designed for SmartSwitch platforms to ensure a graceful shutdown + of the module when transitioning to the admin-down state. + + For non-SmartSwitch platforms, use the standard set_admin_state() method. + + Args: + up: A boolean, True to set the admin-state to UP. False to set the + admin-state to DOWN. + Returns: + bool: True if the request has been issued successfully, False if not + """ + module_name = self.get_name() + # Set the module state to administratively up. + if up: + if not self.set_module_state_transition(module_name, "startup"): + sys.stderr.write("Failed to set module state transition for admin state UP\n") + return False + + admin_status = self.set_admin_state(True) + + # This is only valid on platforms which have pci_rescan sensord changes required. If it is not implemented, + # there are no actions taken during this function execution. + if not self.module_post_startup(): + sys.stderr.write("module_post_startup() failed\n") + + if not self.clear_module_state_transition(module_name): + sys.stderr.write("Failed to clear module state transition for admin state UP\n") + + return admin_status + else: + # Initiate graceful shutdown before setting admin state to down. + if not self.set_module_state_transition(module_name, "shutdown"): + sys.stderr.write("Failed to set module state transition for admin state DOWN\n") + return False + + # This is only valid on platforms which have pci_detach and sensord changes required. If it is not implemented, + # there are no actions taken during this function execution. + if not self.module_pre_shutdown(): + sys.stderr.write("module_pre_shutdown() failed\n") + + if not self._graceful_shutdown_handler(): + sys.stderr.write("Graceful shutdown handler failed or timed out for module: {}\n".format(module_name)) + # Proceeding with admin down even if graceful shutdown fails. + + admin_status = self.set_admin_state(False) + + if not self.clear_module_state_transition(module_name): + sys.stderr.write("Failed to clear module state transition for admin state DOWN\n") + + return admin_status + + ############################################## + # Smartswitch module helpers + ############################################## + _TRANSITION_TIMEOUT_DEFAULTS = { + "startup": 300, # 5 mins + "shutdown": 180, # 3 mins + "reboot": 240, # 4 mins + } + + _TRANSITION_TIMEOUTS_CACHE = None + + def _load_transition_timeouts(self) -> dict: + """ + Loads module state transition timeouts from /usr/share/sonic/platform/platform.json if present, + otherwise fall back to _TRANSITION_TIMEOUT_DEFAULTS. + + Reads the following keys from the JSON file: + - dpu_startup_timeout + - dpu_shutdown_timeout + - dpu_reboot_timeout + Returns: + dict: A dictionary with transition types as keys and their corresponding timeouts + in seconds as values. + """ + if ModuleBase._TRANSITION_TIMEOUTS_CACHE is not None: + return ModuleBase._TRANSITION_TIMEOUTS_CACHE + + timeouts = self._TRANSITION_TIMEOUT_DEFAULTS.copy() + platform_json_path = "/usr/share/sonic/platform/platform.json" + + try: + if os.path.exists(platform_json_path): + with open(platform_json_path, 'r') as f: + platform_data = json.load(f) + timeouts["startup"] = int(platform_data.get("dpu_startup_timeout", timeouts["startup"])) + timeouts["shutdown"] = int(platform_data.get("dpu_shutdown_timeout", timeouts["shutdown"])) + timeouts["reboot"] = int(platform_data.get("dpu_reboot_timeout", timeouts["reboot"])) + except Exception as e: + sys.stderr.write("Error loading transition timeouts from {}: {}\n".format(platform_json_path, str(e))) + + ModuleBase._TRANSITION_TIMEOUTS_CACHE = timeouts + return timeouts + + def _graceful_shutdown_handler(self): + """ + Initiates a graceful shutdown of the module invoked with state transition in progress. + + Returns: + bool: True if the shutdown process was initiated successfully, False otherwise. + """ + module_name = self.get_name() + + shutdown_timeout = self._load_transition_timeouts().get("shutdown", 180) + end_time = time.time() + shutdown_timeout + interval = 5 # seconds + + while time.time() <= end_time: + # (a) External completion: transition flag cleared by external process + if not self.get_module_state_transition(module_name): + return True + + time.sleep(interval) + + # (b) Timeout completion: proceed with shutdown after timeout + if time.time() >= end_time: + self.clear_module_state_transition(module_name) + sys.stderr.write("Shutdown timeout reached for module: {}. Proceeding with shutdown.\n".format(module_name)) + return True + + return False + + # ############################################################ + # Centralized APIs for CHASSIS_MODULE_TABLE transition flags + # ############################################################ + def set_module_state_transition(self, module_name, transition_type): + """ + Sets the module state transition flag 'state_transition_in_progress ' in the CHASSIS_MODULE_TABLE. + + Args: + db: The database connector. + module_name: The name of the module. + transition_type: The type of transition (e.g., "startup", "shutdown", "reset"). + Returns: + bool: Returns True if the flag is successfully set. + + If the flag was already set but has timed out, the function resets the flag and returns True. + Returns False in all other cases. + """ + transition_type = transition_type.lower() + if transition_type not in self._TRANSITION_TIMEOUT_DEFAULTS.keys(): + sys.stderr.write("Invalid transition type: {}\n".format(transition_type)) + return False + + module_name = module_name.upper() + module_key = "CHASSIS_MODULE_TABLE|" + module_name + db = self.state_db + + with self._transition_operation_lock(): + try: + current_flag = db.hget(module_key, "state_transition_in_progress") + if current_flag is None: + # Flag not set, set it now + db.hset(module_key, "state_transition_in_progress", "True") + db.hset(module_key, "transition_type", transition_type) + db.hset(module_key, "transition_start_time", str(int(time.time()))) + return True + else: + # Flag already set, check for timeout + start_time_str = db.hget(module_key, "transition_start_time") + if start_time_str is None: + sys.stderr.write("Missing start time for transition flag on module: {}\n".format(module_name)) + return False + + start_time = int(start_time_str) + current_time = int(time.time()) + timeout = self._load_transition_timeouts().get(transition_type, 0) + if current_time - start_time > timeout: + # Timeout occurred, reset the flag + db.hset(module_key, "state_transition_in_progress", "True") + db.hset(module_key, "transition_type", transition_type) + db.hset(module_key, "transition_start_time", str(current_time)) + return True + else: + # Still within timeout period + sys.stderr.write("Transition already in progress for module: {}\n".format(module_name)) + return False + except Exception as e: + sys.stderr.write("Error setting transition flag for module {}: {}\n".format(module_name, str(e))) + return False + + def clear_module_state_transition(self, module_name): + """ + Clears the module state transition flag 'state_transition_in_progress ' in the CHASSIS_MODULE_TABLE. + + Args: + db: The database connector. + module_name: The name of the module. + Returns: + bool: Returns True if the flag is successfully cleared, False otherwise. + """ + module_name = module_name.upper() + module_key = "CHASSIS_MODULE_TABLE|" + module_name + + with self._transition_operation_lock(): + try: + self.state_db.hdel(module_key, "state_transition_in_progress") + self.state_db.hdel(module_key, "transition_type") + self.state_db.hdel(module_key, "transition_start_time") + return True + except Exception as e: + sys.stderr.write("Error clearing transition flag for module {}: {}\n".format(module_name, str(e))) + return False + + def get_module_state_transition(self, module_name): + """ + Retrieves the module state transition flag 'state_transition_in_progress ' from the CHASSIS_MODULE_TABLE. + + Args: + db: The database connector. + module_name: The name of the module. + Returns: + bool: Returns True if the flag is set, False otherwise. + """ + module_name = module_name.upper() + module_key = "CHASSIS_MODULE_TABLE|" + module_name + + try: + current_flag = self.state_db.hget(module_key, "state_transition_in_progress") + return current_flag == "True" + except Exception as e: + return False + ############################################## # Component methods ############################################## diff --git a/tests/module_base_test.py b/tests/module_base_test.py index 025849e9f..c78bde477 100644 --- a/tests/module_base_test.py +++ b/tests/module_base_test.py @@ -1,264 +1,469 @@ -from sonic_platform_base.module_base import ModuleBase -import pytest +# Unit tests for sonic_platform_base.module_base.ModuleBase import json -import os -import fcntl -from unittest.mock import patch, MagicMock, call -from io import StringIO -import shutil - -class MockFile: - def __init__(self, data=None): - self.data = data - self.written_data = None - self.closed = False - self.fileno_called = False - - def __enter__(self): - return self +import time +from unittest.mock import MagicMock, patch, call +import pytest - def __exit__(self, *args): - self.closed = True +from sonic_platform_base.module_base import ModuleBase - def read(self): - return self.data - def write(self, data): - self.written_data = data +class MockFile: + """Minimal file-like object with a stable fileno() for flock tests.""" + def __init__(self, data=""): + self._data = data + self._closed = False + self.fileno_called = False + def __enter__(self): return self + def __exit__(self, *a): self._closed = True + def read(self): return self._data + def write(self, d): self._data = d def fileno(self): self.fileno_called = True return 123 class TestModuleBase: - - def test_module_base(self): - module = ModuleBase() - not_implemented_methods = [ - [module.get_dpu_id], - [module.get_reboot_cause], - [module.get_state_info], - [module.get_pci_bus_info], - [module.pci_detach], - [module.pci_reattach], - ] - - for method in not_implemented_methods: - exception_raised = False - try: - func = method[0] - args = method[1:] - func(*args) - except NotImplementedError: - exception_raised = True - - assert exception_raised - - def test_sensors(self): - module = ModuleBase() - assert(module.get_num_voltage_sensors() == 0) - assert(module.get_all_voltage_sensors() == []) - assert(module.get_voltage_sensor(0) == None) - module._voltage_sensor_list = ["s1"] - assert(module.get_all_voltage_sensors() == ["s1"]) - assert(module.get_voltage_sensor(0) == "s1") - assert(module.get_num_current_sensors() == 0) - assert(module.get_all_current_sensors() == []) - assert(module.get_current_sensor(0) == None) - module._current_sensor_list = ["s1"] - assert(module.get_all_current_sensors() == ["s1"]) - assert(module.get_current_sensor(0) == "s1") - + # ------------------------------------------------------------------ Setup -- + def setup_method(self): + # Prevent real DB connection during ModuleBase __init__ + self._db_patcher = patch("sonic_py_common.daemon_base.db_connect", lambda *a, **k: None) + self._db_patcher.start() + self.module = ModuleBase() + + def teardown_method(self): + self._db_patcher.stop() + + # ------------------------------------------------------ Not Implemented API -- + @pytest.mark.parametrize( + "method_name", + ["get_dpu_id", "get_reboot_cause", "get_state_info", "get_pci_bus_info", "pci_detach", "pci_reattach"], + ) + def test_not_implemented_methods_raise(self, method_name): + with pytest.raises(NotImplementedError): + getattr(self.module, method_name)() + + # -------------------------------------------------------------- Sensors API -- + def test_sensors_api(self): + assert self.module.get_num_voltage_sensors() == 0 + assert self.module.get_all_voltage_sensors() == [] + assert self.module.get_voltage_sensor(0) is None + assert self.module.get_num_current_sensors() == 0 + assert self.module.get_all_current_sensors() == [] + assert self.module.get_current_sensor(0) is None + + self.module._voltage_sensor_list = ["s1"] + self.module._current_sensor_list = ["s1"] + assert self.module.get_all_voltage_sensors() == ["s1"] + assert self.module.get_voltage_sensor(0) == "s1" + assert self.module.get_all_current_sensors() == ["s1"] + assert self.module.get_current_sensor(0) == "s1" + + # --------------------------------------------------------- PCI state in DB -- def test_pci_entry_state_db(self): - module = ModuleBase() - mock_connector = MagicMock() - module.state_db_connector = mock_connector + db = MagicMock() + self.module.state_db = db - module.pci_entry_state_db("0000:00:00.0", "detaching") - mock_connector.hset.assert_has_calls([ + self.module.pci_entry_state_db("0000:00:00.0", "detaching") + db.hset.assert_has_calls([ call("PCIE_DETACH_INFO|0000:00:00.0", "bus_info", "0000:00:00.0"), - call("PCIE_DETACH_INFO|0000:00:00.0", "dpu_state", "detaching") + call("PCIE_DETACH_INFO|0000:00:00.0", "dpu_state", "detaching"), ]) - module.pci_entry_state_db("0000:00:00.0", "attaching") - mock_connector.delete.assert_called_with("PCIE_DETACH_INFO|0000:00:00.0") - - mock_connector.hset.side_effect = Exception("DB Error") - module.pci_entry_state_db("0000:00:00.0", "detaching") - - def test_file_operation_lock(self): - module = ModuleBase() - mock_file = MockFile() - - with patch('builtins.open', return_value=mock_file) as mock_file_open, \ - patch('fcntl.flock') as mock_flock, \ - patch('os.makedirs') as mock_makedirs: - - with module._file_operation_lock("/var/lock/test.lock"): - mock_flock.assert_called_with(123, fcntl.LOCK_EX) - - mock_flock.assert_has_calls([ - call(123, fcntl.LOCK_EX), - call(123, fcntl.LOCK_UN) - ]) - assert mock_file.fileno_called - - def test_pci_operation_lock(self): - module = ModuleBase() - mock_file = MockFile() - - with patch('builtins.open', return_value=mock_file) as mock_file_open, \ - patch('fcntl.flock') as mock_flock, \ - patch.object(module, 'get_name', return_value="DPU0"), \ - patch('os.makedirs') as mock_makedirs: - - with module._pci_operation_lock(): - mock_flock.assert_called_with(123, fcntl.LOCK_EX) - - mock_flock.assert_has_calls([ - call(123, fcntl.LOCK_EX), - call(123, fcntl.LOCK_UN) - ]) - assert mock_file.fileno_called - - def test_sensord_operation_lock(self): - module = ModuleBase() - mock_file = MockFile() - - with patch('builtins.open', return_value=mock_file) as mock_file_open, \ - patch('fcntl.flock') as mock_flock, \ - patch.object(module, 'get_name', return_value="DPU0"), \ - patch('os.makedirs') as mock_makedirs: - - with module._sensord_operation_lock(): - mock_flock.assert_called_with(123, fcntl.LOCK_EX) - - mock_flock.assert_has_calls([ - call(123, fcntl.LOCK_EX), - call(123, fcntl.LOCK_UN) - ]) - assert mock_file.fileno_called - - def test_handle_pci_removal(self): - module = ModuleBase() - - with patch.object(module, 'get_pci_bus_info', return_value=["0000:00:00.0"]), \ - patch.object(module, 'pci_entry_state_db') as mock_db, \ - patch.object(module, 'pci_detach', return_value=True), \ - patch.object(module, '_pci_operation_lock') as mock_lock, \ - patch.object(module, 'get_name', return_value="DPU0"): - assert module.handle_pci_removal() is True - mock_db.assert_called_with("0000:00:00.0", "detaching") - mock_lock.assert_called_once() - - with patch.object(module, 'get_pci_bus_info', side_effect=Exception()): - assert module.handle_pci_removal() is False - - def test_handle_pci_rescan(self): - module = ModuleBase() - - with patch.object(module, 'get_pci_bus_info', return_value=["0000:00:00.0"]), \ - patch.object(module, 'pci_entry_state_db') as mock_db, \ - patch.object(module, 'pci_reattach', return_value=True), \ - patch.object(module, '_pci_operation_lock') as mock_lock, \ - patch.object(module, 'get_name', return_value="DPU0"): - assert module.handle_pci_rescan() is True - mock_db.assert_called_with("0000:00:00.0", "attaching") - mock_lock.assert_called_once() - - with patch.object(module, 'get_pci_bus_info', side_effect=Exception()): - assert module.handle_pci_rescan() is False - + self.module.pci_entry_state_db("0000:00:00.0", "attaching") + db.delete.assert_called_with("PCIE_DETACH_INFO|0000:00:00.0") + + db.hset.side_effect = Exception("DB Error") + self.module.pci_entry_state_db("0000:00:00.0", "detaching") # should not raise + + # -------------------------------------------------------------- File locks -- + @pytest.mark.parametrize( + "lock_method_name, extra", + [ + ("_file_operation_lock", {"lock_path": "/var/lock/test.lock"}), + ("_pci_operation_lock", {}), + ("_sensord_operation_lock", {}), + ("_transition_operation_lock", {}), + ], + ) + def test_lock_contexts(self, lock_method_name, extra): + mf = MockFile() + with patch("builtins.open", return_value=mf), \ + patch("fcntl.flock") as pflock, \ + patch("os.makedirs"), \ + patch.object(self.module, "get_name", return_value="DPU0"): + lock_ctx = getattr(self.module, lock_method_name) + if "lock_path" in extra: + with lock_ctx(extra["lock_path"]): + pass + else: + with lock_ctx(): + pass + + import fcntl + pflock.assert_has_calls([call(123, fcntl.LOCK_EX), call(123, fcntl.LOCK_UN)]) + assert mf.fileno_called + + # ---------------------------------------------------------- PCI operations -- + def test_handle_pci_removal_success(self): + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "get_pci_bus_info", return_value=["0000:00:00.0"]), \ + patch.object(self.module, "pci_entry_state_db") as mdb, \ + patch.object(self.module, "pci_detach", return_value=True), \ + patch.object(self.module, "_pci_operation_lock"): + assert self.module.handle_pci_removal() is True + mdb.assert_called_with("0000:00:00.0", "detaching") + + def test_handle_pci_removal_error(self): + with patch.object(self.module, "get_pci_bus_info", side_effect=Exception("boom")): + assert self.module.handle_pci_removal() is False + + def test_handle_pci_rescan_success(self): + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "get_pci_bus_info", return_value=["0000:00:00.0"]), \ + patch.object(self.module, "pci_entry_state_db") as mdb, \ + patch.object(self.module, "pci_reattach", return_value=True), \ + patch.object(self.module, "_pci_operation_lock"): + assert self.module.handle_pci_rescan() is True + mdb.assert_called_with("0000:00:00.0", "attaching") + + def test_handle_pci_rescan_error(self): + with patch.object(self.module, "get_pci_bus_info", side_effect=Exception("boom")): + assert self.module.handle_pci_rescan() is False + + # ---------------------------------------------------------- Sensor actions -- def test_handle_sensor_removal(self): - module = ModuleBase() - - with patch.object(module, 'get_name', return_value="DPU0"), \ - patch('os.path.exists', return_value=True), \ - patch('shutil.copy2') as mock_copy, \ - patch('os.system') as mock_system, \ - patch.object(module, '_sensord_operation_lock') as mock_lock: - assert module.handle_sensor_removal() is True - mock_copy.assert_called_once_with("/usr/share/sonic/platform/module_sensors_ignore_conf/ignore_sensors_DPU0.conf", - "/etc/sensors.d/ignore_sensors_DPU0.conf") - mock_system.assert_called_once_with("service sensord restart") - mock_lock.assert_called_once() - - with patch.object(module, 'get_name', return_value="DPU0"), \ - patch('os.path.exists', return_value=False), \ - patch('shutil.copy2') as mock_copy, \ - patch('os.system') as mock_system, \ - patch.object(module, '_sensord_operation_lock') as mock_lock: - assert module.handle_sensor_removal() is True - mock_copy.assert_not_called() - mock_system.assert_not_called() - mock_lock.assert_not_called() - - with patch.object(module, 'get_name', return_value="DPU0"), \ - patch('os.path.exists', return_value=True), \ - patch('shutil.copy2', side_effect=Exception("Copy failed")): - assert module.handle_sensor_removal() is False + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch("os.path.exists", return_value=True), \ + patch("shutil.copy2") as mcopy, \ + patch("os.system") as msys, \ + patch.object(self.module, "_sensord_operation_lock"): + assert self.module.handle_sensor_removal() is True + mcopy.assert_called_once_with( + "/usr/share/sonic/platform/module_sensors_ignore_conf/ignore_sensors_DPU0.conf", + "/etc/sensors.d/ignore_sensors_DPU0.conf", + ) + msys.assert_called_once_with("service sensord restart") + + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch("os.path.exists", return_value=False), \ + patch("shutil.copy2") as mcopy, \ + patch("os.system") as msys, \ + patch.object(self.module, "_sensord_operation_lock") as mlock: + assert self.module.handle_sensor_removal() is True + mcopy.assert_not_called() + msys.assert_not_called() + mlock.assert_not_called() + + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch("os.path.exists", return_value=True), \ + patch("shutil.copy2", side_effect=Exception("copy fail")): + assert self.module.handle_sensor_removal() is False def test_handle_sensor_addition(self): - module = ModuleBase() - - with patch.object(module, 'get_name', return_value="DPU0"), \ - patch('os.path.exists', return_value=True), \ - patch('os.remove') as mock_remove, \ - patch('os.system') as mock_system, \ - patch.object(module, '_sensord_operation_lock') as mock_lock: - assert module.handle_sensor_addition() is True - mock_remove.assert_called_once_with("/etc/sensors.d/ignore_sensors_DPU0.conf") - mock_system.assert_called_once_with("service sensord restart") - mock_lock.assert_called_once() - - with patch.object(module, 'get_name', return_value="DPU0"), \ - patch('os.path.exists', return_value=False), \ - patch('os.remove') as mock_remove, \ - patch('os.system') as mock_system, \ - patch.object(module, '_sensord_operation_lock') as mock_lock: - assert module.handle_sensor_addition() is True - mock_remove.assert_not_called() - mock_system.assert_not_called() - mock_lock.assert_not_called() - - with patch.object(module, 'get_name', return_value="DPU0"), \ - patch('os.path.exists', return_value=True), \ - patch('os.remove', side_effect=Exception("Remove failed")): - assert module.handle_sensor_addition() is False - - def test_module_pre_shutdown(self): - module = ModuleBase() - - # Test successful case - with patch.object(module, 'handle_pci_removal', return_value=True), \ - patch.object(module, 'handle_sensor_removal', return_value=True): - assert module.module_pre_shutdown() is True - - # Test PCI removal failure - with patch.object(module, 'handle_pci_removal', return_value=False), \ - patch.object(module, 'handle_sensor_removal', return_value=True): - assert module.module_pre_shutdown() is False - - # Test sensor removal failure - with patch.object(module, 'handle_pci_removal', return_value=True), \ - patch.object(module, 'handle_sensor_removal', return_value=False): - assert module.module_pre_shutdown() is False - - def test_module_post_startup(self): - module = ModuleBase() - - # Test successful case - with patch.object(module, 'handle_pci_rescan', return_value=True), \ - patch.object(module, 'handle_sensor_addition', return_value=True): - assert module.module_post_startup() is True - - # Test PCI rescan failure - with patch.object(module, 'handle_pci_rescan', return_value=False), \ - patch.object(module, 'handle_sensor_addition', return_value=True): - assert module.module_post_startup() is False - - # Test sensor addition failure - with patch.object(module, 'handle_pci_rescan', return_value=True), \ - patch.object(module, 'handle_sensor_addition', return_value=False): - assert module.module_post_startup() is False + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch("os.path.exists", return_value=True), \ + patch("os.remove") as mrm, \ + patch("os.system") as msys, \ + patch.object(self.module, "_sensord_operation_lock"): + assert self.module.handle_sensor_addition() is True + mrm.assert_called_once_with("/etc/sensors.d/ignore_sensors_DPU0.conf") + msys.assert_called_once_with("service sensord restart") + + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch("os.path.exists", return_value=False), \ + patch("os.remove") as mrm, \ + patch("os.system") as msys, \ + patch.object(self.module, "_sensord_operation_lock") as mlock: + assert self.module.handle_sensor_addition() is True + mrm.assert_not_called() + msys.assert_not_called() + mlock.assert_not_called() + + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch("os.path.exists", return_value=True), \ + patch("os.remove", side_effect=Exception("rm fail")): + assert self.module.handle_sensor_addition() is False + + # ------------------------------------------------ Pre-shutdown/Post-startup -- + @pytest.mark.parametrize( + "pci_ok,sensor_ok,expected", + [(True, True, True), (False, True, False), (True, False, False)], + ) + def test_module_pre_shutdown(self, pci_ok, sensor_ok, expected): + with patch.object(self.module, "handle_pci_removal", return_value=pci_ok), \ + patch.object(self.module, "handle_sensor_removal", return_value=sensor_ok): + assert self.module.module_pre_shutdown() is expected + + @pytest.mark.parametrize( + "pci_ok,sensor_ok,expected", + [(True, True, True), (False, True, False), (True, False, False)], + ) + def test_module_post_startup(self, pci_ok, sensor_ok, expected): + with patch.object(self.module, "handle_pci_rescan", return_value=pci_ok), \ + patch.object(self.module, "handle_sensor_addition", return_value=sensor_ok): + assert self.module.module_post_startup() is expected + + # -------------------------------------- set_admin_state_gracefully paths -- + @pytest.mark.parametrize("admin_up", [True, False]) + def test_set_admin_state_gracefully_success(self, admin_up): + db = MagicMock() + self.module.state_db = db + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "set_module_state_transition", return_value=True), \ + patch.object(self.module, "clear_module_state_transition", return_value=True), \ + patch.object(self.module, "set_admin_state", return_value=True) as mset: + if admin_up: + with patch.object(self.module, "module_post_startup", return_value=True): + assert self.module.set_admin_state_gracefully(True) is True + mset.assert_called_once_with(True) + else: + with patch.object(self.module, "module_pre_shutdown", return_value=True), \ + patch.object(self.module, "_graceful_shutdown_handler", return_value=True): + assert self.module.set_admin_state_gracefully(False) is True + mset.assert_called_once_with(False) + + def test_set_admin_state_gracefully_transition_fail(self, capsys): + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "set_module_state_transition", return_value=False): + assert self.module.set_admin_state_gracefully(True) is False + assert "Failed to set module state transition for admin state UP" in capsys.readouterr().err + + def test_set_admin_state_gracefully_post_startup_warn(self, capsys): + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "set_module_state_transition", return_value=True), \ + patch.object(self.module, "clear_module_state_transition", return_value=True), \ + patch.object(self.module, "set_admin_state", return_value=True), \ + patch.object(self.module, "module_post_startup", return_value=False): + assert self.module.set_admin_state_gracefully(True) is True + assert "module_post_startup() failed" in capsys.readouterr().err + + def test_set_admin_state_gracefully_pre_shutdown_warn(self, capsys): + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "set_module_state_transition", return_value=True), \ + patch.object(self.module, "clear_module_state_transition", return_value=True), \ + patch.object(self.module, "set_admin_state", return_value=True), \ + patch.object(self.module, "module_pre_shutdown", return_value=False), \ + patch.object(self.module, "_graceful_shutdown_handler", return_value=True): + assert self.module.set_admin_state_gracefully(False) is True + assert "module_pre_shutdown() failed" in capsys.readouterr().err + + # ----------------------------------------------------- Timeouts loading ---- + def test_load_transition_timeouts_defaults(self): + ModuleBase._TRANSITION_TIMEOUTS_CACHE = None + with patch("os.path.exists", return_value=False): + assert self.module._load_transition_timeouts() == {"startup": 300, "shutdown": 180, "reboot": 240} + + def test_load_transition_timeouts_custom(self): + ModuleBase._TRANSITION_TIMEOUTS_CACHE = None + data = {"dpu_startup_timeout": 600, "dpu_shutdown_timeout": 360, "dpu_reboot_timeout": 480} + mf = MockFile(json.dumps(data)) + with patch("os.path.exists", return_value=True), patch("builtins.open", return_value=mf): + assert self.module._load_transition_timeouts() == {"startup": 600, "shutdown": 360, "reboot": 480} + + def test_load_transition_timeouts_partial(self): + ModuleBase._TRANSITION_TIMEOUTS_CACHE = None + mf = MockFile(json.dumps({"dpu_startup_timeout": 500})) + with patch("os.path.exists", return_value=True), patch("builtins.open", return_value=mf): + assert self.module._load_transition_timeouts() == {"startup": 500, "shutdown": 180, "reboot": 240} + + def test_load_transition_timeouts_error(self): + ModuleBase._TRANSITION_TIMEOUTS_CACHE = None + with patch("os.path.exists", return_value=True), \ + patch("builtins.open", side_effect=Exception("read error")): + assert self.module._load_transition_timeouts() == {"startup": 300, "shutdown": 180, "reboot": 240} + + def test_load_transition_timeouts_cache(self): + ModuleBase._TRANSITION_TIMEOUTS_CACHE = None + with patch("os.path.exists", return_value=False) as pexists: + t1 = self.module._load_transition_timeouts() + t2 = self.module._load_transition_timeouts() + assert t1 == t2 + pexists.assert_called_once() + ModuleBase._TRANSITION_TIMEOUTS_CACHE = None + + # -------------------------------------- Graceful shutdown wait-loop -------- + def test_graceful_shutdown_handler_external_completion(self): + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_load_transition_timeouts", return_value={"shutdown": 180}), \ + patch.object(self.module, "get_module_state_transition", side_effect=[True, False]), \ + patch("time.sleep") as ms, \ + patch("time.time", side_effect=[1000, 1000, 1005, 1005]): + assert self.module._graceful_shutdown_handler() is True + ms.assert_called_once_with(5) + + def test_graceful_shutdown_handler_timeout(self, capsys): + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_load_transition_timeouts", return_value={"shutdown": 10}), \ + patch.object(self.module, "get_module_state_transition", return_value=True), \ + patch("time.sleep") as ms, \ + patch("time.time", side_effect=[1000, 1005, 1005, 1010, 1010]): + assert self.module._graceful_shutdown_handler() is True + ms.assert_called_with(5) + assert "Shutdown timeout reached for module: DPU0" in capsys.readouterr().err + + def test_graceful_shutdown_handler_immediate_past_end(self): + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_load_transition_timeouts", return_value={"shutdown": 10}), \ + patch("time.sleep"), \ + patch("time.time", side_effect=[1000, 1020, 1020]): + assert self.module._graceful_shutdown_handler() is False + + # -------------------------------- set/get/clear transition flags ----------- + def _key(self, mod="DPU0"): + return f"CHASSIS_MODULE_TABLE|{mod}" + + def test_set_module_state_transition_happy(self): + db = MagicMock() + self.module.state_db = db + db.hget.return_value = None + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_transition_operation_lock"), \ + patch("time.time", return_value=1000): + assert self.module.set_module_state_transition("dpu0", "startup") is True + db.hset.assert_has_calls([ + call(self._key("DPU0"), "state_transition_in_progress", "True"), + call(self._key("DPU0"), "transition_type", "startup"), + call(self._key("DPU0"), "transition_start_time", "1000"), + ]) + + def test_set_module_state_transition_within_timeout(self, capsys): + db = MagicMock() + self.module.state_db = db + db.hget.side_effect = ["True", "950"] + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_transition_operation_lock"), \ + patch.object(self.module, "_load_transition_timeouts", return_value={"startup": 300}), \ + patch("time.time", return_value=1000): + assert self.module.set_module_state_transition("dpu0", "startup") is False + assert "Transition already in progress" in capsys.readouterr().err + db.hset.assert_not_called() + + @pytest.mark.parametrize("elapsed,timeout,expected", [(400, 300, True), (150, 300, False)]) + def test_set_module_state_transition_timeout_behavior(self, elapsed, timeout, expected): + db = MagicMock() + self.module.state_db = db + db.hget.side_effect = ["True", str(1000 - elapsed)] + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_transition_operation_lock"), \ + patch.object(self.module, "_load_transition_timeouts", return_value={"startup": timeout}), \ + patch("time.time", return_value=1000): + assert self.module.set_module_state_transition("dpu0", "startup") is expected + + def test_set_module_state_transition_input_validation(self, capsys): + db = MagicMock() + self.module.state_db = db + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_transition_operation_lock"): + assert self.module.set_module_state_transition("dpu0", "invalid") is False + assert "Invalid transition type: invalid" in capsys.readouterr().err + + def test_set_module_state_transition_missing_start_time(self, capsys): + db = MagicMock() + self.module.state_db = db + db.hget.side_effect = ["True", None] + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_transition_operation_lock"): + assert self.module.set_module_state_transition("dpu0", "startup") is False + assert "Missing start time" in capsys.readouterr().err + + def test_set_module_state_transition_db_errors(self, capsys): + db = MagicMock() + self.module.state_db = db + + db.hget.side_effect = Exception("DB Error") + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_transition_operation_lock"): + assert self.module.set_module_state_transition("dpu0", "startup") is False + assert "Error setting transition flag for module DPU0: DB Error" in capsys.readouterr().err + + db.hget.side_effect = None + db.hget.return_value = None + db.hset.side_effect = Exception("DB Error") + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_transition_operation_lock"), \ + patch("time.time", return_value=1000): + assert self.module.set_module_state_transition("dpu0", "startup") is False + + @pytest.mark.parametrize("tt", ["startup", "shutdown", "reboot"]) + def test_set_module_state_transition_types(self, tt): + db = MagicMock() + self.module.state_db = db + db.hget.return_value = None + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_transition_operation_lock"), \ + patch("time.time", return_value=1000): + assert self.module.set_module_state_transition("dpu0", tt) is True + db.hset.assert_any_call(self._key("DPU0"), "transition_type", tt) + + # ---------------------------------------------------------- clear / get ---- + def test_clear_module_state_transition(self): + db = MagicMock() + self.module.state_db = db + with patch.object(self.module, "_transition_operation_lock"), \ + patch.object(self.module, "get_name", return_value="DPU0"): + assert self.module.clear_module_state_transition("dpu0") is True + db.hdel.assert_has_calls([ + call(self._key("DPU0"), "state_transition_in_progress"), + call(self._key("DPU0"), "transition_type"), + call(self._key("DPU0"), "transition_start_time"), + ]) + + def test_clear_module_state_transition_db_error(self, capsys): + db = MagicMock() + self.module.state_db = db + db.hdel.side_effect = Exception("DB Error") + with patch.object(self.module, "_transition_operation_lock"), \ + patch.object(self.module, "get_name", return_value="DPU0"): + assert self.module.clear_module_state_transition("dpu0") is False + assert "Error clearing transition flag for module DPU0: DB Error" in capsys.readouterr().err + + @pytest.mark.parametrize("mod", ["DPU0", "LINE-CARD1", "SUPERVISOR0", "FABRIC-CARD0"]) + def test_clear_module_state_transition_various_modules(self, mod): + db = MagicMock() + self.module.state_db = db + with patch.object(self.module, "_transition_operation_lock"), \ + patch.object(self.module, "get_name", return_value="DPU0"): + assert self.module.clear_module_state_transition(mod.lower()) is True + db.hdel.assert_any_call(self._key(mod), "state_transition_in_progress") + + @pytest.mark.parametrize("ret,expected", [("True", True), (None, False), ("False", False), ("weird", False)]) + def test_get_module_state_transition(self, ret, expected): + db = MagicMock() + self.module.state_db = db + db.hget.return_value = ret + assert self.module.get_module_state_transition("dpu0") is expected + db.hget.assert_called_with(self._key("DPU0"), "state_transition_in_progress") + + def test_get_module_state_transition_db_error(self, capsys): + db = MagicMock() + self.module.state_db = db + db.hget.side_effect = Exception("DB Error") + assert self.module.get_module_state_transition("dpu0") is False + + @pytest.mark.parametrize("mod", ["DPU0", "LINE-CARD1", "SUPERVISOR0", "FABRIC-CARD0"]) + def test_get_module_state_transition_various_modules(self, mod): + db = MagicMock() + self.module.state_db = db + db.hget.return_value = "True" + assert self.module.get_module_state_transition(mod.lower()) is True + db.hget.assert_called_with(self._key(mod), "state_transition_in_progress") + + # ---------------------------------- Edge timeout semantics coverage -------- + @pytest.mark.parametrize( + "timeouts,hget_vals,now,expected", + [ + ({"startup": 0, "shutdown": 0, "reboot": 0}, ["True", str(int(time.time()))], time.time() + 1, True), + ({"startup": 999999999, "shutdown": 999999999, "reboot": 999999999}, ["True", "1"], 1_000_000, False), + ({"startup": -1, "shutdown": -1, "reboot": -1}, ["True", str(int(time.time()))], time.time() + 1, True), + ], + ) + def test_transition_timeout_edge_cases(self, timeouts, hget_vals, now, expected): + db = MagicMock() + self.module.state_db = db + db.hget.side_effect = hget_vals + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_transition_operation_lock"), \ + patch.object(self.module, "_load_transition_timeouts", return_value=timeouts), \ + patch("time.time", return_value=now): + assert self.module.set_module_state_transition("dpu0", "startup") is expected From 1aac4809589cc1e158fb70139f35964338ecd0f8 Mon Sep 17 00:00:00 2001 From: Vasundhara Volam Date: Sat, 1 Nov 2025 00:17:18 +0000 Subject: [PATCH 2/8] Initialize state_db to None --- sonic_platform_base/module_base.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sonic_platform_base/module_base.py b/sonic_platform_base/module_base.py index 187323ae7..3ef6eb14b 100644 --- a/sonic_platform_base/module_base.py +++ b/sonic_platform_base/module_base.py @@ -87,7 +87,7 @@ def __init__(self): self._thermal_list = [] self._voltage_sensor_list = [] self._current_sensor_list = [] - self.state_db = self.initialize_state_db() + self.state_db = None self.pci_bus_info = None # List of SfpBase-derived objects representing all sfps @@ -97,6 +97,8 @@ def __init__(self): # List of ASIC-derived objects representing all ASICs # visibile in PCI domain on the module self._asic_list = [] + + self.state_db = self.initialize_state_db() def initialize_state_db(self): """ From b94cc2eb8909d8589d799a0ce9e05ca24af509ea Mon Sep 17 00:00:00 2001 From: Vasundhara Volam Date: Sat, 1 Nov 2025 00:36:27 +0000 Subject: [PATCH 3/8] Fix extra spaces --- sonic_platform_base/module_base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sonic_platform_base/module_base.py b/sonic_platform_base/module_base.py index 3ef6eb14b..ed15509db 100644 --- a/sonic_platform_base/module_base.py +++ b/sonic_platform_base/module_base.py @@ -541,7 +541,7 @@ def _graceful_shutdown_handler(self): # ############################################################ def set_module_state_transition(self, module_name, transition_type): """ - Sets the module state transition flag 'state_transition_in_progress ' in the CHASSIS_MODULE_TABLE. + Sets the module state transition flag 'state_transition_in_progress' in the CHASSIS_MODULE_TABLE. Args: db: The database connector. @@ -597,7 +597,7 @@ def set_module_state_transition(self, module_name, transition_type): def clear_module_state_transition(self, module_name): """ - Clears the module state transition flag 'state_transition_in_progress ' in the CHASSIS_MODULE_TABLE. + Clears the module state transition flag 'state_transition_in_progress' in the CHASSIS_MODULE_TABLE. Args: db: The database connector. @@ -620,7 +620,7 @@ def clear_module_state_transition(self, module_name): def get_module_state_transition(self, module_name): """ - Retrieves the module state transition flag 'state_transition_in_progress ' from the CHASSIS_MODULE_TABLE. + Retrieves the module state transition flag 'state_transition_in_progress' from the CHASSIS_MODULE_TABLE. Args: db: The database connector. From d67a33fc2084d66cfd0de807217582bc935d692f Mon Sep 17 00:00:00 2001 From: Vasundhara Volam Date: Sat, 1 Nov 2025 05:45:36 +0000 Subject: [PATCH 4/8] Increase coverage for unit tests --- tests/module_base_test.py | 69 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/tests/module_base_test.py b/tests/module_base_test.py index c78bde477..3855b77d0 100644 --- a/tests/module_base_test.py +++ b/tests/module_base_test.py @@ -248,6 +248,75 @@ def test_set_admin_state_gracefully_pre_shutdown_warn(self, capsys): assert self.module.set_admin_state_gracefully(False) is True assert "module_pre_shutdown() failed" in capsys.readouterr().err + def test_set_admin_state_gracefully_clear_transition_fail_up(self, capsys): + """Test clear_module_state_transition failure for admin UP path (lines 442-443)""" + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "set_module_state_transition", return_value=True), \ + patch.object(self.module, "clear_module_state_transition", return_value=False), \ + patch.object(self.module, "set_admin_state", return_value=True), \ + patch.object(self.module, "module_post_startup", return_value=True): + assert self.module.set_admin_state_gracefully(True) is True + assert "Failed to clear module state transition for admin state UP" in capsys.readouterr().err + + def test_set_admin_state_gracefully_clear_transition_fail_down(self, capsys): + """Test clear_module_state_transition failure for admin DOWN path (lines 463-464)""" + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "set_module_state_transition", return_value=True), \ + patch.object(self.module, "clear_module_state_transition", return_value=False), \ + patch.object(self.module, "set_admin_state", return_value=True), \ + patch.object(self.module, "module_pre_shutdown", return_value=True), \ + patch.object(self.module, "_graceful_shutdown_handler", return_value=True): + assert self.module.set_admin_state_gracefully(False) is True + assert "Failed to clear module state transition for admin state DOWN" in capsys.readouterr().err + + def test_set_admin_state_gracefully_set_transition_fail_down(self, capsys): + """Test set_module_state_transition failure for admin DOWN path (lines 448-450)""" + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "set_module_state_transition", return_value=False): + assert self.module.set_admin_state_gracefully(False) is False + assert "Failed to set module state transition for admin state DOWN" in capsys.readouterr().err + + def test_set_admin_state_gracefully_graceful_shutdown_fail(self, capsys): + """Test graceful shutdown handler failure/timeout (lines 456-458)""" + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "set_module_state_transition", return_value=True), \ + patch.object(self.module, "clear_module_state_transition", return_value=True), \ + patch.object(self.module, "set_admin_state", return_value=True), \ + patch.object(self.module, "module_pre_shutdown", return_value=True), \ + patch.object(self.module, "_graceful_shutdown_handler", return_value=False): + assert self.module.set_admin_state_gracefully(False) is True + assert "Graceful shutdown handler failed or timed out for module: DPU0" in capsys.readouterr().err + + def test_set_admin_state_gracefully_all_failures_up_path(self, capsys): + """Test multiple failure scenarios in the UP path for maximum coverage""" + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "set_module_state_transition", return_value=True), \ + patch.object(self.module, "clear_module_state_transition", return_value=False), \ + patch.object(self.module, "set_admin_state", return_value=True), \ + patch.object(self.module, "module_post_startup", return_value=False): + result = self.module.set_admin_state_gracefully(True) + assert result is True # Method continues despite failures + + captured = capsys.readouterr().err + assert "module_post_startup() failed" in captured + assert "Failed to clear module state transition for admin state UP" in captured + + def test_set_admin_state_gracefully_all_failures_down_path(self, capsys): + """Test multiple failure scenarios in the DOWN path for maximum coverage""" + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "set_module_state_transition", return_value=True), \ + patch.object(self.module, "clear_module_state_transition", return_value=False), \ + patch.object(self.module, "set_admin_state", return_value=True), \ + patch.object(self.module, "module_pre_shutdown", return_value=False), \ + patch.object(self.module, "_graceful_shutdown_handler", return_value=False): + result = self.module.set_admin_state_gracefully(False) + assert result is True # Method continues despite failures + + captured = capsys.readouterr().err + assert "module_pre_shutdown() failed" in captured + assert "Graceful shutdown handler failed or timed out for module: DPU0" in captured + assert "Failed to clear module state transition for admin state DOWN" in captured + # ----------------------------------------------------- Timeouts loading ---- def test_load_transition_timeouts_defaults(self): ModuleBase._TRANSITION_TIMEOUTS_CACHE = None From 4a1bf8834c5858fc4d66a272afbade726a8d41a3 Mon Sep 17 00:00:00 2001 From: Vasundhara Volam Date: Wed, 5 Nov 2025 22:21:00 +0000 Subject: [PATCH 5/8] Introduce gnoi_halt_in_progress field for tracking gnoi shutdown progress --- sonic_platform_base/module_base.py | 89 ++++++++---- tests/module_base_test.py | 214 ++++++++++++++++++++++++++--- 2 files changed, 257 insertions(+), 46 deletions(-) diff --git a/sonic_platform_base/module_base.py b/sonic_platform_base/module_base.py index 27ecdf580..9a1717bd4 100644 --- a/sonic_platform_base/module_base.py +++ b/sonic_platform_base/module_base.py @@ -468,12 +468,13 @@ def set_admin_state_gracefully(self, up): return admin_status ############################################## - # Smartswitch module helpers + # Smartswitch module helpers (Referenced only in module_base.py) ############################################## _TRANSITION_TIMEOUT_DEFAULTS = { - "startup": 300, # 5 mins - "shutdown": 180, # 3 mins - "reboot": 240, # 4 mins + "startup": 300, # 5 mins + "shutdown": 180, # 3 mins + "reboot": 240, # 4 mins + "halt_services": 60 # 1 min } _TRANSITION_TIMEOUTS_CACHE = None @@ -487,6 +488,7 @@ def _load_transition_timeouts(self) -> dict: - dpu_startup_timeout - dpu_shutdown_timeout - dpu_reboot_timeout + - dpu_halt_services_timeout Returns: dict: A dictionary with transition types as keys and their corresponding timeouts in seconds as values. @@ -504,12 +506,48 @@ def _load_transition_timeouts(self) -> dict: timeouts["startup"] = int(platform_data.get("dpu_startup_timeout", timeouts["startup"])) timeouts["shutdown"] = int(platform_data.get("dpu_shutdown_timeout", timeouts["shutdown"])) timeouts["reboot"] = int(platform_data.get("dpu_reboot_timeout", timeouts["reboot"])) + timeouts["halt_services"] = int(platform_data.get("dpu_halt_services_timeout", + timeouts["halt_services"])) except Exception as e: sys.stderr.write("Error loading transition timeouts from {}: {}\n".format(platform_json_path, str(e))) ModuleBase._TRANSITION_TIMEOUTS_CACHE = timeouts return timeouts + def _get_module_gnoi_halt_in_progress(self): + """ + Checks if the GNOI halt operation is in progress for the module. + + Returns: + bool: True if the GNOI halt operation is in progress, False otherwise. + """ + module_name = self.get_name() + module_key = "CHASSIS_MODULE_TABLE|" + module_name + + with self._transition_operation_lock(): + try: + gnoi_halt_flag = self.state_db.hget(module_key, "gnoi_halt_in_progress") + return gnoi_halt_flag == "True" + except Exception as e: + return False + + def _clear_module_gnoi_halt_in_progress(self): + """ + Clears the GNOI halt operation flag for the module. + + Returns: + bool: True if the flag is successfully cleared, False otherwise. + """ + module_name = self.get_name() + module_key = "CHASSIS_MODULE_TABLE|" + module_name + + with self._transition_operation_lock(): + try: + self.state_db.hdel(module_key, "gnoi_halt_in_progress") + return True + except Exception as e: + return False + def _graceful_shutdown_handler(self): """ Initiates a graceful shutdown of the module invoked with state transition in progress. @@ -519,20 +557,20 @@ def _graceful_shutdown_handler(self): """ module_name = self.get_name() - shutdown_timeout = self._load_transition_timeouts().get("shutdown", 180) - end_time = time.time() + shutdown_timeout + halt_timeout = self._load_transition_timeouts().get("halt_services", 60) + end_time = time.time() + halt_timeout interval = 5 # seconds while time.time() <= end_time: - # (a) External completion: transition flag cleared by external process - if not self.get_module_state_transition(module_name): + # (a) External completion: gnoi_halt_in_progress flag cleared by external process + if not self._get_module_gnoi_halt_in_progress(): return True time.sleep(interval) - # (b) Timeout completion: proceed with shutdown after timeout + # (b) Timeout completion: proceed with shutdown after halt_services timeout if time.time() >= end_time: - self.clear_module_state_transition(module_name) + self._clear_module_gnoi_halt_in_progress() sys.stderr.write("Shutdown timeout reached for module: {}. Proceeding with shutdown.\n".format(module_name)) return True @@ -543,10 +581,9 @@ def _graceful_shutdown_handler(self): # ############################################################ def set_module_state_transition(self, module_name, transition_type): """ - Sets the module state transition flag 'state_transition_in_progress' in the CHASSIS_MODULE_TABLE. + Sets the module state transition flag 'transition_in_progress' and corresponding fields in the CHASSIS_MODULE_TABLE. Args: - db: The database connector. module_name: The name of the module. transition_type: The type of transition (e.g., "startup", "shutdown", "reset"). Returns: @@ -566,11 +603,14 @@ def set_module_state_transition(self, module_name, transition_type): with self._transition_operation_lock(): try: - current_flag = db.hget(module_key, "state_transition_in_progress") + current_flag = db.hget(module_key, "transition_in_progress") if current_flag is None: # Flag not set, set it now - db.hset(module_key, "state_transition_in_progress", "True") + db.hset(module_key, "transition_in_progress", "True") db.hset(module_key, "transition_type", transition_type) + # If transition_type is 'shutdown', set the gnoi_halt_in_progress flag + if transition_type == "shutdown": + db.hset(module_key, "gnoi_halt_in_progress", "True") db.hset(module_key, "transition_start_time", str(int(time.time()))) return True else: @@ -585,7 +625,7 @@ def set_module_state_transition(self, module_name, transition_type): timeout = self._load_transition_timeouts().get(transition_type, 0) if current_time - start_time > timeout: # Timeout occurred, reset the flag - db.hset(module_key, "state_transition_in_progress", "True") + db.hset(module_key, "transition_in_progress", "True") db.hset(module_key, "transition_type", transition_type) db.hset(module_key, "transition_start_time", str(current_time)) return True @@ -599,10 +639,9 @@ def set_module_state_transition(self, module_name, transition_type): def clear_module_state_transition(self, module_name): """ - Clears the module state transition flag 'state_transition_in_progress' in the CHASSIS_MODULE_TABLE. + Clears the module state transition flag 'transition_in_progress' and corresponding fields in the CHASSIS_MODULE_TABLE. Args: - db: The database connector. module_name: The name of the module. Returns: bool: Returns True if the flag is successfully cleared, False otherwise. @@ -612,7 +651,7 @@ def clear_module_state_transition(self, module_name): with self._transition_operation_lock(): try: - self.state_db.hdel(module_key, "state_transition_in_progress") + self.state_db.hdel(module_key, "transition_in_progress") self.state_db.hdel(module_key, "transition_type") self.state_db.hdel(module_key, "transition_start_time") return True @@ -622,10 +661,9 @@ def clear_module_state_transition(self, module_name): def get_module_state_transition(self, module_name): """ - Retrieves the module state transition flag 'state_transition_in_progress' from the CHASSIS_MODULE_TABLE. + Retrieves the module state transition flag 'transition_in_progress' from the CHASSIS_MODULE_TABLE. Args: - db: The database connector. module_name: The name of the module. Returns: bool: Returns True if the flag is set, False otherwise. @@ -633,11 +671,12 @@ def get_module_state_transition(self, module_name): module_name = module_name.upper() module_key = "CHASSIS_MODULE_TABLE|" + module_name - try: - current_flag = self.state_db.hget(module_key, "state_transition_in_progress") - return current_flag == "True" - except Exception as e: - return False + with self._transition_operation_lock(): + try: + current_flag = self.state_db.hget(module_key, "transition_in_progress") + return current_flag == "True" + except Exception as e: + return False ############################################## # Component methods diff --git a/tests/module_base_test.py b/tests/module_base_test.py index 60b3863a3..64ef467ef 100644 --- a/tests/module_base_test.py +++ b/tests/module_base_test.py @@ -436,26 +436,51 @@ def test_set_admin_state_gracefully_all_failures_down_path(self, capsys): def test_load_transition_timeouts_defaults(self): ModuleBase._TRANSITION_TIMEOUTS_CACHE = None with patch("os.path.exists", return_value=False): - assert self.module._load_transition_timeouts() == {"startup": 300, "shutdown": 180, "reboot": 240} + assert self.module._load_transition_timeouts() == { + "startup": 300, + "shutdown": 180, + "reboot": 240, + "halt_services": 60 + } def test_load_transition_timeouts_custom(self): ModuleBase._TRANSITION_TIMEOUTS_CACHE = None - data = {"dpu_startup_timeout": 600, "dpu_shutdown_timeout": 360, "dpu_reboot_timeout": 480} + data = { + "dpu_startup_timeout": 600, + "dpu_shutdown_timeout": 360, + "dpu_reboot_timeout": 480, + "dpu_halt_services_timeout": 120 + } mf = MockFile(json.dumps(data)) with patch("os.path.exists", return_value=True), patch("builtins.open", return_value=mf): - assert self.module._load_transition_timeouts() == {"startup": 600, "shutdown": 360, "reboot": 480} + assert self.module._load_transition_timeouts() == { + "startup": 600, + "shutdown": 360, + "reboot": 480, + "halt_services": 120 + } def test_load_transition_timeouts_partial(self): ModuleBase._TRANSITION_TIMEOUTS_CACHE = None mf = MockFile(json.dumps({"dpu_startup_timeout": 500})) with patch("os.path.exists", return_value=True), patch("builtins.open", return_value=mf): - assert self.module._load_transition_timeouts() == {"startup": 500, "shutdown": 180, "reboot": 240} + assert self.module._load_transition_timeouts() == { + "startup": 500, + "shutdown": 180, + "reboot": 240, + "halt_services": 60 + } def test_load_transition_timeouts_error(self): ModuleBase._TRANSITION_TIMEOUTS_CACHE = None with patch("os.path.exists", return_value=True), \ patch("builtins.open", side_effect=Exception("read error")): - assert self.module._load_transition_timeouts() == {"startup": 300, "shutdown": 180, "reboot": 240} + assert self.module._load_transition_timeouts() == { + "startup": 300, + "shutdown": 180, + "reboot": 240, + "halt_services": 60 + } def test_load_transition_timeouts_cache(self): ModuleBase._TRANSITION_TIMEOUTS_CACHE = None @@ -468,30 +493,174 @@ def test_load_transition_timeouts_cache(self): # -------------------------------------- Graceful shutdown wait-loop -------- def test_graceful_shutdown_handler_external_completion(self): + """Test graceful shutdown when external process clears gnoi_halt_in_progress flag""" + db = MagicMock() + self.module.state_db = db + # First call: flag is set, second call: flag is cleared + db.hget.side_effect = ["True", None] + with patch.object(self.module, "get_name", return_value="DPU0"), \ - patch.object(self.module, "_load_transition_timeouts", return_value={"shutdown": 180}), \ - patch.object(self.module, "get_module_state_transition", side_effect=[True, False]), \ + patch.object(self.module, "_load_transition_timeouts", return_value={"halt_services": 60}), \ patch("time.sleep") as ms, \ patch("time.time", side_effect=[1000, 1000, 1005, 1005]): assert self.module._graceful_shutdown_handler() is True ms.assert_called_once_with(5) + # Verify we checked the flag twice + assert db.hget.call_count == 2 def test_graceful_shutdown_handler_timeout(self, capsys): + """Test graceful shutdown when timeout is reached""" + db = MagicMock() + self.module.state_db = db + # Flag remains set throughout + db.hget.return_value = "True" + with patch.object(self.module, "get_name", return_value="DPU0"), \ - patch.object(self.module, "_load_transition_timeouts", return_value={"shutdown": 10}), \ - patch.object(self.module, "get_module_state_transition", return_value=True), \ + patch.object(self.module, "_load_transition_timeouts", return_value={"halt_services": 10}), \ patch("time.sleep") as ms, \ - patch("time.time", side_effect=[1000, 1005, 1005, 1010, 1010]): + patch("time.time", side_effect=[1000, 1000, 1005, 1005, 1010, 1010, 1015]): assert self.module._graceful_shutdown_handler() is True + # Verify sleep was called ms.assert_called_with(5) - assert "Shutdown timeout reached for module: DPU0" in capsys.readouterr().err + # Verify flag was cleared after timeout + db.hdel.assert_called_once_with("CHASSIS_MODULE_TABLE|DPU0", "gnoi_halt_in_progress") + + assert "Shutdown timeout reached for module: DPU0. Proceeding with shutdown." in capsys.readouterr().err def test_graceful_shutdown_handler_immediate_past_end(self): + """Test when current time is already past end time""" + db = MagicMock() + self.module.state_db = db + db.hget.return_value = "True" + with patch.object(self.module, "get_name", return_value="DPU0"), \ - patch.object(self.module, "_load_transition_timeouts", return_value={"shutdown": 10}), \ - patch("time.sleep"), \ + patch.object(self.module, "_load_transition_timeouts", return_value={"halt_services": 10}), \ + patch("time.sleep") as ms, \ patch("time.time", side_effect=[1000, 1020, 1020]): + # Loop condition fails immediately, returns False assert self.module._graceful_shutdown_handler() is False + ms.assert_not_called() + + def test_graceful_shutdown_handler_custom_timeout(self): + """Test graceful shutdown with custom halt_services timeout""" + db = MagicMock() + self.module.state_db = db + db.hget.side_effect = ["True", None] + + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_load_transition_timeouts", return_value={"halt_services": 120}), \ + patch("time.sleep"), \ + patch("time.time", side_effect=[1000, 1000, 1005, 1005]): + assert self.module._graceful_shutdown_handler() is True + + # ---------------------------------- GNOI halt flag operations -------------- + def test_get_module_gnoi_halt_in_progress_true(self): + """Test getting gnoi_halt_in_progress flag when it's set to True""" + db = MagicMock() + self.module.state_db = db + db.hget.return_value = "True" + + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_transition_operation_lock"): + assert self.module._get_module_gnoi_halt_in_progress() is True + db.hget.assert_called_once_with("CHASSIS_MODULE_TABLE|DPU0", "gnoi_halt_in_progress") + + def test_get_module_gnoi_halt_in_progress_false(self): + """Test getting gnoi_halt_in_progress flag when it's not set or False""" + db = MagicMock() + self.module.state_db = db + + for value in [None, "False", "false", "", "0"]: + db.hget.return_value = value + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_transition_operation_lock"): + assert self.module._get_module_gnoi_halt_in_progress() is False + + def test_get_module_gnoi_halt_in_progress_db_error(self): + """Test getting gnoi_halt_in_progress flag when database error occurs""" + db = MagicMock() + self.module.state_db = db + db.hget.side_effect = Exception("DB Error") + + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_transition_operation_lock"): + assert self.module._get_module_gnoi_halt_in_progress() is False + + def test_clear_module_gnoi_halt_in_progress_success(self): + """Test clearing gnoi_halt_in_progress flag successfully""" + db = MagicMock() + self.module.state_db = db + + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_transition_operation_lock"): + assert self.module._clear_module_gnoi_halt_in_progress() is True + db.hdel.assert_called_once_with("CHASSIS_MODULE_TABLE|DPU0", "gnoi_halt_in_progress") + + def test_clear_module_gnoi_halt_in_progress_db_error(self): + """Test clearing gnoi_halt_in_progress flag when database error occurs""" + db = MagicMock() + self.module.state_db = db + db.hdel.side_effect = Exception("DB Error") + + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_transition_operation_lock"): + assert self.module._clear_module_gnoi_halt_in_progress() is False + + @pytest.mark.parametrize("module_name", ["DPU0", "DPU1", "LINE-CARD0", "SUPERVISOR0"]) + def test_get_module_gnoi_halt_in_progress_various_modules(self, module_name): + """Test getting gnoi_halt_in_progress flag for various module types""" + db = MagicMock() + self.module.state_db = db + db.hget.return_value = "True" + + with patch.object(self.module, "get_name", return_value=module_name), \ + patch.object(self.module, "_transition_operation_lock"): + assert self.module._get_module_gnoi_halt_in_progress() is True + db.hget.assert_called_with(f"CHASSIS_MODULE_TABLE|{module_name}", "gnoi_halt_in_progress") + + @pytest.mark.parametrize("module_name", ["DPU0", "DPU1", "LINE-CARD0", "SUPERVISOR0"]) + def test_clear_module_gnoi_halt_in_progress_various_modules(self, module_name): + """Test clearing gnoi_halt_in_progress flag for various module types""" + db = MagicMock() + self.module.state_db = db + + with patch.object(self.module, "get_name", return_value=module_name), \ + patch.object(self.module, "_transition_operation_lock"): + assert self.module._clear_module_gnoi_halt_in_progress() is True + db.hdel.assert_called_with(f"CHASSIS_MODULE_TABLE|{module_name}", "gnoi_halt_in_progress") + + def test_graceful_shutdown_handler_multiple_checks_before_clear(self): + """Test graceful shutdown checks flag multiple times before clearing on timeout""" + db = MagicMock() + self.module.state_db = db + # Flag remains set for 3 checks, then timeout + db.hget.return_value = "True" + + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_load_transition_timeouts", return_value={"halt_services": 15}), \ + patch("time.sleep"), \ + patch("time.time", side_effect=[1000, 1000, 1005, 1005, 1010, 1010, 1015, 1015, 1020]): + assert self.module._graceful_shutdown_handler() is True + # Should check flag at least 3 times before timeout + assert db.hget.call_count >= 3 + db.hdel.assert_called_once() + + def test_graceful_shutdown_handler_flag_cleared_mid_loop(self): + """Test graceful shutdown when flag is cleared after several iterations""" + db = MagicMock() + self.module.state_db = db + # Flag set for first 2 checks, then cleared + db.hget.side_effect = ["True", "True", None] + + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_load_transition_timeouts", return_value={"halt_services": 60}), \ + patch("time.sleep") as ms, \ + patch("time.time", side_effect=[1000, 1000, 1005, 1005, 1010, 1010, 1015]): + assert self.module._graceful_shutdown_handler() is True + # Should have slept twice before flag was cleared + assert ms.call_count == 2 + # Should not clear flag when external process cleared it + db.hdel.assert_not_called() # -------------------------------- set/get/clear transition flags ----------- def _key(self, mod="DPU0"): @@ -506,7 +675,7 @@ def test_set_module_state_transition_happy(self): patch("time.time", return_value=1000): assert self.module.set_module_state_transition("dpu0", "startup") is True db.hset.assert_has_calls([ - call(self._key("DPU0"), "state_transition_in_progress", "True"), + call(self._key("DPU0"), "transition_in_progress", "True"), call(self._key("DPU0"), "transition_type", "startup"), call(self._key("DPU0"), "transition_start_time", "1000"), ]) @@ -588,7 +757,7 @@ def test_clear_module_state_transition(self): patch.object(self.module, "get_name", return_value="DPU0"): assert self.module.clear_module_state_transition("dpu0") is True db.hdel.assert_has_calls([ - call(self._key("DPU0"), "state_transition_in_progress"), + call(self._key("DPU0"), "transition_in_progress"), call(self._key("DPU0"), "transition_type"), call(self._key("DPU0"), "transition_start_time"), ]) @@ -609,29 +778,32 @@ def test_clear_module_state_transition_various_modules(self, mod): with patch.object(self.module, "_transition_operation_lock"), \ patch.object(self.module, "get_name", return_value="DPU0"): assert self.module.clear_module_state_transition(mod.lower()) is True - db.hdel.assert_any_call(self._key(mod), "state_transition_in_progress") + db.hdel.assert_any_call(self._key(mod), "transition_in_progress") @pytest.mark.parametrize("ret,expected", [("True", True), (None, False), ("False", False), ("weird", False)]) def test_get_module_state_transition(self, ret, expected): db = MagicMock() self.module.state_db = db db.hget.return_value = ret - assert self.module.get_module_state_transition("dpu0") is expected - db.hget.assert_called_with(self._key("DPU0"), "state_transition_in_progress") + with patch.object(self.module, "get_name", return_value="DPU0"): + assert self.module.get_module_state_transition("dpu0") is expected + db.hget.assert_called_with(self._key("DPU0"), "transition_in_progress") def test_get_module_state_transition_db_error(self, capsys): db = MagicMock() self.module.state_db = db db.hget.side_effect = Exception("DB Error") - assert self.module.get_module_state_transition("dpu0") is False + with patch.object(self.module, "get_name", return_value="DPU0"): + assert self.module.get_module_state_transition("dpu0") is False @pytest.mark.parametrize("mod", ["DPU0", "LINE-CARD1", "SUPERVISOR0", "FABRIC-CARD0"]) def test_get_module_state_transition_various_modules(self, mod): db = MagicMock() self.module.state_db = db db.hget.return_value = "True" - assert self.module.get_module_state_transition(mod.lower()) is True - db.hget.assert_called_with(self._key(mod), "state_transition_in_progress") + with patch.object(self.module, "get_name", return_value=mod): + assert self.module.get_module_state_transition(mod.lower()) is True + db.hget.assert_called_with(self._key(mod), "transition_in_progress") # ---------------------------------- Edge timeout semantics coverage -------- @pytest.mark.parametrize( From d792479f188152b1282fb62e93013f7bc7757e7e Mon Sep 17 00:00:00 2001 From: Vasundhara Volam Date: Thu, 6 Nov 2025 21:51:14 +0000 Subject: [PATCH 6/8] Set 'gnoi_halt_in_progress' flag after completing pre-shutdown --- sonic_platform_base/module_base.py | 27 ++++++++++++++++--- tests/module_base_test.py | 43 ++++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+), 4 deletions(-) diff --git a/sonic_platform_base/module_base.py b/sonic_platform_base/module_base.py index 9a1717bd4..ab03fcd03 100644 --- a/sonic_platform_base/module_base.py +++ b/sonic_platform_base/module_base.py @@ -531,9 +531,26 @@ def _get_module_gnoi_halt_in_progress(self): except Exception as e: return False + def _set_module_gnoi_halt_in_progress(self): + """ + Sets the GNOI halt operation in progress flag for the module. + Once this flag is set, gnoi_shutdown daemon starts shutting down the services gracefully. + Returns: + bool: True if the flag is successfully set, False otherwise. + """ + module_name = self.get_name() + module_key = "CHASSIS_MODULE_TABLE|" + module_name + + with self._transition_operation_lock(): + try: + self.state_db.hset(module_key, "gnoi_halt_in_progress", "True") + return True + except Exception as e: + return False + def _clear_module_gnoi_halt_in_progress(self): """ - Clears the GNOI halt operation flag for the module. + Clears the GNOI halt operation in progress flag for the module. Returns: bool: True if the flag is successfully cleared, False otherwise. @@ -561,6 +578,11 @@ def _graceful_shutdown_handler(self): end_time = time.time() + halt_timeout interval = 5 # seconds + # Set the gnoi_halt_in_progress flag to notify gnoi_shutdown daemon + if not self._set_module_gnoi_halt_in_progress(): + sys.stderr.write("Failed to set gnoi_halt_in_progress flag for module: {}\n".format(module_name)) + return False + while time.time() <= end_time: # (a) External completion: gnoi_halt_in_progress flag cleared by external process if not self._get_module_gnoi_halt_in_progress(): @@ -608,9 +630,6 @@ def set_module_state_transition(self, module_name, transition_type): # Flag not set, set it now db.hset(module_key, "transition_in_progress", "True") db.hset(module_key, "transition_type", transition_type) - # If transition_type is 'shutdown', set the gnoi_halt_in_progress flag - if transition_type == "shutdown": - db.hset(module_key, "gnoi_halt_in_progress", "True") db.hset(module_key, "transition_start_time", str(int(time.time()))) return True else: diff --git a/tests/module_base_test.py b/tests/module_base_test.py index 64ef467ef..38c10e347 100644 --- a/tests/module_base_test.py +++ b/tests/module_base_test.py @@ -629,6 +629,49 @@ def test_clear_module_gnoi_halt_in_progress_various_modules(self, module_name): assert self.module._clear_module_gnoi_halt_in_progress() is True db.hdel.assert_called_with(f"CHASSIS_MODULE_TABLE|{module_name}", "gnoi_halt_in_progress") + def test_set_module_gnoi_halt_in_progress_success(self): + """Test setting gnoi_halt_in_progress flag successfully""" + db = MagicMock() + self.module.state_db = db + + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_transition_operation_lock"): + assert self.module._set_module_gnoi_halt_in_progress() is True + db.hset.assert_called_once_with("CHASSIS_MODULE_TABLE|DPU0", "gnoi_halt_in_progress", "True") + + def test_set_module_gnoi_halt_in_progress_db_error(self): + """Test setting gnoi_halt_in_progress flag when database error occurs""" + db = MagicMock() + self.module.state_db = db + db.hset.side_effect = Exception("DB Error") + + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_transition_operation_lock"): + assert self.module._set_module_gnoi_halt_in_progress() is False + + @pytest.mark.parametrize("module_name", ["DPU0", "DPU1", "LINE-CARD0", "SUPERVISOR0"]) + def test_set_module_gnoi_halt_in_progress_various_modules(self, module_name): + """Test setting gnoi_halt_in_progress flag for various module types""" + db = MagicMock() + self.module.state_db = db + + with patch.object(self.module, "get_name", return_value=module_name), \ + patch.object(self.module, "_transition_operation_lock"): + assert self.module._set_module_gnoi_halt_in_progress() is True + db.hset.assert_called_with(f"CHASSIS_MODULE_TABLE|{module_name}", "gnoi_halt_in_progress", "True") + + def test_set_module_gnoi_halt_in_progress_uses_lock(self): + """Test that _set_module_gnoi_halt_in_progress uses transition lock""" + db = MagicMock() + self.module.state_db = db + + mock_lock = MagicMock() + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_transition_operation_lock", return_value=mock_lock): + self.module._set_module_gnoi_halt_in_progress() + mock_lock.__enter__.assert_called_once() + mock_lock.__exit__.assert_called_once() + def test_graceful_shutdown_handler_multiple_checks_before_clear(self): """Test graceful shutdown checks flag multiple times before clearing on timeout""" db = MagicMock() From 42ba94897042a924d243e54d6df443230f6b7a78 Mon Sep 17 00:00:00 2001 From: Vasundhara Volam Date: Sat, 8 Nov 2025 16:14:02 +0000 Subject: [PATCH 7/8] Change polling interval to 0.5 secs in _graceful_shutdown_handler --- sonic_platform_base/module_base.py | 6 ++--- tests/module_base_test.py | 39 +++++++++++++++++++++--------- 2 files changed, 31 insertions(+), 14 deletions(-) diff --git a/sonic_platform_base/module_base.py b/sonic_platform_base/module_base.py index ab03fcd03..759b38bc3 100644 --- a/sonic_platform_base/module_base.py +++ b/sonic_platform_base/module_base.py @@ -548,7 +548,7 @@ def _set_module_gnoi_halt_in_progress(self): except Exception as e: return False - def _clear_module_gnoi_halt_in_progress(self): + def clear_module_gnoi_halt_in_progress(self): """ Clears the GNOI halt operation in progress flag for the module. @@ -576,7 +576,7 @@ def _graceful_shutdown_handler(self): halt_timeout = self._load_transition_timeouts().get("halt_services", 60) end_time = time.time() + halt_timeout - interval = 5 # seconds + interval = 0.5 # seconds # Set the gnoi_halt_in_progress flag to notify gnoi_shutdown daemon if not self._set_module_gnoi_halt_in_progress(): @@ -592,7 +592,7 @@ def _graceful_shutdown_handler(self): # (b) Timeout completion: proceed with shutdown after halt_services timeout if time.time() >= end_time: - self._clear_module_gnoi_halt_in_progress() + self.clear_module_gnoi_halt_in_progress() sys.stderr.write("Shutdown timeout reached for module: {}. Proceeding with shutdown.\n".format(module_name)) return True diff --git a/tests/module_base_test.py b/tests/module_base_test.py index 38c10e347..f63a4a36d 100644 --- a/tests/module_base_test.py +++ b/tests/module_base_test.py @@ -496,7 +496,8 @@ def test_graceful_shutdown_handler_external_completion(self): """Test graceful shutdown when external process clears gnoi_halt_in_progress flag""" db = MagicMock() self.module.state_db = db - # First call: flag is set, second call: flag is cleared + # First call to _set returns True, then _get returns True once, then False (external clear) + db.hset.return_value = True db.hget.side_effect = ["True", None] with patch.object(self.module, "get_name", return_value="DPU0"), \ @@ -504,24 +505,26 @@ def test_graceful_shutdown_handler_external_completion(self): patch("time.sleep") as ms, \ patch("time.time", side_effect=[1000, 1000, 1005, 1005]): assert self.module._graceful_shutdown_handler() is True - ms.assert_called_once_with(5) - # Verify we checked the flag twice + ms.assert_called_once_with(0.5) + # Verify _set was called once and _get checked the flag twice + assert db.hset.call_count >= 1 assert db.hget.call_count == 2 def test_graceful_shutdown_handler_timeout(self, capsys): """Test graceful shutdown when timeout is reached""" db = MagicMock() self.module.state_db = db - # Flag remains set throughout + # _set succeeds, flag remains set throughout checks + db.hset.return_value = True db.hget.return_value = "True" with patch.object(self.module, "get_name", return_value="DPU0"), \ patch.object(self.module, "_load_transition_timeouts", return_value={"halt_services": 10}), \ patch("time.sleep") as ms, \ - patch("time.time", side_effect=[1000, 1000, 1005, 1005, 1010, 1010, 1015]): + patch("time.time", side_effect=[1000, 1000, 1000.5, 1000.5, 1001, 1001, 1010, 1010, 1010]): assert self.module._graceful_shutdown_handler() is True - # Verify sleep was called - ms.assert_called_with(5) + # Verify sleep was called with interval 0.5 + ms.assert_called_with(0.5) # Verify flag was cleared after timeout db.hdel.assert_called_once_with("CHASSIS_MODULE_TABLE|DPU0", "gnoi_halt_in_progress") @@ -531,6 +534,7 @@ def test_graceful_shutdown_handler_immediate_past_end(self): """Test when current time is already past end time""" db = MagicMock() self.module.state_db = db + db.hset.return_value = True db.hget.return_value = "True" with patch.object(self.module, "get_name", return_value="DPU0"), \ @@ -545,14 +549,27 @@ def test_graceful_shutdown_handler_custom_timeout(self): """Test graceful shutdown with custom halt_services timeout""" db = MagicMock() self.module.state_db = db + db.hset.return_value = True db.hget.side_effect = ["True", None] with patch.object(self.module, "get_name", return_value="DPU0"), \ patch.object(self.module, "_load_transition_timeouts", return_value={"halt_services": 120}), \ patch("time.sleep"), \ - patch("time.time", side_effect=[1000, 1000, 1005, 1005]): + patch("time.time", side_effect=[1000, 1000, 1000.5, 1000.5]): assert self.module._graceful_shutdown_handler() is True + def test_graceful_shutdown_handler_set_flag_failure(self, capsys): + """Test graceful shutdown when setting gnoi_halt_in_progress flag fails""" + db = MagicMock() + self.module.state_db = db + db.hset.side_effect = Exception("DB Error") + + with patch.object(self.module, "get_name", return_value="DPU0"), \ + patch.object(self.module, "_transition_operation_lock"): + assert self.module._graceful_shutdown_handler() is False + + assert "Failed to set gnoi_halt_in_progress flag for module: DPU0" in capsys.readouterr().err + # ---------------------------------- GNOI halt flag operations -------------- def test_get_module_gnoi_halt_in_progress_true(self): """Test getting gnoi_halt_in_progress flag when it's set to True""" @@ -593,7 +610,7 @@ def test_clear_module_gnoi_halt_in_progress_success(self): with patch.object(self.module, "get_name", return_value="DPU0"), \ patch.object(self.module, "_transition_operation_lock"): - assert self.module._clear_module_gnoi_halt_in_progress() is True + assert self.module.clear_module_gnoi_halt_in_progress() is True db.hdel.assert_called_once_with("CHASSIS_MODULE_TABLE|DPU0", "gnoi_halt_in_progress") def test_clear_module_gnoi_halt_in_progress_db_error(self): @@ -604,7 +621,7 @@ def test_clear_module_gnoi_halt_in_progress_db_error(self): with patch.object(self.module, "get_name", return_value="DPU0"), \ patch.object(self.module, "_transition_operation_lock"): - assert self.module._clear_module_gnoi_halt_in_progress() is False + assert self.module.clear_module_gnoi_halt_in_progress() is False @pytest.mark.parametrize("module_name", ["DPU0", "DPU1", "LINE-CARD0", "SUPERVISOR0"]) def test_get_module_gnoi_halt_in_progress_various_modules(self, module_name): @@ -626,7 +643,7 @@ def test_clear_module_gnoi_halt_in_progress_various_modules(self, module_name): with patch.object(self.module, "get_name", return_value=module_name), \ patch.object(self.module, "_transition_operation_lock"): - assert self.module._clear_module_gnoi_halt_in_progress() is True + assert self.module.clear_module_gnoi_halt_in_progress() is True db.hdel.assert_called_with(f"CHASSIS_MODULE_TABLE|{module_name}", "gnoi_halt_in_progress") def test_set_module_gnoi_halt_in_progress_success(self): From 821aba5b5a9bfdc0a8e2c5f65ba27d0660e8bd3f Mon Sep 17 00:00:00 2001 From: Vasundhara Volam Date: Fri, 14 Nov 2025 02:09:07 +0000 Subject: [PATCH 8/8] Fix timout conditional check in graceful_shutdown_handler --- sonic_platform_base/module_base.py | 13 +++++++------ tests/module_base_test.py | 14 ++++++++------ 2 files changed, 15 insertions(+), 12 deletions(-) diff --git a/sonic_platform_base/module_base.py b/sonic_platform_base/module_base.py index 759b38bc3..0c9f7c647 100644 --- a/sonic_platform_base/module_base.py +++ b/sonic_platform_base/module_base.py @@ -506,8 +506,9 @@ def _load_transition_timeouts(self) -> dict: timeouts["startup"] = int(platform_data.get("dpu_startup_timeout", timeouts["startup"])) timeouts["shutdown"] = int(platform_data.get("dpu_shutdown_timeout", timeouts["shutdown"])) timeouts["reboot"] = int(platform_data.get("dpu_reboot_timeout", timeouts["reboot"])) + # Add 10 seconds buffer to halt_services timeout, as this is used by GNOI daemon as well timeouts["halt_services"] = int(platform_data.get("dpu_halt_services_timeout", - timeouts["halt_services"])) + timeouts["halt_services"])) + 10 except Exception as e: sys.stderr.write("Error loading transition timeouts from {}: {}\n".format(platform_json_path, str(e))) @@ -590,11 +591,11 @@ def _graceful_shutdown_handler(self): time.sleep(interval) - # (b) Timeout completion: proceed with shutdown after halt_services timeout - if time.time() >= end_time: - self.clear_module_gnoi_halt_in_progress() - sys.stderr.write("Shutdown timeout reached for module: {}. Proceeding with shutdown.\n".format(module_name)) - return True + # (b) Timeout completion: proceed with shutdown after halt_services timeout + if time.time() >= end_time: + self.clear_module_gnoi_halt_in_progress() + sys.stderr.write("Shutdown timeout reached for module: {}. Proceeding with shutdown.\n".format(module_name)) + return True return False diff --git a/tests/module_base_test.py b/tests/module_base_test.py index f63a4a36d..4466f4946 100644 --- a/tests/module_base_test.py +++ b/tests/module_base_test.py @@ -457,7 +457,7 @@ def test_load_transition_timeouts_custom(self): "startup": 600, "shutdown": 360, "reboot": 480, - "halt_services": 120 + "halt_services": 130 } def test_load_transition_timeouts_partial(self): @@ -468,7 +468,7 @@ def test_load_transition_timeouts_partial(self): "startup": 500, "shutdown": 180, "reboot": 240, - "halt_services": 60 + "halt_services": 70 } def test_load_transition_timeouts_error(self): @@ -521,7 +521,7 @@ def test_graceful_shutdown_handler_timeout(self, capsys): with patch.object(self.module, "get_name", return_value="DPU0"), \ patch.object(self.module, "_load_transition_timeouts", return_value={"halt_services": 10}), \ patch("time.sleep") as ms, \ - patch("time.time", side_effect=[1000, 1000, 1000.5, 1000.5, 1001, 1001, 1010, 1010, 1010]): + patch("time.time", side_effect=[1000, 1000, 1000.5, 1001, 1005, 1010, 1010.5, 1010.5]): assert self.module._graceful_shutdown_handler() is True # Verify sleep was called with interval 0.5 ms.assert_called_with(0.5) @@ -541,9 +541,11 @@ def test_graceful_shutdown_handler_immediate_past_end(self): patch.object(self.module, "_load_transition_timeouts", return_value={"halt_services": 10}), \ patch("time.sleep") as ms, \ patch("time.time", side_effect=[1000, 1020, 1020]): - # Loop condition fails immediately, returns False - assert self.module._graceful_shutdown_handler() is False + # When already past deadline, function clears flag and returns True (timeout path) + assert self.module._graceful_shutdown_handler() is True ms.assert_not_called() + # Should clear the flag when timeout is reached + db.hdel.assert_called_once_with("CHASSIS_MODULE_TABLE|DPU0", "gnoi_halt_in_progress") def test_graceful_shutdown_handler_custom_timeout(self): """Test graceful shutdown with custom halt_services timeout""" @@ -699,7 +701,7 @@ def test_graceful_shutdown_handler_multiple_checks_before_clear(self): with patch.object(self.module, "get_name", return_value="DPU0"), \ patch.object(self.module, "_load_transition_timeouts", return_value={"halt_services": 15}), \ patch("time.sleep"), \ - patch("time.time", side_effect=[1000, 1000, 1005, 1005, 1010, 1010, 1015, 1015, 1020]): + patch("time.time", side_effect=[1000, 1000, 1005, 1010, 1015, 1015.5, 1015.5]): assert self.module._graceful_shutdown_handler() is True # Should check flag at least 3 times before timeout assert db.hget.call_count >= 3