diff --git a/.gitignore b/.gitignore index 03622ad00..8c897a095 100644 --- a/.gitignore +++ b/.gitignore @@ -294,3 +294,5 @@ dist .vite # Refact binary/symlink **/refact/bin/refact-lsp + +.refact_knowledge*/ diff --git a/refact-agent/engine/src/global_context.rs b/refact-agent/engine/src/global_context.rs index c3acd17be..dc54e324d 100644 --- a/refact-agent/engine/src/global_context.rs +++ b/refact-agent/engine/src/global_context.rs @@ -178,6 +178,7 @@ pub struct GlobalContext { pub init_shadow_repos_lock: Arc>, pub git_operations_abort_flag: Arc, pub app_searchable_id: String, + pub trajectory_events_tx: Option>, } pub type SharedGlobalContext = Arc>; // TODO: remove this type alias, confusing @@ -426,6 +427,7 @@ pub async fn create_global_context( init_shadow_repos_lock: Arc::new(AMutex::new(false)), git_operations_abort_flag: Arc::new(AtomicBool::new(false)), app_searchable_id: get_app_searchable_id(&workspace_dirs), + trajectory_events_tx: Some(tokio::sync::broadcast::channel(100).0), }; let gcx = Arc::new(ARwLock::new(cx)); crate::files_in_workspace::watcher_init(gcx.clone()).await; diff --git a/refact-agent/engine/src/http/routers/v1.rs b/refact-agent/engine/src/http/routers/v1.rs index 18ea86c60..af1df48e6 100644 --- a/refact-agent/engine/src/http/routers/v1.rs +++ b/refact-agent/engine/src/http/routers/v1.rs @@ -1,6 +1,6 @@ use at_tools::handle_v1_post_tools; use axum::Router; -use axum::routing::{get, post, delete}; +use axum::routing::{get, post, put, delete}; use tower_http::cors::CorsLayer; use crate::http::utils::telemetry_middleware; @@ -13,7 +13,6 @@ use crate::http::routers::v1::caps::handle_v1_caps; use crate::http::routers::v1::caps::handle_v1_ping; use crate::http::routers::v1::chat::{handle_v1_chat, handle_v1_chat_completions}; use crate::http::routers::v1::chat_based_handlers::{handle_v1_commit_message_from_diff, handle_v1_trajectory_compress}; -use crate::http::routers::v1::chat_based_handlers::handle_v1_trajectory_save; use crate::http::routers::v1::dashboard::get_dashboard_plots; use crate::http::routers::v1::docker::{handle_v1_docker_container_action, handle_v1_docker_container_list}; use crate::http::routers::v1::git::{handle_v1_git_commit, handle_v1_checkpoints_preview, handle_v1_checkpoints_restore}; @@ -40,6 +39,11 @@ use crate::http::routers::v1::v1_integrations::{handle_v1_integration_get, handl use crate::http::routers::v1::file_edit_tools::handle_v1_file_edit_tool_dry_run; use crate::http::routers::v1::code_edit::handle_v1_code_edit; use crate::http::routers::v1::workspace::{handle_v1_get_app_searchable_id, handle_v1_set_active_group_id}; +use crate::http::routers::v1::trajectories::{ + handle_v1_trajectories_list, handle_v1_trajectories_get, + handle_v1_trajectories_save, handle_v1_trajectories_delete, + handle_v1_trajectories_subscribe, +}; mod ast; pub mod at_commands; @@ -71,6 +75,7 @@ mod v1_integrations; pub mod vecdb; mod workspace; mod knowledge_graph; +pub mod trajectories; pub fn make_v1_router() -> Router { let builder = Router::new() @@ -171,8 +176,12 @@ pub fn make_v1_router() -> Router { .route("/vdb-search", post(handle_v1_vecdb_search)) .route("/vdb-status", get(handle_v1_vecdb_status)) .route("/knowledge-graph", get(handle_v1_knowledge_graph)) - .route("/trajectory-save", post(handle_v1_trajectory_save)) .route("/trajectory-compress", post(handle_v1_trajectory_compress)) + .route("/trajectories", get(handle_v1_trajectories_list)) + .route("/trajectories/subscribe", get(handle_v1_trajectories_subscribe)) + .route("/trajectories/:id", get(handle_v1_trajectories_get)) + .route("/trajectories/:id", put(handle_v1_trajectories_save)) + .route("/trajectories/:id", delete(handle_v1_trajectories_delete)) ; builder diff --git a/refact-agent/engine/src/http/routers/v1/chat_based_handlers.rs b/refact-agent/engine/src/http/routers/v1/chat_based_handlers.rs index 4b2a91371..59d561783 100644 --- a/refact-agent/engine/src/http/routers/v1/chat_based_handlers.rs +++ b/refact-agent/engine/src/http/routers/v1/chat_based_handlers.rs @@ -74,28 +74,4 @@ pub async fn handle_v1_trajectory_compress( } -pub async fn handle_v1_trajectory_save( - Extension(gcx): Extension>>, - body_bytes: hyper::body::Bytes, -) -> axum::response::Result, ScratchError> { - let post = serde_json::from_slice::(&body_bytes).map_err(|e| { - ScratchError::new(StatusCode::UNPROCESSABLE_ENTITY, format!("JSON problem: {}", e)) - })?; - let trajectory = compress_trajectory(gcx.clone(), &post.messages) - .await.map_err(|e| ScratchError::new(StatusCode::UNPROCESSABLE_ENTITY, e))?; - - let file_path = crate::memories::save_trajectory(gcx, &trajectory) - .await.map_err(|e| ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, e))?; - - let response = serde_json::json!({ - "trajectory": trajectory, - "file_path": file_path.to_string_lossy(), - }); - - Ok(Response::builder() - .status(StatusCode::OK) - .header("Content-Type", "application/json") - .body(Body::from(serde_json::to_string(&response).unwrap())) - .unwrap()) -} diff --git a/refact-agent/engine/src/http/routers/v1/trajectories.rs b/refact-agent/engine/src/http/routers/v1/trajectories.rs new file mode 100644 index 000000000..1a8d5b3ce --- /dev/null +++ b/refact-agent/engine/src/http/routers/v1/trajectories.rs @@ -0,0 +1,538 @@ +use std::path::PathBuf; +use std::sync::Arc; +use axum::extract::Path; +use axum::http::{Response, StatusCode}; +use axum::Extension; +use hyper::Body; +use serde::{Deserialize, Serialize}; +use tokio::sync::RwLock as ARwLock; +use tokio::sync::Mutex as AMutex; +use tokio::sync::broadcast; +use tokio::fs; +use tracing::{info, warn}; + +use crate::at_commands::at_commands::AtCommandsContext; +use crate::call_validation::ChatMessage; +use crate::custom_error::ScratchError; +use crate::global_context::{GlobalContext, try_load_caps_quickly_if_not_present}; +use crate::files_correction::get_project_dirs; +use crate::subchat::subchat_single; + +const TRAJECTORIES_FOLDER: &str = ".refact/trajectories"; +const TITLE_GENERATION_PROMPT: &str = "Summarize this chat in 2-4 words. Prefer filenames, classes, entities, and avoid generic terms. Write only the title, nothing else."; + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TrajectoryEvent { + #[serde(rename = "type")] + pub event_type: String, + pub id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub updated_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub title: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TrajectoryMeta { + pub id: String, + pub title: String, + pub created_at: String, + pub updated_at: String, + pub model: String, + pub mode: String, + pub message_count: usize, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TrajectoryData { + pub id: String, + pub title: String, + pub created_at: String, + pub updated_at: String, + pub model: String, + pub mode: String, + pub tool_use: String, + pub messages: Vec, + #[serde(flatten)] + pub extra: serde_json::Map, +} + +async fn get_trajectories_dir(gcx: Arc>) -> Result { + let project_dirs = get_project_dirs(gcx).await; + let workspace_root = project_dirs.first().ok_or("No workspace folder found")?; + Ok(workspace_root.join(TRAJECTORIES_FOLDER)) +} + +fn validate_trajectory_id(id: &str) -> Result<(), ScratchError> { + if id.contains('/') || id.contains('\\') || id.contains("..") || id.contains('\0') { + return Err(ScratchError::new(StatusCode::BAD_REQUEST, "Invalid trajectory id".to_string())); + } + Ok(()) +} + +async fn atomic_write_json(path: &PathBuf, data: &impl Serialize) -> Result<(), String> { + let tmp_path = path.with_extension("json.tmp"); + let json = serde_json::to_string_pretty(data).map_err(|e| e.to_string())?; + fs::write(&tmp_path, &json).await.map_err(|e| e.to_string())?; + fs::rename(&tmp_path, path).await.map_err(|e| e.to_string())?; + Ok(()) +} + +fn is_placeholder_title(title: &str) -> bool { + let normalized = title.trim().to_lowercase(); + normalized.is_empty() || normalized == "new chat" || normalized == "untitled" +} + +fn extract_first_user_message(messages: &[serde_json::Value]) -> Option { + for msg in messages { + let role = msg.get("role").and_then(|r| r.as_str()).unwrap_or(""); + if role != "user" { + continue; + } + + // Handle string content + if let Some(content) = msg.get("content").and_then(|c| c.as_str()) { + let trimmed = content.trim(); + if !trimmed.is_empty() { + return Some(trimmed.chars().take(200).collect()); + } + } + + // Handle array content (multimodal) + if let Some(content_arr) = msg.get("content").and_then(|c| c.as_array()) { + for item in content_arr { + if let Some(text) = item.get("text").and_then(|t| t.as_str()) { + let trimmed = text.trim(); + if !trimmed.is_empty() { + return Some(trimmed.chars().take(200).collect()); + } + } + if let Some(text) = item.get("m_content").and_then(|t| t.as_str()) { + let trimmed = text.trim(); + if !trimmed.is_empty() { + return Some(trimmed.chars().take(200).collect()); + } + } + } + } + } + None +} + +fn build_title_generation_context(messages: &[serde_json::Value]) -> String { + let mut context = String::new(); + let max_messages = 6; + let max_chars_per_message = 500; + + for (i, msg) in messages.iter().take(max_messages).enumerate() { + let role = msg.get("role").and_then(|r| r.as_str()).unwrap_or("unknown"); + + // Skip tool messages and context files for title generation + if role == "tool" || role == "context_file" || role == "cd_instruction" { + continue; + } + + let content_text = if let Some(content) = msg.get("content").and_then(|c| c.as_str()) { + content.to_string() + } else if let Some(content_arr) = msg.get("content").and_then(|c| c.as_array()) { + content_arr.iter() + .filter_map(|item| { + item.get("text").and_then(|t| t.as_str()) + .or_else(|| item.get("m_content").and_then(|t| t.as_str())) + }) + .collect::>() + .join(" ") + } else { + continue; + }; + + let truncated: String = content_text.chars().take(max_chars_per_message).collect(); + if !truncated.trim().is_empty() { + context.push_str(&format!("{}: {}\n\n", role, truncated)); + } + + if i >= max_messages - 1 { + break; + } + } + + context +} + +fn clean_generated_title(raw_title: &str) -> String { + let cleaned = raw_title + .trim() + .trim_matches('"') + .trim_matches('\'') + .trim_matches('`') + .trim_matches('*') + .replace('\n', " ") + .split_whitespace() + .collect::>() + .join(" "); + + // Limit to ~60 chars + if cleaned.chars().count() > 60 { + cleaned.chars().take(57).collect::() + "..." + } else { + cleaned + } +} + +async fn generate_title_llm( + gcx: Arc>, + messages: &[serde_json::Value], +) -> Option { + let caps = match try_load_caps_quickly_if_not_present(gcx.clone(), 0).await { + Ok(caps) => caps, + Err(e) => { + warn!("Failed to load caps for title generation: {:?}", e); + return None; + } + }; + + // Use light model if available, otherwise default + let model_id = if !caps.defaults.chat_light_model.is_empty() { + caps.defaults.chat_light_model.clone() + } else { + caps.defaults.chat_default_model.clone() + }; + + if model_id.is_empty() { + warn!("No model available for title generation"); + return None; + } + + let context = build_title_generation_context(messages); + if context.trim().is_empty() { + return None; + } + + let prompt = format!("Chat conversation:\n{}\n\n{}", context, TITLE_GENERATION_PROMPT); + + let ccx = Arc::new(AMutex::new(AtCommandsContext::new( + gcx.clone(), + 2048, + 5, + false, + vec![], + "title-generation".to_string(), + false, + model_id.clone(), + ).await)); + + let chat_messages = vec![ + ChatMessage::new("user".to_string(), prompt), + ]; + + match subchat_single( + ccx, + &model_id, + chat_messages, + Some(vec![]), // No tools + Some("none".to_string()), // No tool choice + false, + Some(0.3), // Low temperature for consistent titles + Some(50), // Max tokens - titles should be short + 1, // n=1 + None, // No reasoning effort + false, // No system prompt + None, // No usage collector + None, // No tool id + None, // No chat id + ).await { + Ok(results) => { + if let Some(messages) = results.first() { + if let Some(last_msg) = messages.last() { + let raw_title = last_msg.content.content_text_only(); + let cleaned = clean_generated_title(&raw_title); + if !cleaned.is_empty() && cleaned.to_lowercase() != "new chat" { + info!("Generated title: {}", cleaned); + return Some(cleaned); + } + } + } + None + } + Err(e) => { + warn!("Title generation failed: {}", e); + None + } + } +} + +async fn spawn_title_generation_task( + gcx: Arc>, + id: String, + messages: Vec, + trajectories_dir: PathBuf, +) { + tokio::spawn(async move { + // Generate title via LLM + let generated_title = generate_title_llm(gcx.clone(), &messages).await; + + let title = match generated_title { + Some(t) => t, + None => { + // Fallback to truncated first user message + match extract_first_user_message(&messages) { + Some(first_msg) => { + let truncated: String = first_msg.chars().take(60).collect(); + if truncated.len() < first_msg.len() { + format!("{}...", truncated.trim_end()) + } else { + truncated + } + } + None => return, // No title to generate + } + } + }; + + // Read current trajectory data + let file_path = trajectories_dir.join(format!("{}.json", id)); + let content = match fs::read_to_string(&file_path).await { + Ok(c) => c, + Err(e) => { + warn!("Failed to read trajectory for title update: {}", e); + return; + } + }; + + let mut data: TrajectoryData = match serde_json::from_str(&content) { + Ok(d) => d, + Err(e) => { + warn!("Failed to parse trajectory for title update: {}", e); + return; + } + }; + + // Update title and mark as generated + data.title = title.clone(); + data.extra.insert("isTitleGenerated".to_string(), serde_json::json!(true)); + + // Write back + if let Err(e) = atomic_write_json(&file_path, &data).await { + warn!("Failed to write trajectory with generated title: {}", e); + return; + } + + info!("Updated trajectory {} with generated title: {}", id, title); + + // Emit SSE event with new title + let event = TrajectoryEvent { + event_type: "updated".to_string(), + id: id.clone(), + updated_at: Some(data.updated_at.clone()), + title: Some(title), + }; + + if let Some(tx) = &gcx.read().await.trajectory_events_tx { + let _ = tx.send(event); + } + }); +} + +pub async fn handle_v1_trajectories_list( + Extension(gcx): Extension>>, +) -> Result, ScratchError> { + let trajectories_dir = get_trajectories_dir(gcx).await + .map_err(|e| ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, e))?; + + let mut result: Vec = Vec::new(); + + if trajectories_dir.exists() { + let mut entries = fs::read_dir(&trajectories_dir).await + .map_err(|e| ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + while let Some(entry) = entries.next_entry().await + .map_err(|e| ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))? { + let path = entry.path(); + if path.extension().and_then(|e| e.to_str()) != Some("json") { + continue; + } + if let Ok(content) = fs::read_to_string(&path).await { + if let Ok(data) = serde_json::from_str::(&content) { + result.push(TrajectoryMeta { + id: data.id, + title: data.title, + created_at: data.created_at, + updated_at: data.updated_at, + model: data.model, + mode: data.mode, + message_count: data.messages.len(), + }); + } + } + } + } + + result.sort_by(|a, b| b.updated_at.cmp(&a.updated_at)); + + Ok(Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "application/json") + .body(Body::from(serde_json::to_string(&result).unwrap())) + .unwrap()) +} + +pub async fn handle_v1_trajectories_get( + Extension(gcx): Extension>>, + Path(id): Path, +) -> Result, ScratchError> { + validate_trajectory_id(&id)?; + + let trajectories_dir = get_trajectories_dir(gcx).await + .map_err(|e| ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, e))?; + + let file_path = trajectories_dir.join(format!("{}.json", id)); + + if !file_path.exists() { + return Err(ScratchError::new(StatusCode::NOT_FOUND, "Trajectory not found".to_string())); + } + + let content = fs::read_to_string(&file_path).await + .map_err(|e| ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + Ok(Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "application/json") + .body(Body::from(content)) + .unwrap()) +} + +pub async fn handle_v1_trajectories_save( + Extension(gcx): Extension>>, + Path(id): Path, + body_bytes: hyper::body::Bytes, +) -> Result, ScratchError> { + validate_trajectory_id(&id)?; + + let data: TrajectoryData = serde_json::from_slice(&body_bytes) + .map_err(|e| ScratchError::new(StatusCode::BAD_REQUEST, format!("Invalid JSON: {}", e)))?; + + if data.id != id { + return Err(ScratchError::new(StatusCode::BAD_REQUEST, "ID mismatch".to_string())); + } + + let trajectories_dir = get_trajectories_dir(gcx.clone()).await + .map_err(|e| ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, e))?; + + fs::create_dir_all(&trajectories_dir).await + .map_err(|e| ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + let file_path = trajectories_dir.join(format!("{}.json", id)); + let is_new = !file_path.exists(); + + // Check if we need to generate a title + let is_title_generated = data.extra.get("isTitleGenerated") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + let should_generate_title = is_placeholder_title(&data.title) + && !is_title_generated + && !data.messages.is_empty(); + + atomic_write_json(&file_path, &data).await + .map_err(|e| ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, e))?; + + let event = TrajectoryEvent { + event_type: if is_new { "created".to_string() } else { "updated".to_string() }, + id: id.clone(), + updated_at: Some(data.updated_at.clone()), + title: if is_new { Some(data.title.clone()) } else { None }, + }; + + if let Some(tx) = &gcx.read().await.trajectory_events_tx { + let _ = tx.send(event); + } + + // Spawn async title generation if needed + if should_generate_title { + spawn_title_generation_task( + gcx.clone(), + id.clone(), + data.messages.clone(), + trajectories_dir, + ).await; + } + + Ok(Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "application/json") + .body(Body::from(r#"{"status":"ok"}"#)) + .unwrap()) +} + +pub async fn handle_v1_trajectories_delete( + Extension(gcx): Extension>>, + Path(id): Path, +) -> Result, ScratchError> { + validate_trajectory_id(&id)?; + + let trajectories_dir = get_trajectories_dir(gcx.clone()).await + .map_err(|e| ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, e))?; + + let file_path = trajectories_dir.join(format!("{}.json", id)); + + if !file_path.exists() { + return Err(ScratchError::new(StatusCode::NOT_FOUND, "Trajectory not found".to_string())); + } + + fs::remove_file(&file_path).await + .map_err(|e| ScratchError::new(StatusCode::INTERNAL_SERVER_ERROR, e.to_string()))?; + + let event = TrajectoryEvent { + event_type: "deleted".to_string(), + id: id.clone(), + updated_at: None, + title: None, + }; + + if let Some(tx) = &gcx.read().await.trajectory_events_tx { + let _ = tx.send(event); + } + + Ok(Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "application/json") + .body(Body::from(r#"{"status":"ok"}"#)) + .unwrap()) +} + +pub async fn handle_v1_trajectories_subscribe( + Extension(gcx): Extension>>, +) -> Result, ScratchError> { + let rx = { + let gcx_locked = gcx.read().await; + match &gcx_locked.trajectory_events_tx { + Some(tx) => tx.subscribe(), + None => return Err(ScratchError::new( + StatusCode::SERVICE_UNAVAILABLE, + "Trajectory events not available".to_string() + )), + } + }; + + let stream = async_stream::stream! { + let mut rx = rx; + loop { + match rx.recv().await { + Ok(event) => { + let json = serde_json::to_string(&event).unwrap_or_default(); + yield Ok::<_, std::convert::Infallible>(format!("data: {}\n\n", json)); + } + Err(broadcast::error::RecvError::Lagged(_)) => continue, + Err(broadcast::error::RecvError::Closed) => break, + } + } + }; + + Ok(Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "text/event-stream") + .header("Cache-Control", "no-cache") + .header("Connection", "keep-alive") + .body(Body::wrap_stream(stream)) + .unwrap()) +} diff --git a/refact-agent/engine/src/memories.rs b/refact-agent/engine/src/memories.rs index 499933a14..d440f9bfe 100644 --- a/refact-agent/engine/src/memories.rs +++ b/refact-agent/engine/src/memories.rs @@ -252,39 +252,6 @@ async fn memories_search_fallback( Ok(scored_results.into_iter().take(top_n).map(|(_, r)| r).collect()) } -pub async fn save_trajectory( - gcx: Arc>, - compressed_trajectory: &str, -) -> Result { - let knowledge_dir = get_knowledge_dir(gcx.clone()).await?; - let trajectories_dir = knowledge_dir.join("trajectories"); - fs::create_dir_all(&trajectories_dir).await.map_err(|e| format!("Failed to create trajectories dir: {}", e))?; - - let filename = generate_filename(compressed_trajectory); - let file_path = trajectories_dir.join(&filename); - - let frontmatter = create_frontmatter( - compressed_trajectory.lines().next(), - &["trajectory".to_string()], - &[], - &[], - "trajectory", - ); - - let md_content = format!("{}\n\n{}", frontmatter.to_yaml(), compressed_trajectory); - fs::write(&file_path, &md_content).await.map_err(|e| format!("Failed to write trajectory file: {}", e))?; - - info!("Saved trajectory: {}", file_path.display()); - - if let Some(vecdb) = gcx.read().await.vec_db.lock().await.as_ref() { - vecdb.vectorizer_enqueue_files(&vec![file_path.to_string_lossy().to_string()], true).await; - } - - let _ = build_knowledge_graph(gcx).await; - - Ok(file_path) -} - pub async fn deprecate_document( gcx: Arc>, doc_path: &PathBuf, diff --git a/refact-agent/gui/AGENTS.md b/refact-agent/gui/AGENTS.md index 5ce3a62a5..f41466ece 100644 --- a/refact-agent/gui/AGENTS.md +++ b/refact-agent/gui/AGENTS.md @@ -1,7 +1,7 @@ # Refact Agent GUI - Developer Guide **Last Updated**: December 2024 -**Version**: 2.0.10-alpha.3 +**Version**: 2.0.10-alpha.4 **Repository**: https://github.com/smallcloudai/refact/tree/main/refact-agent/gui --- @@ -18,11 +18,12 @@ 8. [API Services](#api-services) 9. [IDE Integration](#ide-integration) 10. [Tool Calling System](#tool-calling-system) -11. [Development Workflows](#development-workflows) -12. [Testing](#testing) -13. [Debugging](#debugging) -14. [Special Features](#special-features) -15. [Common Patterns](#common-patterns) +11. [Multi-Tab Chat & Background Threads](#multi-tab-chat--background-threads) +12. [Development Workflows](#development-workflows) +13. [Testing](#testing) +14. [Debugging](#debugging) +15. [Special Features](#special-features) +16. [Common Patterns](#common-patterns) --- @@ -2421,6 +2422,304 @@ type ToolStatus = --- +## Multi-Tab Chat & Background Threads + +### Thread State Model + +Each chat thread has **two layers of state**: + +| Layer | Type | Storage | Contents | +|-------|------|---------|----------| +| **Thread data** | `ChatThread` | `state.chat.threads[id].thread` | title, messages, model, mode, checkpoints | +| **Runtime** | `ChatThreadRuntime` | `state.chat.threads[id]` | streaming, waiting, queue, confirmation, errors, attached_images | + +**Visibility modes**: +- **Open tab**: `id ∈ state.chat.open_thread_ids` (visible in toolbar) +- **Background runtime**: in `state.chat.threads` but not in `open_thread_ids` + +**Key files**: +- Types: `src/features/Chat/Thread/types.ts` +- Reducers: `src/features/Chat/Thread/reducer.ts` +- Selectors: `src/features/Chat/Thread/selectors.ts` + +### Per-Thread State Machine + +``` +┌─────────┐ user submits ┌─────────┐ first chunk ┌───────────┐ +│ IDLE │ ──────────────► │ WAITING │ ─────────────► │ STREAMING │ +└─────────┘ └─────────┘ └───────────┘ + ▲ │ + │ ┌─────────┐ │ + │◄─────────────────────│ PAUSED │◄─────────────────────┤ + │ user confirms └─────────┘ needs confirmation │ + │ │ + │ ┌─────────┐ │ + └──────────────────────│ STOPPED │◄─────────────────────┘ + doneStreaming └─────────┘ error/abort + (no more tools) +``` + +**State flags per runtime**: +```typescript +{ + streaming: boolean, // Currently receiving chunks + waiting_for_response: boolean, // Request sent, awaiting first chunk + prevent_send: boolean, // Blocked (error, abort, rejection) + error: string | null, // Error message if failed + confirmation: { + pause: boolean, // Waiting for user confirmation + pause_reasons: [], // Why paused (tool names, rules) + status: { + wasInteracted: boolean, // User has interacted with confirmation + confirmationStatus: boolean // Tools are confirmed + } + } +} +``` + +### Complete Chat Flow + +#### 1. User Sends Message +``` +ChatForm.onSubmit + → useSendChatRequest.submit() [hooks/useSendChatRequest.ts] + → if busy: enqueueUserMessage() [actions.ts → reducer.ts] + → else: sendMessages() + → setIsWaitingForResponse({id, true}) [reducer.ts] + → pre-flight confirmation check [toolsApi.checkForConfirmation] + → if pause: setThreadPauseReasons() + return early + → chatAskQuestionThunk() [actions.ts] +``` + +#### 2. Streaming Response +``` +chatAskQuestionThunk [actions.ts] + → sendChat(stream: true) [services/refact/chat.ts] + → for each chunk: dispatch(chatResponse()) [reducer.ts] + → streaming = true + → waiting_for_response = false + → merge chunk into messages (formatChatResponse) + → finally: dispatch(doneStreaming()) [reducer.ts] + → streaming = false + → postProcessMessagesAfterStreaming() +``` + +#### 3. Auto-Continuation (Middleware) +``` +doneStreaming listener [middleware.ts:346-393] + → resetThreadImages (if current thread) + → skip if: error, prevent_send, already paused + → selectHasUncalledToolsById() [selectors.ts] + → if uncalled tools exist: + → checkForConfirmation() [toolsApi] + → if pause needed: + → setThreadPauseReasons() + → auto-switch to thread (if background) + → return + → else: + → setIsWaitingForResponse({id, true}) + → chatAskQuestionThunk() to continue +``` + +#### 4. Tool Confirmation Flow +``` +setThreadPauseReasons [reducer.ts:567-577] + → pause = true + → pause_reasons = [...] + → confirmationStatus = false (blocks autosend) + → streaming = false + → waiting_for_response = false + +Auto-switch listener [middleware.ts:593-605] + → if thread ≠ current: switchToThread() + → switchToThread adds to open_thread_ids [reducer.ts:407-415] + +ChatForm renders ToolConfirmation [ChatForm.tsx:327-330] + when confirmation.pause === true + +User clicks Confirm → confirmToolUsage() [useSendChatRequest.ts:303-308] + → clearThreadPauseReasons() + → setThreadConfirmationStatus(wasInteracted: true) + → sendMessages(currentMessages) to continue +``` + +### Background Thread Handling + +#### Background Continuation (Option B) +Chats continue processing even without an open tab: + +```typescript +// closeThread preserves busy runtimes +builder.addCase(closeThread, (state, action) => { + state.open_thread_ids = state.open_thread_ids.filter(tid => tid !== id); + const rt = state.threads[id]; + // Only delete if safe (not streaming, waiting, or paused) + if (rt && (force || (!rt.streaming && !rt.waiting_for_response && !rt.confirmation.pause))) { + delete state.threads[id]; + } +}); +``` + +#### Auto-Switch on Confirmation +When a background thread needs confirmation, user is auto-switched: + +```typescript +// middleware.ts +startListening({ + actionCreator: setThreadPauseReasons, + effect: (action, listenerApi) => { + const currentThreadId = selectCurrentThreadId(state); + if (action.payload.id !== currentThreadId) { + listenerApi.dispatch(switchToThread({ id: action.payload.id })); + } + }, +}); +``` + +#### Restoring Background Threads +When user clicks a history item that has a background runtime: + +```typescript +// restoreChat adds to open_thread_ids if runtime exists +builder.addCase(restoreChat, (state, action) => { + const existingRt = getRuntime(state, action.payload.id); + if (existingRt) { + if (!state.open_thread_ids.includes(action.payload.id)) { + state.open_thread_ids.push(action.payload.id); + } + state.current_thread_id = action.payload.id; + return; // Don't overwrite existing runtime + } + // ... create new runtime from history +}); +``` + +### SSE Subscription (Metadata Sync) + +Backend sends trajectory updates via Server-Sent Events: + +```typescript +// useTrajectoriesSubscription.ts +eventSource.onmessage = (event) => { + const data: TrajectoryEvent = JSON.parse(event.data); + + if (data.type === "deleted") { + dispatch(deleteChatById(data.id)); + dispatch(closeThread({ id: data.id, force: true })); + } else if (data.type === "updated" || data.type === "created") { + // Fetch full trajectory and update + dispatch(hydrateHistory([trajectory])); + // IMPORTANT: Only sync metadata, NOT messages + dispatch(updateOpenThread({ + id: data.id, + thread: { + title: thread.title, + isTitleGenerated: thread.isTitleGenerated, + // NO messages - they are local-authoritative + }, + })); + } +}; +``` + +**Critical**: Messages are never synced from SSE to prevent overwriting in-progress conversations. + +### useAutoSend Hook + +Handles automatic continuation and queue flushing: + +```typescript +// useSendChatRequest.ts:351-462 +const stopForToolConfirmation = useMemo(() => { + if (isIntegration) return false; + if (isPaused) return true; // Hard stop when paused + return !wasInteracted && !areToolsConfirmed; +}, [isIntegration, isPaused, wasInteracted, areToolsConfirmed]); + +// Queue flushing +useEffect(() => { + if (queuedMessages.length === 0) return; + const nextQueued = queuedMessages[0]; + const isPriority = nextQueued.priority; + + // Priority: flush after streaming ends + // Regular: flush only when fully idle (no tools pending) + const canFlush = isPriority ? canFlushBase : isFullyIdle; + if (!canFlush) return; + + dispatch(dequeueUserMessage({ queuedId: nextQueued.id })); + void sendMessages([...currentMessages, nextQueued.message]); +}, [/* deps */]); +``` + +### Tab UI Indicators + +```typescript +// Toolbar.tsx - tab spinner logic +const tabs = open_thread_ids.map(id => { + const runtime = threads[id]; + return { + id, + title: runtime.thread.title, + streaming: runtime.streaming, + waiting: runtime.waiting_for_response, + }; +}); + +// Render spinner if busy +{(tab.streaming || tab.waiting) && } +``` + +```typescript +// HistoryItem.tsx - history list spinner +const runtime = threads[historyItem.id]; +const isBusy = runtime?.streaming || runtime?.waiting_for_response; +{isBusy && } +``` + +### File Reference Map + +| Concern | Primary File(s) | +|---------|-----------------| +| State types | `features/Chat/Thread/types.ts` | +| Actions | `features/Chat/Thread/actions.ts` | +| Reducers | `features/Chat/Thread/reducer.ts` | +| Selectors | `features/Chat/Thread/selectors.ts` | +| Send logic & hooks | `hooks/useSendChatRequest.ts` | +| Auto-continuation | `app/middleware.ts` (doneStreaming listener) | +| Background switch | `app/middleware.ts` (setThreadPauseReasons listener) | +| IDE tool handling | `app/middleware.ts` (ideToolCallResponse listener) | +| Tab UI | `components/Toolbar/Toolbar.tsx` | +| Chat form | `components/ChatForm/ChatForm.tsx` | +| Stop button | `components/ChatContent/ChatContent.tsx` | +| Confirmation UI | `components/ChatForm/ToolConfirmation.tsx` | +| SSE sync | `hooks/useTrajectoriesSubscription.ts` | +| History list | `components/ChatHistory/HistoryItem.tsx` | + +### Critical Invariants + +```typescript +// Chat can proceed if ALL true: +!runtime.streaming +!runtime.waiting_for_response +!runtime.prevent_send +!runtime.error +!runtime.confirmation.pause +!selectHasUncalledTools(state, chatId) + +// Confirmation blocks everything when: +runtime.confirmation.pause === true +// This sets confirmationStatus=false, which makes stopForToolConfirmation=true + +// Thread is safe to delete when: +!runtime.streaming && !runtime.waiting_for_response && !runtime.confirmation.pause + +// Auto-send is blocked when: +isPaused || (!wasInteracted && !areToolsConfirmed) +``` + +--- + ## Development Workflows ### How to Add a New Redux Slice diff --git a/refact-agent/gui/src/__fixtures__/chat.ts b/refact-agent/gui/src/__fixtures__/chat.ts index 523352ec1..dccb0247c 100644 --- a/refact-agent/gui/src/__fixtures__/chat.ts +++ b/refact-agent/gui/src/__fixtures__/chat.ts @@ -1,9 +1,8 @@ -import type { RootState } from "../app/store"; +import type { ChatThread } from "../features/Chat/Thread/types"; import { ChatHistoryItem } from "../features/History/historySlice"; export * from "./some_chrome_screenshots"; -type ChatThread = RootState["chat"]["thread"]; type ChatMessages = ChatThread["messages"]; export const MARS_ROVER_CHAT: ChatHistoryItem = { diff --git a/refact-agent/gui/src/__fixtures__/chat_config_thread.ts b/refact-agent/gui/src/__fixtures__/chat_config_thread.ts index 39e59fbb7..ffc273b7d 100644 --- a/refact-agent/gui/src/__fixtures__/chat_config_thread.ts +++ b/refact-agent/gui/src/__fixtures__/chat_config_thread.ts @@ -1,10 +1,15 @@ import type { Chat } from "../features/Chat/Thread"; +const THREAD_ID = "941fb8f4-409c-4430-a3b2-6450fafdb9f4"; + export const CHAT_CONFIG_THREAD: Chat = { - streaming: false, - thread: { - mode: "CONFIGURE", - id: "941fb8f4-409c-4430-a3b2-6450fafdb9f4", + current_thread_id: THREAD_ID, + open_thread_ids: [THREAD_ID], + threads: { + [THREAD_ID]: { + thread: { + mode: "CONFIGURE", + id: THREAD_ID, messages: [ { role: "user", @@ -482,16 +487,24 @@ export const CHAT_CONFIG_THREAD: Chat = { new_chat_suggested: { wasSuggested: false, }, - createdAt: "2024-12-02T14:42:18.902Z", - updatedAt: "2024-12-02T14:42:18.902Z", + createdAt: "2024-12-02T14:42:18.902Z", + updatedAt: "2024-12-02T14:42:18.902Z", + }, + streaming: false, + waiting_for_response: false, + prevent_send: true, + error: null, + queued_messages: [], + send_immediately: false, + attached_images: [], + confirmation: { + pause: false, + pause_reasons: [], + status: { wasInteracted: false, confirmationStatus: true }, + }, + }, }, - error: null, - prevent_send: true, - waiting_for_response: false, max_new_tokens: 4096, - cache: {}, system_prompt: {}, tool_use: "agent", - send_immediately: false, - queued_messages: [], }; diff --git a/refact-agent/gui/src/__fixtures__/msw.ts b/refact-agent/gui/src/__fixtures__/msw.ts index 7d8a1449c..d17931946 100644 --- a/refact-agent/gui/src/__fixtures__/msw.ts +++ b/refact-agent/gui/src/__fixtures__/msw.ts @@ -5,12 +5,10 @@ import { STUB_LINKS_FOR_CHAT_RESPONSE } from "./chat_links_response"; import { TOOLS, CHAT_LINKS_URL, - KNOWLEDGE_CREATE_URL, } from "../services/refact/consts"; import { STUB_TOOL_RESPONSE } from "./tools_response"; import { GoodPollingResponse } from "../services/smallcloud/types"; import type { LinksForChatResponse } from "../services/refact/links"; -import { SaveTrajectoryResponse } from "../services/refact/knowledge"; import { ToolConfirmationResponse } from "../services/refact"; export const goodPing: HttpHandler = http.get( @@ -136,16 +134,7 @@ export const goodTools: HttpHandler = http.get( }, ); -export const makeKnowledgeFromChat: HttpHandler = http.post( - `http://127.0.0.1:8001${KNOWLEDGE_CREATE_URL}`, - () => { - const result: SaveTrajectoryResponse = { - memid: "foo", - trajectory: "something", - }; - return HttpResponse.json(result); - }, -); + export const loginPollingGood: HttpHandler = http.get( "https://www.smallcloud.ai/v1/streamlined-login-recall-ticket", @@ -235,3 +224,31 @@ export const ToolConfirmation = http.post( return HttpResponse.json(response); }, ); + +export const emptyTrajectories: HttpHandler = http.get( + "http://127.0.0.1:8001/v1/trajectories", + () => { + return HttpResponse.json([]); + }, +); + +export const trajectoryGet: HttpHandler = http.get( + "http://127.0.0.1:8001/v1/trajectories/:id", + () => { + return HttpResponse.json({ status: "not_found" }, { status: 404 }); + }, +); + +export const trajectorySave: HttpHandler = http.put( + "http://127.0.0.1:8001/v1/trajectories/:id", + () => { + return HttpResponse.json({ status: "ok" }); + }, +); + +export const trajectoryDelete: HttpHandler = http.delete( + "http://127.0.0.1:8001/v1/trajectories/:id", + () => { + return HttpResponse.json({ status: "ok" }); + }, +); diff --git a/refact-agent/gui/src/__tests__/ChatCapsFetchError.test.tsx b/refact-agent/gui/src/__tests__/ChatCapsFetchError.test.tsx index 4ba4668b8..c9e69482f 100644 --- a/refact-agent/gui/src/__tests__/ChatCapsFetchError.test.tsx +++ b/refact-agent/gui/src/__tests__/ChatCapsFetchError.test.tsx @@ -10,6 +10,8 @@ import { chatLinks, telemetryChat, telemetryNetwork, + emptyTrajectories, + trajectorySave, } from "../utils/mockServer"; import { Chat } from "../features/Chat"; @@ -25,6 +27,8 @@ describe("chat caps error", () => { chatLinks, telemetryChat, telemetryNetwork, + emptyTrajectories, + trajectorySave, http.get("http://127.0.0.1:8001/v1/caps", () => { return HttpResponse.json( { diff --git a/refact-agent/gui/src/__tests__/DeleteChat.test.tsx b/refact-agent/gui/src/__tests__/DeleteChat.test.tsx index 7887b6292..faa24768d 100644 --- a/refact-agent/gui/src/__tests__/DeleteChat.test.tsx +++ b/refact-agent/gui/src/__tests__/DeleteChat.test.tsx @@ -8,6 +8,9 @@ import { telemetryChat, telemetryNetwork, goodCaps, + emptyTrajectories, + trajectorySave, + trajectoryDelete, } from "../utils/mockServer"; import { InnerApp } from "../features/App"; import { HistoryState } from "../features/History/historySlice"; @@ -20,6 +23,9 @@ describe("Delete a Chat form history", () => { telemetryChat, telemetryNetwork, goodCaps, + emptyTrajectories, + trajectorySave, + trajectoryDelete, ); it("can delete a chat", async () => { const now = new Date().toISOString(); diff --git a/refact-agent/gui/src/__tests__/RestoreChat.test.tsx b/refact-agent/gui/src/__tests__/RestoreChat.test.tsx index 144b2bf58..f1e5c877d 100644 --- a/refact-agent/gui/src/__tests__/RestoreChat.test.tsx +++ b/refact-agent/gui/src/__tests__/RestoreChat.test.tsx @@ -12,6 +12,8 @@ import { chatLinks, telemetryChat, telemetryNetwork, + emptyTrajectories, + trajectorySave, } from "../utils/mockServer"; import { InnerApp } from "../features/App"; @@ -28,6 +30,8 @@ describe("Restore Chat from history", () => { chatLinks, telemetryChat, telemetryNetwork, + emptyTrajectories, + trajectorySave, ); const { user, ...app } = render(, { diff --git a/refact-agent/gui/src/__tests__/StartNewChat.test.tsx b/refact-agent/gui/src/__tests__/StartNewChat.test.tsx index 62a464abe..99ed62dc8 100644 --- a/refact-agent/gui/src/__tests__/StartNewChat.test.tsx +++ b/refact-agent/gui/src/__tests__/StartNewChat.test.tsx @@ -13,6 +13,8 @@ import { telemetryChat, telemetryNetwork, goodCapsWithKnowledgeFeature, + emptyTrajectories, + trajectorySave, } from "../utils/mockServer"; import { InnerApp } from "../features/App"; import { stubResizeObserver } from "../utils/test-utils"; @@ -34,6 +36,8 @@ describe("Start a new chat", () => { chatLinks, telemetryChat, telemetryNetwork, + emptyTrajectories, + trajectorySave, ); }); diff --git a/refact-agent/gui/src/__tests__/UserSurvey.test.tsx b/refact-agent/gui/src/__tests__/UserSurvey.test.tsx index 86f48f919..17b464f8d 100644 --- a/refact-agent/gui/src/__tests__/UserSurvey.test.tsx +++ b/refact-agent/gui/src/__tests__/UserSurvey.test.tsx @@ -14,6 +14,8 @@ import { chatLinks, telemetryChat, telemetryNetwork, + emptyTrajectories, + trajectorySave, } from "../utils/mockServer"; import { InnerApp } from "../features/App"; @@ -66,6 +68,8 @@ describe("Start a new chat", () => { chatLinks, telemetryChat, telemetryNetwork, + emptyTrajectories, + trajectorySave, ); const { user, ...app } = render(, { diff --git a/refact-agent/gui/src/app/middleware.ts b/refact-agent/gui/src/app/middleware.ts index e8a4b8fe3..cab2b359e 100644 --- a/refact-agent/gui/src/app/middleware.ts +++ b/refact-agent/gui/src/app/middleware.ts @@ -15,6 +15,13 @@ import { sendCurrentChatToLspAfterToolCallUpdate, chatResponse, chatError, + selectHasUncalledToolsById, + clearThreadPauseReasons, + setThreadConfirmationStatus, + setThreadPauseReasons, + resetThreadImages, + switchToThread, + selectCurrentThreadId, } from "../features/Chat/Thread"; import { statisticsApi } from "../services/refact/statistics"; import { integrationsApi } from "../services/refact/integrations"; @@ -35,14 +42,9 @@ import { setIsAuthError, } from "../features/Errors/errorsSlice"; import { setThemeMode, updateConfig } from "../features/Config/configSlice"; -import { resetAttachedImagesSlice } from "../features/AttachedImages"; import { nextTip } from "../features/TipOfTheDay"; import { telemetryApi } from "../services/refact/telemetry"; import { CONFIG_PATH_URL, FULL_PATH_URL } from "../services/refact/consts"; -import { - resetConfirmationInteractedState, - updateConfirmationAfterIdeToolUse, -} from "../features/ToolConfirmation/confirmationSlice"; import { ideToolCallResponse, ideForceReloadProjectTreeFiles, @@ -60,24 +62,24 @@ const startListening = listenerMiddleware.startListening.withTypes< >(); startListening({ - // TODO: figure out why this breaks the tests when it's not a function :/ matcher: isAnyOf( (d: unknown): d is ReturnType => newChatAction.match(d), (d: unknown): d is ReturnType => restoreChat.match(d), ), effect: (_action, listenerApi) => { + const state = listenerApi.getState(); + const chatId = state.chat.current_thread_id; + [ - // pingApi.util.resetApiState(), statisticsApi.util.resetApiState(), - // capsApi.util.resetApiState(), - // promptsApi.util.resetApiState(), toolsApi.util.resetApiState(), commandsApi.util.resetApiState(), - resetAttachedImagesSlice(), - resetConfirmationInteractedState(), ].forEach((api) => listenerApi.dispatch(api)); + listenerApi.dispatch(resetThreadImages({ id: chatId })); + listenerApi.dispatch(clearThreadPauseReasons({ id: chatId })); + listenerApi.dispatch(setThreadConfirmationStatus({ id: chatId, wasInteracted: false, confirmationStatus: true })); listenerApi.dispatch(clearError()); }, }); @@ -343,11 +345,64 @@ startListening({ startListening({ actionCreator: doneStreaming, - effect: (action, listenerApi) => { + effect: async (action, listenerApi) => { const state = listenerApi.getState(); - if (action.payload.id === state.chat.thread.id) { - listenerApi.dispatch(resetAttachedImagesSlice()); + const chatId = action.payload.id; + const isCurrentThread = chatId === state.chat.current_thread_id; + + if (isCurrentThread) { + listenerApi.dispatch(resetThreadImages({ id: chatId })); + } + + const runtime = state.chat.threads[chatId]; + if (!runtime) return; + if (runtime.error) return; + if (runtime.prevent_send) return; + if (runtime.confirmation.pause) return; + + const hasUncalledTools = selectHasUncalledToolsById(state, chatId); + if (!hasUncalledTools) return; + + const lastMessage = runtime.thread.messages[runtime.thread.messages.length - 1]; + if (!lastMessage || !("tool_calls" in lastMessage) || !lastMessage.tool_calls) return; + + // IMPORTANT: Set waiting=true immediately to prevent race conditions + // This blocks any other sender (like useAutoSend) from starting a duplicate request + // during the async confirmation check below + listenerApi.dispatch(setIsWaitingForResponse({ id: chatId, value: true })); + + const isIntegrationChat = runtime.thread.mode === "CONFIGURE"; + if (!isIntegrationChat) { + const confirmationResult = await listenerApi.dispatch( + toolsApi.endpoints.checkForConfirmation.initiate({ + tool_calls: lastMessage.tool_calls, + messages: runtime.thread.messages, + }), + ); + + if ("data" in confirmationResult && confirmationResult.data?.pause) { + // setThreadPauseReasons will reset waiting_for_response to false + listenerApi.dispatch(setThreadPauseReasons({ id: chatId, pauseReasons: confirmationResult.data.pause_reasons })); + return; + } } + + // Re-check state after async operation to prevent duplicate requests + const latestState = listenerApi.getState(); + const latestRuntime = latestState.chat.threads[chatId]; + if (!latestRuntime) return; + if (latestRuntime.streaming) return; + if (latestRuntime.prevent_send) return; + if (latestRuntime.confirmation.pause) return; + + void listenerApi.dispatch( + chatAskQuestionThunk({ + messages: runtime.thread.messages, + chatId, + mode: runtime.thread.mode, + checkpointsEnabled: latestState.chat.checkpoints_enabled, + }), + ); }, }); @@ -377,12 +432,12 @@ startListening({ actionCreator: newIntegrationChat, effect: async (_action, listenerApi) => { const state = listenerApi.getState(); - // TODO: set mode to configure ? or infer it later - // TODO: create a dedicated thunk for this. + const runtime = state.chat.threads[state.chat.current_thread_id]; + if (!runtime) return; await listenerApi.dispatch( chatAskQuestionThunk({ - messages: state.chat.thread.messages, - chatId: state.chat.thread.id, + messages: runtime.thread.messages, + chatId: runtime.thread.id, }), ); }, @@ -407,11 +462,9 @@ startListening({ const state = listenerApi.getState(); if (chatAskQuestionThunk.rejected.match(action) && !action.meta.condition) { const { chatId, mode } = action.meta.arg; - const thread = - chatId in state.chat.cache - ? state.chat.cache[chatId] - : state.chat.thread; - const scope = `sendChat_${thread.model}_${mode}`; + const runtime = state.chat.threads[chatId]; + const thread = runtime?.thread; + const scope = `sendChat_${thread?.model ?? "unknown"}_${mode}`; if (isDetailMessageWithErrorType(action.payload)) { const errorMessage = action.payload.detail; @@ -431,11 +484,9 @@ startListening({ if (chatAskQuestionThunk.fulfilled.match(action)) { const { chatId, mode } = action.meta.arg; - const thread = - chatId in state.chat.cache - ? state.chat.cache[chatId] - : state.chat.thread; - const scope = `sendChat_${thread.model}_${mode}`; + const runtime = state.chat.threads[chatId]; + const thread = runtime?.thread; + const scope = `sendChat_${thread?.model ?? "unknown"}_${mode}`; const thunk = telemetryApi.endpoints.sendTelemetryChatEvent.initiate({ scope, @@ -500,29 +551,39 @@ startListening({ }, }); -// Tool Call results from ide. startListening({ actionCreator: ideToolCallResponse, effect: (action, listenerApi) => { const state = listenerApi.getState(); + const chatId = action.payload.chatId; + const runtime = state.chat.threads[chatId]; listenerApi.dispatch(upsertToolCallIntoHistory(action.payload)); listenerApi.dispatch(upsertToolCall(action.payload)); - listenerApi.dispatch(updateConfirmationAfterIdeToolUse(action.payload)); - const pauseReasons = state.confirmation.pauseReasons.filter( + if (!runtime) return; + + const pauseReasons = runtime.confirmation.pause_reasons.filter( (reason) => reason.tool_call_id !== action.payload.toolCallId, ); if (pauseReasons.length === 0) { - listenerApi.dispatch(resetConfirmationInteractedState()); - listenerApi.dispatch(setIsWaitingForResponse(false)); + listenerApi.dispatch(clearThreadPauseReasons({ id: chatId })); + listenerApi.dispatch(setThreadConfirmationStatus({ id: chatId, wasInteracted: true, confirmationStatus: true })); + // If we're about to dispatch a follow-up, set waiting=true; otherwise false + if (action.payload.accepted) { + listenerApi.dispatch(setIsWaitingForResponse({ id: chatId, value: true })); + } else { + listenerApi.dispatch(setIsWaitingForResponse({ id: chatId, value: false })); + } + } else { + listenerApi.dispatch(setThreadPauseReasons({ id: chatId, pauseReasons })); } if (pauseReasons.length === 0 && action.payload.accepted) { void listenerApi.dispatch( sendCurrentChatToLspAfterToolCallUpdate({ - chatId: action.payload.chatId, + chatId, toolCallId: action.payload.toolCallId, }), ); @@ -545,6 +606,21 @@ startListening({ }, }); +// Auto-switch to thread when it needs confirmation (background chat support) +startListening({ + actionCreator: setThreadPauseReasons, + effect: (action, listenerApi) => { + const state = listenerApi.getState(); + const currentThreadId = selectCurrentThreadId(state); + const threadIdNeedingConfirmation = action.payload.id; + + // If the thread needing confirmation is not the current one, switch to it + if (threadIdNeedingConfirmation !== currentThreadId) { + listenerApi.dispatch(switchToThread({ id: threadIdNeedingConfirmation })); + } + }, +}); + // JB file refresh // TBD: this could include diff messages to startListening({ diff --git a/refact-agent/gui/src/app/storage.ts b/refact-agent/gui/src/app/storage.ts index 3e4d18558..f584f9857 100644 --- a/refact-agent/gui/src/app/storage.ts +++ b/refact-agent/gui/src/app/storage.ts @@ -1,55 +1,4 @@ import type { WebStorage } from "redux-persist"; -import { - ChatHistoryItem, - HistoryState, -} from "../features/History/historySlice"; -import { parseOrElse } from "../utils"; - -type StoredState = { - tipOfTheDay: string; - tour: string; - history: string; -}; - -function getOldest(history: HistoryState): ChatHistoryItem | null { - const sorted = Object.values(history).sort((a, b) => { - return new Date(a.updatedAt).getTime() - new Date(b.updatedAt).getTime(); - }); - const oldest = sorted[0] ?? null; - return oldest; -} - -function prune(key: string, stored: StoredState) { - const history = parseOrElse(stored.history, {}); - const oldest = getOldest(history); - - if (!oldest) return; - const nextHistory = Object.values(history).reduce( - (acc, cur) => { - if (cur.id === oldest.id) return acc; - return { ...acc, [cur.id]: cur }; - }, - {}, - ); - const nextStorage = { ...stored, history: JSON.stringify(nextHistory) }; - try { - const newHistory = JSON.stringify(nextStorage); - localStorage.setItem(key, newHistory); - } catch (e) { - prune(key, nextStorage); - } -} - -function pruneHistory(key: string, item: string) { - const storedString = item; - if (!storedString) return; - try { - const stored = JSON.parse(storedString) as StoredState; - prune(key, stored); - } catch (e) { - /* empty */ - } -} function removeOldEntry(key: string) { if ( @@ -72,22 +21,22 @@ export function storage(): WebStorage { cleanOldEntries(); return { getItem(key: string): Promise { - return new Promise((resolve, _reject) => { + return new Promise((resolve) => { resolve(localStorage.getItem(key)); }); }, setItem(key: string, item: string): Promise { - return new Promise((resolve, _reject) => { + return new Promise((resolve) => { try { localStorage.setItem(key, item); } catch { - pruneHistory(key, item); + // Storage quota exceeded, ignore } resolve(); }); }, removeItem(key: string): Promise { - return new Promise((resolve, _reject) => { + return new Promise((resolve) => { localStorage.removeItem(key); resolve(); }); diff --git a/refact-agent/gui/src/app/store.ts b/refact-agent/gui/src/app/store.ts index b9a4ee02c..5cad7aa04 100644 --- a/refact-agent/gui/src/app/store.ts +++ b/refact-agent/gui/src/app/store.ts @@ -25,6 +25,7 @@ import { providersApi, modelsApi, teamsApi, + trajectoriesApi, } from "../services/refact"; import { smallCloudApi } from "../services/smallcloud"; import { reducer as fimReducer } from "../features/FIM/reducer"; @@ -44,8 +45,6 @@ import { pagesSlice } from "../features/Pages/pagesSlice"; import mergeInitialState from "redux-persist/lib/stateReconciler/autoMergeLevel2"; import { listenerMiddleware } from "./middleware"; import { informationSlice } from "../features/Errors/informationSlice"; -import { confirmationSlice } from "../features/ToolConfirmation/confirmationSlice"; -import { attachedImagesSlice } from "../features/AttachedImages"; import { teamsSlice } from "../features/Teams"; import { userSurveySlice } from "../features/UserSurvey/userSurveySlice"; import { linksApi } from "../services/refact/links"; @@ -95,6 +94,7 @@ const rootReducer = combineSlices( [teamsApi.reducerPath]: teamsApi.reducer, [providersApi.reducerPath]: providersApi.reducer, [modelsApi.reducerPath]: modelsApi.reducer, + [trajectoriesApi.reducerPath]: trajectoriesApi.reducer, }, historySlice, errorSlice, @@ -102,8 +102,6 @@ const rootReducer = combineSlices( pagesSlice, integrationsApi, dockerApi, - confirmationSlice, - attachedImagesSlice, userSurveySlice, teamsSlice, integrationsSlice, @@ -115,7 +113,7 @@ const rootReducer = combineSlices( const rootPersistConfig = { key: "root", storage: storage(), - whitelist: [historySlice.reducerPath, "tour", userSurveySlice.reducerPath], + whitelist: ["tour", userSurveySlice.reducerPath], stateReconciler: mergeInitialState, }; @@ -179,6 +177,7 @@ export function setUpStore(preloadedState?: Partial) { providersApi.middleware, modelsApi.middleware, teamsApi.middleware, + trajectoriesApi.middleware, ) .prepend(historyMiddleware.middleware) // .prepend(errorMiddleware.middleware) diff --git a/refact-agent/gui/src/components/Chat/Chat.stories.tsx b/refact-agent/gui/src/components/Chat/Chat.stories.tsx index 5fe2aaf3f..80d22bdab 100644 --- a/refact-agent/gui/src/components/Chat/Chat.stories.tsx +++ b/refact-agent/gui/src/components/Chat/Chat.stories.tsx @@ -20,7 +20,6 @@ import { goodTools, noTools, // noChatLinks, - makeKnowledgeFromChat, } from "../../__fixtures__/msw"; import { TourProvider } from "../../features/Tour"; import { Flex } from "@radix-ui/themes"; @@ -38,22 +37,34 @@ const Template: React.FC<{ wasSuggested: false, }, }; + const threadId = threadData.id ?? "test"; const store = setUpStore({ tour: { type: "finished", }, chat: { - streaming: false, - prevent_send: false, - waiting_for_response: false, + current_thread_id: threadId, + open_thread_ids: [threadId], + threads: { + [threadId]: { + thread: threadData, + streaming: false, + waiting_for_response: false, + prevent_send: false, + error: null, + queued_messages: [], + send_immediately: false, + attached_images: [], + confirmation: { + pause: false, + pause_reasons: [], + status: { wasInteracted: false, confirmationStatus: true }, + }, + }, + }, max_new_tokens: 4096, tool_use: "agent", - send_immediately: false, - error: null, - cache: {}, system_prompt: {}, - thread: threadData, - queued_messages: [], }, config, }); @@ -105,7 +116,8 @@ export const Primary: Story = {}; export const Configuration: Story = { args: { - thread: CHAT_CONFIG_THREAD.thread, + thread: + CHAT_CONFIG_THREAD.threads[CHAT_CONFIG_THREAD.current_thread_id]?.thread, }, }; @@ -148,7 +160,7 @@ export const Knowledge: Story = { // noChatLinks, chatLinks, noTools, - makeKnowledgeFromChat, + ], }, }, @@ -190,7 +202,7 @@ export const EmptySpaceAtBottom: Story = { // noChatLinks, chatLinks, noTools, - makeKnowledgeFromChat, + ], }, }, @@ -271,7 +283,7 @@ export const UserMessageEmptySpaceAtBottom: Story = { // noChatLinks, chatLinks, noTools, - makeKnowledgeFromChat, + ], }, }, @@ -354,7 +366,7 @@ export const CompressButton: Story = { // noChatLinks, chatLinks, noTools, - makeKnowledgeFromChat, + ], }, }, @@ -381,7 +393,6 @@ export const LowBalance: Story = { goodPrompts, chatLinks, noTools, - makeKnowledgeFromChat, lowBalance, }, }, diff --git a/refact-agent/gui/src/components/ChatContent/AssistantInput.tsx b/refact-agent/gui/src/components/ChatContent/AssistantInput.tsx index aba0860ad..6cc91e0a5 100644 --- a/refact-agent/gui/src/components/ChatContent/AssistantInput.tsx +++ b/refact-agent/gui/src/components/ChatContent/AssistantInput.tsx @@ -2,27 +2,18 @@ import React, { useCallback, useMemo } from "react"; import { Markdown } from "../Markdown"; import { Container, Box, Flex, Text, Link, Card } from "@radix-ui/themes"; -import { ToolCall, Usage, WebSearchCitation } from "../../services/refact"; +import { ToolCall, WebSearchCitation } from "../../services/refact"; import { ToolContent } from "./ToolsContent"; import { fallbackCopying } from "../../utils/fallbackCopying"; import { telemetryApi } from "../../services/refact/telemetry"; -import { LikeButton } from "./LikeButton"; -import { ResendButton } from "./ResendButton"; import { ReasoningContent } from "./ReasoningContent"; -import { MessageUsageInfo } from "./MessageUsageInfo"; type ChatInputProps = { message: string | null; reasoningContent?: string | null; toolCalls?: ToolCall[] | null; - serverExecutedTools?: ToolCall[] | null; // Tools that were executed by the provider (srvtoolu_*) + serverExecutedTools?: ToolCall[] | null; citations?: WebSearchCitation[] | null; - isLast?: boolean; - usage?: Usage | null; - metering_coins_prompt?: number; - metering_coins_generated?: number; - metering_coins_cache_creation?: number; - metering_coins_cache_read?: number; }; export const AssistantInput: React.FC = ({ @@ -31,12 +22,6 @@ export const AssistantInput: React.FC = ({ toolCalls, serverExecutedTools, citations, - isLast, - usage, - metering_coins_prompt, - metering_coins_generated, - metering_coins_cache_creation, - metering_coins_cache_read, }) => { const [sendTelemetryEvent] = telemetryApi.useLazySendTelemetryChatEventQuery(); @@ -85,18 +70,8 @@ export const AssistantInput: React.FC = ({ [sendTelemetryEvent], ); - const hasMessageFirst = !reasoningContent && message; - return ( - {reasoningContent && ( = ({ /> )} {message && ( - + {message} @@ -153,20 +128,6 @@ export const AssistantInput: React.FC = ({ )} {toolCalls && } - {isLast && ( - - - - - - - )} ); }; diff --git a/refact-agent/gui/src/components/ChatContent/ChatContent.stories.tsx b/refact-agent/gui/src/components/ChatContent/ChatContent.stories.tsx index e37fb28c4..df8254ecc 100644 --- a/refact-agent/gui/src/components/ChatContent/ChatContent.stories.tsx +++ b/refact-agent/gui/src/components/ChatContent/ChatContent.stories.tsx @@ -27,7 +27,6 @@ import { goodPing, goodPrompts, goodUser, - makeKnowledgeFromChat, noCommandPreview, noCompletions, noTools, @@ -46,19 +45,31 @@ const MockedStore: React.FC<{ wasSuggested: false, }, }; + const threadId = threadData.id ?? "test"; const store = setUpStore({ chat: { - streaming: false, - prevent_send: false, - waiting_for_response: false, + current_thread_id: threadId, + open_thread_ids: [threadId], + threads: { + [threadId]: { + thread: threadData, + streaming: false, + waiting_for_response: false, + prevent_send: false, + error: null, + queued_messages: [], + send_immediately: false, + attached_images: [], + confirmation: { + pause: false, + pause_reasons: [], + status: { wasInteracted: false, confirmationStatus: true }, + }, + }, + }, max_new_tokens: 4096, tool_use: "quick", - send_immediately: false, - error: null, - cache: {}, system_prompt: {}, - thread: threadData, - queued_messages: [], }, }); @@ -147,7 +158,8 @@ export const MultiModal: Story = { export const IntegrationChat: Story = { args: { - thread: CHAT_CONFIG_THREAD.thread, + thread: + CHAT_CONFIG_THREAD.threads[CHAT_CONFIG_THREAD.current_thread_id]?.thread, }, parameters: { msw: { @@ -173,7 +185,7 @@ export const TextDoc: Story = { goodUser, // noChatLinks, noTools, - makeKnowledgeFromChat, + ToolConfirmation, noCompletions, noCommandPreview, @@ -195,7 +207,7 @@ export const MarkdownIssue: Story = { goodUser, // noChatLinks, noTools, - makeKnowledgeFromChat, + ToolConfirmation, noCompletions, noCommandPreview, @@ -237,7 +249,7 @@ export const ToolWaiting: Story = { goodUser, // noChatLinks, noTools, - makeKnowledgeFromChat, + ToolConfirmation, noCompletions, noCommandPreview, diff --git a/refact-agent/gui/src/components/ChatContent/ChatContent.tsx b/refact-agent/gui/src/components/ChatContent/ChatContent.tsx index c77da7238..bb26b3919 100644 --- a/refact-agent/gui/src/components/ChatContent/ChatContent.tsx +++ b/refact-agent/gui/src/components/ChatContent/ChatContent.tsx @@ -1,7 +1,6 @@ import React, { useCallback, useMemo } from "react"; import { ChatMessages, - isAssistantMessage, isChatContextFileMessage, isDiffMessage, isToolMessage, @@ -14,7 +13,9 @@ import { Flex, Container, Button, Box } from "@radix-ui/themes"; import styles from "./ChatContent.module.css"; import { ContextFiles } from "./ContextFiles"; import { AssistantInput } from "./AssistantInput"; + import { PlainText } from "./PlainText"; +import { MessageUsageInfo } from "./MessageUsageInfo"; import { useAppDispatch, useDiffFileReload } from "../../hooks"; import { useAppSelector } from "../../hooks"; import { @@ -31,13 +32,10 @@ import { popBackTo } from "../../features/Pages/pagesSlice"; import { ChatLinks, UncommittedChangesWarning } from "../ChatLinks"; import { telemetryApi } from "../../services/refact/telemetry"; import { PlaceHolderText } from "./PlaceHolderText"; -import { UsageCounter } from "../UsageCounter"; + import { QueuedMessage } from "./QueuedMessage"; -import { - getConfirmationPauseStatus, - getPauseReasonsWithPauseStatus, -} from "../../features/ToolConfirmation/confirmationSlice"; -import { useUsageCounter } from "../UsageCounter/useUsageCounter.ts"; +import { selectThreadConfirmation, selectThreadPause } from "../../features/Chat"; + import { LogoAnimation } from "../LogoAnimation/LogoAnimation.tsx"; export type ChatContentProps = { @@ -50,18 +48,18 @@ export const ChatContent: React.FC = ({ onRetry, }) => { const dispatch = useAppDispatch(); - const pauseReasonsWithPause = useAppSelector(getPauseReasonsWithPauseStatus); + const pauseReasonsWithPause = useAppSelector(selectThreadConfirmation); const messages = useAppSelector(selectMessages); const queuedMessages = useAppSelector(selectQueuedMessages); const isStreaming = useAppSelector(selectIsStreaming); const thread = useAppSelector(selectThread); - const { shouldShow } = useUsageCounter(); - const isConfig = thread.mode === "CONFIGURE"; + + const isConfig = thread?.mode === "CONFIGURE"; const isWaiting = useAppSelector(selectIsWaiting); const [sendTelemetryEvent] = telemetryApi.useLazySendTelemetryChatEventQuery(); const integrationMeta = useAppSelector(selectIntegration); - const isWaitingForConfirmation = useAppSelector(getConfirmationPauseStatus); + const isWaitingForConfirmation = useAppSelector(selectThreadPause); const onRetryWrapper = (index: number, question: UserMessage["content"]) => { onRetry(index, question); @@ -74,18 +72,18 @@ export const ChatContent: React.FC = ({ dispatch( popBackTo({ name: "integrations page", - projectPath: thread.integration?.project, - integrationName: thread.integration?.name, - integrationPath: thread.integration?.path, + projectPath: thread?.integration?.project, + integrationName: thread?.integration?.name, + integrationPath: thread?.integration?.path, wasOpenedThroughChat: true, }), ); }, [ onStopStreaming, dispatch, - thread.integration?.project, - thread.integration?.name, - thread.integration?.path, + thread?.integration?.project, + thread?.integration?.name, + thread?.integration?.path, ]); const handleManualStopStreamingClick = useCallback(() => { @@ -138,7 +136,6 @@ export const ChatContent: React.FC = ({ - {shouldShow && } {!isWaitingForConfirmation && ( 0) { + const nextMsg = tempTail[0]; + if (isToolMessage(nextMsg)) { + // Skip tool messages (they're handled internally) + skipCount++; + tempTail = tempTail.slice(1); + } else if (isChatContextFileMessage(nextMsg)) { + // Collect context_file messages to render after assistant + const ctxKey = "context-file-" + (index + 1 + skipCount); + contextFilesAfter.push(); + skipCount++; + tempTail = tempTail.slice(1); + } else { + // Stop at any other message type (user, assistant, etc.) + break; + } + } + const nextMemo = [ ...memo, , + ...contextFilesAfter, + , ]; - return renderMessages(tail, onRetry, waiting, nextMemo, index + 1); + // Skip the tool and context_file messages we already processed + const newTail = tail.slice(skipCount); + return renderMessages(newTail, onRetry, waiting, nextMemo, index + 1 + skipCount); } if (head.role === "user") { diff --git a/refact-agent/gui/src/components/ChatContent/LikeButton.module.css b/refact-agent/gui/src/components/ChatContent/LikeButton.module.css deleted file mode 100644 index 094b72b5b..000000000 --- a/refact-agent/gui/src/components/ChatContent/LikeButton.module.css +++ /dev/null @@ -1,20 +0,0 @@ -.like__button__success { - animation: successAnimation 0.5s ease-in-out; - animation-fill-mode: forwards; -} - -@keyframes successAnimation { - 0% { - transform: scale(1); - color: var(--green-9); - } - 50% { - transform: scale(1.2); - color: var(--yellow-9); - } - 100% { - transform: scale(1); - color: var(--blue-9); - display: none; - } -} diff --git a/refact-agent/gui/src/components/ChatContent/LikeButton.tsx b/refact-agent/gui/src/components/ChatContent/LikeButton.tsx deleted file mode 100644 index 3f17c9bbe..000000000 --- a/refact-agent/gui/src/components/ChatContent/LikeButton.tsx +++ /dev/null @@ -1,76 +0,0 @@ -import React from "react"; -import { IconButton, Tooltip } from "@radix-ui/themes"; -import classnames from "classnames"; -import { knowledgeApi } from "../../services/refact/knowledge"; -import { useAppSelector } from "../../hooks"; -import { - selectIsStreaming, - selectIsWaiting, - selectMessages, -} from "../../features/Chat"; -import styles from "./LikeButton.module.css"; -import { useSelector } from "react-redux"; -import { selectThreadProjectOrCurrentProject } from "../../features/Chat/currentProject"; - -function useCreateMemory() { - const messages = useAppSelector(selectMessages); - const isStreaming = useAppSelector(selectIsStreaming); - const isWaiting = useAppSelector(selectIsWaiting); - const currentProjectName = useSelector(selectThreadProjectOrCurrentProject); - const [saveTrajectory, saveResponse] = - knowledgeApi.useCreateNewMemoryFromMessagesMutation(); - - const submitSave = React.useCallback(() => { - void saveTrajectory({ project: currentProjectName, messages }); - }, [currentProjectName, messages, saveTrajectory]); - - const shouldShow = React.useMemo(() => { - if (messages.length === 0) return false; - if (isStreaming) return false; - if (isWaiting) return false; - return true; - }, [messages.length, isStreaming, isWaiting]); - - return { submitSave, saveResponse, shouldShow }; -} - -export const LikeButton = () => { - const { submitSave, saveResponse, shouldShow } = useCreateMemory(); - - if (!shouldShow) return null; - return ( - - - - - - ); -}; - -const SaveIcon: React.FC = () => { - return ( - - - - ); -}; diff --git a/refact-agent/gui/src/components/ChatContent/MessageUsageInfo.tsx b/refact-agent/gui/src/components/ChatContent/MessageUsageInfo.tsx index 960836741..880153abc 100644 --- a/refact-agent/gui/src/components/ChatContent/MessageUsageInfo.tsx +++ b/refact-agent/gui/src/components/ChatContent/MessageUsageInfo.tsx @@ -11,7 +11,6 @@ type MessageUsageInfoProps = { metering_coins_generated?: number; metering_coins_cache_creation?: number; metering_coins_cache_read?: number; - topOffset?: string; }; const TokenDisplay: React.FC<{ label: string; value: number }> = ({ @@ -48,7 +47,6 @@ export const MessageUsageInfo: React.FC = ({ metering_coins_generated = 0, metering_coins_cache_creation = 0, metering_coins_cache_read = 0, - topOffset = "0", }) => { const outputTokens = useMemo(() => { return calculateUsageInputTokens({ @@ -76,13 +74,7 @@ export const MessageUsageInfo: React.FC = ({ if (!usage && totalCoins === 0) return null; return ( - + = ({ cursor: "pointer", }} > - - {Math.round(totalCoins)} - + + {contextTokens > 0 && ( + + ctx: + {formatNumberToFixed(contextTokens)} + + )} + + {Math.round(totalCoins)} + + @@ -160,6 +160,6 @@ export const MessageUsageInfo: React.FC = ({ - + ); }; diff --git a/refact-agent/gui/src/components/ChatContent/ResendButton.tsx b/refact-agent/gui/src/components/ChatContent/ResendButton.tsx index c0f1407ed..ac7e12802 100644 --- a/refact-agent/gui/src/components/ChatContent/ResendButton.tsx +++ b/refact-agent/gui/src/components/ChatContent/ResendButton.tsx @@ -37,7 +37,7 @@ export const ResendButton = () => { return ( - + @@ -47,8 +47,8 @@ export const ResendButton = () => { const ResendIcon: React.FC = () => { return ( { const isPatchAutomatic = useAppSelector(selectAutomaticPatch); const isAgentRollbackEnabled = useAppSelector(selectCheckpointsEnabled); const areFollowUpsEnabled = useAppSelector(selectAreFollowUpsEnabled); - const isTitleGenerationEnabled = useAppSelector( - selectIsTitleGenerationEnabled, - ); const useCompression = useAppSelector(selectUseCompression); const includeProjectInfo = useAppSelector(selectIncludeProjectInfo); const messages = useAppSelector(selectMessages); const isNewChat = messages.length === 0; + const { shouldShow: shouldShowUsage } = useUsageCounter(); const agenticFeatures = useMemo(() => { return [ @@ -60,11 +59,6 @@ export const AgentCapabilities = () => { enabled: areFollowUpsEnabled, switcher: , }, - { - name: "Chat Titles", - enabled: isTitleGenerationEnabled, - switcher: , - }, { name: "Compression", enabled: useCompression, @@ -81,7 +75,6 @@ export const AgentCapabilities = () => { isPatchAutomatic, isAgentRollbackEnabled, areFollowUpsEnabled, - isTitleGenerationEnabled, useCompression, includeProjectInfo, isNewChat, @@ -99,38 +92,45 @@ export const AgentCapabilities = () => { ); return ( - - - - - - - - - - {agenticFeatures.map((feature) => { - if ("hide" in feature && feature.hide) return null; - return {feature.switcher}; - })} - - - - - - - Enabled Features: - {enabledAgenticFeatures} - - - - - - - - Here you can control special features affecting Agent behaviour - - - + + + + + + + + + + + {agenticFeatures.map((feature) => { + if ("hide" in feature && feature.hide) return null; + return {feature.switcher}; + })} + + + + + + + Enabled Features: + {enabledAgenticFeatures} + + + + + + + + Here you can control special features affecting Agent behaviour + + + + + {shouldShowUsage && ( + + + + )} ); }; diff --git a/refact-agent/gui/src/components/ChatForm/ChatControls.tsx b/refact-agent/gui/src/components/ChatForm/ChatControls.tsx index aa3678144..457d85fbb 100644 --- a/refact-agent/gui/src/components/ChatForm/ChatControls.tsx +++ b/refact-agent/gui/src/components/ChatForm/ChatControls.tsx @@ -32,14 +32,12 @@ import { selectChatId, selectCheckpointsEnabled, selectIsStreaming, - selectIsTitleGenerationEnabled, selectIsWaiting, selectMessages, selectToolUse, selectUseCompression, selectIncludeProjectInfo, setAreFollowUpsEnabled, - setIsTitleGenerationEnabled, setAutomaticPatch, setEnabledCheckpoints, setToolUse, @@ -238,72 +236,6 @@ export const FollowUpsSwitch: React.FC = () => { ); }; -export const TitleGenerationSwitch: React.FC = () => { - const dispatch = useAppDispatch(); - const isTitleGenerationEnabled = useAppSelector( - selectIsTitleGenerationEnabled, - ); - - const handleTitleGenerationEnabledChange = (checked: boolean) => { - dispatch(setIsTitleGenerationEnabled(checked)); - }; - - return ( - - - Chat Titles - - - - - - - - - - - When enabled, Refact Agent will automatically generate - summarized chat title for the conversation - - - - - - Warning: may increase coins spending - - - - - - - - - ); -}; - export const UseCompressionSwitch: React.FC = () => { const dispatch = useAppDispatch(); const useCompression = useAppSelector(selectUseCompression); diff --git a/refact-agent/gui/src/components/ChatForm/ChatForm.test.tsx b/refact-agent/gui/src/components/ChatForm/ChatForm.test.tsx index 698979e8c..73763350d 100644 --- a/refact-agent/gui/src/components/ChatForm/ChatForm.test.tsx +++ b/refact-agent/gui/src/components/ChatForm/ChatForm.test.tsx @@ -13,6 +13,8 @@ import { noCompletions, goodPing, goodUser, + emptyTrajectories, + trajectorySave, } from "../../utils/mockServer"; const handlers = [ @@ -23,6 +25,8 @@ const handlers = [ noCommandPreview, noCompletions, goodPing, + emptyTrajectories, + trajectorySave, ]; server.use(...handlers); diff --git a/refact-agent/gui/src/components/ChatForm/ChatForm.tsx b/refact-agent/gui/src/components/ChatForm/ChatForm.tsx index 5b7b02f0a..16131c514 100644 --- a/refact-agent/gui/src/components/ChatForm/ChatForm.tsx +++ b/refact-agent/gui/src/components/ChatForm/ChatForm.tsx @@ -47,8 +47,9 @@ import { InformationCallout, } from "../Callout/Callout"; import { ToolConfirmation } from "./ToolConfirmation"; -import { getPauseReasonsWithPauseStatus } from "../../features/ToolConfirmation/confirmationSlice"; +import { selectThreadConfirmation } from "../../features/Chat"; import { AttachImagesButton, FileList } from "../Dropzone"; +import { ResendButton } from "../ChatContent/ResendButton"; import { useAttachedImages } from "../../hooks/useAttachedImages"; import { selectChatError, @@ -92,7 +93,7 @@ export const ChatForm: React.FC = ({ const globalErrorType = useAppSelector(getErrorType); const chatError = useAppSelector(selectChatError); const information = useAppSelector(getInformationMessage); - const pauseReasonsWithPause = useAppSelector(getPauseReasonsWithPauseStatus); + const pauseReasonsWithPause = useAppSelector(selectThreadConfirmation); const [helpInfo, setHelpInfo] = React.useState(null); const isOnline = useIsOnline(); const { retry } = useSendChatRequest(); @@ -324,9 +325,9 @@ export const ChatForm: React.FC = ({ ); } - if (!isStreaming && pauseReasonsWithPause.pause) { + if (pauseReasonsWithPause.pause) { return ( - + ); } @@ -449,6 +450,7 @@ export const ChatForm: React.FC = ({ )} {/* TODO: Reserved space for microphone button coming later on */} + { - const actions = [ - newChatAction(), - clearPauseReasonsAndHandleToolsStatus({ - wasInteracted: false, - confirmationStatus: true, - }), - popBackTo({ name: "history" }), - push({ name: "chat" }), - ]; - - actions.forEach((action) => dispatch(action)); + dispatch(newChatAction()); + dispatch(clearThreadPauseReasons({ id: chatId })); + dispatch(setThreadConfirmationStatus({ id: chatId, wasInteracted: false, confirmationStatus: true })); + dispatch(popBackTo({ name: "history" })); + dispatch(push({ name: "chat" })); void sendTelemetryEvent({ scope: `openNewChat`, success: true, error_message: "", }); - }, [dispatch, sendTelemetryEvent]); + }, [dispatch, chatId, sendTelemetryEvent]); const tipText = useMemo(() => { if (isWarning) diff --git a/refact-agent/gui/src/components/ChatForm/ToolConfirmation.stories.tsx b/refact-agent/gui/src/components/ChatForm/ToolConfirmation.stories.tsx index 757c3c044..1ee6c4817 100644 --- a/refact-agent/gui/src/components/ChatForm/ToolConfirmation.stories.tsx +++ b/refact-agent/gui/src/components/ChatForm/ToolConfirmation.stories.tsx @@ -16,16 +16,7 @@ import { const MockedStore: React.FC<{ pauseReasons: ToolConfirmationPauseReason[]; }> = ({ pauseReasons }) => { - const store = setUpStore({ - confirmation: { - pauseReasons, - pause: true, - status: { - wasInteracted: false, - confirmationStatus: false, - }, - }, - }); + const store = setUpStore(); return ( diff --git a/refact-agent/gui/src/components/ChatHistory/HistoryItem.tsx b/refact-agent/gui/src/components/ChatHistory/HistoryItem.tsx index e5a6c9d38..1281036c7 100644 --- a/refact-agent/gui/src/components/ChatHistory/HistoryItem.tsx +++ b/refact-agent/gui/src/components/ChatHistory/HistoryItem.tsx @@ -19,7 +19,7 @@ export const HistoryItem: React.FC<{ }> = ({ historyItem, onClick, onDelete, onOpenInTab, disabled }) => { const dateCreated = new Date(historyItem.createdAt); const dateTimeString = dateCreated.toLocaleString(); - const cache = useAppSelector((app) => app.chat.cache); + const threads = useAppSelector((app) => app.chat.threads); const totalCost = useMemo(() => { const totals = getTotalCostMeteringForMessages(historyItem.messages); @@ -34,7 +34,9 @@ export const HistoryItem: React.FC<{ ); }, [historyItem.messages]); - const isStreaming = historyItem.id in cache; + const isStreaming = threads[historyItem.id]?.streaming ?? false; + const isWaiting = threads[historyItem.id]?.waiting_for_response ?? false; + const isBusy = isStreaming || isWaiting; return ( - - {isStreaming && } - {!isStreaming && historyItem.read === false && ( + + {isBusy && } + {!isBusy && historyItem.read === false && ( )} = ({ takingNotes, style }) => { const onHistoryItemClick = useCallback( (thread: ChatHistoryItem) => { - dispatch(restoreChat(thread)); + // Fetch fresh data from backend before restoring + void dispatch(restoreChatFromBackend({ id: thread.id, fallback: thread })); dispatch(push({ name: "chat" })); }, [dispatch], diff --git a/refact-agent/gui/src/components/Toolbar/Toolbar.tsx b/refact-agent/gui/src/components/Toolbar/Toolbar.tsx index 4b8307594..7f0d124e4 100644 --- a/refact-agent/gui/src/components/Toolbar/Toolbar.tsx +++ b/refact-agent/gui/src/components/Toolbar/Toolbar.tsx @@ -10,6 +10,7 @@ import { } from "@radix-ui/themes"; import { Dropdown, DropdownNavigationOptions } from "./Dropdown"; import { + Cross1Icon, DotFilledIcon, DotsVerticalIcon, HomeIcon, @@ -21,6 +22,7 @@ import { popBackTo, push } from "../../features/Pages/pagesSlice"; import { ChangeEvent, KeyboardEvent, + MouseEvent, useCallback, useEffect, useMemo, @@ -29,10 +31,18 @@ import { } from "react"; import { deleteChatById, - getHistory, updateChatTitleById, } from "../../features/History/historySlice"; -import { restoreChat, saveTitle, selectThread } from "../../features/Chat"; +import { + saveTitle, + selectOpenThreadIds, + selectAllThreads, + closeThread, + switchToThread, + selectChatId, + clearThreadPauseReasons, + setThreadConfirmationStatus, +} from "../../features/Chat"; import { TruncateLeft } from "../Text"; import { useAppDispatch, @@ -40,7 +50,6 @@ import { useEventsBusForIDE, } from "../../hooks"; import { useWindowDimensions } from "../../hooks/useWindowDimensions"; -import { clearPauseReasonsAndHandleToolsStatus } from "../../features/ToolConfirmation/confirmationSlice"; import { telemetryApi } from "../../services/refact/telemetry"; import styles from "./Toolbar.module.css"; @@ -80,24 +89,16 @@ export const Toolbar = ({ activeTab }: ToolbarProps) => { const [sendTelemetryEvent] = telemetryApi.useLazySendTelemetryChatEventQuery(); - const history = useAppSelector(getHistory, { - devModeChecks: { stabilityCheck: "never" }, - }); - const isStreaming = useAppSelector((app) => app.chat.streaming); - const { isTitleGenerated, id: chatId } = useAppSelector(selectThread); - const cache = useAppSelector((app) => app.chat.cache); + const openThreadIds = useAppSelector(selectOpenThreadIds); + const allThreads = useAppSelector(selectAllThreads); + const currentChatId = useAppSelector(selectChatId); const { newChatEnabled } = useActiveTeamsGroup(); const { openSettings, openHotKeys } = useEventsBusForIDE(); - const [isOnlyOneChatTab, setIsOnlyOneChatTab] = useState(false); - const [isRenaming, setIsRenaming] = useState(false); + const [renamingTabId, setRenamingTabId] = useState(null); const [newTitle, setNewTitle] = useState(null); - const shouldChatTabLinkBeNotClickable = useMemo(() => { - return isOnlyOneChatTab && !isDashboardTab(activeTab); - }, [isOnlyOneChatTab, activeTab]); - const handleNavigation = useCallback( (to: DropdownNavigationOptions | "chat") => { if (to === "settings") { @@ -160,33 +161,46 @@ export const Toolbar = ({ activeTab }: ToolbarProps) => { ); const onCreateNewChat = useCallback(() => { - setIsRenaming((prev) => (prev ? !prev : prev)); + setRenamingTabId(null); + + // Auto-close empty chat tab when creating a new chat + if (currentChatId) { + const currentThread = allThreads[currentChatId]; + if (currentThread && currentThread.thread.messages.length === 0) { + dispatch(closeThread({ id: currentChatId })); + } + } + dispatch(newChatAction()); - dispatch( - clearPauseReasonsAndHandleToolsStatus({ - wasInteracted: false, - confirmationStatus: true, - }), - ); + dispatch(clearThreadPauseReasons({ id: currentChatId })); + dispatch(setThreadConfirmationStatus({ id: currentChatId, wasInteracted: false, confirmationStatus: true })); handleNavigation("chat"); void sendTelemetryEvent({ scope: `openNewChat`, success: true, error_message: "", }); - }, [dispatch, sendTelemetryEvent, handleNavigation]); + }, [dispatch, currentChatId, allThreads, sendTelemetryEvent, handleNavigation]); const goToTab = useCallback( (tab: Tab) => { + // Auto-close empty chat tab when navigating away + if (isChatTab(activeTab)) { + const currentThread = allThreads[activeTab.id]; + const isNavigatingToSameTab = isChatTab(tab) && tab.id === activeTab.id; + if ( + !isNavigatingToSameTab && + currentThread && + currentThread.thread.messages.length === 0 + ) { + dispatch(closeThread({ id: activeTab.id })); + } + } + if (tab.type === "dashboard") { dispatch(popBackTo({ name: "history" })); - dispatch(newChatAction()); } else { - if (shouldChatTabLinkBeNotClickable) return; - const chat = history.find((chat) => chat.id === tab.id); - if (chat != undefined) { - dispatch(restoreChat(chat)); - } + dispatch(switchToThread({ id: tab.id })); dispatch(popBackTo({ name: "history" })); dispatch(push({ name: "chat" })); } @@ -196,7 +210,7 @@ export const Toolbar = ({ activeTab }: ToolbarProps) => { error_message: "", }); }, - [dispatch, history, shouldChatTabLinkBeNotClickable, sendTelemetryEvent], + [dispatch, sendTelemetryEvent, activeTab, allThreads], ); useEffect(() => { @@ -217,58 +231,77 @@ export const Toolbar = ({ activeTab }: ToolbarProps) => { }, [focus]); const tabs = useMemo(() => { - return history.filter( - (chat) => - chat.read === false || - (activeTab.type === "chat" && activeTab.id == chat.id), - ); - }, [history, activeTab]); + return openThreadIds + .map((id) => { + const runtime = allThreads[id]; + if (!runtime) return null; + return { + id, + title: runtime.thread.title || "New Chat", + read: runtime.thread.read, + streaming: runtime.streaming, + waiting: runtime.waiting_for_response, + }; + }) + .filter((t): t is NonNullable => t !== null); + }, [openThreadIds, allThreads]); const shouldCollapse = useMemo(() => { - const dashboardWidth = windowWidth < 400 ? 47 : 70; // todo: compute this + const dashboardWidth = windowWidth < 400 ? 47 : 70; const totalWidth = dashboardWidth + 140 * tabs.length; return tabNavWidth < totalWidth; }, [tabNavWidth, tabs.length, windowWidth]); - const handleChatThreadDeletion = useCallback(() => { - dispatch(deleteChatById(chatId)); - goToTab({ type: "dashboard" }); - }, [dispatch, chatId, goToTab]); + const handleChatThreadDeletion = useCallback((tabId: string) => { + dispatch(deleteChatById(tabId)); + dispatch(closeThread({ id: tabId })); + if (activeTab.type === "chat" && activeTab.id === tabId) { + goToTab({ type: "dashboard" }); + } + }, [dispatch, activeTab, goToTab]); - const handleChatThreadRenaming = useCallback(() => { - setIsRenaming(true); + const handleChatThreadRenaming = useCallback((tabId: string) => { + setRenamingTabId(tabId); }, []); const handleKeyUpOnRename = useCallback( - (event: KeyboardEvent) => { + (event: KeyboardEvent, tabId: string) => { if (event.code === "Escape") { - setIsRenaming(false); + setRenamingTabId(null); } if (event.code === "Enter") { - setIsRenaming(false); + setRenamingTabId(null); if (!newTitle || newTitle.trim() === "") return; - if (!isTitleGenerated) { - dispatch( - saveTitle({ - id: chatId, - title: newTitle, - isTitleGenerated: true, - }), - ); - } - dispatch(updateChatTitleById({ chatId: chatId, newTitle: newTitle })); + dispatch( + saveTitle({ + id: tabId, + title: newTitle, + isTitleGenerated: true, + }), + ); + dispatch(updateChatTitleById({ chatId: tabId, newTitle: newTitle })); } }, - [dispatch, newTitle, chatId, isTitleGenerated], + [dispatch, newTitle], ); const handleChatTitleChange = (event: ChangeEvent) => { setNewTitle(event.target.value); }; - useEffect(() => { - setIsOnlyOneChatTab(tabs.length < 2); - }, [tabs]); + const handleCloseTab = useCallback((event: MouseEvent, tabId: string) => { + event.stopPropagation(); + event.preventDefault(); + dispatch(closeThread({ id: tabId })); + if (activeTab.type === "chat" && activeTab.id === tabId) { + const remainingTabs = tabs.filter((t) => t.id !== tabId); + if (remainingTabs.length > 0) { + goToTab({ type: "chat", id: remainingTabs[0].id }); + } else { + goToTab({ type: "dashboard" }); + } + } + }, [dispatch, activeTab, tabs, goToTab]); return ( @@ -278,29 +311,28 @@ export const Toolbar = ({ activeTab }: ToolbarProps) => { active={isDashboardTab(activeTab)} ref={(x) => refs.setBack(x)} onClick={() => { - setIsRenaming((prev) => (prev ? !prev : prev)); + setRenamingTabId(null); goToTab({ type: "dashboard" }); }} style={{ cursor: "pointer" }} > {windowWidth < 400 || shouldCollapse ? : "Home"} - {tabs.map((chat) => { - const isStreamingThisTab = - chat.id in cache || - (isChatTab(activeTab) && chat.id === activeTab.id && isStreaming); - const isActive = isChatTab(activeTab) && activeTab.id == chat.id; + {tabs.map((tab) => { + const isActive = isChatTab(activeTab) && activeTab.id === tab.id; + const isRenaming = renamingTabId === tab.id; + if (isRenaming) { return ( setIsRenaming(false)} + onKeyUp={(e) => handleKeyUpOnRename(e, tab.id)} + onBlur={() => setRenamingTabId(null)} autoFocus size="2" - defaultValue={isTitleGenerated ? chat.title : ""} + defaultValue={tab.title} onChange={handleChatTitleChange} className={styles.RenameInput} /> @@ -309,35 +341,31 @@ export const Toolbar = ({ activeTab }: ToolbarProps) => { return ( { - if (shouldChatTabLinkBeNotClickable) return; - goToTab({ type: "chat", id: chat.id }); - }} + key={tab.id} + onClick={() => goToTab({ type: "chat", id: tab.id })} style={{ minWidth: 0, maxWidth: "150px", cursor: "pointer" }} ref={isActive ? setFocus : undefined} - title={chat.title} + title={tab.title} > - {isStreamingThisTab && } - {!isStreamingThisTab && chat.read === false && ( - - )} + {(tab.streaming || tab.waiting) && } + {!tab.streaming && !tab.waiting && tab.read === false && } - {chat.title} + {tab.title} - {isActive && !isStreamingThisTab && isOnlyOneChatTab && ( + e.stopPropagation()} > @@ -346,22 +374,29 @@ export const Toolbar = ({ activeTab }: ToolbarProps) => { size="1" side="bottom" align="end" - style={{ - minWidth: 110, - }} + style={{ minWidth: 110 }} > - + handleChatThreadRenaming(tab.id)}> Rename handleChatThreadDeletion(tab.id)} color="red" > Delete chat - )} + handleCloseTab(e, tab.id)} + > + + + ); diff --git a/refact-agent/gui/src/components/UsageCounter/UsageCounter.module.css b/refact-agent/gui/src/components/UsageCounter/UsageCounter.module.css index bccdf5a9c..444f74a8a 100644 --- a/refact-agent/gui/src/components/UsageCounter/UsageCounter.module.css +++ b/refact-agent/gui/src/components/UsageCounter/UsageCounter.module.css @@ -3,12 +3,44 @@ margin-right: var(--space-3); display: flex; align-items: center; - padding: var(--space-2) var(--space-3); + padding: var(--space-1) var(--space-2); gap: 8px; max-width: max-content; opacity: 0.7; } +.usageCounterBorderless { + --base-card-surface-box-shadow: none; + --base-card-surface-hover-box-shadow: none; + --base-card-surface-active-box-shadow: none; + background: transparent; + padding: 0; + margin-right: var(--space-2); +} + +.circularProgress { + transform: rotate(-90deg); +} + +.circularProgressBg { + fill: none; + stroke: var(--gray-a5); +} + +.circularProgressFill { + fill: none; + stroke: var(--accent-9); + transition: stroke-dashoffset 0.3s ease; +} + +.circularProgressFillWarning { + stroke: var(--yellow-9); +} + +.circularProgressFillOverflown { + stroke: var(--red-9); +} + .usageCounterContainerInline { padding: calc(var(--space-1) * 1.5); --color-panel: transparent; diff --git a/refact-agent/gui/src/components/UsageCounter/UsageCounter.stories.tsx b/refact-agent/gui/src/components/UsageCounter/UsageCounter.stories.tsx index 431c0c236..edfd10286 100644 --- a/refact-agent/gui/src/components/UsageCounter/UsageCounter.stories.tsx +++ b/refact-agent/gui/src/components/UsageCounter/UsageCounter.stories.tsx @@ -28,6 +28,7 @@ const MockedStore: React.FC<{ isInline = false, isMessageEmpty = false, }) => { + const threadId = "test"; const store = setUpStore({ config: { themeProps: { @@ -37,36 +38,47 @@ const MockedStore: React.FC<{ lspPort: 8001, }, chat: { - streaming: false, - error: null, - waiting_for_response: false, - prevent_send: false, - send_immediately: false, - tool_use: "agent", - system_prompt: {}, - cache: {}, - queued_messages: [], - thread: { - id: "test", - messages: [ - { - role: "user", - content: "Hello, how are you?", + current_thread_id: threadId, + open_thread_ids: [threadId], + threads: { + [threadId]: { + thread: { + id: threadId, + messages: [ + { + role: "user", + content: "Hello, how are you?", + }, + { + role: "assistant", + content: "Test content", + usage, + }, + ], + model: "claude-3-5-sonnet", + mode: "AGENT", + new_chat_suggested: { + wasSuggested: false, + }, + currentMaximumContextTokens: threadMaximumContextTokens, + currentMessageContextTokens, }, - { - role: "assistant", - content: "Test content", - usage, + streaming: false, + waiting_for_response: false, + prevent_send: false, + error: null, + queued_messages: [], + send_immediately: false, + attached_images: [], + confirmation: { + pause: false, + pause_reasons: [], + status: { wasInteracted: false, confirmationStatus: true }, }, - ], - model: "claude-3-5-sonnet", - mode: "AGENT", - new_chat_suggested: { - wasSuggested: false, }, - currentMaximumContextTokens: threadMaximumContextTokens, - currentMessageContextTokens, }, + tool_use: "agent", + system_prompt: {}, }, }); diff --git a/refact-agent/gui/src/components/UsageCounter/UsageCounter.tsx b/refact-agent/gui/src/components/UsageCounter/UsageCounter.tsx index 37170107e..bb0b5f12f 100644 --- a/refact-agent/gui/src/components/UsageCounter/UsageCounter.tsx +++ b/refact-agent/gui/src/components/UsageCounter/UsageCounter.tsx @@ -1,5 +1,4 @@ -import { ArrowDownIcon, ArrowUpIcon } from "@radix-ui/react-icons"; -import { Box, Card, Flex, HoverCard, Tabs, Text } from "@radix-ui/themes"; +import { Card, Flex, HoverCard, Text, Box } from "@radix-ui/themes"; import classNames from "classnames"; import React, { useMemo, useState } from "react"; @@ -7,10 +6,10 @@ import { calculateUsageInputTokens } from "../../utils/calculateUsageInputTokens import { ScrollArea } from "../ScrollArea"; import { useUsageCounter } from "./useUsageCounter"; -import { selectAllImages } from "../../features/AttachedImages"; import { selectThreadCurrentMessageTokens, selectThreadMaximumTokens, + selectThreadImages, } from "../../features/Chat"; import { formatNumberToFixed } from "../../utils/formatNumberToFixed"; import { @@ -22,7 +21,57 @@ import { import styles from "./UsageCounter.module.css"; import { Coin } from "../../images"; -import { CompressionStrength, Usage } from "../../services/refact"; + +type CircularProgressProps = { + value: number; + max: number; + size?: number; + strokeWidth?: number; +}; + +const CircularProgress: React.FC = ({ + value, + max, + size = 20, + strokeWidth = 3, +}) => { + const percentage = max > 0 ? Math.min((value / max) * 100, 100) : 0; + const radius = (size - strokeWidth) / 2; + const circumference = 2 * Math.PI * radius; + const strokeDashoffset = circumference - (percentage / 100) * circumference; + + const isWarning = percentage >= 70 && percentage < 90; + const isOverflown = percentage >= 90; + + return ( + + + + + ); +}; type UsageCounterProps = | { @@ -46,51 +95,6 @@ const TokenDisplay: React.FC<{ label: string; value: number }> = ({ ); -const TokensDisplay: React.FC<{ - currentThreadUsage?: Usage | null; - inputTokens: number; - outputTokens: number; -}> = ({ currentThreadUsage, inputTokens, outputTokens }) => { - if (!currentThreadUsage) return; - const { - cache_read_input_tokens, - cache_creation_input_tokens, - completion_tokens_details, - prompt_tokens, - } = currentThreadUsage; - - return ( - - - Tokens spent per chat thread: - - - - - - {cache_read_input_tokens !== undefined && ( - - )} - {cache_creation_input_tokens !== undefined && ( - - )} - - {completion_tokens_details?.reasoning_tokens !== null && ( - - )} - - ); -}; - const CoinDisplay: React.FC<{ label: React.ReactNode; value: number }> = ({ label, value, @@ -109,40 +113,6 @@ const CoinDisplay: React.FC<{ label: React.ReactNode; value: number }> = ({ ); }; -const CoinsDisplay: React.FC<{ - total: number; - prompt?: number; - generated?: number; - cacheRead?: number; - cacheCreation?: number; -}> = ({ total, prompt, generated, cacheRead, cacheCreation }) => { - return ( - - - Coins spent - - - {Math.round(total)} - - - - - {prompt && } - - {generated !== undefined && ( - - )} - - {cacheRead !== undefined && ( - - )} - {cacheCreation !== undefined && ( - - )} - - ); -}; - const InlineHoverCard: React.FC<{ messageTokens: number }> = ({ messageTokens, }) => { @@ -165,109 +135,6 @@ const InlineHoverCard: React.FC<{ messageTokens: number }> = ({ ); }; -const DefaultHoverCard: React.FC<{ - inputTokens: number; - outputTokens: number; -}> = ({ inputTokens, outputTokens }) => { - const cost = useTotalCostForChat(); - const meteringTokens = useTotalTokenMeteringForChat(); - const { currentThreadUsage } = useUsageCounter(); - const total = useMemo(() => { - return ( - (cost?.metering_coins_prompt ?? 0) + - (cost?.metering_coins_generated ?? 0) + - (cost?.metering_coins_cache_creation ?? 0) + - (cost?.metering_coins_cache_read ?? 0) - ); - }, [cost]); - const totalMetering = useMemo(() => { - if (meteringTokens === null) return null; - return Object.values(meteringTokens).reduce( - (acc, cur) => acc + cur, - 0, - ); - }, [meteringTokens]); - - const tabsOptions = useMemo(() => { - const options = []; - if (total > 0) { - options.push({ - value: "coins", - label: "Coins", - }); - } - options.push({ - value: "tokens", - label: "Tokens", - }); - return options; - }, [total]); - - const renderContent = (optionValue: string) => { - if (optionValue === "tokens" && meteringTokens && totalMetering !== null) { - const usage: Usage = { - prompt_tokens: meteringTokens.metering_prompt_tokens_n, - total_tokens: totalMetering, - cache_creation_input_tokens: - meteringTokens.metering_cache_creation_tokens_n, - cache_read_input_tokens: meteringTokens.metering_cache_read_tokens_n, - completion_tokens: meteringTokens.metering_generated_tokens_n, - }; - return ( - - ); - } else if (optionValue === "tokens") { - return ( - - ); - } - return ( - - ); - }; - - if (tabsOptions.length === 1) { - return {renderContent(tabsOptions[0].value)}; - } - - return ( - - - {tabsOptions.map((option) => ( - - {option.label} - - ))} - - - {tabsOptions.map((option) => ( - - {renderContent(option.value)} - - ))} - - - ); -}; - const InlineHoverTriggerContent: React.FC<{ messageTokens: number }> = ({ messageTokens, }) => { @@ -281,87 +148,140 @@ const InlineHoverTriggerContent: React.FC<{ messageTokens: number }> = ({ ); }; -const formatCompressionStage = ( - strength: CompressionStrength | null | undefined, -): string | null => { - switch (strength) { - case "low": - return "1/3"; - case "medium": - return "2/3"; - case "high": - return "3/3"; - case "absent": - default: - return null; - } +const CoinsHoverContent: React.FC<{ + totalCoins: number; + prompt?: number; + generated?: number; + cacheRead?: number; + cacheCreation?: number; +}> = ({ totalCoins, prompt, generated, cacheRead, cacheCreation }) => { + return ( + + + Total coins + + + {Math.round(totalCoins)} + + + + {prompt !== undefined && prompt > 0 && ( + + )} + {generated !== undefined && generated > 0 && ( + + )} + {cacheRead !== undefined && cacheRead > 0 && ( + + )} + {cacheCreation !== undefined && cacheCreation > 0 && ( + + )} + + ); }; -const DefaultHoverTriggerContent: React.FC<{ +const TokensHoverContent: React.FC<{ + currentSessionTokens: number; + maxContextTokens: number; inputTokens: number; outputTokens: number; +}> = ({ currentSessionTokens, maxContextTokens, inputTokens, outputTokens }) => { + const percentage = maxContextTokens > 0 + ? Math.round((currentSessionTokens / maxContextTokens) * 100) + : 0; + + return ( + + + Context usage + {percentage}% + + + + {(inputTokens > 0 || outputTokens > 0) && ( + <> + + Total tokens + {inputTokens > 0 && } + {outputTokens > 0 && } + + )} + + ); +}; + +const DefaultHoverTriggerContent: React.FC<{ currentSessionTokens: number; - compressionStrength?: CompressionStrength | null; + maxContextTokens: number; totalCoins?: number; + inputTokens: number; + outputTokens: number; + coinsPrompt?: number; + coinsGenerated?: number; + coinsCacheRead?: number; + coinsCacheCreation?: number; }> = ({ - inputTokens, - outputTokens, currentSessionTokens, - compressionStrength, + maxContextTokens, totalCoins, + inputTokens, + outputTokens, + coinsPrompt, + coinsGenerated, + coinsCacheRead, + coinsCacheCreation, }) => { - const compressionLabel = formatCompressionStage(compressionStrength); - const hasCoinsOrContext = + const hasContent = (totalCoins !== undefined && totalCoins > 0) || currentSessionTokens !== 0; - const hasInputOutput = inputTokens !== 0 || outputTokens !== 0; + + if (!hasContent) return null; return ( - - {hasCoinsOrContext && ( - - {totalCoins !== undefined && totalCoins > 0 && ( - + + {totalCoins !== undefined && totalCoins > 0 && ( + + + {Math.round(totalCoins)} - )} - {currentSessionTokens !== 0 && ( - - ctx: {formatNumberToFixed(currentSessionTokens)} - - )} - {compressionLabel && ( - - ⚡{compressionLabel} - - )} - + + + + + )} - {hasInputOutput && ( - - {inputTokens !== 0 && ( - - - {formatNumberToFixed(inputTokens)} + {currentSessionTokens !== 0 && maxContextTokens > 0 && ( + + + + + + {formatNumberToFixed(currentSessionTokens)} + - )} - {outputTokens !== 0 && ( - - - {formatNumberToFixed(outputTokens)} - - )} - + + + + + )} ); @@ -372,12 +292,11 @@ export const UsageCounter: React.FC = ({ isMessageEmpty, }) => { const [open, setOpen] = useState(false); - const maybeAttachedImages = useAppSelector(selectAllImages); + const maybeAttachedImages = useAppSelector(selectThreadImages); const { currentThreadUsage, isOverflown, isWarning, - compressionStrength, currentSessionTokens, } = useUsageCounter(); const currentMessageTokens = useAppSelector(selectThreadCurrentMessageTokens); @@ -433,9 +352,14 @@ export const UsageCounter: React.FC = ({ return outputMeteringTokens ?? outputUsageTokens; }, [outputMeteringTokens, outputUsageTokens]); + const maxContextTokens = useAppSelector(selectThreadMaximumTokens) ?? 0; + const shouldUsageBeHidden = useMemo(() => { - return !isInline && inputTokens === 0 && outputTokens === 0; - }, [outputTokens, inputTokens, isInline]); + if (isInline) return false; + const hasCoins = totalCoins > 0; + const hasContext = currentSessionTokens > 0; + return !hasCoins && !hasContext; + }, [totalCoins, currentSessionTokens, isInline]); useEffectOnce(() => { const handleScroll = (event: WheelEvent) => { @@ -455,6 +379,32 @@ export const UsageCounter: React.FC = ({ if (shouldUsageBeHidden) return null; + // For non-inline (panel) usage, render borderless with individual hovercards + if (!isInline) { + return ( + + + + ); + } + + // For inline usage (chat form), keep the HoverCard with detailed info return ( @@ -465,17 +415,7 @@ export const UsageCounter: React.FC = ({ [styles.isOverflown]: isOverflown, })} > - {isInline ? ( - - ) : ( - - )} + @@ -485,18 +425,11 @@ export const UsageCounter: React.FC = ({ maxWidth="90vw" minWidth="300px" avoidCollisions - align={isInline ? "center" : "end"} + align="center" side="top" hideWhenDetached > - {isInline ? ( - - ) : ( - - )} + diff --git a/refact-agent/gui/src/features/App.tsx b/refact-agent/gui/src/features/App.tsx index 74d6084ed..60ae81465 100644 --- a/refact-agent/gui/src/features/App.tsx +++ b/refact-agent/gui/src/features/App.tsx @@ -8,6 +8,7 @@ import { useConfig, useEffectOnce, useEventsBusForIDE, + useTrajectoriesSubscription, } from "../hooks"; import { FIMDebug } from "./FIM"; import { store, persistor, RootState } from "../app/store"; @@ -70,6 +71,7 @@ export const InnerApp: React.FC = ({ style }: AppProps) => { useEventBusForWeb(); useEventBusForApp(); usePatchesAndDiffsEventsForIDE(); + useTrajectoriesSubscription(); const [isPaddingApplied, setIsPaddingApplied] = useState(false); diff --git a/refact-agent/gui/src/features/AttachedImages/imagesSlice.ts b/refact-agent/gui/src/features/AttachedImages/imagesSlice.ts deleted file mode 100644 index 7b432f9dc..000000000 --- a/refact-agent/gui/src/features/AttachedImages/imagesSlice.ts +++ /dev/null @@ -1,40 +0,0 @@ -import { createSlice, type PayloadAction } from "@reduxjs/toolkit"; - -export type ImageFile = { - name: string; - content: string | ArrayBuffer | null; - type: string; -}; - -const initialState: { - images: ImageFile[]; -} = { - images: [], -}; - -export const attachedImagesSlice = createSlice({ - name: "attachedImages", - initialState: initialState, - reducers: { - addImage: (state, action: PayloadAction) => { - if (state.images.length < 10) { - state.images = state.images.concat(action.payload); - } - }, - removeImageByIndex: (state, action: PayloadAction) => { - state.images = state.images.filter( - (_image, index) => index !== action.payload, - ); - }, - resetAttachedImagesSlice: () => { - return initialState; - }, - }, - selectors: { - selectAllImages: (state) => state.images, - }, -}); - -export const { selectAllImages } = attachedImagesSlice.selectors; -export const { addImage, removeImageByIndex, resetAttachedImagesSlice } = - attachedImagesSlice.actions; diff --git a/refact-agent/gui/src/features/AttachedImages/index.ts b/refact-agent/gui/src/features/AttachedImages/index.ts deleted file mode 100644 index 338444c6b..000000000 --- a/refact-agent/gui/src/features/AttachedImages/index.ts +++ /dev/null @@ -1 +0,0 @@ -export * from "./imagesSlice"; diff --git a/refact-agent/gui/src/features/Chat/Chat.test.tsx b/refact-agent/gui/src/features/Chat/Chat.test.tsx index 01aae5a1a..700a37df7 100644 --- a/refact-agent/gui/src/features/Chat/Chat.test.tsx +++ b/refact-agent/gui/src/features/Chat/Chat.test.tsx @@ -47,6 +47,8 @@ import { chatLinks, telemetryChat, telemetryNetwork, + emptyTrajectories, + trajectorySave, } from "../../utils/mockServer"; const handlers = [ @@ -60,6 +62,8 @@ const handlers = [ chatLinks, telemetryChat, telemetryNetwork, + emptyTrajectories, + trajectorySave, ]; // const handlers = [ diff --git a/refact-agent/gui/src/features/Chat/Thread/actions.ts b/refact-agent/gui/src/features/Chat/Thread/actions.ts index 7495754dd..66f1ddb23 100644 --- a/refact-agent/gui/src/features/Chat/Thread/actions.ts +++ b/refact-agent/gui/src/features/Chat/Thread/actions.ts @@ -4,17 +4,17 @@ import { type ChatThread, type PayloadWithId, type ToolUse, + type ImageFile, IntegrationMeta, LspChatMode, PayloadWithChatAndMessageId, PayloadWithChatAndBoolean, PayloadWithChatAndNumber, } from "./types"; +import type { ToolConfirmationPauseReason } from "../../../services/refact"; import { - isAssistantDelta, isAssistantMessage, isCDInstructionMessage, - isChatResponseChoice, isToolCallMessage, isToolMessage, isUserMessage, @@ -26,15 +26,16 @@ import { import type { AppDispatch, RootState } from "../../../app/store"; import { type SystemPrompts } from "../../../services/refact/prompts"; import { formatMessagesForLsp, consumeStream } from "./utils"; -import { generateChatTitle, sendChat } from "../../../services/refact/chat"; +import { sendChat } from "../../../services/refact/chat"; // import { ToolCommand, toolsApi } from "../../../services/refact/tools"; import { scanFoDuplicatesWith, takeFromEndWhile } from "../../../utils"; import { ChatHistoryItem } from "../../History/historySlice"; import { ideToolCallResponse } from "../../../hooks/useEventBusForIDE"; import { - capsApi, DetailMessageWithErrorType, isDetailMessage, + trajectoriesApi, + trajectoryDataToChatThread, } from "../../../services/refact"; export const newChatAction = createAction | undefined>( @@ -51,9 +52,7 @@ export const chatResponse = createAction( "chatThread/response", ); -export const chatTitleGenerationResponse = createAction< - PayloadWithId & ChatResponse ->("chatTitleGeneration/response"); + export const chatAskedQuestion = createAction( "chatThread/askQuestion", @@ -91,7 +90,7 @@ export const doneStreaming = createAction( export const setChatModel = createAction("chatThread/setChatModel"); export const getSelectedChatModel = (state: RootState) => - state.chat.thread.model; + state.chat.threads[state.chat.current_thread_id]?.thread.model ?? ""; export const setSystemPrompt = createAction( "chatThread/setSystemPrompt", @@ -105,6 +104,48 @@ export const restoreChat = createAction( "chatThread/restoreChat", ); +// Update an already-open thread with fresh data from backend (used by subscription) +export const updateOpenThread = createAction<{ + id: string; + thread: Partial; +}>("chatThread/updateOpenThread"); + +export const switchToThread = createAction( + "chatThread/switchToThread", +); + +export const closeThread = createAction( + "chatThread/closeThread", +); + +export const setThreadPauseReasons = createAction<{ + id: string; + pauseReasons: ToolConfirmationPauseReason[]; +}>("chatThread/setPauseReasons"); + +export const clearThreadPauseReasons = createAction( + "chatThread/clearPauseReasons", +); + +export const setThreadConfirmationStatus = createAction<{ + id: string; + wasInteracted: boolean; + confirmationStatus: boolean; +}>("chatThread/setConfirmationStatus"); + +export const addThreadImage = createAction<{ id: string; image: ImageFile }>( + "chatThread/addImage", +); + +export const removeThreadImageByIndex = createAction<{ + id: string; + index: number; +}>("chatThread/removeImageByIndex"); + +export const resetThreadImages = createAction( + "chatThread/resetImages", +); + export const clearChatError = createAction( "chatThread/clearError", ); @@ -116,9 +157,6 @@ export const setPreventSend = createAction( export const setAreFollowUpsEnabled = createAction( "chat/setAreFollowUpsEnabled", ); -export const setIsTitleGenerationEnabled = createAction( - "chat/setIsTitleGenerationEnabled", -); export const setUseCompression = createAction( "chat/setUseCompression", @@ -170,7 +208,7 @@ export const setIntegrationData = createAction | null>( "chatThread/setIntegrationData", ); -export const setIsWaitingForResponse = createAction( +export const setIsWaitingForResponse = createAction<{ id: string; value: boolean }>( "chatThread/setIsWaiting", ); @@ -205,91 +243,6 @@ const createAppAsyncThunk = createAsyncThunk.withTypes<{ dispatch: AppDispatch; }>(); -export const chatGenerateTitleThunk = createAppAsyncThunk< - unknown, - { - messages: ChatMessages; - chatId: string; - } ->("chatThread/generateTitle", async ({ messages, chatId }, thunkAPI) => { - const state = thunkAPI.getState(); - - const messagesToSend = messages.filter( - (msg) => - !isToolMessage(msg) && !isAssistantMessage(msg) && msg.content !== "", - ); - // .map((msg) => { - // if (isAssistantMessage(msg)) { - // return { - // role: msg.role, - // content: msg.content, - // }; - // } - // return msg; - // }); - - const caps = await thunkAPI - .dispatch(capsApi.endpoints.getCaps.initiate(undefined)) - .unwrap(); - const model = caps.chat_default_model; - const messagesForLsp = formatMessagesForLsp([ - ...messagesToSend, - { - role: "user", - content: - "Summarize the chat above in 2-3 words. Prefer filenames, classes, entities, and avoid generic terms. Example: 'Explain MyClass::f()'. Write nothing else, only the 2-3 words.", - checkpoints: [], - }, - ]); - - const chatResponseChunks: ChatResponse[] = []; - - return generateChatTitle({ - messages: messagesForLsp, - model, - stream: true, - abortSignal: thunkAPI.signal, - chatId, - apiKey: state.config.apiKey, - port: state.config.lspPort, - }) - .then((response) => { - if (!response.ok) { - return Promise.reject(new Error(response.statusText)); - } - const reader = response.body?.getReader(); - if (!reader) return; - const onAbort = () => thunkAPI.dispatch(setPreventSend({ id: chatId })); - const onChunk = (json: Record) => { - chatResponseChunks.push(json as ChatResponse); - }; - return consumeStream(reader, thunkAPI.signal, onAbort, onChunk); - }) - .catch((err: Error) => { - thunkAPI.dispatch(doneStreaming({ id: chatId })); - thunkAPI.dispatch(chatError({ id: chatId, message: err.message })); - return thunkAPI.rejectWithValue(err.message); - }) - .finally(() => { - const title = chatResponseChunks.reduce((acc, chunk) => { - if (isChatResponseChoice(chunk)) { - if (isAssistantDelta(chunk.choices[0].delta)) { - const deltaContent = chunk.choices[0].delta.content; - if (deltaContent) { - return acc + deltaContent; - } - } - } - return acc; - }, ""); - - thunkAPI.dispatch( - saveTitle({ id: chatId, title, isTitleGenerated: true }), - ); - thunkAPI.dispatch(doneStreaming({ id: chatId })); - }); -}); - function checkForToolLoop(message: ChatMessages): boolean { const assistantOrToolMessages = takeFromEndWhile(message, (message) => { return ( @@ -338,22 +291,16 @@ export const chatAskQuestionThunk = createAppAsyncThunk< messages: ChatMessages; chatId: string; checkpointsEnabled?: boolean; - mode?: LspChatMode; // used once for actions - // TODO: make a separate function for this... and it'll need to be saved. + mode?: LspChatMode; } >( "chatThread/sendChat", ({ messages, chatId, mode, checkpointsEnabled }, thunkAPI) => { const state = thunkAPI.getState(); - const thread = - chatId in state.chat.cache - ? state.chat.cache[chatId] - : state.chat.thread.id === chatId - ? state.chat.thread - : null; + const runtime = state.chat.threads[chatId]; + const thread = runtime?.thread ?? null; - // stops the stream const onlyDeterministicMessages = checkForToolLoop(messages); const messagesForLsp = formatMessagesForLsp(messages); @@ -361,23 +308,21 @@ export const chatAskQuestionThunk = createAppAsyncThunk< const maybeLastUserMessageId = thread?.last_user_message_id; const boostReasoning = thread?.boost_reasoning ?? false; const increaseMaxTokens = thread?.increase_max_tokens ?? false; - // Only send include_project_info on the first message of a chat - // Check if there's only one user message (the current one being sent) const userMessageCount = messages.filter(isUserMessage).length; const includeProjectInfo = userMessageCount <= 1 ? thread?.include_project_info ?? true : undefined; - // Context tokens cap - send on every request, default to max if not set const contextTokensCap = thread?.context_tokens_cap ?? thread?.currentMaximumContextTokens; - // Use compression - get from state const useCompression = state.chat.use_compression; + const model = thread?.model ?? ""; + return sendChat({ messages: messagesForLsp, last_user_message_id: maybeLastUserMessageId, - model: state.chat.thread.model, + model, stream: true, abortSignal: thunkAPI.signal, increase_max_tokens: increaseMaxTokens, @@ -403,6 +348,8 @@ export const chatAskQuestionThunk = createAppAsyncThunk< const onAbort = () => { thunkAPI.dispatch(setPreventSend({ id: chatId })); thunkAPI.dispatch(fixBrokenToolMessages({ id: chatId })); + // Dispatch doneStreaming immediately on abort to clean up state + thunkAPI.dispatch(doneStreaming({ id: chatId })); }; const onChunk = (json: Record) => { const action = chatResponse({ @@ -414,9 +361,7 @@ export const chatAskQuestionThunk = createAppAsyncThunk< return consumeStream(reader, thunkAPI.signal, onAbort, onChunk); }) .catch((err: unknown) => { - // console.log("Catch called"); const isError = err instanceof Error; - thunkAPI.dispatch(doneStreaming({ id: chatId })); thunkAPI.dispatch(fixBrokenToolMessages({ id: chatId })); const errorObject: DetailMessageWithErrorType = { @@ -431,7 +376,11 @@ export const chatAskQuestionThunk = createAppAsyncThunk< return thunkAPI.rejectWithValue(errorObject); }) .finally(() => { - thunkAPI.dispatch(doneStreaming({ id: chatId })); + // Only dispatch doneStreaming if not aborted - abort handler already did it + // This prevents "late cleanup" from corrupting a new request that started + if (!thunkAPI.signal.aborted) { + thunkAPI.dispatch(doneStreaming({ id: chatId })); + } }); }, ); @@ -443,16 +392,15 @@ export const sendCurrentChatToLspAfterToolCallUpdate = createAppAsyncThunk< "chatThread/sendCurrentChatToLspAfterToolCallUpdate", async ({ chatId, toolCallId }, thunkApi) => { const state = thunkApi.getState(); - if (state.chat.thread.id !== chatId) return; - if ( - state.chat.streaming || - state.chat.prevent_send || - state.chat.waiting_for_response - ) { + const runtime = state.chat.threads[chatId]; + if (!runtime) return; + + if (runtime.streaming || runtime.prevent_send || runtime.waiting_for_response) { return; } + const lastMessages = takeFromEndWhile( - state.chat.thread.messages, + runtime.thread.messages, (message) => !isUserMessage(message) && !isAssistantMessage(message), ); @@ -462,15 +410,47 @@ export const sendCurrentChatToLspAfterToolCallUpdate = createAppAsyncThunk< ); if (!toolUseInThisSet) return; - thunkApi.dispatch(setIsWaitingForResponse(true)); + thunkApi.dispatch(setIsWaitingForResponse({ id: chatId, value: true })); return thunkApi.dispatch( chatAskQuestionThunk({ - messages: state.chat.thread.messages, + messages: runtime.thread.messages, chatId, - mode: state.chat.thread.mode, + mode: runtime.thread.mode, checkpointsEnabled: state.chat.checkpoints_enabled, }), ); }, ); + +// Fetch fresh thread data from backend before restoring (re-opening a closed tab) +export const restoreChatFromBackend = createAsyncThunk< + void, + { id: string; fallback: ChatHistoryItem }, + { dispatch: AppDispatch; state: RootState } +>( + "chatThread/restoreChatFromBackend", + async ({ id, fallback }, thunkApi) => { + try { + const result = await thunkApi.dispatch( + trajectoriesApi.endpoints.getTrajectory.initiate(id, { + forceRefetch: true, + }), + ).unwrap(); + + const thread = trajectoryDataToChatThread(result); + const historyItem: ChatHistoryItem = { + ...thread, + createdAt: result.created_at, + updatedAt: result.updated_at, + title: result.title, + isTitleGenerated: result.isTitleGenerated, + }; + + thunkApi.dispatch(restoreChat(historyItem)); + } catch { + // Backend not available, use fallback from history + thunkApi.dispatch(restoreChat(fallback)); + } + }, +); diff --git a/refact-agent/gui/src/features/Chat/Thread/reducer.test.ts b/refact-agent/gui/src/features/Chat/Thread/reducer.test.ts index c56c3a80d..a3e662e82 100644 --- a/refact-agent/gui/src/features/Chat/Thread/reducer.test.ts +++ b/refact-agent/gui/src/features/Chat/Thread/reducer.test.ts @@ -1,19 +1,22 @@ import { expect, test, describe } from "vitest"; import { chatReducer } from "./reducer"; -import { chatResponse } from "./actions"; -import { createAction } from "@reduxjs/toolkit"; +import { chatResponse, newChatAction } from "./actions"; describe("Chat Thread Reducer", () => { test("streaming should be true on any response", () => { - const init = chatReducer(undefined, createAction("noop")()); + // Create initial empty state and then add a new thread + const emptyState = chatReducer(undefined, { type: "@@INIT" }); + const stateWithThread = chatReducer(emptyState, newChatAction(undefined)); + const chatId = stateWithThread.current_thread_id; + const msg = chatResponse({ - id: init.thread.id, + id: chatId, role: "tool", tool_call_id: "test_tool", content: "👀", }); - const result = chatReducer(init, msg); - expect(result.streaming).toEqual(true); + const result = chatReducer(stateWithThread, msg); + expect(result.threads[chatId]?.streaming).toEqual(true); }); }); diff --git a/refact-agent/gui/src/features/Chat/Thread/reducer.ts b/refact-agent/gui/src/features/Chat/Thread/reducer.ts index 32987febc..952cf7245 100644 --- a/refact-agent/gui/src/features/Chat/Thread/reducer.ts +++ b/refact-agent/gui/src/features/Chat/Thread/reducer.ts @@ -2,6 +2,7 @@ import { createReducer, Draft } from "@reduxjs/toolkit"; import { Chat, ChatThread, + ChatThreadRuntime, IntegrationMeta, ToolUse, LspChatMode, @@ -40,13 +41,22 @@ import { upsertToolCall, setIncreaseMaxTokens, setAreFollowUpsEnabled, - setIsTitleGenerationEnabled, setIncludeProjectInfo, setContextTokensCap, setUseCompression, enqueueUserMessage, dequeueUserMessage, clearQueuedMessages, + closeThread, + switchToThread, + updateOpenThread, + setThreadPauseReasons, + clearThreadPauseReasons, + setThreadConfirmationStatus, + addThreadImage, + removeThreadImageByIndex, + resetThreadImages, + chatAskQuestionThunk, } from "./actions"; import { formatChatResponse, postProcessMessagesAfterStreaming } from "./utils"; import { @@ -61,7 +71,6 @@ import { isUserResponse, ToolCall, ToolMessage, - UserMessage, validateToolCall, } from "../../../services/refact"; import { capsApi } from "../../../services/refact"; @@ -71,7 +80,7 @@ const createChatThread = ( integration?: IntegrationMeta | null, mode?: LspChatMode, ): ChatThread => { - const chat: ChatThread = { + return { id: uuidv4(), messages: [], title: "", @@ -80,100 +89,115 @@ const createChatThread = ( tool_use, integration, mode, - new_chat_suggested: { - wasSuggested: false, - }, + new_chat_suggested: { wasSuggested: false }, boost_reasoning: false, automatic_patch: false, increase_max_tokens: false, include_project_info: true, context_tokens_cap: undefined, }; - return chat; }; -type createInitialStateArgs = { - tool_use?: ToolUse; - integration?: IntegrationMeta | null; - maybeMode?: LspChatMode; +const createThreadRuntime = ( + tool_use: ToolUse, + integration?: IntegrationMeta | null, + mode?: LspChatMode, +): ChatThreadRuntime => { + return { + thread: createChatThread(tool_use, integration, mode), + streaming: false, + waiting_for_response: false, + prevent_send: false, + error: null, + queued_messages: [], + send_immediately: false, + attached_images: [], + confirmation: { + pause: false, + pause_reasons: [], + status: { + wasInteracted: false, + confirmationStatus: true, + }, + }, + }; }; const getThreadMode = ({ tool_use, integration, maybeMode, -}: createInitialStateArgs) => { - if (integration) { - return "CONFIGURE"; - } - if (maybeMode) { - return maybeMode === "CONFIGURE" ? "AGENT" : maybeMode; - } - +}: { + tool_use?: ToolUse; + integration?: IntegrationMeta | null; + maybeMode?: LspChatMode; +}) => { + if (integration) return "CONFIGURE"; + if (maybeMode) return maybeMode === "CONFIGURE" ? "AGENT" : maybeMode; return chatModeToLspMode({ toolUse: tool_use }); }; -const createInitialState = ({ - tool_use = "agent", - integration, - maybeMode, -}: createInitialStateArgs): Chat => { - const mode = getThreadMode({ tool_use, integration, maybeMode }); - +const createInitialState = (): Chat => { return { - streaming: false, - thread: createChatThread(tool_use, integration, mode), - error: null, - prevent_send: false, - waiting_for_response: false, - cache: {}, + current_thread_id: "", + open_thread_ids: [], + threads: {}, system_prompt: {}, - tool_use, + tool_use: "agent", checkpoints_enabled: true, - send_immediately: false, - queued_messages: [], + follow_ups_enabled: undefined, + use_compression: undefined, }; }; -const initialState = createInitialState({}); +const initialState = createInitialState(); + +const getRuntime = (state: Draft, chatId: string): Draft | null => { + return state.threads[chatId] ?? null; +}; + +const getCurrentRuntime = (state: Draft): Draft | null => { + return getRuntime(state, state.current_thread_id); +}; + + export const chatReducer = createReducer(initialState, (builder) => { builder.addCase(setToolUse, (state, action) => { - state.thread.tool_use = action.payload; state.tool_use = action.payload; - state.thread.mode = chatModeToLspMode({ toolUse: action.payload }); + const rt = getCurrentRuntime(state); + if (rt) { + rt.thread.tool_use = action.payload; + rt.thread.mode = chatModeToLspMode({ toolUse: action.payload }); + } }); builder.addCase(setPreventSend, (state, action) => { - if (state.thread.id !== action.payload.id) return state; - state.prevent_send = true; + const rt = getRuntime(state, action.payload.id); + if (rt) rt.prevent_send = true; }); builder.addCase(enableSend, (state, action) => { - if (state.thread.id !== action.payload.id) return state; - state.prevent_send = false; + const rt = getRuntime(state, action.payload.id); + if (rt) rt.prevent_send = false; }); builder.addCase(setAreFollowUpsEnabled, (state, action) => { state.follow_ups_enabled = action.payload; }); - builder.addCase(setIsTitleGenerationEnabled, (state, action) => { - state.title_generation_enabled = action.payload; - }); - builder.addCase(setUseCompression, (state, action) => { state.use_compression = action.payload; }); builder.addCase(clearChatError, (state, action) => { - if (state.thread.id !== action.payload.id) return state; - state.error = null; + const rt = getRuntime(state, action.payload.id); + if (rt) rt.error = null; }); builder.addCase(setChatModel, (state, action) => { - state.thread.model = action.payload; - state.thread.model = action.payload; + const rt = getCurrentRuntime(state); + if (rt) rt.thread.model = action.payload; }); builder.addCase(setSystemPrompt, (state, action) => { @@ -181,57 +205,40 @@ export const chatReducer = createReducer(initialState, (builder) => { }); builder.addCase(newChatAction, (state, action) => { - const next = createInitialState({ - tool_use: state.tool_use, - maybeMode: state.thread.mode, - }); - next.cache = { ...state.cache }; - if (state.streaming || state.waiting_for_response) { - next.cache[state.thread.id] = { ...state.thread, read: false }; + const currentRt = getCurrentRuntime(state); + const mode = getThreadMode({ tool_use: state.tool_use, maybeMode: currentRt?.thread.mode }); + const newRuntime = createThreadRuntime(state.tool_use, null, mode); + + if (currentRt) { + newRuntime.thread.model = currentRt.thread.model; + newRuntime.thread.boost_reasoning = currentRt.thread.boost_reasoning; } - next.thread.model = state.thread.model; - next.system_prompt = state.system_prompt; - next.checkpoints_enabled = state.checkpoints_enabled; - next.follow_ups_enabled = state.follow_ups_enabled; - next.title_generation_enabled = state.title_generation_enabled; - next.use_compression = state.use_compression; - next.thread.boost_reasoning = state.thread.boost_reasoning; - next.queued_messages = []; - // next.thread.automatic_patch = state.thread.automatic_patch; + if (action.payload?.messages) { - next.thread.messages = action.payload.messages; + newRuntime.thread.messages = action.payload.messages; } - return next; + + const newId = newRuntime.thread.id; + state.threads[newId] = newRuntime; + state.open_thread_ids.push(newId); + state.current_thread_id = newId; }); builder.addCase(chatResponse, (state, action) => { - if ( - action.payload.id !== state.thread.id && - !(action.payload.id in state.cache) - ) { - return state; - } - - if (action.payload.id in state.cache) { - const thread = state.cache[action.payload.id]; - // TODO: this might not be needed any more, because we can mutate the last message. - const messages = formatChatResponse(thread.messages, action.payload); - thread.messages = messages; - return state; - } + const rt = getRuntime(state, action.payload.id); + if (!rt) return; - const messages = formatChatResponse(state.thread.messages, action.payload); - - state.thread.messages = messages; - state.streaming = true; - state.waiting_for_response = false; + const messages = formatChatResponse(rt.thread.messages, action.payload); + rt.thread.messages = messages; + rt.streaming = true; + rt.waiting_for_response = false; if ( isUserResponse(action.payload) && action.payload.compression_strength && action.payload.compression_strength !== "absent" ) { - state.thread.new_chat_suggested = { + rt.thread.new_chat_suggested = { wasRejectedByUser: false, wasSuggested: true, }; @@ -239,48 +246,52 @@ export const chatReducer = createReducer(initialState, (builder) => { }); builder.addCase(backUpMessages, (state, action) => { - // TODO: should it also save to history? - state.error = null; - // state.previous_message_length = state.thread.messages.length; - state.thread.messages = action.payload.messages; + const rt = getRuntime(state, action.payload.id); + if (rt) { + rt.error = null; + rt.thread.messages = action.payload.messages; + } }); builder.addCase(chatError, (state, action) => { - state.streaming = false; - state.prevent_send = true; - state.waiting_for_response = false; - state.error = action.payload.message; + const rt = getRuntime(state, action.payload.id); + if (rt) { + rt.streaming = false; + rt.prevent_send = true; + rt.waiting_for_response = false; + rt.error = action.payload.message; + } }); builder.addCase(doneStreaming, (state, action) => { - if (state.thread.id !== action.payload.id) return state; - state.streaming = false; - state.waiting_for_response = false; - state.thread.read = true; - state.thread.messages = postProcessMessagesAfterStreaming( - state.thread.messages, - ); + const rt = getRuntime(state, action.payload.id); + if (rt) { + rt.streaming = false; + rt.waiting_for_response = false; + rt.thread.read = action.payload.id === state.current_thread_id; + rt.thread.messages = postProcessMessagesAfterStreaming(rt.thread.messages); + } }); builder.addCase(setAutomaticPatch, (state, action) => { - if (state.thread.id !== action.payload.chatId) return state; - state.thread.automatic_patch = action.payload.value; + const rt = getRuntime(state, action.payload.chatId); + if (rt) rt.thread.automatic_patch = action.payload.value; }); builder.addCase(setIsNewChatSuggested, (state, action) => { - if (state.thread.id !== action.payload.chatId) return state; - state.thread.new_chat_suggested = { - wasSuggested: action.payload.value, - }; + const rt = getRuntime(state, action.payload.chatId); + if (rt) rt.thread.new_chat_suggested = { wasSuggested: action.payload.value }; }); builder.addCase(setIsNewChatSuggestionRejected, (state, action) => { - if (state.thread.id !== action.payload.chatId) return state; - state.prevent_send = false; - state.thread.new_chat_suggested = { - ...state.thread.new_chat_suggested, - wasRejectedByUser: action.payload.value, - }; + const rt = getRuntime(state, action.payload.chatId); + if (rt) { + rt.prevent_send = false; + rt.thread.new_chat_suggested = { + ...rt.thread.new_chat_suggested, + wasRejectedByUser: action.payload.value, + }; + } }); builder.addCase(setEnabledCheckpoints, (state, action) => { @@ -288,194 +299,253 @@ export const chatReducer = createReducer(initialState, (builder) => { }); builder.addCase(setBoostReasoning, (state, action) => { - if (state.thread.id !== action.payload.chatId) return state; - state.thread.boost_reasoning = action.payload.value; + const rt = getRuntime(state, action.payload.chatId); + if (rt) rt.thread.boost_reasoning = action.payload.value; }); builder.addCase(setLastUserMessageId, (state, action) => { - if (state.thread.id !== action.payload.chatId) return state; - state.thread.last_user_message_id = action.payload.messageId; + const rt = getRuntime(state, action.payload.chatId); + if (rt) rt.thread.last_user_message_id = action.payload.messageId; }); builder.addCase(chatAskedQuestion, (state, action) => { - if (state.thread.id !== action.payload.id) return state; - state.send_immediately = false; - state.waiting_for_response = true; - state.thread.read = false; - state.prevent_send = false; + const rt = getRuntime(state, action.payload.id); + if (rt) { + rt.send_immediately = false; + rt.waiting_for_response = true; + rt.thread.read = false; + rt.prevent_send = false; + } }); builder.addCase(removeChatFromCache, (state, action) => { - if (!(action.payload.id in state.cache)) return state; + const id = action.payload.id; + const rt = state.threads[id]; + if (rt && !rt.streaming && !rt.confirmation.pause) { + delete state.threads[id]; + state.open_thread_ids = state.open_thread_ids.filter((tid) => tid !== id); + } + }); - const cache = Object.entries(state.cache).reduce< - Record - >((acc, cur) => { - if (cur[0] === action.payload.id) return acc; - return { ...acc, [cur[0]]: cur[1] }; - }, {}); - state.cache = cache; + builder.addCase(closeThread, (state, action) => { + const id = action.payload.id; + const force = action.payload.force ?? false; + state.open_thread_ids = state.open_thread_ids.filter((tid) => tid !== id); + const rt = state.threads[id]; + if (rt && (force || (!rt.streaming && !rt.waiting_for_response && !rt.confirmation.pause))) { + delete state.threads[id]; + } + if (state.current_thread_id === id) { + state.current_thread_id = state.open_thread_ids[0] ?? ""; + } }); builder.addCase(restoreChat, (state, action) => { - if (state.thread.id === action.payload.id) return state; - const mostUptoDateThread = - action.payload.id in state.cache - ? { ...state.cache[action.payload.id] } - : { ...action.payload, read: true }; - - state.error = null; - state.waiting_for_response = false; - - if (state.streaming) { - state.cache[state.thread.id] = { ...state.thread, read: false }; - } - if (action.payload.id in state.cache) { - const { [action.payload.id]: _, ...rest } = state.cache; - state.cache = rest; - state.streaming = true; - } else { - state.streaming = false; + const existingRt = getRuntime(state, action.payload.id); + if (existingRt) { + // Runtime exists (possibly running in background) - re-add to open tabs if needed + if (!state.open_thread_ids.includes(action.payload.id)) { + state.open_thread_ids.push(action.payload.id); + } + state.current_thread_id = action.payload.id; + existingRt.thread.read = true; + return; } - state.prevent_send = true; - state.thread = { - new_chat_suggested: { wasSuggested: false }, - ...mostUptoDateThread, + + const mode = action.payload.mode && isLspChatMode(action.payload.mode) + ? action.payload.mode + : "AGENT"; + const newRuntime: ChatThreadRuntime = { + thread: { + new_chat_suggested: { wasSuggested: false }, + ...action.payload, + mode, + tool_use: action.payload.tool_use ?? state.tool_use, + read: true, + }, + streaming: false, + waiting_for_response: false, + prevent_send: false, + error: null, + queued_messages: [], + send_immediately: false, + attached_images: [], + confirmation: { + pause: false, + pause_reasons: [], + status: { + wasInteracted: false, + confirmationStatus: true, + }, + }, }; - state.thread.messages = postProcessMessagesAfterStreaming( - state.thread.messages, + newRuntime.thread.messages = postProcessMessagesAfterStreaming( + newRuntime.thread.messages, ); - state.thread.tool_use = state.thread.tool_use ?? state.tool_use; - if (action.payload.mode && !isLspChatMode(action.payload.mode)) { - state.thread.mode = "AGENT"; - } - const lastUserMessage = action.payload.messages.reduce( - (acc, cur) => { - if (isUserMessage(cur)) return cur; - return acc; - }, + const lastUserMessage = action.payload.messages.reduce( + (acc, cur) => (isUserMessage(cur) ? cur : acc), null, ); - if ( lastUserMessage?.compression_strength && lastUserMessage.compression_strength !== "absent" ) { - state.thread.new_chat_suggested = { + newRuntime.thread.new_chat_suggested = { wasRejectedByUser: false, wasSuggested: true, }; } + + state.threads[action.payload.id] = newRuntime; + if (!state.open_thread_ids.includes(action.payload.id)) { + state.open_thread_ids.push(action.payload.id); + } + state.current_thread_id = action.payload.id; + }); + + builder.addCase(switchToThread, (state, action) => { + const id = action.payload.id; + const existingRt = getRuntime(state, id); + if (existingRt) { + if (!state.open_thread_ids.includes(id)) { + state.open_thread_ids.push(id); + } + state.current_thread_id = id; + existingRt.thread.read = true; + } + }); + + // Update an already-open thread with fresh data from backend (used by subscription) + builder.addCase(updateOpenThread, (state, action) => { + const existingRt = getRuntime(state, action.payload.id); + if (!existingRt) return; + + const incomingTitle = action.payload.thread.title; + const incomingTitleGenerated = action.payload.thread.isTitleGenerated; + + // Always allow title updates if backend generated it and local didn't + if (incomingTitle && incomingTitleGenerated && !existingRt.thread.isTitleGenerated) { + existingRt.thread.title = incomingTitle; + existingRt.thread.isTitleGenerated = true; + } + + // For other fields, only update non-busy, non-current threads + // IMPORTANT: Exclude messages - local runtime is authoritative for messages + const isCurrentThread = action.payload.id === state.current_thread_id; + if (!existingRt.streaming && !existingRt.waiting_for_response && !existingRt.error && !isCurrentThread) { + const { title, isTitleGenerated, messages, ...otherFields } = action.payload.thread; + existingRt.thread = { + ...existingRt.thread, + ...otherFields, + }; + } }); - // New builder to save chat title within the current thread and not only inside of a history thread builder.addCase(saveTitle, (state, action) => { - if (state.thread.id !== action.payload.id) return state; - state.thread.title = action.payload.title; - state.thread.isTitleGenerated = action.payload.isTitleGenerated; + const rt = getRuntime(state, action.payload.id); + if (rt) { + rt.thread.title = action.payload.title; + rt.thread.isTitleGenerated = action.payload.isTitleGenerated; + } }); builder.addCase(newIntegrationChat, (state, action) => { - // TODO: find out about tool use - // TODO: should be CONFIGURE ? - const next = createInitialState({ - tool_use: "agent", - integration: action.payload.integration, - maybeMode: "CONFIGURE", - }); - next.thread.last_user_message_id = action.payload.request_attempt_id; - next.thread.integration = action.payload.integration; - next.thread.messages = action.payload.messages; - - next.thread.model = state.thread.model; - next.system_prompt = state.system_prompt; - next.cache = { ...state.cache }; - if (state.streaming) { - next.cache[state.thread.id] = { ...state.thread, read: false }; + const currentRt = getCurrentRuntime(state); + const newRuntime = createThreadRuntime("agent", action.payload.integration, "CONFIGURE"); + newRuntime.thread.last_user_message_id = action.payload.request_attempt_id; + newRuntime.thread.messages = action.payload.messages; + if (currentRt) { + newRuntime.thread.model = currentRt.thread.model; } - return next; + + const newId = newRuntime.thread.id; + state.threads[newId] = newRuntime; + state.open_thread_ids.push(newId); + state.current_thread_id = newId; }); builder.addCase(setSendImmediately, (state, action) => { - state.send_immediately = action.payload; + const rt = getCurrentRuntime(state); + if (rt) rt.send_immediately = action.payload; }); builder.addCase(enqueueUserMessage, (state, action) => { + const rt = getCurrentRuntime(state); + if (!rt) return; const { priority, ...rest } = action.payload; const messagePayload = { ...rest, priority }; if (priority) { - // Insert at front for "send next" (next available turn) - // Find the position after existing priority messages (stable FIFO among priority) - const insertAt = state.queued_messages.findIndex((m) => !m.priority); + const insertAt = rt.queued_messages.findIndex((m) => !m.priority); if (insertAt === -1) { - state.queued_messages.push(messagePayload); + rt.queued_messages.push(messagePayload); } else { - state.queued_messages.splice(insertAt, 0, messagePayload); + rt.queued_messages.splice(insertAt, 0, messagePayload); } } else { - state.queued_messages.push(messagePayload); + rt.queued_messages.push(messagePayload); } }); builder.addCase(dequeueUserMessage, (state, action) => { - state.queued_messages = state.queued_messages.filter( - (q) => q.id !== action.payload.queuedId, - ); + const rt = getCurrentRuntime(state); + if (rt) { + rt.queued_messages = rt.queued_messages.filter( + (q) => q.id !== action.payload.queuedId, + ); + } }); builder.addCase(clearQueuedMessages, (state) => { - state.queued_messages = []; + const rt = getCurrentRuntime(state); + if (rt) rt.queued_messages = []; }); builder.addCase(setChatMode, (state, action) => { - state.thread.mode = action.payload; + const rt = getCurrentRuntime(state); + if (rt) rt.thread.mode = action.payload; }); builder.addCase(setIntegrationData, (state, action) => { - state.thread.integration = action.payload; + const rt = getCurrentRuntime(state); + if (rt) rt.thread.integration = action.payload; }); builder.addCase(setIsWaitingForResponse, (state, action) => { - state.waiting_for_response = action.payload; + const rt = getRuntime(state, action.payload.id); + if (rt) rt.waiting_for_response = action.payload.value; }); - // TBD: should be safe to remove? builder.addCase(setMaxNewTokens, (state, action) => { - state.thread.currentMaximumContextTokens = action.payload; - // Also adjust context_tokens_cap if it exceeds the new max - if ( - state.thread.context_tokens_cap === undefined || - state.thread.context_tokens_cap > action.payload - ) { - state.thread.context_tokens_cap = action.payload; + const rt = getCurrentRuntime(state); + if (rt) { + rt.thread.currentMaximumContextTokens = action.payload; + if ( + rt.thread.context_tokens_cap === undefined || + rt.thread.context_tokens_cap > action.payload + ) { + rt.thread.context_tokens_cap = action.payload; + } } }); builder.addCase(fixBrokenToolMessages, (state, action) => { - if (action.payload.id !== state.thread.id) return state; - if (state.thread.messages.length === 0) return state; - const lastMessage = state.thread.messages[state.thread.messages.length - 1]; - if (!isToolCallMessage(lastMessage)) return state; - if (lastMessage.tool_calls.every(validateToolCall)) return state; + const rt = getRuntime(state, action.payload.id); + if (!rt || rt.thread.messages.length === 0) return; + const lastMessage = rt.thread.messages[rt.thread.messages.length - 1]; + if (!isToolCallMessage(lastMessage)) return; + if (lastMessage.tool_calls.every(validateToolCall)) return; const validToolCalls = lastMessage.tool_calls.filter(validateToolCall); - const messages = state.thread.messages.slice(0, -1); + const messages = rt.thread.messages.slice(0, -1); const newMessage = { ...lastMessage, tool_calls: validToolCalls }; - state.thread.messages = [...messages, newMessage]; + rt.thread.messages = [...messages, newMessage]; }); builder.addCase(upsertToolCall, (state, action) => { - // if (action.payload.toolCallId !== state.thread.id && !(action.payload.chatId in state.cache)) return state; - if (action.payload.chatId === state.thread.id) { - maybeAppendToolCallResultFromIdeToMessages( - state.thread.messages, - action.payload.toolCallId, - action.payload.accepted, - ); - } else if (action.payload.chatId in state.cache) { - const thread = state.cache[action.payload.chatId]; + const rt = getRuntime(state, action.payload.chatId); + if (rt) { maybeAppendToolCallResultFromIdeToMessages( - thread.messages, + rt.thread.messages, action.payload.toolCallId, action.payload.accepted, action.payload.replaceOnly, @@ -484,38 +554,91 @@ export const chatReducer = createReducer(initialState, (builder) => { }); builder.addCase(setIncreaseMaxTokens, (state, action) => { - state.thread.increase_max_tokens = action.payload; + const rt = getCurrentRuntime(state); + if (rt) rt.thread.increase_max_tokens = action.payload; }); builder.addCase(setIncludeProjectInfo, (state, action) => { - if (state.thread.id !== action.payload.chatId) return state; - state.thread.include_project_info = action.payload.value; + const rt = getRuntime(state, action.payload.chatId); + if (rt) rt.thread.include_project_info = action.payload.value; }); builder.addCase(setContextTokensCap, (state, action) => { - if (state.thread.id !== action.payload.chatId) return state; - state.thread.context_tokens_cap = action.payload.value; + const rt = getRuntime(state, action.payload.chatId); + if (rt) rt.thread.context_tokens_cap = action.payload.value; + }); + + builder.addCase(setThreadPauseReasons, (state, action) => { + const rt = getRuntime(state, action.payload.id); + if (rt) { + rt.confirmation.pause = true; + rt.confirmation.pause_reasons = action.payload.pauseReasons; + rt.confirmation.status.wasInteracted = false; + rt.confirmation.status.confirmationStatus = false; + rt.streaming = false; + rt.waiting_for_response = false; + } + }); + + builder.addCase(clearThreadPauseReasons, (state, action) => { + const rt = getRuntime(state, action.payload.id); + if (rt) { + rt.confirmation.pause = false; + rt.confirmation.pause_reasons = []; + } + }); + + builder.addCase(setThreadConfirmationStatus, (state, action) => { + const rt = getRuntime(state, action.payload.id); + if (rt) { + rt.confirmation.status.wasInteracted = action.payload.wasInteracted; + rt.confirmation.status.confirmationStatus = action.payload.confirmationStatus; + } + }); + + builder.addCase(addThreadImage, (state, action) => { + const rt = getRuntime(state, action.payload.id); + if (rt && rt.attached_images.length < 10) { + rt.attached_images.push(action.payload.image); + } + }); + + builder.addCase(removeThreadImageByIndex, (state, action) => { + const rt = getRuntime(state, action.payload.id); + if (rt) { + rt.attached_images = rt.attached_images.filter( + (_, index) => index !== action.payload.index, + ); + } + }); + + builder.addCase(resetThreadImages, (state, action) => { + const rt = getRuntime(state, action.payload.id); + if (rt) { + rt.attached_images = []; + } }); builder.addMatcher( capsApi.endpoints.getCaps.matchFulfilled, (state, action) => { const defaultModel = action.payload.chat_default_model; + const rt = getCurrentRuntime(state); + if (!rt) return; - const model = state.thread.model || defaultModel; + const model = rt.thread.model || defaultModel; if (!(model in action.payload.chat_models)) return; const currentModelMaximumContextTokens = action.payload.chat_models[model].n_ctx; - state.thread.currentMaximumContextTokens = - currentModelMaximumContextTokens; + rt.thread.currentMaximumContextTokens = currentModelMaximumContextTokens; if ( - state.thread.context_tokens_cap === undefined || - state.thread.context_tokens_cap > currentModelMaximumContextTokens + rt.thread.context_tokens_cap === undefined || + rt.thread.context_tokens_cap > currentModelMaximumContextTokens ) { - state.thread.context_tokens_cap = currentModelMaximumContextTokens; + rt.thread.context_tokens_cap = currentModelMaximumContextTokens; } }, ); @@ -523,8 +646,27 @@ export const chatReducer = createReducer(initialState, (builder) => { builder.addMatcher( commandsApi.endpoints.getCommandPreview.matchFulfilled, (state, action) => { - state.thread.currentMaximumContextTokens = action.payload.number_context; - state.thread.currentMessageContextTokens = action.payload.current_context; // assuming that this number is amount of tokens per current message + const rt = getCurrentRuntime(state); + if (rt) { + rt.thread.currentMaximumContextTokens = action.payload.number_context; + rt.thread.currentMessageContextTokens = action.payload.current_context; + } + }, + ); + + // Handle rejected chat requests - set error state so spinner hides and SSE doesn't overwrite + builder.addMatcher( + chatAskQuestionThunk.rejected.match, + (state, action) => { + const chatId = action.meta.arg.chatId; + const rt = getRuntime(state, chatId); + if (rt && action.payload) { + const payload = action.payload as { detail?: string }; + rt.error = payload.detail ?? "Unknown error"; + rt.prevent_send = true; + rt.streaming = false; + rt.waiting_for_response = false; + } }, ); }); @@ -588,7 +730,6 @@ export function maybeAppendToolCallResultFromIdeToMessages( content: { content: message, tool_call_id: toolCallId, - // assuming, that tool_failed is always false at this point tool_failed: false, }, }; diff --git a/refact-agent/gui/src/features/Chat/Thread/selectors.ts b/refact-agent/gui/src/features/Chat/Thread/selectors.ts index 6a10a1050..d5f3bc1fa 100644 --- a/refact-agent/gui/src/features/Chat/Thread/selectors.ts +++ b/refact-agent/gui/src/features/Chat/Thread/selectors.ts @@ -6,76 +6,150 @@ import { isDiffMessage, isToolMessage, isUserMessage, + ChatMessages, } from "../../../services/refact/types"; import { takeFromLast } from "../../../utils/takeFromLast"; +import { ChatThreadRuntime, QueuedUserMessage, ThreadConfirmation } from "./types"; + +// Constant default values to avoid creating new references on each selector call +const EMPTY_MESSAGES: ChatMessages = []; +const EMPTY_QUEUED: QueuedUserMessage[] = []; +const EMPTY_PAUSE_REASONS: string[] = []; +const EMPTY_IMAGES: string[] = []; +const DEFAULT_NEW_CHAT_SUGGESTED = { wasSuggested: false } as const; +const DEFAULT_CONFIRMATION: ThreadConfirmation = { + pause: false, + pause_reasons: [], + status: { wasInteracted: false, confirmationStatus: true }, +}; +const DEFAULT_CONFIRMATION_STATUS = { wasInteracted: false, confirmationStatus: true } as const; + +export const selectCurrentThreadId = (state: RootState) => state.chat.current_thread_id; +export const selectOpenThreadIds = (state: RootState) => state.chat.open_thread_ids; +export const selectAllThreads = (state: RootState) => state.chat.threads; + +export const selectRuntimeById = (state: RootState, chatId: string): ChatThreadRuntime | null => + state.chat.threads[chatId] ?? null; + +export const selectCurrentRuntime = (state: RootState): ChatThreadRuntime | null => + state.chat.threads[state.chat.current_thread_id] ?? null; + +export const selectThreadById = (state: RootState, chatId: string) => + state.chat.threads[chatId]?.thread ?? null; + +export const selectThread = (state: RootState) => + state.chat.threads[state.chat.current_thread_id]?.thread ?? null; + +export const selectThreadTitle = (state: RootState) => + state.chat.threads[state.chat.current_thread_id]?.thread.title; + +export const selectChatId = (state: RootState) => + state.chat.current_thread_id; + +export const selectModel = (state: RootState) => + state.chat.threads[state.chat.current_thread_id]?.thread.model ?? ""; + +export const selectMessages = (state: RootState) => + state.chat.threads[state.chat.current_thread_id]?.thread.messages ?? EMPTY_MESSAGES; + +export const selectMessagesById = (state: RootState, chatId: string) => + state.chat.threads[chatId]?.thread.messages ?? EMPTY_MESSAGES; -export const selectThread = (state: RootState) => state.chat.thread; -export const selectThreadTitle = (state: RootState) => state.chat.thread.title; -export const selectChatId = (state: RootState) => state.chat.thread.id; -export const selectModel = (state: RootState) => state.chat.thread.model; -export const selectMessages = (state: RootState) => state.chat.thread.messages; export const selectToolUse = (state: RootState) => state.chat.tool_use; + export const selectThreadToolUse = (state: RootState) => - state.chat.thread.tool_use; + state.chat.threads[state.chat.current_thread_id]?.thread.tool_use; + export const selectAutomaticPatch = (state: RootState) => - state.chat.thread.automatic_patch; + state.chat.threads[state.chat.current_thread_id]?.thread.automatic_patch; export const selectCheckpointsEnabled = (state: RootState) => state.chat.checkpoints_enabled; export const selectThreadBoostReasoning = (state: RootState) => - state.chat.thread.boost_reasoning; + state.chat.threads[state.chat.current_thread_id]?.thread.boost_reasoning; export const selectIncludeProjectInfo = (state: RootState) => - state.chat.thread.include_project_info; + state.chat.threads[state.chat.current_thread_id]?.thread.include_project_info; export const selectContextTokensCap = (state: RootState) => - state.chat.thread.context_tokens_cap; + state.chat.threads[state.chat.current_thread_id]?.thread.context_tokens_cap; -// TBD: only used when `/links` suggests a new chat. export const selectThreadNewChatSuggested = (state: RootState) => - state.chat.thread.new_chat_suggested; + state.chat.threads[state.chat.current_thread_id]?.thread.new_chat_suggested ?? DEFAULT_NEW_CHAT_SUGGESTED; + export const selectThreadMaximumTokens = (state: RootState) => - state.chat.thread.currentMaximumContextTokens; + state.chat.threads[state.chat.current_thread_id]?.thread.currentMaximumContextTokens; + export const selectThreadCurrentMessageTokens = (state: RootState) => - state.chat.thread.currentMessageContextTokens; + state.chat.threads[state.chat.current_thread_id]?.thread.currentMessageContextTokens; + export const selectIsWaiting = (state: RootState) => - state.chat.waiting_for_response; + state.chat.threads[state.chat.current_thread_id]?.waiting_for_response ?? false; + +export const selectIsWaitingById = (state: RootState, chatId: string) => + state.chat.threads[chatId]?.waiting_for_response ?? false; + export const selectAreFollowUpsEnabled = (state: RootState) => state.chat.follow_ups_enabled; -export const selectIsTitleGenerationEnabled = (state: RootState) => - state.chat.title_generation_enabled; + export const selectUseCompression = (state: RootState) => state.chat.use_compression; -export const selectIsStreaming = (state: RootState) => state.chat.streaming; -export const selectPreventSend = (state: RootState) => state.chat.prevent_send; -export const selectChatError = (state: RootState) => state.chat.error; + +export const selectIsStreaming = (state: RootState) => + state.chat.threads[state.chat.current_thread_id]?.streaming ?? false; + +export const selectIsStreamingById = (state: RootState, chatId: string) => + state.chat.threads[chatId]?.streaming ?? false; + +export const selectPreventSend = (state: RootState) => + state.chat.threads[state.chat.current_thread_id]?.prevent_send ?? false; + +export const selectPreventSendById = (state: RootState, chatId: string) => + state.chat.threads[chatId]?.prevent_send ?? false; + +export const selectChatError = (state: RootState) => + state.chat.threads[state.chat.current_thread_id]?.error ?? null; + +export const selectChatErrorById = (state: RootState, chatId: string) => + state.chat.threads[chatId]?.error ?? null; + export const selectSendImmediately = (state: RootState) => - state.chat.send_immediately; + state.chat.threads[state.chat.current_thread_id]?.send_immediately ?? false; + export const getSelectedSystemPrompt = (state: RootState) => state.chat.system_prompt; +export const selectAnyThreadStreaming = createSelector( + [selectAllThreads], + (threads) => Object.values(threads).some((rt) => rt.streaming), +); + +export const selectStreamingThreadIds = createSelector( + [selectAllThreads], + (threads) => + Object.entries(threads) + .filter(([, rt]) => rt.streaming) + .map(([id]) => id), +); + export const toolMessagesSelector = createSelector( selectMessages, - (messages) => { - return messages.filter(isToolMessage); - }, + (messages) => messages.filter(isToolMessage), ); export const selectToolResultById = createSelector( [toolMessagesSelector, (_, id?: string) => id], - (messages, id) => { - return messages.find((message) => message.content.tool_call_id === id) - ?.content; - }, + (messages, id) => + messages.find((message) => message.content.tool_call_id === id)?.content, ); export const selectManyToolResultsByIds = (ids: string[]) => - createSelector(toolMessagesSelector, (messages) => { - return messages + createSelector(toolMessagesSelector, (messages) => + messages .filter((message) => ids.includes(message.content.tool_call_id)) - .map((toolMessage) => toolMessage.content); - }); + .map((toolMessage) => toolMessage.content), + ); const selectDiffMessages = createSelector(selectMessages, (messages) => messages.filter(isDiffMessage), @@ -83,27 +157,25 @@ const selectDiffMessages = createSelector(selectMessages, (messages) => export const selectDiffMessageById = createSelector( [selectDiffMessages, (_, id?: string) => id], - (messages, id) => { - return messages.find((message) => message.tool_call_id === id); - }, + (messages, id) => messages.find((message) => message.tool_call_id === id), ); export const selectManyDiffMessageByIds = (ids: string[]) => - createSelector(selectDiffMessages, (diffs) => { - return diffs.filter((message) => ids.includes(message.tool_call_id)); - }); + createSelector(selectDiffMessages, (diffs) => + diffs.filter((message) => ids.includes(message.tool_call_id)), + ); export const getSelectedToolUse = (state: RootState) => - state.chat.thread.tool_use; + state.chat.threads[state.chat.current_thread_id]?.thread.tool_use; export const selectIntegration = createSelector( selectThread, - (thread) => thread.integration, + (thread) => thread?.integration, ); export const selectThreadMode = createSelector( selectThread, - (thread) => thread.mode, + (thread) => thread?.mode, ); export const selectLastSentCompression = createSelector( @@ -121,13 +193,12 @@ export const selectLastSentCompression = createSelector( }, null, ); - return lastCompression; }, ); export const selectQueuedMessages = (state: RootState) => - state.chat.queued_messages; + state.chat.threads[state.chat.current_thread_id]?.queued_messages ?? EMPTY_QUEUED; export const selectQueuedMessagesCount = createSelector( selectQueuedMessages, @@ -139,40 +210,58 @@ export const selectHasQueuedMessages = createSelector( (queued) => queued.length > 0, ); +function hasUncalledToolsInMessages(messages: ReturnType): boolean { + if (messages.length === 0) return false; + const tailMessages = takeFromLast(messages, isUserMessage); + + const toolCalls = tailMessages.reduce((acc, cur) => { + if (!isAssistantMessage(cur)) return acc; + if (!cur.tool_calls || cur.tool_calls.length === 0) return acc; + const curToolCallIds = cur.tool_calls + .map((toolCall) => toolCall.id) + .filter((id) => id !== undefined); + return [...acc, ...curToolCallIds]; + }, []); + + if (toolCalls.length === 0) return false; + + const toolMessages = tailMessages + .map((msg) => { + if (isToolMessage(msg)) return msg.content.tool_call_id; + if ("tool_call_id" in msg && typeof msg.tool_call_id === "string") + return msg.tool_call_id; + return undefined; + }) + .filter((id): id is string => typeof id === "string"); + + return toolCalls.some((toolCallId) => !toolMessages.includes(toolCallId)); +} + +export const selectHasUncalledToolsById = (state: RootState, chatId: string): boolean => + hasUncalledToolsInMessages(selectMessagesById(state, chatId)); + export const selectHasUncalledTools = createSelector( selectMessages, - (messages) => { - if (messages.length === 0) return false; - const tailMessages = takeFromLast(messages, isUserMessage); + hasUncalledToolsInMessages, +); - const toolCalls = tailMessages.reduce((acc, cur) => { - if (!isAssistantMessage(cur)) return acc; - if (!cur.tool_calls || cur.tool_calls.length === 0) return acc; - const curToolCallIds = cur.tool_calls - .map((toolCall) => toolCall.id) - .filter((id) => id !== undefined); +export const selectThreadConfirmation = (state: RootState) => + state.chat.threads[state.chat.current_thread_id]?.confirmation ?? DEFAULT_CONFIRMATION; - return [...acc, ...curToolCallIds]; - }, []); +export const selectThreadConfirmationById = (state: RootState, chatId: string) => + state.chat.threads[chatId]?.confirmation ?? DEFAULT_CONFIRMATION; - if (toolCalls.length === 0) return false; +export const selectThreadPauseReasons = (state: RootState) => + state.chat.threads[state.chat.current_thread_id]?.confirmation.pause_reasons ?? EMPTY_PAUSE_REASONS; - const toolMessages = tailMessages - .map((msg) => { - if (isToolMessage(msg)) { - return msg.content.tool_call_id; - } - if ("tool_call_id" in msg && typeof msg.tool_call_id === "string") { - return msg.tool_call_id; - } - return undefined; - }) - .filter((id): id is string => typeof id === "string"); +export const selectThreadPause = (state: RootState) => + state.chat.threads[state.chat.current_thread_id]?.confirmation.pause ?? false; - const hasUnsentTools = toolCalls.some( - (toolCallId) => !toolMessages.includes(toolCallId), - ); +export const selectThreadConfirmationStatus = (state: RootState) => + state.chat.threads[state.chat.current_thread_id]?.confirmation.status ?? DEFAULT_CONFIRMATION_STATUS; - return hasUnsentTools; - }, -); +export const selectThreadImages = (state: RootState) => + state.chat.threads[state.chat.current_thread_id]?.attached_images ?? EMPTY_IMAGES; + +export const selectThreadImagesById = (state: RootState, chatId: string) => + state.chat.threads[chatId]?.attached_images ?? EMPTY_IMAGES; diff --git a/refact-agent/gui/src/features/Chat/Thread/types.ts b/refact-agent/gui/src/features/Chat/Thread/types.ts index 25091e93e..d76914bfc 100644 --- a/refact-agent/gui/src/features/Chat/Thread/types.ts +++ b/refact-agent/gui/src/features/Chat/Thread/types.ts @@ -1,8 +1,19 @@ -import { Usage } from "../../../services/refact"; +import { ToolConfirmationPauseReason, Usage } from "../../../services/refact"; import { SystemPrompts } from "../../../services/refact/prompts"; import { ChatMessages, UserMessage } from "../../../services/refact/types"; import { parseOrElse } from "../../../utils/parseOrElse"; +export type ImageFile = { + name: string; + content: string | ArrayBuffer | null; + type: string; +}; + +export type ToolConfirmationStatus = { + wasInteracted: boolean; + confirmationStatus: boolean; +}; + export type QueuedUserMessage = { id: string; message: UserMessage; @@ -16,6 +27,7 @@ export type IntegrationMeta = { project?: string; shouldIntermediatePageShowUp?: boolean; }; + export type ChatThread = { id: string; messages: ChatMessages; @@ -47,22 +59,32 @@ export type SuggestedChat = { export type ToolUse = "quick" | "explore" | "agent"; -export type Chat = { - streaming: boolean; +export type ChatThreadRuntime = { thread: ChatThread; - error: null | string; - prevent_send: boolean; - checkpoints_enabled?: boolean; + streaming: boolean; waiting_for_response: boolean; - max_new_tokens?: number; - cache: Record; + prevent_send: boolean; + error: string | null; + queued_messages: QueuedUserMessage[]; + send_immediately: boolean; + attached_images: ImageFile[]; + confirmation: { + pause: boolean; + pause_reasons: ToolConfirmationPauseReason[]; + status: ToolConfirmationStatus; + }; +}; + +export type Chat = { + current_thread_id: string; + open_thread_ids: string[]; + threads: Record; system_prompt: SystemPrompts; tool_use: ToolUse; - send_immediately: boolean; + checkpoints_enabled?: boolean; follow_ups_enabled?: boolean; - title_generation_enabled?: boolean; use_compression?: boolean; - queued_messages: QueuedUserMessage[]; + max_new_tokens?: number; }; export type PayloadWithId = { id: string }; diff --git a/refact-agent/gui/src/features/Chat/Thread/utils.ts b/refact-agent/gui/src/features/Chat/Thread/utils.ts index a5ec4012c..f614c6bbc 100644 --- a/refact-agent/gui/src/features/Chat/Thread/utils.ts +++ b/refact-agent/gui/src/features/Chat/Thread/utils.ts @@ -890,6 +890,14 @@ export function consumeStream( onChunk: (chunk: Record) => void, ) { const decoder = new TextDecoder(); + let abortHandled = false; + + const handleAbort = () => { + if (!abortHandled) { + abortHandled = true; + onAbort(); + } + }; function pump({ done, @@ -897,7 +905,7 @@ export function consumeStream( }: ReadableStreamReadResult): Promise { if (done) return Promise.resolve(); if (signal.aborted) { - onAbort(); + handleAbort(); return Promise.resolve(); } @@ -931,6 +939,13 @@ export function consumeStream( if (deltas.length === 0) return Promise.resolve(); for (const delta of deltas) { + // Check abort signal before processing each chunk to prevent late chunks + // from corrupting state after user stops streaming + if (signal.aborted) { + handleAbort(); + return Promise.resolve(); + } + if (!delta.startsWith("data: ")) { // eslint-disable-next-line no-console console.log("Unexpected data in streaming buf: " + delta); @@ -974,6 +989,13 @@ export function consumeStream( onChunk(json); } + + // Check abort before continuing to read more chunks + if (signal.aborted) { + handleAbort(); + return Promise.resolve(); + } + return reader.read().then(pump); } diff --git a/refact-agent/gui/src/features/Chat/currentProject.ts b/refact-agent/gui/src/features/Chat/currentProject.ts index 39c74e05c..86ed162f4 100644 --- a/refact-agent/gui/src/features/Chat/currentProject.ts +++ b/refact-agent/gui/src/features/Chat/currentProject.ts @@ -24,8 +24,10 @@ export const currentProjectInfoReducer = createReducer( ); export const selectThreadProjectOrCurrentProject = (state: RootState) => { - if (state.chat.thread.integration?.project) { - return state.chat.thread.integration.project; + const runtime = state.chat.threads[state.chat.current_thread_id]; + const thread = runtime?.thread; + if (thread?.integration?.project) { + return thread.integration.project; } - return state.chat.thread.project_name ?? state.current_project.name; + return thread?.project_name ?? state.current_project.name; }; diff --git a/refact-agent/gui/src/features/History/historySlice.ts b/refact-agent/gui/src/features/History/historySlice.ts index 9f9fabe22..6d81faee2 100644 --- a/refact-agent/gui/src/features/History/historySlice.ts +++ b/refact-agent/gui/src/features/History/historySlice.ts @@ -6,20 +6,19 @@ import { import { backUpMessages, chatAskedQuestion, - chatGenerateTitleThunk, ChatThread, doneStreaming, isLspChatMode, maybeAppendToolCallResultFromIdeToMessages, - removeChatFromCache, restoreChat, setChatMode, SuggestedChat, } from "../Chat/Thread"; import { - isAssistantMessage, - isChatGetTitleActionPayload, - isUserMessage, + trajectoriesApi, + chatThreadToTrajectoryData, + TrajectoryData, + trajectoryDataToChatThread, } from "../../services/refact"; import { AppDispatch, RootState } from "../../app/store"; import { ideToolCallResponse } from "../../hooks/useEventBusForIDE"; @@ -42,17 +41,20 @@ export type HistoryState = Record; const initialState: HistoryState = {}; function getFirstUserContentFromChat(messages: ChatThread["messages"]): string { - const message = messages.find(isUserMessage); + const message = messages.find( + (msg): msg is ChatThread["messages"][number] & { role: "user" } => + msg.role === "user", + ); if (!message) return "New Chat"; if (typeof message.content === "string") { - return message.content.replace(/^\s+/, ""); + return message.content.replace(/^\s+/, "").slice(0, 100); } - const firstUserInput = message.content.find((message) => { - if ("m_type" in message && message.m_type === "text") { + const firstUserInput = message.content.find((item) => { + if ("m_type" in item && item.m_type === "text") { return true; } - if ("type" in message && message.type === "text") { + if ("type" in item && item.type === "text") { return true; } return false; @@ -65,7 +67,37 @@ function getFirstUserContentFromChat(messages: ChatThread["messages"]): string { ? firstUserInput.text : "New Chat"; - return text.replace(/^\s+/, ""); + return text.replace(/^\s+/, "").slice(0, 100); +} + +function chatThreadToHistoryItem(thread: ChatThread): ChatHistoryItem { + const now = new Date().toISOString(); + const updatedMode = + thread.mode && !isLspChatMode(thread.mode) ? "AGENT" : thread.mode; + + return { + ...thread, + // Use thread title if available, otherwise truncated first user message + title: thread.title || getFirstUserContentFromChat(thread.messages), + createdAt: thread.createdAt ?? now, + updatedAt: now, + integration: thread.integration, + currentMaximumContextTokens: thread.currentMaximumContextTokens, + isTitleGenerated: thread.isTitleGenerated, + automatic_patch: thread.automatic_patch, + mode: updatedMode, + }; +} + +function trajectoryToHistoryItem(data: TrajectoryData): ChatHistoryItem { + const thread = trajectoryDataToChatThread(data); + return { + ...thread, + createdAt: data.created_at, + updatedAt: data.updated_at, + title: data.title, + isTitleGenerated: data.isTitleGenerated, + }; } export const historySlice = createSlice({ @@ -74,83 +106,57 @@ export const historySlice = createSlice({ reducers: { saveChat: (state, action: PayloadAction) => { if (action.payload.messages.length === 0) return state; - const now = new Date().toISOString(); - - const updatedMode = - action.payload.mode && !isLspChatMode(action.payload.mode) - ? "AGENT" - : action.payload.mode; - - const chat: ChatHistoryItem = { - ...action.payload, - title: action.payload.title - ? action.payload.title - : getFirstUserContentFromChat(action.payload.messages), - createdAt: action.payload.createdAt ?? now, - updatedAt: now, - // TODO: check if this integration may cause any issues - integration: action.payload.integration, - currentMaximumContextTokens: action.payload.currentMaximumContextTokens, - isTitleGenerated: action.payload.isTitleGenerated, - automatic_patch: action.payload.automatic_patch, - mode: updatedMode, - }; - - const messageMap = { - ...state, - }; - messageMap[chat.id] = chat; - - const messages = Object.values(messageMap); - if (messages.length <= 100) { - return messageMap; + const chat = chatThreadToHistoryItem(action.payload); + const existing = state[chat.id]; + if (existing?.isTitleGenerated && !chat.isTitleGenerated) { + chat.title = existing.title; + chat.isTitleGenerated = true; } + state[chat.id] = chat; - const sortedByLastUpdated = messages - .slice(0) - .sort((a, b) => b.updatedAt.localeCompare(a.updatedAt)); - - const newHistory = sortedByLastUpdated.slice(0, 100); - const nextState = newHistory.reduce( - (acc, chat) => ({ ...acc, [chat.id]: chat }), - {}, - ); - return nextState; + const messages = Object.values(state); + if (messages.length > 100) { + const sorted = messages.sort((a, b) => + b.updatedAt.localeCompare(a.updatedAt), + ); + return sorted.slice(0, 100).reduce( + (acc, c) => ({ ...acc, [c.id]: c }), + {}, + ); + } }, - setTitleGenerationCompletionForChat: ( - state, - action: PayloadAction, - ) => { - const chatId = action.payload; - state[chatId].isTitleGenerated = true; + hydrateHistory: (state, action: PayloadAction) => { + for (const data of action.payload) { + state[data.id] = trajectoryToHistoryItem(data); + } }, markChatAsUnread: (state, action: PayloadAction) => { - const chatId = action.payload; - state[chatId].read = false; + if (action.payload in state) { + state[action.payload].read = false; + } }, markChatAsRead: (state, action: PayloadAction) => { - const chatId = action.payload; - state[chatId].read = true; + if (action.payload in state) { + state[action.payload].read = true; + } }, deleteChatById: (state, action: PayloadAction) => { - return Object.entries(state).reduce>( - (acc, [key, value]) => { - if (key === action.payload) return acc; - return { ...acc, [key]: value }; - }, - {}, - ); + delete state[action.payload]; }, + updateChatTitleById: ( state, action: PayloadAction<{ chatId: string; newTitle: string }>, ) => { - state[action.payload.chatId].title = action.payload.newTitle; + if (action.payload.chatId in state) { + state[action.payload.chatId].title = action.payload.newTitle; + } }, + clearHistory: () => { return {}; }, @@ -187,17 +193,25 @@ export const historySlice = createSlice({ export const { saveChat, + hydrateHistory, deleteChatById, markChatAsUnread, markChatAsRead, - setTitleGenerationCompletionForChat, updateChatTitleById, clearHistory, upsertToolCallIntoHistory, } = historySlice.actions; export const { getChatById, getHistory } = historySlice.selectors; -// We could use this or reduce-reducers packages +async function persistToBackend( + dispatch: AppDispatch, + thread: ChatThread, + existingCreatedAt?: string, +) { + const data = chatThreadToTrajectoryData(thread, existingCreatedAt); + dispatch(trajectoriesApi.endpoints.saveTrajectory.initiate(data)); +} + export const historyMiddleware = createListenerMiddleware(); const startHistoryListening = historyMiddleware.startListening.withTypes< RootState, @@ -208,75 +222,17 @@ startHistoryListening({ actionCreator: doneStreaming, effect: (action, listenerApi) => { const state = listenerApi.getState(); - const isTitleGenerationEnabled = state.chat.title_generation_enabled; - - const thread = - action.payload.id in state.chat.cache - ? state.chat.cache[action.payload.id] - : state.chat.thread; - - const lastMessage = thread.messages.slice(-1)[0]; - const isTitleGenerated = thread.isTitleGenerated; - // Checking for reliable chat pause - if ( - thread.messages.length && - isAssistantMessage(lastMessage) && - !lastMessage.tool_calls - ) { - // Getting user message - const firstUserMessage = thread.messages.find(isUserMessage); - if (firstUserMessage) { - // Checking if chat title is already generated, if not - generating it - if (!isTitleGenerated && isTitleGenerationEnabled) { - listenerApi - .dispatch( - chatGenerateTitleThunk({ - messages: [firstUserMessage], - chatId: state.chat.thread.id, - }), - ) - .unwrap() - .then((response) => { - if (isChatGetTitleActionPayload(response)) { - if (typeof response.title === "string") { - listenerApi.dispatch( - saveChat({ - ...thread, - title: response.title, - }), - ); - listenerApi.dispatch( - setTitleGenerationCompletionForChat(thread.id), - ); - } - } - }) - .catch(() => { - // TODO: handle error in case if not generated, now returning user message as a title - const title = getFirstUserContentFromChat([firstUserMessage]); - listenerApi.dispatch( - saveChat({ - ...thread, - title: title, - }), - ); - }); - } - } - } else { - // Probably chat was paused with uncalled tools - listenerApi.dispatch( - saveChat({ - ...thread, - }), - ); - } - if (state.chat.thread.id === action.payload.id) { - listenerApi.dispatch(saveChat(state.chat.thread)); - } else if (action.payload.id in state.chat.cache) { - listenerApi.dispatch(saveChat(state.chat.cache[action.payload.id])); - listenerApi.dispatch(removeChatFromCache({ id: action.payload.id })); - } + + const runtime = state.chat.threads[action.payload.id]; + if (!runtime) return; + const thread = runtime.thread; + + const existingChat = state.history[thread.id]; + const existingCreatedAt = existingChat?.createdAt; + + // Title generation is now handled by the backend + listenerApi.dispatch(saveChat(thread)); + persistToBackend(listenerApi.dispatch, thread, existingCreatedAt); }, }); @@ -284,14 +240,18 @@ startHistoryListening({ actionCreator: backUpMessages, effect: (action, listenerApi) => { const state = listenerApi.getState(); - const thread = state.chat.thread; - if (thread.id !== action.payload.id) return; + const runtime = state.chat.threads[action.payload.id]; + if (!runtime) return; + const thread = runtime.thread; + + const existingChat = state.history[thread.id]; const toSave = { ...thread, messages: action.payload.messages, project_name: thread.project_name ?? state.current_project.name, }; listenerApi.dispatch(saveChat(toSave)); + persistToBackend(listenerApi.dispatch, toSave, existingChat?.createdAt); }, }); @@ -306,8 +266,8 @@ startHistoryListening({ actionCreator: restoreChat, effect: (action, listenerApi) => { const chat = listenerApi.getState().chat; - if (chat.thread.id == action.payload.id && chat.streaming) return; - if (action.payload.id in chat.cache) return; + const runtime = chat.threads[action.payload.id]; + if (runtime?.streaming) return; listenerApi.dispatch(markChatAsRead(action.payload.id)); }, }); @@ -316,12 +276,23 @@ startHistoryListening({ actionCreator: setChatMode, effect: (action, listenerApi) => { const state = listenerApi.getState(); - const thread = state.chat.thread; + const runtime = state.chat.threads[state.chat.current_thread_id]; + if (!runtime) return; + const thread = runtime.thread; if (!(thread.id in state.history)) return; + const existingChat = state.history[thread.id]; const toSave = { ...thread, mode: action.payload }; listenerApi.dispatch(saveChat(toSave)); + persistToBackend(listenerApi.dispatch, toSave, existingChat?.createdAt); }, }); -// TODO: add a listener for creating a new chat ? +startHistoryListening({ + actionCreator: deleteChatById, + effect: (action, listenerApi) => { + listenerApi.dispatch( + trajectoriesApi.endpoints.deleteTrajectory.initiate(action.payload), + ); + }, +}); diff --git a/refact-agent/gui/src/features/ToolConfirmation/confirmationSlice.ts b/refact-agent/gui/src/features/ToolConfirmation/confirmationSlice.ts deleted file mode 100644 index 129c62585..000000000 --- a/refact-agent/gui/src/features/ToolConfirmation/confirmationSlice.ts +++ /dev/null @@ -1,85 +0,0 @@ -import { createSlice, PayloadAction } from "@reduxjs/toolkit"; -import type { ToolConfirmationPauseReason } from "../../services/refact"; -import { ideToolCallResponse } from "../../hooks/useEventBusForIDE"; - -export type ConfirmationState = { - pauseReasons: ToolConfirmationPauseReason[]; - pause: boolean; - status: { - wasInteracted: boolean; - confirmationStatus: boolean; - }; -}; - -const initialState: ConfirmationState = { - pauseReasons: [], - pause: false, - status: { - wasInteracted: false, - confirmationStatus: true, - }, -}; - -type ConfirmationActionPayload = { - wasInteracted: boolean; - confirmationStatus: boolean; -}; - -export const confirmationSlice = createSlice({ - name: "confirmation", - initialState, - reducers: { - setPauseReasons( - state, - action: PayloadAction, - ) { - state.pause = true; - state.pauseReasons = action.payload; - }, - resetConfirmationInteractedState(state) { - state.status.wasInteracted = false; - state.pause = false; - state.pauseReasons = []; - }, - clearPauseReasonsAndHandleToolsStatus( - state, - action: PayloadAction, - ) { - state.pause = false; - state.pauseReasons = []; - state.status = action.payload; - }, - - updateConfirmationAfterIdeToolUse( - state, - action: PayloadAction[0]>, - ) { - const pauseReasons = state.pauseReasons.filter( - (reason) => reason.tool_call_id !== action.payload.toolCallId, - ); - if (pauseReasons.length === 0) { - state.status.wasInteracted = true; // work around for auto send. - } - state.pauseReasons = pauseReasons; - }, - }, - selectors: { - getPauseReasonsWithPauseStatus: (state) => state, - getToolsInteractionStatus: (state) => state.status.wasInteracted, - getToolsConfirmationStatus: (state) => state.status.confirmationStatus, - getConfirmationPauseStatus: (state) => state.pause, - }, -}); - -export const { - setPauseReasons, - resetConfirmationInteractedState, - clearPauseReasonsAndHandleToolsStatus, - updateConfirmationAfterIdeToolUse, -} = confirmationSlice.actions; -export const { - getPauseReasonsWithPauseStatus, - getToolsConfirmationStatus, - getToolsInteractionStatus, - getConfirmationPauseStatus, -} = confirmationSlice.selectors; diff --git a/refact-agent/gui/src/hooks/index.ts b/refact-agent/gui/src/hooks/index.ts index 22d2c9e15..18e7664c3 100644 --- a/refact-agent/gui/src/hooks/index.ts +++ b/refact-agent/gui/src/hooks/index.ts @@ -38,3 +38,4 @@ export * from "./useCompressionStop"; export * from "./useEventBusForApp"; export * from "./useTotalCostForChat"; export * from "./useCheckpoints"; +export * from "./useTrajectoriesSubscription"; diff --git a/refact-agent/gui/src/hooks/useAttachedImages.ts b/refact-agent/gui/src/hooks/useAttachedImages.ts index fb0062c65..023e969b5 100644 --- a/refact-agent/gui/src/hooks/useAttachedImages.ts +++ b/refact-agent/gui/src/hooks/useAttachedImages.ts @@ -2,35 +2,35 @@ import { useCallback, useEffect } from "react"; import { useAppSelector } from "./useAppSelector"; import { useAppDispatch } from "./useAppDispatch"; import { - selectAllImages, - removeImageByIndex, - addImage, + selectThreadImages, + selectChatId, + addThreadImage, + removeThreadImageByIndex, + resetThreadImages, type ImageFile, - resetAttachedImagesSlice, -} from "../features/AttachedImages"; +} from "../features/Chat"; import { setError } from "../features/Errors/errorsSlice"; import { setInformation } from "../features/Errors/informationSlice"; import { useCapsForToolUse } from "./useCapsForToolUse"; export function useAttachedImages() { - const images = useAppSelector(selectAllImages); + const images = useAppSelector(selectThreadImages); + const chatId = useAppSelector(selectChatId); const { isMultimodalitySupportedForCurrentModel } = useCapsForToolUse(); const dispatch = useAppDispatch(); const removeImage = useCallback( (index: number) => { - const action = removeImageByIndex(index); - dispatch(action); + dispatch(removeThreadImageByIndex({ id: chatId, index })); }, - [dispatch], + [dispatch, chatId], ); const insertImage = useCallback( (file: ImageFile) => { - const action = addImage(file); - dispatch(action); + dispatch(addThreadImage({ id: chatId, image: file })); }, - [dispatch], + [dispatch, chatId], ); const handleError = useCallback( @@ -63,10 +63,9 @@ export function useAttachedImages() { useEffect(() => { if (!isMultimodalitySupportedForCurrentModel) { - const action = resetAttachedImagesSlice(); - dispatch(action); + dispatch(resetThreadImages({ id: chatId })); } - }, [isMultimodalitySupportedForCurrentModel, dispatch]); + }, [isMultimodalitySupportedForCurrentModel, dispatch, chatId]); return { images, diff --git a/refact-agent/gui/src/hooks/useCompressChat.ts b/refact-agent/gui/src/hooks/useCompressChat.ts index f539c0483..816894b0e 100644 --- a/refact-agent/gui/src/hooks/useCompressChat.ts +++ b/refact-agent/gui/src/hooks/useCompressChat.ts @@ -12,16 +12,18 @@ export function useCompressChat() { const thread = useAppSelector(selectThread); const [submit, request] = knowledgeApi.useCompressMessagesMutation({ - fixedCacheKey: thread.id, + fixedCacheKey: thread?.id ?? "", }); const compressChat = useCallback(async () => { - dispatch(setIsWaitingForResponse(true)); + if (!thread) return; + + dispatch(setIsWaitingForResponse({ id: thread.id, value: true })); const result = await submit({ messages: thread.messages, project: thread.project_name ?? "", }); - dispatch(setIsWaitingForResponse(false)); + dispatch(setIsWaitingForResponse({ id: thread.id, value: false })); if (result.error) { // TODO: handle errors @@ -40,7 +42,7 @@ export function useCompressChat() { dispatch(action); dispatch(setSendImmediately(true)); } - }, [dispatch, submit, thread.messages, thread.project_name, thread.title]); + }, [dispatch, submit, thread]); return { compressChat, diff --git a/refact-agent/gui/src/hooks/useGoToLink.ts b/refact-agent/gui/src/hooks/useGoToLink.ts index d633f4550..60d27b149 100644 --- a/refact-agent/gui/src/hooks/useGoToLink.ts +++ b/refact-agent/gui/src/hooks/useGoToLink.ts @@ -4,15 +4,15 @@ import { isAbsolutePath } from "../utils/isAbsolutePath"; import { useAppDispatch } from "./useAppDispatch"; import { popBackTo, push } from "../features/Pages/pagesSlice"; import { useAppSelector } from "./useAppSelector"; -import { selectIntegration } from "../features/Chat/Thread/selectors"; +import { selectIntegration, selectChatId } from "../features/Chat/Thread/selectors"; import { debugIntegrations } from "../debugConfig"; -import { newChatAction } from "../features/Chat/Thread/actions"; -import { clearPauseReasonsAndHandleToolsStatus } from "../features/ToolConfirmation/confirmationSlice"; +import { newChatAction, clearThreadPauseReasons, setThreadConfirmationStatus } from "../features/Chat/Thread/actions"; export function useGoToLink() { const dispatch = useAppDispatch(); const { queryPathThenOpenFile } = useEventsBusForIDE(); const maybeIntegration = useAppSelector(selectIntegration); + const chatId = useAppSelector(selectChatId); const handleGoTo = useCallback( ({ goto }: { goto?: string }) => { @@ -55,12 +55,8 @@ export function useGoToLink() { case "newchat": { dispatch(newChatAction()); - dispatch( - clearPauseReasonsAndHandleToolsStatus({ - wasInteracted: false, - confirmationStatus: true, - }), - ); + dispatch(clearThreadPauseReasons({ id: chatId })); + dispatch(setThreadConfirmationStatus({ id: chatId, wasInteracted: false, confirmationStatus: true })); dispatch(popBackTo({ name: "history" })); dispatch(push({ name: "chat" })); return; @@ -72,15 +68,7 @@ export function useGoToLink() { } } }, - [ - dispatch, - // maybeIntegration?.name, - // maybeIntegration?.path, - // maybeIntegration?.project, - // maybeIntegration?.shouldIntermediatePageShowUp, - maybeIntegration, - queryPathThenOpenFile, - ], + [dispatch, chatId, maybeIntegration, queryPathThenOpenFile], ); return { handleGoTo }; diff --git a/refact-agent/gui/src/hooks/useSendChatRequest.ts b/refact-agent/gui/src/hooks/useSendChatRequest.ts index 3788321d8..86610e18c 100644 --- a/refact-agent/gui/src/hooks/useSendChatRequest.ts +++ b/refact-agent/gui/src/hooks/useSendChatRequest.ts @@ -18,6 +18,9 @@ import { selectThread, selectThreadMode, selectThreadToolUse, + selectThreadConfirmationStatus, + selectThreadImages, + selectThreadPause, } from "../features/Chat/Thread/selectors"; import { useCheckForConfirmationMutation } from "./useGetToolGroupsQuery"; import { @@ -35,17 +38,12 @@ import { setSendImmediately, enqueueUserMessage, dequeueUserMessage, + setThreadPauseReasons, + clearThreadPauseReasons, + setThreadConfirmationStatus, } from "../features/Chat/Thread/actions"; -import { selectAllImages } from "../features/AttachedImages"; import { useAbortControllers } from "./useAbortControllers"; -import { - clearPauseReasonsAndHandleToolsStatus, - getToolsConfirmationStatus, - getToolsInteractionStatus, - resetConfirmationInteractedState, - setPauseReasons, -} from "../features/ToolConfirmation/confirmationSlice"; import { chatModeToLspMode, doneStreaming, @@ -114,11 +112,12 @@ export const useSendChatRequest = () => { const currentMessages = useAppSelector(selectMessages); const systemPrompt = useAppSelector(getSelectedSystemPrompt); const toolUse = useAppSelector(selectThreadToolUse); - const attachedImages = useAppSelector(selectAllImages); + const attachedImages = useAppSelector(selectThreadImages); const threadMode = useAppSelector(selectThreadMode); const threadIntegration = useAppSelector(selectIntegration); - const wasInteracted = useAppSelector(getToolsInteractionStatus); // shows if tool confirmation popup was interacted by user - const areToolsConfirmed = useAppSelector(getToolsConfirmationStatus); + const confirmationStatus = useAppSelector(selectThreadConfirmationStatus); + const wasInteracted = confirmationStatus.wasInteracted; + const areToolsConfirmed = confirmationStatus.confirmationStatus; const isPatchAutomatic = useAppSelector(selectAutomaticPatch); const checkpointsEnabled = useAppSelector(selectCheckpointsEnabled); @@ -137,20 +136,24 @@ export const useSendChatRequest = () => { const sendMessages = useCallback( async (messages: ChatMessages, maybeMode?: LspChatMode) => { - dispatch(setIsWaitingForResponse(true)); + dispatch(setIsWaitingForResponse({ id: chatId, value: true })); const lastMessage = messages.slice(-1)[0]; if ( !isWaiting && !wasInteracted && isAssistantMessage(lastMessage) && - lastMessage.tool_calls + lastMessage.tool_calls && + lastMessage.tool_calls.length > 0 ) { const toolCalls = lastMessage.tool_calls; + const firstToolCall = toolCalls[0]; + // Safety check for incomplete tool calls (can happen after aborted streams) + const firstToolName = firstToolCall?.function?.name; if ( !( - toolCalls[0].function.name && - PATCH_LIKE_FUNCTIONS.includes(toolCalls[0].function.name) && + firstToolName && + PATCH_LIKE_FUNCTIONS.includes(firstToolName) && isPatchAutomatic ) ) { @@ -159,7 +162,7 @@ export const useSendChatRequest = () => { messages: messages, }).unwrap(); if (confirmationResponse.pause) { - dispatch(setPauseReasons(confirmationResponse.pause_reasons)); + dispatch(setThreadPauseReasons({ id: chatId, pauseReasons: confirmationResponse.pause_reasons })); return; } } @@ -288,46 +291,38 @@ export const useSendChatRequest = () => { abortControllers.abort(chatId); dispatch(setPreventSend({ id: chatId })); dispatch(fixBrokenToolMessages({ id: chatId })); - dispatch(setIsWaitingForResponse(false)); + dispatch(setIsWaitingForResponse({ id: chatId, value: false })); dispatch(doneStreaming({ id: chatId })); }, [abortControllers, chatId, dispatch]); const retry = useCallback( (messages: ChatMessages) => { abort(); - dispatch( - clearPauseReasonsAndHandleToolsStatus({ - wasInteracted: false, - confirmationStatus: areToolsConfirmed, - }), - ); + dispatch(clearThreadPauseReasons({ id: chatId })); + dispatch(setThreadConfirmationStatus({ id: chatId, wasInteracted: false, confirmationStatus: areToolsConfirmed })); void sendMessages(messages); }, - [abort, sendMessages, dispatch, areToolsConfirmed], + [abort, sendMessages, dispatch, chatId, areToolsConfirmed], ); const confirmToolUsage = useCallback(() => { - dispatch( - clearPauseReasonsAndHandleToolsStatus({ - wasInteracted: true, - confirmationStatus: true, - }), - ); - - dispatch(setIsWaitingForResponse(false)); - }, [dispatch]); + dispatch(clearThreadPauseReasons({ id: chatId })); + dispatch(setThreadConfirmationStatus({ id: chatId, wasInteracted: true, confirmationStatus: true })); + // Continue the conversation - sendMessages will set waiting=true and proceed + // since wasInteracted is now true, the confirmation check will be skipped + void sendMessages(currentMessages); + }, [dispatch, chatId, sendMessages, currentMessages]); const rejectToolUsage = useCallback( (toolCallIds: string[]) => { toolCallIds.forEach((toolCallId) => { - dispatch( - upsertToolCallIntoHistory({ toolCallId, chatId, accepted: false }), - ); + dispatch(upsertToolCallIntoHistory({ toolCallId, chatId, accepted: false })); dispatch(upsertToolCall({ toolCallId, chatId, accepted: false })); }); - dispatch(resetConfirmationInteractedState()); - dispatch(setIsWaitingForResponse(false)); + dispatch(clearThreadPauseReasons({ id: chatId })); + dispatch(setThreadConfirmationStatus({ id: chatId, wasInteracted: false, confirmationStatus: true })); + dispatch(setIsWaitingForResponse({ id: chatId, value: false })); dispatch(doneStreaming({ id: chatId })); dispatch(setPreventSend({ id: chatId })); }, @@ -358,7 +353,6 @@ export const useSendChatRequest = () => { }; }; -// NOTE: only use this once export function useAutoSend() { const dispatch = useAppDispatch(); const streaming = useAppSelector(selectIsStreaming); @@ -367,14 +361,15 @@ export function useAutoSend() { const preventSend = useAppSelector(selectPreventSend); const isWaiting = useAppSelector(selectIsWaiting); const sendImmediately = useAppSelector(selectSendImmediately); - const wasInteracted = useAppSelector(getToolsInteractionStatus); // shows if tool confirmation popup was interacted by user - const areToolsConfirmed = useAppSelector(getToolsConfirmationStatus); + const confirmationStatus = useAppSelector(selectThreadConfirmationStatus); + const wasInteracted = confirmationStatus.wasInteracted; + const areToolsConfirmed = confirmationStatus.confirmationStatus; + const isPaused = useAppSelector(selectThreadPause); const hasUnsentTools = useAppSelector(selectHasUncalledTools); const queuedMessages = useAppSelector(selectQueuedMessages); const { sendMessages, messagesWithSystemPrompt } = useSendChatRequest(); - // TODO: make a selector for this, or show tool formation const thread = useAppSelector(selectThread); - const isIntegration = thread.integration ?? false; + const isIntegration = thread?.integration ?? false; useEffect(() => { if (sendImmediately) { @@ -393,8 +388,9 @@ export function useAutoSend() { const stopForToolConfirmation = useMemo(() => { if (isIntegration) return false; + if (isPaused) return true; return !wasInteracted && !areToolsConfirmed; - }, [isIntegration, wasInteracted, areToolsConfirmed]); + }, [isIntegration, isPaused, wasInteracted, areToolsConfirmed]); // Base conditions for flushing queue (streaming must be done) const canFlushBase = useMemo(() => { @@ -432,7 +428,7 @@ export function useAutoSend() { dispatch(dequeueUserMessage({ queuedId: nextQueued.id })); // Send the queued message - void sendMessages([...currentMessages, nextQueued.message], thread.mode); + void sendMessages([...currentMessages, nextQueued.message], thread?.mode); }, [ canFlushBase, isFullyIdle, @@ -440,7 +436,7 @@ export function useAutoSend() { dispatch, sendMessages, currentMessages, - thread.mode, + thread?.mode, ]); // Check if there are priority messages waiting @@ -449,29 +445,11 @@ export function useAutoSend() { [queuedMessages], ); - useEffect(() => { - if (stop) return; - if (stopForToolConfirmation) return; - // Don't run tool follow-up if there are priority messages waiting - // Let the queue flush handle them first - if (hasPriorityMessages) return; - - dispatch( - clearPauseReasonsAndHandleToolsStatus({ - wasInteracted: false, - confirmationStatus: areToolsConfirmed, - }), - ); - - void sendMessages(currentMessages, thread.mode); - }, [ - areToolsConfirmed, - currentMessages, - dispatch, - hasPriorityMessages, - sendMessages, - stop, - stopForToolConfirmation, - thread.mode, - ]); + // NOTE: Tool auto-continue is handled by middleware (doneStreaming listener) + // Having it here as well caused a race condition where both would fire, + // resulting in two overlapping streaming requests that mixed up messages. + // See middleware.ts doneStreaming listener for the single source of truth. + + // Export these for components that need to know idle state + return { stop, stopForToolConfirmation, hasPriorityMessages }; } diff --git a/refact-agent/gui/src/hooks/useTrajectoriesSubscription.ts b/refact-agent/gui/src/hooks/useTrajectoriesSubscription.ts new file mode 100644 index 000000000..f6e25bc75 --- /dev/null +++ b/refact-agent/gui/src/hooks/useTrajectoriesSubscription.ts @@ -0,0 +1,195 @@ +import { useEffect, useRef, useCallback } from "react"; +import { useAppDispatch } from "./useAppDispatch"; +import { useConfig } from "./useConfig"; +import { + trajectoriesApi, + TrajectoryEvent, + chatThreadToTrajectoryData, + trajectoryDataToChatThread, +} from "../services/refact/trajectories"; +import { hydrateHistory, deleteChatById, ChatHistoryItem } from "../features/History/historySlice"; +import { updateOpenThread, closeThread } from "../features/Chat/Thread"; + +const MIGRATION_KEY = "refact-trajectories-migrated"; + +function getLegacyHistory(): ChatHistoryItem[] { + try { + const raw = localStorage.getItem("persist:root"); + if (!raw) return []; + + const parsed = JSON.parse(raw) as Record; + if (!parsed.history) return []; + + const historyState = JSON.parse(parsed.history) as Record; + return Object.values(historyState); + } catch { + return []; + } +} + +function clearLegacyHistory() { + try { + const raw = localStorage.getItem("persist:root"); + if (!raw) return; + + const parsed = JSON.parse(raw) as Record; + parsed.history = "{}"; + localStorage.setItem("persist:root", JSON.stringify(parsed)); + } catch { + // ignore + } +} + +function isMigrationDone(): boolean { + return localStorage.getItem(MIGRATION_KEY) === "true"; +} + +function markMigrationDone() { + localStorage.setItem(MIGRATION_KEY, "true"); +} + +export function useTrajectoriesSubscription() { + const dispatch = useAppDispatch(); + const config = useConfig(); + const eventSourceRef = useRef(null); + const reconnectTimeoutRef = useRef | null>(null); + + const connect = useCallback(() => { + if (typeof EventSource === "undefined") return; + + const port = config.lspPort ?? 8001; + const url = `http://127.0.0.1:${port}/v1/trajectories/subscribe`; + + if (eventSourceRef.current) { + eventSourceRef.current.close(); + } + + try { + const eventSource = new EventSource(url); + eventSourceRef.current = eventSource; + + eventSource.onmessage = (event) => { + try { + const data: TrajectoryEvent = JSON.parse(event.data); + if (data.type === "deleted") { + dispatch(deleteChatById(data.id)); + // Force delete runtime even if it's streaming - backend says it's gone + dispatch(closeThread({ id: data.id, force: true })); + } else if (data.type === "updated" || data.type === "created") { + dispatch( + trajectoriesApi.endpoints.getTrajectory.initiate(data.id, { + forceRefetch: true, + }), + ) + .unwrap() + .then((trajectory) => { + // Update history + dispatch(hydrateHistory([trajectory])); + // Also update open thread metadata if it exists (subscription signal) + // IMPORTANT: Only sync metadata, NOT messages - messages are local-authoritative + // to prevent SSE from overwriting in-progress or recently-completed conversations + const thread = trajectoryDataToChatThread(trajectory); + dispatch(updateOpenThread({ + id: data.id, + thread: { + title: thread.title, + isTitleGenerated: thread.isTitleGenerated, + // Don't sync `read` - it's a per-client concern + // Don't pass messages - they could be stale from backend + }, + })); + }) + .catch(() => {}); + } + } catch { + // ignore parse errors + } + }; + + eventSource.onerror = () => { + eventSource.close(); + // Clear any existing reconnect timer before scheduling a new one + if (reconnectTimeoutRef.current) { + clearTimeout(reconnectTimeoutRef.current); + } + reconnectTimeoutRef.current = setTimeout(connect, 5000); + }; + } catch { + // EventSource not available or connection failed + } + }, [dispatch, config.lspPort]); + + const migrateFromLocalStorage = useCallback(async () => { + if (isMigrationDone()) return; + + const legacyChats = getLegacyHistory(); + if (legacyChats.length === 0) { + markMigrationDone(); + return; + } + + let successCount = 0; + for (const chat of legacyChats) { + if (!chat.messages || chat.messages.length === 0) continue; + + try { + const trajectoryData = chatThreadToTrajectoryData( + { + ...chat, + new_chat_suggested: chat.new_chat_suggested ?? { wasSuggested: false }, + }, + chat.createdAt, + ); + trajectoryData.updated_at = chat.updatedAt; + + await dispatch( + trajectoriesApi.endpoints.saveTrajectory.initiate(trajectoryData), + ).unwrap(); + successCount++; + } catch { + // Failed to migrate this chat, continue with others + } + } + + if (successCount > 0) { + clearLegacyHistory(); + } + markMigrationDone(); + }, [dispatch]); + + const loadInitialHistory = useCallback(async () => { + try { + await migrateFromLocalStorage(); + + const result = await dispatch( + trajectoriesApi.endpoints.listTrajectories.initiate(), + ).unwrap(); + + const trajectories = await Promise.all( + result.map((meta) => + dispatch( + trajectoriesApi.endpoints.getTrajectory.initiate(meta.id), + ).unwrap(), + ), + ); + + dispatch(hydrateHistory(trajectories)); + } catch { + // Backend not available + } + }, [dispatch, migrateFromLocalStorage]); + + useEffect(() => { + loadInitialHistory(); + connect(); + + return () => { + if (eventSourceRef.current) { + eventSourceRef.current.close(); + } + if (reconnectTimeoutRef.current) { + clearTimeout(reconnectTimeoutRef.current); + } + }; + }, [connect, loadInitialHistory]); +} diff --git a/refact-agent/gui/src/services/refact/chat.ts b/refact-agent/gui/src/services/refact/chat.ts index 4f0dfd043..a25eb9c87 100644 --- a/refact-agent/gui/src/services/refact/chat.ts +++ b/refact-agent/gui/src/services/refact/chat.ts @@ -73,35 +73,6 @@ type SendChatArgs = { use_compression?: boolean; } & StreamArgs; -type GetChatTitleArgs = { - messages: LspChatMessage[]; - model: string; - lspUrl?: string; - takeNote?: boolean; - onlyDeterministicMessages?: boolean; - chatId?: string; - port?: number; - apiKey?: string | null; - boost_reasoning?: boolean; -} & StreamArgs; - -export type GetChatTitleResponse = { - choices: Choice[]; - created: number; - deterministic_messages: DeterministicMessage[]; - id: string; - metering_balance: number; - model: string; - object: string; - system_fingerprint: string; - usage: Usage; -}; - -export type GetChatTitleActionPayload = { - chatId: string; - title: string; -}; - export type Choice = { finish_reason: string; index: number; @@ -220,41 +191,4 @@ export async function sendChat({ }); } -export async function generateChatTitle({ - messages, - stream, - model, - onlyDeterministicMessages: only_deterministic_messages, - chatId: chat_id, - port = 8001, - apiKey, -}: GetChatTitleArgs): Promise { - const body = JSON.stringify({ - messages, - model, - stream, - max_tokens: 300, - only_deterministic_messages: only_deterministic_messages, - chat_id, - // NOTE: we don't want to use reasoning here, for example Anthropic requires at least max_tokens=1024 for thinking - // parameters: boost_reasoning ? { boost_reasoning: true } : undefined, - }); - - const headers = { - "Content-Type": "application/json", - ...(apiKey ? { Authorization: "Bearer " + apiKey } : {}), - }; - - const url = `http://127.0.0.1:${port}${CHAT_URL}`; - return fetch(url, { - method: "POST", - headers, - body, - redirect: "follow", - cache: "no-cache", - // TODO: causes an error during tests :/ - // referrer: "no-referrer", - credentials: "same-origin", - }); -} diff --git a/refact-agent/gui/src/services/refact/checkpoints.ts b/refact-agent/gui/src/services/refact/checkpoints.ts index 4d7f1ac85..d2000e28e 100644 --- a/refact-agent/gui/src/services/refact/checkpoints.ts +++ b/refact-agent/gui/src/services/refact/checkpoints.ts @@ -36,8 +36,9 @@ export const checkpointsApi = createApi({ const port = state.config.lspPort as unknown as number; const url = `http://127.0.0.1:${port}${PREVIEW_CHECKPOINTS}`; - const chat_id = state.chat.thread.id; - const mode = state.chat.thread.mode; + const runtime = state.chat.threads[state.chat.current_thread_id]; + const chat_id = runtime?.thread.id ?? ""; + const mode = runtime?.thread.mode; const result = await baseQuery({ url, @@ -78,8 +79,9 @@ export const checkpointsApi = createApi({ const port = state.config.lspPort as unknown as number; const url = `http://127.0.0.1:${port}${RESTORE_CHECKPOINTS}`; - const chat_id = state.chat.thread.id; - const mode = state.chat.thread.mode; + const runtime = state.chat.threads[state.chat.current_thread_id]; + const chat_id = runtime?.thread.id ?? ""; + const mode = runtime?.thread.mode; const result = await baseQuery({ url, diff --git a/refact-agent/gui/src/services/refact/consts.ts b/refact-agent/gui/src/services/refact/consts.ts index 7d0553a34..7b33efdd4 100644 --- a/refact-agent/gui/src/services/refact/consts.ts +++ b/refact-agent/gui/src/services/refact/consts.ts @@ -35,7 +35,6 @@ export const RESTORE_CHECKPOINTS = "/v1/checkpoints-restore"; export const TELEMETRY_CHAT_PATH = "/v1/telemetry-chat"; export const TELEMETRY_NET_PATH = "/v1/telemetry-network"; -export const KNOWLEDGE_CREATE_URL = "/v1/trajectory-save"; export const COMPRESS_MESSAGES_URL = "/v1/trajectory-compress"; export const SET_ACTIVE_GROUP_ID = "/v1/set-active-group-id"; diff --git a/refact-agent/gui/src/services/refact/index.ts b/refact-agent/gui/src/services/refact/index.ts index 9047e19d1..dda9ad6af 100644 --- a/refact-agent/gui/src/services/refact/index.ts +++ b/refact-agent/gui/src/services/refact/index.ts @@ -16,3 +16,4 @@ export * from "./docker"; export * from "./telemetry"; export * from "./knowledge"; export * from "./teams"; +export * from "./trajectories"; diff --git a/refact-agent/gui/src/services/refact/knowledge.ts b/refact-agent/gui/src/services/refact/knowledge.ts index 169be2988..8f7d3833a 100644 --- a/refact-agent/gui/src/services/refact/knowledge.ts +++ b/refact-agent/gui/src/services/refact/knowledge.ts @@ -1,7 +1,7 @@ import { RootState } from "../../app/store"; import { createApi, fetchBaseQuery } from "@reduxjs/toolkit/query/react"; import { formatMessagesForLsp } from "../../features/Chat/Thread/utils"; -import { COMPRESS_MESSAGES_URL, KNOWLEDGE_CREATE_URL } from "./consts"; +import { COMPRESS_MESSAGES_URL } from "./consts"; import { type ChatMessages } from "."; export type SubscribeArgs = @@ -68,21 +68,6 @@ export type CompressTrajectoryPost = { messages: ChatMessages; }; -export type SaveTrajectoryResponse = { - memid: string; - trajectory: string; -}; - -function isSaveTrajectoryResponse(obj: unknown): obj is SaveTrajectoryResponse { - if (!obj) return false; - if (typeof obj !== "object") return false; - if (!("memid" in obj) || typeof obj.memid !== "string") return false; - if (!("trajectory" in obj) || typeof obj.trajectory !== "string") { - return false; - } - return true; -} - export const knowledgeApi = createApi({ reducerPath: "knowledgeApi", baseQuery: fetchBaseQuery({ @@ -95,41 +80,6 @@ export const knowledgeApi = createApi({ }, }), endpoints: (builder) => ({ - createNewMemoryFromMessages: builder.mutation< - SaveTrajectoryResponse, - CompressTrajectoryPost - >({ - async queryFn(arg, api, extraOptions, baseQuery) { - const messagesForLsp = formatMessagesForLsp(arg.messages); - - const state = api.getState() as RootState; - const port = state.config.lspPort as unknown as number; - const url = `http://127.0.0.1:${port}${KNOWLEDGE_CREATE_URL}`; - const response = await baseQuery({ - ...extraOptions, - url, - method: "POST", - body: { project: arg.project, messages: messagesForLsp }, - }); - - if (response.error) { - return { error: response.error }; - } - - if (!isSaveTrajectoryResponse(response.data)) { - return { - error: { - status: "CUSTOM_ERROR", - error: `Invalid response from ${url}`, - data: response.data, - }, - }; - } - - return { data: response.data }; - }, - }), - compressMessages: builder.mutation< { goal: string; trajectory: string }, CompressTrajectoryPost diff --git a/refact-agent/gui/src/services/refact/trajectories.ts b/refact-agent/gui/src/services/refact/trajectories.ts new file mode 100644 index 000000000..6362a1da8 --- /dev/null +++ b/refact-agent/gui/src/services/refact/trajectories.ts @@ -0,0 +1,124 @@ +import { createApi, fetchBaseQuery } from "@reduxjs/toolkit/query/react"; +import { ChatThread } from "../../features/Chat/Thread/types"; +import { ChatMessages } from "./types"; + +export type TrajectoryMeta = { + id: string; + title: string; + created_at: string; + updated_at: string; + model: string; + mode: string; + message_count: number; +}; + +export type TrajectoryData = { + id: string; + title: string; + created_at: string; + updated_at: string; + model: string; + mode: string; + tool_use: string; + messages: ChatMessages; + boost_reasoning?: boolean; + context_tokens_cap?: number; + include_project_info?: boolean; + increase_max_tokens?: boolean; + automatic_patch?: boolean; + project_name?: string; + read?: boolean; + isTitleGenerated?: boolean; +}; + +export type TrajectoryEvent = { + type: "created" | "updated" | "deleted"; + id: string; + updated_at?: string; + title?: string; +}; + +export function chatThreadToTrajectoryData(thread: ChatThread, createdAt?: string): TrajectoryData { + const now = new Date().toISOString(); + return { + id: thread.id, + title: thread.title || "New Chat", + created_at: createdAt || now, + updated_at: now, + model: thread.model, + mode: thread.mode || "AGENT", + tool_use: thread.tool_use || "agent", + messages: thread.messages, + boost_reasoning: thread.boost_reasoning, + context_tokens_cap: thread.context_tokens_cap, + include_project_info: thread.include_project_info, + increase_max_tokens: thread.increase_max_tokens, + automatic_patch: thread.automatic_patch, + project_name: thread.project_name, + read: thread.read, + isTitleGenerated: thread.isTitleGenerated, + }; +} + +export function trajectoryDataToChatThread(data: TrajectoryData): ChatThread { + return { + id: data.id, + title: data.title, + model: data.model, + mode: data.mode as ChatThread["mode"], + tool_use: data.tool_use as ChatThread["tool_use"], + messages: data.messages, + boost_reasoning: data.boost_reasoning ?? false, + context_tokens_cap: data.context_tokens_cap, + include_project_info: data.include_project_info ?? true, + increase_max_tokens: data.increase_max_tokens ?? false, + automatic_patch: data.automatic_patch ?? false, + project_name: data.project_name, + read: data.read, + isTitleGenerated: data.isTitleGenerated, + createdAt: data.created_at, + last_user_message_id: "", + new_chat_suggested: { wasSuggested: false }, + }; +} + +export const trajectoriesApi = createApi({ + reducerPath: "trajectoriesApi", + baseQuery: fetchBaseQuery({ baseUrl: "/v1" }), + tagTypes: ["Trajectory"], + endpoints: (builder) => ({ + listTrajectories: builder.query({ + query: () => "/trajectories", + providesTags: ["Trajectory"], + }), + getTrajectory: builder.query({ + query: (id) => `/trajectories/${id}`, + providesTags: (_result, _error, id) => [{ type: "Trajectory", id }], + }), + saveTrajectory: builder.mutation({ + query: (data) => ({ + url: `/trajectories/${data.id}`, + method: "PUT", + body: data, + }), + invalidatesTags: (_result, _error, data) => [ + { type: "Trajectory", id: data.id }, + "Trajectory", + ], + }), + deleteTrajectory: builder.mutation({ + query: (id) => ({ + url: `/trajectories/${id}`, + method: "DELETE", + }), + invalidatesTags: ["Trajectory"], + }), + }), +}); + +export const { + useListTrajectoriesQuery, + useGetTrajectoryQuery, + useSaveTrajectoryMutation, + useDeleteTrajectoryMutation, +} = trajectoriesApi; diff --git a/refact-agent/gui/src/services/refact/types.ts b/refact-agent/gui/src/services/refact/types.ts index 1a837c77b..2c8183a96 100644 --- a/refact-agent/gui/src/services/refact/types.ts +++ b/refact-agent/gui/src/services/refact/types.ts @@ -1,6 +1,6 @@ import { LspChatMode } from "../../features/Chat"; import { Checkpoint } from "../../features/Checkpoints/types"; -import { GetChatTitleActionPayload, GetChatTitleResponse, Usage } from "./chat"; +import { Usage } from "./chat"; import { MCPArgs, MCPEnvs } from "./integrations"; export type ChatRole = @@ -473,36 +473,6 @@ export type UserMessageResponse = ChatUserMessageResponse & { role: "user"; }; -export function isChatGetTitleResponse( - json: unknown, -): json is GetChatTitleResponse { - if (!json || typeof json !== "object") return false; - - const requiredKeys = [ - "id", - "choices", - // "metering_balance", // not in BYOK - "model", - "object", - "system_fingerprint", - "usage", - "created", - "deterministic_messages", - ]; - - return requiredKeys.every((key) => key in json); -} - -export function isChatGetTitleActionPayload( - json: unknown, -): json is GetChatTitleActionPayload { - if (!json || typeof json !== "object") return false; - - const requiredKeys = ["title", "chatId"]; - - return requiredKeys.every((key) => key in json); -} - export function isUserResponse(json: unknown): json is UserMessageResponse { if (!isChatUserMessageResponse(json)) return false; return json.role === "user"; diff --git a/refact-agent/gui/src/utils/test-utils.tsx b/refact-agent/gui/src/utils/test-utils.tsx index ba3176a6a..2cdad71b0 100644 --- a/refact-agent/gui/src/utils/test-utils.tsx +++ b/refact-agent/gui/src/utils/test-utils.tsx @@ -8,6 +8,55 @@ import { Provider } from "react-redux"; import { AppStore, RootState, setUpStore } from "../app/store"; import { TourProvider } from "../features/Tour"; import { AbortControllerProvider } from "../contexts/AbortControllers"; +import { v4 as uuidv4 } from "uuid"; +import type { ChatThreadRuntime } from "../features/Chat/Thread/types"; + +// Helper to create a default thread runtime for tests +const createTestThreadRuntime = (): ChatThreadRuntime => { + return { + thread: { + id: uuidv4(), + messages: [], + title: "", + model: "", + last_user_message_id: "", + tool_use: "explore", + new_chat_suggested: { wasSuggested: false }, + boost_reasoning: false, + automatic_patch: false, + increase_max_tokens: false, + include_project_info: true, + context_tokens_cap: undefined, + }, + streaming: false, + waiting_for_response: false, + prevent_send: false, + error: null, + queued_messages: [], + send_immediately: false, + attached_images: [], + confirmation: { + pause: false, + pause_reasons: [], + status: { + wasInteracted: false, + confirmationStatus: true, + }, + }, + }; +}; + +// Helper to create default chat state with a thread +export const createDefaultChatState = () => { + const runtime = createTestThreadRuntime(); + return { + current_thread_id: runtime.thread.id, + open_thread_ids: [runtime.thread.id], + threads: { [runtime.thread.id]: runtime }, + system_prompt: {}, + tool_use: "explore" as const, + }; +}; // This type interface extends the default options for render from RTL, as well // as allows the user to specify other things such as initialState, store. @@ -28,6 +77,8 @@ const customRender = ( store = setUpStore({ // @ts-expect-error finished tour: { type: "finished", step: 0 }, + // Provide default chat state with a thread for tests + chat: createDefaultChatState(), ...preloadedState, }), ...renderOptions