Skip to content

Commit e69c257

Browse files
impl and working tests
1 parent 3e08417 commit e69c257

File tree

3 files changed

+234
-63
lines changed

3 files changed

+234
-63
lines changed

src/WouterVanRanst.Utils.Tests/TaskCompletionBufferTests.cs

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,136 @@ public async Task TaskSetTemporarilyEmptyButMoreTasksAreAdded()
187187
}
188188

189189

190+
191+
192+
[Fact]
193+
public async Task AddTaskAfterEnumeratorStarts()
194+
{
195+
var buffer = new TaskCompletionBuffer<string>();
196+
var task1 = SimulateTask("Task1", 200);
197+
buffer.Add(task1);
198+
199+
await using var enumerator = buffer.GetConsumingEnumerable().GetAsyncEnumerator();
200+
var moveNextTask = enumerator.MoveNextAsync();
201+
202+
var task2 = SimulateTask("Task2", 100);
203+
buffer.Add(task2);
204+
buffer.CompleteAdding();
205+
206+
Assert.True(await moveNextTask);
207+
Assert.Equal("Task2", await enumerator.Current);
208+
209+
Assert.True(await enumerator.MoveNextAsync());
210+
Assert.Equal("Task1", await enumerator.Current);
211+
212+
Assert.False(await enumerator.MoveNextAsync());
213+
}
214+
215+
[Fact]
216+
public async Task AllTasksCanceled()
217+
{
218+
var buffer = new TaskCompletionBuffer<string>();
219+
var cts = new CancellationTokenSource();
220+
cts.Cancel();
221+
222+
var task1 = Task.FromCanceled<string>(cts.Token);
223+
buffer.Add(task1);
224+
buffer.CompleteAdding();
225+
226+
var processedTasks = new List<string>();
227+
await foreach (var task in buffer.GetConsumingEnumerable())
228+
{
229+
try
230+
{
231+
processedTasks.Add(await task);
232+
}
233+
catch (TaskCanceledException)
234+
{
235+
processedTasks.Add("Canceled");
236+
}
237+
}
238+
239+
Assert.Equal(new[] { "Canceled" }, processedTasks);
240+
}
241+
242+
[Fact]
243+
public async Task MultipleProducersAddingConcurrently()
244+
{
245+
var buffer = new TaskCompletionBuffer<string>();
246+
var producer1 = Task.Run(() =>
247+
{
248+
buffer.Add(SimulateTask("P1T1", 200));
249+
buffer.Add(SimulateTask("P1T2", 100));
250+
});
251+
252+
var producer2 = Task.Run(() =>
253+
{
254+
buffer.Add(SimulateTask("P2T1", 150));
255+
buffer.Add(SimulateTask("P2T2", 50));
256+
});
257+
258+
await Task.WhenAll(producer1, producer2);
259+
buffer.CompleteAdding();
260+
261+
var processedTasks = new List<string>();
262+
await foreach (var task in buffer.GetConsumingEnumerable())
263+
{
264+
processedTasks.Add(await task);
265+
}
266+
267+
Assert.Equal(new[] { "P2T2", "P1T2", "P2T1", "P1T1" }, processedTasks);
268+
}
269+
270+
[Fact]
271+
public async Task AddCompletedTaskImmediatelyProcessed()
272+
{
273+
var buffer = new TaskCompletionBuffer<string>();
274+
var task = Task.FromResult("Immediate");
275+
buffer.Add(task);
276+
buffer.CompleteAdding();
277+
278+
var processedTasks = new List<string>();
279+
await foreach (var t in buffer.GetConsumingEnumerable())
280+
{
281+
processedTasks.Add(await t);
282+
}
283+
284+
Assert.Equal(new[] { "Immediate" }, processedTasks);
285+
}
286+
287+
[Fact]
288+
public void CompleteAddingIsIdempotent()
289+
{
290+
var buffer = new TaskCompletionBuffer<string>();
291+
buffer.CompleteAdding();
292+
buffer.CompleteAdding(); // Should not throw
293+
}
294+
295+
[Fact]
296+
public async Task EnumerateWithoutCompleteAddingBlocks()
297+
{
298+
var buffer = new TaskCompletionBuffer<string>();
299+
var task = SimulateTask("Task", 100);
300+
buffer.Add(task);
301+
302+
var processedTasks = new List<string>();
303+
var cts = new CancellationTokenSource(500); // Cancel after 500ms if it blocks indefinitely
304+
305+
await Assert.ThrowsAsync<OperationCanceledException>(async () =>
306+
{
307+
await foreach (var t in buffer.GetConsumingEnumerable(cts.Token))
308+
{
309+
processedTasks.Add(await t);
310+
}
311+
});
312+
313+
Assert.Single(processedTasks); // The task should have completed within 500ms
314+
Assert.Equal("Task", processedTasks[0]);
315+
}
316+
317+
318+
319+
190320
private async Task<string> SimulateTask(string name, int delay)
191321
{
192322
await Task.Delay(delay);

src/WouterVanRanst.Utils/Collections/ConcurrentConsumingTaskCollection.cs

Lines changed: 0 additions & 63 deletions
This file was deleted.
Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
using System.Threading.Channels;
2+
3+
namespace WouterVanRanst.Utils.Collections;
4+
5+
public class TaskCompletionBuffer<T>
6+
{
7+
private readonly Channel<Task<T>> _channel = Channel.CreateUnbounded<Task<T>>();
8+
private readonly object _lock = new object();
9+
private volatile bool _completedAdding;
10+
private int _pendingTaskCount;
11+
12+
public void Add(Task<T> task)
13+
{
14+
lock (_lock)
15+
{
16+
if (_completedAdding)
17+
throw new InvalidOperationException("Cannot add tasks after CompleteAdding has been called.");
18+
Interlocked.Increment(ref _pendingTaskCount);
19+
}
20+
21+
task.ContinueWith(t =>
22+
{
23+
_channel.Writer.TryWrite(t);
24+
int newCount = Interlocked.Decrement(ref _pendingTaskCount);
25+
if (_completedAdding && newCount == 0)
26+
_channel.Writer.TryComplete();
27+
}, TaskContinuationOptions.ExecuteSynchronously);
28+
}
29+
30+
public void CompleteAdding()
31+
{
32+
lock (_lock)
33+
{
34+
if (_completedAdding)
35+
return;
36+
_completedAdding = true;
37+
}
38+
39+
if (Interlocked.CompareExchange(ref _pendingTaskCount, 0, 0) == 0)
40+
_channel.Writer.TryComplete();
41+
}
42+
43+
public IAsyncEnumerable<Task<T>> GetConsumingEnumerable(CancellationToken cancellationToken = default)
44+
{
45+
return _channel.Reader.ReadAllAsync(cancellationToken);
46+
}
47+
}
48+
49+
//public class TaskCompletionBuffer<T>
50+
//{
51+
// private readonly Channel<Task<T>> _taskChannel = Channel.CreateUnbounded<Task<T>>();
52+
// private readonly CancellationTokenSource _completionSignal = new();
53+
// private bool _isAddingCompleted;
54+
55+
// public void Add(Task<T> task)
56+
// {
57+
// if (_isAddingCompleted)
58+
// throw new InvalidOperationException("Adding tasks has been marked as complete.");
59+
60+
// _taskChannel.Writer.TryWrite(task);
61+
// }
62+
63+
// public void CompleteAdding()
64+
// {
65+
// _isAddingCompleted = true;
66+
// _taskChannel.Writer.Complete();
67+
// _completionSignal.Cancel(); // Signal to exit enumeration when done
68+
// }
69+
70+
// public async IAsyncEnumerable<Task<T>> GetConsumingEnumerable(
71+
// [EnumeratorCancellation] CancellationToken cancellationToken = default)
72+
// {
73+
// var pendingTasks = new List<Task<T>>();
74+
// var combinedToken = CancellationTokenSource.CreateLinkedTokenSource(
75+
// cancellationToken, _completionSignal.Token).Token;
76+
77+
// var reader = _taskChannel.Reader;
78+
79+
// while (!combinedToken.IsCancellationRequested || pendingTasks.Count > 0)
80+
// {
81+
// // Check for newly added tasks
82+
// while (reader.TryRead(out var task))
83+
// pendingTasks.Add(task);
84+
85+
// if (pendingTasks.Count == 0)
86+
// {
87+
// if (_isAddingCompleted) break; // No more tasks coming
88+
89+
// // Wait for new tasks or completion signal
90+
// await reader.WaitToReadAsync(combinedToken).ConfigureAwait(false);
91+
// continue;
92+
// }
93+
94+
// // Wait for any task to complete or new tasks to arrive
95+
// var completedTask = await Task.WhenAny([..pendingTasks, reader.WaitToReadAsync(combinedToken).AsTask()]).ConfigureAwait(false);
96+
97+
// if (completedTask is Task<T> resultTask)
98+
// {
99+
// pendingTasks.Remove(resultTask);
100+
// yield return resultTask;
101+
// }
102+
// }
103+
// }
104+
//}

0 commit comments

Comments
 (0)