Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 49 additions & 10 deletions core/src/main/scala/sttp/monad/MonadError.scala
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,17 @@ trait MonadError[F[_]] {
case Failure(e) => error(e)
}

/** Deprecated method which doesn't work properly when constructing the `f` effect itself throws exceptions - the
* finalizer `e` is not run in that case. Use `ensure2` instead, which uses a lazy-evaluated by-name parameter.
*/
@deprecated(message = "Use ensure2 for proper exception handling", since = "1.5.0")
def ensure[T](f: F[T], e: => F[Unit]): F[T]

/** Runs `f`, and ensures that `e` is always run afterwards, regardless of the outcome. `e` is run even when `f`
* throws exceptions during construction of the effect.
*/
def ensure2[T](f: => F[T], e: => F[Unit]): F[T] = ensure(f, e)

def blocking[T](t: => T): F[T] = eval(t)
}

Expand All @@ -62,7 +71,7 @@ object syntax {
def map[B](f: A => B)(implicit ME: MonadError[F]): F[B] = ME.map(r)(f)
def flatMap[B](f: A => F[B])(implicit ME: MonadError[F]): F[B] = ME.flatMap(r)(f)
def handleError[T](h: PartialFunction[Throwable, F[A]])(implicit ME: MonadError[F]): F[A] = ME.handleError(r)(h)
def ensure(e: => F[Unit])(implicit ME: MonadError[F]): F[A] = ME.ensure(r, e)
def ensure(e: => F[Unit])(implicit ME: MonadError[F]): F[A] = ME.ensure2(r, e)
def flatTap[B](f: A => F[B])(implicit ME: MonadError[F]): F[A] = ME.flatTap(r)(f)
}

Expand Down Expand Up @@ -98,13 +107,24 @@ object EitherMonad extends MonadError[Either[Throwable, *]] {
case _ => rt
}

override def ensure[T](f: Either[Throwable, T], e: => Either[Throwable, Unit]): Either[Throwable, T] = {
override def ensure[T](f: Either[Throwable, T], e: => Either[Throwable, Unit]): Either[Throwable, T] = ensure2(f, e)

override def ensure2[T](f: => Either[Throwable, T], e: => Either[Throwable, Unit]): Either[Throwable, T] = {
def runE =
Try(e) match {
case Failure(f) => Left(f)
case Success(v) => v
}
f match {

val ef =
try f
catch {
case t: Throwable =>
runE
throw t
}

ef match {
case Left(f) => runE.right.flatMap(_ => Left(f))
case Right(v) => runE.right.map(_ => v)
}
Expand All @@ -125,11 +145,22 @@ object TryMonad extends MonadError[Try] {

override def fromTry[T](t: Try[T]): Try[T] = t

override def ensure[T](f: Try[T], e: => Try[Unit]): Try[T] =
f match {
override def ensure[T](f: Try[T], e: => Try[Unit]): Try[T] = ensure2(f, e)

override def ensure2[T](f: => Try[T], e: => Try[Unit]): Try[T] = {
val ef =
try f
catch {
case t: Throwable =>
e
throw t
}

ef match {
case Success(v) => Try(e).flatten.map(_ => v)
case Failure(f) => Try(e).flatten.flatMap(_ => Failure(f))
}
}
}
class FutureMonad(implicit ec: ExecutionContext) extends MonadAsyncError[Future] {
override def unit[T](t: T): Future[T] = Future.successful(t)
Expand All @@ -155,16 +186,23 @@ class FutureMonad(implicit ec: ExecutionContext) extends MonadAsyncError[Future]
p.future
}

override def ensure[T](f: Future[T], e: => Future[Unit]): Future[T] = {
override def ensure[T](f: Future[T], e: => Future[Unit]): Future[T] = ensure2(f, e)

override def ensure2[T](f: => Future[T], e: => Future[Unit]): Future[T] = {
val p = Promise[T]()
def runE =
Try(e) match {
case Failure(f) => Future.failed(f)
case Success(v) => v
}
f.onComplete {
case Success(v) => runE.map(_ => v).onComplete(p.complete(_))
case Failure(f) => runE.flatMap(_ => Future.failed(f)).onComplete(p.complete(_))
try {
f.onComplete {
case Success(v) => runE.map(_ => v).onComplete(p.complete(_))
case Failure(f) => runE.flatMap(_ => Future.failed(f)).onComplete(p.complete(_))
}
} catch {
case t: Throwable =>
e.onComplete(_ => p.complete(Failure(t)))
}
p.future
}
Expand All @@ -181,7 +219,8 @@ object IdentityMonad extends MonadError[Identity] {
h: PartialFunction[Throwable, Identity[T]]
): Identity[T] = rt
override def eval[T](t: => T): Identity[T] = t
override def ensure[T](f: Identity[T], e: => Identity[Unit]): Identity[T] =
override def ensure[T](f: Identity[T], e: => Identity[Unit]): Identity[T] = ensure2(f, e)
override def ensure2[T](f: => Identity[T], e: => Identity[Unit]): Identity[T] =
try f
finally e
}
23 changes: 23 additions & 0 deletions core/src/test/scalajvm/sttp/monad/FutureMonadTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package sttp.monad

import org.scalatest.concurrent.ScalaFutures.convertScalaFuture
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

import java.util.concurrent.atomic.AtomicBoolean
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.Future

class FutureMonadTest extends AnyFlatSpec with Matchers {
implicit val m: MonadError[Future] = new FutureMonad()

it should "ensure" in {
val ran = new AtomicBoolean(false)

intercept[RuntimeException] {
m.ensure2((throw new RuntimeException("boom!")): Future[Int], Future(ran.set(true))).futureValue
}

ran.get shouldBe true
}
}
35 changes: 35 additions & 0 deletions core/src/test/scalajvm/sttp/monad/IdentityMonadTest.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package sttp.monad

import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

import sttp.monad.syntax._
import sttp.shared.Identity

class IdentityMonadTest extends AnyFlatSpec with Matchers {
implicit val m: MonadError[Identity] = IdentityMonad

it should "map" in {
m.map(10)(_ + 2) shouldBe 12
}

it should "ensure" in {
var ran = false

intercept[RuntimeException] {
m.ensure2(m.error(new RuntimeException("boom!")), { ran = true })
}

ran shouldBe true
}

it should "ensure using syntax" in {
var ran = false

intercept[RuntimeException] {
m.error[Int](new RuntimeException("boom!")).ensure({ ran = true })
}

ran shouldBe true
}
}
30 changes: 7 additions & 23 deletions ws/src/test/scala/sttp/ws/testing/WebSocketStubTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,15 @@ import sttp.monad.MonadError
import sttp.ws.{WebSocketClosed, WebSocketFrame}

import scala.util.{Failure, Success}
import sttp.monad.IdentityMonad

class WebSocketStubTests extends AnyFlatSpec with Matchers with ScalaFutures {
type Identity[X] = X
object IdMonad extends MonadError[Identity] {
override def unit[T](t: T): Identity[T] = t
override def map[T, T2](fa: Identity[T])(f: T => T2): Identity[T2] = f(fa)
override def flatMap[T, T2](fa: Identity[T])(f: T => Identity[T2]): Identity[T2] = f(fa)

override def error[T](t: Throwable): Identity[T] = throw t
override protected def handleWrappedError[T](rt: Identity[T])(
h: PartialFunction[Throwable, Identity[T]]
): Identity[T] = rt

override def eval[T](t: => T): Identity[T] = t
override def ensure[T](f: Identity[T], e: => Identity[Unit]): Identity[T] =
try f
finally e
}

class MyException extends Exception

"web socket stub" should "return initial Incoming frames on 'receive'" in {
val frames = List("a", "b", "c").map(WebSocketFrame.text)
val webSocketStub = WebSocketStub.initialReceive(frames)
val ws = webSocketStub.build(IdMonad)
val ws = webSocketStub.build(IdentityMonad)

ws.receive() shouldBe WebSocketFrame.text("a")
ws.receive() shouldBe WebSocketFrame.text("b")
Expand All @@ -42,7 +26,7 @@ class WebSocketStubTests extends AnyFlatSpec with Matchers with ScalaFutures {
val okFrame = WebSocketFrame.text("abc")
val exception = new MyException
val webSocketStub = WebSocketStub.initialReceiveWith(List(Success(okFrame), Failure(exception)))
val ws = webSocketStub.build(IdMonad)
val ws = webSocketStub.build(IdentityMonad)

ws.receive() shouldBe WebSocketFrame.text("abc")
assertThrows[MyException](ws.receive())
Expand All @@ -59,7 +43,7 @@ class WebSocketStubTests extends AnyFlatSpec with Matchers with ScalaFutures {
case `expectedFrame` => List(secondFrame, thirdFrame)
case _ => List.empty
}
val ws = webSocketStub.build(IdMonad)
val ws = webSocketStub.build(IdentityMonad)

ws.receive() shouldBe WebSocketFrame.text("No. 1")
assertThrows[IllegalStateException](ws.receive()) // no more stubbed messages
Expand All @@ -76,7 +60,7 @@ class WebSocketStubTests extends AnyFlatSpec with Matchers with ScalaFutures {

val webSocketStub = WebSocketStub.noInitialReceive
.thenRespondWith(_ => List(Success(ok), Failure(exception)))
val ws = webSocketStub.build(IdMonad)
val ws = webSocketStub.build(IdentityMonad)

ws.send(WebSocketFrame.text("let's add responses"))
ws.receive() shouldBe WebSocketFrame.text("ok")
Expand All @@ -87,7 +71,7 @@ class WebSocketStubTests extends AnyFlatSpec with Matchers with ScalaFutures {
val ok = WebSocketFrame.text("ok")
val closeFrame = WebSocketFrame.Close(500, "internal error")
val webSocketStub = WebSocketStub.initialReceive(List(closeFrame, ok))
val ws = webSocketStub.build(IdMonad)
val ws = webSocketStub.build(IdentityMonad)

ws.send(WebSocketFrame.text("let's add responses"))
ws.receive() shouldBe closeFrame
Expand All @@ -100,7 +84,7 @@ class WebSocketStubTests extends AnyFlatSpec with Matchers with ScalaFutures {
.thenRespondS(0) { case (counter, _) =>
(counter + 1, List(WebSocketFrame.text(s"No. $counter")))
}
val ws = webSocketStub.build(IdMonad)
val ws = webSocketStub.build(IdentityMonad)

ws.send(WebSocketFrame.text("a"))
ws.send(WebSocketFrame.text("b"))
Expand Down