66from datasets .packaged_modules .cache .cache import Cache
77
88
9+ SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA = "hf-internal-testing/audiofolder_single_config_in_metadata"
910SAMPLE_DATASET_TWO_CONFIG_IN_METADATA = "hf-internal-testing/audiofolder_two_configs_in_metadata"
1011
1112
12- def test_cache (text_dir : Path ):
13- ds = load_dataset (str (text_dir ))
13+ def test_cache (text_dir : Path , tmp_path : Path ):
14+ cache_dir = tmp_path / "test_cache"
15+ ds = load_dataset (str (text_dir ), cache_dir = str (cache_dir ))
1416 hash = Path (ds ["train" ].cache_files [0 ]["filename" ]).parts [- 2 ]
15- cache = Cache (dataset_name = text_dir .name , hash = hash )
17+ cache = Cache (cache_dir = str ( cache_dir ), dataset_name = text_dir .name , hash = hash )
1618 reloaded = cache .as_dataset ()
1719 assert list (ds ) == list (reloaded )
1820 assert list (ds ["train" ]) == list (reloaded ["train" ])
1921
2022
21- def test_cache_streaming (text_dir : Path ):
22- ds = load_dataset (str (text_dir ))
23+ def test_cache_streaming (text_dir : Path , tmp_path : Path ):
24+ cache_dir = tmp_path / "test_cache_streaming"
25+ ds = load_dataset (str (text_dir ), cache_dir = str (cache_dir ))
2326 hash = Path (ds ["train" ].cache_files [0 ]["filename" ]).parts [- 2 ]
24- cache = Cache (dataset_name = text_dir .name , hash = hash )
27+ cache = Cache (cache_dir = str ( cache_dir ), dataset_name = text_dir .name , hash = hash )
2528 reloaded = cache .as_streaming_dataset ()
2629 assert list (ds ) == list (reloaded )
2730 assert list (ds ["train" ]) == list (reloaded ["train" ])
2831
2932
30- def test_cache_auto_hash (text_dir : Path ):
31- ds = load_dataset (str (text_dir ))
32- cache = Cache (dataset_name = text_dir .name , version = "auto" , hash = "auto" )
33+ def test_cache_auto_hash (text_dir : Path , tmp_path : Path ):
34+ cache_dir = tmp_path / "test_cache_auto_hash"
35+ ds = load_dataset (str (text_dir ), cache_dir = str (cache_dir ))
36+ cache = Cache (cache_dir = str (cache_dir ), dataset_name = text_dir .name , version = "auto" , hash = "auto" )
3337 reloaded = cache .as_dataset ()
3438 assert list (ds ) == list (reloaded )
3539 assert list (ds ["train" ]) == list (reloaded ["train" ])
3640
3741
38- def test_cache_auto_hash_with_custom_config (text_dir : Path ):
39- ds = load_dataset (str (text_dir ), sample_by = "paragraph" )
40- another_ds = load_dataset (str (text_dir ))
41- cache = Cache (dataset_name = text_dir .name , version = "auto" , hash = "auto" , sample_by = "paragraph" )
42- another_cache = Cache (dataset_name = text_dir .name , version = "auto" , hash = "auto" )
42+ def test_cache_auto_hash_with_custom_config (text_dir : Path , tmp_path : Path ):
43+ cache_dir = tmp_path / "test_cache_auto_hash_with_custom_config"
44+ ds = load_dataset (str (text_dir ), sample_by = "paragraph" , cache_dir = str (cache_dir ))
45+ another_ds = load_dataset (str (text_dir ), cache_dir = str (cache_dir ))
46+ cache = Cache (
47+ cache_dir = str (cache_dir ), dataset_name = text_dir .name , version = "auto" , hash = "auto" , sample_by = "paragraph"
48+ )
49+ another_cache = Cache (cache_dir = str (cache_dir ), dataset_name = text_dir .name , version = "auto" , hash = "auto" )
4350 assert cache .config_id .endswith ("paragraph" )
4451 assert not another_cache .config_id .endswith ("paragraph" )
4552 reloaded = cache .as_dataset ()
@@ -50,27 +57,79 @@ def test_cache_auto_hash_with_custom_config(text_dir: Path):
5057 assert list (another_ds ["train" ]) == list (another_reloaded ["train" ])
5158
5259
53- def test_cache_missing (text_dir : Path ):
54- load_dataset (str (text_dir ))
55- Cache (dataset_name = text_dir .name , version = "auto" , hash = "auto" ).download_and_prepare ()
60+ def test_cache_missing (text_dir : Path , tmp_path : Path ):
61+ cache_dir = tmp_path / "test_cache_missing"
62+ load_dataset (str (text_dir ), cache_dir = str (cache_dir ))
63+ Cache (cache_dir = str (cache_dir ), dataset_name = text_dir .name , version = "auto" , hash = "auto" ).download_and_prepare ()
5664 with pytest .raises (ValueError ):
57- Cache (dataset_name = "missing" , version = "auto" , hash = "auto" ).download_and_prepare ()
65+ Cache (cache_dir = str ( cache_dir ), dataset_name = "missing" , version = "auto" , hash = "auto" ).download_and_prepare ()
5866 with pytest .raises (ValueError ):
59- Cache (dataset_name = text_dir .name , hash = "missing" ).download_and_prepare ()
67+ Cache (cache_dir = str ( cache_dir ), dataset_name = text_dir .name , hash = "missing" ).download_and_prepare ()
6068 with pytest .raises (ValueError ):
61- Cache (dataset_name = text_dir .name , config_name = "missing" , version = "auto" , hash = "auto" ).download_and_prepare ()
69+ Cache (
70+ cache_dir = str (cache_dir ), dataset_name = text_dir .name , config_name = "missing" , version = "auto" , hash = "auto"
71+ ).download_and_prepare ()
6272
6373
6474@pytest .mark .integration
65- def test_cache_multi_configs ():
75+ def test_cache_multi_configs (tmp_path : Path ):
76+ cache_dir = tmp_path / "test_cache_multi_configs"
6677 repo_id = SAMPLE_DATASET_TWO_CONFIG_IN_METADATA
6778 dataset_name = repo_id .split ("/" )[- 1 ]
6879 config_name = "v1"
69- ds = load_dataset (repo_id , config_name )
70- cache = Cache (dataset_name = dataset_name , repo_id = repo_id , config_name = config_name , version = "auto" , hash = "auto" )
80+ ds = load_dataset (repo_id , config_name , cache_dir = str (cache_dir ))
81+ cache = Cache (
82+ cache_dir = str (cache_dir ),
83+ dataset_name = dataset_name ,
84+ repo_id = repo_id ,
85+ config_name = config_name ,
86+ version = "auto" ,
87+ hash = "auto" ,
88+ )
7189 reloaded = cache .as_dataset ()
7290 assert list (ds ) == list (reloaded )
7391 assert len (ds ["train" ]) == len (reloaded ["train" ])
7492 with pytest .raises (ValueError ) as excinfo :
75- Cache (dataset_name = dataset_name , repo_id = repo_id , config_name = "missing" , version = "auto" , hash = "auto" )
93+ Cache (
94+ cache_dir = str (cache_dir ),
95+ dataset_name = dataset_name ,
96+ repo_id = repo_id ,
97+ config_name = "missing" ,
98+ version = "auto" ,
99+ hash = "auto" ,
100+ )
101+ assert config_name in str (excinfo .value )
102+
103+
104+ @pytest .mark .integration
105+ def test_cache_single_config (tmp_path : Path ):
106+ cache_dir = tmp_path / "test_cache_single_config"
107+ repo_id = SAMPLE_DATASET_SINGLE_CONFIG_IN_METADATA
108+ dataset_name = repo_id .split ("/" )[- 1 ]
109+ config_name = "custom"
110+ ds = load_dataset (repo_id , cache_dir = str (cache_dir ))
111+ cache = Cache (cache_dir = str (cache_dir ), dataset_name = dataset_name , repo_id = repo_id , version = "auto" , hash = "auto" )
112+ reloaded = cache .as_dataset ()
113+ assert list (ds ) == list (reloaded )
114+ assert len (ds ["train" ]) == len (reloaded ["train" ])
115+ cache = Cache (
116+ cache_dir = str (cache_dir ),
117+ dataset_name = dataset_name ,
118+ config_name = config_name ,
119+ repo_id = repo_id ,
120+ version = "auto" ,
121+ hash = "auto" ,
122+ )
123+ reloaded = cache .as_dataset ()
124+ assert list (ds ) == list (reloaded )
125+ assert len (ds ["train" ]) == len (reloaded ["train" ])
126+ with pytest .raises (ValueError ) as excinfo :
127+ Cache (
128+ cache_dir = str (cache_dir ),
129+ dataset_name = dataset_name ,
130+ repo_id = repo_id ,
131+ config_name = "missing" ,
132+ version = "auto" ,
133+ hash = "auto" ,
134+ )
76135 assert config_name in str (excinfo .value )
0 commit comments