@@ -16,20 +16,30 @@ limitations under the License.
16
16
#include " tensorflow_serving/servables/tensorflow/saved_model_warmup_util.h"
17
17
18
18
#include < algorithm>
19
+ #include < cstdint>
19
20
#include < functional>
20
21
#include < memory>
21
22
#include < utility>
22
23
23
24
#include " google/protobuf/wrappers.pb.h"
25
+ #include " absl/base/thread_annotations.h"
26
+ #include " absl/log/log.h"
27
+ #include " absl/status/status.h"
24
28
#include " tensorflow/cc/saved_model/constants.h"
25
29
#include " xla/tsl/platform/errors.h"
26
30
#include " tensorflow/core/kernels/batching_util/warmup.h"
27
31
#include " tensorflow/core/lib/core/errors.h"
28
- #include " tensorflow/core/lib/io/path.h"
29
32
#include " tensorflow/core/lib/io/record_reader.h"
30
33
#include " tensorflow/core/lib/monitoring/sampler.h"
34
+ #include " tensorflow/core/platform/env.h"
35
+ #include " tensorflow/core/platform/env_time.h"
36
+ #include " tensorflow/core/platform/file_system.h"
31
37
#include " tensorflow/core/platform/mutex.h"
32
- #include " tensorflow/core/platform/status.h"
38
+ #include " tensorflow/core/platform/path.h"
39
+ #include " tensorflow/core/platform/strcat.h"
40
+ #include " tensorflow/core/platform/tstring.h"
41
+ #include " tensorflow/core/platform/types.h"
42
+ #include " tensorflow_serving/util/executor.h"
33
43
#include " tensorflow_serving/util/threadpool_executor.h"
34
44
35
45
namespace tensorflow {
@@ -58,22 +68,9 @@ uint64_t GetLatencyMicroseconds(const uint64_t start_microseconds) {
58
68
constexpr char WarmupConsts::kRequestsFileName [];
59
69
constexpr int WarmupConsts::kMaxNumRecords ;
60
70
61
- absl::Status RunSavedModelWarmup (
71
+ absl::Status RunSavedModelWarmupUntracked (
62
72
const ModelWarmupOptions& model_warmup_options, const string export_dir,
63
73
std::function<absl::Status(PredictionLog)> warmup_request_executor) {
64
- WarmupStateRegistry::Handle warmup_handle;
65
- auto per_model_data = std::make_unique<WarmupStateRegistry::PerModelData>();
66
- per_model_data->warmup_all_batch_sizes =
67
- model_warmup_options.enable_all_batch_sizes_warmup ();
68
- if (!model_warmup_options.model_name ().empty ()) {
69
- auto h = GetGlobalWarmupStateRegistry ().Register (
70
- {model_warmup_options.model_name (),
71
- model_warmup_options.model_version ()},
72
- std::move (per_model_data));
73
- TF_RETURN_IF_ERROR (h.status ());
74
- warmup_handle = std::move (h.value ());
75
- }
76
-
77
74
const uint64_t start_microseconds = EnvTime::NowMicros ();
78
75
const string warmup_path =
79
76
io::JoinPath (export_dir, kSavedModelAssetsExtraDirectory ,
@@ -237,6 +234,26 @@ absl::Status RunSavedModelWarmup(
237
234
return absl::OkStatus ();
238
235
}
239
236
237
+ absl::Status RunSavedModelWarmup (
238
+ const ModelWarmupOptions& model_warmup_options, const string export_dir,
239
+ std::function<absl::Status(PredictionLog)> warmup_request_executor) {
240
+ WarmupStateRegistry::Handle warmup_handle;
241
+ auto per_model_data = std::make_unique<WarmupStateRegistry::PerModelData>();
242
+ per_model_data->warmup_all_batch_sizes =
243
+ model_warmup_options.enable_all_batch_sizes_warmup ();
244
+ if (!model_warmup_options.model_name ().empty ()) {
245
+ auto h = GetGlobalWarmupStateRegistry ().Register (
246
+ {model_warmup_options.model_name (),
247
+ model_warmup_options.model_version ()},
248
+ std::move (per_model_data));
249
+ TF_RETURN_IF_ERROR (h.status ());
250
+ warmup_handle = std::move (h.value ());
251
+ }
252
+
253
+ return RunSavedModelWarmupUntracked (model_warmup_options, export_dir,
254
+ warmup_request_executor);
255
+ }
256
+
240
257
} // namespace internal
241
258
} // namespace serving
242
259
} // namespace tensorflow
0 commit comments