4
4
import enum
5
5
import functools
6
6
from abc import abstractmethod
7
- from collections .abc import Hashable
8
7
from dataclasses import dataclass , fields , make_dataclass
9
8
from typing import (TYPE_CHECKING , Any , ClassVar , Generic , Optional , Protocol ,
10
9
TypeVar )
@@ -67,11 +66,12 @@ class CommonAttentionMetadata:
67
66
block_table_tensor : torch .Tensor
68
67
slot_mapping : torch .Tensor
69
68
69
+ causal : bool = True
70
+
71
+ # Needed by FastPrefillAttentionBuilder
70
72
logits_indices_padded : Optional [torch .Tensor ] = None
71
73
num_logits_indices : Optional [int ] = None
72
74
73
- causal : bool = True
74
-
75
75
76
76
@dataclass
77
77
class UbatchSlice :
@@ -557,9 +557,8 @@ def make_kv_sharing_fast_prefill_common_attn_metadata(
557
557
# Skip computing fast prefill path
558
558
return common_attn_metadata
559
559
560
- if (common_attn_metadata .logits_indices_padded is None
561
- or common_attn_metadata .num_logits_indices is None ):
562
- return common_attn_metadata
560
+ assert common_attn_metadata .logits_indices_padded is not None
561
+ assert common_attn_metadata .num_logits_indices is not None
563
562
564
563
logits_indices_padded = common_attn_metadata .logits_indices_padded
565
564
num_logits_indices = common_attn_metadata .num_logits_indices
@@ -750,59 +749,12 @@ def subclass_attention_metadata(
750
749
return Wrapped
751
750
752
751
753
- @functools .lru_cache
754
- def make_kv_sharing_fast_prefill_attention_metadata (
755
- metadata_cls : Hashable , ) -> Any :
756
- """
757
- Return a new subclass of `metadata_cls` for fast prefill
758
- """
759
- attn_metadata_dataclass = subclass_attention_metadata (
760
- name_prefix = "KVSharingFastPrefill" ,
761
- metadata_cls = metadata_cls ,
762
- fields = KV_SHARING_FAST_PREFILL_METADATA_FIELDS ,
763
- )
764
- # Make attention metadata type inherit
765
- # KVSharingFastPrefillAttentionMetadata type
766
- fast_prefill_metadata_type = type (
767
- attn_metadata_dataclass .__name__ ,
768
- (
769
- attn_metadata_dataclass ,
770
- KVSharingFastPrefillAttentionMetadata ,
771
- ),
772
- {},
773
- )
774
- return fast_prefill_metadata_type
775
-
776
-
777
752
@runtime_checkable
778
- class KVSharingFastPrefillAttentionMetadata (Protocol ):
753
+ class KVSharingFastPrefillMetadata (Protocol ):
779
754
logits_indices_padded : torch .Tensor
780
755
num_logits_indices : int
781
756
782
757
783
- def create_kv_sharing_fast_prefill_attn_metadata_subclass (
784
- metadata : Any ,
785
- common_attn_metadata : CommonAttentionMetadata ,
786
- ) -> Any :
787
- # Dynamically create a a dataclass type that inherits
788
- # from attention metadata type but includes additional
789
- # fields logits_indices_padded and num_logits_indices
790
- # which are required for prefill truncation
791
- fast_prefill_metadata_type = (
792
- make_kv_sharing_fast_prefill_attention_metadata (
793
- metadata_cls = type (metadata ), )) # type: ignore
794
- # Avoid deepcopy caused by dict.asdict
795
- attn_metadata_fields = {}
796
- for field in fields (metadata .__class__ ):
797
- attn_metadata_fields [field .name ] = getattr (metadata , field .name )
798
- attn_metadata_i = fast_prefill_metadata_type (
799
- ** attn_metadata_fields ,
800
- logits_indices_padded = common_attn_metadata .logits_indices_padded ,
801
- num_logits_indices = common_attn_metadata .num_logits_indices ,
802
- )
803
- return attn_metadata_i
804
-
805
-
806
758
def create_fast_prefill_custom_backend (
807
759
prefix : str ,
808
760
underlying_attn_backend : AttentionBackend ,
@@ -820,7 +772,27 @@ def build(self,
820
772
make_kv_sharing_fast_prefill_common_attn_metadata (common_attn_metadata )
821
773
metadata = super ().build (common_prefix_len ,
822
774
new_common_attn_metadata , fast_build )
823
- return create_kv_sharing_fast_prefill_attn_metadata_subclass (
775
+
776
+ class KVSharingFastPrefillAttentionMetadata (
777
+ metadata .__class__ , KVSharingFastPrefillMetadata ):
778
+
779
+ def __init__ (self , metadata , common_attention_metadata ):
780
+ # Shallow copy all fields in metadata cls
781
+ for field in fields (metadata .__class__ ):
782
+ setattr (self , field .name ,
783
+ getattr (metadata , field .name ))
784
+
785
+ # Set additional fields that will be used in model code
786
+ assert (common_attn_metadata .logits_indices_padded
787
+ is not None
788
+ and common_attn_metadata .num_logits_indices
789
+ is not None )
790
+ self .logits_indices_padded = \
791
+ common_attn_metadata .logits_indices_padded
792
+ self .num_logits_indices = \
793
+ common_attn_metadata .num_logits_indices
794
+
795
+ return KVSharingFastPrefillAttentionMetadata (
824
796
metadata , common_attn_metadata )
825
797
826
798
attn_backend = subclass_attention_backend (
0 commit comments