-
Notifications
You must be signed in to change notification settings - Fork 293
Refactor utility functions in ros2bag #358
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8637f5b
2261b7a
ae1034c
2da101a
ba0f954
8fb0671
307320b
01ab7e8
e5c8b8d
8430ddc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| # Copyright 2020 Amazon.com, Inc. or its affiliates. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from argparse import ArgumentTypeError | ||
| import os | ||
| from typing import Any | ||
| from typing import Dict | ||
| from typing import Optional | ||
|
|
||
| from rclpy.duration import Duration | ||
| from rclpy.qos import QoSDurabilityPolicy | ||
| from rclpy.qos import QoSHistoryPolicy | ||
| from rclpy.qos import QoSLivelinessPolicy | ||
| from rclpy.qos import QoSProfile | ||
| from rclpy.qos import QoSReliabilityPolicy | ||
|
|
||
| # This map needs to be updated when new policies are introduced | ||
| _QOS_POLICY_FROM_SHORT_NAME = { | ||
| 'history': QoSHistoryPolicy.get_from_short_key, | ||
| 'reliability': QoSReliabilityPolicy.get_from_short_key, | ||
| 'durability': QoSDurabilityPolicy.get_from_short_key, | ||
| 'liveliness': QoSLivelinessPolicy.get_from_short_key | ||
| } | ||
| _DURATION_KEYS = ['deadline', 'lifespan', 'liveliness_lease_duration'] | ||
| _VALUE_KEYS = ['depth', 'avoid_ros_namespace_conventions'] | ||
|
|
||
|
|
||
| def print_error(string: str) -> str: | ||
| return '[ERROR] [ros2bag]: {}'.format(string) | ||
|
|
||
|
|
||
| def dict_to_duration(time_dict: Optional[Dict[str, int]]) -> Duration: | ||
| """Convert a QoS duration profile from YAML into an rclpy Duration.""" | ||
| if time_dict: | ||
| try: | ||
| return Duration(seconds=time_dict['sec'], nanoseconds=time_dict['nsec']) | ||
| except KeyError: | ||
| raise ValueError( | ||
| 'Time overrides must include both seconds (sec) and nanoseconds (nsec).') | ||
| else: | ||
| return Duration() | ||
|
|
||
|
|
||
| def interpret_dict_as_qos_profile(qos_profile_dict: Dict) -> QoSProfile: | ||
| """Sanitize a user provided dict of a QoS profile and verify all keys are valid.""" | ||
| new_profile_dict = {} | ||
| for policy_key, policy_value in qos_profile_dict.items(): | ||
| if policy_key in _DURATION_KEYS: | ||
| new_profile_dict[policy_key] = dict_to_duration(policy_value) | ||
| elif policy_key in _QOS_POLICY_FROM_SHORT_NAME: | ||
| new_profile_dict[policy_key] = _QOS_POLICY_FROM_SHORT_NAME[policy_key](policy_value) | ||
| elif policy_key in _VALUE_KEYS: | ||
| new_profile_dict[policy_key] = policy_value | ||
| else: | ||
| raise ValueError('Unexpected key `{}` for QoS profile.'.format(policy_key)) | ||
| return QoSProfile(**new_profile_dict) | ||
|
|
||
|
|
||
| def convert_yaml_to_qos_profile(qos_profile_dict: Dict) -> Dict[str, QoSProfile]: | ||
| """Convert a YAML file to use rclpy's QoSProfile.""" | ||
| topic_profile_dict = {} | ||
| for topic, profile in qos_profile_dict.items(): | ||
| topic_profile_dict[topic] = interpret_dict_as_qos_profile(profile) | ||
| return topic_profile_dict | ||
|
|
||
|
|
||
| def create_bag_directory(uri: str) -> Optional[str]: | ||
| """Create a directory.""" | ||
| try: | ||
| os.makedirs(uri) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see the exception-case return value used anywhere. In the successful case this function hits end of scope without returning a value. Is the intended return type
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I moved this from the
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. LGTM'd the PR, but the returned value here doesn't actually ever get printed. You can move forward though, if this is maintaining previous behavior |
||
| except OSError: | ||
| return print_error("Could not create bag folder '{}'.".format(uri)) | ||
|
|
||
|
|
||
| def check_positive_float(value: Any) -> float: | ||
| """Argparse validator to verify that a value is a float and positive.""" | ||
| try: | ||
| fvalue = float(value) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems to contradict your typing? (value is a floating-point value)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wasn't sure what to put the type as since this is supposed to be used as an argparse validator. (Is |
||
| if fvalue <= 0.0: | ||
| raise ArgumentTypeError('{} is not in the valid range (> 0.0)'.format(value)) | ||
| return fvalue | ||
| except ValueError: | ||
| raise ArgumentTypeError('{} is not the valid type (float)'.format(value)) | ||
|
|
||
|
|
||
| def check_path_exists(value: Any) -> str: | ||
| """Argparse validator to verify a path exists.""" | ||
| try: | ||
| if os.path.exists(value): | ||
| return value | ||
| raise ArgumentTypeError("Bag file '{}' does not exist!".format(value)) | ||
| except ValueError: | ||
| raise ArgumentTypeError('{} is not the valid type (string)'.format(value)) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,61 +12,18 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import argparse | ||
| from argparse import FileType | ||
| import datetime | ||
| import logging | ||
| import os | ||
| from typing import Dict | ||
| from typing import Optional | ||
|
|
||
| import rclpy | ||
| from rclpy.duration import Duration | ||
| from rclpy.qos import InvalidQoSProfileException | ||
| from rclpy.qos import QoSProfile | ||
| from ros2bag.api import convert_yaml_to_qos_profile | ||
| from ros2bag.api import create_bag_directory | ||
| from ros2bag.api import print_error | ||
| from ros2bag.verb import VerbExtension | ||
| from ros2cli.node import NODE_NAME_PREFIX | ||
| import yaml | ||
|
|
||
| # This map needs to be updated when new policies are introduced | ||
| POLICY_MAP = { | ||
| 'history': rclpy.qos.QoSHistoryPolicy.get_from_short_key, | ||
| 'reliability': rclpy.qos.QoSReliabilityPolicy.get_from_short_key, | ||
| 'durability': rclpy.qos.QoSDurabilityPolicy.get_from_short_key, | ||
| 'liveliness': rclpy.qos.QoSLivelinessPolicy.get_from_short_key | ||
| } | ||
|
|
||
| logging.basicConfig(level=logging.INFO) | ||
| logger = logging.getLogger('ros2bag') | ||
|
|
||
|
|
||
| def is_dict_valid_duration(duration_dict: Dict[str, int]) -> bool: | ||
| return all(key in duration_dict for key in ['sec', 'nsec']) | ||
|
|
||
|
|
||
| def dict_to_duration(time_dict: Optional[Dict[str, int]]) -> Duration: | ||
| if time_dict: | ||
| if is_dict_valid_duration(time_dict): | ||
| return Duration(seconds=time_dict.get('sec'), nanoseconds=time_dict.get('nsec')) | ||
| else: | ||
| raise ValueError( | ||
| 'Time overrides must include both seconds (sec) and nanoseconds (nsec).') | ||
| else: | ||
| return Duration() | ||
|
|
||
|
|
||
| def validate_qos_profile_overrides(qos_profile_dict: Dict) -> Dict[str, Dict]: | ||
| """Validate the QoS profile yaml file path and its structure.""" | ||
| for name in qos_profile_dict.keys(): | ||
| profile = qos_profile_dict[name] | ||
| # Convert dict to Duration. Required for construction | ||
| conversion_keys = ['deadline', 'lifespan', 'liveliness_lease_duration'] | ||
| for k in conversion_keys: | ||
| profile[k] = dict_to_duration(profile.get(k)) | ||
| for policy in POLICY_MAP.keys(): | ||
| profile[policy] = POLICY_MAP[policy](profile.get(policy, 'system_default')) | ||
| qos_profile_dict[name] = QoSProfile(**profile) | ||
| return qos_profile_dict | ||
|
|
||
|
|
||
| class RecordVerb(VerbExtension): | ||
| """ros2 bag record.""" | ||
|
|
@@ -122,41 +79,36 @@ def add_arguments(self, parser, cli_name): # noqa: D102 | |
| help='record also hidden topics.' | ||
| ) | ||
| parser.add_argument( | ||
| '--qos-profile-overrides-path', type=argparse.FileType('r'), | ||
| '--qos-profile-overrides-path', type=FileType('r'), | ||
| help='Path to a yaml file defining overrides of the QoS profile for specific topics.' | ||
| ) | ||
| self._subparser = parser | ||
|
|
||
| def create_bag_directory(self, uri): | ||
| try: | ||
| os.makedirs(uri) | ||
| except OSError: | ||
| return "[ERROR] [ros2bag]: Could not create bag folder '{}'.".format(uri) | ||
|
|
||
| def main(self, *, args): # noqa: D102 | ||
| if args.all and args.topics: | ||
| return 'Invalid choice: Can not specify topics and -a at the same time.' | ||
| return print_error('Invalid choice: Can not specify topics and -a at the same time.') | ||
|
|
||
| uri = args.output or datetime.datetime.now().strftime('rosbag2_%Y_%m_%d-%H_%M_%S') | ||
|
|
||
| if os.path.isdir(uri): | ||
| return "[ERROR] [ros2bag]: Output folder '{}' already exists.".format(uri) | ||
| return print_error("Output folder '{}' already exists.".format(uri)) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Consider moving the argument validation in the |
||
|
|
||
| if args.compression_format and args.compression_mode == 'none': | ||
| return 'Invalid choice: Cannot specify compression format without a compression mode.' | ||
| return print_error('Invalid choice: Cannot specify compression format ' | ||
| 'without a compression mode.') | ||
|
|
||
| args.compression_mode = args.compression_mode.upper() | ||
|
|
||
| qos_profile_overrides = {} # Specify a valid default | ||
| if args.qos_profile_overrides_path: | ||
| qos_profile_dict = yaml.safe_load(args.qos_profile_overrides_path) | ||
| try: | ||
| qos_profile_overrides = validate_qos_profile_overrides( | ||
| qos_profile_overrides = convert_yaml_to_qos_profile( | ||
| qos_profile_dict) | ||
| except (InvalidQoSProfileException, ValueError) as e: | ||
| logger.error(str(e)) | ||
| return str(e) | ||
| return print_error(str(e)) | ||
|
|
||
| self.create_bag_directory(uri) | ||
| create_bag_directory(uri) | ||
|
|
||
| if args.all: | ||
| # NOTE(hidmic): in merged install workspaces on Windows, Python entrypoint lookups | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| /test_topic: | ||
| history: keep_all | ||
| depth: 0 | ||
| reliability: reliable | ||
| durability: volatile | ||
| deadline: | ||
| sec: 2147483647 # LONG_MAX |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| # Copyright 2020 Amazon.com, Inc. or its affiliates. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import unittest | ||
|
|
||
| from rclpy.qos import QoSDurabilityPolicy | ||
| from rclpy.qos import QoSHistoryPolicy | ||
| from rclpy.qos import QoSReliabilityPolicy | ||
| from ros2bag.api import convert_yaml_to_qos_profile | ||
| from ros2bag.api import dict_to_duration | ||
| from ros2bag.api import interpret_dict_as_qos_profile | ||
|
|
||
|
|
||
| class TestRos2BagRecord(unittest.TestCase): | ||
|
|
||
| def test_dict_to_duration_valid(self): | ||
| expected_nanoseconds = 1000000002 | ||
| duration_dict = {'sec': 1, 'nsec': 2} | ||
| duration = dict_to_duration(duration_dict) | ||
| assert duration.nanoseconds == expected_nanoseconds | ||
|
|
||
| def test_dict_to_duration_invalid(self): | ||
| duration_dict = {'sec': 1} | ||
| with self.assertRaises(ValueError): | ||
| dict_to_duration(duration_dict) | ||
|
|
||
| def test_interpret_dict_as_qos_profile_valid(self): | ||
| qos_dict = {'history': 'keep_last', 'depth': 10} | ||
| qos_profile = interpret_dict_as_qos_profile(qos_dict) | ||
| assert qos_profile.history == QoSHistoryPolicy.RMW_QOS_POLICY_HISTORY_KEEP_LAST | ||
| expected_seconds = 1 | ||
| expected_nanoseconds = int((expected_seconds * 1e9)) | ||
| qos_dict = {'history': 'keep_all', 'deadline': {'sec': expected_seconds, 'nsec': 0}} | ||
| qos_profile = interpret_dict_as_qos_profile(qos_dict) | ||
| assert qos_profile.deadline.nanoseconds == expected_nanoseconds | ||
| expected_convention = False | ||
| qos_dict = {'history': 'keep_all', 'avoid_ros_namespace_conventions': expected_convention} | ||
| qos_profile = interpret_dict_as_qos_profile(qos_dict) | ||
| assert qos_profile.avoid_ros_namespace_conventions == expected_convention | ||
|
|
||
| def test_interpret_dict_as_qos_profile_invalid(self): | ||
| qos_dict = {'foo': 'bar'} | ||
| with self.assertRaises(ValueError): | ||
| interpret_dict_as_qos_profile(qos_dict) | ||
|
|
||
| def test_convert_yaml_to_qos_profile(self): | ||
| topic_name_1 = '/topic1' | ||
| topic_name_2 = '/topic2' | ||
| expected_convention = False | ||
| qos_dict = { | ||
| topic_name_1: { | ||
| 'history': 'keep_all', 'durability': 'volatile', 'reliability': 'reliable'}, | ||
| topic_name_2: { | ||
| 'history': 'keep_all', 'avoid_ros_namespace_conventions': expected_convention} | ||
| } | ||
| qos_profiles = convert_yaml_to_qos_profile(qos_dict) | ||
| assert qos_profiles[topic_name_1].durability == \ | ||
| QoSDurabilityPolicy.RMW_QOS_POLICY_DURABILITY_VOLATILE | ||
| assert qos_profiles[topic_name_1].reliability == \ | ||
| QoSReliabilityPolicy.RMW_QOS_POLICY_RELIABILITY_RELIABLE | ||
| assert qos_profiles[topic_name_1].history == \ | ||
| QoSHistoryPolicy.RMW_QOS_POLICY_HISTORY_KEEP_ALL | ||
| assert qos_profiles[topic_name_2].avoid_ros_namespace_conventions == expected_convention | ||
| assert qos_profiles[topic_name_2].history == \ | ||
| QoSHistoryPolicy.RMW_QOS_POLICY_HISTORY_KEEP_ALL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why reimplementing
logging.error?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was debating between using
loggingvs print statements and it appears that the ros2 verb extension needs toreturnwhen something goes wrong/finishes.If we use logging, we'd need to log followed by a
returnwith the same statement which is duplicate work.If there's a proper approach for
returning from a ros2 verb lmk (I'm following what other packages did).