Skip to content

Commit 2405dba

Browse files
Merge pull request opendatahub-io#65 from ruivieira/RHOAIENG-4963-b
RHOAIENG-4963: ModelMesh should support TLS in payload processors
2 parents cf3fcf6 + ed8161a commit 2405dba

File tree

3 files changed

+100
-16
lines changed

3 files changed

+100
-16
lines changed

src/main/java/com/ibm/watson/modelmesh/ModelMesh.java

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,15 +100,18 @@
100100
import org.eclipse.collections.impl.list.mutable.primitive.IntArrayList;
101101

102102
import javax.annotation.concurrent.GuardedBy;
103-
import java.io.File;
104-
import java.io.InterruptedIOException;
103+
import javax.net.ssl.SSLContext;
104+
import javax.net.ssl.TrustManagerFactory;
105+
import java.io.*;
105106
import java.lang.management.ManagementFactory;
106107
import java.lang.management.MemoryMXBean;
107108
import java.lang.management.MemoryUsage;
108109
import java.lang.reflect.InvocationTargetException;
109110
import java.lang.reflect.Method;
110111
import java.net.URI;
111112
import java.nio.channels.ClosedByInterruptException;
113+
import java.security.KeyStore;
114+
import java.security.NoSuchAlgorithmException;
112115
import java.text.ParseException;
113116
import java.text.SimpleDateFormat;
114117
import java.util.*;
@@ -428,10 +431,38 @@ public abstract class ModelMesh extends ThriftService
428431
}
429432
}
430433

434+
private static final String SSL_TRUSTSTORE_PATH_PROPERTY = "watson.ssl.truststore.path";
435+
private static final String SSL_TRUSTSTORE_PASSWORD_PROPERTY = "watson.ssl.truststore.password";
436+
437+
private static SSLContext sslContext = null;
438+
439+
private static SSLContext loadSSLContext() throws Exception {
440+
if (sslContext == null) {
441+
final String trustStorePath = System.getProperty(SSL_TRUSTSTORE_PATH_PROPERTY);
442+
final String trustStorePassword = System.getProperty(SSL_TRUSTSTORE_PASSWORD_PROPERTY);
443+
444+
if (trustStorePath == null || trustStorePassword == null) {
445+
throw new IllegalArgumentException("Truststore settings not found in system properties");
446+
}
447+
448+
final KeyStore trustStore = KeyStore.getInstance("JKS");
449+
try (FileInputStream trustStoreStream = new FileInputStream(trustStorePath)) {
450+
trustStore.load(trustStoreStream, trustStorePassword.toCharArray());
451+
}
452+
453+
final TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
454+
trustManagerFactory.init(trustStore);
455+
456+
sslContext = SSLContext.getInstance("TLS");
457+
sslContext.init(null, trustManagerFactory.getTrustManagers(), null);
458+
}
459+
return sslContext;
460+
}
461+
431462
private PayloadProcessor initPayloadProcessor() {
432463
String payloadProcessorsDefinitions = getStringParameter(MM_PAYLOAD_PROCESSORS, null);
433464
logger.info("Parsing PayloadProcessor definition '{}'", payloadProcessorsDefinitions);
434-
if (payloadProcessorsDefinitions != null && payloadProcessorsDefinitions.length() > 0) {
465+
if (payloadProcessorsDefinitions != null && !payloadProcessorsDefinitions.isEmpty()) {
435466
List<PayloadProcessor> payloadProcessors = new ArrayList<>();
436467
for (String processorDefinition : payloadProcessorsDefinitions.split(" ")) {
437468
try {
@@ -441,7 +472,17 @@ private PayloadProcessor initPayloadProcessor() {
441472
String modelId = uri.getQuery();
442473
String method = uri.getFragment();
443474
if ("http".equals(processorName)) {
475+
logger.info("Initializing HTTP payload processor");
444476
processor = new RemotePayloadProcessor(uri);
477+
} else if ("https".equals(processorName)) {
478+
SSLContext sslContext;
479+
try {
480+
sslContext = loadSSLContext();
481+
} catch (Exception missingAlgorithmException) {
482+
throw new UncheckedIOException(new IOException(missingAlgorithmException));
483+
}
484+
logger.info("Initializing HTTPS payload processor");
485+
processor = new RemotePayloadProcessor(uri, sslContext, sslContext.getDefaultSSLParameters());
445486
} else if ("logger".equals(processorName)) {
446487
processor = new LoggingPayloadProcessor();
447488
}

src/main/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessor.java

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import java.nio.charset.StandardCharsets;
2424
import java.util.HashMap;
2525
import java.util.Map;
26+
import javax.net.ssl.SSLContext;
27+
import javax.net.ssl.SSLParameters;
2628

2729
import com.fasterxml.jackson.databind.ObjectMapper;
2830
import io.grpc.Metadata;
@@ -42,11 +44,27 @@ public class RemotePayloadProcessor implements PayloadProcessor {
4244

4345
private final URI uri;
4446

47+
private final SSLContext sslContext;
48+
private final SSLParameters sslParameters;
49+
4550
private final HttpClient client;
4651

4752
public RemotePayloadProcessor(URI uri) {
53+
this(uri, null, null);
54+
}
55+
56+
public RemotePayloadProcessor(URI uri, SSLContext sslContext, SSLParameters sslParameters) {
4857
this.uri = uri;
49-
this.client = HttpClient.newHttpClient();
58+
this.sslContext = sslContext;
59+
this.sslParameters = sslParameters;
60+
if (sslContext != null && sslParameters != null) {
61+
this.client = HttpClient.newBuilder()
62+
.sslContext(sslContext)
63+
.sslParameters(sslParameters)
64+
.build();
65+
} else {
66+
this.client = HttpClient.newHttpClient();
67+
}
5068
}
5169

5270
@Override

src/test/java/com/ibm/watson/modelmesh/payload/RemotePayloadProcessorTest.java

Lines changed: 37 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,30 +16,55 @@
1616

1717
package com.ibm.watson.modelmesh.payload;
1818

19+
import java.io.IOException;
1920
import java.net.URI;
21+
import java.security.NoSuchAlgorithmException;
2022

2123
import io.grpc.Metadata;
2224
import io.grpc.Status;
2325
import io.netty.buffer.ByteBuf;
2426
import io.netty.buffer.Unpooled;
2527
import org.junit.jupiter.api.Test;
2628

29+
import javax.net.ssl.SSLContext;
30+
import javax.net.ssl.SSLParameters;
31+
2732
import static org.junit.jupiter.api.Assertions.assertFalse;
2833

2934
class RemotePayloadProcessorTest {
3035

36+
void testDestinationUnreachable() throws IOException {
37+
URI uri = URI.create("http://this-does-not-exist:123");
38+
try (RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(uri)) {
39+
String id = "123";
40+
String modelId = "456";
41+
String method = "predict";
42+
Status kind = Status.INVALID_ARGUMENT;
43+
Metadata metadata = new Metadata();
44+
metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar");
45+
metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes());
46+
ByteBuf data = Unpooled.buffer(4);
47+
Payload payload = new Payload(id, modelId, method, metadata, data, kind);
48+
assertFalse(remotePayloadProcessor.process(payload));
49+
}
50+
}
51+
3152
@Test
32-
void testDestinationUnreachable() {
33-
RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(URI.create("http://this-does-not-exist:123"));
34-
String id = "123";
35-
String modelId = "456";
36-
String method = "predict";
37-
Status kind = Status.INVALID_ARGUMENT;
38-
Metadata metadata = new Metadata();
39-
metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar");
40-
metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes());
41-
ByteBuf data = Unpooled.buffer(4);
42-
Payload payload = new Payload(id, modelId, method, metadata, data, kind);
43-
assertFalse(remotePayloadProcessor.process(payload));
53+
void testDestinationUnreachableHTTPS() throws IOException, NoSuchAlgorithmException {
54+
URI uri = URI.create("https://this-does-not-exist:123");
55+
SSLContext sslContext = SSLContext.getDefault();
56+
SSLParameters sslParameters = sslContext.getDefaultSSLParameters();
57+
try (RemotePayloadProcessor remotePayloadProcessor = new RemotePayloadProcessor(uri, sslContext, sslParameters)) {
58+
String id = "123";
59+
String modelId = "456";
60+
String method = "predict";
61+
Status kind = Status.INVALID_ARGUMENT;
62+
Metadata metadata = new Metadata();
63+
metadata.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar");
64+
metadata.put(Metadata.Key.of("binary-bin", Metadata.BINARY_BYTE_MARSHALLER), "string".getBytes());
65+
ByteBuf data = Unpooled.buffer(4);
66+
Payload payload = new Payload(id, modelId, method, metadata, data, kind);
67+
assertFalse(remotePayloadProcessor.process(payload));
68+
}
4469
}
4570
}

0 commit comments

Comments
 (0)