From 558eeac3a7e1a3d2364d1384def5d6ec8941fe33 Mon Sep 17 00:00:00 2001 From: Florian Albrechtskirchinger Date: Wed, 19 Feb 2025 19:15:31 +0100 Subject: [PATCH] Refactor streams: rename is_* to wait_* for clarity - Replace is_readable() with wait_readable() and is_writable() with wait_writable() in the Stream interface. - Implement a new is_readable() function with semantics that more closely reflect its name. It returns immediately whether data is available for reading, without waiting. - Update call sites of is_writable(), removing redundant checks. --- httplib.h | 64 ++++++++++++++++++++--------------- test/fuzzing/server_fuzzer.cc | 4 ++- test/test.cc | 10 +++--- 3 files changed, 45 insertions(+), 33 deletions(-) diff --git a/httplib.h b/httplib.h index b833e18b97..999a069ed6 100644 --- a/httplib.h +++ b/httplib.h @@ -751,7 +751,8 @@ class Stream { virtual ~Stream() = default; virtual bool is_readable() const = 0; - virtual bool is_writable() const = 0; + virtual bool wait_readable() const = 0; + virtual bool wait_writable() const = 0; virtual ssize_t read(char *ptr, size_t size) = 0; virtual ssize_t write(const char *ptr, size_t size) = 0; @@ -2466,7 +2467,8 @@ class BufferStream final : public Stream { ~BufferStream() override = default; bool is_readable() const override; - bool is_writable() const override; + bool wait_readable() const override; + bool wait_writable() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -3380,7 +3382,8 @@ class SocketStream final : public Stream { ~SocketStream() override; bool is_readable() const override; - bool is_writable() const override; + bool wait_readable() const override; + bool wait_writable() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -3416,7 +3419,8 @@ class SSLSocketStream final : public Stream { ~SSLSocketStream() override; bool is_readable() const override; - bool is_writable() const override; + bool wait_readable() const override; + bool wait_writable() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -4578,7 +4582,7 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider, data_sink.write = [&](const char *d, size_t l) -> bool { if (ok) { - if (strm.is_writable() && write_data(strm, d, l)) { + if (write_data(strm, d, l)) { offset += l; } else { ok = false; @@ -4587,10 +4591,10 @@ inline bool write_content(Stream &strm, const ContentProvider &content_provider, return ok; }; - data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; while (offset < end_offset && !is_shutting_down()) { - if (!strm.is_writable()) { + if (!strm.wait_writable()) { error = Error::Write; return false; } else if (!content_provider(offset, end_offset - offset, data_sink)) { @@ -4628,17 +4632,17 @@ write_content_without_length(Stream &strm, data_sink.write = [&](const char *d, size_t l) -> bool { if (ok) { offset += l; - if (!strm.is_writable() || !write_data(strm, d, l)) { ok = false; } + if (!write_data(strm, d, l)) { ok = false; } } return ok; }; - data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; data_sink.done = [&](void) { data_available = false; }; while (data_available && !is_shutting_down()) { - if (!strm.is_writable()) { + if (!strm.wait_writable()) { return false; } else if (!content_provider(offset, 0, data_sink)) { return false; @@ -4673,10 +4677,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, // Emit chunked response header and footer for each chunk auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; - if (!strm.is_writable() || - !write_data(strm, chunk.data(), chunk.size())) { - ok = false; - } + if (!write_data(strm, chunk.data(), chunk.size())) { ok = false; } } } else { ok = false; @@ -4685,7 +4686,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, return ok; }; - data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; auto done_with_trailer = [&](const Headers *trailer) { if (!ok) { return; } @@ -4705,8 +4706,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, if (!payload.empty()) { // Emit chunked response header and footer for each chunk auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; - if (!strm.is_writable() || - !write_data(strm, chunk.data(), chunk.size())) { + if (!write_data(strm, chunk.data(), chunk.size())) { ok = false; return; } @@ -4738,7 +4738,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, }; while (data_available && !is_shutting_down()) { - if (!strm.is_writable()) { + if (!strm.wait_writable()) { error = Error::Write; return false; } else if (!content_provider(offset, 0, data_sink)) { @@ -6029,6 +6029,10 @@ inline SocketStream::SocketStream( inline SocketStream::~SocketStream() = default; inline bool SocketStream::is_readable() const { + return read_buff_off_ < read_buff_content_size_; +} + +inline bool SocketStream::wait_readable() const { if (max_timeout_msec_ <= 0) { return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; } @@ -6041,7 +6045,7 @@ inline bool SocketStream::is_readable() const { return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; } -inline bool SocketStream::is_writable() const { +inline bool SocketStream::wait_writable() const { return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && is_socket_alive(sock_); } @@ -6068,7 +6072,7 @@ inline ssize_t SocketStream::read(char *ptr, size_t size) { } } - if (!is_readable()) { return -1; } + if (!wait_readable()) { return -1; } read_buff_off_ = 0; read_buff_content_size_ = 0; @@ -6093,7 +6097,7 @@ inline ssize_t SocketStream::read(char *ptr, size_t size) { } inline ssize_t SocketStream::write(const char *ptr, size_t size) { - if (!is_writable()) { return -1; } + if (!wait_writable()) { return -1; } #if defined(_WIN32) && !defined(_WIN64) size = @@ -6124,7 +6128,9 @@ inline time_t SocketStream::duration() const { // Buffer stream implementation inline bool BufferStream::is_readable() const { return true; } -inline bool BufferStream::is_writable() const { return true; } +inline bool BufferStream::wait_readable() const { return true; } + +inline bool BufferStream::wait_writable() const { return true; } inline ssize_t BufferStream::read(char *ptr, size_t size) { #if defined(_MSC_VER) && _MSC_VER < 1910 @@ -9161,6 +9167,10 @@ inline SSLSocketStream::SSLSocketStream( inline SSLSocketStream::~SSLSocketStream() = default; inline bool SSLSocketStream::is_readable() const { + return SSL_pending(ssl_) > 0; +} + +inline bool SSLSocketStream::wait_readable() const { if (max_timeout_msec_ <= 0) { return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; } @@ -9173,7 +9183,7 @@ inline bool SSLSocketStream::is_readable() const { return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; } -inline bool SSLSocketStream::is_writable() const { +inline bool SSLSocketStream::wait_writable() const { return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && is_socket_alive(sock_) && !is_ssl_peer_could_be_closed(ssl_, sock_); } @@ -9181,7 +9191,7 @@ inline bool SSLSocketStream::is_writable() const { inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { if (SSL_pending(ssl_) > 0) { return SSL_read(ssl_, ptr, static_cast(size)); - } else if (is_readable()) { + } else if (wait_readable()) { auto ret = SSL_read(ssl_, ptr, static_cast(size)); if (ret < 0) { auto err = SSL_get_error(ssl_, ret); @@ -9195,7 +9205,7 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { #endif if (SSL_pending(ssl_) > 0) { return SSL_read(ssl_, ptr, static_cast(size)); - } else if (is_readable()) { + } else if (wait_readable()) { std::this_thread::sleep_for(std::chrono::microseconds{10}); ret = SSL_read(ssl_, ptr, static_cast(size)); if (ret >= 0) { return ret; } @@ -9212,7 +9222,7 @@ inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { } inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { - if (is_writable()) { + if (wait_writable()) { auto handle_size = static_cast( std::min(size, (std::numeric_limits::max)())); @@ -9227,7 +9237,7 @@ inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { #else while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { #endif - if (is_writable()) { + if (wait_writable()) { std::this_thread::sleep_for(std::chrono::microseconds{10}); ret = SSL_write(ssl_, ptr, static_cast(handle_size)); if (ret >= 0) { return ret; } diff --git a/test/fuzzing/server_fuzzer.cc b/test/fuzzing/server_fuzzer.cc index b1ba3dbf85..a0f7c0eb83 100644 --- a/test/fuzzing/server_fuzzer.cc +++ b/test/fuzzing/server_fuzzer.cc @@ -25,7 +25,9 @@ class FuzzedStream : public httplib::Stream { bool is_readable() const override { return true; } - bool is_writable() const override { return true; } + bool wait_readable() const override { return true; } + + bool wait_writable() const override { return true; } void get_remote_ip_and_port(std::string &ip, int &port) const override { ip = "127.0.0.1"; diff --git a/test/test.cc b/test/test.cc index 423762b7a6..9762df24fe 100644 --- a/test/test.cc +++ b/test/test.cc @@ -156,7 +156,7 @@ TEST_F(UnixSocketTest, abstract) { } #endif -TEST(SocketStream, is_writable_UNIX) { +TEST(SocketStream, wait_writable_UNIX) { int fds[2]; ASSERT_EQ(0, socketpair(AF_UNIX, SOCK_STREAM, 0, fds)); @@ -167,17 +167,17 @@ TEST(SocketStream, is_writable_UNIX) { }; asSocketStream(fds[0], [&](Stream &s0) { EXPECT_EQ(s0.socket(), fds[0]); - EXPECT_TRUE(s0.is_writable()); + EXPECT_TRUE(s0.wait_writable()); EXPECT_EQ(0, close(fds[1])); - EXPECT_FALSE(s0.is_writable()); + EXPECT_FALSE(s0.wait_writable()); return true; }); EXPECT_EQ(0, close(fds[0])); } -TEST(SocketStream, is_writable_INET) { +TEST(SocketStream, wait_writable_INET) { sockaddr_in addr; memset(&addr, 0, sizeof(addr)); addr.sin_family = AF_INET; @@ -212,7 +212,7 @@ TEST(SocketStream, is_writable_INET) { }; asSocketStream(disconnected_svr_sock, [&](Stream &ss) { EXPECT_EQ(ss.socket(), disconnected_svr_sock); - EXPECT_FALSE(ss.is_writable()); + EXPECT_FALSE(ss.wait_writable()); return true; });