diff --git a/experimental/lambda/src/main/java/io/serverlessworkflow/impl/expressions/func/JavaModel.java b/experimental/lambda/src/main/java/io/serverlessworkflow/impl/expressions/func/JavaModel.java index ec05a1b9..884d0fe4 100644 --- a/experimental/lambda/src/main/java/io/serverlessworkflow/impl/expressions/func/JavaModel.java +++ b/experimental/lambda/src/main/java/io/serverlessworkflow/impl/expressions/func/JavaModel.java @@ -94,6 +94,9 @@ public Class objectClass() { @Override public Optional as(Class clazz) { + if (WorkflowModel.class.isAssignableFrom(clazz)) { + return Optional.of(clazz.cast(this)); + } return object != null && clazz.isAssignableFrom(object.getClass()) ? Optional.of(clazz.cast(object)) : Optional.empty(); diff --git a/experimental/lambda/src/test/java/io/serverless/workflow/impl/CallTest.java b/experimental/lambda/src/test/java/io/serverless/workflow/impl/CallTest.java index 078cc5e3..43625509 100644 --- a/experimental/lambda/src/test/java/io/serverless/workflow/impl/CallTest.java +++ b/experimental/lambda/src/test/java/io/serverless/workflow/impl/CallTest.java @@ -33,6 +33,7 @@ import io.serverlessworkflow.api.types.func.SwitchCaseFunction; import io.serverlessworkflow.impl.WorkflowApplication; import io.serverlessworkflow.impl.WorkflowDefinition; +import io.serverlessworkflow.impl.WorkflowModel; import io.serverlessworkflow.impl.expressions.TaskMetadataKeys; import java.util.Collection; import java.util.List; @@ -166,6 +167,30 @@ void testIf() throws InterruptedException, ExecutionException { } } + @Test + void testIfWithModel() throws InterruptedException, ExecutionException { + try (WorkflowApplication app = WorkflowApplication.builder().build()) { + Workflow workflow = + new Workflow() + .withDocument( + new Document().withNamespace("test").withName("testIf").withVersion("1.0")) + .withDo( + List.of( + new TaskItem( + "java", + new Task() + .withCallTask( + new CallTaskJava( + withPredicate( + CallJava.function( + CallTest::zeroWithModel, WorkflowModel.class), + CallTest::isOdd)))))); + WorkflowDefinition definition = app.workflowDefinition(workflow); + assertThat(definition.instance(3).start().get().asNumber().orElseThrow()).isEqualTo(0); + assertThat(definition.instance(4).start().get().asNumber().orElseThrow()).isEqualTo(4); + } + } + private CallJava withPredicate(CallJava call, Predicate pred) { return (CallJava) call.withMetadata( @@ -184,6 +209,10 @@ public static int zero(Integer value) { return 0; } + public static int zeroWithModel(WorkflowModel value) { + return 0; + } + public static Integer sum(Object model, Integer item) { return model instanceof Collection ? item : (Integer) model + item; }