Skip to content

Commit 2e089d0

Browse files
authored
Merge pull request #45 from tpeczek/race-condition-when-managing-groups-collection
Removing race condition when managing groups. Resolves #43
2 parents 11ac4e3 + 68710ce commit 2e089d0

File tree

4 files changed

+122
-37
lines changed

4 files changed

+122
-37
lines changed

DocFx.AspNetCore.ServerSentEvents/articles/groups.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ A group is a collection of clients associated with a name. Groups are the recomm
44

55
## Adding to a group
66

7-
Client can be added to a group via the `IServerSentEventsService.AddToGroupAsync` method (the method will return an information if a client has been added to an existing or a new group).
7+
Client can be added to a group via the `IServerSentEventsService.AddToGroup` method (the method will return an information if a client has been added to an existing or a new group).
88

99
```cs
10-
await _serverSentEventsService.AddToGroupAsync(groupName, client);
10+
_serverSentEventsService.AddToGroup(groupName, client);
1111
```
1212

1313
Group membership isn't preserved when a client reconnects. The client needs to rejoin the group when connection is re-established. One of possible ways to handle this is putting group assigment logic into `ServerSentEventsServiceOptions.OnClientConnected` callback.
@@ -21,12 +21,12 @@ public class Startup
2121
{
2222
services.AddServerSentEvents(options =>
2323
{
24-
options.OnClientConnected = async (service, clientConnectedArgs) =>
24+
options.OnClientConnected = (service, clientConnectedArgs) =>
2525
{
2626
// Logic which determines the client group.
2727
...
2828

29-
await service.AddToGroupAsync(groupName, clientConnectedArgs.Client);
29+
service.AddToGroup(groupName, clientConnectedArgs.Client);
3030
};
3131
});
3232

Lib.AspNetCore.ServerSentEvents/IServerSentEventsService.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ public interface IServerSentEventsService
5656
/// </summary>
5757
/// <param name="groupName">The group name.</param>
5858
/// <param name="client">The client to add to a group.</param>
59-
/// <returns>The task object representing the result of asynchronous operation</returns>
60-
Task<ServerSentEventsAddToGroupResult> AddToGroupAsync(string groupName, IServerSentEventsClient client);
59+
/// <returns>The result of operation.</returns>
60+
ServerSentEventsAddToGroupResult AddToGroup(string groupName, IServerSentEventsClient client);
6161

6262
/// <summary>
6363
/// Changes the interval after which clients will attempt to reestablish failed connections.

Lib.AspNetCore.ServerSentEvents/ServerSentEventsService.cs

Lines changed: 15 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public class ServerSentEventsService : IServerSentEventsService
3232

3333
private readonly SemaphoreSlim _groupsSemaphore = new SemaphoreSlim(1, 1);
3434
private static readonly IReadOnlyCollection<IServerSentEventsClient> _emptyGroup = new IServerSentEventsClient[0];
35-
private readonly Dictionary<string, ConcurrentDictionary<Guid, IServerSentEventsClient>> _groups = new Dictionary<string, ConcurrentDictionary<Guid, IServerSentEventsClient>>();
35+
private readonly ConcurrentDictionary<string, ConcurrentDictionary<Guid, IServerSentEventsClient>> _groups = new ConcurrentDictionary<string, ConcurrentDictionary<Guid, IServerSentEventsClient>>();
3636
#endregion
3737

3838
#region Constructors
@@ -99,9 +99,9 @@ public IReadOnlyCollection<IServerSentEventsClient> GetClients()
9999
/// <returns>The clients in the specified group.</returns>
100100
public IReadOnlyCollection<IServerSentEventsClient> GetClients(string groupName)
101101
{
102-
if (_groups.ContainsKey(groupName))
102+
if (_groups.TryGetValue(groupName, out ConcurrentDictionary<Guid, IServerSentEventsClient> group))
103103
{
104-
return _groups[groupName].Values.ToArray();
104+
return (IReadOnlyCollection<IServerSentEventsClient>)group.Values;
105105
}
106106

107107
return _emptyGroup;
@@ -113,18 +113,18 @@ public IReadOnlyCollection<IServerSentEventsClient> GetClients(string groupName)
113113
/// <param name="groupName">The group name.</param>
114114
/// /// <param name="client">The client to add to a group.</param>
115115
/// <returns>The task object representing the result of asynchronous operation</returns>
116-
public async Task<ServerSentEventsAddToGroupResult> AddToGroupAsync(string groupName, IServerSentEventsClient client)
116+
public ServerSentEventsAddToGroupResult AddToGroup(string groupName, IServerSentEventsClient client)
117117
{
118118
ServerSentEventsAddToGroupResult result = ServerSentEventsAddToGroupResult.AddedToExistingGroup;
119119

120-
if (!_groups.ContainsKey(groupName))
120+
ConcurrentDictionary<Guid, IServerSentEventsClient> group = _groups.GetOrAdd(groupName, (_) =>
121121
{
122-
await CreateGroupAsync(groupName);
123-
124122
result = ServerSentEventsAddToGroupResult.AddedToNewGroup;
125-
}
126123

127-
_groups[groupName].TryAdd(client.Id, client);
124+
return new ConcurrentDictionary<Guid, IServerSentEventsClient>();
125+
});
126+
127+
group.TryAdd(client.Id, client);
128128

129129
return result;
130130
}
@@ -389,9 +389,9 @@ internal void RemoveClient(ServerSentEventsClient client)
389389

390390
_clients.TryRemove(client.Id, out _);
391391

392-
foreach(ConcurrentDictionary<Guid, IServerSentEventsClient> group in _groups.Values)
392+
foreach(KeyValuePair<string, ConcurrentDictionary<Guid, IServerSentEventsClient>> group in _groups)
393393
{
394-
group.TryRemove(client.Id, out _);
394+
group.Value.TryRemove(client.Id, out _);
395395
}
396396
}
397397

@@ -400,22 +400,6 @@ internal bool IsClientConnected(Guid clientId)
400400
return _clients.ContainsKey(clientId);
401401
}
402402

403-
private async Task CreateGroupAsync(string groupName)
404-
{
405-
await _groupsSemaphore.WaitAsync();
406-
try
407-
{
408-
if (!_groups.ContainsKey(groupName))
409-
{
410-
_groups.Add(groupName, new ConcurrentDictionary<Guid, IServerSentEventsClient>());
411-
}
412-
}
413-
finally
414-
{
415-
_groupsSemaphore.Release();
416-
}
417-
}
418-
419403
internal Task SendAsync(ServerSentEventBytes serverSentEventBytes, CancellationToken cancellationToken)
420404
{
421405
return SendAsync(_clients.Values, serverSentEventBytes, cancellationToken);
@@ -428,19 +412,19 @@ internal Task SendAsync(ServerSentEventBytes serverSentEventBytes, Func<IServerS
428412

429413
internal Task SendAsync(string groupName, ServerSentEventBytes serverSentEventBytes, CancellationToken cancellationToken)
430414
{
431-
if (_groups.ContainsKey(groupName))
415+
if (_groups.TryGetValue(groupName, out ConcurrentDictionary<Guid, IServerSentEventsClient> group))
432416
{
433-
return SendAsync(_groups[groupName].Values, serverSentEventBytes, cancellationToken);
417+
return SendAsync(group.Values, serverSentEventBytes, cancellationToken);
434418
}
435419

436420
return Task.CompletedTask;
437421
}
438422

439423
internal Task SendAsync(string groupName, ServerSentEventBytes serverSentEventBytes, Func<IServerSentEventsClient, bool> clientPredicate, CancellationToken cancellationToken)
440424
{
441-
if (_groups.ContainsKey(groupName))
425+
if (_groups.TryGetValue(groupName, out ConcurrentDictionary<Guid, IServerSentEventsClient> group))
442426
{
443-
return SendAsync(_groups[groupName].Values.Where(clientPredicate), serverSentEventBytes, cancellationToken);
427+
return SendAsync(group.Values.Where(clientPredicate), serverSentEventBytes, cancellationToken);
444428
}
445429

446430
return Task.CompletedTask;
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
using System;
2+
using System.Linq;
3+
using System.Security.Claims;
4+
using System.Threading.Tasks;
5+
using System.Collections.Generic;
6+
using Microsoft.AspNetCore.Http;
7+
using Microsoft.Extensions.Options;
8+
using Xunit;
9+
using Lib.AspNetCore.ServerSentEvents;
10+
using Lib.AspNetCore.ServerSentEvents.Internals;
11+
12+
namespace Test.AspNetCore.ServerSentEvents
13+
{
14+
public class ServerSentEventsServiceTests
15+
{
16+
#region Prepare SUT
17+
private static async Task<ServerSentEventsClient> PrepareAndAddServerSentEventsClientAsync(ServerSentEventsService serverSentEventsService)
18+
{
19+
HttpContext context = new DefaultHttpContext();
20+
ServerSentEventsClient serverSentEventsClient = new ServerSentEventsClient(Guid.NewGuid(), new ClaimsPrincipal(), context.Response, false);
21+
22+
await serverSentEventsService.OnConnectAsync(context.Request, serverSentEventsClient);
23+
24+
serverSentEventsService.AddClient(serverSentEventsClient);
25+
26+
return serverSentEventsClient;
27+
}
28+
#endregion
29+
30+
#region Tests
31+
[Fact]
32+
public void RemoveClient_ClientsBeingAddedToGroupsInParaller_NoRaceCondition()
33+
{
34+
// ARRANGE
35+
ServerSentEventsService serverSentEventsService = new ServerSentEventsService(Options.Create<ServerSentEventsServiceOptions<ServerSentEventsService>>(new ServerSentEventsServiceOptions<ServerSentEventsService>
36+
{
37+
OnClientConnected = (service, clientConnectedArgs) =>
38+
{
39+
for (var i = 0; i < 1000; i++)
40+
{
41+
service.AddToGroup(Guid.NewGuid().ToString(), clientConnectedArgs.Client);
42+
}
43+
}
44+
}));
45+
46+
// ACT
47+
Task[] clientsTasks = new Task[100];
48+
for (int i = 0; i < clientsTasks.Length; i++)
49+
{
50+
clientsTasks[i] = Task.Run(async () =>
51+
{
52+
ServerSentEventsClient serverSentEventsClient = await PrepareAndAddServerSentEventsClientAsync(serverSentEventsService);
53+
54+
await Task.Delay(100);
55+
56+
serverSentEventsService.RemoveClient(serverSentEventsClient);
57+
});
58+
}
59+
Exception recordedException = Record.Exception(() => Task.WaitAll(clientsTasks));
60+
61+
// ASSERT
62+
Assert.False((recordedException as AggregateException)?.InnerExceptions.Any(ex => (ex as InvalidOperationException)?.Message.Contains("Collection was modified") ?? false) ?? false);
63+
}
64+
65+
[Fact]
66+
public async Task GetClients_GroupNameProvidedAndGroupExists_ReturnsGroup()
67+
{
68+
// ARRANGE
69+
const string serverSentEventsClientsGroupName = nameof(GetClients_GroupNameProvidedAndGroupExists_ReturnsGroup);
70+
ServerSentEventsService serverSentEventsService = new ServerSentEventsService(Options.Create<ServerSentEventsServiceOptions<ServerSentEventsService>>(new ServerSentEventsServiceOptions<ServerSentEventsService>
71+
{
72+
OnClientConnected = async (service, clientConnectedArgs) =>
73+
{
74+
service.AddToGroup(serverSentEventsClientsGroupName, clientConnectedArgs.Client);
75+
}
76+
}));
77+
ServerSentEventsClient serverSentEventsClient = await PrepareAndAddServerSentEventsClientAsync(serverSentEventsService);
78+
79+
// ACT
80+
IReadOnlyCollection<IServerSentEventsClient> serverSentEventsClientsGroup = serverSentEventsService.GetClients(serverSentEventsClientsGroupName);
81+
82+
// ASSERT
83+
Assert.Single(serverSentEventsClientsGroup, serverSentEventsClient);
84+
}
85+
86+
[Fact]
87+
public void GetClients_GroupNameProvidedAndGroupNotExists_ReturnsEmptyGroup()
88+
{
89+
// ARRANGE
90+
const string serverSentEventsClientsGroupName = nameof(GetClients_GroupNameProvidedAndGroupNotExists_ReturnsEmptyGroup);
91+
ServerSentEventsService serverSentEventsService = new ServerSentEventsService(Options.Create<ServerSentEventsServiceOptions<ServerSentEventsService>>(null));
92+
93+
// ACT
94+
IReadOnlyCollection<IServerSentEventsClient> serverSentEventsClientsGroup = serverSentEventsService.GetClients(serverSentEventsClientsGroupName);
95+
96+
// ASSERT
97+
Assert.Empty(serverSentEventsClientsGroup);
98+
}
99+
#endregion
100+
}
101+
}

0 commit comments

Comments
 (0)