Skip to content

Commit 5a5cd6d

Browse files
authored
fix timestamp conversion from pd to py in streaming (#4541)
1 parent 47cccc9 commit 5a5cd6d

File tree

3 files changed

+52
-2
lines changed

3 files changed

+52
-2
lines changed

src/datasets/features/features.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -336,9 +336,23 @@ def _cast_to_python_objects(obj: Any, only_1d_for_numpy: bool, optimize_list_cas
336336
elif config.PIL_AVAILABLE and "PIL" in sys.modules and isinstance(obj, PIL.Image.Image):
337337
return encode_pil_image(obj), True
338338
elif isinstance(obj, pd.Series):
339-
return obj.values.tolist(), True
339+
return (
340+
_cast_to_python_objects(
341+
obj.tolist(), only_1d_for_numpy=only_1d_for_numpy, optimize_list_casting=optimize_list_casting
342+
)[0],
343+
True,
344+
)
340345
elif isinstance(obj, pd.DataFrame):
341-
return obj.to_dict("list"), True
346+
return {
347+
key: _cast_to_python_objects(
348+
value, only_1d_for_numpy=only_1d_for_numpy, optimize_list_casting=optimize_list_casting
349+
)[0]
350+
for key, value in obj.to_dict("list").items()
351+
}, True
352+
elif isinstance(obj, pd.Timestamp):
353+
return obj.to_pydatetime(), True
354+
elif isinstance(obj, pd.Timedelta):
355+
return obj.to_pytimedelta(), True
342356
elif isinstance(obj, Mapping): # check for dict-like to handle nested LazyDict objects
343357
has_changed = not isinstance(obj, dict)
344358
output = {}

tests/features/test_features.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,26 @@ def test_cast_to_python_objects_dataframe(self):
482482
casted_obj = cast_to_python_objects(obj)
483483
self.assertDictEqual(casted_obj, expected_obj)
484484

485+
def test_cast_to_python_objects_pandas_timestamp(self):
486+
obj = pd.Timestamp(2020, 1, 1)
487+
expected_obj = obj.to_pydatetime()
488+
casted_obj = cast_to_python_objects(obj)
489+
self.assertEqual(casted_obj, expected_obj)
490+
casted_obj = cast_to_python_objects(pd.Series([obj]))
491+
self.assertListEqual(casted_obj, [expected_obj])
492+
casted_obj = cast_to_python_objects(pd.DataFrame({"a": [obj]}))
493+
self.assertDictEqual(casted_obj, {"a": [expected_obj]})
494+
495+
def test_cast_to_python_objects_pandas_timedelta(self):
496+
obj = pd.Timedelta(seconds=1)
497+
expected_obj = obj.to_pytimedelta()
498+
casted_obj = cast_to_python_objects(obj)
499+
self.assertEqual(casted_obj, expected_obj)
500+
casted_obj = cast_to_python_objects(pd.Series([obj]))
501+
self.assertListEqual(casted_obj, [expected_obj])
502+
casted_obj = cast_to_python_objects(pd.DataFrame({"a": [obj]}))
503+
self.assertDictEqual(casted_obj, {"a": [expected_obj]})
504+
485505
@require_torch
486506
def test_cast_to_python_objects_torch(self):
487507
import torch

tests/test_iterable_dataset.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from itertools import chain, islice
33

44
import numpy as np
5+
import pandas as pd
56
import pytest
67

78
from datasets import load_dataset
@@ -701,6 +702,21 @@ def test_iterable_dataset_features(features):
701702
assert list(dataset) == expected
702703

703704

705+
def test_iterable_dataset_features_cast_to_python():
706+
ex_iterable = ExamplesIterable(
707+
generate_examples_fn, {"timestamp": pd.Timestamp(2020, 1, 1), "array": np.ones(5), "n": 1}
708+
)
709+
features = Features(
710+
{
711+
"id": Value("int64"),
712+
"timestamp": Value("timestamp[us]"),
713+
"array": [Value("int64")],
714+
}
715+
)
716+
dataset = IterableDataset(ex_iterable, info=DatasetInfo(features=features))
717+
assert list(dataset) == [{"timestamp": pd.Timestamp(2020, 1, 1).to_pydatetime(), "array": [1] * 5, "id": 0}]
718+
719+
704720
@require_torch
705721
@pytest.mark.parametrize("format_type", [None, "torch", "python"])
706722
def test_iterable_dataset_with_format(dataset: IterableDataset, format_type):

0 commit comments

Comments
 (0)