Skip to content

Commit f196fa8

Browse files
cky9301tensorflow-copybara
authored andcommitted
Move Servable interface to tensorflow serving.
PiperOrigin-RevId: 549286752
1 parent 72b83ed commit f196fa8

File tree

5 files changed

+298
-0
lines changed

5 files changed

+298
-0
lines changed

tensorflow_serving/servables/tensorflow/BUILD

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1072,6 +1072,38 @@ cc_library(
10721072
],
10731073
)
10741074

1075+
cc_library(
1076+
name = "servable",
1077+
srcs = ["servable.cc"],
1078+
hdrs = ["servable.h"],
1079+
visibility = [
1080+
"//visibility:public",
1081+
],
1082+
deps = [
1083+
"//tensorflow_serving/apis:classification_cc_proto",
1084+
"//tensorflow_serving/apis:get_model_metadata_cc_proto",
1085+
"//tensorflow_serving/apis:inference_cc_proto",
1086+
"//tensorflow_serving/apis:predict_cc_proto",
1087+
"//tensorflow_serving/apis:regression_cc_proto",
1088+
"@com_google_absl//absl/functional:any_invocable",
1089+
"@com_google_absl//absl/status",
1090+
"@com_google_absl//absl/status:statusor",
1091+
"@com_google_absl//absl/strings",
1092+
],
1093+
)
1094+
1095+
cc_library(
1096+
name = "mock_servable",
1097+
testonly = True,
1098+
hdrs = ["mock_servable.h"],
1099+
deps = [
1100+
":servable",
1101+
"//tensorflow_serving/test_util",
1102+
"@com_google_absl//absl/functional:any_invocable",
1103+
"@com_google_absl//absl/status",
1104+
],
1105+
)
1106+
10751107
cc_test(
10761108
name = "thread_pool_factory_test",
10771109
size = "small",
@@ -1084,3 +1116,15 @@ cc_test(
10841116
"@org_tensorflow//tensorflow/core/platform:threadpool_options",
10851117
],
10861118
)
1119+
1120+
cc_test(
1121+
name = "servable_test",
1122+
srcs = ["servable_test.cc"],
1123+
deps = [
1124+
":servable",
1125+
"//tensorflow_serving/apis:predict_cc_proto",
1126+
"//tensorflow_serving/core/test_util:test_main",
1127+
"//tensorflow_serving/test_util",
1128+
"@com_google_absl//absl/status",
1129+
],
1130+
)
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
/* Copyright 2023 Google Inc. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_MOCK_SERVABLE_H_
17+
#define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_MOCK_SERVABLE_H_
18+
19+
#include <gmock/gmock.h>
20+
#include "absl/functional/any_invocable.h"
21+
#include "absl/status/status.h"
22+
#include "tensorflow_serving/servables/tensorflow/servable.h"
23+
24+
namespace tensorflow {
25+
namespace serving {
26+
27+
// A mock of tensorflow::serving::Servable.
28+
class MockServable : public Servable {
29+
public:
30+
MockServable() : Servable("", 0) {}
31+
~MockServable() override = default;
32+
33+
MOCK_METHOD(absl::Status, Classify,
34+
(const tensorflow::serving::ClassificationRequest& request,
35+
tensorflow::serving::ClassificationResponse* response));
36+
MOCK_METHOD(absl::Status, Regress,
37+
(const tensorflow::serving::RegressionRequest& request,
38+
tensorflow::serving::RegressionResponse* response));
39+
MOCK_METHOD(absl::Status, Predict,
40+
(const tensorflow::serving::PredictRequest& request,
41+
tensorflow::serving::PredictResponse* response));
42+
MOCK_METHOD(absl::Status, PredictStreamed,
43+
(const tensorflow::serving::PredictRequest& request,
44+
absl::AnyInvocable<void(tensorflow::serving::PredictResponse)>
45+
response_callback));
46+
MOCK_METHOD(absl::Status, MultiInference,
47+
(const tensorflow::serving::MultiInferenceRequest& request,
48+
tensorflow::serving::MultiInferenceResponse* response));
49+
MOCK_METHOD(absl::Status, GetModelMetadata,
50+
(const tensorflow::serving::GetModelMetadataRequest& request,
51+
tensorflow::serving::GetModelMetadataResponse* response));
52+
};
53+
54+
} // namespace serving
55+
} // namespace tensorflow
56+
57+
#endif // TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_MOCK_SERVABLE_H_
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/* Copyright 2023 Google Inc. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow_serving/servables/tensorflow/servable.h"
17+
18+
namespace tensorflow {
19+
namespace serving {
20+
21+
EmptyServable::EmptyServable()
22+
: Servable(/*name=*/"", /*version=*/0),
23+
error_(absl::FailedPreconditionError("No models loaded")) {}
24+
25+
} // namespace serving
26+
} // namespace tensorflow
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
/* Copyright 2023 Google Inc. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_SERVABLE_H_
17+
#define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_SERVABLE_H_
18+
19+
#include <stdint.h>
20+
21+
#include <string>
22+
23+
#include "absl/functional/any_invocable.h"
24+
#include "absl/status/status.h"
25+
#include "absl/status/statusor.h"
26+
#include "absl/strings/string_view.h"
27+
#include "tensorflow_serving/apis/classification.pb.h"
28+
#include "tensorflow_serving/apis/get_model_metadata.pb.h"
29+
#include "tensorflow_serving/apis/inference.pb.h"
30+
#include "tensorflow_serving/apis/predict.pb.h"
31+
#include "tensorflow_serving/apis/regression.pb.h"
32+
33+
namespace tensorflow {
34+
namespace serving {
35+
36+
// Provides a `PredictionService`-like interface. All concrete implementations
37+
// are expected to be thread-safe.
38+
class Servable {
39+
public:
40+
Servable(absl::string_view name, int64_t version)
41+
: name_(std::string(name)), version_(version) {}
42+
43+
virtual ~Servable() = default;
44+
45+
// Returns the name associated with this servable.
46+
absl::string_view name() const { return name_; }
47+
48+
// Returns the version associated with this servable.
49+
int64_t version() const { return version_; }
50+
51+
virtual absl::Status Classify(const ClassificationRequest& request,
52+
ClassificationResponse* response) = 0;
53+
54+
virtual absl::Status Regress(const RegressionRequest& request,
55+
RegressionResponse* response) = 0;
56+
57+
virtual absl::Status Predict(const PredictRequest& request,
58+
PredictResponse* response) = 0;
59+
60+
// Streamed version of `Predict`. Experimental API that is not yet part of the
61+
// PredictionService API.
62+
//
63+
// `response_callback` is called for each streamed output, zero or more times,
64+
// when the streamed output becomes available. The callback invocation is
65+
// serialized by the runtime, which means that `response_callback` does not
66+
// have to be thread-safe, but blocking inside the callback causes the next
67+
// callback invocation to be delayed. The implementation guarantees that the
68+
// callback is never called after the `PredictStreamed` method returns.
69+
virtual absl::Status PredictStreamed(
70+
const PredictRequest& request,
71+
absl::AnyInvocable<void(PredictResponse)> response_callback) = 0;
72+
73+
virtual absl::Status MultiInference(const MultiInferenceRequest& request,
74+
MultiInferenceResponse* response) = 0;
75+
76+
virtual absl::Status GetModelMetadata(const GetModelMetadataRequest& request,
77+
GetModelMetadataResponse* response) = 0;
78+
79+
private:
80+
// Metadata of this servable. Currently matches the fields in
81+
// `ServableId`.
82+
const std::string name_;
83+
const int64_t version_;
84+
};
85+
86+
// An "empty" servable where there's no model associated with the servable. All
87+
// methods will return an error.
88+
//
89+
// Empty servables can be used in places where a servable is expected but we
90+
// don't need to load any models. For example, Model Server currently expects
91+
// each task to have at least one servable loaded, but Pathways Serving requires
92+
// only the controller task to initiate loading servables. So we use empty
93+
// servables in non-zero tasks to make sure non-zero tasks don't load anything.
94+
class EmptyServable : public Servable {
95+
public:
96+
EmptyServable();
97+
98+
absl::Status Classify(const ClassificationRequest& request,
99+
ClassificationResponse* response) override {
100+
return error_;
101+
}
102+
103+
absl::Status Regress(const RegressionRequest& request,
104+
RegressionResponse* response) override {
105+
return error_;
106+
}
107+
108+
absl::Status Predict(const PredictRequest& request,
109+
PredictResponse* response) override {
110+
return error_;
111+
}
112+
113+
absl::Status PredictStreamed(
114+
const PredictRequest& request,
115+
absl::AnyInvocable<void(PredictResponse)> response_callback) override {
116+
return error_;
117+
}
118+
119+
absl::Status MultiInference(const MultiInferenceRequest& request,
120+
MultiInferenceResponse* response) override {
121+
return error_;
122+
}
123+
124+
absl::Status GetModelMetadata(const GetModelMetadataRequest& request,
125+
GetModelMetadataResponse* response) override {
126+
return error_;
127+
}
128+
129+
private:
130+
absl::Status error_;
131+
};
132+
133+
} // namespace serving
134+
} // namespace tensorflow
135+
136+
#endif // TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_SERVABLE_H_
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
/* Copyright 2023 Google Inc. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "tensorflow_serving/servables/tensorflow/servable.h"
17+
18+
#include <gmock/gmock.h>
19+
#include <gtest/gtest.h>
20+
#include "absl/status/status.h"
21+
#include "tensorflow_serving/apis/predict.pb.h"
22+
23+
namespace tensorflow {
24+
namespace serving {
25+
namespace {
26+
27+
TEST(EmptyServableTest, Predict) {
28+
PredictResponse response;
29+
EXPECT_EQ(EmptyServable().Predict(PredictRequest(), &response).code(),
30+
absl::StatusCode::kFailedPrecondition);
31+
}
32+
33+
} // namespace
34+
} // namespace serving
35+
} // namespace tensorflow

0 commit comments

Comments
 (0)