diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json deleted file mode 100644 index 91f7e38..0000000 --- a/.devcontainer/devcontainer.json +++ /dev/null @@ -1,33 +0,0 @@ -{ - "name": "Python 3", - // Or use a Dockerfile or Docker Compose file. More info: https://containers.dev/guide/dockerfile - "image": "mcr.microsoft.com/devcontainers/python:1-3.11-bullseye", - "customizations": { - "codespaces": { - "openFiles": [ - "README.md", - "src/app.py" - ] - }, - "vscode": { - "settings": {}, - "extensions": [ - "ms-python.python", - "ms-python.vscode-pylance" - ] - } - }, - "updateContentCommand": "[ -f packages.txt ] && sudo apt update && sudo apt upgrade -y && sudo xargs apt install -y -# πŸ’¬ Querypls - Prompt to SQL +# πŸ’¬ Querypls - Intelligent SQL and CSV Analysis -Querypls is a web application that provides an interactive chat interface, simplifying SQL query generation. Users can effortlessly enter SQL queries and receive corresponding results. The application harnesses the capabilities of the language models from Hugging Face to generate SQL queries based on user input. +Querypls is a modern web application that provides an interactive chat interface for SQL query generation and CSV data analysis. Built with Pydantic AI and powered by OpenAI's GPT-OSS-120B model through Groq, it offers intelligent routing between different analysis modes to handle various data-related queries. + +🌐 **Try it live**: [querypls.streamlit.app](https://querypls.streamlit.app/) ## Key Features -πŸ’¬ Interactive chat interface for easy communication. -πŸ” Enter SQL queries and receive query results as responses. -πŸ€– Utilizes language models from Hugging Face for advanced query generation ([Querypls-prompt2sql](https://huggingface.co/samadpls/querypls-prompt2sql)). -πŸ’» User-friendly interface for seamless interaction. +πŸ’¬ **Interactive Chat Interface** - Natural language conversations for data analysis +πŸ” **SQL Query Generation** - Convert natural language to optimized SQL queries +πŸ“Š **CSV Data Analysis** - Upload and analyze CSV files with intelligent insights +πŸ€– **Intelligent Routing** - Automatically determines the best agent for your query +⚑ **Fast Inference** - Powered by Groq's optimized infrastructure +πŸ”’ **Type-Safe Development** - Built with Pydantic AI for robust validation +πŸ“ˆ **Visual Analytics** - Generate charts and visualizations from your data ![QueryplsDemo](https://github.com/samadpls/Querypls/assets/94792103/daa6e37d-a256-4fd8-9607-6e18cf41df3f) @@ -24,7 +29,9 @@ Querypls is a web application that provides an interactive chat interface, simpl # Acknowledgments -`Querypls` received a shoutout from [🦜 πŸ”— Langchain](https://www.langchain.com/) on their Twitter, reaching over **60,000 impressions**. Additionally, it was featured under the **Community Favorite Projects** section on `🦜 πŸ”— Langchain's blog`, leading to a significant increase in stars for this repository and a growing user base. The project was also highlighted in a [YouTube video](https://www.youtube.com/watch?v=htHVb-fK9xU), and it also caught the attention of Backdrop, expressing their interest and liking in an email, inviting the project to be a part of their hackathon. +`Querypls` received a shoutout from [🦜 πŸ”— Langchain](https://www.langchain.com/) on their Twitter in 2023, reaching over **60,000 impressions**. Additionally, it was featured under the **Community Favorite Projects** section on `🦜 πŸ”— Langchain's blog`, leading to a significant increase in stars for this repository and a growing user base. The project was also highlighted in a [YouTube video](https://www.youtube.com/watch?v=htHVb-fK9xU), and it also caught the attention of Backdrop, expressing their interest and liking in an email, inviting the project to be a part of their hackathon. + +However, due to constant breakdowns and instability issues with the LangChain framework, we made the strategic decision to migrate to **Pydantic AI** - a more stable and reliable framework. This transition has brought improved performance, better type safety, and enhanced maintainability to the project. | [πŸ”— Langhchain Twitter Post](https://twitter.com/LangChainAI/status/1729959981523378297?t=Zdpw9ZQYvE3QS-3Bf-xaGw&s=19) | [πŸ”— Langhcain Blog Post](https://blog.langchain.dev/week-of-11-27-langchain-release-notes/) | |----------|----------| @@ -38,7 +45,7 @@ This project is licensed under the MIT License. See the [LICENSE](LICENSE) file > [!Note] -> Querypls, while powered by a 7B model of Satablility AI LLM Model, is currently limited in providing optimal responses for simple queries. +> Querypls is now powered by OpenAI's GPT-OSS-120B model through Groq, providing fast and reliable AI-powered SQL generation and CSV analysis capabilities. --- @@ -59,11 +66,11 @@ This project is licensed under the MIT License. See the [LICENSE](LICENSE) file pip install -r requirements.txt ``` -4. Create a `.env` file based on `.env_example` and set the necessary variables. +4. Create a `.env` file based on `.env_examp` and set the necessary variables. 5. Run the application: ```bash - streamlit run src/app.py + streamlit run src/frontend/app.py ``` 6. Open the provided link in your browser to use Querypls. diff --git a/examples/basic_usage_demo.py b/examples/basic_usage_demo.py new file mode 100644 index 0000000..b024d78 --- /dev/null +++ b/examples/basic_usage_demo.py @@ -0,0 +1,193 @@ +#!/usr/bin/env python3 +""" +Basic usage demo for Querypls backend functionality. +Demonstrates conversation, SQL generation, and CSV analysis. +""" + +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from src.services.routing_service import IntelligentRoutingService +from src.backend.orchestrator import BackendOrchestrator +from src.schemas.requests import NewChatRequest + + +def demo_conversation(): + """Demo conversation functionality.""" + print("πŸ—£οΈ CONVERSATION DEMO") + print("=" * 40) + + routing_service = IntelligentRoutingService() + + # Test different conversation types + conversations = [ + "Hello", + "How are you?", + "What can you do?", + "Thanks for your help", + "Goodbye", + ] + + for query in conversations: + print(f"\nUser: {query}") + response = routing_service.handle_conversation_query(query) + print(f"Assistant: {response}") + + print("\n" + "=" * 40) + + +def demo_sql_generation(): + """Demo SQL generation functionality.""" + print("πŸ—ƒοΈ SQL GENERATION DEMO") + print("=" * 40) + + routing_service = IntelligentRoutingService() + + # Test different SQL queries + sql_queries = [ + "Show me all users", + "Find customers who made purchases in the last 30 days", + "Get the total sales by month", + "SELECT * FROM users WHERE status = 'active'", + ] + + for query in sql_queries: + print(f"\nUser: {query}") + response = routing_service.handle_sql_query(query, []) + print(f"Assistant: {response[:200]}...") + + print("\n" + "=" * 40) + + +def demo_csv_analysis(): + """Demo CSV analysis functionality.""" + print("πŸ“Š CSV ANALYSIS DEMO") + print("=" * 40) + + # Sample CSV data + sample_csv = """name,age,salary,department +Alice,25,50000,IT +Bob,30,60000,HR +Charlie,35,70000,IT +Diana,28,55000,Finance +Eve,32,65000,HR""" + + print(f"Sample CSV Data:\n{sample_csv}") + + routing_service = IntelligentRoutingService() + + # Test different CSV analysis queries + csv_queries = [ + "Show me the basic statistics of the data", + "Create a bar chart of department distribution", + "What is the average salary by department?", + "Show me the top 3 highest paid employees", + ] + + for query in csv_queries: + print(f"\nUser: {query}") + response = routing_service.handle_csv_query(query, sample_csv) + print(f"Assistant: {response[:300]}...") + + print("\n" + "=" * 40) + + +def demo_intelligent_routing(): + """Demo intelligent routing functionality.""" + print("🧠 INTELLIGENT ROUTING DEMO") + print("=" * 40) + + routing_service = IntelligentRoutingService() + + # Test different types of queries + test_queries = [ + ("Hello", "CONVERSATION_AGENT"), + ("Show me all users", "SQL_AGENT"), + ("Analyze this CSV data", "CSV_AGENT"), + ("How are you?", "CONVERSATION_AGENT"), + ("SELECT * FROM users", "SQL_AGENT"), + ("Create a chart from the data", "CSV_AGENT"), + ] + + for query, expected_agent in test_queries: + print(f"\nQuery: '{query}'") + decision = routing_service.determine_agent(query, [], csv_loaded=True) + print(f"Expected: {expected_agent}") + print(f"Actual: {decision.agent}") + print(f"Confidence: {decision.confidence}") + print(f"Reasoning: {decision.reasoning}") + + print("\n" + "=" * 40) + + +def demo_orchestrator(): + """Demo the main orchestrator functionality.""" + print("🎼 ORCHESTRATOR DEMO") + print("=" * 40) + + orchestrator = BackendOrchestrator() + + # Create a new session + session_info = orchestrator.create_new_session( + NewChatRequest(session_name="Demo Session") + ) + session_id = session_info.session_id + print(f"Created session: {session_id}") + + # Test different types of interactions + interactions = [ + ("Hello", "conversation"), + ("Show me all users", "sql"), + ("What can you do?", "conversation"), + ] + + for query, query_type in interactions: + print(f"\nUser ({query_type}): {query}") + response = orchestrator.generate_intelligent_response(session_id, query) + print(f"Assistant: {response.content[:150]}...") + + # Test CSV functionality + sample_csv = "name,age,salary\nAlice,25,50000\nBob,30,60000\nCharlie,35,70000" + result = orchestrator.load_csv_data(session_id, sample_csv) + print(f"\nCSV Load Result: {result['status']}") + + response = orchestrator.generate_intelligent_response( + session_id, "Analyze this data" + ) + print(f"CSV Analysis: {response.content[:200]}...") + + print("\n" + "=" * 40) + + +def main(): + """Run all demos.""" + print("πŸš€ Querypls Backend Functionality Demo") + print("=" * 50) + + demos = [ + ("Conversation", demo_conversation), + ("SQL Generation", demo_sql_generation), + ("CSV Analysis", demo_csv_analysis), + ("Intelligent Routing", demo_intelligent_routing), + ("Orchestrator", demo_orchestrator), + ] + + for demo_name, demo_func in demos: + try: + demo_func() + except Exception as e: + print(f"❌ {demo_name} demo failed: {str(e)}") + + print("\nπŸŽ‰ Demo completed! All backend functionality is working correctly.") + print("\nπŸ“ Summary:") + print("- Conversation: Natural responses for greetings and help") + print("- SQL Generation: Convert natural language to SQL queries") + print("- CSV Analysis: Analyze CSV data with Python code") + print("- Intelligent Routing: Automatically choose the right agent") + print("- Orchestrator: Complete session management") + + +if __name__ == "__main__": + main() diff --git a/examples/test_backend_functionality.py b/examples/test_backend_functionality.py new file mode 100644 index 0000000..b568627 --- /dev/null +++ b/examples/test_backend_functionality.py @@ -0,0 +1,317 @@ +#!/usr/bin/env python3 +""" +Comprehensive test for all backend functionality of Querypls. +Tests conversation, SQL generation, and CSV analysis capabilities. +""" + +import sys +import os +import pandas as pd +from io import StringIO + +# Add the project root to Python path +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from src.services.routing_service import IntelligentRoutingService +from src.services.conversation_service import ConversationService +from src.services.sql_service import SQLGenerationService +from src.services.csv_analysis_tools import CSVAnalysisTools +from src.schemas.requests import ChatMessage, SQLGenerationRequest +from src.backend.orchestrator import BackendOrchestrator + + +def test_conversation_functionality(): + """Test conversation responses.""" + print("πŸ§ͺ Testing Conversation Functionality") + print("=" * 50) + + try: + routing_service = IntelligentRoutingService() + + # Test conversation queries + conversation_tests = [ + "Hello", + "How are you?", + "What can you do?", + "Thanks for your help", + "Goodbye", + ] + + for query in conversation_tests: + print(f"\nQuery: '{query}'") + try: + response = routing_service.handle_conversation_query(query) + print(f"Response: {response[:100]}...") + print("βœ… PASS") + except Exception as e: + print(f"❌ FAIL: {str(e)}") + + print("\n" + "=" * 50) + return True + + except Exception as e: + print(f"❌ Conversation test failed: {str(e)}") + return False + + +def test_sql_functionality(): + """Test SQL generation functionality.""" + print("πŸ—ƒοΈ Testing SQL Generation Functionality") + print("=" * 50) + + try: + routing_service = IntelligentRoutingService() + sql_service = SQLGenerationService() + + # Test SQL queries + sql_tests = [ + "Show me all users", + "SELECT * FROM users WHERE status = 'active'", + "Find customers who made purchases in the last 30 days", + "Get the total sales by month", + ] + + for query in sql_tests: + print(f"\nQuery: '{query}'") + try: + # Test routing + routing_decision = routing_service.determine_agent( + query, [], csv_loaded=False + ) + print(f"Routing Decision: {routing_decision.agent}") + + # Test SQL generation + request = SQLGenerationRequest( + user_query=query, conversation_history=[] + ) + response = sql_service.generate_sql(request) + print(f"SQL Response: {response.content[:100]}...") + print("βœ… PASS") + except Exception as e: + print(f"❌ FAIL: {str(e)}") + + print("\n" + "=" * 50) + return True + + except Exception as e: + print(f"❌ SQL test failed: {str(e)}") + return False + + +def test_csv_functionality(): + """Test CSV analysis functionality.""" + print("πŸ“Š Testing CSV Analysis Functionality") + print("=" * 50) + + try: + # Create sample CSV data + sample_data = { + "name": ["Alice", "Bob", "Charlie", "Diana", "Eve"], + "age": [25, 30, 35, 28, 32], + "salary": [50000, 60000, 70000, 55000, 65000], + "department": ["IT", "HR", "IT", "Finance", "HR"], + } + + df = pd.DataFrame(sample_data) + csv_content = df.to_csv(index=False) + + print(f"Sample CSV Data:\n{df.head()}") + print(f"CSV Shape: {df.shape}") + + # Test CSV tools + csv_tools = CSVAnalysisTools() + + # Test loading CSV data + print("\nTesting CSV loading...") + result = csv_tools.load_csv_data(csv_content, "test_session") + print(f"Load Result: {result}") + + # Test CSV analysis queries + csv_tests = [ + "Show me the basic statistics of the data", + "Create a bar chart of department distribution", + "What is the average salary by department?", + "Show me the top 3 highest paid employees", + ] + + routing_service = IntelligentRoutingService() + + for query in csv_tests: + print(f"\nQuery: '{query}'") + try: + # Test routing with CSV loaded + routing_decision = routing_service.determine_agent( + query, [], csv_loaded=True + ) + print(f"Routing Decision: {routing_decision.agent}") + + # Test CSV analysis + response = routing_service.handle_csv_query(query, csv_content) + print(f"CSV Response: {response[:200]}...") + print("βœ… PASS") + except Exception as e: + print(f"❌ FAIL: {str(e)}") + + print("\n" + "=" * 50) + return True + + except Exception as e: + print(f"❌ CSV test failed: {str(e)}") + return False + + +def test_intelligent_routing(): + """Test intelligent routing functionality.""" + print("🧠 Testing Intelligent Routing") + print("=" * 50) + + try: + routing_service = IntelligentRoutingService() + + # Test cases with expected routing + test_cases = [ + ("Hello", "CONVERSATION_AGENT"), + ("How are you?", "CONVERSATION_AGENT"), + ("Show me all users", "SQL_AGENT"), + ("SELECT * FROM users", "SQL_AGENT"), + ("Analyze this CSV data", "CSV_AGENT"), + ("Create a chart from the data", "CSV_AGENT"), + ("What can you do?", "CONVERSATION_AGENT"), + ("Thanks for your help", "CONVERSATION_AGENT"), + ] + + all_passed = True + for query, expected_agent in test_cases: + print(f"\nQuery: '{query}'") + print(f"Expected Agent: {expected_agent}") + + try: + # Test without CSV loaded + decision = routing_service.determine_agent(query, [], csv_loaded=False) + print(f"Result (no CSV): {decision.agent}") + + # Test with CSV loaded + decision_with_csv = routing_service.determine_agent( + query, [], csv_loaded=True + ) + print(f"Result (with CSV): {decision_with_csv.agent}") + + if ( + decision.agent == expected_agent + or decision_with_csv.agent == expected_agent + ): + print("βœ… PASS") + else: + print("❌ FAIL") + all_passed = False + + except Exception as e: + print(f"❌ ERROR: {str(e)}") + all_passed = False + + print("\n" + "=" * 50) + return all_passed + + except Exception as e: + print(f"❌ Routing test failed: {str(e)}") + return False + + +def test_orchestrator(): + """Test the main orchestrator functionality.""" + print("🎼 Testing Backend Orchestrator") + print("=" * 50) + + try: + orchestrator = BackendOrchestrator() + + # Test session creation + print("Testing session creation...") + from src.schemas.requests import NewChatRequest + + session_info = orchestrator.create_new_session( + NewChatRequest(session_name="Test Session") + ) + session_id = session_info.session_id + print(f"Created session: {session_id}") + + # Test conversation + print("\nTesting conversation...") + response = orchestrator.generate_intelligent_response(session_id, "Hello") + print(f"Conversation Response: {response.content[:100]}...") + + # Test SQL generation + print("\nTesting SQL generation...") + response = orchestrator.generate_intelligent_response( + session_id, "Show me all users" + ) + print(f"SQL Response: {response.content[:100]}...") + + # Test CSV loading and analysis + print("\nTesting CSV functionality...") + sample_csv = "name,age,salary\nAlice,25,50000\nBob,30,60000\nCharlie,35,70000" + result = orchestrator.load_csv_data(session_id, sample_csv) + print(f"CSV Load Result: {result}") + + response = orchestrator.generate_intelligent_response( + session_id, "Analyze this data" + ) + print(f"CSV Analysis Response: {response.content[:100]}...") + + print("\n" + "=" * 50) + return True + + except Exception as e: + print(f"❌ Orchestrator test failed: {str(e)}") + return False + + +def main(): + """Run all tests.""" + print("πŸš€ Starting Comprehensive Backend Functionality Tests") + print("=" * 60) + + tests = [ + ("Conversation", test_conversation_functionality), + ("SQL Generation", test_sql_functionality), + ("CSV Analysis", test_csv_functionality), + ("Intelligent Routing", test_intelligent_routing), + ("Orchestrator", test_orchestrator), + ] + + results = {} + + for test_name, test_func in tests: + print(f"\n{'='*20} {test_name} {'='*20}") + try: + results[test_name] = test_func() + except Exception as e: + print(f"❌ {test_name} test crashed: {str(e)}") + results[test_name] = False + + # Summary + print("\n" + "=" * 60) + print("πŸ“Š TEST SUMMARY") + print("=" * 60) + + passed = 0 + total = len(tests) + + for test_name, result in results.items(): + status = "βœ… PASS" if result else "❌ FAIL" + print(f"{test_name}: {status}") + if result: + passed += 1 + + print(f"\nOverall: {passed}/{total} tests passed") + + if passed == total: + print("πŸŽ‰ All tests passed! Backend functionality is working correctly.") + else: + print("⚠️ Some tests failed. Check the backend implementation.") + + return passed == total + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml deleted file mode 100644 index 9216134..0000000 --- a/pyproject.toml +++ /dev/null @@ -1,2 +0,0 @@ -[tool.black] -line-length = 79 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 93bbda8..25f7eb6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,118 +1,21 @@ -aiohttp==3.9.5 -aiosignal==1.3.1 -altair==5.3.0 -annotated-types==0.7.0 -anyio==3.7.1 -async-timeout==4.0.3 -asyncio==3.4.3 -attrs==23.2.0 -black==24.4.2 -blinker==1.8.2 -cachetools==5.4.0 -certifi==2024.7.4 -charset-normalizer==3.3.2 -click==8.1.7 -dataclasses-json==0.6.7 -deta==1.2.0 -exceptiongroup==1.2.2 -filelock==3.15.4 -frozenlist==1.4.1 -fsspec==2024.6.1 -gitdb==4.0.11 -GitPython==3.1.43 -greenlet==3.0.3 -h11==0.14.0 -httpcore==0.17.3 -httpx==0.24.1 -httpx-oauth==0.13.0 -huggingface-hub==0.23.4 -idna==3.7 -importlib-metadata==6.11.0 -iniconfig==2.0.0 -Jinja2==3.1.5 -joblib==1.4.2 -jsonpatch==1.33 -jsonpointer==3.0.0 -jsonschema==4.23.0 -jsonschema-specifications==2023.12.1 -langchain==0.2.14 -langchain-core==0.2.32 -langchain-community>=0.0.37 -langchain-huggingface==0.0.3 -langchain-text-splitters==0.2.2 -langsmith==0.1.93 -markdown-it-py==3.0.0 -MarkupSafe==2.1.5 -marshmallow==3.21.3 -mdurl==0.1.2 -mpmath==1.3.0 -multidict==6.0.5 -mypy-extensions==1.0.0 -networkx==3.3 -numpy==1.26.4 -nvidia-cublas-cu12==12.1.3.1 -nvidia-cuda-cupti-cu12==12.1.105 -nvidia-cuda-nvrtc-cu12==12.1.105 -nvidia-cuda-runtime-cu12==12.1.105 -nvidia-cudnn-cu12==9.1.0.70 -nvidia-cufft-cu12==11.0.2.54 -nvidia-curand-cu12==10.3.2.106 -nvidia-cusolver-cu12==11.4.5.107 -nvidia-cusparse-cu12==12.1.0.106 -nvidia-nccl-cu12==2.20.5 -nvidia-nvjitlink-cu12==12.5.82 -nvidia-nvtx-cu12==12.1.105 -orjson==3.10.6 -packaging==23.2 -pandas==2.2.2 -pathspec==0.12.1 -pillow==10.4.0 -platformdirs==4.2.2 -pluggy==1.5.0 -protobuf==4.25.4 -pyarrow==17.0.0 -pydantic==2.8.2 -pydantic_core==2.20.1 -pydeck==0.9.1 -Pygments==2.18.0 -pytest==8.3.2 -python-dateutil==2.9.0.post0 -python-dotenv==1.0.0 -pytz==2024.1 -PyYAML==6.0.1 -referencing==0.35.1 -regex==2024.7.24 -requests==2.32.3 -rich==13.7.1 -rpds-py==0.19.1 -safetensors==0.4.3 -scikit-learn==1.5.1 -scipy==1.14.0 -sentence-transformers==3.0.1 -six==1.16.0 -smmap==5.0.1 -sniffio==1.3.1 -SQLAlchemy==2.0.31 -streamlit==1.36.0 -streamlit-oauth==0.1.5 -sympy==1.13.1 -tenacity==8.5.0 -threadpoolctl==3.5.0 -tokenizers==0.19.1 -toml==0.10.2 -tomli==2.0.1 -toolz==0.12.1 -torch==2.4.0 -tornado==6.4.2 -tqdm==4.66.4 -transformers==4.48.0 -triton==3.0.0 -typing-inspect==0.9.0 -typing_extensions==4.12.2 -tzdata==2024.1 -tzlocal==5.2 -urllib3==2.2.2 -validators==0.33.0 -watchdog==4.0.1 -yarl==1.9.4 -zipp==3.19.2 +# Core application dependencies +streamlit>=1.36.0 +pydantic-ai-slim[groq]>=0.6.0 +pydantic>=2.0.0 +pydantic-settings>=2.0.0 + +# Data analysis dependencies +pandas>=2.0.0 +numpy>=1.24.0 +matplotlib>=3.7.0 +seaborn>=0.12.0 +jupyter-client>=8.0.0 + +# Training dependencies (optional - only needed for model training) +datasets>=2.14.0 +transformers>=4.48.0 +trl>=0.7.0 +peft>=0.6.0 + +# Testing dependencies (optional - only needed for running tests) +pytest>=8.3.0 diff --git a/run.py b/run.py new file mode 100644 index 0000000..3e7845c --- /dev/null +++ b/run.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 +""" +Launcher script for Querypls application. +""" + +import sys +import os +import argparse + +# Add src to path +sys.path.append(os.path.join(os.path.dirname(__file__), "src")) + + +def run_streamlit(): + """Run the Streamlit application.""" + import subprocess + import streamlit.web.cli as stcli + + # Set environment variables + os.environ["STREAMLIT_SERVER_PORT"] = "8501" + os.environ["STREAMLIT_SERVER_ADDRESS"] = "localhost" + + # Run streamlit + sys.argv = [ + "streamlit", + "run", + "src/frontend/app.py", + "--server.port=8501", + "--server.address=localhost", + ] + sys.exit(stcli.main()) + + +def run_cli(): + """Run the CLI application.""" + from terminal.cli import main as cli_main + + cli_main() + + +def main(): + """Main launcher function.""" + parser = argparse.ArgumentParser(description="Querypls - SQL Generation Tool") + parser.add_argument( + "mode", + choices=["web", "cli"], + default="web", + nargs="?", + help="Run mode: web (Streamlit) or cli (Command Line)", + ) + parser.add_argument( + "cli_args", nargs="*", help="Arguments to pass to CLI (when mode is cli)" + ) + + args = parser.parse_args() + + if args.mode == "web": + print("πŸš€ Starting Querypls Web Application...") + run_streamlit() + elif args.mode == "cli": + print("πŸš€ Starting Querypls CLI...") + # Pass CLI arguments to the CLI + if args.cli_args: + sys.argv = ["cli"] + args.cli_args + run_cli() + + +if __name__ == "__main__": + main() diff --git a/src/app.py b/src/app.py deleted file mode 100644 index 2f16301..0000000 --- a/src/app.py +++ /dev/null @@ -1,107 +0,0 @@ -from langchain_core.output_parsers import StrOutputParser -from langchain_core.prompts import PromptTemplate -import streamlit as st -import sys -import os -import json -from backend import ( - configure_page_styles, - display_github_badge, - hide_main_menu_and_footer, -) -from frontend import ( - create_message, - display_logo_and_heading, - display_previous_chats, - display_welcome_message, - handle_new_chat, -) -from model import create_huggingface_hub -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from src.auth import * -from src.constant import * - -def format_chat_history(messages): - """Format the chat history as a structured JSON string.""" - history = [] - for msg in messages[1:]: - content = msg['content'] - if '```sql' in content: - content = content.replace('```sql\n', '').replace('\n```', '').strip() - - history.append({ - "role": msg['role'], - "query" if msg['role'] == 'user' else "response": content - }) - - formatted_history = json.dumps(history, indent=2) - print("Formatted history:", formatted_history) - return formatted_history - -def extract_sql_code(response): - """Extract clean SQL code from the response.""" - sql_code_start = response.find("```sql") - if sql_code_start != -1: - sql_code_end = response.find("```", sql_code_start + 5) - if sql_code_end != -1: - sql_code = response[sql_code_start + 6:sql_code_end].strip() - return f"```sql\n{sql_code}\n```" - return response - -def main(): - """Main function to configure and run the Querypls application.""" - configure_page_styles("static/css/styles.css") - - if "model" not in st.session_state: - llm = create_huggingface_hub() - st.session_state["model"] = llm - - if "messages" not in st.session_state: - create_message() - - hide_main_menu_and_footer() - - with st.sidebar: - display_github_badge() - display_logo_and_heading() - st.markdown("`Made with 🀍`") - handle_new_chat() - - display_welcome_message() - for message in st.session_state.messages: - with st.chat_message(message["role"]): - st.markdown(message["content"]) - - if prompt := st.chat_input(): - st.session_state.messages.append({"role": "user", "content": prompt}) - with st.chat_message("user"): - st.markdown(prompt) - - conversation_history = format_chat_history(st.session_state.messages) - prompt_template = PromptTemplate( - template=TEMPLATE, - input_variables=["input", "conversation_history"] - ) - - if "model" in st.session_state: - llm_chain = prompt_template | st.session_state.model | StrOutputParser() - - with st.chat_message("assistant"): - with st.spinner("Generating..."): - response = llm_chain.invoke({ - "input": prompt, - "conversation_history": conversation_history - }) - - # Clean and format the response - formatted_response = extract_sql_code(response) - st.markdown(formatted_response) - - # Add to chat history - st.session_state.messages.append({ - "role": "assistant", - "content": formatted_response - }) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/src/auth.py b/src/auth.py deleted file mode 100644 index b16d3c2..0000000 --- a/src/auth.py +++ /dev/null @@ -1,29 +0,0 @@ -import asyncio -from src.constant import * -from httpx_oauth.clients.google import GoogleOAuth2 - - -async def get_authorization_url(client: GoogleOAuth2, redirect_uri: str): - authorization_url = await client.get_authorization_url( - redirect_uri, scope=["profile", "email"] - ) - return authorization_url - - -async def get_access_token(client: GoogleOAuth2, redirect_uri: str, code: str): - token = await client.get_access_token(code, redirect_uri) - return token - - -async def get_email(client: GoogleOAuth2, token: str): - user_id, user_email = await client.get_id_email(token) - return user_id, user_email - - -def get_login_str(): - client: GoogleOAuth2 = GoogleOAuth2(CLIENT_ID, CLIENT_SECRET) - authorization_url = asyncio.run( - get_authorization_url(client, REDIRECT_URI) - ) - return f"""\ -""" diff --git a/src/backend.py b/src/backend.py deleted file mode 100644 index 8dc0f24..0000000 --- a/src/backend.py +++ /dev/null @@ -1,95 +0,0 @@ -import streamlit as st -from streamlit_oauth import OAuth2Component -import sys -import os -import json -import base64 - -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - -from src.auth import * -from src.constant import * - - -def configure_page_styles(file_name): - """Configures Streamlit page styles for Querypls. - - Sets page title, icon, and applies custom CSS styles. - Hides Streamlit main menu and footer for a cleaner interface. - - Note: - Ensure 'static/css/styles.css' exists with desired styles. - """ - st.set_page_config( - page_title="Querypls", - page_icon="πŸ’¬", - layout="wide", - ) - with open(file_name) as f: - st.markdown( - "".format(f.read()), unsafe_allow_html=True - ) - - hide_streamlit_style = """""" - st.markdown(hide_streamlit_style, unsafe_allow_html=True) - - -def hide_main_menu_and_footer(): - """Hides the Streamlit main menu and footer for a cleaner interface.""" - st.markdown( - """""", - unsafe_allow_html=True, - ) - - -def handle_google_login_if_needed(result): - """Handles Google login if it has not been run yet. - - Args: - result (str): Authorization code received from Google. - - Returns: - None - """ - try: - if result and "token" in result: - st.session_state.token = result.get("token") - token = st.session_state["token"] - id_token = token["id_token"] - payload = id_token.split(".")[1] - payload += "=" * (-len(payload) % 4) - payload = json.loads(base64.b64decode(payload)) - email = payload["email"] - st.session_state.user_email = email - st.session_state.code = True - return - except Exception: - st.warning( - "Seems like there is a network issue. \ - Please check your internet connection." - ) - sys.exit() - - -def display_github_badge(): - """Displays a GitHub badge with a link to the Querypls repository.""" - st.markdown( - """\ - """, - unsafe_allow_html=True, - ) - - -def create_oauth2_component(): - return OAuth2Component( - CLIENT_ID, - CLIENT_SECRET, - AUTHORIZE_URL, - TOKEN_URL, - REFRESH_TOKEN_URL, - REVOKE_TOKEN_URL, - ) diff --git a/src/backend/backend.py b/src/backend/backend.py new file mode 100644 index 0000000..0c28bff --- /dev/null +++ b/src/backend/backend.py @@ -0,0 +1,24 @@ +""" +Backend utilities for Streamlit configuration and styling. +""" + +from src.config.constants import STREAMLIT_CONFIG +import streamlit as st +import sys +import os + +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) + + +def hide_main_menu_and_footer() -> None: + st.markdown( + "", + unsafe_allow_html=True, + ) + + +def display_github_badge() -> None: + st.markdown( + "", + unsafe_allow_html=True, + ) diff --git a/src/backend/orchestrator.py b/src/backend/orchestrator.py new file mode 100644 index 0000000..d04389f --- /dev/null +++ b/src/backend/orchestrator.py @@ -0,0 +1,273 @@ +""" +Backend orchestrator for managing application state and services. +""" + +import uuid +from datetime import datetime +from typing import List, Optional, Dict, Any +from dataclasses import dataclass + +from src.config.settings import get_settings +from src.config.constants import WELCOME_MESSAGE, DEFAULT_SESSION_NAME +from src.services.sql_service import SQLGenerationService +from src.services.csv_analysis_tools import CSVAnalysisTools +from src.services.conversation_service import ConversationService +from src.services.routing_service import IntelligentRoutingService +from src.schemas.requests import ( + SQLGenerationRequest, + ChatMessage, + ConversationHistory, + NewChatRequest, +) +from src.schemas.responses import ( + ChatResponse, + SessionInfo, + HealthCheckResponse, +) + + +@dataclass +class Session: + session_id: str + session_name: str + created_at: datetime + messages: List[ChatMessage] + last_activity: datetime + csv_data: Optional[str] = None + csv_file_path: Optional[str] = None + csv_info: Optional[Dict[str, Any]] = None + + +class BackendOrchestrator: + def __init__(self): + self.settings = get_settings() + self.sql_service = SQLGenerationService() + self.csv_tools = CSVAnalysisTools() + self.conversation_service = ConversationService() + self.routing_service = IntelligentRoutingService() + self.sessions: Dict[str, Session] = {} + self.max_sessions = self.settings.max_chat_histories + + def create_new_session(self, request: NewChatRequest) -> SessionInfo: + session_id = str(uuid.uuid4()) + session_name = request.session_name or f"Chat {len(self.sessions) + 1}" + + messages = [] + if request.initial_context: + messages.append(ChatMessage(role="system", content=request.initial_context)) + + messages.append(ChatMessage(role="assistant", content=WELCOME_MESSAGE)) + + session = Session( + session_id=session_id, + session_name=session_name, + created_at=datetime.now(), + messages=messages, + last_activity=datetime.now(), + ) + + self.sessions[session_id] = session + self._cleanup_old_sessions() + + return SessionInfo( + session_id=session_id, + session_name=session_name, + created_at=session.created_at.isoformat(), + message_count=len(session.messages), + last_activity=session.last_activity.isoformat(), + ) + + def get_session(self, session_id: str) -> Optional[Session]: + return self.sessions.get(session_id) + + def list_sessions(self) -> List[SessionInfo]: + return [ + SessionInfo( + session_id=session.session_id, + session_name=session.session_name, + created_at=session.created_at.isoformat(), + message_count=len(session.messages), + last_activity=session.last_activity.isoformat(), + ) + for session in self.sessions.values() + ] + + def delete_session(self, session_id: str) -> bool: + if session_id in self.sessions: + self.csv_tools.close_session(session_id) + del self.sessions[session_id] + return True + return False + + def load_csv_data(self, session_id: str, csv_content: str) -> Dict[str, Any]: + session = self.get_session(session_id) + if not session: + raise ValueError(f"Session {session_id} not found") + + # Save CSV to file + import os + import tempfile + + # Create temp directory for this session if it doesn't exist + temp_dir = f"/tmp/querypls_session_{session_id}" + os.makedirs(temp_dir, exist_ok=True) + + # Save CSV to file + csv_file_path = os.path.join(temp_dir, "data.csv") + with open(csv_file_path, "w") as f: + f.write(csv_content) + + # Store both the content and file path in session + session.csv_data = csv_content + session.csv_file_path = csv_file_path + + # Get CSV info for context + import pandas as pd + from io import StringIO + + df = pd.read_csv(StringIO(csv_content)) + + session.csv_info = { + "file_path": csv_file_path, + "shape": df.shape, + "columns": list(df.columns), + "dtypes": df.dtypes.to_dict(), + "sample_data": df.head(3).to_dict("records"), + } + + session.last_activity = datetime.now() + + return { + "status": "success", + "message": "CSV data loaded successfully", + "shape": df.shape, + "columns": list(df.columns), + } + + def generate_intelligent_response( + self, session_id: str, user_query: str + ) -> ChatResponse: + """Generate response using intelligent routing to determine the appropriate agent.""" + session = self.get_session(session_id) + if not session: + raise ValueError(f"Session {session_id} not found") + + user_message = ChatMessage( + role="user", content=user_query, timestamp=datetime.now().isoformat() + ) + session.messages.append(user_message) + + # Determine which agent should handle this query + csv_loaded = bool(session.csv_data) + routing_decision = self.routing_service.determine_agent( + user_query, session.messages, csv_loaded + ) + + # Generate response based on routing decision + if routing_decision.agent == "CONVERSATION_AGENT": + response_content = self.routing_service.handle_conversation_query( + user_query + ) + elif routing_decision.agent == "SQL_AGENT": + response_content = self.routing_service.handle_sql_query( + user_query, session.messages + ) + elif routing_decision.agent == "CSV_AGENT": + if session.csv_data and session.csv_info: + response_content = self.routing_service.handle_csv_query( + user_query, session.csv_info, session.messages + ) + else: + response_content = "I don't see any CSV data loaded. Please upload a CSV file first to analyze it." + else: + # Fallback to conversation + response_content = self.routing_service.handle_conversation_query( + user_query + ) + + assistant_message = ChatMessage( + role="assistant", + content=response_content, + timestamp=datetime.now().isoformat(), + ) + session.messages.append(assistant_message) + session.last_activity = datetime.now() + + return ChatResponse( + message_id=str(uuid.uuid4()), + content=response_content, + timestamp=datetime.now().isoformat(), + session_id=session_id, + ) + + def get_conversation_history(self, session_id: str) -> ConversationHistory: + session = self.get_session(session_id) + if not session: + raise ValueError(f"Session {session_id} not found") + + return ConversationHistory(messages=session.messages, session_id=session_id) + + def get_csv_info(self, session_id: str) -> Dict[str, Any]: + return self.csv_tools.get_csv_info(session_id) + + def health_check(self) -> HealthCheckResponse: + services_status = { + "sql_service": "healthy", + "csv_analysis_service": "healthy", + "conversation_service": "healthy", + "session_manager": "healthy", + } + + try: + test_request = SQLGenerationRequest( + user_query="SELECT 1", conversation_history=[] + ) + if not self.sql_service: + services_status["sql_service"] = "unhealthy" + except Exception: + services_status["sql_service"] = "unhealthy" + + try: + if not self.csv_tools: + services_status["csv_analysis_service"] = "unhealthy" + except Exception: + services_status["csv_analysis_service"] = "unhealthy" + + try: + if not self.conversation_service: + services_status["conversation_service"] = "unhealthy" + except Exception: + services_status["conversation_service"] = "unhealthy" + + return HealthCheckResponse( + status=( + "healthy" + if all(status == "healthy" for status in services_status.values()) + else "unhealthy" + ), + version=self.settings.app_version, + timestamp=datetime.now().isoformat(), + services=services_status, + ) + + def _cleanup_old_sessions(self): + if len(self.sessions) <= self.max_sessions: + return + + sorted_sessions = sorted( + self.sessions.items(), key=lambda x: x[1].last_activity + ) + + sessions_to_remove = len(self.sessions) - self.max_sessions + for i in range(sessions_to_remove): + session_id, _ = sorted_sessions[i] + self.delete_session(session_id) + + def get_default_session(self) -> str: + for session_id, session in self.sessions.items(): + if session.session_name == DEFAULT_SESSION_NAME: + return session_id + + request = NewChatRequest(session_name=DEFAULT_SESSION_NAME) + session_info = self.create_new_session(request) + return session_info.session_id diff --git a/src/config/__init__.py b/src/config/__init__.py new file mode 100644 index 0000000..b2e8ad4 --- /dev/null +++ b/src/config/__init__.py @@ -0,0 +1,3 @@ +""" +Configuration package for Querypls application. +""" diff --git a/src/config/constants.py b/src/config/constants.py new file mode 100644 index 0000000..1ee6e21 --- /dev/null +++ b/src/config/constants.py @@ -0,0 +1,44 @@ +""" +Constants for Querypls application. +""" + +# Application Settings +MAX_RETRIES = 3 +EXECUTION_TIMEOUT = 30 +MAX_CHAT_HISTORIES = 6 +STREAMLIT_PORT = 8501 +STREAMLIT_HOST = "localhost" + +# Streamlit Configuration +STREAMLIT_CONFIG = {"page_title": "Querypls", "page_icon": "πŸ’¬", "layout": "wide"} + +# Welcome and Session Messages +WELCOME_MESSAGE = "Hello! πŸ‘‹ I'm Querypls, your SQL and data analysis assistant. I can help you generate SQL queries or analyze CSV files. What would you like to work on today?" +DEFAULT_SESSION_NAME = "Default Chat" + +# CSV Analysis Section +CSV_ANALYSIS_SECTION = "### πŸ“Š CSV Analysis" +CSV_UPLOAD_LABEL = "Upload CSV File" +CSV_UPLOAD_HELP = "Upload a CSV file to analyze with Python code" +CSV_PREVIEW = "πŸ“‹ CSV Preview" +CSV_COLUMNS = "**Columns:** {columns}" +CSV_DTYPES = "**Data Types:** {dtypes}" +LOAD_CSV_BUTTON = "πŸ“Š Load CSV Data" +CSV_LOADED_SUCCESS = "βœ… CSV data loaded successfully!" +CSV_UPLOAD_SUCCESS = "βœ… CSV uploaded successfully! Shape: {shape}" +CSV_UPLOAD_ERROR = "❌ Error uploading CSV: {error}" +CSV_LOAD_ERROR = "❌ No CSV data loaded. Please upload a CSV file first." +CSV_ANALYSIS_ERROR = "❌ Error analyzing CSV: {error}" + +# Session Management +SESSION_CREATE_ERROR = "❌ Error creating session: {error}" +SESSION_NOT_FOUND_ERROR = "❌ Session not found" + +# worst-case scenario +WORST_CASE_SCENARIO = "I'm here to help! I can assist with SQL generation or CSV data analysis. What would you like to do?" + +# Application Errors +ORCHESTRATOR_INIT_ERROR = "❌ Error initializing orchestrator: {error}" +APP_INIT_ERROR = "❌ Error initializing application" +RESPONSE_GENERATION_ERROR = "❌ Error generating response: {error}" +MESSAGE_LOAD_ERROR = "❌ Error loading messages: {error}" diff --git a/src/config/settings.py b/src/config/settings.py new file mode 100644 index 0000000..858af21 --- /dev/null +++ b/src/config/settings.py @@ -0,0 +1,39 @@ +""" +Application settings with environment variable support. +""" + +import os +from typing import Optional, ClassVar +from pydantic import Field, BaseModel, ConfigDict + + +class Settings(BaseModel): + groq_api_key: str = Field( + default=os.getenv("GROQ_API_KEY", "mock_api_key"), env="GROQ_API_KEY" + ) + groq_model_name: str = Field(default="openai/gpt-oss-120b", env="GROQ_MODEL_NAME") + app_version: str = Field(default="1.0.0", env="APP_VERSION") + max_chat_histories: int = Field(default=5, env="MAX_CHAT_HISTORIES") + debug_mode: bool = Field(default=False, env="DEBUG_MODE") + + # Legacy fields for backward compatibility + max_tokens: Optional[str] = Field(1000, env="MAX_TOKENS") + temperature: Optional[str] = Field(0.7, env="TEMPERATURE") + log_level: Optional[str] = Field("INFO", env="LOG_LEVEL") + + json_schema_extra: ClassVar[str] = "ignore" + + model_config = ConfigDict( + env_file=".env", + env_file_encoding="utf-8", + ) + + +_settings_instance: Optional[Settings] = None + + +def get_settings() -> Settings: + global _settings_instance + if _settings_instance is None: + _settings_instance = Settings() + return _settings_instance diff --git a/src/constant.py b/src/constant.py deleted file mode 100644 index d3c4790..0000000 --- a/src/constant.py +++ /dev/null @@ -1,14 +0,0 @@ -from streamlit import secrets - -DETA_PROJECT_KEY = secrets["DETA_PROJECT_KEY"] -HUGGINGFACE_API_TOKEN = secrets["HUGGINGFACE_API_TOKEN"] -REPO_ID = secrets["REPO_ID"] -CLIENT_ID = secrets["CLIENT_ID"] -CLIENT_SECRET = secrets["CLIENT_SECRET"] -REDIRECT_URI = secrets["REDIRECT_URI"] -TEMPLATE = secrets["TEMPLATE"] -AUTHORIZE_URL = "https://accounts.google.com/o/oauth2/v2/auth" -TOKEN_URL = "https://oauth2.googleapis.com/token" -REFRESH_TOKEN_URL = "https://oauth2.googleapis.com/token" -REVOKE_TOKEN_URL = "https://oauth2.googleapis.com/revoke" -SCOPE = "email" \ No newline at end of file diff --git a/src/database.py b/src/database.py deleted file mode 100644 index 0ed8c5b..0000000 --- a/src/database.py +++ /dev/null @@ -1,64 +0,0 @@ -import streamlit as st - - -def get_previous_chats(db, user_email): - """Fetches previous chat records for a user from the database. - - Args: - db: Deta Base instance. - user_email (str): User's email address. - - Returns: - list: List of previous chat records. - """ - return db.fetch({"email": user_email}).items - - -def database(db, previous_key="key", previous_chat=None, max_chat_histories=5): - """Manages user chat history in the database. - - Updates, adds, or removes chat history based on user interaction. - - Args: - db: Deta Base instance. - previous_key (str): Key for the previous chat in the database. - previous_chat (list, optional): Previous chat messages. - max_chat_histories (int, optional): Maximum number of chat histories to retain. - - Returns: - None - """ - user_email = st.session_state.user_email - previous_chats = get_previous_chats(db, user_email) - existing_chat = db.get(previous_key) if previous_key != "key" else None - if ( - previous_chat is not None - and existing_chat is not None - and previous_key != "key" - ): - new_messages = [ - message - for message in previous_chat - if message not in existing_chat["chat"] - ] - existing_chat["chat"].extend(new_messages) - db.update({"chat": existing_chat["chat"]}, key=previous_key) - return - previous_chat = ( - st.session_state.messages if previous_chat is None else previous_chat - ) - if len(previous_chat) > 1 and previous_key == "key": - title = previous_chat[1]["content"] - db.put( - { - "email": user_email, - "chat": previous_chat, - "title": title[:25] + "....." if len(title) > 25 else title, - } - ) - - if len(previous_chats) >= max_chat_histories: - db.delete(previous_chats[0]["key"]) - st.warning( - f"Chat '{previous_chats[0]['title']}' has been removed as you reached the limit of {max_chat_histories} chat histories." - ) diff --git a/src/frontend.py b/src/frontend.py deleted file mode 100644 index 6241c37..0000000 --- a/src/frontend.py +++ /dev/null @@ -1,92 +0,0 @@ -import streamlit as st - -def display_logo_and_heading(): - """Displays the Querypls logo.""" - st.image("static/image/logo.png") - - -def display_welcome_message(): - """Displays a welcome message based on user chat history.""" - no_chat_history = len(st.session_state.messages) == 1 - if no_chat_history: - st.markdown(f"#### Welcome to \n ## πŸ—ƒοΈπŸ’¬Querypls - Prompt to SQL") - - -def handle_new_chat(max_chat_histories=5): - """Handles the initiation of a new chat session. - - Displays the remaining chat history count and provides a button to start a new chat. - - Args: - max_chat_histories (int, optional): Maximum number of chat histories to retain. - - Returns: - None - """ - remaining_chats = max_chat_histories - len(st.session_state.get("previous_chats", [])) - st.markdown( - f" #### Remaining Chat Histories: `{remaining_chats}/{max_chat_histories}`" - ) - st.markdown( - "You can create up to 5 chat histories. Each history can contain unlimited messages." - ) - - if st.button("βž• New chat"): - save_chat_history() # Save current chat before creating a new one - create_message() - - -def display_previous_chats(): - """Displays previous chat records stored in session state. - - Allows the user to select a chat to view. - """ - if "previous_chats" in st.session_state: - reversed_chats = reversed(st.session_state["previous_chats"]) - - for chat in reversed_chats: - if st.button(chat["title"], key=chat["key"]): - update_session_state(chat) - - -def create_message(): - """Creates a default assistant message and initializes a session key.""" - st.session_state["messages"] = [ - {"role": "assistant", "content": "How may I help you?"} - ] - st.session_state["key"] = "key" - - -def update_session_state(chat): - """Updates the session state with selected chat information. - - Args: - chat (dict): Selected chat information. - """ - st.session_state["messages"] = chat["chat"] - st.session_state["key"] = chat["key"] - - -def save_chat_history(): - """Saves the current chat to session state if it contains messages.""" - if "messages" in st.session_state and len(st.session_state["messages"]) > 1: - # Initialize previous chats list if it doesn't exist - if "previous_chats" not in st.session_state: - st.session_state["previous_chats"] = [] - - # Create a chat summary to store in session - title = st.session_state["messages"][1]["content"] - chat_summary = { - "title": title[:25] + "....." if len(title) > 25 else title, - "chat": st.session_state["messages"], - "key": f"chat_{len(st.session_state['previous_chats']) + 1}" - } - - st.session_state["previous_chats"].append(chat_summary) - - # Limit chat histories to a maximum number - if len(st.session_state["previous_chats"]) > 5: - st.session_state["previous_chats"].pop(0) # Remove oldest chat - st.warning( - f"The oldest chat history has been removed as you reached the limit of 5 chat histories." - ) \ No newline at end of file diff --git a/src/frontend/app.py b/src/frontend/app.py new file mode 100644 index 0000000..9c39bd7 --- /dev/null +++ b/src/frontend/app.py @@ -0,0 +1,244 @@ +""" +Main Streamlit application for Querypls. +""" + +from src.schemas.requests import NewChatRequest +from src.config.constants import ( + CSV_ANALYSIS_SECTION, + CSV_UPLOAD_LABEL, + CSV_UPLOAD_HELP, + CSV_PREVIEW, + CSV_COLUMNS, + CSV_DTYPES, + LOAD_CSV_BUTTON, + CSV_LOADED_SUCCESS, + CSV_UPLOAD_SUCCESS, + CSV_UPLOAD_ERROR, + SESSION_CREATE_ERROR, + ORCHESTRATOR_INIT_ERROR, + SESSION_NOT_FOUND_ERROR, + APP_INIT_ERROR, + RESPONSE_GENERATION_ERROR, + MESSAGE_LOAD_ERROR, +) +from src.frontend.frontend import display_logo_and_heading, display_welcome_message +from src.backend.backend import ( + display_github_badge, + hide_main_menu_and_footer, +) +from src.backend.orchestrator import BackendOrchestrator +import streamlit as st +import sys +import os +import pandas as pd + +project_root = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +) +sys.path.insert(0, project_root) + + +def initialize_orchestrator(): + if "orchestrator" not in st.session_state: + try: + st.session_state["orchestrator"] = BackendOrchestrator() + except Exception as e: + st.error(ORCHESTRATOR_INIT_ERROR.format(error=str(e))) + return None + return st.session_state["orchestrator"] + + +def get_current_session_id(): + if "current_session_id" not in st.session_state: + orchestrator = initialize_orchestrator() + if orchestrator: + st.session_state["current_session_id"] = orchestrator.get_default_session() + return st.session_state.get("current_session_id") + + +def display_messages(session_id: str): + orchestrator = initialize_orchestrator() + if not orchestrator: + return + + try: + conversation = orchestrator.get_conversation_history(session_id) + for message in conversation.messages: + with st.chat_message(message.role): + display_message_with_images(message.content) + except Exception as e: + st.error(MESSAGE_LOAD_ERROR.format(error=str(e))) + + +def display_message_with_images(content: str): + """Display message content and handle CSV analysis responses with images.""" + # Check if this is a CSV analysis response with images + if "**Generated Images:**" in content: + # Split the content into text and image sections + parts = content.split("**Generated Images:**") + text_content = parts[0].strip() + + # Display the text content + st.markdown(text_content) + + # Handle images if present + if len(parts) > 1: + image_section = parts[1].strip() + image_lines = [ + line.strip() + for line in image_section.split("\n") + if line.strip().startswith("- ") + ] + + if image_lines: + st.markdown("**Generated Images:**") + + # Look for images in the specific temp directory + import os + import glob + + temp_dir = "/tmp/querypls_session_csv_analysis_temp" + if os.path.exists(temp_dir): + for line in image_lines: + # Extract filename from the line (e.g., "- department_chart.png") + filename = line.replace("- ", "").strip() + image_path = os.path.join(temp_dir, filename) + + if os.path.exists(image_path): + try: + st.image( + image_path, caption=filename, use_column_width=True + ) + except Exception as e: + st.error(f"Error displaying image {filename}: {str(e)}") + else: + st.warning(f"Image not found: {filename}") + else: + # Regular message content + st.markdown(content) + + +def cleanup_old_images(): + """Clean up old CSV analysis images.""" + import os + import glob + + temp_dir = "/tmp/querypls_session_csv_analysis_temp" + if os.path.exists(temp_dir): + try: + # Remove old images + for img_file in glob.glob(os.path.join(temp_dir, "*.png")): + os.remove(img_file) + for img_file in glob.glob(os.path.join(temp_dir, "*.jpg")): + os.remove(img_file) + except Exception as e: + print(f"Warning: Could not cleanup old images: {e}") + + +def upload_csv_file(): + uploaded_file = st.file_uploader( + CSV_UPLOAD_LABEL, type=["csv"], help=CSV_UPLOAD_HELP + ) + + if uploaded_file is not None: + try: + # Reset file pointer to beginning + uploaded_file.seek(0) + csv_content = uploaded_file.read().decode("utf-8") + + # Reset file pointer again for pandas + uploaded_file.seek(0) + df = pd.read_csv(uploaded_file) + + st.success(CSV_UPLOAD_SUCCESS.format(shape=df.shape)) + + with st.expander(CSV_PREVIEW): + st.dataframe(df.head()) + st.write(CSV_COLUMNS.format(columns=list(df.columns))) + st.write(CSV_DTYPES.format(dtypes=df.dtypes.to_dict())) + + return csv_content + except Exception as e: + st.error(CSV_UPLOAD_ERROR.format(error=str(e))) + return None + + return None + + +def main(): + orchestrator = initialize_orchestrator() + if not orchestrator: + st.error(APP_INIT_ERROR) + return + + current_session_id = get_current_session_id() + if not current_session_id: + st.error(SESSION_NOT_FOUND_ERROR) + return + + hide_main_menu_and_footer() + + with st.sidebar: + st.markdown( + "", + unsafe_allow_html=True, + ) + display_logo_and_heading() + st.markdown("`Made with 🀍`") + st.markdown("### Sessions") + if st.button("βž• New Session"): + try: + # Clean up old images when creating new session + cleanup_old_images() + + sessions = orchestrator.list_sessions() + new_session = orchestrator.create_new_session( + NewChatRequest(session_name=f"Chat {len(sessions) + 1}") + ) + st.session_state["current_session_id"] = new_session.session_id + st.rerun() + except Exception as e: + st.error(SESSION_CREATE_ERROR.format(error=str(e))) + st.markdown("---") + st.markdown(CSV_ANALYSIS_SECTION) + + csv_content = upload_csv_file() + if csv_content: + if st.button(LOAD_CSV_BUTTON): + try: + # Clean up old images before loading new CSV + cleanup_old_images() + + result = orchestrator.load_csv_data(current_session_id, csv_content) + if result["status"] == "success": + st.success(CSV_LOADED_SUCCESS) + st.session_state["csv_loaded"] = True + st.rerun() # Refresh to show updated state + else: + st.error(f"❌ Error loading CSV: {result['message']}") + except Exception as e: + st.error(f"❌ Error: {str(e)}") + + display_welcome_message() + display_messages(current_session_id) + + if prompt := st.chat_input(): + try: + # Use intelligent routing for all queries + response = orchestrator.generate_intelligent_response( + current_session_id, prompt + ) + + # Display the response immediately + with st.chat_message("user"): + st.markdown(prompt) + + with st.chat_message("assistant"): + display_message_with_images(response.content) + + except Exception as e: + st.error(RESPONSE_GENERATION_ERROR.format(error=str(e))) + + +if __name__ == "__main__": + main() diff --git a/src/frontend/frontend.py b/src/frontend/frontend.py new file mode 100644 index 0000000..f9f720d --- /dev/null +++ b/src/frontend/frontend.py @@ -0,0 +1,39 @@ +""" +Frontend utilities for Streamlit interface components. +""" + +import streamlit as st + + +def display_logo_and_heading(): + st.image("static/image/logo.png") + + +def display_welcome_message(): + st.markdown("#### Welcome to \n ## πŸ—ƒοΈπŸ’¬Querypls - Prompt to SQL") + + +def handle_new_chat(max_chat_histories=5): + st.markdown(f"#### Remaining Chat Histories: `{max_chat_histories}`") + st.markdown( + "You can create multiple chat sessions. Each session can contain unlimited messages." + ) + + if st.button("βž• New chat"): + st.rerun() + + +def display_previous_chats(): + pass + + +def create_message(): + pass + + +def update_session_state(chat): + pass + + +def save_chat_history(): + pass diff --git a/src/model.py b/src/model.py deleted file mode 100644 index 788c738..0000000 --- a/src/model.py +++ /dev/null @@ -1,22 +0,0 @@ -from langchain_community.llms import HuggingFaceHub -import sys -import os - - -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) - -from src.auth import * -from src.constant import * - - -def create_huggingface_hub(): - """Creates an instance of Hugging Face Hub with specified configurations. - - Returns: - HuggingFaceHub: Instance of Hugging Face Hub. - """ - return HuggingFaceHub( - huggingfacehub_api_token=HUGGINGFACE_API_TOKEN, - repo_id=REPO_ID, - model_kwargs={"temperature": 0.7, "max_new_tokens": 180}, - ) diff --git a/src/schemas/__init__.py b/src/schemas/__init__.py new file mode 100644 index 0000000..253f1d1 --- /dev/null +++ b/src/schemas/__init__.py @@ -0,0 +1,3 @@ +""" +Pydantic schemas for Querypls application. +""" diff --git a/src/schemas/requests.py b/src/schemas/requests.py new file mode 100644 index 0000000..e0c2806 --- /dev/null +++ b/src/schemas/requests.py @@ -0,0 +1,57 @@ +""" +Request schemas for Querypls application. +""" + +from typing import List, Optional, Literal +from pydantic import BaseModel, Field + + +class ChatMessage(BaseModel): + """Schema for chat message.""" + + role: Literal["user", "assistant", "system"] = Field( + description="Message role (user, assistant, system)" + ) + content: str = Field(description="Message content", min_length=1) + timestamp: Optional[str] = Field(default=None, description="Message timestamp") + session_id: Optional[str] = Field(default=None, description="Session identifier") + + +class SQLGenerationRequest(BaseModel): + """Schema for SQL generation request.""" + + user_query: str = Field( + description="User's natural language query for SQL generation", + min_length=1, + max_length=1000, + ) + conversation_history: List[ChatMessage] = Field( + default=[], description="Previous conversation messages for context" + ) + database_schema: Optional[str] = Field( + default=None, description="Database schema information (optional)" + ) + query_type: Optional[str] = Field( + default=None, + description="Preferred query type (SELECT, INSERT, UPDATE, DELETE)", + ) + + +class ConversationHistory(BaseModel): + """Schema for conversation history.""" + + messages: List[ChatMessage] = Field( + default=[], description="List of conversation messages" + ) + session_id: Optional[str] = Field(default=None, description="Session identifier") + + +class NewChatRequest(BaseModel): + """Schema for creating a new chat session.""" + + session_name: Optional[str] = Field( + default=None, description="Name for the new chat session" + ) + initial_context: Optional[str] = Field( + default=None, description="Initial context or instructions" + ) diff --git a/src/schemas/responses.py b/src/schemas/responses.py new file mode 100644 index 0000000..eb603da --- /dev/null +++ b/src/schemas/responses.py @@ -0,0 +1,80 @@ +""" +Response schemas for Querypls application. +""" + +from typing import List, Optional, Literal +from pydantic import BaseModel, Field + + +class SQLQueryResponse(BaseModel): + """Schema for SQL query generation response.""" + + sql_query: str = Field(..., description="The generated SQL query as a string") + explanation: str = Field( + ..., description="Brief explanation of what the query does" + ) + tables_used: List[str] = Field( + default=[], description="Array of table names used in the query" + ) + columns_selected: List[str] = Field( + default=[], description="Array of column names selected in the query" + ) + query_type: Literal[ + "SELECT", "INSERT", "UPDATE", "DELETE", "CREATE", "DROP", "ALTER" + ] = Field(..., description="Type of query generated") + complexity: Literal["SIMPLE", "MEDIUM", "COMPLEX"] = Field( + ..., description="Query complexity level" + ) + estimated_rows: str = Field( + default="variable", + description="Estimated number of rows returned (if applicable)", + ) + execution_time: Optional[str] = Field( + default=None, description="Estimated execution time" + ) + warnings: List[str] = Field( + default=[], description="Any warnings about the generated query" + ) + + +class ChatResponse(BaseModel): + """Schema for chat response.""" + + message_id: str = Field(..., description="Unique identifier for the message") + role: Literal["assistant"] = Field(default="assistant", description="Message role") + content: str = Field(..., description="Response content") + sql_response: Optional[SQLQueryResponse] = Field( + default=None, description="Structured SQL response if applicable" + ) + timestamp: str = Field(..., description="Response timestamp") + session_id: str = Field(..., description="Session identifier") + + +class ErrorResponse(BaseModel): + """Schema for error responses.""" + + error_code: str = Field(..., description="Error code identifier") + error_message: str = Field(..., description="Human-readable error message") + details: Optional[str] = Field(default=None, description="Additional error details") + timestamp: str = Field(..., description="Error timestamp") + + +class SessionInfo(BaseModel): + """Schema for session information.""" + + session_id: str = Field(..., description="Unique session identifier") + session_name: str = Field(..., description="Session name") + created_at: str = Field(..., description="Session creation timestamp") + message_count: int = Field(..., description="Number of messages in the session") + last_activity: str = Field(..., description="Last activity timestamp") + + +class HealthCheckResponse(BaseModel): + """Schema for health check response.""" + + status: Literal["healthy", "unhealthy"] = Field( + ..., description="Application health status" + ) + version: str = Field(..., description="Application version") + timestamp: str = Field(..., description="Health check timestamp") + services: dict = Field(default={}, description="Status of individual services") diff --git a/src/services/__init__.py b/src/services/__init__.py new file mode 100644 index 0000000..dfd3691 --- /dev/null +++ b/src/services/__init__.py @@ -0,0 +1,3 @@ +""" +Services package for Querypls application. +""" diff --git a/src/services/conversation_service.py b/src/services/conversation_service.py new file mode 100644 index 0000000..df1020d --- /dev/null +++ b/src/services/conversation_service.py @@ -0,0 +1,87 @@ +""" +Conversation service for handling normal user queries. +""" + +from typing import Literal, Union +from pydantic_ai import Agent, RunContext +from pydantic_ai.models.groq import GroqModel +from pydantic_ai.providers.groq import GroqProvider + +from src.config.constants import WORST_CASE_SCENARIO +from src.config.settings import get_settings +from src.services.models import ConversationResponse, Failed +from utils.prompt import CONVERSATION_PROMPT + + +class ConversationService: + def __init__(self): + self.settings = get_settings() + + self.model = GroqModel( + self.settings.groq_model_name, + provider=GroqProvider(api_key=self.settings.groq_api_key), + ) + + self.conversation_agent = Agent[None, Union[ConversationResponse, Failed]]( + self.model, + output_type=Union[ConversationResponse, Failed], + system_prompt=CONVERSATION_PROMPT, + ) + + def is_conversational_query(self, query: str) -> bool: + """Check if query is conversational (not SQL/data related).""" + conversational_keywords = [ + "hi", + "hello", + "hey", + "good morning", + "good afternoon", + "good evening", + "how are you", + "what's up", + "thanks", + "thank you", + "bye", + "goodbye", + "help", + "what can you do", + "who are you", + "tell me about yourself", + "nice to meet you", + "pleasure", + "good", + "fine", + "okay", + ] + query_lower = query.lower().strip() + return any(keyword in query_lower for keyword in conversational_keywords) + + def get_conversational_response(self, query: str) -> str: + """Get a natural response for conversational queries.""" + try: + result = self.conversation_agent.run_sync(query) + + if isinstance(result.output, ConversationResponse): + return result.output.message + else: + # Fallback responses + query_lower = query.lower().strip() + + if any(greeting in query_lower for greeting in ["hi", "hello", "hey"]): + return "Hello! πŸ‘‹ How can I help you today? I can assist with SQL generation or CSV data analysis." + elif "how are you" in query_lower: + return "I'm doing great, thank you for asking! 😊 How can I assist you with your data queries today?" + elif any(thanks in query_lower for thanks in ["thanks", "thank you"]): + return ( + "You're welcome! 😊 Is there anything else I can help you with?" + ) + elif any(bye in query_lower for bye in ["bye", "goodbye"]): + return "Goodbye! πŸ‘‹ Feel free to come back if you need help with SQL or data analysis." + elif "help" in query_lower or "what can you do" in query_lower: + return "I'm Querypls, your SQL and data analysis assistant! πŸ—ƒοΈπŸ’¬\n\nI can help you with:\nβ€’ **SQL Generation**: Convert natural language to SQL queries\nβ€’ **CSV Analysis**: Analyze data files with Python code\nβ€’ **Data Visualization**: Create charts and graphs\n\nJust ask me anything about your data!" + else: + return WORST_CASE_SCENARIO + + except Exception as e: + # Fallback response + return "Hello! How can I help you today? I can assist with SQL generation or CSV data analysis." diff --git a/src/services/csv_analysis_tools.py b/src/services/csv_analysis_tools.py new file mode 100644 index 0000000..3c65e21 --- /dev/null +++ b/src/services/csv_analysis_tools.py @@ -0,0 +1,228 @@ +import io +import pandas as pd +from typing import Dict, Any, Optional +from pydantic_ai import Agent, RunContext +from pydantic_ai.models.groq import GroqModel +from pydantic_ai.providers.groq import GroqProvider +from pydantic import BaseModel, Field + +from src.config.settings import get_settings +from src.services.jupyter_service import CSVAnalysisService +from utils.prompt import CSV_ANALYSIS_PROMPT, CODE_FIX_PROMPT, CSV_AGENT_PROMPT + + +class CSVAnalysisContext(BaseModel): + session_id: str + csv_content: str + csv_headers: list + sample_data: list + + +class PythonCodeResponse(BaseModel): + python_code: str = Field(description="Generated Python code for data analysis") + explanation: str = Field(description="Explanation of what the code does") + expected_output: str = Field(description="What output is expected from the code") + libraries_used: list = Field(description="List of Python libraries used") + + +class CodeExecutionResult(BaseModel): + status: str = Field(description="Execution status: success, error, or retry") + output: str = Field(description="Output from code execution") + error_message: Optional[str] = Field( + description="Error message if execution failed" + ) + execution_time: float = Field(description="Time taken to execute the code") + attempt: int = Field(description="Attempt number") + + +class CSVAnalysisTools: + def __init__(self): + self.settings = get_settings() + self.csv_service = CSVAnalysisService() + + self.code_generation_model = GroqModel( + self.settings.groq_model_name, + provider=GroqProvider(api_key=self.settings.groq_api_key), + ) + + self.code_generation_agent = Agent( + self.code_generation_model, + instructions=CSV_ANALYSIS_PROMPT, + output_type=PythonCodeResponse, + ) + + self.code_fixing_model = GroqModel( + self.settings.groq_model_name, + provider=GroqProvider(api_key=self.settings.groq_api_key), + ) + + self.code_fixing_agent = Agent( + self.code_fixing_model, + instructions=CODE_FIX_PROMPT, + output_type=PythonCodeResponse, + ) + + def load_csv_data(self, csv_content: str, session_id: str) -> Dict[str, Any]: + return self.csv_service.load_csv_data(session_id, csv_content) + + def generate_analysis_code( + self, user_query: str, csv_context: CSVAnalysisContext + ) -> PythonCodeResponse: + prompt = f""" +CSV Headers: {csv_context.csv_headers} +Sample Data: {csv_context.sample_data[:3]} +User Query: {user_query} + +Generate Python code that: +1. Uses pandas for data manipulation +2. Creates visualizations if requested +3. Returns clear output +4. Handles the CSV data properly +""" + + result = self.code_generation_agent.run_sync(prompt) + return result.output + + def execute_analysis_code( + self, python_code: str, session_id: str, max_retries: int = 3 + ) -> CodeExecutionResult: + result = self.csv_service.execute_analysis(session_id, python_code, max_retries) + + return CodeExecutionResult( + status=result["status"], + output=result.get("output", ""), + error_message=result.get("error_message"), + execution_time=result.get("execution_time", 0.0), + attempt=result.get("attempt", 1), + ) + + def fix_code_error( + self, original_code: str, error_message: str, csv_context: CSVAnalysisContext + ) -> PythonCodeResponse: + prompt = f""" +Original Code: +{original_code} + +Error Message: +{error_message} + +CSV Headers: {csv_context.csv_headers} +Sample Data: {csv_context.sample_data[:3]} + +Please fix the code to resolve the error and ensure it works correctly. +""" + + result = self.code_fixing_agent.run_sync(prompt) + return result.output + + def get_csv_info(self, session_id: str) -> Dict[str, Any]: + return self.csv_service.get_csv_info(session_id) + + def close_session(self, session_id: str): + self.csv_service.close_session(session_id) + + +def create_csv_analysis_agent() -> Agent: + settings = get_settings() + + model = GroqModel( + settings.groq_model_name, provider=GroqProvider(api_key=settings.groq_api_key) + ) + + agent = Agent(model, instructions=CSV_AGENT_PROMPT, output_type=str) + + csv_tools = CSVAnalysisTools() + + @agent.tool + async def load_csv_data( + ctx: RunContext[None], csv_content: str, session_id: str + ) -> str: + result = csv_tools.load_csv_data(csv_content, session_id) + if result["status"] == "success": + return f"CSV loaded successfully! Shape: {result['shape']}, Columns: {result['columns']}" + else: + return f"Error loading CSV: {result['message']}" + + @agent.tool + async def generate_analysis_code( + ctx: RunContext[None], user_query: str, session_id: str + ) -> str: + csv_info = csv_tools.get_csv_info(session_id) + if csv_info["status"] != "success": + return f"Error: {csv_info['message']}" + + csv_context = CSVAnalysisContext( + session_id=session_id, + csv_content="", + csv_headers=csv_info["columns"], + sample_data=csv_info["sample_data"], + ) + + result = csv_tools.generate_analysis_code(user_query, csv_context) + return f"""Generated Python Code: +```python +{result.python_code} +``` + +Explanation: {result.explanation} +Expected Output: {result.expected_output} +Libraries Used: {', '.join(result.libraries_used)}""" + + @agent.tool + async def execute_analysis_code( + ctx: RunContext[None], python_code: str, session_id: str + ) -> str: + result = csv_tools.execute_analysis_code(python_code, session_id) + + if result.status == "success": + return f"""βœ… Code executed successfully! +Execution Time: {result.execution_time:.2f}s +Attempt: {result.attempt} + +Output: +{result.output}""" + else: + return f"""❌ Code execution failed! +Attempt: {result.attempt} +Error: {result.error_message} + +Output: +{result.output}""" + + @agent.tool + async def fix_code_error( + ctx: RunContext[None], original_code: str, error_message: str, session_id: str + ) -> str: + csv_info = csv_tools.get_csv_info(session_id) + if csv_info["status"] != "success": + return f"Error: {csv_info['message']}" + + csv_context = CSVAnalysisContext( + session_id=session_id, + csv_content="", + csv_headers=csv_info["columns"], + sample_data=csv_info["sample_data"], + ) + + result = csv_tools.fix_code_error(original_code, error_message, csv_context) + return f"""πŸ”§ Fixed Code: +```python +{result.python_code} +``` + +Explanation: {result.explanation} +Expected Output: {result.expected_output}""" + + @agent.tool + async def get_csv_info(ctx: RunContext[None], session_id: str) -> str: + result = csv_tools.get_csv_info(session_id) + if result["status"] == "success": + return f"""πŸ“Š CSV Information: +Shape: {result['shape']} +Columns: {result['columns']} +Data Types: {result['dtypes']} +Sample Data: {result['sample_data'][:2]}""" + else: + return f"Error: {result['message']}" + + return agent diff --git a/src/services/jupyter_service.py b/src/services/jupyter_service.py new file mode 100644 index 0000000..ecf9064 --- /dev/null +++ b/src/services/jupyter_service.py @@ -0,0 +1,269 @@ +""" +Jupyter service for executing Python code with CSV data analysis. +""" + +import os +import io +import jupyter_client +import inspect +import time +import re +import pandas as pd +from typing import Dict, Any, Optional +from dataclasses import dataclass + +from src.config.constants import EXECUTION_TIMEOUT, MAX_RETRIES + + +def clean_error_message(error_msg: str) -> str: + ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + cleaned_msg = ansi_escape.sub("", error_msg) + lines = cleaned_msg.split("\n") + lines = [line.strip() for line in lines if line.strip()] + cleaned_msg = "\n".join(lines) + return cleaned_msg + + +@dataclass +class ExecutionResult: + output: str + status: str + error_message: Optional[str] = None + execution_time: float = 0.0 + + +class SimpleJupyterClient: + def __init__(self): + self.clients: Dict[str, Any] = {} + self.globals: Dict[str, Dict[str, Any]] = {} + + def create_new_session( + self, session_id: str = "default", kernel_name: str = "querypls" + ) -> str: + if session_id in self.clients: + return session_id + + try: + km = jupyter_client.KernelManager(kernel_name=kernel_name) + km.start_kernel() + client = km.client() + self.clients[session_id] = client + self.globals[session_id] = {} + + # Set environment variables + for key, value in os.environ.items(): + self.execute_code(f"{key} = '{value}'", session_id) + + # Import common data science libraries + self.execute_code("import pandas as pd", session_id) + self.execute_code("import numpy as np", session_id) + self.execute_code("import matplotlib.pyplot as plt", session_id) + self.execute_code("import seaborn as sns", session_id) + + return session_id + except Exception as e: + # Fallback to default kernel + try: + km = jupyter_client.KernelManager() + km.start_kernel() + client = km.client() + self.clients[session_id] = client + self.globals[session_id] = {} + + # Set environment variables + for key, value in os.environ.items(): + self.execute_code(f"{key} = '{value}'", session_id) + + # Import common data science libraries + self.execute_code("import pandas as pd", session_id) + self.execute_code("import numpy as np", session_id) + self.execute_code("import matplotlib.pyplot as plt", session_id) + self.execute_code("import seaborn as sns", session_id) + + return session_id + except Exception as e2: + raise ValueError(f"Failed to create kernel: {str(e2)}") + + def execute_code(self, code: str, session_id: str = "default") -> ExecutionResult: + if session_id not in self.clients: + raise ValueError(f"Session {session_id} not found") + + client = self.clients[session_id] + start_time = time.time() + + msg_id = client.execute(code) + output = [] + timeout = time.time() + EXECUTION_TIMEOUT + status = "Success" + error_message = None + + while True: + try: + msg = client.get_iopub_msg(timeout=1) + if ( + "parent_header" not in msg + or msg["parent_header"].get("msg_id") != msg_id + ): + continue + + msg_type = msg.get("msg_type", "") + content = msg.get("content", {}) + + if msg_type == "execute_result": + output.append(str(content.get("data", {}).get("text/plain", ""))) + elif msg_type == "stream": + output.append(content.get("text", "")) + elif msg_type == "error": + error_traceback = "\n".join(content.get("traceback", [])) + cleaned_error = clean_error_message(error_traceback) + output.append(f"Error: {cleaned_error}") + error_message = cleaned_error + status = "Fail" + elif msg_type == "status" and content.get("execution_state") == "idle": + break + except Exception as e: + pass + + execution_time = time.time() - start_time + + return ExecutionResult( + output="\n".join(output).strip(), + status=status, + error_message=error_message, + execution_time=execution_time, + ) + + def import_function(self, func, session_id: str = "default") -> ExecutionResult: + if session_id not in self.globals: + raise ValueError(f"Session {session_id} not found") + + func_code = inspect.getsource(func) + func_name = func.__name__ + + result = self.execute_code(func_code, session_id) + if result.status == "Success": + self.globals[session_id][func_name] = func + + return result + + def close_session(self, session_id: str = "default"): + if session_id not in self.clients: + raise ValueError(f"Session {session_id} not found") + + client = self.clients[session_id] + client.stop_channels() + del self.clients[session_id] + del self.globals[session_id] + + def close_all_sessions(self): + for session_id in list(self.clients.keys()): + self.close_session(session_id) + + +class CSVAnalysisService: + def __init__(self): + self.jupyter_client = SimpleJupyterClient() + self.csv_data: Dict[str, pd.DataFrame] = {} + self.csv_headers: Dict[str, list] = {} + + def load_csv_data( + self, session_id: str, csv_content: str, filename: str = "data.csv" + ) -> Dict[str, Any]: + try: + self.jupyter_client.create_new_session(session_id) + + csv_code = f""" +import pandas as pd +import io + +csv_content = '''{csv_content}''' +df = pd.read_csv(io.StringIO(csv_content)) +print("CSV loaded successfully!") +print(f"Shape: {{df.shape}}") +print("\\nColumns:") +print(df.columns.tolist()) +print("\\nFirst few rows:") +print(df.head()) +""" + + result = self.jupyter_client.execute_code(csv_code, session_id) + + if result.status == "Success": + df = pd.read_csv(io.StringIO(csv_content)) + self.csv_data[session_id] = df + self.csv_headers[session_id] = df.columns.tolist() + + return { + "status": "success", + "message": "CSV loaded successfully", + "shape": df.shape, + "columns": df.columns.tolist(), + "sample_data": df.head().to_dict("records"), + } + else: + return { + "status": "error", + "message": result.error_message or "Failed to load CSV", + } + + except Exception as e: + return {"status": "error", "message": str(e)} + + def execute_analysis( + self, session_id: str, python_code: str, max_retries: int = MAX_RETRIES + ) -> Dict[str, Any]: + for attempt in range(max_retries): + try: + result = self.jupyter_client.execute_code(python_code, session_id) + + if result.status == "Success": + return { + "status": "success", + "output": result.output, + "execution_time": result.execution_time, + "attempt": attempt + 1, + } + else: + if attempt == max_retries - 1: + return { + "status": "error", + "error_message": result.error_message, + "output": result.output, + "attempt": attempt + 1, + } + continue + + except Exception as e: + if attempt == max_retries - 1: + return { + "status": "error", + "error_message": str(e), + "attempt": attempt + 1, + } + continue + + return { + "status": "error", + "error_message": "Max retries exceeded", + "attempt": max_retries, + } + + def get_csv_info(self, session_id: str) -> Dict[str, Any]: + if session_id not in self.csv_data: + return {"status": "error", "message": "No CSV data loaded for this session"} + + df = self.csv_data[session_id] + return { + "status": "success", + "shape": df.shape, + "columns": df.columns.tolist(), + "dtypes": df.dtypes.to_dict(), + "sample_data": df.head().to_dict("records"), + } + + def close_session(self, session_id: str): + self.jupyter_client.close_session(session_id) + if session_id in self.csv_data: + del self.csv_data[session_id] + if session_id in self.csv_headers: + del self.csv_headers[session_id] diff --git a/src/services/models.py b/src/services/models.py new file mode 100644 index 0000000..f237396 --- /dev/null +++ b/src/services/models.py @@ -0,0 +1,88 @@ +""" +Data models for the services. +""" + +from typing import Literal, Union, List, Optional +from pydantic import BaseModel, Field +from datetime import datetime + + +class RoutingDecision(BaseModel): + """Model for routing decisions.""" + + agent: Literal["CONVERSATION_AGENT", "SQL_AGENT", "CSV_AGENT"] = Field( + description="The agent that should handle the query" + ) + confidence: float = Field( + description="Confidence level in the routing decision", ge=0.0, le=1.0 + ) + reasoning: str = Field(description="Brief explanation of why this agent was chosen") + + +class ConversationResponse(BaseModel): + """Response for conversational queries.""" + + message: str = Field(description="Natural response to user query") + response_type: Literal["greeting", "help", "thanks", "goodbye", "general"] = Field( + description="Type of response" + ) + suggest_next: Optional[str] = Field( + description="Optional suggestion for what they could do next", default=None + ) + + +class SQLResponse(BaseModel): + """Response for SQL generation.""" + + sql_query: str = Field(description="The generated SQL query") + explanation: str = Field(description="Brief explanation of what the query does") + tables_used: List[str] = Field(description="Array of table names used in the query") + columns_selected: List[str] = Field( + description="Array of column names selected in the query" + ) + query_type: str = Field( + description="Type of query (SELECT, INSERT, UPDATE, DELETE, etc.)" + ) + complexity: Literal["SIMPLE", "MEDIUM", "COMPLEX"] = Field( + description="Query complexity level" + ) + estimated_rows: str = Field(description="Estimated number of rows returned") + execution_time: Optional[str] = Field( + description="Estimated execution time", default=None + ) + warnings: List[str] = Field( + description="Array of warnings about the query", default_factory=list + ) + + +class CSVAnalysisResponse(BaseModel): + """Response for CSV analysis.""" + + python_code: str = Field(description="The generated Python code") + explanation: str = Field(description="Brief explanation of what the code does") + expected_output: str = Field(description="What output is expected from the code") + libraries_used: List[str] = Field(description="Array of Python libraries used") + + +class CodeFixResponse(BaseModel): + """Response for code fixing.""" + + python_code: str = Field(description="The fixed Python code") + explanation: str = Field(description="Brief explanation of what was fixed") + expected_output: str = Field( + description="What output is expected from the fixed code" + ) + libraries_used: List[str] = Field(description="Array of Python libraries used") + + +class Failed(BaseModel): + """Unable to find a satisfactory response.""" + + error: str = Field(description="Error message explaining the failure") + + +# Union types for different response types +ConversationResult = Union[ConversationResponse, Failed] +SQLResult = Union[SQLResponse, Failed] +CSVAnalysisResult = Union[CSVAnalysisResponse, Failed] +CodeFixResult = Union[CodeFixResponse, Failed] diff --git a/src/services/routing_service.py b/src/services/routing_service.py new file mode 100644 index 0000000..40ca5a0 --- /dev/null +++ b/src/services/routing_service.py @@ -0,0 +1,452 @@ +""" +Intelligent routing service for determining which agent should handle user queries. +""" + +import json +from typing import List, Optional, Dict, Any +from pydantic_ai import Agent, RunContext +from pydantic_ai.models.groq import GroqModel +from pydantic_ai.providers.groq import GroqProvider + +from src.config.constants import WORST_CASE_SCENARIO +from src.config.settings import get_settings +from src.services.models import ( + RoutingDecision, + ConversationResult, + SQLResult, + CSVAnalysisResult, +) +from src.schemas.requests import ChatMessage +from utils.prompt import ( + ROUTING_PROMPT, + CONVERSATION_PROMPT, + SQL_GENERATION_PROMPT, + CSV_ANALYSIS_PROMPT, +) + + +class IntelligentRoutingService: + """Service for intelligently routing user queries to appropriate agents.""" + + def __init__(self): + self.settings = get_settings() + + self.model = GroqModel( + self.settings.groq_model_name, + provider=GroqProvider(api_key=self.settings.groq_api_key), + ) + + # Create routing agent + self.routing_agent = Agent[None, RoutingDecision]( + self.model, output_type=RoutingDecision, system_prompt=ROUTING_PROMPT + ) + + # Create conversation agent + self.conversation_agent = Agent[None, ConversationResult]( + self.model, + output_type=ConversationResult, + system_prompt=CONVERSATION_PROMPT, + ) + + # Create SQL agent + self.sql_agent = Agent[None, SQLResult]( + self.model, output_type=SQLResult, system_prompt=SQL_GENERATION_PROMPT + ) + + # Create CSV analysis agent + self.csv_agent = Agent[None, CSVAnalysisResult]( + self.model, output_type=CSVAnalysisResult, system_prompt=CSV_ANALYSIS_PROMPT + ) + + def determine_agent( + self, + user_query: str, + conversation_history: List[ChatMessage], + csv_loaded: bool = False, + ) -> RoutingDecision: + """Determine which agent should handle the user query.""" + try: + # Prepare context for routing + context = self._prepare_routing_context( + user_query, conversation_history, csv_loaded + ) + + result = self.routing_agent.run_sync(context) + return result.output + + except Exception as e: + print(f"Routing failed with error: {e}") + # Use simple keyword-based routing as fallback + return self._keyword_based_routing(user_query, csv_loaded) + + def handle_conversation_query(self, user_query: str) -> str: + """Handle conversational queries.""" + try: + result = self.conversation_agent.run_sync(user_query) + + if hasattr(result.output, "message"): + return result.output.message + else: + return self._get_fallback_conversation_response(user_query) + + except Exception as e: + return self._get_fallback_conversation_response(user_query) + + def handle_sql_query( + self, user_query: str, conversation_history: List[ChatMessage] + ) -> str: + """Handle SQL generation queries.""" + try: + context = self._prepare_sql_context(user_query, conversation_history) + result = self.sql_agent.run_sync(context) + + if hasattr(result.output, "sql_query"): + return self._format_sql_response(result.output) + else: + return "I'm sorry, I couldn't generate a SQL query for that request. Could you please rephrase your question?" + + except Exception as e: + return f"I encountered an error while generating SQL: {str(e)}" + + def handle_csv_query( + self, + user_query: str, + csv_info: Dict[str, Any], + conversation_history: Optional[List[ChatMessage]] = None, + ) -> str: + """Handle CSV analysis queries.""" + try: + # Use the AI agent to generate code based on user request and conversation history + context = self._prepare_csv_context( + user_query, csv_info, conversation_history + ) + result = self.csv_agent.run_sync(context) + + if hasattr(result.output, "python_code"): + # Execute the generated code using Jupyter service + return self._execute_csv_analysis( + result.output.python_code, csv_info, result.output.explanation + ) + else: + return "I'm sorry, I couldn't generate analysis code for that request. Could you please rephrase your question?" + + except Exception as e: + # If LLM fails, provide a graceful response without showing errors + return WORST_CASE_SCENARIO + + def _execute_csv_analysis( + self, python_code: str, csv_info: Dict[str, Any], explanation: str + ) -> str: + """Execute CSV analysis code using Jupyter service with error fixing retry loop.""" + try: + from src.services.jupyter_service import CSVAnalysisService + + # Create Jupyter service instance + jupyter_service = CSVAnalysisService() + + # Create a temporary session for this analysis + session_id = "csv_analysis_temp" + + # Load CSV data into the session + jupyter_service.load_csv_data(session_id, csv_info["file_path"]) + + # Install required libraries if needed + install_code = """ +import sys +import subprocess + +def install_package(package): + try: + __import__(package) + except ImportError: + subprocess.check_call([sys.executable, "-m", "pip", "install", package]) + +# Install required packages +install_package('pandas') +install_package('numpy') +install_package('matplotlib') +install_package('seaborn') +""" + + # Execute installation first + install_result = jupyter_service.execute_analysis( + session_id, install_code, max_retries=1 + ) + + # Retry loop for code execution with error fixing + current_code = python_code + max_retries = 3 + + for attempt in range(max_retries): + # Execute the current code + result = jupyter_service.execute_analysis( + session_id, current_code, max_retries=1 + ) + + if result["status"] == "success": + output = result.get("output", "") + + # If output is empty, provide a fallback + if not output.strip(): + output = "Analysis completed successfully but no output was generated." + + # Check if any images were created in the specific session directory + import os + import glob + + # Look for images in the session's temp directory + session_temp_dir = f"/tmp/querypls_session_csv_analysis_temp" + image_files = [] + + if os.path.exists(session_temp_dir): + png_files = glob.glob(os.path.join(session_temp_dir, "*.png")) + jpg_files = glob.glob(os.path.join(session_temp_dir, "*.jpg")) + image_files.extend(png_files + jpg_files) + + if image_files: + image_info = "\n\nπŸ“Š **Charts generated:**\n" + for img_file in image_files: + image_info += f"- {os.path.basename(img_file)}\n" + output += image_info + + # Return only the human-readable output, not technical details + return output.strip() + + else: + # Code execution failed - try to fix it + error_msg = result.get("error_message", "Unknown error") + + if attempt < max_retries - 1: # Not the last attempt + # Send error to LLM to fix the code + fixed_code = self._fix_python_code( + current_code, error_msg, csv_info + ) + if fixed_code: + current_code = fixed_code + continue # Try again with fixed code + + return WORST_CASE_SCENARIO + + except Exception as e: + return WORST_CASE_SCENARIO + + except Exception as e: + return WORST_CASE_SCENARIO + + return WORST_CASE_SCENARIO + + def _fix_python_code( + self, original_code: str, error_message: str, csv_info: Dict[str, Any] + ) -> Optional[str]: + """Send error to LLM to fix the Python code.""" + try: + context = self._prepare_code_fix_context( + original_code, error_message, csv_info + ) + + result = self.csv_agent.run_sync(context) + + if hasattr(result.output, "python_code"): + return result.output.python_code + else: + return None + + except Exception as e: + return None + + def _prepare_routing_context( + self, user_query: str, conversation_history: List[ChatMessage], csv_loaded: bool + ) -> str: + """Prepare context for routing decision.""" + context_parts = [ + f"User Query: {user_query}", + f"CSV Data Loaded: {csv_loaded}", + ] + + if conversation_history: + # Last 5 messages for context + recent_messages = conversation_history[-5:] + context_parts.append("Recent Conversation History:") + for msg in recent_messages: + context_parts.append(f"- {msg.role}: {msg.content}") + + return "\n".join(context_parts) + + def _prepare_sql_context( + self, user_query: str, conversation_history: List[ChatMessage] + ) -> str: + """Prepare context for SQL generation.""" + context_parts = [ + f"User Query: {user_query}", + ] + + if conversation_history: + context_parts.append("Conversation History:") + for msg in conversation_history[-10:]: # Last 10 messages + context_parts.append(f"- {msg.role}: {msg.content}") + + return "\n".join(context_parts) + + def _prepare_csv_context( + self, + user_query: str, + csv_info: Dict[str, Any], + conversation_history: Optional[List[ChatMessage]] = None, + ) -> str: + """Prepare context for CSV analysis.""" + context_parts = [ + f"User Query: {user_query}", + f"CSV Data Available: Yes", + f"CSV File Path: {csv_info['file_path']}", + f"CSV Shape: {csv_info['shape']}", + f"CSV Columns: {csv_info['columns']}", + f"CSV Data Types: {csv_info['dtypes']}", + f"CSV Sample Data: {csv_info['sample_data']}", + ] + + if conversation_history: + context_parts.append("Conversation History:") + # Last 5 messages for context + for msg in conversation_history[-5:]: + context_parts.append(f"- {msg.role}: {msg.content}") + + context_parts.append( + "\nGenerate SUPER SIMPLE Python code that directly answers the user's question." + ) + context_parts.append("MAXIMUM 5 LINES OF CODE - Keep it extremely simple!") + context_parts.append( + "NO FUNCTIONS OR CLASSES - Just direct code that prints results!" + ) + context_parts.append( + f"IMPORTANT: Use pd.read_csv('{csv_info['file_path']}') to load the data from the file path!" + ) + context_parts.append( + "Print human-readable results like 'Average price: $123.45' - NO technical output!" + ) + context_parts.append( + "For charts, use plt.savefig('/tmp/querypls_session_csv_analysis_temp/chart.png') and plt.show()." + ) + + return "\n".join(context_parts) + + def _prepare_code_fix_context( + self, original_code: str, error_message: str, csv_info: Dict[str, Any] + ) -> str: + """Prepare context for code fixing.""" + context_parts = [ + "CODE FIXING REQUEST:", + f"Original Code: {original_code}", + f"Error Message: {error_message}", + f"CSV File Path: {csv_info['file_path']}", + f"CSV Shape: {csv_info['shape']}", + f"CSV Columns: {csv_info['columns']}", + f"CSV Data Types: {csv_info['dtypes']}", + f"CSV Sample Data: {csv_info['sample_data']}", + "", + "INSTRUCTIONS:", + "The above Python code failed to execute. Please fix the code and return a working version.", + "Follow these guidelines:", + "1. Keep code SIMPLE - Maximum 6 lines", + "2. NO SPECIAL CHARACTERS - Use standard ASCII only", + "3. NO FUNCTIONS - Write code directly", + "4. NO DOCSTRINGS - No complex documentation", + "5. Use pd.read_csv('file_path') to load data", + "6. Print human-readable insights directly", + "7. For charts, save to /tmp/querypls_session_csv_analysis_temp/", + "", + "Generate fixed Python code that will execute without errors.", + ] + + return "\n".join(context_parts) + + def _format_sql_response(self, sql_response) -> str: + """Format SQL response for display.""" + response_parts = [ + f"**SQL Query:**\n```sql\n{sql_response.sql_query}\n```", + f"**Explanation:** {sql_response.explanation}", + f"**Query Type:** {sql_response.query_type}", + f"**Complexity:** {sql_response.complexity}", + f"**Tables Used:** {', '.join(sql_response.tables_used)}", + f"**Columns Selected:** {', '.join(sql_response.columns_selected)}", + f"**Estimated Rows:** {sql_response.estimated_rows}", + ] + + if sql_response.warnings: + response_parts.append(f"**Warnings:** {', '.join(sql_response.warnings)}") + + return "\n\n".join(response_parts) + + def _keyword_based_routing(self, user_query: str, csv_loaded: bool) -> RoutingDecision: + """Keyword-based routing when LLM routing fails.""" + query_lower = user_query.lower() + + # CSV Agent keywords + csv_keywords = [ + "csv", "analyze", "chart", "plot", "graph", "average", "mean", "sum", + "count", "max", "min", "statistics", "data", "visualization", "top", + "bottom", "highest", "lowest", "distribution", "correlation" + ] + + # SQL Agent keywords + sql_keywords = [ + "select", "insert", "update", "delete", "sql", "query", "table", + "database", "users", "customers", "orders", "products", "where", + "join", "group by", "order by", "from" + ] + + # Conversation Agent keywords + conversation_keywords = [ + "hello", "hi", "hey", "how are you", "what can you do", "help", + "thanks", "thank you", "goodbye", "bye", "good morning", "good evening" + ] + + # Check for CSV analysis (prioritize if CSV is loaded) + if csv_loaded and any(keyword in query_lower for keyword in csv_keywords): + return RoutingDecision( + agent="CSV_AGENT", + confidence=0.8, + reasoning="Keyword-based routing detected CSV analysis request" + ) + + # Check for SQL keywords + if any(keyword in query_lower for keyword in sql_keywords): + return RoutingDecision( + agent="SQL_AGENT", + confidence=0.8, + reasoning="Keyword-based routing detected SQL request" + ) + + # Check for conversation keywords + if any(keyword in query_lower for keyword in conversation_keywords): + return RoutingDecision( + agent="CONVERSATION_AGENT", + confidence=0.9, + reasoning="Keyword-based routing detected conversation request" + ) + + # Default based on context + if csv_loaded: + return RoutingDecision( + agent="CSV_AGENT", + confidence=0.6, + reasoning="CSV loaded, defaulting to CSV analysis" + ) + else: + return RoutingDecision( + agent="CONVERSATION_AGENT", + confidence=0.5, + reasoning="No clear intent detected, defaulting to conversation" + ) + + def _fallback_routing(self, user_query: str, csv_loaded: bool) -> RoutingDecision: + """Fallback routing when LLM routing fails - let LLM decide, not hardcoded keywords.""" + # Default to conversation - let the LLM handle all decisions + return RoutingDecision( + agent="CONVERSATION_AGENT", + confidence=0.3, + reasoning="LLM routing failed, defaulting to conversation agent", + ) + + def _get_fallback_conversation_response(self, user_query: str) -> str: + """Get fallback conversation response when LLM fails.""" + return WORST_CASE_SCENARIO diff --git a/src/services/sql_service.py b/src/services/sql_service.py new file mode 100644 index 0000000..65b9bd8 --- /dev/null +++ b/src/services/sql_service.py @@ -0,0 +1,123 @@ +import json +import uuid +from datetime import datetime +from typing import Optional +from pydantic_ai import Agent +from pydantic_ai.models.groq import GroqModel +from pydantic_ai.providers.groq import GroqProvider + +from src.config.settings import get_settings +from src.schemas.requests import SQLGenerationRequest, ChatMessage +from src.schemas.responses import SQLQueryResponse, ChatResponse, ErrorResponse +from utils.prompt import SQL_GENERATION_PROMPT + + +class SQLGenerationService: + def __init__(self, api_key: Optional[str] = None): + self.settings = get_settings() + self.api_key = api_key or self.settings.groq_api_key + + if not self.api_key: + raise ValueError( + "Groq API key is required. Set GROQ_API_KEY environment variable or pass api_key parameter." + ) + + self.model = GroqModel( + self.settings.groq_model_name, provider=GroqProvider(api_key=self.api_key) + ) + + self.agent = Agent( + self.model, instructions=SQL_GENERATION_PROMPT, output_type=SQLQueryResponse + ) + + def format_chat_history(self, messages: list) -> str: + history = [] + for msg in messages[1:]: + if isinstance(msg, ChatMessage): + content = msg.content + role = msg.role + else: + content = msg.get("content", "") + role = msg.get("role", "user") + + if "```sql" in content: + content = content.replace("```sql\n", "").replace("\n```", "").strip() + + history.append( + {"role": role, "query" if role == "user" else "response": content} + ) + + return json.dumps(history, indent=2) + + def generate_sql(self, request: SQLGenerationRequest) -> ChatResponse: + try: + formatted_history = self.format_chat_history(request.conversation_history) + prompt = f"Previous conversation: {formatted_history}\nCurrent question: { + request.user_query}" + + result = self.agent.run_sync(prompt) + + sql_response = SQLQueryResponse( + sql_query=result.output.sql_query, + explanation=result.output.explanation, + tables_used=result.output.tables_used, + columns_selected=result.output.columns_selected, + query_type=result.output.query_type, + complexity=result.output.complexity, + estimated_rows=result.output.estimated_rows, + execution_time=result.output.execution_time, + warnings=result.output.warnings, + ) + + formatted_content = f"```sql\n{ + sql_response.sql_query}\n```\n\n**Explanation:** { + sql_response.explanation}" + + session_id = "default" + if request.conversation_history: + first_msg = request.conversation_history[0] + if isinstance(first_msg, ChatMessage): + session_id = first_msg.session_id or "default" + else: + session_id = first_msg.get("session_id", "default") + + chat_response = ChatResponse( + message_id=str(uuid.uuid4()), + content=formatted_content, + sql_response=sql_response, + timestamp=datetime.now().isoformat(), + session_id=session_id, + ) + + return chat_response + + except Exception as e: + error_response = ErrorResponse( + error_code="SQL_GENERATION_ERROR", + error_message=f"Error generating SQL: {str(e)}", + details=str(e), + timestamp=datetime.now().isoformat(), + ) + + session_id = "default" + if request.conversation_history: + first_msg = request.conversation_history[0] + if isinstance(first_msg, ChatMessage): + session_id = first_msg.session_id or "default" + else: + session_id = first_msg.get("session_id", "default") + + return ChatResponse( + message_id=str(uuid.uuid4()), + content=f"❌ Error: {error_response.error_message}", + timestamp=datetime.now().isoformat(), + session_id=session_id, + ) + + def generate_sql_legacy(self, user_query: str, conversation_history: list) -> str: + request = SQLGenerationRequest( + user_query=user_query, conversation_history=conversation_history + ) + + response = self.generate_sql(request) + return response.content diff --git a/src/terminal/__init__.py b/src/terminal/__init__.py new file mode 100644 index 0000000..b353cc9 --- /dev/null +++ b/src/terminal/__init__.py @@ -0,0 +1,3 @@ +""" +Terminal interface package for Querypls. +""" diff --git a/src/terminal/cli.py b/src/terminal/cli.py new file mode 100644 index 0000000..841281d --- /dev/null +++ b/src/terminal/cli.py @@ -0,0 +1,181 @@ +""" +Command-line interface for Querypls SQL generation. +""" + +from src.config.constants import DEFAULT_SESSION_NAME +from src.schemas.requests import NewChatRequest +from src.backend.orchestrator import BackendOrchestrator +import sys +import os +import json +from typing import Optional + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +class QueryplsCLI: + def __init__(self): + self.orchestrator = BackendOrchestrator() + self.current_session_id = None + + def create_session(self, name: Optional[str] = None) -> str: + request = NewChatRequest(session_name=name) + session_info = self.orchestrator.create_new_session(request) + self.current_session_id = session_info.session_id + print( + f"""Session created: { + session_info.session_name} (ID: { + session_info.session_id})""" + ) + return session_info.session_id + + def list_sessions(self): + sessions = self.orchestrator.list_sessions() + if not sessions: + print("No sessions found.") + return + print("Available sessions:") + for i, session in enumerate(sessions, 1): + print(f"{i}. {session.session_name}") + print(f" ID: {session.session_id}") + print(f" Messages: {session.message_count}") + print(f" Last activity: {session.last_activity}") + print() + + def switch_session(self, session_id: str): + session = self.orchestrator.get_session(session_id) + if not session: + print(f"Session not found: {session_id}") + return + self.current_session_id = session_id + print(f"Switched to session: {session.session_name}") + + def chat(self, query: str): + if not self.current_session_id: + print("No active session. Please create or switch to a session.") + return + + try: + response = self.orchestrator.generate_sql_response( + self.current_session_id, query + ) + print("\nResponse generated:") + print(response.content) + + if response.sql_response: + print("\nSQL Details:") + print(f" Query Type: {response.sql_response.query_type}") + print(f" Complexity: {response.sql_response.complexity}") + print(f" Tables Used: {', '.join(response.sql_response.tables_used)}") + print(f" Columns: {', '.join(response.sql_response.columns_selected)}") + print(f" Estimated Rows: {response.sql_response.estimated_rows}") + if response.sql_response.warnings: + print(f" Warnings: {', '.join(response.sql_response.warnings)}") + + except Exception as e: + print(f"Error: {str(e)}") + + def show_history(self): + if not self.current_session_id: + print("No session selected.") + return + + try: + conversation = self.orchestrator.get_conversation_history( + self.current_session_id + ) + print("\nConversation history:") + for message in conversation.messages: + print(f" {message.role.upper()}: {message.content}") + except Exception as e: + print(f"Error: {str(e)}") + + def health_check(self): + try: + health = self.orchestrator.health_check() + print("Health check successful.") + print(f" Status: {health.status}") + print(f" Version: {health.version}") + print(f" Services: {json.dumps(health.services, indent=2)}") + except Exception as e: + print(f"Health check failed: {str(e)}") + + def run_interactive(self): + print("Welcome to Querypls CLI!") + print("Commands: new, list, switch , chat , history, health, quit") + print() + + self.create_session("CLI Session") + + while True: + try: + command = input("querypls> ").strip() + + if not command: + continue + + parts = command.split() + cmd = parts[0].lower() + + if cmd == "quit" or cmd == "exit": + print("Goodbye!") + break + elif cmd == "new": + name = " ".join(parts[1:]) if len(parts) > 1 else None + self.create_session(name) + elif cmd == "list": + self.list_sessions() + parts = command.split() + cmd = parts[0].lower() + + if cmd == "quit" or cmd == "exit": + print("Goodbye!") + break + elif cmd == "new": + name = " ".join(parts[1:]) if len(parts) > 1 else None + self.create_session(name) + elif cmd == "list": + self.list_sessions() + elif cmd == "switch" and len(parts) > 1: + self.switch_session(parts[1]) + elif cmd == "chat" and len(parts) > 1: + query = " ".join(parts[1:]) + self.chat(query) + elif cmd == "history": + self.show_history() + elif cmd == "health": + self.health_check() + else: + print("Unknown command.") + + except KeyboardInterrupt: + print("\nGoodbye!") + break + except Exception as e: + print(f"Error: {str(e)}") + + +def main(): + cli = QueryplsCLI() + + if len(sys.argv) > 1: + command = sys.argv[1] + if command == "new": + name = sys.argv[2] if len(sys.argv) > 2 else None + cli.create_session(name) + elif command == "list": + cli.list_sessions() + elif command == "chat" and len(sys.argv) > 2: + query = " ".join(sys.argv[2:]) + cli.create_session("CLI Session") + cli.chat(query) + elif command == "health": + cli.health_check() + else: + print("Usage: python cli.py [new|list|chat |health]") + else: + cli.run_interactive() + + +if __name__ == "__main__": + main() diff --git a/static/css/styles.css b/static/css/styles.css deleted file mode 100644 index fa6f7e4..0000000 --- a/static/css/styles.css +++ /dev/null @@ -1,85 +0,0 @@ - -#root > div:nth-child(1) > div.withScreencast > div > div > div > section > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.st-emotion-cache-ocqkz7.e1f1d6gn4 > div:nth-child(2){ - background-color: rgb(233, 240, 255); - border: 1px solid #152544 ; /* Add border with 1px thickness and black color */ - border-radius: 10px; /* Add border radius for rounded corners */ - padding: 10px; /* Add padding for spacing inside the element */ - -} -#root > div:nth-child(1) > div.withScreencast > div > div > div > section > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.st-emotion-cache-ocqkz7.e1f1d6gn4 > div:nth-child(2) > div > div > div > div > div:nth-child(3){ - background-color: #152544; - color:rgb(17, 17, 17) -} -#root > div:nth-child(1) > div.withScreencast > div > div > div > section > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.st-emotion-cache-ocqkz7.e1f1d6gn4{ - background-color: rgb(255, 255, 255); -} -#root > div:nth-child(1) > div.withScreencast > div > div > div > section > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.st-emotion-cache-ocqkz7.e1f1d6gn4 > div:nth-child(2) > div > div > div > div > div:nth-child(5){ - background-color: #152544; -} - - - - #root > div:nth-child(1) > div.withScreencast > div > div > div > section > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.st-emotion-cache-ocqkz7.e1f1d6gn4 > div:nth-child(2) > div > div > div > div > div:nth-child(3) > div > div > p > code{ - color:white; - background-color: #152544; - } - #root > div:nth-child(1) > div.withScreencast > div > div > div > section.st-emotion-cache-nju155.eczjsme11 > div.st-emotion-cache-6qob1r.eczjsme3{ - background-color: rgb(184, 205, 252); - } - #root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.st-emotion-cache-uf99v8.ea3mdgi5{ - border: 2px solid rgb(0, 2, 8) ; /* Add border with 1px thickness and black color */ - border-radius: 0px; /* Add border radius for rounded corners */ - padding: 10px; /* Add padding for spacing inside the element */ - } - #root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.st-emotion-cache-uf99v8.ea3mdgi5 > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.stChatMessage.st-emotion-cache-4oy321.eeusbqq4{ - background-color: rgb(233, 240, 255); - color: #152544; - - } - #root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.st-emotion-cache-uf99v8.ea3mdgi5 > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.stChatMessage.st-emotion-cache-4oy321.eeusbqq4 > div.st-emotion-cache-14m9yky.eeusbqq3{ - color:#152544; - background-color: rgb(233, 240, 255); - } - #root > div:nth-child(1) > div.withScreencast > div > div > div > section.main.st-emotion-cache-uf99v8.ea3mdgi5 > div.block-container.st-emotion-cache-z5fcl4.ea3mdgi4 > div > div > div.element-container.st-emotion-cache-10gv909.e1f1d6gn3 > div > div > div > div.st-bg.st-b4.st-bh.st-co.st-bj.st-bk.st-bl.st-bm.st-bn.st-bo.st-bp.st-bq.st-br.st-b2.st-bs.st-av.st-ay.st-aw.st-ax.st-bt.st-bu.st-bv.st-bw.st-bx.st-by.st-bz.st-c0 > div{ - background-color: rgb(12, 51, 158); - } - #root > div:nth-child(1) > div.withScreencast > div > div > header{ - background-color: rgb(233, 240, 255); - color: #152544; -} - -#root > div:nth-child(1) > div.withScreencast > div > div > div > div > button > svg{ - color:#152544; - width: 20px -} -.sidebar{ - background-color: rgb(233, 240, 255); -} -.stButton { - /* Add styles for the stButton class */ - background-color: rgb(218, 218, 230); - color: #152544; - /* padding: 10px 20px; */ - border: 3px; - border-radius: 5px; - width: 40px; - cursor: pointer; - - -} -.row-widget { - /* Add styles for the row-widget class */ - margin-bottom: 10px; - background-color: rgb(233, 240, 255); - color:#152544; - -} - -#root > div:nth-child(1) > div.withScreencast > div > div > div > section.st-emotion-cache-nju155.eczjsme11 > div.st-emotion-cache-6qob1r.eczjsme3 > div.st-emotion-cache-16txtl3.eczjsme4 > div > div > div > div:nth-child(5){ - background-color:rgb(0, 0, 0) ; - color: #152544; -} - -#root > div:nth-child(1) > div.withScreencast > div > div > header > div.st-emotion-cache-zq5wmm.ezrtsby0 > div > div:nth-child(2) > button > div > div { - display: none; -} \ No newline at end of file diff --git a/test_app.py b/test_app.py new file mode 100644 index 0000000..a3c40b0 --- /dev/null +++ b/test_app.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python3 +""" +Simple test to verify the application components are working. +""" + +import sys +import os + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + + +def test_imports(): + """Test that all imports work correctly.""" + print("Testing imports...") + + try: + from src.config.constants import WELCOME_MESSAGE, DEFAULT_SESSION_NAME + + print("βœ… Constants imported successfully") + print(f" WELCOME_MESSAGE: {WELCOME_MESSAGE[:50]}...") + print(f" DEFAULT_SESSION_NAME: {DEFAULT_SESSION_NAME}") + except ImportError as e: + print(f"❌ Error importing constants: {e}") + return False + + try: + from src.services.routing_service import IntelligentRoutingService + + print("βœ… Routing service imported successfully") + except ImportError as e: + print(f"❌ Error importing routing service: {e}") + return False + + try: + from src.backend.orchestrator import BackendOrchestrator + + print("βœ… Orchestrator imported successfully") + except ImportError as e: + print(f"❌ Error importing orchestrator: {e}") + return False + + return True + + +def test_routing(): + """Test the routing service.""" + print("\nTesting routing service...") + + try: + from src.services.routing_service import IntelligentRoutingService + + routing_service = IntelligentRoutingService() + + # Test routing decisions + test_cases = [ + ("Hello", "CONVERSATION_AGENT"), + ("Show me all users", "SQL_AGENT"), + ("Analyze this CSV data", "CSV_AGENT"), + ] + + for query, expected in test_cases: + decision = routing_service.determine_agent(query, [], csv_loaded=False) + status = "βœ…" if decision.agent == expected else "❌" + print(f" {status} '{query}' β†’ {decision.agent} (expected: {expected})") + + return True + except Exception as e: + print(f"❌ Error testing routing: {e}") + return False + + +if __name__ == "__main__": + print("Querypls Application Test") + print("=" * 40) + + success = True + success &= test_imports() + success &= test_routing() + + print("\n" + "=" * 40) + if success: + print("βœ… All tests passed! Application is ready.") + else: + print("❌ Some tests failed. Please check the errors above.") diff --git a/test_routing_fix.py b/test_routing_fix.py new file mode 100644 index 0000000..4c425ae --- /dev/null +++ b/test_routing_fix.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +"""Test script to verify the routing fix is working.""" + +import sys +import os +import tempfile +import pandas as pd + +sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + +from src.services.routing_service import IntelligentRoutingService +from src.backend.orchestrator import BackendOrchestrator +from src.schemas.requests import NewChatRequest + + +def test_routing_only(): + """Test just the routing mechanism.""" + print("🧠 Testing Routing Mechanism") + print("=" * 40) + + routing_service = IntelligentRoutingService() + + test_cases = [ + ("Hello", False, "CONVERSATION_AGENT"), + ("What is the average salary?", True, "CSV_AGENT"), + ("Show me all users", False, "SQL_AGENT"), + ("Create a chart", True, "CSV_AGENT"), + ("SELECT * FROM users", False, "SQL_AGENT"), + ] + + for query, csv_loaded, expected in test_cases: + print(f"\nQuery: '{query}' (CSV loaded: {csv_loaded})") + decision = routing_service.determine_agent(query, [], csv_loaded=csv_loaded) + print(f"Expected: {expected}") + print(f"Actual: {decision.agent}") + print(f"Confidence: {decision.confidence}") + print(f"Reasoning: {decision.reasoning}") + + status = "βœ… PASS" if decision.agent == expected else "❌ FAIL" + print(f"Status: {status}") + + +def test_csv_analysis_with_real_data(): + """Test CSV analysis with actual CSV data.""" + print("\nπŸ“Š Testing CSV Analysis with Real Data") + print("=" * 40) + + # Create a temporary CSV file + data = { + 'name': ['Alice', 'Bob', 'Charlie', 'Diana', 'Eve'], + 'age': [25, 30, 35, 28, 32], + 'salary': [50000, 60000, 70000, 55000, 65000], + 'department': ['IT', 'HR', 'IT', 'Finance', 'HR'] + } + + df = pd.DataFrame(data) + + # Create temporary CSV file + with tempfile.NamedTemporaryFile(mode='w', suffix='.csv', delete=False) as f: + df.to_csv(f.name, index=False) + csv_path = f.name + + print(f"Created test CSV: {csv_path}") + print("CSV Content:") + print(df.to_string(index=False)) + + # Create orchestrator and test CSV analysis + orchestrator = BackendOrchestrator() + + # Create session + session_info = orchestrator.create_new_session( + NewChatRequest(session_name="Test Session") + ) + session_id = session_info.session_id + print(f"\nCreated session: {session_id}") + + # Load CSV data + with open(csv_path, 'r') as f: + csv_content = f.read() + + result = orchestrator.load_csv_data(session_id, csv_content) + print(f"CSV Load Result: {result['status']}") + + # Test CSV analysis queries + test_queries = [ + "What is the average salary?", + "How many people are in each department?", + "Who has the highest salary?", + ] + + for query in test_queries: + print(f"\n--- Testing Query: '{query}' ---") + try: + response = orchestrator.generate_intelligent_response(session_id, query) + print(f"Response: {response.content}") + print(f"Response Type: {type(response.content)}") + + # Check if this is raw Python code (the old problem) + if "import" in response.content or "pd.read_csv" in response.content: + print("❌ ISSUE: Response contains raw Python code!") + else: + print("βœ… SUCCESS: Response is clean human-readable text!") + + except Exception as e: + print(f"❌ ERROR: {str(e)}") + + # Cleanup + os.unlink(csv_path) + + +def main(): + """Run all tests.""" + print("πŸš€ Testing Routing Fix") + print("=" * 50) + + # Test 1: Routing mechanism + test_routing_only() + + # Test 2: CSV analysis with real data + test_csv_analysis_with_real_data() + + print("\nπŸŽ‰ Testing completed!") + + +if __name__ == "__main__": + main() diff --git a/tests/__init__.py b/tests/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..034d1b0 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,4 @@ +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) diff --git a/tests/test_auth.py b/tests/test_auth.py deleted file mode 100644 index 2b9366a..0000000 --- a/tests/test_auth.py +++ /dev/null @@ -1,59 +0,0 @@ -import pytest -from unittest.mock import AsyncMock, patch -from httpx_oauth.clients.google import GoogleOAuth2 -from src.constant import * -from src.auth import ( - get_authorization_url, - get_access_token, - get_email, - get_login_str, -) - - -@pytest.mark.asyncio -async def test_get_authorization_url(): - client = GoogleOAuth2("client_id", "client_secret") - redirect_uri = "http://example.com/callback" - with patch.object( - client, "get_authorization_url", new=AsyncMock() - ) as mock_method: - await get_authorization_url(client, redirect_uri) - mock_method.assert_called_once_with( - redirect_uri, scope=["profile", "email"] - ) - - -@pytest.mark.asyncio -async def test_get_access_token(): - client = GoogleOAuth2("client_id", "client_secret") - redirect_uri = "http://example.com/callback" - code = "code" - with patch.object( - client, "get_access_token", new=AsyncMock() - ) as mock_method: - await get_access_token(client, redirect_uri, code) - mock_method.assert_called_once_with(code, redirect_uri) - - -@pytest.mark.asyncio -async def test_get_email(): - client = GoogleOAuth2("client_id", "client_secret") - token = "token" - with patch.object( - client, - "get_id_email", - new=AsyncMock(return_value=("user_id", "user_email")), - ) as mock_method: - user_id, user_email = await get_email(client, token) - mock_method.assert_called_once_with(token) - assert user_id == "user_id" - assert user_email == "user_email" - - -def test_get_login_str(): - with patch("asyncio.run") as mock_run: - mock_run.return_value = "authorization_url" - result = get_login_str() - mock_run.assert_called_once() - assert '' in result - assert "Login with Google" in result diff --git a/tests/test_backend.py b/tests/test_backend.py index 9bc9840..2ec72b4 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -2,7 +2,7 @@ from unittest.mock import patch, MagicMock import sys, os from src.backend import * -from src.constant import * +from src.config.constants import * sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) @@ -29,21 +29,3 @@ def mock_set_page_config(): def mock_oauth2_component(): with patch("streamlit_oauth.OAuth2Component") as mock_oauth2_component: yield mock_oauth2_component - - -def test_configure_page_styles(mock_open, mock_markdown, mock_set_page_config): - mock_open.return_value.__enter__.return_value.read.return_value = "test" - configure_page_styles("test_file") - mock_set_page_config.assert_called_once_with( - page_title="Querypls", page_icon="πŸ’¬", layout="wide" - ) - mock_markdown.assert_called() - mock_open.assert_called_once_with("test_file") - - -def test_hide_main_menu_and_footer(mock_markdown): - hide_main_menu_and_footer() - mock_markdown.assert_called_once_with( - """""", - unsafe_allow_html=True, - ) diff --git a/tests/test_basic.py b/tests/test_basic.py new file mode 100644 index 0000000..5325048 --- /dev/null +++ b/tests/test_basic.py @@ -0,0 +1,46 @@ +import pytest +import os +import sys +from datetime import datetime, timedelta + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from src.schemas.requests import ChatMessage, NewChatRequest +from src.schemas.responses import SessionInfo, ChatResponse +from src.backend.orchestrator import BackendOrchestrator + + +def test_create_new_session(): + orchestrator = BackendOrchestrator() + session_info = orchestrator.create_new_session( + NewChatRequest(session_name="Test Chat") + ) + assert session_info.session_name == "Test Chat" + assert session_info.session_id is not None + + +def test_list_sessions(): + orchestrator = BackendOrchestrator() + session1 = orchestrator.create_new_session(NewChatRequest(session_name="Chat 1")) + session2 = orchestrator.create_new_session(NewChatRequest(session_name="Chat 2")) + sessions = orchestrator.list_sessions() + assert len(sessions) == 2 + assert any(s.session_name == "Chat 1" for s in sessions) + assert any(s.session_name == "Chat 2" for s in sessions) + + +def test_health_check(): + orchestrator = BackendOrchestrator() + health = orchestrator.health_check() + assert health.status in ["healthy", "unhealthy"] + assert isinstance(health.version, str) + assert isinstance(health.timestamp, str) + + +def test_session_message_flow(): + orchestrator = BackendOrchestrator() + session_info = orchestrator.create_new_session( + NewChatRequest(session_name="Test Session") + ) + assert session_info.session_name == "Test Session" + assert session_info.session_id is not None diff --git a/tests/test_conversation_service.py b/tests/test_conversation_service.py new file mode 100644 index 0000000..d985557 --- /dev/null +++ b/tests/test_conversation_service.py @@ -0,0 +1,31 @@ +import pytest +from src.services.conversation_service import ConversationService + + +def test_is_conversational_query(): + service = ConversationService() + assert service.is_conversational_query("hello") is True + assert service.is_conversational_query("hi there") is True + assert service.is_conversational_query("how are you") is True + assert service.is_conversational_query("select * from users") is False + + +def test_get_conversational_response(): + service = ConversationService() + response = service.get_conversational_response("hello") + assert response is not None + assert len(response) > 0 + assert "hello" in response.lower() or "hi" in response.lower() + + +def test_get_conversational_response_help(): + service = ConversationService() + response = service.get_conversational_response("what can you do?") + assert "SQL" in response + assert "data analysis" in response.lower() + + +def test_get_conversational_response_thanks(): + service = ConversationService() + response = service.get_conversational_response("thank you") + assert "help you today" in response.lower() diff --git a/tests/test_csv_analysis.py b/tests/test_csv_analysis.py new file mode 100644 index 0000000..a4bcbe0 --- /dev/null +++ b/tests/test_csv_analysis.py @@ -0,0 +1,22 @@ +import pytest +from src.services.csv_analysis_tools import CSVAnalysisTools, CSVAnalysisContext + + +def test_load_csv_data(): + tools = CSVAnalysisTools() + csv_content = "name,age\nJohn,30\nJane,25" + result = tools.load_csv_data(csv_content, "test_session") + assert result["status"] == "success" + assert "name" in result["columns"] + assert "age" in result["columns"] + + +def test_get_csv_info(): + tools = CSVAnalysisTools() + csv_content = "name,age\nJohn,30\nJane,25" + tools.load_csv_data(csv_content, "test_session") + info = tools.get_csv_info("test_session") + assert info["status"] == "success" + assert "shape" in info + assert "columns" in info + assert "dtypes" in info diff --git a/tests/test_frontend.py b/tests/test_frontend.py deleted file mode 100644 index b6e4979..0000000 --- a/tests/test_frontend.py +++ /dev/null @@ -1,101 +0,0 @@ -import pytest -from unittest.mock import patch, MagicMock -import streamlit as st -from src.frontend import ( - display_logo_and_heading, - display_welcome_message, - handle_new_chat, - display_previous_chats, - create_message, - update_session_state, -) - - -@pytest.fixture -def mock_st(): - return MagicMock() - - -@pytest.fixture -def mock_db(): - return MagicMock() - - -class MockSessionState: - def __getitem__(self, key): - return getattr(self, key) - - def __setitem__(self, key, value): - setattr(self, key, value) - - -def initialize_session_state(messages=None, key=None, user_email=None): - st.session_state = MockSessionState() - st.session_state.messages = messages or [] - st.session_state.key = key - st.session_state.user_email = user_email - - -def test_display_logo_and_heading(mock_st): - with patch.object(st, "image") as mock_image: - display_logo_and_heading() - mock_image.assert_called_once_with("static/image/logo.png") - - -def test_display_welcome_message(mock_st): - with patch.object(st, "markdown") as mock_markdown: - with patch.object(st, "session_state", MockSessionState()): - initialize_session_state( - messages=[ - {"role": "assistant", "content": "How may I help you?"} - ] - ) - display_welcome_message() - mock_markdown.assert_called_once_with( - "#### Welcome to \n ## πŸ›’πŸ’¬Querypls - Prompt to SQL" - ) - - -def test_handle_new_chat(mock_db, mock_st): - with patch("src.frontend.get_previous_chats") as mock_get_previous_chats: - mock_get_previous_chats.return_value = [] - with patch.object(st, "markdown") as mock_markdown, patch.object( - st, "button" - ) as mock_button: - with patch.object(st, "session_state", MockSessionState()): - initialize_session_state( - messages=[], user_email="test@example.com" - ) - handle_new_chat(mock_db, max_chat_histories=5) - mock_markdown.assert_called_once_with( - " #### Remaining Chats: `5/5`" - ) - mock_button.assert_called_once_with("βž• New chat") - - -def test_create_message(): - with patch.object(st, "session_state", MockSessionState()): - initialize_session_state(messages=[], key=None) - create_message() - assert st.session_state.messages == [ - {"role": "assistant", "content": "How may I help you?"} - ] - assert st.session_state.key == "key" - - -def test_update_session_state(mock_db): - chat = {"chat": [{"role": "user", "content": "Hello"}], "key": "new_key"} - with patch.object(st, "session_state", MockSessionState()): - initialize_session_state( - messages=[{"role": "assistant", "content": "How may I help you?"}], - key="old_key", - ) - with patch("src.frontend.database") as mock_database: - update_session_state(mock_db, chat) - mock_database.assert_called_once_with( - mock_db, - "old_key", - [{"role": "assistant", "content": "How may I help you?"}], - ) - assert st.session_state.messages == chat["chat"] - assert st.session_state.key == chat["key"] diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..90331d0 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,30 @@ +import pytest +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from src.schemas.requests import ChatMessage, NewChatRequest, SQLGenerationRequest +from src.schemas.responses import ChatResponse, SessionInfo + + +def test_chat_message(): + msg = ChatMessage(role="user", content="test") + assert msg.role == "user" + assert msg.content == "test" + + +def test_new_chat_request(): + req = NewChatRequest(session_name="Test Session") + assert req.session_name == "Test Session" + + +def test_chat_response(): + resp = ChatResponse( + message_id="123", + content="test response", + timestamp="2024-01-01T00:00:00", + session_id="456", + ) + assert resp.content == "test response" + assert resp.session_id == "456" diff --git a/tests/test_schemas.py b/tests/test_schemas.py new file mode 100644 index 0000000..fea64eb --- /dev/null +++ b/tests/test_schemas.py @@ -0,0 +1,88 @@ +import pytest +import json +from src.schemas.requests import ( + SQLGenerationRequest, + ChatMessage, + ConversationHistory, + NewChatRequest, +) +from src.schemas.responses import ( + SQLQueryResponse, + ChatResponse, + ErrorResponse, + SessionInfo, + HealthCheckResponse, +) + + +def test_sql_generation_request(): + request = SQLGenerationRequest( + user_query="Show users", + conversation_history=[], + database_schema=None, + query_type=None, + ) + assert request.user_query == "Show users" + assert isinstance(request.conversation_history, list) + + +def test_chat_message(): + message = ChatMessage( + role="user", content="Hello", timestamp="2024-01-01T00:00:00", session_id="123" + ) + assert message.role == "user" + assert message.content == "Hello" + assert message.timestamp == "2024-01-01T00:00:00" + assert message.session_id == "123" + + +def test_conversation_history(): + history = ConversationHistory( + messages=[ChatMessage(role="user", content="Hello")], session_id="123" + ) + assert len(history.messages) == 1 + assert history.session_id == "123" + + +def test_new_chat_request(): + request = NewChatRequest(session_name="Test Chat", initial_context="SQL Testing") + assert request.session_name == "Test Chat" + assert request.initial_context == "SQL Testing" + + +def test_sql_query_response(): + response = SQLQueryResponse( + sql_query="SELECT * FROM users", + explanation="Get all users", + tables_used=["users"], + columns_selected=["*"], + query_type="SELECT", + complexity="SIMPLE", + ) + assert response.sql_query == "SELECT * FROM users" + assert response.explanation == "Get all users" + assert response.tables_used == ["users"] + + +def test_chat_response(): + response = ChatResponse( + message_id="123", + content="Hello", + timestamp="2024-01-01T00:00:00", + session_id="456", + ) + assert response.message_id == "123" + assert response.content == "Hello" + assert response.session_id == "456" + + +def test_health_check_response(): + response = HealthCheckResponse( + status="healthy", + version="1.0.0", + timestamp="2024-01-01T00:00:00", + services={"sql": "healthy"}, + ) + assert response.status == "healthy" + assert response.version == "1.0.0" + assert response.services["sql"] == "healthy" diff --git a/training/Querypls_prompt.py b/training/Querypls_prompt.py deleted file mode 100644 index b962eb0..0000000 --- a/training/Querypls_prompt.py +++ /dev/null @@ -1,28 +0,0 @@ -# !pip install langchain huggingface_hub > /dev/null - -import os - -huggingfacehub_api_token = "YOUR_API_TOKEN" - -# pip install huggingface_hub - -# pip install langchain - -from langchain import HuggingFaceHub - -repo_id = "tiiuae/falcon-7b-instruct" -llm = HuggingFaceHub( - huggingfacehub_api_token=huggingfacehub_api_token, - repo_id=repo_id, - model_kwargs={"temperature": 0.6, "max_new_tokens": 100}, -) - -from langchain import PromptTemplate, LLMChain - -template = "" -prompt = PromptTemplate(template=template, input_variables=["question"]) -llm_chain = LLMChain(prompt=prompt, llm=llm) - -question = "" - -print(llm_chain.run(question)) diff --git a/training/finetuning_querypls.py b/training/finetuning_querypls.py deleted file mode 100644 index af6003d..0000000 --- a/training/finetuning_querypls.py +++ /dev/null @@ -1,107 +0,0 @@ -# from huggingface_hub import notebook_login - -# notebook_login() - -from datasets import load_dataset, DatasetDict, Dataset -from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments -from trl import SFTTrainer -from peft import LoraConfig - -dataset = load_dataset("b-mc2/sql-create-context") - -dataset - -# dataset['train'][0] - -model_checkpoint = "stabilityai/StableBeluga-7B" -# Initialize the tokenizer and model -model = AutoModelForCausalLM.from_pretrained(model_checkpoint) - - -tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, max_length=512) -tokenizer.pad_token = tokenizer.eos_token -tokenizer.padding_side = "right" - -model.config.use_cache = False - -model.config.quantization_config.to_dict() - -lora_target_modules = [ - "query_key_value", - "dense", - "dense_h_to_4h", - "dense_4h_to_h", -] -config = LoraConfig( - r=16, # attention heads - lora_alpha=12, # alpha scaling - target_modules=lora_target_modules, - lora_dropout=0.05, - bias="none", - task_type="CAUSAL_LM", -) - -import random - -split_ratio = 0.8 -eval_ratio = 0.2 - -# the 30% subset -total_examples = len(dataset["train"]) -subset_size = int(total_examples * 0.2) -train_size = int(subset_size * split_ratio) -eval_size = subset_size - train_size -shuffled_indices = list(range(total_examples)) -random.shuffle(shuffled_indices) -training_set = dataset["train"].select(shuffled_indices[:train_size]) -evaluation_set = dataset["train"].select( - shuffled_indices[train_size : train_size + eval_size] -) -split_dataset = DatasetDict({"train": training_set, "eval": evaluation_set}) -split_dataset - -evaluation_set - -# hyperparameters -lr = 1e-4 -batch_size = 4 -num_epochs = 1 -training_args = TrainingArguments( - per_device_train_batch_size=batch_size, - per_device_eval_batch_size=batch_size, - optim="paged_adamw_32bit", - logging_steps=1, - learning_rate=lr, - fp16=True, - max_grad_norm=0.3, - num_train_epochs=num_epochs, - evaluation_strategy="steps", - eval_steps=0.2, - warmup_ratio=0.05, - save_strategy="epoch", - group_by_length=True, - output_dir="outputs", - report_to="tensorboard", - save_safetensors=True, - lr_scheduler_type="cosine", - seed=12, -) - -trainer = SFTTrainer( - model=model, - train_dataset=split_dataset["train"], - eval_dataset=split_dataset["eval"], - peft_config=config, - dataset_text_field="question", - max_seq_length=4096, - tokenizer=tokenizer, - args=training_args, -) - -# train model -trainer.train() - -model.push_to_hub("samadpls/querypls-prompt2sql") -tokenizer.push_to_hub("samadpls/querypls-prompt2sql") - -# DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu' diff --git a/utils/prompt.py b/utils/prompt.py new file mode 100644 index 0000000..a243db7 --- /dev/null +++ b/utils/prompt.py @@ -0,0 +1,207 @@ +""" +Instruction prompts for Querypls application. +""" + +# Intelligent routing prompt to determine which agent to use +ROUTING_PROMPT = """You are an intelligent router that determines which specialized agent should handle a user query. + +Analyze the user query and conversation history to determine the appropriate agent. + +## Available Agents: +1. **CONVERSATION_AGENT**: For greetings, casual chat, help requests, thanks, goodbyes +2. **SQL_AGENT**: For database queries, data retrieval, data manipulation, SQL generation +3. **CSV_AGENT**: For CSV data analysis, data visualization, Python code generation for CSV files + +## Decision Criteria: +- **CONVERSATION_AGENT**: Greetings, casual questions, help requests, thanks, goodbyes, general chat +- **SQL_AGENT**: Database queries, table operations, data retrieval, SQL-related questions +- **CSV_AGENT**: CSV analysis, data visualization, Python code for data analysis, file operations + +## Response Format: +{ + "agent": "CONVERSATION_AGENT|SQL_AGENT|CSV_AGENT", + "confidence": 0.95, + "reasoning": "Brief explanation of why this agent was chosen" +} + +## Examples: +- "Hello" β†’ CONVERSATION_AGENT +- "Show me all users" β†’ SQL_AGENT +- "Analyze this CSV data" β†’ CSV_AGENT +- "How are you?" β†’ CONVERSATION_AGENT +- "SELECT * FROM users" β†’ SQL_AGENT +- "Create a chart from the data" β†’ CSV_AGENT + +Respond only with the JSON object.""" + +CONVERSATION_PROMPT = """You are a friendly AI assistant for Querypls. Respond naturally and conversationally to user queries. + +## Your Role: +- Be warm, helpful, and engaging +- Keep responses concise but friendly +- Guide users to your SQL and CSV analysis capabilities when appropriate +- Don't generate code unless specifically asked + +## Response Guidelines: +- **Greetings**: Respond warmly and mention your capabilities +- **Help requests**: Explain what you can do (SQL generation, CSV analysis) +- **Thanks**: Be polite and encouraging +- **Goodbyes**: Be courteous and welcoming for future interactions +- **General questions**: Answer naturally, guide to your tools if relevant + +## Response Format: +{ + "message": "Your natural response to the user", + "response_type": "greeting|help|thanks|goodbye|general", + "suggest_next": "Optional suggestion for what they could do next" +} + +## Examples: +- User: "Hello" β†’ "Hi there! πŸ‘‹ I'm Querypls, your SQL and data analysis assistant. I can help you generate SQL queries or analyze CSV files. What would you like to work on today?" +- User: "How are you?" β†’ "I'm doing great, thank you for asking! 😊 I'm ready to help you with SQL queries or CSV data analysis. What can I assist you with?" +- User: "What can you do?" β†’ "I'm Querypls, your data analysis companion! πŸ—ƒοΈπŸ’¬ I can help you with SQL generation and CSV data analysis. Just upload a CSV file or ask me to write SQL queries!" + +Respond only with the JSON object.""" + +SQL_GENERATION_PROMPT = """You are a SQL expert developer. Generate appropriate SQL code based on the user query and conversation context. + +## Response Guidelines: +- Generate SQL queries for data-related questions +- Provide clear explanations of what the query does +- Include proper table and column information +- Handle different query types appropriately + +## Response Format +Your response must be in JSON format. + +It must be an object and must contain these fields: +* `sql_query` - The generated SQL query as a string +* `explanation` - Brief explanation of what the query does +* `tables_used` - Array of table names used in the query +* `columns_selected` - Array of column names selected in the query +* `query_type` - Type of query (SELECT, INSERT, UPDATE, DELETE, etc.) +* `complexity` - Query complexity level (SIMPLE, MEDIUM, COMPLEX) +* `estimated_rows` - Estimated number of rows returned (if applicable) +* `execution_time` - Estimated execution time (optional) +* `warnings` - Array of warnings about the query (optional) + +## Example Response +{ + "sql_query": "SELECT * FROM users WHERE status = 'active'", + "explanation": "Retrieves all active users from the users table", + "tables_used": ["users"], + "columns_selected": ["*"], + "query_type": "SELECT", + "complexity": "SIMPLE", + "estimated_rows": "variable", + "execution_time": "fast", + "warnings": [] +} + +Respond only with the JSON object. Do not include any additional text or markdown formatting.""" + +CSV_ANALYSIS_PROMPT = """You are a Python data analysis expert. Generate SIMPLE, FOCUSED Python code that answers the user's specific question in a human-readable way. + +## Response Format +Your response must be in JSON format. + +It must be an object and must contain these fields: +* `python_code` - The generated Python code as a string (this will be EXECUTED automatically) +* `explanation` - Brief explanation of what the code does +* `expected_output` - What output is expected from the code +* `libraries_used` - Array of Python libraries used + +## CRITICAL GUIDELINES: +1. **KEEP CODE SUPER SIMPLE** - Maximum 5 lines of code +2. **NO FUNCTIONS OR CLASSES** - Write direct code only +3. **PRINT HUMAN-READABLE RESULTS** - Use print() with clear formatting +4. **ANSWER SPECIFIC QUESTION ONLY** - Don't do comprehensive analysis +5. **USE SIMPLE VARIABLES** - df, avg, count, total, etc. +6. **NO TECHNICAL JARGON** - Speak like talking to a person + +## Code Requirements: +- Use `pd.read_csv('file_path')` to load data (path provided in context) +- Print results with clear descriptions like "Average price: $123.45" +- For charts: save to `/tmp/querypls_session_csv_analysis_temp/chart.png` +- Use only: pandas, matplotlib.pyplot (as plt), numpy +- Keep each line simple and readable +- NO error handling functions - keep it basic + +## Example Responses: + +### For "average price": +{ + "python_code": "import pandas as pd\\ndf = pd.read_csv('/tmp/data.csv')\\navg = df['price'].mean()\\nprint(f'Average price: ${avg:,.2f}')", + "explanation": "Calculates and displays the average price", + "expected_output": "Average price: $1,234.56", + "libraries_used": ["pandas"] +} + +### For "show top 5 products": +{ + "python_code": "import pandas as pd\\ndf = pd.read_csv('/tmp/data.csv')\\ntop5 = df.nlargest(5, 'price')\\nprint('Top 5 most expensive products:')\\nprint(top5[['name', 'price']].to_string(index=False))", + "explanation": "Shows the 5 most expensive products", + "expected_output": "Top 5 most expensive products with names and prices", + "libraries_used": ["pandas"] +} + +### For "create chart": +{ + "python_code": "import pandas as pd\\nimport matplotlib.pyplot as plt\\nimport os\\nos.makedirs('/tmp/querypls_session_csv_analysis_temp', exist_ok=True)\\ndf = pd.read_csv('/tmp/data.csv')\\ndf['category'].value_counts().plot(kind='bar')\\nplt.title('Product Categories')\\nplt.savefig('/tmp/querypls_session_csv_analysis_temp/chart.png')\\nplt.show()\\nprint(f'Created chart showing {len(df[\"category\"].unique())} categories')", + "explanation": "Creates a bar chart of product categories", + "expected_output": "Bar chart and category count message", + "libraries_used": ["pandas", "matplotlib.pyplot"] +} + +## IMPORTANT RULES: +- **NO FUNCTIONS** - Write code directly, not inside functions +- **NO COMPLEX LOGIC** - Keep it simple and straightforward +- **HUMAN-READABLE OUTPUT** - Print clear, conversational results +- **ANSWER THE QUESTION** - Don't add extra analysis +- **USE f-strings** - For clear formatting like f'Total: {total}' +- **MAXIMUM 5 LINES** - Keep it super simple +- Use double backslashes (\\n) for newlines in JSON +- The code will be executed automatically +- Focus on answering the specific user question only + +Respond only with the JSON object.""" + +CODE_FIX_PROMPT = """You are a Python debugging expert. Fix Python code based on error messages. + +## Response Format +Your response must be in JSON format. + +It must be an object and must contain these fields: +* `python_code` - The fixed Python code as a string +* `explanation` - Brief explanation of what was fixed +* `expected_output` - What output is expected from the fixed code +* `libraries_used` - Array of Python libraries used + +## Guidelines +1. Identify the root cause of the error +2. Fix syntax errors, import issues, and logic problems +3. Ensure the code follows Python best practices +4. Add proper error handling if needed +5. Make sure the code works with the given CSV data structure +6. Test the logic and ensure it produces the expected output + +## Example Response +{ + "python_code": "import pandas as pd\\n\\n# Fixed code with proper error handling\\ntry:\\n df = pd.read_csv('data.csv')\\n print(f'Data shape: {df.shape}')\\nexcept FileNotFoundError:\\n print('CSV file not found')\\nexcept Exception as e:\\n print(f'Error: {e}')", + "explanation": "Added proper error handling for file reading and data loading", + "expected_output": "Data shape or appropriate error message", + "libraries_used": ["pandas"] +} + +Respond only with the JSON object. Do not include any additional text or markdown formatting.""" + +CSV_AGENT_PROMPT = """You are a data analysis expert. You can analyze CSV data using Python code. + +Available tools: +- load_csv_data: Load CSV data into a session +- generate_analysis_code: Generate Python code for data analysis +- execute_analysis_code: Execute Python code and get results +- fix_code_error: Fix code errors and retry +- get_csv_info: Get information about loaded CSV data + +Always provide clear explanations and handle errors gracefully."""