Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 20 additions & 41 deletions tests/integration/test_analyze.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,31 @@
import pandas as pd

from dask_sql.mappings import python_to_sql_type
from tests.utils import assert_eq


def test_analyze(c, df):
result_df = c.sql("ANALYZE TABLE df COMPUTE STATISTICS FOR ALL COLUMNS")

expected_df = pd.DataFrame(
{
"a": [
700.0,
df.a.mean(),
df.a.std(),
df.a.min(),
# Dask's approx quantiles do not match up with pandas and must be specified explicitly
2.0,
2.0,
3.0,
df.a.max(),
"double",
"a",
],
"b": [
700.0,
df.b.mean(),
df.b.std(),
df.b.min(),
# Dask's approx quantiles do not match up with pandas and must be specified explicitly
2.73108,
5.20286,
7.60595,
df.b.max(),
"double",
"b",
],
},
index=[
"count",
"mean",
"std",
"min",
"25%",
"50%",
"75%",
"max",
"data_type",
"col_name",
],
# extract table and compute stats with Dask manually
expected_df = (
c.sql("SELECT * FROM df")
.describe()
.append(
pd.Series(
{
col: str(python_to_sql_type(df[col].dtype)).lower()
for col in df.columns
},
name="data_type",
)
)
.append(
pd.Series(
{col: col for col in df.columns},
name="col_name",
)
)
)

assert_eq(result_df, expected_df)
Expand Down
93 changes: 59 additions & 34 deletions tests/integration/test_sample.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,71 @@
def test_sample(c, df):
# Fixed sample, check absolute numbers
return_df = c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (20) REPEATABLE (10)")

assert len(return_df) == 234

return_df = c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (20) REPEATABLE (11)")

assert len(return_df) == 468 # Yes, that is horrible, but at least fast...

return_df = c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (50) REPEATABLE (10)")

assert len(return_df) == 234

return_df = c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (0.001) REPEATABLE (10)")

assert len(return_df) == 0
import numpy as np

return_df = c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (99.999) REPEATABLE (10)")
from tests.utils import assert_eq

assert len(return_df) == len(df)

return_df = c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (50) REPEATABLE (10)")
def get_system_sample(df, fraction, seed):
random_state = np.random.RandomState(seed)
random_choice = random_state.choice(
[True, False],
size=df.npartitions,
replace=True,
p=[fraction, 1 - fraction],
)

assert len(return_df) == 350
if random_choice.any():
df = df.partitions[random_choice]
else:
df = df.head(0, compute=False)

return_df = c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (70) REPEATABLE (10)")
return df

assert len(return_df) == 490

return_df = c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (0.001) REPEATABLE (10)")

assert len(return_df) == 0

return_df = c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (99.999) REPEATABLE (10)")

assert len(return_df) == len(df)

# Not fixed sample, can only check boundaries
def test_sample(c, df):
ddf = c.sql("SELECT * FROM df")

# fixed system samples
assert_eq(
c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (20) REPEATABLE (10)"),
get_system_sample(ddf, 0.20, 10),
)
assert_eq(
c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (20) REPEATABLE (11)"),
get_system_sample(ddf, 0.20, 11),
)
assert_eq(
c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (50) REPEATABLE (10)"),
get_system_sample(ddf, 0.50, 10),
)
assert_eq(
c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (0.001) REPEATABLE (10)"),
get_system_sample(ddf, 0.00001, 10),
)
assert_eq(
c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (99.999) REPEATABLE (10)"),
get_system_sample(ddf, 0.99999, 10),
)

# fixed bernoulli samples
assert_eq(
c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (50) REPEATABLE (10)"),
ddf.sample(frac=0.50, replace=False, random_state=10),
)
assert_eq(
c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (70) REPEATABLE (10)"),
ddf.sample(frac=0.70, replace=False, random_state=10),
)
assert_eq(
c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (0.001) REPEATABLE (10)"),
ddf.sample(frac=0.00001, replace=False, random_state=10),
)
assert_eq(
c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (99.999) REPEATABLE (10)"),
ddf.sample(frac=0.99999, replace=False, random_state=10),
)

# variable samples, can only check boundaries
return_df = c.sql("SELECT * FROM df TABLESAMPLE BERNOULLI (50)")

assert len(return_df) >= 0 and len(return_df) <= len(df)

return_df = c.sql("SELECT * FROM df TABLESAMPLE SYSTEM (50)")

assert len(return_df) >= 0 and len(return_df) <= len(df)