Skip to content

Commit 4596d53

Browse files
authored
Support for limiting streams (Pekko, Akka, fs2, zio-streams) (#350)
1 parent d5c50d3 commit 4596d53

File tree

10 files changed

+336
-9
lines changed

10 files changed

+336
-9
lines changed
Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,23 @@
11
package sttp.capabilities.akka
22

3-
import akka.stream.scaladsl.{Flow, Source}
3+
import akka.stream.scaladsl.Flow
4+
import akka.stream.scaladsl.Source
45
import akka.util.ByteString
6+
import sttp.capabilities.StreamMaxLengthExceededException
57
import sttp.capabilities.Streams
68

79
trait AkkaStreams extends Streams[AkkaStreams] {
810
override type BinaryStream = Source[ByteString, Any]
911
override type Pipe[A, B] = Flow[A, B, Any]
1012
}
11-
object AkkaStreams extends AkkaStreams
13+
14+
object AkkaStreams extends AkkaStreams {
15+
16+
def limitBytes(stream: Source[ByteString, Any], maxBytes: Long): Source[ByteString, Any] = {
17+
stream
18+
.limitWeighted(maxBytes)(_.length.toLong)
19+
.mapError { case _: akka.stream.StreamLimitReachedException =>
20+
StreamMaxLengthExceededException(maxBytes)
21+
}
22+
}
23+
}
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package sttp.capabilities.akka
2+
3+
import akka.actor.ActorSystem
4+
import akka.stream.scaladsl._
5+
import akka.stream.testkit.scaladsl.TestSink
6+
import akka.util.ByteString
7+
import org.scalatest.BeforeAndAfterAll
8+
import org.scalatest.flatspec.AnyFlatSpec
9+
import org.scalatest.matchers.should.Matchers
10+
import sttp.capabilities.StreamMaxLengthExceededException
11+
12+
import scala.concurrent.Await
13+
import scala.concurrent.duration._
14+
15+
class AkkaStreamsTest extends AnyFlatSpec with Matchers with BeforeAndAfterAll {
16+
17+
behavior of "AkkaStreams"
18+
implicit lazy val system: ActorSystem = ActorSystem()
19+
20+
override def afterAll(): Unit = {
21+
val _ = Await.result(system.terminate(), 10.seconds)
22+
}
23+
24+
it should "Pass all bytes if limit is not exceeded" in {
25+
// given
26+
val inputByteCount = 8192
27+
val maxBytes = 8192L
28+
29+
val iterator = Iterator.fill[Byte](inputByteCount)('5'.toByte)
30+
val chunkSize = 256
31+
32+
val inputStream: Source[ByteString, Any] =
33+
Source.fromIterator(() => iterator.grouped(chunkSize).map(group => ByteString(group.toArray)))
34+
35+
// when
36+
val stream = AkkaStreams.limitBytes(inputStream, maxBytes)
37+
38+
// then
39+
stream
40+
.fold(0L)((acc, bs) => acc + bs.length)
41+
.runWith(TestSink[Long]())
42+
.request(1)
43+
.expectNext(inputByteCount.toLong)
44+
.expectComplete()
45+
}
46+
47+
it should "Fail stream if limit is exceeded" in {
48+
// given
49+
val inputByteCount = 8192
50+
val maxBytes = 8191L
51+
52+
val iterator = Iterator.fill[Byte](inputByteCount)('5'.toByte)
53+
val chunkSize = 256
54+
55+
val inputStream: Source[ByteString, Any] =
56+
Source.fromIterator(() => iterator.grouped(chunkSize).map(group => ByteString(group.toArray)))
57+
58+
// when
59+
val stream = AkkaStreams.limitBytes(inputStream, maxBytes)
60+
val probe = stream.runWith(TestSink[ByteString]())
61+
val _ = for (_ <- 1 to 31) yield probe.requestNext()
62+
63+
// then
64+
probe.request(1).expectError(StreamMaxLengthExceededException(maxBytes))
65+
}
66+
}

build.sbt

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ val scala2_13 = "2.13.12"
99
val scala2 = List(scala2_11, scala2_12, scala2_13)
1010
val scala2alive = List(scala2_12, scala2_13)
1111
val scala3 = List("3.3.1")
12-
12+
val akkaVersion = "2.6.20"
13+
val pekkoVersion = "1.0.1"
1314
val sttpModelVersion = "1.6.0"
1415

1516
val scalaTestVersion = "3.2.17"
@@ -39,6 +40,7 @@ val commonSettings = commonSmlBuildSettings ++ ossPublishSettings ++ Seq(
3940
val commonJvmSettings = commonSettings ++ Seq(
4041
scalacOptions ++= Seq("-target:jvm-1.8"),
4142
ideSkipProject := (scalaVersion.value != scala2_13),
43+
bspEnabled := !ideSkipProject.value,
4244
mimaPreviousArtifacts := previousStableVersion.value.map(organization.value %% moduleName.value % _).toSet,
4345
mimaReportBinaryIssues := { if ((publish / skip).value) {} else mimaReportBinaryIssues.value }
4446
)
@@ -141,7 +143,10 @@ lazy val akka = (projectMatrix in file("akka"))
141143
.jvmPlatform(
142144
scalaVersions = scala2alive ++ scala3,
143145
settings = commonJvmSettings ++ Seq(
144-
libraryDependencies += "com.typesafe.akka" %% "akka-stream" % "2.6.20" % "provided"
146+
libraryDependencies ++= Seq(
147+
"com.typesafe.akka" %% "akka-stream" % akkaVersion % "provided",
148+
"com.typesafe.akka" %% "akka-stream-testkit" % akkaVersion % Test
149+
)
145150
)
146151
)
147152
.dependsOn(core)
@@ -153,7 +158,10 @@ lazy val pekko = (projectMatrix in file("pekko"))
153158
.jvmPlatform(
154159
scalaVersions = scala2alive ++ scala3,
155160
settings = commonJvmSettings ++ Seq(
156-
libraryDependencies += "org.apache.pekko" %% "pekko-stream" % "1.0.1" % "provided"
161+
libraryDependencies ++= Seq(
162+
"org.apache.pekko" %% "pekko-stream" % pekkoVersion % "provided",
163+
"org.apache.pekko" %% "pekko-stream-testkit" % pekkoVersion % Test
164+
)
157165
)
158166
)
159167
.dependsOn(core)
@@ -196,7 +204,10 @@ lazy val fs2 = (projectMatrix in file("fs2"))
196204
.jvmPlatform(
197205
scalaVersions = scala2alive ++ scala3,
198206
settings = commonJvmSettings ++ Seq(
199-
libraryDependencies += "co.fs2" %% "fs2-io" % fs2_3_version
207+
libraryDependencies ++= Seq(
208+
"co.fs2" %% "fs2-io" % fs2_3_version,
209+
"org.scalatest" %% "scalatest" % scalaTestVersion % Test
210+
)
200211
)
201212
)
202213
.jsPlatform(
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package sttp.capabilities
2+
3+
case class StreamMaxLengthExceededException(maxBytes: Long) extends Exception {
4+
override def getMessage(): String = s"Stream length limit of $maxBytes bytes exceeded"
5+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,33 @@
11
package sttp.capabilities.fs2
22

3+
import cats.MonadThrow
4+
import fs2.Pull
35
import fs2.Stream
6+
import sttp.capabilities.StreamMaxLengthExceededException
47
import sttp.capabilities.Streams
58

69
trait Fs2Streams[F[_]] extends Streams[Fs2Streams[F]] {
710
override type BinaryStream = Stream[F, Byte]
811
override type Pipe[A, B] = fs2.Pipe[F, A, B]
912
}
13+
1014
object Fs2Streams {
1115
def apply[F[_]]: Fs2Streams[F] = new Fs2Streams[F] {}
16+
17+
def limitBytes[F[_]](stream: Stream[F, Byte], maxBytes: Long)(implicit mErr: MonadThrow[F]): Stream[F, Byte] = {
18+
def go(s: Stream[F, Byte], remaining: Long): Pull[F, Byte, Unit] = {
19+
if (remaining < 0) Pull.raiseError(new StreamMaxLengthExceededException(maxBytes))
20+
else
21+
s.pull.uncons.flatMap {
22+
case Some((chunk, tail)) =>
23+
val chunkSize = chunk.size.toLong
24+
if (chunkSize <= remaining)
25+
Pull.output(chunk) >> go(tail, remaining - chunkSize)
26+
else
27+
Pull.raiseError(new StreamMaxLengthExceededException(maxBytes))
28+
case None => Pull.done
29+
}
30+
}
31+
go(stream, maxBytes).stream
32+
}
1233
}
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package sttp.capabilities.fs2
2+
3+
import cats.effect.IO
4+
import cats.effect.unsafe.implicits.global
5+
import fs2._
6+
import org.scalatest.flatspec.AsyncFlatSpec
7+
import org.scalatest.matchers.should.Matchers
8+
import sttp.capabilities.StreamMaxLengthExceededException
9+
10+
class Fs2StreamsTest extends AsyncFlatSpec with Matchers {
11+
behavior of "Fs2Streams"
12+
13+
it should "Pass all bytes if limit is not exceeded" in {
14+
// given
15+
val inputByteCount = 8192
16+
val maxBytes = 8192L
17+
val inputStream = Stream.fromIterator[IO](Iterator.fill[Byte](inputByteCount)('5'.toByte), chunkSize = 1024)
18+
19+
// when
20+
val stream = Fs2Streams.limitBytes(inputStream, maxBytes)
21+
22+
// then
23+
stream.fold(0L)((acc, _) => acc + 1).compile.lastOrError.unsafeToFuture().map { count =>
24+
count shouldBe inputByteCount
25+
}
26+
}
27+
28+
it should "Fail stream if limit is exceeded" in {
29+
// given
30+
val inputByteCount = 8192
31+
val maxBytes = 8191L
32+
val inputStream = Stream.fromIterator[IO](Iterator.fill[Byte](inputByteCount)('5'.toByte), chunkSize = 1024)
33+
34+
// when
35+
val stream = Fs2Streams.limitBytes(inputStream, maxBytes)
36+
37+
// then
38+
stream.compile.drain
39+
.map(_ => fail("Unexpected end of stream."))
40+
.handleErrorWith {
41+
case StreamMaxLengthExceededException(limit) =>
42+
IO(limit shouldBe maxBytes)
43+
case other =>
44+
IO(fail(s"Unexpected failure cause: $other"))
45+
}
46+
.unsafeToFuture()
47+
}
48+
}
Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,22 @@
11
package sttp.capabilities.pekko
22

33
import org.apache.pekko
4+
import sttp.capabilities.StreamMaxLengthExceededException
5+
import sttp.capabilities.Streams
6+
47
import pekko.stream.scaladsl.{Flow, Source}
58
import pekko.util.ByteString
6-
import sttp.capabilities.Streams
79

810
trait PekkoStreams extends Streams[PekkoStreams] {
911
override type BinaryStream = Source[ByteString, Any]
1012
override type Pipe[A, B] = Flow[A, B, Any]
1113
}
12-
object PekkoStreams extends PekkoStreams
14+
object PekkoStreams extends PekkoStreams {
15+
def limitBytes(stream: Source[ByteString, Any], maxBytes: Long): Source[ByteString, Any] = {
16+
stream
17+
.limitWeighted(maxBytes)(_.length.toLong)
18+
.mapError {
19+
case _: pekko.stream.StreamLimitReachedException => StreamMaxLengthExceededException(maxBytes)
20+
}
21+
}
22+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
package sttp.capabilities.pekko
2+
3+
import org.apache.pekko.actor.ActorSystem
4+
import org.apache.pekko.stream.scaladsl._
5+
import org.apache.pekko.stream.testkit.scaladsl.TestSink
6+
import org.apache.pekko.util.ByteString
7+
import org.scalatest.BeforeAndAfterAll
8+
import org.scalatest.flatspec.AnyFlatSpec
9+
import org.scalatest.matchers.should.Matchers
10+
import sttp.capabilities.StreamMaxLengthExceededException
11+
12+
import scala.concurrent.Await
13+
import scala.concurrent.duration._
14+
15+
class PekkoStreamsTest extends AnyFlatSpec with Matchers with BeforeAndAfterAll {
16+
17+
behavior of "PekkoStreams"
18+
implicit lazy val system: ActorSystem = ActorSystem()
19+
20+
override def afterAll(): Unit = {
21+
val _ = Await.result(system.terminate(), 10.seconds)
22+
}
23+
24+
it should "Pass all bytes if limit is not exceeded" in {
25+
// given
26+
val inputByteCount = 8192
27+
val maxBytes = 8192L
28+
29+
val iterator = Iterator.fill[Byte](inputByteCount)('5'.toByte)
30+
val chunkSize = 256
31+
32+
val inputStream: Source[ByteString, Any] =
33+
Source.fromIterator(() => iterator.grouped(chunkSize).map(group => ByteString(group.toArray)))
34+
35+
// when
36+
val stream = PekkoStreams.limitBytes(inputStream, maxBytes)
37+
38+
// then
39+
stream.fold(0L)((acc, bs) => acc + bs.length).runWith(TestSink[Long]()).request(1).expectNext(inputByteCount.toLong).expectComplete()
40+
}
41+
42+
it should "Fail stream if limit is exceeded" in {
43+
// given
44+
val inputByteCount = 8192
45+
val maxBytes = 8191L
46+
47+
val iterator = Iterator.fill[Byte](inputByteCount)('5'.toByte)
48+
val chunkSize = 256
49+
50+
val inputStream: Source[ByteString, Any] =
51+
Source.fromIterator(() => iterator.grouped(chunkSize).map(group => ByteString(group.toArray)))
52+
53+
// when
54+
val stream = PekkoStreams.limitBytes(inputStream, maxBytes)
55+
val probe = stream.runWith(TestSink[ByteString]())
56+
val _ = for (_ <- 1 to 31) yield probe.requestNext()
57+
58+
// then
59+
probe.request(1).expectError(StreamMaxLengthExceededException(maxBytes))
60+
}
61+
}
Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,49 @@
11
package sttp.capabilities.zio
22

3+
import sttp.capabilities.StreamMaxLengthExceededException
34
import sttp.capabilities.Streams
5+
import zio.Chunk
6+
import zio.Trace
47
import zio.stream.Stream
8+
import zio.stream.ZChannel
9+
import zio.stream.ZStream
10+
11+
import scala.util.control.NonFatal
512

613
trait ZioStreams extends Streams[ZioStreams] {
714
override type BinaryStream = Stream[Throwable, Byte]
815
override type Pipe[A, B] = Stream[Throwable, A] => Stream[Throwable, B]
916
}
10-
object ZioStreams extends ZioStreams
17+
18+
object ZioStreams extends ZioStreams {
19+
20+
def limitBytes(stream: Stream[Throwable, Byte], maxBytes: Long): Stream[Throwable, Byte] =
21+
scanChunksAccum(stream, initState = 0L) { case (accumulatedBytes, chunk) =>
22+
val byteCount = accumulatedBytes + chunk.size
23+
if (byteCount > maxBytes)
24+
throw new StreamMaxLengthExceededException(maxBytes)
25+
else
26+
byteCount
27+
}
28+
29+
private def scanChunksAccum[S, R, A](inputStream: ZStream[R, Throwable, A], initState: => S)(
30+
f: (S, Chunk[A]) => S
31+
)(implicit trace: Trace): ZStream[R, Throwable, A] =
32+
ZStream.succeed(initState).flatMap { state =>
33+
def accumulator(currS: S): ZChannel[Any, Throwable, Chunk[A], Any, Throwable, Chunk[A], Unit] =
34+
ZChannel.readWith(
35+
(in: Chunk[A]) => {
36+
try {
37+
val nextS = f(currS, in)
38+
ZChannel.write(in) *> accumulator(nextS)
39+
} catch {
40+
case NonFatal(err) => ZChannel.fail(err)
41+
}
42+
},
43+
(err: Throwable) => ZChannel.fail(err),
44+
(_: Any) => ZChannel.unit
45+
)
46+
47+
ZStream.fromChannel(inputStream.channel >>> accumulator(state))
48+
}
49+
}

0 commit comments

Comments
 (0)