Skip to content

Commit cc713a1

Browse files
Internal change.
PiperOrigin-RevId: 686677820
1 parent 0961ff4 commit cc713a1

File tree

3 files changed

+81
-7
lines changed

3 files changed

+81
-7
lines changed

tensorflow_serving/servables/tensorflow/BUILD

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,15 +159,17 @@ cc_library(
159159
],
160160
deps = [
161161
":bundle_factory_util",
162+
":saved_model_config_cc_proto",
163+
":saved_model_config_util",
162164
":session_bundle_config_cc_proto",
163165
":tflite_session_lib",
164166
"//tensorflow_serving/batching:batching_session",
165167
"//tensorflow_serving/core:loader",
166168
"//tensorflow_serving/resources:resources_cc_proto",
167-
"//tensorflow_serving/session_bundle:session_bundle_util",
169+
"//tensorflow_serving/session_bundle:session_bundle_util", # buildcleaner: keep
170+
"@com_google_absl//absl/log",
168171
"@com_google_absl//absl/strings",
169172
"@com_google_absl//absl/types:optional",
170-
"@com_google_protobuf//:cc_wkt_protos",
171173
"@org_tensorflow//tensorflow/cc/saved_model:loader",
172174
"@org_tensorflow//tensorflow/cc/saved_model:tag_constants",
173175
"@org_tensorflow//tensorflow/core:core_cpu",
@@ -190,14 +192,20 @@ cc_test(
190192
":bundle_factory_test",
191193
":bundle_factory_test_util",
192194
":saved_model_bundle_factory",
195+
":saved_model_config_cc_proto",
193196
":session_bundle_config_cc_proto",
194197
"//tensorflow_serving/core/test_util:session_test_util",
195198
"//tensorflow_serving/core/test_util:test_main",
199+
"@com_google_absl//absl/status",
196200
"@com_google_protobuf//:cc_wkt_protos",
201+
"@org_tensorflow//tensorflow/cc/saved_model:constants",
197202
"@org_tensorflow//tensorflow/cc/saved_model:loader",
198203
"@org_tensorflow//tensorflow/cc/saved_model:tag_constants",
199204
"@org_tensorflow//tensorflow/core:core_cpu",
205+
"@org_tensorflow//tensorflow/core:framework",
200206
"@org_tensorflow//tensorflow/core:lib",
207+
"@org_tensorflow//tensorflow/core:portable_tensorflow_test_lib",
208+
"@org_tensorflow//tensorflow/core:protos_all_cc",
201209
"@org_tensorflow//tensorflow/core:test",
202210
],
203211
)

tensorflow_serving/servables/tensorflow/saved_model_bundle_factory.cc

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020
#include <utility>
2121
#include <vector>
2222

23+
#include "absl/log/log.h"
2324
#include "absl/strings/string_view.h"
2425
#include "tensorflow/cc/saved_model/tag_constants.h"
2526
#include "tensorflow/core/framework/tensor.pb.h"
@@ -31,6 +32,8 @@ limitations under the License.
3132
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
3233
#include "tensorflow/core/public/session_options.h"
3334
#include "tensorflow_serving/servables/tensorflow/bundle_factory_util.h"
35+
#include "tensorflow_serving/servables/tensorflow/saved_model_config.pb.h"
36+
#include "tensorflow_serving/servables/tensorflow/saved_model_config_util.h"
3437
#include "tensorflow_serving/servables/tensorflow/tflite_session.h"
3538
#include "tensorflow_serving/session_bundle/session_bundle_util.h"
3639

@@ -125,15 +128,16 @@ Status SavedModelBundleFactory::InternalCreateSavedModelBundle(
125128
if (saved_model_tags.empty()) {
126129
saved_model_tags.insert(kSavedModelTagServe);
127130
}
131+
bool is_tflite = config_.prefer_tflite_model() && TfLiteModelFound(path);
128132
const auto& session_options = [&]() {
129133
auto result = GetSessionOptions(config_);
130134
string mixed_precision_value = config_.mixed_precision();
135+
tensorflow::ConfigProto& config = result.config;
136+
GraphOptions* gopt = config.mutable_graph_options();
137+
RewriterConfig* rwcfg = gopt->mutable_rewrite_options();
131138
if (!mixed_precision_value.empty()) {
132139
if (mixed_precision_value == "bfloat16") {
133140
LOG(INFO) << "Running inference with bfloat16 auto mixed precision";
134-
tensorflow::ConfigProto& config = result.config;
135-
GraphOptions* gopt = config.mutable_graph_options();
136-
RewriterConfig* rwcfg = gopt->mutable_rewrite_options();
137141
rwcfg->set_auto_mixed_precision_onednn_bfloat16(RewriterConfig::ON);
138142
} else {
139143
LOG(WARNING)
@@ -147,10 +151,20 @@ Status SavedModelBundleFactory::InternalCreateSavedModelBundle(
147151
session_metadata->set_name(metadata->servable_id.name);
148152
session_metadata->set_version(metadata->servable_id.version);
149153
}
154+
// Set other options from saved model config.
155+
if (!is_tflite) {
156+
absl::StatusOr<SavedModelConfig> model_config =
157+
LoadSavedModelConfigOrDefault(path);
158+
if (!model_config.ok()) {
159+
LOG(WARNING) << "Failed to load saved model config: "
160+
<< model_config.status();
161+
} else if (model_config->has_session_overrides()) {
162+
UpdateRewriterConfig(model_config->session_overrides(), rwcfg);
163+
}
164+
}
150165
return result;
151166
}();
152167

153-
bool is_tflite = config_.prefer_tflite_model() && TfLiteModelFound(path);
154168
if (is_tflite) {
155169
int num_tflite_pools = config_.num_tflite_pools();
156170
if (num_tflite_pools == 0 && config_.num_tflite_interpreters() > 0) {

tensorflow_serving/servables/tensorflow/saved_model_bundle_factory_test.cc

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,19 +25,28 @@ limitations under the License.
2525
#include "google/protobuf/wrappers.pb.h"
2626
#include <gmock/gmock.h>
2727
#include <gtest/gtest.h>
28+
#include "absl/status/status.h"
2829
#include "tensorflow/cc/saved_model/constants.h"
2930
#include "tensorflow/cc/saved_model/loader.h"
3031
#include "tensorflow/cc/saved_model/tag_constants.h"
32+
#include "xla/tsl/lib/core/status_test_util.h"
3133
#include "tensorflow/core/framework/tensor_testutil.h"
3234
#include "tensorflow/core/lib/core/status.h"
3335
#include "tensorflow/core/lib/core/status_test_util.h"
3436
#include "tensorflow/core/lib/io/path.h"
37+
#include "tensorflow/core/platform/env.h"
38+
#include "tensorflow/core/platform/path.h"
39+
#include "tensorflow/core/platform/test.h"
3540
#include "tensorflow/core/protobuf/named_tensor.pb.h"
41+
#include "tensorflow/core/protobuf/rewriter_config.pb.h"
3642
#include "tensorflow/core/public/session.h"
43+
#include "tensorflow/core/public/session_options.h"
3744
#include "tensorflow/core/public/version.h"
45+
#include "tsl/platform/errors.h"
3846
#include "tensorflow_serving/core/test_util/session_test_util.h"
3947
#include "tensorflow_serving/servables/tensorflow/bundle_factory_test.h"
4048
#include "tensorflow_serving/servables/tensorflow/bundle_factory_test_util.h"
49+
#include "tensorflow_serving/servables/tensorflow/saved_model_config.pb.h"
4150
#include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
4251

4352
namespace tensorflow {
@@ -56,10 +65,11 @@ Status CreateBundleFromPath(const CreationType creation_type,
5665
const string& path,
5766
std::unique_ptr<SavedModelBundle>* bundle) {
5867
std::unique_ptr<SavedModelBundleFactory> factory;
59-
TF_RETURN_IF_ERROR(SavedModelBundleFactory::Create(config, &factory));
6068
auto config_with_session_hook = config;
6169
config_with_session_hook.set_session_target(
6270
test_util::kNewSessionHookSessionTargetPrefix);
71+
TF_RETURN_IF_ERROR(
72+
SavedModelBundleFactory::Create(config_with_session_hook, &factory));
6373
test_util::SetNewSessionHook([&](const SessionOptions& session_options) {
6474
const bool enable_session_metadata =
6575
creation_type == CreationType::kWithMetadata;
@@ -288,6 +298,48 @@ TEST_P(SavedModelBundleFactoryTest, RunOptions) { TestRunOptions(); }
288298

289299
TEST_P(SavedModelBundleFactoryTest, RunOptionsError) { TestRunOptionsError(); }
290300

301+
TEST_P(SavedModelBundleFactoryTest, SetGraphOptionsFromSavedModelConfig) {
302+
const std::string dst_dir = io::JoinPath(testing::TmpDir(), "model");
303+
test_util::CopyDirOrDie(export_dir_, dst_dir);
304+
tensorflow::Env* env = tensorflow::Env::Default();
305+
TF_ASSERT_OK(env->CreateDir(io::JoinPath(dst_dir, "assets.extra")));
306+
SavedModelConfig saved_model_config;
307+
saved_model_config.mutable_session_overrides()->set_disable_meta_optimizer(
308+
true);
309+
TF_ASSERT_OK(tensorflow::WriteBinaryProto(
310+
env, io::JoinPath(dst_dir, "assets.extra", "saved_model_config.pb"),
311+
saved_model_config));
312+
export_dir_ = dst_dir;
313+
314+
SessionBundleConfig config = GetSessionBundleConfig();
315+
config.set_session_target(test_util::kNewSessionHookSessionTargetPrefix);
316+
std::unique_ptr<SavedModelBundle> bundle;
317+
if (ExpectCreateBundleFailure()) {
318+
EXPECT_FALSE(CreateBundleFromPath(GetParam().creation_type, config,
319+
export_dir_, &bundle)
320+
.ok());
321+
return;
322+
}
323+
std::unique_ptr<SavedModelBundleFactory> factory;
324+
TF_ASSERT_OK(SavedModelBundleFactory::Create(config, &factory));
325+
test_util::SetNewSessionHook([&](const SessionOptions& session_options) {
326+
EXPECT_TRUE(session_options.config.graph_options()
327+
.rewrite_options()
328+
.disable_meta_optimizer());
329+
return absl::OkStatus();
330+
});
331+
332+
switch (GetParam().creation_type) {
333+
case CreationType::kWithoutMetadata:
334+
TF_ASSERT_OK(factory->CreateSavedModelBundle(export_dir_, &bundle));
335+
break;
336+
case CreationType::kWithMetadata:
337+
TF_ASSERT_OK(factory->CreateSavedModelBundleWithMetadata(
338+
CreateMetadata(), export_dir_, &bundle));
339+
break;
340+
}
341+
}
342+
291343
} // namespace
292344
} // namespace serving
293345
} // namespace tensorflow

0 commit comments

Comments
 (0)