Skip to content

Commit b45df09

Browse files
committed
[CSStep] Don't favor choices until the disjunction is picked
1 parent 52a67e9 commit b45df09

File tree

4 files changed

+33
-22
lines changed

4 files changed

+33
-22
lines changed

include/swift/Sema/ConstraintSystem.h

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5750,8 +5750,9 @@ class ConstraintSystem {
57505750

57515751
/// Pick a disjunction from the InactiveConstraints list.
57525752
///
5753-
/// \returns The selected disjunction.
5754-
Constraint *selectDisjunction();
5753+
/// \returns The selected disjunction and a set of it's favored choices.
5754+
Optional<std::pair<Constraint *, llvm::TinyPtrVector<Constraint *>>>
5755+
selectDisjunction();
57555756

57565757
/// Pick a conjunction from the InactiveConstraints list.
57575758
///
@@ -6693,7 +6694,8 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
66936694
public:
66946695
using Element = DisjunctionChoice;
66956696

6696-
DisjunctionChoiceProducer(ConstraintSystem &cs, Constraint *disjunction)
6697+
DisjunctionChoiceProducer(ConstraintSystem &cs, Constraint *disjunction,
6698+
llvm::TinyPtrVector<Constraint *> &favorites)
66976699
: BindingProducer(cs, disjunction->shouldRememberChoice()
66986700
? disjunction->getLocator()
66996701
: nullptr),
@@ -6703,6 +6705,11 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
67036705
assert(disjunction->getKind() == ConstraintKind::Disjunction);
67046706
assert(!disjunction->shouldRememberChoice() || disjunction->getLocator());
67056707

6708+
// Mark constraints as favored. This information
6709+
// is going to be used by partitioner.
6710+
for (auto *choice : favorites)
6711+
cs.favorConstraint(choice);
6712+
67066713
// Order and partition the disjunction choices.
67076714
partitionDisjunction(Ordering, PartitionBeginning);
67086715
}
@@ -6747,8 +6754,9 @@ class DisjunctionChoiceProducer : public BindingProducer<DisjunctionChoice> {
67476754
// Partition the choices in the disjunction into groups that we will
67486755
// iterate over in an order appropriate to attempt to stop before we
67496756
// have to visit all of the options.
6750-
void partitionDisjunction(SmallVectorImpl<unsigned> &Ordering,
6751-
SmallVectorImpl<unsigned> &PartitionBeginning);
6757+
void
6758+
partitionDisjunction(SmallVectorImpl<unsigned> &Ordering,
6759+
SmallVectorImpl<unsigned> &PartitionBeginning);
67526760

67536761
/// Partition the choices in the range \c first to \c last into groups and
67546762
/// order the groups in the best order to attempt based on the argument

lib/Sema/CSOptimizer.cpp

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -473,15 +473,16 @@ selectBestBindingDisjunction(ConstraintSystem &cs,
473473
return firstBindDisjunction;
474474
}
475475

476-
Constraint *ConstraintSystem::selectDisjunction() {
476+
Optional<std::pair<Constraint *, llvm::TinyPtrVector<Constraint *>>>
477+
ConstraintSystem::selectDisjunction() {
477478
SmallVector<Constraint *, 4> disjunctions;
478479

479480
collectDisjunctions(disjunctions);
480481
if (disjunctions.empty())
481-
return nullptr;
482+
return None;
482483

483484
if (auto *disjunction = selectBestBindingDisjunction(*this, disjunctions))
484-
return disjunction;
485+
return std::make_pair(disjunction, llvm::TinyPtrVector<Constraint *>());
485486

486487
llvm::DenseMap<Constraint *, llvm::TinyPtrVector<Constraint *>> favorings;
487488
determineBestChoicesInContext(*this, disjunctions, favorings);
@@ -513,14 +514,8 @@ Constraint *ConstraintSystem::selectDisjunction() {
513514
return firstFavored < secondFavored;
514515
});
515516

516-
if (bestDisjunction != disjunctions.end()) {
517-
// If selected disjunction has any choices that should be favored
518-
// let's record them now.
519-
for (auto *choice : favorings[*bestDisjunction])
520-
favorConstraint(choice);
521-
522-
return *bestDisjunction;
523-
}
517+
if (bestDisjunction != disjunctions.end())
518+
return std::make_pair(*bestDisjunction, favorings[*bestDisjunction]);
524519

525-
return nullptr;
520+
return None;
526521
}

lib/Sema/CSStep.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ StepResult ComponentStep::take(bool prevFailed) {
359359
}
360360
});
361361

362-
auto *disjunction = CS.selectDisjunction();
362+
auto disjunction = CS.selectDisjunction();
363363
auto *conjunction = CS.selectConjunction();
364364

365365
if (CS.isDebugMode()) {
@@ -402,7 +402,8 @@ StepResult ComponentStep::take(bool prevFailed) {
402402
// Bindings usually happen first, but sometimes we want to prioritize a
403403
// disjunction or conjunction.
404404
if (bestBindings) {
405-
if (disjunction && !bestBindings->favoredOverDisjunction(disjunction))
405+
if (disjunction &&
406+
!bestBindings->favoredOverDisjunction(disjunction->first))
406407
return StepKind::Disjunction;
407408

408409
if (conjunction && !bestBindings->favoredOverConjunction(conjunction))
@@ -426,7 +427,7 @@ StepResult ComponentStep::take(bool prevFailed) {
426427
std::make_unique<TypeVariableStep>(*bestBindings, Solutions));
427428
case StepKind::Disjunction:
428429
return suspend(
429-
std::make_unique<DisjunctionStep>(CS, disjunction, Solutions));
430+
std::make_unique<DisjunctionStep>(CS, *disjunction, Solutions));
430431
case StepKind::Conjunction:
431432
return suspend(
432433
std::make_unique<ConjunctionStep>(CS, conjunction, Solutions));

lib/Sema/CSStep.h

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -684,10 +684,17 @@ class DisjunctionStep final : public BindingStep<DisjunctionChoiceProducer> {
684684
Optional<std::pair<Constraint *, Score>> LastSolvedChoice;
685685

686686
public:
687+
DisjunctionStep(
688+
ConstraintSystem &cs,
689+
std::pair<Constraint *, llvm::TinyPtrVector<Constraint *>> &disjunction,
690+
SmallVectorImpl<Solution> &solutions)
691+
: DisjunctionStep(cs, disjunction.first, disjunction.second, solutions) {}
692+
687693
DisjunctionStep(ConstraintSystem &cs, Constraint *disjunction,
694+
llvm::TinyPtrVector<Constraint *> &favoredChoices,
688695
SmallVectorImpl<Solution> &solutions)
689-
: BindingStep(cs, {cs, disjunction}, solutions), Disjunction(disjunction),
690-
AfterDisjunction(erase(disjunction)) {
696+
: BindingStep(cs, {cs, disjunction, favoredChoices}, solutions),
697+
Disjunction(disjunction), AfterDisjunction(erase(disjunction)) {
691698
assert(Disjunction->getKind() == ConstraintKind::Disjunction);
692699
pruneOverloadSet(Disjunction);
693700
++cs.solverState->NumDisjunctions;

0 commit comments

Comments
 (0)