@@ -1190,7 +1190,7 @@ def test_map_multiprocessing(self, in_memory):
11901190 self .assertNotEqual (dset_test ._fingerprint , fingerprint )
11911191 assert_arrow_metadata_are_synced_with_dataset_features (dset_test )
11921192
1193- def test_new_features (self , in_memory ):
1193+ def test_map_new_features (self , in_memory ):
11941194 with tempfile .TemporaryDirectory () as tmp_dir :
11951195 with self ._create_dummy_dataset (in_memory , tmp_dir ) as dset :
11961196 features = Features ({"filename" : Value ("string" ), "label" : ClassLabel (names = ["positive" , "negative" ])})
@@ -1397,6 +1397,84 @@ def test_map_caching(self, in_memory):
13971397 finally :
13981398 datasets .enable_caching ()
13991399
1400+ def test_map_return_pa_table (self , in_memory ):
1401+ def func_return_single_row_pa_table (x ):
1402+ return pa .table ({"id" : [0 ], "text" : ["a" ]})
1403+
1404+ with tempfile .TemporaryDirectory () as tmp_dir :
1405+ with self ._create_dummy_dataset (in_memory , tmp_dir ) as dset :
1406+ with dset .map (func_return_single_row_pa_table ) as dset_test :
1407+ self .assertEqual (len (dset_test ), 30 )
1408+ self .assertDictEqual (
1409+ dset_test .features ,
1410+ Features ({"id" : Value ("int64" ), "text" : Value ("string" )}),
1411+ )
1412+ self .assertEqual (dset_test [0 ]["id" ], 0 )
1413+ self .assertEqual (dset_test [0 ]["text" ], "a" )
1414+
1415+ # Batched
1416+ def func_return_single_row_pa_table_batched (x ):
1417+ batch_size = len (x [next (iter (x ))])
1418+ return pa .table ({"id" : [0 ] * batch_size , "text" : ["a" ] * batch_size })
1419+
1420+ with tempfile .TemporaryDirectory () as tmp_dir :
1421+ with self ._create_dummy_dataset (in_memory , tmp_dir ) as dset :
1422+ with dset .map (func_return_single_row_pa_table_batched , batched = True ) as dset_test :
1423+ self .assertEqual (len (dset_test ), 30 )
1424+ self .assertDictEqual (
1425+ dset_test .features ,
1426+ Features ({"id" : Value ("int64" ), "text" : Value ("string" )}),
1427+ )
1428+ self .assertEqual (dset_test [0 ]["id" ], 0 )
1429+ self .assertEqual (dset_test [0 ]["text" ], "a" )
1430+
1431+ # Error when returning a table with more than one row in the non-batched mode
1432+ def func_return_multi_row_pa_table (x ):
1433+ return pa .table ({"id" : [0 , 1 ], "text" : ["a" , "b" ]})
1434+
1435+ with tempfile .TemporaryDirectory () as tmp_dir :
1436+ with self ._create_dummy_dataset (in_memory , tmp_dir ) as dset :
1437+ self .assertRaises (ValueError , dset .map , func_return_multi_row_pa_table )
1438+
1439+ def test_map_return_pd_dataframe (self , in_memory ):
1440+ def func_return_single_row_pd_dataframe (x ):
1441+ return pd .DataFrame ({"id" : [0 ], "text" : ["a" ]})
1442+
1443+ with tempfile .TemporaryDirectory () as tmp_dir :
1444+ with self ._create_dummy_dataset (in_memory , tmp_dir ) as dset :
1445+ with dset .map (func_return_single_row_pd_dataframe ) as dset_test :
1446+ self .assertEqual (len (dset_test ), 30 )
1447+ self .assertDictEqual (
1448+ dset_test .features ,
1449+ Features ({"id" : Value ("int64" ), "text" : Value ("string" )}),
1450+ )
1451+ self .assertEqual (dset_test [0 ]["id" ], 0 )
1452+ self .assertEqual (dset_test [0 ]["text" ], "a" )
1453+
1454+ # Batched
1455+ def func_return_single_row_pd_dataframe_batched (x ):
1456+ batch_size = len (x [next (iter (x ))])
1457+ return pd .DataFrame ({"id" : [0 ] * batch_size , "text" : ["a" ] * batch_size })
1458+
1459+ with tempfile .TemporaryDirectory () as tmp_dir :
1460+ with self ._create_dummy_dataset (in_memory , tmp_dir ) as dset :
1461+ with dset .map (func_return_single_row_pd_dataframe_batched , batched = True ) as dset_test :
1462+ self .assertEqual (len (dset_test ), 30 )
1463+ self .assertDictEqual (
1464+ dset_test .features ,
1465+ Features ({"id" : Value ("int64" ), "text" : Value ("string" )}),
1466+ )
1467+ self .assertEqual (dset_test [0 ]["id" ], 0 )
1468+ self .assertEqual (dset_test [0 ]["text" ], "a" )
1469+
1470+ # Error when returning a table with more than one row in the non-batched mode
1471+ def func_return_multi_row_pd_dataframe (x ):
1472+ return pd .DataFrame ({"id" : [0 , 1 ], "text" : ["a" , "b" ]})
1473+
1474+ with tempfile .TemporaryDirectory () as tmp_dir :
1475+ with self ._create_dummy_dataset (in_memory , tmp_dir ) as dset :
1476+ self .assertRaises (ValueError , dset .map , func_return_multi_row_pd_dataframe )
1477+
14001478 @require_torch
14011479 def test_map_torch (self , in_memory ):
14021480 import torch
0 commit comments