Skip to content

Commit a419266

Browse files
AspirinSJLtensorflow-copybara
authored andcommitted
Extract the untracked part of RunSavedModelWarmup to a separate function
PiperOrigin-RevId: 791455959
1 parent ea320f0 commit a419266

File tree

4 files changed

+124
-61
lines changed

4 files changed

+124
-61
lines changed

tensorflow_serving/servables/tensorflow/BUILD

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -933,10 +933,13 @@ cc_library(
933933
deps = [
934934
":session_bundle_config_cc_proto",
935935
"//tensorflow_serving/apis:prediction_log_cc_proto",
936+
"//tensorflow_serving/util:executor",
936937
"//tensorflow_serving/util:threadpool_executor",
938+
"@com_google_absl//absl/base:core_headers",
939+
"@com_google_absl//absl/log",
940+
"@com_google_absl//absl/status",
937941
"@com_google_protobuf//:cc_wkt_protos",
938942
"@org_tensorflow//tensorflow/cc/saved_model:constants",
939-
"@org_tensorflow//tensorflow/cc/saved_model:loader",
940943
"@org_tensorflow//tensorflow/core:lib",
941944
"@org_tensorflow//tensorflow/core:protos_all_cc",
942945
"@org_tensorflow//tensorflow/core/kernels/batching_util:warmup",
@@ -979,14 +982,22 @@ cc_test(
979982
"//tensorflow_serving/apis:predict_cc_proto",
980983
"//tensorflow_serving/apis:prediction_log_cc_proto",
981984
"//tensorflow_serving/apis:regression_cc_proto",
982-
"//tensorflow_serving/core/test_util:test_main",
985+
"//tensorflow_serving/core/test_util:test_main", # buildcleaner: keep
986+
"@com_google_absl//absl/status",
987+
"@com_google_absl//absl/strings:string_view",
983988
"@com_google_protobuf//:cc_wkt_protos",
989+
"@local_xla//xla/tsl/lib/core:status_test_util",
984990
"@org_tensorflow//tensorflow/cc/saved_model:constants",
985-
"@org_tensorflow//tensorflow/cc/saved_model:signature_constants",
991+
"@org_tensorflow//tensorflow/cc/saved_model:loader",
992+
"@org_tensorflow//tensorflow/core:framework_lite",
986993
"@org_tensorflow//tensorflow/core:lib",
987994
"@org_tensorflow//tensorflow/core:protos_all_cc",
988995
"@org_tensorflow//tensorflow/core:test",
996+
"@org_tensorflow//tensorflow/core:tflite_portable_logging",
997+
"@org_tensorflow//tensorflow/core/example:example_protos_cc",
989998
"@org_tensorflow//tensorflow/core/kernels/batching_util:warmup",
999+
"@org_tensorflow//tensorflow/core/platform:errors",
1000+
"@org_tensorflow//tensorflow/core/platform:path",
9901001
],
9911002
)
9921003

tensorflow_serving/servables/tensorflow/saved_model_warmup_util.cc

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,30 @@ limitations under the License.
1616
#include "tensorflow_serving/servables/tensorflow/saved_model_warmup_util.h"
1717

1818
#include <algorithm>
19+
#include <cstdint>
1920
#include <functional>
2021
#include <memory>
2122
#include <utility>
2223

2324
#include "google/protobuf/wrappers.pb.h"
25+
#include "absl/base/thread_annotations.h"
26+
#include "absl/log/log.h"
27+
#include "absl/status/status.h"
2428
#include "tensorflow/cc/saved_model/constants.h"
2529
#include "xla/tsl/platform/errors.h"
2630
#include "tensorflow/core/kernels/batching_util/warmup.h"
2731
#include "tensorflow/core/lib/core/errors.h"
28-
#include "tensorflow/core/lib/io/path.h"
2932
#include "tensorflow/core/lib/io/record_reader.h"
3033
#include "tensorflow/core/lib/monitoring/sampler.h"
34+
#include "tensorflow/core/platform/env.h"
35+
#include "tensorflow/core/platform/env_time.h"
36+
#include "tensorflow/core/platform/file_system.h"
3137
#include "tensorflow/core/platform/mutex.h"
32-
#include "tensorflow/core/platform/status.h"
38+
#include "tensorflow/core/platform/path.h"
39+
#include "tensorflow/core/platform/strcat.h"
40+
#include "tensorflow/core/platform/tstring.h"
41+
#include "tensorflow/core/platform/types.h"
42+
#include "tensorflow_serving/util/executor.h"
3343
#include "tensorflow_serving/util/threadpool_executor.h"
3444

3545
namespace tensorflow {
@@ -58,22 +68,9 @@ uint64_t GetLatencyMicroseconds(const uint64_t start_microseconds) {
5868
constexpr char WarmupConsts::kRequestsFileName[];
5969
constexpr int WarmupConsts::kMaxNumRecords;
6070

61-
absl::Status RunSavedModelWarmup(
71+
absl::Status RunSavedModelWarmupUntracked(
6272
const ModelWarmupOptions& model_warmup_options, const string export_dir,
6373
std::function<absl::Status(PredictionLog)> warmup_request_executor) {
64-
WarmupStateRegistry::Handle warmup_handle;
65-
auto per_model_data = std::make_unique<WarmupStateRegistry::PerModelData>();
66-
per_model_data->warmup_all_batch_sizes =
67-
model_warmup_options.enable_all_batch_sizes_warmup();
68-
if (!model_warmup_options.model_name().empty()) {
69-
auto h = GetGlobalWarmupStateRegistry().Register(
70-
{model_warmup_options.model_name(),
71-
model_warmup_options.model_version()},
72-
std::move(per_model_data));
73-
TF_RETURN_IF_ERROR(h.status());
74-
warmup_handle = std::move(h.value());
75-
}
76-
7774
const uint64_t start_microseconds = EnvTime::NowMicros();
7875
const string warmup_path =
7976
io::JoinPath(export_dir, kSavedModelAssetsExtraDirectory,
@@ -237,6 +234,26 @@ absl::Status RunSavedModelWarmup(
237234
return absl::OkStatus();
238235
}
239236

237+
absl::Status RunSavedModelWarmup(
238+
const ModelWarmupOptions& model_warmup_options, const string export_dir,
239+
std::function<absl::Status(PredictionLog)> warmup_request_executor) {
240+
WarmupStateRegistry::Handle warmup_handle;
241+
auto per_model_data = std::make_unique<WarmupStateRegistry::PerModelData>();
242+
per_model_data->warmup_all_batch_sizes =
243+
model_warmup_options.enable_all_batch_sizes_warmup();
244+
if (!model_warmup_options.model_name().empty()) {
245+
auto h = GetGlobalWarmupStateRegistry().Register(
246+
{model_warmup_options.model_name(),
247+
model_warmup_options.model_version()},
248+
std::move(per_model_data));
249+
TF_RETURN_IF_ERROR(h.status());
250+
warmup_handle = std::move(h.value());
251+
}
252+
253+
return RunSavedModelWarmupUntracked(model_warmup_options, export_dir,
254+
warmup_request_executor);
255+
}
256+
240257
} // namespace internal
241258
} // namespace serving
242259
} // namespace tensorflow

tensorflow_serving/servables/tensorflow/saved_model_warmup_util.h

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,10 @@ limitations under the License.
1616
#ifndef THIRD_PARTY_TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_SAVED_MODEL_WARMUP_UTIL_H_
1717
#define THIRD_PARTY_TENSORFLOW_SERVING_SERVABLES_TENSORFLOW_SAVED_MODEL_WARMUP_UTIL_H_
1818

19-
#include "tensorflow/cc/saved_model/loader.h"
19+
#include <functional>
20+
21+
#include "absl/status/status.h"
22+
#include "tensorflow/core/platform/types.h"
2023
#include "tensorflow/core/protobuf/saved_model.pb.h"
2124
#include "tensorflow_serving/apis/prediction_log.pb.h"
2225
#include "tensorflow_serving/servables/tensorflow/session_bundle_config.pb.h"
@@ -35,9 +38,18 @@ struct WarmupConsts {
3538
// to trigger lazy initializations (such as TF optimizations, XLA compilations)
3639
// at load time, and consequently improve first request latency.
3740
// Warmup is skipped if no warmup file present.
38-
Status RunSavedModelWarmup(
41+
absl::Status RunSavedModelWarmup(
42+
const ModelWarmupOptions& model_warmup_options, const string export_dir,
43+
std::function<absl::Status(PredictionLog)> warmup_request_executor);
44+
45+
// Similar to `RunSavedModelWarmup()`, but does not track the warmup state.
46+
//
47+
// WARNING: Inside the function, multiple warmup threads might be dispatched to
48+
// run `warmup_request_executor`. Use with caution, especially when batching is
49+
// involved.
50+
absl::Status RunSavedModelWarmupUntracked(
3951
const ModelWarmupOptions& model_warmup_options, const string export_dir,
40-
std::function<Status(PredictionLog)> warmup_request_executor);
52+
std::function<absl::Status(PredictionLog)> warmup_request_executor);
4153

4254
} // namespace internal
4355
} // namespace serving

0 commit comments

Comments
 (0)