diff --git a/Source/Client/Networking/HostUtil.cs b/Source/Client/Networking/HostUtil.cs index c72165f2..59bfcfd4 100644 --- a/Source/Client/Networking/HostUtil.cs +++ b/Source/Client/Networking/HostUtil.cs @@ -74,8 +74,8 @@ private static void PrepareLocalServer(ServerSettings settings, bool fromReplay) localServer.startingTimer = TickPatch.Timer; } - localServer.initData = - ServerInitData.Deserialize(new ByteReader(ClientJoiningState.PackInitData(settings.syncConfigs))); + localServer.StartInitData().SetResult( + ServerInitData.Deserialize(new ByteReader(ClientJoiningState.PackInitData(settings.syncConfigs)))); } private static void PrepareGame() diff --git a/Source/Common/CommandHandler.cs b/Source/Common/CommandHandler.cs index f9fad36b..3987d2ad 100644 --- a/Source/Common/CommandHandler.cs +++ b/Source/Common/CommandHandler.cs @@ -20,11 +20,11 @@ public void Send(CommandType cmd, int factionId, int mapId, byte[] data, ServerP { bool debugCmd = cmd == CommandType.DebugTools || - cmd == CommandType.Sync && server.initData!.DebugOnlySyncCmds.Contains(BitConverter.ToInt32(data, 0)); + cmd == CommandType.Sync && server.InitData!.DebugOnlySyncCmds.Contains(BitConverter.ToInt32(data, 0)); if (debugCmd && !CanUseDevMode(sourcePlayer)) return; - bool hostOnly = cmd == CommandType.Sync && server.initData!.HostOnlySyncCmds.Contains(BitConverter.ToInt32(data, 0)); + bool hostOnly = cmd == CommandType.Sync && server.InitData!.HostOnlySyncCmds.Contains(BitConverter.ToInt32(data, 0)); if (hostOnly && !sourcePlayer.IsHost) return; diff --git a/Source/Common/MultiplayerServer.cs b/Source/Common/MultiplayerServer.cs index 32db4a6d..b61a6ae8 100644 --- a/Source/Common/MultiplayerServer.cs +++ b/Source/Common/MultiplayerServer.cs @@ -47,10 +47,10 @@ static MultiplayerServer() public ActionQueue queue = new(); public ServerSettings settings; - public ServerInitData? initData; + public ServerInitData? InitData => initDataSource.Task.ResultNowOrNull(); private TaskCompletionSource initDataSource = new(); public InitDataState InitDataState => - initData != null ? InitDataState.Complete : + InitData != null ? InitDataState.Complete : // started init data must've completed with null, meaning the client disconnected while waiting for the data, // so we are waiting again initDataSource.Task.IsCompleted ? InitDataState.Waiting : @@ -250,7 +250,7 @@ public void RegisterChatCmd(string cmdName, ChatCmdHandler handler) => public void HandleChatCmd(IChatSource source, string cmd) => chatCmdManager.Handle(source, cmd); - public Task InitData() => initDataSource.Task; + public Task InitDataTask() => initDataSource.Task; /// Can only start one init data at a time. A StartInitData is considered complete once /// TaskCompletionResult.SetResult is called. Until that time no new calls to StartInitData will succeed. @@ -258,14 +258,7 @@ public void RegisterChatCmd(string cmdName, ChatCmdHandler handler) => { if (InitDataState != InitDataState.Waiting) throw new InvalidOperationException($"Can't start init data in state {InitDataState}"); - var currInitDataSource = initDataSource = new TaskCompletionSource(); - currInitDataSource.Task.ContinueWith(task => - { - if (currInitDataSource != initDataSource) - ServerLog.Error("InitDataSource changed during StartInitData"); - initData = task.Result; - }, TaskContinuationOptions.ExecuteSynchronously); - return currInitDataSource; + return initDataSource = new TaskCompletionSource(); } } diff --git a/Source/Common/Networking/State/ServerJoiningState.cs b/Source/Common/Networking/State/ServerJoiningState.cs index e0f7c83d..26f90198 100644 --- a/Source/Common/Networking/State/ServerJoiningState.cs +++ b/Source/Common/Networking/State/ServerJoiningState.cs @@ -1,4 +1,5 @@ -using System.Threading.Tasks; +using System; +using System.Threading.Tasks; namespace Multiplayer.Common; @@ -13,7 +14,7 @@ protected override async Task RunState() HandleProtocol(await Packet(Packets.Client_Protocol)); HandleUsername(await Packet(Packets.Client_Username)); - while (await Server.InitData() is null && await EndIfDead()) + while (await Server.InitDataTask() is null && await EndIfDead()) if (Server.InitDataState == InitDataState.Waiting) await RequestInitData(); @@ -105,9 +106,11 @@ private bool HandleClientJoinData(ByteReader data) var modCtorRoundMode = data.ReadEnum(); var staticCtorRoundMode = data.ReadEnum(); - if ((modCtorRoundMode, staticCtorRoundMode) != Server.initData!.RoundModes) + var serverInitData = Server.InitData ?? + throw new Exception("Server init data is null during handling of client join data"); + if ((modCtorRoundMode, staticCtorRoundMode) != serverInitData.RoundModes) { - Player.Disconnect($"FP round modes don't match: {(modCtorRoundMode, staticCtorRoundMode)} != {Server.initData!.RoundModes}"); + Player.Disconnect($"FP round modes don't match: {(modCtorRoundMode, staticCtorRoundMode)} != {serverInitData.RoundModes}"); return false; } @@ -129,7 +132,7 @@ private bool HandleClientJoinData(ByteReader data) var status = DefCheckStatus.Ok; - if (!Server.initData!.DefInfos.TryGetValue(defType, out DefInfo info)) + if (!serverInitData.DefInfos.TryGetValue(defType, out DefInfo info)) status = DefCheckStatus.Not_Found; else if (info.count != defCount) status = DefCheckStatus.Count_Diff; @@ -146,10 +149,10 @@ private bool HandleClientJoinData(ByteReader data) Packets.Server_JoinData, Server.settings.gameName, Player.id, - Server.initData!.RwVersion, + serverInitData.RwVersion, MpVersion.Version, defsResponse.ToArray(), - Server.initData.RawData + serverInitData.RawData ); return defsMatch; diff --git a/Source/Common/Util/Extensions.cs b/Source/Common/Util/Extensions.cs index 66f46686..eadc66eb 100644 --- a/Source/Common/Util/Extensions.cs +++ b/Source/Common/Util/Extensions.cs @@ -7,6 +7,7 @@ using System.Reflection; using System.Security.Cryptography; using System.Text; +using System.Threading.Tasks; namespace Multiplayer.Common { @@ -128,6 +129,9 @@ public static float MaxOrZero(this IEnumerable items, Func map) { return items.Max(i => (float?)map(i)) ?? 0f; } + + public static T? ResultNowOrNull(this Task task) where T : class? => + task.Status == TaskStatus.RanToCompletion ? task.Result : null; } public static class Utils