|
1 | 1 | import enum |
2 | 2 | import json |
3 | 3 | from dataclasses import dataclass, field, fields |
4 | | -from typing import (TYPE_CHECKING, ClassVar, List, Mapping, Optional, Tuple, |
5 | | - Type, Union) |
| 4 | +from typing import (TYPE_CHECKING, Any, ClassVar, Dict, List, Mapping, |
| 5 | + Optional, Tuple, Type, Union) |
6 | 6 |
|
7 | 7 | import torch |
8 | 8 | from transformers import PretrainedConfig |
@@ -115,35 +115,39 @@ class ModelConfig: |
115 | 115 | the model name will be the same as `model`. |
116 | 116 | limit_mm_per_prompt: Maximum number of data instances per modality |
117 | 117 | per prompt. Only applicable for multimodal models. |
| 118 | + override_neuron_config: Initialize non default neuron config or |
| 119 | + override default neuron config that are specific to Neuron devices, |
| 120 | + this argument will be used to configure the neuron config that |
| 121 | + can not be gathered from the vllm arguments. |
118 | 122 | """ |
119 | 123 |
|
120 | 124 | def __init__( |
121 | | - self, |
122 | | - model: str, |
123 | | - tokenizer: str, |
124 | | - tokenizer_mode: str, |
125 | | - trust_remote_code: bool, |
126 | | - dtype: Union[str, torch.dtype], |
127 | | - seed: int, |
128 | | - revision: Optional[str] = None, |
129 | | - code_revision: Optional[str] = None, |
130 | | - rope_scaling: Optional[dict] = None, |
131 | | - rope_theta: Optional[float] = None, |
132 | | - tokenizer_revision: Optional[str] = None, |
133 | | - max_model_len: Optional[int] = None, |
134 | | - spec_target_max_model_len: Optional[int] = None, |
135 | | - quantization: Optional[str] = None, |
136 | | - quantization_param_path: Optional[str] = None, |
137 | | - enforce_eager: Optional[bool] = None, |
138 | | - max_context_len_to_capture: Optional[int] = None, |
139 | | - max_seq_len_to_capture: Optional[int] = None, |
140 | | - max_logprobs: int = 20, |
141 | | - disable_sliding_window: bool = False, |
142 | | - skip_tokenizer_init: bool = False, |
143 | | - served_model_name: Optional[Union[str, List[str]]] = None, |
144 | | - limit_mm_per_prompt: Optional[Mapping[str, int]] = None, |
145 | | - use_async_output_proc: bool = True, |
146 | | - ) -> None: |
| 125 | + self, |
| 126 | + model: str, |
| 127 | + tokenizer: str, |
| 128 | + tokenizer_mode: str, |
| 129 | + trust_remote_code: bool, |
| 130 | + dtype: Union[str, torch.dtype], |
| 131 | + seed: int, |
| 132 | + revision: Optional[str] = None, |
| 133 | + code_revision: Optional[str] = None, |
| 134 | + rope_scaling: Optional[dict] = None, |
| 135 | + rope_theta: Optional[float] = None, |
| 136 | + tokenizer_revision: Optional[str] = None, |
| 137 | + max_model_len: Optional[int] = None, |
| 138 | + spec_target_max_model_len: Optional[int] = None, |
| 139 | + quantization: Optional[str] = None, |
| 140 | + quantization_param_path: Optional[str] = None, |
| 141 | + enforce_eager: Optional[bool] = None, |
| 142 | + max_context_len_to_capture: Optional[int] = None, |
| 143 | + max_seq_len_to_capture: Optional[int] = None, |
| 144 | + max_logprobs: int = 20, |
| 145 | + disable_sliding_window: bool = False, |
| 146 | + skip_tokenizer_init: bool = False, |
| 147 | + served_model_name: Optional[Union[str, List[str]]] = None, |
| 148 | + limit_mm_per_prompt: Optional[Mapping[str, int]] = None, |
| 149 | + use_async_output_proc: bool = True, |
| 150 | + override_neuron_config: Optional[Dict[str, Any]] = None) -> None: |
147 | 151 | self.model = model |
148 | 152 | self.tokenizer = tokenizer |
149 | 153 | self.tokenizer_mode = tokenizer_mode |
@@ -227,6 +231,9 @@ def __init__( |
227 | 231 | limit_mm_per_prompt) |
228 | 232 | if not self.skip_tokenizer_init: |
229 | 233 | self._verify_tokenizer_mode() |
| 234 | + |
| 235 | + self.override_neuron_config = override_neuron_config if is_neuron( |
| 236 | + ) else None |
230 | 237 | self._verify_embedding_mode() |
231 | 238 | self._verify_quantization() |
232 | 239 | self._verify_cuda_graph() |
@@ -275,6 +282,7 @@ def _verify_quantization(self) -> None: |
275 | 282 | "experts_int8" |
276 | 283 | ] |
277 | 284 | tpu_supported_quantization = ["tpu_int8"] |
| 285 | + neuron_supported_quantization = ["neuron_quant"] |
278 | 286 | if self.quantization is not None: |
279 | 287 | self.quantization = self.quantization.lower() |
280 | 288 |
|
@@ -329,6 +337,11 @@ def _verify_quantization(self) -> None: |
329 | 337 | "Using AWQ quantization with ROCm, but VLLM_USE_TRITON_AWQ" |
330 | 338 | " is not set, enabling VLLM_USE_TRITON_AWQ.") |
331 | 339 | envs.VLLM_USE_TRITON_AWQ = True |
| 340 | + if is_neuron( |
| 341 | + ) and self.quantization not in neuron_supported_quantization: |
| 342 | + raise ValueError( |
| 343 | + f"{self.quantization} quantization is currently not " |
| 344 | + f"supported in Neuron Backend.") |
332 | 345 |
|
333 | 346 | def _verify_cuda_graph(self) -> None: |
334 | 347 | if self.max_seq_len_to_capture is None: |
|
0 commit comments