2
2
3
3
from array import array
4
4
from dataclasses import dataclass
5
- from typing import Dict , List , Optional , Tuple
5
+ from typing import Optional
6
6
7
7
import torch
8
8
@@ -25,10 +25,10 @@ class SequenceGroupToSample:
25
25
# |-- query_len ---|
26
26
27
27
# Sequence ids for the sequence group in a previous step.
28
- seq_ids : List [int ]
28
+ seq_ids : list [int ]
29
29
sampling_params : SamplingParams
30
30
# seq_id -> sequence data.
31
- seq_data : Dict [int , SequenceData ]
31
+ seq_data : dict [int , SequenceData ]
32
32
# The length of the sequence (all tokens seen in the past + new token to
33
33
# compute attention) of the sequence group. None if it is in a decode
34
34
# stage.
@@ -44,9 +44,9 @@ class SequenceGroupToSample:
44
44
is_prompt : bool
45
45
# Query token indices from logits. to compute prompt logprob. Empty if
46
46
# prompt logprob is not required.
47
- prompt_logprob_indices : List [int ]
47
+ prompt_logprob_indices : list [int ]
48
48
# Sample token indices from logits. Empty if sampling is not required.
49
- sample_indices : List [int ]
49
+ sample_indices : list [int ]
50
50
51
51
@property
52
52
def do_sample (self ):
@@ -78,7 +78,7 @@ class SamplingMetadataCache:
78
78
"""Used to cache SamplingMetadata objects between scheduler iterations"""
79
79
80
80
def __init__ (self ):
81
- self ._seq_group_to_sample_cache : Dict [int , PyObjectCache ] = {}
81
+ self ._seq_group_to_sample_cache : dict [int , PyObjectCache ] = {}
82
82
83
83
def get_cached_seq_group_to_sample (self , num_seqs ):
84
84
if num_seqs not in self ._seq_group_to_sample_cache :
@@ -130,9 +130,9 @@ def sample(logits):
130
130
131
131
def __init__ (
132
132
self ,
133
- seq_groups : List [SequenceGroupToSample ],
133
+ seq_groups : list [SequenceGroupToSample ],
134
134
selected_token_indices : torch .Tensor ,
135
- categorized_sample_indices : Dict [SamplingType , torch .Tensor ],
135
+ categorized_sample_indices : dict [SamplingType , torch .Tensor ],
136
136
num_prompts : int ,
137
137
skip_sampler_cpu_output : bool = False ,
138
138
reuse_sampling_tensors : bool = False ,
@@ -146,12 +146,12 @@ def __init__(
146
146
147
147
@staticmethod
148
148
def prepare (
149
- seq_group_metadata_list : List [SequenceGroupMetadata ],
150
- seq_lens : List [int ],
151
- query_lens : List [int ],
149
+ seq_group_metadata_list : list [SequenceGroupMetadata ],
150
+ seq_lens : list [int ],
151
+ query_lens : list [int ],
152
152
device : str ,
153
153
pin_memory : bool ,
154
- generators : Optional [Dict [str , torch .Generator ]] = None ,
154
+ generators : Optional [dict [str , torch .Generator ]] = None ,
155
155
cache : Optional [SamplingMetadataCache ] = None ,
156
156
) -> "SamplingMetadata" :
157
157
(
@@ -195,16 +195,16 @@ def __repr__(self) -> str:
195
195
196
196
197
197
def _prepare_seq_groups (
198
- seq_group_metadata_list : List [SequenceGroupMetadata ],
199
- seq_lens : List [int ],
200
- query_lens : List [int ],
198
+ seq_group_metadata_list : list [SequenceGroupMetadata ],
199
+ seq_lens : list [int ],
200
+ query_lens : list [int ],
201
201
device : str ,
202
- generators : Optional [Dict [str , torch .Generator ]] = None ,
202
+ generators : Optional [dict [str , torch .Generator ]] = None ,
203
203
cache : Optional [SamplingMetadataCache ] = None ,
204
- ) -> Tuple [
205
- List [SequenceGroupToSample ],
206
- List [int ],
207
- Dict [SamplingType , List [int ]],
204
+ ) -> tuple [
205
+ list [SequenceGroupToSample ],
206
+ list [int ],
207
+ dict [SamplingType , list [int ]],
208
208
int ,
209
209
]:
210
210
"""Prepare sequence groups and indices for sampling.
@@ -227,17 +227,17 @@ def _prepare_seq_groups(
227
227
num_prompts: Total number of prompts from `seq_group_metadata_list`.
228
228
"""
229
229
# Batched sequence groups for the current model forward stsep.
230
- seq_groups : List [SequenceGroupToSample ] = []
230
+ seq_groups : list [SequenceGroupToSample ] = []
231
231
# A list of token indices to sample/compute logprob. It is used to
232
232
# prune the outcome logits from the model for the performance.
233
- selected_token_indices : List [int ] = []
233
+ selected_token_indices : list [int ] = []
234
234
# Used for selected_token_indices.
235
235
model_output_idx = 0
236
236
237
237
# Sampling type -> (
238
238
# indices to sample/prompt logprob within pruned output logits,
239
239
# indices to sample within pruned logits)
240
- categorized_sample_indices : Dict [SamplingType , List [int ]] = {
240
+ categorized_sample_indices : dict [SamplingType , list [int ]] = {
241
241
t : []
242
242
for t in SamplingType
243
243
}
@@ -265,9 +265,9 @@ def _prepare_seq_groups(
265
265
# If the current seq group is in decode stage, it is None.
266
266
seq_len : Optional [int ] = None
267
267
query_len : Optional [int ] = None
268
- prompt_logprob_indices : List [int ] = (sample_obj .prompt_logprob_indices
268
+ prompt_logprob_indices : list [int ] = (sample_obj .prompt_logprob_indices
269
269
if cache is not None else [])
270
- sample_indices : List [int ] = (sample_obj .sample_indices
270
+ sample_indices : list [int ] = (sample_obj .sample_indices
271
271
if cache is not None else [])
272
272
do_sample = seq_group_metadata .do_sample
273
273
@@ -389,16 +389,16 @@ def from_sampling_metadata(
389
389
vocab_size : int ,
390
390
device : torch .device ,
391
391
dtype : torch .dtype ,
392
- ) -> Tuple ["SamplingTensors" , bool , bool , bool ]:
393
- prompt_tokens : List [array ] = []
394
- output_tokens : List [array ] = []
395
- top_ks : List [int ] = []
396
- temperatures : List [float ] = []
397
- top_ps : List [float ] = []
398
- min_ps : List [float ] = []
399
- presence_penalties : List [float ] = []
400
- frequency_penalties : List [float ] = []
401
- repetition_penalties : List [float ] = []
392
+ ) -> tuple ["SamplingTensors" , bool , bool , bool ]:
393
+ prompt_tokens : list [array ] = []
394
+ output_tokens : list [array ] = []
395
+ top_ks : list [int ] = []
396
+ temperatures : list [float ] = []
397
+ top_ps : list [float ] = []
398
+ min_ps : list [float ] = []
399
+ presence_penalties : list [float ] = []
400
+ frequency_penalties : list [float ] = []
401
+ repetition_penalties : list [float ] = []
402
402
do_penalties = False
403
403
do_top_p_top_k = False
404
404
do_min_p = False
@@ -496,15 +496,15 @@ def from_sampling_metadata(
496
496
@classmethod
497
497
def from_lists (
498
498
cls ,
499
- temperatures : List [float ],
500
- top_ps : List [float ],
501
- top_ks : List [int ],
502
- min_ps : List [float ],
503
- presence_penalties : List [float ],
504
- frequency_penalties : List [float ],
505
- repetition_penalties : List [float ],
506
- prompt_tokens : List [array ],
507
- output_tokens : List [array ],
499
+ temperatures : list [float ],
500
+ top_ps : list [float ],
501
+ top_ks : list [int ],
502
+ min_ps : list [float ],
503
+ presence_penalties : list [float ],
504
+ frequency_penalties : list [float ],
505
+ repetition_penalties : list [float ],
506
+ prompt_tokens : list [array ],
507
+ output_tokens : list [array ],
508
508
vocab_size : int ,
509
509
device : torch .device ,
510
510
dtype : torch .dtype ,
0 commit comments