diff --git a/build.gradle b/build.gradle index a7e6489db4..084a6d7aad 100644 --- a/build.gradle +++ b/build.gradle @@ -22,6 +22,7 @@ plugins { allprojects { repositories { + mavenLocal() mavenCentral() } } @@ -30,7 +31,7 @@ ext { // Platforms grpcVersion = '1.58.1' // [1.38.0,) Needed for io.grpc.protobuf.services.HealthStatusManager jacksonVersion = '2.14.2' // [2.9.0,) - nexusVersion = '0.4.0-alpha' + nexusVersion = '0.5.0-SNAPSHOT' // we don't upgrade to 1.10.x because it requires kotlin 1.6. Users may use 1.10.x in their environments though. micrometerVersion = project.hasProperty("edgeDepsTest") ? '1.13.6' : '1.9.9' // [1.0.0,) diff --git a/temporal-opentracing/src/main/java/io/temporal/opentracing/OpenTracingClientInterceptor.java b/temporal-opentracing/src/main/java/io/temporal/opentracing/OpenTracingClientInterceptor.java index 095ddd99cc..619a7378fd 100644 --- a/temporal-opentracing/src/main/java/io/temporal/opentracing/OpenTracingClientInterceptor.java +++ b/temporal-opentracing/src/main/java/io/temporal/opentracing/OpenTracingClientInterceptor.java @@ -1,8 +1,10 @@ package io.temporal.opentracing; +import io.temporal.common.interceptors.NexusServiceClientCallsInterceptor; import io.temporal.common.interceptors.WorkflowClientCallsInterceptor; import io.temporal.common.interceptors.WorkflowClientInterceptorBase; import io.temporal.opentracing.internal.ContextAccessor; +import io.temporal.opentracing.internal.OpenTracingNexusServiceClientCallsInterceptor; import io.temporal.opentracing.internal.OpenTracingWorkflowClientCallsInterceptor; import io.temporal.opentracing.internal.SpanFactory; @@ -27,4 +29,11 @@ public WorkflowClientCallsInterceptor workflowClientCallsInterceptor( return new OpenTracingWorkflowClientCallsInterceptor( next, options, spanFactory, contextAccessor); } + + @Override + public NexusServiceClientCallsInterceptor nexusServiceClientCallsInterceptor( + NexusServiceClientCallsInterceptor next) { + return new OpenTracingNexusServiceClientCallsInterceptor( + next, options, spanFactory, contextAccessor); + } } diff --git a/temporal-opentracing/src/main/java/io/temporal/opentracing/SpanOperationType.java b/temporal-opentracing/src/main/java/io/temporal/opentracing/SpanOperationType.java index 2f8a27429d..bf44101ba1 100644 --- a/temporal-opentracing/src/main/java/io/temporal/opentracing/SpanOperationType.java +++ b/temporal-opentracing/src/main/java/io/temporal/opentracing/SpanOperationType.java @@ -17,7 +17,13 @@ public enum SpanOperationType { HANDLE_UPDATE("HandleUpdate"), START_NEXUS_OPERATION("StartNexusOperation"), RUN_START_NEXUS_OPERATION("RunStartNexusOperationHandler"), - RUN_CANCEL_NEXUS_OPERATION("RunCancelNexusOperationHandler"); + RUN_CANCEL_NEXUS_OPERATION("RunCancelNexusOperationHandler"), + RUN_FETCH_NEXUS_OPERATION_INFO("RunFetchNexusOperationInfoHandler"), + RUN_FETCH_NEXUS_OPERATION_RESULT("RunFetchNexusOperationResultHandler"), + CLIENT_START_NEXUS_OPERATION("ClientStartNexusOperation"), + CLIENT_CANCEL_NEXUS_OPERATION("ClientCancelNexusOperation"), + CLIENT_FETCH_NEXUS_OPERATION_INFO("ClientFetchNexusOperationInfo"), + CLIENT_FETCH_NEXUS_OPERATION_RESULT("ClientFetchNexusOperationResult"); private final String defaultPrefix; diff --git a/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/ActionTypeAndNameSpanBuilderProvider.java b/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/ActionTypeAndNameSpanBuilderProvider.java index 1734f3a36d..4dacd1ed6d 100644 --- a/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/ActionTypeAndNameSpanBuilderProvider.java +++ b/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/ActionTypeAndNameSpanBuilderProvider.java @@ -86,7 +86,13 @@ protected Map getSpanTags(SpanCreationContext context) { StandardTagNames.RUN_ID, context.getRunId()); case RUN_START_NEXUS_OPERATION: case RUN_CANCEL_NEXUS_OPERATION: + case RUN_FETCH_NEXUS_OPERATION_INFO: + case RUN_FETCH_NEXUS_OPERATION_RESULT: case HANDLE_QUERY: + case CLIENT_START_NEXUS_OPERATION: + case CLIENT_CANCEL_NEXUS_OPERATION: + case CLIENT_FETCH_NEXUS_OPERATION_INFO: + case CLIENT_FETCH_NEXUS_OPERATION_RESULT: return ImmutableMap.of(); } throw new IllegalArgumentException("Unknown span operation type provided"); diff --git a/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/OpenTracingNexusOperationInboundCallsInterceptor.java b/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/OpenTracingNexusOperationInboundCallsInterceptor.java index d0cb151524..6389ce8312 100644 --- a/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/OpenTracingNexusOperationInboundCallsInterceptor.java +++ b/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/OpenTracingNexusOperationInboundCallsInterceptor.java @@ -1,6 +1,7 @@ package io.temporal.opentracing.internal; import io.nexusrpc.OperationException; +import io.nexusrpc.OperationStillRunningException; import io.opentracing.Scope; import io.opentracing.Span; import io.opentracing.SpanContext; @@ -73,4 +74,48 @@ public CancelOperationOutput cancelOperation(CancelOperationInput input) { operationCancelSpan.finish(); } } + + @Override + public FetchOperationResultOutput fetchOperationResult(FetchOperationResultInput input) + throws OperationException, OperationStillRunningException { + SpanContext rootSpanContext = + contextAccessor.readSpanContextFromHeader(input.getOperationContext().getHeaders(), tracer); + + Span operationFetchResultSpan = + spanFactory + .createFetchNexusOperationResultSpan( + tracer, + input.getOperationContext().getService(), + input.getOperationContext().getOperation(), + rootSpanContext) + .start(); + try (Scope scope = tracer.scopeManager().activate(operationFetchResultSpan)) { + return super.fetchOperationResult(input); + } catch (Throwable t) { + spanFactory.logFail(operationFetchResultSpan, t); + throw t; + } finally { + operationFetchResultSpan.finish(); + } + } + + @Override + public FetchOperationInfoResponse fetchOperationInfo(FetchOperationInfoInput input) { + SpanContext rootSpanContext = + contextAccessor.readSpanContextFromHeader(input.getOperationContext().getHeaders(), tracer); + + Span operationFetchInfoSpan = + spanFactory + .createFetchNexusOperationInfoSpan( + tracer, + input.getOperationContext().getService(), + input.getOperationContext().getOperation(), + rootSpanContext) + .start(); + try (Scope scope = tracer.scopeManager().activate(operationFetchInfoSpan)) { + return super.fetchOperationInfo(input); + } finally { + operationFetchInfoSpan.finish(); + } + } } diff --git a/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/OpenTracingNexusServiceClientCallsInterceptor.java b/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/OpenTracingNexusServiceClientCallsInterceptor.java new file mode 100644 index 0000000000..bab3ff02ff --- /dev/null +++ b/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/OpenTracingNexusServiceClientCallsInterceptor.java @@ -0,0 +1,239 @@ +package io.temporal.opentracing.internal; + +import io.nexusrpc.OperationException; +import io.nexusrpc.OperationStillRunningException; +import io.opentracing.Scope; +import io.opentracing.Span; +import io.opentracing.Tracer; +import io.temporal.common.interceptors.NexusServiceClientCallsInterceptor; +import io.temporal.common.interceptors.NexusServiceClientCallsInterceptorBase; +import io.temporal.opentracing.OpenTracingOptions; +import java.util.Map; +import java.util.concurrent.CompletableFuture; + +/** + * Nexus service client interceptor that creates OpenTracing spans and propagates the active span + * context. + */ +public class OpenTracingNexusServiceClientCallsInterceptor + extends NexusServiceClientCallsInterceptorBase { + private final SpanFactory spanFactory; + private final Tracer tracer; + private final ContextAccessor contextAccessor; + + public OpenTracingNexusServiceClientCallsInterceptor( + NexusServiceClientCallsInterceptor next, + OpenTracingOptions options, + SpanFactory spanFactory, + ContextAccessor contextAccessor) { + super(next); + this.spanFactory = spanFactory; + this.tracer = options.getTracer(); + this.contextAccessor = contextAccessor; + } + + @Override + public StartOperationOutput startOperation(StartOperationInput input) throws OperationException { + Span span = + contextAccessor.writeSpanContextToHeader( + () -> + spanFactory + .createClientStartNexusOperationSpan( + tracer, input.getServiceName(), input.getOperationName()) + .start(), + input.getOptions().getHeaders(), + tracer); + try (Scope ignored = tracer.scopeManager().activate(span)) { + return super.startOperation(input); + } catch (Throwable t) { + spanFactory.logFail(span, t); + throw t; + } finally { + span.finish(); + } + } + + @Override + public CompletableFuture startOperationAsync(StartOperationInput input) { + Span span = + contextAccessor.writeSpanContextToHeader( + () -> + spanFactory + .createClientStartNexusOperationSpan( + tracer, input.getServiceName(), input.getOperationName()) + .start(), + input.getOptions().getHeaders(), + tracer); + try (Scope ignored = tracer.scopeManager().activate(span)) { + return super.startOperationAsync(input) + .whenComplete( + (r, e) -> { + if (e != null) { + spanFactory.logFail(span, e); + } + span.finish(); + }); + } + } + + @Override + public CancelOperationOutput cancelOperation(CancelOperationInput input) { + Span span = + contextAccessor.writeSpanContextToHeader( + () -> + spanFactory + .createClientCancelNexusOperationSpan( + tracer, input.getServiceName(), input.getOperationName()) + .start(), + input.getOptions().getHeaders(), + tracer); + try (Scope ignored = tracer.scopeManager().activate(span)) { + return super.cancelOperation(input); + } catch (Throwable t) { + spanFactory.logFail(span, t); + throw t; + } finally { + span.finish(); + } + } + + @Override + public CompletableFuture cancelOperationAsync(CancelOperationInput input) { + Span span = + contextAccessor.writeSpanContextToHeader( + () -> + spanFactory + .createClientCancelNexusOperationSpan( + tracer, input.getServiceName(), input.getOperationName()) + .start(), + input.getOptions().getHeaders(), + tracer); + try (Scope ignored = tracer.scopeManager().activate(span)) { + return super.cancelOperationAsync(input) + .whenComplete( + (r, e) -> { + if (e != null) { + spanFactory.logFail(span, e); + } + span.finish(); + }); + } + } + + @Override + public FetchOperationResultOutput fetchOperationResult(FetchOperationResultInput input) + throws OperationException, OperationStillRunningException { + Span span = + contextAccessor.writeSpanContextToHeader( + () -> + spanFactory + .createClientFetchNexusOperationResultSpan( + tracer, + input.getServiceName(), + input.getOperationName(), + input.getOperationToken()) + .start(), + input.getOptions().getHeaders(), + tracer); + try (Scope ignored = tracer.scopeManager().activate(span)) { + return super.fetchOperationResult(input); + } catch (Throwable t) { + spanFactory.logFail(span, t); + throw t; + } finally { + span.finish(); + } + } + + @Override + public FetchOperationInfoOutput fetchOperationInfo(FetchOperationInfoInput input) { + Span span = + contextAccessor.writeSpanContextToHeader( + () -> + spanFactory + .createClientFetchNexusOperationInfoSpan( + tracer, input.getServiceName(), input.getOperationName()) + .start(), + input.getOptions().getHeaders(), + tracer); + try (Scope ignored = tracer.scopeManager().activate(span)) { + return super.fetchOperationInfo(input); + } catch (Throwable t) { + spanFactory.logFail(span, t); + throw t; + } finally { + span.finish(); + } + } + + @Override + public CompleteOperationOutput completeOperation(CompleteOperationInput input) { + propagate(input.getOptions().getHeaders()); + return super.completeOperation(input); + } + + @Override + public CompletableFuture fetchOperationResultAsync( + FetchOperationResultInput input) { + Span span = + contextAccessor.writeSpanContextToHeader( + () -> + spanFactory + .createClientFetchNexusOperationResultSpan( + tracer, + input.getServiceName(), + input.getOperationName(), + input.getOperationToken()) + .start(), + input.getOptions().getHeaders(), + tracer); + try (Scope ignored = tracer.scopeManager().activate(span)) { + return super.fetchOperationResultAsync(input) + .whenComplete( + (r, e) -> { + if (e != null) { + spanFactory.logFail(span, e); + } + span.finish(); + }); + } + } + + @Override + public CompletableFuture fetchOperationInfoAsync( + FetchOperationInfoInput input) { + Span span = + contextAccessor.writeSpanContextToHeader( + () -> + spanFactory + .createClientFetchNexusOperationInfoSpan( + tracer, input.getServiceName(), input.getOperationName()) + .start(), + input.getOptions().getHeaders(), + tracer); + try (Scope ignored = tracer.scopeManager().activate(span)) { + return super.fetchOperationInfoAsync(input) + .whenComplete( + (r, e) -> { + if (e != null) { + spanFactory.logFail(span, e); + } + span.finish(); + }); + } + } + + @Override + public CompletableFuture completeOperationAsync( + CompleteOperationAsyncInput input) { + propagate(input.getOptions().getHeaders()); + return super.completeOperationAsync(input); + } + + private void propagate(Map headers) { + Span activeSpan = tracer.scopeManager().activeSpan(); + if (activeSpan != null) { + contextAccessor.writeSpanContextToHeader(activeSpan.context(), headers, tracer); + } + } +} diff --git a/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/SpanFactory.java b/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/SpanFactory.java index 0848c6b652..041bd6ccde 100644 --- a/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/SpanFactory.java +++ b/temporal-opentracing/src/main/java/io/temporal/opentracing/internal/SpanFactory.java @@ -69,6 +69,46 @@ public Tracer.SpanBuilder createChildWorkflowStartSpan( return createSpan(context, tracer, null, References.CHILD_OF); } + public Tracer.SpanBuilder createClientStartNexusOperationSpan( + Tracer tracer, String serviceName, String operationName) { + SpanCreationContext context = + SpanCreationContext.newBuilder() + .setSpanOperationType(SpanOperationType.CLIENT_START_NEXUS_OPERATION) + .setActionName(serviceName + "/" + operationName) + .build(); + return createSpan(context, tracer, null, References.CHILD_OF); + } + + public Tracer.SpanBuilder createClientCancelNexusOperationSpan( + Tracer tracer, String serviceName, String operationName) { + SpanCreationContext context = + SpanCreationContext.newBuilder() + .setSpanOperationType(SpanOperationType.CLIENT_CANCEL_NEXUS_OPERATION) + .setActionName(serviceName + "/" + operationName) + .build(); + return createSpan(context, tracer, null, References.CHILD_OF); + } + + public Tracer.SpanBuilder createClientFetchNexusOperationInfoSpan( + Tracer tracer, String serviceName, String operationName) { + SpanCreationContext context = + SpanCreationContext.newBuilder() + .setSpanOperationType(SpanOperationType.CLIENT_FETCH_NEXUS_OPERATION_INFO) + .setActionName(serviceName + "/" + operationName) + .build(); + return createSpan(context, tracer, null, References.CHILD_OF); + } + + public Tracer.SpanBuilder createClientFetchNexusOperationResultSpan( + Tracer tracer, String serviceName, String operationName, String operationToken) { + SpanCreationContext context = + SpanCreationContext.newBuilder() + .setSpanOperationType(SpanOperationType.CLIENT_FETCH_NEXUS_OPERATION_RESULT) + .setActionName(serviceName + "/" + operationName) + .build(); + return createSpan(context, tracer, null, References.CHILD_OF); + } + public Tracer.SpanBuilder createExternalWorkflowSignalSpan( Tracer tracer, String signalName, String workflowId, String runId) { SpanCreationContext context = @@ -185,6 +225,26 @@ public Tracer.SpanBuilder createCancelNexusOperationSpan( return createSpan(context, tracer, nexusStartSpanContext, References.FOLLOWS_FROM); } + public Tracer.SpanBuilder createFetchNexusOperationResultSpan( + Tracer tracer, String serviceName, String operationName, SpanContext nexusStartSpanContext) { + SpanCreationContext context = + SpanCreationContext.newBuilder() + .setSpanOperationType(SpanOperationType.RUN_FETCH_NEXUS_OPERATION_RESULT) + .setActionName(serviceName + "/" + operationName) + .build(); + return createSpan(context, tracer, nexusStartSpanContext, References.FOLLOWS_FROM); + } + + public Tracer.SpanBuilder createFetchNexusOperationInfoSpan( + Tracer tracer, String serviceName, String operationName, SpanContext nexusStartSpanContext) { + SpanCreationContext context = + SpanCreationContext.newBuilder() + .setSpanOperationType(SpanOperationType.RUN_FETCH_NEXUS_OPERATION_INFO) + .setActionName(serviceName + "/" + operationName) + .build(); + return createSpan(context, tracer, nexusStartSpanContext, References.FOLLOWS_FROM); + } + public Tracer.SpanBuilder createWorkflowStartUpdateSpan( Tracer tracer, String updateName, String workflowId, String runId) { SpanCreationContext context = diff --git a/temporal-opentracing/src/test/java/io/temporal/opentracing/NexusServiceClientTracingTest.java b/temporal-opentracing/src/test/java/io/temporal/opentracing/NexusServiceClientTracingTest.java new file mode 100644 index 0000000000..c981cb1586 --- /dev/null +++ b/temporal-opentracing/src/test/java/io/temporal/opentracing/NexusServiceClientTracingTest.java @@ -0,0 +1,151 @@ +package io.temporal.opentracing; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import io.nexusrpc.OperationInfo; +import io.nexusrpc.OperationState; +import io.nexusrpc.client.OperationHandle; +import io.nexusrpc.client.ServiceClient; +import io.nexusrpc.client.StartOperationResponse; +import io.nexusrpc.handler.*; +import io.opentracing.Scope; +import io.opentracing.mock.MockSpan; +import io.opentracing.mock.MockTracer; +import io.opentracing.util.ThreadLocalScopeManager; +import io.temporal.client.WorkflowClientOptions; +import io.temporal.testing.internal.SDKTestWorkflowRule; +import io.temporal.worker.WorkerFactoryOptions; +import org.junit.After; +import org.junit.Rule; +import org.junit.Test; + +public class NexusServiceClientTracingTest { + private static final String NEXUS_OPERATION_TOKEN = "test-operation-token"; + private final MockTracer mockTracer = + new MockTracer(new ThreadLocalScopeManager(), MockTracer.Propagator.TEXT_MAP); + + private final OpenTracingOptions OT_OPTIONS = + OpenTracingOptions.newBuilder().setTracer(mockTracer).build(); + + @Rule + public SDKTestWorkflowRule testWorkflowRule = + SDKTestWorkflowRule.newBuilder() + .setWorkflowClientOptions( + WorkflowClientOptions.newBuilder() + .setInterceptors(new OpenTracingClientInterceptor(OT_OPTIONS)) + .validateAndBuildWithDefaults()) + .setWorkerFactoryOptions( + WorkerFactoryOptions.newBuilder() + .setWorkerInterceptors(new OpenTracingWorkerInterceptor(OT_OPTIONS)) + .validateAndBuildWithDefaults()) + .setNexusServiceImplementation(new TestNexusServiceImpl()) + .build(); + + @After + public void tearDown() { + mockTracer.reset(); + } + + @Test + public void testTracing() throws Exception { + MockSpan span = mockTracer.buildSpan("ClientFunction").start(); + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(NexusOperationTest.TestNexusService.class); + + try (Scope scope = mockTracer.scopeManager().activate(span)) { + StartOperationResponse result = + serviceClient.startOperation(NexusOperationTest.TestNexusService::operation, "World"); + assertTrue(result instanceof StartOperationResponse.Async); + OperationHandle handle = ((StartOperationResponse.Async) result).getHandle(); + handle.cancel(); + assertEquals("Hello World", handle.fetchResult()); + handle.fetchInfo(); + } finally { + span.finish(); + } + + OpenTracingSpansHelper spansHelper = new OpenTracingSpansHelper(mockTracer.finishedSpans()); + // Verify the start span from the client and the handler + MockSpan clientSpan = spansHelper.getSpanByOperationName("ClientFunction"); + MockSpan startSpan = spansHelper.getByParentSpan(clientSpan).get(0); + assertEquals(clientSpan.context().spanId(), startSpan.parentId()); + assertEquals("ClientStartNexusOperation:TestNexusService/operation", startSpan.operationName()); + + MockSpan runSpan = spansHelper.getByParentSpan(startSpan).get(0); + assertEquals(startSpan.context().spanId(), runSpan.parentId()); + assertEquals( + "RunStartNexusOperationHandler:TestNexusService/operation", runSpan.operationName()); + + // Verify the cancel span from the client and the handler + MockSpan clientCancelSpan = spansHelper.getByParentSpan(clientSpan).get(1); + assertEquals(clientSpan.context().spanId(), clientCancelSpan.parentId()); + assertEquals( + "ClientCancelNexusOperation:TestNexusService/operation", clientCancelSpan.operationName()); + + MockSpan handlerCancelSpan = spansHelper.getByParentSpan(clientCancelSpan).get(0); + assertEquals(clientCancelSpan.context().spanId(), handlerCancelSpan.parentId()); + assertEquals( + "RunCancelNexusOperationHandler:TestNexusService/operation", + handlerCancelSpan.operationName()); + + // Verify the fetchResult span from the client and the handler + MockSpan fetchResultSpan = spansHelper.getByParentSpan(clientSpan).get(2); + assertEquals(clientSpan.context().spanId(), fetchResultSpan.parentId()); + assertEquals( + "ClientFetchNexusOperationResult:TestNexusService/operation", + fetchResultSpan.operationName()); + + MockSpan handlerFetchResultSpan = spansHelper.getByParentSpan(fetchResultSpan).get(0); + assertEquals(fetchResultSpan.context().spanId(), handlerFetchResultSpan.parentId()); + assertEquals( + "RunFetchNexusOperationResultHandler:TestNexusService/operation", + handlerFetchResultSpan.operationName()); + + // Verify the fetchInfo span from the client and the handler + MockSpan fetchInfoSpan = spansHelper.getByParentSpan(clientSpan).get(3); + assertEquals(clientSpan.context().spanId(), fetchInfoSpan.parentId()); + assertEquals( + "ClientFetchNexusOperationInfo:TestNexusService/operation", fetchInfoSpan.operationName()); + + MockSpan handlerFetchInfoSpan = spansHelper.getByParentSpan(fetchInfoSpan).get(0); + assertEquals(fetchInfoSpan.context().spanId(), handlerFetchInfoSpan.parentId()); + assertEquals( + "RunFetchNexusOperationInfoHandler:TestNexusService/operation", + handlerFetchInfoSpan.operationName()); + } + + @ServiceImpl(service = NexusOperationTest.TestNexusService.class) + public static class TestNexusServiceImpl { + @OperationImpl + public OperationHandler operation() { + return new OperationHandler() { + @Override + public OperationStartResult start( + OperationContext context, OperationStartDetails details, String param) + throws HandlerException { + return OperationStartResult.async(NEXUS_OPERATION_TOKEN); + } + + @Override + public String fetchResult(OperationContext context, OperationFetchResultDetails details) + throws HandlerException { + return "Hello World"; + } + + @Override + public OperationInfo fetchInfo(OperationContext context, OperationFetchInfoDetails details) + throws HandlerException { + return OperationInfo.newBuilder() + .setState(OperationState.SUCCEEDED) + .setToken(NEXUS_OPERATION_TOKEN) + .build(); + } + + @Override + public void cancel(OperationContext context, OperationCancelDetails details) + throws HandlerException {} + }; + } + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/client/TemporalNexusServiceClientOptions.java b/temporal-sdk/src/main/java/io/temporal/client/TemporalNexusServiceClientOptions.java new file mode 100644 index 0000000000..09e6656271 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/client/TemporalNexusServiceClientOptions.java @@ -0,0 +1,39 @@ +package io.temporal.client; + +import com.google.common.base.Strings; +import io.temporal.common.Experimental; + +/** Options for configuring the Temporal Nexus Service client. */ +@Experimental +public class TemporalNexusServiceClientOptions { + public static Builder newBuilder() { + return new Builder(); + } + + private final String endpoint; + + TemporalNexusServiceClientOptions(String endpoint) { + this.endpoint = endpoint; + } + + public String getEndpoint() { + return endpoint; + } + + public static final class Builder { + private String endpoint; + + public Builder setEndpoint(String endpoint) { + this.endpoint = endpoint; + return this; + } + + public TemporalNexusServiceClientOptions build() { + if (Strings.isNullOrEmpty(endpoint)) { + throw new IllegalArgumentException("Must provide an endpoint"); + } + + return new TemporalNexusServiceClientOptions(endpoint); + } + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/client/WorkflowClient.java b/temporal-sdk/src/main/java/io/temporal/client/WorkflowClient.java index 49387639df..36a5cf979c 100644 --- a/temporal-sdk/src/main/java/io/temporal/client/WorkflowClient.java +++ b/temporal-sdk/src/main/java/io/temporal/client/WorkflowClient.java @@ -1,5 +1,7 @@ package io.temporal.client; +import io.nexusrpc.client.CompletionClient; +import io.nexusrpc.client.ServiceClient; import io.temporal.activity.Activity; import io.temporal.activity.ActivityExecutionContext; import io.temporal.api.common.v1.WorkflowExecution; @@ -126,6 +128,25 @@ static WorkflowClient newInstance(WorkflowServiceStubs service, WorkflowClientOp WorkflowServiceStubs getWorkflowServiceStubs(); + /** + * Create a new {@link ServiceClient} that can be used to start operations or get handlers to + * operations on a Nexus services. + * + * @param nexusServiceInterface The interface of the Nexus service to create a client for. + * @return A new {@link ServiceClient} instance backed by this {@link WorkflowClient} instance. + */ + @Experimental + ServiceClient newNexusServiceClient( + Class nexusServiceInterface, TemporalNexusServiceClientOptions options); + + /** + * Creates a new {@link CompletionClient} that can be used to complete or fail async operations + * + * @return A new {@link CompletionClient} instance backed by this {@link WorkflowClient} instance. + */ + @Experimental + CompletionClient newNexusCompletionClient(); + /** * Creates workflow client stub that can be used to start a single workflow execution. The first * call must be to a method annotated with @WorkflowMethod. After workflow is started it can be diff --git a/temporal-sdk/src/main/java/io/temporal/client/WorkflowClientInternalImpl.java b/temporal-sdk/src/main/java/io/temporal/client/WorkflowClientInternalImpl.java index c9aa37326d..19c701c896 100644 --- a/temporal-sdk/src/main/java/io/temporal/client/WorkflowClientInternalImpl.java +++ b/temporal-sdk/src/main/java/io/temporal/client/WorkflowClientInternalImpl.java @@ -6,6 +6,9 @@ import com.google.common.base.Strings; import com.google.common.reflect.TypeToken; import com.uber.m3.tally.Scope; +import io.nexusrpc.client.CompletionClient; +import io.nexusrpc.client.ServiceClient; +import io.nexusrpc.client.ServiceClientOptions; import io.temporal.api.common.v1.WorkflowExecution; import io.temporal.api.enums.v1.TaskReachability; import io.temporal.api.history.v1.History; @@ -13,14 +16,17 @@ import io.temporal.api.workflowservice.v1.*; import io.temporal.client.WorkflowInvocationHandler.InvocationType; import io.temporal.common.WorkflowExecutionHistory; +import io.temporal.common.interceptors.NexusServiceClientCallsInterceptor; import io.temporal.common.interceptors.WorkflowClientCallsInterceptor; import io.temporal.common.interceptors.WorkflowClientInterceptor; import io.temporal.internal.WorkflowThreadMarker; import io.temporal.internal.client.*; +import io.temporal.internal.client.NexusServiceClientCallsInterceptorRoot; import io.temporal.internal.client.NexusStartWorkflowResponse; import io.temporal.internal.client.external.GenericWorkflowClient; import io.temporal.internal.client.external.GenericWorkflowClientImpl; import io.temporal.internal.client.external.ManualActivityCompletionClientFactory; +import io.temporal.internal.nexus.PayloadSerializer; import io.temporal.internal.sync.StubMarker; import io.temporal.serviceclient.MetricsTag; import io.temporal.serviceclient.WorkflowServiceStubs; @@ -101,6 +107,32 @@ public WorkflowServiceStubs getWorkflowServiceStubs() { return workflowServiceStubs; } + @Override + public ServiceClient newNexusServiceClient( + Class nexusServiceInterface, TemporalNexusServiceClientOptions serviceClientOptions) { + NexusServiceClientCallsInterceptor interceptorChain = + new NexusServiceClientCallsInterceptorRoot(genericClient, options, serviceClientOptions); + for (WorkflowClientInterceptor interceptor : interceptors) { + interceptorChain = interceptor.nexusServiceClientCallsInterceptor(interceptorChain); + } + return new ServiceClient<>( + ServiceClientOptions.newBuilder(nexusServiceInterface) + .setTransport(new temporalNexusTransport(interceptorChain)) + .setSerializer(new PayloadSerializer(options.getDataConverter())) + .build()); + } + + @Override + public CompletionClient newNexusCompletionClient() { + NexusServiceClientCallsInterceptor interceptorChain = + new NexusServiceClientCallsInterceptorRoot( + genericClient, options, TemporalNexusServiceClientOptions.newBuilder().build()); + for (WorkflowClientInterceptor interceptor : interceptors) { + interceptorChain = interceptor.nexusServiceClientCallsInterceptor(interceptorChain); + } + return new CompletionClient(new temporalNexusTransport(interceptorChain)); + } + @Override public WorkflowClientOptions getOptions() { return options; diff --git a/temporal-sdk/src/main/java/io/temporal/client/temporalNexusTransport.java b/temporal-sdk/src/main/java/io/temporal/client/temporalNexusTransport.java new file mode 100644 index 0000000000..db3f37b960 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/client/temporalNexusTransport.java @@ -0,0 +1,132 @@ +package io.temporal.client; + +import io.nexusrpc.OperationException; +import io.nexusrpc.OperationStillRunningException; +import io.nexusrpc.client.transport.*; +import io.temporal.common.interceptors.NexusServiceClientCallsInterceptor; +import java.util.concurrent.CompletableFuture; + +class temporalNexusTransport implements Transport { + private final NexusServiceClientCallsInterceptor interceptor; + + public temporalNexusTransport(NexusServiceClientCallsInterceptor interceptor) { + this.interceptor = interceptor; + } + + @Override + public StartOperationResponse startOperation( + String operationName, String serviceName, Object input, StartOperationOptions options) + throws OperationException { + return interceptor + .startOperation( + new NexusServiceClientCallsInterceptor.StartOperationInput( + operationName, serviceName, input, options)) + .getResponse(); + } + + @Override + public FetchOperationResultResponse fetchOperationResult( + String operationName, + String serviceName, + String operationToken, + FetchOperationResultOptions options) + throws OperationException, OperationStillRunningException { + return interceptor + .fetchOperationResult( + new NexusServiceClientCallsInterceptor.FetchOperationResultInput( + operationName, serviceName, operationToken, options)) + .getResponse(); + } + + @Override + public FetchOperationInfoResponse fetchOperationInfo( + String operationName, + String serviceName, + String operationToken, + FetchOperationInfoOptions options) { + return interceptor + .fetchOperationInfo( + new NexusServiceClientCallsInterceptor.FetchOperationInfoInput( + operationName, serviceName, operationToken, options)) + .getResponse(); + } + + @Override + public CancelOperationResponse cancelOperation( + String operationName, + String serviceName, + String operationToken, + CancelOperationOptions options) { + return interceptor + .cancelOperation( + new NexusServiceClientCallsInterceptor.CancelOperationInput( + operationName, serviceName, operationToken, options)) + .getResponse(); + } + + @Override + public CompleteOperationResponse completeOperation(String url, CompleteOperationOptions options) { + return interceptor + .completeOperation( + new NexusServiceClientCallsInterceptor.CompleteOperationInput(url, options)) + .getResponse(); + } + + @Override + public CompletableFuture startOperationAsync( + String operationName, String serviceName, Object input, StartOperationOptions options) { + return interceptor + .startOperationAsync( + new NexusServiceClientCallsInterceptor.StartOperationInput( + operationName, serviceName, input, options)) + .thenApply(NexusServiceClientCallsInterceptor.StartOperationOutput::getResponse); + } + + @Override + public CompletableFuture fetchOperationResultAsync( + String operationName, + String serviceName, + String operationToken, + FetchOperationResultOptions options) { + return interceptor + .fetchOperationResultAsync( + new NexusServiceClientCallsInterceptor.FetchOperationResultInput( + operationName, serviceName, operationToken, options)) + .thenApply(NexusServiceClientCallsInterceptor.FetchOperationResultOutput::getResponse); + } + + @Override + public CompletableFuture fetchOperationInfoAsync( + String operationName, + String serviceName, + String operationToken, + FetchOperationInfoOptions options) { + return interceptor + .fetchOperationInfoAsync( + new NexusServiceClientCallsInterceptor.FetchOperationInfoInput( + operationName, serviceName, operationToken, options)) + .thenApply(NexusServiceClientCallsInterceptor.FetchOperationInfoOutput::getResponse); + } + + @Override + public CompletableFuture cancelOperationAsync( + String operationName, + String serviceName, + String operationToken, + CancelOperationOptions options) { + return interceptor + .cancelOperationAsync( + new NexusServiceClientCallsInterceptor.CancelOperationInput( + operationName, serviceName, operationToken, options)) + .thenApply(NexusServiceClientCallsInterceptor.CancelOperationOutput::getResponse); + } + + @Override + public CompletableFuture completeOperationAsync( + String url, CompleteOperationOptions options) { + return interceptor + .completeOperationAsync( + new NexusServiceClientCallsInterceptor.CompleteOperationAsyncInput(url, options)) + .thenApply(NexusServiceClientCallsInterceptor.CompleteOperationOutput::getResponse); + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/common/interceptors/NexusOperationInboundCallsInterceptor.java b/temporal-sdk/src/main/java/io/temporal/common/interceptors/NexusOperationInboundCallsInterceptor.java index ccf867fa40..a81e9f7666 100644 --- a/temporal-sdk/src/main/java/io/temporal/common/interceptors/NexusOperationInboundCallsInterceptor.java +++ b/temporal-sdk/src/main/java/io/temporal/common/interceptors/NexusOperationInboundCallsInterceptor.java @@ -1,6 +1,8 @@ package io.temporal.common.interceptors; import io.nexusrpc.OperationException; +import io.nexusrpc.OperationInfo; +import io.nexusrpc.OperationStillRunningException; import io.nexusrpc.handler.*; import io.temporal.common.Experimental; @@ -76,6 +78,69 @@ public OperationCancelDetails getCancelDetails() { final class CancelOperationOutput {} + final class FetchOperationResultInput { + private final OperationContext operationContext; + private final OperationFetchResultDetails operationFetchResultDetails; + + public FetchOperationResultInput( + OperationContext operationContext, + OperationFetchResultDetails operationFetchResultDetails) { + this.operationContext = operationContext; + this.operationFetchResultDetails = operationFetchResultDetails; + } + + public OperationContext getOperationContext() { + return operationContext; + } + + public OperationFetchResultDetails getOperationFetchResultDetails() { + return operationFetchResultDetails; + } + } + + final class FetchOperationResultOutput { + private final Object result; + + public FetchOperationResultOutput(Object result) { + this.result = result; + } + + public Object getResult() { + return result; + } + } + + final class FetchOperationInfoResponse { + private final OperationInfo operationInfo; + + public FetchOperationInfoResponse(OperationInfo operationInfo) { + this.operationInfo = operationInfo; + } + + public OperationInfo getOperationInfo() { + return operationInfo; + } + } + + final class FetchOperationInfoInput { + private final OperationContext operationContext; + private final OperationFetchInfoDetails operationFetchInfoDetails; + + public FetchOperationInfoInput( + OperationContext operationContext, OperationFetchInfoDetails operationFetchInfoDetails) { + this.operationContext = operationContext; + this.operationFetchInfoDetails = operationFetchInfoDetails; + } + + public OperationContext getOperationContext() { + return operationContext; + } + + public OperationFetchInfoDetails getOperationFetchInfoDetails() { + return operationFetchInfoDetails; + } + } + void init(NexusOperationOutboundCallsInterceptor outboundCalls); /** @@ -87,6 +152,25 @@ final class CancelOperationOutput {} */ StartOperationOutput startOperation(StartOperationInput input) throws OperationException; + /** + * Intercepts a call to fetch a Nexus operation result. + * + * @param input input to the operation result retrieval. + * @throws OperationStillRunningException if the operation is still running. + * @throws OperationException if the operation failed. + * @return result of the operation. + */ + FetchOperationResultOutput fetchOperationResult(FetchOperationResultInput input) + throws OperationStillRunningException, OperationException; + + /** + * Intercepts a call to fetch information about a Nexus operation. + * + * @param input input to the operation info retrieval. + * @return information about the operation. + */ + FetchOperationInfoResponse fetchOperationInfo(FetchOperationInfoInput input); + /** * Intercepts a call to cancel a Nexus operation. * diff --git a/temporal-sdk/src/main/java/io/temporal/common/interceptors/NexusOperationInboundCallsInterceptorBase.java b/temporal-sdk/src/main/java/io/temporal/common/interceptors/NexusOperationInboundCallsInterceptorBase.java index 06dcf94ba3..e3093023d6 100644 --- a/temporal-sdk/src/main/java/io/temporal/common/interceptors/NexusOperationInboundCallsInterceptorBase.java +++ b/temporal-sdk/src/main/java/io/temporal/common/interceptors/NexusOperationInboundCallsInterceptorBase.java @@ -1,6 +1,7 @@ package io.temporal.common.interceptors; import io.nexusrpc.OperationException; +import io.nexusrpc.OperationStillRunningException; import io.temporal.common.Experimental; /** Convenience base class for {@link NexusOperationInboundCallsInterceptor} implementations. */ @@ -23,6 +24,17 @@ public StartOperationOutput startOperation(StartOperationInput input) throws Ope return next.startOperation(input); } + @Override + public FetchOperationResultOutput fetchOperationResult(FetchOperationResultInput input) + throws OperationStillRunningException, OperationException { + return next.fetchOperationResult(input); + } + + @Override + public FetchOperationInfoResponse fetchOperationInfo(FetchOperationInfoInput input) { + return next.fetchOperationInfo(input); + } + @Override public CancelOperationOutput cancelOperation(CancelOperationInput input) { return next.cancelOperation(input); diff --git a/temporal-sdk/src/main/java/io/temporal/common/interceptors/NexusServiceClientCallsInterceptor.java b/temporal-sdk/src/main/java/io/temporal/common/interceptors/NexusServiceClientCallsInterceptor.java new file mode 100644 index 0000000000..98e89d149a --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/common/interceptors/NexusServiceClientCallsInterceptor.java @@ -0,0 +1,343 @@ +package io.temporal.common.interceptors; + +import io.nexusrpc.Experimental; +import io.nexusrpc.OperationException; +import io.nexusrpc.OperationStillRunningException; +import io.nexusrpc.client.transport.CancelOperationOptions; +import io.nexusrpc.client.transport.CancelOperationResponse; +import io.nexusrpc.client.transport.CompleteOperationOptions; +import io.nexusrpc.client.transport.CompleteOperationResponse; +import io.nexusrpc.client.transport.FetchOperationInfoOptions; +import io.nexusrpc.client.transport.FetchOperationInfoResponse; +import io.nexusrpc.client.transport.FetchOperationResultOptions; +import io.nexusrpc.client.transport.FetchOperationResultResponse; +import io.nexusrpc.client.transport.StartOperationOptions; +import java.util.concurrent.CompletableFuture; + +/** + * Intercepts calls made by a {@link io.nexusrpc.client.ServiceClient}. + * + *

Prefer extending {@link NexusServiceClientCallsInterceptorBase} and overriding only the + * methods you need instead of implementing this interface directly. {@link + * NexusServiceClientCallsInterceptorBase} provides correct default implementations to all the + * methods of this interface. + */ +@Experimental +public interface NexusServiceClientCallsInterceptor { + + /** + * Intercepts a request to start a Nexus operation. + * + * @param input operation start request + * @return output containing the start response + * @throws OperationException if the operation fails + */ + StartOperationOutput startOperation(StartOperationInput input) throws OperationException; + + /** + * Intercepts a request to fetch the result of a Nexus operation. + * + * @param input operation result request + * @return output containing the operation result + * @throws OperationException if the operation failed + * @throws OperationStillRunningException if the operation is still running + */ + FetchOperationResultOutput fetchOperationResult(FetchOperationResultInput input) + throws OperationException, OperationStillRunningException; + + /** + * Intercepts a request to fetch information about a Nexus operation. + * + * @param input operation info request + * @return output containing the operation information + */ + FetchOperationInfoOutput fetchOperationInfo(FetchOperationInfoInput input); + + /** + * Intercepts a request to cancel a Nexus operation. + * + * @param input cancellation request + * @return output containing the cancellation result + */ + CancelOperationOutput cancelOperation(CancelOperationInput input); + + /** + * Intercepts a request to complete a Nexus operation. + * + * @param input completion request + * @return output containing the completion result + */ + CompleteOperationOutput completeOperation(CompleteOperationInput input); + + /** + * Intercepts an asynchronous request to start a Nexus operation. + * + * @param input operation start request + * @return future containing the start response + */ + CompletableFuture startOperationAsync(StartOperationInput input); + + /** + * Intercepts an asynchronous request to fetch the result of a Nexus operation. + * + * @param input operation result request + * @return future containing the operation result + */ + CompletableFuture fetchOperationResultAsync( + FetchOperationResultInput input); + + /** + * Intercepts an asynchronous request to fetch information about a Nexus operation. + * + * @param input operation info request + * @return future containing the operation information + */ + CompletableFuture fetchOperationInfoAsync( + FetchOperationInfoInput input); + + /** + * Intercepts an asynchronous request to cancel a Nexus operation. + * + * @param input cancellation request + * @return future containing the cancellation result + */ + CompletableFuture cancelOperationAsync(CancelOperationInput input); + + /** + * Intercepts an asynchronous request to complete a Nexus operation. + * + * @param input completion request + * @return future containing the completion result + */ + CompletableFuture completeOperationAsync( + CompleteOperationAsyncInput input); + + final class StartOperationInput { + private final String operationName; + private final String serviceName; + private final Object input; + private final StartOperationOptions options; + + public StartOperationInput( + String operationName, String serviceName, Object input, StartOperationOptions options) { + this.operationName = operationName; + this.serviceName = serviceName; + this.input = input; + this.options = options; + } + + public String getOperationName() { + return operationName; + } + + public String getServiceName() { + return serviceName; + } + + public Object getInput() { + return input; + } + + public StartOperationOptions getOptions() { + return options; + } + } + + final class FetchOperationResultInput { + private final String operationName; + private final String serviceName; + private final String operationToken; + private final FetchOperationResultOptions options; + + public FetchOperationResultInput( + String operationName, + String serviceName, + String operationToken, + FetchOperationResultOptions options) { + this.operationName = operationName; + this.serviceName = serviceName; + this.operationToken = operationToken; + this.options = options; + } + + public String getOperationName() { + return operationName; + } + + public String getServiceName() { + return serviceName; + } + + public String getOperationToken() { + return operationToken; + } + + public FetchOperationResultOptions getOptions() { + return options; + } + } + + final class FetchOperationInfoInput { + private final String operationName; + private final String serviceName; + private final String operationToken; + private final FetchOperationInfoOptions options; + + public FetchOperationInfoInput( + String operationName, + String serviceName, + String operationToken, + FetchOperationInfoOptions options) { + this.operationName = operationName; + this.serviceName = serviceName; + this.operationToken = operationToken; + this.options = options; + } + + public String getOperationName() { + return operationName; + } + + public String getServiceName() { + return serviceName; + } + + public String getOperationToken() { + return operationToken; + } + + public FetchOperationInfoOptions getOptions() { + return options; + } + } + + final class CancelOperationInput { + private final String operationName; + private final String serviceName; + private final String operationToken; + private final CancelOperationOptions options; + + public CancelOperationInput( + String operationName, + String serviceName, + String operationToken, + CancelOperationOptions options) { + this.operationName = operationName; + this.serviceName = serviceName; + this.operationToken = operationToken; + this.options = options; + } + + public String getOperationName() { + return operationName; + } + + public String getServiceName() { + return serviceName; + } + + public String getOperationToken() { + return operationToken; + } + + public CancelOperationOptions getOptions() { + return options; + } + } + + final class CompleteOperationInput { + private final String url; + private final CompleteOperationOptions options; + + public CompleteOperationInput(String url, CompleteOperationOptions options) { + this.url = url; + this.options = options; + } + + public String getUrl() { + return url; + } + + public CompleteOperationOptions getOptions() { + return options; + } + } + + final class CompleteOperationAsyncInput { + private final String url; + private final CompleteOperationOptions options; + + public CompleteOperationAsyncInput(String url, CompleteOperationOptions options) { + this.url = url; + this.options = options; + } + + public String getUrl() { + return url; + } + + public CompleteOperationOptions getOptions() { + return options; + } + } + + final class StartOperationOutput { + private final io.nexusrpc.client.transport.StartOperationResponse response; + + public StartOperationOutput(io.nexusrpc.client.transport.StartOperationResponse response) { + this.response = response; + } + + public io.nexusrpc.client.transport.StartOperationResponse getResponse() { + return response; + } + } + + final class FetchOperationResultOutput { + private final FetchOperationResultResponse response; + + public FetchOperationResultOutput(FetchOperationResultResponse response) { + this.response = response; + } + + public FetchOperationResultResponse getResponse() { + return response; + } + } + + final class FetchOperationInfoOutput { + private final FetchOperationInfoResponse response; + + public FetchOperationInfoOutput(FetchOperationInfoResponse response) { + this.response = response; + } + + public FetchOperationInfoResponse getResponse() { + return response; + } + } + + final class CancelOperationOutput { + private final CancelOperationResponse response; + + public CancelOperationOutput(CancelOperationResponse response) { + this.response = response; + } + + public CancelOperationResponse getResponse() { + return response; + } + } + + final class CompleteOperationOutput { + private final CompleteOperationResponse response; + + public CompleteOperationOutput(CompleteOperationResponse response) { + this.response = response; + } + + public CompleteOperationResponse getResponse() { + return response; + } + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/common/interceptors/NexusServiceClientCallsInterceptorBase.java b/temporal-sdk/src/main/java/io/temporal/common/interceptors/NexusServiceClientCallsInterceptorBase.java new file mode 100644 index 0000000000..7353058cae --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/common/interceptors/NexusServiceClientCallsInterceptorBase.java @@ -0,0 +1,71 @@ +package io.temporal.common.interceptors; + +import io.nexusrpc.Experimental; +import io.nexusrpc.OperationException; +import io.nexusrpc.OperationStillRunningException; +import java.util.concurrent.CompletableFuture; + +/** Convenience base class for {@link NexusServiceClientCallsInterceptor} implementations. */ +@Experimental +public class NexusServiceClientCallsInterceptorBase implements NexusServiceClientCallsInterceptor { + + private final NexusServiceClientCallsInterceptor next; + + public NexusServiceClientCallsInterceptorBase(NexusServiceClientCallsInterceptor next) { + this.next = next; + } + + @Override + public StartOperationOutput startOperation(StartOperationInput input) throws OperationException { + return next.startOperation(input); + } + + @Override + public FetchOperationResultOutput fetchOperationResult(FetchOperationResultInput input) + throws OperationException, OperationStillRunningException { + return next.fetchOperationResult(input); + } + + @Override + public FetchOperationInfoOutput fetchOperationInfo(FetchOperationInfoInput input) { + return next.fetchOperationInfo(input); + } + + @Override + public CancelOperationOutput cancelOperation(CancelOperationInput input) { + return next.cancelOperation(input); + } + + @Override + public CompleteOperationOutput completeOperation(CompleteOperationInput input) { + return next.completeOperation(input); + } + + @Override + public CompletableFuture startOperationAsync(StartOperationInput input) { + return next.startOperationAsync(input); + } + + @Override + public CompletableFuture fetchOperationResultAsync( + FetchOperationResultInput input) { + return next.fetchOperationResultAsync(input); + } + + @Override + public CompletableFuture fetchOperationInfoAsync( + FetchOperationInfoInput input) { + return next.fetchOperationInfoAsync(input); + } + + @Override + public CompletableFuture cancelOperationAsync(CancelOperationInput input) { + return next.cancelOperationAsync(input); + } + + @Override + public CompletableFuture completeOperationAsync( + CompleteOperationAsyncInput input) { + return next.completeOperationAsync(input); + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/common/interceptors/WorkflowClientInterceptor.java b/temporal-sdk/src/main/java/io/temporal/common/interceptors/WorkflowClientInterceptor.java index 8d8db247c3..bcef237903 100644 --- a/temporal-sdk/src/main/java/io/temporal/common/interceptors/WorkflowClientInterceptor.java +++ b/temporal-sdk/src/main/java/io/temporal/common/interceptors/WorkflowClientInterceptor.java @@ -39,6 +39,19 @@ WorkflowStub newUntypedWorkflowStub( ActivityCompletionClient newActivityCompletionClient(ActivityCompletionClient next); + /** + * Called when a Nexus {@link io.nexusrpc.client.ServiceClient} is created through {@link + * io.temporal.client.WorkflowClient#newNexusServiceClient(Class, + * io.temporal.client.TemporalNexusServiceClientOptions)}. Allows decorating the temporalTransport + * used by the service client. + * + * @param next next interceptor in the chain + * @return interceptor that should decorate calls to {@code next} + */ + @Experimental + NexusServiceClientCallsInterceptor nexusServiceClientCallsInterceptor( + NexusServiceClientCallsInterceptor next); + /** * Called once during creation of WorkflowClient to create a chain of Client Workflow Interceptors * diff --git a/temporal-sdk/src/main/java/io/temporal/common/interceptors/WorkflowClientInterceptorBase.java b/temporal-sdk/src/main/java/io/temporal/common/interceptors/WorkflowClientInterceptorBase.java index 832df4c59e..138b8fc0f6 100644 --- a/temporal-sdk/src/main/java/io/temporal/common/interceptors/WorkflowClientInterceptorBase.java +++ b/temporal-sdk/src/main/java/io/temporal/common/interceptors/WorkflowClientInterceptorBase.java @@ -28,6 +28,12 @@ public ActivityCompletionClient newActivityCompletionClient(ActivityCompletionCl return next; } + @Override + public NexusServiceClientCallsInterceptor nexusServiceClientCallsInterceptor( + NexusServiceClientCallsInterceptor next) { + return next; + } + @Override public WorkflowClientCallsInterceptor workflowClientCallsInterceptor( WorkflowClientCallsInterceptor next) { diff --git a/temporal-sdk/src/main/java/io/temporal/internal/client/NexusServiceClientCallsInterceptorRoot.java b/temporal-sdk/src/main/java/io/temporal/internal/client/NexusServiceClientCallsInterceptorRoot.java new file mode 100644 index 0000000000..5dd4052219 --- /dev/null +++ b/temporal-sdk/src/main/java/io/temporal/internal/client/NexusServiceClientCallsInterceptorRoot.java @@ -0,0 +1,543 @@ +package io.temporal.internal.client; + +import static io.temporal.internal.common.NexusFailureUtil.nexusFailureToAPIFailure; +import static io.temporal.internal.common.NexusUtil.exceptionToNexusFailure; + +import com.google.common.base.Strings; +import io.grpc.StatusRuntimeException; +import io.nexusrpc.*; +import io.nexusrpc.client.transport.*; +import io.temporal.api.common.v1.Callback; +import io.temporal.api.nexus.v1.HandlerError; +import io.temporal.api.nexus.v1.TaskDispatchTarget; +import io.temporal.api.nexus.v1.UnsuccessfulOperationError; +import io.temporal.api.workflowservice.v1.*; +import io.temporal.client.TemporalNexusServiceClientOptions; +import io.temporal.client.WorkflowClientOptions; +import io.temporal.common.interceptors.NexusServiceClientCallsInterceptor; +import io.temporal.internal.client.external.GenericWorkflowClient; +import io.temporal.internal.common.NexusFailureUtil; +import io.temporal.internal.common.NexusUtil; +import io.temporal.internal.common.ProtobufTimeUtils; +import java.time.Duration; +import java.time.Instant; +import java.util.UUID; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; + +public final class NexusServiceClientCallsInterceptorRoot + implements NexusServiceClientCallsInterceptor { + private final GenericWorkflowClient client; + private final WorkflowClientOptions clientOptions; + private final TaskDispatchTarget dispatchTarget; + + public NexusServiceClientCallsInterceptorRoot( + GenericWorkflowClient client, + WorkflowClientOptions clientOptions, + TemporalNexusServiceClientOptions serviceClientOptions) { + this.client = client; + this.clientOptions = clientOptions; + this.dispatchTarget = + TaskDispatchTarget.newBuilder().setEndpoint(serviceClientOptions.getEndpoint()).build(); + } + + private OperationState deserializeOperationState(String state) { + switch (state) { + case "running": + return OperationState.RUNNING; + case "succeeded": + return OperationState.SUCCEEDED; + case "failed": + return OperationState.FAILED; + case "canceled": + return OperationState.CANCELED; + default: + throw new IllegalArgumentException("Unknown operation state: " + state); + } + } + + private StartNexusOperationRequest createStartOperationRequest( + String operationName, String serviceName, Object input, StartOperationOptions options) { + StartNexusOperationRequest.Builder request = + StartNexusOperationRequest.newBuilder() + .setIdentity(clientOptions.getIdentity()) + .setNamespace(clientOptions.getNamespace()) + .setTarget(dispatchTarget) + .setOperation(operationName) + .setService(serviceName) + .putAllCallbackHeader(options.getCallbackHeaders()) + .putAllHeader(options.getHeaders()); + + if (Strings.isNullOrEmpty(options.getRequestId())) { + request.setRequestId(UUID.randomUUID().toString()); + } else { + request.setRequestId(options.getRequestId()); + } + + if (!Strings.isNullOrEmpty(options.getCallbackURL())) { + request.setCallback(options.getCallbackURL()); + } + + clientOptions.getDataConverter().toPayload(input).ifPresent(request::setPayload); + + options.getInboundLinks().stream() + .map( + link -> + io.temporal.api.nexus.v1.Link.newBuilder() + .setType(link.getType()) + .setUrl(link.getUri().toString()) + .build()) + .forEach(request::addLinks); + return request.build(); + } + + private StartOperationResponse createStartOperationResponse(StartNexusOperationResponse response) + throws OperationException { + if (response.hasSyncSuccess()) { + StartNexusOperationResponse.Sync syncResult = response.getSyncSuccess(); + return StartOperationResponse.newBuilder() + .setResult( + Serializer.Content.newBuilder().setData(syncResult.getResult().toByteArray()).build()) + .build(); + } else if (response.hasAsyncSuccess()) { + StartNexusOperationResponse.Async asyncResult = response.getAsyncSuccess(); + return StartOperationResponse.newBuilder() + .setAsyncOperationToken(asyncResult.getOperationToken()) + .build(); + } else if (response.hasUnsuccessful()) { + StartNexusOperationResponse.Unsuccessful unsuccessful = response.getUnsuccessful(); + UnsuccessfulOperationError error = unsuccessful.getOperationError(); + Throwable cause = + clientOptions + .getDataConverter() + .failureToException( + NexusFailureUtil.nexusFailureToAPIFailure(error.getFailure(), false)); + if (error.getOperationState().equals("canceled")) { + throw OperationException.canceled(cause); + } else { + throw OperationException.failure(cause); + } + } else if (response.hasHandlerError()) { + HandlerError error = response.getHandlerError(); + throw clientOptions + .getDataConverter() + .failureToException(NexusFailureUtil.handlerErrorToFailure(error)); + } else { + throw new IllegalStateException("Unknown response from startNexusCall: " + response); + } + } + + @Override + public StartOperationOutput startOperation(StartOperationInput input) throws OperationException { + try { + StartNexusOperationResponse response = + client.startNexusOperation( + createStartOperationRequest( + input.getOperationName(), + input.getServiceName(), + input.getInput(), + input.getOptions())); + return new StartOperationOutput(createStartOperationResponse(response)); + } catch (StatusRuntimeException sre) { + throw NexusUtil.grpcExceptionToHandlerException(sre); + } + } + + private GetNexusOperationResultRequest createGetNexusOperationResultRequest( + String operationName, + String serviceName, + String operationToken, + FetchOperationResultOptions options) { + GetNexusOperationResultRequest.Builder request = + GetNexusOperationResultRequest.newBuilder() + .setIdentity(clientOptions.getIdentity()) + .setNamespace(clientOptions.getNamespace()) + .setOperation(operationName) + .setService(serviceName) + .setTarget(dispatchTarget) + .setOperationToken(operationToken) + .setWait(ProtobufTimeUtils.toProtoDuration(options.getTimeout())); + + options.getHeaders().forEach(request::putHeader); + + return request.build(); + } + + private FetchOperationResultResponse createGetOperationResultResponse( + GetNexusOperationResultResponse response) + throws OperationException, OperationStillRunningException { + if (response.hasSuccessful()) { + GetNexusOperationResultResponse.Successful successful = response.getSuccessful(); + return FetchOperationResultResponse.newBuilder() + .setResult( + Serializer.Content.newBuilder().setData(successful.getResult().toByteArray()).build()) + .build(); + } else if (response.hasUnsuccessful()) { + GetNexusOperationResultResponse.Unsuccessful unsuccessful = response.getUnsuccessful(); + UnsuccessfulOperationError error = unsuccessful.getOperationError(); + Throwable cause = + clientOptions + .getDataConverter() + .failureToException(nexusFailureToAPIFailure(error.getFailure(), false)); + if (error.getOperationState().equals("canceled")) { + throw OperationException.canceled(cause); + } else { + throw OperationException.failure(cause); + } + } else if (response.hasHandlerError()) { + HandlerError error = response.getHandlerError(); + throw clientOptions + .getDataConverter() + .failureToException(NexusFailureUtil.handlerErrorToFailure(error)); + } else if (response.hasStillRunning()) { + throw new OperationStillRunningException(); + } else { + throw new IllegalStateException("Unknown response from startNexusCall: " + response); + } + } + + @Override + public FetchOperationResultOutput fetchOperationResult(FetchOperationResultInput input) + throws OperationException, OperationStillRunningException { + Instant startTime = Instant.now(); + while (true) { + try { + try { + GetNexusOperationResultResponse response = + client.getNexusOperationResult( + createGetNexusOperationResultRequest( + input.getOperationName(), + input.getServiceName(), + input.getOperationToken(), + input.getOptions())); + return new FetchOperationResultOutput(createGetOperationResultResponse(response)); + } catch (StatusRuntimeException sre) { + throw NexusUtil.grpcExceptionToHandlerException(sre); + } + } catch (OperationStillRunningException e) { + // If the operation is still running, we wait for the specified timeout before retrying. + if (Instant.now().isAfter(startTime.plus(input.getOptions().getTimeout()))) { + throw e; // Timeout reached, rethrow the exception. + } + // TODO implement exponential backoff or other retry strategies. + } + } + } + + private GetNexusOperationInfoRequest createGetNexusOperationInfoRequest( + String operationName, + String serviceName, + String operationToken, + FetchOperationInfoOptions options) { + GetNexusOperationInfoRequest.Builder request = + GetNexusOperationInfoRequest.newBuilder() + .setIdentity(clientOptions.getIdentity()) + .setNamespace(clientOptions.getNamespace()) + .setTarget(dispatchTarget) + .setOperation(operationName) + .setService(serviceName) + .setOperationToken(operationToken); + + options.getHeaders().forEach(request::putHeader); + + return request.build(); + } + + private FetchOperationInfoResponse createGetOperationInfoResponse( + GetNexusOperationInfoResponse response) { + if (response.hasHandlerError()) { + HandlerError error = response.getHandlerError(); + throw clientOptions + .getDataConverter() + .failureToException(NexusFailureUtil.handlerErrorToFailure(error)); + } + + return FetchOperationInfoResponse.newBuilder() + .setOperationInfo( + OperationInfo.newBuilder() + .setToken(response.getInfo().getToken()) + .setState(deserializeOperationState(response.getInfo().getState())) + .build()) + .build(); + } + + @Override + public FetchOperationInfoOutput fetchOperationInfo(FetchOperationInfoInput input) { + try { + return new FetchOperationInfoOutput( + createGetOperationInfoResponse( + client.getNexusOperationInfo( + createGetNexusOperationInfoRequest( + input.getOperationName(), + input.getServiceName(), + input.getOperationToken(), + input.getOptions())))); + } catch (StatusRuntimeException sre) { + throw NexusUtil.grpcExceptionToHandlerException(sre); + } + } + + private RequestCancelNexusOperationRequest createRequestCancelNexusOperationRequest( + String operationName, + String serviceName, + String operationToken, + CancelOperationOptions options) { + RequestCancelNexusOperationRequest.Builder request = + RequestCancelNexusOperationRequest.newBuilder() + .setIdentity(clientOptions.getIdentity()) + .setNamespace(clientOptions.getNamespace()) + .setTarget(dispatchTarget) + .setOperation(operationName) + .setService(serviceName) + .setOperationToken(operationToken); + + options.getHeaders().forEach(request::putHeader); + + return request.build(); + } + + private CancelOperationResponse createRequestCancelNexusOperationResponse( + RequestCancelNexusOperationResponse response) { + if (response.hasHandlerError()) { + HandlerError error = response.getHandlerError(); + throw clientOptions + .getDataConverter() + .failureToException(NexusFailureUtil.handlerErrorToFailure(error)); + } + + return new CancelOperationResponse(); + } + + @Override + public CancelOperationOutput cancelOperation(CancelOperationInput input) { + try { + return new CancelOperationOutput( + createRequestCancelNexusOperationResponse( + client.requestCancelNexusOperation( + createRequestCancelNexusOperationRequest( + input.getOperationName(), + input.getServiceName(), + input.getOperationToken(), + input.getOptions())))); + } catch (StatusRuntimeException sre) { + throw NexusUtil.grpcExceptionToHandlerException(sre); + } + } + + private CompleteNexusOperationRequest createCompleteNexusOperationRequest( + String url, CompleteOperationOptions options) { + Callback.Nexus.Builder callbackBuilder = Callback.Nexus.newBuilder().setUrl(url); + if (options.getHeaders() != null) { + callbackBuilder.putAllHeader(options.getHeaders()); + } + + CompleteNexusOperationRequest.Builder request = + CompleteNexusOperationRequest.newBuilder() + .setIdentity(clientOptions.getIdentity()) + .setNamespace(clientOptions.getNamespace()) + .setCallback(callbackBuilder.build()); + + request.setRequestId(UUID.randomUUID().toString()); + + if (options.getStartTime() != null) { + request.setStartedTime(ProtobufTimeUtils.toProtoTimestamp(options.getStartTime())); + } + + if (options.getLinks() != null) { + options.getLinks().stream() + .map( + link -> + io.temporal.api.nexus.v1.Link.newBuilder() + .setType(link.getType()) + .setUrl(link.getUri().toString()) + .build()) + .forEach(request::addLinks); + } + + if (options.getResult() != null) { + request.setResult(clientOptions.getDataConverter().toPayload(options.getResult()).get()); + } else if (options.getError() != null) { + OperationException operationException = options.getError(); + request.setOperationError( + UnsuccessfulOperationError.newBuilder() + .setOperationState(options.getError().getState().toString().toLowerCase()) + .setFailure( + exceptionToNexusFailure(operationException, clientOptions.getDataConverter())) + .build()); + } + return request.build(); + } + + private CompleteOperationResponse createCompleteOperationResponse( + CompleteNexusOperationResponse response) { + if (response.hasHandlerError()) { + HandlerError error = response.getHandlerError(); + throw clientOptions + .getDataConverter() + .failureToException(NexusFailureUtil.handlerErrorToFailure(error)); + } + + return new CompleteOperationResponse(); + } + + @Override + public CompleteOperationOutput completeOperation(CompleteOperationInput input) { + try { + return new CompleteOperationOutput( + createCompleteOperationResponse( + client.completeNexusOperation( + createCompleteNexusOperationRequest(input.getUrl(), input.getOptions())))); + } catch (StatusRuntimeException sre) { + throw NexusUtil.grpcExceptionToHandlerException(sre); + } + } + + @Override + public CompletableFuture startOperationAsync(StartOperationInput input) { + return client + .startNexusOperationAsync( + createStartOperationRequest( + input.getOperationName(), + input.getServiceName(), + input.getInput(), + input.getOptions())) + .thenApply( + response -> { + try { + return createStartOperationResponse(response); + } catch (OperationException e) { + throw new CompletionException(e); + } + }) + .thenApply(StartOperationOutput::new) + .exceptionally( + ex -> { + if (ex.getCause() instanceof StatusRuntimeException) { + throw NexusUtil.grpcExceptionToHandlerException( + (StatusRuntimeException) ex.getCause()); + } else if (ex instanceof RuntimeException) { + throw (RuntimeException) ex; + } else { + throw new CompletionException(ex); + } + }); + } + + private CompletableFuture waitForResult( + Instant startTime, + Duration timeout, + GetNexusOperationResultRequest request, + GetNexusOperationResultResponse response) { + try { + return CompletableFuture.completedFuture(createGetOperationResultResponse(response)); + } catch (OperationException e) { + throw new CompletionException(e); + } catch (OperationStillRunningException e) { + if (Instant.now().isAfter(startTime.plus(timeout))) { + throw new CompletionException(e); + } + return client + .getNexusOperationResultAsync(request) + .thenComposeAsync(r -> waitForResult(startTime, timeout, request, r)); + } + } + + @Override + public CompletableFuture fetchOperationResultAsync( + FetchOperationResultInput input) { + Instant startTime = Instant.now(); + GetNexusOperationResultRequest request = + createGetNexusOperationResultRequest( + input.getOperationName(), + input.getServiceName(), + input.getOperationToken(), + input.getOptions()); + CompletableFuture response = + client.getNexusOperationResultAsync(request); + return response + .thenComposeAsync( + r -> waitForResult(startTime, input.getOptions().getTimeout(), request, r)) + .thenApply(FetchOperationResultOutput::new) + .exceptionally( + ex -> { + if (ex.getCause() instanceof StatusRuntimeException) { + throw NexusUtil.grpcExceptionToHandlerException( + (StatusRuntimeException) ex.getCause()); + } else if (ex instanceof RuntimeException) { + throw (RuntimeException) ex; + } else { + throw new CompletionException(ex); + } + }); + } + + @Override + public CompletableFuture fetchOperationInfoAsync( + FetchOperationInfoInput input) { + return client + .getNexusOperationInfoAsync( + createGetNexusOperationInfoRequest( + input.getOperationName(), + input.getServiceName(), + input.getOperationToken(), + input.getOptions())) + .thenApply(this::createGetOperationInfoResponse) + .thenApply(FetchOperationInfoOutput::new) + .exceptionally( + ex -> { + if (ex.getCause() instanceof StatusRuntimeException) { + throw NexusUtil.grpcExceptionToHandlerException( + (StatusRuntimeException) ex.getCause()); + } else if (ex instanceof RuntimeException) { + throw (RuntimeException) ex; + } else { + throw new CompletionException(ex); + } + }); + } + + @Override + public CompletableFuture cancelOperationAsync(CancelOperationInput input) { + return client + .requestCancelNexusOperationAsync( + createRequestCancelNexusOperationRequest( + input.getOperationName(), + input.getServiceName(), + input.getOperationToken(), + input.getOptions())) + .thenApply(this::createRequestCancelNexusOperationResponse) + .thenApply(CancelOperationOutput::new) + .exceptionally( + ex -> { + if (ex.getCause() instanceof StatusRuntimeException) { + throw NexusUtil.grpcExceptionToHandlerException( + (StatusRuntimeException) ex.getCause()); + } else if (ex instanceof RuntimeException) { + throw (RuntimeException) ex; + } else { + throw new CompletionException(ex); + } + }); + } + + @Override + public CompletableFuture completeOperationAsync( + CompleteOperationAsyncInput input) { + return client + .completeNexusOperationAsync( + createCompleteNexusOperationRequest(input.getUrl(), input.getOptions())) + .thenApply(this::createCompleteOperationResponse) + .thenApply(CompleteOperationOutput::new) + .exceptionally( + ex -> { + if (ex.getCause() instanceof StatusRuntimeException) { + throw NexusUtil.grpcExceptionToHandlerException( + (StatusRuntimeException) ex.getCause()); + } else if (ex instanceof RuntimeException) { + throw (RuntimeException) ex; + } else { + throw new CompletionException(ex); + } + }); + } +} diff --git a/temporal-sdk/src/main/java/io/temporal/internal/client/external/GenericWorkflowClient.java b/temporal-sdk/src/main/java/io/temporal/internal/client/external/GenericWorkflowClient.java index 96d1db3b20..f544396641 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/client/external/GenericWorkflowClient.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/client/external/GenericWorkflowClient.java @@ -75,4 +75,40 @@ GetWorkerBuildIdCompatibilityResponse getWorkerBuildIdCompatability( @Experimental GetWorkerTaskReachabilityResponse GetWorkerTaskReachability(GetWorkerTaskReachabilityRequest req); + + @Experimental + GetNexusOperationInfoResponse getNexusOperationInfo(GetNexusOperationInfoRequest request); + + @Experimental + StartNexusOperationResponse startNexusOperation(StartNexusOperationRequest request); + + @Experimental + RequestCancelNexusOperationResponse requestCancelNexusOperation( + RequestCancelNexusOperationRequest request); + + @Experimental + GetNexusOperationResultResponse getNexusOperationResult(GetNexusOperationResultRequest request); + + @Experimental + CompleteNexusOperationResponse completeNexusOperation(CompleteNexusOperationRequest request); + + @Experimental + CompletableFuture getNexusOperationInfoAsync( + GetNexusOperationInfoRequest request); + + @Experimental + CompletableFuture startNexusOperationAsync( + StartNexusOperationRequest request); + + @Experimental + CompletableFuture requestCancelNexusOperationAsync( + RequestCancelNexusOperationRequest request); + + @Experimental + CompletableFuture getNexusOperationResultAsync( + GetNexusOperationResultRequest request); + + @Experimental + CompletableFuture completeNexusOperationAsync( + CompleteNexusOperationRequest request); } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/client/external/GenericWorkflowClientImpl.java b/temporal-sdk/src/main/java/io/temporal/internal/client/external/GenericWorkflowClientImpl.java index 044894d74b..7ae0cb18c1 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/client/external/GenericWorkflowClientImpl.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/client/external/GenericWorkflowClientImpl.java @@ -60,6 +60,13 @@ private static Map tagsForStartWorkflow(StartWorkflowExecutionRe .build(); } + private static Map tagsForNexusOperations(String service, String operation) { + return new ImmutableMap.Builder(2) + .put(MetricsTag.NEXUS_SERVICE, service) + .put(MetricsTag.OPERATION_NAME, operation) + .build(); + } + @Override public void signal(SignalWorkflowExecutionRequest request) { Map tags = @@ -423,4 +430,149 @@ public ExecuteMultiOperationResponse executeMultiOperation( .executeMultiOperation(req), grpcRetryerOptions); } + + @Override + public GetNexusOperationInfoResponse getNexusOperationInfo(GetNexusOperationInfoRequest request) { + return grpcRetryer.retryWithResult( + () -> + service + .blockingStub() + .withOption(METRICS_TAGS_CALL_OPTIONS_KEY, metricsScope) + .getNexusOperationInfo(request), + grpcRetryerOptions); + } + + @Override + public StartNexusOperationResponse startNexusOperation(StartNexusOperationRequest request) { + Map tags = tagsForNexusOperations(request.getService(), request.getOperation()); + Scope scope = metricsScope.tagged(tags); + return grpcRetryer.retryWithResult( + () -> + service + .blockingStub() + .withOption(METRICS_TAGS_CALL_OPTIONS_KEY, scope) + .startNexusOperation(request), + grpcRetryerOptions); + } + + @Override + public RequestCancelNexusOperationResponse requestCancelNexusOperation( + RequestCancelNexusOperationRequest request) { + Map tags = tagsForNexusOperations(request.getService(), request.getOperation()); + Scope scope = metricsScope.tagged(tags); + return grpcRetryer.retryWithResult( + () -> + service + .blockingStub() + .withOption(METRICS_TAGS_CALL_OPTIONS_KEY, scope) + .requestCancelNexusOperation(request), + grpcRetryerOptions); + } + + @Override + public GetNexusOperationResultResponse getNexusOperationResult( + GetNexusOperationResultRequest request) { + Map tags = tagsForNexusOperations(request.getService(), request.getOperation()); + Scope scope = metricsScope.tagged(tags); + // Deadline deadline = + // Deadline.after(request.getWait().getSeconds() * 1000, TimeUnit.MILLISECONDS); + return grpcRetryer.retryWithResult( + () -> + service + .blockingStub() + .withOption(METRICS_TAGS_CALL_OPTIONS_KEY, scope) + // .withDeadline(deadline) + .getNexusOperationResult(request), + grpcRetryerOptions); + } + + @Override + public CompleteNexusOperationResponse completeNexusOperation( + CompleteNexusOperationRequest request) { + return grpcRetryer.retryWithResult( + () -> + service + .blockingStub() + .withOption(METRICS_TAGS_CALL_OPTIONS_KEY, metricsScope) + .completeNexusOperation(request), + grpcRetryerOptions); + } + + @Override + public CompletableFuture getNexusOperationInfoAsync( + GetNexusOperationInfoRequest request) { + Map tags = tagsForNexusOperations(request.getService(), request.getOperation()); + Scope scope = metricsScope.tagged(tags); + return grpcRetryer.retryWithResultAsync( + asyncThrottlerExecutor, + () -> + toCompletableFuture( + service + .futureStub() + .withOption(METRICS_TAGS_CALL_OPTIONS_KEY, scope) + .getNexusOperationInfo(request)), + grpcRetryerOptions); + } + + @Override + public CompletableFuture startNexusOperationAsync( + StartNexusOperationRequest request) { + Map tags = tagsForNexusOperations(request.getService(), request.getOperation()); + Scope scope = metricsScope.tagged(tags); + return grpcRetryer.retryWithResultAsync( + asyncThrottlerExecutor, + () -> + toCompletableFuture( + service + .futureStub() + .withOption(METRICS_TAGS_CALL_OPTIONS_KEY, scope) + .startNexusOperation(request)), + grpcRetryerOptions); + } + + @Override + public CompletableFuture requestCancelNexusOperationAsync( + RequestCancelNexusOperationRequest request) { + Map tags = tagsForNexusOperations(request.getService(), request.getOperation()); + Scope scope = metricsScope.tagged(tags); + return grpcRetryer.retryWithResultAsync( + asyncThrottlerExecutor, + () -> + toCompletableFuture( + service + .futureStub() + .withOption(METRICS_TAGS_CALL_OPTIONS_KEY, scope) + .requestCancelNexusOperation(request)), + grpcRetryerOptions); + } + + @Override + public CompletableFuture getNexusOperationResultAsync( + GetNexusOperationResultRequest request) { + Map tags = tagsForNexusOperations(request.getService(), request.getOperation()); + Scope scope = metricsScope.tagged(tags); + return grpcRetryer.retryWithResultAsync( + asyncThrottlerExecutor, + () -> + toCompletableFuture( + service + .futureStub() + .withOption(METRICS_TAGS_CALL_OPTIONS_KEY, scope) + .getNexusOperationResult(request)), + grpcRetryerOptions); + } + + @Override + public CompletableFuture completeNexusOperationAsync( + CompleteNexusOperationRequest request) { + return grpcRetryer.retryWithResultAsync( + asyncThrottlerExecutor, + () -> + toCompletableFuture( + service + .futureStub() + .withOption(METRICS_TAGS_CALL_OPTIONS_KEY, metricsScope) + .completeNexusOperation(request)), + grpcRetryerOptions); + } } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/common/InternalUtils.java b/temporal-sdk/src/main/java/io/temporal/internal/common/InternalUtils.java index a494f2e83b..9e646ffc2d 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/common/InternalUtils.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/common/InternalUtils.java @@ -2,6 +2,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.google.common.base.Defaults; +import com.google.common.base.Strings; import io.nexusrpc.Header; import io.nexusrpc.handler.HandlerException; import io.nexusrpc.handler.ServiceImplInstance; @@ -88,23 +89,7 @@ public static NexusWorkflowStarter createNexusBoundStub( HandlerException.ErrorType.BAD_REQUEST, new IllegalArgumentException("failed to generate workflow operation token", e)); } - // Add the Nexus operation ID to the headers if it is not already present to support fabricating - // a NexusOperationStarted event if the completion is received before the response to a - // StartOperation request. - Map headers = - request.getCallbackHeaders().entrySet().stream() - .collect( - Collectors.toMap( - (k) -> k.getKey().toLowerCase(), - Map.Entry::getValue, - (a, b) -> a, - () -> new TreeMap<>(String.CASE_INSENSITIVE_ORDER))); - if (!headers.containsKey(Header.OPERATION_ID)) { - headers.put(Header.OPERATION_ID.toLowerCase(), operationToken); - } - if (!headers.containsKey(Header.OPERATION_TOKEN)) { - headers.put(Header.OPERATION_TOKEN.toLowerCase(), operationToken); - } + List links = request.getLinks() == null ? null @@ -127,21 +112,43 @@ public static NexusWorkflowStarter createNexusBoundStub( }) .filter(Objects::nonNull) .collect(Collectors.toList()); - Callback.Builder cbBuilder = - Callback.newBuilder() - .setNexus( - Callback.Nexus.newBuilder() - .setUrl(request.getCallbackUrl()) - .putAllHeader(headers) - .build()); - if (links != null) { - cbBuilder.addAllLinks(links); - } + WorkflowOptions.Builder nexusWorkflowOptions = - WorkflowOptions.newBuilder(options) - .setRequestId(request.getRequestId()) - .setCompletionCallbacks(Collections.singletonList(cbBuilder.build())) - .setLinks(links); + WorkflowOptions.newBuilder(options).setRequestId(request.getRequestId()).setLinks(links); + + if (!Strings.isNullOrEmpty(request.getCallbackUrl())) { + // Add the Nexus operation ID to the headers if it is not already present to support + // fabricating + // a NexusOperationStarted event if the completion is received before the response to a + // StartOperation request. + Map callbackHeaders = + request.getCallbackHeaders().entrySet().stream() + .collect( + Collectors.toMap( + (k) -> k.getKey().toLowerCase(), + Map.Entry::getValue, + (a, b) -> a, + () -> new TreeMap<>(String.CASE_INSENSITIVE_ORDER))); + if (!callbackHeaders.containsKey(Header.OPERATION_ID)) { + callbackHeaders.put(Header.OPERATION_ID.toLowerCase(), operationToken); + } + if (!callbackHeaders.containsKey(Header.OPERATION_TOKEN)) { + callbackHeaders.put(Header.OPERATION_TOKEN.toLowerCase(), operationToken); + } + + Callback.Builder cbBuilder = + Callback.newBuilder() + .setNexus( + Callback.Nexus.newBuilder() + .setUrl(request.getCallbackUrl()) + .putAllHeader(callbackHeaders) + .build()); + if (links != null) { + cbBuilder.addAllLinks(links); + } + nexusWorkflowOptions.setCompletionCallbacks(Collections.singletonList(cbBuilder.build())); + } + if (options.getTaskQueue() == null) { nexusWorkflowOptions.setTaskQueue(request.getTaskQueue()); } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/common/NexusUtil.java b/temporal-sdk/src/main/java/io/temporal/internal/common/NexusUtil.java index 88d8e1d03b..50f3c37d42 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/common/NexusUtil.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/common/NexusUtil.java @@ -3,7 +3,10 @@ import com.google.protobuf.ByteString; import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.util.JsonFormat; +import io.grpc.Status; +import io.grpc.StatusRuntimeException; import io.nexusrpc.Link; +import io.nexusrpc.handler.HandlerException; import io.temporal.api.nexus.v1.Failure; import io.temporal.common.converter.DataConverter; import java.net.URI; @@ -65,5 +68,41 @@ public static Failure exceptionToNexusFailure(Throwable exception, DataConverter .build(); } + public static HandlerException grpcExceptionToHandlerException(StatusRuntimeException sre) { + Status status = sre.getStatus(); + switch (status.getCode()) { + case ALREADY_EXISTS: + case INVALID_ARGUMENT: + case FAILED_PRECONDITION: + case OUT_OF_RANGE: + return new HandlerException(HandlerException.ErrorType.BAD_REQUEST, sre); + case ABORTED: + case UNAVAILABLE: + return new HandlerException(HandlerException.ErrorType.UNAVAILABLE, sre); + case CANCELLED: + return new HandlerException(HandlerException.ErrorType.INTERNAL, sre); + case DATA_LOSS: + case INTERNAL: + case UNKNOWN: + return new HandlerException(HandlerException.ErrorType.INTERNAL, sre); + case UNAUTHENTICATED: + return new HandlerException(HandlerException.ErrorType.UNAUTHENTICATED, sre); + case PERMISSION_DENIED: + return new HandlerException(HandlerException.ErrorType.UNAUTHORIZED, sre); + case NOT_FOUND: + return new HandlerException(HandlerException.ErrorType.NOT_FOUND, sre); + case RESOURCE_EXHAUSTED: + return new HandlerException(HandlerException.ErrorType.RESOURCE_EXHAUSTED, sre); + case UNIMPLEMENTED: + return new HandlerException(HandlerException.ErrorType.NOT_IMPLEMENTED, sre); + case DEADLINE_EXCEEDED: + return new HandlerException(HandlerException.ErrorType.UPSTREAM_TIMEOUT, sre); + default: + return new HandlerException( + HandlerException.ErrorType.INTERNAL, + new IllegalStateException("Unexpected gRPC status code: " + status.getCode(), sre)); + } + } + private NexusUtil() {} } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/nexus/InternalNexusOperationContext.java b/temporal-sdk/src/main/java/io/temporal/internal/nexus/InternalNexusOperationContext.java index b32e4de2c1..4f775bc05b 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/nexus/InternalNexusOperationContext.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/nexus/InternalNexusOperationContext.java @@ -1,6 +1,7 @@ package io.temporal.internal.nexus; import com.uber.m3.tally.Scope; +import io.nexusrpc.OperationDefinition; import io.temporal.api.common.v1.Link; import io.temporal.client.WorkflowClient; import io.temporal.common.interceptors.NexusOperationOutboundCallsInterceptor; @@ -9,15 +10,21 @@ public class InternalNexusOperationContext { private final String namespace; private final String taskQueue; + private final OperationDefinition operationDefinition; private final Scope metricScope; private final WorkflowClient client; NexusOperationOutboundCallsInterceptor outboundCalls; Link startWorkflowResponseLink; public InternalNexusOperationContext( - String namespace, String taskQueue, Scope metricScope, WorkflowClient client) { + String namespace, + String taskQueue, + OperationDefinition operationDefinition, + Scope metricScope, + WorkflowClient client) { this.namespace = namespace; this.taskQueue = taskQueue; + this.operationDefinition = operationDefinition; this.metricScope = metricScope; this.client = client; } @@ -38,6 +45,10 @@ public String getNamespace() { return namespace; } + public OperationDefinition getOperationDefinition() { + return operationDefinition; + } + public void setOutboundInterceptor(NexusOperationOutboundCallsInterceptor outboundCalls) { this.outboundCalls = outboundCalls; } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/nexus/NexusTaskHandlerImpl.java b/temporal-sdk/src/main/java/io/temporal/internal/nexus/NexusTaskHandlerImpl.java index c7d8c062c5..998d15dc79 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/nexus/NexusTaskHandlerImpl.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/nexus/NexusTaskHandlerImpl.java @@ -5,11 +5,14 @@ import com.uber.m3.tally.Scope; import io.nexusrpc.Header; +import io.nexusrpc.OperationDefinition; import io.nexusrpc.OperationException; +import io.nexusrpc.OperationStillRunningException; import io.nexusrpc.handler.*; import io.temporal.api.common.v1.Payload; import io.temporal.api.enums.v1.NexusHandlerErrorRetryBehavior; import io.temporal.api.nexus.v1.*; +import io.temporal.api.workflowservice.v1.PollNexusTaskQueueResponseOrBuilder; import io.temporal.client.WorkflowClient; import io.temporal.client.WorkflowException; import io.temporal.common.converter.DataConverter; @@ -17,6 +20,7 @@ import io.temporal.failure.ApplicationFailure; import io.temporal.internal.common.InternalUtils; import io.temporal.internal.common.NexusUtil; +import io.temporal.internal.common.ProtobufTimeUtils; import io.temporal.internal.worker.NexusTask; import io.temporal.internal.worker.NexusTaskHandler; import io.temporal.internal.worker.ShutdownManager; @@ -72,6 +76,34 @@ public boolean start() { return true; } + private String getNexusTaskService(PollNexusTaskQueueResponseOrBuilder pollResponse) { + Request request = pollResponse.getRequest(); + if (request.hasStartOperation()) { + return request.getStartOperation().getService(); + } else if (request.hasCancelOperation()) { + return request.getCancelOperation().getService(); + } else if (request.hasGetOperationInfo()) { + return request.getGetOperationInfo().getService(); + } else if (request.hasGetOperationResult()) { + return request.getGetOperationResult().getService(); + } + return ""; + } + + private String getNexusTaskOperation(PollNexusTaskQueueResponseOrBuilder pollResponse) { + Request request = pollResponse.getRequest(); + if (request.hasStartOperation()) { + return request.getStartOperation().getOperation(); + } else if (request.hasCancelOperation()) { + return request.getCancelOperation().getOperation(); + } else if (request.hasGetOperationInfo()) { + return request.getGetOperationInfo().getOperation(); + } else if (request.hasGetOperationResult()) { + return request.getGetOperationResult().getOperation(); + } + return ""; + } + @Override public Result handle(NexusTask task, Scope metricsScope) throws TimeoutException { Request request = task.getResponse().getRequest(); @@ -109,8 +141,20 @@ public Result handle(NexusTask task, Scope metricsScope) throws TimeoutException } } + // TODO: refactor + OperationDefinition operationDefinition = null; + String service = getNexusTaskService(task.getResponse()); + String operation = getNexusTaskOperation(task.getResponse()); + if (serviceHandler.getInstances().containsKey(service)) { + ServiceImplInstance serviceImpl = serviceHandler.getInstances().get(service); + if (serviceImpl.getDefinition().getOperations().containsKey(operation)) { + operationDefinition = serviceImpl.getDefinition().getOperations().get(operation); + } + } + CurrentNexusOperationContext.set( - new InternalNexusOperationContext(namespace, taskQueue, metricsScope, client)); + new InternalNexusOperationContext( + namespace, taskQueue, operationDefinition, metricsScope, client)); switch (request.getVariantCase()) { case START_OPERATION: @@ -121,6 +165,14 @@ public Result handle(NexusTask task, Scope metricsScope) throws TimeoutException CancelOperationResponse cancelResponse = handleCancelledOperation(ctx, request.getCancelOperation()); return new Result(Response.newBuilder().setCancelOperation(cancelResponse).build()); + case GET_OPERATION_INFO: + GetOperationInfoResponse getInfoResponse = + handleGetOperationInfo(ctx, request.getGetOperationInfo()); + return new Result(Response.newBuilder().setGetOperationInfo(getInfoResponse).build()); + case GET_OPERATION_RESULT: + GetOperationResultResponse getResultResponse = + handleGetOperationResult(ctx, request.getGetOperationResult()); + return new Result(Response.newBuilder().setGetOperationResult(getResultResponse).build()); default: throw new HandlerException( HandlerException.ErrorType.NOT_IMPLEMENTED, @@ -152,6 +204,92 @@ public Result handle(NexusTask task, Scope metricsScope) throws TimeoutException } } + private GetOperationResultResponse handleGetOperationResult( + OperationContext.Builder ctx, GetOperationResultRequest task) { + ctx.setService(task.getService()).setOperation(task.getOperation()); + + OperationFetchResultDetails operationFetchResultDetails = + OperationFetchResultDetails.newBuilder() + .setOperationToken(task.getOperationToken()) + .setTimeout(ProtobufTimeUtils.toJavaDuration(task.getWait())) + .build(); + try { + try { + HandlerResultContent result = + serviceHandler.fetchOperationResult(ctx.build(), operationFetchResultDetails); + return GetOperationResultResponse.newBuilder() + .setSuccessful( + GetOperationResultResponse.Successful.newBuilder() + .setResult(Payload.parseFrom(result.getDataBytes())) + .build()) + .build(); + } catch (OperationStillRunningException e) { + return GetOperationResultResponse.newBuilder() + .setStillRunning(GetOperationResultResponse.StillRunning.newBuilder().build()) + .build(); + } catch (Throwable e) { + Throwable failure = CheckedExceptionWrapper.unwrap(e); + log.warn( + "Nexus fetch operation result failure. Service={}, Operation={}", + task.getService(), + task.getOperation(), + failure); + // Re-throw the original exception to handle it in the caller + throw e; + } + } catch (OperationException e) { + return GetOperationResultResponse.newBuilder() + .setUnsuccessful( + GetOperationResultResponse.Unsuccessful.newBuilder() + .setOperationError( + UnsuccessfulOperationError.newBuilder() + .setOperationState(e.getState().toString().toLowerCase()) + .setFailure(exceptionToNexusFailure(e.getCause(), dataConverter)) + .build())) + .build(); + } catch (Throwable failure) { + convertKnownFailures(failure); + } + throw new HandlerException( + HandlerException.ErrorType.INTERNAL, + new RuntimeException("Failed to handle get operation result"), + HandlerException.RetryBehavior.NON_RETRYABLE); + } + + private GetOperationInfoResponse handleGetOperationInfo( + OperationContext.Builder ctx, GetOperationInfoRequest task) { + ctx.setService(task.getService()).setOperation(task.getOperation()); + + OperationFetchInfoDetails details = + OperationFetchInfoDetails.newBuilder().setOperationToken(task.getOperationToken()).build(); + try { + try { + io.nexusrpc.OperationInfo info = serviceHandler.fetchOperationInfo(ctx.build(), details); + return GetOperationInfoResponse.newBuilder() + .setInfo( + OperationInfo.newBuilder() + .setState(info.getState().toString().toLowerCase()) + .setToken(info.getToken()) + .build()) + .build(); + } catch (Throwable e) { + Throwable failure = CheckedExceptionWrapper.unwrap(e); + log.warn( + "Nexus cancel operation failure. Service={}, Operation={}", + task.getService(), + task.getOperation(), + failure); + // Re-throw the original exception to handle it in the caller + throw e; + } + } catch (Throwable failure) { + convertKnownFailures(failure); + } + return GetOperationInfoResponse.newBuilder() + .setInfo(OperationInfo.newBuilder().build()) + .build(); + } + private NexusHandlerErrorRetryBehavior mapRetryBehavior( HandlerException.RetryBehavior retryBehavior) { switch (retryBehavior) { @@ -263,7 +401,6 @@ private StartOperationResponse handleStartOperation( HandlerInputContent.Builder input = HandlerInputContent.newBuilder().setDataStream(task.getPayload().toByteString().newInput()); - StartOperationResponse.Builder startResponseBuilder = StartOperationResponse.newBuilder(); OperationContext context = ctx.build(); try { diff --git a/temporal-sdk/src/main/java/io/temporal/internal/nexus/PayloadSerializer.java b/temporal-sdk/src/main/java/io/temporal/internal/nexus/PayloadSerializer.java index e28f4a49ba..8b77e64a05 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/nexus/PayloadSerializer.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/nexus/PayloadSerializer.java @@ -14,10 +14,10 @@ * io.nexusrpc.Serializer.Content} objects by using the {@link DataConverter} to convert objects to * and from {@link Payload} objects. */ -class PayloadSerializer implements Serializer { +public class PayloadSerializer implements Serializer { DataConverter dataConverter; - PayloadSerializer(DataConverter dataConverter) { + public PayloadSerializer(DataConverter dataConverter) { this.dataConverter = dataConverter; } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/nexus/RootNexusOperationInboundCallsInterceptor.java b/temporal-sdk/src/main/java/io/temporal/internal/nexus/RootNexusOperationInboundCallsInterceptor.java index b4c24100b0..ba91a9a358 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/nexus/RootNexusOperationInboundCallsInterceptor.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/nexus/RootNexusOperationInboundCallsInterceptor.java @@ -1,6 +1,7 @@ package io.temporal.internal.nexus; import io.nexusrpc.OperationException; +import io.nexusrpc.OperationStillRunningException; import io.nexusrpc.handler.OperationHandler; import io.nexusrpc.handler.OperationStartResult; import io.temporal.common.interceptors.NexusOperationInboundCallsInterceptor; @@ -8,10 +9,10 @@ public class RootNexusOperationInboundCallsInterceptor implements NexusOperationInboundCallsInterceptor { - private final OperationHandler operationInterceptor; + private final OperationHandler rootHandler; - RootNexusOperationInboundCallsInterceptor(OperationHandler operationInterceptor) { - this.operationInterceptor = operationInterceptor; + RootNexusOperationInboundCallsInterceptor(OperationHandler rootHandler) { + this.rootHandler = rootHandler; } @Override @@ -22,14 +23,28 @@ public void init(NexusOperationOutboundCallsInterceptor outboundCalls) { @Override public StartOperationOutput startOperation(StartOperationInput input) throws OperationException { OperationStartResult result = - operationInterceptor.start( - input.getOperationContext(), input.getStartDetails(), input.getInput()); + rootHandler.start(input.getOperationContext(), input.getStartDetails(), input.getInput()); return new StartOperationOutput(result); } + @Override + public FetchOperationResultOutput fetchOperationResult(FetchOperationResultInput input) + throws OperationStillRunningException, OperationException { + Object result = + rootHandler.fetchResult( + input.getOperationContext(), input.getOperationFetchResultDetails()); + return new FetchOperationResultOutput(result); + } + + @Override + public FetchOperationInfoResponse fetchOperationInfo(FetchOperationInfoInput input) { + return new FetchOperationInfoResponse( + rootHandler.fetchInfo(input.getOperationContext(), input.getOperationFetchInfoDetails())); + } + @Override public CancelOperationOutput cancelOperation(CancelOperationInput input) { - operationInterceptor.cancel(input.getOperationContext(), input.getCancelDetails()); + rootHandler.cancel(input.getOperationContext(), input.getCancelDetails()); return new CancelOperationOutput(); } } diff --git a/temporal-sdk/src/main/java/io/temporal/internal/nexus/TemporalInterceptorMiddleware.java b/temporal-sdk/src/main/java/io/temporal/internal/nexus/TemporalInterceptorMiddleware.java index 68ecfc9208..f7beef2be6 100644 --- a/temporal-sdk/src/main/java/io/temporal/internal/nexus/TemporalInterceptorMiddleware.java +++ b/temporal-sdk/src/main/java/io/temporal/internal/nexus/TemporalInterceptorMiddleware.java @@ -2,6 +2,7 @@ import io.nexusrpc.OperationException; import io.nexusrpc.OperationInfo; +import io.nexusrpc.OperationStillRunningException; import io.nexusrpc.handler.*; import io.temporal.common.interceptors.NexusOperationInboundCallsInterceptor; import io.temporal.common.interceptors.WorkerInterceptor; @@ -51,15 +52,21 @@ public OperationStartResult start( @Override public Object fetchResult( OperationContext operationContext, OperationFetchResultDetails operationFetchResultDetails) - throws OperationException { - throw new UnsupportedOperationException("Not implemented"); + throws OperationException, OperationStillRunningException { + return next.fetchOperationResult( + new NexusOperationInboundCallsInterceptor.FetchOperationResultInput( + operationContext, operationFetchResultDetails)) + .getResult(); } @Override public OperationInfo fetchInfo( OperationContext operationContext, OperationFetchInfoDetails operationFetchInfoDetails) throws HandlerException { - throw new UnsupportedOperationException("Not implemented"); + return next.fetchOperationInfo( + new NexusOperationInboundCallsInterceptor.FetchOperationInfoInput( + operationContext, operationFetchInfoDetails)) + .getOperationInfo(); } @Override diff --git a/temporal-sdk/src/main/java/io/temporal/nexus/WorkflowRunOperationImpl.java b/temporal-sdk/src/main/java/io/temporal/nexus/WorkflowRunOperationImpl.java index a26bcd2a62..841d344123 100644 --- a/temporal-sdk/src/main/java/io/temporal/nexus/WorkflowRunOperationImpl.java +++ b/temporal-sdk/src/main/java/io/temporal/nexus/WorkflowRunOperationImpl.java @@ -3,19 +3,30 @@ import static io.temporal.internal.common.LinkConverter.workflowEventToNexusLink; import static io.temporal.internal.common.NexusUtil.nexusProtoLinkToLink; +import io.nexusrpc.OperationException; import io.nexusrpc.OperationInfo; +import io.nexusrpc.OperationState; +import io.nexusrpc.OperationStillRunningException; import io.nexusrpc.handler.*; import io.nexusrpc.handler.OperationHandler; import io.temporal.api.common.v1.Link; import io.temporal.api.common.v1.WorkflowExecution; import io.temporal.api.enums.v1.EventType; import io.temporal.client.WorkflowClient; +import io.temporal.client.WorkflowException; +import io.temporal.client.WorkflowExecutionDescription; +import io.temporal.client.WorkflowNotFoundException; +import io.temporal.failure.CanceledFailure; import io.temporal.internal.client.NexusStartWorkflowRequest; import io.temporal.internal.client.NexusStartWorkflowResponse; import io.temporal.internal.nexus.CurrentNexusOperationContext; import io.temporal.internal.nexus.InternalNexusOperationContext; import io.temporal.internal.nexus.OperationTokenUtil; +import java.lang.reflect.Type; import java.net.URISyntaxException; +import java.time.temporal.ChronoUnit; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; class WorkflowRunOperationImpl implements OperationHandler { private final WorkflowHandleFactory handleFactory; @@ -29,7 +40,7 @@ public OperationStartResult start( OperationContext ctx, OperationStartDetails operationStartDetails, T input) { InternalNexusOperationContext nexusCtx = CurrentNexusOperationContext.get(); - WorkflowHandle handle = handleFactory.apply(ctx, operationStartDetails, input); + WorkflowHandle handle = handleFactory.apply(ctx, operationStartDetails, input); NexusStartWorkflowRequest nexusRequest = new NexusStartWorkflowRequest( @@ -78,32 +89,98 @@ public OperationStartResult start( } @Override + @SuppressWarnings("unchecked") public R fetchResult( - OperationContext operationContext, OperationFetchResultDetails operationFetchResultDetails) { - throw new UnsupportedOperationException("Not implemented"); + OperationContext operationContext, OperationFetchResultDetails operationFetchResultDetails) + throws OperationStillRunningException, OperationException { + String workflowId = extractWorkflowIdFromToken(operationFetchResultDetails.getOperationToken()); + WorkflowClient client = CurrentNexusOperationContext.get().getWorkflowClient(); + + Type outputType = CurrentNexusOperationContext.get().getOperationDefinition().getOutputType(); + try { + return (R) + client + .newUntypedWorkflowStub(workflowId) + .getResult( + operationFetchResultDetails.getTimeout().get(ChronoUnit.SECONDS), + TimeUnit.SECONDS, + com.google.common.reflect.TypeToken.of(outputType).getRawType(), + outputType); + } catch (TimeoutException te) { + throw new OperationStillRunningException(); + } catch (WorkflowNotFoundException e) { + throw new HandlerException(HandlerException.ErrorType.NOT_FOUND, e); + } catch (WorkflowException we) { + if (we.getCause() instanceof CanceledFailure) { + throw OperationException.canceled(we.getCause()); + } else { + throw OperationException.failure(we.getCause()); + } + } } @Override public OperationInfo fetchInfo( OperationContext operationContext, OperationFetchInfoDetails operationFetchInfoDetails) { - throw new UnsupportedOperationException("Not implemented"); + String workflowId = extractWorkflowIdFromToken(operationFetchInfoDetails.getOperationToken()); + try { + WorkflowClient client = CurrentNexusOperationContext.get().getWorkflowClient(); + WorkflowExecutionDescription description = + client.newUntypedWorkflowStub(workflowId).describe(); + OperationState state = null; + switch (description.getStatus()) { + case WORKFLOW_EXECUTION_STATUS_RUNNING: + case WORKFLOW_EXECUTION_STATUS_CONTINUED_AS_NEW: + // WORKFLOW_EXECUTION_STATUS_CONTINUED_AS_NEW really shouldn't be possible here, + // but we handle it gracefully by treating it as RUNNING. + state = OperationState.RUNNING; + break; + case WORKFLOW_EXECUTION_STATUS_COMPLETED: + state = OperationState.SUCCEEDED; + break; + case WORKFLOW_EXECUTION_STATUS_CANCELED: + state = OperationState.CANCELED; + break; + case WORKFLOW_EXECUTION_STATUS_FAILED: + case WORKFLOW_EXECUTION_STATUS_TIMED_OUT: + case WORKFLOW_EXECUTION_STATUS_TERMINATED: + state = OperationState.FAILED; + break; + default: + throw new HandlerException( + HandlerException.ErrorType.INTERNAL, + new IllegalArgumentException("Unknown workflow status: " + description.getStatus())); + } + return OperationInfo.newBuilder() + .setState(state) + .setToken(operationFetchInfoDetails.getOperationToken()) + .build(); + } catch (WorkflowNotFoundException e) { + throw new HandlerException(HandlerException.ErrorType.NOT_FOUND, e); + } } @Override public void cancel( OperationContext operationContext, OperationCancelDetails operationCancelDetails) { + try { + String workflowId = extractWorkflowIdFromToken(operationCancelDetails.getOperationToken()); + WorkflowClient client = CurrentNexusOperationContext.get().getWorkflowClient(); + client.newUntypedWorkflowStub(workflowId).cancel(); + } catch (WorkflowNotFoundException e) { + throw new HandlerException(HandlerException.ErrorType.NOT_FOUND, e); + } + } + + private String extractWorkflowIdFromToken(String operationToken) { String workflowId; try { - workflowId = - OperationTokenUtil.loadWorkflowIdFromOperationToken( - operationCancelDetails.getOperationToken()); + workflowId = OperationTokenUtil.loadWorkflowIdFromOperationToken(operationToken); } catch (IllegalArgumentException e) { throw new HandlerException( HandlerException.ErrorType.BAD_REQUEST, new IllegalArgumentException("failed to parse operation token", e)); } - - WorkflowClient client = CurrentNexusOperationContext.get().getWorkflowClient(); - client.newUntypedWorkflowStub(workflowId).cancel(); + return workflowId; } } diff --git a/temporal-sdk/src/test/java/io/temporal/client/NexusServiceClientCallsInterceptorTest.java b/temporal-sdk/src/test/java/io/temporal/client/NexusServiceClientCallsInterceptorTest.java new file mode 100644 index 0000000000..81c3ff6766 --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/client/NexusServiceClientCallsInterceptorTest.java @@ -0,0 +1,65 @@ +package io.temporal.client; + +import io.nexusrpc.client.ServiceClient; +import io.nexusrpc.handler.OperationContext; +import io.nexusrpc.handler.OperationHandler; +import io.nexusrpc.handler.OperationImpl; +import io.nexusrpc.handler.OperationStartDetails; +import io.nexusrpc.handler.ServiceImpl; +import io.temporal.common.interceptors.NexusServiceClientCallsInterceptor; +import io.temporal.common.interceptors.NexusServiceClientCallsInterceptorBase; +import io.temporal.common.interceptors.WorkflowClientInterceptorBase; +import io.temporal.testing.internal.SDKTestWorkflowRule; +import io.temporal.workflow.shared.TestNexusServices; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; + +public class NexusServiceClientCallsInterceptorTest { + private final AtomicInteger intercepted = new AtomicInteger(); + + @Rule + public SDKTestWorkflowRule testWorkflowRule = + SDKTestWorkflowRule.newBuilder() + .setNexusServiceImplementation(new TestNexusServiceImpl()) + .setWorkflowClientOptions( + WorkflowClientOptions.newBuilder() + .setInterceptors( + new WorkflowClientInterceptorBase() { + @Override + public NexusServiceClientCallsInterceptor + nexusServiceClientCallsInterceptor( + NexusServiceClientCallsInterceptor next) { + return new NexusServiceClientCallsInterceptorBase(next) { + @Override + public StartOperationOutput startOperation(StartOperationInput input) + throws io.nexusrpc.OperationException { + intercepted.incrementAndGet(); + return super.startOperation(input); + } + }; + } + }) + .validateAndBuildWithDefaults()) + .build(); + + @Test + public void interceptorIsInvoked() throws Exception { + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(TestNexusServices.TestNexusService1.class); + String result = + serviceClient.executeOperation(TestNexusServices.TestNexusService1::operation, "World"); + Assert.assertEquals("Hello World", result); + Assert.assertEquals(1, intercepted.get()); + } + + @ServiceImpl(service = TestNexusServices.TestNexusService1.class) + public static class TestNexusServiceImpl { + @OperationImpl + public OperationHandler operation() { + return OperationHandler.sync( + (OperationContext ctx, OperationStartDetails details, String param) -> "Hello " + param); + } + } +} diff --git a/temporal-sdk/src/test/java/io/temporal/client/NexusServiceClientSyncOperationTest.java b/temporal-sdk/src/test/java/io/temporal/client/NexusServiceClientSyncOperationTest.java new file mode 100644 index 0000000000..7af3840fcc --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/client/NexusServiceClientSyncOperationTest.java @@ -0,0 +1,192 @@ +package io.temporal.client; + +import io.nexusrpc.OperationException; +import io.nexusrpc.OperationState; +import io.nexusrpc.OperationStillRunningException; +import io.nexusrpc.client.ServiceClient; +import io.nexusrpc.client.StartOperationResponse; +import io.nexusrpc.handler.*; +import io.temporal.failure.ApplicationFailure; +import io.temporal.testing.internal.SDKTestWorkflowRule; +import io.temporal.workflow.shared.TestNexusServices; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; + +public class NexusServiceClientSyncOperationTest { + @Rule + public SDKTestWorkflowRule testWorkflowRule = + SDKTestWorkflowRule.newBuilder() + .setNexusServiceImplementation(new TestNexusServiceImpl()) + .build(); + + @Test + public void executeSyncOperation() + throws OperationException, + OperationStillRunningException, + ExecutionException, + InterruptedException { + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(TestNexusServices.TestNexusService1.class); + + String result = + serviceClient.executeOperation(TestNexusServices.TestNexusService1::operation, "World"); + Assert.assertEquals("Hello World", result); + + String asyncResult = + serviceClient + .executeOperationAsync(TestNexusServices.TestNexusService1::operation, "World Async") + .get(); + Assert.assertEquals("Hello World Async", asyncResult); + } + + @Test + public void executeSyncOperationFail() { + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(TestNexusServices.TestNexusService1.class); + + OperationException oe = + Assert.assertThrows( + OperationException.class, + () -> + serviceClient.executeOperation( + TestNexusServices.TestNexusService1::operation, "fail")); + Assert.assertEquals(OperationState.FAILED, oe.getState()); + Assert.assertTrue(oe.getCause() instanceof ApplicationFailure); + + CompletableFuture result = + serviceClient.executeOperationAsync(TestNexusServices.TestNexusService1::operation, "fail"); + ExecutionException ee = Assert.assertThrows(ExecutionException.class, result::get); + Assert.assertTrue(ee.getCause() instanceof OperationException); + oe = (OperationException) ee.getCause(); + Assert.assertEquals(OperationState.FAILED, oe.getState()); + Assert.assertTrue(oe.getCause() instanceof ApplicationFailure); + } + + @Test + public void executeSyncOperationHandlerError() { + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(TestNexusServices.TestNexusService1.class); + + HandlerException he = + Assert.assertThrows( + HandlerException.class, + () -> + serviceClient.executeOperation( + TestNexusServices.TestNexusService1::operation, "handlerError")); + System.out.println(he.getMessage()); + } + + @Test + public void executeSyncOperationCancel() { + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(TestNexusServices.TestNexusService1.class); + + OperationException oe = + Assert.assertThrows( + OperationException.class, + () -> + serviceClient.executeOperation( + TestNexusServices.TestNexusService1::operation, "cancel")); + Assert.assertEquals(OperationState.CANCELED, oe.getState()); + Assert.assertTrue(oe.getCause() instanceof ApplicationFailure); + + CompletableFuture result = + serviceClient.executeOperationAsync( + TestNexusServices.TestNexusService1::operation, "cancel"); + ExecutionException ee = Assert.assertThrows(ExecutionException.class, result::get); + Assert.assertTrue(ee.getCause() instanceof OperationException); + oe = (OperationException) ee.getCause(); + Assert.assertEquals(OperationState.CANCELED, oe.getState()); + Assert.assertTrue(oe.getCause() instanceof ApplicationFailure); + } + + @Test + public void startSyncOperation() throws OperationException { + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(TestNexusServices.TestNexusService1.class); + + StartOperationResponse result = + serviceClient.startOperation(TestNexusServices.TestNexusService1::operation, "World"); + Assert.assertTrue(result instanceof StartOperationResponse.Sync); + StartOperationResponse.Sync syncResult = (StartOperationResponse.Sync) result; + Assert.assertEquals("Hello World", syncResult.getResult()); + + CompletableFuture> asyncResult = + serviceClient.startOperationAsync( + TestNexusServices.TestNexusService1::operation, "World Async"); + Assert.assertTrue(asyncResult.join() instanceof StartOperationResponse.Sync); + syncResult = (StartOperationResponse.Sync) asyncResult.join(); + Assert.assertEquals("Hello World Async", syncResult.getResult()); + } + + @Test + public void startSyncOperationFail() { + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(TestNexusServices.TestNexusService1.class); + + OperationException oe = + Assert.assertThrows( + OperationException.class, + () -> + serviceClient.startOperation( + TestNexusServices.TestNexusService1::operation, "fail")); + Assert.assertEquals(OperationState.FAILED, oe.getState()); + Assert.assertTrue(oe.getCause() instanceof ApplicationFailure); + + CompletableFuture> asyncResult = + serviceClient.startOperationAsync(TestNexusServices.TestNexusService1::operation, "fail"); + ExecutionException ee = Assert.assertThrows(ExecutionException.class, asyncResult::get); + Assert.assertTrue(ee.getCause() instanceof OperationException); + oe = (OperationException) ee.getCause(); + Assert.assertEquals(OperationState.FAILED, oe.getState()); + Assert.assertTrue(oe.getCause() instanceof ApplicationFailure); + } + + @Test + public void startSyncOperationCancel() { + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(TestNexusServices.TestNexusService1.class); + + OperationException oe = + Assert.assertThrows( + OperationException.class, + () -> + serviceClient.startOperation( + TestNexusServices.TestNexusService1::operation, "cancel")); + Assert.assertEquals(OperationState.CANCELED, oe.getState()); + Assert.assertTrue(oe.getCause() instanceof ApplicationFailure); + + CompletableFuture> asyncResult = + serviceClient.startOperationAsync(TestNexusServices.TestNexusService1::operation, "cancel"); + ExecutionException ee = Assert.assertThrows(ExecutionException.class, asyncResult::get); + Assert.assertTrue(ee.getCause() instanceof OperationException); + oe = (OperationException) ee.getCause(); + Assert.assertEquals(OperationState.CANCELED, oe.getState()); + Assert.assertTrue(oe.getCause() instanceof ApplicationFailure); + } + + @ServiceImpl(service = TestNexusServices.TestNexusService1.class) + public static class TestNexusServiceImpl { + @OperationImpl + public OperationHandler operation() { + return OperationHandler.sync( + (OperationContext context, OperationStartDetails details, String param) -> { + if (Objects.equals(param, "fail")) { + throw OperationException.failure(new IllegalArgumentException("fail")); + } else if (Objects.equals(param, "cancel")) { + throw OperationException.canceled(new IllegalArgumentException("cancel")); + } else if (Objects.equals(param, "handlerError")) { + throw new HandlerException( + HandlerException.ErrorType.RESOURCE_EXHAUSTED, + new IllegalArgumentException("handlerError"), + HandlerException.RetryBehavior.RETRYABLE); + } + return "Hello " + param; + }); + } + } +} diff --git a/temporal-sdk/src/test/java/io/temporal/client/NexusServiceClientWorkflowOperationTest.java b/temporal-sdk/src/test/java/io/temporal/client/NexusServiceClientWorkflowOperationTest.java new file mode 100644 index 0000000000..fffc661881 --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/client/NexusServiceClientWorkflowOperationTest.java @@ -0,0 +1,353 @@ +package io.temporal.client; + +import com.fasterxml.jackson.core.JsonProcessingException; +import io.nexusrpc.OperationException; +import io.nexusrpc.OperationInfo; +import io.nexusrpc.OperationState; +import io.nexusrpc.OperationStillRunningException; +import io.nexusrpc.client.*; +import io.nexusrpc.handler.*; +import io.temporal.failure.ApplicationFailure; +import io.temporal.internal.nexus.OperationTokenUtil; +import io.temporal.nexus.Nexus; +import io.temporal.nexus.WorkflowRunOperation; +import io.temporal.testing.internal.SDKTestWorkflowRule; +import io.temporal.workflow.Workflow; +import io.temporal.workflow.shared.TestNexusServices; +import io.temporal.workflow.shared.TestWorkflows; +import java.time.Duration; +import java.util.Objects; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; + +public class NexusServiceClientWorkflowOperationTest { + @Rule + public SDKTestWorkflowRule testWorkflowRule = + SDKTestWorkflowRule.newBuilder() + .setWorkflowTypes(TestNexus.class) + .setNexusServiceImplementation(new TestNexusServiceImpl()) + .build(); + + @Test(timeout = 5000000) + public void executeWorkflowOperationSuccess() + throws OperationStillRunningException, OperationException { + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(TestNexusServices.TestNexusService1.class); + + Assert.assertEquals( + "Hello World", + serviceClient.executeOperation( + TestNexusServices.TestNexusService1::operation, + "World", + ExecuteOperationOptions.newBuilder().setTimeout(Duration.ofMinutes(2)).build())); + + CompletableFuture resultAsync = + serviceClient.executeOperationAsync( + TestNexusServices.TestNexusService1::operation, + "World Async", + ExecuteOperationOptions.newBuilder().setTimeout(Duration.ofMinutes(2)).build()); + Assert.assertEquals("Hello World Async", resultAsync.join()); + } + + @Test + public void executeWorkflowOperationFail() { + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(TestNexusServices.TestNexusService1.class); + + OperationException oe = + Assert.assertThrows( + OperationException.class, + () -> + serviceClient.executeOperation( + TestNexusServices.TestNexusService1::operation, + "fail", + ExecuteOperationOptions.newBuilder() + .setTimeout(Duration.ofSeconds(10)) + .build())); + Assert.assertTrue(oe.getCause() instanceof ApplicationFailure); + + CompletableFuture resultAsync = + serviceClient.executeOperationAsync( + TestNexusServices.TestNexusService1::operation, + "fail", + ExecuteOperationOptions.newBuilder().setTimeout(Duration.ofSeconds(10)).build()); + CompletionException ce = + Assert.assertThrows(CompletionException.class, () -> resultAsync.join()); + Assert.assertTrue(ce.getCause() instanceof OperationException); + OperationException asyncOperationException = (OperationException) ce.getCause(); + Assert.assertTrue(asyncOperationException.getCause() instanceof ApplicationFailure); + } + + @Test + public void executeWorkflowOperationStillRunning() { + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(TestNexusServices.TestNexusService1.class); + Assert.assertThrows( + OperationStillRunningException.class, + () -> + serviceClient.executeOperation( + TestNexusServices.TestNexusService1::operation, + "World", + ExecuteOperationOptions.newBuilder().setTimeout(Duration.ofSeconds(1)).build())); + + CompletableFuture resultAsync = + serviceClient.executeOperationAsync( + TestNexusServices.TestNexusService1::operation, + "World Async", + ExecuteOperationOptions.newBuilder().setTimeout(Duration.ofSeconds(1)).build()); + CompletionException ce = + Assert.assertThrows(CompletionException.class, () -> resultAsync.join()); + Assert.assertTrue(ce.getCause() instanceof OperationStillRunningException); + } + + @Test + public void createHandle() + throws OperationException, OperationStillRunningException, InterruptedException { + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(TestNexusServices.TestNexusService1.class); + + StartOperationResponse startResult = + serviceClient.startOperation(TestNexusServices.TestNexusService1::operation, "World"); + Assert.assertTrue(startResult instanceof StartOperationResponse.Async); + OperationHandle handle = + ((StartOperationResponse.Async) startResult).getHandle(); + OperationHandle newHandler = + serviceClient.newHandle( + TestNexusServices.TestNexusService1::operation, handle.getOperationToken()); + Thread.sleep(6000); // Wait for the operation to complete + String operationResult = + newHandler.fetchResult( + FetchOperationResultOptions.newBuilder().setTimeout(Duration.ofSeconds(1)).build()); + Assert.assertEquals("Hello World", operationResult); + } + + @Test + public void createHandleUntyped() + throws OperationException, OperationStillRunningException, InterruptedException { + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(TestNexusServices.TestNexusService1.class); + + StartOperationResponse startResult = + serviceClient.startOperation(TestNexusServices.TestNexusService1::operation, "World"); + Assert.assertTrue(startResult instanceof StartOperationResponse.Async); + OperationHandle handle = + ((StartOperationResponse.Async) startResult).getHandle(); + + OperationHandle newHandler = + serviceClient.newHandle( + TestNexusServices.TestNexusService1::operation, handle.getOperationToken()); + Thread.sleep(6000); // Wait for the operation to complete + String operationResult = + newHandler.fetchResult( + FetchOperationResultOptions.newBuilder().setTimeout(Duration.ofSeconds(1)).build()); + Assert.assertEquals("Hello World", operationResult); + } + + @Test + public void createInvalidHandle() throws JsonProcessingException { + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(TestNexusServices.TestNexusService1.class); + + OperationHandle badHandle = + serviceClient.newHandle(TestNexusServices.TestNexusService1::operation, "BAD_TOKEN"); + HandlerException he = + Assert.assertThrows( + HandlerException.class, + () -> + badHandle.fetchResult( + FetchOperationResultOptions.newBuilder() + .setTimeout(Duration.ofSeconds(1)) + .build())); + System.out.println("Expected exception: " + he.getMessage()); + + String token = OperationTokenUtil.generateWorkflowRunOperationToken("workflowId", "namespace"); + OperationHandle missingHandle = + serviceClient.newHandle(TestNexusServices.TestNexusService1::operation, token); + he = + Assert.assertThrows( + HandlerException.class, + () -> + missingHandle.fetchResult( + FetchOperationResultOptions.newBuilder() + .setTimeout(Duration.ofSeconds(5)) + .build())); + Assert.assertEquals(HandlerException.ErrorType.NOT_FOUND, he.getErrorType()); + + he = Assert.assertThrows(HandlerException.class, missingHandle::cancel); + Assert.assertEquals(HandlerException.ErrorType.NOT_FOUND, he.getErrorType()); + + he = Assert.assertThrows(HandlerException.class, missingHandle::fetchInfo); + Assert.assertEquals(HandlerException.ErrorType.NOT_FOUND, he.getErrorType()); + } + + @Test + public void startAsyncOperation() throws OperationException, OperationStillRunningException { + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(TestNexusServices.TestNexusService1.class); + + StartOperationResponse startResult = + serviceClient.startOperation(TestNexusServices.TestNexusService1::operation, "World"); + Assert.assertTrue(startResult instanceof StartOperationResponse.Async); + OperationHandle handle = + ((StartOperationResponse.Async) startResult).getHandle(); + Assert.assertThrows( + OperationStillRunningException.class, + () -> + handle.fetchResult( + FetchOperationResultOptions.newBuilder() + .setTimeout(Duration.ofSeconds(1)) + .build())); + Assert.assertEquals(OperationState.RUNNING, handle.fetchInfo().getState()); + // Thread.sleep(6000); // Wait for the operation to complete + String operationResult = + handle.fetchResult( + FetchOperationResultOptions.newBuilder().setTimeout(Duration.ofSeconds(10)).build()); + Assert.assertEquals("Hello World", operationResult); + Assert.assertEquals(OperationState.SUCCEEDED, handle.fetchInfo().getState()); + } + + @Test + public void cancelAsyncOperation() throws OperationException { + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(TestNexusServices.TestNexusService1.class); + + StartOperationResponse startResult = + serviceClient.startOperation(TestNexusServices.TestNexusService1::operation, "World"); + Assert.assertTrue(startResult instanceof StartOperationResponse.Async); + OperationHandle handle = + ((StartOperationResponse.Async) startResult).getHandle(); + handle.cancel(); + // Verify that we can call cancel again without issues + handle.cancel(); + OperationException oe = + Assert.assertThrows( + OperationException.class, + () -> + handle.fetchResult( + FetchOperationResultOptions.newBuilder() + .setTimeout(Duration.ofSeconds(5)) + .build())); + Assert.assertEquals(OperationState.CANCELED, oe.getState()); + // Verify that we can call cancel after the operation is already completed + handle.cancel(); + } + + @Test + public void cancelAsyncOperationAsync() { + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(TestNexusServices.TestNexusService1.class); + + StartOperationResponse startResult = + serviceClient + .startOperationAsync(TestNexusServices.TestNexusService1::operation, "World") + .join(); + Assert.assertTrue(startResult instanceof StartOperationResponse.Async); + OperationHandle handle = + ((StartOperationResponse.Async) startResult).getHandle(); + Assert.assertNotNull(handle); + handle.cancelAsync().join(); + CompletionException ce = + Assert.assertThrows( + CompletionException.class, + () -> + handle + .fetchResultAsync( + FetchOperationResultOptions.newBuilder() + .setTimeout(Duration.ofSeconds(5)) + .build()) + .join()); + Assert.assertTrue(ce.getCause() instanceof OperationException); + OperationException oe = (OperationException) ce.getCause(); + Assert.assertEquals(OperationState.CANCELED, oe.getState()); + } + + @Test + public void startAsyncOperationAsync() { + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(TestNexusServices.TestNexusService1.class); + + StartOperationResponse startResult = + serviceClient + .startOperationAsync(TestNexusServices.TestNexusService1::operation, "World") + .join(); + Assert.assertTrue(startResult instanceof StartOperationResponse.Async); + OperationHandle handle = + ((StartOperationResponse.Async) startResult).getHandle(); + Assert.assertEquals(OperationState.RUNNING, handle.fetchInfoAsync().join().getState()); + CompletionException ce = + Assert.assertThrows( + CompletionException.class, + () -> + handle + .fetchResultAsync( + FetchOperationResultOptions.newBuilder() + .setTimeout(Duration.ofSeconds(1)) + .build()) + .join()); + Assert.assertTrue(ce.getCause() instanceof OperationStillRunningException); + // Thread.sleep(6000); // Wait for the operation to complete + String operationResult = + handle + .fetchResultAsync( + FetchOperationResultOptions.newBuilder().setTimeout(Duration.ofSeconds(10)).build()) + .join(); + Assert.assertEquals("Hello World", operationResult); + Assert.assertEquals(OperationState.SUCCEEDED, handle.fetchInfoAsync().join().getState()); + } + + @Test + public void startWorkflowOperationFail() throws OperationException { + ServiceClient serviceClient = + testWorkflowRule.newNexusServiceClient(TestNexusServices.TestNexusService1.class); + + StartOperationResponse startResult = + serviceClient.startOperation(TestNexusServices.TestNexusService1::operation, "fail"); + Assert.assertTrue(startResult instanceof StartOperationResponse.Async); + OperationHandle handle = + ((StartOperationResponse.Async) startResult).getHandle(); + OperationException oe = + Assert.assertThrows( + OperationException.class, + () -> + handle.fetchResult( + FetchOperationResultOptions.newBuilder() + .setTimeout(Duration.ofSeconds(5)) + .build())); + Assert.assertTrue(oe.getCause() instanceof ApplicationFailure); + + OperationInfo info = handle.fetchInfo(); + Assert.assertEquals(OperationState.FAILED, info.getState()); + } + + public static class TestNexus implements TestWorkflows.TestWorkflow1 { + + @Override + public String execute(String arg) { + if (Objects.equals(arg, "fail")) { + throw ApplicationFailure.newNonRetryableFailure("fail workflow", "TestError"); + } + Workflow.sleep(Duration.ofSeconds(5)); + return "Hello " + arg; + } + } + + @ServiceImpl(service = TestNexusServices.TestNexusService1.class) + public static class TestNexusServiceImpl { + @OperationImpl + public OperationHandler operation() { + return WorkflowRunOperation.fromWorkflowMethod( + (context, details, input) -> + Nexus.getOperationContext() + .getWorkflowClient() + .newWorkflowStub( + TestWorkflows.TestWorkflow1.class, + WorkflowOptions.newBuilder() + .setWorkflowId(details.getRequestId()) + .build()) + ::execute); + } + } +} diff --git a/temporal-sdk/src/test/java/io/temporal/workflow/nexus/NexusOperationClientTest.java b/temporal-sdk/src/test/java/io/temporal/workflow/nexus/NexusOperationClientTest.java new file mode 100644 index 0000000000..dcccfa3542 --- /dev/null +++ b/temporal-sdk/src/test/java/io/temporal/workflow/nexus/NexusOperationClientTest.java @@ -0,0 +1,121 @@ +package io.temporal.workflow.nexus; + +import io.nexusrpc.OperationException; +import io.nexusrpc.OperationInfo; +import io.nexusrpc.OperationStillRunningException; +import io.nexusrpc.client.CompleteOperationOptions; +import io.nexusrpc.client.CompletionClient; +import io.nexusrpc.handler.*; +import io.temporal.client.WorkflowClient; +import io.temporal.client.WorkflowException; +import io.temporal.testing.internal.SDKTestWorkflowRule; +import io.temporal.workflow.Workflow; +import io.temporal.workflow.shared.TestNexusServices; +import io.temporal.workflow.shared.TestWorkflows; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import org.junit.Assert; +import org.junit.Rule; +import org.junit.Test; + +public class NexusOperationClientTest { + static String url; + static Map headers; + static CountDownLatch latch = new CountDownLatch(1); + + @Rule + public SDKTestWorkflowRule testWorkflowRule = + SDKTestWorkflowRule.newBuilder() + .setWorkflowTypes(TestNexus.class) + .setNexusServiceImplementation(new TestNexusServiceImpl()) + .build(); + + @Test + public void testCompletionClientSucceed() throws InterruptedException { + TestWorkflows.TestWorkflow1 workflowStub = + testWorkflowRule.newWorkflowStubTimeoutOptions(TestWorkflows.TestWorkflow1.class); + // Start the workflow + WorkflowClient.start(workflowStub::execute, testWorkflowRule.getTaskQueue()); + // Wait for the operation to start + latch.await(); + // Complete the operation + CompletionClient nexusCompletionClient = + testWorkflowRule.getWorkflowClient().newNexusCompletionClient(); + Thread.sleep(100); + nexusCompletionClient.succeedOperation( + url, "result", CompleteOperationOptions.newBuilder().setHeaders(headers).build()); + // Wait for the workflow to complete + String result = workflowStub.execute(testWorkflowRule.getTaskQueue()); + Assert.assertEquals("result", result); + } + + @Test + public void testCompletionClientFail() throws InterruptedException { + TestWorkflows.TestWorkflow1 workflowStub = + testWorkflowRule.newWorkflowStubTimeoutOptions(TestWorkflows.TestWorkflow1.class); + // Start the workflow + WorkflowClient.start(workflowStub::execute, testWorkflowRule.getTaskQueue()); + // Wait for the operation to start + latch.await(); + // Complete the operation + CompletionClient nexusCompletionClient = + testWorkflowRule.getWorkflowClient().newNexusCompletionClient(); + Thread.sleep(100); + nexusCompletionClient.failOperation( + url, + OperationException.failure(new Exception("test failure")), + CompleteOperationOptions.newBuilder().setHeaders(headers).build()); + // Wait for the workflow to complete with a failure + WorkflowException we = + Assert.assertThrows( + WorkflowException.class, () -> workflowStub.execute(testWorkflowRule.getTaskQueue())); + System.out.println(we); + } + + public static class TestNexus implements TestWorkflows.TestWorkflow1 { + @Override + public String execute(String input) { + // Try to call with the typed stub + TestNexusServices.TestNexusService1 serviceStub = + Workflow.newNexusServiceStub(TestNexusServices.TestNexusService1.class); + return serviceStub.operation(input); + } + } + + @ServiceImpl(service = TestNexusServices.TestNexusService1.class) + public static class TestNexusServiceImpl { + @OperationImpl + public OperationHandler operation() { + return new TestAsyncHandler(); + } + } + + public static class TestAsyncHandler implements OperationHandler { + + @Override + public OperationStartResult start( + OperationContext context, OperationStartDetails details, String param) + throws OperationException, HandlerException { + url = details.getCallbackUrl(); + headers = details.getCallbackHeaders(); + latch.countDown(); + return OperationStartResult.async("token-" + param); + } + + @Override + public String fetchResult(OperationContext context, OperationFetchResultDetails details) + throws OperationStillRunningException, OperationException, HandlerException { + return ""; + } + + @Override + public OperationInfo fetchInfo(OperationContext context, OperationFetchInfoDetails details) + throws HandlerException { + return null; + } + + @Override + public void cancel(OperationContext context, OperationCancelDetails details) + throws HandlerException {} + } +} diff --git a/temporal-serviceclient/src/main/java/io/temporal/internal/common/NexusFailureUtil.java b/temporal-serviceclient/src/main/java/io/temporal/internal/common/NexusFailureUtil.java new file mode 100644 index 0000000000..277fceefbe --- /dev/null +++ b/temporal-serviceclient/src/main/java/io/temporal/internal/common/NexusFailureUtil.java @@ -0,0 +1,69 @@ +package io.temporal.internal.common; + +import static java.nio.charset.StandardCharsets.UTF_8; + +import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.util.JsonFormat; +import io.temporal.api.common.v1.Payload; +import io.temporal.api.common.v1.Payloads; +import io.temporal.api.failure.v1.ApplicationFailureInfo; +import io.temporal.api.failure.v1.Failure; +import io.temporal.api.failure.v1.NexusHandlerFailureInfo; +import io.temporal.api.nexus.v1.HandlerError; +import java.util.Map; +import java.util.stream.Collectors; + +public class NexusFailureUtil { + private static final JsonFormat.Parser JSON_PARSER = JsonFormat.parser(); + private static final String FAILURE_TYPE_STRING = Failure.getDescriptor().getFullName(); + + public static Failure handlerErrorToFailure(HandlerError err) { + return Failure.newBuilder() + .setMessage(err.getFailure().getMessage()) + .setNexusHandlerFailureInfo( + NexusHandlerFailureInfo.newBuilder() + .setType(err.getErrorType()) + .setRetryBehavior(err.getRetryBehavior()) + .build()) + .setCause(nexusFailureToAPIFailure(err.getFailure(), false)) + .build(); + } + + /** + * nexusFailureToAPIFailure converts a Nexus Failure to an API proto Failure. If the failure + * metadata "type" field is set to the fullname of the temporal API Failure message, the failure + * is reconstructed using protojson.Unmarshal on the failure details field. + */ + public static Failure nexusFailureToAPIFailure( + io.temporal.api.nexus.v1.Failure failure, boolean retryable) { + Failure.Builder apiFailure = Failure.newBuilder(); + if (failure.getMetadataMap().containsKey("type") + && failure.getMetadataMap().get("type").equals(FAILURE_TYPE_STRING)) { + try { + JSON_PARSER.merge(failure.getDetails().toString(UTF_8), apiFailure); + } catch (InvalidProtocolBufferException e) { + throw new RuntimeException(e); + } + } else { + Payloads payloads = nexusFailureMetadataToPayloads(failure); + ApplicationFailureInfo.Builder applicationFailureInfo = ApplicationFailureInfo.newBuilder(); + applicationFailureInfo.setType("NexusFailure"); + applicationFailureInfo.setDetails(payloads); + applicationFailureInfo.setNonRetryable(!retryable); + apiFailure.setApplicationFailureInfo(applicationFailureInfo.build()); + } + apiFailure.setMessage(failure.getMessage()); + return apiFailure.build(); + } + + public static Payloads nexusFailureMetadataToPayloads(io.temporal.api.nexus.v1.Failure failure) { + Map metadata = + failure.getMetadataMap().entrySet().stream() + .collect( + Collectors.toMap(Map.Entry::getKey, e -> ByteString.copyFromUtf8(e.getValue()))); + return Payloads.newBuilder() + .addPayloads(Payload.newBuilder().putAllMetadata(metadata).setData(failure.getDetails())) + .build(); + } +} diff --git a/temporal-serviceclient/src/main/proto b/temporal-serviceclient/src/main/proto index 49f9286fae..23f2591fec 160000 --- a/temporal-serviceclient/src/main/proto +++ b/temporal-serviceclient/src/main/proto @@ -1 +1 @@ -Subproject commit 49f9286fae31a472ba4ca953df6a7432c493085f +Subproject commit 23f2591fec250c31c6c6bb21aabdc90843e8ecdb diff --git a/temporal-test-server/src/main/java/io/temporal/internal/testservice/NexusTaskToken.java b/temporal-test-server/src/main/java/io/temporal/internal/testservice/NexusTaskToken.java index e7d3a8918c..bf2259f5d4 100644 --- a/temporal-test-server/src/main/java/io/temporal/internal/testservice/NexusTaskToken.java +++ b/temporal-test-server/src/main/java/io/temporal/internal/testservice/NexusTaskToken.java @@ -2,93 +2,12 @@ import com.google.protobuf.ByteString; import io.grpc.Status; -import io.temporal.api.common.v1.WorkflowExecution; -import java.io.*; -import java.util.Objects; -import javax.annotation.Nonnull; +import java.io.ByteArrayInputStream; +import java.io.DataInputStream; +import java.io.IOException; public class NexusTaskToken { - - @Nonnull private final NexusOperationRef ref; - private final int attempt; - private final boolean isCancel; - - NexusTaskToken( - @Nonnull String namespace, - @Nonnull WorkflowExecution execution, - long scheduledEventId, - int attempt, - boolean isCancel) { - this( - new ExecutionId(Objects.requireNonNull(namespace), Objects.requireNonNull(execution)), - scheduledEventId, - attempt, - isCancel); - } - - NexusTaskToken( - @Nonnull String namespace, - @Nonnull String workflowId, - @Nonnull String runId, - long scheduledEventId, - int attempt, - boolean isCancel) { - this( - namespace, - WorkflowExecution.newBuilder() - .setWorkflowId(Objects.requireNonNull(workflowId)) - .setRunId(Objects.requireNonNull(runId)) - .build(), - scheduledEventId, - attempt, - isCancel); - } - - NexusTaskToken( - @Nonnull ExecutionId executionId, long scheduledEventId, int attempt, boolean isCancel) { - this( - new NexusOperationRef(Objects.requireNonNull(executionId), scheduledEventId), - attempt, - isCancel); - } - - public NexusTaskToken(@Nonnull NexusOperationRef ref, int attempt, boolean isCancel) { - this.ref = Objects.requireNonNull(ref); - this.attempt = attempt; - this.isCancel = isCancel; - } - - public NexusOperationRef getOperationRef() { - return ref; - } - - public long getAttempt() { - return attempt; - } - - public boolean isCancel() { - return isCancel; - } - - /** Used for task tokens. */ - public ByteString toBytes() { - try (ByteArrayOutputStream bout = new ByteArrayOutputStream(); - DataOutputStream out = new DataOutputStream(bout)) { - ExecutionId executionId = ref.getExecutionId(); - out.writeUTF(executionId.getNamespace()); - WorkflowExecution execution = executionId.getExecution(); - out.writeUTF(execution.getWorkflowId()); - out.writeUTF(execution.getRunId()); - out.writeLong(ref.getScheduledEventId()); - out.writeInt(attempt); - out.writeBoolean(isCancel); - return ByteString.copyFrom(bout.toByteArray()); - } catch (IOException e) { - throw Status.INTERNAL.withCause(e).withDescription(e.getMessage()).asRuntimeException(); - } - } - - public static NexusTaskToken fromBytes(ByteString serialized) { + public static NexusWorkflowTaskToken fromBytes(ByteString serialized) { ByteArrayInputStream bin = new ByteArrayInputStream(serialized.toByteArray()); DataInputStream in = new DataInputStream(bin); try { @@ -98,7 +17,8 @@ public static NexusTaskToken fromBytes(ByteString serialized) { long scheduledEventId = in.readLong(); int attempt = in.readInt(); boolean isCancel = in.readBoolean(); - return new NexusTaskToken(namespace, workflowId, runId, scheduledEventId, attempt, isCancel); + return new NexusWorkflowTaskToken( + namespace, workflowId, runId, scheduledEventId, attempt, isCancel); } catch (IOException e) { throw Status.INVALID_ARGUMENT .withCause(e) diff --git a/temporal-test-server/src/main/java/io/temporal/internal/testservice/NexusWorkflowTaskToken.java b/temporal-test-server/src/main/java/io/temporal/internal/testservice/NexusWorkflowTaskToken.java new file mode 100644 index 0000000000..c90fb0fe74 --- /dev/null +++ b/temporal-test-server/src/main/java/io/temporal/internal/testservice/NexusWorkflowTaskToken.java @@ -0,0 +1,96 @@ +package io.temporal.internal.testservice; + +import com.google.protobuf.ByteString; +import io.temporal.api.common.v1.WorkflowExecution; +import io.temporal.api.testservice.internal.v1.NexusTaskToken; +import java.io.*; +import java.util.Objects; +import javax.annotation.Nonnull; + +public class NexusWorkflowTaskToken { + + @Nonnull private final NexusOperationRef ref; + private final int attempt; + private final boolean isCancel; + + NexusWorkflowTaskToken( + @Nonnull String namespace, + @Nonnull WorkflowExecution execution, + long scheduledEventId, + int attempt, + boolean isCancel) { + this( + new ExecutionId(Objects.requireNonNull(namespace), Objects.requireNonNull(execution)), + scheduledEventId, + attempt, + isCancel); + } + + NexusWorkflowTaskToken( + @Nonnull String namespace, + @Nonnull String workflowId, + @Nonnull String runId, + long scheduledEventId, + int attempt, + boolean isCancel) { + this( + namespace, + WorkflowExecution.newBuilder() + .setWorkflowId(Objects.requireNonNull(workflowId)) + .setRunId(Objects.requireNonNull(runId)) + .build(), + scheduledEventId, + attempt, + isCancel); + } + + NexusWorkflowTaskToken( + @Nonnull ExecutionId executionId, long scheduledEventId, int attempt, boolean isCancel) { + this( + new NexusOperationRef(Objects.requireNonNull(executionId), scheduledEventId), + attempt, + isCancel); + } + + public NexusWorkflowTaskToken(@Nonnull NexusOperationRef ref, int attempt, boolean isCancel) { + this.ref = Objects.requireNonNull(ref); + this.attempt = attempt; + this.isCancel = isCancel; + } + + public static NexusWorkflowTaskToken fromTaskToken(NexusTaskToken nexusTaskToken) { + return new NexusWorkflowTaskToken( + nexusTaskToken.getWorkflowCaller().getNamespace(), + nexusTaskToken.getWorkflowCaller().getExecution(), + nexusTaskToken.getWorkflowCaller().getScheduledEventId(), + nexusTaskToken.getAttempt(), + nexusTaskToken.getCancelled()); + } + + public NexusOperationRef getOperationRef() { + return ref; + } + + public long getAttempt() { + return attempt; + } + + public boolean isCancel() { + return isCancel; + } + + /** Used for task tokens. */ + public ByteString toBytes() { + return NexusTaskToken.newBuilder() + .setAttempt(attempt) + .setCancelled(isCancel) + .setWorkflowCaller( + NexusTaskToken.WorkflowCallerTaskToken.newBuilder() + .setExecution(ref.getExecutionId().getExecution()) + .setNamespace(ref.getExecutionId().getNamespace()) + .setScheduledEventId(ref.getScheduledEventId()) + .build()) + .build() + .toByteString(); + } +} diff --git a/temporal-test-server/src/main/java/io/temporal/internal/testservice/StateMachines.java b/temporal-test-server/src/main/java/io/temporal/internal/testservice/StateMachines.java index f1138f9b41..703d796ab0 100644 --- a/temporal-test-server/src/main/java/io/temporal/internal/testservice/StateMachines.java +++ b/temporal-test-server/src/main/java/io/temporal/internal/testservice/StateMachines.java @@ -675,7 +675,7 @@ private static void scheduleNexusOperation( long scheduledEventId = ctx.addEvent(event.build()); NexusOperationRef ref = new NexusOperationRef(ctx.getExecutionId(), scheduledEventId); - NexusTaskToken taskToken = new NexusTaskToken(ref, data.getAttempt(), false); + NexusWorkflowTaskToken taskToken = new NexusWorkflowTaskToken(ref, data.getAttempt(), false); Link link = workflowEventToNexusLink( @@ -914,7 +914,7 @@ private static RetryState attemptNexusOperationRetry( data.nextAttemptScheduleTime = Timestamps.add(ProtobufTimeUtils.getCurrentProtoTime(), data.nextBackoffInterval); task.setTaskToken( - new NexusTaskToken( + new NexusWorkflowTaskToken( ctx.getExecutionId(), data.scheduledEventId, nextAttempt.getAttempt(), @@ -942,8 +942,9 @@ private static void requestCancelNexusOperation( .setWorkflowTaskCompletedEventId(workflowTaskCompletedId)) .build()); - NexusTaskToken taskToken = - new NexusTaskToken(ctx.getExecutionId(), data.scheduledEventId, data.getAttempt(), true); + NexusWorkflowTaskToken taskToken = + new NexusWorkflowTaskToken( + ctx.getExecutionId(), data.scheduledEventId, data.getAttempt(), true); PollNexusTaskQueueResponse.Builder pollResponse = PollNexusTaskQueueResponse.newBuilder() diff --git a/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowMutableState.java b/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowMutableState.java index 9fee4c6256..bcef214b00 100644 --- a/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowMutableState.java +++ b/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowMutableState.java @@ -112,7 +112,7 @@ void completeAsyncNexusOperation( void failNexusOperation(NexusOperationRef ref, Failure failure); - boolean validateOperationTaskToken(NexusTaskToken tt); + boolean validateOperationTaskToken(NexusWorkflowTaskToken tt); QueryWorkflowResponse query(QueryWorkflowRequest queryRequest, long deadline); diff --git a/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowMutableStateImpl.java b/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowMutableStateImpl.java index f40ceab1a3..051f597328 100644 --- a/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowMutableStateImpl.java +++ b/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowMutableStateImpl.java @@ -3651,7 +3651,7 @@ private boolean operationInFlight(StateMachines.State operationState) { } @Override - public boolean validateOperationTaskToken(NexusTaskToken tt) { + public boolean validateOperationTaskToken(NexusWorkflowTaskToken tt) { NexusOperationData data = getPendingNexusOperation(tt.getOperationRef().getScheduledEventId()).getData(); if (tt.getAttempt() != data.getAttempt()) { diff --git a/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowService.java b/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowService.java index 59a67d8b68..5bc301ed4f 100644 --- a/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowService.java +++ b/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowService.java @@ -6,9 +6,9 @@ import static io.temporal.api.workflowservice.v1.ExecuteMultiOperationRequest.Operation.OperationCase.START_WORKFLOW; import static io.temporal.api.workflowservice.v1.ExecuteMultiOperationRequest.Operation.OperationCase.UPDATE_WORKFLOW; import static io.temporal.internal.testservice.CronUtils.getBackoffInterval; -import static java.nio.charset.StandardCharsets.UTF_8; import com.google.common.base.Preconditions; +import com.google.common.base.Strings; import com.google.common.base.Throwables; import com.google.protobuf.*; import com.google.protobuf.util.JsonFormat; @@ -37,6 +37,7 @@ import io.temporal.api.workflow.v1.RequestIdInfo; import io.temporal.api.workflow.v1.WorkflowExecutionInfo; import io.temporal.api.workflowservice.v1.*; +import io.temporal.internal.common.NexusFailureUtil; import io.temporal.internal.common.ProtoUtils; import io.temporal.internal.common.ProtobufTimeUtils; import io.temporal.internal.testservice.TestWorkflowStore.WorkflowState; @@ -509,7 +510,7 @@ public void getWorkflowExecutionHistory( // context deadline and throw DEADLINE_EXCEEDED if the deadline is less // than 20s. // If it's longer than 20 seconds - we return an empty result. - Deadline.after(20, TimeUnit.SECONDS))); + Deadline.after(100, TimeUnit.MILLISECONDS))); responseObserver.onCompleted(); } catch (StatusRuntimeException e) { if (e.getStatus().getCode() == Status.Code.INTERNAL) { @@ -885,57 +886,96 @@ private static Failure wrapNexusOperationFailure(Failure cause) { .build(); } + private io.temporal.api.testservice.internal.v1.NexusTaskToken parseNexusTaskToken( + ByteString token) { + try { + return io.temporal.api.testservice.internal.v1.NexusTaskToken.parseFrom(token); + } catch (Exception e) { + throw Status.INVALID_ARGUMENT + .withCause(e) + .withDescription(e.getMessage()) + .asRuntimeException(); + } + } + @Override public void respondNexusTaskCompleted( RespondNexusTaskCompletedRequest request, StreamObserver responseObserver) { try { - NexusTaskToken tt = NexusTaskToken.fromBytes(request.getTaskToken()); - TestWorkflowMutableState mutableState = - getMutableState(tt.getOperationRef().getExecutionId()); - if (!mutableState.validateOperationTaskToken(tt)) { - responseObserver.onNext(RespondNexusTaskCompletedResponse.getDefaultInstance()); - responseObserver.onCompleted(); - return; - } - - if (request.getResponse().hasCancelOperation()) { - mutableState.cancelNexusOperationRequestAcknowledge(tt.getOperationRef()); - } else if (request.getResponse().hasStartOperation()) { - StartOperationResponse startResp = request.getResponse().getStartOperation(); - if (startResp.hasOperationError()) { - UnsuccessfulOperationError opError = startResp.getOperationError(); - Failure.Builder b = Failure.newBuilder().setMessage(opError.getFailure().getMessage()); - - if (startResp.getOperationError().getOperationState().equals("canceled")) { - b.setCanceledFailureInfo( - CanceledFailureInfo.newBuilder() - .setDetails(nexusFailureMetadataToPayloads(opError.getFailure()))); - mutableState.cancelNexusOperation(tt.getOperationRef(), b.build()); + io.temporal.api.testservice.internal.v1.NexusTaskToken nexusTaskToken = + parseNexusTaskToken(request.getTaskToken()); + + if (nexusTaskToken.hasWorkflowCaller()) { + NexusWorkflowTaskToken tt = NexusWorkflowTaskToken.fromTaskToken(nexusTaskToken); + TestWorkflowMutableState mutableState = + getMutableState(tt.getOperationRef().getExecutionId()); + if (!mutableState.validateOperationTaskToken(tt)) { + responseObserver.onNext(RespondNexusTaskCompletedResponse.getDefaultInstance()); + responseObserver.onCompleted(); + return; + } + if (request.getResponse().hasCancelOperation()) { + mutableState.cancelNexusOperationRequestAcknowledge(tt.getOperationRef()); + } else if (request.getResponse().hasStartOperation()) { + StartOperationResponse startResp = request.getResponse().getStartOperation(); + if (startResp.hasOperationError()) { + UnsuccessfulOperationError opError = startResp.getOperationError(); + Failure.Builder b = Failure.newBuilder().setMessage(opError.getFailure().getMessage()); + + if (startResp.getOperationError().getOperationState().equals("canceled")) { + b.setCanceledFailureInfo( + CanceledFailureInfo.newBuilder() + .setDetails( + NexusFailureUtil.nexusFailureMetadataToPayloads(opError.getFailure()))); + mutableState.cancelNexusOperation(tt.getOperationRef(), b.build()); + } else { + mutableState.failNexusOperation( + tt.getOperationRef(), + wrapNexusOperationFailure( + NexusFailureUtil.nexusFailureToAPIFailure(opError.getFailure(), false))); + } + } else if (startResp.hasAsyncSuccess()) { + // Start event is only recorded for async success + mutableState.startNexusOperation( + tt.getOperationRef().getScheduledEventId(), + request.getIdentity(), + startResp.getAsyncSuccess()); + } else if (startResp.hasSyncSuccess()) { + mutableState.completeNexusOperation( + tt.getOperationRef(), startResp.getSyncSuccess().getPayload()); } else { - mutableState.failNexusOperation( - tt.getOperationRef(), - wrapNexusOperationFailure(nexusFailureToAPIFailure(opError.getFailure(), false))); + throw Status.INVALID_ARGUMENT + .withDescription("Expected success or OperationError to be set on request.") + .asRuntimeException(); } - } else if (startResp.hasAsyncSuccess()) { - // Start event is only recorded for async success - mutableState.startNexusOperation( - tt.getOperationRef().getScheduledEventId(), - request.getIdentity(), - startResp.getAsyncSuccess()); - } else if (startResp.hasSyncSuccess()) { - mutableState.completeNexusOperation( - tt.getOperationRef(), startResp.getSyncSuccess().getPayload()); } else { throw Status.INVALID_ARGUMENT - .withDescription("Expected success or OperationError to be set on request.") + .withDescription("Expected StartOperation or CancelOperation to be set on request.") + .asRuntimeException(); + } + } else if (nexusTaskToken.hasExternalCaller()) { + if (request.getResponse().hasStartOperation()) { + store.respondStartNexusOperationTask( + nexusTaskToken.getExternalCaller().getId(), + request.getResponse().getStartOperation()); + } else if (request.getResponse().hasCancelOperation()) { + store.respondCancelNexusOperationTask(nexusTaskToken.getExternalCaller().getId()); + } else if (request.getResponse().hasGetOperationInfo()) { + store.respondGetNexusOperationInfoTask( + nexusTaskToken.getExternalCaller().getId(), + request.getResponse().getGetOperationInfo()); + } else if (request.getResponse().hasGetOperationResult()) { + store.respondGetNexusOperationResultTask( + nexusTaskToken.getExternalCaller().getId(), + request.getResponse().getGetOperationResult()); + } else { + throw Status.INVALID_ARGUMENT + .withDescription("Expected StartOperation or CancelOperation to be set on request.") .asRuntimeException(); } - } else { - throw Status.INVALID_ARGUMENT - .withDescription("Expected StartOperation or CancelOperation to be set on request.") - .asRuntimeException(); } + responseObserver.onNext(RespondNexusTaskCompletedResponse.getDefaultInstance()); responseObserver.onCompleted(); } catch (StatusRuntimeException e) { @@ -953,12 +993,20 @@ public void respondNexusTaskFailed( .withDescription("Nexus handler error not set on RespondNexusTaskFailedRequest") .asRuntimeException(); } - NexusTaskToken tt = NexusTaskToken.fromBytes(request.getTaskToken()); - TestWorkflowMutableState mutableState = - getMutableState(tt.getOperationRef().getExecutionId()); - if (mutableState.validateOperationTaskToken(tt)) { - Failure failure = handlerErrorToFailure(request.getError()); - mutableState.failNexusOperation(tt.getOperationRef(), failure); + io.temporal.api.testservice.internal.v1.NexusTaskToken nexusTaskToken = + parseNexusTaskToken(request.getTaskToken()); + + if (nexusTaskToken.hasWorkflowCaller()) { + NexusWorkflowTaskToken tt = NexusWorkflowTaskToken.fromTaskToken(nexusTaskToken); + TestWorkflowMutableState mutableState = + getMutableState(tt.getOperationRef().getExecutionId()); + if (mutableState.validateOperationTaskToken(tt)) { + Failure failure = NexusFailureUtil.handlerErrorToFailure(request.getError()); + mutableState.failNexusOperation(tt.getOperationRef(), failure); + } + } else if (nexusTaskToken.hasExternalCaller()) { + String requestId = nexusTaskToken.getExternalCaller().getId(); + store.respondFailNexusTask(requestId, request.getError()); } responseObserver.onNext(RespondNexusTaskFailedResponse.getDefaultInstance()); responseObserver.onCompleted(); @@ -1029,55 +1077,6 @@ public void completeNexusOperation( } } - private static Failure handlerErrorToFailure(HandlerError err) { - return Failure.newBuilder() - .setMessage(err.getFailure().getMessage()) - .setNexusHandlerFailureInfo( - NexusHandlerFailureInfo.newBuilder() - .setType(err.getErrorType()) - .setRetryBehavior(err.getRetryBehavior()) - .build()) - .setCause(nexusFailureToAPIFailure(err.getFailure(), false)) - .build(); - } - - /** - * nexusFailureToAPIFailure converts a Nexus Failure to an API proto Failure. If the failure - * metadata "type" field is set to the fullname of the temporal API Failure message, the failure - * is reconstructed using protojson.Unmarshal on the failure details field. - */ - private static Failure nexusFailureToAPIFailure( - io.temporal.api.nexus.v1.Failure failure, boolean retryable) { - Failure.Builder apiFailure = Failure.newBuilder(); - if (failure.getMetadataMap().containsKey("type") - && failure.getMetadataMap().get("type").equals(FAILURE_TYPE_STRING)) { - try { - JSON_PARSER.merge(failure.getDetails().toString(UTF_8), apiFailure); - } catch (InvalidProtocolBufferException e) { - throw new RuntimeException(e); - } - } else { - Payloads payloads = nexusFailureMetadataToPayloads(failure); - ApplicationFailureInfo.Builder applicationFailureInfo = ApplicationFailureInfo.newBuilder(); - applicationFailureInfo.setType("NexusFailure"); - applicationFailureInfo.setDetails(payloads); - applicationFailureInfo.setNonRetryable(!retryable); - apiFailure.setApplicationFailureInfo(applicationFailureInfo.build()); - } - apiFailure.setMessage(failure.getMessage()); - return apiFailure.build(); - } - - private static Payloads nexusFailureMetadataToPayloads(io.temporal.api.nexus.v1.Failure failure) { - Map metadata = - failure.getMetadataMap().entrySet().stream() - .collect( - Collectors.toMap(Map.Entry::getKey, e -> ByteString.copyFromUtf8(e.getValue()))); - return Payloads.newBuilder() - .addPayloads(Payload.newBuilder().putAllMetadata(metadata).setData(failure.getDetails())) - .build(); - } - @Override public void requestCancelWorkflowExecution( RequestCancelWorkflowExecutionRequest cancelRequest, @@ -1872,6 +1871,229 @@ public void updateWorkerDeploymentVersionMetadata( responseObserver); } + private TestWorkflowStore.TaskQueueId getTaskQueueIdFromTarget( + String namespace, TaskDispatchTarget target) { + if (target.hasEndpoint()) { + Endpoint endpoint = nexusEndpointStore.getEndpointByName(target.getEndpoint()); + return new TestWorkflowStore.TaskQueueId( + endpoint.getSpec().getTarget().getWorker().getNamespace(), + endpoint.getSpec().getTarget().getWorker().getTaskQueue()); + } else if (target.hasTaskQueue()) { + return new TestWorkflowStore.TaskQueueId(namespace, target.getTaskQueue()); + } else { + throw createInvalidArgument("Target must have either endpoint or task queue set."); + } + } + + public void startNexusOperation( + StartNexusOperationRequest request, + StreamObserver responseObserver) { + try { + if (request.getNamespace().isEmpty()) { + throw createInvalidArgument("Namespace not set on request."); + } + if (!request.hasTarget()) { + throw createInvalidArgument("Target not set on request."); + } + if (Strings.isNullOrEmpty(request.getOperation())) { + throw createInvalidArgument("Operation not set on request."); + } + if (Strings.isNullOrEmpty(request.getService())) { + throw createInvalidArgument("Service not set on request."); + } + + TestWorkflowStore.TaskQueueId taskQueueId = + getTaskQueueIdFromTarget(request.getNamespace(), request.getTarget()); + StartOperationRequest.Builder taskRequest = + StartOperationRequest.newBuilder() + .setService(request.getService()) + .setOperation(request.getOperation()) + .setRequestId(request.getRequestId()) + .setCallback(request.getCallback()) + .putAllCallbackHeader(request.getCallbackHeaderMap()) + .addAllLinks(request.getLinksList()); + + if (request.hasPayload()) { + taskRequest.setPayload(request.getPayload()); + } + + // @Nullable Deadline deadline = Context.current().getDeadline(); + store + .startNexusOperation(taskQueueId, taskRequest.build(), request.getHeaderMap()) + .thenAccept( + response -> { + responseObserver.onNext(response); + responseObserver.onCompleted(); + }); + } catch (StatusRuntimeException e) { + handleStatusRuntimeException(e, responseObserver); + } + } + + @Override + public void requestCancelNexusOperation( + RequestCancelNexusOperationRequest request, + StreamObserver responseObserver) { + try { + if (request.getNamespace().isEmpty()) { + throw createInvalidArgument("Namespace not set on request."); + } + if (!request.hasTarget()) { + throw createInvalidArgument("Target not set on request."); + } + if (Strings.isNullOrEmpty(request.getOperation())) { + throw createInvalidArgument("Operation not set on request."); + } + if (Strings.isNullOrEmpty(request.getService())) { + throw createInvalidArgument("Service not set on request."); + } + if (Strings.isNullOrEmpty(request.getOperationToken())) { + throw createInvalidArgument("Operation token not set on request."); + } + + TestWorkflowStore.TaskQueueId taskQueueId = + getTaskQueueIdFromTarget(request.getNamespace(), request.getTarget()); + store + .requestCancelNexusOperation( + taskQueueId, + CancelOperationRequest.newBuilder() + .setOperation(request.getOperation()) + .setService(request.getService()) + .setOperationToken(request.getOperationToken()) + .build(), + request.getHeaderMap()) + .thenApply( + response -> { + responseObserver.onNext(response); + responseObserver.onCompleted(); + return null; + }); + } catch (StatusRuntimeException e) { + handleStatusRuntimeException(e, responseObserver); + } + } + + public void getNexusOperationInfo( + GetNexusOperationInfoRequest request, + io.grpc.stub.StreamObserver responseObserver) { + try { + if (request.getNamespace().isEmpty()) { + throw createInvalidArgument("Namespace not set on request."); + } + if (!request.hasTarget()) { + throw createInvalidArgument("Target not set on request."); + } + if (Strings.isNullOrEmpty(request.getOperation())) { + throw createInvalidArgument("Operation not set on request."); + } + if (Strings.isNullOrEmpty(request.getService())) { + throw createInvalidArgument("Service not set on request."); + } + if (Strings.isNullOrEmpty(request.getOperationToken())) { + throw createInvalidArgument("Operation token not set on request."); + } + + TestWorkflowStore.TaskQueueId taskQueueId = + getTaskQueueIdFromTarget(request.getNamespace(), request.getTarget()); + store + .getNexusOperationInfo( + taskQueueId, + GetOperationInfoRequest.newBuilder() + .setService(request.getService()) + .setOperation(request.getOperation()) + .setOperationToken(request.getOperationToken()) + .build(), + request.getHeaderMap()) + .thenAccept( + response -> { + responseObserver.onNext(response); + responseObserver.onCompleted(); + }); + } catch (StatusRuntimeException e) { + handleStatusRuntimeException(e, responseObserver); + } + } + + @Override + public void getNexusOperationResult( + GetNexusOperationResultRequest request, + StreamObserver responseObserver) { + try { + if (request.getNamespace().isEmpty()) { + throw createInvalidArgument("Namespace not set on request."); + } + if (!request.hasTarget()) { + throw createInvalidArgument("Target not set on request."); + } + if (Strings.isNullOrEmpty(request.getOperation())) { + throw createInvalidArgument("Operation not set on request."); + } + if (Strings.isNullOrEmpty(request.getService())) { + throw createInvalidArgument("Service not set on request."); + } + if (Strings.isNullOrEmpty(request.getOperationToken())) { + throw createInvalidArgument("Operation token not set on request."); + } + + TestWorkflowStore.TaskQueueId taskQueueId = + getTaskQueueIdFromTarget(request.getNamespace(), request.getTarget()); + + GetOperationResultRequest.Builder taskRequest = + GetOperationResultRequest.newBuilder() + .setService(request.getService()) + .setOperation(request.getOperation()) + .setOperationToken(request.getOperationToken()); + if (request.hasWait()) { + taskRequest.setWait(request.getWait()); + } + + store + .getNexusOperationResult(taskQueueId, taskRequest.build(), request.getHeaderMap()) + .thenAccept( + response -> { + responseObserver.onNext(response); + responseObserver.onCompleted(); + }); + } catch (StatusRuntimeException e) { + handleStatusRuntimeException(e, responseObserver); + } + } + + public void completeNexusOperation( + CompleteNexusOperationRequest request, + StreamObserver responseObserver) { + try { + if (!request.hasCallback()) { + throw createInvalidArgument("Callback not set on request."); + } + + String serializedRef = request.getCallback().getHeaderOrThrow("operation-reference"); + NexusOperationRef ref = NexusOperationRef.fromBytes(serializedRef.getBytes()); + TestWorkflowMutableState target = getMutableState(ref.getExecutionId()); + Payload p = request.hasResult() ? request.getResult() : Payload.getDefaultInstance(); + if (request.hasResult()) { + target.completeAsyncNexusOperation( + ref, + p, + request.getOperationToken(), + io.temporal.api.nexus.v1.Link.getDefaultInstance()); + } else if (request.hasOperationError()) { + target.failNexusOperation( + ref, + wrapNexusOperationFailure( + NexusFailureUtil.nexusFailureToAPIFailure( + request.getOperationError().getFailure(), false))); + } else { + throw createInvalidArgument("Either result or operation error must be set on request."); + } + + responseObserver.onNext(CompleteNexusOperationResponse.getDefaultInstance()); + responseObserver.onCompleted(); + } catch (StatusRuntimeException e) { + handleStatusRuntimeException(e, responseObserver); + } + } + private R requireNotNull(String fieldName, R value) { if (value == null) { throw Status.INVALID_ARGUMENT diff --git a/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowStore.java b/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowStore.java index ee11316837..b44c944494 100644 --- a/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowStore.java +++ b/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowStore.java @@ -3,12 +3,15 @@ import com.google.protobuf.Timestamp; import io.grpc.Deadline; import io.temporal.api.common.v1.Priority; +import io.temporal.api.nexus.v1.*; import io.temporal.api.workflow.v1.WorkflowExecutionInfo; import io.temporal.api.workflowservice.v1.*; import java.time.Duration; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; interface TestWorkflowStore { @@ -171,6 +174,28 @@ void sendQueryTask( PollWorkflowTaskQueueResponse.Builder task, Priority priority); + void respondGetNexusOperationInfoTask(String id, GetOperationInfoResponse getOperationInfo); + + void respondCancelNexusOperationTask(String requestId); + + void respondStartNexusOperationTask(String id, StartOperationResponse startOperation); + + void respondGetNexusOperationResultTask(String id, GetOperationResultResponse getOperationResult); + + void respondFailNexusTask(String requestId, HandlerError handlerError); + + CompletableFuture startNexusOperation( + TaskQueueId taskQueueId, StartOperationRequest taskRequest, Map headers); + + CompletableFuture requestCancelNexusOperation( + TaskQueueId taskQueueId, CancelOperationRequest taskRequest, Map headers); + + CompletableFuture getNexusOperationInfo( + TaskQueueId taskQueueId, GetOperationInfoRequest taskRequest, Map headers); + + CompletableFuture getNexusOperationResult( + TaskQueueId taskQueueId, GetOperationResultRequest taskRequest, Map headers); + GetWorkflowExecutionHistoryResponse getWorkflowExecutionHistory( ExecutionId executionId, GetWorkflowExecutionHistoryRequest getRequest, diff --git a/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowStoreImpl.java b/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowStoreImpl.java index 16c144c5e2..97c5ec65b3 100644 --- a/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowStoreImpl.java +++ b/temporal-test-server/src/main/java/io/temporal/internal/testservice/TestWorkflowStoreImpl.java @@ -13,7 +13,9 @@ import io.temporal.api.enums.v1.WorkflowExecutionStatus; import io.temporal.api.history.v1.History; import io.temporal.api.history.v1.HistoryEvent; +import io.temporal.api.nexus.v1.*; import io.temporal.api.taskqueue.v1.StickyExecutionAttributes; +import io.temporal.api.testservice.internal.v1.NexusTaskToken; import io.temporal.api.workflow.v1.WorkflowExecutionInfo; import io.temporal.api.workflowservice.v1.*; import io.temporal.common.WorkflowExecutionHistory; @@ -22,17 +24,15 @@ import io.temporal.internal.testservice.RequestContext.Timer; import io.temporal.workflow.Functions; import java.time.Duration; -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.Map.Entry; -import java.util.Optional; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.Condition; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; import java.util.stream.Collectors; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -48,6 +48,14 @@ class TestWorkflowStoreImpl implements TestWorkflowStore { private final Map> workflowTaskQueues = new HashMap<>(); private final Map> nexusTaskQueues = new HashMap<>(); + private final Map> + nexusExternalCallerStartRequests = new HashMap<>(); + private final Map> + nexusExternalCallerCancelRequests = new HashMap<>(); + private final Map> + nexusExternalCallerGetInfoRequests = new HashMap<>(); + private final Map> + nexusExternalCallerGetResultRequests = new HashMap<>(); private final SelfAdvancingTimer selfAdvancingTimer; private static class HistoryStore { @@ -384,6 +392,267 @@ public void sendQueryTask( workflowTaskQueue.add(task, priority); } + @Override + public void respondGetNexusOperationInfoTask( + String requestId, GetOperationInfoResponse getOperationInfo) { + Consumer callback = + nexusExternalCallerGetInfoRequests.remove(requestId); + if (callback == null) { + throw Status.NOT_FOUND + .withDescription("No such requestId: " + requestId) + .asRuntimeException(); + } + callback.accept( + GetNexusOperationInfoResponse.newBuilder().setInfo(getOperationInfo.getInfo()).build()); + } + + @Override + public void respondCancelNexusOperationTask(String requestId) { + Consumer callback = + nexusExternalCallerCancelRequests.remove(requestId); + if (callback == null) { + throw Status.NOT_FOUND + .withDescription("No such requestId: " + requestId) + .asRuntimeException(); + } + callback.accept(RequestCancelNexusOperationResponse.getDefaultInstance()); + } + + @Override + public void respondStartNexusOperationTask( + String requestId, StartOperationResponse startOperation) { + Consumer callback = + nexusExternalCallerStartRequests.remove(requestId); + if (callback == null) { + throw Status.NOT_FOUND + .withDescription("No such requestId: " + requestId) + .asRuntimeException(); + } + StartNexusOperationResponse.Builder response = StartNexusOperationResponse.newBuilder(); + if (startOperation.hasSyncSuccess()) { + response.setSyncSuccess( + StartNexusOperationResponse.Sync.newBuilder() + .setResult(startOperation.getSyncSuccess().getPayload()) + .build()); + } else if (startOperation.hasAsyncSuccess()) { + response.setAsyncSuccess( + StartNexusOperationResponse.Async.newBuilder() + .setOperationToken(startOperation.getAsyncSuccess().getOperationToken()) + .build()); + } else if (startOperation.hasOperationError()) { + response.setUnsuccessful( + StartNexusOperationResponse.Unsuccessful.newBuilder() + .setOperationError(startOperation.getOperationError()) + .build()); + } else { + throw Status.INTERNAL + .withDescription("Unexpected StartOperationResponse: " + startOperation) + .asRuntimeException(); + } + callback.accept(response.build()); + } + + @Override + public void respondGetNexusOperationResultTask( + String requestId, GetOperationResultResponse getOperationResult) { + Consumer callback = + nexusExternalCallerGetResultRequests.remove(requestId); + if (callback == null) { + throw Status.NOT_FOUND + .withDescription("No such requestId: " + requestId) + .asRuntimeException(); + } + GetNexusOperationResultResponse.Builder response = GetNexusOperationResultResponse.newBuilder(); + if (getOperationResult.hasSuccessful()) { + response.setSuccessful( + GetNexusOperationResultResponse.Successful.newBuilder() + .setResult(getOperationResult.getSuccessful().getResult()) + .build()); + } else if (getOperationResult.hasStillRunning()) { + response.setStillRunning(GetNexusOperationResultResponse.StillRunning.newBuilder().build()); + } else if (getOperationResult.hasUnsuccessful()) { + response.setUnsuccessful( + GetNexusOperationResultResponse.Unsuccessful.newBuilder() + .setOperationError(getOperationResult.getUnsuccessful().getOperationError()) + .build()); + } else { + throw Status.INTERNAL + .withDescription("Unexpected GetOperationResultResponse: " + getOperationResult) + .asRuntimeException(); + } + callback.accept(response.build()); + } + + @Override + public void respondFailNexusTask(String requestId, HandlerError handlerError) { + if (nexusExternalCallerStartRequests.containsKey(requestId)) { + nexusExternalCallerStartRequests + .remove(requestId) + .accept(StartNexusOperationResponse.newBuilder().setHandlerError(handlerError).build()); + } else if (nexusExternalCallerCancelRequests.containsKey(requestId)) { + Consumer f = + nexusExternalCallerCancelRequests.remove(requestId); + f.accept( + RequestCancelNexusOperationResponse.newBuilder().setHandlerError(handlerError).build()); + } else if (nexusExternalCallerGetInfoRequests.containsKey(requestId)) { + Consumer f = + nexusExternalCallerGetInfoRequests.remove(requestId); + f.accept(GetNexusOperationInfoResponse.newBuilder().setHandlerError(handlerError).build()); + } else if (nexusExternalCallerGetResultRequests.containsKey(requestId)) { + Consumer f = + nexusExternalCallerGetResultRequests.remove(requestId); + f.accept(GetNexusOperationResultResponse.newBuilder().setHandlerError(handlerError).build()); + } else { + throw Status.NOT_FOUND + .withDescription("No such requestId: " + requestId) + .asRuntimeException(); + } + } + + @Override + public CompletableFuture startNexusOperation( + TaskQueueId taskQueueId, StartOperationRequest startRequest, Map headers) { + TaskQueue taskQueue = getNexusTaskQueueQueue(taskQueueId); + + // Create the task token + String requestId = UUID.randomUUID().toString(); + CompletableFuture future = new CompletableFuture<>(); + nexusExternalCallerStartRequests.put(requestId, future::complete); + NexusTaskToken nexusTaskToken = + NexusTaskToken.newBuilder() + .setAttempt(1) + .setExternalCaller( + NexusTaskToken.ExternalCallerTaskToken.newBuilder().setId(requestId).build()) + .build(); + // Create the task + Request request = + Request.newBuilder() + .setScheduledTime(Timestamps.now()) + .setStartOperation(startRequest) + .putAllHeader(headers) + .build(); + PollNexusTaskQueueResponse.Builder pollNexusTaskQueueResponse = + PollNexusTaskQueueResponse.newBuilder() + .setTaskToken(nexusTaskToken.toByteString()) + .setRequest(request); + taskQueue.add( + new NexusTask( + taskQueueId, + pollNexusTaskQueueResponse, + // TODO: Derive the deadline from the context + Timestamps.fromMillis(System.currentTimeMillis() + 1000 * 10) // 10s + )); + return future; + } + + @Override + public CompletableFuture requestCancelNexusOperation( + TaskQueueId taskQueueId, CancelOperationRequest taskRequest, Map headers) { + TaskQueue taskQueue = getNexusTaskQueueQueue(taskQueueId); + + // Create the task token + String requestId = UUID.randomUUID().toString(); + CompletableFuture future = new CompletableFuture<>(); + nexusExternalCallerCancelRequests.put(requestId, future::complete); + NexusTaskToken nexusTaskToken = + NexusTaskToken.newBuilder() + .setAttempt(1) + .setExternalCaller( + NexusTaskToken.ExternalCallerTaskToken.newBuilder().setId(requestId).build()) + .build(); + // Create the task + Request request = + Request.newBuilder() + .setScheduledTime(Timestamps.now()) + .setCancelOperation(taskRequest) + .putAllHeader(headers) + .build(); + PollNexusTaskQueueResponse.Builder pollNexusTaskQueueResponse = + PollNexusTaskQueueResponse.newBuilder() + .setTaskToken(nexusTaskToken.toByteString()) + .setRequest(request); + taskQueue.add( + new NexusTask( + taskQueueId, + pollNexusTaskQueueResponse, + // TODO: Derive the deadline from the context + Timestamps.fromMillis(System.currentTimeMillis() + 1000 * 10) // 10s + )); + return future; + } + + @Override + public CompletableFuture getNexusOperationInfo( + TaskQueueId taskQueueId, GetOperationInfoRequest taskRequest, Map headers) { + TaskQueue taskQueue = getNexusTaskQueueQueue(taskQueueId); + + // Create the task token + String requestId = UUID.randomUUID().toString(); + CompletableFuture future = new CompletableFuture<>(); + nexusExternalCallerGetInfoRequests.put(requestId, future::complete); + NexusTaskToken nexusTaskToken = + NexusTaskToken.newBuilder() + .setAttempt(1) + .setExternalCaller( + NexusTaskToken.ExternalCallerTaskToken.newBuilder().setId(requestId).build()) + .build(); + // Create the task + Request request = + Request.newBuilder() + .setScheduledTime(Timestamps.now()) + .setGetOperationInfo(taskRequest) + .putAllHeader(headers) + .build(); + PollNexusTaskQueueResponse.Builder pollNexusTaskQueueResponse = + PollNexusTaskQueueResponse.newBuilder() + .setTaskToken(nexusTaskToken.toByteString()) + .setRequest(request); + taskQueue.add( + new NexusTask( + taskQueueId, + pollNexusTaskQueueResponse, + // TODO: Derive the deadline from the context + Timestamps.fromMillis(System.currentTimeMillis() + 1000 * 10) // 10s + )); + return future; + } + + @Override + public CompletableFuture getNexusOperationResult( + TaskQueueId taskQueueId, GetOperationResultRequest taskRequest, Map headers) { + TaskQueue taskQueue = getNexusTaskQueueQueue(taskQueueId); + + // Create the task token + String requestId = UUID.randomUUID().toString(); + CompletableFuture future = new CompletableFuture<>(); + nexusExternalCallerGetResultRequests.put(requestId, future::complete); + NexusTaskToken nexusTaskToken = + NexusTaskToken.newBuilder() + .setAttempt(1) + .setExternalCaller( + NexusTaskToken.ExternalCallerTaskToken.newBuilder().setId(requestId).build()) + .build(); + // Create the task + Request request = + Request.newBuilder() + .setScheduledTime(Timestamps.now()) + .setGetOperationResult(taskRequest) + .putAllHeader(headers) + .build(); + PollNexusTaskQueueResponse.Builder pollNexusTaskQueueResponse = + PollNexusTaskQueueResponse.newBuilder() + .setTaskToken(nexusTaskToken.toByteString()) + .setRequest(request); + taskQueue.add( + new NexusTask( + taskQueueId, + pollNexusTaskQueueResponse, + // TODO: Derive the deadline from the context + Timestamps.fromMillis(System.currentTimeMillis() + 1000 * 10) // 10s + )); + return future; + } + @Override public GetWorkflowExecutionHistoryResponse getWorkflowExecutionHistory( ExecutionId executionId, diff --git a/temporal-test-server/src/main/proto/api-linter.yaml b/temporal-test-server/src/main/proto/api-linter.yaml index 6204827d0e..fc99951253 100644 --- a/temporal-test-server/src/main/proto/api-linter.yaml +++ b/temporal-test-server/src/main/proto/api-linter.yaml @@ -25,6 +25,22 @@ - 'core::0158::response-plural-first-field' - 'core::0158::response-repeated-first-field' +- included_paths: + - '**/testservice/internal/v1/request_response.proto' + disabled_rules: + - 'core::0122::name-suffix' + - 'core::0131::request-name-required' + - 'core::0131::request-unknown-fields' + - 'core::0132::request-parent-required' + - 'core::0132::request-unknown-fields' + - 'core::0132::response-unknown-fields' + - 'core::0134::request-unknown-fields' + - 'core::0158::request-page-size-field' + - 'core::0158::request-page-token-field' + - 'core::0158::response-next-page-token-field' + - 'core::0158::response-plural-first-field' + - 'core::0158::response-repeated-first-field' + - included_paths: - '**/testservice/v1/service.proto' disabled_rules: diff --git a/temporal-test-server/src/main/proto/buf.yaml b/temporal-test-server/src/main/proto/buf.yaml index 827b7ef613..f6ad5f842d 100644 --- a/temporal-test-server/src/main/proto/buf.yaml +++ b/temporal-test-server/src/main/proto/buf.yaml @@ -1,4 +1,6 @@ version: v1beta1 +deps: + - ../temporal-serviceclient/src/main/proto build: roots: - . diff --git a/temporal-test-server/src/main/proto/temporal/api/testservice/internal/v1/messages.proto b/temporal-test-server/src/main/proto/temporal/api/testservice/internal/v1/messages.proto new file mode 100644 index 0000000000..38a040f6da --- /dev/null +++ b/temporal-test-server/src/main/proto/temporal/api/testservice/internal/v1/messages.proto @@ -0,0 +1,30 @@ +syntax = "proto3"; + +package temporal.api.testservice.internal.v1; + +option java_package = "io.temporal.api.testservice.internal.v1"; +option java_multiple_files = true; +option java_outer_classname = "Messages"; + +import "google/protobuf/duration.proto"; +import "google/protobuf/timestamp.proto"; +import "temporal/api/common/v1/message.proto"; + +message NexusTaskToken { + message WorkflowCallerTaskToken { + string namespace = 1; + temporal.api.common.v1.WorkflowExecution execution = 2; + int64 scheduledEventId = 3; + } + + message ExternalCallerTaskToken { + string id = 1; + } + + oneof test_oneof { + ExternalCallerTaskToken externalCaller = 1; + WorkflowCallerTaskToken workflowCaller = 2; + } + int32 attempt = 3; + bool cancelled = 4; +} diff --git a/temporal-test-server/src/test/java/io/temporal/testserver/functional/NexusWorkflowTest.java b/temporal-test-server/src/test/java/io/temporal/testserver/functional/NexusWorkflowTest.java index be8c4570ca..e69d4e29c5 100644 --- a/temporal-test-server/src/test/java/io/temporal/testserver/functional/NexusWorkflowTest.java +++ b/temporal-test-server/src/test/java/io/temporal/testserver/functional/NexusWorkflowTest.java @@ -22,6 +22,7 @@ import io.temporal.client.WorkflowStub; import io.temporal.internal.common.LinkConverter; import io.temporal.internal.testservice.NexusTaskToken; +import io.temporal.internal.testservice.NexusWorkflowTaskToken; import io.temporal.testing.internal.SDKTestWorkflowRule; import io.temporal.testserver.functional.common.TestWorkflows; import java.util.Arrays; @@ -834,9 +835,9 @@ public void testNexusOperationInvalidRef() { pollNexusTask() .thenCompose( task -> { - NexusTaskToken valid = NexusTaskToken.fromBytes(task.getTaskToken()); - NexusTaskToken invalid = - new NexusTaskToken( + NexusWorkflowTaskToken valid = NexusTaskToken.fromBytes(task.getTaskToken()); + NexusWorkflowTaskToken invalid = + new NexusWorkflowTaskToken( valid.getOperationRef(), (int) (valid.getAttempt() + 20), valid.isCancel()); diff --git a/temporal-testing/src/main/java/io/temporal/testing/internal/SDKTestWorkflowRule.java b/temporal-testing/src/main/java/io/temporal/testing/internal/SDKTestWorkflowRule.java index 57ec25a7a0..0d671e9788 100644 --- a/temporal-testing/src/main/java/io/temporal/testing/internal/SDKTestWorkflowRule.java +++ b/temporal-testing/src/main/java/io/temporal/testing/internal/SDKTestWorkflowRule.java @@ -8,15 +8,13 @@ import com.google.common.io.CharSink; import com.google.common.io.Files; import com.uber.m3.tally.Scope; +import io.nexusrpc.client.ServiceClient; import io.temporal.api.enums.v1.EventType; import io.temporal.api.enums.v1.IndexedValueType; import io.temporal.api.history.v1.History; import io.temporal.api.history.v1.HistoryEvent; import io.temporal.api.nexus.v1.Endpoint; -import io.temporal.client.WorkflowClient; -import io.temporal.client.WorkflowClientOptions; -import io.temporal.client.WorkflowQueryException; -import io.temporal.client.WorkflowStub; +import io.temporal.client.*; import io.temporal.common.SearchAttributeKey; import io.temporal.common.WorkerDeploymentVersion; import io.temporal.common.WorkflowExecutionHistory; @@ -367,6 +365,16 @@ public WorkflowClient getWorkflowClient() { return testWorkflowRule.getWorkflowClient(); } + public ServiceClient newNexusServiceClient(Class nexusServiceInterface) { + return testWorkflowRule + .getWorkflowClient() + .newNexusServiceClient( + nexusServiceInterface, + TemporalNexusServiceClientOptions.newBuilder() + .setEndpoint(getNexusEndpoint().getSpec().getName()) + .build()); + } + public WorkflowServiceStubs getWorkflowServiceStubs() { return testWorkflowRule.getWorkflowServiceStubs(); } diff --git a/temporal-testing/src/main/java/io/temporal/testing/internal/TracingWorkerInterceptor.java b/temporal-testing/src/main/java/io/temporal/testing/internal/TracingWorkerInterceptor.java index 8dfcd48b69..2e48b69b58 100644 --- a/temporal-testing/src/main/java/io/temporal/testing/internal/TracingWorkerInterceptor.java +++ b/temporal-testing/src/main/java/io/temporal/testing/internal/TracingWorkerInterceptor.java @@ -5,6 +5,7 @@ import com.uber.m3.tally.Scope; import io.nexusrpc.OperationException; +import io.nexusrpc.OperationStillRunningException; import io.nexusrpc.handler.OperationContext; import io.temporal.activity.ActivityExecutionContext; import io.temporal.client.ActivityCompletionException; @@ -484,6 +485,27 @@ public StartOperationOutput startOperation(StartOperationInput input) return next.startOperation(input); } + @Override + public FetchOperationResultOutput fetchOperationResult(FetchOperationResultInput input) + throws OperationStillRunningException, OperationException { + trace.add( + "fetchOperationResult " + + input.getOperationContext().getService() + + " " + + input.getOperationContext().getOperation()); + return next.fetchOperationResult(input); + } + + @Override + public FetchOperationInfoResponse fetchOperationInfo(FetchOperationInfoInput input) { + trace.add( + "fetchOperationInfo " + + input.getOperationContext().getService() + + " " + + input.getOperationContext().getOperation()); + return next.fetchOperationInfo(input); + } + @Override public CancelOperationOutput cancelOperation(CancelOperationInput input) { trace.add(