Skip to content

Commit 2bd1d18

Browse files
authored
Switch tests from pd.testing.assert_frame_equal to dd.assert_eq (#365)
* Start moving tests to dd.assert_eq * Use assert_eq in datetime filter test * Resolve most resulting test failures * Resolve remaining test failures * Convert over tests * Convert more tests * Consolidate select limit cpu/gpu test * Remove remaining assert_series_equal * Remove explicit cudf imports from many tests * Resolve rex test failures * Remove some additional compute calls * Consolidate sorting tests with getfixturevalue * Fix failed join test * Remove breakpoint * Use custom assert_eq function for tests * Resolve test failures / seg faults * Remove unnecessary testing utils * Resolve local test failures * Generalize RAND test * Avoid closing client if using independent cluster * Fix failures on Windows * Resolve black failures * Make random test variables more clear
1 parent 1eb30c1 commit 2bd1d18

27 files changed

+651
-1028
lines changed

tests/integration/fixtures.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
import numpy as np
66
import pandas as pd
77
import pytest
8-
from dask.datasets import timeseries
98
from dask.distributed import Client
10-
from pandas.testing import assert_frame_equal
9+
10+
from tests.utils import assert_eq
1111

1212
try:
1313
import cudf
@@ -23,18 +23,6 @@
2323
SCHEDULER_ADDR = os.getenv("DASK_SQL_TEST_SCHEDULER", None)
2424

2525

26-
@pytest.fixture()
27-
def timeseries_df(c):
28-
pdf = timeseries(freq="1d").compute().reset_index(drop=True)
29-
# impute nans in pandas dataframe
30-
col1_index = np.random.randint(0, 30, size=int(pdf.shape[0] * 0.2))
31-
col2_index = np.random.randint(0, 30, size=int(pdf.shape[0] * 0.3))
32-
pdf.loc[col1_index, "x"] = np.nan
33-
pdf.loc[col2_index, "y"] = np.nan
34-
c.create_table("timeseries", pdf, persist=True)
35-
return pdf
36-
37-
3826
@pytest.fixture()
3927
def df_simple():
4028
return pd.DataFrame({"a": [1, 2, 3], "b": [1.1, 2.2, 3.3]})
@@ -311,7 +299,7 @@ def _assert_query_gives_same_result(query, sort_columns=None, **kwargs):
311299
sql_result = sql_result.reset_index(drop=True)
312300
dask_result = dask_result.reset_index(drop=True)
313301

314-
assert_frame_equal(sql_result, dask_result, check_dtype=False, **kwargs)
302+
assert_eq(sql_result, dask_result, check_dtype=False, **kwargs)
315303

316304
return _assert_query_gives_same_result
317305

tests/integration/test_analyze.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
1-
import numpy as np
21
import pandas as pd
3-
from pandas.testing import assert_frame_equal
2+
3+
from tests.utils import assert_eq
44

55

66
def test_analyze(c, df):
77
result_df = c.sql("ANALYZE TABLE df COMPUTE STATISTICS FOR ALL COLUMNS")
8-
result_df = result_df.compute()
98

109
expected_df = pd.DataFrame(
1110
{
@@ -15,8 +14,7 @@ def test_analyze(c, df):
1514
df.a.std(),
1615
1.0,
1716
2.0,
18-
# That is actually wrong. But the approximate quantile function in dask gives a different result than the actual computation
19-
result_df["a"].iloc[5],
17+
2.0, # incorrect, but what Dask gives for approx quantile
2018
3.0,
2119
3.0,
2220
"double",
@@ -50,12 +48,8 @@ def test_analyze(c, df):
5048
)
5149

5250
# The percentiles are calculated only approximately, therefore we do not use exact matching
53-
p = ["25%", "50%", "75%"]
54-
result_df.loc[p, :] = result_df.loc[p, :].astype(float).apply(np.ceil)
55-
expected_df.loc[p, :] = expected_df.loc[p, :].astype(float).apply(np.ceil)
56-
assert_frame_equal(result_df, expected_df, check_exact=False)
51+
assert_eq(result_df, expected_df, rtol=0.135)
5752

5853
result_df = c.sql("ANALYZE TABLE df COMPUTE STATISTICS FOR COLUMNS a")
59-
result_df = result_df.compute()
6054

61-
assert_frame_equal(result_df, expected_df[["a"]])
55+
assert_eq(result_df, expected_df[["a"]], rtol=0.135)

tests/integration/test_compatibility.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414

1515
import numpy as np
1616
import pandas as pd
17-
from pandas.testing import assert_frame_equal
1817

1918
from dask_sql import Context
2019
from dask_sql.utils import ParsingException
20+
from tests.utils import assert_eq
2121

2222

2323
def cast_datetime_to_string(df):
24-
cols = df.select_dtypes(include=["datetime64[ns]"]).columns
24+
cols = df.select_dtypes(include=["datetime64[ns]"]).columns.tolist()
2525
# Casting to object first as
2626
# directly converting to string looses second precision
2727
df[cols] = df[cols].astype("object").astype("string")
@@ -36,7 +36,7 @@ def eq_sqlite(sql, **dfs):
3636
c.create_table(name, df)
3737
df.to_sql(name, engine, index=False)
3838

39-
dask_result = c.sql(sql).compute().reset_index(drop=True)
39+
dask_result = c.sql(sql).reset_index(drop=True)
4040
sqlite_result = pd.read_sql(sql, engine).reset_index(drop=True)
4141

4242
# casting to object to ensure equality with sql-lite
@@ -47,7 +47,7 @@ def eq_sqlite(sql, **dfs):
4747
dask_result = dask_result.fillna(np.NaN)
4848
sqlite_result = sqlite_result.fillna(np.NaN)
4949

50-
assert_frame_equal(dask_result, sqlite_result, check_dtype=False)
50+
assert_eq(dask_result, sqlite_result, check_dtype=False)
5151

5252

5353
def make_rand_df(size: int, **kwargs):

tests/integration/test_complex.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,6 @@ def test_complex_query(c):
2828
lhs.name = rhs.max_name AND
2929
lhs.x = rhs.max_x
3030
"""
31-
)
31+
).compute()
3232

33-
# should not fail
34-
df = result.compute()
35-
36-
assert len(df) > 0
33+
assert len(result) > 0

tests/integration/test_create.py

Lines changed: 31 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import dask.dataframe as dd
22
import pandas as pd
33
import pytest
4-
from pandas.testing import assert_frame_equal
54

65
import dask_sql
6+
from tests.utils import assert_eq
77

88

99
@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
@@ -26,12 +26,9 @@ def test_create_from_csv(c, df, temporary_data_file, gpu):
2626
"""
2727
SELECT * FROM new_table
2828
"""
29-
).compute()
30-
31-
if gpu:
32-
result_df = result_df.to_pandas()
29+
)
3330

34-
assert_frame_equal(result_df, df)
31+
assert_eq(result_df, df)
3532

3633

3734
@pytest.mark.parametrize(
@@ -60,12 +57,9 @@ def test_cluster_memory(client, c, df, gpu):
6057
"""
6158
SELECT * FROM new_table
6259
"""
63-
).compute()
64-
65-
if gpu:
66-
return_df = return_df.to_pandas()
60+
)
6761

68-
assert_frame_equal(df, return_df)
62+
assert_eq(df, return_df)
6963

7064
client.unpublish_dataset("df")
7165

@@ -91,12 +85,9 @@ def test_create_from_csv_persist(c, df, temporary_data_file, gpu):
9185
"""
9286
SELECT * FROM new_table
9387
"""
94-
).compute()
95-
96-
if gpu:
97-
return_df = return_df.to_pandas()
88+
)
9889

99-
assert_frame_equal(df, return_df)
90+
assert_eq(df, return_df)
10091

10192

10293
def test_wrong_create(c):
@@ -139,9 +130,9 @@ def test_create_from_query(c, df):
139130
"""
140131
SELECT * FROM new_table
141132
"""
142-
).compute()
133+
)
143134

144-
assert_frame_equal(df, return_df)
135+
assert_eq(df, return_df)
145136

146137
c.sql(
147138
"""
@@ -157,9 +148,9 @@ def test_create_from_query(c, df):
157148
"""
158149
SELECT * FROM new_table
159150
"""
160-
).compute()
151+
)
161152

162-
assert_frame_equal(df, return_df)
153+
assert_eq(df, return_df)
163154

164155

165156
@pytest.mark.parametrize(
@@ -210,27 +201,19 @@ def test_view_table_persist(c, temporary_data_file, df, gpu):
210201
"""
211202
)
212203

213-
from_view = c.sql("SELECT c FROM count_view").compute()
214-
from_table = c.sql("SELECT c FROM count_table").compute()
215-
216-
if gpu:
217-
from_view = from_view.to_pandas()
218-
from_table = from_table.to_pandas()
204+
from_view = c.sql("SELECT c FROM count_view")
205+
from_table = c.sql("SELECT c FROM count_table")
219206

220-
assert_frame_equal(from_view, pd.DataFrame({"c": [700]}))
221-
assert_frame_equal(from_table, pd.DataFrame({"c": [700]}))
207+
assert_eq(from_view, pd.DataFrame({"c": [700]}))
208+
assert_eq(from_table, pd.DataFrame({"c": [700]}))
222209

223210
df.iloc[:10].to_csv(temporary_data_file, index=False)
224211

225-
from_view = c.sql("SELECT c FROM count_view").compute()
226-
from_table = c.sql("SELECT c FROM count_table").compute()
212+
from_view = c.sql("SELECT c FROM count_view")
213+
from_table = c.sql("SELECT c FROM count_table")
227214

228-
if gpu:
229-
from_view = from_view.to_pandas()
230-
from_table = from_table.to_pandas()
231-
232-
assert_frame_equal(from_view, pd.DataFrame({"c": [10]}))
233-
assert_frame_equal(from_table, pd.DataFrame({"c": [700]}))
215+
assert_eq(from_view, pd.DataFrame({"c": [10]}))
216+
assert_eq(from_table, pd.DataFrame({"c": [700]}))
234217

235218

236219
def test_replace_and_error(c, temporary_data_file, df):
@@ -244,8 +227,8 @@ def test_replace_and_error(c, temporary_data_file, df):
244227
"""
245228
)
246229

247-
assert_frame_equal(
248-
c.sql("SELECT a FROM new_table").compute(),
230+
assert_eq(
231+
c.sql("SELECT a FROM new_table"),
249232
pd.DataFrame({"a": [1]}),
250233
check_dtype=False,
251234
)
@@ -271,8 +254,8 @@ def test_replace_and_error(c, temporary_data_file, df):
271254
"""
272255
)
273256

274-
assert_frame_equal(
275-
c.sql("SELECT a FROM new_table").compute(),
257+
assert_eq(
258+
c.sql("SELECT a FROM new_table"),
276259
pd.DataFrame({"a": [1]}),
277260
check_dtype=False,
278261
)
@@ -287,8 +270,8 @@ def test_replace_and_error(c, temporary_data_file, df):
287270
"""
288271
)
289272

290-
assert_frame_equal(
291-
c.sql("SELECT a FROM new_table").compute(),
273+
assert_eq(
274+
c.sql("SELECT a FROM new_table"),
292275
pd.DataFrame({"a": [2]}),
293276
check_dtype=False,
294277
)
@@ -308,8 +291,8 @@ def test_replace_and_error(c, temporary_data_file, df):
308291
"""
309292
)
310293

311-
assert_frame_equal(
312-
c.sql("SELECT a FROM new_table").compute(),
294+
assert_eq(
295+
c.sql("SELECT a FROM new_table"),
313296
pd.DataFrame({"a": [3]}),
314297
check_dtype=False,
315298
)
@@ -338,8 +321,8 @@ def test_replace_and_error(c, temporary_data_file, df):
338321
"""
339322
)
340323

341-
assert_frame_equal(
342-
c.sql("SELECT a FROM new_table").compute(),
324+
assert_eq(
325+
c.sql("SELECT a FROM new_table"),
343326
pd.DataFrame({"a": [3]}),
344327
check_dtype=False,
345328
)
@@ -355,13 +338,9 @@ def test_replace_and_error(c, temporary_data_file, df):
355338
"""
356339
)
357340

358-
result_df = c.sql(
359-
"""
360-
SELECT * FROM new_table
361-
"""
362-
).compute()
341+
result_df = c.sql("SELECT * FROM new_table")
363342

364-
assert_frame_equal(result_df, df)
343+
assert_eq(result_df, df)
365344

366345

367346
def test_drop(c):

tests/integration/test_distributeby.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,13 @@
22
import pandas as pd
33
import pytest
44

5-
try:
6-
import cudf
7-
except ImportError:
8-
cudf = None
9-
105

116
@pytest.mark.parametrize("gpu", [False, pytest.param(True, marks=pytest.mark.gpu)])
127
def test_distribute_by(c, gpu):
13-
14-
if gpu:
15-
xd = cudf
16-
else:
17-
xd = pd
18-
19-
df = xd.DataFrame({"id": [0, 1, 2, 1, 2, 3], "val": [0, 1, 2, 1, 2, 3]})
8+
df = pd.DataFrame({"id": [0, 1, 2, 1, 2, 3], "val": [0, 1, 2, 1, 2, 3]})
209
ddf = dd.from_pandas(df, npartitions=2)
2110

22-
c.create_table("test", ddf)
11+
c.create_table("test", ddf, gpu=gpu)
2312
partitioned_ddf = c.sql(
2413
"""
2514
SELECT

0 commit comments

Comments
 (0)