Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions src/ensemble_scheduler/ensemble_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include "ensemble_scheduler.h"

#include <condition_variable>
#include <mutex>

#include "cuda_utils.h"
Expand All @@ -45,6 +46,9 @@ class EnsembleContext;

using IterationCount = size_t;

// Timeout for mutex blocking to prevent potential deadlocks
constexpr int kMutexTimeoutSeconds = 300;

// Check if the model is configured to preserve the order of responses.
// This is critical for async execution of ResponseComplete callbacks.
inline bool
Expand Down Expand Up @@ -370,6 +374,13 @@ class EnsembleContext {

size_t inflight_step_counter_;

// Backpressure support: Limits memory growth from decoupled models.
// Tracks inflight requests per step; blocks producers when downstream
// consumers are overloaded. Only active if max_inflight_requests_ > 0.
std::vector<size_t> step_inflight_request_counts_;
std::vector<std::unique_ptr<std::mutex>> step_mutexes_;
std::vector<std::unique_ptr<std::condition_variable>> step_cv_vec_;

// pointer that either points to 'pruned_tensor_to_step_' or to
// 'info_->tensor_to_step_' if all ensemble outputs are requested
std::unordered_map<std::string, std::set<size_t>>* tensor_to_step_;
Expand Down Expand Up @@ -505,6 +516,20 @@ EnsembleContext::EnsembleContext(
}
}

// Initialize backpressure tracking if enabled.
size_t num_steps = info_->steps_.size();
step_inflight_request_counts_.resize(num_steps, 0);

if (info_->max_inflight_requests_ > 0) {
step_mutexes_.resize(num_steps);
step_cv_vec_.resize(num_steps);

for (size_t i = 0; i < num_steps; i++) {
step_mutexes_[i] = std::make_unique<std::mutex>();
step_cv_vec_[i] = std::make_unique<std::condition_variable>();
}
}

if (ensemble_status_.IsOk()) {
request_id_ = lrequest->Id();
correlation_id_ = lrequest->CorrelationId();
Expand Down Expand Up @@ -907,6 +932,16 @@ EnsembleContext::UpdateEnsembleState(
if (completed_step->response_flags_ &
TRITONSERVER_RESPONSE_COMPLETE_FINAL) {
inflight_step_counter_--;

size_t completed_step_idx = completed_step->step_idx_;

// Decrement step_inflight_request_counts_, then notify any producer
// threads blocked waiting for this step's capacity
if (info_->max_inflight_requests_ > 0 && !step_cv_vec_.empty()) {
std::lock_guard<std::mutex> lk(*step_mutexes_[completed_step_idx]);
step_inflight_request_counts_[completed_step_idx]--;
step_cv_vec_[completed_step_idx]->notify_one();
}
}
RETURN_IF_ERROR(ConsumeResponse(completed_step));
updated_tensors->swap(completed_step->updated_tensors_);
Expand Down Expand Up @@ -1392,6 +1427,39 @@ EnsembleContext::ScheduleSteps(
{
for (auto& step : steps) {
step->ctx_ = context;
size_t this_step_idx = step->step_idx_;

// Apply backpressure to downstream steps only, not the entry step
if ((this_step_idx != 0) && context->info_->max_inflight_requests_ > 0 &&
!context->step_cv_vec_.empty()) {
std::unique_lock<std::mutex> lk(*context->step_mutexes_[this_step_idx]);

auto timeout = std::chrono::seconds(kMutexTimeoutSeconds);
auto cancelled = [&]() {
auto& req = context->request_tracker_->Request();
return (req == nullptr) || req->IsCancelled();
};

bool capacity_available =
context->step_cv_vec_[this_step_idx]->wait_for(lk, timeout, [&] {
return cancelled() ||
(context->step_inflight_request_counts_[this_step_idx] <
context->info_->max_inflight_requests_);
});

// Log error only if timeout occurred (not cancellation).
if (!capacity_available && !cancelled()) {
LOG_ERROR << "[Internal Error] Ensemble '"
<< context->info_->ensemble_name_
<< "' unable to schedule step " << this_step_idx
<< " (inflight: "
<< context->step_inflight_request_counts_[this_step_idx]
<< " >= limit: " << context->info_->max_inflight_requests_
<< ") for " << kMutexTimeoutSeconds
<< " seconds. Proceeding to avoid deadlock.";
}
}

bool should_schedule = false;
// Must release lock before InferAsync to avoid deadlock, as the same thread
// will be calling request/response callbacks on cache hits, which will
Expand Down Expand Up @@ -1421,6 +1489,15 @@ EnsembleContext::ScheduleSteps(
std::unique_ptr<InferenceRequest> request = std::move(step->request_);
auto step_status = context->is_->InferAsync(request);
if (step_status.IsOk()) {
// Increment inflight counter AFTER successful scheduling. Always
// increment for ALL steps (including step 0) to ensure symmetry with
// decrement and prevent underflow when steps complete.
if (context->info_->max_inflight_requests_ > 0 &&
!context->step_mutexes_.empty()) {
std::lock_guard<std::mutex> lk(
*context->step_mutexes_[this_step_idx]);
context->step_inflight_request_counts_[this_step_idx]++;
}
step.release();
continue;
} else {
Expand Down Expand Up @@ -1602,6 +1679,18 @@ EnsembleScheduler::EnsembleScheduler(
}
}
callback_pool_ = is_->EnsembleCallbackPool();

// Backpressure configuration from protobuf field. Limits concurrent requests
// from decoupled steps to prevent memory growth. Value of 0 means unlimited.
if (config.has_ensemble_scheduling()) {
info_->max_inflight_requests_ =
config.ensemble_scheduling().max_inflight_requests();
if (info_->max_inflight_requests_ > 0) {
LOG_INFO << "Ensemble model '" << config.name()
<< "' configured with max_inflight_requests: "
<< info_->max_inflight_requests_;
}
}
}

EnsembleScheduler::~EnsembleScheduler()
Expand Down
7 changes: 7 additions & 0 deletions src/ensemble_scheduler/ensemble_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ struct EnsembleInfo {

// backward path, ensemble tensor to the step that provides its data
std::unordered_map<std::string, size_t> tensor_to_prev_step_;

// Maximum concurrent inflight requests to ensemble steps (downstream
// consumers). Prevents memory growth by blocking producers when limit is
// reached. Default value is 0, which indicates unlimited (no backpressure
// applied). Configured via 'max_inflight_requests' field in
// ensemble_scheduling.
size_t max_inflight_requests_ = 0;
};

// Scheduler that implements ensemble scheduling.
Expand Down