Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 42 additions & 19 deletions e2e-tests/llm-katan/llm_katan/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,7 @@ async def generate(
raise RuntimeError("Model not loaded. Call load_model() first.")

max_tokens = max_tokens or self.config.max_tokens
temperature = (
temperature if temperature is not None else self.config.temperature
)
temperature = temperature if temperature is not None else self.config.temperature

# Convert messages to prompt
prompt = self._messages_to_prompt(messages)
Expand Down Expand Up @@ -136,20 +134,27 @@ async def generate(
"object": "chat.completion",
"created": int(time.time()),
"model": self.config.served_model_name,
"system_fingerprint": "llm-katan-transformers",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": generated_text},
"logprobs": None,
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"prompt_tokens_details": {"cached_tokens": 0},
"completion_tokens_details": {"reasoning_tokens": 0},
},
}

# Add token_usage as alias for better SDK compatibility
response_data["token_usage"] = response_data["usage"]

if stream:
# For streaming, yield chunks
words = generated_text.split()
Expand All @@ -159,12 +164,12 @@ async def generate(
"object": "chat.completion.chunk",
"created": response_data["created"],
"model": self.config.served_model_name,
"system_fingerprint": "llm-katan-transformers",
"choices": [
{
"index": 0,
"delta": {
"content": word + " " if i < len(words) - 1 else word
},
"delta": {"content": word + " " if i < len(words) - 1 else word},
"logprobs": None,
"finish_reason": None,
}
],
Expand All @@ -178,7 +183,15 @@ async def generate(
"object": "chat.completion.chunk",
"created": response_data["created"],
"model": self.config.served_model_name,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
"system_fingerprint": "llm-katan-transformers",
"choices": [{"index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop"}],
"usage": {
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"prompt_tokens_details": {"cached_tokens": 0},
"completion_tokens_details": {"reasoning_tokens": 0},
},
}
yield final_chunk
else:
Expand Down Expand Up @@ -268,9 +281,7 @@ async def generate(
from vllm.sampling_params import SamplingParams

max_tokens = max_tokens or self.config.max_tokens
temperature = (
temperature if temperature is not None else self.config.temperature
)
temperature = temperature if temperature is not None else self.config.temperature

# Convert messages to prompt
prompt = self._messages_to_prompt(messages)
Expand All @@ -282,9 +293,7 @@ async def generate(

# Generate
loop = asyncio.get_event_loop()
outputs = await loop.run_in_executor(
None, self.engine.generate, [prompt], sampling_params
)
outputs = await loop.run_in_executor(None, self.engine.generate, [prompt], sampling_params)

output = outputs[0]
generated_text = output.outputs[0].text.strip()
Expand All @@ -295,21 +304,27 @@ async def generate(
"object": "chat.completion",
"created": int(time.time()),
"model": self.config.served_model_name,
"system_fingerprint": "llm-katan-vllm",
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": generated_text},
"logprobs": None,
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": len(output.prompt_token_ids),
"completion_tokens": len(output.outputs[0].token_ids),
"total_tokens": len(output.prompt_token_ids)
+ len(output.outputs[0].token_ids),
"total_tokens": len(output.prompt_token_ids) + len(output.outputs[0].token_ids),
"prompt_tokens_details": {"cached_tokens": 0},
"completion_tokens_details": {"reasoning_tokens": 0},
},
}

# Add token_usage as alias for better SDK compatibility
response_data["token_usage"] = response_data["usage"]

if stream:
# For streaming, yield chunks (simplified for now)
words = generated_text.split()
Expand All @@ -319,12 +334,12 @@ async def generate(
"object": "chat.completion.chunk",
"created": response_data["created"],
"model": self.config.served_model_name,
"system_fingerprint": "llm-katan-vllm",
"choices": [
{
"index": 0,
"delta": {
"content": word + " " if i < len(words) - 1 else word
},
"delta": {"content": word + " " if i < len(words) - 1 else word},
"logprobs": None,
"finish_reason": None,
}
],
Expand All @@ -338,7 +353,15 @@ async def generate(
"object": "chat.completion.chunk",
"created": response_data["created"],
"model": self.config.served_model_name,
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
"system_fingerprint": "llm-katan-vllm",
"choices": [{"index": 0, "delta": {}, "logprobs": None, "finish_reason": "stop"}],
"usage": {
"prompt_tokens": len(output.prompt_token_ids),
"completion_tokens": len(output.outputs[0].token_ids),
"total_tokens": len(output.prompt_token_ids) + len(output.outputs[0].token_ids),
"prompt_tokens_details": {"cached_tokens": 0},
"completion_tokens_details": {"reasoning_tokens": 0},
},
}
yield final_chunk
else:
Expand Down
17 changes: 9 additions & 8 deletions e2e-tests/llm-katan/llm_katan/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,7 @@ async def chat_completions(request: ChatCompletionRequest, http_request: Request

try:
# Convert messages to dict format
messages = [
{"role": msg.role, "content": msg.content} for msg in request.messages
]
messages = [{"role": msg.role, "content": msg.content} for msg in request.messages]

# Update metrics
metrics["total_requests"] += 1
Expand All @@ -181,8 +179,13 @@ async def generate_stream():

return StreamingResponse(
generate_stream(),
media_type="text/plain",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"Access-Control-Allow-Origin": "*",
"Access-Control-Allow-Headers": "Content-Type",
},
)
else:
# Non-streaming response
Expand All @@ -198,9 +201,7 @@ async def generate_stream():
response_time = time.time() - start_time
metrics["response_times"].append(response_time)
if "choices" in response and response["choices"]:
generated_text = (
response["choices"][0].get("message", {}).get("content", "")
)
generated_text = response["choices"][0].get("message", {}).get("content", "")
token_count = len(generated_text.split()) # Rough token estimate
metrics["total_tokens_generated"] += token_count

Expand Down
2 changes: 1 addition & 1 deletion e2e-tests/llm-katan/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "llm-katan"
version = "0.1.8"
version = "0.1.9"
description = "LLM Katan - Lightweight LLM Server for Testing - Real tiny models with FastAPI and HuggingFace"
readme = "README.md"
authors = [
Expand Down
Loading