Skip to content

Commit f1e1341

Browse files
Add option to configure the name of the input layer of remote model.
This binary is useful to test other models based on resnet that may have slightly different names. PiperOrigin-RevId: 602779121
1 parent db74e57 commit f1e1341

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

tensorflow_serving/example/resnet_client.cc

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ class ServingClient {
119119

120120
tensorflow::string callPredict(const tensorflow::string& model_name,
121121
const tensorflow::string& model_signature_name,
122+
const tensorflow::string& input_name,
122123
const tensorflow::string& file_path) {
123124
PredictRequest predictRequest;
124125
PredictResponse response;
@@ -140,7 +141,7 @@ class ServingClient {
140141
return "execution failed";
141142
}
142143

143-
inputs["input_1"] = proto;
144+
inputs[input_name] = proto;
144145

145146
Status status = stub_->Predict(&context, predictRequest, &response);
146147

@@ -182,13 +183,15 @@ int main(int argc, char** argv) {
182183
tensorflow::string image_file = "";
183184
tensorflow::string model_name = "resnet";
184185
tensorflow::string model_signature_name = "serving_default";
186+
tensorflow::string input_name = "input_1";
185187
std::vector<tensorflow::Flag> flag_list = {
186188
tensorflow::Flag("server_port", &server_port,
187189
"the IP and port of the server"),
188190
tensorflow::Flag("image_file", &image_file, "the path to the image"),
189191
tensorflow::Flag("model_name", &model_name, "name of model"),
190192
tensorflow::Flag("model_signature_name", &model_signature_name,
191-
"name of model signature")};
193+
"name of model signature"),
194+
tensorflow::Flag("input_name", &input_name, "name of input tensor")};
192195

193196
tensorflow::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
194197
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
@@ -201,7 +204,8 @@ int main(int argc, char** argv) {
201204
grpc::CreateChannel(server_port, grpc::InsecureChannelCredentials()));
202205
std::cout << "calling predict using file: " << image_file << " ..."
203206
<< std::endl;
204-
std::cout << guide.callPredict(model_name, model_signature_name, image_file)
207+
std::cout << guide.callPredict(model_name, model_signature_name, input_name,
208+
image_file)
205209
<< std::endl;
206210
return 0;
207211
}

0 commit comments

Comments
 (0)