Skip to content

Commit

Permalink
fixed StaticStrideScheduler looping too many times (#10370)
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyjongyoonan authored Jul 13, 2023
1 parent 78cf1c3 commit 37351c6
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,6 @@ public boolean isEquivalentTo(RoundRobinPicker picker) {
@VisibleForTesting
static final class StaticStrideScheduler {
private final short[] scaledWeights;
private final int sizeDivisor;
private final AtomicInteger sequence;
private static final int K_MAX_WEIGHT = 0xFFFF;

Expand All @@ -373,7 +372,7 @@ static final class StaticStrideScheduler {
if (numWeightedChannels > 0) {
meanWeight = (short) Math.round(scalingFactor * sumWeight / numWeightedChannels);
} else {
meanWeight = 1;
meanWeight = (short) Math.round(scalingFactor);
}

// scales weights s.t. max(weights) == K_MAX_WEIGHT, meanWeight is scaled accordingly
Expand All @@ -387,7 +386,6 @@ static final class StaticStrideScheduler {
}

this.scaledWeights = scaledWeights;
this.sizeDivisor = numChannels;
this.sequence = new AtomicInteger(random.nextInt());

}
Expand Down Expand Up @@ -433,15 +431,18 @@ long getSequence() {
* an offset that varies per backend index is also included to the calculation.
*/
int pick() {
int i = 0;
while (true) {
i++;
long sequence = this.nextSequence();
int backendIndex = (int) (sequence % this.sizeDivisor);
long generation = sequence / this.sizeDivisor;
int weight = Short.toUnsignedInt(this.scaledWeights[backendIndex]);
int backendIndex = (int) (sequence % scaledWeights.length);
long generation = sequence / scaledWeights.length;
int weight = Short.toUnsignedInt(scaledWeights[backendIndex]);
long offset = (long) K_MAX_WEIGHT / 2 * backendIndex;
if ((weight * generation + offset) % K_MAX_WEIGHT < K_MAX_WEIGHT - weight) {
continue;
}
assert i <= scaledWeights.length : "scheduler has more than one pass through";
return backendIndex;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,11 @@ weightedSubchannel3.new OrcaReportListener(weightedConfig.errorUtilizationPenalt
}
assertThat(pickCount.size()).isEqualTo(3);
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 10000.0 - subchannel1PickRatio))
.isAtMost(0.001);
.isAtMost(0.0001);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 10000.0 - subchannel2PickRatio ))
.isAtMost(0.001);
.isAtMost(0.0001);
assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 10000.0 - subchannel3PickRatio ))
.isAtMost(0.001);
.isAtMost(0.0001);
}

@Test
Expand Down Expand Up @@ -751,12 +751,12 @@ weightedSubchannel2.new OrcaReportListener(weightedConfig.errorUtilizationPenalt
}
assertThat(pickCount.size()).isEqualTo(3);
assertThat(Math.abs(pickCount.get(weightedSubchannel1) / 1000.0 - 4.0 / 9))
.isAtMost(0.002);
.isAtMost(0.001);
assertThat(Math.abs(pickCount.get(weightedSubchannel2) / 1000.0 - 2.0 / 9))
.isAtMost(0.002);
.isAtMost(0.001);
// subchannel3's weight is average of subchannel1 and subchannel2
assertThat(Math.abs(pickCount.get(weightedSubchannel3) / 1000.0 - 3.0 / 9))
.isAtMost(0.002);
.isAtMost(0.001);
}

@Test
Expand Down Expand Up @@ -947,7 +947,7 @@ public void testStaticStrideSchedulerNonIntegers1() {
}
for (int i = 0; i < 3; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.01);
.isAtMost(0.001);
}
}

Expand All @@ -964,7 +964,7 @@ public void testStaticStrideSchedulerNonIntegers2() {
}
for (int i = 0; i < 3; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.01);
.isAtMost(0.001);
}
}

Expand All @@ -981,7 +981,7 @@ public void testTwoWeights() {
}
for (int i = 0; i < 2; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.01);
.isAtMost(0.001);
}
}

Expand Down Expand Up @@ -1015,7 +1015,7 @@ public void testManyComplexWeights() {
}
for (int i = 0; i < 8; i++) {
assertThat(Math.abs(pickCount.getOrDefault(i, 0) / 1000.0 - weights[i] / totalWeight))
.isAtMost(0.01);
.isAtMost(0.004);
}
}

Expand Down

0 comments on commit 37351c6

Please # to comment.