Skip to content

Commit eb384db

Browse files
authored
feat: Make it possible to attach a PayloadProcessor to process model predictions (opendatahub-io#84)
#### Motivation This PR seeks to address the model-mesh side of kserve/modelmesh-serving#284. #### Modifications It provides a `PayloadProcessor` interface. `PayloadProcessors` are picked by `ModelMesh` instances at startup and predictions (`Payloads`) are processed asynchronously at fixed timing. A first logger implementation allows to log `Payloads` (at _info_ level). #### Result A SPI for post processing model predictions. --- resolves kserve/modelmesh-serving#284 Signed-off-by: Tommaso Teofili <[email protected]>
1 parent 426193c commit eb384db

20 files changed

+1237
-33
lines changed

config/base/kustomization.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ vars:
3232

3333
#patchesStrategicMerge:
3434
# - patches/etcd.yaml
35+
# - patches/logger.yaml
3536
# - patches/tls.yaml
3637
# - patches/uds.yaml
3738
# - patches/max_msg_size.yaml

config/base/patches/logger.yaml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright 2023 IBM Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Use this patch to change the max size in bytes allowed
16+
# per proxied gRPC message, for headers and data
17+
#
18+
apiVersion: apps/v1
19+
kind: Deployment
20+
metadata:
21+
name: model-mesh
22+
spec:
23+
template:
24+
spec:
25+
containers:
26+
- name: mm
27+
env:
28+
- name: MM_PAYLOAD_PROCESSORS
29+
value: "logger://*"

src/main/java/com/ibm/watson/modelmesh/ModelMesh.java

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,12 @@
6161
import com.ibm.watson.modelmesh.TypeConstraintManager.ProhibitedTypeSet;
6262
import com.ibm.watson.modelmesh.clhm.ConcurrentLinkedHashMap;
6363
import com.ibm.watson.modelmesh.clhm.ConcurrentLinkedHashMap.EvictionListenerWithTime;
64+
import com.ibm.watson.modelmesh.payload.AsyncPayloadProcessor;
65+
import com.ibm.watson.modelmesh.payload.CompositePayloadProcessor;
66+
import com.ibm.watson.modelmesh.payload.LoggingPayloadProcessor;
67+
import com.ibm.watson.modelmesh.payload.MatchingPayloadProcessor;
68+
import com.ibm.watson.modelmesh.payload.PayloadProcessor;
69+
import com.ibm.watson.modelmesh.payload.RemotePayloadProcessor;
6470
import com.ibm.watson.modelmesh.thrift.ApplierException;
6571
import com.ibm.watson.modelmesh.thrift.BaseModelMeshService;
6672
import com.ibm.watson.modelmesh.thrift.InternalException;
@@ -101,6 +107,7 @@
101107
import java.lang.management.MemoryUsage;
102108
import java.lang.reflect.InvocationTargetException;
103109
import java.lang.reflect.Method;
110+
import java.net.URI;
104111
import java.nio.channels.ClosedByInterruptException;
105112
import java.text.ParseException;
106113
import java.text.SimpleDateFormat;
@@ -421,6 +428,40 @@ public abstract class ModelMesh extends ThriftService
421428
}
422429
}
423430

431+
private PayloadProcessor initPayloadProcessor() {
432+
String payloadProcessorsDefinitions = getStringParameter(MM_PAYLOAD_PROCESSORS, null);
433+
logger.info("Parsing PayloadProcessor definition '{}'", payloadProcessorsDefinitions);
434+
if (payloadProcessorsDefinitions != null && payloadProcessorsDefinitions.length() > 0) {
435+
List<PayloadProcessor> payloadProcessors = new ArrayList<>();
436+
for (String processorDefinition : payloadProcessorsDefinitions.split(" ")) {
437+
try {
438+
URI uri = URI.create(processorDefinition);
439+
String processorName = uri.getScheme();
440+
PayloadProcessor processor = null;
441+
String modelId = uri.getQuery();
442+
String method = uri.getFragment();
443+
if ("http".equals(processorName)) {
444+
processor = new RemotePayloadProcessor(uri);
445+
} else if ("logger".equals(processorName)) {
446+
processor = new LoggingPayloadProcessor();
447+
}
448+
if (processor != null) {
449+
MatchingPayloadProcessor p = MatchingPayloadProcessor.from(modelId, method, processor);
450+
payloadProcessors.add(p);
451+
logger.info("Added PayloadProcessor {}", p.getName());
452+
}
453+
} catch (IllegalArgumentException iae) {
454+
logger.error("Unable to parse PayloadProcessor URI definition {}", processorDefinition);
455+
}
456+
}
457+
return new AsyncPayloadProcessor(new CompositePayloadProcessor(payloadProcessors), 1, MINUTES,
458+
Executors.newScheduledThreadPool(getIntParameter(MM_PAYLOAD_PROCESSORS_THREADS, 2)),
459+
getIntParameter(MM_PAYLOAD_PROCESSORS_CAPACITY, 64));
460+
} else {
461+
return null;
462+
}
463+
}
464+
424465
/* ---------------------------------- initialization --------------------------------------------------------- */
425466

426467
@Override
@@ -854,10 +895,11 @@ protected final TProcessor initialize() throws Exception {
854895
}
855896

856897
LogRequestHeaders logHeaders = LogRequestHeaders.getConfiguredLogRequestHeaders();
898+
PayloadProcessor payloadProcessor = initPayloadProcessor();
857899

858900
grpcServer = new ModelMeshApi((SidecarModelMesh) this, vModelManager, GRPC_PORT, keyCertFile, privateKeyFile,
859901
privateKeyPassphrase, clientAuth, caCertFiles, maxGrpcMessageSize, maxGrpcHeadersSize,
860-
maxGrpcConnectionAge, maxGrpcConnectionAgeGrace, logHeaders);
902+
maxGrpcConnectionAge, maxGrpcConnectionAgeGrace, logHeaders, payloadProcessor);
861903
}
862904

863905
if (grpcServer != null) {

src/main/java/com/ibm/watson/modelmesh/ModelMeshApi.java

Lines changed: 92 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@
4545
import com.ibm.watson.modelmesh.api.UnregisterModelRequest;
4646
import com.ibm.watson.modelmesh.api.UnregisterModelResponse;
4747
import com.ibm.watson.modelmesh.api.VModelStatusInfo;
48+
import com.ibm.watson.modelmesh.payload.Payload;
49+
import com.ibm.watson.modelmesh.payload.PayloadProcessor;
4850
import com.ibm.watson.modelmesh.thrift.ApplierException;
4951
import com.ibm.watson.modelmesh.thrift.InvalidInputException;
5052
import com.ibm.watson.modelmesh.thrift.InvalidStateException;
@@ -156,6 +158,10 @@ public final class ModelMeshApi extends ModelMeshGrpc.ModelMeshImplBase
156158
// null if header logging is not enabled.
157159
protected final LogRequestHeaders logHeaders;
158160

161+
private final PayloadProcessor payloadProcessor;
162+
163+
private final ThreadLocal<long[]> localIdCounter = ThreadLocal.withInitial(() -> new long[1]);
164+
159165
/**
160166
* Create <b>and start</b> the server.
161167
*
@@ -171,16 +177,18 @@ public final class ModelMeshApi extends ModelMeshGrpc.ModelMeshImplBase
171177
* @param maxConnectionAge in seconds
172178
* @param maxConnectionAgeGrace in seconds, custom grace time for graceful connection termination
173179
* @param logHeaders
180+
* @param payloadProcessor a processor of payloads
174181
* @throws IOException
175182
*/
176183
public ModelMeshApi(SidecarModelMesh delegate, VModelManager vmm, int port, File keyCert, File privateKey,
177184
String privateKeyPassphrase, ClientAuth clientAuth, File[] trustCerts,
178185
int maxMessageSize, int maxHeadersSize, long maxConnectionAge, long maxConnectionAgeGrace,
179-
LogRequestHeaders logHeaders) throws IOException {
186+
LogRequestHeaders logHeaders, PayloadProcessor payloadProcessor) throws IOException {
180187

181188
this.delegate = delegate;
182189
this.vmm = vmm;
183190
this.logHeaders = logHeaders;
191+
this.payloadProcessor = payloadProcessor;
184192

185193
this.multiParallelism = getMultiParallelism();
186194

@@ -293,6 +301,13 @@ public void shutdown(long timeout, TimeUnit unit) throws InterruptedException {
293301
if (!done) {
294302
server.shutdownNow();
295303
}
304+
if (payloadProcessor != null) {
305+
try {
306+
payloadProcessor.close();
307+
} catch (IOException e) {
308+
logger.warn("Error closing PayloadProcessor {}: {}", payloadProcessor, e.getMessage());
309+
}
310+
}
296311
threads.shutdownNow();
297312
shutdownEventLoops();
298313
}
@@ -686,49 +701,57 @@ public void onHalfClose() {
686701
call.close(INTERNAL.withDescription("Half-closed without a request"), emptyMeta());
687702
return;
688703
}
689-
final int reqSize = reqMessage.readableBytes();
704+
int reqReaderIndex = reqMessage.readerIndex();
705+
int reqSize = reqMessage.readableBytes();
690706
int respSize = -1;
707+
int respReaderIndex = 0;
708+
691709
io.grpc.Status status = INTERNAL;
710+
String modelId = null;
711+
String requestId = null;
712+
ModelResponse response = null;
692713
try (InterruptingListener cancelListener = newInterruptingListener()) {
693714
if (logHeaders != null) {
694715
logHeaders.addToMDC(headers); // MDC cleared in finally block
695716
}
696-
ModelResponse response = null;
717+
if (payloadProcessor != null) {
718+
requestId = Thread.currentThread().getId() + "-" + ++localIdCounter.get()[0];
719+
}
697720
try {
698-
try {
699-
String balancedMetaVal = headers.get(BALANCED_META_KEY);
700-
Iterator<String> midIt = modelIds.iterator();
701-
// guaranteed at least one
702-
String modelId = validateModelId(midIt.next(), isVModel);
703-
if (!midIt.hasNext()) {
704-
// single model case (most common)
705-
response = callModel(modelId, isVModel, methodName,
706-
balancedMetaVal, headers, reqMessage).retain();
707-
} else {
708-
// multi-model case (specialized)
709-
boolean allRequired = "all".equalsIgnoreCase(headers.get(REQUIRED_KEY));
710-
List<String> idList = new ArrayList<>();
711-
idList.add(modelId);
712-
while (midIt.hasNext()) {
713-
idList.add(validateModelId(midIt.next(), isVModel));
714-
}
715-
response = applyParallelMultiModel(idList, isVModel, methodName,
716-
balancedMetaVal, headers, reqMessage, allRequired);
721+
String balancedMetaVal = headers.get(BALANCED_META_KEY);
722+
Iterator<String> midIt = modelIds.iterator();
723+
// guaranteed at least one
724+
modelId = validateModelId(midIt.next(), isVModel);
725+
if (!midIt.hasNext()) {
726+
// single model case (most common)
727+
response = callModel(modelId, isVModel, methodName,
728+
balancedMetaVal, headers, reqMessage).retain();
729+
} else {
730+
// multi-model case (specialized)
731+
boolean allRequired = "all".equalsIgnoreCase(headers.get(REQUIRED_KEY));
732+
List<String> idList = new ArrayList<>();
733+
idList.add(modelId);
734+
while (midIt.hasNext()) {
735+
idList.add(validateModelId(midIt.next(), isVModel));
717736
}
718-
} finally {
719-
releaseReqMessage();
737+
response = applyParallelMultiModel(idList, isVModel, methodName,
738+
balancedMetaVal, headers, reqMessage, allRequired);
720739
}
721-
722-
respSize = response.data.readableBytes();
723-
call.sendHeaders(response.metadata);
724-
call.sendMessage(response.data);
725-
response = null;
726740
} finally {
727-
if (response != null) {
728-
response.release();
741+
if (payloadProcessor != null) {
742+
processPayload(reqMessage.readerIndex(reqReaderIndex),
743+
requestId, modelId, methodName, headers, null, true);
744+
} else {
745+
releaseReqMessage();
729746
}
747+
reqMessage = null; // ownership released or transferred
730748
}
731749

750+
respReaderIndex = response.data.readerIndex();
751+
respSize = response.data.readableBytes();
752+
call.sendHeaders(response.metadata);
753+
call.sendMessage(response.data);
754+
// response is released via ReleaseAfterResponse.releaseAll()
732755
status = OK;
733756
} catch (Exception e) {
734757
status = toStatus(e);
@@ -745,6 +768,15 @@ public void onHalfClose() {
745768
evictMethodDescriptor(methodName);
746769
}
747770
} finally {
771+
if (payloadProcessor != null) {
772+
ByteBuf data = null;
773+
Metadata metadata = null;
774+
if (response != null) {
775+
data = response.data.readerIndex(respReaderIndex);
776+
metadata = response.metadata;
777+
}
778+
processPayload(data, requestId, modelId, methodName, metadata, status, false);
779+
}
748780
ReleaseAfterResponse.releaseAll();
749781
clearThreadLocals();
750782
//TODO(maybe) additional trailer info in exception case?
@@ -757,6 +789,35 @@ public void onHalfClose() {
757789
}
758790
}
759791

792+
/**
793+
* Invoke PayloadProcessor on the request/response data
794+
* @param data the binary data
795+
* @param payloadId the id of the request
796+
* @param modelId the id of the model
797+
* @param methodName the name of the invoked method
798+
* @param metadata the method name metadata
799+
* @param status null for requests, non-null for responses
800+
* @param takeOwnership whether the processor should take ownership
801+
*/
802+
private void processPayload(ByteBuf data, String payloadId, String modelId, String methodName,
803+
Metadata metadata, io.grpc.Status status, boolean takeOwnership) {
804+
Payload payload = null;
805+
try {
806+
assert payloadProcessor != null;
807+
if (!takeOwnership) {
808+
data.retain();
809+
}
810+
payload = new Payload(payloadId, modelId, methodName, metadata, data, status);
811+
if (payloadProcessor.process(payload)) {
812+
data = null; // ownership transferred
813+
}
814+
} catch (Throwable t) {
815+
logger.warn("Error while processing payload: {}", payload, t);
816+
} finally {
817+
ReferenceCountUtil.release(data);
818+
}
819+
}
820+
760821
@Override
761822
public void onComplete() {
762823
releaseReqMessage();

src/main/java/com/ibm/watson/modelmesh/ModelMeshEnvVars.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ public final class ModelMeshEnvVars {
2323

2424
private ModelMeshEnvVars() {}
2525

26+
public static final String MM_PAYLOAD_PROCESSORS = "MM_PAYLOAD_PROCESSORS";
27+
public static final String MM_PAYLOAD_PROCESSORS_THREADS = "MM_PAYLOAD_PROCESSORS_THREADS";
28+
public static final String MM_PAYLOAD_PROCESSORS_CAPACITY = "MM_PAYLOAD_PROCESSORS_CAPACITY";
29+
2630
// This must not be changed after model-mesh is already deployed to a particular env
2731
public static final String KV_STORE_PREFIX = "MM_KVSTORE_PREFIX";
2832

0 commit comments

Comments
 (0)