Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions refact-agent/engine/src/call_validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ impl Default for ReasoningEffort {
}
}

impl ReasoningEffort {
pub fn to_string(&self) -> String { format!("{:?}", self).to_lowercase() }
}

#[derive(Debug, Serialize, Deserialize, Clone, Default)]
pub struct SamplingParameters {
#[serde(default)]
Expand Down
33 changes: 32 additions & 1 deletion refact-agent/engine/src/caps.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use tracing::{error, info, warn};

use crate::custom_error::ScratchError;
use crate::global_context::{try_load_caps_quickly_if_not_present, GlobalContext};
use crate::call_validation::ReasoningEffort;
use crate::known_models::KNOWN_MODELS;


Expand All @@ -39,6 +40,14 @@ pub struct ModelRecord {
pub supports_clicks: bool,
#[serde(default)]
pub supports_agent: bool,
#[serde(default)]
pub supports_reasoning: bool,
#[serde(default)]
pub supports_reasoning_effort: Vec<ReasoningEffort>,
#[serde(default)]
pub default_temperature: Option<f32>,
#[serde(default)]
pub inference_model_name: Option<String>,
}

#[derive(Debug, Deserialize)]
Expand Down Expand Up @@ -445,7 +454,29 @@ fn _inherit_r1_from_r0(

for (rec_name, rec) in r0.code_chat_models.iter() {
if rec_name == &k_stripped || rec.similar_models.contains(&k_stripped) {
r1.code_chat_models.insert(k.to_string(), rec.clone());
if rec.supports_reasoning_effort.is_empty() {
r1.code_chat_models.insert(k.to_string(), rec.clone());
} else {
// NOTE: expand model list with all supported reasoning efforts
for reasoning_effort in &rec.supports_reasoning_effort {
let mut model_name = k.to_string();
let mut reasoning_rec = rec.clone();
let tokenizer_rewrite_path = if let Some(path) = r1.tokenizer_rewrite_path.get(k) {
path.clone()
} else if let Some(path) = r0.tokenizer_rewrite_path.get(k) {
path.clone()
} else {
k.to_string()
};
if reasoning_effort.clone() != ReasoningEffort::Medium {
model_name = format!("{}-{}", model_name, reasoning_effort.to_string());
}
reasoning_rec.supports_reasoning_effort = vec![reasoning_effort.clone()];
reasoning_rec.inference_model_name = Some(k.to_string());
r1.code_chat_models.insert(model_name.clone(), reasoning_rec);
r1.tokenizer_rewrite_path.insert(model_name, tokenizer_rewrite_path);
}
}
}
}
}
Expand Down
22 changes: 13 additions & 9 deletions refact-agent/engine/src/forward_to_openai_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,17 +45,14 @@ pub async fn forward_to_openai_style_endpoint(
if let Some(n) = sampling_parameters.n {
data["n"] = serde_json::Value::from(n);
}
if model_name != "o1-mini" {
data["temperature"] = serde_json::Value::from(sampling_parameters.temperature);
if let Some(reasoning_effort) = sampling_parameters.reasoning_effort.clone() {
// NOTE: reasoning_effort supported by openai models only
data["reasoning_effort"] = serde_json::Value::String(format!("{:?}", reasoning_effort));
data["max_completion_tokens"] = serde_json::Value::from(sampling_parameters.max_new_tokens);
} else {
data["temperature"] = serde_json::Value::from(sampling_parameters.temperature);
data["max_completion_tokens"] = serde_json::Value::from(sampling_parameters.max_new_tokens);
}
if let Some(n) = sampling_parameters.n {
if n > 1 {
data["n"] = serde_json::Value::from(n);
}
}
info!("NOT STREAMING TEMP {}", sampling_parameters.temperature
.map(|x| x.to_string())
.unwrap_or("None".to_string()));
Expand Down Expand Up @@ -122,8 +119,6 @@ pub async fn forward_to_openai_style_endpoint_streaming(
let mut data = json!({
"model": model_name,
"stream": true,
"temperature": sampling_parameters.temperature,
"max_completion_tokens": sampling_parameters.max_new_tokens,
"stream_options": {"include_usage": true},
});
if !sampling_parameters.stop.is_empty() { // openai does not like empty stop
Expand All @@ -132,6 +127,14 @@ pub async fn forward_to_openai_style_endpoint_streaming(
if let Some(n) = sampling_parameters.n{
data["n"] = serde_json::Value::from(n);
}
if let Some(reasoning_effort) = sampling_parameters.reasoning_effort.clone() {
// NOTE: reasoning_effort supported by openai models only
data["reasoning_effort"] = serde_json::Value::String(reasoning_effort.to_string());
data["max_completion_tokens"] = serde_json::Value::from(sampling_parameters.max_new_tokens);
} else {
data["temperature"] = serde_json::Value::from(sampling_parameters.temperature);
data["max_completion_tokens"] = serde_json::Value::from(sampling_parameters.max_new_tokens);
}
info!("STREAMING TEMP {}", sampling_parameters.temperature
.map(|x| x.to_string())
.unwrap_or("None".to_string()));
Expand All @@ -152,6 +155,7 @@ pub async fn forward_to_openai_style_endpoint_streaming(
Ok(event_source)
}

// NOTE: questionable function, no idea why we need it
fn passthrough_messages_to_json(
data: &mut serde_json::Value,
prompt: &str,
Expand Down
6 changes: 6 additions & 0 deletions refact-agent/engine/src/known_models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ pub const KNOWN_MODELS: &str = r####"
"supports_tools": true,
"supports_multimodality": true,
"supports_reasoning": true,
"supports_reasoning_effort": ["medium", "high"],
"supports_scratchpads": {
"PASSTHROUGH": {
}
Expand All @@ -388,6 +389,7 @@ pub const KNOWN_MODELS: &str = r####"
"n_ctx": 128000,
"supports_tools": true,
"supports_reasoning": true,
"supports_reasoning_effort": ["medium", "high"],
"supports_scratchpads": {
"PASSTHROUGH": {
}
Expand All @@ -399,6 +401,7 @@ pub const KNOWN_MODELS: &str = r####"
"supports_multimodality": false,
"supports_agent": true,
"supports_reasoning": true,
"supports_reasoning_effort": ["medium", "high"],
"supports_scratchpads": {
"PASSTHROUGH": {
}
Expand Down Expand Up @@ -726,6 +729,7 @@ pub const KNOWN_MODELS: &str = r####"
"supports_tools": false,
"supports_multimodality": false,
"supports_reasoning": true,
"default_temperature": 0.6,
"supports_scratchpads": {
"PASSTHROUGH": {}
}
Expand Down Expand Up @@ -760,6 +764,8 @@ pub const KNOWN_MODELS: &str = r####"
},
"deepseek-r1-distill/1.5b/vllm": {
"n_ctx": 32768,
"supports_reasoning": true,
"default_temperature": 0.6,
"supports_scratchpads": {
"PASSTHROUGH": {}
},
Expand Down
25 changes: 15 additions & 10 deletions refact-agent/engine/src/restream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,26 +23,27 @@ async fn _get_endpoint_and_stuff_from_model_name(
gcx: Arc<ARwLock<crate::global_context::GlobalContext>>,
caps: Arc<StdRwLock<crate::caps::CodeAssistantCaps>>,
model_name: String,
) -> (String, String, String, String)
) -> (String, String, String, String, String)
{
let (
custom_apikey,
mut endpoint_style,
custom_endpoint_style,
mut endpoint_template,
custom_endpoint_template,
endpoint_chat_passthrough
endpoint_chat_passthrough,
inference_model_name,
) = {
let caps_locked = caps.read().unwrap();
let is_chat = caps_locked.code_chat_models.contains_key(&model_name);
if is_chat {
if let Some(model_record) = caps_locked.code_chat_models.get(&model_name) {
(
caps_locked.chat_apikey.clone(),
caps_locked.endpoint_style.clone(), // abstract
caps_locked.chat_endpoint_style.clone(), // chat-specific
caps_locked.endpoint_template.clone(), // abstract
caps_locked.chat_endpoint.clone(), // chat-specific
caps_locked.endpoint_chat_passthrough.clone(),
model_record.clone().inference_model_name.unwrap_or(model_name.clone())
)
} else {
(
Expand All @@ -52,6 +53,7 @@ async fn _get_endpoint_and_stuff_from_model_name(
caps_locked.endpoint_template.clone(), // abstract
caps_locked.completion_endpoint.clone(), // completion-specific
"".to_string(),
model_name.clone(),
)
}
};
Expand All @@ -62,11 +64,12 @@ async fn _get_endpoint_and_stuff_from_model_name(
if !custom_endpoint_template.is_empty() {
endpoint_template = custom_endpoint_template;
}
return (
(
api_key,
endpoint_template,
endpoint_style,
endpoint_chat_passthrough,
inference_model_name,
)
}

Expand Down Expand Up @@ -98,19 +101,20 @@ pub async fn scratchpad_interaction_not_stream_json(
endpoint_template,
endpoint_style,
endpoint_chat_passthrough,
inference_model_name,
) = _get_endpoint_and_stuff_from_model_name(gcx.clone(), caps.clone(), model_name.clone()).await;

let mut save_url: String = String::new();
let _ = slowdown_arc.acquire().await;
let metadata_supported = crate::global_context::is_metadata_supported(gcx.clone()).await;
let mut model_says = if only_deterministic_messages {
save_url = "only-det-messages".to_string();
Ok(serde_json::Value::Object(serde_json::Map::new()))
Ok(Value::Object(serde_json::Map::new()))
} else if endpoint_style == "hf" {
crate::forward_to_hf_endpoint::forward_to_hf_style_endpoint(
&mut save_url,
bearer.clone(),
&model_name,
&inference_model_name,
&prompt,
&client,
&endpoint_template,
Expand All @@ -121,7 +125,7 @@ pub async fn scratchpad_interaction_not_stream_json(
crate::forward_to_openai_endpoint::forward_to_openai_style_endpoint(
&mut save_url,
bearer.clone(),
&model_name,
&inference_model_name,
&prompt,
&client,
&endpoint_template,
Expand Down Expand Up @@ -323,6 +327,7 @@ pub async fn scratchpad_interaction_stream(
endpoint_template,
endpoint_style,
endpoint_chat_passthrough,
inference_model_name,
) = _get_endpoint_and_stuff_from_model_name(gcx.clone(), caps.clone(), model_name.clone()).await;

let t0 = std::time::Instant::now();
Expand Down Expand Up @@ -401,7 +406,7 @@ pub async fn scratchpad_interaction_stream(
crate::forward_to_hf_endpoint::forward_to_hf_style_endpoint_streaming(
&mut save_url,
bearer.clone(),
&model_name,
&inference_model_name,
prompt.as_str(),
&client,
&endpoint_template,
Expand All @@ -412,7 +417,7 @@ pub async fn scratchpad_interaction_stream(
crate::forward_to_openai_endpoint::forward_to_openai_style_endpoint_streaming(
&mut save_url,
bearer.clone(),
&model_name,
&inference_model_name,
prompt.as_str(),
&client,
&endpoint_template,
Expand Down
59 changes: 39 additions & 20 deletions refact-agent/engine/src/scratchpads/chat_passthrough.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ use tracing::{error, info};

use crate::at_commands::execute_at::{run_at_commands_locally, run_at_commands_remotely};
use crate::at_commands::at_commands::AtCommandsContext;
use crate::call_validation::{ChatContent, ChatMessage, ChatPost, SamplingParameters};
use crate::call_validation::{ChatContent, ChatMessage, ChatPost, ReasoningEffort, SamplingParameters};
use crate::caps::ModelRecord;
use crate::http::http_get_json;
use crate::integrations::docker::docker_container_manager::docker_container_get_host_lsp_port_to_connect;
use crate::scratchpad_abstract::{FinishReason, HasTokenizerAndEot, ScratchpadAbstract};
Expand Down Expand Up @@ -145,9 +146,24 @@ impl ScratchpadAbstract for ChatPassthrough {
);
_remove_invalid_tool_calls_and_tool_calls_results(&mut messages);

// Handle models that support reasoning
let messages = if model_supports_reasoning(&self.post.model) {
_adapt_for_reasoning_models(&messages, sampling_parameters_to_patch)
let caps = {
let gcx_locked = gcx.write().await;
gcx_locked.caps.clone().unwrap()
};
let model_record_mb = {
let caps_locked = caps.read().unwrap();
caps_locked.code_chat_models.get(&self.post.model).cloned()
};

let (supports_reasoning, default_temperature, default_reasoning_effort) =
_model_reasoning_params(model_record_mb);
let messages = if supports_reasoning {
_adapt_for_reasoning_models(
&messages,
sampling_parameters_to_patch,
default_temperature,
default_reasoning_effort,
)
} else {
messages
};
Expand All @@ -157,7 +173,7 @@ impl ScratchpadAbstract for ChatPassthrough {
vec![]
});

if self.prepend_system_prompt && !model_supports_reasoning(&self.post.model) {
if self.prepend_system_prompt && !supports_reasoning {
assert_eq!(limited_msgs.first().unwrap().role, "system");
}
let converted_messages = convert_messages_to_openai_format(limited_msgs, &style);
Expand Down Expand Up @@ -340,29 +356,32 @@ fn _replace_broken_tool_call_messages(
}
}

pub fn model_supports_reasoning(model_name: &str) -> bool {
let known_models: serde_json::Value = serde_json::from_str(crate::known_models::KNOWN_MODELS)
.expect("Failed to parse KNOWN_MODELS");

// Check if the model exists in code_chat_models and has supports_reasoning set to true
if let Some(chat_models) = known_models.get("code_chat_models") {
if let Some(model) = chat_models.get(model_name) {
return model.get("supports_reasoning")
.and_then(|v| v.as_bool())
.unwrap_or(false);
}
fn _model_reasoning_params(
model_record_mb: Option<ModelRecord>,
) -> (bool, Option<f32>, Option<ReasoningEffort>) {
let mut support_reasoning: bool = false;
let mut temperature: Option<f32> = None;
let mut reasoning_effort: Option<ReasoningEffort> = None;

if let Some(model_record) = model_record_mb {
support_reasoning = model_record.supports_reasoning.clone();
temperature = model_record.default_temperature.clone();
reasoning_effort = model_record.supports_reasoning_effort.first().cloned();
}

// If model is not found or doesn't have supports_reasoning field, return false
false
(support_reasoning, temperature, reasoning_effort)
}

fn _adapt_for_reasoning_models(
messages: &Vec<ChatMessage>,
sampling_parameters: &mut SamplingParameters,
default_temperature: Option<f32>,
default_reasoning_effort: Option<ReasoningEffort>,
) -> Vec<ChatMessage> {
// Set temperature to None
sampling_parameters.temperature = None;
sampling_parameters.temperature = default_temperature.clone();
if sampling_parameters.reasoning_effort.is_none() {
sampling_parameters.reasoning_effort = default_reasoning_effort.clone();
}

// Convert system messages to user messages
messages.iter().map(|msg| {
Expand Down
11 changes: 11 additions & 0 deletions refact-server/refact_known_models/passthrough.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,17 @@
"pp1000t_generated": 12_000, # $12.00 / 1M tokens (2025 january)
"filter_caps": ["chat", "tools"],
},
"o3-mini": {
"backend": "litellm",
"provider": "openai",
"tokenizer_path": "Xenova/gpt-4o",
"resolve_as": "o3-mini-2025-01-31",
"T": 200_000,
"T_out": 64_000,
"pp1000t_prompt": 1_100, # $1.10 / 1M tokens (2025 january)
"pp1000t_generated": 4_400, # $4.40 / 1M tokens (2025 january)
"filter_caps": ["chat", "tools"],
},

# Anthropic models
"claude-3-5-sonnet": {
Expand Down
Loading