Skip to content

Commit 17b76f8

Browse files
authored
Additional websocket handshake header support (#2322)
1 parent 4bfb834 commit 17b76f8

File tree

4 files changed

+737
-1
lines changed

4 files changed

+737
-1
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ out/
99
CMakeUserPresets.json
1010
**/CMakeFiles/
1111
**/CMakeCache.txt
12+
**/Testing/
1213

1314
# IDE files
1415
.idea/

docs/networking/websocket-client.md

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ int main() {
5353
- **Message Masking**: Automatically masks outgoing messages in client mode (per RFC 6455)
5454
- **Secure Connections**: Built-in SSL/TLS support for WSS connections
5555
- **Configurable**: Adjustable max message size and other options
56+
- **Custom Handshake Headers**: Add headers such as `Authorization` for authenticated endpoints
5657
- **Shared io_context**: Can share an ASIO io_context with other network operations
5758

5859
## Constructor
@@ -181,6 +182,71 @@ Default is 16 MB (16 * 1024 * 1024 bytes).
181182
client.set_max_message_size(1024 * 1024 * 32); // 32 MB
182183
```
183184

185+
### set_header()
186+
187+
Sets an additional HTTP header for the opening WebSocket handshake.
188+
189+
```cpp
190+
bool set_header(std::string_view name, std::string_view value);
191+
```
192+
193+
Header names and values are validated to prevent malformed handshakes:
194+
- Header names must be valid HTTP token names
195+
- Header values cannot contain control characters such as CR/LF
196+
- Reserved handshake headers are rejected and cannot be overridden:
197+
`Host`, `Upgrade`, `Connection`, and `Sec-WebSocket-*`
198+
- Setting the same header name again (case-insensitive) replaces the previous value
199+
- Returns `true` if the header is accepted and stored
200+
- Returns `false` if validation fails
201+
If validation fails, no header is changed.
202+
203+
**Example:**
204+
```cpp
205+
if (!client.set_header("Authorization", "Bearer your-token")) {
206+
std::cerr << "Invalid Authorization header" << std::endl;
207+
}
208+
client.set_header("X-API-Key", "your-key");
209+
```
210+
211+
### last_header_error()
212+
213+
Returns the validation status from the most recent `set_header()` call.
214+
215+
```cpp
216+
websocket_client::header_validation_error last_header_error() const;
217+
```
218+
219+
Possible values:
220+
- `none`
221+
- `empty_name`
222+
- `reserved_name`
223+
- `invalid_name`
224+
- `invalid_value`
225+
226+
**Example:**
227+
```cpp
228+
if (!client.set_header("Sec-WebSocket-Key", "bad")) {
229+
if (client.last_header_error() == glz::websocket_client::header_validation_error::reserved_name) {
230+
std::cerr << "Reserved handshake header cannot be overridden" << std::endl;
231+
}
232+
}
233+
```
234+
235+
### clear_headers()
236+
237+
Clears all additional handshake headers previously set via `set_header()`.
238+
Headers persist across `connect()` calls until cleared.
239+
This does not change `last_header_error()`, which only reflects the most recent `set_header()` call.
240+
241+
```cpp
242+
void clear_headers();
243+
```
244+
245+
**Example:**
246+
```cpp
247+
client.clear_headers();
248+
```
249+
184250
## Event Handlers
185251

186252
Event handlers are callback functions that are invoked when specific events occur.
@@ -395,6 +461,30 @@ int main() {
395461
}
396462
```
397463

464+
### Authenticated WebSocket Client
465+
466+
```cpp
467+
#include "glaze/net/websocket_client.hpp"
468+
#include <iostream>
469+
470+
int main() {
471+
glz::websocket_client client;
472+
473+
client.set_header("Authorization", "Bearer your-token");
474+
475+
client.on_open([]() {
476+
std::cout << "Authenticated websocket connected" << std::endl;
477+
});
478+
479+
client.on_error([](std::error_code ec) {
480+
std::cerr << "Error: " << ec.message() << std::endl;
481+
});
482+
483+
client.connect("wss://api.example.com/ws");
484+
client.run();
485+
}
486+
```
487+
398488
### JSON Message Exchange
399489

400490
```cpp

include/glaze/net/websocket_client.hpp

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
#pragma once
22

3+
#include <algorithm>
4+
#include <atomic>
5+
#include <cctype>
6+
#include <cstdint>
37
#include <memory>
48
#include <mutex>
59
#include <random>
10+
#include <vector>
611
#include <variant>
712

813
#include "glaze/net/http_client.hpp"
@@ -12,6 +17,15 @@ namespace glz
1217
{
1318
struct websocket_client
1419
{
20+
enum class header_validation_error : uint8_t
21+
{
22+
none,
23+
empty_name,
24+
reserved_name,
25+
invalid_name,
26+
invalid_value,
27+
};
28+
1529
using message_handler_t = std::function<void(std::string_view, ws_opcode)>;
1630
using open_handler_t = std::function<void()>;
1731
using close_handler_t = std::function<void(ws_close_code, std::string_view)>;
@@ -47,6 +61,9 @@ namespace glz
4761
std::shared_ptr<asio::ssl::context> ssl_ctx_;
4862
std::shared_ptr<ssl_socket> ssl_socket_;
4963
#endif
64+
mutable std::mutex request_headers_mutex_;
65+
std::vector<std::pair<std::string, std::string>> request_headers_;
66+
std::atomic<header_validation_error> last_header_validation_error_{header_validation_error::none};
5067

5168
size_t max_message_size{1024 * 1024 * 16}; // 16 MB limit
5269
#ifdef GLZ_ENABLE_SSL
@@ -55,6 +72,128 @@ namespace glz
5572

5673
explicit impl(std::shared_ptr<asio::io_context> context) : ctx(std::move(context)) {}
5774

75+
static bool header_name_equal(std::string_view lhs, std::string_view rhs)
76+
{
77+
if (lhs.size() != rhs.size()) return false;
78+
return std::equal(lhs.begin(), lhs.end(), rhs.begin(), [](char a, char b) {
79+
return std::tolower(static_cast<unsigned char>(a)) == std::tolower(static_cast<unsigned char>(b));
80+
});
81+
}
82+
83+
static bool header_name_starts_with(std::string_view value, std::string_view prefix)
84+
{
85+
if (value.size() < prefix.size()) return false;
86+
return std::equal(prefix.begin(), prefix.end(), value.begin(), [](char a, char b) {
87+
return std::tolower(static_cast<unsigned char>(a)) == std::tolower(static_cast<unsigned char>(b));
88+
});
89+
}
90+
91+
static bool is_tchar(unsigned char c)
92+
{
93+
if ((c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9')) return true;
94+
95+
switch (c) {
96+
case '!':
97+
case '#':
98+
case '$':
99+
case '%':
100+
case '&':
101+
case '\'':
102+
case '*':
103+
case '+':
104+
case '-':
105+
case '.':
106+
case '^':
107+
case '_':
108+
case '`':
109+
case '|':
110+
case '~':
111+
return true;
112+
default:
113+
return false;
114+
}
115+
}
116+
117+
static bool is_reserved_handshake_header(std::string_view name)
118+
{
119+
return header_name_equal(name, "Host") || header_name_equal(name, "Upgrade") ||
120+
header_name_equal(name, "Connection") || header_name_starts_with(name, "Sec-WebSocket-");
121+
}
122+
123+
static bool validate_header_name(std::string_view name, header_validation_error& error)
124+
{
125+
if (name.empty()) {
126+
error = header_validation_error::empty_name;
127+
return false;
128+
}
129+
130+
if (is_reserved_handshake_header(name)) {
131+
error = header_validation_error::reserved_name;
132+
return false;
133+
}
134+
135+
for (const unsigned char c : name) {
136+
if (!is_tchar(c)) {
137+
error = header_validation_error::invalid_name;
138+
return false;
139+
}
140+
}
141+
return true;
142+
}
143+
144+
static bool validate_header_value(std::string_view value, header_validation_error& error)
145+
{
146+
for (const unsigned char c : value) {
147+
if (c == '\r' || c == '\n' || c == 127) {
148+
error = header_validation_error::invalid_value;
149+
return false;
150+
}
151+
if (c < 32 && c != '\t') {
152+
error = header_validation_error::invalid_value;
153+
return false;
154+
}
155+
}
156+
return true;
157+
}
158+
159+
std::vector<std::pair<std::string, std::string>> request_headers_snapshot() const
160+
{
161+
std::lock_guard<std::mutex> lock(request_headers_mutex_);
162+
return request_headers_;
163+
}
164+
165+
bool set_request_header(std::string_view name, std::string_view value)
166+
{
167+
header_validation_error error = header_validation_error::none;
168+
if (!validate_header_name(name, error) || !validate_header_value(value, error)) {
169+
last_header_validation_error_.store(error, std::memory_order_relaxed);
170+
return false;
171+
}
172+
173+
std::lock_guard<std::mutex> lock(request_headers_mutex_);
174+
for (auto& [existing_name, existing_value] : request_headers_) {
175+
if (header_name_equal(existing_name, name)) {
176+
existing_value = std::string(value);
177+
last_header_validation_error_.store(header_validation_error::none, std::memory_order_relaxed);
178+
return true;
179+
}
180+
}
181+
request_headers_.emplace_back(std::string(name), std::string(value));
182+
last_header_validation_error_.store(header_validation_error::none, std::memory_order_relaxed);
183+
return true;
184+
}
185+
186+
void clear_request_headers()
187+
{
188+
std::lock_guard<std::mutex> lock(request_headers_mutex_);
189+
request_headers_.clear();
190+
}
191+
192+
header_validation_error last_header_validation_error() const
193+
{
194+
return last_header_validation_error_.load(std::memory_order_relaxed);
195+
}
196+
58197
void cancel_all()
59198
{
60199
// Clear handlers first to prevent callbacks during cleanup
@@ -214,7 +353,12 @@ namespace glz
214353

215354
std::string handshake = "GET " + url.path + " HTTP/1.1\r\n" + "Host: " + url.host + "\r\n" +
216355
"Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Key: " + key +
217-
"\r\n" + "Sec-WebSocket-Version: 13\r\n\r\n";
356+
"\r\n" + "Sec-WebSocket-Version: 13\r\n";
357+
358+
for (const auto& [name, value] : request_headers_snapshot()) {
359+
handshake += name + ": " + value + "\r\n";
360+
}
361+
handshake += "\r\n";
218362

219363
auto req_buf = std::make_shared<std::string>(std::move(handshake));
220364
std::weak_ptr<impl> weak_self = weak_from_this();
@@ -413,6 +557,22 @@ namespace glz
413557
void set_ssl_verify_mode(asio::ssl::verify_mode mode) { impl_->ssl_verify_mode_ = mode; }
414558
#endif
415559

560+
// Set an additional HTTP header for the opening WebSocket handshake.
561+
// Reserved handshake headers (Host, Upgrade, Connection, Sec-WebSocket-*) cannot be overridden.
562+
// Returns false if the name/value fails validation.
563+
[[nodiscard]] bool set_header(std::string_view name, std::string_view value)
564+
{
565+
return impl_->set_request_header(name, value);
566+
}
567+
568+
// Clear all additional handshake headers previously set via set_header().
569+
void clear_headers() { impl_->clear_request_headers(); }
570+
571+
[[nodiscard]] header_validation_error last_header_error() const
572+
{
573+
return impl_->last_header_validation_error();
574+
}
575+
416576
std::shared_ptr<asio::io_context>& context() { return impl_->ctx; }
417577

418578
asio::ip::tcp::socket& socket() { return impl_->get_tcp_socket_ref(); }

0 commit comments

Comments
 (0)