diff --git a/ansible/devutil/device_inventory.py b/ansible/devutil/device_inventory.py new file mode 100644 index 00000000000..b9daf581254 --- /dev/null +++ b/ansible/devutil/device_inventory.py @@ -0,0 +1,93 @@ +import os +import csv +import glob +from typing import Dict, List, Optional + + +class DeviceInfo(object): + """Device information.""" + + def __init__( + self, + hostname: str, + management_ip: str, + hw_sku: str, + device_type: str, + protocol: str = "", + os: str = "", + ): + self.hostname = hostname + self.management_ip = management_ip + self.hw_sku = hw_sku + self.device_type = device_type + self.protocol = protocol + self.os = os + + @staticmethod + def from_csv_row(row: List[str]) -> "DeviceInfo": + # The device CSV file has the following columns (the last 2 are optional): + # + # Hostname,ManagementIp,HwSku,Type,Protocol,Os + # + return DeviceInfo( + row[0], + row[1].split("/")[0], + row[2], + row[3], + row[4] if len(row) > 4 else "", + row[5] if len(row) > 5 else "", + ) + + def is_ssh_supported(self) -> bool: + if self.device_type == "ConsoleServer": + return False + + if self.protocol == "snmp": + return False + + return True + + +class DeviceInventory(object): + """Device inventory from csv files.""" + + def __init__( + self, inv_name: str, device_file_name: str, devices: Dict[str, DeviceInfo] + ): + self.inv_name = inv_name + self.device_file_name = device_file_name + self.devices = devices + + @staticmethod + def from_device_files(device_file_pattern: str) -> "List[DeviceInventory]": + inv: List[DeviceInventory] = [] + for file_path in glob.glob(device_file_pattern): + device_inventory = DeviceInventory.from_device_file(file_path) + inv.append(device_inventory) + + return inv + + @staticmethod + def from_device_file(file_path: str) -> "DeviceInventory": + print(f"Loading device inventory: {file_path}") + + # Parse inv name from the file path. + # The inv name can be deducted from the file name part in the path using format sonic__devices.csv + inv_name = os.path.basename(file_path).split("_")[1] + + devices: Dict[str, DeviceInfo] = {} + with open(file_path, newline="") as file: + reader = csv.reader(file) + + # Skip the header line + next(reader) + + for row in reader: + if row: + device_info = DeviceInfo.from_csv_row(row) + devices[device_info.hostname] = device_info + + return DeviceInventory(inv_name, file_path, devices) + + def get_device(self, hostname: str) -> Optional[DeviceInfo]: + return self.devices.get(hostname) diff --git a/ansible/devutil/testbed.py b/ansible/devutil/testbed.py index 7afd3ac59ae..58892335494 100644 --- a/ansible/devutil/testbed.py +++ b/ansible/devutil/testbed.py @@ -5,19 +5,27 @@ import os import re import yaml +from typing import Any, Dict, List, Optional + +from devutil.device_inventory import DeviceInfo, DeviceInventory class TestBed(object): """Data model that represents a testbed object.""" @classmethod - def from_file(cls, testbed_file="testbed.yaml", testbed_pattern=None, hosts=None): + def from_file( + cls, + device_inventories: List[DeviceInventory], + testbed_file: str = "testbed.yaml", + testbed_pattern: Optional[str] = None, + ) -> Dict[str, "TestBed"]: """Load all testbed objects from YAML file. Args: testbed_file (str): Path to testbed file. testbed_pattern (str): Regex pattern to filter testbeds. - hosts (AnsibleHosts): AnsibleHosts object that contains all hosts in the testbed. + hosts (HostManager): AnsibleHosts object that contains all hosts in the testbed. Returns: dict: Testbed name to testbed object mapping. @@ -39,11 +47,11 @@ def from_file(cls, testbed_file="testbed.yaml", testbed_pattern=None, hosts=None for raw_testbed in raw_testbeds: if testbed_pattern and not testbed_pattern.match(raw_testbed["conf-name"]): continue - testbeds[raw_testbed["conf-name"]] = cls(raw_testbed, hosts=hosts) + testbeds[raw_testbed["conf-name"]] = cls(raw_testbed, device_inventories) return testbeds - def __init__(self, raw_dict, hosts=None): + def __init__(self, raw_dict: Any, device_inventories: List[DeviceInventory]): """Initialize a testbed object. Args: @@ -55,46 +63,26 @@ def __init__(self, raw_dict, hosts=None): setattr(self, key.replace("-", "_"), value) # Create a PTF node object - self.ptf_node = TestBedNode(self.ptf, hosts) - - # Loop through each DUT in the testbed and create TestBedNode object + self.ptf_node = DeviceInfo( + hostname=self.ptf, + management_ip=self.ptf_ip.split("/")[0], + hw_sku="Container", + device_type="PTF", + protocol="ssh", + ) + + # Loop through each DUT in the testbed and find the device info self.dut_nodes = {} for dut in raw_dict["dut"]: - self.dut_nodes[dut] = TestBedNode(dut, hosts) + for inv in device_inventories: + device = inv.get_device(dut) + if device is not None: + self.dut_nodes[dut] = device + break + else: + print(f"Error: Failed to find device info for DUT {dut}") # Some testbeds are dummy ones and doesn't have inv_name specified, # so we need to use "unknown" as inv_name instead. if not hasattr(self, "inv_name"): self.inv_name = "unknown" - - -class TestBedNode(object): - """Data model that represents a testbed node object.""" - - def __init__(self, name, hosts=None): - """Initialize a testbed node object. - - Args: - name (str): Node name. - ansible_vars (dict): Ansible variables of the node. - """ - self.name = name - self.ssh_ip = None - self.ssh_user = None - self.ssh_pass = None - - if hosts: - try: - host_vars = hosts.get_host_vars(self.name) - self.ssh_ip = host_vars["ansible_host"] - self.ssh_user = host_vars["creds"]["username"] - self.ssh_pass = host_vars["creds"]["password"][0] - except Exception as e: - print( - "Error: Failed to get host vars for {}: {}".format( - self.name, str(e) - ) - ) - self.ssh_ip = None - self.ssh_user = None - self.ssh_pass = None diff --git a/ansible/ssh_session_gen.py b/ansible/ssh_session_gen.py index 1f28868d674..f45fc160603 100644 --- a/ansible/ssh_session_gen.py +++ b/ansible/ssh_session_gen.py @@ -4,117 +4,264 @@ import argparse import os +import re +from typing import Dict, List, Optional, Tuple +from devutil.device_inventory import DeviceInfo, DeviceInventory from devutil.testbed import TestBed from devutil.inv_helpers import HostManager -from devutil.ssh_session_repo import SecureCRTSshSessionRepoGenerator, SshConfigSshSessionRepoGenerator +from devutil.ssh_session_repo import ( + SecureCRTSshSessionRepoGenerator, + SshConfigSshSessionRepoGenerator, + SshSessionRepoGenerator, +) + + +class SSHInfoSolver(object): + """SSH info solver for testbeds and devices.""" + + def __init__( + self, + ansible_hosts: HostManager, + dut_user: Optional[str], + dut_pass: Optional[str], + server_user: Optional[str], + server_pass: Optional[str], + leaf_fanout_user: Optional[str], + leaf_fanout_pass: Optional[str], + root_fanout_user: Optional[str], + root_fanout_pass: Optional[str], + console_server_user: Optional[str], + console_server_pass: Optional[str], + ptf_user: Optional[str], + ptf_pass: Optional[str], + ): + """ + Init SSH credential solver. + + Args: + ansible_hosts (HostManager): Ansible host inventory manager. + dut_user (str): Default SSH user for DUTs. + dut_pass (str): Default SSH password for DUTs. + server_user (str): Default SSH user for servers. + server_pass (str): Default SSH password for servers. + leaf_fanout_user (str): Default SSH user for leaf fanouts. + leaf_fanout_pass (str): Default SSH password for leaf fanouts. + root_fanout_user (str): Default SSH user for root fanouts. + root_fanout_pass (str): Default SSH password for root fanouts. + """ + self.ansible_hosts = ansible_hosts + + self.ssh_overrides = { + "Server": {"user": server_user, "pass": server_pass}, + "DevSonic": {"user": dut_user, "pass": dut_pass}, + "FanoutLeaf": {"user": leaf_fanout_user, "pass": leaf_fanout_pass}, + "FanoutLeafSonic": {"user": leaf_fanout_user, "pass": leaf_fanout_pass}, + "FanoutRoot": {"user": root_fanout_user, "pass": root_fanout_pass}, + "ConsoleServer": {"user": console_server_user, "pass": console_server_pass}, + "PTF": {"user": ptf_user, "pass": ptf_pass}, + } + + def get_ssh_cred(self, device: DeviceInfo) -> Tuple[str, str, str]: + """ + Get SSH info for a testbed node. + + Args: + device (DeviceInfo): Represents a connectable node in the testbed. + + Returns: + tuple: SSH IP, user and password. + """ + ssh_ip = device.management_ip + ssh_user = ( + self.ssh_overrides[device.device_type]["user"] + if device.device_type in self.ssh_overrides + else "" + ) + ssh_pass = ( + self.ssh_overrides[device.device_type]["pass"] + if device.device_type in self.ssh_overrides + else "" + ) + + if not ssh_ip or not ssh_user or not ssh_pass: + try: + host_vars = self.ansible_hosts.get_host_vars(device.hostname) + + ssh_ip = host_vars["ansible_host"] if not ssh_ip else ssh_ip + ssh_user = host_vars["creds"]["username"] if not ssh_user else ssh_user + ssh_pass = ( + host_vars["creds"]["password"][-1] if not ssh_pass else ssh_pass + ) + except Exception as e: + print( + f"Error: Failed to get SSH credential for device {device.hostname} ({device.device_type}): {str(e)}" + ) + + ssh_ip = "" if ssh_ip is None else ssh_ip + ssh_user = "" if ssh_user is None else ssh_user + ssh_pass = "" if ssh_pass is None else ssh_pass + + return ssh_ip, ssh_user, ssh_pass + + +class DeviceSshSessionRepoGenerator(object): + def __init__( + self, repo_generator: SshSessionRepoGenerator, ssh_info_solver: SSHInfoSolver + ) -> None: + self.repo_generator = repo_generator + self.ssh_info_solver = ssh_info_solver + def generate_ssh_session_for_device(self, device: DeviceInfo, session_path: str): + """Generate SSH session for a device. -class TestBedSshSessionRepoGenerator(object): + Args: + device (DeviceInfo): Represents a device. + session_path (str): Path to store the SSH session file. + """ + if not device.is_ssh_supported(): + return + + ssh_ip, ssh_user, ssh_pass = self.ssh_info_solver.get_ssh_cred(device) + if ssh_ip is None: + print( + f"WARNING: Management IP is not specified for testbed node, skipped: {device.hostname}" + ) + return + + if not ssh_user: + print( + "WARNING: SSH credential is missing for device: {}".format( + device.hostname + ) + ) + + self.repo_generator.generate( + session_path, + ssh_ip, + ssh_user, + ssh_pass, + ) + + +class TestBedSshSessionRepoGenerator(DeviceSshSessionRepoGenerator): """SSH session repo generator for testbeds.""" - def __init__(self, testbeds, repo_generator): - """Store all parameters as attributes. + def __init__( + self, + testbeds: Dict[str, TestBed], + repo_generator: SshSessionRepoGenerator, + ssh_info_solver: SSHInfoSolver, + ): + """ + Store all parameters as attributes. Args: testbeds (dict): Testbed name to testbed object mapping. repo_generator (SshSessionRepoGenerator): SSH session repo generator. """ + super().__init__(repo_generator, ssh_info_solver) self.testbeds = testbeds - self.repo_generator = repo_generator def generate(self): """Generate SSH session repo.""" + + print("\nStart generating SSH session files for all testbeds:") + for testbed in self.testbeds.values(): self._generate_ssh_sessions_for_testbed(testbed) self.repo_generator.finish() - def _generate_ssh_sessions_for_testbed(self, testbed): - """Generate SSH sessions for a testbed. + def _generate_ssh_sessions_for_testbed(self, testbed: TestBed): + """ + Generate SSH sessions for a testbed. Args: testbed (object): Represents a testbed setup. """ - print("Start generating SSH sessions for testbed: {}".format(testbed.conf_name)) - - testbed_nodes = [["ptf", testbed.ptf_node]] + [ - ["dut", item] for item in testbed.dut_nodes.values() - ] - for testbed_node in testbed_nodes: - self._generate_ssh_session_for_testbed_node( - testbed, testbed_node[0], testbed_node[1] - ) - - print( - "Finish generating SSH session files for testbed: {}\n".format( - testbed.conf_name - ) - ) + devices = [testbed.ptf_node] + list(testbed.dut_nodes.values()) + for device in devices: + self._generate_ssh_session_for_testbed_node(testbed, device) def _generate_ssh_session_for_testbed_node( - self, testbed, testbed_node_type, testbed_node + self, testbed: TestBed, device: DeviceInfo ): - """Generate SSH session for a testbed node. + """ + Generate SSH session for a testbed node. We use the following naming convention for SSH session path: - //- + testbeds///- Args: testbed (object): Represents a testbed setup. testbed_node_type (str): Type of the testbed node. It can be "ptf" or "dut". testbed_node (object): Represents a connectable node in the testbed. """ - if testbed_node.ssh_ip is None: - print( - """Skip generating SSH session for testbed node: Testbed = {}, Type = {}, Node = {} - (SSH IP is not specified)""".format( - testbed.conf_name, testbed_node_type, testbed_node.name - ) - ) - return - - print( - "Start generating SSH session for testbed node: Testbed = {}, Type = {}, Node = {}".format( - testbed.conf_name, testbed_node_type, testbed_node.name - ) - ) - - if testbed_node.ssh_user == '': - print("WARNING: SSH user is empty for testbed node: {}".format(testbed_node.name)) + device_type = "dut" if device.device_type != "PTF" else "ptf" session_path = os.path.join( + "testbeds", testbed.inv_name, testbed.conf_name, - testbed_node_type + "-" + testbed_node.name, - ) - self.repo_generator.generate( - session_path, - testbed_node.ssh_ip, - testbed_node.ssh_user, - testbed_node.ssh_pass, + device_type + "-" + device.hostname, ) + self.generate_ssh_session_for_device(device, session_path) -def main(args): - print("Loading ansible host inventory: {}\n".format(args.inventory_file_paths)) - ansible_hosts = HostManager(args.inventory_file_paths) - print( - "Loading testbed config: TestBedFile = {}, Pattern = {}".format( - args.testbed_file_path, args.testbed_pattern +device_type_pattern = re.compile(r"(?// + + Args: + device_inventory (List[DeviceInventory]): Represents a device inventory. + """ + print( + "\nStart generating SSH session files for device inventory: {}".format( + device_inventory.inv_name + ) ) - ) - testbeds = TestBed.from_file( - args.testbed_file_path, args.testbed_pattern, ansible_hosts - ) - if len(testbeds) == 0: - print("No testbeds loaded. Exit.") - return - else: - print("{} testbeds loaded.\n".format(len(testbeds))) + for device in device_inventory.devices.values(): + device_type = device_type_pattern.sub("-", device.device_type).lower() + session_path = os.path.join( + "devices", device_inventory.inv_name, device_type, device.hostname + ) + self.generate_ssh_session_for_device(device, session_path) + + +def main(args): print( - "Starting SSH session repo generation with config: Target = {}, Format = {}, Template = {}".format( + "Creating generator with config: Target = {}, Format = {}, Template = {}".format( args.target, args.format, args.template_file_path ) ) @@ -135,8 +282,50 @@ def main(args): print("Unsupported output format: {}".format(args.format)) return - testbed_repo_generator = TestBedSshSessionRepoGenerator(testbeds, repo_generator) - testbed_repo_generator.generate() + print(f"\nLoading device inventories: Files = {args.device_file_pattern}") + device_inventories = DeviceInventory.from_device_files(args.device_file_pattern) + + print( + f"\nLoading testbeds: TestBedFile = {args.testbed_file_path}, Pattern = {args.testbed_pattern}" + ) + testbeds = TestBed.from_file( + device_inventories, args.testbed_file_path, args.testbed_pattern + ) + + print(f"\nLoading ansible host inventory for getting SSH info: {args.inventory_file_paths}") + ansible_hosts = HostManager(args.inventory_file_paths) + + ssh_info_solver = SSHInfoSolver( + ansible_hosts, + args.dut_user, + args.dut_pass, + args.server_user, + args.server_pass, + args.leaf_fanout_user, + args.leaf_fanout_pass, + args.root_fanout_user, + args.root_fanout_pass, + args.console_server_user, + args.console_server_pass, + args.ptf_user, + args.ptf_pass, + ) + + if len(testbeds) == 0: + print("No testbeds loaded. Skipped.") + else: + testbed_repo_generator = TestBedSshSessionRepoGenerator( + testbeds, repo_generator, ssh_info_solver + ) + testbed_repo_generator.generate() + + if len(device_inventories) == 0: + print("No device inventories loaded. Skipped.") + else: + device_repo_generator = DeviceSessionRepoGenerator( + device_inventories, repo_generator, ssh_info_solver + ) + device_repo_generator.generate() if __name__ == "__main__": @@ -162,7 +351,17 @@ def main(args): parser = argparse.ArgumentParser( description="Generate SSH session files for console access to devices.", epilog=example_text, - formatter_class=argparse.RawDescriptionHelpFormatter) + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + + parser.add_argument( + "-d", + "--device-file", + type=str, + dest="device_file_pattern", + default="files/sonic_*_devices.csv", + help="Device file path.", + ) parser.add_argument( "-t", @@ -223,7 +422,91 @@ def main(args): type=str, dest="template_file_path", help="Session file template path. Used for clone your current session settings. " - "Only used when --format=securecrt.", + "Only used when --format=securecrt.", + ) + + parser.add_argument( + "--dut-user", + type=str, + dest="dut_user", + help="SSH user name of DUTs. If not specified, we will use ansible to get the SSH configuration.", + ) + + parser.add_argument( + "--dut-pass", + type=str, + dest="dut_pass", + help="SSH password of DUTs. If not specified, we will use ansible to get the SSH configuration.", + ) + + parser.add_argument( + "--ptf-user", + type=str, + dest="ptf_user", + help="SSH user name of PTF containers. If not specified, we will use ansible to get the SSH configuration.", + ) + + parser.add_argument( + "--ptf-pass", + type=str, + dest="ptf_pass", + help="SSH password of PTF containers. If not specified, we will use ansible to get the SSH configuration.", + ) + + parser.add_argument( + "--server-user", + type=str, + dest="server_user", + help="SSH user name of servers. If not specified, we will use ansible to get the SSH configuration.", + ) + + parser.add_argument( + "--server-pass", + type=str, + dest="server_pass", + help="SSH password of servers. If not specified, we will use ansible to get the SSH configuration.", + ) + + parser.add_argument( + "--leaf-fanout-user", + type=str, + dest="leaf_fanout_user", + help="SSH user name of leaf fanouts. If not specified, we will use ansible to get the SSH configuration.", + ) + + parser.add_argument( + "--leaf-fanout-pass", + type=str, + dest="leaf_fanout_pass", + help="SSH password of leaf fanouts. If not specified, we will use ansible to get the SSH configuration.", + ) + + parser.add_argument( + "--root-fanout-user", + type=str, + dest="root_fanout_user", + help="SSH user name of root fanouts. If not specified, we will use ansible to get the SSH configuration.", + ) + + parser.add_argument( + "--root-fanout-pass", + type=str, + dest="root_fanout_pass", + help="SSH password of root fanouts. If not specified, we will use ansible to get the SSH configuration.", + ) + + parser.add_argument( + "--console-server-user", + type=str, + dest="console_server_user", + help="SSH user name of console server. If not specified, we will use ansible to get the SSH configuration.", + ) + + parser.add_argument( + "--console-server-pass", + type=str, + dest="console_server_pass", + help="SSH password of console server. If not specified, we will use ansible to get the SSH configuration.", ) args = parser.parse_args()