diff --git a/continuous_integration/recipe/meta.yaml b/continuous_integration/recipe/meta.yaml index 7c7b2e210..3098fb668 100644 --- a/continuous_integration/recipe/meta.yaml +++ b/continuous_integration/recipe/meta.yaml @@ -35,6 +35,7 @@ requirements: - prompt_toolkit >=3.0.8 - pygments >=2.7.3 - nest-asyncio >=1.0.0 + - tabulate >=0.8.9 test: commands: diff --git a/dask_sql/cmd.py b/dask_sql/cmd.py index d6857bd19..a7ed0c6e2 100644 --- a/dask_sql/cmd.py +++ b/dask_sql/cmd.py @@ -1,11 +1,19 @@ import logging +import os +import sys +import tempfile import traceback from argparse import ArgumentParser from functools import partial +from typing import Union import pandas as pd from dask.datasets import timeseries -from dask.distributed import Client +from dask.distributed import Client, as_completed +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.completion import WordCompleter +from prompt_toolkit.history import FileHistory +from prompt_toolkit.shortcuts import ProgressBar from pygments.lexers.sql import SqlLexer try: @@ -17,6 +25,10 @@ from dask_sql.context import Context +meta_command_completer = WordCompleter( + ["\\l", "\\d?", "\\dt", "\\df", "\\de", "\\dm", "\\conninfo", "quit"] +) + class CompatiblePromptSession: """ @@ -32,17 +44,106 @@ class CompatiblePromptSession: """ def __init__(self, lexer) -> None: # pragma: no cover + # make sure everytime dask-sql uses same history file + kwargs = { + "lexer": lexer, + "history": FileHistory( + os.path.join(tempfile.gettempdir(), "dask-sql-history") + ), + "auto_suggest": AutoSuggestFromHistory(), + "completer": meta_command_completer, + } try: # Version >= 2.0.1: we can use the session object from prompt_toolkit import PromptSession - session = PromptSession(lexer=lexer) + session = PromptSession(**kwargs) self.prompt = session.prompt except ImportError: # Version < 2.0: there is no session object from prompt_toolkit.shortcuts import prompt - self.prompt = partial(prompt, lexer=lexer) + self.prompt = partial(prompt, **kwargs) + + +def _display_markdown(content, **kwargs): + df = pd.DataFrame(content, **kwargs) + print(df.to_markdown(tablefmt="fancy_grid")) + + +def _parse_meta_command(sql): + command, _, arg = sql.partition(" ") + return command, arg.strip() + + +def _meta_commands(sql: str, context: Context, client: Client) -> Union[bool, Client]: + """ + parses metacommands and prints their result + returns True if meta commands detected + """ + cmd, schema_name = _parse_meta_command(sql) + available_commands = [ + ["\\l", "List schemas"], + ["\\d?, help, ?", "Show available commands"], + ["\\conninfo", "Show Dask cluster info"], + ["\\dt [schema]", "List tables"], + ["\\df [schema]", "List functions"], + ["\\dm [schema]", "List models"], + ["\\de [schema]", "List experiments"], + ["\\dss [schema]", "Switch schema"], + ["\\dsc [dask scheduler address]", "Switch Dask cluster"], + ["quit", "Quits dask-sql-cli"], + ] + if cmd == "\\dsc": + # Switch Dask cluster + _, scheduler_address = _parse_meta_command(sql) + client = Client(scheduler_address) + return client # pragma: no cover + schema_name = schema_name or context.schema_name + if cmd == "\\d?" or cmd == "help" or cmd == "?": + _display_markdown(available_commands, columns=["Commands", "Description"]) + elif cmd == "\\l": + _display_markdown(context.schema.keys(), columns=["Schemas"]) + elif cmd == "\\dt": + _display_markdown(context.schema[schema_name].tables.keys(), columns=["Tables"]) + elif cmd == "\\df": + _display_markdown( + context.schema[schema_name].functions.keys(), columns=["Functions"] + ) + elif cmd == "\\de": + _display_markdown( + context.schema[schema_name].experiments.keys(), columns=["Experiments"] + ) + elif cmd == "\\dm": + _display_markdown(context.schema[schema_name].models.keys(), columns=["Models"]) + elif cmd == "\\conninfo": + cluster_info = [ + ["Dask scheduler", client.scheduler.__dict__["addr"]], + ["Dask dashboard", client.dashboard_link], + ["Cluster status", client.status], + ["Dask workers", len(client.cluster.workers)], + ] + _display_markdown( + cluster_info, columns=["components", "value"] + ) # pragma: no cover + elif cmd == "\\dss": + if schema_name in context.schema: + context.schema_name = schema_name + else: + print(f"Schema {schema_name} not available") + elif cmd == "quit": + print("Quitting dask-sql ...") + client.close() # for safer side + sys.exit() + elif cmd.startswith("\\"): + print( + f"The meta command {cmd} not available, please use commands from below list" + ) + _display_markdown(available_commands, columns=["Commands", "Description"]) + else: + # nothing detected probably not a meta command + return False + return True def cmd_loop( @@ -103,11 +204,27 @@ def cmd_loop( if not text: continue - try: - df = context.sql(text, return_futures=False) - print(df) - except Exception: - traceback.print_exc() + meta_command_detected = _meta_commands(text, context=context, client=client) + if isinstance(meta_command_detected, Client): + client = meta_command_detected + + if not meta_command_detected: + try: + df = context.sql(text, return_futures=True) + if df is not None: # some sql commands returns None + df = df.persist() + # Now turn it into a list of futures + futures = client.futures_of(df) + with ProgressBar() as pb: + for _ in pb( + as_completed(futures), total=len(futures), label="Executing" + ): + continue + df = df.compute() + print(df.to_markdown(tablefmt="fancy_grid")) + + except Exception: + traceback.print_exc() def main(): # pragma: no cover diff --git a/setup.py b/setup.py index c29b41ff1..d13801e71 100755 --- a/setup.py +++ b/setup.py @@ -83,6 +83,7 @@ def run(self): "tzlocal>=2.1", "prompt_toolkit", "pygments", + "tabulate", "nest-asyncio", # backport for python versions without importlib.metadata "importlib_metadata; python_version < '3.8.0'", diff --git a/tests/integration/test_cmd.py b/tests/integration/test_cmd.py new file mode 100644 index 000000000..8193fb6e8 --- /dev/null +++ b/tests/integration/test_cmd.py @@ -0,0 +1,137 @@ +import pytest +from mock import MagicMock, patch +from prompt_toolkit.application import create_app_session +from prompt_toolkit.input import create_pipe_input +from prompt_toolkit.output import DummyOutput +from prompt_toolkit.shortcuts import PromptSession + +from dask_sql.cmd import _meta_commands + + +@pytest.fixture(autouse=True, scope="function") +def mock_prompt_input(): + pipe_input = create_pipe_input() + try: + with create_app_session(input=pipe_input, output=DummyOutput()): + yield pipe_input + finally: + pipe_input.close() + + +def _feed_cli_with_input( + text, + editing_mode=None, + clipboard=None, + history=None, + multiline=False, + check_line_ending=True, + key_bindings=None, +): + """ + Create a Prompt, feed it with the given user input and return the CLI + object. + This returns a (result, Application) tuple. + """ + # If the given text doesn't end with a newline, the interface won't finish. + if check_line_ending: + assert text.endswith("\r") + + inp = create_pipe_input() + + try: + inp.send_text(text) + session = PromptSession( + input=inp, + output=DummyOutput(), + editing_mode=editing_mode, + history=history, + multiline=multiline, + clipboard=clipboard, + key_bindings=key_bindings, + ) + + result = session.prompt() + return session.default_buffer.document, session.app + + finally: + inp.close() + + +def test_meta_commands(c, client, capsys): + _meta_commands("?", context=c, client=client) + captured = capsys.readouterr() + assert "Commands" in captured.out + + _meta_commands("help", context=c, client=client) + captured = capsys.readouterr() + assert "Commands" in captured.out + + _meta_commands("\\d?", context=c, client=client) + captured = capsys.readouterr() + assert "Commands" in captured.out + + _meta_commands("\\l", context=c, client=client) + captured = capsys.readouterr() + assert "Schemas" in captured.out + + _meta_commands("\\dt", context=c, client=client) + captured = capsys.readouterr() + assert "Tables" in captured.out + + _meta_commands("\\dm", context=c, client=client) + captured = capsys.readouterr() + assert "Models" in captured.out + + _meta_commands("\\df", context=c, client=client) + captured = capsys.readouterr() + assert "Functions" in captured.out + + _meta_commands("\\de", context=c, client=client) + captured = capsys.readouterr() + assert "Experiments" in captured.out + + c.create_schema("test_schema") + _meta_commands("\\dss test_schema", context=c, client=client) + assert c.schema_name == "test_schema" + + _meta_commands("\\dss not_exists", context=c, client=client) + captured = capsys.readouterr() + assert "Schema not_exists not available\n" == captured.out + + with pytest.raises( + OSError, + match="Timed out during handshake while " + "connecting to tcp://localhost:8787 after 5 s", + ): + client = _meta_commands("\\dsc localhost:8787", context=c, client=client) + assert client.scheduler.__dict__["addr"] == "localhost:8787" + + +def test_connection_info(c, client, capsys): + dummy_client = MagicMock() + dummy_client.scheduler.__dict__["addr"] = "somewhereonearth:8787" + dummy_client.cluster.worker = ["worker1", "worker2"] + + _meta_commands("\\conninfo", context=c, client=dummy_client) + captured = capsys.readouterr() + assert "somewhereonearth" in captured.out + + +def test_quit(c, client, capsys): + with patch("sys.exit", return_value=lambda: "exit"): + _meta_commands("quit", context=c, client=client) + captured = capsys.readouterr() + assert captured.out == "Quitting dask-sql ...\n" + + +def test_non_meta_commands(c, client, capsys): + _meta_commands("\\x", context=c, client=client) + captured = capsys.readouterr() + assert ( + "The meta command \\x not available, please use commands from below list" + in captured.out + ) + + res = _meta_commands("Select 42 as answer", context=c, client=client) + captured = capsys.readouterr() + assert res is False