Skip to content

Commit f21a415

Browse files
committed
Core/Authserver: Minor span/string_view modernization
1 parent 13f1873 commit f21a415

File tree

2 files changed

+35
-55
lines changed

2 files changed

+35
-55
lines changed

src/server/authserver/Server/AuthSession.cpp

Lines changed: 29 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -56,18 +56,18 @@ typedef struct AUTH_LOGON_CHALLENGE_C
5656
uint8 cmd;
5757
uint8 error;
5858
uint16 size;
59-
uint8 gamename[4];
59+
uint32 gamename;
6060
uint8 version1;
6161
uint8 version2;
6262
uint8 version3;
6363
uint16 build;
64-
uint8 platform[4];
65-
uint8 os[4];
66-
uint8 country[4];
64+
uint32 platform;
65+
uint32 os;
66+
uint32 country;
6767
uint32 timezone_bias;
6868
uint32 ip;
6969
uint8 I_len;
70-
uint8 I[1];
70+
char I[1];
7171
} sAuthLogonChallenge_C;
7272
static_assert(sizeof(sAuthLogonChallenge_C) == (1 + 1 + 2 + 4 + 1 + 1 + 1 + 2 + 4 + 4 + 4 + 4 + 4 + 1 + 1));
7373

@@ -180,10 +180,10 @@ void AccountInfo::LoadResult(Field* fields)
180180
//FROM account a LEFT JOIN account_access aa ON a.id = aa.AccountID LEFT JOIN account_banned ab ON ab.id = a.id AND ab.active = 1 WHERE a.username = ?
181181

182182
Id = fields[0].GetUInt32();
183-
Login = fields[1].GetString();
183+
Login = fields[1].GetStringView();
184184
IsLockedToIP = fields[2].GetBool();
185-
LockCountry = fields[3].GetString();
186-
LastIP = fields[4].GetString();
185+
LockCountry = fields[3].GetStringView();
186+
LastIP = fields[4].GetStringView();
187187
FailedLogins = fields[5].GetUInt32();
188188
IsBanned = fields[6].GetUInt64() != 0;
189189
IsPermanenetlyBanned = fields[7].GetUInt64() != 0;
@@ -196,7 +196,9 @@ void AccountInfo::LoadResult(Field* fields)
196196
}
197197

198198
AuthSession::AuthSession(tcp::socket&& socket) : Socket(std::move(socket)),
199-
_status(STATUS_CHALLENGE), _build(0), _timezoneOffset(0min), _expversion(0) { }
199+
_status(STATUS_CHALLENGE), _locale(LOCALE_enUS), _os(0), _build(0), _expversion(0), _timezoneOffset(0min)
200+
{
201+
}
200202

201203
void AuthSession::Start()
202204
{
@@ -311,28 +313,19 @@ bool AuthSession::HandleLogonChallenge()
311313
if (challenge->size - (sizeof(sAuthLogonChallenge_C) - AUTH_LOGON_CHALLENGE_INITIAL_SIZE - 1) != challenge->I_len)
312314
return false;
313315

314-
std::string login((char const*)challenge->I, challenge->I_len);
316+
std::string_view login(challenge->I, challenge->I_len);
315317
TC_LOG_DEBUG("server.authserver", "[AuthChallenge] '{}'", login);
316318

317319
_build = challenge->build;
318320
_expversion = uint8(AuthHelper::IsPostBCAcceptedClientBuild(_build) ? POST_BC_EXP_FLAG : (AuthHelper::IsPreBCAcceptedClientBuild(_build) ? PRE_BC_EXP_FLAG : NO_VALID_EXP_FLAG));
319-
std::array<char, 5> os;
320-
os.fill('\0');
321-
memcpy(os.data(), challenge->os, sizeof(challenge->os));
322-
_os = os.data();
323-
324-
// Restore string order as its byte order is reversed
325-
std::reverse(_os.begin(), _os.end());
326-
327-
_localizationName.resize(4);
328-
for (int i = 0; i < 4; ++i)
329-
_localizationName[i] = challenge->country[4 - i - 1];
321+
_os = challenge->os;
322+
_locale = GetLocaleByName(ClientBuild::ToCharArray(challenge->country).data());
330323

331324
_timezoneOffset = Minutes(challenge->timezone_bias);
332325

333326
// Get the account details from the account table
334327
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_LOGONCHALLENGE);
335-
stmt->setString(0, login);
328+
stmt->setStringView(0, login);
336329

337330
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
338331
.WithPreparedCallback([this](PreparedQueryResult result) { LogonChallengeCallback(std::move(result)); }));
@@ -467,7 +460,7 @@ void AuthSession::LogonChallengeCallback(PreparedQueryResult result)
467460
pkt << uint8(1);
468461

469462
TC_LOG_DEBUG("server.authserver", "'{}:{}' [AuthChallenge] account {} is using '{}' locale ({})",
470-
ipAddress, port, _accountInfo.Login, _localizationName, GetLocaleByName(_localizationName));
463+
ipAddress, port, _accountInfo.Login, localeNames[_locale], uint32(_locale));
471464

472465
_status = STATUS_LOGON_PROOF;
473466
}
@@ -524,7 +517,7 @@ bool AuthSession::HandleLogonProof()
524517
return true;
525518
}
526519

527-
if (!VerifyVersion(logonProof->A.data(), logonProof->A.size(), logonProof->crc_hash, false))
520+
if (!VerifyVersion(logonProof->A, logonProof->crc_hash, false))
528521
{
529522
ByteBuffer packet;
530523
packet << uint8(AUTH_LOGON_PROOF);
@@ -542,8 +535,8 @@ bool AuthSession::HandleLogonProof()
542535
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_UPD_LOGONPROOF);
543536
stmt->setBinary(0, _sessionKey);
544537
stmt->setString(1, address);
545-
stmt->setUInt32(2, GetLocaleByName(_localizationName));
546-
stmt->setString(3, _os);
538+
stmt->setUInt32(2, _locale);
539+
stmt->setStringView(3, ClientBuild::ToCharArray(_os).data());
547540
stmt->setInt16(4, _timezoneOffset.count());
548541
stmt->setString(5, _accountInfo.Login);
549542
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
@@ -651,28 +644,19 @@ bool AuthSession::HandleReconnectChallenge()
651644
if (challenge->size - (sizeof(sAuthLogonChallenge_C) - AUTH_LOGON_CHALLENGE_INITIAL_SIZE - 1) != challenge->I_len)
652645
return false;
653646

654-
std::string login((char const*)challenge->I, challenge->I_len);
647+
std::string_view login(challenge->I, challenge->I_len);
655648
TC_LOG_DEBUG("server.authserver", "[ReconnectChallenge] '{}'", login);
656649

657650
_build = challenge->build;
658651
_expversion = uint8(AuthHelper::IsPostBCAcceptedClientBuild(_build) ? POST_BC_EXP_FLAG : (AuthHelper::IsPreBCAcceptedClientBuild(_build) ? PRE_BC_EXP_FLAG : NO_VALID_EXP_FLAG));
659-
std::array<char, 5> os;
660-
os.fill('\0');
661-
memcpy(os.data(), challenge->os, sizeof(challenge->os));
662-
_os = os.data();
663-
664-
// Restore string order as its byte order is reversed
665-
std::reverse(_os.begin(), _os.end());
666-
667-
_localizationName.resize(4);
668-
for (int i = 0; i < 4; ++i)
669-
_localizationName[i] = challenge->country[4 - i - 1];
652+
_os = challenge->os;
653+
_locale = GetLocaleByName(ClientBuild::ToCharArray(challenge->country).data());
670654

671655
_timezoneOffset = Minutes(challenge->timezone_bias);
672656

673657
// Get the account details from the account table
674658
LoginDatabasePreparedStatement* stmt = LoginDatabase.GetPreparedStatement(LOGIN_SEL_RECONNECTCHALLENGE);
675-
stmt->setString(0, login);
659+
stmt->setStringView(0, login);
676660

677661
_queryProcessor.AddCallback(LoginDatabase.AsyncQuery(stmt)
678662
.WithPreparedCallback([this](PreparedQueryResult result) { ReconnectChallengeCallback(std::move(result)); }));
@@ -724,7 +708,7 @@ bool AuthSession::HandleReconnectProof()
724708

725709
if (sha.GetDigest() == reconnectProof->R2)
726710
{
727-
if (!VerifyVersion(reconnectProof->R1, sizeof(reconnectProof->R1), reconnectProof->R3, true))
711+
if (!VerifyVersion(reconnectProof->R1, reconnectProof->R3, true))
728712
{
729713
ByteBuffer packet;
730714
packet << uint8(AUTH_RECONNECT_PROOF);
@@ -800,11 +784,7 @@ void AuthSession::RealmListCallback(PreparedQueryResult result)
800784

801785
std::string name = realm.Name;
802786
if (_expversion & PRE_BC_EXP_FLAG && flag & REALM_FLAG_SPECIFYBUILD)
803-
{
804-
std::ostringstream ss;
805-
ss << name << " (" << buildInfo->MajorVersion << '.' << buildInfo->MinorVersion << '.' << buildInfo->BugfixVersion << ')';
806-
name = ss.str();
807-
}
787+
Trinity::StringFormatTo(std::back_inserter(name), " ({}.{}.{})", buildInfo->MajorVersion, buildInfo->MinorVersion, buildInfo->BugfixVersion);
808788

809789
uint8 lock = (realm.AllowedSecurityLevel > _accountInfo.SecurityLevel) ? 1 : 0;
810790

@@ -886,7 +866,7 @@ bool AuthSession::HandleXferCancel()
886866
return false;
887867
}
888868

889-
bool AuthSession::VerifyVersion(uint8 const* a, int32 aLength, Trinity::Crypto::SHA1::Digest const& versionProof, bool isReconnect)
869+
bool AuthSession::VerifyVersion(std::span<uint8 const> a, Trinity::Crypto::SHA1::Digest const& versionProof, bool isReconnect)
890870
{
891871
if (!sConfigMgr->GetBoolDefault("StrictVersionCheck", false))
892872
return true;
@@ -899,7 +879,7 @@ bool AuthSession::VerifyVersion(uint8 const* a, int32 aLength, Trinity::Crypto::
899879
if (!buildInfo)
900880
return false;
901881

902-
auto platformItr = std::ranges::find(buildInfo->ExecutableHashes, ClientBuild::ToFourCC(_os), &ClientBuild::ExecutableHash::Platform);
882+
auto platformItr = std::ranges::find(buildInfo->ExecutableHashes, _os, &ClientBuild::ExecutableHash::Platform);
903883
if (platformItr == buildInfo->ExecutableHashes.end())
904884
return true; // not filled serverside
905885

@@ -909,9 +889,9 @@ bool AuthSession::VerifyVersion(uint8 const* a, int32 aLength, Trinity::Crypto::
909889
versionHash = &zeros;
910890

911891
Trinity::Crypto::SHA1 version;
912-
version.UpdateData(a, aLength);
892+
version.UpdateData(a);
913893
version.UpdateData(*versionHash);
914894
version.Finalize();
915895

916-
return (versionProof == version.GetDigest());
896+
return versionProof == version.GetDigest();
917897
}

src/server/authserver/Server/AuthSession.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@
2727
#include "Socket.h"
2828
#include "SRP6.h"
2929
#include <boost/asio/ip/tcp.hpp>
30+
#include <span>
3031

3132
using boost::asio::ip::tcp;
3233

3334
class AuthHandlerTable;
3435
class ByteBuffer;
35-
enum eAuthCmd : uint8;
3636

3737
enum AuthStatus
3838
{
@@ -91,7 +91,7 @@ class AuthSession : public Socket<AuthSession>
9191
void ReconnectChallengeCallback(PreparedQueryResult result);
9292
void RealmListCallback(PreparedQueryResult result);
9393

94-
bool VerifyVersion(uint8 const* a, int32 aLength, Trinity::Crypto::SHA1::Digest const& versionProof, bool isReconnect);
94+
bool VerifyVersion(std::span<uint8 const> a, Trinity::Crypto::SHA1::Digest const& versionProof, bool isReconnect);
9595

9696
Optional<Trinity::Crypto::SRP6> _srp6;
9797
SessionKey _sessionKey = {};
@@ -100,12 +100,12 @@ class AuthSession : public Socket<AuthSession>
100100
AuthStatus _status;
101101
AccountInfo _accountInfo;
102102
Optional<std::vector<uint8>> _totpSecret;
103-
std::string _localizationName;
104-
std::string _os;
105-
std::string _ipCountry;
103+
LocaleConstant _locale;
104+
uint32 _os;
105+
std::string_view _ipCountry;
106106
uint16 _build;
107-
Minutes _timezoneOffset;
108107
uint8 _expversion;
108+
Minutes _timezoneOffset;
109109

110110
QueryCallbackProcessor _queryProcessor;
111111
};

0 commit comments

Comments
 (0)