diff --git a/refact-agent/gui/src/features/Chat/Thread/utils.ts b/refact-agent/gui/src/features/Chat/Thread/utils.ts index 0d7c67e82..e97b9d29d 100644 --- a/refact-agent/gui/src/features/Chat/Thread/utils.ts +++ b/refact-agent/gui/src/features/Chat/Thread/utils.ts @@ -36,6 +36,7 @@ import { isToolCallMessage, Usage, } from "../../../services/refact"; +import { v4 as uuidv4 } from "uuid"; import { parseOrElse } from "../../../utils"; import { type LspChatMessage } from "../../../services/refact"; import { checkForDetailMessage } from "./types"; @@ -81,8 +82,18 @@ POINT2 FOR_FUTURE_FEREFENCE: ... function mergeToolCall(prev: ToolCall[], add: ToolCall): ToolCall[] { const calls = prev.slice(); - if (calls[add.index]) { - const prevCall = calls[add.index]; + // NOTE: we can't be sure that backend sends correct indexes for tool calls + // in case of qwen3 with sglang I get 2 problems fixed here: + // 1. index of first tool call delta == 2 next == 0 (huh?) + // 2. second tool call in a row has id == null + if (!calls.length || add.function.name) { + add.index = calls.length; + if (!add.id) { + add.id = uuidv4(); + } + calls[calls.length] = add; + } else { + const prevCall = calls[calls.length - 1]; const prevArgs = prevCall.function.arguments; const nextArgs = prevArgs + add.function.arguments; const call: ToolCall = { @@ -92,9 +103,7 @@ function mergeToolCall(prev: ToolCall[], add: ToolCall): ToolCall[] { arguments: nextArgs, }, }; - calls[add.index] = call; - } else { - calls[add.index] = add; + calls[calls.length - 1] = call; } return calls; } diff --git a/refact-server/refact_utils/third_party/utils/configs.py b/refact-server/refact_utils/third_party/utils/configs.py index acd503853..f394009e1 100644 --- a/refact-server/refact_utils/third_party/utils/configs.py +++ b/refact-server/refact_utils/third_party/utils/configs.py @@ -17,7 +17,7 @@ class ModelCapabilities(BaseModel): agent: bool clicks: bool completion: bool - reasoning: Optional[str] = False + reasoning: Optional[str] = None boost_reasoning: bool = False diff --git a/refact-server/refact_webgui/webgui/selfhost_fastapi_completions.py b/refact-server/refact_webgui/webgui/selfhost_fastapi_completions.py index 4060648cf..7edc582aa 100644 --- a/refact-server/refact_webgui/webgui/selfhost_fastapi_completions.py +++ b/refact-server/refact_webgui/webgui/selfhost_fastapi_completions.py @@ -30,7 +30,7 @@ from pydantic import BaseModel from typing import List, Dict, Union, Optional, Tuple, Any -__all__ = ["BaseCompletionsRouter", "CompletionsRouter"] +__all__ = ["BaseCompletionsRouter", "CompletionsRouter", "ThinkingPatcher"] def clamp(lower, upper, x): @@ -192,6 +192,48 @@ async def embeddings_streamer(ticket: Ticket, timeout, created_ts): ticket.done() +# NOTE: some models doesn't support multiple parsers for now, we need parse thinking manually in this case +class ThinkingPatcher: + def __init__( + self, + thinking_tokens: Optional[Tuple[str, str]], + ): + if thinking_tokens is None: + thinking_tokens = None, None + self._thinking_start_token, self._thinking_end_token = thinking_tokens + self._thinking_split_index = set() + + def patch_choices(self, choices: List[Dict]) -> List[Dict]: + if self._thinking_end_token is None: + return choices + for choice in choices: + index = choice["index"] + if "delta" in choice: + if content := choice["delta"].get("content"): + if self._thinking_start_token: + content = content.replace(self._thinking_start_token, "") + if index not in self._thinking_split_index: + if self._thinking_end_token in content: + self._thinking_split_index.add(index) + choice["delta"]["reasoning_content"], choice["delta"]["content"] \ + = (*content.split(self._thinking_end_token), "")[:2] + else: + choice["delta"]["reasoning_content"] = content + choice["delta"]["content"] = "" + else: + choice["delta"]["reasoning_content"] = "" + choice["delta"]["content"] = content + elif "message" in choice: + if content := choice["message"].get("content", ""): + if self._thinking_start_token: + content = content.replace(self._thinking_start_token, "") + choice["message"]["reasoning_content"], choice["message"]["content"] \ + = (*content.split(self._thinking_end_token), "")[:2] + else: + log(f"unknown choice type with keys: {choice.keys()}, skip thinking patch") + return choices + + class BaseCompletionsRouter(APIRouter): def __init__(self, @@ -573,6 +615,18 @@ def _wrap_output(output: str) -> str: "timeout": 60 * 60, # An hour timeout for thinking models } + thinking_tokens = None + if model_config.capabilities.reasoning in ["qwen", "deepseek"]: + thinking_tokens = "", "" + + # Qwen3 thinking arguments override + # NOTE: qwen3 can work in two different modes, + # but we're not pass this specific argument into litellm here + if post.enable_thinking is not None: + completion_kwargs["top_p"] = 0.95 + completion_kwargs["presence_penalty"] = 1 + thinking_patcher = ThinkingPatcher(thinking_tokens=thinking_tokens) + if post.reasoning_effort or post.thinking: del completion_kwargs["temperature"] del completion_kwargs["top_p"] @@ -592,11 +646,13 @@ async def litellm_streamer(): async for model_response in response: try: data = model_response.dict() - choice0 = data["choices"][0] - finish_reason = choice0["finish_reason"] - if delta := choice0.get("delta"): - if text := delta.get("content"): - generated_tokens_n += litellm.token_counter(model_config.model_id, text=text) + if "choices" in data: + data["choices"] = thinking_patcher.patch_choices(data["choices"]) + choice0 = data["choices"][0] + finish_reason = choice0["finish_reason"] + if delta := choice0.get("delta"): + if text := delta.get("content"): + generated_tokens_n += litellm.token_counter(model_config.model_id, text=text) except json.JSONDecodeError: data = {"choices": [{"finish_reason": finish_reason}]} @@ -628,6 +684,8 @@ async def litellm_non_streamer(): if text := choice.get("message", {}).get("content"): generated_tokens_n += litellm.token_counter(model_config.model_id, text=text) finish_reason = choice.get("finish_reason") + if "choices" in data: + data["choices"] = thinking_patcher.patch_choices(data["choices"]) usage_dict = model_config.compose_usage_dict(prompt_tokens_n, generated_tokens_n) data.update(usage_dict) except json.JSONDecodeError: