Skip to content

Commit 016354d

Browse files
authored
feat: Add support for max_inflight_requests parameter to prevent unbounded memory growth in ensemble models (#455)
1 parent e813ef8 commit 016354d

File tree

3 files changed

+136
-0
lines changed

3 files changed

+136
-0
lines changed

src/constants.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ constexpr int MAX_GRPC_MESSAGE_SIZE = INT32_MAX;
9595
constexpr uint64_t SEQUENCE_IDLE_DEFAULT_MICROSECONDS = 1000 * 1000;
9696
constexpr size_t CUDA_IPC_STRUCT_SIZE = 64;
9797

98+
constexpr int kMutexTimeoutSeconds = 300;
99+
98100
#ifdef TRITON_ENABLE_METRICS
99101
// MetricModelReporter expects a device ID for GPUs, but we reuse this device
100102
// ID for other metrics as well such as for CPU and Response Cache metrics

src/ensemble_scheduler/ensemble_scheduler.cc

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,10 @@
2828

2929
#include "ensemble_scheduler.h"
3030

31+
#include <condition_variable>
3132
#include <mutex>
3233

34+
#include "constants.h"
3335
#include "cuda_utils.h"
3436
#include "metrics.h"
3537
#include "model.h"
@@ -150,6 +152,82 @@ class RequestTracker {
150152
triton::common::ThreadPool* const callback_pool_;
151153
};
152154

155+
// Limits concurrent inflight requests for a single ensemble step.
156+
// Tracks inflight requests count and blocks producers when limit is reached.
157+
class StepInflightRequestLimiter {
158+
public:
159+
explicit StepInflightRequestLimiter(const size_t max_inflight)
160+
: inflight_count_(0), max_inflight_(max_inflight)
161+
{
162+
}
163+
164+
// Wait until capacity is available or request is cancelled.
165+
// No-op if limit not configured (max_inflight_ == 0).
166+
void WaitForCapacity(
167+
RequestTracker* request_tracker, const size_t step_idx,
168+
const std::string& ensemble_name)
169+
{
170+
// No limit configured, no blocking
171+
if (max_inflight_ == 0) {
172+
return;
173+
}
174+
175+
std::unique_lock<std::mutex> lk(mutex_);
176+
auto timeout = std::chrono::seconds(kMutexTimeoutSeconds);
177+
178+
auto is_request_cancelled = [&]() {
179+
auto& req = request_tracker->Request();
180+
return (req == nullptr) || req->IsCancelled();
181+
};
182+
183+
bool capacity_available = cv_.wait_for(lk, timeout, [&] {
184+
return is_request_cancelled() || (inflight_count_ < max_inflight_);
185+
});
186+
187+
// Log error if timeout occurred (not cancellation), but proceed anyway
188+
// to avoid deadlock. Caller always continues after this call.
189+
if (!capacity_available && !is_request_cancelled()) {
190+
LOG_ERROR << "[Internal Error] Ensemble '" << ensemble_name
191+
<< "' unable to schedule step " << step_idx
192+
<< " (inflight: " << inflight_count_
193+
<< " >= limit: " << max_inflight_ << ") for "
194+
<< kMutexTimeoutSeconds
195+
<< " seconds. Proceeding to avoid deadlock.";
196+
}
197+
}
198+
199+
// Increment inflight count after successfully scheduling a request.
200+
// No-op if limit not configured (max_inflight_ == 0).
201+
void IncrementInflightCount()
202+
{
203+
// No limit configured, no tracking needed
204+
if (max_inflight_ == 0) {
205+
return;
206+
}
207+
std::lock_guard<std::mutex> lk(mutex_);
208+
inflight_count_++;
209+
}
210+
211+
// Decrement inflight count when a request completes, and notify waiting
212+
// producers. No-op if limit not configured (max_inflight_ == 0).
213+
void DecrementInflightCount()
214+
{
215+
// No limit configured, no tracking needed
216+
if (max_inflight_ == 0) {
217+
return;
218+
}
219+
std::lock_guard<std::mutex> lk(mutex_);
220+
inflight_count_--;
221+
cv_.notify_one();
222+
}
223+
224+
private:
225+
size_t inflight_count_;
226+
const size_t max_inflight_;
227+
std::mutex mutex_;
228+
std::condition_variable cv_;
229+
};
230+
153231
// Step is used as 'userp' and keeps ensemble context alive
154232
// until no more internal requests are inflight.
155233
// Step contains metadata, and status for the
@@ -370,6 +448,11 @@ class EnsembleContext {
370448

371449
size_t inflight_step_counter_;
372450

451+
// Inflight request limiters for each ensemble step.
452+
// Only allocated when max_inflight_requests_ > 0.
453+
std::vector<std::unique_ptr<StepInflightRequestLimiter>>
454+
step_inflight_request_limiters_;
455+
373456
// pointer that either points to 'pruned_tensor_to_step_' or to
374457
// 'info_->tensor_to_step_' if all ensemble outputs are requested
375458
std::unordered_map<std::string, std::set<size_t>>* tensor_to_step_;
@@ -505,6 +588,17 @@ EnsembleContext::EnsembleContext(
505588
}
506589
}
507590

591+
// Initialize step inflight request limiters for each step.
592+
if (info_->max_inflight_requests_ > 0) {
593+
size_t num_steps = info_->steps_.size();
594+
step_inflight_request_limiters_.reserve(num_steps);
595+
for (size_t i = 0; i < num_steps; i++) {
596+
step_inflight_request_limiters_.emplace_back(
597+
std::make_unique<StepInflightRequestLimiter>(
598+
info_->max_inflight_requests_));
599+
}
600+
}
601+
508602
if (ensemble_status_.IsOk()) {
509603
request_id_ = lrequest->Id();
510604
correlation_id_ = lrequest->CorrelationId();
@@ -907,6 +1001,10 @@ EnsembleContext::UpdateEnsembleState(
9071001
if (completed_step->response_flags_ &
9081002
TRITONSERVER_RESPONSE_COMPLETE_FINAL) {
9091003
inflight_step_counter_--;
1004+
if (!step_inflight_request_limiters_.empty()) {
1005+
step_inflight_request_limiters_[completed_step->step_idx_]
1006+
->DecrementInflightCount();
1007+
}
9101008
}
9111009
RETURN_IF_ERROR(ConsumeResponse(completed_step));
9121010
updated_tensors->swap(completed_step->updated_tensors_);
@@ -1392,6 +1490,15 @@ EnsembleContext::ScheduleSteps(
13921490
{
13931491
for (auto& step : steps) {
13941492
step->ctx_ = context;
1493+
size_t this_step_idx = step->step_idx_;
1494+
1495+
// Apply step inflight request limiters if configured.
1496+
if (!context->step_inflight_request_limiters_.empty()) {
1497+
context->step_inflight_request_limiters_[this_step_idx]->WaitForCapacity(
1498+
context->request_tracker_, this_step_idx,
1499+
context->info_->ensemble_name_);
1500+
}
1501+
13951502
bool should_schedule = false;
13961503
// Must release lock before InferAsync to avoid deadlock, as the same thread
13971504
// will be calling request/response callbacks on cache hits, which will
@@ -1421,6 +1528,13 @@ EnsembleContext::ScheduleSteps(
14211528
std::unique_ptr<InferenceRequest> request = std::move(step->request_);
14221529
auto step_status = context->is_->InferAsync(request);
14231530
if (step_status.IsOk()) {
1531+
// Increment inflight counter AFTER successful scheduling. Always
1532+
// increment for ALL steps (including step 0) to ensure symmetry with
1533+
// decrement and prevent underflow when steps complete.
1534+
if (!context->step_inflight_request_limiters_.empty()) {
1535+
context->step_inflight_request_limiters_[this_step_idx]
1536+
->IncrementInflightCount();
1537+
}
14241538
step.release();
14251539
continue;
14261540
} else {
@@ -1602,6 +1716,17 @@ EnsembleScheduler::EnsembleScheduler(
16021716
}
16031717
}
16041718
callback_pool_ = is_->EnsembleCallbackPool();
1719+
1720+
// Parse the configuration for max_inflight_requests from the protobuf field.
1721+
if (config.has_ensemble_scheduling()) {
1722+
info_->max_inflight_requests_ =
1723+
config.ensemble_scheduling().max_inflight_requests();
1724+
if (info_->max_inflight_requests_ > 0) {
1725+
LOG_INFO << "Ensemble model '" << config.name()
1726+
<< "' configured with max_inflight_requests: "
1727+
<< info_->max_inflight_requests_;
1728+
}
1729+
}
16051730
}
16061731

16071732
EnsembleScheduler::~EnsembleScheduler()

src/ensemble_scheduler/ensemble_scheduler.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,15 @@ struct EnsembleInfo {
8383

8484
// backward path, ensemble tensor to the step that provides its data
8585
std::unordered_map<std::string, size_t> tensor_to_prev_step_;
86+
87+
// The maximum number of concurrent inflight requests allowed at each ensemble
88+
// step per inference request. This limit is applied per step and per
89+
// inference request, not globally for the entire ensemble model. This limit
90+
// prevents unbounded memory growth when ensemble steps produce responses
91+
// faster than downstream steps can consume them. Default value is 0, which
92+
// indicates that no limit is enforced. Configured via 'max_inflight_requests'
93+
// field in ensemble_scheduling.
94+
size_t max_inflight_requests_ = 0;
8695
};
8796

8897
// Scheduler that implements ensemble scheduling.

0 commit comments

Comments
 (0)