Skip to content

Commit 4f313bb

Browse files
Add poller autoscaling (#2535)
Add poller autoscaling
1 parent 0268390 commit 4f313bb

40 files changed

+2626
-506
lines changed

temporal-sdk/src/main/java/io/temporal/internal/common/GrpcUtils.java

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
package io.temporal.internal.common;
22

3+
import com.google.common.util.concurrent.ListenableFuture;
34
import io.grpc.Status;
45
import io.grpc.StatusRuntimeException;
6+
import java.util.concurrent.CompletableFuture;
7+
import java.util.concurrent.ExecutionException;
8+
import java.util.concurrent.ForkJoinPool;
59

610
public class GrpcUtils {
711
/**
@@ -14,4 +18,20 @@ public static boolean isChannelShutdownException(StatusRuntimeException ex) {
1418
&& (description.startsWith("Channel shutdown")
1519
|| description.startsWith("Subchannel shutdown")));
1620
}
21+
22+
public static <T> CompletableFuture<T> toCompletableFuture(ListenableFuture<T> listenableFuture) {
23+
CompletableFuture<T> result = new CompletableFuture<>();
24+
listenableFuture.addListener(
25+
() -> {
26+
try {
27+
result.complete(listenableFuture.get());
28+
} catch (ExecutionException e) {
29+
result.completeExceptionally(e.getCause());
30+
} catch (Exception e) {
31+
result.completeExceptionally(e);
32+
}
33+
},
34+
ForkJoinPool.commonPool());
35+
return result;
36+
}
1737
}

temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityPollTask.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import org.slf4j.Logger;
2525
import org.slf4j.LoggerFactory;
2626

27-
final class ActivityPollTask implements Poller.PollTask<ActivityTask> {
27+
final class ActivityPollTask implements MultiThreadedPoller.PollTask<ActivityTask> {
2828
private static final Logger log = LoggerFactory.getLogger(ActivityPollTask.class);
2929

3030
private final WorkflowServiceStubs service;
@@ -92,7 +92,7 @@ public ActivityTask poll() {
9292
log.warn("Error while trying to reserve a slot for an activity", e.getCause());
9393
return null;
9494
}
95-
permit = Poller.getSlotPermitAndHandleInterrupts(future, slotSupplier);
95+
permit = MultiThreadedPoller.getSlotPermitAndHandleInterrupts(future, slotSupplier);
9696
if (permit == null) return null;
9797

9898
MetricsTag.tagged(metricsScope, PollerTypeMetricsTag.PollerType.ACTIVITY_TASK)

temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityTask.java

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import io.temporal.worker.tuning.SlotPermit;
55
import io.temporal.workflow.Functions;
66
import javax.annotation.Nonnull;
7+
import javax.annotation.Nullable;
78

8-
public final class ActivityTask {
9+
public final class ActivityTask implements ScalingTask {
910
private final @Nonnull PollActivityTaskQueueResponseOrBuilder response;
1011
private final @Nonnull SlotPermit permit;
1112
private final @Nonnull Functions.Proc completionCallback;
@@ -37,4 +38,15 @@ public Functions.Proc getCompletionCallback() {
3738
public SlotPermit getPermit() {
3839
return permit;
3940
}
41+
42+
@Nullable
43+
@Override
44+
public ScalingDecision getScalingDecision() {
45+
if (!response.hasPollerScalingDecision()) {
46+
return null;
47+
}
48+
49+
return new ScalingTask.ScalingDecision(
50+
response.getPollerScalingDecision().getPollRequestDeltaSuggestion());
51+
}
4052
}

temporal-sdk/src/main/java/io/temporal/internal/worker/ActivityWorker.java

Lines changed: 43 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import io.temporal.worker.MetricsType;
2222
import io.temporal.worker.WorkerMetricsTag;
2323
import io.temporal.worker.tuning.*;
24+
import io.temporal.worker.tuning.PollerBehaviorAutoscaling;
2425
import java.util.Objects;
2526
import java.util.Optional;
2627
import java.util.concurrent.CompletableFuture;
@@ -85,23 +86,48 @@ public boolean start() {
8586
pollerOptions,
8687
slotSupplier.maximumSlots().orElse(Integer.MAX_VALUE),
8788
options.isUsingVirtualThreads());
88-
poller =
89-
new Poller<>(
90-
options.getIdentity(),
91-
new ActivityPollTask(
92-
service,
93-
namespace,
94-
taskQueue,
95-
options.getIdentity(),
96-
options.getBuildId(),
97-
options.isUsingBuildIdForVersioning(),
98-
taskQueueActivitiesPerSecond,
99-
this.slotSupplier,
100-
workerMetricsScope,
101-
service.getServerCapabilities()),
102-
this.pollTaskExecutor,
103-
pollerOptions,
104-
workerMetricsScope);
89+
90+
boolean useAsyncPoller =
91+
pollerOptions.getPollerBehavior() instanceof PollerBehaviorAutoscaling;
92+
if (useAsyncPoller) {
93+
poller =
94+
new AsyncPoller<>(
95+
slotSupplier,
96+
new SlotReservationData(taskQueue, options.getIdentity(), options.getBuildId()),
97+
new AsyncActivityPollTask(
98+
service,
99+
namespace,
100+
taskQueue,
101+
options.getIdentity(),
102+
options.getBuildId(),
103+
options.isUsingBuildIdForVersioning(),
104+
taskQueueActivitiesPerSecond,
105+
this.slotSupplier,
106+
workerMetricsScope,
107+
service.getServerCapabilities()),
108+
this.pollTaskExecutor,
109+
pollerOptions,
110+
workerMetricsScope);
111+
112+
} else {
113+
poller =
114+
new MultiThreadedPoller<>(
115+
options.getIdentity(),
116+
new ActivityPollTask(
117+
service,
118+
namespace,
119+
taskQueue,
120+
options.getIdentity(),
121+
options.getBuildId(),
122+
options.isUsingBuildIdForVersioning(),
123+
taskQueueActivitiesPerSecond,
124+
this.slotSupplier,
125+
workerMetricsScope,
126+
service.getServerCapabilities()),
127+
this.pollTaskExecutor,
128+
pollerOptions,
129+
workerMetricsScope);
130+
}
105131
poller.start();
106132
workerMetricsScope.counter(MetricsType.WORKER_START_COUNTER).inc(1);
107133
return true;
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package io.temporal.internal.worker;
2+
3+
import java.util.concurrent.Semaphore;
4+
import javax.annotation.concurrent.ThreadSafe;
5+
6+
/** A simple implementation of an adjustable semaphore. */
7+
@ThreadSafe
8+
public final class AdjustableSemaphore {
9+
private final ResizeableSemaphore semaphore;
10+
11+
/**
12+
* how many permits are allowed as governed by this semaphore. Access must be synchronized on this
13+
* object.
14+
*/
15+
private int maxPermits = 0;
16+
17+
/**
18+
* Create a new adjustable semaphore with the given number of initial permits.
19+
*
20+
* @param initialPermits the initial number of permits, must be at least 1
21+
*/
22+
public AdjustableSemaphore(int initialPermits) {
23+
if (initialPermits < 1) {
24+
throw new IllegalArgumentException(
25+
"Semaphore size must be at least 1," + " was " + initialPermits);
26+
}
27+
this.maxPermits = initialPermits;
28+
this.semaphore = new ResizeableSemaphore(initialPermits);
29+
}
30+
31+
/**
32+
* Set the max number of permits. Must be greater than zero.
33+
*
34+
* <p>Note that if there are more than the new max number of permits currently outstanding, any
35+
* currently blocking threads or any new threads that start to block after the call will wait
36+
* until enough permits have been released to have the number of outstanding permits fall below
37+
* the new maximum. In other words, it does what you probably think it should.
38+
*
39+
* @param newMax the new maximum number of permits
40+
*/
41+
synchronized void setMaxPermits(int newMax) {
42+
if (newMax < 1) {
43+
throw new IllegalArgumentException("Semaphore size must be at least 1," + " was " + newMax);
44+
}
45+
46+
int delta = newMax - this.maxPermits;
47+
48+
if (delta == 0) {
49+
return;
50+
} else if (delta > 0) {
51+
// new max is higher, so release that many permits
52+
this.semaphore.release(delta);
53+
} else {
54+
// delta < 0.
55+
// reducePermits needs a positive #
56+
this.semaphore.reducePermits(Math.abs(delta));
57+
}
58+
59+
this.maxPermits = newMax;
60+
}
61+
62+
/** Release a permit back to the semaphore. */
63+
void release() {
64+
this.semaphore.release();
65+
}
66+
67+
/**
68+
* Get a permit, blocking if necessary.
69+
*
70+
* @throws InterruptedException if interrupted while waiting for a permit
71+
*/
72+
void acquire() throws InterruptedException {
73+
this.semaphore.acquire();
74+
}
75+
76+
/**
77+
* A trivial subclass of <code>Semaphore</code> that exposes the reducePermits call to the parent
78+
* class.
79+
*/
80+
private static final class ResizeableSemaphore extends Semaphore {
81+
/** */
82+
private static final long serialVersionUID = 1L;
83+
84+
/** Create a new semaphore with 0 permits. */
85+
ResizeableSemaphore(int initialPermits) {
86+
super(initialPermits);
87+
}
88+
89+
@Override
90+
protected void reducePermits(int reduction) {
91+
super.reducePermits(reduction);
92+
}
93+
}
94+
}
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
package io.temporal.internal.worker;
2+
3+
import static io.temporal.serviceclient.MetricsTag.METRICS_TAGS_CALL_OPTIONS_KEY;
4+
5+
import com.google.protobuf.DoubleValue;
6+
import com.google.protobuf.Timestamp;
7+
import com.uber.m3.tally.Scope;
8+
import io.grpc.Context;
9+
import io.temporal.api.common.v1.WorkerVersionCapabilities;
10+
import io.temporal.api.taskqueue.v1.TaskQueue;
11+
import io.temporal.api.taskqueue.v1.TaskQueueMetadata;
12+
import io.temporal.api.workflowservice.v1.GetSystemInfoResponse;
13+
import io.temporal.api.workflowservice.v1.PollActivityTaskQueueRequest;
14+
import io.temporal.api.workflowservice.v1.PollActivityTaskQueueResponse;
15+
import io.temporal.internal.common.GrpcUtils;
16+
import io.temporal.internal.common.ProtobufTimeUtils;
17+
import io.temporal.serviceclient.MetricsTag;
18+
import io.temporal.serviceclient.WorkflowServiceStubs;
19+
import io.temporal.worker.MetricsType;
20+
import io.temporal.worker.PollerTypeMetricsTag;
21+
import io.temporal.worker.tuning.ActivitySlotInfo;
22+
import io.temporal.worker.tuning.SlotPermit;
23+
import io.temporal.worker.tuning.SlotReleaseReason;
24+
import java.util.concurrent.CompletableFuture;
25+
import java.util.concurrent.atomic.AtomicInteger;
26+
import java.util.function.Supplier;
27+
import javax.annotation.Nonnull;
28+
import javax.annotation.Nullable;
29+
import org.slf4j.Logger;
30+
import org.slf4j.LoggerFactory;
31+
32+
public class AsyncActivityPollTask implements AsyncPoller.PollTaskAsync<ActivityTask> {
33+
private static final Logger log = LoggerFactory.getLogger(AsyncActivityPollTask.class);
34+
35+
private final TrackingSlotSupplier<?> slotSupplier;
36+
private final WorkflowServiceStubs service;
37+
private final Scope metricsScope;
38+
private final PollActivityTaskQueueRequest pollRequest;
39+
private final AtomicInteger pollGauge = new AtomicInteger();
40+
private final Context.CancellableContext grpcContext = Context.ROOT.withCancellation();
41+
42+
@SuppressWarnings("deprecation")
43+
public AsyncActivityPollTask(
44+
@Nonnull WorkflowServiceStubs service,
45+
@Nonnull String namespace,
46+
@Nonnull String taskQueue,
47+
@Nonnull String identity,
48+
@Nullable String buildId,
49+
boolean useBuildIdForVersioning,
50+
double activitiesPerSecond,
51+
@Nonnull TrackingSlotSupplier<ActivitySlotInfo> slotSupplier,
52+
@Nonnull Scope metricsScope,
53+
@Nonnull Supplier<GetSystemInfoResponse.Capabilities> serverCapabilities) {
54+
this.service = service;
55+
this.slotSupplier = slotSupplier;
56+
this.metricsScope = metricsScope;
57+
58+
PollActivityTaskQueueRequest.Builder pollRequest =
59+
PollActivityTaskQueueRequest.newBuilder()
60+
.setNamespace(namespace)
61+
.setIdentity(identity)
62+
.setTaskQueue(TaskQueue.newBuilder().setName(taskQueue));
63+
if (activitiesPerSecond > 0) {
64+
pollRequest.setTaskQueueMetadata(
65+
TaskQueueMetadata.newBuilder()
66+
.setMaxTasksPerSecond(DoubleValue.newBuilder().setValue(activitiesPerSecond).build())
67+
.build());
68+
}
69+
70+
if (serverCapabilities.get().getBuildIdBasedVersioning()) {
71+
pollRequest.setWorkerVersionCapabilities(
72+
WorkerVersionCapabilities.newBuilder()
73+
.setBuildId(buildId)
74+
.setUseVersioning(useBuildIdForVersioning)
75+
.build());
76+
}
77+
this.pollRequest = pollRequest.build();
78+
}
79+
80+
@Override
81+
public CompletableFuture<ActivityTask> poll(SlotPermit permit) {
82+
if (log.isTraceEnabled()) {
83+
log.trace("poll request begin: " + pollRequest);
84+
}
85+
86+
MetricsTag.tagged(metricsScope, PollerTypeMetricsTag.PollerType.ACTIVITY_TASK)
87+
.gauge(MetricsType.NUM_POLLERS)
88+
.update(pollGauge.incrementAndGet());
89+
90+
CompletableFuture<PollActivityTaskQueueResponse> response = null;
91+
try {
92+
response =
93+
grpcContext.call(
94+
() ->
95+
GrpcUtils.toCompletableFuture(
96+
service
97+
.futureStub()
98+
.withOption(METRICS_TAGS_CALL_OPTIONS_KEY, metricsScope)
99+
.pollActivityTaskQueue(pollRequest)));
100+
} catch (Exception e) {
101+
MetricsTag.tagged(metricsScope, PollerTypeMetricsTag.PollerType.ACTIVITY_TASK)
102+
.gauge(MetricsType.NUM_POLLERS)
103+
.update(pollGauge.decrementAndGet());
104+
throw new RuntimeException(e);
105+
}
106+
107+
return response
108+
.thenApply(
109+
r -> {
110+
if (r == null || r.getTaskToken().isEmpty()) {
111+
metricsScope.counter(MetricsType.ACTIVITY_POLL_NO_TASK_COUNTER).inc(1);
112+
return null;
113+
}
114+
Timestamp startedTime = ProtobufTimeUtils.getCurrentProtoTime();
115+
metricsScope
116+
.timer(MetricsType.ACTIVITY_SCHEDULE_TO_START_LATENCY)
117+
.record(ProtobufTimeUtils.toM3Duration(startedTime, r.getScheduledTime()));
118+
return new ActivityTask(
119+
r,
120+
permit,
121+
() -> slotSupplier.releaseSlot(SlotReleaseReason.taskComplete(), permit));
122+
})
123+
.whenComplete(
124+
(r, e) ->
125+
MetricsTag.tagged(metricsScope, PollerTypeMetricsTag.PollerType.ACTIVITY_TASK)
126+
.gauge(MetricsType.NUM_POLLERS)
127+
.update(pollGauge.decrementAndGet()));
128+
}
129+
130+
@Override
131+
public void cancel(Throwable cause) {
132+
grpcContext.cancel(cause);
133+
}
134+
135+
@Override
136+
public String getLabel() {
137+
return "AsyncActivityPollTask";
138+
}
139+
140+
@Override
141+
public String toString() {
142+
return "AsyncActivityPollTask{}";
143+
}
144+
}

0 commit comments

Comments
 (0)