@@ -625,56 +625,6 @@ def forward(
625625ClassifierFn = Callable [[torch .Tensor ], torch .Tensor ]
626626
627627
628- class VisionPooler (Pooler ):
629-
630- @classmethod
631- def from_config (cls , model_config : ModelConfig ) -> "VisionPooler" :
632- return cls (model_config )
633-
634- def __init__ (self , config : ModelConfig ):
635- super ().__init__ ()
636- self .config = config
637-
638- def get_pooling_params (self , task : PoolingTask ) -> Optional [PoolingParams ]:
639- if task == "embed" :
640- return PoolingParams (pooling_type = "vision" ,
641- logits_processing_needs_token_ids = True )
642- return None
643-
644- def forward (
645- self ,
646- hidden_states : torch .Tensor ,
647- pooling_metadata : PoolingMetadata ,
648- ) -> PoolerOutput :
649- assert isinstance (pooling_metadata , V1PoolingMetadata )
650-
651- pooled_outputs = []
652- for i in range (len (pooling_metadata .prompt_lens )):
653- start_pos = (pooling_metadata .prompt_token_ids [i ] == self .config .
654- hf_config .vision_start_token_id ).nonzero ()[- 1 ].item ()
655- end_pos = (pooling_metadata .prompt_token_ids [i ] == self .config .
656- hf_config .vision_end_token_id ).nonzero ()[- 1 ].item ()
657-
658- seq_start = torch .cumsum (
659- torch .tensor ([0 ] + pooling_metadata .prompt_lens .tolist ()),
660- dim = 0 )[i ]
661- seq_len = pooling_metadata .prompt_lens [i ]
662-
663- output = torch .empty (self .config .hidden_size ,
664- device = hidden_states .device ,
665- dtype = hidden_states .dtype )
666-
667- grid = lambda meta : (self .config .hidden_size , )
668- mean_pool_with_position_kernel [grid ](hidden_states , output ,
669- seq_start , seq_len ,
670- self .config .hidden_size ,
671- start_pos , end_pos + 1 )
672-
673- pooled_outputs .append (output )
674-
675- return build_output (torch .stack (pooled_outputs ))
676-
677-
678628if HAS_TRITON :
679629
680630 @triton .jit
@@ -688,7 +638,6 @@ def mean_pool_with_position_kernel(
688638 pool_end ,
689639 BLOCK_SIZE : tl .constexpr ,
690640 ):
691- """Triton kernel to perform mean pooling over a specified token range."""
692641 pid = tl .program_id (0 )
693642
694643 if pid >= hidden_size :
@@ -817,10 +766,12 @@ def forward(
817766
818767 pooled_outputs = []
819768 for i in range (len (pooling_metadata .prompt_lens )):
820- start_pos = (pooling_metadata .prompt_token_ids [i ] == self .config .
821- hf_config .vision_start_token_id ).nonzero ()[- 1 ].item ()
822- end_pos = (pooling_metadata .prompt_token_ids [i ] == self .config .
823- hf_config .vision_end_token_id ).nonzero ()[- 1 ].item ()
769+ start_pos = (pooling_metadata .prompt_token_ids [i ] ==
770+ self .config .hf_config .vision_start_token_id ).
771+ nonzero ()[- 1 ].item ()
772+ end_pos = (pooling_metadata .prompt_token_ids [i ] ==
773+ self .config .hf_config .vision_end_token_id ).
774+ nonzero ()[- 1 ].item ()
824775
825776 seq_start = torch .cumsum (
826777 torch .tensor ([0 ] + pooling_metadata .prompt_lens .tolist ()),
@@ -832,41 +783,18 @@ def forward(
832783 dtype = hidden_states .dtype )
833784
834785 grid = lambda meta : (self .config .hidden_size , )
835- mean_pool_with_position_kernel [grid ](hidden_states , output ,
836- seq_start , seq_len ,
837- self .config .hidden_size ,
838- start_pos , end_pos + 1 )
786+ if HAS_TRITON :
787+ mean_pool_with_position_kernel [grid ](hidden_states , output ,
788+ seq_start , seq_len ,
789+ self .config .hidden_size ,
790+ start_pos , end_pos + 1 )
791+ else :
792+ # Fallback to PyTorch implementation if Triton is not available
793+ vision_tokens_range = hidden_states [seq_start + start_pos : seq_start + end_pos + 1 ]
794+ output = vision_tokens_range .mean (dim = 0 )
839795
840796 pooled_outputs .append (output )
841797
842798 return build_output (torch .stack (pooled_outputs ))
843799
844800
845- if HAS_TRITON :
846-
847- @triton .jit
848- def mean_pool_with_position_kernel (
849- hidden_states_ptr ,
850- output_ptr ,
851- seq_start ,
852- seq_len ,
853- hidden_size ,
854- pool_start ,
855- pool_end ,
856- BLOCK_SIZE : tl .constexpr ,
857- ):
858- """Triton kernel to perform mean pooling over a specified token range."""
859- pid = tl .program_id (0 )
860-
861- if pid >= hidden_size :
862- return
863-
864- accumulator = 0.0
865- for i in range (pool_start , pool_end ):
866- hidden_val = tl .load (hidden_states_ptr +
867- (seq_start + i ) * hidden_size + pid )
868- accumulator += hidden_val
869-
870- # Store mean pooled result
871- result = accumulator / (pool_end - pool_start )
872- tl .store (output_ptr + pid , result )
0 commit comments