@@ -37,6 +37,21 @@ class WebSocketEcho final : public server::handlers::WebsocketHandlerBase {
3737 }
3838};
3939
40+ class WebSocketUnauth final : public server::handlers::WebsocketHandlerBase {
41+ public:
42+ static constexpr std::string_view kName = " websocket-unauth-handler" ;
43+
44+ using WebsocketHandlerBase::WebsocketHandlerBase;
45+
46+ bool
47+ HandleHandshake (const server::http::HttpRequest&, server::http::HttpResponse& response, server::request::RequestContext&) const override {
48+ response.SetStatus (server::http::HttpStatus::kUnauthorized );
49+ return false ;
50+ }
51+
52+ void Handle (websocket::WebSocketConnection&, server::request::RequestContext&) const override {}
53+ };
54+
4055// HTTP handler for testing C++ WebSocket client
4156class TestClientHandler final : public server::handlers::HttpHandlerBase {
4257public:
@@ -66,6 +81,10 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase {
6681 return TestNonblockingRead (request);
6782 } else if (test_name == " nonblocking_write" ) {
6883 return TestNonblockingWrite (request);
84+ } else if (test_name == " unauth" ) {
85+ return TestUnauth (request);
86+ } else if (test_name == " connection_already_extracted" ) {
87+ return TestConnectionAlreadyExtracted (request);
6988 }
7089 return " Unknown test" ;
7190 } catch (const std::exception& e) {
@@ -74,20 +93,15 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase {
7493 }
7594
7695private:
77- std::shared_ptr<websocket::WebSocketConnection> MakeWebSocketConnection (
78- const server::http::HttpRequest& request,
79- const std::string& uri = " /echo"
80- ) const {
96+ clients::http::WebSocketResponse PerformWebSocket (const server::http::HttpRequest& request, const std::string& uri)
97+ const {
8198 const auto port = std::stoi (request.GetArg (" port" ));
8299
83- auto ws_response =
84- client_.CreateRequest ().url (fmt::format (" ws://localhost:{}{}" , port, uri)).PerformWebSocketHandshake ();
85-
86- return ws_response.MakeWebSocketConnection ();
100+ return client_.CreateRequest ().url (fmt::format (" ws://localhost:{}{}" , port, uri)).PerformWebSocketHandshake ();
87101 }
88102
89103 std::string TestEcho (const server::http::HttpRequest& request) const {
90- auto conn = MakeWebSocketConnection (request);
104+ auto conn = PerformWebSocket (request, " /echo " ). MakeWebSocketConnection ( );
91105
92106 conn->SendText (" Hello WebSocket" );
93107 websocket::Message msg;
@@ -97,7 +111,7 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase {
97111 }
98112
99113 std::string TestLarge (const server::http::HttpRequest& request) const {
100- auto conn = MakeWebSocketConnection (request);
114+ auto conn = PerformWebSocket (request, " /echo " ). MakeWebSocketConnection ( );
101115
102116 std::string large_msg (50000 , ' X' );
103117 conn->SendText (large_msg);
@@ -108,7 +122,7 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase {
108122 }
109123
110124 std::string TestMultiple (const server::http::HttpRequest& request) const {
111- auto conn = MakeWebSocketConnection (request);
125+ auto conn = PerformWebSocket (request, " /echo " ). MakeWebSocketConnection ( );
112126
113127 for (int i = 0 ; i < 50 ; ++i) {
114128 const auto text = fmt::format (" msg{}" , i);
@@ -124,7 +138,7 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase {
124138 }
125139
126140 std::string TestBinary (const server::http::HttpRequest& request) const {
127- auto conn = MakeWebSocketConnection (request);
141+ auto conn = PerformWebSocket (request, " /echo " ). MakeWebSocketConnection ( );
128142
129143 std::vector<std::byte> data{std::byte{0x01 }, std::byte{0xFF }};
130144 conn->SendBinary (data);
@@ -143,8 +157,8 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase {
143157 return resp.data == msg;
144158 };
145159
146- auto conn1 = MakeWebSocketConnection (request);
147- auto conn2 = MakeWebSocketConnection (request);
160+ auto conn1 = PerformWebSocket (request, " /echo " ). MakeWebSocketConnection ( );
161+ auto conn2 = PerformWebSocket (request, " /echo " ). MakeWebSocketConnection ( );
148162
149163 auto task1 = utils::Async (" task1" , func, std::ref (*conn1), std::string (" msg1" ));
150164 auto task2 = utils::Async (" task2" , func, std::ref (*conn2), std::string (" msg2" ));
@@ -153,8 +167,8 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase {
153167 }
154168
155169 std::string TestNonblockingRead (const server::http::HttpRequest& request) const {
156- auto conn0 = MakeWebSocketConnection (request);
157- auto conn1 = MakeWebSocketConnection (request);
170+ auto conn0 = PerformWebSocket (request, " /echo " ). MakeWebSocketConnection ( );
171+ auto conn1 = PerformWebSocket (request, " /echo " ). MakeWebSocketConnection ( );
158172
159173 conn0->SendText (" msg0" );
160174 conn1->SendText (" msg1" );
@@ -183,8 +197,8 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase {
183197 }
184198
185199 std::string TestNonblockingWrite (const server::http::HttpRequest& request) const {
186- auto conn0 = MakeWebSocketConnection (request);
187- auto conn1 = MakeWebSocketConnection (request);
200+ auto conn0 = PerformWebSocket (request, " /echo " ). MakeWebSocketConnection ( );
201+ auto conn1 = PerformWebSocket (request, " /echo " ). MakeWebSocketConnection ( );
188202
189203 std::vector<std::string> messages{};
190204 for (int i = 0 ; i < 10 ; ++i) {
@@ -201,6 +215,39 @@ class TestClientHandler final : public server::handlers::HttpHandlerBase {
201215 return " OK" ;
202216 }
203217
218+ std::string TestUnauth (const server::http::HttpRequest& request) const {
219+ auto ws_response = PerformWebSocket (request, " /unauth" );
220+
221+ if (ws_response.IsProtocolUpgraded ()) {
222+ return " FAIL: Protocol upgraded" ;
223+ }
224+
225+ const auto status_code = ws_response.GetHandshakeResponse ()->status_code ();
226+ if (status_code != http::StatusCode::kUnauthorized ) {
227+ return " FAIL: Status code is " + ToString (status_code);
228+ }
229+
230+ try {
231+ ws_response.MakeWebSocketConnection ();
232+ return " FAIL: connection should not be created" ;
233+ } catch (const std::exception&) { // NOLINT(bugprone-empty-catch)
234+ }
235+
236+ return " OK" ;
237+ }
238+
239+ std::string TestConnectionAlreadyExtracted (const server::http::HttpRequest& request) const {
240+ auto ws_response = PerformWebSocket (request, " /echo" );
241+ ws_response.MakeWebSocketConnection ();
242+
243+ try {
244+ ws_response.MakeWebSocketConnection ();
245+ return " FAIL: connection should not be created twice" ;
246+ } catch (const std::exception&) { // NOLINT(bugprone-empty-catch)
247+ }
248+ return " OK" ;
249+ }
250+
204251 clients::http::Client& client_;
205252};
206253
@@ -210,6 +257,7 @@ int main(int argc, char* argv[]) {
210257 .AppendComponentList (clients::http::ComponentList ())
211258 .Append <clients::dns::Component>()
212259 .Append <WebSocketEcho>()
260+ .Append <WebSocketUnauth>()
213261 .Append <TestClientHandler>()
214262 .Append <components::TestsuiteSupport>();
215263 return utils::DaemonMain (argc, argv, component_list);
0 commit comments