Skip to content

Commit

Permalink
[#noissue] Refactor grpc ssl initialization
Browse files Browse the repository at this point in the history
  • Loading branch information
emeroad committed Apr 25, 2023
1 parent 32c0b6d commit 2c887aa
Show file tree
Hide file tree
Showing 12 changed files with 124 additions and 145 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
package com.navercorp.pinpoint.grpc.client;

import com.navercorp.pinpoint.grpc.client.config.ClientOption;
import com.navercorp.pinpoint.grpc.client.config.SslOption;

import io.grpc.ClientInterceptor;
import io.grpc.NameResolverProvider;
import io.netty.handler.ssl.SslContext;


/**
Expand All @@ -39,7 +38,7 @@ public interface ChannelFactoryBuilder {

void setClientOption(ClientOption clientOption);

void setSslOption(SslOption sslOption);
void setSslContext(SslContext sslContext);

void setNameResolverProvider(NameResolverProvider nameResolverProvider);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import com.navercorp.pinpoint.grpc.ChannelTypeEnum;
import com.navercorp.pinpoint.grpc.ExecutorUtils;
import com.navercorp.pinpoint.grpc.client.config.ClientOption;
import com.navercorp.pinpoint.grpc.security.SslClientConfig;
import com.navercorp.pinpoint.grpc.security.SslContextFactory;
import io.grpc.ClientInterceptor;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
Expand All @@ -39,7 +37,6 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import javax.net.ssl.SSLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
Expand All @@ -63,7 +60,8 @@ public class DefaultChannelFactory implements ChannelFactory {
private final HeaderFactory headerFactory;

private final ClientOption clientOption;
private final SslClientConfig sslClientConfig;
// nullable
private final SslContext sslContext;

private final List<ClientInterceptor> clientInterceptorList;
private final NameResolverProvider nameResolverProvider;
Expand All @@ -79,18 +77,20 @@ public class DefaultChannelFactory implements ChannelFactory {
HeaderFactory headerFactory,
NameResolverProvider nameResolverProvider,
ClientOption clientOption,
SslClientConfig sslClientConfig,
List<ClientInterceptor> clientInterceptorList) {
List<ClientInterceptor> clientInterceptorList,
SslContext sslContext) {
this.factoryName = Objects.requireNonNull(factoryName, "factoryName");
this.executorQueueSize = executorQueueSize;
this.headerFactory = Objects.requireNonNull(headerFactory, "headerFactory");
// @Nullable
this.nameResolverProvider = nameResolverProvider;
this.clientOption = Objects.requireNonNull(clientOption, "clientOption");
this.sslClientConfig = Objects.requireNonNull(sslClientConfig, "sslClientConfig");

Objects.requireNonNull(clientInterceptorList, "clientInterceptorList");
this.clientInterceptorList = new ArrayList<>(clientInterceptorList);
// nullable
this.sslContext = sslContext;


ChannelType channelType = getChannelType();
this.channelType = channelType.getChannelType();
Expand Down Expand Up @@ -151,23 +151,15 @@ public ManagedChannel build(String channelName, String host, int port) {
}
setupClientOption(channelBuilder);

if (sslClientConfig.isEnable()) {
SslContext sslContext = null;
try {
SslContextFactory factory = new SslContextFactory(sslClientConfig.getSslProviderType());
sslContext = factory.forClient(sslClientConfig);
} catch (SSLException e) {
throw new SecurityException(e);
}
if (sslContext != null) {
logger.info("{} enable SslContext", channelName);
channelBuilder.sslContext(sslContext);
channelBuilder.negotiationType(NegotiationType.TLS);
}

channelBuilder.maxTraceEvents(clientOption.getMaxTraceEvent());

final ManagedChannel channel = channelBuilder.build();

return channel;
return channelBuilder.build();
}

@SuppressWarnings("deprecation")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,11 @@

import com.navercorp.pinpoint.common.util.Assert;
import com.navercorp.pinpoint.grpc.client.config.ClientOption;
import com.navercorp.pinpoint.grpc.client.config.SslOption;
import com.navercorp.pinpoint.grpc.security.SslClientConfig;
import com.navercorp.pinpoint.grpc.util.Resource;

import io.grpc.ClientInterceptor;
import io.grpc.NameResolverProvider;
import org.apache.logging.log4j.Logger;
import io.netty.handler.ssl.SslContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.LinkedList;
import java.util.Objects;
Expand All @@ -43,7 +40,7 @@ public class DefaultChannelFactoryBuilder implements ChannelFactoryBuilder {
private HeaderFactory headerFactory;

private ClientOption clientOption;
private SslOption sslOption;
private SslContext sslContext;

private final LinkedList<ClientInterceptor> clientInterceptorList = new LinkedList<>();
private NameResolverProvider nameResolverProvider;
Expand Down Expand Up @@ -81,9 +78,8 @@ public void setClientOption(ClientOption clientOption) {
}

@Override
public void setSslOption(SslOption sslOption) {
// nullable
this.sslOption = sslOption;
public void setSslContext(SslContext sslContext) {
this.sslContext = sslContext;
}

@Override
Expand All @@ -97,15 +93,8 @@ public ChannelFactory build() {
Objects.requireNonNull(headerFactory, "headerFactory");
Objects.requireNonNull(clientOption, "clientOption");

SslClientConfig sslClientConfig = SslClientConfig.DISABLED_CONFIG;
if (sslOption != null && sslOption.isEnable()) {
String providerType = sslOption.getProviderType();
Resource trustCertResource = sslOption.getTrustCertResource();
sslClientConfig = new SslClientConfig(true, providerType, trustCertResource);
}

return new DefaultChannelFactory(factoryName, executorQueueSize,
headerFactory, nameResolverProvider,
clientOption, sslClientConfig, clientInterceptorList);
clientOption, clientInterceptorList, sslContext);
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

import com.navercorp.pinpoint.common.util.CollectionUtils;
import com.navercorp.pinpoint.common.util.StringUtils;
import com.navercorp.pinpoint.grpc.util.Resource;
import io.grpc.netty.GrpcSslContexts;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
Expand All @@ -39,7 +38,7 @@
*/
public final class SslContextFactory {

private static final Logger LOGGER = LogManager.getLogger(SslContextFactory.class);
private final Logger LOGGER = LogManager.getLogger(SslContextFactory.class);

private final SslProvider sslProvider;

Expand All @@ -66,39 +65,39 @@ public SslContext forServer(InputStream keyCertChainInputStream, InputStream key
}
}

public SslContext forClient(SslClientConfig clientConfig) throws SSLException {
Objects.requireNonNull(clientConfig, "clientConfig");
public SslContext forClient(InputStream trustCertCollectionInputStream) throws SSLException {
Objects.requireNonNull(trustCertCollectionInputStream, "trustCertCollectionInputStream");

if (!clientConfig.isEnable()) {
throw new IllegalArgumentException("sslConfig is disabled.");
try {
SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();

sslContextBuilder.trustManager(trustCertCollectionInputStream);
return createSslContext(sslContextBuilder, sslProvider);
} catch (SSLException e) {
throw e;
} catch (Exception e) {
throw new SSLException(e);
}
}

public SslContext forClient() throws SSLException {
try {
SslContextBuilder sslContextBuilder = SslContextBuilder.forClient();

Resource trustCertResource = clientConfig.getTrustCertResource();
if (trustCertResource != null) {
sslContextBuilder.trustManager(trustCertResource.getInputStream());
} else {
// Loads default Root CA certificates (generally, from JAVA_HOME/lib/cacerts)
TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init((KeyStore)null);
sslContextBuilder.trustManager(trustManagerFactory);
}

SslProvider sslProvider = getSslProvider(clientConfig.getSslProviderType());
SslContext sslContext = createSslContext(sslContextBuilder, sslProvider);
// Loads default Root CA certificates (generally, from JAVA_HOME/lib/cacerts)
TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init((KeyStore)null);
sslContextBuilder.trustManager(trustManagerFactory);

assertValidCipherSuite(sslContext);

return sslContext;
return createSslContext(sslContextBuilder, sslProvider);
} catch (SSLException e) {
throw e;
} catch (Exception e) {
throw new SSLException(e);
}
}


private SslContext createSslContext(SslContextBuilder sslContextBuilder, SslProvider sslProvider) throws SSLException {
sslContextBuilder.sslProvider(sslProvider);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,7 @@ public void read(Properties properties) {
this.spanClientOption = readSpanClientOption(properties);

// Ssl
SslOption sslOption = readSslOption(properties);
this.sslOption = sslOption;
this.sslOption = readSslOption(properties);
}

private ClientOption readAgentClientOption(final Properties properties) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
import com.navercorp.pinpoint.profiler.context.provider.grpc.MetadataGrpcDataSenderProvider;
import com.navercorp.pinpoint.profiler.context.provider.grpc.ReconnectExecutorProvider;
import com.navercorp.pinpoint.profiler.context.provider.grpc.ReconnectSchedulerProvider;
import com.navercorp.pinpoint.profiler.context.provider.grpc.SSLContextProvider;
import com.navercorp.pinpoint.profiler.context.provider.grpc.SpanGrpcDataSenderProvider;
import com.navercorp.pinpoint.profiler.context.provider.grpc.StatGrpcDataSenderProvider;
import com.navercorp.pinpoint.profiler.context.thrift.MessageConverter;
Expand All @@ -55,6 +56,7 @@
import com.navercorp.pinpoint.profiler.sender.grpc.metric.DefaultChannelzScheduledReporter;
import io.grpc.LoadBalancerRegistry;
import io.grpc.NameResolverProvider;
import io.netty.handler.ssl.SslContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

Expand Down Expand Up @@ -94,6 +96,8 @@ protected void configure() {

registerGrpcProviders(grpcTransportConfig);

bind(SslContext.class).toProvider(SSLContextProvider.class).in(Scopes.SINGLETON);

// not singleton
bind(ReconnectExecutor.class).toProvider(ReconnectExecutorProvider.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import com.navercorp.pinpoint.grpc.client.HeaderFactory;
import com.navercorp.pinpoint.grpc.client.UnaryCallDeadlineInterceptor;
import com.navercorp.pinpoint.grpc.client.config.ClientOption;
import com.navercorp.pinpoint.grpc.client.config.SslOption;
import com.navercorp.pinpoint.profiler.context.active.ActiveTraceRepository;
import com.navercorp.pinpoint.profiler.context.grpc.config.GrpcTransportConfig;
import com.navercorp.pinpoint.profiler.context.module.AgentDataSender;
Expand All @@ -43,8 +42,9 @@
import com.navercorp.pinpoint.profiler.sender.grpc.ReconnectExecutor;
import io.grpc.ClientInterceptor;
import io.grpc.NameResolverProvider;
import org.apache.logging.log4j.Logger;
import io.netty.handler.ssl.SslContext;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.util.List;
import java.util.Objects;
Expand All @@ -69,6 +69,7 @@ public class AgentGrpcDataSenderProvider implements Provider<EnhancedDataSender<
private final ActiveTraceRepository activeTraceRepository;

private List<ClientInterceptor> clientInterceptorList;
private final Provider<SslContext> sslContextProvider;

@Inject
public AgentGrpcDataSenderProvider(GrpcTransportConfig grpcTransportConfig,
Expand All @@ -77,7 +78,8 @@ public AgentGrpcDataSenderProvider(GrpcTransportConfig grpcTransportConfig,
Provider<ReconnectExecutor> reconnectExecutor,
ScheduledExecutorService retransmissionExecutor,
NameResolverProvider nameResolverProvider,
ActiveTraceRepository activeTraceRepository) {
ActiveTraceRepository activeTraceRepository,
Provider<SslContext> sslContextProvider) {
this.grpcTransportConfig = Objects.requireNonNull(grpcTransportConfig, "grpcTransportConfig");
this.messageConverter = Objects.requireNonNull(messageConverter, "messageConverter");
this.headerFactory = Objects.requireNonNull(headerFactory, "headerFactory");
Expand All @@ -88,6 +90,8 @@ public AgentGrpcDataSenderProvider(GrpcTransportConfig grpcTransportConfig,

this.nameResolverProvider = Objects.requireNonNull(nameResolverProvider, "nameResolverProvider");
this.activeTraceRepository = Objects.requireNonNull(activeTraceRepository, "activeTraceRepository");

this.sslContextProvider = Objects.requireNonNull(sslContextProvider, "sslContextProvider");
}

@Inject(optional = true)
Expand Down Expand Up @@ -141,8 +145,8 @@ ChannelFactoryBuilder newChannelFactoryBuilder(boolean sslEnable) {
channelFactoryBuilder.setClientOption(clientOption);

if (sslEnable) {
SslOption sslOption = grpcTransportConfig.getSslOption();
channelFactoryBuilder.setSslOption(sslOption);
SslContext sslContext = sslContextProvider.get();
channelFactoryBuilder.setSslContext(sslContext);
}

return channelFactoryBuilder;
Expand Down
Loading

0 comments on commit 2c887aa

Please # to comment.