Skip to content

Commit 77457b3

Browse files
authored
Kinesis Source: Change from SynchronousQueue to CountdownLatch for back pressure (#142)
This fixes a problem where an app could OOM when Kinesis scales to add more shards. It relates to the feature implemented in #102. It is a requirement of the Source that the ShardRecordProcessor gets blocked until the downstream app is ready to consume an event. Before this PR we used a SynchronousQueue to achieve this blocking. After this PR we instead use a CountdownLatch, plus an unbounded queue. This means we have better control over backpressure during the scenario where the Source tries to handle many shard ends at the same time.
1 parent a16f800 commit 77457b3

File tree

4 files changed

+75
-41
lines changed

4 files changed

+75
-41
lines changed

modules/kinesis/src/main/scala/com/snowplowanalytics/snowplow/streams/kinesis/source/KCLAction.scala

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,40 @@ private sealed trait KCLAction
1515

1616
private object KCLAction {
1717

18-
final case class ProcessRecords(shardId: String, processRecordsInput: ProcessRecordsInput) extends KCLAction
18+
/**
19+
* The action emitted by the ShardRecordProcessor when it receives new records
20+
*
21+
* @param await
22+
* A countdown latch used to backpressure the ShardRecordProcessor. The consumer of the queue
23+
* should release the countdown latch to unblock the ShardRecordProcessor and let it fetch more
24+
* records from Kinesis.
25+
*/
26+
final case class ProcessRecords(
27+
shardId: String,
28+
await: CountDownLatch,
29+
processRecordsInput: ProcessRecordsInput
30+
) extends KCLAction
31+
32+
/**
33+
* The action emitted by the ShardRecordProcessor when it reaches a shard end.
34+
*
35+
* @param await
36+
* A countdown latch used to block the ShardRecordProcessor until all records from this shard
37+
* have been checkpointed.
38+
*
39+
* @note
40+
* Unlike the `await` in the `ProcessRecords` class, this countdown latch must not be released
41+
* immediately by the queue consumer. It must only be released by the checkpointer.
42+
*/
1943
final case class ShardEnd(
2044
shardId: String,
2145
await: CountDownLatch,
2246
shardEndedInput: ShardEndedInput
2347
) extends KCLAction
24-
final case class KCLError(t: Throwable) extends KCLAction
48+
49+
final case class KCLError(
50+
t: Throwable,
51+
await: CountDownLatch
52+
) extends KCLAction
2553

2654
}

modules/kinesis/src/main/scala/com/snowplowanalytics/snowplow/streams/kinesis/source/KCLScheduler.scala

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import software.amazon.kinesis.retrieval.polling.PollingConfig
2525

2626
import java.net.URI
2727
import java.util.Date
28-
import java.util.concurrent.SynchronousQueue
28+
import java.util.concurrent.{CountDownLatch, LinkedBlockingQueue}
2929
import java.util.concurrent.atomic.AtomicReference
3030

3131
import com.snowplowanalytics.snowplow.streams.kinesis.KinesisSourceConfig
@@ -34,7 +34,7 @@ private[source] object KCLScheduler {
3434

3535
def populateQueue[F[_]: Async](
3636
config: KinesisSourceConfig,
37-
queue: SynchronousQueue[KCLAction],
37+
queue: LinkedBlockingQueue[KCLAction],
3838
client: SdkAsyncHttpClient,
3939
awsUserAgent: Option[String]
4040
): Resource[F, Unit] =
@@ -51,7 +51,7 @@ private[source] object KCLScheduler {
5151
dynamoDbClient: DynamoDbAsyncClient,
5252
cloudWatchClient: CloudWatchAsyncClient,
5353
kinesisConfig: KinesisSourceConfig,
54-
queue: SynchronousQueue[KCLAction]
54+
queue: LinkedBlockingQueue[KCLAction]
5555
): F[Scheduler] =
5656
Sync[F].delay {
5757
val configsBuilder =
@@ -92,8 +92,12 @@ private[source] object KCLScheduler {
9292
val coordinatorConfig = configsBuilder.coordinatorConfig
9393
.workerStateChangeListener(new WorkerStateChangeListener {
9494
def onWorkerStateChange(newState: WorkerStateChangeListener.WorkerState): Unit = ()
95-
override def onAllInitializationAttemptsFailed(e: Throwable): Unit =
96-
queue.put(KCLAction.KCLError(e))
95+
override def onAllInitializationAttemptsFailed(e: Throwable): Unit = {
96+
val countDownLatch = new CountDownLatch(1)
97+
queue.put(KCLAction.KCLError(e, countDownLatch))
98+
countDownLatch.await()
99+
()
100+
}
97101
})
98102

99103
new Scheduler(

modules/kinesis/src/main/scala/com/snowplowanalytics/snowplow/streams/kinesis/source/KinesisSource.scala

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import software.amazon.awssdk.http.async.SdkAsyncHttpClient
1717
import software.amazon.kinesis.lifecycle.events.{ProcessRecordsInput, ShardEndedInput}
1818
import software.amazon.kinesis.retrieval.kpl.ExtendedSequenceNumber
1919

20-
import java.util.concurrent.{CountDownLatch, SynchronousQueue}
20+
import java.util.concurrent.{CountDownLatch, LinkedBlockingQueue}
2121
import scala.concurrent.duration.{DurationLong, FiniteDuration}
2222
import scala.jdk.CollectionConverters._
2323

@@ -46,34 +46,37 @@ private[kinesis] object KinesisSource {
4646
}
4747
}
4848

49-
// We enable fairness on the `SynchronousQueue` to ensure all Kinesis shards are sourced at an equal rate.
50-
private val synchronousQueueFairness: Boolean = true
51-
5249
private def kinesisStream[F[_]: Async](
5350
config: KinesisSourceConfig,
5451
client: SdkAsyncHttpClient,
5552
awsUserAgent: Option[String]
5653
): Stream[F, Stream[F, Option[LowLevelEvents[Map[String, Checkpointable]]]]] = {
57-
val actionQueue = new SynchronousQueue[KCLAction](synchronousQueueFairness)
54+
val actionQueue = new LinkedBlockingQueue[KCLAction]()
5855
for {
5956
_ <- Stream.resource(KCLScheduler.populateQueue[F](config, actionQueue, client, awsUserAgent))
6057
events <- Stream.emit(pullFromQueueAndEmit(actionQueue).stream).repeat
6158
} yield events
6259
}
6360

6461
private def pullFromQueueAndEmit[F[_]: Sync](
65-
queue: SynchronousQueue[KCLAction]
62+
queue: LinkedBlockingQueue[KCLAction]
6663
): Pull[F, Option[LowLevelEvents[Map[String, Checkpointable]]], Unit] =
67-
Pull.eval(pullFromQueue(queue)).flatMap { case PullFromQueueResult(actions, hasShardEnd) =>
64+
Pull.eval(pullFromQueue(queue)).flatMap { actions =>
6865
val toEmit = actions.traverse {
69-
case KCLAction.ProcessRecords(_, processRecordsInput) if processRecordsInput.records.asScala.isEmpty =>
70-
Pull.output1(None)
71-
case KCLAction.ProcessRecords(shardId, processRecordsInput) =>
72-
Pull.output1(Some(provideNextChunk(shardId, processRecordsInput))).covary[F]
66+
case KCLAction.ProcessRecords(_, await, processRecordsInput) if processRecordsInput.records.asScala.isEmpty =>
67+
Pull.eval(Sync[F].delay(await.countDown())) >> Pull.output1(None)
68+
case KCLAction.ProcessRecords(shardId, await, processRecordsInput) =>
69+
Pull.eval(Sync[F].delay(await.countDown())) >> Pull.output1(Some(provideNextChunk(shardId, processRecordsInput))).covary[F]
7370
case KCLAction.ShardEnd(shardId, await, shardEndedInput) =>
71+
// Do not call `await.countDown()` yet. It must be released later by the checkpointer.
7472
handleShardEnd[F](shardId, await, shardEndedInput)
75-
case KCLAction.KCLError(t) =>
76-
Pull.eval(Logger[F].error(t)("Exception from Kinesis source")) *> Pull.raiseError[F](t)
73+
case KCLAction.KCLError(t, await) =>
74+
Pull.eval(Sync[F].delay(await.countDown())) >> Pull.eval(Logger[F].error(t)("Exception from Kinesis source")) >> Pull
75+
.raiseError[F](t)
76+
}
77+
val hasShardEnd = actions.exists {
78+
case _: KCLAction.ShardEnd => true
79+
case _: KCLAction => false
7780
}
7881
if (hasShardEnd) {
7982
val log = Logger[F].info {
@@ -88,31 +91,21 @@ private[kinesis] object KinesisSource {
8891
toEmit *> pullFromQueueAndEmit(queue)
8992
}
9093

91-
private case class PullFromQueueResult(actions: NonEmptyList[KCLAction], hasShardEnd: Boolean)
92-
93-
private def pullFromQueue[F[_]: Sync](queue: SynchronousQueue[KCLAction]): F[PullFromQueueResult] =
94-
resolveNextAction(queue)
95-
.flatMap {
96-
case shardEnd: KCLAction.ShardEnd =>
97-
// If we reached the end of one shard, it is likely we reached the end of other shards too.
98-
// Therefore pull more actions from the queue, to minimize the number of times we need to do
99-
// an early close of the inner stream.
100-
resolveAllActions(queue).map { more =>
101-
PullFromQueueResult(NonEmptyList(shardEnd, more), hasShardEnd = true)
102-
}
103-
case other =>
104-
PullFromQueueResult(NonEmptyList.one(other), hasShardEnd = false).pure[F]
105-
}
94+
private def pullFromQueue[F[_]: Sync](queue: LinkedBlockingQueue[KCLAction]): F[NonEmptyList[KCLAction]] =
95+
for {
96+
head <- resolveNextAction(queue)
97+
tail <- resolveAllActions(queue)
98+
} yield NonEmptyList(head, tail)
10699

107100
/** Always returns a `KCLAction`, possibly waiting until one is available */
108-
private def resolveNextAction[F[_]: Sync](queue: SynchronousQueue[KCLAction]): F[KCLAction] =
101+
private def resolveNextAction[F[_]: Sync](queue: LinkedBlockingQueue[KCLAction]): F[KCLAction] =
109102
Sync[F].delay(Option[KCLAction](queue.poll)).flatMap {
110103
case Some(action) => Sync[F].pure(action)
111104
case None => Sync[F].interruptible(queue.take)
112105
}
113106

114107
/** Returns immediately, but the `List[KCLAction]` might be empty */
115-
private def resolveAllActions[F[_]: Sync](queue: SynchronousQueue[KCLAction]): F[List[KCLAction]] =
108+
private def resolveAllActions[F[_]: Sync](queue: LinkedBlockingQueue[KCLAction]): F[List[KCLAction]] =
116109
for {
117110
ret <- Sync[F].delay(new java.util.ArrayList[KCLAction]())
118111
_ <- Sync[F].delay(queue.drainTo(ret))

modules/kinesis/src/main/scala/com/snowplowanalytics/snowplow/streams/kinesis/source/ShardRecordProcessor.scala

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ import software.amazon.kinesis.lifecycle.events.{
1616
}
1717
import software.amazon.kinesis.processor.{ShardRecordProcessor => KCLShardProcessor}
1818

19-
import java.util.concurrent.{CountDownLatch, SynchronousQueue}
19+
import java.util.concurrent.{CountDownLatch, LinkedBlockingQueue}
2020
import java.util.concurrent.atomic.AtomicReference
2121

2222
private[source] object ShardRecordProcessor {
2323

2424
def apply(
25-
queue: SynchronousQueue[KCLAction],
25+
queue: LinkedBlockingQueue[KCLAction],
2626
currentShardIds: AtomicReference[Set[String]]
2727
): KCLShardProcessor = new KCLShardProcessor {
2828
private var shardId: String = _
@@ -36,9 +36,15 @@ private[source] object ShardRecordProcessor {
3636
// 2. KCL re-aquires the lost lease for the same shard
3737
// 3. The original ShardRecordProcessor is not terminated until after KCL re-aquires the lease
3838
// This is a very unhealthy state, so we should kill the app.
39-
val action = KCLAction.KCLError(new RuntimeException(s"Refusing to initialize a duplicate record processor for shard $shardId"))
39+
val countDownLatch = new CountDownLatch(1)
40+
val action = KCLAction.KCLError(
41+
new RuntimeException(s"Refusing to initialize a duplicate record processor for shard $shardId"),
42+
countDownLatch
43+
)
4044
withHandledInterrupts {
4145
queue.put(action)
46+
countDownLatch.await()
47+
()
4248
}
4349
}
4450
}
@@ -54,9 +60,12 @@ private[source] object ShardRecordProcessor {
5460
}
5561

5662
override def processRecords(processRecordsInput: ProcessRecordsInput): Unit = {
57-
val action = KCLAction.ProcessRecords(shardId, processRecordsInput)
63+
val countDownLatch = new CountDownLatch(1)
64+
val action = KCLAction.ProcessRecords(shardId, countDownLatch, processRecordsInput)
5865
withHandledInterrupts {
5966
queue.put(action)
67+
countDownLatch.await()
68+
()
6069
}
6170
}
6271

0 commit comments

Comments
 (0)