|
40 | 40 | from QEfficient.generation.vlm_generation import VisionLanguageGeneration |
41 | 41 | from QEfficient.transformers.modeling_utils import ( |
42 | 42 | DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH, |
43 | | - SPECIALIZED_PREFILL_ONLY_MODEL_ARCH, |
| 43 | + SPECIALIZED_DISAGG_SERVING_MODEL_ARCH, |
44 | 44 | ) |
45 | 45 | from QEfficient.transformers.models.pytorch_transforms import ( |
46 | 46 | BlockedKVAttentionTransform, |
@@ -2522,15 +2522,18 @@ def get_seq_len_and_handle_specialized_prefill_model( |
2522 | 2522 |
|
2523 | 2523 | num_q_blocks = os.environ.get("NUM_Q_BLOCKS", None) |
2524 | 2524 | if num_q_blocks is None: |
2525 | | - block_size = 256 |
2526 | | - if prefill_seq_len is None or prefill_seq_len % block_size != 0 or prefill_seq_len < 128: |
| 2525 | + if ( |
| 2526 | + prefill_seq_len is None |
| 2527 | + or prefill_seq_len % constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE != 0 |
| 2528 | + or prefill_seq_len < constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE |
| 2529 | + ): |
2527 | 2530 | raise ValueError( |
2528 | | - f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={block_size}. " |
| 2531 | + f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE}. " |
2529 | 2532 | f"Or set `NUM_Q_BLOCKS` ENV variable" |
2530 | 2533 | f"Received: prefill_seq_len={prefill_seq_len}" |
2531 | 2534 | ) |
2532 | 2535 |
|
2533 | | - num_q_blocks = prefill_seq_len // block_size |
| 2536 | + num_q_blocks = prefill_seq_len // constants.GPT_OSS_PREFILL_Q_BLOCK_SIZE |
2534 | 2537 | logger.warning( |
2535 | 2538 | f"Setting NUM_Q_BLOCKS={num_q_blocks} used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_Q_BLOCKS` to override" |
2536 | 2539 | ) |
@@ -2588,31 +2591,28 @@ def export( |
2588 | 2591 | self.model.config, fbs if self.continuous_batching else bs, seq_len |
2589 | 2592 | ) |
2590 | 2593 | enable_chunking = kwargs.get("enable_chunking", False) |
2591 | | - if prefill_only: |
2592 | | - if not enable_chunking and self.continuous_batching: |
2593 | | - raise NotImplementedError( |
2594 | | - "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!" |
2595 | | - ) |
2596 | | - self.prefill(enable=True, enable_chunking=enable_chunking) |
2597 | | - self.hash_params.pop("retain_full_kv", None) |
2598 | | - seq_len = ( |
2599 | | - self.get_seq_len_and_handle_specialized_prefill_model( |
| 2594 | + |
| 2595 | + # TODO: move this to a DA Serving utility class |
| 2596 | + if self.model.config.model_type in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH: |
| 2597 | + if prefill_only: |
| 2598 | + if self.continuous_batching and not enable_chunking: |
| 2599 | + raise NotImplementedError("Can't enable prefix-caching without chunking") |
| 2600 | + self.prefill(enable=True, enable_chunking=enable_chunking) |
| 2601 | + self.hash_params.pop("retain_full_kv", None) |
| 2602 | + seq_len = self.get_seq_len_and_handle_specialized_prefill_model( |
2600 | 2603 | prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking |
2601 | 2604 | ) |
2602 | | - if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH |
2603 | | - else seq_len |
2604 | | - ) |
2605 | | - kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len |
2606 | | - else: |
2607 | | - self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False)) |
2608 | | - self.hash_params.pop("prefill_only", None) |
2609 | | - self.hash_params.pop("NUM_Q_BLOCKS", None) |
2610 | | - self.hash_params.pop("NUM_FFN_BLOCKS", None) |
2611 | | - self.hash_params.pop("ENABLE_OPT_SWA", None) |
2612 | | - self.hash_params.pop("chunking", None) |
2613 | | - if kwargs.get("retain_full_kv", False): |
2614 | | - kv_cache_shape[2] = seq_len + self.model.config.sliding_window |
2615 | | - self.hash_params["retain_full_kv"] = True |
| 2605 | + kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len |
| 2606 | + else: |
| 2607 | + self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False)) |
| 2608 | + self.hash_params.pop("prefill_only", None) |
| 2609 | + self.hash_params.pop("NUM_Q_BLOCKS", None) |
| 2610 | + self.hash_params.pop("NUM_FFN_BLOCKS", None) |
| 2611 | + self.hash_params.pop("ENABLE_OPT_SWA", None) |
| 2612 | + self.hash_params.pop("chunking", None) |
| 2613 | + if kwargs.get("retain_full_kv", False): |
| 2614 | + kv_cache_shape[2] = seq_len + self.model.config.sliding_window |
| 2615 | + self.hash_params["retain_full_kv"] = True |
2616 | 2616 |
|
2617 | 2617 | example_inputs = { |
2618 | 2618 | "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), |
@@ -2942,7 +2942,6 @@ def compile( |
2942 | 2942 | if prefill_only is None or not prefill_only: |
2943 | 2943 | if self.continuous_batching and full_batch_size is None: |
2944 | 2944 | raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") |
2945 | | - |
2946 | 2945 | else: |
2947 | 2946 | if self.continuous_batching and kv_cache_batch_size is None and full_batch_size is None: |
2948 | 2947 | raise ValueError( |
|
0 commit comments