Skip to content

Commit 0958fe2

Browse files
authored
Add Flow.retry and Flow.recover (#339)
Closes #44
1 parent 893d05d commit 0958fe2

File tree

4 files changed

+397
-8
lines changed

4 files changed

+397
-8
lines changed

core/src/main/scala/ox/flow/FlowOps.scala

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ import ox.forkCancellable
1818
import ox.forkUnsupervised
1919
import ox.forkUser
2020
import ox.repeatWhile
21+
import ox.resilience.RetryConfig
22+
import ox.scheduling.Schedule
2123
import ox.sleep
2224
import ox.supervised
2325
import ox.tapException
@@ -951,6 +953,60 @@ class FlowOps[+T]:
951953
def onError(f: Throwable => Unit): Flow[T] = Flow.usingEmitInline: emit =>
952954
last.run(emit).tapException(f)
953955

956+
/** Retries the upstream flow execution using the provided retry configuration. If the flow fails with an exception, it will be retried
957+
* according to the schedule defined in the retry config until it succeeds or the retry policy decides to stop.
958+
*
959+
* Each retry attempt will run the complete upstream flow, from start up to this point. The retry behavior is controlled by the
960+
* [[RetryConfig]].
961+
*
962+
* Note that this retries the flow execution itself, not individual elements within the flow. If you need to retry individual operations
963+
* within the flow, consider using retry logic inside methods such as [[map]].
964+
*
965+
* Creates an asynchronous boundary (see [[buffer]]) to isolate failures when running the upstream flow.
966+
*
967+
* @param config
968+
* The retry configuration that specifies the retry schedule and success/failure conditions.
969+
* @return
970+
* A new flow that will retry execution according to the provided configuration.
971+
* @throws anything
972+
* The exception from the last retry attempt if all retries are exhausted.
973+
* @see
974+
* [[ox.resilience.retry]]
975+
*/
976+
def retry(config: RetryConfig[Throwable, Unit])(using BufferCapacity): Flow[T] = Flow.usingEmitInline: emit =>
977+
val ch = BufferCapacity.newChannel[T]
978+
unsupervised:
979+
forkPropagate(ch) {
980+
ox.resilience.retry(config)(last.run(FlowEmit.fromInline(t => ch.send(t))))
981+
ch.done()
982+
}.discard
983+
FlowEmit.channelToEmit(ch, emit)
984+
985+
/** @see
986+
* [[retry(RetryConfig)]]
987+
*/
988+
def retry(schedule: Schedule): Flow[T] = retry(RetryConfig(schedule))
989+
990+
/** Recovers from errors in the upstream flow by emitting a recovery value when the error is handled by the partial function. If the
991+
* partial function is not defined for the error, the original error is propagated.
992+
*
993+
* Creates an asynchronous boundary (see [[buffer]]) to isolate failures when running the upstream flow.
994+
*
995+
* @param pf
996+
* A partial function that handles specific exceptions and returns a recovery value to emit.
997+
* @return
998+
* A flow that emits elements from the upstream flow, and emits a recovery value if the upstream fails with a handled exception.
999+
*/
1000+
def recover[U >: T](pf: PartialFunction[Throwable, U])(using BufferCapacity): Flow[U] = Flow.usingEmitInline: emit =>
1001+
val ch = BufferCapacity.newChannel[U]
1002+
unsupervised:
1003+
forkPropagate(ch) {
1004+
try last.run(FlowEmit.fromInline(t => ch.send(t)))
1005+
catch case e: Throwable if pf.isDefinedAt(e) => ch.send(pf(e))
1006+
ch.done()
1007+
}.discard
1008+
FlowEmit.channelToEmit(ch, emit)
1009+
9541010
//
9551011

9561012
protected def runLastToChannelAsync(ch: Sink[T])(using OxUnsupervised): Unit =

core/src/main/scala/ox/scheduling/scheduled.scala

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,16 +102,15 @@ def scheduledEither[E, T](config: ScheduledConfig[E, T])(operation: => Either[E,
102102
*/
103103
def scheduledWithErrorMode[E, F[_], T](em: ErrorMode[E, F])(config: ScheduledConfig[E, T])(operation: => F[T]): F[T] =
104104
@tailrec
105-
def loop(invocation: Int, intervals: LazyList[FiniteDuration], lastDuration: Option[FiniteDuration]): F[T] =
106-
def sleepIfNeeded(startTimestamp: Long, nextDelay: FiniteDuration) =
105+
def loop(invocation: Int, intervals: LazyList[FiniteDuration]): F[T] =
106+
def sleepIfNeeded(startTimestamp: Long, nextDelay: FiniteDuration): Unit =
107107
val delay = config.sleepMode match
108108
case SleepMode.StartToStart =>
109109
val elapsed = System.nanoTime() - startTimestamp
110110
val remaining = nextDelay.toNanos - elapsed
111111
remaining.nanos
112112
case SleepMode.EndToStart => nextDelay
113113
if delay.toMillis > 0 then sleep(delay)
114-
delay
115114
end sleepIfNeeded
116115

117116
val startTimestamp = System.nanoTime()
@@ -123,22 +122,22 @@ def scheduledWithErrorMode[E, F[_], T](em: ErrorMode[E, F])(config: ScheduledCon
123122

124123
nextDelay match
125124
case Some(nd) if !shouldStop.stop =>
126-
val delay = sleepIfNeeded(startTimestamp, nd)
127-
loop(invocation + 1, intervals.tail, Some(delay))
125+
sleepIfNeeded(startTimestamp, nd)
126+
loop(invocation + 1, intervals.tail)
128127
case _ => v
129128
case v =>
130129
val result = em.getT(v)
131130
val shouldStop = config.afterAttempt(invocation, Right(result))
132131

133132
nextDelay match
134133
case Some(nd) if !shouldStop.stop =>
135-
val delay = sleepIfNeeded(startTimestamp, nd)
136-
loop(invocation + 1, intervals.tail, Some(delay))
134+
sleepIfNeeded(startTimestamp, nd)
135+
loop(invocation + 1, intervals.tail)
137136
case _ => v
138137
end match
139138
end loop
140139

141140
config.schedule.initialDelay.foreach(sleep)
142141

143-
loop(1, config.schedule.intervals(), None)
142+
loop(1, config.schedule.intervals())
144143
end scheduledWithErrorMode
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package ox.flow
2+
3+
import org.scalatest.flatspec.AnyFlatSpec
4+
import org.scalatest.matchers.should.Matchers
5+
import ox.*
6+
import ox.channels.ChannelClosedException
7+
8+
class FlowOpsRecoverTest extends AnyFlatSpec with Matchers:
9+
10+
behavior of "Flow.recover"
11+
12+
it should "pass through elements when upstream flow succeeds" in:
13+
// given
14+
val flow = Flow.fromValues(1, 2, 3)
15+
val recoveryFunction: PartialFunction[Throwable, Int] = { case _: IllegalArgumentException =>
16+
42
17+
}
18+
19+
// when
20+
val result = flow.recover(recoveryFunction).runToList()
21+
22+
// then
23+
result shouldBe List(1, 2, 3)
24+
25+
it should "emit recovery value when upstream flow fails with handled exception" in:
26+
// given
27+
val exception = new IllegalArgumentException("test error")
28+
val flow = Flow.fromValues(1, 2).concat(Flow.failed(exception))
29+
val recoveryFunction: PartialFunction[Throwable, Int] = { case _: IllegalArgumentException =>
30+
42
31+
}
32+
33+
// when
34+
val result = flow.recover(recoveryFunction).runToList()
35+
36+
// then
37+
result shouldBe List(1, 2, 42)
38+
39+
it should "not emit recovery value when downstream flow fails with handled exception" in:
40+
// given
41+
val exception = new IllegalArgumentException("test error")
42+
val recoveryFunction: PartialFunction[Throwable, Int] = { case _: IllegalArgumentException =>
43+
42
44+
}
45+
val flow = Flow.fromValues(1, 2).recover(recoveryFunction).concat(Flow.failed(exception))
46+
47+
// when & then
48+
the[IllegalArgumentException] thrownBy {
49+
flow.runToList()
50+
} should have message "test error"
51+
52+
it should "propagate unhandled exceptions" in:
53+
// given
54+
val exception = new RuntimeException("unhandled error")
55+
val flow = Flow.fromValues(1, 2).concat(Flow.failed(exception))
56+
val recoveryFunction: PartialFunction[Throwable, Int] = { case _: IllegalArgumentException =>
57+
42
58+
}
59+
60+
// when & then
61+
val caught = the[ChannelClosedException.Error] thrownBy {
62+
flow.recover(recoveryFunction).runToList()
63+
}
64+
caught.getCause shouldBe an[RuntimeException]
65+
caught.getCause.getMessage shouldBe "unhandled error"
66+
67+
it should "handle multiple exception types" in:
68+
// given
69+
val exception = new IllegalStateException("state error")
70+
val flow = Flow.fromValues(1, 2).concat(Flow.failed(exception))
71+
val recoveryFunction: PartialFunction[Throwable, Int] = {
72+
case _: IllegalArgumentException => 42
73+
case _: IllegalStateException => 99
74+
case _: NullPointerException => 0
75+
}
76+
77+
// when
78+
val result = flow.recover(recoveryFunction).runToList()
79+
80+
// then
81+
result shouldBe List(1, 2, 99)
82+
83+
it should "work with different recovery value type" in:
84+
// given
85+
val exception = new IllegalArgumentException("test error")
86+
val flow = Flow.fromValues("a", "b").concat(Flow.failed(exception))
87+
val recoveryFunction: PartialFunction[Throwable, String] = { case _: IllegalArgumentException =>
88+
"recovered"
89+
}
90+
91+
// when
92+
val result = flow.recover(recoveryFunction).runToList()
93+
94+
// then
95+
result shouldBe List("a", "b", "recovered")
96+
97+
it should "handle exception thrown during flow processing" in:
98+
// given
99+
val flow = Flow.fromValues(1, 2, 3).map(x => if x == 3 then throw new IllegalArgumentException("map error") else x)
100+
val recoveryFunction: PartialFunction[Throwable, Int] = { case _: IllegalArgumentException =>
101+
-1
102+
}
103+
104+
// when
105+
val result = flow.recover(recoveryFunction).runToList()
106+
107+
// then
108+
result shouldBe List(1, 2, -1)
109+
110+
it should "work with empty flow" in:
111+
// given
112+
val flow = Flow.empty[Int]
113+
val recoveryFunction: PartialFunction[Throwable, Int] = { case _: IllegalArgumentException =>
114+
42
115+
}
116+
117+
// when
118+
val result = flow.recover(recoveryFunction).runToList()
119+
120+
// then
121+
result shouldBe List.empty
122+
123+
it should "propagate exception when partial function throws" in:
124+
// given
125+
val originalException = new IllegalArgumentException("original error")
126+
val flow = Flow.fromValues(1, 2).concat(Flow.failed(originalException))
127+
val recoveryFunction: PartialFunction[Throwable, Int] = { case _: IllegalArgumentException =>
128+
throw new RuntimeException("recovery failed")
129+
}
130+
131+
// when & then
132+
val caught = the[ChannelClosedException.Error] thrownBy {
133+
flow.recover(recoveryFunction).runToList()
134+
}
135+
caught.getCause shouldBe an[RuntimeException]
136+
caught.getCause.getMessage shouldBe "recovery failed"
137+
end FlowOpsRecoverTest

0 commit comments

Comments
 (0)