Skip to content

Commit b934151

Browse files
vjanfazaAnn Kuruvilla
authored andcommitted
Adding the support of CCL to the Prefilling of Disaggregated Serving (quic#825)
In this PR, I have added the support of CCL during prefilling of Disaggregated Serving. In the current version, we only have the support of CCL during decoding of DA which results in very high TTFT for larger Context Lengths. With this added we can compile the model with the largest CL and yet get good TTFT for smaller PL using the related CCL value instead of CL. These changes don't affect other applications and are only related to Disaggregated Serving and only prefilling of Disaggregated Serving. --------- Signed-off-by: Vahid Janfaza <vjanfaza@qti.qualcomm.com>
1 parent 3e76f60 commit b934151

File tree

3 files changed

+26
-23
lines changed

3 files changed

+26
-23
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3288,7 +3288,7 @@ def compile(
32883288
if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None:
32893289
logger.info("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).")
32903290
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
3291-
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
3291+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len, enable_chunking
32923292
)
32933293
# For supporting VLLM and Disaggregated with CCL
32943294
elif comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
@@ -3308,7 +3308,7 @@ def compile(
33083308
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
33093309

33103310
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
3311-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len, prefill_seq_len
3311+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len, prefill_seq_len, enable_chunking
33123312
)
33133313
# --- Validation ---
33143314
if prefill_only is not None and not isinstance(prefill_only, bool):
@@ -3333,8 +3333,6 @@ def compile(
33333333
ccl_lengths = self.comp_ctx_lengths_decode if prefill_seq_len == 1 else self.comp_ctx_lengths_prefill
33343334
# Adding elements from self.comp_ctx_lengths_prefill to prefill_specialization
33353335
for i in range(0, len(ccl_lengths)):
3336-
if prefill_only or enable_chunking:
3337-
raise NotImplementedError("prefill_only or enable_chunking is not supported with CCL")
33383336
specializations.append(
33393337
self.build_prefill_specialization(
33403338
prefill_seq_len=prefill_seq_len,

QEfficient/utils/check_ccl_specializations.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,15 @@ def validate_ccl_lists(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len):
132132
# Check CCL values are not negative and more than the CCL minimum context length = constants.CCL_MIN_CTX_LEN
133133
if ccl_prefill:
134134
ccl_prefill = [x if x >= constants.CCL_MIN_CTX_LEN else constants.CCL_MIN_CTX_LEN for x in ccl_prefill]
135+
# Check the last element of ccl_prefill and ccl_decode to make sure it's not less than ctx_len
136+
if ccl_prefill[-1] < ctx_len:
137+
ccl_prefill.append(ctx_len)
138+
135139
if ccl_decode:
136140
ccl_decode = [x if x >= constants.CCL_MIN_CTX_LEN else constants.CCL_MIN_CTX_LEN for x in ccl_decode]
137-
138-
# Check the last element of ccl_prefill and ccl_decode to make sure it's not less than ctx_len
139-
if ccl_prefill[-1] < ctx_len - 1:
140-
ccl_prefill.append(ctx_len)
141-
if ccl_decode[-1] < ctx_len:
142-
ccl_decode.append(ctx_len)
141+
# Check the last element of ccl_prefill and ccl_decode to make sure it's not less than ctx_len
142+
if ccl_decode[-1] < ctx_len:
143+
ccl_decode.append(ctx_len)
143144

144145
if prefill_seq_len == 1:
145146
# both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them.
@@ -153,22 +154,25 @@ def validate_ccl_lists(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len):
153154
if ccl_decode:
154155
ccl_decode = sorted({min(x, ctx_len) for x in (ccl_decode)})
155156

156-
# Handling the common values between ccl_prefill and ccl_decode. The elements of these two lists should be unique (COMPILER)
157-
tmp_prefill = ccl_prefill
158-
ccl_prefill = []
159-
for val in tmp_prefill:
160-
while val in ccl_decode or val in ccl_prefill:
161-
val -= 1
162-
if val < 0:
163-
break # Prevent negative values
164-
if val >= 0:
165-
ccl_prefill.append(val)
166-
ccl_prefill.sort()
157+
# This cheking is related to disaggregated serving application since it generates two separate QPCs for prefilling and decoding. So, ccl_prefill will be None in decode QPC and ccl_decode will be None in prefill QPC
158+
if ccl_prefill and ccl_decode:
159+
# Handling the common values between ccl_prefill and ccl_decode. The elements of these two lists should be unique (COMPILER)
160+
tmp_prefill = ccl_prefill
161+
ccl_prefill = []
162+
for val in tmp_prefill:
163+
while val in ccl_decode or val in ccl_prefill:
164+
# In case of common values between ccl_prefill and ccl_decode, change the value in ccl_prefill and set it to the closest value which is multiple of CCL_UNIQNE_STEP to avoid repetition and also be hardware and compiler efficient
165+
val = (val - 1) // constants.CCL_UNIQNE_STEP * constants.CCL_UNIQNE_STEP
166+
if val < 0:
167+
break # Prevent negative values
168+
if val >= 0:
169+
ccl_prefill.append(val)
170+
ccl_prefill.sort()
167171

168172
return ccl_prefill, ccl_decode
169173

170174

171-
def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len):
175+
def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len, enable_chunking=False):
172176
"""
173177
This function evaluates the values of CCL lists based on three inputs:
174178
- ccl_prefill: optional [list]
@@ -193,7 +197,7 @@ def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_le
193197

194198
# One of ccl lists is [] or None -> replace it with [ctx_len] -> CCL lists have to have a value when CCL is enabled
195199
# Condition #3, #4, #5, and #6
196-
elif not ccl_prefill or not ccl_decode:
200+
elif not ccl_prefill or not ccl_decode and not enable_chunking:
197201
# Initial setting and will be checked with edge cases later
198202
ccl_prefill = ccl_prefill if ccl_prefill else [ctx_len]
199203
ccl_decode = ccl_decode if ccl_decode else [ctx_len]

QEfficient/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ def get_models_dir():
190190
CCL_MAX_ELEMENTS_LISTS = 5
191191
CCL_START_CTX_LEN = 4096
192192
CCL_MIN_CTX_LEN = 1024
193+
CCL_UNIQNE_STEP = 32
193194

194195
# used for gpt-oss prefill-only model Q-blocking
195196
GPT_OSS_PREFILL_Q_BLOCK_SIZE = 256

0 commit comments

Comments
 (0)