Skip to content

Commit 7729826

Browse files
author
abakumov-mik
committed
feat userver: handle non-101 HTTP responses in WebSocket handshake
#### Summary Fixed WebSocket client to properly handle HTTP error responses (401, 403, etc.) during handshake instead of treating them as connection errors. #### Problem curl expects 101 status for WebSocket upgrade. When server returns other HTTP status (e.g., 401 Unauthorized), curl reports `kHttpReturnedError`, preventing client code from accessing the actual HTTP response. #### Solution Ignore curl error in `request_state.cpp` for WebSocket handshake when a valid HTTP status code is received. This allows client to inspect response status, headers, and body. commit_hash:fe9f0782a43dfefc66295357d5a38c54f026d1b5
1 parent 027f1dc commit 7729826

File tree

5 files changed

+98
-20
lines changed

5 files changed

+98
-20
lines changed

core/functional_tests/websocket_client/service.cpp

Lines changed: 66 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,21 @@ class WebSocketEcho final : public server::handlers::WebsocketHandlerBase {
3737
}
3838
};
3939

40+
class WebSocketUnauth final : public server::handlers::WebsocketHandlerBase {
41+
public:
42+
static constexpr std::string_view kName = "websocket-unauth-handler";
43+
44+
using WebsocketHandlerBase::WebsocketHandlerBase;
45+
46+
bool
47+
HandleHandshake(const server::http::HttpRequest&, server::http::HttpResponse& response, server::request::RequestContext&) const override {
48+
response.SetStatus(server::http::HttpStatus::kUnauthorized);
49+
return false;
50+
}
51+
52+
void Handle(websocket::WebSocketConnection&, server::request::RequestContext&) const override {}
53+
};
54+
4055
// HTTP handler for testing C++ WebSocket client
4156
class TestClientHandler final : public server::handlers::HttpHandlerBase {
4257
public:
@@ -66,6 +81,10 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase {
6681
return TestNonblockingRead(request);
6782
} else if (test_name == "nonblocking_write") {
6883
return TestNonblockingWrite(request);
84+
} else if (test_name == "unauth") {
85+
return TestUnauth(request);
86+
} else if (test_name == "connection_already_extracted") {
87+
return TestConnectionAlreadyExtracted(request);
6988
}
7089
return "Unknown test";
7190
} catch (const std::exception& e) {
@@ -74,20 +93,15 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase {
7493
}
7594

7695
private:
77-
std::shared_ptr<websocket::WebSocketConnection> MakeWebSocketConnection(
78-
const server::http::HttpRequest& request,
79-
const std::string& uri = "/echo"
80-
) const {
96+
clients::http::WebSocketResponse PerformWebSocket(const server::http::HttpRequest& request, const std::string& uri)
97+
const {
8198
const auto port = std::stoi(request.GetArg("port"));
8299

83-
auto ws_response =
84-
client_.CreateRequest().url(fmt::format("ws://localhost:{}{}", port, uri)).PerformWebSocketHandshake();
85-
86-
return ws_response.MakeWebSocketConnection();
100+
return client_.CreateRequest().url(fmt::format("ws://localhost:{}{}", port, uri)).PerformWebSocketHandshake();
87101
}
88102

89103
std::string TestEcho(const server::http::HttpRequest& request) const {
90-
auto conn = MakeWebSocketConnection(request);
104+
auto conn = PerformWebSocket(request, "/echo").MakeWebSocketConnection();
91105

92106
conn->SendText("Hello WebSocket");
93107
websocket::Message msg;
@@ -97,7 +111,7 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase {
97111
}
98112

99113
std::string TestLarge(const server::http::HttpRequest& request) const {
100-
auto conn = MakeWebSocketConnection(request);
114+
auto conn = PerformWebSocket(request, "/echo").MakeWebSocketConnection();
101115

102116
std::string large_msg(50000, 'X');
103117
conn->SendText(large_msg);
@@ -108,7 +122,7 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase {
108122
}
109123

110124
std::string TestMultiple(const server::http::HttpRequest& request) const {
111-
auto conn = MakeWebSocketConnection(request);
125+
auto conn = PerformWebSocket(request, "/echo").MakeWebSocketConnection();
112126

113127
for (int i = 0; i < 50; ++i) {
114128
const auto text = fmt::format("msg{}", i);
@@ -124,7 +138,7 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase {
124138
}
125139

126140
std::string TestBinary(const server::http::HttpRequest& request) const {
127-
auto conn = MakeWebSocketConnection(request);
141+
auto conn = PerformWebSocket(request, "/echo").MakeWebSocketConnection();
128142

129143
std::vector<std::byte> data{std::byte{0x01}, std::byte{0xFF}};
130144
conn->SendBinary(data);
@@ -143,8 +157,8 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase {
143157
return resp.data == msg;
144158
};
145159

146-
auto conn1 = MakeWebSocketConnection(request);
147-
auto conn2 = MakeWebSocketConnection(request);
160+
auto conn1 = PerformWebSocket(request, "/echo").MakeWebSocketConnection();
161+
auto conn2 = PerformWebSocket(request, "/echo").MakeWebSocketConnection();
148162

149163
auto task1 = utils::Async("task1", func, std::ref(*conn1), std::string("msg1"));
150164
auto task2 = utils::Async("task2", func, std::ref(*conn2), std::string("msg2"));
@@ -153,8 +167,8 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase {
153167
}
154168

155169
std::string TestNonblockingRead(const server::http::HttpRequest& request) const {
156-
auto conn0 = MakeWebSocketConnection(request);
157-
auto conn1 = MakeWebSocketConnection(request);
170+
auto conn0 = PerformWebSocket(request, "/echo").MakeWebSocketConnection();
171+
auto conn1 = PerformWebSocket(request, "/echo").MakeWebSocketConnection();
158172

159173
conn0->SendText("msg0");
160174
conn1->SendText("msg1");
@@ -183,8 +197,8 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase {
183197
}
184198

185199
std::string TestNonblockingWrite(const server::http::HttpRequest& request) const {
186-
auto conn0 = MakeWebSocketConnection(request);
187-
auto conn1 = MakeWebSocketConnection(request);
200+
auto conn0 = PerformWebSocket(request, "/echo").MakeWebSocketConnection();
201+
auto conn1 = PerformWebSocket(request, "/echo").MakeWebSocketConnection();
188202

189203
std::vector<std::string> messages{};
190204
for (int i = 0; i < 10; ++i) {
@@ -201,6 +215,39 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase {
201215
return "OK";
202216
}
203217

218+
std::string TestUnauth(const server::http::HttpRequest& request) const {
219+
auto ws_response = PerformWebSocket(request, "/unauth");
220+
221+
if (ws_response.IsProtocolUpgraded()) {
222+
return "FAIL: Protocol upgraded";
223+
}
224+
225+
const auto status_code = ws_response.GetHandshakeResponse()->status_code();
226+
if (status_code != http::StatusCode::kUnauthorized) {
227+
return "FAIL: Status code is " + ToString(status_code);
228+
}
229+
230+
try {
231+
ws_response.MakeWebSocketConnection();
232+
return "FAIL: connection should not be created";
233+
} catch (const std::exception&) { // NOLINT(bugprone-empty-catch)
234+
}
235+
236+
return "OK";
237+
}
238+
239+
std::string TestConnectionAlreadyExtracted(const server::http::HttpRequest& request) const {
240+
auto ws_response = PerformWebSocket(request, "/echo");
241+
ws_response.MakeWebSocketConnection();
242+
243+
try {
244+
ws_response.MakeWebSocketConnection();
245+
return "FAIL: connection should not be created twice";
246+
} catch (const std::exception&) { // NOLINT(bugprone-empty-catch)
247+
}
248+
return "OK";
249+
}
250+
204251
clients::http::Client& client_;
205252
};
206253

@@ -210,6 +257,7 @@ int main(int argc, char* argv[]) {
210257
.AppendComponentList(clients::http::ComponentList())
211258
.Append<clients::dns::Component>()
212259
.Append<WebSocketEcho>()
260+
.Append<WebSocketUnauth>()
213261
.Append<TestClientHandler>()
214262
.Append<components::TestsuiteSupport>();
215263
return utils::DaemonMain(argc, argv, component_list);

core/functional_tests/websocket_client/static_config.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ components_manager:
2929
max-remote-payload: 100000
3030
fragment-size: 65536
3131

32+
websocket-unauth-handler:
33+
path: /unauth
34+
method: GET
35+
task_processor: main-task-processor
36+
max-remote-payload: 100000
37+
fragment-size: 65536
38+
3239
# HTTP handler for C++ client tests
3340
test-client:
3441
path: /test-client

core/functional_tests/websocket_client/tests/test_websocket_client.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,18 @@
22

33

44
@pytest.mark.parametrize(
5-
'test_name', ['echo', 'large', 'multiple', 'binary', 'concurrent', 'nonblocking_read', 'nonblocking_write']
5+
'test_name',
6+
[
7+
'echo',
8+
'large',
9+
'multiple',
10+
'binary',
11+
'concurrent',
12+
'nonblocking_read',
13+
'nonblocking_write',
14+
'unauth',
15+
'connection_already_extracted',
16+
],
617
)
718
async def test_client(service_client, service_port, test_name):
819
response = await service_client.get(

core/src/clients/http/request_state.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -513,6 +513,14 @@ void RequestState::OnCompleted(std::shared_ptr<RequestState> holder, std::error_
513513

514514
holder->CheckResponseDeadline(err, status_code);
515515

516+
if (err && std::get_if<WebSocketHandshakeData>(&holder->data_) &&
517+
err == std::error_code(curl::errc::EasyErrorCode::kHttpReturnedError) && status_code != Status::kInvalid)
518+
{
519+
// curl expects 101 for WebSocket, treats other statuses as error.
520+
// Ignore error if got complete HTTP response (e.g. 401, 403).
521+
err = {};
522+
}
523+
516524
if (holder->testsuite_config_ && !err) {
517525
const auto& headers = holder->response()->headers();
518526
err = TestsuiteResponseHook(status_code, headers, span);

core/src/clients/http/websocket_response.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,12 @@ bool WebSocketResponse::IsProtocolUpgraded() const {
1919
}
2020

2121
std::shared_ptr<websocket::WebSocketConnection> WebSocketResponse::MakeWebSocketConnection() {
22+
if (!IsProtocolUpgraded()) {
23+
throw std::runtime_error("Protocol is not upgraded to WebSocket");
24+
}
25+
2226
if (!socket_.IsOpen()) {
23-
return nullptr;
27+
throw std::runtime_error("WebSocketConnection has already been extracted");
2428
}
2529

2630
auto socket = std::make_unique<engine::io::Socket>(socket_.GetNative());

0 commit comments

Comments
 (0)