Skip to content

Commit 8be19ce

Browse files
authored
Merge pull request #323 from kirilg/branch_147391542
Upstream internal changes
2 parents be063be + a89f907 commit 8be19ce

File tree

8 files changed

+209
-23
lines changed

8 files changed

+209
-23
lines changed

tensorflow_serving/apis/prediction_service_pb2.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# ==============================================================================
1515
# Generated by the protocol buffer compiler. DO NOT EDIT!
1616
# source: tensorflow_serving/apis/prediction_service.proto
17+
# To regenerate run
18+
# python -m grpc.tools.protoc --python_out=. --grpc_python_out=. -I. tensorflow_serving/apis/prediction_service.proto
1719

1820
import sys
1921
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
@@ -27,17 +29,19 @@
2729
_sym_db = _symbol_database.Default()
2830

2931

32+
from tensorflow_serving.apis import classification_pb2 as tensorflow__serving_dot_apis_dot_classification__pb2
3033
from tensorflow_serving.apis import get_model_metadata_pb2 as tensorflow__serving_dot_apis_dot_get__model__metadata__pb2
3134
from tensorflow_serving.apis import predict_pb2 as tensorflow__serving_dot_apis_dot_predict__pb2
35+
from tensorflow_serving.apis import regression_pb2 as tensorflow__serving_dot_apis_dot_regression__pb2
3236

3337

3438
DESCRIPTOR = _descriptor.FileDescriptor(
3539
name='tensorflow_serving/apis/prediction_service.proto',
3640
package='tensorflow.serving',
3741
syntax='proto3',
38-
serialized_pb=_b('\n0tensorflow_serving/apis/prediction_service.proto\x12\x12tensorflow.serving\x1a\x30tensorflow_serving/apis/get_model_metadata.proto\x1a%tensorflow_serving/apis/predict.proto2\xd6\x01\n\x11PredictionService\x12R\n\x07Predict\x12\".tensorflow.serving.PredictRequest\x1a#.tensorflow.serving.PredictResponse\x12m\n\x10GetModelMetadata\x12+.tensorflow.serving.GetModelMetadataRequest\x1a,.tensorflow.serving.GetModelMetadataResponseB\x03\xf8\x01\x01\x62\x06proto3')
42+
serialized_pb=_b('\n0tensorflow_serving/apis/prediction_service.proto\x12\x12tensorflow.serving\x1a,tensorflow_serving/apis/classification.proto\x1a\x30tensorflow_serving/apis/get_model_metadata.proto\x1a%tensorflow_serving/apis/predict.proto\x1a(tensorflow_serving/apis/regression.proto2\x93\x03\n\x11PredictionService\x12\x61\n\x08\x43lassify\x12).tensorflow.serving.ClassificationRequest\x1a*.tensorflow.serving.ClassificationResponse\x12X\n\x07Regress\x12%.tensorflow.serving.RegressionRequest\x1a&.tensorflow.serving.RegressionResponse\x12R\n\x07Predict\x12\".tensorflow.serving.PredictRequest\x1a#.tensorflow.serving.PredictResponse\x12m\n\x10GetModelMetadata\x12+.tensorflow.serving.GetModelMetadataRequest\x1a,.tensorflow.serving.GetModelMetadataResponseB\x03\xf8\x01\x01\x62\x06proto3')
3943
,
40-
dependencies=[tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_predict__pb2.DESCRIPTOR,])
44+
dependencies=[tensorflow__serving_dot_apis_dot_classification__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_predict__pb2.DESCRIPTOR,tensorflow__serving_dot_apis_dot_regression__pb2.DESCRIPTOR,])
4145
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
4246

4347

@@ -66,6 +70,16 @@ def __init__(self, channel):
6670
Args:
6771
channel: A grpc.Channel.
6872
"""
73+
self.Classify = channel.unary_unary(
74+
'/tensorflow.serving.PredictionService/Classify',
75+
request_serializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.SerializeToString,
76+
response_deserializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.FromString,
77+
)
78+
self.Regress = channel.unary_unary(
79+
'/tensorflow.serving.PredictionService/Regress',
80+
request_serializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.SerializeToString,
81+
response_deserializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.FromString,
82+
)
6983
self.Predict = channel.unary_unary(
7084
'/tensorflow.serving.PredictionService/Predict',
7185
request_serializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.SerializeToString,
@@ -84,6 +98,20 @@ class PredictionServiceServicer(object):
8498
model_servers.
8599
"""
86100

101+
def Classify(self, request, context):
102+
"""Classify.
103+
"""
104+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
105+
context.set_details('Method not implemented!')
106+
raise NotImplementedError('Method not implemented!')
107+
108+
def Regress(self, request, context):
109+
"""Regress.
110+
"""
111+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
112+
context.set_details('Method not implemented!')
113+
raise NotImplementedError('Method not implemented!')
114+
87115
def Predict(self, request, context):
88116
"""Predict -- provides access to loaded TensorFlow model.
89117
"""
@@ -101,6 +129,16 @@ def GetModelMetadata(self, request, context):
101129

102130
def add_PredictionServiceServicer_to_server(servicer, server):
103131
rpc_method_handlers = {
132+
'Classify': grpc.unary_unary_rpc_method_handler(
133+
servicer.Classify,
134+
request_deserializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.FromString,
135+
response_serializer=tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.SerializeToString,
136+
),
137+
'Regress': grpc.unary_unary_rpc_method_handler(
138+
servicer.Regress,
139+
request_deserializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.FromString,
140+
response_serializer=tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.SerializeToString,
141+
),
104142
'Predict': grpc.unary_unary_rpc_method_handler(
105143
servicer.Predict,
106144
request_deserializer=tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.FromString,
@@ -127,6 +165,14 @@ class BetaPredictionServiceServicer(object):
127165
PredictionService provides access to machine-learned models loaded by
128166
model_servers.
129167
"""
168+
def Classify(self, request, context):
169+
"""Classify.
170+
"""
171+
context.code(beta_interfaces.StatusCode.UNIMPLEMENTED)
172+
def Regress(self, request, context):
173+
"""Regress.
174+
"""
175+
context.code(beta_interfaces.StatusCode.UNIMPLEMENTED)
130176
def Predict(self, request, context):
131177
"""Predict -- provides access to loaded TensorFlow model.
132178
"""
@@ -147,6 +193,16 @@ class BetaPredictionServiceStub(object):
147193
PredictionService provides access to machine-learned models loaded by
148194
model_servers.
149195
"""
196+
def Classify(self, request, timeout, metadata=None, with_call=False, protocol_options=None):
197+
"""Classify.
198+
"""
199+
raise NotImplementedError()
200+
Classify.future = None
201+
def Regress(self, request, timeout, metadata=None, with_call=False, protocol_options=None):
202+
"""Regress.
203+
"""
204+
raise NotImplementedError()
205+
Regress.future = None
150206
def Predict(self, request, timeout, metadata=None, with_call=False, protocol_options=None):
151207
"""Predict -- provides access to loaded TensorFlow model.
152208
"""
@@ -166,16 +222,22 @@ def beta_create_PredictionService_server(servicer, pool=None, pool_size=None, de
166222
file not marked beta) for all further purposes. This function was
167223
generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0"""
168224
request_deserializers = {
225+
('tensorflow.serving.PredictionService', 'Classify'): tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.FromString,
169226
('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.FromString,
170227
('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.FromString,
228+
('tensorflow.serving.PredictionService', 'Regress'): tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.FromString,
171229
}
172230
response_serializers = {
231+
('tensorflow.serving.PredictionService', 'Classify'): tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.SerializeToString,
173232
('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.SerializeToString,
174233
('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.SerializeToString,
234+
('tensorflow.serving.PredictionService', 'Regress'): tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.SerializeToString,
175235
}
176236
method_implementations = {
237+
('tensorflow.serving.PredictionService', 'Classify'): face_utilities.unary_unary_inline(servicer.Classify),
177238
('tensorflow.serving.PredictionService', 'GetModelMetadata'): face_utilities.unary_unary_inline(servicer.GetModelMetadata),
178239
('tensorflow.serving.PredictionService', 'Predict'): face_utilities.unary_unary_inline(servicer.Predict),
240+
('tensorflow.serving.PredictionService', 'Regress'): face_utilities.unary_unary_inline(servicer.Regress),
179241
}
180242
server_options = beta_implementations.server_options(request_deserializers=request_deserializers, response_serializers=response_serializers, thread_pool=pool, thread_pool_size=pool_size, default_timeout=default_timeout, maximum_timeout=maximum_timeout)
181243
return beta_implementations.server(method_implementations, options=server_options)
@@ -188,16 +250,22 @@ def beta_create_PredictionService_stub(channel, host=None, metadata_transformer=
188250
file not marked beta) for all further purposes. This function was
189251
generated only to ease transition from grpcio<0.15.0 to grpcio>=0.15.0"""
190252
request_serializers = {
253+
('tensorflow.serving.PredictionService', 'Classify'): tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationRequest.SerializeToString,
191254
('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataRequest.SerializeToString,
192255
('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictRequest.SerializeToString,
256+
('tensorflow.serving.PredictionService', 'Regress'): tensorflow__serving_dot_apis_dot_regression__pb2.RegressionRequest.SerializeToString,
193257
}
194258
response_deserializers = {
259+
('tensorflow.serving.PredictionService', 'Classify'): tensorflow__serving_dot_apis_dot_classification__pb2.ClassificationResponse.FromString,
195260
('tensorflow.serving.PredictionService', 'GetModelMetadata'): tensorflow__serving_dot_apis_dot_get__model__metadata__pb2.GetModelMetadataResponse.FromString,
196261
('tensorflow.serving.PredictionService', 'Predict'): tensorflow__serving_dot_apis_dot_predict__pb2.PredictResponse.FromString,
262+
('tensorflow.serving.PredictionService', 'Regress'): tensorflow__serving_dot_apis_dot_regression__pb2.RegressionResponse.FromString,
197263
}
198264
cardinalities = {
265+
'Classify': cardinality.Cardinality.UNARY_UNARY,
199266
'GetModelMetadata': cardinality.Cardinality.UNARY_UNARY,
200267
'Predict': cardinality.Cardinality.UNARY_UNARY,
268+
'Regress': cardinality.Cardinality.UNARY_UNARY,
201269
}
202270
stub_options = beta_implementations.stub_options(host=host, metadata_transformer=metadata_transformer, request_serializers=request_serializers, response_deserializers=response_deserializers, thread_pool=pool, thread_pool_size=pool_size)
203271
return beta_implementations.dynamic_stub(channel, 'tensorflow.serving.PredictionService', cardinalities, options=stub_options)

tensorflow_serving/batching/batching_session.cc

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,14 @@ class BatchingSession : public ServingSession {
126126
const std::vector<string>& target_node_names,
127127
std::vector<Tensor>* outputs) override;
128128

129-
// TODO(b/34971139): at the moment this method ignores run_options and
130-
// run_metadata and behaves exactly like Run.
129+
// RunOptions handling:
130+
// Since multiple of these Run() calls get backed into a single call to the
131+
// underlying Session's Run(), we select an arbitrary 'run_options' (typically
132+
// they are the same across calls). The exception is the timeout; we take the
133+
// largest value (after subtracting time spent in the batching queue).
134+
//
135+
// RunMetadata:
136+
// We copy the batched call's RunMetadata to each non-batched call's output.
131137
Status Run(const RunOptions& run_options,
132138
const std::vector<std::pair<string, Tensor>>& inputs,
133139
const std::vector<string>& output_tensor_names,
@@ -210,21 +216,21 @@ Status BatchingSession::Create(
210216
}
211217

212218
Status BatchingSession::Run(
213-
const RunOptions& run_options,
214219
const std::vector<std::pair<string, Tensor>>& inputs,
215220
const std::vector<string>& output_tensor_names,
216-
const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
217-
RunMetadata* run_metadata) {
218-
LOG(WARNING) << "Currently both run_options and run_metadata are ignored, "
219-
<< "see b/34971139";
220-
return Run(inputs, output_tensor_names, target_node_names, outputs);
221+
const std::vector<string>& target_node_names,
222+
std::vector<Tensor>* outputs) {
223+
RunMetadata run_metadata;
224+
return Run(RunOptions(), inputs, output_tensor_names, target_node_names,
225+
outputs, &run_metadata);
221226
}
222227

223228
Status BatchingSession::Run(
229+
const RunOptions& run_options,
224230
const std::vector<std::pair<string, Tensor>>& inputs,
225231
const std::vector<string>& output_tensor_names,
226-
const std::vector<string>& target_node_names,
227-
std::vector<Tensor>* outputs) {
232+
const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
233+
RunMetadata* run_metadata) {
228234
if (!target_node_names.empty()) {
229235
return errors::PermissionDenied(
230236
"BatchingSession does not support target nodes");
@@ -239,8 +245,8 @@ Status BatchingSession::Run(
239245
LOG(WARNING) << "Request doesn't match any declared signature. Bypassing "
240246
"batcher. Request signature is: "
241247
<< TensorSignatureDebugString(signature);
242-
return wrapped_->Run(inputs, output_tensor_names, target_node_names,
243-
outputs);
248+
return wrapped_->Run(run_options, inputs, output_tensor_names,
249+
target_node_names, outputs, run_metadata);
244250
}
245251
BatchScheduler<BatchingSessionTask>* batch_scheduler =
246252
batch_scheduler_it->second.get();
@@ -250,12 +256,15 @@ Status BatchingSession::Run(
250256
Notification done;
251257
Status status;
252258
auto task = std::unique_ptr<BatchingSessionTask>(new BatchingSessionTask);
259+
task->enqueue_time_micros = Env::Default()->NowMicros();
260+
task->run_options = run_options;
253261
TF_RETURN_IF_ERROR(ComputeInputSize(inputs, &task->zeroth_dim_size));
254262
task->inputs = &inputs;
255263
task->output_tensor_names = &output_tensor_names;
256264
task->done = &done;
257265
task->status = &status;
258266
task->outputs = outputs;
267+
task->run_metadata = run_metadata;
259268

260269
TF_RETURN_IF_ERROR(batch_scheduler->Schedule(&task));
261270
done.WaitForNotification();
@@ -457,18 +466,55 @@ void BatchingSession::ProcessBatch(
457466
return;
458467
}
459468

460-
Status status;
469+
const uint64 dequeue_time_micros = Env::Default()->NowMicros();
461470

462471
// Regardless of the outcome, we need to propagate the status to the
463472
// individual tasks and signal that they are done. We use MakeCleanup() to
464473
// ensure that this happens no matter how we exit the method below.
474+
Status status;
465475
auto finally = MakeCleanup([&status, &batch] {
466476
for (int i = 0; i < batch->num_tasks(); ++i) {
467477
*batch->mutable_task(i)->status = status;
468478
batch->mutable_task(i)->done->Notify();
469479
}
470480
});
471481

482+
// Make sure we have at least one task that hasn't exceeded its timeout from
483+
// queue time alone, and find the latest task deadline which we'll use for the
484+
// overall batch.
485+
bool all_tasks_timeout_exceeded = true;
486+
uint64 batch_deadline_micros = 0;
487+
for (int i = 0; i < batch->num_tasks(); ++i) {
488+
const BatchingSessionTask& task = batch->task(i);
489+
// If the caller doesn't populate RunOptions, the timeout is 0 by default.
490+
// Interpret that as "no timeout" i.e. infinity.
491+
const int64 task_timeout_micros =
492+
task.run_options.timeout_in_ms() <= 0
493+
? INT_MAX
494+
: task.run_options.timeout_in_ms() * 1000;
495+
const uint64 task_deadline_micros =
496+
task.enqueue_time_micros + task_timeout_micros;
497+
if (task_deadline_micros > dequeue_time_micros) {
498+
all_tasks_timeout_exceeded = false;
499+
if (task_deadline_micros > batch_deadline_micros) {
500+
batch_deadline_micros = task_deadline_micros;
501+
}
502+
}
503+
}
504+
if (all_tasks_timeout_exceeded) {
505+
status = Status(error::RESOURCE_EXHAUSTED,
506+
"Run() timeout exceeded while waiting in batching queue");
507+
return;
508+
}
509+
510+
RunOptions run_options = batch->task(0).run_options;
511+
if (batch_deadline_micros == INT_MAX) {
512+
run_options.set_timeout_in_ms(0);
513+
} else {
514+
run_options.set_timeout_in_ms(
515+
(batch_deadline_micros - dequeue_time_micros) / 1000);
516+
}
517+
472518
std::vector<std::pair<string, Tensor>> merged_inputs;
473519
status = MergeInputTensors(signature, *batch, &merged_inputs);
474520
if (!status.ok()) {
@@ -478,8 +524,13 @@ void BatchingSession::ProcessBatch(
478524
const std::vector<string> output_tensor_names(
479525
signature.output_tensors.begin(), signature.output_tensors.end());
480526
std::vector<Tensor> combined_outputs;
481-
status = wrapped_->Run(merged_inputs, output_tensor_names,
482-
{} /* target node names */, &combined_outputs);
527+
RunMetadata run_metadata;
528+
status = wrapped_->Run(run_options, merged_inputs, output_tensor_names,
529+
{} /* target node names */, &combined_outputs,
530+
&run_metadata);
531+
for (int i = 0; i < batch->num_tasks(); ++i) {
532+
*(batch->mutable_task(i)->run_metadata) = run_metadata;
533+
}
483534
if (!status.ok()) {
484535
return;
485536
}

tensorflow_serving/batching/batching_session.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,8 @@ struct BatchingSessionTask : public BatchTask {
167167
size_t size() const override { return zeroth_dim_size; }
168168

169169
// Fields populated when a task is received.
170+
uint64 enqueue_time_micros;
171+
RunOptions run_options;
170172
size_t zeroth_dim_size;
171173
const std::vector<std::pair<string, Tensor>>* inputs;
172174
const std::vector<string>* output_tensor_names;
@@ -175,6 +177,7 @@ struct BatchingSessionTask : public BatchTask {
175177
Notification* done;
176178
Status* status;
177179
std::vector<Tensor>* outputs;
180+
RunMetadata* run_metadata;
178181
};
179182

180183
} // namespace serving

0 commit comments

Comments
 (0)