Skip to content

Commit aa9674f

Browse files
authored
Merge pull request #488 from softwaremill/ensure2
Add ensure2, which runs the finalizer even when constructing the effect throws an exception. Deprecate ensure.
2 parents 1a86c98 + 4d73521 commit aa9674f

File tree

4 files changed

+114
-33
lines changed

4 files changed

+114
-33
lines changed

core/src/main/scala/sttp/monad/MonadError.scala

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,17 @@ trait MonadError[F[_]] {
4242
case Failure(e) => error(e)
4343
}
4444

45+
/** Deprecated method which doesn't work properly when constructing the `f` effect itself throws exceptions - the
46+
* finalizer `e` is not run in that case. Use `ensure2` instead, which uses a lazy-evaluated by-name parameter.
47+
*/
48+
@deprecated(message = "Use ensure2 for proper exception handling", since = "1.5.0")
4549
def ensure[T](f: F[T], e: => F[Unit]): F[T]
4650

51+
/** Runs `f`, and ensures that `e` is always run afterwards, regardless of the outcome. `e` is run even when `f`
52+
* throws exceptions during construction of the effect.
53+
*/
54+
def ensure2[T](f: => F[T], e: => F[Unit]): F[T] = ensure(f, e)
55+
4756
def blocking[T](t: => T): F[T] = eval(t)
4857
}
4958

@@ -62,7 +71,7 @@ object syntax {
6271
def map[B](f: A => B)(implicit ME: MonadError[F]): F[B] = ME.map(r)(f)
6372
def flatMap[B](f: A => F[B])(implicit ME: MonadError[F]): F[B] = ME.flatMap(r)(f)
6473
def handleError[T](h: PartialFunction[Throwable, F[A]])(implicit ME: MonadError[F]): F[A] = ME.handleError(r)(h)
65-
def ensure(e: => F[Unit])(implicit ME: MonadError[F]): F[A] = ME.ensure(r, e)
74+
def ensure(e: => F[Unit])(implicit ME: MonadError[F]): F[A] = ME.ensure2(r, e)
6675
def flatTap[B](f: A => F[B])(implicit ME: MonadError[F]): F[A] = ME.flatTap(r)(f)
6776
}
6877

@@ -98,13 +107,24 @@ object EitherMonad extends MonadError[Either[Throwable, *]] {
98107
case _ => rt
99108
}
100109

101-
override def ensure[T](f: Either[Throwable, T], e: => Either[Throwable, Unit]): Either[Throwable, T] = {
110+
override def ensure[T](f: Either[Throwable, T], e: => Either[Throwable, Unit]): Either[Throwable, T] = ensure2(f, e)
111+
112+
override def ensure2[T](f: => Either[Throwable, T], e: => Either[Throwable, Unit]): Either[Throwable, T] = {
102113
def runE =
103114
Try(e) match {
104115
case Failure(f) => Left(f)
105116
case Success(v) => v
106117
}
107-
f match {
118+
119+
val ef =
120+
try f
121+
catch {
122+
case t: Throwable =>
123+
runE
124+
throw t
125+
}
126+
127+
ef match {
108128
case Left(f) => runE.right.flatMap(_ => Left(f))
109129
case Right(v) => runE.right.map(_ => v)
110130
}
@@ -125,11 +145,22 @@ object TryMonad extends MonadError[Try] {
125145

126146
override def fromTry[T](t: Try[T]): Try[T] = t
127147

128-
override def ensure[T](f: Try[T], e: => Try[Unit]): Try[T] =
129-
f match {
148+
override def ensure[T](f: Try[T], e: => Try[Unit]): Try[T] = ensure2(f, e)
149+
150+
override def ensure2[T](f: => Try[T], e: => Try[Unit]): Try[T] = {
151+
val ef =
152+
try f
153+
catch {
154+
case t: Throwable =>
155+
e
156+
throw t
157+
}
158+
159+
ef match {
130160
case Success(v) => Try(e).flatten.map(_ => v)
131161
case Failure(f) => Try(e).flatten.flatMap(_ => Failure(f))
132162
}
163+
}
133164
}
134165
class FutureMonad(implicit ec: ExecutionContext) extends MonadAsyncError[Future] {
135166
override def unit[T](t: T): Future[T] = Future.successful(t)
@@ -155,16 +186,23 @@ class FutureMonad(implicit ec: ExecutionContext) extends MonadAsyncError[Future]
155186
p.future
156187
}
157188

158-
override def ensure[T](f: Future[T], e: => Future[Unit]): Future[T] = {
189+
override def ensure[T](f: Future[T], e: => Future[Unit]): Future[T] = ensure2(f, e)
190+
191+
override def ensure2[T](f: => Future[T], e: => Future[Unit]): Future[T] = {
159192
val p = Promise[T]()
160193
def runE =
161194
Try(e) match {
162195
case Failure(f) => Future.failed(f)
163196
case Success(v) => v
164197
}
165-
f.onComplete {
166-
case Success(v) => runE.map(_ => v).onComplete(p.complete(_))
167-
case Failure(f) => runE.flatMap(_ => Future.failed(f)).onComplete(p.complete(_))
198+
try {
199+
f.onComplete {
200+
case Success(v) => runE.map(_ => v).onComplete(p.complete(_))
201+
case Failure(f) => runE.flatMap(_ => Future.failed(f)).onComplete(p.complete(_))
202+
}
203+
} catch {
204+
case t: Throwable =>
205+
e.onComplete(_ => p.complete(Failure(t)))
168206
}
169207
p.future
170208
}
@@ -181,7 +219,8 @@ object IdentityMonad extends MonadError[Identity] {
181219
h: PartialFunction[Throwable, Identity[T]]
182220
): Identity[T] = rt
183221
override def eval[T](t: => T): Identity[T] = t
184-
override def ensure[T](f: Identity[T], e: => Identity[Unit]): Identity[T] =
222+
override def ensure[T](f: Identity[T], e: => Identity[Unit]): Identity[T] = ensure2(f, e)
223+
override def ensure2[T](f: => Identity[T], e: => Identity[Unit]): Identity[T] =
185224
try f
186225
finally e
187226
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
package sttp.monad
2+
3+
import org.scalatest.concurrent.ScalaFutures.convertScalaFuture
4+
import org.scalatest.flatspec.AnyFlatSpec
5+
import org.scalatest.matchers.should.Matchers
6+
7+
import java.util.concurrent.atomic.AtomicBoolean
8+
import scala.concurrent.ExecutionContext.Implicits.global
9+
import scala.concurrent.Future
10+
11+
class FutureMonadTest extends AnyFlatSpec with Matchers {
12+
implicit val m: MonadError[Future] = new FutureMonad()
13+
14+
it should "ensure" in {
15+
val ran = new AtomicBoolean(false)
16+
17+
intercept[RuntimeException] {
18+
m.ensure2((throw new RuntimeException("boom!")): Future[Int], Future(ran.set(true))).futureValue
19+
}
20+
21+
ran.get shouldBe true
22+
}
23+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
package sttp.monad
2+
3+
import org.scalatest.flatspec.AnyFlatSpec
4+
import org.scalatest.matchers.should.Matchers
5+
6+
import sttp.monad.syntax._
7+
import sttp.shared.Identity
8+
9+
class IdentityMonadTest extends AnyFlatSpec with Matchers {
10+
implicit val m: MonadError[Identity] = IdentityMonad
11+
12+
it should "map" in {
13+
m.map(10)(_ + 2) shouldBe 12
14+
}
15+
16+
it should "ensure" in {
17+
var ran = false
18+
19+
intercept[RuntimeException] {
20+
m.ensure2(m.error(new RuntimeException("boom!")), { ran = true })
21+
}
22+
23+
ran shouldBe true
24+
}
25+
26+
it should "ensure using syntax" in {
27+
var ran = false
28+
29+
intercept[RuntimeException] {
30+
m.error[Int](new RuntimeException("boom!")).ensure({ ran = true })
31+
}
32+
33+
ran shouldBe true
34+
}
35+
}

ws/src/test/scala/sttp/ws/testing/WebSocketStubTests.scala

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,15 @@ import sttp.monad.MonadError
77
import sttp.ws.{WebSocketClosed, WebSocketFrame}
88

99
import scala.util.{Failure, Success}
10+
import sttp.monad.IdentityMonad
1011

1112
class WebSocketStubTests extends AnyFlatSpec with Matchers with ScalaFutures {
12-
type Identity[X] = X
13-
object IdMonad extends MonadError[Identity] {
14-
override def unit[T](t: T): Identity[T] = t
15-
override def map[T, T2](fa: Identity[T])(f: T => T2): Identity[T2] = f(fa)
16-
override def flatMap[T, T2](fa: Identity[T])(f: T => Identity[T2]): Identity[T2] = f(fa)
17-
18-
override def error[T](t: Throwable): Identity[T] = throw t
19-
override protected def handleWrappedError[T](rt: Identity[T])(
20-
h: PartialFunction[Throwable, Identity[T]]
21-
): Identity[T] = rt
22-
23-
override def eval[T](t: => T): Identity[T] = t
24-
override def ensure[T](f: Identity[T], e: => Identity[Unit]): Identity[T] =
25-
try f
26-
finally e
27-
}
28-
2913
class MyException extends Exception
3014

3115
"web socket stub" should "return initial Incoming frames on 'receive'" in {
3216
val frames = List("a", "b", "c").map(WebSocketFrame.text)
3317
val webSocketStub = WebSocketStub.initialReceive(frames)
34-
val ws = webSocketStub.build(IdMonad)
18+
val ws = webSocketStub.build(IdentityMonad)
3519

3620
ws.receive() shouldBe WebSocketFrame.text("a")
3721
ws.receive() shouldBe WebSocketFrame.text("b")
@@ -42,7 +26,7 @@ class WebSocketStubTests extends AnyFlatSpec with Matchers with ScalaFutures {
4226
val okFrame = WebSocketFrame.text("abc")
4327
val exception = new MyException
4428
val webSocketStub = WebSocketStub.initialReceiveWith(List(Success(okFrame), Failure(exception)))
45-
val ws = webSocketStub.build(IdMonad)
29+
val ws = webSocketStub.build(IdentityMonad)
4630

4731
ws.receive() shouldBe WebSocketFrame.text("abc")
4832
assertThrows[MyException](ws.receive())
@@ -59,7 +43,7 @@ class WebSocketStubTests extends AnyFlatSpec with Matchers with ScalaFutures {
5943
case `expectedFrame` => List(secondFrame, thirdFrame)
6044
case _ => List.empty
6145
}
62-
val ws = webSocketStub.build(IdMonad)
46+
val ws = webSocketStub.build(IdentityMonad)
6347

6448
ws.receive() shouldBe WebSocketFrame.text("No. 1")
6549
assertThrows[IllegalStateException](ws.receive()) // no more stubbed messages
@@ -76,7 +60,7 @@ class WebSocketStubTests extends AnyFlatSpec with Matchers with ScalaFutures {
7660

7761
val webSocketStub = WebSocketStub.noInitialReceive
7862
.thenRespondWith(_ => List(Success(ok), Failure(exception)))
79-
val ws = webSocketStub.build(IdMonad)
63+
val ws = webSocketStub.build(IdentityMonad)
8064

8165
ws.send(WebSocketFrame.text("let's add responses"))
8266
ws.receive() shouldBe WebSocketFrame.text("ok")
@@ -87,7 +71,7 @@ class WebSocketStubTests extends AnyFlatSpec with Matchers with ScalaFutures {
8771
val ok = WebSocketFrame.text("ok")
8872
val closeFrame = WebSocketFrame.Close(500, "internal error")
8973
val webSocketStub = WebSocketStub.initialReceive(List(closeFrame, ok))
90-
val ws = webSocketStub.build(IdMonad)
74+
val ws = webSocketStub.build(IdentityMonad)
9175

9276
ws.send(WebSocketFrame.text("let's add responses"))
9377
ws.receive() shouldBe closeFrame
@@ -100,7 +84,7 @@ class WebSocketStubTests extends AnyFlatSpec with Matchers with ScalaFutures {
10084
.thenRespondS(0) { case (counter, _) =>
10185
(counter + 1, List(WebSocketFrame.text(s"No. $counter")))
10286
}
103-
val ws = webSocketStub.build(IdMonad)
87+
val ws = webSocketStub.build(IdentityMonad)
10488

10589
ws.send(WebSocketFrame.text("a"))
10690
ws.send(WebSocketFrame.text("b"))

0 commit comments

Comments
 (0)