Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/constants.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ constexpr int MAX_GRPC_MESSAGE_SIZE = INT32_MAX;
constexpr uint64_t SEQUENCE_IDLE_DEFAULT_MICROSECONDS = 1000 * 1000;
constexpr size_t CUDA_IPC_STRUCT_SIZE = 64;

constexpr int kMutexTimeoutSeconds = 300;

#ifdef TRITON_ENABLE_METRICS
// MetricModelReporter expects a device ID for GPUs, but we reuse this device
// ID for other metrics as well such as for CPU and Response Cache metrics
Expand Down
124 changes: 124 additions & 0 deletions src/ensemble_scheduler/ensemble_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@

#include "ensemble_scheduler.h"

#include <condition_variable>
#include <mutex>

#include "constants.h"
#include "cuda_utils.h"
#include "metrics.h"
#include "model.h"
Expand Down Expand Up @@ -150,6 +152,82 @@ class RequestTracker {
triton::common::ThreadPool* const callback_pool_;
};

// Limits concurrent inflight requests for a single ensemble step.
// Tracks inflight requests count and blocks producers when limit is reached.
class StepInflightRequestLimiter {
public:
explicit StepInflightRequestLimiter(const size_t max_inflight)
: inflight_count_(0), max_inflight_(max_inflight)
{
}

// Wait until capacity is available or request is cancelled.
// No-op if limit not configured (max_inflight_ == 0).
void WaitForCapacity(
RequestTracker* request_tracker, const size_t step_idx,
const std::string& ensemble_name)
{
// No limit configured, no blocking
if (max_inflight_ == 0) {
return;
}

std::unique_lock<std::mutex> lk(mutex_);
auto timeout = std::chrono::seconds(kMutexTimeoutSeconds);

auto is_request_cancelled = [&]() {
auto& req = request_tracker->Request();
return (req == nullptr) || req->IsCancelled();
};

bool capacity_available = cv_.wait_for(lk, timeout, [&] {
return is_request_cancelled() || (inflight_count_ < max_inflight_);
});

// Log error if timeout occurred (not cancellation), but proceed anyway
// to avoid deadlock. Caller always continues after this call.
if (!capacity_available && !is_request_cancelled()) {
LOG_ERROR << "[Internal Error] Ensemble '" << ensemble_name
<< "' unable to schedule step " << step_idx
<< " (inflight: " << inflight_count_
<< " >= limit: " << max_inflight_ << ") for "
<< kMutexTimeoutSeconds
<< " seconds. Proceeding to avoid deadlock.";
}
}

// Increment inflight count after successfully scheduling a request.
// No-op if limit not configured (max_inflight_ == 0).
void IncrementInflightCount()
{
// No limit configured, no tracking needed
if (max_inflight_ == 0) {
return;
}
std::lock_guard<std::mutex> lk(mutex_);
inflight_count_++;
}

// Decrement inflight count when a request completes, and notify waiting
// producers. No-op if limit not configured (max_inflight_ == 0).
void DecrementInflightCount()
{
// No limit configured, no tracking needed
if (max_inflight_ == 0) {
return;
}
std::lock_guard<std::mutex> lk(mutex_);
inflight_count_--;
cv_.notify_one();
}

private:
size_t inflight_count_;
const size_t max_inflight_;
std::mutex mutex_;
std::condition_variable cv_;
};

// Step is used as 'userp' and keeps ensemble context alive
// until no more internal requests are inflight.
// Step contains metadata, and status for the
Expand Down Expand Up @@ -370,6 +448,11 @@ class EnsembleContext {

size_t inflight_step_counter_;

// Inflight request limiters for each ensemble step.
// Only allocated when max_inflight_requests_ > 0.
std::vector<std::unique_ptr<StepInflightRequestLimiter>>
step_inflight_request_limiters_;

// pointer that either points to 'pruned_tensor_to_step_' or to
// 'info_->tensor_to_step_' if all ensemble outputs are requested
std::unordered_map<std::string, std::set<size_t>>* tensor_to_step_;
Expand Down Expand Up @@ -505,6 +588,16 @@ EnsembleContext::EnsembleContext(
}
}

// Initialize step inflight request limiters for each step.
if (info_->max_inflight_requests_ > 0) {
size_t num_steps = info_->steps_.size();
for (size_t i = 0; i < num_steps; i++) {
step_inflight_request_limiters_.emplace_back(
std::make_unique<StepInflightRequestLimiter>(
info_->max_inflight_requests_));
}
}

if (ensemble_status_.IsOk()) {
request_id_ = lrequest->Id();
correlation_id_ = lrequest->CorrelationId();
Expand Down Expand Up @@ -907,6 +1000,10 @@ EnsembleContext::UpdateEnsembleState(
if (completed_step->response_flags_ &
TRITONSERVER_RESPONSE_COMPLETE_FINAL) {
inflight_step_counter_--;
if (!step_inflight_request_limiters_.empty()) {
step_inflight_request_limiters_[completed_step->step_idx_]
->DecrementInflightCount();
}
}
RETURN_IF_ERROR(ConsumeResponse(completed_step));
updated_tensors->swap(completed_step->updated_tensors_);
Expand Down Expand Up @@ -1392,6 +1489,15 @@ EnsembleContext::ScheduleSteps(
{
for (auto& step : steps) {
step->ctx_ = context;
size_t this_step_idx = step->step_idx_;

// Apply step inflight request limiters if configured.
if (!context->step_inflight_request_limiters_.empty()) {
context->step_inflight_request_limiters_[this_step_idx]->WaitForCapacity(
context->request_tracker_, this_step_idx,
context->info_->ensemble_name_);
}

bool should_schedule = false;
// Must release lock before InferAsync to avoid deadlock, as the same thread
// will be calling request/response callbacks on cache hits, which will
Expand Down Expand Up @@ -1421,6 +1527,13 @@ EnsembleContext::ScheduleSteps(
std::unique_ptr<InferenceRequest> request = std::move(step->request_);
auto step_status = context->is_->InferAsync(request);
if (step_status.IsOk()) {
// Increment inflight counter AFTER successful scheduling. Always
// increment for ALL steps (including step 0) to ensure symmetry with
// decrement and prevent underflow when steps complete.
if (!context->step_inflight_request_limiters_.empty()) {
context->step_inflight_request_limiters_[this_step_idx]
->IncrementInflightCount();
}
step.release();
continue;
} else {
Expand Down Expand Up @@ -1602,6 +1715,17 @@ EnsembleScheduler::EnsembleScheduler(
}
}
callback_pool_ = is_->EnsembleCallbackPool();

// Parse the configuration for max_inflight_requests from the protobuf field.
if (config.has_ensemble_scheduling()) {
info_->max_inflight_requests_ =
config.ensemble_scheduling().max_inflight_requests();
if (info_->max_inflight_requests_ > 0) {
LOG_INFO << "Ensemble model '" << config.name()
<< "' configured with max_inflight_requests: "
<< info_->max_inflight_requests_;
}
}
}

EnsembleScheduler::~EnsembleScheduler()
Expand Down
9 changes: 9 additions & 0 deletions src/ensemble_scheduler/ensemble_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ struct EnsembleInfo {

// backward path, ensemble tensor to the step that provides its data
std::unordered_map<std::string, size_t> tensor_to_prev_step_;

// The maximum number of concurrent inflight requests allowed at each ensemble
// step per inference request. This limit is applied per step, not globally
// for the entire ensemble model. This limit prevents unbounded memory growth
// when ensemble steps produce responses faster than downstream steps can
// consume them. Default value is 0, which indicates that no limit is
// enforced. Configured via 'max_inflight_requests' field in
// ensemble_scheduling.
size_t max_inflight_requests_ = 0;
};

// Scheduler that implements ensemble scheduling.
Expand Down