diff --git a/httplib.h b/httplib.h index 5c649c4e89..473404644d 100644 --- a/httplib.h +++ b/httplib.h @@ -517,6 +517,9 @@ using Progress = std::function; struct Response; using ResponseHandler = std::function; +class Stream; +using StreamHandler = std::function; + struct MultipartFormData { std::string name; std::string content; @@ -634,6 +637,7 @@ struct Request { // for client ResponseHandler response_handler; + StreamHandler stream_handler; ContentReceiverWithProgress content_receiver; Progress progress; #ifdef CPPHTTPLIB_OPENSSL_SUPPORT @@ -703,6 +707,8 @@ struct Response { const std::string &content_type); void set_file_content(const std::string &path); + void set_stream_handler(StreamHandler stream_handler); + Response() = default; Response(const Response &) = default; Response &operator=(const Response &) = default; @@ -722,6 +728,7 @@ struct Response { bool content_provider_success_ = false; std::string file_content_path_; std::string file_content_content_type_; + StreamHandler stream_handler_; }; class Stream { @@ -731,6 +738,10 @@ class Stream { virtual bool is_readable() const = 0; virtual bool is_writable() const = 0; + // Returns maximum size that may be passed to read() without blocking; read() + // may return fewer bytes + virtual size_t nonblocking_read_size() const = 0; + virtual ssize_t read(char *ptr, size_t size) = 0; virtual ssize_t write(const char *ptr, size_t size) = 0; virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; @@ -1148,6 +1159,7 @@ enum class Error { Compression, ConnectionTimeout, ProxyConnection, + StreamHandler, // For internal use only SSLPeerCouldBeClosed_, @@ -2199,6 +2211,7 @@ inline std::string to_string(const Error error) { case Error::Compression: return "Compression failed"; case Error::ConnectionTimeout: return "Connection timed out"; case Error::ProxyConnection: return "Proxy connection failed"; + case Error::StreamHandler: return "Stream handler failed"; case Error::Unknown: return "Unknown"; default: break; } @@ -2361,12 +2374,23 @@ bool parse_multipart_boundary(const std::string &content_type, bool parse_range_header(const std::string &s, Ranges &ranges); +void random_bytes(char *ptr, size_t length, bool alphanum); + +std::string random_string(size_t length); + int close_socket(socket_t sock); ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags); ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags); +ssize_t select_read(socket_t sock, time_t sec, time_t usec); + +ssize_t select_read(socket_t sock, socket_t extra_fd, time_t sec, time_t usec, + bool &sock_readable, bool &extra_fd_readable); + +void set_nonblocking(socket_t sock, bool nonblocking); + enum class EncodingType { None = 0, Gzip, Brotli }; EncodingType encoding_type(const Request &req, const Response &res); @@ -2378,6 +2402,7 @@ class BufferStream final : public Stream { bool is_readable() const override; bool is_writable() const override; + size_t nonblocking_read_size() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -3173,34 +3198,81 @@ inline ssize_t send_socket(socket_t sock, const void *ptr, size_t size, }); } -inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) { +template +inline ssize_t select_read_impl(socket_t sock, socket_t extra_fd, time_t sec, + time_t usec, bool *sock_readable, + bool *extra_fd_readable) { #ifdef CPPHTTPLIB_USE_POLL - struct pollfd pfd_read; - pfd_read.fd = sock; - pfd_read.events = POLLIN; - - auto timeout = static_cast(sec * 1000 + usec / 1000); - - return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + constexpr size_t nfds = WithExtraFD ? 2 : 1; + struct pollfd pfd_read[nfds]; + pfd_read[0].fd = sock; + pfd_read[0].events = POLLIN; + if (WithExtraFD) { + pfd_read[1].fd = extra_fd; + pfd_read[1].events = POLLIN; + } + + auto timeout = + static_cast((WithExtraFD && sec == static_cast(-1)) + ? -1 + : sec * 1000 + usec / 1000); + + size_t ret = handle_EINTR([&]() { return poll(pfd_read, nfds, timeout); }); + if (WithExtraFD && ret > 0) { + assert(sock_readable && extra_fd_readable); + *sock_readable = pfd_read[0].revents & POLLIN; + *extra_fd_readable = pfd_read[1].revents & POLLIN; + } + return ret; #else #ifndef _WIN32 - if (sock >= FD_SETSIZE) { return -1; } + if (sock >= FD_SETSIZE || (WithExtraFD && extra_fd >= FD_SETSIZE)) { + return -1; + } #endif + int nfds; fd_set fds; FD_ZERO(&fds); FD_SET(sock, &fds); - timeval tv; - tv.tv_sec = static_cast(sec); - tv.tv_usec = static_cast(usec); + if (WithExtraFD) { + FD_SET(extra_fd, &fds); + nfds = static_cast((std::max)(sock, extra_fd) + 1); + } else { + nfds = static_cast(sock + 1); + } - return handle_EINTR([&]() { - return select(static_cast(sock + 1), &fds, nullptr, nullptr, &tv); - }); + timeval tv, *ptv = nullptr; + if (!WithExtraFD || sec != static_cast(-1)) { + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + ptv = &tv; + } + + ssize_t ret = + handle_EINTR([&]() { return select(nfds, &fds, nullptr, nullptr, ptv); }); + if (WithExtraFD && ret > 0) { + assert(sock_readable && extra_fd_readable); + *sock_readable = FD_ISSET(sock, &fds); + *extra_fd_readable = FD_ISSET(extra_fd, &fds); + } + return ret; #endif } +inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) { + return select_read_impl(sock, INVALID_SOCKET, sec, usec, nullptr, + nullptr); +} + +inline ssize_t select_read(socket_t sock, socket_t extra_fd, time_t sec, + time_t usec, bool &sock_readable, + bool &extra_fd_readable) { + return select_read_impl(sock, extra_fd, sec, usec, &sock_readable, + &extra_fd_readable); +} + inline ssize_t select_write(socket_t sock, time_t sec, time_t usec) { #ifdef CPPHTTPLIB_USE_POLL struct pollfd pfd_read; @@ -3305,6 +3377,7 @@ class SocketStream final : public Stream { bool is_readable() const override; bool is_writable() const override; + size_t nonblocking_read_size() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -3335,6 +3408,7 @@ class SSLSocketStream final : public Stream { bool is_readable() const override; bool is_writable() const override; + size_t nonblocking_read_size() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; @@ -5119,7 +5193,7 @@ class MultipartFormDataParser { size_t buf_epos_ = 0; }; -inline std::string random_string(size_t length) { +inline void random_bytes(char *ptr, size_t length, bool alphanum) { static const char data[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; @@ -5134,11 +5208,26 @@ inline std::string random_string(size_t length) { static std::mt19937 engine(seed_sequence); - std::string result; - for (size_t i = 0; i < length; i++) { - result += data[engine() % (sizeof(data) - 1)]; + if (alphanum) { + for (size_t i = 0; i < length; i++) { + *(ptr++) = data[engine() % (sizeof(data) - 1)]; + } + } else { + for (size_t i = 0; i < length;) { + auto val = engine(); + for (size_t j = 0; i < length && j < sizeof(val); ++i, ++j) { + *(ptr++) = static_cast(val); + val >>= 8; + } + } } - return result; +} + +inline std::string random_string(size_t length) { + std::string s; + s.resize(length); + random_bytes(&s[0], length, true); + return s; } inline std::string make_multipart_data_boundary() { @@ -5922,6 +6011,10 @@ inline void Response::set_file_content(const std::string &path) { file_content_path_ = path; } +inline void Response::set_stream_handler(StreamHandler stream_handler) { + stream_handler_ = std::move(stream_handler); +} + // Result implementation inline bool Result::has_request_header(const std::string &key) const { return request_headers_.find(key) != request_headers_.end(); @@ -5971,6 +6064,10 @@ inline bool SocketStream::is_writable() const { is_socket_alive(sock_); } +inline size_t SocketStream::nonblocking_read_size() const { + return read_buff_content_size_ - read_buff_off_; +} + inline ssize_t SocketStream::read(char *ptr, size_t size) { #ifdef _WIN32 size = @@ -6045,6 +6142,10 @@ inline bool BufferStream::is_readable() const { return true; } inline bool BufferStream::is_writable() const { return true; } +inline size_t BufferStream::nonblocking_read_size() const { + return (std::numeric_limits::max)(); +} + inline ssize_t BufferStream::read(char *ptr, size_t size) { #if defined(_MSC_VER) && _MSC_VER < 1910 auto len_read = buffer._Copy_s(ptr, size, size, position); @@ -6512,31 +6613,34 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection, std::string content_type; std::string boundary; - if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); } + if (!res.stream_handler_) { + if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); } - // Prepare additional headers - if (close_connection || req.get_header_value("Connection") == "close") { - res.set_header("Connection", "close"); - } else { - std::string s = "timeout="; - s += std::to_string(keep_alive_timeout_sec_); - s += ", max="; - s += std::to_string(keep_alive_max_count_); - res.set_header("Keep-Alive", s); - } + // Prepare additional headers + if (close_connection || req.get_header_value("Connection") == "close") { + res.set_header("Connection", "close"); + } else { + std::string s = "timeout="; + s += std::to_string(keep_alive_timeout_sec_); + s += ", max="; + s += std::to_string(keep_alive_max_count_); + res.set_header("Keep-Alive", s); + } - if ((!res.body.empty() || res.content_length_ > 0 || res.content_provider_) && - !res.has_header("Content-Type")) { - res.set_header("Content-Type", "text/plain"); - } + if ((!res.body.empty() || res.content_length_ > 0 || + res.content_provider_) && + !res.has_header("Content-Type")) { + res.set_header("Content-Type", "text/plain"); + } - if (res.body.empty() && !res.content_length_ && !res.content_provider_ && - !res.has_header("Content-Length")) { - res.set_header("Content-Length", "0"); - } + if (res.body.empty() && !res.content_length_ && !res.content_provider_ && + !res.has_header("Content-Length")) { + res.set_header("Content-Length", "0"); + } - if (req.method == "HEAD" && !res.has_header("Accept-Ranges")) { - res.set_header("Accept-Ranges", "bytes"); + if (req.method == "HEAD" && !res.has_header("Accept-Ranges")) { + res.set_header("Accept-Ranges", "bytes"); + } } if (post_routing_handler_) { post_routing_handler_(req, res); } @@ -6554,16 +6658,24 @@ inline bool Server::write_response_core(Stream &strm, bool close_connection, // Body auto ret = true; - if (req.method != "HEAD") { - if (!res.body.empty()) { - if (!detail::write_data(strm, res.body.data(), res.body.size())) { - ret = false; - } - } else if (res.content_provider_) { - if (write_content_with_provider(strm, req, res, boundary, content_type)) { - res.content_provider_success_ = true; - } else { - ret = false; + if (res.stream_handler_) { + // Log + if (logger_) { logger_(req, res); } + + return res.stream_handler_(strm); + } else { + if (req.method != "HEAD") { + if (!res.body.empty()) { + if (!detail::write_data(strm, res.body.data(), res.body.size())) { + ret = false; + } + } else if (res.content_provider_) { + if (write_content_with_provider(strm, req, res, boundary, + content_type)) { + res.content_provider_success_ = true; + } else { + ret = false; + } } } } @@ -8064,10 +8176,23 @@ inline bool ClientImpl::process_request(Stream &strm, Request &req, res.status != StatusCode::NotModified_304 && follow_location_; - if (req.response_handler && !redirect) { - if (!req.response_handler(res)) { - error = Error::Canceled; - return false; + if (!redirect) { + if (req.response_handler) { + if (!req.response_handler(res)) { + error = Error::Canceled; + return false; + } + } + + if (req.stream_handler) { + // Log + if (logger_) { logger_(req, res); } + + if (!req.stream_handler(strm)) { + error = Error::StreamHandler; + return false; + } + return true; } } @@ -9102,6 +9227,10 @@ inline bool SSLSocketStream::is_writable() const { is_socket_alive(sock_); } +inline size_t SSLSocketStream::nonblocking_read_size() const { + return static_cast(SSL_pending(ssl_)); +} + inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { if (SSL_pending(ssl_) > 0) { return SSL_read(ssl_, ptr, static_cast(size)); diff --git a/test/fuzzing/server_fuzzer.cc b/test/fuzzing/server_fuzzer.cc index 3cffbae244..8424246473 100644 --- a/test/fuzzing/server_fuzzer.cc +++ b/test/fuzzing/server_fuzzer.cc @@ -1,4 +1,5 @@ #include +#include #include @@ -27,6 +28,8 @@ class FuzzedStream : public httplib::Stream { bool is_writable() const override { return true; } + size_t nonblocking_read_size() const override { return (std::numeric_limits::max)(); } + void get_remote_ip_and_port(std::string &ip, int &port) const override { ip = "127.0.0.1"; port = 8080;