diff --git a/src/java/src/main/java/triton/client/InferResult.java b/src/java/src/main/java/triton/client/InferResult.java index c5166f3f0..7870fae71 100644 --- a/src/java/src/main/java/triton/client/InferResult.java +++ b/src/java/src/main/java/triton/client/InferResult.java @@ -307,6 +307,39 @@ public double[] getOutputAsDouble(String output) return (double[]) getOutputImpl(out, double.class, ByteBuffer::getDouble); } + /** + * Get String tensor named as by parameter output. The tensor must be of + * DataType.BYTES. + * + * @param output name of output tensor. + * @return null if output not found or the tensor in String. + */ + public String getOutputAsString(String output) + { + IOTensor out = this.response.getOutputByName(output); + if (out == null) { + return null; + } + + Object[] data = out.getData(); + if (data == null || data.length == 0) { + return null; + } + + Preconditions.checkArgument( + out.getDatatype() == DataType.BYTES, + "Could not get String from data of type %s on output %s.", + out.getDatatype(), out.getName()); + if (data[0] instanceof String) { + return (String) data[0]; + } else if (data[0] instanceof byte[]) { + return new String( + (byte[]) data[0], java.nio.charset.StandardCharsets.UTF_8); + } else { + return data[0].toString(); + } + } + private Object getOutputImpl(IOTensor out, Class clazz, Function getter) {