diff --git a/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java index 7446905194e..7625b43f4ab 100644 --- a/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/internal/BinderClientTransportTest.java @@ -24,6 +24,8 @@ import android.os.RemoteException; import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; +import com.google.common.util.concurrent.Futures; +import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import com.google.protobuf.Empty; import io.grpc.CallOptions; @@ -35,6 +37,7 @@ import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.binder.AndroidComponentAddress; +import io.grpc.binder.AsyncSecurityPolicy; import io.grpc.binder.BinderServerBuilder; import io.grpc.binder.HostServices; import io.grpc.binder.SecurityPolicy; @@ -381,6 +384,34 @@ public void testBlackHoleSecurityPolicyConnectTimeout() throws Exception { blockingSecurityPolicy.provideNextCheckAuthorizationResult(Status.OK); } + @Test + public void testAsyncSecurityPolicyFailure() throws Exception { + SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); + transport = new BinderClientTransportBuilder() + .setSecurityPolicy(securityPolicy) + .build(); + RuntimeException exception = new NullPointerException(); + securityPolicy.setAuthorizationException(exception); + transport.start(transportListener).run(); + Status transportStatus = transportListener.awaitShutdown(); + assertThat(transportStatus.getCode()).isEqualTo(Code.INTERNAL); + assertThat(transportStatus.getCause()).isEqualTo(exception); + transportListener.awaitTermination(); + } + + @Test + public void testAsyncSecurityPolicySuccess() throws Exception { + SettableAsyncSecurityPolicy securityPolicy = new SettableAsyncSecurityPolicy(); + transport = new BinderClientTransportBuilder() + .setSecurityPolicy(securityPolicy) + .build(); + securityPolicy.setAuthorizationResult(Status.PERMISSION_DENIED); + transport.start(transportListener).run(); + Status transportStatus = transportListener.awaitShutdown(); + assertThat(transportStatus.getCode()).isEqualTo(Code.PERMISSION_DENIED); + transportListener.awaitTermination(); + } + private static void startAndAwaitReady( BinderTransport.BinderClientTransport transport, TestTransportListener transportListener) throws Exception { @@ -540,4 +571,27 @@ public Status checkAuthorization(int uid) { } } } + + /** + * An AsyncSecurityPolicy that lets a test specify the outcome of checkAuthorizationAsync(). + */ + static class SettableAsyncSecurityPolicy extends AsyncSecurityPolicy { + private SettableFuture result = SettableFuture.create(); + + public void clearAuthorizationResult() { + result = SettableFuture.create(); + } + + public boolean setAuthorizationResult(Status status) { + return result.set(status); + } + + public boolean setAuthorizationException(Throwable t) { + return result.setException(t); + } + + public ListenableFuture checkAuthorizationAsync(int uid) { + return Futures.nonCancellationPropagating(result); + } + } } diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java index 9deb2bfaea1..dbdcaef6908 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java @@ -32,6 +32,8 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Ticker; import com.google.common.base.Verify; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; import com.google.common.util.concurrent.ListenableFuture; import io.grpc.Attributes; import io.grpc.CallOptions; @@ -47,6 +49,7 @@ import io.grpc.Status; import io.grpc.StatusException; import io.grpc.binder.AndroidComponentAddress; +import io.grpc.binder.AsyncSecurityPolicy; import io.grpc.binder.InboundParcelablePolicy; import io.grpc.binder.SecurityPolicy; import io.grpc.internal.ClientStream; @@ -743,8 +746,8 @@ void notifyTerminated() { @Override @GuardedBy("this") protected void handleSetupTransport(Parcel parcel) { - // Add the remote uid to our attributes. - attributes = setSecurityAttrs(attributes, Binder.getCallingUid()); + int remoteUid = Binder.getCallingUid(); + attributes = setSecurityAttrs(attributes, remoteUid); if (inState(TransportState.SETUP)) { int version = parcel.readInt(); IBinder binder = parcel.readStrongBinder(); @@ -755,46 +758,54 @@ protected void handleSetupTransport(Parcel parcel) { shutdownInternal( Status.UNAVAILABLE.withDescription("Malformed SETUP_TRANSPORT data"), true); } else { - offloadExecutor.execute(() -> checkSecurityPolicy(binder)); + ListenableFuture authFuture = (securityPolicy instanceof AsyncSecurityPolicy) ? + ((AsyncSecurityPolicy) securityPolicy).checkAuthorizationAsync(remoteUid) : + Futures.submit(() -> securityPolicy.checkAuthorization(remoteUid), offloadExecutor); + Futures.addCallback( + authFuture, + new FutureCallback() { + @Override + public void onSuccess(Status result) { handleAuthResult(binder, result); } + + @Override + public void onFailure(Throwable t) { handleAuthResult(t); } + }, + offloadExecutor); } } } - private void checkSecurityPolicy(IBinder binder) { - Status authorization; - Integer remoteUid; - synchronized (this) { - remoteUid = attributes.get(REMOTE_UID); - } - if (remoteUid == null) { - authorization = Status.UNAUTHENTICATED.withDescription("No remote UID available"); - } else { - authorization = securityPolicy.checkAuthorization(remoteUid); - } - synchronized (this) { - if (inState(TransportState.SETUP)) { - if (!authorization.isOk()) { - shutdownInternal(authorization, true); - } else if (!setOutgoingBinder(OneWayBinderProxy.wrap(binder, offloadExecutor))) { - shutdownInternal( - Status.UNAVAILABLE.withDescription("Failed to observe outgoing binder"), true); - } else { - // Check state again, since a failure inside setOutgoingBinder (or a callback it - // triggers), could have shut us down. - if (!isShutdown()) { - setState(TransportState.READY); - attributes = clientTransportListener.filterTransport(attributes); - clientTransportListener.transportReady(); - if (readyTimeoutFuture != null) { - readyTimeoutFuture.cancel(false); - readyTimeoutFuture = null; - } + private synchronized void handleAuthResult(IBinder binder, Status authorization) { + if (inState(TransportState.SETUP)) { + if (!authorization.isOk()) { + shutdownInternal(authorization, true); + } else if (!setOutgoingBinder(OneWayBinderProxy.wrap(binder, offloadExecutor))) { + shutdownInternal( + Status.UNAVAILABLE.withDescription("Failed to observe outgoing binder"), true); + } else { + // Check state again, since a failure inside setOutgoingBinder (or a callback it + // triggers), could have shut us down. + if (!isShutdown()) { + setState(TransportState.READY); + attributes = clientTransportListener.filterTransport(attributes); + clientTransportListener.transportReady(); + if (readyTimeoutFuture != null) { + readyTimeoutFuture.cancel(false); + readyTimeoutFuture = null; } } } } } + private synchronized void handleAuthResult(Throwable t) { + shutdownInternal( + Status.INTERNAL + .withDescription("Could not evaluate SecurityPolicy") + .withCause(t), + true); + } + @GuardedBy("this") @Override protected void handlePingResponse(Parcel parcel) {