4040from QEfficient .generation .vlm_generation import VisionLanguageGeneration
4141from QEfficient .transformers .modeling_utils import (
4242 DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH ,
43- SPECIALIZED_PREFILL_ONLY_MODEL_ARCH ,
43+ SPECIALIZED_DISAGG_SERVING_MODEL_ARCH ,
4444)
4545from QEfficient .transformers .models .pytorch_transforms import (
4646 BlockedKVAttentionTransform ,
@@ -2522,15 +2522,18 @@ def get_seq_len_and_handle_specialized_prefill_model(
25222522
25232523 num_q_blocks = os .environ .get ("NUM_Q_BLOCKS" , None )
25242524 if num_q_blocks is None :
2525- block_size = 256
2526- if prefill_seq_len is None or prefill_seq_len % block_size != 0 or prefill_seq_len < 128 :
2525+ if (
2526+ prefill_seq_len is None
2527+ or prefill_seq_len % constants .GPT_OSS_PREFILL_Q_BLOCK_SIZE != 0
2528+ or prefill_seq_len < constants .GPT_OSS_PREFILL_Q_BLOCK_SIZE
2529+ ):
25272530 raise ValueError (
2528- f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={ block_size } . "
2531+ f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={ constants . GPT_OSS_PREFILL_Q_BLOCK_SIZE } . "
25292532 f"Or set `NUM_Q_BLOCKS` ENV variable"
25302533 f"Received: prefill_seq_len={ prefill_seq_len } "
25312534 )
25322535
2533- num_q_blocks = prefill_seq_len // block_size
2536+ num_q_blocks = prefill_seq_len // constants . GPT_OSS_PREFILL_Q_BLOCK_SIZE
25342537 logger .warning (
25352538 f"Setting NUM_Q_BLOCKS={ num_q_blocks } used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_Q_BLOCKS` to override"
25362539 )
@@ -2588,31 +2591,28 @@ def export(
25882591 self .model .config , fbs if self .continuous_batching else bs , seq_len
25892592 )
25902593 enable_chunking = kwargs .get ("enable_chunking" , False )
2591- if prefill_only :
2592- if not enable_chunking and self . continuous_batching :
2593- raise NotImplementedError (
2594- "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!"
2595- )
2596- self . prefill ( enable = True , enable_chunking = enable_chunking )
2597- self .hash_params . pop ( "retain_full_kv" , None )
2598- seq_len = (
2599- self .get_seq_len_and_handle_specialized_prefill_model (
2594+
2595+ # TODO: move this to a DA Serving utility class
2596+ if self . model . config . model_type in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH :
2597+ if prefill_only :
2598+ if self . continuous_batching and not enable_chunking :
2599+ raise NotImplementedError ( "Can't enable prefix-caching without chunking" )
2600+ self .prefill ( enable = True , enable_chunking = enable_chunking )
2601+ self . hash_params . pop ( "retain_full_kv" , None )
2602+ seq_len = self .get_seq_len_and_handle_specialized_prefill_model (
26002603 prefill_seq_len = prefill_seq_len , enable_chunking = enable_chunking
26012604 )
2602- if self .model .config .model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH
2603- else seq_len
2604- )
2605- kv_cache_shape [2 ] = seq_len + self .model .config .sliding_window if enable_chunking else seq_len
2606- else :
2607- self .prefill (False , retain_full_kv = kwargs .get ("retain_full_kv" , False ))
2608- self .hash_params .pop ("prefill_only" , None )
2609- self .hash_params .pop ("NUM_Q_BLOCKS" , None )
2610- self .hash_params .pop ("NUM_FFN_BLOCKS" , None )
2611- self .hash_params .pop ("ENABLE_OPT_SWA" , None )
2612- self .hash_params .pop ("chunking" , None )
2613- if kwargs .get ("retain_full_kv" , False ):
2614- kv_cache_shape [2 ] = seq_len + self .model .config .sliding_window
2615- self .hash_params ["retain_full_kv" ] = True
2605+ kv_cache_shape [2 ] = seq_len + self .model .config .sliding_window if enable_chunking else seq_len
2606+ else :
2607+ self .prefill (False , retain_full_kv = kwargs .get ("retain_full_kv" , False ))
2608+ self .hash_params .pop ("prefill_only" , None )
2609+ self .hash_params .pop ("NUM_Q_BLOCKS" , None )
2610+ self .hash_params .pop ("NUM_FFN_BLOCKS" , None )
2611+ self .hash_params .pop ("ENABLE_OPT_SWA" , None )
2612+ self .hash_params .pop ("chunking" , None )
2613+ if kwargs .get ("retain_full_kv" , False ):
2614+ kv_cache_shape [2 ] = seq_len + self .model .config .sliding_window
2615+ self .hash_params ["retain_full_kv" ] = True
26162616
26172617 example_inputs = {
26182618 "input_ids" : torch .zeros ((bs , seq_len ), dtype = torch .int64 ),
@@ -2741,10 +2741,12 @@ def build_prefill_specialization(
27412741 Dict[str, Union[int, str]]
27422742 A dictionary defining the prefill specialization.
27432743 """
2744- if prefill_seq_len == 1 and self .continuous_batching :
2744+ if not self .continuous_batching :
2745+ exec_batch_size = batch_size
2746+ elif prefill_seq_len == 1 :
27452747 exec_batch_size = full_batch_size
27462748 else :
2747- exec_batch_size = 1 if self . continuous_batching else batch_size
2749+ exec_batch_size = 1
27482750
27492751 if hasattr (self .model , "get_specializations" ):
27502752 spec = self .model .get_specializations (
@@ -2755,7 +2757,7 @@ def build_prefill_specialization(
27552757 )[0 ]
27562758 else :
27572759 spec = {
2758- "batch_size" : 1 if self . continuous_batching else batch_size ,
2760+ "batch_size" : exec_batch_size ,
27592761 "seq_len" : prefill_seq_len ,
27602762 "ctx_len" : ctx_len ,
27612763 }
@@ -2766,8 +2768,9 @@ def build_prefill_specialization(
27662768 spec ["full_batch_size" ] = kv_cache_batch_size
27672769 else :
27682770 spec ["batch_size" ] = kv_cache_batch_size
2771+ # TODO: remove this; not required
27692772 if full_batch_size :
2770- spec ["full_batch_exec_size" ] = full_batch_size
2773+ spec ["full_batch_exec_size" ] = exec_batch_size
27712774 return {k : v for k , v in spec .items () if v is not None }
27722775
27732776 def build_decode_specialization (
@@ -2805,9 +2808,6 @@ def build_decode_specialization(
28052808 A dictionary defining the decode specialization, or None if it would be a duplicate
28062809 of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching).
28072810 """
2808- if prefill_seq_len == 1 and not self .continuous_batching :
2809- return None # Avoid duplication with prefill
2810-
28112811 if hasattr (self .model , "get_specializations" ):
28122812 spec = self .model .get_specializations (
28132813 batch_size = full_batch_size if self .continuous_batching else batch_size ,
@@ -2942,7 +2942,6 @@ def compile(
29422942 if prefill_only is None or not prefill_only :
29432943 if self .continuous_batching and full_batch_size is None :
29442944 raise TypeError ("`full_batch_size` is required when `continuous_batching=True`." )
2945-
29462945 else :
29472946 if self .continuous_batching and kv_cache_batch_size is None and full_batch_size is None :
29482947 raise ValueError (
@@ -3026,7 +3025,7 @@ def compile(
30263025 )
30273026 )
30283027
3029- if prefill_only is None or not prefill_only :
3028+ if ( prefill_only is None or not prefill_only ) and prefill_seq_len != 1 :
30303029 if self .comp_ctx_lengths_decode is not None :
30313030 # Adding elements from self.comp_ctx_lengths_decode to decode_specialization
30323031 for i in range (0 , len (self .comp_ctx_lengths_decode )):
@@ -3055,6 +3054,8 @@ def compile(
30553054 if decode_spec :
30563055 specializations .append (decode_spec )
30573056
3057+ if kw_spec := compiler_options .pop ("specializations" , None ):
3058+ specializations = kw_spec
30583059 # --- Compilation ---
30593060 kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16"
30603061 custom_io = {}
0 commit comments