@@ -83,25 +83,29 @@ def inject_fixtures(self, caplog):
8383 self ._caplog = caplog
8484
8585 def _create_dummy_dataset (
86- self , in_memory : bool , tmp_dir : str , multiple_columns = False , array_features = False
86+ self , in_memory : bool , tmp_dir : str , multiple_columns = False , array_features = False , nested_features = False
8787 ) -> Dataset :
88+ assert int (multiple_columns ) + int (array_features ) + int (nested_features ) < 2
8889 if multiple_columns :
89- if array_features :
90- data = {
91- "col_1" : [[[True , False ], [False , True ]]] * 4 , # 2D
92- "col_2" : [[[["a" , "b" ], ["c" , "d" ]], [["e" , "f" ], ["g" , "h" ]]]] * 4 , # 3D array
93- "col_3" : [[3 , 2 , 1 , 0 ]] * 4 , # Sequence
90+ data = {"col_1" : [3 , 2 , 1 , 0 ], "col_2" : ["a" , "b" , "c" , "d" ], "col_3" : [False , True , False , True ]}
91+ dset = Dataset .from_dict (data )
92+ elif array_features :
93+ data = {
94+ "col_1" : [[[True , False ], [False , True ]]] * 4 , # 2D
95+ "col_2" : [[[["a" , "b" ], ["c" , "d" ]], [["e" , "f" ], ["g" , "h" ]]]] * 4 , # 3D array
96+ "col_3" : [[3 , 2 , 1 , 0 ]] * 4 , # Sequence
97+ }
98+ features = Features (
99+ {
100+ "col_1" : Array2D (shape = (2 , 2 ), dtype = "bool" ),
101+ "col_2" : Array3D (shape = (2 , 2 , 2 ), dtype = "string" ),
102+ "col_3" : Sequence (feature = Value ("int64" )),
94103 }
95- features = Features (
96- {
97- "col_1" : Array2D (shape = (2 , 2 ), dtype = "bool" ),
98- "col_2" : Array3D (shape = (2 , 2 , 2 ), dtype = "string" ),
99- "col_3" : Sequence (feature = Value ("int64" )),
100- }
101- )
102- else :
103- data = {"col_1" : [3 , 2 , 1 , 0 ], "col_2" : ["a" , "b" , "c" , "d" ], "col_3" : [False , True , False , True ]}
104- features = None
104+ )
105+ dset = Dataset .from_dict (data , features = features )
106+ elif nested_features :
107+ data = {"nested" : [{"a" : i , "x" : i * 10 , "c" : i * 100 } for i in range (1 , 11 )]}
108+ features = Features ({"nested" : {"a" : Value ("int64" ), "x" : Value ("int64" ), "c" : Value ("int64" )}})
105109 dset = Dataset .from_dict (data , features = features )
106110 else :
107111 dset = Dataset .from_dict ({"filename" : ["my_name-train" + "_" + str (x ) for x in np .arange (30 ).tolist ()]})
@@ -139,7 +143,7 @@ def test_dummy_dataset(self, in_memory):
139143 self .assertEqual (dset ["col_1" ][0 ], 3 )
140144
141145 with tempfile .TemporaryDirectory () as tmp_dir :
142- with self ._create_dummy_dataset (in_memory , tmp_dir , multiple_columns = True , array_features = True ) as dset :
146+ with self ._create_dummy_dataset (in_memory , tmp_dir , array_features = True ) as dset :
143147 self .assertDictEqual (
144148 dset .features ,
145149 Features (
@@ -249,6 +253,19 @@ def test_dummy_dataset_serialize(self, in_memory):
249253 self .assertEqual (dset [0 ]["filename" ], "my_name-train_0" )
250254 self .assertEqual (dset ["filename" ][0 ], "my_name-train_0" )
251255
256+ with self ._create_dummy_dataset (in_memory , tmp_dir , nested_features = True ) as dset :
257+ with assert_arrow_memory_doesnt_increase ():
258+ dset .save_to_disk (dataset_path )
259+
260+ with Dataset .load_from_disk (dataset_path ) as dset :
261+ self .assertEqual (len (dset ), 10 )
262+ self .assertDictEqual (
263+ dset .features ,
264+ Features ({"nested" : {"a" : Value ("int64" ), "x" : Value ("int64" ), "c" : Value ("int64" )}}),
265+ )
266+ self .assertDictEqual (dset [0 ]["nested" ], {"a" : 1 , "c" : 100 , "x" : 10 })
267+ self .assertDictEqual (dset ["nested" ][0 ], {"a" : 1 , "c" : 100 , "x" : 10 })
268+
252269 def test_dummy_dataset_load_from_disk (self , in_memory ):
253270 with tempfile .TemporaryDirectory () as tmp_dir :
254271
@@ -453,7 +470,7 @@ def test_class_encode_column(self, in_memory):
453470 assert_arrow_metadata_are_synced_with_dataset_features (casted_dset )
454471
455472 # Test raises if feature is an array / sequence
456- with self ._create_dummy_dataset (in_memory , tmp_dir , multiple_columns = True , array_features = True ) as dset :
473+ with self ._create_dummy_dataset (in_memory , tmp_dir , array_features = True ) as dset :
457474 for column in dset .column_names :
458475 with self .assertRaises (ValueError ):
459476 dset .class_encode_column (column )
@@ -1597,7 +1614,7 @@ def test_to_csv(self, in_memory):
15971614 self .assertListEqual (list (csv_dset .columns ), list (dset .column_names ))
15981615
15991616 # With array features
1600- with self ._create_dummy_dataset (in_memory , tmp_dir , multiple_columns = True , array_features = True ) as dset :
1617+ with self ._create_dummy_dataset (in_memory , tmp_dir , array_features = True ) as dset :
16011618 file_path = os .path .join (tmp_dir , "test_path.csv" )
16021619 bytes_written = dset .to_csv (path_or_buf = file_path )
16031620
0 commit comments