From 90a165482c70431c6423bce95f867cbe63800c0a Mon Sep 17 00:00:00 2001 From: Bright Chen Date: Sun, 25 Jun 2023 14:35:36 +0800 Subject: [PATCH] Force SSL for all connections of Acceptor (#2231) * Force SSL for all connections * Force SSL for all connections of Acceptor * Force SSL option in ServerOptions --- src/brpc/acceptor.cpp | 12 ++++++- src/brpc/acceptor.h | 4 ++- src/brpc/server.cpp | 11 ++++-- src/brpc/server.h | 3 ++ src/brpc/socket.cpp | 5 +++ src/brpc/socket.h | 4 +++ src/brpc/socket_inl.h | 1 + test/brpc_channel_unittest.cpp | 2 +- test/brpc_input_messenger_unittest.cpp | 2 +- test/brpc_socket_unittest.cpp | 4 +-- test/brpc_ssl_unittest.cpp | 50 ++++++++++++++++++++++++++ 11 files changed, 90 insertions(+), 8 deletions(-) diff --git a/src/brpc/acceptor.cpp b/src/brpc/acceptor.cpp index 62732881f2..f2d1c0871c 100644 --- a/src/brpc/acceptor.cpp +++ b/src/brpc/acceptor.cpp @@ -38,6 +38,7 @@ Acceptor::Acceptor(bthread_keytable_pool_t* pool) , _listened_fd(-1) , _acception_id(0) , _empty_cond(&_map_mutex) + , _force_ssl(false) , _ssl_ctx(NULL) , _use_rdma(false) { } @@ -48,11 +49,18 @@ Acceptor::~Acceptor() { } int Acceptor::StartAccept(int listened_fd, int idle_timeout_sec, - const std::shared_ptr& ssl_ctx) { + const std::shared_ptr& ssl_ctx, + bool force_ssl) { if (listened_fd < 0) { LOG(FATAL) << "Invalid listened_fd=" << listened_fd; return -1; } + + if (!ssl_ctx && force_ssl) { + LOG(ERROR) << "Fail to force SSL for all connections " + " because ssl_ctx is NULL"; + return -1; + } BAIDU_SCOPED_LOCK(_map_mutex); if (_status == UNINITIALIZED) { @@ -74,6 +82,7 @@ int Acceptor::StartAccept(int listened_fd, int idle_timeout_sec, } } _idle_timeout_sec = idle_timeout_sec; + _force_ssl = force_ssl; _ssl_ctx = ssl_ctx; // Creation of _acception_id is inside lock so that OnNewConnections @@ -274,6 +283,7 @@ void Acceptor::OnNewConnectionsUntilEAGAIN(Socket* acception) { options.fd = in_fd; butil::sockaddr2endpoint(&in_addr, in_len, &options.remote_side); options.user = acception->user(); + options.force_ssl = am->_force_ssl; options.initial_ssl_ctx = am->_ssl_ctx; #if BRPC_WITH_RDMA if (am->_use_rdma) { diff --git a/src/brpc/acceptor.h b/src/brpc/acceptor.h index c442a60c8a..c82cdcc19a 100644 --- a/src/brpc/acceptor.h +++ b/src/brpc/acceptor.h @@ -55,7 +55,8 @@ friend class Server; // `idle_timeout_sec' > 0 // Return 0 on success, -1 otherwise. int StartAccept(int listened_fd, int idle_timeout_sec, - const std::shared_ptr& ssl_ctx); + const std::shared_ptr& ssl_ctx, + bool force_ssl); // [thread-safe] Stop accepting connections. // `closewait_ms' is not used anymore. @@ -106,6 +107,7 @@ friend class Server; // The map containing all the accepted sockets SocketMap _socket_map; + bool _force_ssl; std::shared_ptr _ssl_ctx; // Whether to use rdma or not diff --git a/src/brpc/server.cpp b/src/brpc/server.cpp index 4953f88c5d..ce5a0dd2a3 100644 --- a/src/brpc/server.cpp +++ b/src/brpc/server.cpp @@ -139,6 +139,7 @@ ServerOptions::ServerOptions() , bthread_init_count(0) , internal_port(-1) , has_builtin_services(true) + , force_ssl(false) , use_rdma(false) , http_master_service(NULL) , health_reporter(NULL) @@ -933,6 +934,10 @@ int Server::StartInternal(const butil::EndPoint& endpoint, return -1; } } + } else if (_options.force_ssl) { + LOG(ERROR) << "Fail to force SSL for all connections " + "without ServerOptions.ssl_options"; + return -1; } _concurrency = 0; @@ -1045,7 +1050,8 @@ int Server::StartInternal(const butil::EndPoint& endpoint, // Pass ownership of `sockfd' to `_am' if (_am->StartAccept(sockfd, _options.idle_timeout_sec, - _default_ssl_ctx) != 0) { + _default_ssl_ctx, + _options.force_ssl) != 0) { LOG(ERROR) << "Fail to start acceptor"; return -1; } @@ -1085,7 +1091,8 @@ int Server::StartInternal(const butil::EndPoint& endpoint, } // Pass ownership of `sockfd' to `_internal_am' if (_internal_am->StartAccept(sockfd, _options.idle_timeout_sec, - _default_ssl_ctx) != 0) { + _default_ssl_ctx, + false) != 0) { LOG(ERROR) << "Fail to start internal_acceptor"; return -1; } diff --git a/src/brpc/server.h b/src/brpc/server.h index c00f9dc898..e598a6e8b0 100644 --- a/src/brpc/server.h +++ b/src/brpc/server.h @@ -217,6 +217,9 @@ struct ServerOptions { const ServerSSLOptions& ssl_options() const { return *_ssl_options; } ServerSSLOptions* mutable_ssl_options(); + // Force ssl for all connections of the port to Start(). + bool force_ssl; + // Whether the server uses rdma or not // Default: false bool use_rdma; diff --git a/src/brpc/socket.cpp b/src/brpc/socket.cpp index e0a69422fc..c49ca08358 100644 --- a/src/brpc/socket.cpp +++ b/src/brpc/socket.cpp @@ -698,6 +698,7 @@ int Socket::Create(const SocketOptions& options, SocketId* id) { m->SetFailed(rc2, "Fail to create auth_id: %s", berror(rc2)); return -1; } + m->_force_ssl = options.force_ssl; // Disable SSL check if there is no SSL context m->_ssl_state = (options.initial_ssl_ctx == NULL ? SSL_OFF : SSL_UNKNOWN); m->_ssl_session = NULL; @@ -2026,6 +2027,10 @@ ssize_t Socket::DoRead(size_t size_hint) { } // _ssl_state has been set if (ssl_state() == SSL_OFF) { + if (_force_ssl) { + errno = ESSL; + return -1; + } CHECK(_rdma_state == RDMA_OFF); return _read_buf.append_from_file_descriptor(fd(), size_hint); } diff --git a/src/brpc/socket.h b/src/brpc/socket.h index bd753f6069..eff9474ce5 100644 --- a/src/brpc/socket.h +++ b/src/brpc/socket.h @@ -205,6 +205,8 @@ struct SocketOptions { // one thread at any time. void (*on_edge_triggered_events)(Socket*); int health_check_interval_s; + // Only accept ssl connection. + bool force_ssl; std::shared_ptr initial_ssl_ctx; bool use_rdma; bthread_keytable_pool_t* keytable_pool; @@ -826,6 +828,8 @@ friend void DereferenceSocket(Socket*); // exists in server side AuthContext* _auth_context; + // Only accept ssl connection. + bool _force_ssl; SSLState _ssl_state; // SSL objects cannot be read and written at the same time. // Use mutex to protect SSL objects when ssl_state is SSL_CONNECTED. diff --git a/src/brpc/socket_inl.h b/src/brpc/socket_inl.h index 9423bfdf0e..df93ac7109 100644 --- a/src/brpc/socket_inl.h +++ b/src/brpc/socket_inl.h @@ -57,6 +57,7 @@ inline SocketOptions::SocketOptions() , user(NULL) , on_edge_triggered_events(NULL) , health_check_interval_s(-1) + , force_ssl(false) , use_rdma(false) , keytable_pool(NULL) , conn(NULL) diff --git a/test/brpc_channel_unittest.cpp b/test/brpc_channel_unittest.cpp index 4de8e350b5..694f3f7f5c 100644 --- a/test/brpc_channel_unittest.cpp +++ b/test/brpc_channel_unittest.cpp @@ -263,7 +263,7 @@ class ChannelTest : public ::testing::Test{ return -1; } } - if (_messenger.StartAccept(listening_fd, -1, NULL) != 0) { + if (_messenger.StartAccept(listening_fd, -1, NULL, false) != 0) { return -1; } return 0; diff --git a/test/brpc_input_messenger_unittest.cpp b/test/brpc_input_messenger_unittest.cpp index 7682b83b3f..00b14ed41e 100644 --- a/test/brpc_input_messenger_unittest.cpp +++ b/test/brpc_input_messenger_unittest.cpp @@ -169,7 +169,7 @@ TEST_F(MessengerTest, dispatch_tasks) { ASSERT_TRUE(listening_fd > 0); butil::make_non_blocking(listening_fd); ASSERT_EQ(0, messenger[i].AddHandler(pairs[0])); - ASSERT_EQ(0, messenger[i].StartAccept(listening_fd, -1, NULL)); + ASSERT_EQ(0, messenger[i].StartAccept(listening_fd, -1, NULL, false)); } for (size_t i = 0; i < NCLIENT; ++i) { diff --git a/test/brpc_socket_unittest.cpp b/test/brpc_socket_unittest.cpp index 3f08091115..36a3b1b019 100644 --- a/test/brpc_socket_unittest.cpp +++ b/test/brpc_socket_unittest.cpp @@ -339,7 +339,7 @@ TEST_F(SocketTest, single_threaded_connect_and_write) { ASSERT_TRUE(listening_fd > 0); butil::make_non_blocking(listening_fd); ASSERT_EQ(0, messenger->AddHandler(pairs[0])); - ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL)); + ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL, false)); brpc::SocketId id = 8888; brpc::SocketOptions options; @@ -727,7 +727,7 @@ TEST_F(SocketTest, health_check) { ASSERT_TRUE(listening_fd > 0); butil::make_non_blocking(listening_fd); ASSERT_EQ(0, messenger->AddHandler(pairs[0])); - ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL)); + ASSERT_EQ(0, messenger->StartAccept(listening_fd, -1, NULL, false)); int64_t start_time = butil::gettimeofday_us(); nref = -1; diff --git a/test/brpc_ssl_unittest.cpp b/test/brpc_ssl_unittest.cpp index f32dbcb7c2..7d58e4550c 100644 --- a/test/brpc_ssl_unittest.cpp +++ b/test/brpc_ssl_unittest.cpp @@ -35,6 +35,7 @@ #include "echo.pb.h" namespace brpc { + void ExtractHostnames(X509* x, std::vector* hostnames); } // namespace brpc @@ -175,6 +176,55 @@ TEST_F(SSLTest, sanity) { ASSERT_EQ(0, server.Join()); } +TEST_F(SSLTest, force_ssl) { + const int port = 8613; + brpc::Server server; + brpc::ServerOptions options; + EchoServiceImpl echo_svc; + ASSERT_EQ(0, server.AddService( + &echo_svc, brpc::SERVER_DOESNT_OWN_SERVICE)); + + options.force_ssl = true; + ASSERT_EQ(-1, server.Start(port, &options)); + + brpc::CertInfo cert; + cert.certificate = "cert1.crt"; + cert.private_key = "cert1.key"; + options.mutable_ssl_options()->default_cert = cert; + + ASSERT_EQ(0, server.Start(port, &options)); + + test::EchoRequest req; + req.set_message(EXP_REQUEST); + { + brpc::Channel channel; + brpc::ChannelOptions coptions; + coptions.mutable_ssl_options(); + coptions.mutable_ssl_options()->sni_name = "localhost"; + ASSERT_EQ(0, channel.Init("localhost", port, &coptions)); + + brpc::Controller cntl; + test::EchoService_Stub stub(&channel); + test::EchoResponse res; + stub.Echo(&cntl, &req, &res, NULL); + EXPECT_EQ(EXP_RESPONSE, res.message()) << cntl.ErrorText(); + } + + { + brpc::Channel channel; + ASSERT_EQ(0, channel.Init("localhost", port, NULL)); + + brpc::Controller cntl; + test::EchoService_Stub stub(&channel); + test::EchoResponse res; + stub.Echo(&cntl, &req, &res, NULL); + EXPECT_TRUE(cntl.Failed()); + } + + ASSERT_EQ(0, server.Stop(0)); + ASSERT_EQ(0, server.Join()); +} + void CheckCert(const char* cname, const char* cert) { const int port = 8613; brpc::Channel channel;