|
16 | 16 | from pathlib import Path
|
17 | 17 | from typing import Any, Callable, Literal, Optional, Union
|
18 | 18 |
|
| 19 | +import anthropic |
19 | 20 | import cloudpickle
|
20 | 21 | import openai
|
21 | 22 | import pytest
|
@@ -194,6 +195,130 @@ def get_async_client(self, **kwargs):
|
194 | 195 | **kwargs)
|
195 | 196 |
|
196 | 197 |
|
| 198 | +class RemoteAnthropicServer: |
| 199 | + DUMMY_API_KEY = "token-abc123" # vLLM's Anthropic server does not need API key |
| 200 | + def __init__(self, |
| 201 | + model: str, |
| 202 | + vllm_serve_args: list[str], |
| 203 | + *, |
| 204 | + env_dict: Optional[dict[str, str]] = None, |
| 205 | + seed: Optional[int] = 0, |
| 206 | + auto_port: bool = True, |
| 207 | + max_wait_seconds: Optional[float] = None) -> None: |
| 208 | + if auto_port: |
| 209 | + if "-p" in vllm_serve_args or "--port" in vllm_serve_args: |
| 210 | + raise ValueError("You have manually specified the port " |
| 211 | + "when `auto_port=True`.") |
| 212 | + |
| 213 | + # Don't mutate the input args |
| 214 | + vllm_serve_args = vllm_serve_args + [ |
| 215 | + "--port", str(get_open_port()) |
| 216 | + ] |
| 217 | + if seed is not None: |
| 218 | + if "--seed" in vllm_serve_args: |
| 219 | + raise ValueError("You have manually specified the seed " |
| 220 | + f"when `seed={seed}`.") |
| 221 | + |
| 222 | + vllm_serve_args = vllm_serve_args + ["--seed", str(seed)] |
| 223 | + |
| 224 | + parser = FlexibleArgumentParser( |
| 225 | + description="vLLM's remote Anthropic server.") |
| 226 | + subparsers = parser.add_subparsers(required=False, dest="subparser") |
| 227 | + parser = ServeSubcommand().subparser_init(subparsers) |
| 228 | + args = parser.parse_args(["--model", model, *vllm_serve_args]) |
| 229 | + self.host = str(args.host or 'localhost') |
| 230 | + self.port = int(args.port) |
| 231 | + |
| 232 | + self.show_hidden_metrics = \ |
| 233 | + args.show_hidden_metrics_for_version is not None |
| 234 | + |
| 235 | + # download the model before starting the server to avoid timeout |
| 236 | + is_local = os.path.isdir(model) |
| 237 | + if not is_local: |
| 238 | + engine_args = AsyncEngineArgs.from_cli_args(args) |
| 239 | + model_config = engine_args.create_model_config() |
| 240 | + load_config = engine_args.create_load_config() |
| 241 | + |
| 242 | + model_loader = get_model_loader(load_config) |
| 243 | + model_loader.download_model(model_config) |
| 244 | + |
| 245 | + env = os.environ.copy() |
| 246 | + # the current process might initialize cuda, |
| 247 | + # to be safe, we should use spawn method |
| 248 | + env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' |
| 249 | + if env_dict is not None: |
| 250 | + env.update(env_dict) |
| 251 | + self.proc = subprocess.Popen( |
| 252 | + ["python -m", "vllm.entrypoints.anthropic.api_server", model, *vllm_serve_args], |
| 253 | + env=env, |
| 254 | + stdout=sys.stdout, |
| 255 | + stderr=sys.stderr, |
| 256 | + ) |
| 257 | + max_wait_seconds = max_wait_seconds or 240 |
| 258 | + self._wait_for_server(url=self.url_for("health"), |
| 259 | + timeout=max_wait_seconds) |
| 260 | + |
| 261 | + def __enter__(self): |
| 262 | + return self |
| 263 | + |
| 264 | + def __exit__(self, exc_type, exc_value, traceback): |
| 265 | + self.proc.terminate() |
| 266 | + try: |
| 267 | + self.proc.wait(8) |
| 268 | + except subprocess.TimeoutExpired: |
| 269 | + # force kill if needed |
| 270 | + self.proc.kill() |
| 271 | + |
| 272 | + def _wait_for_server(self, *, url: str, timeout: float): |
| 273 | + # run health check |
| 274 | + start = time.time() |
| 275 | + while True: |
| 276 | + try: |
| 277 | + if requests.get(url).status_code == 200: |
| 278 | + break |
| 279 | + except Exception: |
| 280 | + # this exception can only be raised by requests.get, |
| 281 | + # which means the server is not ready yet. |
| 282 | + # the stack trace is not useful, so we suppress it |
| 283 | + # by using `raise from None`. |
| 284 | + result = self.proc.poll() |
| 285 | + if result is not None and result != 0: |
| 286 | + raise RuntimeError("Server exited unexpectedly.") from None |
| 287 | + |
| 288 | + time.sleep(0.5) |
| 289 | + if time.time() - start > timeout: |
| 290 | + raise RuntimeError( |
| 291 | + "Server failed to start in time.") from None |
| 292 | + |
| 293 | + @property |
| 294 | + def url_root(self) -> str: |
| 295 | + return f"http://{self.host}:{self.port}" |
| 296 | + |
| 297 | + def url_for(self, *parts: str) -> str: |
| 298 | + return self.url_root + "/" + "/".join(parts) |
| 299 | + |
| 300 | + def get_client(self, **kwargs): |
| 301 | + if "timeout" not in kwargs: |
| 302 | + kwargs["timeout"] = 600 |
| 303 | + return anthropic.Anthropic( |
| 304 | + base_url=self.url_for("v1"), |
| 305 | + api_key=self.DUMMY_API_KEY, |
| 306 | + max_retries=0, |
| 307 | + **kwargs, |
| 308 | + ) |
| 309 | + |
| 310 | + def get_async_client(self, **kwargs): |
| 311 | + if "timeout" not in kwargs: |
| 312 | + kwargs["timeout"] = 600 |
| 313 | + return anthropic.AsyncAnthropic( |
| 314 | + base_url=self.url_for("v1"), |
| 315 | + api_key=self.DUMMY_API_KEY, |
| 316 | + max_retries=0, |
| 317 | + **kwargs |
| 318 | + ) |
| 319 | + |
| 320 | + |
| 321 | + |
197 | 322 | def _test_completion(
|
198 | 323 | client: openai.OpenAI,
|
199 | 324 | model: str,
|
|
0 commit comments