Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions refact-agent/engine/src/call_validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ pub struct SamplingParameters {
pub reasoning_effort: Option<ReasoningEffort>, // OpenAI style reasoning
#[serde(default)]
pub thinking: Option<serde_json::Value>, // Anthropic style reasoning
#[serde(default)]
pub enable_thinking: Option<bool>, // Qwen style reasoning
}

#[derive(Debug, Deserialize, Clone)]
Expand Down
8 changes: 7 additions & 1 deletion refact-agent/engine/src/forward_to_openai_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ pub async fn forward_to_openai_style_endpoint(
data["reasoning_effort"] = serde_json::Value::String(reasoning_effort.to_string());
} else if let Some(thinking) = sampling_parameters.thinking.clone() {
data["thinking"] = thinking.clone();
} else if let Some(enable_thinking) = sampling_parameters.enable_thinking {
data["enable_thinking"] = serde_json::Value::Bool(enable_thinking);
data["temperature"] = serde_json::Value::from(sampling_parameters.temperature);
} else if let Some(temperature) = sampling_parameters.temperature {
data["temperature"] = serde_json::Value::from(temperature);
}
Expand Down Expand Up @@ -130,7 +133,10 @@ pub async fn forward_to_openai_style_endpoint_streaming(
data["reasoning_effort"] = serde_json::Value::String(reasoning_effort.to_string());
} else if let Some(thinking) = sampling_parameters.thinking.clone() {
data["thinking"] = thinking.clone();
} else if let Some(temperature) = sampling_parameters.temperature {
} else if let Some(enable_thinking) = sampling_parameters.enable_thinking {
data["enable_thinking"] = serde_json::Value::Bool(enable_thinking);
data["temperature"] = serde_json::Value::from(sampling_parameters.temperature);
}else if let Some(temperature) = sampling_parameters.temperature {
data["temperature"] = serde_json::Value::from(temperature);
}
data["max_completion_tokens"] = serde_json::Value::from(sampling_parameters.max_new_tokens);
Expand Down
15 changes: 11 additions & 4 deletions refact-agent/engine/src/scratchpads/chat_passthrough.rs
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ impl ScratchpadAbstract for ChatPassthrough {
// drop all reasoning parameters in case of non-reasoning model
sampling_parameters_to_patch.reasoning_effort = None;
sampling_parameters_to_patch.thinking = None;
sampling_parameters_to_patch.enable_thinking = None;
limited_msgs
};

Expand Down Expand Up @@ -282,7 +283,6 @@ fn _adapt_for_reasoning_models(
sampling_parameters.reasoning_effort = Some(ReasoningEffort::High);
}
sampling_parameters.temperature = default_temperature;
sampling_parameters.thinking = None;

// NOTE: OpenAI prefer user message over system
messages.into_iter().map(|mut msg| {
Expand All @@ -304,12 +304,19 @@ fn _adapt_for_reasoning_models(
"budget_tokens": budget_tokens,
}));
}
sampling_parameters.reasoning_effort = None;
messages
},
"qwen" => {
if supports_boost_reasoning && sampling_parameters.boost_reasoning {
sampling_parameters.enable_thinking = Some(true);
} else {
sampling_parameters.enable_thinking = Some(false);
}
// In fact qwen3 wants 0.7 temperature for no-thinking mode but we'll use defaults for thinking
sampling_parameters.temperature = default_temperature.clone();
messages
},
_ => {
sampling_parameters.reasoning_effort = None;
sampling_parameters.thinking = None;
sampling_parameters.temperature = default_temperature.clone();
messages
}
Expand Down
2 changes: 1 addition & 1 deletion refact-server/refact_utils/third_party/utils/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def to_chat_model_record(self) -> Dict[str, Any]:
"supports_agent": self.capabilities.agent,
"supports_reasoning": self.capabilities.reasoning,
"supports_boost_reasoning": self.capabilities.boost_reasoning,
"default_temperature": 0.6 if self.capabilities.reasoning == "deepseek" else None,
"default_temperature": 0.6 if self.capabilities.reasoning in ["deepseek", "qwen"] else None,
}


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class ChatContext(NlpSamplingParams):
n: int = 1
reasoning_effort: Optional[str] = None # OpenAI style reasoning
thinking: Optional[Dict] = None # Anthropic style reasoning
enable_thinking: Optional[bool] = None # Qwen style reasoning


class EmbeddingsStyleOpenAI(BaseModel):
Expand Down Expand Up @@ -569,6 +570,7 @@ def _wrap_output(output: str) -> str:
"stop": post.stop if post.stop else None,
"n": post.n,
"extra_headers": model_config.extra_headers if model_config.extra_headers else None,
"timeout": 60 * 60, # An hour timeout for thinking models
}

if post.reasoning_effort or post.thinking:
Expand Down
31 changes: 29 additions & 2 deletions refact-server/refact_webgui/webgui/selfhost_model_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
__all__ = ["ModelAssigner"]


ALLOWED_N_CTX = [2 ** p for p in range(10, 20)]
# ALLOWED_N_CTX = [2 ** p for p in range(10, 20)]
ALLOWED_N_CTX = [1024, 2048, 4096] + [8192 * (t + 1) for t in range(0, 16)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we reducing max allowed n_ctx from 2^20 to 2^17 ? there are models with around 200000 context and 1000000 context already

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll increase it but in fact we don't have models with context > 40k
large context requires some hacks for it

ALLOWED_GPUS_SHARD = [2 ** p for p in range(10)]
ALLOWED_CONCURRENCY = [2 ** p for p in range(9)]


def has_context_switch(filter_caps: List[str]) -> bool:
Expand Down Expand Up @@ -55,6 +57,7 @@ class ModelWatchdogDConfig:
share_gpu: bool
n_ctx: Optional[int] = None
has_loras: bool = False
concurrency: Optional[int] = None

def dump(self, model_cfg_j: Dict) -> str:
model_cfg_j["command_line"].extend(["--model", self.model_name])
Expand All @@ -63,6 +66,8 @@ def dump(self, model_cfg_j: Dict) -> str:
model_cfg_j["command_line"].extend(["--n-ctx", self.n_ctx])
if not self.has_loras:
model_cfg_j["command_line"].append("--loraless")
if self.concurrency:
model_cfg_j["command_line"].extend(["--concurrency", self.concurrency])

model_cfg_j["gpus"] = self.gpus
model_cfg_j["share_gpu"] = self.share_gpu
Expand Down Expand Up @@ -103,6 +108,10 @@ def shard_gpu_backends(self) -> Set[str]:
def share_gpu_backends(self) -> Set[str]:
return {"transformers"}

@property
def concurrency_backends(self) -> Set[str]:
return set()

@property
def models_db(self) -> Dict[str, Any]:
return models_mini_db
Expand Down Expand Up @@ -218,6 +227,7 @@ def _model_inference_setup(self, inference_config: Dict[str, Any]) -> Dict[str,
share_gpu=assignment.get("share_gpu", False),
n_ctx=assignment.get("n_ctx", None),
has_loras=self._has_loras(model_name),
concurrency=assignment.get("concurrency", None),
))
continue
for model_cursor in range(cursor, next_cursor, assignment["gpus_shard"]):
Expand All @@ -228,6 +238,7 @@ def _model_inference_setup(self, inference_config: Dict[str, Any]) -> Dict[str,
share_gpu=assignment.get("share_gpu", False),
n_ctx=assignment.get("n_ctx", None),
has_loras=self._has_loras(model_name),
concurrency=assignment.get("concurrency", None),
))
for _ in range(model_group.gpus_shard()):
if gpus[cursor]["mem_total_mb"] < model_group.required_memory_mb(self.models_db):
Expand Down Expand Up @@ -327,6 +338,13 @@ def models_info(self):
gpus_shard for gpus_shard in ALLOWED_GPUS_SHARD
if gpus_shard <= max_available_shards
]
if rec["backend"] in self.concurrency_backends:
default_concurrency = ALLOWED_CONCURRENCY[-1]
available_concurrency = ALLOWED_CONCURRENCY
else:
default_concurrency = 0
available_concurrency = []

info.append({
"name": k,
"backend": rec["backend"],
Expand All @@ -340,6 +358,8 @@ def models_info(self):
"default_n_ctx": default_n_ctx,
"available_n_ctx": available_n_ctx,
"available_shards": available_shards,
"default_concurrency": default_concurrency,
"available_concurrency": available_concurrency,
"is_deprecated": bool(rec.get("deprecated", False)),
"repo_status": self._models_repo_status[k],
"repo_url": f"https://huggingface.co/{rec['model_path']}",
Expand Down Expand Up @@ -367,8 +387,15 @@ def _set_n_ctx(model: str, record: Dict) -> Dict:
record["n_ctx"] = n_ctx
return record

def _set_concurrency(model: str, record: Dict) -> Dict:
if self.models_db[model]["backend"] in self.concurrency_backends:
record["concurrency"] = record.get("concurrency", ALLOWED_CONCURRENCY[-1])
else:
record["concurrency"] = 0
return record

j["model_assign"] = self._share_gpu_filter({
model: _set_n_ctx(model, v)
model: _set_concurrency(model, _set_n_ctx(model, v))
for model, v in j["model_assign"].items()
if model in self.models_db
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ <h3>Hosted Models</h3>
<tr>
<th>Model</th>
<th>Context</th>
<th>Concurrency</th>
<th>Finetune</th>
<th>Sharding</th>
<th>Share GPU</th>
Expand Down
33 changes: 29 additions & 4 deletions refact-server/refact_webgui/webgui/static/tab-model-hosting.js
Original file line number Diff line number Diff line change
Expand Up @@ -262,13 +262,15 @@ function render_models_assigned(models) {
const row = document.createElement('tr');
row.setAttribute('data-model',index);
let model_name = document.createElement("td");
model_name.style.width = "20%";
model_name.style.width = "18%";
let context = document.createElement("td");
context.style.width = "15%";
context.style.width = "12%";
let concurrency = document.createElement("td");
concurrency.style.width = "12%";
let finetune_info = document.createElement("td");
finetune_info.style.width = "35%";
finetune_info.style.width = "30%";
let select_gpus = document.createElement("td");
select_gpus.style.width = "15%";
select_gpus.style.width = "13%";
let gpus_share = document.createElement("td");
gpus_share.style.width = "10%";
let del = document.createElement("td");
Expand Down Expand Up @@ -317,6 +319,27 @@ function render_models_assigned(models) {
context.innerHTML = `<span class="default-context">${models_info[index].default_n_ctx}</span>`;
}

const model_concurrency = models_data.model_assign[index].concurrency || models_info[index].default_concurrency;
if (models_info[index].available_concurrency && models_info[index].available_concurrency.length > 0) {
const concurrency_options = models_info[index].available_concurrency;
const concurrency_input = document.createElement("select");
concurrency_input.classList.add('form-select','form-select-sm');
concurrency_options.forEach(element => {
const concurrency_option = document.createElement("option");
concurrency_option.setAttribute('value',element);
concurrency_option.textContent = element;
if(element === model_concurrency) {
concurrency_option.setAttribute('selected','selected');
}
concurrency_input.appendChild(concurrency_option);
});
concurrency_input.addEventListener('change', function() {
models_data.model_assign[index].concurrency = Number(this.value);
save_model_assigned();
});
concurrency.appendChild(concurrency_input);
}

let finetune_runs = [];
if (finetune_configs_and_runs) {
finetune_runs = finetune_configs_and_runs.finetune_runs.filter(
Expand Down Expand Up @@ -397,6 +420,7 @@ function render_models_assigned(models) {

row.appendChild(model_name);
row.appendChild(context);
row.appendChild(concurrency);
row.appendChild(finetune_info);
row.appendChild(select_gpus);
row.appendChild(gpus_share);
Expand Down Expand Up @@ -680,6 +704,7 @@ function render_models(models) {
models_data.model_assign[model_name] = {
gpus_shard: default_gpus_shard,
n_ctx: element.default_n_ctx,
concurrency: element.default_concurrency,
};
save_model_assigned();
add_model_modal.hide();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,7 @@ function showAddModelModal(providerId) {
<option value="">None</option>
<option value="openai">OpenAI</option>
<option value="anthropic">Anthropic</option>
<option value="qwen">Qwen</option>
<option value="deepseek">DeepSeek</option>
</select>
<div class="form-text">Select the reasoning type supported by this model.</div>
Expand Down
7 changes: 6 additions & 1 deletion refact-server/refact_webgui/webgui/tab_models_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from refact_webgui.webgui.tab_loras import rm
from refact_webgui.webgui.tab_loras import unpack
from refact_webgui.webgui.tab_loras import write_to_file
from refact_webgui.webgui.selfhost_model_assigner import ModelAssigner
from refact_webgui.webgui.selfhost_model_assigner import ModelAssigner, ALLOWED_CONCURRENCY

from pathlib import Path
from pydantic import BaseModel
Expand Down Expand Up @@ -38,6 +38,7 @@ class TabHostModelRec(BaseModel):
gpus_shard: int = Query(default=1, ge=0, le=1024)
share_gpu: bool = False
n_ctx: Optional[int] = None
concurrency: Optional[int] = None


class TabHostModelsAssign(BaseModel):
Expand Down Expand Up @@ -111,11 +112,15 @@ async def _tab_host_models_assign(self, post: TabHostModelsAssign):
for model_name, model_cfg in post.model_assign.items():
if model_cfg.n_ctx is None:
raise HTTPException(status_code=400, detail=f"n_ctx must be set for {model_name}")
if model_cfg.concurrency is None:
raise HTTPException(status_code=400, detail=f"concurrency must be set for {model_name}")
for model_info in self._model_assigner.models_info["models"]:
if model_info["name"] == model_name:
max_n_ctx = model_info["default_n_ctx"]
if model_cfg.n_ctx > max_n_ctx:
raise HTTPException(status_code=400, detail=f"n_ctx must be less or equal to {max_n_ctx} for {model_name}")
if model_cfg.concurrency and model_cfg.concurrency not in model_info["available_concurrency"]:
raise HTTPException(status_code=400, detail=f"concurrency must be one of {model_info['available_concurrency']} for {model_name}")
break
else:
raise HTTPException(status_code=400, detail=f"model {model_name} not found")
Expand Down