11#pragma once
22
3+ #include < algorithm>
4+ #include < atomic>
5+ #include < cctype>
6+ #include < cstdint>
37#include < memory>
48#include < mutex>
59#include < random>
10+ #include < vector>
611#include < variant>
712
813#include " glaze/net/http_client.hpp"
@@ -12,6 +17,15 @@ namespace glz
1217{
1318 struct websocket_client
1419 {
20+ enum class header_validation_error : uint8_t
21+ {
22+ none,
23+ empty_name,
24+ reserved_name,
25+ invalid_name,
26+ invalid_value,
27+ };
28+
1529 using message_handler_t = std::function<void (std::string_view, ws_opcode)>;
1630 using open_handler_t = std::function<void ()>;
1731 using close_handler_t = std::function<void (ws_close_code, std::string_view)>;
@@ -47,6 +61,9 @@ namespace glz
4761 std::shared_ptr<asio::ssl::context> ssl_ctx_;
4862 std::shared_ptr<ssl_socket> ssl_socket_;
4963#endif
64+ mutable std::mutex request_headers_mutex_;
65+ std::vector<std::pair<std::string, std::string>> request_headers_;
66+ std::atomic<header_validation_error> last_header_validation_error_{header_validation_error::none};
5067
5168 size_t max_message_size{1024 * 1024 * 16 }; // 16 MB limit
5269#ifdef GLZ_ENABLE_SSL
@@ -55,6 +72,128 @@ namespace glz
5572
5673 explicit impl (std::shared_ptr<asio::io_context> context) : ctx(std::move(context)) {}
5774
75+ static bool header_name_equal (std::string_view lhs, std::string_view rhs)
76+ {
77+ if (lhs.size () != rhs.size ()) return false ;
78+ return std::equal (lhs.begin (), lhs.end (), rhs.begin (), [](char a, char b) {
79+ return std::tolower (static_cast <unsigned char >(a)) == std::tolower (static_cast <unsigned char >(b));
80+ });
81+ }
82+
83+ static bool header_name_starts_with (std::string_view value, std::string_view prefix)
84+ {
85+ if (value.size () < prefix.size ()) return false ;
86+ return std::equal (prefix.begin (), prefix.end (), value.begin (), [](char a, char b) {
87+ return std::tolower (static_cast <unsigned char >(a)) == std::tolower (static_cast <unsigned char >(b));
88+ });
89+ }
90+
91+ static bool is_tchar (unsigned char c)
92+ {
93+ if ((c >= ' A' && c <= ' Z' ) || (c >= ' a' && c <= ' z' ) || (c >= ' 0' && c <= ' 9' )) return true ;
94+
95+ switch (c) {
96+ case ' !' :
97+ case ' #' :
98+ case ' $' :
99+ case ' %' :
100+ case ' &' :
101+ case ' \' ' :
102+ case ' *' :
103+ case ' +' :
104+ case ' -' :
105+ case ' .' :
106+ case ' ^' :
107+ case ' _' :
108+ case ' `' :
109+ case ' |' :
110+ case ' ~' :
111+ return true ;
112+ default :
113+ return false ;
114+ }
115+ }
116+
117+ static bool is_reserved_handshake_header (std::string_view name)
118+ {
119+ return header_name_equal (name, " Host" ) || header_name_equal (name, " Upgrade" ) ||
120+ header_name_equal (name, " Connection" ) || header_name_starts_with (name, " Sec-WebSocket-" );
121+ }
122+
123+ static bool validate_header_name (std::string_view name, header_validation_error& error)
124+ {
125+ if (name.empty ()) {
126+ error = header_validation_error::empty_name;
127+ return false ;
128+ }
129+
130+ if (is_reserved_handshake_header (name)) {
131+ error = header_validation_error::reserved_name;
132+ return false ;
133+ }
134+
135+ for (const unsigned char c : name) {
136+ if (!is_tchar (c)) {
137+ error = header_validation_error::invalid_name;
138+ return false ;
139+ }
140+ }
141+ return true ;
142+ }
143+
144+ static bool validate_header_value (std::string_view value, header_validation_error& error)
145+ {
146+ for (const unsigned char c : value) {
147+ if (c == ' \r ' || c == ' \n ' || c == 127 ) {
148+ error = header_validation_error::invalid_value;
149+ return false ;
150+ }
151+ if (c < 32 && c != ' \t ' ) {
152+ error = header_validation_error::invalid_value;
153+ return false ;
154+ }
155+ }
156+ return true ;
157+ }
158+
159+ std::vector<std::pair<std::string, std::string>> request_headers_snapshot () const
160+ {
161+ std::lock_guard<std::mutex> lock (request_headers_mutex_);
162+ return request_headers_;
163+ }
164+
165+ bool set_request_header (std::string_view name, std::string_view value)
166+ {
167+ header_validation_error error = header_validation_error::none;
168+ if (!validate_header_name (name, error) || !validate_header_value (value, error)) {
169+ last_header_validation_error_.store (error, std::memory_order_relaxed);
170+ return false ;
171+ }
172+
173+ std::lock_guard<std::mutex> lock (request_headers_mutex_);
174+ for (auto & [existing_name, existing_value] : request_headers_) {
175+ if (header_name_equal (existing_name, name)) {
176+ existing_value = std::string (value);
177+ last_header_validation_error_.store (header_validation_error::none, std::memory_order_relaxed);
178+ return true ;
179+ }
180+ }
181+ request_headers_.emplace_back (std::string (name), std::string (value));
182+ last_header_validation_error_.store (header_validation_error::none, std::memory_order_relaxed);
183+ return true ;
184+ }
185+
186+ void clear_request_headers ()
187+ {
188+ std::lock_guard<std::mutex> lock (request_headers_mutex_);
189+ request_headers_.clear ();
190+ }
191+
192+ header_validation_error last_header_validation_error () const
193+ {
194+ return last_header_validation_error_.load (std::memory_order_relaxed);
195+ }
196+
58197 void cancel_all ()
59198 {
60199 // Clear handlers first to prevent callbacks during cleanup
@@ -214,7 +353,12 @@ namespace glz
214353
215354 std::string handshake = " GET " + url.path + " HTTP/1.1\r\n " + " Host: " + url.host + " \r\n " +
216355 " Upgrade: websocket\r\n " + " Connection: Upgrade\r\n " + " Sec-WebSocket-Key: " + key +
217- " \r\n " + " Sec-WebSocket-Version: 13\r\n\r\n " ;
356+ " \r\n " + " Sec-WebSocket-Version: 13\r\n " ;
357+
358+ for (const auto & [name, value] : request_headers_snapshot ()) {
359+ handshake += name + " : " + value + " \r\n " ;
360+ }
361+ handshake += " \r\n " ;
218362
219363 auto req_buf = std::make_shared<std::string>(std::move (handshake));
220364 std::weak_ptr<impl> weak_self = weak_from_this ();
@@ -413,6 +557,22 @@ namespace glz
413557 void set_ssl_verify_mode (asio::ssl::verify_mode mode) { impl_->ssl_verify_mode_ = mode; }
414558#endif
415559
560+ // Set an additional HTTP header for the opening WebSocket handshake.
561+ // Reserved handshake headers (Host, Upgrade, Connection, Sec-WebSocket-*) cannot be overridden.
562+ // Returns false if the name/value fails validation.
563+ [[nodiscard]] bool set_header (std::string_view name, std::string_view value)
564+ {
565+ return impl_->set_request_header (name, value);
566+ }
567+
568+ // Clear all additional handshake headers previously set via set_header().
569+ void clear_headers () { impl_->clear_request_headers (); }
570+
571+ [[nodiscard]] header_validation_error last_header_error () const
572+ {
573+ return impl_->last_header_validation_error ();
574+ }
575+
416576 std::shared_ptr<asio::io_context>& context () { return impl_->ctx ; }
417577
418578 asio::ip::tcp::socket& socket () { return impl_->get_tcp_socket_ref (); }
0 commit comments