Skip to content
265 changes: 254 additions & 11 deletions sonic_platform_base/module_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@
import sys
import os
import fcntl
import time
from . import device_base
import json
import threading
import contextlib
import shutil
import subprocess
import os

# PCI state database constants
PCIE_DETACH_INFO_TABLE = "PCIE_DETACH_INFO"
Expand All @@ -30,6 +28,7 @@ class ModuleBase(device_base.DeviceBase):
DEVICE_TYPE = "module"
PCI_OPERATION_LOCK_FILE_PATH = "/var/lock/{}_pci.lock"
SENSORD_OPERATION_LOCK_FILE_PATH = "/var/lock/sensord.lock"
TRANSITION_OPERATION_LOCK_FILE_PATH = "/var/lock/{}_transition.lock"

# Possible card types for modular chassis
MODULE_TYPE_SUPERVISOR = "SUPERVISOR"
Expand Down Expand Up @@ -87,7 +86,7 @@ def __init__(self):
self._thermal_list = []
self._voltage_sensor_list = []
self._current_sensor_list = []
self.state_db_connector = None
self.state_db = None
self.pci_bus_info = None

# List of SfpBase-derived objects representing all sfps
Expand All @@ -100,7 +99,21 @@ def __init__(self):

# Flag to indicate if the module is running on the host/container
self.is_host = self._is_host()

self.state_db = self.initialize_state_db()

def initialize_state_db(self):
"""
Initializes and returns the state database connector.

Returns:
A database connector object for the state database.
"""
if self.state_db is None:
import sonic_py_common.daemon_base as daemon_base
self.state_db = daemon_base.db_connect("STATE_DB")
return self.state_db

@contextlib.contextmanager
def _file_operation_lock(self, lock_file_path):
"""Common file-based lock for operations using flock"""
Expand All @@ -125,6 +138,13 @@ def _sensord_operation_lock(self):
with self._file_operation_lock(lock_file_path):
yield

@contextlib.contextmanager
def _transition_operation_lock(self):
"""File-based lock for module state transition operations using flock"""
lock_file_path = self.TRANSITION_OPERATION_LOCK_FILE_PATH.format(self.get_name())
with self._file_operation_lock(lock_file_path):
yield

def get_base_mac(self):
"""
Retrieves the base MAC address for the module
Expand Down Expand Up @@ -347,16 +367,12 @@ def pci_entry_state_db(self, pcie_string, operation):
RuntimeError: If state database connection fails
"""
try:
# Do not use import if swsscommon is not needed
import swsscommon
PCIE_DETACH_INFO_TABLE_KEY = PCIE_DETACH_INFO_TABLE+"|"+pcie_string
if not self.state_db_connector:
self.state_db_connector = swsscommon.swsscommon.DBConnector("STATE_DB", 0)
if operation == PCIE_OPERATION_ATTACHING:
self.state_db_connector.delete(PCIE_DETACH_INFO_TABLE_KEY)
self.state_db.delete(PCIE_DETACH_INFO_TABLE_KEY)
return
self.state_db_connector.hset(PCIE_DETACH_INFO_TABLE_KEY, "bus_info", pcie_string)
self.state_db_connector.hset(PCIE_DETACH_INFO_TABLE_KEY, "dpu_state", operation)
self.state_db.hset(PCIE_DETACH_INFO_TABLE_KEY, "bus_info", pcie_string)
self.state_db.hset(PCIE_DETACH_INFO_TABLE_KEY, "dpu_state", operation)
except Exception as e:
sys.stderr.write("Failed to write pcie bus info to state database: {}\n".format(str(e)))

Expand Down Expand Up @@ -396,6 +412,233 @@ def pci_reattach(self):
"""
raise NotImplementedError

def set_admin_state_gracefully(self, up):
"""
Request to keep the module in administratively up/down state with graceful shutdown.

This function is designed for SmartSwitch platforms to ensure a graceful shutdown
of the module when transitioning to the admin-down state.

For non-SmartSwitch platforms, use the standard set_admin_state() method.

Args:
up: A boolean, True to set the admin-state to UP. False to set the
admin-state to DOWN.
Returns:
bool: True if the request has been issued successfully, False if not
"""
module_name = self.get_name()
# Set the module state to administratively up.
if up:
if not self.set_module_state_transition(module_name, "startup"):
sys.stderr.write("Failed to set module state transition for admin state UP\n")
return False

admin_status = self.set_admin_state(True)

# This is only valid on platforms which have pci_rescan sensord changes required. If it is not implemented,
# there are no actions taken during this function execution.
if not self.module_post_startup():
sys.stderr.write("module_post_startup() failed\n")

if not self.clear_module_state_transition(module_name):
sys.stderr.write("Failed to clear module state transition for admin state UP\n")

return admin_status
else:
# Initiate graceful shutdown before setting admin state to down.
if not self.set_module_state_transition(module_name, "shutdown"):
sys.stderr.write("Failed to set module state transition for admin state DOWN\n")
return False

# This is only valid on platforms which have pci_detach and sensord changes required. If it is not implemented,
# there are no actions taken during this function execution.
if not self.module_pre_shutdown():
sys.stderr.write("module_pre_shutdown() failed\n")

if not self._graceful_shutdown_handler():
sys.stderr.write("Graceful shutdown handler failed or timed out for module: {}\n".format(module_name))
# Proceeding with admin down even if graceful shutdown fails.

admin_status = self.set_admin_state(False)

if not self.clear_module_state_transition(module_name):
sys.stderr.write("Failed to clear module state transition for admin state DOWN\n")

return admin_status

##############################################
# Smartswitch module helpers
##############################################
_TRANSITION_TIMEOUT_DEFAULTS = {
"startup": 300, # 5 mins
"shutdown": 180, # 3 mins
"reboot": 240, # 4 mins
}

_TRANSITION_TIMEOUTS_CACHE = None

def _load_transition_timeouts(self) -> dict:
"""
Loads module state transition timeouts from /usr/share/sonic/platform/platform.json if present,
otherwise fall back to _TRANSITION_TIMEOUT_DEFAULTS.

Reads the following keys from the JSON file:
- dpu_startup_timeout
- dpu_shutdown_timeout
- dpu_reboot_timeout
Returns:
dict: A dictionary with transition types as keys and their corresponding timeouts
in seconds as values.
"""
if ModuleBase._TRANSITION_TIMEOUTS_CACHE is not None:
return ModuleBase._TRANSITION_TIMEOUTS_CACHE

timeouts = self._TRANSITION_TIMEOUT_DEFAULTS.copy()
platform_json_path = "/usr/share/sonic/platform/platform.json"

try:
if os.path.exists(platform_json_path):
with open(platform_json_path, 'r') as f:
platform_data = json.load(f)
timeouts["startup"] = int(platform_data.get("dpu_startup_timeout", timeouts["startup"]))
timeouts["shutdown"] = int(platform_data.get("dpu_shutdown_timeout", timeouts["shutdown"]))
timeouts["reboot"] = int(platform_data.get("dpu_reboot_timeout", timeouts["reboot"]))
except Exception as e:
sys.stderr.write("Error loading transition timeouts from {}: {}\n".format(platform_json_path, str(e)))

ModuleBase._TRANSITION_TIMEOUTS_CACHE = timeouts
return timeouts

def _graceful_shutdown_handler(self):
"""
Initiates a graceful shutdown of the module invoked with state transition in progress.

Returns:
bool: True if the shutdown process was initiated successfully, False otherwise.
"""
module_name = self.get_name()

shutdown_timeout = self._load_transition_timeouts().get("shutdown", 180)
end_time = time.time() + shutdown_timeout
interval = 5 # seconds

while time.time() <= end_time:
# (a) External completion: transition flag cleared by external process
if not self.get_module_state_transition(module_name):
return True

time.sleep(interval)

# (b) Timeout completion: proceed with shutdown after timeout
if time.time() >= end_time:
self.clear_module_state_transition(module_name)
sys.stderr.write("Shutdown timeout reached for module: {}. Proceeding with shutdown.\n".format(module_name))
return True

return False

# ############################################################
# Centralized APIs for CHASSIS_MODULE_TABLE transition flags
# ############################################################
def set_module_state_transition(self, module_name, transition_type):
"""
Sets the module state transition flag 'state_transition_in_progress' in the CHASSIS_MODULE_TABLE.

Args:
db: The database connector.
module_name: The name of the module.
transition_type: The type of transition (e.g., "startup", "shutdown", "reset").
Returns:
bool: Returns True if the flag is successfully set.

If the flag was already set but has timed out, the function resets the flag and returns True.
Returns False in all other cases.
"""
transition_type = transition_type.lower()
if transition_type not in self._TRANSITION_TIMEOUT_DEFAULTS.keys():
sys.stderr.write("Invalid transition type: {}\n".format(transition_type))
return False

module_name = module_name.upper()
module_key = "CHASSIS_MODULE_TABLE|" + module_name
db = self.state_db

with self._transition_operation_lock():
try:
current_flag = db.hget(module_key, "state_transition_in_progress")
if current_flag is None:
# Flag not set, set it now
db.hset(module_key, "state_transition_in_progress", "True")
db.hset(module_key, "transition_type", transition_type)
db.hset(module_key, "transition_start_time", str(int(time.time())))
return True
else:
# Flag already set, check for timeout
start_time_str = db.hget(module_key, "transition_start_time")
if start_time_str is None:
sys.stderr.write("Missing start time for transition flag on module: {}\n".format(module_name))
return False

start_time = int(start_time_str)
current_time = int(time.time())
timeout = self._load_transition_timeouts().get(transition_type, 0)
if current_time - start_time > timeout:
# Timeout occurred, reset the flag
db.hset(module_key, "state_transition_in_progress", "True")
db.hset(module_key, "transition_type", transition_type)
db.hset(module_key, "transition_start_time", str(current_time))
return True
else:
# Still within timeout period
sys.stderr.write("Transition already in progress for module: {}\n".format(module_name))
return False
except Exception as e:
sys.stderr.write("Error setting transition flag for module {}: {}\n".format(module_name, str(e)))
return False

def clear_module_state_transition(self, module_name):
"""
Clears the module state transition flag 'state_transition_in_progress' in the CHASSIS_MODULE_TABLE.

Args:
db: The database connector.
module_name: The name of the module.
Returns:
bool: Returns True if the flag is successfully cleared, False otherwise.
"""
module_name = module_name.upper()
module_key = "CHASSIS_MODULE_TABLE|" + module_name

with self._transition_operation_lock():
try:
self.state_db.hdel(module_key, "state_transition_in_progress")
self.state_db.hdel(module_key, "transition_type")
self.state_db.hdel(module_key, "transition_start_time")
return True
except Exception as e:
sys.stderr.write("Error clearing transition flag for module {}: {}\n".format(module_name, str(e)))
return False

def get_module_state_transition(self, module_name):
"""
Retrieves the module state transition flag 'state_transition_in_progress' from the CHASSIS_MODULE_TABLE.

Args:
db: The database connector.
module_name: The name of the module.
Returns:
bool: Returns True if the flag is set, False otherwise.
"""
module_name = module_name.upper()
module_key = "CHASSIS_MODULE_TABLE|" + module_name

try:
current_flag = self.state_db.hget(module_key, "state_transition_in_progress")
return current_flag == "True"
except Exception as e:
return False

##############################################
# Component methods
##############################################
Expand Down
Loading
Loading