diff --git a/docs/docs/sql/index.mdx b/docs/docs/sql/index.mdx index 510504fb37c..8d0aa21a92f 100644 --- a/docs/docs/sql/index.mdx +++ b/docs/docs/sql/index.mdx @@ -52,3 +52,40 @@ You can run the linter via `lint` goal: ``` pants lint --only=sqlfluff :: ``` + +## Using different templaters + +As of 3.2.5, sqlfluff doesn't support using different templaters in +subdirectories. If you try to set `templater` in a subdirectory you'll get +something like this: + +``` +WARNING Attempt to set templater to jinja failed. Using python templater. +Templater cannot be set in a .sqlfluff file in a subdirectory of the current +working directory. It can be set in a .sqlfluff in the current working +directory. See Nesting section of the docs for more details. +``` + +However, pants overcomes this limitation by running multiple sqlfluff processes +in parallel. You can set the `templater` config value and pants will figure out +how to partition your files: + +```toml tab={"label":"a/.sqlfluff"} +[sqlfluff] +templater = python +dialect = postgres +``` + +```sql tab={"label":"a/query.sql"} +select * from {table} +``` + +```toml tab={"label":"b/.sqlfluff"} +[sqlfluff] +templater = jinja +dialect = postgres +``` + +```sql tab={"label":"b/query.sql"} +select * from {{table}} +``` diff --git a/docs/notes/2.26.x.md b/docs/notes/2.26.x.md index 48fdb3f765b..88053c6b21e 100644 --- a/docs/notes/2.26.x.md +++ b/docs/notes/2.26.x.md @@ -38,6 +38,10 @@ Some deprecations have expired and been removed: For the `tfsec` linter, the deprecation of support for leading `v`s in the `version` and `known_versions` field has expired and been removed. Write `1.28.13` instead of `v1.28.13`. +#### SQL + +Sqlfluff can now be used with multiple different templaters in a single repo. + ### Plugin API changes diff --git a/src/python/pants/backend/sql/BUILD b/src/python/pants/backend/sql/BUILD index b95a2079024..3928f7e3ba8 100644 --- a/src/python/pants/backend/sql/BUILD +++ b/src/python/pants/backend/sql/BUILD @@ -1,9 +1,4 @@ # Copyright 2024 Pants project contributors (see CONTRIBUTORS.md). # Licensed under the Apache License, Version 2.0 (see LICENSE). -python_sources( - sources=[ - "tailor.py", - "target_types.py", - ], -) +python_sources() diff --git a/src/python/pants/backend/sql/lint/sqlfluff/BUILD b/src/python/pants/backend/sql/lint/sqlfluff/BUILD index 8d68558a9ac..dd1a3ff744e 100644 --- a/src/python/pants/backend/sql/lint/sqlfluff/BUILD +++ b/src/python/pants/backend/sql/lint/sqlfluff/BUILD @@ -8,6 +8,7 @@ python_sources( python_tests( name="tests", + timeout=600, overrides={ "rules_integration_test.py": { "tags": ["platform_specific_behavior"], diff --git a/src/python/pants/backend/sql/lint/sqlfluff/rules.py b/src/python/pants/backend/sql/lint/sqlfluff/rules.py index b02b7cd4e23..7b6ec242940 100644 --- a/src/python/pants/backend/sql/lint/sqlfluff/rules.py +++ b/src/python/pants/backend/sql/lint/sqlfluff/rules.py @@ -2,8 +2,11 @@ # Licensed under the Apache License, Version 2.0 (see LICENSE). from __future__ import annotations +import logging +import re +from collections import defaultdict from dataclasses import dataclass -from typing import Any, Iterable, Tuple +from typing import Any, DefaultDict, Iterable, Iterator, Sequence, Tuple from typing_extensions import assert_never @@ -14,22 +17,26 @@ from pants.core.goals.fmt import FmtResult, FmtTargetsRequest from pants.core.goals.lint import LintResult, LintTargetsRequest from pants.core.util_rules.config_files import find_config_file -from pants.core.util_rules.partitions import PartitionerType +from pants.core.util_rules.partitions import Partition, PartitionerType, Partitions from pants.core.util_rules.source_files import SourceFilesRequest, determine_source_files -from pants.engine.fs import MergeDigests +from pants.engine.collection import Collection +from pants.engine.fs import Digest, DigestContents, FileContent, MergeDigests from pants.engine.internals.native_engine import Snapshot from pants.engine.intrinsics import execute_process, merge_digests from pants.engine.process import FallibleProcessResult -from pants.engine.rules import Rule, collect_rules, concurrently, implicitly, rule +from pants.engine.rules import Get, Rule, collect_rules, concurrently, implicitly, rule +from pants.util.frozendict import FrozenDict from pants.util.logging import LogLevel from pants.util.meta import classproperty from pants.util.strutil import pluralize +logger = logging.getLogger(__name__) + class SqlfluffFixRequest(FixTargetsRequest): field_set_type = SqlfluffFieldSet tool_subsystem = Sqlfluff - partitioner_type = PartitionerType.DEFAULT_SINGLE_PARTITION + partitioner_type = PartitionerType.CUSTOM # We don't need to include automatically added lint rules for this SqlfluffFixRequest, # because these lint rules are already checked by SqlfluffLintRequest. @@ -39,13 +46,13 @@ class SqlfluffFixRequest(FixTargetsRequest): class SqlfluffLintRequest(LintTargetsRequest): field_set_type = SqlfluffFieldSet tool_subsystem = Sqlfluff - partitioner_type = PartitionerType.DEFAULT_SINGLE_PARTITION + partitioner_type = PartitionerType.CUSTOM class SqlfluffFormatRequest(FmtTargetsRequest): field_set_type = SqlfluffFieldSet tool_subsystem = Sqlfluff - partitioner_type = PartitionerType.DEFAULT_SINGLE_PARTITION + partitioner_type = PartitionerType.CUSTOM @classproperty def tool_name(cls) -> str: @@ -59,6 +66,7 @@ def tool_id(cls) -> str: @dataclass(frozen=True) class _RunSqlfluffRequest: snapshot: Snapshot + templater: str mode: SqlfluffMode @@ -84,12 +92,19 @@ async def run_sqlfluff( assert_never(request.mode) conf_args = ["--config", sqlfluff.config] if sqlfluff.config else [] + templater_args = ["--templater", request.templater] if request.templater is not None else [] result = await execute_process( **implicitly( VenvPexProcess( sqlfluff_pex, - argv=(*initial_args, *conf_args, *sqlfluff.args, *request.snapshot.files), + argv=( + *initial_args, + *templater_args, + *conf_args, + *sqlfluff.args, + *request.snapshot.files, + ), input_digest=input_digest, output_files=request.snapshot.files, description=f"Run sqlfluff {' '.join(initial_args)} on {pluralize(len(request.snapshot.files), 'file')}.", @@ -100,14 +115,172 @@ async def run_sqlfluff( return result +@dataclass(frozen=True) +class TemplaterMetadata: + templater: str | None + + @property + def description(self) -> str: + return f"templater={self.templater}" + + +class ConfigParser: + def __init__(self) -> None: + self.regex = re.compile("^templater *= *(?P[^ ]+) *$") + + def parse_templater(self, content: str) -> str | None: + for line in content.splitlines(): + if match := self.regex.match(line): + return match.group("templater").strip('"') + return None + + +@dataclass(frozen=True) +class NestedConfig: + templaters: dict[str, str] + + @classmethod + def new(cls, parser: ConfigParser, contents: Collection[FileContent]) -> NestedConfig: + templaters = {} + for file_content in contents: + content = file_content.content.decode("utf-8") + templater = parser.parse_templater(content) + if templater is None: + continue + + directory = file_content.path.rsplit("/", 1)[0] + templaters[directory] = templater + + return NestedConfig(templaters) + + def find_templater(self, directory: str) -> str | None: + for d in recursively(directory): + templater = self.templaters.get(d) + if templater is not None: + return templater + return None + + +def recursively(directory: str) -> Iterator[str]: + while True: + yield directory + parts = directory.rsplit("/", 1) + if len(parts) == 1: + return + + directory, _ = parts + + +@dataclass(frozen=True) +class _GroupByTemplaterRequest: + field_sets: Sequence[SqlfluffFieldSet] + + +@dataclass(frozen=True) +class _GroupByTemplaterResult: + groups: FrozenDict[str | None, tuple[SqlfluffFieldSet, ...]] + + +@rule +async def _group_by_templater( + request: _GroupByTemplaterRequest, + sqlfluff: Sqlfluff, +) -> _GroupByTemplaterResult: + dirs = [ + directory + for field_set in request.field_sets + for directory in recursively(field_set.address.spec_path) + ] + + config_files = await find_config_file(sqlfluff.config_request(dirs)) + + logger.debug("sqlfluff config files: %s", config_files.snapshot.files) + contents = await Get(DigestContents, Digest, config_files.snapshot.digest) + + parser = ConfigParser() + nested_config = NestedConfig.new(parser, contents) + logger.debug("sqlfluff nested config: %s", nested_config) + + result: DefaultDict[str | None, list[SqlfluffFieldSet]] = defaultdict(list) + for field_set in request.field_sets: + directory = field_set.address.spec_path + templater = nested_config.find_templater(directory) + if templater is None: + result[None].append(field_set) + continue + + result[templater].append(field_set) + return _GroupByTemplaterResult(groups=FrozenDict((k, tuple(v)) for k, v in result.items())) + + +@dataclass(frozen=True) +class _GroupFilesByTemplaterRequest: + field_sets: Sequence[SqlfluffFieldSet] + + +@rule +async def _group_files_by_templater(request: _GroupFilesByTemplaterRequest) -> Partitions: + result = await _group_by_templater(**implicitly(_GroupByTemplaterRequest(request.field_sets))) + gets = [ + determine_source_files(SourceFilesRequest(field_set.source for field_set in field_sets)) + for field_sets in result.groups.values() + ] + all_source_files = (await concurrently(*gets)) if gets else [] + + partitions = Partitions( + Partition( + elements=source_files.files, + metadata=TemplaterMetadata(templater), + ) + for templater, source_files in zip(result.groups, all_source_files) + ) + return partitions + + +@dataclass(frozen=True) +class _GroupFieldSetsByTemplaterRequest: + field_sets: Sequence[SqlfluffFieldSet] + + +@rule +async def _group_field_sets_by_templater( + request: _GroupFieldSetsByTemplaterRequest, +) -> Partitions: + result = await _group_by_templater(**implicitly(_GroupByTemplaterRequest(request.field_sets))) + partitions = Partitions( + Partition( + elements=tuple(sorted(field_sets, key=lambda fs: fs.address)), + metadata=TemplaterMetadata(templater), + ) + for templater, field_sets in result.groups.items() + ) + logger.debug("sqlfluff partitions: %s", partitions) + return partitions + + @rule(desc="Fix with sqlfluff fix", level=LogLevel.DEBUG) async def sqlfluff_fix(request: SqlfluffFixRequest.Batch, sqlfluff: Sqlfluff) -> FixResult: result = await run_sqlfluff( - _RunSqlfluffRequest(snapshot=request.snapshot, mode=SqlfluffMode.FIX), sqlfluff + _RunSqlfluffRequest( + snapshot=request.snapshot, + templater=request.partition_metadata.templater, + mode=SqlfluffMode.FIX, + ), + sqlfluff, ) return await FixResult.create(request, result) +@rule +async def sqlfluff_fix_partition( + request: SqlfluffFixRequest.PartitionRequest, +) -> Partitions: + return await Get( + Partitions, + _GroupFilesByTemplaterRequest(field_sets=request.field_sets), + ) + + @rule(desc="Lint with sqlfluff lint", level=LogLevel.DEBUG) async def sqlfluff_lint( request: SqlfluffLintRequest.Batch[SqlfluffFieldSet, Any], sqlfluff: Sqlfluff @@ -116,19 +289,49 @@ async def sqlfluff_lint( SourceFilesRequest(field_set.source for field_set in request.elements) ) result = await run_sqlfluff( - _RunSqlfluffRequest(snapshot=source_files.snapshot, mode=SqlfluffMode.LINT), sqlfluff + _RunSqlfluffRequest( + snapshot=source_files.snapshot, + templater=request.partition_metadata.templater, + mode=SqlfluffMode.LINT, + ), + sqlfluff, ) return LintResult.create(request, result) +@rule +async def sqlfluff_lint_partition( + request: SqlfluffLintRequest.PartitionRequest, +) -> Partitions: + return await Get( + Partitions, + _GroupFieldSetsByTemplaterRequest(field_sets=request.field_sets), + ) + + @rule(desc="Format with sqlfluff format", level=LogLevel.DEBUG) async def sqlfluff_fmt(request: SqlfluffFormatRequest.Batch, sqlfluff: Sqlfluff) -> FmtResult: result = await run_sqlfluff( - _RunSqlfluffRequest(snapshot=request.snapshot, mode=SqlfluffMode.FMT), sqlfluff + _RunSqlfluffRequest( + snapshot=request.snapshot, + templater=request.partition_metadata.templater, + mode=SqlfluffMode.FMT, + ), + sqlfluff, ) return await FmtResult.create(request, result) +@rule +async def sqlfluff_fmt_partition( + request: SqlfluffFormatRequest.PartitionRequest, +) -> Partitions: + return await Get( + Partitions, + _GroupFilesByTemplaterRequest(field_sets=request.field_sets), + ) + + def rules() -> Iterable[Rule]: return ( *collect_rules(), diff --git a/src/python/pants/backend/sql/lint/sqlfluff/rules_integration_test.py b/src/python/pants/backend/sql/lint/sqlfluff/rules_integration_test.py index eae9ae6b9eb..cc8156c06c8 100644 --- a/src/python/pants/backend/sql/lint/sqlfluff/rules_integration_test.py +++ b/src/python/pants/backend/sql/lint/sqlfluff/rules_integration_test.py @@ -2,6 +2,7 @@ # Licensed under the Apache License, Version 2.0 (see LICENSE). from __future__ import annotations +import os from pathlib import Path from textwrap import dedent @@ -15,17 +16,13 @@ SqlfluffFormatRequest, SqlfluffLintRequest, ) -from pants.backend.sql.lint.sqlfluff.skip_field import SkipSqlfluffField -from pants.backend.sql.lint.sqlfluff.subsystem import SqlfluffFieldSet from pants.backend.sql.target_types import SqlSourcesGeneratorTarget from pants.core.goals.fix import FixResult from pants.core.goals.fmt import FmtResult from pants.core.goals.lint import LintResult from pants.core.util_rules import config_files -from pants.core.util_rules.partitions import _EmptyMetadata from pants.core.util_rules.source_files import SourceFiles, SourceFilesRequest -from pants.engine.addresses import Address -from pants.engine.target import Target +from pants.testutil.pants_integration_test import PantsResult, run_pants, setup_tmpdir from pants.testutil.rule_runner import QueryRule, RuleRunner GOOD_FILE = dedent( @@ -56,6 +53,20 @@ dialect = postgres """ ) +CONFIG_PYPROJECT = dedent( + """\ + [tool.sqlfluff.core] + dialect = "postgres" + exclude_rules = ["RF03"] + """ +) +CONFIG_NATIVE = dedent( + """\ + [sqlfluff] + dialect = postgres + exclude_rules = RF03 + """ +) @pytest.fixture @@ -77,238 +88,295 @@ def rule_runner() -> RuleRunner: def run_sqlfluff( rule_runner: RuleRunner, - targets: list[Target], + paths: list[str], + files: dict[str, str], *, extra_args: list[str] | None = None, -) -> tuple[FixResult, LintResult, FmtResult]: +) -> tuple[PantsResult, PantsResult, PantsResult]: args = [ - "--backend-packages=pants.backend.sql.lint.sqlfluff", + "--backend-packages=['pants.backend.experimental.sql','pants.backend.experimental.sql.lint.sqlfluff']", + "--no-watch-filesystem", + "--no-pantsd", '--sqlfluff-fix-args="--force"', *(extra_args or ()), ] - rule_runner.set_options(args, env_inherit={"PATH", "PYENV_ROOT", "HOME"}) - field_sets = [ - SqlfluffFieldSet.create(tgt) for tgt in targets if SqlfluffFieldSet.is_applicable(tgt) - ] - source_reqs = [SourceFilesRequest(field_set.source for field_set in field_sets)] - input_sources = rule_runner.request(SourceFiles, source_reqs) - - fix_result = rule_runner.request( - FixResult, - [ - SqlfluffFixRequest.Batch( - "", - tuple(field_sets), - partition_metadata=_EmptyMetadata(), - snapshot=input_sources.snapshot, - ), - ], - ) - lint_result = rule_runner.request( - LintResult, - [ - SqlfluffLintRequest.Batch( - "", - tuple(field_sets), - partition_metadata=_EmptyMetadata(), - ), - ], - ) - fmt_result = rule_runner.request( - FmtResult, - [ - SqlfluffFormatRequest.Batch( - "", - tuple(field_sets), - partition_metadata=_EmptyMetadata(), - snapshot=input_sources.snapshot, - ) - ], - ) + with setup_tmpdir(files) as workdir: + result = run_pants(command=[*args, "list", f"{workdir}/::"]) + result.assert_success() + assert result.stdout == "abc" + addresses = [f"{workdir}/{path}" for path in paths] + fix_result = run_pants(command=[*args, "fix", *addresses]) + lint_result = run_pants(command=[*args, "lint", *addresses]) + fmt_result = run_pants(command=[*args, "fmt", *addresses]) return fix_result, lint_result, fmt_result -@pytest.mark.platform_specific_behavior -def test_passing(rule_runner: RuleRunner) -> None: - rule_runner.write_files( - { - "query.sql": GOOD_FILE, - "BUILD": "sql_sources(name='t')", - ".sqlfluff": CONFIG_POSTGRES, - } - ) - tgt = rule_runner.get_target(Address("", target_name="t", relative_file_path="query.sql")) - fix_result, lint_result, fmt_result = run_sqlfluff(rule_runner, [tgt]) - assert fix_result.stdout == dedent( - """\ - ==== finding fixable violations ==== - FORCE MODE: Attempting fixes... - ==== no fixable linting violations found ==== - All Finished! - """ - ) - assert fix_result.stderr == "" - assert lint_result.exit_code == 0 - assert not fix_result.did_change - assert fix_result.output == rule_runner.make_snapshot({"query.sql": GOOD_FILE}) - assert not fmt_result.did_change - assert fmt_result.output == rule_runner.make_snapshot({"query.sql": GOOD_FILE}) - - -def test_failing(rule_runner: RuleRunner) -> None: - rule_runner.write_files( - { - "query.sql": BAD_FILE, - "BUILD": "sql_sources(name='t')", - ".sqlfluff": CONFIG_POSTGRES, - } - ) - tgt = rule_runner.get_target(Address("", target_name="t", relative_file_path="query.sql")) - fix_result, lint_result, fmt_result = run_sqlfluff(rule_runner, [tgt]) - assert fix_result.stdout == dedent( - """\ - ==== finding fixable violations ==== - FORCE MODE: Attempting fixes... - == [query.sql] FAIL - L: 3 | P: 5 | RF03 | Unqualified reference 'name' found in single table - | select. [references.consistent] - == [query.sql] FIXED - 1 fixable linting violations found - [1 unfixable linting violations found] - """ - ) - assert fix_result.stderr == "" - assert lint_result.stdout == dedent( - """\ - == [query.sql] FAIL - L: 3 | P: 5 | RF03 | Unqualified reference 'name' found in single table - | select. [references.consistent] - L: 3 | P: 5 | RF03 | Unqualified reference 'name' found in single table - | select which is inconsistent with previous references. - | [references.consistent] - All Finished! - """ - ) - assert lint_result.stderr == "" - assert lint_result.exit_code == 1 - assert fix_result.did_change - assert fix_result.output == rule_runner.make_snapshot({"query.sql": GOOD_FILE}) - assert not fmt_result.did_change - assert fmt_result.output == rule_runner.make_snapshot({"query.sql": BAD_FILE}) - - -def test_multiple_targets(rule_runner: RuleRunner) -> None: - rule_runner.write_files( - { - "good.sql": GOOD_FILE, - "bad.sql": BAD_FILE, - "unformatted.sql": UNFORMATTED_FILE, - "BUILD": "sql_sources(name='t')", - ".sqlfluff": CONFIG_POSTGRES, - } - ) - tgts = [ - rule_runner.get_target(Address("", target_name="t", relative_file_path="good.sql")), - rule_runner.get_target(Address("", target_name="t", relative_file_path="bad.sql")), - rule_runner.get_target(Address("", target_name="t", relative_file_path="unformatted.sql")), +def collect_files(rootdir: str) -> dict: + result = {} + for root, _, files in os.walk(rootdir): + for file in files: + path = f"{root}/{file}" + relpath = os.path.relpath(path, rootdir) + result[relpath] = Path(path).read_text() + return result + + +@pytest.fixture +def args(): + return [ + "--backend-packages=['pants.backend.experimental.sql','pants.backend.experimental.sql.lint.sqlfluff']", + '--sqlfluff-fix-args="--force"', + "--python-interpreter-constraints=['==3.12.*']", ] - fix_result, lint_result, fmt_result = run_sqlfluff(rule_runner, tgts) - assert lint_result.exit_code == 1 - assert fix_result.output == rule_runner.make_snapshot( - {"good.sql": GOOD_FILE, "bad.sql": GOOD_FILE, "unformatted.sql": GOOD_FILE} - ) - assert fix_result.did_change is True - assert fmt_result.output == rule_runner.make_snapshot( - {"good.sql": GOOD_FILE, "bad.sql": BAD_FILE, "unformatted.sql": GOOD_FILE} + + +@pytest.fixture +def good_query(): + return { + "project/query.sql": GOOD_FILE, + "project/BUILD": "sql_sources()", + "project/.sqlfluff": CONFIG_POSTGRES, + } + + +def test_passing_lint(good_query: dict[str, str], args: list[str]) -> None: + with setup_tmpdir(good_query) as tmpdir: + result = run_pants([*args, "lint", f"{tmpdir}/project:"]) + result.assert_success() + + +def test_passing_fix(good_query: dict[str, str], args: list[str]) -> None: + with setup_tmpdir(good_query) as tmpdir: + result = run_pants([*args, "fix", f"{tmpdir}/project:"]) + files = collect_files(tmpdir) + result.assert_success() + assert result.stdout == "" + assert "sqlfluff format made no changes." in result.stderr + assert "sqlfluff made no changes." in result.stderr + assert files["project/query.sql"] == GOOD_FILE + + +def test_passing_fmt(good_query: dict[str, str], args: list[str]) -> None: + with setup_tmpdir(good_query) as tmpdir: + result = run_pants([*args, "fmt", f"{tmpdir}/project:"]) + files = collect_files(tmpdir) + result.assert_success() + assert result.stdout == "" + assert "sqlfluff format made no changes." in result.stderr + assert files["project/query.sql"] == GOOD_FILE + + +@pytest.fixture +def bad_query(): + return { + "project/query.sql": BAD_FILE, + "project/BUILD": "sql_sources()", + "project/.sqlfluff": CONFIG_POSTGRES, + } + + +def test_failing_lint(bad_query: dict[str, str], args: list[str]) -> None: + with setup_tmpdir(bad_query) as tmpdir: + result = run_pants([*args, "lint", f"{tmpdir}/project:"]) + result.assert_failure() + assert ( + dedent( + f"""\ + == [{tmpdir}/project/query.sql] FAIL + L: 3 | P: 5 | RF03 | Unqualified reference 'name' found in single table + | select. [references.consistent] + L: 3 | P: 5 | RF03 | Unqualified reference 'name' found in single table + | select which is inconsistent with previous references. + | [references.consistent] + All Finished! + """ + ) + in result.stderr ) - assert fmt_result.did_change is True -def test_skip_field(rule_runner: RuleRunner) -> None: - rule_runner.write_files( - { - "good.sql": GOOD_FILE, - "bad.sql": BAD_FILE, - "unformatted.sql": UNFORMATTED_FILE, - "BUILD": "sql_sources(name='t', skip_sqlfluff=True)", - } +def test_failing_fix(bad_query: dict[str, str], args: list[str]) -> None: + with setup_tmpdir(bad_query) as tmpdir: + result = run_pants([*args, "fix", f"{tmpdir}/project:"]) + files = collect_files(tmpdir) + result.assert_success() + assert "sqlfluff made changes." in result.stderr + assert files["project/query.sql"] == GOOD_FILE + + +def test_failing_fmt(bad_query: dict[str, str], args: list[str]) -> None: + with setup_tmpdir(bad_query) as tmpdir: + result = run_pants([*args, "fmt", f"{tmpdir}/project:"]) + files = collect_files(tmpdir) + result.assert_success() + assert "sqlfluff format made no changes." in result.stderr + assert files["project/query.sql"] == BAD_FILE + + +@pytest.fixture +def multiple_queries() -> dict[str, str]: + return { + "project/good.sql": GOOD_FILE, + "project/bad.sql": BAD_FILE, + "project/unformatted.sql": UNFORMATTED_FILE, + "project/BUILD": "sql_sources(name='t')", + "project/.sqlfluff": CONFIG_POSTGRES, + } + + +def test_multiple_targets_lint(multiple_queries: dict[str, str], args: list[str]) -> None: + with setup_tmpdir(multiple_queries) as tmpdir: + result = run_pants([*args, "lint", f"{tmpdir}/project:"]) + result.assert_failure() + assert ( + dedent( + f"""\ + == [{tmpdir}/project/bad.sql] FAIL + L: 3 | P: 5 | RF03 | Unqualified reference 'name' found in single table + | select. [references.consistent] + L: 3 | P: 5 | RF03 | Unqualified reference 'name' found in single table + | select which is inconsistent with previous references. + | [references.consistent] + == [{tmpdir}/project/unformatted.sql] FAIL + L: 1 | P: 1 | LT09 | Select targets should be on a new line unless there is + | only one select target. [layout.select_targets] + All Finished! + """ + ) + in result.stderr ) - tgts = [ - rule_runner.get_target(Address("", target_name="t", relative_file_path="good.sql")), - rule_runner.get_target(Address("", target_name="t", relative_file_path="bad.sql")), - rule_runner.get_target(Address("", target_name="t", relative_file_path="unformatted.sql")), - ] - for tgt in tgts: - assert tgt.get(SkipSqlfluffField).value is True - fix_result, lint_result, fmt_result = run_sqlfluff(rule_runner, tgts) - assert lint_result.exit_code == 0 - assert fix_result.output == rule_runner.make_snapshot({}) - assert fix_result.did_change is False - assert fmt_result.output == rule_runner.make_snapshot({}) - assert fmt_result.did_change is False +def test_multiple_targets_fix(multiple_queries: dict[str, str], args: list[str]) -> None: + with setup_tmpdir(multiple_queries) as tmpdir: + result = run_pants([*args, "fix", f"{tmpdir}/project:"]) + files = collect_files(tmpdir) + result.assert_success() + assert "sqlfluff made changes." in result.stderr + assert files["project/good.sql"] == GOOD_FILE + assert files["project/bad.sql"] == GOOD_FILE + assert files["project/unformatted.sql"] == GOOD_FILE + + +def test_multiple_targets_fmt(multiple_queries: dict[str, str], args: list[str]) -> None: + with setup_tmpdir(multiple_queries) as tmpdir: + result = run_pants([*args, "fmt", f"{tmpdir}/project:"]) + files = collect_files(tmpdir) + result.assert_success() + assert "sqlfluff format made changes." in result.stderr + assert files["project/good.sql"] == GOOD_FILE + assert files["project/bad.sql"] == BAD_FILE + assert files["project/unformatted.sql"] == GOOD_FILE + + +@pytest.fixture +def skip_queries() -> dict[str, str]: + return { + "project/good.sql": GOOD_FILE, + "project/bad.sql": BAD_FILE, + "project/unformatted.sql": UNFORMATTED_FILE, + "project/BUILD": "sql_sources(skip_sqlfluff=True)", + } + + +def test_skip_field_lint(skip_queries: dict[str, str], args: list[str]) -> None: + with setup_tmpdir(skip_queries) as tmpdir: + result = run_pants([*args, "lint", f"{tmpdir}/project:"]) + result.assert_success() + + +def test_skip_field_fix(skip_queries: dict[str, str], args: list[str]) -> None: + with setup_tmpdir(skip_queries) as tmpdir: + result = run_pants([*args, "fix", f"{tmpdir}/project:"]) + files = collect_files(tmpdir) + result.assert_success() + assert files["project/good.sql"] == GOOD_FILE + assert files["project/bad.sql"] == BAD_FILE + assert files["project/unformatted.sql"] == UNFORMATTED_FILE + + +def test_skip_field_fmt(skip_queries: dict[str, str], args: list[str]) -> None: + with setup_tmpdir(skip_queries) as tmpdir: + result = run_pants([*args, "fmt", f"{tmpdir}/project:"]) + files = collect_files(tmpdir) + result.assert_success() + assert files["project/good.sql"] == GOOD_FILE + assert files["project/bad.sql"] == BAD_FILE + assert files["project/unformatted.sql"] == UNFORMATTED_FILE @pytest.mark.parametrize( - "file_path,config_path,extra_args,should_change", + ("files",), ( - [Path("query.sql"), Path("pyproject.toml"), [], False], - [Path("query.sql"), Path(".sqlfluff"), [], False], - [Path("custom/query.sql"), Path("custom/pyproject.toml"), [], False], - [Path("custom/query.sql"), Path("custom/.sqlfluff"), [], False], - [ - Path("query.sql"), - Path("custom/config.sqlfluff"), - ["--sqlfluff-config=custom/config.sqlfluff"], - False, - ], - [ - Path("query.sql"), - Path("custom/.sqlfluff"), - ['--sqlfluff-args="--dialect=postgres"'], - True, - ], + pytest.param( + { + "query.sql": BAD_FILE, + "BUILD": "sql_sources()", + "pyproject.toml": CONFIG_PYPROJECT, + }, + id="root:query.sql+pyproject.toml", + ), + pytest.param( + { + "query.sql": BAD_FILE, + "BUILD": "sql_sources()", + ".sqlfluff": CONFIG_NATIVE, + }, + id="root:query.sql+.sqlfluff", + ), + pytest.param( + { + "project/query.sql": BAD_FILE, + "project/BUILD": "sql_sources()", + "project/pyproject.toml": CONFIG_PYPROJECT, + }, + id="subdir:query.sql+pyproject.toml", + ), + pytest.param( + { + "project/query.sql": BAD_FILE, + "project/BUILD": "sql_sources()", + "project/.sqlfluff": CONFIG_NATIVE, + }, + id="subdir:query.sql+.sqlfluff", + ), ), ) -def test_config_file( - rule_runner: RuleRunner, - file_path: Path, - config_path: Path, - extra_args: list[str], - should_change: bool, -) -> None: - if config_path.stem == "pyproject": - config = dedent( - """\ - [tool.sqlfluff.core] - dialect = "postgres" - exclude_rules = ["RF03"] - """ - ) - else: - config = dedent( - """\ - [sqlfluff] - dialect = postgres - exclude_rules = RF03 - """ - ) +def test_config_file(files: dict[str, str], args: list[str]) -> None: + with setup_tmpdir(files) as tmpdir: + result = run_pants([*args, "fix", f"{tmpdir}::"]) - rule_runner.write_files( - { - file_path: BAD_FILE, - file_path.parent / "BUILD": "sql_sources()", - config_path: config, - } - ) + result.assert_success() + assert "sqlfluff made no changes" in result.stderr + + +def test_custom_config_path(args: list[str]) -> None: + files = { + "query.sql": BAD_FILE, + "BUILD": "sql_sources()", + "custom/config.sqlfluff": CONFIG_NATIVE, + } + + with setup_tmpdir(files) as tmpdir: + extra_args = [f"--sqlfluff-config={tmpdir}/custom/config.sqlfluff"] + result = run_pants([*args, *extra_args, "fix", f"{tmpdir}::"]) + + result.assert_success() + assert "sqlfluff made no changes" in result.stderr + + +def test_project_config_doesnt_affect_root_query(args: list[str]) -> None: + files = { + "query.sql": BAD_FILE, + "BUILD": "sql_sources()", + "custom/.sqlfluff": CONFIG_NATIVE, + } + extra_args = ['--sqlfluff-args="--dialect=postgres"'] + + with setup_tmpdir(files) as tmpdir: + result = run_pants([*args, *extra_args, "fix", f"{tmpdir}::"]) - spec_path = str(file_path.parent).replace(".", "") - rel_file_path = file_path.relative_to(*file_path.parts[:1]) if spec_path else file_path - addr = Address(spec_path, relative_file_path=str(rel_file_path)) - tgt = rule_runner.get_target(addr) - fix_result, lint_result, _ = run_sqlfluff(rule_runner, [tgt], extra_args=extra_args) - assert lint_result.exit_code == (1 if should_change else 0) - assert fix_result.did_change is should_change + result.assert_success() + assert "sqlfluff made changes" in result.stderr