@@ -25,19 +25,28 @@ limitations under the License.
25
25
#include " google/protobuf/wrappers.pb.h"
26
26
#include < gmock/gmock.h>
27
27
#include < gtest/gtest.h>
28
+ #include " absl/status/status.h"
28
29
#include " tensorflow/cc/saved_model/constants.h"
29
30
#include " tensorflow/cc/saved_model/loader.h"
30
31
#include " tensorflow/cc/saved_model/tag_constants.h"
32
+ #include " xla/tsl/lib/core/status_test_util.h"
31
33
#include " tensorflow/core/framework/tensor_testutil.h"
32
34
#include " tensorflow/core/lib/core/status.h"
33
35
#include " tensorflow/core/lib/core/status_test_util.h"
34
36
#include " tensorflow/core/lib/io/path.h"
37
+ #include " tensorflow/core/platform/env.h"
38
+ #include " tensorflow/core/platform/path.h"
39
+ #include " tensorflow/core/platform/test.h"
35
40
#include " tensorflow/core/protobuf/named_tensor.pb.h"
41
+ #include " tensorflow/core/protobuf/rewriter_config.pb.h"
36
42
#include " tensorflow/core/public/session.h"
43
+ #include " tensorflow/core/public/session_options.h"
37
44
#include " tensorflow/core/public/version.h"
45
+ #include " tsl/platform/errors.h"
38
46
#include " tensorflow_serving/core/test_util/session_test_util.h"
39
47
#include " tensorflow_serving/servables/tensorflow/bundle_factory_test.h"
40
48
#include " tensorflow_serving/servables/tensorflow/bundle_factory_test_util.h"
49
+ #include " tensorflow_serving/servables/tensorflow/saved_model_config.pb.h"
41
50
#include " tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
42
51
43
52
namespace tensorflow {
@@ -56,10 +65,11 @@ Status CreateBundleFromPath(const CreationType creation_type,
56
65
const string& path,
57
66
std::unique_ptr<SavedModelBundle>* bundle) {
58
67
std::unique_ptr<SavedModelBundleFactory> factory;
59
- TF_RETURN_IF_ERROR (SavedModelBundleFactory::Create (config, &factory));
60
68
auto config_with_session_hook = config;
61
69
config_with_session_hook.set_session_target (
62
70
test_util::kNewSessionHookSessionTargetPrefix );
71
+ TF_RETURN_IF_ERROR (
72
+ SavedModelBundleFactory::Create (config_with_session_hook, &factory));
63
73
test_util::SetNewSessionHook ([&](const SessionOptions& session_options) {
64
74
const bool enable_session_metadata =
65
75
creation_type == CreationType::kWithMetadata ;
@@ -288,6 +298,48 @@ TEST_P(SavedModelBundleFactoryTest, RunOptions) { TestRunOptions(); }
288
298
289
299
TEST_P (SavedModelBundleFactoryTest, RunOptionsError) { TestRunOptionsError (); }
290
300
301
+ TEST_P (SavedModelBundleFactoryTest, SetGraphOptionsFromSavedModelConfig) {
302
+ const std::string dst_dir = io::JoinPath (testing::TmpDir (), " model" );
303
+ test_util::CopyDirOrDie (export_dir_, dst_dir);
304
+ tensorflow::Env* env = tensorflow::Env::Default ();
305
+ TF_ASSERT_OK (env->CreateDir (io::JoinPath (dst_dir, " assets.extra" )));
306
+ SavedModelConfig saved_model_config;
307
+ saved_model_config.mutable_session_overrides ()->set_disable_meta_optimizer (
308
+ true );
309
+ TF_ASSERT_OK (tensorflow::WriteBinaryProto (
310
+ env, io::JoinPath (dst_dir, " assets.extra" , " saved_model_config.pb" ),
311
+ saved_model_config));
312
+ export_dir_ = dst_dir;
313
+
314
+ SessionBundleConfig config = GetSessionBundleConfig ();
315
+ config.set_session_target (test_util::kNewSessionHookSessionTargetPrefix );
316
+ std::unique_ptr<SavedModelBundle> bundle;
317
+ if (ExpectCreateBundleFailure ()) {
318
+ EXPECT_FALSE (CreateBundleFromPath (GetParam ().creation_type , config,
319
+ export_dir_, &bundle)
320
+ .ok ());
321
+ return ;
322
+ }
323
+ std::unique_ptr<SavedModelBundleFactory> factory;
324
+ TF_ASSERT_OK (SavedModelBundleFactory::Create (config, &factory));
325
+ test_util::SetNewSessionHook ([&](const SessionOptions& session_options) {
326
+ EXPECT_TRUE (session_options.config .graph_options ()
327
+ .rewrite_options ()
328
+ .disable_meta_optimizer ());
329
+ return absl::OkStatus ();
330
+ });
331
+
332
+ switch (GetParam ().creation_type ) {
333
+ case CreationType::kWithoutMetadata :
334
+ TF_ASSERT_OK (factory->CreateSavedModelBundle (export_dir_, &bundle));
335
+ break ;
336
+ case CreationType::kWithMetadata :
337
+ TF_ASSERT_OK (factory->CreateSavedModelBundleWithMetadata (
338
+ CreateMetadata (), export_dir_, &bundle));
339
+ break ;
340
+ }
341
+ }
342
+
291
343
} // namespace
292
344
} // namespace serving
293
345
} // namespace tensorflow
0 commit comments