Skip to content

Commit bc08806

Browse files
authored
Implemented ssh configurations (sonic-net#32)
HLD in sonic-net/SONiC#1075 - Why I did it Implemented ssh configurations - How I did it Added ssh config table in configDB, once changed - hostcfgd will change the relevant OS files (sshd_config) - How to verify it Tests in added in this PR. User can change relevant configs in configDB such as ports, and see sshd port was modified - Link to config_db schema for YANG module changes https://github.com/ycoheNvidia/SONiC/blob/ssh_config/doc/ssh_config/ssh_config.md
1 parent eab4a9e commit bc08806

10 files changed

Lines changed: 1192 additions & 7 deletions

File tree

scripts/hostcfgd

Lines changed: 123 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import signal
1111
import re
1212
import jinja2
1313
import threading
14+
from shutil import copy2
1415
from sonic_py_common import device_info
1516
from sonic_py_common.general import check_output_pipe
1617
from swsscommon.swsscommon import ConfigDBConnector, DBConnector, Table, SonicDBConfig
@@ -21,6 +22,8 @@ PAM_AUTH_CONF = "/etc/pam.d/common-auth-sonic"
2122
PAM_AUTH_CONF_TEMPLATE = "/usr/share/sonic/templates/common-auth-sonic.j2"
2223
PAM_PASSWORD_CONF = "/etc/pam.d/common-password"
2324
PAM_PASSWORD_CONF_TEMPLATE = "/usr/share/sonic/templates/common-password.j2"
25+
SSH_CONFG = "/etc/ssh/sshd_config"
26+
SSH_CONFG_TMP = SSH_CONFG + ".tmp"
2427
NSS_TACPLUS_CONF = "/etc/tacplus_nss.conf"
2528
NSS_TACPLUS_CONF_TEMPLATE = "/usr/share/sonic/templates/tacplus_nss.conf.j2"
2629
NSS_RADIUS_CONF = "/etc/radius_nss.conf"
@@ -35,6 +38,11 @@ ETC_LOGIN_DEF = "/etc/login.defs"
3538
LINUX_DEFAULT_PASS_MAX_DAYS = 99999
3639
LINUX_DEFAULT_PASS_WARN_AGE = 7
3740

41+
# Ssh min-max values
42+
SSH_MIN_VALUES={"authentication_retries": 3, "login_timeout": 1, "ports": 1}
43+
SSH_MAX_VALUES={"authentication_retries": 100, "login_timeout": 600, "ports": 65535}
44+
SSH_CONFIG_NAMES={"authentication_retries": "MaxAuthTries" , "login_timeout": "LoginGraceTime"}
45+
3846
ACCOUNT_NAME = 0 # index of account name
3947
AGE_DICT = { 'MAX_DAYS': {'REGEX_DAYS': r'^PASS_MAX_DAYS[ \t]*(?P<max_days>\d*)', 'DAYS': 'max_days', 'CHAGE_FLAG': '-M '},
4048
'WARN_DAYS': {'REGEX_DAYS': r'^PASS_WARN_AGE[ \t]*(?P<warn_days>\d*)', 'DAYS': 'warn_days', 'CHAGE_FLAG': '-W '}
@@ -1066,6 +1074,12 @@ class AaaCfg(object):
10661074
"{} - failed: return code - {}, output:\n{}"
10671075
.format(err.cmd, err.returncode, err.output))
10681076

1077+
def modify_single_file_inplace(filename, operations=None):
1078+
if operations:
1079+
cmd = ["sed", '-i'] + operations + [filename]
1080+
syslog.syslog(syslog.LOG_DEBUG, "modify_single_file_inplace: cmd - {}".format(cmd))
1081+
subprocess.run(cmd)
1082+
10691083

10701084
class PasswHardening(object):
10711085
def __init__(self):
@@ -1105,11 +1119,6 @@ class PasswHardening(object):
11051119
if modify_conf:
11061120
self.modify_passw_conf_file()
11071121

1108-
def modify_single_file_inplace(self, filename, operations=None):
1109-
if operations:
1110-
cmd = ["sed", '-i'] + operations + [filename]
1111-
syslog.syslog(syslog.LOG_DEBUG, "modify_single_file_inplace: cmd - {}".format(cmd))
1112-
subprocess.call(cmd)
11131122

11141123
def set_passw_hardening_policies(self, passw_policies):
11151124
# Password Hardening flow
@@ -1154,14 +1163,14 @@ class PasswHardening(object):
11541163
self.passwd_aging_expire_modify(curr_expiration, 'MAX_DAYS')
11551164

11561165
# Aging policy for new users
1157-
self.modify_single_file_inplace(ETC_LOGIN_DEF, ["/^PASS_MAX_DAYS/c\PASS_MAX_DAYS " +str(curr_expiration)])
1166+
modify_single_file_inplace(ETC_LOGIN_DEF, ["/^PASS_MAX_DAYS/c\PASS_MAX_DAYS " +str(curr_expiration)])
11581167

11591168
if self.is_passwd_aging_expire_update(curr_expiration_warning, 'WARN_DAYS'):
11601169
# Aging policy for existing users
11611170
self.passwd_aging_expire_modify(curr_expiration_warning, 'WARN_DAYS')
11621171

11631172
# Aging policy for new users
1164-
self.modify_single_file_inplace(ETC_LOGIN_DEF, ["/^PASS_WARN_AGE/c\PASS_WARN_AGE " +str(curr_expiration_warning)])
1173+
modify_single_file_inplace(ETC_LOGIN_DEF, ["/^PASS_WARN_AGE/c\PASS_WARN_AGE " +str(curr_expiration_warning)])
11651174

11661175
def passwd_aging_expire_modify(self, curr_expiration, age_type):
11671176
normal_accounts = self.get_normal_accounts()
@@ -1249,6 +1258,103 @@ class PasswHardening(object):
12491258
# set new Password Hardening policies.
12501259
self.set_passw_hardening_policies(passw_policies)
12511260

1261+
class SshServer(object):
1262+
def __init__(self):
1263+
self.policies = {}
1264+
1265+
def load(self, policies_conf):
1266+
if 'POLICIES' in policies_conf:
1267+
self.policies_update('POLICIES', policies_conf['POLICIES'], modify_conf=False)
1268+
else:
1269+
self.policies = {}
1270+
1271+
self.modify_conf_file()
1272+
1273+
def modify_conf_file(self):
1274+
ssh_policies = {}
1275+
ssh_policies.update(self.policies)
1276+
1277+
# set new SSH server policies.
1278+
if len(ssh_policies) > 0:
1279+
self.set_policies(ssh_policies)
1280+
1281+
def policies_update(self, key, data, modify_conf=True):
1282+
syslog.syslog(syslog.LOG_DEBUG, "ssh_policies_update - key: {}".format(key))
1283+
syslog.syslog(syslog.LOG_DEBUG, "ssh_policies_update - data: {}".format(data))
1284+
if data:
1285+
if 'ports' in data:
1286+
data['ports'] = data['ports'].split(',')
1287+
self.policies = data
1288+
1289+
if modify_conf:
1290+
self.modify_conf_file()
1291+
1292+
# return first line apperience of pattern - else return number of lines in the file
1293+
def get_line_num_of_pattern(self, pattern, file_path, find_commented=False):
1294+
syslog.syslog(syslog.LOG_DEBUG, "looking for pattern {} line in file {}".format(pattern, file_path))
1295+
return_value = 0
1296+
with open(file_path, 'r') as f:
1297+
for (i, line) in enumerate(f):
1298+
if re.match(pattern, line):
1299+
syslog.syslog(syslog.LOG_DEBUG, "found pattern {} in line {}".format(pattern, str(i)))
1300+
return i + 1
1301+
if find_commented and re.match('#' + pattern, line):
1302+
syslog.syslog(syslog.LOG_DEBUG, "found pattern {} in line {}".format('#' + pattern, str(i)))
1303+
return i + 1
1304+
return_value = i
1305+
return return_value
1306+
1307+
def handle_ports_set(self, values_list):
1308+
if len(values_list) == 0:
1309+
return False
1310+
key='ports'
1311+
for port_num in values_list:
1312+
if isinstance(port_num, int):
1313+
syslog.syslog(syslog.LOG_ERR, "port num value {} in wrong format".format(port_num))
1314+
return False
1315+
if int(port_num) < SSH_MIN_VALUES[key] or SSH_MAX_VALUES[key] < int(port_num):
1316+
syslog.syslog(syslog.LOG_ERR, "Ssh {} {} out of range".format('port', port_num))
1317+
return False
1318+
port_line_num = self.get_line_num_of_pattern("Port", SSH_CONFG_TMP, True)
1319+
modify_single_file_inplace(SSH_CONFG_TMP, ['-E', "/^(#)?Port [0-9]+$/d"])
1320+
1321+
for port_num in values_list:
1322+
# add port in original line
1323+
modify_single_file_inplace(SSH_CONFG_TMP, [f'{str(port_line_num)} i Port {str(port_num)}'])
1324+
return True
1325+
1326+
def set_policies(self, ssh_policies):
1327+
# Ssh server flow
1328+
# The ssh_policies from CONFIG_DB will be set in the ssh config files /etc/ssh/sshd_config
1329+
copy2(SSH_CONFG, SSH_CONFG_TMP)
1330+
1331+
for key, value in ssh_policies.items():
1332+
if key == 'ports':
1333+
if not self.handle_ports_set(value):
1334+
syslog.syslog(syslog.LOG_ERR, "Failed to update sshd config files - wrong port configuration")
1335+
return
1336+
elif int(value) < SSH_MIN_VALUES.get(key, 65535) or SSH_MAX_VALUES.get(key, -1) < int(value):
1337+
syslog.syslog(syslog.LOG_ERR, "Ssh {} {} out of range".format(key, value))
1338+
elif key in SSH_CONFIG_NAMES:
1339+
# search replace configuration - if not in config file - append
1340+
kv_str = "{} {}".format(SSH_CONFIG_NAMES[key], str(value)) # name +' '+ value format
1341+
modify_single_file_inplace(SSH_CONFG_TMP,['-E', "/^#?" + SSH_CONFIG_NAMES[key]+"/{h;s/.*/"+
1342+
kv_str + "/};${x;/^$/{s//" + kv_str + "/;H};x}"])
1343+
else:
1344+
syslog.syslog(syslog.LOG_ERR, "Failed to update sshd config file - wrong key {}".format(key))
1345+
1346+
ssh_verify_res = subprocess.run(['sudo', 'sshd', '-T', '-f', SSH_CONFG_TMP], capture_output=True)
1347+
if ssh_verify_res.returncode == 0:
1348+
os.rename(SSH_CONFG_TMP, SSH_CONFG)
1349+
try:
1350+
run_cmd(['systemctl', 'restart', 'ssh'],
1351+
log_err=True, raise_exception=True)
1352+
except Exception:
1353+
syslog.syslog(syslog.LOG_ERR, f'Failed to update sshd config file')
1354+
else:
1355+
syslog.syslog(syslog.LOG_ERR, f'Failed to update sshd config file - sshd -T returned {ssh_verify_res.returncode} with error {ssh_verify_res.stderr.decode()}')
1356+
os.remove(SSH_CONFG_TMP)
1357+
12521358

12531359
class KdumpCfg(object):
12541360
def __init__(self, CfgDb):
@@ -1720,6 +1826,9 @@ class HostConfigDaemon:
17201826
# Initialize MgmtIfaceCfg
17211827
self.mgmtifacecfg = MgmtIfaceCfg()
17221828

1829+
# Initialize SshServer
1830+
self.sshscfg = SshServer()
1831+
17231832
# Initialize RSyslogCfg
17241833
self.rsyslogcfg = RSyslogCfg()
17251834

@@ -1738,6 +1847,7 @@ class HostConfigDaemon:
17381847
ntp_global = init_data['NTP']
17391848
kdump = init_data['KDUMP']
17401849
passwh = init_data['PASSW_HARDENING']
1850+
ssh_server = init_data['SSH_SERVER']
17411851
dev_meta = init_data.get(swsscommon.CFG_DEVICE_METADATA_TABLE_NAME, {})
17421852
mgmt_ifc = init_data.get(swsscommon.CFG_MGMT_INTERFACE_TABLE_NAME, {})
17431853
mgmt_vrf = init_data.get(swsscommon.CFG_MGMT_VRF_CONFIG_TABLE_NAME, {})
@@ -1751,6 +1861,7 @@ class HostConfigDaemon:
17511861
self.ntpcfg.load(ntp_global, ntp_server)
17521862
self.kdumpCfg.load(kdump)
17531863
self.passwcfg.load(passwh)
1864+
self.sshscfg.load(ssh_server)
17541865
self.devmetacfg.load(dev_meta)
17551866
self.mgmtifacecfg.load(mgmt_ifc, mgmt_vrf)
17561867

@@ -1775,6 +1886,10 @@ class HostConfigDaemon:
17751886
self.passwcfg.passw_policies_update(key, data)
17761887
syslog.syslog(syslog.LOG_INFO, 'PASSW_HARDENING Update: key: {}, op: {}, data: {}'.format(key, op, data))
17771888

1889+
def ssh_handler(self, key, op, data):
1890+
self.sshscfg.policies_update(key, data)
1891+
syslog.syslog(syslog.LOG_INFO, 'SSH Update: key: {}, op: {}, data: {}'.format(key, op, data))
1892+
17781893
def tacacs_server_handler(self, key, op, data):
17791894
self.aaacfg.tacacs_server_update(key, data)
17801895
log_data = copy.deepcopy(data)
@@ -1902,6 +2017,7 @@ class HostConfigDaemon:
19022017
self.config_db.subscribe('RADIUS', make_callback(self.radius_global_handler))
19032018
self.config_db.subscribe('RADIUS_SERVER', make_callback(self.radius_server_handler))
19042019
self.config_db.subscribe('PASSW_HARDENING', make_callback(self.passwh_handler))
2020+
self.config_db.subscribe('SSH_SERVER', make_callback(self.ssh_handler))
19052021
# Handle IPTables configuration
19062022
self.config_db.subscribe('LOOPBACK_INTERFACE', make_callback(self.lpbk_handler))
19072023
# Handle NTP & NTP_SERVER updates
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import importlib.machinery
2+
import importlib.util
3+
import filecmp
4+
import shutil
5+
import os
6+
import sys
7+
import subprocess
8+
import re
9+
10+
from parameterized import parameterized
11+
from unittest import TestCase, mock
12+
from tests.hostcfgd.test_ssh_server_vectors import HOSTCFGD_TEST_SSH_SERVER_VECTOR
13+
from tests.common.mock_configdb import MockConfigDb, MockDBConnector
14+
15+
test_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
16+
modules_path = os.path.dirname(test_path)
17+
scripts_path = os.path.join(modules_path, "scripts")
18+
src_path = os.path.dirname(modules_path)
19+
output_path = os.path.join(test_path, "hostcfgd/output")
20+
sample_output_path = os.path.join(test_path, "hostcfgd/sample_output")
21+
sys.path.insert(0, modules_path)
22+
23+
# Load the file under test
24+
hostcfgd_path = os.path.join(scripts_path, 'hostcfgd')
25+
loader = importlib.machinery.SourceFileLoader('hostcfgd', hostcfgd_path)
26+
spec = importlib.util.spec_from_loader(loader.name, loader)
27+
hostcfgd = importlib.util.module_from_spec(spec)
28+
loader.exec_module(hostcfgd)
29+
sys.modules['hostcfgd'] = hostcfgd
30+
31+
# Mock swsscommon classes
32+
hostcfgd.ConfigDBConnector = MockConfigDb
33+
hostcfgd.DBConnector = MockDBConnector
34+
hostcfgd.Table = mock.Mock()
35+
36+
37+
class TestHostcfgdSSHServer(TestCase):
38+
"""
39+
Test hostcfd daemon - SSHServer
40+
"""
41+
def run_diff(self, file1, file2):
42+
try:
43+
diff_out = subprocess.check_output('diff -ur {} {} || true'.format(file1, file2), shell=True)
44+
return diff_out
45+
except subprocess.CalledProcessError as err:
46+
syslog.syslog(syslog.LOG_ERR, "{} - failed: return code - {}, output:\n{}".format(err.cmd, err.returncode, err.output))
47+
return -1
48+
49+
"""
50+
Check different config
51+
"""
52+
def check_config(self, test_name, test_data, config_name):
53+
op_path = output_path + "/" + test_name + "_" + config_name
54+
sop_path = sample_output_path + "/" + test_name + "_" + config_name
55+
sop_path_common = sample_output_path + "/" + test_name
56+
hostcfgd.SSH_CONFG = op_path + "/sshd_config"
57+
hostcfgd.SSH_CONFG_TMP = hostcfgd.SSH_CONFG + ".tmp"
58+
shutil.rmtree(op_path, ignore_errors=True)
59+
os.mkdir(op_path)
60+
61+
shutil.copyfile(sop_path_common + "/sshd_config.old", op_path + "/sshd_config")
62+
MockConfigDb.set_config_db(test_data[config_name])
63+
host_config_daemon = hostcfgd.HostConfigDaemon()
64+
65+
try:
66+
ssh_table = host_config_daemon.config_db.get_table('SSH_SERVER')
67+
except Exception as e:
68+
syslog.syslog(syslog.LOG_ERR, "failed: get_table 'SSH_SERVER', exception={}".format(e))
69+
ssh_table = []
70+
71+
host_config_daemon.sshscfg.load(ssh_table)
72+
73+
74+
diff_output = ""
75+
files_to_compare = ['sshd_config']
76+
77+
# check output files exists
78+
for name in files_to_compare:
79+
if not os.path.isfile(sop_path + "/" + name):
80+
raise ValueError('filename: %s not exit' % (sop_path + "/" + name))
81+
if not os.path.isfile(op_path + "/" + name):
82+
raise ValueError('filename: %s not exit' % (op_path + "/" + name))
83+
84+
# deep comparison
85+
match, mismatch, errors = filecmp.cmpfiles(sop_path, op_path, files_to_compare, shallow=False)
86+
87+
if not match:
88+
for name in files_to_compare:
89+
diff_output += self.run_diff( sop_path + "/" + name,\
90+
op_path + "/" + name).decode('utf-8')
91+
92+
self.assertTrue(len(diff_output) == 0, diff_output)
93+
94+
@parameterized.expand(HOSTCFGD_TEST_SSH_SERVER_VECTOR)
95+
def test_hostcfgd_sshs_default_values(self, test_name, test_data):
96+
"""
97+
Test SSHS hostcfd daemon initialization
98+
99+
Args:
100+
test_name(str): test name
101+
test_data(dict): test data which contains initial Config Db tables, and expected results
102+
103+
Returns:
104+
None
105+
"""
106+
107+
self.check_config(test_name, test_data, "default_values")
108+
109+
@parameterized.expand(HOSTCFGD_TEST_SSH_SERVER_VECTOR)
110+
def test_hostcfgd_sshs_login_timeout(self, test_name, test_data):
111+
"""
112+
Test SSHS hostcfd daemon initialization
113+
114+
Args:
115+
test_name(str): test name
116+
test_data(dict): test data which contains initial Config Db tables, and expected results
117+
118+
Returns:
119+
None
120+
"""
121+
122+
self.check_config(test_name, test_data, "modify_login_timeout")
123+
124+
125+
@parameterized.expand(HOSTCFGD_TEST_SSH_SERVER_VECTOR)
126+
def test_hostcfgd_sshs_authentication_retries(self, test_name, test_data):
127+
"""
128+
Test SSHS hostcfd daemon initialization
129+
130+
Args:
131+
test_name(str): test name
132+
test_data(dict): test data which contains initial Config Db tables, and expected results
133+
134+
Returns:
135+
None
136+
"""
137+
138+
self.check_config(test_name, test_data, "modify_authentication_retries")
139+
140+
@parameterized.expand(HOSTCFGD_TEST_SSH_SERVER_VECTOR)
141+
def test_hostcfgd_sshs_ports(self, test_name, test_data):
142+
"""
143+
Test SSHS hostcfd daemon initialization
144+
145+
Args:
146+
test_name(str): test name
147+
test_data(dict): test data which contains initial Config Db tables, and expected results
148+
149+
Returns:
150+
None
151+
"""
152+
153+
self.check_config(test_name, test_data, "modify_ports")
154+
155+
@parameterized.expand(HOSTCFGD_TEST_SSH_SERVER_VECTOR)
156+
def test_hostcfgd_sshs_all(self, test_name, test_data):
157+
"""
158+
Test SSHS hostcfd daemon initialization
159+
160+
Args:
161+
test_name(str): test name
162+
test_data(dict): test data which contains initial Config Db tables, and expected results
163+
164+
Returns:
165+
None
166+
"""
167+
168+
self.check_config(test_name, test_data, "modify_all")

0 commit comments

Comments
 (0)