diff --git a/src/dbt_mcp/dbt_admin/tools.py b/src/dbt_mcp/dbt_admin/tools.py index 07984653..5cc3d6d4 100644 --- a/src/dbt_mcp/dbt_admin/tools.py +++ b/src/dbt_mcp/dbt_admin/tools.py @@ -10,6 +10,7 @@ from dbt_mcp.dbt_admin.client import DbtAdminAPIClient from dbt_mcp.prompts.prompts import get_prompt from dbt_mcp.tools.annotations import create_tool_annotations +from dbt_mcp.tools.config import DbtMcpContext from dbt_mcp.tools.definitions import ToolDefinition from dbt_mcp.tools.register import register_tools from dbt_mcp.tools.tool_names import ToolName @@ -38,10 +39,21 @@ class JobRunStatus(str, Enum): } -def create_admin_api_tool_definitions( - admin_client: DbtAdminAPIClient, admin_api_config: AdminApiConfig -) -> list[ToolDefinition]: +def get_admin_client_and_config( + ctx: DbtMcpContext, +) -> tuple[DbtAdminAPIClient, AdminApiConfig]: + admin_api_config = ctx.get_admin_api_config() + if admin_api_config is None: + raise ValueError("admin api config is not set") + admin_api_client = ctx.get_admin_api_client() + if admin_api_client is None: + raise ValueError("admin api client is not set") + return admin_api_client, admin_api_config + + +def create_admin_api_tool_definitions() -> list[ToolDefinition]: def list_jobs( + ctx: DbtMcpContext, # TODO: add support for project_id in the future # project_id: Optional[int] = None, limit: int | None = None, @@ -49,6 +61,7 @@ def list_jobs( ) -> list[dict[str, Any]] | str: """List jobs in an account.""" try: + admin_client, admin_api_config = get_admin_client_and_config(ctx) params = {} # if project_id: # params["project_id"] = project_id @@ -65,15 +78,20 @@ def list_jobs( ) return str(e) - def get_job_details(job_id: int) -> dict[str, Any] | str: + def get_job_details( + ctx: DbtMcpContext, + job_id: int, + ) -> dict[str, Any] | str: """Get details for a specific job.""" try: + admin_client, admin_api_config = get_admin_client_and_config(ctx) return admin_client.get_job_details(admin_api_config.account_id, job_id) except Exception as e: logger.error(f"Error getting job {job_id}: {e}") return str(e) def trigger_job_run( + ctx: DbtMcpContext, job_id: int, cause: str = "Triggered by dbt MCP", git_branch: str | None = None, @@ -82,6 +100,7 @@ def trigger_job_run( ) -> dict[str, Any] | str: """Trigger a job run.""" try: + admin_client, admin_api_config = get_admin_client_and_config(ctx) kwargs = {} if git_branch: kwargs["git_branch"] = git_branch @@ -97,6 +116,7 @@ def trigger_job_run( return str(e) def list_jobs_runs( + ctx: DbtMcpContext, job_id: int | None = None, status: JobRunStatus | None = None, limit: int | None = None, @@ -105,6 +125,7 @@ def list_jobs_runs( ) -> list[dict[str, Any]] | str: """List runs in an account.""" try: + admin_client, admin_api_config = get_admin_client_and_config(ctx) params: dict[str, Any] = {} if job_id: params["job_definition_id"] = job_id @@ -125,6 +146,7 @@ def list_jobs_runs( return str(e) def get_job_run_details( + ctx: DbtMcpContext, run_id: int, debug: bool = Field( default=False, @@ -133,6 +155,7 @@ def get_job_run_details( ) -> dict[str, Any] | str: """Get details for a specific job run.""" try: + admin_client, admin_api_config = get_admin_client_and_config(ctx) return admin_client.get_job_run_details( admin_api_config.account_id, run_id, debug=debug ) @@ -140,25 +163,37 @@ def get_job_run_details( logger.error(f"Error getting run {run_id}: {e}") return str(e) - def cancel_job_run(run_id: int) -> dict[str, Any] | str: + def cancel_job_run( + ctx: DbtMcpContext, + run_id: int, + ) -> dict[str, Any] | str: """Cancel a job run.""" try: + admin_client, admin_api_config = get_admin_client_and_config(ctx) return admin_client.cancel_job_run(admin_api_config.account_id, run_id) except Exception as e: logger.error(f"Error cancelling run {run_id}: {e}") return str(e) - def retry_job_run(run_id: int) -> dict[str, Any] | str: + def retry_job_run( + ctx: DbtMcpContext, + run_id: int, + ) -> dict[str, Any] | str: """Retry a failed job run.""" try: + admin_client, admin_api_config = get_admin_client_and_config(ctx) return admin_client.retry_job_run(admin_api_config.account_id, run_id) except Exception as e: logger.error(f"Error retrying run {run_id}: {e}") return str(e) - def list_job_run_artifacts(run_id: int) -> list[str] | str: + def list_job_run_artifacts( + ctx: DbtMcpContext, + run_id: int, + ) -> list[str] | str: """List artifacts for a job run.""" try: + admin_client, admin_api_config = get_admin_client_and_config(ctx) return admin_client.list_job_run_artifacts( admin_api_config.account_id, run_id ) @@ -167,10 +202,14 @@ def list_job_run_artifacts(run_id: int) -> list[str] | str: return str(e) def get_job_run_artifact( - run_id: int, artifact_path: str, step: int | None = None + ctx: DbtMcpContext, + run_id: int, + artifact_path: str, + step: int | None = None, ) -> Any | str: """Get a specific job run artifact.""" try: + admin_client, admin_api_config = get_admin_client_and_config(ctx) return admin_client.get_job_run_artifact( admin_api_config.account_id, run_id, artifact_path, step ) @@ -276,13 +315,11 @@ def get_job_run_artifact( def register_admin_api_tools( dbt_mcp: FastMCP, - admin_config: AdminApiConfig, exclude_tools: Sequence[ToolName] = [], ) -> None: """Register dbt Admin API tools.""" - admin_client = DbtAdminAPIClient(admin_config) register_tools( dbt_mcp, - create_admin_api_tool_definitions(admin_client, admin_config), + create_admin_api_tool_definitions(), exclude_tools, ) diff --git a/src/dbt_mcp/dbt_cli/tools.py b/src/dbt_mcp/dbt_cli/tools.py index f7634bc0..9d8f74cf 100644 --- a/src/dbt_mcp/dbt_cli/tools.py +++ b/src/dbt_mcp/dbt_cli/tools.py @@ -7,14 +7,23 @@ from dbt_mcp.config.config import DbtCliConfig from dbt_mcp.prompts.prompts import get_prompt +from dbt_mcp.tools.annotations import create_tool_annotations +from dbt_mcp.tools.config import DbtMcpContext from dbt_mcp.tools.definitions import ToolDefinition from dbt_mcp.tools.register import register_tools from dbt_mcp.tools.tool_names import ToolName -from dbt_mcp.tools.annotations import create_tool_annotations -def create_dbt_cli_tool_definitions(config: DbtCliConfig) -> list[ToolDefinition]: +def get_cli_config(ctx: DbtMcpContext) -> DbtCliConfig: + dbt_cli_config = ctx.get_dbt_cli_config() + if dbt_cli_config is None: + raise ValueError("dbt cli config is not set") + return dbt_cli_config + + +def create_dbt_cli_tool_definitions() -> list[ToolDefinition]: def _run_dbt_command( + ctx: DbtMcpContext, command: list[str], selector: str | None = None, resource_type: list[str] | None = None, @@ -22,6 +31,7 @@ def _run_dbt_command( is_full_refresh: bool | None = False, vars: str | None = None, ) -> str: + config = ctx.get_dbt_cli_config() try: # Commands that should always be quiet to reduce output verbosity verbose_commands = [ @@ -79,6 +89,7 @@ def _run_dbt_command( return str(e) def build( + ctx: DbtMcpContext, selector: str | None = Field( default=None, description=get_prompt("dbt_cli/args/selectors") ), @@ -90,6 +101,7 @@ def build( ), ) -> str: return _run_dbt_command( + ctx, ["build"], selector, is_selectable=True, @@ -97,13 +109,14 @@ def build( vars=vars, ) - def compile() -> str: - return _run_dbt_command(["compile"]) + def compile(ctx: DbtMcpContext) -> str: + return _run_dbt_command(ctx, ["compile"]) - def docs() -> str: - return _run_dbt_command(["docs", "generate"]) + def docs(ctx: DbtMcpContext) -> str: + return _run_dbt_command(ctx, ["docs", "generate"]) def ls( + ctx: DbtMcpContext, selector: str | None = Field( default=None, description=get_prompt("dbt_cli/args/selectors") ), @@ -113,16 +126,18 @@ def ls( ), ) -> str: return _run_dbt_command( + ctx, ["list"], selector, resource_type=resource_type, is_selectable=True, ) - def parse() -> str: - return _run_dbt_command(["parse"]) + def parse(ctx: DbtMcpContext) -> str: + return _run_dbt_command(ctx, ["parse"]) def run( + ctx: DbtMcpContext, selector: str | None = Field( default=None, description=get_prompt("dbt_cli/args/selectors") ), @@ -134,6 +149,7 @@ def run( ), ) -> str: return _run_dbt_command( + ctx, ["run"], selector, is_selectable=True, @@ -142,6 +158,7 @@ def run( ) def test( + ctx: DbtMcpContext, selector: str | None = Field( default=None, description=get_prompt("dbt_cli/args/selectors") ), @@ -149,9 +166,10 @@ def test( default=None, description=get_prompt("dbt_cli/args/vars") ), ) -> str: - return _run_dbt_command(["test"], selector, is_selectable=True, vars=vars) + return _run_dbt_command(ctx, ["test"], selector, is_selectable=True, vars=vars) def show( + ctx: DbtMcpContext, sql_query: str = Field(description=get_prompt("dbt_cli/args/sql_query")), limit: int = Field(default=5, description=get_prompt("dbt_cli/args/limit")), ) -> str: @@ -171,7 +189,7 @@ def show( if cli_limit is not None: args.extend(["--limit", str(cli_limit)]) args.extend(["--output", "json"]) - return _run_dbt_command(args) + return _run_dbt_command(ctx, args) return [ ToolDefinition( @@ -260,11 +278,10 @@ def show( def register_dbt_cli_tools( dbt_mcp: FastMCP, - config: DbtCliConfig, exclude_tools: Sequence[ToolName] = [], ) -> None: register_tools( dbt_mcp, - create_dbt_cli_tool_definitions(config), + create_dbt_cli_tool_definitions(), exclude_tools, ) diff --git a/src/dbt_mcp/discovery/tools.py b/src/dbt_mcp/discovery/tools.py index 4a33e8ce..824fb4c6 100644 --- a/src/dbt_mcp/discovery/tools.py +++ b/src/dbt_mcp/discovery/tools.py @@ -3,10 +3,10 @@ from mcp.server.fastmcp import FastMCP -from dbt_mcp.config.config import DiscoveryConfig from dbt_mcp.discovery.client import MetadataAPIClient, ModelsFetcher from dbt_mcp.prompts.prompts import get_prompt from dbt_mcp.tools.annotations import create_tool_annotations +from dbt_mcp.tools.config import DbtMcpContext from dbt_mcp.tools.definitions import ToolDefinition from dbt_mcp.tools.register import register_tools from dbt_mcp.tools.tool_names import ToolName @@ -14,17 +14,25 @@ logger = logging.getLogger(__name__) -def create_discovery_tool_definitions(config: DiscoveryConfig) -> list[ToolDefinition]: +def get_fetcher(ctx: DbtMcpContext) -> ModelsFetcher: + discovery_config = ctx.get_discovery_config() + if discovery_config is None: + raise ValueError("Discovery config is not set") api_client = MetadataAPIClient( - url=config.url, - headers=config.headers, + url=discovery_config.url, + headers=discovery_config.headers, ) models_fetcher = ModelsFetcher( - api_client=api_client, environment_id=config.environment_id + api_client=api_client, environment_id=discovery_config.environment_id ) + return models_fetcher - def get_mart_models() -> list[dict] | str: + +def create_discovery_tool_definitions() -> list[ToolDefinition]: + + def get_mart_models(ctx: DbtMcpContext) -> list[dict] | str: try: + models_fetcher = get_fetcher(ctx) mart_models = models_fetcher.fetch_models( model_filter={"modelingLayer": "marts"} ) @@ -32,40 +40,45 @@ def get_mart_models() -> list[dict] | str: except Exception as e: return str(e) - def get_all_models() -> list[dict] | str: + def get_all_models(ctx: DbtMcpContext) -> list[dict] | str: try: + models_fetcher = get_fetcher(ctx) return models_fetcher.fetch_models() except Exception as e: return str(e) def get_model_details( - model_name: str | None = None, unique_id: str | None = None + ctx: DbtMcpContext, model_name: str | None = None, unique_id: str | None = None ) -> dict | str: try: + models_fetcher = get_fetcher(ctx) return models_fetcher.fetch_model_details(model_name, unique_id) except Exception as e: return str(e) def get_model_parents( - model_name: str | None = None, unique_id: str | None = None + ctx: DbtMcpContext, model_name: str | None = None, unique_id: str | None = None ) -> list[dict] | str: try: + models_fetcher = get_fetcher(ctx) return models_fetcher.fetch_model_parents(model_name, unique_id) except Exception as e: return str(e) def get_model_children( - model_name: str | None = None, unique_id: str | None = None + ctx: DbtMcpContext, model_name: str | None = None, unique_id: str | None = None ) -> list[dict] | str: try: + models_fetcher = get_fetcher(ctx) return models_fetcher.fetch_model_children(model_name, unique_id) except Exception as e: return str(e) def get_model_health( - model_name: str | None = None, unique_id: str | None = None + ctx: DbtMcpContext, model_name: str | None = None, unique_id: str | None = None ) -> list[dict] | str: try: + models_fetcher = get_fetcher(ctx) return models_fetcher.fetch_model_health(model_name, unique_id) except Exception as e: return str(e) @@ -136,11 +149,10 @@ def get_model_health( def register_discovery_tools( dbt_mcp: FastMCP, - config: DiscoveryConfig, exclude_tools: Sequence[ToolName] = [], ) -> None: register_tools( dbt_mcp, - create_discovery_tool_definitions(config), + create_discovery_tool_definitions(), exclude_tools, ) diff --git a/src/dbt_mcp/mcp/server.py b/src/dbt_mcp/mcp/server.py index 5f844504..9768c76a 100644 --- a/src/dbt_mcp/mcp/server.py +++ b/src/dbt_mcp/mcp/server.py @@ -1,24 +1,22 @@ import logging import time from collections.abc import AsyncIterator, Sequence -from contextlib import ( - asynccontextmanager, -) +from contextlib import asynccontextmanager from typing import Any from dbtlabs_vortex.producer import shutdown +from dbtsl.client.sync import SyncSemanticLayerClient from mcp.server.fastmcp import FastMCP -from mcp.types import ( - ContentBlock, - TextContent, -) +from mcp.types import ContentBlock, TextContent from dbt_mcp.config.config import Config +from dbt_mcp.dbt_admin.client import DbtAdminAPIClient from dbt_mcp.dbt_admin.tools import register_admin_api_tools from dbt_mcp.dbt_cli.tools import register_dbt_cli_tools from dbt_mcp.discovery.tools import register_discovery_tools from dbt_mcp.semantic_layer.tools import register_sl_tools from dbt_mcp.sql.tools import SqlToolsManager, register_sql_tools +from dbt_mcp.tools.config import DbtMcpContext from dbt_mcp.tracking.tracking import UsageTracker logger = logging.getLogger(__name__) @@ -99,6 +97,38 @@ async def call_tool( ) return result + def get_context(self) -> DbtMcpContext: + """ + Returns a Context object. Note that the context will only be valid + during a request; outside a request, most methods will error. + """ + try: + request_context = self._mcp_server.request_context + except LookupError: + request_context = None + return DbtMcpContext( + request_context=request_context, + fastmcp=self, + semantic_layer_config=self.config.semantic_layer_config, + semantic_layer_client=( + SyncSemanticLayerClient( + environment_id=self.config.semantic_layer_config.prod_environment_id, + auth_token=self.config.semantic_layer_config.service_token, + host=self.config.semantic_layer_config.host, + ) + if self.config.semantic_layer_config + else None + ), + discovery_config=self.config.discovery_config, + dbt_cli_config=self.config.dbt_cli_config, + admin_api_config=self.config.admin_api_config, + admin_api_client=( + DbtAdminAPIClient(self.config.admin_api_config) + if self.config.admin_api_config + else None + ), + ) + async def create_dbt_mcp(config: Config): dbt_mcp = DbtMCP( @@ -110,19 +140,19 @@ async def create_dbt_mcp(config: Config): if config.semantic_layer_config: logger.info("Registering semantic layer tools") - register_sl_tools(dbt_mcp, config.semantic_layer_config, config.disable_tools) + register_sl_tools(dbt_mcp, config.disable_tools) if config.discovery_config: logger.info("Registering discovery tools") - register_discovery_tools(dbt_mcp, config.discovery_config, config.disable_tools) + register_discovery_tools(dbt_mcp, config.disable_tools) if config.dbt_cli_config: logger.info("Registering dbt cli tools") - register_dbt_cli_tools(dbt_mcp, config.dbt_cli_config, config.disable_tools) + register_dbt_cli_tools(dbt_mcp, config.disable_tools) if config.admin_api_config: logger.info("Registering dbt admin API tools") - register_admin_api_tools(dbt_mcp, config.admin_api_config, config.disable_tools) + register_admin_api_tools(dbt_mcp, config.disable_tools) if config.sql_config: logger.info("Registering SQL tools") diff --git a/src/dbt_mcp/semantic_layer/tools.py b/src/dbt_mcp/semantic_layer/tools.py index b6314a26..1dd964b0 100644 --- a/src/dbt_mcp/semantic_layer/tools.py +++ b/src/dbt_mcp/semantic_layer/tools.py @@ -5,12 +5,8 @@ from dbtsl.client.sync import SyncSemanticLayerClient from mcp.server.fastmcp import FastMCP -from dbt_mcp.config.config import SemanticLayerConfig from dbt_mcp.prompts.prompts import get_prompt -from dbt_mcp.semantic_layer.client import ( - SemanticLayerClientProtocol, - SemanticLayerFetcher, -) +from dbt_mcp.semantic_layer.client import SemanticLayerFetcher from dbt_mcp.semantic_layer.types import ( DimensionToolResponse, EntityToolResponse, @@ -19,41 +15,55 @@ OrderByParam, QueryMetricsSuccess, ) +from dbt_mcp.tools.annotations import create_tool_annotations +from dbt_mcp.tools.config import DbtMcpContext from dbt_mcp.tools.definitions import ToolDefinition from dbt_mcp.tools.register import register_tools from dbt_mcp.tools.tool_names import ToolName -from dbt_mcp.tools.annotations import create_tool_annotations logger = logging.getLogger(__name__) -def create_sl_tool_definitions( - config: SemanticLayerConfig, sl_client: SemanticLayerClientProtocol -) -> list[ToolDefinition]: - semantic_layer_fetcher = SemanticLayerFetcher( +def get_fetcher(ctx: DbtMcpContext) -> SemanticLayerFetcher: + sl_config = ctx.get_semantic_layer_config() + sl_client = ctx.get_semantic_layer_client() + return SemanticLayerFetcher( sl_client=sl_client, - config=config, + config=sl_config, ) - def list_metrics() -> list[MetricToolResponse] | str: + +def create_sl_tool_definitions() -> list[ToolDefinition]: + + def list_metrics( + ctx: DbtMcpContext, + ) -> list[MetricToolResponse] | str: try: + semantic_layer_fetcher = get_fetcher(ctx) return semantic_layer_fetcher.list_metrics() except Exception as e: return str(e) - def get_dimensions(metrics: list[str]) -> list[DimensionToolResponse] | str: + def get_dimensions( + ctx: DbtMcpContext, metrics: list[str] + ) -> list[DimensionToolResponse] | str: try: + semantic_layer_fetcher = get_fetcher(ctx) return semantic_layer_fetcher.get_dimensions(metrics=metrics) except Exception as e: return str(e) - def get_entities(metrics: list[str]) -> list[EntityToolResponse] | str: + def get_entities( + ctx: DbtMcpContext, metrics: list[str] + ) -> list[EntityToolResponse] | str: try: + semantic_layer_fetcher = get_fetcher(ctx) return semantic_layer_fetcher.get_entities(metrics=metrics) except Exception as e: return str(e) def query_metrics( + ctx: DbtMcpContext, metrics: list[str], group_by: list[GroupByParam] | None = None, order_by: list[OrderByParam] | None = None, @@ -61,6 +71,7 @@ def query_metrics( limit: int | None = None, ) -> str: try: + semantic_layer_fetcher = get_fetcher(ctx) result = semantic_layer_fetcher.query_metrics( metrics=metrics, group_by=group_by, @@ -76,6 +87,7 @@ def query_metrics( return str(e) def get_metrics_compiled_sql( + ctx: DbtMcpContext, metrics: list[str], group_by: list[GroupByParam] | None = None, order_by: list[OrderByParam] | None = None, @@ -83,6 +95,7 @@ def get_metrics_compiled_sql( limit: int | None = None, ) -> str: try: + semantic_layer_fetcher = get_fetcher(ctx) result = semantic_layer_fetcher.get_metrics_compiled_sql( metrics=metrics, group_by=group_by, @@ -153,18 +166,10 @@ def get_metrics_compiled_sql( def register_sl_tools( dbt_mcp: FastMCP, - config: SemanticLayerConfig, exclude_tools: Sequence[ToolName] = [], ) -> None: register_tools( dbt_mcp, - create_sl_tool_definitions( - config, - SyncSemanticLayerClient( - environment_id=config.prod_environment_id, - auth_token=config.service_token, - host=config.host, - ), - ), + create_sl_tool_definitions(), exclude_tools, ) diff --git a/src/dbt_mcp/tools/config.py b/src/dbt_mcp/tools/config.py new file mode 100644 index 00000000..c2e1778b --- /dev/null +++ b/src/dbt_mcp/tools/config.py @@ -0,0 +1,73 @@ +from mcp import ServerSession +from mcp.server.fastmcp import Context, FastMCP +from mcp.shared.context import RequestContext +from starlette.requests import Request + +from dbt_mcp.config.config import ( + AdminApiConfig, + DbtCliConfig, + DiscoveryConfig, + SemanticLayerConfig, +) +from dbt_mcp.dbt_admin.client import DbtAdminAPIClient +from dbt_mcp.semantic_layer.client import SemanticLayerClientProtocol + + +class DbtMcpContext(Context[ServerSession, object, Request]): + """Custom context for the MCP server""" + + _semantic_layer_config: SemanticLayerConfig | None = None + _semantic_layer_client: SemanticLayerClientProtocol | None = None + _discovery_config: DiscoveryConfig | None = None + _dbt_cli_config: DbtCliConfig | None = None + _admin_api_config: AdminApiConfig | None = None + _admin_api_client: DbtAdminAPIClient | None = None + + def __init__( + self, + request_context: RequestContext[ServerSession, object, Request] | None = None, + fastmcp: FastMCP | None = None, + semantic_layer_config: SemanticLayerConfig | None = None, + semantic_layer_client: SemanticLayerClientProtocol | None = None, + discovery_config: DiscoveryConfig | None = None, + dbt_cli_config: DbtCliConfig | None = None, + admin_api_config: AdminApiConfig | None = None, + admin_api_client: DbtAdminAPIClient | None = None, + ): + super().__init__(request_context=request_context, fastmcp=fastmcp) + self._semantic_layer_config = semantic_layer_config + self._semantic_layer_client = semantic_layer_client + self._discovery_config = discovery_config + self._dbt_cli_config = dbt_cli_config + self._admin_api_config = admin_api_config + self._admin_api_client = admin_api_client + + def get_semantic_layer_config(self) -> SemanticLayerConfig: + if self._semantic_layer_config is None: + raise ValueError("Semantic layer config is not set") + return self._semantic_layer_config + + def get_semantic_layer_client(self) -> SemanticLayerClientProtocol: + if self._semantic_layer_client is None: + raise ValueError("Semantic layer client is not set") + return self._semantic_layer_client + + def get_discovery_config(self) -> DiscoveryConfig: + if self._discovery_config is None: + raise ValueError("Discovery config is not set") + return self._discovery_config + + def get_dbt_cli_config(self) -> DbtCliConfig: + if self._dbt_cli_config is None: + raise ValueError("Dbt cli config is not set") + return self._dbt_cli_config + + def get_admin_api_config(self) -> AdminApiConfig: + if self._admin_api_config is None: + raise ValueError("Admin api config is not set") + return self._admin_api_config + + def get_admin_api_client(self) -> DbtAdminAPIClient: + if self._admin_api_client is None: + raise ValueError("Admin api client is not set") + return self._admin_api_client diff --git a/tests/unit/dbt_admin/test_tools.py b/tests/unit/dbt_admin/test_tools.py index d62f4a8c..80a3688c 100644 --- a/tests/unit/dbt_admin/test_tools.py +++ b/tests/unit/dbt_admin/test_tools.py @@ -8,6 +8,7 @@ create_admin_api_tool_definitions, register_admin_api_tools, ) +from dbt_mcp.tools.config import DbtMcpContext @pytest.fixture @@ -103,7 +104,7 @@ def test_register_admin_api_tools_all_tools( mock_get_prompt.return_value = "Test prompt" fastmcp, tools = mock_fastmcp - register_admin_api_tools(fastmcp, mock_config.admin_api_config, []) + register_admin_api_tools(fastmcp, []) # Should call register_tools with 9 tool definitions mock_register_tools.assert_called_once() @@ -121,7 +122,7 @@ def test_register_admin_api_tools_with_disabled_tools( fastmcp, tools = mock_fastmcp disable_tools = ["list_jobs", "get_job", "trigger_job_run"] - register_admin_api_tools(fastmcp, mock_config.admin_api_config, disable_tools) + register_admin_api_tools(fastmcp, disable_tools) # Should still call register_tools with all 9 tool definitions # The exclude_tools parameter is passed to register_tools to handle filtering @@ -137,12 +138,11 @@ def test_register_admin_api_tools_with_disabled_tools( def test_list_jobs_tool(mock_get_prompt, mock_config, mock_admin_client): mock_get_prompt.return_value = "List jobs prompt" - tool_definitions = create_admin_api_tool_definitions( - mock_admin_client, mock_config.admin_api_config - ) + tool_definitions = create_admin_api_tool_definitions() + context = DbtMcpContext(admin_api_config=mock_config.admin_api_config, admin_api_client=mock_admin_client) list_jobs_tool = tool_definitions[0].fn # First tool is list_jobs - result = list_jobs_tool(limit=10) + result = list_jobs_tool(ctx=context, limit=10) assert isinstance(result, list) mock_admin_client.list_jobs.assert_called_once_with(12345, limit=10) @@ -152,12 +152,11 @@ def test_list_jobs_tool(mock_get_prompt, mock_config, mock_admin_client): def test_get_job_details_tool(mock_get_prompt, mock_config, mock_admin_client): mock_get_prompt.return_value = "Get job prompt" - tool_definitions = create_admin_api_tool_definitions( - mock_admin_client, mock_config.admin_api_config - ) + tool_definitions = create_admin_api_tool_definitions() + context = DbtMcpContext(admin_api_config=mock_config.admin_api_config, admin_api_client=mock_admin_client) get_job_details_tool = tool_definitions[1].fn # Second tool is get_job_details - result = get_job_details_tool(job_id=1) + result = get_job_details_tool(ctx=context, job_id=1) assert isinstance(result, dict) mock_admin_client.get_job_details.assert_called_once_with(12345, 1) @@ -167,12 +166,13 @@ def test_get_job_details_tool(mock_get_prompt, mock_config, mock_admin_client): def test_trigger_job_run_tool(mock_get_prompt, mock_config, mock_admin_client): mock_get_prompt.return_value = "Trigger job run prompt" - tool_definitions = create_admin_api_tool_definitions( - mock_admin_client, mock_config.admin_api_config - ) + tool_definitions = create_admin_api_tool_definitions() + context = DbtMcpContext(admin_api_config=mock_config.admin_api_config, admin_api_client=mock_admin_client) trigger_job_run_tool = tool_definitions[2].fn # Third tool is trigger_job_run - result = trigger_job_run_tool(job_id=1, cause="Manual trigger", git_branch="main") + result = trigger_job_run_tool( + ctx=context, job_id=1, cause="Manual trigger", git_branch="main" + ) assert isinstance(result, dict) mock_admin_client.trigger_job_run.assert_called_once_with( @@ -184,12 +184,13 @@ def test_trigger_job_run_tool(mock_get_prompt, mock_config, mock_admin_client): def test_list_jobs_runs_tool(mock_get_prompt, mock_config, mock_admin_client): mock_get_prompt.return_value = "List runs prompt" - tool_definitions = create_admin_api_tool_definitions( - mock_admin_client, mock_config.admin_api_config - ) + tool_definitions = create_admin_api_tool_definitions() + context = DbtMcpContext(admin_api_config=mock_config.admin_api_config, admin_api_client=mock_admin_client) list_jobs_runs_tool = tool_definitions[3].fn # Fourth tool is list_jobs_runs - result = list_jobs_runs_tool(job_id=1, status=JobRunStatus.SUCCESS, limit=5) + result = list_jobs_runs_tool( + ctx=context, job_id=1, status=JobRunStatus.SUCCESS, limit=5 + ) assert isinstance(result, list) mock_admin_client.list_jobs_runs.assert_called_once_with( @@ -201,14 +202,13 @@ def test_list_jobs_runs_tool(mock_get_prompt, mock_config, mock_admin_client): def test_get_job_run_details_tool(mock_get_prompt, mock_config, mock_admin_client): mock_get_prompt.return_value = "Get run prompt" - tool_definitions = create_admin_api_tool_definitions( - mock_admin_client, mock_config.admin_api_config - ) + tool_definitions = create_admin_api_tool_definitions() + context = DbtMcpContext(admin_api_config=mock_config.admin_api_config, admin_api_client=mock_admin_client) get_job_run_details_tool = tool_definitions[ 4 ].fn # Fifth tool is get_job_run_details - result = get_job_run_details_tool(run_id=100, debug=True) + result = get_job_run_details_tool(ctx=context, run_id=100, debug=True) assert isinstance(result, dict) mock_admin_client.get_job_run_details.assert_called_once_with( @@ -220,12 +220,11 @@ def test_get_job_run_details_tool(mock_get_prompt, mock_config, mock_admin_clien def test_cancel_job_run_tool(mock_get_prompt, mock_config, mock_admin_client): mock_get_prompt.return_value = "Cancel run prompt" - tool_definitions = create_admin_api_tool_definitions( - mock_admin_client, mock_config.admin_api_config - ) + tool_definitions = create_admin_api_tool_definitions() + context = DbtMcpContext(admin_api_config=mock_config.admin_api_config, admin_api_client=mock_admin_client) cancel_job_run_tool = tool_definitions[5].fn # Sixth tool is cancel_job_run - result = cancel_job_run_tool(run_id=100) + result = cancel_job_run_tool(ctx=context, run_id=100) assert isinstance(result, dict) mock_admin_client.cancel_job_run.assert_called_once_with(12345, 100) @@ -235,12 +234,11 @@ def test_cancel_job_run_tool(mock_get_prompt, mock_config, mock_admin_client): def test_retry_job_run_tool(mock_get_prompt, mock_config, mock_admin_client): mock_get_prompt.return_value = "Retry run prompt" - tool_definitions = create_admin_api_tool_definitions( - mock_admin_client, mock_config.admin_api_config - ) + tool_definitions = create_admin_api_tool_definitions() + context = DbtMcpContext(admin_api_config=mock_config.admin_api_config, admin_api_client=mock_admin_client) retry_job_run_tool = tool_definitions[6].fn # Seventh tool is retry_job_run - result = retry_job_run_tool(run_id=100) + result = retry_job_run_tool(ctx=context, run_id=100) assert isinstance(result, dict) mock_admin_client.retry_job_run.assert_called_once_with(12345, 100) @@ -250,14 +248,13 @@ def test_retry_job_run_tool(mock_get_prompt, mock_config, mock_admin_client): def test_list_job_run_artifacts_tool(mock_get_prompt, mock_config, mock_admin_client): mock_get_prompt.return_value = "List run artifacts prompt" - tool_definitions = create_admin_api_tool_definitions( - mock_admin_client, mock_config.admin_api_config - ) + tool_definitions = create_admin_api_tool_definitions() + context = DbtMcpContext(admin_api_config=mock_config.admin_api_config, admin_api_client=mock_admin_client) list_job_run_artifacts_tool = tool_definitions[ 7 ].fn # Eighth tool is list_job_run_artifacts - result = list_job_run_artifacts_tool(run_id=100) + result = list_job_run_artifacts_tool(ctx=context, run_id=100) assert isinstance(result, list) mock_admin_client.list_job_run_artifacts.assert_called_once_with(12345, 100) @@ -267,15 +264,14 @@ def test_list_job_run_artifacts_tool(mock_get_prompt, mock_config, mock_admin_cl def test_get_job_run_artifact_tool(mock_get_prompt, mock_config, mock_admin_client): mock_get_prompt.return_value = "Get run artifact prompt" - tool_definitions = create_admin_api_tool_definitions( - mock_admin_client, mock_config.admin_api_config - ) + tool_definitions = create_admin_api_tool_definitions() + context = DbtMcpContext(admin_api_config=mock_config.admin_api_config, admin_api_client=mock_admin_client) get_job_run_artifact_tool = tool_definitions[ 8 ].fn # Ninth tool is get_job_run_artifact result = get_job_run_artifact_tool( - run_id=100, artifact_path="manifest.json", step=1 + ctx=context, run_id=100, artifact_path="manifest.json", step=1 ) assert result is not None @@ -290,12 +286,11 @@ def test_tools_handle_exceptions(mock_get_prompt, mock_config): mock_admin_client = Mock() mock_admin_client.list_jobs.side_effect = Exception("API Error") - tool_definitions = create_admin_api_tool_definitions( - mock_admin_client, mock_config.admin_api_config - ) + tool_definitions = create_admin_api_tool_definitions() + context = DbtMcpContext(admin_api_config=mock_config.admin_api_config, admin_api_client=mock_admin_client) list_jobs_tool = tool_definitions[0].fn # First tool is list_jobs - result = list_jobs_tool() + result = list_jobs_tool(ctx=context) assert isinstance(result, str) assert "API Error" in result @@ -307,25 +302,24 @@ def test_tools_with_no_optional_parameters( ): mock_get_prompt.return_value = "Test prompt" - tool_definitions = create_admin_api_tool_definitions( - mock_admin_client, mock_config.admin_api_config - ) + tool_definitions = create_admin_api_tool_definitions() + context = DbtMcpContext(admin_api_config=mock_config.admin_api_config, admin_api_client=mock_admin_client) # Test list_jobs with no parameters list_jobs_tool = tool_definitions[0].fn - result = list_jobs_tool() + result = list_jobs_tool(ctx=context) assert isinstance(result, list) mock_admin_client.list_jobs.assert_called_with(12345) # Test list_jobs_runs with no parameters list_jobs_runs_tool = tool_definitions[3].fn - result = list_jobs_runs_tool() + result = list_jobs_runs_tool(ctx=context) assert isinstance(result, list) mock_admin_client.list_jobs_runs.assert_called_with(12345) # Test get_job_run_details with default debug parameter get_job_run_details_tool = tool_definitions[4].fn - result = get_job_run_details_tool(run_id=100) + result = get_job_run_details_tool(ctx=context, run_id=100) assert isinstance(result, dict) # The debug parameter should be a Field object with default False call_args = mock_admin_client.get_job_run_details.call_args @@ -341,12 +335,12 @@ def test_trigger_job_run_with_all_optional_params( ): mock_get_prompt.return_value = "Trigger job run prompt" - tool_definitions = create_admin_api_tool_definitions( - mock_admin_client, mock_config.admin_api_config - ) + tool_definitions = create_admin_api_tool_definitions() + context = DbtMcpContext(admin_api_config=mock_config.admin_api_config, admin_api_client=mock_admin_client) trigger_job_run_tool = tool_definitions[2].fn # Third tool is trigger_job_run result = trigger_job_run_tool( + ctx=context, job_id=1, cause="Manual trigger", git_branch="feature-branch", diff --git a/tests/unit/dbt_cli/test_cli_integration.py b/tests/unit/dbt_cli/test_cli_integration.py index 136ff740..6e4d805c 100644 --- a/tests/unit/dbt_cli/test_cli_integration.py +++ b/tests/unit/dbt_cli/test_cli_integration.py @@ -1,6 +1,7 @@ import unittest from unittest.mock import MagicMock, patch +from dbt_mcp.tools.config import DbtMcpContext from tests.mocks.config import mock_config @@ -35,7 +36,7 @@ def decorator(func): mock_fastmcp.tool = mock_tool_decorator # Register the tools - register_dbt_cli_tools(mock_fastmcp, mock_config.dbt_cli_config) + register_dbt_cli_tools(mock_fastmcp) # Test cases for different command types test_cases = [ @@ -89,12 +90,14 @@ def decorator(func): ), ] + context = DbtMcpContext(dbt_cli_config=mock_config.dbt_cli_config) + # Run each test case for command_name, args, expected_args in test_cases: mock_popen.reset_mock() # Call the function - result = tools[command_name](*args) + result = tools[command_name](context, *args) # Verify the command was called correctly mock_popen.assert_called_once() diff --git a/tests/unit/dbt_cli/test_tools.py b/tests/unit/dbt_cli/test_tools.py index 16e1cb75..f23f8579 100644 --- a/tests/unit/dbt_cli/test_tools.py +++ b/tests/unit/dbt_cli/test_tools.py @@ -4,6 +4,7 @@ from pytest import MonkeyPatch from dbt_mcp.dbt_cli.tools import register_dbt_cli_tools +from dbt_mcp.tools.config import DbtMcpContext from tests.mocks.config import mock_dbt_cli_config @@ -115,11 +116,12 @@ def mock_popen(args, **kwargs): # Register tools and get show tool fastmcp, tools = mock_fastmcp - register_dbt_cli_tools(fastmcp, mock_dbt_cli_config) + register_dbt_cli_tools(fastmcp) + context = DbtMcpContext(dbt_cli_config=mock_dbt_cli_config) show_tool = tools["show"] # Call show tool with test parameters - show_tool(sql_query=sql_query, limit=limit_param) + show_tool(ctx=context, sql_query=sql_query, limit=limit_param) # Verify the command was called with expected arguments assert mock_calls @@ -141,11 +143,12 @@ def mock_popen(args, **kwargs): # Setup mock_fastmcp_obj, tools = mock_fastmcp - register_dbt_cli_tools(mock_fastmcp_obj, mock_dbt_cli_config) + register_dbt_cli_tools(mock_fastmcp_obj) + context = DbtMcpContext(dbt_cli_config=mock_dbt_cli_config) run_tool = tools["run"] # Execute - run_tool() + run_tool(ctx=context) # Verify assert mock_calls @@ -168,11 +171,12 @@ def mock_popen(args, **kwargs): fastmcp, tools = mock_fastmcp # Register the tools - register_dbt_cli_tools(fastmcp, mock_dbt_cli_config) + register_dbt_cli_tools(fastmcp) + context = DbtMcpContext(dbt_cli_config=mock_dbt_cli_config) run_tool = tools["run"] # Run the command with a selector - run_tool(selector="my_model") + run_tool(ctx=context, selector="my_model") # Verify the command is correctly formatted assert mock_calls @@ -200,11 +204,12 @@ def mock_popen(args, **kwargs): # Setup mock_fastmcp_obj, tools = mock_fastmcp - register_dbt_cli_tools(mock_fastmcp_obj, mock_dbt_cli_config) + register_dbt_cli_tools(mock_fastmcp_obj) + context = DbtMcpContext(dbt_cli_config=mock_dbt_cli_config) show_tool = tools["show"] # Execute - show_tool(sql_query="SELECT * FROM my_model") + show_tool(ctx=context, sql_query="SELECT * FROM my_model") # Verify assert mock_calls @@ -229,16 +234,17 @@ def mock_popen(*args, **kwargs): # Setup mock_fastmcp_obj, tools = mock_fastmcp - register_dbt_cli_tools(mock_fastmcp_obj, mock_dbt_cli_config) + register_dbt_cli_tools(mock_fastmcp_obj) + context = DbtMcpContext(dbt_cli_config=mock_dbt_cli_config) list_tool = tools["ls"] # Test timeout case - result = list_tool(resource_type=["model", "snapshot"]) + result = list_tool(ctx=context, resource_type=["model", "snapshot"]) assert "Timeout: dbt command took too long to complete" in result assert "Try using a specific selector to narrow down the results" in result # Test with selector - should still timeout - result = list_tool(selector="my_model", resource_type=["model"]) + result = list_tool(ctx=context, selector="my_model", resource_type=["model"]) assert "Timeout: dbt command took too long to complete" in result assert "Try using a specific selector to narrow down the results" in result @@ -256,10 +262,11 @@ def mock_popen(args, **kwargs): monkeypatch.setattr("subprocess.Popen", mock_popen) fastmcp, tools = mock_fastmcp - register_dbt_cli_tools(fastmcp, mock_dbt_cli_config) + register_dbt_cli_tools(fastmcp) + context = DbtMcpContext(dbt_cli_config=mock_dbt_cli_config) tool = tools[command_name] - tool(is_full_refresh=True) + tool(ctx=context, is_full_refresh=True) assert mock_calls args_list = mock_calls[0] @@ -279,10 +286,11 @@ def mock_popen(args, **kwargs): monkeypatch.setattr("subprocess.Popen", mock_popen) fastmcp, tools = mock_fastmcp - register_dbt_cli_tools(fastmcp, mock_dbt_cli_config) + register_dbt_cli_tools(fastmcp) + context = DbtMcpContext(dbt_cli_config=mock_dbt_cli_config) tool = tools[command_name] - tool(vars="environment: production") + tool(ctx=context, vars="environment: production") assert mock_calls args_list = mock_calls[0] @@ -300,10 +308,11 @@ def mock_popen(args, **kwargs): monkeypatch.setattr("subprocess.Popen", mock_popen) fastmcp, tools = mock_fastmcp - register_dbt_cli_tools(fastmcp, mock_dbt_cli_config) + register_dbt_cli_tools(fastmcp) + context = DbtMcpContext(dbt_cli_config=mock_dbt_cli_config) build_tool = tools["build"] - build_tool() # Non-explicit + build_tool(ctx=context) # Non-explicit assert mock_calls args_list = mock_calls[0]