Skip to content

Commit 1d2c7ad

Browse files
committed
Core/Network: Socket refactors
* Devirtualize calls to Read and Update by marking concrete implementations as final * Removed derived class template argument * Specialize boost::asio::basic_stream_socket for boost::asio::io_context instead of type-erased any_io_executor * Make socket initialization easier composable (before entering Read loop) * Remove use of deprecated boost::asio::null_buffers and boost::beast::ssl_stream (cherry picked from commit e8b2be3)
1 parent f246eb9 commit 1d2c7ad

File tree

21 files changed

+639
-314
lines changed

21 files changed

+639
-314
lines changed

src/common/Utilities/Containers.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ namespace Trinity
207207
if (!p(*rpos))
208208
{
209209
if (rpos != wpos)
210-
std::swap(*rpos, *wpos);
210+
std::ranges::swap(*rpos, *wpos);
211211
++wpos;
212212
}
213213
}

src/server/authserver/Server/AuthSession.cpp

Lines changed: 27 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "DatabaseEnv.h"
2828
#include "IPLocation.h"
2929
#include "IoContext.h"
30+
#include "IpBanCheckConnectionInitializer.h"
3031
#include "Log.h"
3132
#include "RealmList.h"
3233
#include "SecretMgr.h"
@@ -199,21 +200,23 @@ void AccountInfo::LoadResult(Field* fields)
199200
Utf8ToUpperOnlyLatin(Login);
200201
}
201202

202-
AuthSession::AuthSession(tcp::socket&& socket) : Socket(std::move(socket)),
203-
_timeout(*underlying_stream().get_executor().target<boost::asio::io_context::executor_type>()),
203+
AuthSession::AuthSession(Trinity::Net::IoContextTcpSocket&& socket) : Socket(std::move(socket)),
204+
_timeout(underlying_stream().get_executor()),
204205
_status(STATUS_CHALLENGE), _locale(LOCALE_enUS), _os(0), _build(0), _expversion(0), _timezoneOffset(0min)
205206
{
206207
}
207208

208209
void AuthSession::Start()
209210
{
210-
std::string ip_address = GetRemoteIpAddress().to_string();
211-
TC_LOG_TRACE("session", "Accepted connection from {}", ip_address);
212-
213-
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_IP_INFO);
214-
stmt->setString(0, ip_address);
215-
216-
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback(std::bind(&AuthSession::CheckIpCallback, this, std::placeholders::_1)));
211+
// build initializer chain
212+
std::array<std::shared_ptr<Trinity::Net::SocketConnectionInitializer>, 3> initializers =
213+
{ {
214+
std::make_shared<Trinity::Net::IpBanCheckConnectionInitializer<AuthSession>>(this),
215+
std::make_shared<Trinity::Net::ReadConnectionInitializer<AuthSession>>(this),
216+
} };
217+
218+
Trinity::Net::SocketConnectionInitializer::SetupChain(initializers)->Start();
219+
SetTimeout();
217220
}
218221

219222
bool AuthSession::Update()
@@ -226,36 +229,7 @@ bool AuthSession::Update()
226229
return true;
227230
}
228231

229-
void AuthSession::CheckIpCallback(PreparedQueryResult result)
230-
{
231-
if (result)
232-
{
233-
bool banned = false;
234-
do
235-
{
236-
Field* fields = result->Fetch();
237-
if (fields[0].GetUInt64() != 0)
238-
banned = true;
239-
240-
} while (result->NextRow());
241-
242-
if (banned)
243-
{
244-
ByteBuffer pkt;
245-
pkt << uint8(AUTH_LOGON_CHALLENGE);
246-
pkt << uint8(0x00);
247-
pkt << uint8(WOW_FAIL_BANNED);
248-
SendPacket(pkt);
249-
TC_LOG_DEBUG("session", "[AuthSession::CheckIpCallback] Banned ip '{}:{}' tries to login!", GetRemoteIpAddress().to_string(), GetRemotePort());
250-
return;
251-
}
252-
}
253-
254-
AsyncRead();
255-
SetTimeout();
256-
}
257-
258-
void AuthSession::ReadHandler()
232+
Trinity::Net::SocketReadCallbackResult AuthSession::ReadHandler()
259233
{
260234
MessageBuffer& packet = GetReadBuffer();
261235
while (packet.GetActiveSize())
@@ -265,7 +239,7 @@ void AuthSession::ReadHandler()
265239
if (!itr || _status != itr->status)
266240
{
267241
CloseSocket();
268-
return;
242+
return Trinity::Net::SocketReadCallbackResult::Stop;
269243
}
270244

271245
std::size_t size = itr->packetSize;
@@ -279,7 +253,7 @@ void AuthSession::ReadHandler()
279253
if (size > MAX_ACCEPTED_CHALLENGE_SIZE)
280254
{
281255
CloseSocket();
282-
return;
256+
return Trinity::Net::SocketReadCallbackResult::Stop;
283257
}
284258
}
285259

@@ -289,14 +263,19 @@ void AuthSession::ReadHandler()
289263
if (!itr->handler(this))
290264
{
291265
CloseSocket();
292-
return;
266+
return Trinity::Net::SocketReadCallbackResult::Stop;
293267
}
294268

295269
packet.ReadCompleted(size);
296270
SetTimeout();
297271
}
298272

299-
AsyncRead();
273+
return Trinity::Net::SocketReadCallbackResult::KeepReading;
274+
}
275+
276+
void AuthSession::QueueQuery(QueryCallback&& queryCallback)
277+
{
278+
_queryProcessor.AddCallback(std::move(queryCallback));
300279
}
301280

302281
void AuthSession::SendPacket(ByteBuffer& packet)
@@ -334,7 +313,7 @@ bool AuthSession::HandleLogonChallenge()
334313
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_LOGONCHALLENGE);
335314
stmt->setStringView(0, login);
336315

337-
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
316+
QueueQuery(LoginDatabase.AsyncQuery(stmt)
338317
.WithPreparedCallback([this](PreparedQueryResult result) { LogonChallengeCallback(std::move(result)); }));
339318
return true;
340319
}
@@ -546,7 +525,7 @@ bool AuthSession::HandleLogonProof()
546525
stmt->setStringView(3, ClientBuild::ToCharArray(_os).data());
547526
stmt->setInt16(4, _timezoneOffset.count());
548527
stmt->setString(5, _accountInfo.Login);
549-
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
528+
QueueQuery(LoginDatabase.AsyncQuery(stmt)
550529
.WithPreparedCallback([this, M2 = Trinity::Crypto::SRP6::GetSessionVerifier(logonProof->A, logonProof->clientM, _sessionKey)](PreparedQueryResult const&)
551530
{
552531
// Finish SRP6 and send the final result to the client
@@ -665,7 +644,7 @@ bool AuthSession::HandleReconnectChallenge()
665644
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_RECONNECTCHALLENGE);
666645
stmt->setStringView(0, login);
667646

668-
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
647+
QueueQuery(LoginDatabase.AsyncQuery(stmt)
669648
.WithPreparedCallback([this](PreparedQueryResult result) { ReconnectChallengeCallback(std::move(result)); }));
670649
return true;
671650
}
@@ -748,7 +727,7 @@ bool AuthSession::HandleRealmList()
748727
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_REALM_CHARACTER_COUNTS);
749728
stmt->setUInt32(0, _accountInfo.Id);
750729

751-
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback(std::bind(&AuthSession::RealmListCallback, this, std::placeholders::_1)));
730+
QueueQuery(LoginDatabase.AsyncQuery(stmt).WithPreparedCallback(std::bind(&AuthSession::RealmListCallback, this, std::placeholders::_1)));
752731
_status = STATUS_WAITING_FOR_REALM_LIST;
753732
return true;
754733
}
@@ -922,7 +901,7 @@ void AuthSession::SetTimeout()
922901

923902
_timeout.async_wait([selfRef = weak_from_this()](boost::system::error_code const& error)
924903
{
925-
std::shared_ptr<AuthSession> self = selfRef.lock();
904+
std::shared_ptr<AuthSession> self = static_pointer_cast<AuthSession>(selfRef.lock());
926905
if (!self)
927906
return;
928907

src/server/authserver/Server/AuthSession.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,20 +61,21 @@ struct AccountInfo
6161
AccountTypes SecurityLevel = SEC_PLAYER;
6262
};
6363

64-
class AuthSession : public Socket<AuthSession>
64+
class AuthSession final : public Trinity::Net::Socket<>
6565
{
66-
typedef Socket<AuthSession> AuthSocket;
66+
using AuthSocket = Socket;
6767

6868
public:
69-
AuthSession(tcp::socket&& socket);
69+
AuthSession(Trinity::Net::IoContextTcpSocket&& socket);
7070

7171
void Start() override;
7272
bool Update() override;
7373

7474
void SendPacket(ByteBuffer& packet);
7575

76-
protected:
77-
void ReadHandler() override;
76+
Trinity::Net::SocketReadCallbackResult ReadHandler() override;
77+
78+
void QueueQuery(QueryCallback&& queryCallback);
7879

7980
private:
8081
friend AuthHandlerTable;
@@ -87,7 +88,6 @@ class AuthSession : public Socket<AuthSession>
8788
bool HandleXferResume();
8889
bool HandleXferCancel();
8990

90-
void CheckIpCallback(PreparedQueryResult result);
9191
void LogonChallengeCallback(PreparedQueryResult result);
9292
void ReconnectChallengeCallback(PreparedQueryResult result);
9393
void RealmListCallback(PreparedQueryResult result);

src/server/authserver/Server/AuthSocketMgr.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#include "SocketMgr.h"
2222
#include "AuthSession.h"
2323

24-
class AuthSocketMgr : public SocketMgr<AuthSession>
24+
class AuthSocketMgr : public Trinity::Net::SocketMgr<AuthSession>
2525
{
2626
typedef SocketMgr<AuthSession> BaseSocketMgr;
2727

@@ -37,19 +37,17 @@ class AuthSocketMgr : public SocketMgr<AuthSession>
3737
if (!BaseSocketMgr::StartNetwork(ioContext, bindIp, port, threadCount))
3838
return false;
3939

40-
_acceptor->AsyncAcceptWithCallback<&AuthSocketMgr::OnSocketAccept>();
40+
_acceptor->AsyncAccept([this](Trinity::Net::IoContextTcpSocket&& sock, uint32 threadIndex)
41+
{
42+
OnSocketOpen(std::move(sock), threadIndex);
43+
});
4144
return true;
4245
}
4346

4447
protected:
45-
NetworkThread<AuthSession>* CreateThreads() const override
48+
Trinity::Net::NetworkThread<AuthSession>* CreateThreads() const override
4649
{
47-
return new NetworkThread<AuthSession>[1];
48-
}
49-
50-
static void OnSocketAccept(tcp::socket&& sock, uint32 threadIndex)
51-
{
52-
Instance().OnSocketOpen(std::forward<tcp::socket>(sock), threadIndex);
50+
return new Trinity::Net::NetworkThread<AuthSession>[1];
5351
}
5452
};
5553

src/server/game/Scripting/ScriptMgr.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,14 +1274,14 @@ void ScriptMgr::OnNetworkStop()
12741274
FOREACH_SCRIPT(ServerScript)->OnNetworkStop();
12751275
}
12761276

1277-
void ScriptMgr::OnSocketOpen(std::shared_ptr<WorldSocket> socket)
1277+
void ScriptMgr::OnSocketOpen(std::shared_ptr<WorldSocket> const& socket)
12781278
{
12791279
ASSERT(socket);
12801280

12811281
FOREACH_SCRIPT(ServerScript)->OnSocketOpen(socket);
12821282
}
12831283

1284-
void ScriptMgr::OnSocketClose(std::shared_ptr<WorldSocket> socket)
1284+
void ScriptMgr::OnSocketClose(std::shared_ptr<WorldSocket> const& socket)
12851285
{
12861286
ASSERT(socket);
12871287

src/server/game/Scripting/ScriptMgr.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -884,8 +884,8 @@ class TC_GAME_API ScriptMgr
884884

885885
void OnNetworkStart();
886886
void OnNetworkStop();
887-
void OnSocketOpen(std::shared_ptr<WorldSocket> socket);
888-
void OnSocketClose(std::shared_ptr<WorldSocket> socket);
887+
void OnSocketOpen(std::shared_ptr<WorldSocket> const& socket);
888+
void OnSocketClose(std::shared_ptr<WorldSocket> const& socket);
889889
void OnPacketReceive(WorldSession* session, WorldPacket const& packet);
890890
void OnPacketSend(WorldSession* session, WorldPacket const& packet);
891891

0 commit comments

Comments
 (0)