Skip to content

Commit ac1586b

Browse files
SiqiaoWu1993tensorflow-copybara
authored andcommitted
Internal change only
PiperOrigin-RevId: 792359286
1 parent a419266 commit ac1586b

File tree

5 files changed

+94
-10
lines changed

5 files changed

+94
-10
lines changed

tensorflow_serving/servables/tensorflow/BUILD

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1254,6 +1254,8 @@ cc_library(
12541254
"//tensorflow_serving/resources:resource_values",
12551255
"//tensorflow_serving/resources:resources_cc_proto",
12561256
"//tensorflow_serving/session_bundle:graph_rewriter",
1257+
"@com_google_absl//absl/base:core_headers",
1258+
"@com_google_absl//absl/log",
12571259
"@com_google_absl//absl/status",
12581260
"@com_google_absl//absl/status:statusor",
12591261
"@com_google_absl//absl/strings",
@@ -1266,7 +1268,9 @@ cc_library(
12661268
"@org_tensorflow//tensorflow/core:core_cpu",
12671269
"@org_tensorflow//tensorflow/core:lib",
12681270
"@org_tensorflow//tensorflow/core:protos_all_cc",
1271+
"@org_tensorflow//tensorflow/core/kernels/batching_util:batch_scheduler_hdrs",
12691272
"@org_tensorflow//tensorflow/core/kernels/batching_util:shared_batch_scheduler_hdrs",
1273+
"@org_tensorflow//tensorflow/core/tfrt/graph_executor:graph_execution_options",
12701274
"@org_tensorflow//tensorflow/core/tfrt/runtime",
12711275
"@org_tensorflow//tensorflow/core/tfrt/saved_model:saved_model_cpu",
12721276
],
@@ -1320,18 +1324,22 @@ cc_library(
13201324
],
13211325
deps = [
13221326
":bundle_factory_util",
1327+
":file_acl",
13231328
":machine_learning_metadata",
13241329
":servable",
13251330
":tfrt_saved_model_factory",
13261331
":tfrt_saved_model_source_adapter_cc_proto",
13271332
":tfrt_servable",
1333+
"//tensorflow_serving/core:loader",
13281334
"//tensorflow_serving/core:simple_loader",
13291335
"//tensorflow_serving/core:source_adapter",
13301336
"//tensorflow_serving/core:storage_path",
13311337
"//tensorflow_serving/resources:resource_util",
13321338
"//tensorflow_serving/resources:resource_values",
13331339
"//tensorflow_serving/resources:resources_cc_proto",
1334-
"//tensorflow_serving/servables/tensorflow:file_acl",
1340+
"@com_google_absl//absl/status",
1341+
"@com_google_absl//absl/strings",
1342+
"@com_google_absl//absl/strings:string_view",
13351343
"@org_tensorflow//tensorflow/cc/saved_model:loader",
13361344
"@org_tensorflow//tensorflow/core:lib",
13371345
],

tensorflow_serving/servables/tensorflow/tfrt_saved_model_factory.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,15 @@ limitations under the License.
1717

1818
#include <algorithm>
1919
#include <cstdint>
20+
#include <functional>
2021
#include <memory>
2122
#include <string>
2223
#include <unordered_set>
2324
#include <utility>
2425
#include <vector>
2526

2627
#include "google/protobuf/wrappers.pb.h"
28+
#include "absl/log/log.h"
2729
#include "absl/status/status.h"
2830
#include "absl/status/statusor.h"
2931
#include "absl/strings/string_view.h"
@@ -33,8 +35,11 @@ limitations under the License.
3335
#include "tensorflow/compiler/mlir/tfrt/translate/tfrt_compile_options.h"
3436
#include "xla/tsl/platform/env.h"
3537
#include "xla/tsl/platform/errors.h"
38+
#include "xla/tsl/platform/statusor.h"
39+
#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
3640
#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
3741
#include "tensorflow/core/lib/core/errors.h"
42+
#include "tensorflow/core/platform/types.h"
3843
#include "tensorflow/core/protobuf/config.pb.h"
3944
#include "tensorflow/core/protobuf/meta_graph.pb.h"
4045
#include "tensorflow/core/public/session_options.h"
@@ -401,5 +406,14 @@ CreateThreadPoolFactoryFromConfig(const TfrtSavedModelConfig& config) {
401406
return thread_pool_factory;
402407
}
403408

409+
// copybara:strip_begin (Do not leak in tesorflow serving OSS.)
410+
absl::Status TfrtSavedModelFactory::CreateOrbaxServable(
411+
const Loader::Metadata& metadata, const string& path,
412+
std::unique_ptr<Servable>* servable) {
413+
return absl::UnimplementedError(
414+
"CreateOrbaxServable is not implemented yet.");
415+
}
416+
// copybara:strip_end
417+
404418
} // namespace serving
405419
} // namespace tensorflow

tensorflow_serving/servables/tensorflow/tfrt_saved_model_factory.h

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,30 @@ limitations under the License.
1616
#ifndef TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFRT_SAVED_MODEL_FACTORY_H_
1717
#define TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_TFRT_SAVED_MODEL_FACTORY_H_
1818

19+
#include <cstdint>
1920
#include <functional>
2021
#include <memory>
2122
#include <string>
23+
#include <unordered_set>
24+
#include <utility>
2225

26+
#include "absl/base/attributes.h"
27+
#include "absl/base/thread_annotations.h"
28+
#include "absl/log/log.h"
2329
#include "absl/status/status.h"
30+
#include "absl/status/statusor.h"
31+
#include "absl/strings/string_view.h"
2432
#include "absl/synchronization/mutex.h"
25-
#include "absl/types/optional.h"
33+
#include "xla/tsl/platform/macros.h"
2634
#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
27-
#include "tensorflow/core/lib/core/status.h"
28-
#include "tensorflow/core/platform/macros.h"
35+
#include "tensorflow/core/platform/types.h"
36+
#include "tensorflow/core/tfrt/graph_executor/graph_execution_options.h"
2937
#include "tensorflow/core/tfrt/runtime/runtime.h"
3038
#include "tensorflow/core/tfrt/saved_model/saved_model.h"
3139
#include "tensorflow_serving/batching/tfrt_saved_model_with_batching.h"
3240
#include "tensorflow_serving/core/loader.h"
3341
#include "tensorflow_serving/resources/resources.pb.h"
42+
#include "tensorflow_serving/servables/tensorflow/servable.h"
3443
#include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
3544
#include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.pb.h"
3645
#include "tensorflow_serving/servables/tensorflow/tfrt_servable.h"
@@ -106,6 +115,12 @@ class TfrtSavedModelFactory {
106115
absl::Status EstimateResourceRequirement(const string& path,
107116
ResourceAllocation* estimate) const;
108117

118+
// copybara:strip_begin (Do not leak in tesorflow serving OSS.)
119+
virtual absl::Status CreateOrbaxServable(const Loader::Metadata& metadata,
120+
const string& path,
121+
std::unique_ptr<Servable>* servable);
122+
// copybara:strip_end
123+
109124
const TfrtSavedModelConfig& config() const { return config_; }
110125
TfrtSavedModelConfig& mutable_config() { return config_; }
111126
absl::string_view GetServingResourceType() const;

tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.cc

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,21 @@ limitations under the License.
1616
#include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.h"
1717

1818
#include <memory>
19-
19+
#include <string>
20+
#include <utility>
21+
22+
#include "absl/status/status.h"
23+
#include "absl/strings/str_cat.h"
24+
#include "absl/strings/string_view.h"
25+
#include "tensorflow/cc/saved_model/loader.h"
26+
#include "xla/tsl/platform/env.h"
2027
#include "xla/tsl/platform/errors.h"
21-
#include "tensorflow/core/lib/core/errors.h"
28+
#include "tensorflow/core/platform/path.h"
2229
#include "tensorflow/core/platform/types.h"
30+
#include "tensorflow_serving/core/loader.h"
2331
#include "tensorflow_serving/core/simple_loader.h"
32+
#include "tensorflow_serving/core/source_adapter.h"
33+
#include "tensorflow_serving/core/storage_path.h"
2434
#include "tensorflow_serving/resources/resource_util.h"
2535
#include "tensorflow_serving/resources/resource_values.h"
2636
#include "tensorflow_serving/resources/resources.pb.h"
@@ -29,11 +39,31 @@ limitations under the License.
2939
#include "tensorflow_serving/servables/tensorflow/machine_learning_metadata.h"
3040
#include "tensorflow_serving/servables/tensorflow/servable.h"
3141
#include "tensorflow_serving/servables/tensorflow/tfrt_saved_model_factory.h"
32-
#include "tensorflow_serving/servables/tensorflow/tfrt_servable.h"
3342

3443
namespace tensorflow {
3544
namespace serving {
3645

46+
// copybara:strip_begin (Do not leak in tensorflow serving OSS.)
47+
namespace {
48+
// Orbax manifest file name.
49+
inline constexpr char kOrbaxModelManifestPb[] = "manifest.pb";
50+
// Orbax manifest version file name.
51+
inline constexpr char kOrbaxModelManifestVersionTxt[] = "manifest_version.txt";
52+
53+
absl::Status IsOrbaxModelDirectory(absl::string_view path) {
54+
const std::string orbax_model_manifest_pb_path =
55+
tensorflow::io::JoinPath(path, kOrbaxModelManifestPb);
56+
const std::string orbax_model_manifest_version_path =
57+
tensorflow::io::JoinPath(path, kOrbaxModelManifestVersionTxt);
58+
tsl::Env* env = tsl::Env::Default();
59+
TF_RETURN_IF_ERROR(env->FileExists(orbax_model_manifest_pb_path));
60+
TF_RETURN_IF_ERROR(env->FileExists(orbax_model_manifest_version_path));
61+
return absl::OkStatus();
62+
}
63+
64+
} // namespace
65+
// copybara:strip_end
66+
3767
absl::Status TfrtSavedModelSourceAdapter::Create(
3868
const TfrtSavedModelSourceAdapterConfig& config,
3969
std::unique_ptr<TfrtSavedModelSourceAdapter>* adapter) {
@@ -57,9 +87,20 @@ TfrtSavedModelSourceAdapter::GetServableCreator(
5787
return [factory, path](const Loader::Metadata& metadata,
5888
std::unique_ptr<Servable>* servable) {
5989
TF_RETURN_IF_ERROR(RegisterModelRoot(metadata.servable_id, path));
60-
TF_RETURN_IF_ERROR(
61-
factory->CreateTfrtSavedModelWithMetadata(metadata, path, servable));
62-
return absl::OkStatus();
90+
if (MaybeSavedModelDirectory(path)) {
91+
return factory->CreateTfrtSavedModelWithMetadata(metadata, path,
92+
servable);
93+
}
94+
95+
// copybara:strip_begin (Do not leak in tesorflow serving OSS.)
96+
if (IsOrbaxModelDirectory(path).ok()) {
97+
return factory->CreateOrbaxServable(metadata, path, servable);
98+
}
99+
// copybara:strip_end
100+
101+
return absl::InvalidArgumentError(
102+
absl::StrCat("Unsupported model directory: ", path,
103+
". Only SavedModel and Orbax Model are supported."));
63104
};
64105
}
65106

tensorflow_serving/servables/tensorflow/tfrt_saved_model_source_adapter.proto

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,12 @@ message TfrtSavedModelConfig {
238238
// If non-zero, all models on this server are switched to use a prioritized
239239
// batching function using this number of global threads.
240240
int64 tfrt_batch_queue_global_prioritization_num_threads = 2029;
241+
242+
// copybara:strip_begin (Do not leak in tesorflow serving OSS.)
243+
// If true, allow Orbax as an additional format for loading models.
244+
bool allow_orbax = 2030;
245+
// copybara:strip_end
246+
bool allow_xla_cpu = 2031;
241247
}
242248

243249
// Config proto for TfrtSavedModelSourceAdapter.

0 commit comments

Comments
 (0)