Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion sros2/package.xml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
<test_depend>ament_pep257</test_depend>
<test_depend>python3-pytest</test_depend>
<test_depend>std_msgs</test_depend>

<test_depend>std_srvs</test_depend>

<export>
<build_type>ament_python</build_type>
</export>
Expand Down
4 changes: 4 additions & 0 deletions sros2/sros2/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def get_service_info(node, node_name):
return get_topics(node_name, node.get_service_names_and_types_by_node)


def get_client_info(node, node_name):
return get_topics(node_name, node.get_client_names_and_types_by_node)


def _write_key(
key,
key_path,
Expand Down
5 changes: 5 additions & 0 deletions sros2/sros2/verb/generate_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def FilesCompleter(*, allowednames, directories):
from ros2cli.node.strategy import NodeStrategy

from sros2.api import (
get_client_info,
get_node_names,
get_publisher_info,
get_service_info,
Expand Down Expand Up @@ -138,6 +139,10 @@ def main(self, *, args):
if reply_services:
self.add_permission(
profile, 'service', 'reply', 'ALLOW', reply_services, node_name)
request_services = get_client_info(node=node, node_name=node_name)
if request_services:
self.add_permission(
profile, 'service', 'request', 'ALLOW', request_services, node_name)

with open(args.POLICY_FILE_PATH, 'w') as stream:
dump_policy(policy, stream)
Expand Down
43 changes: 42 additions & 1 deletion sros2/test/sros2/commands/security/verbs/test_generate_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@
from ros2cli import cli
from sros2.policy import load_policy
from std_msgs.msg import String
from std_srvs.srv import Trigger


def test_generate_policy():
def test_generate_policy_topics():
with tempfile.TemporaryDirectory() as tmpdir:
# Create a test-specific context so that generate_policy can still init
context = rclpy.Context()
Expand Down Expand Up @@ -62,6 +63,46 @@ def test_generate_policy():
assert len([t for t in topics if t.text == 'topic_pub']) == 0


def test_generate_policy_services():
with tempfile.TemporaryDirectory() as tmpdir:
# Create a test-specific context so that generate_policy can still init
context = rclpy.Context()
rclpy.init(context=context)
node = rclpy.create_node('test_node', context=context)

try:
# Create a server and client
node.create_client(Trigger, 'service_client')
node.create_service(Trigger, 'service_server', lambda request,
response: response)

# Generate the policy for the running node
assert cli.main(
argv=['security', 'generate_policy', os.path.join(tmpdir, 'test-policy.xml')]) == 0
finally:
node.destroy_node()
rclpy.shutdown(context=context)

# Load the policy and pull out allowed replies and requests
policy = load_policy(os.path.join(tmpdir, 'test-policy.xml'))
profile = policy.find(path='profiles/profile[@ns="/"][@node="test_node"]')
assert profile is not None
service_reply_allowed = profile.find(path='services[@reply="ALLOW"]')
assert service_reply_allowed is not None
service_request_allowed = profile.find(path='services[@request="ALLOW"]')
assert service_request_allowed is not None

# Verify that the allowed replies include service_server and not service_client
services = service_reply_allowed.findall('service')
assert len([s for s in services if s.text == 'service_server']) == 1
assert len([s for s in services if s.text == 'service_client']) == 0

# Verify that the allowed requests include service_client and not service_server
services = service_request_allowed.findall('service')
assert len([s for s in services if s.text == 'service_client']) == 1
assert len([s for s in services if s.text == 'service_server']) == 0


def test_generate_policy_no_nodes(capsys):
with tempfile.TemporaryDirectory() as tmpdir:
assert cli.main(argv=[
Expand Down