diff --git a/data/debian/rules b/data/debian/rules index 47d26ccb..f32142df 100755 --- a/data/debian/rules +++ b/data/debian/rules @@ -20,5 +20,6 @@ override_dh_installsystemd: dh_installsystemd --no-start --name=procdockerstatsd dh_installsystemd --no-start --name=determine-reboot-cause dh_installsystemd --no-start --name=process-reboot-cause + dh_installsystemd --no-start --name=gnoi-shutdown dh_installsystemd $(HOST_SERVICE_OPTS) --name=sonic-hostservice diff --git a/data/debian/sonic-host-services-data.gnoi-shutdown.service b/data/debian/sonic-host-services-data.gnoi-shutdown.service new file mode 100644 index 00000000..cb50c5b1 --- /dev/null +++ b/data/debian/sonic-host-services-data.gnoi-shutdown.service @@ -0,0 +1,16 @@ +[Unit] +Description=gNOI based DPU Graceful Shutdown Daemon +Requires=database.service +Wants=network-online.target +After=network-online.target database.service + +[Service] +Type=simple +ExecStartPre=/usr/bin/python3 /usr/local/bin/check_platform.py +ExecStartPre=/bin/bash /usr/local/bin/wait-for-sonic-core.sh +ExecStart=/usr/bin/python3 /usr/local/bin/gnoi_shutdown_daemon.py +Restart=always +RestartSec=5 + +[Install] +WantedBy=multi-user.target diff --git a/scripts/check_platform.py b/scripts/check_platform.py new file mode 100644 index 00000000..4fbffefd --- /dev/null +++ b/scripts/check_platform.py @@ -0,0 +1,23 @@ +#!/usr/bin/env python3 +""" +Check if the current platform is a SmartSwitch NPU (not DPU). +Exit 0 if SmartSwitch NPU, exit 1 otherwise. +""" +import sys + +def main(): + try: + from sonic_py_common import device_info + from utilities_common.chassis import is_dpu + + # Check if SmartSwitch NPU (not DPU) + if device_info.is_smartswitch() and not is_dpu(): + sys.exit(0) + else: + sys.exit(1) + except (ImportError, AttributeError, RuntimeError) as e: + sys.stderr.write("check_platform failed: {}\n".format(str(e))) + sys.exit(1) + +if __name__ == "__main__": + main() diff --git a/scripts/gnoi_shutdown_daemon.py b/scripts/gnoi_shutdown_daemon.py new file mode 100644 index 00000000..ee5f565d --- /dev/null +++ b/scripts/gnoi_shutdown_daemon.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python3 +""" +gnoi-shutdown-daemon + +Listens for CHASSIS_MODULE_TABLE state changes in STATE_DB and, when a +SmartSwitch DPU module enters a "shutdown" transition, issues a gNOI Reboot +(method HALT) toward that DPU and polls RebootStatus until complete or timeout. +""" + +import json +import time +import subprocess +import os +import redis +import threading +import sonic_py_common.daemon_base as daemon_base +from sonic_py_common import syslogger +from swsscommon import swsscommon + +REBOOT_RPC_TIMEOUT_SEC = 60 # gNOI System.Reboot call timeout +STATUS_POLL_TIMEOUT_SEC = 60 # overall time - polling RebootStatus +STATUS_POLL_INTERVAL_SEC = 1 # delay between reboot status polls +HALT_IN_PROGRESS_POLL_INTERVAL_SEC = 5 # delay between halt_in_progress checks +STATUS_RPC_TIMEOUT_SEC = 10 # per RebootStatus RPC timeout +REBOOT_METHOD_HALT = 3 # gNOI System.Reboot method: HALT +STATE_DB_INDEX = 6 +CONFIG_DB_INDEX = 4 +DEFAULT_GNMI_PORT = "8080" # Default GNMI port for DPU + +SYSLOG_IDENTIFIER = "gnoi-shutdown-daemon" +logger = syslogger.SysLogger(SYSLOG_IDENTIFIER) + + +# ########## +# Helpers +# ########## + + +def _get_halt_timeout() -> int: + """Get halt_services timeout from platform.json, or default to STATUS_POLL_TIMEOUT_SEC.""" + try: + from sonic_platform import platform + chassis = platform.Platform().get_chassis() + platform_name = chassis.get_name() if hasattr(chassis, 'get_name') else None + + if not platform_name: + return STATUS_POLL_TIMEOUT_SEC + + platform_json_path = f"/usr/share/sonic/device/{platform_name}/platform.json" + + if os.path.exists(platform_json_path): + with open(platform_json_path, 'r') as f: + return int(json.load(f).get("dpu_halt_services_timeout", STATUS_POLL_TIMEOUT_SEC)) + except (OSError, IOError, ValueError, KeyError) as e: + logger.log_info(f"Could not load timeout from platform.json: {e}, using default {STATUS_POLL_TIMEOUT_SEC}s") + return STATUS_POLL_TIMEOUT_SEC + + +def execute_command(command_args, timeout_sec=REBOOT_RPC_TIMEOUT_SEC, suppress_stderr=False): + """Run gnoi_client with a timeout; return (rc, stdout, stderr).""" + try: + stderr_dest = subprocess.DEVNULL if suppress_stderr else subprocess.PIPE + result = subprocess.run(command_args, stdout=subprocess.PIPE, stderr=stderr_dest, text=True, timeout=timeout_sec) + return result.returncode, result.stdout.strip(), result.stderr.strip() if not suppress_stderr else "" + except subprocess.TimeoutExpired as e: + return -1, "", f"Command timed out after {int(e.timeout)}s." + except Exception as e: + return -2, "", f"Command failed: {e}" + + +def get_dpu_ip(config_db, dpu_name: str) -> str: + """Retrieve DPU IP from CONFIG_DB DHCP_SERVER_IPV4_PORT table.""" + dpu_name_lower = dpu_name.lower() + + try: + key = f"DHCP_SERVER_IPV4_PORT|bridge-midplane|{dpu_name_lower}" + ips = config_db.hget(key, "ips@") + + if ips: + if isinstance(ips, bytes): + ips = ips.decode('utf-8') + ip = ips[0] if isinstance(ips, list) else ips + return ip + + except (AttributeError, KeyError, TypeError) as e: + logger.log_error(f"{dpu_name}: Error getting IP: {e}") + + return None + + +def get_dpu_gnmi_port(config_db, dpu_name: str) -> str: + """Retrieve GNMI port from CONFIG_DB DPU table, default to 8080.""" + dpu_name_lower = dpu_name.lower() + + try: + for k in [dpu_name_lower, dpu_name.upper(), dpu_name]: + key = f"DPU|{k}" + gnmi_port = config_db.hget(key, "gnmi_port") + if gnmi_port: + if isinstance(gnmi_port, bytes): + gnmi_port = gnmi_port.decode('utf-8') + return str(gnmi_port) + except (AttributeError, KeyError, TypeError) as e: + logger.log_warning(f"{dpu_name}: Error getting gNMI port, using default: {e}") + + logger.log_info(f"{dpu_name}: gNMI port not found, using default {DEFAULT_GNMI_PORT}") + return DEFAULT_GNMI_PORT + +# ############### +# gNOI Reboot Handler +# ############### +class GnoiRebootHandler: + """ + Handles gNOI reboot operations for DPU modules, including sending reboot commands + and polling for status completion. + """ + def __init__(self, db, config_db, chassis): + self._db = db + self._config_db = config_db + self._chassis = chassis + + def _handle_transition(self, dpu_name: str, transition_type: str) -> bool: + """ + Handle a shutdown or reboot transition for a DPU module. + Returns True if the operation completed successfully, False otherwise. + """ + logger.log_notice(f"{dpu_name}: Starting gNOI shutdown sequence") + + # Wait for platform PCI detach completion + if not self._wait_for_gnoi_halt_in_progress(dpu_name): + logger.log_warning(f"{dpu_name}: Timeout waiting for PCI detach, proceeding anyway") + + # Get DPU configuration + dpu_ip = None + try: + dpu_ip = get_dpu_ip(self._config_db, dpu_name) + port = get_dpu_gnmi_port(self._config_db, dpu_name) + if not dpu_ip: + logger.log_error(f"{dpu_name}: IP not found in DHCP_SERVER_IPV4_PORT table (key: bridge-midplane|{dpu_name.lower()}), cannot proceed") + self._clear_halt_flag(dpu_name) + return False + except Exception as e: + logger.log_error(f"{dpu_name}: Failed to get configuration: {e}") + self._clear_halt_flag(dpu_name) + return False + + # Send gNOI Reboot HALT command + reboot_sent = self._send_reboot_command(dpu_name, dpu_ip, port) + if not reboot_sent: + logger.log_error(f"{dpu_name}: Failed to send Reboot command") + self._clear_halt_flag(dpu_name) + return False + + # Poll for RebootStatus completion + reboot_successful = self._poll_reboot_status(dpu_name, dpu_ip, port) + + if self._clear_halt_flag(dpu_name): + logger.log_notice(f"{dpu_name}: Halting the services on DPU is successful for {dpu_name}") + + return reboot_successful + + def _wait_for_gnoi_halt_in_progress(self, dpu_name: str) -> bool: + """ + Poll for gnoi_halt_in_progress flag in STATE_DB CHASSIS_MODULE_TABLE. + This flag is set by the platform after completing PCI detach. + """ + deadline = time.monotonic() + _get_halt_timeout() + + while time.monotonic() < deadline: + try: + table = swsscommon.Table(self._db, "CHASSIS_MODULE_TABLE") + (status, fvs) = table.get(dpu_name) + + if status: + entry = dict(fvs) + halt_in_progress = entry.get("gnoi_halt_in_progress", "False") + + if halt_in_progress == "True": + logger.log_notice(f"{dpu_name}: PCI detach complete, proceeding for halting services via gNOI") + return True + + except Exception as e: + logger.log_error(f"{dpu_name}: Error reading halt flag: {e}") + + time.sleep(HALT_IN_PROGRESS_POLL_INTERVAL_SEC) + + return False + + def _send_reboot_command(self, dpu_name: str, dpu_ip: str, port: str) -> bool: + """Send gNOI Reboot HALT command to the DPU.""" + reboot_cmd = [ + "docker", "exec", "gnmi", "gnoi_client", + f"-target={dpu_ip}:{port}", + "-logtostderr", "-notls", + "-module", "System", + "-rpc", "Reboot", + "-jsonin", json.dumps({"method": REBOOT_METHOD_HALT, "message": "Triggered by SmartSwitch graceful shutdown"}) + ] + rc, out, err = execute_command(reboot_cmd, timeout_sec=REBOOT_RPC_TIMEOUT_SEC, suppress_stderr=True) + if rc != 0: + logger.log_error(f"{dpu_name}: Reboot command failed") + return False + return True + + def _poll_reboot_status(self, dpu_name: str, dpu_ip: str, port: str) -> bool: + """Poll RebootStatus until completion or timeout.""" + deadline = time.monotonic() + _get_halt_timeout() + status_cmd = [ + "docker", "exec", "gnmi", "gnoi_client", + f"-target={dpu_ip}:{port}", + "-logtostderr", "-notls", + "-module", "System", + "-rpc", "RebootStatus" + ] + while time.monotonic() < deadline: + rc_s, out_s, err_s = execute_command(status_cmd, timeout_sec=STATUS_RPC_TIMEOUT_SEC) + if rc_s == 0 and out_s and ("reboot complete" in out_s.lower()): + return True + time.sleep(STATUS_POLL_INTERVAL_SEC) + logger.log_notice(f"{dpu_name}: Timeout waiting for RebootStatus completion, proceeding with halt flag clear") + return False + + def _clear_halt_flag(self, dpu_name: str) -> bool: + """Clear halt_in_progress flag via platform API.""" + try: + # Use chassis.get_module_index() to get the correct platform index for the named module + module_index = self._chassis.get_module_index(dpu_name) + if module_index < 0: + logger.log_error(f"{dpu_name}: Unable to get module index from chassis") + return False + + module = self._chassis.get_module(module_index) + if module is None: + logger.log_error(f"{dpu_name}: Module at index {module_index} not found in chassis") + return False + + module.clear_module_gnoi_halt_in_progress() + logger.log_info(f"{dpu_name}: Successfully cleared halt_in_progress flag (module index: {module_index})") + return True + except Exception as e: + logger.log_error(f"{dpu_name}: Failed to clear halt flag: {e}") + return False + +# ######### +# Main loop +# ######### + +def main(): + # Connect for STATE_DB (for gnoi_halt_in_progress flag) and CONFIG_DB + state_db = daemon_base.db_connect("STATE_DB") + config_db = daemon_base.db_connect("CONFIG_DB") + + # Also connect ConfigDBConnector for pubsub support (has get_redis_client method) + config_db_connector = swsscommon.ConfigDBConnector() + config_db_connector.connect(wait_for_init=False) + + # Get chassis instance for accessing ModuleBase APIs + try: + from sonic_platform import platform + chassis = platform.Platform().get_chassis() + logger.log_info("Successfully obtained chassis instance") + except Exception as e: + logger.log_error(f"Failed to get chassis instance: {e}") + raise + + # gNOI reboot handler + reboot_handler = GnoiRebootHandler(state_db, config_db, chassis) + + # Track active transitions to prevent duplicate threads for the same DPU + active_transitions = set() + active_transitions_lock = threading.Lock() + + # Keyspace notifications are globally enabled in docker-database + pubsub = config_db_connector.get_redis_client(config_db_connector.db_name).pubsub() + + # Listen to keyspace notifications for CHASSIS_MODULE table keys in CONFIG_DB + topic = f"__keyspace@{CONFIG_DB_INDEX}__:CHASSIS_MODULE|*" + pubsub.psubscribe(topic) + + logger.log_notice("gnoi-shutdown-daemon started, monitoring CHASSIS_MODULE admin_status changes") + + while True: + message = pubsub.get_message(timeout=1.0) + if message: + msg_type = message.get("type") + if isinstance(msg_type, bytes): + msg_type = msg_type.decode('utf-8') + + if msg_type == "pmessage": + channel = message.get("channel", b"") + if isinstance(channel, bytes): + channel = channel.decode('utf-8') + + # Extract key from channel: "__keyspace@4__:CHASSIS_MODULE|DPU0" + key = channel.split(":", 1)[-1] if ":" in channel else channel + + if not key.startswith("CHASSIS_MODULE|"): + continue + + # Extract module name + try: + dpu_name = key.split("|", 1)[1] + if not dpu_name: + raise IndexError + except IndexError: + continue + + # Read admin_status from CONFIG_DB + try: + key = f"CHASSIS_MODULE|{dpu_name}" + admin_status = config_db.hget(key, "admin_status") + if not admin_status: + continue + + if isinstance(admin_status, bytes): + admin_status = admin_status.decode('utf-8') + + except (AttributeError, KeyError, TypeError) as e: + logger.log_error(f"{dpu_name}: Failed to read CONFIG_DB: {e}") + continue + + if admin_status == "down": + # Check if already processing this DPU + with active_transitions_lock: + if dpu_name in active_transitions: + continue + active_transitions.add(dpu_name) + + logger.log_notice(f"{dpu_name}: Admin shutdown detected, initiating gNOI HALT") + + # Wrapper to clean up after transition + def handle_and_cleanup(dpu): + try: + reboot_handler._handle_transition(dpu, "shutdown") + logger.log_info(f"{dpu}: Transition thread completed successfully") + except Exception as e: + logger.log_error(f"{dpu}: Transition thread failed with exception: {e}") + finally: + with active_transitions_lock: + active_transitions.discard(dpu) + + # Run in background thread + thread = threading.Thread( + target=handle_and_cleanup, + args=(dpu_name,), + name=f"gnoi-{dpu_name}", + daemon=True + ) + thread.start() + +if __name__ == "__main__": + main() diff --git a/scripts/wait-for-sonic-core.sh b/scripts/wait-for-sonic-core.sh new file mode 100644 index 00000000..3f4c5182 --- /dev/null +++ b/scripts/wait-for-sonic-core.sh @@ -0,0 +1,45 @@ +set -euo pipefail + +log() { echo "[wait-for-sonic-core] $*"; } + +# Hard dep we expect to be up before we start: swss +if ! systemctl is-active --quiet swss.service; then + log "Waiting for swss.service to become active…" + systemctl --no-pager --full status swss.service || true + exit 0 # let systemd retry; ExecStartPre must be quick +fi + +# Hard dep we expect to be up before we start: gnmi +if ! systemctl is-active --quiet gnmi.service; then + log "Waiting for gnmi.service to become active…" + systemctl --no-pager --full status gnmi.service || true + exit 0 # let systemd retry; ExecStartPre must be quick +fi + +# pmon is advisory: proceed even if it's not active yet +if ! systemctl is-active --quiet pmon.service; then + log "pmon.service not active yet (advisory)" +fi + +# Wait for CHASSIS_MODULE to exist (best-effort, bounded time) +DEFAULT_MAX_WAIT_SECONDS=60 +MAX_WAIT=${WAIT_CORE_MAX_SECONDS:-$DEFAULT_MAX_WAIT_SECONDS} +INTERVAL=2 +ELAPSED=0 + +has_chassis_table() { + redis-cli -n 4 KEYS 'CHASSIS_MODULE|*' | grep -q . +} + +log "Waiting for CHASSIS_MODULE keys…" +while ! has_chassis_table; do + if (( ELAPSED >= MAX_WAIT )); then + log "Timed out waiting for CHASSIS_MODULE; proceeding anyway." + exit 0 + fi + sleep "$INTERVAL" + ELAPSED=$((ELAPSED + INTERVAL)) +done + +log "SONiC core is ready." +exit 0 diff --git a/setup.py b/setup.py index 0f83259f..0b7252ed 100644 --- a/setup.py +++ b/setup.py @@ -31,9 +31,14 @@ maintainer_email = 'jolevequ@microsoft.com', packages = [ 'host_modules', - 'utils', + 'utils' ], - scripts = [ + # Map packages to their actual dirs + package_dir = { + 'host_modules': 'host_modules', + 'utils': 'utils' + }, + scripts=[ 'scripts/caclmgrd', 'scripts/hostcfgd', 'scripts/featured', @@ -41,6 +46,9 @@ 'scripts/procdockerstatsd', 'scripts/determine-reboot-cause', 'scripts/process-reboot-cause', + 'scripts/check_platform.py', + 'scripts/wait-for-sonic-core.sh', + 'scripts/gnoi_shutdown_daemon.py', 'scripts/sonic-host-server', 'scripts/ldap.py' ], diff --git a/tests/check_platform_test.py b/tests/check_platform_test.py new file mode 100644 index 00000000..76067cc6 --- /dev/null +++ b/tests/check_platform_test.py @@ -0,0 +1,60 @@ +import sys +import os +from unittest.mock import patch +import unittest + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'scripts'))) + +import check_platform + +class TestCheckPlatform(unittest.TestCase): + + @patch('utilities_common.chassis.is_dpu', return_value=False) + @patch('sonic_py_common.device_info.is_smartswitch', return_value=True) + def test_smart_switch_npu(self, mock_is_smartswitch, mock_is_dpu): + """Test case for SmartSwitch NPU platform.""" + with self.assertRaises(SystemExit) as cm: + check_platform.main() + self.assertEqual(cm.exception.code, 0) + + @patch('utilities_common.chassis.is_dpu', return_value=True) + @patch('sonic_py_common.device_info.is_smartswitch', return_value=True) + def test_dpu_platform(self, mock_is_smartswitch, mock_is_dpu): + """Test case for DPU platform (SmartSwitch but is DPU).""" + with self.assertRaises(SystemExit) as cm: + check_platform.main() + self.assertEqual(cm.exception.code, 1) + + @patch('utilities_common.chassis.is_dpu', return_value=False) + @patch('sonic_py_common.device_info.is_smartswitch', return_value=False) + def test_other_platform(self, mock_is_smartswitch, mock_is_dpu): + """Test case for other platforms (not SmartSwitch).""" + with self.assertRaises(SystemExit) as cm: + check_platform.main() + self.assertEqual(cm.exception.code, 1) + + @patch('sonic_py_common.device_info.is_smartswitch', side_effect=ImportError("Test error")) + def test_exception(self, mock_is_smartswitch): + """Test case for exception during is_smartswitch check.""" + with self.assertRaises(SystemExit) as cm: + check_platform.main() + self.assertEqual(cm.exception.code, 1) + + @patch('utilities_common.chassis.is_dpu', side_effect=AttributeError("DPU check failed")) + @patch('sonic_py_common.device_info.is_smartswitch', return_value=True) + def test_is_dpu_exception(self, mock_is_smartswitch, mock_is_dpu): + """Test case when is_dpu() raises an exception.""" + with self.assertRaises(SystemExit) as cm: + check_platform.main() + self.assertEqual(cm.exception.code, 1) + + @patch('utilities_common.chassis.is_dpu', return_value=False) + @patch('sonic_py_common.device_info.is_smartswitch', side_effect=ImportError("Module not found")) + def test_is_smartswitch_import_error(self, mock_is_smartswitch, mock_is_dpu): + """Test case when is_smartswitch import fails.""" + with self.assertRaises(SystemExit) as cm: + check_platform.main() + self.assertEqual(cm.exception.code, 1) + +if __name__ == '__main__': + unittest.main() diff --git a/tests/gnoi_shutdown_daemon_test.py b/tests/gnoi_shutdown_daemon_test.py new file mode 100644 index 00000000..76ab7536 --- /dev/null +++ b/tests/gnoi_shutdown_daemon_test.py @@ -0,0 +1,540 @@ +import unittest +from unittest.mock import patch, MagicMock, call +import subprocess +import sys +import os +import json + +# Mock redis module (available in SONiC runtime, not in test environment) +sys.modules['redis'] = MagicMock() + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'scripts'))) + +import gnoi_shutdown_daemon + +# Common fixtures +mock_message = { + "type": "pmessage", + "channel": f"__keyspace@{gnoi_shutdown_daemon.CONFIG_DB_INDEX}__:CHASSIS_MODULE|DPU0", + "data": "hset", +} +mock_config_entry = { + "admin_status": "down" +} +mock_ip_entry = {"ips": ["10.0.0.1"]} +mock_port_entry = {"gnmi_port": "12345"} + + +class TestGnoiShutdownDaemon(unittest.TestCase): + + def setUp(self): + # Ensure a clean state for each test + gnoi_shutdown_daemon.main = gnoi_shutdown_daemon.__dict__["main"] + + def test_execute_command_success(self): + """Test successful execution of a gNOI command.""" + with patch("gnoi_shutdown_daemon.subprocess.run") as mock_run: + mock_run.return_value = MagicMock(returncode=0, stdout="success", stderr="") + rc, stdout, stderr = gnoi_shutdown_daemon.execute_command(["dummy"]) + self.assertEqual(rc, 0) + self.assertEqual(stdout, "success") + self.assertEqual(stderr, "") + + def test_execute_command_timeout(self): + """Test gNOI command timeout.""" + with patch("gnoi_shutdown_daemon.subprocess.run", side_effect=subprocess.TimeoutExpired(cmd=["dummy"], timeout=60)): + rc, stdout, stderr = gnoi_shutdown_daemon.execute_command(["dummy"]) + self.assertEqual(rc, -1) + self.assertEqual(stdout, "") + self.assertIn("Command timed out", stderr) + + def test_execute_command_exception(self): + """Test gNOI command failure due to an exception.""" + with patch("gnoi_shutdown_daemon.subprocess.run", side_effect=Exception("Test error")): + rc, stdout, stderr = gnoi_shutdown_daemon.execute_command(["dummy"]) + self.assertEqual(rc, -2) + self.assertEqual(stdout, "") + self.assertIn("Command failed: Test error", stderr) + + def test_get_halt_timeout_from_platform_json(self): + """Test _get_halt_timeout with platform.json containing timeout.""" + from unittest.mock import mock_open + + mock_chassis = MagicMock() + mock_chassis.get_name.return_value = "test_platform" + + mock_platform_instance = MagicMock() + mock_platform_instance.get_chassis.return_value = mock_chassis + + mock_platform_class = MagicMock(return_value=mock_platform_instance) + mock_platform_module = MagicMock() + mock_platform_module.Platform = mock_platform_class + + platform_json_content = {"dpu_halt_services_timeout": 120} + + with patch.dict('sys.modules', {'sonic_platform': MagicMock(), 'sonic_platform.platform': mock_platform_module}): + with patch("gnoi_shutdown_daemon.os.path.exists", return_value=True): + with patch("builtins.open", mock_open(read_data=json.dumps(platform_json_content))): + timeout = gnoi_shutdown_daemon._get_halt_timeout() + self.assertEqual(timeout, 120) + + def test_get_halt_timeout_default(self): + """Test _get_halt_timeout returns default when platform.json not found.""" + mock_chassis = MagicMock() + mock_chassis.get_name.return_value = "test_platform" + + mock_platform_instance = MagicMock() + mock_platform_instance.get_chassis.return_value = mock_chassis + + mock_platform_class = MagicMock(return_value=mock_platform_instance) + mock_platform_module = MagicMock() + mock_platform_module.Platform = mock_platform_class + + with patch.dict('sys.modules', {'sonic_platform': MagicMock(), 'sonic_platform.platform': mock_platform_module}): + with patch("gnoi_shutdown_daemon.os.path.exists", return_value=False): + timeout = gnoi_shutdown_daemon._get_halt_timeout() + self.assertEqual(timeout, gnoi_shutdown_daemon.STATUS_POLL_TIMEOUT_SEC) + + def test_get_halt_timeout_exception(self): + """Test _get_halt_timeout returns default on exception.""" + # Mock sonic_platform import to succeed, then mock file operation to raise exception + mock_chassis = MagicMock() + mock_chassis.get_name.return_value = "test-platform" + mock_platform_class = MagicMock() + mock_platform_class.return_value.get_chassis.return_value = mock_chassis + + with patch.dict('sys.modules', {'sonic_platform': MagicMock(), 'sonic_platform.platform': MagicMock(Platform=mock_platform_class)}), \ + patch('gnoi_shutdown_daemon.open', side_effect=OSError("File system error")): + timeout = gnoi_shutdown_daemon._get_halt_timeout() + self.assertEqual(timeout, gnoi_shutdown_daemon.STATUS_POLL_TIMEOUT_SEC) + + @patch('gnoi_shutdown_daemon.daemon_base.db_connect') + @patch('gnoi_shutdown_daemon.GnoiRebootHandler') + @patch('gnoi_shutdown_daemon.swsscommon.ConfigDBConnector') + @patch('threading.Thread') + def test_main_loop_flow(self, mock_thread, mock_config_db_connector_class, mock_gnoi_reboot_handler, mock_db_connect): + """Test the main loop processing of a shutdown event.""" + # Mock DB connections + mock_state_db = MagicMock() + mock_config_db = MagicMock() + mock_db_connect.side_effect = [mock_state_db, mock_config_db] + + # Mock config_db.hget to return admin_status=down to trigger thread creation + mock_config_db.hget.return_value = "down" + + # Mock ConfigDBConnector for pubsub + mock_config_db_connector = MagicMock() + mock_config_db_connector.db_name = "CONFIG_DB" + mock_pubsub = MagicMock() + mock_pubsub.get_message.side_effect = [mock_message, KeyboardInterrupt] + mock_redis_client = MagicMock() + mock_redis_client.pubsub.return_value = mock_pubsub + mock_config_db_connector.get_redis_client.return_value = mock_redis_client + mock_config_db_connector_class.return_value = mock_config_db_connector + + # Mock chassis + mock_chassis = MagicMock() + mock_platform_instance = MagicMock() + mock_platform_instance.get_chassis.return_value = mock_chassis + + # Create mock for sonic_platform.platform module + mock_platform_submodule = MagicMock() + mock_platform_submodule.Platform.return_value = mock_platform_instance + + # Create mock for sonic_platform parent module + mock_sonic_platform = MagicMock() + mock_sonic_platform.platform = mock_platform_submodule + + # Mock the reboot handler's _handle_transition to avoid actual execution + mock_handler_instance = MagicMock() + mock_gnoi_reboot_handler.return_value = mock_handler_instance + + # Temporarily add mocks to sys.modules for the duration of this test + with patch.dict('sys.modules', { + 'sonic_platform': mock_sonic_platform, + 'sonic_platform.platform': mock_platform_submodule + }): + with self.assertRaises(KeyboardInterrupt): + gnoi_shutdown_daemon.main() + + # Verify initialization + mock_db_connect.assert_has_calls([call("STATE_DB"), call("CONFIG_DB")]) + mock_gnoi_reboot_handler.assert_called_with(mock_state_db, mock_config_db, mock_chassis) + + # Verify that a thread was created to handle the transition + mock_thread.assert_called_once() + # Verify the thread was started + mock_thread.return_value.start.assert_called_once() + + @patch('gnoi_shutdown_daemon._get_halt_timeout', return_value=60) + @patch('gnoi_shutdown_daemon.get_dpu_ip') + @patch('gnoi_shutdown_daemon.get_dpu_gnmi_port') + @patch('gnoi_shutdown_daemon.execute_command') + @patch('gnoi_shutdown_daemon.time.sleep') + @patch('gnoi_shutdown_daemon.time.monotonic') + def test_handle_transition_success(self, mock_monotonic, mock_sleep, mock_execute_command, mock_get_gnmi_port, mock_get_dpu_ip, mock_get_halt_timeout): + """Test the full successful transition handling.""" + mock_db = MagicMock() + mock_config_db = MagicMock() + mock_chassis = MagicMock() + + # Mock return values + mock_get_dpu_ip.return_value = "10.0.0.1" + mock_get_gnmi_port.return_value = "8080" + + # Mock table.get() for gnoi_halt_in_progress check + mock_table = MagicMock() + mock_table.get.return_value = (True, [("gnoi_halt_in_progress", "True")]) + + # Mock time for polling + mock_monotonic.side_effect = [ + 0, 1, # For _wait_for_gnoi_halt_in_progress + 2, 3 # For _poll_reboot_status + ] + + # Reboot command success, RebootStatus success + mock_execute_command.side_effect = [ + (0, "reboot sent", ""), + (0, "reboot complete", "") + ] + + # Mock module for clear operation + mock_module = MagicMock() + mock_chassis.get_module_index.return_value = 0 + mock_chassis.get_module.return_value = mock_module + + with patch('gnoi_shutdown_daemon.swsscommon.Table', return_value=mock_table): + handler = gnoi_shutdown_daemon.GnoiRebootHandler(mock_db, mock_config_db, mock_chassis) + result = handler._handle_transition("DPU0", "shutdown") + + self.assertTrue(result) + mock_chassis.get_module_index.assert_called_with("DPU0") + mock_chassis.get_module.assert_called_with(0) + mock_module.clear_module_gnoi_halt_in_progress.assert_called_once() + self.assertEqual(mock_execute_command.call_count, 2) + + @patch('gnoi_shutdown_daemon._get_halt_timeout', return_value=60) + @patch('gnoi_shutdown_daemon.get_dpu_ip') + @patch('gnoi_shutdown_daemon.get_dpu_gnmi_port') + @patch('gnoi_shutdown_daemon.time.sleep') + @patch('gnoi_shutdown_daemon.time.monotonic') + @patch('gnoi_shutdown_daemon.execute_command') + def test_handle_transition_gnoi_halt_timeout(self, mock_execute_command, mock_monotonic, mock_sleep, mock_get_gnmi_port, mock_get_dpu_ip, mock_get_halt_timeout): + """Test transition proceeds despite gnoi_halt_in_progress timeout.""" + mock_db = MagicMock() + mock_config_db = MagicMock() + mock_chassis = MagicMock() + + mock_get_dpu_ip.return_value = "10.0.0.1" + mock_get_gnmi_port.return_value = "8080" + + # Mock table.get() to never return True (simulates timeout in wait) + mock_table = MagicMock() + mock_table.get.return_value = (True, [("gnoi_halt_in_progress", "False")]) + + # Simulate timeout in _wait_for_gnoi_halt_in_progress, then success in _poll_reboot_status + mock_monotonic.side_effect = [ + # _wait_for_gnoi_halt_in_progress times out + 0, 1, 2, gnoi_shutdown_daemon.STATUS_POLL_TIMEOUT_SEC + 1, + # _poll_reboot_status succeeds + 0, 1 + ] + + # Reboot command and status succeed + mock_execute_command.side_effect = [ + (0, "reboot sent", ""), + (0, "reboot complete", "") + ] + + # Mock module for clear operation + mock_module = MagicMock() + mock_chassis.get_module_index.return_value = 0 + mock_chassis.get_module.return_value = mock_module + + with patch('gnoi_shutdown_daemon.swsscommon.Table', return_value=mock_table): + handler = gnoi_shutdown_daemon.GnoiRebootHandler(mock_db, mock_config_db, mock_chassis) + result = handler._handle_transition("DPU0", "shutdown") + + # Should still succeed - code proceeds anyway after timeout warning + self.assertTrue(result) + mock_chassis.get_module_index.assert_called_with("DPU0") + mock_chassis.get_module.assert_called_with(0) + mock_module.clear_module_gnoi_halt_in_progress.assert_called_once() + + def test_get_dpu_ip_and_port(self): + """Test DPU IP and gNMI port retrieval.""" + # Test IP retrieval + mock_config = MagicMock() + mock_config.hget.return_value = "10.0.0.1" + + ip = gnoi_shutdown_daemon.get_dpu_ip(mock_config, "DPU0") + self.assertEqual(ip, "10.0.0.1") + mock_config.hget.assert_called_with("DHCP_SERVER_IPV4_PORT|bridge-midplane|dpu0", "ips@") + + # Test port retrieval + mock_config = MagicMock() + mock_config.hget.return_value = "12345" + + port = gnoi_shutdown_daemon.get_dpu_gnmi_port(mock_config, "DPU0") + self.assertEqual(port, "12345") + + # Test port fallback + mock_config = MagicMock() + mock_config.hget.return_value = None + + port = gnoi_shutdown_daemon.get_dpu_gnmi_port(mock_config, "DPU0") + self.assertEqual(port, "8080") + + @patch('gnoi_shutdown_daemon._get_halt_timeout', return_value=60) + @patch('gnoi_shutdown_daemon.get_dpu_ip', return_value=None) + @patch('gnoi_shutdown_daemon.get_dpu_gnmi_port', return_value="8080") + def test_handle_transition_ip_failure(self, mock_get_gnmi_port, mock_get_dpu_ip, mock_get_halt_timeout): + """Test handle_transition failure on DPU IP retrieval.""" + mock_db = MagicMock() + mock_config_db = MagicMock() + mock_chassis = MagicMock() + + # Mock module for clear operation + mock_module = MagicMock() + mock_chassis.get_module_index.return_value = 0 + mock_chassis.get_module.return_value = mock_module + + handler = gnoi_shutdown_daemon.GnoiRebootHandler(mock_db, mock_config_db, mock_chassis) + + # Mock _wait_for_gnoi_halt_in_progress to return immediately to prevent hanging + handler._wait_for_gnoi_halt_in_progress = MagicMock(return_value=True) + + result = handler._handle_transition("DPU0", "shutdown") + + self.assertFalse(result) + # Verify that clear_module_gnoi_halt_in_progress was called + mock_chassis.get_module_index.assert_called_with("DPU0") + mock_chassis.get_module.assert_called_with(0) + mock_module.clear_module_gnoi_halt_in_progress.assert_called_once() + + @patch('gnoi_shutdown_daemon.get_dpu_ip', return_value="10.0.0.1") + @patch('gnoi_shutdown_daemon.get_dpu_gnmi_port', return_value="8080") + @patch('gnoi_shutdown_daemon.execute_command', return_value=(-1, "", "error")) + def test_send_reboot_command_failure(self, mock_execute, mock_get_port, mock_get_ip): + """Test failure of _send_reboot_command.""" + handler = gnoi_shutdown_daemon.GnoiRebootHandler(MagicMock(), MagicMock(), MagicMock()) + result = handler._send_reboot_command("DPU0", "10.0.0.1", "8080") + self.assertFalse(result) + + def test_get_dpu_gnmi_port_variants(self): + """Test DPU gNMI port retrieval with name variants.""" + mock_config = MagicMock() + mock_config.hget.side_effect = [ + None, # dpu0 fails + None, # DPU0 fails + "12345" # DPU0 succeeds + ] + + port = gnoi_shutdown_daemon.get_dpu_gnmi_port(mock_config, "DPU0") + self.assertEqual(port, "12345") + self.assertEqual(mock_config.hget.call_count, 3) + + @patch('gnoi_shutdown_daemon.daemon_base.db_connect') + @patch('gnoi_shutdown_daemon.swsscommon.ConfigDBConnector') + def test_main_loop_no_dpu_name(self, mock_config_db_connector_class, mock_db_connect): + """Test main loop with a malformed key.""" + mock_chassis = MagicMock() + mock_platform_instance = MagicMock() + mock_platform_instance.get_chassis.return_value = mock_chassis + + # Create mock for sonic_platform.platform module + mock_platform_submodule = MagicMock() + mock_platform_submodule.Platform.return_value = mock_platform_instance + + # Create mock for sonic_platform parent module + mock_sonic_platform = MagicMock() + mock_sonic_platform.platform = mock_platform_submodule + + mock_pubsub = MagicMock() + # Malformed message, then stop + malformed_message = mock_message.copy() + malformed_message["channel"] = f"__keyspace@{gnoi_shutdown_daemon.CONFIG_DB_INDEX}__:CHASSIS_MODULE|" + mock_pubsub.get_message.side_effect = [malformed_message, KeyboardInterrupt] + + # Mock DB connections + mock_state_db = MagicMock() + mock_config_db = MagicMock() + mock_db_connect.side_effect = [mock_state_db, mock_config_db] + + # Mock ConfigDBConnector for pubsub + mock_config_db_connector = MagicMock() + mock_config_db_connector.db_name = "CONFIG_DB" + mock_redis_client = MagicMock() + mock_redis_client.pubsub.return_value = mock_pubsub + mock_config_db_connector.get_redis_client.return_value = mock_redis_client + mock_config_db_connector_class.return_value = mock_config_db_connector + + with patch.dict('sys.modules', { + 'sonic_platform': mock_sonic_platform, + 'sonic_platform.platform': mock_platform_submodule + }): + with self.assertRaises(KeyboardInterrupt): + gnoi_shutdown_daemon.main() + + @patch('gnoi_shutdown_daemon.daemon_base.db_connect') + @patch('gnoi_shutdown_daemon.swsscommon.ConfigDBConnector') + def test_main_loop_get_transition_exception(self, mock_config_db_connector_class, mock_db_connect): + """Test main loop when hget raises an exception.""" + mock_chassis = MagicMock() + mock_platform_instance = MagicMock() + mock_platform_instance.get_chassis.return_value = mock_chassis + + # Create mock for sonic_platform.platform module + mock_platform_submodule = MagicMock() + mock_platform_submodule.Platform.return_value = mock_platform_instance + + # Create mock for sonic_platform parent module + mock_sonic_platform = MagicMock() + mock_sonic_platform.platform = mock_platform_submodule + + mock_pubsub = MagicMock() + mock_pubsub.get_message.side_effect = [mock_message, KeyboardInterrupt] + + # Mock config_db to raise exception on hget + mock_config_db = MagicMock() + mock_state_db = MagicMock() + mock_db_connect.side_effect = [mock_state_db, mock_config_db] + mock_config_db.hget.side_effect = AttributeError("DB error") + + # Mock ConfigDBConnector for pubsub + mock_config_db_connector = MagicMock() + mock_config_db_connector.db_name = "CONFIG_DB" + mock_redis_client = MagicMock() + mock_redis_client.pubsub.return_value = mock_pubsub + mock_config_db_connector.get_redis_client.return_value = mock_redis_client + mock_config_db_connector_class.return_value = mock_config_db_connector + + with patch.dict('sys.modules', { + 'sonic_platform': mock_sonic_platform, + 'sonic_platform.platform': mock_platform_submodule + }): + with self.assertRaises(KeyboardInterrupt): + gnoi_shutdown_daemon.main() + + @patch('gnoi_shutdown_daemon._get_halt_timeout', return_value=60) + @patch('gnoi_shutdown_daemon.execute_command', return_value=(-1, "", "RPC error")) + def test_poll_reboot_status_failure(self, mock_execute_command, mock_get_halt_timeout): + """Test _poll_reboot_status with a command failure.""" + handler = gnoi_shutdown_daemon.GnoiRebootHandler(MagicMock(), MagicMock(), MagicMock()) + with patch('gnoi_shutdown_daemon.time.monotonic', side_effect=[0, 1, 61]): + result = handler._poll_reboot_status("DPU0", "10.0.0.1", "8080") + self.assertFalse(result) + + def test_sonic_platform_import_mock(self): + """Simple test to verify sonic_platform import mocking works.""" + # Create mock chassis + mock_chassis = MagicMock() + mock_chassis.get_name.return_value = "test_chassis" + + # Create mock platform instance that returns our chassis + mock_platform_instance = MagicMock() + mock_platform_instance.get_chassis.return_value = mock_chassis + + # Create mock Platform class + mock_platform_class = MagicMock(return_value=mock_platform_instance) + + # Create mock for sonic_platform.platform module + mock_platform_submodule = MagicMock() + mock_platform_submodule.Platform = mock_platform_class + + # Create mock for sonic_platform parent module + mock_sonic_platform = MagicMock() + mock_sonic_platform.platform = mock_platform_submodule + + # Test that we can mock the import + with patch.dict('sys.modules', { + 'sonic_platform': mock_sonic_platform, + 'sonic_platform.platform': mock_platform_submodule + }): + # Simulate what the actual code does + from sonic_platform import platform + chassis = platform.Platform().get_chassis() + + # Verify it worked + self.assertEqual(chassis, mock_chassis) + self.assertEqual(chassis.get_name(), "test_chassis") + mock_platform_class.assert_called_once() + mock_platform_instance.get_chassis.assert_called_once() + + def test_get_dpu_ip_with_string_ips(self): + """Test get_dpu_ip when ips is a string instead of list.""" + mock_config = MagicMock() + mock_config.hget.return_value = "10.0.0.5" + + ip = gnoi_shutdown_daemon.get_dpu_ip(mock_config, "DPU1") + self.assertEqual(ip, "10.0.0.5") + + def test_get_dpu_ip_empty_entry(self): + """Test get_dpu_ip when entry is empty.""" + mock_config = MagicMock() + mock_config.hget.return_value = None + + ip = gnoi_shutdown_daemon.get_dpu_ip(mock_config, "DPU1") + self.assertIsNone(ip) + + def test_get_dpu_ip_no_ips_field(self): + """Test get_dpu_ip when hget returns None (field doesn't exist).""" + mock_config = MagicMock() + mock_config.hget.return_value = None + + ip = gnoi_shutdown_daemon.get_dpu_ip(mock_config, "DPU1") + self.assertIsNone(ip) + + def test_get_dpu_ip_exception(self): + """Test get_dpu_ip when exception occurs.""" + mock_config = MagicMock() + mock_config.hget.side_effect = AttributeError("Database error") + + ip = gnoi_shutdown_daemon.get_dpu_ip(mock_config, "DPU1") + self.assertIsNone(ip) + + def test_get_dpu_gnmi_port_exception(self): + """Test get_dpu_gnmi_port when exception occurs.""" + mock_config = MagicMock() + mock_config.hget.side_effect = AttributeError("Database error") + + port = gnoi_shutdown_daemon.get_dpu_gnmi_port(mock_config, "DPU1") + self.assertEqual(port, "8080") + + def test_send_reboot_command_success(self): + """Test successful _send_reboot_command.""" + with patch('gnoi_shutdown_daemon.execute_command', return_value=(0, "success", "")): + handler = gnoi_shutdown_daemon.GnoiRebootHandler(MagicMock(), MagicMock(), MagicMock()) + result = handler._send_reboot_command("DPU0", "10.0.0.1", "8080") + self.assertTrue(result) + + @patch('gnoi_shutdown_daemon._get_halt_timeout', return_value=60) + @patch('gnoi_shutdown_daemon.get_dpu_ip', return_value="10.0.0.1") + @patch('gnoi_shutdown_daemon.get_dpu_gnmi_port', side_effect=Exception("Port lookup failed")) + def test_handle_transition_config_exception(self, mock_get_port, mock_get_ip, mock_get_halt_timeout): + """Test handle_transition when configuration lookup raises exception.""" + mock_db = MagicMock() + mock_config_db = MagicMock() + mock_chassis = MagicMock() + + # Mock module for clear operation + mock_module = MagicMock() + mock_chassis.get_module_index.return_value = 0 + mock_chassis.get_module.return_value = mock_module + + handler = gnoi_shutdown_daemon.GnoiRebootHandler(mock_db, mock_config_db, mock_chassis) + + # Mock _wait_for_gnoi_halt_in_progress to return immediately to prevent hanging + handler._wait_for_gnoi_halt_in_progress = MagicMock(return_value=True) + + result = handler._handle_transition("DPU0", "shutdown") + + self.assertFalse(result) + # Verify that clear_module_gnoi_halt_in_progress was called + mock_chassis.get_module_index.assert_called_with("DPU0") + mock_chassis.get_module.assert_called_with(0) + mock_module.clear_module_gnoi_halt_in_progress.assert_called_once() + + +if __name__ == '__main__': + unittest.main()