From 2e7c49bf68be131624da0b882957e9679863bb23 Mon Sep 17 00:00:00 2001 From: Rune Olsen Date: Tue, 6 Jan 2026 13:52:44 +0100 Subject: [PATCH 01/10] Major update --- .gitignore | 11 +- README.md | 535 ++++++-- oai.py | 3397 ++++++++++++++++++++++++++++++++++++++++++---- requirements.txt | 64 +- 4 files changed, 3557 insertions(+), 450 deletions(-) diff --git a/.gitignore b/.gitignore index 107b564..3c2ad4a 100644 --- a/.gitignore +++ b/.gitignore @@ -22,8 +22,9 @@ Pipfile.lock # Consider if you want to include or exclude ._* *~.nib *~.xib -README.md.old -oai.zip + +# Added by author +*.zip .note diagnose.py *.log @@ -33,4 +34,8 @@ build* compiled/ images/oai-iOS-Default-1024x1024@1x.png images/oai.icon/ -b0.sh \ No newline at end of file +b0.sh +*.bak +*.old +*.sh +*.back diff --git a/README.md b/README.md index 0c91bd6..a3c2e5d 100644 --- a/README.md +++ b/README.md @@ -1,53 +1,77 @@ # oAI - OpenRouter AI Chat -A terminal-based chat interface for OpenRouter API with conversation management, cost tracking, and rich formatting. +A powerful terminal-based chat interface for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI agents to access local files and query SQLite databases directly. ## Description -oAI is a command-line chat application that provides an interactive interface to OpenRouter's AI models. It features conversation persistence, file attachments, export capabilities, and detailed session metrics. +oAI is a feature-rich command-line chat application that provides an interactive interface to OpenRouter's AI models. It now includes **MCP integration** for local file system access and read-only database querying, allowing AI to help with code analysis, data exploration, and more. ## Features -- Interactive chat with multiple AI models via OpenRouter -- Model selection with search functionality -- Conversation save/load/export (Markdown, JSON, HTML) -- File attachment support (code files and images) -- Session cost tracking and credit monitoring -- Rich terminal formatting with syntax highlighting -- Persistent command history -- Configurable system prompts and token limits -- SQLite-based configuration and conversation storage +### Core Features +- 🤖 Interactive chat with 300+ AI models via OpenRouter +- 🔍 Model selection with search and capability filtering +- 💾 Conversation save/load/export (Markdown, JSON, HTML) +- 📎 File attachment support (images, PDFs, code files) +- 💰 Session cost tracking and credit monitoring +- 🎨 Rich terminal formatting with syntax highlighting +- 📝 Persistent command history with search (Ctrl+R) +- ⚙️ Configurable system prompts and token limits +- 🗄️ SQLite-based configuration and conversation storage +- 🌐 Online mode (web search capabilities) +- 🧠 Conversation memory toggle (save costs with stateless mode) + +### NEW: MCP (Model Context Protocol) v2.1.0-beta +- 🔧 **File Mode**: AI can read, search, and list your local files + - Automatic .gitignore filtering + - Virtual environment exclusion (venv, node_modules, etc.) + - Supports code files, text, JSON, YAML, and more + - Large file handling (auto-truncates >50KB) + +- 🗄️ **Database Mode**: AI can query your SQLite databases + - Read-only access (no data modification possible) + - Schema inspection (tables, columns, indexes) + - Full-text search across all tables + - SQL query execution (SELECT, JOINs, CTEs, subqueries) + - Query validation and timeout protection + - Result limiting (max 1000 rows) + +- 🔒 **Security Features**: + - Explicit folder/database approval required + - System directory blocking + - Read-only database access + - SQL injection protection + - Query timeout (5 seconds) ## Requirements -- Python 3.7 or higher +- Python 3.10-3.13 (3.14 not supported yet) - OpenRouter API key (get one at https://openrouter.ai) +- Function-calling model required for MCP features (GPT-4, Claude, etc.) -## Screenshot (from version 1.0) +## Screenshot [](https://gitlab.pm/rune/oai/src/branch/main/README.md) -Screenshot of `/help` screen. +*Screenshot from version 1.0 - MCP interface shows mode indicators like `[🔧 MCP: Files]` or `[🗄️ MCP: DB #1]`* ## Installation -### 1. Install Dependencies +### Option 1: From Source (Recommended for Development) -Use the included `requirements.txt` file to install the dependencies: +#### 1. Install Dependencies ```bash pip install -r requirements.txt ``` -### 2. Make the Script Executable +#### 2. Make Executable ```bash chmod +x oai.py ``` -### 3. Copy to PATH - -Copy the script to a directory in your `$PATH` environment variable. Common locations include: +#### 3. Copy to PATH ```bash # Option 1: System-wide (requires sudo) @@ -57,145 +81,408 @@ sudo cp oai.py /usr/local/bin/oai mkdir -p ~/.local/bin cp oai.py ~/.local/bin/oai -# Add to PATH if not already (add to ~/.bashrc or ~/.zshrc) +# Add to PATH if needed (add to ~/.bashrc or ~/.zshrc) export PATH="$HOME/.local/bin:$PATH" ``` -### 4. Verify Installation +#### 4. Verify Installation + +```bash +oai --version +``` + +### Option 2: Pre-built Binaries + +Download platform-specific binaries: +- **macOS (Apple Silicon)**: `oai_vx.x.x_mac_arm64.zip` +- **Linux (x86_64)**: `oai_vx.x.x-linux-x86_64.zip` + +```bash +# Extract and install +unzip oai_vx.x.x_mac_arm64.zip # or `oai_vx.x.x-linux-x86_64.zip` +chmod +x oai +mkdir -p ~/.local/bin +mv oai ~/.local/bin/ +``` + +### Option 3: Build Your Own Binary + +```bash +# Install build dependencies +pip install -r requirements.txt +pip install nuitka ordered-set zstandard + +# Run build script +chmod +x build.sh +./build.sh + +# Binary will be in dist/oai +cp dist/oai ~/.local/bin/ +``` + +### Alternative: Shell Alias + +```bash +# Add to ~/.bashrc or ~/.zshrc +alias oai='python3 /path/to/oai.py' +``` + +## Quick Start + +### First Run Setup ```bash oai ``` -### 5. Alternative Installation (for *nix systems) +On first run, you'll be prompted to enter your OpenRouter API key. -If you have issues with the above method you can add an alias in your `.bashrc`, `.zshrc` etc. - -```bash -alias oai='python3 ' -``` - -On first run, you will be prompted to enter your OpenRouter API key. - -### 6. Use Binaries - -You can also just download the supplied binary for either Mac wit Mx (M1, M2 etc) `oai_mac_arm64.zip` and follow [#3](https://gitlab.pm/rune/oai#3-copy-to-path). Or download for Linux (64bit) `oai_linux_x86_64.zip` and also follow [#3](https://gitlab.pm/rune/oai#3-copy-to-path). - -## Usage - -### Starting the Application +### Basic Usage ```bash +# Start chatting oai + +# Select a model +You> /model + +# Enable MCP for file access +You> /mcp enable +You> /mcp add ~/Documents + +# Ask AI to help with files +[🔧 MCP: Files] You> List all Python files in Documents +[🔧 MCP: Files] You> Read and explain main.py + +# Switch to database mode +You> /mcp add db ~/myapp/data.db +You> /mcp db 1 +[🗄️ MCP: DB #1] You> Show me all tables +[🗄️ MCP: DB #1] You> Find all users created this month ``` -### Basic Commands +## MCP Guide -``` -/help Show all available commands -/model Select an AI model -/config api Set OpenRouter API key -exit Quit the application +### File Mode (Default) + +**Setup:** +```bash +/mcp enable # Start MCP server +/mcp add ~/Projects # Grant access to folder +/mcp add ~/Documents # Add another folder +/mcp list # View all allowed folders ``` -### Configuration - -All configuration is stored in `~/.config/oai/`: -- `oai_config.db` - SQLite database for settings and conversations -- `oai.log` - Application log file -- `history.txt` - Command history - -### Common Workflows - -**Select a Model:** +**Natural Language Usage:** ``` -/model +"List all Python files in Projects" +"Read and explain config.yaml" +"Search for files containing 'TODO'" +"What's in my Documents folder?" ``` -**Paste from clipboard:** -Paste and send content to model -``` -/paste +**Available Tools:** +- `read_file` - Read complete file contents +- `list_directory` - List files/folders (recursive optional) +- `search_files` - Search by name or content + +**Features:** +- ✅ Automatic .gitignore filtering +- ✅ Skips virtual environments (venv, node_modules) +- ✅ Handles large files (auto-truncates >50KB) +- ✅ Cross-platform (macOS, Linux, Windows via WSL) + +### Database Mode + +**Setup:** +```bash +/mcp add db ~/app/database.db # Add SQLite database +/mcp db list # View all databases +/mcp db 1 # Switch to database #1 ``` -Paste with prompt and send content to model +**Natural Language Usage:** ``` -/paste Analyze this text +"Show me all tables in this database" +"Find records mentioning 'error'" +"How many users registered last week?" +"Get the schema for the orders table" +"Show me the 10 most recent transactions" ``` -**Start Chatting:** -``` -You> Hello, how are you? -``` +**Available Tools:** +- `inspect_database` - View schema, tables, columns, indexes +- `search_database` - Full-text search across tables +- `query_database` - Execute read-only SQL queries -**Attach Files:** -``` -You> Debug this code @/path/to/script.py -You> Analyze this image @/path/to/image.png -``` +**Supported Queries:** +- ✅ SELECT statements +- ✅ JOINs (INNER, LEFT, RIGHT, FULL) +- ✅ Subqueries +- ✅ CTEs (Common Table Expressions) +- ✅ Aggregations (COUNT, SUM, AVG, etc.) +- ✅ WHERE, GROUP BY, HAVING, ORDER BY, LIMIT +- ❌ INSERT/UPDATE/DELETE (blocked for safety) -**Save Conversation:** -``` -/save my_conversation -``` +### Mode Management -**Export to File:** +```bash +/mcp status # Show current mode, stats, folders/databases +/mcp files # Switch to file mode +/mcp db # Switch to database mode +/mcp gitignore on # Enable .gitignore filtering (default) +/mcp remove 2 # Remove folder/database by number ``` -/export md notes.md -/export json backup.json -/export html report.html -``` - -**View Session Stats:** - -``` -/stats -/credits -``` - -**Prevous commands input:** - -Use the up/down arrows to see earlier `/`commands and earlier input to model and `` to execute the same command or resend the same input. ## Command Reference -Use `/help` within the application for a complete command reference organized by category: -- Session Commands -- Model Commands -- Configuration -- Token & System -- Conversation Management -- Monitoring & Stats -- File Attachments +### Session Commands +``` +/help [command] Show help menu or detailed command help +/help mcp Comprehensive MCP guide +/clear or /cl Clear terminal screen (or Ctrl+L) +/memory on|off Toggle conversation memory (save costs) +/online on|off Enable/disable web search +/paste [prompt] Paste clipboard content +/retry Resend last prompt +/reset Clear history and system prompt +/prev View previous response +/next View next response +``` + +### MCP Commands +``` +/mcp enable Start MCP server +/mcp disable Stop MCP server +/mcp status Show comprehensive status +/mcp add Add folder for file access +/mcp add db Add SQLite database +/mcp list List all folders +/mcp db list List all databases +/mcp db Switch to database mode +/mcp files Switch to file mode +/mcp remove Remove folder/database +/mcp gitignore on Enable .gitignore filtering +``` + +### Model Commands +``` +/model [search] Select/change AI model +/info [model_id] Show model details (pricing, capabilities) +``` + +### Configuration +``` +/config View all settings +/config api Set API key +/config model Set default model +/config online Set default online mode (on|off) +/config stream Enable/disable streaming (on|off) +/config maxtoken Set max token limit +/config costwarning Set cost warning threshold ($) +/config loglevel Set log level (debug/info/warning/error) +/config log Set log file size (MB) +``` + +### Conversation Management +``` +/save Save conversation +/load Load saved conversation +/delete Delete conversation +/list List saved conversations +/export md|json|html Export conversation +``` + +### Token & System +``` +/maxtoken [value] Set session token limit +/system [prompt] Set system prompt (use 'clear' to reset) +/middleout on|off Enable prompt compression +``` + +### Monitoring +``` +/stats View session statistics +/credits Check OpenRouter credits +``` + +### File Attachments +``` +@/path/to/file Attach file (images, PDFs, code) + +Examples: + Debug @script.py + Analyze @data.json + Review @screenshot.png +``` ## Configuration Options -- API Key: `/config api` -- Base URL: `/config url` -- Streaming: `/config stream on|off` -- Default Model: `/config model` -- Cost Warning: `/config costwarning ` -- Max Token Limit: `/config maxtoken ` +All configuration stored in `~/.config/oai/`: -## File Support +### Files +- `oai_config.db` - SQLite database (settings, conversations, MCP config) +- `oai.log` - Application logs (rotating, configurable size) +- `history.txt` - Command history (searchable with Ctrl+R) -**Supported Code Extensions:** -.py, .js, .ts, .cs, .java, .c, .cpp, .h, .hpp, .rb, .ruby, .php, .swift, .kt, .kts, .go, .sh, .bat, .ps1, .R, .scala, .pl, .lua, .dart, .elm, .xml, .json, .yaml, .yml, .md, .txt +### Key Settings +- **API Key**: OpenRouter authentication +- **Default Model**: Auto-select on startup +- **Streaming**: Real-time response display +- **Max Tokens**: Global and session limits +- **Cost Warning**: Alert threshold for expensive requests +- **Online Mode**: Default web search setting +- **Log Level**: debug/info/warning/error/critical +- **Log Size**: Rotating file size in MB -**Image Support:** -Any image format with proper MIME type (PNG, JPEG, GIF, etc.) +## Supported File Types -## Data Storage +### Code Files +`.py, .js, .ts, .cs, .java, .c, .cpp, .h, .hpp, .rb, .ruby, .php, .swift, .kt, .kts, .go, .sh, .bat, .ps1, .R, .scala, .pl, .lua, .dart, .elm` -- Configuration: `~/.config/oai/oai_config.db` -- Logs: `~/.config/oai/oai.log` -- History: `~/.config/oai/history.txt` +### Data Files +`.json, .yaml, .yml, .xml, .csv, .txt, .md` + +### Images +All standard formats: PNG, JPEG, JPG, GIF, WEBP, BMP + +### Documents +PDF (models with document support) + +### Size Limits +- Images: 10 MB max +- Code/Text: Auto-truncates files >50KB +- Binary data: Displayed as `` + +## MCP Security + +### Access Control +- ✅ Explicit folder/database approval required +- ✅ System directories blocked automatically +- ✅ User confirmation for each addition +- ✅ .gitignore patterns respected (file mode) + +### Database Safety +- ✅ Read-only mode (cannot modify data) +- ✅ SQL query validation (blocks INSERT/UPDATE/DELETE) +- ✅ Query timeout (5 seconds max) +- ✅ Result limits (1000 rows max) +- ✅ Database opened in `mode=ro` + +### File System Safety +- ✅ Read-only access (no write/delete) +- ✅ Virtual environment exclusion +- ✅ Build artifact filtering +- ✅ Maximum file size (10 MB) + +## Tips & Tricks + +### Command History +- **↑/↓ arrows**: Navigate previous commands +- **Ctrl+R**: Search command history +- **Auto-complete**: Start typing `/` for command suggestions + +### Cost Optimization +```bash +/memory off # Disable context (stateless mode) +/maxtoken 1000 # Limit response length +/config costwarning 0.01 # Set alert threshold +``` + +### MCP Best Practices +```bash +# Check status frequently +/mcp status + +# Use specific paths to reduce search time +"List Python files in Projects/app/" # Better than +"List all Python files" # Slower + +# Database queries - be specific +"SELECT * FROM users LIMIT 10" # Good +"SELECT * FROM users" # May hit row limit +``` + +### Debugging +```bash +# Enable debug logging +/config loglevel debug + +# Check log file +tail -f ~/.config/oai/oai.log + +# View MCP statistics +/mcp status # Shows tool call counts +``` + +## Troubleshooting + +### MCP Not Working +```bash +# 1. Check if MCP is installed +python3 -c "import mcp; print('MCP OK')" + +# 2. Verify model supports function calling +/info # Look for "tools" in supported parameters + +# 3. Check MCP status +/mcp status + +# 4. Review logs +tail ~/.config/oai/oai.log +``` + +### Import Errors +```bash +# Reinstall dependencies +pip install --force-reinstall -r requirements.txt +``` + +### Binary Issues (macOS) +```bash +# Remove quarantine +xattr -cr ~/.local/bin/oai + +# Check security settings +# System Settings > Privacy & Security > "Allow Anyway" +``` + +### Database Errors +```bash +# Verify it's a valid SQLite database +sqlite3 database.db ".tables" + +# Check file permissions +ls -la database.db +``` + +## Version History + +### v2.1.0-beta (Current) +- ✨ **NEW**: MCP (Model Context Protocol) integration +- ✨ **NEW**: File system access (read, search, list) +- ✨ **NEW**: SQLite database querying (read-only) +- ✨ **NEW**: Dual mode support (Files & Database) +- ✨ **NEW**: .gitignore filtering +- ✨ **NEW**: Binary data handling in databases +- ✨ **NEW**: Mode indicators in prompt +- ✨ **NEW**: Comprehensive `/help mcp` guide +- 🔧 Improved error handling for tool calls +- 🔧 Enhanced logging for MCP operations +- 🔧 Statistics tracking for tool usage + +### v1.9.6 +- Base version with core chat functionality +- Conversation management +- File attachments +- Cost tracking +- Export capabilities ## License MIT License -Copyright (c) 2024 Rune Olsen +Copyright (c) 2024-2025 Rune Olsen Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal @@ -221,12 +508,22 @@ Full license: https://opensource.org/licenses/MIT **Rune Olsen** -Blog: https://blog.rune.pm +- Blog: https://blog.rune.pm +- Project: https://iurl.no/oai -## Version +## Contributing -1.0 +Contributions welcome! Please: +1. Fork the repository +2. Create a feature branch +3. Submit a pull request with detailed description -## Support +## Acknowledgments -For issues, questions, or contributions, visit https://iurl.no/oai and create an issue. \ No newline at end of file +- OpenRouter team for the unified AI API +- Rich library for beautiful terminal output +- MCP community for the protocol specification + +--- + +**Star ⭐ this project if you find it useful!** diff --git a/oai.py b/oai.py index 26b92ef..e679ce5 100644 --- a/oai.py +++ b/oai.py @@ -2,7 +2,8 @@ import sys import os import requests -import time # For response time tracking +import time +import asyncio from pathlib import Path from typing import Optional, List, Dict, Any import typer @@ -21,20 +22,34 @@ import sqlite3 import json import datetime import logging -from logging.handlers import RotatingFileHandler # Added for log rotation +from logging.handlers import RotatingFileHandler from prompt_toolkit import PromptSession from prompt_toolkit.history import FileHistory from rich.logging import RichHandler from prompt_toolkit.auto_suggest import AutoSuggestFromHistory from packaging import version as pkg_version -import io # Added for custom handler +import io +import platform +import shutil +import subprocess +import fnmatch +import signal -# App version. Changes by author with new releases. -version = '1.9.6' +# MCP imports +try: + from mcp import ClientSession, StdioServerParameters + from mcp.client.stdio import stdio_client + MCP_AVAILABLE = True +except ImportError: + MCP_AVAILABLE = False + print("Warning: MCP library not found. Install with: pip install mcp") + +# App version +version = '2.1.0-beta' app = typer.Typer() -# Application identification for OpenRouter +# Application identification APP_NAME = "oAI" APP_URL = "https://iurl.no/oai" @@ -42,25 +57,26 @@ APP_URL = "https://iurl.no/oai" home = Path.home() config_dir = home / '.config' / 'oai' cache_dir = home / '.cache' / 'oai' -history_file = config_dir / 'history.txt' # Persistent input history file +history_file = config_dir / 'history.txt' database = config_dir / 'oai_config.db' log_file = config_dir / 'oai.log' -# Create dirs if needed +# Create dirs config_dir.mkdir(parents=True, exist_ok=True) cache_dir.mkdir(parents=True, exist_ok=True) -# Rich console for chat UI (separate from logging) +# Rich console console = Console() -# Valid commands list for validation +# Valid commands VALID_COMMANDS = { '/retry', '/online', '/memory', '/paste', '/export', '/save', '/load', '/delete', '/list', '/prev', '/next', '/stats', '/middleout', '/reset', - '/info', '/model', '/maxtoken', '/system', '/config', '/credits', '/clear', '/cl', '/help' + '/info', '/model', '/maxtoken', '/system', '/config', '/credits', '/clear', + '/cl', '/help', '/mcp' } -# Detailed command help database +# Command help database (COMPLETE - includes MCP comprehensive guide) COMMAND_HELP = { '/clear': { 'aliases': ['/cl'], @@ -74,13 +90,155 @@ COMMAND_HELP = { }, '/help': { 'description': 'Display help information for commands.', - 'usage': '/help [command]', + 'usage': '/help [command|topic]', 'examples': [ ('Show all commands', '/help'), ('Get help for a specific command', '/help /model'), - ('Get help for config', '/help /config'), + ('Get detailed MCP help', '/help mcp'), ], - 'notes': 'Use /help without arguments to see the full command list, or /help for detailed information about a specific command.' + 'notes': 'Use /help without arguments to see the full command list, /help for detailed command info, or /help mcp for comprehensive MCP documentation.' + }, + 'mcp': { + 'description': 'Complete guide to MCP (Model Context Protocol) - file access and database querying for AI agents.', + 'usage': 'See detailed examples below', + 'examples': [], + 'notes': ''' +MCP (Model Context Protocol) gives your AI assistant direct access to: + • Local files and folders (read, search, list) + • SQLite databases (inspect, search, query) + +╭─────────────────────────────────────────────────────────╮ +│ 🗂️ FILE MODE │ +╰─────────────────────────────────────────────────────────╯ + +SETUP: + /mcp enable Start MCP server + /mcp add ~/Documents Grant access to folder + /mcp add ~/Code/project Add another folder + /mcp list View all allowed folders + +USAGE (just ask naturally): + "List all Python files in my Code folder" + "Read the contents of Documents/notes.txt" + "Search for files containing 'budget'" + "What's in my Documents folder?" + +FEATURES: + ✓ Automatically respects .gitignore patterns + ✓ Skips virtual environments (venv, node_modules, etc.) + ✓ Handles large files (auto-truncates >50KB) + ✓ Cross-platform (macOS, Linux, Windows) + +MANAGEMENT: + /mcp remove ~/Desktop Remove folder access + /mcp remove 2 Remove by number + /mcp gitignore on|off Toggle .gitignore filtering + /mcp status Show comprehensive status + +╭─────────────────────────────────────────────────────────╮ +│ 🗄️ DATABASE MODE │ +╰─────────────────────────────────────────────────────────╯ + +SETUP: + /mcp add db ~/app/data.db Add specific database + /mcp add db ~/.config/oai/oai_config.db + /mcp db list View all databases + +SWITCH TO DATABASE: + /mcp db 1 Work with database #1 + /mcp db 2 Switch to database #2 + /mcp files Switch back to file mode + +USAGE (after selecting a database): + "Show me all tables in this database" + "Find any records mentioning 'Santa Clause'" + "How many users are in the database?" + "Show me the schema for the orders table" + "Get the 10 most recent transactions" + +FEATURES: + ✓ Read-only mode (no data modification possible) + ✓ Smart schema inspection + ✓ Full-text search across all tables + ✓ SQL query execution (SELECT only) + ✓ JOINs, subqueries, CTEs supported + ✓ Automatic query timeout (5 seconds) + ✓ Result limits (max 1000 rows) + +SAFETY: + • All queries are read-only (INSERT/UPDATE/DELETE blocked) + • Database opened in read-only mode + • Query validation prevents dangerous operations + • Timeout protection prevents infinite loops + +╭─────────────────────────────────────────────────────────╮ +│ 💡 TIPS & TRICKS │ +╰─────────────────────────────────────────────────────────╯ + +MODE INDICATORS: + [🔧 MCP: Files] You're in file mode + [🗄️ MCP: DB #1] You're querying database #1 + +QUICK REFERENCE: + /mcp status See current mode, stats, folders/databases + /mcp files Switch to file mode (default) + /mcp db Switch to database mode + /mcp db list List all databases with details + /mcp list List all folders + +TROUBLESHOOTING: + • No results? Check /mcp status to see what's accessible + • Wrong mode? Use /mcp files or /mcp db to switch + • Database errors? Ensure file exists and is valid SQLite + • .gitignore not working? Check /mcp status shows patterns + +SECURITY NOTES: + • MCP only accesses explicitly added folders/databases + • File mode: read-only access (no write/delete) + • Database mode: SELECT queries only (no modifications) + • System directories are blocked automatically + • Each addition requires your explicit confirmation + +For command-specific help: /help /mcp + ''' + }, + '/mcp': { + 'description': 'Manage MCP (Model Context Protocol) for local file access and SQLite database querying. When enabled with a function-calling model, AI can automatically search, read, list files, and query databases. Supports two modes: Files (default) and Database.', + 'usage': '/mcp [args]', + 'examples': [ + ('Enable MCP server', '/mcp enable'), + ('Disable MCP server', '/mcp disable'), + ('Show MCP status and current mode', '/mcp status'), + ('', ''), + ('━━━ FILE MODE ━━━', ''), + ('Add folder for file access', '/mcp add ~/Documents'), + ('Remove folder by path', '/mcp remove ~/Desktop'), + ('Remove folder by number', '/mcp remove 2'), + ('List allowed folders', '/mcp list'), + ('Toggle .gitignore filtering', '/mcp gitignore on'), + ('', ''), + ('━━━ DATABASE MODE ━━━', ''), + ('Add SQLite database', '/mcp add db ~/app/data.db'), + ('List all databases', '/mcp db list'), + ('Switch to database #1', '/mcp db 1'), + ('Remove database', '/mcp remove db 2'), + ('Switch back to file mode', '/mcp files'), + ], + 'notes': '''MCP allows AI to read local files and query SQLite databases. Works automatically with function-calling models (GPT-4, Claude, etc.). + +FILE MODE (default): +- Automatically loads and respects .gitignore patterns +- Skips virtual environments and build artifacts +- Supports search, read, and list operations + +DATABASE MODE: +- Read-only access (no data modification) +- Execute SELECT queries with JOINs, subqueries, CTEs +- Full-text search across all tables +- Schema inspection and data exploration +- Automatic query validation and timeout protection + +Use /help mcp for comprehensive guide with examples.''' }, '/memory': { 'description': 'Toggle conversation memory. When ON, the AI remembers conversation history. When OFF, each request is independent (saves tokens and cost).', @@ -272,7 +430,7 @@ COMMAND_HELP = { }, } -# Supported code file extensions +# Supported code extensions SUPPORTED_CODE_EXTENSIONS = { '.py', '.js', '.ts', '.cs', '.java', '.c', '.cpp', '.h', '.hpp', '.rb', '.ruby', '.php', '.swift', '.kt', '.kts', '.go', @@ -280,16 +438,16 @@ SUPPORTED_CODE_EXTENSIONS = { '.elm', '.xml', '.json', '.yaml', '.yml', '.md', '.txt' } -# Session metrics constants (per 1M tokens, in USD; adjustable) +# Pricing MODEL_PRICING = { - 'input': 3.0, # $3/M input tokens (adjustable) - 'output': 15.0 # $15/M output tokens (adjustable) + 'input': 3.0, + 'output': 15.0 } -LOW_CREDIT_RATIO = 0.1 # Warn if credits left < 10% of total -LOW_CREDIT_AMOUNT = 1.0 # Warn if credits left < $1 in absolute terms -HIGH_COST_WARNING = "cost_warning_threshold" # Configurable key for cost threshold, default $0.01 +LOW_CREDIT_RATIO = 0.1 +LOW_CREDIT_AMOUNT = 1.0 +HIGH_COST_WARNING = "cost_warning_threshold" -# Valid log levels mapping +# Valid log levels VALID_LOG_LEVELS = { 'debug': logging.DEBUG, 'info': logging.INFO, @@ -298,14 +456,18 @@ VALID_LOG_LEVELS = { 'critical': logging.CRITICAL } -# DB configuration -database = config_dir / 'oai_config.db' -DB_FILE = str(database) +# System directories to block (cross-platform) +SYSTEM_DIRS_BLACKLIST = { + '/System', '/Library', '/private', '/usr', '/bin', '/sbin', # macOS + '/boot', '/dev', '/proc', '/sys', '/root', # Linux + 'C:\\Windows', 'C:\\Program Files', 'C:\\Program Files (x86)' # Windows +} +# Database functions def create_table_if_not_exists(): - """Ensure the config and conversation_sessions tables exist.""" + """Ensure tables exist.""" os.makedirs(config_dir, exist_ok=True) - with sqlite3.connect(DB_FILE) as conn: + with sqlite3.connect(str(database)) as conn: conn.execute('''CREATE TABLE IF NOT EXISTS config ( key TEXT PRIMARY KEY, value TEXT NOT NULL @@ -314,32 +476,106 @@ def create_table_if_not_exists(): id INTEGER PRIMARY KEY AUTOINCREMENT, name TEXT NOT NULL, timestamp TEXT NOT NULL, - data TEXT NOT NULL -- JSON of session_history + data TEXT NOT NULL + )''') + # MCP configuration table + conn.execute('''CREATE TABLE IF NOT EXISTS mcp_config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + )''') + # MCP statistics table + conn.execute('''CREATE TABLE IF NOT EXISTS mcp_stats ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL, + tool_name TEXT NOT NULL, + folder TEXT, + success INTEGER NOT NULL, + error_message TEXT + )''') + # MCP databases table + conn.execute('''CREATE TABLE IF NOT EXISTS mcp_databases ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + path TEXT NOT NULL UNIQUE, + name TEXT NOT NULL, + size INTEGER, + tables TEXT, + added_timestamp TEXT NOT NULL )''') conn.commit() def get_config(key: str) -> Optional[str]: create_table_if_not_exists() - with sqlite3.connect(DB_FILE) as conn: + with sqlite3.connect(str(database)) as conn: cursor = conn.execute('SELECT value FROM config WHERE key = ?', (key,)) result = cursor.fetchone() return result[0] if result else None def set_config(key: str, value: str): create_table_if_not_exists() - with sqlite3.connect(DB_FILE) as conn: + with sqlite3.connect(str(database)) as conn: conn.execute('INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)', (key, value)) conn.commit() -# ============================================================================ -# ROTATING RICH HANDLER - Combines RotatingFileHandler with Rich formatting -# ============================================================================ +def get_mcp_config(key: str) -> Optional[str]: + """Get MCP configuration value.""" + create_table_if_not_exists() + with sqlite3.connect(str(database)) as conn: + cursor = conn.execute('SELECT value FROM mcp_config WHERE key = ?', (key,)) + result = cursor.fetchone() + return result[0] if result else None + +def set_mcp_config(key: str, value: str): + """Set MCP configuration value.""" + create_table_if_not_exists() + with sqlite3.connect(str(database)) as conn: + conn.execute('INSERT OR REPLACE INTO mcp_config (key, value) VALUES (?, ?)', (key, value)) + conn.commit() + +def log_mcp_stat(tool_name: str, folder: Optional[str], success: bool, error_message: Optional[str] = None): + """Log MCP tool usage statistics.""" + create_table_if_not_exists() + timestamp = datetime.datetime.now().isoformat() + with sqlite3.connect(str(database)) as conn: + conn.execute( + 'INSERT INTO mcp_stats (timestamp, tool_name, folder, success, error_message) VALUES (?, ?, ?, ?, ?)', + (timestamp, tool_name, folder, 1 if success else 0, error_message) + ) + conn.commit() + +def get_mcp_stats() -> Dict[str, Any]: + """Get MCP usage statistics.""" + create_table_if_not_exists() + with sqlite3.connect(str(database)) as conn: + cursor = conn.execute(''' + SELECT + COUNT(*) as total_calls, + SUM(CASE WHEN tool_name = 'read_file' THEN 1 ELSE 0 END) as reads, + SUM(CASE WHEN tool_name = 'list_directory' THEN 1 ELSE 0 END) as lists, + SUM(CASE WHEN tool_name = 'search_files' THEN 1 ELSE 0 END) as searches, + SUM(CASE WHEN tool_name = 'inspect_database' THEN 1 ELSE 0 END) as db_inspects, + SUM(CASE WHEN tool_name = 'search_database' THEN 1 ELSE 0 END) as db_searches, + SUM(CASE WHEN tool_name = 'query_database' THEN 1 ELSE 0 END) as db_queries, + MAX(timestamp) as last_used + FROM mcp_stats + ''') + row = cursor.fetchone() + return { + 'total_calls': row[0] or 0, + 'reads': row[1] or 0, + 'lists': row[2] or 0, + 'searches': row[3] or 0, + 'db_inspects': row[4] or 0, + 'db_searches': row[5] or 0, + 'db_queries': row[6] or 0, + 'last_used': row[7] + } + +# Rotating Rich Handler class RotatingRichHandler(RotatingFileHandler): - """Custom handler that combines RotatingFileHandler with Rich formatting.""" + """Custom handler combining RotatingFileHandler with Rich formatting.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - # Create a Rich console that writes to a string buffer self.rich_console = Console(file=io.StringIO(), width=120, force_terminal=False) self.rich_handler = RichHandler( console=self.rich_console, @@ -351,46 +587,31 @@ class RotatingRichHandler(RotatingFileHandler): def emit(self, record): try: - # Let RichHandler format the record self.rich_handler.emit(record) - - # Get the formatted output from the string buffer output = self.rich_console.file.getvalue() - - # Clear the buffer for next use self.rich_console.file.seek(0) self.rich_console.file.truncate(0) - - # Write the Rich-formatted output to our rotating file if output: self.stream.write(output) self.flush() - except Exception: self.handleError(record) -# ============================================================================ -# LOGGING SETUP - MUST BE DONE AFTER CONFIG IS LOADED -# ============================================================================ - -# Load log configuration from DB FIRST (before creating handler) +# Logging setup LOG_MAX_SIZE_MB = int(get_config('log_max_size_mb') or "10") LOG_BACKUP_COUNT = int(get_config('log_backup_count') or "2") LOG_LEVEL_STR = get_config('log_level') or "info" LOG_LEVEL = VALID_LOG_LEVELS.get(LOG_LEVEL_STR.lower(), logging.INFO) -# Global reference to the handler for dynamic reloading app_handler = None app_logger = None def setup_logging(): - """Setup or reset logging configuration with current settings.""" + """Setup or reset logging configuration.""" global app_handler, LOG_MAX_SIZE_MB, LOG_BACKUP_COUNT, LOG_LEVEL, app_logger - # Get the root logger root_logger = logging.getLogger() - # Remove existing handler if present if app_handler is not None: root_logger.removeHandler(app_handler) try: @@ -398,13 +619,12 @@ def setup_logging(): except: pass - # Check if log file needs immediate rotation + # Check if log needs rotation if os.path.exists(log_file): current_size = os.path.getsize(log_file) max_bytes = LOG_MAX_SIZE_MB * 1024 * 1024 if current_size >= max_bytes: - # Perform immediate rotation import shutil timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') backup_file = f"{log_file}.{timestamp}" @@ -413,7 +633,7 @@ def setup_logging(): except Exception as e: print(f"Warning: Could not rotate log file: {e}") - # Clean up old backups if exceeding limit + # Clean old backups log_dir = os.path.dirname(log_file) log_basename = os.path.basename(log_file) backup_pattern = f"{log_basename}.*" @@ -421,7 +641,6 @@ def setup_logging(): import glob backups = sorted(glob.glob(os.path.join(log_dir, backup_pattern))) - # Keep only the most recent backups while len(backups) > LOG_BACKUP_COUNT: oldest = backups.pop(0) try: @@ -429,7 +648,6 @@ def setup_logging(): except: pass - # Create new handler with current settings app_handler = RotatingRichHandler( filename=str(log_file), maxBytes=LOG_MAX_SIZE_MB * 1024 * 1024, @@ -437,15 +655,11 @@ def setup_logging(): encoding='utf-8' ) - # Set handler level to NOTSET so it processes all records app_handler.setLevel(logging.NOTSET) - - # Configure root logger - set to WARNING to suppress third-party library noise root_logger.setLevel(logging.WARNING) root_logger.addHandler(app_handler) - # Suppress noisy third-party loggers - # These libraries create DEBUG logs that pollute our log file + # Suppress noisy loggers logging.getLogger('asyncio').setLevel(logging.WARNING) logging.getLogger('urllib3').setLevel(logging.WARNING) logging.getLogger('requests').setLevel(logging.WARNING) @@ -454,19 +668,14 @@ def setup_logging(): logging.getLogger('openai').setLevel(logging.WARNING) logging.getLogger('openrouter').setLevel(logging.WARNING) - # Get or create app logger and set its level (this filters what gets logged) app_logger = logging.getLogger("oai_app") app_logger.setLevel(LOG_LEVEL) - # Don't propagate to avoid root logger filtering app_logger.propagate = True return app_logger -# Initial logging setup -app_logger = setup_logging() - def set_log_level(level_str: str) -> bool: - """Set the application log level. Returns True if successful.""" + """Set application log level.""" global LOG_LEVEL, LOG_LEVEL_STR, app_logger level_str_lower = level_str.lower() if level_str_lower not in VALID_LOG_LEVELS: @@ -474,40 +683,1911 @@ def set_log_level(level_str: str) -> bool: LOG_LEVEL = VALID_LOG_LEVELS[level_str_lower] LOG_LEVEL_STR = level_str_lower - # Update the logger level immediately if app_logger: app_logger.setLevel(LOG_LEVEL) return True def reload_logging_config(): - """Reload logging configuration from database and reinitialize handler.""" + """Reload logging configuration.""" global LOG_MAX_SIZE_MB, LOG_BACKUP_COUNT, LOG_LEVEL, LOG_LEVEL_STR, app_logger - # Reload from database LOG_MAX_SIZE_MB = int(get_config('log_max_size_mb') or "10") LOG_BACKUP_COUNT = int(get_config('log_backup_count') or "2") LOG_LEVEL_STR = get_config('log_level') or "info" LOG_LEVEL = VALID_LOG_LEVELS.get(LOG_LEVEL_STR.lower(), logging.INFO) - # Reinitialize logging app_logger = setup_logging() - return app_logger -# ============================================================================ -# END OF LOGGING SETUP -# ============================================================================ - +app_logger = setup_logging() logger = logging.getLogger(__name__) -def check_for_updates(current_version: str) -> str: - """ - Check if a new version is available using semantic versioning. +# Load configs +API_KEY = get_config('api_key') +OPENROUTER_BASE_URL = get_config('base_url') or "https://openrouter.ai/api/v1" +STREAM_ENABLED = get_config('stream_enabled') or "on" +DEFAULT_MODEL_ID = get_config('default_model') +MAX_TOKEN = int(get_config('max_token') or "100000") +COST_WARNING_THRESHOLD = float(get_config(HIGH_COST_WARNING) or "0.01") +DEFAULT_ONLINE_MODE = get_config('default_online_mode') or "off" + +# Fetch models +models_data = [] +text_models = [] +try: + headers = { + "Authorization": f"Bearer {API_KEY}", + "HTTP-Referer": APP_URL, + "X-Title": APP_NAME + } if API_KEY else { + "HTTP-Referer": APP_URL, + "X-Title": APP_NAME + } + response = requests.get(f"{OPENROUTER_BASE_URL}/models", headers=headers) + response.raise_for_status() + models_data = response.json()["data"] + text_models = [m for m in models_data if "modalities" not in m or "video" not in (m.get("modalities") or [])] + selected_model_default = None + if DEFAULT_MODEL_ID: + selected_model_default = next((m for m in text_models if m["id"] == DEFAULT_MODEL_ID), None) + if not selected_model_default: + console.print(f"[bold yellow]Warning: Default model '{DEFAULT_MODEL_ID}' unavailable. Use '/config model'.[/]") +except Exception as e: + models_data = [] + text_models = [] + app_logger.error(f"Failed to fetch models: {e}") + +# ============================================================================ +# MCP INTEGRATION CLASSES +# ============================================================================ + +class CrossPlatformMCPConfig: + """Handle OS-specific MCP configuration.""" - Returns: - Formatted status string for display - """ + def __init__(self): + self.system = platform.system() + self.is_macos = self.system == "Darwin" + self.is_linux = self.system == "Linux" + self.is_windows = self.system == "Windows" + + app_logger.info(f"Detected OS: {self.system}") + + def get_default_allowed_dirs(self) -> List[Path]: + """Get safe default directories.""" + home = Path.home() + + if self.is_macos: + return [ + home / "Documents", + home / "Desktop", + home / "Downloads" + ] + elif self.is_linux: + dirs = [home / "Documents"] + + try: + for xdg_dir in ["DOCUMENTS", "DESKTOP", "DOWNLOAD"]: + result = subprocess.run( + ["xdg-user-dir", xdg_dir], + capture_output=True, + text=True, + timeout=1 + ) + if result.returncode == 0: + dir_path = Path(result.stdout.strip()) + if dir_path.exists(): + dirs.append(dir_path) + except: + dirs.extend([ + home / "Desktop", + home / "Downloads" + ]) + + return list(set(dirs)) + + elif self.is_windows: + return [ + home / "Documents", + home / "Desktop", + home / "Downloads" + ] + + return [home] + + def get_python_command(self) -> str: + """Get Python command.""" + import sys + return sys.executable + + def get_filesystem_warning(self) -> str: + """Get OS-specific security warning.""" + if self.is_macos: + return """ +⚠️ macOS Security Notice: +The Filesystem MCP server needs access to your selected folder. +You may see a security prompt - click 'Allow' to proceed. +(System Settings > Privacy & Security > Files and Folders) +""" + elif self.is_linux: + return """ +⚠️ Linux Security Notice: +The Filesystem MCP server will access your selected folder. +Ensure oAI has appropriate file permissions. +""" + elif self.is_windows: + return """ +⚠️ Windows Security Notice: +The Filesystem MCP server will access your selected folder. +You may need to grant file access permissions. +""" + return "" + + def normalize_path(self, path: str) -> Path: + """Normalize path for current OS.""" + return Path(os.path.expanduser(path)).resolve() + + def is_system_directory(self, path: Path) -> bool: + """Check if path is a system directory.""" + path_str = str(path) + for blocked in SYSTEM_DIRS_BLACKLIST: + if path_str.startswith(blocked): + return True + return False + + def is_safe_path(self, requested_path: Path, allowed_dirs: List[Path]) -> bool: + """Check if path is within allowed directories.""" + try: + requested = requested_path.resolve() + + for allowed in allowed_dirs: + try: + allowed_resolved = allowed.resolve() + requested.relative_to(allowed_resolved) + return True + except ValueError: + continue + + return False + except: + return False + + def get_folder_stats(self, folder: Path) -> Dict[str, Any]: + """Get folder statistics.""" + try: + if not folder.exists() or not folder.is_dir(): + return {'exists': False} + + file_count = 0 + total_size = 0 + + for item in folder.rglob('*'): + if item.is_file(): + file_count += 1 + try: + total_size += item.stat().st_size + except: + pass + + return { + 'exists': True, + 'file_count': file_count, + 'total_size': total_size, + 'size_mb': total_size / (1024 * 1024) + } + except Exception as e: + app_logger.error(f"Error getting folder stats for {folder}: {e}") + return {'exists': False, 'error': str(e)} + + +class GitignoreParser: + """Parse .gitignore files and check if paths should be ignored.""" + + def __init__(self): + self.patterns = [] # List of (pattern, is_negation, source_dir) + + def add_gitignore(self, gitignore_path: Path): + """Parse and add patterns from a .gitignore file.""" + if not gitignore_path.exists(): + return + + try: + source_dir = gitignore_path.parent + with open(gitignore_path, 'r', encoding='utf-8') as f: + for line_num, line in enumerate(f, 1): + line = line.rstrip('\n\r') + + # Skip empty lines and comments + if not line or line.startswith('#'): + continue + + # Check for negation pattern + is_negation = line.startswith('!') + if is_negation: + line = line[1:] + + # Remove leading slash (make relative to gitignore location) + if line.startswith('/'): + line = line[1:] + + self.patterns.append((line, is_negation, source_dir)) + + app_logger.debug(f"Loaded {len(self.patterns)} patterns from {gitignore_path}") + except Exception as e: + app_logger.warning(f"Error reading {gitignore_path}: {e}") + + def should_ignore(self, path: Path) -> bool: + """Check if a path should be ignored based on gitignore patterns.""" + if not self.patterns: + return False + + ignored = False + + for pattern, is_negation, source_dir in self.patterns: + # Only apply pattern if path is under the source directory + try: + rel_path = path.relative_to(source_dir) + except ValueError: + # Path is not relative to this gitignore's directory + continue + + rel_path_str = str(rel_path) + + # Check if pattern matches + if self._match_pattern(pattern, rel_path_str, path.is_dir()): + if is_negation: + ignored = False # Negation patterns un-ignore + else: + ignored = True + + return ignored + + def _match_pattern(self, pattern: str, path: str, is_dir: bool) -> bool: + """Match a gitignore pattern against a path.""" + # Directory-only pattern (ends with /) + if pattern.endswith('/'): + if not is_dir: + return False + pattern = pattern[:-1] + + # ** matches any number of directories + if '**' in pattern: + # Convert ** to regex-like pattern + pattern_parts = pattern.split('**') + if len(pattern_parts) == 2: + prefix, suffix = pattern_parts + # Match if path starts with prefix and ends with suffix + if prefix: + if not path.startswith(prefix.rstrip('/')): + return False + if suffix: + suffix = suffix.lstrip('/') + if not (path.endswith(suffix) or f'/{suffix}' in path): + return False + return True + + # Direct match + if fnmatch.fnmatch(path, pattern): + return True + + # Match as subdirectory pattern + if '/' not in pattern: + # Pattern without / matches in any directory + parts = path.split('/') + if any(fnmatch.fnmatch(part, pattern) for part in parts): + return True + + return False + + +class SQLiteQueryValidator: + """Validate SQLite queries for read-only safety.""" + + @staticmethod + def is_safe_query(query: str) -> tuple[bool, str]: + """ + Validate that query is a safe read-only SELECT. + Returns (is_safe, error_message) + """ + query_upper = query.strip().upper() + + # Must start with SELECT or WITH (for CTEs) + if not (query_upper.startswith('SELECT') or query_upper.startswith('WITH')): + return False, "Only SELECT queries are allowed (including WITH/CTE)" + + # Block dangerous keywords even in SELECT + dangerous_keywords = [ + 'INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', + 'ALTER', 'TRUNCATE', 'REPLACE', 'ATTACH', 'DETACH', + 'PRAGMA', 'VACUUM', 'REINDEX' + ] + + # Check for dangerous keywords (but allow them in string literals) + # Simple check: look for keywords outside of quotes + query_no_strings = re.sub(r"'[^']*'", '', query_upper) + query_no_strings = re.sub(r'"[^"]*"', '', query_no_strings) + + for keyword in dangerous_keywords: + if re.search(r'\b' + keyword + r'\b', query_no_strings): + return False, f"Keyword '{keyword}' not allowed in read-only mode" + + return True, "" + # P2 +class MCPFilesystemServer: + """MCP Filesystem Server with file access and SQLite database querying.""" + + def __init__(self, allowed_folders: List[Path]): + self.allowed_folders = allowed_folders + self.config = CrossPlatformMCPConfig() + self.max_file_size = 10 * 1024 * 1024 # 10MB limit + self.max_list_items = 1000 # Max items to return in list_directory + self.respect_gitignore = True # Default enabled + + # Initialize gitignore parser + self.gitignore_parser = GitignoreParser() + self._load_gitignores() + + # SQLite configuration + self.max_query_timeout = 5 # seconds + self.max_query_results = 1000 # max rows + self.default_query_limit = 100 # default rows + + app_logger.info(f"MCP Filesystem Server initialized with {len(allowed_folders)} folders") + + def _load_gitignores(self): + """Load all .gitignore files from allowed folders.""" + gitignore_count = 0 + for folder in self.allowed_folders: + if not folder.exists(): + continue + + # Load root .gitignore + root_gitignore = folder / '.gitignore' + if root_gitignore.exists(): + self.gitignore_parser.add_gitignore(root_gitignore) + gitignore_count += 1 + app_logger.info(f"Loaded .gitignore from {folder}") + + # Load nested .gitignore files + try: + for gitignore_path in folder.rglob('.gitignore'): + if gitignore_path != root_gitignore: + self.gitignore_parser.add_gitignore(gitignore_path) + gitignore_count += 1 + app_logger.debug(f"Loaded nested .gitignore: {gitignore_path}") + except Exception as e: + app_logger.warning(f"Error loading nested .gitignores from {folder}: {e}") + + if gitignore_count > 0: + app_logger.info(f"Loaded {gitignore_count} .gitignore file(s) with {len(self.gitignore_parser.patterns)} total patterns") + + def reload_gitignores(self): + """Reload all .gitignore files (call when allowed_folders changes).""" + self.gitignore_parser = GitignoreParser() + self._load_gitignores() + app_logger.info("Reloaded .gitignore patterns") + + def is_allowed_path(self, path: Path) -> bool: + """Check if path is allowed.""" + return self.config.is_safe_path(path, self.allowed_folders) + + async def read_file(self, file_path: str) -> Dict[str, Any]: + """Read file contents.""" + try: + path = self.config.normalize_path(file_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} not in allowed folders" + app_logger.warning(error_msg) + log_mcp_stat('read_file', str(path.parent), False, error_msg) + return {'error': error_msg} + + if not path.exists(): + error_msg = f"File not found: {path}" + app_logger.warning(error_msg) + log_mcp_stat('read_file', str(path.parent), False, error_msg) + return {'error': error_msg} + + if not path.is_file(): + error_msg = f"Not a file: {path}" + app_logger.warning(error_msg) + log_mcp_stat('read_file', str(path.parent), False, error_msg) + return {'error': error_msg} + + # Check if file should be ignored + if self.respect_gitignore and self.gitignore_parser.should_ignore(path): + error_msg = f"File ignored by .gitignore: {path}" + app_logger.warning(error_msg) + log_mcp_stat('read_file', str(path.parent), False, error_msg) + return {'error': error_msg} + + file_size = path.stat().st_size + if file_size > self.max_file_size: + error_msg = f"File too large: {file_size / (1024*1024):.1f}MB (max: {self.max_file_size / (1024*1024):.0f}MB)" + app_logger.warning(error_msg) + log_mcp_stat('read_file', str(path.parent), False, error_msg) + return {'error': error_msg} + + # Try to read as text + try: + content = path.read_text(encoding='utf-8') + + # If file is large (>50KB), provide summary instead of full content + max_content_size = 50 * 1024 # 50KB + if file_size > max_content_size: + lines = content.split('\n') + total_lines = len(lines) + + # Return first 500 lines + last 100 lines with notice + head_lines = 500 + tail_lines = 100 + + if total_lines > (head_lines + tail_lines): + truncated_content = ( + '\n'.join(lines[:head_lines]) + + f"\n\n... [TRUNCATED: {total_lines - head_lines - tail_lines} lines omitted to stay within size limits] ...\n\n" + + '\n'.join(lines[-tail_lines:]) + ) + + app_logger.info(f"Read file (truncated): {path} ({file_size} bytes, {total_lines} lines)") + log_mcp_stat('read_file', str(path.parent), True) + + return { + 'content': truncated_content, + 'path': str(path), + 'size': file_size, + 'truncated': True, + 'total_lines': total_lines, + 'lines_shown': head_lines + tail_lines, + 'note': f'File truncated: showing first {head_lines} and last {tail_lines} lines of {total_lines} total' + } + + app_logger.info(f"Read file: {path} ({file_size} bytes)") + log_mcp_stat('read_file', str(path.parent), True) + return { + 'content': content, + 'path': str(path), + 'size': file_size + } + except UnicodeDecodeError: + error_msg = f"Cannot decode file as UTF-8: {path}" + app_logger.warning(error_msg) + log_mcp_stat('read_file', str(path.parent), False, error_msg) + return {'error': error_msg} + + except Exception as e: + error_msg = f"Error reading file: {e}" + app_logger.error(error_msg) + log_mcp_stat('read_file', file_path, False, str(e)) + return {'error': error_msg} + + async def list_directory(self, dir_path: str, recursive: bool = False) -> Dict[str, Any]: + """List directory contents.""" + try: + path = self.config.normalize_path(dir_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} not in allowed folders" + app_logger.warning(error_msg) + log_mcp_stat('list_directory', str(path), False, error_msg) + return {'error': error_msg} + + if not path.exists(): + error_msg = f"Directory not found: {path}" + app_logger.warning(error_msg) + log_mcp_stat('list_directory', str(path), False, error_msg) + return {'error': error_msg} + + if not path.is_dir(): + error_msg = f"Not a directory: {path}" + app_logger.warning(error_msg) + log_mcp_stat('list_directory', str(path), False, error_msg) + return {'error': error_msg} + + items = [] + pattern = '**/*' if recursive else '*' + + # Directories to skip (hardcoded common patterns) + skip_dirs = { + '.venv', 'venv', 'env', 'virtualenv', + 'site-packages', 'dist-packages', + '__pycache__', '.pytest_cache', '.mypy_cache', + 'node_modules', '.git', '.svn', + '.idea', '.vscode', + 'build', 'dist', 'eggs', '.eggs' + } + + def should_skip_path(item_path: Path) -> bool: + """Check if path should be skipped.""" + # Check hardcoded skip directories + path_parts = item_path.parts + if any(part in skip_dirs for part in path_parts): + return True + + # Check gitignore patterns (if enabled) + if self.respect_gitignore and self.gitignore_parser.should_ignore(item_path): + return True + + return False + + for item in path.glob(pattern): + # Stop if we hit the limit + if len(items) >= self.max_list_items: + break + + # Skip excluded directories + if should_skip_path(item): + continue + + try: + stat = item.stat() + items.append({ + 'name': item.name, + 'path': str(item), + 'type': 'directory' if item.is_dir() else 'file', + 'size': stat.st_size if item.is_file() else 0, + 'modified': datetime.datetime.fromtimestamp(stat.st_mtime).isoformat() + }) + except: + continue + + truncated = len(items) >= self.max_list_items + + app_logger.info(f"Listed directory: {path} ({len(items)} items, recursive={recursive}, truncated={truncated})") + log_mcp_stat('list_directory', str(path), True) + + result = { + 'path': str(path), + 'items': items, + 'count': len(items), + 'truncated': truncated + } + + if truncated: + result['note'] = f'Results limited to {self.max_list_items} items. Use more specific search path or disable recursive mode.' + + return result + + except Exception as e: + error_msg = f"Error listing directory: {e}" + app_logger.error(error_msg) + log_mcp_stat('list_directory', dir_path, False, str(e)) + return {'error': error_msg} + + async def search_files(self, pattern: str, search_path: Optional[str] = None, content_search: bool = False) -> Dict[str, Any]: + """Search for files.""" + try: + # Determine search roots + if search_path: + path = self.config.normalize_path(search_path) + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} not in allowed folders" + app_logger.warning(error_msg) + log_mcp_stat('search_files', str(path), False, error_msg) + return {'error': error_msg} + search_roots = [path] + else: + search_roots = self.allowed_folders + + matches = [] + + # Directories to skip (virtual environments, caches, etc.) + skip_dirs = { + '.venv', 'venv', 'env', 'virtualenv', + 'site-packages', 'dist-packages', + '__pycache__', '.pytest_cache', '.mypy_cache', + 'node_modules', '.git', '.svn', + '.idea', '.vscode', + 'build', 'dist', 'eggs', '.eggs' + } + + def should_skip_path(item_path: Path) -> bool: + """Check if path should be skipped.""" + # Check hardcoded skip directories + path_parts = item_path.parts + if any(part in skip_dirs for part in path_parts): + return True + + # Check gitignore patterns (if enabled) + if self.respect_gitignore and self.gitignore_parser.should_ignore(item_path): + return True + + return False + + for root in search_roots: + if not root.exists(): + continue + + # Filename search + if not content_search: + for item in root.rglob(pattern): + if item.is_file() and not should_skip_path(item): + try: + stat = item.stat() + matches.append({ + 'path': str(item), + 'name': item.name, + 'size': stat.st_size, + 'modified': datetime.datetime.fromtimestamp(stat.st_mtime).isoformat() + }) + except: + continue + else: + # Content search (slower) + for item in root.rglob('*'): + if item.is_file() and not should_skip_path(item): + try: + # Skip large files + if item.stat().st_size > self.max_file_size: + continue + + # Try to read and search content + try: + content = item.read_text(encoding='utf-8') + if pattern.lower() in content.lower(): + stat = item.stat() + matches.append({ + 'path': str(item), + 'name': item.name, + 'size': stat.st_size, + 'modified': datetime.datetime.fromtimestamp(stat.st_mtime).isoformat() + }) + except: + continue + except: + continue + + search_type = "content" if content_search else "filename" + app_logger.info(f"Searched files: pattern='{pattern}', type={search_type}, found={len(matches)}") + log_mcp_stat('search_files', str(search_roots[0]) if search_roots else None, True) + + return { + 'pattern': pattern, + 'search_type': search_type, + 'matches': matches, + 'count': len(matches) + } + + except Exception as e: + error_msg = f"Error searching files: {e}" + app_logger.error(error_msg) + log_mcp_stat('search_files', search_path or "all", False, str(e)) + return {'error': error_msg} + + # ======================================================================== + # SQLite DATABASE METHODS + # ======================================================================== + + async def inspect_database(self, db_path: str, table_name: Optional[str] = None) -> Dict[str, Any]: + """Inspect SQLite database schema.""" + try: + path = Path(db_path).resolve() + + if not path.exists(): + error_msg = f"Database not found: {path}" + app_logger.warning(error_msg) + log_mcp_stat('inspect_database', str(path.parent), False, error_msg) + return {'error': error_msg} + + if not path.is_file(): + error_msg = f"Not a file: {path}" + app_logger.warning(error_msg) + log_mcp_stat('inspect_database', str(path.parent), False, error_msg) + return {'error': error_msg} + + # Open database in read-only mode + try: + conn = sqlite3.connect(f'file:{path}?mode=ro', uri=True) + cursor = conn.cursor() + except sqlite3.DatabaseError as e: + error_msg = f"Not a valid SQLite database: {path} ({e})" + app_logger.warning(error_msg) + log_mcp_stat('inspect_database', str(path.parent), False, error_msg) + return {'error': error_msg} + + try: + if table_name: + # Get info for specific table + cursor.execute("SELECT sql FROM sqlite_master WHERE type='table' AND name=?", (table_name,)) + result = cursor.fetchone() + + if not result: + return {'error': f"Table '{table_name}' not found"} + + create_sql = result[0] + + # Get column info + cursor.execute(f"PRAGMA table_info({table_name})") + columns = [] + for row in cursor.fetchall(): + columns.append({ + 'id': row[0], + 'name': row[1], + 'type': row[2], + 'not_null': bool(row[3]), + 'default_value': row[4], + 'primary_key': bool(row[5]) + }) + + # Get indexes + cursor.execute(f"PRAGMA index_list({table_name})") + indexes = [] + for row in cursor.fetchall(): + indexes.append({ + 'name': row[1], + 'unique': bool(row[2]) + }) + + # Get row count + cursor.execute(f"SELECT COUNT(*) FROM {table_name}") + row_count = cursor.fetchone()[0] + + app_logger.info(f"Inspected table: {table_name} in {path}") + log_mcp_stat('inspect_database', str(path.parent), True) + + return { + 'database': str(path), + 'table': table_name, + 'create_sql': create_sql, + 'columns': columns, + 'indexes': indexes, + 'row_count': row_count + } + else: + # Get all tables + cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table' ORDER BY name") + tables = [] + for row in cursor.fetchall(): + table = row[0] + # Get row count + cursor.execute(f"SELECT COUNT(*) FROM {table}") + count = cursor.fetchone()[0] + tables.append({ + 'name': table, + 'row_count': count + }) + + # Get indexes + cursor.execute("SELECT name, tbl_name FROM sqlite_master WHERE type='index' ORDER BY name") + indexes = [{'name': row[0], 'table': row[1]} for row in cursor.fetchall()] + + # Get triggers + cursor.execute("SELECT name, tbl_name FROM sqlite_master WHERE type='trigger' ORDER BY name") + triggers = [{'name': row[0], 'table': row[1]} for row in cursor.fetchall()] + + # Get database size + db_size = path.stat().st_size + + app_logger.info(f"Inspected database: {path} ({len(tables)} tables)") + log_mcp_stat('inspect_database', str(path.parent), True) + + return { + 'database': str(path), + 'size': db_size, + 'size_mb': db_size / (1024 * 1024), + 'tables': tables, + 'table_count': len(tables), + 'indexes': indexes, + 'triggers': triggers + } + finally: + conn.close() + + except Exception as e: + error_msg = f"Error inspecting database: {e}" + app_logger.error(error_msg) + log_mcp_stat('inspect_database', db_path, False, str(e)) + return {'error': error_msg} + + async def search_database(self, db_path: str, search_term: str, table_name: Optional[str] = None, column_name: Optional[str] = None) -> Dict[str, Any]: + """Search for a value across database tables.""" + try: + path = Path(db_path).resolve() + + if not path.exists(): + error_msg = f"Database not found: {path}" + app_logger.warning(error_msg) + log_mcp_stat('search_database', str(path.parent), False, error_msg) + return {'error': error_msg} + + # Open database in read-only mode + try: + conn = sqlite3.connect(f'file:{path}?mode=ro', uri=True) + cursor = conn.cursor() + except sqlite3.DatabaseError as e: + error_msg = f"Not a valid SQLite database: {path} ({e})" + app_logger.warning(error_msg) + log_mcp_stat('search_database', str(path.parent), False, error_msg) + return {'error': error_msg} + + try: + matches = [] + + # Get tables to search + if table_name: + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,)) + tables = [row[0] for row in cursor.fetchall()] + if not tables: + return {'error': f"Table '{table_name}' not found"} + else: + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = [row[0] for row in cursor.fetchall()] + + for table in tables: + # Get columns for this table + cursor.execute(f"PRAGMA table_info({table})") + columns = [row[1] for row in cursor.fetchall()] + + # Filter columns if specified + if column_name: + if column_name not in columns: + continue + columns = [column_name] + + # Search in each column + for column in columns: + try: + # Use LIKE for partial matching + query = f"SELECT * FROM {table} WHERE {column} LIKE ? LIMIT {self.default_query_limit}" + cursor.execute(query, (f'%{search_term}%',)) + results = cursor.fetchall() + + if results: + # Get column names + col_names = [desc[0] for desc in cursor.description] + + for row in results: + # Convert row to dict, handling binary data + row_dict = {} + for col_name, value in zip(col_names, row): + if isinstance(value, bytes): + # Convert bytes to hex string or base64 + try: + # Try to decode as UTF-8 first + row_dict[col_name] = value.decode('utf-8') + except UnicodeDecodeError: + # If binary, show as hex string + row_dict[col_name] = f"" + else: + row_dict[col_name] = value + + matches.append({ + 'table': table, + 'column': column, + 'row': row_dict + }) + except sqlite3.Error: + # Skip columns that can't be searched (e.g., BLOBs) + continue + + app_logger.info(f"Searched database: {path} for '{search_term}' - found {len(matches)} matches") + log_mcp_stat('search_database', str(path.parent), True) + + result = { + 'database': str(path), + 'search_term': search_term, + 'matches': matches, + 'count': len(matches) + } + + if table_name: + result['table_filter'] = table_name + if column_name: + result['column_filter'] = column_name + + if len(matches) >= self.default_query_limit: + result['note'] = f'Results limited to {self.default_query_limit} matches. Use more specific search or query_database for full results.' + + return result + + finally: + conn.close() + + except Exception as e: + error_msg = f"Error searching database: {e}" + app_logger.error(error_msg) + log_mcp_stat('search_database', db_path, False, str(e)) + return {'error': error_msg} + + async def query_database(self, db_path: str, query: str, limit: Optional[int] = None) -> Dict[str, Any]: + """Execute a read-only SQL query on database.""" + try: + path = Path(db_path).resolve() + + if not path.exists(): + error_msg = f"Database not found: {path}" + app_logger.warning(error_msg) + log_mcp_stat('query_database', str(path.parent), False, error_msg) + return {'error': error_msg} + + # Validate query + is_safe, error_msg = SQLiteQueryValidator.is_safe_query(query) + if not is_safe: + app_logger.warning(f"Unsafe query rejected: {error_msg}") + log_mcp_stat('query_database', str(path.parent), False, error_msg) + return {'error': error_msg} + + # Open database in read-only mode + try: + conn = sqlite3.connect(f'file:{path}?mode=ro', uri=True) + cursor = conn.cursor() + except sqlite3.DatabaseError as e: + error_msg = f"Not a valid SQLite database: {path} ({e})" + app_logger.warning(error_msg) + log_mcp_stat('query_database', str(path.parent), False, error_msg) + return {'error': error_msg} + + try: + # Set query timeout + def timeout_handler(signum, frame): + raise TimeoutError(f"Query exceeded {self.max_query_timeout} second timeout") + + # Note: signal.alarm only works on Unix-like systems + if hasattr(signal, 'SIGALRM'): + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(self.max_query_timeout) + + try: + # Execute query + cursor.execute(query) + + # Get results + result_limit = min(limit or self.default_query_limit, self.max_query_results) + results = cursor.fetchmany(result_limit) + + # Get column names + columns = [desc[0] for desc in cursor.description] if cursor.description else [] + + # Convert to list of dicts, handling binary data + rows = [] + for row in results: + row_dict = {} + for col_name, value in zip(columns, row): + if isinstance(value, bytes): + # Convert bytes to readable format + try: + # Try to decode as UTF-8 first + row_dict[col_name] = value.decode('utf-8') + except UnicodeDecodeError: + # If binary, show as hex string or summary + row_dict[col_name] = f"" + else: + row_dict[col_name] = value + rows.append(row_dict) + + # Check if more results available + has_more = len(results) == result_limit + if has_more: + # Try to fetch one more to check + one_more = cursor.fetchone() + has_more = one_more is not None + + finally: + if hasattr(signal, 'SIGALRM'): + signal.alarm(0) # Cancel timeout + + app_logger.info(f"Executed query on {path}: returned {len(rows)} rows") + log_mcp_stat('query_database', str(path.parent), True) + + result = { + 'database': str(path), + 'query': query, + 'columns': columns, + 'rows': rows, + 'count': len(rows) + } + + if has_more: + result['truncated'] = True + result['note'] = f'Results limited to {result_limit} rows. Use LIMIT clause in query for more control.' + + return result + + except TimeoutError as e: + error_msg = str(e) + app_logger.warning(error_msg) + log_mcp_stat('query_database', str(path.parent), False, error_msg) + return {'error': error_msg} + except sqlite3.Error as e: + error_msg = f"SQL error: {e}" + app_logger.warning(error_msg) + log_mcp_stat('query_database', str(path.parent), False, error_msg) + return {'error': error_msg} + finally: + if hasattr(signal, 'SIGALRM'): + signal.alarm(0) # Ensure timeout is cancelled + conn.close() + + except Exception as e: + error_msg = f"Error executing query: {e}" + app_logger.error(error_msg) + log_mcp_stat('query_database', db_path, False, str(e)) + return {'error': error_msg} + + +class MCPManager: + """Manage MCP server lifecycle, tool calls, and mode switching.""" + + def __init__(self): + self.enabled = False + self.mode = "files" # "files" or "database" + self.selected_db_index = None + + self.server: Optional[MCPFilesystemServer] = None + + # File/folder mode + self.allowed_folders: List[Path] = [] + + # Database mode + self.databases: List[Dict[str, Any]] = [] + + self.config = CrossPlatformMCPConfig() + self.session_start_time: Optional[datetime.datetime] = None + + # Load persisted data + self._load_folders() + self._load_databases() + + app_logger.info("MCP Manager initialized") + + def _load_folders(self): + """Load allowed folders from database.""" + folders_json = get_mcp_config('allowed_folders') + if folders_json: + try: + folder_paths = json.loads(folders_json) + self.allowed_folders = [self.config.normalize_path(p) for p in folder_paths] + app_logger.info(f"Loaded {len(self.allowed_folders)} folders from config") + except Exception as e: + app_logger.error(f"Error loading MCP folders: {e}") + self.allowed_folders = [] + + def _save_folders(self): + """Save allowed folders to database.""" + folder_paths = [str(p) for p in self.allowed_folders] + set_mcp_config('allowed_folders', json.dumps(folder_paths)) + app_logger.info(f"Saved {len(self.allowed_folders)} folders to config") + + def _load_databases(self): + """Load databases from database.""" + create_table_if_not_exists() + try: + with sqlite3.connect(str(database)) as conn: + cursor = conn.execute('SELECT id, path, name, size, tables, added_timestamp FROM mcp_databases ORDER BY id') + self.databases = [] + for row in cursor.fetchall(): + tables_list = json.loads(row[4]) if row[4] else [] + self.databases.append({ + 'id': row[0], + 'path': row[1], + 'name': row[2], + 'size': row[3], + 'tables': tables_list, + 'added': row[5] + }) + app_logger.info(f"Loaded {len(self.databases)} databases from config") + except Exception as e: + app_logger.error(f"Error loading databases: {e}") + self.databases = [] + + def _save_database(self, db_info: Dict[str, Any]): + """Save a database to config.""" + create_table_if_not_exists() + try: + with sqlite3.connect(str(database)) as conn: + conn.execute( + 'INSERT INTO mcp_databases (path, name, size, tables, added_timestamp) VALUES (?, ?, ?, ?, ?)', + (db_info['path'], db_info['name'], db_info['size'], json.dumps(db_info['tables']), db_info['added']) + ) + conn.commit() + # Get the ID + cursor = conn.execute('SELECT id FROM mcp_databases WHERE path = ?', (db_info['path'],)) + db_info['id'] = cursor.fetchone()[0] + app_logger.info(f"Saved database {db_info['name']} to config") + except Exception as e: + app_logger.error(f"Error saving database: {e}") + + def _remove_database_from_config(self, db_path: str): + """Remove database from config.""" + create_table_if_not_exists() + try: + with sqlite3.connect(str(database)) as conn: + conn.execute('DELETE FROM mcp_databases WHERE path = ?', (db_path,)) + conn.commit() + app_logger.info(f"Removed database from config: {db_path}") + except Exception as e: + app_logger.error(f"Error removing database from config: {e}") + + def enable(self) -> Dict[str, Any]: + """Enable MCP server.""" + if not MCP_AVAILABLE: + return { + 'success': False, + 'error': 'MCP library not installed. Run: pip install mcp' + } + + if self.enabled: + return { + 'success': False, + 'error': 'MCP is already enabled' + } + + try: + self.server = MCPFilesystemServer(self.allowed_folders) + self.enabled = True + self.session_start_time = datetime.datetime.now() + + set_mcp_config('mcp_enabled', 'true') + + app_logger.info("MCP Filesystem Server enabled") + + return { + 'success': True, + 'folder_count': len(self.allowed_folders), + 'database_count': len(self.databases), + 'message': 'MCP Filesystem Server started successfully' + } + except Exception as e: + app_logger.error(f"Error enabling MCP: {e}") + return { + 'success': False, + 'error': str(e) + } + + def disable(self) -> Dict[str, Any]: + """Disable MCP server.""" + if not self.enabled: + return { + 'success': False, + 'error': 'MCP is not enabled' + } + + try: + self.server = None + self.enabled = False + self.session_start_time = None + self.mode = "files" + self.selected_db_index = None + + set_mcp_config('mcp_enabled', 'false') + + app_logger.info("MCP Filesystem Server disabled") + + return { + 'success': True, + 'message': 'MCP Filesystem Server stopped' + } + except Exception as e: + app_logger.error(f"Error disabling MCP: {e}") + return { + 'success': False, + 'error': str(e) + } + + def switch_mode(self, new_mode: str, db_index: Optional[int] = None) -> Dict[str, Any]: + """Switch between files and database mode.""" + if not self.enabled: + return { + 'success': False, + 'error': 'MCP is not enabled. Use /mcp enable first' + } + + if new_mode == "files": + self.mode = "files" + self.selected_db_index = None + app_logger.info("Switched to file mode") + return { + 'success': True, + 'mode': 'files', + 'message': 'Switched to file mode', + 'tools': ['read_file', 'list_directory', 'search_files'] + } + elif new_mode == "database": + if db_index is None: + return { + 'success': False, + 'error': 'Database index required. Use /mcp db ' + } + + if db_index < 1 or db_index > len(self.databases): + return { + 'success': False, + 'error': f'Invalid database number. Use 1-{len(self.databases)}' + } + + self.mode = "database" + self.selected_db_index = db_index - 1 # Convert to 0-based + db = self.databases[self.selected_db_index] + + app_logger.info(f"Switched to database mode: {db['name']}") + return { + 'success': True, + 'mode': 'database', + 'database': db, + 'message': f"Switched to database #{db_index}: {db['name']}", + 'tools': ['inspect_database', 'search_database', 'query_database'] + } + else: + return { + 'success': False, + 'error': f'Invalid mode: {new_mode}' + } + + def add_folder(self, folder_path: str) -> Dict[str, Any]: + """Add folder to allowed list.""" + if not self.enabled: + return { + 'success': False, + 'error': 'MCP is not enabled. Use /mcp enable first' + } + + try: + path = self.config.normalize_path(folder_path) + + # Validation checks + if not path.exists(): + return { + 'success': False, + 'error': f'Directory does not exist: {path}' + } + + if not path.is_dir(): + return { + 'success': False, + 'error': f'Not a directory: {path}' + } + + if self.config.is_system_directory(path): + return { + 'success': False, + 'error': f'Cannot add system directory: {path}' + } + + # Check if already added + if path in self.allowed_folders: + return { + 'success': False, + 'error': f'Folder already in allowed list: {path}' + } + + # Check for nested paths + parent_folder = None + for existing in self.allowed_folders: + try: + path.relative_to(existing) + parent_folder = existing + break + except ValueError: + continue + + # Get folder stats + stats = self.config.get_folder_stats(path) + + # Add folder + self.allowed_folders.append(path) + self._save_folders() + + # Update server and reload gitignores + if self.server: + self.server.allowed_folders = self.allowed_folders + self.server.reload_gitignores() + + app_logger.info(f"Added folder to MCP: {path}") + + result = { + 'success': True, + 'path': str(path), + 'stats': stats, + 'total_folders': len(self.allowed_folders) + } + + if parent_folder: + result['warning'] = f'Note: {parent_folder} is already allowed (parent folder)' + + return result + + except Exception as e: + app_logger.error(f"Error adding folder: {e}") + return { + 'success': False, + 'error': str(e) + } + + def remove_folder(self, folder_ref: str) -> Dict[str, Any]: + """Remove folder from allowed list (by path or number).""" + if not self.enabled: + return { + 'success': False, + 'error': 'MCP is not enabled. Use /mcp enable first' + } + + try: + # Check if it's a number + if folder_ref.isdigit(): + index = int(folder_ref) - 1 + if 0 <= index < len(self.allowed_folders): + path = self.allowed_folders[index] + else: + return { + 'success': False, + 'error': f'Invalid folder number: {folder_ref}' + } + else: + # Treat as path + path = self.config.normalize_path(folder_ref) + if path not in self.allowed_folders: + return { + 'success': False, + 'error': f'Folder not in allowed list: {path}' + } + + # Remove folder + self.allowed_folders.remove(path) + self._save_folders() + + # Update server and reload gitignores + if self.server: + self.server.allowed_folders = self.allowed_folders + self.server.reload_gitignores() + + app_logger.info(f"Removed folder from MCP: {path}") + + result = { + 'success': True, + 'path': str(path), + 'total_folders': len(self.allowed_folders) + } + + if len(self.allowed_folders) == 0: + result['warning'] = 'No allowed folders remaining. Add one with /mcp add ' + + return result + + except Exception as e: + app_logger.error(f"Error removing folder: {e}") + return { + 'success': False, + 'error': str(e) + } + + def add_database(self, db_path: str) -> Dict[str, Any]: + """Add SQLite database.""" + if not self.enabled: + return { + 'success': False, + 'error': 'MCP is not enabled. Use /mcp enable first' + } + + try: + path = Path(db_path).resolve() + + # Validation + if not path.exists(): + return { + 'success': False, + 'error': f'Database file not found: {path}' + } + + if not path.is_file(): + return { + 'success': False, + 'error': f'Not a file: {path}' + } + + # Check if already added + if any(db['path'] == str(path) for db in self.databases): + return { + 'success': False, + 'error': f'Database already added: {path.name}' + } + + # Validate it's a SQLite database and get schema + try: + conn = sqlite3.connect(f'file:{path}?mode=ro', uri=True) + cursor = conn.cursor() + + # Get tables + cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") + tables = [row[0] for row in cursor.fetchall()] + + conn.close() + except sqlite3.DatabaseError as e: + return { + 'success': False, + 'error': f'Not a valid SQLite database: {e}' + } + + # Get file size + db_size = path.stat().st_size + + # Create database entry + db_info = { + 'path': str(path), + 'name': path.name, + 'size': db_size, + 'tables': tables, + 'added': datetime.datetime.now().isoformat() + } + + # Save to config + self._save_database(db_info) + + # Add to list + self.databases.append(db_info) + + app_logger.info(f"Added database: {path.name}") + + return { + 'success': True, + 'database': db_info, + 'number': len(self.databases), + 'message': f'Added database #{len(self.databases)}: {path.name}' + } + + except Exception as e: + app_logger.error(f"Error adding database: {e}") + return { + 'success': False, + 'error': str(e) + } + + def remove_database(self, db_ref: str) -> Dict[str, Any]: + """Remove database (by number or path).""" + if not self.enabled: + return { + 'success': False, + 'error': 'MCP is not enabled' + } + + try: + # Check if it's a number + if db_ref.isdigit(): + index = int(db_ref) - 1 + if 0 <= index < len(self.databases): + db = self.databases[index] + else: + return { + 'success': False, + 'error': f'Invalid database number: {db_ref}' + } + else: + # Treat as path + path = str(Path(db_ref).resolve()) + db = next((d for d in self.databases if d['path'] == path), None) + if not db: + return { + 'success': False, + 'error': f'Database not found: {db_ref}' + } + index = self.databases.index(db) + + # If currently selected, deselect + if self.mode == "database" and self.selected_db_index == index: + self.mode = "files" + self.selected_db_index = None + + # Remove from config + self._remove_database_from_config(db['path']) + + # Remove from list + self.databases.pop(index) + + app_logger.info(f"Removed database: {db['name']}") + + result = { + 'success': True, + 'database': db, + 'message': f"Removed database: {db['name']}" + } + + if len(self.databases) == 0: + result['warning'] = 'No databases remaining. Add one with /mcp add db ' + + return result + + except Exception as e: + app_logger.error(f"Error removing database: {e}") + return { + 'success': False, + 'error': str(e) + } + + def list_databases(self) -> Dict[str, Any]: + """List all databases.""" + try: + db_list = [] + for idx, db in enumerate(self.databases, 1): + db_info = { + 'number': idx, + 'name': db['name'], + 'path': db['path'], + 'size_mb': db['size'] / (1024 * 1024), + 'tables': db['tables'], + 'table_count': len(db['tables']), + 'added': db['added'] + } + + # Check if file still exists + if not Path(db['path']).exists(): + db_info['warning'] = 'File not found' + + db_list.append(db_info) + + return { + 'success': True, + 'databases': db_list, + 'count': len(db_list) + } + + except Exception as e: + app_logger.error(f"Error listing databases: {e}") + return { + 'success': False, + 'error': str(e) + } + + def toggle_gitignore(self, enabled: bool) -> Dict[str, Any]: + """Toggle .gitignore filtering.""" + if not self.enabled: + return { + 'success': False, + 'error': 'MCP is not enabled' + } + + if not self.server: + return { + 'success': False, + 'error': 'MCP server not running' + } + + self.server.respect_gitignore = enabled + status = "enabled" if enabled else "disabled" + + app_logger.info(f".gitignore filtering {status}") + + return { + 'success': True, + 'message': f'.gitignore filtering {status}', + 'pattern_count': len(self.server.gitignore_parser.patterns) if enabled else 0 + } + + def list_folders(self) -> Dict[str, Any]: + """List all allowed folders with stats.""" + try: + folders_info = [] + total_files = 0 + total_size = 0 + + for idx, folder in enumerate(self.allowed_folders, 1): + stats = self.config.get_folder_stats(folder) + + folder_info = { + 'number': idx, + 'path': str(folder), + 'exists': stats.get('exists', False) + } + + if stats.get('exists'): + folder_info['file_count'] = stats['file_count'] + folder_info['size_mb'] = stats['size_mb'] + total_files += stats['file_count'] + total_size += stats['total_size'] + + folders_info.append(folder_info) + + return { + 'success': True, + 'folders': folders_info, + 'total_folders': len(self.allowed_folders), + 'total_files': total_files, + 'total_size_mb': total_size / (1024 * 1024) + } + + except Exception as e: + app_logger.error(f"Error listing folders: {e}") + return { + 'success': False, + 'error': str(e) + } + + def get_status(self) -> Dict[str, Any]: + """Get comprehensive MCP status.""" + try: + stats = get_mcp_stats() + + uptime = None + if self.session_start_time: + delta = datetime.datetime.now() - self.session_start_time + hours = delta.seconds // 3600 + minutes = (delta.seconds % 3600) // 60 + uptime = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m" + + folder_info = self.list_folders() + + gitignore_status = "enabled" if (self.server and self.server.respect_gitignore) else "disabled" + gitignore_patterns = len(self.server.gitignore_parser.patterns) if self.server else 0 + + # Current mode info + mode_info = { + 'mode': self.mode, + 'mode_display': '🔧 Files' if self.mode == 'files' else f'🗄️ DB #{self.selected_db_index + 1}' + } + + if self.mode == 'database' and self.selected_db_index is not None: + db = self.databases[self.selected_db_index] + mode_info['database'] = db + + return { + 'success': True, + 'enabled': self.enabled, + 'uptime': uptime, + 'mode_info': mode_info, + 'folder_count': len(self.allowed_folders), + 'database_count': len(self.databases), + 'total_files': folder_info.get('total_files', 0), + 'total_size_mb': folder_info.get('total_size_mb', 0), + 'stats': stats, + 'tools_available': self._get_current_tools(), + 'gitignore_status': gitignore_status, + 'gitignore_patterns': gitignore_patterns + } + + except Exception as e: + app_logger.error(f"Error getting MCP status: {e}") + return { + 'success': False, + 'error': str(e) + } + + def _get_current_tools(self) -> List[str]: + """Get tools available in current mode.""" + if not self.enabled: + return [] + + if self.mode == "files": + return ['read_file', 'list_directory', 'search_files'] + elif self.mode == "database": + return ['inspect_database', 'search_database', 'query_database'] + + return [] + + async def call_tool(self, tool_name: str, **kwargs) -> Dict[str, Any]: + """Call an MCP tool.""" + if not self.enabled or not self.server: + return { + 'error': 'MCP is not enabled' + } + + try: + # File mode tools + if tool_name == 'read_file': + return await self.server.read_file(kwargs.get('file_path', '')) + elif tool_name == 'list_directory': + return await self.server.list_directory( + kwargs.get('dir_path', ''), + kwargs.get('recursive', False) + ) + elif tool_name == 'search_files': + return await self.server.search_files( + kwargs.get('pattern', ''), + kwargs.get('search_path'), + kwargs.get('content_search', False) + ) + + # Database mode tools + elif tool_name == 'inspect_database': + if self.mode != 'database' or self.selected_db_index is None: + return {'error': 'Not in database mode. Use /mcp db first'} + db = self.databases[self.selected_db_index] + return await self.server.inspect_database( + db['path'], + kwargs.get('table_name') + ) + elif tool_name == 'search_database': + if self.mode != 'database' or self.selected_db_index is None: + return {'error': 'Not in database mode. Use /mcp db first'} + db = self.databases[self.selected_db_index] + return await self.server.search_database( + db['path'], + kwargs.get('search_term', ''), + kwargs.get('table_name'), + kwargs.get('column_name') + ) + elif tool_name == 'query_database': + if self.mode != 'database' or self.selected_db_index is None: + return {'error': 'Not in database mode. Use /mcp db first'} + db = self.databases[self.selected_db_index] + return await self.server.query_database( + db['path'], + kwargs.get('query', ''), + kwargs.get('limit') + ) + else: + return { + 'error': f'Unknown tool: {tool_name}' + } + except Exception as e: + app_logger.error(f"Error calling MCP tool {tool_name}: {e}") + return { + 'error': str(e) + } + + def get_tools_schema(self) -> List[Dict[str, Any]]: + """Get MCP tools as OpenAI function calling schema (for current mode).""" + if not self.enabled: + return [] + + if self.mode == "files": + return self._get_file_tools_schema() + elif self.mode == "database": + return self._get_database_tools_schema() + + return [] + + def _get_file_tools_schema(self) -> List[Dict[str, Any]]: + """Get file mode tools schema.""" + if len(self.allowed_folders) == 0: + return [] + + allowed_dirs_str = ", ".join(str(f) for f in self.allowed_folders) + + return [ + { + "type": "function", + "function": { + "name": "search_files", + "description": f"Search for files in the user's local filesystem. Allowed directories: {allowed_dirs_str}. Can search by filename pattern (e.g., '*.py' for Python files) or search within file contents. Automatically respects .gitignore patterns and excludes virtual environments.", + "parameters": { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "Search pattern. For filename search, use glob patterns like '*.py', '*.txt', 'report*'. For content search, use plain text to search for." + }, + "content_search": { + "type": "boolean", + "description": "If true, searches inside file contents for the pattern (SLOW - use only when specifically asked to search file contents). If false, searches only filenames (FAST - use for finding files by name). Default is false. Virtual environments and package directories are automatically excluded.", + "default": False + }, + "search_path": { + "type": "string", + "description": "Optional: specific directory to search in. If not provided, searches all allowed directories.", + } + }, + "required": ["pattern"] + } + } + }, + { + "type": "function", + "function": { + "name": "read_file", + "description": "Read the complete contents of a text file from the user's local filesystem. Only works for files within allowed directories. Maximum file size: 10MB. Files larger than 50KB are automatically truncated. Respects .gitignore patterns.", + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Full path to the file to read (e.g., /Users/username/Documents/report.txt or ~/Documents/notes.md)" + } + }, + "required": ["file_path"] + } + } + }, + { + "type": "function", + "function": { + "name": "list_directory", + "description": "List all files and subdirectories in a directory from the user's local filesystem. Only works for directories within allowed paths. Automatically filters out virtual environments, build artifacts, and .gitignore patterns. Limited to 1000 items.", + "parameters": { + "type": "object", + "properties": { + "dir_path": { + "type": "string", + "description": "Directory path to list (e.g., ~/Documents or /Users/username/Projects)" + }, + "recursive": { + "type": "boolean", + "description": "If true, lists all files in subdirectories recursively. If false, lists only immediate children. Default is true. WARNING: Recursive listings can be very large - use specific paths when possible.", + "default": True + } + }, + "required": ["dir_path"] + } + } + } + ] + + def _get_database_tools_schema(self) -> List[Dict[str, Any]]: + """Get database mode tools schema.""" + if self.selected_db_index is None or self.selected_db_index >= len(self.databases): + return [] + + db = self.databases[self.selected_db_index] + db_name = db['name'] + tables_str = ", ".join(db['tables']) + + return [ + { + "type": "function", + "function": { + "name": "inspect_database", + "description": f"Inspect the schema of the currently selected SQLite database ({db_name}). Can get all tables or details for a specific table including columns, types, indexes, and row counts. Available tables: {tables_str}.", + "parameters": { + "type": "object", + "properties": { + "table_name": { + "type": "string", + "description": f"Optional: specific table to inspect. If not provided, returns info for all tables. Available: {tables_str}" + } + }, + "required": [] + } + } + }, + { + "type": "function", + "function": { + "name": "search_database", + "description": f"Search for a value across all tables in the database ({db_name}). Performs a LIKE search (partial matching) across all columns or specific table/column. Returns matching rows with all column data. Limited to {self.server.default_query_limit} results.", + "parameters": { + "type": "object", + "properties": { + "search_term": { + "type": "string", + "description": "Value to search for (partial match supported)" + }, + "table_name": { + "type": "string", + "description": f"Optional: limit search to specific table. Available: {tables_str}" + }, + "column_name": { + "type": "string", + "description": "Optional: limit search to specific column within the table" + } + }, + "required": ["search_term"] + } + } + }, + { + "type": "function", + "function": { + "name": "query_database", + "description": f"Execute a read-only SQL query on the database ({db_name}). Supports SELECT queries including JOINs, subqueries, CTEs, aggregations, etc. Maximum {self.server.max_query_results} rows returned. Query timeout: {self.server.max_query_timeout} seconds. INSERT/UPDATE/DELETE/DROP are blocked for safety.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": f"SQL SELECT query to execute. Available tables: {tables_str}. Example: SELECT * FROM table WHERE column = 'value' LIMIT 10" + }, + "limit": { + "type": "integer", + "description": f"Optional: maximum rows to return (default {self.server.default_query_limit}, max {self.server.max_query_results}). You can also use LIMIT in your query." + } + }, + "required": ["query"] + } + } + } + ] + +# Global MCP manager +mcp_manager = MCPManager() + +# ============================================================================ +# HELPER FUNCTIONS +# ============================================================================ + +def supports_function_calling(model: Dict[str, Any]) -> bool: + """Check if model supports function calling.""" + supported_params = model.get("supported_parameters", []) + return "tools" in supported_params or "functions" in supported_params + +def check_for_updates(current_version: str) -> str: + """Check for updates.""" try: response = requests.get( 'https://gitlab.pm/api/v1/repos/rune/oai/releases/latest', @@ -534,36 +2614,20 @@ def check_for_updates(current_version: str) -> str: logger.debug(f"Already up to date: {current_version}") return f"[bold green]oAI version {current_version} (up to date)[/]" - except requests.exceptions.HTTPError as e: - logger.warning(f"HTTP error checking for updates: {e.response.status_code}") - return f"[bold green]oAI version {current_version}[/]" - except requests.exceptions.ConnectionError: - logger.warning("Network error checking for updates (offline?)") - return f"[bold green]oAI version {current_version}[/]" - except requests.exceptions.Timeout: - logger.warning("Timeout checking for updates") - return f"[bold green]oAI version {current_version}[/]" - except requests.exceptions.RequestException as e: - logger.warning(f"Request error checking for updates: {type(e).__name__}") - return f"[bold green]oAI version {current_version}[/]" - except (KeyError, ValueError) as e: - logger.warning(f"Invalid API response checking for updates: {e}") - return f"[bold green]oAI version {current_version}[/]" - except Exception as e: - logger.error(f"Unexpected error checking for updates: {e}") + except: return f"[bold green]oAI version {current_version}[/]" def save_conversation(name: str, data: List[Dict[str, str]]): - """Save conversation history to DB.""" + """Save conversation.""" timestamp = datetime.datetime.now().isoformat() data_json = json.dumps(data) - with sqlite3.connect(DB_FILE) as conn: + with sqlite3.connect(str(database)) as conn: conn.execute('INSERT INTO conversation_sessions (name, timestamp, data) VALUES (?, ?, ?)', (name, timestamp, data_json)) conn.commit() def load_conversation(name: str) -> Optional[List[Dict[str, str]]]: - """Load conversation history from DB (latest by timestamp).""" - with sqlite3.connect(DB_FILE) as conn: + """Load conversation.""" + with sqlite3.connect(str(database)) as conn: cursor = conn.execute('SELECT data FROM conversation_sessions WHERE name = ? ORDER BY timestamp DESC LIMIT 1', (name,)) result = cursor.fetchone() if result: @@ -571,15 +2635,15 @@ def load_conversation(name: str) -> Optional[List[Dict[str, str]]]: return None def delete_conversation(name: str) -> int: - """Delete all conversation sessions with the given name. Returns number of deleted rows.""" - with sqlite3.connect(DB_FILE) as conn: + """Delete conversation.""" + with sqlite3.connect(str(database)) as conn: cursor = conn.execute('DELETE FROM conversation_sessions WHERE name = ?', (name,)) conn.commit() return cursor.rowcount def list_conversations() -> List[Dict[str, Any]]: - """List all saved conversations from DB with metadata.""" - with sqlite3.connect(DB_FILE) as conn: + """List conversations.""" + with sqlite3.connect(str(database)) as conn: cursor = conn.execute(''' SELECT name, MAX(timestamp) as last_saved, data FROM conversation_sessions @@ -598,34 +2662,32 @@ def list_conversations() -> List[Dict[str, Any]]: return conversations def estimate_cost(input_tokens: int, output_tokens: int) -> float: - """Estimate cost in USD based on token counts.""" + """Estimate cost.""" return (input_tokens * MODEL_PRICING['input'] / 1_000_000) + (output_tokens * MODEL_PRICING['output'] / 1_000_000) def has_web_search_capability(model: Dict[str, Any]) -> bool: - """Check if model supports web search based on supported_parameters.""" + """Check web search capability.""" supported_params = model.get("supported_parameters", []) - # Web search is typically indicated by 'tools' parameter support return "tools" in supported_params def has_image_capability(model: Dict[str, Any]) -> bool: - """Check if model supports image input based on input modalities.""" + """Check image capability.""" architecture = model.get("architecture", {}) input_modalities = architecture.get("input_modalities", []) return "image" in input_modalities def supports_online_mode(model: Dict[str, Any]) -> bool: - """Check if model supports :online suffix for web search.""" - # Models that support tools parameter can use :online + """Check online mode support.""" return has_web_search_capability(model) def get_effective_model_id(base_model_id: str, online_enabled: bool) -> str: - """Get the effective model ID with :online suffix if enabled.""" + """Get effective model ID.""" if online_enabled and not base_model_id.endswith(':online'): return f"{base_model_id}:online" return base_model_id def export_as_markdown(session_history: List[Dict[str, str]], session_system_prompt: str = "") -> str: - """Export conversation history as Markdown.""" + """Export as Markdown.""" lines = ["# Conversation Export", ""] if session_system_prompt: lines.extend([f"**System Prompt:** {session_system_prompt}", ""]) @@ -651,7 +2713,7 @@ def export_as_markdown(session_history: List[Dict[str, str]], session_system_pro return "\n".join(lines) def export_as_json(session_history: List[Dict[str, str]], session_system_prompt: str = "") -> str: - """Export conversation history as JSON.""" + """Export as JSON.""" export_data = { "export_date": datetime.datetime.now().isoformat(), "system_prompt": session_system_prompt, @@ -661,8 +2723,7 @@ def export_as_json(session_history: List[Dict[str, str]], session_system_prompt: return json.dumps(export_data, indent=2, ensure_ascii=False) def export_as_html(session_history: List[Dict[str, str]], session_system_prompt: str = "") -> str: - """Export conversation history as HTML.""" - # Escape HTML special characters + """Export as HTML.""" def escape_html(text): return text.replace('&', '&').replace('<', '<').replace('>', '>').replace('"', '"').replace("'", ''') @@ -733,99 +2794,84 @@ def export_as_html(session_history: List[Dict[str, str]], session_system_prompt: return "\n".join(html_parts) -def show_command_help(command: str): - """Display detailed help for a specific command.""" - # Normalize command to ensure it starts with / - if not command.startswith('/'): - command = '/' + command +def show_command_help(command_or_topic: str): + """Display detailed help for command or topic.""" + # Handle topics (like 'mcp') + if not command_or_topic.startswith('/'): + if command_or_topic.lower() == 'mcp': + help_data = COMMAND_HELP['mcp'] + + help_content = [] + help_content.append(f"[bold cyan]Description:[/]") + help_content.append(help_data['description']) + help_content.append("") + help_content.append(help_data['notes']) + + console.print(Panel( + "\n".join(help_content), + title="[bold green]MCP - Model Context Protocol Guide[/]", + title_align="left", + border_style="green", + width=console.width - 4 + )) + + app_logger.info("Displayed MCP comprehensive guide") + return + else: + command_or_topic = '/' + command_or_topic - # Check if command exists - if command not in COMMAND_HELP: - console.print(f"[bold red]Unknown command: {command}[/]") + if command_or_topic not in COMMAND_HELP: + console.print(f"[bold red]Unknown command: {command_or_topic}[/]") console.print("[bold yellow]Type /help to see all available commands.[/]") - app_logger.warning(f"Help requested for unknown command: {command}") + console.print("[bold yellow]Type /help mcp for comprehensive MCP guide.[/]") + app_logger.warning(f"Help requested for unknown command: {command_or_topic}") return - help_data = COMMAND_HELP[command] + help_data = COMMAND_HELP[command_or_topic] - # Create detailed help panel help_content = [] - # Aliases if available if 'aliases' in help_data: aliases_str = ", ".join(help_data['aliases']) help_content.append(f"[bold cyan]Aliases:[/] {aliases_str}") help_content.append("") - # Description help_content.append(f"[bold cyan]Description:[/]") help_content.append(help_data['description']) help_content.append("") - # Usage help_content.append(f"[bold cyan]Usage:[/]") help_content.append(f"[yellow]{help_data['usage']}[/]") help_content.append("") - # Examples if 'examples' in help_data and help_data['examples']: help_content.append(f"[bold cyan]Examples:[/]") for desc, example in help_data['examples']: - help_content.append(f" [dim]{desc}:[/]") - help_content.append(f" [green]{example}[/]") - help_content.append("") + if not desc and not example: + help_content.append("") + elif desc.startswith('━━━'): + help_content.append(f"[bold yellow]{desc}[/]") + else: + help_content.append(f" [dim]{desc}:[/]" if desc else "") + help_content.append(f" [green]{example}[/]" if example else "") + help_content.append("") - # Notes if 'notes' in help_data: help_content.append(f"[bold cyan]Notes:[/]") help_content.append(f"[dim]{help_data['notes']}[/]") console.print(Panel( "\n".join(help_content), - title=f"[bold green]Help: {command}[/]", + title=f"[bold green]Help: {command_or_topic}[/]", title_align="left", border_style="green", width=console.width - 4 )) - app_logger.info(f"Displayed detailed help for command: {command}") - -# Load configs (AFTER logging is set up) -API_KEY = get_config('api_key') -OPENROUTER_BASE_URL = get_config('base_url') or "https://openrouter.ai/api/v1" -STREAM_ENABLED = get_config('stream_enabled') or "on" -DEFAULT_MODEL_ID = get_config('default_model') -MAX_TOKEN = int(get_config('max_token') or "100000") -COST_WARNING_THRESHOLD = float(get_config(HIGH_COST_WARNING) or "0.01") # Configurable cost threshold for alerts -DEFAULT_ONLINE_MODE = get_config('default_online_mode') or "off" # New: Default online mode setting - -# Fetch models with app identification headers -models_data = [] -text_models = [] -try: - headers = { - "Authorization": f"Bearer {API_KEY}", - "HTTP-Referer": APP_URL, - "X-Title": APP_NAME - } if API_KEY else { - "HTTP-Referer": APP_URL, - "X-Title": APP_NAME - } - response = requests.get(f"{OPENROUTER_BASE_URL}/models", headers=headers) - response.raise_for_status() - models_data = response.json()["data"] - text_models = [m for m in models_data if "modalities" not in m or "video" not in (m.get("modalities") or [])] - selected_model_default = None - if DEFAULT_MODEL_ID: - selected_model_default = next((m for m in text_models if m["id"] == DEFAULT_MODEL_ID), None) - if not selected_model_default: - console.print(f"[bold yellow]Warning: Default model '{DEFAULT_MODEL_ID}' unavailable. Use '/config model'.[/]") -except Exception as e: - models_data = [] - text_models = [] - app_logger.error(f"Failed to fetch models: {e}") + app_logger.info(f"Displayed detailed help for command: {command_or_topic}") def get_credits(api_key: str, base_url: str = OPENROUTER_BASE_URL) -> Optional[Dict[str, str]]: + """Get credits.""" if not api_key: return None url = f"{base_url}/credits" @@ -851,7 +2897,7 @@ def get_credits(api_key: str, base_url: str = OPENROUTER_BASE_URL) -> Optional[D return None def check_credit_alerts(credits_data: Optional[Dict[str, str]]) -> List[str]: - """Check and return list of credit-related alerts.""" + """Check credit alerts.""" alerts = [] if credits_data: credits_left_value = float(credits_data['credits_left'].strip('$')) @@ -863,26 +2909,23 @@ def check_credit_alerts(credits_data: Optional[Dict[str, str]]) -> List[str]: return alerts def clear_screen(): + """Clear screen.""" try: print("\033[H\033[J", end="", flush=True) except: print("\n" * 100) def display_paginated_table(table: Table, title: str): - """Display a table with pagination support using Rich console for colored output, repeating header on each page.""" - # Get terminal height (subtract some lines for prompt and margins) + """Display paginated table.""" try: terminal_height = os.get_terminal_size().lines - 8 except: - terminal_height = 20 # Fallback if terminal size can't be determined + terminal_height = 20 - # Create a segment-based approach to capture Rich-rendered output from rich.segment import Segment - # Render the table to segments segments = list(console.render(table)) - # Convert segments to lines while preserving style current_line_segments = [] all_lines = [] @@ -893,27 +2936,22 @@ def display_paginated_table(table: Table, title: str): else: current_line_segments.append(segment) - # Add last line if not empty if current_line_segments: all_lines.append(current_line_segments) total_lines = len(all_lines) - # If fits on one screen after segment analysis if total_lines <= terminal_height: console.print(Panel(table, title=title, title_align="left")) return - # Separate header from data rows header_lines = [] data_lines = [] - # Find where the header ends header_end_index = 0 found_header_text = False for i, line_segments in enumerate(all_lines): - # Check if this line contains header-style text has_header_style = any( seg.style and ('bold' in str(seg.style) or 'magenta' in str(seg.style)) for seg in line_segments @@ -922,56 +2960,44 @@ def display_paginated_table(table: Table, title: str): if has_header_style: found_header_text = True - # After finding header text, the next line with box-drawing chars is the separator if found_header_text and i > 0: line_text = ''.join(seg.text for seg in line_segments) if any(char in line_text for char in ['─', '━', '┼', '╪', '┤', '├']): header_end_index = i break - # If we found a header separator, split there if header_end_index > 0: header_lines = all_lines[:header_end_index + 1] data_lines = all_lines[header_end_index + 1:] else: - # Fallback: assume first 3 lines are header header_lines = all_lines[:min(3, len(all_lines))] data_lines = all_lines[min(3, len(all_lines)):] - # Calculate how many data lines fit per page lines_per_page = terminal_height - len(header_lines) - # Display with pagination current_line = 0 page_number = 1 while current_line < len(data_lines): - # Clear screen for each page clear_screen() - # Print title console.print(f"[bold cyan]{title} (Page {page_number})[/]") - # Print header on every page for line_segments in header_lines: for segment in line_segments: console.print(segment.text, style=segment.style, end="") console.print() - # Calculate how many data lines to show on this page end_line = min(current_line + lines_per_page, len(data_lines)) - # Print data lines for this page for line_segments in data_lines[current_line:end_line]: for segment in line_segments: console.print(segment.text, style=segment.style, end="") console.print() - # Update position current_line = end_line page_number += 1 - # If there's more content, wait for user if current_line < len(data_lines): console.print(f"\n[dim yellow]--- Press SPACE for next page, or any other key to finish (Page {page_number - 1}, showing {end_line}/{len(data_lines)} data rows) ---[/dim yellow]") try: @@ -991,16 +3017,20 @@ def display_paginated_table(table: Table, title: str): finally: termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) except: - # Fallback for Windows or if termios not available input_char = input().strip() if input_char != '': break else: break + # P3 + # ============================================================================ +# MAIN CHAT FUNCTION WITH FULL MCP SUPPORT (FILES + DATABASE) +# ============================================================================ @app.command() def chat(): global API_KEY, OPENROUTER_BASE_URL, STREAM_ENABLED, MAX_TOKEN, COST_WARNING_THRESHOLD, DEFAULT_ONLINE_MODE, LOG_MAX_SIZE_MB, LOG_BACKUP_COUNT, LOG_LEVEL, LOG_LEVEL_STR, app_logger + session_max_token = 0 session_system_prompt = "" session_history = [] @@ -1009,11 +3039,11 @@ def chat(): total_output_tokens = 0 total_cost = 0.0 message_count = 0 - middle_out_enabled = False # Session-level middle-out transform flag - conversation_memory_enabled = True # Memory ON by default - memory_start_index = 0 # Track when memory was last enabled - saved_conversations_cache = [] # Cache for /list results to use with /load by number - online_mode_enabled = DEFAULT_ONLINE_MODE == "on" # Initialize from config + middle_out_enabled = False + conversation_memory_enabled = True + memory_start_index = 0 + saved_conversations_cache = [] + online_mode_enabled = DEFAULT_ONLINE_MODE == "on" app_logger.info("Starting new chat session with memory enabled") @@ -1035,7 +3065,6 @@ def chat(): console.print("[bold red]No models available. Check API key/URL.[/]") raise typer.Exit() - # Check for credit alerts at startup credits_data = get_credits(API_KEY, OPENROUTER_BASE_URL) startup_credit_alerts = check_credit_alerts(credits_data) if startup_credit_alerts: @@ -1045,7 +3074,6 @@ def chat(): selected_model = selected_model_default - # Initialize OpenRouter client client = OpenRouter(api_key=API_KEY) if selected_model: @@ -1057,18 +3085,35 @@ def chat(): if not selected_model: console.print("[bold yellow]No model selected. Use '/model'.[/]") - # Persistent input history session = PromptSession(history=FileHistory(str(history_file))) while True: try: - user_input = session.prompt("You> ", auto_suggest=AutoSuggestFromHistory()).strip() + # ============================================================ + # BUILD PROMPT PREFIX WITH MODE INDICATOR + # ============================================================ + prompt_prefix = "You> " + if mcp_manager.enabled: + if mcp_manager.mode == "files": + prompt_prefix = "[🔧 MCP: Files] You> " + elif mcp_manager.mode == "database" and mcp_manager.selected_db_index is not None: + db = mcp_manager.databases[mcp_manager.selected_db_index] + prompt_prefix = f"[🗄️ MCP: DB #{mcp_manager.selected_db_index + 1}] You> " - # Handle // escape sequence - convert to single / and treat as regular text + # ============================================================ + # INITIALIZE LOOP VARIABLES + # ============================================================ + text_part = "" + file_attachments = [] + content_blocks = [] + + user_input = session.prompt(prompt_prefix, auto_suggest=AutoSuggestFromHistory()).strip() + + # Handle escape sequence if user_input.startswith("//"): - user_input = user_input[1:] # Remove first slash, keep the rest + user_input = user_input[1:] - # Check for unknown commands + # Check unknown commands elif user_input.startswith("/") and user_input.lower() not in ["exit", "quit", "bye"]: command_word = user_input.split()[0].lower() if user_input.split() else user_input.lower() @@ -1084,7 +3129,479 @@ def chat(): console.print("[bold yellow]Goodbye![/]") return - # Commands with logging + # ============================================================ + # MCP COMMANDS + # ============================================================ + if user_input.lower().startswith("/mcp"): + parts = user_input[5:].strip().split(maxsplit=1) + mcp_command = parts[0].lower() if parts else "" + mcp_args = parts[1] if len(parts) > 1 else "" + + if mcp_command == "enable": + result = mcp_manager.enable() + if result['success']: + console.print(f"[bold green]✓ {result['message']}[/]") + if result['folder_count'] == 0 and result['database_count'] == 0: + console.print("\n[bold yellow]No folders or databases configured yet.[/]") + console.print("\n[bold cyan]To get started:[/]") + console.print(" [bold yellow]Files:[/] /mcp add ~/Documents") + console.print(" [bold yellow]Databases:[/] /mcp add db ~/app/data.db") + else: + folders_msg = f"{result['folder_count']} folder(s)" if result['folder_count'] else "no folders" + dbs_msg = f"{result['database_count']} database(s)" if result['database_count'] else "no databases" + console.print(f"\n[bold cyan]MCP active with {folders_msg} and {dbs_msg}.[/]") + + warning = mcp_manager.config.get_filesystem_warning() + if warning: + console.print(warning) + else: + console.print(f"[bold red]❌ {result['error']}[/]") + continue + + elif mcp_command == "disable": + result = mcp_manager.disable() + if result['success']: + console.print(f"[bold green]✓ {result['message']}[/]") + else: + console.print(f"[bold red]❌ {result['error']}[/]") + continue + + elif mcp_command == "add": + # Check if it's a database or folder + if mcp_args.startswith("db "): + # Database + db_path = mcp_args[3:].strip() + if not db_path: + console.print("[bold red]Usage: /mcp add db [/]") + continue + + result = mcp_manager.add_database(db_path) + + if result['success']: + db = result['database'] + + console.print("\n[bold yellow]⚠️ Database Check:[/]") + console.print(f"Adding: [bold]{db['name']}[/]") + console.print(f"Size: {db['size'] / (1024 * 1024):.2f} MB") + console.print(f"Tables: {', '.join(db['tables'])} ({len(db['tables'])} total)") + + if db['size'] > 100 * 1024 * 1024: # > 100MB + console.print("\n[bold yellow]⚠️ Large database detected! Queries may be slow.[/]") + + console.print("\nMCP will be able to:") + console.print(" [green]✓[/green] Inspect database schema") + console.print(" [green]✓[/green] Search for data across tables") + console.print(" [green]✓[/green] Execute read-only SQL queries") + console.print(" [red]✗[/red] Modify data (read-only mode)") + + try: + confirm = typer.confirm("\nProceed?", default=True) + if not confirm: + # Remove it + mcp_manager.remove_database(str(result['number'])) + console.print("[bold yellow]Cancelled. Database not added.[/]") + continue + except (EOFError, KeyboardInterrupt): + mcp_manager.remove_database(str(result['number'])) + console.print("\n[bold yellow]Cancelled. Database not added.[/]") + continue + + console.print(f"\n[bold green]✓ {result['message']}[/]") + console.print(f"\n[bold cyan]Use '/mcp db {result['number']}' to start querying this database.[/]") + else: + console.print(f"[bold red]❌ {result['error']}[/]") + else: + # Folder + folder_path = mcp_args + if not folder_path: + console.print("[bold red]Usage: /mcp add or /mcp add db [/]") + continue + + result = mcp_manager.add_folder(folder_path) + + if result['success']: + stats = result['stats'] + + console.print("\n[bold yellow]⚠️ Security Check:[/]") + console.print(f"You are granting MCP access to: [bold]{result['path']}[/]") + + if stats.get('exists'): + file_count = stats['file_count'] + size_mb = stats['size_mb'] + console.print(f"This folder contains: [bold]{file_count} files ({size_mb:.1f} MB)[/]") + + console.print("\nMCP will be able to:") + console.print(" [green]✓[/green] Read files in this folder") + console.print(" [green]✓[/green] List and search files") + console.print(" [green]✓[/green] Access subfolders recursively") + console.print(" [green]✓[/green] Automatically respect .gitignore patterns") + console.print(" [red]✗[/red] Delete or modify files (read-only)") + + try: + confirm = typer.confirm("\nProceed?", default=True) + if not confirm: + mcp_manager.allowed_folders.remove(mcp_manager.config.normalize_path(folder_path)) + mcp_manager._save_folders() + console.print("[bold yellow]Cancelled. Folder not added.[/]") + continue + except (EOFError, KeyboardInterrupt): + mcp_manager.allowed_folders.remove(mcp_manager.config.normalize_path(folder_path)) + mcp_manager._save_folders() + console.print("\n[bold yellow]Cancelled. Folder not added.[/]") + continue + + console.print(f"\n[bold green]✓ Added {result['path']} to MCP allowed folders[/]") + console.print(f"[dim cyan]MCP now has access to {result['total_folders']} folder(s) total.[/]") + + if result.get('warning'): + console.print(f"\n[bold yellow]⚠️ {result['warning']}[/]") + else: + console.print(f"[bold red]❌ {result['error']}[/]") + continue + + elif mcp_command in ["remove", "rem"]: + # Check if it's a database or folder + if mcp_args.startswith("db "): + # Database + db_ref = mcp_args[3:].strip() + if not db_ref: + console.print("[bold red]Usage: /mcp remove db [/]") + continue + + result = mcp_manager.remove_database(db_ref) + + if result['success']: + console.print(f"\n[bold yellow]Removing: {result['database']['name']}[/]") + + try: + confirm = typer.confirm("Confirm removal?", default=False) + if not confirm: + console.print("[bold yellow]Cancelled.[/]") + # Re-add it + mcp_manager.databases.append(result['database']) + mcp_manager._save_database(result['database']) + continue + except (EOFError, KeyboardInterrupt): + console.print("\n[bold yellow]Cancelled.[/]") + mcp_manager.databases.append(result['database']) + mcp_manager._save_database(result['database']) + continue + + console.print(f"\n[bold green]✓ {result['message']}[/]") + + if result.get('warning'): + console.print(f"\n[bold yellow]⚠️ {result['warning']}[/]") + else: + console.print(f"[bold red]❌ {result['error']}[/]") + else: + # Folder + if not mcp_args: + console.print("[bold red]Usage: /mcp remove or /mcp remove db [/]") + continue + + result = mcp_manager.remove_folder(mcp_args) + + if result['success']: + console.print(f"\n[bold yellow]Removing: {result['path']}[/]") + + try: + confirm = typer.confirm("Confirm removal?", default=False) + if not confirm: + mcp_manager.allowed_folders.append(mcp_manager.config.normalize_path(result['path'])) + mcp_manager._save_folders() + console.print("[bold yellow]Cancelled. Folder not removed.[/]") + continue + except (EOFError, KeyboardInterrupt): + mcp_manager.allowed_folders.append(mcp_manager.config.normalize_path(result['path'])) + mcp_manager._save_folders() + console.print("\n[bold yellow]Cancelled. Folder not removed.[/]") + continue + + console.print(f"\n[bold green]✓ Removed {result['path']} from MCP allowed folders[/]") + console.print(f"[dim cyan]MCP now has access to {result['total_folders']} folder(s) total.[/]") + + if result.get('warning'): + console.print(f"\n[bold yellow]⚠️ {result['warning']}[/]") + else: + console.print(f"[bold red]❌ {result['error']}[/]") + continue + + elif mcp_command == "list": + result = mcp_manager.list_folders() + + if result['success']: + if result['total_folders'] == 0: + console.print("[bold yellow]No folders configured.[/]") + console.print("\n[bold cyan]Add a folder with: /mcp add ~/Documents[/]") + continue + + status_indicator = "[green]✓[/green]" if mcp_manager.enabled else "[red]✗[/red]" + + table = Table( + "No.", "Path", "Files", "Size", + show_header=True, + header_style="bold magenta" + ) + + for folder_info in result['folders']: + number = str(folder_info['number']) + path = folder_info['path'] + + if folder_info['exists']: + files = f"📁 {folder_info['file_count']}" + size = f"{folder_info['size_mb']:.1f} MB" + else: + files = "[red]Not found[/red]" + size = "-" + + table.add_row(number, path, files, size) + + gitignore_info = "" + if mcp_manager.server: + gitignore_status = "on" if mcp_manager.server.respect_gitignore else "off" + pattern_count = len(mcp_manager.server.gitignore_parser.patterns) + gitignore_info = f" | .gitignore: {gitignore_status} ({pattern_count} patterns)" + + console.print(Panel( + table, + title=f"[bold green]MCP Folders: {'Active' if mcp_manager.enabled else 'Inactive'} {status_indicator}[/]", + title_align="left", + subtitle=f"[dim]Total: {result['total_folders']} folders, {result['total_files']} files ({result['total_size_mb']:.1f} MB){gitignore_info}[/]", + subtitle_align="right" + )) + else: + console.print(f"[bold red]❌ {result['error']}[/]") + continue + + elif mcp_command == "db": + if not mcp_args: + # Show current database or list all + if mcp_manager.mode == "database" and mcp_manager.selected_db_index is not None: + db = mcp_manager.databases[mcp_manager.selected_db_index] + console.print(f"[bold cyan]Currently using database #{mcp_manager.selected_db_index + 1}: {db['name']}[/]") + console.print(f"[dim]Path: {db['path']}[/]") + console.print(f"[dim]Tables: {', '.join(db['tables'])}[/]") + else: + console.print("[bold yellow]Not in database mode. Use '/mcp db ' to select a database.[/]") + + # Also show hint to list + console.print("\n[dim]Use '/mcp db list' to see all databases[/]") + continue + + if mcp_args == "list": + result = mcp_manager.list_databases() + + if result['success']: + if result['count'] == 0: + console.print("[bold yellow]No databases configured.[/]") + console.print("\n[bold cyan]Add a database with: /mcp add db ~/app/data.db[/]") + continue + + table = Table( + "No.", "Name", "Tables", "Size", "Status", + show_header=True, + header_style="bold magenta" + ) + + for db_info in result['databases']: + number = str(db_info['number']) + name = db_info['name'] + table_count = f"{db_info['table_count']} tables" + size = f"{db_info['size_mb']:.1f} MB" + + if db_info.get('warning'): + status = f"[red]{db_info['warning']}[/red]" + else: + status = "[green]✓[/green]" + + table.add_row(number, name, table_count, size, status) + + console.print(Panel( + table, + title="[bold green]MCP Databases[/]", + title_align="left", + subtitle=f"[dim]Total: {result['count']} database(s) | Use '/mcp db ' to select[/]", + subtitle_align="right" + )) + else: + console.print(f"[bold red]❌ {result['error']}[/]") + continue + + # Switch to database mode + try: + db_num = int(mcp_args) + result = mcp_manager.switch_mode("database", db_num) + + if result['success']: + db = result['database'] + console.print(f"\n[bold green]✓ {result['message']}[/]") + console.print(f"[dim cyan]Tables: {', '.join(db['tables'])}[/]") + console.print(f"\n[bold cyan]Available tools:[/] inspect_database, search_database, query_database") + console.print(f"\n[bold yellow]You can now ask questions about this database![/]") + console.print(f"[dim]Examples:[/]") + console.print(f" 'Show me all tables'") + console.print(f" 'Search for records mentioning X'") + console.print(f" 'How many rows in the users table?'") + console.print(f"\n[dim]Switch back to files: /mcp files[/]") + else: + console.print(f"[bold red]❌ {result['error']}[/]") + except ValueError: + console.print(f"[bold red]Invalid database number: {mcp_args}[/]") + console.print(f"[bold yellow]Use '/mcp db list' to see available databases[/]") + continue + + elif mcp_command == "files": + result = mcp_manager.switch_mode("files") + + if result['success']: + console.print(f"[bold green]✓ {result['message']}[/]") + console.print(f"\n[bold cyan]Available tools:[/] read_file, list_directory, search_files") + else: + console.print(f"[bold red]❌ {result['error']}[/]") + continue + + elif mcp_command == "gitignore": + if not mcp_args: + if mcp_manager.server: + status = "enabled" if mcp_manager.server.respect_gitignore else "disabled" + pattern_count = len(mcp_manager.server.gitignore_parser.patterns) + console.print(f"[bold blue].gitignore filtering: {status}[/]") + console.print(f"[dim cyan]Loaded {pattern_count} pattern(s) from .gitignore files[/]") + else: + console.print("[bold yellow]MCP is not enabled. Enable with /mcp enable[/]") + continue + + if mcp_args.lower() == "on": + result = mcp_manager.toggle_gitignore(True) + if result['success']: + console.print(f"[bold green]✓ {result['message']}[/]") + console.print(f"[dim cyan]Using {result['pattern_count']} .gitignore pattern(s)[/]") + else: + console.print(f"[bold red]❌ {result['error']}[/]") + elif mcp_args.lower() == "off": + result = mcp_manager.toggle_gitignore(False) + if result['success']: + console.print(f"[bold green]✓ {result['message']}[/]") + console.print("[bold yellow]Warning: All files will be visible, including those in .gitignore[/]") + else: + console.print(f"[bold red]❌ {result['error']}[/]") + else: + console.print("[bold yellow]Usage: /mcp gitignore on|off[/]") + continue + + elif mcp_command == "status": + result = mcp_manager.get_status() + + if result['success']: + status_color = "green" if result['enabled'] else "red" + status_text = "Active ✓" if result['enabled'] else "Inactive ✗" + + table = Table( + "Property", "Value", + show_header=True, + header_style="bold magenta" + ) + + table.add_row("Status", f"[{status_color}]{status_text}[/{status_color}]") + + # Mode info + mode_info = result['mode_info'] + table.add_row("Current Mode", mode_info['mode_display']) + if 'database' in mode_info: + db = mode_info['database'] + table.add_row("Database Path", db['path']) + table.add_row("Database Tables", ', '.join(db['tables'])) + + if result['uptime']: + table.add_row("Uptime", result['uptime']) + + table.add_row("", "") + table.add_row("[bold]Configuration[/]", "") + table.add_row("Allowed Folders", str(result['folder_count'])) + table.add_row("Databases", str(result['database_count'])) + table.add_row("Total Files Accessible", str(result['total_files'])) + table.add_row("Total Size", f"{result['total_size_mb']:.1f} MB") + table.add_row(".gitignore Filtering", result['gitignore_status']) + if result['gitignore_patterns'] > 0: + table.add_row(".gitignore Patterns", str(result['gitignore_patterns'])) + + if result['enabled']: + table.add_row("", "") + table.add_row("[bold]Tools Available[/]", "") + for tool in result['tools_available']: + table.add_row(f" ✓ {tool}", "") + + stats = result['stats'] + if stats['total_calls'] > 0: + table.add_row("", "") + table.add_row("[bold]Session Stats[/]", "") + table.add_row("Total Tool Calls", str(stats['total_calls'])) + + if stats['reads'] > 0: + table.add_row("Files Read", str(stats['reads'])) + if stats['lists'] > 0: + table.add_row("Directories Listed", str(stats['lists'])) + if stats['searches'] > 0: + table.add_row("File Searches", str(stats['searches'])) + if stats['db_inspects'] > 0: + table.add_row("DB Inspections", str(stats['db_inspects'])) + if stats['db_searches'] > 0: + table.add_row("DB Searches", str(stats['db_searches'])) + if stats['db_queries'] > 0: + table.add_row("SQL Queries", str(stats['db_queries'])) + + if stats['last_used']: + try: + last_used = datetime.datetime.fromisoformat(stats['last_used']) + time_ago = datetime.datetime.now() - last_used + + if time_ago.seconds < 60: + time_str = f"{time_ago.seconds} seconds ago" + elif time_ago.seconds < 3600: + time_str = f"{time_ago.seconds // 60} minutes ago" + else: + time_str = f"{time_ago.seconds // 3600} hours ago" + + table.add_row("Last Used", time_str) + except: + pass + + console.print(Panel( + table, + title="[bold green]MCP Status[/]", + title_align="left" + )) + else: + console.print(f"[bold red]❌ {result['error']}[/]") + continue + + else: + console.print("[bold cyan]MCP (Model Context Protocol) Commands:[/]\n") + console.print(" [bold yellow]/mcp enable[/] - Start MCP server") + console.print(" [bold yellow]/mcp disable[/] - Stop MCP server") + console.print(" [bold yellow]/mcp status[/] - Show comprehensive MCP status") + console.print("") + console.print(" [bold cyan]FILE MODE (default):[/]") + console.print(" [bold yellow]/mcp add [/] - Add folder for file access") + console.print(" [bold yellow]/mcp remove [/] - Remove folder") + console.print(" [bold yellow]/mcp list[/] - List all folders") + console.print(" [bold yellow]/mcp files[/] - Switch to file mode") + console.print(" [bold yellow]/mcp gitignore [on|off][/] - Toggle .gitignore filtering") + console.print("") + console.print(" [bold cyan]DATABASE MODE:[/]") + console.print(" [bold yellow]/mcp add db [/] - Add SQLite database") + console.print(" [bold yellow]/mcp db list[/] - List all databases") + console.print(" [bold yellow]/mcp db [/] - Switch to database mode") + console.print(" [bold yellow]/mcp remove db [/]- Remove database") + console.print("") + console.print("[dim]Use '/help /mcp' for detailed command help[/]") + console.print("[dim]Use '/help mcp' for comprehensive MCP guide[/]") + continue + + # ============================================================ + # ALL OTHER COMMANDS + # ============================================================ + if user_input.lower() == "/retry": if not session_history: console.print("[bold red]No history to retry.[/]") @@ -1094,6 +3611,7 @@ def chat(): console.print("[bold green]Retrying last prompt...[/]") app_logger.info(f"Retrying prompt: {last_prompt[:100]}...") user_input = last_prompt + elif user_input.lower().startswith("/online"): args = user_input[8:].strip() if not args: @@ -1129,6 +3647,7 @@ def chat(): else: console.print("[bold yellow]Usage: /online on|off (or /online to view status)[/]") continue + elif user_input.lower().startswith("/memory"): args = user_input[8:].strip() if not args: @@ -1154,6 +3673,7 @@ def chat(): else: console.print("[bold yellow]Usage: /memory on|off (or /memory to view status)[/]") continue + elif user_input.lower().startswith("/paste"): optional_prompt = user_input[7:].strip() @@ -1205,7 +3725,7 @@ def chat(): console.print(f"[bold red]Error processing clipboard content: {e}[/]") app_logger.error(f"Clipboard processing error: {e}") continue - + elif user_input.lower().startswith("/export"): args = user_input[8:].strip().split(maxsplit=1) if len(args) != 2: @@ -1244,6 +3764,7 @@ def chat(): console.print(f"[bold red]Export failed: {e}[/]") app_logger.error(f"Export error: {e}") continue + elif user_input.lower().startswith("/save"): args = user_input[6:].strip() if not args: @@ -1256,6 +3777,7 @@ def chat(): console.print(f"[bold green]Conversation saved as '{args}'.[/]") app_logger.info(f"Conversation saved as '{args}' with {len(session_history)} messages") continue + elif user_input.lower().startswith("/load"): args = user_input[6:].strip() if not args: @@ -1292,6 +3814,7 @@ def chat(): console.print(f"[bold green]Conversation '{conversation_name}' loaded with {len(session_history)} messages.[/]") app_logger.info(f"Conversation '{conversation_name}' loaded with {len(session_history)} messages") continue + elif user_input.lower().startswith("/delete"): args = user_input[8:].strip() if not args: @@ -1331,6 +3854,7 @@ def chat(): console.print(f"[bold red]Conversation '{conversation_name}' not found.[/]") app_logger.warning(f"Delete failed for '{conversation_name}' - not found") continue + elif user_input.lower() == "/list": conversations = list_conversations() if not conversations: @@ -1359,6 +3883,7 @@ def chat(): console.print(Panel(table, title=f"[bold green]Saved Conversations ({len(conversations)} total)[/]", title_align="left", subtitle="[dim]Use /load or /delete to manage conversations[/]", subtitle_align="right")) app_logger.info(f"User viewed conversation list - {len(conversations)} conversations") continue + elif user_input.lower() == "/prev": if not session_history or current_index <= 0: console.print("[bold red]No previous response.[/]") @@ -1369,6 +3894,7 @@ def chat(): console.print(Panel(md, title=f"[bold green]Previous Response ({current_index + 1}/{len(session_history)})[/]", title_align="left")) app_logger.debug(f"Viewed previous response at index {current_index}") continue + elif user_input.lower() == "/next": if not session_history or current_index >= len(session_history) - 1: console.print("[bold red]No next response.[/]") @@ -1379,6 +3905,7 @@ def chat(): console.print(Panel(md, title=f"[bold green]Next Response ({current_index + 1}/{len(session_history)})[/]", title_align="left")) app_logger.debug(f"Viewed next response at index {current_index}") continue + elif user_input.lower() == "/stats": credits = get_credits(API_KEY, OPENROUTER_BASE_URL) credits_left = credits['credits_left'] if credits else "Unknown" @@ -1395,6 +3922,7 @@ def chat(): console.print(f"[bold red]⚠️ {warning_text}[/]") app_logger.warning(f"Warnings in stats: {warning_text}") continue + elif user_input.lower().startswith("/middleout"): args = user_input[11:].strip() if not args: @@ -1409,6 +3937,7 @@ def chat(): else: console.print("[bold yellow]Usage: /middleout on|off (or /middleout to view status)[/]") continue + elif user_input.lower() == "/reset": try: confirm = typer.confirm("Reset conversation context? This clears history and prompt.", default=False) @@ -1496,7 +4025,6 @@ def chat(): if 1 <= choice <= len(filtered_models): selected_model = filtered_models[choice - 1] - # Apply default online mode if model supports it if supports_online_mode(selected_model) and DEFAULT_ONLINE_MODE == "on": online_mode_enabled = True console.print(f"[bold cyan]Selected: {selected_model['name']} ({selected_model['id']})[/]") @@ -1634,7 +4162,6 @@ def chat(): set_config('log_max_size_mb', str(new_size_mb)) LOG_MAX_SIZE_MB = new_size_mb - # Reload logging configuration immediately app_logger = reload_logging_config() console.print(f"[bold green]Log size limit set to {new_size_mb} MB and applied immediately.[/]") @@ -1715,6 +4242,14 @@ def chat(): DEFAULT_MODEL_ID = get_config('default_model') memory_status = "Enabled" if conversation_memory_enabled else "Disabled" memory_tracked = len(session_history) - memory_start_index if conversation_memory_enabled else 0 + + mcp_status = "Enabled" if mcp_manager.enabled else "Disabled" + mcp_folders = len(mcp_manager.allowed_folders) + mcp_databases = len(mcp_manager.databases) + mcp_mode = mcp_manager.mode + gitignore_status = "enabled" if (mcp_manager.server and mcp_manager.server.respect_gitignore) else "disabled" + gitignore_patterns = len(mcp_manager.server.gitignore_parser.patterns) if mcp_manager.server else 0 + table = Table("Setting", "Value", show_header=True, header_style="bold magenta", width=console.width - 10) table.add_row("API Key", API_KEY or "[Not set]") table.add_row("Base URL", OPENROUTER_BASE_URL or "[Not set]") @@ -1728,6 +4263,9 @@ def chat(): table.add_row("Current Model", "[Not set]" if selected_model is None else str(selected_model["name"])) table.add_row("Default Online Mode", "Enabled" if DEFAULT_ONLINE_MODE == "on" else "Disabled") table.add_row("Session Online Mode", "Enabled" if online_mode_enabled else "Disabled") + table.add_row("MCP Status", f"{mcp_status} ({mcp_folders} folders, {mcp_databases} DBs)") + table.add_row("MCP Mode", f"{mcp_mode}") + table.add_row("MCP .gitignore", f"{gitignore_status} ({gitignore_patterns} patterns)") table.add_row("Max Token", str(MAX_TOKEN)) table.add_row("Session Token", "[Not set]" if session_max_token == 0 else str(session_max_token)) table.add_row("Session System Prompt", session_system_prompt or "[Not set]") @@ -1767,22 +4305,24 @@ def chat(): if user_input.lower() == "/clear" or user_input.lower() == "/cl": clear_screen() DEFAULT_MODEL_ID = get_config('default_model') - token_value = session_max_token if session_max_token != 0 else " Not set" + token_value = session_max_token if session_max_token != 0 else "Not set" console.print(f"[bold cyan]Token limits: Max= {MAX_TOKEN}, Session={token_value}[/]") console.print("[bold blue]Active model[/] [bold red]%s[/]" %(str(selected_model["name"]) if selected_model else "None")) if online_mode_enabled: console.print("[bold cyan]Online mode: Enabled (web search active)[/]") + if mcp_manager.enabled: + gitignore_status = "on" if (mcp_manager.server and mcp_manager.server.respect_gitignore) else "off" + mode_display = "Files" if mcp_manager.mode == "files" else f"DB #{mcp_manager.selected_db_index + 1}" + console.print(f"[bold cyan]MCP: Enabled (Mode: {mode_display}, {len(mcp_manager.allowed_folders)} folders, {len(mcp_manager.databases)} DBs, .gitignore {gitignore_status})[/]") continue if user_input.lower().startswith("/help"): args = user_input[6:].strip() - # If a specific command is requested if args: show_command_help(args) continue - # Otherwise show the full help menu help_table = Table("Command", "Description", "Example", show_header=True, header_style="bold cyan", width=console.width - 10) # SESSION COMMANDS @@ -1793,50 +4333,117 @@ def chat(): ) help_table.add_row( "/clear or /cl", - "Clear the terminal screen for a clean interface. You can also use the keycombo [bold]ctrl+l[/]", + "Clear terminal screen. Keyboard shortcut: [bold]Ctrl+L[/]", "/clear\n/cl" ) help_table.add_row( - "/help [command]", - "Show this help menu or get detailed help for a specific command.", - "/help\n/help /model" + "/help [command|topic]", + "Show help menu or detailed help. Use '/help mcp' for MCP guide.", + "/help\n/help mcp" ) help_table.add_row( "/memory [on|off]", - "Toggle conversation memory. ON sends history (AI remembers), OFF sends only current message (saves cost).", + "Toggle conversation memory. ON sends history, OFF is stateless (saves cost).", "/memory\n/memory off" ) help_table.add_row( "/next", - "View the next response in history.", + "View next response in history.", "/next" ) help_table.add_row( "/online [on|off]", - "Enable/disable online mode (web search) for current session. Overrides default setting.", + "Enable/disable online mode (web search) for session.", "/online on\n/online off" ) help_table.add_row( "/paste [prompt]", - "Paste plain text/code from clipboard and send to AI. Optional prompt can be added.", - "/paste\n/paste Explain this code" + "Paste clipboard text/code. Optional prompt.", + "/paste\n/paste Explain" ) help_table.add_row( "/prev", - "View the previous response in history.", + "View previous response in history.", "/prev" ) help_table.add_row( "/reset", - "Clear conversation history and reset system prompt (resets session metrics). Requires confirmation.", + "Clear history and reset system prompt (requires confirmation).", "/reset" ) help_table.add_row( "/retry", - "Resend the last prompt from history.", + "Resend last prompt.", "/retry" ) + # MCP COMMANDS + help_table.add_row( + "[bold yellow]━━━ MCP (FILE & DATABASE ACCESS) ━━━[/]", + "", + "" + ) + help_table.add_row( + "/mcp enable", + "Start MCP server for file and database access.", + "/mcp enable" + ) + help_table.add_row( + "/mcp disable", + "Stop MCP server.", + "/mcp disable" + ) + help_table.add_row( + "/mcp status", + "Show mode, stats, folders, databases, and .gitignore status.", + "/mcp status" + ) + help_table.add_row( + "/mcp add ", + "Add folder for file access (auto-loads .gitignore).", + "/mcp add ~/Documents" + ) + help_table.add_row( + "/mcp add db ", + "Add SQLite database for querying.", + "/mcp add db ~/app.db" + ) + help_table.add_row( + "/mcp list", + "List all allowed folders with stats.", + "/mcp list" + ) + help_table.add_row( + "/mcp db list", + "List all databases with details.", + "/mcp db list" + ) + help_table.add_row( + "/mcp db ", + "Switch to database mode (select DB by number).", + "/mcp db 1" + ) + help_table.add_row( + "/mcp files", + "Switch to file mode (default).", + "/mcp files" + ) + help_table.add_row( + "/mcp remove ", + "Remove folder by path or number.", + "/mcp remove 2" + ) + help_table.add_row( + "/mcp remove db ", + "Remove database by number.", + "/mcp remove db 1" + ) + help_table.add_row( + "/mcp gitignore [on|off]", + "Toggle .gitignore filtering (respects project excludes).", + "/mcp gitignore on" + ) + # MODEL COMMANDS help_table.add_row( "[bold yellow]━━━ MODEL COMMANDS ━━━[/]", @@ -1845,16 +4452,16 @@ def chat(): ) help_table.add_row( "/info [model_id]", - "Display detailed info (pricing, modalities, context length, online support, etc.) for current or specified model.", + "Display model details (pricing, capabilities, context).", "/info\n/info gpt-4o" ) help_table.add_row( "/model [search]", - "Select or change the current model for the session. Shows image and online capabilities. Supports searching by name or ID.", + "Select/change model. Shows image and online capabilities.", "/model\n/model gpt" ) - # CONFIGURATION COMMANDS + # CONFIGURATION help_table.add_row( "[bold yellow]━━━ CONFIGURATION ━━━[/]", "", @@ -1862,56 +4469,56 @@ def chat(): ) help_table.add_row( "/config", - "View all current configurations, including limits, credits, and history.", + "View all configurations.", "/config" ) help_table.add_row( "/config api", - "Set or update the OpenRouter API key.", + "Set/update API key.", "/config api" ) help_table.add_row( - "/config costwarning [value]", - "Set the cost warning threshold. Alerts when response exceeds this cost (in USD).", + "/config costwarning [val]", + "Set cost warning threshold (USD).", "/config costwarning 0.05" ) help_table.add_row( "/config log [size_mb]", - "Set log file size limit in MB. Older logs are rotated automatically. Takes effect immediately.", + "Set log file size limit. Takes effect immediately.", "/config log 20" ) help_table.add_row( "/config loglevel [level]", - "Set log verbosity level. Valid levels: debug, info, warning, error, critical. Takes effect immediately.", - "/config loglevel debug\n/config loglevel warning" + "Set log verbosity (debug/info/warning/error/critical).", + "/config loglevel debug" ) help_table.add_row( - "/config maxtoken [value]", - "Set stored max token limit (persisted in DB). View current if no value provided.", + "/config maxtoken [val]", + "Set stored max token limit.", "/config maxtoken 50000" ) help_table.add_row( "/config model [search]", - "Set default model that loads on startup. Shows image and online capabilities. Doesn't change current session model.", + "Set default startup model.", "/config model gpt" ) help_table.add_row( "/config online [on|off]", - "Set default online mode for new model selections. Use '/online on|off' to override current session.", + "Set default online mode for new models.", "/config online on" ) help_table.add_row( "/config stream [on|off]", - "Enable or disable response streaming.", + "Enable/disable response streaming.", "/config stream off" ) help_table.add_row( "/config url", - "Set or update the base URL for OpenRouter API.", + "Set/update base URL.", "/config url" ) - # TOKEN & SYSTEM COMMANDS + # TOKEN & SYSTEM help_table.add_row( "[bold yellow]━━━ TOKEN & SYSTEM ━━━[/]", "", @@ -1919,21 +4526,21 @@ def chat(): ) help_table.add_row( "/maxtoken [value]", - "Set temporary session token limit (≤ stored max). View current if no value provided.", + "Set temporary session token limit.", "/maxtoken 2000" ) help_table.add_row( "/middleout [on|off]", - "Enable/disable middle-out transform to compress prompts exceeding context size.", + "Enable/disable prompt compression for large contexts.", "/middleout on" ) help_table.add_row( "/system [prompt|clear]", - "Set session-level system prompt to guide AI behavior. Use 'clear' to reset.", + "Set session system prompt. Use 'clear' to reset.", "/system You are a Python expert" ) - # CONVERSATION MANAGEMENT + # CONVERSATION MGMT help_table.add_row( "[bold yellow]━━━ CONVERSATION MGMT ━━━[/]", "", @@ -1941,31 +4548,31 @@ def chat(): ) help_table.add_row( "/delete ", - "Delete a saved conversation by name or number (from /list). Requires confirmation.", + "Delete saved conversation (requires confirmation).", "/delete my_chat\n/delete 3" ) help_table.add_row( - "/export ", - "Export conversation to file. Formats: md (Markdown), json (JSON), html (HTML).", - "/export md notes.md\n/export html report.html" + "/export ", + "Export conversation (md/json/html).", + "/export md notes.md" ) help_table.add_row( "/list", - "List all saved conversations with numbers, message counts, and timestamps.", + "List saved conversations with numbers and stats.", "/list" ) help_table.add_row( "/load ", - "Load a saved conversation by name or number (from /list). Resets session metrics.", + "Load saved conversation.", "/load my_chat\n/load 3" ) help_table.add_row( "/save ", - "Save current conversation history to database.", + "Save current conversation.", "/save my_chat" ) - # MONITORING & STATS + # MONITORING help_table.add_row( "[bold yellow]━━━ MONITORING & STATS ━━━[/]", "", @@ -1973,12 +4580,12 @@ def chat(): ) help_table.add_row( "/credits", - "Display credits left on your OpenRouter account with alerts.", + "Display OpenRouter credits with alerts.", "/credits" ) help_table.add_row( "/stats", - "Display session cost summary: tokens, cost, credits left, and warnings.", + "Display session stats: tokens, cost, credits.", "/stats" ) @@ -1990,17 +4597,17 @@ def chat(): ) help_table.add_row( "@/path/to/file", - "Attach files to messages: images (PNG, JPG, etc.), PDFs, and code files (.py, .js, etc.).", - "Debug @script.py\nSummarize @document.pdf\nAnalyze @image.png" + "Attach files: images (PNG, JPG), PDFs, code files.", + "Debug @script.py\nAnalyze @image.png" ) help_table.add_row( "Clipboard paste", - "Use /paste to send clipboard content (plain text/code) to AI.", - "/paste\n/paste Explain this" + "Use /paste to send clipboard content.", + "/paste\n/paste Explain" ) help_table.add_row( "// escape", - "Start message with // to send a literal / character (e.g., //command sends '/command' as text, not a command)", + "Start with // to send literal / character.", "//help sends '/help' as text" ) @@ -2012,7 +4619,7 @@ def chat(): ) help_table.add_row( "exit | quit | bye", - "Quit the chat application and display session summary.", + "Quit with session summary.", "exit" ) @@ -2020,7 +4627,7 @@ def chat(): help_table, title="[bold cyan]oAI Chat Help (Version %s)[/]" % version, title_align="center", - subtitle="💡 Tip: Commands are case-insensitive • Use /help for detailed help • Memory ON by default • Use // to escape / • Visit: https://iurl.no/oai", + subtitle="💡 Commands are case-insensitive • Use /help for details • Use /help mcp for MCP guide • Memory ON by default • MCP has file & DB modes • Use // to escape / • Visit: https://iurl.no/oai", subtitle_align="center", border_style="cyan" )) @@ -2029,10 +4636,11 @@ def chat(): if not selected_model: console.print("[bold yellow]Select a model first with '/model'.[/]") continue - # Process file attachments with PDF support - content_blocks = [] + + # ============================================================ + # PROCESS FILE ATTACHMENTS + # ============================================================ text_part = user_input - file_attachments = [] for match in re.finditer(r'@([^\s]+)', user_input): file_path = match.group(1) expanded_path = os.path.expanduser(os.path.abspath(file_path)) @@ -2049,7 +4657,6 @@ def chat(): with open(expanded_path, 'rb') as f: file_data = f.read() - # Handle images if mime_type and mime_type.startswith('image/'): modalities = selected_model.get("architecture", {}).get("input_modalities", []) if "image" not in modalities: @@ -2060,7 +4667,6 @@ def chat(): content_blocks.append({"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{b64_data}"}}) console.print(f"[dim green]✓ Image attached: {os.path.basename(expanded_path)} ({file_size / 1024:.1f} KB)[/]") - # Handle PDFs elif mime_type == 'application/pdf' or file_ext == '.pdf': modalities = selected_model.get("architecture", {}).get("input_modalities", []) supports_pdf = any(mod in modalities for mod in ["document", "pdf", "file"]) @@ -2072,7 +4678,6 @@ def chat(): content_blocks.append({"type": "image_url", "image_url": {"url": f"data:application/pdf;base64,{b64_data}"}}) console.print(f"[dim green]✓ PDF attached: {os.path.basename(expanded_path)} ({file_size / 1024:.1f} KB)[/]") - # Handle code/text files elif (mime_type == 'text/plain' or file_ext in SUPPORTED_CODE_EXTENSIONS): text_content = file_data.decode('utf-8') content_blocks.append({"type": "text", "text": f"Code File: {os.path.basename(expanded_path)}\n\n{text_content}"}) @@ -2105,12 +4710,28 @@ def chat(): console.print("[bold red]Prompt cannot be empty.[/]") continue - # Build API messages with conversation history if memory is enabled + # ============================================================ + # BUILD API MESSAGES + # ============================================================ api_messages = [] if session_system_prompt: api_messages.append({"role": "system", "content": session_system_prompt}) + # Add database context if in database mode + if mcp_manager.enabled and mcp_manager.mode == "database" and mcp_manager.selected_db_index is not None: + db = mcp_manager.databases[mcp_manager.selected_db_index] + db_context = f"""You are currently connected to SQLite database: {db['name']} +Available tables: {', '.join(db['tables'])} + +You can: +- Inspect the database schema with inspect_database +- Search for data across tables with search_database +- Execute read-only SQL queries with query_database + +All queries are read-only. INSERT/UPDATE/DELETE are not allowed.""" + api_messages.append({"role": "system", "content": db_context}) + if conversation_memory_enabled: for i in range(memory_start_index, len(session_history)): history_entry = session_history[i] @@ -2125,10 +4746,12 @@ def chat(): api_messages.append({"role": "user", "content": message_content}) - # Get effective model ID with :online suffix if enabled + # Get effective model ID effective_model_id = get_effective_model_id(selected_model["id"], online_mode_enabled) - # Build API params + # ======================================================================== + # BUILD API PARAMS WITH MCP TOOLS SUPPORT + # ======================================================================== api_params = { "model": effective_model_id, "messages": api_messages, @@ -2138,6 +4761,27 @@ def chat(): "X-Title": APP_NAME } } + + # Add MCP tools if enabled + mcp_tools_added = False + mcp_tools = [] + if mcp_manager.enabled: + if supports_function_calling(selected_model): + mcp_tools = mcp_manager.get_tools_schema() + if mcp_tools: + api_params["tools"] = mcp_tools + api_params["tool_choice"] = "auto" + mcp_tools_added = True + + # IMPORTANT: Disable streaming if MCP tools are present + if api_params["stream"]: + api_params["stream"] = False + app_logger.info("Disabled streaming due to MCP tool calls") + + app_logger.info(f"Added {len(mcp_tools)} MCP tools to request (mode: {mcp_manager.mode})") + else: + app_logger.debug(f"Model {selected_model['id']} doesn't support function calling - MCP tools not added") + if session_max_token > 0: api_params["max_tokens"] = session_max_token if middle_out_enabled: @@ -2148,17 +4792,44 @@ def chat(): history_messages_count = len(session_history) - memory_start_index if conversation_memory_enabled else 0 memory_status = "ON" if conversation_memory_enabled else "OFF" online_status = "ON" if online_mode_enabled else "OFF" - app_logger.info(f"API Request: Model '{effective_model_id}' (Online: {online_status}), Prompt length: {len(text_part)} chars, {file_count} file(s) attached, Memory: {memory_status}, History sent: {history_messages_count} messages, Transforms: middle-out {'enabled' if middle_out_enabled else 'disabled'}, App: {APP_NAME} ({APP_URL}).") + + if mcp_tools_added: + if mcp_manager.mode == "files": + mcp_status = f"ON (Files, {len(mcp_manager.allowed_folders)} folders, {len(mcp_tools)} tools)" + elif mcp_manager.mode == "database": + db = mcp_manager.databases[mcp_manager.selected_db_index] + mcp_status = f"ON (DB: {db['name']}, {len(mcp_tools)} tools)" + else: + mcp_status = f"ON ({len(mcp_tools)} tools)" + else: + mcp_status = "OFF" + + app_logger.info(f"API Request: Model '{effective_model_id}' (Online: {online_status}, MCP: {mcp_status}), Prompt length: {len(text_part)} chars, {file_count} file(s) attached, Memory: {memory_status}, History sent: {history_messages_count} messages.") - # Send and handle response - is_streaming = STREAM_ENABLED == "on" + # Send request + is_streaming = api_params["stream"] if is_streaming: console.print("[bold green]Streaming response...[/] [dim](Press Ctrl+C to cancel)[/]") if online_mode_enabled: console.print("[dim cyan]🌐 Online mode active - model has web search access[/]") console.print("") else: - console.print("[bold green]Thinking...[/]", end="\r") + thinking_msg = "Thinking" + if mcp_tools_added: + if mcp_manager.mode == "database": + thinking_msg = "Thinking (database tools available)" + else: + thinking_msg = "Thinking (MCP tools available)" + console.print(f"[bold green]{thinking_msg}...[/]", end="\r") + if online_mode_enabled: + console.print(f"[dim cyan]🌐 Online mode active[/]") + if mcp_tools_added: + if mcp_manager.mode == "files": + gitignore_status = "on" if (mcp_manager.server and mcp_manager.server.respect_gitignore) else "off" + console.print(f"[dim cyan]🔧 MCP active - AI can access {len(mcp_manager.allowed_folders)} folder(s) with {len(mcp_tools)} tools (.gitignore {gitignore_status})[/]") + elif mcp_manager.mode == "database": + db = mcp_manager.databases[mcp_manager.selected_db_index] + console.print(f"[dim cyan]🗄️ MCP active - AI can query database: {db['name']} ({len(db['tables'])} tables, {len(mcp_tools)} tools)[/]") start_time = time.time() try: @@ -2169,8 +4840,150 @@ def chat(): app_logger.error(f"API Error: {type(e).__name__}: {e}") continue + # ======================================================================== + # HANDLE TOOL CALLS (MCP FUNCTION CALLING) + # ======================================================================== + tool_call_loop_count = 0 + max_tool_loops = 5 + + while tool_call_loop_count < max_tool_loops: + wants_tool_call = False + + if hasattr(response, 'choices') and response.choices: + message = response.choices[0].message + + if hasattr(message, 'tool_calls') and message.tool_calls: + wants_tool_call = True + tool_calls = message.tool_calls + + console.print(f"\n[dim yellow]🔧 AI requesting {len(tool_calls)} tool call(s)...[/]") + app_logger.info(f"Model requested {len(tool_calls)} tool calls") + + tool_results = [] + for tool_call in tool_calls: + tool_name = tool_call.function.name + + try: + tool_args = json.loads(tool_call.function.arguments) + except json.JSONDecodeError as e: + app_logger.error(f"Failed to parse tool arguments: {e}") + tool_results.append({ + "tool_call_id": tool_call.id, + "role": "tool", + "name": tool_name, + "content": json.dumps({"error": f"Invalid arguments: {e}"}) + }) + continue + + args_display = ', '.join(f'{k}="{v}"' if isinstance(v, str) else f'{k}={v}' for k, v in tool_args.items()) + console.print(f"[dim cyan] → Calling: {tool_name}({args_display})[/]") + app_logger.info(f"Executing MCP tool: {tool_name} with args: {tool_args}") + + try: + result = asyncio.run(mcp_manager.call_tool(tool_name, **tool_args)) + except Exception as e: + app_logger.error(f"MCP tool execution error: {e}") + result = {"error": str(e)} + + if 'error' in result: + result_content = json.dumps({"error": result['error']}) + console.print(f"[dim red] ✗ Error: {result['error']}[/]") + app_logger.warning(f"MCP tool {tool_name} returned error: {result['error']}") + else: + # Display appropriate success message based on tool + if tool_name == 'search_files': + count = result.get('count', 0) + console.print(f"[dim green] ✓ Found {count} file(s)[/]") + elif tool_name == 'read_file': + size = result.get('size', 0) + truncated = " (truncated)" if result.get('truncated') else "" + console.print(f"[dim green] ✓ Read file ({size} bytes{truncated})[/]") + elif tool_name == 'list_directory': + count = result.get('count', 0) + truncated = " (limited to 1000)" if result.get('truncated') else "" + console.print(f"[dim green] ✓ Listed {count} item(s){truncated}[/]") + elif tool_name == 'inspect_database': + if 'table' in result: + console.print(f"[dim green] ✓ Inspected table: {result['table']} ({result['row_count']} rows)[/]") + else: + console.print(f"[dim green] ✓ Inspected database ({result['table_count']} tables)[/]") + elif tool_name == 'search_database': + count = result.get('count', 0) + console.print(f"[dim green] ✓ Found {count} match(es)[/]") + elif tool_name == 'query_database': + count = result.get('count', 0) + truncated = " (truncated)" if result.get('truncated') else "" + console.print(f"[dim green] ✓ Query returned {count} row(s){truncated}[/]") + else: + console.print(f"[dim green] ✓ Success[/]") + + result_content = json.dumps(result, indent=2) + app_logger.info(f"MCP tool {tool_name} succeeded") + + tool_results.append({ + "tool_call_id": tool_call.id, + "role": "tool", + "name": tool_name, + "content": result_content + }) + + api_messages.append({ + "role": "assistant", + "content": message.content, + "tool_calls": [ + { + "id": tc.id, + "type": "function", + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments + } + } for tc in tool_calls + ] + }) + api_messages.extend(tool_results) + + console.print("\n[dim cyan]💭 Processing tool results...[/]") + app_logger.info("Sending tool results back to model") + + followup_params = { + "model": effective_model_id, + "messages": api_messages, + "stream": False, # Keep streaming disabled for follow-ups + "http_headers": { + "HTTP-Referer": APP_URL, + "X-Title": APP_NAME + } + } + + if mcp_tools_added: + followup_params["tools"] = mcp_tools + followup_params["tool_choice"] = "auto" + + if session_max_token > 0: + followup_params["max_tokens"] = session_max_token + + try: + response = client.chat.send(**followup_params) + tool_call_loop_count += 1 + app_logger.info(f"Follow-up request successful (loop {tool_call_loop_count})") + except Exception as e: + console.print(f"[bold red]Error getting follow-up response: {e}[/]") + app_logger.error(f"Follow-up API error: {e}") + break + + if not wants_tool_call: + break + + if tool_call_loop_count >= max_tool_loops: + console.print(f"[bold yellow]⚠️ Reached maximum tool call depth ({max_tool_loops}). Stopping.[/]") + app_logger.warning(f"Hit max tool call loop limit: {max_tool_loops}") + response_time = time.time() - start_time + # ======================================================================== + # PROCESS FINAL RESPONSE + # ======================================================================== full_response = "" if is_streaming: try: @@ -2194,7 +5007,7 @@ def chat(): continue else: full_response = response.choices[0].message.content if response.choices else "" - console.print(f"\r{' ' * 20}\r", end="") + console.print(f"\r{' ' * 50}\r", end="") if full_response: if not is_streaming: @@ -2204,7 +5017,6 @@ def chat(): session_history.append({'prompt': user_input, 'response': full_response}) current_index = len(session_history) - 1 - # Process metrics usage = getattr(response, 'usage', None) input_tokens = usage.input_tokens if usage and hasattr(usage, 'input_tokens') else 0 output_tokens = usage.output_tokens if usage and hasattr(usage, 'output_tokens') else 0 @@ -2215,9 +5027,8 @@ def chat(): total_cost += msg_cost message_count += 1 - app_logger.info(f"Response: Tokens - I:{input_tokens} O:{output_tokens} T:{input_tokens + output_tokens}, Cost: ${msg_cost:.4f}, Time: {response_time:.2f}s, Online: {online_mode_enabled}") + app_logger.info(f"Response: Tokens - I:{input_tokens} O:{output_tokens} T:{input_tokens + output_tokens}, Cost: ${msg_cost:.4f}, Time: {response_time:.2f}s, Tool loops: {tool_call_loop_count}") - # Per-message metrics display if conversation_memory_enabled: context_count = len(session_history) - memory_start_index context_info = f", Context: {context_count} msg(s)" if context_count > 1 else "" @@ -2225,12 +5036,18 @@ def chat(): context_info = ", Memory: OFF" online_info = " 🌐" if online_mode_enabled else "" - console.print(f"\n[dim blue]📊 Metrics: {input_tokens + output_tokens} tokens | ${msg_cost:.4f} | {response_time:.2f}s{context_info}{online_info} | Session: {total_input_tokens + total_output_tokens} tokens | ${total_cost:.4f}[/]") + mcp_info = "" + if mcp_tools_added: + if mcp_manager.mode == "files": + mcp_info = " 🔧" + elif mcp_manager.mode == "database": + mcp_info = " 🗄️" + tool_info = f" ({tool_call_loop_count} tool loop(s))" if tool_call_loop_count > 0 else "" + console.print(f"\n[dim blue]📊 Metrics: {input_tokens + output_tokens} tokens | ${msg_cost:.4f} | {response_time:.2f}s{context_info}{online_info}{mcp_info}{tool_info} | Session: {total_input_tokens + total_output_tokens} tokens | ${total_cost:.4f}[/]") - # Cost and credit alerts warnings = [] if msg_cost > COST_WARNING_THRESHOLD: - warnings.append(f"High cost alert: This response exceeded configurable threshold ${COST_WARNING_THRESHOLD:.4f}") + warnings.append(f"High cost alert: This response exceeded threshold ${COST_WARNING_THRESHOLD:.4f}") credits_data = get_credits(API_KEY, OPENROUTER_BASE_URL) if credits_data: warning_alerts = check_credit_alerts(credits_data) diff --git a/requirements.txt b/requirements.txt index 8a0a132..ea6105b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,38 +1,26 @@ -anyio==4.11.0 -beautifulsoup4==4.14.2 -charset-normalizer==3.4.4 -click==8.3.1 -docopt==0.6.2 -h11==0.16.0 -httpcore==1.0.9 -httpx==0.28.1 -idna==3.11 -latex2mathml==3.78.1 -loguru==0.7.3 -markdown-it-py==4.0.0 -markdown2==2.5.4 -mdurl==0.1.2 -natsort==8.4.0 -openrouter==0.0.19 -packaging==25.0 -pipreqs==0.4.13 -prompt-toolkit==3.0.52 -Pygments==2.19.2 -pyperclip==1.11.0 -python-dateutil==2.9.0.post0 -python-magic==0.4.27 -PyYAML==6.0.3 -requests==2.32.5 -rich==14.2.0 -shellingham==1.5.4 -six==1.17.0 -sniffio==1.3.1 -soupsieve==2.8 -svgwrite==1.4.3 -tqdm==4.67.1 -typer==0.20.0 -typing-extensions==4.15.0 -urllib3==2.5.0 -wavedrom==2.0.3.post3 -wcwidth==0.2.14 -yarg==0.1.10 +# oai.py v2.1.0-beta - Core Dependencies +anyio>=4.11.0 +charset-normalizer>=3.4.4 +click>=8.3.1 +h11>=0.16.0 +httpcore>=1.0.9 +httpx>=0.28.1 +idna>=3.11 +markdown-it-py>=4.0.0 +mdurl>=0.1.2 +openrouter>=0.0.19 +packaging>=25.0 +prompt-toolkit>=3.0.52 +Pygments>=2.19.2 +pyperclip>=1.11.0 +requests>=2.32.5 +rich>=14.2.0 +shellingham>=1.5.4 +sniffio>=1.3.1 +typer>=0.20.0 +typing-extensions>=4.15.0 +urllib3>=2.5.0 +wcwidth>=0.2.14 + +# MCP (Model Context Protocol) +mcp>=1.25.0 \ No newline at end of file -- 2.49.1 From d4f1a1c6a4fd68e8f61e19e86a4a23fdb5653125 Mon Sep 17 00:00:00 2001 From: Rune Olsen Date: Wed, 7 Jan 2026 08:01:33 +0100 Subject: [PATCH 02/10] Fixed bug in ctrl+c handling during response streaming --- oai.py | 61 ++++++++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 47 insertions(+), 14 deletions(-) diff --git a/oai.py b/oai.py index e679ce5..be8c1b7 100644 --- a/oai.py +++ b/oai.py @@ -4985,26 +4985,59 @@ All queries are read-only. INSERT/UPDATE/DELETE are not allowed.""" # PROCESS FINAL RESPONSE # ======================================================================== full_response = "" + stream_interrupted = False + if is_streaming: try: with Live("", console=console, refresh_per_second=10, auto_refresh=True) as live: - for chunk in response: - if hasattr(chunk, 'error') and chunk.error: - console.print(f"\n[bold red]Stream error: {chunk.error.message}[/]") - app_logger.error(f"Stream error: {chunk.error.message}") - break - if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content: - content_chunk = chunk.choices[0].delta.content - full_response += content_chunk - md = Markdown(full_response) - live.update(md) + try: + for chunk in response: + if hasattr(chunk, 'error') and chunk.error: + console.print(f"\n[bold red]Stream error: {chunk.error.message}[/]") + app_logger.error(f"Stream error: {chunk.error.message}") + break + if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content: + content_chunk = chunk.choices[0].delta.content + full_response += content_chunk + md = Markdown(full_response) + live.update(md) + except KeyboardInterrupt: + stream_interrupted = True + console.print("\n[bold yellow]⚠️ Streaming interrupted![/]") + app_logger.info("Streaming interrupted by user (Ctrl+C)") + except Exception as stream_error: + stream_interrupted = True + console.print(f"\n[bold red]Stream error: {stream_error}[/]") + app_logger.error(f"Stream processing error: {stream_error}") - console.print("") + if not stream_interrupted: + console.print("") except KeyboardInterrupt: - console.print("\n[bold yellow]Streaming cancelled![/]") - app_logger.info("Streaming cancelled by user") - continue + # Outer interrupt handler (in case inner misses it) + stream_interrupted = True + console.print("\n[bold yellow]⚠️ Streaming interrupted![/]") + app_logger.info("Streaming interrupted by user (outer)") + except Exception as e: + stream_interrupted = True + console.print(f"\n[bold red]Error during streaming: {e}[/]") + app_logger.error(f"Streaming error: {e}") + + # If stream was interrupted, skip processing and continue to next prompt + if stream_interrupted: + if full_response: + console.print(f"\n[dim yellow]Partial response received ({len(full_response)} chars). Discarding...[/]") + console.print("[dim blue]💡 Ready for next prompt[/]\n") + app_logger.info("Stream cleanup completed, returning to prompt") + + # Force close the response if possible + try: + if hasattr(response, 'close'): + response.close() + except: + pass + + continue # Now it's safe to continue else: full_response = response.choices[0].message.content if response.choices else "" console.print(f"\r{' ' * 50}\r", end="") -- 2.49.1 From 2c9f33868e89e0d1d6b9565488f31fe665cb9f9c Mon Sep 17 00:00:00 2001 From: Rune Olsen Date: Wed, 7 Jan 2026 09:58:58 +0100 Subject: [PATCH 03/10] More bugfixes --- oai.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/oai.py b/oai.py index be8c1b7..f64a7fe 100644 --- a/oai.py +++ b/oai.py @@ -5034,9 +5034,18 @@ All queries are read-only. INSERT/UPDATE/DELETE are not allowed.""" try: if hasattr(response, 'close'): response.close() + elif hasattr(response, '__exit__'): + response.__exit__(None, None, None) except: pass + # Recreate client to be safe + try: + client = OpenRouter(api_key=API_KEY) + app_logger.info("Client recreated after stream interruption") + except: + pass + continue # Now it's safe to continue else: full_response = response.choices[0].message.content if response.choices else "" -- 2.49.1 From c305d5cf49d1d52177b86f1dd1a844b4cbc2be2d Mon Sep 17 00:00:00 2001 From: Rune Olsen Date: Wed, 7 Jan 2026 11:09:13 +0100 Subject: [PATCH 04/10] More bugfixes --- oai.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/oai.py b/oai.py index f64a7fe..991d679 100644 --- a/oai.py +++ b/oai.py @@ -313,14 +313,14 @@ Use /help mcp for comprehensive guide with examples.''' 'notes': 'Without arguments, shows info for the currently selected model. Displays pricing per million tokens, supported modalities (text, image, etc.), and parameter support.' }, '/model': { - 'description': 'Select or change the AI model for the current session. Shows image and online capabilities.', + 'description': 'Select or change the AI model for the current session. Shows image, online, and tool capabilities.', 'usage': '/model [search_term]', 'examples': [ ('List all models', '/model'), ('Search for GPT models', '/model gpt'), ('Search for Claude models', '/model claude'), ], - 'notes': 'Models are numbered for easy selection. The table shows Image (✓ if model accepts images) and Online (✓ if model supports web search) columns.' + 'notes': 'Models are numbered for easy selection. The table shows Image (✓ if model accepts images), Online (✓ if model supports web search), and Tools (✓ if model supports function calling for MCP).' }, '/config': { 'description': 'View or modify application configuration settings.', @@ -2586,6 +2586,10 @@ def supports_function_calling(model: Dict[str, Any]) -> bool: supported_params = model.get("supported_parameters", []) return "tools" in supported_params or "functions" in supported_params +def supports_tools(model: Dict[str, Any]) -> bool: + """Check if model supports tools/function calling (same as supports_function_calling).""" + return supports_function_calling(model) + def check_for_updates(current_version: str) -> str: """Check for updates.""" try: @@ -4008,11 +4012,12 @@ def chat(): console.print(f"[bold red]No models match '{search_term}'. Try '/model'.[/]") continue - table = Table("No.", "Name", "ID", "Image", "Online", show_header=True, header_style="bold magenta") + table = Table("No.", "Name", "ID", "Image", "Online", "Tools", show_header=True, header_style="bold magenta") for i, model in enumerate(filtered_models, 1): image_support = "[green]✓[/green]" if has_image_capability(model) else "[red]✗[/red]" online_support = "[green]✓[/green]" if supports_online_mode(model) else "[red]✗[/red]" - table.add_row(str(i), model["name"], model["id"], image_support, online_support) + tools_support = "[green]✓[/green]" if supports_function_calling(model) else "[red]✗[/red]" + table.add_row(str(i), model["name"], model["id"], image_support, online_support, tools_support) title = f"[bold green]Available Models ({'All' if not search_term else f'Search: {search_term}'})[/]" display_paginated_table(table, title) @@ -4215,11 +4220,12 @@ def chat(): console.print(f"[bold red]No models match '{search_term}'. Try without search.[/]") continue - table = Table("No.", "Name", "ID", "Image", "Online", show_header=True, header_style="bold magenta") + table = Table("No.", "Name", "ID", "Image", "Online", "Tools", show_header=True, header_style="bold magenta") for i, model in enumerate(filtered_models, 1): image_support = "[green]✓[/green]" if has_image_capability(model) else "[red]✗[/red]" online_support = "[green]✓[/green]" if supports_online_mode(model) else "[red]✗[/red]" - table.add_row(str(i), model["name"], model["id"], image_support, online_support) + tools_support = "[green]✓[/green]" if supports_function_calling(model) else "[red]✗[/red]" + table.add_row(str(i), model["name"], model["id"], image_support, online_support, tools_support) title = f"[bold green]Available Models for Default ({'All' if not search_term else f'Search: {search_term}'})[/]" display_paginated_table(table, title) -- 2.49.1 From 8a113a9bbe4ea03995e0e3d75c00d7f733a89b2d Mon Sep 17 00:00:00 2001 From: Rune Olsen Date: Thu, 8 Jan 2026 13:12:17 +0100 Subject: [PATCH 05/10] Fixed some regex for files --- oai.py | 31 +++++++++++++++++++++++++++++-- 1 file changed, 29 insertions(+), 2 deletions(-) diff --git a/oai.py b/oai.py index 991d679..d58266b 100644 --- a/oai.py +++ b/oai.py @@ -4647,18 +4647,42 @@ def chat(): # PROCESS FILE ATTACHMENTS # ============================================================ text_part = user_input - for match in re.finditer(r'@([^\s]+)', user_input): + file_attachments = [] + content_blocks = [] + + # Smart file detection with extension whitelist + # Only matches files with known extensions or clear path prefixes + # + # Valid matches: + # @/path/to/file + # @~/file.txt + # @./script.py + # @report.pdf + # + # Excludes: + # @domain.com (no whitelisted extension) + # @mail.server.eu (TLD not in whitelist) + # user@domain.com (email pattern) + # Diagnostic-Code: @server.eu (TLD not in whitelist) + + file_pattern = r'(?:^|\s)@((?:[~/][\S]+)|(?:\.[\S]+)|(?:[\w][\w-]*\.(?:py|txt|md|log|json|csv|pdf|png|jpg|jpeg|gif|bmp|webp|svg|ico|zip|tar|gz|bz2|7z|rar|xz|js|ts|jsx|tsx|vue|html|css|scss|sass|less|xml|yaml|yml|toml|ini|conf|cfg|env|properties|sh|bash|zsh|fish|bat|cmd|ps1|c|cpp|cc|cxx|h|hpp|hxx|java|class|jar|war|rb|go|rs|swift|kt|kts|php|sql|db|sqlite|sqlite3|lock|gitignore|dockerignore|editorconfig|eslintrc|prettierrc|babelrc|nvmrc|npmrc|pyc|pyo|pyd|so|dll|dylib|exe|app|dmg|pkg|deb|rpm|apk|ipa|wasm|proto|graphql|graphqls|grpc|avro|parquet|orc|feather|arrow|hdf5|h5|mat|r|rdata|rds|pkl|pickle|joblib|npy|npz|safetensors|onnx|pt|pth|ckpt|pb|tflite|mlmodel|coreml|rknn)))(?=\s|$)' + + for match in re.finditer(file_pattern, user_input, re.IGNORECASE): file_path = match.group(1) expanded_path = os.path.expanduser(os.path.abspath(file_path)) + if not os.path.exists(expanded_path) or os.path.isdir(expanded_path): console.print(f"[bold red]File not found or is a directory: {expanded_path}[/]") continue + file_size = os.path.getsize(expanded_path) if file_size > 10 * 1024 * 1024: console.print(f"[bold red]File too large (>10MB): {expanded_path}[/]") continue + mime_type, _ = mimetypes.guess_type(expanded_path) file_ext = os.path.splitext(expanded_path)[1].lower() + try: with open(expanded_path, 'rb') as f: file_data = f.read() @@ -4704,7 +4728,9 @@ def chat(): console.print(f"[bold red]Error reading file {expanded_path}: {e}[/]") app_logger.error(f"File read error for {expanded_path}: {e}") continue - text_part = re.sub(r'@([^\s]+)', '', text_part).strip() + + # Remove file attachments from text (use same pattern as file detection) + text_part = re.sub(file_pattern, lambda m: m.group(0)[0] if m.group(0)[0].isspace() else '', text_part).strip() # Build message content if text_part or content_blocks: @@ -4716,6 +4742,7 @@ def chat(): console.print("[bold red]Prompt cannot be empty.[/]") continue + # ============================================================ # BUILD API MESSAGES # ============================================================ -- 2.49.1 From 8c49339452ebcd994ed1aa8749f28777f07ca0f6 Mon Sep 17 00:00:00 2001 From: Rune Olsen Date: Thu, 8 Jan 2026 15:00:59 +0100 Subject: [PATCH 06/10] Minor bug fixes --- README.md | 21 ++++++------------- oai.py | 63 ++++++++++++++++++++++++++++++++++++++++--------------- 2 files changed, 52 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index a3c2e5d..c07b3af 100644 --- a/README.md +++ b/README.md @@ -101,24 +101,10 @@ Download platform-specific binaries: # Extract and install unzip oai_vx.x.x_mac_arm64.zip # or `oai_vx.x.x-linux-x86_64.zip` chmod +x oai -mkdir -p ~/.local/bin +mkdir -p ~/.local/bin # Remember to add this to your path. Or just move to folder already in your $PATH mv oai ~/.local/bin/ ``` -### Option 3: Build Your Own Binary - -```bash -# Install build dependencies -pip install -r requirements.txt -pip install nuitka ordered-set zstandard - -# Run build script -chmod +x build.sh -./build.sh - -# Binary will be in dist/oai -cp dist/oai ~/.local/bin/ -``` ### Alternative: Shell Alias @@ -508,6 +494,7 @@ Full license: https://opensource.org/licenses/MIT **Rune Olsen** +- Homepage: https://ai.fubar.pm/ - Blog: https://blog.rune.pm - Project: https://iurl.no/oai @@ -527,3 +514,7 @@ Contributions welcome! Please: --- **Star ⭐ this project if you find it useful!** + +--- + +Did you really read all the way down here? WOW! You deserve a 🍾 🥂! diff --git a/oai.py b/oai.py index d58266b..339c744 100644 --- a/oai.py +++ b/oai.py @@ -4650,25 +4650,55 @@ def chat(): file_attachments = [] content_blocks = [] - # Smart file detection with extension whitelist - # Only matches files with known extensions or clear path prefixes - # - # Valid matches: - # @/path/to/file - # @~/file.txt - # @./script.py - # @report.pdf - # - # Excludes: - # @domain.com (no whitelisted extension) - # @mail.server.eu (TLD not in whitelist) - # user@domain.com (email pattern) - # Diagnostic-Code: @server.eu (TLD not in whitelist) + # Smart file detection: Simple pattern + extension validation + # This avoids extremely long regex that can cause binary signing issues - file_pattern = r'(?:^|\s)@((?:[~/][\S]+)|(?:\.[\S]+)|(?:[\w][\w-]*\.(?:py|txt|md|log|json|csv|pdf|png|jpg|jpeg|gif|bmp|webp|svg|ico|zip|tar|gz|bz2|7z|rar|xz|js|ts|jsx|tsx|vue|html|css|scss|sass|less|xml|yaml|yml|toml|ini|conf|cfg|env|properties|sh|bash|zsh|fish|bat|cmd|ps1|c|cpp|cc|cxx|h|hpp|hxx|java|class|jar|war|rb|go|rs|swift|kt|kts|php|sql|db|sqlite|sqlite3|lock|gitignore|dockerignore|editorconfig|eslintrc|prettierrc|babelrc|nvmrc|npmrc|pyc|pyo|pyd|so|dll|dylib|exe|app|dmg|pkg|deb|rpm|apk|ipa|wasm|proto|graphql|graphqls|grpc|avro|parquet|orc|feather|arrow|hdf5|h5|mat|r|rdata|rds|pkl|pickle|joblib|npy|npz|safetensors|onnx|pt|pth|ckpt|pb|tflite|mlmodel|coreml|rknn)))(?=\s|$)' + # Common file extensions we support + ALLOWED_FILE_EXTENSIONS = { + # Code + '.py', '.js', '.ts', '.jsx', '.tsx', '.vue', '.java', '.c', '.cpp', '.cc', '.cxx', + '.h', '.hpp', '.hxx', '.rb', '.go', '.rs', '.swift', '.kt', '.kts', '.php', + '.sh', '.bash', '.zsh', '.fish', '.bat', '.cmd', '.ps1', + # Data + '.json', '.csv', '.yaml', '.yml', '.toml', '.xml', '.sql', '.db', '.sqlite', '.sqlite3', + # Documents + '.txt', '.md', '.log', '.conf', '.cfg', '.ini', '.env', '.properties', + # Images + '.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp', '.svg', '.ico', + # Archives + '.zip', '.tar', '.gz', '.bz2', '.7z', '.rar', '.xz', + # Config files + '.lock', '.gitignore', '.dockerignore', '.editorconfig', '.eslintrc', + '.prettierrc', '.babelrc', '.nvmrc', '.npmrc', + # Binary/Compiled + '.pyc', '.pyo', '.pyd', '.so', '.dll', '.dylib', '.exe', '.app', + '.dmg', '.pkg', '.deb', '.rpm', '.apk', '.ipa', + # ML/AI + '.pkl', '.pickle', '.joblib', '.npy', '.npz', '.safetensors', '.onnx', + '.pt', '.pth', '.ckpt', '.pb', '.tflite', '.mlmodel', '.coreml', '.rknn', + # Data formats + '.wasm', '.proto', '.graphql', '.graphqls', '.grpc', '.avro', '.parquet', + '.orc', '.feather', '.arrow', '.hdf5', '.h5', '.mat', '.r', '.rdata', '.rds', + # Other + '.pdf', '.class', '.jar', '.war' + } + + # Simple pattern: @filepath where filepath starts with ~, /, . or has an extension + # Much shorter than listing all extensions in the regex + file_pattern = r'(?:^|\s)@([~/\.][\S]+|[\w][\w-]*\.[\w]+)(?=\s|$)' for match in re.finditer(file_pattern, user_input, re.IGNORECASE): file_path = match.group(1) + + # Get the extension + file_ext = os.path.splitext(file_path)[1].lower() + + # Skip if it doesn't start with a path char and has no allowed extension + if not file_path.startswith(('/', '~', '.', '\\')): + if not file_ext or file_ext not in ALLOWED_FILE_EXTENSIONS: + # This looks like a domain name, not a file + continue + expanded_path = os.path.expanduser(os.path.abspath(file_path)) if not os.path.exists(expanded_path) or os.path.isdir(expanded_path): @@ -4681,7 +4711,6 @@ def chat(): continue mime_type, _ = mimetypes.guess_type(expanded_path) - file_ext = os.path.splitext(expanded_path)[1].lower() try: with open(expanded_path, 'rb') as f: @@ -4729,7 +4758,7 @@ def chat(): app_logger.error(f"File read error for {expanded_path}: {e}") continue - # Remove file attachments from text (use same pattern as file detection) + # Remove file attachments from text (use same simple pattern) text_part = re.sub(file_pattern, lambda m: m.group(0)[0] if m.group(0)[0].isspace() else '', text_part).strip() # Build message content -- 2.49.1 From 106ef676e25376f1b210bbf4545e9ea110dcbdbb Mon Sep 17 00:00:00 2001 From: Rune Olsen Date: Thu, 15 Jan 2026 10:13:49 +0100 Subject: [PATCH 07/10] Updated to RC1 --- README.md | 88 +++++- oai.py | 781 ++++++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 831 insertions(+), 38 deletions(-) diff --git a/README.md b/README.md index c07b3af..64f3220 100644 --- a/README.md +++ b/README.md @@ -28,6 +28,14 @@ oAI is a feature-rich command-line chat application that provides an interactive - Supports code files, text, JSON, YAML, and more - Large file handling (auto-truncates >50KB) +- ✍️ **Write Mode** (NEW!): AI can modify files with your permission + - Create and edit files within allowed folders + - Delete files (always requires confirmation) + - Move, copy, and organize files + - Create directories + - Ignores .gitignore for write operations + - OFF by default - explicit opt-in required + - 🗄️ **Database Mode**: AI can query your SQLite databases - Read-only access (no data modification possible) - Schema inspection (tables, columns, indexes) @@ -39,6 +47,8 @@ oAI is a feature-rich command-line chat application that provides an interactive - 🔒 **Security Features**: - Explicit folder/database approval required - System directory blocking + - Write mode OFF by default (non-persistent) + - Delete operations always require user confirmation - Read-only database access - SQL injection protection - Query timeout (5 seconds) @@ -133,13 +143,18 @@ oai You> /model # Enable MCP for file access -You> /mcp enable +You> /mcp on You> /mcp add ~/Documents -# Ask AI to help with files +# Ask AI to help with files (read-only) [🔧 MCP: Files] You> List all Python files in Documents [🔧 MCP: Files] You> Read and explain main.py +# Enable write mode to let AI modify files +You> /mcp write on +[🔧✍️ MCP: Files+Write] You> Create a new Python file with helper functions +[🔧✍️ MCP: Files+Write] You> Refactor main.py to use async/await + # Switch to database mode You> /mcp add db ~/myapp/data.db You> /mcp db 1 @@ -153,7 +168,7 @@ You> /mcp db 1 **Setup:** ```bash -/mcp enable # Start MCP server +/mcp on # Start MCP server /mcp add ~/Projects # Grant access to folder /mcp add ~/Documents # Add another folder /mcp list # View all allowed folders @@ -167,16 +182,26 @@ You> /mcp db 1 "What's in my Documents folder?" ``` -**Available Tools:** +**Available Tools (Read-Only):** - `read_file` - Read complete file contents - `list_directory` - List files/folders (recursive optional) - `search_files` - Search by name or content +**Available Tools (Write Mode - requires `/mcp write on`):** +- `write_file` - Create new files or overwrite existing ones +- `edit_file` - Find and replace text in existing files +- `delete_file` - Delete files (always requires confirmation) +- `create_directory` - Create directories +- `move_file` - Move or rename files +- `copy_file` - Copy files to new locations + **Features:** -- ✅ Automatic .gitignore filtering +- ✅ Automatic .gitignore filtering (read operations only) - ✅ Skips virtual environments (venv, node_modules) - ✅ Handles large files (auto-truncates >50KB) - ✅ Cross-platform (macOS, Linux, Windows via WSL) +- ✅ Write mode OFF by default for safety +- ✅ Delete operations require user confirmation with LLM's reason ### Database Mode @@ -210,13 +235,41 @@ You> /mcp db 1 - ✅ WHERE, GROUP BY, HAVING, ORDER BY, LIMIT - ❌ INSERT/UPDATE/DELETE (blocked for safety) +### Write Mode + +**Enable Write Mode:** +```bash +/mcp write on # Enable write mode (shows warning, requires confirmation) +``` + +**Natural Language Usage:** +``` +"Create a new Python file called utils.py with helper functions" +"Edit main.py and replace the old API endpoint with the new one" +"Delete the backup.old file" (will prompt for confirmation) +"Create a directory called tests" +"Move config.json to the config folder" +``` + +**Important:** +- ⚠️ Write mode is **OFF by default** and resets each session +- ⚠️ Delete operations **always** require user confirmation +- ⚠️ All operations are limited to allowed MCP folders +- ✅ Write operations ignore .gitignore (can write to any file in allowed folders) + +**Disable Write Mode:** +```bash +/mcp write off # Disable write mode (back to read-only) +``` + ### Mode Management ```bash -/mcp status # Show current mode, stats, folders/databases +/mcp status # Show current mode, write mode, stats, folders/databases /mcp files # Switch to file mode /mcp db # Switch to database mode /mcp gitignore on # Enable .gitignore filtering (default) +/mcp write on|off # Enable/disable write mode /mcp remove 2 # Remove folder/database by number ``` @@ -238,9 +291,9 @@ You> /mcp db 1 ### MCP Commands ``` -/mcp enable Start MCP server -/mcp disable Stop MCP server -/mcp status Show comprehensive status +/mcp on Start MCP server +/mcp off Stop MCP server +/mcp status Show comprehensive status (includes write mode) /mcp add Add folder for file access /mcp add db Add SQLite database /mcp list List all folders @@ -249,6 +302,8 @@ You> /mcp db 1 /mcp files Switch to file mode /mcp remove Remove folder/database /mcp gitignore on Enable .gitignore filtering +/mcp write on Enable write mode (create/edit/delete files) +/mcp write off Disable write mode (read-only) ``` ### Model Commands @@ -356,7 +411,11 @@ PDF (models with document support) - ✅ Database opened in `mode=ro` ### File System Safety -- ✅ Read-only access (no write/delete) +- ✅ Read-only by default (write mode requires explicit opt-in) +- ✅ Write mode OFF by default each session (non-persistent) +- ✅ Delete operations always require user confirmation +- ✅ Write operations limited to allowed folders only +- ✅ System directories blocked - ✅ Virtual environment exclusion - ✅ Build artifact filtering - ✅ Maximum file size (10 MB) @@ -444,14 +503,19 @@ ls -la database.db ## Version History -### v2.1.0-beta (Current) +### v2.1.0-RC1 (Current) - ✨ **NEW**: MCP (Model Context Protocol) integration - ✨ **NEW**: File system access (read, search, list) +- ✨ **NEW**: Write mode - AI can create, edit, and delete files + - 6 write tools: write_file, edit_file, delete_file, create_directory, move_file, copy_file + - OFF by default - requires explicit `/mcp write on` activation + - Delete operations always require user confirmation + - Non-persistent setting (resets each session) - ✨ **NEW**: SQLite database querying (read-only) - ✨ **NEW**: Dual mode support (Files & Database) - ✨ **NEW**: .gitignore filtering - ✨ **NEW**: Binary data handling in databases -- ✨ **NEW**: Mode indicators in prompt +- ✨ **NEW**: Mode indicators in prompt (shows ✍️ when write mode active) - ✨ **NEW**: Comprehensive `/help mcp` guide - 🔧 Improved error handling for tool calls - 🔧 Enhanced logging for MCP operations diff --git a/oai.py b/oai.py index 339c744..746f759 100644 --- a/oai.py +++ b/oai.py @@ -27,6 +27,9 @@ from prompt_toolkit import PromptSession from prompt_toolkit.history import FileHistory from rich.logging import RichHandler from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.key_binding import KeyBindings +from prompt_toolkit.filters import Condition +from prompt_toolkit.application.current import get_app from packaging import version as pkg_version import io import platform @@ -45,7 +48,7 @@ except ImportError: print("Warning: MCP library not found. Install with: pip install mcp") # App version -version = '2.1.0-beta' +version = '2.1.0-RC1' app = typer.Typer() @@ -112,7 +115,7 @@ MCP (Model Context Protocol) gives your AI assistant direct access to: ╰─────────────────────────────────────────────────────────╯ SETUP: - /mcp enable Start MCP server + /mcp on Start MCP server /mcp add ~/Documents Grant access to folder /mcp add ~/Code/project Add another folder /mcp list View all allowed folders @@ -135,6 +138,41 @@ MANAGEMENT: /mcp gitignore on|off Toggle .gitignore filtering /mcp status Show comprehensive status +╭─────────────────────────────────────────────────────────╮ +│ ✍️ WRITE MODE (OPTIONAL) │ +╰─────────────────────────────────────────────────────────╯ + +ENABLE WRITE MODE: + /mcp write on Enable file modifications (requires confirmation) + /mcp write off Disable write mode (back to read-only) + +WHAT WRITE MODE ALLOWS: + • Create new files or overwrite existing ones + • Edit files (find and replace text) + • Delete files (always requires confirmation) + • Create directories + • Move and copy files + +USAGE (after enabling write mode): + "Create a new Python file called utils.py" + "Edit main.py and update the API endpoint" + "Delete the old backup.txt file" + "Create a tests directory" + "Move config.json to the configs folder" + +SAFETY FEATURES: + ✓ OFF by default (resets each session) + ✓ Requires explicit activation with user confirmation + ✓ Delete operations ALWAYS require user confirmation + ✓ Limited to allowed MCP folders only + ✓ Ignores .gitignore (can write to any file in allowed folders) + ✓ System directories remain blocked + +IMPORTANT: + ⚠️ Write mode is powerful - use with caution! + ⚠️ Always review what the AI plans to modify + ⚠️ Deletions will show file details + AI's reason for confirmation + ╭─────────────────────────────────────────────────────────╮ │ 🗄️ DATABASE MODE │ ╰─────────────────────────────────────────────────────────╯ @@ -176,15 +214,18 @@ SAFETY: ╰─────────────────────────────────────────────────────────╯ MODE INDICATORS: - [🔧 MCP: Files] You're in file mode - [🗄️ MCP: DB #1] You're querying database #1 + [🔧 MCP: Files] You're in file mode (read-only) + [🔧✍️ MCP: Files+Write] You're in file mode with write permissions + [🗄️ MCP: DB #1] You're querying database #1 QUICK REFERENCE: - /mcp status See current mode, stats, folders/databases + /mcp status See current mode, write mode, stats, folders/databases /mcp files Switch to file mode (default) /mcp db Switch to database mode /mcp db list List all databases with details /mcp list List all folders + /mcp write on Enable write mode (create/edit/delete files) + /mcp write off Disable write mode (read-only) TROUBLESHOOTING: • No results? Check /mcp status to see what's accessible @@ -194,7 +235,9 @@ TROUBLESHOOTING: SECURITY NOTES: • MCP only accesses explicitly added folders/databases - • File mode: read-only access (no write/delete) + • File mode: read-only by default (write mode requires opt-in) + • Write mode: OFF by default, resets each session + • Delete operations: Always require user confirmation • Database mode: SELECT queries only (no modifications) • System directories are blocked automatically • Each addition requires your explicit confirmation @@ -206,8 +249,8 @@ For command-specific help: /help /mcp 'description': 'Manage MCP (Model Context Protocol) for local file access and SQLite database querying. When enabled with a function-calling model, AI can automatically search, read, list files, and query databases. Supports two modes: Files (default) and Database.', 'usage': '/mcp [args]', 'examples': [ - ('Enable MCP server', '/mcp enable'), - ('Disable MCP server', '/mcp disable'), + ('Enable MCP server', '/mcp on'), + ('Disable MCP server', '/mcp off'), ('Show MCP status and current mode', '/mcp status'), ('', ''), ('━━━ FILE MODE ━━━', ''), @@ -216,6 +259,8 @@ For command-specific help: /help /mcp ('Remove folder by number', '/mcp remove 2'), ('List allowed folders', '/mcp list'), ('Toggle .gitignore filtering', '/mcp gitignore on'), + ('Enable write mode', '/mcp write on'), + ('Disable write mode', '/mcp write off'), ('', ''), ('━━━ DATABASE MODE ━━━', ''), ('Add SQLite database', '/mcp add db ~/app/data.db'), @@ -227,10 +272,18 @@ For command-specific help: /help /mcp 'notes': '''MCP allows AI to read local files and query SQLite databases. Works automatically with function-calling models (GPT-4, Claude, etc.). FILE MODE (default): +- Read-only by default (write mode requires opt-in) - Automatically loads and respects .gitignore patterns - Skips virtual environments and build artifacts - Supports search, read, and list operations +WRITE MODE (optional): +- Enable with /mcp write on (requires confirmation) +- Allows creating, editing, and deleting files +- OFF by default, resets each session +- Delete operations always require user confirmation +- Limited to allowed folders only + DATABASE MODE: - Read-only access (no data modification) - Execute SELECT queries with JOINs, subqueries, CTEs @@ -1354,6 +1407,396 @@ class MCPFilesystemServer: log_mcp_stat('search_files', search_path or "all", False, str(e)) return {'error': error_msg} + # ======================================================================== + # FILE WRITE METHODS + # ======================================================================== + + async def write_file(self, file_path: str, content: str) -> Dict[str, Any]: + """Create a new file or overwrite an existing file with content. + + Note: Ignores .gitignore - allows writing to any file in allowed folders. + """ + try: + path = self.config.normalize_path(file_path) + + # Permission check: must be in allowed folders + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} is not in allowed MCP folders" + app_logger.warning(error_msg) + log_mcp_stat('write_file', str(path.parent), False, error_msg) + return {'error': error_msg} + + # Check if parent directory exists, create if needed + parent_dir = path.parent + if not parent_dir.exists(): + parent_dir.mkdir(parents=True, exist_ok=True) + app_logger.info(f"Created parent directory: {parent_dir}") + + # Determine if creating new file or overwriting + is_new_file = not path.exists() + + # Write content to file + path.write_text(content, encoding='utf-8') + file_size = path.stat().st_size + + app_logger.info(f"{'Created' if is_new_file else 'Updated'} file: {path} ({file_size} bytes)") + log_mcp_stat('write_file', str(path.parent), True) + + return { + 'success': True, + 'path': str(path), + 'size': file_size, + 'created': is_new_file + } + + except PermissionError as e: + error_msg = f"Permission denied writing to {file_path}: {e}" + app_logger.error(error_msg) + log_mcp_stat('write_file', file_path, False, str(e)) + return {'error': error_msg, 'error_type': 'PermissionError'} + except UnicodeEncodeError as e: + error_msg = f"Encoding error writing to {file_path}: {e}" + app_logger.error(error_msg) + log_mcp_stat('write_file', file_path, False, str(e)) + return {'error': error_msg, 'error_type': 'EncodingError'} + except Exception as e: + error_msg = f"Error writing file {file_path}: {e}" + app_logger.error(error_msg) + log_mcp_stat('write_file', file_path, False, str(e)) + return {'error': error_msg} + + async def edit_file(self, file_path: str, old_text: str, new_text: str) -> Dict[str, Any]: + """Find and replace text in an existing file. + + Note: Ignores .gitignore - allows editing any file in allowed folders. + """ + try: + path = self.config.normalize_path(file_path) + + # Permission check + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} is not in allowed MCP folders" + app_logger.warning(error_msg) + log_mcp_stat('edit_file', str(path.parent), False, error_msg) + return {'error': error_msg} + + # File must exist + if not path.exists(): + error_msg = f"File not found: {path}" + app_logger.warning(error_msg) + log_mcp_stat('edit_file', str(path), False, error_msg) + return {'error': error_msg} + + if not path.is_file(): + error_msg = f"Not a file: {path}" + app_logger.warning(error_msg) + log_mcp_stat('edit_file', str(path), False, error_msg) + return {'error': error_msg} + + # Read current content + current_content = path.read_text(encoding='utf-8') + + # Check if old_text exists in file + if old_text not in current_content: + error_msg = f"Text not found in file: '{old_text[:50]}...'" + app_logger.warning(f"Edit failed - text not found in {path}") + log_mcp_stat('edit_file', str(path), False, error_msg) + return {'error': error_msg} + + # Count occurrences + occurrence_count = current_content.count(old_text) + + if occurrence_count > 1: + error_msg = f"Text appears {occurrence_count} times in file. Please provide more context to make the match unique." + app_logger.warning(f"Edit failed - ambiguous match in {path}") + log_mcp_stat('edit_file', str(path), False, error_msg) + return {'error': error_msg} + + # Perform replacement + new_content = current_content.replace(old_text, new_text, 1) + path.write_text(new_content, encoding='utf-8') + + app_logger.info(f"Edited file: {path} (1 replacement)") + log_mcp_stat('edit_file', str(path.parent), True) + + return { + 'success': True, + 'path': str(path), + 'changes': 1 + } + + except PermissionError as e: + error_msg = f"Permission denied editing {file_path}: {e}" + app_logger.error(error_msg) + log_mcp_stat('edit_file', file_path, False, str(e)) + return {'error': error_msg, 'error_type': 'PermissionError'} + except UnicodeDecodeError as e: + error_msg = f"Cannot read file (encoding issue): {file_path}" + app_logger.error(error_msg) + log_mcp_stat('edit_file', file_path, False, str(e)) + return {'error': error_msg, 'error_type': 'EncodingError'} + except Exception as e: + error_msg = f"Error editing file {file_path}: {e}" + app_logger.error(error_msg) + log_mcp_stat('edit_file', file_path, False, str(e)) + return {'error': error_msg} + + async def delete_file(self, file_path: str, reason: str) -> Dict[str, Any]: + """Delete a file with user confirmation. + + Always requires user to confirm deletion with the LLM's provided reason. + Note: Ignores .gitignore - allows deleting any file in allowed folders. + """ + try: + path = self.config.normalize_path(file_path) + + # Permission check + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} is not in allowed MCP folders" + app_logger.warning(error_msg) + log_mcp_stat('delete_file', str(path.parent), False, error_msg) + return {'error': error_msg} + + # File must exist + if not path.exists(): + error_msg = f"File not found: {path}" + app_logger.warning(error_msg) + log_mcp_stat('delete_file', str(path), False, error_msg) + return {'error': error_msg} + + if not path.is_file(): + error_msg = f"Not a file: {path}" + app_logger.warning(error_msg) + log_mcp_stat('delete_file', str(path), False, error_msg) + return {'error': error_msg} + + # Get file info for confirmation prompt + file_size = path.stat().st_size + file_mtime = datetime.datetime.fromtimestamp(path.stat().st_mtime) + + # User confirmation required + console.print(Panel.fit( + f"[bold yellow]⚠️ DELETE FILE?[/]\n\n" + f"Path: [cyan]{path}[/]\n" + f"Size: {file_size:,} bytes\n" + f"Modified: {file_mtime.strftime('%Y-%m-%d %H:%M:%S')}\n\n" + f"[bold]LLM Reason:[/] {reason}", + title="Delete File Confirmation", + border_style="yellow" + )) + + confirm = typer.confirm("Delete this file?", default=False) + + if not confirm: + app_logger.info(f"User cancelled file deletion: {path}") + log_mcp_stat('delete_file', str(path), False, "User cancelled") + return { + 'success': False, + 'user_cancelled': True, + 'path': str(path) + } + + # Delete the file + path.unlink() + + app_logger.info(f"Deleted file: {path}") + log_mcp_stat('delete_file', str(path.parent), True) + + return { + 'success': True, + 'path': str(path), + 'user_cancelled': False + } + + except PermissionError as e: + error_msg = f"Permission denied deleting {file_path}: {e}" + app_logger.error(error_msg) + log_mcp_stat('delete_file', file_path, False, str(e)) + return {'error': error_msg, 'error_type': 'PermissionError'} + except Exception as e: + error_msg = f"Error deleting file {file_path}: {e}" + app_logger.error(error_msg) + log_mcp_stat('delete_file', file_path, False, str(e)) + return {'error': error_msg} + + async def create_directory(self, dir_path: str) -> Dict[str, Any]: + """Create a directory (and parent directories if needed). + + Note: Ignores .gitignore - allows creating directories anywhere in allowed folders. + """ + try: + path = self.config.normalize_path(dir_path) + + # Permission check + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} is not in allowed MCP folders" + app_logger.warning(error_msg) + log_mcp_stat('create_directory', str(path.parent), False, error_msg) + return {'error': error_msg} + + # Check if already exists + already_exists = path.exists() + + if already_exists and not path.is_dir(): + error_msg = f"Path exists but is not a directory: {path}" + app_logger.warning(error_msg) + log_mcp_stat('create_directory', str(path), False, error_msg) + return {'error': error_msg} + + # Create directory (and parents) + path.mkdir(parents=True, exist_ok=True) + + if already_exists: + app_logger.info(f"Directory already exists: {path}") + else: + app_logger.info(f"Created directory: {path}") + + log_mcp_stat('create_directory', str(path.parent), True) + + return { + 'success': True, + 'path': str(path), + 'created': not already_exists + } + + except PermissionError as e: + error_msg = f"Permission denied creating directory {dir_path}: {e}" + app_logger.error(error_msg) + log_mcp_stat('create_directory', dir_path, False, str(e)) + return {'error': error_msg, 'error_type': 'PermissionError'} + except Exception as e: + error_msg = f"Error creating directory {dir_path}: {e}" + app_logger.error(error_msg) + log_mcp_stat('create_directory', dir_path, False, str(e)) + return {'error': error_msg} + + async def move_file(self, source_path: str, dest_path: str) -> Dict[str, Any]: + """Move or rename a file. + + Note: Ignores .gitignore - allows moving any file within allowed folders. + """ + try: + source = self.config.normalize_path(source_path) + dest = self.config.normalize_path(dest_path) + + # Both paths must be in allowed folders + if not self.is_allowed_path(source): + error_msg = f"Access denied: source {source} is not in allowed MCP folders" + app_logger.warning(error_msg) + log_mcp_stat('move_file', str(source.parent), False, error_msg) + return {'error': error_msg} + + if not self.is_allowed_path(dest): + error_msg = f"Access denied: destination {dest} is not in allowed MCP folders" + app_logger.warning(error_msg) + log_mcp_stat('move_file', str(dest.parent), False, error_msg) + return {'error': error_msg} + + # Source must exist + if not source.exists(): + error_msg = f"Source file not found: {source}" + app_logger.warning(error_msg) + log_mcp_stat('move_file', str(source), False, error_msg) + return {'error': error_msg} + + if not source.is_file(): + error_msg = f"Source is not a file: {source}" + app_logger.warning(error_msg) + log_mcp_stat('move_file', str(source), False, error_msg) + return {'error': error_msg} + + # Create destination parent directory if needed + dest.parent.mkdir(parents=True, exist_ok=True) + + # Move/rename the file + import shutil + shutil.move(str(source), str(dest)) + + app_logger.info(f"Moved file: {source} -> {dest}") + log_mcp_stat('move_file', str(source.parent), True) + + return { + 'success': True, + 'source': str(source), + 'destination': str(dest) + } + + except PermissionError as e: + error_msg = f"Permission denied moving file: {e}" + app_logger.error(error_msg) + log_mcp_stat('move_file', source_path, False, str(e)) + return {'error': error_msg, 'error_type': 'PermissionError'} + except Exception as e: + error_msg = f"Error moving file from {source_path} to {dest_path}: {e}" + app_logger.error(error_msg) + log_mcp_stat('move_file', source_path, False, str(e)) + return {'error': error_msg} + + async def copy_file(self, source_path: str, dest_path: str) -> Dict[str, Any]: + """Copy a file to a new location. + + Note: Ignores .gitignore - allows copying any file within allowed folders. + """ + try: + source = self.config.normalize_path(source_path) + dest = self.config.normalize_path(dest_path) + + # Both paths must be in allowed folders + if not self.is_allowed_path(source): + error_msg = f"Access denied: source {source} is not in allowed MCP folders" + app_logger.warning(error_msg) + log_mcp_stat('copy_file', str(source.parent), False, error_msg) + return {'error': error_msg} + + if not self.is_allowed_path(dest): + error_msg = f"Access denied: destination {dest} is not in allowed MCP folders" + app_logger.warning(error_msg) + log_mcp_stat('copy_file', str(dest.parent), False, error_msg) + return {'error': error_msg} + + # Source must exist + if not source.exists(): + error_msg = f"Source file not found: {source}" + app_logger.warning(error_msg) + log_mcp_stat('copy_file', str(source), False, error_msg) + return {'error': error_msg} + + if not source.is_file(): + error_msg = f"Source is not a file: {source}" + app_logger.warning(error_msg) + log_mcp_stat('copy_file', str(source), False, error_msg) + return {'error': error_msg} + + # Create destination parent directory if needed + dest.parent.mkdir(parents=True, exist_ok=True) + + # Copy the file + import shutil + shutil.copy2(str(source), str(dest)) + + file_size = dest.stat().st_size + + app_logger.info(f"Copied file: {source} -> {dest} ({file_size} bytes)") + log_mcp_stat('copy_file', str(source.parent), True) + + return { + 'success': True, + 'source': str(source), + 'destination': str(dest), + 'size': file_size + } + + except PermissionError as e: + error_msg = f"Permission denied copying file: {e}" + app_logger.error(error_msg) + log_mcp_stat('copy_file', source_path, False, str(e)) + return {'error': error_msg, 'error_type': 'PermissionError'} + except Exception as e: + error_msg = f"Error copying file from {source_path} to {dest_path}: {e}" + app_logger.error(error_msg) + log_mcp_stat('copy_file', source_path, False, str(e)) + return {'error': error_msg} + # ======================================================================== # SQLite DATABASE METHODS # ======================================================================== @@ -1710,6 +2153,7 @@ class MCPManager: def __init__(self): self.enabled = False + self.write_enabled = False # Write mode off by default (non-persistent) self.mode = "files" # "files" or "database" self.selected_db_index = None @@ -1864,12 +2308,57 @@ class MCPManager: 'error': str(e) } + def enable_write(self) -> None: + """Enable write mode with user confirmation.""" + if not self.enabled: + console.print("[bold red]Error:[/] MCP must be enabled first. Use '/mcp on'") + app_logger.warning("Attempted to enable write mode without MCP enabled") + return + + # Show warning panel + console.print(Panel.fit( + "[bold yellow]⚠️ WRITE MODE WARNING[/]\n\n" + "Enabling write mode allows the LLM to:\n" + "• Create new files\n" + "• Modify existing files\n" + "• Delete files (with confirmation)\n" + "• Create directories\n" + "• Move and copy files\n\n" + "[dim]Operations are limited to your allowed MCP folders.[/]\n" + "[dim]Deletions will always require your confirmation.[/]\n\n" + "[bold red]Use with caution. Review LLM changes carefully.[/]", + title="MCP Write Mode", + border_style="yellow" + )) + + confirm = typer.confirm("Enable write mode?", default=False) + + if confirm: + self.write_enabled = True + console.print("[bold green]✓ Write mode enabled[/]") + app_logger.info("MCP write mode enabled by user") + log_mcp_stat('write_mode_enabled', '', True) + else: + console.print("[yellow]Cancelled. Write mode remains disabled.[/]") + app_logger.info("User cancelled write mode enablement") + + def disable_write(self) -> None: + """Disable write mode.""" + if not self.write_enabled: + console.print("[dim]Write mode is already disabled.[/]") + return + + self.write_enabled = False + console.print("[bold green]✓ Write mode disabled[/]") + app_logger.info("MCP write mode disabled") + log_mcp_stat('write_mode_disabled', '', True) + def switch_mode(self, new_mode: str, db_index: Optional[int] = None) -> Dict[str, Any]: """Switch between files and database mode.""" if not self.enabled: return { 'success': False, - 'error': 'MCP is not enabled. Use /mcp enable first' + 'error': 'MCP is not enabled. Use /mcp on first' } if new_mode == "files": @@ -1918,7 +2407,7 @@ class MCPManager: if not self.enabled: return { 'success': False, - 'error': 'MCP is not enabled. Use /mcp enable first' + 'error': 'MCP is not enabled. Use /mcp on first' } try: @@ -1998,7 +2487,7 @@ class MCPManager: if not self.enabled: return { 'success': False, - 'error': 'MCP is not enabled. Use /mcp enable first' + 'error': 'MCP is not enabled. Use /mcp on first' } try: @@ -2055,7 +2544,7 @@ class MCPManager: if not self.enabled: return { 'success': False, - 'error': 'MCP is not enabled. Use /mcp enable first' + 'error': 'MCP is not enabled. Use /mcp on first' } try: @@ -2320,6 +2809,7 @@ class MCPManager: return { 'success': True, 'enabled': self.enabled, + 'write_enabled': self.write_enabled, 'uptime': uptime, 'mode_info': mode_info, 'folder_count': len(self.allowed_folders), @@ -2358,8 +2848,15 @@ class MCPManager: 'error': 'MCP is not enabled' } + # Check write permissions for write operations + write_tools = {'write_file', 'edit_file', 'delete_file', 'create_directory', 'move_file', 'copy_file'} + if tool_name in write_tools and not self.write_enabled: + return { + 'error': 'Write operations disabled. Enable with /mcp write on' + } + try: - # File mode tools + # File mode tools (read-only) if tool_name == 'read_file': return await self.server.read_file(kwargs.get('file_path', '')) elif tool_name == 'list_directory': @@ -2402,6 +2899,38 @@ class MCPManager: kwargs.get('query', ''), kwargs.get('limit') ) + + # Write mode tools + elif tool_name == 'write_file': + return await self.server.write_file( + kwargs.get('file_path', ''), + kwargs.get('content', '') + ) + elif tool_name == 'edit_file': + return await self.server.edit_file( + kwargs.get('file_path', ''), + kwargs.get('old_text', ''), + kwargs.get('new_text', '') + ) + elif tool_name == 'delete_file': + return await self.server.delete_file( + kwargs.get('file_path', ''), + kwargs.get('reason', 'No reason provided') + ) + elif tool_name == 'create_directory': + return await self.server.create_directory( + kwargs.get('dir_path', '') + ) + elif tool_name == 'move_file': + return await self.server.move_file( + kwargs.get('source_path', ''), + kwargs.get('dest_path', '') + ) + elif tool_name == 'copy_file': + return await self.server.copy_file( + kwargs.get('source_path', ''), + kwargs.get('dest_path', '') + ) else: return { 'error': f'Unknown tool: {tool_name}' @@ -2431,7 +2960,8 @@ class MCPManager: allowed_dirs_str = ", ".join(str(f) for f in self.allowed_folders) - return [ + # Base read-only tools + tools = [ { "type": "function", "function": { @@ -2498,6 +3028,139 @@ class MCPManager: } } ] + + # Add write tools if write mode is enabled + if self.write_enabled: + tools.extend([ + { + "type": "function", + "function": { + "name": "write_file", + "description": f"Create a new file or overwrite an existing file with the specified content. Works within allowed directories: {allowed_dirs_str}. Ignores .gitignore patterns - can write to any file in allowed folders. Automatically creates parent directories if needed.", + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Full path to the file to create or overwrite (e.g., /Users/username/project/src/main.py)" + }, + "content": { + "type": "string", + "description": "The complete content to write to the file" + } + }, + "required": ["file_path", "content"] + } + } + }, + { + "type": "function", + "function": { + "name": "edit_file", + "description": f"Make targeted edits to an existing file by finding and replacing specific text. The old_text must match exactly and appear only once in the file. For multiple matches, provide more context to make the match unique. Works within allowed directories: {allowed_dirs_str}. Ignores .gitignore patterns.", + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Full path to the file to edit" + }, + "old_text": { + "type": "string", + "description": "The exact text to find and replace. Must match exactly and appear only once. Include enough context to make the match unique." + }, + "new_text": { + "type": "string", + "description": "The new text to replace the old text with" + } + }, + "required": ["file_path", "old_text", "new_text"] + } + } + }, + { + "type": "function", + "function": { + "name": "delete_file", + "description": f"Delete a file from the filesystem. ALWAYS requires user confirmation before deletion. The user will see the file path, size, modification time, and your reason before deciding. Works within allowed directories: {allowed_dirs_str}. Ignores .gitignore patterns.", + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Full path to the file to delete" + }, + "reason": { + "type": "string", + "description": "Clear explanation for why this file should be deleted. The user will see this reason in the confirmation prompt." + } + }, + "required": ["file_path", "reason"] + } + } + }, + { + "type": "function", + "function": { + "name": "create_directory", + "description": f"Create a new directory (and all parent directories if needed). If the directory already exists, returns success without error. Works within allowed directories: {allowed_dirs_str}. Ignores .gitignore patterns.", + "parameters": { + "type": "object", + "properties": { + "dir_path": { + "type": "string", + "description": "Full path to the directory to create (e.g., /Users/username/project/src/components)" + } + }, + "required": ["dir_path"] + } + } + }, + { + "type": "function", + "function": { + "name": "move_file", + "description": f"Move or rename a file. Both source and destination must be within allowed directories: {allowed_dirs_str}. Creates destination parent directories if needed. Ignores .gitignore patterns.", + "parameters": { + "type": "object", + "properties": { + "source_path": { + "type": "string", + "description": "Full path to the file to move/rename" + }, + "dest_path": { + "type": "string", + "description": "Full path for the new location/name" + } + }, + "required": ["source_path", "dest_path"] + } + } + }, + { + "type": "function", + "function": { + "name": "copy_file", + "description": f"Copy a file to a new location. Both source and destination must be within allowed directories: {allowed_dirs_str}. Creates destination parent directories if needed. Preserves file metadata. Ignores .gitignore patterns.", + "parameters": { + "type": "object", + "properties": { + "source_path": { + "type": "string", + "description": "Full path to the file to copy" + }, + "dest_path": { + "type": "string", + "description": "Full path for the copy destination" + } + }, + "required": ["source_path", "dest_path"] + } + } + } + ]) + + return tools def _get_database_tools_schema(self) -> List[Dict[str, Any]]: """Get database mode tools schema.""" @@ -2951,6 +3614,7 @@ def display_paginated_table(table: Table, title: str): header_lines = [] data_lines = [] + footer_line = [] header_end_index = 0 found_header_text = False @@ -2970,6 +3634,14 @@ def display_paginated_table(table: Table, title: str): header_end_index = i break + # Extract footer (bottom border) - it's the last line of the table + if all_lines: + last_line_text = ''.join(seg.text for seg in all_lines[-1]) + # Check if last line is a border line (contains box drawing characters) + if any(char in last_line_text for char in ['─', '━', '┴', '╧', '┘', '└']): + footer_line = all_lines[-1] + all_lines = all_lines[:-1] # Remove footer from all_lines + if header_end_index > 0: header_lines = all_lines[:header_end_index + 1] data_lines = all_lines[header_end_index + 1:] @@ -2999,6 +3671,12 @@ def display_paginated_table(table: Table, title: str): console.print(segment.text, style=segment.style, end="") console.print() + # Add footer (bottom border) on each page + if footer_line: + for segment in footer_line: + console.print(segment.text, style=segment.style, end="") + console.print() + current_line = end_line page_number += 1 @@ -3089,7 +3767,21 @@ def chat(): if not selected_model: console.print("[bold yellow]No model selected. Use '/model'.[/]") - session = PromptSession(history=FileHistory(str(history_file))) + # Create custom key bindings to add tab support for autocomplete + kb = KeyBindings() + + @kb.add('tab') + def _(event): + """Accept suggestion with tab key.""" + b = event.current_buffer + suggestion = b.suggestion + if suggestion and b.document.is_cursor_at_the_end: + b.insert_text(suggestion.text) + + session = PromptSession( + history=FileHistory(str(history_file)), + key_bindings=kb + ) while True: try: @@ -3099,7 +3791,10 @@ def chat(): prompt_prefix = "You> " if mcp_manager.enabled: if mcp_manager.mode == "files": - prompt_prefix = "[🔧 MCP: Files] You> " + if mcp_manager.write_enabled: + prompt_prefix = "[🔧✍️ MCP: Files+Write] You> " + else: + prompt_prefix = "[🔧 MCP: Files] You> " elif mcp_manager.mode == "database" and mcp_manager.selected_db_index is not None: db = mcp_manager.databases[mcp_manager.selected_db_index] prompt_prefix = f"[🗄️ MCP: DB #{mcp_manager.selected_db_index + 1}] You> " @@ -3141,7 +3836,7 @@ def chat(): mcp_command = parts[0].lower() if parts else "" mcp_args = parts[1] if len(parts) > 1 else "" - if mcp_command == "enable": + if mcp_command in ["enable", "on"]: result = mcp_manager.enable() if result['success']: console.print(f"[bold green]✓ {result['message']}[/]") @@ -3162,7 +3857,7 @@ def chat(): console.print(f"[bold red]❌ {result['error']}[/]") continue - elif mcp_command == "disable": + elif mcp_command in ["disable", "off"]: result = mcp_manager.disable() if result['success']: console.print(f"[bold green]✓ {result['message']}[/]") @@ -3472,7 +4167,7 @@ def chat(): console.print(f"[bold blue].gitignore filtering: {status}[/]") console.print(f"[dim cyan]Loaded {pattern_count} pattern(s) from .gitignore files[/]") else: - console.print("[bold yellow]MCP is not enabled. Enable with /mcp enable[/]") + console.print("[bold yellow]MCP is not enabled. Enable with /mcp on[/]") continue if mcp_args.lower() == "on": @@ -3493,6 +4188,28 @@ def chat(): console.print("[bold yellow]Usage: /mcp gitignore on|off[/]") continue + elif mcp_command == "write": + if not mcp_args: + # Show current write mode status + if mcp_manager.enabled: + status = "[bold green]Enabled ⚠️[/]" if mcp_manager.write_enabled else "[dim]Disabled[/]" + console.print(f"[bold blue]Write Mode:[/] {status}") + if mcp_manager.write_enabled: + console.print("[dim cyan]LLM can create, modify, and delete files in allowed folders[/]") + else: + console.print("[dim cyan]Use '/mcp write on' to enable write operations[/]") + else: + console.print("[bold yellow]MCP is not enabled. Enable with /mcp on[/]") + continue + + if mcp_args.lower() == "on": + mcp_manager.enable_write() + elif mcp_args.lower() == "off": + mcp_manager.disable_write() + else: + console.print("[bold yellow]Usage: /mcp write on|off[/]") + continue + elif mcp_command == "status": result = mcp_manager.get_status() @@ -3529,6 +4246,10 @@ def chat(): if result['gitignore_patterns'] > 0: table.add_row(".gitignore Patterns", str(result['gitignore_patterns'])) + # Write mode status + write_status = "[green]Enabled ⚠️[/]" if result['write_enabled'] else "[dim]Disabled[/]" + table.add_row("Write Mode", write_status) + if result['enabled']: table.add_row("", "") table.add_row("[bold]Tools Available[/]", "") @@ -3581,8 +4302,8 @@ def chat(): else: console.print("[bold cyan]MCP (Model Context Protocol) Commands:[/]\n") - console.print(" [bold yellow]/mcp enable[/] - Start MCP server") - console.print(" [bold yellow]/mcp disable[/] - Stop MCP server") + console.print(" [bold yellow]/mcp on[/] - Start MCP server") + console.print(" [bold yellow]/mcp off[/] - Stop MCP server") console.print(" [bold yellow]/mcp status[/] - Show comprehensive MCP status") console.print("") console.print(" [bold cyan]FILE MODE (default):[/]") @@ -4390,14 +5111,14 @@ def chat(): "" ) help_table.add_row( - "/mcp enable", + "/mcp on", "Start MCP server for file and database access.", - "/mcp enable" + "/mcp on" ) help_table.add_row( - "/mcp disable", + "/mcp off", "Stop MCP server.", - "/mcp disable" + "/mcp off" ) help_table.add_row( "/mcp status", @@ -4449,6 +5170,11 @@ def chat(): "Toggle .gitignore filtering (respects project excludes).", "/mcp gitignore on" ) + help_table.add_row( + "/mcp write [on|off]", + "Enable/disable write mode (create, edit, delete files). Requires confirmation.", + "/mcp write on" + ) # MODEL COMMANDS help_table.add_row( @@ -5111,11 +5837,14 @@ All queries are read-only. INSERT/UPDATE/DELETE are not allowed.""" continue # Now it's safe to continue else: full_response = response.choices[0].message.content if response.choices else "" - console.print(f"\r{' ' * 50}\r", end="") + # Clear any processing messages before showing response + console.print(f"\r{' ' * 100}\r", end="") if full_response: if not is_streaming: md = Markdown(full_response) + # Add newline before Panel to ensure clean rendering + console.print() console.print(Panel(md, title="[bold green]AI Response[/]", title_align="left", border_style="green")) session_history.append({'prompt': user_input, 'response': full_response}) -- 2.49.1 From 9a76ce2c1f22e6530dc295da63d008463ddfacb4 Mon Sep 17 00:00:00 2001 From: Rune Olsen Date: Mon, 26 Jan 2026 11:16:39 +0100 Subject: [PATCH 08/10] Fixed bug in cost display and handling --- oai.py | 82 +++++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 73 insertions(+), 9 deletions(-) diff --git a/oai.py b/oai.py index 746f759..b7b38b8 100644 --- a/oai.py +++ b/oai.py @@ -4532,12 +4532,24 @@ def chat(): current_index = len(session_history) - 1 if conversation_memory_enabled: memory_start_index = 0 + + # Recalculate session totals from loaded message history total_input_tokens = 0 total_output_tokens = 0 total_cost = 0.0 message_count = 0 + + for msg in session_history: + # Handle both old format (no cost data) and new format (with cost data) + total_input_tokens += msg.get('prompt_tokens', 0) + total_output_tokens += msg.get('completion_tokens', 0) + total_cost += msg.get('msg_cost', 0.0) + message_count += 1 + console.print(f"[bold green]Conversation '{conversation_name}' loaded with {len(session_history)} messages.[/]") - app_logger.info(f"Conversation '{conversation_name}' loaded with {len(session_history)} messages") + if total_cost > 0: + console.print(f"[dim cyan]Restored session totals: {total_input_tokens + total_output_tokens} tokens, ${total_cost:.4f} cost[/]") + app_logger.info(f"Conversation '{conversation_name}' loaded with {len(session_history)} messages, restored cost: ${total_cost:.4f}") continue elif user_input.lower().startswith("/delete"): @@ -4680,8 +4692,8 @@ def chat(): total_output_tokens = 0 total_cost = 0.0 message_count = 0 - console.print("[bold green]Conversation context reset.[/]") - app_logger.info("Conversation context reset by user") + console.print("[bold green]Conversation context reset. Totals cleared.[/]") + app_logger.info("Conversation context reset by user - all totals reset to 0") continue elif user_input.lower().startswith("/info"): @@ -5776,10 +5788,13 @@ All queries are read-only. INSERT/UPDATE/DELETE are not allowed.""" stream_interrupted = False if is_streaming: + # Store the last chunk to get usage data after streaming completes + last_chunk = None try: with Live("", console=console, refresh_per_second=10, auto_refresh=True) as live: try: for chunk in response: + last_chunk = chunk # Keep track of last chunk for usage data if hasattr(chunk, 'error') and chunk.error: console.print(f"\n[bold red]Stream error: {chunk.error.message}[/]") app_logger.error(f"Stream error: {chunk.error.message}") @@ -5835,6 +5850,14 @@ All queries are read-only. INSERT/UPDATE/DELETE are not allowed.""" pass continue # Now it's safe to continue + + # For streaming, try to get usage from last chunk + if last_chunk and hasattr(last_chunk, 'usage'): + response.usage = last_chunk.usage + app_logger.debug("Extracted usage data from last streaming chunk") + elif last_chunk: + app_logger.warning("Last streaming chunk has no usage data") + else: full_response = response.choices[0].message.content if response.choices else "" # Clear any processing messages before showing response @@ -5847,13 +5870,54 @@ All queries are read-only. INSERT/UPDATE/DELETE are not allowed.""" console.print() console.print(Panel(md, title="[bold green]AI Response[/]", title_align="left", border_style="green")) - session_history.append({'prompt': user_input, 'response': full_response}) - current_index = len(session_history) - 1 - + # Extract usage data BEFORE appending to history usage = getattr(response, 'usage', None) - input_tokens = usage.input_tokens if usage and hasattr(usage, 'input_tokens') else 0 - output_tokens = usage.output_tokens if usage and hasattr(usage, 'output_tokens') else 0 - msg_cost = usage.total_cost_usd if usage and hasattr(usage, 'total_cost_usd') else estimate_cost(input_tokens, output_tokens) + + # DEBUG: Log what OpenRouter actually returns + if usage: + app_logger.debug(f"Usage object type: {type(usage)}") + app_logger.debug(f"Usage attributes: {dir(usage)}") + if hasattr(usage, '__dict__'): + app_logger.debug(f"Usage dict: {usage.__dict__}") + else: + app_logger.warning("No usage object in response!") + + # Try both attribute naming conventions (OpenAI standard vs Anthropic) + input_tokens = 0 + output_tokens = 0 + + if usage: + # Try prompt_tokens/completion_tokens (OpenAI/OpenRouter standard) + if hasattr(usage, 'prompt_tokens'): + input_tokens = usage.prompt_tokens or 0 + elif hasattr(usage, 'input_tokens'): + input_tokens = usage.input_tokens or 0 + + if hasattr(usage, 'completion_tokens'): + output_tokens = usage.completion_tokens or 0 + elif hasattr(usage, 'output_tokens'): + output_tokens = usage.output_tokens or 0 + + app_logger.debug(f"Extracted tokens: input={input_tokens}, output={output_tokens}") + + # Get cost from API or estimate + msg_cost = 0.0 + if usage and hasattr(usage, 'total_cost_usd') and usage.total_cost_usd: + msg_cost = float(usage.total_cost_usd) + app_logger.debug(f"Using API cost: ${msg_cost:.6f}") + else: + msg_cost = estimate_cost(input_tokens, output_tokens) + app_logger.debug(f"Estimated cost: ${msg_cost:.6f} (from {input_tokens} input + {output_tokens} output tokens)") + + # NOW append to history with cost data + session_history.append({ + 'prompt': user_input, + 'response': full_response, + 'msg_cost': msg_cost, + 'prompt_tokens': input_tokens, + 'completion_tokens': output_tokens + }) + current_index = len(session_history) - 1 total_input_tokens += input_tokens total_output_tokens += output_tokens -- 2.49.1 From a0ed0eaaf034868139959a513aa78a2f866bc90b Mon Sep 17 00:00:00 2001 From: Rune Olsen Date: Mon, 26 Jan 2026 11:39:42 +0100 Subject: [PATCH 09/10] Fixed bug in cost display and handling --- oai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oai.py b/oai.py index b7b38b8..d3787c2 100644 --- a/oai.py +++ b/oai.py @@ -48,7 +48,7 @@ except ImportError: print("Warning: MCP library not found. Install with: pip install mcp") # App version -version = '2.1.0-RC1' +version = '2.1.0-RC2' app = typer.Typer() -- 2.49.1 From 2ffed27ab6b45722a96f5eb54204ca58ca9d553b Mon Sep 17 00:00:00 2001 From: Rune Olsen Date: Tue, 3 Feb 2026 08:57:16 +0100 Subject: [PATCH 10/10] Final release of 2.1 --- .gitignore | 6 + README.md | 667 ++-- oai.py | 5987 ----------------------------------- oai/__init__.py | 26 + oai/__main__.py | 8 + oai/cli.py | 719 +++++ oai/commands/__init__.py | 24 + oai/commands/handlers.py | 1441 +++++++++ oai/commands/registry.py | 381 +++ oai/config/__init__.py | 11 + oai/config/database.py | 472 +++ oai/config/settings.py | 361 +++ oai/constants.py | 448 +++ oai/core/__init__.py | 14 + oai/core/client.py | 422 +++ oai/core/session.py | 659 ++++ oai/mcp/__init__.py | 28 + oai/mcp/gitignore.py | 166 + oai/mcp/manager.py | 1365 ++++++++ oai/mcp/platform.py | 228 ++ oai/mcp/server.py | 1368 ++++++++ oai/mcp/validators.py | 123 + oai/providers/__init__.py | 32 + oai/providers/base.py | 413 +++ oai/providers/openrouter.py | 623 ++++ oai/py.typed | 2 + oai/ui/__init__.py | 51 + oai/ui/console.py | 242 ++ oai/ui/prompts.py | 274 ++ oai/ui/tables.py | 373 +++ oai/utils/__init__.py | 20 + oai/utils/export.py | 248 ++ oai/utils/files.py | 323 ++ oai/utils/logging.py | 297 ++ pyproject.toml | 134 + 35 files changed, 11494 insertions(+), 6462 deletions(-) delete mode 100644 oai.py create mode 100644 oai/__init__.py create mode 100644 oai/__main__.py create mode 100644 oai/cli.py create mode 100644 oai/commands/__init__.py create mode 100644 oai/commands/handlers.py create mode 100644 oai/commands/registry.py create mode 100644 oai/config/__init__.py create mode 100644 oai/config/database.py create mode 100644 oai/config/settings.py create mode 100644 oai/constants.py create mode 100644 oai/core/__init__.py create mode 100644 oai/core/client.py create mode 100644 oai/core/session.py create mode 100644 oai/mcp/__init__.py create mode 100644 oai/mcp/gitignore.py create mode 100644 oai/mcp/manager.py create mode 100644 oai/mcp/platform.py create mode 100644 oai/mcp/server.py create mode 100644 oai/mcp/validators.py create mode 100644 oai/providers/__init__.py create mode 100644 oai/providers/base.py create mode 100644 oai/providers/openrouter.py create mode 100644 oai/py.typed create mode 100644 oai/ui/__init__.py create mode 100644 oai/ui/console.py create mode 100644 oai/ui/prompts.py create mode 100644 oai/ui/tables.py create mode 100644 oai/utils/__init__.py create mode 100644 oai/utils/export.py create mode 100644 oai/utils/files.py create mode 100644 oai/utils/logging.py create mode 100644 pyproject.toml diff --git a/.gitignore b/.gitignore index 3c2ad4a..c8e0fc6 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,9 @@ Pipfile.lock # Consider if you want to include or exclude *~.nib *~.xib +# Claude Code local settings +.claude/ + # Added by author *.zip .note @@ -39,3 +42,6 @@ b0.sh *.old *.sh *.back +requirements.txt +system_prompt.txt +CLAUDE* diff --git a/README.md b/README.md index 64f3220..61bc8fd 100644 --- a/README.md +++ b/README.md @@ -1,584 +1,301 @@ -# oAI - OpenRouter AI Chat +# oAI - OpenRouter AI Chat Client -A powerful terminal-based chat interface for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI agents to access local files and query SQLite databases directly. - -## Description - -oAI is a feature-rich command-line chat application that provides an interactive interface to OpenRouter's AI models. It now includes **MCP integration** for local file system access and read-only database querying, allowing AI to help with code analysis, data exploration, and more. +A powerful, extensible terminal-based chat client for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI to access local files and query SQLite databases. ## Features ### Core Features - 🤖 Interactive chat with 300+ AI models via OpenRouter -- 🔍 Model selection with search and capability filtering +- 🔍 Model selection with search and filtering - 💾 Conversation save/load/export (Markdown, JSON, HTML) -- 📎 File attachment support (images, PDFs, code files) -- 💰 Session cost tracking and credit monitoring -- 🎨 Rich terminal formatting with syntax highlighting +- 📎 File attachments (images, PDFs, code files) +- 💰 Real-time cost tracking and credit monitoring +- 🎨 Rich terminal UI with syntax highlighting - 📝 Persistent command history with search (Ctrl+R) -- ⚙️ Configurable system prompts and token limits -- 🗄️ SQLite-based configuration and conversation storage - 🌐 Online mode (web search capabilities) -- 🧠 Conversation memory toggle (save costs with stateless mode) +- 🧠 Conversation memory toggle -### NEW: MCP (Model Context Protocol) v2.1.0-beta -- 🔧 **File Mode**: AI can read, search, and list your local files +### MCP Integration +- 🔧 **File Mode**: AI can read, search, and list local files - Automatic .gitignore filtering - - Virtual environment exclusion (venv, node_modules, etc.) - - Supports code files, text, JSON, YAML, and more + - Virtual environment exclusion - Large file handling (auto-truncates >50KB) - -- ✍️ **Write Mode** (NEW!): AI can modify files with your permission - - Create and edit files within allowed folders - - Delete files (always requires confirmation) - - Move, copy, and organize files - - Create directories - - Ignores .gitignore for write operations - - OFF by default - explicit opt-in required - -- 🗄️ **Database Mode**: AI can query your SQLite databases - - Read-only access (no data modification possible) - - Schema inspection (tables, columns, indexes) - - Full-text search across all tables - - SQL query execution (SELECT, JOINs, CTEs, subqueries) - - Query validation and timeout protection - - Result limiting (max 1000 rows) -- 🔒 **Security Features**: - - Explicit folder/database approval required - - System directory blocking - - Write mode OFF by default (non-persistent) - - Delete operations always require user confirmation - - Read-only database access - - SQL injection protection - - Query timeout (5 seconds) +- ✍️ **Write Mode**: AI can modify files with permission + - Create, edit, delete files + - Move, copy, organize files + - Always requires explicit opt-in + +- 🗄️ **Database Mode**: AI can query SQLite databases + - Read-only access (safe) + - Schema inspection + - Full SQL query support ## Requirements -- Python 3.10-3.13 (3.14 not supported yet) -- OpenRouter API key (get one at https://openrouter.ai) -- Function-calling model required for MCP features (GPT-4, Claude, etc.) - -## Screenshot - -[](https://gitlab.pm/rune/oai/src/branch/main/README.md) - -*Screenshot from version 1.0 - MCP interface shows mode indicators like `[🔧 MCP: Files]` or `[🗄️ MCP: DB #1]`* +- Python 3.10-3.13 +- OpenRouter API key ([get one here](https://openrouter.ai)) ## Installation -### Option 1: From Source (Recommended for Development) - -#### 1. Install Dependencies +### Option 1: Install from Source (Recommended) ```bash -pip install -r requirements.txt +# Clone the repository +git clone https://gitlab.pm/rune/oai.git +cd oai + +# Install with pip +pip install -e . ``` -#### 2. Make Executable +### Option 2: Pre-built Binary (macOS/Linux) -```bash -chmod +x oai.py -``` - -#### 3. Copy to PATH - -```bash -# Option 1: System-wide (requires sudo) -sudo cp oai.py /usr/local/bin/oai - -# Option 2: User-local (recommended) -mkdir -p ~/.local/bin -cp oai.py ~/.local/bin/oai - -# Add to PATH if needed (add to ~/.bashrc or ~/.zshrc) -export PATH="$HOME/.local/bin:$PATH" -``` - -#### 4. Verify Installation - -```bash -oai --version -``` - -### Option 2: Pre-built Binaries - -Download platform-specific binaries: -- **macOS (Apple Silicon)**: `oai_vx.x.x_mac_arm64.zip` -- **Linux (x86_64)**: `oai_vx.x.x-linux-x86_64.zip` +Download from [Releases](https://gitlab.pm/rune/oai/releases): +- **macOS (Apple Silicon)**: `oai_v2.1.0_mac_arm64.zip` +- **Linux (x86_64)**: `oai_v2.1.0_linux_x86_64.zip` ```bash # Extract and install -unzip oai_vx.x.x_mac_arm64.zip # or `oai_vx.x.x-linux-x86_64.zip` -chmod +x oai -mkdir -p ~/.local/bin # Remember to add this to your path. Or just move to folder already in your $PATH +unzip oai_v2.1.0_*.zip +mkdir -p ~/.local/bin mv oai ~/.local/bin/ + +# macOS only: Remove quarantine and approve +xattr -cr ~/.local/bin/oai +# Then right-click oai in Finder → Open With → Terminal → Click "Open" ``` - -### Alternative: Shell Alias +### Add to PATH ```bash -# Add to ~/.bashrc or ~/.zshrc -alias oai='python3 /path/to/oai.py' +# Add to ~/.zshrc or ~/.bashrc +export PATH="$HOME/.local/bin:$PATH" ``` ## Quick Start -### First Run Setup +```bash +# Start the chat client +oai chat + +# Or with options +oai chat --model gpt-4o --mcp +``` + +On first run, you'll be prompted for your OpenRouter API key. + +### Basic Commands ```bash -oai +# In the chat interface: +/model # Select AI model +/help # Show all commands +/mcp on # Enable file/database access +/stats # View session statistics +exit # Quit ``` -On first run, you'll be prompted to enter your OpenRouter API key. +## MCP (Model Context Protocol) -### Basic Usage +MCP allows the AI to interact with your local files and databases. + +### File Access ```bash -# Start chatting -oai +/mcp on # Enable MCP +/mcp add ~/Projects # Grant access to folder +/mcp list # View allowed folders -# Select a model -You> /model - -# Enable MCP for file access -You> /mcp on -You> /mcp add ~/Documents - -# Ask AI to help with files (read-only) -[🔧 MCP: Files] You> List all Python files in Documents -[🔧 MCP: Files] You> Read and explain main.py - -# Enable write mode to let AI modify files -You> /mcp write on -[🔧✍️ MCP: Files+Write] You> Create a new Python file with helper functions -[🔧✍️ MCP: Files+Write] You> Refactor main.py to use async/await - -# Switch to database mode -You> /mcp add db ~/myapp/data.db -You> /mcp db 1 -[🗄️ MCP: DB #1] You> Show me all tables -[🗄️ MCP: DB #1] You> Find all users created this month -``` - -## MCP Guide - -### File Mode (Default) - -**Setup:** -```bash -/mcp on # Start MCP server -/mcp add ~/Projects # Grant access to folder -/mcp add ~/Documents # Add another folder -/mcp list # View all allowed folders -``` - -**Natural Language Usage:** -``` +# Now ask the AI: "List all Python files in Projects" -"Read and explain config.yaml" +"Read and explain main.py" "Search for files containing 'TODO'" -"What's in my Documents folder?" ``` -**Available Tools (Read-Only):** -- `read_file` - Read complete file contents -- `list_directory` - List files/folders (recursive optional) -- `search_files` - Search by name or content - -**Available Tools (Write Mode - requires `/mcp write on`):** -- `write_file` - Create new files or overwrite existing ones -- `edit_file` - Find and replace text in existing files -- `delete_file` - Delete files (always requires confirmation) -- `create_directory` - Create directories -- `move_file` - Move or rename files -- `copy_file` - Copy files to new locations - -**Features:** -- ✅ Automatic .gitignore filtering (read operations only) -- ✅ Skips virtual environments (venv, node_modules) -- ✅ Handles large files (auto-truncates >50KB) -- ✅ Cross-platform (macOS, Linux, Windows via WSL) -- ✅ Write mode OFF by default for safety -- ✅ Delete operations require user confirmation with LLM's reason - -### Database Mode - -**Setup:** -```bash -/mcp add db ~/app/database.db # Add SQLite database -/mcp db list # View all databases -/mcp db 1 # Switch to database #1 -``` - -**Natural Language Usage:** -``` -"Show me all tables in this database" -"Find records mentioning 'error'" -"How many users registered last week?" -"Get the schema for the orders table" -"Show me the 10 most recent transactions" -``` - -**Available Tools:** -- `inspect_database` - View schema, tables, columns, indexes -- `search_database` - Full-text search across tables -- `query_database` - Execute read-only SQL queries - -**Supported Queries:** -- ✅ SELECT statements -- ✅ JOINs (INNER, LEFT, RIGHT, FULL) -- ✅ Subqueries -- ✅ CTEs (Common Table Expressions) -- ✅ Aggregations (COUNT, SUM, AVG, etc.) -- ✅ WHERE, GROUP BY, HAVING, ORDER BY, LIMIT -- ❌ INSERT/UPDATE/DELETE (blocked for safety) - ### Write Mode -**Enable Write Mode:** ```bash -/mcp write on # Enable write mode (shows warning, requires confirmation) +/mcp write on # Enable file modifications + +# AI can now: +"Create a new file called utils.py" +"Edit config.json and update the API URL" +"Delete the old backup files" # Always asks for confirmation ``` -**Natural Language Usage:** -``` -"Create a new Python file called utils.py with helper functions" -"Edit main.py and replace the old API endpoint with the new one" -"Delete the backup.old file" (will prompt for confirmation) -"Create a directory called tests" -"Move config.json to the config folder" -``` - -**Important:** -- ⚠️ Write mode is **OFF by default** and resets each session -- ⚠️ Delete operations **always** require user confirmation -- ⚠️ All operations are limited to allowed MCP folders -- ✅ Write operations ignore .gitignore (can write to any file in allowed folders) - -**Disable Write Mode:** -```bash -/mcp write off # Disable write mode (back to read-only) -``` - -### Mode Management +### Database Mode ```bash -/mcp status # Show current mode, write mode, stats, folders/databases -/mcp files # Switch to file mode -/mcp db # Switch to database mode -/mcp gitignore on # Enable .gitignore filtering (default) -/mcp write on|off # Enable/disable write mode -/mcp remove 2 # Remove folder/database by number +/mcp add db ~/app/data.db # Add database +/mcp db 1 # Switch to database mode + +# Ask the AI: +"Show all tables" +"Find users created this month" +"What's the schema for the orders table?" ``` ## Command Reference -### Session Commands -``` -/help [command] Show help menu or detailed command help -/help mcp Comprehensive MCP guide -/clear or /cl Clear terminal screen (or Ctrl+L) -/memory on|off Toggle conversation memory (save costs) -/online on|off Enable/disable web search -/paste [prompt] Paste clipboard content -/retry Resend last prompt -/reset Clear history and system prompt -/prev View previous response -/next View next response -``` +### Chat Commands +| Command | Description | +|---------|-------------| +| `/help [cmd]` | Show help | +| `/model [search]` | Select model | +| `/info [model]` | Model details | +| `/memory on\|off` | Toggle context | +| `/online on\|off` | Toggle web search | +| `/retry` | Resend last message | +| `/clear` | Clear screen | ### MCP Commands -``` -/mcp on Start MCP server -/mcp off Stop MCP server -/mcp status Show comprehensive status (includes write mode) -/mcp add Add folder for file access -/mcp add db Add SQLite database -/mcp list List all folders -/mcp db list List all databases -/mcp db Switch to database mode -/mcp files Switch to file mode -/mcp remove Remove folder/database -/mcp gitignore on Enable .gitignore filtering -/mcp write on Enable write mode (create/edit/delete files) -/mcp write off Disable write mode (read-only) -``` +| Command | Description | +|---------|-------------| +| `/mcp on\|off` | Enable/disable MCP | +| `/mcp status` | Show MCP status | +| `/mcp add ` | Add folder | +| `/mcp add db ` | Add database | +| `/mcp list` | List folders | +| `/mcp db list` | List databases | +| `/mcp db ` | Switch to database | +| `/mcp files` | Switch to file mode | +| `/mcp write on\|off` | Toggle write mode | -### Model Commands -``` -/model [search] Select/change AI model -/info [model_id] Show model details (pricing, capabilities) -``` +### Conversation Commands +| Command | Description | +|---------|-------------| +| `/save ` | Save conversation | +| `/load ` | Load conversation | +| `/list` | List saved conversations | +| `/delete ` | Delete conversation | +| `/export md\|json\|html ` | Export | ### Configuration -``` -/config View all settings -/config api Set API key -/config model Set default model -/config online Set default online mode (on|off) -/config stream Enable/disable streaming (on|off) -/config maxtoken Set max token limit -/config costwarning Set cost warning threshold ($) -/config loglevel Set log level (debug/info/warning/error) -/config log Set log file size (MB) -``` +| Command | Description | +|---------|-------------| +| `/config` | View settings | +| `/config api` | Set API key | +| `/config model ` | Set default model | +| `/config stream on\|off` | Toggle streaming | +| `/stats` | Session statistics | +| `/credits` | Check credits | -### Conversation Management -``` -/save Save conversation -/load Load saved conversation -/delete Delete conversation -/list List saved conversations -/export md|json|html Export conversation -``` +## CLI Options -### Token & System -``` -/maxtoken [value] Set session token limit -/system [prompt] Set system prompt (use 'clear' to reset) -/middleout on|off Enable prompt compression -``` - -### Monitoring -``` -/stats View session statistics -/credits Check OpenRouter credits -``` - -### File Attachments -``` -@/path/to/file Attach file (images, PDFs, code) - -Examples: - Debug @script.py - Analyze @data.json - Review @screenshot.png -``` - -## Configuration Options - -All configuration stored in `~/.config/oai/`: - -### Files -- `oai_config.db` - SQLite database (settings, conversations, MCP config) -- `oai.log` - Application logs (rotating, configurable size) -- `history.txt` - Command history (searchable with Ctrl+R) - -### Key Settings -- **API Key**: OpenRouter authentication -- **Default Model**: Auto-select on startup -- **Streaming**: Real-time response display -- **Max Tokens**: Global and session limits -- **Cost Warning**: Alert threshold for expensive requests -- **Online Mode**: Default web search setting -- **Log Level**: debug/info/warning/error/critical -- **Log Size**: Rotating file size in MB - -## Supported File Types - -### Code Files -`.py, .js, .ts, .cs, .java, .c, .cpp, .h, .hpp, .rb, .ruby, .php, .swift, .kt, .kts, .go, .sh, .bat, .ps1, .R, .scala, .pl, .lua, .dart, .elm` - -### Data Files -`.json, .yaml, .yml, .xml, .csv, .txt, .md` - -### Images -All standard formats: PNG, JPEG, JPG, GIF, WEBP, BMP - -### Documents -PDF (models with document support) - -### Size Limits -- Images: 10 MB max -- Code/Text: Auto-truncates files >50KB -- Binary data: Displayed as `` - -## MCP Security - -### Access Control -- ✅ Explicit folder/database approval required -- ✅ System directories blocked automatically -- ✅ User confirmation for each addition -- ✅ .gitignore patterns respected (file mode) - -### Database Safety -- ✅ Read-only mode (cannot modify data) -- ✅ SQL query validation (blocks INSERT/UPDATE/DELETE) -- ✅ Query timeout (5 seconds max) -- ✅ Result limits (1000 rows max) -- ✅ Database opened in `mode=ro` - -### File System Safety -- ✅ Read-only by default (write mode requires explicit opt-in) -- ✅ Write mode OFF by default each session (non-persistent) -- ✅ Delete operations always require user confirmation -- ✅ Write operations limited to allowed folders only -- ✅ System directories blocked -- ✅ Virtual environment exclusion -- ✅ Build artifact filtering -- ✅ Maximum file size (10 MB) - -## Tips & Tricks - -### Command History -- **↑/↓ arrows**: Navigate previous commands -- **Ctrl+R**: Search command history -- **Auto-complete**: Start typing `/` for command suggestions - -### Cost Optimization ```bash -/memory off # Disable context (stateless mode) -/maxtoken 1000 # Limit response length -/config costwarning 0.01 # Set alert threshold +oai chat [OPTIONS] + +Options: + -m, --model TEXT Model ID to use + -s, --system TEXT System prompt + -o, --online Enable online mode + --mcp Enable MCP server + --help Show help ``` -### MCP Best Practices +Other commands: ```bash -# Check status frequently -/mcp status - -# Use specific paths to reduce search time -"List Python files in Projects/app/" # Better than -"List all Python files" # Slower - -# Database queries - be specific -"SELECT * FROM users LIMIT 10" # Good -"SELECT * FROM users" # May hit row limit +oai config [setting] [value] # Configure settings +oai version # Show version +oai credits # Check credits ``` -### Debugging -```bash -# Enable debug logging -/config loglevel debug +## Configuration -# Check log file -tail -f ~/.config/oai/oai.log +Configuration is stored in `~/.config/oai/`: -# View MCP statistics -/mcp status # Shows tool call counts +| File | Purpose | +|------|---------| +| `oai_config.db` | Settings, conversations, MCP config | +| `oai.log` | Application logs | +| `history.txt` | Command history | + +## Project Structure + +``` +oai/ +├── oai/ +│ ├── __init__.py +│ ├── __main__.py # Entry point for python -m oai +│ ├── cli.py # Main CLI interface +│ ├── constants.py # Configuration constants +│ ├── commands/ # Slash command handlers +│ ├── config/ # Settings and database +│ ├── core/ # Chat client and session +│ ├── mcp/ # MCP server and tools +│ ├── providers/ # AI provider abstraction +│ ├── ui/ # Terminal UI utilities +│ └── utils/ # Logging, export, etc. +├── pyproject.toml # Package configuration +├── build.sh # Binary build script +└── README.md ``` ## Troubleshooting -### MCP Not Working -```bash -# 1. Check if MCP is installed -python3 -c "import mcp; print('MCP OK')" +### macOS Binary Issues -# 2. Verify model supports function calling +```bash +# Remove quarantine attribute +xattr -cr ~/.local/bin/oai + +# Then in Finder: right-click oai → Open With → Terminal → Click "Open" +# After this, oai works from any terminal +``` + +### MCP Not Working + +```bash +# Check if model supports function calling /info # Look for "tools" in supported parameters -# 3. Check MCP status +# Check MCP status /mcp status -# 4. Review logs -tail ~/.config/oai/oai.log +# View logs +tail -f ~/.config/oai/oai.log ``` ### Import Errors + ```bash -# Reinstall dependencies -pip install --force-reinstall -r requirements.txt -``` - -### Binary Issues (macOS) -```bash -# Remove quarantine -xattr -cr ~/.local/bin/oai - -# Check security settings -# System Settings > Privacy & Security > "Allow Anyway" -``` - -### Database Errors -```bash -# Verify it's a valid SQLite database -sqlite3 database.db ".tables" - -# Check file permissions -ls -la database.db +# Reinstall package +pip install -e . --force-reinstall ``` ## Version History -### v2.1.0-RC1 (Current) -- ✨ **NEW**: MCP (Model Context Protocol) integration -- ✨ **NEW**: File system access (read, search, list) -- ✨ **NEW**: Write mode - AI can create, edit, and delete files - - 6 write tools: write_file, edit_file, delete_file, create_directory, move_file, copy_file - - OFF by default - requires explicit `/mcp write on` activation - - Delete operations always require user confirmation - - Non-persistent setting (resets each session) -- ✨ **NEW**: SQLite database querying (read-only) -- ✨ **NEW**: Dual mode support (Files & Database) -- ✨ **NEW**: .gitignore filtering -- ✨ **NEW**: Binary data handling in databases -- ✨ **NEW**: Mode indicators in prompt (shows ✍️ when write mode active) -- ✨ **NEW**: Comprehensive `/help mcp` guide -- 🔧 Improved error handling for tool calls -- 🔧 Enhanced logging for MCP operations -- 🔧 Statistics tracking for tool usage +### v2.1.0 (Current) +- 🏗️ Complete codebase refactoring to modular package structure +- 🔌 Extensible provider architecture for adding new AI providers +- 📦 Proper Python packaging with pyproject.toml +- ✨ MCP integration (file access, write mode, database queries) +- 🔧 Command registry pattern for slash commands +- 📊 Improved cost tracking and session statistics -### v1.9.6 -- Base version with core chat functionality -- Conversation management +### v1.9.x +- Single-file implementation +- Core chat functionality - File attachments -- Cost tracking -- Export capabilities +- Conversation management ## License -MIT License - -Copyright (c) 2024-2025 Rune Olsen - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -Full license: https://opensource.org/licenses/MIT +MIT License - See [LICENSE](LICENSE) for details. ## Author **Rune Olsen** - -- Homepage: https://ai.fubar.pm/ -- Blog: https://blog.rune.pm - Project: https://iurl.no/oai +- Repository: https://gitlab.pm/rune/oai ## Contributing -Contributions welcome! Please: 1. Fork the repository 2. Create a feature branch -3. Submit a pull request with detailed description - -## Acknowledgments - -- OpenRouter team for the unified AI API -- Rich library for beautiful terminal output -- MCP community for the protocol specification +3. Submit a pull request --- -**Star ⭐ this project if you find it useful!** - ---- - -Did you really read all the way down here? WOW! You deserve a 🍾 🥂! +**⭐ Star this project if you find it useful!** diff --git a/oai.py b/oai.py deleted file mode 100644 index d3787c2..0000000 --- a/oai.py +++ /dev/null @@ -1,5987 +0,0 @@ -#!/usr/bin/python3 -W ignore::DeprecationWarning -import sys -import os -import requests -import time -import asyncio -from pathlib import Path -from typing import Optional, List, Dict, Any -import typer -from rich.console import Console -from rich.panel import Panel -from rich.table import Table -from rich.text import Text -from rich.markdown import Markdown -from rich.live import Live -from openrouter import OpenRouter -import pyperclip -import mimetypes -import base64 -import re -import sqlite3 -import json -import datetime -import logging -from logging.handlers import RotatingFileHandler -from prompt_toolkit import PromptSession -from prompt_toolkit.history import FileHistory -from rich.logging import RichHandler -from prompt_toolkit.auto_suggest import AutoSuggestFromHistory -from prompt_toolkit.key_binding import KeyBindings -from prompt_toolkit.filters import Condition -from prompt_toolkit.application.current import get_app -from packaging import version as pkg_version -import io -import platform -import shutil -import subprocess -import fnmatch -import signal - -# MCP imports -try: - from mcp import ClientSession, StdioServerParameters - from mcp.client.stdio import stdio_client - MCP_AVAILABLE = True -except ImportError: - MCP_AVAILABLE = False - print("Warning: MCP library not found. Install with: pip install mcp") - -# App version -version = '2.1.0-RC2' - -app = typer.Typer() - -# Application identification -APP_NAME = "oAI" -APP_URL = "https://iurl.no/oai" - -# Paths -home = Path.home() -config_dir = home / '.config' / 'oai' -cache_dir = home / '.cache' / 'oai' -history_file = config_dir / 'history.txt' -database = config_dir / 'oai_config.db' -log_file = config_dir / 'oai.log' - -# Create dirs -config_dir.mkdir(parents=True, exist_ok=True) -cache_dir.mkdir(parents=True, exist_ok=True) - -# Rich console -console = Console() - -# Valid commands -VALID_COMMANDS = { - '/retry', '/online', '/memory', '/paste', '/export', '/save', '/load', - '/delete', '/list', '/prev', '/next', '/stats', '/middleout', '/reset', - '/info', '/model', '/maxtoken', '/system', '/config', '/credits', '/clear', - '/cl', '/help', '/mcp' -} - -# Command help database (COMPLETE - includes MCP comprehensive guide) -COMMAND_HELP = { - '/clear': { - 'aliases': ['/cl'], - 'description': 'Clear the terminal screen for a clean interface.', - 'usage': '/clear\n/cl', - 'examples': [ - ('Clear screen', '/clear'), - ('Using short alias', '/cl'), - ], - 'notes': 'You can also use the keyboard shortcut Ctrl+L to clear the screen.' - }, - '/help': { - 'description': 'Display help information for commands.', - 'usage': '/help [command|topic]', - 'examples': [ - ('Show all commands', '/help'), - ('Get help for a specific command', '/help /model'), - ('Get detailed MCP help', '/help mcp'), - ], - 'notes': 'Use /help without arguments to see the full command list, /help for detailed command info, or /help mcp for comprehensive MCP documentation.' - }, - 'mcp': { - 'description': 'Complete guide to MCP (Model Context Protocol) - file access and database querying for AI agents.', - 'usage': 'See detailed examples below', - 'examples': [], - 'notes': ''' -MCP (Model Context Protocol) gives your AI assistant direct access to: - • Local files and folders (read, search, list) - • SQLite databases (inspect, search, query) - -╭─────────────────────────────────────────────────────────╮ -│ 🗂️ FILE MODE │ -╰─────────────────────────────────────────────────────────╯ - -SETUP: - /mcp on Start MCP server - /mcp add ~/Documents Grant access to folder - /mcp add ~/Code/project Add another folder - /mcp list View all allowed folders - -USAGE (just ask naturally): - "List all Python files in my Code folder" - "Read the contents of Documents/notes.txt" - "Search for files containing 'budget'" - "What's in my Documents folder?" - -FEATURES: - ✓ Automatically respects .gitignore patterns - ✓ Skips virtual environments (venv, node_modules, etc.) - ✓ Handles large files (auto-truncates >50KB) - ✓ Cross-platform (macOS, Linux, Windows) - -MANAGEMENT: - /mcp remove ~/Desktop Remove folder access - /mcp remove 2 Remove by number - /mcp gitignore on|off Toggle .gitignore filtering - /mcp status Show comprehensive status - -╭─────────────────────────────────────────────────────────╮ -│ ✍️ WRITE MODE (OPTIONAL) │ -╰─────────────────────────────────────────────────────────╯ - -ENABLE WRITE MODE: - /mcp write on Enable file modifications (requires confirmation) - /mcp write off Disable write mode (back to read-only) - -WHAT WRITE MODE ALLOWS: - • Create new files or overwrite existing ones - • Edit files (find and replace text) - • Delete files (always requires confirmation) - • Create directories - • Move and copy files - -USAGE (after enabling write mode): - "Create a new Python file called utils.py" - "Edit main.py and update the API endpoint" - "Delete the old backup.txt file" - "Create a tests directory" - "Move config.json to the configs folder" - -SAFETY FEATURES: - ✓ OFF by default (resets each session) - ✓ Requires explicit activation with user confirmation - ✓ Delete operations ALWAYS require user confirmation - ✓ Limited to allowed MCP folders only - ✓ Ignores .gitignore (can write to any file in allowed folders) - ✓ System directories remain blocked - -IMPORTANT: - ⚠️ Write mode is powerful - use with caution! - ⚠️ Always review what the AI plans to modify - ⚠️ Deletions will show file details + AI's reason for confirmation - -╭─────────────────────────────────────────────────────────╮ -│ 🗄️ DATABASE MODE │ -╰─────────────────────────────────────────────────────────╯ - -SETUP: - /mcp add db ~/app/data.db Add specific database - /mcp add db ~/.config/oai/oai_config.db - /mcp db list View all databases - -SWITCH TO DATABASE: - /mcp db 1 Work with database #1 - /mcp db 2 Switch to database #2 - /mcp files Switch back to file mode - -USAGE (after selecting a database): - "Show me all tables in this database" - "Find any records mentioning 'Santa Clause'" - "How many users are in the database?" - "Show me the schema for the orders table" - "Get the 10 most recent transactions" - -FEATURES: - ✓ Read-only mode (no data modification possible) - ✓ Smart schema inspection - ✓ Full-text search across all tables - ✓ SQL query execution (SELECT only) - ✓ JOINs, subqueries, CTEs supported - ✓ Automatic query timeout (5 seconds) - ✓ Result limits (max 1000 rows) - -SAFETY: - • All queries are read-only (INSERT/UPDATE/DELETE blocked) - • Database opened in read-only mode - • Query validation prevents dangerous operations - • Timeout protection prevents infinite loops - -╭─────────────────────────────────────────────────────────╮ -│ 💡 TIPS & TRICKS │ -╰─────────────────────────────────────────────────────────╯ - -MODE INDICATORS: - [🔧 MCP: Files] You're in file mode (read-only) - [🔧✍️ MCP: Files+Write] You're in file mode with write permissions - [🗄️ MCP: DB #1] You're querying database #1 - -QUICK REFERENCE: - /mcp status See current mode, write mode, stats, folders/databases - /mcp files Switch to file mode (default) - /mcp db Switch to database mode - /mcp db list List all databases with details - /mcp list List all folders - /mcp write on Enable write mode (create/edit/delete files) - /mcp write off Disable write mode (read-only) - -TROUBLESHOOTING: - • No results? Check /mcp status to see what's accessible - • Wrong mode? Use /mcp files or /mcp db to switch - • Database errors? Ensure file exists and is valid SQLite - • .gitignore not working? Check /mcp status shows patterns - -SECURITY NOTES: - • MCP only accesses explicitly added folders/databases - • File mode: read-only by default (write mode requires opt-in) - • Write mode: OFF by default, resets each session - • Delete operations: Always require user confirmation - • Database mode: SELECT queries only (no modifications) - • System directories are blocked automatically - • Each addition requires your explicit confirmation - -For command-specific help: /help /mcp - ''' - }, - '/mcp': { - 'description': 'Manage MCP (Model Context Protocol) for local file access and SQLite database querying. When enabled with a function-calling model, AI can automatically search, read, list files, and query databases. Supports two modes: Files (default) and Database.', - 'usage': '/mcp [args]', - 'examples': [ - ('Enable MCP server', '/mcp on'), - ('Disable MCP server', '/mcp off'), - ('Show MCP status and current mode', '/mcp status'), - ('', ''), - ('━━━ FILE MODE ━━━', ''), - ('Add folder for file access', '/mcp add ~/Documents'), - ('Remove folder by path', '/mcp remove ~/Desktop'), - ('Remove folder by number', '/mcp remove 2'), - ('List allowed folders', '/mcp list'), - ('Toggle .gitignore filtering', '/mcp gitignore on'), - ('Enable write mode', '/mcp write on'), - ('Disable write mode', '/mcp write off'), - ('', ''), - ('━━━ DATABASE MODE ━━━', ''), - ('Add SQLite database', '/mcp add db ~/app/data.db'), - ('List all databases', '/mcp db list'), - ('Switch to database #1', '/mcp db 1'), - ('Remove database', '/mcp remove db 2'), - ('Switch back to file mode', '/mcp files'), - ], - 'notes': '''MCP allows AI to read local files and query SQLite databases. Works automatically with function-calling models (GPT-4, Claude, etc.). - -FILE MODE (default): -- Read-only by default (write mode requires opt-in) -- Automatically loads and respects .gitignore patterns -- Skips virtual environments and build artifacts -- Supports search, read, and list operations - -WRITE MODE (optional): -- Enable with /mcp write on (requires confirmation) -- Allows creating, editing, and deleting files -- OFF by default, resets each session -- Delete operations always require user confirmation -- Limited to allowed folders only - -DATABASE MODE: -- Read-only access (no data modification) -- Execute SELECT queries with JOINs, subqueries, CTEs -- Full-text search across all tables -- Schema inspection and data exploration -- Automatic query validation and timeout protection - -Use /help mcp for comprehensive guide with examples.''' - }, - '/memory': { - 'description': 'Toggle conversation memory. When ON, the AI remembers conversation history. When OFF, each request is independent (saves tokens and cost).', - 'usage': '/memory [on|off]', - 'examples': [ - ('Check current memory status', '/memory'), - ('Enable conversation memory', '/memory on'), - ('Disable memory (save costs)', '/memory off'), - ], - 'notes': 'Memory is ON by default. Disabling memory reduces API costs but the AI won\'t remember previous messages. Messages are still saved locally for your reference.' - }, - '/online': { - 'description': 'Enable or disable online mode (web search capabilities) for the current session.', - 'usage': '/online [on|off]', - 'examples': [ - ('Check online mode status', '/online'), - ('Enable web search', '/online on'), - ('Disable web search', '/online off'), - ], - 'notes': 'Not all models support online mode. The model must have "tools" parameter support. This setting overrides the default online mode configured with /config online.' - }, - '/paste': { - 'description': 'Paste plain text or code from clipboard and send to the AI. Optionally add a prompt.', - 'usage': '/paste [prompt]', - 'examples': [ - ('Paste clipboard content', '/paste'), - ('Paste with a question', '/paste Explain this code'), - ('Paste and ask for review', '/paste Review this for bugs'), - ], - 'notes': 'Only plain text is supported. Binary clipboard data will be rejected. The clipboard content is shown as a preview before sending.' - }, - '/retry': { - 'description': 'Resend the last prompt from conversation history.', - 'usage': '/retry', - 'examples': [ - ('Retry last message', '/retry'), - ], - 'notes': 'Useful when you get an error or want a different response to the same prompt. Requires at least one message in history.' - }, - '/next': { - 'description': 'View the next response in conversation history.', - 'usage': '/next', - 'examples': [ - ('Navigate to next response', '/next'), - ], - 'notes': 'Navigate through conversation history. Use /prev to go backward.' - }, - '/prev': { - 'description': 'View the previous response in conversation history.', - 'usage': '/prev', - 'examples': [ - ('Navigate to previous response', '/prev'), - ], - 'notes': 'Navigate through conversation history. Use /next to go forward.' - }, - '/reset': { - 'description': 'Clear conversation history and reset system prompt. This resets all session metrics.', - 'usage': '/reset', - 'examples': [ - ('Reset conversation', '/reset'), - ], - 'notes': 'Requires confirmation. This clears all message history, resets the system prompt, and resets token/cost counters. Use when starting a completely new conversation topic.' - }, - '/info': { - 'description': 'Display detailed information about a model including pricing, capabilities, context length, and online support.', - 'usage': '/info [model_id]', - 'examples': [ - ('Show current model info', '/info'), - ('Show specific model info', '/info gpt-4o'), - ('Check model capabilities', '/info claude-3-opus'), - ], - 'notes': 'Without arguments, shows info for the currently selected model. Displays pricing per million tokens, supported modalities (text, image, etc.), and parameter support.' - }, - '/model': { - 'description': 'Select or change the AI model for the current session. Shows image, online, and tool capabilities.', - 'usage': '/model [search_term]', - 'examples': [ - ('List all models', '/model'), - ('Search for GPT models', '/model gpt'), - ('Search for Claude models', '/model claude'), - ], - 'notes': 'Models are numbered for easy selection. The table shows Image (✓ if model accepts images), Online (✓ if model supports web search), and Tools (✓ if model supports function calling for MCP).' - }, - '/config': { - 'description': 'View or modify application configuration settings.', - 'usage': '/config [setting] [value]', - 'examples': [ - ('View all settings', '/config'), - ('Set API key', '/config api'), - ('Set default model', '/config model'), - ('Enable streaming', '/config stream on'), - ('Set cost warning threshold', '/config costwarning 0.05'), - ('Set log level', '/config loglevel debug'), - ('Set default online mode', '/config online on'), - ], - 'notes': 'Available settings: api (API key), url (base URL), model (default model), stream (on/off), costwarning (threshold $), maxtoken (limit), online (default on/off), log (size MB), loglevel (debug/info/warning/error/critical).' - }, - '/maxtoken': { - 'description': 'Set a temporary session token limit (cannot exceed stored max token limit).', - 'usage': '/maxtoken [value]', - 'examples': [ - ('View current session limit', '/maxtoken'), - ('Set session limit to 2000', '/maxtoken 2000'), - ('Set to 50000', '/maxtoken 50000'), - ], - 'notes': 'This is a session-only setting and cannot exceed the stored max token limit (set with /config maxtoken). View without arguments to see current value.' - }, - '/system': { - 'description': 'Set or clear the session-level system prompt to guide AI behavior.', - 'usage': '/system [prompt|clear]', - 'examples': [ - ('View current system prompt', '/system'), - ('Set as Python expert', '/system You are a Python expert'), - ('Set as code reviewer', '/system You are a senior code reviewer. Focus on bugs and best practices.'), - ('Clear system prompt', '/system clear'), - ], - 'notes': 'System prompts influence how the AI responds throughout the session. Use "clear" to remove the current system prompt.' - }, - '/save': { - 'description': 'Save the current conversation history to the database.', - 'usage': '/save ', - 'examples': [ - ('Save conversation', '/save my_chat'), - ('Save with descriptive name', '/save python_debugging_2024'), - ], - 'notes': 'Saved conversations can be loaded later with /load. Use descriptive names to easily find them later.' - }, - '/load': { - 'description': 'Load a saved conversation from the database by name or number.', - 'usage': '/load ', - 'examples': [ - ('Load by name', '/load my_chat'), - ('Load by number from /list', '/load 3'), - ], - 'notes': 'Use /list to see numbered conversations. Loading a conversation replaces current history and resets session metrics.' - }, - '/delete': { - 'description': 'Delete a saved conversation from the database. Requires confirmation.', - 'usage': '/delete ', - 'examples': [ - ('Delete by name', '/delete my_chat'), - ('Delete by number from /list', '/delete 3'), - ], - 'notes': 'Use /list to see numbered conversations. This action cannot be undone!' - }, - '/list': { - 'description': 'List all saved conversations with numbers, message counts, and timestamps.', - 'usage': '/list', - 'examples': [ - ('Show saved conversations', '/list'), - ], - 'notes': 'Conversations are numbered for easy use with /load and /delete commands. Shows message count and last saved time.' - }, - '/export': { - 'description': 'Export the current conversation to a file in various formats.', - 'usage': '/export ', - 'examples': [ - ('Export as Markdown', '/export md notes.md'), - ('Export as JSON', '/export json conversation.json'), - ('Export as HTML', '/export html report.html'), - ], - 'notes': 'Available formats: md (Markdown), json (JSON), html (HTML). The export includes all messages and the system prompt if set.' - }, - '/stats': { - 'description': 'Display session statistics including tokens, costs, and credits.', - 'usage': '/stats', - 'examples': [ - ('View session statistics', '/stats'), - ], - 'notes': 'Shows total input/output tokens, total cost, average cost per message, and remaining credits. Also displays credit warnings if applicable.' - }, - '/credits': { - 'description': 'Display your OpenRouter account credits and usage.', - 'usage': '/credits', - 'examples': [ - ('Check credits', '/credits'), - ], - 'notes': 'Shows total credits, used credits, and credits left. Displays warnings if credits are low (< $1 or < 10% of total).' - }, - '/middleout': { - 'description': 'Enable or disable middle-out transform to compress prompts exceeding context size.', - 'usage': '/middleout [on|off]', - 'examples': [ - ('Check status', '/middleout'), - ('Enable compression', '/middleout on'), - ('Disable compression', '/middleout off'), - ], - 'notes': 'Middle-out transform intelligently compresses long prompts to fit within model context limits. Useful for very long conversations.' - }, -} - -# Supported code extensions -SUPPORTED_CODE_EXTENSIONS = { - '.py', '.js', '.ts', '.cs', '.java', '.c', '.cpp', '.h', '.hpp', - '.rb', '.ruby', '.php', '.swift', '.kt', '.kts', '.go', - '.sh', '.bat', '.ps1', '.R', '.scala', '.pl', '.lua', '.dart', - '.elm', '.xml', '.json', '.yaml', '.yml', '.md', '.txt' -} - -# Pricing -MODEL_PRICING = { - 'input': 3.0, - 'output': 15.0 -} -LOW_CREDIT_RATIO = 0.1 -LOW_CREDIT_AMOUNT = 1.0 -HIGH_COST_WARNING = "cost_warning_threshold" - -# Valid log levels -VALID_LOG_LEVELS = { - 'debug': logging.DEBUG, - 'info': logging.INFO, - 'warning': logging.WARNING, - 'error': logging.ERROR, - 'critical': logging.CRITICAL -} - -# System directories to block (cross-platform) -SYSTEM_DIRS_BLACKLIST = { - '/System', '/Library', '/private', '/usr', '/bin', '/sbin', # macOS - '/boot', '/dev', '/proc', '/sys', '/root', # Linux - 'C:\\Windows', 'C:\\Program Files', 'C:\\Program Files (x86)' # Windows -} - -# Database functions -def create_table_if_not_exists(): - """Ensure tables exist.""" - os.makedirs(config_dir, exist_ok=True) - with sqlite3.connect(str(database)) as conn: - conn.execute('''CREATE TABLE IF NOT EXISTS config ( - key TEXT PRIMARY KEY, - value TEXT NOT NULL - )''') - conn.execute('''CREATE TABLE IF NOT EXISTS conversation_sessions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL, - timestamp TEXT NOT NULL, - data TEXT NOT NULL - )''') - # MCP configuration table - conn.execute('''CREATE TABLE IF NOT EXISTS mcp_config ( - key TEXT PRIMARY KEY, - value TEXT NOT NULL - )''') - # MCP statistics table - conn.execute('''CREATE TABLE IF NOT EXISTS mcp_stats ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL, - tool_name TEXT NOT NULL, - folder TEXT, - success INTEGER NOT NULL, - error_message TEXT - )''') - # MCP databases table - conn.execute('''CREATE TABLE IF NOT EXISTS mcp_databases ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - path TEXT NOT NULL UNIQUE, - name TEXT NOT NULL, - size INTEGER, - tables TEXT, - added_timestamp TEXT NOT NULL - )''') - conn.commit() - -def get_config(key: str) -> Optional[str]: - create_table_if_not_exists() - with sqlite3.connect(str(database)) as conn: - cursor = conn.execute('SELECT value FROM config WHERE key = ?', (key,)) - result = cursor.fetchone() - return result[0] if result else None - -def set_config(key: str, value: str): - create_table_if_not_exists() - with sqlite3.connect(str(database)) as conn: - conn.execute('INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)', (key, value)) - conn.commit() - -def get_mcp_config(key: str) -> Optional[str]: - """Get MCP configuration value.""" - create_table_if_not_exists() - with sqlite3.connect(str(database)) as conn: - cursor = conn.execute('SELECT value FROM mcp_config WHERE key = ?', (key,)) - result = cursor.fetchone() - return result[0] if result else None - -def set_mcp_config(key: str, value: str): - """Set MCP configuration value.""" - create_table_if_not_exists() - with sqlite3.connect(str(database)) as conn: - conn.execute('INSERT OR REPLACE INTO mcp_config (key, value) VALUES (?, ?)', (key, value)) - conn.commit() - -def log_mcp_stat(tool_name: str, folder: Optional[str], success: bool, error_message: Optional[str] = None): - """Log MCP tool usage statistics.""" - create_table_if_not_exists() - timestamp = datetime.datetime.now().isoformat() - with sqlite3.connect(str(database)) as conn: - conn.execute( - 'INSERT INTO mcp_stats (timestamp, tool_name, folder, success, error_message) VALUES (?, ?, ?, ?, ?)', - (timestamp, tool_name, folder, 1 if success else 0, error_message) - ) - conn.commit() - -def get_mcp_stats() -> Dict[str, Any]: - """Get MCP usage statistics.""" - create_table_if_not_exists() - with sqlite3.connect(str(database)) as conn: - cursor = conn.execute(''' - SELECT - COUNT(*) as total_calls, - SUM(CASE WHEN tool_name = 'read_file' THEN 1 ELSE 0 END) as reads, - SUM(CASE WHEN tool_name = 'list_directory' THEN 1 ELSE 0 END) as lists, - SUM(CASE WHEN tool_name = 'search_files' THEN 1 ELSE 0 END) as searches, - SUM(CASE WHEN tool_name = 'inspect_database' THEN 1 ELSE 0 END) as db_inspects, - SUM(CASE WHEN tool_name = 'search_database' THEN 1 ELSE 0 END) as db_searches, - SUM(CASE WHEN tool_name = 'query_database' THEN 1 ELSE 0 END) as db_queries, - MAX(timestamp) as last_used - FROM mcp_stats - ''') - row = cursor.fetchone() - return { - 'total_calls': row[0] or 0, - 'reads': row[1] or 0, - 'lists': row[2] or 0, - 'searches': row[3] or 0, - 'db_inspects': row[4] or 0, - 'db_searches': row[5] or 0, - 'db_queries': row[6] or 0, - 'last_used': row[7] - } - -# Rotating Rich Handler -class RotatingRichHandler(RotatingFileHandler): - """Custom handler combining RotatingFileHandler with Rich formatting.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.rich_console = Console(file=io.StringIO(), width=120, force_terminal=False) - self.rich_handler = RichHandler( - console=self.rich_console, - show_time=True, - show_path=True, - rich_tracebacks=True, - tracebacks_suppress=['requests', 'openrouter', 'urllib3', 'httpx', 'openai'] - ) - - def emit(self, record): - try: - self.rich_handler.emit(record) - output = self.rich_console.file.getvalue() - self.rich_console.file.seek(0) - self.rich_console.file.truncate(0) - if output: - self.stream.write(output) - self.flush() - except Exception: - self.handleError(record) - -# Logging setup -LOG_MAX_SIZE_MB = int(get_config('log_max_size_mb') or "10") -LOG_BACKUP_COUNT = int(get_config('log_backup_count') or "2") -LOG_LEVEL_STR = get_config('log_level') or "info" -LOG_LEVEL = VALID_LOG_LEVELS.get(LOG_LEVEL_STR.lower(), logging.INFO) - -app_handler = None -app_logger = None - -def setup_logging(): - """Setup or reset logging configuration.""" - global app_handler, LOG_MAX_SIZE_MB, LOG_BACKUP_COUNT, LOG_LEVEL, app_logger - - root_logger = logging.getLogger() - - if app_handler is not None: - root_logger.removeHandler(app_handler) - try: - app_handler.close() - except: - pass - - # Check if log needs rotation - if os.path.exists(log_file): - current_size = os.path.getsize(log_file) - max_bytes = LOG_MAX_SIZE_MB * 1024 * 1024 - - if current_size >= max_bytes: - import shutil - timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') - backup_file = f"{log_file}.{timestamp}" - try: - shutil.move(str(log_file), backup_file) - except Exception as e: - print(f"Warning: Could not rotate log file: {e}") - - # Clean old backups - log_dir = os.path.dirname(log_file) - log_basename = os.path.basename(log_file) - backup_pattern = f"{log_basename}.*" - - import glob - backups = sorted(glob.glob(os.path.join(log_dir, backup_pattern))) - - while len(backups) > LOG_BACKUP_COUNT: - oldest = backups.pop(0) - try: - os.remove(oldest) - except: - pass - - app_handler = RotatingRichHandler( - filename=str(log_file), - maxBytes=LOG_MAX_SIZE_MB * 1024 * 1024, - backupCount=LOG_BACKUP_COUNT, - encoding='utf-8' - ) - - app_handler.setLevel(logging.NOTSET) - root_logger.setLevel(logging.WARNING) - root_logger.addHandler(app_handler) - - # Suppress noisy loggers - logging.getLogger('asyncio').setLevel(logging.WARNING) - logging.getLogger('urllib3').setLevel(logging.WARNING) - logging.getLogger('requests').setLevel(logging.WARNING) - logging.getLogger('httpx').setLevel(logging.WARNING) - logging.getLogger('httpcore').setLevel(logging.WARNING) - logging.getLogger('openai').setLevel(logging.WARNING) - logging.getLogger('openrouter').setLevel(logging.WARNING) - - app_logger = logging.getLogger("oai_app") - app_logger.setLevel(LOG_LEVEL) - app_logger.propagate = True - - return app_logger - -def set_log_level(level_str: str) -> bool: - """Set application log level.""" - global LOG_LEVEL, LOG_LEVEL_STR, app_logger - level_str_lower = level_str.lower() - if level_str_lower not in VALID_LOG_LEVELS: - return False - LOG_LEVEL = VALID_LOG_LEVELS[level_str_lower] - LOG_LEVEL_STR = level_str_lower - - if app_logger: - app_logger.setLevel(LOG_LEVEL) - - return True - -def reload_logging_config(): - """Reload logging configuration.""" - global LOG_MAX_SIZE_MB, LOG_BACKUP_COUNT, LOG_LEVEL, LOG_LEVEL_STR, app_logger - - LOG_MAX_SIZE_MB = int(get_config('log_max_size_mb') or "10") - LOG_BACKUP_COUNT = int(get_config('log_backup_count') or "2") - LOG_LEVEL_STR = get_config('log_level') or "info" - LOG_LEVEL = VALID_LOG_LEVELS.get(LOG_LEVEL_STR.lower(), logging.INFO) - - app_logger = setup_logging() - return app_logger - -app_logger = setup_logging() -logger = logging.getLogger(__name__) - -# Load configs -API_KEY = get_config('api_key') -OPENROUTER_BASE_URL = get_config('base_url') or "https://openrouter.ai/api/v1" -STREAM_ENABLED = get_config('stream_enabled') or "on" -DEFAULT_MODEL_ID = get_config('default_model') -MAX_TOKEN = int(get_config('max_token') or "100000") -COST_WARNING_THRESHOLD = float(get_config(HIGH_COST_WARNING) or "0.01") -DEFAULT_ONLINE_MODE = get_config('default_online_mode') or "off" - -# Fetch models -models_data = [] -text_models = [] -try: - headers = { - "Authorization": f"Bearer {API_KEY}", - "HTTP-Referer": APP_URL, - "X-Title": APP_NAME - } if API_KEY else { - "HTTP-Referer": APP_URL, - "X-Title": APP_NAME - } - response = requests.get(f"{OPENROUTER_BASE_URL}/models", headers=headers) - response.raise_for_status() - models_data = response.json()["data"] - text_models = [m for m in models_data if "modalities" not in m or "video" not in (m.get("modalities") or [])] - selected_model_default = None - if DEFAULT_MODEL_ID: - selected_model_default = next((m for m in text_models if m["id"] == DEFAULT_MODEL_ID), None) - if not selected_model_default: - console.print(f"[bold yellow]Warning: Default model '{DEFAULT_MODEL_ID}' unavailable. Use '/config model'.[/]") -except Exception as e: - models_data = [] - text_models = [] - app_logger.error(f"Failed to fetch models: {e}") - -# ============================================================================ -# MCP INTEGRATION CLASSES -# ============================================================================ - -class CrossPlatformMCPConfig: - """Handle OS-specific MCP configuration.""" - - def __init__(self): - self.system = platform.system() - self.is_macos = self.system == "Darwin" - self.is_linux = self.system == "Linux" - self.is_windows = self.system == "Windows" - - app_logger.info(f"Detected OS: {self.system}") - - def get_default_allowed_dirs(self) -> List[Path]: - """Get safe default directories.""" - home = Path.home() - - if self.is_macos: - return [ - home / "Documents", - home / "Desktop", - home / "Downloads" - ] - elif self.is_linux: - dirs = [home / "Documents"] - - try: - for xdg_dir in ["DOCUMENTS", "DESKTOP", "DOWNLOAD"]: - result = subprocess.run( - ["xdg-user-dir", xdg_dir], - capture_output=True, - text=True, - timeout=1 - ) - if result.returncode == 0: - dir_path = Path(result.stdout.strip()) - if dir_path.exists(): - dirs.append(dir_path) - except: - dirs.extend([ - home / "Desktop", - home / "Downloads" - ]) - - return list(set(dirs)) - - elif self.is_windows: - return [ - home / "Documents", - home / "Desktop", - home / "Downloads" - ] - - return [home] - - def get_python_command(self) -> str: - """Get Python command.""" - import sys - return sys.executable - - def get_filesystem_warning(self) -> str: - """Get OS-specific security warning.""" - if self.is_macos: - return """ -⚠️ macOS Security Notice: -The Filesystem MCP server needs access to your selected folder. -You may see a security prompt - click 'Allow' to proceed. -(System Settings > Privacy & Security > Files and Folders) -""" - elif self.is_linux: - return """ -⚠️ Linux Security Notice: -The Filesystem MCP server will access your selected folder. -Ensure oAI has appropriate file permissions. -""" - elif self.is_windows: - return """ -⚠️ Windows Security Notice: -The Filesystem MCP server will access your selected folder. -You may need to grant file access permissions. -""" - return "" - - def normalize_path(self, path: str) -> Path: - """Normalize path for current OS.""" - return Path(os.path.expanduser(path)).resolve() - - def is_system_directory(self, path: Path) -> bool: - """Check if path is a system directory.""" - path_str = str(path) - for blocked in SYSTEM_DIRS_BLACKLIST: - if path_str.startswith(blocked): - return True - return False - - def is_safe_path(self, requested_path: Path, allowed_dirs: List[Path]) -> bool: - """Check if path is within allowed directories.""" - try: - requested = requested_path.resolve() - - for allowed in allowed_dirs: - try: - allowed_resolved = allowed.resolve() - requested.relative_to(allowed_resolved) - return True - except ValueError: - continue - - return False - except: - return False - - def get_folder_stats(self, folder: Path) -> Dict[str, Any]: - """Get folder statistics.""" - try: - if not folder.exists() or not folder.is_dir(): - return {'exists': False} - - file_count = 0 - total_size = 0 - - for item in folder.rglob('*'): - if item.is_file(): - file_count += 1 - try: - total_size += item.stat().st_size - except: - pass - - return { - 'exists': True, - 'file_count': file_count, - 'total_size': total_size, - 'size_mb': total_size / (1024 * 1024) - } - except Exception as e: - app_logger.error(f"Error getting folder stats for {folder}: {e}") - return {'exists': False, 'error': str(e)} - - -class GitignoreParser: - """Parse .gitignore files and check if paths should be ignored.""" - - def __init__(self): - self.patterns = [] # List of (pattern, is_negation, source_dir) - - def add_gitignore(self, gitignore_path: Path): - """Parse and add patterns from a .gitignore file.""" - if not gitignore_path.exists(): - return - - try: - source_dir = gitignore_path.parent - with open(gitignore_path, 'r', encoding='utf-8') as f: - for line_num, line in enumerate(f, 1): - line = line.rstrip('\n\r') - - # Skip empty lines and comments - if not line or line.startswith('#'): - continue - - # Check for negation pattern - is_negation = line.startswith('!') - if is_negation: - line = line[1:] - - # Remove leading slash (make relative to gitignore location) - if line.startswith('/'): - line = line[1:] - - self.patterns.append((line, is_negation, source_dir)) - - app_logger.debug(f"Loaded {len(self.patterns)} patterns from {gitignore_path}") - except Exception as e: - app_logger.warning(f"Error reading {gitignore_path}: {e}") - - def should_ignore(self, path: Path) -> bool: - """Check if a path should be ignored based on gitignore patterns.""" - if not self.patterns: - return False - - ignored = False - - for pattern, is_negation, source_dir in self.patterns: - # Only apply pattern if path is under the source directory - try: - rel_path = path.relative_to(source_dir) - except ValueError: - # Path is not relative to this gitignore's directory - continue - - rel_path_str = str(rel_path) - - # Check if pattern matches - if self._match_pattern(pattern, rel_path_str, path.is_dir()): - if is_negation: - ignored = False # Negation patterns un-ignore - else: - ignored = True - - return ignored - - def _match_pattern(self, pattern: str, path: str, is_dir: bool) -> bool: - """Match a gitignore pattern against a path.""" - # Directory-only pattern (ends with /) - if pattern.endswith('/'): - if not is_dir: - return False - pattern = pattern[:-1] - - # ** matches any number of directories - if '**' in pattern: - # Convert ** to regex-like pattern - pattern_parts = pattern.split('**') - if len(pattern_parts) == 2: - prefix, suffix = pattern_parts - # Match if path starts with prefix and ends with suffix - if prefix: - if not path.startswith(prefix.rstrip('/')): - return False - if suffix: - suffix = suffix.lstrip('/') - if not (path.endswith(suffix) or f'/{suffix}' in path): - return False - return True - - # Direct match - if fnmatch.fnmatch(path, pattern): - return True - - # Match as subdirectory pattern - if '/' not in pattern: - # Pattern without / matches in any directory - parts = path.split('/') - if any(fnmatch.fnmatch(part, pattern) for part in parts): - return True - - return False - - -class SQLiteQueryValidator: - """Validate SQLite queries for read-only safety.""" - - @staticmethod - def is_safe_query(query: str) -> tuple[bool, str]: - """ - Validate that query is a safe read-only SELECT. - Returns (is_safe, error_message) - """ - query_upper = query.strip().upper() - - # Must start with SELECT or WITH (for CTEs) - if not (query_upper.startswith('SELECT') or query_upper.startswith('WITH')): - return False, "Only SELECT queries are allowed (including WITH/CTE)" - - # Block dangerous keywords even in SELECT - dangerous_keywords = [ - 'INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', - 'ALTER', 'TRUNCATE', 'REPLACE', 'ATTACH', 'DETACH', - 'PRAGMA', 'VACUUM', 'REINDEX' - ] - - # Check for dangerous keywords (but allow them in string literals) - # Simple check: look for keywords outside of quotes - query_no_strings = re.sub(r"'[^']*'", '', query_upper) - query_no_strings = re.sub(r'"[^"]*"', '', query_no_strings) - - for keyword in dangerous_keywords: - if re.search(r'\b' + keyword + r'\b', query_no_strings): - return False, f"Keyword '{keyword}' not allowed in read-only mode" - - return True, "" - # P2 -class MCPFilesystemServer: - """MCP Filesystem Server with file access and SQLite database querying.""" - - def __init__(self, allowed_folders: List[Path]): - self.allowed_folders = allowed_folders - self.config = CrossPlatformMCPConfig() - self.max_file_size = 10 * 1024 * 1024 # 10MB limit - self.max_list_items = 1000 # Max items to return in list_directory - self.respect_gitignore = True # Default enabled - - # Initialize gitignore parser - self.gitignore_parser = GitignoreParser() - self._load_gitignores() - - # SQLite configuration - self.max_query_timeout = 5 # seconds - self.max_query_results = 1000 # max rows - self.default_query_limit = 100 # default rows - - app_logger.info(f"MCP Filesystem Server initialized with {len(allowed_folders)} folders") - - def _load_gitignores(self): - """Load all .gitignore files from allowed folders.""" - gitignore_count = 0 - for folder in self.allowed_folders: - if not folder.exists(): - continue - - # Load root .gitignore - root_gitignore = folder / '.gitignore' - if root_gitignore.exists(): - self.gitignore_parser.add_gitignore(root_gitignore) - gitignore_count += 1 - app_logger.info(f"Loaded .gitignore from {folder}") - - # Load nested .gitignore files - try: - for gitignore_path in folder.rglob('.gitignore'): - if gitignore_path != root_gitignore: - self.gitignore_parser.add_gitignore(gitignore_path) - gitignore_count += 1 - app_logger.debug(f"Loaded nested .gitignore: {gitignore_path}") - except Exception as e: - app_logger.warning(f"Error loading nested .gitignores from {folder}: {e}") - - if gitignore_count > 0: - app_logger.info(f"Loaded {gitignore_count} .gitignore file(s) with {len(self.gitignore_parser.patterns)} total patterns") - - def reload_gitignores(self): - """Reload all .gitignore files (call when allowed_folders changes).""" - self.gitignore_parser = GitignoreParser() - self._load_gitignores() - app_logger.info("Reloaded .gitignore patterns") - - def is_allowed_path(self, path: Path) -> bool: - """Check if path is allowed.""" - return self.config.is_safe_path(path, self.allowed_folders) - - async def read_file(self, file_path: str) -> Dict[str, Any]: - """Read file contents.""" - try: - path = self.config.normalize_path(file_path) - - if not self.is_allowed_path(path): - error_msg = f"Access denied: {path} not in allowed folders" - app_logger.warning(error_msg) - log_mcp_stat('read_file', str(path.parent), False, error_msg) - return {'error': error_msg} - - if not path.exists(): - error_msg = f"File not found: {path}" - app_logger.warning(error_msg) - log_mcp_stat('read_file', str(path.parent), False, error_msg) - return {'error': error_msg} - - if not path.is_file(): - error_msg = f"Not a file: {path}" - app_logger.warning(error_msg) - log_mcp_stat('read_file', str(path.parent), False, error_msg) - return {'error': error_msg} - - # Check if file should be ignored - if self.respect_gitignore and self.gitignore_parser.should_ignore(path): - error_msg = f"File ignored by .gitignore: {path}" - app_logger.warning(error_msg) - log_mcp_stat('read_file', str(path.parent), False, error_msg) - return {'error': error_msg} - - file_size = path.stat().st_size - if file_size > self.max_file_size: - error_msg = f"File too large: {file_size / (1024*1024):.1f}MB (max: {self.max_file_size / (1024*1024):.0f}MB)" - app_logger.warning(error_msg) - log_mcp_stat('read_file', str(path.parent), False, error_msg) - return {'error': error_msg} - - # Try to read as text - try: - content = path.read_text(encoding='utf-8') - - # If file is large (>50KB), provide summary instead of full content - max_content_size = 50 * 1024 # 50KB - if file_size > max_content_size: - lines = content.split('\n') - total_lines = len(lines) - - # Return first 500 lines + last 100 lines with notice - head_lines = 500 - tail_lines = 100 - - if total_lines > (head_lines + tail_lines): - truncated_content = ( - '\n'.join(lines[:head_lines]) + - f"\n\n... [TRUNCATED: {total_lines - head_lines - tail_lines} lines omitted to stay within size limits] ...\n\n" + - '\n'.join(lines[-tail_lines:]) - ) - - app_logger.info(f"Read file (truncated): {path} ({file_size} bytes, {total_lines} lines)") - log_mcp_stat('read_file', str(path.parent), True) - - return { - 'content': truncated_content, - 'path': str(path), - 'size': file_size, - 'truncated': True, - 'total_lines': total_lines, - 'lines_shown': head_lines + tail_lines, - 'note': f'File truncated: showing first {head_lines} and last {tail_lines} lines of {total_lines} total' - } - - app_logger.info(f"Read file: {path} ({file_size} bytes)") - log_mcp_stat('read_file', str(path.parent), True) - return { - 'content': content, - 'path': str(path), - 'size': file_size - } - except UnicodeDecodeError: - error_msg = f"Cannot decode file as UTF-8: {path}" - app_logger.warning(error_msg) - log_mcp_stat('read_file', str(path.parent), False, error_msg) - return {'error': error_msg} - - except Exception as e: - error_msg = f"Error reading file: {e}" - app_logger.error(error_msg) - log_mcp_stat('read_file', file_path, False, str(e)) - return {'error': error_msg} - - async def list_directory(self, dir_path: str, recursive: bool = False) -> Dict[str, Any]: - """List directory contents.""" - try: - path = self.config.normalize_path(dir_path) - - if not self.is_allowed_path(path): - error_msg = f"Access denied: {path} not in allowed folders" - app_logger.warning(error_msg) - log_mcp_stat('list_directory', str(path), False, error_msg) - return {'error': error_msg} - - if not path.exists(): - error_msg = f"Directory not found: {path}" - app_logger.warning(error_msg) - log_mcp_stat('list_directory', str(path), False, error_msg) - return {'error': error_msg} - - if not path.is_dir(): - error_msg = f"Not a directory: {path}" - app_logger.warning(error_msg) - log_mcp_stat('list_directory', str(path), False, error_msg) - return {'error': error_msg} - - items = [] - pattern = '**/*' if recursive else '*' - - # Directories to skip (hardcoded common patterns) - skip_dirs = { - '.venv', 'venv', 'env', 'virtualenv', - 'site-packages', 'dist-packages', - '__pycache__', '.pytest_cache', '.mypy_cache', - 'node_modules', '.git', '.svn', - '.idea', '.vscode', - 'build', 'dist', 'eggs', '.eggs' - } - - def should_skip_path(item_path: Path) -> bool: - """Check if path should be skipped.""" - # Check hardcoded skip directories - path_parts = item_path.parts - if any(part in skip_dirs for part in path_parts): - return True - - # Check gitignore patterns (if enabled) - if self.respect_gitignore and self.gitignore_parser.should_ignore(item_path): - return True - - return False - - for item in path.glob(pattern): - # Stop if we hit the limit - if len(items) >= self.max_list_items: - break - - # Skip excluded directories - if should_skip_path(item): - continue - - try: - stat = item.stat() - items.append({ - 'name': item.name, - 'path': str(item), - 'type': 'directory' if item.is_dir() else 'file', - 'size': stat.st_size if item.is_file() else 0, - 'modified': datetime.datetime.fromtimestamp(stat.st_mtime).isoformat() - }) - except: - continue - - truncated = len(items) >= self.max_list_items - - app_logger.info(f"Listed directory: {path} ({len(items)} items, recursive={recursive}, truncated={truncated})") - log_mcp_stat('list_directory', str(path), True) - - result = { - 'path': str(path), - 'items': items, - 'count': len(items), - 'truncated': truncated - } - - if truncated: - result['note'] = f'Results limited to {self.max_list_items} items. Use more specific search path or disable recursive mode.' - - return result - - except Exception as e: - error_msg = f"Error listing directory: {e}" - app_logger.error(error_msg) - log_mcp_stat('list_directory', dir_path, False, str(e)) - return {'error': error_msg} - - async def search_files(self, pattern: str, search_path: Optional[str] = None, content_search: bool = False) -> Dict[str, Any]: - """Search for files.""" - try: - # Determine search roots - if search_path: - path = self.config.normalize_path(search_path) - if not self.is_allowed_path(path): - error_msg = f"Access denied: {path} not in allowed folders" - app_logger.warning(error_msg) - log_mcp_stat('search_files', str(path), False, error_msg) - return {'error': error_msg} - search_roots = [path] - else: - search_roots = self.allowed_folders - - matches = [] - - # Directories to skip (virtual environments, caches, etc.) - skip_dirs = { - '.venv', 'venv', 'env', 'virtualenv', - 'site-packages', 'dist-packages', - '__pycache__', '.pytest_cache', '.mypy_cache', - 'node_modules', '.git', '.svn', - '.idea', '.vscode', - 'build', 'dist', 'eggs', '.eggs' - } - - def should_skip_path(item_path: Path) -> bool: - """Check if path should be skipped.""" - # Check hardcoded skip directories - path_parts = item_path.parts - if any(part in skip_dirs for part in path_parts): - return True - - # Check gitignore patterns (if enabled) - if self.respect_gitignore and self.gitignore_parser.should_ignore(item_path): - return True - - return False - - for root in search_roots: - if not root.exists(): - continue - - # Filename search - if not content_search: - for item in root.rglob(pattern): - if item.is_file() and not should_skip_path(item): - try: - stat = item.stat() - matches.append({ - 'path': str(item), - 'name': item.name, - 'size': stat.st_size, - 'modified': datetime.datetime.fromtimestamp(stat.st_mtime).isoformat() - }) - except: - continue - else: - # Content search (slower) - for item in root.rglob('*'): - if item.is_file() and not should_skip_path(item): - try: - # Skip large files - if item.stat().st_size > self.max_file_size: - continue - - # Try to read and search content - try: - content = item.read_text(encoding='utf-8') - if pattern.lower() in content.lower(): - stat = item.stat() - matches.append({ - 'path': str(item), - 'name': item.name, - 'size': stat.st_size, - 'modified': datetime.datetime.fromtimestamp(stat.st_mtime).isoformat() - }) - except: - continue - except: - continue - - search_type = "content" if content_search else "filename" - app_logger.info(f"Searched files: pattern='{pattern}', type={search_type}, found={len(matches)}") - log_mcp_stat('search_files', str(search_roots[0]) if search_roots else None, True) - - return { - 'pattern': pattern, - 'search_type': search_type, - 'matches': matches, - 'count': len(matches) - } - - except Exception as e: - error_msg = f"Error searching files: {e}" - app_logger.error(error_msg) - log_mcp_stat('search_files', search_path or "all", False, str(e)) - return {'error': error_msg} - - # ======================================================================== - # FILE WRITE METHODS - # ======================================================================== - - async def write_file(self, file_path: str, content: str) -> Dict[str, Any]: - """Create a new file or overwrite an existing file with content. - - Note: Ignores .gitignore - allows writing to any file in allowed folders. - """ - try: - path = self.config.normalize_path(file_path) - - # Permission check: must be in allowed folders - if not self.is_allowed_path(path): - error_msg = f"Access denied: {path} is not in allowed MCP folders" - app_logger.warning(error_msg) - log_mcp_stat('write_file', str(path.parent), False, error_msg) - return {'error': error_msg} - - # Check if parent directory exists, create if needed - parent_dir = path.parent - if not parent_dir.exists(): - parent_dir.mkdir(parents=True, exist_ok=True) - app_logger.info(f"Created parent directory: {parent_dir}") - - # Determine if creating new file or overwriting - is_new_file = not path.exists() - - # Write content to file - path.write_text(content, encoding='utf-8') - file_size = path.stat().st_size - - app_logger.info(f"{'Created' if is_new_file else 'Updated'} file: {path} ({file_size} bytes)") - log_mcp_stat('write_file', str(path.parent), True) - - return { - 'success': True, - 'path': str(path), - 'size': file_size, - 'created': is_new_file - } - - except PermissionError as e: - error_msg = f"Permission denied writing to {file_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('write_file', file_path, False, str(e)) - return {'error': error_msg, 'error_type': 'PermissionError'} - except UnicodeEncodeError as e: - error_msg = f"Encoding error writing to {file_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('write_file', file_path, False, str(e)) - return {'error': error_msg, 'error_type': 'EncodingError'} - except Exception as e: - error_msg = f"Error writing file {file_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('write_file', file_path, False, str(e)) - return {'error': error_msg} - - async def edit_file(self, file_path: str, old_text: str, new_text: str) -> Dict[str, Any]: - """Find and replace text in an existing file. - - Note: Ignores .gitignore - allows editing any file in allowed folders. - """ - try: - path = self.config.normalize_path(file_path) - - # Permission check - if not self.is_allowed_path(path): - error_msg = f"Access denied: {path} is not in allowed MCP folders" - app_logger.warning(error_msg) - log_mcp_stat('edit_file', str(path.parent), False, error_msg) - return {'error': error_msg} - - # File must exist - if not path.exists(): - error_msg = f"File not found: {path}" - app_logger.warning(error_msg) - log_mcp_stat('edit_file', str(path), False, error_msg) - return {'error': error_msg} - - if not path.is_file(): - error_msg = f"Not a file: {path}" - app_logger.warning(error_msg) - log_mcp_stat('edit_file', str(path), False, error_msg) - return {'error': error_msg} - - # Read current content - current_content = path.read_text(encoding='utf-8') - - # Check if old_text exists in file - if old_text not in current_content: - error_msg = f"Text not found in file: '{old_text[:50]}...'" - app_logger.warning(f"Edit failed - text not found in {path}") - log_mcp_stat('edit_file', str(path), False, error_msg) - return {'error': error_msg} - - # Count occurrences - occurrence_count = current_content.count(old_text) - - if occurrence_count > 1: - error_msg = f"Text appears {occurrence_count} times in file. Please provide more context to make the match unique." - app_logger.warning(f"Edit failed - ambiguous match in {path}") - log_mcp_stat('edit_file', str(path), False, error_msg) - return {'error': error_msg} - - # Perform replacement - new_content = current_content.replace(old_text, new_text, 1) - path.write_text(new_content, encoding='utf-8') - - app_logger.info(f"Edited file: {path} (1 replacement)") - log_mcp_stat('edit_file', str(path.parent), True) - - return { - 'success': True, - 'path': str(path), - 'changes': 1 - } - - except PermissionError as e: - error_msg = f"Permission denied editing {file_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('edit_file', file_path, False, str(e)) - return {'error': error_msg, 'error_type': 'PermissionError'} - except UnicodeDecodeError as e: - error_msg = f"Cannot read file (encoding issue): {file_path}" - app_logger.error(error_msg) - log_mcp_stat('edit_file', file_path, False, str(e)) - return {'error': error_msg, 'error_type': 'EncodingError'} - except Exception as e: - error_msg = f"Error editing file {file_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('edit_file', file_path, False, str(e)) - return {'error': error_msg} - - async def delete_file(self, file_path: str, reason: str) -> Dict[str, Any]: - """Delete a file with user confirmation. - - Always requires user to confirm deletion with the LLM's provided reason. - Note: Ignores .gitignore - allows deleting any file in allowed folders. - """ - try: - path = self.config.normalize_path(file_path) - - # Permission check - if not self.is_allowed_path(path): - error_msg = f"Access denied: {path} is not in allowed MCP folders" - app_logger.warning(error_msg) - log_mcp_stat('delete_file', str(path.parent), False, error_msg) - return {'error': error_msg} - - # File must exist - if not path.exists(): - error_msg = f"File not found: {path}" - app_logger.warning(error_msg) - log_mcp_stat('delete_file', str(path), False, error_msg) - return {'error': error_msg} - - if not path.is_file(): - error_msg = f"Not a file: {path}" - app_logger.warning(error_msg) - log_mcp_stat('delete_file', str(path), False, error_msg) - return {'error': error_msg} - - # Get file info for confirmation prompt - file_size = path.stat().st_size - file_mtime = datetime.datetime.fromtimestamp(path.stat().st_mtime) - - # User confirmation required - console.print(Panel.fit( - f"[bold yellow]⚠️ DELETE FILE?[/]\n\n" - f"Path: [cyan]{path}[/]\n" - f"Size: {file_size:,} bytes\n" - f"Modified: {file_mtime.strftime('%Y-%m-%d %H:%M:%S')}\n\n" - f"[bold]LLM Reason:[/] {reason}", - title="Delete File Confirmation", - border_style="yellow" - )) - - confirm = typer.confirm("Delete this file?", default=False) - - if not confirm: - app_logger.info(f"User cancelled file deletion: {path}") - log_mcp_stat('delete_file', str(path), False, "User cancelled") - return { - 'success': False, - 'user_cancelled': True, - 'path': str(path) - } - - # Delete the file - path.unlink() - - app_logger.info(f"Deleted file: {path}") - log_mcp_stat('delete_file', str(path.parent), True) - - return { - 'success': True, - 'path': str(path), - 'user_cancelled': False - } - - except PermissionError as e: - error_msg = f"Permission denied deleting {file_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('delete_file', file_path, False, str(e)) - return {'error': error_msg, 'error_type': 'PermissionError'} - except Exception as e: - error_msg = f"Error deleting file {file_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('delete_file', file_path, False, str(e)) - return {'error': error_msg} - - async def create_directory(self, dir_path: str) -> Dict[str, Any]: - """Create a directory (and parent directories if needed). - - Note: Ignores .gitignore - allows creating directories anywhere in allowed folders. - """ - try: - path = self.config.normalize_path(dir_path) - - # Permission check - if not self.is_allowed_path(path): - error_msg = f"Access denied: {path} is not in allowed MCP folders" - app_logger.warning(error_msg) - log_mcp_stat('create_directory', str(path.parent), False, error_msg) - return {'error': error_msg} - - # Check if already exists - already_exists = path.exists() - - if already_exists and not path.is_dir(): - error_msg = f"Path exists but is not a directory: {path}" - app_logger.warning(error_msg) - log_mcp_stat('create_directory', str(path), False, error_msg) - return {'error': error_msg} - - # Create directory (and parents) - path.mkdir(parents=True, exist_ok=True) - - if already_exists: - app_logger.info(f"Directory already exists: {path}") - else: - app_logger.info(f"Created directory: {path}") - - log_mcp_stat('create_directory', str(path.parent), True) - - return { - 'success': True, - 'path': str(path), - 'created': not already_exists - } - - except PermissionError as e: - error_msg = f"Permission denied creating directory {dir_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('create_directory', dir_path, False, str(e)) - return {'error': error_msg, 'error_type': 'PermissionError'} - except Exception as e: - error_msg = f"Error creating directory {dir_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('create_directory', dir_path, False, str(e)) - return {'error': error_msg} - - async def move_file(self, source_path: str, dest_path: str) -> Dict[str, Any]: - """Move or rename a file. - - Note: Ignores .gitignore - allows moving any file within allowed folders. - """ - try: - source = self.config.normalize_path(source_path) - dest = self.config.normalize_path(dest_path) - - # Both paths must be in allowed folders - if not self.is_allowed_path(source): - error_msg = f"Access denied: source {source} is not in allowed MCP folders" - app_logger.warning(error_msg) - log_mcp_stat('move_file', str(source.parent), False, error_msg) - return {'error': error_msg} - - if not self.is_allowed_path(dest): - error_msg = f"Access denied: destination {dest} is not in allowed MCP folders" - app_logger.warning(error_msg) - log_mcp_stat('move_file', str(dest.parent), False, error_msg) - return {'error': error_msg} - - # Source must exist - if not source.exists(): - error_msg = f"Source file not found: {source}" - app_logger.warning(error_msg) - log_mcp_stat('move_file', str(source), False, error_msg) - return {'error': error_msg} - - if not source.is_file(): - error_msg = f"Source is not a file: {source}" - app_logger.warning(error_msg) - log_mcp_stat('move_file', str(source), False, error_msg) - return {'error': error_msg} - - # Create destination parent directory if needed - dest.parent.mkdir(parents=True, exist_ok=True) - - # Move/rename the file - import shutil - shutil.move(str(source), str(dest)) - - app_logger.info(f"Moved file: {source} -> {dest}") - log_mcp_stat('move_file', str(source.parent), True) - - return { - 'success': True, - 'source': str(source), - 'destination': str(dest) - } - - except PermissionError as e: - error_msg = f"Permission denied moving file: {e}" - app_logger.error(error_msg) - log_mcp_stat('move_file', source_path, False, str(e)) - return {'error': error_msg, 'error_type': 'PermissionError'} - except Exception as e: - error_msg = f"Error moving file from {source_path} to {dest_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('move_file', source_path, False, str(e)) - return {'error': error_msg} - - async def copy_file(self, source_path: str, dest_path: str) -> Dict[str, Any]: - """Copy a file to a new location. - - Note: Ignores .gitignore - allows copying any file within allowed folders. - """ - try: - source = self.config.normalize_path(source_path) - dest = self.config.normalize_path(dest_path) - - # Both paths must be in allowed folders - if not self.is_allowed_path(source): - error_msg = f"Access denied: source {source} is not in allowed MCP folders" - app_logger.warning(error_msg) - log_mcp_stat('copy_file', str(source.parent), False, error_msg) - return {'error': error_msg} - - if not self.is_allowed_path(dest): - error_msg = f"Access denied: destination {dest} is not in allowed MCP folders" - app_logger.warning(error_msg) - log_mcp_stat('copy_file', str(dest.parent), False, error_msg) - return {'error': error_msg} - - # Source must exist - if not source.exists(): - error_msg = f"Source file not found: {source}" - app_logger.warning(error_msg) - log_mcp_stat('copy_file', str(source), False, error_msg) - return {'error': error_msg} - - if not source.is_file(): - error_msg = f"Source is not a file: {source}" - app_logger.warning(error_msg) - log_mcp_stat('copy_file', str(source), False, error_msg) - return {'error': error_msg} - - # Create destination parent directory if needed - dest.parent.mkdir(parents=True, exist_ok=True) - - # Copy the file - import shutil - shutil.copy2(str(source), str(dest)) - - file_size = dest.stat().st_size - - app_logger.info(f"Copied file: {source} -> {dest} ({file_size} bytes)") - log_mcp_stat('copy_file', str(source.parent), True) - - return { - 'success': True, - 'source': str(source), - 'destination': str(dest), - 'size': file_size - } - - except PermissionError as e: - error_msg = f"Permission denied copying file: {e}" - app_logger.error(error_msg) - log_mcp_stat('copy_file', source_path, False, str(e)) - return {'error': error_msg, 'error_type': 'PermissionError'} - except Exception as e: - error_msg = f"Error copying file from {source_path} to {dest_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('copy_file', source_path, False, str(e)) - return {'error': error_msg} - - # ======================================================================== - # SQLite DATABASE METHODS - # ======================================================================== - - async def inspect_database(self, db_path: str, table_name: Optional[str] = None) -> Dict[str, Any]: - """Inspect SQLite database schema.""" - try: - path = Path(db_path).resolve() - - if not path.exists(): - error_msg = f"Database not found: {path}" - app_logger.warning(error_msg) - log_mcp_stat('inspect_database', str(path.parent), False, error_msg) - return {'error': error_msg} - - if not path.is_file(): - error_msg = f"Not a file: {path}" - app_logger.warning(error_msg) - log_mcp_stat('inspect_database', str(path.parent), False, error_msg) - return {'error': error_msg} - - # Open database in read-only mode - try: - conn = sqlite3.connect(f'file:{path}?mode=ro', uri=True) - cursor = conn.cursor() - except sqlite3.DatabaseError as e: - error_msg = f"Not a valid SQLite database: {path} ({e})" - app_logger.warning(error_msg) - log_mcp_stat('inspect_database', str(path.parent), False, error_msg) - return {'error': error_msg} - - try: - if table_name: - # Get info for specific table - cursor.execute("SELECT sql FROM sqlite_master WHERE type='table' AND name=?", (table_name,)) - result = cursor.fetchone() - - if not result: - return {'error': f"Table '{table_name}' not found"} - - create_sql = result[0] - - # Get column info - cursor.execute(f"PRAGMA table_info({table_name})") - columns = [] - for row in cursor.fetchall(): - columns.append({ - 'id': row[0], - 'name': row[1], - 'type': row[2], - 'not_null': bool(row[3]), - 'default_value': row[4], - 'primary_key': bool(row[5]) - }) - - # Get indexes - cursor.execute(f"PRAGMA index_list({table_name})") - indexes = [] - for row in cursor.fetchall(): - indexes.append({ - 'name': row[1], - 'unique': bool(row[2]) - }) - - # Get row count - cursor.execute(f"SELECT COUNT(*) FROM {table_name}") - row_count = cursor.fetchone()[0] - - app_logger.info(f"Inspected table: {table_name} in {path}") - log_mcp_stat('inspect_database', str(path.parent), True) - - return { - 'database': str(path), - 'table': table_name, - 'create_sql': create_sql, - 'columns': columns, - 'indexes': indexes, - 'row_count': row_count - } - else: - # Get all tables - cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table' ORDER BY name") - tables = [] - for row in cursor.fetchall(): - table = row[0] - # Get row count - cursor.execute(f"SELECT COUNT(*) FROM {table}") - count = cursor.fetchone()[0] - tables.append({ - 'name': table, - 'row_count': count - }) - - # Get indexes - cursor.execute("SELECT name, tbl_name FROM sqlite_master WHERE type='index' ORDER BY name") - indexes = [{'name': row[0], 'table': row[1]} for row in cursor.fetchall()] - - # Get triggers - cursor.execute("SELECT name, tbl_name FROM sqlite_master WHERE type='trigger' ORDER BY name") - triggers = [{'name': row[0], 'table': row[1]} for row in cursor.fetchall()] - - # Get database size - db_size = path.stat().st_size - - app_logger.info(f"Inspected database: {path} ({len(tables)} tables)") - log_mcp_stat('inspect_database', str(path.parent), True) - - return { - 'database': str(path), - 'size': db_size, - 'size_mb': db_size / (1024 * 1024), - 'tables': tables, - 'table_count': len(tables), - 'indexes': indexes, - 'triggers': triggers - } - finally: - conn.close() - - except Exception as e: - error_msg = f"Error inspecting database: {e}" - app_logger.error(error_msg) - log_mcp_stat('inspect_database', db_path, False, str(e)) - return {'error': error_msg} - - async def search_database(self, db_path: str, search_term: str, table_name: Optional[str] = None, column_name: Optional[str] = None) -> Dict[str, Any]: - """Search for a value across database tables.""" - try: - path = Path(db_path).resolve() - - if not path.exists(): - error_msg = f"Database not found: {path}" - app_logger.warning(error_msg) - log_mcp_stat('search_database', str(path.parent), False, error_msg) - return {'error': error_msg} - - # Open database in read-only mode - try: - conn = sqlite3.connect(f'file:{path}?mode=ro', uri=True) - cursor = conn.cursor() - except sqlite3.DatabaseError as e: - error_msg = f"Not a valid SQLite database: {path} ({e})" - app_logger.warning(error_msg) - log_mcp_stat('search_database', str(path.parent), False, error_msg) - return {'error': error_msg} - - try: - matches = [] - - # Get tables to search - if table_name: - cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,)) - tables = [row[0] for row in cursor.fetchall()] - if not tables: - return {'error': f"Table '{table_name}' not found"} - else: - cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") - tables = [row[0] for row in cursor.fetchall()] - - for table in tables: - # Get columns for this table - cursor.execute(f"PRAGMA table_info({table})") - columns = [row[1] for row in cursor.fetchall()] - - # Filter columns if specified - if column_name: - if column_name not in columns: - continue - columns = [column_name] - - # Search in each column - for column in columns: - try: - # Use LIKE for partial matching - query = f"SELECT * FROM {table} WHERE {column} LIKE ? LIMIT {self.default_query_limit}" - cursor.execute(query, (f'%{search_term}%',)) - results = cursor.fetchall() - - if results: - # Get column names - col_names = [desc[0] for desc in cursor.description] - - for row in results: - # Convert row to dict, handling binary data - row_dict = {} - for col_name, value in zip(col_names, row): - if isinstance(value, bytes): - # Convert bytes to hex string or base64 - try: - # Try to decode as UTF-8 first - row_dict[col_name] = value.decode('utf-8') - except UnicodeDecodeError: - # If binary, show as hex string - row_dict[col_name] = f"" - else: - row_dict[col_name] = value - - matches.append({ - 'table': table, - 'column': column, - 'row': row_dict - }) - except sqlite3.Error: - # Skip columns that can't be searched (e.g., BLOBs) - continue - - app_logger.info(f"Searched database: {path} for '{search_term}' - found {len(matches)} matches") - log_mcp_stat('search_database', str(path.parent), True) - - result = { - 'database': str(path), - 'search_term': search_term, - 'matches': matches, - 'count': len(matches) - } - - if table_name: - result['table_filter'] = table_name - if column_name: - result['column_filter'] = column_name - - if len(matches) >= self.default_query_limit: - result['note'] = f'Results limited to {self.default_query_limit} matches. Use more specific search or query_database for full results.' - - return result - - finally: - conn.close() - - except Exception as e: - error_msg = f"Error searching database: {e}" - app_logger.error(error_msg) - log_mcp_stat('search_database', db_path, False, str(e)) - return {'error': error_msg} - - async def query_database(self, db_path: str, query: str, limit: Optional[int] = None) -> Dict[str, Any]: - """Execute a read-only SQL query on database.""" - try: - path = Path(db_path).resolve() - - if not path.exists(): - error_msg = f"Database not found: {path}" - app_logger.warning(error_msg) - log_mcp_stat('query_database', str(path.parent), False, error_msg) - return {'error': error_msg} - - # Validate query - is_safe, error_msg = SQLiteQueryValidator.is_safe_query(query) - if not is_safe: - app_logger.warning(f"Unsafe query rejected: {error_msg}") - log_mcp_stat('query_database', str(path.parent), False, error_msg) - return {'error': error_msg} - - # Open database in read-only mode - try: - conn = sqlite3.connect(f'file:{path}?mode=ro', uri=True) - cursor = conn.cursor() - except sqlite3.DatabaseError as e: - error_msg = f"Not a valid SQLite database: {path} ({e})" - app_logger.warning(error_msg) - log_mcp_stat('query_database', str(path.parent), False, error_msg) - return {'error': error_msg} - - try: - # Set query timeout - def timeout_handler(signum, frame): - raise TimeoutError(f"Query exceeded {self.max_query_timeout} second timeout") - - # Note: signal.alarm only works on Unix-like systems - if hasattr(signal, 'SIGALRM'): - signal.signal(signal.SIGALRM, timeout_handler) - signal.alarm(self.max_query_timeout) - - try: - # Execute query - cursor.execute(query) - - # Get results - result_limit = min(limit or self.default_query_limit, self.max_query_results) - results = cursor.fetchmany(result_limit) - - # Get column names - columns = [desc[0] for desc in cursor.description] if cursor.description else [] - - # Convert to list of dicts, handling binary data - rows = [] - for row in results: - row_dict = {} - for col_name, value in zip(columns, row): - if isinstance(value, bytes): - # Convert bytes to readable format - try: - # Try to decode as UTF-8 first - row_dict[col_name] = value.decode('utf-8') - except UnicodeDecodeError: - # If binary, show as hex string or summary - row_dict[col_name] = f"" - else: - row_dict[col_name] = value - rows.append(row_dict) - - # Check if more results available - has_more = len(results) == result_limit - if has_more: - # Try to fetch one more to check - one_more = cursor.fetchone() - has_more = one_more is not None - - finally: - if hasattr(signal, 'SIGALRM'): - signal.alarm(0) # Cancel timeout - - app_logger.info(f"Executed query on {path}: returned {len(rows)} rows") - log_mcp_stat('query_database', str(path.parent), True) - - result = { - 'database': str(path), - 'query': query, - 'columns': columns, - 'rows': rows, - 'count': len(rows) - } - - if has_more: - result['truncated'] = True - result['note'] = f'Results limited to {result_limit} rows. Use LIMIT clause in query for more control.' - - return result - - except TimeoutError as e: - error_msg = str(e) - app_logger.warning(error_msg) - log_mcp_stat('query_database', str(path.parent), False, error_msg) - return {'error': error_msg} - except sqlite3.Error as e: - error_msg = f"SQL error: {e}" - app_logger.warning(error_msg) - log_mcp_stat('query_database', str(path.parent), False, error_msg) - return {'error': error_msg} - finally: - if hasattr(signal, 'SIGALRM'): - signal.alarm(0) # Ensure timeout is cancelled - conn.close() - - except Exception as e: - error_msg = f"Error executing query: {e}" - app_logger.error(error_msg) - log_mcp_stat('query_database', db_path, False, str(e)) - return {'error': error_msg} - - -class MCPManager: - """Manage MCP server lifecycle, tool calls, and mode switching.""" - - def __init__(self): - self.enabled = False - self.write_enabled = False # Write mode off by default (non-persistent) - self.mode = "files" # "files" or "database" - self.selected_db_index = None - - self.server: Optional[MCPFilesystemServer] = None - - # File/folder mode - self.allowed_folders: List[Path] = [] - - # Database mode - self.databases: List[Dict[str, Any]] = [] - - self.config = CrossPlatformMCPConfig() - self.session_start_time: Optional[datetime.datetime] = None - - # Load persisted data - self._load_folders() - self._load_databases() - - app_logger.info("MCP Manager initialized") - - def _load_folders(self): - """Load allowed folders from database.""" - folders_json = get_mcp_config('allowed_folders') - if folders_json: - try: - folder_paths = json.loads(folders_json) - self.allowed_folders = [self.config.normalize_path(p) for p in folder_paths] - app_logger.info(f"Loaded {len(self.allowed_folders)} folders from config") - except Exception as e: - app_logger.error(f"Error loading MCP folders: {e}") - self.allowed_folders = [] - - def _save_folders(self): - """Save allowed folders to database.""" - folder_paths = [str(p) for p in self.allowed_folders] - set_mcp_config('allowed_folders', json.dumps(folder_paths)) - app_logger.info(f"Saved {len(self.allowed_folders)} folders to config") - - def _load_databases(self): - """Load databases from database.""" - create_table_if_not_exists() - try: - with sqlite3.connect(str(database)) as conn: - cursor = conn.execute('SELECT id, path, name, size, tables, added_timestamp FROM mcp_databases ORDER BY id') - self.databases = [] - for row in cursor.fetchall(): - tables_list = json.loads(row[4]) if row[4] else [] - self.databases.append({ - 'id': row[0], - 'path': row[1], - 'name': row[2], - 'size': row[3], - 'tables': tables_list, - 'added': row[5] - }) - app_logger.info(f"Loaded {len(self.databases)} databases from config") - except Exception as e: - app_logger.error(f"Error loading databases: {e}") - self.databases = [] - - def _save_database(self, db_info: Dict[str, Any]): - """Save a database to config.""" - create_table_if_not_exists() - try: - with sqlite3.connect(str(database)) as conn: - conn.execute( - 'INSERT INTO mcp_databases (path, name, size, tables, added_timestamp) VALUES (?, ?, ?, ?, ?)', - (db_info['path'], db_info['name'], db_info['size'], json.dumps(db_info['tables']), db_info['added']) - ) - conn.commit() - # Get the ID - cursor = conn.execute('SELECT id FROM mcp_databases WHERE path = ?', (db_info['path'],)) - db_info['id'] = cursor.fetchone()[0] - app_logger.info(f"Saved database {db_info['name']} to config") - except Exception as e: - app_logger.error(f"Error saving database: {e}") - - def _remove_database_from_config(self, db_path: str): - """Remove database from config.""" - create_table_if_not_exists() - try: - with sqlite3.connect(str(database)) as conn: - conn.execute('DELETE FROM mcp_databases WHERE path = ?', (db_path,)) - conn.commit() - app_logger.info(f"Removed database from config: {db_path}") - except Exception as e: - app_logger.error(f"Error removing database from config: {e}") - - def enable(self) -> Dict[str, Any]: - """Enable MCP server.""" - if not MCP_AVAILABLE: - return { - 'success': False, - 'error': 'MCP library not installed. Run: pip install mcp' - } - - if self.enabled: - return { - 'success': False, - 'error': 'MCP is already enabled' - } - - try: - self.server = MCPFilesystemServer(self.allowed_folders) - self.enabled = True - self.session_start_time = datetime.datetime.now() - - set_mcp_config('mcp_enabled', 'true') - - app_logger.info("MCP Filesystem Server enabled") - - return { - 'success': True, - 'folder_count': len(self.allowed_folders), - 'database_count': len(self.databases), - 'message': 'MCP Filesystem Server started successfully' - } - except Exception as e: - app_logger.error(f"Error enabling MCP: {e}") - return { - 'success': False, - 'error': str(e) - } - - def disable(self) -> Dict[str, Any]: - """Disable MCP server.""" - if not self.enabled: - return { - 'success': False, - 'error': 'MCP is not enabled' - } - - try: - self.server = None - self.enabled = False - self.session_start_time = None - self.mode = "files" - self.selected_db_index = None - - set_mcp_config('mcp_enabled', 'false') - - app_logger.info("MCP Filesystem Server disabled") - - return { - 'success': True, - 'message': 'MCP Filesystem Server stopped' - } - except Exception as e: - app_logger.error(f"Error disabling MCP: {e}") - return { - 'success': False, - 'error': str(e) - } - - def enable_write(self) -> None: - """Enable write mode with user confirmation.""" - if not self.enabled: - console.print("[bold red]Error:[/] MCP must be enabled first. Use '/mcp on'") - app_logger.warning("Attempted to enable write mode without MCP enabled") - return - - # Show warning panel - console.print(Panel.fit( - "[bold yellow]⚠️ WRITE MODE WARNING[/]\n\n" - "Enabling write mode allows the LLM to:\n" - "• Create new files\n" - "• Modify existing files\n" - "• Delete files (with confirmation)\n" - "• Create directories\n" - "• Move and copy files\n\n" - "[dim]Operations are limited to your allowed MCP folders.[/]\n" - "[dim]Deletions will always require your confirmation.[/]\n\n" - "[bold red]Use with caution. Review LLM changes carefully.[/]", - title="MCP Write Mode", - border_style="yellow" - )) - - confirm = typer.confirm("Enable write mode?", default=False) - - if confirm: - self.write_enabled = True - console.print("[bold green]✓ Write mode enabled[/]") - app_logger.info("MCP write mode enabled by user") - log_mcp_stat('write_mode_enabled', '', True) - else: - console.print("[yellow]Cancelled. Write mode remains disabled.[/]") - app_logger.info("User cancelled write mode enablement") - - def disable_write(self) -> None: - """Disable write mode.""" - if not self.write_enabled: - console.print("[dim]Write mode is already disabled.[/]") - return - - self.write_enabled = False - console.print("[bold green]✓ Write mode disabled[/]") - app_logger.info("MCP write mode disabled") - log_mcp_stat('write_mode_disabled', '', True) - - def switch_mode(self, new_mode: str, db_index: Optional[int] = None) -> Dict[str, Any]: - """Switch between files and database mode.""" - if not self.enabled: - return { - 'success': False, - 'error': 'MCP is not enabled. Use /mcp on first' - } - - if new_mode == "files": - self.mode = "files" - self.selected_db_index = None - app_logger.info("Switched to file mode") - return { - 'success': True, - 'mode': 'files', - 'message': 'Switched to file mode', - 'tools': ['read_file', 'list_directory', 'search_files'] - } - elif new_mode == "database": - if db_index is None: - return { - 'success': False, - 'error': 'Database index required. Use /mcp db ' - } - - if db_index < 1 or db_index > len(self.databases): - return { - 'success': False, - 'error': f'Invalid database number. Use 1-{len(self.databases)}' - } - - self.mode = "database" - self.selected_db_index = db_index - 1 # Convert to 0-based - db = self.databases[self.selected_db_index] - - app_logger.info(f"Switched to database mode: {db['name']}") - return { - 'success': True, - 'mode': 'database', - 'database': db, - 'message': f"Switched to database #{db_index}: {db['name']}", - 'tools': ['inspect_database', 'search_database', 'query_database'] - } - else: - return { - 'success': False, - 'error': f'Invalid mode: {new_mode}' - } - - def add_folder(self, folder_path: str) -> Dict[str, Any]: - """Add folder to allowed list.""" - if not self.enabled: - return { - 'success': False, - 'error': 'MCP is not enabled. Use /mcp on first' - } - - try: - path = self.config.normalize_path(folder_path) - - # Validation checks - if not path.exists(): - return { - 'success': False, - 'error': f'Directory does not exist: {path}' - } - - if not path.is_dir(): - return { - 'success': False, - 'error': f'Not a directory: {path}' - } - - if self.config.is_system_directory(path): - return { - 'success': False, - 'error': f'Cannot add system directory: {path}' - } - - # Check if already added - if path in self.allowed_folders: - return { - 'success': False, - 'error': f'Folder already in allowed list: {path}' - } - - # Check for nested paths - parent_folder = None - for existing in self.allowed_folders: - try: - path.relative_to(existing) - parent_folder = existing - break - except ValueError: - continue - - # Get folder stats - stats = self.config.get_folder_stats(path) - - # Add folder - self.allowed_folders.append(path) - self._save_folders() - - # Update server and reload gitignores - if self.server: - self.server.allowed_folders = self.allowed_folders - self.server.reload_gitignores() - - app_logger.info(f"Added folder to MCP: {path}") - - result = { - 'success': True, - 'path': str(path), - 'stats': stats, - 'total_folders': len(self.allowed_folders) - } - - if parent_folder: - result['warning'] = f'Note: {parent_folder} is already allowed (parent folder)' - - return result - - except Exception as e: - app_logger.error(f"Error adding folder: {e}") - return { - 'success': False, - 'error': str(e) - } - - def remove_folder(self, folder_ref: str) -> Dict[str, Any]: - """Remove folder from allowed list (by path or number).""" - if not self.enabled: - return { - 'success': False, - 'error': 'MCP is not enabled. Use /mcp on first' - } - - try: - # Check if it's a number - if folder_ref.isdigit(): - index = int(folder_ref) - 1 - if 0 <= index < len(self.allowed_folders): - path = self.allowed_folders[index] - else: - return { - 'success': False, - 'error': f'Invalid folder number: {folder_ref}' - } - else: - # Treat as path - path = self.config.normalize_path(folder_ref) - if path not in self.allowed_folders: - return { - 'success': False, - 'error': f'Folder not in allowed list: {path}' - } - - # Remove folder - self.allowed_folders.remove(path) - self._save_folders() - - # Update server and reload gitignores - if self.server: - self.server.allowed_folders = self.allowed_folders - self.server.reload_gitignores() - - app_logger.info(f"Removed folder from MCP: {path}") - - result = { - 'success': True, - 'path': str(path), - 'total_folders': len(self.allowed_folders) - } - - if len(self.allowed_folders) == 0: - result['warning'] = 'No allowed folders remaining. Add one with /mcp add ' - - return result - - except Exception as e: - app_logger.error(f"Error removing folder: {e}") - return { - 'success': False, - 'error': str(e) - } - - def add_database(self, db_path: str) -> Dict[str, Any]: - """Add SQLite database.""" - if not self.enabled: - return { - 'success': False, - 'error': 'MCP is not enabled. Use /mcp on first' - } - - try: - path = Path(db_path).resolve() - - # Validation - if not path.exists(): - return { - 'success': False, - 'error': f'Database file not found: {path}' - } - - if not path.is_file(): - return { - 'success': False, - 'error': f'Not a file: {path}' - } - - # Check if already added - if any(db['path'] == str(path) for db in self.databases): - return { - 'success': False, - 'error': f'Database already added: {path.name}' - } - - # Validate it's a SQLite database and get schema - try: - conn = sqlite3.connect(f'file:{path}?mode=ro', uri=True) - cursor = conn.cursor() - - # Get tables - cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") - tables = [row[0] for row in cursor.fetchall()] - - conn.close() - except sqlite3.DatabaseError as e: - return { - 'success': False, - 'error': f'Not a valid SQLite database: {e}' - } - - # Get file size - db_size = path.stat().st_size - - # Create database entry - db_info = { - 'path': str(path), - 'name': path.name, - 'size': db_size, - 'tables': tables, - 'added': datetime.datetime.now().isoformat() - } - - # Save to config - self._save_database(db_info) - - # Add to list - self.databases.append(db_info) - - app_logger.info(f"Added database: {path.name}") - - return { - 'success': True, - 'database': db_info, - 'number': len(self.databases), - 'message': f'Added database #{len(self.databases)}: {path.name}' - } - - except Exception as e: - app_logger.error(f"Error adding database: {e}") - return { - 'success': False, - 'error': str(e) - } - - def remove_database(self, db_ref: str) -> Dict[str, Any]: - """Remove database (by number or path).""" - if not self.enabled: - return { - 'success': False, - 'error': 'MCP is not enabled' - } - - try: - # Check if it's a number - if db_ref.isdigit(): - index = int(db_ref) - 1 - if 0 <= index < len(self.databases): - db = self.databases[index] - else: - return { - 'success': False, - 'error': f'Invalid database number: {db_ref}' - } - else: - # Treat as path - path = str(Path(db_ref).resolve()) - db = next((d for d in self.databases if d['path'] == path), None) - if not db: - return { - 'success': False, - 'error': f'Database not found: {db_ref}' - } - index = self.databases.index(db) - - # If currently selected, deselect - if self.mode == "database" and self.selected_db_index == index: - self.mode = "files" - self.selected_db_index = None - - # Remove from config - self._remove_database_from_config(db['path']) - - # Remove from list - self.databases.pop(index) - - app_logger.info(f"Removed database: {db['name']}") - - result = { - 'success': True, - 'database': db, - 'message': f"Removed database: {db['name']}" - } - - if len(self.databases) == 0: - result['warning'] = 'No databases remaining. Add one with /mcp add db ' - - return result - - except Exception as e: - app_logger.error(f"Error removing database: {e}") - return { - 'success': False, - 'error': str(e) - } - - def list_databases(self) -> Dict[str, Any]: - """List all databases.""" - try: - db_list = [] - for idx, db in enumerate(self.databases, 1): - db_info = { - 'number': idx, - 'name': db['name'], - 'path': db['path'], - 'size_mb': db['size'] / (1024 * 1024), - 'tables': db['tables'], - 'table_count': len(db['tables']), - 'added': db['added'] - } - - # Check if file still exists - if not Path(db['path']).exists(): - db_info['warning'] = 'File not found' - - db_list.append(db_info) - - return { - 'success': True, - 'databases': db_list, - 'count': len(db_list) - } - - except Exception as e: - app_logger.error(f"Error listing databases: {e}") - return { - 'success': False, - 'error': str(e) - } - - def toggle_gitignore(self, enabled: bool) -> Dict[str, Any]: - """Toggle .gitignore filtering.""" - if not self.enabled: - return { - 'success': False, - 'error': 'MCP is not enabled' - } - - if not self.server: - return { - 'success': False, - 'error': 'MCP server not running' - } - - self.server.respect_gitignore = enabled - status = "enabled" if enabled else "disabled" - - app_logger.info(f".gitignore filtering {status}") - - return { - 'success': True, - 'message': f'.gitignore filtering {status}', - 'pattern_count': len(self.server.gitignore_parser.patterns) if enabled else 0 - } - - def list_folders(self) -> Dict[str, Any]: - """List all allowed folders with stats.""" - try: - folders_info = [] - total_files = 0 - total_size = 0 - - for idx, folder in enumerate(self.allowed_folders, 1): - stats = self.config.get_folder_stats(folder) - - folder_info = { - 'number': idx, - 'path': str(folder), - 'exists': stats.get('exists', False) - } - - if stats.get('exists'): - folder_info['file_count'] = stats['file_count'] - folder_info['size_mb'] = stats['size_mb'] - total_files += stats['file_count'] - total_size += stats['total_size'] - - folders_info.append(folder_info) - - return { - 'success': True, - 'folders': folders_info, - 'total_folders': len(self.allowed_folders), - 'total_files': total_files, - 'total_size_mb': total_size / (1024 * 1024) - } - - except Exception as e: - app_logger.error(f"Error listing folders: {e}") - return { - 'success': False, - 'error': str(e) - } - - def get_status(self) -> Dict[str, Any]: - """Get comprehensive MCP status.""" - try: - stats = get_mcp_stats() - - uptime = None - if self.session_start_time: - delta = datetime.datetime.now() - self.session_start_time - hours = delta.seconds // 3600 - minutes = (delta.seconds % 3600) // 60 - uptime = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m" - - folder_info = self.list_folders() - - gitignore_status = "enabled" if (self.server and self.server.respect_gitignore) else "disabled" - gitignore_patterns = len(self.server.gitignore_parser.patterns) if self.server else 0 - - # Current mode info - mode_info = { - 'mode': self.mode, - 'mode_display': '🔧 Files' if self.mode == 'files' else f'🗄️ DB #{self.selected_db_index + 1}' - } - - if self.mode == 'database' and self.selected_db_index is not None: - db = self.databases[self.selected_db_index] - mode_info['database'] = db - - return { - 'success': True, - 'enabled': self.enabled, - 'write_enabled': self.write_enabled, - 'uptime': uptime, - 'mode_info': mode_info, - 'folder_count': len(self.allowed_folders), - 'database_count': len(self.databases), - 'total_files': folder_info.get('total_files', 0), - 'total_size_mb': folder_info.get('total_size_mb', 0), - 'stats': stats, - 'tools_available': self._get_current_tools(), - 'gitignore_status': gitignore_status, - 'gitignore_patterns': gitignore_patterns - } - - except Exception as e: - app_logger.error(f"Error getting MCP status: {e}") - return { - 'success': False, - 'error': str(e) - } - - def _get_current_tools(self) -> List[str]: - """Get tools available in current mode.""" - if not self.enabled: - return [] - - if self.mode == "files": - return ['read_file', 'list_directory', 'search_files'] - elif self.mode == "database": - return ['inspect_database', 'search_database', 'query_database'] - - return [] - - async def call_tool(self, tool_name: str, **kwargs) -> Dict[str, Any]: - """Call an MCP tool.""" - if not self.enabled or not self.server: - return { - 'error': 'MCP is not enabled' - } - - # Check write permissions for write operations - write_tools = {'write_file', 'edit_file', 'delete_file', 'create_directory', 'move_file', 'copy_file'} - if tool_name in write_tools and not self.write_enabled: - return { - 'error': 'Write operations disabled. Enable with /mcp write on' - } - - try: - # File mode tools (read-only) - if tool_name == 'read_file': - return await self.server.read_file(kwargs.get('file_path', '')) - elif tool_name == 'list_directory': - return await self.server.list_directory( - kwargs.get('dir_path', ''), - kwargs.get('recursive', False) - ) - elif tool_name == 'search_files': - return await self.server.search_files( - kwargs.get('pattern', ''), - kwargs.get('search_path'), - kwargs.get('content_search', False) - ) - - # Database mode tools - elif tool_name == 'inspect_database': - if self.mode != 'database' or self.selected_db_index is None: - return {'error': 'Not in database mode. Use /mcp db first'} - db = self.databases[self.selected_db_index] - return await self.server.inspect_database( - db['path'], - kwargs.get('table_name') - ) - elif tool_name == 'search_database': - if self.mode != 'database' or self.selected_db_index is None: - return {'error': 'Not in database mode. Use /mcp db first'} - db = self.databases[self.selected_db_index] - return await self.server.search_database( - db['path'], - kwargs.get('search_term', ''), - kwargs.get('table_name'), - kwargs.get('column_name') - ) - elif tool_name == 'query_database': - if self.mode != 'database' or self.selected_db_index is None: - return {'error': 'Not in database mode. Use /mcp db first'} - db = self.databases[self.selected_db_index] - return await self.server.query_database( - db['path'], - kwargs.get('query', ''), - kwargs.get('limit') - ) - - # Write mode tools - elif tool_name == 'write_file': - return await self.server.write_file( - kwargs.get('file_path', ''), - kwargs.get('content', '') - ) - elif tool_name == 'edit_file': - return await self.server.edit_file( - kwargs.get('file_path', ''), - kwargs.get('old_text', ''), - kwargs.get('new_text', '') - ) - elif tool_name == 'delete_file': - return await self.server.delete_file( - kwargs.get('file_path', ''), - kwargs.get('reason', 'No reason provided') - ) - elif tool_name == 'create_directory': - return await self.server.create_directory( - kwargs.get('dir_path', '') - ) - elif tool_name == 'move_file': - return await self.server.move_file( - kwargs.get('source_path', ''), - kwargs.get('dest_path', '') - ) - elif tool_name == 'copy_file': - return await self.server.copy_file( - kwargs.get('source_path', ''), - kwargs.get('dest_path', '') - ) - else: - return { - 'error': f'Unknown tool: {tool_name}' - } - except Exception as e: - app_logger.error(f"Error calling MCP tool {tool_name}: {e}") - return { - 'error': str(e) - } - - def get_tools_schema(self) -> List[Dict[str, Any]]: - """Get MCP tools as OpenAI function calling schema (for current mode).""" - if not self.enabled: - return [] - - if self.mode == "files": - return self._get_file_tools_schema() - elif self.mode == "database": - return self._get_database_tools_schema() - - return [] - - def _get_file_tools_schema(self) -> List[Dict[str, Any]]: - """Get file mode tools schema.""" - if len(self.allowed_folders) == 0: - return [] - - allowed_dirs_str = ", ".join(str(f) for f in self.allowed_folders) - - # Base read-only tools - tools = [ - { - "type": "function", - "function": { - "name": "search_files", - "description": f"Search for files in the user's local filesystem. Allowed directories: {allowed_dirs_str}. Can search by filename pattern (e.g., '*.py' for Python files) or search within file contents. Automatically respects .gitignore patterns and excludes virtual environments.", - "parameters": { - "type": "object", - "properties": { - "pattern": { - "type": "string", - "description": "Search pattern. For filename search, use glob patterns like '*.py', '*.txt', 'report*'. For content search, use plain text to search for." - }, - "content_search": { - "type": "boolean", - "description": "If true, searches inside file contents for the pattern (SLOW - use only when specifically asked to search file contents). If false, searches only filenames (FAST - use for finding files by name). Default is false. Virtual environments and package directories are automatically excluded.", - "default": False - }, - "search_path": { - "type": "string", - "description": "Optional: specific directory to search in. If not provided, searches all allowed directories.", - } - }, - "required": ["pattern"] - } - } - }, - { - "type": "function", - "function": { - "name": "read_file", - "description": "Read the complete contents of a text file from the user's local filesystem. Only works for files within allowed directories. Maximum file size: 10MB. Files larger than 50KB are automatically truncated. Respects .gitignore patterns.", - "parameters": { - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "Full path to the file to read (e.g., /Users/username/Documents/report.txt or ~/Documents/notes.md)" - } - }, - "required": ["file_path"] - } - } - }, - { - "type": "function", - "function": { - "name": "list_directory", - "description": "List all files and subdirectories in a directory from the user's local filesystem. Only works for directories within allowed paths. Automatically filters out virtual environments, build artifacts, and .gitignore patterns. Limited to 1000 items.", - "parameters": { - "type": "object", - "properties": { - "dir_path": { - "type": "string", - "description": "Directory path to list (e.g., ~/Documents or /Users/username/Projects)" - }, - "recursive": { - "type": "boolean", - "description": "If true, lists all files in subdirectories recursively. If false, lists only immediate children. Default is true. WARNING: Recursive listings can be very large - use specific paths when possible.", - "default": True - } - }, - "required": ["dir_path"] - } - } - } - ] - - # Add write tools if write mode is enabled - if self.write_enabled: - tools.extend([ - { - "type": "function", - "function": { - "name": "write_file", - "description": f"Create a new file or overwrite an existing file with the specified content. Works within allowed directories: {allowed_dirs_str}. Ignores .gitignore patterns - can write to any file in allowed folders. Automatically creates parent directories if needed.", - "parameters": { - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "Full path to the file to create or overwrite (e.g., /Users/username/project/src/main.py)" - }, - "content": { - "type": "string", - "description": "The complete content to write to the file" - } - }, - "required": ["file_path", "content"] - } - } - }, - { - "type": "function", - "function": { - "name": "edit_file", - "description": f"Make targeted edits to an existing file by finding and replacing specific text. The old_text must match exactly and appear only once in the file. For multiple matches, provide more context to make the match unique. Works within allowed directories: {allowed_dirs_str}. Ignores .gitignore patterns.", - "parameters": { - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "Full path to the file to edit" - }, - "old_text": { - "type": "string", - "description": "The exact text to find and replace. Must match exactly and appear only once. Include enough context to make the match unique." - }, - "new_text": { - "type": "string", - "description": "The new text to replace the old text with" - } - }, - "required": ["file_path", "old_text", "new_text"] - } - } - }, - { - "type": "function", - "function": { - "name": "delete_file", - "description": f"Delete a file from the filesystem. ALWAYS requires user confirmation before deletion. The user will see the file path, size, modification time, and your reason before deciding. Works within allowed directories: {allowed_dirs_str}. Ignores .gitignore patterns.", - "parameters": { - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "Full path to the file to delete" - }, - "reason": { - "type": "string", - "description": "Clear explanation for why this file should be deleted. The user will see this reason in the confirmation prompt." - } - }, - "required": ["file_path", "reason"] - } - } - }, - { - "type": "function", - "function": { - "name": "create_directory", - "description": f"Create a new directory (and all parent directories if needed). If the directory already exists, returns success without error. Works within allowed directories: {allowed_dirs_str}. Ignores .gitignore patterns.", - "parameters": { - "type": "object", - "properties": { - "dir_path": { - "type": "string", - "description": "Full path to the directory to create (e.g., /Users/username/project/src/components)" - } - }, - "required": ["dir_path"] - } - } - }, - { - "type": "function", - "function": { - "name": "move_file", - "description": f"Move or rename a file. Both source and destination must be within allowed directories: {allowed_dirs_str}. Creates destination parent directories if needed. Ignores .gitignore patterns.", - "parameters": { - "type": "object", - "properties": { - "source_path": { - "type": "string", - "description": "Full path to the file to move/rename" - }, - "dest_path": { - "type": "string", - "description": "Full path for the new location/name" - } - }, - "required": ["source_path", "dest_path"] - } - } - }, - { - "type": "function", - "function": { - "name": "copy_file", - "description": f"Copy a file to a new location. Both source and destination must be within allowed directories: {allowed_dirs_str}. Creates destination parent directories if needed. Preserves file metadata. Ignores .gitignore patterns.", - "parameters": { - "type": "object", - "properties": { - "source_path": { - "type": "string", - "description": "Full path to the file to copy" - }, - "dest_path": { - "type": "string", - "description": "Full path for the copy destination" - } - }, - "required": ["source_path", "dest_path"] - } - } - } - ]) - - return tools - - def _get_database_tools_schema(self) -> List[Dict[str, Any]]: - """Get database mode tools schema.""" - if self.selected_db_index is None or self.selected_db_index >= len(self.databases): - return [] - - db = self.databases[self.selected_db_index] - db_name = db['name'] - tables_str = ", ".join(db['tables']) - - return [ - { - "type": "function", - "function": { - "name": "inspect_database", - "description": f"Inspect the schema of the currently selected SQLite database ({db_name}). Can get all tables or details for a specific table including columns, types, indexes, and row counts. Available tables: {tables_str}.", - "parameters": { - "type": "object", - "properties": { - "table_name": { - "type": "string", - "description": f"Optional: specific table to inspect. If not provided, returns info for all tables. Available: {tables_str}" - } - }, - "required": [] - } - } - }, - { - "type": "function", - "function": { - "name": "search_database", - "description": f"Search for a value across all tables in the database ({db_name}). Performs a LIKE search (partial matching) across all columns or specific table/column. Returns matching rows with all column data. Limited to {self.server.default_query_limit} results.", - "parameters": { - "type": "object", - "properties": { - "search_term": { - "type": "string", - "description": "Value to search for (partial match supported)" - }, - "table_name": { - "type": "string", - "description": f"Optional: limit search to specific table. Available: {tables_str}" - }, - "column_name": { - "type": "string", - "description": "Optional: limit search to specific column within the table" - } - }, - "required": ["search_term"] - } - } - }, - { - "type": "function", - "function": { - "name": "query_database", - "description": f"Execute a read-only SQL query on the database ({db_name}). Supports SELECT queries including JOINs, subqueries, CTEs, aggregations, etc. Maximum {self.server.max_query_results} rows returned. Query timeout: {self.server.max_query_timeout} seconds. INSERT/UPDATE/DELETE/DROP are blocked for safety.", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": f"SQL SELECT query to execute. Available tables: {tables_str}. Example: SELECT * FROM table WHERE column = 'value' LIMIT 10" - }, - "limit": { - "type": "integer", - "description": f"Optional: maximum rows to return (default {self.server.default_query_limit}, max {self.server.max_query_results}). You can also use LIMIT in your query." - } - }, - "required": ["query"] - } - } - } - ] - -# Global MCP manager -mcp_manager = MCPManager() - -# ============================================================================ -# HELPER FUNCTIONS -# ============================================================================ - -def supports_function_calling(model: Dict[str, Any]) -> bool: - """Check if model supports function calling.""" - supported_params = model.get("supported_parameters", []) - return "tools" in supported_params or "functions" in supported_params - -def supports_tools(model: Dict[str, Any]) -> bool: - """Check if model supports tools/function calling (same as supports_function_calling).""" - return supports_function_calling(model) - -def check_for_updates(current_version: str) -> str: - """Check for updates.""" - try: - response = requests.get( - 'https://gitlab.pm/api/v1/repos/rune/oai/releases/latest', - headers={"Content-Type": "application/json"}, - timeout=1.0, - allow_redirects=True - ) - response.raise_for_status() - - data = response.json() - version_online = data.get('tag_name', '').lstrip('v') - - if not version_online: - logger.warning("No version found in API response") - return f"[bold green]oAI version {current_version}[/]" - - current = pkg_version.parse(current_version) - latest = pkg_version.parse(version_online) - - if latest > current: - logger.info(f"Update available: {current_version} → {version_online}") - return f"[bold green]oAI version {current_version} [/][bold red](Update available: {current_version} → {version_online})[/]" - else: - logger.debug(f"Already up to date: {current_version}") - return f"[bold green]oAI version {current_version} (up to date)[/]" - - except: - return f"[bold green]oAI version {current_version}[/]" - -def save_conversation(name: str, data: List[Dict[str, str]]): - """Save conversation.""" - timestamp = datetime.datetime.now().isoformat() - data_json = json.dumps(data) - with sqlite3.connect(str(database)) as conn: - conn.execute('INSERT INTO conversation_sessions (name, timestamp, data) VALUES (?, ?, ?)', (name, timestamp, data_json)) - conn.commit() - -def load_conversation(name: str) -> Optional[List[Dict[str, str]]]: - """Load conversation.""" - with sqlite3.connect(str(database)) as conn: - cursor = conn.execute('SELECT data FROM conversation_sessions WHERE name = ? ORDER BY timestamp DESC LIMIT 1', (name,)) - result = cursor.fetchone() - if result: - return json.loads(result[0]) - return None - -def delete_conversation(name: str) -> int: - """Delete conversation.""" - with sqlite3.connect(str(database)) as conn: - cursor = conn.execute('DELETE FROM conversation_sessions WHERE name = ?', (name,)) - conn.commit() - return cursor.rowcount - -def list_conversations() -> List[Dict[str, Any]]: - """List conversations.""" - with sqlite3.connect(str(database)) as conn: - cursor = conn.execute(''' - SELECT name, MAX(timestamp) as last_saved, data - FROM conversation_sessions - GROUP BY name - ORDER BY last_saved DESC - ''') - conversations = [] - for row in cursor.fetchall(): - name, timestamp, data_json = row - data = json.loads(data_json) - conversations.append({ - 'name': name, - 'timestamp': timestamp, - 'message_count': len(data) - }) - return conversations - -def estimate_cost(input_tokens: int, output_tokens: int) -> float: - """Estimate cost.""" - return (input_tokens * MODEL_PRICING['input'] / 1_000_000) + (output_tokens * MODEL_PRICING['output'] / 1_000_000) - -def has_web_search_capability(model: Dict[str, Any]) -> bool: - """Check web search capability.""" - supported_params = model.get("supported_parameters", []) - return "tools" in supported_params - -def has_image_capability(model: Dict[str, Any]) -> bool: - """Check image capability.""" - architecture = model.get("architecture", {}) - input_modalities = architecture.get("input_modalities", []) - return "image" in input_modalities - -def supports_online_mode(model: Dict[str, Any]) -> bool: - """Check online mode support.""" - return has_web_search_capability(model) - -def get_effective_model_id(base_model_id: str, online_enabled: bool) -> str: - """Get effective model ID.""" - if online_enabled and not base_model_id.endswith(':online'): - return f"{base_model_id}:online" - return base_model_id - -def export_as_markdown(session_history: List[Dict[str, str]], session_system_prompt: str = "") -> str: - """Export as Markdown.""" - lines = ["# Conversation Export", ""] - if session_system_prompt: - lines.extend([f"**System Prompt:** {session_system_prompt}", ""]) - lines.append(f"**Export Date:** {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") - lines.append("") - lines.append("---") - lines.append("") - - for i, entry in enumerate(session_history, 1): - lines.append(f"## Message {i}") - lines.append("") - lines.append("**User:**") - lines.append("") - lines.append(entry['prompt']) - lines.append("") - lines.append("**Assistant:**") - lines.append("") - lines.append(entry['response']) - lines.append("") - lines.append("---") - lines.append("") - - return "\n".join(lines) - -def export_as_json(session_history: List[Dict[str, str]], session_system_prompt: str = "") -> str: - """Export as JSON.""" - export_data = { - "export_date": datetime.datetime.now().isoformat(), - "system_prompt": session_system_prompt, - "message_count": len(session_history), - "messages": session_history - } - return json.dumps(export_data, indent=2, ensure_ascii=False) - -def export_as_html(session_history: List[Dict[str, str]], session_system_prompt: str = "") -> str: - """Export as HTML.""" - def escape_html(text): - return text.replace('&', '&').replace('<', '<').replace('>', '>').replace('"', '"').replace("'", ''') - - html_parts = [ - "", - "", - "", - " ", - " ", - " Conversation Export", - " ", - "", - "", - "
", - "

💬 Conversation Export

", - f"
📅 Exported: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
", - f"
📊 Total Messages: {len(session_history)}
", - "
", - ] - - if session_system_prompt: - html_parts.extend([ - "
", - " ⚙️ System Prompt", - f"
{escape_html(session_system_prompt)}
", - "
", - ]) - - for i, entry in enumerate(session_history, 1): - html_parts.extend([ - "
", - f"
Message {i} of {len(session_history)}
", - "
", - "
👤 User
", - f"
{escape_html(entry['prompt'])}
", - "
", - "
", - "
🤖 Assistant
", - f"
{escape_html(entry['response'])}
", - "
", - "
", - ]) - - html_parts.extend([ - "
", - "

Generated by oAI Chat • https://iurl.no/oai

", - "
", - "", - "", - ]) - - return "\n".join(html_parts) - -def show_command_help(command_or_topic: str): - """Display detailed help for command or topic.""" - # Handle topics (like 'mcp') - if not command_or_topic.startswith('/'): - if command_or_topic.lower() == 'mcp': - help_data = COMMAND_HELP['mcp'] - - help_content = [] - help_content.append(f"[bold cyan]Description:[/]") - help_content.append(help_data['description']) - help_content.append("") - help_content.append(help_data['notes']) - - console.print(Panel( - "\n".join(help_content), - title="[bold green]MCP - Model Context Protocol Guide[/]", - title_align="left", - border_style="green", - width=console.width - 4 - )) - - app_logger.info("Displayed MCP comprehensive guide") - return - else: - command_or_topic = '/' + command_or_topic - - if command_or_topic not in COMMAND_HELP: - console.print(f"[bold red]Unknown command: {command_or_topic}[/]") - console.print("[bold yellow]Type /help to see all available commands.[/]") - console.print("[bold yellow]Type /help mcp for comprehensive MCP guide.[/]") - app_logger.warning(f"Help requested for unknown command: {command_or_topic}") - return - - help_data = COMMAND_HELP[command_or_topic] - - help_content = [] - - if 'aliases' in help_data: - aliases_str = ", ".join(help_data['aliases']) - help_content.append(f"[bold cyan]Aliases:[/] {aliases_str}") - help_content.append("") - - help_content.append(f"[bold cyan]Description:[/]") - help_content.append(help_data['description']) - help_content.append("") - - help_content.append(f"[bold cyan]Usage:[/]") - help_content.append(f"[yellow]{help_data['usage']}[/]") - help_content.append("") - - if 'examples' in help_data and help_data['examples']: - help_content.append(f"[bold cyan]Examples:[/]") - for desc, example in help_data['examples']: - if not desc and not example: - help_content.append("") - elif desc.startswith('━━━'): - help_content.append(f"[bold yellow]{desc}[/]") - else: - help_content.append(f" [dim]{desc}:[/]" if desc else "") - help_content.append(f" [green]{example}[/]" if example else "") - help_content.append("") - - if 'notes' in help_data: - help_content.append(f"[bold cyan]Notes:[/]") - help_content.append(f"[dim]{help_data['notes']}[/]") - - console.print(Panel( - "\n".join(help_content), - title=f"[bold green]Help: {command_or_topic}[/]", - title_align="left", - border_style="green", - width=console.width - 4 - )) - - app_logger.info(f"Displayed detailed help for command: {command_or_topic}") - -def get_credits(api_key: str, base_url: str = OPENROUTER_BASE_URL) -> Optional[Dict[str, str]]: - """Get credits.""" - if not api_key: - return None - url = f"{base_url}/credits" - headers = { - "Authorization": f"Bearer {api_key}", - "HTTP-Referer": APP_URL, - "X-Title": APP_NAME - } - try: - response = requests.get(url, headers=headers) - response.raise_for_status() - data = response.json().get('data', {}) - total_credits = float(data.get('total_credits', 0)) - total_usage = float(data.get('total_usage', 0)) - credits_left = total_credits - total_usage - return { - 'total_credits': f"${total_credits:.2f}", - 'used_credits': f"${total_usage:.2f}", - 'credits_left': f"${credits_left:.2f}" - } - except Exception as e: - console.print(f"[bold red]Error fetching credits: {e}[/]") - return None - -def check_credit_alerts(credits_data: Optional[Dict[str, str]]) -> List[str]: - """Check credit alerts.""" - alerts = [] - if credits_data: - credits_left_value = float(credits_data['credits_left'].strip('$')) - total_credits_value = float(credits_data['total_credits'].strip('$')) - if credits_left_value < LOW_CREDIT_AMOUNT: - alerts.append(f"Critical credit alert: Less than ${LOW_CREDIT_AMOUNT:.2f} left ({credits_data['credits_left']})") - elif credits_left_value < total_credits_value * LOW_CREDIT_RATIO: - alerts.append(f"Low credit alert: Credits left < 10% of total ({credits_data['credits_left']})") - return alerts - -def clear_screen(): - """Clear screen.""" - try: - print("\033[H\033[J", end="", flush=True) - except: - print("\n" * 100) - -def display_paginated_table(table: Table, title: str): - """Display paginated table.""" - try: - terminal_height = os.get_terminal_size().lines - 8 - except: - terminal_height = 20 - - from rich.segment import Segment - - segments = list(console.render(table)) - - current_line_segments = [] - all_lines = [] - - for segment in segments: - if segment.text == '\n': - all_lines.append(current_line_segments) - current_line_segments = [] - else: - current_line_segments.append(segment) - - if current_line_segments: - all_lines.append(current_line_segments) - - total_lines = len(all_lines) - - if total_lines <= terminal_height: - console.print(Panel(table, title=title, title_align="left")) - return - - header_lines = [] - data_lines = [] - footer_line = [] - - header_end_index = 0 - found_header_text = False - - for i, line_segments in enumerate(all_lines): - has_header_style = any( - seg.style and ('bold' in str(seg.style) or 'magenta' in str(seg.style)) - for seg in line_segments - ) - - if has_header_style: - found_header_text = True - - if found_header_text and i > 0: - line_text = ''.join(seg.text for seg in line_segments) - if any(char in line_text for char in ['─', '━', '┼', '╪', '┤', '├']): - header_end_index = i - break - - # Extract footer (bottom border) - it's the last line of the table - if all_lines: - last_line_text = ''.join(seg.text for seg in all_lines[-1]) - # Check if last line is a border line (contains box drawing characters) - if any(char in last_line_text for char in ['─', '━', '┴', '╧', '┘', '└']): - footer_line = all_lines[-1] - all_lines = all_lines[:-1] # Remove footer from all_lines - - if header_end_index > 0: - header_lines = all_lines[:header_end_index + 1] - data_lines = all_lines[header_end_index + 1:] - else: - header_lines = all_lines[:min(3, len(all_lines))] - data_lines = all_lines[min(3, len(all_lines)):] - - lines_per_page = terminal_height - len(header_lines) - - current_line = 0 - page_number = 1 - - while current_line < len(data_lines): - clear_screen() - - console.print(f"[bold cyan]{title} (Page {page_number})[/]") - - for line_segments in header_lines: - for segment in line_segments: - console.print(segment.text, style=segment.style, end="") - console.print() - - end_line = min(current_line + lines_per_page, len(data_lines)) - - for line_segments in data_lines[current_line:end_line]: - for segment in line_segments: - console.print(segment.text, style=segment.style, end="") - console.print() - - # Add footer (bottom border) on each page - if footer_line: - for segment in footer_line: - console.print(segment.text, style=segment.style, end="") - console.print() - - current_line = end_line - page_number += 1 - - if current_line < len(data_lines): - console.print(f"\n[dim yellow]--- Press SPACE for next page, or any other key to finish (Page {page_number - 1}, showing {end_line}/{len(data_lines)} data rows) ---[/dim yellow]") - try: - import sys - import tty - import termios - - fd = sys.stdin.fileno() - old_settings = termios.tcgetattr(fd) - - try: - tty.setraw(fd) - char = sys.stdin.read(1) - - if char != ' ': - break - finally: - termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) - except: - input_char = input().strip() - if input_char != '': - break - else: - break - # P3 - # ============================================================================ -# MAIN CHAT FUNCTION WITH FULL MCP SUPPORT (FILES + DATABASE) -# ============================================================================ - -@app.command() -def chat(): - global API_KEY, OPENROUTER_BASE_URL, STREAM_ENABLED, MAX_TOKEN, COST_WARNING_THRESHOLD, DEFAULT_ONLINE_MODE, LOG_MAX_SIZE_MB, LOG_BACKUP_COUNT, LOG_LEVEL, LOG_LEVEL_STR, app_logger - - session_max_token = 0 - session_system_prompt = "" - session_history = [] - current_index = -1 - total_input_tokens = 0 - total_output_tokens = 0 - total_cost = 0.0 - message_count = 0 - middle_out_enabled = False - conversation_memory_enabled = True - memory_start_index = 0 - saved_conversations_cache = [] - online_mode_enabled = DEFAULT_ONLINE_MODE == "on" - - app_logger.info("Starting new chat session with memory enabled") - - if not API_KEY: - console.print("[bold red]API key not found. Use '/config api'.[/]") - try: - new_api_key = typer.prompt("Enter API key") - if new_api_key.strip(): - set_config('api_key', new_api_key.strip()) - API_KEY = new_api_key.strip() - console.print("[bold green]API key saved. Re-run.[/]") - else: - raise typer.Exit() - except: - console.print("[bold red]No API key. Exiting.[/]") - raise typer.Exit() - - if not text_models: - console.print("[bold red]No models available. Check API key/URL.[/]") - raise typer.Exit() - - credits_data = get_credits(API_KEY, OPENROUTER_BASE_URL) - startup_credit_alerts = check_credit_alerts(credits_data) - if startup_credit_alerts: - startup_alert_msg = " | ".join(startup_credit_alerts) - console.print(f"[bold red]⚠️ Startup {startup_alert_msg}[/]") - app_logger.warning(f"Startup credit alerts: {startup_alert_msg}") - - selected_model = selected_model_default - - client = OpenRouter(api_key=API_KEY) - - if selected_model: - online_status = "enabled" if online_mode_enabled else "disabled" - console.print(f"[bold blue]Welcome to oAI![/] [bold red]Active model: {selected_model['name']}[/] [dim cyan](Online mode: {online_status})[/]") - else: - console.print("[bold blue]Welcome to oAI![/] [italic blue]Select a model with '/model'.[/]") - - if not selected_model: - console.print("[bold yellow]No model selected. Use '/model'.[/]") - - # Create custom key bindings to add tab support for autocomplete - kb = KeyBindings() - - @kb.add('tab') - def _(event): - """Accept suggestion with tab key.""" - b = event.current_buffer - suggestion = b.suggestion - if suggestion and b.document.is_cursor_at_the_end: - b.insert_text(suggestion.text) - - session = PromptSession( - history=FileHistory(str(history_file)), - key_bindings=kb - ) - - while True: - try: - # ============================================================ - # BUILD PROMPT PREFIX WITH MODE INDICATOR - # ============================================================ - prompt_prefix = "You> " - if mcp_manager.enabled: - if mcp_manager.mode == "files": - if mcp_manager.write_enabled: - prompt_prefix = "[🔧✍️ MCP: Files+Write] You> " - else: - prompt_prefix = "[🔧 MCP: Files] You> " - elif mcp_manager.mode == "database" and mcp_manager.selected_db_index is not None: - db = mcp_manager.databases[mcp_manager.selected_db_index] - prompt_prefix = f"[🗄️ MCP: DB #{mcp_manager.selected_db_index + 1}] You> " - - # ============================================================ - # INITIALIZE LOOP VARIABLES - # ============================================================ - text_part = "" - file_attachments = [] - content_blocks = [] - - user_input = session.prompt(prompt_prefix, auto_suggest=AutoSuggestFromHistory()).strip() - - # Handle escape sequence - if user_input.startswith("//"): - user_input = user_input[1:] - - # Check unknown commands - elif user_input.startswith("/") and user_input.lower() not in ["exit", "quit", "bye"]: - command_word = user_input.split()[0].lower() if user_input.split() else user_input.lower() - - if not any(command_word.startswith(cmd) for cmd in VALID_COMMANDS): - console.print(f"[bold red]Unknown command: {command_word}[/]") - console.print("[bold yellow]Type /help to see all available commands.[/]") - app_logger.warning(f"Unknown command attempted: {command_word}") - continue - - if user_input.lower() in ["exit", "quit", "bye"]: - total_tokens = total_input_tokens + total_output_tokens - app_logger.info(f"Session ended. Total messages: {message_count}, Total tokens: {total_tokens}, Total cost: ${total_cost:.4f}") - console.print("[bold yellow]Goodbye![/]") - return - - # ============================================================ - # MCP COMMANDS - # ============================================================ - if user_input.lower().startswith("/mcp"): - parts = user_input[5:].strip().split(maxsplit=1) - mcp_command = parts[0].lower() if parts else "" - mcp_args = parts[1] if len(parts) > 1 else "" - - if mcp_command in ["enable", "on"]: - result = mcp_manager.enable() - if result['success']: - console.print(f"[bold green]✓ {result['message']}[/]") - if result['folder_count'] == 0 and result['database_count'] == 0: - console.print("\n[bold yellow]No folders or databases configured yet.[/]") - console.print("\n[bold cyan]To get started:[/]") - console.print(" [bold yellow]Files:[/] /mcp add ~/Documents") - console.print(" [bold yellow]Databases:[/] /mcp add db ~/app/data.db") - else: - folders_msg = f"{result['folder_count']} folder(s)" if result['folder_count'] else "no folders" - dbs_msg = f"{result['database_count']} database(s)" if result['database_count'] else "no databases" - console.print(f"\n[bold cyan]MCP active with {folders_msg} and {dbs_msg}.[/]") - - warning = mcp_manager.config.get_filesystem_warning() - if warning: - console.print(warning) - else: - console.print(f"[bold red]❌ {result['error']}[/]") - continue - - elif mcp_command in ["disable", "off"]: - result = mcp_manager.disable() - if result['success']: - console.print(f"[bold green]✓ {result['message']}[/]") - else: - console.print(f"[bold red]❌ {result['error']}[/]") - continue - - elif mcp_command == "add": - # Check if it's a database or folder - if mcp_args.startswith("db "): - # Database - db_path = mcp_args[3:].strip() - if not db_path: - console.print("[bold red]Usage: /mcp add db [/]") - continue - - result = mcp_manager.add_database(db_path) - - if result['success']: - db = result['database'] - - console.print("\n[bold yellow]⚠️ Database Check:[/]") - console.print(f"Adding: [bold]{db['name']}[/]") - console.print(f"Size: {db['size'] / (1024 * 1024):.2f} MB") - console.print(f"Tables: {', '.join(db['tables'])} ({len(db['tables'])} total)") - - if db['size'] > 100 * 1024 * 1024: # > 100MB - console.print("\n[bold yellow]⚠️ Large database detected! Queries may be slow.[/]") - - console.print("\nMCP will be able to:") - console.print(" [green]✓[/green] Inspect database schema") - console.print(" [green]✓[/green] Search for data across tables") - console.print(" [green]✓[/green] Execute read-only SQL queries") - console.print(" [red]✗[/red] Modify data (read-only mode)") - - try: - confirm = typer.confirm("\nProceed?", default=True) - if not confirm: - # Remove it - mcp_manager.remove_database(str(result['number'])) - console.print("[bold yellow]Cancelled. Database not added.[/]") - continue - except (EOFError, KeyboardInterrupt): - mcp_manager.remove_database(str(result['number'])) - console.print("\n[bold yellow]Cancelled. Database not added.[/]") - continue - - console.print(f"\n[bold green]✓ {result['message']}[/]") - console.print(f"\n[bold cyan]Use '/mcp db {result['number']}' to start querying this database.[/]") - else: - console.print(f"[bold red]❌ {result['error']}[/]") - else: - # Folder - folder_path = mcp_args - if not folder_path: - console.print("[bold red]Usage: /mcp add or /mcp add db [/]") - continue - - result = mcp_manager.add_folder(folder_path) - - if result['success']: - stats = result['stats'] - - console.print("\n[bold yellow]⚠️ Security Check:[/]") - console.print(f"You are granting MCP access to: [bold]{result['path']}[/]") - - if stats.get('exists'): - file_count = stats['file_count'] - size_mb = stats['size_mb'] - console.print(f"This folder contains: [bold]{file_count} files ({size_mb:.1f} MB)[/]") - - console.print("\nMCP will be able to:") - console.print(" [green]✓[/green] Read files in this folder") - console.print(" [green]✓[/green] List and search files") - console.print(" [green]✓[/green] Access subfolders recursively") - console.print(" [green]✓[/green] Automatically respect .gitignore patterns") - console.print(" [red]✗[/red] Delete or modify files (read-only)") - - try: - confirm = typer.confirm("\nProceed?", default=True) - if not confirm: - mcp_manager.allowed_folders.remove(mcp_manager.config.normalize_path(folder_path)) - mcp_manager._save_folders() - console.print("[bold yellow]Cancelled. Folder not added.[/]") - continue - except (EOFError, KeyboardInterrupt): - mcp_manager.allowed_folders.remove(mcp_manager.config.normalize_path(folder_path)) - mcp_manager._save_folders() - console.print("\n[bold yellow]Cancelled. Folder not added.[/]") - continue - - console.print(f"\n[bold green]✓ Added {result['path']} to MCP allowed folders[/]") - console.print(f"[dim cyan]MCP now has access to {result['total_folders']} folder(s) total.[/]") - - if result.get('warning'): - console.print(f"\n[bold yellow]⚠️ {result['warning']}[/]") - else: - console.print(f"[bold red]❌ {result['error']}[/]") - continue - - elif mcp_command in ["remove", "rem"]: - # Check if it's a database or folder - if mcp_args.startswith("db "): - # Database - db_ref = mcp_args[3:].strip() - if not db_ref: - console.print("[bold red]Usage: /mcp remove db [/]") - continue - - result = mcp_manager.remove_database(db_ref) - - if result['success']: - console.print(f"\n[bold yellow]Removing: {result['database']['name']}[/]") - - try: - confirm = typer.confirm("Confirm removal?", default=False) - if not confirm: - console.print("[bold yellow]Cancelled.[/]") - # Re-add it - mcp_manager.databases.append(result['database']) - mcp_manager._save_database(result['database']) - continue - except (EOFError, KeyboardInterrupt): - console.print("\n[bold yellow]Cancelled.[/]") - mcp_manager.databases.append(result['database']) - mcp_manager._save_database(result['database']) - continue - - console.print(f"\n[bold green]✓ {result['message']}[/]") - - if result.get('warning'): - console.print(f"\n[bold yellow]⚠️ {result['warning']}[/]") - else: - console.print(f"[bold red]❌ {result['error']}[/]") - else: - # Folder - if not mcp_args: - console.print("[bold red]Usage: /mcp remove or /mcp remove db [/]") - continue - - result = mcp_manager.remove_folder(mcp_args) - - if result['success']: - console.print(f"\n[bold yellow]Removing: {result['path']}[/]") - - try: - confirm = typer.confirm("Confirm removal?", default=False) - if not confirm: - mcp_manager.allowed_folders.append(mcp_manager.config.normalize_path(result['path'])) - mcp_manager._save_folders() - console.print("[bold yellow]Cancelled. Folder not removed.[/]") - continue - except (EOFError, KeyboardInterrupt): - mcp_manager.allowed_folders.append(mcp_manager.config.normalize_path(result['path'])) - mcp_manager._save_folders() - console.print("\n[bold yellow]Cancelled. Folder not removed.[/]") - continue - - console.print(f"\n[bold green]✓ Removed {result['path']} from MCP allowed folders[/]") - console.print(f"[dim cyan]MCP now has access to {result['total_folders']} folder(s) total.[/]") - - if result.get('warning'): - console.print(f"\n[bold yellow]⚠️ {result['warning']}[/]") - else: - console.print(f"[bold red]❌ {result['error']}[/]") - continue - - elif mcp_command == "list": - result = mcp_manager.list_folders() - - if result['success']: - if result['total_folders'] == 0: - console.print("[bold yellow]No folders configured.[/]") - console.print("\n[bold cyan]Add a folder with: /mcp add ~/Documents[/]") - continue - - status_indicator = "[green]✓[/green]" if mcp_manager.enabled else "[red]✗[/red]" - - table = Table( - "No.", "Path", "Files", "Size", - show_header=True, - header_style="bold magenta" - ) - - for folder_info in result['folders']: - number = str(folder_info['number']) - path = folder_info['path'] - - if folder_info['exists']: - files = f"📁 {folder_info['file_count']}" - size = f"{folder_info['size_mb']:.1f} MB" - else: - files = "[red]Not found[/red]" - size = "-" - - table.add_row(number, path, files, size) - - gitignore_info = "" - if mcp_manager.server: - gitignore_status = "on" if mcp_manager.server.respect_gitignore else "off" - pattern_count = len(mcp_manager.server.gitignore_parser.patterns) - gitignore_info = f" | .gitignore: {gitignore_status} ({pattern_count} patterns)" - - console.print(Panel( - table, - title=f"[bold green]MCP Folders: {'Active' if mcp_manager.enabled else 'Inactive'} {status_indicator}[/]", - title_align="left", - subtitle=f"[dim]Total: {result['total_folders']} folders, {result['total_files']} files ({result['total_size_mb']:.1f} MB){gitignore_info}[/]", - subtitle_align="right" - )) - else: - console.print(f"[bold red]❌ {result['error']}[/]") - continue - - elif mcp_command == "db": - if not mcp_args: - # Show current database or list all - if mcp_manager.mode == "database" and mcp_manager.selected_db_index is not None: - db = mcp_manager.databases[mcp_manager.selected_db_index] - console.print(f"[bold cyan]Currently using database #{mcp_manager.selected_db_index + 1}: {db['name']}[/]") - console.print(f"[dim]Path: {db['path']}[/]") - console.print(f"[dim]Tables: {', '.join(db['tables'])}[/]") - else: - console.print("[bold yellow]Not in database mode. Use '/mcp db ' to select a database.[/]") - - # Also show hint to list - console.print("\n[dim]Use '/mcp db list' to see all databases[/]") - continue - - if mcp_args == "list": - result = mcp_manager.list_databases() - - if result['success']: - if result['count'] == 0: - console.print("[bold yellow]No databases configured.[/]") - console.print("\n[bold cyan]Add a database with: /mcp add db ~/app/data.db[/]") - continue - - table = Table( - "No.", "Name", "Tables", "Size", "Status", - show_header=True, - header_style="bold magenta" - ) - - for db_info in result['databases']: - number = str(db_info['number']) - name = db_info['name'] - table_count = f"{db_info['table_count']} tables" - size = f"{db_info['size_mb']:.1f} MB" - - if db_info.get('warning'): - status = f"[red]{db_info['warning']}[/red]" - else: - status = "[green]✓[/green]" - - table.add_row(number, name, table_count, size, status) - - console.print(Panel( - table, - title="[bold green]MCP Databases[/]", - title_align="left", - subtitle=f"[dim]Total: {result['count']} database(s) | Use '/mcp db ' to select[/]", - subtitle_align="right" - )) - else: - console.print(f"[bold red]❌ {result['error']}[/]") - continue - - # Switch to database mode - try: - db_num = int(mcp_args) - result = mcp_manager.switch_mode("database", db_num) - - if result['success']: - db = result['database'] - console.print(f"\n[bold green]✓ {result['message']}[/]") - console.print(f"[dim cyan]Tables: {', '.join(db['tables'])}[/]") - console.print(f"\n[bold cyan]Available tools:[/] inspect_database, search_database, query_database") - console.print(f"\n[bold yellow]You can now ask questions about this database![/]") - console.print(f"[dim]Examples:[/]") - console.print(f" 'Show me all tables'") - console.print(f" 'Search for records mentioning X'") - console.print(f" 'How many rows in the users table?'") - console.print(f"\n[dim]Switch back to files: /mcp files[/]") - else: - console.print(f"[bold red]❌ {result['error']}[/]") - except ValueError: - console.print(f"[bold red]Invalid database number: {mcp_args}[/]") - console.print(f"[bold yellow]Use '/mcp db list' to see available databases[/]") - continue - - elif mcp_command == "files": - result = mcp_manager.switch_mode("files") - - if result['success']: - console.print(f"[bold green]✓ {result['message']}[/]") - console.print(f"\n[bold cyan]Available tools:[/] read_file, list_directory, search_files") - else: - console.print(f"[bold red]❌ {result['error']}[/]") - continue - - elif mcp_command == "gitignore": - if not mcp_args: - if mcp_manager.server: - status = "enabled" if mcp_manager.server.respect_gitignore else "disabled" - pattern_count = len(mcp_manager.server.gitignore_parser.patterns) - console.print(f"[bold blue].gitignore filtering: {status}[/]") - console.print(f"[dim cyan]Loaded {pattern_count} pattern(s) from .gitignore files[/]") - else: - console.print("[bold yellow]MCP is not enabled. Enable with /mcp on[/]") - continue - - if mcp_args.lower() == "on": - result = mcp_manager.toggle_gitignore(True) - if result['success']: - console.print(f"[bold green]✓ {result['message']}[/]") - console.print(f"[dim cyan]Using {result['pattern_count']} .gitignore pattern(s)[/]") - else: - console.print(f"[bold red]❌ {result['error']}[/]") - elif mcp_args.lower() == "off": - result = mcp_manager.toggle_gitignore(False) - if result['success']: - console.print(f"[bold green]✓ {result['message']}[/]") - console.print("[bold yellow]Warning: All files will be visible, including those in .gitignore[/]") - else: - console.print(f"[bold red]❌ {result['error']}[/]") - else: - console.print("[bold yellow]Usage: /mcp gitignore on|off[/]") - continue - - elif mcp_command == "write": - if not mcp_args: - # Show current write mode status - if mcp_manager.enabled: - status = "[bold green]Enabled ⚠️[/]" if mcp_manager.write_enabled else "[dim]Disabled[/]" - console.print(f"[bold blue]Write Mode:[/] {status}") - if mcp_manager.write_enabled: - console.print("[dim cyan]LLM can create, modify, and delete files in allowed folders[/]") - else: - console.print("[dim cyan]Use '/mcp write on' to enable write operations[/]") - else: - console.print("[bold yellow]MCP is not enabled. Enable with /mcp on[/]") - continue - - if mcp_args.lower() == "on": - mcp_manager.enable_write() - elif mcp_args.lower() == "off": - mcp_manager.disable_write() - else: - console.print("[bold yellow]Usage: /mcp write on|off[/]") - continue - - elif mcp_command == "status": - result = mcp_manager.get_status() - - if result['success']: - status_color = "green" if result['enabled'] else "red" - status_text = "Active ✓" if result['enabled'] else "Inactive ✗" - - table = Table( - "Property", "Value", - show_header=True, - header_style="bold magenta" - ) - - table.add_row("Status", f"[{status_color}]{status_text}[/{status_color}]") - - # Mode info - mode_info = result['mode_info'] - table.add_row("Current Mode", mode_info['mode_display']) - if 'database' in mode_info: - db = mode_info['database'] - table.add_row("Database Path", db['path']) - table.add_row("Database Tables", ', '.join(db['tables'])) - - if result['uptime']: - table.add_row("Uptime", result['uptime']) - - table.add_row("", "") - table.add_row("[bold]Configuration[/]", "") - table.add_row("Allowed Folders", str(result['folder_count'])) - table.add_row("Databases", str(result['database_count'])) - table.add_row("Total Files Accessible", str(result['total_files'])) - table.add_row("Total Size", f"{result['total_size_mb']:.1f} MB") - table.add_row(".gitignore Filtering", result['gitignore_status']) - if result['gitignore_patterns'] > 0: - table.add_row(".gitignore Patterns", str(result['gitignore_patterns'])) - - # Write mode status - write_status = "[green]Enabled ⚠️[/]" if result['write_enabled'] else "[dim]Disabled[/]" - table.add_row("Write Mode", write_status) - - if result['enabled']: - table.add_row("", "") - table.add_row("[bold]Tools Available[/]", "") - for tool in result['tools_available']: - table.add_row(f" ✓ {tool}", "") - - stats = result['stats'] - if stats['total_calls'] > 0: - table.add_row("", "") - table.add_row("[bold]Session Stats[/]", "") - table.add_row("Total Tool Calls", str(stats['total_calls'])) - - if stats['reads'] > 0: - table.add_row("Files Read", str(stats['reads'])) - if stats['lists'] > 0: - table.add_row("Directories Listed", str(stats['lists'])) - if stats['searches'] > 0: - table.add_row("File Searches", str(stats['searches'])) - if stats['db_inspects'] > 0: - table.add_row("DB Inspections", str(stats['db_inspects'])) - if stats['db_searches'] > 0: - table.add_row("DB Searches", str(stats['db_searches'])) - if stats['db_queries'] > 0: - table.add_row("SQL Queries", str(stats['db_queries'])) - - if stats['last_used']: - try: - last_used = datetime.datetime.fromisoformat(stats['last_used']) - time_ago = datetime.datetime.now() - last_used - - if time_ago.seconds < 60: - time_str = f"{time_ago.seconds} seconds ago" - elif time_ago.seconds < 3600: - time_str = f"{time_ago.seconds // 60} minutes ago" - else: - time_str = f"{time_ago.seconds // 3600} hours ago" - - table.add_row("Last Used", time_str) - except: - pass - - console.print(Panel( - table, - title="[bold green]MCP Status[/]", - title_align="left" - )) - else: - console.print(f"[bold red]❌ {result['error']}[/]") - continue - - else: - console.print("[bold cyan]MCP (Model Context Protocol) Commands:[/]\n") - console.print(" [bold yellow]/mcp on[/] - Start MCP server") - console.print(" [bold yellow]/mcp off[/] - Stop MCP server") - console.print(" [bold yellow]/mcp status[/] - Show comprehensive MCP status") - console.print("") - console.print(" [bold cyan]FILE MODE (default):[/]") - console.print(" [bold yellow]/mcp add [/] - Add folder for file access") - console.print(" [bold yellow]/mcp remove [/] - Remove folder") - console.print(" [bold yellow]/mcp list[/] - List all folders") - console.print(" [bold yellow]/mcp files[/] - Switch to file mode") - console.print(" [bold yellow]/mcp gitignore [on|off][/] - Toggle .gitignore filtering") - console.print("") - console.print(" [bold cyan]DATABASE MODE:[/]") - console.print(" [bold yellow]/mcp add db [/] - Add SQLite database") - console.print(" [bold yellow]/mcp db list[/] - List all databases") - console.print(" [bold yellow]/mcp db [/] - Switch to database mode") - console.print(" [bold yellow]/mcp remove db [/]- Remove database") - console.print("") - console.print("[dim]Use '/help /mcp' for detailed command help[/]") - console.print("[dim]Use '/help mcp' for comprehensive MCP guide[/]") - continue - - # ============================================================ - # ALL OTHER COMMANDS - # ============================================================ - - if user_input.lower() == "/retry": - if not session_history: - console.print("[bold red]No history to retry.[/]") - app_logger.warning("Retry attempted with no history") - continue - last_prompt = session_history[-1]['prompt'] - console.print("[bold green]Retrying last prompt...[/]") - app_logger.info(f"Retrying prompt: {last_prompt[:100]}...") - user_input = last_prompt - - elif user_input.lower().startswith("/online"): - args = user_input[8:].strip() - if not args: - status = "enabled" if online_mode_enabled else "disabled" - default_status = "enabled" if DEFAULT_ONLINE_MODE == "on" else "disabled" - console.print(f"[bold blue]Online mode (web search) {status}.[/]") - console.print(f"[dim blue]Default setting: {default_status} (use '/config online on|off' to change)[/]") - if selected_model: - if supports_online_mode(selected_model): - console.print(f"[dim green]Current model '{selected_model['name']}' supports online mode.[/]") - else: - console.print(f"[dim yellow]Current model '{selected_model['name']}' does not support online mode.[/]") - continue - if args.lower() == "on": - if not selected_model: - console.print("[bold red]No model selected. Select a model first with '/model'.[/]") - continue - if not supports_online_mode(selected_model): - console.print(f"[bold red]Model '{selected_model['name']}' does not support online mode (web search).[/]") - console.print("[dim yellow]Online mode requires models with 'tools' parameter support.[/]") - app_logger.warning(f"Online mode activation failed - model {selected_model['id']} doesn't support it") - continue - online_mode_enabled = True - console.print("[bold green]Online mode enabled for this session. Model will use web search capabilities.[/]") - console.print(f"[dim blue]Effective model ID: {get_effective_model_id(selected_model['id'], True)}[/]") - app_logger.info(f"Online mode enabled for model {selected_model['id']}") - elif args.lower() == "off": - online_mode_enabled = False - console.print("[bold green]Online mode disabled for this session. Model will not use web search.[/]") - if selected_model: - console.print(f"[dim blue]Effective model ID: {selected_model['id']}[/]") - app_logger.info("Online mode disabled") - else: - console.print("[bold yellow]Usage: /online on|off (or /online to view status)[/]") - continue - - elif user_input.lower().startswith("/memory"): - args = user_input[8:].strip() - if not args: - status = "enabled" if conversation_memory_enabled else "disabled" - history_count = len(session_history) - memory_start_index if conversation_memory_enabled and memory_start_index < len(session_history) else 0 - console.print(f"[bold blue]Conversation memory {status}.[/]") - if conversation_memory_enabled: - console.print(f"[dim blue]Tracking {history_count} message(s) since memory enabled.[/]") - else: - console.print(f"[dim yellow]Memory disabled. Each request is independent (saves tokens/cost).[/]") - continue - if args.lower() == "on": - conversation_memory_enabled = True - memory_start_index = len(session_history) - console.print("[bold green]Conversation memory enabled. Will remember conversations from this point forward.[/]") - console.print(f"[dim blue]Memory will track messages starting from index {memory_start_index}.[/]") - app_logger.info(f"Conversation memory enabled at index {memory_start_index}") - elif args.lower() == "off": - conversation_memory_enabled = False - console.print("[bold green]Conversation memory disabled. API calls will not include history (lower cost).[/]") - console.print(f"[dim yellow]Note: Messages are still saved locally but not sent to API.[/]") - app_logger.info("Conversation memory disabled") - else: - console.print("[bold yellow]Usage: /memory on|off (or /memory to view status)[/]") - continue - - elif user_input.lower().startswith("/paste"): - optional_prompt = user_input[7:].strip() - - try: - clipboard_content = pyperclip.paste() - except Exception as e: - console.print(f"[bold red]Failed to access clipboard: {e}[/]") - app_logger.error(f"Clipboard access error: {e}") - continue - - if not clipboard_content or not clipboard_content.strip(): - console.print("[bold red]Clipboard is empty.[/]") - app_logger.warning("Paste attempted with empty clipboard") - continue - - try: - clipboard_content.encode('utf-8') - - preview_lines = clipboard_content.split('\n')[:10] - preview_text = '\n'.join(preview_lines) - if len(clipboard_content.split('\n')) > 10: - preview_text += "\n... (content truncated for preview)" - - char_count = len(clipboard_content) - line_count = len(clipboard_content.split('\n')) - - console.print(Panel( - preview_text, - title=f"[bold cyan]📋 Clipboard Content Preview ({char_count} chars, {line_count} lines)[/]", - title_align="left", - border_style="cyan" - )) - - if optional_prompt: - final_prompt = f"{optional_prompt}\n\n```\n{clipboard_content}\n```" - console.print(f"[dim blue]Sending with prompt: '{optional_prompt}'[/]") - else: - final_prompt = clipboard_content - console.print("[dim blue]Sending clipboard content without additional prompt[/]") - - user_input = final_prompt - app_logger.info(f"Pasted content from clipboard: {char_count} chars, {line_count} lines, with prompt: {bool(optional_prompt)}") - - except UnicodeDecodeError: - console.print("[bold red]Clipboard contains non-text (binary) data. Only plain text is supported.[/]") - app_logger.error("Paste failed - clipboard contains binary data") - continue - except Exception as e: - console.print(f"[bold red]Error processing clipboard content: {e}[/]") - app_logger.error(f"Clipboard processing error: {e}") - continue - - elif user_input.lower().startswith("/export"): - args = user_input[8:].strip().split(maxsplit=1) - if len(args) != 2: - console.print("[bold red]Usage: /export [/]") - console.print("[bold yellow]Formats: md (Markdown), json (JSON), html (HTML)[/]") - console.print("[bold yellow]Example: /export md my_conversation.md[/]") - continue - - export_format = args[0].lower() - filename = args[1] - - if not session_history: - console.print("[bold red]No conversation history to export.[/]") - continue - - if export_format not in ['md', 'json', 'html']: - console.print("[bold red]Invalid format. Use: md, json, or html[/]") - continue - - try: - if export_format == 'md': - content = export_as_markdown(session_history, session_system_prompt) - elif export_format == 'json': - content = export_as_json(session_history, session_system_prompt) - elif export_format == 'html': - content = export_as_html(session_history, session_system_prompt) - - export_path = Path(filename).expanduser() - with open(export_path, 'w', encoding='utf-8') as f: - f.write(content) - - console.print(f"[bold green]✅ Conversation exported to: {export_path.absolute()}[/]") - console.print(f"[dim blue]Format: {export_format.upper()} | Messages: {len(session_history)} | Size: {len(content)} bytes[/]") - app_logger.info(f"Conversation exported as {export_format} to {export_path} ({len(session_history)} messages)") - except Exception as e: - console.print(f"[bold red]Export failed: {e}[/]") - app_logger.error(f"Export error: {e}") - continue - - elif user_input.lower().startswith("/save"): - args = user_input[6:].strip() - if not args: - console.print("[bold red]Usage: /save [/]") - continue - if not session_history: - console.print("[bold red]No history to save.[/]") - continue - save_conversation(args, session_history) - console.print(f"[bold green]Conversation saved as '{args}'.[/]") - app_logger.info(f"Conversation saved as '{args}' with {len(session_history)} messages") - continue - - elif user_input.lower().startswith("/load"): - args = user_input[6:].strip() - if not args: - console.print("[bold red]Usage: /load [/]") - console.print("[bold yellow]Tip: Use /list to see numbered conversations[/]") - continue - - conversation_name = None - if args.isdigit(): - conv_number = int(args) - if saved_conversations_cache and 1 <= conv_number <= len(saved_conversations_cache): - conversation_name = saved_conversations_cache[conv_number - 1]['name'] - console.print(f"[bold cyan]Loading conversation #{conv_number}: '{conversation_name}'[/]") - else: - console.print(f"[bold red]Invalid conversation number: {conv_number}[/]") - console.print(f"[bold yellow]Use /list to see available conversations (1-{len(saved_conversations_cache) if saved_conversations_cache else 0})[/]") - continue - else: - conversation_name = args - - loaded_data = load_conversation(conversation_name) - if not loaded_data: - console.print(f"[bold red]Conversation '{conversation_name}' not found.[/]") - app_logger.warning(f"Load failed for '{conversation_name}' - not found") - continue - session_history = loaded_data - current_index = len(session_history) - 1 - if conversation_memory_enabled: - memory_start_index = 0 - - # Recalculate session totals from loaded message history - total_input_tokens = 0 - total_output_tokens = 0 - total_cost = 0.0 - message_count = 0 - - for msg in session_history: - # Handle both old format (no cost data) and new format (with cost data) - total_input_tokens += msg.get('prompt_tokens', 0) - total_output_tokens += msg.get('completion_tokens', 0) - total_cost += msg.get('msg_cost', 0.0) - message_count += 1 - - console.print(f"[bold green]Conversation '{conversation_name}' loaded with {len(session_history)} messages.[/]") - if total_cost > 0: - console.print(f"[dim cyan]Restored session totals: {total_input_tokens + total_output_tokens} tokens, ${total_cost:.4f} cost[/]") - app_logger.info(f"Conversation '{conversation_name}' loaded with {len(session_history)} messages, restored cost: ${total_cost:.4f}") - continue - - elif user_input.lower().startswith("/delete"): - args = user_input[8:].strip() - if not args: - console.print("[bold red]Usage: /delete [/]") - console.print("[bold yellow]Tip: Use /list to see numbered conversations[/]") - continue - - conversation_name = None - if args.isdigit(): - conv_number = int(args) - if saved_conversations_cache and 1 <= conv_number <= len(saved_conversations_cache): - conversation_name = saved_conversations_cache[conv_number - 1]['name'] - console.print(f"[bold cyan]Deleting conversation #{conv_number}: '{conversation_name}'[/]") - else: - console.print(f"[bold red]Invalid conversation number: {conv_number}[/]") - console.print(f"[bold yellow]Use /list to see available conversations (1-{len(saved_conversations_cache) if saved_conversations_cache else 0})[/]") - continue - else: - conversation_name = args - - try: - confirm = typer.confirm(f"Delete conversation '{conversation_name}'? This cannot be undone.", default=False) - if not confirm: - console.print("[bold yellow]Deletion cancelled.[/]") - continue - except (EOFError, KeyboardInterrupt): - console.print("\n[bold yellow]Deletion cancelled.[/]") - continue - - deleted_count = delete_conversation(conversation_name) - if deleted_count > 0: - console.print(f"[bold green]Conversation '{conversation_name}' deleted ({deleted_count} version(s) removed).[/]") - app_logger.info(f"Conversation '{conversation_name}' deleted - {deleted_count} version(s)") - if saved_conversations_cache: - saved_conversations_cache = [c for c in saved_conversations_cache if c['name'] != conversation_name] - else: - console.print(f"[bold red]Conversation '{conversation_name}' not found.[/]") - app_logger.warning(f"Delete failed for '{conversation_name}' - not found") - continue - - elif user_input.lower() == "/list": - conversations = list_conversations() - if not conversations: - console.print("[bold yellow]No saved conversations found.[/]") - app_logger.info("User viewed conversation list - empty") - saved_conversations_cache = [] - continue - - saved_conversations_cache = conversations - - table = Table("No.", "Name", "Messages", "Last Saved", show_header=True, header_style="bold magenta") - for idx, conv in enumerate(conversations, 1): - try: - dt = datetime.datetime.fromisoformat(conv['timestamp']) - formatted_time = dt.strftime('%Y-%m-%d %H:%M:%S') - except: - formatted_time = conv['timestamp'] - - table.add_row( - str(idx), - conv['name'], - str(conv['message_count']), - formatted_time - ) - - console.print(Panel(table, title=f"[bold green]Saved Conversations ({len(conversations)} total)[/]", title_align="left", subtitle="[dim]Use /load or /delete to manage conversations[/]", subtitle_align="right")) - app_logger.info(f"User viewed conversation list - {len(conversations)} conversations") - continue - - elif user_input.lower() == "/prev": - if not session_history or current_index <= 0: - console.print("[bold red]No previous response.[/]") - continue - current_index -= 1 - prev_response = session_history[current_index]['response'] - md = Markdown(prev_response) - console.print(Panel(md, title=f"[bold green]Previous Response ({current_index + 1}/{len(session_history)})[/]", title_align="left")) - app_logger.debug(f"Viewed previous response at index {current_index}") - continue - - elif user_input.lower() == "/next": - if not session_history or current_index >= len(session_history) - 1: - console.print("[bold red]No next response.[/]") - continue - current_index += 1 - next_response = session_history[current_index]['response'] - md = Markdown(next_response) - console.print(Panel(md, title=f"[bold green]Next Response ({current_index + 1}/{len(session_history)})[/]", title_align="left")) - app_logger.debug(f"Viewed next response at index {current_index}") - continue - - elif user_input.lower() == "/stats": - credits = get_credits(API_KEY, OPENROUTER_BASE_URL) - credits_left = credits['credits_left'] if credits else "Unknown" - stats = f"Total Input: {total_input_tokens}, Total Output: {total_output_tokens}, Total Tokens: {total_input_tokens + total_output_tokens}, Total Cost: ${total_cost:.4f}, Avg Cost/Message: ${total_cost / message_count:.4f}" if message_count > 0 else "No messages." - table = Table("Metric", "Value", show_header=True, header_style="bold magenta") - table.add_row("Session Stats", stats) - table.add_row("Credits Left", credits_left) - console.print(Panel(table, title="[bold green]Session Cost Summary[/]", title_align="left")) - app_logger.info(f"User viewed stats: {stats}") - - warnings = check_credit_alerts(credits) - if warnings: - warning_text = '|'.join(warnings) - console.print(f"[bold red]⚠️ {warning_text}[/]") - app_logger.warning(f"Warnings in stats: {warning_text}") - continue - - elif user_input.lower().startswith("/middleout"): - args = user_input[11:].strip() - if not args: - console.print(f"[bold blue]Middle-out transform {'enabled' if middle_out_enabled else 'disabled'}.[/]") - continue - if args.lower() == "on": - middle_out_enabled = True - console.print("[bold green]Middle-out transform enabled.[/]") - elif args.lower() == "off": - middle_out_enabled = False - console.print("[bold green]Middle-out transform disabled.[/]") - else: - console.print("[bold yellow]Usage: /middleout on|off (or /middleout to view status)[/]") - continue - - elif user_input.lower() == "/reset": - try: - confirm = typer.confirm("Reset conversation context? This clears history and prompt.", default=False) - if not confirm: - console.print("[bold yellow]Reset cancelled.[/]") - continue - except (EOFError, KeyboardInterrupt): - console.print("\n[bold yellow]Reset cancelled.[/]") - continue - session_history = [] - current_index = -1 - session_system_prompt = "" - memory_start_index = 0 - total_input_tokens = 0 - total_output_tokens = 0 - total_cost = 0.0 - message_count = 0 - console.print("[bold green]Conversation context reset. Totals cleared.[/]") - app_logger.info("Conversation context reset by user - all totals reset to 0") - continue - - elif user_input.lower().startswith("/info"): - args = user_input[6:].strip() - if not args: - if not selected_model: - console.print("[bold red]No model selected and no model ID provided. Use '/model' first or '/info '.[/]") - continue - model_to_show = selected_model - else: - model_to_show = next((m for m in models_data if m["id"] == args or m.get("canonical_slug") == args or args.lower() in m["name"].lower()), None) - if not model_to_show: - console.print(f"[bold red]Model '{args}' not found.[/]") - continue - - pricing = model_to_show.get("pricing", {}) - architecture = model_to_show.get("architecture", {}) - supported_params = ", ".join(model_to_show.get("supported_parameters", [])) or "None" - top_provider = model_to_show.get("top_provider", {}) - - table = Table("Property", "Value", show_header=True, header_style="bold magenta") - table.add_row("ID", model_to_show["id"]) - table.add_row("Name", model_to_show["name"]) - table.add_row("Description", model_to_show.get("description", "N/A")) - table.add_row("Context Length", str(model_to_show.get("context_length", "N/A"))) - table.add_row("Online Support", "Yes" if supports_online_mode(model_to_show) else "No") - table.add_row("Pricing - Prompt ($/M tokens)", pricing.get("prompt", "N/A")) - table.add_row("Pricing - Completion ($/M tokens)", pricing.get("completion", "N/A")) - table.add_row("Pricing - Request ($)", pricing.get("request", "N/A")) - table.add_row("Pricing - Image ($)", pricing.get("image", "N/A")) - table.add_row("Input Modalities", ", ".join(architecture.get("input_modalities", [])) or "None") - table.add_row("Output Modalities", ", ".join(architecture.get("output_modalities", [])) or "None") - table.add_row("Supported Parameters", supported_params) - table.add_row("Top Provider Context Length", str(top_provider.get("context_length", "N/A"))) - table.add_row("Max Completion Tokens", str(top_provider.get("max_completion_tokens", "N/A"))) - table.add_row("Moderated", "Yes" if top_provider.get("is_moderated", False) else "No") - - console.print(Panel(table, title=f"[bold green]Model Info: {model_to_show['name']}[/]", title_align="left")) - continue - - elif user_input.startswith("/model"): - app_logger.info("User initiated model selection") - args = user_input[7:].strip() - search_term = args if args else "" - filtered_models = text_models - if search_term: - filtered_models = [m for m in text_models if search_term.lower() in m["name"].lower() or search_term.lower() in m["id"].lower()] - if not filtered_models: - console.print(f"[bold red]No models match '{search_term}'. Try '/model'.[/]") - continue - - table = Table("No.", "Name", "ID", "Image", "Online", "Tools", show_header=True, header_style="bold magenta") - for i, model in enumerate(filtered_models, 1): - image_support = "[green]✓[/green]" if has_image_capability(model) else "[red]✗[/red]" - online_support = "[green]✓[/green]" if supports_online_mode(model) else "[red]✗[/red]" - tools_support = "[green]✓[/green]" if supports_function_calling(model) else "[red]✗[/red]" - table.add_row(str(i), model["name"], model["id"], image_support, online_support, tools_support) - - title = f"[bold green]Available Models ({'All' if not search_term else f'Search: {search_term}'})[/]" - display_paginated_table(table, title) - - while True: - try: - choice = int(typer.prompt("Enter model number (or 0 to cancel)")) - if choice == 0: - break - if 1 <= choice <= len(filtered_models): - selected_model = filtered_models[choice - 1] - - if supports_online_mode(selected_model) and DEFAULT_ONLINE_MODE == "on": - online_mode_enabled = True - console.print(f"[bold cyan]Selected: {selected_model['name']} ({selected_model['id']})[/]") - console.print("[dim cyan]✓ Online mode auto-enabled (default setting). Use '/online off' to disable for this session.[/]") - else: - online_mode_enabled = False - console.print(f"[bold cyan]Selected: {selected_model['name']} ({selected_model['id']})[/]") - if supports_online_mode(selected_model): - console.print("[dim green]✓ This model supports online mode. Use '/online on' to enable web search.[/]") - - app_logger.info(f"Model selected: {selected_model['name']} ({selected_model['id']}), Online: {online_mode_enabled}") - break - console.print("[bold red]Invalid choice. Try again.[/]") - except ValueError: - console.print("[bold red]Invalid input. Enter a number.[/]") - continue - - elif user_input.startswith("/maxtoken"): - args = user_input[10:].strip() - if not args: - console.print(f"[bold blue]Current session max tokens: {session_max_token}[/]") - continue - try: - new_limit = int(args) - if new_limit < 1: - console.print("[bold red]Session token limit must be at least 1.[/]") - continue - if new_limit > MAX_TOKEN: - console.print(f"[bold yellow]Cannot exceed stored max ({MAX_TOKEN}). Capping.[/]") - new_limit = MAX_TOKEN - session_max_token = new_limit - console.print(f"[bold green]Session max tokens set to: {session_max_token}[/]") - except ValueError: - console.print("[bold red]Invalid token limit. Provide a positive integer.[/]") - continue - - elif user_input.startswith("/system"): - args = user_input[8:].strip() - if not args: - if session_system_prompt: - console.print(f"[bold blue]Current session system prompt:[/] {session_system_prompt}") - else: - console.print("[bold blue]No session system prompt set.[/]") - continue - if args.lower() == "clear": - session_system_prompt = "" - console.print("[bold green]Session system prompt cleared.[/]") - else: - session_system_prompt = args - console.print(f"[bold green]Session system prompt set to: {session_system_prompt}[/]") - continue - - elif user_input.startswith("/config"): - args = user_input[8:].strip().lower() - update = check_for_updates(version) - - if args == "api": - try: - new_api_key = typer.prompt("Enter new API key") - if new_api_key.strip(): - set_config('api_key', new_api_key.strip()) - API_KEY = new_api_key.strip() - client = OpenRouter(api_key=API_KEY) - console.print("[bold green]API key updated![/]") - else: - console.print("[bold yellow]No change.[/]") - except Exception as e: - console.print(f"[bold red]Error updating API key: {e}[/]") - elif args == "url": - try: - new_url = typer.prompt("Enter new base URL") - if new_url.strip(): - set_config('base_url', new_url.strip()) - OPENROUTER_BASE_URL = new_url.strip() - console.print("[bold green]Base URL updated![/]") - else: - console.print("[bold yellow]No change.[/]") - except Exception as e: - console.print(f"[bold red]Error updating URL: {e}[/]") - elif args.startswith("costwarning"): - sub_args = args[11:].strip() - if not sub_args: - console.print(f"[bold blue]Stored cost warning threshold: ${COST_WARNING_THRESHOLD:.4f}[/]") - continue - try: - new_threshold = float(sub_args) - if new_threshold < 0: - console.print("[bold red]Cost warning threshold must be >= 0.[/]") - continue - set_config(HIGH_COST_WARNING, str(new_threshold)) - COST_WARNING_THRESHOLD = new_threshold - console.print(f"[bold green]Cost warning threshold set to ${COST_WARNING_THRESHOLD:.4f}[/]") - except ValueError: - console.print("[bold red]Invalid cost threshold. Provide a valid number.[/]") - elif args.startswith("stream"): - sub_args = args[7:].strip() - if sub_args in ["on", "off"]: - set_config('stream_enabled', sub_args) - STREAM_ENABLED = sub_args - console.print(f"[bold green]Streaming {'enabled' if sub_args == 'on' else 'disabled'}.[/]") - else: - console.print("[bold yellow]Usage: /config stream on|off[/]") - elif args.startswith("loglevel"): - sub_args = args[8:].strip() - if not sub_args: - console.print(f"[bold blue]Current log level: {LOG_LEVEL_STR.upper()}[/]") - console.print(f"[dim yellow]Valid levels: debug, info, warning, error, critical[/]") - continue - if sub_args.lower() in VALID_LOG_LEVELS: - if set_log_level(sub_args): - set_config('log_level', sub_args.lower()) - console.print(f"[bold green]Log level set to: {sub_args.upper()}[/]") - app_logger.info(f"Log level changed to {sub_args.upper()}") - else: - console.print(f"[bold red]Failed to set log level.[/]") - else: - console.print(f"[bold red]Invalid log level: {sub_args}[/]") - console.print(f"[bold yellow]Valid levels: debug, info, warning, error, critical[/]") - elif args.startswith("log"): - sub_args = args[4:].strip() - if not sub_args: - console.print(f"[bold blue]Current log file size limit: {LOG_MAX_SIZE_MB} MB[/]") - console.print(f"[bold blue]Log backup count: {LOG_BACKUP_COUNT} files[/]") - console.print(f"[bold blue]Log level: {LOG_LEVEL_STR.upper()}[/]") - console.print(f"[dim yellow]Total max disk usage: ~{LOG_MAX_SIZE_MB * (LOG_BACKUP_COUNT + 1)} MB[/]") - continue - try: - new_size_mb = int(sub_args) - if new_size_mb < 1: - console.print("[bold red]Log size must be at least 1 MB.[/]") - continue - if new_size_mb > 100: - console.print("[bold yellow]Warning: Log size > 100MB. Capping at 100MB.[/]") - new_size_mb = 100 - set_config('log_max_size_mb', str(new_size_mb)) - LOG_MAX_SIZE_MB = new_size_mb - - app_logger = reload_logging_config() - - console.print(f"[bold green]Log size limit set to {new_size_mb} MB and applied immediately.[/]") - console.print(f"[dim cyan]Log file rotated if it exceeded the new limit.[/]") - app_logger.info(f"Log size limit updated to {new_size_mb} MB and reloaded") - except ValueError: - console.print("[bold red]Invalid size. Provide a number in MB.[/]") - elif args.startswith("online"): - sub_args = args[7:].strip() - if not sub_args: - current_default = "enabled" if DEFAULT_ONLINE_MODE == "on" else "disabled" - console.print(f"[bold blue]Default online mode: {current_default}[/]") - console.print("[dim yellow]This sets the default for new models. Use '/online on|off' to override in current session.[/]") - continue - if sub_args in ["on", "off"]: - set_config('default_online_mode', sub_args) - DEFAULT_ONLINE_MODE = sub_args - console.print(f"[bold green]Default online mode {'enabled' if sub_args == 'on' else 'disabled'}.[/]") - console.print("[dim blue]Note: This affects new model selections. Current session unchanged.[/]") - app_logger.info(f"Default online mode set to {sub_args}") - else: - console.print("[bold yellow]Usage: /config online on|off[/]") - elif args.startswith("maxtoken"): - sub_args = args[9:].strip() - if not sub_args: - console.print(f"[bold blue]Stored max token limit: {MAX_TOKEN}[/]") - continue - try: - new_max = int(sub_args) - if new_max < 1: - console.print("[bold red]Max token limit must be at least 1.[/]") - continue - if new_max > 1000000: - console.print("[bold yellow]Capped at 1M for safety.[/]") - new_max = 1000000 - set_config('max_token', str(new_max)) - MAX_TOKEN = new_max - if session_max_token > MAX_TOKEN: - session_max_token = MAX_TOKEN - console.print(f"[bold yellow]Session adjusted to {session_max_token}.[/]") - console.print(f"[bold green]Stored max token limit updated to: {MAX_TOKEN}[/]") - except ValueError: - console.print("[bold red]Invalid token limit.[/]") - elif args.startswith("model"): - sub_args = args[6:].strip() - search_term = sub_args if sub_args else "" - filtered_models = text_models - if search_term: - filtered_models = [m for m in text_models if search_term.lower() in m["name"].lower() or search_term.lower() in m["id"].lower()] - if not filtered_models: - console.print(f"[bold red]No models match '{search_term}'. Try without search.[/]") - continue - - table = Table("No.", "Name", "ID", "Image", "Online", "Tools", show_header=True, header_style="bold magenta") - for i, model in enumerate(filtered_models, 1): - image_support = "[green]✓[/green]" if has_image_capability(model) else "[red]✗[/red]" - online_support = "[green]✓[/green]" if supports_online_mode(model) else "[red]✗[/red]" - tools_support = "[green]✓[/green]" if supports_function_calling(model) else "[red]✗[/red]" - table.add_row(str(i), model["name"], model["id"], image_support, online_support, tools_support) - - title = f"[bold green]Available Models for Default ({'All' if not search_term else f'Search: {search_term}'})[/]" - display_paginated_table(table, title) - - while True: - try: - choice = int(typer.prompt("Enter model number (or 0 to cancel)")) - if choice == 0: - break - if 1 <= choice <= len(filtered_models): - default_model = filtered_models[choice - 1] - set_config('default_model', default_model["id"]) - current_name = selected_model['name'] if selected_model else "None" - console.print(f"[bold cyan]Default model set to: {default_model['name']} ({default_model['id']}). Current unchanged: {current_name}[/]") - break - console.print("[bold red]Invalid choice. Try again.[/]") - except ValueError: - console.print("[bold red]Invalid input. Enter a number.[/]") - else: - DEFAULT_MODEL_ID = get_config('default_model') - memory_status = "Enabled" if conversation_memory_enabled else "Disabled" - memory_tracked = len(session_history) - memory_start_index if conversation_memory_enabled else 0 - - mcp_status = "Enabled" if mcp_manager.enabled else "Disabled" - mcp_folders = len(mcp_manager.allowed_folders) - mcp_databases = len(mcp_manager.databases) - mcp_mode = mcp_manager.mode - gitignore_status = "enabled" if (mcp_manager.server and mcp_manager.server.respect_gitignore) else "disabled" - gitignore_patterns = len(mcp_manager.server.gitignore_parser.patterns) if mcp_manager.server else 0 - - table = Table("Setting", "Value", show_header=True, header_style="bold magenta", width=console.width - 10) - table.add_row("API Key", API_KEY or "[Not set]") - table.add_row("Base URL", OPENROUTER_BASE_URL or "[Not set]") - table.add_row("DB Path", str(database) or "[Not set]") - table.add_row("Logfile", str(log_file) or "[Not set]") - table.add_row("Log Size Limit", f"{LOG_MAX_SIZE_MB} MB") - table.add_row("Log Backups", str(LOG_BACKUP_COUNT)) - table.add_row("Log Level", LOG_LEVEL_STR.upper()) - table.add_row("Streaming", "Enabled" if STREAM_ENABLED == "on" else "Disabled") - table.add_row("Default Model", DEFAULT_MODEL_ID or "[Not set]") - table.add_row("Current Model", "[Not set]" if selected_model is None else str(selected_model["name"])) - table.add_row("Default Online Mode", "Enabled" if DEFAULT_ONLINE_MODE == "on" else "Disabled") - table.add_row("Session Online Mode", "Enabled" if online_mode_enabled else "Disabled") - table.add_row("MCP Status", f"{mcp_status} ({mcp_folders} folders, {mcp_databases} DBs)") - table.add_row("MCP Mode", f"{mcp_mode}") - table.add_row("MCP .gitignore", f"{gitignore_status} ({gitignore_patterns} patterns)") - table.add_row("Max Token", str(MAX_TOKEN)) - table.add_row("Session Token", "[Not set]" if session_max_token == 0 else str(session_max_token)) - table.add_row("Session System Prompt", session_system_prompt or "[Not set]") - table.add_row("Cost Warning Threshold", f"${COST_WARNING_THRESHOLD:.4f}") - table.add_row("Middle-out Transform", "Enabled" if middle_out_enabled else "Disabled") - table.add_row("Conversation Memory", f"{memory_status} ({memory_tracked} tracked)" if conversation_memory_enabled else memory_status) - table.add_row("History Size", str(len(session_history))) - table.add_row("Current History Index", str(current_index) if current_index >= 0 else "[None]") - table.add_row("App Name", APP_NAME) - table.add_row("App URL", APP_URL) - - credits = get_credits(API_KEY, OPENROUTER_BASE_URL) - if credits: - table.add_row("Total Credits", credits['total_credits']) - table.add_row("Used Credits", credits['used_credits']) - table.add_row("Credits Left", credits['credits_left']) - else: - table.add_row("Total Credits", "[Unavailable - Check API key]") - table.add_row("Used Credits", "[Unavailable - Check API key]") - table.add_row("Credits Left", "[Unavailable - Check API key]") - - console.print(Panel(table, title="[bold green]Current Configurations[/]", title_align="left", subtitle="%s" %(update), subtitle_align="right")) - continue - - if user_input.lower() == "/credits": - credits = get_credits(API_KEY, OPENROUTER_BASE_URL) - if credits: - console.print(f"[bold green]Credits left: {credits['credits_left']}[/]") - alerts = check_credit_alerts(credits) - if alerts: - for alert in alerts: - console.print(f"[bold red]⚠️ {alert}[/]") - else: - console.print("[bold red]Unable to fetch credits. Check your API key or network.[/]") - continue - - if user_input.lower() == "/clear" or user_input.lower() == "/cl": - clear_screen() - DEFAULT_MODEL_ID = get_config('default_model') - token_value = session_max_token if session_max_token != 0 else "Not set" - console.print(f"[bold cyan]Token limits: Max= {MAX_TOKEN}, Session={token_value}[/]") - console.print("[bold blue]Active model[/] [bold red]%s[/]" %(str(selected_model["name"]) if selected_model else "None")) - if online_mode_enabled: - console.print("[bold cyan]Online mode: Enabled (web search active)[/]") - if mcp_manager.enabled: - gitignore_status = "on" if (mcp_manager.server and mcp_manager.server.respect_gitignore) else "off" - mode_display = "Files" if mcp_manager.mode == "files" else f"DB #{mcp_manager.selected_db_index + 1}" - console.print(f"[bold cyan]MCP: Enabled (Mode: {mode_display}, {len(mcp_manager.allowed_folders)} folders, {len(mcp_manager.databases)} DBs, .gitignore {gitignore_status})[/]") - continue - - if user_input.lower().startswith("/help"): - args = user_input[6:].strip() - - if args: - show_command_help(args) - continue - - help_table = Table("Command", "Description", "Example", show_header=True, header_style="bold cyan", width=console.width - 10) - - # SESSION COMMANDS - help_table.add_row( - "[bold yellow]━━━ SESSION COMMANDS ━━━[/]", - "", - "" - ) - help_table.add_row( - "/clear or /cl", - "Clear terminal screen. Keyboard shortcut: [bold]Ctrl+L[/]", - "/clear\n/cl" - ) - help_table.add_row( - "/help [command|topic]", - "Show help menu or detailed help. Use '/help mcp' for MCP guide.", - "/help\n/help mcp" - ) - help_table.add_row( - "/memory [on|off]", - "Toggle conversation memory. ON sends history, OFF is stateless (saves cost).", - "/memory\n/memory off" - ) - help_table.add_row( - "/next", - "View next response in history.", - "/next" - ) - help_table.add_row( - "/online [on|off]", - "Enable/disable online mode (web search) for session.", - "/online on\n/online off" - ) - help_table.add_row( - "/paste [prompt]", - "Paste clipboard text/code. Optional prompt.", - "/paste\n/paste Explain" - ) - help_table.add_row( - "/prev", - "View previous response in history.", - "/prev" - ) - help_table.add_row( - "/reset", - "Clear history and reset system prompt (requires confirmation).", - "/reset" - ) - help_table.add_row( - "/retry", - "Resend last prompt.", - "/retry" - ) - - # MCP COMMANDS - help_table.add_row( - "[bold yellow]━━━ MCP (FILE & DATABASE ACCESS) ━━━[/]", - "", - "" - ) - help_table.add_row( - "/mcp on", - "Start MCP server for file and database access.", - "/mcp on" - ) - help_table.add_row( - "/mcp off", - "Stop MCP server.", - "/mcp off" - ) - help_table.add_row( - "/mcp status", - "Show mode, stats, folders, databases, and .gitignore status.", - "/mcp status" - ) - help_table.add_row( - "/mcp add ", - "Add folder for file access (auto-loads .gitignore).", - "/mcp add ~/Documents" - ) - help_table.add_row( - "/mcp add db ", - "Add SQLite database for querying.", - "/mcp add db ~/app.db" - ) - help_table.add_row( - "/mcp list", - "List all allowed folders with stats.", - "/mcp list" - ) - help_table.add_row( - "/mcp db list", - "List all databases with details.", - "/mcp db list" - ) - help_table.add_row( - "/mcp db ", - "Switch to database mode (select DB by number).", - "/mcp db 1" - ) - help_table.add_row( - "/mcp files", - "Switch to file mode (default).", - "/mcp files" - ) - help_table.add_row( - "/mcp remove ", - "Remove folder by path or number.", - "/mcp remove 2" - ) - help_table.add_row( - "/mcp remove db ", - "Remove database by number.", - "/mcp remove db 1" - ) - help_table.add_row( - "/mcp gitignore [on|off]", - "Toggle .gitignore filtering (respects project excludes).", - "/mcp gitignore on" - ) - help_table.add_row( - "/mcp write [on|off]", - "Enable/disable write mode (create, edit, delete files). Requires confirmation.", - "/mcp write on" - ) - - # MODEL COMMANDS - help_table.add_row( - "[bold yellow]━━━ MODEL COMMANDS ━━━[/]", - "", - "" - ) - help_table.add_row( - "/info [model_id]", - "Display model details (pricing, capabilities, context).", - "/info\n/info gpt-4o" - ) - help_table.add_row( - "/model [search]", - "Select/change model. Shows image and online capabilities.", - "/model\n/model gpt" - ) - - # CONFIGURATION - help_table.add_row( - "[bold yellow]━━━ CONFIGURATION ━━━[/]", - "", - "" - ) - help_table.add_row( - "/config", - "View all configurations.", - "/config" - ) - help_table.add_row( - "/config api", - "Set/update API key.", - "/config api" - ) - help_table.add_row( - "/config costwarning [val]", - "Set cost warning threshold (USD).", - "/config costwarning 0.05" - ) - help_table.add_row( - "/config log [size_mb]", - "Set log file size limit. Takes effect immediately.", - "/config log 20" - ) - help_table.add_row( - "/config loglevel [level]", - "Set log verbosity (debug/info/warning/error/critical).", - "/config loglevel debug" - ) - help_table.add_row( - "/config maxtoken [val]", - "Set stored max token limit.", - "/config maxtoken 50000" - ) - help_table.add_row( - "/config model [search]", - "Set default startup model.", - "/config model gpt" - ) - help_table.add_row( - "/config online [on|off]", - "Set default online mode for new models.", - "/config online on" - ) - help_table.add_row( - "/config stream [on|off]", - "Enable/disable response streaming.", - "/config stream off" - ) - help_table.add_row( - "/config url", - "Set/update base URL.", - "/config url" - ) - - # TOKEN & SYSTEM - help_table.add_row( - "[bold yellow]━━━ TOKEN & SYSTEM ━━━[/]", - "", - "" - ) - help_table.add_row( - "/maxtoken [value]", - "Set temporary session token limit.", - "/maxtoken 2000" - ) - help_table.add_row( - "/middleout [on|off]", - "Enable/disable prompt compression for large contexts.", - "/middleout on" - ) - help_table.add_row( - "/system [prompt|clear]", - "Set session system prompt. Use 'clear' to reset.", - "/system You are a Python expert" - ) - - # CONVERSATION MGMT - help_table.add_row( - "[bold yellow]━━━ CONVERSATION MGMT ━━━[/]", - "", - "" - ) - help_table.add_row( - "/delete ", - "Delete saved conversation (requires confirmation).", - "/delete my_chat\n/delete 3" - ) - help_table.add_row( - "/export ", - "Export conversation (md/json/html).", - "/export md notes.md" - ) - help_table.add_row( - "/list", - "List saved conversations with numbers and stats.", - "/list" - ) - help_table.add_row( - "/load ", - "Load saved conversation.", - "/load my_chat\n/load 3" - ) - help_table.add_row( - "/save ", - "Save current conversation.", - "/save my_chat" - ) - - # MONITORING - help_table.add_row( - "[bold yellow]━━━ MONITORING & STATS ━━━[/]", - "", - "" - ) - help_table.add_row( - "/credits", - "Display OpenRouter credits with alerts.", - "/credits" - ) - help_table.add_row( - "/stats", - "Display session stats: tokens, cost, credits.", - "/stats" - ) - - # INPUT METHODS - help_table.add_row( - "[bold yellow]━━━ INPUT METHODS ━━━[/]", - "", - "" - ) - help_table.add_row( - "@/path/to/file", - "Attach files: images (PNG, JPG), PDFs, code files.", - "Debug @script.py\nAnalyze @image.png" - ) - help_table.add_row( - "Clipboard paste", - "Use /paste to send clipboard content.", - "/paste\n/paste Explain" - ) - help_table.add_row( - "// escape", - "Start with // to send literal / character.", - "//help sends '/help' as text" - ) - - # EXIT - help_table.add_row( - "[bold yellow]━━━ EXIT ━━━[/]", - "", - "" - ) - help_table.add_row( - "exit | quit | bye", - "Quit with session summary.", - "exit" - ) - - console.print(Panel( - help_table, - title="[bold cyan]oAI Chat Help (Version %s)[/]" % version, - title_align="center", - subtitle="💡 Commands are case-insensitive • Use /help for details • Use /help mcp for MCP guide • Memory ON by default • MCP has file & DB modes • Use // to escape / • Visit: https://iurl.no/oai", - subtitle_align="center", - border_style="cyan" - )) - continue - - if not selected_model: - console.print("[bold yellow]Select a model first with '/model'.[/]") - continue - - # ============================================================ - # PROCESS FILE ATTACHMENTS - # ============================================================ - text_part = user_input - file_attachments = [] - content_blocks = [] - - # Smart file detection: Simple pattern + extension validation - # This avoids extremely long regex that can cause binary signing issues - - # Common file extensions we support - ALLOWED_FILE_EXTENSIONS = { - # Code - '.py', '.js', '.ts', '.jsx', '.tsx', '.vue', '.java', '.c', '.cpp', '.cc', '.cxx', - '.h', '.hpp', '.hxx', '.rb', '.go', '.rs', '.swift', '.kt', '.kts', '.php', - '.sh', '.bash', '.zsh', '.fish', '.bat', '.cmd', '.ps1', - # Data - '.json', '.csv', '.yaml', '.yml', '.toml', '.xml', '.sql', '.db', '.sqlite', '.sqlite3', - # Documents - '.txt', '.md', '.log', '.conf', '.cfg', '.ini', '.env', '.properties', - # Images - '.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp', '.svg', '.ico', - # Archives - '.zip', '.tar', '.gz', '.bz2', '.7z', '.rar', '.xz', - # Config files - '.lock', '.gitignore', '.dockerignore', '.editorconfig', '.eslintrc', - '.prettierrc', '.babelrc', '.nvmrc', '.npmrc', - # Binary/Compiled - '.pyc', '.pyo', '.pyd', '.so', '.dll', '.dylib', '.exe', '.app', - '.dmg', '.pkg', '.deb', '.rpm', '.apk', '.ipa', - # ML/AI - '.pkl', '.pickle', '.joblib', '.npy', '.npz', '.safetensors', '.onnx', - '.pt', '.pth', '.ckpt', '.pb', '.tflite', '.mlmodel', '.coreml', '.rknn', - # Data formats - '.wasm', '.proto', '.graphql', '.graphqls', '.grpc', '.avro', '.parquet', - '.orc', '.feather', '.arrow', '.hdf5', '.h5', '.mat', '.r', '.rdata', '.rds', - # Other - '.pdf', '.class', '.jar', '.war' - } - - # Simple pattern: @filepath where filepath starts with ~, /, . or has an extension - # Much shorter than listing all extensions in the regex - file_pattern = r'(?:^|\s)@([~/\.][\S]+|[\w][\w-]*\.[\w]+)(?=\s|$)' - - for match in re.finditer(file_pattern, user_input, re.IGNORECASE): - file_path = match.group(1) - - # Get the extension - file_ext = os.path.splitext(file_path)[1].lower() - - # Skip if it doesn't start with a path char and has no allowed extension - if not file_path.startswith(('/', '~', '.', '\\')): - if not file_ext or file_ext not in ALLOWED_FILE_EXTENSIONS: - # This looks like a domain name, not a file - continue - - expanded_path = os.path.expanduser(os.path.abspath(file_path)) - - if not os.path.exists(expanded_path) or os.path.isdir(expanded_path): - console.print(f"[bold red]File not found or is a directory: {expanded_path}[/]") - continue - - file_size = os.path.getsize(expanded_path) - if file_size > 10 * 1024 * 1024: - console.print(f"[bold red]File too large (>10MB): {expanded_path}[/]") - continue - - mime_type, _ = mimetypes.guess_type(expanded_path) - - try: - with open(expanded_path, 'rb') as f: - file_data = f.read() - - if mime_type and mime_type.startswith('image/'): - modalities = selected_model.get("architecture", {}).get("input_modalities", []) - if "image" not in modalities: - console.print("[bold red]Selected model does not support image attachments.[/]") - console.print(f"[dim yellow]Supported modalities: {', '.join(modalities) if modalities else 'text only'}[/]") - continue - b64_data = base64.b64encode(file_data).decode('utf-8') - content_blocks.append({"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{b64_data}"}}) - console.print(f"[dim green]✓ Image attached: {os.path.basename(expanded_path)} ({file_size / 1024:.1f} KB)[/]") - - elif mime_type == 'application/pdf' or file_ext == '.pdf': - modalities = selected_model.get("architecture", {}).get("input_modalities", []) - supports_pdf = any(mod in modalities for mod in ["document", "pdf", "file"]) - if not supports_pdf: - console.print("[bold red]Selected model does not support PDF attachments.[/]") - console.print(f"[dim yellow]Supported modalities: {', '.join(modalities) if modalities else 'text only'}[/]") - continue - b64_data = base64.b64encode(file_data).decode('utf-8') - content_blocks.append({"type": "image_url", "image_url": {"url": f"data:application/pdf;base64,{b64_data}"}}) - console.print(f"[dim green]✓ PDF attached: {os.path.basename(expanded_path)} ({file_size / 1024:.1f} KB)[/]") - - elif (mime_type == 'text/plain' or file_ext in SUPPORTED_CODE_EXTENSIONS): - text_content = file_data.decode('utf-8') - content_blocks.append({"type": "text", "text": f"Code File: {os.path.basename(expanded_path)}\n\n{text_content}"}) - console.print(f"[dim green]✓ Code file attached: {os.path.basename(expanded_path)} ({file_size / 1024:.1f} KB)[/]") - - else: - console.print(f"[bold red]Unsupported file type ({mime_type}) for {expanded_path}.[/]") - console.print("[bold yellow]Supported types: images (PNG, JPG, etc.), PDFs, and code files (.py, .js, etc.)[/]") - continue - - file_attachments.append(file_path) - app_logger.info(f"File attached: {os.path.basename(expanded_path)}, Type: {mime_type or file_ext}, Size: {file_size / 1024:.1f} KB") - except UnicodeDecodeError: - console.print(f"[bold red]Cannot decode {expanded_path} as UTF-8. File may be binary or use unsupported encoding.[/]") - app_logger.error(f"UTF-8 decode error for {expanded_path}") - continue - except Exception as e: - console.print(f"[bold red]Error reading file {expanded_path}: {e}[/]") - app_logger.error(f"File read error for {expanded_path}: {e}") - continue - - # Remove file attachments from text (use same simple pattern) - text_part = re.sub(file_pattern, lambda m: m.group(0)[0] if m.group(0)[0].isspace() else '', text_part).strip() - - # Build message content - if text_part or content_blocks: - message_content = [] - if text_part: - message_content.append({"type": "text", "text": text_part}) - message_content.extend(content_blocks) - else: - console.print("[bold red]Prompt cannot be empty.[/]") - continue - - - # ============================================================ - # BUILD API MESSAGES - # ============================================================ - api_messages = [] - - if session_system_prompt: - api_messages.append({"role": "system", "content": session_system_prompt}) - - # Add database context if in database mode - if mcp_manager.enabled and mcp_manager.mode == "database" and mcp_manager.selected_db_index is not None: - db = mcp_manager.databases[mcp_manager.selected_db_index] - db_context = f"""You are currently connected to SQLite database: {db['name']} -Available tables: {', '.join(db['tables'])} - -You can: -- Inspect the database schema with inspect_database -- Search for data across tables with search_database -- Execute read-only SQL queries with query_database - -All queries are read-only. INSERT/UPDATE/DELETE are not allowed.""" - api_messages.append({"role": "system", "content": db_context}) - - if conversation_memory_enabled: - for i in range(memory_start_index, len(session_history)): - history_entry = session_history[i] - api_messages.append({ - "role": "user", - "content": history_entry['prompt'] - }) - api_messages.append({ - "role": "assistant", - "content": history_entry['response'] - }) - - api_messages.append({"role": "user", "content": message_content}) - - # Get effective model ID - effective_model_id = get_effective_model_id(selected_model["id"], online_mode_enabled) - - # ======================================================================== - # BUILD API PARAMS WITH MCP TOOLS SUPPORT - # ======================================================================== - api_params = { - "model": effective_model_id, - "messages": api_messages, - "stream": STREAM_ENABLED == "on", - "http_headers": { - "HTTP-Referer": APP_URL, - "X-Title": APP_NAME - } - } - - # Add MCP tools if enabled - mcp_tools_added = False - mcp_tools = [] - if mcp_manager.enabled: - if supports_function_calling(selected_model): - mcp_tools = mcp_manager.get_tools_schema() - if mcp_tools: - api_params["tools"] = mcp_tools - api_params["tool_choice"] = "auto" - mcp_tools_added = True - - # IMPORTANT: Disable streaming if MCP tools are present - if api_params["stream"]: - api_params["stream"] = False - app_logger.info("Disabled streaming due to MCP tool calls") - - app_logger.info(f"Added {len(mcp_tools)} MCP tools to request (mode: {mcp_manager.mode})") - else: - app_logger.debug(f"Model {selected_model['id']} doesn't support function calling - MCP tools not added") - - if session_max_token > 0: - api_params["max_tokens"] = session_max_token - if middle_out_enabled: - api_params["transforms"] = ["middle-out"] - - # Log API request - file_count = len(file_attachments) - history_messages_count = len(session_history) - memory_start_index if conversation_memory_enabled else 0 - memory_status = "ON" if conversation_memory_enabled else "OFF" - online_status = "ON" if online_mode_enabled else "OFF" - - if mcp_tools_added: - if mcp_manager.mode == "files": - mcp_status = f"ON (Files, {len(mcp_manager.allowed_folders)} folders, {len(mcp_tools)} tools)" - elif mcp_manager.mode == "database": - db = mcp_manager.databases[mcp_manager.selected_db_index] - mcp_status = f"ON (DB: {db['name']}, {len(mcp_tools)} tools)" - else: - mcp_status = f"ON ({len(mcp_tools)} tools)" - else: - mcp_status = "OFF" - - app_logger.info(f"API Request: Model '{effective_model_id}' (Online: {online_status}, MCP: {mcp_status}), Prompt length: {len(text_part)} chars, {file_count} file(s) attached, Memory: {memory_status}, History sent: {history_messages_count} messages.") - - # Send request - is_streaming = api_params["stream"] - if is_streaming: - console.print("[bold green]Streaming response...[/] [dim](Press Ctrl+C to cancel)[/]") - if online_mode_enabled: - console.print("[dim cyan]🌐 Online mode active - model has web search access[/]") - console.print("") - else: - thinking_msg = "Thinking" - if mcp_tools_added: - if mcp_manager.mode == "database": - thinking_msg = "Thinking (database tools available)" - else: - thinking_msg = "Thinking (MCP tools available)" - console.print(f"[bold green]{thinking_msg}...[/]", end="\r") - if online_mode_enabled: - console.print(f"[dim cyan]🌐 Online mode active[/]") - if mcp_tools_added: - if mcp_manager.mode == "files": - gitignore_status = "on" if (mcp_manager.server and mcp_manager.server.respect_gitignore) else "off" - console.print(f"[dim cyan]🔧 MCP active - AI can access {len(mcp_manager.allowed_folders)} folder(s) with {len(mcp_tools)} tools (.gitignore {gitignore_status})[/]") - elif mcp_manager.mode == "database": - db = mcp_manager.databases[mcp_manager.selected_db_index] - console.print(f"[dim cyan]🗄️ MCP active - AI can query database: {db['name']} ({len(db['tables'])} tables, {len(mcp_tools)} tools)[/]") - - start_time = time.time() - try: - response = client.chat.send(**api_params) - app_logger.info(f"API call successful for model '{effective_model_id}'") - except Exception as e: - console.print(f"[bold red]Error sending request: {e}[/]") - app_logger.error(f"API Error: {type(e).__name__}: {e}") - continue - - # ======================================================================== - # HANDLE TOOL CALLS (MCP FUNCTION CALLING) - # ======================================================================== - tool_call_loop_count = 0 - max_tool_loops = 5 - - while tool_call_loop_count < max_tool_loops: - wants_tool_call = False - - if hasattr(response, 'choices') and response.choices: - message = response.choices[0].message - - if hasattr(message, 'tool_calls') and message.tool_calls: - wants_tool_call = True - tool_calls = message.tool_calls - - console.print(f"\n[dim yellow]🔧 AI requesting {len(tool_calls)} tool call(s)...[/]") - app_logger.info(f"Model requested {len(tool_calls)} tool calls") - - tool_results = [] - for tool_call in tool_calls: - tool_name = tool_call.function.name - - try: - tool_args = json.loads(tool_call.function.arguments) - except json.JSONDecodeError as e: - app_logger.error(f"Failed to parse tool arguments: {e}") - tool_results.append({ - "tool_call_id": tool_call.id, - "role": "tool", - "name": tool_name, - "content": json.dumps({"error": f"Invalid arguments: {e}"}) - }) - continue - - args_display = ', '.join(f'{k}="{v}"' if isinstance(v, str) else f'{k}={v}' for k, v in tool_args.items()) - console.print(f"[dim cyan] → Calling: {tool_name}({args_display})[/]") - app_logger.info(f"Executing MCP tool: {tool_name} with args: {tool_args}") - - try: - result = asyncio.run(mcp_manager.call_tool(tool_name, **tool_args)) - except Exception as e: - app_logger.error(f"MCP tool execution error: {e}") - result = {"error": str(e)} - - if 'error' in result: - result_content = json.dumps({"error": result['error']}) - console.print(f"[dim red] ✗ Error: {result['error']}[/]") - app_logger.warning(f"MCP tool {tool_name} returned error: {result['error']}") - else: - # Display appropriate success message based on tool - if tool_name == 'search_files': - count = result.get('count', 0) - console.print(f"[dim green] ✓ Found {count} file(s)[/]") - elif tool_name == 'read_file': - size = result.get('size', 0) - truncated = " (truncated)" if result.get('truncated') else "" - console.print(f"[dim green] ✓ Read file ({size} bytes{truncated})[/]") - elif tool_name == 'list_directory': - count = result.get('count', 0) - truncated = " (limited to 1000)" if result.get('truncated') else "" - console.print(f"[dim green] ✓ Listed {count} item(s){truncated}[/]") - elif tool_name == 'inspect_database': - if 'table' in result: - console.print(f"[dim green] ✓ Inspected table: {result['table']} ({result['row_count']} rows)[/]") - else: - console.print(f"[dim green] ✓ Inspected database ({result['table_count']} tables)[/]") - elif tool_name == 'search_database': - count = result.get('count', 0) - console.print(f"[dim green] ✓ Found {count} match(es)[/]") - elif tool_name == 'query_database': - count = result.get('count', 0) - truncated = " (truncated)" if result.get('truncated') else "" - console.print(f"[dim green] ✓ Query returned {count} row(s){truncated}[/]") - else: - console.print(f"[dim green] ✓ Success[/]") - - result_content = json.dumps(result, indent=2) - app_logger.info(f"MCP tool {tool_name} succeeded") - - tool_results.append({ - "tool_call_id": tool_call.id, - "role": "tool", - "name": tool_name, - "content": result_content - }) - - api_messages.append({ - "role": "assistant", - "content": message.content, - "tool_calls": [ - { - "id": tc.id, - "type": "function", - "function": { - "name": tc.function.name, - "arguments": tc.function.arguments - } - } for tc in tool_calls - ] - }) - api_messages.extend(tool_results) - - console.print("\n[dim cyan]💭 Processing tool results...[/]") - app_logger.info("Sending tool results back to model") - - followup_params = { - "model": effective_model_id, - "messages": api_messages, - "stream": False, # Keep streaming disabled for follow-ups - "http_headers": { - "HTTP-Referer": APP_URL, - "X-Title": APP_NAME - } - } - - if mcp_tools_added: - followup_params["tools"] = mcp_tools - followup_params["tool_choice"] = "auto" - - if session_max_token > 0: - followup_params["max_tokens"] = session_max_token - - try: - response = client.chat.send(**followup_params) - tool_call_loop_count += 1 - app_logger.info(f"Follow-up request successful (loop {tool_call_loop_count})") - except Exception as e: - console.print(f"[bold red]Error getting follow-up response: {e}[/]") - app_logger.error(f"Follow-up API error: {e}") - break - - if not wants_tool_call: - break - - if tool_call_loop_count >= max_tool_loops: - console.print(f"[bold yellow]⚠️ Reached maximum tool call depth ({max_tool_loops}). Stopping.[/]") - app_logger.warning(f"Hit max tool call loop limit: {max_tool_loops}") - - response_time = time.time() - start_time - - # ======================================================================== - # PROCESS FINAL RESPONSE - # ======================================================================== - full_response = "" - stream_interrupted = False - - if is_streaming: - # Store the last chunk to get usage data after streaming completes - last_chunk = None - try: - with Live("", console=console, refresh_per_second=10, auto_refresh=True) as live: - try: - for chunk in response: - last_chunk = chunk # Keep track of last chunk for usage data - if hasattr(chunk, 'error') and chunk.error: - console.print(f"\n[bold red]Stream error: {chunk.error.message}[/]") - app_logger.error(f"Stream error: {chunk.error.message}") - break - if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content: - content_chunk = chunk.choices[0].delta.content - full_response += content_chunk - md = Markdown(full_response) - live.update(md) - except KeyboardInterrupt: - stream_interrupted = True - console.print("\n[bold yellow]⚠️ Streaming interrupted![/]") - app_logger.info("Streaming interrupted by user (Ctrl+C)") - except Exception as stream_error: - stream_interrupted = True - console.print(f"\n[bold red]Stream error: {stream_error}[/]") - app_logger.error(f"Stream processing error: {stream_error}") - - if not stream_interrupted: - console.print("") - - except KeyboardInterrupt: - # Outer interrupt handler (in case inner misses it) - stream_interrupted = True - console.print("\n[bold yellow]⚠️ Streaming interrupted![/]") - app_logger.info("Streaming interrupted by user (outer)") - except Exception as e: - stream_interrupted = True - console.print(f"\n[bold red]Error during streaming: {e}[/]") - app_logger.error(f"Streaming error: {e}") - - # If stream was interrupted, skip processing and continue to next prompt - if stream_interrupted: - if full_response: - console.print(f"\n[dim yellow]Partial response received ({len(full_response)} chars). Discarding...[/]") - console.print("[dim blue]💡 Ready for next prompt[/]\n") - app_logger.info("Stream cleanup completed, returning to prompt") - - # Force close the response if possible - try: - if hasattr(response, 'close'): - response.close() - elif hasattr(response, '__exit__'): - response.__exit__(None, None, None) - except: - pass - - # Recreate client to be safe - try: - client = OpenRouter(api_key=API_KEY) - app_logger.info("Client recreated after stream interruption") - except: - pass - - continue # Now it's safe to continue - - # For streaming, try to get usage from last chunk - if last_chunk and hasattr(last_chunk, 'usage'): - response.usage = last_chunk.usage - app_logger.debug("Extracted usage data from last streaming chunk") - elif last_chunk: - app_logger.warning("Last streaming chunk has no usage data") - - else: - full_response = response.choices[0].message.content if response.choices else "" - # Clear any processing messages before showing response - console.print(f"\r{' ' * 100}\r", end="") - - if full_response: - if not is_streaming: - md = Markdown(full_response) - # Add newline before Panel to ensure clean rendering - console.print() - console.print(Panel(md, title="[bold green]AI Response[/]", title_align="left", border_style="green")) - - # Extract usage data BEFORE appending to history - usage = getattr(response, 'usage', None) - - # DEBUG: Log what OpenRouter actually returns - if usage: - app_logger.debug(f"Usage object type: {type(usage)}") - app_logger.debug(f"Usage attributes: {dir(usage)}") - if hasattr(usage, '__dict__'): - app_logger.debug(f"Usage dict: {usage.__dict__}") - else: - app_logger.warning("No usage object in response!") - - # Try both attribute naming conventions (OpenAI standard vs Anthropic) - input_tokens = 0 - output_tokens = 0 - - if usage: - # Try prompt_tokens/completion_tokens (OpenAI/OpenRouter standard) - if hasattr(usage, 'prompt_tokens'): - input_tokens = usage.prompt_tokens or 0 - elif hasattr(usage, 'input_tokens'): - input_tokens = usage.input_tokens or 0 - - if hasattr(usage, 'completion_tokens'): - output_tokens = usage.completion_tokens or 0 - elif hasattr(usage, 'output_tokens'): - output_tokens = usage.output_tokens or 0 - - app_logger.debug(f"Extracted tokens: input={input_tokens}, output={output_tokens}") - - # Get cost from API or estimate - msg_cost = 0.0 - if usage and hasattr(usage, 'total_cost_usd') and usage.total_cost_usd: - msg_cost = float(usage.total_cost_usd) - app_logger.debug(f"Using API cost: ${msg_cost:.6f}") - else: - msg_cost = estimate_cost(input_tokens, output_tokens) - app_logger.debug(f"Estimated cost: ${msg_cost:.6f} (from {input_tokens} input + {output_tokens} output tokens)") - - # NOW append to history with cost data - session_history.append({ - 'prompt': user_input, - 'response': full_response, - 'msg_cost': msg_cost, - 'prompt_tokens': input_tokens, - 'completion_tokens': output_tokens - }) - current_index = len(session_history) - 1 - - total_input_tokens += input_tokens - total_output_tokens += output_tokens - total_cost += msg_cost - message_count += 1 - - app_logger.info(f"Response: Tokens - I:{input_tokens} O:{output_tokens} T:{input_tokens + output_tokens}, Cost: ${msg_cost:.4f}, Time: {response_time:.2f}s, Tool loops: {tool_call_loop_count}") - - if conversation_memory_enabled: - context_count = len(session_history) - memory_start_index - context_info = f", Context: {context_count} msg(s)" if context_count > 1 else "" - else: - context_info = ", Memory: OFF" - - online_info = " 🌐" if online_mode_enabled else "" - mcp_info = "" - if mcp_tools_added: - if mcp_manager.mode == "files": - mcp_info = " 🔧" - elif mcp_manager.mode == "database": - mcp_info = " 🗄️" - tool_info = f" ({tool_call_loop_count} tool loop(s))" if tool_call_loop_count > 0 else "" - console.print(f"\n[dim blue]📊 Metrics: {input_tokens + output_tokens} tokens | ${msg_cost:.4f} | {response_time:.2f}s{context_info}{online_info}{mcp_info}{tool_info} | Session: {total_input_tokens + total_output_tokens} tokens | ${total_cost:.4f}[/]") - - warnings = [] - if msg_cost > COST_WARNING_THRESHOLD: - warnings.append(f"High cost alert: This response exceeded threshold ${COST_WARNING_THRESHOLD:.4f}") - credits_data = get_credits(API_KEY, OPENROUTER_BASE_URL) - if credits_data: - warning_alerts = check_credit_alerts(credits_data) - warnings.extend(warning_alerts) - if warnings: - warning_text = ' | '.join(warnings) - console.print(f"[bold red]⚠️ {warning_text}[/]") - app_logger.warning(f"Warnings triggered: {warning_text}") - - console.print("") - try: - copy_choice = input("💾 Type 'c' to copy response, or press Enter to continue: ").strip().lower() - if copy_choice == "c": - pyperclip.copy(full_response) - console.print("[bold green]✅ Response copied to clipboard![/]") - except (EOFError, KeyboardInterrupt): - pass - - console.print("") - else: - console.print("[bold red]No response received.[/]") - app_logger.error("No response from API") - - except KeyboardInterrupt: - console.print("\n[bold yellow]Input interrupted. Continuing...[/]") - app_logger.warning("Input interrupted by Ctrl+C") - continue - except EOFError: - console.print("\n[bold yellow]Goodbye![/]") - total_tokens = total_input_tokens + total_output_tokens - app_logger.info(f"Session ended via EOF. Total messages: {message_count}, Total tokens: {total_tokens}, Total cost: ${total_cost:.4f}") - return - except Exception as e: - console.print(f"[bold red]Error: {e}[/]") - console.print("[bold yellow]Try again or select a model.[/]") - app_logger.error(f"Unexpected error: {type(e).__name__}: {e}") - -if __name__ == "__main__": - clear_screen() - app() \ No newline at end of file diff --git a/oai/__init__.py b/oai/__init__.py new file mode 100644 index 0000000..77da9b0 --- /dev/null +++ b/oai/__init__.py @@ -0,0 +1,26 @@ +""" +oAI - OpenRouter AI Chat Client + +A feature-rich terminal-based chat application that provides an interactive CLI +interface to OpenRouter's unified AI API with advanced Model Context Protocol (MCP) +integration for filesystem and database access. + +Author: Rune +License: MIT +""" + +__version__ = "2.1.0" +__author__ = "Rune" +__license__ = "MIT" + +# Lazy imports to avoid circular dependencies and improve startup time +# Full imports are available via submodules: +# from oai.config import Settings, Database +# from oai.providers import OpenRouterProvider, AIProvider +# from oai.mcp import MCPManager + +__all__ = [ + "__version__", + "__author__", + "__license__", +] diff --git a/oai/__main__.py b/oai/__main__.py new file mode 100644 index 0000000..bd02e03 --- /dev/null +++ b/oai/__main__.py @@ -0,0 +1,8 @@ +""" +Entry point for running oAI as a module: python -m oai +""" + +from oai.cli import main + +if __name__ == "__main__": + main() diff --git a/oai/cli.py b/oai/cli.py new file mode 100644 index 0000000..0405885 --- /dev/null +++ b/oai/cli.py @@ -0,0 +1,719 @@ +""" +Main CLI entry point for oAI. + +This module provides the command-line interface for the oAI chat application. +""" + +import os +import sys +from pathlib import Path +from typing import Optional + +import typer +from prompt_toolkit import PromptSession +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.history import FileHistory +from rich.markdown import Markdown +from rich.panel import Panel + +from oai import __version__ +from oai.commands import register_all_commands, registry +from oai.commands.registry import CommandContext, CommandStatus +from oai.config.database import Database +from oai.config.settings import Settings +from oai.constants import ( + APP_NAME, + APP_URL, + APP_VERSION, + CONFIG_DIR, + HISTORY_FILE, + VALID_COMMANDS, +) +from oai.core.client import AIClient +from oai.core.session import ChatSession +from oai.mcp.manager import MCPManager +from oai.providers.base import UsageStats +from oai.providers.openrouter import OpenRouterProvider +from oai.ui.console import ( + clear_screen, + console, + display_panel, + print_error, + print_info, + print_success, + print_warning, +) +from oai.ui.tables import create_model_table, display_paginated_table +from oai.utils.logging import LoggingManager, get_logger + +# Create Typer app +app = typer.Typer( + name="oai", + help=f"oAI - OpenRouter AI Chat Client\n\nVersion: {APP_VERSION}", + add_completion=False, + epilog="For more information, visit: " + APP_URL, +) + + +@app.callback(invoke_without_command=True) +def main_callback( + ctx: typer.Context, + version_flag: bool = typer.Option( + False, + "--version", + "-v", + help="Show version information", + is_flag=True, + ), +) -> None: + """Main callback to handle global options.""" + # Show version with update check if --version flag + if version_flag: + version_info = check_for_updates(APP_VERSION) + console.print(version_info) + raise typer.Exit() + + # Show version with update check when --help is requested + if "--help" in sys.argv or "-h" in sys.argv: + version_info = check_for_updates(APP_VERSION) + console.print(f"\n{version_info}\n") + + # Continue to subcommand if provided + if ctx.invoked_subcommand is None: + return + + +def check_for_updates(current_version: str) -> str: + """Check for available updates.""" + import requests + from packaging import version as pkg_version + + try: + response = requests.get( + "https://gitlab.pm/api/v1/repos/rune/oai/releases/latest", + headers={"Content-Type": "application/json"}, + timeout=1.0, + ) + response.raise_for_status() + + data = response.json() + version_online = data.get("tag_name", "").lstrip("v") + + if not version_online: + return f"[bold green]oAI version {current_version}[/]" + + current = pkg_version.parse(current_version) + latest = pkg_version.parse(version_online) + + if latest > current: + return ( + f"[bold green]oAI version {current_version}[/] " + f"[bold red](Update available: {current_version} → {version_online})[/]" + ) + return f"[bold green]oAI version {current_version} (up to date)[/]" + + except Exception: + return f"[bold green]oAI version {current_version}[/]" + + +def show_welcome(settings: Settings, version_info: str) -> None: + """Display welcome message.""" + console.print(Panel.fit( + f"{version_info}\n\n" + "[bold cyan]Commands:[/] /help for commands, /model to select model\n" + "[bold cyan]MCP:[/] /mcp on to enable file/database access\n" + "[bold cyan]Exit:[/] Type 'exit', 'quit', or 'bye'", + title=f"[bold green]Welcome to {APP_NAME}[/]", + border_style="green", + )) + + +def select_model(client: AIClient, search_term: Optional[str] = None) -> Optional[dict]: + """Display model selection interface.""" + try: + models = client.provider.get_raw_models() + if not models: + print_error("No models available") + return None + + # Filter by search term if provided + if search_term: + search_lower = search_term.lower() + models = [m for m in models if search_lower in m.get("id", "").lower()] + + if not models: + print_error(f"No models found matching '{search_term}'") + return None + + # Create and display table + table = create_model_table(models) + display_paginated_table( + table, + f"[bold green]Available Models ({len(models)})[/]", + ) + + # Prompt for selection + console.print("") + try: + choice = input("Enter model number (or press Enter to cancel): ").strip() + except (EOFError, KeyboardInterrupt): + return None + + if not choice: + return None + + try: + index = int(choice) - 1 + if 0 <= index < len(models): + selected = models[index] + print_success(f"Selected model: {selected['id']}") + return selected + except ValueError: + pass + + print_error("Invalid selection") + return None + + except Exception as e: + print_error(f"Failed to fetch models: {e}") + return None + + +def run_chat_loop( + session: ChatSession, + prompt_session: PromptSession, + settings: Settings, +) -> None: + """Run the main chat loop.""" + logger = get_logger() + mcp_manager = session.mcp_manager + + while True: + try: + # Build prompt prefix + prefix = "You> " + if mcp_manager and mcp_manager.enabled: + if mcp_manager.mode == "files": + if mcp_manager.write_enabled: + prefix = "[🔧✍️ MCP: Files+Write] You> " + else: + prefix = "[🔧 MCP: Files] You> " + elif mcp_manager.mode == "database" and mcp_manager.selected_db_index is not None: + prefix = f"[🗄️ MCP: DB #{mcp_manager.selected_db_index + 1}] You> " + + # Get user input + user_input = prompt_session.prompt( + prefix, + auto_suggest=AutoSuggestFromHistory(), + ).strip() + + if not user_input: + continue + + # Handle escape sequence + if user_input.startswith("//"): + user_input = user_input[1:] + + # Check for exit + if user_input.lower() in ["exit", "quit", "bye"]: + console.print( + f"\n[bold yellow]Goodbye![/]\n" + f"[dim]Session: {session.stats.total_tokens:,} tokens, " + f"${session.stats.total_cost:.4f}[/]" + ) + logger.info( + f"Session ended. Messages: {session.stats.message_count}, " + f"Tokens: {session.stats.total_tokens}, " + f"Cost: ${session.stats.total_cost:.4f}" + ) + return + + # Check for unknown commands + if user_input.startswith("/"): + cmd_word = user_input.split()[0].lower() + if not registry.is_command(user_input): + # Check if it's a valid command prefix + is_valid = any(cmd_word.startswith(cmd) for cmd in VALID_COMMANDS) + if not is_valid: + print_error(f"Unknown command: {cmd_word}") + print_info("Type /help to see available commands.") + continue + + # Try to execute as command + context = session.get_context() + result = registry.execute(user_input, context) + + if result: + # Update session state from context + session.memory_enabled = context.memory_enabled + session.memory_start_index = context.memory_start_index + session.online_enabled = context.online_enabled + session.middle_out_enabled = context.middle_out_enabled + session.session_max_token = context.session_max_token + session.current_index = context.current_index + session.system_prompt = context.session_system_prompt + + if result.status == CommandStatus.EXIT: + return + + # Handle special results + if result.data: + # Retry - resend last prompt + if "retry_prompt" in result.data: + user_input = result.data["retry_prompt"] + # Fall through to send message + + # Paste - send clipboard content + elif "paste_prompt" in result.data: + user_input = result.data["paste_prompt"] + # Fall through to send message + + # Model selection + elif "show_model_selector" in result.data: + search = result.data.get("search", "") + model = select_model(session.client, search if search else None) + if model: + session.set_model(model) + # If this came from /config model, also save as default + if result.data.get("set_as_default"): + settings.set_default_model(model["id"]) + print_success(f"Default model set to: {model['id']}") + continue + + # Load conversation + elif "load_conversation" in result.data: + history = result.data.get("history", []) + session.history.clear() + from oai.core.session import HistoryEntry + for entry in history: + session.history.append(HistoryEntry( + prompt=entry.get("prompt", ""), + response=entry.get("response", ""), + prompt_tokens=entry.get("prompt_tokens", 0), + completion_tokens=entry.get("completion_tokens", 0), + msg_cost=entry.get("msg_cost", 0.0), + )) + session.current_index = len(session.history) - 1 + continue + + else: + # Normal command completed + continue + else: + # Command completed with no special data + continue + + # Ensure model is selected + if not session.selected_model: + print_warning("Please select a model first with /model") + continue + + # Send message + stream = settings.stream_enabled + if mcp_manager and mcp_manager.enabled: + tools = session.get_mcp_tools() + if tools: + stream = False # Disable streaming with tools + + if stream: + console.print( + "[bold green]Streaming response...[/] " + "[dim](Press Ctrl+C to cancel)[/]" + ) + if session.online_enabled: + console.print("[dim cyan]🌐 Online mode active[/]") + console.print("") + + try: + response_text, usage, response_time = session.send_message( + user_input, + stream=stream, + ) + except Exception as e: + print_error(f"Error: {e}") + logger.error(f"Message error: {e}") + continue + + if not response_text: + print_error("No response received") + continue + + # Display non-streaming response + if not stream: + console.print() + display_panel( + Markdown(response_text), + title="[bold green]AI Response[/]", + border_style="green", + ) + + # Calculate cost and tokens + cost = 0.0 + tokens = 0 + estimated = False + + if usage and (usage.prompt_tokens > 0 or usage.completion_tokens > 0): + tokens = usage.total_tokens + if usage.total_cost_usd: + cost = usage.total_cost_usd + else: + cost = session.client.estimate_cost( + session.selected_model["id"], + usage.prompt_tokens, + usage.completion_tokens, + ) + else: + # Estimate tokens when usage not available (streaming fallback) + # Rough estimate: ~4 characters per token for English text + est_input_tokens = len(user_input) // 4 + 1 + est_output_tokens = len(response_text) // 4 + 1 + tokens = est_input_tokens + est_output_tokens + cost = session.client.estimate_cost( + session.selected_model["id"], + est_input_tokens, + est_output_tokens, + ) + # Create estimated usage for session tracking + usage = UsageStats( + prompt_tokens=est_input_tokens, + completion_tokens=est_output_tokens, + total_tokens=tokens, + ) + estimated = True + + # Add to history + session.add_to_history(user_input, response_text, usage, cost) + + # Display metrics + est_marker = "~" if estimated else "" + context_info = "" + if session.memory_enabled: + context_count = len(session.history) - session.memory_start_index + if context_count > 1: + context_info = f", Context: {context_count} msg(s)" + else: + context_info = ", Memory: OFF" + + online_emoji = " 🌐" if session.online_enabled else "" + mcp_emoji = "" + if mcp_manager and mcp_manager.enabled: + if mcp_manager.mode == "files": + mcp_emoji = " 🔧" + elif mcp_manager.mode == "database": + mcp_emoji = " 🗄️" + + console.print( + f"\n[dim blue]📊 {est_marker}{tokens} tokens | {est_marker}${cost:.4f} | {response_time:.2f}s" + f"{context_info}{online_emoji}{mcp_emoji} | " + f"Session: {est_marker}{session.stats.total_tokens:,} tokens | " + f"{est_marker}${session.stats.total_cost:.4f}[/]" + ) + + # Check warnings + warnings = session.check_warnings() + for warning in warnings: + print_warning(warning) + + # Offer to copy + console.print("") + try: + from oai.ui.prompts import prompt_copy_response + prompt_copy_response(response_text) + except Exception: + pass + console.print("") + + except KeyboardInterrupt: + console.print("\n[dim]Input cancelled[/]") + continue + except EOFError: + console.print("\n[bold yellow]Goodbye![/]") + return + + +@app.command() +def chat( + model: Optional[str] = typer.Option( + None, + "--model", + "-m", + help="Model ID to use", + ), + system: Optional[str] = typer.Option( + None, + "--system", + "-s", + help="System prompt", + ), + online: bool = typer.Option( + False, + "--online", + "-o", + help="Enable online mode", + ), + mcp: bool = typer.Option( + False, + "--mcp", + help="Enable MCP server", + ), +) -> None: + """Start an interactive chat session.""" + # Setup logging + logging_manager = LoggingManager() + logging_manager.setup() + logger = get_logger() + + # Clear screen + clear_screen() + + # Load settings + settings = Settings.load() + + # Check API key + if not settings.api_key: + print_error("No API key configured") + print_info("Run: oai --config api to set your API key") + raise typer.Exit(1) + + # Initialize client + try: + client = AIClient( + api_key=settings.api_key, + base_url=settings.base_url, + ) + except Exception as e: + print_error(f"Failed to initialize client: {e}") + raise typer.Exit(1) + + # Register commands + register_all_commands() + + # Check for updates and show welcome + version_info = check_for_updates(APP_VERSION) + show_welcome(settings, version_info) + + # Initialize MCP manager + mcp_manager = MCPManager() + if mcp: + result = mcp_manager.enable() + if result["success"]: + print_success("MCP enabled") + else: + print_warning(f"MCP: {result.get('error', 'Failed to enable')}") + + # Create session + session = ChatSession( + client=client, + settings=settings, + mcp_manager=mcp_manager, + ) + + # Set system prompt + if system: + session.system_prompt = system + print_info(f"System prompt: {system}") + + # Set online mode + if online: + session.online_enabled = True + print_info("Online mode enabled") + + # Select model + if model: + raw_model = client.get_raw_model(model) + if raw_model: + session.set_model(raw_model) + else: + print_warning(f"Model '{model}' not found") + elif settings.default_model: + raw_model = client.get_raw_model(settings.default_model) + if raw_model: + session.set_model(raw_model) + else: + print_warning(f"Default model '{settings.default_model}' not available") + + # Setup prompt session + HISTORY_FILE.parent.mkdir(parents=True, exist_ok=True) + prompt_session = PromptSession( + history=FileHistory(str(HISTORY_FILE)), + ) + + # Run chat loop + run_chat_loop(session, prompt_session, settings) + + +@app.command() +def config( + setting: Optional[str] = typer.Argument( + None, + help="Setting to configure (api, url, model, system, stream, costwarning, maxtoken, online, log, loglevel)", + ), + value: Optional[str] = typer.Argument( + None, + help="Value to set", + ), +) -> None: + """View or modify configuration settings.""" + settings = Settings.load() + + if not setting: + # Show all settings + from rich.table import Table + from oai.constants import DEFAULT_SYSTEM_PROMPT + + table = Table("Setting", "Value", show_header=True, header_style="bold magenta") + table.add_row("API Key", "***" + settings.api_key[-4:] if settings.api_key else "Not set") + table.add_row("Base URL", settings.base_url) + table.add_row("Default Model", settings.default_model or "Not set") + + # Show system prompt status + if settings.default_system_prompt is None: + system_prompt_display = f"[default] {DEFAULT_SYSTEM_PROMPT[:40]}..." + elif settings.default_system_prompt == "": + system_prompt_display = "[blank]" + else: + system_prompt_display = settings.default_system_prompt[:50] + "..." if len(settings.default_system_prompt) > 50 else settings.default_system_prompt + table.add_row("System Prompt", system_prompt_display) + + table.add_row("Streaming", "on" if settings.stream_enabled else "off") + table.add_row("Cost Warning", f"${settings.cost_warning_threshold:.4f}") + table.add_row("Max Tokens", str(settings.max_tokens)) + table.add_row("Default Online", "on" if settings.default_online_mode else "off") + table.add_row("Log Level", settings.log_level) + + display_panel(table, title="[bold green]Configuration[/]") + return + + setting = setting.lower() + + if setting == "api": + if value: + settings.set_api_key(value) + else: + from oai.ui.prompts import prompt_input + new_key = prompt_input("Enter API key", password=True) + if new_key: + settings.set_api_key(new_key) + print_success("API key updated") + + elif setting == "url": + settings.set_base_url(value or "https://openrouter.ai/api/v1") + print_success(f"Base URL set to: {settings.base_url}") + + elif setting == "model": + if value: + settings.set_default_model(value) + print_success(f"Default model set to: {value}") + else: + print_info(f"Current default model: {settings.default_model or 'Not set'}") + + elif setting == "system": + from oai.constants import DEFAULT_SYSTEM_PROMPT + + if value: + # Decode escape sequences like \n for newlines + value = value.encode().decode('unicode_escape') + settings.set_default_system_prompt(value) + if value: + print_success(f"Default system prompt set to: {value}") + else: + print_success("Default system prompt set to blank.") + else: + if settings.default_system_prompt is None: + print_info(f"Using hardcoded default: {DEFAULT_SYSTEM_PROMPT[:60]}...") + elif settings.default_system_prompt == "": + print_info("System prompt: [blank]") + else: + print_info(f"System prompt: {settings.default_system_prompt}") + + elif setting == "stream": + if value and value.lower() in ["on", "off"]: + settings.set_stream_enabled(value.lower() == "on") + print_success(f"Streaming {'enabled' if settings.stream_enabled else 'disabled'}") + else: + print_info("Usage: oai config stream [on|off]") + + elif setting == "costwarning": + if value: + try: + threshold = float(value) + settings.set_cost_warning_threshold(threshold) + print_success(f"Cost warning threshold set to: ${threshold:.4f}") + except ValueError: + print_error("Please enter a valid number") + else: + print_info(f"Current threshold: ${settings.cost_warning_threshold:.4f}") + + elif setting == "maxtoken": + if value: + try: + max_tok = int(value) + settings.set_max_tokens(max_tok) + print_success(f"Max tokens set to: {max_tok}") + except ValueError: + print_error("Please enter a valid number") + else: + print_info(f"Current max tokens: {settings.max_tokens}") + + elif setting == "online": + if value and value.lower() in ["on", "off"]: + settings.set_default_online_mode(value.lower() == "on") + print_success(f"Default online mode {'enabled' if settings.default_online_mode else 'disabled'}") + else: + print_info("Usage: oai config online [on|off]") + + elif setting == "loglevel": + valid_levels = ["debug", "info", "warning", "error", "critical"] + if value and value.lower() in valid_levels: + settings.set_log_level(value.lower()) + print_success(f"Log level set to: {value.lower()}") + else: + print_info(f"Valid levels: {', '.join(valid_levels)}") + + else: + print_error(f"Unknown setting: {setting}") + + +@app.command() +def version() -> None: + """Show version information.""" + version_info = check_for_updates(APP_VERSION) + console.print(version_info) + + +@app.command() +def credits() -> None: + """Check account credits.""" + settings = Settings.load() + + if not settings.api_key: + print_error("No API key configured") + raise typer.Exit(1) + + client = AIClient(api_key=settings.api_key, base_url=settings.base_url) + credits_data = client.get_credits() + + if not credits_data: + print_error("Failed to fetch credits") + raise typer.Exit(1) + + from rich.table import Table + + table = Table("Metric", "Value", show_header=True, header_style="bold magenta") + table.add_row("Total Credits", credits_data.get("total_credits_formatted", "N/A")) + table.add_row("Used Credits", credits_data.get("used_credits_formatted", "N/A")) + table.add_row("Credits Left", credits_data.get("credits_left_formatted", "N/A")) + + display_panel(table, title="[bold green]Account Credits[/]") + + +def main() -> None: + """Main entry point.""" + # Default to 'chat' command if no arguments provided + if len(sys.argv) == 1: + sys.argv.append("chat") + app() + + +if __name__ == "__main__": + main() diff --git a/oai/commands/__init__.py b/oai/commands/__init__.py new file mode 100644 index 0000000..37068b3 --- /dev/null +++ b/oai/commands/__init__.py @@ -0,0 +1,24 @@ +""" +Command system for oAI. + +This module provides a command registry and handler system +for processing slash commands in the chat interface. +""" + +from oai.commands.registry import ( + Command, + CommandRegistry, + CommandContext, + CommandResult, + registry, +) +from oai.commands.handlers import register_all_commands + +__all__ = [ + "Command", + "CommandRegistry", + "CommandContext", + "CommandResult", + "registry", + "register_all_commands", +] diff --git a/oai/commands/handlers.py b/oai/commands/handlers.py new file mode 100644 index 0000000..49f23a8 --- /dev/null +++ b/oai/commands/handlers.py @@ -0,0 +1,1441 @@ +""" +Command handlers for oAI. + +This module implements all the slash commands available in the chat interface. +Each command is registered with the global registry. +""" + +import json +from typing import Any, Dict, List, Optional + +from oai.commands.registry import ( + Command, + CommandContext, + CommandHelp, + CommandResult, + CommandStatus, + registry, +) +from oai.constants import COMMAND_HELP, VALID_COMMANDS +from oai.ui.console import ( + clear_screen, + console, + display_markdown, + display_panel, + print_error, + print_info, + print_success, + print_warning, +) +from oai.ui.prompts import prompt_confirm, prompt_input +from oai.ui.tables import ( + create_model_table, + create_stats_table, + display_paginated_table, +) +from oai.utils.export import export_as_html, export_as_json, export_as_markdown +from oai.utils.logging import get_logger + + +class HelpCommand(Command): + """Display help information for commands.""" + + @property + def name(self) -> str: + return "/help" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Display help information for commands.", + usage="/help [command|topic]", + examples=[ + ("Show all commands", "/help"), + ("Get help for a specific command", "/help /model"), + ("Get detailed MCP help", "/help mcp"), + ], + notes="Use /help without arguments to see the full command list.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + logger = get_logger() + + if args: + # Show help for specific command/topic + self._show_command_help(args) + else: + # Show general help + self._show_general_help(context) + + logger.info(f"Displayed help for: {args or 'general'}") + return CommandResult.success() + + def _show_command_help(self, command_or_topic: str) -> None: + """Show help for a specific command or topic.""" + # Handle topics (like 'mcp') + if not command_or_topic.startswith("/"): + if command_or_topic.lower() == "mcp": + help_data = COMMAND_HELP.get("mcp", {}) + if help_data: + content = [] + content.append(f"[bold cyan]Description:[/]") + content.append(help_data.get("description", "")) + content.append("") + content.append(help_data.get("notes", "")) + + display_panel( + "\n".join(content), + title="[bold green]MCP - Model Context Protocol Guide[/]", + border_style="green", + ) + return + command_or_topic = "/" + command_or_topic + + help_data = COMMAND_HELP.get(command_or_topic) + if not help_data: + print_error(f"Unknown command: {command_or_topic}") + print_warning("Type /help to see all available commands.") + return + + content = [] + + if help_data.get("aliases"): + content.append(f"[bold cyan]Aliases:[/] {', '.join(help_data['aliases'])}") + content.append("") + + content.append("[bold cyan]Description:[/]") + content.append(help_data.get("description", "")) + content.append("") + + content.append("[bold cyan]Usage:[/]") + content.append(f"[yellow]{help_data.get('usage', '')}[/]") + content.append("") + + examples = help_data.get("examples", []) + if examples: + content.append("[bold cyan]Examples:[/]") + for desc, example in examples: + if not desc and not example: + content.append("") + elif desc.startswith("━━━"): + content.append(f"[bold yellow]{desc}[/]") + else: + if desc: + content.append(f" [dim]{desc}:[/]") + if example: + content.append(f" [green]{example}[/]") + content.append("") + + notes = help_data.get("notes") + if notes: + content.append("[bold cyan]Notes:[/]") + content.append(f"[dim]{notes}[/]") + + display_panel( + "\n".join(content), + title=f"[bold green]Help: {command_or_topic}[/]", + border_style="green", + ) + + def _show_general_help(self, context: CommandContext) -> None: + """Show general help with all commands.""" + from rich.table import Table + + table = Table( + "Command", + "Description", + "Example", + show_header=True, + header_style="bold magenta", + show_lines=False, + ) + + # Group commands by category + categories = [ + ("[bold cyan]━━━ CHAT ━━━[/]", [ + ("/retry", "Resend the last prompt.", "/retry"), + ("/memory", "Toggle conversation memory.", "/memory on"), + ("/online", "Toggle online mode (web search).", "/online on"), + ("/paste", "Paste from clipboard with optional prompt.", "/paste Explain"), + ]), + ("[bold cyan]━━━ NAVIGATION ━━━[/]", [ + ("/prev", "View previous response in history.", "/prev"), + ("/next", "View next response in history.", "/next"), + ("/reset", "Clear conversation history.", "/reset"), + ]), + ("[bold cyan]━━━ MODEL & CONFIG ━━━[/]", [ + ("/model", "Select AI model.", "/model gpt"), + ("/info", "Show model information.", "/info"), + ("/config", "View or change settings.", "/config stream on"), + ("/maxtoken", "Set session token limit.", "/maxtoken 2000"), + ("/system", "Set system prompt.", "/system You are an expert"), + ]), + ("[bold cyan]━━━ SAVE & EXPORT ━━━[/]", [ + ("/save", "Save conversation.", "/save my_chat"), + ("/load", "Load saved conversation.", "/load my_chat"), + ("/delete", "Delete saved conversation.", "/delete my_chat"), + ("/list", "List saved conversations.", "/list"), + ("/export", "Export conversation.", "/export md notes.md"), + ]), + ("[bold cyan]━━━ STATS & INFO ━━━[/]", [ + ("/stats", "Show session statistics.", "/stats"), + ("/credits", "Show account credits.", "/credits"), + ("/middleout", "Toggle middle-out compression.", "/middleout on"), + ]), + ("[bold cyan]━━━ MCP (FILE/DB ACCESS) ━━━[/]", [ + ("/mcp on", "Enable MCP server.", "/mcp on"), + ("/mcp add", "Add folder or database.", "/mcp add ~/Documents"), + ("/mcp status", "Show MCP status.", "/mcp status"), + ("/mcp write", "Toggle write mode.", "/mcp write on"), + ]), + ("[bold cyan]━━━ UTILITY ━━━[/]", [ + ("/clear", "Clear the screen.", "/clear"), + ("/help", "Show this help.", "/help /model"), + ]), + ("[bold yellow]━━━ EXIT ━━━[/]", [ + ("exit", "Quit the application.", "exit"), + ]), + ] + + for header, commands in categories: + table.add_row(header, "", "") + for cmd, desc, example in commands: + table.add_row(cmd, desc, example) + + from oai.constants import APP_VERSION + + display_panel( + table, + title=f"[bold cyan]oAI Chat Help (Version {APP_VERSION})[/]", + subtitle="💡 Use /help for details • /help mcp for MCP guide", + ) + + +class ClearCommand(Command): + """Clear the terminal screen.""" + + @property + def name(self) -> str: + return "/clear" + + @property + def aliases(self) -> List[str]: + return ["/cl"] + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Clear the terminal screen.", + usage="/clear", + aliases=["/cl"], + notes="You can also use Ctrl+L.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + clear_screen() + return CommandResult.success() + + +class MemoryCommand(Command): + """Toggle conversation memory.""" + + @property + def name(self) -> str: + return "/memory" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Toggle conversation memory.", + usage="/memory [on|off]", + examples=[ + ("Check status", "/memory"), + ("Enable memory", "/memory on"), + ("Disable memory", "/memory off"), + ], + notes="When off, each message is independent (saves tokens).", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + status = "enabled" if context.memory_enabled else "disabled" + print_info(f"Conversation memory is {status}.") + return CommandResult.success() + + if args.lower() == "on": + context.memory_enabled = True + print_success("Memory enabled - AI will remember conversation.") + return CommandResult.success(data={"memory_enabled": True}) + elif args.lower() == "off": + context.memory_enabled = False + context.memory_start_index = len(context.session_history) + print_success("Memory disabled - each message is independent.") + return CommandResult.success(data={"memory_enabled": False}) + else: + print_error("Usage: /memory [on|off]") + return CommandResult.error("Invalid argument") + + +class OnlineCommand(Command): + """Toggle online mode (web search).""" + + @property + def name(self) -> str: + return "/online" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Enable or disable online mode (web search).", + usage="/online [on|off]", + examples=[ + ("Check status", "/online"), + ("Enable web search", "/online on"), + ("Disable web search", "/online off"), + ], + notes="Not all models support online mode.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + status = "enabled" if context.online_enabled else "disabled" + print_info(f"Online mode is {status}.") + return CommandResult.success() + + if args.lower() == "on": + if context.selected_model_raw: + params = context.selected_model_raw.get("supported_parameters", []) + if "tools" not in params: + print_warning("Current model may not support online mode.") + + context.online_enabled = True + print_success("Online mode enabled - AI can search the web.") + return CommandResult.success(data={"online_enabled": True}) + elif args.lower() == "off": + context.online_enabled = False + print_success("Online mode disabled.") + return CommandResult.success(data={"online_enabled": False}) + else: + print_error("Usage: /online [on|off]") + return CommandResult.error("Invalid argument") + + +class ResetCommand(Command): + """Reset conversation history.""" + + @property + def name(self) -> str: + return "/reset" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Clear conversation history and reset system prompt.", + usage="/reset", + notes="Requires confirmation. Resets all session metrics.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not prompt_confirm("Reset conversation and clear history?"): + print_info("Reset cancelled.") + return CommandResult.success() + + context.session_history.clear() + context.session_system_prompt = "" + context.memory_start_index = 0 + context.current_index = 0 + context.total_input_tokens = 0 + context.total_output_tokens = 0 + context.total_cost = 0.0 + context.message_count = 0 + + print_success("Conversation reset. History and system prompt cleared.") + get_logger().info("Conversation reset by user") + return CommandResult.success() + + +class StatsCommand(Command): + """Display session statistics.""" + + @property + def name(self) -> str: + return "/stats" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Display session statistics.", + usage="/stats", + notes="Shows tokens, costs, and credits.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + from rich.table import Table + + table = Table( + "Metric", + "Value", + show_header=True, + header_style="bold magenta", + ) + + table.add_row("Input Tokens", f"{context.total_input_tokens:,}") + table.add_row("Output Tokens", f"{context.total_output_tokens:,}") + table.add_row( + "Total Tokens", + f"{context.total_input_tokens + context.total_output_tokens:,}", + ) + table.add_row("Total Cost", f"${context.total_cost:.4f}") + + if context.message_count > 0: + avg_cost = context.total_cost / context.message_count + table.add_row("Avg Cost/Message", f"${avg_cost:.4f}") + + table.add_row("Messages", str(context.message_count)) + + # Get credits if provider available + if context.provider: + credits = context.provider.get_credits() + if credits: + table.add_row("", "") + table.add_row("[bold]Credits[/]", "") + table.add_row( + "Credits Left", + credits.get("credits_left_formatted", "N/A"), + ) + + display_panel(table, title="[bold green]Session Statistics[/]") + return CommandResult.success() + + +class CreditsCommand(Command): + """Display account credits.""" + + @property + def name(self) -> str: + return "/credits" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Display your OpenRouter account credits.", + usage="/credits", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not context.provider: + print_error("No provider configured.") + return CommandResult.error("No provider") + + credits = context.provider.get_credits() + if not credits: + print_error("Failed to fetch credits.") + return CommandResult.error("Failed to fetch credits") + + from rich.table import Table + + table = Table( + "Metric", + "Value", + show_header=True, + header_style="bold magenta", + ) + + table.add_row("Total Credits", credits.get("total_credits_formatted", "N/A")) + table.add_row("Used Credits", credits.get("used_credits_formatted", "N/A")) + table.add_row("Credits Left", credits.get("credits_left_formatted", "N/A")) + + # Check for warnings + from oai.constants import LOW_CREDIT_AMOUNT, LOW_CREDIT_RATIO + + credits_left = credits.get("credits_left", 0) + total = credits.get("total_credits", 0) + + warnings = [] + if credits_left < LOW_CREDIT_AMOUNT: + warnings.append(f"Less than ${LOW_CREDIT_AMOUNT:.2f} remaining!") + elif total > 0 and credits_left < total * LOW_CREDIT_RATIO: + warnings.append("Less than 10% of credits remaining!") + + display_panel(table, title="[bold green]Account Credits[/]") + + for warning in warnings: + print_warning(warning) + + return CommandResult.success(data=credits) + + +class ExportCommand(Command): + """Export conversation to file.""" + + @property + def name(self) -> str: + return "/export" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Export conversation to a file.", + usage="/export ", + examples=[ + ("Export as Markdown", "/export md notes.md"), + ("Export as JSON", "/export json conversation.json"), + ("Export as HTML", "/export html report.html"), + ], + notes="Available formats: md, json, html", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + parts = args.split(maxsplit=1) + if len(parts) != 2: + print_error("Usage: /export ") + return CommandResult.error("Invalid arguments") + + fmt, filename = parts + fmt = fmt.lower() + + if fmt not in ["md", "json", "html"]: + print_error(f"Unknown format: {fmt}") + print_info("Available formats: md, json, html") + return CommandResult.error("Invalid format") + + if not context.session_history: + print_warning("No conversation to export.") + return CommandResult.error("Empty history") + + try: + if fmt == "md": + content = export_as_markdown( + context.session_history, + context.session_system_prompt, + ) + elif fmt == "json": + content = export_as_json( + context.session_history, + context.session_system_prompt, + ) + else: # html + content = export_as_html( + context.session_history, + context.session_system_prompt, + ) + + with open(filename, "w", encoding="utf-8") as f: + f.write(content) + + print_success(f"Exported to {filename}") + get_logger().info(f"Exported conversation to {filename} ({fmt})") + return CommandResult.success() + + except Exception as e: + print_error(f"Export failed: {e}") + return CommandResult.error(str(e)) + + +class MiddleOutCommand(Command): + """Toggle middle-out compression.""" + + @property + def name(self) -> str: + return "/middleout" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Toggle middle-out transform for long prompts.", + usage="/middleout [on|off]", + examples=[ + ("Check status", "/middleout"), + ("Enable", "/middleout on"), + ("Disable", "/middleout off"), + ], + notes="Compresses prompts exceeding context size.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + status = "enabled" if context.middle_out_enabled else "disabled" + print_info(f"Middle-out compression is {status}.") + return CommandResult.success() + + if args.lower() == "on": + context.middle_out_enabled = True + print_success("Middle-out compression enabled.") + return CommandResult.success(data={"middle_out_enabled": True}) + elif args.lower() == "off": + context.middle_out_enabled = False + print_success("Middle-out compression disabled.") + return CommandResult.success(data={"middle_out_enabled": False}) + else: + print_error("Usage: /middleout [on|off]") + return CommandResult.error("Invalid argument") + + +class MaxTokenCommand(Command): + """Set session token limit.""" + + @property + def name(self) -> str: + return "/maxtoken" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Set session token limit.", + usage="/maxtoken [value]", + examples=[ + ("View current", "/maxtoken"), + ("Set to 2000", "/maxtoken 2000"), + ], + notes="Cannot exceed stored max token setting.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + if context.session_max_token > 0: + print_info(f"Session max token: {context.session_max_token}") + else: + print_info("No session token limit set (using model default).") + return CommandResult.success() + + try: + value = int(args) + if value <= 0: + print_error("Token limit must be positive.") + return CommandResult.error("Invalid value") + + # Check against stored limit + stored_max = 100000 # Default + if context.settings: + stored_max = context.settings.max_tokens + + if value > stored_max: + print_warning( + f"Value {value} exceeds stored limit {stored_max}. " + f"Using {stored_max}." + ) + value = stored_max + + context.session_max_token = value + print_success(f"Session max token set to {value}.") + return CommandResult.success(data={"session_max_token": value}) + + except ValueError: + print_error("Please enter a valid number.") + return CommandResult.error("Invalid number") + + +class SystemCommand(Command): + """Set session system prompt.""" + + @property + def name(self) -> str: + return "/system" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Set or clear the session system prompt.", + usage="/system [prompt|clear|default ]", + examples=[ + ("View current", "/system"), + ("Set prompt", "/system You are a Python expert"), + ("Multiline prompt", r"/system You are an expert.\nRespond clearly."), + ("Blank prompt", '/system ""'), + ("Save as default", "/system default You are a Python expert"), + ("Revert to default", "/system clear"), + ], + notes=r'Use \n for newlines. Use /system "" for blank, /system clear to revert to default.', + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + from oai.constants import DEFAULT_SYSTEM_PROMPT + + if not args: + # Show current session prompt and default + if context.session_system_prompt: + print_info(f"Session prompt: {context.session_system_prompt}") + else: + print_info("Session prompt: [blank]") + + if context.settings: + if context.settings.default_system_prompt is None: + print_info(f"Default prompt: [hardcoded] {DEFAULT_SYSTEM_PROMPT[:60]}...") + elif context.settings.default_system_prompt == "": + print_info("Default prompt: [blank]") + else: + print_info(f"Custom default: {context.settings.default_system_prompt}") + return CommandResult.success() + + if args.lower() == "clear": + # Revert to hardcoded default + if context.settings: + context.settings.clear_default_system_prompt() + context.session_system_prompt = DEFAULT_SYSTEM_PROMPT + print_success("Reverted to hardcoded default system prompt.") + print_info(f"Default: {DEFAULT_SYSTEM_PROMPT[:60]}...") + else: + context.session_system_prompt = DEFAULT_SYSTEM_PROMPT + print_success("Session prompt reverted to default.") + return CommandResult.success() + + # Check for default command + if args.lower().startswith("default "): + prompt = args[8:] # Remove "default " prefix (keep trailing spaces) + if not prompt: + print_error("Usage: /system default ") + return CommandResult.error("No prompt provided") + + # Decode escape sequences like \n for newlines + prompt = prompt.encode().decode('unicode_escape') + + if context.settings: + context.settings.set_default_system_prompt(prompt) + context.session_system_prompt = prompt + if prompt: + print_success(f"Default system prompt saved: {prompt}") + else: + print_success("Default system prompt set to blank.") + get_logger().info(f"Default system prompt updated: {prompt[:50]}...") + else: + print_error("Settings not available") + return CommandResult.error("No settings") + return CommandResult.success() + + # Decode escape sequences like \n for newlines + prompt = args.encode().decode('unicode_escape') + + context.session_system_prompt = prompt + if prompt: + print_success(f"Session system prompt set: {prompt}") + print_info("Use '/system default ' to save as default for all sessions") + else: + print_success("Session system prompt set to blank.") + get_logger().info(f"System prompt updated: {prompt[:50]}...") + return CommandResult.success() + + +class RetryCommand(Command): + """Resend the last prompt.""" + + @property + def name(self) -> str: + return "/retry" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Resend the last prompt.", + usage="/retry", + notes="Requires at least one message in history.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not context.session_history: + print_error("No message to retry.") + return CommandResult.error("Empty history") + + last_prompt = context.session_history[-1].get("prompt", "") + if not last_prompt: + print_error("Last message has no prompt.") + return CommandResult.error("No prompt") + + # Return the prompt to be re-sent + print_info(f"Retrying: {last_prompt[:50]}...") + return CommandResult.success(data={"retry_prompt": last_prompt}) + + +class PrevCommand(Command): + """View previous response.""" + + @property + def name(self) -> str: + return "/prev" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="View previous response in history.", + usage="/prev", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not context.session_history: + print_error("No history to navigate.") + return CommandResult.error("Empty history") + + if context.current_index <= 0: + print_info("Already at the beginning of history.") + return CommandResult.success() + + context.current_index -= 1 + entry = context.session_history[context.current_index] + + console.print(f"\n[dim]Message {context.current_index + 1}/{len(context.session_history)}[/]") + console.print(f"\n[bold cyan]Prompt:[/] {entry.get('prompt', '')}") + display_markdown(entry.get("response", ""), panel=True, title="Response") + + return CommandResult.success() + + +class NextCommand(Command): + """View next response.""" + + @property + def name(self) -> str: + return "/next" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="View next response in history.", + usage="/next", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not context.session_history: + print_error("No history to navigate.") + return CommandResult.error("Empty history") + + if context.current_index >= len(context.session_history) - 1: + print_info("Already at the end of history.") + return CommandResult.success() + + context.current_index += 1 + entry = context.session_history[context.current_index] + + console.print(f"\n[dim]Message {context.current_index + 1}/{len(context.session_history)}[/]") + console.print(f"\n[bold cyan]Prompt:[/] {entry.get('prompt', '')}") + display_markdown(entry.get("response", ""), panel=True, title="Response") + + return CommandResult.success() + + +class ConfigCommand(Command): + """View or modify configuration.""" + + @property + def name(self) -> str: + return "/config" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="View or modify application configuration.", + usage="/config [setting] [value]", + examples=[ + ("View all settings", "/config"), + ("Set API key", "/config api"), + ("Enable streaming", "/config stream on"), + ], + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + from rich.table import Table + + if not context.settings: + print_error("Settings not available") + return CommandResult.error("No settings") + + settings = context.settings + + if not args: + # Show all settings + from oai.constants import DEFAULT_SYSTEM_PROMPT + + table = Table("Setting", "Value", show_header=True, header_style="bold magenta") + table.add_row("API Key", "***" + settings.api_key[-4:] if settings.api_key else "Not set") + table.add_row("Base URL", settings.base_url) + table.add_row("Default Model", settings.default_model or "Not set") + + # Show system prompt status + if settings.default_system_prompt is None: + system_prompt_display = f"[default] {DEFAULT_SYSTEM_PROMPT[:40]}..." + elif settings.default_system_prompt == "": + system_prompt_display = "[blank]" + else: + system_prompt_display = settings.default_system_prompt[:50] + "..." if len(settings.default_system_prompt) > 50 else settings.default_system_prompt + table.add_row("System Prompt", system_prompt_display) + + table.add_row("Streaming", "on" if settings.stream_enabled else "off") + table.add_row("Cost Warning", f"${settings.cost_warning_threshold:.4f}") + table.add_row("Max Tokens", str(settings.max_tokens)) + table.add_row("Default Online", "on" if settings.default_online_mode else "off") + table.add_row("Log Level", settings.log_level) + + display_panel(table, title="[bold green]Configuration[/]") + return CommandResult.success() + + parts = args.split(maxsplit=1) + setting = parts[0].lower() + value = parts[1] if len(parts) > 1 else None + + if setting == "api": + if value: + settings.set_api_key(value) + else: + new_key = prompt_input("Enter API key", password=True) + if new_key: + settings.set_api_key(new_key) + print_success("API key updated") + + elif setting == "stream": + if value and value.lower() in ["on", "off"]: + settings.set_stream_enabled(value.lower() == "on") + print_success(f"Streaming {'enabled' if settings.stream_enabled else 'disabled'}") + else: + print_info("Usage: /config stream [on|off]") + + elif setting == "model": + if value: + # Show model selector with search term, same as /model + return CommandResult.success(data={"show_model_selector": True, "search": value, "set_as_default": True}) + else: + print_info(f"Current: {settings.default_model or 'Not set'}") + + elif setting == "system": + from oai.constants import DEFAULT_SYSTEM_PROMPT + + if value: + # Decode escape sequences like \n for newlines + value = value.encode().decode('unicode_escape') + settings.set_default_system_prompt(value) + if value: + print_success(f"Default system prompt set to: {value}") + else: + print_success("Default system prompt set to blank.") + else: + if settings.default_system_prompt is None: + print_info(f"Using hardcoded default: {DEFAULT_SYSTEM_PROMPT[:60]}...") + elif settings.default_system_prompt == "": + print_info("System prompt: [blank]") + else: + print_info(f"System prompt: {settings.default_system_prompt}") + + elif setting == "online": + if value and value.lower() in ["on", "off"]: + settings.set_default_online_mode(value.lower() == "on") + print_success(f"Default online {'enabled' if settings.default_online_mode else 'disabled'}") + else: + print_info("Usage: /config online [on|off]") + + elif setting == "costwarning": + if value: + try: + settings.set_cost_warning_threshold(float(value)) + print_success(f"Cost warning set to: ${float(value):.4f}") + except ValueError: + print_error("Please enter a valid number") + else: + print_info(f"Current: ${settings.cost_warning_threshold:.4f}") + + elif setting == "maxtoken": + if value: + try: + settings.set_max_tokens(int(value)) + print_success(f"Max tokens set to: {value}") + except ValueError: + print_error("Please enter a valid number") + else: + print_info(f"Current: {settings.max_tokens}") + + elif setting == "loglevel": + valid_levels = ["debug", "info", "warning", "error", "critical"] + if value and value.lower() in valid_levels: + settings.set_log_level(value.lower()) + print_success(f"Log level set to: {value.lower()}") + else: + print_info(f"Valid levels: {', '.join(valid_levels)}") + + else: + print_error(f"Unknown setting: {setting}") + return CommandResult.error("Unknown setting") + + return CommandResult.success() + + +class ListCommand(Command): + """List saved conversations.""" + + @property + def name(self) -> str: + return "/list" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="List all saved conversations.", + usage="/list", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + from oai.config.database import Database + from rich.table import Table + + db = Database() + conversations = db.list_conversations() + + if not conversations: + print_info("No saved conversations.") + return CommandResult.success() + + table = Table("No.", "Name", "Messages", "Last Saved", show_header=True, header_style="bold magenta") + + for i, conv in enumerate(conversations, 1): + table.add_row( + str(i), + conv["name"], + str(conv["message_count"]), + conv["timestamp"][:19] if conv.get("timestamp") else "-", + ) + + display_panel(table, title="[bold green]Saved Conversations[/]") + return CommandResult.success(data={"conversations": conversations}) + + +class SaveCommand(Command): + """Save conversation.""" + + @property + def name(self) -> str: + return "/save" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Save the current conversation.", + usage="/save ", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + print_error("Usage: /save ") + return CommandResult.error("Missing name") + + if not context.session_history: + print_warning("No conversation to save.") + return CommandResult.error("Empty history") + + from oai.config.database import Database + + db = Database() + db.save_conversation(args, context.session_history) + print_success(f"Conversation saved as '{args}'") + get_logger().info(f"Saved conversation: {args}") + return CommandResult.success() + + +class LoadCommand(Command): + """Load saved conversation.""" + + @property + def name(self) -> str: + return "/load" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Load a saved conversation.", + usage="/load ", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + print_error("Usage: /load ") + return CommandResult.error("Missing name") + + from oai.config.database import Database + + db = Database() + + # Check if it's a number + name = args + if args.isdigit(): + conversations = db.list_conversations() + index = int(args) - 1 + if 0 <= index < len(conversations): + name = conversations[index]["name"] + else: + print_error(f"Invalid number. Use 1-{len(conversations)}") + return CommandResult.error("Invalid number") + + data = db.load_conversation(name) + if not data: + print_error(f"Conversation '{name}' not found") + return CommandResult.error("Not found") + + print_success(f"Loaded conversation '{name}' ({len(data)} messages)") + get_logger().info(f"Loaded conversation: {name}") + return CommandResult.success(data={"load_conversation": name, "history": data}) + + +class DeleteCommand(Command): + """Delete saved conversation.""" + + @property + def name(self) -> str: + return "/delete" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Delete a saved conversation.", + usage="/delete ", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + print_error("Usage: /delete ") + return CommandResult.error("Missing name") + + from oai.config.database import Database + + db = Database() + + # Check if it's a number + name = args + if args.isdigit(): + conversations = db.list_conversations() + index = int(args) - 1 + if 0 <= index < len(conversations): + name = conversations[index]["name"] + else: + print_error(f"Invalid number. Use 1-{len(conversations)}") + return CommandResult.error("Invalid number") + + if not prompt_confirm(f"Delete conversation '{name}'?"): + print_info("Cancelled") + return CommandResult.success() + + count = db.delete_conversation(name) + if count > 0: + print_success(f"Deleted conversation '{name}'") + get_logger().info(f"Deleted conversation: {name}") + else: + print_error(f"Conversation '{name}' not found") + return CommandResult.error("Not found") + + return CommandResult.success() + + +class InfoCommand(Command): + """Show model information.""" + + @property + def name(self) -> str: + return "/info" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Display detailed model information.", + usage="/info [model_id]", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + from rich.table import Table + + if not context.provider: + print_error("No provider available") + return CommandResult.error("No provider") + + model_id = args.strip() if args else None + + if not model_id and context.selected_model_raw: + model_id = context.selected_model_raw.get("id") + + if not model_id: + print_error("No model specified. Use /info or select a model first.") + return CommandResult.error("No model") + + # Get raw model data + model = None + if hasattr(context.provider, "get_raw_model"): + model = context.provider.get_raw_model(model_id) + + if not model: + print_error(f"Model '{model_id}' not found") + return CommandResult.error("Not found") + + table = Table("Property", "Value", show_header=True, header_style="bold magenta") + table.add_row("ID", model.get("id", "")) + table.add_row("Name", model.get("name", "")) + table.add_row("Context Length", f"{model.get('context_length', 0):,}") + + # Pricing + pricing = model.get("pricing", {}) + if pricing: + prompt_price = float(pricing.get("prompt", 0)) * 1_000_000 + completion_price = float(pricing.get("completion", 0)) * 1_000_000 + table.add_row("Input Price", f"${prompt_price:.2f}/M tokens") + table.add_row("Output Price", f"${completion_price:.2f}/M tokens") + + # Capabilities + arch = model.get("architecture", {}) + input_mod = arch.get("input_modalities", []) + output_mod = arch.get("output_modalities", []) + supported = model.get("supported_parameters", []) + + table.add_row("Input Modalities", ", ".join(input_mod) if input_mod else "text") + table.add_row("Output Modalities", ", ".join(output_mod) if output_mod else "text") + table.add_row("Image Support", "✓" if "image" in input_mod else "✗") + table.add_row("Tool Support", "✓" if "tools" in supported else "✗") + table.add_row("Online Support", "✓" if "tools" in supported else "✗") + + display_panel(table, title=f"[bold green]Model: {model_id}[/]") + return CommandResult.success() + + +class MCPCommand(Command): + """MCP management command.""" + + @property + def name(self) -> str: + return "/mcp" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Manage MCP (Model Context Protocol).", + usage="/mcp [args]", + examples=[ + ("Enable MCP", "/mcp on"), + ("Show status", "/mcp status"), + ("Add folder", "/mcp add ~/Documents"), + ], + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not context.mcp_manager: + print_error("MCP not available") + return CommandResult.error("No MCP") + + mcp = context.mcp_manager + parts = args.strip().split(maxsplit=1) + cmd = parts[0].lower() if parts else "" + cmd_args = parts[1] if len(parts) > 1 else "" + + if cmd in ["on", "enable"]: + result = mcp.enable() + if result["success"]: + print_success(result.get("message", "MCP enabled")) + if result.get("folder_count", 0) == 0: + print_info("Add folders with: /mcp add ~/Documents") + else: + print_error(result.get("error", "Failed")) + return CommandResult.success() + + elif cmd in ["off", "disable"]: + result = mcp.disable() + if result["success"]: + print_success("MCP disabled") + else: + print_error(result.get("error", "Failed")) + return CommandResult.success() + + elif cmd == "status": + result = mcp.get_status() + if result["success"]: + from rich.table import Table + table = Table("Property", "Value", show_header=True, header_style="bold magenta") + + status = "[green]Active ✓[/]" if result["enabled"] else "[red]Inactive ✗[/]" + table.add_row("Status", status) + table.add_row("Mode", result.get("mode_info", {}).get("mode_display", "files")) + table.add_row("Folders", str(result.get("folder_count", 0))) + table.add_row("Databases", str(result.get("database_count", 0))) + table.add_row("Write Mode", "[green]Enabled[/]" if result.get("write_enabled") else "[dim]Disabled[/]") + table.add_row(".gitignore", result.get("gitignore_status", "on")) + + display_panel(table, title="[bold green]MCP Status[/]") + return CommandResult.success() + + elif cmd == "add": + if cmd_args.startswith("db "): + db_path = cmd_args[3:].strip() + result = mcp.add_database(db_path) + else: + result = mcp.add_folder(cmd_args) + + if result["success"]: + print_success(f"Added: {cmd_args}") + else: + print_error(result.get("error", "Failed")) + return CommandResult.success() + + elif cmd in ["remove", "rem"]: + if cmd_args.startswith("db "): + result = mcp.remove_database(cmd_args[3:].strip()) + else: + result = mcp.remove_folder(cmd_args) + + if result["success"]: + print_success(f"Removed: {cmd_args}") + else: + print_error(result.get("error", "Failed")) + return CommandResult.success() + + elif cmd == "list": + result = mcp.list_folders() + if result["success"]: + from rich.table import Table + table = Table("No.", "Path", "Files", "Size", show_header=True, header_style="bold magenta") + for f in result.get("folders", []): + table.add_row( + str(f["number"]), + f["path"], + str(f.get("file_count", 0)), + f"{f.get('size_mb', 0):.1f} MB", + ) + display_panel(table, title="[bold green]MCP Folders[/]") + return CommandResult.success() + + elif cmd == "db": + if cmd_args == "list": + result = mcp.list_databases() + if result["success"]: + from rich.table import Table + table = Table("No.", "Name", "Tables", "Size", show_header=True, header_style="bold magenta") + for db in result.get("databases", []): + table.add_row( + str(db["number"]), + db["name"], + str(db.get("table_count", 0)), + f"{db.get('size_mb', 0):.1f} MB", + ) + display_panel(table, title="[bold green]MCP Databases[/]") + elif cmd_args.isdigit(): + result = mcp.switch_mode("database", int(cmd_args)) + if result["success"]: + print_success(result.get("message", "Switched to database mode")) + else: + print_error(result.get("error", "Failed")) + return CommandResult.success() + + elif cmd == "files": + result = mcp.switch_mode("files") + if result["success"]: + print_success("Switched to file mode") + return CommandResult.success() + + elif cmd == "write": + if cmd_args.lower() == "on": + mcp.enable_write() + elif cmd_args.lower() == "off": + mcp.disable_write() + else: + status = "[green]Enabled[/]" if mcp.write_enabled else "[dim]Disabled[/]" + print_info(f"Write mode: {status}") + return CommandResult.success() + + elif cmd == "gitignore": + if cmd_args.lower() in ["on", "off"]: + result = mcp.toggle_gitignore(cmd_args.lower() == "on") + if result["success"]: + print_success(result.get("message", "Updated")) + else: + print_info("Usage: /mcp gitignore [on|off]") + return CommandResult.success() + + else: + print_error(f"Unknown MCP command: {cmd}") + print_info("Commands: on, off, status, add, remove, list, db, files, write, gitignore") + return CommandResult.error("Unknown command") + + +class PasteCommand(Command): + """Paste from clipboard.""" + + @property + def name(self) -> str: + return "/paste" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Paste from clipboard and send to AI.", + usage="/paste [prompt]", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + try: + import pyperclip + content = pyperclip.paste() + except ImportError: + print_error("pyperclip not installed") + return CommandResult.error("pyperclip not installed") + except Exception as e: + print_error(f"Failed to access clipboard: {e}") + return CommandResult.error(str(e)) + + if not content: + print_warning("Clipboard is empty") + return CommandResult.error("Empty clipboard") + + # Show preview + preview = content[:200] + "..." if len(content) > 200 else content + console.print(f"\n[dim]Clipboard content ({len(content)} chars):[/]") + console.print(f"[yellow]{preview}[/]\n") + + # Build the prompt + if args: + full_prompt = f"{args}\n\n```\n{content}\n```" + else: + full_prompt = content + + return CommandResult.success(data={"paste_prompt": full_prompt}) + + +class ModelCommand(Command): + """Select AI model.""" + + @property + def name(self) -> str: + return "/model" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Select or search for AI models.", + usage="/model [search_term]", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + # This is handled specially in the CLI, but we need a handler + # to prevent it from being sent to the AI + return CommandResult.success(data={"show_model_selector": True, "search": args}) + + +def register_all_commands() -> None: + """Register all built-in commands with the global registry.""" + commands = [ + HelpCommand(), + ClearCommand(), + MemoryCommand(), + OnlineCommand(), + ResetCommand(), + StatsCommand(), + CreditsCommand(), + ExportCommand(), + MiddleOutCommand(), + MaxTokenCommand(), + SystemCommand(), + RetryCommand(), + PrevCommand(), + NextCommand(), + ConfigCommand(), + ListCommand(), + SaveCommand(), + LoadCommand(), + DeleteCommand(), + InfoCommand(), + MCPCommand(), + PasteCommand(), + ModelCommand(), + ] + + for command in commands: + try: + registry.register(command) + except ValueError as e: + get_logger().warning(f"Failed to register command: {e}") diff --git a/oai/commands/registry.py b/oai/commands/registry.py new file mode 100644 index 0000000..9be81ad --- /dev/null +++ b/oai/commands/registry.py @@ -0,0 +1,381 @@ +""" +Command registry for oAI. + +This module defines the command system infrastructure including +the Command base class, CommandContext for state, and CommandRegistry +for managing available commands. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING + +from oai.utils.logging import get_logger + +if TYPE_CHECKING: + from oai.config.settings import Settings + from oai.providers.base import AIProvider, ModelInfo + from oai.mcp.manager import MCPManager + + +class CommandStatus(str, Enum): + """Status of command execution.""" + + SUCCESS = "success" + ERROR = "error" + CONTINUE = "continue" # Continue to next handler + EXIT = "exit" # Exit the application + + +@dataclass +class CommandResult: + """ + Result of a command execution. + + Attributes: + status: Execution status + message: Optional message to display + data: Optional data payload + should_continue: Whether to continue the main loop + """ + + status: CommandStatus = CommandStatus.SUCCESS + message: Optional[str] = None + data: Optional[Any] = None + should_continue: bool = True + + @classmethod + def success(cls, message: Optional[str] = None, data: Any = None) -> "CommandResult": + """Create a success result.""" + return cls(status=CommandStatus.SUCCESS, message=message, data=data) + + @classmethod + def error(cls, message: str) -> "CommandResult": + """Create an error result.""" + return cls(status=CommandStatus.ERROR, message=message) + + @classmethod + def exit(cls, message: Optional[str] = None) -> "CommandResult": + """Create an exit result.""" + return cls(status=CommandStatus.EXIT, message=message, should_continue=False) + + +@dataclass +class CommandContext: + """ + Context object providing state to command handlers. + + Contains all the session state needed by commands including + settings, provider, conversation history, and MCP manager. + + Attributes: + settings: Application settings + provider: AI provider instance + mcp_manager: MCP manager instance + selected_model: Currently selected model + session_history: Conversation history + session_system_prompt: Current system prompt + memory_enabled: Whether memory is enabled + online_enabled: Whether online mode is enabled + session_tokens: Session token counts + session_cost: Session cost total + """ + + settings: Optional["Settings"] = None + provider: Optional["AIProvider"] = None + mcp_manager: Optional["MCPManager"] = None + selected_model: Optional["ModelInfo"] = None + selected_model_raw: Optional[Dict[str, Any]] = None + session_history: List[Dict[str, Any]] = field(default_factory=list) + session_system_prompt: str = "" + memory_enabled: bool = True + memory_start_index: int = 0 + online_enabled: bool = False + middle_out_enabled: bool = False + session_max_token: int = 0 + total_input_tokens: int = 0 + total_output_tokens: int = 0 + total_cost: float = 0.0 + message_count: int = 0 + current_index: int = 0 + + +@dataclass +class CommandHelp: + """ + Help information for a command. + + Attributes: + description: Brief description + usage: Usage syntax + examples: List of (description, example) tuples + notes: Additional notes + aliases: Command aliases + """ + + description: str + usage: str = "" + examples: List[tuple] = field(default_factory=list) + notes: str = "" + aliases: List[str] = field(default_factory=list) + + +class Command(ABC): + """ + Abstract base class for all commands. + + Commands implement the execute method to handle their logic. + They can also provide help information and aliases. + """ + + @property + @abstractmethod + def name(self) -> str: + """Get the primary command name (e.g., '/help').""" + pass + + @property + def aliases(self) -> List[str]: + """Get command aliases (e.g., ['/h'] for help).""" + return [] + + @property + @abstractmethod + def help(self) -> CommandHelp: + """Get command help information.""" + pass + + @abstractmethod + def execute(self, args: str, context: CommandContext) -> CommandResult: + """ + Execute the command. + + Args: + args: Arguments passed to the command + context: Command execution context + + Returns: + CommandResult indicating success/failure + """ + pass + + def matches(self, input_text: str) -> bool: + """ + Check if this command matches the input. + + Args: + input_text: User input text + + Returns: + True if this command should handle the input + """ + input_lower = input_text.lower() + cmd_word = input_lower.split()[0] if input_lower.split() else "" + + # Check primary name + if cmd_word == self.name.lower(): + return True + + # Check aliases + for alias in self.aliases: + if cmd_word == alias.lower(): + return True + + return False + + def get_args(self, input_text: str) -> str: + """ + Extract arguments from the input text. + + Args: + input_text: Full user input + + Returns: + Arguments portion of the input + """ + parts = input_text.split(maxsplit=1) + return parts[1] if len(parts) > 1 else "" + + +class CommandRegistry: + """ + Registry for managing available commands. + + Provides registration, lookup, and execution of commands. + """ + + def __init__(self): + """Initialize an empty command registry.""" + self._commands: Dict[str, Command] = {} + self._aliases: Dict[str, str] = {} + self.logger = get_logger() + + def register(self, command: Command) -> None: + """ + Register a command. + + Args: + command: Command instance to register + + Raises: + ValueError: If command name already registered + """ + name = command.name.lower() + + if name in self._commands: + raise ValueError(f"Command '{name}' already registered") + + self._commands[name] = command + + # Register aliases + for alias in command.aliases: + alias_lower = alias.lower() + if alias_lower in self._aliases: + self.logger.warning( + f"Alias '{alias}' already registered, overwriting" + ) + self._aliases[alias_lower] = name + + self.logger.debug(f"Registered command: {name}") + + def register_function( + self, + name: str, + handler: Callable[[str, CommandContext], CommandResult], + description: str, + usage: str = "", + aliases: Optional[List[str]] = None, + examples: Optional[List[tuple]] = None, + notes: str = "", + ) -> None: + """ + Register a function-based command. + + Convenience method for simple commands that don't need + a full Command class. + + Args: + name: Command name (e.g., '/help') + handler: Function to execute + description: Help description + usage: Usage syntax + aliases: Command aliases + examples: Example usages + notes: Additional notes + """ + aliases = aliases or [] + examples = examples or [] + + class FunctionCommand(Command): + @property + def name(self) -> str: + return name + + @property + def aliases(self) -> List[str]: + return aliases + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description=description, + usage=usage, + examples=examples, + notes=notes, + aliases=aliases, + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + return handler(args, context) + + self.register(FunctionCommand()) + + def get(self, name: str) -> Optional[Command]: + """ + Get a command by name or alias. + + Args: + name: Command name or alias + + Returns: + Command instance or None if not found + """ + name_lower = name.lower() + + # Check direct match + if name_lower in self._commands: + return self._commands[name_lower] + + # Check aliases + if name_lower in self._aliases: + return self._commands[self._aliases[name_lower]] + + return None + + def find(self, input_text: str) -> Optional[Command]: + """ + Find a command that matches the input. + + Args: + input_text: User input text + + Returns: + Matching Command or None + """ + cmd_word = input_text.lower().split()[0] if input_text.split() else "" + return self.get(cmd_word) + + def execute(self, input_text: str, context: CommandContext) -> Optional[CommandResult]: + """ + Execute a command matching the input. + + Args: + input_text: User input text + context: Execution context + + Returns: + CommandResult or None if no matching command + """ + command = self.find(input_text) + if command: + args = command.get_args(input_text) + self.logger.debug(f"Executing command: {command.name} with args: {args}") + return command.execute(args, context) + return None + + def is_command(self, input_text: str) -> bool: + """ + Check if input is a valid command. + + Args: + input_text: User input text + + Returns: + True if input matches a registered command + """ + return self.find(input_text) is not None + + def list_commands(self) -> List[Command]: + """ + Get all registered commands. + + Returns: + List of Command instances + """ + return list(self._commands.values()) + + def get_all_names(self) -> List[str]: + """ + Get all command names and aliases. + + Returns: + List of command names including aliases + """ + names = list(self._commands.keys()) + names.extend(self._aliases.keys()) + return sorted(set(names)) + + +# Global registry instance +registry = CommandRegistry() diff --git a/oai/config/__init__.py b/oai/config/__init__.py new file mode 100644 index 0000000..d12b503 --- /dev/null +++ b/oai/config/__init__.py @@ -0,0 +1,11 @@ +""" +Configuration management for oAI. + +This package handles all configuration persistence, settings management, +and database operations for the application. +""" + +from oai.config.settings import Settings +from oai.config.database import Database + +__all__ = ["Settings", "Database"] diff --git a/oai/config/database.py b/oai/config/database.py new file mode 100644 index 0000000..a6a3496 --- /dev/null +++ b/oai/config/database.py @@ -0,0 +1,472 @@ +""" +Database persistence layer for oAI. + +This module provides a clean abstraction for SQLite operations including +configuration storage, conversation persistence, and MCP statistics tracking. +All database operations are centralized here for maintainability. +""" + +import sqlite3 +import json +import datetime +from pathlib import Path +from typing import Optional, List, Dict, Any +from contextlib import contextmanager + +from oai.constants import DATABASE_FILE, CONFIG_DIR + + +class Database: + """ + SQLite database manager for oAI. + + Handles all database operations including: + - Configuration key-value storage + - Conversation session persistence + - MCP configuration and statistics + - Database registrations for MCP + + Uses context managers for safe connection handling and supports + automatic table creation on first use. + """ + + def __init__(self, db_path: Optional[Path] = None): + """ + Initialize the database manager. + + Args: + db_path: Optional custom database path. Defaults to standard location. + """ + self.db_path = db_path or DATABASE_FILE + self._ensure_directories() + self._ensure_tables() + + def _ensure_directories(self) -> None: + """Ensure the configuration directory exists.""" + CONFIG_DIR.mkdir(parents=True, exist_ok=True) + + @contextmanager + def _connection(self): + """ + Context manager for database connections. + + Yields: + sqlite3.Connection: Active database connection + + Example: + with self._connection() as conn: + conn.execute("SELECT * FROM config") + """ + conn = sqlite3.connect(str(self.db_path)) + try: + yield conn + conn.commit() + except Exception: + conn.rollback() + raise + finally: + conn.close() + + def _ensure_tables(self) -> None: + """Create all required tables if they don't exist.""" + with self._connection() as conn: + # Main configuration table + conn.execute(""" + CREATE TABLE IF NOT EXISTS config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ) + """) + + # Conversation sessions table + conn.execute(""" + CREATE TABLE IF NOT EXISTS conversation_sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + timestamp TEXT NOT NULL, + data TEXT NOT NULL + ) + """) + + # MCP configuration table + conn.execute(""" + CREATE TABLE IF NOT EXISTS mcp_config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ) + """) + + # MCP statistics table + conn.execute(""" + CREATE TABLE IF NOT EXISTS mcp_stats ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL, + tool_name TEXT NOT NULL, + folder TEXT, + success INTEGER NOT NULL, + error_message TEXT + ) + """) + + # MCP databases table + conn.execute(""" + CREATE TABLE IF NOT EXISTS mcp_databases ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + path TEXT NOT NULL UNIQUE, + name TEXT NOT NULL, + size INTEGER, + tables TEXT, + added_timestamp TEXT NOT NULL + ) + """) + + # ========================================================================= + # CONFIGURATION METHODS + # ========================================================================= + + def get_config(self, key: str) -> Optional[str]: + """ + Retrieve a configuration value by key. + + Args: + key: The configuration key to retrieve + + Returns: + The configuration value, or None if not found + """ + with self._connection() as conn: + cursor = conn.execute( + "SELECT value FROM config WHERE key = ?", + (key,) + ) + result = cursor.fetchone() + return result[0] if result else None + + def set_config(self, key: str, value: str) -> None: + """ + Set a configuration value. + + Args: + key: The configuration key + value: The value to store + """ + with self._connection() as conn: + conn.execute( + "INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)", + (key, value) + ) + + def delete_config(self, key: str) -> bool: + """ + Delete a configuration value. + + Args: + key: The configuration key to delete + + Returns: + True if a row was deleted, False otherwise + """ + with self._connection() as conn: + cursor = conn.execute( + "DELETE FROM config WHERE key = ?", + (key,) + ) + return cursor.rowcount > 0 + + def get_all_config(self) -> Dict[str, str]: + """ + Retrieve all configuration values. + + Returns: + Dictionary of all key-value pairs + """ + with self._connection() as conn: + cursor = conn.execute("SELECT key, value FROM config") + return dict(cursor.fetchall()) + + # ========================================================================= + # MCP CONFIGURATION METHODS + # ========================================================================= + + def get_mcp_config(self, key: str) -> Optional[str]: + """ + Retrieve an MCP configuration value. + + Args: + key: The MCP configuration key + + Returns: + The configuration value, or None if not found + """ + with self._connection() as conn: + cursor = conn.execute( + "SELECT value FROM mcp_config WHERE key = ?", + (key,) + ) + result = cursor.fetchone() + return result[0] if result else None + + def set_mcp_config(self, key: str, value: str) -> None: + """ + Set an MCP configuration value. + + Args: + key: The MCP configuration key + value: The value to store + """ + with self._connection() as conn: + conn.execute( + "INSERT OR REPLACE INTO mcp_config (key, value) VALUES (?, ?)", + (key, value) + ) + + # ========================================================================= + # MCP STATISTICS METHODS + # ========================================================================= + + def log_mcp_stat( + self, + tool_name: str, + folder: Optional[str], + success: bool, + error_message: Optional[str] = None + ) -> None: + """ + Log an MCP tool usage event. + + Args: + tool_name: Name of the MCP tool that was called + folder: The folder path involved (if any) + success: Whether the call succeeded + error_message: Error message if the call failed + """ + timestamp = datetime.datetime.now().isoformat() + with self._connection() as conn: + conn.execute( + """INSERT INTO mcp_stats + (timestamp, tool_name, folder, success, error_message) + VALUES (?, ?, ?, ?, ?)""", + (timestamp, tool_name, folder, 1 if success else 0, error_message) + ) + + def get_mcp_stats(self) -> Dict[str, Any]: + """ + Get aggregated MCP usage statistics. + + Returns: + Dictionary containing usage statistics: + - total_calls: Total number of tool calls + - reads: Number of file reads + - lists: Number of directory listings + - searches: Number of file searches + - db_inspects: Number of database inspections + - db_searches: Number of database searches + - db_queries: Number of database queries + - last_used: Timestamp of last usage + """ + with self._connection() as conn: + cursor = conn.execute(""" + SELECT + COUNT(*) as total_calls, + SUM(CASE WHEN tool_name = 'read_file' THEN 1 ELSE 0 END) as reads, + SUM(CASE WHEN tool_name = 'list_directory' THEN 1 ELSE 0 END) as lists, + SUM(CASE WHEN tool_name = 'search_files' THEN 1 ELSE 0 END) as searches, + SUM(CASE WHEN tool_name = 'inspect_database' THEN 1 ELSE 0 END) as db_inspects, + SUM(CASE WHEN tool_name = 'search_database' THEN 1 ELSE 0 END) as db_searches, + SUM(CASE WHEN tool_name = 'query_database' THEN 1 ELSE 0 END) as db_queries, + MAX(timestamp) as last_used + FROM mcp_stats + """) + row = cursor.fetchone() + return { + "total_calls": row[0] or 0, + "reads": row[1] or 0, + "lists": row[2] or 0, + "searches": row[3] or 0, + "db_inspects": row[4] or 0, + "db_searches": row[5] or 0, + "db_queries": row[6] or 0, + "last_used": row[7], + } + + # ========================================================================= + # MCP DATABASE REGISTRY METHODS + # ========================================================================= + + def add_mcp_database(self, db_info: Dict[str, Any]) -> int: + """ + Register a database for MCP access. + + Args: + db_info: Dictionary containing: + - path: Database file path + - name: Display name + - size: File size in bytes + - tables: List of table names + - added: Timestamp when added + + Returns: + The database ID + """ + with self._connection() as conn: + conn.execute( + """INSERT INTO mcp_databases + (path, name, size, tables, added_timestamp) + VALUES (?, ?, ?, ?, ?)""", + ( + db_info["path"], + db_info["name"], + db_info["size"], + json.dumps(db_info["tables"]), + db_info["added"] + ) + ) + cursor = conn.execute( + "SELECT id FROM mcp_databases WHERE path = ?", + (db_info["path"],) + ) + return cursor.fetchone()[0] + + def remove_mcp_database(self, db_path: str) -> bool: + """ + Remove a database from the MCP registry. + + Args: + db_path: Path to the database file + + Returns: + True if a row was deleted, False otherwise + """ + with self._connection() as conn: + cursor = conn.execute( + "DELETE FROM mcp_databases WHERE path = ?", + (db_path,) + ) + return cursor.rowcount > 0 + + def get_mcp_databases(self) -> List[Dict[str, Any]]: + """ + Retrieve all registered MCP databases. + + Returns: + List of database information dictionaries + """ + with self._connection() as conn: + cursor = conn.execute( + """SELECT id, path, name, size, tables, added_timestamp + FROM mcp_databases ORDER BY id""" + ) + databases = [] + for row in cursor.fetchall(): + tables_list = json.loads(row[4]) if row[4] else [] + databases.append({ + "id": row[0], + "path": row[1], + "name": row[2], + "size": row[3], + "tables": tables_list, + "added": row[5], + }) + return databases + + # ========================================================================= + # CONVERSATION METHODS + # ========================================================================= + + def save_conversation(self, name: str, data: List[Dict[str, str]]) -> None: + """ + Save a conversation session. + + Args: + name: Name/identifier for the conversation + data: List of message dictionaries with 'prompt' and 'response' keys + """ + timestamp = datetime.datetime.now().isoformat() + data_json = json.dumps(data) + with self._connection() as conn: + conn.execute( + """INSERT INTO conversation_sessions + (name, timestamp, data) VALUES (?, ?, ?)""", + (name, timestamp, data_json) + ) + + def load_conversation(self, name: str) -> Optional[List[Dict[str, str]]]: + """ + Load a conversation by name. + + Args: + name: Name of the conversation to load + + Returns: + List of messages, or None if not found + """ + with self._connection() as conn: + cursor = conn.execute( + """SELECT data FROM conversation_sessions + WHERE name = ? + ORDER BY timestamp DESC LIMIT 1""", + (name,) + ) + result = cursor.fetchone() + if result: + return json.loads(result[0]) + return None + + def delete_conversation(self, name: str) -> int: + """ + Delete a conversation by name. + + Args: + name: Name of the conversation to delete + + Returns: + Number of rows deleted + """ + with self._connection() as conn: + cursor = conn.execute( + "DELETE FROM conversation_sessions WHERE name = ?", + (name,) + ) + return cursor.rowcount + + def list_conversations(self) -> List[Dict[str, Any]]: + """ + List all saved conversations. + + Returns: + List of conversation summaries with name, timestamp, and message_count + """ + with self._connection() as conn: + cursor = conn.execute(""" + SELECT name, MAX(timestamp) as last_saved, data + FROM conversation_sessions + GROUP BY name + ORDER BY last_saved DESC + """) + conversations = [] + for row in cursor.fetchall(): + name, timestamp, data_json = row + data = json.loads(data_json) + conversations.append({ + "name": name, + "timestamp": timestamp, + "message_count": len(data), + }) + return conversations + + +# Global database instance for convenience +_db: Optional[Database] = None + + +def get_database() -> Database: + """ + Get the global database instance. + + Returns: + The shared Database instance + """ + global _db + if _db is None: + _db = Database() + return _db diff --git a/oai/config/settings.py b/oai/config/settings.py new file mode 100644 index 0000000..907a75a --- /dev/null +++ b/oai/config/settings.py @@ -0,0 +1,361 @@ +""" +Settings management for oAI. + +This module provides a centralized settings class that handles all application +configuration with type safety, validation, and persistence. +""" + +from dataclasses import dataclass, field +from typing import Optional +from pathlib import Path + +from oai.constants import ( + DEFAULT_BASE_URL, + DEFAULT_STREAM_ENABLED, + DEFAULT_MAX_TOKENS, + DEFAULT_ONLINE_MODE, + DEFAULT_COST_WARNING_THRESHOLD, + DEFAULT_LOG_MAX_SIZE_MB, + DEFAULT_LOG_BACKUP_COUNT, + DEFAULT_LOG_LEVEL, + DEFAULT_SYSTEM_PROMPT, + VALID_LOG_LEVELS, +) +from oai.config.database import get_database + + +@dataclass +class Settings: + """ + Application settings with persistence support. + + This class provides a clean interface for managing all configuration + options. Settings are automatically loaded from the database on + initialization and can be persisted back. + + Attributes: + api_key: OpenRouter API key + base_url: API base URL + default_model: Default model ID to use + default_system_prompt: Custom system prompt (None = use hardcoded default, "" = blank) + stream_enabled: Whether to stream responses + max_tokens: Maximum tokens per request + cost_warning_threshold: Alert threshold for message cost + default_online_mode: Whether online mode is enabled by default + log_max_size_mb: Maximum log file size in MB + log_backup_count: Number of log file backups to keep + log_level: Logging level (debug/info/warning/error/critical) + """ + + api_key: Optional[str] = None + base_url: str = DEFAULT_BASE_URL + default_model: Optional[str] = None + default_system_prompt: Optional[str] = None + stream_enabled: bool = DEFAULT_STREAM_ENABLED + max_tokens: int = DEFAULT_MAX_TOKENS + cost_warning_threshold: float = DEFAULT_COST_WARNING_THRESHOLD + default_online_mode: bool = DEFAULT_ONLINE_MODE + log_max_size_mb: int = DEFAULT_LOG_MAX_SIZE_MB + log_backup_count: int = DEFAULT_LOG_BACKUP_COUNT + log_level: str = DEFAULT_LOG_LEVEL + + @property + def effective_system_prompt(self) -> str: + """ + Get the effective system prompt to use. + + Returns: + The custom prompt if set, hardcoded default if None, or blank if explicitly set to "" + """ + if self.default_system_prompt is None: + return DEFAULT_SYSTEM_PROMPT + return self.default_system_prompt + + def __post_init__(self): + """Validate settings after initialization.""" + self._validate() + + def _validate(self) -> None: + """Validate all settings values.""" + # Validate log level + if self.log_level.lower() not in VALID_LOG_LEVELS: + raise ValueError( + f"Invalid log level: {self.log_level}. " + f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}" + ) + + # Validate numeric bounds + if self.max_tokens < 1: + raise ValueError("max_tokens must be at least 1") + + if self.cost_warning_threshold < 0: + raise ValueError("cost_warning_threshold must be non-negative") + + if self.log_max_size_mb < 1: + raise ValueError("log_max_size_mb must be at least 1") + + if self.log_backup_count < 0: + raise ValueError("log_backup_count must be non-negative") + + @classmethod + def load(cls) -> "Settings": + """ + Load settings from the database. + + Returns: + Settings instance with values from database + """ + db = get_database() + + # Helper to safely parse boolean + def parse_bool(value: Optional[str], default: bool) -> bool: + if value is None: + return default + return value.lower() in ("on", "true", "1", "yes") + + # Helper to safely parse int + def parse_int(value: Optional[str], default: int) -> int: + if value is None: + return default + try: + return int(value) + except ValueError: + return default + + # Helper to safely parse float + def parse_float(value: Optional[str], default: float) -> float: + if value is None: + return default + try: + return float(value) + except ValueError: + return default + + # Get system prompt from DB: None means not set (use default), "" means explicitly blank + system_prompt_value = db.get_config("default_system_prompt") + + return cls( + api_key=db.get_config("api_key"), + base_url=db.get_config("base_url") or DEFAULT_BASE_URL, + default_model=db.get_config("default_model"), + default_system_prompt=system_prompt_value, + stream_enabled=parse_bool( + db.get_config("stream_enabled"), + DEFAULT_STREAM_ENABLED + ), + max_tokens=parse_int( + db.get_config("max_token"), + DEFAULT_MAX_TOKENS + ), + cost_warning_threshold=parse_float( + db.get_config("cost_warning_threshold"), + DEFAULT_COST_WARNING_THRESHOLD + ), + default_online_mode=parse_bool( + db.get_config("default_online_mode"), + DEFAULT_ONLINE_MODE + ), + log_max_size_mb=parse_int( + db.get_config("log_max_size_mb"), + DEFAULT_LOG_MAX_SIZE_MB + ), + log_backup_count=parse_int( + db.get_config("log_backup_count"), + DEFAULT_LOG_BACKUP_COUNT + ), + log_level=db.get_config("log_level") or DEFAULT_LOG_LEVEL, + ) + + def save(self) -> None: + """Persist all settings to the database.""" + db = get_database() + + # Only save API key if it exists + if self.api_key: + db.set_config("api_key", self.api_key) + + db.set_config("base_url", self.base_url) + + if self.default_model: + db.set_config("default_model", self.default_model) + + # Save system prompt: None means not set (don't save), otherwise save the value (even if "") + if self.default_system_prompt is not None: + db.set_config("default_system_prompt", self.default_system_prompt) + + db.set_config("stream_enabled", "on" if self.stream_enabled else "off") + db.set_config("max_token", str(self.max_tokens)) + db.set_config("cost_warning_threshold", str(self.cost_warning_threshold)) + db.set_config("default_online_mode", "on" if self.default_online_mode else "off") + db.set_config("log_max_size_mb", str(self.log_max_size_mb)) + db.set_config("log_backup_count", str(self.log_backup_count)) + db.set_config("log_level", self.log_level) + + def set_api_key(self, api_key: str) -> None: + """ + Set and persist the API key. + + Args: + api_key: The new API key + """ + self.api_key = api_key.strip() + get_database().set_config("api_key", self.api_key) + + def set_base_url(self, url: str) -> None: + """ + Set and persist the base URL. + + Args: + url: The new base URL + """ + self.base_url = url.strip() + get_database().set_config("base_url", self.base_url) + + def set_default_model(self, model_id: str) -> None: + """ + Set and persist the default model. + + Args: + model_id: The model ID to set as default + """ + self.default_model = model_id + get_database().set_config("default_model", model_id) + + def set_default_system_prompt(self, prompt: str) -> None: + """ + Set and persist the default system prompt. + + Args: + prompt: The system prompt to use for all new sessions. + Empty string "" means blank prompt (no system message). + """ + self.default_system_prompt = prompt + get_database().set_config("default_system_prompt", prompt) + + def clear_default_system_prompt(self) -> None: + """ + Clear the custom system prompt and revert to hardcoded default. + + This removes the custom prompt from the database, causing the + application to use the built-in DEFAULT_SYSTEM_PROMPT. + """ + self.default_system_prompt = None + # Remove from database to indicate "not set" + db = get_database() + with db._connection() as conn: + conn.execute("DELETE FROM config WHERE key = ?", ("default_system_prompt",)) + conn.commit() + + def set_stream_enabled(self, enabled: bool) -> None: + """ + Set and persist the streaming preference. + + Args: + enabled: Whether to enable streaming + """ + self.stream_enabled = enabled + get_database().set_config("stream_enabled", "on" if enabled else "off") + + def set_max_tokens(self, max_tokens: int) -> None: + """ + Set and persist the maximum tokens. + + Args: + max_tokens: Maximum number of tokens + + Raises: + ValueError: If max_tokens is less than 1 + """ + if max_tokens < 1: + raise ValueError("max_tokens must be at least 1") + self.max_tokens = max_tokens + get_database().set_config("max_token", str(max_tokens)) + + def set_cost_warning_threshold(self, threshold: float) -> None: + """ + Set and persist the cost warning threshold. + + Args: + threshold: Cost threshold in USD + + Raises: + ValueError: If threshold is negative + """ + if threshold < 0: + raise ValueError("cost_warning_threshold must be non-negative") + self.cost_warning_threshold = threshold + get_database().set_config("cost_warning_threshold", str(threshold)) + + def set_default_online_mode(self, enabled: bool) -> None: + """ + Set and persist the default online mode. + + Args: + enabled: Whether online mode should be enabled by default + """ + self.default_online_mode = enabled + get_database().set_config("default_online_mode", "on" if enabled else "off") + + def set_log_level(self, level: str) -> None: + """ + Set and persist the log level. + + Args: + level: The log level (debug/info/warning/error/critical) + + Raises: + ValueError: If level is not valid + """ + level_lower = level.lower() + if level_lower not in VALID_LOG_LEVELS: + raise ValueError( + f"Invalid log level: {level}. " + f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}" + ) + self.log_level = level_lower + get_database().set_config("log_level", level_lower) + + def set_log_max_size(self, size_mb: int) -> None: + """ + Set and persist the maximum log file size. + + Args: + size_mb: Maximum size in megabytes + + Raises: + ValueError: If size_mb is less than 1 + """ + if size_mb < 1: + raise ValueError("log_max_size_mb must be at least 1") + # Cap at 100 MB for safety + self.log_max_size_mb = min(size_mb, 100) + get_database().set_config("log_max_size_mb", str(self.log_max_size_mb)) + + +# Global settings instance +_settings: Optional[Settings] = None + + +def get_settings() -> Settings: + """ + Get the global settings instance. + + Returns: + The shared Settings instance, loading from database if needed + """ + global _settings + if _settings is None: + _settings = Settings.load() + return _settings + + +def reload_settings() -> Settings: + """ + Force reload settings from the database. + + Returns: + Fresh Settings instance + """ + global _settings + _settings = Settings.load() + return _settings diff --git a/oai/constants.py b/oai/constants.py new file mode 100644 index 0000000..a91e330 --- /dev/null +++ b/oai/constants.py @@ -0,0 +1,448 @@ +""" +Application-wide constants for oAI. + +This module contains all configuration constants, default values, and static +definitions used throughout the application. Centralizing these values makes +the codebase easier to maintain and configure. +""" + +from pathlib import Path +from typing import Set, Dict, Any +import logging + +# ============================================================================= +# APPLICATION METADATA +# ============================================================================= + +APP_NAME = "oAI" +APP_VERSION = "2.1.0" +APP_URL = "https://iurl.no/oai" +APP_DESCRIPTION = "OpenRouter AI Chat Client with MCP Integration" + +# ============================================================================= +# FILE PATHS +# ============================================================================= + +HOME_DIR = Path.home() +CONFIG_DIR = HOME_DIR / ".config" / "oai" +CACHE_DIR = HOME_DIR / ".cache" / "oai" +HISTORY_FILE = CONFIG_DIR / "history.txt" +DATABASE_FILE = CONFIG_DIR / "oai_config.db" +LOG_FILE = CONFIG_DIR / "oai.log" + +# ============================================================================= +# API CONFIGURATION +# ============================================================================= + +DEFAULT_BASE_URL = "https://openrouter.ai/api/v1" +DEFAULT_STREAM_ENABLED = True +DEFAULT_MAX_TOKENS = 100_000 +DEFAULT_ONLINE_MODE = False + +# ============================================================================= +# DEFAULT SYSTEM PROMPT +# ============================================================================= + +DEFAULT_SYSTEM_PROMPT = ( + "You are a knowledgeable and helpful AI assistant. Provide clear, accurate, " + "and well-structured responses. Be concise yet thorough. When uncertain about " + "something, acknowledge your limitations. For technical topics, include relevant " + "details and examples when helpful." +) + +# ============================================================================= +# PRICING DEFAULTS (per million tokens) +# ============================================================================= + +DEFAULT_INPUT_PRICE = 3.0 +DEFAULT_OUTPUT_PRICE = 15.0 + +MODEL_PRICING: Dict[str, float] = { + "input": DEFAULT_INPUT_PRICE, + "output": DEFAULT_OUTPUT_PRICE, +} + +# ============================================================================= +# CREDIT ALERTS +# ============================================================================= + +LOW_CREDIT_RATIO = 0.1 # Alert when credits < 10% of total +LOW_CREDIT_AMOUNT = 1.0 # Alert when credits < $1.00 +DEFAULT_COST_WARNING_THRESHOLD = 0.01 # Alert when single message cost exceeds this +COST_WARNING_THRESHOLD = DEFAULT_COST_WARNING_THRESHOLD # Alias for convenience + +# ============================================================================= +# LOGGING CONFIGURATION +# ============================================================================= + +DEFAULT_LOG_MAX_SIZE_MB = 10 +DEFAULT_LOG_BACKUP_COUNT = 2 +DEFAULT_LOG_LEVEL = "info" + +VALID_LOG_LEVELS: Dict[str, int] = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + +# ============================================================================= +# FILE HANDLING +# ============================================================================= + +# Maximum file size for reading (10 MB) +MAX_FILE_SIZE = 10 * 1024 * 1024 + +# Content truncation threshold (50 KB) +CONTENT_TRUNCATION_THRESHOLD = 50 * 1024 + +# Maximum items in directory listing +MAX_LIST_ITEMS = 1000 + +# Supported code file extensions for syntax highlighting +SUPPORTED_CODE_EXTENSIONS: Set[str] = { + ".py", ".js", ".ts", ".cs", ".java", ".c", ".cpp", ".h", ".hpp", + ".rb", ".ruby", ".php", ".swift", ".kt", ".kts", ".go", + ".sh", ".bat", ".ps1", ".r", ".scala", ".pl", ".lua", ".dart", + ".elm", ".xml", ".json", ".yaml", ".yml", ".md", ".txt", +} + +# All allowed file extensions for attachment +ALLOWED_FILE_EXTENSIONS: Set[str] = { + # Code files + ".py", ".js", ".ts", ".jsx", ".tsx", ".vue", ".java", ".c", ".cpp", ".cc", ".cxx", + ".h", ".hpp", ".hxx", ".rb", ".go", ".rs", ".swift", ".kt", ".kts", ".php", + ".sh", ".bash", ".zsh", ".fish", ".bat", ".cmd", ".ps1", + # Data files + ".json", ".csv", ".yaml", ".yml", ".toml", ".xml", ".sql", ".db", ".sqlite", ".sqlite3", + # Documents + ".txt", ".md", ".log", ".conf", ".cfg", ".ini", ".env", ".properties", + # Images + ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg", ".ico", + # Archives + ".zip", ".tar", ".gz", ".bz2", ".7z", ".rar", ".xz", + # Config files + ".lock", ".gitignore", ".dockerignore", ".editorconfig", ".eslintrc", + ".prettierrc", ".babelrc", ".nvmrc", ".npmrc", + # Binary/Compiled + ".pyc", ".pyo", ".pyd", ".so", ".dll", ".dylib", ".exe", ".app", + ".dmg", ".pkg", ".deb", ".rpm", ".apk", ".ipa", + # ML/AI + ".pkl", ".pickle", ".joblib", ".npy", ".npz", ".safetensors", ".onnx", + ".pt", ".pth", ".ckpt", ".pb", ".tflite", ".mlmodel", ".coreml", ".rknn", + # Data formats + ".wasm", ".proto", ".graphql", ".graphqls", ".grpc", ".avro", ".parquet", + ".orc", ".feather", ".arrow", ".hdf5", ".h5", ".mat", ".rdata", ".rds", + # Other + ".pdf", ".class", ".jar", ".war", +} + +# ============================================================================= +# SECURITY CONFIGURATION +# ============================================================================= + +# System directories that should never be accessed +SYSTEM_DIRS_BLACKLIST: Set[str] = { + # macOS + "/System", "/Library", "/private", "/usr", "/bin", "/sbin", + # Linux + "/boot", "/dev", "/proc", "/sys", "/root", + # Windows + "C:\\Windows", "C:\\Program Files", "C:\\Program Files (x86)", +} + +# Directories to skip during file operations +SKIP_DIRECTORIES: Set[str] = { + # Python virtual environments + ".venv", "venv", "env", "virtualenv", + "site-packages", "dist-packages", + # Python caches + "__pycache__", ".pytest_cache", ".mypy_cache", + # JavaScript/Node + "node_modules", + # Version control + ".git", ".svn", + # IDEs + ".idea", ".vscode", + # Build directories + "build", "dist", "eggs", ".eggs", +} + +# ============================================================================= +# DATABASE QUERIES - SQL SAFETY +# ============================================================================= + +# Maximum query execution timeout (seconds) +MAX_QUERY_TIMEOUT = 5 + +# Maximum rows returned from queries +MAX_QUERY_RESULTS = 1000 + +# Default rows per query +DEFAULT_QUERY_LIMIT = 100 + +# Keywords that are blocked in database queries +DANGEROUS_SQL_KEYWORDS: Set[str] = { + "INSERT", "UPDATE", "DELETE", "DROP", "CREATE", + "ALTER", "TRUNCATE", "REPLACE", "ATTACH", "DETACH", + "PRAGMA", "VACUUM", "REINDEX", +} + +# ============================================================================= +# MCP CONFIGURATION +# ============================================================================= + +# Maximum tool call iterations per request +MAX_TOOL_LOOPS = 5 + +# ============================================================================= +# VALID COMMANDS +# ============================================================================= + +VALID_COMMANDS: Set[str] = { + "/retry", "/online", "/memory", "/paste", "/export", "/save", "/load", + "/delete", "/list", "/prev", "/next", "/stats", "/middleout", "/reset", + "/info", "/model", "/maxtoken", "/system", "/config", "/credits", "/clear", + "/cl", "/help", "/mcp", +} + +# ============================================================================= +# COMMAND HELP DATABASE +# ============================================================================= + +COMMAND_HELP: Dict[str, Dict[str, Any]] = { + "/clear": { + "aliases": ["/cl"], + "description": "Clear the terminal screen for a clean interface.", + "usage": "/clear\n/cl", + "examples": [ + ("Clear screen", "/clear"), + ("Using short alias", "/cl"), + ], + "notes": "You can also use the keyboard shortcut Ctrl+L.", + }, + "/help": { + "description": "Display help information for commands.", + "usage": "/help [command|topic]", + "examples": [ + ("Show all commands", "/help"), + ("Get help for a specific command", "/help /model"), + ("Get detailed MCP help", "/help mcp"), + ], + "notes": "Use /help without arguments to see the full command list.", + }, + "mcp": { + "description": "Complete guide to MCP (Model Context Protocol).", + "usage": "See detailed examples below", + "examples": [], + "notes": """ +MCP (Model Context Protocol) gives your AI assistant direct access to: + • Local files and folders (read, search, list) + • SQLite databases (inspect, search, query) + +FILE MODE (default): + /mcp on Start MCP server + /mcp add ~/Documents Grant access to folder + /mcp list View all allowed folders + +DATABASE MODE: + /mcp add db ~/app/data.db Add specific database + /mcp db list View all databases + /mcp db 1 Work with database #1 + /mcp files Switch back to file mode + +WRITE MODE (optional): + /mcp write on Enable file modifications + /mcp write off Disable write mode (back to read-only) + +For command-specific help: /help /mcp + """, + }, + "/mcp": { + "description": "Manage MCP for local file access and SQLite database querying.", + "usage": "/mcp [args]", + "examples": [ + ("Enable MCP server", "/mcp on"), + ("Disable MCP server", "/mcp off"), + ("Show MCP status", "/mcp status"), + ("", ""), + ("━━━ FILE MODE ━━━", ""), + ("Add folder for file access", "/mcp add ~/Documents"), + ("Remove folder", "/mcp remove ~/Desktop"), + ("List allowed folders", "/mcp list"), + ("Enable write mode", "/mcp write on"), + ("", ""), + ("━━━ DATABASE MODE ━━━", ""), + ("Add SQLite database", "/mcp add db ~/app/data.db"), + ("List all databases", "/mcp db list"), + ("Switch to database #1", "/mcp db 1"), + ("Switch back to file mode", "/mcp files"), + ], + "notes": "MCP allows AI to read local files and query SQLite databases.", + }, + "/memory": { + "description": "Toggle conversation memory.", + "usage": "/memory [on|off]", + "examples": [ + ("Check current memory status", "/memory"), + ("Enable conversation memory", "/memory on"), + ("Disable memory (save costs)", "/memory off"), + ], + "notes": "Memory is ON by default. Disabling saves tokens.", + }, + "/online": { + "description": "Enable or disable online mode (web search).", + "usage": "/online [on|off]", + "examples": [ + ("Check online mode status", "/online"), + ("Enable web search", "/online on"), + ("Disable web search", "/online off"), + ], + "notes": "Not all models support online mode.", + }, + "/paste": { + "description": "Paste plain text from clipboard and send to the AI.", + "usage": "/paste [prompt]", + "examples": [ + ("Paste clipboard content", "/paste"), + ("Paste with a question", "/paste Explain this code"), + ], + "notes": "Only plain text is supported.", + }, + "/retry": { + "description": "Resend the last prompt from conversation history.", + "usage": "/retry", + "examples": [("Retry last message", "/retry")], + "notes": "Requires at least one message in history.", + }, + "/next": { + "description": "View the next response in conversation history.", + "usage": "/next", + "examples": [("Navigate to next response", "/next")], + "notes": "Use /prev to go backward.", + }, + "/prev": { + "description": "View the previous response in conversation history.", + "usage": "/prev", + "examples": [("Navigate to previous response", "/prev")], + "notes": "Use /next to go forward.", + }, + "/reset": { + "description": "Clear conversation history and reset system prompt.", + "usage": "/reset", + "examples": [("Reset conversation", "/reset")], + "notes": "Requires confirmation.", + }, + "/info": { + "description": "Display detailed information about a model.", + "usage": "/info [model_id]", + "examples": [ + ("Show current model info", "/info"), + ("Show specific model info", "/info gpt-4o"), + ], + "notes": "Shows pricing, capabilities, and context length.", + }, + "/model": { + "description": "Select or change the AI model.", + "usage": "/model [search_term]", + "examples": [ + ("List all models", "/model"), + ("Search for GPT models", "/model gpt"), + ("Search for Claude models", "/model claude"), + ], + "notes": "Models are numbered for easy selection.", + }, + "/config": { + "description": "View or modify application configuration.", + "usage": "/config [setting] [value]", + "examples": [ + ("View all settings", "/config"), + ("Set API key", "/config api"), + ("Set default model", "/config model"), + ("Set system prompt", "/config system You are a helpful assistant"), + ("Enable streaming", "/config stream on"), + ], + "notes": "Available: api, url, model, system, stream, costwarning, maxtoken, online, loglevel.", + }, + "/maxtoken": { + "description": "Set a temporary session token limit.", + "usage": "/maxtoken [value]", + "examples": [ + ("View current session limit", "/maxtoken"), + ("Set session limit to 2000", "/maxtoken 2000"), + ], + "notes": "Cannot exceed stored max token limit.", + }, + "/system": { + "description": "Set or clear the session-level system prompt.", + "usage": "/system [prompt|clear|default ]", + "examples": [ + ("View current system prompt", "/system"), + ("Set as Python expert", "/system You are a Python expert"), + ("Multiline with newlines", r"/system You are an expert.\nBe clear and concise."), + ("Save as default", "/system default You are a helpful assistant"), + ("Revert to default", "/system clear"), + ("Blank prompt", '/system ""'), + ], + "notes": r"Use \n for newlines. /system clear reverts to hardcoded default.", + }, + "/save": { + "description": "Save the current conversation history.", + "usage": "/save ", + "examples": [("Save conversation", "/save my_chat")], + "notes": "Saved conversations can be loaded later with /load.", + }, + "/load": { + "description": "Load a saved conversation.", + "usage": "/load ", + "examples": [ + ("Load by name", "/load my_chat"), + ("Load by number from /list", "/load 3"), + ], + "notes": "Use /list to see numbered conversations.", + }, + "/delete": { + "description": "Delete a saved conversation.", + "usage": "/delete ", + "examples": [("Delete by name", "/delete my_chat")], + "notes": "Requires confirmation. Cannot be undone.", + }, + "/list": { + "description": "List all saved conversations.", + "usage": "/list", + "examples": [("Show saved conversations", "/list")], + "notes": "Conversations are numbered for use with /load and /delete.", + }, + "/export": { + "description": "Export the current conversation to a file.", + "usage": "/export ", + "examples": [ + ("Export as Markdown", "/export md notes.md"), + ("Export as JSON", "/export json conversation.json"), + ("Export as HTML", "/export html report.html"), + ], + "notes": "Available formats: md, json, html.", + }, + "/stats": { + "description": "Display session statistics.", + "usage": "/stats", + "examples": [("View session statistics", "/stats")], + "notes": "Shows tokens, costs, and credits.", + }, + "/credits": { + "description": "Display your OpenRouter account credits.", + "usage": "/credits", + "examples": [("Check credits", "/credits")], + "notes": "Shows total, used, and remaining credits.", + }, + "/middleout": { + "description": "Toggle middle-out transform for long prompts.", + "usage": "/middleout [on|off]", + "examples": [ + ("Check status", "/middleout"), + ("Enable compression", "/middleout on"), + ], + "notes": "Compresses prompts exceeding context size.", + }, +} diff --git a/oai/core/__init__.py b/oai/core/__init__.py new file mode 100644 index 0000000..d2d1e9e --- /dev/null +++ b/oai/core/__init__.py @@ -0,0 +1,14 @@ +""" +Core functionality for oAI. + +This module provides the main session management and AI client +classes that power the chat application. +""" + +from oai.core.session import ChatSession +from oai.core.client import AIClient + +__all__ = [ + "ChatSession", + "AIClient", +] diff --git a/oai/core/client.py b/oai/core/client.py new file mode 100644 index 0000000..5bcdbec --- /dev/null +++ b/oai/core/client.py @@ -0,0 +1,422 @@ +""" +AI Client for oAI. + +This module provides a high-level client for interacting with AI models +through the provider abstraction layer. +""" + +import asyncio +import json +from typing import Any, Callable, Dict, Iterator, List, Optional, Union + +from oai.constants import APP_NAME, APP_URL, MODEL_PRICING +from oai.providers.base import ( + AIProvider, + ChatMessage, + ChatResponse, + ModelInfo, + StreamChunk, + ToolCall, + UsageStats, +) +from oai.providers.openrouter import OpenRouterProvider +from oai.utils.logging import get_logger + + +class AIClient: + """ + High-level AI client for chat interactions. + + Provides a simplified interface for sending chat requests, + handling streaming, and managing tool calls. + + Attributes: + provider: The underlying AI provider + default_model: Default model ID to use + http_headers: Custom HTTP headers for requests + """ + + def __init__( + self, + api_key: str, + base_url: Optional[str] = None, + provider_class: type = OpenRouterProvider, + app_name: str = APP_NAME, + app_url: str = APP_URL, + ): + """ + Initialize the AI client. + + Args: + api_key: API key for authentication + base_url: Optional custom base URL + provider_class: Provider class to use (default: OpenRouterProvider) + app_name: Application name for headers + app_url: Application URL for headers + """ + self.provider: AIProvider = provider_class( + api_key=api_key, + base_url=base_url, + app_name=app_name, + app_url=app_url, + ) + self.default_model: Optional[str] = None + self.logger = get_logger() + + def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]: + """ + Get available models. + + Args: + filter_text_only: Whether to exclude video-only models + + Returns: + List of ModelInfo objects + """ + return self.provider.list_models(filter_text_only=filter_text_only) + + def get_model(self, model_id: str) -> Optional[ModelInfo]: + """ + Get information about a specific model. + + Args: + model_id: Model identifier + + Returns: + ModelInfo or None if not found + """ + return self.provider.get_model(model_id) + + def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]: + """ + Get raw model data for provider-specific fields. + + Args: + model_id: Model identifier + + Returns: + Raw model dictionary or None + """ + if hasattr(self.provider, "get_raw_model"): + return self.provider.get_raw_model(model_id) + return None + + def chat( + self, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + system_prompt: Optional[str] = None, + online: bool = False, + transforms: Optional[List[str]] = None, + ) -> Union[ChatResponse, Iterator[StreamChunk]]: + """ + Send a chat request. + + Args: + messages: List of message dictionaries + model: Model ID (uses default if not specified) + stream: Whether to stream the response + max_tokens: Maximum tokens in response + temperature: Sampling temperature + tools: Tool definitions for function calling + tool_choice: Tool selection mode + system_prompt: System prompt to prepend + online: Whether to enable online mode + transforms: List of transforms (e.g., ["middle-out"]) + + Returns: + ChatResponse for non-streaming, Iterator[StreamChunk] for streaming + + Raises: + ValueError: If no model specified and no default set + """ + model_id = model or self.default_model + if not model_id: + raise ValueError("No model specified and no default set") + + # Apply online mode suffix + if online and hasattr(self.provider, "get_effective_model_id"): + model_id = self.provider.get_effective_model_id(model_id, True) + + # Convert dict messages to ChatMessage objects + chat_messages = [] + + # Add system prompt if provided + if system_prompt: + chat_messages.append(ChatMessage(role="system", content=system_prompt)) + + # Convert message dicts + for msg in messages: + # Convert tool_calls dicts to ToolCall objects if present + tool_calls_data = msg.get("tool_calls") + tool_calls_obj = None + if tool_calls_data: + from oai.providers.base import ToolCall, ToolFunction + tool_calls_obj = [] + for tc in tool_calls_data: + # Handle both ToolCall objects and dicts + if isinstance(tc, ToolCall): + tool_calls_obj.append(tc) + elif isinstance(tc, dict): + func_data = tc.get("function", {}) + tool_calls_obj.append( + ToolCall( + id=tc.get("id", ""), + type=tc.get("type", "function"), + function=ToolFunction( + name=func_data.get("name", ""), + arguments=func_data.get("arguments", "{}"), + ), + ) + ) + + chat_messages.append( + ChatMessage( + role=msg.get("role", "user"), + content=msg.get("content"), + tool_calls=tool_calls_obj, + tool_call_id=msg.get("tool_call_id"), + ) + ) + + self.logger.debug( + f"Sending chat request: model={model_id}, " + f"messages={len(chat_messages)}, stream={stream}" + ) + + return self.provider.chat( + model=model_id, + messages=chat_messages, + stream=stream, + max_tokens=max_tokens, + temperature=temperature, + tools=tools, + tool_choice=tool_choice, + transforms=transforms, + ) + + def chat_with_tools( + self, + messages: List[Dict[str, Any]], + tools: List[Dict[str, Any]], + tool_executor: Callable[[str, Dict[str, Any]], Dict[str, Any]], + model: Optional[str] = None, + max_loops: int = 5, + max_tokens: Optional[int] = None, + system_prompt: Optional[str] = None, + on_tool_call: Optional[Callable[[ToolCall], None]] = None, + on_tool_result: Optional[Callable[[str, Dict[str, Any]], None]] = None, + ) -> ChatResponse: + """ + Send a chat request with automatic tool call handling. + + Executes tool calls returned by the model and continues + the conversation until no more tool calls are requested. + + Args: + messages: Initial messages + tools: Tool definitions + tool_executor: Function to execute tool calls + model: Model ID + max_loops: Maximum tool call iterations + max_tokens: Maximum response tokens + system_prompt: System prompt + on_tool_call: Callback when tool is called + on_tool_result: Callback when tool returns result + + Returns: + Final ChatResponse after all tool calls complete + """ + model_id = model or self.default_model + if not model_id: + raise ValueError("No model specified and no default set") + + # Build initial messages + chat_messages = [] + if system_prompt: + chat_messages.append({"role": "system", "content": system_prompt}) + chat_messages.extend(messages) + + loop_count = 0 + current_response: Optional[ChatResponse] = None + + while loop_count < max_loops: + # Send request + response = self.chat( + messages=chat_messages, + model=model_id, + stream=False, + max_tokens=max_tokens, + tools=tools, + tool_choice="auto", + ) + + if not isinstance(response, ChatResponse): + raise ValueError("Expected non-streaming response") + + current_response = response + + # Check for tool calls + tool_calls = response.tool_calls + if not tool_calls: + break + + self.logger.info(f"Model requested {len(tool_calls)} tool call(s)") + + # Process each tool call + tool_results = [] + for tc in tool_calls: + if on_tool_call: + on_tool_call(tc) + + try: + args = json.loads(tc.function.arguments) + except json.JSONDecodeError as e: + self.logger.error(f"Failed to parse tool arguments: {e}") + result = {"error": f"Invalid arguments: {e}"} + else: + result = tool_executor(tc.function.name, args) + + if on_tool_result: + on_tool_result(tc.function.name, result) + + tool_results.append({ + "tool_call_id": tc.id, + "role": "tool", + "name": tc.function.name, + "content": json.dumps(result), + }) + + # Add assistant message with tool calls + assistant_msg = { + "role": "assistant", + "content": response.content, + "tool_calls": [ + { + "id": tc.id, + "type": tc.type, + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in tool_calls + ], + } + chat_messages.append(assistant_msg) + chat_messages.extend(tool_results) + + loop_count += 1 + + if loop_count >= max_loops: + self.logger.warning(f"Reached max tool call loops ({max_loops})") + + return current_response + + def stream_chat( + self, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + max_tokens: Optional[int] = None, + system_prompt: Optional[str] = None, + online: bool = False, + on_chunk: Optional[Callable[[StreamChunk], None]] = None, + ) -> tuple[str, Optional[UsageStats]]: + """ + Stream a chat response and collect the full text. + + Args: + messages: Chat messages + model: Model ID + max_tokens: Maximum tokens + system_prompt: System prompt + online: Online mode + on_chunk: Optional callback for each chunk + + Returns: + Tuple of (full_response_text, usage_stats) + """ + response = self.chat( + messages=messages, + model=model, + stream=True, + max_tokens=max_tokens, + system_prompt=system_prompt, + online=online, + ) + + if isinstance(response, ChatResponse): + # Not actually streaming + return response.content or "", response.usage + + full_text = "" + usage: Optional[UsageStats] = None + + for chunk in response: + if chunk.error: + self.logger.error(f"Stream error: {chunk.error}") + break + + if chunk.delta_content: + full_text += chunk.delta_content + if on_chunk: + on_chunk(chunk) + + if chunk.usage: + usage = chunk.usage + + return full_text, usage + + def get_credits(self) -> Optional[Dict[str, Any]]: + """ + Get account credit information. + + Returns: + Credit info dict or None if unavailable + """ + return self.provider.get_credits() + + def estimate_cost( + self, + model_id: str, + input_tokens: int, + output_tokens: int, + ) -> float: + """ + Estimate cost for a completion. + + Args: + model_id: Model ID + input_tokens: Number of input tokens + output_tokens: Number of output tokens + + Returns: + Estimated cost in USD + """ + if hasattr(self.provider, "estimate_cost"): + return self.provider.estimate_cost(model_id, input_tokens, output_tokens) + + # Fallback to default pricing + input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000 + output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000 + return input_cost + output_cost + + def set_default_model(self, model_id: str) -> None: + """ + Set the default model. + + Args: + model_id: Model ID to use as default + """ + self.default_model = model_id + self.logger.info(f"Default model set to: {model_id}") + + def clear_cache(self) -> None: + """Clear the provider's model cache.""" + if hasattr(self.provider, "clear_cache"): + self.provider.clear_cache() diff --git a/oai/core/session.py b/oai/core/session.py new file mode 100644 index 0000000..c4e5e31 --- /dev/null +++ b/oai/core/session.py @@ -0,0 +1,659 @@ +""" +Chat session management for oAI. + +This module provides the ChatSession class that manages an interactive +chat session including history, state, and message handling. +""" + +import asyncio +import json +import time +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple + +from rich.live import Live +from rich.markdown import Markdown + +from oai.commands.registry import CommandContext, CommandResult, registry +from oai.config.database import Database +from oai.config.settings import Settings +from oai.constants import ( + COST_WARNING_THRESHOLD, + LOW_CREDIT_AMOUNT, + LOW_CREDIT_RATIO, +) +from oai.core.client import AIClient +from oai.mcp.manager import MCPManager +from oai.providers.base import ChatResponse, StreamChunk, UsageStats +from oai.ui.console import ( + console, + display_markdown, + display_panel, + print_error, + print_info, + print_success, + print_warning, +) +from oai.ui.prompts import prompt_copy_response +from oai.utils.logging import get_logger + + +@dataclass +class SessionStats: + """ + Statistics for the current session. + + Tracks tokens, costs, and message counts. + """ + + total_input_tokens: int = 0 + total_output_tokens: int = 0 + total_cost: float = 0.0 + message_count: int = 0 + + @property + def total_tokens(self) -> int: + """Get total token count.""" + return self.total_input_tokens + self.total_output_tokens + + def add_usage(self, usage: Optional[UsageStats], cost: float = 0.0) -> None: + """ + Add usage stats from a response. + + Args: + usage: Usage statistics + cost: Cost if not in usage + """ + if usage: + self.total_input_tokens += usage.prompt_tokens + self.total_output_tokens += usage.completion_tokens + if usage.total_cost_usd: + self.total_cost += usage.total_cost_usd + else: + self.total_cost += cost + else: + self.total_cost += cost + self.message_count += 1 + + +@dataclass +class HistoryEntry: + """ + A single entry in the conversation history. + + Stores the user prompt, assistant response, and metrics. + """ + + prompt: str + response: str + prompt_tokens: int = 0 + completion_tokens: int = 0 + msg_cost: float = 0.0 + timestamp: Optional[float] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + return { + "prompt": self.prompt, + "response": self.response, + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + "msg_cost": self.msg_cost, + } + + +class ChatSession: + """ + Manages an interactive chat session. + + Handles conversation history, state management, command processing, + and communication with the AI client. + + Attributes: + client: AI client for API requests + settings: Application settings + mcp_manager: MCP manager for file/database access + history: Conversation history + stats: Session statistics + """ + + def __init__( + self, + client: AIClient, + settings: Settings, + mcp_manager: Optional[MCPManager] = None, + ): + """ + Initialize a chat session. + + Args: + client: AI client instance + settings: Application settings + mcp_manager: Optional MCP manager + """ + self.client = client + self.settings = settings + self.mcp_manager = mcp_manager + self.db = Database() + + self.history: List[HistoryEntry] = [] + self.stats = SessionStats() + + # Session state + self.system_prompt: str = settings.effective_system_prompt + self.memory_enabled: bool = True + self.memory_start_index: int = 0 + self.online_enabled: bool = settings.default_online_mode + self.middle_out_enabled: bool = False + self.session_max_token: int = 0 + self.current_index: int = 0 + + # Selected model + self.selected_model: Optional[Dict[str, Any]] = None + + self.logger = get_logger() + + def get_context(self) -> CommandContext: + """ + Get the current command context. + + Returns: + CommandContext with current session state + """ + return CommandContext( + settings=self.settings, + provider=self.client.provider, + mcp_manager=self.mcp_manager, + selected_model_raw=self.selected_model, + session_history=[e.to_dict() for e in self.history], + session_system_prompt=self.system_prompt, + memory_enabled=self.memory_enabled, + memory_start_index=self.memory_start_index, + online_enabled=self.online_enabled, + middle_out_enabled=self.middle_out_enabled, + session_max_token=self.session_max_token, + total_input_tokens=self.stats.total_input_tokens, + total_output_tokens=self.stats.total_output_tokens, + total_cost=self.stats.total_cost, + message_count=self.stats.message_count, + current_index=self.current_index, + ) + + def set_model(self, model: Dict[str, Any]) -> None: + """ + Set the selected model. + + Args: + model: Raw model dictionary + """ + self.selected_model = model + self.client.set_default_model(model["id"]) + self.logger.info(f"Model selected: {model['id']}") + + def build_api_messages(self, user_input: str) -> List[Dict[str, Any]]: + """ + Build the messages array for an API request. + + Includes system prompt, history (if memory enabled), and current input. + + Args: + user_input: Current user input + + Returns: + List of message dictionaries + """ + messages = [] + + # Add system prompt + if self.system_prompt: + messages.append({"role": "system", "content": self.system_prompt}) + + # Add database context if in database mode + if self.mcp_manager and self.mcp_manager.enabled: + if self.mcp_manager.mode == "database" and self.mcp_manager.selected_db_index is not None: + db = self.mcp_manager.databases[self.mcp_manager.selected_db_index] + db_context = ( + f"You are connected to SQLite database: {db['name']}\n" + f"Available tables: {', '.join(db['tables'])}\n\n" + "Use inspect_database, search_database, or query_database tools. " + "All queries are read-only." + ) + messages.append({"role": "system", "content": db_context}) + + # Add history if memory enabled + if self.memory_enabled: + for i in range(self.memory_start_index, len(self.history)): + entry = self.history[i] + messages.append({"role": "user", "content": entry.prompt}) + messages.append({"role": "assistant", "content": entry.response}) + + # Add current message + messages.append({"role": "user", "content": user_input}) + + return messages + + def get_mcp_tools(self) -> Optional[List[Dict[str, Any]]]: + """ + Get MCP tool definitions if available. + + Returns: + List of tool schemas or None + """ + if not self.mcp_manager or not self.mcp_manager.enabled: + return None + + if not self.selected_model: + return None + + # Check if model supports tools + supported_params = self.selected_model.get("supported_parameters", []) + if "tools" not in supported_params and "functions" not in supported_params: + return None + + return self.mcp_manager.get_tools_schema() + + async def execute_tool( + self, + tool_name: str, + tool_args: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Execute an MCP tool. + + Args: + tool_name: Name of the tool + tool_args: Tool arguments + + Returns: + Tool execution result + """ + if not self.mcp_manager: + return {"error": "MCP not available"} + + return await self.mcp_manager.call_tool(tool_name, **tool_args) + + def send_message( + self, + user_input: str, + stream: bool = True, + on_stream_chunk: Optional[Callable[[str], None]] = None, + ) -> Tuple[str, Optional[UsageStats], float]: + """ + Send a message and get a response. + + Args: + user_input: User's input text + stream: Whether to stream the response + on_stream_chunk: Callback for stream chunks + + Returns: + Tuple of (response_text, usage_stats, response_time) + """ + if not self.selected_model: + raise ValueError("No model selected") + + start_time = time.time() + messages = self.build_api_messages(user_input) + + # Get MCP tools + tools = self.get_mcp_tools() + if tools: + # Disable streaming when tools are present + stream = False + + # Build request parameters + model_id = self.selected_model["id"] + if self.online_enabled: + if hasattr(self.client.provider, "get_effective_model_id"): + model_id = self.client.provider.get_effective_model_id(model_id, True) + + transforms = ["middle-out"] if self.middle_out_enabled else None + + max_tokens = None + if self.session_max_token > 0: + max_tokens = self.session_max_token + + if tools: + # Use tool handling flow + response = self._send_with_tools( + messages=messages, + model_id=model_id, + tools=tools, + max_tokens=max_tokens, + transforms=transforms, + ) + response_time = time.time() - start_time + return response.content or "", response.usage, response_time + + elif stream: + # Use streaming flow + full_text, usage = self._stream_response( + messages=messages, + model_id=model_id, + max_tokens=max_tokens, + transforms=transforms, + on_chunk=on_stream_chunk, + ) + response_time = time.time() - start_time + return full_text, usage, response_time + + else: + # Non-streaming request + response = self.client.chat( + messages=messages, + model=model_id, + stream=False, + max_tokens=max_tokens, + transforms=transforms, + ) + response_time = time.time() - start_time + + if isinstance(response, ChatResponse): + return response.content or "", response.usage, response_time + else: + return "", None, response_time + + def _send_with_tools( + self, + messages: List[Dict[str, Any]], + model_id: str, + tools: List[Dict[str, Any]], + max_tokens: Optional[int] = None, + transforms: Optional[List[str]] = None, + ) -> ChatResponse: + """ + Send a request with tool call handling. + + Args: + messages: API messages + model_id: Model ID + tools: Tool definitions + max_tokens: Max tokens + transforms: Transforms list + + Returns: + Final ChatResponse + """ + max_loops = 5 + loop_count = 0 + api_messages = list(messages) + + while loop_count < max_loops: + response = self.client.chat( + messages=api_messages, + model=model_id, + stream=False, + max_tokens=max_tokens, + tools=tools, + tool_choice="auto", + transforms=transforms, + ) + + if not isinstance(response, ChatResponse): + raise ValueError("Expected ChatResponse") + + tool_calls = response.tool_calls + if not tool_calls: + return response + + console.print(f"\n[dim yellow]🔧 AI requesting {len(tool_calls)} tool call(s)...[/]") + + tool_results = [] + for tc in tool_calls: + try: + args = json.loads(tc.function.arguments) + except json.JSONDecodeError as e: + self.logger.error(f"Failed to parse tool arguments: {e}") + tool_results.append({ + "tool_call_id": tc.id, + "role": "tool", + "name": tc.function.name, + "content": json.dumps({"error": f"Invalid arguments: {e}"}), + }) + continue + + # Display tool call + args_display = ", ".join( + f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}" + for k, v in args.items() + ) + console.print(f"[dim cyan] → {tc.function.name}({args_display})[/]") + + # Execute tool + result = asyncio.run(self.execute_tool(tc.function.name, args)) + + if "error" in result: + console.print(f"[dim red] ✗ Error: {result['error']}[/]") + else: + self._display_tool_success(tc.function.name, result) + + tool_results.append({ + "tool_call_id": tc.id, + "role": "tool", + "name": tc.function.name, + "content": json.dumps(result), + }) + + # Add assistant message with tool calls + api_messages.append({ + "role": "assistant", + "content": response.content, + "tool_calls": [ + { + "id": tc.id, + "type": tc.type, + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in tool_calls + ], + }) + api_messages.extend(tool_results) + + console.print("\n[dim cyan]💭 Processing tool results...[/]") + loop_count += 1 + + self.logger.warning(f"Reached max tool loops ({max_loops})") + console.print(f"[bold yellow]⚠️ Reached maximum tool calls ({max_loops})[/]") + return response + + def _display_tool_success(self, tool_name: str, result: Dict[str, Any]) -> None: + """Display a success message for a tool call.""" + if tool_name == "search_files": + count = result.get("count", 0) + console.print(f"[dim green] ✓ Found {count} file(s)[/]") + elif tool_name == "read_file": + size = result.get("size", 0) + truncated = " (truncated)" if result.get("truncated") else "" + console.print(f"[dim green] ✓ Read {size} bytes{truncated}[/]") + elif tool_name == "list_directory": + count = result.get("count", 0) + console.print(f"[dim green] ✓ Listed {count} item(s)[/]") + elif tool_name == "inspect_database": + if "table" in result: + console.print(f"[dim green] ✓ Inspected table: {result['table']}[/]") + else: + console.print(f"[dim green] ✓ Inspected database ({result.get('table_count', 0)} tables)[/]") + elif tool_name == "search_database": + count = result.get("count", 0) + console.print(f"[dim green] ✓ Found {count} match(es)[/]") + elif tool_name == "query_database": + count = result.get("count", 0) + console.print(f"[dim green] ✓ Query returned {count} row(s)[/]") + else: + console.print("[dim green] ✓ Success[/]") + + def _stream_response( + self, + messages: List[Dict[str, Any]], + model_id: str, + max_tokens: Optional[int] = None, + transforms: Optional[List[str]] = None, + on_chunk: Optional[Callable[[str], None]] = None, + ) -> Tuple[str, Optional[UsageStats]]: + """ + Stream a response with live display. + + Args: + messages: API messages + model_id: Model ID + max_tokens: Max tokens + transforms: Transforms + on_chunk: Callback for chunks + + Returns: + Tuple of (full_text, usage) + """ + response = self.client.chat( + messages=messages, + model=model_id, + stream=True, + max_tokens=max_tokens, + transforms=transforms, + ) + + if isinstance(response, ChatResponse): + return response.content or "", response.usage + + full_text = "" + usage: Optional[UsageStats] = None + + try: + with Live("", console=console, refresh_per_second=10) as live: + for chunk in response: + if chunk.error: + console.print(f"\n[bold red]Stream error: {chunk.error}[/]") + break + + if chunk.delta_content: + full_text += chunk.delta_content + live.update(Markdown(full_text)) + if on_chunk: + on_chunk(chunk.delta_content) + + if chunk.usage: + usage = chunk.usage + + except KeyboardInterrupt: + console.print("\n[bold yellow]⚠️ Streaming interrupted[/]") + return "", None + + return full_text, usage + + def add_to_history( + self, + prompt: str, + response: str, + usage: Optional[UsageStats] = None, + cost: float = 0.0, + ) -> None: + """ + Add an exchange to the history. + + Args: + prompt: User prompt + response: Assistant response + usage: Usage statistics + cost: Cost if not in usage + """ + entry = HistoryEntry( + prompt=prompt, + response=response, + prompt_tokens=usage.prompt_tokens if usage else 0, + completion_tokens=usage.completion_tokens if usage else 0, + msg_cost=usage.total_cost_usd if usage and usage.total_cost_usd else cost, + timestamp=time.time(), + ) + self.history.append(entry) + self.current_index = len(self.history) - 1 + self.stats.add_usage(usage, cost) + + def save_conversation(self, name: str) -> bool: + """ + Save the current conversation. + + Args: + name: Name for the saved conversation + + Returns: + True if saved successfully + """ + if not self.history: + return False + + data = [e.to_dict() for e in self.history] + self.db.save_conversation(name, data) + self.logger.info(f"Saved conversation: {name}") + return True + + def load_conversation(self, name: str) -> bool: + """ + Load a saved conversation. + + Args: + name: Name of the conversation to load + + Returns: + True if loaded successfully + """ + data = self.db.load_conversation(name) + if not data: + return False + + self.history.clear() + for entry_dict in data: + self.history.append(HistoryEntry( + prompt=entry_dict.get("prompt", ""), + response=entry_dict.get("response", ""), + prompt_tokens=entry_dict.get("prompt_tokens", 0), + completion_tokens=entry_dict.get("completion_tokens", 0), + msg_cost=entry_dict.get("msg_cost", 0.0), + )) + + self.current_index = len(self.history) - 1 + self.memory_start_index = 0 + self.stats = SessionStats() # Reset stats for loaded conversation + self.logger.info(f"Loaded conversation: {name}") + return True + + def reset(self) -> None: + """Reset the session state.""" + self.history.clear() + self.stats = SessionStats() + self.system_prompt = "" + self.memory_start_index = 0 + self.current_index = 0 + self.logger.info("Session reset") + + def check_warnings(self) -> List[str]: + """ + Check for cost and credit warnings. + + Returns: + List of warning messages + """ + warnings = [] + + # Check last message cost + if self.history: + last_cost = self.history[-1].msg_cost + threshold = self.settings.cost_warning_threshold + if last_cost > threshold: + warnings.append( + f"High cost: ${last_cost:.4f} exceeds threshold ${threshold:.4f}" + ) + + # Check credits + credits = self.client.get_credits() + if credits: + left = credits.get("credits_left", 0) + total = credits.get("total_credits", 0) + + if left < LOW_CREDIT_AMOUNT: + warnings.append(f"Low credits: ${left:.2f} remaining!") + elif total > 0 and left < total * LOW_CREDIT_RATIO: + warnings.append(f"Credits low: less than 10% remaining (${left:.2f})") + + return warnings diff --git a/oai/mcp/__init__.py b/oai/mcp/__init__.py new file mode 100644 index 0000000..b7abc9a --- /dev/null +++ b/oai/mcp/__init__.py @@ -0,0 +1,28 @@ +""" +Model Context Protocol (MCP) integration for oAI. + +This package provides filesystem and database access capabilities +through the MCP standard, allowing AI models to interact with +local files and SQLite databases safely. + +Key components: +- MCPManager: High-level manager for MCP operations +- MCPFilesystemServer: Filesystem and database access implementation +- GitignoreParser: Pattern matching for .gitignore support +- SQLiteQueryValidator: Query safety validation +- CrossPlatformMCPConfig: OS-specific configuration +""" + +from oai.mcp.manager import MCPManager +from oai.mcp.server import MCPFilesystemServer +from oai.mcp.gitignore import GitignoreParser +from oai.mcp.validators import SQLiteQueryValidator +from oai.mcp.platform import CrossPlatformMCPConfig + +__all__ = [ + "MCPManager", + "MCPFilesystemServer", + "GitignoreParser", + "SQLiteQueryValidator", + "CrossPlatformMCPConfig", +] diff --git a/oai/mcp/gitignore.py b/oai/mcp/gitignore.py new file mode 100644 index 0000000..83ec4f4 --- /dev/null +++ b/oai/mcp/gitignore.py @@ -0,0 +1,166 @@ +""" +Gitignore pattern parsing for oAI MCP. + +This module implements .gitignore pattern matching to filter files +during MCP filesystem operations. +""" + +import fnmatch +from pathlib import Path +from typing import List, Tuple + +from oai.utils.logging import get_logger + + +class GitignoreParser: + """ + Parse and apply .gitignore patterns. + + Supports standard gitignore syntax including: + - Wildcards (*) and double wildcards (**) + - Directory-only patterns (ending with /) + - Negation patterns (starting with !) + - Comments (lines starting with #) + + Patterns are applied relative to the directory containing + the .gitignore file. + """ + + def __init__(self): + """Initialize an empty pattern collection.""" + # List of (pattern, is_negation, source_dir) + self.patterns: List[Tuple[str, bool, Path]] = [] + + def add_gitignore(self, gitignore_path: Path) -> None: + """ + Parse and add patterns from a .gitignore file. + + Args: + gitignore_path: Path to the .gitignore file + """ + logger = get_logger() + + if not gitignore_path.exists(): + return + + try: + source_dir = gitignore_path.parent + + with open(gitignore_path, "r", encoding="utf-8") as f: + for line_num, line in enumerate(f, 1): + line = line.rstrip("\n\r") + + # Skip empty lines and comments + if not line or line.startswith("#"): + continue + + # Check for negation pattern + is_negation = line.startswith("!") + if is_negation: + line = line[1:] + + # Remove leading slash (make relative to gitignore location) + if line.startswith("/"): + line = line[1:] + + self.patterns.append((line, is_negation, source_dir)) + + logger.debug( + f"Loaded {len(self.patterns)} patterns from {gitignore_path}" + ) + + except Exception as e: + logger.warning(f"Error reading {gitignore_path}: {e}") + + def should_ignore(self, path: Path) -> bool: + """ + Check if a path should be ignored based on gitignore patterns. + + Patterns are evaluated in order, with later patterns overriding + earlier ones. Negation patterns (starting with !) un-ignore + previously matched paths. + + Args: + path: Path to check + + Returns: + True if the path should be ignored + """ + if not self.patterns: + return False + + ignored = False + + for pattern, is_negation, source_dir in self.patterns: + # Only apply pattern if path is under the source directory + try: + rel_path = path.relative_to(source_dir) + except ValueError: + # Path is not relative to this gitignore's directory + continue + + rel_path_str = str(rel_path) + + # Check if pattern matches + if self._match_pattern(pattern, rel_path_str, path.is_dir()): + if is_negation: + ignored = False # Negation patterns un-ignore + else: + ignored = True + + return ignored + + def _match_pattern(self, pattern: str, path: str, is_dir: bool) -> bool: + """ + Match a gitignore pattern against a path. + + Args: + pattern: The gitignore pattern + path: The relative path string to match + is_dir: Whether the path is a directory + + Returns: + True if the pattern matches + """ + # Directory-only pattern (ends with /) + if pattern.endswith("/"): + if not is_dir: + return False + pattern = pattern[:-1] + + # Handle ** patterns (matches any number of directories) + if "**" in pattern: + pattern_parts = pattern.split("**") + if len(pattern_parts) == 2: + prefix, suffix = pattern_parts + + # Match if path starts with prefix and ends with suffix + if prefix: + if not path.startswith(prefix.rstrip("/")): + return False + if suffix: + suffix = suffix.lstrip("/") + if not (path.endswith(suffix) or f"/{suffix}" in path): + return False + return True + + # Direct match using fnmatch + if fnmatch.fnmatch(path, pattern): + return True + + # Match as subdirectory pattern (pattern without / matches in any directory) + if "/" not in pattern: + parts = path.split("/") + if any(fnmatch.fnmatch(part, pattern) for part in parts): + return True + + return False + + def clear(self) -> None: + """Clear all loaded patterns.""" + self.patterns = [] + + @property + def pattern_count(self) -> int: + """Get the number of loaded patterns.""" + return len(self.patterns) diff --git a/oai/mcp/manager.py b/oai/mcp/manager.py new file mode 100644 index 0000000..c44001c --- /dev/null +++ b/oai/mcp/manager.py @@ -0,0 +1,1365 @@ +""" +MCP Manager for oAI. + +This module provides the high-level interface for managing MCP operations, +including server lifecycle, mode switching, folder/database management, +and tool schema generation. +""" + +import asyncio +import datetime +import json +from pathlib import Path +from typing import Optional, List, Dict, Any + +from oai.constants import MAX_TOOL_LOOPS +from oai.config.database import get_database +from oai.mcp.platform import CrossPlatformMCPConfig +from oai.mcp.server import MCPFilesystemServer +from oai.utils.logging import get_logger + + +class MCPManager: + """ + Manage MCP server lifecycle, tool calls, and mode switching. + + This class provides the main interface for MCP functionality including: + - Enabling/disabling the MCP server + - Managing allowed folders and databases + - Switching between file and database modes + - Generating tool schemas for API requests + - Executing tool calls + + Attributes: + enabled: Whether MCP is currently enabled + write_enabled: Whether write mode is enabled (non-persistent) + mode: Current mode ('files' or 'database') + selected_db_index: Index of selected database (if in database mode) + server: The MCPFilesystemServer instance + allowed_folders: List of allowed folder paths + databases: List of registered databases + config: Platform configuration + session_start_time: When the current session started + """ + + def __init__(self): + """Initialize the MCP manager.""" + self.enabled = False + self.write_enabled = False # Off by default, resets each session + self.mode = "files" + self.selected_db_index: Optional[int] = None + + self.server: Optional[MCPFilesystemServer] = None + self.allowed_folders: List[Path] = [] + self.databases: List[Dict[str, Any]] = [] + + self.config = CrossPlatformMCPConfig() + self.session_start_time: Optional[datetime.datetime] = None + + # Load persisted data + self._load_folders() + self._load_databases() + + get_logger().info("MCP Manager initialized") + + # ========================================================================= + # PERSISTENCE + # ========================================================================= + + def _load_folders(self) -> None: + """Load allowed folders from database.""" + logger = get_logger() + db = get_database() + folders_json = db.get_mcp_config("allowed_folders") + + if folders_json: + try: + folder_paths = json.loads(folders_json) + self.allowed_folders = [ + self.config.normalize_path(p) for p in folder_paths + ] + logger.info(f"Loaded {len(self.allowed_folders)} folders from config") + except Exception as e: + logger.error(f"Error loading MCP folders: {e}") + self.allowed_folders = [] + + def _save_folders(self) -> None: + """Save allowed folders to database.""" + db = get_database() + folder_paths = [str(p) for p in self.allowed_folders] + db.set_mcp_config("allowed_folders", json.dumps(folder_paths)) + get_logger().info(f"Saved {len(self.allowed_folders)} folders to config") + + def _load_databases(self) -> None: + """Load databases from database.""" + self.databases = get_database().get_mcp_databases() + get_logger().info(f"Loaded {len(self.databases)} databases from config") + + # ========================================================================= + # ENABLE/DISABLE + # ========================================================================= + + def enable(self) -> Dict[str, Any]: + """ + Enable MCP server. + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - folder_count: Number of allowed folders + - database_count: Number of registered databases + - message: Status message + - error: Error message if failed + """ + logger = get_logger() + + if self.enabled: + return { + "success": False, + "error": "MCP is already enabled" + } + + try: + self.server = MCPFilesystemServer(self.allowed_folders) + self.enabled = True + self.session_start_time = datetime.datetime.now() + + get_database().set_mcp_config("mcp_enabled", "true") + + logger.info("MCP Filesystem Server enabled") + + return { + "success": True, + "folder_count": len(self.allowed_folders), + "database_count": len(self.databases), + "message": "MCP Filesystem Server started successfully" + } + except Exception as e: + logger.error(f"Error enabling MCP: {e}") + return { + "success": False, + "error": str(e) + } + + def disable(self) -> Dict[str, Any]: + """ + Disable MCP server. + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - message: Status message + - error: Error message if failed + """ + logger = get_logger() + + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled" + } + + try: + self.server = None + self.enabled = False + self.session_start_time = None + self.mode = "files" + self.selected_db_index = None + + get_database().set_mcp_config("mcp_enabled", "false") + + logger.info("MCP Filesystem Server disabled") + + return { + "success": True, + "message": "MCP Filesystem Server stopped" + } + except Exception as e: + logger.error(f"Error disabling MCP: {e}") + return { + "success": False, + "error": str(e) + } + + # ========================================================================= + # WRITE MODE + # ========================================================================= + + def enable_write(self) -> bool: + """ + Enable write mode. + + This method should be called after user confirmation in the UI. + + Returns: + True if write mode was enabled, False otherwise + """ + logger = get_logger() + + if not self.enabled: + logger.warning("Attempted to enable write mode without MCP enabled") + return False + + self.write_enabled = True + logger.info("MCP write mode enabled by user") + get_database().log_mcp_stat("write_mode_enabled", "", True) + return True + + def disable_write(self) -> None: + """Disable write mode.""" + if self.write_enabled: + self.write_enabled = False + get_logger().info("MCP write mode disabled") + get_database().log_mcp_stat("write_mode_disabled", "", True) + + # ========================================================================= + # MODE SWITCHING + # ========================================================================= + + def switch_mode( + self, + new_mode: str, + db_index: Optional[int] = None + ) -> Dict[str, Any]: + """ + Switch between files and database mode. + + Args: + new_mode: 'files' or 'database' + db_index: Database number (1-based) if switching to database mode + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - mode: Current mode + - message: Status message + - tools: Available tools in this mode + - database: Database info (if database mode) + """ + logger = get_logger() + + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled. Use /mcp on first" + } + + if new_mode == "files": + self.mode = "files" + self.selected_db_index = None + logger.info("Switched to file mode") + return { + "success": True, + "mode": "files", + "message": "Switched to file mode", + "tools": ["read_file", "list_directory", "search_files"] + } + + elif new_mode == "database": + if db_index is None: + return { + "success": False, + "error": "Database index required. Use /mcp db " + } + + if db_index < 1 or db_index > len(self.databases): + return { + "success": False, + "error": f"Invalid database number. Use 1-{len(self.databases)}" + } + + self.mode = "database" + self.selected_db_index = db_index - 1 # Convert to 0-based + db = self.databases[self.selected_db_index] + + logger.info(f"Switched to database mode: {db['name']}") + return { + "success": True, + "mode": "database", + "database": db, + "message": f"Switched to database #{db_index}: {db['name']}", + "tools": ["inspect_database", "search_database", "query_database"] + } + + else: + return { + "success": False, + "error": f"Invalid mode: {new_mode}" + } + + # ========================================================================= + # FOLDER MANAGEMENT + # ========================================================================= + + def add_folder(self, folder_path: str) -> Dict[str, Any]: + """ + Add a folder to the allowed list. + + Args: + folder_path: Path to the folder + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - path: Normalized path + - stats: Folder statistics + - total_folders: Total number of allowed folders + - warning: Optional warning message + """ + logger = get_logger() + + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled. Use /mcp on first" + } + + try: + path = self.config.normalize_path(folder_path) + + # Validation + if not path.exists(): + return { + "success": False, + "error": f"Directory does not exist: {path}" + } + + if not path.is_dir(): + return { + "success": False, + "error": f"Not a directory: {path}" + } + + if self.config.is_system_directory(path): + return { + "success": False, + "error": f"Cannot add system directory: {path}" + } + + if path in self.allowed_folders: + return { + "success": False, + "error": f"Folder already in allowed list: {path}" + } + + # Check for nested paths + parent_folder = None + for existing in self.allowed_folders: + try: + path.relative_to(existing) + parent_folder = existing + break + except ValueError: + continue + + # Get folder stats + stats = self.config.get_folder_stats(path) + + # Add folder + self.allowed_folders.append(path) + self._save_folders() + + # Update server and reload gitignores + if self.server: + self.server.allowed_folders = self.allowed_folders + self.server.reload_gitignores() + + logger.info(f"Added folder to MCP: {path}") + + result = { + "success": True, + "path": str(path), + "stats": stats, + "total_folders": len(self.allowed_folders) + } + + if parent_folder: + result["warning"] = f"Note: {parent_folder} is already allowed (parent folder)" + + return result + + except Exception as e: + logger.error(f"Error adding folder: {e}") + return { + "success": False, + "error": str(e) + } + + def remove_folder(self, folder_ref: str) -> Dict[str, Any]: + """ + Remove a folder from the allowed list. + + Args: + folder_ref: Folder path or number (1-based) + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - path: Removed folder path + - total_folders: Remaining folder count + - warning: Optional warning message + """ + logger = get_logger() + + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled. Use /mcp on first" + } + + try: + # Check if it's a number + if folder_ref.isdigit(): + index = int(folder_ref) - 1 + if 0 <= index < len(self.allowed_folders): + path = self.allowed_folders[index] + else: + return { + "success": False, + "error": f"Invalid folder number: {folder_ref}" + } + else: + # Treat as path + path = self.config.normalize_path(folder_ref) + if path not in self.allowed_folders: + return { + "success": False, + "error": f"Folder not in allowed list: {path}" + } + + # Remove folder + self.allowed_folders.remove(path) + self._save_folders() + + # Update server and reload gitignores + if self.server: + self.server.allowed_folders = self.allowed_folders + self.server.reload_gitignores() + + logger.info(f"Removed folder from MCP: {path}") + + result = { + "success": True, + "path": str(path), + "total_folders": len(self.allowed_folders) + } + + if len(self.allowed_folders) == 0: + result["warning"] = "No allowed folders remaining. Add one with /mcp add " + + return result + + except Exception as e: + logger.error(f"Error removing folder: {e}") + return { + "success": False, + "error": str(e) + } + + def list_folders(self) -> Dict[str, Any]: + """ + List all allowed folders with stats. + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - folders: List of folder info + - total_folders: Total folder count + - total_files: Total file count across folders + - total_size_mb: Total size in MB + """ + try: + folders_info = [] + total_files = 0 + total_size = 0 + + for idx, folder in enumerate(self.allowed_folders, 1): + stats = self.config.get_folder_stats(folder) + + folder_info = { + "number": idx, + "path": str(folder), + "exists": stats.get("exists", False) + } + + if stats.get("exists"): + folder_info["file_count"] = stats["file_count"] + folder_info["size_mb"] = stats["size_mb"] + total_files += stats["file_count"] + total_size += stats["total_size"] + + folders_info.append(folder_info) + + return { + "success": True, + "folders": folders_info, + "total_folders": len(self.allowed_folders), + "total_files": total_files, + "total_size_mb": total_size / (1024 * 1024) + } + + except Exception as e: + get_logger().error(f"Error listing folders: {e}") + return { + "success": False, + "error": str(e) + } + + # ========================================================================= + # DATABASE MANAGEMENT + # ========================================================================= + + def add_database(self, db_path: str) -> Dict[str, Any]: + """ + Add a SQLite database. + + Args: + db_path: Path to the database file + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - database: Database info + - number: Database number + - message: Status message + """ + import sqlite3 + logger = get_logger() + + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled. Use /mcp on first" + } + + try: + path = Path(db_path).resolve() + + # Validation + if not path.exists(): + return { + "success": False, + "error": f"Database file not found: {path}" + } + + if not path.is_file(): + return { + "success": False, + "error": f"Not a file: {path}" + } + + # Check if already added + if any(db["path"] == str(path) for db in self.databases): + return { + "success": False, + "error": f"Database already added: {path.name}" + } + + # Validate it's a SQLite database + try: + conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True) + cursor = conn.cursor() + + # Get tables + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" + ) + tables = [row[0] for row in cursor.fetchall()] + + conn.close() + except sqlite3.DatabaseError as e: + return { + "success": False, + "error": f"Not a valid SQLite database: {e}" + } + + # Get file size + db_size = path.stat().st_size + + # Create database entry + db_info = { + "path": str(path), + "name": path.name, + "size": db_size, + "tables": tables, + "added": datetime.datetime.now().isoformat() + } + + # Save to config + db_id = get_database().add_mcp_database(db_info) + db_info["id"] = db_id + + # Add to list + self.databases.append(db_info) + + logger.info(f"Added database: {path.name}") + + return { + "success": True, + "database": db_info, + "number": len(self.databases), + "message": f"Added database #{len(self.databases)}: {path.name}" + } + + except Exception as e: + logger.error(f"Error adding database: {e}") + return { + "success": False, + "error": str(e) + } + + def remove_database(self, db_ref: str) -> Dict[str, Any]: + """ + Remove a database. + + Args: + db_ref: Database number (1-based) or path + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - database: Removed database info + - message: Status message + - warning: Optional warning message + """ + logger = get_logger() + + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled" + } + + try: + # Check if it's a number + if db_ref.isdigit(): + index = int(db_ref) - 1 + if 0 <= index < len(self.databases): + db = self.databases[index] + else: + return { + "success": False, + "error": f"Invalid database number: {db_ref}" + } + else: + # Treat as path + path = str(Path(db_ref).resolve()) + db = next((d for d in self.databases if d["path"] == path), None) + if not db: + return { + "success": False, + "error": f"Database not found: {db_ref}" + } + index = self.databases.index(db) + + # If currently selected, deselect + if self.mode == "database" and self.selected_db_index == index: + self.mode = "files" + self.selected_db_index = None + + # Remove from config + get_database().remove_mcp_database(db["path"]) + + # Remove from list + self.databases.pop(index) + + logger.info(f"Removed database: {db['name']}") + + result = { + "success": True, + "database": db, + "message": f"Removed database: {db['name']}" + } + + if len(self.databases) == 0: + result["warning"] = "No databases remaining. Add one with /mcp add db " + + return result + + except Exception as e: + logger.error(f"Error removing database: {e}") + return { + "success": False, + "error": str(e) + } + + def list_databases(self) -> Dict[str, Any]: + """ + List all databases. + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - databases: List of database info + - count: Number of databases + """ + try: + db_list = [] + for idx, db in enumerate(self.databases, 1): + db_info = { + "number": idx, + "name": db["name"], + "path": db["path"], + "size_mb": db["size"] / (1024 * 1024), + "tables": db["tables"], + "table_count": len(db["tables"]), + "added": db["added"] + } + + # Check if file still exists + if not Path(db["path"]).exists(): + db_info["warning"] = "File not found" + + db_list.append(db_info) + + return { + "success": True, + "databases": db_list, + "count": len(db_list) + } + + except Exception as e: + get_logger().error(f"Error listing databases: {e}") + return { + "success": False, + "error": str(e) + } + + # ========================================================================= + # GITIGNORE + # ========================================================================= + + def toggle_gitignore(self, enabled: bool) -> Dict[str, Any]: + """ + Toggle .gitignore filtering. + + Args: + enabled: Whether to enable gitignore filtering + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - message: Status message + - pattern_count: Number of patterns (if enabled) + """ + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled" + } + + if not self.server: + return { + "success": False, + "error": "MCP server not running" + } + + self.server.respect_gitignore = enabled + status = "enabled" if enabled else "disabled" + + get_logger().info(f".gitignore filtering {status}") + + return { + "success": True, + "message": f".gitignore filtering {status}", + "pattern_count": self.server.gitignore_parser.pattern_count if enabled else 0 + } + + # ========================================================================= + # STATUS + # ========================================================================= + + def get_status(self) -> Dict[str, Any]: + """ + Get comprehensive MCP status. + + Returns: + Dictionary containing full MCP status information + """ + try: + stats = get_database().get_mcp_stats() + + uptime = None + if self.session_start_time: + delta = datetime.datetime.now() - self.session_start_time + hours = delta.seconds // 3600 + minutes = (delta.seconds % 3600) // 60 + uptime = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m" + + folder_info = self.list_folders() + + gitignore_status = "enabled" if ( + self.server and self.server.respect_gitignore + ) else "disabled" + gitignore_patterns = ( + self.server.gitignore_parser.pattern_count if self.server else 0 + ) + + # Current mode info + mode_info = { + "mode": self.mode, + "mode_display": ( + "Files" if self.mode == "files" + else f"DB #{self.selected_db_index + 1}" + ) + } + + if self.mode == "database" and self.selected_db_index is not None: + mode_info["database"] = self.databases[self.selected_db_index] + + return { + "success": True, + "enabled": self.enabled, + "write_enabled": self.write_enabled, + "uptime": uptime, + "mode_info": mode_info, + "folder_count": len(self.allowed_folders), + "database_count": len(self.databases), + "total_files": folder_info.get("total_files", 0), + "total_size_mb": folder_info.get("total_size_mb", 0), + "stats": stats, + "tools_available": self._get_current_tools(), + "gitignore_status": gitignore_status, + "gitignore_patterns": gitignore_patterns + } + + except Exception as e: + get_logger().error(f"Error getting MCP status: {e}") + return { + "success": False, + "error": str(e) + } + + def _get_current_tools(self) -> List[str]: + """Get tools available in current mode.""" + if not self.enabled: + return [] + + if self.mode == "files": + tools = ["read_file", "list_directory", "search_files"] + if self.write_enabled: + tools.extend([ + "write_file", "edit_file", "delete_file", + "create_directory", "move_file", "copy_file" + ]) + return tools + elif self.mode == "database": + return ["inspect_database", "search_database", "query_database"] + + return [] + + # ========================================================================= + # TOOL EXECUTION + # ========================================================================= + + async def call_tool(self, tool_name: str, **kwargs) -> Dict[str, Any]: + """ + Call an MCP tool. + + Args: + tool_name: Name of the tool to call + **kwargs: Tool arguments + + Returns: + Tool execution result + """ + if not self.enabled or not self.server: + return {"error": "MCP is not enabled"} + + # Check write permissions for write operations + write_tools = { + "write_file", "edit_file", "delete_file", + "create_directory", "move_file", "copy_file" + } + if tool_name in write_tools and not self.write_enabled: + return {"error": "Write operations disabled. Enable with /mcp write on"} + + try: + # File mode tools (read-only) + if tool_name == "read_file": + return await self.server.read_file(kwargs.get("file_path", "")) + elif tool_name == "list_directory": + return await self.server.list_directory( + kwargs.get("dir_path", ""), + kwargs.get("recursive", False) + ) + elif tool_name == "search_files": + return await self.server.search_files( + kwargs.get("pattern", ""), + kwargs.get("search_path"), + kwargs.get("content_search", False) + ) + + # Database mode tools + elif tool_name == "inspect_database": + if self.mode != "database" or self.selected_db_index is None: + return {"error": "Not in database mode. Use /mcp db first"} + db = self.databases[self.selected_db_index] + return await self.server.inspect_database( + db["path"], + kwargs.get("table_name") + ) + elif tool_name == "search_database": + if self.mode != "database" or self.selected_db_index is None: + return {"error": "Not in database mode. Use /mcp db first"} + db = self.databases[self.selected_db_index] + return await self.server.search_database( + db["path"], + kwargs.get("search_term", ""), + kwargs.get("table_name"), + kwargs.get("column_name") + ) + elif tool_name == "query_database": + if self.mode != "database" or self.selected_db_index is None: + return {"error": "Not in database mode. Use /mcp db first"} + db = self.databases[self.selected_db_index] + return await self.server.query_database( + db["path"], + kwargs.get("query", ""), + kwargs.get("limit") + ) + + # Write mode tools + elif tool_name == "write_file": + return await self.server.write_file( + kwargs.get("file_path", ""), + kwargs.get("content", "") + ) + elif tool_name == "edit_file": + return await self.server.edit_file( + kwargs.get("file_path", ""), + kwargs.get("old_text", ""), + kwargs.get("new_text", "") + ) + elif tool_name == "delete_file": + return await self.server.delete_file( + kwargs.get("file_path", ""), + kwargs.get("reason", "No reason provided"), + kwargs.get("confirm_callback") + ) + elif tool_name == "create_directory": + return await self.server.create_directory( + kwargs.get("dir_path", "") + ) + elif tool_name == "move_file": + return await self.server.move_file( + kwargs.get("source_path", ""), + kwargs.get("dest_path", "") + ) + elif tool_name == "copy_file": + return await self.server.copy_file( + kwargs.get("source_path", ""), + kwargs.get("dest_path", "") + ) + else: + return {"error": f"Unknown tool: {tool_name}"} + + except Exception as e: + get_logger().error(f"Error calling MCP tool {tool_name}: {e}") + return {"error": str(e)} + + # ========================================================================= + # TOOL SCHEMA GENERATION + # ========================================================================= + + def get_tools_schema(self) -> List[Dict[str, Any]]: + """ + Get MCP tools as OpenAI function calling schema. + + Returns the appropriate schema based on current mode. + + Returns: + List of tool definitions for API request + """ + if not self.enabled: + return [] + + if self.mode == "files": + return self._get_file_tools_schema() + elif self.mode == "database": + return self._get_database_tools_schema() + + return [] + + def _get_file_tools_schema(self) -> List[Dict[str, Any]]: + """Get file mode tools schema.""" + if len(self.allowed_folders) == 0: + return [] + + allowed_dirs_str = ", ".join(str(f) for f in self.allowed_folders) + + # Base read-only tools + tools = [ + { + "type": "function", + "function": { + "name": "search_files", + "description": ( + f"Search for files in the user's local filesystem. " + f"Allowed directories: {allowed_dirs_str}. " + "Can search by filename pattern (e.g., '*.py' for Python files) " + "or search within file contents. Automatically respects .gitignore " + "patterns and excludes virtual environments." + ), + "parameters": { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": ( + "Search pattern. For filename search, use glob patterns " + "like '*.py', '*.txt', 'report*'. For content search, " + "use plain text to search for." + ) + }, + "content_search": { + "type": "boolean", + "description": ( + "If true, searches inside file contents (SLOW). " + "If false, searches only filenames (FAST). Default is false." + ), + "default": False + }, + "search_path": { + "type": "string", + "description": ( + "Optional: specific directory to search in. " + "If not provided, searches all allowed directories." + ), + } + }, + "required": ["pattern"] + } + } + }, + { + "type": "function", + "function": { + "name": "read_file", + "description": ( + "Read the complete contents of a text file. " + "Only works for files within allowed directories. " + "Maximum file size: 10MB. Files larger than 50KB are automatically truncated. " + "Respects .gitignore patterns." + ), + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": ( + "Full path to the file to read " + "(e.g., /Users/username/Documents/report.txt)" + ) + } + }, + "required": ["file_path"] + } + } + }, + { + "type": "function", + "function": { + "name": "list_directory", + "description": ( + "List all files and subdirectories in a directory. " + "Automatically filters out virtual environments, build artifacts, " + "and .gitignore patterns. Limited to 1000 items." + ), + "parameters": { + "type": "object", + "properties": { + "dir_path": { + "type": "string", + "description": ( + "Directory path to list " + "(e.g., ~/Documents or /Users/username/Projects)" + ) + }, + "recursive": { + "type": "boolean", + "description": ( + "If true, lists files in subdirectories recursively. " + "Default is true. WARNING: Can be very large." + ), + "default": True + } + }, + "required": ["dir_path"] + } + } + } + ] + + # Add write tools if write mode is enabled + if self.write_enabled: + tools.extend(self._get_write_tools_schema(allowed_dirs_str)) + + return tools + + def _get_write_tools_schema(self, allowed_dirs_str: str) -> List[Dict[str, Any]]: + """Get write mode tools schema.""" + return [ + { + "type": "function", + "function": { + "name": "write_file", + "description": ( + f"Create or overwrite a file with content. " + f"Allowed directories: {allowed_dirs_str}. " + "Automatically creates parent directories if needed." + ), + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Full path to the file to create or overwrite" + }, + "content": { + "type": "string", + "description": "The complete content to write to the file" + } + }, + "required": ["file_path", "content"] + } + } + }, + { + "type": "function", + "function": { + "name": "edit_file", + "description": ( + "Make targeted edits by finding and replacing specific text. " + "The old_text must match exactly and appear only once in the file." + ), + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Full path to the file to edit" + }, + "old_text": { + "type": "string", + "description": ( + "The exact text to find and replace. " + "Must match exactly and appear only once." + ) + }, + "new_text": { + "type": "string", + "description": "The new text to replace the old text with" + } + }, + "required": ["file_path", "old_text", "new_text"] + } + } + }, + { + "type": "function", + "function": { + "name": "delete_file", + "description": ( + "Delete a file. ALWAYS requires user confirmation. " + "The user will see the file details and your reason before deciding." + ), + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Full path to the file to delete" + }, + "reason": { + "type": "string", + "description": ( + "Clear explanation for why this file should be deleted. " + "The user will see this reason." + ) + } + }, + "required": ["file_path", "reason"] + } + } + }, + { + "type": "function", + "function": { + "name": "create_directory", + "description": ( + "Create a new directory (and parent directories if needed). " + "Returns success if directory already exists." + ), + "parameters": { + "type": "object", + "properties": { + "dir_path": { + "type": "string", + "description": "Full path to the directory to create" + } + }, + "required": ["dir_path"] + } + } + }, + { + "type": "function", + "function": { + "name": "move_file", + "description": "Move or rename a file within allowed directories.", + "parameters": { + "type": "object", + "properties": { + "source_path": { + "type": "string", + "description": "Full path to the file to move/rename" + }, + "dest_path": { + "type": "string", + "description": "Full path for the new location/name" + } + }, + "required": ["source_path", "dest_path"] + } + } + }, + { + "type": "function", + "function": { + "name": "copy_file", + "description": "Copy a file to a new location within allowed directories.", + "parameters": { + "type": "object", + "properties": { + "source_path": { + "type": "string", + "description": "Full path to the file to copy" + }, + "dest_path": { + "type": "string", + "description": "Full path for the copy destination" + } + }, + "required": ["source_path", "dest_path"] + } + } + } + ] + + def _get_database_tools_schema(self) -> List[Dict[str, Any]]: + """Get database mode tools schema.""" + if self.selected_db_index is None or self.selected_db_index >= len(self.databases): + return [] + + db = self.databases[self.selected_db_index] + db_name = db["name"] + tables_str = ", ".join(db["tables"]) + + return [ + { + "type": "function", + "function": { + "name": "inspect_database", + "description": ( + f"Inspect the schema of the currently selected database ({db_name}). " + f"Can get all tables or details for a specific table. " + f"Available tables: {tables_str}." + ), + "parameters": { + "type": "object", + "properties": { + "table_name": { + "type": "string", + "description": ( + f"Optional: specific table to inspect. " + f"If not provided, returns info for all tables. " + f"Available: {tables_str}" + ) + } + }, + "required": [] + } + } + }, + { + "type": "function", + "function": { + "name": "search_database", + "description": ( + f"Search for a value across tables in the database ({db_name}). " + f"Performs partial matching across columns. " + f"Limited to {self.server.default_query_limit} results." + ), + "parameters": { + "type": "object", + "properties": { + "search_term": { + "type": "string", + "description": "Value to search for (partial match supported)" + }, + "table_name": { + "type": "string", + "description": f"Optional: limit search to specific table. Available: {tables_str}" + }, + "column_name": { + "type": "string", + "description": "Optional: limit search to specific column" + } + }, + "required": ["search_term"] + } + } + }, + { + "type": "function", + "function": { + "name": "query_database", + "description": ( + f"Execute a read-only SQL query on the database ({db_name}). " + f"Supports SELECT queries including JOINs, subqueries, CTEs. " + f"Maximum {self.server.max_query_results} rows. " + f"Timeout: {self.server.max_query_timeout} seconds. " + "INSERT/UPDATE/DELETE/DROP are blocked." + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": ( + f"SQL SELECT query to execute. " + f"Available tables: {tables_str}." + ) + }, + "limit": { + "type": "integer", + "description": ( + f"Optional: max rows to return " + f"(default {self.server.default_query_limit}, " + f"max {self.server.max_query_results})" + ) + } + }, + "required": ["query"] + } + } + } + ] + + +# Global MCP manager instance +_mcp_manager: Optional[MCPManager] = None + + +def get_mcp_manager() -> MCPManager: + """ + Get the global MCP manager instance. + + Returns: + The shared MCPManager instance + """ + global _mcp_manager + if _mcp_manager is None: + _mcp_manager = MCPManager() + return _mcp_manager diff --git a/oai/mcp/platform.py b/oai/mcp/platform.py new file mode 100644 index 0000000..dc941bf --- /dev/null +++ b/oai/mcp/platform.py @@ -0,0 +1,228 @@ +""" +Cross-platform MCP configuration for oAI. + +This module handles OS-specific configuration, path handling, +and security checks for the MCP filesystem server. +""" + +import os +import platform +import subprocess +from pathlib import Path +from typing import List, Dict, Any, Optional + +from oai.constants import SYSTEM_DIRS_BLACKLIST +from oai.utils.logging import get_logger + + +class CrossPlatformMCPConfig: + """ + Handle OS-specific MCP configuration. + + Provides methods for path normalization, security validation, + and OS-specific default directories. + + Attributes: + system: Operating system name + is_macos: Whether running on macOS + is_linux: Whether running on Linux + is_windows: Whether running on Windows + """ + + def __init__(self): + """Initialize platform detection.""" + self.system = platform.system() + self.is_macos = self.system == "Darwin" + self.is_linux = self.system == "Linux" + self.is_windows = self.system == "Windows" + + logger = get_logger() + logger.info(f"Detected OS: {self.system}") + + def get_default_allowed_dirs(self) -> List[Path]: + """ + Get safe default directories for the current OS. + + Returns: + List of default directories that are safe to access + """ + home = Path.home() + + if self.is_macos: + return [ + home / "Documents", + home / "Desktop", + home / "Downloads", + ] + + elif self.is_linux: + dirs = [home / "Documents"] + + # Try to get XDG directories + try: + for xdg_dir in ["DOCUMENTS", "DESKTOP", "DOWNLOAD"]: + result = subprocess.run( + ["xdg-user-dir", xdg_dir], + capture_output=True, + text=True, + timeout=1 + ) + if result.returncode == 0: + dir_path = Path(result.stdout.strip()) + if dir_path.exists(): + dirs.append(dir_path) + except (subprocess.TimeoutExpired, FileNotFoundError): + # Fallback to standard locations + dirs.extend([ + home / "Desktop", + home / "Downloads", + ]) + + return list(set(dirs)) + + elif self.is_windows: + return [ + home / "Documents", + home / "Desktop", + home / "Downloads", + ] + + # Fallback for unknown OS + return [home] + + def get_python_command(self) -> str: + """ + Get the Python executable path. + + Returns: + Path to the Python executable + """ + import sys + return sys.executable + + def get_filesystem_warning(self) -> str: + """ + Get OS-specific security warning message. + + Returns: + Warning message for the current OS + """ + if self.is_macos: + return """ +Note: macOS Security +The Filesystem MCP server needs access to your selected folder. +You may see a security prompt - click 'Allow' to proceed. +(System Settings > Privacy & Security > Files and Folders) +""" + elif self.is_linux: + return """ +Note: Linux Security +The Filesystem MCP server will access your selected folder. +Ensure oAI has appropriate file permissions. +""" + elif self.is_windows: + return """ +Note: Windows Security +The Filesystem MCP server will access your selected folder. +You may need to grant file access permissions. +""" + return "" + + def normalize_path(self, path: str) -> Path: + """ + Normalize a path for the current OS. + + Expands user directory (~) and resolves to absolute path. + + Args: + path: Path string to normalize + + Returns: + Normalized absolute Path + """ + return Path(os.path.expanduser(path)).resolve() + + def is_system_directory(self, path: Path) -> bool: + """ + Check if a path is a protected system directory. + + Args: + path: Path to check + + Returns: + True if the path is a system directory + """ + path_str = str(path) + for blocked in SYSTEM_DIRS_BLACKLIST: + if path_str.startswith(blocked): + return True + return False + + def is_safe_path(self, requested_path: Path, allowed_dirs: List[Path]) -> bool: + """ + Check if a path is within allowed directories. + + Args: + requested_path: Path being requested + allowed_dirs: List of allowed parent directories + + Returns: + True if the path is within an allowed directory + """ + try: + requested = requested_path.resolve() + + for allowed in allowed_dirs: + try: + allowed_resolved = allowed.resolve() + requested.relative_to(allowed_resolved) + return True + except ValueError: + continue + + return False + except Exception: + return False + + def get_folder_stats(self, folder: Path) -> Dict[str, Any]: + """ + Get statistics for a folder. + + Args: + folder: Path to the folder + + Returns: + Dictionary with folder statistics: + - exists: Whether the folder exists + - file_count: Number of files (if exists) + - total_size: Total size in bytes (if exists) + - size_mb: Size in megabytes (if exists) + - error: Error message (if any) + """ + logger = get_logger() + + try: + if not folder.exists() or not folder.is_dir(): + return {"exists": False} + + file_count = 0 + total_size = 0 + + for item in folder.rglob("*"): + if item.is_file(): + file_count += 1 + try: + total_size += item.stat().st_size + except (OSError, PermissionError): + pass + + return { + "exists": True, + "file_count": file_count, + "total_size": total_size, + "size_mb": total_size / (1024 * 1024), + } + + except Exception as e: + logger.error(f"Error getting folder stats for {folder}: {e}") + return {"exists": False, "error": str(e)} diff --git a/oai/mcp/server.py b/oai/mcp/server.py new file mode 100644 index 0000000..105a0de --- /dev/null +++ b/oai/mcp/server.py @@ -0,0 +1,1368 @@ +""" +MCP Filesystem Server implementation for oAI. + +This module provides the actual filesystem and database access operations +for the MCP integration. It handles file reading, directory listing, +searching, and SQLite database operations. +""" + +import datetime +import signal +import sqlite3 +from pathlib import Path +from typing import Optional, List, Dict, Any, Set + +from oai.constants import ( + MAX_FILE_SIZE, + CONTENT_TRUNCATION_THRESHOLD, + MAX_LIST_ITEMS, + MAX_QUERY_TIMEOUT, + MAX_QUERY_RESULTS, + DEFAULT_QUERY_LIMIT, + SKIP_DIRECTORIES, +) +from oai.config.database import get_database +from oai.mcp.platform import CrossPlatformMCPConfig +from oai.mcp.gitignore import GitignoreParser +from oai.mcp.validators import SQLiteQueryValidator +from oai.utils.logging import get_logger + + +class MCPFilesystemServer: + """ + MCP Filesystem Server with file access and SQLite database querying. + + Provides async methods for: + - File operations: read, write, edit, delete, copy, move + - Directory operations: list, create + - Search operations: filename and content search + - Database operations: inspect, search, query + + All operations respect allowed folder restrictions and gitignore patterns. + + Attributes: + allowed_folders: List of folders the server can access + config: Platform-specific configuration + max_file_size: Maximum file size for read operations + max_list_items: Maximum items to return in directory listings + respect_gitignore: Whether to apply gitignore filtering + gitignore_parser: Parser for gitignore patterns + max_query_timeout: Maximum seconds for database queries + max_query_results: Maximum rows to return from queries + default_query_limit: Default row limit for queries + """ + + def __init__(self, allowed_folders: List[Path]): + """ + Initialize the filesystem server. + + Args: + allowed_folders: List of folder paths the server can access + """ + self.allowed_folders = allowed_folders + self.config = CrossPlatformMCPConfig() + self.max_file_size = MAX_FILE_SIZE + self.max_list_items = MAX_LIST_ITEMS + self.respect_gitignore = True + + # Initialize gitignore parser and load patterns + self.gitignore_parser = GitignoreParser() + self._load_gitignores() + + # SQLite configuration + self.max_query_timeout = MAX_QUERY_TIMEOUT + self.max_query_results = MAX_QUERY_RESULTS + self.default_query_limit = DEFAULT_QUERY_LIMIT + + logger = get_logger() + logger.info( + f"MCP Filesystem Server initialized with {len(allowed_folders)} folders" + ) + + def _load_gitignores(self) -> None: + """Load all .gitignore files from allowed folders.""" + logger = get_logger() + gitignore_count = 0 + + for folder in self.allowed_folders: + if not folder.exists(): + continue + + # Load root .gitignore + root_gitignore = folder / ".gitignore" + if root_gitignore.exists(): + self.gitignore_parser.add_gitignore(root_gitignore) + gitignore_count += 1 + logger.info(f"Loaded .gitignore from {folder}") + + # Load nested .gitignore files + try: + for gitignore_path in folder.rglob(".gitignore"): + if gitignore_path != root_gitignore: + self.gitignore_parser.add_gitignore(gitignore_path) + gitignore_count += 1 + logger.debug(f"Loaded nested .gitignore: {gitignore_path}") + except Exception as e: + logger.warning(f"Error loading nested .gitignores from {folder}: {e}") + + if gitignore_count > 0: + logger.info( + f"Loaded {gitignore_count} .gitignore file(s) with " + f"{self.gitignore_parser.pattern_count} total patterns" + ) + + def reload_gitignores(self) -> None: + """Reload all .gitignore files (call when allowed_folders changes).""" + self.gitignore_parser.clear() + self._load_gitignores() + get_logger().info("Reloaded .gitignore patterns") + + def is_allowed_path(self, path: Path) -> bool: + """ + Check if a path is within allowed folders. + + Args: + path: Path to check + + Returns: + True if the path is allowed + """ + return self.config.is_safe_path(path, self.allowed_folders) + + def _should_skip_path(self, item_path: Path) -> bool: + """ + Check if a path should be skipped during traversal. + + Args: + item_path: Path to check + + Returns: + True if the path should be skipped + """ + # Check hardcoded skip directories + path_parts = item_path.parts + if any(part in SKIP_DIRECTORIES for part in path_parts): + return True + + # Check gitignore patterns (if enabled) + if self.respect_gitignore and self.gitignore_parser.should_ignore(item_path): + return True + + return False + + def _log_stat( + self, + tool_name: str, + folder: Optional[str], + success: bool, + error_message: Optional[str] = None + ) -> None: + """Log an MCP tool usage event.""" + get_database().log_mcp_stat(tool_name, folder, success, error_message) + + # ========================================================================= + # FILE READ OPERATIONS + # ========================================================================= + + async def read_file(self, file_path: str) -> Dict[str, Any]: + """ + Read file contents. + + Args: + file_path: Path to the file to read + + Returns: + Dictionary containing: + - content: File content + - path: Resolved file path + - size: File size in bytes + - truncated: Whether content was truncated + - error: Error message if failed + """ + logger = get_logger() + + try: + path = self.config.normalize_path(file_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} not in allowed folders" + logger.warning(error_msg) + self._log_stat("read_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + if not path.exists(): + error_msg = f"File not found: {path}" + logger.warning(error_msg) + self._log_stat("read_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + if not path.is_file(): + error_msg = f"Not a file: {path}" + logger.warning(error_msg) + self._log_stat("read_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Check if file should be ignored + if self.respect_gitignore and self.gitignore_parser.should_ignore(path): + error_msg = f"File ignored by .gitignore: {path}" + logger.warning(error_msg) + self._log_stat("read_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + file_size = path.stat().st_size + if file_size > self.max_file_size: + error_msg = ( + f"File too large: {file_size / (1024*1024):.1f}MB " + f"(max: {self.max_file_size / (1024*1024):.0f}MB)" + ) + logger.warning(error_msg) + self._log_stat("read_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Try to read as text + try: + content = path.read_text(encoding="utf-8") + + # Truncate large files + if file_size > CONTENT_TRUNCATION_THRESHOLD: + lines = content.split("\n") + total_lines = len(lines) + + head_lines = 500 + tail_lines = 100 + + if total_lines > (head_lines + tail_lines): + truncated_content = ( + "\n".join(lines[:head_lines]) + + f"\n\n... [TRUNCATED: {total_lines - head_lines - tail_lines} " + f"lines omitted] ...\n\n" + + "\n".join(lines[-tail_lines:]) + ) + + logger.info( + f"Read file (truncated): {path} " + f"({file_size} bytes, {total_lines} lines)" + ) + self._log_stat("read_file", str(path.parent), True) + + return { + "content": truncated_content, + "path": str(path), + "size": file_size, + "truncated": True, + "total_lines": total_lines, + "lines_shown": head_lines + tail_lines, + "note": ( + f"File truncated: showing first {head_lines} " + f"and last {tail_lines} lines of {total_lines} total" + ) + } + + logger.info(f"Read file: {path} ({file_size} bytes)") + self._log_stat("read_file", str(path.parent), True) + + return { + "content": content, + "path": str(path), + "size": file_size + } + + except UnicodeDecodeError: + error_msg = f"Cannot decode file as UTF-8: {path}" + logger.warning(error_msg) + self._log_stat("read_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + except Exception as e: + error_msg = f"Error reading file: {e}" + logger.error(error_msg) + self._log_stat("read_file", file_path, False, str(e)) + return {"error": error_msg} + + async def list_directory( + self, + dir_path: str, + recursive: bool = False + ) -> Dict[str, Any]: + """ + List directory contents. + + Args: + dir_path: Path to the directory + recursive: Whether to list recursively + + Returns: + Dictionary containing: + - path: Directory path + - items: List of file/directory info + - count: Number of items + - truncated: Whether results were truncated + """ + logger = get_logger() + + try: + path = self.config.normalize_path(dir_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} not in allowed folders" + logger.warning(error_msg) + self._log_stat("list_directory", str(path), False, error_msg) + return {"error": error_msg} + + if not path.exists(): + error_msg = f"Directory not found: {path}" + logger.warning(error_msg) + self._log_stat("list_directory", str(path), False, error_msg) + return {"error": error_msg} + + if not path.is_dir(): + error_msg = f"Not a directory: {path}" + logger.warning(error_msg) + self._log_stat("list_directory", str(path), False, error_msg) + return {"error": error_msg} + + items = [] + pattern = "**/*" if recursive else "*" + + for item in path.glob(pattern): + # Stop if we hit the limit + if len(items) >= self.max_list_items: + break + + # Skip excluded directories + if self._should_skip_path(item): + continue + + try: + stat = item.stat() + items.append({ + "name": item.name, + "path": str(item), + "type": "directory" if item.is_dir() else "file", + "size": stat.st_size if item.is_file() else 0, + "modified": datetime.datetime.fromtimestamp( + stat.st_mtime + ).isoformat() + }) + except (OSError, PermissionError): + continue + + truncated = len(items) >= self.max_list_items + + logger.info( + f"Listed directory: {path} ({len(items)} items, " + f"recursive={recursive}, truncated={truncated})" + ) + self._log_stat("list_directory", str(path), True) + + result = { + "path": str(path), + "items": items, + "count": len(items), + "truncated": truncated + } + + if truncated: + result["note"] = ( + f"Results limited to {self.max_list_items} items. " + "Use more specific search path or disable recursive mode." + ) + + return result + + except Exception as e: + error_msg = f"Error listing directory: {e}" + logger.error(error_msg) + self._log_stat("list_directory", dir_path, False, str(e)) + return {"error": error_msg} + + async def search_files( + self, + pattern: str, + search_path: Optional[str] = None, + content_search: bool = False + ) -> Dict[str, Any]: + """ + Search for files by name or content. + + Args: + pattern: Search pattern (glob for filename, text for content) + search_path: Optional path to search within + content_search: Whether to search within file contents + + Returns: + Dictionary containing: + - pattern: The search pattern used + - search_type: 'filename' or 'content' + - matches: List of matching files + - count: Number of matches + """ + logger = get_logger() + + try: + # Determine search roots + if search_path: + path = self.config.normalize_path(search_path) + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} not in allowed folders" + logger.warning(error_msg) + self._log_stat("search_files", str(path), False, error_msg) + return {"error": error_msg} + search_roots = [path] + else: + search_roots = self.allowed_folders + + matches = [] + + for root in search_roots: + if not root.exists(): + continue + + # Filename search (fast) + if not content_search: + for item in root.rglob(pattern): + if item.is_file() and not self._should_skip_path(item): + try: + stat = item.stat() + matches.append({ + "path": str(item), + "name": item.name, + "size": stat.st_size, + "modified": datetime.datetime.fromtimestamp( + stat.st_mtime + ).isoformat() + }) + except (OSError, PermissionError): + continue + else: + # Content search (slower) + for item in root.rglob("*"): + if item.is_file() and not self._should_skip_path(item): + try: + # Skip large files + if item.stat().st_size > self.max_file_size: + continue + + # Try to read and search content + try: + content = item.read_text(encoding="utf-8") + if pattern.lower() in content.lower(): + stat = item.stat() + matches.append({ + "path": str(item), + "name": item.name, + "size": stat.st_size, + "modified": datetime.datetime.fromtimestamp( + stat.st_mtime + ).isoformat() + }) + except (UnicodeDecodeError, PermissionError): + continue + except (OSError, PermissionError): + continue + + search_type = "content" if content_search else "filename" + logger.info( + f"Searched files: pattern='{pattern}', type={search_type}, " + f"found={len(matches)}" + ) + self._log_stat( + "search_files", + str(search_roots[0]) if search_roots else None, + True + ) + + return { + "pattern": pattern, + "search_type": search_type, + "matches": matches, + "count": len(matches) + } + + except Exception as e: + error_msg = f"Error searching files: {e}" + logger.error(error_msg) + self._log_stat("search_files", search_path or "all", False, str(e)) + return {"error": error_msg} + + # ========================================================================= + # FILE WRITE OPERATIONS + # ========================================================================= + + async def write_file(self, file_path: str, content: str) -> Dict[str, Any]: + """ + Create or overwrite a file. + + Note: Ignores .gitignore - allows writing to any file in allowed folders. + + Args: + file_path: Path to the file + content: Content to write + + Returns: + Dictionary containing: + - success: Whether the operation succeeded + - path: File path + - size: File size after writing + - created: Whether a new file was created + """ + logger = get_logger() + + try: + path = self.config.normalize_path(file_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("write_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Create parent directory if needed + parent_dir = path.parent + if not parent_dir.exists(): + parent_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Created parent directory: {parent_dir}") + + is_new_file = not path.exists() + + # Write content + path.write_text(content, encoding="utf-8") + file_size = path.stat().st_size + + logger.info( + f"{'Created' if is_new_file else 'Updated'} file: " + f"{path} ({file_size} bytes)" + ) + self._log_stat("write_file", str(path.parent), True) + + return { + "success": True, + "path": str(path), + "size": file_size, + "created": is_new_file + } + + except PermissionError as e: + error_msg = f"Permission denied writing to {file_path}: {e}" + logger.error(error_msg) + self._log_stat("write_file", file_path, False, str(e)) + return {"error": error_msg, "error_type": "PermissionError"} + except UnicodeEncodeError as e: + error_msg = f"Encoding error writing to {file_path}: {e}" + logger.error(error_msg) + self._log_stat("write_file", file_path, False, str(e)) + return {"error": error_msg, "error_type": "EncodingError"} + except Exception as e: + error_msg = f"Error writing file {file_path}: {e}" + logger.error(error_msg) + self._log_stat("write_file", file_path, False, str(e)) + return {"error": error_msg} + + async def edit_file( + self, + file_path: str, + old_text: str, + new_text: str + ) -> Dict[str, Any]: + """ + Find and replace text in a file. + + Note: Ignores .gitignore - allows editing any file in allowed folders. + + Args: + file_path: Path to the file + old_text: Text to find (must be unique in file) + new_text: Text to replace with + + Returns: + Dictionary containing: + - success: Whether the operation succeeded + - path: File path + - changes: Number of replacements made + """ + logger = get_logger() + + try: + path = self.config.normalize_path(file_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("edit_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + if not path.exists(): + error_msg = f"File not found: {path}" + logger.warning(error_msg) + self._log_stat("edit_file", str(path), False, error_msg) + return {"error": error_msg} + + if not path.is_file(): + error_msg = f"Not a file: {path}" + logger.warning(error_msg) + self._log_stat("edit_file", str(path), False, error_msg) + return {"error": error_msg} + + # Read current content + current_content = path.read_text(encoding="utf-8") + + # Check if old_text exists + if old_text not in current_content: + error_msg = f"Text not found in file: '{old_text[:50]}...'" + logger.warning(f"Edit failed - text not found in {path}") + self._log_stat("edit_file", str(path), False, error_msg) + return {"error": error_msg} + + # Check for unique match + occurrence_count = current_content.count(old_text) + if occurrence_count > 1: + error_msg = ( + f"Text appears {occurrence_count} times in file. " + "Please provide more context to make the match unique." + ) + logger.warning(f"Edit failed - ambiguous match in {path}") + self._log_stat("edit_file", str(path), False, error_msg) + return {"error": error_msg} + + # Perform replacement + new_content = current_content.replace(old_text, new_text, 1) + path.write_text(new_content, encoding="utf-8") + + logger.info(f"Edited file: {path} (1 replacement)") + self._log_stat("edit_file", str(path.parent), True) + + return { + "success": True, + "path": str(path), + "changes": 1 + } + + except PermissionError as e: + error_msg = f"Permission denied editing {file_path}: {e}" + logger.error(error_msg) + self._log_stat("edit_file", file_path, False, str(e)) + return {"error": error_msg, "error_type": "PermissionError"} + except UnicodeDecodeError as e: + error_msg = f"Cannot read file (encoding issue): {file_path}" + logger.error(error_msg) + self._log_stat("edit_file", file_path, False, str(e)) + return {"error": error_msg, "error_type": "EncodingError"} + except Exception as e: + error_msg = f"Error editing file {file_path}: {e}" + logger.error(error_msg) + self._log_stat("edit_file", file_path, False, str(e)) + return {"error": error_msg} + + async def delete_file( + self, + file_path: str, + reason: str, + confirm_callback=None + ) -> Dict[str, Any]: + """ + Delete a file with user confirmation. + + Note: Ignores .gitignore - allows deleting any file in allowed folders. + + Args: + file_path: Path to the file to delete + reason: Reason for deletion (shown to user) + confirm_callback: Callback function for user confirmation + + Returns: + Dictionary containing: + - success: Whether the file was deleted + - path: File path + - user_cancelled: Whether user cancelled the operation + """ + logger = get_logger() + + try: + path = self.config.normalize_path(file_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("delete_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + if not path.exists(): + error_msg = f"File not found: {path}" + logger.warning(error_msg) + self._log_stat("delete_file", str(path), False, error_msg) + return {"error": error_msg} + + if not path.is_file(): + error_msg = f"Not a file: {path}" + logger.warning(error_msg) + self._log_stat("delete_file", str(path), False, error_msg) + return {"error": error_msg} + + # Get file info for confirmation + file_size = path.stat().st_size + file_mtime = datetime.datetime.fromtimestamp(path.stat().st_mtime) + + # User confirmation required (handled by callback) + if confirm_callback: + confirmed = confirm_callback(path, file_size, file_mtime, reason) + if not confirmed: + logger.info(f"User cancelled file deletion: {path}") + self._log_stat("delete_file", str(path), False, "User cancelled") + return { + "success": False, + "user_cancelled": True, + "path": str(path) + } + + # Delete the file + path.unlink() + + logger.info(f"Deleted file: {path}") + self._log_stat("delete_file", str(path.parent), True) + + return { + "success": True, + "path": str(path), + "user_cancelled": False + } + + except PermissionError as e: + error_msg = f"Permission denied deleting {file_path}: {e}" + logger.error(error_msg) + self._log_stat("delete_file", file_path, False, str(e)) + return {"error": error_msg, "error_type": "PermissionError"} + except Exception as e: + error_msg = f"Error deleting file {file_path}: {e}" + logger.error(error_msg) + self._log_stat("delete_file", file_path, False, str(e)) + return {"error": error_msg} + + async def create_directory(self, dir_path: str) -> Dict[str, Any]: + """ + Create a directory. + + Note: Ignores .gitignore - allows creating directories in allowed folders. + + Args: + dir_path: Path to the directory to create + + Returns: + Dictionary containing: + - success: Whether the operation succeeded + - path: Directory path + - created: Whether a new directory was created + """ + logger = get_logger() + + try: + path = self.config.normalize_path(dir_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("create_directory", str(path.parent), False, error_msg) + return {"error": error_msg} + + already_exists = path.exists() + + if already_exists and not path.is_dir(): + error_msg = f"Path exists but is not a directory: {path}" + logger.warning(error_msg) + self._log_stat("create_directory", str(path), False, error_msg) + return {"error": error_msg} + + # Create directory (and parents) + path.mkdir(parents=True, exist_ok=True) + + if already_exists: + logger.info(f"Directory already exists: {path}") + else: + logger.info(f"Created directory: {path}") + + self._log_stat("create_directory", str(path.parent), True) + + return { + "success": True, + "path": str(path), + "created": not already_exists + } + + except PermissionError as e: + error_msg = f"Permission denied creating directory {dir_path}: {e}" + logger.error(error_msg) + self._log_stat("create_directory", dir_path, False, str(e)) + return {"error": error_msg, "error_type": "PermissionError"} + except Exception as e: + error_msg = f"Error creating directory {dir_path}: {e}" + logger.error(error_msg) + self._log_stat("create_directory", dir_path, False, str(e)) + return {"error": error_msg} + + async def move_file(self, source_path: str, dest_path: str) -> Dict[str, Any]: + """ + Move or rename a file. + + Note: Ignores .gitignore - allows moving files within allowed folders. + + Args: + source_path: Path to the source file + dest_path: Destination path + + Returns: + Dictionary containing: + - success: Whether the operation succeeded + - source: Source path + - destination: Destination path + """ + import shutil + logger = get_logger() + + try: + source = self.config.normalize_path(source_path) + dest = self.config.normalize_path(dest_path) + + # Both paths must be in allowed folders + if not self.is_allowed_path(source): + error_msg = f"Access denied: source {source} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("move_file", str(source.parent), False, error_msg) + return {"error": error_msg} + + if not self.is_allowed_path(dest): + error_msg = f"Access denied: destination {dest} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("move_file", str(dest.parent), False, error_msg) + return {"error": error_msg} + + if not source.exists(): + error_msg = f"Source file not found: {source}" + logger.warning(error_msg) + self._log_stat("move_file", str(source), False, error_msg) + return {"error": error_msg} + + if not source.is_file(): + error_msg = f"Source is not a file: {source}" + logger.warning(error_msg) + self._log_stat("move_file", str(source), False, error_msg) + return {"error": error_msg} + + # Create destination parent directory if needed + dest.parent.mkdir(parents=True, exist_ok=True) + + # Move/rename the file + shutil.move(str(source), str(dest)) + + logger.info(f"Moved file: {source} -> {dest}") + self._log_stat("move_file", str(source.parent), True) + + return { + "success": True, + "source": str(source), + "destination": str(dest) + } + + except PermissionError as e: + error_msg = f"Permission denied moving file: {e}" + logger.error(error_msg) + self._log_stat("move_file", source_path, False, str(e)) + return {"error": error_msg, "error_type": "PermissionError"} + except Exception as e: + error_msg = f"Error moving file from {source_path} to {dest_path}: {e}" + logger.error(error_msg) + self._log_stat("move_file", source_path, False, str(e)) + return {"error": error_msg} + + async def copy_file(self, source_path: str, dest_path: str) -> Dict[str, Any]: + """ + Copy a file. + + Note: Ignores .gitignore - allows copying files within allowed folders. + + Args: + source_path: Path to the source file + dest_path: Destination path + + Returns: + Dictionary containing: + - success: Whether the operation succeeded + - source: Source path + - destination: Destination path + - size: Size of copied file + """ + import shutil + logger = get_logger() + + try: + source = self.config.normalize_path(source_path) + dest = self.config.normalize_path(dest_path) + + # Both paths must be in allowed folders + if not self.is_allowed_path(source): + error_msg = f"Access denied: source {source} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("copy_file", str(source.parent), False, error_msg) + return {"error": error_msg} + + if not self.is_allowed_path(dest): + error_msg = f"Access denied: destination {dest} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("copy_file", str(dest.parent), False, error_msg) + return {"error": error_msg} + + if not source.exists(): + error_msg = f"Source file not found: {source}" + logger.warning(error_msg) + self._log_stat("copy_file", str(source), False, error_msg) + return {"error": error_msg} + + if not source.is_file(): + error_msg = f"Source is not a file: {source}" + logger.warning(error_msg) + self._log_stat("copy_file", str(source), False, error_msg) + return {"error": error_msg} + + # Create destination parent directory if needed + dest.parent.mkdir(parents=True, exist_ok=True) + + # Copy the file (preserves metadata) + shutil.copy2(str(source), str(dest)) + + file_size = dest.stat().st_size + + logger.info(f"Copied file: {source} -> {dest} ({file_size} bytes)") + self._log_stat("copy_file", str(source.parent), True) + + return { + "success": True, + "source": str(source), + "destination": str(dest), + "size": file_size + } + + except PermissionError as e: + error_msg = f"Permission denied copying file: {e}" + logger.error(error_msg) + self._log_stat("copy_file", source_path, False, str(e)) + return {"error": error_msg, "error_type": "PermissionError"} + except Exception as e: + error_msg = f"Error copying file from {source_path} to {dest_path}: {e}" + logger.error(error_msg) + self._log_stat("copy_file", source_path, False, str(e)) + return {"error": error_msg} + + # ========================================================================= + # DATABASE OPERATIONS + # ========================================================================= + + async def inspect_database( + self, + db_path: str, + table_name: Optional[str] = None + ) -> Dict[str, Any]: + """ + Inspect SQLite database schema. + + Args: + db_path: Path to the database file + table_name: Optional specific table to inspect + + Returns: + Dictionary containing database/table information + """ + logger = get_logger() + + try: + path = Path(db_path).resolve() + + if not path.exists(): + error_msg = f"Database not found: {path}" + logger.warning(error_msg) + self._log_stat("inspect_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + if not path.is_file(): + error_msg = f"Not a file: {path}" + logger.warning(error_msg) + self._log_stat("inspect_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Open database in read-only mode + try: + conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True) + cursor = conn.cursor() + except sqlite3.DatabaseError as e: + error_msg = f"Not a valid SQLite database: {path} ({e})" + logger.warning(error_msg) + self._log_stat("inspect_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + try: + if table_name: + # Get info for specific table + cursor.execute( + "SELECT sql FROM sqlite_master WHERE type='table' AND name=?", + (table_name,) + ) + result = cursor.fetchone() + + if not result: + return {"error": f"Table '{table_name}' not found"} + + create_sql = result[0] + + # Get column info + cursor.execute(f"PRAGMA table_info({table_name})") + columns = [] + for row in cursor.fetchall(): + columns.append({ + "id": row[0], + "name": row[1], + "type": row[2], + "not_null": bool(row[3]), + "default_value": row[4], + "primary_key": bool(row[5]) + }) + + # Get indexes + cursor.execute(f"PRAGMA index_list({table_name})") + indexes = [] + for row in cursor.fetchall(): + indexes.append({ + "name": row[1], + "unique": bool(row[2]) + }) + + # Get row count + cursor.execute(f"SELECT COUNT(*) FROM {table_name}") + row_count = cursor.fetchone()[0] + + logger.info(f"Inspected table: {table_name} in {path}") + self._log_stat("inspect_database", str(path.parent), True) + + return { + "database": str(path), + "table": table_name, + "create_sql": create_sql, + "columns": columns, + "indexes": indexes, + "row_count": row_count + } + else: + # Get all tables + cursor.execute( + "SELECT name, sql FROM sqlite_master WHERE type='table' ORDER BY name" + ) + tables = [] + for row in cursor.fetchall(): + table = row[0] + cursor.execute(f"SELECT COUNT(*) FROM {table}") + count = cursor.fetchone()[0] + tables.append({ + "name": table, + "row_count": count + }) + + # Get indexes + cursor.execute( + "SELECT name, tbl_name FROM sqlite_master WHERE type='index' ORDER BY name" + ) + indexes = [{"name": row[0], "table": row[1]} for row in cursor.fetchall()] + + # Get triggers + cursor.execute( + "SELECT name, tbl_name FROM sqlite_master WHERE type='trigger' ORDER BY name" + ) + triggers = [{"name": row[0], "table": row[1]} for row in cursor.fetchall()] + + # Get database size + db_size = path.stat().st_size + + logger.info(f"Inspected database: {path} ({len(tables)} tables)") + self._log_stat("inspect_database", str(path.parent), True) + + return { + "database": str(path), + "size": db_size, + "size_mb": db_size / (1024 * 1024), + "tables": tables, + "table_count": len(tables), + "indexes": indexes, + "triggers": triggers + } + finally: + conn.close() + + except Exception as e: + error_msg = f"Error inspecting database: {e}" + logger.error(error_msg) + self._log_stat("inspect_database", db_path, False, str(e)) + return {"error": error_msg} + + async def search_database( + self, + db_path: str, + search_term: str, + table_name: Optional[str] = None, + column_name: Optional[str] = None + ) -> Dict[str, Any]: + """ + Search for a value across database tables. + + Args: + db_path: Path to the database file + search_term: Value to search for (partial match) + table_name: Optional table to limit search + column_name: Optional column to limit search + + Returns: + Dictionary containing search results + """ + logger = get_logger() + + try: + path = Path(db_path).resolve() + + if not path.exists(): + error_msg = f"Database not found: {path}" + logger.warning(error_msg) + self._log_stat("search_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Open database in read-only mode + try: + conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True) + cursor = conn.cursor() + except sqlite3.DatabaseError as e: + error_msg = f"Not a valid SQLite database: {path} ({e})" + logger.warning(error_msg) + self._log_stat("search_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + try: + matches = [] + + # Get tables to search + if table_name: + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (table_name,) + ) + tables = [row[0] for row in cursor.fetchall()] + if not tables: + return {"error": f"Table '{table_name}' not found"} + else: + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = [row[0] for row in cursor.fetchall()] + + for table in tables: + # Get columns for this table + cursor.execute(f"PRAGMA table_info({table})") + columns = [row[1] for row in cursor.fetchall()] + + # Filter columns if specified + if column_name: + if column_name not in columns: + continue + columns = [column_name] + + # Search in each column + for column in columns: + try: + query = f"SELECT * FROM {table} WHERE {column} LIKE ? LIMIT {self.default_query_limit}" + cursor.execute(query, (f"%{search_term}%",)) + results = cursor.fetchall() + + if results: + col_names = [desc[0] for desc in cursor.description] + + for row in results: + row_dict = {} + for col_name, value in zip(col_names, row): + if isinstance(value, bytes): + try: + row_dict[col_name] = value.decode("utf-8") + except UnicodeDecodeError: + row_dict[col_name] = f"" + else: + row_dict[col_name] = value + + matches.append({ + "table": table, + "column": column, + "row": row_dict + }) + except sqlite3.Error: + # Skip columns that can't be searched + continue + + logger.info( + f"Searched database: {path} for '{search_term}' - " + f"found {len(matches)} matches" + ) + self._log_stat("search_database", str(path.parent), True) + + result = { + "database": str(path), + "search_term": search_term, + "matches": matches, + "count": len(matches) + } + + if table_name: + result["table_filter"] = table_name + if column_name: + result["column_filter"] = column_name + + if len(matches) >= self.default_query_limit: + result["note"] = ( + f"Results limited to {self.default_query_limit} matches. " + "Use more specific search or query_database for full results." + ) + + return result + + finally: + conn.close() + + except Exception as e: + error_msg = f"Error searching database: {e}" + logger.error(error_msg) + self._log_stat("search_database", db_path, False, str(e)) + return {"error": error_msg} + + async def query_database( + self, + db_path: str, + query: str, + limit: Optional[int] = None + ) -> Dict[str, Any]: + """ + Execute a read-only SQL query. + + Args: + db_path: Path to the database file + query: SQL query to execute (SELECT only) + limit: Optional row limit + + Returns: + Dictionary containing query results + """ + logger = get_logger() + + try: + path = Path(db_path).resolve() + + if not path.exists(): + error_msg = f"Database not found: {path}" + logger.warning(error_msg) + self._log_stat("query_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Validate query + is_safe, error_msg = SQLiteQueryValidator.is_safe_query(query) + if not is_safe: + logger.warning(f"Unsafe query rejected: {error_msg}") + self._log_stat("query_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Open database in read-only mode + try: + conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True) + cursor = conn.cursor() + except sqlite3.DatabaseError as e: + error_msg = f"Not a valid SQLite database: {path} ({e})" + logger.warning(error_msg) + self._log_stat("query_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + try: + # Set query timeout (Unix only) + def timeout_handler(signum, frame): + raise TimeoutError( + f"Query exceeded {self.max_query_timeout} second timeout" + ) + + if hasattr(signal, "SIGALRM"): + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(self.max_query_timeout) + + try: + # Execute query + cursor.execute(query) + + # Get results + result_limit = min( + limit or self.default_query_limit, + self.max_query_results + ) + results = cursor.fetchmany(result_limit) + + # Get column names + columns = [ + desc[0] for desc in cursor.description + ] if cursor.description else [] + + # Convert to list of dicts + rows = [] + for row in results: + row_dict = {} + for col_name, value in zip(columns, row): + if isinstance(value, bytes): + try: + row_dict[col_name] = value.decode("utf-8") + except UnicodeDecodeError: + row_dict[col_name] = f"" + else: + row_dict[col_name] = value + rows.append(row_dict) + + # Check if more results available + has_more = len(results) == result_limit + if has_more: + one_more = cursor.fetchone() + has_more = one_more is not None + + finally: + if hasattr(signal, "SIGALRM"): + signal.alarm(0) + + logger.info(f"Executed query on {path}: returned {len(rows)} rows") + self._log_stat("query_database", str(path.parent), True) + + result = { + "database": str(path), + "query": query, + "columns": columns, + "rows": rows, + "count": len(rows) + } + + if has_more: + result["truncated"] = True + result["note"] = ( + f"Results limited to {result_limit} rows. " + "Use LIMIT clause in query for more control." + ) + + return result + + except TimeoutError as e: + error_msg = str(e) + logger.warning(error_msg) + self._log_stat("query_database", str(path.parent), False, error_msg) + return {"error": error_msg} + except sqlite3.Error as e: + error_msg = f"SQL error: {e}" + logger.warning(error_msg) + self._log_stat("query_database", str(path.parent), False, error_msg) + return {"error": error_msg} + finally: + if hasattr(signal, "SIGALRM"): + signal.alarm(0) + conn.close() + + except Exception as e: + error_msg = f"Error executing query: {e}" + logger.error(error_msg) + self._log_stat("query_database", db_path, False, str(e)) + return {"error": error_msg} diff --git a/oai/mcp/validators.py b/oai/mcp/validators.py new file mode 100644 index 0000000..387a5fd --- /dev/null +++ b/oai/mcp/validators.py @@ -0,0 +1,123 @@ +""" +Query validation for oAI MCP database operations. + +This module provides safety validation for SQL queries to ensure +only read-only operations are executed. +""" + +import re +from typing import Tuple + +from oai.constants import DANGEROUS_SQL_KEYWORDS + + +class SQLiteQueryValidator: + """ + Validate SQLite queries for read-only safety. + + Ensures that only SELECT queries (including CTEs) are allowed + and blocks potentially dangerous operations like INSERT, UPDATE, + DELETE, DROP, etc. + """ + + @staticmethod + def is_safe_query(query: str) -> Tuple[bool, str]: + """ + Validate that a query is a safe read-only SELECT. + + The validation: + 1. Checks that query starts with SELECT or WITH + 2. Strips string literals before checking for dangerous keywords + 3. Blocks any dangerous keywords outside of string literals + + Args: + query: SQL query string to validate + + Returns: + Tuple of (is_safe, error_message) + - is_safe: True if the query is safe to execute + - error_message: Description of why query is unsafe (empty if safe) + + Examples: + >>> SQLiteQueryValidator.is_safe_query("SELECT * FROM users") + (True, "") + >>> SQLiteQueryValidator.is_safe_query("DELETE FROM users") + (False, "Only SELECT queries are allowed...") + >>> SQLiteQueryValidator.is_safe_query("SELECT 'DELETE' FROM users") + (True, "") # 'DELETE' is inside a string literal + """ + query_upper = query.strip().upper() + + # Must start with SELECT or WITH (for CTEs) + if not (query_upper.startswith("SELECT") or query_upper.startswith("WITH")): + return False, "Only SELECT queries are allowed (including WITH/CTE)" + + # Remove string literals before checking for dangerous keywords + # This prevents false positives when keywords appear in data + query_no_strings = re.sub(r"'[^']*'", "", query_upper) + query_no_strings = re.sub(r'"[^"]*"', "", query_no_strings) + + # Check for dangerous keywords outside of quotes + for keyword in DANGEROUS_SQL_KEYWORDS: + if re.search(r"\b" + keyword + r"\b", query_no_strings): + return False, f"Keyword '{keyword}' not allowed in read-only mode" + + return True, "" + + @staticmethod + def sanitize_table_name(table_name: str) -> str: + """ + Sanitize a table name to prevent SQL injection. + + Only allows alphanumeric characters and underscores. + + Args: + table_name: Table name to sanitize + + Returns: + Sanitized table name + + Raises: + ValueError: If table name contains invalid characters + """ + # Remove any characters that aren't alphanumeric or underscore + sanitized = re.sub(r"[^\w]", "", table_name) + + if not sanitized: + raise ValueError("Table name cannot be empty after sanitization") + + if sanitized != table_name: + raise ValueError( + f"Table name contains invalid characters: {table_name}" + ) + + return sanitized + + @staticmethod + def sanitize_column_name(column_name: str) -> str: + """ + Sanitize a column name to prevent SQL injection. + + Only allows alphanumeric characters and underscores. + + Args: + column_name: Column name to sanitize + + Returns: + Sanitized column name + + Raises: + ValueError: If column name contains invalid characters + """ + # Remove any characters that aren't alphanumeric or underscore + sanitized = re.sub(r"[^\w]", "", column_name) + + if not sanitized: + raise ValueError("Column name cannot be empty after sanitization") + + if sanitized != column_name: + raise ValueError( + f"Column name contains invalid characters: {column_name}" + ) + + return sanitized diff --git a/oai/providers/__init__.py b/oai/providers/__init__.py new file mode 100644 index 0000000..93df1e5 --- /dev/null +++ b/oai/providers/__init__.py @@ -0,0 +1,32 @@ +""" +Provider abstraction for oAI. + +This module provides a unified interface for AI model providers, +enabling easy extension to support additional providers beyond OpenRouter. +""" + +from oai.providers.base import ( + AIProvider, + ChatMessage, + ChatResponse, + ToolCall, + ToolFunction, + UsageStats, + ModelInfo, + ProviderCapabilities, +) +from oai.providers.openrouter import OpenRouterProvider + +__all__ = [ + # Base classes and types + "AIProvider", + "ChatMessage", + "ChatResponse", + "ToolCall", + "ToolFunction", + "UsageStats", + "ModelInfo", + "ProviderCapabilities", + # Provider implementations + "OpenRouterProvider", +] diff --git a/oai/providers/base.py b/oai/providers/base.py new file mode 100644 index 0000000..865fe7d --- /dev/null +++ b/oai/providers/base.py @@ -0,0 +1,413 @@ +""" +Abstract base provider for AI model integration. + +This module defines the interface that all AI providers must implement, +along with common data structures for requests and responses. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import ( + Any, + AsyncIterator, + Dict, + Iterator, + List, + Optional, + Union, +) + + +class MessageRole(str, Enum): + """Message roles in a conversation.""" + + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" + + +@dataclass +class ToolFunction: + """ + Represents a function within a tool call. + + Attributes: + name: The function name + arguments: JSON string of function arguments + """ + + name: str + arguments: str + + +@dataclass +class ToolCall: + """ + Represents a tool/function call requested by the model. + + Attributes: + id: Unique identifier for this tool call + type: Type of tool call (usually "function") + function: The function being called + """ + + id: str + type: str + function: ToolFunction + + +@dataclass +class UsageStats: + """ + Token usage statistics from an API response. + + Attributes: + prompt_tokens: Number of tokens in the prompt + completion_tokens: Number of tokens in the completion + total_tokens: Total tokens used + total_cost_usd: Cost in USD (if available from API) + """ + + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + total_cost_usd: Optional[float] = None + + @property + def input_tokens(self) -> int: + """Alias for prompt_tokens.""" + return self.prompt_tokens + + @property + def output_tokens(self) -> int: + """Alias for completion_tokens.""" + return self.completion_tokens + + +@dataclass +class ChatMessage: + """ + A single message in a chat conversation. + + Attributes: + role: The role of the message sender + content: Message content (text or structured content blocks) + name: Optional name for the sender + tool_calls: List of tool calls (for assistant messages) + tool_call_id: Tool call ID this message responds to (for tool messages) + """ + + role: str + content: Union[str, List[Dict[str, Any]], None] = None + name: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = None + tool_call_id: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format for API requests.""" + result: Dict[str, Any] = {"role": self.role} + + if self.content is not None: + result["content"] = self.content + + if self.name: + result["name"] = self.name + + if self.tool_calls: + result["tool_calls"] = [ + { + "id": tc.id, + "type": tc.type, + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in self.tool_calls + ] + + if self.tool_call_id: + result["tool_call_id"] = self.tool_call_id + + return result + + +@dataclass +class ChatResponseChoice: + """ + A single choice in a chat response. + + Attributes: + index: Index of this choice + message: The response message + finish_reason: Why the response ended + """ + + index: int + message: ChatMessage + finish_reason: Optional[str] = None + + +@dataclass +class ChatResponse: + """ + Response from a chat completion request. + + Attributes: + id: Unique identifier for this response + choices: List of response choices + usage: Token usage statistics + model: Model that generated this response + created: Unix timestamp of creation + """ + + id: str + choices: List[ChatResponseChoice] + usage: Optional[UsageStats] = None + model: Optional[str] = None + created: Optional[int] = None + + @property + def message(self) -> Optional[ChatMessage]: + """Get the first choice's message.""" + if self.choices: + return self.choices[0].message + return None + + @property + def content(self) -> Optional[str]: + """Get the text content of the first choice.""" + msg = self.message + if msg and isinstance(msg.content, str): + return msg.content + return None + + @property + def tool_calls(self) -> Optional[List[ToolCall]]: + """Get tool calls from the first choice.""" + msg = self.message + if msg: + return msg.tool_calls + return None + + +@dataclass +class StreamChunk: + """ + A single chunk from a streaming response. + + Attributes: + id: Response ID + delta_content: New content in this chunk + finish_reason: Finish reason (if this is the last chunk) + usage: Usage stats (usually in the last chunk) + error: Error message if something went wrong + """ + + id: str + delta_content: Optional[str] = None + finish_reason: Optional[str] = None + usage: Optional[UsageStats] = None + error: Optional[str] = None + + +@dataclass +class ModelInfo: + """ + Information about an AI model. + + Attributes: + id: Unique model identifier + name: Display name + description: Model description + context_length: Maximum context window size + pricing: Pricing info (input/output per million tokens) + supported_parameters: List of supported API parameters + input_modalities: Supported input types (text, image, etc.) + output_modalities: Supported output types + """ + + id: str + name: str + description: str = "" + context_length: int = 0 + pricing: Dict[str, float] = field(default_factory=dict) + supported_parameters: List[str] = field(default_factory=list) + input_modalities: List[str] = field(default_factory=lambda: ["text"]) + output_modalities: List[str] = field(default_factory=lambda: ["text"]) + + def supports_images(self) -> bool: + """Check if model supports image input.""" + return "image" in self.input_modalities + + def supports_tools(self) -> bool: + """Check if model supports function calling/tools.""" + return "tools" in self.supported_parameters or "functions" in self.supported_parameters + + def supports_streaming(self) -> bool: + """Check if model supports streaming responses.""" + return "stream" in self.supported_parameters + + def supports_online(self) -> bool: + """Check if model supports web search (online mode).""" + return self.supports_tools() + + +@dataclass +class ProviderCapabilities: + """ + Capabilities supported by a provider. + + Attributes: + streaming: Provider supports streaming responses + tools: Provider supports function calling + images: Provider supports image inputs + online: Provider supports web search + max_context: Maximum context length across all models + """ + + streaming: bool = True + tools: bool = True + images: bool = True + online: bool = False + max_context: int = 128000 + + +class AIProvider(ABC): + """ + Abstract base class for AI model providers. + + All provider implementations must inherit from this class + and implement the required abstract methods. + """ + + def __init__(self, api_key: str, base_url: Optional[str] = None): + """ + Initialize the provider. + + Args: + api_key: API key for authentication + base_url: Optional custom base URL for the API + """ + self.api_key = api_key + self.base_url = base_url + + @property + @abstractmethod + def name(self) -> str: + """Get the provider name.""" + pass + + @property + @abstractmethod + def capabilities(self) -> ProviderCapabilities: + """Get provider capabilities.""" + pass + + @abstractmethod + def list_models(self) -> List[ModelInfo]: + """ + Fetch available models from the provider. + + Returns: + List of available models with their info + """ + pass + + @abstractmethod + def get_model(self, model_id: str) -> Optional[ModelInfo]: + """ + Get information about a specific model. + + Args: + model_id: The model identifier + + Returns: + Model information or None if not found + """ + pass + + @abstractmethod + def chat( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + **kwargs: Any, + ) -> Union[ChatResponse, Iterator[StreamChunk]]: + """ + Send a chat completion request. + + Args: + model: Model ID to use + messages: List of chat messages + stream: Whether to stream the response + max_tokens: Maximum tokens in response + temperature: Sampling temperature + tools: List of tool definitions for function calling + tool_choice: How to handle tool selection ("auto", "none", etc.) + **kwargs: Additional provider-specific parameters + + Returns: + ChatResponse for non-streaming, Iterator[StreamChunk] for streaming + """ + pass + + @abstractmethod + async def chat_async( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + **kwargs: Any, + ) -> Union[ChatResponse, AsyncIterator[StreamChunk]]: + """ + Send an async chat completion request. + + Args: + model: Model ID to use + messages: List of chat messages + stream: Whether to stream the response + max_tokens: Maximum tokens in response + temperature: Sampling temperature + tools: List of tool definitions for function calling + tool_choice: How to handle tool selection + **kwargs: Additional provider-specific parameters + + Returns: + ChatResponse for non-streaming, AsyncIterator[StreamChunk] for streaming + """ + pass + + @abstractmethod + def get_credits(self) -> Optional[Dict[str, Any]]: + """ + Get account credit/balance information. + + Returns: + Dict with credit info or None if not supported + """ + pass + + def validate_api_key(self) -> bool: + """ + Validate that the API key is valid. + + Returns: + True if API key is valid + """ + try: + self.list_models() + return True + except Exception: + return False diff --git a/oai/providers/openrouter.py b/oai/providers/openrouter.py new file mode 100644 index 0000000..075fb05 --- /dev/null +++ b/oai/providers/openrouter.py @@ -0,0 +1,623 @@ +""" +OpenRouter provider implementation. + +This module implements the AIProvider interface for OpenRouter, +supporting chat completions, streaming, and function calling. +""" + +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union + +import requests +from openrouter import OpenRouter + +from oai.constants import APP_NAME, APP_URL, DEFAULT_BASE_URL +from oai.providers.base import ( + AIProvider, + ChatMessage, + ChatResponse, + ChatResponseChoice, + ModelInfo, + ProviderCapabilities, + StreamChunk, + ToolCall, + ToolFunction, + UsageStats, +) +from oai.utils.logging import get_logger + + +class OpenRouterProvider(AIProvider): + """ + OpenRouter API provider implementation. + + Provides access to multiple AI models through OpenRouter's unified API, + supporting chat completions, streaming responses, and function calling. + + Attributes: + client: The underlying OpenRouter client + _models_cache: Cached list of available models + """ + + def __init__( + self, + api_key: str, + base_url: Optional[str] = None, + app_name: str = APP_NAME, + app_url: str = APP_URL, + ): + """ + Initialize the OpenRouter provider. + + Args: + api_key: OpenRouter API key + base_url: Optional custom base URL + app_name: Application name for API headers + app_url: Application URL for API headers + """ + super().__init__(api_key, base_url or DEFAULT_BASE_URL) + self.app_name = app_name + self.app_url = app_url + self.client = OpenRouter(api_key=api_key) + self._models_cache: Optional[List[ModelInfo]] = None + self._raw_models_cache: Optional[List[Dict[str, Any]]] = None + + self.logger = get_logger() + self.logger.info(f"OpenRouter provider initialized with base URL: {self.base_url}") + + @property + def name(self) -> str: + """Get the provider name.""" + return "OpenRouter" + + @property + def capabilities(self) -> ProviderCapabilities: + """Get provider capabilities.""" + return ProviderCapabilities( + streaming=True, + tools=True, + images=True, + online=True, + max_context=2000000, # Claude models support up to 200k + ) + + def _get_headers(self) -> Dict[str, str]: + """Get standard HTTP headers for API requests.""" + headers = { + "HTTP-Referer": self.app_url, + "X-Title": self.app_name, + } + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + return headers + + def _parse_model(self, model_data: Dict[str, Any]) -> ModelInfo: + """ + Parse raw model data into ModelInfo. + + Args: + model_data: Raw model data from API + + Returns: + Parsed ModelInfo object + """ + architecture = model_data.get("architecture", {}) + pricing_data = model_data.get("pricing", {}) + + # Parse pricing (convert from string to float if needed) + pricing = {} + for key in ["prompt", "completion"]: + value = pricing_data.get(key) + if value is not None: + try: + # Convert from per-token to per-million-tokens + pricing[key] = float(value) * 1_000_000 + except (ValueError, TypeError): + pricing[key] = 0.0 + + return ModelInfo( + id=model_data.get("id", ""), + name=model_data.get("name", model_data.get("id", "")), + description=model_data.get("description", ""), + context_length=model_data.get("context_length", 0), + pricing=pricing, + supported_parameters=model_data.get("supported_parameters", []), + input_modalities=architecture.get("input_modalities", ["text"]), + output_modalities=architecture.get("output_modalities", ["text"]), + ) + + def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]: + """ + Fetch available models from OpenRouter. + + Args: + filter_text_only: If True, exclude video-only models + + Returns: + List of available models + + Raises: + Exception: If API request fails + """ + if self._models_cache is not None: + return self._models_cache + + try: + response = requests.get( + f"{self.base_url}/models", + headers=self._get_headers(), + timeout=10, + ) + response.raise_for_status() + + raw_models = response.json().get("data", []) + self._raw_models_cache = raw_models + + models = [] + for model_data in raw_models: + # Optionally filter out video-only models + if filter_text_only: + modalities = model_data.get("modalities", []) + if modalities and "video" in modalities and "text" not in modalities: + continue + + models.append(self._parse_model(model_data)) + + self._models_cache = models + self.logger.info(f"Fetched {len(models)} models from OpenRouter") + return models + + except requests.RequestException as e: + self.logger.error(f"Failed to fetch models: {e}") + raise + + def get_raw_models(self) -> List[Dict[str, Any]]: + """ + Get raw model data as returned by the API. + + Useful for accessing provider-specific fields not in ModelInfo. + + Returns: + List of raw model dictionaries + """ + if self._raw_models_cache is None: + self.list_models() + return self._raw_models_cache or [] + + def get_model(self, model_id: str) -> Optional[ModelInfo]: + """ + Get information about a specific model. + + Args: + model_id: The model identifier + + Returns: + Model information or None if not found + """ + models = self.list_models() + for model in models: + if model.id == model_id: + return model + return None + + def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]: + """ + Get raw model data for a specific model. + + Args: + model_id: The model identifier + + Returns: + Raw model dictionary or None if not found + """ + raw_models = self.get_raw_models() + for model in raw_models: + if model.get("id") == model_id: + return model + return None + + def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]: + """ + Convert ChatMessage objects to API format. + + Args: + messages: List of ChatMessage objects + + Returns: + List of message dictionaries for the API + """ + return [msg.to_dict() for msg in messages] + + def _parse_usage(self, usage_data: Any) -> Optional[UsageStats]: + """ + Parse usage data from API response. + + Args: + usage_data: Raw usage data from API + + Returns: + Parsed UsageStats or None + """ + if not usage_data: + return None + + # Handle both attribute and dict access + prompt_tokens = 0 + completion_tokens = 0 + total_cost = None + + if hasattr(usage_data, "prompt_tokens"): + prompt_tokens = getattr(usage_data, "prompt_tokens", 0) or 0 + elif isinstance(usage_data, dict): + prompt_tokens = usage_data.get("prompt_tokens", 0) or 0 + + if hasattr(usage_data, "completion_tokens"): + completion_tokens = getattr(usage_data, "completion_tokens", 0) or 0 + elif isinstance(usage_data, dict): + completion_tokens = usage_data.get("completion_tokens", 0) or 0 + + # Try alternative naming (input_tokens/output_tokens) + if prompt_tokens == 0: + if hasattr(usage_data, "input_tokens"): + prompt_tokens = getattr(usage_data, "input_tokens", 0) or 0 + elif isinstance(usage_data, dict): + prompt_tokens = usage_data.get("input_tokens", 0) or 0 + + if completion_tokens == 0: + if hasattr(usage_data, "output_tokens"): + completion_tokens = getattr(usage_data, "output_tokens", 0) or 0 + elif isinstance(usage_data, dict): + completion_tokens = usage_data.get("output_tokens", 0) or 0 + + # Get cost if available + if hasattr(usage_data, "total_cost_usd"): + total_cost = getattr(usage_data, "total_cost_usd", None) + elif isinstance(usage_data, dict): + total_cost = usage_data.get("total_cost_usd") + + return UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + total_cost_usd=float(total_cost) if total_cost else None, + ) + + def _parse_tool_calls(self, tool_calls_data: Any) -> Optional[List[ToolCall]]: + """ + Parse tool calls from API response. + + Args: + tool_calls_data: Raw tool calls data + + Returns: + List of ToolCall objects or None + """ + if not tool_calls_data: + return None + + tool_calls = [] + for tc in tool_calls_data: + # Handle both attribute and dict access + if hasattr(tc, "id"): + tc_id = tc.id + tc_type = getattr(tc, "type", "function") + func = tc.function + func_name = func.name + func_args = func.arguments + else: + tc_id = tc.get("id", "") + tc_type = tc.get("type", "function") + func = tc.get("function", {}) + func_name = func.get("name", "") + func_args = func.get("arguments", "{}") + + tool_calls.append( + ToolCall( + id=tc_id, + type=tc_type, + function=ToolFunction(name=func_name, arguments=func_args), + ) + ) + + return tool_calls if tool_calls else None + + def _parse_response(self, response: Any) -> ChatResponse: + """ + Parse API response into ChatResponse. + + Args: + response: Raw API response + + Returns: + Parsed ChatResponse + """ + choices = [] + for choice in response.choices: + msg = choice.message + message = ChatMessage( + role=msg.role if hasattr(msg, "role") else "assistant", + content=msg.content if hasattr(msg, "content") else None, + tool_calls=self._parse_tool_calls( + getattr(msg, "tool_calls", None) + ), + ) + choices.append( + ChatResponseChoice( + index=choice.index if hasattr(choice, "index") else 0, + message=message, + finish_reason=getattr(choice, "finish_reason", None), + ) + ) + + return ChatResponse( + id=response.id if hasattr(response, "id") else "", + choices=choices, + usage=self._parse_usage(getattr(response, "usage", None)), + model=getattr(response, "model", None), + created=getattr(response, "created", None), + ) + + def chat( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + transforms: Optional[List[str]] = None, + **kwargs: Any, + ) -> Union[ChatResponse, Iterator[StreamChunk]]: + """ + Send a chat completion request to OpenRouter. + + Args: + model: Model ID to use + messages: List of chat messages + stream: Whether to stream the response + max_tokens: Maximum tokens in response + temperature: Sampling temperature (0-2) + tools: List of tool definitions for function calling + tool_choice: How to handle tool selection ("auto", "none", etc.) + transforms: List of transforms (e.g., ["middle-out"]) + **kwargs: Additional parameters + + Returns: + ChatResponse for non-streaming, Iterator[StreamChunk] for streaming + """ + # Build request parameters + params: Dict[str, Any] = { + "model": model, + "messages": self._convert_messages(messages), + "stream": stream, + "http_headers": self._get_headers(), + } + + # Request usage stats in streaming responses + if stream: + params["stream_options"] = {"include_usage": True} + + if max_tokens is not None: + params["max_tokens"] = max_tokens + + if temperature is not None: + params["temperature"] = temperature + + if tools: + params["tools"] = tools + params["tool_choice"] = tool_choice or "auto" + + if transforms: + params["transforms"] = transforms + + # Add any additional parameters + params.update(kwargs) + + self.logger.debug(f"Sending chat request to model {model}") + + try: + response = self.client.chat.send(**params) + + if stream: + return self._stream_response(response) + else: + return self._parse_response(response) + + except Exception as e: + self.logger.error(f"Chat request failed: {e}") + raise + + def _stream_response(self, response: Any) -> Iterator[StreamChunk]: + """ + Process a streaming response. + + Args: + response: Streaming response from API + + Yields: + StreamChunk objects + """ + last_usage = None + + try: + for chunk in response: + # Check for errors + if hasattr(chunk, "error") and chunk.error: + yield StreamChunk( + id=getattr(chunk, "id", ""), + error=chunk.error.message if hasattr(chunk.error, "message") else str(chunk.error), + ) + return + + # Extract delta content + delta_content = None + finish_reason = None + + if hasattr(chunk, "choices") and chunk.choices: + choice = chunk.choices[0] + if hasattr(choice, "delta"): + delta = choice.delta + if hasattr(delta, "content") and delta.content: + delta_content = delta.content + finish_reason = getattr(choice, "finish_reason", None) + + # Track usage from last chunk + if hasattr(chunk, "usage") and chunk.usage: + last_usage = self._parse_usage(chunk.usage) + + yield StreamChunk( + id=getattr(chunk, "id", ""), + delta_content=delta_content, + finish_reason=finish_reason, + usage=last_usage if finish_reason else None, + ) + + except Exception as e: + self.logger.error(f"Stream error: {e}") + yield StreamChunk(id="", error=str(e)) + + async def chat_async( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + **kwargs: Any, + ) -> Union[ChatResponse, AsyncIterator[StreamChunk]]: + """ + Send an async chat completion request. + + Note: Currently wraps the sync implementation. + TODO: Implement true async support when OpenRouter SDK supports it. + + Args: + model: Model ID to use + messages: List of chat messages + stream: Whether to stream the response + max_tokens: Maximum tokens in response + temperature: Sampling temperature + tools: List of tool definitions + tool_choice: Tool selection mode + **kwargs: Additional parameters + + Returns: + ChatResponse for non-streaming, AsyncIterator for streaming + """ + # For now, use sync implementation + # TODO: Add true async when SDK supports it + result = self.chat( + model=model, + messages=messages, + stream=stream, + max_tokens=max_tokens, + temperature=temperature, + tools=tools, + tool_choice=tool_choice, + **kwargs, + ) + + if stream and isinstance(result, Iterator): + # Convert sync iterator to async + async def async_iter() -> AsyncIterator[StreamChunk]: + for chunk in result: + yield chunk + + return async_iter() + + return result + + def get_credits(self) -> Optional[Dict[str, Any]]: + """ + Get OpenRouter account credit information. + + Returns: + Dict with credit info: + - total_credits: Total credits purchased + - used_credits: Credits used + - credits_left: Remaining credits + + Raises: + Exception: If API request fails + """ + if not self.api_key: + return None + + try: + response = requests.get( + f"{self.base_url}/credits", + headers=self._get_headers(), + timeout=10, + ) + response.raise_for_status() + + data = response.json().get("data", {}) + total_credits = float(data.get("total_credits", 0)) + total_usage = float(data.get("total_usage", 0)) + credits_left = total_credits - total_usage + + return { + "total_credits": total_credits, + "used_credits": total_usage, + "credits_left": credits_left, + "total_credits_formatted": f"${total_credits:.2f}", + "used_credits_formatted": f"${total_usage:.2f}", + "credits_left_formatted": f"${credits_left:.2f}", + } + + except Exception as e: + self.logger.error(f"Failed to fetch credits: {e}") + return None + + def clear_cache(self) -> None: + """Clear the models cache to force a refresh.""" + self._models_cache = None + self._raw_models_cache = None + self.logger.debug("Models cache cleared") + + def get_effective_model_id(self, model_id: str, online_enabled: bool) -> str: + """ + Get the effective model ID with online suffix if needed. + + Args: + model_id: Base model ID + online_enabled: Whether online mode is enabled + + Returns: + Model ID with :online suffix if applicable + """ + if online_enabled and not model_id.endswith(":online"): + return f"{model_id}:online" + return model_id + + def estimate_cost( + self, + model_id: str, + input_tokens: int, + output_tokens: int, + ) -> float: + """ + Estimate the cost for a completion. + + Args: + model_id: Model ID + input_tokens: Number of input tokens + output_tokens: Number of output tokens + + Returns: + Estimated cost in USD + """ + model = self.get_model(model_id) + if model and model.pricing: + input_cost = model.pricing.get("prompt", 0) * input_tokens / 1_000_000 + output_cost = model.pricing.get("completion", 0) * output_tokens / 1_000_000 + return input_cost + output_cost + + # Fallback to default pricing if model not found + from oai.constants import MODEL_PRICING + + input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000 + output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000 + return input_cost + output_cost diff --git a/oai/py.typed b/oai/py.typed new file mode 100644 index 0000000..cb5707d --- /dev/null +++ b/oai/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561 +# This package supports type checking diff --git a/oai/ui/__init__.py b/oai/ui/__init__.py new file mode 100644 index 0000000..9f4be60 --- /dev/null +++ b/oai/ui/__init__.py @@ -0,0 +1,51 @@ +""" +UI utilities for oAI. + +This module provides rich terminal UI components and display helpers +for the chat application. +""" + +from oai.ui.console import ( + console, + clear_screen, + display_panel, + display_table, + display_markdown, + print_error, + print_warning, + print_success, + print_info, +) +from oai.ui.tables import ( + create_model_table, + create_stats_table, + create_help_table, + display_paginated_table, +) +from oai.ui.prompts import ( + prompt_confirm, + prompt_choice, + prompt_input, +) + +__all__ = [ + # Console utilities + "console", + "clear_screen", + "display_panel", + "display_table", + "display_markdown", + "print_error", + "print_warning", + "print_success", + "print_info", + # Table utilities + "create_model_table", + "create_stats_table", + "create_help_table", + "display_paginated_table", + # Prompt utilities + "prompt_confirm", + "prompt_choice", + "prompt_input", +] diff --git a/oai/ui/console.py b/oai/ui/console.py new file mode 100644 index 0000000..6836b7d --- /dev/null +++ b/oai/ui/console.py @@ -0,0 +1,242 @@ +""" +Console utilities for oAI. + +This module provides the Rich console instance and common display functions +for formatted terminal output. +""" + +from typing import Any, Optional + +from rich.console import Console +from rich.markdown import Markdown +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +# Global console instance for the application +console = Console() + + +def clear_screen() -> None: + """ + Clear the terminal screen. + + Uses ANSI escape codes for fast clearing, with a fallback + for terminals that don't support them. + """ + try: + print("\033[H\033[J", end="", flush=True) + except Exception: + # Fallback: print many newlines + print("\n" * 100) + + +def display_panel( + content: Any, + title: Optional[str] = None, + subtitle: Optional[str] = None, + border_style: str = "green", + title_align: str = "left", + subtitle_align: str = "right", +) -> None: + """ + Display content in a bordered panel. + + Args: + content: Content to display (string, Table, or Markdown) + title: Optional panel title + subtitle: Optional panel subtitle + border_style: Border color/style + title_align: Title alignment ("left", "center", "right") + subtitle_align: Subtitle alignment + """ + panel = Panel( + content, + title=title, + subtitle=subtitle, + border_style=border_style, + title_align=title_align, + subtitle_align=subtitle_align, + ) + console.print(panel) + + +def display_table( + table: Table, + title: Optional[str] = None, + subtitle: Optional[str] = None, +) -> None: + """ + Display a table with optional title panel. + + Args: + table: Rich Table to display + title: Optional panel title + subtitle: Optional panel subtitle + """ + if title: + display_panel(table, title=title, subtitle=subtitle) + else: + console.print(table) + + +def display_markdown( + content: str, + panel: bool = False, + title: Optional[str] = None, +) -> None: + """ + Display markdown-formatted content. + + Args: + content: Markdown text to display + panel: Whether to wrap in a panel + title: Optional panel title (if panel=True) + """ + md = Markdown(content) + if panel: + display_panel(md, title=title) + else: + console.print(md) + + +def print_error(message: str, prefix: str = "Error:") -> None: + """ + Print an error message in red. + + Args: + message: Error message to display + prefix: Prefix before the message (default: "Error:") + """ + console.print(f"[bold red]{prefix}[/] {message}") + + +def print_warning(message: str, prefix: str = "Warning:") -> None: + """ + Print a warning message in yellow. + + Args: + message: Warning message to display + prefix: Prefix before the message (default: "Warning:") + """ + console.print(f"[bold yellow]{prefix}[/] {message}") + + +def print_success(message: str, prefix: str = "✓") -> None: + """ + Print a success message in green. + + Args: + message: Success message to display + prefix: Prefix before the message (default: "✓") + """ + console.print(f"[bold green]{prefix}[/] {message}") + + +def print_info(message: str, dim: bool = False) -> None: + """ + Print an informational message in cyan. + + Args: + message: Info message to display + dim: Whether to dim the message + """ + if dim: + console.print(f"[dim cyan]{message}[/]") + else: + console.print(f"[bold cyan]{message}[/]") + + +def print_metrics( + tokens: int, + cost: float, + time_seconds: float, + context_info: str = "", + online: bool = False, + mcp_mode: Optional[str] = None, + tool_loops: int = 0, + session_tokens: int = 0, + session_cost: float = 0.0, +) -> None: + """ + Print formatted metrics for a response. + + Args: + tokens: Total tokens used + cost: Cost in USD + time_seconds: Response time + context_info: Context information string + online: Whether online mode is active + mcp_mode: MCP mode ("files", "database", or None) + tool_loops: Number of tool call loops + session_tokens: Total session tokens + session_cost: Total session cost + """ + parts = [ + f"📊 Metrics: {tokens} tokens", + f"${cost:.4f}", + f"{time_seconds:.2f}s", + ] + + if context_info: + parts.append(context_info) + + if online: + parts.append("🌐") + + if mcp_mode == "files": + parts.append("🔧") + elif mcp_mode == "database": + parts.append("🗄️") + + if tool_loops > 0: + parts.append(f"({tool_loops} tool loop(s))") + + parts.append(f"Session: {session_tokens} tokens") + parts.append(f"${session_cost:.4f}") + + console.print(f"\n[dim blue]{' | '.join(parts)}[/]") + + +def format_size(size_bytes: int) -> str: + """ + Format a size in bytes to a human-readable string. + + Args: + size_bytes: Size in bytes + + Returns: + Formatted size string (e.g., "1.5 MB") + """ + for unit in ["B", "KB", "MB", "GB", "TB"]: + if abs(size_bytes) < 1024.0: + return f"{size_bytes:.1f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.1f} PB" + + +def format_tokens(tokens: int) -> str: + """ + Format token count with thousands separators. + + Args: + tokens: Number of tokens + + Returns: + Formatted token string (e.g., "1,234,567") + """ + return f"{tokens:,}" + + +def format_cost(cost: float, precision: int = 4) -> str: + """ + Format cost in USD. + + Args: + cost: Cost in dollars + precision: Decimal places + + Returns: + Formatted cost string (e.g., "$0.0123") + """ + return f"${cost:.{precision}f}" diff --git a/oai/ui/prompts.py b/oai/ui/prompts.py new file mode 100644 index 0000000..654a43e --- /dev/null +++ b/oai/ui/prompts.py @@ -0,0 +1,274 @@ +""" +Prompt utilities for oAI. + +This module provides functions for gathering user input +through confirmations, choices, and text prompts. +""" + +from typing import List, Optional, TypeVar + +import typer + +from oai.ui.console import console + +T = TypeVar("T") + + +def prompt_confirm( + message: str, + default: bool = False, + abort: bool = False, +) -> bool: + """ + Prompt the user for a yes/no confirmation. + + Args: + message: The question to ask + default: Default value if user presses Enter + abort: Whether to abort on "no" response + + Returns: + True if user confirms, False otherwise + """ + try: + return typer.confirm(message, default=default, abort=abort) + except (EOFError, KeyboardInterrupt): + console.print("\n[yellow]Cancelled[/]") + return False + + +def prompt_choice( + message: str, + choices: List[str], + default: Optional[str] = None, +) -> Optional[str]: + """ + Prompt the user to select from a list of choices. + + Args: + message: The question to ask + choices: List of valid choices + default: Default choice if user presses Enter + + Returns: + Selected choice or None if cancelled + """ + # Display choices + console.print(f"\n[bold cyan]{message}[/]") + for i, choice in enumerate(choices, 1): + default_marker = " [default]" if choice == default else "" + console.print(f" {i}. {choice}{default_marker}") + + try: + response = input("\nEnter number or value: ").strip() + + if not response and default: + return default + + # Try as number first + try: + index = int(response) - 1 + if 0 <= index < len(choices): + return choices[index] + except ValueError: + pass + + # Try as exact match + if response in choices: + return response + + # Try case-insensitive match + response_lower = response.lower() + for choice in choices: + if choice.lower() == response_lower: + return choice + + console.print(f"[red]Invalid choice: {response}[/]") + return None + + except (EOFError, KeyboardInterrupt): + console.print("\n[yellow]Cancelled[/]") + return None + + +def prompt_input( + message: str, + default: Optional[str] = None, + password: bool = False, + required: bool = False, +) -> Optional[str]: + """ + Prompt the user for text input. + + Args: + message: The prompt message + default: Default value if user presses Enter + password: Whether to hide input (for sensitive data) + required: Whether input is required (loops until provided) + + Returns: + User input or default, None if cancelled + """ + prompt_text = message + if default: + prompt_text += f" [{default}]" + prompt_text += ": " + + try: + while True: + if password: + import getpass + + response = getpass.getpass(prompt_text) + else: + response = input(prompt_text).strip() + + if not response: + if default: + return default + if required: + console.print("[yellow]Input required[/]") + continue + return None + + return response + + except (EOFError, KeyboardInterrupt): + console.print("\n[yellow]Cancelled[/]") + return None + + +def prompt_number( + message: str, + min_value: Optional[int] = None, + max_value: Optional[int] = None, + default: Optional[int] = None, +) -> Optional[int]: + """ + Prompt the user for a numeric input. + + Args: + message: The prompt message + min_value: Minimum allowed value + max_value: Maximum allowed value + default: Default value if user presses Enter + + Returns: + Integer value or None if cancelled + """ + prompt_text = message + if default is not None: + prompt_text += f" [{default}]" + prompt_text += ": " + + try: + while True: + response = input(prompt_text).strip() + + if not response: + if default is not None: + return default + return None + + try: + value = int(response) + except ValueError: + console.print("[red]Please enter a valid number[/]") + continue + + if min_value is not None and value < min_value: + console.print(f"[red]Value must be at least {min_value}[/]") + continue + + if max_value is not None and value > max_value: + console.print(f"[red]Value must be at most {max_value}[/]") + continue + + return value + + except (EOFError, KeyboardInterrupt): + console.print("\n[yellow]Cancelled[/]") + return None + + +def prompt_selection( + items: List[T], + message: str = "Select an item", + display_func: Optional[callable] = None, + allow_cancel: bool = True, +) -> Optional[T]: + """ + Prompt the user to select an item from a list. + + Args: + items: List of items to choose from + message: The selection prompt + display_func: Function to convert item to display string + allow_cancel: Whether to allow cancellation + + Returns: + Selected item or None if cancelled + """ + if not items: + console.print("[yellow]No items to select[/]") + return None + + display = display_func or str + + console.print(f"\n[bold cyan]{message}[/]") + for i, item in enumerate(items, 1): + console.print(f" {i}. {display(item)}") + + if allow_cancel: + console.print(f" 0. Cancel") + + try: + while True: + response = input("\nEnter number: ").strip() + + try: + index = int(response) + except ValueError: + console.print("[red]Please enter a valid number[/]") + continue + + if allow_cancel and index == 0: + return None + + if 1 <= index <= len(items): + return items[index - 1] + + console.print(f"[red]Please enter a number between 1 and {len(items)}[/]") + + except (EOFError, KeyboardInterrupt): + console.print("\n[yellow]Cancelled[/]") + return None + + +def prompt_copy_response(response: str) -> bool: + """ + Prompt user to copy a response to clipboard. + + Args: + response: The response text + + Returns: + True if copied, False otherwise + """ + try: + copy_choice = input("💾 Type 'c' to copy response, or press Enter to continue: ").strip().lower() + if copy_choice == "c": + try: + import pyperclip + + pyperclip.copy(response) + console.print("[bold green]✅ Response copied to clipboard![/]") + return True + except ImportError: + console.print("[yellow]pyperclip not installed - cannot copy to clipboard[/]") + except Exception as e: + console.print(f"[red]Failed to copy: {e}[/]") + except (EOFError, KeyboardInterrupt): + pass + + return False diff --git a/oai/ui/tables.py b/oai/ui/tables.py new file mode 100644 index 0000000..49a1c2e --- /dev/null +++ b/oai/ui/tables.py @@ -0,0 +1,373 @@ +""" +Table utilities for oAI. + +This module provides functions for creating and displaying +formatted tables with pagination support. +""" + +import os +import sys +from typing import Any, Dict, List, Optional + +from rich.panel import Panel +from rich.table import Table + +from oai.ui.console import clear_screen, console + + +def create_model_table( + models: List[Dict[str, Any]], + show_capabilities: bool = True, +) -> Table: + """ + Create a table displaying available AI models. + + Args: + models: List of model dictionaries + show_capabilities: Whether to show capability columns + + Returns: + Rich Table with model information + """ + if show_capabilities: + table = Table( + "No.", + "Model ID", + "Context", + "Image", + "Online", + "Tools", + show_header=True, + header_style="bold magenta", + ) + else: + table = Table( + "No.", + "Model ID", + "Context", + show_header=True, + header_style="bold magenta", + ) + + for i, model in enumerate(models, 1): + model_id = model.get("id", "Unknown") + context = model.get("context_length", 0) + context_str = f"{context:,}" if context else "-" + + if show_capabilities: + # Get modalities and parameters + architecture = model.get("architecture", {}) + input_modalities = architecture.get("input_modalities", []) + supported_params = model.get("supported_parameters", []) + + has_image = "✓" if "image" in input_modalities else "-" + has_online = "✓" if "tools" in supported_params else "-" + has_tools = "✓" if "tools" in supported_params or "functions" in supported_params else "-" + + table.add_row( + str(i), + model_id, + context_str, + has_image, + has_online, + has_tools, + ) + else: + table.add_row(str(i), model_id, context_str) + + return table + + +def create_stats_table(stats: Dict[str, Any]) -> Table: + """ + Create a table displaying session statistics. + + Args: + stats: Dictionary with statistics data + + Returns: + Rich Table with stats + """ + table = Table( + "Metric", + "Value", + show_header=True, + header_style="bold magenta", + ) + + # Token stats + if "input_tokens" in stats: + table.add_row("Input Tokens", f"{stats['input_tokens']:,}") + if "output_tokens" in stats: + table.add_row("Output Tokens", f"{stats['output_tokens']:,}") + if "total_tokens" in stats: + table.add_row("Total Tokens", f"{stats['total_tokens']:,}") + + # Cost stats + if "total_cost" in stats: + table.add_row("Total Cost", f"${stats['total_cost']:.4f}") + if "avg_cost" in stats: + table.add_row("Avg Cost/Message", f"${stats['avg_cost']:.4f}") + + # Message stats + if "message_count" in stats: + table.add_row("Messages", str(stats["message_count"])) + + # Credits + if "credits_left" in stats: + table.add_row("Credits Left", stats["credits_left"]) + + return table + + +def create_help_table(commands: Dict[str, Dict[str, str]]) -> Table: + """ + Create a help table for commands. + + Args: + commands: Dictionary of command info + + Returns: + Rich Table with command help + """ + table = Table( + "Command", + "Description", + "Example", + show_header=True, + header_style="bold magenta", + show_lines=False, + ) + + for cmd, info in commands.items(): + description = info.get("description", "") + example = info.get("example", "") + table.add_row(cmd, description, example) + + return table + + +def create_folder_table( + folders: List[Dict[str, Any]], + gitignore_info: str = "", +) -> Table: + """ + Create a table for MCP folder listing. + + Args: + folders: List of folder dictionaries + gitignore_info: Optional gitignore status info + + Returns: + Rich Table with folder information + """ + table = Table( + "No.", + "Path", + "Files", + "Size", + show_header=True, + header_style="bold magenta", + ) + + for folder in folders: + number = str(folder.get("number", "")) + path = folder.get("path", "") + + if folder.get("exists", True): + files = f"📁 {folder.get('file_count', 0)}" + size = f"{folder.get('size_mb', 0):.1f} MB" + else: + files = "[red]Not found[/red]" + size = "-" + + table.add_row(number, path, files, size) + + return table + + +def create_database_table(databases: List[Dict[str, Any]]) -> Table: + """ + Create a table for MCP database listing. + + Args: + databases: List of database dictionaries + + Returns: + Rich Table with database information + """ + table = Table( + "No.", + "Name", + "Tables", + "Size", + "Status", + show_header=True, + header_style="bold magenta", + ) + + for db in databases: + number = str(db.get("number", "")) + name = db.get("name", "") + table_count = f"{db.get('table_count', 0)} tables" + size = f"{db.get('size_mb', 0):.1f} MB" + + if db.get("warning"): + status = f"[red]{db['warning']}[/red]" + else: + status = "[green]✓[/green]" + + table.add_row(number, name, table_count, size, status) + + return table + + +def display_paginated_table( + table: Table, + title: str, + terminal_height: Optional[int] = None, +) -> None: + """ + Display a table with pagination for large datasets. + + Allows navigating through pages with keyboard input. + Press SPACE for next page, any other key to exit. + + Args: + table: Rich Table to display + title: Title for the table + terminal_height: Override terminal height (auto-detected if None) + """ + # Get terminal dimensions + try: + term_height = terminal_height or os.get_terminal_size().lines - 8 + except OSError: + term_height = 20 + + # Render table to segments + from rich.segment import Segment + + segments = list(console.render(table)) + + # Group segments into lines + current_line_segments: List[Segment] = [] + all_lines: List[List[Segment]] = [] + + for segment in segments: + if segment.text == "\n": + all_lines.append(current_line_segments) + current_line_segments = [] + else: + current_line_segments.append(segment) + + if current_line_segments: + all_lines.append(current_line_segments) + + total_lines = len(all_lines) + + # If table fits in one screen, just display it + if total_lines <= term_height: + console.print(Panel(table, title=title, title_align="left")) + return + + # Extract header and footer lines + header_lines: List[List[Segment]] = [] + data_lines: List[List[Segment]] = [] + footer_line: List[Segment] = [] + + # Find header end (line after the header text with border) + header_end_index = 0 + found_header_text = False + + for i, line_segments in enumerate(all_lines): + has_header_style = any( + seg.style and ("bold" in str(seg.style) or "magenta" in str(seg.style)) + for seg in line_segments + ) + + if has_header_style: + found_header_text = True + + if found_header_text and i > 0: + line_text = "".join(seg.text for seg in line_segments) + if any(char in line_text for char in ["─", "━", "┼", "╪", "┤", "├"]): + header_end_index = i + break + + # Extract footer (bottom border) + if all_lines: + last_line_text = "".join(seg.text for seg in all_lines[-1]) + if any(char in last_line_text for char in ["─", "━", "┴", "╧", "┘", "└"]): + footer_line = all_lines[-1] + all_lines = all_lines[:-1] + + # Split into header and data + if header_end_index > 0: + header_lines = all_lines[: header_end_index + 1] + data_lines = all_lines[header_end_index + 1 :] + else: + header_lines = all_lines[: min(3, len(all_lines))] + data_lines = all_lines[min(3, len(all_lines)) :] + + lines_per_page = term_height - len(header_lines) + current_line = 0 + page_number = 1 + + # Paginate + while current_line < len(data_lines): + clear_screen() + console.print(f"[bold cyan]{title} (Page {page_number})[/]") + + # Print header + for line_segments in header_lines: + for segment in line_segments: + console.print(segment.text, style=segment.style, end="") + console.print() + + # Print data rows for this page + end_line = min(current_line + lines_per_page, len(data_lines)) + for line_segments in data_lines[current_line:end_line]: + for segment in line_segments: + console.print(segment.text, style=segment.style, end="") + console.print() + + # Print footer + if footer_line: + for segment in footer_line: + console.print(segment.text, style=segment.style, end="") + console.print() + + current_line = end_line + page_number += 1 + + # Prompt for next page + if current_line < len(data_lines): + console.print( + f"\n[dim yellow]--- Press SPACE for next page, " + f"or any other key to finish (Page {page_number - 1}, " + f"showing {end_line}/{len(data_lines)} data rows) ---[/dim yellow]" + ) + + try: + import termios + import tty + + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + + try: + tty.setraw(fd) + char = sys.stdin.read(1) + if char != " ": + break + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + + except (ImportError, OSError, AttributeError): + # Fallback for non-Unix systems + try: + user_input = input() + if user_input.strip(): + break + except (EOFError, KeyboardInterrupt): + break diff --git a/oai/utils/__init__.py b/oai/utils/__init__.py new file mode 100644 index 0000000..6048bfe --- /dev/null +++ b/oai/utils/__init__.py @@ -0,0 +1,20 @@ +""" +Utility modules for oAI. + +This package provides common utilities used throughout the application +including logging, file handling, and export functionality. +""" + +from oai.utils.logging import setup_logging, get_logger +from oai.utils.files import read_file_safe, is_binary_file +from oai.utils.export import export_as_markdown, export_as_json, export_as_html + +__all__ = [ + "setup_logging", + "get_logger", + "read_file_safe", + "is_binary_file", + "export_as_markdown", + "export_as_json", + "export_as_html", +] diff --git a/oai/utils/export.py b/oai/utils/export.py new file mode 100644 index 0000000..c476ef9 --- /dev/null +++ b/oai/utils/export.py @@ -0,0 +1,248 @@ +""" +Export utilities for oAI. + +This module provides functions for exporting conversation history +in various formats including Markdown, JSON, and HTML. +""" + +import json +import datetime +from typing import List, Dict +from html import escape as html_escape + +from oai.constants import APP_VERSION, APP_URL + + +def export_as_markdown( + session_history: List[Dict[str, str]], + session_system_prompt: str = "" +) -> str: + """ + Export conversation history as Markdown. + + Args: + session_history: List of message dictionaries with 'prompt' and 'response' + session_system_prompt: Optional system prompt to include + + Returns: + Markdown formatted string + """ + lines = ["# Conversation Export", ""] + + if session_system_prompt: + lines.extend([f"**System Prompt:** {session_system_prompt}", ""]) + + lines.append(f"**Export Date:** {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + lines.append(f"**Messages:** {len(session_history)}") + lines.append("") + lines.append("---") + lines.append("") + + for i, entry in enumerate(session_history, 1): + lines.append(f"## Message {i}") + lines.append("") + lines.append("**User:**") + lines.append("") + lines.append(entry.get("prompt", "")) + lines.append("") + lines.append("**Assistant:**") + lines.append("") + lines.append(entry.get("response", "")) + lines.append("") + lines.append("---") + lines.append("") + + lines.append(f"*Exported from oAI v{APP_VERSION} - {APP_URL}*") + + return "\n".join(lines) + + +def export_as_json( + session_history: List[Dict[str, str]], + session_system_prompt: str = "" +) -> str: + """ + Export conversation history as JSON. + + Args: + session_history: List of message dictionaries + session_system_prompt: Optional system prompt to include + + Returns: + JSON formatted string + """ + export_data = { + "export_date": datetime.datetime.now().isoformat(), + "app_version": APP_VERSION, + "system_prompt": session_system_prompt, + "message_count": len(session_history), + "messages": [ + { + "index": i + 1, + "prompt": entry.get("prompt", ""), + "response": entry.get("response", ""), + "prompt_tokens": entry.get("prompt_tokens", 0), + "completion_tokens": entry.get("completion_tokens", 0), + "cost": entry.get("msg_cost", 0.0), + } + for i, entry in enumerate(session_history) + ], + "totals": { + "prompt_tokens": sum(e.get("prompt_tokens", 0) for e in session_history), + "completion_tokens": sum(e.get("completion_tokens", 0) for e in session_history), + "total_cost": sum(e.get("msg_cost", 0.0) for e in session_history), + } + } + return json.dumps(export_data, indent=2, ensure_ascii=False) + + +def export_as_html( + session_history: List[Dict[str, str]], + session_system_prompt: str = "" +) -> str: + """ + Export conversation history as styled HTML. + + Args: + session_history: List of message dictionaries + session_system_prompt: Optional system prompt to include + + Returns: + HTML formatted string with embedded CSS + """ + html_parts = [ + "", + "", + "", + " ", + " ", + " Conversation Export - oAI", + " ", + "", + "", + "
", + "

Conversation Export

", + f"
Exported: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
", + f"
Total Messages: {len(session_history)}
", + "
", + ] + + if session_system_prompt: + html_parts.extend([ + "
", + " System Prompt", + f"
{html_escape(session_system_prompt)}
", + "
", + ]) + + for i, entry in enumerate(session_history, 1): + prompt = html_escape(entry.get("prompt", "")) + response = html_escape(entry.get("response", "")) + + html_parts.extend([ + "
", + f"
Message {i} of {len(session_history)}
", + "
", + "
User
", + f"
{prompt}
", + "
", + "
", + "
Assistant
", + f"
{response}
", + "
", + "
", + ]) + + html_parts.extend([ + " ", + "", + "", + ]) + + return "\n".join(html_parts) diff --git a/oai/utils/files.py b/oai/utils/files.py new file mode 100644 index 0000000..f793d8a --- /dev/null +++ b/oai/utils/files.py @@ -0,0 +1,323 @@ +""" +File handling utilities for oAI. + +This module provides safe file reading, type detection, and other +file-related operations used throughout the application. +""" + +import os +import mimetypes +import base64 +from pathlib import Path +from typing import Optional, Dict, Any, Tuple + +from oai.constants import ( + MAX_FILE_SIZE, + CONTENT_TRUNCATION_THRESHOLD, + SUPPORTED_CODE_EXTENSIONS, + ALLOWED_FILE_EXTENSIONS, +) +from oai.utils.logging import get_logger + + +def is_binary_file(file_path: Path) -> bool: + """ + Check if a file appears to be binary. + + Args: + file_path: Path to the file to check + + Returns: + True if the file appears to be binary, False otherwise + """ + try: + with open(file_path, "rb") as f: + # Read first 8KB to check for binary content + chunk = f.read(8192) + # Check for null bytes (common in binary files) + if b"\x00" in chunk: + return True + # Try to decode as UTF-8 + try: + chunk.decode("utf-8") + return False + except UnicodeDecodeError: + return True + except Exception: + return True + + +def get_file_type(file_path: Path) -> Tuple[Optional[str], str]: + """ + Determine the MIME type and category of a file. + + Args: + file_path: Path to the file + + Returns: + Tuple of (mime_type, category) where category is one of: + 'image', 'pdf', 'code', 'text', 'binary', 'unknown' + """ + mime_type, _ = mimetypes.guess_type(str(file_path)) + ext = file_path.suffix.lower() + + if mime_type and mime_type.startswith("image/"): + return mime_type, "image" + elif mime_type == "application/pdf" or ext == ".pdf": + return mime_type or "application/pdf", "pdf" + elif ext in SUPPORTED_CODE_EXTENSIONS: + return mime_type or "text/plain", "code" + elif mime_type and mime_type.startswith("text/"): + return mime_type, "text" + elif is_binary_file(file_path): + return mime_type, "binary" + else: + return mime_type, "unknown" + + +def read_file_safe( + file_path: Path, + max_size: int = MAX_FILE_SIZE, + truncate_threshold: int = CONTENT_TRUNCATION_THRESHOLD +) -> Dict[str, Any]: + """ + Safely read a file with size limits and truncation support. + + Args: + file_path: Path to the file to read + max_size: Maximum file size to read (bytes) + truncate_threshold: Threshold for truncating large files + + Returns: + Dictionary containing: + - content: File content (text or base64) + - size: File size in bytes + - truncated: Whether content was truncated + - encoding: 'text', 'base64', or None on error + - error: Error message if reading failed + """ + logger = get_logger() + + try: + path = Path(file_path).resolve() + + if not path.exists(): + return { + "content": None, + "size": 0, + "truncated": False, + "encoding": None, + "error": f"File not found: {path}" + } + + if not path.is_file(): + return { + "content": None, + "size": 0, + "truncated": False, + "encoding": None, + "error": f"Not a file: {path}" + } + + file_size = path.stat().st_size + + if file_size > max_size: + return { + "content": None, + "size": file_size, + "truncated": False, + "encoding": None, + "error": f"File too large: {file_size / (1024*1024):.1f}MB (max: {max_size / (1024*1024):.0f}MB)" + } + + # Try to read as text first + try: + content = path.read_text(encoding="utf-8") + + # Check if truncation is needed + if file_size > truncate_threshold: + lines = content.split("\n") + total_lines = len(lines) + + # Keep first 500 lines and last 100 lines + head_lines = 500 + tail_lines = 100 + + if total_lines > (head_lines + tail_lines): + truncated_content = ( + "\n".join(lines[:head_lines]) + + f"\n\n... [TRUNCATED: {total_lines - head_lines - tail_lines} lines omitted] ...\n\n" + + "\n".join(lines[-tail_lines:]) + ) + logger.info(f"Read file (truncated): {path} ({file_size} bytes, {total_lines} lines)") + return { + "content": truncated_content, + "size": file_size, + "truncated": True, + "total_lines": total_lines, + "lines_shown": head_lines + tail_lines, + "encoding": "text", + "error": None + } + + logger.info(f"Read file: {path} ({file_size} bytes)") + return { + "content": content, + "size": file_size, + "truncated": False, + "encoding": "text", + "error": None + } + + except UnicodeDecodeError: + # File is binary, return base64 encoded + with open(path, "rb") as f: + binary_data = f.read() + b64_content = base64.b64encode(binary_data).decode("utf-8") + logger.info(f"Read binary file: {path} ({file_size} bytes)") + return { + "content": b64_content, + "size": file_size, + "truncated": False, + "encoding": "base64", + "error": None + } + + except PermissionError as e: + return { + "content": None, + "size": 0, + "truncated": False, + "encoding": None, + "error": f"Permission denied: {e}" + } + except Exception as e: + logger.error(f"Error reading file {file_path}: {e}") + return { + "content": None, + "size": 0, + "truncated": False, + "encoding": None, + "error": str(e) + } + + +def get_file_extension(file_path: Path) -> str: + """ + Get the lowercase file extension. + + Args: + file_path: Path to the file + + Returns: + Lowercase extension including the dot (e.g., '.py') + """ + return file_path.suffix.lower() + + +def is_allowed_extension(file_path: Path) -> bool: + """ + Check if a file has an allowed extension for attachment. + + Args: + file_path: Path to the file + + Returns: + True if the extension is allowed, False otherwise + """ + return get_file_extension(file_path) in ALLOWED_FILE_EXTENSIONS + + +def format_file_size(size_bytes: int) -> str: + """ + Format a file size in human-readable format. + + Args: + size_bytes: Size in bytes + + Returns: + Formatted string (e.g., '1.5 MB', '512 KB') + """ + for unit in ["B", "KB", "MB", "GB", "TB"]: + if abs(size_bytes) < 1024: + return f"{size_bytes:.1f} {unit}" + size_bytes /= 1024 + return f"{size_bytes:.1f} PB" + + +def prepare_file_attachment( + file_path: Path, + model_capabilities: Dict[str, Any] +) -> Optional[Dict[str, Any]]: + """ + Prepare a file for attachment to an API request. + + Args: + file_path: Path to the file + model_capabilities: Model capability information + + Returns: + Content block dictionary for the API, or None if unsupported + """ + logger = get_logger() + path = Path(file_path).resolve() + + if not path.exists(): + logger.warning(f"File not found: {path}") + return None + + mime_type, category = get_file_type(path) + file_size = path.stat().st_size + + if file_size > MAX_FILE_SIZE: + logger.warning(f"File too large: {path} ({format_file_size(file_size)})") + return None + + try: + with open(path, "rb") as f: + file_data = f.read() + + if category == "image": + # Check if model supports images + input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", []) + if "image" not in input_modalities: + logger.warning(f"Model does not support images") + return None + + b64_data = base64.b64encode(file_data).decode("utf-8") + return { + "type": "image_url", + "image_url": {"url": f"data:{mime_type};base64,{b64_data}"} + } + + elif category == "pdf": + # Check if model supports PDFs + input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", []) + supports_pdf = any(mod in input_modalities for mod in ["document", "pdf", "file"]) + if not supports_pdf: + logger.warning(f"Model does not support PDFs") + return None + + b64_data = base64.b64encode(file_data).decode("utf-8") + return { + "type": "image_url", + "image_url": {"url": f"data:application/pdf;base64,{b64_data}"} + } + + elif category in ("code", "text"): + text_content = file_data.decode("utf-8") + return { + "type": "text", + "text": f"File: {path.name}\n\n{text_content}" + } + + else: + logger.warning(f"Unsupported file type: {category} ({mime_type})") + return None + + except UnicodeDecodeError: + logger.error(f"Cannot decode file as UTF-8: {path}") + return None + except Exception as e: + logger.error(f"Error preparing file attachment {path}: {e}") + return None diff --git a/oai/utils/logging.py b/oai/utils/logging.py new file mode 100644 index 0000000..859c51c --- /dev/null +++ b/oai/utils/logging.py @@ -0,0 +1,297 @@ +""" +Logging configuration for oAI. + +This module provides centralized logging setup with Rich formatting, +file rotation, and configurable log levels. +""" + +import io +import os +import glob +import logging +import datetime +import shutil +from logging.handlers import RotatingFileHandler +from pathlib import Path +from typing import Optional + +from rich.console import Console +from rich.logging import RichHandler + +from oai.constants import ( + LOG_FILE, + CONFIG_DIR, + DEFAULT_LOG_MAX_SIZE_MB, + DEFAULT_LOG_BACKUP_COUNT, + DEFAULT_LOG_LEVEL, + VALID_LOG_LEVELS, +) + + +class RotatingRichHandler(RotatingFileHandler): + """ + Custom log handler combining file rotation with Rich formatting. + + This handler writes Rich-formatted log output to a rotating file, + providing colored and formatted logs even in file output while + managing file size and backups automatically. + """ + + def __init__(self, *args, **kwargs): + """Initialize the handler with Rich console for formatting.""" + super().__init__(*args, **kwargs) + # Create an internal console for Rich formatting + self.rich_console = Console( + file=io.StringIO(), + width=120, + force_terminal=False + ) + self.rich_handler = RichHandler( + console=self.rich_console, + show_time=True, + show_path=True, + rich_tracebacks=True, + tracebacks_suppress=["requests", "openrouter", "urllib3", "httpx", "openai"] + ) + + def emit(self, record: logging.LogRecord) -> None: + """ + Emit a log record with Rich formatting. + + Args: + record: The log record to emit + """ + try: + # Format with Rich + self.rich_handler.emit(record) + output = self.rich_console.file.getvalue() + self.rich_console.file.seek(0) + self.rich_console.file.truncate(0) + + if output: + self.stream.write(output) + self.flush() + except Exception: + self.handleError(record) + + +class LoggingManager: + """ + Manages application logging configuration. + + Provides methods to setup, configure, and manage logging with + support for runtime reconfiguration and level changes. + """ + + def __init__(self): + """Initialize the logging manager.""" + self.handler: Optional[RotatingRichHandler] = None + self.app_logger: Optional[logging.Logger] = None + self.max_size_mb: int = DEFAULT_LOG_MAX_SIZE_MB + self.backup_count: int = DEFAULT_LOG_BACKUP_COUNT + self.log_level: str = DEFAULT_LOG_LEVEL + + def setup( + self, + max_size_mb: Optional[int] = None, + backup_count: Optional[int] = None, + log_level: Optional[str] = None + ) -> logging.Logger: + """ + Setup or reconfigure logging. + + Args: + max_size_mb: Maximum log file size in MB + backup_count: Number of backup files to keep + log_level: Logging level string + + Returns: + The configured application logger + """ + # Update configuration if provided + if max_size_mb is not None: + self.max_size_mb = max_size_mb + if backup_count is not None: + self.backup_count = backup_count + if log_level is not None: + self.log_level = log_level + + # Ensure config directory exists + CONFIG_DIR.mkdir(parents=True, exist_ok=True) + + # Get root logger + root_logger = logging.getLogger() + + # Remove existing handler if present + if self.handler is not None: + root_logger.removeHandler(self.handler) + try: + self.handler.close() + except Exception: + pass + + # Check if log needs manual rotation + self._check_rotation() + + # Create new handler + max_bytes = self.max_size_mb * 1024 * 1024 + self.handler = RotatingRichHandler( + filename=str(LOG_FILE), + maxBytes=max_bytes, + backupCount=self.backup_count, + encoding="utf-8" + ) + + self.handler.setLevel(logging.NOTSET) + root_logger.setLevel(logging.WARNING) + root_logger.addHandler(self.handler) + + # Suppress noisy third-party loggers + for logger_name in [ + "asyncio", "urllib3", "requests", "httpx", + "httpcore", "openai", "openrouter" + ]: + logging.getLogger(logger_name).setLevel(logging.WARNING) + + # Configure application logger + self.app_logger = logging.getLogger("oai_app") + level = VALID_LOG_LEVELS.get(self.log_level.lower(), logging.INFO) + self.app_logger.setLevel(level) + self.app_logger.propagate = True + + return self.app_logger + + def _check_rotation(self) -> None: + """Check if log file needs rotation and perform if necessary.""" + if not LOG_FILE.exists(): + return + + current_size = LOG_FILE.stat().st_size + max_bytes = self.max_size_mb * 1024 * 1024 + + if current_size >= max_bytes: + # Perform manual rotation + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + backup_file = f"{LOG_FILE}.{timestamp}" + + try: + shutil.move(str(LOG_FILE), backup_file) + except Exception: + pass + + # Clean old backups + self._cleanup_old_backups() + + def _cleanup_old_backups(self) -> None: + """Remove old backup files exceeding the backup count.""" + log_dir = LOG_FILE.parent + backup_pattern = f"{LOG_FILE.name}.*" + backups = sorted(glob.glob(str(log_dir / backup_pattern))) + + while len(backups) > self.backup_count: + oldest = backups.pop(0) + try: + os.remove(oldest) + except Exception: + pass + + def set_level(self, level: str) -> bool: + """ + Set the application log level. + + Args: + level: Log level string (debug/info/warning/error/critical) + + Returns: + True if level was set successfully, False otherwise + """ + level_lower = level.lower() + if level_lower not in VALID_LOG_LEVELS: + return False + + self.log_level = level_lower + if self.app_logger: + self.app_logger.setLevel(VALID_LOG_LEVELS[level_lower]) + + return True + + def get_logger(self) -> logging.Logger: + """ + Get the application logger, initializing if necessary. + + Returns: + The application logger + """ + if self.app_logger is None: + self.setup() + return self.app_logger + + +# Global logging manager instance +_logging_manager = LoggingManager() + + +def setup_logging( + max_size_mb: Optional[int] = None, + backup_count: Optional[int] = None, + log_level: Optional[str] = None +) -> logging.Logger: + """ + Setup application logging. + + This is the main entry point for configuring logging. Call this + early in application startup. + + Args: + max_size_mb: Maximum log file size in MB + backup_count: Number of backup files to keep + log_level: Logging level string + + Returns: + The configured application logger + """ + return _logging_manager.setup(max_size_mb, backup_count, log_level) + + +def get_logger() -> logging.Logger: + """ + Get the application logger. + + Returns: + The application logger instance + """ + return _logging_manager.get_logger() + + +def set_log_level(level: str) -> bool: + """ + Set the application log level. + + Args: + level: Log level string + + Returns: + True if successful, False otherwise + """ + return _logging_manager.set_level(level) + + +def reload_logging( + max_size_mb: Optional[int] = None, + backup_count: Optional[int] = None, + log_level: Optional[str] = None +) -> logging.Logger: + """ + Reload logging configuration. + + Useful when settings change at runtime. + + Args: + max_size_mb: New maximum log file size + backup_count: New backup count + log_level: New log level + + Returns: + The reconfigured logger + """ + return _logging_manager.setup(max_size_mb, backup_count, log_level) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..af80856 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,134 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "oai" +version = "2.1.0" +description = "OpenRouter AI Chat Client - A feature-rich terminal-based chat application" +readme = "README.md" +license = {text = "MIT"} +authors = [ + {name = "Rune", email = "rune@example.com"} +] +maintainers = [ + {name = "Rune", email = "rune@example.com"} +] +keywords = [ + "ai", + "chat", + "openrouter", + "cli", + "terminal", + "mcp", + "llm", +] +classifiers = [ + "Development Status :: 4 - Beta", + "Environment :: Console", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Utilities", +] +requires-python = ">=3.10" +dependencies = [ + "anyio>=4.0.0", + "click>=8.0.0", + "httpx>=0.24.0", + "markdown-it-py>=3.0.0", + "openrouter>=0.0.19", + "packaging>=21.0", + "prompt-toolkit>=3.0.0", + "pyperclip>=1.8.0", + "requests>=2.28.0", + "rich>=13.0.0", + "typer>=0.9.0", + "mcp>=1.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", + "pytest-asyncio>=0.21.0", + "pytest-cov>=4.0.0", + "black>=23.0.0", + "isort>=5.12.0", + "mypy>=1.0.0", + "ruff>=0.1.0", +] + +[project.urls] +Homepage = "https://iurl.no/oai" +Repository = "https://gitlab.pm/rune/oai" +Documentation = "https://iurl.no/oai" +"Bug Tracker" = "https://gitlab.pm/rune/oai/issues" + +[project.scripts] +oai = "oai.cli:main" + +[tool.setuptools] +packages = ["oai", "oai.commands", "oai.config", "oai.core", "oai.mcp", "oai.providers", "oai.ui", "oai.utils"] + +[tool.setuptools.package-data] +oai = ["py.typed"] + +[tool.black] +line-length = 100 +target-version = ["py310", "py311", "py312"] +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.mypy_cache + | \.pytest_cache + | \.venv + | build + | dist +)/ +''' + +[tool.isort] +profile = "black" +line_length = 100 +skip_gitignore = true + +[tool.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_configs = true +ignore_missing_imports = true +exclude = [ + "build", + "dist", + ".venv", +] + +[tool.ruff] +line-length = 100 +target-version = "py310" +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # Pyflakes + "I", # isort + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "UP", # pyupgrade +] +ignore = [ + "E501", # line too long (handled by black) + "B008", # do not perform function calls in argument defaults + "C901", # too complex +] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +asyncio_mode = "auto" +addopts = "-v --tb=short" -- 2.49.1