diff --git a/core/src/main/scala/ox/flow/FlowOps.scala b/core/src/main/scala/ox/flow/FlowOps.scala index 395e8751..05b35a0b 100644 --- a/core/src/main/scala/ox/flow/FlowOps.scala +++ b/core/src/main/scala/ox/flow/FlowOps.scala @@ -18,6 +18,8 @@ import ox.forkCancellable import ox.forkUnsupervised import ox.forkUser import ox.repeatWhile +import ox.resilience.RetryConfig +import ox.scheduling.Schedule import ox.sleep import ox.supervised import ox.tapException @@ -951,6 +953,60 @@ class FlowOps[+T]: def onError(f: Throwable => Unit): Flow[T] = Flow.usingEmitInline: emit => last.run(emit).tapException(f) + /** Retries the upstream flow execution using the provided retry configuration. If the flow fails with an exception, it will be retried + * according to the schedule defined in the retry config until it succeeds or the retry policy decides to stop. + * + * Each retry attempt will run the complete upstream flow, from start up to this point. The retry behavior is controlled by the + * [[RetryConfig]]. + * + * Note that this retries the flow execution itself, not individual elements within the flow. If you need to retry individual operations + * within the flow, consider using retry logic inside methods such as [[map]]. + * + * Creates an asynchronous boundary (see [[buffer]]) to isolate failures when running the upstream flow. + * + * @param config + * The retry configuration that specifies the retry schedule and success/failure conditions. + * @return + * A new flow that will retry execution according to the provided configuration. + * @throws anything + * The exception from the last retry attempt if all retries are exhausted. + * @see + * [[ox.resilience.retry]] + */ + def retry(config: RetryConfig[Throwable, Unit])(using BufferCapacity): Flow[T] = Flow.usingEmitInline: emit => + val ch = BufferCapacity.newChannel[T] + unsupervised: + forkPropagate(ch) { + ox.resilience.retry(config)(last.run(FlowEmit.fromInline(t => ch.send(t)))) + ch.done() + }.discard + FlowEmit.channelToEmit(ch, emit) + + /** @see + * [[retry(RetryConfig)]] + */ + def retry(schedule: Schedule): Flow[T] = retry(RetryConfig(schedule)) + + /** Recovers from errors in the upstream flow by emitting a recovery value when the error is handled by the partial function. If the + * partial function is not defined for the error, the original error is propagated. + * + * Creates an asynchronous boundary (see [[buffer]]) to isolate failures when running the upstream flow. + * + * @param pf + * A partial function that handles specific exceptions and returns a recovery value to emit. + * @return + * A flow that emits elements from the upstream flow, and emits a recovery value if the upstream fails with a handled exception. + */ + def recover[U >: T](pf: PartialFunction[Throwable, U])(using BufferCapacity): Flow[U] = Flow.usingEmitInline: emit => + val ch = BufferCapacity.newChannel[U] + unsupervised: + forkPropagate(ch) { + try last.run(FlowEmit.fromInline(t => ch.send(t))) + catch case e: Throwable if pf.isDefinedAt(e) => ch.send(pf(e)) + ch.done() + }.discard + FlowEmit.channelToEmit(ch, emit) + // protected def runLastToChannelAsync(ch: Sink[T])(using OxUnsupervised): Unit = diff --git a/core/src/main/scala/ox/scheduling/scheduled.scala b/core/src/main/scala/ox/scheduling/scheduled.scala index 985b961c..dad2190e 100644 --- a/core/src/main/scala/ox/scheduling/scheduled.scala +++ b/core/src/main/scala/ox/scheduling/scheduled.scala @@ -102,8 +102,8 @@ def scheduledEither[E, T](config: ScheduledConfig[E, T])(operation: => Either[E, */ def scheduledWithErrorMode[E, F[_], T](em: ErrorMode[E, F])(config: ScheduledConfig[E, T])(operation: => F[T]): F[T] = @tailrec - def loop(invocation: Int, intervals: LazyList[FiniteDuration], lastDuration: Option[FiniteDuration]): F[T] = - def sleepIfNeeded(startTimestamp: Long, nextDelay: FiniteDuration) = + def loop(invocation: Int, intervals: LazyList[FiniteDuration]): F[T] = + def sleepIfNeeded(startTimestamp: Long, nextDelay: FiniteDuration): Unit = val delay = config.sleepMode match case SleepMode.StartToStart => val elapsed = System.nanoTime() - startTimestamp @@ -111,7 +111,6 @@ def scheduledWithErrorMode[E, F[_], T](em: ErrorMode[E, F])(config: ScheduledCon remaining.nanos case SleepMode.EndToStart => nextDelay if delay.toMillis > 0 then sleep(delay) - delay end sleepIfNeeded val startTimestamp = System.nanoTime() @@ -123,8 +122,8 @@ def scheduledWithErrorMode[E, F[_], T](em: ErrorMode[E, F])(config: ScheduledCon nextDelay match case Some(nd) if !shouldStop.stop => - val delay = sleepIfNeeded(startTimestamp, nd) - loop(invocation + 1, intervals.tail, Some(delay)) + sleepIfNeeded(startTimestamp, nd) + loop(invocation + 1, intervals.tail) case _ => v case v => val result = em.getT(v) @@ -132,13 +131,13 @@ def scheduledWithErrorMode[E, F[_], T](em: ErrorMode[E, F])(config: ScheduledCon nextDelay match case Some(nd) if !shouldStop.stop => - val delay = sleepIfNeeded(startTimestamp, nd) - loop(invocation + 1, intervals.tail, Some(delay)) + sleepIfNeeded(startTimestamp, nd) + loop(invocation + 1, intervals.tail) case _ => v end match end loop config.schedule.initialDelay.foreach(sleep) - loop(1, config.schedule.intervals(), None) + loop(1, config.schedule.intervals()) end scheduledWithErrorMode diff --git a/core/src/test/scala/ox/flow/FlowOpsRecoverTest.scala b/core/src/test/scala/ox/flow/FlowOpsRecoverTest.scala new file mode 100644 index 00000000..54716cbc --- /dev/null +++ b/core/src/test/scala/ox/flow/FlowOpsRecoverTest.scala @@ -0,0 +1,137 @@ +package ox.flow + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import ox.* +import ox.channels.ChannelClosedException + +class FlowOpsRecoverTest extends AnyFlatSpec with Matchers: + + behavior of "Flow.recover" + + it should "pass through elements when upstream flow succeeds" in: + // given + val flow = Flow.fromValues(1, 2, 3) + val recoveryFunction: PartialFunction[Throwable, Int] = { case _: IllegalArgumentException => + 42 + } + + // when + val result = flow.recover(recoveryFunction).runToList() + + // then + result shouldBe List(1, 2, 3) + + it should "emit recovery value when upstream flow fails with handled exception" in: + // given + val exception = new IllegalArgumentException("test error") + val flow = Flow.fromValues(1, 2).concat(Flow.failed(exception)) + val recoveryFunction: PartialFunction[Throwable, Int] = { case _: IllegalArgumentException => + 42 + } + + // when + val result = flow.recover(recoveryFunction).runToList() + + // then + result shouldBe List(1, 2, 42) + + it should "not emit recovery value when downstream flow fails with handled exception" in: + // given + val exception = new IllegalArgumentException("test error") + val recoveryFunction: PartialFunction[Throwable, Int] = { case _: IllegalArgumentException => + 42 + } + val flow = Flow.fromValues(1, 2).recover(recoveryFunction).concat(Flow.failed(exception)) + + // when & then + the[IllegalArgumentException] thrownBy { + flow.runToList() + } should have message "test error" + + it should "propagate unhandled exceptions" in: + // given + val exception = new RuntimeException("unhandled error") + val flow = Flow.fromValues(1, 2).concat(Flow.failed(exception)) + val recoveryFunction: PartialFunction[Throwable, Int] = { case _: IllegalArgumentException => + 42 + } + + // when & then + val caught = the[ChannelClosedException.Error] thrownBy { + flow.recover(recoveryFunction).runToList() + } + caught.getCause shouldBe an[RuntimeException] + caught.getCause.getMessage shouldBe "unhandled error" + + it should "handle multiple exception types" in: + // given + val exception = new IllegalStateException("state error") + val flow = Flow.fromValues(1, 2).concat(Flow.failed(exception)) + val recoveryFunction: PartialFunction[Throwable, Int] = { + case _: IllegalArgumentException => 42 + case _: IllegalStateException => 99 + case _: NullPointerException => 0 + } + + // when + val result = flow.recover(recoveryFunction).runToList() + + // then + result shouldBe List(1, 2, 99) + + it should "work with different recovery value type" in: + // given + val exception = new IllegalArgumentException("test error") + val flow = Flow.fromValues("a", "b").concat(Flow.failed(exception)) + val recoveryFunction: PartialFunction[Throwable, String] = { case _: IllegalArgumentException => + "recovered" + } + + // when + val result = flow.recover(recoveryFunction).runToList() + + // then + result shouldBe List("a", "b", "recovered") + + it should "handle exception thrown during flow processing" in: + // given + val flow = Flow.fromValues(1, 2, 3).map(x => if x == 3 then throw new IllegalArgumentException("map error") else x) + val recoveryFunction: PartialFunction[Throwable, Int] = { case _: IllegalArgumentException => + -1 + } + + // when + val result = flow.recover(recoveryFunction).runToList() + + // then + result shouldBe List(1, 2, -1) + + it should "work with empty flow" in: + // given + val flow = Flow.empty[Int] + val recoveryFunction: PartialFunction[Throwable, Int] = { case _: IllegalArgumentException => + 42 + } + + // when + val result = flow.recover(recoveryFunction).runToList() + + // then + result shouldBe List.empty + + it should "propagate exception when partial function throws" in: + // given + val originalException = new IllegalArgumentException("original error") + val flow = Flow.fromValues(1, 2).concat(Flow.failed(originalException)) + val recoveryFunction: PartialFunction[Throwable, Int] = { case _: IllegalArgumentException => + throw new RuntimeException("recovery failed") + } + + // when & then + val caught = the[ChannelClosedException.Error] thrownBy { + flow.recover(recoveryFunction).runToList() + } + caught.getCause shouldBe an[RuntimeException] + caught.getCause.getMessage shouldBe "recovery failed" +end FlowOpsRecoverTest diff --git a/core/src/test/scala/ox/flow/FlowOpsRetryTest.scala b/core/src/test/scala/ox/flow/FlowOpsRetryTest.scala new file mode 100644 index 00000000..0b17a34f --- /dev/null +++ b/core/src/test/scala/ox/flow/FlowOpsRetryTest.scala @@ -0,0 +1,197 @@ +package ox.flow + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers +import ox.* +import ox.channels.ChannelClosedException +import ox.resilience.ResultPolicy +import ox.resilience.RetryConfig +import ox.scheduling.Schedule +import ox.util.ElapsedTime + +import java.util.concurrent.atomic.AtomicInteger +import scala.concurrent.duration.* + +class FlowOpsRetryTest extends AnyFlatSpec with Matchers with ElapsedTime: + + behavior of "Flow.retry" + + it should "successfully run a flow without retries when no errors occur" in: + // given + val flow = Flow.fromValues(1, 2, 3) + + // when + val result = flow.retry(Schedule.immediate.maxRepeats(3)).runToList() + + // then + result shouldBe List(1, 2, 3) + + it should "retry a failing flow with immediate schedule" in: + // given + val attemptCounter = new AtomicInteger(0) + val maxRetries = 3 + + val flow = Flow.usingEmit[Int] { emit => + val attempt = attemptCounter.incrementAndGet() + if attempt <= maxRetries then throw new RuntimeException(s"attempt $attempt failed") + else emit(42) + } + + // when + val result = flow.retry(Schedule.immediate.maxRepeats(maxRetries)).runToList() + + // then + result shouldBe List(42) + attemptCounter.get() shouldBe maxRetries + 1 + + it should "retry a failing flow with fixed interval schedule" in: + // given + val attemptCounter = new AtomicInteger(0) + val maxRetries = 2 + val interval = 50.millis + + val flow = Flow.usingEmit[Int] { emit => + val attempt = attemptCounter.incrementAndGet() + if attempt <= maxRetries then throw new RuntimeException(s"attempt $attempt failed") + else emit(100) + } + + // when + val (result, elapsedTime) = measure { + flow.retry(Schedule.fixedInterval(interval).maxRepeats(maxRetries)).runToList() + } + + // then + result shouldBe List(100) + attemptCounter.get() shouldBe maxRetries + 1 + elapsedTime.toMillis should be >= (maxRetries * interval.toMillis) + + it should "not retry a flow which fails downstream" in: + // given + val upstreamInvocationCounter = new AtomicInteger(0) + val downstreamInvocationCounter = new AtomicInteger(0) + + val flow = Flow + .fromValues(1, 2, 3) + .tap(_ => upstreamInvocationCounter.incrementAndGet().discard) + .retry(Schedule.immediate.maxRepeats(3)) + .tap { value => + downstreamInvocationCounter.incrementAndGet().discard + if value == 2 then throw new RuntimeException("downstream failure") + } + + // when + val result = intercept[RuntimeException](flow.runToList()) + + // then + upstreamInvocationCounter.get() shouldBe 3 // 1, 2, 3 + downstreamInvocationCounter.get() shouldBe 2 // 1, 2 + result.getMessage() shouldBe "downstream failure" + + it should "fail after exhausting all retry attempts" in: + // given + val attemptCounter = new AtomicInteger(0) + val maxRetries = 3 + val errorMessage = "persistent failure" + + val flow = Flow.usingEmit[Int] { _ => + attemptCounter.incrementAndGet() + throw new RuntimeException(errorMessage) + } + + // when/then + val exception = the[ChannelClosedException.Error] thrownBy { + flow.retry(Schedule.immediate.maxRepeats(maxRetries)).runToList() + } + + exception.getCause.getMessage shouldBe errorMessage + attemptCounter.get() shouldBe maxRetries + 1 + + it should "use custom ResultPolicy to determine retry worthiness" in: + // given + val attemptCounter = new AtomicInteger(0) + val maxRetries = 3 + val fatalErrorMessage = "fatal error" + val retryableErrorMessage = "retryable error" + + val flow = Flow.usingEmit[Int] { emit => + val attempt = attemptCounter.incrementAndGet() + if attempt == 1 then throw new RuntimeException(retryableErrorMessage) + else if attempt == 2 then throw new RuntimeException(fatalErrorMessage) + else emit(50) + } + + val config = RetryConfig[Throwable, Unit]( + Schedule.immediate.maxRepeats(maxRetries), + ResultPolicy.retryWhen[Throwable, Unit](_.getMessage != fatalErrorMessage) + ) + + // when/then + val exception = the[ChannelClosedException.Error] thrownBy { + flow.retry(config).runToList() + } + + exception.getCause.getMessage shouldBe fatalErrorMessage + attemptCounter.get() shouldBe 2 // Should stop after fatal error, not retry + + it should "handle empty flows correctly" in: + // given + val flow = Flow.empty[Int] + + // when + val result = flow.retry(Schedule.immediate.maxRepeats(3)).runToList() + + // then + result shouldBe List.empty + + it should "handle flows that complete successfully on first attempt" in: + // given + val invocationCounter = new AtomicInteger(0) + val flow = Flow.usingEmit[String] { emit => + invocationCounter.incrementAndGet() + emit("first try success") + } + + // when + val result = flow.retry(Schedule.immediate.maxRepeats(5)).runToList() + + // then + result shouldBe List("first try success") + invocationCounter.get() shouldBe 1 // Should only run once + + it should "retry the entire flow when processing fails" in: + // given + val invocationCounter = new AtomicInteger(0) + val flow = Flow.fromValues(1, 2, 3).map { value => + val invocation = invocationCounter.incrementAndGet() + if invocation == 1 then throw new RuntimeException("processing error") + else value * 2 + } + + // when + val result = flow.retry(Schedule.immediate.maxRepeats(2)).runToList() + + // then + result shouldBe List(2, 4, 6) + invocationCounter.get() shouldBe 4 // First attempt fails on element 1, second attempt processes all 3 elements + + it should "work with complex flows containing transformations" in: + // given + val invocationCounter = new AtomicInteger(0) + val flow = Flow + .fromValues(1, 2, 3, 4, 5) + .filter(_ % 2 == 0) // Keep only even numbers: 2, 4 + .map { value => + val invocation = invocationCounter.incrementAndGet() + if invocation == 1 then throw new RuntimeException("transformation failed") + else value * 10 + } + + // when + val result = flow.retry(Schedule.immediate.maxRepeats(1)).runToList() + + // then + result shouldBe List(20, 40) + invocationCounter.get() shouldBe 3 // First attempt fails on element 2, second attempt processes both filtered elements (2, 4) + +end FlowOpsRetryTest