-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathTensorFlowModelPrediction.java
More file actions
127 lines (111 loc) · 4.44 KB
/
TensorFlowModelPrediction.java
File metadata and controls
127 lines (111 loc) · 4.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
package plugins.frauzufall.deeplearning.tensorflow;
import icy.image.IcyBufferedImage;
import icy.sequence.Sequence;
import icy.type.DataType;
import org.apache.commons.io.FileUtils;
import org.apache.commons.math3.stat.StatUtils;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.Tensor;
import org.tensorflow.proto.framework.MetaGraphDef;
import org.tensorflow.proto.framework.SignatureDef;
import plugins.frauzufall.deeplearning.AbstractModelPrediction;
import plugins.frauzufall.deeplearning.Converter;
import plugins.frauzufall.deeplearning.DefaultPredictionOptions;
import plugins.frauzufall.deeplearning.PredictionOptions;
import plugins.frauzufall.deeplearning.util.ZipUtils;
import java.io.File;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.List;
public class TensorFlowModelPrediction extends AbstractModelPrediction<Tensor> {
private final PredictionOptions options;
private final String modelTag = "serve";
private static final String DEFAULT_SERVING_SIGNATURE_DEF_KEY =
"serving_default";
private SavedModelBundle modelBundle;
private SignatureDef signature;
public TensorFlowModelPrediction() {
this.options = new DefaultPredictionOptions();
}
public PredictionOptions getOptions() {
return options;
}
@Override
public boolean initialize() {
try {
modelBundle = load(getOptions().getModelFile().getAbsolutePath(), modelTag);
signature = MetaGraphDef.parseFrom(modelBundle.metaGraphDef().toByteArray()).getSignatureDefOrThrow(
DEFAULT_SERVING_SIGNATURE_DEF_KEY);
System.out.println("Model inputs: " + signature.getInputsMap().toString().replace("\n", " ").replace("\t", " "));
System.out.println("Model outputs: " + signature.getOutputsMap().toString().replace("\n", " ").replace("\t", " "));
} catch (IOException e) {
e.printStackTrace();
return false;
}
return true;
}
private SavedModelBundle load(String zipFile, String modelTag) throws IOException {
Path tmpFolder = Files.createTempDirectory(new File(zipFile).getName());
ZipUtils.extractFolder(zipFile, tmpFolder.toAbsolutePath().toString());
SavedModelBundle bundle = SavedModelBundle.load(tmpFolder.toAbsolutePath().toString(), modelTag);
FileUtils.deleteDirectory(tmpFolder.toFile());
return bundle;
}
@Override
protected Tensor process(Tensor tensor) {
Session.Runner runner = getSession().runner();
runner.feed(getInputTensorName(), tensor);
runner.fetch(getOutputTensorName());
List<Tensor<?>> tensors = runner.run();
return tensors.get(0);
}
@Override
protected Converter<Tensor> createConverter() {
return new TensorFlowConverter();
}
private String getInputTensorName() {
return signature.getInputsMap().values().iterator().next().getName();
}
private String getOutputTensorName() {
return signature.getOutputsMap().values().iterator().next().getName();
}
private Session getSession() {
return modelBundle.session();
}
public Sequence predict(Sequence inputSequence) {
return predict(inputSequence, 0);
}
@Override
public Sequence predict(Sequence inputSequence, int time) {
Sequence normalized = normalize(inputSequence, time);
// Because the normalized sequence has only 1 time-point, we have to process its first time-point, the number 0.
return super.predict(normalized, 0);
}
/**
* Normalizes the time-point of the specified sequence, and returns the results in a <b>1 time-point</b> sequence.
* @param input the sequence to normalize.
* @param time the time-point to normalize.
* @return a new sequence, with only one time-point.
*/
private Sequence normalize(Sequence input, int time) {
double[] values = new double[input.getWidth() * input.getHeight()];
for (int x = 0; x < input.getWidth(); x++) {
for (int y = 0; y < input.getHeight(); y++) {
values[x + y * input.getWidth()] = input.getData(time, 0, 0, y, x);
}
}
double minVal = StatUtils.percentile(values, getOptions().getPercentileBottom());
double maxVal = StatUtils.percentile(values, getOptions().getPercentileTop());
IcyBufferedImage resImg = new IcyBufferedImage(input.getWidth(), input.getHeight(), input.getSizeC(), DataType.FLOAT);
float[] valuesOut = new float[input.getWidth() * input.getHeight()];
double factor = 1./(maxVal - minVal);
for (int i = 0; i < values.length; i++) {
valuesOut[i] = (float) ((values[i] - minVal) * factor);
}
Sequence res = new Sequence(resImg);
res.setDataXY(0, 0, 0, valuesOut);
return res;
}
}