diff --git a/news/3751.feature.md b/news/3751.feature.md new file mode 100644 index 0000000000..4e31324646 --- /dev/null +++ b/news/3751.feature.md @@ -0,0 +1 @@ +Speed up dependency resolution when there are complex conflicts. diff --git a/src/pdm/resolver/providers.py b/src/pdm/resolver/providers.py index f67779c3bd..e0dbe50e11 100644 --- a/src/pdm/resolver/providers.py +++ b/src/pdm/resolver/providers.py @@ -2,6 +2,7 @@ import dataclasses import os +from collections import defaultdict from functools import cached_property from typing import TYPE_CHECKING, Callable @@ -39,6 +40,7 @@ _PROVIDER_REGISTRY: dict[str, type[BaseProvider]] = {} +_CONFLICT_PRIORITY_THRESHOLD = 5 def get_provider(strategy: str) -> type[BaseProvider]: @@ -79,6 +81,8 @@ def __init__( self.excludes = {normalize_name(k) for k in project.pyproject.resolution.get("excludes", [])} self.direct_minimal_versions = direct_minimal_versions self.locked_repository = locked_repository + self._conflict_counts: defaultdict[str, int] = defaultdict(int) + self._conflict_promoted: set[str] = set() def requirement_preference(self, requirement: Requirement) -> Comparable: """Return the preference of a requirement to find candidates. @@ -97,12 +101,49 @@ def requirement_preference(self, requirement: Requirement) -> Comparable: def identify(self, requirement_or_candidate: Requirement | Candidate) -> str: return requirement_or_candidate.identify() + def narrow_requirement_selection( + self, + identifiers: Iterable[str], + resolutions: Mapping[str, Candidate], + candidates: Mapping[str, Iterator[Candidate]], + information: Mapping[str, Iterator[RequirementInformation]], + backtrack_causes: Sequence[RequirementInformation], + ) -> Iterable[str]: + backtrack_identifiers: set[str] = set() + for requirement, parent in backtrack_causes: + names = [requirement.identify()] + if parent is not None: + names.append(parent.identify()) + for name in names: + backtrack_identifiers.add(name) + if name not in resolutions: + self._conflict_counts[name] += 1 + if self._conflict_counts[name] >= _CONFLICT_PRIORITY_THRESHOLD: + self._conflict_promoted.add(name) + + current_backtrack_causes: list[str] = [] + promoted: list[str] = [] + for identifier in identifiers: + if identifier == "python": + return [identifier] + if identifier in backtrack_identifiers: + current_backtrack_causes.append(identifier) + continue + if identifier in self._conflict_promoted: + promoted.append(identifier) + + if current_backtrack_causes: + return current_backtrack_causes + if promoted: + return promoted + return identifiers + def get_preference( self, identifier: str, - resolutions: dict[str, Candidate], - candidates: dict[str, Iterator[Candidate]], - information: dict[str, Iterator[RequirementInformation]], + resolutions: Mapping[str, Candidate], + candidates: Mapping[str, Iterator[Candidate]], + information: Mapping[str, Iterator[RequirementInformation]], backtrack_causes: Sequence[RequirementInformation], ) -> tuple[Comparable, ...]: is_top = any(parent is None for _, parent in information[identifier]) @@ -123,9 +164,11 @@ def get_preference( is_python = identifier == "python" is_pinned = any(op[:2] == "==" for op in operators) constraints = len(operators) + is_conflict_promoted = identifier in self._conflict_promoted return ( not is_python, not is_top, + not is_conflict_promoted, not is_file_or_url, not is_pinned, not is_backtrack_cause, @@ -458,9 +501,9 @@ def get_dependencies(self, candidate: Candidate) -> list[Requirement]: def get_preference( self, identifier: str, - resolutions: dict[str, Candidate], - candidates: dict[str, Iterator[Candidate]], - information: dict[str, Iterator[RequirementInformation]], + resolutions: Mapping[str, Candidate], + candidates: Mapping[str, Iterator[Candidate]], + information: Mapping[str, Iterator[RequirementInformation]], backtrack_causes: Sequence[RequirementInformation], ) -> tuple[Comparable, ...]: # Resolve tracking packages so we have a chance to unpin them first. diff --git a/tests/resolver/test_providers.py b/tests/resolver/test_providers.py new file mode 100644 index 0000000000..94f66ce099 --- /dev/null +++ b/tests/resolver/test_providers.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from collections.abc import Iterator + +from resolvelib.resolvers import RequirementInformation + +from pdm.models.candidates import Candidate +from pdm.models.requirements import parse_requirement +from pdm.resolver.providers import _CONFLICT_PRIORITY_THRESHOLD + + +def _build_candidates(identifier: str) -> dict[str, Iterator[Candidate]]: + requirement = parse_requirement(identifier) + candidate = Candidate(requirement, name=requirement.project_name, version="1.0") + return {identifier: iter([candidate])} + + +def _build_information(identifier: str) -> dict[str, Iterator[RequirementInformation]]: + requirement = parse_requirement(identifier) + return {identifier: iter([RequirementInformation(requirement, None)])} + + +def test_narrow_requirement_selection_promotes_repeated_conflicts(project, repository): + repository.add_candidate("conflict-pkg", "1.0") + repository.add_candidate("other-pkg", "1.0") + + provider = project.get_provider() + narrow = provider.narrow_requirement_selection + causes = [RequirementInformation(parse_requirement("conflict-pkg"), None)] + + for _ in range(1, _CONFLICT_PRIORITY_THRESHOLD): + result = list(narrow(["other-pkg"], {}, {}, {}, causes)) + assert result == ["other-pkg"] + + result = list(narrow(["other-pkg", "conflict-pkg"], {}, {}, {}, causes)) + assert result == ["conflict-pkg"] + + result = list(narrow(["other-pkg", "conflict-pkg"], {}, {}, {}, [])) + assert result == ["conflict-pkg"] + + other_causes = [RequirementInformation(parse_requirement("other-pkg"), None)] + result = list(narrow(["other-pkg", "conflict-pkg"], {}, {}, {}, other_causes)) + assert result == ["other-pkg"] + + +def test_get_preference_prioritizes_promoted_conflicts(project, repository): + repository.add_candidate("promoted-pkg", "1.0") + repository.add_candidate("normal-pkg", "1.0") + + provider = project.get_provider() + provider._conflict_promoted.add("promoted-pkg") + + promoted_preference = provider.get_preference( + "promoted-pkg", + {}, + _build_candidates("promoted-pkg"), + _build_information("promoted-pkg"), + [], + ) + normal_preference = provider.get_preference( + "normal-pkg", + {}, + _build_candidates("normal-pkg"), + _build_information("normal-pkg"), + [], + ) + + assert promoted_preference < normal_preference