Skip to content

Commit db15be9

Browse files
JimHsiungliutongxuan
authored andcommitted
feat: support profiling tpot in disaggregated pd mode.
1 parent 182a13c commit db15be9

File tree

5 files changed

+132
-7
lines changed

5 files changed

+132
-7
lines changed

xllm/core/common/types.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,9 @@ struct InstanceInfo {
199199
std::vector<int64_t> v_cache_ids;
200200
int32_t dp_size;
201201
// ttft profiling data
202-
std::vector<std::pair<int32_t, int64_t>> ttft_profiling_data;
202+
std::vector<std::pair<int32_t, double>> ttft_profiling_data;
203+
// tpot profiling data
204+
std::vector<std::tuple<int32_t, int32_t, double>> tpot_profiling_data;
203205

204206
nlohmann::json serialize_to_json() const {
205207
nlohmann::json json_val;
@@ -223,6 +225,7 @@ struct InstanceInfo {
223225
json_val["v_cache_ids"] = v_cache_ids;
224226
json_val["dp_size"] = dp_size;
225227
json_val["ttft_profiling_data"] = ttft_profiling_data;
228+
json_val["tpot_profiling_data"] = tpot_profiling_data;
226229
return json_val;
227230
}
228231
};

xllm/core/scheduler/disagg_pd_scheduler.cpp

Lines changed: 45 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,10 +57,11 @@ DisaggPDScheduler::DisaggPDScheduler(Engine* engine, const Options& options)
5757
initialize_rpc_server_and_client("DisaggPDServer");
5858
register_instance_info("DisaggPDServer", engine);
5959

60-
// Profile ttft and update instance info (for non-decode instances)
60+
// Profile ttft & topt and update instance info (for mix instances)
6161
if (!options_.disable_ttft_profiling() &&
62-
options_.instance_role().value() != InstanceRole::DECODE) {
62+
options_.instance_role().value() == InstanceRole::MIX) {
6363
profile_ttft();
64+
profile_tpot();
6465
}
6566
}
6667
}
@@ -114,6 +115,7 @@ void DisaggPDScheduler::register_instance_info(const std::string& server_name,
114115
}
115116

116117
void DisaggPDScheduler::profile_ttft() {
118+
LOG(INFO) << "Start profiling TTFT.";
117119
// get the maximum prefill token length
118120
auto& model_args = engine_->model_args();
119121
int32_t max_context_len = model_args.max_position_embeddings();
@@ -125,16 +127,53 @@ void DisaggPDScheduler::profile_ttft() {
125127
// warm up
126128
profile_manager_->run_request(max_context_len, 0);
127129

128-
// get TTFT starting from max_context_len, dividing the token length by 2 in
129-
// each loop iteration
130+
// get TTFT starting from max_context_len
130131
for (int32_t token_length = max_context_len; token_length > 1;
131-
token_length >>= 1) {
132-
int64_t latency = profile_manager_->run_request(token_length, 0);
132+
token_length *= 0.9) {
133+
double latency = profile_manager_->run_request(token_length, 0);
133134
instance_info_.ttft_profiling_data.emplace_back(
134135
std::make_pair(token_length, latency));
135136
}
136137
}
137138

139+
void DisaggPDScheduler::profile_tpot() {
140+
LOG(INFO) << "Start profiling TPOT.";
141+
// get the maximum token length
142+
auto& model_args = engine_->model_args();
143+
int32_t max_context_len = model_args.max_position_embeddings();
144+
if (!options_.enable_chunked_prefill()) {
145+
max_context_len =
146+
std::min(max_context_len, options_.max_tokens_per_batch());
147+
}
148+
149+
int32_t num_blocks = kv_cache_manager_->num_blocks();
150+
int32_t block_size = kv_cache_manager_->block_size();
151+
int32_t max_seqs_per_batch = options_.max_seqs_per_batch();
152+
int32_t request_blocks = max_context_len / block_size + 1;
153+
int32_t max_batch_size = num_blocks / request_blocks;
154+
155+
// warm up
156+
profile_manager_->run_request(
157+
max_context_len, max_context_len - 1, max_batch_size);
158+
159+
// get TPOT starting from max_context_len, dividing the token length by 2 in
160+
// each loop iteration. Skip small token lengths to speed up profiling.
161+
for (int32_t token_length = max_context_len; token_length > 64;
162+
token_length >>= 1) {
163+
max_batch_size = num_blocks / (token_length / block_size + 1);
164+
int32_t current_max_batch_size = max_batch_size > max_seqs_per_batch
165+
? max_seqs_per_batch
166+
: max_batch_size;
167+
for (int32_t batch_size = current_max_batch_size; batch_size > 0;
168+
batch_size *= 0.9) {
169+
double latency = profile_manager_->profile_decode_step_time(
170+
token_length, batch_size, /*min_context_len=*/64, max_context_len);
171+
instance_info_.tpot_profiling_data.emplace_back(
172+
token_length, batch_size, latency);
173+
}
174+
}
175+
}
176+
138177
// TODO: maybe we should consider update info case even if info already exists
139178
// in local.
140179
bool DisaggPDScheduler::check_remote_instance_info(

xllm/core/scheduler/disagg_pd_scheduler.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ class DisaggPDScheduler : public ContinuousScheduler {
101101
// corresponding TTFT for calculating the estimated TTFT of requests.
102102
void profile_ttft();
103103

104+
void profile_tpot();
105+
104106
// check remote instance info, if not exist, get from master service
105107
bool check_remote_instance_info(const std::string& instance_name);
106108

xllm/core/scheduler/profile/profile_manager.cpp

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -604,4 +604,69 @@ double ProfileManager::run_request(
604604
return latency;
605605
}
606606

607+
// Generate a batch of decode requests and execute it, then return the step
608+
// latency.
609+
double ProfileManager::profile_decode_step_time(int32_t token_length,
610+
int32_t batch_size,
611+
int32_t min_context_len,
612+
int32_t max_context_len) {
613+
double total_latency = 0;
614+
for (int32_t i = 0; i < profile_count_per_step_; ++i) {
615+
std::vector<int32_t> token_length_vec;
616+
std::vector<int32_t> prefix_length_vec;
617+
generate_random_decode_batch(batch_size * token_length,
618+
batch_size,
619+
min_context_len,
620+
max_context_len,
621+
token_length_vec,
622+
prefix_length_vec);
623+
double latency = run_request(token_length_vec, prefix_length_vec);
624+
total_latency += latency;
625+
}
626+
return total_latency / profile_count_per_step_;
627+
}
628+
629+
// Generate a batch of random decode requests with an average length of
630+
// token_length.
631+
void ProfileManager::generate_random_decode_batch(
632+
int32_t total_length,
633+
int32_t batch_size,
634+
int32_t min_context_len,
635+
int32_t max_context_len,
636+
std::vector<int32_t>& token_length_vec,
637+
std::vector<int32_t>& prefix_length_vec) {
638+
CHECK(total_length >= batch_size * min_context_len);
639+
640+
token_length_vec.resize(batch_size, min_context_len);
641+
prefix_length_vec.resize(batch_size, min_context_len - 1);
642+
int remain = total_length - batch_size * min_context_len;
643+
644+
std::random_device rd;
645+
std::mt19937_64 gen(rd());
646+
647+
for (int i = 0; i < batch_size; ++i) {
648+
if (remain == 0) break;
649+
650+
int max = remain > (max_context_len - min_context_len)
651+
? (max_context_len - min_context_len)
652+
: remain;
653+
654+
std::uniform_int_distribution<int> dis(0, max);
655+
int add = dis(gen);
656+
token_length_vec[i] += add;
657+
prefix_length_vec[i] += add;
658+
remain -= add;
659+
}
660+
661+
int idx = 0;
662+
while (remain > 0) {
663+
if (token_length_vec[idx % batch_size] < max_context_len) {
664+
token_length_vec[idx % batch_size] += 1;
665+
prefix_length_vec[idx % batch_size] += 1;
666+
--remain;
667+
}
668+
++idx;
669+
}
670+
}
671+
607672
} // namespace xllm

xllm/core/scheduler/profile/profile_manager.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,13 @@ class ProfileManager {
7979
double run_request(const std::vector<int32_t>& token_length_vec,
8080
const std::vector<int32_t>& prefix_length_vec);
8181

82+
// Generate a batch of decode requests and execute it, then return the step
83+
// latency.
84+
double profile_decode_step_time(int32_t token_length,
85+
int32_t batch_size,
86+
int32_t min_context_len,
87+
int32_t max_context_len);
88+
8289
void train_prefill_time_predictor(
8390
std::vector<std::tuple<int32_t, int32_t, double>> time_profiling_data);
8491

@@ -119,6 +126,15 @@ class ProfileManager {
119126
int32_t lower_bound,
120127
int32_t upper_bound);
121128

129+
// Generate a batch of random decode requests with an average length of
130+
// token_length.
131+
void generate_random_decode_batch(int32_t total_length,
132+
int32_t batch_size,
133+
int32_t min_context_len,
134+
int32_t max_context_len,
135+
std::vector<int32_t>& token_length_vec,
136+
std::vector<int32_t>& prefix_length_vec);
137+
122138
std::unique_ptr<TimePredictor> prefill_time_predictor_;
123139
std::unique_ptr<TimePredictor> decode_time_predictor_;
124140

0 commit comments

Comments
 (0)