From 6e2e645ff1246440572a6e635053995145857ca9 Mon Sep 17 00:00:00 2001 From: Maxim Koltsov Date: Sun, 7 Aug 2022 19:14:13 +0300 Subject: [PATCH] Preserve package's source when it has extras --- src/poetry/puzzle/provider.py | 10 +++++++++- tests/puzzle/test_provider.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/src/poetry/puzzle/provider.py b/src/poetry/puzzle/provider.py index 15668ac0b52..f0ba16e5669 100644 --- a/src/poetry/puzzle/provider.py +++ b/src/poetry/puzzle/provider.py @@ -603,7 +603,15 @@ def complete_package( ) package = dependency_package.package dependency = dependency_package.dependency - _dependencies.append(package.without_features().to_dependency()) + new_dependency = package.without_features().to_dependency() + + # When adding dependency foo[extra] -> foo, preserve foo's source, if it's + # specified. This prevents us from trying to get foo from PyPI + # when user explicitly set repo for foo[extra]. + if not new_dependency.source_name and dependency.source_name: + new_dependency.source_name = dependency.source_name + + _dependencies.append(new_dependency) for dep in requires: if not self._python_constraint.allows_any(dep.python_constraint): diff --git a/tests/puzzle/test_provider.py b/tests/puzzle/test_provider.py index 60a45e9fca6..a6c1575681a 100644 --- a/tests/puzzle/test_provider.py +++ b/tests/puzzle/test_provider.py @@ -688,3 +688,31 @@ def test_complete_package_preserves_source_type_with_subdirectories( dependency_one_copy.to_pep_508(), dependency_two.to_pep_508(), } + + +@pytest.mark.parametrize("source_name", [None, "repo"]) +def test_complete_package_with_extras_preserves_source_name( + provider: Provider, repository: Repository, source_name: str | None +) -> None: + package_a = Package("A", "1.0") + package_b = Package("B", "1.0") + dep = get_dependency("B", "^1.0", optional=True) + package_a.add_dependency(dep) + package_a.extras = {"foo": [dep]} + repository.add_package(package_a) + repository.add_package(package_b) + + dependency = Dependency("A", "1.0", extras=["foo"]) + if source_name: + dependency.source_name = source_name + + complete_package = provider.complete_package( + DependencyPackage(dependency, package_a) + ) + + requires = complete_package.package.all_requires + assert len(requires) == 2 + assert requires[0].name == "a" + assert requires[0].source_name == source_name + assert requires[1].name == "b" + assert requires[1].source_name is None