Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 14 additions & 5 deletions refact-agent/gui/src/features/Chat/Thread/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import {
isToolCallMessage,
Usage,
} from "../../../services/refact";
import { v4 as uuidv4 } from "uuid";
import { parseOrElse } from "../../../utils";
import { type LspChatMessage } from "../../../services/refact";
import { checkForDetailMessage } from "./types";
Expand Down Expand Up @@ -81,8 +82,18 @@ POINT2 FOR_FUTURE_FEREFENCE: ...
function mergeToolCall(prev: ToolCall[], add: ToolCall): ToolCall[] {
const calls = prev.slice();

if (calls[add.index]) {
const prevCall = calls[add.index];
// NOTE: we can't be sure that backend sends correct indexes for tool calls
// in case of qwen3 with sglang I get 2 problems fixed here:
// 1. index of first tool call delta == 2 next == 0 (huh?)
// 2. second tool call in a row has id == null
if (!calls.length || add.function.name) {
add.index = calls.length;
if (!add.id) {
add.id = uuidv4();
}
calls[calls.length] = add;
} else {
const prevCall = calls[calls.length - 1];
const prevArgs = prevCall.function.arguments;
const nextArgs = prevArgs + add.function.arguments;
const call: ToolCall = {
Expand All @@ -92,9 +103,7 @@ function mergeToolCall(prev: ToolCall[], add: ToolCall): ToolCall[] {
arguments: nextArgs,
},
};
calls[add.index] = call;
} else {
calls[add.index] = add;
calls[calls.length - 1] = call;
}
return calls;
}
Expand Down
2 changes: 1 addition & 1 deletion refact-server/refact_utils/third_party/utils/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class ModelCapabilities(BaseModel):
agent: bool
clicks: bool
completion: bool
reasoning: Optional[str] = False
reasoning: Optional[str] = None
boost_reasoning: bool = False


Expand Down
70 changes: 64 additions & 6 deletions refact-server/refact_webgui/webgui/selfhost_fastapi_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from pydantic import BaseModel
from typing import List, Dict, Union, Optional, Tuple, Any

__all__ = ["BaseCompletionsRouter", "CompletionsRouter"]
__all__ = ["BaseCompletionsRouter", "CompletionsRouter", "ThinkingPatcher"]


def clamp(lower, upper, x):
Expand Down Expand Up @@ -192,6 +192,48 @@ async def embeddings_streamer(ticket: Ticket, timeout, created_ts):
ticket.done()


# NOTE: some models doesn't support multiple parsers for now, we need parse thinking manually in this case
class ThinkingPatcher:
def __init__(
self,
thinking_tokens: Optional[Tuple[str, str]],
):
if thinking_tokens is None:
thinking_tokens = None, None
self._thinking_start_token, self._thinking_end_token = thinking_tokens
self._thinking_split_index = set()

def patch_choices(self, choices: List[Dict]) -> List[Dict]:
if self._thinking_end_token is None:
return choices
for choice in choices:
index = choice["index"]
if "delta" in choice:
if content := choice["delta"].get("content"):
if self._thinking_start_token:
content = content.replace(self._thinking_start_token, "")
if index not in self._thinking_split_index:
if self._thinking_end_token in content:
self._thinking_split_index.add(index)
choice["delta"]["reasoning_content"], choice["delta"]["content"] \
= (*content.split(self._thinking_end_token), "")[:2]
else:
choice["delta"]["reasoning_content"] = content
choice["delta"]["content"] = ""
else:
choice["delta"]["reasoning_content"] = ""
choice["delta"]["content"] = content
elif "message" in choice:
if content := choice["message"].get("content", ""):
if self._thinking_start_token:
content = content.replace(self._thinking_start_token, "")
choice["message"]["reasoning_content"], choice["message"]["content"] \
= (*content.split(self._thinking_end_token), "")[:2]
else:
log(f"unknown choice type with keys: {choice.keys()}, skip thinking patch")
return choices


class BaseCompletionsRouter(APIRouter):

def __init__(self,
Expand Down Expand Up @@ -573,6 +615,18 @@ def _wrap_output(output: str) -> str:
"timeout": 60 * 60, # An hour timeout for thinking models
}

thinking_tokens = None
if model_config.capabilities.reasoning in ["qwen", "deepseek"]:
thinking_tokens = "<think>", "</think>"

# Qwen3 thinking arguments override
# NOTE: qwen3 can work in two different modes,
# but we're not pass this specific argument into litellm here
if post.enable_thinking is not None:
completion_kwargs["top_p"] = 0.95
completion_kwargs["presence_penalty"] = 1
thinking_patcher = ThinkingPatcher(thinking_tokens=thinking_tokens)

if post.reasoning_effort or post.thinking:
del completion_kwargs["temperature"]
del completion_kwargs["top_p"]
Expand All @@ -592,11 +646,13 @@ async def litellm_streamer():
async for model_response in response:
try:
data = model_response.dict()
choice0 = data["choices"][0]
finish_reason = choice0["finish_reason"]
if delta := choice0.get("delta"):
if text := delta.get("content"):
generated_tokens_n += litellm.token_counter(model_config.model_id, text=text)
if "choices" in data:
data["choices"] = thinking_patcher.patch_choices(data["choices"])
choice0 = data["choices"][0]
finish_reason = choice0["finish_reason"]
if delta := choice0.get("delta"):
if text := delta.get("content"):
generated_tokens_n += litellm.token_counter(model_config.model_id, text=text)

except json.JSONDecodeError:
data = {"choices": [{"finish_reason": finish_reason}]}
Expand Down Expand Up @@ -628,6 +684,8 @@ async def litellm_non_streamer():
if text := choice.get("message", {}).get("content"):
generated_tokens_n += litellm.token_counter(model_config.model_id, text=text)
finish_reason = choice.get("finish_reason")
if "choices" in data:
data["choices"] = thinking_patcher.patch_choices(data["choices"])
usage_dict = model_config.compose_usage_dict(prompt_tokens_n, generated_tokens_n)
data.update(usage_dict)
except json.JSONDecodeError:
Expand Down