Skip to content

Commit

Permalink
Force SSL for all connections of Acceptor (#2231)
Browse files Browse the repository at this point in the history
* Force SSL for all connections

* Force SSL for all connections of Acceptor

* Force SSL option in ServerOptions
  • Loading branch information
chenBright authored Jun 25, 2023
1 parent 21a6aab commit bc6f30d
Show file tree
Hide file tree
Showing 11 changed files with 90 additions and 8 deletions.
12 changes: 11 additions & 1 deletion src/brpc/acceptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Acceptor::Acceptor(bthread_keytable_pool_t* pool)
, _listened_fd(-1)
, _acception_id(0)
, _empty_cond(&_map_mutex)
, _force_ssl(false)
, _ssl_ctx(NULL)
, _use_rdma(false) {
}
Expand All @@ -48,11 +49,18 @@ Acceptor::~Acceptor() {
}

int Acceptor::StartAccept(int listened_fd, int idle_timeout_sec,
const std::shared_ptr<SocketSSLContext>& ssl_ctx) {
const std::shared_ptr<SocketSSLContext>& ssl_ctx,
bool force_ssl) {
if (listened_fd < 0) {
LOG(FATAL) << "Invalid listened_fd=" << listened_fd;
return -1;
}

if (!ssl_ctx && force_ssl) {
LOG(ERROR) << "Fail to force SSL for all connections "
" because ssl_ctx is NULL";
return -1;
}

BAIDU_SCOPED_LOCK(_map_mutex);
if (_status == UNINITIALIZED) {
Expand All @@ -74,6 +82,7 @@ int Acceptor::StartAccept(int listened_fd, int idle_timeout_sec,
}
}
_idle_timeout_sec = idle_timeout_sec;
_force_ssl = force_ssl;
_ssl_ctx = ssl_ctx;

// Creation of _acception_id is inside lock so that OnNewConnections
Expand Down Expand Up @@ -274,6 +283,7 @@ void Acceptor::OnNewConnectionsUntilEAGAIN(Socket* acception) {
options.fd = in_fd;
butil::sockaddr2endpoint(&in_addr, in_len, &options.remote_side);
options.user = acception->user();
options.force_ssl = am->_force_ssl;
options.initial_ssl_ctx = am->_ssl_ctx;
#if BRPC_WITH_RDMA
if (am->_use_rdma) {
Expand Down
4 changes: 3 additions & 1 deletion src/brpc/acceptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ friend class Server;
// `idle_timeout_sec' > 0
// Return 0 on success, -1 otherwise.
int StartAccept(int listened_fd, int idle_timeout_sec,
const std::shared_ptr<SocketSSLContext>& ssl_ctx);
const std::shared_ptr<SocketSSLContext>& ssl_ctx,
bool force_ssl);

// [thread-safe] Stop accepting connections.
// `closewait_ms' is not used anymore.
Expand Down Expand Up @@ -106,6 +107,7 @@ friend class Server;
// The map containing all the accepted sockets
SocketMap _socket_map;

bool _force_ssl;
std::shared_ptr<SocketSSLContext> _ssl_ctx;

// Whether to use rdma or not
Expand Down
11 changes: 9 additions & 2 deletions src/brpc/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ ServerOptions::ServerOptions()
, bthread_init_count(0)
, internal_port(-1)
, has_builtin_services(true)
, force_ssl(false)
, use_rdma(false)
, http_master_service(NULL)
, health_reporter(NULL)
Expand Down Expand Up @@ -933,6 +934,10 @@ int Server::StartInternal(const butil::EndPoint& endpoint,
return -1;
}
}
} else if (_options.force_ssl) {
LOG(ERROR) << "Fail to force SSL for all connections "
"without ServerOptions.ssl_options";
return -1;
}

_concurrency = 0;
Expand Down Expand Up @@ -1045,7 +1050,8 @@ int Server::StartInternal(const butil::EndPoint& endpoint,

// Pass ownership of `sockfd' to `_am'
if (_am->StartAccept(sockfd, _options.idle_timeout_sec,
_default_ssl_ctx) != 0) {
_default_ssl_ctx,
_options.force_ssl) != 0) {
LOG(ERROR) << "Fail to start acceptor";
return -1;
}
Expand Down Expand Up @@ -1085,7 +1091,8 @@ int Server::StartInternal(const butil::EndPoint& endpoint,
}
// Pass ownership of `sockfd' to `_internal_am'
if (_internal_am->StartAccept(sockfd, _options.idle_timeout_sec,
_default_ssl_ctx) != 0) {
_default_ssl_ctx,
false) != 0) {
LOG(ERROR) << "Fail to start internal_acceptor";
return -1;
}
Expand Down
3 changes: 3 additions & 0 deletions src/brpc/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ struct ServerOptions {
const ServerSSLOptions& ssl_options() const { return *_ssl_options; }
ServerSSLOptions* mutable_ssl_options();

// Force ssl for all connections of the port to Start().
bool force_ssl;

// Whether the server uses rdma or not
// Default: false
bool use_rdma;
Expand Down
5 changes: 5 additions & 0 deletions src/brpc/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,7 @@ int Socket::Create(const SocketOptions& options, SocketId* id) {
m->SetFailed(rc2, "Fail to create auth_id: %s", berror(rc2));
return -1;
}
m->_force_ssl = options.force_ssl;
// Disable SSL check if there is no SSL context
m->_ssl_state = (options.initial_ssl_ctx == NULL ? SSL_OFF : SSL_UNKNOWN);
m->_ssl_session = NULL;
Expand Down Expand Up @@ -2026,6 +2027,10 @@ ssize_t Socket::DoRead(size_t size_hint) {
}
// _ssl_state has been set
if (ssl_state() == SSL_OFF) {
if (_force_ssl) {
errno = ESSL;
return -1;
}
CHECK(_rdma_state == RDMA_OFF);
return _read_buf.append_from_file_descriptor(fd(), size_hint);
}
Expand Down
4 changes: 4 additions & 0 deletions src/brpc/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ struct SocketOptions {
// one thread at any time.
void (*on_edge_triggered_events)(Socket*);
int health_check_interval_s;
// Only accept ssl connection.
bool force_ssl;
std::shared_ptr<SocketSSLContext> initial_ssl_ctx;
bool use_rdma;
bthread_keytable_pool_t* keytable_pool;
Expand Down Expand Up @@ -826,6 +828,8 @@ friend void DereferenceSocket(Socket*);
// exists in server side
AuthContext* _auth_context;

// Only accept ssl connection.
bool _force_ssl;
SSLState _ssl_state;
// SSL objects cannot be read and written at the same time.
// Use mutex to protect SSL objects when ssl_state is SSL_CONNECTED.
Expand Down
1 change: 1 addition & 0 deletions src/brpc/socket_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ inline SocketOptions::SocketOptions()
, user(NULL)
, on_edge_triggered_events(NULL)
, health_check_interval_s(-1)
, force_ssl(false)
, use_rdma(false)
, keytable_pool(NULL)
, conn(NULL)
Expand Down
2 changes: 1 addition & 1 deletion test/brpc_channel_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ class ChannelTest : public ::testing::Test{
return -1;
}
}
if (_messenger.StartAccept(listening_fd, -1, NULL) != 0) {
if (_messenger.StartAccept(listening_fd, -1, NULL, false) != 0) {
return -1;
}
return 0;
Expand Down
2 changes: 1 addition & 1 deletion test/brpc_input_messenger_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ TEST_F(MessengerTest, dispatch_tasks) {
ASSERT_TRUE(listening_fd > 0);
butil::make_non_blocking(listening_fd);
ASSERT_EQ(0, messenger[i].AddHandler(pairs[0]));
ASSERT_EQ(0, messenger[i].StartAccept(listening_fd, -1, NULL));
ASSERT_EQ(0, messenger[i].StartAccept(listening_fd, -1, NULL, false));
}

for (size_t i = 0; i < NCLIENT; ++i) {
Expand Down
4 changes: 2 additions & 2 deletions test/brpc_socket_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ TEST_F(SocketTest, single_threaded_connect_and_write) {
ASSERT_TRUE(listening_fd > 0);
butil::make_non_blocking(listening_fd);
ASSERT_EQ(0, messenger->AddHandler(pairs[0]));
ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL));
ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL, false));

brpc::SocketId id = 8888;
brpc::SocketOptions options;
Expand Down Expand Up @@ -727,7 +727,7 @@ TEST_F(SocketTest, health_check) {
ASSERT_TRUE(listening_fd > 0);
butil::make_non_blocking(listening_fd);
ASSERT_EQ(0, messenger->AddHandler(pairs[0]));
ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL));
ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL, false));

int64_t start_time = butil::gettimeofday_us();
nref = -1;
Expand Down
50 changes: 50 additions & 0 deletions test/brpc_ssl_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "echo.pb.h"

namespace brpc {

void ExtractHostnames(X509* x, std::vector<std::string>* hostnames);
} // namespace brpc

Expand Down Expand Up @@ -175,6 +176,55 @@ TEST_F(SSLTest, sanity) {
ASSERT_EQ(0, server.Join());
}

TEST_F(SSLTest, force_ssl) {
const int port = 8613;
brpc::Server server;
brpc::ServerOptions options;
EchoServiceImpl echo_svc;
ASSERT_EQ(0, server.AddService(
&echo_svc, brpc::SERVER_DOESNT_OWN_SERVICE));

options.force_ssl = true;
ASSERT_EQ(-1, server.Start(port, &options));

brpc::CertInfo cert;
cert.certificate = "cert1.crt";
cert.private_key = "cert1.key";
options.mutable_ssl_options()->default_cert = cert;

ASSERT_EQ(0, server.Start(port, &options));

test::EchoRequest req;
req.set_message(EXP_REQUEST);
{
brpc::Channel channel;
brpc::ChannelOptions coptions;
coptions.mutable_ssl_options();
coptions.mutable_ssl_options()->sni_name = "localhost";
ASSERT_EQ(0, channel.Init("localhost", port, &coptions));

brpc::Controller cntl;
test::EchoService_Stub stub(&channel);
test::EchoResponse res;
stub.Echo(&cntl, &req, &res, NULL);
EXPECT_EQ(EXP_RESPONSE, res.message()) << cntl.ErrorText();
}

{
brpc::Channel channel;
ASSERT_EQ(0, channel.Init("localhost", port, NULL));

brpc::Controller cntl;
test::EchoService_Stub stub(&channel);
test::EchoResponse res;
stub.Echo(&cntl, &req, &res, NULL);
EXPECT_TRUE(cntl.Failed());
}

ASSERT_EQ(0, server.Stop(0));
ASSERT_EQ(0, server.Join());
}

void CheckCert(const char* cname, const char* cert) {
const int port = 8613;
brpc::Channel channel;
Expand Down

0 comments on commit bc6f30d

Please # to comment.