Skip to content

Conversation

@psmyth94
Copy link
Contributor

Hello,

When working with bio-data, each feature often has metadata associated with it (e.g. species, lineage, snp position, etc). To store this, I like to use the feature classes with the added metadata attribute. However, when saving or loading with custom features, you get an error since that class doesn't exist in the global namespace in datasets.features.features. Take for example,

from dataclasses import dataclass, field
from datasets import Dataset
from datasets.features.features import Value, Features

@dataclass
class FeatureA(Value):
    metadata: dict = field(default=dict)
    _type: str = field(default="FeatureA", init=False, repr=False)

@dataclass
class FeatureB(Value):
    metadata: dict = field(default_factory=dict)
    _type: str = field(default="FeatureB", init=False, repr=False)

test_data = {
    "a": [1, 2, 3],
    "b": [4, 5, 6],
}
test_data = Dataset.from_dict(
    test_data, 
    features=Features({
        "a": FeatureA("int32", metadata={"species": "lactobacillus acetotolerans"}),
        "b": FeatureB("int32", metadata={"species": "lactobacillus iners"}),
    })
)

# returns an error since FeatureA and FeatureB are not in the global namespace
test_data.save_to_disk('./test_data')
Saving the dataset (0/1 shards):   0%|          | 0/3 [00:00<?, ? examples/s]



---------------------------------------------------------------------------

KeyError                                  Traceback (most recent call last)

Cell In[2], line 28
     19 test_data = Dataset.from_dict(
     20     test_data, 
     21     features=Features({
   (...)
     24     })
     25 )
     27 # returns an error since FeatureA and FeatureB are not in the global namespace
---> 28 test_data.save_to_disk('./test_data')
...
File ~\Documents\datasets\src\datasets\features\features.py:1361, in generate_from_dict(obj)
   1359     return {key: generate_from_dict(value) for key, value in obj.items()}
   1360 obj = dict(obj)
-> 1361 class_type = globals()[obj.pop("_type")]
   1363 if class_type == Sequence:
   1364     return Sequence(feature=generate_from_dict(obj["feature"]), length=obj.get("length", -1))


KeyError: 'FeatureA'

We can avoid this by having a registry (like formatters) and doing

from datasets.features.features import register_feature
register_feature(FeatureA, "FeatureA")
register_feature(FeatureB, "FeatureB")
test_data.save_to_disk('./test_data')
Saving the dataset (1/1 shards): 100%|------| 3/3 [00:00<00:00, 211.13 examples/s]

and loading from disk returns with all metadata information

from datasets import load_from_disk
test_data = load_from_disk('./test_data')
test_data.features
{'a': FeatureA(dtype='int32', id=None, metadata={'species': 'lactobacillus acetotolerans'}),
 'b': FeatureB(dtype='int32', id=None, metadata={'species': 'lactobacillus iners'})}

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool ! Though for now I'd keep this feature experimental if it's fine for you.

I just added a few comments:

@psmyth94
Copy link
Contributor Author

looks like some files are missing in your google storage

@lhoestq
Copy link
Member

lhoestq commented Mar 12, 2024

cc @mariosasko is it related to #6474 ? The files should ideally not move for backward compatibility anyway

@mariosasko
Copy link
Collaborator

@lhoestq All the files are still there.

The problem is that the natural_questions is now a no-code dataset, so the test's paths are no longer correct (unless the revision is pinned to the previous version).

@psmyth94 This has been fixed on main, so you can make the CI tests green with the following:

git remote add upstream https://github.com/huggingface/datasets.git
git pull upstream main
git push

@lhoestq
Copy link
Member

lhoestq commented Mar 12, 2024

Thank you @mariosasko ! I'm updating this branch if you don't mind @psmyth94

Copy link
Member

@lhoestq lhoestq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks all good now, thanks !

@lhoestq lhoestq merged commit 4591ac1 into huggingface:main Mar 13, 2024
@github-actions
Copy link

Show benchmarks

PyArrow==8.0.0

Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.004903 / 0.011353 (-0.006450) 0.003105 / 0.011008 (-0.007903) 0.061980 / 0.038508 (0.023471) 0.029726 / 0.023109 (0.006617) 0.243406 / 0.275898 (-0.032492) 0.262530 / 0.323480 (-0.060950) 0.003905 / 0.007986 (-0.004081) 0.002617 / 0.004328 (-0.001712) 0.047851 / 0.004250 (0.043601) 0.040397 / 0.037052 (0.003345) 0.259461 / 0.258489 (0.000972) 0.285059 / 0.293841 (-0.008782) 0.027321 / 0.128546 (-0.101225) 0.009876 / 0.075646 (-0.065770) 0.206999 / 0.419271 (-0.212273) 0.034906 / 0.043533 (-0.008626) 0.245120 / 0.255139 (-0.010019) 0.270490 / 0.283200 (-0.012710) 0.017341 / 0.141683 (-0.124342) 1.128182 / 1.452155 (-0.323973) 1.173024 / 1.492716 (-0.319693)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.089337 / 0.018006 (0.071331) 0.298256 / 0.000490 (0.297767) 0.000216 / 0.000200 (0.000016) 0.000047 / 0.000054 (-0.000007)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.018179 / 0.037411 (-0.019233) 0.061275 / 0.014526 (0.046749) 0.073137 / 0.176557 (-0.103419) 0.119603 / 0.737135 (-0.617532) 0.073969 / 0.296338 (-0.222370)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.283109 / 0.215209 (0.067900) 2.765441 / 2.077655 (0.687787) 1.471276 / 1.504120 (-0.032844) 1.346365 / 1.541195 (-0.194830) 1.360668 / 1.468490 (-0.107822) 0.549947 / 4.584777 (-4.034830) 2.344213 / 3.745712 (-1.401499) 2.700905 / 5.269862 (-2.568956) 1.689936 / 4.565676 (-2.875741) 0.061985 / 0.424275 (-0.362290) 0.004923 / 0.007607 (-0.002684) 0.329833 / 0.226044 (0.103788) 3.277580 / 2.268929 (1.008652) 1.833987 / 55.444624 (-53.610638) 1.571023 / 6.876477 (-5.305454) 1.573259 / 2.142072 (-0.568813) 0.627504 / 4.805227 (-4.177723) 0.114106 / 6.500664 (-6.386558) 0.041197 / 0.075469 (-0.034272)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 0.967400 / 1.841788 (-0.874388) 11.046527 / 8.074308 (2.972219) 9.542214 / 10.191392 (-0.649178) 0.140745 / 0.680424 (-0.539679) 0.013627 / 0.534201 (-0.520574) 0.288429 / 0.579283 (-0.290855) 0.260509 / 0.434364 (-0.173855) 0.324704 / 0.540337 (-0.215633) 0.419366 / 1.386936 (-0.967570)
PyArrow==latest
Show updated benchmarks!

Benchmark: benchmark_array_xd.json

metric read_batch_formatted_as_numpy after write_array2d read_batch_formatted_as_numpy after write_flattened_sequence read_batch_formatted_as_numpy after write_nested_sequence read_batch_unformated after write_array2d read_batch_unformated after write_flattened_sequence read_batch_unformated after write_nested_sequence read_col_formatted_as_numpy after write_array2d read_col_formatted_as_numpy after write_flattened_sequence read_col_formatted_as_numpy after write_nested_sequence read_col_unformated after write_array2d read_col_unformated after write_flattened_sequence read_col_unformated after write_nested_sequence read_formatted_as_numpy after write_array2d read_formatted_as_numpy after write_flattened_sequence read_formatted_as_numpy after write_nested_sequence read_unformated after write_array2d read_unformated after write_flattened_sequence read_unformated after write_nested_sequence write_array2d write_flattened_sequence write_nested_sequence
new / old (diff) 0.005123 / 0.011353 (-0.006230) 0.003119 / 0.011008 (-0.007890) 0.048931 / 0.038508 (0.010423) 0.032067 / 0.023109 (0.008958) 0.276825 / 0.275898 (0.000927) 0.297589 / 0.323480 (-0.025890) 0.004075 / 0.007986 (-0.003911) 0.002579 / 0.004328 (-0.001750) 0.047862 / 0.004250 (0.043612) 0.044032 / 0.037052 (0.006980) 0.289469 / 0.258489 (0.030980) 0.327269 / 0.293841 (0.033428) 0.029369 / 0.128546 (-0.099177) 0.010180 / 0.075646 (-0.065466) 0.057111 / 0.419271 (-0.362161) 0.051046 / 0.043533 (0.007513) 0.276758 / 0.255139 (0.021619) 0.296084 / 0.283200 (0.012884) 0.017376 / 0.141683 (-0.124306) 1.154486 / 1.452155 (-0.297669) 1.192699 / 1.492716 (-0.300018)

Benchmark: benchmark_getitem_100B.json

metric get_batch_of_1024_random_rows get_batch_of_1024_rows get_first_row get_last_row
new / old (diff) 0.085981 / 0.018006 (0.067974) 0.296956 / 0.000490 (0.296466) 0.000211 / 0.000200 (0.000011) 0.000050 / 0.000054 (-0.000004)

Benchmark: benchmark_indices_mapping.json

metric select shard shuffle sort train_test_split
new / old (diff) 0.021239 / 0.037411 (-0.016172) 0.074851 / 0.014526 (0.060326) 0.085676 / 0.176557 (-0.090881) 0.125876 / 0.737135 (-0.611259) 0.087573 / 0.296338 (-0.208765)

Benchmark: benchmark_iterating.json

metric read 5000 read 50000 read_batch 50000 10 read_batch 50000 100 read_batch 50000 1000 read_formatted numpy 5000 read_formatted pandas 5000 read_formatted tensorflow 5000 read_formatted torch 5000 read_formatted_batch numpy 5000 10 read_formatted_batch numpy 5000 1000 shuffled read 5000 shuffled read 50000 shuffled read_batch 50000 10 shuffled read_batch 50000 100 shuffled read_batch 50000 1000 shuffled read_formatted numpy 5000 shuffled read_formatted_batch numpy 5000 10 shuffled read_formatted_batch numpy 5000 1000
new / old (diff) 0.289220 / 0.215209 (0.074011) 2.812342 / 2.077655 (0.734688) 1.572886 / 1.504120 (0.068766) 1.446442 / 1.541195 (-0.094752) 1.458737 / 1.468490 (-0.009753) 0.562010 / 4.584777 (-4.022767) 2.422896 / 3.745712 (-1.322816) 2.578408 / 5.269862 (-2.691454) 1.689998 / 4.565676 (-2.875678) 0.064782 / 0.424275 (-0.359493) 0.005051 / 0.007607 (-0.002556) 0.339982 / 0.226044 (0.113938) 3.309882 / 2.268929 (1.040953) 1.910273 / 55.444624 (-53.534351) 1.649723 / 6.876477 (-5.226753) 1.744073 / 2.142072 (-0.397999) 0.651905 / 4.805227 (-4.153323) 0.114606 / 6.500664 (-6.386058) 0.040030 / 0.075469 (-0.035439)

Benchmark: benchmark_map_filter.json

metric filter map fast-tokenizer batched map identity map identity batched map no-op batched map no-op batched numpy map no-op batched pandas map no-op batched pytorch map no-op batched tensorflow
new / old (diff) 1.008374 / 1.841788 (-0.833414) 11.547300 / 8.074308 (3.472992) 9.966061 / 10.191392 (-0.225331) 0.144874 / 0.680424 (-0.535550) 0.014400 / 0.534201 (-0.519801) 0.285435 / 0.579283 (-0.293848) 0.274755 / 0.434364 (-0.159609) 0.323105 / 0.540337 (-0.217232) 0.439172 / 1.386936 (-0.947764)

@psmyth94 psmyth94 deleted the fetch-features-from-registry branch March 13, 2024 12:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants