45
45
import com .ibm .watson .modelmesh .api .UnregisterModelRequest ;
46
46
import com .ibm .watson .modelmesh .api .UnregisterModelResponse ;
47
47
import com .ibm .watson .modelmesh .api .VModelStatusInfo ;
48
+ import com .ibm .watson .modelmesh .payload .Payload ;
49
+ import com .ibm .watson .modelmesh .payload .PayloadProcessor ;
48
50
import com .ibm .watson .modelmesh .thrift .ApplierException ;
49
51
import com .ibm .watson .modelmesh .thrift .InvalidInputException ;
50
52
import com .ibm .watson .modelmesh .thrift .InvalidStateException ;
@@ -156,6 +158,10 @@ public final class ModelMeshApi extends ModelMeshGrpc.ModelMeshImplBase
156
158
// null if header logging is not enabled.
157
159
protected final LogRequestHeaders logHeaders ;
158
160
161
+ private final PayloadProcessor payloadProcessor ;
162
+
163
+ private final ThreadLocal <long []> localIdCounter = ThreadLocal .withInitial (() -> new long [1 ]);
164
+
159
165
/**
160
166
* Create <b>and start</b> the server.
161
167
*
@@ -171,16 +177,18 @@ public final class ModelMeshApi extends ModelMeshGrpc.ModelMeshImplBase
171
177
* @param maxConnectionAge in seconds
172
178
* @param maxConnectionAgeGrace in seconds, custom grace time for graceful connection termination
173
179
* @param logHeaders
180
+ * @param payloadProcessor a processor of payloads
174
181
* @throws IOException
175
182
*/
176
183
public ModelMeshApi (SidecarModelMesh delegate , VModelManager vmm , int port , File keyCert , File privateKey ,
177
184
String privateKeyPassphrase , ClientAuth clientAuth , File [] trustCerts ,
178
185
int maxMessageSize , int maxHeadersSize , long maxConnectionAge , long maxConnectionAgeGrace ,
179
- LogRequestHeaders logHeaders ) throws IOException {
186
+ LogRequestHeaders logHeaders , PayloadProcessor payloadProcessor ) throws IOException {
180
187
181
188
this .delegate = delegate ;
182
189
this .vmm = vmm ;
183
190
this .logHeaders = logHeaders ;
191
+ this .payloadProcessor = payloadProcessor ;
184
192
185
193
this .multiParallelism = getMultiParallelism ();
186
194
@@ -293,6 +301,13 @@ public void shutdown(long timeout, TimeUnit unit) throws InterruptedException {
293
301
if (!done ) {
294
302
server .shutdownNow ();
295
303
}
304
+ if (payloadProcessor != null ) {
305
+ try {
306
+ payloadProcessor .close ();
307
+ } catch (IOException e ) {
308
+ logger .warn ("Error closing PayloadProcessor {}: {}" , payloadProcessor , e .getMessage ());
309
+ }
310
+ }
296
311
threads .shutdownNow ();
297
312
shutdownEventLoops ();
298
313
}
@@ -686,49 +701,57 @@ public void onHalfClose() {
686
701
call .close (INTERNAL .withDescription ("Half-closed without a request" ), emptyMeta ());
687
702
return ;
688
703
}
689
- final int reqSize = reqMessage .readableBytes ();
704
+ int reqReaderIndex = reqMessage .readerIndex ();
705
+ int reqSize = reqMessage .readableBytes ();
690
706
int respSize = -1 ;
707
+ int respReaderIndex = 0 ;
708
+
691
709
io .grpc .Status status = INTERNAL ;
710
+ String modelId = null ;
711
+ String requestId = null ;
712
+ ModelResponse response = null ;
692
713
try (InterruptingListener cancelListener = newInterruptingListener ()) {
693
714
if (logHeaders != null ) {
694
715
logHeaders .addToMDC (headers ); // MDC cleared in finally block
695
716
}
696
- ModelResponse response = null ;
717
+ if (payloadProcessor != null ) {
718
+ requestId = Thread .currentThread ().getId () + "-" + ++localIdCounter .get ()[0 ];
719
+ }
697
720
try {
698
- try {
699
- String balancedMetaVal = headers .get (BALANCED_META_KEY );
700
- Iterator <String > midIt = modelIds .iterator ();
701
- // guaranteed at least one
702
- String modelId = validateModelId (midIt .next (), isVModel );
703
- if (!midIt .hasNext ()) {
704
- // single model case (most common)
705
- response = callModel (modelId , isVModel , methodName ,
706
- balancedMetaVal , headers , reqMessage ).retain ();
707
- } else {
708
- // multi-model case (specialized)
709
- boolean allRequired = "all" .equalsIgnoreCase (headers .get (REQUIRED_KEY ));
710
- List <String > idList = new ArrayList <>();
711
- idList .add (modelId );
712
- while (midIt .hasNext ()) {
713
- idList .add (validateModelId (midIt .next (), isVModel ));
714
- }
715
- response = applyParallelMultiModel (idList , isVModel , methodName ,
716
- balancedMetaVal , headers , reqMessage , allRequired );
721
+ String balancedMetaVal = headers .get (BALANCED_META_KEY );
722
+ Iterator <String > midIt = modelIds .iterator ();
723
+ // guaranteed at least one
724
+ modelId = validateModelId (midIt .next (), isVModel );
725
+ if (!midIt .hasNext ()) {
726
+ // single model case (most common)
727
+ response = callModel (modelId , isVModel , methodName ,
728
+ balancedMetaVal , headers , reqMessage ).retain ();
729
+ } else {
730
+ // multi-model case (specialized)
731
+ boolean allRequired = "all" .equalsIgnoreCase (headers .get (REQUIRED_KEY ));
732
+ List <String > idList = new ArrayList <>();
733
+ idList .add (modelId );
734
+ while (midIt .hasNext ()) {
735
+ idList .add (validateModelId (midIt .next (), isVModel ));
717
736
}
718
- } finally {
719
- releaseReqMessage ( );
737
+ response = applyParallelMultiModel ( idList , isVModel , methodName ,
738
+ balancedMetaVal , headers , reqMessage , allRequired );
720
739
}
721
-
722
- respSize = response .data .readableBytes ();
723
- call .sendHeaders (response .metadata );
724
- call .sendMessage (response .data );
725
- response = null ;
726
740
} finally {
727
- if (response != null ) {
728
- response .release ();
741
+ if (payloadProcessor != null ) {
742
+ processPayload (reqMessage .readerIndex (reqReaderIndex ),
743
+ requestId , modelId , methodName , headers , null , true );
744
+ } else {
745
+ releaseReqMessage ();
729
746
}
747
+ reqMessage = null ; // ownership released or transferred
730
748
}
731
749
750
+ respReaderIndex = response .data .readerIndex ();
751
+ respSize = response .data .readableBytes ();
752
+ call .sendHeaders (response .metadata );
753
+ call .sendMessage (response .data );
754
+ // response is released via ReleaseAfterResponse.releaseAll()
732
755
status = OK ;
733
756
} catch (Exception e ) {
734
757
status = toStatus (e );
@@ -745,6 +768,15 @@ public void onHalfClose() {
745
768
evictMethodDescriptor (methodName );
746
769
}
747
770
} finally {
771
+ if (payloadProcessor != null ) {
772
+ ByteBuf data = null ;
773
+ Metadata metadata = null ;
774
+ if (response != null ) {
775
+ data = response .data .readerIndex (respReaderIndex );
776
+ metadata = response .metadata ;
777
+ }
778
+ processPayload (data , requestId , modelId , methodName , metadata , status , false );
779
+ }
748
780
ReleaseAfterResponse .releaseAll ();
749
781
clearThreadLocals ();
750
782
//TODO(maybe) additional trailer info in exception case?
@@ -757,6 +789,35 @@ public void onHalfClose() {
757
789
}
758
790
}
759
791
792
+ /**
793
+ * Invoke PayloadProcessor on the request/response data
794
+ * @param data the binary data
795
+ * @param payloadId the id of the request
796
+ * @param modelId the id of the model
797
+ * @param methodName the name of the invoked method
798
+ * @param metadata the method name metadata
799
+ * @param status null for requests, non-null for responses
800
+ * @param takeOwnership whether the processor should take ownership
801
+ */
802
+ private void processPayload (ByteBuf data , String payloadId , String modelId , String methodName ,
803
+ Metadata metadata , io .grpc .Status status , boolean takeOwnership ) {
804
+ Payload payload = null ;
805
+ try {
806
+ assert payloadProcessor != null ;
807
+ if (!takeOwnership ) {
808
+ data .retain ();
809
+ }
810
+ payload = new Payload (payloadId , modelId , methodName , metadata , data , status );
811
+ if (payloadProcessor .process (payload )) {
812
+ data = null ; // ownership transferred
813
+ }
814
+ } catch (Throwable t ) {
815
+ logger .warn ("Error while processing payload: {}" , payload , t );
816
+ } finally {
817
+ ReferenceCountUtil .release (data );
818
+ }
819
+ }
820
+
760
821
@ Override
761
822
public void onComplete () {
762
823
releaseReqMessage ();
0 commit comments