Skip to content

Commit 5da0fff

Browse files
committed
Adding the support of CCL to the Prefilling of Disaggregated Serving
Signed-off-by: Vahid Janfaza <vjanfaza@qti.qualcomm.com>
1 parent 33c8ff7 commit 5da0fff

File tree

3 files changed

+28
-23
lines changed

3 files changed

+28
-23
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3298,7 +3298,7 @@ def compile(
32983298
if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None:
32993299
logger.info("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).")
33003300
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
3301-
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
3301+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len, enable_chunking
33023302
)
33033303
# For supporting VLLM and Disaggregated with CCL
33043304
elif comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
@@ -3318,7 +3318,7 @@ def compile(
33183318
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
33193319

33203320
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
3321-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len, prefill_seq_len
3321+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len, prefill_seq_len, enable_chunking
33223322
)
33233323
# --- Validation ---
33243324
if prefill_only is not None and not isinstance(prefill_only, bool):
@@ -3343,8 +3343,8 @@ def compile(
33433343
ccl_lengths = self.comp_ctx_lengths_decode if prefill_seq_len == 1 else self.comp_ctx_lengths_prefill
33443344
# Adding elements from self.comp_ctx_lengths_prefill to prefill_specialization
33453345
for i in range(0, len(ccl_lengths)):
3346-
if prefill_only or enable_chunking:
3347-
raise NotImplementedError("prefill_only or enable_chunking is not supported with CCL")
3346+
# if prefill_only or enable_chunking:
3347+
# raise NotImplementedError("prefill_only or enable_chunking is not supported with CCL")
33483348
specializations.append(
33493349
self.build_prefill_specialization(
33503350
prefill_seq_len=prefill_seq_len,

QEfficient/utils/check_ccl_specializations.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -132,14 +132,15 @@ def validate_ccl_lists(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len):
132132
# Check CCL values are not negative and more than the CCL minimum context length = constants.CCL_MIN_CTX_LEN
133133
if ccl_prefill:
134134
ccl_prefill = [x if x >= constants.CCL_MIN_CTX_LEN else constants.CCL_MIN_CTX_LEN for x in ccl_prefill]
135+
# Check the last element of ccl_prefill and ccl_decode to make sure it's not less than ctx_len
136+
if ccl_prefill[-1] < ctx_len:
137+
ccl_prefill.append(ctx_len)
138+
135139
if ccl_decode:
136140
ccl_decode = [x if x >= constants.CCL_MIN_CTX_LEN else constants.CCL_MIN_CTX_LEN for x in ccl_decode]
137-
138-
# Check the last element of ccl_prefill and ccl_decode to make sure it's not less than ctx_len
139-
if ccl_prefill[-1] < ctx_len - 1:
140-
ccl_prefill.append(ctx_len)
141-
if ccl_decode[-1] < ctx_len:
142-
ccl_decode.append(ctx_len)
141+
# Check the last element of ccl_prefill and ccl_decode to make sure it's not less than ctx_len
142+
if ccl_decode[-1] < ctx_len:
143+
ccl_decode.append(ctx_len)
143144

144145
if prefill_seq_len == 1:
145146
# both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them.
@@ -153,22 +154,25 @@ def validate_ccl_lists(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len):
153154
if ccl_decode:
154155
ccl_decode = sorted({min(x, ctx_len) for x in (ccl_decode)})
155156

156-
# Handling the common values between ccl_prefill and ccl_decode. The elements of these two lists should be unique (COMPILER)
157-
tmp_prefill = ccl_prefill
158-
ccl_prefill = []
159-
for val in tmp_prefill:
160-
while val in ccl_decode or val in ccl_prefill:
161-
val -= 1
162-
if val < 0:
163-
break # Prevent negative values
164-
if val >= 0:
165-
ccl_prefill.append(val)
166-
ccl_prefill.sort()
157+
# This cheking is related to disaggregated serving application since it generates two separate QPCs for prefilling and decoding. So, ccl_prefill will be None in decode QPC and ccl_decode will be None in prefill QPC
158+
if ccl_prefill and ccl_decode:
159+
# Handling the common values between ccl_prefill and ccl_decode. The elements of these two lists should be unique (COMPILER)
160+
tmp_prefill = ccl_prefill
161+
ccl_prefill = []
162+
for val in tmp_prefill:
163+
while val in ccl_decode or val in ccl_prefill:
164+
# In case of common values between ccl_prefill and ccl_decode, change the value in ccl_prefill and set it to the closest value which is multiple of CCL_UNIQNE_STEP to avoid repetition and also be hardware and compiler efficient
165+
val = (val - 1) // constants.CCL_UNIQNE_STEP * constants.CCL_UNIQNE_STEP
166+
if val < 0:
167+
break # Prevent negative values
168+
if val >= 0:
169+
ccl_prefill.append(val)
170+
ccl_prefill.sort()
167171

168172
return ccl_prefill, ccl_decode
169173

170174

171-
def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len):
175+
def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len, enable_chunking=False):
172176
"""
173177
This function evaluates the values of CCL lists based on three inputs:
174178
- ccl_prefill: optional [list]
@@ -193,7 +197,7 @@ def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_le
193197

194198
# One of ccl lists is [] or None -> replace it with [ctx_len] -> CCL lists have to have a value when CCL is enabled
195199
# Condition #3, #4, #5, and #6
196-
elif not ccl_prefill or not ccl_decode:
200+
elif not ccl_prefill or not ccl_decode and not enable_chunking:
197201
# Initial setting and will be checked with edge cases later
198202
ccl_prefill = ccl_prefill if ccl_prefill else [ctx_len]
199203
ccl_decode = ccl_decode if ccl_decode else [ctx_len]

QEfficient/utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def get_models_dir():
187187
CCL_MAX_ELEMENTS_LISTS = 5
188188
CCL_START_CTX_LEN = 4096
189189
CCL_MIN_CTX_LEN = 1024
190+
CCL_UNIQNE_STEP = 32
190191

191192
# used for gpt-oss prefill-only model Q-blocking
192193
GPT_OSS_PREFILL_Q_BLOCK_SIZE = 256

0 commit comments

Comments
 (0)