From 49383b693bb368ef815b7b657b2291d61e2000fe Mon Sep 17 00:00:00 2001 From: Simon Willison Date: Wed, 28 May 2025 07:41:45 -0700 Subject: [PATCH] set_resolved_model(), refs #2 --- llm_llama_server.py | 40 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 40 insertions(+) diff --git a/llm_llama_server.py b/llm_llama_server.py index 83306ca..252a5c5 100644 --- a/llm_llama_server.py +++ b/llm_llama_server.py @@ -1,5 +1,7 @@ +import httpx import llm from llm.default_plugins.openai_models import Chat, AsyncChat +import os class LlamaServer(Chat): @@ -14,6 +16,19 @@ def __init__(self, **kwargs): **kwargs, ) + def execute(self, prompt, stream, response, conversation=None, key=None): + yield from super().execute(prompt, stream, response, conversation, key) + # Quick timeout limited hit to get resolved_model_id + try: + http_response = httpx.get( + f"{self.api_base}/models", + timeout=httpx.Timeout(0.1, connect=0.1), + ) + http_response.raise_for_status() + set_resolved_model(response, http_response.json()) + except httpx.HTTPError: + pass + def __str__(self): return "llama-server: {}".format(self.model_id) @@ -30,6 +45,20 @@ def __init__(self, **kwargs): **kwargs, ) + async def execute(self, prompt, stream, response, conversation=None, key=None): + async for chunk in super().execute(prompt, stream, response, conversation, key): + yield chunk + try: + async with httpx.AsyncClient() as client: + http_response = await client.get( + f"{self.api_base}/models", + timeout=httpx.Timeout(0.1, connect=0.1), + ) + http_response.raise_for_status() + set_resolved_model(response, http_response.json()) + except httpx.HTTPError: + raise + def __str__(self): return f"llama-server (async): {self.model_id}" @@ -50,6 +79,17 @@ class AsyncLlamaServerTools(AsyncLlamaServer): model_id = "llama-server-tools" +def set_resolved_model(response, data): + try: + model_path = data["data"][0]["id"] + # This will be something like: + # '.../Caches/llama.cpp/unsloth_gemma-3-12b-it-qat-GGUF_gemma-3-12b-it-qat-Q4_K_M.gguf' + resolved_model = os.path.basename(model_path) + response.set_resolved_model(resolved_model) + except (IndexError, KeyError): + raise + + @llm.hookimpl def register_models(register): register(