Skip to content

Commit ecc3da3

Browse files
rajagurunathGurunath LankupalliVenugopalnils-braun
authored
ML model improvement : Added SHOW MODELS and DESCRIBE MODEL (#185)
* ML model improvement : Adding "SHOW MODELS and DESCRIBE MODEL" Author: rajagurunath <[email protected]> Date: Mon May 24 02:37:40 2021 +0530 * fix typo * ML model improvement : refactoring for PR * ML model improvement : Adding stmts in notebook * ML model improvement : Adding stmts in notebook * ML model improvement : also test the non-happy path Co-authored-by: Gurunath LankupalliVenugopal <[email protected]> Co-authored-by: Nils Braun <[email protected]>
1 parent 0d34cef commit ecc3da3

11 files changed

Lines changed: 304 additions & 9 deletions

File tree

dask_sql/context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ def __init__(self):
105105
RelConverter.add_plugin_class(custom.ShowColumnsPlugin, replace=False)
106106
RelConverter.add_plugin_class(custom.ShowSchemasPlugin, replace=False)
107107
RelConverter.add_plugin_class(custom.ShowTablesPlugin, replace=False)
108+
RelConverter.add_plugin_class(custom.ShowModelsPlugin, replace=False)
109+
RelConverter.add_plugin_class(custom.ShowModelParamsPlugin, replace=False)
108110

109111
RexConverter.add_plugin_class(core.RexCallPlugin, replace=False)
110112
RexConverter.add_plugin_class(core.RexInputRefPlugin, replace=False)

dask_sql/physical/rel/custom/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@
33
from .create_model import CreateModelPlugin
44
from .create_table import CreateTablePlugin
55
from .create_table_as import CreateTableAsPlugin
6+
from .describe_model import ShowModelParamsPlugin
67
from .drop_model import DropModelPlugin
78
from .drop_table import DropTablePlugin
89
from .predict import PredictModelPlugin
910
from .schemas import ShowSchemasPlugin
11+
from .show_models import ShowModelsPlugin
1012
from .tables import ShowTablesPlugin
1113

1214
__all__ = [
@@ -20,4 +22,6 @@
2022
ShowColumnsPlugin,
2123
ShowSchemasPlugin,
2224
ShowTablesPlugin,
25+
ShowModelsPlugin,
26+
ShowModelParamsPlugin,
2327
]
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import dask.dataframe as dd
2+
import pandas as pd
3+
4+
from dask_sql.datacontainer import ColumnContainer, DataContainer
5+
from dask_sql.physical.rel.base import BaseRelPlugin
6+
7+
8+
class ShowModelParamsPlugin(BaseRelPlugin):
9+
"""
10+
Show all Params used to train a given model along with the columns
11+
used for training.
12+
The SQL is:
13+
14+
DESCRIBE MODEL <model_name>
15+
16+
The result is also a table, although it is created on the fly.
17+
"""
18+
19+
class_name = "com.dask.sql.parser.SqlShowModelParams"
20+
21+
def convert(
22+
self, sql: "org.apache.calcite.sql.SqlNode", context: "dask_sql.Context"
23+
) -> DataContainer:
24+
model_name = str(sql.getModelName().getIdentifier())
25+
if model_name not in context.models:
26+
raise RuntimeError(f"A model with the name {model_name} is not present.")
27+
model, training_columns = context.models[model_name]
28+
model_params = model.get_params()
29+
model_params["training_columns"] = training_columns.tolist()
30+
df = pd.DataFrame.from_dict(model_params, orient="index", columns=["Params"])
31+
cc = ColumnContainer(df.columns)
32+
dc = DataContainer(dd.from_pandas(df, npartitions=1), cc)
33+
return dc
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import dask.dataframe as dd
2+
import pandas as pd
3+
4+
from dask_sql.datacontainer import ColumnContainer, DataContainer
5+
from dask_sql.physical.rel.base import BaseRelPlugin
6+
7+
8+
class ShowModelsPlugin(BaseRelPlugin):
9+
"""
10+
Show all MODELS currently registered/trained.
11+
The SQL is:
12+
13+
SHOW MODELS
14+
15+
The result is also a table, although it is created on the fly.
16+
"""
17+
18+
class_name = "com.dask.sql.parser.SqlShowModels"
19+
20+
def convert(
21+
self, sql: "org.apache.calcite.sql.SqlNode", context: "dask_sql.Context"
22+
) -> DataContainer:
23+
24+
df = pd.DataFrame({"Models": list(context.models.keys())})
25+
26+
cc = ColumnContainer(df.columns)
27+
dc = DataContainer(dd.from_pandas(df, npartitions=1), cc)
28+
return dc

dask_sql/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from collections import defaultdict
66
from contextlib import contextmanager
77
from datetime import datetime
8-
from typing import Any, Dict, List
8+
from typing import Any, Dict, List, Tuple
99
from unittest.mock import patch
1010
from uuid import uuid4
1111

notebooks/Feature Overview.ipynb

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,26 @@
285285
"SHOW COLUMNS FROM \"iris\""
286286
]
287287
},
288+
{
289+
"cell_type": "code",
290+
"execution_count": null,
291+
"metadata": {},
292+
"outputs": [],
293+
"source": [
294+
"%%sql\n",
295+
"DESCRIBE iris"
296+
]
297+
},
298+
{
299+
"cell_type": "code",
300+
"execution_count": null,
301+
"metadata": {},
302+
"outputs": [],
303+
"source": [
304+
"%%sql\n",
305+
"DESCRIBE TABLE iris"
306+
]
307+
},
288308
{
289309
"cell_type": "markdown",
290310
"metadata": {},
@@ -507,6 +527,26 @@
507527
")"
508528
]
509529
},
530+
{
531+
"cell_type": "code",
532+
"execution_count": null,
533+
"metadata": {},
534+
"outputs": [],
535+
"source": [
536+
"%%sql\n",
537+
"SHOW MODELS"
538+
]
539+
},
540+
{
541+
"cell_type": "code",
542+
"execution_count": null,
543+
"metadata": {},
544+
"outputs": [],
545+
"source": [
546+
"%%sql\n",
547+
"DESCRIBE MODEL my_model"
548+
]
549+
},
510550
{
511551
"cell_type": "code",
512552
"execution_count": null,
@@ -571,13 +611,20 @@
571611
"\"\"\").compute() \n",
572612
"t.set_index([\"target\", \"species\"]).unstack(\"species\").number.plot.bar()"
573613
]
614+
},
615+
{
616+
"cell_type": "code",
617+
"execution_count": null,
618+
"metadata": {},
619+
"outputs": [],
620+
"source": []
574621
}
575622
],
576623
"metadata": {
577624
"kernelspec": {
578625
"display_name": "Python 3",
579626
"language": "python",
580-
"name": "python3"
627+
"name": "Python 3"
581628
},
582629
"language_info": {
583630
"codemirror_mode": {

planner/src/main/codegen/config.fmpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ data: {
3636
"com.dask.sql.parser.SqlShowColumns",
3737
"com.dask.sql.parser.SqlShowSchemas",
3838
"com.dask.sql.parser.SqlShowTables"
39+
"com.dask.sql.parser.SqlShowModels"
3940
]
4041

4142
# List of keywords.
@@ -49,6 +50,7 @@ data: {
4950
"SCHEMAS"
5051
"STATISTICS"
5152
"TABLES"
53+
"MODELS"
5254
]
5355

5456
# The keywords can only be used in a specific context,
@@ -68,11 +70,12 @@ data: {
6870
# List of methods for parsing custom SQL statements
6971
statementParserMethods: [
7072
"SqlAnalyzeTable()"
71-
"SqlDescribeTable()"
7273
"SqlShowColumns()"
7374
"SqlShowSchemas()"
7475
"SqlShowTables()"
7576
"SqlPredictModel()"
77+
"SqlShowModels()"
78+
"SqlDescribeTableOrModel()"
7679
]
7780

7881
createStatementParserMethods: [

planner/src/main/codegen/includes/show.ftl

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,19 +46,31 @@ SqlNode SqlShowColumns() :
4646
}
4747
}
4848

49-
// DESCRIBE "table"
50-
SqlNode SqlDescribeTable() :
49+
// DESCRIBE <table_name>
50+
// DESCRIBE TABLE <table_name>
51+
// DESCRIBE MODEL <model_name>
52+
SqlNode SqlDescribeTableOrModel() :
5153
{
5254
final Span s;
5355
final SqlIdentifier schemaName;
5456
final SqlIdentifier tableName;
57+
final SqlModelIdentifier modelName;
5558
}
5659
{
5760
<DESCRIBE> { s = span(); }
58-
tableName = CompoundTableIdentifier()
59-
{
60-
return new SqlShowColumns(s.end(this), tableName);
61-
}
61+
(
62+
LOOKAHEAD(2)
63+
modelName = ModelIdentifier()
64+
{
65+
return new SqlShowModelParams(s.end(this), modelName);
66+
}
67+
|
68+
[ <TABLE> ]
69+
tableName = CompoundTableIdentifier()
70+
{
71+
return new SqlShowColumns(s.end(this), tableName);
72+
}
73+
)
6274
}
6375

6476
// ANALYZE TABLE table_identifier COMPUTE STATISTICS [ FOR COLUMNS col [ , ... ] | FOR ALL COLUMNS ]
@@ -86,3 +98,15 @@ SqlNode SqlAnalyzeTable() :
8698
return new SqlAnalyzeTable(s.end(this), tableName, columnList);
8799
}
88100
}
101+
102+
// SHOW MODELS
103+
SqlNode SqlShowModels() :
104+
{
105+
final Span s;
106+
}
107+
{
108+
<SHOW> { s = span(); } <MODELS>
109+
{
110+
return new SqlShowModels(s.end(this));
111+
}
112+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package com.dask.sql.parser;
2+
import java.util.List;
3+
import org.apache.calcite.sql.SqlCall;
4+
import org.apache.calcite.sql.SqlIdentifier;
5+
import org.apache.calcite.sql.SqlNode;
6+
import org.apache.calcite.sql.SqlNodeList;
7+
import org.apache.calcite.sql.SqlOperator;
8+
import org.apache.calcite.sql.SqlWriter;
9+
import org.apache.calcite.sql.parser.SqlParserPos;
10+
11+
12+
public class SqlShowModelParams extends SqlCall {
13+
final SqlModelIdentifier modelName;
14+
15+
public SqlShowModelParams(SqlParserPos pos, SqlModelIdentifier modelName) {
16+
super(pos);
17+
this.modelName = modelName;
18+
}
19+
20+
@Override
21+
public SqlOperator getOperator() {
22+
throw new UnsupportedOperationException();
23+
}
24+
25+
@Override
26+
public List<SqlNode> getOperandList() {
27+
throw new UnsupportedOperationException();
28+
}
29+
30+
@Override
31+
public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
32+
writer.keyword("DESCRIBE");
33+
writer.keyword("MODEL");
34+
this.modelName.unparse(writer, leftPrec, rightPrec);
35+
}
36+
public SqlModelIdentifier getModelName() {
37+
return this.modelName;
38+
}
39+
40+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package com.dask.sql.parser;
2+
import java.util.ArrayList;
3+
import java.util.List;
4+
import org.apache.calcite.sql.SqlCall;
5+
import org.apache.calcite.sql.SqlOperator;
6+
import org.apache.calcite.sql.SqlWriter;
7+
import org.apache.calcite.sql.SqlNode;
8+
import org.apache.calcite.sql.parser.SqlParserPos;
9+
10+
public class SqlShowModels extends SqlCall {
11+
12+
public SqlShowModels(SqlParserPos pos) {
13+
super(pos);
14+
}
15+
public SqlOperator getOperator() {
16+
throw new UnsupportedOperationException();
17+
}
18+
19+
public List<SqlNode> getOperandList() {
20+
ArrayList<SqlNode> operandList = new ArrayList<SqlNode>();
21+
return operandList;
22+
}
23+
24+
@Override
25+
public void unparse(SqlWriter writer,int leftPrec, int rightPrec) {
26+
writer.keyword("SHOW");
27+
writer.keyword("MODELS");
28+
29+
}
30+
}

0 commit comments

Comments
 (0)