@@ -149,7 +149,7 @@ class SamplingParams(
149
149
top_p: Float that controls the cumulative probability of the top tokens
150
150
to consider. Must be in (0, 1]. Set to 1 to consider all tokens.
151
151
top_k: Integer that controls the number of top tokens to consider. Set
152
- to -1 to consider all tokens.
152
+ to 0 (or -1) to consider all tokens.
153
153
min_p: Float that represents the minimum probability for a token to be
154
154
considered, relative to the probability of the most likely token.
155
155
Must be in [0, 1]. Set to 0 to disable this.
@@ -209,7 +209,7 @@ class SamplingParams(
209
209
repetition_penalty : float = 1.0
210
210
temperature : float = 1.0
211
211
top_p : float = 1.0
212
- top_k : int = - 1
212
+ top_k : int = 0
213
213
min_p : float = 0.0
214
214
seed : Optional [int ] = None
215
215
stop : Optional [Union [str , list [str ]]] = None
@@ -256,7 +256,7 @@ def from_optional(
256
256
repetition_penalty : Optional [float ] = 1.0 ,
257
257
temperature : Optional [float ] = 1.0 ,
258
258
top_p : Optional [float ] = 1.0 ,
259
- top_k : int = - 1 ,
259
+ top_k : int = 0 ,
260
260
min_p : float = 0.0 ,
261
261
seed : Optional [int ] = None ,
262
262
stop : Optional [Union [str , list [str ]]] = None ,
@@ -376,7 +376,7 @@ def __post_init__(self) -> None:
376
376
if self .temperature < _SAMPLING_EPS :
377
377
# Zero temperature means greedy sampling.
378
378
self .top_p = 1.0
379
- self .top_k = - 1
379
+ self .top_k = 0
380
380
self .min_p = 0.0
381
381
self ._verify_greedy_sampling ()
382
382
@@ -404,8 +404,9 @@ def _verify_args(self) -> None:
404
404
f"temperature must be non-negative, got { self .temperature } ." )
405
405
if not 0.0 < self .top_p <= 1.0 :
406
406
raise ValueError (f"top_p must be in (0, 1], got { self .top_p } ." )
407
- if self .top_k < - 1 or self .top_k == 0 :
408
- raise ValueError (f"top_k must be -1 (disable), or at least 1, "
407
+ # quietly accept -1 as disabled, but prefer 0
408
+ if self .top_k < - 1 :
409
+ raise ValueError (f"top_k must be 0 (disable), or at least 1, "
409
410
f"got { self .top_k } ." )
410
411
if not isinstance (self .top_k , int ):
411
412
raise TypeError (
0 commit comments