diff --git a/petastorm/transform.py b/petastorm/transform.py index e8cebf1d7..ded2c9b8e 100644 --- a/petastorm/transform.py +++ b/petastorm/transform.py @@ -17,7 +17,7 @@ class TransformSpec(object): - def __init__(self, func=None, edit_fields=None, removed_fields=None): + def __init__(self, func=None, edit_fields=None, removed_fields=None, selected_fields=None): """TransformSpec defines a user transformation that is applied to a loaded row on a worker thread/process. The object defines the function (callable) that perform the transform as well as the @@ -34,10 +34,12 @@ def __init__(self, func=None, edit_fields=None, removed_fields=None): :param edit_fields: Optional. A list of 4-tuples with the following fields: ``(name, numpy_dtype, shape, is_nullable)`` :param removed_fields: Optional[list]. A list of field names that will be removed from the original schema. + :param selected_fields: Optional[list]. A list of field names specify the fields to be selected. """ self.func = func self.edit_fields = edit_fields or [] self.removed_fields = removed_fields or [] + self.selected_fields = selected_fields def transform_schema(schema, transform_spec): @@ -61,4 +63,8 @@ def transform_schema(schema, transform_spec): shape=field_to_edit[2], codec=None, nullable=field_to_edit[3]) fields.append(edited_unischema_field) + if transform_spec.selected_fields is not None: + fields = [f for f in fields if f.name in transform_spec.selected_fields] + fields = sorted(fields, key=lambda f: transform_spec.selected_fields.index(f.name)) + return Unischema(schema._name + '_transformed', fields) diff --git a/petastorm/unischema.py b/petastorm/unischema.py index 472abf7c4..2002f717e 100644 --- a/petastorm/unischema.py +++ b/petastorm/unischema.py @@ -95,11 +95,11 @@ def get(parent_schema_name, field_names): :return: A namedtuple with field names defined by `field_names` """ # Cache key is a combination of schema name and all field names - sorted_names = list(sorted(field_names)) - key = ' '.join([parent_schema_name] + sorted_names) + field_names = list(field_names) + key = ' '.join([parent_schema_name] + field_names) if key not in _NamedtupleCache._store: _NamedtupleCache._store[key] = \ - _new_gt_255_compatible_namedtuple('{}_view'.format(parent_schema_name), sorted_names) + _new_gt_255_compatible_namedtuple('{}_view'.format(parent_schema_name), field_names) return _NamedtupleCache._store[key] @@ -185,6 +185,8 @@ def __init__(self, name, fields): warnings.warn(('Can not create dynamic property {} because it conflicts with an existing property of ' 'Unischema').format(f.name)) + self._ordered_fields_name = [f.name for f in fields] + def create_schema_view(self, fields): """Creates a new instance of the schema using a subset of fields. @@ -229,7 +231,7 @@ def create_schema_view(self, fields): return Unischema('{}_view'.format(self._name), view_fields) def _get_namedtuple(self): - return _NamedtupleCache.get(self._name, self._fields.keys()) + return _NamedtupleCache.get(self._name, self._ordered_fields_name) def __str__(self): """Represent this as the following form: