File tree Expand file tree Collapse file tree 5 files changed +16
-1
lines changed Expand file tree Collapse file tree 5 files changed +16
-1
lines changed Original file line number Diff line number Diff line change @@ -497,6 +497,7 @@ cc_library(
497
497
"//tensorflow_serving/servables/tensorflow:get_model_metadata_impl" ,
498
498
"//tensorflow_serving/servables/tensorflow:multi_inference" ,
499
499
"//tensorflow_serving/servables/tensorflow:predict_impl" ,
500
+ "//tensorflow_serving/servables/tensorflow:predict_response_tensor_serialization_option" ,
500
501
"//tensorflow_serving/servables/tensorflow:regression_service" ,
501
502
"//tensorflow_serving/servables/tensorflow:saved_model_bundle_source_adapter" ,
502
503
"//tensorflow_serving/servables/tensorflow:session_bundle_config_cc_proto" ,
Original file line number Diff line number Diff line change @@ -307,7 +307,11 @@ int main(int argc, char** argv) {
307
307
" Whether to skip auto initializing TPU." ),
308
308
tensorflow::Flag (" enable_grpc_healthcheck_service" ,
309
309
&options.enable_grpc_healthcheck_service ,
310
- " Enable the standard gRPC healthcheck service." )};
310
+ " Enable the standard gRPC healthcheck service." ),
311
+ tensorflow::Flag (
312
+ " enable_serialization_as_tensor_content" ,
313
+ &options.enable_serialization_as_tensor_content ,
314
+ " Enable serialization of predict response as tensor content." )};
311
315
312
316
const auto & usage = tensorflow::Flags::Usage (argv[0 ], flag_list);
313
317
if (!tensorflow::Flags::Parse (&argc, argv, flag_list)) {
Original file line number Diff line number Diff line change @@ -52,6 +52,7 @@ limitations under the License.
52
52
#include " tensorflow_serving/model_servers/model_platform_types.h"
53
53
#include " tensorflow_serving/model_servers/server_core.h"
54
54
#include " tensorflow_serving/model_servers/server_init.h"
55
+ #include " tensorflow_serving/servables/tensorflow/predict_response_tensor_serialization_option.h"
55
56
#include " tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
56
57
#include " tensorflow_serving/servables/tensorflow/thread_pool_factory_config.pb.h"
57
58
#include " tensorflow_serving/servables/tensorflow/util.h"
@@ -321,6 +322,10 @@ Status Server::BuildAndStart(const Options& server_options) {
321
322
options.force_allow_any_version_labels_for_unavailable_models =
322
323
server_options.force_allow_any_version_labels_for_unavailable_models ;
323
324
options.enable_cors_support = server_options.enable_cors_support ;
325
+ if (server_options.enable_serialization_as_tensor_content ) {
326
+ options.predict_response_tensor_serialization_option =
327
+ internal::PredictResponseTensorSerializationOption::kAsProtoContent ;
328
+ }
324
329
325
330
TF_RETURN_IF_ERROR (ServerCore::Create (std::move (options), &server_core_));
326
331
Original file line number Diff line number Diff line change @@ -105,6 +105,8 @@ class Server {
105
105
bool skip_initialize_tpu = false ;
106
106
// Misc GRPC features
107
107
bool enable_grpc_healthcheck_service = false ;
108
+ // Control whether to serialize predict response as tensor content.
109
+ bool enable_serialization_as_tensor_content = false ;
108
110
Options ();
109
111
};
110
112
Original file line number Diff line number Diff line change @@ -501,6 +501,9 @@ cc_test(
501
501
cc_library (
502
502
name = "predict_response_tensor_serialization_option" ,
503
503
hdrs = ["predict_response_tensor_serialization_option.h" ],
504
+ visibility = [
505
+ "//visibility:public" ,
506
+ ],
504
507
)
505
508
506
509
cc_library (
You can’t perform that action at this time.
0 commit comments