33from collections .abc import Callable
44
55import torch
6- import triton
7- import triton .language as tl
86
97from vllm .config .model import LogprobsMode
8+ from vllm .triton_utils import tl , triton
109from vllm .v1 .outputs import LogprobsTensors , SamplerOutput
1110from vllm .v1 .sample .ops .topk_topp_sampler import apply_top_k_top_p
1211from vllm .v1 .worker .gpu .states import SamplingMetadata
@@ -78,7 +77,10 @@ def sample(
7877
7978@triton .jit
8079def _gumbel_sample_kernel (
81- sampled_ptr ,
80+ local_argmax_ptr ,
81+ local_argmax_stride ,
82+ local_max_ptr ,
83+ local_max_stride ,
8284 logits_ptr ,
8385 logits_stride ,
8486 seeds_ptr ,
@@ -88,57 +90,35 @@ def _gumbel_sample_kernel(
8890 BLOCK_SIZE : tl .constexpr ,
8991):
9092 req_idx = tl .program_id (0 )
91- is_greedy = tl .load (is_greedy_ptr + req_idx )
92-
93- if is_greedy :
94- # Greedy sampling. Don't apply gumbel noise.
95- max_val = float ("-inf" )
96- max_idx = 0
97- for i in range (0 , vocab_size , BLOCK_SIZE ):
98- block = i + tl .arange (0 , BLOCK_SIZE )
99- mask = block < vocab_size
100- logits = tl .load (
101- logits_ptr + req_idx * logits_stride + block ,
102- mask = mask ,
103- other = float ("-inf" ),
104- )
105-
106- idx = tl .argmax (logits , axis = 0 )
107- value = tl .max (logits , axis = 0 )
108- is_greater = value > max_val
109- max_val = tl .where (is_greater , value , max_val )
110- max_idx = tl .where (is_greater , i + idx , max_idx )
111- tl .store (sampled_ptr + req_idx , max_idx )
112- return
113-
114- # Random sampling.
115- # Calculate gumbel seed.
116- seed = tl .load (seeds_ptr + req_idx )
117- pos = tl .load (pos_ptr + req_idx )
118- gumbel_seed = tl .randint (seed , pos )
93+ block_idx = tl .program_id (1 )
94+ block = block_idx * BLOCK_SIZE + tl .arange (0 , BLOCK_SIZE )
95+ mask = block < vocab_size
96+ logits = tl .load (
97+ logits_ptr + req_idx * logits_stride + block ,
98+ mask = mask ,
99+ other = float ("-inf" ),
100+ )
119101
120- max_val = float ("-inf" )
121- max_idx = 0
122- for i in range (0 , vocab_size , BLOCK_SIZE ):
123- block = i + tl .arange (0 , BLOCK_SIZE )
124- mask = block < vocab_size
102+ is_greedy = tl .load (is_greedy_ptr + req_idx )
103+ if not is_greedy :
104+ # Calculate the seed for gumbel noise.
105+ seed = tl .load (seeds_ptr + req_idx )
106+ pos = tl .load (pos_ptr + req_idx )
107+ gumbel_seed = tl .randint (seed , pos )
125108
126109 # Generate gumbel noise.
127110 r = tl .rand (gumbel_seed , block ).to (tl .float64 )
128111 gumbel_noise = - tl .log (- tl .log (r + 1e-20 ) + 1e-20 )
129112 gumbel_noise = gumbel_noise .to (tl .float32 )
130113
131114 # Apply gumbel noise.
132- logits = tl .load (logits_ptr + req_idx * logits_stride + block , mask = mask )
133115 logits = tl .where (mask , logits + gumbel_noise , float ("-inf" ))
134116
135- # Argmax to get the sampled token.
136- idx = tl .argmax (logits , axis = 0 )
137- value = tl .max (logits , axis = 0 )
138- is_greater = value > max_val
139- max_val = tl .where (is_greater , value , max_val )
140- max_idx = tl .where (is_greater , i + idx , max_idx )
141- tl .store (sampled_ptr + req_idx , max_idx )
117+ idx = tl .argmax (logits , axis = 0 )
118+ token_id = block_idx * BLOCK_SIZE + idx
119+ value = tl .max (logits , axis = 0 )
120+ tl .store (local_argmax_ptr + req_idx * local_argmax_stride + block_idx , token_id )
121+ tl .store (local_max_ptr + req_idx * local_max_stride + block_idx , value )
142122
143123
144124def gumbel_sample (
@@ -148,23 +128,36 @@ def gumbel_sample(
148128 pos : torch .Tensor , # [num_reqs]
149129) -> torch .Tensor :
150130 num_reqs , vocab_size = logits .shape
151- # NOTE(woosuk): Use int64 for later indexing.
152- sampled = torch .empty (
131+ BLOCK_SIZE = 1024
132+ num_blocks = triton .cdiv (vocab_size , BLOCK_SIZE )
133+ local_argmax = torch .empty (
153134 num_reqs ,
135+ num_blocks ,
154136 dtype = torch .int64 ,
155137 device = logits .device ,
156138 )
157- _gumbel_sample_kernel [(num_reqs ,)](
158- sampled ,
139+ local_max = torch .empty (
140+ num_reqs ,
141+ num_blocks ,
142+ dtype = torch .float32 ,
143+ device = logits .device ,
144+ )
145+ _gumbel_sample_kernel [(num_reqs , num_blocks )](
146+ local_argmax ,
147+ local_argmax .stride (0 ),
148+ local_max ,
149+ local_max .stride (0 ),
159150 logits ,
160151 logits .stride (0 ),
161152 seed ,
162153 pos ,
163154 is_greedy ,
164155 vocab_size ,
165- num_warps = 8 ,
166- BLOCK_SIZE = 16384 , # type: ignore
156+ BLOCK_SIZE = BLOCK_SIZE ,
167157 )
158+ # NOTE(woosuk): Use int64 for later indexing.
159+ max_block_idx = local_max .argmax (dim = - 1 , keepdim = True )
160+ sampled = local_argmax .gather (dim = - 1 , index = max_block_idx ).view (- 1 )
168161 return sampled
169162
170163
0 commit comments