diff --git a/src/FRC.NetworkTables/Dispatcher.cs b/src/FRC.NetworkTables/Dispatcher.cs index 1600dd3..1a92b77 100644 --- a/src/FRC.NetworkTables/Dispatcher.cs +++ b/src/FRC.NetworkTables/Dispatcher.cs @@ -4,6 +4,7 @@ using NetworkTables.Interfaces; using NetworkTables.TcpSockets; using NetworkTables.Logging; +using System; namespace NetworkTables { @@ -18,7 +19,7 @@ private Dispatcher() : this(Storage.Instance, Notifier.Instance) public Dispatcher(Storage storage, Notifier notifier) : base(storage, notifier) { - + } /// @@ -45,22 +46,26 @@ public void StartServer(string persistentFilename, string listenAddress, int por public void SetServer(string serverName, int port) { - SetConnector(() => TcpConnector.Connect(serverName, port, Logger.Instance, 1)); + SetConnector(() => + { + return TcpConnector.ConnectParallel(new List<(string server, int port)> { (serverName, port) }, Logger.Instance, TimeSpan.FromSeconds(3)); + }); } public void SetServer(IList servers) { - List connectors = new List(); + List<(string server, int port)> addresses = new List<(string server, int port)>(servers.Count); foreach (var server in servers) { - connectors.Add(() => TcpConnector.Connect(server.IpAddress, server.Port, Logger.Instance, 1)); + addresses.Add((server.IpAddress, server.Port)); } - SetConnector(connectors); + + SetConnector(() => TcpConnector.ConnectParallel(addresses, Logger.Instance, TimeSpan.FromSeconds(3))); } public void SetServerOverride(IPAddress address, int port) { - SetConnectorOverride(() => TcpConnector.Connect(address.ToString(), port, Logger.Instance, 1)); + SetConnectorOverride(() => TcpConnector.ConnectParallel(new List<(string server, int port)> { (address.ToString(), port) }, Logger.Instance, TimeSpan.FromSeconds(3))); } public void ClearServerOverride() @@ -68,4 +73,4 @@ public void ClearServerOverride() ClearConnectorOverride(); } } -} +} \ No newline at end of file diff --git a/src/FRC.NetworkTables/DispatcherBase.cs b/src/FRC.NetworkTables/DispatcherBase.cs index a5d4512..e41ed32 100644 --- a/src/FRC.NetworkTables/DispatcherBase.cs +++ b/src/FRC.NetworkTables/DispatcherBase.cs @@ -13,7 +13,7 @@ namespace NetworkTables { internal class DispatcherBase : IDisposable { - public delegate NtTcpClient Connector(); + public delegate IClient Connector(); public const double MinimumUpdateTime = 0.01; //100ms public const double MaximumUpdateTime = 1.0; //1 second diff --git a/src/FRC.NetworkTables/TcpSockets/TcpConnector.cs b/src/FRC.NetworkTables/TcpSockets/TcpConnector.cs index 4d6f8e6..fb9b6f0 100644 --- a/src/FRC.NetworkTables/TcpSockets/TcpConnector.cs +++ b/src/FRC.NetworkTables/TcpSockets/TcpConnector.cs @@ -6,101 +6,116 @@ using System.Threading.Tasks; using System.Runtime.ExceptionServices; using static NetworkTables.Logging.Logger; +using Nito.AsyncEx; +using System.IO; namespace NetworkTables.TcpSockets { internal class TcpConnector { - private static bool WaitAndUnwrapException(Task task, int timeout) + public class TcpClientNt : IClient { - try + private readonly TcpClient m_client; + + internal TcpClientNt(TcpClient client) { - return task.Wait(timeout); + m_client = client; } - catch (AggregateException ex) + + public Stream GetStream() { - ExceptionDispatchInfo.Capture(ex.InnerException).Throw(); - throw ex.InnerException; + return m_client.GetStream(); } - } - - private static int ResolveHostName(string hostName, out IPAddress[] addr) - { - try + public EndPoint RemoteEndPoint { - var entries = Dns.GetHostAddressesAsync(hostName); - var success = WaitAndUnwrapException(entries, 1000); - if (!success) + get { - addr = null; - return 1; + return m_client.Client.RemoteEndPoint; } - List addresses = new List(); - foreach (var ipAddress in entries.Result) + } + public bool NoDelay + { + set { - // Only allow IPV4 addresses for now - // Sockets don't all support IPV6 - if (ipAddress.AddressFamily == AddressFamily.InterNetwork) - { - if (!addresses.Contains(ipAddress)) - { - addresses.Add(ipAddress); - } - } } - addr = addresses.ToArray(); - } - catch (SocketException e) + + public void Dispose() { - addr = null; - return (int)e.SocketErrorCode; + m_client.Dispose(); } - return 0; } - public static NtTcpClient Connect(string server, int port, Logger logger, int timeout = 0) + private static void PrintConnectFailList(IList<(string server, int port)> servers, Logger logger) { - if (ResolveHostName(server, out IPAddress[] addr) != 0) + Logger.Error(logger, "Failed to connect to the following IP Addresses:"); + foreach (var item in servers) { - try - { - addr = new IPAddress[1]; - addr[0] = IPAddress.Parse(server); - } - catch (FormatException) - { - Error(logger, $"could not resolve {server} address"); - return null; - } + Logger.Error(logger, $" Server: {item.server} Port: {item.port}"); } + } - //Create out client - NtTcpClient client = new NtTcpClient(AddressFamily.InterNetwork); - // No time limit, connect forever - if (timeout == 0) + public static IClient ConnectParallel(IList<(string ip, int port)> conns, Logger logger, TimeSpan timeout) + { + return ConnectParallelAsync(conns, logger, timeout).Result; + } + + public static Task ConnectParallelAsync(IList<(string ip, int port)> conns, Logger logger, TimeSpan timeout) + { + List clients = new List(); + List tasks = new List(); + + foreach (var item in conns) { + var client = new TcpClient(); + Task connectTask; try { - client.Connect(addr, port); + connectTask = client.ConnectAsync(item.ip, item.port); + + } + catch (ArgumentOutOfRangeException aore) + { + // TODO: Log + Logger.Error(logger, $"Bad argument {aore}"); + continue; } - catch (SocketException ex) + catch (SocketException se) { - Error(logger, $"Connect() to {server} port {port.ToString()} failed: {ex.SocketErrorCode}"); - ((IDisposable)client).Dispose(); - return null; + // TODO: Log + Logger.Warning(logger, $"Socket connect failed {se}"); + continue; } - return client; + clients.Add(client); + tasks.Add(connectTask); } - //Connect with time limit - bool connectedWithTimeout = client.ConnectWithTimeout(addr, port, logger, timeout); - if (!connectedWithTimeout) + var delayTask = Task.Delay(timeout); + tasks.Add(delayTask); + + async Task ConnectAsyncInternal() { - ((IDisposable)client).Dispose(); + while (tasks.Count > 0) + { + var task = await Task.WhenAny(tasks); + if (task == delayTask) + { + return null; + } + var index = tasks.IndexOf(task); + var client = clients[index]; + if (client.Connected) + { + return new TcpClientNt(client); + } + clients.RemoveAt(index); + tasks.RemoveAt(index); + } return null; } - return client; + + return ConnectAsyncInternal(); } + } -} +} \ No newline at end of file