|
1 | 1 | package sttp.capabilities.zio
|
2 | 2 |
|
| 3 | +import org.scalatest.flatspec.AsyncFlatSpec |
| 4 | +import org.scalatest.matchers.should.Matchers |
3 | 5 | import sttp.capabilities.StreamMaxLengthExceededException
|
4 | 6 | import zio._
|
5 | 7 | import zio.stream.ZStream
|
6 |
| -import zio.test._ |
7 |
| - |
8 |
| -object ZioStreamsTest extends ZIOSpecDefault { |
9 |
| - def spec: Spec[TestEnvironment, Any] = suite("ZioStreams")( |
10 |
| - test("should Pass all bytes if limit is not exceeded") { |
11 |
| - // given |
12 |
| - val inputByteCount = 8192 |
13 |
| - val maxBytes = 8192L |
14 |
| - val inputStream = ZStream.fromIterator(Iterator.fill[Byte](inputByteCount)('5'.toByte)) |
15 |
| - |
16 |
| - // when |
17 |
| - val stream = ZioStreams.limitBytes(inputStream, maxBytes) |
18 |
| - |
19 |
| - // then |
20 |
| - for { |
21 |
| - count <- stream.runFold(0L)((acc, _) => acc + 1) |
22 |
| - } yield assertTrue(count == inputByteCount) |
23 |
| - }, |
24 |
| - test("should Fail stream if limit is exceeded") { |
25 |
| - val inputByteCount = 8192 |
26 |
| - val maxBytes = 8191L |
27 |
| - val inputStream = ZStream.fromIterator(Iterator.fill[Byte](inputByteCount)('5'.toByte)) |
28 |
| - |
29 |
| - // when |
30 |
| - val stream = ZioStreams.limitBytes(inputStream, maxBytes) |
31 |
| - |
32 |
| - // then |
33 |
| - for { |
34 |
| - limit <- stream.runLast.flip |
35 |
| - .flatMap { |
| 8 | +import scala.concurrent.ExecutionContext |
| 9 | + |
| 10 | +class ZioStreamsTest extends AsyncFlatSpec with Matchers { |
| 11 | + override implicit val executionContext: ExecutionContext = Runtime.defaultExecutor.asExecutionContext |
| 12 | + |
| 13 | + behavior of "ZioStreams" |
| 14 | + |
| 15 | + implicit val r: Runtime[Any] = Runtime.default |
| 16 | + |
| 17 | + it should "Pass all bytes if limit is not exceeded" in { |
| 18 | + // given |
| 19 | + val inputByteCount = 8192 |
| 20 | + val maxBytes = 8192L |
| 21 | + val inputStream = ZStream.fromIterator(Iterator.fill[Byte](inputByteCount)('5'.toByte)) |
| 22 | + |
| 23 | + // when |
| 24 | + val stream = ZioStreams.limitBytes(inputStream, maxBytes) |
| 25 | + |
| 26 | + // then |
| 27 | + Unsafe.unsafe(implicit u => |
| 28 | + r.unsafe.runToFuture(stream.runFold(0L)((acc, _) => acc + 1).map { count => |
| 29 | + count shouldBe inputByteCount |
| 30 | + }) |
| 31 | + ) |
| 32 | + } |
| 33 | + |
| 34 | + it should "Fail stream if limit is exceeded" in { |
| 35 | + // given |
| 36 | + val inputByteCount = 8192 |
| 37 | + val maxBytes = 8191L |
| 38 | + val inputStream = ZStream.fromIterator(Iterator.fill[Byte](inputByteCount)('5'.toByte)) |
| 39 | + |
| 40 | + // when |
| 41 | + val stream = ZioStreams.limitBytes(inputStream, maxBytes) |
| 42 | + |
| 43 | + // then |
| 44 | + Unsafe.unsafe(implicit u => |
| 45 | + r.unsafe.runToFuture( |
| 46 | + stream.runLast |
| 47 | + .flatMap(_ => ZIO.succeed(fail("Unexpected end of stream"))) |
| 48 | + .catchSome { |
36 | 49 | case StreamMaxLengthExceededException(limit) =>
|
37 |
| - ZIO.succeed(limit) |
| 50 | + ZIO.succeed(limit shouldBe maxBytes) |
38 | 51 | case other =>
|
39 |
| - ZIO.fail(s"Unexpected failure cause: $other") |
| 52 | + ZIO.succeed(fail(s"Unexpected failure cause: $other")) |
40 | 53 | }
|
41 |
| - } yield assertTrue(limit == maxBytes) |
42 |
| - } |
43 |
| - ) |
| 54 | + ) |
| 55 | + ) |
| 56 | + } |
44 | 57 | }
|
0 commit comments