33#include < string>
44#include < utility>
55#include < iostream>
6+ #include < algorithm>
7+
8+ #include " sleepy_discord/net_endian.h"
9+ #include " version_helper.h"
10+
611#include " asio_include.h"
712#if defined(SLEEPY_USE_BOOST_ASIO)
813#include < boost/asio/ssl.hpp>
1722#include < asio/read.hpp>
1823#include < asio/streambuf.hpp>
1924#endif
25+
2026#include < openssl/x509.h>
2127#include < openssl/evp.h>
2228#include < openssl/rand.h>
2329#include < openssl/sha.h>
24- #include " endian.h"
25- #include " version_helper.h"
2630
27- // to do: handle muliple frames in one read
2831// to do: use a log event and remove iostream include
2932
3033namespace SleepyDiscord {
@@ -118,11 +121,11 @@ namespace SleepyDiscord {
118121
119122 cancelSignal->emit (asio::cancellation_type::all);
120123 self->logError (" connect took too long, canceling" , err);
121- self->socketPtr ->async_shutdown ();
124+ self->socketPtr ->async_shutdown ([]( const asio::error_code& _) {} );
122125 });
123126
124127 auto cleanUp = [self, connectDeadline]() {
125- self->socketPtr ->async_shutdown ();
128+ self->socketPtr ->async_shutdown ([]( const asio::error_code& _) {} );
126129 connectDeadline->cancel ();
127130 };
128131
@@ -238,11 +241,12 @@ namespace SleepyDiscord {
238241 std::string headerFieldRight = response.substr (offset, newLinePos - offset);
239242 offset = newLinePos + 2 ;
240243 // insert to the map with the key being lowercase ASCII only
241- std::transform (headerFieldLeft.begin (), headerFieldLeft.end (), headerFieldLeft.begin (), std::tolower);
244+ std::transform (headerFieldLeft.begin (), headerFieldLeft.end (), headerFieldLeft.begin (), [](char c) {
245+ return std::tolower (c);
246+ });
242247 headers.insert (std::make_pair (headerFieldLeft, std::move (headerFieldRight)));
243248 }
244249 offset += 2 ; // + 2 to skip end of header
245- const auto bodyStart = offset;
246250
247251 // easyer to write it like this, but a map of functions might be faster
248252 {
@@ -253,7 +257,9 @@ namespace SleepyDiscord {
253257 }
254258 const std::string expectedValue = " websocket" ;
255259 std::string& value = iter->second ;
256- std::transform (value.begin (), value.end (), value.begin (), std::tolower);
260+ std::transform (value.begin (), value.end (), value.begin (), [](char c) {
261+ return std::tolower (c);
262+ });
257263 if (value != expectedValue) {
258264 self->logError (" HTTP protocol switch to websocket failed, incorrect Upgrade value from server" , err);
259265 return cleanUp ();
@@ -268,7 +274,9 @@ namespace SleepyDiscord {
268274 }
269275 const std::string expectedValue = " upgrade" ;
270276 std::string& value = iter->second ;
271- std::transform (value.begin (), value.end (), value.begin (), std::tolower);
277+ std::transform (value.begin (), value.end (), value.begin (), [](char c) {
278+ return std::tolower (c);
279+ });
272280 if (value != expectedValue) {
273281 self->logError (" HTTP protocol switch to websocket failed, incorrect Connection value from server" , err);
274282 return cleanUp ();
@@ -344,7 +352,7 @@ namespace SleepyDiscord {
344352 // so they are basically the same to us. But, the server might do the validation, so this
345353 // should be set correctly by the sender to prevent a invalid UTF-8 error.
346354 uint16_t opCode,
347- std::function<void (asio::error_code)> callback
355+ std::function<void (const asio::error_code& )> callback
348356 ) {
349357 if (!ready.load ()) { // can't send while not connected
350358 std::cerr << " Can't send while not connected\n " ;
@@ -374,11 +382,11 @@ namespace SleepyDiscord {
374382
375383 switch (shortPayloadLength) {
376384 case 126 : {
377- auto netInt16 = system2net16 (static_cast <uint16_t >(payloadLength));
385+ auto netInt16 = system2net16< uint8_t > (static_cast <uint16_t >(payloadLength));
378386 message->append (netInt16.data (), netInt16.size ());
379387 } break ;
380388 case 127 : {
381- auto netInt64 = system2net64 (static_cast <uint64_t >(payloadLength));
389+ auto netInt64 = system2net64< uint8_t > (static_cast <uint64_t >(payloadLength));
382390 message->append (netInt64.data (), netInt64.size ());
383391 } break ;
384392 default : break ;
@@ -411,10 +419,10 @@ namespace SleepyDiscord {
411419
412420 if (socketPtr == nullptr ) return ; // already disconnected
413421 auto self = shared_from_this ();
414- auto netCloseInt = system2net16 (code);
422+ auto netCloseInt = system2net16< uint8_t > (code);
415423 std::basic_string<uint8_t > closePayload{ netCloseInt.data (), netCloseInt.size () };
416424 closePayload.append (reinterpret_cast <const uint8_t *>(reason.data ()), reason.length ());
417- send (closePayload, 0x8 /* Close*/ , [self, closeSocketAfterSend](asio::error_code& err) {
425+ send (closePayload, 0x8 /* Close*/ , [self, closeSocketAfterSend](const asio::error_code& err) {
418426 if (err || closeSocketAfterSend) {
419427 self->shutdown ();
420428 }
@@ -469,7 +477,7 @@ namespace SleepyDiscord {
469477
470478 void receiveFullHandshake (std::shared_ptr<std::array<char , 256 >> responsePtr, const std::size_t length, int score = 0 ) {
471479 const static std::array<char , 4 > endOfHandshake = { ' \r ' , ' \n ' , ' \r ' , ' \n ' };
472- for (int i = 0 ; i < length; i += 1 ) {
480+ for (std:: size_t i = 0 ; i < length; i += 1 ) {
473481 if (responsePtr->data ()[i] == endOfHandshake[score]) {
474482 score += 1 ;
475483 if (score == 4 ) {
@@ -529,15 +537,15 @@ namespace SleepyDiscord {
529537 std::memcpy (&secondByte, &temp, 1 );
530538 }
531539
532- constexpr int8_t hasMaskByte = 1 ; constexpr int8_t hasMaskBit = static_cast <uint8_t >(0b1000'0000 );
540+ constexpr int8_t hasMaskBit = static_cast <uint8_t >(0b1000'0000 );
533541 bool hasMask = (secondByte & hasMaskBit) == hasMaskBit;
534542 const std::size_t maskLength = hasMask ? sizeof (uint32_t ) : 0 ;
535543 if (hasMask) { // Servers shouldn't sent masked payloads
536544 std::cerr << " Recevied masked websocket message, Servers shouldn't send masked payloads\n " ;
537545 return ;
538546 }
539547
540- constexpr int8_t hasExtendedLengthByte = 1 ; constexpr int8_t hasExtendedLength = static_cast <uint8_t >(0b0111'1111 );
548+ constexpr int8_t hasExtendedLength = static_cast <uint8_t >(0b0111'1111 );
541549 int8_t length7Bit = secondByte & hasExtendedLength;
542550 std::size_t lengthOfLength;
543551 switch (length7Bit) {
@@ -580,7 +588,7 @@ namespace SleepyDiscord {
580588 auto self = shared_from_this ();
581589 std::istream inputStream (readBuffer.get ());
582590 int extendedLengthPrefixLength;
583- uint64_t payloadLength;
591+ std:: size_t payloadLength;
584592 switch (frame.length7Bit ) {
585593 case 126 : {
586594 constexpr std::size_t size = sizeof (uint16_t );
@@ -600,7 +608,19 @@ namespace SleepyDiscord {
600608 inputStream.read (temp.data (), networkInt.size ());
601609 std::memcpy (networkInt.data (), temp.data (), networkInt.size ());
602610 }
603- payloadLength = net2System64 (networkInt);
611+ uint64_t temp = net2System64 (networkInt);
612+
613+ constexpr bool pointerIsTooSmall = (std::numeric_limits<std::size_t >::max)() < (std::numeric_limits<uint64_t >::max)();
614+ if ( // on a 32 bit or 16 bit system, we can't put the message in memory
615+ // the compiler should be smart enough to get rid of this on 64-bit systems
616+ pointerIsTooSmall && ((std::numeric_limits<std::size_t >::max)() < payloadLength)
617+ ) { // doubt that Discord would ever sent a 4 GB websocket payload, so just kill it
618+ std::cerr << " Can't read message larger then max pointer value\n " ;
619+ disconnect (1009 , " sizeof(PTR)<MSG" );
620+ return ;
621+ }
622+
623+ payloadLength = static_cast <std::size_t >(temp);
604624 } break ;
605625 default :
606626 extendedLengthPrefixLength = 0 ;
@@ -610,11 +630,11 @@ namespace SleepyDiscord {
610630
611631 const std::size_t maskStart = 2 + static_cast <size_t >(extendedLengthPrefixLength);
612632 const std::size_t payloadStart = maskStart + frame.maskLength ;
613- const std::size_t fullMessageLength = payloadStart + payloadLength ;
633+ const std::size_t leftOverLength = length - payloadStart ;
614634
615- if (length < fullMessageLength ) {
635+ if (leftOverLength < payloadLength ) {
616636 // we don't have the whole message, loop receive until we have the whole message
617- const std::size_t bytesLeftToGet = fullMessageLength - length ;
637+ const std::size_t bytesLeftToGet = payloadLength - leftOverLength ;
618638 asio::async_read (*(self->socketPtr ), *readBuffer, asio::transfer_at_least (bytesLeftToGet), [self, readBuffer, frame, payloadLength, bytesLeftToGet](const asio::error_code& err, const std::size_t length) {
619639 if (err) {
620640 std::cerr << " failed to read whole\n " ;
@@ -627,7 +647,7 @@ namespace SleepyDiscord {
627647 else {
628648 // we have the whole message, we don't need to read again
629649 self->onPayload (frame.opCode , readBuffer, payloadLength);
630- afterPayload (length - fullMessageLength , readBuffer);
650+ afterPayload (leftOverLength - payloadLength , readBuffer);
631651 }
632652 }
633653
@@ -671,6 +691,8 @@ namespace SleepyDiscord {
671691 case pongOp:
672692 // only needed if ping send is implemented
673693 break ;
694+ case continueOp:
695+ // shouldn't happen because we don't support websocket fragmentation
674696 default :
675697 std::cerr << " Unknown op code from server\n " ;
676698 return ;
0 commit comments