Skip to content
New issue

Have a question about this project? # for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “#”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? # to your account

Add unit tests for stream handlers and Stream::is_readable() #2075

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 70 additions & 27 deletions httplib.h
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,11 @@ using Progress = std::function<bool(uint64_t current, uint64_t total)>;
struct Response;
using ResponseHandler = std::function<bool(const Response &response)>;

class Stream;
// Note: do not replace 'std::function<bool(Stream &strm)>' with StreamHandler;
// signature is not final
using StreamHandler = std::function<bool(Stream &strm)>;

struct MultipartFormData {
std::string name;
std::string content;
Expand Down Expand Up @@ -641,6 +646,7 @@ struct Request {

// for client
ResponseHandler response_handler;
StreamHandler stream_handler; // EXPERIMENTAL function signature may change
ContentReceiverWithProgress content_receiver;
Progress progress;
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
Expand Down Expand Up @@ -712,6 +718,9 @@ struct Response {
const std::string &content_type);
void set_file_content(const std::string &path);

// EXPERIMENTAL callback function signature may change
void set_stream_handler(StreamHandler stream_handler);

Response() = default;
Response(const Response &) = default;
Response &operator=(const Response &) = default;
Expand All @@ -731,6 +740,8 @@ struct Response {
bool content_provider_success_ = false;
std::string file_content_path_;
std::string file_content_content_type_;
// EXPERIMENTAL function signature may change
StreamHandler stream_handler_;
};

class Stream {
Expand Down Expand Up @@ -1170,6 +1181,7 @@ enum class Error {
Compression,
ConnectionTimeout,
ProxyConnection,
StreamHandler,

// For internal use only
SSLPeerCouldBeClosed_,
Expand Down Expand Up @@ -2260,6 +2272,7 @@ inline std::string to_string(const Error error) {
case Error::Compression: return "Compression failed";
case Error::ConnectionTimeout: return "Connection timed out";
case Error::ProxyConnection: return "Proxy connection failed";
case Error::StreamHandler: return "Stream handler failed";
case Error::Unknown: return "Unknown";
default: break;
}
Expand Down Expand Up @@ -5906,6 +5919,10 @@ inline void Response::set_file_content(const std::string &path) {
file_content_path_ = path;
}

inline void Response::set_stream_handler(StreamHandler stream_handler) {
stream_handler_ = std::move(stream_handler);
}

// Result implementation
inline bool Result::has_request_header(const std::string &key) const {
return request_headers_.find(key) != request_headers_.end();
Expand Down Expand Up @@ -6548,18 +6565,21 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection,
res.set_header("Keep-Alive", s);
}

if ((!res.body.empty() || res.content_length_ > 0 || res.content_provider_) &&
!res.has_header("Content-Type")) {
res.set_header("Content-Type", "text/plain");
}
if (!res.stream_handler_) {
if ((!res.body.empty() || res.content_length_ > 0 ||
res.content_provider_) &&
!res.has_header("Content-Type")) {
res.set_header("Content-Type", "text/plain");
}

if (res.body.empty() && !res.content_length_ && !res.content_provider_ &&
!res.has_header("Content-Length")) {
res.set_header("Content-Length", "0");
}
if (res.body.empty() && !res.content_length_ && !res.content_provider_ &&
!res.has_header("Content-Length")) {
res.set_header("Content-Length", "0");
}

if (req.method == "HEAD" && !res.has_header("Accept-Ranges")) {
res.set_header("Accept-Ranges", "bytes");
if (req.method == "HEAD" && !res.has_header("Accept-Ranges")) {
res.set_header("Accept-Ranges", "bytes");
}
}

if (post_routing_handler_) { post_routing_handler_(req, res); }
Expand All @@ -6577,16 +6597,24 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection,

// Body
auto ret = true;
if (req.method != "HEAD") {
if (!res.body.empty()) {
if (!detail::write_data(strm, res.body.data(), res.body.size())) {
ret = false;
}
} else if (res.content_provider_) {
if (write_content_with_provider(strm, req, res, boundary, content_type)) {
res.content_provider_success_ = true;
} else {
ret = false;
if (res.stream_handler_) {
// Log early
if (logger_) { logger_(req, res); }

return res.stream_handler_(strm);
} else {
if (req.method != "HEAD") {
if (!res.body.empty()) {
if (!detail::write_data(strm, res.body.data(), res.body.size())) {
ret = false;
}
} else if (res.content_provider_) {
if (write_content_with_provider(strm, req, res, boundary,
content_type)) {
res.content_provider_success_ = true;
} else {
ret = false;
}
}
}
}
Expand Down Expand Up @@ -7800,10 +7828,12 @@ inline bool ClientImpl::write_request(Stream &strm, Request &req,
}
}

if (!req.has_header("Accept")) { req.set_header("Accept", "*/*"); }
if (!req.stream_handler && !req.has_header("Accept")) {
req.set_header("Accept", "*/*");
}

if (!req.content_receiver) {
if (!req.has_header("Accept-Encoding")) {
if (!req.stream_handler && !req.has_header("Accept-Encoding")) {
std::string accept_encoding;
#ifdef CPPHTTPLIB_BROTLI_SUPPORT
accept_encoding = "br";
Expand All @@ -7821,7 +7851,7 @@ inline bool ClientImpl::write_request(Stream &strm, Request &req,
req.set_header("User-Agent", agent);
}
#endif
};
}

if (req.body.empty()) {
if (req.content_provider_) {
Expand Down Expand Up @@ -8053,10 +8083,23 @@ inline bool ClientImpl::process_request(Stream &strm, Request &req,
res.status != StatusCode::NotModified_304 &&
follow_location_;

if (req.response_handler && !redirect) {
if (!req.response_handler(res)) {
error = Error::Canceled;
return false;
if (!redirect) {
if (req.response_handler) {
if (!req.response_handler(res)) {
error = Error::Canceled;
return false;
}
}

if (req.stream_handler) {
// Log early
if (logger_) { logger_(req, res); }

if (!req.stream_handler(strm)) {
error = Error::StreamHandler;
return false;
}
return true;
}
}

Expand Down
111 changes: 111 additions & 0 deletions test/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8434,3 +8434,114 @@ TEST(ClientInThreadTest, Issue2068) {
t.join();
}
}

template <typename S, typename C>
static void stream_handler_test(S &svr, C &cli) {
const auto delay = std::chrono::milliseconds{200};
const auto timeout_us =
std::chrono::duration_cast<std::chrono::microseconds>(delay).count() / 2;

svr.Get("/", [delay](const Request &req, Response &res) {
// Request should contain limited default headers
EXPECT_EQ(req.has_header("Host"), true);
EXPECT_EQ(req.has_header("User-Agent"), true);
EXPECT_EQ(req.has_header("Connection"), true);
// Need connection to close at the end for test to succeed
EXPECT_EQ(req.get_header_value("Connection"), "close");
// REMOTE_ADDR, REMOTE_PORT, LOCAL_ADDR, LOCAL_PORT = 4
EXPECT_EQ(req.headers.size(), (4 + 3));

res.set_stream_handler([&](Stream &strm) -> bool {
char buf[16]{};
// Client shpuld time out first
std::this_thread::sleep_for(delay);
strm.write(buf, sizeof(buf));

// Synchronize with client and close connection
EXPECT_TRUE(strm.wait_readable());

// Read to avoid RST on Windows
strm.read(buf, sizeof(buf));

return true;
});
});
auto thread = std::thread([&]() { svr.listen(HOST, PORT); });

auto se = detail::scope_exit([&] {
svr.stop();
thread.join();
ASSERT_FALSE(svr.is_running());
});

svr.wait_until_ready();

Request req;
req.method = "GET";
req.path = "/";
req.response_handler = [](const Response &res) -> bool {
EXPECT_EQ(res.get_header_value("Connection"), "close");
EXPECT_EQ(res.headers.size(), 1);
return true;
};
req.stream_handler = [delay](Stream &strm) -> bool {
char buf[16]{};
ssize_t n = 0;
// Buffer should be empty and first read should time out
EXPECT_FALSE(strm.is_readable());
EXPECT_FALSE(strm.wait_readable());

// Sever will send data soon
std::this_thread::sleep_for(delay);
EXPECT_TRUE(strm.wait_readable());

n = strm.read(buf, sizeof(buf) / 2);
EXPECT_EQ(sizeof(buf) / 2, n);

// Server sent 16 bytes, we read 8; remainder should be buffered
EXPECT_TRUE(strm.is_readable());

// Read remaining bytes from buffer
n = strm.read(buf, sizeof(buf) / 2);
EXPECT_EQ(sizeof(buf) / 2, n);

// Buffer should be empty
EXPECT_FALSE(strm.is_readable());

// Signal server to close connection
strm.write(buf, sizeof(buf));
std::this_thread::sleep_for(delay);

// Server should have closed connection
n = strm.read(buf, sizeof(buf));
EXPECT_EQ(0, n);

return true;
};

cli.set_read_timeout(0, timeout_us);

Response res;
Error error;
ASSERT_TRUE(cli.send(req, res, error));
EXPECT_EQ(StatusCode::OK_200, res.status);
EXPECT_EQ(res.headers.size(), 1);
EXPECT_TRUE(res.body.empty());
}

TEST(StreamHandlerTest, Basic) {
Server svr;
Client cli(HOST, PORT);

stream_handler_test(svr, cli);
}

#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
TEST(StreamHandlerTest, BasicSSL) {
SSLServer svr(SERVER_CERT_FILE, SERVER_PRIVATE_KEY_FILE);
SSLClient cli(HOST, PORT);
cli.enable_server_certificate_verification(false);

stream_handler_test(svr, cli);
}
#endif
Loading