Skip to content

Commit

Permalink
xds: ensure server interceptors are created in a sync context (#11930)
Browse files Browse the repository at this point in the history
`XdsServerWrapper#generatePerRouteInterceptors` was always intended
to be executed within a sync context. This PR ensures that by calling
`syncContext.throwIfNotInThisSynchronizationContext()`.

This change is needed for upcoming xDS filter state retention because
the new tests in XdsServerWrapperTest flake with this NPE:

> `Cannot invoke "io.grpc.xds.client.XdsClient$ResourceWatcher.onChanged(io.grpc.xds.client.XdsClient$ResourceUpdate)" because "this.ldsWatcher" is null`
  • Loading branch information
sergiitk authored Mar 3, 2025
1 parent cdab410 commit 1a2285b
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 36 deletions.
4 changes: 1 addition & 3 deletions xds/src/main/java/io/grpc/xds/XdsServerWrapper.java
Original file line number Diff line number Diff line change
Expand Up @@ -524,9 +524,7 @@ private AtomicReference<ServerRoutingConfig> generateRoutingConfig(FilterChain f

private ImmutableMap<Route, ServerInterceptor> generatePerRouteInterceptors(
@Nullable List<NamedFilterConfig> filterConfigs, List<VirtualHost> virtualHosts) {
// This should always be called from the sync context.
// Ideally we'd want to throw otherwise, but this breaks the tests now.
// syncContext.throwIfNotInThisSynchronizationContext();
syncContext.throwIfNotInThisSynchronizationContext();

ImmutableMap.Builder<Route, ServerInterceptor> perRouteInterceptors =
new ImmutableMap.Builder<>();
Expand Down
83 changes: 65 additions & 18 deletions xds/src/test/java/io/grpc/xds/XdsServerTestHelper.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,18 @@
import io.grpc.xds.client.XdsClient;
import io.grpc.xds.client.XdsInitializationException;
import io.grpc.xds.client.XdsResourceType;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import javax.annotation.Nullable;

/**
Expand Down Expand Up @@ -174,12 +178,18 @@ public List<String> getTargets() {
}
}

// Implementation details:
// 1. Use `synchronized` in methods where XdsClientImpl uses its own `syncContext`.
// 2. Use `serverExecutor` via `execute()` in methods where XdsClientImpl uses watcher's executor.
static final class FakeXdsClient extends XdsClient {
boolean shutdown;
SettableFuture<String> ldsResource = SettableFuture.create();
ResourceWatcher<LdsUpdate> ldsWatcher;
CountDownLatch rdsCount = new CountDownLatch(1);
public static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(5);

private boolean shutdown;
@Nullable SettableFuture<String> ldsResource = SettableFuture.create();
@Nullable ResourceWatcher<LdsUpdate> ldsWatcher;
private CountDownLatch rdsCount = new CountDownLatch(1);
final Map<String, ResourceWatcher<RdsUpdate>> rdsWatchers = new HashMap<>();
@Nullable private volatile Executor serverExecutor;

@Override
public TlsContextManager getSecurityConfig() {
Expand All @@ -193,14 +203,20 @@ public BootstrapInfo getBootstrapInfo() {

@Override
@SuppressWarnings("unchecked")
public <T extends ResourceUpdate> void watchXdsResource(XdsResourceType<T> resourceType,
String resourceName,
ResourceWatcher<T> watcher,
Executor syncContext) {
public synchronized <T extends ResourceUpdate> void watchXdsResource(
XdsResourceType<T> resourceType,
String resourceName,
ResourceWatcher<T> watcher,
Executor executor) {
if (serverExecutor != null) {
assertThat(executor).isEqualTo(serverExecutor);
}

switch (resourceType.typeName()) {
case "LDS":
assertThat(ldsWatcher).isNull();
ldsWatcher = (ResourceWatcher<LdsUpdate>) watcher;
serverExecutor = executor;
ldsResource.set(resourceName);
break;
case "RDS":
Expand All @@ -213,14 +229,14 @@ public <T extends ResourceUpdate> void watchXdsResource(XdsResourceType<T> resou
}

@Override
public <T extends ResourceUpdate> void cancelXdsResourceWatch(XdsResourceType<T> type,
String resourceName,
ResourceWatcher<T> watcher) {
public synchronized <T extends ResourceUpdate> void cancelXdsResourceWatch(
XdsResourceType<T> type, String resourceName, ResourceWatcher<T> watcher) {
switch (type.typeName()) {
case "LDS":
assertThat(ldsWatcher).isNotNull();
ldsResource = null;
ldsWatcher = null;
serverExecutor = null;
break;
case "RDS":
rdsWatchers.remove(resourceName);
Expand All @@ -230,27 +246,58 @@ public <T extends ResourceUpdate> void cancelXdsResourceWatch(XdsResourceType<T>
}

@Override
public void shutdown() {
public synchronized void shutdown() {
shutdown = true;
}

@Override
public boolean isShutDown() {
public synchronized boolean isShutDown() {
return shutdown;
}

public void awaitRds(Duration timeout) throws InterruptedException, TimeoutException {
if (!rdsCount.await(timeout.toMillis(), TimeUnit.MILLISECONDS)) {
throw new TimeoutException("Timeout " + timeout + " waiting for RDSs");
}
}

public void setExpectedRdsCount(int count) {
rdsCount = new CountDownLatch(count);
}

private void execute(Runnable action) {
// This method ensures that all watcher updates:
// - Happen after the server started watching LDS.
// - Are executed within the sync context of the server.
//
// Note that this doesn't guarantee that any of the RDS watchers are created.
// Tests should use setExpectedRdsCount(int) and awaitRds() for that.
if (ldsResource == null) {
throw new IllegalStateException("xDS resource update after watcher cancel");
}
try {
ldsResource.get(DEFAULT_TIMEOUT.toMillis(), TimeUnit.MILLISECONDS);
} catch (ExecutionException | TimeoutException e) {
throw new RuntimeException("Can't resolve LDS resource name in " + DEFAULT_TIMEOUT, e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(e);
}
serverExecutor.execute(action);
}

void deliverLdsUpdate(List<FilterChain> filterChains,
FilterChain defaultFilterChain) {
ldsWatcher.onChanged(LdsUpdate.forTcpListener(Listener.create(
"listener", "0.0.0.0:1", ImmutableList.copyOf(filterChains), defaultFilterChain)));
deliverLdsUpdate(LdsUpdate.forTcpListener(Listener.create(
"listener", "0.0.0.0:1", ImmutableList.copyOf(filterChains), defaultFilterChain)));
}

void deliverLdsUpdate(LdsUpdate ldsUpdate) {
ldsWatcher.onChanged(ldsUpdate);
execute(() -> ldsWatcher.onChanged(ldsUpdate));
}

void deliverRdsUpdate(String rdsName, List<VirtualHost> virtualHosts) {
rdsWatchers.get(rdsName).onChanged(new RdsUpdate(virtualHosts));
void deliverRdsUpdate(String resourceName, List<VirtualHost> virtualHosts) {
execute(() -> rdsWatchers.get(resourceName).onChanged(new RdsUpdate(virtualHosts)));
}
}
}
28 changes: 13 additions & 15 deletions xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
Expand Down Expand Up @@ -252,7 +251,7 @@ public void run() {
FilterChain f0 = createFilterChain("filter-chain-0", hcm_virtual);
FilterChain f1 = createFilterChain("filter-chain-1", createRds("rds"));
xdsClient.deliverLdsUpdate(Collections.singletonList(f0), f1);
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("rds",
Collections.singletonList(createVirtualHost("virtual-host-1")));
verify(listener, timeout(5000)).onServing();
Expand All @@ -261,7 +260,7 @@ public void run() {
xdsServerWrapper.shutdown();
assertThat(xdsServerWrapper.isShutdown()).isTrue();
assertThat(xdsClient.ldsResource).isNull();
assertThat(xdsClient.shutdown).isTrue();
assertThat(xdsClient.isShutDown()).isTrue();
verify(mockServer).shutdown();
assertThat(f0.sslContextProviderSupplier().isShutdown()).isTrue();
assertThat(f1.sslContextProviderSupplier().isShutdown()).isTrue();
Expand Down Expand Up @@ -303,7 +302,7 @@ public void run() {
verify(mockServer, never()).start();
assertThat(xdsServerWrapper.isShutdown()).isTrue();
assertThat(xdsClient.ldsResource).isNull();
assertThat(xdsClient.shutdown).isTrue();
assertThat(xdsClient.isShutDown()).isTrue();
verify(mockServer).shutdown();
assertThat(f0.sslContextProviderSupplier().isShutdown()).isTrue();
assertThat(f1.sslContextProviderSupplier().isShutdown()).isTrue();
Expand Down Expand Up @@ -342,7 +341,7 @@ public void run() {
xdsServerWrapper.shutdown();
assertThat(xdsServerWrapper.isShutdown()).isTrue();
assertThat(xdsClient.ldsResource).isNull();
assertThat(xdsClient.shutdown).isTrue();
assertThat(xdsClient.isShutDown()).isTrue();
verify(mockBuilder, times(1)).build();
verify(mockServer, times(1)).shutdown();
xdsServerWrapper.awaitTermination(1, TimeUnit.SECONDS);
Expand All @@ -367,7 +366,7 @@ public void run() {
FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds"));
SslContextProviderSupplier sslSupplier = filterChain.sslContextProviderSupplier();
xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null);
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("rds",
Collections.singletonList(createVirtualHost("virtual-host-1")));
try {
Expand Down Expand Up @@ -434,7 +433,7 @@ public void run() {
xdsClient.ldsResource.get(5, TimeUnit.SECONDS);
FilterChain filterChain = createFilterChain("filter-chain-1", createRds("rds"));
xdsClient.deliverLdsUpdate(Collections.singletonList(filterChain), null);
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("rds",
Collections.singletonList(createVirtualHost("virtual-host-1")));
try {
Expand Down Expand Up @@ -544,7 +543,7 @@ public void run() {
0L, Collections.singletonList(virtualHost), new ArrayList<NamedFilterConfig>());
EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual);
EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0"));
xdsClient.rdsCount = new CountDownLatch(3);
xdsClient.setExpectedRdsCount(3);
xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null);
assertThat(start.isDone()).isFalse();
assertThat(selectorManager.getSelectorToUpdateSelector()).isNull();
Expand All @@ -556,7 +555,7 @@ public void run() {
xdsClient.deliverLdsUpdate(Arrays.asList(f0, f2), f3);
verify(mockServer, never()).start();
verify(listener, never()).onServing();
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);

xdsClient.deliverRdsUpdate("r1",
Collections.singletonList(createVirtualHost("virtual-host-1")));
Expand Down Expand Up @@ -602,12 +601,11 @@ public void run() {
EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0"));
EnvoyServerProtoData.FilterChain f2 = createFilterChain("filter-chain-2", createRds("r0"));

xdsClient.rdsCount = new CountDownLatch(1);
xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), f2);
assertThat(start.isDone()).isFalse();
assertThat(selectorManager.getSelectorToUpdateSelector()).isNull();

xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("r0",
Collections.singletonList(createVirtualHost("virtual-host-0")));
start.get(5000, TimeUnit.MILLISECONDS);
Expand All @@ -633,9 +631,9 @@ public void run() {
EnvoyServerProtoData.FilterChain f3 = createFilterChain("filter-chain-3", createRds("r0"));
EnvoyServerProtoData.FilterChain f4 = createFilterChain("filter-chain-4", createRds("r1"));
EnvoyServerProtoData.FilterChain f5 = createFilterChain("filter-chain-4", createRds("r1"));
xdsClient.rdsCount = new CountDownLatch(1);
xdsClient.setExpectedRdsCount(1);
xdsClient.deliverLdsUpdate(Arrays.asList(f5, f3), f4);
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("r1",
Collections.singletonList(createVirtualHost("virtual-host-1")));
xdsClient.deliverRdsUpdate("r0",
Expand Down Expand Up @@ -688,7 +686,7 @@ public void run() {
EnvoyServerProtoData.FilterChain f0 = createFilterChain("filter-chain-0", hcmVirtual);
EnvoyServerProtoData.FilterChain f1 = createFilterChain("filter-chain-1", createRds("r0"));
xdsClient.deliverLdsUpdate(Arrays.asList(f0, f1), null);
xdsClient.rdsCount.await();
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.rdsWatchers.get("r0").onError(Status.CANCELLED);
start.get(5000, TimeUnit.MILLISECONDS);
assertThat(selectorManager.getSelectorToUpdateSelector().getRoutingConfigs().size())
Expand Down Expand Up @@ -1235,7 +1233,7 @@ public <ReqT, RespT> ServerCall.Listener<ReqT> interceptCall(ServerCall<ReqT, Re
VirtualHost virtualHost = VirtualHost.create(
"v1", Collections.singletonList("foo.google.com"), Arrays.asList(route),
ImmutableMap.of("filter-config-name-0", f0Override));
xdsClient.rdsCount.await(5, TimeUnit.SECONDS);
xdsClient.awaitRds(FakeXdsClient.DEFAULT_TIMEOUT);
xdsClient.deliverRdsUpdate("r0", Collections.singletonList(virtualHost));
start.get(5000, TimeUnit.MILLISECONDS);
verify(mockServer).start();
Expand Down

0 comments on commit 1a2285b

Please # to comment.