@@ -26,6 +26,7 @@ limitations under the License.
2626#include " tensorflow/core/platform/types.h"
2727#include " tensorflow_serving/core/loader.h"
2828#include " tensorflow_serving/core/source_adapter.h"
29+ #include " tensorflow_serving/resources/resource_util.h"
2930#include " tensorflow_serving/resources/resource_values.h"
3031#include " tensorflow_serving/util/any_ptr.h"
3132#include " tensorflow_serving/util/optional.h"
@@ -62,6 +63,9 @@ namespace serving {
6263// };
6364// std::unique_ptr<Loader> loader(new SimpleLoader<time_t>(
6465// servable_creator, resource_estimator));
66+ //
67+ // This class is not thread-safe. Synchronization is assumed to be done by the
68+ // caller.
6569template <typename ServableType>
6670class SimpleLoader : public Loader {
6771 public:
@@ -80,7 +84,19 @@ class SimpleLoader : public Loader {
8084 // and hence the serving system cannot enforce resource safety.
8185 static ResourceEstimator EstimateNoResources ();
8286
87+ // Constructor that takes a single resource estimator, to use for estimating
88+ // the resources needed during load as well as post-load.
8389 SimpleLoader (Creator creator, ResourceEstimator resource_estimator);
90+
91+ // Constructor that takes two resource estimators: one to use for estimating
92+ // the resources needed during load, as well as a second one that gives a
93+ // different estimate after loading has finished. See the documentation on
94+ // Loader::EstimateResources() for (a) potential reasons the estimate might
95+ // decrease, and (b) correctness constraints on how the estimate is allowed to
96+ // change over time.
97+ SimpleLoader (Creator creator, ResourceEstimator resource_estimator,
98+ ResourceEstimator post_load_resource_estimator);
99+
84100 ~SimpleLoader () override = default ;
85101
86102 Status EstimateResources (ResourceAllocation* estimate) const override ;
@@ -94,11 +110,20 @@ class SimpleLoader : public Loader {
94110 private:
95111 Creator creator_;
96112
113+ // A function that estimates the resources needed to load the servable.
97114 ResourceEstimator resource_estimator_;
98115
99- // The memoized estimated resource requirement of the session bundle servable.
116+ // An optional function that estimates the resources needed for the servable
117+ // after it has been loaded. (If omitted, 'resource_estimator_' should be used
118+ // for all estimates, i.e. before, during and after load.)
119+ optional<ResourceEstimator> post_load_resource_estimator_;
120+
121+ // The memoized estimated resource requirement of the servable.
100122 mutable optional<ResourceAllocation> memoized_resource_estimate_;
101123
124+ std::unique_ptr<ResourceUtil> resource_util_;
125+ Resource ram_resource_;
126+
102127 std::unique_ptr<ServableType> servable_;
103128
104129 TF_DISALLOW_COPY_AND_ASSIGN (SimpleLoader);
@@ -180,7 +205,23 @@ SimpleLoader<ServableType>::EstimateNoResources() {
180205template <typename ServableType>
181206SimpleLoader<ServableType>::SimpleLoader(Creator creator,
182207 ResourceEstimator resource_estimator)
183- : creator_(creator), resource_estimator_(resource_estimator) {}
208+ : creator_(creator), resource_estimator_(resource_estimator) {
209+ ResourceUtil::Options resource_util_options;
210+ resource_util_options.devices = {{device_types::kMain , 1 }};
211+ resource_util_ =
212+ std::unique_ptr<ResourceUtil>(new ResourceUtil (resource_util_options));
213+
214+ ram_resource_ = resource_util_->CreateBoundResource (
215+ device_types::kMain , resource_kinds::kRamBytes );
216+ }
217+
218+ template <typename ServableType>
219+ SimpleLoader<ServableType>::SimpleLoader(
220+ Creator creator, ResourceEstimator resource_estimator,
221+ ResourceEstimator post_load_resource_estimator)
222+ : SimpleLoader(creator, resource_estimator) {
223+ post_load_resource_estimator_ = post_load_resource_estimator;
224+ }
184225
185226template <typename ServableType>
186227Status SimpleLoader<ServableType>::EstimateResources(
@@ -198,8 +239,36 @@ Status SimpleLoader<ServableType>::EstimateResources(
198239
199240template <typename ServableType>
200241Status SimpleLoader<ServableType>::Load() {
201- const Status status = creator_ (&servable_);
202- return status;
242+ TF_RETURN_IF_ERROR (creator_ (&servable_));
243+
244+ if (post_load_resource_estimator_) {
245+ // Save the during-load estimate (may be able to use the memoized value).
246+ ResourceAllocation during_load_resource_estimate;
247+ TF_RETURN_IF_ERROR (EstimateResources (&during_load_resource_estimate));
248+
249+ // Obtain the post-load estimate, and store it as the memoized value.
250+ ResourceAllocation post_load_resource_estimate;
251+ TF_RETURN_IF_ERROR (
252+ (*post_load_resource_estimator_)(&post_load_resource_estimate));
253+ memoized_resource_estimate_ = post_load_resource_estimate;
254+
255+ // Release any transient memory used only during load to the OS.
256+ const uint64 during_load_ram_estimate = resource_util_->GetQuantity (
257+ ram_resource_, during_load_resource_estimate);
258+ const uint64 post_load_ram_estimate =
259+ resource_util_->GetQuantity (ram_resource_, post_load_resource_estimate);
260+ if (post_load_ram_estimate < during_load_ram_estimate) {
261+ const uint64 transient_ram_estimate =
262+ during_load_ram_estimate - post_load_ram_estimate;
263+ LOG (INFO) << " Calling MallocExtension_ReleaseToSystem() after servable "
264+ " load with "
265+ << transient_ram_estimate;
266+ ::tensorflow::port::MallocExtension_ReleaseToSystem (
267+ transient_ram_estimate);
268+ }
269+ }
270+
271+ return Status::OK ();
203272}
204273
205274template <typename ServableType>
@@ -219,14 +288,13 @@ void SimpleLoader<ServableType>::Unload() {
219288
220289 // If we have a main-memory footprint estimate, release that amount of memory
221290 // to the OS.
222- for (const ResourceAllocation::Entry& entry :
223- resource_estimate.resource_quantities ()) {
224- if (entry.resource ().device () == device_types::kMain &&
225- entry.resource ().kind () == resource_kinds::kRamBytes ) {
226- LOG (INFO) << " Calling MallocExtension_ReleaseToSystem() with "
227- << entry.quantity ();
228- ::tensorflow::port::MallocExtension_ReleaseToSystem (entry.quantity());
229- }
291+ const uint64 memory_estimate =
292+ resource_util_->GetQuantity (ram_resource_, resource_estimate);
293+ if (memory_estimate > 0 ) {
294+ LOG (INFO) << " Calling MallocExtension_ReleaseToSystem() after servable "
295+ " unload with "
296+ << memory_estimate;
297+ ::tensorflow::port::MallocExtension_ReleaseToSystem (memory_estimate);
230298 }
231299}
232300
0 commit comments