diff --git a/core/src/main/scala/sttp/monad/MonadError.scala b/core/src/main/scala/sttp/monad/MonadError.scala index d2c5c5e..1cdf94a 100644 --- a/core/src/main/scala/sttp/monad/MonadError.scala +++ b/core/src/main/scala/sttp/monad/MonadError.scala @@ -42,8 +42,17 @@ trait MonadError[F[_]] { case Failure(e) => error(e) } + /** Deprecated method which doesn't work properly when constructing the `f` effect itself throws exceptions - the + * finalizer `e` is not run in that case. Use `ensure2` instead, which uses a lazy-evaluated by-name parameter. + */ + @deprecated(message = "Use ensure2 for proper exception handling", since = "1.5.0") def ensure[T](f: F[T], e: => F[Unit]): F[T] + /** Runs `f`, and ensures that `e` is always run afterwards, regardless of the outcome. `e` is run even when `f` + * throws exceptions during construction of the effect. + */ + def ensure2[T](f: => F[T], e: => F[Unit]): F[T] = ensure(f, e) + def blocking[T](t: => T): F[T] = eval(t) } @@ -62,7 +71,7 @@ object syntax { def map[B](f: A => B)(implicit ME: MonadError[F]): F[B] = ME.map(r)(f) def flatMap[B](f: A => F[B])(implicit ME: MonadError[F]): F[B] = ME.flatMap(r)(f) def handleError[T](h: PartialFunction[Throwable, F[A]])(implicit ME: MonadError[F]): F[A] = ME.handleError(r)(h) - def ensure(e: => F[Unit])(implicit ME: MonadError[F]): F[A] = ME.ensure(r, e) + def ensure(e: => F[Unit])(implicit ME: MonadError[F]): F[A] = ME.ensure2(r, e) def flatTap[B](f: A => F[B])(implicit ME: MonadError[F]): F[A] = ME.flatTap(r)(f) } @@ -98,13 +107,24 @@ object EitherMonad extends MonadError[Either[Throwable, *]] { case _ => rt } - override def ensure[T](f: Either[Throwable, T], e: => Either[Throwable, Unit]): Either[Throwable, T] = { + override def ensure[T](f: Either[Throwable, T], e: => Either[Throwable, Unit]): Either[Throwable, T] = ensure2(f, e) + + override def ensure2[T](f: => Either[Throwable, T], e: => Either[Throwable, Unit]): Either[Throwable, T] = { def runE = Try(e) match { case Failure(f) => Left(f) case Success(v) => v } - f match { + + val ef = + try f + catch { + case t: Throwable => + runE + throw t + } + + ef match { case Left(f) => runE.right.flatMap(_ => Left(f)) case Right(v) => runE.right.map(_ => v) } @@ -125,11 +145,22 @@ object TryMonad extends MonadError[Try] { override def fromTry[T](t: Try[T]): Try[T] = t - override def ensure[T](f: Try[T], e: => Try[Unit]): Try[T] = - f match { + override def ensure[T](f: Try[T], e: => Try[Unit]): Try[T] = ensure2(f, e) + + override def ensure2[T](f: => Try[T], e: => Try[Unit]): Try[T] = { + val ef = + try f + catch { + case t: Throwable => + e + throw t + } + + ef match { case Success(v) => Try(e).flatten.map(_ => v) case Failure(f) => Try(e).flatten.flatMap(_ => Failure(f)) } + } } class FutureMonad(implicit ec: ExecutionContext) extends MonadAsyncError[Future] { override def unit[T](t: T): Future[T] = Future.successful(t) @@ -155,16 +186,23 @@ class FutureMonad(implicit ec: ExecutionContext) extends MonadAsyncError[Future] p.future } - override def ensure[T](f: Future[T], e: => Future[Unit]): Future[T] = { + override def ensure[T](f: Future[T], e: => Future[Unit]): Future[T] = ensure2(f, e) + + override def ensure2[T](f: => Future[T], e: => Future[Unit]): Future[T] = { val p = Promise[T]() def runE = Try(e) match { case Failure(f) => Future.failed(f) case Success(v) => v } - f.onComplete { - case Success(v) => runE.map(_ => v).onComplete(p.complete(_)) - case Failure(f) => runE.flatMap(_ => Future.failed(f)).onComplete(p.complete(_)) + try { + f.onComplete { + case Success(v) => runE.map(_ => v).onComplete(p.complete(_)) + case Failure(f) => runE.flatMap(_ => Future.failed(f)).onComplete(p.complete(_)) + } + } catch { + case t: Throwable => + e.onComplete(_ => p.complete(Failure(t))) } p.future } @@ -181,7 +219,8 @@ object IdentityMonad extends MonadError[Identity] { h: PartialFunction[Throwable, Identity[T]] ): Identity[T] = rt override def eval[T](t: => T): Identity[T] = t - override def ensure[T](f: Identity[T], e: => Identity[Unit]): Identity[T] = + override def ensure[T](f: Identity[T], e: => Identity[Unit]): Identity[T] = ensure2(f, e) + override def ensure2[T](f: => Identity[T], e: => Identity[Unit]): Identity[T] = try f finally e } diff --git a/core/src/test/scalajvm/sttp/monad/FutureMonadTest.scala b/core/src/test/scalajvm/sttp/monad/FutureMonadTest.scala new file mode 100644 index 0000000..6e58e6a --- /dev/null +++ b/core/src/test/scalajvm/sttp/monad/FutureMonadTest.scala @@ -0,0 +1,23 @@ +package sttp.monad + +import org.scalatest.concurrent.ScalaFutures.convertScalaFuture +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import java.util.concurrent.atomic.AtomicBoolean +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future + +class FutureMonadTest extends AnyFlatSpec with Matchers { + implicit val m: MonadError[Future] = new FutureMonad() + + it should "ensure" in { + val ran = new AtomicBoolean(false) + + intercept[RuntimeException] { + m.ensure2((throw new RuntimeException("boom!")): Future[Int], Future(ran.set(true))).futureValue + } + + ran.get shouldBe true + } +} diff --git a/core/src/test/scalajvm/sttp/monad/IdentityMonadTest.scala b/core/src/test/scalajvm/sttp/monad/IdentityMonadTest.scala new file mode 100644 index 0000000..cb0b785 --- /dev/null +++ b/core/src/test/scalajvm/sttp/monad/IdentityMonadTest.scala @@ -0,0 +1,35 @@ +package sttp.monad + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import sttp.monad.syntax._ +import sttp.shared.Identity + +class IdentityMonadTest extends AnyFlatSpec with Matchers { + implicit val m: MonadError[Identity] = IdentityMonad + + it should "map" in { + m.map(10)(_ + 2) shouldBe 12 + } + + it should "ensure" in { + var ran = false + + intercept[RuntimeException] { + m.ensure2(m.error(new RuntimeException("boom!")), { ran = true }) + } + + ran shouldBe true + } + + it should "ensure using syntax" in { + var ran = false + + intercept[RuntimeException] { + m.error[Int](new RuntimeException("boom!")).ensure({ ran = true }) + } + + ran shouldBe true + } +} diff --git a/ws/src/test/scala/sttp/ws/testing/WebSocketStubTests.scala b/ws/src/test/scala/sttp/ws/testing/WebSocketStubTests.scala index e34b24b..72f4fb2 100644 --- a/ws/src/test/scala/sttp/ws/testing/WebSocketStubTests.scala +++ b/ws/src/test/scala/sttp/ws/testing/WebSocketStubTests.scala @@ -7,31 +7,15 @@ import sttp.monad.MonadError import sttp.ws.{WebSocketClosed, WebSocketFrame} import scala.util.{Failure, Success} +import sttp.monad.IdentityMonad class WebSocketStubTests extends AnyFlatSpec with Matchers with ScalaFutures { - type Identity[X] = X - object IdMonad extends MonadError[Identity] { - override def unit[T](t: T): Identity[T] = t - override def map[T, T2](fa: Identity[T])(f: T => T2): Identity[T2] = f(fa) - override def flatMap[T, T2](fa: Identity[T])(f: T => Identity[T2]): Identity[T2] = f(fa) - - override def error[T](t: Throwable): Identity[T] = throw t - override protected def handleWrappedError[T](rt: Identity[T])( - h: PartialFunction[Throwable, Identity[T]] - ): Identity[T] = rt - - override def eval[T](t: => T): Identity[T] = t - override def ensure[T](f: Identity[T], e: => Identity[Unit]): Identity[T] = - try f - finally e - } - class MyException extends Exception "web socket stub" should "return initial Incoming frames on 'receive'" in { val frames = List("a", "b", "c").map(WebSocketFrame.text) val webSocketStub = WebSocketStub.initialReceive(frames) - val ws = webSocketStub.build(IdMonad) + val ws = webSocketStub.build(IdentityMonad) ws.receive() shouldBe WebSocketFrame.text("a") ws.receive() shouldBe WebSocketFrame.text("b") @@ -42,7 +26,7 @@ class WebSocketStubTests extends AnyFlatSpec with Matchers with ScalaFutures { val okFrame = WebSocketFrame.text("abc") val exception = new MyException val webSocketStub = WebSocketStub.initialReceiveWith(List(Success(okFrame), Failure(exception))) - val ws = webSocketStub.build(IdMonad) + val ws = webSocketStub.build(IdentityMonad) ws.receive() shouldBe WebSocketFrame.text("abc") assertThrows[MyException](ws.receive()) @@ -59,7 +43,7 @@ class WebSocketStubTests extends AnyFlatSpec with Matchers with ScalaFutures { case `expectedFrame` => List(secondFrame, thirdFrame) case _ => List.empty } - val ws = webSocketStub.build(IdMonad) + val ws = webSocketStub.build(IdentityMonad) ws.receive() shouldBe WebSocketFrame.text("No. 1") assertThrows[IllegalStateException](ws.receive()) // no more stubbed messages @@ -76,7 +60,7 @@ class WebSocketStubTests extends AnyFlatSpec with Matchers with ScalaFutures { val webSocketStub = WebSocketStub.noInitialReceive .thenRespondWith(_ => List(Success(ok), Failure(exception))) - val ws = webSocketStub.build(IdMonad) + val ws = webSocketStub.build(IdentityMonad) ws.send(WebSocketFrame.text("let's add responses")) ws.receive() shouldBe WebSocketFrame.text("ok") @@ -87,7 +71,7 @@ class WebSocketStubTests extends AnyFlatSpec with Matchers with ScalaFutures { val ok = WebSocketFrame.text("ok") val closeFrame = WebSocketFrame.Close(500, "internal error") val webSocketStub = WebSocketStub.initialReceive(List(closeFrame, ok)) - val ws = webSocketStub.build(IdMonad) + val ws = webSocketStub.build(IdentityMonad) ws.send(WebSocketFrame.text("let's add responses")) ws.receive() shouldBe closeFrame @@ -100,7 +84,7 @@ class WebSocketStubTests extends AnyFlatSpec with Matchers with ScalaFutures { .thenRespondS(0) { case (counter, _) => (counter + 1, List(WebSocketFrame.text(s"No. $counter"))) } - val ws = webSocketStub.build(IdMonad) + val ws = webSocketStub.build(IdentityMonad) ws.send(WebSocketFrame.text("a")) ws.send(WebSocketFrame.text("b"))