forked from dask-contrib/dask-sql
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathanalyze_table.py
More file actions
70 lines (54 loc) · 2.24 KB
/
analyze_table.py
File metadata and controls
70 lines (54 loc) · 2.24 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from typing import TYPE_CHECKING
import dask.dataframe as dd
import pandas as pd
from dask_sql.datacontainer import ColumnContainer, DataContainer
from dask_sql.mappings import python_to_sql_type
from dask_sql.physical.rel.base import BaseRelPlugin
if TYPE_CHECKING:
import dask_sql
from dask_planner.rust import LogicalPlan
class AnalyzeTablePlugin(BaseRelPlugin):
"""
Show information on the table (like mean, max etc.)
on all or a subset of the columns..
The SQL is:
ANALYZE TABLE <table> COMPUTE STATISTICS FOR [ALL COLUMNS | COLUMNS a, b, ...]
The result is also a table, although it is created on the fly.
Please note: even though the syntax is very similar to e.g.
[the spark version](https://spark.apache.org/docs/3.0.0/sql-ref-syntax-aux-analyze-table.html),
this call does not help with query optimization (as the spark call would do),
as this is currently not implemented in dask-sql.
"""
class_name = "AnalyzeTable"
def convert(self, rel: "LogicalPlan", context: "dask_sql.Context") -> DataContainer:
analyze_table = rel.analyze_table()
schema_name = analyze_table.getSchemaName() or context.schema_name
table_name = analyze_table.getTableName()
dc = context.schema[schema_name].tables[table_name]
columns = analyze_table.getColumns()
if not columns:
columns = dc.column_container.columns
# Define some useful shortcuts
mapping = dc.column_container.get_backend_by_frontend_name
df = dc.df
# Calculate statistics
statistics = dd.concat(
[
df[[mapping(col) for col in columns]].describe(),
pd.DataFrame(
{
mapping(col): str(
python_to_sql_type(df[mapping(col)].dtype)
).lower()
for col in columns
},
index=["data_type"],
),
pd.DataFrame(
{mapping(col): col for col in columns}, index=["col_name"]
),
]
)
cc = ColumnContainer(statistics.columns)
dc = DataContainer(statistics, cc)
return dc