2727#include " DatabaseEnv.h"
2828#include " IPLocation.h"
2929#include " IoContext.h"
30+ #include " IpBanCheckConnectionInitializer.h"
3031#include " Log.h"
3132#include " RealmList.h"
3233#include " SecretMgr.h"
@@ -199,21 +200,23 @@ void AccountInfo::LoadResult(Field* fields)
199200 Utf8ToUpperOnlyLatin (Login);
200201}
201202
202- AuthSession::AuthSession (tcp::socket && socket) : Socket(std::move(socket)),
203- _timeout(* underlying_stream ().get_executor().target<boost::asio::io_context::executor_type> ()),
203+ AuthSession::AuthSession (Trinity::Net::IoContextTcpSocket && socket) : Socket(std::move(socket)),
204+ _timeout(underlying_stream().get_executor()),
204205 _status(STATUS_CHALLENGE), _locale(LOCALE_enUS), _os(0 ), _build(0 ), _expversion(0 ), _timezoneOffset(0min)
205206{
206207}
207208
208209void AuthSession::Start ()
209210{
210- std::string ip_address = GetRemoteIpAddress ().to_string ();
211- TC_LOG_TRACE (" session" , " Accepted connection from {}" , ip_address);
212-
213- LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement (LOGIN_SEL_IP_INFO);
214- stmt->setString (0 , ip_address);
215-
216- _queryProcessor.AddCallback (LoginDatabase.AsyncQuery (stmt).WithPreparedCallback (std::bind (&AuthSession::CheckIpCallback, this , std::placeholders::_1)));
211+ // build initializer chain
212+ std::array<std::shared_ptr<Trinity::Net::SocketConnectionInitializer>, 3 > initializers =
213+ { {
214+ std::make_shared<Trinity::Net::IpBanCheckConnectionInitializer<AuthSession>>(this ),
215+ std::make_shared<Trinity::Net::ReadConnectionInitializer<AuthSession>>(this ),
216+ } };
217+
218+ Trinity::Net::SocketConnectionInitializer::SetupChain (initializers)->Start ();
219+ SetTimeout ();
217220}
218221
219222bool AuthSession::Update ()
@@ -226,36 +229,7 @@ bool AuthSession::Update()
226229 return true ;
227230}
228231
229- void AuthSession::CheckIpCallback (PreparedQueryResult result)
230- {
231- if (result)
232- {
233- bool banned = false ;
234- do
235- {
236- Field* fields = result->Fetch ();
237- if (fields[0 ].GetUInt64 () != 0 )
238- banned = true ;
239-
240- } while (result->NextRow ());
241-
242- if (banned)
243- {
244- ByteBuffer pkt;
245- pkt << uint8 (AUTH_LOGON_CHALLENGE);
246- pkt << uint8 (0x00 );
247- pkt << uint8 (WOW_FAIL_BANNED);
248- SendPacket (pkt);
249- TC_LOG_DEBUG (" session" , " [AuthSession::CheckIpCallback] Banned ip '{}:{}' tries to login!" , GetRemoteIpAddress ().to_string (), GetRemotePort ());
250- return ;
251- }
252- }
253-
254- AsyncRead ();
255- SetTimeout ();
256- }
257-
258- void AuthSession::ReadHandler ()
232+ Trinity::Net::SocketReadCallbackResult AuthSession::ReadHandler ()
259233{
260234 MessageBuffer& packet = GetReadBuffer ();
261235 while (packet.GetActiveSize ())
@@ -265,7 +239,7 @@ void AuthSession::ReadHandler()
265239 if (!itr || _status != itr->status )
266240 {
267241 CloseSocket ();
268- return ;
242+ return Trinity::Net::SocketReadCallbackResult::Stop ;
269243 }
270244
271245 std::size_t size = itr->packetSize ;
@@ -279,7 +253,7 @@ void AuthSession::ReadHandler()
279253 if (size > MAX_ACCEPTED_CHALLENGE_SIZE)
280254 {
281255 CloseSocket ();
282- return ;
256+ return Trinity::Net::SocketReadCallbackResult::Stop ;
283257 }
284258 }
285259
@@ -289,14 +263,19 @@ void AuthSession::ReadHandler()
289263 if (!itr->handler (this ))
290264 {
291265 CloseSocket ();
292- return ;
266+ return Trinity::Net::SocketReadCallbackResult::Stop ;
293267 }
294268
295269 packet.ReadCompleted (size);
296270 SetTimeout ();
297271 }
298272
299- AsyncRead ();
273+ return Trinity::Net::SocketReadCallbackResult::KeepReading;
274+ }
275+
276+ void AuthSession::QueueQuery (QueryCallback&& queryCallback)
277+ {
278+ _queryProcessor.AddCallback (std::move (queryCallback));
300279}
301280
302281void AuthSession::SendPacket (ByteBuffer& packet)
@@ -334,7 +313,7 @@ bool AuthSession::HandleLogonChallenge()
334313 LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement (LOGIN_SEL_LOGONCHALLENGE);
335314 stmt->setStringView (0 , login);
336315
337- _queryProcessor. AddCallback (LoginDatabase.AsyncQuery (stmt)
316+ QueueQuery (LoginDatabase.AsyncQuery (stmt)
338317 .WithPreparedCallback ([this ](PreparedQueryResult result) { LogonChallengeCallback (std::move (result)); }));
339318 return true ;
340319}
@@ -546,7 +525,7 @@ bool AuthSession::HandleLogonProof()
546525 stmt->setStringView (3 , ClientBuild::ToCharArray (_os).data ());
547526 stmt->setInt16 (4 , _timezoneOffset.count ());
548527 stmt->setString (5 , _accountInfo.Login );
549- _queryProcessor. AddCallback (LoginDatabase.AsyncQuery (stmt)
528+ QueueQuery (LoginDatabase.AsyncQuery (stmt)
550529 .WithPreparedCallback ([this , M2 = Trinity::Crypto::SRP6::GetSessionVerifier (logonProof->A , logonProof->clientM , _sessionKey)](PreparedQueryResult const &)
551530 {
552531 // Finish SRP6 and send the final result to the client
@@ -665,7 +644,7 @@ bool AuthSession::HandleReconnectChallenge()
665644 LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement (LOGIN_SEL_RECONNECTCHALLENGE);
666645 stmt->setStringView (0 , login);
667646
668- _queryProcessor. AddCallback (LoginDatabase.AsyncQuery (stmt)
647+ QueueQuery (LoginDatabase.AsyncQuery (stmt)
669648 .WithPreparedCallback ([this ](PreparedQueryResult result) { ReconnectChallengeCallback (std::move (result)); }));
670649 return true ;
671650}
@@ -748,7 +727,7 @@ bool AuthSession::HandleRealmList()
748727 LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement (LOGIN_SEL_REALM_CHARACTER_COUNTS);
749728 stmt->setUInt32 (0 , _accountInfo.Id );
750729
751- _queryProcessor. AddCallback (LoginDatabase.AsyncQuery (stmt).WithPreparedCallback (std::bind (&AuthSession::RealmListCallback, this , std::placeholders::_1)));
730+ QueueQuery (LoginDatabase.AsyncQuery (stmt).WithPreparedCallback (std::bind (&AuthSession::RealmListCallback, this , std::placeholders::_1)));
752731 _status = STATUS_WAITING_FOR_REALM_LIST;
753732 return true ;
754733}
@@ -922,7 +901,7 @@ void AuthSession::SetTimeout()
922901
923902 _timeout.async_wait ([selfRef = weak_from_this ()](boost::system::error_code const & error)
924903 {
925- std::shared_ptr<AuthSession> self = selfRef.lock ();
904+ std::shared_ptr<AuthSession> self = static_pointer_cast<AuthSession>( selfRef.lock () );
926905 if (!self)
927906 return ;
928907
0 commit comments