|
2 | 2 | from enum import IntEnum
|
3 | 3 | from functools import cached_property
|
4 | 4 | from typing import Callable, List, Optional, Union
|
| 5 | + |
5 | 6 | import torch
|
6 | 7 |
|
7 | 8 | _SAMPLING_EPS = 1e-5
|
@@ -70,6 +71,8 @@ class SamplingParams:
|
70 | 71 | stop_token_ids: List of tokens that stop the generation when they are
|
71 | 72 | generated. The returned output will contain the stop tokens unless
|
72 | 73 | the stop tokens are special tokens.
|
| 74 | + include_stop_str_in_output: Whether to include the stop strings in output |
| 75 | + text. Defaults to False. |
73 | 76 | ignore_eos: Whether to ignore the EOS token and continue generating
|
74 | 77 | tokens after the EOS token is generated.
|
75 | 78 | max_tokens: Maximum number of tokens to generate per output sequence.
|
@@ -103,6 +106,7 @@ def __init__(
|
103 | 106 | early_stopping: Union[bool, str] = False,
|
104 | 107 | stop: Optional[Union[str, List[str]]] = None,
|
105 | 108 | stop_token_ids: Optional[List[int]] = None,
|
| 109 | + include_stop_str_in_output: bool = False, |
106 | 110 | ignore_eos: bool = False,
|
107 | 111 | max_tokens: int = 16,
|
108 | 112 | logprobs: Optional[int] = None,
|
@@ -140,6 +144,7 @@ def __init__(
|
140 | 144 | self.skip_special_tokens = skip_special_tokens
|
141 | 145 | self.spaces_between_special_tokens = spaces_between_special_tokens
|
142 | 146 | self.logits_processors = logits_processors
|
| 147 | + self.include_stop_str_in_output = include_stop_str_in_output |
143 | 148 | self._verify_args()
|
144 | 149 | if self.use_beam_search:
|
145 | 150 | self._verify_beam_search()
|
@@ -227,24 +232,26 @@ def sampling_type(self) -> SamplingType:
|
227 | 232 | return SamplingType.RANDOM
|
228 | 233 |
|
229 | 234 | def __repr__(self) -> str:
|
230 |
| - return (f"SamplingParams(n={self.n}, " |
231 |
| - f"best_of={self.best_of}, " |
232 |
| - f"presence_penalty={self.presence_penalty}, " |
233 |
| - f"frequency_penalty={self.frequency_penalty}, " |
234 |
| - f"repetition_penalty={self.repetition_penalty}, " |
235 |
| - f"temperature={self.temperature}, " |
236 |
| - f"top_p={self.top_p}, " |
237 |
| - f"top_k={self.top_k}, " |
238 |
| - f"min_p={self.min_p}, " |
239 |
| - f"use_beam_search={self.use_beam_search}, " |
240 |
| - f"length_penalty={self.length_penalty}, " |
241 |
| - f"early_stopping={self.early_stopping}, " |
242 |
| - f"stop={self.stop}, " |
243 |
| - f"stop_token_ids={self.stop_token_ids}, " |
244 |
| - f"ignore_eos={self.ignore_eos}, " |
245 |
| - f"max_tokens={self.max_tokens}, " |
246 |
| - f"logprobs={self.logprobs}, " |
247 |
| - f"prompt_logprobs={self.prompt_logprobs}, " |
248 |
| - f"skip_special_tokens={self.skip_special_tokens}, " |
249 |
| - "spaces_between_special_tokens=" |
250 |
| - f"{self.spaces_between_special_tokens})") |
| 235 | + return ( |
| 236 | + f"SamplingParams(n={self.n}, " |
| 237 | + f"best_of={self.best_of}, " |
| 238 | + f"presence_penalty={self.presence_penalty}, " |
| 239 | + f"frequency_penalty={self.frequency_penalty}, " |
| 240 | + f"repetition_penalty={self.repetition_penalty}, " |
| 241 | + f"temperature={self.temperature}, " |
| 242 | + f"top_p={self.top_p}, " |
| 243 | + f"top_k={self.top_k}, " |
| 244 | + f"min_p={self.min_p}, " |
| 245 | + f"use_beam_search={self.use_beam_search}, " |
| 246 | + f"length_penalty={self.length_penalty}, " |
| 247 | + f"early_stopping={self.early_stopping}, " |
| 248 | + f"stop={self.stop}, " |
| 249 | + f"stop_token_ids={self.stop_token_ids}, " |
| 250 | + f"include_stop_str_in_output={self.include_stop_str_in_output}, " |
| 251 | + f"ignore_eos={self.ignore_eos}, " |
| 252 | + f"max_tokens={self.max_tokens}, " |
| 253 | + f"logprobs={self.logprobs}, " |
| 254 | + f"prompt_logprobs={self.prompt_logprobs}, " |
| 255 | + f"skip_special_tokens={self.skip_special_tokens}, " |
| 256 | + "spaces_between_special_tokens=" |
| 257 | + f"{self.spaces_between_special_tokens})") |
0 commit comments