Skip to content

Commit c4f26c1

Browse files
Automated Code Change
PiperOrigin-RevId: 765482765
1 parent fabf1fd commit c4f26c1

File tree

6 files changed

+91
-85
lines changed

6 files changed

+91
-85
lines changed

tensorflow_serving/servables/tensorflow/bundle_factory_util_test.cc

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -51,17 +51,18 @@ using Batcher = SharedBatchScheduler<BatchingSessionTask>;
5151

5252
class MockSession : public Session {
5353
public:
54-
MOCK_METHOD(tensorflow::Status, Create, (const GraphDef& graph), (override));
55-
MOCK_METHOD(tensorflow::Status, Extend, (const GraphDef& graph), (override));
56-
MOCK_METHOD(tensorflow::Status, ListDevices,
54+
MOCK_METHOD(absl::Status, Create, (const GraphDef& graph), (override));
55+
MOCK_METHOD(absl::Status, Extend, (const GraphDef& graph), (override));
56+
MOCK_METHOD(absl::Status, ListDevices,
5757
(std::vector<DeviceAttributes> * response), (override));
58-
MOCK_METHOD(tensorflow::Status, Close, (), (override));
59-
60-
Status Run(const RunOptions& run_options,
61-
const std::vector<std::pair<string, Tensor>>& inputs,
62-
const std::vector<string>& output_tensor_names,
63-
const std::vector<string>& target_node_names,
64-
std::vector<Tensor>* outputs, RunMetadata* run_metadata) override {
58+
MOCK_METHOD(absl::Status, Close, (), (override));
59+
60+
absl::Status Run(const RunOptions& run_options,
61+
const std::vector<std::pair<string, Tensor>>& inputs,
62+
const std::vector<string>& output_tensor_names,
63+
const std::vector<string>& target_node_names,
64+
std::vector<Tensor>* outputs,
65+
RunMetadata* run_metadata) override {
6566
// half plus two: output should be input / 2 + 2.
6667
const auto& input = inputs[0].second.flat<float>();
6768
Tensor output(DT_FLOAT, inputs[0].second.shape());
@@ -72,9 +73,10 @@ class MockSession : public Session {
7273
}
7374

7475
// Unused, but we need to provide a definition (virtual = 0).
75-
Status Run(const std::vector<std::pair<std::string, Tensor>>&,
76-
const std::vector<std::string>&, const std::vector<std::string>&,
77-
std::vector<Tensor>* outputs) override {
76+
absl::Status Run(const std::vector<std::pair<std::string, Tensor>>&,
77+
const std::vector<std::string>&,
78+
const std::vector<std::string>&,
79+
std::vector<Tensor>* outputs) override {
7880
return errors::Unimplemented(
7981
"Run with threadpool is not supported for this session.");
8082
}
@@ -171,7 +173,7 @@ TEST_F(BundleFactoryUtilTest, WrapSessionForBatchingConfigError) {
171173
auto status = WrapSessionForBatching(batching_params, batch_scheduler,
172174
{test_util::GetTestSessionSignature()},
173175
&bundle.session);
174-
ASSERT_TRUE(errors::IsInvalidArgument(status));
176+
ASSERT_TRUE(absl::IsInvalidArgument(status));
175177
}
176178

177179
TEST_F(BundleFactoryUtilTest, GetPerModelBatchingParams) {
@@ -219,7 +221,7 @@ TEST_F(BundleFactoryUtilTest, GetPerModelBatchingParams) {
219221

220222
TEST_F(BundleFactoryUtilTest, EstimateResourceFromPathWithBadExport) {
221223
ResourceAllocation resource_requirement;
222-
const Status status = EstimateResourceFromPath(
224+
const absl::Status status = EstimateResourceFromPath(
223225
"/a/bogus/export/dir",
224226
/*use_validation_result=*/false, &resource_requirement);
225227
EXPECT_FALSE(status.ok());

tensorflow_serving/servables/tensorflow/classifier.cc

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ class SavedModelTensorFlowClassifier : public ClassifierInterface {
5959

6060
~SavedModelTensorFlowClassifier() override = default;
6161

62-
Status Classify(const ClassificationRequest& request,
63-
ClassificationResult* result) override {
62+
absl::Status Classify(const ClassificationRequest& request,
63+
ClassificationResult* result) override {
6464
TRACELITERAL("TensorFlowClassifier::Classify");
6565

6666
string input_tensor_name;
@@ -100,8 +100,8 @@ class SavedModelClassifier : public ClassifierInterface {
100100

101101
~SavedModelClassifier() override = default;
102102

103-
Status Classify(const ClassificationRequest& request,
104-
ClassificationResult* result) override {
103+
absl::Status Classify(const ClassificationRequest& request,
104+
ClassificationResult* result) override {
105105
// Get the default signature of the graph. Expected to be a
106106
// classification signature.
107107
// TODO(b/26220896): Move TensorFlowClassifier creation to construction
@@ -123,22 +123,22 @@ class SavedModelClassifier : public ClassifierInterface {
123123

124124
} // namespace
125125

126-
Status CreateClassifierFromSavedModelBundle(
126+
absl::Status CreateClassifierFromSavedModelBundle(
127127
const RunOptions& run_options, std::unique_ptr<SavedModelBundle> bundle,
128128
std::unique_ptr<ClassifierInterface>* service) {
129129
service->reset(new SavedModelClassifier(run_options, std::move(bundle)));
130130
return absl::OkStatus();
131131
}
132132

133-
Status CreateFlyweightTensorFlowClassifier(
133+
absl::Status CreateFlyweightTensorFlowClassifier(
134134
const RunOptions& run_options, Session* session,
135135
const SignatureDef* signature,
136136
std::unique_ptr<ClassifierInterface>* service) {
137137
return CreateFlyweightTensorFlowClassifier(
138138
run_options, session, signature, thread::ThreadPoolOptions(), service);
139139
}
140140

141-
Status CreateFlyweightTensorFlowClassifier(
141+
absl::Status CreateFlyweightTensorFlowClassifier(
142142
const RunOptions& run_options, Session* session,
143143
const SignatureDef* signature,
144144
const thread::ThreadPoolOptions& thread_pool_options,
@@ -148,9 +148,9 @@ Status CreateFlyweightTensorFlowClassifier(
148148
return absl::OkStatus();
149149
}
150150

151-
Status GetClassificationSignatureDef(const ModelSpec& model_spec,
152-
const MetaGraphDef& meta_graph_def,
153-
SignatureDef* signature) {
151+
absl::Status GetClassificationSignatureDef(const ModelSpec& model_spec,
152+
const MetaGraphDef& meta_graph_def,
153+
SignatureDef* signature) {
154154
const string signature_name = model_spec.signature_name().empty()
155155
? kDefaultServingSignatureDefKey
156156
: model_spec.signature_name();
@@ -173,9 +173,9 @@ Status GetClassificationSignatureDef(const ModelSpec& model_spec,
173173
return absl::OkStatus();
174174
}
175175

176-
Status PreProcessClassification(const SignatureDef& signature,
177-
string* input_tensor_name,
178-
std::vector<string>* output_tensor_names) {
176+
absl::Status PreProcessClassification(
177+
const SignatureDef& signature, string* input_tensor_name,
178+
std::vector<string>* output_tensor_names) {
179179
if (GetSignatureMethodNameCheckFeature() &&
180180
signature.method_name() != kClassifyMethodName) {
181181
return errors::InvalidArgument(strings::StrCat(
@@ -222,7 +222,7 @@ Status PreProcessClassification(const SignatureDef& signature,
222222
return absl::OkStatus();
223223
}
224224

225-
Status PostProcessClassificationResult(
225+
absl::Status PostProcessClassificationResult(
226226
const SignatureDef& signature, int num_examples,
227227
const std::vector<string>& output_tensor_names,
228228
const std::vector<Tensor>& output_tensors, ClassificationResult* result) {
@@ -323,12 +323,12 @@ Status PostProcessClassificationResult(
323323
return absl::OkStatus();
324324
}
325325

326-
Status RunClassify(const RunOptions& run_options,
327-
const MetaGraphDef& meta_graph_def,
328-
const absl::optional<int64_t>& servable_version,
329-
Session* session, const ClassificationRequest& request,
330-
ClassificationResponse* response,
331-
const thread::ThreadPoolOptions& thread_pool_options) {
326+
absl::Status RunClassify(const RunOptions& run_options,
327+
const MetaGraphDef& meta_graph_def,
328+
const absl::optional<int64_t>& servable_version,
329+
Session* session, const ClassificationRequest& request,
330+
ClassificationResponse* response,
331+
const thread::ThreadPoolOptions& thread_pool_options) {
332332
SignatureDef signature;
333333
TF_RETURN_IF_ERROR(GetClassificationSignatureDef(request.model_spec(),
334334
meta_graph_def, &signature));

tensorflow_serving/servables/tensorflow/classifier_test.cc

Lines changed: 39 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -69,25 +69,25 @@ class FakeSession : public tensorflow::Session {
6969
explicit FakeSession(absl::optional<int64_t> expected_timeout)
7070
: expected_timeout_(expected_timeout) {}
7171
~FakeSession() override = default;
72-
Status Create(const GraphDef& graph) override {
72+
absl::Status Create(const GraphDef& graph) override {
7373
return errors::Unimplemented("not available in fake");
7474
}
75-
Status Extend(const GraphDef& graph) override {
75+
absl::Status Extend(const GraphDef& graph) override {
7676
return errors::Unimplemented("not available in fake");
7777
}
7878

79-
Status Close() override {
79+
absl::Status Close() override {
8080
return errors::Unimplemented("not available in fake");
8181
}
8282

83-
Status ListDevices(std::vector<DeviceAttributes>* response) override {
83+
absl::Status ListDevices(std::vector<DeviceAttributes>* response) override {
8484
return errors::Unimplemented("not available in fake");
8585
}
8686

87-
Status Run(const std::vector<std::pair<string, Tensor>>& inputs,
88-
const std::vector<string>& output_names,
89-
const std::vector<string>& target_nodes,
90-
std::vector<Tensor>* outputs) override {
87+
absl::Status Run(const std::vector<std::pair<string, Tensor>>& inputs,
88+
const std::vector<string>& output_names,
89+
const std::vector<string>& target_nodes,
90+
std::vector<Tensor>* outputs) override {
9191
if (expected_timeout_) {
9292
LOG(FATAL) << "Run() without RunOptions not expected to be called";
9393
}
@@ -96,21 +96,23 @@ class FakeSession : public tensorflow::Session {
9696
&run_metadata);
9797
}
9898

99-
Status Run(const RunOptions& run_options,
100-
const std::vector<std::pair<string, Tensor>>& inputs,
101-
const std::vector<string>& output_names,
102-
const std::vector<string>& target_nodes,
103-
std::vector<Tensor>* outputs, RunMetadata* run_metadata) override {
99+
absl::Status Run(const RunOptions& run_options,
100+
const std::vector<std::pair<string, Tensor>>& inputs,
101+
const std::vector<string>& output_names,
102+
const std::vector<string>& target_nodes,
103+
std::vector<Tensor>* outputs,
104+
RunMetadata* run_metadata) override {
104105
return Run(run_options, inputs, output_names, target_nodes, outputs,
105106
run_metadata, thread::ThreadPoolOptions());
106107
}
107108

108-
Status Run(const RunOptions& run_options,
109-
const std::vector<std::pair<string, Tensor>>& inputs,
110-
const std::vector<string>& output_names,
111-
const std::vector<string>& target_nodes,
112-
std::vector<Tensor>* outputs, RunMetadata* run_metadata,
113-
const thread::ThreadPoolOptions& thread_pool_options) override {
109+
absl::Status Run(
110+
const RunOptions& run_options,
111+
const std::vector<std::pair<string, Tensor>>& inputs,
112+
const std::vector<string>& output_names,
113+
const std::vector<string>& target_nodes, std::vector<Tensor>* outputs,
114+
RunMetadata* run_metadata,
115+
const thread::ThreadPoolOptions& thread_pool_options) override {
114116
if (expected_timeout_) {
115117
CHECK_EQ(*expected_timeout_, run_options.timeout_in_ms());
116118
}
@@ -143,8 +145,8 @@ class FakeSession : public tensorflow::Session {
143145
}
144146

145147
// Parses TensorFlow Examples from a string Tensor.
146-
static Status GetExamples(const Tensor& input,
147-
std::vector<Example>* examples) {
148+
static absl::Status GetExamples(const Tensor& input,
149+
std::vector<Example>* examples) {
148150
examples->clear();
149151
const int batch_size = input.dim_size(0);
150152
const auto& flat_input = input.flat<tstring>();
@@ -183,9 +185,9 @@ class FakeSession : public tensorflow::Session {
183185
// Creates a Tensor by copying the "class" feature from each Example.
184186
// Requires each Example have an bytes feature called "class" which is of the
185187
// same non-zero length.
186-
static Status GetClassTensor(const std::vector<Example>& examples,
187-
const std::vector<string>& output_names,
188-
Tensor* classes, Tensor* scores) {
188+
static absl::Status GetClassTensor(const std::vector<Example>& examples,
189+
const std::vector<string>& output_names,
190+
Tensor* classes, Tensor* scores) {
189191
if (examples.empty()) {
190192
return errors::Internal("empty example list");
191193
}
@@ -281,7 +283,7 @@ class ClassifierTest : public ::testing::TestWithParam<bool> {
281283
return example;
282284
}
283285

284-
Status Create() {
286+
absl::Status Create() {
285287
std::unique_ptr<SavedModelBundle> saved_model(new SavedModelBundle);
286288
saved_model->meta_graph_def = saved_model_bundle_->meta_graph_def;
287289
saved_model->session = std::move(saved_model_bundle_->session);
@@ -699,7 +701,7 @@ TEST_P(ClassifierTest, InvalidNamedSignature) {
699701
request_.mutable_input()->mutable_example_list()->mutable_examples();
700702
*examples->Add() = example({{"dos", 2}, {"uno", 1}});
701703
*examples->Add() = example({{"cuatro", 4}, {"tres", 3}});
702-
Status status = classifier_->Classify(request_, &result_);
704+
absl::Status status = classifier_->Classify(request_, &result_);
703705

704706
ASSERT_FALSE(status.ok());
705707
EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
@@ -723,7 +725,7 @@ TEST_P(ClassifierTest, MalformedScores) {
723725
request_.mutable_input()->mutable_example_list()->mutable_examples();
724726
*examples->Add() = example({{"dos", 2}, {"uno", 1}});
725727
*examples->Add() = example({{"cuatro", 4}, {"tres", 3}});
726-
Status status = classifier_->Classify(request_, &result_);
728+
absl::Status status = classifier_->Classify(request_, &result_);
727729

728730
ASSERT_FALSE(status.ok());
729731
EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
@@ -748,7 +750,7 @@ TEST_P(ClassifierTest, MissingClassificationSignature) {
748750
request_.mutable_input()->mutable_example_list()->mutable_examples();
749751
*examples->Add() = example({{"dos", 2}});
750752
// TODO(b/26220896): This error should move to construction time.
751-
Status status = classifier_->Classify(request_, &result_);
753+
absl::Status status = classifier_->Classify(request_, &result_);
752754
ASSERT_FALSE(status.ok());
753755
EXPECT_EQ(static_cast<absl::StatusCode>(absl::StatusCode::kInvalidArgument),
754756
status.code())
@@ -767,7 +769,7 @@ TEST_P(ClassifierTest, EmptyInput) {
767769
TF_ASSERT_OK(Create());
768770
// Touch input.
769771
request_.mutable_input();
770-
Status status = classifier_->Classify(request_, &result_);
772+
absl::Status status = classifier_->Classify(request_, &result_);
771773
ASSERT_FALSE(status.ok());
772774
EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
773775
EXPECT_THAT(status.message(), ::testing::HasSubstr("Input is empty"));
@@ -784,7 +786,7 @@ TEST_P(ClassifierTest, EmptyExampleList) {
784786
TF_ASSERT_OK(Create());
785787
// Touch ExampleList.
786788
request_.mutable_input()->mutable_example_list();
787-
Status status = classifier_->Classify(request_, &result_);
789+
absl::Status status = classifier_->Classify(request_, &result_);
788790
ASSERT_FALSE(status.ok());
789791
EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
790792
EXPECT_THAT(status.message(), ::testing::HasSubstr("Input is empty"));
@@ -803,7 +805,7 @@ TEST_P(ClassifierTest, EmptyExampleListWithContext) {
803805
*request_.mutable_input()
804806
->mutable_example_list_with_context()
805807
->mutable_context() = example({{"dos", 2}});
806-
Status status = classifier_->Classify(request_, &result_);
808+
absl::Status status = classifier_->Classify(request_, &result_);
807809
ASSERT_FALSE(status.ok());
808810
EXPECT_EQ(status.code(), error::Code::INVALID_ARGUMENT);
809811
EXPECT_THAT(status.message(), ::testing::HasSubstr("Input is empty"));
@@ -826,7 +828,7 @@ TEST_P(ClassifierTest, RunsFails) {
826828
auto* examples =
827829
request_.mutable_input()->mutable_example_list()->mutable_examples();
828830
*examples->Add() = example({{"dos", 2}});
829-
Status status = classifier_->Classify(request_, &result_);
831+
absl::Status status = classifier_->Classify(request_, &result_);
830832
ASSERT_FALSE(status.ok());
831833
EXPECT_THAT(status.ToString(), ::testing::HasSubstr("Run totally failed"));
832834

@@ -853,7 +855,7 @@ TEST_P(ClassifierTest, ClassesIncorrectTensorBatchSize) {
853855
*examples->Add() = example({{"dos", 2}, {"uno", 1}});
854856
*examples->Add() = example({{"cuatro", 4}, {"tres", 3}});
855857

856-
Status status = classifier_->Classify(request_, &result_);
858+
absl::Status status = classifier_->Classify(request_, &result_);
857859
ASSERT_FALSE(status.ok());
858860
EXPECT_THAT(status.ToString(), ::testing::HasSubstr("batch size"));
859861

@@ -881,7 +883,7 @@ TEST_P(ClassifierTest, ClassesIncorrectTensorType) {
881883
*examples->Add() = example({{"dos", 2}, {"uno", 1}});
882884
*examples->Add() = example({{"cuatro", 4}, {"tres", 3}});
883885

884-
Status status = classifier_->Classify(request_, &result_);
886+
absl::Status status = classifier_->Classify(request_, &result_);
885887
ASSERT_FALSE(status.ok());
886888
EXPECT_THAT(status.ToString(),
887889
::testing::HasSubstr("Expected classes Tensor of DT_STRING"));
@@ -909,7 +911,7 @@ TEST_P(ClassifierTest, ScoresIncorrectTensorBatchSize) {
909911
*examples->Add() = example({{"dos", 2}, {"uno", 1}});
910912
*examples->Add() = example({{"cuatro", 4}, {"tres", 3}});
911913

912-
Status status = classifier_->Classify(request_, &result_);
914+
absl::Status status = classifier_->Classify(request_, &result_);
913915
ASSERT_FALSE(status.ok());
914916
EXPECT_THAT(status.ToString(), ::testing::HasSubstr("batch size"));
915917

@@ -936,7 +938,7 @@ TEST_P(ClassifierTest, ScoresIncorrectTensorType) {
936938
*examples->Add() = example({{"dos", 2}, {"uno", 1}});
937939
*examples->Add() = example({{"cuatro", 4}, {"tres", 3}});
938940

939-
Status status = classifier_->Classify(request_, &result_);
941+
absl::Status status = classifier_->Classify(request_, &result_);
940942
ASSERT_FALSE(status.ok());
941943
EXPECT_THAT(status.ToString(),
942944
::testing::HasSubstr("Expected scores Tensor of DT_FLOAT"));
@@ -965,7 +967,7 @@ TEST_P(ClassifierTest, MismatchedNumberOfTensorClasses) {
965967
*examples->Add() = example({{"dos", 2}, {"uno", 1}});
966968
*examples->Add() = example({{"cuatro", 4}, {"tres", 3}});
967969

968-
Status status = classifier_->Classify(request_, &result_);
970+
absl::Status status = classifier_->Classify(request_, &result_);
969971
ASSERT_FALSE(status.ok());
970972
EXPECT_THAT(
971973
status.ToString(),

0 commit comments

Comments
 (0)