Skip to content

Commit

Permalink
locker: refactor dependency walk logic
Browse files Browse the repository at this point in the history
Resolves: #5141
  • Loading branch information
dimbleby authored Apr 3, 2022
1 parent eb27f81 commit fb13b3a
Show file tree
Hide file tree
Showing 4 changed files with 586 additions and 193 deletions.
182 changes: 78 additions & 104 deletions src/poetry/packages/locker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@


if TYPE_CHECKING:
from poetry.core.semver.version_constraint import VersionConstraint
from poetry.core.version.markers import BaseMarker
from tomlkit.items import InlineTable
from tomlkit.toml_document import TOMLDocument

Expand Down Expand Up @@ -203,152 +205,130 @@ def locked_repository(self, with_dev_reqs: bool = False) -> Repository:

@staticmethod
def __get_locked_package(
_dependency: Dependency, packages_by_name: dict[str, list[Package]]
dependency: Dependency,
packages_by_name: dict[str, list[Package]],
decided: dict[Package, Dependency] | None = None,
) -> Package | None:
"""
Internal helper to identify corresponding locked package using dependency
version constraints.
"""
for _package in packages_by_name.get(_dependency.name, []):
if _dependency.constraint.allows(_package.version):
return _package
return None
decided = decided or {}

# Get the packages that are consistent with this dependency.
packages = [
package
for package in packages_by_name.get(dependency.name, [])
if package.python_constraint.allows_all(dependency.python_constraint)
and dependency.constraint.allows(package.version)
]

# If we've previously made a choice that is compatible with the current
# requirement, stick with it.
for package in packages:
old_decision = decided.get(package)
if (
old_decision is not None
and not old_decision.marker.intersect(dependency.marker).is_empty()
):
return package

return next(iter(packages), None)

@classmethod
def __walk_dependency_level(
def __walk_dependencies(
cls,
dependencies: list[Dependency],
level: int,
pinned_versions: bool,
packages_by_name: dict[str, list[Package]],
project_level_dependencies: set[str],
nested_dependencies: dict[tuple[str, str], Dependency],
) -> dict[tuple[str, str], Dependency]:
if not dependencies:
return nested_dependencies

next_level_dependencies = []
) -> dict[Package, Dependency]:
nested_dependencies: dict[Package, Dependency] = {}

for requirement in dependencies:
key = (requirement.name, requirement.pretty_constraint)
locked_package = cls.__get_locked_package(requirement, packages_by_name)

if locked_package:
# create dependency from locked package to retain dependency metadata
# if this is not done, we can end-up with incorrect nested dependencies
constraint = requirement.constraint
pretty_constraint = requirement.pretty_constraint
marker = requirement.marker
requirement = locked_package.to_dependency()
requirement.marker = requirement.marker.intersect(marker)

key = (requirement.name, pretty_constraint)
visited: set[tuple[Dependency, BaseMarker]] = set()
while dependencies:
requirement = dependencies.pop(0)
if (requirement, requirement.marker) in visited:
continue
visited.add((requirement, requirement.marker))

if not pinned_versions:
requirement.set_constraint(constraint)
locked_package = cls.__get_locked_package(
requirement, packages_by_name, nested_dependencies
)

for require in locked_package.requires:
if require.marker.is_empty():
require.marker = requirement.marker
else:
require.marker = require.marker.intersect(requirement.marker)
if not locked_package:
raise RuntimeError(f"Dependency walk failed at {requirement}")

require.marker = require.marker.intersect(locked_package.marker)
# create dependency from locked package to retain dependency metadata
# if this is not done, we can end-up with incorrect nested dependencies
constraint = requirement.constraint
marker = requirement.marker
extras = requirement.extras
requirement = locked_package.to_dependency()
requirement.marker = requirement.marker.intersect(marker)

if key not in nested_dependencies:
next_level_dependencies.append(require)
requirement.set_constraint(constraint)

if requirement.name in project_level_dependencies and level == 0:
# project level dependencies take precedence
continue
for require in locked_package.requires:
if require.in_extras and extras.isdisjoint(require.in_extras):
continue

if not locked_package:
# we make a copy to avoid any side-effects
requirement = deepcopy(requirement)
require = deepcopy(require)
require.marker = require.marker.intersect(
requirement.marker.without_extras()
)
if not require.marker.is_empty():
dependencies.append(require)

key = locked_package
if key not in nested_dependencies:
nested_dependencies[key] = requirement
else:
nested_dependencies[key].marker = nested_dependencies[key].marker.union(
requirement.marker
)

return cls.__walk_dependency_level(
dependencies=next_level_dependencies,
level=level + 1,
pinned_versions=pinned_versions,
packages_by_name=packages_by_name,
project_level_dependencies=project_level_dependencies,
nested_dependencies=nested_dependencies,
)
return nested_dependencies

@classmethod
def get_project_dependencies(
cls,
project_requires: list[Dependency],
locked_packages: list[Package],
pinned_versions: bool = False,
with_nested: bool = False,
) -> Iterable[Dependency]:
) -> Iterable[tuple[Package, Dependency]]:
# group packages entries by name, this is required because requirement might use
# different constraints
# different constraints.
packages_by_name: dict[str, list[Package]] = {}
for pkg in locked_packages:
if pkg.name not in packages_by_name:
packages_by_name[pkg.name] = []
packages_by_name[pkg.name].append(pkg)

project_level_dependencies = set()
dependencies = []

for dependency in project_requires:
dependency = deepcopy(dependency)
locked_package = cls.__get_locked_package(dependency, packages_by_name)
if locked_package:
locked_dependency = locked_package.to_dependency()
locked_dependency.marker = dependency.marker.intersect(
locked_package.marker
)

if not pinned_versions:
locked_dependency.set_constraint(dependency.constraint)

dependency = locked_dependency

project_level_dependencies.add(dependency.name)
dependencies.append(dependency)

if not with_nested:
# return only with project level dependencies
return dependencies
# Put higher versions first so that we prefer them.
for packages in packages_by_name.values():
packages.sort(key=lambda package: package.version, reverse=True)

nested_dependencies = cls.__walk_dependency_level(
dependencies=dependencies,
level=0,
pinned_versions=pinned_versions,
nested_dependencies = cls.__walk_dependencies(
dependencies=project_requires,
packages_by_name=packages_by_name,
project_level_dependencies=project_level_dependencies,
nested_dependencies={},
)

# Merge same dependencies using marker union
for requirement in dependencies:
key = (requirement.name, requirement.pretty_constraint)
if key not in nested_dependencies:
nested_dependencies[key] = requirement
else:
nested_dependencies[key].marker = nested_dependencies[key].marker.union(
requirement.marker
)

return sorted(nested_dependencies.values(), key=lambda x: x.name.lower())
return nested_dependencies.items()

def get_project_dependency_packages(
self,
project_requires: list[Dependency],
project_python_marker: VersionConstraint | None = None,
dev: bool = False,
extras: bool | Sequence[str] | None = None,
) -> Iterator[DependencyPackage]:
# Apply the project python marker to all requirements.
if project_python_marker is not None:
marked_requires: list[Dependency] = []
for require in project_requires:
require = deepcopy(require)
require.marker = require.marker.intersect(project_python_marker)
marked_requires.append(require)
project_requires = marked_requires

repository = self.locked_repository(with_dev_reqs=dev)

# Build a set of all packages required by our selected extras
Expand Down Expand Up @@ -379,16 +359,10 @@ def get_project_dependency_packages(

selected.append(dependency)

for dependency in self.get_project_dependencies(
for package, dependency in self.get_project_dependencies(
project_requires=selected,
locked_packages=repository.packages,
with_nested=True,
):
try:
package = repository.find_packages(dependency=dependency)[0]
except IndexError:
continue

for extra in dependency.extras:
package.requires_extras.append(extra)

Expand Down
28 changes: 14 additions & 14 deletions src/poetry/utils/exporter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from __future__ import annotations

import itertools
import urllib.parse

from typing import TYPE_CHECKING
Expand Down Expand Up @@ -70,21 +69,22 @@ def _export_requirements_txt(
content = ""
dependency_lines = set()

for package, groups in itertools.groupby(
self._poetry.locker.get_project_dependency_packages(
project_requires=self._poetry.package.all_requires,
dev=dev,
extras=extras,
),
lambda dependency_package: dependency_package.package,
# Get project dependencies.
root_package = (
self._poetry.package.clone()
if dev
else self._poetry.package.with_dependency_groups(["default"], only=True)
)

for dependency_package in self._poetry.locker.get_project_dependency_packages(
project_requires=root_package.all_requires,
project_python_marker=root_package.python_marker,
dev=dev,
extras=extras,
):
line = ""
dependency_packages = list(groups)
dependency = dependency_packages[0].dependency
marker = dependency.marker
for dep_package in dependency_packages[1:]:
marker = marker.union(dep_package.dependency.marker)
dependency.marker = marker
dependency = dependency_package.dependency
package = dependency_package.package

if package.develop:
line += "-e "
Expand Down
20 changes: 15 additions & 5 deletions tests/console/commands/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,9 @@ def _export_requirements(tester: CommandTester, poetry: Poetry) -> None:
assert poetry.locker.lock.exists()

expected = """\
foo==1.0.0
foo==1.0.0 ;\
python_version >= "2.7" and python_version < "2.8" or\
python_version >= "3.4" and python_version < "4.0"
"""

assert content == expected
Expand Down Expand Up @@ -113,7 +115,9 @@ def test_export_fails_on_invalid_format(tester: CommandTester, do_lock: None):
def test_export_prints_to_stdout_by_default(tester: CommandTester, do_lock: None):
tester.execute("--format requirements.txt")
expected = """\
foo==1.0.0
foo==1.0.0 ;\
python_version >= "2.7" and python_version < "2.8" or\
python_version >= "3.4" and python_version < "4.0"
"""
assert tester.io.fetch_output() == expected

Expand All @@ -123,16 +127,22 @@ def test_export_uses_requirements_txt_format_by_default(
):
tester.execute()
expected = """\
foo==1.0.0
foo==1.0.0 ;\
python_version >= "2.7" and python_version < "2.8" or\
python_version >= "3.4" and python_version < "4.0"
"""
assert tester.io.fetch_output() == expected


def test_export_includes_extras_by_flag(tester: CommandTester, do_lock: None):
tester.execute("--format requirements.txt --extras feature_bar")
expected = """\
bar==1.1.0
foo==1.0.0
bar==1.1.0 ;\
python_version >= "2.7" and python_version < "2.8" or\
python_version >= "3.4" and python_version < "4.0"
foo==1.0.0 ;\
python_version >= "2.7" and python_version < "2.8" or\
python_version >= "3.4" and python_version < "4.0"
"""
assert tester.io.fetch_output() == expected

Expand Down
Loading

0 comments on commit fb13b3a

Please sign in to comment.