Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Force SSL for all connections of Acceptor #2231

Merged
merged 3 commits into from
Jun 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion src/brpc/acceptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Acceptor::Acceptor(bthread_keytable_pool_t* pool)
, _listened_fd(-1)
, _acception_id(0)
, _empty_cond(&_map_mutex)
, _force_ssl(false)
, _ssl_ctx(NULL)
, _use_rdma(false) {
}
Expand All @@ -48,11 +49,18 @@ Acceptor::~Acceptor() {
}

int Acceptor::StartAccept(int listened_fd, int idle_timeout_sec,
const std::shared_ptr<SocketSSLContext>& ssl_ctx) {
const std::shared_ptr<SocketSSLContext>& ssl_ctx,
bool force_ssl) {
if (listened_fd < 0) {
LOG(FATAL) << "Invalid listened_fd=" << listened_fd;
return -1;
}

if (!ssl_ctx && force_ssl) {
LOG(ERROR) << "Fail to force SSL for all connections "
" because ssl_ctx is NULL";
return -1;
}

BAIDU_SCOPED_LOCK(_map_mutex);
if (_status == UNINITIALIZED) {
Expand All @@ -74,6 +82,7 @@ int Acceptor::StartAccept(int listened_fd, int idle_timeout_sec,
}
}
_idle_timeout_sec = idle_timeout_sec;
_force_ssl = force_ssl;
_ssl_ctx = ssl_ctx;

// Creation of _acception_id is inside lock so that OnNewConnections
Expand Down Expand Up @@ -274,6 +283,7 @@ void Acceptor::OnNewConnectionsUntilEAGAIN(Socket* acception) {
options.fd = in_fd;
butil::sockaddr2endpoint(&in_addr, in_len, &options.remote_side);
options.user = acception->user();
options.force_ssl = am->_force_ssl;
options.initial_ssl_ctx = am->_ssl_ctx;
#if BRPC_WITH_RDMA
if (am->_use_rdma) {
Expand Down
4 changes: 3 additions & 1 deletion src/brpc/acceptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ friend class Server;
// `idle_timeout_sec' > 0
// Return 0 on success, -1 otherwise.
int StartAccept(int listened_fd, int idle_timeout_sec,
const std::shared_ptr<SocketSSLContext>& ssl_ctx);
const std::shared_ptr<SocketSSLContext>& ssl_ctx,
bool force_ssl);

// [thread-safe] Stop accepting connections.
// `closewait_ms' is not used anymore.
Expand Down Expand Up @@ -106,6 +107,7 @@ friend class Server;
// The map containing all the accepted sockets
SocketMap _socket_map;

bool _force_ssl;
std::shared_ptr<SocketSSLContext> _ssl_ctx;

// Whether to use rdma or not
Expand Down
11 changes: 9 additions & 2 deletions src/brpc/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ ServerOptions::ServerOptions()
, bthread_init_count(0)
, internal_port(-1)
, has_builtin_services(true)
, force_ssl(false)
, use_rdma(false)
, http_master_service(NULL)
, health_reporter(NULL)
Expand Down Expand Up @@ -932,6 +933,10 @@ int Server::StartInternal(const butil::EndPoint& endpoint,
return -1;
}
}
} else if (_options.force_ssl) {
LOG(ERROR) << "Fail to force SSL for all connections "
"without ServerOptions.ssl_options";
return -1;
}

_concurrency = 0;
Expand Down Expand Up @@ -1044,7 +1049,8 @@ int Server::StartInternal(const butil::EndPoint& endpoint,

// Pass ownership of `sockfd' to `_am'
if (_am->StartAccept(sockfd, _options.idle_timeout_sec,
_default_ssl_ctx) != 0) {
_default_ssl_ctx,
_options.force_ssl) != 0) {
LOG(ERROR) << "Fail to start acceptor";
return -1;
}
Expand Down Expand Up @@ -1084,7 +1090,8 @@ int Server::StartInternal(const butil::EndPoint& endpoint,
}
// Pass ownership of `sockfd' to `_internal_am'
if (_internal_am->StartAccept(sockfd, _options.idle_timeout_sec,
_default_ssl_ctx) != 0) {
_default_ssl_ctx,
false) != 0) {
LOG(ERROR) << "Fail to start internal_acceptor";
return -1;
}
Expand Down
3 changes: 3 additions & 0 deletions src/brpc/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ struct ServerOptions {
const ServerSSLOptions& ssl_options() const { return *_ssl_options; }
ServerSSLOptions* mutable_ssl_options();

// Force ssl for all connections of the port to Start().
bool force_ssl;

// Whether the server uses rdma or not
// Default: false
bool use_rdma;
Expand Down
5 changes: 5 additions & 0 deletions src/brpc/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -698,6 +698,7 @@ int Socket::Create(const SocketOptions& options, SocketId* id) {
m->SetFailed(rc2, "Fail to create auth_id: %s", berror(rc2));
return -1;
}
m->_force_ssl = options.force_ssl;
// Disable SSL check if there is no SSL context
m->_ssl_state = (options.initial_ssl_ctx == NULL ? SSL_OFF : SSL_UNKNOWN);
m->_ssl_session = NULL;
Expand Down Expand Up @@ -2021,6 +2022,10 @@ ssize_t Socket::DoRead(size_t size_hint) {
}
// _ssl_state has been set
if (ssl_state() == SSL_OFF) {
if (_force_ssl) {
errno = ESSL;
return -1;
}
CHECK(_rdma_state == RDMA_OFF);
return _read_buf.append_from_file_descriptor(fd(), size_hint);
}
Expand Down
4 changes: 4 additions & 0 deletions src/brpc/socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,8 @@ struct SocketOptions {
// one thread at any time.
void (*on_edge_triggered_events)(Socket*);
int health_check_interval_s;
// Only accept ssl connection.
bool force_ssl;
std::shared_ptr<SocketSSLContext> initial_ssl_ctx;
bool use_rdma;
bthread_keytable_pool_t* keytable_pool;
Expand Down Expand Up @@ -826,6 +828,8 @@ friend void DereferenceSocket(Socket*);
// exists in server side
AuthContext* _auth_context;

// Only accept ssl connection.
bool _force_ssl;
SSLState _ssl_state;
SSL* _ssl_session; // owner
std::shared_ptr<SocketSSLContext> _ssl_ctx;
Expand Down
1 change: 1 addition & 0 deletions src/brpc/socket_inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ inline SocketOptions::SocketOptions()
, user(NULL)
, on_edge_triggered_events(NULL)
, health_check_interval_s(-1)
, force_ssl(false)
, use_rdma(false)
, keytable_pool(NULL)
, conn(NULL)
Expand Down
2 changes: 1 addition & 1 deletion test/brpc_channel_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ class ChannelTest : public ::testing::Test{
return -1;
}
}
if (_messenger.StartAccept(listening_fd, -1, NULL) != 0) {
if (_messenger.StartAccept(listening_fd, -1, NULL, false) != 0) {
return -1;
}
return 0;
Expand Down
2 changes: 1 addition & 1 deletion test/brpc_input_messenger_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ TEST_F(MessengerTest, dispatch_tasks) {
ASSERT_TRUE(listening_fd > 0);
butil::make_non_blocking(listening_fd);
ASSERT_EQ(0, messenger[i].AddHandler(pairs[0]));
ASSERT_EQ(0, messenger[i].StartAccept(listening_fd, -1, NULL));
ASSERT_EQ(0, messenger[i].StartAccept(listening_fd, -1, NULL, false));
}

for (size_t i = 0; i < NCLIENT; ++i) {
Expand Down
4 changes: 2 additions & 2 deletions test/brpc_socket_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,7 @@ TEST_F(SocketTest, single_threaded_connect_and_write) {
ASSERT_TRUE(listening_fd > 0);
butil::make_non_blocking(listening_fd);
ASSERT_EQ(0, messenger->AddHandler(pairs[0]));
ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL));
ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL, false));

brpc::SocketId id = 8888;
brpc::SocketOptions options;
Expand Down Expand Up @@ -727,7 +727,7 @@ TEST_F(SocketTest, health_check) {
ASSERT_TRUE(listening_fd > 0);
butil::make_non_blocking(listening_fd);
ASSERT_EQ(0, messenger->AddHandler(pairs[0]));
ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL));
ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL, false));

int64_t start_time = butil::gettimeofday_us();
nref = -1;
Expand Down
50 changes: 50 additions & 0 deletions test/brpc_ssl_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "echo.pb.h"

namespace brpc {

void ExtractHostnames(X509* x, std::vector<std::string>* hostnames);
} // namespace brpc

Expand Down Expand Up @@ -175,6 +176,55 @@ TEST_F(SSLTest, sanity) {
ASSERT_EQ(0, server.Join());
}

TEST_F(SSLTest, force_ssl) {
const int port = 8613;
brpc::Server server;
brpc::ServerOptions options;
EchoServiceImpl echo_svc;
ASSERT_EQ(0, server.AddService(
&echo_svc, brpc::SERVER_DOESNT_OWN_SERVICE));

options.force_ssl = true;
ASSERT_EQ(-1, server.Start(port, &options));

brpc::CertInfo cert;
cert.certificate = "cert1.crt";
cert.private_key = "cert1.key";
options.mutable_ssl_options()->default_cert = cert;

ASSERT_EQ(0, server.Start(port, &options));

test::EchoRequest req;
req.set_message(EXP_REQUEST);
{
brpc::Channel channel;
brpc::ChannelOptions coptions;
coptions.mutable_ssl_options();
coptions.mutable_ssl_options()->sni_name = "localhost";
ASSERT_EQ(0, channel.Init("localhost", port, &coptions));

brpc::Controller cntl;
test::EchoService_Stub stub(&channel);
test::EchoResponse res;
stub.Echo(&cntl, &req, &res, NULL);
EXPECT_EQ(EXP_RESPONSE, res.message()) << cntl.ErrorText();
}

{
brpc::Channel channel;
ASSERT_EQ(0, channel.Init("localhost", port, NULL));

brpc::Controller cntl;
test::EchoService_Stub stub(&channel);
test::EchoResponse res;
stub.Echo(&cntl, &req, &res, NULL);
EXPECT_TRUE(cntl.Failed());
}

ASSERT_EQ(0, server.Stop(0));
ASSERT_EQ(0, server.Join());
}

void CheckCert(const char* cname, const char* cert) {
const int port = 8613;
brpc::Channel channel;
Expand Down