Skip to content

Commit c3ddb1e

Browse files
authored
Fix CI (#6780)
* Fix CI * Nit
1 parent ad3467e commit c3ddb1e

File tree

3 files changed

+21
-18
lines changed

3 files changed

+21
-18
lines changed

tests/test_arrow_dataset.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3980,7 +3980,7 @@ def _check_sql_dataset(dataset, expected_features):
39803980

39813981
@require_sqlalchemy
39823982
@pytest.mark.parametrize("con_type", ["string", "engine"])
3983-
def test_dataset_from_sql_con_type(con_type, sqlite_path, tmp_path, set_sqlalchemy_silence_uber_warning):
3983+
def test_dataset_from_sql_con_type(con_type, sqlite_path, tmp_path, set_sqlalchemy_silence_uber_warning, caplog):
39843984
cache_dir = tmp_path / "cache"
39853985
expected_features = {"col_1": "string", "col_2": "int64", "col_3": "float64"}
39863986
if con_type == "string":
@@ -3989,17 +3989,16 @@ def test_dataset_from_sql_con_type(con_type, sqlite_path, tmp_path, set_sqlalche
39893989
import sqlalchemy
39903990

39913991
con = sqlalchemy.create_engine("sqlite:///" + sqlite_path)
3992-
# # https://github.com/huggingface/datasets/issues/2832 needs to be fixed first for this to work
3993-
# with caplog.at_level(INFO):
3994-
# dataset = Dataset.from_sql(
3995-
# "dataset",
3996-
# con,
3997-
# cache_dir=cache_dir,
3998-
# )
3999-
# if con_type == "string":
4000-
# assert "couldn't be hashed properly" not in caplog.text
4001-
# elif con_type == "engine":
4002-
# assert "couldn't be hashed properly" in caplog.text
3992+
with caplog.at_level(INFO, logger=get_logger().name):
3993+
dataset = Dataset.from_sql(
3994+
"dataset",
3995+
con,
3996+
cache_dir=cache_dir,
3997+
)
3998+
if con_type == "string":
3999+
assert "couldn't be hashed properly" not in caplog.text
4000+
elif con_type == "engine":
4001+
assert "couldn't be hashed properly" in caplog.text
40034002
dataset = Dataset.from_sql(
40044003
"dataset",
40054004
con,

tests/test_inspect.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
pytestmark = pytest.mark.integration
1919

2020

21-
@pytest.mark.parametrize("path", ["lhoestq/test", csv.__file__])
21+
@pytest.mark.parametrize("path", ["hf-internal-testing/dataset_with_script", csv.__file__])
2222
def test_inspect_dataset(path, tmp_path):
2323
inspect_dataset(path, tmp_path)
2424
script_name = Path(path).stem + ".py"

tests/test_load.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -418,24 +418,28 @@ def setUp(self):
418418
)
419419

420420
def test_HubDatasetModuleFactoryWithScript_dont_trust_remote_code(self):
421-
# "lhoestq/test" has a dataset script
422421
factory = HubDatasetModuleFactoryWithScript(
423-
"lhoestq/test", download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path
422+
"hf-internal-testing/dataset_with_script",
423+
download_config=self.download_config,
424+
dynamic_modules_path=self.dynamic_modules_path,
424425
)
425426
with patch.object(config, "HF_DATASETS_TRUST_REMOTE_CODE", None): # this will be the default soon
426427
self.assertRaises(ValueError, factory.get_module)
427428
factory = HubDatasetModuleFactoryWithScript(
428-
"lhoestq/test",
429+
"hf-internal-testing/dataset_with_script",
429430
download_config=self.download_config,
430431
dynamic_modules_path=self.dynamic_modules_path,
431432
trust_remote_code=False,
432433
)
433434
self.assertRaises(ValueError, factory.get_module)
434435

435-
def test_HubDatasetModuleFactoryWithScript_with_github_dataset(self):
436+
def test_HubDatasetModuleFactoryWithScript_with_hub_dataset(self):
436437
# "wmt_t2t" has additional imports (internal)
437438
factory = HubDatasetModuleFactoryWithScript(
438-
"wmt_t2t", download_config=self.download_config, dynamic_modules_path=self.dynamic_modules_path
439+
"wmt_t2t",
440+
download_config=self.download_config,
441+
dynamic_modules_path=self.dynamic_modules_path,
442+
revision="861aac88b2c6247dd93ade8b1c189ce714627750",
439443
)
440444
module_factory_result = factory.get_module()
441445
assert importlib.import_module(module_factory_result.module_path) is not None

0 commit comments

Comments
 (0)