diff --git a/dask_sql/server/app.py b/dask_sql/server/app.py index 619edf2b5..34217629e 100644 --- a/dask_sql/server/app.py +++ b/dask_sql/server/app.py @@ -10,6 +10,7 @@ from uvicorn import Config, Server from dask_sql.context import Context +from dask_sql.server.presto_jdbc import create_meta_data from dask_sql.server.responses import DataResults, ErrorResults, QueryResults app = FastAPI() @@ -74,6 +75,12 @@ async def query(request: Request): """ try: sql = (await request.body()).decode().strip() + # required for PrestoDB JDBC driver compatibility + # replaces queries to unsupported `system` catalog with queries to `system_jdbc` + # schema created by `create_meta_data(context)` when `jdbc_metadata=True` + # TODO: explore Trino which should make JDBC compatibility easier but requires + # changing response headers (see https://github.com/dask-contrib/dask-sql/pull/351) + sql = sql.replace("system.jdbc", "system_jdbc") df = request.app.c.sql(sql) if df is None: @@ -102,6 +109,7 @@ def run_server( startup=False, log_level=None, blocking: bool = True, + jdbc_metadata: bool = False, ): # pragma: no cover """ Run a HTTP server for answering SQL queries using ``dask-sql``. @@ -128,6 +136,8 @@ def run_server( log_level: (:obj:`str`): The log level of the server and dask-sql blocking: (:obj:`bool`): If running in an environment with an event loop (e.g. a jupyter notebook), do not block. The server can be stopped with `context.stop_server()` afterwards. + jdbc_metadata: (:obj:`bool`): If enabled create JDBC metadata tables using schemas and tables in + the current dask_sql context Example: It is possible to run an SQL server by using the CLI script ``dask-sql-server`` @@ -179,6 +189,8 @@ def run_server( """ _init_app(app, context=context, client=client) + if jdbc_metadata: + create_meta_data(context) if startup: app.c.sql("SELECT 1 + 1").compute() diff --git a/dask_sql/server/presto_jdbc.py b/dask_sql/server/presto_jdbc.py new file mode 100644 index 000000000..d3c3880cb --- /dev/null +++ b/dask_sql/server/presto_jdbc.py @@ -0,0 +1,149 @@ +import logging + +import pandas as pd + +from dask_sql.context import Context + +logger = logging.getLogger(__name__) + + +def create_meta_data(c: Context): + """ + Creates the schema, table and column data for prestodb JDBC driver so that data can be viewed + in a database tool like DBeaver. It doesn't create a catalog entry although JDBC expects one + as dask-sql doesn't support catalogs. For both catalogs and procedures empty placeholder + tables are created. + + The meta-data appears in a separate schema called system_jdbc largely because the JDBC driver + tries to access system.jdbc and it sufficiently so shouldn't clash with other schemas. + + A function is required in the /v1/statement to change system.jdbc to system_jdbc and ignore + order by statements from the driver (as adjust_for_presto_sql above) + + :param c: Context containing created tables + :return: + """ + + if c is None: + logger.warn("Context None: jdbc meta data not created") + return + catalog = "" + system_schema = "system_jdbc" + c.create_schema(system_schema) + + # TODO: add support for catalogs in presto interface + # see https://github.com/dask-contrib/dask-sql/pull/351 + # if catalog and len(catalog.strip()) > 0: + # catalogs = pd.DataFrame().append(create_catalog_row(catalog), ignore_index=True) + # c.create_table("catalogs", catalogs, schema_name=system_schema) + + schemas = pd.DataFrame().append(create_schema_row(), ignore_index=True) + c.create_table("schemas", schemas, schema_name=system_schema) + schema_rows = [] + + tables = pd.DataFrame().append(create_table_row(), ignore_index=True) + c.create_table("tables", tables, schema_name=system_schema) + table_rows = [] + + columns = pd.DataFrame().append(create_column_row(), ignore_index=True) + c.create_table("columns", columns, schema_name=system_schema) + column_rows = [] + + for schema_name, schema in c.schema.items(): + schema_rows.append(create_schema_row(catalog, schema_name)) + for table_name, dc in schema.tables.items(): + df = dc.df + logger.info(f"schema ${schema_name}, table {table_name}, {df}") + table_rows.append(create_table_row(catalog, schema_name, table_name)) + pos: int = 0 + for column in df.columns: + pos = pos + 1 + logger.debug(f"column {column}") + dtype = "VARCHAR" + if df[column].dtype == "int64" or df[column].dtype == "int": + dtype = "INTEGER" + elif df[column].dtype == "float64" or df[column].dtype == "float": + dtype = "FLOAT" + elif ( + df[column].dtype == "datetime" + or df[column].dtype == "datetime64[ns]" + ): + dtype = "TIMESTAMP" + column_rows.append( + create_column_row( + catalog, + schema_name, + table_name, + dtype, + df[column].name, + str(pos), + ) + ) + + schemas = pd.DataFrame(schema_rows) + c.create_table("schemas", schemas, schema_name=system_schema) + tables = pd.DataFrame(table_rows) + c.create_table("tables", tables, schema_name=system_schema) + columns = pd.DataFrame(column_rows) + c.create_table("columns", columns, schema_name=system_schema) + + logger.info(f"jdbc meta data ready for {len(table_rows)} tables") + + +def create_catalog_row(catalog: str = ""): + return {"TABLE_CAT": catalog} + + +def create_schema_row(catalog: str = "", schema: str = ""): + return {"TABLE_CATALOG": catalog, "TABLE_SCHEM": schema} + + +def create_table_row(catalog: str = "", schema: str = "", table: str = ""): + return { + "TABLE_CAT": catalog, + "TABLE_SCHEM": schema, + "TABLE_NAME": table, + "TABLE_TYPE": "", + "REMARKS": "", + "TYPE_CAT": "", + "TYPE_SCHEM": "", + "TYPE_NAME": "", + "SELF_REFERENCING_COL_NAME": "", + "REF_GENERATION": "", + } + + +def create_column_row( + catalog: str = "", + schema: str = "", + table: str = "", + dtype: str = "", + column: str = "", + pos: str = "", +): + return { + "TABLE_CAT": catalog, + "TABLE_SCHEM": schema, + "TABLE_NAME": table, + "COLUMN_NAME": column, + "DATA_TYPE": dtype, + "TYPE_NAME": dtype, + "COLUMN_SIZE": "", + "BUFFER_LENGTH": "", + "DECIMAL_DIGITS": "", + "NUM_PREC_RADIX": "", + "NULLABLE": "", + "REMARKS": "", + "COLUMN_DEF": "", + "SQL_DATA_TYPE": dtype, + "SQL_DATETIME_SUB": "", + "CHAR_OCTET_LENGTH": "", + "ORDINAL_POSITION": pos, + "IS_NULLABLE": "", + "SCOPE_CATALOG": "", + "SCOPE_SCHEMA": "", + "SCOPE_TABLE": "", + "SOURCE_DATA_TYPE": "", + "IS_AUTOINCREMENT": "", + "IS_GENERATEDCOLUMN": "", + } diff --git a/tests/integration/test_jdbc.py b/tests/integration/test_jdbc.py new file mode 100644 index 000000000..355f1a2fb --- /dev/null +++ b/tests/integration/test_jdbc.py @@ -0,0 +1,236 @@ +from time import sleep + +import pandas as pd +import pytest + +from dask_sql import Context +from dask_sql.server.app import _init_app, app +from dask_sql.server.presto_jdbc import create_meta_data + +# needed for the testclient +pytest.importorskip("requests") + +schema = "a_schema" +table = "a_table" + + +@pytest.fixture(scope="module") +def c(): + c = Context() + c.create_schema(schema) + row = create_table_row() + tables = pd.DataFrame().append(row, ignore_index=True) + tables = tables.astype({"AN_INT": "int64"}) + c.create_table(table, tables, schema_name=schema) + + yield c + + c.drop_schema(schema) + + +@pytest.fixture(scope="module") +def app_client(c): + c.sql("SELECT 1 + 1").compute() + _init_app(app, c) + # late import for the importskip + from fastapi.testclient import TestClient + + yield TestClient(app) + + app.client.close() + + +def test_jdbc_has_schema(app_client, c): + create_meta_data(c) + + check_data(app_client) + + response = app_client.post( + "/v1/statement", data="SELECT * from system.jdbc.schemas" + ) + assert response.status_code == 200 + result = get_result_or_error(app_client, response) + + assert_result(result, 2, 3) + assert result["columns"] == [ + { + "name": "TABLE_CATALOG", + "type": "varchar", + "typeSignature": {"rawType": "varchar", "arguments": []}, + }, + { + "name": "TABLE_SCHEM", + "type": "varchar", + "typeSignature": {"rawType": "varchar", "arguments": []}, + }, + ] + assert result["data"] == [ + ["", "root"], + ["", "a_schema"], + ["", "system_jdbc"], + ] + + +def test_jdbc_has_table(app_client, c): + create_meta_data(c) + check_data(app_client) + + response = app_client.post("/v1/statement", data="SELECT * from system.jdbc.tables") + assert response.status_code == 200 + result = get_result_or_error(app_client, response) + + assert_result(result, 10, 4) + assert result["data"] == [ + ["", "a_schema", "a_table", "", "", "", "", "", "", ""], + ["", "system_jdbc", "schemas", "", "", "", "", "", "", ""], + ["", "system_jdbc", "tables", "", "", "", "", "", "", ""], + ["", "system_jdbc", "columns", "", "", "", "", "", "", ""], + ] + + +def test_jdbc_has_columns(app_client, c): + create_meta_data(c) + check_data(app_client) + + response = app_client.post( + "/v1/statement", + data=f"SELECT * from system.jdbc.columns where TABLE_NAME = '{table}'", + ) + assert response.status_code == 200 + result = get_result_or_error(app_client, response) + + assert_result(result, 24, 3) + assert result["data"] == [ + [ + "", + "a_schema", + "a_table", + "A_STR", + "VARCHAR", + "VARCHAR", + "", + "", + "", + "", + "", + "", + "", + "VARCHAR", + "", + "", + "1", + "", + "", + "", + "", + "", + "", + "", + ], + [ + "", + "a_schema", + "a_table", + "AN_INT", + "INTEGER", + "INTEGER", + "", + "", + "", + "", + "", + "", + "", + "INTEGER", + "", + "", + "2", + "", + "", + "", + "", + "", + "", + "", + ], + [ + "", + "a_schema", + "a_table", + "A_FLOAT", + "FLOAT", + "FLOAT", + "", + "", + "", + "", + "", + "", + "", + "FLOAT", + "", + "", + "3", + "", + "", + "", + "", + "", + "", + "", + ], + ] + + +def assert_result(result, col_len, data_len): + assert "columns" in result + assert "data" in result + assert "error" not in result + assert len(result["columns"]) == col_len + assert len(result["data"]) == data_len + + +def create_table_row(a_str: str = "any", an_int: int = 1, a_float: float = 1.1): + return { + "A_STR": a_str, + "AN_INT": an_int, + "A_FLOAT": a_float, + } + + +def check_data(app_client): + response = app_client.post("/v1/statement", data=f"SELECT * from {schema}.{table}") + assert response.status_code == 200 + a_table = get_result_or_error(app_client, response) + assert "columns" in a_table + assert "data" in a_table + assert "error" not in a_table + + +def get_result_or_error(app_client, response): + result = response.json() + + assert "nextUri" in result + assert "error" not in result + + status_url = result["nextUri"] + next_url = status_url + + counter = 0 + while True: + response = app_client.get(next_url) + assert response.status_code == 200 + + result = response.json() + + if "nextUri" not in result: + break + + next_url = result["nextUri"] + + counter += 1 + assert counter <= 100 + + sleep(0.1) + + return result diff --git a/tests/integration/test_server.py b/tests/integration/test_server.py index 19f725dfb..88d08a4f7 100644 --- a/tests/integration/test_server.py +++ b/tests/integration/test_server.py @@ -2,6 +2,7 @@ import pytest +from dask_sql import Context from dask_sql.server.app import _init_app, app # needed for the testclient @@ -10,7 +11,9 @@ @pytest.fixture(scope="module") def app_client(): - _init_app(app) + c = Context() + c.sql("SELECT 1 + 1").compute() + _init_app(app, c) # late import for the importskip from fastapi.testclient import TestClient