Skip to content

Commit 990dde9

Browse files
committed
Merge branch 'main' of github.com:dask-contrib/dask-sql into bump-codecov-step
2 parents 7cfd126 + eb8c326 commit 990dde9

2 files changed

Lines changed: 79 additions & 75 deletions

File tree

tests/integration/test_analyze.py

Lines changed: 20 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,31 @@
11
import pandas as pd
22

3+
from dask_sql.mappings import python_to_sql_type
34
from tests.utils import assert_eq
45

56

67
def test_analyze(c, df):
78
result_df = c.sql("ANALYZE TABLE df COMPUTE STATISTICS FOR ALL COLUMNS")
89

9-
expected_df = pd.DataFrame(
10-
{
11-
"a": [
12-
700.0,
13-
df.a.mean(),
14-
df.a.std(),
15-
df.a.min(),
16-
# Dask's approx quantiles do not match up with pandas and must be specified explicitly
17-
2.0,
18-
2.0,
19-
3.0,
20-
df.a.max(),
21-
"double",
22-
"a",
23-
],
24-
"b": [
25-
700.0,
26-
df.b.mean(),
27-
df.b.std(),
28-
df.b.min(),
29-
# Dask's approx quantiles do not match up with pandas and must be specified explicitly
30-
2.73108,
31-
5.20286,
32-
7.60595,
33-
df.b.max(),
34-
"double",
35-
"b",
36-
],
37-
},
38-
index=[
39-
"count",
40-
"mean",
41-
"std",
42-
"min",
43-
"25%",
44-
"50%",
45-
"75%",
46-
"max",
47-
"data_type",
48-
"col_name",
49-
],
10+
# extract table and compute stats with Dask manually
11+
expected_df = (
12+
c.sql("SELECT * FROM df")
13+
.describe()
14+
.append(
15+
pd.Series(
16+
{
17+
col: str(python_to_sql_type(df[col].dtype)).lower()
18+
for col in df.columns
19+
},
20+
name="data_type",
21+
)
22+
)
23+
.append(
24+
pd.Series(
25+
{col: col for col in df.columns},
26+
name="col_name",
27+
)
28+
)
5029
)
5130

5231
assert_eq(result_df, expected_df)

tests/integration/test_sample.py

Lines changed: 59 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,71 @@
1-
def test_sample(c, df):
2-
# Fixed sample, check absolute numbers
3-
return_df = c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (20) REPEATABLE (10)")
4-
5-
assert len(return_df) == 234
6-
7-
return_df = c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (20) REPEATABLE (11)")
8-
9-
assert len(return_df) == 468 # Yes, that is horrible, but at least fast...
10-
11-
return_df = c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (50) REPEATABLE (10)")
12-
13-
assert len(return_df) == 234
14-
15-
return_df = c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (0.001) REPEATABLE (10)")
16-
17-
assert len(return_df) == 0
1+
import numpy as np
182

19-
return_df = c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (99.999) REPEATABLE (10)")
3+
from tests.utils import assert_eq
204

21-
assert len(return_df) == len(df)
225

23-
return_df = c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (50) REPEATABLE (10)")
6+
def get_system_sample(df, fraction, seed):
7+
random_state = np.random.RandomState(seed)
8+
random_choice = random_state.choice(
9+
[True, False],
10+
size=df.npartitions,
11+
replace=True,
12+
p=[fraction, 1 - fraction],
13+
)
2414

25-
assert len(return_df) == 350
15+
if random_choice.any():
16+
df = df.partitions[random_choice]
17+
else:
18+
df = df.head(0, compute=False)
2619

27-
return_df = c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (70) REPEATABLE (10)")
20+
return df
2821

29-
assert len(return_df) == 490
3022

31-
return_df = c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (0.001) REPEATABLE (10)")
32-
33-
assert len(return_df) == 0
34-
35-
return_df = c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (99.999) REPEATABLE (10)")
36-
37-
assert len(return_df) == len(df)
38-
39-
# Not fixed sample, can only check boundaries
23+
def test_sample(c, df):
24+
ddf = c.sql("SELECT * FROM df")
25+
26+
# fixed system samples
27+
assert_eq(
28+
c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (20) REPEATABLE (10)"),
29+
get_system_sample(ddf, 0.20, 10),
30+
)
31+
assert_eq(
32+
c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (20) REPEATABLE (11)"),
33+
get_system_sample(ddf, 0.20, 11),
34+
)
35+
assert_eq(
36+
c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (50) REPEATABLE (10)"),
37+
get_system_sample(ddf, 0.50, 10),
38+
)
39+
assert_eq(
40+
c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (0.001) REPEATABLE (10)"),
41+
get_system_sample(ddf, 0.00001, 10),
42+
)
43+
assert_eq(
44+
c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (99.999) REPEATABLE (10)"),
45+
get_system_sample(ddf, 0.99999, 10),
46+
)
47+
48+
# fixed bernoulli samples
49+
assert_eq(
50+
c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (50) REPEATABLE (10)"),
51+
ddf.sample(frac=0.50, replace=False, random_state=10),
52+
)
53+
assert_eq(
54+
c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (70) REPEATABLE (10)"),
55+
ddf.sample(frac=0.70, replace=False, random_state=10),
56+
)
57+
assert_eq(
58+
c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (0.001) REPEATABLE (10)"),
59+
ddf.sample(frac=0.00001, replace=False, random_state=10),
60+
)
61+
assert_eq(
62+
c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (99.999) REPEATABLE (10)"),
63+
ddf.sample(frac=0.99999, replace=False, random_state=10),
64+
)
65+
66+
# variable samples, can only check boundaries
4067
return_df = c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (50)")
41-
4268
assert len(return_df) >= 0 and len(return_df) <= len(df)
4369

4470
return_df = c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (50)")
45-
4671
assert len(return_df) >= 0 and len(return_df) <= len(df)

0 commit comments

Comments
 (0)