diff --git a/.gitignore b/.gitignore index df1576d7..17f22854 100644 --- a/.gitignore +++ b/.gitignore @@ -232,3 +232,14 @@ data/ddinter_raw/ # Sphinx build /docs/build/ /docs/source/api/ +biomni_env/setup_path.sh +CLAUDE.md + +# Analysis data/outputs generated during test runs +geo_data/ +outputs/ +outputs_dn_meta/ +dn_analysis/ +dn_glom_meta_probelevel.tsv +test_dn_meta_analysis.py +workspace/ diff --git a/biomni/agent/a1.py b/biomni/agent/a1.py index 7a62e59f..df55ca90 100644 --- a/biomni/agent/a1.py +++ b/biomni/agent/a1.py @@ -8,11 +8,11 @@ from typing import Any, Literal, TypedDict import pandas as pd -from dotenv import load_dotenv from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage from langchain_core.prompts import ChatPromptTemplate from langgraph.checkpoint.memory import MemorySaver from langgraph.graph import END, START, StateGraph +from pydantic import BaseModel, Field from biomni.config import default_config from biomni.know_how import KnowHowLoader @@ -43,9 +43,7 @@ textify_api_dict, ) -if os.path.exists(".env"): - load_dotenv(".env", override=False) - print("Loaded environment variables from .env") +# .env is loaded early in biomni.config to ensure default_config sees env vars class AgentState(TypedDict): @@ -53,6 +51,28 @@ class AgentState(TypedDict): next_step: str | None +class AgentResponse(BaseModel): + """Structured response schema for OpenAI models. + + Used with structured outputs (constrained decoding) to guarantee + format compliance instead of relying on XML tag parsing. + """ + + reasoning: str = Field( + description="Your step-by-step reasoning about the current state, what observations mean, and what to do next." + ) + action: Literal["execute", "solution"] = Field( + description="'execute' to run code (Python/R/Bash), 'solution' to provide the final answer." + ) + content: str = Field( + description=( + "If action is 'execute': the code to run. Python by default. " + "Prefix with #!R for R code, #!BASH for Bash scripts. " + "If action is 'solution': the final answer text." + ), + ) + + class A1: def __init__( self, @@ -202,6 +222,17 @@ def __init__( api_key=api_key, config=default_config, ) + + # If the constructor overrides the LLM to a different provider, + # update default_config.llm_lite to match so that all lite-model + # consumers (database.py, genomics.py, utils.py) use the same provider. + from biomni.config import BiomniConfig + + inferred_lite = BiomniConfig._infer_lite_model(llm) + if inferred_lite != default_config.llm_lite: + default_config.llm_lite = inferred_lite + print(f" Lite Model: {default_config.llm_lite} (updated to match agent provider)") + self.module2api = module2api self.use_tool_retriever = use_tool_retriever @@ -222,6 +253,12 @@ def __init__( self.timeout_seconds = timeout_seconds # 10 minutes default timeout self.configure() + @property + def _is_openai_model(self) -> bool: + """Check if the current LLM is an OpenAI model (supports structured outputs).""" + model_name = getattr(self.llm, "model_name", "") or getattr(self.llm, "model", "") + return str(model_name).lower().startswith("gpt-") or "openai" in str(type(self.llm)).lower() + def add_tool(self, api): """Add a new tool to the agent's tool registry and make it available for retrieval. @@ -352,8 +389,9 @@ def add_mcp(self, config_path: str | Path = "./tutorials/examples/mcp_config.yam Add MCP (Model Context Protocol) tools from configuration file. This method dynamically registers MCP server tools as callable functions within - the biomni agent system. Each MCP server is loaded as an independent module - with its tools exposed as synchronous wrapper functions. + the biomni agent system. Each MCP server is started as a persistent process with + a long-lived session so that stateful tools (background jobs, caching) work correctly + across multiple calls. Supports both manual tool definitions and automatic tool discovery from MCP servers. @@ -367,8 +405,10 @@ def add_mcp(self, config_path: str | Path = "./tutorials/examples/mcp_config.yam RuntimeError: If MCP server initialization fails """ import asyncio + import atexit import os import sys + import threading import types from pathlib import Path @@ -379,62 +419,120 @@ def add_mcp(self, config_path: str | Path = "./tutorials/examples/mcp_config.yam nest_asyncio.apply() - def discover_mcp_tools_sync(server_params: StdioServerParameters) -> list[dict]: - """Discover available tools from MCP server synchronously.""" - try: + # ---- Persistent MCP session management ---- + # Each MCP server gets a single long-lived process and session so that + # stateful patterns (background jobs, caching) work across tool calls. + if not hasattr(self, "_mcp_sessions"): + self._mcp_sessions: dict[str, dict] = {} # server_name -> {session, cleanup, loop, thread} - async def _discover_async(): + def _start_persistent_session(server_name: str, server_params: StdioServerParameters) -> ClientSession: + """Start an MCP server process and keep the session alive in a background thread.""" + + ready_event = threading.Event() + session_holder: dict = {} + + def _run_event_loop(): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + session_holder["loop"] = loop + + async def _keep_alive(): async with stdio_client(server_params) as (reader, writer): async with ClientSession(reader, writer) as session: await session.initialize() + session_holder["session"] = session + ready_event.set() + # Block until we're told to shut down + session_holder["shutdown_event"] = asyncio.Event() + await session_holder["shutdown_event"].wait() - # Get available tools - tools_result = await session.list_tools() - tools = tools_result.tools if hasattr(tools_result, "tools") else tools_result - - discovered_tools = [] - for tool in tools: - if hasattr(tool, "name"): - discovered_tools.append( - { - "name": tool.name, - "description": tool.description, - "inputSchema": tool.inputSchema, - } - ) - else: - print(f"Warning: Skipping tool with no name attribute: {tool}") + try: + loop.run_until_complete(_keep_alive()) + except Exception: + ready_event.set() # unblock caller even on failure + finally: + loop.close() - return discovered_tools + thread = threading.Thread(target=_run_event_loop, daemon=True, name=f"mcp-{server_name}") + thread.start() - return asyncio.run(_discover_async()) - except Exception as e: - print(f"Failed to discover tools: {e}") - return [] + # Wait for the session to be ready (up to 30 seconds) + if not ready_event.wait(timeout=30): + raise RuntimeError(f"MCP server '{server_name}' failed to start within 30 seconds") + + if "session" not in session_holder: + raise RuntimeError(f"MCP server '{server_name}' session failed to initialize") - def make_mcp_wrapper(cmd: str, args: list[str], tool_name: str, doc: str, env_vars: dict = None): - """Create a synchronous wrapper for an async MCP tool call.""" + self._mcp_sessions[server_name] = session_holder + return session_holder["session"] - def sync_tool_wrapper(**kwargs): - """Synchronous wrapper for MCP tool execution.""" + def _shutdown_all_sessions(): + """Clean up all persistent MCP sessions on exit.""" + for _name, holder in getattr(self, "_mcp_sessions", {}).items(): try: - server_params = StdioServerParameters(command=cmd, args=args, env=env_vars) - - async def async_tool_call(): - async with stdio_client(server_params) as (reader, writer): - async with ClientSession(reader, writer) as session: - await session.initialize() - result = await session.call_tool(tool_name, kwargs) - content = result.content[0] - if hasattr(content, "json"): - return content.json() - return content.text + shutdown_event = holder.get("shutdown_event") + loop = holder.get("loop") + if shutdown_event and loop and loop.is_running(): + loop.call_soon_threadsafe(shutdown_event.set) + except Exception: + pass + + atexit.register(_shutdown_all_sessions) + + def discover_mcp_tools_via_session(session: ClientSession, loop: asyncio.AbstractEventLoop) -> list[dict]: + """Discover tools using an already-running persistent session.""" + future = asyncio.run_coroutine_threadsafe(session.list_tools(), loop) + tools_result = future.result(timeout=30) + tools = tools_result.tools if hasattr(tools_result, "tools") else tools_result + + discovered_tools = [] + for tool in tools: + if hasattr(tool, "name"): + discovered_tools.append( + { + "name": tool.name, + "description": tool.description, + "inputSchema": tool.inputSchema, + } + ) + else: + print(f"Warning: Skipping tool with no name attribute: {tool}") + return discovered_tools + def make_mcp_wrapper(server_name: str, tool_name: str, doc: str, param_names: list = None): + """Create a synchronous wrapper that calls a tool on the persistent session.""" + param_names = param_names or [] + + def sync_tool_wrapper(*args_positional, **kwargs): + """Synchronous wrapper for MCP tool execution via persistent session.""" + try: + call_kwargs = dict(kwargs) + for i, arg in enumerate(args_positional): + if i < len(param_names): + call_kwargs[param_names[i]] = arg + + holder = self._mcp_sessions.get(server_name) + if not holder or "session" not in holder: + raise RuntimeError(f"MCP session for '{server_name}' is not available") + + session = holder["session"] + loop = holder["loop"] + + future = asyncio.run_coroutine_threadsafe(session.call_tool(tool_name, call_kwargs), loop) + result = future.result(timeout=300) # 5 min timeout per tool call + content = result.content[0] + # Return the text payload, not the MCP envelope. + # content is a Pydantic TextContent model — .json() would serialize + # the entire envelope (type, text, annotations, meta), but tools + # expect just the inner text (e.g. a JSON string with job_id). + # Auto-parse JSON so the agent gets dicts directly. + text = content.text try: - loop = asyncio.get_running_loop() - return loop.create_task(async_tool_call()) - except RuntimeError: - return asyncio.run(async_tool_call()) + import json as _json + + return _json.loads(text) + except (ValueError, TypeError): + return text except Exception as e: raise RuntimeError(f"MCP tool execution failed for '{tool_name}': {e}") from e @@ -487,17 +585,33 @@ async def async_tool_call(): env_vars = processed_env # Create module namespace for this MCP server - mcp_module_name = f"mcp_servers.{server_name}" + # First ensure the parent 'mcp_servers' module exists + if "mcp_servers" not in sys.modules: + sys.modules["mcp_servers"] = types.ModuleType("mcp_servers") + # Sanitize server name to be a valid Python identifier + # (e.g., "okn-wobd" -> "okn_wobd") so the agent can import it + safe_server_name = server_name.replace("-", "_").replace(".", "_").replace(" ", "_") + mcp_module_name = f"mcp_servers.{safe_server_name}" if mcp_module_name not in sys.modules: sys.modules[mcp_module_name] = types.ModuleType(mcp_module_name) server_module = sys.modules[mcp_module_name] + # Start a persistent session for this MCP server + server_params = StdioServerParameters(command=cmd, args=args, env=env_vars) + try: + session = _start_persistent_session(server_name, server_params) + except Exception as e: + print(f"Failed to start persistent session for {server_name}: {e}") + continue + + holder = self._mcp_sessions[server_name] + loop = holder["loop"] + tools_config = server_meta.get("tools", []) if not tools_config: try: - server_params = StdioServerParameters(command=cmd, args=args, env=env_vars) - tools_config = discover_mcp_tools_sync(server_params) + tools_config = discover_mcp_tools_via_session(session, loop) if tools_config: print(f"Discovered {len(tools_config)} tools from {server_name} MCP server") @@ -534,8 +648,11 @@ async def async_tool_call(): print(f"Warning: Skipping tool with no name in {server_name}") continue - # Create wrapper function - wrapper_function = make_mcp_wrapper(cmd, args, tool_name, description, env_vars) + # Get ordered list of parameter names for positional arg support + param_names_ordered = list(parameters.keys()) + + # Create wrapper function using the persistent session + wrapper_function = make_mcp_wrapper(server_name, tool_name, description, param_names_ordered) # Add to module namespace setattr(server_module, tool_name, wrapper_function) @@ -1094,8 +1211,8 @@ def format_item_with_description(name, description): # Include full content in system prompt (metadata already removed) know_how_formatted.append(f"📚 {name}:\n{content}") - # Base prompt - prompt_modifier = """ + # Base prompt — shared preamble for all providers + prompt_preamble = """ You are a helpful biomedical assistant assigned with the task of problem-solving. To achieve this, you will be using an interactive coding environment equipped with a variety of tool functions, data, and softwares to assist you throughout the process. @@ -1117,7 +1234,41 @@ def format_item_with_description(name, description): 4. [ ] Third step Always show the updated plan after each step so the user can track progress. +""" + + # Format instructions differ by provider + if self._is_openai_model: + # Structured-output mode: the response schema enforces the format, + # so the prompt only needs to explain the semantics. + prompt_format = """ +At each turn, your response will be structured with three fields: "reasoning", "action", and "content". + +- "reasoning": Provide your step-by-step thinking. Briefly note what you know so far and what you plan to do next. +- "action": Choose exactly one of: + "execute" — to run NEW code and observe its result. + "solution" — to provide your final answer. Choose this as soon as you have the information needed to answer. +- "content": The code to execute, or your final answer text. + For Python code (default): just write the code. + For R code: prefix with #!R on the first line. + For Bash scripts and commands: prefix with #!BASH on the first line. + +IMPORTANT: Once code has executed successfully and you have the result, move directly to action "solution". Do NOT re-execute the same code. Each "execute" action should run NEW code that makes progress toward the goal. + +You have many chances to interact with the environment. Decompose your work into multiple small steps. +Don't overcomplicate the code. Keep it simple and easy to understand. +When calling the existing python functions in the function dictionary, YOU MUST SAVE THE OUTPUT and PRINT OUT the result. +For example, result = understand_scRNA(XXX) print(result) +Otherwise the system will not be able to know what has been done. +TASK COMPLETION REQUIREMENTS: +Before choosing action "solution", verify that you have addressed the requirements in the original task: +- If the task asks for "at least N" items, ensure you have found N or more +- If your initial search doesn't meet requirements, try alternative search terms, broaden criteria, or search additional databases +- Do NOT choose "solution" until all stated requirements are met, OR you have exhausted reasonable alternatives and explicitly explain why the requirements cannot be met +""" + else: + # XML-tag mode for Claude and other providers + prompt_format = """ At each turn, you should first provide your thinking and reasoning given the conversation history. After that, you have two options: @@ -1142,6 +1293,8 @@ def format_item_with_description(name, description): In each response, you must include EITHER or tag. Not both at the same time. Do not respond with messages without any tags. No empty messages. """ + prompt_modifier = prompt_preamble + prompt_format + # Add self-critic instructions if needed if self_critic: prompt_modifier += """ @@ -1152,6 +1305,9 @@ def format_item_with_description(name, description): prompt_modifier += """ PROTOCOL GENERATION: If the user requests an experimental protocol, use search_protocols(), advanced_web_search_claude(), list_local_protocols(), and read_local_protocol() to generate an accurate protocol. Include details such as reagents (with catalog numbers if available), equipment specifications, replicate requirements, error handling, and troubleshooting - but ONLY include information found in these resources. Do not make up specifications, catalog numbers, or equipment details. Prioritize accuracy over completeness. + +DATABASE QUERIES: +IMPORTANT: For querying biological databases (NCBI, GEO, PubMed, ClinVar, dbSNP, UniProt, etc.), ALWAYS use the provided Python tool functions like query_geo(), query_pubmed(), query_clinvar(), query_dbsnp(), etc. Do NOT attempt to use command-line tools like esearch, efetch, or other NCBI E-utilities commands via Bash - these are not installed. The Python functions use REST APIs and will work reliably. """ # Add custom resources section first (highlighted) @@ -1379,39 +1535,92 @@ def configure(self, self_critic=False, test_time_scale_round=0): # Define the nodes def generate(state: AgentState) -> AgentState: - # Add OpenAI-specific formatting reminders if using OpenAI models - system_prompt = self.system_prompt - if hasattr(self.llm, "model_name") and ( - "gpt" in str(self.llm.model_name).lower() or "openai" in str(type(self.llm)).lower() - ): - system_prompt += "\n\nIMPORTANT FOR GPT MODELS: You MUST use XML tags or in EVERY response. Do not use markdown code blocks (```) - use tags instead." + messages = [SystemMessage(content=self.system_prompt)] + state["messages"] - messages = [SystemMessage(content=system_prompt)] + state["messages"] - response = self.llm.invoke(messages) + # ----- OpenAI structured-output path ----- + if self._is_openai_model: + try: + structured_llm = self.llm.with_structured_output(AgentResponse) + response: AgentResponse = structured_llm.invoke(messages) + + # Strip markdown code fences from content if present + # (LLMs often wrap code in ```python ... ``` even in structured output) + content = response.content + if response.action == "execute": + content = re.sub(r"^```(?:python|bash|r|shell|sh)?\s*\n?", "", content.strip()) + content = re.sub(r"\n?```\s*$", "", content) + + # Reconstruct XML-tagged message so execute() and logging work unchanged + tag = response.action # "execute" or "solution" + msg = f"{response.reasoning}\n<{tag}>{content}" + state["messages"].append(AIMessage(content=msg.strip())) + + if response.action == "solution": + # Guard: reject premature solutions when no code has been executed. + # Check if any exists in prior messages (observations + # are only added after code execution). + has_executed = any( + "" in (m.content if isinstance(m.content, str) else "") + for m in state["messages"] + ) + if not has_executed: + nudge = ( + "\nYou chose 'solution' but have not executed any code yet. " + "You MUST run code using the available tools before providing a final answer. " + "Choose action 'execute' and write code to accomplish the task." + ) + state["messages"].append(AIMessage(content=nudge.strip())) + state["next_step"] = "generate" + else: + state["next_step"] = "end" + else: + state["next_step"] = "execute" + return state - # Normalize Responses API content blocks (list of dicts) into a plain string - content = response.content - if isinstance(content, list): - # Concatenate textual parts; ignore tool_use or other non-text blocks - text_parts: list[str] = [] - for block in content: - try: - if isinstance(block, dict): - btype = block.get("type") - if btype in ("text", "output_text", "redacted_text"): - part = block.get("text") or block.get("content") or "" - if isinstance(part, str): - text_parts.append(part) - except Exception: - # Be conservative; skip malformed blocks - continue - msg = "".join(text_parts) + except Exception as e: + # If structured output fails, fall through to XML parsing + print(f"⚠️ Structured output failed ({e}), falling back to XML parsing...") + response = self.llm.invoke(messages) + content = response.content + if isinstance(content, list): + text_parts: list[str] = [] + for block in content: + try: + if isinstance(block, dict): + btype = block.get("type") + if btype in ("text", "output_text", "redacted_text"): + part = block.get("text") or block.get("content") or "" + if isinstance(part, str): + text_parts.append(part) + except Exception: + continue + msg = "".join(text_parts) + else: + msg = str(content) + + # ----- XML parsing path (Claude and fallback) ----- else: - # Fallback to string conversion for legacy content - msg = str(content) + response = self.llm.invoke(messages) + + # Normalize Responses API content blocks (list of dicts) into a plain string + content = response.content + if isinstance(content, list): + text_parts: list[str] = [] + for block in content: + try: + if isinstance(block, dict): + btype = block.get("type") + if btype in ("text", "output_text", "redacted_text"): + part = block.get("text") or block.get("content") or "" + if isinstance(part, str): + text_parts.append(part) + except Exception: + continue + msg = "".join(text_parts) + else: + msg = str(content) - # Enhanced parsing for better OpenAI compatibility - # Check for incomplete tags and fix them + # XML tag parsing (used by Claude directly, and as fallback for OpenAI) if "" in msg and "" not in msg: msg += "" if "" in msg and "" not in msg: @@ -1419,20 +1628,10 @@ def generate(state: AgentState) -> AgentState: if "" in msg and "" not in msg: msg += "" - # More flexible pattern matching for different LLM styles think_match = re.search(r"(.*?)", msg, re.DOTALL | re.IGNORECASE) execute_match = re.search(r"(.*?)", msg, re.DOTALL | re.IGNORECASE) answer_match = re.search(r"(.*?)", msg, re.DOTALL | re.IGNORECASE) - # Alternative patterns for OpenAI models that might use different formatting - if not execute_match: - # Try to find code blocks that might be intended as execute blocks - code_block_match = re.search(r"```(?:python|bash|r)?\s*(.*?)```", msg, re.DOTALL) - if code_block_match and not answer_match: - # If we found a code block and no solution, treat it as execute - execute_match = code_block_match - - # Add the message to the state before checking for errors state["messages"].append(AIMessage(content=msg.strip())) if answer_match: @@ -1442,32 +1641,43 @@ def generate(state: AgentState) -> AgentState: elif think_match: state["next_step"] = "generate" else: - print("parsing error...") - - error_count = sum( - 1 for m in state["messages"] if isinstance(m, AIMessage) and "There are no tags" in m.content + # Check if any code has been executed yet (observations exist) + has_executed = any( + "" in (m.content if isinstance(m.content, str) else "") for m in state["messages"] ) - if error_count >= 2: - # If we've already tried to correct the model twice, just end the conversation - print("Detected repeated parsing errors, ending conversation") + if not has_executed: + # No code executed yet and no tags — treat as a direct + # conversational response (e.g. user said "hello"). + # Wrap in so downstream pipeline handles it. + state["messages"][-1] = AIMessage(content=f"{msg.strip()}") state["next_step"] = "end" - # Add a final message explaining the termination - state["messages"].append( - AIMessage( - content="Execution terminated due to repeated parsing errors. Please check your input and try again." - ) - ) else: - # Try to correct it - state["messages"].append( - HumanMessage( - content="Each response must include thinking process followed by either or tag. But there are no tags in the current response. Please follow the instruction, fix and regenerate the response again." - ) + print("parsing error...") + + error_count = sum( + 1 for m in state["messages"] if isinstance(m, AIMessage) and "There are no tags" in m.content ) - state["next_step"] = "generate" + + if error_count >= 2: + print("Detected repeated parsing errors, ending conversation") + state["next_step"] = "end" + state["messages"].append( + AIMessage( + content="Execution terminated due to repeated parsing errors. Please check your input and try again." + ) + ) + else: + state["messages"].append( + HumanMessage( + content="Each response must include thinking process followed by either or tag. But there are no tags in the current response. Please follow the instruction, fix and regenerate the response again." + ) + ) + state["next_step"] = "generate" return state + _previous_code_blocks: list[str] = [] + def execute(state: AgentState) -> AgentState: last_message = state["messages"][-1].content # Only add the closing tag if it's not already there @@ -1478,6 +1688,25 @@ def execute(state: AgentState) -> AgentState: if execute_match: code = execute_match.group(1) + # Detect duplicate code execution — if the model re-submits the same code + # it already ran, nudge it to move to the solution instead of re-running. + code_normalized = code.strip() + if code_normalized in _previous_code_blocks: + observation = ( + "\nYou already executed this exact code. " + "The result was shown above. Do not re-run the same code. " + "If you have the answer, choose action 'solution' now." + ) + state["messages"].append(HumanMessage(content=observation.strip())) + return state + _previous_code_blocks.append(code_normalized) + + # Strip markdown code fences that LLMs (especially OpenAI) often wrap code in, + # even when using structured output. This is a safety net — generate() should + # already strip fences, but this catches fallback/edge cases. + code = re.sub(r"^```(?:python|bash|r|shell|sh)?\s*\n?", "", code.strip()) + code = re.sub(r"\n?```\s*$", "", code) + # Set timeout duration (10 minutes = 600 seconds) timeout = self.timeout_seconds @@ -1548,7 +1777,7 @@ def execute(state: AgentState) -> AgentState: self._execution_results.append(execution_entry) observation = f"\n{result}" - state["messages"].append(AIMessage(content=observation.strip())) + state["messages"].append(HumanMessage(content=observation.strip())) return state diff --git a/biomni/agent/env_collection.py b/biomni/agent/env_collection.py index 5e9a8ab4..81114116 100644 --- a/biomni/agent/env_collection.py +++ b/biomni/agent/env_collection.py @@ -194,15 +194,20 @@ def _consolidate_tasks(self, chunk_results: list[str]) -> dict[str, list[dict[st prompt = self.consolidation_prompt.format(task_lists=all_tasks) response = self.llm.invoke(prompt) + # Normalize content (handles string, list of blocks, etc.) + from biomni.utils import normalize_llm_content + + response_text = normalize_llm_content(response.content) + # Extract the JSON from the response try: # Try to parse the entire response as JSON - result = json.loads(response.content) + result = json.loads(response_text) except json.JSONDecodeError: # If that fails, try to extract JSON from the text try: # Look for JSON-like content between triple backticks - json_match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", response.content) + json_match = re.search(r"```(?:json)?\s*([\s\S]*?)\s*```", response_text) if json_match: result = json.loads(json_match.group(1)) else: @@ -211,7 +216,7 @@ def _consolidate_tasks(self, chunk_results: list[str]) -> dict[str, list[dict[st "tasks": [ { "error": "Could not parse JSON", - "raw_response": response.content, + "raw_response": response_text, } ], "databases": [], @@ -222,7 +227,7 @@ def _consolidate_tasks(self, chunk_results: list[str]) -> dict[str, list[dict[st "tasks": [ { "error": "Could not parse JSON", - "raw_response": response.content, + "raw_response": response_text, } ], "databases": [], diff --git a/biomni/config.py b/biomni/config.py index 874a1adf..a0fa5839 100644 --- a/biomni/config.py +++ b/biomni/config.py @@ -7,6 +7,16 @@ import os from dataclasses import dataclass +from pathlib import Path + +from dotenv import load_dotenv + +# Load .env before anything reads os.getenv so that default_config (created at +# module level below) picks up the values. override=False keeps real +# environment variables authoritative. +_env_path = Path(__file__).resolve().parents[1] / ".env" +if _env_path.is_file(): + load_dotenv(_env_path, override=False) @dataclass @@ -29,11 +39,13 @@ class BiomniConfig: """ # Data and execution settings - path: str = "./data" + path: str = "./data" # Data lake path (input data) + workspace: str = "./workspace" # Output directory for generated files timeout_seconds: int = 600 # LLM settings (API keys still from environment) - llm: str = "claude-sonnet-4-5" + llm: str = "claude-sonnet-4-5" # Primary model for agent reasoning + llm_lite: str | None = None # Lightweight model for simple tasks; auto-inferred from llm provider if not set temperature: float = 0.7 # Tool settings @@ -58,10 +70,14 @@ def __post_init__(self): # Support both old and new names for backwards compatibility if os.getenv("BIOMNI_PATH") or os.getenv("BIOMNI_DATA_PATH"): self.path = os.getenv("BIOMNI_PATH") or os.getenv("BIOMNI_DATA_PATH") + if os.getenv("BIOMNI_WORKSPACE"): + self.workspace = os.getenv("BIOMNI_WORKSPACE") if os.getenv("BIOMNI_TIMEOUT_SECONDS"): self.timeout_seconds = int(os.getenv("BIOMNI_TIMEOUT_SECONDS")) if os.getenv("BIOMNI_LLM") or os.getenv("BIOMNI_LLM_MODEL"): self.llm = os.getenv("BIOMNI_LLM") or os.getenv("BIOMNI_LLM_MODEL") + if os.getenv("BIOMNI_LLM_LITE"): + self.llm_lite = os.getenv("BIOMNI_LLM_LITE") if os.getenv("BIOMNI_USE_TOOL_RETRIEVER"): self.use_tool_retriever = os.getenv("BIOMNI_USE_TOOL_RETRIEVER").lower() == "true" if os.getenv("BIOMNI_COMMERCIAL_MODE"): @@ -80,12 +96,34 @@ def __post_init__(self): if env_token: self.protocols_io_access_token = env_token + # Auto-infer llm_lite from the main llm's provider if not explicitly set + if self.llm_lite is None: + self.llm_lite = self._infer_lite_model(self.llm) + + @staticmethod + def _infer_lite_model(model: str) -> str: + """Return a lightweight model that matches the provider of the given model.""" + _LITE_MODELS = { + "claude-": "claude-haiku-4-5", + "gpt-": "gpt-5-mini", + "gemini-": "gemini-1.5-flash", + "groq": "llama-3.1-8b-instant", + } + model_lower = model.lower() + for prefix, lite in _LITE_MODELS.items(): + if model_lower.startswith(prefix) or prefix in model_lower: + return lite + # Fallback: same provider detection as llm.py + return "claude-haiku-4-5" + def to_dict(self) -> dict: """Convert config to dictionary for easy access.""" return { "path": self.path, + "workspace": self.workspace, "timeout_seconds": self.timeout_seconds, "llm": self.llm, + "llm_lite": self.llm_lite, "temperature": self.temperature, "use_tool_retriever": self.use_tool_retriever, "commercial_mode": self.commercial_mode, diff --git a/biomni/env_desc.py b/biomni/env_desc.py index ec834d4c..16259782 100644 --- a/biomni/env_desc.py +++ b/biomni/env_desc.py @@ -93,6 +93,7 @@ "biotite": "[Python Package] A comprehensive library for computational molecular biology, providing tools for sequence analysis, structure analysis, and more.", "lazyslide": "[Python Package] A Python framework that brings interoperable, reproducible whole slide image analysis, enabling seamless histopathology workflows from preprocessing to deep learning.", # Genomics & Variant Analysis (Python) + "GEOparse": "[Python Package] A library for downloading and parsing data from NCBI GEO (Gene Expression Omnibus). Use GEOparse.get_GEO(geo='GSE12345', destdir='./output') to download series or sample data.", "gget": "[Python Package] A toolkit for accessing genomic databases and retrieving sequences, annotations, and other genomic data.", "lifelines": "[Python Package] A complete survival analysis library for fitting models, plotting, and statistical tests.", # "scvi-tools": "[Python Package] A package for probabilistic modeling of single-cell omics data, including deep generative models.", diff --git a/biomni/env_desc_cm.py b/biomni/env_desc_cm.py index 177920da..b0b29aab 100644 --- a/biomni/env_desc_cm.py +++ b/biomni/env_desc_cm.py @@ -93,6 +93,7 @@ "biotite": "[Python Package] A comprehensive library for computational molecular biology, providing tools for sequence analysis, structure analysis, and more.", "lazyslide": "[Python Package] A Python framework that brings interoperable, reproducible whole slide image analysis, enabling seamless histopathology workflows from preprocessing to deep learning.", # Genomics & Variant Analysis (Python) + "GEOparse": "[Python Package] A library for downloading and parsing data from NCBI GEO (Gene Expression Omnibus). Use GEOparse.get_GEO(geo='GSE12345', destdir='./output') to download series or sample data.", "gget": "[Python Package] A toolkit for accessing genomic databases and retrieving sequences, annotations, and other genomic data.", "lifelines": "[Python Package] A complete survival analysis library for fitting models, plotting, and statistical tests.", # "scvi-tools": "[Python Package] A package for probabilistic modeling of single-cell omics data, including deep generative models.", diff --git a/biomni/llm.py b/biomni/llm.py index 1ca4e6a4..6c8c008b 100644 --- a/biomni/llm.py +++ b/biomni/llm.py @@ -9,6 +9,74 @@ SourceType = Literal["OpenAI", "AzureOpenAI", "Anthropic", "Ollama", "Gemini", "Bedrock", "Groq", "Custom"] ALLOWED_SOURCES: set[str] = set(SourceType.__args__) +# Default models for each provider (used when no model is specified) +DEFAULT_MODELS = { + "Anthropic": "claude-sonnet-4-5", + "OpenAI": "gpt-4o", + "Gemini": "gemini-1.5-pro", + "Groq": "llama-3.1-70b-versatile", + "Bedrock": "anthropic.claude-3-5-sonnet-20241022-v2:0", +} + +# Lightweight models for simple tasks (parsing, classification, etc.) +DEFAULT_MODELS_LITE = { + "Anthropic": "claude-haiku-4-5", + "OpenAI": "gpt-5-mini", + "Gemini": "gemini-1.5-flash", + "Groq": "llama-3.1-8b-instant", + "Bedrock": "anthropic.claude-3-5-haiku-20241022-v1:0", +} + + +def _get_default_model() -> str: + """Select the default model based on available API keys. + + Checks for API keys in order of preference and returns the appropriate + default model for the first available provider. + + Returns: + str: The default model name for the available provider. + """ + # Check providers in order of preference + if os.getenv("ANTHROPIC_API_KEY"): + return DEFAULT_MODELS["Anthropic"] + if os.getenv("OPENAI_API_KEY"): + return DEFAULT_MODELS["OpenAI"] + if os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY"): + return DEFAULT_MODELS["Gemini"] + if os.getenv("GROQ_API_KEY"): + return DEFAULT_MODELS["Groq"] + if os.getenv("AWS_ACCESS_KEY_ID") or os.getenv("AWS_PROFILE"): + return DEFAULT_MODELS["Bedrock"] + + # Fallback to Anthropic model (will fail if no key, but provides clear error) + return DEFAULT_MODELS["Anthropic"] + + +def _get_default_model_lite() -> str: + """Select the default lightweight model based on available API keys. + + Similar to _get_default_model() but returns cheaper/faster models + suitable for simple tasks like parsing, classification, etc. + + Returns: + str: The default lite model name for the available provider. + """ + # Check providers in order of preference + if os.getenv("ANTHROPIC_API_KEY"): + return DEFAULT_MODELS_LITE["Anthropic"] + if os.getenv("OPENAI_API_KEY"): + return DEFAULT_MODELS_LITE["OpenAI"] + if os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY"): + return DEFAULT_MODELS_LITE["Gemini"] + if os.getenv("GROQ_API_KEY"): + return DEFAULT_MODELS_LITE["Groq"] + if os.getenv("AWS_ACCESS_KEY_ID") or os.getenv("AWS_PROFILE"): + return DEFAULT_MODELS_LITE["Bedrock"] + + # Fallback to Anthropic model (will fail if no key, but provides clear error) + return DEFAULT_MODELS_LITE["Anthropic"] + def get_llm( model: str | None = None, @@ -35,7 +103,7 @@ def get_llm( # Use config values for any unspecified parameters if config is not None: if model is None: - model = config.llm_model + model = config.llm # Use config's LLM setting if temperature is None: temperature = config.temperature if source is None: @@ -45,9 +113,9 @@ def get_llm( if api_key is None: api_key = config.api_key or "EMPTY" - # Use defaults if still not specified + # Use defaults if still not specified - select based on available API keys if model is None: - model = "claude-3-5-sonnet-20241022" + model = _get_default_model() if temperature is None: temperature = 0.7 if api_key is None: @@ -131,12 +199,14 @@ def _get_request_payload(self, input_, *, stop=None, **kwargs): # type: ignore[ stop_sequences=stop_sequences, use_responses_api=True, output_version="v0", + request_timeout=300, # 5 minute timeout for API requests ) else: return ChatOpenAI( model=model, temperature=temperature, stop_sequences=stop_sequences, + request_timeout=300, # 5 minute timeout for API requests ) elif source == "AzureOpenAI": @@ -186,6 +256,7 @@ def _get_request_payload(self, input_, *, stop=None, **kwargs): # type: ignore[ temperature=temperature, max_tokens=8192, stop_sequences=stop_sequences, + timeout=300, # 5 minute timeout for API requests ) elif source == "Gemini": diff --git a/biomni/model/retriever.py b/biomni/model/retriever.py index 1d5fbc0b..e3383b7e 100644 --- a/biomni/model/retriever.py +++ b/biomni/model/retriever.py @@ -2,7 +2,6 @@ import re from langchain_core.messages import HumanMessage -from langchain_openai import ChatOpenAI class ToolRetriever: @@ -88,15 +87,21 @@ def prompt_based_retrieval(self, query: str, resources: dict, llm=None) -> dict: prompt = "\n".join(prompt_sections) + response_format - # Use the provided LLM or create a new one + # Use the provided LLM or create one from config if llm is None: - llm = ChatOpenAI(model="gpt-4o") + from biomni.config import BiomniConfig + from biomni.llm import get_llm + + llm = get_llm(config=BiomniConfig()) # Invoke the LLM if hasattr(llm, "invoke"): # For LangChain-style LLMs response = llm.invoke([HumanMessage(content=prompt)]) - response_content = response.content + # Normalize content (handles string, list of blocks, etc.) + from biomni.utils import normalize_llm_content + + response_content = normalize_llm_content(response.content) else: # For other LLM interfaces response_content = str(llm(prompt)) diff --git a/biomni/tool/database.py b/biomni/tool/database.py index fa55b1c6..f4f68a99 100644 --- a/biomni/tool/database.py +++ b/biomni/tool/database.py @@ -2,17 +2,32 @@ import os import pickle import time -from typing import Any +from typing import Any, Literal import requests from Bio.Blast import NCBIWWW, NCBIXML from Bio.Seq import Seq from langchain_core.messages import HumanMessage, SystemMessage +from pydantic import BaseModel, Field from biomni.llm import get_llm from biomni.utils import parse_hpo_obo +# Pydantic models for structured LLM output +class GEOQueryResponse(BaseModel): + """Structured response for GEO database queries.""" + + search_term: str = Field(description="The GEO search query string") + database: Literal["gds", "geoprofiles"] = Field(default="gds", description="GEO database to search") + + +class DBSNPQueryResponse(BaseModel): + """Structured response for dbSNP database queries.""" + + search_term: str = Field(description="The dbSNP search query string") + + # Function to map HPO terms to names def get_hpo_names(hpo_terms: list[str], data_lake_path: str) -> list[str]: """Retrieve the names of given HPO terms. @@ -33,7 +48,7 @@ def get_hpo_names(hpo_terms: list[str], data_lake_path: str) -> list[str]: return hpo_names -def _query_llm_for_api(prompt, schema, system_template): +def _query_llm_for_api(prompt, schema, system_template, response_model=None): """Helper function to query LLMs for generating API calls based on natural language prompts. Supports multiple model providers including Claude, Gemini, GPT, and others via the unified get_llm interface. @@ -43,6 +58,8 @@ def _query_llm_for_api(prompt, schema, system_template): prompt (str): Natural language query to process schema (dict): API schema to include in the system prompt system_template (str): Template string for the system prompt (should have {schema} placeholder) + response_model (BaseModel, optional): Pydantic model for structured output. If provided, uses + LangChain's with_structured_output() for reliable JSON responses. Returns ------- @@ -50,13 +67,17 @@ def _query_llm_for_api(prompt, schema, system_template): """ # Use global config for model and api_key + # This is a simple parsing task, so use the lite model try: from biomni.config import default_config - model = default_config.llm + model = default_config.llm_lite # Use lightweight model for parsing api_key = default_config.api_key except ImportError: - model = "claude-3-5-haiku-20241022" + # Fallback: use smart default based on available API keys + from biomni.llm import _get_default_model_lite + + model = _get_default_model_lite() api_key = None try: @@ -75,6 +96,22 @@ def _query_llm_for_api(prompt, schema, system_template): except ImportError: llm = get_llm(model=model, temperature=0.0, api_key=api_key or "EMPTY") + # Use structured output if a response model is provided + if response_model is not None: + try: + structured_llm = llm.with_structured_output(response_model) + result = structured_llm.invoke(system_prompt + "\n\nUser query: " + prompt) + # Convert Pydantic model to dict + if hasattr(result, "model_dump"): + return {"success": True, "data": result.model_dump()} + elif hasattr(result, "dict"): + return {"success": True, "data": result.dict()} + else: + return {"success": True, "data": dict(result)} + except Exception: + # Fall back to manual parsing if structured output fails + pass # Continue to manual parsing below + # Compose messages messages = [ SystemMessage(content=system_prompt), @@ -83,7 +120,10 @@ def _query_llm_for_api(prompt, schema, system_template): # Query the LLM response = llm.invoke(messages) - llm_text = response.content.strip() + # Normalize content (handles string, list of blocks, etc.) + from biomni.utils import normalize_llm_content + + llm_text = normalize_llm_content(response.content) # Find JSON boundaries (in case LLM adds explanations) json_start = llm_text.find("{") @@ -218,16 +258,25 @@ def _query_ncbi_database( dict: Dictionary containing both the structured query and the results """ + # NCBI rate limit: max 3 requests/second without API key, 10 with API key + # Limit max_results to avoid overwhelming the API + max_results = min(max_results, 20) # Cap at 20 to avoid rate limits + # Query NCBI API using the structured search term esearch_url = "https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi" esearch_params = { "db": database, "term": search_term, "retmode": "json", - "retmax": 100, + "retmax": max_results, # Only fetch as many as we need "usehistory": "y", # Use history server to store results } + # Add NCBI API key if available (increases rate limit from 3 to 10 requests/sec) + ncbi_api_key = os.getenv("NCBI_API_KEY") + if ncbi_api_key: + esearch_params["api_key"] = ncbi_api_key + # Get IDs of matching entries search_response = _query_rest_api( endpoint=esearch_url, @@ -243,6 +292,9 @@ def _query_ncbi_database( # If we have results, fetch the details if "esearchresult" in search_data and int(search_data["esearchresult"]["count"]) > 0: + # Small delay to respect NCBI rate limits + time.sleep(0.4) + # Extract WebEnv and query_key from the search results webenv = search_data["esearchresult"].get("webenv", "") query_key = search_data["esearchresult"].get("querykey", "") @@ -258,6 +310,8 @@ def _query_ncbi_database( "retmode": "json", "retmax": max_results, } + if ncbi_api_key: + esummary_params["api_key"] = ncbi_api_key details_response = _query_rest_api( endpoint=esummary_url, @@ -282,6 +336,8 @@ def _query_ncbi_database( "id": ",".join(id_list), "retmode": "json", } + if ncbi_api_key: + esummary_params["api_key"] = ncbi_api_key details_response = _query_rest_api( endpoint=esummary_url, @@ -1966,16 +2022,17 @@ def query_geo( If database isn't clearly specified, default to "gds" as it contains most common experiment metadata. EXAMPLES OF CORRECT OUTPUTS: - - For "RNA-seq data in breast cancer": {"search_term": "RNA-seq AND breast cancer AND gse[ETYP]", "database": "gds"} - - For "Mouse microarray data from 2020": {"search_term": "Mus musculus[ORGN] AND 2020[PDAT] AND microarray AND gse[ETYP]", "database": "gds"} - - For "Expression profiles of TP53 in lung cancer": {"search_term": "TP53[Gene Symbol] AND lung cancer", "database": "geoprofiles"} + - For "RNA-seq data in breast cancer": {{"search_term": "RNA-seq AND breast cancer AND gse[ETYP]", "database": "gds"}} + - For "Mouse microarray data from 2020": {{"search_term": "Mus musculus[ORGN] AND 2020[PDAT] AND microarray AND gse[ETYP]", "database": "gds"}} + - For "Expression profiles of TP53 in lung cancer": {{"search_term": "TP53[Gene Symbol] AND lung cancer", "database": "geoprofiles"}} """ - # Query Claude to generate the API call + # Query LLM to generate the API call (use structured output for reliability) llm_result = _query_llm_for_api( prompt=prompt, schema=geo_schema, system_template=system_template, + response_model=GEOQueryResponse, ) if not llm_result["success"]: @@ -1999,6 +2056,495 @@ def query_geo( max_results=max_results, ) + # Format GEO results into a cleaner structure + if result.get("formatted_results") and isinstance(result["formatted_results"], dict): + raw_results = result["formatted_results"].get("result", {}) + datasets = [] + for uid, data in raw_results.items(): + if uid == "uids": + continue + if isinstance(data, dict): + datasets.append( + { + "accession": data.get("accession", f"GDS{uid}"), + "title": data.get("title", ""), + "summary": data.get("summary", ""), + "organism": data.get("taxon", ""), + "platform": data.get("gpl", ""), + "samples": data.get("n_samples", data.get("samplecount", "")), + "type": data.get("gdstype", data.get("entrytype", "")), + } + ) + result["datasets"] = datasets + + return result + + +def download_geo( + accession: str, + output_dir: str | None = None, + return_expression_matrix: bool = True, + return_metadata: bool = True, +): + """Download and parse data from a GEO accession (GSE series or GSM sample). + + Parameters + ---------- + accession : str + GEO accession ID (e.g., 'GSE123456' for series, 'GSM123456' for sample) + output_dir : str, optional + Directory to save downloaded files. Defaults to workspace/geo_data from config + return_expression_matrix : bool + Whether to return the expression matrix as a pandas DataFrame. Defaults to True + return_metadata : bool + Whether to return sample metadata. Defaults to True + + Returns + ------- + dict + Dictionary containing: + - 'accession': The GEO accession ID + - 'files': List of downloaded file paths + - 'expression_matrix': pandas DataFrame of expression values (if requested and available) + - 'metadata': Sample metadata DataFrame (if requested) + - 'error': Error message if download failed + + Examples + -------- + >>> result = download_geo("GSE161650") + >>> df = result["expression_matrix"] + >>> metadata = result["metadata"] + + Notes + ----- + This function uses the GEOparse library to download and parse GEO data. + For GSE accessions, it downloads the series matrix file. + For GSM accessions, it downloads the individual sample data. + """ + try: + import GEOparse + except ImportError: + return { + "error": "GEOparse library not installed. Install with: pip install GEOparse", + "accession": accession, + } + + # Use workspace from config if output_dir not specified + if output_dir is None: + from biomni.config import default_config + + output_dir = os.path.join(default_config.workspace, "geo_data") + + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + result = { + "accession": accession, + "files": [], + "expression_matrix": None, + "metadata": None, + "error": None, + } + + try: + # Download via HTTPS (more reliable than FTP) + accession = accession.upper() + + if accession.startswith("GSE"): + # Construct HTTPS URL for SOFT file + # URL pattern: https://ftp.ncbi.nlm.nih.gov/geo/series/GSEnnn/GSE12345/soft/GSE12345_family.soft.gz + series_num = accession[3:] # Remove 'GSE' prefix + series_dir = f"GSE{series_num[:-3]}nnn" if len(series_num) > 3 else "GSEnnn" + soft_url = ( + f"https://ftp.ncbi.nlm.nih.gov/geo/series/{series_dir}/{accession}/soft/{accession}_family.soft.gz" + ) + + soft_path = os.path.join(output_dir, f"{accession}_family.soft.gz") + + # Download the file via HTTPS + if not os.path.exists(soft_path): + response = requests.get(soft_url, stream=True) + response.raise_for_status() + with open(soft_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + result["files"].append(soft_path) + + # Parse with GEOparse using the downloaded file + gse = GEOparse.get_GEO(filepath=soft_path, silent=True) + + # Get downloaded files + series_matrix_path = soft_path + if os.path.exists(series_matrix_path): + result["files"].append(series_matrix_path) + + # Extract expression matrix + if return_expression_matrix and hasattr(gse, "gpls") and len(gse.gpls) > 0: + try: + # Try to get pivoted expression data + pivot_samples = gse.pivot_samples("VALUE") + if pivot_samples is not None and not pivot_samples.empty: + result["expression_matrix"] = pivot_samples + # Save to CSV + expr_path = os.path.join(output_dir, f"{accession}_expression.csv") + pivot_samples.to_csv(expr_path) + result["files"].append(expr_path) + except Exception: + # Expression matrix not available in expected format + pass + + # Extract metadata + if return_metadata: + try: + import pandas as pd + + metadata_rows = [] + for gsm_name, gsm in gse.gsms.items(): + row = {"sample_id": gsm_name} + row.update(gsm.metadata) + # Flatten lists in metadata + for k, v in row.items(): + if isinstance(v, list) and len(v) == 1: + row[k] = v[0] + elif isinstance(v, list): + row[k] = "; ".join(str(x) for x in v) + metadata_rows.append(row) + + if metadata_rows: + metadata_df = pd.DataFrame(metadata_rows) + result["metadata"] = metadata_df + # Save to CSV + meta_path = os.path.join(output_dir, f"{accession}_metadata.csv") + metadata_df.to_csv(meta_path, index=False) + result["files"].append(meta_path) + except Exception as e: + result["metadata_error"] = str(e) + + # Summary info + result["n_samples"] = len(gse.gsms) + result["n_platforms"] = len(gse.gpls) + result["title"] = gse.metadata.get("title", ["Unknown"])[0] if gse.metadata.get("title") else "Unknown" + + elif accession.startswith("GSM"): + # Construct HTTPS URL for GSM SOFT file + # URL pattern: https://ftp.ncbi.nlm.nih.gov/geo/samples/GSMnnn/GSM12345/soft/GSM12345.soft.gz + sample_num = accession[3:] # Remove 'GSM' prefix + sample_dir = f"GSM{sample_num[:-3]}nnn" if len(sample_num) > 3 else "GSMnnn" + soft_url = f"https://ftp.ncbi.nlm.nih.gov/geo/samples/{sample_dir}/{accession}/soft/{accession}.soft.gz" + + sample_path = os.path.join(output_dir, f"{accession}.soft.gz") + + # Download the file via HTTPS + if not os.path.exists(sample_path): + response = requests.get(soft_url, stream=True) + response.raise_for_status() + with open(sample_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + + result["files"].append(sample_path) + + # Parse with GEOparse using the downloaded file + gsm = GEOparse.get_GEO(filepath=sample_path, silent=True) + + # Extract table data as expression + if return_expression_matrix and hasattr(gsm, "table") and gsm.table is not None: + result["expression_matrix"] = gsm.table + expr_path = os.path.join(output_dir, f"{accession}_data.csv") + gsm.table.to_csv(expr_path) + result["files"].append(expr_path) + + # Extract metadata + if return_metadata: + result["metadata"] = gsm.metadata + result["title"] = gsm.metadata.get("title", ["Unknown"])[0] if gsm.metadata.get("title") else "Unknown" + + else: + result["error"] = f"Unsupported accession type: {accession}. Use GSE (series) or GSM (sample) accessions." + + except Exception as e: + result["error"] = f"Failed to download {accession}: {str(e)}" + + return result + + +def download_gpl_annotation( + gpl_id: str, + output_dir: str | None = None, + gene_symbol_column: str | None = None, +): + """Download and parse GPL platform annotation file with probe-to-gene mappings. + + Parameters + ---------- + gpl_id : str + GPL platform ID (e.g., 'GPL571', 'GPL17586') + output_dir : str, optional + Directory to save downloaded files. Defaults to workspace/geo_data from config + gene_symbol_column : str, optional + Column name containing gene symbols. If None, auto-detects from common column names. + + Returns + ------- + dict + Dictionary containing: + - 'gpl_id': The GPL platform ID + - 'annotation_table': pandas DataFrame with full annotation table + - 'probe_to_gene': dict mapping probe IDs to gene symbols + - 'gene_columns': list of detected gene-related columns + - 'files': List of downloaded file paths + - 'error': Error message if download failed + + Examples + -------- + >>> result = download_gpl_annotation("GPL571") + >>> probe_to_gene = result["probe_to_gene"] + >>> gene_symbol = probe_to_gene.get("1007_s_at") # 'DDR1' + + Notes + ----- + This function uses GEOparse to download annotated GPL files which contain + probe-to-gene mappings necessary for cross-platform meta-analysis. + """ + try: + import GEOparse + import pandas as pd + except ImportError: + return { + "error": "GEOparse library not installed. Install with: pip install GEOparse", + "gpl_id": gpl_id, + } + + # Use workspace from config if output_dir not specified + if output_dir is None: + from biomni.config import default_config + + output_dir = os.path.join(default_config.workspace, "geo_data", "gpl_annotations") + + # Create output directory + os.makedirs(output_dir, exist_ok=True) + + result = { + "gpl_id": gpl_id, + "annotation_table": None, + "probe_to_gene": {}, + "gene_columns": [], + "files": [], + "error": None, + } + + try: + gpl_id = gpl_id.upper() + if not gpl_id.startswith("GPL"): + gpl_id = f"GPL{gpl_id}" + + # Download GPL with annotations (annotate_gpl=True gets the curated annotation file) + gpl = GEOparse.get_GEO(geo=gpl_id, destdir=output_dir, annotate_gpl=True, silent=True) + + if not hasattr(gpl, "table") or gpl.table is None or gpl.table.empty: + result["error"] = f"No annotation table found for {gpl_id}" + return result + + annotation_df = gpl.table + result["annotation_table"] = annotation_df + + # Find gene-related columns + gene_col_patterns = ["gene symbol", "gene_symbol", "symbol", "gene_assignment", "gene name"] + entrez_col_patterns = ["gene id", "gene_id", "entrez", "entrez_gene_id"] + + # Auto-detect gene symbol column + detected_gene_col = None + for col in annotation_df.columns: + col_lower = col.lower() + if gene_symbol_column and col_lower == gene_symbol_column.lower(): + detected_gene_col = col + break + for pattern in gene_col_patterns: + if pattern in col_lower: + detected_gene_col = col + break + if detected_gene_col: + break + + # Collect all gene-related columns for reference + all_gene_cols = [] + for col in annotation_df.columns: + col_lower = col.lower() + if any(p in col_lower for p in gene_col_patterns + entrez_col_patterns + ["refseq", "ensembl", "unigene"]): + all_gene_cols.append(col) + result["gene_columns"] = all_gene_cols + + # Build probe-to-gene mapping + if detected_gene_col and "ID" in annotation_df.columns: + probe_to_gene = {} + for _, row in annotation_df.iterrows(): + probe_id = str(row["ID"]) + gene_symbol = row.get(detected_gene_col, "") + + # Handle complex gene symbol formats (e.g., "DDR1 /// MIR4640") + if pd.notna(gene_symbol) and gene_symbol: + gene_str = str(gene_symbol).strip() + # Take the first gene if multiple are listed + if "///" in gene_str: + gene_str = gene_str.split("///")[0].strip() + elif "//" in gene_str: + gene_str = gene_str.split("//")[0].strip() + if gene_str and gene_str != "---" and gene_str.lower() != "nan": + probe_to_gene[probe_id] = gene_str + + result["probe_to_gene"] = probe_to_gene + result["gene_symbol_column"] = detected_gene_col + result["n_mapped_probes"] = len(probe_to_gene) + result["n_total_probes"] = len(annotation_df) + + # Save annotation to CSV + annot_path = os.path.join(output_dir, f"{gpl_id}_annotation.csv") + annotation_df.to_csv(annot_path, index=False) + result["files"].append(annot_path) + + # Save probe-to-gene mapping + if result["probe_to_gene"]: + mapping_path = os.path.join(output_dir, f"{gpl_id}_probe_to_gene.csv") + mapping_df = pd.DataFrame([{"probe_id": k, "gene_symbol": v} for k, v in result["probe_to_gene"].items()]) + mapping_df.to_csv(mapping_path, index=False) + result["files"].append(mapping_path) + + # Add platform info + if hasattr(gpl, "metadata"): + result["platform_title"] = ( + gpl.metadata.get("title", ["Unknown"])[0] if gpl.metadata.get("title") else "Unknown" + ) + result["platform_organism"] = ( + gpl.metadata.get("organism", ["Unknown"])[0] if gpl.metadata.get("organism") else "Unknown" + ) + + except Exception as e: + result["error"] = f"Failed to download {gpl_id}: {str(e)}" + + return result + + +def map_expression_to_genes( + expression_matrix, + gpl_id: str | None = None, + probe_to_gene: dict | None = None, + aggregation: str = "mean", +): + """Map probe-level expression data to gene-level using GPL annotations. + + Parameters + ---------- + expression_matrix : pandas.DataFrame + Expression matrix with probe IDs as index and samples as columns + gpl_id : str, optional + GPL platform ID to download annotations from (if probe_to_gene not provided) + probe_to_gene : dict, optional + Pre-computed probe-to-gene mapping dict. If not provided, will download from gpl_id + aggregation : str + Method to aggregate multiple probes per gene: 'mean', 'median', 'max', 'sum' + Default is 'mean' + + Returns + ------- + dict + Dictionary containing: + - 'gene_expression': pandas DataFrame with genes as index + - 'n_genes': Number of unique genes in output + - 'n_probes_mapped': Number of probes successfully mapped + - 'n_probes_unmapped': Number of probes without gene mapping + - 'unmapped_probes': List of probe IDs that couldn't be mapped + - 'error': Error message if mapping failed + + Examples + -------- + >>> result = map_expression_to_genes(expr_df, gpl_id="GPL571") + >>> gene_expr = result["gene_expression"] + + Notes + ----- + For cross-platform meta-analysis, this function allows mapping probe-level + data from different platforms to a common gene-level representation. + """ + + result = { + "gene_expression": None, + "n_genes": 0, + "n_probes_mapped": 0, + "n_probes_unmapped": 0, + "unmapped_probes": [], + "error": None, + } + + try: + # Get probe-to-gene mapping + if probe_to_gene is None: + if gpl_id is None: + result["error"] = "Either gpl_id or probe_to_gene mapping must be provided" + return result + + gpl_result = download_gpl_annotation(gpl_id) + if gpl_result.get("error"): + result["error"] = f"Failed to get GPL annotation: {gpl_result['error']}" + return result + + probe_to_gene = gpl_result.get("probe_to_gene", {}) + + if not probe_to_gene: + result["error"] = "No probe-to-gene mapping available" + return result + + # Map probes to genes + expr_df = expression_matrix.copy() + + # Handle index types + expr_df.index = expr_df.index.astype(str) + + # Map probe IDs to gene symbols + gene_symbols = [] + mapped_probes = [] + unmapped_probes = [] + + for probe_id in expr_df.index: + gene = probe_to_gene.get(str(probe_id)) + if gene: + gene_symbols.append(gene) + mapped_probes.append(probe_id) + else: + unmapped_probes.append(probe_id) + + result["n_probes_mapped"] = len(mapped_probes) + result["n_probes_unmapped"] = len(unmapped_probes) + result["unmapped_probes"] = unmapped_probes[:100] # Limit to first 100 + + if not mapped_probes: + result["error"] = "No probes could be mapped to genes" + return result + + # Filter to mapped probes and add gene column + expr_mapped = expr_df.loc[mapped_probes].copy() + expr_mapped["gene_symbol"] = gene_symbols + + # Aggregate by gene symbol + agg_funcs = { + "mean": "mean", + "median": "median", + "max": "max", + "sum": "sum", + } + agg_func = agg_funcs.get(aggregation, "mean") + + # Group by gene and aggregate + numeric_cols = expr_mapped.select_dtypes(include=["number"]).columns.tolist() + gene_expr = expr_mapped.groupby("gene_symbol")[numeric_cols].agg(agg_func) + + result["gene_expression"] = gene_expr + result["n_genes"] = len(gene_expr) + result["aggregation_method"] = aggregation + + except Exception as e: + result["error"] = f"Failed to map expression to genes: {str(e)}" + return result diff --git a/biomni/tool/genomics.py b/biomni/tool/genomics.py index d84d26bc..c4723cac 100644 --- a/biomni/tool/genomics.py +++ b/biomni/tool/genomics.py @@ -1,14 +1,12 @@ import os from pathlib import Path -import esm import gget import gseapy import numpy as np import pandas as pd import requests import scanpy as sc -import torch from pybiomart import Dataset from tqdm import tqdm @@ -336,6 +334,9 @@ def generate_gene_embeddings_with_ESM_models( Returns: String containing the steps performed during the embedding generation process """ + import esm + import torch + steps = [] steps.append(f"Loading ESM model: {model_name}") # model loading take a while, once loaded for smaller models generation is relatively fast @@ -466,7 +467,7 @@ def annotate_celltype_scRNA( data_info, data_lake_path, cluster="leiden", - llm="claude-3-5-sonnet-20241022", + llm=None, # Uses default_config.llm_lite if None composition=None, ): """Annotate cell types based on gene markers and transferred labels using LLM. @@ -541,6 +542,12 @@ def _cluster_info(cluster_id, marker_genes, composition_df=None): """ # Some can be a mixture of multiple cell types. + # Use default config if llm is None + # Cell type annotation is a classification task, so use the lite model + if llm is None: + from biomni.config import default_config + + llm = default_config.llm_lite llm = get_llm(llm) prompt = PromptTemplate(input_variables=["cluster_info"], template=prompt_template) chain = prompt | llm @@ -557,9 +564,11 @@ def _cluster_info(cluster_id, marker_genes, composition_df=None): while True: response = chain.invoke({"cluster_info": cluster_info}) - # Handle different response types + # Handle different response types (including OpenAI Responses API list format) + from biomni.utils import normalize_llm_content + if hasattr(response, "content"): # For AIMessage - response = response.content + response = normalize_llm_content(response.content) elif isinstance(response, dict) and "text" in response: response = response["text"] elif isinstance(response, str): @@ -2216,6 +2225,8 @@ def generate_embeddings_with_state( import os import subprocess + import torch + # Initialize steps list for logging steps = [] diff --git a/biomni/tool/literature.py b/biomni/tool/literature.py index e4b31b14..702c6be1 100644 --- a/biomni/tool/literature.py +++ b/biomni/tool/literature.py @@ -215,15 +215,77 @@ def search_google(query: str, num_results: int = 3, language: str = "en") -> lis return results_string +def advanced_web_search( + query: str, + num_results: int = 5, + max_retries: int = 3, +) -> str: + """ + Perform an advanced web search using Google search and content extraction. + Works with any LLM provider (OpenAI, Claude, etc.). + + Parameters + ---------- + query : str + The search phrase to look up. + num_results : int, optional + Number of search results to retrieve (default: 5). + max_retries : int, optional + Maximum number of retry attempts. + + Returns + ------- + str + A formatted string containing the search results with titles, URLs, and content snippets. + """ + import random + + delay = random.randint(1, 3) + + for attempt in range(1, max_retries + 1): + try: + results = [] + search_results = search_google(query, num_results=num_results) + + if search_results: + return search_results + + # If search_google returns empty, try direct Google search + for res in search(query, num_results=num_results, lang="en", advanced=True): + result_text = f"Title: {res.title}\nURL: {res.url}\nDescription: {res.description}\n" + try: + # Try to get more content from the page + content = extract_url_content(res.url) + if content and len(content) > 100: + result_text += f"Content Preview: {content[:500]}...\n" + except Exception: + pass + results.append(result_text) + + if results: + return "\n---\n".join(results) + return "No search results found." + + except Exception as e: + if attempt < max_retries: + time.sleep(delay) + delay *= 2 + continue + return f"Error performing web search after {max_retries} attempts: {str(e)}" + + def advanced_web_search_claude( query: str, max_searches: int = 1, max_retries: int = 3, -) -> tuple[str, list[dict[str, str]], list]: +) -> str: """ Initiate an advanced web search by launching a specialized agent to collect relevant information and citations through multiple rounds of web searches for a given query. Craft the query carefully for the search agent to find the most relevant information. + Note: This function requires a Claude model with web search capabilities. + If Claude is not available, it will automatically fall back to advanced_web_search(). + Parameters ---------- query : str @@ -240,8 +302,6 @@ def advanced_web_search_claude( """ import random - import anthropic - try: from biomni.config import default_config @@ -253,11 +313,17 @@ def advanced_web_search_claude( model = "claude-4-sonnet-latest" api_key = os.getenv("ANTHROPIC_API_KEY") - if "claude" not in model: - raise ValueError("Model must be a Claude model.") + # Fall back to provider-agnostic search if not using Claude + if "claude" not in model.lower(): + print(f"Note: advanced_web_search_claude requires Claude model, but {model} is configured.") + print("Falling back to advanced_web_search()...") + return advanced_web_search(query, num_results=max_searches * 5, max_retries=max_retries) if not api_key: - raise ValueError("Set your api_key explicitly.") + print("No ANTHROPIC_API_KEY found. Falling back to advanced_web_search()...") + return advanced_web_search(query, num_results=max_searches * 5, max_retries=max_retries) + + import anthropic client = anthropic.Anthropic(api_key=api_key) tool_def = { diff --git a/biomni/tool/tool_description/database.py b/biomni/tool/tool_description/database.py index 820eece7..6cb46765 100644 --- a/biomni/tool/tool_description/database.py +++ b/biomni/tool/tool_description/database.py @@ -250,6 +250,96 @@ } ], }, + { + "description": "Download and parse data from a GEO accession (GSE series or GSM sample). Returns expression matrix and sample metadata.", + "name": "download_geo", + "optional_parameters": [ + { + "name": "output_dir", + "type": "str", + "description": "Directory to save downloaded files (defaults to workspace/geo_data)", + "default": None, + }, + { + "name": "return_expression_matrix", + "type": "bool", + "description": "Whether to return expression matrix", + "default": True, + }, + { + "name": "return_metadata", + "type": "bool", + "description": "Whether to return sample metadata", + "default": True, + }, + ], + "required_parameters": [ + { + "name": "accession", + "type": "str", + "description": "GEO accession ID (e.g., 'GSE123456' for series, 'GSM123456' for sample)", + "default": None, + } + ], + }, + { + "description": "Download GPL platform annotation file with probe-to-gene mappings. Essential for cross-platform meta-analysis to map probe IDs to common gene identifiers.", + "name": "download_gpl_annotation", + "optional_parameters": [ + { + "name": "output_dir", + "type": "str", + "description": "Directory to save downloaded files (defaults to workspace/geo_data/gpl_annotations)", + "default": None, + }, + { + "name": "gene_symbol_column", + "type": "str", + "description": "Column name containing gene symbols. If None, auto-detects from common column names.", + "default": None, + }, + ], + "required_parameters": [ + { + "name": "gpl_id", + "type": "str", + "description": "GPL platform ID (e.g., 'GPL571', 'GPL17586')", + "default": None, + } + ], + }, + { + "description": "Map probe-level expression data to gene-level using GPL annotations. Use this for cross-platform meta-analysis after downloading GEO datasets.", + "name": "map_expression_to_genes", + "optional_parameters": [ + { + "name": "gpl_id", + "type": "str", + "description": "GPL platform ID to download annotations from (if probe_to_gene not provided)", + "default": None, + }, + { + "name": "probe_to_gene", + "type": "dict", + "description": "Pre-computed probe-to-gene mapping dict", + "default": None, + }, + { + "name": "aggregation", + "type": "str", + "description": "Method to aggregate multiple probes per gene: 'mean', 'median', 'max', 'sum'", + "default": "mean", + }, + ], + "required_parameters": [ + { + "name": "expression_matrix", + "type": "pandas.DataFrame", + "description": "Expression matrix with probe IDs as index and samples as columns", + "default": None, + } + ], + }, { "description": "Query the NCBI dbSNP database using natural language or direct search term.", "name": "query_dbsnp", diff --git a/biomni/tool/tool_description/genomics.py b/biomni/tool/tool_description/genomics.py index 787ff80a..224a51e4 100644 --- a/biomni/tool/tool_description/genomics.py +++ b/biomni/tool/tool_description/genomics.py @@ -13,8 +13,8 @@ "type": "str", }, { - "default": "claude-3-5-sonnet-20241022", - "description": "Language model instance for cell type prediction", + "default": None, + "description": "Language model instance for cell type prediction. Uses default_config.llm_lite if None.", "name": "llm", "type": "str", }, diff --git a/biomni/tool/tool_description/literature.py b/biomni/tool/tool_description/literature.py index 39daa7c5..69e7cef5 100644 --- a/biomni/tool/tool_description/literature.py +++ b/biomni/tool/tool_description/literature.py @@ -133,7 +133,33 @@ ], }, { - "description": "Initiate an advanced web search by launching a specialized agent to collect relevant information and citations through multiple rounds of web searches for a given query.", + "description": "Perform an advanced web search using Google search and content extraction. Works with any LLM provider.", + "name": "advanced_web_search", + "optional_parameters": [ + { + "default": 5, + "description": "Number of search results to retrieve", + "name": "num_results", + "type": "int", + }, + { + "default": 3, + "description": "Maximum number of retry attempts.", + "name": "max_retries", + "type": "int", + }, + ], + "required_parameters": [ + { + "default": None, + "description": "The search query string.", + "name": "query", + "type": "str", + } + ], + }, + { + "description": "Initiate an advanced web search using Claude's web search capabilities. Falls back to advanced_web_search if Claude is not available.", "name": "advanced_web_search_claude", "optional_parameters": [ { diff --git a/biomni/utils.py b/biomni/utils.py index 09e1ef91..195e37b5 100644 --- a/biomni/utils.py +++ b/biomni/utils.py @@ -21,6 +21,46 @@ from pydantic import BaseModel, Field, ValidationError +def normalize_llm_content(content) -> str: + """Normalize LLM response content to a string. + + Handles different response formats from various LLM APIs: + - String content (most common) + - List of content blocks (OpenAI Responses API, used by GPT-5) + - Dict with 'text' key + + Args: + content: The content from response.content (can be str, list, or dict) + + Returns: + str: Normalized string content + """ + if isinstance(content, str): + return content.strip() + elif isinstance(content, list): + # OpenAI Responses API returns list of content blocks + text_parts = [] + for block in content: + if isinstance(block, dict): + # Handle various block types + block_type = block.get("type", "") + if block_type in ("text", "output_text", "redacted_text"): + text = block.get("text") or block.get("content") or "" + if isinstance(text, str): + text_parts.append(text) + elif "text" in block: + text_parts.append(str(block["text"])) + elif isinstance(block, str): + text_parts.append(block) + elif hasattr(block, "text"): + text_parts.append(str(block.text)) + return "".join(text_parts).strip() + elif isinstance(content, dict): + return str(content.get("text", content.get("content", str(content)))).strip() + else: + return str(content).strip() + + # Add these new functions for running R code and CLI commands def run_r_code(code: str) -> str: """Run R code using subprocess. @@ -119,11 +159,12 @@ def run_bash_script(script: str) -> str: cwd = os.getcwd() # Run the Bash script with the current environment and working directory + # Use bytes mode to handle binary/non-UTF8 output gracefully result = subprocess.run( [temp_file], shell=True, capture_output=True, - text=True, + text=False, # Use bytes mode to avoid UTF-8 decode errors check=False, env=env, cwd=cwd, @@ -132,13 +173,23 @@ def run_bash_script(script: str) -> str: # Clean up the temporary file os.unlink(temp_file) + # Decode output with error handling for non-UTF8 content + try: + stdout = result.stdout.decode("utf-8", errors="replace") + except Exception: + stdout = str(result.stdout) + try: + stderr = result.stderr.decode("utf-8", errors="replace") + except Exception: + stderr = str(result.stderr) + # Return the output if result.returncode != 0: traceback.print_stack() print(result) - return f"Error running Bash script (exit code {result.returncode}):\n{result.stderr}" + return f"Error running Bash script (exit code {result.returncode}):\n{stderr}" else: - return result.stdout + return stdout except Exception as e: traceback.print_exc() return f"Error running Bash script: {str(e)}" @@ -316,11 +367,14 @@ def get_all_functions_from_file(file_path): def write_python_code(request: str): - from langchain_anthropic import ChatAnthropic from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate - model = ChatAnthropic(model="claude-3-5-sonnet-20240620") + from biomni.config import default_config + from biomni.llm import get_llm + + # Use lightweight model for simple code generation tasks + model = get_llm(model=default_config.llm_lite, temperature=0.7, config=default_config) template = """Write some python code to solve the user's problem. Return only python code in Markdown format, e.g.: diff --git a/launch_biomni.py b/launch_biomni.py new file mode 100644 index 00000000..96605237 --- /dev/null +++ b/launch_biomni.py @@ -0,0 +1,8 @@ +"""Launch Biomni with OKN-WOBD MCP server on Gradio.""" + +from biomni.agent import A1 + +agent = A1(path="./data", llm="gpt-5", expected_data_lake_files=[]) +agent.add_mcp(config_path="./mcp_config.yaml") +print("\n🚀 Launching Gradio UI...") +agent.launch_gradio_demo(share=False) diff --git a/mcp_config.yaml b/mcp_config.yaml new file mode 100644 index 00000000..947d1215 --- /dev/null +++ b/mcp_config.yaml @@ -0,0 +1,13 @@ +# MCP Server Config for Biomni — OKN-WOBD +# Usage: agent.add_mcp(config_path="./mcp_config.yaml") + +mcp_servers: + okn-wobd: + command: ["python3.11", "-m", "okn_wobd.mcp_server"] + enabled: true + description: > + Biomedical analysis tools for gene-disease path finding, gene neighborhood + queries, drug-disease opposing expression patterns, and differential + expression analysis via ARCHS4. + env: + PYTHONPATH: "/Users/bgood/Documents/GitHub/OKN-WOBD/src" diff --git a/proto_okn_mcp_config.yaml b/proto_okn_mcp_config.yaml new file mode 100644 index 00000000..4d770af5 --- /dev/null +++ b/proto_okn_mcp_config.yaml @@ -0,0 +1,77 @@ +# MCP Server Config for Proto-OKN SPARQL endpoints +# Each server provides tools to query a specific knowledge graph + +mcp_servers: + # Biomedical Knowledge Graphs + spoke-genelab: + command: ["uvx", "mcp-proto-okn", "--endpoint", "https://frink.apps.renci.org/spoke-genelab/sparql"] + enabled: true + description: "SPOKE-GeneLab - biomedical KG with genes, diseases, compounds, anatomies" + + spoke: + command: ["uvx", "mcp-proto-okn", "--endpoint", "https://frink.apps.renci.org/spoke/sparql"] + enabled: true + description: "SPOKE - Scalable Precision Medicine Open Knowledge Engine" + + nde: + command: ["uvx", "mcp-proto-okn", "--endpoint", "https://frink.apps.renci.org/nde/sparql"] + enabled: true + description: "NDE - NIAID Data Ecosystem knowledge graph" + + biohealth: + command: ["uvx", "mcp-proto-okn", "--endpoint", "https://frink.apps.renci.org/biohealth/sparql"] + enabled: true + description: "BioHealth knowledge graph" + + gene-expression-atlas-okn: + command: ["uvx", "mcp-proto-okn", "--endpoint", "https://frink.apps.renci.org/gene-expression-atlas-okn/sparql"] + enabled: true + description: "Gene Expression Atlas OKN" + + # BioBricks toxicology/chemistry KGs + biobricks-aopwiki: + command: ["uvx", "mcp-proto-okn", "--endpoint", "https://frink.apps.renci.org/biobricks-aopwiki/sparql"] + enabled: true + description: "BioBricks AOP-Wiki - Adverse Outcome Pathways" + + biobricks-mesh: + command: ["uvx", "mcp-proto-okn", "--endpoint", "https://frink.apps.renci.org/biobricks-mesh/sparql"] + enabled: true + description: "BioBricks MeSH - Medical Subject Headings" + + biobricks-pubchem-annotations: + command: ["uvx", "mcp-proto-okn", "--endpoint", "https://frink.apps.renci.org/biobricks-pubchem-annotations/sparql"] + enabled: true + description: "BioBricks PubChem Annotations - chemical compound annotations" + + biobricks-tox21: + command: ["uvx", "mcp-proto-okn", "--endpoint", "https://frink.apps.renci.org/biobricks-tox21/sparql"] + enabled: true + description: "BioBricks Tox21 - toxicology screening data" + + biobricks-toxcast: + command: ["uvx", "mcp-proto-okn", "--endpoint", "https://frink.apps.renci.org/biobricks-toxcast/sparql"] + enabled: true + description: "BioBricks ToxCast - EPA toxicity forecasting" + + biobricks-ice: + command: ["uvx", "mcp-proto-okn", "--endpoint", "https://frink.apps.renci.org/biobricks-ice/sparql"] + enabled: true + description: "BioBricks ICE - Integrated Chemical Environment" + + # Other scientific KGs + dreamkg: + command: ["uvx", "mcp-proto-okn", "--endpoint", "https://frink.apps.renci.org/dreamkg/sparql"] + enabled: true + description: "DreamKG knowledge graph" + + wildlifekn: + command: ["uvx", "mcp-proto-okn", "--endpoint", "https://frink.apps.renci.org/wildlifekn/sparql"] + enabled: true + description: "Wildlife Knowledge Network" + + # External endpoints + uniprot-sparql: + command: ["uvx", "mcp-proto-okn", "--endpoint", "https://sparql.uniprot.org/sparql", "--description", "UniProt protein sequence and function database"] + enabled: true + description: "UniProt SPARQL - protein sequences and annotations" diff --git a/pyproject.toml b/pyproject.toml index 0fad3a91..af09c3d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ Repository = "https://github.com/snap-stanford/biomni" [project.optional-dependencies] gradio = ["gradio>=5.0,<6.0"] +geo = ["GEOparse>=2.0"] [tool.setuptools] include-package-data = true diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..d2ce2819 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,25 @@ +"""Shared pytest configuration for Biomni tests.""" + +import pytest + + +def pytest_addoption(parser): + parser.addoption( + "--live", + action="store_true", + default=False, + help="Run tests that call live LLM APIs (requires API keys).", + ) + + +def pytest_configure(config): + config.addinivalue_line("markers", "live: marks tests that call live LLM APIs (skip unless --live)") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--live"): + return # --live given: run everything + skip_live = pytest.mark.skip(reason="needs --live option to run") + for item in items: + if "live" in item.keywords: + item.add_marker(skip_live) diff --git a/tests/test_agent_geo.py b/tests/test_agent_geo.py new file mode 100644 index 00000000..91943ac6 --- /dev/null +++ b/tests/test_agent_geo.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python +"""Test Biomni agent with GEO query functionality. + +This script tests the full agent with a simple GEO query task. +Run with: python tests/test_agent_geo.py +""" + +import os +import sys + +# Add project root to path for imports +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Load environment variables from .env file +from dotenv import load_dotenv + +load_dotenv() + +from biomni.config import default_config + +print("=" * 60) +print("BIOMNI AGENT GEO QUERY TEST") +print("=" * 60) +print(f"LLM (reasoning): {default_config.llm}") +print(f"LLM Lite (simple): {default_config.llm_lite}") +print(f"OpenAI API Key set: {'Yes' if os.environ.get('OPENAI_API_KEY') else 'No'}") +print(f"Anthropic API Key set: {'Yes' if os.environ.get('ANTHROPIC_API_KEY') else 'No'}") +print() + +# Test: Run the agent with a GEO query +print("TEST: Agent query for diabetic nephropathy datasets") +print("-" * 40) + +try: + from biomni.agent import A1 + + # Create agent with config from .env + print("Creating agent...") + agent = A1( + path="./data", + llm=default_config.llm, # Use config from .env + expected_data_lake_files=[], + use_tool_retriever=False, # Disable retriever for simpler testing + ) + + # Configure the agent + print("Configuring agent...") + agent.configure() + + # Run a simple query + print("Running query...") + prompt = "Use the query_geo tool to find RNA-seq datasets related to diabetic nephropathy in humans. Return the top 3 results with their GEO accession numbers and titles." + + result = agent.go(prompt) + + print("\n" + "=" * 40) + print("AGENT RESULT:") + print("=" * 40) + + if result: + messages, final_answer = result + print(f"Final Answer:\n{final_answer}") + else: + print("No result returned") + +except Exception as e: + print(f"EXCEPTION: {type(e).__name__}: {e}") + import traceback + + traceback.print_exc() + +print() +print("=" * 60) +print("TEST COMPLETE") +print("=" * 60) diff --git a/tests/test_geo_query.py b/tests/test_geo_query.py new file mode 100644 index 00000000..334c3e7a --- /dev/null +++ b/tests/test_geo_query.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python +"""Test script for Biomni GEO query functionality. + +This script tests the query_geo function directly without the full agent. +Run with: python tests/test_geo_query.py +""" + +import os +import sys + +# Add project root to path for imports +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Load environment variables from .env file +from dotenv import load_dotenv + +load_dotenv() + +from biomni.config import default_config + +print("=" * 60) +print("BIOMNI GEO QUERY TEST") +print("=" * 60) +print(f"LLM (reasoning): {default_config.llm}") +print(f"LLM Lite (simple): {default_config.llm_lite}") +print() + +# Test 1: Direct search term (bypass LLM) +print("TEST 1: Direct search term (bypassing LLM)") +print("-" * 40) + +try: + from biomni.tool.database import query_geo + + result = query_geo( + search_term="diabetic nephropathy[Title] AND Homo sapiens[Organism] AND gse[ETYP]", max_results=3 + ) + + if isinstance(result, dict): + if "total_results" in result: + print(f"SUCCESS! Found {result.get('total_results')} datasets") + elif result.get("error"): + print(f"ERROR: {result.get('error')}") + else: + print(f"Result keys: {result.keys()}") + else: + print(f"Result: {result}") + +except Exception as e: + print(f"EXCEPTION: {type(e).__name__}: {e}") + import traceback + + traceback.print_exc() + +print() + +# Test 2: Natural language query (uses LLM) +print("TEST 2: Natural language query (uses LLM lite)") +print("-" * 40) + +try: + from biomni.tool.database import query_geo + + result = query_geo(prompt="Find RNA-seq datasets for diabetic nephropathy in humans", max_results=3) + + if isinstance(result, dict): + if "total_results" in result: + print(f"SUCCESS! Found {result.get('total_results')} datasets") + print(f"Query interpretation: {result.get('query_interpretation', 'N/A')}") + elif result.get("error"): + print(f"ERROR: {result.get('error')}") + if result.get("raw_response"): + print(f"Raw response: {result.get('raw_response')[:200]}") + else: + print(f"Result keys: {result.keys()}") + else: + print(f"Result: {result}") + +except Exception as e: + print(f"EXCEPTION: {type(e).__name__}: {e}") + import traceback + + traceback.print_exc() + +print() +print("=" * 60) +print("TESTS COMPLETE") +print("=" * 60) diff --git a/tests/test_geo_query_integration.py b/tests/test_geo_query_integration.py new file mode 100644 index 00000000..1c84b62c --- /dev/null +++ b/tests/test_geo_query_integration.py @@ -0,0 +1,288 @@ +#!/usr/bin/env python +""" +Integration tests for Biomni GEO query functionality. + +Tests the query_geo tool both directly and through the agent to ensure +OpenAI models can successfully query the NCBI GEO database. + +Run with: pytest tests/test_geo_query_integration.py -v +Or directly: python tests/test_geo_query_integration.py +""" + +import os +import sys + +import pytest + +# Ensure we can import from the project root +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Load environment variables +from dotenv import load_dotenv + +load_dotenv() + + +class TestGeoQueryDirect: + """Test query_geo function directly without the agent.""" + + @pytest.fixture(autouse=True) + def setup_fixture(self): + """Set up test configuration (pytest fixture).""" + self._setup() + yield + + def _setup(self): + """Set up test configuration.""" + from biomni.config import default_config + + default_config.llm = "gpt-4o" + default_config.llm_lite = "gpt-4o-mini" # Use OpenAI lite model for tests + self.config = default_config + + def test_direct_search_term(self): + """Test query_geo with a direct search term (bypasses LLM).""" + from biomni.tool.database import query_geo + + result = query_geo( + search_term="diabetic nephropathy[Title] AND Homo sapiens[Organism] AND gse[ETYP]", max_results=3 + ) + + assert isinstance(result, dict), "Result should be a dictionary" + assert "total_results" in result or "error" in result, "Should have results or error" + + if "error" not in result: + assert result.get("total_results", 0) > 0, "Should find some datasets" + assert "formatted_results" in result, "Should have formatted results" + print(f"✓ Found {result['total_results']} datasets") + + def test_natural_language_query(self): + """Test query_geo with natural language prompt (uses LLM).""" + from biomni.tool.database import query_geo + + result = query_geo(prompt="Find RNA-seq datasets for diabetic nephropathy in humans", max_results=3) + + assert isinstance(result, dict), "Result should be a dictionary" + + # Check for success + if "error" in result: + pytest.fail(f"Query failed: {result.get('error')} - Raw: {result.get('raw_response', 'N/A')}") + + assert "total_results" in result, "Should have total_results" + assert result["total_results"] > 0, "Should find some datasets" + assert "formatted_results" in result, "Should have formatted results" + + # Verify we got actual dataset info + results = result.get("formatted_results", {}).get("result", {}) + uids = results.get("uids", []) + assert len(uids) > 0, "Should have dataset UIDs" + + # Check first dataset has expected fields + first_uid = uids[0] + first_dataset = results.get(first_uid, {}) + assert "accession" in first_dataset, "Dataset should have accession" + assert "title" in first_dataset, "Dataset should have title" + + print(f"✓ Found {result['total_results']} datasets via natural language") + print(f" Query interpretation: {result.get('query_interpretation', 'N/A')}") + for uid in uids[:3]: + ds = results.get(uid, {}) + print(f" - {ds.get('accession')}: {ds.get('title', 'N/A')[:60]}...") + + def test_llm_query_parsing(self): + """Test that the LLM correctly parses natural language to GEO query.""" + import pickle + + from biomni.tool.database import _query_llm_for_api + + # Load GEO schema + schema_path = os.path.join(os.path.dirname(__file__), "../biomni/tool/schema_db/geo.pkl") + with open(schema_path, "rb") as f: + geo_schema = pickle.load(f) + + system_template = """ + You are a bioinformatics assistant. Convert the user query into a GEO search term. + Output only a JSON object with: + 1. "search_term": The GEO search query + 2. "database": Either "gds" or "geoprofiles" + + Schema: {schema} + """ + + result = _query_llm_for_api( + prompt="Find RNA-seq data for diabetic nephropathy", schema=geo_schema, system_template=system_template + ) + + assert result.get("success"), f"LLM query failed: {result.get('error')}" + assert "data" in result, "Should have data" + assert "search_term" in result["data"], "Should have search_term" + assert "database" in result["data"], "Should have database" + + print(f"✓ LLM generated search term: {result['data']['search_term']}") + + +class TestGeoQueryAgent: + """Test query_geo through the Biomni agent.""" + + @pytest.fixture(autouse=True) + def setup_fixture(self): + """Set up test configuration (pytest fixture).""" + self._setup() + yield + + def _setup(self): + """Set up test configuration.""" + from biomni.config import default_config + + default_config.llm = "gpt-4o" + default_config.llm_lite = "gpt-4o-mini" + + @pytest.mark.slow + def test_agent_geo_query(self): + """Test that the agent can use query_geo to find datasets.""" + from biomni.agent import A1 + + # Create agent with minimal setup + agent = A1(path="./data", llm="gpt-4o", expected_data_lake_files=[], use_tool_retriever=False) + agent.configure() + + # Run query + prompt = "Use the query_geo tool to find RNA-seq datasets related to diabetic nephropathy in humans. Return the top 3 results with their GEO accession numbers and titles." + + result = agent.go(prompt) + + # Result is a tuple of (messages, final_answer) + assert result is not None, "Agent should return a result" + + messages, final_answer = result + assert final_answer is not None, "Should have a final answer" + + # Check that the answer mentions GEO accession numbers + answer_lower = final_answer.lower() + assert "gse" in answer_lower, "Answer should mention GSE accession numbers" + + # Check for specific dataset indicators + has_datasets = any( + x in answer_lower + for x in [ + "gse317266", + "gse315877", + "gse273001", # Known datasets + "diabetic", + "nephropathy", + "kidney", + ] + ) + assert has_datasets, "Answer should contain dataset information" + + print("✓ Agent successfully queried GEO and returned results") + print(f"Final answer preview: {final_answer[:500]}...") + + +class TestOpenAICompatibility: + """Test OpenAI-specific compatibility features.""" + + def test_config_uses_openai(self): + """Verify configuration is set to use OpenAI models.""" + from biomni.config import default_config + + default_config.llm = "gpt-4o" + default_config.llm_lite = "gpt-4o-mini" + + assert default_config.llm == "gpt-4o", "Config should use gpt-4o" + assert os.environ.get("OPENAI_API_KEY"), "OPENAI_API_KEY should be set" + print("✓ OpenAI configuration verified") + + def test_llm_factory_creates_openai(self): + """Test that get_llm creates OpenAI model correctly.""" + from biomni.config import default_config + from biomni.llm import get_llm + + default_config.llm = "gpt-4o" + default_config.llm_lite = "gpt-4o-mini" + + llm = get_llm(model="gpt-4o", temperature=0.7) + + assert llm is not None, "Should create LLM instance" + assert "openai" in str(type(llm)).lower(), "Should be an OpenAI model" + print(f"✓ Created LLM: {type(llm).__name__}") + + +def run_tests(): + """Run all tests and report results.""" + print("=" * 60) + print("BIOMNI GEO QUERY INTEGRATION TESTS") + print("=" * 60) + print() + + # Check environment + if not os.environ.get("OPENAI_API_KEY"): + print("ERROR: OPENAI_API_KEY not set. Please set it in .env file.") + return False + + passed = 0 + failed = 0 + + # Test 1: OpenAI configuration + print("Test 1: OpenAI Configuration") + print("-" * 40) + try: + test = TestOpenAICompatibility() + test.test_config_uses_openai() + test.test_llm_factory_creates_openai() + passed += 1 + except Exception as e: + print(f"FAILED: {e}") + failed += 1 + print() + + # Test 2: Direct search term + print("Test 2: Direct GEO Search Term") + print("-" * 40) + try: + test = TestGeoQueryDirect() + test._setup() + test.test_direct_search_term() + passed += 1 + except Exception as e: + print(f"FAILED: {e}") + failed += 1 + print() + + # Test 3: LLM query parsing + print("Test 3: LLM Query Parsing") + print("-" * 40) + try: + test = TestGeoQueryDirect() + test._setup() + test.test_llm_query_parsing() + passed += 1 + except Exception as e: + print(f"FAILED: {e}") + failed += 1 + print() + + # Test 4: Natural language query + print("Test 4: Natural Language GEO Query") + print("-" * 40) + try: + test = TestGeoQueryDirect() + test._setup() + test.test_natural_language_query() + passed += 1 + except Exception as e: + print(f"FAILED: {e}") + failed += 1 + print() + + # Summary + print("=" * 60) + print(f"RESULTS: {passed} passed, {failed} failed") + print("=" * 60) + + return failed == 0 + + +if __name__ == "__main__": + success = run_tests() + sys.exit(0 if success else 1) diff --git a/tests/test_meta_analysis.py b/tests/test_meta_analysis.py new file mode 100644 index 00000000..7d644adc --- /dev/null +++ b/tests/test_meta_analysis.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +"""Test Biomni agent with a complex meta-analysis task. + +This script tests the agent's ability to handle a multi-step scientific task. +Run with: python tests/test_meta_analysis.py +""" + +import os +import sys + +# Add project root to path for imports +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +# Load environment variables +from dotenv import load_dotenv + +load_dotenv() + +from biomni.config import default_config + +print("=" * 70) +print("BIOMNI META-ANALYSIS TEST") +print("=" * 70) +print(f"LLM (reasoning): {default_config.llm}") +print(f"LLM Lite (simple): {default_config.llm_lite}") +print() + +TASK = """You are tasked with conducting a meta-analysis of publicly available gene expression datasets to identify genes and regulatory mechanisms that are consistently dysregulated in kidney tissue of patients with diabetic nephropathy (DN). + +Biological question: What genes are reproducibly differentially expressed in kidney samples from DN patients compared to healthy controls, and what pathways and transcription factors form the core regulatory architecture of those expression changes? + +Complete the following steps: + +STEP 1 - Data acquisition: Search the NCBI Gene Expression Omnibus (GEO) for microarray or RNA-seq datasets that profile gene expression in human kidney tissue from DN patients and non-diabetic controls. Find at least 4 independent datasets. Apply inclusion criteria: human samples, availability of expression data with case/control groups, at least 5 samples per group. Exclude animal models. + +STEP 2 - Data retrieval: For each suitable dataset, download or access the processed expression data. Extract the sample metadata to identify DN vs control groups. + +STEP 3 - Differential expression analysis: For each dataset, perform differential expression analysis comparing DN vs control samples. Identify significantly differentially expressed genes (DEGs) using appropriate statistical thresholds (e.g., adjusted p-value < 0.05, |log2FC| > 1). + +STEP 4 - Meta-analysis: Identify genes that are consistently differentially expressed across multiple datasets. Create a list of "consensus DEGs" that appear in at least 2-3 datasets. + +STEP 5 - Pathway analysis: Perform pathway enrichment analysis on the consensus DEGs to identify biological pathways dysregulated in DN. + +Execute each step in order, showing your work and intermediate results. Provide a final summary of the key findings.""" + +print(f"Task (first 500 chars):\n{TASK[:500]}...") +print("\n" + "=" * 70) + +try: + from biomni.agent import A1 + + print("Creating agent...") + agent = A1( + path="./data", + llm=default_config.llm, # Use config from .env + expected_data_lake_files=[], + use_tool_retriever=True, + ) + + print("Configuring agent...") + agent.configure() + + print("Running task...") + print("=" * 70) + + result = agent.go(TASK) + + messages, final_answer = result + + print("\n" + "=" * 70) + print("CONVERSATION LOG:") + print("=" * 70) + for i, msg in enumerate(messages[-10:]): # Last 10 messages + print(f"\n--- Message {i + 1} ---") + print(msg[:1000] if len(msg) > 1000 else msg) + + print("\n" + "=" * 70) + print("FINAL ANSWER:") + print("=" * 70) + print(final_answer) + +except KeyboardInterrupt: + print("\n\nInterrupted by user") +except Exception as e: + print(f"\nEXCEPTION: {type(e).__name__}: {e}") + import traceback + + traceback.print_exc() + +print("\n" + "=" * 70) +print("TEST COMPLETE") +print("=" * 70) diff --git a/tests/test_openai_e2e.py b/tests/test_openai_e2e.py new file mode 100644 index 00000000..2da3a75d --- /dev/null +++ b/tests/test_openai_e2e.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python +""" +End-to-end integration test for the Biomni agent with OpenAI models. + +This test creates a real A1 agent with an OpenAI model, sends it a simple task, +and verifies the full ReAct loop completes: retrieval -> generate -> execute -> result. + +Run with: pytest tests/test_openai_e2e.py -v --live -s +Requires OPENAI_API_KEY in environment (loaded from .env). +""" + +import os +import re +import sys + +import pytest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +@pytest.mark.live +class TestOpenAIEndToEnd: + """Full agent loop with a real OpenAI model and real code execution. + + These tests instantiate A1 with gpt-5-mini (cheap, fast) and run + tasks that force the agent through generate -> execute -> solution. + """ + + @pytest.fixture(autouse=True) + def check_openai_key(self): + # Load .env if present (same as the agent does) + from dotenv import load_dotenv + + if os.path.exists(".env"): + load_dotenv(".env", override=False) + if not os.environ.get("OPENAI_API_KEY"): + pytest.skip("OPENAI_API_KEY not set") + + @pytest.fixture + def agent_4o_mini(self, tmp_path): + """Create a minimal A1 agent with gpt-5-mini, no datalake download.""" + from biomni.agent import A1 + + agent = A1( + path=str(tmp_path / "data"), + llm="gpt-5-mini", + use_tool_retriever=False, # skip retrieval to keep it fast + expected_data_lake_files=[], # skip datalake download + timeout_seconds=120, + ) + return agent + + @pytest.fixture + def agent_gpt5(self, tmp_path): + """Create a minimal A1 agent with gpt-5, no datalake download.""" + from biomni.agent import A1 + + try: + agent = A1( + path=str(tmp_path / "data"), + llm="gpt-5", + use_tool_retriever=False, + expected_data_lake_files=[], + timeout_seconds=120, + ) + except Exception as e: + pytest.skip(f"gpt-5 not available: {e}") + return agent + + def test_simple_math_e2e(self, agent_4o_mini): + """Agent should execute Python code and return a correct math result.""" + log, final = agent_4o_mini.go( + "Use Python to compute 7 * 13 and tell me the result. " + "You must execute the code, not just tell me the answer." + ) + + # The log should contain multiple entries (at least generate + execute + solution) + assert len(log) >= 3, f"Expected at least 3 log entries, got {len(log)}" + + # Verify code was actually executed (an tag should appear somewhere) + all_log_text = "\n".join(str(entry) for entry in log) + assert "" in all_log_text, "Agent should have produced an tag" + + # Verify the correct answer appears in the final output + assert "91" in final, f"Expected 91 in final answer, got: {final[:300]}" + + print(f"\n Log entries: {len(log)}") + print(f" Final answer (first 200 chars): {final[:200]}") + + def test_no_syntax_error_loop(self, agent_4o_mini): + """Regression test: agent should NOT loop on 'invalid syntax' from markdown fences.""" + log, final = agent_4o_mini.go( + "Write and run a Python script that creates a list of the first 10 square numbers " + "and prints them. Execute the code." + ) + + all_log_text = "\n".join(str(entry) for entry in log) + + # Count syntax errors - should be 0 or at most 1 (transient) + syntax_errors = all_log_text.count("invalid syntax") + assert syntax_errors <= 1, ( + f"Agent looped on syntax errors ({syntax_errors} occurrences). Markdown fence stripping may not be working." + ) + + # Should have produced a solution + assert "" in all_log_text or "solution" in all_log_text.lower(), ( + "Agent should have reached a solution" + ) + + # The squares should appear somewhere + assert "1" in final and "4" in final and "9" in final, f"Expected square numbers in output, got: {final[:300]}" + + print(f"\n Log entries: {len(log)}") + print(f" Syntax errors seen: {syntax_errors}") + print(f" Final answer (first 200 chars): {final[:200]}") + + def test_execute_tag_contains_clean_code(self, agent_4o_mini): + """Verify that code inside tags has no markdown fences.""" + log, final = agent_4o_mini.go("Execute Python code that prints 'biomni_test_marker_12345'. Just run it.") + + all_log_text = "\n".join(str(entry) for entry in log) + + # Find all blocks + execute_blocks = re.findall(r"(.*?)", all_log_text, re.DOTALL) + assert len(execute_blocks) >= 1, "Should have at least one execute block" + + for i, block in enumerate(execute_blocks): + assert not block.strip().startswith("```"), f"Execute block {i} starts with markdown fence: {block[:100]}" + assert not block.strip().endswith("```"), f"Execute block {i} ends with markdown fence: {block[-100:]}" + + # The marker should have been printed + assert "biomni_test_marker_12345" in all_log_text, "The test marker should appear in execution output" + + print(f"\n Execute blocks found: {len(execute_blocks)}") + print(f" First block (first 120 chars): {execute_blocks[0][:120]}") + + def test_gpt5_e2e_if_available(self, agent_gpt5): + """Same test with gpt-5 to verify the Responses API path works end-to-end.""" + log, final = agent_gpt5.go("Use Python to compute 2**10 and tell me the result. Execute the code.") + + assert len(log) >= 3, f"Expected at least 3 log entries, got {len(log)}" + + all_log_text = "\n".join(str(entry) for entry in log) + assert "" in all_log_text, "Agent should have produced an tag" + assert "1024" in final, f"Expected 1024 in final answer, got: {final[:300]}" + + # Verify no syntax error loop + syntax_errors = all_log_text.count("invalid syntax") + assert syntax_errors <= 1, f"Syntax error loop detected ({syntax_errors} occurrences)" + + print(f"\n Log entries: {len(log)}") + print(f" Final answer (first 200 chars): {final[:200]}") + + +@pytest.mark.live +class TestOpenAIWithRetriever: + """Test with tool retriever enabled — closer to real usage.""" + + @pytest.fixture(autouse=True) + def check_keys(self): + from dotenv import load_dotenv + + if os.path.exists(".env"): + load_dotenv(".env", override=False) + if not os.environ.get("OPENAI_API_KEY"): + pytest.skip("OPENAI_API_KEY not set") + # Retriever uses the lite model which may need an Anthropic key + # depending on config, so check both + if not os.environ.get("ANTHROPIC_API_KEY") and not os.environ.get("OPENAI_API_KEY"): + pytest.skip("Need at least one API key for retriever") + + def test_retriever_plus_execute(self, tmp_path): + """Full loop with retriever: retrieve tools -> generate -> execute -> solution.""" + from biomni.agent import A1 + + agent = A1( + path=str(tmp_path / "data"), + llm="gpt-5-mini", + use_tool_retriever=True, + expected_data_lake_files=[], + timeout_seconds=180, + ) + + log, final = agent.go("Using Python, calculate the factorial of 10 and tell me the result.") + + assert len(log) >= 3, f"Expected at least 3 log entries, got {len(log)}" + + all_log_text = "\n".join(str(entry) for entry in log) + assert "" in all_log_text, "Agent should have executed code" + assert "3628800" in final, f"Expected 3628800 (10!) in final answer, got: {final[:300]}" + + # No syntax error looping + syntax_errors = all_log_text.count("invalid syntax") + assert syntax_errors <= 1, f"Syntax error loop detected ({syntax_errors} occurrences)" + + print(f"\n Log entries: {len(log)}") + print(f" Final answer (first 200 chars): {final[:200]}") + + +@pytest.mark.live +class TestOpenAIWithMCP: + """Test with OKN-WOBD MCP server — full agent loop including external tool discovery.""" + + @pytest.fixture(autouse=True) + def check_keys(self): + from dotenv import load_dotenv + + if os.path.exists(".env"): + load_dotenv(".env", override=False) + if not os.environ.get("OPENAI_API_KEY"): + pytest.skip("OPENAI_API_KEY not set") + + @pytest.fixture(autouse=True) + def check_mcp_available(self): + """Skip if the OKN-WOBD MCP server source isn't available locally.""" + okn_src = "/Users/bgood/Documents/GitHub/OKN-WOBD/src" + if not os.path.isdir(okn_src): + pytest.skip("OKN-WOBD source not found at /Users/bgood/Documents/GitHub/OKN-WOBD/src") + mcp_config = os.path.join(os.path.dirname(os.path.dirname(__file__)), "mcp_config.yaml") + if not os.path.isfile(mcp_config): + pytest.skip("mcp_config.yaml not found in project root") + + def test_mcp_osteoarthritis_geo_dataset(self, tmp_path): + """Agent uses OKN-WOBD MCP tools to find a GEO dataset about osteoarthritis. + + This exercises: + 1. MCP server startup and tool discovery + 2. Tool retriever selecting MCP tools + 3. Agent generating code that calls an MCP tool + 4. Code execution returning real results + 5. Agent summarizing findings in a solution + """ + from biomni.agent import A1 + + mcp_config = os.path.join(os.path.dirname(os.path.dirname(__file__)), "mcp_config.yaml") + + agent = A1( + path=str(tmp_path / "data"), + llm="gpt-5", + use_tool_retriever=True, + expected_data_lake_files=[], + timeout_seconds=300, + ) + agent.add_mcp(config_path=mcp_config) + + log, final = agent.go("Using okn-wobd mcp, find me a geo dataset about osteoarthritis.") + + all_log_text = "\n".join(str(entry) for entry in log) + + # Should have executed code (not just answered from knowledge) + assert "" in all_log_text, "Agent should have produced an tag" + + # Should have reached a solution + assert "" in all_log_text, "Agent should have reached a solution" + + # The answer should reference GEO datasets (GSE IDs) or osteoarthritis + final_lower = final.lower() + has_geo = "gse" in final_lower or "geo" in final_lower + has_oa = "osteoarthritis" in final_lower or "arthritis" in final_lower + assert has_geo or has_oa, f"Expected GEO dataset IDs or osteoarthritis mention in answer, got: {final[:500]}" + + # No syntax error loop + syntax_errors = all_log_text.count("invalid syntax") + assert syntax_errors <= 1, f"Syntax error loop detected ({syntax_errors} occurrences)" + + print(f"\n Log entries: {len(log)}") + print(f" Final answer (first 500 chars): {final[:500]}") + + +if __name__ == "__main__": + pytest.main([__file__, "-v", "--live", "-s"]) diff --git a/tests/test_openai_structured_output.py b/tests/test_openai_structured_output.py new file mode 100644 index 00000000..761a17dd --- /dev/null +++ b/tests/test_openai_structured_output.py @@ -0,0 +1,736 @@ +#!/usr/bin/env python +""" +Tests for OpenAI structured output support and llm_lite auto-inference. + +Exercises the changes from the goodb-wobd branch: + 1. BiomniConfig._infer_lite_model() provider matching + 2. BiomniConfig auto-inference of llm_lite from llm + 3. AgentResponse Pydantic schema correctness + 4. A1._is_openai_model detection property + 5. generate() structured output path (mocked LLM) + 6. generate() XML fallback path (mocked LLM) + +Run with: pytest tests/test_openai_structured_output.py -v +""" + +import os +import sys +from unittest.mock import MagicMock, patch + +import pytest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +# --------------------------------------------------------------------------- +# 1. BiomniConfig._infer_lite_model +# --------------------------------------------------------------------------- +class TestInferLiteModel: + """Test the static method that maps a main model to its lite counterpart.""" + + def test_claude_model(self): + from biomni.config import BiomniConfig + + assert BiomniConfig._infer_lite_model("claude-sonnet-4-5") == "claude-haiku-4-5" + + def test_gpt_model(self): + from biomni.config import BiomniConfig + + assert BiomniConfig._infer_lite_model("gpt-5.2") == "gpt-5-mini" + + def test_gpt4_model(self): + from biomni.config import BiomniConfig + + assert BiomniConfig._infer_lite_model("gpt-4o") == "gpt-5-mini" + + def test_gemini_model(self): + from biomni.config import BiomniConfig + + assert BiomniConfig._infer_lite_model("gemini-2.0-pro") == "gemini-1.5-flash" + + def test_unknown_model_falls_back_to_claude(self): + from biomni.config import BiomniConfig + + assert BiomniConfig._infer_lite_model("some-unknown-model") == "claude-haiku-4-5" + + +# --------------------------------------------------------------------------- +# 2. BiomniConfig llm_lite auto-inference integration +# --------------------------------------------------------------------------- +class TestConfigLiteAutoInference: + """Test that BiomniConfig auto-infers llm_lite when not explicitly set.""" + + def test_default_is_claude(self): + from biomni.config import BiomniConfig + + config = BiomniConfig() + assert config.llm == "claude-sonnet-4-5" + assert config.llm_lite == "claude-haiku-4-5" + + def test_openai_main_infers_openai_lite(self): + from biomni.config import BiomniConfig + + config = BiomniConfig(llm="gpt-5.2") + assert config.llm_lite == "gpt-5-mini" + + def test_gemini_main_infers_gemini_lite(self): + from biomni.config import BiomniConfig + + config = BiomniConfig(llm="gemini-2.0-pro") + assert config.llm_lite == "gemini-1.5-flash" + + def test_explicit_lite_not_overridden(self): + from biomni.config import BiomniConfig + + config = BiomniConfig(llm="gpt-5.2", llm_lite="gpt-4o") + assert config.llm_lite == "gpt-4o", "Explicit llm_lite should not be overridden" + + def test_env_var_overrides_inference(self): + """BIOMNI_LLM_LITE env var should take precedence over auto-inference.""" + from biomni.config import BiomniConfig + + with patch.dict(os.environ, {"BIOMNI_LLM_LITE": "my-custom-lite-model"}): + config = BiomniConfig(llm="gpt-5.2") + assert config.llm_lite == "my-custom-lite-model" + + def test_env_var_llm_triggers_matching_lite(self): + """When BIOMNI_LLM sets an OpenAI model, llm_lite should auto-match.""" + from biomni.config import BiomniConfig + + with patch.dict(os.environ, {"BIOMNI_LLM": "gpt-4o"}, clear=False): + # Remove BIOMNI_LLM_LITE if present to test inference + env = os.environ.copy() + env.pop("BIOMNI_LLM_LITE", None) + with patch.dict(os.environ, env, clear=True): + config = BiomniConfig() + assert config.llm == "gpt-4o" + assert config.llm_lite == "gpt-5-mini" + + +# --------------------------------------------------------------------------- +# 3. AgentResponse schema +# --------------------------------------------------------------------------- +class TestAgentResponseSchema: + """Test the Pydantic schema used for OpenAI structured outputs.""" + + def test_fields_exist(self): + from biomni.agent.a1 import AgentResponse + + fields = set(AgentResponse.model_fields.keys()) + assert fields == {"reasoning", "action", "content"} + + def test_action_enum_values(self): + from biomni.agent.a1 import AgentResponse + + # Valid actions + resp = AgentResponse(reasoning="thinking...", action="execute", content="print('hi')") + assert resp.action == "execute" + + resp = AgentResponse(reasoning="done", action="solution", content="The answer is 42") + assert resp.action == "solution" + + def test_invalid_action_rejected(self): + from biomni.agent.a1 import AgentResponse + + with pytest.raises(ValueError): + AgentResponse(reasoning="hmm", action="think", content="...") + + def test_json_schema_has_enum(self): + """The JSON schema sent to OpenAI should constrain action to the enum.""" + from biomni.agent.a1 import AgentResponse + + schema = AgentResponse.model_json_schema() + assert schema["properties"]["action"]["enum"] == ["execute", "solution"] + + def test_all_fields_required(self): + from biomni.agent.a1 import AgentResponse + + schema = AgentResponse.model_json_schema() + assert set(schema["required"]) == {"reasoning", "action", "content"} + + +# --------------------------------------------------------------------------- +# 4. A1._is_openai_model property +# --------------------------------------------------------------------------- +class TestIsOpenAIModel: + """Test the property that detects whether the LLM is an OpenAI model.""" + + def _make_stub(self, model_name): + """Create a minimal A1-like object with a mock LLM.""" + from biomni.agent.a1 import A1 + + stub = object.__new__(A1) # skip __init__ + stub.llm = MagicMock() + stub.llm.model_name = model_name + return stub + + def test_gpt5_detected(self): + stub = self._make_stub("gpt-5.2") + assert stub._is_openai_model is True + + def test_gpt4o_detected(self): + stub = self._make_stub("gpt-4o") + assert stub._is_openai_model is True + + def test_claude_not_detected(self): + stub = self._make_stub("claude-sonnet-4-5") + assert stub._is_openai_model is False + + def test_gemini_not_detected(self): + stub = self._make_stub("gemini-2.0-pro") + assert stub._is_openai_model is False + + +# --------------------------------------------------------------------------- +# 5. generate() — structured output path (OpenAI) +# --------------------------------------------------------------------------- +class TestGenerateStructuredOutput: + """Test that generate() uses structured output for OpenAI models and + produces correctly tagged messages for the downstream execute() node.""" + + def _build_generate(self, agent_response): + """Build the generate() function from a1.py with a mocked LLM. + + Returns (generate_fn, mock_llm) so tests can inspect calls. + """ + import re + + from biomni.agent.a1 import AgentResponse + from langchain_core.messages import AIMessage, SystemMessage + + mock_llm = MagicMock() + mock_structured = MagicMock() + mock_structured.invoke.return_value = agent_response + mock_llm.with_structured_output.return_value = mock_structured + mock_llm.model_name = "gpt-5.2" + + # Minimal self-like namespace + class FakeSelf: + llm = mock_llm + system_prompt = "You are a test agent." + + @property + def _is_openai_model(self): + model_name = getattr(self.llm, "model_name", "") or getattr(self.llm, "model", "") + return str(model_name).lower().startswith("gpt-") or "openai" in str(type(self.llm)).lower() + + self_obj = FakeSelf() + + # Reconstruct the generate function logic (mirrors a1.py) + def generate(state): + messages = [SystemMessage(content=self_obj.system_prompt)] + state["messages"] + + if self_obj._is_openai_model: + try: + structured_llm = self_obj.llm.with_structured_output(AgentResponse) + response = structured_llm.invoke(messages) + content = response.content + if response.action == "execute": + content = re.sub(r"^```(?:python|bash|r)?\s*\n?", "", content) + content = re.sub(r"\n?```\s*$", "", content) + tag = response.action + msg = f"{response.reasoning}\n<{tag}>{content}" + state["messages"].append(AIMessage(content=msg.strip())) + if response.action == "solution": + # Guard: reject premature solutions when no code has been executed + has_executed = any( + "" in (m.content if isinstance(m.content, str) else "") + for m in state["messages"] + ) + if not has_executed: + nudge = ( + "\nYou chose 'solution' but have not executed any code yet. " + "You MUST run code using the available tools before providing a final answer. " + "Choose action 'execute' and write code to accomplish the task." + ) + state["messages"].append(AIMessage(content=nudge.strip())) + state["next_step"] = "generate" + else: + state["next_step"] = "end" + else: + state["next_step"] = "execute" + return state + except Exception as e: + print(f"Structured output failed ({e}), falling back to XML parsing...") + response = self_obj.llm.invoke(messages) + msg = str(response.content) + else: + response = self_obj.llm.invoke(messages) + msg = str(response.content) + + # XML fallback + if "" in msg and "" not in msg: + msg += "" + if "" in msg and "" not in msg: + msg += "" + + execute_match = re.search(r"(.*?)", msg, re.DOTALL | re.IGNORECASE) + answer_match = re.search(r"(.*?)", msg, re.DOTALL | re.IGNORECASE) + + state["messages"].append(AIMessage(content=msg.strip())) + if answer_match: + state["next_step"] = "end" + elif execute_match: + state["next_step"] = "execute" + else: + state["next_step"] = "generate" + return state + + return generate, mock_llm + + def test_execute_action_produces_execute_tag(self): + from biomni.agent.a1 import AgentResponse + + resp = AgentResponse( + reasoning="I will query the database.", + action="execute", + content='from biomni.tool.database import query_geo\nresult = query_geo(prompt="test")\nprint(result)', + ) + generate, mock_llm = self._build_generate(resp) + + state = {"messages": [], "next_step": None} + result = generate(state) + + assert result["next_step"] == "execute" + last_msg = result["messages"][-1].content + assert "" in last_msg + assert "" in last_msg + assert "query_geo" in last_msg + assert "I will query the database." in last_msg + # Verify structured output was used (not raw invoke) + mock_llm.with_structured_output.assert_called_once_with(AgentResponse) + + def test_solution_action_produces_solution_tag_and_ends(self): + from biomni.agent.a1 import AgentResponse + from langchain_core.messages import AIMessage + + resp = AgentResponse( + reasoning="All steps complete. Here are the results.", + action="solution", + content="The top datasets are GSE123, GSE456, GSE789.", + ) + generate, _ = self._build_generate(resp) + + # With a prior (i.e. code has been executed), solution should end + state = { + "messages": [AIMessage(content="some output")], + "next_step": None, + } + result = generate(state) + + assert result["next_step"] == "end" + last_msg = result["messages"][-1].content + assert "" in last_msg + assert "" in last_msg + assert "GSE123" in last_msg + + def test_premature_solution_rejected_without_prior_execution(self): + """Agent cannot skip to 'solution' without executing code first.""" + from biomni.agent.a1 import AgentResponse + + resp = AgentResponse( + reasoning="I already know the answer.", + action="solution", + content="The answer is 42.", + ) + generate, _ = self._build_generate(resp) + + # No prior messages — no code has been executed + state = {"messages": [], "next_step": None} + result = generate(state) + + # Should be sent back to generate, not end + assert result["next_step"] == "generate" + # Should have a nudge message telling the agent to execute code + nudge_msg = result["messages"][-1].content + assert "have not executed any code" in nudge_msg + + def test_execute_content_extractable_by_downstream_regex(self): + """Verify the reconstructed message can be parsed by execute() regex.""" + import re + + from biomni.agent.a1 import AgentResponse + + code = "print('hello world')" + resp = AgentResponse(reasoning="Testing code.", action="execute", content=code) + generate, _ = self._build_generate(resp) + + state = {"messages": [], "next_step": None} + result = generate(state) + + last_msg = result["messages"][-1].content + match = re.search(r"(.*?)", last_msg, re.DOTALL) + assert match is not None, "execute() regex must find the code" + assert match.group(1) == code + + def test_markdown_code_fences_stripped_from_content(self): + """LLMs often wrap code in ```python fences even in structured output. These must be stripped.""" + import re + + from biomni.agent.a1 import AgentResponse + + # Simulate model returning markdown-fenced code + fenced_code = "```python\nprint('hello world')\n```" + resp = AgentResponse(reasoning="Running code.", action="execute", content=fenced_code) + generate, _ = self._build_generate(resp) + + state = {"messages": [], "next_step": None} + result = generate(state) + + last_msg = result["messages"][-1].content + match = re.search(r"(.*?)", last_msg, re.DOTALL) + assert match is not None + extracted = match.group(1) + assert "```" not in extracted, f"Markdown fences should be stripped, got: {extracted}" + assert "print('hello world')" in extracted + + def test_markdown_fences_not_stripped_from_solution(self): + """Markdown fences in solution content should be preserved (they may be intentional formatting).""" + from biomni.agent.a1 import AgentResponse + from langchain_core.messages import AIMessage + + content_with_fences = "Here is the code:\n```python\nx = 42\n```" + resp = AgentResponse(reasoning="Done.", action="solution", content=content_with_fences) + generate, _ = self._build_generate(resp) + + # Provide a prior observation so the premature-solution guard doesn't reject + state = { + "messages": [AIMessage(content="prior output")], + "next_step": None, + } + result = generate(state) + + last_msg = result["messages"][-1].content + assert "```python" in last_msg, "Solution content should preserve markdown fences" + + +# --------------------------------------------------------------------------- +# 6. generate() — XML fallback when structured output fails +# --------------------------------------------------------------------------- +class TestGenerateXMLFallback: + """Test that generate() falls back to XML parsing when structured output raises.""" + + def test_fallback_on_structured_output_error(self): + import re + + from biomni.agent.a1 import AgentResponse + from langchain_core.messages import AIMessage, SystemMessage + + mock_llm = MagicMock() + mock_llm.model_name = "gpt-4o" + # Structured output raises + mock_llm.with_structured_output.side_effect = Exception("unsupported") + # Raw invoke returns XML-tagged content + mock_raw_response = MagicMock() + mock_raw_response.content = "Thinking...\nprint('fallback')" + mock_llm.invoke.return_value = mock_raw_response + + class FakeSelf: + llm = mock_llm + system_prompt = "You are a test agent." + + @property + def _is_openai_model(self): + model_name = getattr(self.llm, "model_name", "") or getattr(self.llm, "model", "") + return str(model_name).lower().startswith("gpt-") or "openai" in str(type(self.llm)).lower() + + self_obj = FakeSelf() + + def generate(state): + messages = [SystemMessage(content=self_obj.system_prompt)] + state["messages"] + if self_obj._is_openai_model: + try: + structured_llm = self_obj.llm.with_structured_output(AgentResponse) + response = structured_llm.invoke(messages) + content = response.content + if response.action == "execute": + content = re.sub(r"^```(?:python|bash|r)?\s*\n?", "", content) + content = re.sub(r"\n?```\s*$", "", content) + tag = response.action + msg = f"{response.reasoning}\n<{tag}>{content}" + state["messages"].append(AIMessage(content=msg.strip())) + if response.action == "solution": + # Guard: reject premature solutions when no code has been executed + has_executed = any( + "" in (m.content if isinstance(m.content, str) else "") + for m in state["messages"] + ) + if not has_executed: + nudge = ( + "\nYou chose 'solution' but have not executed any code yet. " + "You MUST run code using the available tools before providing a final answer. " + "Choose action 'execute' and write code to accomplish the task." + ) + state["messages"].append(AIMessage(content=nudge.strip())) + state["next_step"] = "generate" + else: + state["next_step"] = "end" + else: + state["next_step"] = "execute" + return state + except Exception as e: + print(f"Structured output failed ({e}), falling back to XML parsing...") + response = self_obj.llm.invoke(messages) + content = response.content + msg = str(content) if not isinstance(content, list) else "" + else: + response = self_obj.llm.invoke(messages) + msg = str(response.content) + + if "" in msg and "" not in msg: + msg += "" + execute_match = re.search(r"(.*?)", msg, re.DOTALL | re.IGNORECASE) + answer_match = re.search(r"(.*?)", msg, re.DOTALL | re.IGNORECASE) + state["messages"].append(AIMessage(content=msg.strip())) + if answer_match: + state["next_step"] = "end" + elif execute_match: + state["next_step"] = "execute" + else: + state["next_step"] = "generate" + return state + + state = {"messages": [], "next_step": None} + result = generate(state) + + assert result["next_step"] == "execute" + assert "fallback" in result["messages"][-1].content + + +# --------------------------------------------------------------------------- +# 7. System prompt differs by provider +# --------------------------------------------------------------------------- +class TestSystemPromptBranching: + """Test that the system prompt format instructions differ for OpenAI vs Claude.""" + + def _make_stub(self, model_name): + from biomni.agent.a1 import A1 + + stub = object.__new__(A1) + stub.llm = MagicMock() + stub.llm.model_name = model_name + stub.data_lake_dict = {} + stub.library_content_dict = {} + stub.path = "/tmp/test_biomni/biomni_data" + stub.commercial_mode = False + return stub + + def test_openai_prompt_mentions_structured_fields(self): + """When _is_openai_model is True, prompt should reference reasoning/action/content fields.""" + stub = self._make_stub("gpt-5.2") + + prompt = stub._generate_system_prompt( + tool_desc={}, + data_lake_content=[], + library_content_list=[], + ) + assert '"reasoning"' in prompt, "OpenAI prompt should mention the reasoning field" + assert '"action"' in prompt, "OpenAI prompt should mention the action field" + assert '"content"' in prompt, "OpenAI prompt should mention the content field" + # Should NOT contain the XML format instructions + assert "Your code should be enclosed using" not in prompt + + def test_claude_prompt_mentions_xml_tags(self): + """When _is_openai_model is False, prompt should reference / tags.""" + stub = self._make_stub("claude-sonnet-4-5") + + prompt = stub._generate_system_prompt( + tool_desc={}, + data_lake_content=[], + library_content_list=[], + ) + assert "" in prompt, "Claude prompt should mention tag" + assert "" in prompt, "Claude prompt should mention tag" + # Should NOT contain structured output field names + assert '"reasoning"' not in prompt + + +# --------------------------------------------------------------------------- +# 8. Live API tests (require --live flag and API keys) +# --------------------------------------------------------------------------- +@pytest.mark.live +class TestLiveOpenAIStructuredOutput: + """Tests that call the real OpenAI API with structured output. + + Run with: pytest tests/test_openai_structured_output.py -v --live + Requires OPENAI_API_KEY in environment. + """ + + @pytest.fixture(autouse=True) + def check_openai_key(self): + if not os.environ.get("OPENAI_API_KEY"): + pytest.skip("OPENAI_API_KEY not set") + + def test_structured_output_execute(self): + """Live call: OpenAI model returns a valid AgentResponse with action=execute.""" + from biomni.agent.a1 import AgentResponse + from biomni.llm import get_llm + + llm = get_llm(model="gpt-4o", temperature=0.0) + structured_llm = llm.with_structured_output(AgentResponse) + + from langchain_core.messages import HumanMessage, SystemMessage + + messages = [ + SystemMessage(content="You are a coding assistant. Respond with action='execute' and provide Python code."), + HumanMessage(content="Write a one-line Python print statement that says hello world."), + ] + response = structured_llm.invoke(messages) + + assert isinstance(response, AgentResponse) + assert response.action == "execute" + assert "print" in response.content.lower() + assert len(response.reasoning) > 0 + print(f" reasoning: {response.reasoning[:80]}...") + print(f" action: {response.action}") + print(f" content: {response.content}") + + def test_structured_output_solution(self): + """Live call: OpenAI model returns a valid AgentResponse with action=solution.""" + from biomni.agent.a1 import AgentResponse + from biomni.llm import get_llm + + llm = get_llm(model="gpt-4o", temperature=0.0) + structured_llm = llm.with_structured_output(AgentResponse) + + from langchain_core.messages import HumanMessage, SystemMessage + + messages = [ + SystemMessage( + content=( + "You are a helpful assistant. The user's task is already complete. " + "Respond with action='solution' and provide a final answer." + ) + ), + HumanMessage(content="What is 2 + 2? Give the final answer directly."), + ] + response = structured_llm.invoke(messages) + + assert isinstance(response, AgentResponse) + assert response.action == "solution" + assert "4" in response.content + print(f" reasoning: {response.reasoning[:80]}...") + print(f" action: {response.action}") + print(f" content: {response.content}") + + def test_structured_output_roundtrip_with_xml_reconstruction(self): + """Live call: structured response reconstructs into XML that execute() regex can parse.""" + import re + + from biomni.agent.a1 import AgentResponse + from biomni.llm import get_llm + + llm = get_llm(model="gpt-4o", temperature=0.0) + structured_llm = llm.with_structured_output(AgentResponse) + + from langchain_core.messages import HumanMessage, SystemMessage + + messages = [ + SystemMessage(content="You are a coding assistant. Respond with action='execute'."), + HumanMessage(content="Write Python code: x = 1 + 1; print(x)"), + ] + response = structured_llm.invoke(messages) + + # Reconstruct the same way generate() does + tag = response.action + msg = f"{response.reasoning}\n<{tag}>{response.content}" + + # Verify the downstream execute() regex can extract the code + match = re.search(r"(.*?)", msg, re.DOTALL) + assert match is not None, "Reconstructed message must be parseable by execute()" + assert "print" in match.group(1) + print(f" Extracted code: {match.group(1)}") + + def test_gpt5_structured_output_if_available(self): + """Live call with gpt-5 (if available). Skips gracefully if model not accessible.""" + from biomni.agent.a1 import AgentResponse + from biomni.llm import get_llm + + try: + llm = get_llm(model="gpt-5", temperature=0.0) + except Exception as e: + pytest.skip(f"gpt-5 not available: {e}") + + structured_llm = llm.with_structured_output(AgentResponse) + + from langchain_core.messages import HumanMessage, SystemMessage + + messages = [ + SystemMessage(content="You are a coding assistant. Respond with action='execute'."), + HumanMessage(content="Write Python code that prints the numbers 1 through 5."), + ] + try: + response = structured_llm.invoke(messages) + except Exception as e: + pytest.skip(f"gpt-5 structured output call failed: {e}") + + assert isinstance(response, AgentResponse) + assert response.action == "execute" + print(f" gpt-5 reasoning: {response.reasoning[:80]}...") + print(f" gpt-5 content: {response.content[:120]}") + + +@pytest.mark.live +class TestLiveClaudeXML: + """Tests that call the real Anthropic API with XML tag format. + + Run with: pytest tests/test_openai_structured_output.py -v --live + Requires ANTHROPIC_API_KEY in environment. + """ + + @pytest.fixture(autouse=True) + def check_anthropic_key(self): + if not os.environ.get("ANTHROPIC_API_KEY"): + pytest.skip("ANTHROPIC_API_KEY not set") + + def test_claude_produces_execute_tag(self): + """Live call: Claude model produces tags without structured output.""" + from biomni.llm import get_llm + + llm = get_llm(model="claude-haiku-4-5", temperature=0.0) + + from langchain_core.messages import HumanMessage, SystemMessage + + messages = [ + SystemMessage( + content=( + "You are a coding assistant. Always respond with reasoning followed by code " + "inside tags. Example: I will run code.\nprint('hi')" + ) + ), + HumanMessage(content="Write a Python print statement that says hello world."), + ] + response = llm.invoke(messages) + content = str(response.content) + + assert "" in content, f"Claude should produce tag, got: {content[:200]}" + assert "" in content, f"Claude should close the tag, got: {content[:200]}" + print(f" Claude response: {content[:200]}...") + + def test_claude_produces_solution_tag(self): + """Live call: Claude model produces tags.""" + from biomni.llm import get_llm + + llm = get_llm(model="claude-haiku-4-5", temperature=0.0) + + from langchain_core.messages import HumanMessage, SystemMessage + + messages = [ + SystemMessage( + content=( + "You are a helpful assistant. Provide your final answer inside " + " tags. Example: The answer is 42" + ) + ), + HumanMessage(content="What is 2 + 2? Give the final answer."), + ] + response = llm.invoke(messages) + content = str(response.content) + + assert "" in content, f"Claude should produce tag, got: {content[:200]}" + assert "4" in content + print(f" Claude response: {content[:200]}...") + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])