Skip to content
45 changes: 41 additions & 4 deletions agents/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,30 @@

logger = logging.getLogger(__name__)

REQUEST_TIMEOUT = float(os.getenv("RITS_REQUEST_TIMEOUT_SECONDS", 60.0))
MAX_RETRIES = int(os.getenv("RITS_MAX_RETRIES", 2))

timeout = httpx.Timeout(
connect=10.0,
read=REQUEST_TIMEOUT,
write=30.0,
pool=10.0,
)

class RITSChatModel(BaseChatModel):
"""LangChain-compatible chat model using httpx for internal RITS inference service."""

# Mapping from endpoint name (short) to payload model name (full)
MODEL_NAME_MAPPING: Dict[str, str] = {
"llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
# Open Source Models
"qwen3-5-397b-a17b-fp8": "Qwen/Qwen3.5-397B-A17B-FP8",
"mistral-large-3-675b-2512-fp4": "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4",
"glm-5-1": "",
"moonshotai-kimi-k2-5":"moonshotai/Kimi-K2.5",
"gpt-oss-120b": "openai/gpt-oss-120b",
"qwen3-5-397b-a17b-fp8": "qwen/qwen3.5-397B-A17B-FP8",
"mistral-large-3-675b-2512-fp4": "mistralai/Mistral-Large-3-675B-Instruct-2512-NVFP4"
# smaller models
"llama-3-3-70b-instruct": "meta-llama/llama-3-3-70b-instruct",
"qwen2-5-72b-instruct": "Qwen/Qwen2.5-72B-Instruct",
}

model_name: str
Expand Down Expand Up @@ -125,12 +139,35 @@ async def _agenerate(
if self.bound_tools:
payload["tools"] = self.bound_tools

# Add MAX_RETRIES and timeout handling
# async with httpx.AsyncClient(timeout=timeout) as client:
# for attempt in range(MAX_RETRIES + 1):
# try:
# resp = await client.post(
# url,
# json=payload,
# headers=headers,
# )
# resp.raise_for_status()
# break

# except httpx.ReadTimeout:
# if attempt == MAX_RETRIES:
# raise
# await asyncio.sleep(2 ** attempt)

# except httpx.HTTPError:
# if attempt == MAX_RETRIES:
# raise
# await asyncio.sleep(2 ** attempt)
# data = resp.json()

async with httpx.AsyncClient() as client:
resp = await client.post(
url,
headers=headers,
json=payload,
timeout=60.0
timeout=float(os.environ.get("RITS_REQUEST_TIMEOUT_SECONDS", "60"))
)
resp.raise_for_status()
data = resp.json()
Expand Down
15 changes: 12 additions & 3 deletions benchmark_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
Results saved to: output/capability_{id}_{timestamp}/<domain>.json
e.g. output/capability_2_feb_18_11_21am/hockey.json
"""
import os
import asyncio
from contextlib import AsyncExitStack
import json
Expand Down Expand Up @@ -145,7 +146,7 @@ def _setup_phoenix(endpoint: str, project_name: str = "enterprise-benchmark") ->
Path(__file__).parent / "benchmark" / "mcp_connection_config.yaml"
)
# Timeout for agent execution (seconds)
AGENT_TIMEOUT_SECONDS = 300
AGENT_TIMEOUT_SECONDS = float(os.environ.get("AGENT_TIMEOUT_SECONDS", "300"))


async def run_benchmark_for_domain(
Expand Down Expand Up @@ -316,7 +317,7 @@ async def run_benchmark_for_domain(
except Exception as e:
import traceback
result.status = "error"
result.error = str(e)
result.error = f"{type(e).__name__} "+str(e)
tlog(f" Status: error | {type(e).__name__}: {str(e)[:200]}")
tlog(f" Traceback: {traceback.format_exc()}")

Expand Down Expand Up @@ -357,6 +358,7 @@ async def run_capability(
top_k_tools: int = 0,
max_iterations: Optional[int] = None,
restart: bool = False,
temperature: float = 0.0,
) -> List[BenchmarkResult]:
"""Run benchmark for a given capability_id, iterating over all domain files."""

Expand Down Expand Up @@ -397,7 +399,7 @@ async def run_capability(
tlog(f"Restart mode: skipping {len(completed)} already-completed domain(s): {sorted(completed)}")
domain_list = [d for d in domain_list if d not in completed]

llm = create_llm(provider=provider, model=model)
llm = create_llm(provider=provider, model=model, temperature=temperature)

# Process each domain, writing output incrementally
all_results: List[BenchmarkResult] = []
Expand Down Expand Up @@ -553,6 +555,12 @@ def main():
default="enterprise-benchmark",
help="Phoenix project name for grouping traces (default: enterprise-benchmark)",
)
parser.add_argument(
"--temperature",
type=float,
default=0.0,
help="LLM temperature (default: 0.0)"
)

args = parser.parse_args()
capability_ids = args.capability_id # list of ints now
Expand Down Expand Up @@ -588,6 +596,7 @@ def _make_run_task_coro(tid: int):
top_k_tools=args.top_k_tools,
max_iterations=args.max_iterations,
restart=args.restart,
temperature=args.temperature
)

def _make_list_tools_coro(tid: int):
Expand Down
2 changes: 1 addition & 1 deletion evaluator/judge.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
import os
import json
import deepcopy
from copy import deepcopy
from prompt import GroundednessPrompt, CorrectnessPrompt
from utils import JudgeInput, JudgeOutput
from langchain_openai import ChatOpenAI
Expand Down
Loading