|
28 | 28 |
|
29 | 29 | #include "ensemble_scheduler.h" |
30 | 30 |
|
| 31 | +#include <condition_variable> |
31 | 32 | #include <mutex> |
32 | 33 |
|
| 34 | +#include "constants.h" |
33 | 35 | #include "cuda_utils.h" |
34 | 36 | #include "metrics.h" |
35 | 37 | #include "model.h" |
@@ -150,6 +152,82 @@ class RequestTracker { |
150 | 152 | triton::common::ThreadPool* const callback_pool_; |
151 | 153 | }; |
152 | 154 |
|
| 155 | +// Limits concurrent inflight requests for a single ensemble step. |
| 156 | +// Tracks inflight requests count and blocks producers when limit is reached. |
| 157 | +class StepInflightRequestLimiter { |
| 158 | + public: |
| 159 | + explicit StepInflightRequestLimiter(const size_t max_inflight) |
| 160 | + : inflight_count_(0), max_inflight_(max_inflight) |
| 161 | + { |
| 162 | + } |
| 163 | + |
| 164 | + // Wait until capacity is available or request is cancelled. |
| 165 | + // No-op if limit not configured (max_inflight_ == 0). |
| 166 | + void WaitForCapacity( |
| 167 | + RequestTracker* request_tracker, const size_t step_idx, |
| 168 | + const std::string& ensemble_name) |
| 169 | + { |
| 170 | + // No limit configured, no blocking |
| 171 | + if (max_inflight_ == 0) { |
| 172 | + return; |
| 173 | + } |
| 174 | + |
| 175 | + std::unique_lock<std::mutex> lk(mutex_); |
| 176 | + auto timeout = std::chrono::seconds(kMutexTimeoutSeconds); |
| 177 | + |
| 178 | + auto is_request_cancelled = [&]() { |
| 179 | + auto& req = request_tracker->Request(); |
| 180 | + return (req == nullptr) || req->IsCancelled(); |
| 181 | + }; |
| 182 | + |
| 183 | + bool capacity_available = cv_.wait_for(lk, timeout, [&] { |
| 184 | + return is_request_cancelled() || (inflight_count_ < max_inflight_); |
| 185 | + }); |
| 186 | + |
| 187 | + // Log error if timeout occurred (not cancellation), but proceed anyway |
| 188 | + // to avoid deadlock. Caller always continues after this call. |
| 189 | + if (!capacity_available && !is_request_cancelled()) { |
| 190 | + LOG_ERROR << "[Internal Error] Ensemble '" << ensemble_name |
| 191 | + << "' unable to schedule step " << step_idx |
| 192 | + << " (inflight: " << inflight_count_ |
| 193 | + << " >= limit: " << max_inflight_ << ") for " |
| 194 | + << kMutexTimeoutSeconds |
| 195 | + << " seconds. Proceeding to avoid deadlock."; |
| 196 | + } |
| 197 | + } |
| 198 | + |
| 199 | + // Increment inflight count after successfully scheduling a request. |
| 200 | + // No-op if limit not configured (max_inflight_ == 0). |
| 201 | + void IncrementInflightCount() |
| 202 | + { |
| 203 | + // No limit configured, no tracking needed |
| 204 | + if (max_inflight_ == 0) { |
| 205 | + return; |
| 206 | + } |
| 207 | + std::lock_guard<std::mutex> lk(mutex_); |
| 208 | + inflight_count_++; |
| 209 | + } |
| 210 | + |
| 211 | + // Decrement inflight count when a request completes, and notify waiting |
| 212 | + // producers. No-op if limit not configured (max_inflight_ == 0). |
| 213 | + void DecrementInflightCount() |
| 214 | + { |
| 215 | + // No limit configured, no tracking needed |
| 216 | + if (max_inflight_ == 0) { |
| 217 | + return; |
| 218 | + } |
| 219 | + std::lock_guard<std::mutex> lk(mutex_); |
| 220 | + inflight_count_--; |
| 221 | + cv_.notify_one(); |
| 222 | + } |
| 223 | + |
| 224 | + private: |
| 225 | + size_t inflight_count_; |
| 226 | + const size_t max_inflight_; |
| 227 | + std::mutex mutex_; |
| 228 | + std::condition_variable cv_; |
| 229 | +}; |
| 230 | + |
153 | 231 | // Step is used as 'userp' and keeps ensemble context alive |
154 | 232 | // until no more internal requests are inflight. |
155 | 233 | // Step contains metadata, and status for the |
@@ -370,6 +448,11 @@ class EnsembleContext { |
370 | 448 |
|
371 | 449 | size_t inflight_step_counter_; |
372 | 450 |
|
| 451 | + // Inflight request limiters for each ensemble step. |
| 452 | + // Only allocated when max_inflight_requests_ > 0. |
| 453 | + std::vector<std::unique_ptr<StepInflightRequestLimiter>> |
| 454 | + step_inflight_request_limiters_; |
| 455 | + |
373 | 456 | // pointer that either points to 'pruned_tensor_to_step_' or to |
374 | 457 | // 'info_->tensor_to_step_' if all ensemble outputs are requested |
375 | 458 | std::unordered_map<std::string, std::set<size_t>>* tensor_to_step_; |
@@ -505,6 +588,17 @@ EnsembleContext::EnsembleContext( |
505 | 588 | } |
506 | 589 | } |
507 | 590 |
|
| 591 | + // Initialize step inflight request limiters for each step. |
| 592 | + if (info_->max_inflight_requests_ > 0) { |
| 593 | + size_t num_steps = info_->steps_.size(); |
| 594 | + step_inflight_request_limiters_.reserve(num_steps); |
| 595 | + for (size_t i = 0; i < num_steps; i++) { |
| 596 | + step_inflight_request_limiters_.emplace_back( |
| 597 | + std::make_unique<StepInflightRequestLimiter>( |
| 598 | + info_->max_inflight_requests_)); |
| 599 | + } |
| 600 | + } |
| 601 | + |
508 | 602 | if (ensemble_status_.IsOk()) { |
509 | 603 | request_id_ = lrequest->Id(); |
510 | 604 | correlation_id_ = lrequest->CorrelationId(); |
@@ -907,6 +1001,10 @@ EnsembleContext::UpdateEnsembleState( |
907 | 1001 | if (completed_step->response_flags_ & |
908 | 1002 | TRITONSERVER_RESPONSE_COMPLETE_FINAL) { |
909 | 1003 | inflight_step_counter_--; |
| 1004 | + if (!step_inflight_request_limiters_.empty()) { |
| 1005 | + step_inflight_request_limiters_[completed_step->step_idx_] |
| 1006 | + ->DecrementInflightCount(); |
| 1007 | + } |
910 | 1008 | } |
911 | 1009 | RETURN_IF_ERROR(ConsumeResponse(completed_step)); |
912 | 1010 | updated_tensors->swap(completed_step->updated_tensors_); |
@@ -1392,6 +1490,15 @@ EnsembleContext::ScheduleSteps( |
1392 | 1490 | { |
1393 | 1491 | for (auto& step : steps) { |
1394 | 1492 | step->ctx_ = context; |
| 1493 | + size_t this_step_idx = step->step_idx_; |
| 1494 | + |
| 1495 | + // Apply step inflight request limiters if configured. |
| 1496 | + if (!context->step_inflight_request_limiters_.empty()) { |
| 1497 | + context->step_inflight_request_limiters_[this_step_idx]->WaitForCapacity( |
| 1498 | + context->request_tracker_, this_step_idx, |
| 1499 | + context->info_->ensemble_name_); |
| 1500 | + } |
| 1501 | + |
1395 | 1502 | bool should_schedule = false; |
1396 | 1503 | // Must release lock before InferAsync to avoid deadlock, as the same thread |
1397 | 1504 | // will be calling request/response callbacks on cache hits, which will |
@@ -1421,6 +1528,13 @@ EnsembleContext::ScheduleSteps( |
1421 | 1528 | std::unique_ptr<InferenceRequest> request = std::move(step->request_); |
1422 | 1529 | auto step_status = context->is_->InferAsync(request); |
1423 | 1530 | if (step_status.IsOk()) { |
| 1531 | + // Increment inflight counter AFTER successful scheduling. Always |
| 1532 | + // increment for ALL steps (including step 0) to ensure symmetry with |
| 1533 | + // decrement and prevent underflow when steps complete. |
| 1534 | + if (!context->step_inflight_request_limiters_.empty()) { |
| 1535 | + context->step_inflight_request_limiters_[this_step_idx] |
| 1536 | + ->IncrementInflightCount(); |
| 1537 | + } |
1424 | 1538 | step.release(); |
1425 | 1539 | continue; |
1426 | 1540 | } else { |
@@ -1602,6 +1716,17 @@ EnsembleScheduler::EnsembleScheduler( |
1602 | 1716 | } |
1603 | 1717 | } |
1604 | 1718 | callback_pool_ = is_->EnsembleCallbackPool(); |
| 1719 | + |
| 1720 | + // Parse the configuration for max_inflight_requests from the protobuf field. |
| 1721 | + if (config.has_ensemble_scheduling()) { |
| 1722 | + info_->max_inflight_requests_ = |
| 1723 | + config.ensemble_scheduling().max_inflight_requests(); |
| 1724 | + if (info_->max_inflight_requests_ > 0) { |
| 1725 | + LOG_INFO << "Ensemble model '" << config.name() |
| 1726 | + << "' configured with max_inflight_requests: " |
| 1727 | + << info_->max_inflight_requests_; |
| 1728 | + } |
| 1729 | + } |
1605 | 1730 | } |
1606 | 1731 |
|
1607 | 1732 | EnsembleScheduler::~EnsembleScheduler() |
|
0 commit comments