2727from ssl import create_default_context
2828from urllib .request import build_opener , HTTPSHandler , install_opener
2929import certifi
30+ import functools
3031import hypothesis
3132from cuml .internals .safe_imports import gpu_only_import
3233import pytest
3334import os
3435import subprocess
36+ import time
3537import pandas as pd
3638import cudf .pandas
3739
@@ -212,7 +214,7 @@ def pytest_pyfunc_call(pyfuncitem):
212214 pytest .skip ("Test requires cudf.pandas accelerator" )
213215
214216
215- @pytest .fixture (scope = "module " )
217+ @pytest .fixture (scope = "session " )
216218def nlp_20news ():
217219 try :
218220 twenty_train = fetch_20newsgroups (
@@ -228,7 +230,7 @@ def nlp_20news():
228230 return X , Y
229231
230232
231- @pytest .fixture (scope = "module " )
233+ @pytest .fixture (scope = "session " )
232234def housing_dataset ():
233235 try :
234236 data = fetch_california_housing ()
@@ -245,16 +247,30 @@ def housing_dataset():
245247 return X , y , feature_names
246248
247249
248- @pytest .fixture (scope = "module" )
250+ @functools .cache
251+ def get_boston_data ():
252+ n_retries = 3
253+ url = "https://raw.githubusercontent.com/scikit-learn/scikit-learn/baf828ca126bcb2c0ad813226963621cafe38adb/sklearn/datasets/data/boston_house_prices.csv" # noqa: E501
254+ for _ in range (n_retries ):
255+ try :
256+ return pd .read_csv (url , header = None )
257+ except Exception :
258+ time .sleep (1 )
259+ raise RuntimeError (
260+ f"Failed to download file from { url } after { n_retries } retries."
261+ )
262+
263+
264+ @pytest .fixture (scope = "session" )
249265def deprecated_boston_dataset ():
250266 # dataset was removed in Scikit-learn 1.2, we should change it for a
251267 # better dataset for tests, see
252268 # https://github.com/rapidsai/cuml/issues/5158
253269
254- df = pd . read_csv (
255- "https://raw.githubusercontent.com/scikit-learn/scikit-learn/baf828ca126bcb2c0ad813226963621cafe38adb/sklearn/datasets/data/boston_house_prices.csv" ,
256- header = None ,
257- ) # noqa: E501
270+ try :
271+ df = get_boston_data ()
272+ except : # noqa E722
273+ pytest . xfail ( reason = "Error fetching Boston housing dataset" )
258274 n_samples = int (df [0 ][0 ])
259275 data = df [list (np .arange (13 ))].values [2 :n_samples ].astype (np .float64 )
260276 targets = df [13 ].values [2 :n_samples ].astype (np .float64 )
@@ -266,7 +282,7 @@ def deprecated_boston_dataset():
266282
267283
268284@pytest .fixture (
269- scope = "module " ,
285+ scope = "session " ,
270286 params = ["digits" , "deprecated_boston_dataset" , "diabetes" , "cancer" ],
271287)
272288def test_datasets (request , deprecated_boston_dataset ):
@@ -313,7 +329,7 @@ def failure_logger(request):
313329 print (error_msg )
314330
315331
316- @pytest .fixture (scope = "module " )
332+ @pytest .fixture (scope = "session " )
317333def exact_shap_regression_dataset ():
318334 return create_synthetic_dataset (
319335 generator = skl_make_reg ,
@@ -326,7 +342,7 @@ def exact_shap_regression_dataset():
326342 )
327343
328344
329- @pytest .fixture (scope = "module " )
345+ @pytest .fixture (scope = "session " )
330346def exact_shap_classification_dataset ():
331347 return create_synthetic_dataset (
332348 generator = skl_make_clas ,
0 commit comments