Skip to content

Commit a4dce6f

Browse files
committed
Kinesis Source: Change from SynchronousQueue to CountdownLatch for back pressure
This fixes a problem where an app could OOM when Kinesis scales to add more shards. It relates to the feature implemented in #102. It is a requirement of the Source that the ShardRecordProcessor gets blocked until the downstream app is ready to consume an event. Before this PR we used a SynchronousQueue to achieve this blocking. After this PR we instead use a CountdownLatch, plus an unbounded queue. This means we have better control over backpressure during the scenario where the Source tries to handle many shard ends at the same time.
1 parent 98d5aaf commit a4dce6f

File tree

4 files changed

+75
-41
lines changed

4 files changed

+75
-41
lines changed

modules/kinesis/src/main/scala/com/snowplowanalytics/snowplow/streams/kinesis/source/KCLAction.scala

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,40 @@ private sealed trait KCLAction
1515

1616
private object KCLAction {
1717

18-
final case class ProcessRecords(shardId: String, processRecordsInput: ProcessRecordsInput) extends KCLAction
18+
/**
19+
* The action emitted by the ShardRecordProcessor when it receives new records
20+
*
21+
* @param await
22+
* A countdown latch used to backpressure the ShardRecordProcessor. The consumer of the queue
23+
* should release the countdown latch to unblock the ShardRecordProcessor and let it fetch more
24+
* records from Kinesis.
25+
*/
26+
final case class ProcessRecords(
27+
shardId: String,
28+
await: CountDownLatch,
29+
processRecordsInput: ProcessRecordsInput
30+
) extends KCLAction
31+
32+
/**
33+
* The action emitted by the ShardRecordProcessor when it reaches a shard end.
34+
*
35+
* @param await
36+
* A countdown latch used to block the ShardRecordProcessor until all records from this stream
37+
* have been checkpointed.
38+
*
39+
* @note
40+
* Unlike the `await` in the `ProcessRecords` class, this countdown latch must not be released
41+
* immediately by the queue consumer. It must only be released by the checkpointer.
42+
*/
1943
final case class ShardEnd(
2044
shardId: String,
2145
await: CountDownLatch,
2246
shardEndedInput: ShardEndedInput
2347
) extends KCLAction
24-
final case class KCLError(t: Throwable) extends KCLAction
48+
49+
final case class KCLError(
50+
t: Throwable,
51+
await: CountDownLatch
52+
) extends KCLAction
2553

2654
}

modules/kinesis/src/main/scala/com/snowplowanalytics/snowplow/streams/kinesis/source/KCLScheduler.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import software.amazon.kinesis.retrieval.polling.PollingConfig
2424

2525
import java.net.URI
2626
import java.util.Date
27-
import java.util.concurrent.SynchronousQueue
27+
import java.util.concurrent.{CountDownLatch, LinkedBlockingQueue}
2828
import java.util.concurrent.atomic.AtomicReference
2929

3030
import com.snowplowanalytics.snowplow.streams.kinesis.KinesisSourceConfig
@@ -33,7 +33,7 @@ private[source] object KCLScheduler {
3333

3434
def populateQueue[F[_]: Async](
3535
config: KinesisSourceConfig,
36-
queue: SynchronousQueue[KCLAction],
36+
queue: LinkedBlockingQueue[KCLAction],
3737
client: SdkAsyncHttpClient
3838
): Resource[F, Unit] =
3939
for {
@@ -49,7 +49,7 @@ private[source] object KCLScheduler {
4949
dynamoDbClient: DynamoDbAsyncClient,
5050
cloudWatchClient: CloudWatchAsyncClient,
5151
kinesisConfig: KinesisSourceConfig,
52-
queue: SynchronousQueue[KCLAction]
52+
queue: LinkedBlockingQueue[KCLAction]
5353
): F[Scheduler] =
5454
Sync[F].delay {
5555
val configsBuilder =
@@ -90,8 +90,12 @@ private[source] object KCLScheduler {
9090
val coordinatorConfig = configsBuilder.coordinatorConfig
9191
.workerStateChangeListener(new WorkerStateChangeListener {
9292
def onWorkerStateChange(newState: WorkerStateChangeListener.WorkerState): Unit = ()
93-
override def onAllInitializationAttemptsFailed(e: Throwable): Unit =
94-
queue.put(KCLAction.KCLError(e))
93+
override def onAllInitializationAttemptsFailed(e: Throwable): Unit = {
94+
val countDownLatch = new CountDownLatch(1)
95+
queue.put(KCLAction.KCLError(e, countDownLatch))
96+
countDownLatch.await()
97+
()
98+
}
9599
})
96100

97101
new Scheduler(

modules/kinesis/src/main/scala/com/snowplowanalytics/snowplow/streams/kinesis/source/KinesisSource.scala

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import software.amazon.awssdk.http.async.SdkAsyncHttpClient
1717
import software.amazon.kinesis.lifecycle.events.{ProcessRecordsInput, ShardEndedInput}
1818
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber
1919

20-
import java.util.concurrent.{CountDownLatch, SynchronousQueue}
20+
import java.util.concurrent.{CountDownLatch, LinkedBlockingQueue}
2121
import scala.concurrent.duration.{DurationLong, FiniteDuration}
2222
import scala.jdk.CollectionConverters._
2323

@@ -42,33 +42,36 @@ private[kinesis] object KinesisSource {
4242
}
4343
}
4444

45-
// We enable fairness on the `SynchronousQueue` to ensure all Kinesis shards are sourced at an equal rate.
46-
private val synchronousQueueFairness: Boolean = true
47-
4845
private def kinesisStream[F[_]: Async](
4946
config: KinesisSourceConfig,
5047
client: SdkAsyncHttpClient
5148
): Stream[F, Stream[F, Option[LowLevelEvents[Map[String, Checkpointable]]]]] = {
52-
val actionQueue = new SynchronousQueue[KCLAction](synchronousQueueFairness)
49+
val actionQueue = new LinkedBlockingQueue[KCLAction]()
5350
for {
5451
_ <- Stream.resource(KCLScheduler.populateQueue[F](config, actionQueue, client))
5552
events <- Stream.emit(pullFromQueueAndEmit(actionQueue).stream).repeat
5653
} yield events
5754
}
5855

5956
private def pullFromQueueAndEmit[F[_]: Sync](
60-
queue: SynchronousQueue[KCLAction]
57+
queue: LinkedBlockingQueue[KCLAction]
6158
): Pull[F, Option[LowLevelEvents[Map[String, Checkpointable]]], Unit] =
62-
Pull.eval(pullFromQueue(queue)).flatMap { case PullFromQueueResult(actions, hasShardEnd) =>
59+
Pull.eval(pullFromQueue(queue)).flatMap { actions =>
6360
val toEmit = actions.traverse {
64-
case KCLAction.ProcessRecords(_, processRecordsInput) if processRecordsInput.records.asScala.isEmpty =>
65-
Pull.output1(None)
66-
case KCLAction.ProcessRecords(shardId, processRecordsInput) =>
67-
Pull.output1(Some(provideNextChunk(shardId, processRecordsInput))).covary[F]
61+
case KCLAction.ProcessRecords(_, await, processRecordsInput) if processRecordsInput.records.asScala.isEmpty =>
62+
Pull.eval(Sync[F].delay(await.countDown())) >> Pull.output1(None)
63+
case KCLAction.ProcessRecords(shardId, await, processRecordsInput) =>
64+
Pull.eval(Sync[F].delay(await.countDown())) >> Pull.output1(Some(provideNextChunk(shardId, processRecordsInput))).covary[F]
6865
case KCLAction.ShardEnd(shardId, await, shardEndedInput) =>
66+
// Do not call `await.countDown()` yet. It must be released later by the checkpointer.
6967
handleShardEnd[F](shardId, await, shardEndedInput)
70-
case KCLAction.KCLError(t) =>
71-
Pull.eval(Logger[F].error(t)("Exception from Kinesis source")) *> Pull.raiseError[F](t)
68+
case KCLAction.KCLError(t, await) =>
69+
Pull.eval(Sync[F].delay(await.countDown())) >> Pull.eval(Logger[F].error(t)("Exception from Kinesis source")) >> Pull
70+
.raiseError[F](t)
71+
}
72+
val hasShardEnd = actions.exists {
73+
case _: KCLAction.ShardEnd => true
74+
case _: KCLAction => false
7275
}
7376
if (hasShardEnd) {
7477
val log = Logger[F].info {
@@ -83,31 +86,21 @@ private[kinesis] object KinesisSource {
8386
toEmit *> pullFromQueueAndEmit(queue)
8487
}
8588

86-
private case class PullFromQueueResult(actions: NonEmptyList[KCLAction], hasShardEnd: Boolean)
87-
88-
private def pullFromQueue[F[_]: Sync](queue: SynchronousQueue[KCLAction]): F[PullFromQueueResult] =
89-
resolveNextAction(queue)
90-
.flatMap {
91-
case shardEnd: KCLAction.ShardEnd =>
92-
// If we reached the end of one shard, it is likely we reached the end of other shards too.
93-
// Therefore pull more actions from the queue, to minimize the number of times we need to do
94-
// an early close of the inner stream.
95-
resolveAllActions(queue).map { more =>
96-
PullFromQueueResult(NonEmptyList(shardEnd, more), hasShardEnd = true)
97-
}
98-
case other =>
99-
PullFromQueueResult(NonEmptyList.one(other), hasShardEnd = false).pure[F]
100-
}
89+
private def pullFromQueue[F[_]: Sync](queue: LinkedBlockingQueue[KCLAction]): F[NonEmptyList[KCLAction]] =
90+
for {
91+
head <- resolveNextAction(queue)
92+
tail <- resolveAllActions(queue)
93+
} yield NonEmptyList(head, tail)
10194

10295
/** Always returns a `KCLAction`, possibly waiting until one is available */
103-
private def resolveNextAction[F[_]: Sync](queue: SynchronousQueue[KCLAction]): F[KCLAction] =
96+
private def resolveNextAction[F[_]: Sync](queue: LinkedBlockingQueue[KCLAction]): F[KCLAction] =
10497
Sync[F].delay(Option[KCLAction](queue.poll)).flatMap {
10598
case Some(action) => Sync[F].pure(action)
10699
case None => Sync[F].interruptible(queue.take)
107100
}
108101

109102
/** Returns immediately, but the `List[KCLAction]` might be empty */
110-
private def resolveAllActions[F[_]: Sync](queue: SynchronousQueue[KCLAction]): F[List[KCLAction]] =
103+
private def resolveAllActions[F[_]: Sync](queue: LinkedBlockingQueue[KCLAction]): F[List[KCLAction]] =
111104
for {
112105
ret <- Sync[F].delay(new java.util.ArrayList[KCLAction]())
113106
_ <- Sync[F].delay(queue.drainTo(ret))

modules/kinesis/src/main/scala/com/snowplowanalytics/snowplow/streams/kinesis/source/ShardRecordProcessor.scala

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ import software.amazon.kinesis.lifecycle.events.{
1616
}
1717
import software.amazon.kinesis.processor.{ShardRecordProcessor => KCLShardProcessor}
1818

19-
import java.util.concurrent.{CountDownLatch, SynchronousQueue}
19+
import java.util.concurrent.{CountDownLatch, LinkedBlockingQueue}
2020
import java.util.concurrent.atomic.AtomicReference
2121

2222
private[source] object ShardRecordProcessor {
2323

2424
def apply(
25-
queue: SynchronousQueue[KCLAction],
25+
queue: LinkedBlockingQueue[KCLAction],
2626
currentShardIds: AtomicReference[Set[String]]
2727
): KCLShardProcessor = new KCLShardProcessor {
2828
private var shardId: String = _
@@ -36,9 +36,15 @@ private[source] object ShardRecordProcessor {
3636
// 2. KCL re-aquires the lost lease for the same shard
3737
// 3. The original ShardRecordProcessor is not terminated until after KCL re-aquires the lease
3838
// This is a very unhealthy state, so we should kill the app.
39-
val action = KCLAction.KCLError(new RuntimeException(s"Refusing to initialize a duplicate record processor for shard $shardId"))
39+
val countDownLatch = new CountDownLatch(1)
40+
val action = KCLAction.KCLError(
41+
new RuntimeException(s"Refusing to initialize a duplicate record processor for shard $shardId"),
42+
countDownLatch
43+
)
4044
withHandledInterrupts {
4145
queue.put(action)
46+
countDownLatch.await()
47+
()
4248
}
4349
}
4450
}
@@ -54,9 +60,12 @@ private[source] object ShardRecordProcessor {
5460
}
5561

5662
override def processRecords(processRecordsInput: ProcessRecordsInput): Unit = {
57-
val action = KCLAction.ProcessRecords(shardId, processRecordsInput)
63+
val countDownLatch = new CountDownLatch(1)
64+
val action = KCLAction.ProcessRecords(shardId, countDownLatch, processRecordsInput)
5865
withHandledInterrupts {
5966
queue.put(action)
67+
countDownLatch.await()
68+
()
6069
}
6170
}
6271

0 commit comments

Comments
 (0)