Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions llm_llama_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import httpx
import llm
from llm.default_plugins.openai_models import Chat, AsyncChat
import os


class LlamaServer(Chat):
Expand All @@ -14,6 +16,19 @@ def __init__(self, **kwargs):
**kwargs,
)

def execute(self, prompt, stream, response, conversation=None, key=None):
yield from super().execute(prompt, stream, response, conversation, key)
# Quick timeout limited hit to get resolved_model_id
try:
http_response = httpx.get(
f"{self.api_base}/models",
timeout=httpx.Timeout(0.1, connect=0.1),
)
http_response.raise_for_status()
set_resolved_model(response, http_response.json())
except httpx.HTTPError:
pass

def __str__(self):
return "llama-server: {}".format(self.model_id)

Expand All @@ -30,6 +45,20 @@ def __init__(self, **kwargs):
**kwargs,
)

async def execute(self, prompt, stream, response, conversation=None, key=None):
async for chunk in super().execute(prompt, stream, response, conversation, key):
yield chunk
try:
async with httpx.AsyncClient() as client:
http_response = await client.get(
f"{self.api_base}/models",
timeout=httpx.Timeout(0.1, connect=0.1),
)
http_response.raise_for_status()
set_resolved_model(response, http_response.json())
except httpx.HTTPError:
raise

def __str__(self):
return f"llama-server (async): {self.model_id}"

Expand All @@ -50,6 +79,17 @@ class AsyncLlamaServerTools(AsyncLlamaServer):
model_id = "llama-server-tools"


def set_resolved_model(response, data):
try:
model_path = data["data"][0]["id"]
# This will be something like:
# '.../Caches/llama.cpp/unsloth_gemma-3-12b-it-qat-GGUF_gemma-3-12b-it-qat-Q4_K_M.gguf'
resolved_model = os.path.basename(model_path)
response.set_resolved_model(resolved_model)
except (IndexError, KeyError):
raise


@llm.hookimpl
def register_models(register):
register(
Expand Down