@@ -3178,6 +3178,60 @@ def test_dataset_from_text_path_type(path_type, text_path, tmp_path):
31783178 _check_text_dataset (dataset , expected_features )
31793179
31803180
3181+ @pytest .fixture
3182+ def data_generator ():
3183+ def _gen ():
3184+ data = [
3185+ {"col_1" : "0" , "col_2" : 0 , "col_3" : 0.0 },
3186+ {"col_1" : "1" , "col_2" : 1 , "col_3" : 1.0 },
3187+ {"col_1" : "2" , "col_2" : 2 , "col_3" : 2.0 },
3188+ {"col_1" : "3" , "col_2" : 3 , "col_3" : 3.0 },
3189+ ]
3190+ for item in data :
3191+ yield item
3192+
3193+ return _gen
3194+
3195+
3196+ def _check_generator_dataset (dataset , expected_features ):
3197+ assert isinstance (dataset , Dataset )
3198+ assert dataset .num_rows == 4
3199+ assert dataset .num_columns == 3
3200+ assert dataset .column_names == ["col_1" , "col_2" , "col_3" ]
3201+ for feature , expected_dtype in expected_features .items ():
3202+ assert dataset .features [feature ].dtype == expected_dtype
3203+
3204+
3205+ @pytest .mark .parametrize ("keep_in_memory" , [False , True ])
3206+ def test_dataset_from_generator_keep_in_memory (keep_in_memory , data_generator , tmp_path ):
3207+ cache_dir = tmp_path / "cache"
3208+ expected_features = {"col_1" : "string" , "col_2" : "int64" , "col_3" : "float64" }
3209+ with assert_arrow_memory_increases () if keep_in_memory else assert_arrow_memory_doesnt_increase ():
3210+ dataset = Dataset .from_generator (data_generator , cache_dir = cache_dir , keep_in_memory = keep_in_memory )
3211+ _check_generator_dataset (dataset , expected_features )
3212+
3213+
3214+ @pytest .mark .parametrize (
3215+ "features" ,
3216+ [
3217+ None ,
3218+ {"col_1" : "string" , "col_2" : "int64" , "col_3" : "float64" },
3219+ {"col_1" : "string" , "col_2" : "string" , "col_3" : "string" },
3220+ {"col_1" : "int32" , "col_2" : "int32" , "col_3" : "int32" },
3221+ {"col_1" : "float32" , "col_2" : "float32" , "col_3" : "float32" },
3222+ ],
3223+ )
3224+ def test_dataset_from_generator_features (features , data_generator , tmp_path ):
3225+ cache_dir = tmp_path / "cache"
3226+ default_expected_features = {"col_1" : "string" , "col_2" : "int64" , "col_3" : "float64" }
3227+ expected_features = features .copy () if features else default_expected_features
3228+ features = (
3229+ Features ({feature : Value (dtype ) for feature , dtype in features .items ()}) if features is not None else None
3230+ )
3231+ dataset = Dataset .from_generator (data_generator , features = features , cache_dir = cache_dir )
3232+ _check_generator_dataset (dataset , expected_features )
3233+
3234+
31813235def test_dataset_to_json (dataset , tmp_path ):
31823236 file_path = tmp_path / "test_path.jsonl"
31833237 bytes_written = dataset .to_json (path_or_buf = file_path )
0 commit comments