Skip to content

Commit 1e16551

Browse files
Enable serialization of predict response as tensor content.
PiperOrigin-RevId: 689691887
1 parent 36aa4a6 commit 1e16551

File tree

5 files changed

+16
-1
lines changed

5 files changed

+16
-1
lines changed

tensorflow_serving/model_servers/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ cc_library(
497497
"//tensorflow_serving/servables/tensorflow:get_model_metadata_impl",
498498
"//tensorflow_serving/servables/tensorflow:multi_inference",
499499
"//tensorflow_serving/servables/tensorflow:predict_impl",
500+
"//tensorflow_serving/servables/tensorflow:predict_response_tensor_serialization_option",
500501
"//tensorflow_serving/servables/tensorflow:regression_service",
501502
"//tensorflow_serving/servables/tensorflow:saved_model_bundle_source_adapter",
502503
"//tensorflow_serving/servables/tensorflow:session_bundle_config_cc_proto",

tensorflow_serving/model_servers/main.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,11 @@ int main(int argc, char** argv) {
307307
"Whether to skip auto initializing TPU."),
308308
tensorflow::Flag("enable_grpc_healthcheck_service",
309309
&options.enable_grpc_healthcheck_service,
310-
"Enable the standard gRPC healthcheck service.")};
310+
"Enable the standard gRPC healthcheck service."),
311+
tensorflow::Flag(
312+
"enable_serialization_as_tensor_content",
313+
&options.enable_serialization_as_tensor_content,
314+
"Enable serialization of predict response as tensor content.")};
311315

312316
const auto& usage = tensorflow::Flags::Usage(argv[0], flag_list);
313317
if (!tensorflow::Flags::Parse(&argc, argv, flag_list)) {

tensorflow_serving/model_servers/server.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ limitations under the License.
5252
#include "tensorflow_serving/model_servers/model_platform_types.h"
5353
#include "tensorflow_serving/model_servers/server_core.h"
5454
#include "tensorflow_serving/model_servers/server_init.h"
55+
#include "tensorflow_serving/servables/tensorflow/predict_response_tensor_serialization_option.h"
5556
#include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
5657
#include "tensorflow_serving/servables/tensorflow/thread_pool_factory_config.pb.h"
5758
#include "tensorflow_serving/servables/tensorflow/util.h"
@@ -321,6 +322,10 @@ Status Server::BuildAndStart(const Options& server_options) {
321322
options.force_allow_any_version_labels_for_unavailable_models =
322323
server_options.force_allow_any_version_labels_for_unavailable_models;
323324
options.enable_cors_support = server_options.enable_cors_support;
325+
if (server_options.enable_serialization_as_tensor_content) {
326+
options.predict_response_tensor_serialization_option =
327+
internal::PredictResponseTensorSerializationOption::kAsProtoContent;
328+
}
324329

325330
TF_RETURN_IF_ERROR(ServerCore::Create(std::move(options), &server_core_));
326331

tensorflow_serving/model_servers/server.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ class Server {
105105
bool skip_initialize_tpu = false;
106106
// Misc GRPC features
107107
bool enable_grpc_healthcheck_service = false;
108+
// Control whether to serialize predict response as tensor content.
109+
bool enable_serialization_as_tensor_content = false;
108110
Options();
109111
};
110112

tensorflow_serving/servables/tensorflow/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,9 @@ cc_test(
501501
cc_library(
502502
name = "predict_response_tensor_serialization_option",
503503
hdrs = ["predict_response_tensor_serialization_option.h"],
504+
visibility = [
505+
"//visibility:public",
506+
],
504507
)
505508

506509
cc_library(

0 commit comments

Comments
 (0)