|
| 1 | +import json as _json |
1 | 2 | import re |
2 | 3 | from typing import Any, Optional |
3 | 4 |
|
4 | 5 | from fastapi import APIRouter, HTTPException, Request |
| 6 | +from fastapi.responses import StreamingResponse |
5 | 7 | from pydantic import BaseModel, field_validator |
6 | 8 |
|
7 | 9 | from crud import prompts as crud |
8 | 10 | from middlewares.auth import verify_internal_key |
| 11 | +from services.llm_service import stream_chat_completion |
| 12 | +from services.proxy_service import resolve_endpoint_for_key |
9 | 13 | from utils.auth import get_org_id, get_user_id, require_admin |
10 | 14 | from utils.notifications import notify_config_change |
11 | 15 |
|
@@ -377,27 +381,42 @@ async def delete_test_dataset( |
377 | 381 | return {"deleted": True} |
378 | 382 |
|
379 | 383 |
|
380 | | -@router.post("/test", status_code=501) |
| 384 | +@router.post("/test") |
381 | 385 | async def test_prompt(request: Request, body: TestPromptRequest): |
382 | 386 | """ |
383 | | - Test a prompt by resolving variables and proxying to the LLM completion |
384 | | - endpoint. The LLM proxy integration lives in the Express backend for now; |
385 | | - this stub returns 501 until the proxy is co-located in the AIGateway |
386 | | - service. |
| 387 | + Test a prompt by resolving variables and streaming the LLM response. |
| 388 | + Returns an SSE stream compatible with the frontend's streamPromptTest(). |
387 | 389 | """ |
388 | 390 | verify_internal_key(request) |
| 391 | + org_id = get_org_id(request) |
389 | 392 |
|
390 | | - # Resolve variables so callers can at least validate substitution locally |
| 393 | + # Resolve variables in the prompt content |
391 | 394 | resolved_content = crud.resolve_variables( |
392 | 395 | body.content, |
393 | 396 | body.variables or {}, |
394 | 397 | ) |
395 | 398 |
|
396 | | - raise HTTPException( |
397 | | - status_code=501, |
398 | | - detail=( |
399 | | - "test-prompt requires proxy integration — " |
400 | | - "LLM proxy currently runs in the Express backend. " |
401 | | - "Resolved content is available but cannot be forwarded yet." |
402 | | - ), |
403 | | - ) |
| 399 | + # Resolve the endpoint to get provider, model, and API key |
| 400 | + try: |
| 401 | + endpoint = await resolve_endpoint_for_key( |
| 402 | + organization_id=org_id, |
| 403 | + endpoint_slug=body.endpoint_slug, |
| 404 | + allowed_endpoint_ids=[], |
| 405 | + ) |
| 406 | + except ValueError as e: |
| 407 | + raise HTTPException(status_code=404, detail=str(e)) |
| 408 | + |
| 409 | + # Stream the LLM response |
| 410 | + async def _stream(): |
| 411 | + try: |
| 412 | + async for chunk_str in stream_chat_completion( |
| 413 | + model=endpoint["model"], |
| 414 | + messages=resolved_content, |
| 415 | + api_key=endpoint["decrypted_key"], |
| 416 | + ): |
| 417 | + yield chunk_str |
| 418 | + except Exception as e: |
| 419 | + yield f"data: {_json.dumps({'error': str(e)})}\n\n" |
| 420 | + yield "data: [DONE]\n\n" |
| 421 | + |
| 422 | + return StreamingResponse(_stream(), media_type="text/event-stream") |
0 commit comments