Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions ros2bag/ros2bag/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from argparse import ArgumentTypeError
import os
from typing import Any
from typing import Dict
from typing import Optional

from rclpy.duration import Duration
from rclpy.qos import QoSDurabilityPolicy
from rclpy.qos import QoSHistoryPolicy
from rclpy.qos import QoSLivelinessPolicy
from rclpy.qos import QoSProfile
from rclpy.qos import QoSReliabilityPolicy

# This map needs to be updated when new policies are introduced
_QOS_POLICY_FROM_SHORT_NAME = {
'history': QoSHistoryPolicy.get_from_short_key,
'reliability': QoSReliabilityPolicy.get_from_short_key,
'durability': QoSDurabilityPolicy.get_from_short_key,
'liveliness': QoSLivelinessPolicy.get_from_short_key
}
_DURATION_KEYS = ['deadline', 'lifespan', 'liveliness_lease_duration']
_VALUE_KEYS = ['depth', 'avoid_ros_namespace_conventions']


def print_error(string: str) -> str:
return '[ERROR] [ros2bag]: {}'.format(string)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why reimplementing logging.error?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was debating between using logging vs print statements and it appears that the ros2 verb extension needs to return when something goes wrong/finishes.
If we use logging, we'd need to log followed by a return with the same statement which is duplicate work.
If there's a proper approach for returning from a ros2 verb lmk (I'm following what other packages did).



def dict_to_duration(time_dict: Optional[Dict[str, int]]) -> Duration:
"""Convert a QoS duration profile from YAML into an rclpy Duration."""
if time_dict:
try:
return Duration(seconds=time_dict['sec'], nanoseconds=time_dict['nsec'])
except KeyError:
raise ValueError(
'Time overrides must include both seconds (sec) and nanoseconds (nsec).')
else:
return Duration()


def interpret_dict_as_qos_profile(qos_profile_dict: Dict) -> QoSProfile:
"""Sanitize a user provided dict of a QoS profile and verify all keys are valid."""
new_profile_dict = {}
for policy_key, policy_value in qos_profile_dict.items():
if policy_key in _DURATION_KEYS:
new_profile_dict[policy_key] = dict_to_duration(policy_value)
elif policy_key in _QOS_POLICY_FROM_SHORT_NAME:
new_profile_dict[policy_key] = _QOS_POLICY_FROM_SHORT_NAME[policy_key](policy_value)
elif policy_key in _VALUE_KEYS:
new_profile_dict[policy_key] = policy_value
else:
raise ValueError('Unexpected key `{}` for QoS profile.'.format(policy_key))
return QoSProfile(**new_profile_dict)


def convert_yaml_to_qos_profile(qos_profile_dict: Dict) -> Dict[str, QoSProfile]:
"""Convert a YAML file to use rclpy's QoSProfile."""
topic_profile_dict = {}
for topic, profile in qos_profile_dict.items():
topic_profile_dict[topic] = interpret_dict_as_qos_profile(profile)
return topic_profile_dict


def create_bag_directory(uri: str) -> Optional[str]:
"""Create a directory."""
try:
os.makedirs(uri)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see the exception-case return value used anywhere.

In the successful case this function hits end of scope without returning a value.

Is the intended return type None (and supposed to log) or is it Optional[str]?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I moved this from the play verb and I believe it's intention is to simply log the result.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM'd the PR, but the returned value here doesn't actually ever get printed. You can move forward though, if this is maintaining previous behavior

except OSError:
return print_error("Could not create bag folder '{}'.".format(uri))


def check_positive_float(value: Any) -> float:
"""Argparse validator to verify that a value is a float and positive."""
try:
fvalue = float(value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems to contradict your typing? (value is a floating-point value)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wasn't sure what to put the type as since this is supposed to be used as an argparse validator. (Is Any valid?)

if fvalue <= 0.0:
raise ArgumentTypeError('{} is not in the valid range (> 0.0)'.format(value))
return fvalue
except ValueError:
raise ArgumentTypeError('{} is not the valid type (float)'.format(value))


def check_path_exists(value: Any) -> str:
"""Argparse validator to verify a path exists."""
try:
if os.path.exists(value):
return value
raise ArgumentTypeError("Bag file '{}' does not exist!".format(value))
except ValueError:
raise ArgumentTypeError('{} is not the valid type (string)'.format(value))
23 changes: 5 additions & 18 deletions ros2bag/ros2bag/verb/play.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from argparse import ArgumentTypeError
import os

from ros2bag.api import check_path_exists
from ros2bag.api import check_positive_float
from ros2bag.verb import VerbExtension
from ros2cli.node import NODE_NAME_PREFIX

Expand All @@ -24,7 +23,7 @@ class PlayVerb(VerbExtension):

def add_arguments(self, parser, cli_name): # noqa: D102
parser.add_argument(
'bag_file', help='bag file to replay')
'bag_file', type=check_path_exists, help='bag file to replay')
parser.add_argument(
'-s', '--storage', default='sqlite3',
help='storage identifier to be used, defaults to "sqlite3"')
Expand All @@ -34,30 +33,18 @@ def add_arguments(self, parser, cli_name): # noqa: D102
'playback. Larger size will result in larger memory needs but might prevent '
'delay of message playback.')
parser.add_argument(
'-r', '--rate', type=self.check_positive_float, default=1.0,
'-r', '--rate', type=check_positive_float, default=1.0,
help='rate at which to play back messages. Valid range > 0.0.')

def check_positive_float(self, value):
try:
fvalue = float(value)
if fvalue <= 0.0:
raise ArgumentTypeError('%s is not in the valid range (> 0.0)' % value)
return fvalue
except ValueError:
raise ArgumentTypeError('%s is not of the valid type (float)' % value)

def main(self, *, args): # noqa: D102
bag_file = args.bag_file
if not os.path.exists(bag_file):
return "[ERROR] [ros2bag] bag file '{}' does not exist!".format(bag_file)
# NOTE(hidmic): in merged install workspaces on Windows, Python entrypoint lookups
# combined with constrained environments (as imposed by colcon test)
# may result in DLL loading failures when attempting to import a C
# extension. Therefore, do not import rosbag2_transport at the module
# level but on demand, right before first use.
from rosbag2_transport import rosbag2_transport_py
rosbag2_transport_py.play(
uri=bag_file,
uri=args.bag_file,
storage_id=args.storage,
node_prefix=NODE_NAME_PREFIX,
read_ahead_queue_size=args.read_ahead_queue_size,
Expand Down
74 changes: 13 additions & 61 deletions ros2bag/ros2bag/verb/record.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,61 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
from argparse import FileType
import datetime
import logging
import os
from typing import Dict
from typing import Optional

import rclpy
from rclpy.duration import Duration
from rclpy.qos import InvalidQoSProfileException
from rclpy.qos import QoSProfile
from ros2bag.api import convert_yaml_to_qos_profile
from ros2bag.api import create_bag_directory
from ros2bag.api import print_error
from ros2bag.verb import VerbExtension
from ros2cli.node import NODE_NAME_PREFIX
import yaml

# This map needs to be updated when new policies are introduced
POLICY_MAP = {
'history': rclpy.qos.QoSHistoryPolicy.get_from_short_key,
'reliability': rclpy.qos.QoSReliabilityPolicy.get_from_short_key,
'durability': rclpy.qos.QoSDurabilityPolicy.get_from_short_key,
'liveliness': rclpy.qos.QoSLivelinessPolicy.get_from_short_key
}

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger('ros2bag')


def is_dict_valid_duration(duration_dict: Dict[str, int]) -> bool:
return all(key in duration_dict for key in ['sec', 'nsec'])


def dict_to_duration(time_dict: Optional[Dict[str, int]]) -> Duration:
if time_dict:
if is_dict_valid_duration(time_dict):
return Duration(seconds=time_dict.get('sec'), nanoseconds=time_dict.get('nsec'))
else:
raise ValueError(
'Time overrides must include both seconds (sec) and nanoseconds (nsec).')
else:
return Duration()


def validate_qos_profile_overrides(qos_profile_dict: Dict) -> Dict[str, Dict]:
"""Validate the QoS profile yaml file path and its structure."""
for name in qos_profile_dict.keys():
profile = qos_profile_dict[name]
# Convert dict to Duration. Required for construction
conversion_keys = ['deadline', 'lifespan', 'liveliness_lease_duration']
for k in conversion_keys:
profile[k] = dict_to_duration(profile.get(k))
for policy in POLICY_MAP.keys():
profile[policy] = POLICY_MAP[policy](profile.get(policy, 'system_default'))
qos_profile_dict[name] = QoSProfile(**profile)
return qos_profile_dict


class RecordVerb(VerbExtension):
"""ros2 bag record."""
Expand Down Expand Up @@ -122,41 +79,36 @@ def add_arguments(self, parser, cli_name): # noqa: D102
help='record also hidden topics.'
)
parser.add_argument(
'--qos-profile-overrides-path', type=argparse.FileType('r'),
'--qos-profile-overrides-path', type=FileType('r'),
help='Path to a yaml file defining overrides of the QoS profile for specific topics.'
)
self._subparser = parser

def create_bag_directory(self, uri):
try:
os.makedirs(uri)
except OSError:
return "[ERROR] [ros2bag]: Could not create bag folder '{}'.".format(uri)

def main(self, *, args): # noqa: D102
if args.all and args.topics:
return 'Invalid choice: Can not specify topics and -a at the same time.'
return print_error('Invalid choice: Can not specify topics and -a at the same time.')

uri = args.output or datetime.datetime.now().strftime('rosbag2_%Y_%m_%d-%H_%M_%S')

if os.path.isdir(uri):
return "[ERROR] [ros2bag]: Output folder '{}' already exists.".format(uri)
return print_error("Output folder '{}' already exists.".format(uri))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider moving the argument validation in the arpgarse logic if you keep modifying this code. This will take care for you of reporting errors to the users the "proper way". You could just create a function checking if the directory exists, and call strftime to pass a default value as it's built here.


if args.compression_format and args.compression_mode == 'none':
return 'Invalid choice: Cannot specify compression format without a compression mode.'
return print_error('Invalid choice: Cannot specify compression format '
'without a compression mode.')

args.compression_mode = args.compression_mode.upper()

qos_profile_overrides = {} # Specify a valid default
if args.qos_profile_overrides_path:
qos_profile_dict = yaml.safe_load(args.qos_profile_overrides_path)
try:
qos_profile_overrides = validate_qos_profile_overrides(
qos_profile_overrides = convert_yaml_to_qos_profile(
qos_profile_dict)
except (InvalidQoSProfileException, ValueError) as e:
logger.error(str(e))
return str(e)
return print_error(str(e))

self.create_bag_directory(uri)
create_bag_directory(uri)

if args.all:
# NOTE(hidmic): in merged install workspaces on Windows, Python entrypoint lookups
Expand Down
7 changes: 7 additions & 0 deletions ros2bag/test/resources/incomplete_qos_duration.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
/test_topic:
history: keep_all
depth: 0
reliability: reliable
durability: volatile
deadline:
sec: 2147483647 # LONG_MAX
76 changes: 76 additions & 0 deletions ros2bag/test/test_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright 2020 Amazon.com, Inc. or its affiliates. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from rclpy.qos import QoSDurabilityPolicy
from rclpy.qos import QoSHistoryPolicy
from rclpy.qos import QoSReliabilityPolicy
from ros2bag.api import convert_yaml_to_qos_profile
from ros2bag.api import dict_to_duration
from ros2bag.api import interpret_dict_as_qos_profile


class TestRos2BagRecord(unittest.TestCase):

def test_dict_to_duration_valid(self):
expected_nanoseconds = 1000000002
duration_dict = {'sec': 1, 'nsec': 2}
duration = dict_to_duration(duration_dict)
assert duration.nanoseconds == expected_nanoseconds

def test_dict_to_duration_invalid(self):
duration_dict = {'sec': 1}
with self.assertRaises(ValueError):
dict_to_duration(duration_dict)

def test_interpret_dict_as_qos_profile_valid(self):
qos_dict = {'history': 'keep_last', 'depth': 10}
qos_profile = interpret_dict_as_qos_profile(qos_dict)
assert qos_profile.history == QoSHistoryPolicy.RMW_QOS_POLICY_HISTORY_KEEP_LAST
expected_seconds = 1
expected_nanoseconds = int((expected_seconds * 1e9))
qos_dict = {'history': 'keep_all', 'deadline': {'sec': expected_seconds, 'nsec': 0}}
qos_profile = interpret_dict_as_qos_profile(qos_dict)
assert qos_profile.deadline.nanoseconds == expected_nanoseconds
expected_convention = False
qos_dict = {'history': 'keep_all', 'avoid_ros_namespace_conventions': expected_convention}
qos_profile = interpret_dict_as_qos_profile(qos_dict)
assert qos_profile.avoid_ros_namespace_conventions == expected_convention

def test_interpret_dict_as_qos_profile_invalid(self):
qos_dict = {'foo': 'bar'}
with self.assertRaises(ValueError):
interpret_dict_as_qos_profile(qos_dict)

def test_convert_yaml_to_qos_profile(self):
topic_name_1 = '/topic1'
topic_name_2 = '/topic2'
expected_convention = False
qos_dict = {
topic_name_1: {
'history': 'keep_all', 'durability': 'volatile', 'reliability': 'reliable'},
topic_name_2: {
'history': 'keep_all', 'avoid_ros_namespace_conventions': expected_convention}
}
qos_profiles = convert_yaml_to_qos_profile(qos_dict)
assert qos_profiles[topic_name_1].durability == \
QoSDurabilityPolicy.RMW_QOS_POLICY_DURABILITY_VOLATILE
assert qos_profiles[topic_name_1].reliability == \
QoSReliabilityPolicy.RMW_QOS_POLICY_RELIABILITY_RELIABLE
assert qos_profiles[topic_name_1].history == \
QoSHistoryPolicy.RMW_QOS_POLICY_HISTORY_KEEP_ALL
assert qos_profiles[topic_name_2].avoid_ros_namespace_conventions == expected_convention
assert qos_profiles[topic_name_2].history == \
QoSHistoryPolicy.RMW_QOS_POLICY_HISTORY_KEEP_ALL
Loading