From aa7dcbe4416c4acfd2de1bcd794934012d2f2452 Mon Sep 17 00:00:00 2001 From: Roman Sokolkov Date: Mon, 20 Jun 2022 13:44:33 +0200 Subject: [PATCH] Add enum support to ROSIDL Signed-off-by: Roman Sokolkov --- rosidl_generator_py/CMakeLists.txt | 1 + rosidl_generator_py/idl/Enums.idl | 25 ++++++ rosidl_generator_py/resource/_msg.py.em | 33 ++++++- .../resource/_msg_support.c.em | 88 +++++++++++++------ .../rosidl_generator_py/generate_py_impl.py | 6 ++ rosidl_generator_py/test/test_interfaces.py | 26 ++++++ 6 files changed, 151 insertions(+), 28 deletions(-) create mode 100644 rosidl_generator_py/idl/Enums.idl diff --git a/rosidl_generator_py/CMakeLists.txt b/rosidl_generator_py/CMakeLists.txt index 96854450..b004b8d3 100644 --- a/rosidl_generator_py/CMakeLists.txt +++ b/rosidl_generator_py/CMakeLists.txt @@ -59,6 +59,7 @@ if(BUILD_TESTING) msg/BuiltinTypeSequencesIdl.idl msg/StringArrays.msg msg/Property.msg + idl/Enums.idl ADD_LINTER_TESTS SKIP_INSTALL ) diff --git a/rosidl_generator_py/idl/Enums.idl b/rosidl_generator_py/idl/Enums.idl new file mode 100644 index 00000000..5f98c83d --- /dev/null +++ b/rosidl_generator_py/idl/Enums.idl @@ -0,0 +1,25 @@ +module rosidl_generator_py { + module idl { + typedef SomeEnum SomeEnum__3[3]; + + module Enums_Enums { + enum SomeEnum { + ENUMERATOR1, + ENUMERATOR2 + }; + }; + + struct Enums { + SomeEnum enum_value; + + @default (value="ENUMERATOR2") + SomeEnum enum_default_value; + + SomeEnum__3 static_array_values; + + sequence bounded_array_values; + + sequence dynamic_array_values; + }; + }; +}; diff --git a/rosidl_generator_py/resource/_msg.py.em b/rosidl_generator_py/resource/_msg.py.em index 4b49f1a4..fdf548df 100644 --- a/rosidl_generator_py/resource/_msg.py.em +++ b/rosidl_generator_py/resource/_msg.py.em @@ -18,6 +18,7 @@ from rosidl_parser.definition import BasicType from rosidl_parser.definition import BOOLEAN_TYPE from rosidl_parser.definition import BoundedSequence from rosidl_parser.definition import CHARACTER_TYPES +from rosidl_parser.definition import EnumerationType from rosidl_parser.definition import EMPTY_STRUCTURE_REQUIRED_MEMBER_NAME from rosidl_parser.definition import FLOATING_POINT_TYPES from rosidl_parser.definition import INTEGER_TYPES @@ -35,6 +36,8 @@ imports = OrderedDict() if message.structure.members: imports.setdefault( 'import rosidl_parser.definition', []) # used for SLOT_TYPES +if message.enumerations: + imports.setdefault('from enum import IntEnum', []) for member in message.structure.members: if member.name != EMPTY_STRUCTURE_REQUIRED_MEMBER_NAME: imports.setdefault( @@ -78,6 +81,16 @@ for member in message.structure.members: @[ end for]@ @[end if]@ @#>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>> +@[for enum in message.enumerations]@ + + +class @(enum.enumeration_type.name)(IntEnum): + """Enumeration '@(enum.enumeration_type.name)'.""" + +@[ for i in range(len(enum.enumerators))]@ + @(enum.enumerators[i]) = @(i) +@[ end for]@ +@[end for]@ class Metaclass_@(message.structure.namespaced_type.name)(type): @@ -178,6 +191,7 @@ for member in message.structure.members: class @(message.structure.namespaced_type.name)(metaclass=Metaclass_@(message.structure.namespaced_type.name)): @[if not message.constants]@ """Message class '@(message.structure.namespaced_type.name)'.""" + @[else]@ """ Message class '@(message.structure.namespaced_type.name)'. @@ -187,8 +201,11 @@ class @(message.structure.namespaced_type.name)(metaclass=Metaclass_@(message.st @(constant_name) @[ end for]@ """ -@[end if]@ +@[end if]@ +@[for enum in message.enumerations]@ + @(enum.enumeration_type.name) = @(enum.enumeration_type.name) +@[end for]@ __slots__ = [ @[for member in message.structure.members]@ @[ if len(message.structure.members) == 1 and member.name == EMPTY_STRUCTURE_REQUIRED_MEMBER_NAME]@ @@ -216,6 +233,8 @@ sequence<@ @# the typename of the non-nested type or the nested basetype @[ if isinstance(type_, BasicType)]@ @(type_.typename)@ +@[ elif isinstance(type_, EnumerationType)]@ +@(type_.name)@ @[ elif isinstance(type_, AbstractGenericString)]@ @ @[ if isinstance(type_, AbstractWString)]@ @@ -259,6 +278,8 @@ if isinstance(type_, AbstractNestedType): @(type_.__class__.__module__).@(type_.__class__.__name__)(@ @[ if isinstance(type_, BasicType)]@ '@(type_.typename)'@ +@[ elif isinstance(type_, EnumerationType)]@ +@(type_.namespaces), '@(type_.name)'@ @[ elif isinstance(type_, AbstractGenericString) and type_.has_maximum_size()]@ @(type_.maximum_size)@ @[ elif isinstance(type_, NamespacedType)]@ @@ -323,6 +344,11 @@ if isinstance(type_, AbstractNestedType): else: self.@(member.name) = numpy.array(kwargs.get('@(member.name)'), dtype=@(SPECIAL_NESTED_BASIC_TYPES[member.type.value_type.typename]['dtype'])) assert self.@(member.name).shape == (@(member.type.size), ) +@[ elif isinstance(member.type.value_type, EnumerationType)]@ + self.@(member.name) = kwargs.get( + '@(member.name)', + [@(get_python_type(type_))(0) for x in range(@(member.type.size))] + ) @[ else]@ self.@(member.name) = kwargs.get( '@(member.name)', @@ -340,6 +366,8 @@ if isinstance(type_, AbstractNestedType): self.@(member.name) = kwargs.get('@(member.name)', bytes([0])) @[ elif isinstance(type_, BasicType) and type_.typename in CHARACTER_TYPES]@ self.@(member.name) = kwargs.get('@(member.name)', chr(0)) +@[ elif isinstance(type_, EnumerationType)]@ + self.@(member.name) = kwargs.get('@(member.name)', @(get_python_type(type_))(0)) @[ else]@ self.@(member.name) = kwargs.get('@(member.name)', @(get_python_type(type_))()) @[ end if]@ @@ -572,6 +600,9 @@ bound = 1.7976931348623157e+308 assert value >= -@(bound) and value <= @(bound), \ "The '@(member.name)' field must be a @(name) in [@(-bound), @(bound)]" @[ end if]@ +@[ elif isinstance(type_, EnumerationType)]@ + isinstance(value, @(get_python_type(type_))), \ + "The '@(member.name)' field must be of type '@(get_python_type(type_))'" @[ else]@ False @[ end if]@ diff --git a/rosidl_generator_py/resource/_msg_support.c.em b/rosidl_generator_py/resource/_msg_support.c.em index 8701b183..fc35d67e 100644 --- a/rosidl_generator_py/resource/_msg_support.c.em +++ b/rosidl_generator_py/resource/_msg_support.c.em @@ -10,18 +10,22 @@ from rosidl_parser.definition import AbstractWString from rosidl_parser.definition import Array from rosidl_parser.definition import BasicType from rosidl_parser.definition import EMPTY_STRUCTURE_REQUIRED_MEMBER_NAME +from rosidl_parser.definition import EnumerationType from rosidl_parser.definition import NamespacedType -def primitive_msg_type_to_c(type_): +def primitive_msg_type_or_enum_to_c(type_): from rosidl_generator_c import BASIC_IDL_TYPES_TO_C from rosidl_parser.definition import AbstractString from rosidl_parser.definition import AbstractWString from rosidl_parser.definition import BasicType + from rosidl_parser.definition import EnumerationType if isinstance(type_, AbstractString): return 'rosidl_runtime_c__String' if isinstance(type_, AbstractWString): return 'rosidl_runtime_c__U16String' + if isinstance(type_, EnumerationType): + return '__'.join(type_.namespaced_name()) assert isinstance(type_, BasicType) return BASIC_IDL_TYPES_TO_C[type_.typename] @@ -257,14 +261,14 @@ nested_type = '__'.join(type_.namespaced_name()) Py_DECREF(field); return false; } - Py_ssize_t size = view.len / sizeof(@primitive_msg_type_to_c(member.type.value_type)); + Py_ssize_t size = view.len / sizeof(@primitive_msg_type_or_enum_to_c(member.type.value_type)); if (!rosidl_runtime_c__@(member.type.value_type.typename)__Sequence__init(&(ros_message->@(member.name)), size)) { PyErr_SetString(PyExc_RuntimeError, "unable to create @(member.type.value_type.typename)__Sequence ros_message"); PyBuffer_Release(&view); Py_DECREF(field); return false; } - @primitive_msg_type_to_c(member.type.value_type) * dest = ros_message->@(member.name).data; + @primitive_msg_type_or_enum_to_c(member.type.value_type) * dest = ros_message->@(member.name).data; rc = PyBuffer_ToContiguous(dest, &view, view.len, 'C'); if (rc < 0) { PyBuffer_Release(&view); @@ -313,6 +317,13 @@ nested_type = '__'.join(type_.namespaced_name()) Py_DECREF(field); return false; } +@[ elif isinstance(member.type.value_type, EnumerationType)]@ + if (!@(primitive_msg_type_or_enum_to_c(member.type.value_type))__Sequence__init(&(ros_message->@(member.name)), size)) { + PyErr_SetString(PyExc_RuntimeError, "unable to create @(member.type.value_type.name)__Sequence ros_message"); + Py_DECREF(seq_field); + Py_DECREF(field); + return false; + } @[ else]@ if (!rosidl_runtime_c__@(member.type.value_type.typename)__Sequence__init(&(ros_message->@(member.name)), size)) { PyErr_SetString(PyExc_RuntimeError, "unable to create @(member.type.value_type.typename)__Sequence ros_message"); @@ -321,10 +332,10 @@ nested_type = '__'.join(type_.namespaced_name()) return false; } @[ end if]@ - @primitive_msg_type_to_c(member.type.value_type) * dest = ros_message->@(member.name).data; + @primitive_msg_type_or_enum_to_c(member.type.value_type) * dest = ros_message->@(member.name).data; @[ else]@ Py_ssize_t size = @(member.type.size); - @primitive_msg_type_to_c(member.type.value_type) * dest = ros_message->@(member.name); + @primitive_msg_type_or_enum_to_c(member.type.value_type) * dest = ros_message->@(member.name); @[ end if]@ for (Py_ssize_t i = 0; i < size; ++i) { @[ if not isinstance(member.type, Array) or not isinstance(member.type.value_type, BasicType) or member.type.value_type.typename not in SPECIAL_NESTED_BASIC_TYPES]@ @@ -336,7 +347,7 @@ nested_type = '__'.join(type_.namespaced_name()) } @[ end if]@ @[ if isinstance(member.type, Array) and isinstance(member.type.value_type, BasicType) and member.type.value_type.typename in SPECIAL_NESTED_BASIC_TYPES]@ - @primitive_msg_type_to_c(member.type.value_type) tmp = *(@(SPECIAL_NESTED_BASIC_TYPES[member.type.value_type.typename]['dtype'].replace('numpy.', 'npy_')) *)PyArray_GETPTR1(seq_field, i); + @primitive_msg_type_or_enum_to_c(member.type.value_type) tmp = *(@(SPECIAL_NESTED_BASIC_TYPES[member.type.value_type.typename]['dtype'].replace('numpy.', 'npy_')) *)PyArray_GETPTR1(seq_field, i); @[ elif isinstance(member.type.value_type, BasicType) and member.type.value_type.typename == 'char']@ assert(PyUnicode_Check(item)); PyObject * encoded_item = PyUnicode_AsUTF8String(item); @@ -345,11 +356,11 @@ nested_type = '__'.join(type_.namespaced_name()) Py_DECREF(field); return false; } - @primitive_msg_type_to_c(member.type.value_type) tmp = PyBytes_AS_STRING(encoded_item)[0]; + @primitive_msg_type_or_enum_to_c(member.type.value_type) tmp = PyBytes_AS_STRING(encoded_item)[0]; Py_DECREF(encoded_item); @[ elif isinstance(member.type.value_type, BasicType) and member.type.value_type.typename == 'octet']@ assert(PyBytes_Check(item)); - @primitive_msg_type_to_c(member.type.value_type) tmp = PyBytes_AS_STRING(item)[0]; + @primitive_msg_type_or_enum_to_c(member.type.value_type) tmp = PyBytes_AS_STRING(item)[0]; @[ elif isinstance(member.type.value_type, AbstractString)]@ assert(PyUnicode_Check(item)); PyObject * encoded_item = PyUnicode_AsUTF8String(item); @@ -388,13 +399,13 @@ nested_type = '__'.join(type_.namespaced_name()) } @[ elif isinstance(member.type.value_type, BasicType) and member.type.value_type.typename == 'boolean']@ assert(PyBool_Check(item)); - @primitive_msg_type_to_c(member.type.value_type) tmp = (item == Py_True); + @primitive_msg_type_or_enum_to_c(member.type.value_type) tmp = (item == Py_True); @[ elif isinstance(member.type.value_type, BasicType) and member.type.value_type.typename in ('float', 'double')]@ assert(PyFloat_Check(item)); @[ if member.type.value_type.typename == 'float']@ - @primitive_msg_type_to_c(member.type.value_type) tmp = (float)PyFloat_AS_DOUBLE(item); + @primitive_msg_type_or_enum_to_c(member.type.value_type) tmp = (float)PyFloat_AS_DOUBLE(item); @[ else]@ - @primitive_msg_type_to_c(member.type.value_type) tmp = PyFloat_AS_DOUBLE(item); + @primitive_msg_type_or_enum_to_c(member.type.value_type) tmp = PyFloat_AS_DOUBLE(item); @[ end if]@ @[ elif isinstance(member.type.value_type, BasicType) and member.type.value_type.typename in ( 'int8', @@ -402,7 +413,7 @@ nested_type = '__'.join(type_.namespaced_name()) 'int32', )]@ assert(PyLong_Check(item)); - @primitive_msg_type_to_c(member.type.value_type) tmp = (@(primitive_msg_type_to_c(member.type.value_type)))PyLong_AsLong(item); + @primitive_msg_type_or_enum_to_c(member.type.value_type) tmp = (@(primitive_msg_type_or_enum_to_c(member.type.value_type)))PyLong_AsLong(item); @[ elif isinstance(member.type.value_type, BasicType) and member.type.value_type.typename in ( 'uint8', 'uint16', @@ -410,19 +421,24 @@ nested_type = '__'.join(type_.namespaced_name()) )]@ assert(PyLong_Check(item)); @[ if isinstance(member.type.value_type, BasicType) and member.type.value_type.typename == 'uint32']@ - @primitive_msg_type_to_c(member.type.value_type) tmp = PyLong_AsUnsignedLong(item); + @primitive_msg_type_or_enum_to_c(member.type.value_type) tmp = PyLong_AsUnsignedLong(item); @[ else]@ - @primitive_msg_type_to_c(member.type.value_type) tmp = (@(primitive_msg_type_to_c(member.type.value_type)))PyLong_AsUnsignedLong(item); + @primitive_msg_type_or_enum_to_c(member.type.value_type) tmp = (@(primitive_msg_type_or_enum_to_c(member.type.value_type)))PyLong_AsUnsignedLong(item); @[ end if] @[ elif isinstance(member.type.value_type, BasicType) and member.type.value_type.typename == 'int64']@ assert(PyLong_Check(item)); - @primitive_msg_type_to_c(member.type.value_type) tmp = PyLong_AsLongLong(item); + @primitive_msg_type_or_enum_to_c(member.type.value_type) tmp = PyLong_AsLongLong(item); @[ elif isinstance(member.type.value_type, BasicType) and member.type.value_type.typename == 'uint64']@ assert(PyLong_Check(item)); - @primitive_msg_type_to_c(member.type.value_type) tmp = PyLong_AsUnsignedLongLong(item); + @primitive_msg_type_or_enum_to_c(member.type.value_type) tmp = PyLong_AsUnsignedLongLong(item); +@[ elif isinstance(member.type.value_type, EnumerationType)]@ + assert(PyLong_Check(item)); + @primitive_msg_type_or_enum_to_c(member.type.value_type) tmp = PyLong_AsLong(item); @[ end if]@ @[ if isinstance(member.type.value_type, BasicType)]@ - memcpy(&dest[i], &tmp, sizeof(@primitive_msg_type_to_c(member.type.value_type))); + memcpy(&dest[i], &tmp, sizeof(@primitive_msg_type_or_enum_to_c(member.type.value_type))); +@[ elif isinstance(member.type.value_type, EnumerationType)]@ + memcpy(&dest[i], &tmp, sizeof(@primitive_msg_type_or_enum_to_c(member.type.value_type))); @[ end if]@ } Py_DECREF(seq_field); @@ -489,7 +505,7 @@ nested_type = '__'.join(type_.namespaced_name()) 'int32', )]@ assert(PyLong_Check(field)); - ros_message->@(member.name) = (@(primitive_msg_type_to_c(member.type)))PyLong_AsLong(field); + ros_message->@(member.name) = (@(primitive_msg_type_or_enum_to_c(member.type)))PyLong_AsLong(field); @[ elif isinstance(member.type, BasicType) and member.type.typename in ( 'uint8', 'uint16', @@ -499,7 +515,7 @@ nested_type = '__'.join(type_.namespaced_name()) @[ if member.type.typename == 'uint32']@ ros_message->@(member.name) = PyLong_AsUnsignedLong(field); @[ else]@ - ros_message->@(member.name) = (@(primitive_msg_type_to_c(member.type)))PyLong_AsUnsignedLong(field); + ros_message->@(member.name) = (@(primitive_msg_type_or_enum_to_c(member.type)))PyLong_AsUnsignedLong(field); @[ end if]@ @[ elif isinstance(member.type, BasicType) and member.type.typename == 'int64']@ assert(PyLong_Check(field)); @@ -507,6 +523,9 @@ nested_type = '__'.join(type_.namespaced_name()) @[ elif isinstance(member.type, BasicType) and member.type.typename == 'uint64']@ assert(PyLong_Check(field)); ros_message->@(member.name) = PyLong_AsUnsignedLongLong(field); +@[ elif isinstance(member.type, EnumerationType)]@ + assert(PyLong_Check(field)); + ros_message->@(member.name) = (@(primitive_msg_type_or_enum_to_c(member.type)))PyLong_AsLong(field); @[ else]@ assert(false); @[ end if]@ @@ -562,10 +581,10 @@ if isinstance(type_, AbstractNestedType): PyArrayObject * seq_field = (PyArrayObject *)field; assert(PyArray_NDIM(seq_field) == 1); assert(PyArray_TYPE(seq_field) == @(SPECIAL_NESTED_BASIC_TYPES[member.type.value_type.typename]['dtype'].replace('numpy.', 'NPY_').upper())); - assert(sizeof(@(SPECIAL_NESTED_BASIC_TYPES[member.type.value_type.typename]['dtype'].replace('numpy.', 'npy_'))) == sizeof(@primitive_msg_type_to_c(member.type.value_type))); + assert(sizeof(@(SPECIAL_NESTED_BASIC_TYPES[member.type.value_type.typename]['dtype'].replace('numpy.', 'npy_'))) == sizeof(@primitive_msg_type_or_enum_to_c(member.type.value_type))); @(SPECIAL_NESTED_BASIC_TYPES[member.type.value_type.typename]['dtype'].replace('numpy.', 'npy_')) * dst = (@(SPECIAL_NESTED_BASIC_TYPES[member.type.value_type.typename]['dtype'].replace('numpy.', 'npy_')) *)PyArray_GETPTR1(seq_field, 0); - @primitive_msg_type_to_c(member.type.value_type) * src = &(ros_message->@(member.name)[0]); - memcpy(dst, src, @(member.type.size) * sizeof(@primitive_msg_type_to_c(member.type.value_type))); + @primitive_msg_type_or_enum_to_c(member.type.value_type) * src = &(ros_message->@(member.name)[0]); + memcpy(dst, src, @(member.type.size) * sizeof(@primitive_msg_type_or_enum_to_c(member.type.value_type))); Py_DECREF(field); @[ elif isinstance(member.type, AbstractSequence)]@ field = PyObject_GetAttrString(_pymessage, "@(member.name)"); @@ -580,7 +599,7 @@ if isinstance(type_, AbstractNestedType): assert(itemsize_attr != NULL); size_t itemsize = PyLong_AsSize_t(itemsize_attr); Py_DECREF(itemsize_attr); - if (itemsize != sizeof(@primitive_msg_type_to_c(member.type.value_type))) { + if (itemsize != sizeof(@primitive_msg_type_or_enum_to_c(member.type.value_type))) { PyErr_SetString(PyExc_RuntimeError, "itemsize doesn't match expectation"); Py_DECREF(field); return NULL; @@ -609,8 +628,8 @@ if isinstance(type_, AbstractNestedType): // populating the array.array using the frombytes method PyObject * frombytes = PyObject_GetAttrString(field, "frombytes"); assert(frombytes != NULL); - @primitive_msg_type_to_c(member.type.value_type) * src = &(ros_message->@(member.name).data[0]); - PyObject * data = PyBytes_FromStringAndSize((const char *)src, ros_message->@(member.name).size * sizeof(@primitive_msg_type_to_c(member.type.value_type))); + @primitive_msg_type_or_enum_to_c(member.type.value_type) * src = &(ros_message->@(member.name).data[0]); + PyObject * data = PyBytes_FromStringAndSize((const char *)src, ros_message->@(member.name).size * sizeof(@primitive_msg_type_or_enum_to_c(member.type.value_type))); assert(data != NULL); PyObject * ret = PyObject_CallFunctionObjArgs(frombytes, data, NULL); Py_DECREF(data); @@ -664,10 +683,10 @@ nested_type = '__'.join(type_.namespaced_name()) @[ elif isinstance(member.type, AbstractNestedType)]@ @[ if isinstance(member.type, AbstractSequence)]@ size_t size = ros_message->@(member.name).size; - @primitive_msg_type_to_c(member.type.value_type) * src = ros_message->@(member.name).data; + @primitive_msg_type_or_enum_to_c(member.type.value_type) * src = ros_message->@(member.name).data; @[ else]@ size_t size = @(member.type.size); - @primitive_msg_type_to_c(member.type.value_type) * src = ros_message->@(member.name); + @primitive_msg_type_or_enum_to_c(member.type.value_type) * src = ros_message->@(member.name); @[ end if]@ field = PyList_New(size); if (!field) { @@ -732,6 +751,15 @@ nested_type = '__'.join(type_.namespaced_name()) int rc = PyList_SetItem(field, i, PyLong_FromUnsignedLongLong(src[i])); (void)rc; assert(rc == 0); +@[ elif isinstance(member.type.value_type, EnumerationType)]@ + PyObject * enum_class = PyObject_GetAttrString(_pymessage, "@(member.type.value_type.name)"); + assert(enum_class); + PyObject * value = PyObject_CallFunction(enum_class, "(i)", (int32_t)src[i]); + assert(value); + Py_DECREF(enum_class); + int rc = PyList_SetItem(field, i, value); + (void)rc; + assert(rc == 0); @[ end if]@ } assert(PySequence_Check(field)); @@ -783,6 +811,12 @@ nested_type = '__'.join(type_.namespaced_name()) field = PyLong_FromLongLong(ros_message->@(member.name)); @[ elif isinstance(member.type, BasicType) and member.type.typename == 'uint64']@ field = PyLong_FromUnsignedLongLong(ros_message->@(member.name)); +@[ elif isinstance(member.type, EnumerationType)]@ + PyObject * enum_class = PyObject_GetAttrString(_pymessage, "@(member.type.name)"); + assert(enum_class); + field = PyObject_CallFunction(enum_class, "(i)", (int32_t)ros_message->@(member.name)); + assert(field); + Py_DECREF(enum_class); @[ else]@ assert(false); @[ end if]@ diff --git a/rosidl_generator_py/rosidl_generator_py/generate_py_impl.py b/rosidl_generator_py/rosidl_generator_py/generate_py_impl.py index dfa2d94c..423fad6a 100644 --- a/rosidl_generator_py/rosidl_generator_py/generate_py_impl.py +++ b/rosidl_generator_py/rosidl_generator_py/generate_py_impl.py @@ -30,6 +30,7 @@ from rosidl_parser.definition import Array from rosidl_parser.definition import BasicType from rosidl_parser.definition import CHARACTER_TYPES +from rosidl_parser.definition import EnumerationType from rosidl_parser.definition import FLOATING_POINT_TYPES from rosidl_parser.definition import IdlContent from rosidl_parser.definition import IdlLocator @@ -166,6 +167,8 @@ def value_to_py(type_, value, array_as_tuple=False): assert value is not None if not isinstance(type_, AbstractNestedType): + if isinstance(type_, EnumerationType): + return '%s.%s' % (type_.name, value) return primitive_value_to_py(type_, value) py_values = [] @@ -259,6 +262,9 @@ def get_python_type(type_): if isinstance(type_, NamespacedType): return type_.name + if isinstance(type_, EnumerationType): + return type_.name + if isinstance(type_, AbstractGenericString): return 'str' diff --git a/rosidl_generator_py/test/test_interfaces.py b/rosidl_generator_py/test/test_interfaces.py index 9a2d5571..06a428dc 100644 --- a/rosidl_generator_py/test/test_interfaces.py +++ b/rosidl_generator_py/test/test_interfaces.py @@ -17,6 +17,7 @@ import numpy import pytest +from rosidl_generator_py.idl import Enums from rosidl_generator_py.msg import Arrays from rosidl_generator_py.msg import BasicTypes from rosidl_generator_py.msg import BoundedSequences @@ -926,3 +927,28 @@ def test_builtin_sequence_slot_attributes(): builtin_sequence_slot_types_dict = getattr(msg, 'get_fields_and_field_types')() builtin_sequence_slots = getattr(msg, '__slots__') assert len(builtin_sequence_slot_types_dict) == len(builtin_sequence_slots) + + +def test_enums(): + array_size = 3 + expected_value = Enums.SomeEnum.ENUMERATOR2 + expected_default_value = Enums.SomeEnum.ENUMERATOR2 + + msg = Enums() + assert isinstance(msg.enum_value, Enums.SomeEnum) + assert isinstance(msg.static_array_values, list) + assert isinstance(msg.bounded_array_values, list) + assert isinstance(msg.dynamic_array_values, list) + + msg.enum_value = expected_value + for i in range(array_size): + msg.static_array_values[i] = expected_value + msg.bounded_array_values.append(expected_value) + msg.dynamic_array_values.append(expected_value) + + assert (msg.enum_value == expected_value) + assert (msg.enum_default_value == expected_default_value) + for i in range(array_size): + assert (msg.static_array_values[i] == expected_value) + assert (msg.dynamic_array_values[i] == expected_value) + assert (msg.bounded_array_values[i] == expected_value)