diff --git a/.gitignore b/.gitignore index 3c2ad4a..c8e0fc6 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,9 @@ Pipfile.lock # Consider if you want to include or exclude *~.nib *~.xib +# Claude Code local settings +.claude/ + # Added by author *.zip .note @@ -39,3 +42,6 @@ b0.sh *.old *.sh *.back +requirements.txt +system_prompt.txt +CLAUDE* diff --git a/README.md b/README.md index 64f3220..61bc8fd 100644 --- a/README.md +++ b/README.md @@ -1,584 +1,301 @@ -# oAI - OpenRouter AI Chat +# oAI - OpenRouter AI Chat Client -A powerful terminal-based chat interface for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI agents to access local files and query SQLite databases directly. - -## Description - -oAI is a feature-rich command-line chat application that provides an interactive interface to OpenRouter's AI models. It now includes **MCP integration** for local file system access and read-only database querying, allowing AI to help with code analysis, data exploration, and more. +A powerful, extensible terminal-based chat client for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI to access local files and query SQLite databases. ## Features ### Core Features - ๐Ÿค– Interactive chat with 300+ AI models via OpenRouter -- ๐Ÿ” Model selection with search and capability filtering +- ๐Ÿ” Model selection with search and filtering - ๐Ÿ’พ Conversation save/load/export (Markdown, JSON, HTML) -- ๐Ÿ“Ž File attachment support (images, PDFs, code files) -- ๐Ÿ’ฐ Session cost tracking and credit monitoring -- ๐ŸŽจ Rich terminal formatting with syntax highlighting +- ๐Ÿ“Ž File attachments (images, PDFs, code files) +- ๐Ÿ’ฐ Real-time cost tracking and credit monitoring +- ๐ŸŽจ Rich terminal UI with syntax highlighting - ๐Ÿ“ Persistent command history with search (Ctrl+R) -- โš™๏ธ Configurable system prompts and token limits -- ๐Ÿ—„๏ธ SQLite-based configuration and conversation storage - ๐ŸŒ Online mode (web search capabilities) -- ๐Ÿง  Conversation memory toggle (save costs with stateless mode) +- ๐Ÿง  Conversation memory toggle -### NEW: MCP (Model Context Protocol) v2.1.0-beta -- ๐Ÿ”ง **File Mode**: AI can read, search, and list your local files +### MCP Integration +- ๐Ÿ”ง **File Mode**: AI can read, search, and list local files - Automatic .gitignore filtering - - Virtual environment exclusion (venv, node_modules, etc.) - - Supports code files, text, JSON, YAML, and more + - Virtual environment exclusion - Large file handling (auto-truncates >50KB) - -- โœ๏ธ **Write Mode** (NEW!): AI can modify files with your permission - - Create and edit files within allowed folders - - Delete files (always requires confirmation) - - Move, copy, and organize files - - Create directories - - Ignores .gitignore for write operations - - OFF by default - explicit opt-in required - -- ๐Ÿ—„๏ธ **Database Mode**: AI can query your SQLite databases - - Read-only access (no data modification possible) - - Schema inspection (tables, columns, indexes) - - Full-text search across all tables - - SQL query execution (SELECT, JOINs, CTEs, subqueries) - - Query validation and timeout protection - - Result limiting (max 1000 rows) -- ๐Ÿ”’ **Security Features**: - - Explicit folder/database approval required - - System directory blocking - - Write mode OFF by default (non-persistent) - - Delete operations always require user confirmation - - Read-only database access - - SQL injection protection - - Query timeout (5 seconds) +- โœ๏ธ **Write Mode**: AI can modify files with permission + - Create, edit, delete files + - Move, copy, organize files + - Always requires explicit opt-in + +- ๐Ÿ—„๏ธ **Database Mode**: AI can query SQLite databases + - Read-only access (safe) + - Schema inspection + - Full SQL query support ## Requirements -- Python 3.10-3.13 (3.14 not supported yet) -- OpenRouter API key (get one at https://openrouter.ai) -- Function-calling model required for MCP features (GPT-4, Claude, etc.) - -## Screenshot - -[](https://gitlab.pm/rune/oai/src/branch/main/README.md) - -*Screenshot from version 1.0 - MCP interface shows mode indicators like `[๐Ÿ”ง MCP: Files]` or `[๐Ÿ—„๏ธ MCP: DB #1]`* +- Python 3.10-3.13 +- OpenRouter API key ([get one here](https://openrouter.ai)) ## Installation -### Option 1: From Source (Recommended for Development) - -#### 1. Install Dependencies +### Option 1: Install from Source (Recommended) ```bash -pip install -r requirements.txt +# Clone the repository +git clone https://gitlab.pm/rune/oai.git +cd oai + +# Install with pip +pip install -e . ``` -#### 2. Make Executable +### Option 2: Pre-built Binary (macOS/Linux) -```bash -chmod +x oai.py -``` - -#### 3. Copy to PATH - -```bash -# Option 1: System-wide (requires sudo) -sudo cp oai.py /usr/local/bin/oai - -# Option 2: User-local (recommended) -mkdir -p ~/.local/bin -cp oai.py ~/.local/bin/oai - -# Add to PATH if needed (add to ~/.bashrc or ~/.zshrc) -export PATH="$HOME/.local/bin:$PATH" -``` - -#### 4. Verify Installation - -```bash -oai --version -``` - -### Option 2: Pre-built Binaries - -Download platform-specific binaries: -- **macOS (Apple Silicon)**: `oai_vx.x.x_mac_arm64.zip` -- **Linux (x86_64)**: `oai_vx.x.x-linux-x86_64.zip` +Download from [Releases](https://gitlab.pm/rune/oai/releases): +- **macOS (Apple Silicon)**: `oai_v2.1.0_mac_arm64.zip` +- **Linux (x86_64)**: `oai_v2.1.0_linux_x86_64.zip` ```bash # Extract and install -unzip oai_vx.x.x_mac_arm64.zip # or `oai_vx.x.x-linux-x86_64.zip` -chmod +x oai -mkdir -p ~/.local/bin # Remember to add this to your path. Or just move to folder already in your $PATH +unzip oai_v2.1.0_*.zip +mkdir -p ~/.local/bin mv oai ~/.local/bin/ + +# macOS only: Remove quarantine and approve +xattr -cr ~/.local/bin/oai +# Then right-click oai in Finder โ†’ Open With โ†’ Terminal โ†’ Click "Open" ``` - -### Alternative: Shell Alias +### Add to PATH ```bash -# Add to ~/.bashrc or ~/.zshrc -alias oai='python3 /path/to/oai.py' +# Add to ~/.zshrc or ~/.bashrc +export PATH="$HOME/.local/bin:$PATH" ``` ## Quick Start -### First Run Setup +```bash +# Start the chat client +oai chat + +# Or with options +oai chat --model gpt-4o --mcp +``` + +On first run, you'll be prompted for your OpenRouter API key. + +### Basic Commands ```bash -oai +# In the chat interface: +/model # Select AI model +/help # Show all commands +/mcp on # Enable file/database access +/stats # View session statistics +exit # Quit ``` -On first run, you'll be prompted to enter your OpenRouter API key. +## MCP (Model Context Protocol) -### Basic Usage +MCP allows the AI to interact with your local files and databases. + +### File Access ```bash -# Start chatting -oai +/mcp on # Enable MCP +/mcp add ~/Projects # Grant access to folder +/mcp list # View allowed folders -# Select a model -You> /model - -# Enable MCP for file access -You> /mcp on -You> /mcp add ~/Documents - -# Ask AI to help with files (read-only) -[๐Ÿ”ง MCP: Files] You> List all Python files in Documents -[๐Ÿ”ง MCP: Files] You> Read and explain main.py - -# Enable write mode to let AI modify files -You> /mcp write on -[๐Ÿ”งโœ๏ธ MCP: Files+Write] You> Create a new Python file with helper functions -[๐Ÿ”งโœ๏ธ MCP: Files+Write] You> Refactor main.py to use async/await - -# Switch to database mode -You> /mcp add db ~/myapp/data.db -You> /mcp db 1 -[๐Ÿ—„๏ธ MCP: DB #1] You> Show me all tables -[๐Ÿ—„๏ธ MCP: DB #1] You> Find all users created this month -``` - -## MCP Guide - -### File Mode (Default) - -**Setup:** -```bash -/mcp on # Start MCP server -/mcp add ~/Projects # Grant access to folder -/mcp add ~/Documents # Add another folder -/mcp list # View all allowed folders -``` - -**Natural Language Usage:** -``` +# Now ask the AI: "List all Python files in Projects" -"Read and explain config.yaml" +"Read and explain main.py" "Search for files containing 'TODO'" -"What's in my Documents folder?" ``` -**Available Tools (Read-Only):** -- `read_file` - Read complete file contents -- `list_directory` - List files/folders (recursive optional) -- `search_files` - Search by name or content - -**Available Tools (Write Mode - requires `/mcp write on`):** -- `write_file` - Create new files or overwrite existing ones -- `edit_file` - Find and replace text in existing files -- `delete_file` - Delete files (always requires confirmation) -- `create_directory` - Create directories -- `move_file` - Move or rename files -- `copy_file` - Copy files to new locations - -**Features:** -- โœ… Automatic .gitignore filtering (read operations only) -- โœ… Skips virtual environments (venv, node_modules) -- โœ… Handles large files (auto-truncates >50KB) -- โœ… Cross-platform (macOS, Linux, Windows via WSL) -- โœ… Write mode OFF by default for safety -- โœ… Delete operations require user confirmation with LLM's reason - -### Database Mode - -**Setup:** -```bash -/mcp add db ~/app/database.db # Add SQLite database -/mcp db list # View all databases -/mcp db 1 # Switch to database #1 -``` - -**Natural Language Usage:** -``` -"Show me all tables in this database" -"Find records mentioning 'error'" -"How many users registered last week?" -"Get the schema for the orders table" -"Show me the 10 most recent transactions" -``` - -**Available Tools:** -- `inspect_database` - View schema, tables, columns, indexes -- `search_database` - Full-text search across tables -- `query_database` - Execute read-only SQL queries - -**Supported Queries:** -- โœ… SELECT statements -- โœ… JOINs (INNER, LEFT, RIGHT, FULL) -- โœ… Subqueries -- โœ… CTEs (Common Table Expressions) -- โœ… Aggregations (COUNT, SUM, AVG, etc.) -- โœ… WHERE, GROUP BY, HAVING, ORDER BY, LIMIT -- โŒ INSERT/UPDATE/DELETE (blocked for safety) - ### Write Mode -**Enable Write Mode:** ```bash -/mcp write on # Enable write mode (shows warning, requires confirmation) +/mcp write on # Enable file modifications + +# AI can now: +"Create a new file called utils.py" +"Edit config.json and update the API URL" +"Delete the old backup files" # Always asks for confirmation ``` -**Natural Language Usage:** -``` -"Create a new Python file called utils.py with helper functions" -"Edit main.py and replace the old API endpoint with the new one" -"Delete the backup.old file" (will prompt for confirmation) -"Create a directory called tests" -"Move config.json to the config folder" -``` - -**Important:** -- โš ๏ธ Write mode is **OFF by default** and resets each session -- โš ๏ธ Delete operations **always** require user confirmation -- โš ๏ธ All operations are limited to allowed MCP folders -- โœ… Write operations ignore .gitignore (can write to any file in allowed folders) - -**Disable Write Mode:** -```bash -/mcp write off # Disable write mode (back to read-only) -``` - -### Mode Management +### Database Mode ```bash -/mcp status # Show current mode, write mode, stats, folders/databases -/mcp files # Switch to file mode -/mcp db # Switch to database mode -/mcp gitignore on # Enable .gitignore filtering (default) -/mcp write on|off # Enable/disable write mode -/mcp remove 2 # Remove folder/database by number +/mcp add db ~/app/data.db # Add database +/mcp db 1 # Switch to database mode + +# Ask the AI: +"Show all tables" +"Find users created this month" +"What's the schema for the orders table?" ``` ## Command Reference -### Session Commands -``` -/help [command] Show help menu or detailed command help -/help mcp Comprehensive MCP guide -/clear or /cl Clear terminal screen (or Ctrl+L) -/memory on|off Toggle conversation memory (save costs) -/online on|off Enable/disable web search -/paste [prompt] Paste clipboard content -/retry Resend last prompt -/reset Clear history and system prompt -/prev View previous response -/next View next response -``` +### Chat Commands +| Command | Description | +|---------|-------------| +| `/help [cmd]` | Show help | +| `/model [search]` | Select model | +| `/info [model]` | Model details | +| `/memory on\|off` | Toggle context | +| `/online on\|off` | Toggle web search | +| `/retry` | Resend last message | +| `/clear` | Clear screen | ### MCP Commands -``` -/mcp on Start MCP server -/mcp off Stop MCP server -/mcp status Show comprehensive status (includes write mode) -/mcp add Add folder for file access -/mcp add db Add SQLite database -/mcp list List all folders -/mcp db list List all databases -/mcp db Switch to database mode -/mcp files Switch to file mode -/mcp remove Remove folder/database -/mcp gitignore on Enable .gitignore filtering -/mcp write on Enable write mode (create/edit/delete files) -/mcp write off Disable write mode (read-only) -``` +| Command | Description | +|---------|-------------| +| `/mcp on\|off` | Enable/disable MCP | +| `/mcp status` | Show MCP status | +| `/mcp add ` | Add folder | +| `/mcp add db ` | Add database | +| `/mcp list` | List folders | +| `/mcp db list` | List databases | +| `/mcp db ` | Switch to database | +| `/mcp files` | Switch to file mode | +| `/mcp write on\|off` | Toggle write mode | -### Model Commands -``` -/model [search] Select/change AI model -/info [model_id] Show model details (pricing, capabilities) -``` +### Conversation Commands +| Command | Description | +|---------|-------------| +| `/save ` | Save conversation | +| `/load ` | Load conversation | +| `/list` | List saved conversations | +| `/delete ` | Delete conversation | +| `/export md\|json\|html ` | Export | ### Configuration -``` -/config View all settings -/config api Set API key -/config model Set default model -/config online Set default online mode (on|off) -/config stream Enable/disable streaming (on|off) -/config maxtoken Set max token limit -/config costwarning Set cost warning threshold ($) -/config loglevel Set log level (debug/info/warning/error) -/config log Set log file size (MB) -``` +| Command | Description | +|---------|-------------| +| `/config` | View settings | +| `/config api` | Set API key | +| `/config model ` | Set default model | +| `/config stream on\|off` | Toggle streaming | +| `/stats` | Session statistics | +| `/credits` | Check credits | -### Conversation Management -``` -/save Save conversation -/load Load saved conversation -/delete Delete conversation -/list List saved conversations -/export md|json|html Export conversation -``` +## CLI Options -### Token & System -``` -/maxtoken [value] Set session token limit -/system [prompt] Set system prompt (use 'clear' to reset) -/middleout on|off Enable prompt compression -``` - -### Monitoring -``` -/stats View session statistics -/credits Check OpenRouter credits -``` - -### File Attachments -``` -@/path/to/file Attach file (images, PDFs, code) - -Examples: - Debug @script.py - Analyze @data.json - Review @screenshot.png -``` - -## Configuration Options - -All configuration stored in `~/.config/oai/`: - -### Files -- `oai_config.db` - SQLite database (settings, conversations, MCP config) -- `oai.log` - Application logs (rotating, configurable size) -- `history.txt` - Command history (searchable with Ctrl+R) - -### Key Settings -- **API Key**: OpenRouter authentication -- **Default Model**: Auto-select on startup -- **Streaming**: Real-time response display -- **Max Tokens**: Global and session limits -- **Cost Warning**: Alert threshold for expensive requests -- **Online Mode**: Default web search setting -- **Log Level**: debug/info/warning/error/critical -- **Log Size**: Rotating file size in MB - -## Supported File Types - -### Code Files -`.py, .js, .ts, .cs, .java, .c, .cpp, .h, .hpp, .rb, .ruby, .php, .swift, .kt, .kts, .go, .sh, .bat, .ps1, .R, .scala, .pl, .lua, .dart, .elm` - -### Data Files -`.json, .yaml, .yml, .xml, .csv, .txt, .md` - -### Images -All standard formats: PNG, JPEG, JPG, GIF, WEBP, BMP - -### Documents -PDF (models with document support) - -### Size Limits -- Images: 10 MB max -- Code/Text: Auto-truncates files >50KB -- Binary data: Displayed as `` - -## MCP Security - -### Access Control -- โœ… Explicit folder/database approval required -- โœ… System directories blocked automatically -- โœ… User confirmation for each addition -- โœ… .gitignore patterns respected (file mode) - -### Database Safety -- โœ… Read-only mode (cannot modify data) -- โœ… SQL query validation (blocks INSERT/UPDATE/DELETE) -- โœ… Query timeout (5 seconds max) -- โœ… Result limits (1000 rows max) -- โœ… Database opened in `mode=ro` - -### File System Safety -- โœ… Read-only by default (write mode requires explicit opt-in) -- โœ… Write mode OFF by default each session (non-persistent) -- โœ… Delete operations always require user confirmation -- โœ… Write operations limited to allowed folders only -- โœ… System directories blocked -- โœ… Virtual environment exclusion -- โœ… Build artifact filtering -- โœ… Maximum file size (10 MB) - -## Tips & Tricks - -### Command History -- **โ†‘/โ†“ arrows**: Navigate previous commands -- **Ctrl+R**: Search command history -- **Auto-complete**: Start typing `/` for command suggestions - -### Cost Optimization ```bash -/memory off # Disable context (stateless mode) -/maxtoken 1000 # Limit response length -/config costwarning 0.01 # Set alert threshold +oai chat [OPTIONS] + +Options: + -m, --model TEXT Model ID to use + -s, --system TEXT System prompt + -o, --online Enable online mode + --mcp Enable MCP server + --help Show help ``` -### MCP Best Practices +Other commands: ```bash -# Check status frequently -/mcp status - -# Use specific paths to reduce search time -"List Python files in Projects/app/" # Better than -"List all Python files" # Slower - -# Database queries - be specific -"SELECT * FROM users LIMIT 10" # Good -"SELECT * FROM users" # May hit row limit +oai config [setting] [value] # Configure settings +oai version # Show version +oai credits # Check credits ``` -### Debugging -```bash -# Enable debug logging -/config loglevel debug +## Configuration -# Check log file -tail -f ~/.config/oai/oai.log +Configuration is stored in `~/.config/oai/`: -# View MCP statistics -/mcp status # Shows tool call counts +| File | Purpose | +|------|---------| +| `oai_config.db` | Settings, conversations, MCP config | +| `oai.log` | Application logs | +| `history.txt` | Command history | + +## Project Structure + +``` +oai/ +โ”œโ”€โ”€ oai/ +โ”‚ โ”œโ”€โ”€ __init__.py +โ”‚ โ”œโ”€โ”€ __main__.py # Entry point for python -m oai +โ”‚ โ”œโ”€โ”€ cli.py # Main CLI interface +โ”‚ โ”œโ”€โ”€ constants.py # Configuration constants +โ”‚ โ”œโ”€โ”€ commands/ # Slash command handlers +โ”‚ โ”œโ”€โ”€ config/ # Settings and database +โ”‚ โ”œโ”€โ”€ core/ # Chat client and session +โ”‚ โ”œโ”€โ”€ mcp/ # MCP server and tools +โ”‚ โ”œโ”€โ”€ providers/ # AI provider abstraction +โ”‚ โ”œโ”€โ”€ ui/ # Terminal UI utilities +โ”‚ โ””โ”€โ”€ utils/ # Logging, export, etc. +โ”œโ”€โ”€ pyproject.toml # Package configuration +โ”œโ”€โ”€ build.sh # Binary build script +โ””โ”€โ”€ README.md ``` ## Troubleshooting -### MCP Not Working -```bash -# 1. Check if MCP is installed -python3 -c "import mcp; print('MCP OK')" +### macOS Binary Issues -# 2. Verify model supports function calling +```bash +# Remove quarantine attribute +xattr -cr ~/.local/bin/oai + +# Then in Finder: right-click oai โ†’ Open With โ†’ Terminal โ†’ Click "Open" +# After this, oai works from any terminal +``` + +### MCP Not Working + +```bash +# Check if model supports function calling /info # Look for "tools" in supported parameters -# 3. Check MCP status +# Check MCP status /mcp status -# 4. Review logs -tail ~/.config/oai/oai.log +# View logs +tail -f ~/.config/oai/oai.log ``` ### Import Errors + ```bash -# Reinstall dependencies -pip install --force-reinstall -r requirements.txt -``` - -### Binary Issues (macOS) -```bash -# Remove quarantine -xattr -cr ~/.local/bin/oai - -# Check security settings -# System Settings > Privacy & Security > "Allow Anyway" -``` - -### Database Errors -```bash -# Verify it's a valid SQLite database -sqlite3 database.db ".tables" - -# Check file permissions -ls -la database.db +# Reinstall package +pip install -e . --force-reinstall ``` ## Version History -### v2.1.0-RC1 (Current) -- โœจ **NEW**: MCP (Model Context Protocol) integration -- โœจ **NEW**: File system access (read, search, list) -- โœจ **NEW**: Write mode - AI can create, edit, and delete files - - 6 write tools: write_file, edit_file, delete_file, create_directory, move_file, copy_file - - OFF by default - requires explicit `/mcp write on` activation - - Delete operations always require user confirmation - - Non-persistent setting (resets each session) -- โœจ **NEW**: SQLite database querying (read-only) -- โœจ **NEW**: Dual mode support (Files & Database) -- โœจ **NEW**: .gitignore filtering -- โœจ **NEW**: Binary data handling in databases -- โœจ **NEW**: Mode indicators in prompt (shows โœ๏ธ when write mode active) -- โœจ **NEW**: Comprehensive `/help mcp` guide -- ๐Ÿ”ง Improved error handling for tool calls -- ๐Ÿ”ง Enhanced logging for MCP operations -- ๐Ÿ”ง Statistics tracking for tool usage +### v2.1.0 (Current) +- ๐Ÿ—๏ธ Complete codebase refactoring to modular package structure +- ๐Ÿ”Œ Extensible provider architecture for adding new AI providers +- ๐Ÿ“ฆ Proper Python packaging with pyproject.toml +- โœจ MCP integration (file access, write mode, database queries) +- ๐Ÿ”ง Command registry pattern for slash commands +- ๐Ÿ“Š Improved cost tracking and session statistics -### v1.9.6 -- Base version with core chat functionality -- Conversation management +### v1.9.x +- Single-file implementation +- Core chat functionality - File attachments -- Cost tracking -- Export capabilities +- Conversation management ## License -MIT License - -Copyright (c) 2024-2025 Rune Olsen - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -Full license: https://opensource.org/licenses/MIT +MIT License - See [LICENSE](LICENSE) for details. ## Author **Rune Olsen** - -- Homepage: https://ai.fubar.pm/ -- Blog: https://blog.rune.pm - Project: https://iurl.no/oai +- Repository: https://gitlab.pm/rune/oai ## Contributing -Contributions welcome! Please: 1. Fork the repository 2. Create a feature branch -3. Submit a pull request with detailed description - -## Acknowledgments - -- OpenRouter team for the unified AI API -- Rich library for beautiful terminal output -- MCP community for the protocol specification +3. Submit a pull request --- -**Star โญ this project if you find it useful!** - ---- - -Did you really read all the way down here? WOW! You deserve a ๐Ÿพ ๐Ÿฅ‚! +**โญ Star this project if you find it useful!** diff --git a/oai.py b/oai.py deleted file mode 100644 index d3787c2..0000000 --- a/oai.py +++ /dev/null @@ -1,5987 +0,0 @@ -#!/usr/bin/python3 -W ignore::DeprecationWarning -import sys -import os -import requests -import time -import asyncio -from pathlib import Path -from typing import Optional, List, Dict, Any -import typer -from rich.console import Console -from rich.panel import Panel -from rich.table import Table -from rich.text import Text -from rich.markdown import Markdown -from rich.live import Live -from openrouter import OpenRouter -import pyperclip -import mimetypes -import base64 -import re -import sqlite3 -import json -import datetime -import logging -from logging.handlers import RotatingFileHandler -from prompt_toolkit import PromptSession -from prompt_toolkit.history import FileHistory -from rich.logging import RichHandler -from prompt_toolkit.auto_suggest import AutoSuggestFromHistory -from prompt_toolkit.key_binding import KeyBindings -from prompt_toolkit.filters import Condition -from prompt_toolkit.application.current import get_app -from packaging import version as pkg_version -import io -import platform -import shutil -import subprocess -import fnmatch -import signal - -# MCP imports -try: - from mcp import ClientSession, StdioServerParameters - from mcp.client.stdio import stdio_client - MCP_AVAILABLE = True -except ImportError: - MCP_AVAILABLE = False - print("Warning: MCP library not found. Install with: pip install mcp") - -# App version -version = '2.1.0-RC2' - -app = typer.Typer() - -# Application identification -APP_NAME = "oAI" -APP_URL = "https://iurl.no/oai" - -# Paths -home = Path.home() -config_dir = home / '.config' / 'oai' -cache_dir = home / '.cache' / 'oai' -history_file = config_dir / 'history.txt' -database = config_dir / 'oai_config.db' -log_file = config_dir / 'oai.log' - -# Create dirs -config_dir.mkdir(parents=True, exist_ok=True) -cache_dir.mkdir(parents=True, exist_ok=True) - -# Rich console -console = Console() - -# Valid commands -VALID_COMMANDS = { - '/retry', '/online', '/memory', '/paste', '/export', '/save', '/load', - '/delete', '/list', '/prev', '/next', '/stats', '/middleout', '/reset', - '/info', '/model', '/maxtoken', '/system', '/config', '/credits', '/clear', - '/cl', '/help', '/mcp' -} - -# Command help database (COMPLETE - includes MCP comprehensive guide) -COMMAND_HELP = { - '/clear': { - 'aliases': ['/cl'], - 'description': 'Clear the terminal screen for a clean interface.', - 'usage': '/clear\n/cl', - 'examples': [ - ('Clear screen', '/clear'), - ('Using short alias', '/cl'), - ], - 'notes': 'You can also use the keyboard shortcut Ctrl+L to clear the screen.' - }, - '/help': { - 'description': 'Display help information for commands.', - 'usage': '/help [command|topic]', - 'examples': [ - ('Show all commands', '/help'), - ('Get help for a specific command', '/help /model'), - ('Get detailed MCP help', '/help mcp'), - ], - 'notes': 'Use /help without arguments to see the full command list, /help for detailed command info, or /help mcp for comprehensive MCP documentation.' - }, - 'mcp': { - 'description': 'Complete guide to MCP (Model Context Protocol) - file access and database querying for AI agents.', - 'usage': 'See detailed examples below', - 'examples': [], - 'notes': ''' -MCP (Model Context Protocol) gives your AI assistant direct access to: - โ€ข Local files and folders (read, search, list) - โ€ข SQLite databases (inspect, search, query) - -โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ -โ”‚ ๐Ÿ—‚๏ธ FILE MODE โ”‚ -โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ - -SETUP: - /mcp on Start MCP server - /mcp add ~/Documents Grant access to folder - /mcp add ~/Code/project Add another folder - /mcp list View all allowed folders - -USAGE (just ask naturally): - "List all Python files in my Code folder" - "Read the contents of Documents/notes.txt" - "Search for files containing 'budget'" - "What's in my Documents folder?" - -FEATURES: - โœ“ Automatically respects .gitignore patterns - โœ“ Skips virtual environments (venv, node_modules, etc.) - โœ“ Handles large files (auto-truncates >50KB) - โœ“ Cross-platform (macOS, Linux, Windows) - -MANAGEMENT: - /mcp remove ~/Desktop Remove folder access - /mcp remove 2 Remove by number - /mcp gitignore on|off Toggle .gitignore filtering - /mcp status Show comprehensive status - -โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ -โ”‚ โœ๏ธ WRITE MODE (OPTIONAL) โ”‚ -โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ - -ENABLE WRITE MODE: - /mcp write on Enable file modifications (requires confirmation) - /mcp write off Disable write mode (back to read-only) - -WHAT WRITE MODE ALLOWS: - โ€ข Create new files or overwrite existing ones - โ€ข Edit files (find and replace text) - โ€ข Delete files (always requires confirmation) - โ€ข Create directories - โ€ข Move and copy files - -USAGE (after enabling write mode): - "Create a new Python file called utils.py" - "Edit main.py and update the API endpoint" - "Delete the old backup.txt file" - "Create a tests directory" - "Move config.json to the configs folder" - -SAFETY FEATURES: - โœ“ OFF by default (resets each session) - โœ“ Requires explicit activation with user confirmation - โœ“ Delete operations ALWAYS require user confirmation - โœ“ Limited to allowed MCP folders only - โœ“ Ignores .gitignore (can write to any file in allowed folders) - โœ“ System directories remain blocked - -IMPORTANT: - โš ๏ธ Write mode is powerful - use with caution! - โš ๏ธ Always review what the AI plans to modify - โš ๏ธ Deletions will show file details + AI's reason for confirmation - -โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ -โ”‚ ๐Ÿ—„๏ธ DATABASE MODE โ”‚ -โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ - -SETUP: - /mcp add db ~/app/data.db Add specific database - /mcp add db ~/.config/oai/oai_config.db - /mcp db list View all databases - -SWITCH TO DATABASE: - /mcp db 1 Work with database #1 - /mcp db 2 Switch to database #2 - /mcp files Switch back to file mode - -USAGE (after selecting a database): - "Show me all tables in this database" - "Find any records mentioning 'Santa Clause'" - "How many users are in the database?" - "Show me the schema for the orders table" - "Get the 10 most recent transactions" - -FEATURES: - โœ“ Read-only mode (no data modification possible) - โœ“ Smart schema inspection - โœ“ Full-text search across all tables - โœ“ SQL query execution (SELECT only) - โœ“ JOINs, subqueries, CTEs supported - โœ“ Automatic query timeout (5 seconds) - โœ“ Result limits (max 1000 rows) - -SAFETY: - โ€ข All queries are read-only (INSERT/UPDATE/DELETE blocked) - โ€ข Database opened in read-only mode - โ€ข Query validation prevents dangerous operations - โ€ข Timeout protection prevents infinite loops - -โ•ญโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฎ -โ”‚ ๐Ÿ’ก TIPS & TRICKS โ”‚ -โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ - -MODE INDICATORS: - [๐Ÿ”ง MCP: Files] You're in file mode (read-only) - [๐Ÿ”งโœ๏ธ MCP: Files+Write] You're in file mode with write permissions - [๐Ÿ—„๏ธ MCP: DB #1] You're querying database #1 - -QUICK REFERENCE: - /mcp status See current mode, write mode, stats, folders/databases - /mcp files Switch to file mode (default) - /mcp db Switch to database mode - /mcp db list List all databases with details - /mcp list List all folders - /mcp write on Enable write mode (create/edit/delete files) - /mcp write off Disable write mode (read-only) - -TROUBLESHOOTING: - โ€ข No results? Check /mcp status to see what's accessible - โ€ข Wrong mode? Use /mcp files or /mcp db to switch - โ€ข Database errors? Ensure file exists and is valid SQLite - โ€ข .gitignore not working? Check /mcp status shows patterns - -SECURITY NOTES: - โ€ข MCP only accesses explicitly added folders/databases - โ€ข File mode: read-only by default (write mode requires opt-in) - โ€ข Write mode: OFF by default, resets each session - โ€ข Delete operations: Always require user confirmation - โ€ข Database mode: SELECT queries only (no modifications) - โ€ข System directories are blocked automatically - โ€ข Each addition requires your explicit confirmation - -For command-specific help: /help /mcp - ''' - }, - '/mcp': { - 'description': 'Manage MCP (Model Context Protocol) for local file access and SQLite database querying. When enabled with a function-calling model, AI can automatically search, read, list files, and query databases. Supports two modes: Files (default) and Database.', - 'usage': '/mcp [args]', - 'examples': [ - ('Enable MCP server', '/mcp on'), - ('Disable MCP server', '/mcp off'), - ('Show MCP status and current mode', '/mcp status'), - ('', ''), - ('โ”โ”โ” FILE MODE โ”โ”โ”', ''), - ('Add folder for file access', '/mcp add ~/Documents'), - ('Remove folder by path', '/mcp remove ~/Desktop'), - ('Remove folder by number', '/mcp remove 2'), - ('List allowed folders', '/mcp list'), - ('Toggle .gitignore filtering', '/mcp gitignore on'), - ('Enable write mode', '/mcp write on'), - ('Disable write mode', '/mcp write off'), - ('', ''), - ('โ”โ”โ” DATABASE MODE โ”โ”โ”', ''), - ('Add SQLite database', '/mcp add db ~/app/data.db'), - ('List all databases', '/mcp db list'), - ('Switch to database #1', '/mcp db 1'), - ('Remove database', '/mcp remove db 2'), - ('Switch back to file mode', '/mcp files'), - ], - 'notes': '''MCP allows AI to read local files and query SQLite databases. Works automatically with function-calling models (GPT-4, Claude, etc.). - -FILE MODE (default): -- Read-only by default (write mode requires opt-in) -- Automatically loads and respects .gitignore patterns -- Skips virtual environments and build artifacts -- Supports search, read, and list operations - -WRITE MODE (optional): -- Enable with /mcp write on (requires confirmation) -- Allows creating, editing, and deleting files -- OFF by default, resets each session -- Delete operations always require user confirmation -- Limited to allowed folders only - -DATABASE MODE: -- Read-only access (no data modification) -- Execute SELECT queries with JOINs, subqueries, CTEs -- Full-text search across all tables -- Schema inspection and data exploration -- Automatic query validation and timeout protection - -Use /help mcp for comprehensive guide with examples.''' - }, - '/memory': { - 'description': 'Toggle conversation memory. When ON, the AI remembers conversation history. When OFF, each request is independent (saves tokens and cost).', - 'usage': '/memory [on|off]', - 'examples': [ - ('Check current memory status', '/memory'), - ('Enable conversation memory', '/memory on'), - ('Disable memory (save costs)', '/memory off'), - ], - 'notes': 'Memory is ON by default. Disabling memory reduces API costs but the AI won\'t remember previous messages. Messages are still saved locally for your reference.' - }, - '/online': { - 'description': 'Enable or disable online mode (web search capabilities) for the current session.', - 'usage': '/online [on|off]', - 'examples': [ - ('Check online mode status', '/online'), - ('Enable web search', '/online on'), - ('Disable web search', '/online off'), - ], - 'notes': 'Not all models support online mode. The model must have "tools" parameter support. This setting overrides the default online mode configured with /config online.' - }, - '/paste': { - 'description': 'Paste plain text or code from clipboard and send to the AI. Optionally add a prompt.', - 'usage': '/paste [prompt]', - 'examples': [ - ('Paste clipboard content', '/paste'), - ('Paste with a question', '/paste Explain this code'), - ('Paste and ask for review', '/paste Review this for bugs'), - ], - 'notes': 'Only plain text is supported. Binary clipboard data will be rejected. The clipboard content is shown as a preview before sending.' - }, - '/retry': { - 'description': 'Resend the last prompt from conversation history.', - 'usage': '/retry', - 'examples': [ - ('Retry last message', '/retry'), - ], - 'notes': 'Useful when you get an error or want a different response to the same prompt. Requires at least one message in history.' - }, - '/next': { - 'description': 'View the next response in conversation history.', - 'usage': '/next', - 'examples': [ - ('Navigate to next response', '/next'), - ], - 'notes': 'Navigate through conversation history. Use /prev to go backward.' - }, - '/prev': { - 'description': 'View the previous response in conversation history.', - 'usage': '/prev', - 'examples': [ - ('Navigate to previous response', '/prev'), - ], - 'notes': 'Navigate through conversation history. Use /next to go forward.' - }, - '/reset': { - 'description': 'Clear conversation history and reset system prompt. This resets all session metrics.', - 'usage': '/reset', - 'examples': [ - ('Reset conversation', '/reset'), - ], - 'notes': 'Requires confirmation. This clears all message history, resets the system prompt, and resets token/cost counters. Use when starting a completely new conversation topic.' - }, - '/info': { - 'description': 'Display detailed information about a model including pricing, capabilities, context length, and online support.', - 'usage': '/info [model_id]', - 'examples': [ - ('Show current model info', '/info'), - ('Show specific model info', '/info gpt-4o'), - ('Check model capabilities', '/info claude-3-opus'), - ], - 'notes': 'Without arguments, shows info for the currently selected model. Displays pricing per million tokens, supported modalities (text, image, etc.), and parameter support.' - }, - '/model': { - 'description': 'Select or change the AI model for the current session. Shows image, online, and tool capabilities.', - 'usage': '/model [search_term]', - 'examples': [ - ('List all models', '/model'), - ('Search for GPT models', '/model gpt'), - ('Search for Claude models', '/model claude'), - ], - 'notes': 'Models are numbered for easy selection. The table shows Image (โœ“ if model accepts images), Online (โœ“ if model supports web search), and Tools (โœ“ if model supports function calling for MCP).' - }, - '/config': { - 'description': 'View or modify application configuration settings.', - 'usage': '/config [setting] [value]', - 'examples': [ - ('View all settings', '/config'), - ('Set API key', '/config api'), - ('Set default model', '/config model'), - ('Enable streaming', '/config stream on'), - ('Set cost warning threshold', '/config costwarning 0.05'), - ('Set log level', '/config loglevel debug'), - ('Set default online mode', '/config online on'), - ], - 'notes': 'Available settings: api (API key), url (base URL), model (default model), stream (on/off), costwarning (threshold $), maxtoken (limit), online (default on/off), log (size MB), loglevel (debug/info/warning/error/critical).' - }, - '/maxtoken': { - 'description': 'Set a temporary session token limit (cannot exceed stored max token limit).', - 'usage': '/maxtoken [value]', - 'examples': [ - ('View current session limit', '/maxtoken'), - ('Set session limit to 2000', '/maxtoken 2000'), - ('Set to 50000', '/maxtoken 50000'), - ], - 'notes': 'This is a session-only setting and cannot exceed the stored max token limit (set with /config maxtoken). View without arguments to see current value.' - }, - '/system': { - 'description': 'Set or clear the session-level system prompt to guide AI behavior.', - 'usage': '/system [prompt|clear]', - 'examples': [ - ('View current system prompt', '/system'), - ('Set as Python expert', '/system You are a Python expert'), - ('Set as code reviewer', '/system You are a senior code reviewer. Focus on bugs and best practices.'), - ('Clear system prompt', '/system clear'), - ], - 'notes': 'System prompts influence how the AI responds throughout the session. Use "clear" to remove the current system prompt.' - }, - '/save': { - 'description': 'Save the current conversation history to the database.', - 'usage': '/save ', - 'examples': [ - ('Save conversation', '/save my_chat'), - ('Save with descriptive name', '/save python_debugging_2024'), - ], - 'notes': 'Saved conversations can be loaded later with /load. Use descriptive names to easily find them later.' - }, - '/load': { - 'description': 'Load a saved conversation from the database by name or number.', - 'usage': '/load ', - 'examples': [ - ('Load by name', '/load my_chat'), - ('Load by number from /list', '/load 3'), - ], - 'notes': 'Use /list to see numbered conversations. Loading a conversation replaces current history and resets session metrics.' - }, - '/delete': { - 'description': 'Delete a saved conversation from the database. Requires confirmation.', - 'usage': '/delete ', - 'examples': [ - ('Delete by name', '/delete my_chat'), - ('Delete by number from /list', '/delete 3'), - ], - 'notes': 'Use /list to see numbered conversations. This action cannot be undone!' - }, - '/list': { - 'description': 'List all saved conversations with numbers, message counts, and timestamps.', - 'usage': '/list', - 'examples': [ - ('Show saved conversations', '/list'), - ], - 'notes': 'Conversations are numbered for easy use with /load and /delete commands. Shows message count and last saved time.' - }, - '/export': { - 'description': 'Export the current conversation to a file in various formats.', - 'usage': '/export ', - 'examples': [ - ('Export as Markdown', '/export md notes.md'), - ('Export as JSON', '/export json conversation.json'), - ('Export as HTML', '/export html report.html'), - ], - 'notes': 'Available formats: md (Markdown), json (JSON), html (HTML). The export includes all messages and the system prompt if set.' - }, - '/stats': { - 'description': 'Display session statistics including tokens, costs, and credits.', - 'usage': '/stats', - 'examples': [ - ('View session statistics', '/stats'), - ], - 'notes': 'Shows total input/output tokens, total cost, average cost per message, and remaining credits. Also displays credit warnings if applicable.' - }, - '/credits': { - 'description': 'Display your OpenRouter account credits and usage.', - 'usage': '/credits', - 'examples': [ - ('Check credits', '/credits'), - ], - 'notes': 'Shows total credits, used credits, and credits left. Displays warnings if credits are low (< $1 or < 10% of total).' - }, - '/middleout': { - 'description': 'Enable or disable middle-out transform to compress prompts exceeding context size.', - 'usage': '/middleout [on|off]', - 'examples': [ - ('Check status', '/middleout'), - ('Enable compression', '/middleout on'), - ('Disable compression', '/middleout off'), - ], - 'notes': 'Middle-out transform intelligently compresses long prompts to fit within model context limits. Useful for very long conversations.' - }, -} - -# Supported code extensions -SUPPORTED_CODE_EXTENSIONS = { - '.py', '.js', '.ts', '.cs', '.java', '.c', '.cpp', '.h', '.hpp', - '.rb', '.ruby', '.php', '.swift', '.kt', '.kts', '.go', - '.sh', '.bat', '.ps1', '.R', '.scala', '.pl', '.lua', '.dart', - '.elm', '.xml', '.json', '.yaml', '.yml', '.md', '.txt' -} - -# Pricing -MODEL_PRICING = { - 'input': 3.0, - 'output': 15.0 -} -LOW_CREDIT_RATIO = 0.1 -LOW_CREDIT_AMOUNT = 1.0 -HIGH_COST_WARNING = "cost_warning_threshold" - -# Valid log levels -VALID_LOG_LEVELS = { - 'debug': logging.DEBUG, - 'info': logging.INFO, - 'warning': logging.WARNING, - 'error': logging.ERROR, - 'critical': logging.CRITICAL -} - -# System directories to block (cross-platform) -SYSTEM_DIRS_BLACKLIST = { - '/System', '/Library', '/private', '/usr', '/bin', '/sbin', # macOS - '/boot', '/dev', '/proc', '/sys', '/root', # Linux - 'C:\\Windows', 'C:\\Program Files', 'C:\\Program Files (x86)' # Windows -} - -# Database functions -def create_table_if_not_exists(): - """Ensure tables exist.""" - os.makedirs(config_dir, exist_ok=True) - with sqlite3.connect(str(database)) as conn: - conn.execute('''CREATE TABLE IF NOT EXISTS config ( - key TEXT PRIMARY KEY, - value TEXT NOT NULL - )''') - conn.execute('''CREATE TABLE IF NOT EXISTS conversation_sessions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL, - timestamp TEXT NOT NULL, - data TEXT NOT NULL - )''') - # MCP configuration table - conn.execute('''CREATE TABLE IF NOT EXISTS mcp_config ( - key TEXT PRIMARY KEY, - value TEXT NOT NULL - )''') - # MCP statistics table - conn.execute('''CREATE TABLE IF NOT EXISTS mcp_stats ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - timestamp TEXT NOT NULL, - tool_name TEXT NOT NULL, - folder TEXT, - success INTEGER NOT NULL, - error_message TEXT - )''') - # MCP databases table - conn.execute('''CREATE TABLE IF NOT EXISTS mcp_databases ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - path TEXT NOT NULL UNIQUE, - name TEXT NOT NULL, - size INTEGER, - tables TEXT, - added_timestamp TEXT NOT NULL - )''') - conn.commit() - -def get_config(key: str) -> Optional[str]: - create_table_if_not_exists() - with sqlite3.connect(str(database)) as conn: - cursor = conn.execute('SELECT value FROM config WHERE key = ?', (key,)) - result = cursor.fetchone() - return result[0] if result else None - -def set_config(key: str, value: str): - create_table_if_not_exists() - with sqlite3.connect(str(database)) as conn: - conn.execute('INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)', (key, value)) - conn.commit() - -def get_mcp_config(key: str) -> Optional[str]: - """Get MCP configuration value.""" - create_table_if_not_exists() - with sqlite3.connect(str(database)) as conn: - cursor = conn.execute('SELECT value FROM mcp_config WHERE key = ?', (key,)) - result = cursor.fetchone() - return result[0] if result else None - -def set_mcp_config(key: str, value: str): - """Set MCP configuration value.""" - create_table_if_not_exists() - with sqlite3.connect(str(database)) as conn: - conn.execute('INSERT OR REPLACE INTO mcp_config (key, value) VALUES (?, ?)', (key, value)) - conn.commit() - -def log_mcp_stat(tool_name: str, folder: Optional[str], success: bool, error_message: Optional[str] = None): - """Log MCP tool usage statistics.""" - create_table_if_not_exists() - timestamp = datetime.datetime.now().isoformat() - with sqlite3.connect(str(database)) as conn: - conn.execute( - 'INSERT INTO mcp_stats (timestamp, tool_name, folder, success, error_message) VALUES (?, ?, ?, ?, ?)', - (timestamp, tool_name, folder, 1 if success else 0, error_message) - ) - conn.commit() - -def get_mcp_stats() -> Dict[str, Any]: - """Get MCP usage statistics.""" - create_table_if_not_exists() - with sqlite3.connect(str(database)) as conn: - cursor = conn.execute(''' - SELECT - COUNT(*) as total_calls, - SUM(CASE WHEN tool_name = 'read_file' THEN 1 ELSE 0 END) as reads, - SUM(CASE WHEN tool_name = 'list_directory' THEN 1 ELSE 0 END) as lists, - SUM(CASE WHEN tool_name = 'search_files' THEN 1 ELSE 0 END) as searches, - SUM(CASE WHEN tool_name = 'inspect_database' THEN 1 ELSE 0 END) as db_inspects, - SUM(CASE WHEN tool_name = 'search_database' THEN 1 ELSE 0 END) as db_searches, - SUM(CASE WHEN tool_name = 'query_database' THEN 1 ELSE 0 END) as db_queries, - MAX(timestamp) as last_used - FROM mcp_stats - ''') - row = cursor.fetchone() - return { - 'total_calls': row[0] or 0, - 'reads': row[1] or 0, - 'lists': row[2] or 0, - 'searches': row[3] or 0, - 'db_inspects': row[4] or 0, - 'db_searches': row[5] or 0, - 'db_queries': row[6] or 0, - 'last_used': row[7] - } - -# Rotating Rich Handler -class RotatingRichHandler(RotatingFileHandler): - """Custom handler combining RotatingFileHandler with Rich formatting.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.rich_console = Console(file=io.StringIO(), width=120, force_terminal=False) - self.rich_handler = RichHandler( - console=self.rich_console, - show_time=True, - show_path=True, - rich_tracebacks=True, - tracebacks_suppress=['requests', 'openrouter', 'urllib3', 'httpx', 'openai'] - ) - - def emit(self, record): - try: - self.rich_handler.emit(record) - output = self.rich_console.file.getvalue() - self.rich_console.file.seek(0) - self.rich_console.file.truncate(0) - if output: - self.stream.write(output) - self.flush() - except Exception: - self.handleError(record) - -# Logging setup -LOG_MAX_SIZE_MB = int(get_config('log_max_size_mb') or "10") -LOG_BACKUP_COUNT = int(get_config('log_backup_count') or "2") -LOG_LEVEL_STR = get_config('log_level') or "info" -LOG_LEVEL = VALID_LOG_LEVELS.get(LOG_LEVEL_STR.lower(), logging.INFO) - -app_handler = None -app_logger = None - -def setup_logging(): - """Setup or reset logging configuration.""" - global app_handler, LOG_MAX_SIZE_MB, LOG_BACKUP_COUNT, LOG_LEVEL, app_logger - - root_logger = logging.getLogger() - - if app_handler is not None: - root_logger.removeHandler(app_handler) - try: - app_handler.close() - except: - pass - - # Check if log needs rotation - if os.path.exists(log_file): - current_size = os.path.getsize(log_file) - max_bytes = LOG_MAX_SIZE_MB * 1024 * 1024 - - if current_size >= max_bytes: - import shutil - timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') - backup_file = f"{log_file}.{timestamp}" - try: - shutil.move(str(log_file), backup_file) - except Exception as e: - print(f"Warning: Could not rotate log file: {e}") - - # Clean old backups - log_dir = os.path.dirname(log_file) - log_basename = os.path.basename(log_file) - backup_pattern = f"{log_basename}.*" - - import glob - backups = sorted(glob.glob(os.path.join(log_dir, backup_pattern))) - - while len(backups) > LOG_BACKUP_COUNT: - oldest = backups.pop(0) - try: - os.remove(oldest) - except: - pass - - app_handler = RotatingRichHandler( - filename=str(log_file), - maxBytes=LOG_MAX_SIZE_MB * 1024 * 1024, - backupCount=LOG_BACKUP_COUNT, - encoding='utf-8' - ) - - app_handler.setLevel(logging.NOTSET) - root_logger.setLevel(logging.WARNING) - root_logger.addHandler(app_handler) - - # Suppress noisy loggers - logging.getLogger('asyncio').setLevel(logging.WARNING) - logging.getLogger('urllib3').setLevel(logging.WARNING) - logging.getLogger('requests').setLevel(logging.WARNING) - logging.getLogger('httpx').setLevel(logging.WARNING) - logging.getLogger('httpcore').setLevel(logging.WARNING) - logging.getLogger('openai').setLevel(logging.WARNING) - logging.getLogger('openrouter').setLevel(logging.WARNING) - - app_logger = logging.getLogger("oai_app") - app_logger.setLevel(LOG_LEVEL) - app_logger.propagate = True - - return app_logger - -def set_log_level(level_str: str) -> bool: - """Set application log level.""" - global LOG_LEVEL, LOG_LEVEL_STR, app_logger - level_str_lower = level_str.lower() - if level_str_lower not in VALID_LOG_LEVELS: - return False - LOG_LEVEL = VALID_LOG_LEVELS[level_str_lower] - LOG_LEVEL_STR = level_str_lower - - if app_logger: - app_logger.setLevel(LOG_LEVEL) - - return True - -def reload_logging_config(): - """Reload logging configuration.""" - global LOG_MAX_SIZE_MB, LOG_BACKUP_COUNT, LOG_LEVEL, LOG_LEVEL_STR, app_logger - - LOG_MAX_SIZE_MB = int(get_config('log_max_size_mb') or "10") - LOG_BACKUP_COUNT = int(get_config('log_backup_count') or "2") - LOG_LEVEL_STR = get_config('log_level') or "info" - LOG_LEVEL = VALID_LOG_LEVELS.get(LOG_LEVEL_STR.lower(), logging.INFO) - - app_logger = setup_logging() - return app_logger - -app_logger = setup_logging() -logger = logging.getLogger(__name__) - -# Load configs -API_KEY = get_config('api_key') -OPENROUTER_BASE_URL = get_config('base_url') or "https://openrouter.ai/api/v1" -STREAM_ENABLED = get_config('stream_enabled') or "on" -DEFAULT_MODEL_ID = get_config('default_model') -MAX_TOKEN = int(get_config('max_token') or "100000") -COST_WARNING_THRESHOLD = float(get_config(HIGH_COST_WARNING) or "0.01") -DEFAULT_ONLINE_MODE = get_config('default_online_mode') or "off" - -# Fetch models -models_data = [] -text_models = [] -try: - headers = { - "Authorization": f"Bearer {API_KEY}", - "HTTP-Referer": APP_URL, - "X-Title": APP_NAME - } if API_KEY else { - "HTTP-Referer": APP_URL, - "X-Title": APP_NAME - } - response = requests.get(f"{OPENROUTER_BASE_URL}/models", headers=headers) - response.raise_for_status() - models_data = response.json()["data"] - text_models = [m for m in models_data if "modalities" not in m or "video" not in (m.get("modalities") or [])] - selected_model_default = None - if DEFAULT_MODEL_ID: - selected_model_default = next((m for m in text_models if m["id"] == DEFAULT_MODEL_ID), None) - if not selected_model_default: - console.print(f"[bold yellow]Warning: Default model '{DEFAULT_MODEL_ID}' unavailable. Use '/config model'.[/]") -except Exception as e: - models_data = [] - text_models = [] - app_logger.error(f"Failed to fetch models: {e}") - -# ============================================================================ -# MCP INTEGRATION CLASSES -# ============================================================================ - -class CrossPlatformMCPConfig: - """Handle OS-specific MCP configuration.""" - - def __init__(self): - self.system = platform.system() - self.is_macos = self.system == "Darwin" - self.is_linux = self.system == "Linux" - self.is_windows = self.system == "Windows" - - app_logger.info(f"Detected OS: {self.system}") - - def get_default_allowed_dirs(self) -> List[Path]: - """Get safe default directories.""" - home = Path.home() - - if self.is_macos: - return [ - home / "Documents", - home / "Desktop", - home / "Downloads" - ] - elif self.is_linux: - dirs = [home / "Documents"] - - try: - for xdg_dir in ["DOCUMENTS", "DESKTOP", "DOWNLOAD"]: - result = subprocess.run( - ["xdg-user-dir", xdg_dir], - capture_output=True, - text=True, - timeout=1 - ) - if result.returncode == 0: - dir_path = Path(result.stdout.strip()) - if dir_path.exists(): - dirs.append(dir_path) - except: - dirs.extend([ - home / "Desktop", - home / "Downloads" - ]) - - return list(set(dirs)) - - elif self.is_windows: - return [ - home / "Documents", - home / "Desktop", - home / "Downloads" - ] - - return [home] - - def get_python_command(self) -> str: - """Get Python command.""" - import sys - return sys.executable - - def get_filesystem_warning(self) -> str: - """Get OS-specific security warning.""" - if self.is_macos: - return """ -โš ๏ธ macOS Security Notice: -The Filesystem MCP server needs access to your selected folder. -You may see a security prompt - click 'Allow' to proceed. -(System Settings > Privacy & Security > Files and Folders) -""" - elif self.is_linux: - return """ -โš ๏ธ Linux Security Notice: -The Filesystem MCP server will access your selected folder. -Ensure oAI has appropriate file permissions. -""" - elif self.is_windows: - return """ -โš ๏ธ Windows Security Notice: -The Filesystem MCP server will access your selected folder. -You may need to grant file access permissions. -""" - return "" - - def normalize_path(self, path: str) -> Path: - """Normalize path for current OS.""" - return Path(os.path.expanduser(path)).resolve() - - def is_system_directory(self, path: Path) -> bool: - """Check if path is a system directory.""" - path_str = str(path) - for blocked in SYSTEM_DIRS_BLACKLIST: - if path_str.startswith(blocked): - return True - return False - - def is_safe_path(self, requested_path: Path, allowed_dirs: List[Path]) -> bool: - """Check if path is within allowed directories.""" - try: - requested = requested_path.resolve() - - for allowed in allowed_dirs: - try: - allowed_resolved = allowed.resolve() - requested.relative_to(allowed_resolved) - return True - except ValueError: - continue - - return False - except: - return False - - def get_folder_stats(self, folder: Path) -> Dict[str, Any]: - """Get folder statistics.""" - try: - if not folder.exists() or not folder.is_dir(): - return {'exists': False} - - file_count = 0 - total_size = 0 - - for item in folder.rglob('*'): - if item.is_file(): - file_count += 1 - try: - total_size += item.stat().st_size - except: - pass - - return { - 'exists': True, - 'file_count': file_count, - 'total_size': total_size, - 'size_mb': total_size / (1024 * 1024) - } - except Exception as e: - app_logger.error(f"Error getting folder stats for {folder}: {e}") - return {'exists': False, 'error': str(e)} - - -class GitignoreParser: - """Parse .gitignore files and check if paths should be ignored.""" - - def __init__(self): - self.patterns = [] # List of (pattern, is_negation, source_dir) - - def add_gitignore(self, gitignore_path: Path): - """Parse and add patterns from a .gitignore file.""" - if not gitignore_path.exists(): - return - - try: - source_dir = gitignore_path.parent - with open(gitignore_path, 'r', encoding='utf-8') as f: - for line_num, line in enumerate(f, 1): - line = line.rstrip('\n\r') - - # Skip empty lines and comments - if not line or line.startswith('#'): - continue - - # Check for negation pattern - is_negation = line.startswith('!') - if is_negation: - line = line[1:] - - # Remove leading slash (make relative to gitignore location) - if line.startswith('/'): - line = line[1:] - - self.patterns.append((line, is_negation, source_dir)) - - app_logger.debug(f"Loaded {len(self.patterns)} patterns from {gitignore_path}") - except Exception as e: - app_logger.warning(f"Error reading {gitignore_path}: {e}") - - def should_ignore(self, path: Path) -> bool: - """Check if a path should be ignored based on gitignore patterns.""" - if not self.patterns: - return False - - ignored = False - - for pattern, is_negation, source_dir in self.patterns: - # Only apply pattern if path is under the source directory - try: - rel_path = path.relative_to(source_dir) - except ValueError: - # Path is not relative to this gitignore's directory - continue - - rel_path_str = str(rel_path) - - # Check if pattern matches - if self._match_pattern(pattern, rel_path_str, path.is_dir()): - if is_negation: - ignored = False # Negation patterns un-ignore - else: - ignored = True - - return ignored - - def _match_pattern(self, pattern: str, path: str, is_dir: bool) -> bool: - """Match a gitignore pattern against a path.""" - # Directory-only pattern (ends with /) - if pattern.endswith('/'): - if not is_dir: - return False - pattern = pattern[:-1] - - # ** matches any number of directories - if '**' in pattern: - # Convert ** to regex-like pattern - pattern_parts = pattern.split('**') - if len(pattern_parts) == 2: - prefix, suffix = pattern_parts - # Match if path starts with prefix and ends with suffix - if prefix: - if not path.startswith(prefix.rstrip('/')): - return False - if suffix: - suffix = suffix.lstrip('/') - if not (path.endswith(suffix) or f'/{suffix}' in path): - return False - return True - - # Direct match - if fnmatch.fnmatch(path, pattern): - return True - - # Match as subdirectory pattern - if '/' not in pattern: - # Pattern without / matches in any directory - parts = path.split('/') - if any(fnmatch.fnmatch(part, pattern) for part in parts): - return True - - return False - - -class SQLiteQueryValidator: - """Validate SQLite queries for read-only safety.""" - - @staticmethod - def is_safe_query(query: str) -> tuple[bool, str]: - """ - Validate that query is a safe read-only SELECT. - Returns (is_safe, error_message) - """ - query_upper = query.strip().upper() - - # Must start with SELECT or WITH (for CTEs) - if not (query_upper.startswith('SELECT') or query_upper.startswith('WITH')): - return False, "Only SELECT queries are allowed (including WITH/CTE)" - - # Block dangerous keywords even in SELECT - dangerous_keywords = [ - 'INSERT', 'UPDATE', 'DELETE', 'DROP', 'CREATE', - 'ALTER', 'TRUNCATE', 'REPLACE', 'ATTACH', 'DETACH', - 'PRAGMA', 'VACUUM', 'REINDEX' - ] - - # Check for dangerous keywords (but allow them in string literals) - # Simple check: look for keywords outside of quotes - query_no_strings = re.sub(r"'[^']*'", '', query_upper) - query_no_strings = re.sub(r'"[^"]*"', '', query_no_strings) - - for keyword in dangerous_keywords: - if re.search(r'\b' + keyword + r'\b', query_no_strings): - return False, f"Keyword '{keyword}' not allowed in read-only mode" - - return True, "" - # P2 -class MCPFilesystemServer: - """MCP Filesystem Server with file access and SQLite database querying.""" - - def __init__(self, allowed_folders: List[Path]): - self.allowed_folders = allowed_folders - self.config = CrossPlatformMCPConfig() - self.max_file_size = 10 * 1024 * 1024 # 10MB limit - self.max_list_items = 1000 # Max items to return in list_directory - self.respect_gitignore = True # Default enabled - - # Initialize gitignore parser - self.gitignore_parser = GitignoreParser() - self._load_gitignores() - - # SQLite configuration - self.max_query_timeout = 5 # seconds - self.max_query_results = 1000 # max rows - self.default_query_limit = 100 # default rows - - app_logger.info(f"MCP Filesystem Server initialized with {len(allowed_folders)} folders") - - def _load_gitignores(self): - """Load all .gitignore files from allowed folders.""" - gitignore_count = 0 - for folder in self.allowed_folders: - if not folder.exists(): - continue - - # Load root .gitignore - root_gitignore = folder / '.gitignore' - if root_gitignore.exists(): - self.gitignore_parser.add_gitignore(root_gitignore) - gitignore_count += 1 - app_logger.info(f"Loaded .gitignore from {folder}") - - # Load nested .gitignore files - try: - for gitignore_path in folder.rglob('.gitignore'): - if gitignore_path != root_gitignore: - self.gitignore_parser.add_gitignore(gitignore_path) - gitignore_count += 1 - app_logger.debug(f"Loaded nested .gitignore: {gitignore_path}") - except Exception as e: - app_logger.warning(f"Error loading nested .gitignores from {folder}: {e}") - - if gitignore_count > 0: - app_logger.info(f"Loaded {gitignore_count} .gitignore file(s) with {len(self.gitignore_parser.patterns)} total patterns") - - def reload_gitignores(self): - """Reload all .gitignore files (call when allowed_folders changes).""" - self.gitignore_parser = GitignoreParser() - self._load_gitignores() - app_logger.info("Reloaded .gitignore patterns") - - def is_allowed_path(self, path: Path) -> bool: - """Check if path is allowed.""" - return self.config.is_safe_path(path, self.allowed_folders) - - async def read_file(self, file_path: str) -> Dict[str, Any]: - """Read file contents.""" - try: - path = self.config.normalize_path(file_path) - - if not self.is_allowed_path(path): - error_msg = f"Access denied: {path} not in allowed folders" - app_logger.warning(error_msg) - log_mcp_stat('read_file', str(path.parent), False, error_msg) - return {'error': error_msg} - - if not path.exists(): - error_msg = f"File not found: {path}" - app_logger.warning(error_msg) - log_mcp_stat('read_file', str(path.parent), False, error_msg) - return {'error': error_msg} - - if not path.is_file(): - error_msg = f"Not a file: {path}" - app_logger.warning(error_msg) - log_mcp_stat('read_file', str(path.parent), False, error_msg) - return {'error': error_msg} - - # Check if file should be ignored - if self.respect_gitignore and self.gitignore_parser.should_ignore(path): - error_msg = f"File ignored by .gitignore: {path}" - app_logger.warning(error_msg) - log_mcp_stat('read_file', str(path.parent), False, error_msg) - return {'error': error_msg} - - file_size = path.stat().st_size - if file_size > self.max_file_size: - error_msg = f"File too large: {file_size / (1024*1024):.1f}MB (max: {self.max_file_size / (1024*1024):.0f}MB)" - app_logger.warning(error_msg) - log_mcp_stat('read_file', str(path.parent), False, error_msg) - return {'error': error_msg} - - # Try to read as text - try: - content = path.read_text(encoding='utf-8') - - # If file is large (>50KB), provide summary instead of full content - max_content_size = 50 * 1024 # 50KB - if file_size > max_content_size: - lines = content.split('\n') - total_lines = len(lines) - - # Return first 500 lines + last 100 lines with notice - head_lines = 500 - tail_lines = 100 - - if total_lines > (head_lines + tail_lines): - truncated_content = ( - '\n'.join(lines[:head_lines]) + - f"\n\n... [TRUNCATED: {total_lines - head_lines - tail_lines} lines omitted to stay within size limits] ...\n\n" + - '\n'.join(lines[-tail_lines:]) - ) - - app_logger.info(f"Read file (truncated): {path} ({file_size} bytes, {total_lines} lines)") - log_mcp_stat('read_file', str(path.parent), True) - - return { - 'content': truncated_content, - 'path': str(path), - 'size': file_size, - 'truncated': True, - 'total_lines': total_lines, - 'lines_shown': head_lines + tail_lines, - 'note': f'File truncated: showing first {head_lines} and last {tail_lines} lines of {total_lines} total' - } - - app_logger.info(f"Read file: {path} ({file_size} bytes)") - log_mcp_stat('read_file', str(path.parent), True) - return { - 'content': content, - 'path': str(path), - 'size': file_size - } - except UnicodeDecodeError: - error_msg = f"Cannot decode file as UTF-8: {path}" - app_logger.warning(error_msg) - log_mcp_stat('read_file', str(path.parent), False, error_msg) - return {'error': error_msg} - - except Exception as e: - error_msg = f"Error reading file: {e}" - app_logger.error(error_msg) - log_mcp_stat('read_file', file_path, False, str(e)) - return {'error': error_msg} - - async def list_directory(self, dir_path: str, recursive: bool = False) -> Dict[str, Any]: - """List directory contents.""" - try: - path = self.config.normalize_path(dir_path) - - if not self.is_allowed_path(path): - error_msg = f"Access denied: {path} not in allowed folders" - app_logger.warning(error_msg) - log_mcp_stat('list_directory', str(path), False, error_msg) - return {'error': error_msg} - - if not path.exists(): - error_msg = f"Directory not found: {path}" - app_logger.warning(error_msg) - log_mcp_stat('list_directory', str(path), False, error_msg) - return {'error': error_msg} - - if not path.is_dir(): - error_msg = f"Not a directory: {path}" - app_logger.warning(error_msg) - log_mcp_stat('list_directory', str(path), False, error_msg) - return {'error': error_msg} - - items = [] - pattern = '**/*' if recursive else '*' - - # Directories to skip (hardcoded common patterns) - skip_dirs = { - '.venv', 'venv', 'env', 'virtualenv', - 'site-packages', 'dist-packages', - '__pycache__', '.pytest_cache', '.mypy_cache', - 'node_modules', '.git', '.svn', - '.idea', '.vscode', - 'build', 'dist', 'eggs', '.eggs' - } - - def should_skip_path(item_path: Path) -> bool: - """Check if path should be skipped.""" - # Check hardcoded skip directories - path_parts = item_path.parts - if any(part in skip_dirs for part in path_parts): - return True - - # Check gitignore patterns (if enabled) - if self.respect_gitignore and self.gitignore_parser.should_ignore(item_path): - return True - - return False - - for item in path.glob(pattern): - # Stop if we hit the limit - if len(items) >= self.max_list_items: - break - - # Skip excluded directories - if should_skip_path(item): - continue - - try: - stat = item.stat() - items.append({ - 'name': item.name, - 'path': str(item), - 'type': 'directory' if item.is_dir() else 'file', - 'size': stat.st_size if item.is_file() else 0, - 'modified': datetime.datetime.fromtimestamp(stat.st_mtime).isoformat() - }) - except: - continue - - truncated = len(items) >= self.max_list_items - - app_logger.info(f"Listed directory: {path} ({len(items)} items, recursive={recursive}, truncated={truncated})") - log_mcp_stat('list_directory', str(path), True) - - result = { - 'path': str(path), - 'items': items, - 'count': len(items), - 'truncated': truncated - } - - if truncated: - result['note'] = f'Results limited to {self.max_list_items} items. Use more specific search path or disable recursive mode.' - - return result - - except Exception as e: - error_msg = f"Error listing directory: {e}" - app_logger.error(error_msg) - log_mcp_stat('list_directory', dir_path, False, str(e)) - return {'error': error_msg} - - async def search_files(self, pattern: str, search_path: Optional[str] = None, content_search: bool = False) -> Dict[str, Any]: - """Search for files.""" - try: - # Determine search roots - if search_path: - path = self.config.normalize_path(search_path) - if not self.is_allowed_path(path): - error_msg = f"Access denied: {path} not in allowed folders" - app_logger.warning(error_msg) - log_mcp_stat('search_files', str(path), False, error_msg) - return {'error': error_msg} - search_roots = [path] - else: - search_roots = self.allowed_folders - - matches = [] - - # Directories to skip (virtual environments, caches, etc.) - skip_dirs = { - '.venv', 'venv', 'env', 'virtualenv', - 'site-packages', 'dist-packages', - '__pycache__', '.pytest_cache', '.mypy_cache', - 'node_modules', '.git', '.svn', - '.idea', '.vscode', - 'build', 'dist', 'eggs', '.eggs' - } - - def should_skip_path(item_path: Path) -> bool: - """Check if path should be skipped.""" - # Check hardcoded skip directories - path_parts = item_path.parts - if any(part in skip_dirs for part in path_parts): - return True - - # Check gitignore patterns (if enabled) - if self.respect_gitignore and self.gitignore_parser.should_ignore(item_path): - return True - - return False - - for root in search_roots: - if not root.exists(): - continue - - # Filename search - if not content_search: - for item in root.rglob(pattern): - if item.is_file() and not should_skip_path(item): - try: - stat = item.stat() - matches.append({ - 'path': str(item), - 'name': item.name, - 'size': stat.st_size, - 'modified': datetime.datetime.fromtimestamp(stat.st_mtime).isoformat() - }) - except: - continue - else: - # Content search (slower) - for item in root.rglob('*'): - if item.is_file() and not should_skip_path(item): - try: - # Skip large files - if item.stat().st_size > self.max_file_size: - continue - - # Try to read and search content - try: - content = item.read_text(encoding='utf-8') - if pattern.lower() in content.lower(): - stat = item.stat() - matches.append({ - 'path': str(item), - 'name': item.name, - 'size': stat.st_size, - 'modified': datetime.datetime.fromtimestamp(stat.st_mtime).isoformat() - }) - except: - continue - except: - continue - - search_type = "content" if content_search else "filename" - app_logger.info(f"Searched files: pattern='{pattern}', type={search_type}, found={len(matches)}") - log_mcp_stat('search_files', str(search_roots[0]) if search_roots else None, True) - - return { - 'pattern': pattern, - 'search_type': search_type, - 'matches': matches, - 'count': len(matches) - } - - except Exception as e: - error_msg = f"Error searching files: {e}" - app_logger.error(error_msg) - log_mcp_stat('search_files', search_path or "all", False, str(e)) - return {'error': error_msg} - - # ======================================================================== - # FILE WRITE METHODS - # ======================================================================== - - async def write_file(self, file_path: str, content: str) -> Dict[str, Any]: - """Create a new file or overwrite an existing file with content. - - Note: Ignores .gitignore - allows writing to any file in allowed folders. - """ - try: - path = self.config.normalize_path(file_path) - - # Permission check: must be in allowed folders - if not self.is_allowed_path(path): - error_msg = f"Access denied: {path} is not in allowed MCP folders" - app_logger.warning(error_msg) - log_mcp_stat('write_file', str(path.parent), False, error_msg) - return {'error': error_msg} - - # Check if parent directory exists, create if needed - parent_dir = path.parent - if not parent_dir.exists(): - parent_dir.mkdir(parents=True, exist_ok=True) - app_logger.info(f"Created parent directory: {parent_dir}") - - # Determine if creating new file or overwriting - is_new_file = not path.exists() - - # Write content to file - path.write_text(content, encoding='utf-8') - file_size = path.stat().st_size - - app_logger.info(f"{'Created' if is_new_file else 'Updated'} file: {path} ({file_size} bytes)") - log_mcp_stat('write_file', str(path.parent), True) - - return { - 'success': True, - 'path': str(path), - 'size': file_size, - 'created': is_new_file - } - - except PermissionError as e: - error_msg = f"Permission denied writing to {file_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('write_file', file_path, False, str(e)) - return {'error': error_msg, 'error_type': 'PermissionError'} - except UnicodeEncodeError as e: - error_msg = f"Encoding error writing to {file_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('write_file', file_path, False, str(e)) - return {'error': error_msg, 'error_type': 'EncodingError'} - except Exception as e: - error_msg = f"Error writing file {file_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('write_file', file_path, False, str(e)) - return {'error': error_msg} - - async def edit_file(self, file_path: str, old_text: str, new_text: str) -> Dict[str, Any]: - """Find and replace text in an existing file. - - Note: Ignores .gitignore - allows editing any file in allowed folders. - """ - try: - path = self.config.normalize_path(file_path) - - # Permission check - if not self.is_allowed_path(path): - error_msg = f"Access denied: {path} is not in allowed MCP folders" - app_logger.warning(error_msg) - log_mcp_stat('edit_file', str(path.parent), False, error_msg) - return {'error': error_msg} - - # File must exist - if not path.exists(): - error_msg = f"File not found: {path}" - app_logger.warning(error_msg) - log_mcp_stat('edit_file', str(path), False, error_msg) - return {'error': error_msg} - - if not path.is_file(): - error_msg = f"Not a file: {path}" - app_logger.warning(error_msg) - log_mcp_stat('edit_file', str(path), False, error_msg) - return {'error': error_msg} - - # Read current content - current_content = path.read_text(encoding='utf-8') - - # Check if old_text exists in file - if old_text not in current_content: - error_msg = f"Text not found in file: '{old_text[:50]}...'" - app_logger.warning(f"Edit failed - text not found in {path}") - log_mcp_stat('edit_file', str(path), False, error_msg) - return {'error': error_msg} - - # Count occurrences - occurrence_count = current_content.count(old_text) - - if occurrence_count > 1: - error_msg = f"Text appears {occurrence_count} times in file. Please provide more context to make the match unique." - app_logger.warning(f"Edit failed - ambiguous match in {path}") - log_mcp_stat('edit_file', str(path), False, error_msg) - return {'error': error_msg} - - # Perform replacement - new_content = current_content.replace(old_text, new_text, 1) - path.write_text(new_content, encoding='utf-8') - - app_logger.info(f"Edited file: {path} (1 replacement)") - log_mcp_stat('edit_file', str(path.parent), True) - - return { - 'success': True, - 'path': str(path), - 'changes': 1 - } - - except PermissionError as e: - error_msg = f"Permission denied editing {file_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('edit_file', file_path, False, str(e)) - return {'error': error_msg, 'error_type': 'PermissionError'} - except UnicodeDecodeError as e: - error_msg = f"Cannot read file (encoding issue): {file_path}" - app_logger.error(error_msg) - log_mcp_stat('edit_file', file_path, False, str(e)) - return {'error': error_msg, 'error_type': 'EncodingError'} - except Exception as e: - error_msg = f"Error editing file {file_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('edit_file', file_path, False, str(e)) - return {'error': error_msg} - - async def delete_file(self, file_path: str, reason: str) -> Dict[str, Any]: - """Delete a file with user confirmation. - - Always requires user to confirm deletion with the LLM's provided reason. - Note: Ignores .gitignore - allows deleting any file in allowed folders. - """ - try: - path = self.config.normalize_path(file_path) - - # Permission check - if not self.is_allowed_path(path): - error_msg = f"Access denied: {path} is not in allowed MCP folders" - app_logger.warning(error_msg) - log_mcp_stat('delete_file', str(path.parent), False, error_msg) - return {'error': error_msg} - - # File must exist - if not path.exists(): - error_msg = f"File not found: {path}" - app_logger.warning(error_msg) - log_mcp_stat('delete_file', str(path), False, error_msg) - return {'error': error_msg} - - if not path.is_file(): - error_msg = f"Not a file: {path}" - app_logger.warning(error_msg) - log_mcp_stat('delete_file', str(path), False, error_msg) - return {'error': error_msg} - - # Get file info for confirmation prompt - file_size = path.stat().st_size - file_mtime = datetime.datetime.fromtimestamp(path.stat().st_mtime) - - # User confirmation required - console.print(Panel.fit( - f"[bold yellow]โš ๏ธ DELETE FILE?[/]\n\n" - f"Path: [cyan]{path}[/]\n" - f"Size: {file_size:,} bytes\n" - f"Modified: {file_mtime.strftime('%Y-%m-%d %H:%M:%S')}\n\n" - f"[bold]LLM Reason:[/] {reason}", - title="Delete File Confirmation", - border_style="yellow" - )) - - confirm = typer.confirm("Delete this file?", default=False) - - if not confirm: - app_logger.info(f"User cancelled file deletion: {path}") - log_mcp_stat('delete_file', str(path), False, "User cancelled") - return { - 'success': False, - 'user_cancelled': True, - 'path': str(path) - } - - # Delete the file - path.unlink() - - app_logger.info(f"Deleted file: {path}") - log_mcp_stat('delete_file', str(path.parent), True) - - return { - 'success': True, - 'path': str(path), - 'user_cancelled': False - } - - except PermissionError as e: - error_msg = f"Permission denied deleting {file_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('delete_file', file_path, False, str(e)) - return {'error': error_msg, 'error_type': 'PermissionError'} - except Exception as e: - error_msg = f"Error deleting file {file_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('delete_file', file_path, False, str(e)) - return {'error': error_msg} - - async def create_directory(self, dir_path: str) -> Dict[str, Any]: - """Create a directory (and parent directories if needed). - - Note: Ignores .gitignore - allows creating directories anywhere in allowed folders. - """ - try: - path = self.config.normalize_path(dir_path) - - # Permission check - if not self.is_allowed_path(path): - error_msg = f"Access denied: {path} is not in allowed MCP folders" - app_logger.warning(error_msg) - log_mcp_stat('create_directory', str(path.parent), False, error_msg) - return {'error': error_msg} - - # Check if already exists - already_exists = path.exists() - - if already_exists and not path.is_dir(): - error_msg = f"Path exists but is not a directory: {path}" - app_logger.warning(error_msg) - log_mcp_stat('create_directory', str(path), False, error_msg) - return {'error': error_msg} - - # Create directory (and parents) - path.mkdir(parents=True, exist_ok=True) - - if already_exists: - app_logger.info(f"Directory already exists: {path}") - else: - app_logger.info(f"Created directory: {path}") - - log_mcp_stat('create_directory', str(path.parent), True) - - return { - 'success': True, - 'path': str(path), - 'created': not already_exists - } - - except PermissionError as e: - error_msg = f"Permission denied creating directory {dir_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('create_directory', dir_path, False, str(e)) - return {'error': error_msg, 'error_type': 'PermissionError'} - except Exception as e: - error_msg = f"Error creating directory {dir_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('create_directory', dir_path, False, str(e)) - return {'error': error_msg} - - async def move_file(self, source_path: str, dest_path: str) -> Dict[str, Any]: - """Move or rename a file. - - Note: Ignores .gitignore - allows moving any file within allowed folders. - """ - try: - source = self.config.normalize_path(source_path) - dest = self.config.normalize_path(dest_path) - - # Both paths must be in allowed folders - if not self.is_allowed_path(source): - error_msg = f"Access denied: source {source} is not in allowed MCP folders" - app_logger.warning(error_msg) - log_mcp_stat('move_file', str(source.parent), False, error_msg) - return {'error': error_msg} - - if not self.is_allowed_path(dest): - error_msg = f"Access denied: destination {dest} is not in allowed MCP folders" - app_logger.warning(error_msg) - log_mcp_stat('move_file', str(dest.parent), False, error_msg) - return {'error': error_msg} - - # Source must exist - if not source.exists(): - error_msg = f"Source file not found: {source}" - app_logger.warning(error_msg) - log_mcp_stat('move_file', str(source), False, error_msg) - return {'error': error_msg} - - if not source.is_file(): - error_msg = f"Source is not a file: {source}" - app_logger.warning(error_msg) - log_mcp_stat('move_file', str(source), False, error_msg) - return {'error': error_msg} - - # Create destination parent directory if needed - dest.parent.mkdir(parents=True, exist_ok=True) - - # Move/rename the file - import shutil - shutil.move(str(source), str(dest)) - - app_logger.info(f"Moved file: {source} -> {dest}") - log_mcp_stat('move_file', str(source.parent), True) - - return { - 'success': True, - 'source': str(source), - 'destination': str(dest) - } - - except PermissionError as e: - error_msg = f"Permission denied moving file: {e}" - app_logger.error(error_msg) - log_mcp_stat('move_file', source_path, False, str(e)) - return {'error': error_msg, 'error_type': 'PermissionError'} - except Exception as e: - error_msg = f"Error moving file from {source_path} to {dest_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('move_file', source_path, False, str(e)) - return {'error': error_msg} - - async def copy_file(self, source_path: str, dest_path: str) -> Dict[str, Any]: - """Copy a file to a new location. - - Note: Ignores .gitignore - allows copying any file within allowed folders. - """ - try: - source = self.config.normalize_path(source_path) - dest = self.config.normalize_path(dest_path) - - # Both paths must be in allowed folders - if not self.is_allowed_path(source): - error_msg = f"Access denied: source {source} is not in allowed MCP folders" - app_logger.warning(error_msg) - log_mcp_stat('copy_file', str(source.parent), False, error_msg) - return {'error': error_msg} - - if not self.is_allowed_path(dest): - error_msg = f"Access denied: destination {dest} is not in allowed MCP folders" - app_logger.warning(error_msg) - log_mcp_stat('copy_file', str(dest.parent), False, error_msg) - return {'error': error_msg} - - # Source must exist - if not source.exists(): - error_msg = f"Source file not found: {source}" - app_logger.warning(error_msg) - log_mcp_stat('copy_file', str(source), False, error_msg) - return {'error': error_msg} - - if not source.is_file(): - error_msg = f"Source is not a file: {source}" - app_logger.warning(error_msg) - log_mcp_stat('copy_file', str(source), False, error_msg) - return {'error': error_msg} - - # Create destination parent directory if needed - dest.parent.mkdir(parents=True, exist_ok=True) - - # Copy the file - import shutil - shutil.copy2(str(source), str(dest)) - - file_size = dest.stat().st_size - - app_logger.info(f"Copied file: {source} -> {dest} ({file_size} bytes)") - log_mcp_stat('copy_file', str(source.parent), True) - - return { - 'success': True, - 'source': str(source), - 'destination': str(dest), - 'size': file_size - } - - except PermissionError as e: - error_msg = f"Permission denied copying file: {e}" - app_logger.error(error_msg) - log_mcp_stat('copy_file', source_path, False, str(e)) - return {'error': error_msg, 'error_type': 'PermissionError'} - except Exception as e: - error_msg = f"Error copying file from {source_path} to {dest_path}: {e}" - app_logger.error(error_msg) - log_mcp_stat('copy_file', source_path, False, str(e)) - return {'error': error_msg} - - # ======================================================================== - # SQLite DATABASE METHODS - # ======================================================================== - - async def inspect_database(self, db_path: str, table_name: Optional[str] = None) -> Dict[str, Any]: - """Inspect SQLite database schema.""" - try: - path = Path(db_path).resolve() - - if not path.exists(): - error_msg = f"Database not found: {path}" - app_logger.warning(error_msg) - log_mcp_stat('inspect_database', str(path.parent), False, error_msg) - return {'error': error_msg} - - if not path.is_file(): - error_msg = f"Not a file: {path}" - app_logger.warning(error_msg) - log_mcp_stat('inspect_database', str(path.parent), False, error_msg) - return {'error': error_msg} - - # Open database in read-only mode - try: - conn = sqlite3.connect(f'file:{path}?mode=ro', uri=True) - cursor = conn.cursor() - except sqlite3.DatabaseError as e: - error_msg = f"Not a valid SQLite database: {path} ({e})" - app_logger.warning(error_msg) - log_mcp_stat('inspect_database', str(path.parent), False, error_msg) - return {'error': error_msg} - - try: - if table_name: - # Get info for specific table - cursor.execute("SELECT sql FROM sqlite_master WHERE type='table' AND name=?", (table_name,)) - result = cursor.fetchone() - - if not result: - return {'error': f"Table '{table_name}' not found"} - - create_sql = result[0] - - # Get column info - cursor.execute(f"PRAGMA table_info({table_name})") - columns = [] - for row in cursor.fetchall(): - columns.append({ - 'id': row[0], - 'name': row[1], - 'type': row[2], - 'not_null': bool(row[3]), - 'default_value': row[4], - 'primary_key': bool(row[5]) - }) - - # Get indexes - cursor.execute(f"PRAGMA index_list({table_name})") - indexes = [] - for row in cursor.fetchall(): - indexes.append({ - 'name': row[1], - 'unique': bool(row[2]) - }) - - # Get row count - cursor.execute(f"SELECT COUNT(*) FROM {table_name}") - row_count = cursor.fetchone()[0] - - app_logger.info(f"Inspected table: {table_name} in {path}") - log_mcp_stat('inspect_database', str(path.parent), True) - - return { - 'database': str(path), - 'table': table_name, - 'create_sql': create_sql, - 'columns': columns, - 'indexes': indexes, - 'row_count': row_count - } - else: - # Get all tables - cursor.execute("SELECT name, sql FROM sqlite_master WHERE type='table' ORDER BY name") - tables = [] - for row in cursor.fetchall(): - table = row[0] - # Get row count - cursor.execute(f"SELECT COUNT(*) FROM {table}") - count = cursor.fetchone()[0] - tables.append({ - 'name': table, - 'row_count': count - }) - - # Get indexes - cursor.execute("SELECT name, tbl_name FROM sqlite_master WHERE type='index' ORDER BY name") - indexes = [{'name': row[0], 'table': row[1]} for row in cursor.fetchall()] - - # Get triggers - cursor.execute("SELECT name, tbl_name FROM sqlite_master WHERE type='trigger' ORDER BY name") - triggers = [{'name': row[0], 'table': row[1]} for row in cursor.fetchall()] - - # Get database size - db_size = path.stat().st_size - - app_logger.info(f"Inspected database: {path} ({len(tables)} tables)") - log_mcp_stat('inspect_database', str(path.parent), True) - - return { - 'database': str(path), - 'size': db_size, - 'size_mb': db_size / (1024 * 1024), - 'tables': tables, - 'table_count': len(tables), - 'indexes': indexes, - 'triggers': triggers - } - finally: - conn.close() - - except Exception as e: - error_msg = f"Error inspecting database: {e}" - app_logger.error(error_msg) - log_mcp_stat('inspect_database', db_path, False, str(e)) - return {'error': error_msg} - - async def search_database(self, db_path: str, search_term: str, table_name: Optional[str] = None, column_name: Optional[str] = None) -> Dict[str, Any]: - """Search for a value across database tables.""" - try: - path = Path(db_path).resolve() - - if not path.exists(): - error_msg = f"Database not found: {path}" - app_logger.warning(error_msg) - log_mcp_stat('search_database', str(path.parent), False, error_msg) - return {'error': error_msg} - - # Open database in read-only mode - try: - conn = sqlite3.connect(f'file:{path}?mode=ro', uri=True) - cursor = conn.cursor() - except sqlite3.DatabaseError as e: - error_msg = f"Not a valid SQLite database: {path} ({e})" - app_logger.warning(error_msg) - log_mcp_stat('search_database', str(path.parent), False, error_msg) - return {'error': error_msg} - - try: - matches = [] - - # Get tables to search - if table_name: - cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name=?", (table_name,)) - tables = [row[0] for row in cursor.fetchall()] - if not tables: - return {'error': f"Table '{table_name}' not found"} - else: - cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") - tables = [row[0] for row in cursor.fetchall()] - - for table in tables: - # Get columns for this table - cursor.execute(f"PRAGMA table_info({table})") - columns = [row[1] for row in cursor.fetchall()] - - # Filter columns if specified - if column_name: - if column_name not in columns: - continue - columns = [column_name] - - # Search in each column - for column in columns: - try: - # Use LIKE for partial matching - query = f"SELECT * FROM {table} WHERE {column} LIKE ? LIMIT {self.default_query_limit}" - cursor.execute(query, (f'%{search_term}%',)) - results = cursor.fetchall() - - if results: - # Get column names - col_names = [desc[0] for desc in cursor.description] - - for row in results: - # Convert row to dict, handling binary data - row_dict = {} - for col_name, value in zip(col_names, row): - if isinstance(value, bytes): - # Convert bytes to hex string or base64 - try: - # Try to decode as UTF-8 first - row_dict[col_name] = value.decode('utf-8') - except UnicodeDecodeError: - # If binary, show as hex string - row_dict[col_name] = f"" - else: - row_dict[col_name] = value - - matches.append({ - 'table': table, - 'column': column, - 'row': row_dict - }) - except sqlite3.Error: - # Skip columns that can't be searched (e.g., BLOBs) - continue - - app_logger.info(f"Searched database: {path} for '{search_term}' - found {len(matches)} matches") - log_mcp_stat('search_database', str(path.parent), True) - - result = { - 'database': str(path), - 'search_term': search_term, - 'matches': matches, - 'count': len(matches) - } - - if table_name: - result['table_filter'] = table_name - if column_name: - result['column_filter'] = column_name - - if len(matches) >= self.default_query_limit: - result['note'] = f'Results limited to {self.default_query_limit} matches. Use more specific search or query_database for full results.' - - return result - - finally: - conn.close() - - except Exception as e: - error_msg = f"Error searching database: {e}" - app_logger.error(error_msg) - log_mcp_stat('search_database', db_path, False, str(e)) - return {'error': error_msg} - - async def query_database(self, db_path: str, query: str, limit: Optional[int] = None) -> Dict[str, Any]: - """Execute a read-only SQL query on database.""" - try: - path = Path(db_path).resolve() - - if not path.exists(): - error_msg = f"Database not found: {path}" - app_logger.warning(error_msg) - log_mcp_stat('query_database', str(path.parent), False, error_msg) - return {'error': error_msg} - - # Validate query - is_safe, error_msg = SQLiteQueryValidator.is_safe_query(query) - if not is_safe: - app_logger.warning(f"Unsafe query rejected: {error_msg}") - log_mcp_stat('query_database', str(path.parent), False, error_msg) - return {'error': error_msg} - - # Open database in read-only mode - try: - conn = sqlite3.connect(f'file:{path}?mode=ro', uri=True) - cursor = conn.cursor() - except sqlite3.DatabaseError as e: - error_msg = f"Not a valid SQLite database: {path} ({e})" - app_logger.warning(error_msg) - log_mcp_stat('query_database', str(path.parent), False, error_msg) - return {'error': error_msg} - - try: - # Set query timeout - def timeout_handler(signum, frame): - raise TimeoutError(f"Query exceeded {self.max_query_timeout} second timeout") - - # Note: signal.alarm only works on Unix-like systems - if hasattr(signal, 'SIGALRM'): - signal.signal(signal.SIGALRM, timeout_handler) - signal.alarm(self.max_query_timeout) - - try: - # Execute query - cursor.execute(query) - - # Get results - result_limit = min(limit or self.default_query_limit, self.max_query_results) - results = cursor.fetchmany(result_limit) - - # Get column names - columns = [desc[0] for desc in cursor.description] if cursor.description else [] - - # Convert to list of dicts, handling binary data - rows = [] - for row in results: - row_dict = {} - for col_name, value in zip(columns, row): - if isinstance(value, bytes): - # Convert bytes to readable format - try: - # Try to decode as UTF-8 first - row_dict[col_name] = value.decode('utf-8') - except UnicodeDecodeError: - # If binary, show as hex string or summary - row_dict[col_name] = f"" - else: - row_dict[col_name] = value - rows.append(row_dict) - - # Check if more results available - has_more = len(results) == result_limit - if has_more: - # Try to fetch one more to check - one_more = cursor.fetchone() - has_more = one_more is not None - - finally: - if hasattr(signal, 'SIGALRM'): - signal.alarm(0) # Cancel timeout - - app_logger.info(f"Executed query on {path}: returned {len(rows)} rows") - log_mcp_stat('query_database', str(path.parent), True) - - result = { - 'database': str(path), - 'query': query, - 'columns': columns, - 'rows': rows, - 'count': len(rows) - } - - if has_more: - result['truncated'] = True - result['note'] = f'Results limited to {result_limit} rows. Use LIMIT clause in query for more control.' - - return result - - except TimeoutError as e: - error_msg = str(e) - app_logger.warning(error_msg) - log_mcp_stat('query_database', str(path.parent), False, error_msg) - return {'error': error_msg} - except sqlite3.Error as e: - error_msg = f"SQL error: {e}" - app_logger.warning(error_msg) - log_mcp_stat('query_database', str(path.parent), False, error_msg) - return {'error': error_msg} - finally: - if hasattr(signal, 'SIGALRM'): - signal.alarm(0) # Ensure timeout is cancelled - conn.close() - - except Exception as e: - error_msg = f"Error executing query: {e}" - app_logger.error(error_msg) - log_mcp_stat('query_database', db_path, False, str(e)) - return {'error': error_msg} - - -class MCPManager: - """Manage MCP server lifecycle, tool calls, and mode switching.""" - - def __init__(self): - self.enabled = False - self.write_enabled = False # Write mode off by default (non-persistent) - self.mode = "files" # "files" or "database" - self.selected_db_index = None - - self.server: Optional[MCPFilesystemServer] = None - - # File/folder mode - self.allowed_folders: List[Path] = [] - - # Database mode - self.databases: List[Dict[str, Any]] = [] - - self.config = CrossPlatformMCPConfig() - self.session_start_time: Optional[datetime.datetime] = None - - # Load persisted data - self._load_folders() - self._load_databases() - - app_logger.info("MCP Manager initialized") - - def _load_folders(self): - """Load allowed folders from database.""" - folders_json = get_mcp_config('allowed_folders') - if folders_json: - try: - folder_paths = json.loads(folders_json) - self.allowed_folders = [self.config.normalize_path(p) for p in folder_paths] - app_logger.info(f"Loaded {len(self.allowed_folders)} folders from config") - except Exception as e: - app_logger.error(f"Error loading MCP folders: {e}") - self.allowed_folders = [] - - def _save_folders(self): - """Save allowed folders to database.""" - folder_paths = [str(p) for p in self.allowed_folders] - set_mcp_config('allowed_folders', json.dumps(folder_paths)) - app_logger.info(f"Saved {len(self.allowed_folders)} folders to config") - - def _load_databases(self): - """Load databases from database.""" - create_table_if_not_exists() - try: - with sqlite3.connect(str(database)) as conn: - cursor = conn.execute('SELECT id, path, name, size, tables, added_timestamp FROM mcp_databases ORDER BY id') - self.databases = [] - for row in cursor.fetchall(): - tables_list = json.loads(row[4]) if row[4] else [] - self.databases.append({ - 'id': row[0], - 'path': row[1], - 'name': row[2], - 'size': row[3], - 'tables': tables_list, - 'added': row[5] - }) - app_logger.info(f"Loaded {len(self.databases)} databases from config") - except Exception as e: - app_logger.error(f"Error loading databases: {e}") - self.databases = [] - - def _save_database(self, db_info: Dict[str, Any]): - """Save a database to config.""" - create_table_if_not_exists() - try: - with sqlite3.connect(str(database)) as conn: - conn.execute( - 'INSERT INTO mcp_databases (path, name, size, tables, added_timestamp) VALUES (?, ?, ?, ?, ?)', - (db_info['path'], db_info['name'], db_info['size'], json.dumps(db_info['tables']), db_info['added']) - ) - conn.commit() - # Get the ID - cursor = conn.execute('SELECT id FROM mcp_databases WHERE path = ?', (db_info['path'],)) - db_info['id'] = cursor.fetchone()[0] - app_logger.info(f"Saved database {db_info['name']} to config") - except Exception as e: - app_logger.error(f"Error saving database: {e}") - - def _remove_database_from_config(self, db_path: str): - """Remove database from config.""" - create_table_if_not_exists() - try: - with sqlite3.connect(str(database)) as conn: - conn.execute('DELETE FROM mcp_databases WHERE path = ?', (db_path,)) - conn.commit() - app_logger.info(f"Removed database from config: {db_path}") - except Exception as e: - app_logger.error(f"Error removing database from config: {e}") - - def enable(self) -> Dict[str, Any]: - """Enable MCP server.""" - if not MCP_AVAILABLE: - return { - 'success': False, - 'error': 'MCP library not installed. Run: pip install mcp' - } - - if self.enabled: - return { - 'success': False, - 'error': 'MCP is already enabled' - } - - try: - self.server = MCPFilesystemServer(self.allowed_folders) - self.enabled = True - self.session_start_time = datetime.datetime.now() - - set_mcp_config('mcp_enabled', 'true') - - app_logger.info("MCP Filesystem Server enabled") - - return { - 'success': True, - 'folder_count': len(self.allowed_folders), - 'database_count': len(self.databases), - 'message': 'MCP Filesystem Server started successfully' - } - except Exception as e: - app_logger.error(f"Error enabling MCP: {e}") - return { - 'success': False, - 'error': str(e) - } - - def disable(self) -> Dict[str, Any]: - """Disable MCP server.""" - if not self.enabled: - return { - 'success': False, - 'error': 'MCP is not enabled' - } - - try: - self.server = None - self.enabled = False - self.session_start_time = None - self.mode = "files" - self.selected_db_index = None - - set_mcp_config('mcp_enabled', 'false') - - app_logger.info("MCP Filesystem Server disabled") - - return { - 'success': True, - 'message': 'MCP Filesystem Server stopped' - } - except Exception as e: - app_logger.error(f"Error disabling MCP: {e}") - return { - 'success': False, - 'error': str(e) - } - - def enable_write(self) -> None: - """Enable write mode with user confirmation.""" - if not self.enabled: - console.print("[bold red]Error:[/] MCP must be enabled first. Use '/mcp on'") - app_logger.warning("Attempted to enable write mode without MCP enabled") - return - - # Show warning panel - console.print(Panel.fit( - "[bold yellow]โš ๏ธ WRITE MODE WARNING[/]\n\n" - "Enabling write mode allows the LLM to:\n" - "โ€ข Create new files\n" - "โ€ข Modify existing files\n" - "โ€ข Delete files (with confirmation)\n" - "โ€ข Create directories\n" - "โ€ข Move and copy files\n\n" - "[dim]Operations are limited to your allowed MCP folders.[/]\n" - "[dim]Deletions will always require your confirmation.[/]\n\n" - "[bold red]Use with caution. Review LLM changes carefully.[/]", - title="MCP Write Mode", - border_style="yellow" - )) - - confirm = typer.confirm("Enable write mode?", default=False) - - if confirm: - self.write_enabled = True - console.print("[bold green]โœ“ Write mode enabled[/]") - app_logger.info("MCP write mode enabled by user") - log_mcp_stat('write_mode_enabled', '', True) - else: - console.print("[yellow]Cancelled. Write mode remains disabled.[/]") - app_logger.info("User cancelled write mode enablement") - - def disable_write(self) -> None: - """Disable write mode.""" - if not self.write_enabled: - console.print("[dim]Write mode is already disabled.[/]") - return - - self.write_enabled = False - console.print("[bold green]โœ“ Write mode disabled[/]") - app_logger.info("MCP write mode disabled") - log_mcp_stat('write_mode_disabled', '', True) - - def switch_mode(self, new_mode: str, db_index: Optional[int] = None) -> Dict[str, Any]: - """Switch between files and database mode.""" - if not self.enabled: - return { - 'success': False, - 'error': 'MCP is not enabled. Use /mcp on first' - } - - if new_mode == "files": - self.mode = "files" - self.selected_db_index = None - app_logger.info("Switched to file mode") - return { - 'success': True, - 'mode': 'files', - 'message': 'Switched to file mode', - 'tools': ['read_file', 'list_directory', 'search_files'] - } - elif new_mode == "database": - if db_index is None: - return { - 'success': False, - 'error': 'Database index required. Use /mcp db ' - } - - if db_index < 1 or db_index > len(self.databases): - return { - 'success': False, - 'error': f'Invalid database number. Use 1-{len(self.databases)}' - } - - self.mode = "database" - self.selected_db_index = db_index - 1 # Convert to 0-based - db = self.databases[self.selected_db_index] - - app_logger.info(f"Switched to database mode: {db['name']}") - return { - 'success': True, - 'mode': 'database', - 'database': db, - 'message': f"Switched to database #{db_index}: {db['name']}", - 'tools': ['inspect_database', 'search_database', 'query_database'] - } - else: - return { - 'success': False, - 'error': f'Invalid mode: {new_mode}' - } - - def add_folder(self, folder_path: str) -> Dict[str, Any]: - """Add folder to allowed list.""" - if not self.enabled: - return { - 'success': False, - 'error': 'MCP is not enabled. Use /mcp on first' - } - - try: - path = self.config.normalize_path(folder_path) - - # Validation checks - if not path.exists(): - return { - 'success': False, - 'error': f'Directory does not exist: {path}' - } - - if not path.is_dir(): - return { - 'success': False, - 'error': f'Not a directory: {path}' - } - - if self.config.is_system_directory(path): - return { - 'success': False, - 'error': f'Cannot add system directory: {path}' - } - - # Check if already added - if path in self.allowed_folders: - return { - 'success': False, - 'error': f'Folder already in allowed list: {path}' - } - - # Check for nested paths - parent_folder = None - for existing in self.allowed_folders: - try: - path.relative_to(existing) - parent_folder = existing - break - except ValueError: - continue - - # Get folder stats - stats = self.config.get_folder_stats(path) - - # Add folder - self.allowed_folders.append(path) - self._save_folders() - - # Update server and reload gitignores - if self.server: - self.server.allowed_folders = self.allowed_folders - self.server.reload_gitignores() - - app_logger.info(f"Added folder to MCP: {path}") - - result = { - 'success': True, - 'path': str(path), - 'stats': stats, - 'total_folders': len(self.allowed_folders) - } - - if parent_folder: - result['warning'] = f'Note: {parent_folder} is already allowed (parent folder)' - - return result - - except Exception as e: - app_logger.error(f"Error adding folder: {e}") - return { - 'success': False, - 'error': str(e) - } - - def remove_folder(self, folder_ref: str) -> Dict[str, Any]: - """Remove folder from allowed list (by path or number).""" - if not self.enabled: - return { - 'success': False, - 'error': 'MCP is not enabled. Use /mcp on first' - } - - try: - # Check if it's a number - if folder_ref.isdigit(): - index = int(folder_ref) - 1 - if 0 <= index < len(self.allowed_folders): - path = self.allowed_folders[index] - else: - return { - 'success': False, - 'error': f'Invalid folder number: {folder_ref}' - } - else: - # Treat as path - path = self.config.normalize_path(folder_ref) - if path not in self.allowed_folders: - return { - 'success': False, - 'error': f'Folder not in allowed list: {path}' - } - - # Remove folder - self.allowed_folders.remove(path) - self._save_folders() - - # Update server and reload gitignores - if self.server: - self.server.allowed_folders = self.allowed_folders - self.server.reload_gitignores() - - app_logger.info(f"Removed folder from MCP: {path}") - - result = { - 'success': True, - 'path': str(path), - 'total_folders': len(self.allowed_folders) - } - - if len(self.allowed_folders) == 0: - result['warning'] = 'No allowed folders remaining. Add one with /mcp add ' - - return result - - except Exception as e: - app_logger.error(f"Error removing folder: {e}") - return { - 'success': False, - 'error': str(e) - } - - def add_database(self, db_path: str) -> Dict[str, Any]: - """Add SQLite database.""" - if not self.enabled: - return { - 'success': False, - 'error': 'MCP is not enabled. Use /mcp on first' - } - - try: - path = Path(db_path).resolve() - - # Validation - if not path.exists(): - return { - 'success': False, - 'error': f'Database file not found: {path}' - } - - if not path.is_file(): - return { - 'success': False, - 'error': f'Not a file: {path}' - } - - # Check if already added - if any(db['path'] == str(path) for db in self.databases): - return { - 'success': False, - 'error': f'Database already added: {path.name}' - } - - # Validate it's a SQLite database and get schema - try: - conn = sqlite3.connect(f'file:{path}?mode=ro', uri=True) - cursor = conn.cursor() - - # Get tables - cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name") - tables = [row[0] for row in cursor.fetchall()] - - conn.close() - except sqlite3.DatabaseError as e: - return { - 'success': False, - 'error': f'Not a valid SQLite database: {e}' - } - - # Get file size - db_size = path.stat().st_size - - # Create database entry - db_info = { - 'path': str(path), - 'name': path.name, - 'size': db_size, - 'tables': tables, - 'added': datetime.datetime.now().isoformat() - } - - # Save to config - self._save_database(db_info) - - # Add to list - self.databases.append(db_info) - - app_logger.info(f"Added database: {path.name}") - - return { - 'success': True, - 'database': db_info, - 'number': len(self.databases), - 'message': f'Added database #{len(self.databases)}: {path.name}' - } - - except Exception as e: - app_logger.error(f"Error adding database: {e}") - return { - 'success': False, - 'error': str(e) - } - - def remove_database(self, db_ref: str) -> Dict[str, Any]: - """Remove database (by number or path).""" - if not self.enabled: - return { - 'success': False, - 'error': 'MCP is not enabled' - } - - try: - # Check if it's a number - if db_ref.isdigit(): - index = int(db_ref) - 1 - if 0 <= index < len(self.databases): - db = self.databases[index] - else: - return { - 'success': False, - 'error': f'Invalid database number: {db_ref}' - } - else: - # Treat as path - path = str(Path(db_ref).resolve()) - db = next((d for d in self.databases if d['path'] == path), None) - if not db: - return { - 'success': False, - 'error': f'Database not found: {db_ref}' - } - index = self.databases.index(db) - - # If currently selected, deselect - if self.mode == "database" and self.selected_db_index == index: - self.mode = "files" - self.selected_db_index = None - - # Remove from config - self._remove_database_from_config(db['path']) - - # Remove from list - self.databases.pop(index) - - app_logger.info(f"Removed database: {db['name']}") - - result = { - 'success': True, - 'database': db, - 'message': f"Removed database: {db['name']}" - } - - if len(self.databases) == 0: - result['warning'] = 'No databases remaining. Add one with /mcp add db ' - - return result - - except Exception as e: - app_logger.error(f"Error removing database: {e}") - return { - 'success': False, - 'error': str(e) - } - - def list_databases(self) -> Dict[str, Any]: - """List all databases.""" - try: - db_list = [] - for idx, db in enumerate(self.databases, 1): - db_info = { - 'number': idx, - 'name': db['name'], - 'path': db['path'], - 'size_mb': db['size'] / (1024 * 1024), - 'tables': db['tables'], - 'table_count': len(db['tables']), - 'added': db['added'] - } - - # Check if file still exists - if not Path(db['path']).exists(): - db_info['warning'] = 'File not found' - - db_list.append(db_info) - - return { - 'success': True, - 'databases': db_list, - 'count': len(db_list) - } - - except Exception as e: - app_logger.error(f"Error listing databases: {e}") - return { - 'success': False, - 'error': str(e) - } - - def toggle_gitignore(self, enabled: bool) -> Dict[str, Any]: - """Toggle .gitignore filtering.""" - if not self.enabled: - return { - 'success': False, - 'error': 'MCP is not enabled' - } - - if not self.server: - return { - 'success': False, - 'error': 'MCP server not running' - } - - self.server.respect_gitignore = enabled - status = "enabled" if enabled else "disabled" - - app_logger.info(f".gitignore filtering {status}") - - return { - 'success': True, - 'message': f'.gitignore filtering {status}', - 'pattern_count': len(self.server.gitignore_parser.patterns) if enabled else 0 - } - - def list_folders(self) -> Dict[str, Any]: - """List all allowed folders with stats.""" - try: - folders_info = [] - total_files = 0 - total_size = 0 - - for idx, folder in enumerate(self.allowed_folders, 1): - stats = self.config.get_folder_stats(folder) - - folder_info = { - 'number': idx, - 'path': str(folder), - 'exists': stats.get('exists', False) - } - - if stats.get('exists'): - folder_info['file_count'] = stats['file_count'] - folder_info['size_mb'] = stats['size_mb'] - total_files += stats['file_count'] - total_size += stats['total_size'] - - folders_info.append(folder_info) - - return { - 'success': True, - 'folders': folders_info, - 'total_folders': len(self.allowed_folders), - 'total_files': total_files, - 'total_size_mb': total_size / (1024 * 1024) - } - - except Exception as e: - app_logger.error(f"Error listing folders: {e}") - return { - 'success': False, - 'error': str(e) - } - - def get_status(self) -> Dict[str, Any]: - """Get comprehensive MCP status.""" - try: - stats = get_mcp_stats() - - uptime = None - if self.session_start_time: - delta = datetime.datetime.now() - self.session_start_time - hours = delta.seconds // 3600 - minutes = (delta.seconds % 3600) // 60 - uptime = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m" - - folder_info = self.list_folders() - - gitignore_status = "enabled" if (self.server and self.server.respect_gitignore) else "disabled" - gitignore_patterns = len(self.server.gitignore_parser.patterns) if self.server else 0 - - # Current mode info - mode_info = { - 'mode': self.mode, - 'mode_display': '๐Ÿ”ง Files' if self.mode == 'files' else f'๐Ÿ—„๏ธ DB #{self.selected_db_index + 1}' - } - - if self.mode == 'database' and self.selected_db_index is not None: - db = self.databases[self.selected_db_index] - mode_info['database'] = db - - return { - 'success': True, - 'enabled': self.enabled, - 'write_enabled': self.write_enabled, - 'uptime': uptime, - 'mode_info': mode_info, - 'folder_count': len(self.allowed_folders), - 'database_count': len(self.databases), - 'total_files': folder_info.get('total_files', 0), - 'total_size_mb': folder_info.get('total_size_mb', 0), - 'stats': stats, - 'tools_available': self._get_current_tools(), - 'gitignore_status': gitignore_status, - 'gitignore_patterns': gitignore_patterns - } - - except Exception as e: - app_logger.error(f"Error getting MCP status: {e}") - return { - 'success': False, - 'error': str(e) - } - - def _get_current_tools(self) -> List[str]: - """Get tools available in current mode.""" - if not self.enabled: - return [] - - if self.mode == "files": - return ['read_file', 'list_directory', 'search_files'] - elif self.mode == "database": - return ['inspect_database', 'search_database', 'query_database'] - - return [] - - async def call_tool(self, tool_name: str, **kwargs) -> Dict[str, Any]: - """Call an MCP tool.""" - if not self.enabled or not self.server: - return { - 'error': 'MCP is not enabled' - } - - # Check write permissions for write operations - write_tools = {'write_file', 'edit_file', 'delete_file', 'create_directory', 'move_file', 'copy_file'} - if tool_name in write_tools and not self.write_enabled: - return { - 'error': 'Write operations disabled. Enable with /mcp write on' - } - - try: - # File mode tools (read-only) - if tool_name == 'read_file': - return await self.server.read_file(kwargs.get('file_path', '')) - elif tool_name == 'list_directory': - return await self.server.list_directory( - kwargs.get('dir_path', ''), - kwargs.get('recursive', False) - ) - elif tool_name == 'search_files': - return await self.server.search_files( - kwargs.get('pattern', ''), - kwargs.get('search_path'), - kwargs.get('content_search', False) - ) - - # Database mode tools - elif tool_name == 'inspect_database': - if self.mode != 'database' or self.selected_db_index is None: - return {'error': 'Not in database mode. Use /mcp db first'} - db = self.databases[self.selected_db_index] - return await self.server.inspect_database( - db['path'], - kwargs.get('table_name') - ) - elif tool_name == 'search_database': - if self.mode != 'database' or self.selected_db_index is None: - return {'error': 'Not in database mode. Use /mcp db first'} - db = self.databases[self.selected_db_index] - return await self.server.search_database( - db['path'], - kwargs.get('search_term', ''), - kwargs.get('table_name'), - kwargs.get('column_name') - ) - elif tool_name == 'query_database': - if self.mode != 'database' or self.selected_db_index is None: - return {'error': 'Not in database mode. Use /mcp db first'} - db = self.databases[self.selected_db_index] - return await self.server.query_database( - db['path'], - kwargs.get('query', ''), - kwargs.get('limit') - ) - - # Write mode tools - elif tool_name == 'write_file': - return await self.server.write_file( - kwargs.get('file_path', ''), - kwargs.get('content', '') - ) - elif tool_name == 'edit_file': - return await self.server.edit_file( - kwargs.get('file_path', ''), - kwargs.get('old_text', ''), - kwargs.get('new_text', '') - ) - elif tool_name == 'delete_file': - return await self.server.delete_file( - kwargs.get('file_path', ''), - kwargs.get('reason', 'No reason provided') - ) - elif tool_name == 'create_directory': - return await self.server.create_directory( - kwargs.get('dir_path', '') - ) - elif tool_name == 'move_file': - return await self.server.move_file( - kwargs.get('source_path', ''), - kwargs.get('dest_path', '') - ) - elif tool_name == 'copy_file': - return await self.server.copy_file( - kwargs.get('source_path', ''), - kwargs.get('dest_path', '') - ) - else: - return { - 'error': f'Unknown tool: {tool_name}' - } - except Exception as e: - app_logger.error(f"Error calling MCP tool {tool_name}: {e}") - return { - 'error': str(e) - } - - def get_tools_schema(self) -> List[Dict[str, Any]]: - """Get MCP tools as OpenAI function calling schema (for current mode).""" - if not self.enabled: - return [] - - if self.mode == "files": - return self._get_file_tools_schema() - elif self.mode == "database": - return self._get_database_tools_schema() - - return [] - - def _get_file_tools_schema(self) -> List[Dict[str, Any]]: - """Get file mode tools schema.""" - if len(self.allowed_folders) == 0: - return [] - - allowed_dirs_str = ", ".join(str(f) for f in self.allowed_folders) - - # Base read-only tools - tools = [ - { - "type": "function", - "function": { - "name": "search_files", - "description": f"Search for files in the user's local filesystem. Allowed directories: {allowed_dirs_str}. Can search by filename pattern (e.g., '*.py' for Python files) or search within file contents. Automatically respects .gitignore patterns and excludes virtual environments.", - "parameters": { - "type": "object", - "properties": { - "pattern": { - "type": "string", - "description": "Search pattern. For filename search, use glob patterns like '*.py', '*.txt', 'report*'. For content search, use plain text to search for." - }, - "content_search": { - "type": "boolean", - "description": "If true, searches inside file contents for the pattern (SLOW - use only when specifically asked to search file contents). If false, searches only filenames (FAST - use for finding files by name). Default is false. Virtual environments and package directories are automatically excluded.", - "default": False - }, - "search_path": { - "type": "string", - "description": "Optional: specific directory to search in. If not provided, searches all allowed directories.", - } - }, - "required": ["pattern"] - } - } - }, - { - "type": "function", - "function": { - "name": "read_file", - "description": "Read the complete contents of a text file from the user's local filesystem. Only works for files within allowed directories. Maximum file size: 10MB. Files larger than 50KB are automatically truncated. Respects .gitignore patterns.", - "parameters": { - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "Full path to the file to read (e.g., /Users/username/Documents/report.txt or ~/Documents/notes.md)" - } - }, - "required": ["file_path"] - } - } - }, - { - "type": "function", - "function": { - "name": "list_directory", - "description": "List all files and subdirectories in a directory from the user's local filesystem. Only works for directories within allowed paths. Automatically filters out virtual environments, build artifacts, and .gitignore patterns. Limited to 1000 items.", - "parameters": { - "type": "object", - "properties": { - "dir_path": { - "type": "string", - "description": "Directory path to list (e.g., ~/Documents or /Users/username/Projects)" - }, - "recursive": { - "type": "boolean", - "description": "If true, lists all files in subdirectories recursively. If false, lists only immediate children. Default is true. WARNING: Recursive listings can be very large - use specific paths when possible.", - "default": True - } - }, - "required": ["dir_path"] - } - } - } - ] - - # Add write tools if write mode is enabled - if self.write_enabled: - tools.extend([ - { - "type": "function", - "function": { - "name": "write_file", - "description": f"Create a new file or overwrite an existing file with the specified content. Works within allowed directories: {allowed_dirs_str}. Ignores .gitignore patterns - can write to any file in allowed folders. Automatically creates parent directories if needed.", - "parameters": { - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "Full path to the file to create or overwrite (e.g., /Users/username/project/src/main.py)" - }, - "content": { - "type": "string", - "description": "The complete content to write to the file" - } - }, - "required": ["file_path", "content"] - } - } - }, - { - "type": "function", - "function": { - "name": "edit_file", - "description": f"Make targeted edits to an existing file by finding and replacing specific text. The old_text must match exactly and appear only once in the file. For multiple matches, provide more context to make the match unique. Works within allowed directories: {allowed_dirs_str}. Ignores .gitignore patterns.", - "parameters": { - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "Full path to the file to edit" - }, - "old_text": { - "type": "string", - "description": "The exact text to find and replace. Must match exactly and appear only once. Include enough context to make the match unique." - }, - "new_text": { - "type": "string", - "description": "The new text to replace the old text with" - } - }, - "required": ["file_path", "old_text", "new_text"] - } - } - }, - { - "type": "function", - "function": { - "name": "delete_file", - "description": f"Delete a file from the filesystem. ALWAYS requires user confirmation before deletion. The user will see the file path, size, modification time, and your reason before deciding. Works within allowed directories: {allowed_dirs_str}. Ignores .gitignore patterns.", - "parameters": { - "type": "object", - "properties": { - "file_path": { - "type": "string", - "description": "Full path to the file to delete" - }, - "reason": { - "type": "string", - "description": "Clear explanation for why this file should be deleted. The user will see this reason in the confirmation prompt." - } - }, - "required": ["file_path", "reason"] - } - } - }, - { - "type": "function", - "function": { - "name": "create_directory", - "description": f"Create a new directory (and all parent directories if needed). If the directory already exists, returns success without error. Works within allowed directories: {allowed_dirs_str}. Ignores .gitignore patterns.", - "parameters": { - "type": "object", - "properties": { - "dir_path": { - "type": "string", - "description": "Full path to the directory to create (e.g., /Users/username/project/src/components)" - } - }, - "required": ["dir_path"] - } - } - }, - { - "type": "function", - "function": { - "name": "move_file", - "description": f"Move or rename a file. Both source and destination must be within allowed directories: {allowed_dirs_str}. Creates destination parent directories if needed. Ignores .gitignore patterns.", - "parameters": { - "type": "object", - "properties": { - "source_path": { - "type": "string", - "description": "Full path to the file to move/rename" - }, - "dest_path": { - "type": "string", - "description": "Full path for the new location/name" - } - }, - "required": ["source_path", "dest_path"] - } - } - }, - { - "type": "function", - "function": { - "name": "copy_file", - "description": f"Copy a file to a new location. Both source and destination must be within allowed directories: {allowed_dirs_str}. Creates destination parent directories if needed. Preserves file metadata. Ignores .gitignore patterns.", - "parameters": { - "type": "object", - "properties": { - "source_path": { - "type": "string", - "description": "Full path to the file to copy" - }, - "dest_path": { - "type": "string", - "description": "Full path for the copy destination" - } - }, - "required": ["source_path", "dest_path"] - } - } - } - ]) - - return tools - - def _get_database_tools_schema(self) -> List[Dict[str, Any]]: - """Get database mode tools schema.""" - if self.selected_db_index is None or self.selected_db_index >= len(self.databases): - return [] - - db = self.databases[self.selected_db_index] - db_name = db['name'] - tables_str = ", ".join(db['tables']) - - return [ - { - "type": "function", - "function": { - "name": "inspect_database", - "description": f"Inspect the schema of the currently selected SQLite database ({db_name}). Can get all tables or details for a specific table including columns, types, indexes, and row counts. Available tables: {tables_str}.", - "parameters": { - "type": "object", - "properties": { - "table_name": { - "type": "string", - "description": f"Optional: specific table to inspect. If not provided, returns info for all tables. Available: {tables_str}" - } - }, - "required": [] - } - } - }, - { - "type": "function", - "function": { - "name": "search_database", - "description": f"Search for a value across all tables in the database ({db_name}). Performs a LIKE search (partial matching) across all columns or specific table/column. Returns matching rows with all column data. Limited to {self.server.default_query_limit} results.", - "parameters": { - "type": "object", - "properties": { - "search_term": { - "type": "string", - "description": "Value to search for (partial match supported)" - }, - "table_name": { - "type": "string", - "description": f"Optional: limit search to specific table. Available: {tables_str}" - }, - "column_name": { - "type": "string", - "description": "Optional: limit search to specific column within the table" - } - }, - "required": ["search_term"] - } - } - }, - { - "type": "function", - "function": { - "name": "query_database", - "description": f"Execute a read-only SQL query on the database ({db_name}). Supports SELECT queries including JOINs, subqueries, CTEs, aggregations, etc. Maximum {self.server.max_query_results} rows returned. Query timeout: {self.server.max_query_timeout} seconds. INSERT/UPDATE/DELETE/DROP are blocked for safety.", - "parameters": { - "type": "object", - "properties": { - "query": { - "type": "string", - "description": f"SQL SELECT query to execute. Available tables: {tables_str}. Example: SELECT * FROM table WHERE column = 'value' LIMIT 10" - }, - "limit": { - "type": "integer", - "description": f"Optional: maximum rows to return (default {self.server.default_query_limit}, max {self.server.max_query_results}). You can also use LIMIT in your query." - } - }, - "required": ["query"] - } - } - } - ] - -# Global MCP manager -mcp_manager = MCPManager() - -# ============================================================================ -# HELPER FUNCTIONS -# ============================================================================ - -def supports_function_calling(model: Dict[str, Any]) -> bool: - """Check if model supports function calling.""" - supported_params = model.get("supported_parameters", []) - return "tools" in supported_params or "functions" in supported_params - -def supports_tools(model: Dict[str, Any]) -> bool: - """Check if model supports tools/function calling (same as supports_function_calling).""" - return supports_function_calling(model) - -def check_for_updates(current_version: str) -> str: - """Check for updates.""" - try: - response = requests.get( - 'https://gitlab.pm/api/v1/repos/rune/oai/releases/latest', - headers={"Content-Type": "application/json"}, - timeout=1.0, - allow_redirects=True - ) - response.raise_for_status() - - data = response.json() - version_online = data.get('tag_name', '').lstrip('v') - - if not version_online: - logger.warning("No version found in API response") - return f"[bold green]oAI version {current_version}[/]" - - current = pkg_version.parse(current_version) - latest = pkg_version.parse(version_online) - - if latest > current: - logger.info(f"Update available: {current_version} โ†’ {version_online}") - return f"[bold green]oAI version {current_version} [/][bold red](Update available: {current_version} โ†’ {version_online})[/]" - else: - logger.debug(f"Already up to date: {current_version}") - return f"[bold green]oAI version {current_version} (up to date)[/]" - - except: - return f"[bold green]oAI version {current_version}[/]" - -def save_conversation(name: str, data: List[Dict[str, str]]): - """Save conversation.""" - timestamp = datetime.datetime.now().isoformat() - data_json = json.dumps(data) - with sqlite3.connect(str(database)) as conn: - conn.execute('INSERT INTO conversation_sessions (name, timestamp, data) VALUES (?, ?, ?)', (name, timestamp, data_json)) - conn.commit() - -def load_conversation(name: str) -> Optional[List[Dict[str, str]]]: - """Load conversation.""" - with sqlite3.connect(str(database)) as conn: - cursor = conn.execute('SELECT data FROM conversation_sessions WHERE name = ? ORDER BY timestamp DESC LIMIT 1', (name,)) - result = cursor.fetchone() - if result: - return json.loads(result[0]) - return None - -def delete_conversation(name: str) -> int: - """Delete conversation.""" - with sqlite3.connect(str(database)) as conn: - cursor = conn.execute('DELETE FROM conversation_sessions WHERE name = ?', (name,)) - conn.commit() - return cursor.rowcount - -def list_conversations() -> List[Dict[str, Any]]: - """List conversations.""" - with sqlite3.connect(str(database)) as conn: - cursor = conn.execute(''' - SELECT name, MAX(timestamp) as last_saved, data - FROM conversation_sessions - GROUP BY name - ORDER BY last_saved DESC - ''') - conversations = [] - for row in cursor.fetchall(): - name, timestamp, data_json = row - data = json.loads(data_json) - conversations.append({ - 'name': name, - 'timestamp': timestamp, - 'message_count': len(data) - }) - return conversations - -def estimate_cost(input_tokens: int, output_tokens: int) -> float: - """Estimate cost.""" - return (input_tokens * MODEL_PRICING['input'] / 1_000_000) + (output_tokens * MODEL_PRICING['output'] / 1_000_000) - -def has_web_search_capability(model: Dict[str, Any]) -> bool: - """Check web search capability.""" - supported_params = model.get("supported_parameters", []) - return "tools" in supported_params - -def has_image_capability(model: Dict[str, Any]) -> bool: - """Check image capability.""" - architecture = model.get("architecture", {}) - input_modalities = architecture.get("input_modalities", []) - return "image" in input_modalities - -def supports_online_mode(model: Dict[str, Any]) -> bool: - """Check online mode support.""" - return has_web_search_capability(model) - -def get_effective_model_id(base_model_id: str, online_enabled: bool) -> str: - """Get effective model ID.""" - if online_enabled and not base_model_id.endswith(':online'): - return f"{base_model_id}:online" - return base_model_id - -def export_as_markdown(session_history: List[Dict[str, str]], session_system_prompt: str = "") -> str: - """Export as Markdown.""" - lines = ["# Conversation Export", ""] - if session_system_prompt: - lines.extend([f"**System Prompt:** {session_system_prompt}", ""]) - lines.append(f"**Export Date:** {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") - lines.append("") - lines.append("---") - lines.append("") - - for i, entry in enumerate(session_history, 1): - lines.append(f"## Message {i}") - lines.append("") - lines.append("**User:**") - lines.append("") - lines.append(entry['prompt']) - lines.append("") - lines.append("**Assistant:**") - lines.append("") - lines.append(entry['response']) - lines.append("") - lines.append("---") - lines.append("") - - return "\n".join(lines) - -def export_as_json(session_history: List[Dict[str, str]], session_system_prompt: str = "") -> str: - """Export as JSON.""" - export_data = { - "export_date": datetime.datetime.now().isoformat(), - "system_prompt": session_system_prompt, - "message_count": len(session_history), - "messages": session_history - } - return json.dumps(export_data, indent=2, ensure_ascii=False) - -def export_as_html(session_history: List[Dict[str, str]], session_system_prompt: str = "") -> str: - """Export as HTML.""" - def escape_html(text): - return text.replace('&', '&').replace('<', '<').replace('>', '>').replace('"', '"').replace("'", ''') - - html_parts = [ - "", - "", - "", - " ", - " ", - " Conversation Export", - " ", - "", - "", - "
", - "

๐Ÿ’ฌ Conversation Export

", - f"
๐Ÿ“… Exported: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
", - f"
๐Ÿ“Š Total Messages: {len(session_history)}
", - "
", - ] - - if session_system_prompt: - html_parts.extend([ - "
", - " โš™๏ธ System Prompt", - f"
{escape_html(session_system_prompt)}
", - "
", - ]) - - for i, entry in enumerate(session_history, 1): - html_parts.extend([ - "
", - f"
Message {i} of {len(session_history)}
", - "
", - "
๐Ÿ‘ค User
", - f"
{escape_html(entry['prompt'])}
", - "
", - "
", - "
๐Ÿค– Assistant
", - f"
{escape_html(entry['response'])}
", - "
", - "
", - ]) - - html_parts.extend([ - "
", - "

Generated by oAI Chat โ€ข https://iurl.no/oai

", - "
", - "", - "", - ]) - - return "\n".join(html_parts) - -def show_command_help(command_or_topic: str): - """Display detailed help for command or topic.""" - # Handle topics (like 'mcp') - if not command_or_topic.startswith('/'): - if command_or_topic.lower() == 'mcp': - help_data = COMMAND_HELP['mcp'] - - help_content = [] - help_content.append(f"[bold cyan]Description:[/]") - help_content.append(help_data['description']) - help_content.append("") - help_content.append(help_data['notes']) - - console.print(Panel( - "\n".join(help_content), - title="[bold green]MCP - Model Context Protocol Guide[/]", - title_align="left", - border_style="green", - width=console.width - 4 - )) - - app_logger.info("Displayed MCP comprehensive guide") - return - else: - command_or_topic = '/' + command_or_topic - - if command_or_topic not in COMMAND_HELP: - console.print(f"[bold red]Unknown command: {command_or_topic}[/]") - console.print("[bold yellow]Type /help to see all available commands.[/]") - console.print("[bold yellow]Type /help mcp for comprehensive MCP guide.[/]") - app_logger.warning(f"Help requested for unknown command: {command_or_topic}") - return - - help_data = COMMAND_HELP[command_or_topic] - - help_content = [] - - if 'aliases' in help_data: - aliases_str = ", ".join(help_data['aliases']) - help_content.append(f"[bold cyan]Aliases:[/] {aliases_str}") - help_content.append("") - - help_content.append(f"[bold cyan]Description:[/]") - help_content.append(help_data['description']) - help_content.append("") - - help_content.append(f"[bold cyan]Usage:[/]") - help_content.append(f"[yellow]{help_data['usage']}[/]") - help_content.append("") - - if 'examples' in help_data and help_data['examples']: - help_content.append(f"[bold cyan]Examples:[/]") - for desc, example in help_data['examples']: - if not desc and not example: - help_content.append("") - elif desc.startswith('โ”โ”โ”'): - help_content.append(f"[bold yellow]{desc}[/]") - else: - help_content.append(f" [dim]{desc}:[/]" if desc else "") - help_content.append(f" [green]{example}[/]" if example else "") - help_content.append("") - - if 'notes' in help_data: - help_content.append(f"[bold cyan]Notes:[/]") - help_content.append(f"[dim]{help_data['notes']}[/]") - - console.print(Panel( - "\n".join(help_content), - title=f"[bold green]Help: {command_or_topic}[/]", - title_align="left", - border_style="green", - width=console.width - 4 - )) - - app_logger.info(f"Displayed detailed help for command: {command_or_topic}") - -def get_credits(api_key: str, base_url: str = OPENROUTER_BASE_URL) -> Optional[Dict[str, str]]: - """Get credits.""" - if not api_key: - return None - url = f"{base_url}/credits" - headers = { - "Authorization": f"Bearer {api_key}", - "HTTP-Referer": APP_URL, - "X-Title": APP_NAME - } - try: - response = requests.get(url, headers=headers) - response.raise_for_status() - data = response.json().get('data', {}) - total_credits = float(data.get('total_credits', 0)) - total_usage = float(data.get('total_usage', 0)) - credits_left = total_credits - total_usage - return { - 'total_credits': f"${total_credits:.2f}", - 'used_credits': f"${total_usage:.2f}", - 'credits_left': f"${credits_left:.2f}" - } - except Exception as e: - console.print(f"[bold red]Error fetching credits: {e}[/]") - return None - -def check_credit_alerts(credits_data: Optional[Dict[str, str]]) -> List[str]: - """Check credit alerts.""" - alerts = [] - if credits_data: - credits_left_value = float(credits_data['credits_left'].strip('$')) - total_credits_value = float(credits_data['total_credits'].strip('$')) - if credits_left_value < LOW_CREDIT_AMOUNT: - alerts.append(f"Critical credit alert: Less than ${LOW_CREDIT_AMOUNT:.2f} left ({credits_data['credits_left']})") - elif credits_left_value < total_credits_value * LOW_CREDIT_RATIO: - alerts.append(f"Low credit alert: Credits left < 10% of total ({credits_data['credits_left']})") - return alerts - -def clear_screen(): - """Clear screen.""" - try: - print("\033[H\033[J", end="", flush=True) - except: - print("\n" * 100) - -def display_paginated_table(table: Table, title: str): - """Display paginated table.""" - try: - terminal_height = os.get_terminal_size().lines - 8 - except: - terminal_height = 20 - - from rich.segment import Segment - - segments = list(console.render(table)) - - current_line_segments = [] - all_lines = [] - - for segment in segments: - if segment.text == '\n': - all_lines.append(current_line_segments) - current_line_segments = [] - else: - current_line_segments.append(segment) - - if current_line_segments: - all_lines.append(current_line_segments) - - total_lines = len(all_lines) - - if total_lines <= terminal_height: - console.print(Panel(table, title=title, title_align="left")) - return - - header_lines = [] - data_lines = [] - footer_line = [] - - header_end_index = 0 - found_header_text = False - - for i, line_segments in enumerate(all_lines): - has_header_style = any( - seg.style and ('bold' in str(seg.style) or 'magenta' in str(seg.style)) - for seg in line_segments - ) - - if has_header_style: - found_header_text = True - - if found_header_text and i > 0: - line_text = ''.join(seg.text for seg in line_segments) - if any(char in line_text for char in ['โ”€', 'โ”', 'โ”ผ', 'โ•ช', 'โ”ค', 'โ”œ']): - header_end_index = i - break - - # Extract footer (bottom border) - it's the last line of the table - if all_lines: - last_line_text = ''.join(seg.text for seg in all_lines[-1]) - # Check if last line is a border line (contains box drawing characters) - if any(char in last_line_text for char in ['โ”€', 'โ”', 'โ”ด', 'โ•ง', 'โ”˜', 'โ””']): - footer_line = all_lines[-1] - all_lines = all_lines[:-1] # Remove footer from all_lines - - if header_end_index > 0: - header_lines = all_lines[:header_end_index + 1] - data_lines = all_lines[header_end_index + 1:] - else: - header_lines = all_lines[:min(3, len(all_lines))] - data_lines = all_lines[min(3, len(all_lines)):] - - lines_per_page = terminal_height - len(header_lines) - - current_line = 0 - page_number = 1 - - while current_line < len(data_lines): - clear_screen() - - console.print(f"[bold cyan]{title} (Page {page_number})[/]") - - for line_segments in header_lines: - for segment in line_segments: - console.print(segment.text, style=segment.style, end="") - console.print() - - end_line = min(current_line + lines_per_page, len(data_lines)) - - for line_segments in data_lines[current_line:end_line]: - for segment in line_segments: - console.print(segment.text, style=segment.style, end="") - console.print() - - # Add footer (bottom border) on each page - if footer_line: - for segment in footer_line: - console.print(segment.text, style=segment.style, end="") - console.print() - - current_line = end_line - page_number += 1 - - if current_line < len(data_lines): - console.print(f"\n[dim yellow]--- Press SPACE for next page, or any other key to finish (Page {page_number - 1}, showing {end_line}/{len(data_lines)} data rows) ---[/dim yellow]") - try: - import sys - import tty - import termios - - fd = sys.stdin.fileno() - old_settings = termios.tcgetattr(fd) - - try: - tty.setraw(fd) - char = sys.stdin.read(1) - - if char != ' ': - break - finally: - termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) - except: - input_char = input().strip() - if input_char != '': - break - else: - break - # P3 - # ============================================================================ -# MAIN CHAT FUNCTION WITH FULL MCP SUPPORT (FILES + DATABASE) -# ============================================================================ - -@app.command() -def chat(): - global API_KEY, OPENROUTER_BASE_URL, STREAM_ENABLED, MAX_TOKEN, COST_WARNING_THRESHOLD, DEFAULT_ONLINE_MODE, LOG_MAX_SIZE_MB, LOG_BACKUP_COUNT, LOG_LEVEL, LOG_LEVEL_STR, app_logger - - session_max_token = 0 - session_system_prompt = "" - session_history = [] - current_index = -1 - total_input_tokens = 0 - total_output_tokens = 0 - total_cost = 0.0 - message_count = 0 - middle_out_enabled = False - conversation_memory_enabled = True - memory_start_index = 0 - saved_conversations_cache = [] - online_mode_enabled = DEFAULT_ONLINE_MODE == "on" - - app_logger.info("Starting new chat session with memory enabled") - - if not API_KEY: - console.print("[bold red]API key not found. Use '/config api'.[/]") - try: - new_api_key = typer.prompt("Enter API key") - if new_api_key.strip(): - set_config('api_key', new_api_key.strip()) - API_KEY = new_api_key.strip() - console.print("[bold green]API key saved. Re-run.[/]") - else: - raise typer.Exit() - except: - console.print("[bold red]No API key. Exiting.[/]") - raise typer.Exit() - - if not text_models: - console.print("[bold red]No models available. Check API key/URL.[/]") - raise typer.Exit() - - credits_data = get_credits(API_KEY, OPENROUTER_BASE_URL) - startup_credit_alerts = check_credit_alerts(credits_data) - if startup_credit_alerts: - startup_alert_msg = " | ".join(startup_credit_alerts) - console.print(f"[bold red]โš ๏ธ Startup {startup_alert_msg}[/]") - app_logger.warning(f"Startup credit alerts: {startup_alert_msg}") - - selected_model = selected_model_default - - client = OpenRouter(api_key=API_KEY) - - if selected_model: - online_status = "enabled" if online_mode_enabled else "disabled" - console.print(f"[bold blue]Welcome to oAI![/] [bold red]Active model: {selected_model['name']}[/] [dim cyan](Online mode: {online_status})[/]") - else: - console.print("[bold blue]Welcome to oAI![/] [italic blue]Select a model with '/model'.[/]") - - if not selected_model: - console.print("[bold yellow]No model selected. Use '/model'.[/]") - - # Create custom key bindings to add tab support for autocomplete - kb = KeyBindings() - - @kb.add('tab') - def _(event): - """Accept suggestion with tab key.""" - b = event.current_buffer - suggestion = b.suggestion - if suggestion and b.document.is_cursor_at_the_end: - b.insert_text(suggestion.text) - - session = PromptSession( - history=FileHistory(str(history_file)), - key_bindings=kb - ) - - while True: - try: - # ============================================================ - # BUILD PROMPT PREFIX WITH MODE INDICATOR - # ============================================================ - prompt_prefix = "You> " - if mcp_manager.enabled: - if mcp_manager.mode == "files": - if mcp_manager.write_enabled: - prompt_prefix = "[๐Ÿ”งโœ๏ธ MCP: Files+Write] You> " - else: - prompt_prefix = "[๐Ÿ”ง MCP: Files] You> " - elif mcp_manager.mode == "database" and mcp_manager.selected_db_index is not None: - db = mcp_manager.databases[mcp_manager.selected_db_index] - prompt_prefix = f"[๐Ÿ—„๏ธ MCP: DB #{mcp_manager.selected_db_index + 1}] You> " - - # ============================================================ - # INITIALIZE LOOP VARIABLES - # ============================================================ - text_part = "" - file_attachments = [] - content_blocks = [] - - user_input = session.prompt(prompt_prefix, auto_suggest=AutoSuggestFromHistory()).strip() - - # Handle escape sequence - if user_input.startswith("//"): - user_input = user_input[1:] - - # Check unknown commands - elif user_input.startswith("/") and user_input.lower() not in ["exit", "quit", "bye"]: - command_word = user_input.split()[0].lower() if user_input.split() else user_input.lower() - - if not any(command_word.startswith(cmd) for cmd in VALID_COMMANDS): - console.print(f"[bold red]Unknown command: {command_word}[/]") - console.print("[bold yellow]Type /help to see all available commands.[/]") - app_logger.warning(f"Unknown command attempted: {command_word}") - continue - - if user_input.lower() in ["exit", "quit", "bye"]: - total_tokens = total_input_tokens + total_output_tokens - app_logger.info(f"Session ended. Total messages: {message_count}, Total tokens: {total_tokens}, Total cost: ${total_cost:.4f}") - console.print("[bold yellow]Goodbye![/]") - return - - # ============================================================ - # MCP COMMANDS - # ============================================================ - if user_input.lower().startswith("/mcp"): - parts = user_input[5:].strip().split(maxsplit=1) - mcp_command = parts[0].lower() if parts else "" - mcp_args = parts[1] if len(parts) > 1 else "" - - if mcp_command in ["enable", "on"]: - result = mcp_manager.enable() - if result['success']: - console.print(f"[bold green]โœ“ {result['message']}[/]") - if result['folder_count'] == 0 and result['database_count'] == 0: - console.print("\n[bold yellow]No folders or databases configured yet.[/]") - console.print("\n[bold cyan]To get started:[/]") - console.print(" [bold yellow]Files:[/] /mcp add ~/Documents") - console.print(" [bold yellow]Databases:[/] /mcp add db ~/app/data.db") - else: - folders_msg = f"{result['folder_count']} folder(s)" if result['folder_count'] else "no folders" - dbs_msg = f"{result['database_count']} database(s)" if result['database_count'] else "no databases" - console.print(f"\n[bold cyan]MCP active with {folders_msg} and {dbs_msg}.[/]") - - warning = mcp_manager.config.get_filesystem_warning() - if warning: - console.print(warning) - else: - console.print(f"[bold red]โŒ {result['error']}[/]") - continue - - elif mcp_command in ["disable", "off"]: - result = mcp_manager.disable() - if result['success']: - console.print(f"[bold green]โœ“ {result['message']}[/]") - else: - console.print(f"[bold red]โŒ {result['error']}[/]") - continue - - elif mcp_command == "add": - # Check if it's a database or folder - if mcp_args.startswith("db "): - # Database - db_path = mcp_args[3:].strip() - if not db_path: - console.print("[bold red]Usage: /mcp add db [/]") - continue - - result = mcp_manager.add_database(db_path) - - if result['success']: - db = result['database'] - - console.print("\n[bold yellow]โš ๏ธ Database Check:[/]") - console.print(f"Adding: [bold]{db['name']}[/]") - console.print(f"Size: {db['size'] / (1024 * 1024):.2f} MB") - console.print(f"Tables: {', '.join(db['tables'])} ({len(db['tables'])} total)") - - if db['size'] > 100 * 1024 * 1024: # > 100MB - console.print("\n[bold yellow]โš ๏ธ Large database detected! Queries may be slow.[/]") - - console.print("\nMCP will be able to:") - console.print(" [green]โœ“[/green] Inspect database schema") - console.print(" [green]โœ“[/green] Search for data across tables") - console.print(" [green]โœ“[/green] Execute read-only SQL queries") - console.print(" [red]โœ—[/red] Modify data (read-only mode)") - - try: - confirm = typer.confirm("\nProceed?", default=True) - if not confirm: - # Remove it - mcp_manager.remove_database(str(result['number'])) - console.print("[bold yellow]Cancelled. Database not added.[/]") - continue - except (EOFError, KeyboardInterrupt): - mcp_manager.remove_database(str(result['number'])) - console.print("\n[bold yellow]Cancelled. Database not added.[/]") - continue - - console.print(f"\n[bold green]โœ“ {result['message']}[/]") - console.print(f"\n[bold cyan]Use '/mcp db {result['number']}' to start querying this database.[/]") - else: - console.print(f"[bold red]โŒ {result['error']}[/]") - else: - # Folder - folder_path = mcp_args - if not folder_path: - console.print("[bold red]Usage: /mcp add or /mcp add db [/]") - continue - - result = mcp_manager.add_folder(folder_path) - - if result['success']: - stats = result['stats'] - - console.print("\n[bold yellow]โš ๏ธ Security Check:[/]") - console.print(f"You are granting MCP access to: [bold]{result['path']}[/]") - - if stats.get('exists'): - file_count = stats['file_count'] - size_mb = stats['size_mb'] - console.print(f"This folder contains: [bold]{file_count} files ({size_mb:.1f} MB)[/]") - - console.print("\nMCP will be able to:") - console.print(" [green]โœ“[/green] Read files in this folder") - console.print(" [green]โœ“[/green] List and search files") - console.print(" [green]โœ“[/green] Access subfolders recursively") - console.print(" [green]โœ“[/green] Automatically respect .gitignore patterns") - console.print(" [red]โœ—[/red] Delete or modify files (read-only)") - - try: - confirm = typer.confirm("\nProceed?", default=True) - if not confirm: - mcp_manager.allowed_folders.remove(mcp_manager.config.normalize_path(folder_path)) - mcp_manager._save_folders() - console.print("[bold yellow]Cancelled. Folder not added.[/]") - continue - except (EOFError, KeyboardInterrupt): - mcp_manager.allowed_folders.remove(mcp_manager.config.normalize_path(folder_path)) - mcp_manager._save_folders() - console.print("\n[bold yellow]Cancelled. Folder not added.[/]") - continue - - console.print(f"\n[bold green]โœ“ Added {result['path']} to MCP allowed folders[/]") - console.print(f"[dim cyan]MCP now has access to {result['total_folders']} folder(s) total.[/]") - - if result.get('warning'): - console.print(f"\n[bold yellow]โš ๏ธ {result['warning']}[/]") - else: - console.print(f"[bold red]โŒ {result['error']}[/]") - continue - - elif mcp_command in ["remove", "rem"]: - # Check if it's a database or folder - if mcp_args.startswith("db "): - # Database - db_ref = mcp_args[3:].strip() - if not db_ref: - console.print("[bold red]Usage: /mcp remove db [/]") - continue - - result = mcp_manager.remove_database(db_ref) - - if result['success']: - console.print(f"\n[bold yellow]Removing: {result['database']['name']}[/]") - - try: - confirm = typer.confirm("Confirm removal?", default=False) - if not confirm: - console.print("[bold yellow]Cancelled.[/]") - # Re-add it - mcp_manager.databases.append(result['database']) - mcp_manager._save_database(result['database']) - continue - except (EOFError, KeyboardInterrupt): - console.print("\n[bold yellow]Cancelled.[/]") - mcp_manager.databases.append(result['database']) - mcp_manager._save_database(result['database']) - continue - - console.print(f"\n[bold green]โœ“ {result['message']}[/]") - - if result.get('warning'): - console.print(f"\n[bold yellow]โš ๏ธ {result['warning']}[/]") - else: - console.print(f"[bold red]โŒ {result['error']}[/]") - else: - # Folder - if not mcp_args: - console.print("[bold red]Usage: /mcp remove or /mcp remove db [/]") - continue - - result = mcp_manager.remove_folder(mcp_args) - - if result['success']: - console.print(f"\n[bold yellow]Removing: {result['path']}[/]") - - try: - confirm = typer.confirm("Confirm removal?", default=False) - if not confirm: - mcp_manager.allowed_folders.append(mcp_manager.config.normalize_path(result['path'])) - mcp_manager._save_folders() - console.print("[bold yellow]Cancelled. Folder not removed.[/]") - continue - except (EOFError, KeyboardInterrupt): - mcp_manager.allowed_folders.append(mcp_manager.config.normalize_path(result['path'])) - mcp_manager._save_folders() - console.print("\n[bold yellow]Cancelled. Folder not removed.[/]") - continue - - console.print(f"\n[bold green]โœ“ Removed {result['path']} from MCP allowed folders[/]") - console.print(f"[dim cyan]MCP now has access to {result['total_folders']} folder(s) total.[/]") - - if result.get('warning'): - console.print(f"\n[bold yellow]โš ๏ธ {result['warning']}[/]") - else: - console.print(f"[bold red]โŒ {result['error']}[/]") - continue - - elif mcp_command == "list": - result = mcp_manager.list_folders() - - if result['success']: - if result['total_folders'] == 0: - console.print("[bold yellow]No folders configured.[/]") - console.print("\n[bold cyan]Add a folder with: /mcp add ~/Documents[/]") - continue - - status_indicator = "[green]โœ“[/green]" if mcp_manager.enabled else "[red]โœ—[/red]" - - table = Table( - "No.", "Path", "Files", "Size", - show_header=True, - header_style="bold magenta" - ) - - for folder_info in result['folders']: - number = str(folder_info['number']) - path = folder_info['path'] - - if folder_info['exists']: - files = f"๐Ÿ“ {folder_info['file_count']}" - size = f"{folder_info['size_mb']:.1f} MB" - else: - files = "[red]Not found[/red]" - size = "-" - - table.add_row(number, path, files, size) - - gitignore_info = "" - if mcp_manager.server: - gitignore_status = "on" if mcp_manager.server.respect_gitignore else "off" - pattern_count = len(mcp_manager.server.gitignore_parser.patterns) - gitignore_info = f" | .gitignore: {gitignore_status} ({pattern_count} patterns)" - - console.print(Panel( - table, - title=f"[bold green]MCP Folders: {'Active' if mcp_manager.enabled else 'Inactive'} {status_indicator}[/]", - title_align="left", - subtitle=f"[dim]Total: {result['total_folders']} folders, {result['total_files']} files ({result['total_size_mb']:.1f} MB){gitignore_info}[/]", - subtitle_align="right" - )) - else: - console.print(f"[bold red]โŒ {result['error']}[/]") - continue - - elif mcp_command == "db": - if not mcp_args: - # Show current database or list all - if mcp_manager.mode == "database" and mcp_manager.selected_db_index is not None: - db = mcp_manager.databases[mcp_manager.selected_db_index] - console.print(f"[bold cyan]Currently using database #{mcp_manager.selected_db_index + 1}: {db['name']}[/]") - console.print(f"[dim]Path: {db['path']}[/]") - console.print(f"[dim]Tables: {', '.join(db['tables'])}[/]") - else: - console.print("[bold yellow]Not in database mode. Use '/mcp db ' to select a database.[/]") - - # Also show hint to list - console.print("\n[dim]Use '/mcp db list' to see all databases[/]") - continue - - if mcp_args == "list": - result = mcp_manager.list_databases() - - if result['success']: - if result['count'] == 0: - console.print("[bold yellow]No databases configured.[/]") - console.print("\n[bold cyan]Add a database with: /mcp add db ~/app/data.db[/]") - continue - - table = Table( - "No.", "Name", "Tables", "Size", "Status", - show_header=True, - header_style="bold magenta" - ) - - for db_info in result['databases']: - number = str(db_info['number']) - name = db_info['name'] - table_count = f"{db_info['table_count']} tables" - size = f"{db_info['size_mb']:.1f} MB" - - if db_info.get('warning'): - status = f"[red]{db_info['warning']}[/red]" - else: - status = "[green]โœ“[/green]" - - table.add_row(number, name, table_count, size, status) - - console.print(Panel( - table, - title="[bold green]MCP Databases[/]", - title_align="left", - subtitle=f"[dim]Total: {result['count']} database(s) | Use '/mcp db ' to select[/]", - subtitle_align="right" - )) - else: - console.print(f"[bold red]โŒ {result['error']}[/]") - continue - - # Switch to database mode - try: - db_num = int(mcp_args) - result = mcp_manager.switch_mode("database", db_num) - - if result['success']: - db = result['database'] - console.print(f"\n[bold green]โœ“ {result['message']}[/]") - console.print(f"[dim cyan]Tables: {', '.join(db['tables'])}[/]") - console.print(f"\n[bold cyan]Available tools:[/] inspect_database, search_database, query_database") - console.print(f"\n[bold yellow]You can now ask questions about this database![/]") - console.print(f"[dim]Examples:[/]") - console.print(f" 'Show me all tables'") - console.print(f" 'Search for records mentioning X'") - console.print(f" 'How many rows in the users table?'") - console.print(f"\n[dim]Switch back to files: /mcp files[/]") - else: - console.print(f"[bold red]โŒ {result['error']}[/]") - except ValueError: - console.print(f"[bold red]Invalid database number: {mcp_args}[/]") - console.print(f"[bold yellow]Use '/mcp db list' to see available databases[/]") - continue - - elif mcp_command == "files": - result = mcp_manager.switch_mode("files") - - if result['success']: - console.print(f"[bold green]โœ“ {result['message']}[/]") - console.print(f"\n[bold cyan]Available tools:[/] read_file, list_directory, search_files") - else: - console.print(f"[bold red]โŒ {result['error']}[/]") - continue - - elif mcp_command == "gitignore": - if not mcp_args: - if mcp_manager.server: - status = "enabled" if mcp_manager.server.respect_gitignore else "disabled" - pattern_count = len(mcp_manager.server.gitignore_parser.patterns) - console.print(f"[bold blue].gitignore filtering: {status}[/]") - console.print(f"[dim cyan]Loaded {pattern_count} pattern(s) from .gitignore files[/]") - else: - console.print("[bold yellow]MCP is not enabled. Enable with /mcp on[/]") - continue - - if mcp_args.lower() == "on": - result = mcp_manager.toggle_gitignore(True) - if result['success']: - console.print(f"[bold green]โœ“ {result['message']}[/]") - console.print(f"[dim cyan]Using {result['pattern_count']} .gitignore pattern(s)[/]") - else: - console.print(f"[bold red]โŒ {result['error']}[/]") - elif mcp_args.lower() == "off": - result = mcp_manager.toggle_gitignore(False) - if result['success']: - console.print(f"[bold green]โœ“ {result['message']}[/]") - console.print("[bold yellow]Warning: All files will be visible, including those in .gitignore[/]") - else: - console.print(f"[bold red]โŒ {result['error']}[/]") - else: - console.print("[bold yellow]Usage: /mcp gitignore on|off[/]") - continue - - elif mcp_command == "write": - if not mcp_args: - # Show current write mode status - if mcp_manager.enabled: - status = "[bold green]Enabled โš ๏ธ[/]" if mcp_manager.write_enabled else "[dim]Disabled[/]" - console.print(f"[bold blue]Write Mode:[/] {status}") - if mcp_manager.write_enabled: - console.print("[dim cyan]LLM can create, modify, and delete files in allowed folders[/]") - else: - console.print("[dim cyan]Use '/mcp write on' to enable write operations[/]") - else: - console.print("[bold yellow]MCP is not enabled. Enable with /mcp on[/]") - continue - - if mcp_args.lower() == "on": - mcp_manager.enable_write() - elif mcp_args.lower() == "off": - mcp_manager.disable_write() - else: - console.print("[bold yellow]Usage: /mcp write on|off[/]") - continue - - elif mcp_command == "status": - result = mcp_manager.get_status() - - if result['success']: - status_color = "green" if result['enabled'] else "red" - status_text = "Active โœ“" if result['enabled'] else "Inactive โœ—" - - table = Table( - "Property", "Value", - show_header=True, - header_style="bold magenta" - ) - - table.add_row("Status", f"[{status_color}]{status_text}[/{status_color}]") - - # Mode info - mode_info = result['mode_info'] - table.add_row("Current Mode", mode_info['mode_display']) - if 'database' in mode_info: - db = mode_info['database'] - table.add_row("Database Path", db['path']) - table.add_row("Database Tables", ', '.join(db['tables'])) - - if result['uptime']: - table.add_row("Uptime", result['uptime']) - - table.add_row("", "") - table.add_row("[bold]Configuration[/]", "") - table.add_row("Allowed Folders", str(result['folder_count'])) - table.add_row("Databases", str(result['database_count'])) - table.add_row("Total Files Accessible", str(result['total_files'])) - table.add_row("Total Size", f"{result['total_size_mb']:.1f} MB") - table.add_row(".gitignore Filtering", result['gitignore_status']) - if result['gitignore_patterns'] > 0: - table.add_row(".gitignore Patterns", str(result['gitignore_patterns'])) - - # Write mode status - write_status = "[green]Enabled โš ๏ธ[/]" if result['write_enabled'] else "[dim]Disabled[/]" - table.add_row("Write Mode", write_status) - - if result['enabled']: - table.add_row("", "") - table.add_row("[bold]Tools Available[/]", "") - for tool in result['tools_available']: - table.add_row(f" โœ“ {tool}", "") - - stats = result['stats'] - if stats['total_calls'] > 0: - table.add_row("", "") - table.add_row("[bold]Session Stats[/]", "") - table.add_row("Total Tool Calls", str(stats['total_calls'])) - - if stats['reads'] > 0: - table.add_row("Files Read", str(stats['reads'])) - if stats['lists'] > 0: - table.add_row("Directories Listed", str(stats['lists'])) - if stats['searches'] > 0: - table.add_row("File Searches", str(stats['searches'])) - if stats['db_inspects'] > 0: - table.add_row("DB Inspections", str(stats['db_inspects'])) - if stats['db_searches'] > 0: - table.add_row("DB Searches", str(stats['db_searches'])) - if stats['db_queries'] > 0: - table.add_row("SQL Queries", str(stats['db_queries'])) - - if stats['last_used']: - try: - last_used = datetime.datetime.fromisoformat(stats['last_used']) - time_ago = datetime.datetime.now() - last_used - - if time_ago.seconds < 60: - time_str = f"{time_ago.seconds} seconds ago" - elif time_ago.seconds < 3600: - time_str = f"{time_ago.seconds // 60} minutes ago" - else: - time_str = f"{time_ago.seconds // 3600} hours ago" - - table.add_row("Last Used", time_str) - except: - pass - - console.print(Panel( - table, - title="[bold green]MCP Status[/]", - title_align="left" - )) - else: - console.print(f"[bold red]โŒ {result['error']}[/]") - continue - - else: - console.print("[bold cyan]MCP (Model Context Protocol) Commands:[/]\n") - console.print(" [bold yellow]/mcp on[/] - Start MCP server") - console.print(" [bold yellow]/mcp off[/] - Stop MCP server") - console.print(" [bold yellow]/mcp status[/] - Show comprehensive MCP status") - console.print("") - console.print(" [bold cyan]FILE MODE (default):[/]") - console.print(" [bold yellow]/mcp add [/] - Add folder for file access") - console.print(" [bold yellow]/mcp remove [/] - Remove folder") - console.print(" [bold yellow]/mcp list[/] - List all folders") - console.print(" [bold yellow]/mcp files[/] - Switch to file mode") - console.print(" [bold yellow]/mcp gitignore [on|off][/] - Toggle .gitignore filtering") - console.print("") - console.print(" [bold cyan]DATABASE MODE:[/]") - console.print(" [bold yellow]/mcp add db [/] - Add SQLite database") - console.print(" [bold yellow]/mcp db list[/] - List all databases") - console.print(" [bold yellow]/mcp db [/] - Switch to database mode") - console.print(" [bold yellow]/mcp remove db [/]- Remove database") - console.print("") - console.print("[dim]Use '/help /mcp' for detailed command help[/]") - console.print("[dim]Use '/help mcp' for comprehensive MCP guide[/]") - continue - - # ============================================================ - # ALL OTHER COMMANDS - # ============================================================ - - if user_input.lower() == "/retry": - if not session_history: - console.print("[bold red]No history to retry.[/]") - app_logger.warning("Retry attempted with no history") - continue - last_prompt = session_history[-1]['prompt'] - console.print("[bold green]Retrying last prompt...[/]") - app_logger.info(f"Retrying prompt: {last_prompt[:100]}...") - user_input = last_prompt - - elif user_input.lower().startswith("/online"): - args = user_input[8:].strip() - if not args: - status = "enabled" if online_mode_enabled else "disabled" - default_status = "enabled" if DEFAULT_ONLINE_MODE == "on" else "disabled" - console.print(f"[bold blue]Online mode (web search) {status}.[/]") - console.print(f"[dim blue]Default setting: {default_status} (use '/config online on|off' to change)[/]") - if selected_model: - if supports_online_mode(selected_model): - console.print(f"[dim green]Current model '{selected_model['name']}' supports online mode.[/]") - else: - console.print(f"[dim yellow]Current model '{selected_model['name']}' does not support online mode.[/]") - continue - if args.lower() == "on": - if not selected_model: - console.print("[bold red]No model selected. Select a model first with '/model'.[/]") - continue - if not supports_online_mode(selected_model): - console.print(f"[bold red]Model '{selected_model['name']}' does not support online mode (web search).[/]") - console.print("[dim yellow]Online mode requires models with 'tools' parameter support.[/]") - app_logger.warning(f"Online mode activation failed - model {selected_model['id']} doesn't support it") - continue - online_mode_enabled = True - console.print("[bold green]Online mode enabled for this session. Model will use web search capabilities.[/]") - console.print(f"[dim blue]Effective model ID: {get_effective_model_id(selected_model['id'], True)}[/]") - app_logger.info(f"Online mode enabled for model {selected_model['id']}") - elif args.lower() == "off": - online_mode_enabled = False - console.print("[bold green]Online mode disabled for this session. Model will not use web search.[/]") - if selected_model: - console.print(f"[dim blue]Effective model ID: {selected_model['id']}[/]") - app_logger.info("Online mode disabled") - else: - console.print("[bold yellow]Usage: /online on|off (or /online to view status)[/]") - continue - - elif user_input.lower().startswith("/memory"): - args = user_input[8:].strip() - if not args: - status = "enabled" if conversation_memory_enabled else "disabled" - history_count = len(session_history) - memory_start_index if conversation_memory_enabled and memory_start_index < len(session_history) else 0 - console.print(f"[bold blue]Conversation memory {status}.[/]") - if conversation_memory_enabled: - console.print(f"[dim blue]Tracking {history_count} message(s) since memory enabled.[/]") - else: - console.print(f"[dim yellow]Memory disabled. Each request is independent (saves tokens/cost).[/]") - continue - if args.lower() == "on": - conversation_memory_enabled = True - memory_start_index = len(session_history) - console.print("[bold green]Conversation memory enabled. Will remember conversations from this point forward.[/]") - console.print(f"[dim blue]Memory will track messages starting from index {memory_start_index}.[/]") - app_logger.info(f"Conversation memory enabled at index {memory_start_index}") - elif args.lower() == "off": - conversation_memory_enabled = False - console.print("[bold green]Conversation memory disabled. API calls will not include history (lower cost).[/]") - console.print(f"[dim yellow]Note: Messages are still saved locally but not sent to API.[/]") - app_logger.info("Conversation memory disabled") - else: - console.print("[bold yellow]Usage: /memory on|off (or /memory to view status)[/]") - continue - - elif user_input.lower().startswith("/paste"): - optional_prompt = user_input[7:].strip() - - try: - clipboard_content = pyperclip.paste() - except Exception as e: - console.print(f"[bold red]Failed to access clipboard: {e}[/]") - app_logger.error(f"Clipboard access error: {e}") - continue - - if not clipboard_content or not clipboard_content.strip(): - console.print("[bold red]Clipboard is empty.[/]") - app_logger.warning("Paste attempted with empty clipboard") - continue - - try: - clipboard_content.encode('utf-8') - - preview_lines = clipboard_content.split('\n')[:10] - preview_text = '\n'.join(preview_lines) - if len(clipboard_content.split('\n')) > 10: - preview_text += "\n... (content truncated for preview)" - - char_count = len(clipboard_content) - line_count = len(clipboard_content.split('\n')) - - console.print(Panel( - preview_text, - title=f"[bold cyan]๐Ÿ“‹ Clipboard Content Preview ({char_count} chars, {line_count} lines)[/]", - title_align="left", - border_style="cyan" - )) - - if optional_prompt: - final_prompt = f"{optional_prompt}\n\n```\n{clipboard_content}\n```" - console.print(f"[dim blue]Sending with prompt: '{optional_prompt}'[/]") - else: - final_prompt = clipboard_content - console.print("[dim blue]Sending clipboard content without additional prompt[/]") - - user_input = final_prompt - app_logger.info(f"Pasted content from clipboard: {char_count} chars, {line_count} lines, with prompt: {bool(optional_prompt)}") - - except UnicodeDecodeError: - console.print("[bold red]Clipboard contains non-text (binary) data. Only plain text is supported.[/]") - app_logger.error("Paste failed - clipboard contains binary data") - continue - except Exception as e: - console.print(f"[bold red]Error processing clipboard content: {e}[/]") - app_logger.error(f"Clipboard processing error: {e}") - continue - - elif user_input.lower().startswith("/export"): - args = user_input[8:].strip().split(maxsplit=1) - if len(args) != 2: - console.print("[bold red]Usage: /export [/]") - console.print("[bold yellow]Formats: md (Markdown), json (JSON), html (HTML)[/]") - console.print("[bold yellow]Example: /export md my_conversation.md[/]") - continue - - export_format = args[0].lower() - filename = args[1] - - if not session_history: - console.print("[bold red]No conversation history to export.[/]") - continue - - if export_format not in ['md', 'json', 'html']: - console.print("[bold red]Invalid format. Use: md, json, or html[/]") - continue - - try: - if export_format == 'md': - content = export_as_markdown(session_history, session_system_prompt) - elif export_format == 'json': - content = export_as_json(session_history, session_system_prompt) - elif export_format == 'html': - content = export_as_html(session_history, session_system_prompt) - - export_path = Path(filename).expanduser() - with open(export_path, 'w', encoding='utf-8') as f: - f.write(content) - - console.print(f"[bold green]โœ… Conversation exported to: {export_path.absolute()}[/]") - console.print(f"[dim blue]Format: {export_format.upper()} | Messages: {len(session_history)} | Size: {len(content)} bytes[/]") - app_logger.info(f"Conversation exported as {export_format} to {export_path} ({len(session_history)} messages)") - except Exception as e: - console.print(f"[bold red]Export failed: {e}[/]") - app_logger.error(f"Export error: {e}") - continue - - elif user_input.lower().startswith("/save"): - args = user_input[6:].strip() - if not args: - console.print("[bold red]Usage: /save [/]") - continue - if not session_history: - console.print("[bold red]No history to save.[/]") - continue - save_conversation(args, session_history) - console.print(f"[bold green]Conversation saved as '{args}'.[/]") - app_logger.info(f"Conversation saved as '{args}' with {len(session_history)} messages") - continue - - elif user_input.lower().startswith("/load"): - args = user_input[6:].strip() - if not args: - console.print("[bold red]Usage: /load [/]") - console.print("[bold yellow]Tip: Use /list to see numbered conversations[/]") - continue - - conversation_name = None - if args.isdigit(): - conv_number = int(args) - if saved_conversations_cache and 1 <= conv_number <= len(saved_conversations_cache): - conversation_name = saved_conversations_cache[conv_number - 1]['name'] - console.print(f"[bold cyan]Loading conversation #{conv_number}: '{conversation_name}'[/]") - else: - console.print(f"[bold red]Invalid conversation number: {conv_number}[/]") - console.print(f"[bold yellow]Use /list to see available conversations (1-{len(saved_conversations_cache) if saved_conversations_cache else 0})[/]") - continue - else: - conversation_name = args - - loaded_data = load_conversation(conversation_name) - if not loaded_data: - console.print(f"[bold red]Conversation '{conversation_name}' not found.[/]") - app_logger.warning(f"Load failed for '{conversation_name}' - not found") - continue - session_history = loaded_data - current_index = len(session_history) - 1 - if conversation_memory_enabled: - memory_start_index = 0 - - # Recalculate session totals from loaded message history - total_input_tokens = 0 - total_output_tokens = 0 - total_cost = 0.0 - message_count = 0 - - for msg in session_history: - # Handle both old format (no cost data) and new format (with cost data) - total_input_tokens += msg.get('prompt_tokens', 0) - total_output_tokens += msg.get('completion_tokens', 0) - total_cost += msg.get('msg_cost', 0.0) - message_count += 1 - - console.print(f"[bold green]Conversation '{conversation_name}' loaded with {len(session_history)} messages.[/]") - if total_cost > 0: - console.print(f"[dim cyan]Restored session totals: {total_input_tokens + total_output_tokens} tokens, ${total_cost:.4f} cost[/]") - app_logger.info(f"Conversation '{conversation_name}' loaded with {len(session_history)} messages, restored cost: ${total_cost:.4f}") - continue - - elif user_input.lower().startswith("/delete"): - args = user_input[8:].strip() - if not args: - console.print("[bold red]Usage: /delete [/]") - console.print("[bold yellow]Tip: Use /list to see numbered conversations[/]") - continue - - conversation_name = None - if args.isdigit(): - conv_number = int(args) - if saved_conversations_cache and 1 <= conv_number <= len(saved_conversations_cache): - conversation_name = saved_conversations_cache[conv_number - 1]['name'] - console.print(f"[bold cyan]Deleting conversation #{conv_number}: '{conversation_name}'[/]") - else: - console.print(f"[bold red]Invalid conversation number: {conv_number}[/]") - console.print(f"[bold yellow]Use /list to see available conversations (1-{len(saved_conversations_cache) if saved_conversations_cache else 0})[/]") - continue - else: - conversation_name = args - - try: - confirm = typer.confirm(f"Delete conversation '{conversation_name}'? This cannot be undone.", default=False) - if not confirm: - console.print("[bold yellow]Deletion cancelled.[/]") - continue - except (EOFError, KeyboardInterrupt): - console.print("\n[bold yellow]Deletion cancelled.[/]") - continue - - deleted_count = delete_conversation(conversation_name) - if deleted_count > 0: - console.print(f"[bold green]Conversation '{conversation_name}' deleted ({deleted_count} version(s) removed).[/]") - app_logger.info(f"Conversation '{conversation_name}' deleted - {deleted_count} version(s)") - if saved_conversations_cache: - saved_conversations_cache = [c for c in saved_conversations_cache if c['name'] != conversation_name] - else: - console.print(f"[bold red]Conversation '{conversation_name}' not found.[/]") - app_logger.warning(f"Delete failed for '{conversation_name}' - not found") - continue - - elif user_input.lower() == "/list": - conversations = list_conversations() - if not conversations: - console.print("[bold yellow]No saved conversations found.[/]") - app_logger.info("User viewed conversation list - empty") - saved_conversations_cache = [] - continue - - saved_conversations_cache = conversations - - table = Table("No.", "Name", "Messages", "Last Saved", show_header=True, header_style="bold magenta") - for idx, conv in enumerate(conversations, 1): - try: - dt = datetime.datetime.fromisoformat(conv['timestamp']) - formatted_time = dt.strftime('%Y-%m-%d %H:%M:%S') - except: - formatted_time = conv['timestamp'] - - table.add_row( - str(idx), - conv['name'], - str(conv['message_count']), - formatted_time - ) - - console.print(Panel(table, title=f"[bold green]Saved Conversations ({len(conversations)} total)[/]", title_align="left", subtitle="[dim]Use /load or /delete to manage conversations[/]", subtitle_align="right")) - app_logger.info(f"User viewed conversation list - {len(conversations)} conversations") - continue - - elif user_input.lower() == "/prev": - if not session_history or current_index <= 0: - console.print("[bold red]No previous response.[/]") - continue - current_index -= 1 - prev_response = session_history[current_index]['response'] - md = Markdown(prev_response) - console.print(Panel(md, title=f"[bold green]Previous Response ({current_index + 1}/{len(session_history)})[/]", title_align="left")) - app_logger.debug(f"Viewed previous response at index {current_index}") - continue - - elif user_input.lower() == "/next": - if not session_history or current_index >= len(session_history) - 1: - console.print("[bold red]No next response.[/]") - continue - current_index += 1 - next_response = session_history[current_index]['response'] - md = Markdown(next_response) - console.print(Panel(md, title=f"[bold green]Next Response ({current_index + 1}/{len(session_history)})[/]", title_align="left")) - app_logger.debug(f"Viewed next response at index {current_index}") - continue - - elif user_input.lower() == "/stats": - credits = get_credits(API_KEY, OPENROUTER_BASE_URL) - credits_left = credits['credits_left'] if credits else "Unknown" - stats = f"Total Input: {total_input_tokens}, Total Output: {total_output_tokens}, Total Tokens: {total_input_tokens + total_output_tokens}, Total Cost: ${total_cost:.4f}, Avg Cost/Message: ${total_cost / message_count:.4f}" if message_count > 0 else "No messages." - table = Table("Metric", "Value", show_header=True, header_style="bold magenta") - table.add_row("Session Stats", stats) - table.add_row("Credits Left", credits_left) - console.print(Panel(table, title="[bold green]Session Cost Summary[/]", title_align="left")) - app_logger.info(f"User viewed stats: {stats}") - - warnings = check_credit_alerts(credits) - if warnings: - warning_text = '|'.join(warnings) - console.print(f"[bold red]โš ๏ธ {warning_text}[/]") - app_logger.warning(f"Warnings in stats: {warning_text}") - continue - - elif user_input.lower().startswith("/middleout"): - args = user_input[11:].strip() - if not args: - console.print(f"[bold blue]Middle-out transform {'enabled' if middle_out_enabled else 'disabled'}.[/]") - continue - if args.lower() == "on": - middle_out_enabled = True - console.print("[bold green]Middle-out transform enabled.[/]") - elif args.lower() == "off": - middle_out_enabled = False - console.print("[bold green]Middle-out transform disabled.[/]") - else: - console.print("[bold yellow]Usage: /middleout on|off (or /middleout to view status)[/]") - continue - - elif user_input.lower() == "/reset": - try: - confirm = typer.confirm("Reset conversation context? This clears history and prompt.", default=False) - if not confirm: - console.print("[bold yellow]Reset cancelled.[/]") - continue - except (EOFError, KeyboardInterrupt): - console.print("\n[bold yellow]Reset cancelled.[/]") - continue - session_history = [] - current_index = -1 - session_system_prompt = "" - memory_start_index = 0 - total_input_tokens = 0 - total_output_tokens = 0 - total_cost = 0.0 - message_count = 0 - console.print("[bold green]Conversation context reset. Totals cleared.[/]") - app_logger.info("Conversation context reset by user - all totals reset to 0") - continue - - elif user_input.lower().startswith("/info"): - args = user_input[6:].strip() - if not args: - if not selected_model: - console.print("[bold red]No model selected and no model ID provided. Use '/model' first or '/info '.[/]") - continue - model_to_show = selected_model - else: - model_to_show = next((m for m in models_data if m["id"] == args or m.get("canonical_slug") == args or args.lower() in m["name"].lower()), None) - if not model_to_show: - console.print(f"[bold red]Model '{args}' not found.[/]") - continue - - pricing = model_to_show.get("pricing", {}) - architecture = model_to_show.get("architecture", {}) - supported_params = ", ".join(model_to_show.get("supported_parameters", [])) or "None" - top_provider = model_to_show.get("top_provider", {}) - - table = Table("Property", "Value", show_header=True, header_style="bold magenta") - table.add_row("ID", model_to_show["id"]) - table.add_row("Name", model_to_show["name"]) - table.add_row("Description", model_to_show.get("description", "N/A")) - table.add_row("Context Length", str(model_to_show.get("context_length", "N/A"))) - table.add_row("Online Support", "Yes" if supports_online_mode(model_to_show) else "No") - table.add_row("Pricing - Prompt ($/M tokens)", pricing.get("prompt", "N/A")) - table.add_row("Pricing - Completion ($/M tokens)", pricing.get("completion", "N/A")) - table.add_row("Pricing - Request ($)", pricing.get("request", "N/A")) - table.add_row("Pricing - Image ($)", pricing.get("image", "N/A")) - table.add_row("Input Modalities", ", ".join(architecture.get("input_modalities", [])) or "None") - table.add_row("Output Modalities", ", ".join(architecture.get("output_modalities", [])) or "None") - table.add_row("Supported Parameters", supported_params) - table.add_row("Top Provider Context Length", str(top_provider.get("context_length", "N/A"))) - table.add_row("Max Completion Tokens", str(top_provider.get("max_completion_tokens", "N/A"))) - table.add_row("Moderated", "Yes" if top_provider.get("is_moderated", False) else "No") - - console.print(Panel(table, title=f"[bold green]Model Info: {model_to_show['name']}[/]", title_align="left")) - continue - - elif user_input.startswith("/model"): - app_logger.info("User initiated model selection") - args = user_input[7:].strip() - search_term = args if args else "" - filtered_models = text_models - if search_term: - filtered_models = [m for m in text_models if search_term.lower() in m["name"].lower() or search_term.lower() in m["id"].lower()] - if not filtered_models: - console.print(f"[bold red]No models match '{search_term}'. Try '/model'.[/]") - continue - - table = Table("No.", "Name", "ID", "Image", "Online", "Tools", show_header=True, header_style="bold magenta") - for i, model in enumerate(filtered_models, 1): - image_support = "[green]โœ“[/green]" if has_image_capability(model) else "[red]โœ—[/red]" - online_support = "[green]โœ“[/green]" if supports_online_mode(model) else "[red]โœ—[/red]" - tools_support = "[green]โœ“[/green]" if supports_function_calling(model) else "[red]โœ—[/red]" - table.add_row(str(i), model["name"], model["id"], image_support, online_support, tools_support) - - title = f"[bold green]Available Models ({'All' if not search_term else f'Search: {search_term}'})[/]" - display_paginated_table(table, title) - - while True: - try: - choice = int(typer.prompt("Enter model number (or 0 to cancel)")) - if choice == 0: - break - if 1 <= choice <= len(filtered_models): - selected_model = filtered_models[choice - 1] - - if supports_online_mode(selected_model) and DEFAULT_ONLINE_MODE == "on": - online_mode_enabled = True - console.print(f"[bold cyan]Selected: {selected_model['name']} ({selected_model['id']})[/]") - console.print("[dim cyan]โœ“ Online mode auto-enabled (default setting). Use '/online off' to disable for this session.[/]") - else: - online_mode_enabled = False - console.print(f"[bold cyan]Selected: {selected_model['name']} ({selected_model['id']})[/]") - if supports_online_mode(selected_model): - console.print("[dim green]โœ“ This model supports online mode. Use '/online on' to enable web search.[/]") - - app_logger.info(f"Model selected: {selected_model['name']} ({selected_model['id']}), Online: {online_mode_enabled}") - break - console.print("[bold red]Invalid choice. Try again.[/]") - except ValueError: - console.print("[bold red]Invalid input. Enter a number.[/]") - continue - - elif user_input.startswith("/maxtoken"): - args = user_input[10:].strip() - if not args: - console.print(f"[bold blue]Current session max tokens: {session_max_token}[/]") - continue - try: - new_limit = int(args) - if new_limit < 1: - console.print("[bold red]Session token limit must be at least 1.[/]") - continue - if new_limit > MAX_TOKEN: - console.print(f"[bold yellow]Cannot exceed stored max ({MAX_TOKEN}). Capping.[/]") - new_limit = MAX_TOKEN - session_max_token = new_limit - console.print(f"[bold green]Session max tokens set to: {session_max_token}[/]") - except ValueError: - console.print("[bold red]Invalid token limit. Provide a positive integer.[/]") - continue - - elif user_input.startswith("/system"): - args = user_input[8:].strip() - if not args: - if session_system_prompt: - console.print(f"[bold blue]Current session system prompt:[/] {session_system_prompt}") - else: - console.print("[bold blue]No session system prompt set.[/]") - continue - if args.lower() == "clear": - session_system_prompt = "" - console.print("[bold green]Session system prompt cleared.[/]") - else: - session_system_prompt = args - console.print(f"[bold green]Session system prompt set to: {session_system_prompt}[/]") - continue - - elif user_input.startswith("/config"): - args = user_input[8:].strip().lower() - update = check_for_updates(version) - - if args == "api": - try: - new_api_key = typer.prompt("Enter new API key") - if new_api_key.strip(): - set_config('api_key', new_api_key.strip()) - API_KEY = new_api_key.strip() - client = OpenRouter(api_key=API_KEY) - console.print("[bold green]API key updated![/]") - else: - console.print("[bold yellow]No change.[/]") - except Exception as e: - console.print(f"[bold red]Error updating API key: {e}[/]") - elif args == "url": - try: - new_url = typer.prompt("Enter new base URL") - if new_url.strip(): - set_config('base_url', new_url.strip()) - OPENROUTER_BASE_URL = new_url.strip() - console.print("[bold green]Base URL updated![/]") - else: - console.print("[bold yellow]No change.[/]") - except Exception as e: - console.print(f"[bold red]Error updating URL: {e}[/]") - elif args.startswith("costwarning"): - sub_args = args[11:].strip() - if not sub_args: - console.print(f"[bold blue]Stored cost warning threshold: ${COST_WARNING_THRESHOLD:.4f}[/]") - continue - try: - new_threshold = float(sub_args) - if new_threshold < 0: - console.print("[bold red]Cost warning threshold must be >= 0.[/]") - continue - set_config(HIGH_COST_WARNING, str(new_threshold)) - COST_WARNING_THRESHOLD = new_threshold - console.print(f"[bold green]Cost warning threshold set to ${COST_WARNING_THRESHOLD:.4f}[/]") - except ValueError: - console.print("[bold red]Invalid cost threshold. Provide a valid number.[/]") - elif args.startswith("stream"): - sub_args = args[7:].strip() - if sub_args in ["on", "off"]: - set_config('stream_enabled', sub_args) - STREAM_ENABLED = sub_args - console.print(f"[bold green]Streaming {'enabled' if sub_args == 'on' else 'disabled'}.[/]") - else: - console.print("[bold yellow]Usage: /config stream on|off[/]") - elif args.startswith("loglevel"): - sub_args = args[8:].strip() - if not sub_args: - console.print(f"[bold blue]Current log level: {LOG_LEVEL_STR.upper()}[/]") - console.print(f"[dim yellow]Valid levels: debug, info, warning, error, critical[/]") - continue - if sub_args.lower() in VALID_LOG_LEVELS: - if set_log_level(sub_args): - set_config('log_level', sub_args.lower()) - console.print(f"[bold green]Log level set to: {sub_args.upper()}[/]") - app_logger.info(f"Log level changed to {sub_args.upper()}") - else: - console.print(f"[bold red]Failed to set log level.[/]") - else: - console.print(f"[bold red]Invalid log level: {sub_args}[/]") - console.print(f"[bold yellow]Valid levels: debug, info, warning, error, critical[/]") - elif args.startswith("log"): - sub_args = args[4:].strip() - if not sub_args: - console.print(f"[bold blue]Current log file size limit: {LOG_MAX_SIZE_MB} MB[/]") - console.print(f"[bold blue]Log backup count: {LOG_BACKUP_COUNT} files[/]") - console.print(f"[bold blue]Log level: {LOG_LEVEL_STR.upper()}[/]") - console.print(f"[dim yellow]Total max disk usage: ~{LOG_MAX_SIZE_MB * (LOG_BACKUP_COUNT + 1)} MB[/]") - continue - try: - new_size_mb = int(sub_args) - if new_size_mb < 1: - console.print("[bold red]Log size must be at least 1 MB.[/]") - continue - if new_size_mb > 100: - console.print("[bold yellow]Warning: Log size > 100MB. Capping at 100MB.[/]") - new_size_mb = 100 - set_config('log_max_size_mb', str(new_size_mb)) - LOG_MAX_SIZE_MB = new_size_mb - - app_logger = reload_logging_config() - - console.print(f"[bold green]Log size limit set to {new_size_mb} MB and applied immediately.[/]") - console.print(f"[dim cyan]Log file rotated if it exceeded the new limit.[/]") - app_logger.info(f"Log size limit updated to {new_size_mb} MB and reloaded") - except ValueError: - console.print("[bold red]Invalid size. Provide a number in MB.[/]") - elif args.startswith("online"): - sub_args = args[7:].strip() - if not sub_args: - current_default = "enabled" if DEFAULT_ONLINE_MODE == "on" else "disabled" - console.print(f"[bold blue]Default online mode: {current_default}[/]") - console.print("[dim yellow]This sets the default for new models. Use '/online on|off' to override in current session.[/]") - continue - if sub_args in ["on", "off"]: - set_config('default_online_mode', sub_args) - DEFAULT_ONLINE_MODE = sub_args - console.print(f"[bold green]Default online mode {'enabled' if sub_args == 'on' else 'disabled'}.[/]") - console.print("[dim blue]Note: This affects new model selections. Current session unchanged.[/]") - app_logger.info(f"Default online mode set to {sub_args}") - else: - console.print("[bold yellow]Usage: /config online on|off[/]") - elif args.startswith("maxtoken"): - sub_args = args[9:].strip() - if not sub_args: - console.print(f"[bold blue]Stored max token limit: {MAX_TOKEN}[/]") - continue - try: - new_max = int(sub_args) - if new_max < 1: - console.print("[bold red]Max token limit must be at least 1.[/]") - continue - if new_max > 1000000: - console.print("[bold yellow]Capped at 1M for safety.[/]") - new_max = 1000000 - set_config('max_token', str(new_max)) - MAX_TOKEN = new_max - if session_max_token > MAX_TOKEN: - session_max_token = MAX_TOKEN - console.print(f"[bold yellow]Session adjusted to {session_max_token}.[/]") - console.print(f"[bold green]Stored max token limit updated to: {MAX_TOKEN}[/]") - except ValueError: - console.print("[bold red]Invalid token limit.[/]") - elif args.startswith("model"): - sub_args = args[6:].strip() - search_term = sub_args if sub_args else "" - filtered_models = text_models - if search_term: - filtered_models = [m for m in text_models if search_term.lower() in m["name"].lower() or search_term.lower() in m["id"].lower()] - if not filtered_models: - console.print(f"[bold red]No models match '{search_term}'. Try without search.[/]") - continue - - table = Table("No.", "Name", "ID", "Image", "Online", "Tools", show_header=True, header_style="bold magenta") - for i, model in enumerate(filtered_models, 1): - image_support = "[green]โœ“[/green]" if has_image_capability(model) else "[red]โœ—[/red]" - online_support = "[green]โœ“[/green]" if supports_online_mode(model) else "[red]โœ—[/red]" - tools_support = "[green]โœ“[/green]" if supports_function_calling(model) else "[red]โœ—[/red]" - table.add_row(str(i), model["name"], model["id"], image_support, online_support, tools_support) - - title = f"[bold green]Available Models for Default ({'All' if not search_term else f'Search: {search_term}'})[/]" - display_paginated_table(table, title) - - while True: - try: - choice = int(typer.prompt("Enter model number (or 0 to cancel)")) - if choice == 0: - break - if 1 <= choice <= len(filtered_models): - default_model = filtered_models[choice - 1] - set_config('default_model', default_model["id"]) - current_name = selected_model['name'] if selected_model else "None" - console.print(f"[bold cyan]Default model set to: {default_model['name']} ({default_model['id']}). Current unchanged: {current_name}[/]") - break - console.print("[bold red]Invalid choice. Try again.[/]") - except ValueError: - console.print("[bold red]Invalid input. Enter a number.[/]") - else: - DEFAULT_MODEL_ID = get_config('default_model') - memory_status = "Enabled" if conversation_memory_enabled else "Disabled" - memory_tracked = len(session_history) - memory_start_index if conversation_memory_enabled else 0 - - mcp_status = "Enabled" if mcp_manager.enabled else "Disabled" - mcp_folders = len(mcp_manager.allowed_folders) - mcp_databases = len(mcp_manager.databases) - mcp_mode = mcp_manager.mode - gitignore_status = "enabled" if (mcp_manager.server and mcp_manager.server.respect_gitignore) else "disabled" - gitignore_patterns = len(mcp_manager.server.gitignore_parser.patterns) if mcp_manager.server else 0 - - table = Table("Setting", "Value", show_header=True, header_style="bold magenta", width=console.width - 10) - table.add_row("API Key", API_KEY or "[Not set]") - table.add_row("Base URL", OPENROUTER_BASE_URL or "[Not set]") - table.add_row("DB Path", str(database) or "[Not set]") - table.add_row("Logfile", str(log_file) or "[Not set]") - table.add_row("Log Size Limit", f"{LOG_MAX_SIZE_MB} MB") - table.add_row("Log Backups", str(LOG_BACKUP_COUNT)) - table.add_row("Log Level", LOG_LEVEL_STR.upper()) - table.add_row("Streaming", "Enabled" if STREAM_ENABLED == "on" else "Disabled") - table.add_row("Default Model", DEFAULT_MODEL_ID or "[Not set]") - table.add_row("Current Model", "[Not set]" if selected_model is None else str(selected_model["name"])) - table.add_row("Default Online Mode", "Enabled" if DEFAULT_ONLINE_MODE == "on" else "Disabled") - table.add_row("Session Online Mode", "Enabled" if online_mode_enabled else "Disabled") - table.add_row("MCP Status", f"{mcp_status} ({mcp_folders} folders, {mcp_databases} DBs)") - table.add_row("MCP Mode", f"{mcp_mode}") - table.add_row("MCP .gitignore", f"{gitignore_status} ({gitignore_patterns} patterns)") - table.add_row("Max Token", str(MAX_TOKEN)) - table.add_row("Session Token", "[Not set]" if session_max_token == 0 else str(session_max_token)) - table.add_row("Session System Prompt", session_system_prompt or "[Not set]") - table.add_row("Cost Warning Threshold", f"${COST_WARNING_THRESHOLD:.4f}") - table.add_row("Middle-out Transform", "Enabled" if middle_out_enabled else "Disabled") - table.add_row("Conversation Memory", f"{memory_status} ({memory_tracked} tracked)" if conversation_memory_enabled else memory_status) - table.add_row("History Size", str(len(session_history))) - table.add_row("Current History Index", str(current_index) if current_index >= 0 else "[None]") - table.add_row("App Name", APP_NAME) - table.add_row("App URL", APP_URL) - - credits = get_credits(API_KEY, OPENROUTER_BASE_URL) - if credits: - table.add_row("Total Credits", credits['total_credits']) - table.add_row("Used Credits", credits['used_credits']) - table.add_row("Credits Left", credits['credits_left']) - else: - table.add_row("Total Credits", "[Unavailable - Check API key]") - table.add_row("Used Credits", "[Unavailable - Check API key]") - table.add_row("Credits Left", "[Unavailable - Check API key]") - - console.print(Panel(table, title="[bold green]Current Configurations[/]", title_align="left", subtitle="%s" %(update), subtitle_align="right")) - continue - - if user_input.lower() == "/credits": - credits = get_credits(API_KEY, OPENROUTER_BASE_URL) - if credits: - console.print(f"[bold green]Credits left: {credits['credits_left']}[/]") - alerts = check_credit_alerts(credits) - if alerts: - for alert in alerts: - console.print(f"[bold red]โš ๏ธ {alert}[/]") - else: - console.print("[bold red]Unable to fetch credits. Check your API key or network.[/]") - continue - - if user_input.lower() == "/clear" or user_input.lower() == "/cl": - clear_screen() - DEFAULT_MODEL_ID = get_config('default_model') - token_value = session_max_token if session_max_token != 0 else "Not set" - console.print(f"[bold cyan]Token limits: Max= {MAX_TOKEN}, Session={token_value}[/]") - console.print("[bold blue]Active model[/] [bold red]%s[/]" %(str(selected_model["name"]) if selected_model else "None")) - if online_mode_enabled: - console.print("[bold cyan]Online mode: Enabled (web search active)[/]") - if mcp_manager.enabled: - gitignore_status = "on" if (mcp_manager.server and mcp_manager.server.respect_gitignore) else "off" - mode_display = "Files" if mcp_manager.mode == "files" else f"DB #{mcp_manager.selected_db_index + 1}" - console.print(f"[bold cyan]MCP: Enabled (Mode: {mode_display}, {len(mcp_manager.allowed_folders)} folders, {len(mcp_manager.databases)} DBs, .gitignore {gitignore_status})[/]") - continue - - if user_input.lower().startswith("/help"): - args = user_input[6:].strip() - - if args: - show_command_help(args) - continue - - help_table = Table("Command", "Description", "Example", show_header=True, header_style="bold cyan", width=console.width - 10) - - # SESSION COMMANDS - help_table.add_row( - "[bold yellow]โ”โ”โ” SESSION COMMANDS โ”โ”โ”[/]", - "", - "" - ) - help_table.add_row( - "/clear or /cl", - "Clear terminal screen. Keyboard shortcut: [bold]Ctrl+L[/]", - "/clear\n/cl" - ) - help_table.add_row( - "/help [command|topic]", - "Show help menu or detailed help. Use '/help mcp' for MCP guide.", - "/help\n/help mcp" - ) - help_table.add_row( - "/memory [on|off]", - "Toggle conversation memory. ON sends history, OFF is stateless (saves cost).", - "/memory\n/memory off" - ) - help_table.add_row( - "/next", - "View next response in history.", - "/next" - ) - help_table.add_row( - "/online [on|off]", - "Enable/disable online mode (web search) for session.", - "/online on\n/online off" - ) - help_table.add_row( - "/paste [prompt]", - "Paste clipboard text/code. Optional prompt.", - "/paste\n/paste Explain" - ) - help_table.add_row( - "/prev", - "View previous response in history.", - "/prev" - ) - help_table.add_row( - "/reset", - "Clear history and reset system prompt (requires confirmation).", - "/reset" - ) - help_table.add_row( - "/retry", - "Resend last prompt.", - "/retry" - ) - - # MCP COMMANDS - help_table.add_row( - "[bold yellow]โ”โ”โ” MCP (FILE & DATABASE ACCESS) โ”โ”โ”[/]", - "", - "" - ) - help_table.add_row( - "/mcp on", - "Start MCP server for file and database access.", - "/mcp on" - ) - help_table.add_row( - "/mcp off", - "Stop MCP server.", - "/mcp off" - ) - help_table.add_row( - "/mcp status", - "Show mode, stats, folders, databases, and .gitignore status.", - "/mcp status" - ) - help_table.add_row( - "/mcp add ", - "Add folder for file access (auto-loads .gitignore).", - "/mcp add ~/Documents" - ) - help_table.add_row( - "/mcp add db ", - "Add SQLite database for querying.", - "/mcp add db ~/app.db" - ) - help_table.add_row( - "/mcp list", - "List all allowed folders with stats.", - "/mcp list" - ) - help_table.add_row( - "/mcp db list", - "List all databases with details.", - "/mcp db list" - ) - help_table.add_row( - "/mcp db ", - "Switch to database mode (select DB by number).", - "/mcp db 1" - ) - help_table.add_row( - "/mcp files", - "Switch to file mode (default).", - "/mcp files" - ) - help_table.add_row( - "/mcp remove ", - "Remove folder by path or number.", - "/mcp remove 2" - ) - help_table.add_row( - "/mcp remove db ", - "Remove database by number.", - "/mcp remove db 1" - ) - help_table.add_row( - "/mcp gitignore [on|off]", - "Toggle .gitignore filtering (respects project excludes).", - "/mcp gitignore on" - ) - help_table.add_row( - "/mcp write [on|off]", - "Enable/disable write mode (create, edit, delete files). Requires confirmation.", - "/mcp write on" - ) - - # MODEL COMMANDS - help_table.add_row( - "[bold yellow]โ”โ”โ” MODEL COMMANDS โ”โ”โ”[/]", - "", - "" - ) - help_table.add_row( - "/info [model_id]", - "Display model details (pricing, capabilities, context).", - "/info\n/info gpt-4o" - ) - help_table.add_row( - "/model [search]", - "Select/change model. Shows image and online capabilities.", - "/model\n/model gpt" - ) - - # CONFIGURATION - help_table.add_row( - "[bold yellow]โ”โ”โ” CONFIGURATION โ”โ”โ”[/]", - "", - "" - ) - help_table.add_row( - "/config", - "View all configurations.", - "/config" - ) - help_table.add_row( - "/config api", - "Set/update API key.", - "/config api" - ) - help_table.add_row( - "/config costwarning [val]", - "Set cost warning threshold (USD).", - "/config costwarning 0.05" - ) - help_table.add_row( - "/config log [size_mb]", - "Set log file size limit. Takes effect immediately.", - "/config log 20" - ) - help_table.add_row( - "/config loglevel [level]", - "Set log verbosity (debug/info/warning/error/critical).", - "/config loglevel debug" - ) - help_table.add_row( - "/config maxtoken [val]", - "Set stored max token limit.", - "/config maxtoken 50000" - ) - help_table.add_row( - "/config model [search]", - "Set default startup model.", - "/config model gpt" - ) - help_table.add_row( - "/config online [on|off]", - "Set default online mode for new models.", - "/config online on" - ) - help_table.add_row( - "/config stream [on|off]", - "Enable/disable response streaming.", - "/config stream off" - ) - help_table.add_row( - "/config url", - "Set/update base URL.", - "/config url" - ) - - # TOKEN & SYSTEM - help_table.add_row( - "[bold yellow]โ”โ”โ” TOKEN & SYSTEM โ”โ”โ”[/]", - "", - "" - ) - help_table.add_row( - "/maxtoken [value]", - "Set temporary session token limit.", - "/maxtoken 2000" - ) - help_table.add_row( - "/middleout [on|off]", - "Enable/disable prompt compression for large contexts.", - "/middleout on" - ) - help_table.add_row( - "/system [prompt|clear]", - "Set session system prompt. Use 'clear' to reset.", - "/system You are a Python expert" - ) - - # CONVERSATION MGMT - help_table.add_row( - "[bold yellow]โ”โ”โ” CONVERSATION MGMT โ”โ”โ”[/]", - "", - "" - ) - help_table.add_row( - "/delete ", - "Delete saved conversation (requires confirmation).", - "/delete my_chat\n/delete 3" - ) - help_table.add_row( - "/export ", - "Export conversation (md/json/html).", - "/export md notes.md" - ) - help_table.add_row( - "/list", - "List saved conversations with numbers and stats.", - "/list" - ) - help_table.add_row( - "/load ", - "Load saved conversation.", - "/load my_chat\n/load 3" - ) - help_table.add_row( - "/save ", - "Save current conversation.", - "/save my_chat" - ) - - # MONITORING - help_table.add_row( - "[bold yellow]โ”โ”โ” MONITORING & STATS โ”โ”โ”[/]", - "", - "" - ) - help_table.add_row( - "/credits", - "Display OpenRouter credits with alerts.", - "/credits" - ) - help_table.add_row( - "/stats", - "Display session stats: tokens, cost, credits.", - "/stats" - ) - - # INPUT METHODS - help_table.add_row( - "[bold yellow]โ”โ”โ” INPUT METHODS โ”โ”โ”[/]", - "", - "" - ) - help_table.add_row( - "@/path/to/file", - "Attach files: images (PNG, JPG), PDFs, code files.", - "Debug @script.py\nAnalyze @image.png" - ) - help_table.add_row( - "Clipboard paste", - "Use /paste to send clipboard content.", - "/paste\n/paste Explain" - ) - help_table.add_row( - "// escape", - "Start with // to send literal / character.", - "//help sends '/help' as text" - ) - - # EXIT - help_table.add_row( - "[bold yellow]โ”โ”โ” EXIT โ”โ”โ”[/]", - "", - "" - ) - help_table.add_row( - "exit | quit | bye", - "Quit with session summary.", - "exit" - ) - - console.print(Panel( - help_table, - title="[bold cyan]oAI Chat Help (Version %s)[/]" % version, - title_align="center", - subtitle="๐Ÿ’ก Commands are case-insensitive โ€ข Use /help for details โ€ข Use /help mcp for MCP guide โ€ข Memory ON by default โ€ข MCP has file & DB modes โ€ข Use // to escape / โ€ข Visit: https://iurl.no/oai", - subtitle_align="center", - border_style="cyan" - )) - continue - - if not selected_model: - console.print("[bold yellow]Select a model first with '/model'.[/]") - continue - - # ============================================================ - # PROCESS FILE ATTACHMENTS - # ============================================================ - text_part = user_input - file_attachments = [] - content_blocks = [] - - # Smart file detection: Simple pattern + extension validation - # This avoids extremely long regex that can cause binary signing issues - - # Common file extensions we support - ALLOWED_FILE_EXTENSIONS = { - # Code - '.py', '.js', '.ts', '.jsx', '.tsx', '.vue', '.java', '.c', '.cpp', '.cc', '.cxx', - '.h', '.hpp', '.hxx', '.rb', '.go', '.rs', '.swift', '.kt', '.kts', '.php', - '.sh', '.bash', '.zsh', '.fish', '.bat', '.cmd', '.ps1', - # Data - '.json', '.csv', '.yaml', '.yml', '.toml', '.xml', '.sql', '.db', '.sqlite', '.sqlite3', - # Documents - '.txt', '.md', '.log', '.conf', '.cfg', '.ini', '.env', '.properties', - # Images - '.png', '.jpg', '.jpeg', '.gif', '.bmp', '.webp', '.svg', '.ico', - # Archives - '.zip', '.tar', '.gz', '.bz2', '.7z', '.rar', '.xz', - # Config files - '.lock', '.gitignore', '.dockerignore', '.editorconfig', '.eslintrc', - '.prettierrc', '.babelrc', '.nvmrc', '.npmrc', - # Binary/Compiled - '.pyc', '.pyo', '.pyd', '.so', '.dll', '.dylib', '.exe', '.app', - '.dmg', '.pkg', '.deb', '.rpm', '.apk', '.ipa', - # ML/AI - '.pkl', '.pickle', '.joblib', '.npy', '.npz', '.safetensors', '.onnx', - '.pt', '.pth', '.ckpt', '.pb', '.tflite', '.mlmodel', '.coreml', '.rknn', - # Data formats - '.wasm', '.proto', '.graphql', '.graphqls', '.grpc', '.avro', '.parquet', - '.orc', '.feather', '.arrow', '.hdf5', '.h5', '.mat', '.r', '.rdata', '.rds', - # Other - '.pdf', '.class', '.jar', '.war' - } - - # Simple pattern: @filepath where filepath starts with ~, /, . or has an extension - # Much shorter than listing all extensions in the regex - file_pattern = r'(?:^|\s)@([~/\.][\S]+|[\w][\w-]*\.[\w]+)(?=\s|$)' - - for match in re.finditer(file_pattern, user_input, re.IGNORECASE): - file_path = match.group(1) - - # Get the extension - file_ext = os.path.splitext(file_path)[1].lower() - - # Skip if it doesn't start with a path char and has no allowed extension - if not file_path.startswith(('/', '~', '.', '\\')): - if not file_ext or file_ext not in ALLOWED_FILE_EXTENSIONS: - # This looks like a domain name, not a file - continue - - expanded_path = os.path.expanduser(os.path.abspath(file_path)) - - if not os.path.exists(expanded_path) or os.path.isdir(expanded_path): - console.print(f"[bold red]File not found or is a directory: {expanded_path}[/]") - continue - - file_size = os.path.getsize(expanded_path) - if file_size > 10 * 1024 * 1024: - console.print(f"[bold red]File too large (>10MB): {expanded_path}[/]") - continue - - mime_type, _ = mimetypes.guess_type(expanded_path) - - try: - with open(expanded_path, 'rb') as f: - file_data = f.read() - - if mime_type and mime_type.startswith('image/'): - modalities = selected_model.get("architecture", {}).get("input_modalities", []) - if "image" not in modalities: - console.print("[bold red]Selected model does not support image attachments.[/]") - console.print(f"[dim yellow]Supported modalities: {', '.join(modalities) if modalities else 'text only'}[/]") - continue - b64_data = base64.b64encode(file_data).decode('utf-8') - content_blocks.append({"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{b64_data}"}}) - console.print(f"[dim green]โœ“ Image attached: {os.path.basename(expanded_path)} ({file_size / 1024:.1f} KB)[/]") - - elif mime_type == 'application/pdf' or file_ext == '.pdf': - modalities = selected_model.get("architecture", {}).get("input_modalities", []) - supports_pdf = any(mod in modalities for mod in ["document", "pdf", "file"]) - if not supports_pdf: - console.print("[bold red]Selected model does not support PDF attachments.[/]") - console.print(f"[dim yellow]Supported modalities: {', '.join(modalities) if modalities else 'text only'}[/]") - continue - b64_data = base64.b64encode(file_data).decode('utf-8') - content_blocks.append({"type": "image_url", "image_url": {"url": f"data:application/pdf;base64,{b64_data}"}}) - console.print(f"[dim green]โœ“ PDF attached: {os.path.basename(expanded_path)} ({file_size / 1024:.1f} KB)[/]") - - elif (mime_type == 'text/plain' or file_ext in SUPPORTED_CODE_EXTENSIONS): - text_content = file_data.decode('utf-8') - content_blocks.append({"type": "text", "text": f"Code File: {os.path.basename(expanded_path)}\n\n{text_content}"}) - console.print(f"[dim green]โœ“ Code file attached: {os.path.basename(expanded_path)} ({file_size / 1024:.1f} KB)[/]") - - else: - console.print(f"[bold red]Unsupported file type ({mime_type}) for {expanded_path}.[/]") - console.print("[bold yellow]Supported types: images (PNG, JPG, etc.), PDFs, and code files (.py, .js, etc.)[/]") - continue - - file_attachments.append(file_path) - app_logger.info(f"File attached: {os.path.basename(expanded_path)}, Type: {mime_type or file_ext}, Size: {file_size / 1024:.1f} KB") - except UnicodeDecodeError: - console.print(f"[bold red]Cannot decode {expanded_path} as UTF-8. File may be binary or use unsupported encoding.[/]") - app_logger.error(f"UTF-8 decode error for {expanded_path}") - continue - except Exception as e: - console.print(f"[bold red]Error reading file {expanded_path}: {e}[/]") - app_logger.error(f"File read error for {expanded_path}: {e}") - continue - - # Remove file attachments from text (use same simple pattern) - text_part = re.sub(file_pattern, lambda m: m.group(0)[0] if m.group(0)[0].isspace() else '', text_part).strip() - - # Build message content - if text_part or content_blocks: - message_content = [] - if text_part: - message_content.append({"type": "text", "text": text_part}) - message_content.extend(content_blocks) - else: - console.print("[bold red]Prompt cannot be empty.[/]") - continue - - - # ============================================================ - # BUILD API MESSAGES - # ============================================================ - api_messages = [] - - if session_system_prompt: - api_messages.append({"role": "system", "content": session_system_prompt}) - - # Add database context if in database mode - if mcp_manager.enabled and mcp_manager.mode == "database" and mcp_manager.selected_db_index is not None: - db = mcp_manager.databases[mcp_manager.selected_db_index] - db_context = f"""You are currently connected to SQLite database: {db['name']} -Available tables: {', '.join(db['tables'])} - -You can: -- Inspect the database schema with inspect_database -- Search for data across tables with search_database -- Execute read-only SQL queries with query_database - -All queries are read-only. INSERT/UPDATE/DELETE are not allowed.""" - api_messages.append({"role": "system", "content": db_context}) - - if conversation_memory_enabled: - for i in range(memory_start_index, len(session_history)): - history_entry = session_history[i] - api_messages.append({ - "role": "user", - "content": history_entry['prompt'] - }) - api_messages.append({ - "role": "assistant", - "content": history_entry['response'] - }) - - api_messages.append({"role": "user", "content": message_content}) - - # Get effective model ID - effective_model_id = get_effective_model_id(selected_model["id"], online_mode_enabled) - - # ======================================================================== - # BUILD API PARAMS WITH MCP TOOLS SUPPORT - # ======================================================================== - api_params = { - "model": effective_model_id, - "messages": api_messages, - "stream": STREAM_ENABLED == "on", - "http_headers": { - "HTTP-Referer": APP_URL, - "X-Title": APP_NAME - } - } - - # Add MCP tools if enabled - mcp_tools_added = False - mcp_tools = [] - if mcp_manager.enabled: - if supports_function_calling(selected_model): - mcp_tools = mcp_manager.get_tools_schema() - if mcp_tools: - api_params["tools"] = mcp_tools - api_params["tool_choice"] = "auto" - mcp_tools_added = True - - # IMPORTANT: Disable streaming if MCP tools are present - if api_params["stream"]: - api_params["stream"] = False - app_logger.info("Disabled streaming due to MCP tool calls") - - app_logger.info(f"Added {len(mcp_tools)} MCP tools to request (mode: {mcp_manager.mode})") - else: - app_logger.debug(f"Model {selected_model['id']} doesn't support function calling - MCP tools not added") - - if session_max_token > 0: - api_params["max_tokens"] = session_max_token - if middle_out_enabled: - api_params["transforms"] = ["middle-out"] - - # Log API request - file_count = len(file_attachments) - history_messages_count = len(session_history) - memory_start_index if conversation_memory_enabled else 0 - memory_status = "ON" if conversation_memory_enabled else "OFF" - online_status = "ON" if online_mode_enabled else "OFF" - - if mcp_tools_added: - if mcp_manager.mode == "files": - mcp_status = f"ON (Files, {len(mcp_manager.allowed_folders)} folders, {len(mcp_tools)} tools)" - elif mcp_manager.mode == "database": - db = mcp_manager.databases[mcp_manager.selected_db_index] - mcp_status = f"ON (DB: {db['name']}, {len(mcp_tools)} tools)" - else: - mcp_status = f"ON ({len(mcp_tools)} tools)" - else: - mcp_status = "OFF" - - app_logger.info(f"API Request: Model '{effective_model_id}' (Online: {online_status}, MCP: {mcp_status}), Prompt length: {len(text_part)} chars, {file_count} file(s) attached, Memory: {memory_status}, History sent: {history_messages_count} messages.") - - # Send request - is_streaming = api_params["stream"] - if is_streaming: - console.print("[bold green]Streaming response...[/] [dim](Press Ctrl+C to cancel)[/]") - if online_mode_enabled: - console.print("[dim cyan]๐ŸŒ Online mode active - model has web search access[/]") - console.print("") - else: - thinking_msg = "Thinking" - if mcp_tools_added: - if mcp_manager.mode == "database": - thinking_msg = "Thinking (database tools available)" - else: - thinking_msg = "Thinking (MCP tools available)" - console.print(f"[bold green]{thinking_msg}...[/]", end="\r") - if online_mode_enabled: - console.print(f"[dim cyan]๐ŸŒ Online mode active[/]") - if mcp_tools_added: - if mcp_manager.mode == "files": - gitignore_status = "on" if (mcp_manager.server and mcp_manager.server.respect_gitignore) else "off" - console.print(f"[dim cyan]๐Ÿ”ง MCP active - AI can access {len(mcp_manager.allowed_folders)} folder(s) with {len(mcp_tools)} tools (.gitignore {gitignore_status})[/]") - elif mcp_manager.mode == "database": - db = mcp_manager.databases[mcp_manager.selected_db_index] - console.print(f"[dim cyan]๐Ÿ—„๏ธ MCP active - AI can query database: {db['name']} ({len(db['tables'])} tables, {len(mcp_tools)} tools)[/]") - - start_time = time.time() - try: - response = client.chat.send(**api_params) - app_logger.info(f"API call successful for model '{effective_model_id}'") - except Exception as e: - console.print(f"[bold red]Error sending request: {e}[/]") - app_logger.error(f"API Error: {type(e).__name__}: {e}") - continue - - # ======================================================================== - # HANDLE TOOL CALLS (MCP FUNCTION CALLING) - # ======================================================================== - tool_call_loop_count = 0 - max_tool_loops = 5 - - while tool_call_loop_count < max_tool_loops: - wants_tool_call = False - - if hasattr(response, 'choices') and response.choices: - message = response.choices[0].message - - if hasattr(message, 'tool_calls') and message.tool_calls: - wants_tool_call = True - tool_calls = message.tool_calls - - console.print(f"\n[dim yellow]๐Ÿ”ง AI requesting {len(tool_calls)} tool call(s)...[/]") - app_logger.info(f"Model requested {len(tool_calls)} tool calls") - - tool_results = [] - for tool_call in tool_calls: - tool_name = tool_call.function.name - - try: - tool_args = json.loads(tool_call.function.arguments) - except json.JSONDecodeError as e: - app_logger.error(f"Failed to parse tool arguments: {e}") - tool_results.append({ - "tool_call_id": tool_call.id, - "role": "tool", - "name": tool_name, - "content": json.dumps({"error": f"Invalid arguments: {e}"}) - }) - continue - - args_display = ', '.join(f'{k}="{v}"' if isinstance(v, str) else f'{k}={v}' for k, v in tool_args.items()) - console.print(f"[dim cyan] โ†’ Calling: {tool_name}({args_display})[/]") - app_logger.info(f"Executing MCP tool: {tool_name} with args: {tool_args}") - - try: - result = asyncio.run(mcp_manager.call_tool(tool_name, **tool_args)) - except Exception as e: - app_logger.error(f"MCP tool execution error: {e}") - result = {"error": str(e)} - - if 'error' in result: - result_content = json.dumps({"error": result['error']}) - console.print(f"[dim red] โœ— Error: {result['error']}[/]") - app_logger.warning(f"MCP tool {tool_name} returned error: {result['error']}") - else: - # Display appropriate success message based on tool - if tool_name == 'search_files': - count = result.get('count', 0) - console.print(f"[dim green] โœ“ Found {count} file(s)[/]") - elif tool_name == 'read_file': - size = result.get('size', 0) - truncated = " (truncated)" if result.get('truncated') else "" - console.print(f"[dim green] โœ“ Read file ({size} bytes{truncated})[/]") - elif tool_name == 'list_directory': - count = result.get('count', 0) - truncated = " (limited to 1000)" if result.get('truncated') else "" - console.print(f"[dim green] โœ“ Listed {count} item(s){truncated}[/]") - elif tool_name == 'inspect_database': - if 'table' in result: - console.print(f"[dim green] โœ“ Inspected table: {result['table']} ({result['row_count']} rows)[/]") - else: - console.print(f"[dim green] โœ“ Inspected database ({result['table_count']} tables)[/]") - elif tool_name == 'search_database': - count = result.get('count', 0) - console.print(f"[dim green] โœ“ Found {count} match(es)[/]") - elif tool_name == 'query_database': - count = result.get('count', 0) - truncated = " (truncated)" if result.get('truncated') else "" - console.print(f"[dim green] โœ“ Query returned {count} row(s){truncated}[/]") - else: - console.print(f"[dim green] โœ“ Success[/]") - - result_content = json.dumps(result, indent=2) - app_logger.info(f"MCP tool {tool_name} succeeded") - - tool_results.append({ - "tool_call_id": tool_call.id, - "role": "tool", - "name": tool_name, - "content": result_content - }) - - api_messages.append({ - "role": "assistant", - "content": message.content, - "tool_calls": [ - { - "id": tc.id, - "type": "function", - "function": { - "name": tc.function.name, - "arguments": tc.function.arguments - } - } for tc in tool_calls - ] - }) - api_messages.extend(tool_results) - - console.print("\n[dim cyan]๐Ÿ’ญ Processing tool results...[/]") - app_logger.info("Sending tool results back to model") - - followup_params = { - "model": effective_model_id, - "messages": api_messages, - "stream": False, # Keep streaming disabled for follow-ups - "http_headers": { - "HTTP-Referer": APP_URL, - "X-Title": APP_NAME - } - } - - if mcp_tools_added: - followup_params["tools"] = mcp_tools - followup_params["tool_choice"] = "auto" - - if session_max_token > 0: - followup_params["max_tokens"] = session_max_token - - try: - response = client.chat.send(**followup_params) - tool_call_loop_count += 1 - app_logger.info(f"Follow-up request successful (loop {tool_call_loop_count})") - except Exception as e: - console.print(f"[bold red]Error getting follow-up response: {e}[/]") - app_logger.error(f"Follow-up API error: {e}") - break - - if not wants_tool_call: - break - - if tool_call_loop_count >= max_tool_loops: - console.print(f"[bold yellow]โš ๏ธ Reached maximum tool call depth ({max_tool_loops}). Stopping.[/]") - app_logger.warning(f"Hit max tool call loop limit: {max_tool_loops}") - - response_time = time.time() - start_time - - # ======================================================================== - # PROCESS FINAL RESPONSE - # ======================================================================== - full_response = "" - stream_interrupted = False - - if is_streaming: - # Store the last chunk to get usage data after streaming completes - last_chunk = None - try: - with Live("", console=console, refresh_per_second=10, auto_refresh=True) as live: - try: - for chunk in response: - last_chunk = chunk # Keep track of last chunk for usage data - if hasattr(chunk, 'error') and chunk.error: - console.print(f"\n[bold red]Stream error: {chunk.error.message}[/]") - app_logger.error(f"Stream error: {chunk.error.message}") - break - if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content: - content_chunk = chunk.choices[0].delta.content - full_response += content_chunk - md = Markdown(full_response) - live.update(md) - except KeyboardInterrupt: - stream_interrupted = True - console.print("\n[bold yellow]โš ๏ธ Streaming interrupted![/]") - app_logger.info("Streaming interrupted by user (Ctrl+C)") - except Exception as stream_error: - stream_interrupted = True - console.print(f"\n[bold red]Stream error: {stream_error}[/]") - app_logger.error(f"Stream processing error: {stream_error}") - - if not stream_interrupted: - console.print("") - - except KeyboardInterrupt: - # Outer interrupt handler (in case inner misses it) - stream_interrupted = True - console.print("\n[bold yellow]โš ๏ธ Streaming interrupted![/]") - app_logger.info("Streaming interrupted by user (outer)") - except Exception as e: - stream_interrupted = True - console.print(f"\n[bold red]Error during streaming: {e}[/]") - app_logger.error(f"Streaming error: {e}") - - # If stream was interrupted, skip processing and continue to next prompt - if stream_interrupted: - if full_response: - console.print(f"\n[dim yellow]Partial response received ({len(full_response)} chars). Discarding...[/]") - console.print("[dim blue]๐Ÿ’ก Ready for next prompt[/]\n") - app_logger.info("Stream cleanup completed, returning to prompt") - - # Force close the response if possible - try: - if hasattr(response, 'close'): - response.close() - elif hasattr(response, '__exit__'): - response.__exit__(None, None, None) - except: - pass - - # Recreate client to be safe - try: - client = OpenRouter(api_key=API_KEY) - app_logger.info("Client recreated after stream interruption") - except: - pass - - continue # Now it's safe to continue - - # For streaming, try to get usage from last chunk - if last_chunk and hasattr(last_chunk, 'usage'): - response.usage = last_chunk.usage - app_logger.debug("Extracted usage data from last streaming chunk") - elif last_chunk: - app_logger.warning("Last streaming chunk has no usage data") - - else: - full_response = response.choices[0].message.content if response.choices else "" - # Clear any processing messages before showing response - console.print(f"\r{' ' * 100}\r", end="") - - if full_response: - if not is_streaming: - md = Markdown(full_response) - # Add newline before Panel to ensure clean rendering - console.print() - console.print(Panel(md, title="[bold green]AI Response[/]", title_align="left", border_style="green")) - - # Extract usage data BEFORE appending to history - usage = getattr(response, 'usage', None) - - # DEBUG: Log what OpenRouter actually returns - if usage: - app_logger.debug(f"Usage object type: {type(usage)}") - app_logger.debug(f"Usage attributes: {dir(usage)}") - if hasattr(usage, '__dict__'): - app_logger.debug(f"Usage dict: {usage.__dict__}") - else: - app_logger.warning("No usage object in response!") - - # Try both attribute naming conventions (OpenAI standard vs Anthropic) - input_tokens = 0 - output_tokens = 0 - - if usage: - # Try prompt_tokens/completion_tokens (OpenAI/OpenRouter standard) - if hasattr(usage, 'prompt_tokens'): - input_tokens = usage.prompt_tokens or 0 - elif hasattr(usage, 'input_tokens'): - input_tokens = usage.input_tokens or 0 - - if hasattr(usage, 'completion_tokens'): - output_tokens = usage.completion_tokens or 0 - elif hasattr(usage, 'output_tokens'): - output_tokens = usage.output_tokens or 0 - - app_logger.debug(f"Extracted tokens: input={input_tokens}, output={output_tokens}") - - # Get cost from API or estimate - msg_cost = 0.0 - if usage and hasattr(usage, 'total_cost_usd') and usage.total_cost_usd: - msg_cost = float(usage.total_cost_usd) - app_logger.debug(f"Using API cost: ${msg_cost:.6f}") - else: - msg_cost = estimate_cost(input_tokens, output_tokens) - app_logger.debug(f"Estimated cost: ${msg_cost:.6f} (from {input_tokens} input + {output_tokens} output tokens)") - - # NOW append to history with cost data - session_history.append({ - 'prompt': user_input, - 'response': full_response, - 'msg_cost': msg_cost, - 'prompt_tokens': input_tokens, - 'completion_tokens': output_tokens - }) - current_index = len(session_history) - 1 - - total_input_tokens += input_tokens - total_output_tokens += output_tokens - total_cost += msg_cost - message_count += 1 - - app_logger.info(f"Response: Tokens - I:{input_tokens} O:{output_tokens} T:{input_tokens + output_tokens}, Cost: ${msg_cost:.4f}, Time: {response_time:.2f}s, Tool loops: {tool_call_loop_count}") - - if conversation_memory_enabled: - context_count = len(session_history) - memory_start_index - context_info = f", Context: {context_count} msg(s)" if context_count > 1 else "" - else: - context_info = ", Memory: OFF" - - online_info = " ๐ŸŒ" if online_mode_enabled else "" - mcp_info = "" - if mcp_tools_added: - if mcp_manager.mode == "files": - mcp_info = " ๐Ÿ”ง" - elif mcp_manager.mode == "database": - mcp_info = " ๐Ÿ—„๏ธ" - tool_info = f" ({tool_call_loop_count} tool loop(s))" if tool_call_loop_count > 0 else "" - console.print(f"\n[dim blue]๐Ÿ“Š Metrics: {input_tokens + output_tokens} tokens | ${msg_cost:.4f} | {response_time:.2f}s{context_info}{online_info}{mcp_info}{tool_info} | Session: {total_input_tokens + total_output_tokens} tokens | ${total_cost:.4f}[/]") - - warnings = [] - if msg_cost > COST_WARNING_THRESHOLD: - warnings.append(f"High cost alert: This response exceeded threshold ${COST_WARNING_THRESHOLD:.4f}") - credits_data = get_credits(API_KEY, OPENROUTER_BASE_URL) - if credits_data: - warning_alerts = check_credit_alerts(credits_data) - warnings.extend(warning_alerts) - if warnings: - warning_text = ' | '.join(warnings) - console.print(f"[bold red]โš ๏ธ {warning_text}[/]") - app_logger.warning(f"Warnings triggered: {warning_text}") - - console.print("") - try: - copy_choice = input("๐Ÿ’พ Type 'c' to copy response, or press Enter to continue: ").strip().lower() - if copy_choice == "c": - pyperclip.copy(full_response) - console.print("[bold green]โœ… Response copied to clipboard![/]") - except (EOFError, KeyboardInterrupt): - pass - - console.print("") - else: - console.print("[bold red]No response received.[/]") - app_logger.error("No response from API") - - except KeyboardInterrupt: - console.print("\n[bold yellow]Input interrupted. Continuing...[/]") - app_logger.warning("Input interrupted by Ctrl+C") - continue - except EOFError: - console.print("\n[bold yellow]Goodbye![/]") - total_tokens = total_input_tokens + total_output_tokens - app_logger.info(f"Session ended via EOF. Total messages: {message_count}, Total tokens: {total_tokens}, Total cost: ${total_cost:.4f}") - return - except Exception as e: - console.print(f"[bold red]Error: {e}[/]") - console.print("[bold yellow]Try again or select a model.[/]") - app_logger.error(f"Unexpected error: {type(e).__name__}: {e}") - -if __name__ == "__main__": - clear_screen() - app() \ No newline at end of file diff --git a/oai/__init__.py b/oai/__init__.py new file mode 100644 index 0000000..77da9b0 --- /dev/null +++ b/oai/__init__.py @@ -0,0 +1,26 @@ +""" +oAI - OpenRouter AI Chat Client + +A feature-rich terminal-based chat application that provides an interactive CLI +interface to OpenRouter's unified AI API with advanced Model Context Protocol (MCP) +integration for filesystem and database access. + +Author: Rune +License: MIT +""" + +__version__ = "2.1.0" +__author__ = "Rune" +__license__ = "MIT" + +# Lazy imports to avoid circular dependencies and improve startup time +# Full imports are available via submodules: +# from oai.config import Settings, Database +# from oai.providers import OpenRouterProvider, AIProvider +# from oai.mcp import MCPManager + +__all__ = [ + "__version__", + "__author__", + "__license__", +] diff --git a/oai/__main__.py b/oai/__main__.py new file mode 100644 index 0000000..bd02e03 --- /dev/null +++ b/oai/__main__.py @@ -0,0 +1,8 @@ +""" +Entry point for running oAI as a module: python -m oai +""" + +from oai.cli import main + +if __name__ == "__main__": + main() diff --git a/oai/cli.py b/oai/cli.py new file mode 100644 index 0000000..0405885 --- /dev/null +++ b/oai/cli.py @@ -0,0 +1,719 @@ +""" +Main CLI entry point for oAI. + +This module provides the command-line interface for the oAI chat application. +""" + +import os +import sys +from pathlib import Path +from typing import Optional + +import typer +from prompt_toolkit import PromptSession +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.history import FileHistory +from rich.markdown import Markdown +from rich.panel import Panel + +from oai import __version__ +from oai.commands import register_all_commands, registry +from oai.commands.registry import CommandContext, CommandStatus +from oai.config.database import Database +from oai.config.settings import Settings +from oai.constants import ( + APP_NAME, + APP_URL, + APP_VERSION, + CONFIG_DIR, + HISTORY_FILE, + VALID_COMMANDS, +) +from oai.core.client import AIClient +from oai.core.session import ChatSession +from oai.mcp.manager import MCPManager +from oai.providers.base import UsageStats +from oai.providers.openrouter import OpenRouterProvider +from oai.ui.console import ( + clear_screen, + console, + display_panel, + print_error, + print_info, + print_success, + print_warning, +) +from oai.ui.tables import create_model_table, display_paginated_table +from oai.utils.logging import LoggingManager, get_logger + +# Create Typer app +app = typer.Typer( + name="oai", + help=f"oAI - OpenRouter AI Chat Client\n\nVersion: {APP_VERSION}", + add_completion=False, + epilog="For more information, visit: " + APP_URL, +) + + +@app.callback(invoke_without_command=True) +def main_callback( + ctx: typer.Context, + version_flag: bool = typer.Option( + False, + "--version", + "-v", + help="Show version information", + is_flag=True, + ), +) -> None: + """Main callback to handle global options.""" + # Show version with update check if --version flag + if version_flag: + version_info = check_for_updates(APP_VERSION) + console.print(version_info) + raise typer.Exit() + + # Show version with update check when --help is requested + if "--help" in sys.argv or "-h" in sys.argv: + version_info = check_for_updates(APP_VERSION) + console.print(f"\n{version_info}\n") + + # Continue to subcommand if provided + if ctx.invoked_subcommand is None: + return + + +def check_for_updates(current_version: str) -> str: + """Check for available updates.""" + import requests + from packaging import version as pkg_version + + try: + response = requests.get( + "https://gitlab.pm/api/v1/repos/rune/oai/releases/latest", + headers={"Content-Type": "application/json"}, + timeout=1.0, + ) + response.raise_for_status() + + data = response.json() + version_online = data.get("tag_name", "").lstrip("v") + + if not version_online: + return f"[bold green]oAI version {current_version}[/]" + + current = pkg_version.parse(current_version) + latest = pkg_version.parse(version_online) + + if latest > current: + return ( + f"[bold green]oAI version {current_version}[/] " + f"[bold red](Update available: {current_version} โ†’ {version_online})[/]" + ) + return f"[bold green]oAI version {current_version} (up to date)[/]" + + except Exception: + return f"[bold green]oAI version {current_version}[/]" + + +def show_welcome(settings: Settings, version_info: str) -> None: + """Display welcome message.""" + console.print(Panel.fit( + f"{version_info}\n\n" + "[bold cyan]Commands:[/] /help for commands, /model to select model\n" + "[bold cyan]MCP:[/] /mcp on to enable file/database access\n" + "[bold cyan]Exit:[/] Type 'exit', 'quit', or 'bye'", + title=f"[bold green]Welcome to {APP_NAME}[/]", + border_style="green", + )) + + +def select_model(client: AIClient, search_term: Optional[str] = None) -> Optional[dict]: + """Display model selection interface.""" + try: + models = client.provider.get_raw_models() + if not models: + print_error("No models available") + return None + + # Filter by search term if provided + if search_term: + search_lower = search_term.lower() + models = [m for m in models if search_lower in m.get("id", "").lower()] + + if not models: + print_error(f"No models found matching '{search_term}'") + return None + + # Create and display table + table = create_model_table(models) + display_paginated_table( + table, + f"[bold green]Available Models ({len(models)})[/]", + ) + + # Prompt for selection + console.print("") + try: + choice = input("Enter model number (or press Enter to cancel): ").strip() + except (EOFError, KeyboardInterrupt): + return None + + if not choice: + return None + + try: + index = int(choice) - 1 + if 0 <= index < len(models): + selected = models[index] + print_success(f"Selected model: {selected['id']}") + return selected + except ValueError: + pass + + print_error("Invalid selection") + return None + + except Exception as e: + print_error(f"Failed to fetch models: {e}") + return None + + +def run_chat_loop( + session: ChatSession, + prompt_session: PromptSession, + settings: Settings, +) -> None: + """Run the main chat loop.""" + logger = get_logger() + mcp_manager = session.mcp_manager + + while True: + try: + # Build prompt prefix + prefix = "You> " + if mcp_manager and mcp_manager.enabled: + if mcp_manager.mode == "files": + if mcp_manager.write_enabled: + prefix = "[๐Ÿ”งโœ๏ธ MCP: Files+Write] You> " + else: + prefix = "[๐Ÿ”ง MCP: Files] You> " + elif mcp_manager.mode == "database" and mcp_manager.selected_db_index is not None: + prefix = f"[๐Ÿ—„๏ธ MCP: DB #{mcp_manager.selected_db_index + 1}] You> " + + # Get user input + user_input = prompt_session.prompt( + prefix, + auto_suggest=AutoSuggestFromHistory(), + ).strip() + + if not user_input: + continue + + # Handle escape sequence + if user_input.startswith("//"): + user_input = user_input[1:] + + # Check for exit + if user_input.lower() in ["exit", "quit", "bye"]: + console.print( + f"\n[bold yellow]Goodbye![/]\n" + f"[dim]Session: {session.stats.total_tokens:,} tokens, " + f"${session.stats.total_cost:.4f}[/]" + ) + logger.info( + f"Session ended. Messages: {session.stats.message_count}, " + f"Tokens: {session.stats.total_tokens}, " + f"Cost: ${session.stats.total_cost:.4f}" + ) + return + + # Check for unknown commands + if user_input.startswith("/"): + cmd_word = user_input.split()[0].lower() + if not registry.is_command(user_input): + # Check if it's a valid command prefix + is_valid = any(cmd_word.startswith(cmd) for cmd in VALID_COMMANDS) + if not is_valid: + print_error(f"Unknown command: {cmd_word}") + print_info("Type /help to see available commands.") + continue + + # Try to execute as command + context = session.get_context() + result = registry.execute(user_input, context) + + if result: + # Update session state from context + session.memory_enabled = context.memory_enabled + session.memory_start_index = context.memory_start_index + session.online_enabled = context.online_enabled + session.middle_out_enabled = context.middle_out_enabled + session.session_max_token = context.session_max_token + session.current_index = context.current_index + session.system_prompt = context.session_system_prompt + + if result.status == CommandStatus.EXIT: + return + + # Handle special results + if result.data: + # Retry - resend last prompt + if "retry_prompt" in result.data: + user_input = result.data["retry_prompt"] + # Fall through to send message + + # Paste - send clipboard content + elif "paste_prompt" in result.data: + user_input = result.data["paste_prompt"] + # Fall through to send message + + # Model selection + elif "show_model_selector" in result.data: + search = result.data.get("search", "") + model = select_model(session.client, search if search else None) + if model: + session.set_model(model) + # If this came from /config model, also save as default + if result.data.get("set_as_default"): + settings.set_default_model(model["id"]) + print_success(f"Default model set to: {model['id']}") + continue + + # Load conversation + elif "load_conversation" in result.data: + history = result.data.get("history", []) + session.history.clear() + from oai.core.session import HistoryEntry + for entry in history: + session.history.append(HistoryEntry( + prompt=entry.get("prompt", ""), + response=entry.get("response", ""), + prompt_tokens=entry.get("prompt_tokens", 0), + completion_tokens=entry.get("completion_tokens", 0), + msg_cost=entry.get("msg_cost", 0.0), + )) + session.current_index = len(session.history) - 1 + continue + + else: + # Normal command completed + continue + else: + # Command completed with no special data + continue + + # Ensure model is selected + if not session.selected_model: + print_warning("Please select a model first with /model") + continue + + # Send message + stream = settings.stream_enabled + if mcp_manager and mcp_manager.enabled: + tools = session.get_mcp_tools() + if tools: + stream = False # Disable streaming with tools + + if stream: + console.print( + "[bold green]Streaming response...[/] " + "[dim](Press Ctrl+C to cancel)[/]" + ) + if session.online_enabled: + console.print("[dim cyan]๐ŸŒ Online mode active[/]") + console.print("") + + try: + response_text, usage, response_time = session.send_message( + user_input, + stream=stream, + ) + except Exception as e: + print_error(f"Error: {e}") + logger.error(f"Message error: {e}") + continue + + if not response_text: + print_error("No response received") + continue + + # Display non-streaming response + if not stream: + console.print() + display_panel( + Markdown(response_text), + title="[bold green]AI Response[/]", + border_style="green", + ) + + # Calculate cost and tokens + cost = 0.0 + tokens = 0 + estimated = False + + if usage and (usage.prompt_tokens > 0 or usage.completion_tokens > 0): + tokens = usage.total_tokens + if usage.total_cost_usd: + cost = usage.total_cost_usd + else: + cost = session.client.estimate_cost( + session.selected_model["id"], + usage.prompt_tokens, + usage.completion_tokens, + ) + else: + # Estimate tokens when usage not available (streaming fallback) + # Rough estimate: ~4 characters per token for English text + est_input_tokens = len(user_input) // 4 + 1 + est_output_tokens = len(response_text) // 4 + 1 + tokens = est_input_tokens + est_output_tokens + cost = session.client.estimate_cost( + session.selected_model["id"], + est_input_tokens, + est_output_tokens, + ) + # Create estimated usage for session tracking + usage = UsageStats( + prompt_tokens=est_input_tokens, + completion_tokens=est_output_tokens, + total_tokens=tokens, + ) + estimated = True + + # Add to history + session.add_to_history(user_input, response_text, usage, cost) + + # Display metrics + est_marker = "~" if estimated else "" + context_info = "" + if session.memory_enabled: + context_count = len(session.history) - session.memory_start_index + if context_count > 1: + context_info = f", Context: {context_count} msg(s)" + else: + context_info = ", Memory: OFF" + + online_emoji = " ๐ŸŒ" if session.online_enabled else "" + mcp_emoji = "" + if mcp_manager and mcp_manager.enabled: + if mcp_manager.mode == "files": + mcp_emoji = " ๐Ÿ”ง" + elif mcp_manager.mode == "database": + mcp_emoji = " ๐Ÿ—„๏ธ" + + console.print( + f"\n[dim blue]๐Ÿ“Š {est_marker}{tokens} tokens | {est_marker}${cost:.4f} | {response_time:.2f}s" + f"{context_info}{online_emoji}{mcp_emoji} | " + f"Session: {est_marker}{session.stats.total_tokens:,} tokens | " + f"{est_marker}${session.stats.total_cost:.4f}[/]" + ) + + # Check warnings + warnings = session.check_warnings() + for warning in warnings: + print_warning(warning) + + # Offer to copy + console.print("") + try: + from oai.ui.prompts import prompt_copy_response + prompt_copy_response(response_text) + except Exception: + pass + console.print("") + + except KeyboardInterrupt: + console.print("\n[dim]Input cancelled[/]") + continue + except EOFError: + console.print("\n[bold yellow]Goodbye![/]") + return + + +@app.command() +def chat( + model: Optional[str] = typer.Option( + None, + "--model", + "-m", + help="Model ID to use", + ), + system: Optional[str] = typer.Option( + None, + "--system", + "-s", + help="System prompt", + ), + online: bool = typer.Option( + False, + "--online", + "-o", + help="Enable online mode", + ), + mcp: bool = typer.Option( + False, + "--mcp", + help="Enable MCP server", + ), +) -> None: + """Start an interactive chat session.""" + # Setup logging + logging_manager = LoggingManager() + logging_manager.setup() + logger = get_logger() + + # Clear screen + clear_screen() + + # Load settings + settings = Settings.load() + + # Check API key + if not settings.api_key: + print_error("No API key configured") + print_info("Run: oai --config api to set your API key") + raise typer.Exit(1) + + # Initialize client + try: + client = AIClient( + api_key=settings.api_key, + base_url=settings.base_url, + ) + except Exception as e: + print_error(f"Failed to initialize client: {e}") + raise typer.Exit(1) + + # Register commands + register_all_commands() + + # Check for updates and show welcome + version_info = check_for_updates(APP_VERSION) + show_welcome(settings, version_info) + + # Initialize MCP manager + mcp_manager = MCPManager() + if mcp: + result = mcp_manager.enable() + if result["success"]: + print_success("MCP enabled") + else: + print_warning(f"MCP: {result.get('error', 'Failed to enable')}") + + # Create session + session = ChatSession( + client=client, + settings=settings, + mcp_manager=mcp_manager, + ) + + # Set system prompt + if system: + session.system_prompt = system + print_info(f"System prompt: {system}") + + # Set online mode + if online: + session.online_enabled = True + print_info("Online mode enabled") + + # Select model + if model: + raw_model = client.get_raw_model(model) + if raw_model: + session.set_model(raw_model) + else: + print_warning(f"Model '{model}' not found") + elif settings.default_model: + raw_model = client.get_raw_model(settings.default_model) + if raw_model: + session.set_model(raw_model) + else: + print_warning(f"Default model '{settings.default_model}' not available") + + # Setup prompt session + HISTORY_FILE.parent.mkdir(parents=True, exist_ok=True) + prompt_session = PromptSession( + history=FileHistory(str(HISTORY_FILE)), + ) + + # Run chat loop + run_chat_loop(session, prompt_session, settings) + + +@app.command() +def config( + setting: Optional[str] = typer.Argument( + None, + help="Setting to configure (api, url, model, system, stream, costwarning, maxtoken, online, log, loglevel)", + ), + value: Optional[str] = typer.Argument( + None, + help="Value to set", + ), +) -> None: + """View or modify configuration settings.""" + settings = Settings.load() + + if not setting: + # Show all settings + from rich.table import Table + from oai.constants import DEFAULT_SYSTEM_PROMPT + + table = Table("Setting", "Value", show_header=True, header_style="bold magenta") + table.add_row("API Key", "***" + settings.api_key[-4:] if settings.api_key else "Not set") + table.add_row("Base URL", settings.base_url) + table.add_row("Default Model", settings.default_model or "Not set") + + # Show system prompt status + if settings.default_system_prompt is None: + system_prompt_display = f"[default] {DEFAULT_SYSTEM_PROMPT[:40]}..." + elif settings.default_system_prompt == "": + system_prompt_display = "[blank]" + else: + system_prompt_display = settings.default_system_prompt[:50] + "..." if len(settings.default_system_prompt) > 50 else settings.default_system_prompt + table.add_row("System Prompt", system_prompt_display) + + table.add_row("Streaming", "on" if settings.stream_enabled else "off") + table.add_row("Cost Warning", f"${settings.cost_warning_threshold:.4f}") + table.add_row("Max Tokens", str(settings.max_tokens)) + table.add_row("Default Online", "on" if settings.default_online_mode else "off") + table.add_row("Log Level", settings.log_level) + + display_panel(table, title="[bold green]Configuration[/]") + return + + setting = setting.lower() + + if setting == "api": + if value: + settings.set_api_key(value) + else: + from oai.ui.prompts import prompt_input + new_key = prompt_input("Enter API key", password=True) + if new_key: + settings.set_api_key(new_key) + print_success("API key updated") + + elif setting == "url": + settings.set_base_url(value or "https://openrouter.ai/api/v1") + print_success(f"Base URL set to: {settings.base_url}") + + elif setting == "model": + if value: + settings.set_default_model(value) + print_success(f"Default model set to: {value}") + else: + print_info(f"Current default model: {settings.default_model or 'Not set'}") + + elif setting == "system": + from oai.constants import DEFAULT_SYSTEM_PROMPT + + if value: + # Decode escape sequences like \n for newlines + value = value.encode().decode('unicode_escape') + settings.set_default_system_prompt(value) + if value: + print_success(f"Default system prompt set to: {value}") + else: + print_success("Default system prompt set to blank.") + else: + if settings.default_system_prompt is None: + print_info(f"Using hardcoded default: {DEFAULT_SYSTEM_PROMPT[:60]}...") + elif settings.default_system_prompt == "": + print_info("System prompt: [blank]") + else: + print_info(f"System prompt: {settings.default_system_prompt}") + + elif setting == "stream": + if value and value.lower() in ["on", "off"]: + settings.set_stream_enabled(value.lower() == "on") + print_success(f"Streaming {'enabled' if settings.stream_enabled else 'disabled'}") + else: + print_info("Usage: oai config stream [on|off]") + + elif setting == "costwarning": + if value: + try: + threshold = float(value) + settings.set_cost_warning_threshold(threshold) + print_success(f"Cost warning threshold set to: ${threshold:.4f}") + except ValueError: + print_error("Please enter a valid number") + else: + print_info(f"Current threshold: ${settings.cost_warning_threshold:.4f}") + + elif setting == "maxtoken": + if value: + try: + max_tok = int(value) + settings.set_max_tokens(max_tok) + print_success(f"Max tokens set to: {max_tok}") + except ValueError: + print_error("Please enter a valid number") + else: + print_info(f"Current max tokens: {settings.max_tokens}") + + elif setting == "online": + if value and value.lower() in ["on", "off"]: + settings.set_default_online_mode(value.lower() == "on") + print_success(f"Default online mode {'enabled' if settings.default_online_mode else 'disabled'}") + else: + print_info("Usage: oai config online [on|off]") + + elif setting == "loglevel": + valid_levels = ["debug", "info", "warning", "error", "critical"] + if value and value.lower() in valid_levels: + settings.set_log_level(value.lower()) + print_success(f"Log level set to: {value.lower()}") + else: + print_info(f"Valid levels: {', '.join(valid_levels)}") + + else: + print_error(f"Unknown setting: {setting}") + + +@app.command() +def version() -> None: + """Show version information.""" + version_info = check_for_updates(APP_VERSION) + console.print(version_info) + + +@app.command() +def credits() -> None: + """Check account credits.""" + settings = Settings.load() + + if not settings.api_key: + print_error("No API key configured") + raise typer.Exit(1) + + client = AIClient(api_key=settings.api_key, base_url=settings.base_url) + credits_data = client.get_credits() + + if not credits_data: + print_error("Failed to fetch credits") + raise typer.Exit(1) + + from rich.table import Table + + table = Table("Metric", "Value", show_header=True, header_style="bold magenta") + table.add_row("Total Credits", credits_data.get("total_credits_formatted", "N/A")) + table.add_row("Used Credits", credits_data.get("used_credits_formatted", "N/A")) + table.add_row("Credits Left", credits_data.get("credits_left_formatted", "N/A")) + + display_panel(table, title="[bold green]Account Credits[/]") + + +def main() -> None: + """Main entry point.""" + # Default to 'chat' command if no arguments provided + if len(sys.argv) == 1: + sys.argv.append("chat") + app() + + +if __name__ == "__main__": + main() diff --git a/oai/commands/__init__.py b/oai/commands/__init__.py new file mode 100644 index 0000000..37068b3 --- /dev/null +++ b/oai/commands/__init__.py @@ -0,0 +1,24 @@ +""" +Command system for oAI. + +This module provides a command registry and handler system +for processing slash commands in the chat interface. +""" + +from oai.commands.registry import ( + Command, + CommandRegistry, + CommandContext, + CommandResult, + registry, +) +from oai.commands.handlers import register_all_commands + +__all__ = [ + "Command", + "CommandRegistry", + "CommandContext", + "CommandResult", + "registry", + "register_all_commands", +] diff --git a/oai/commands/handlers.py b/oai/commands/handlers.py new file mode 100644 index 0000000..49f23a8 --- /dev/null +++ b/oai/commands/handlers.py @@ -0,0 +1,1441 @@ +""" +Command handlers for oAI. + +This module implements all the slash commands available in the chat interface. +Each command is registered with the global registry. +""" + +import json +from typing import Any, Dict, List, Optional + +from oai.commands.registry import ( + Command, + CommandContext, + CommandHelp, + CommandResult, + CommandStatus, + registry, +) +from oai.constants import COMMAND_HELP, VALID_COMMANDS +from oai.ui.console import ( + clear_screen, + console, + display_markdown, + display_panel, + print_error, + print_info, + print_success, + print_warning, +) +from oai.ui.prompts import prompt_confirm, prompt_input +from oai.ui.tables import ( + create_model_table, + create_stats_table, + display_paginated_table, +) +from oai.utils.export import export_as_html, export_as_json, export_as_markdown +from oai.utils.logging import get_logger + + +class HelpCommand(Command): + """Display help information for commands.""" + + @property + def name(self) -> str: + return "/help" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Display help information for commands.", + usage="/help [command|topic]", + examples=[ + ("Show all commands", "/help"), + ("Get help for a specific command", "/help /model"), + ("Get detailed MCP help", "/help mcp"), + ], + notes="Use /help without arguments to see the full command list.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + logger = get_logger() + + if args: + # Show help for specific command/topic + self._show_command_help(args) + else: + # Show general help + self._show_general_help(context) + + logger.info(f"Displayed help for: {args or 'general'}") + return CommandResult.success() + + def _show_command_help(self, command_or_topic: str) -> None: + """Show help for a specific command or topic.""" + # Handle topics (like 'mcp') + if not command_or_topic.startswith("/"): + if command_or_topic.lower() == "mcp": + help_data = COMMAND_HELP.get("mcp", {}) + if help_data: + content = [] + content.append(f"[bold cyan]Description:[/]") + content.append(help_data.get("description", "")) + content.append("") + content.append(help_data.get("notes", "")) + + display_panel( + "\n".join(content), + title="[bold green]MCP - Model Context Protocol Guide[/]", + border_style="green", + ) + return + command_or_topic = "/" + command_or_topic + + help_data = COMMAND_HELP.get(command_or_topic) + if not help_data: + print_error(f"Unknown command: {command_or_topic}") + print_warning("Type /help to see all available commands.") + return + + content = [] + + if help_data.get("aliases"): + content.append(f"[bold cyan]Aliases:[/] {', '.join(help_data['aliases'])}") + content.append("") + + content.append("[bold cyan]Description:[/]") + content.append(help_data.get("description", "")) + content.append("") + + content.append("[bold cyan]Usage:[/]") + content.append(f"[yellow]{help_data.get('usage', '')}[/]") + content.append("") + + examples = help_data.get("examples", []) + if examples: + content.append("[bold cyan]Examples:[/]") + for desc, example in examples: + if not desc and not example: + content.append("") + elif desc.startswith("โ”โ”โ”"): + content.append(f"[bold yellow]{desc}[/]") + else: + if desc: + content.append(f" [dim]{desc}:[/]") + if example: + content.append(f" [green]{example}[/]") + content.append("") + + notes = help_data.get("notes") + if notes: + content.append("[bold cyan]Notes:[/]") + content.append(f"[dim]{notes}[/]") + + display_panel( + "\n".join(content), + title=f"[bold green]Help: {command_or_topic}[/]", + border_style="green", + ) + + def _show_general_help(self, context: CommandContext) -> None: + """Show general help with all commands.""" + from rich.table import Table + + table = Table( + "Command", + "Description", + "Example", + show_header=True, + header_style="bold magenta", + show_lines=False, + ) + + # Group commands by category + categories = [ + ("[bold cyan]โ”โ”โ” CHAT โ”โ”โ”[/]", [ + ("/retry", "Resend the last prompt.", "/retry"), + ("/memory", "Toggle conversation memory.", "/memory on"), + ("/online", "Toggle online mode (web search).", "/online on"), + ("/paste", "Paste from clipboard with optional prompt.", "/paste Explain"), + ]), + ("[bold cyan]โ”โ”โ” NAVIGATION โ”โ”โ”[/]", [ + ("/prev", "View previous response in history.", "/prev"), + ("/next", "View next response in history.", "/next"), + ("/reset", "Clear conversation history.", "/reset"), + ]), + ("[bold cyan]โ”โ”โ” MODEL & CONFIG โ”โ”โ”[/]", [ + ("/model", "Select AI model.", "/model gpt"), + ("/info", "Show model information.", "/info"), + ("/config", "View or change settings.", "/config stream on"), + ("/maxtoken", "Set session token limit.", "/maxtoken 2000"), + ("/system", "Set system prompt.", "/system You are an expert"), + ]), + ("[bold cyan]โ”โ”โ” SAVE & EXPORT โ”โ”โ”[/]", [ + ("/save", "Save conversation.", "/save my_chat"), + ("/load", "Load saved conversation.", "/load my_chat"), + ("/delete", "Delete saved conversation.", "/delete my_chat"), + ("/list", "List saved conversations.", "/list"), + ("/export", "Export conversation.", "/export md notes.md"), + ]), + ("[bold cyan]โ”โ”โ” STATS & INFO โ”โ”โ”[/]", [ + ("/stats", "Show session statistics.", "/stats"), + ("/credits", "Show account credits.", "/credits"), + ("/middleout", "Toggle middle-out compression.", "/middleout on"), + ]), + ("[bold cyan]โ”โ”โ” MCP (FILE/DB ACCESS) โ”โ”โ”[/]", [ + ("/mcp on", "Enable MCP server.", "/mcp on"), + ("/mcp add", "Add folder or database.", "/mcp add ~/Documents"), + ("/mcp status", "Show MCP status.", "/mcp status"), + ("/mcp write", "Toggle write mode.", "/mcp write on"), + ]), + ("[bold cyan]โ”โ”โ” UTILITY โ”โ”โ”[/]", [ + ("/clear", "Clear the screen.", "/clear"), + ("/help", "Show this help.", "/help /model"), + ]), + ("[bold yellow]โ”โ”โ” EXIT โ”โ”โ”[/]", [ + ("exit", "Quit the application.", "exit"), + ]), + ] + + for header, commands in categories: + table.add_row(header, "", "") + for cmd, desc, example in commands: + table.add_row(cmd, desc, example) + + from oai.constants import APP_VERSION + + display_panel( + table, + title=f"[bold cyan]oAI Chat Help (Version {APP_VERSION})[/]", + subtitle="๐Ÿ’ก Use /help for details โ€ข /help mcp for MCP guide", + ) + + +class ClearCommand(Command): + """Clear the terminal screen.""" + + @property + def name(self) -> str: + return "/clear" + + @property + def aliases(self) -> List[str]: + return ["/cl"] + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Clear the terminal screen.", + usage="/clear", + aliases=["/cl"], + notes="You can also use Ctrl+L.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + clear_screen() + return CommandResult.success() + + +class MemoryCommand(Command): + """Toggle conversation memory.""" + + @property + def name(self) -> str: + return "/memory" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Toggle conversation memory.", + usage="/memory [on|off]", + examples=[ + ("Check status", "/memory"), + ("Enable memory", "/memory on"), + ("Disable memory", "/memory off"), + ], + notes="When off, each message is independent (saves tokens).", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + status = "enabled" if context.memory_enabled else "disabled" + print_info(f"Conversation memory is {status}.") + return CommandResult.success() + + if args.lower() == "on": + context.memory_enabled = True + print_success("Memory enabled - AI will remember conversation.") + return CommandResult.success(data={"memory_enabled": True}) + elif args.lower() == "off": + context.memory_enabled = False + context.memory_start_index = len(context.session_history) + print_success("Memory disabled - each message is independent.") + return CommandResult.success(data={"memory_enabled": False}) + else: + print_error("Usage: /memory [on|off]") + return CommandResult.error("Invalid argument") + + +class OnlineCommand(Command): + """Toggle online mode (web search).""" + + @property + def name(self) -> str: + return "/online" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Enable or disable online mode (web search).", + usage="/online [on|off]", + examples=[ + ("Check status", "/online"), + ("Enable web search", "/online on"), + ("Disable web search", "/online off"), + ], + notes="Not all models support online mode.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + status = "enabled" if context.online_enabled else "disabled" + print_info(f"Online mode is {status}.") + return CommandResult.success() + + if args.lower() == "on": + if context.selected_model_raw: + params = context.selected_model_raw.get("supported_parameters", []) + if "tools" not in params: + print_warning("Current model may not support online mode.") + + context.online_enabled = True + print_success("Online mode enabled - AI can search the web.") + return CommandResult.success(data={"online_enabled": True}) + elif args.lower() == "off": + context.online_enabled = False + print_success("Online mode disabled.") + return CommandResult.success(data={"online_enabled": False}) + else: + print_error("Usage: /online [on|off]") + return CommandResult.error("Invalid argument") + + +class ResetCommand(Command): + """Reset conversation history.""" + + @property + def name(self) -> str: + return "/reset" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Clear conversation history and reset system prompt.", + usage="/reset", + notes="Requires confirmation. Resets all session metrics.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not prompt_confirm("Reset conversation and clear history?"): + print_info("Reset cancelled.") + return CommandResult.success() + + context.session_history.clear() + context.session_system_prompt = "" + context.memory_start_index = 0 + context.current_index = 0 + context.total_input_tokens = 0 + context.total_output_tokens = 0 + context.total_cost = 0.0 + context.message_count = 0 + + print_success("Conversation reset. History and system prompt cleared.") + get_logger().info("Conversation reset by user") + return CommandResult.success() + + +class StatsCommand(Command): + """Display session statistics.""" + + @property + def name(self) -> str: + return "/stats" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Display session statistics.", + usage="/stats", + notes="Shows tokens, costs, and credits.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + from rich.table import Table + + table = Table( + "Metric", + "Value", + show_header=True, + header_style="bold magenta", + ) + + table.add_row("Input Tokens", f"{context.total_input_tokens:,}") + table.add_row("Output Tokens", f"{context.total_output_tokens:,}") + table.add_row( + "Total Tokens", + f"{context.total_input_tokens + context.total_output_tokens:,}", + ) + table.add_row("Total Cost", f"${context.total_cost:.4f}") + + if context.message_count > 0: + avg_cost = context.total_cost / context.message_count + table.add_row("Avg Cost/Message", f"${avg_cost:.4f}") + + table.add_row("Messages", str(context.message_count)) + + # Get credits if provider available + if context.provider: + credits = context.provider.get_credits() + if credits: + table.add_row("", "") + table.add_row("[bold]Credits[/]", "") + table.add_row( + "Credits Left", + credits.get("credits_left_formatted", "N/A"), + ) + + display_panel(table, title="[bold green]Session Statistics[/]") + return CommandResult.success() + + +class CreditsCommand(Command): + """Display account credits.""" + + @property + def name(self) -> str: + return "/credits" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Display your OpenRouter account credits.", + usage="/credits", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not context.provider: + print_error("No provider configured.") + return CommandResult.error("No provider") + + credits = context.provider.get_credits() + if not credits: + print_error("Failed to fetch credits.") + return CommandResult.error("Failed to fetch credits") + + from rich.table import Table + + table = Table( + "Metric", + "Value", + show_header=True, + header_style="bold magenta", + ) + + table.add_row("Total Credits", credits.get("total_credits_formatted", "N/A")) + table.add_row("Used Credits", credits.get("used_credits_formatted", "N/A")) + table.add_row("Credits Left", credits.get("credits_left_formatted", "N/A")) + + # Check for warnings + from oai.constants import LOW_CREDIT_AMOUNT, LOW_CREDIT_RATIO + + credits_left = credits.get("credits_left", 0) + total = credits.get("total_credits", 0) + + warnings = [] + if credits_left < LOW_CREDIT_AMOUNT: + warnings.append(f"Less than ${LOW_CREDIT_AMOUNT:.2f} remaining!") + elif total > 0 and credits_left < total * LOW_CREDIT_RATIO: + warnings.append("Less than 10% of credits remaining!") + + display_panel(table, title="[bold green]Account Credits[/]") + + for warning in warnings: + print_warning(warning) + + return CommandResult.success(data=credits) + + +class ExportCommand(Command): + """Export conversation to file.""" + + @property + def name(self) -> str: + return "/export" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Export conversation to a file.", + usage="/export ", + examples=[ + ("Export as Markdown", "/export md notes.md"), + ("Export as JSON", "/export json conversation.json"), + ("Export as HTML", "/export html report.html"), + ], + notes="Available formats: md, json, html", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + parts = args.split(maxsplit=1) + if len(parts) != 2: + print_error("Usage: /export ") + return CommandResult.error("Invalid arguments") + + fmt, filename = parts + fmt = fmt.lower() + + if fmt not in ["md", "json", "html"]: + print_error(f"Unknown format: {fmt}") + print_info("Available formats: md, json, html") + return CommandResult.error("Invalid format") + + if not context.session_history: + print_warning("No conversation to export.") + return CommandResult.error("Empty history") + + try: + if fmt == "md": + content = export_as_markdown( + context.session_history, + context.session_system_prompt, + ) + elif fmt == "json": + content = export_as_json( + context.session_history, + context.session_system_prompt, + ) + else: # html + content = export_as_html( + context.session_history, + context.session_system_prompt, + ) + + with open(filename, "w", encoding="utf-8") as f: + f.write(content) + + print_success(f"Exported to {filename}") + get_logger().info(f"Exported conversation to {filename} ({fmt})") + return CommandResult.success() + + except Exception as e: + print_error(f"Export failed: {e}") + return CommandResult.error(str(e)) + + +class MiddleOutCommand(Command): + """Toggle middle-out compression.""" + + @property + def name(self) -> str: + return "/middleout" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Toggle middle-out transform for long prompts.", + usage="/middleout [on|off]", + examples=[ + ("Check status", "/middleout"), + ("Enable", "/middleout on"), + ("Disable", "/middleout off"), + ], + notes="Compresses prompts exceeding context size.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + status = "enabled" if context.middle_out_enabled else "disabled" + print_info(f"Middle-out compression is {status}.") + return CommandResult.success() + + if args.lower() == "on": + context.middle_out_enabled = True + print_success("Middle-out compression enabled.") + return CommandResult.success(data={"middle_out_enabled": True}) + elif args.lower() == "off": + context.middle_out_enabled = False + print_success("Middle-out compression disabled.") + return CommandResult.success(data={"middle_out_enabled": False}) + else: + print_error("Usage: /middleout [on|off]") + return CommandResult.error("Invalid argument") + + +class MaxTokenCommand(Command): + """Set session token limit.""" + + @property + def name(self) -> str: + return "/maxtoken" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Set session token limit.", + usage="/maxtoken [value]", + examples=[ + ("View current", "/maxtoken"), + ("Set to 2000", "/maxtoken 2000"), + ], + notes="Cannot exceed stored max token setting.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + if context.session_max_token > 0: + print_info(f"Session max token: {context.session_max_token}") + else: + print_info("No session token limit set (using model default).") + return CommandResult.success() + + try: + value = int(args) + if value <= 0: + print_error("Token limit must be positive.") + return CommandResult.error("Invalid value") + + # Check against stored limit + stored_max = 100000 # Default + if context.settings: + stored_max = context.settings.max_tokens + + if value > stored_max: + print_warning( + f"Value {value} exceeds stored limit {stored_max}. " + f"Using {stored_max}." + ) + value = stored_max + + context.session_max_token = value + print_success(f"Session max token set to {value}.") + return CommandResult.success(data={"session_max_token": value}) + + except ValueError: + print_error("Please enter a valid number.") + return CommandResult.error("Invalid number") + + +class SystemCommand(Command): + """Set session system prompt.""" + + @property + def name(self) -> str: + return "/system" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Set or clear the session system prompt.", + usage="/system [prompt|clear|default ]", + examples=[ + ("View current", "/system"), + ("Set prompt", "/system You are a Python expert"), + ("Multiline prompt", r"/system You are an expert.\nRespond clearly."), + ("Blank prompt", '/system ""'), + ("Save as default", "/system default You are a Python expert"), + ("Revert to default", "/system clear"), + ], + notes=r'Use \n for newlines. Use /system "" for blank, /system clear to revert to default.', + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + from oai.constants import DEFAULT_SYSTEM_PROMPT + + if not args: + # Show current session prompt and default + if context.session_system_prompt: + print_info(f"Session prompt: {context.session_system_prompt}") + else: + print_info("Session prompt: [blank]") + + if context.settings: + if context.settings.default_system_prompt is None: + print_info(f"Default prompt: [hardcoded] {DEFAULT_SYSTEM_PROMPT[:60]}...") + elif context.settings.default_system_prompt == "": + print_info("Default prompt: [blank]") + else: + print_info(f"Custom default: {context.settings.default_system_prompt}") + return CommandResult.success() + + if args.lower() == "clear": + # Revert to hardcoded default + if context.settings: + context.settings.clear_default_system_prompt() + context.session_system_prompt = DEFAULT_SYSTEM_PROMPT + print_success("Reverted to hardcoded default system prompt.") + print_info(f"Default: {DEFAULT_SYSTEM_PROMPT[:60]}...") + else: + context.session_system_prompt = DEFAULT_SYSTEM_PROMPT + print_success("Session prompt reverted to default.") + return CommandResult.success() + + # Check for default command + if args.lower().startswith("default "): + prompt = args[8:] # Remove "default " prefix (keep trailing spaces) + if not prompt: + print_error("Usage: /system default ") + return CommandResult.error("No prompt provided") + + # Decode escape sequences like \n for newlines + prompt = prompt.encode().decode('unicode_escape') + + if context.settings: + context.settings.set_default_system_prompt(prompt) + context.session_system_prompt = prompt + if prompt: + print_success(f"Default system prompt saved: {prompt}") + else: + print_success("Default system prompt set to blank.") + get_logger().info(f"Default system prompt updated: {prompt[:50]}...") + else: + print_error("Settings not available") + return CommandResult.error("No settings") + return CommandResult.success() + + # Decode escape sequences like \n for newlines + prompt = args.encode().decode('unicode_escape') + + context.session_system_prompt = prompt + if prompt: + print_success(f"Session system prompt set: {prompt}") + print_info("Use '/system default ' to save as default for all sessions") + else: + print_success("Session system prompt set to blank.") + get_logger().info(f"System prompt updated: {prompt[:50]}...") + return CommandResult.success() + + +class RetryCommand(Command): + """Resend the last prompt.""" + + @property + def name(self) -> str: + return "/retry" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Resend the last prompt.", + usage="/retry", + notes="Requires at least one message in history.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not context.session_history: + print_error("No message to retry.") + return CommandResult.error("Empty history") + + last_prompt = context.session_history[-1].get("prompt", "") + if not last_prompt: + print_error("Last message has no prompt.") + return CommandResult.error("No prompt") + + # Return the prompt to be re-sent + print_info(f"Retrying: {last_prompt[:50]}...") + return CommandResult.success(data={"retry_prompt": last_prompt}) + + +class PrevCommand(Command): + """View previous response.""" + + @property + def name(self) -> str: + return "/prev" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="View previous response in history.", + usage="/prev", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not context.session_history: + print_error("No history to navigate.") + return CommandResult.error("Empty history") + + if context.current_index <= 0: + print_info("Already at the beginning of history.") + return CommandResult.success() + + context.current_index -= 1 + entry = context.session_history[context.current_index] + + console.print(f"\n[dim]Message {context.current_index + 1}/{len(context.session_history)}[/]") + console.print(f"\n[bold cyan]Prompt:[/] {entry.get('prompt', '')}") + display_markdown(entry.get("response", ""), panel=True, title="Response") + + return CommandResult.success() + + +class NextCommand(Command): + """View next response.""" + + @property + def name(self) -> str: + return "/next" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="View next response in history.", + usage="/next", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not context.session_history: + print_error("No history to navigate.") + return CommandResult.error("Empty history") + + if context.current_index >= len(context.session_history) - 1: + print_info("Already at the end of history.") + return CommandResult.success() + + context.current_index += 1 + entry = context.session_history[context.current_index] + + console.print(f"\n[dim]Message {context.current_index + 1}/{len(context.session_history)}[/]") + console.print(f"\n[bold cyan]Prompt:[/] {entry.get('prompt', '')}") + display_markdown(entry.get("response", ""), panel=True, title="Response") + + return CommandResult.success() + + +class ConfigCommand(Command): + """View or modify configuration.""" + + @property + def name(self) -> str: + return "/config" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="View or modify application configuration.", + usage="/config [setting] [value]", + examples=[ + ("View all settings", "/config"), + ("Set API key", "/config api"), + ("Enable streaming", "/config stream on"), + ], + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + from rich.table import Table + + if not context.settings: + print_error("Settings not available") + return CommandResult.error("No settings") + + settings = context.settings + + if not args: + # Show all settings + from oai.constants import DEFAULT_SYSTEM_PROMPT + + table = Table("Setting", "Value", show_header=True, header_style="bold magenta") + table.add_row("API Key", "***" + settings.api_key[-4:] if settings.api_key else "Not set") + table.add_row("Base URL", settings.base_url) + table.add_row("Default Model", settings.default_model or "Not set") + + # Show system prompt status + if settings.default_system_prompt is None: + system_prompt_display = f"[default] {DEFAULT_SYSTEM_PROMPT[:40]}..." + elif settings.default_system_prompt == "": + system_prompt_display = "[blank]" + else: + system_prompt_display = settings.default_system_prompt[:50] + "..." if len(settings.default_system_prompt) > 50 else settings.default_system_prompt + table.add_row("System Prompt", system_prompt_display) + + table.add_row("Streaming", "on" if settings.stream_enabled else "off") + table.add_row("Cost Warning", f"${settings.cost_warning_threshold:.4f}") + table.add_row("Max Tokens", str(settings.max_tokens)) + table.add_row("Default Online", "on" if settings.default_online_mode else "off") + table.add_row("Log Level", settings.log_level) + + display_panel(table, title="[bold green]Configuration[/]") + return CommandResult.success() + + parts = args.split(maxsplit=1) + setting = parts[0].lower() + value = parts[1] if len(parts) > 1 else None + + if setting == "api": + if value: + settings.set_api_key(value) + else: + new_key = prompt_input("Enter API key", password=True) + if new_key: + settings.set_api_key(new_key) + print_success("API key updated") + + elif setting == "stream": + if value and value.lower() in ["on", "off"]: + settings.set_stream_enabled(value.lower() == "on") + print_success(f"Streaming {'enabled' if settings.stream_enabled else 'disabled'}") + else: + print_info("Usage: /config stream [on|off]") + + elif setting == "model": + if value: + # Show model selector with search term, same as /model + return CommandResult.success(data={"show_model_selector": True, "search": value, "set_as_default": True}) + else: + print_info(f"Current: {settings.default_model or 'Not set'}") + + elif setting == "system": + from oai.constants import DEFAULT_SYSTEM_PROMPT + + if value: + # Decode escape sequences like \n for newlines + value = value.encode().decode('unicode_escape') + settings.set_default_system_prompt(value) + if value: + print_success(f"Default system prompt set to: {value}") + else: + print_success("Default system prompt set to blank.") + else: + if settings.default_system_prompt is None: + print_info(f"Using hardcoded default: {DEFAULT_SYSTEM_PROMPT[:60]}...") + elif settings.default_system_prompt == "": + print_info("System prompt: [blank]") + else: + print_info(f"System prompt: {settings.default_system_prompt}") + + elif setting == "online": + if value and value.lower() in ["on", "off"]: + settings.set_default_online_mode(value.lower() == "on") + print_success(f"Default online {'enabled' if settings.default_online_mode else 'disabled'}") + else: + print_info("Usage: /config online [on|off]") + + elif setting == "costwarning": + if value: + try: + settings.set_cost_warning_threshold(float(value)) + print_success(f"Cost warning set to: ${float(value):.4f}") + except ValueError: + print_error("Please enter a valid number") + else: + print_info(f"Current: ${settings.cost_warning_threshold:.4f}") + + elif setting == "maxtoken": + if value: + try: + settings.set_max_tokens(int(value)) + print_success(f"Max tokens set to: {value}") + except ValueError: + print_error("Please enter a valid number") + else: + print_info(f"Current: {settings.max_tokens}") + + elif setting == "loglevel": + valid_levels = ["debug", "info", "warning", "error", "critical"] + if value and value.lower() in valid_levels: + settings.set_log_level(value.lower()) + print_success(f"Log level set to: {value.lower()}") + else: + print_info(f"Valid levels: {', '.join(valid_levels)}") + + else: + print_error(f"Unknown setting: {setting}") + return CommandResult.error("Unknown setting") + + return CommandResult.success() + + +class ListCommand(Command): + """List saved conversations.""" + + @property + def name(self) -> str: + return "/list" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="List all saved conversations.", + usage="/list", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + from oai.config.database import Database + from rich.table import Table + + db = Database() + conversations = db.list_conversations() + + if not conversations: + print_info("No saved conversations.") + return CommandResult.success() + + table = Table("No.", "Name", "Messages", "Last Saved", show_header=True, header_style="bold magenta") + + for i, conv in enumerate(conversations, 1): + table.add_row( + str(i), + conv["name"], + str(conv["message_count"]), + conv["timestamp"][:19] if conv.get("timestamp") else "-", + ) + + display_panel(table, title="[bold green]Saved Conversations[/]") + return CommandResult.success(data={"conversations": conversations}) + + +class SaveCommand(Command): + """Save conversation.""" + + @property + def name(self) -> str: + return "/save" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Save the current conversation.", + usage="/save ", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + print_error("Usage: /save ") + return CommandResult.error("Missing name") + + if not context.session_history: + print_warning("No conversation to save.") + return CommandResult.error("Empty history") + + from oai.config.database import Database + + db = Database() + db.save_conversation(args, context.session_history) + print_success(f"Conversation saved as '{args}'") + get_logger().info(f"Saved conversation: {args}") + return CommandResult.success() + + +class LoadCommand(Command): + """Load saved conversation.""" + + @property + def name(self) -> str: + return "/load" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Load a saved conversation.", + usage="/load ", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + print_error("Usage: /load ") + return CommandResult.error("Missing name") + + from oai.config.database import Database + + db = Database() + + # Check if it's a number + name = args + if args.isdigit(): + conversations = db.list_conversations() + index = int(args) - 1 + if 0 <= index < len(conversations): + name = conversations[index]["name"] + else: + print_error(f"Invalid number. Use 1-{len(conversations)}") + return CommandResult.error("Invalid number") + + data = db.load_conversation(name) + if not data: + print_error(f"Conversation '{name}' not found") + return CommandResult.error("Not found") + + print_success(f"Loaded conversation '{name}' ({len(data)} messages)") + get_logger().info(f"Loaded conversation: {name}") + return CommandResult.success(data={"load_conversation": name, "history": data}) + + +class DeleteCommand(Command): + """Delete saved conversation.""" + + @property + def name(self) -> str: + return "/delete" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Delete a saved conversation.", + usage="/delete ", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + print_error("Usage: /delete ") + return CommandResult.error("Missing name") + + from oai.config.database import Database + + db = Database() + + # Check if it's a number + name = args + if args.isdigit(): + conversations = db.list_conversations() + index = int(args) - 1 + if 0 <= index < len(conversations): + name = conversations[index]["name"] + else: + print_error(f"Invalid number. Use 1-{len(conversations)}") + return CommandResult.error("Invalid number") + + if not prompt_confirm(f"Delete conversation '{name}'?"): + print_info("Cancelled") + return CommandResult.success() + + count = db.delete_conversation(name) + if count > 0: + print_success(f"Deleted conversation '{name}'") + get_logger().info(f"Deleted conversation: {name}") + else: + print_error(f"Conversation '{name}' not found") + return CommandResult.error("Not found") + + return CommandResult.success() + + +class InfoCommand(Command): + """Show model information.""" + + @property + def name(self) -> str: + return "/info" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Display detailed model information.", + usage="/info [model_id]", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + from rich.table import Table + + if not context.provider: + print_error("No provider available") + return CommandResult.error("No provider") + + model_id = args.strip() if args else None + + if not model_id and context.selected_model_raw: + model_id = context.selected_model_raw.get("id") + + if not model_id: + print_error("No model specified. Use /info or select a model first.") + return CommandResult.error("No model") + + # Get raw model data + model = None + if hasattr(context.provider, "get_raw_model"): + model = context.provider.get_raw_model(model_id) + + if not model: + print_error(f"Model '{model_id}' not found") + return CommandResult.error("Not found") + + table = Table("Property", "Value", show_header=True, header_style="bold magenta") + table.add_row("ID", model.get("id", "")) + table.add_row("Name", model.get("name", "")) + table.add_row("Context Length", f"{model.get('context_length', 0):,}") + + # Pricing + pricing = model.get("pricing", {}) + if pricing: + prompt_price = float(pricing.get("prompt", 0)) * 1_000_000 + completion_price = float(pricing.get("completion", 0)) * 1_000_000 + table.add_row("Input Price", f"${prompt_price:.2f}/M tokens") + table.add_row("Output Price", f"${completion_price:.2f}/M tokens") + + # Capabilities + arch = model.get("architecture", {}) + input_mod = arch.get("input_modalities", []) + output_mod = arch.get("output_modalities", []) + supported = model.get("supported_parameters", []) + + table.add_row("Input Modalities", ", ".join(input_mod) if input_mod else "text") + table.add_row("Output Modalities", ", ".join(output_mod) if output_mod else "text") + table.add_row("Image Support", "โœ“" if "image" in input_mod else "โœ—") + table.add_row("Tool Support", "โœ“" if "tools" in supported else "โœ—") + table.add_row("Online Support", "โœ“" if "tools" in supported else "โœ—") + + display_panel(table, title=f"[bold green]Model: {model_id}[/]") + return CommandResult.success() + + +class MCPCommand(Command): + """MCP management command.""" + + @property + def name(self) -> str: + return "/mcp" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Manage MCP (Model Context Protocol).", + usage="/mcp [args]", + examples=[ + ("Enable MCP", "/mcp on"), + ("Show status", "/mcp status"), + ("Add folder", "/mcp add ~/Documents"), + ], + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not context.mcp_manager: + print_error("MCP not available") + return CommandResult.error("No MCP") + + mcp = context.mcp_manager + parts = args.strip().split(maxsplit=1) + cmd = parts[0].lower() if parts else "" + cmd_args = parts[1] if len(parts) > 1 else "" + + if cmd in ["on", "enable"]: + result = mcp.enable() + if result["success"]: + print_success(result.get("message", "MCP enabled")) + if result.get("folder_count", 0) == 0: + print_info("Add folders with: /mcp add ~/Documents") + else: + print_error(result.get("error", "Failed")) + return CommandResult.success() + + elif cmd in ["off", "disable"]: + result = mcp.disable() + if result["success"]: + print_success("MCP disabled") + else: + print_error(result.get("error", "Failed")) + return CommandResult.success() + + elif cmd == "status": + result = mcp.get_status() + if result["success"]: + from rich.table import Table + table = Table("Property", "Value", show_header=True, header_style="bold magenta") + + status = "[green]Active โœ“[/]" if result["enabled"] else "[red]Inactive โœ—[/]" + table.add_row("Status", status) + table.add_row("Mode", result.get("mode_info", {}).get("mode_display", "files")) + table.add_row("Folders", str(result.get("folder_count", 0))) + table.add_row("Databases", str(result.get("database_count", 0))) + table.add_row("Write Mode", "[green]Enabled[/]" if result.get("write_enabled") else "[dim]Disabled[/]") + table.add_row(".gitignore", result.get("gitignore_status", "on")) + + display_panel(table, title="[bold green]MCP Status[/]") + return CommandResult.success() + + elif cmd == "add": + if cmd_args.startswith("db "): + db_path = cmd_args[3:].strip() + result = mcp.add_database(db_path) + else: + result = mcp.add_folder(cmd_args) + + if result["success"]: + print_success(f"Added: {cmd_args}") + else: + print_error(result.get("error", "Failed")) + return CommandResult.success() + + elif cmd in ["remove", "rem"]: + if cmd_args.startswith("db "): + result = mcp.remove_database(cmd_args[3:].strip()) + else: + result = mcp.remove_folder(cmd_args) + + if result["success"]: + print_success(f"Removed: {cmd_args}") + else: + print_error(result.get("error", "Failed")) + return CommandResult.success() + + elif cmd == "list": + result = mcp.list_folders() + if result["success"]: + from rich.table import Table + table = Table("No.", "Path", "Files", "Size", show_header=True, header_style="bold magenta") + for f in result.get("folders", []): + table.add_row( + str(f["number"]), + f["path"], + str(f.get("file_count", 0)), + f"{f.get('size_mb', 0):.1f} MB", + ) + display_panel(table, title="[bold green]MCP Folders[/]") + return CommandResult.success() + + elif cmd == "db": + if cmd_args == "list": + result = mcp.list_databases() + if result["success"]: + from rich.table import Table + table = Table("No.", "Name", "Tables", "Size", show_header=True, header_style="bold magenta") + for db in result.get("databases", []): + table.add_row( + str(db["number"]), + db["name"], + str(db.get("table_count", 0)), + f"{db.get('size_mb', 0):.1f} MB", + ) + display_panel(table, title="[bold green]MCP Databases[/]") + elif cmd_args.isdigit(): + result = mcp.switch_mode("database", int(cmd_args)) + if result["success"]: + print_success(result.get("message", "Switched to database mode")) + else: + print_error(result.get("error", "Failed")) + return CommandResult.success() + + elif cmd == "files": + result = mcp.switch_mode("files") + if result["success"]: + print_success("Switched to file mode") + return CommandResult.success() + + elif cmd == "write": + if cmd_args.lower() == "on": + mcp.enable_write() + elif cmd_args.lower() == "off": + mcp.disable_write() + else: + status = "[green]Enabled[/]" if mcp.write_enabled else "[dim]Disabled[/]" + print_info(f"Write mode: {status}") + return CommandResult.success() + + elif cmd == "gitignore": + if cmd_args.lower() in ["on", "off"]: + result = mcp.toggle_gitignore(cmd_args.lower() == "on") + if result["success"]: + print_success(result.get("message", "Updated")) + else: + print_info("Usage: /mcp gitignore [on|off]") + return CommandResult.success() + + else: + print_error(f"Unknown MCP command: {cmd}") + print_info("Commands: on, off, status, add, remove, list, db, files, write, gitignore") + return CommandResult.error("Unknown command") + + +class PasteCommand(Command): + """Paste from clipboard.""" + + @property + def name(self) -> str: + return "/paste" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Paste from clipboard and send to AI.", + usage="/paste [prompt]", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + try: + import pyperclip + content = pyperclip.paste() + except ImportError: + print_error("pyperclip not installed") + return CommandResult.error("pyperclip not installed") + except Exception as e: + print_error(f"Failed to access clipboard: {e}") + return CommandResult.error(str(e)) + + if not content: + print_warning("Clipboard is empty") + return CommandResult.error("Empty clipboard") + + # Show preview + preview = content[:200] + "..." if len(content) > 200 else content + console.print(f"\n[dim]Clipboard content ({len(content)} chars):[/]") + console.print(f"[yellow]{preview}[/]\n") + + # Build the prompt + if args: + full_prompt = f"{args}\n\n```\n{content}\n```" + else: + full_prompt = content + + return CommandResult.success(data={"paste_prompt": full_prompt}) + + +class ModelCommand(Command): + """Select AI model.""" + + @property + def name(self) -> str: + return "/model" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Select or search for AI models.", + usage="/model [search_term]", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + # This is handled specially in the CLI, but we need a handler + # to prevent it from being sent to the AI + return CommandResult.success(data={"show_model_selector": True, "search": args}) + + +def register_all_commands() -> None: + """Register all built-in commands with the global registry.""" + commands = [ + HelpCommand(), + ClearCommand(), + MemoryCommand(), + OnlineCommand(), + ResetCommand(), + StatsCommand(), + CreditsCommand(), + ExportCommand(), + MiddleOutCommand(), + MaxTokenCommand(), + SystemCommand(), + RetryCommand(), + PrevCommand(), + NextCommand(), + ConfigCommand(), + ListCommand(), + SaveCommand(), + LoadCommand(), + DeleteCommand(), + InfoCommand(), + MCPCommand(), + PasteCommand(), + ModelCommand(), + ] + + for command in commands: + try: + registry.register(command) + except ValueError as e: + get_logger().warning(f"Failed to register command: {e}") diff --git a/oai/commands/registry.py b/oai/commands/registry.py new file mode 100644 index 0000000..9be81ad --- /dev/null +++ b/oai/commands/registry.py @@ -0,0 +1,381 @@ +""" +Command registry for oAI. + +This module defines the command system infrastructure including +the Command base class, CommandContext for state, and CommandRegistry +for managing available commands. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING + +from oai.utils.logging import get_logger + +if TYPE_CHECKING: + from oai.config.settings import Settings + from oai.providers.base import AIProvider, ModelInfo + from oai.mcp.manager import MCPManager + + +class CommandStatus(str, Enum): + """Status of command execution.""" + + SUCCESS = "success" + ERROR = "error" + CONTINUE = "continue" # Continue to next handler + EXIT = "exit" # Exit the application + + +@dataclass +class CommandResult: + """ + Result of a command execution. + + Attributes: + status: Execution status + message: Optional message to display + data: Optional data payload + should_continue: Whether to continue the main loop + """ + + status: CommandStatus = CommandStatus.SUCCESS + message: Optional[str] = None + data: Optional[Any] = None + should_continue: bool = True + + @classmethod + def success(cls, message: Optional[str] = None, data: Any = None) -> "CommandResult": + """Create a success result.""" + return cls(status=CommandStatus.SUCCESS, message=message, data=data) + + @classmethod + def error(cls, message: str) -> "CommandResult": + """Create an error result.""" + return cls(status=CommandStatus.ERROR, message=message) + + @classmethod + def exit(cls, message: Optional[str] = None) -> "CommandResult": + """Create an exit result.""" + return cls(status=CommandStatus.EXIT, message=message, should_continue=False) + + +@dataclass +class CommandContext: + """ + Context object providing state to command handlers. + + Contains all the session state needed by commands including + settings, provider, conversation history, and MCP manager. + + Attributes: + settings: Application settings + provider: AI provider instance + mcp_manager: MCP manager instance + selected_model: Currently selected model + session_history: Conversation history + session_system_prompt: Current system prompt + memory_enabled: Whether memory is enabled + online_enabled: Whether online mode is enabled + session_tokens: Session token counts + session_cost: Session cost total + """ + + settings: Optional["Settings"] = None + provider: Optional["AIProvider"] = None + mcp_manager: Optional["MCPManager"] = None + selected_model: Optional["ModelInfo"] = None + selected_model_raw: Optional[Dict[str, Any]] = None + session_history: List[Dict[str, Any]] = field(default_factory=list) + session_system_prompt: str = "" + memory_enabled: bool = True + memory_start_index: int = 0 + online_enabled: bool = False + middle_out_enabled: bool = False + session_max_token: int = 0 + total_input_tokens: int = 0 + total_output_tokens: int = 0 + total_cost: float = 0.0 + message_count: int = 0 + current_index: int = 0 + + +@dataclass +class CommandHelp: + """ + Help information for a command. + + Attributes: + description: Brief description + usage: Usage syntax + examples: List of (description, example) tuples + notes: Additional notes + aliases: Command aliases + """ + + description: str + usage: str = "" + examples: List[tuple] = field(default_factory=list) + notes: str = "" + aliases: List[str] = field(default_factory=list) + + +class Command(ABC): + """ + Abstract base class for all commands. + + Commands implement the execute method to handle their logic. + They can also provide help information and aliases. + """ + + @property + @abstractmethod + def name(self) -> str: + """Get the primary command name (e.g., '/help').""" + pass + + @property + def aliases(self) -> List[str]: + """Get command aliases (e.g., ['/h'] for help).""" + return [] + + @property + @abstractmethod + def help(self) -> CommandHelp: + """Get command help information.""" + pass + + @abstractmethod + def execute(self, args: str, context: CommandContext) -> CommandResult: + """ + Execute the command. + + Args: + args: Arguments passed to the command + context: Command execution context + + Returns: + CommandResult indicating success/failure + """ + pass + + def matches(self, input_text: str) -> bool: + """ + Check if this command matches the input. + + Args: + input_text: User input text + + Returns: + True if this command should handle the input + """ + input_lower = input_text.lower() + cmd_word = input_lower.split()[0] if input_lower.split() else "" + + # Check primary name + if cmd_word == self.name.lower(): + return True + + # Check aliases + for alias in self.aliases: + if cmd_word == alias.lower(): + return True + + return False + + def get_args(self, input_text: str) -> str: + """ + Extract arguments from the input text. + + Args: + input_text: Full user input + + Returns: + Arguments portion of the input + """ + parts = input_text.split(maxsplit=1) + return parts[1] if len(parts) > 1 else "" + + +class CommandRegistry: + """ + Registry for managing available commands. + + Provides registration, lookup, and execution of commands. + """ + + def __init__(self): + """Initialize an empty command registry.""" + self._commands: Dict[str, Command] = {} + self._aliases: Dict[str, str] = {} + self.logger = get_logger() + + def register(self, command: Command) -> None: + """ + Register a command. + + Args: + command: Command instance to register + + Raises: + ValueError: If command name already registered + """ + name = command.name.lower() + + if name in self._commands: + raise ValueError(f"Command '{name}' already registered") + + self._commands[name] = command + + # Register aliases + for alias in command.aliases: + alias_lower = alias.lower() + if alias_lower in self._aliases: + self.logger.warning( + f"Alias '{alias}' already registered, overwriting" + ) + self._aliases[alias_lower] = name + + self.logger.debug(f"Registered command: {name}") + + def register_function( + self, + name: str, + handler: Callable[[str, CommandContext], CommandResult], + description: str, + usage: str = "", + aliases: Optional[List[str]] = None, + examples: Optional[List[tuple]] = None, + notes: str = "", + ) -> None: + """ + Register a function-based command. + + Convenience method for simple commands that don't need + a full Command class. + + Args: + name: Command name (e.g., '/help') + handler: Function to execute + description: Help description + usage: Usage syntax + aliases: Command aliases + examples: Example usages + notes: Additional notes + """ + aliases = aliases or [] + examples = examples or [] + + class FunctionCommand(Command): + @property + def name(self) -> str: + return name + + @property + def aliases(self) -> List[str]: + return aliases + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description=description, + usage=usage, + examples=examples, + notes=notes, + aliases=aliases, + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + return handler(args, context) + + self.register(FunctionCommand()) + + def get(self, name: str) -> Optional[Command]: + """ + Get a command by name or alias. + + Args: + name: Command name or alias + + Returns: + Command instance or None if not found + """ + name_lower = name.lower() + + # Check direct match + if name_lower in self._commands: + return self._commands[name_lower] + + # Check aliases + if name_lower in self._aliases: + return self._commands[self._aliases[name_lower]] + + return None + + def find(self, input_text: str) -> Optional[Command]: + """ + Find a command that matches the input. + + Args: + input_text: User input text + + Returns: + Matching Command or None + """ + cmd_word = input_text.lower().split()[0] if input_text.split() else "" + return self.get(cmd_word) + + def execute(self, input_text: str, context: CommandContext) -> Optional[CommandResult]: + """ + Execute a command matching the input. + + Args: + input_text: User input text + context: Execution context + + Returns: + CommandResult or None if no matching command + """ + command = self.find(input_text) + if command: + args = command.get_args(input_text) + self.logger.debug(f"Executing command: {command.name} with args: {args}") + return command.execute(args, context) + return None + + def is_command(self, input_text: str) -> bool: + """ + Check if input is a valid command. + + Args: + input_text: User input text + + Returns: + True if input matches a registered command + """ + return self.find(input_text) is not None + + def list_commands(self) -> List[Command]: + """ + Get all registered commands. + + Returns: + List of Command instances + """ + return list(self._commands.values()) + + def get_all_names(self) -> List[str]: + """ + Get all command names and aliases. + + Returns: + List of command names including aliases + """ + names = list(self._commands.keys()) + names.extend(self._aliases.keys()) + return sorted(set(names)) + + +# Global registry instance +registry = CommandRegistry() diff --git a/oai/config/__init__.py b/oai/config/__init__.py new file mode 100644 index 0000000..d12b503 --- /dev/null +++ b/oai/config/__init__.py @@ -0,0 +1,11 @@ +""" +Configuration management for oAI. + +This package handles all configuration persistence, settings management, +and database operations for the application. +""" + +from oai.config.settings import Settings +from oai.config.database import Database + +__all__ = ["Settings", "Database"] diff --git a/oai/config/database.py b/oai/config/database.py new file mode 100644 index 0000000..a6a3496 --- /dev/null +++ b/oai/config/database.py @@ -0,0 +1,472 @@ +""" +Database persistence layer for oAI. + +This module provides a clean abstraction for SQLite operations including +configuration storage, conversation persistence, and MCP statistics tracking. +All database operations are centralized here for maintainability. +""" + +import sqlite3 +import json +import datetime +from pathlib import Path +from typing import Optional, List, Dict, Any +from contextlib import contextmanager + +from oai.constants import DATABASE_FILE, CONFIG_DIR + + +class Database: + """ + SQLite database manager for oAI. + + Handles all database operations including: + - Configuration key-value storage + - Conversation session persistence + - MCP configuration and statistics + - Database registrations for MCP + + Uses context managers for safe connection handling and supports + automatic table creation on first use. + """ + + def __init__(self, db_path: Optional[Path] = None): + """ + Initialize the database manager. + + Args: + db_path: Optional custom database path. Defaults to standard location. + """ + self.db_path = db_path or DATABASE_FILE + self._ensure_directories() + self._ensure_tables() + + def _ensure_directories(self) -> None: + """Ensure the configuration directory exists.""" + CONFIG_DIR.mkdir(parents=True, exist_ok=True) + + @contextmanager + def _connection(self): + """ + Context manager for database connections. + + Yields: + sqlite3.Connection: Active database connection + + Example: + with self._connection() as conn: + conn.execute("SELECT * FROM config") + """ + conn = sqlite3.connect(str(self.db_path)) + try: + yield conn + conn.commit() + except Exception: + conn.rollback() + raise + finally: + conn.close() + + def _ensure_tables(self) -> None: + """Create all required tables if they don't exist.""" + with self._connection() as conn: + # Main configuration table + conn.execute(""" + CREATE TABLE IF NOT EXISTS config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ) + """) + + # Conversation sessions table + conn.execute(""" + CREATE TABLE IF NOT EXISTS conversation_sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + timestamp TEXT NOT NULL, + data TEXT NOT NULL + ) + """) + + # MCP configuration table + conn.execute(""" + CREATE TABLE IF NOT EXISTS mcp_config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ) + """) + + # MCP statistics table + conn.execute(""" + CREATE TABLE IF NOT EXISTS mcp_stats ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL, + tool_name TEXT NOT NULL, + folder TEXT, + success INTEGER NOT NULL, + error_message TEXT + ) + """) + + # MCP databases table + conn.execute(""" + CREATE TABLE IF NOT EXISTS mcp_databases ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + path TEXT NOT NULL UNIQUE, + name TEXT NOT NULL, + size INTEGER, + tables TEXT, + added_timestamp TEXT NOT NULL + ) + """) + + # ========================================================================= + # CONFIGURATION METHODS + # ========================================================================= + + def get_config(self, key: str) -> Optional[str]: + """ + Retrieve a configuration value by key. + + Args: + key: The configuration key to retrieve + + Returns: + The configuration value, or None if not found + """ + with self._connection() as conn: + cursor = conn.execute( + "SELECT value FROM config WHERE key = ?", + (key,) + ) + result = cursor.fetchone() + return result[0] if result else None + + def set_config(self, key: str, value: str) -> None: + """ + Set a configuration value. + + Args: + key: The configuration key + value: The value to store + """ + with self._connection() as conn: + conn.execute( + "INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)", + (key, value) + ) + + def delete_config(self, key: str) -> bool: + """ + Delete a configuration value. + + Args: + key: The configuration key to delete + + Returns: + True if a row was deleted, False otherwise + """ + with self._connection() as conn: + cursor = conn.execute( + "DELETE FROM config WHERE key = ?", + (key,) + ) + return cursor.rowcount > 0 + + def get_all_config(self) -> Dict[str, str]: + """ + Retrieve all configuration values. + + Returns: + Dictionary of all key-value pairs + """ + with self._connection() as conn: + cursor = conn.execute("SELECT key, value FROM config") + return dict(cursor.fetchall()) + + # ========================================================================= + # MCP CONFIGURATION METHODS + # ========================================================================= + + def get_mcp_config(self, key: str) -> Optional[str]: + """ + Retrieve an MCP configuration value. + + Args: + key: The MCP configuration key + + Returns: + The configuration value, or None if not found + """ + with self._connection() as conn: + cursor = conn.execute( + "SELECT value FROM mcp_config WHERE key = ?", + (key,) + ) + result = cursor.fetchone() + return result[0] if result else None + + def set_mcp_config(self, key: str, value: str) -> None: + """ + Set an MCP configuration value. + + Args: + key: The MCP configuration key + value: The value to store + """ + with self._connection() as conn: + conn.execute( + "INSERT OR REPLACE INTO mcp_config (key, value) VALUES (?, ?)", + (key, value) + ) + + # ========================================================================= + # MCP STATISTICS METHODS + # ========================================================================= + + def log_mcp_stat( + self, + tool_name: str, + folder: Optional[str], + success: bool, + error_message: Optional[str] = None + ) -> None: + """ + Log an MCP tool usage event. + + Args: + tool_name: Name of the MCP tool that was called + folder: The folder path involved (if any) + success: Whether the call succeeded + error_message: Error message if the call failed + """ + timestamp = datetime.datetime.now().isoformat() + with self._connection() as conn: + conn.execute( + """INSERT INTO mcp_stats + (timestamp, tool_name, folder, success, error_message) + VALUES (?, ?, ?, ?, ?)""", + (timestamp, tool_name, folder, 1 if success else 0, error_message) + ) + + def get_mcp_stats(self) -> Dict[str, Any]: + """ + Get aggregated MCP usage statistics. + + Returns: + Dictionary containing usage statistics: + - total_calls: Total number of tool calls + - reads: Number of file reads + - lists: Number of directory listings + - searches: Number of file searches + - db_inspects: Number of database inspections + - db_searches: Number of database searches + - db_queries: Number of database queries + - last_used: Timestamp of last usage + """ + with self._connection() as conn: + cursor = conn.execute(""" + SELECT + COUNT(*) as total_calls, + SUM(CASE WHEN tool_name = 'read_file' THEN 1 ELSE 0 END) as reads, + SUM(CASE WHEN tool_name = 'list_directory' THEN 1 ELSE 0 END) as lists, + SUM(CASE WHEN tool_name = 'search_files' THEN 1 ELSE 0 END) as searches, + SUM(CASE WHEN tool_name = 'inspect_database' THEN 1 ELSE 0 END) as db_inspects, + SUM(CASE WHEN tool_name = 'search_database' THEN 1 ELSE 0 END) as db_searches, + SUM(CASE WHEN tool_name = 'query_database' THEN 1 ELSE 0 END) as db_queries, + MAX(timestamp) as last_used + FROM mcp_stats + """) + row = cursor.fetchone() + return { + "total_calls": row[0] or 0, + "reads": row[1] or 0, + "lists": row[2] or 0, + "searches": row[3] or 0, + "db_inspects": row[4] or 0, + "db_searches": row[5] or 0, + "db_queries": row[6] or 0, + "last_used": row[7], + } + + # ========================================================================= + # MCP DATABASE REGISTRY METHODS + # ========================================================================= + + def add_mcp_database(self, db_info: Dict[str, Any]) -> int: + """ + Register a database for MCP access. + + Args: + db_info: Dictionary containing: + - path: Database file path + - name: Display name + - size: File size in bytes + - tables: List of table names + - added: Timestamp when added + + Returns: + The database ID + """ + with self._connection() as conn: + conn.execute( + """INSERT INTO mcp_databases + (path, name, size, tables, added_timestamp) + VALUES (?, ?, ?, ?, ?)""", + ( + db_info["path"], + db_info["name"], + db_info["size"], + json.dumps(db_info["tables"]), + db_info["added"] + ) + ) + cursor = conn.execute( + "SELECT id FROM mcp_databases WHERE path = ?", + (db_info["path"],) + ) + return cursor.fetchone()[0] + + def remove_mcp_database(self, db_path: str) -> bool: + """ + Remove a database from the MCP registry. + + Args: + db_path: Path to the database file + + Returns: + True if a row was deleted, False otherwise + """ + with self._connection() as conn: + cursor = conn.execute( + "DELETE FROM mcp_databases WHERE path = ?", + (db_path,) + ) + return cursor.rowcount > 0 + + def get_mcp_databases(self) -> List[Dict[str, Any]]: + """ + Retrieve all registered MCP databases. + + Returns: + List of database information dictionaries + """ + with self._connection() as conn: + cursor = conn.execute( + """SELECT id, path, name, size, tables, added_timestamp + FROM mcp_databases ORDER BY id""" + ) + databases = [] + for row in cursor.fetchall(): + tables_list = json.loads(row[4]) if row[4] else [] + databases.append({ + "id": row[0], + "path": row[1], + "name": row[2], + "size": row[3], + "tables": tables_list, + "added": row[5], + }) + return databases + + # ========================================================================= + # CONVERSATION METHODS + # ========================================================================= + + def save_conversation(self, name: str, data: List[Dict[str, str]]) -> None: + """ + Save a conversation session. + + Args: + name: Name/identifier for the conversation + data: List of message dictionaries with 'prompt' and 'response' keys + """ + timestamp = datetime.datetime.now().isoformat() + data_json = json.dumps(data) + with self._connection() as conn: + conn.execute( + """INSERT INTO conversation_sessions + (name, timestamp, data) VALUES (?, ?, ?)""", + (name, timestamp, data_json) + ) + + def load_conversation(self, name: str) -> Optional[List[Dict[str, str]]]: + """ + Load a conversation by name. + + Args: + name: Name of the conversation to load + + Returns: + List of messages, or None if not found + """ + with self._connection() as conn: + cursor = conn.execute( + """SELECT data FROM conversation_sessions + WHERE name = ? + ORDER BY timestamp DESC LIMIT 1""", + (name,) + ) + result = cursor.fetchone() + if result: + return json.loads(result[0]) + return None + + def delete_conversation(self, name: str) -> int: + """ + Delete a conversation by name. + + Args: + name: Name of the conversation to delete + + Returns: + Number of rows deleted + """ + with self._connection() as conn: + cursor = conn.execute( + "DELETE FROM conversation_sessions WHERE name = ?", + (name,) + ) + return cursor.rowcount + + def list_conversations(self) -> List[Dict[str, Any]]: + """ + List all saved conversations. + + Returns: + List of conversation summaries with name, timestamp, and message_count + """ + with self._connection() as conn: + cursor = conn.execute(""" + SELECT name, MAX(timestamp) as last_saved, data + FROM conversation_sessions + GROUP BY name + ORDER BY last_saved DESC + """) + conversations = [] + for row in cursor.fetchall(): + name, timestamp, data_json = row + data = json.loads(data_json) + conversations.append({ + "name": name, + "timestamp": timestamp, + "message_count": len(data), + }) + return conversations + + +# Global database instance for convenience +_db: Optional[Database] = None + + +def get_database() -> Database: + """ + Get the global database instance. + + Returns: + The shared Database instance + """ + global _db + if _db is None: + _db = Database() + return _db diff --git a/oai/config/settings.py b/oai/config/settings.py new file mode 100644 index 0000000..907a75a --- /dev/null +++ b/oai/config/settings.py @@ -0,0 +1,361 @@ +""" +Settings management for oAI. + +This module provides a centralized settings class that handles all application +configuration with type safety, validation, and persistence. +""" + +from dataclasses import dataclass, field +from typing import Optional +from pathlib import Path + +from oai.constants import ( + DEFAULT_BASE_URL, + DEFAULT_STREAM_ENABLED, + DEFAULT_MAX_TOKENS, + DEFAULT_ONLINE_MODE, + DEFAULT_COST_WARNING_THRESHOLD, + DEFAULT_LOG_MAX_SIZE_MB, + DEFAULT_LOG_BACKUP_COUNT, + DEFAULT_LOG_LEVEL, + DEFAULT_SYSTEM_PROMPT, + VALID_LOG_LEVELS, +) +from oai.config.database import get_database + + +@dataclass +class Settings: + """ + Application settings with persistence support. + + This class provides a clean interface for managing all configuration + options. Settings are automatically loaded from the database on + initialization and can be persisted back. + + Attributes: + api_key: OpenRouter API key + base_url: API base URL + default_model: Default model ID to use + default_system_prompt: Custom system prompt (None = use hardcoded default, "" = blank) + stream_enabled: Whether to stream responses + max_tokens: Maximum tokens per request + cost_warning_threshold: Alert threshold for message cost + default_online_mode: Whether online mode is enabled by default + log_max_size_mb: Maximum log file size in MB + log_backup_count: Number of log file backups to keep + log_level: Logging level (debug/info/warning/error/critical) + """ + + api_key: Optional[str] = None + base_url: str = DEFAULT_BASE_URL + default_model: Optional[str] = None + default_system_prompt: Optional[str] = None + stream_enabled: bool = DEFAULT_STREAM_ENABLED + max_tokens: int = DEFAULT_MAX_TOKENS + cost_warning_threshold: float = DEFAULT_COST_WARNING_THRESHOLD + default_online_mode: bool = DEFAULT_ONLINE_MODE + log_max_size_mb: int = DEFAULT_LOG_MAX_SIZE_MB + log_backup_count: int = DEFAULT_LOG_BACKUP_COUNT + log_level: str = DEFAULT_LOG_LEVEL + + @property + def effective_system_prompt(self) -> str: + """ + Get the effective system prompt to use. + + Returns: + The custom prompt if set, hardcoded default if None, or blank if explicitly set to "" + """ + if self.default_system_prompt is None: + return DEFAULT_SYSTEM_PROMPT + return self.default_system_prompt + + def __post_init__(self): + """Validate settings after initialization.""" + self._validate() + + def _validate(self) -> None: + """Validate all settings values.""" + # Validate log level + if self.log_level.lower() not in VALID_LOG_LEVELS: + raise ValueError( + f"Invalid log level: {self.log_level}. " + f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}" + ) + + # Validate numeric bounds + if self.max_tokens < 1: + raise ValueError("max_tokens must be at least 1") + + if self.cost_warning_threshold < 0: + raise ValueError("cost_warning_threshold must be non-negative") + + if self.log_max_size_mb < 1: + raise ValueError("log_max_size_mb must be at least 1") + + if self.log_backup_count < 0: + raise ValueError("log_backup_count must be non-negative") + + @classmethod + def load(cls) -> "Settings": + """ + Load settings from the database. + + Returns: + Settings instance with values from database + """ + db = get_database() + + # Helper to safely parse boolean + def parse_bool(value: Optional[str], default: bool) -> bool: + if value is None: + return default + return value.lower() in ("on", "true", "1", "yes") + + # Helper to safely parse int + def parse_int(value: Optional[str], default: int) -> int: + if value is None: + return default + try: + return int(value) + except ValueError: + return default + + # Helper to safely parse float + def parse_float(value: Optional[str], default: float) -> float: + if value is None: + return default + try: + return float(value) + except ValueError: + return default + + # Get system prompt from DB: None means not set (use default), "" means explicitly blank + system_prompt_value = db.get_config("default_system_prompt") + + return cls( + api_key=db.get_config("api_key"), + base_url=db.get_config("base_url") or DEFAULT_BASE_URL, + default_model=db.get_config("default_model"), + default_system_prompt=system_prompt_value, + stream_enabled=parse_bool( + db.get_config("stream_enabled"), + DEFAULT_STREAM_ENABLED + ), + max_tokens=parse_int( + db.get_config("max_token"), + DEFAULT_MAX_TOKENS + ), + cost_warning_threshold=parse_float( + db.get_config("cost_warning_threshold"), + DEFAULT_COST_WARNING_THRESHOLD + ), + default_online_mode=parse_bool( + db.get_config("default_online_mode"), + DEFAULT_ONLINE_MODE + ), + log_max_size_mb=parse_int( + db.get_config("log_max_size_mb"), + DEFAULT_LOG_MAX_SIZE_MB + ), + log_backup_count=parse_int( + db.get_config("log_backup_count"), + DEFAULT_LOG_BACKUP_COUNT + ), + log_level=db.get_config("log_level") or DEFAULT_LOG_LEVEL, + ) + + def save(self) -> None: + """Persist all settings to the database.""" + db = get_database() + + # Only save API key if it exists + if self.api_key: + db.set_config("api_key", self.api_key) + + db.set_config("base_url", self.base_url) + + if self.default_model: + db.set_config("default_model", self.default_model) + + # Save system prompt: None means not set (don't save), otherwise save the value (even if "") + if self.default_system_prompt is not None: + db.set_config("default_system_prompt", self.default_system_prompt) + + db.set_config("stream_enabled", "on" if self.stream_enabled else "off") + db.set_config("max_token", str(self.max_tokens)) + db.set_config("cost_warning_threshold", str(self.cost_warning_threshold)) + db.set_config("default_online_mode", "on" if self.default_online_mode else "off") + db.set_config("log_max_size_mb", str(self.log_max_size_mb)) + db.set_config("log_backup_count", str(self.log_backup_count)) + db.set_config("log_level", self.log_level) + + def set_api_key(self, api_key: str) -> None: + """ + Set and persist the API key. + + Args: + api_key: The new API key + """ + self.api_key = api_key.strip() + get_database().set_config("api_key", self.api_key) + + def set_base_url(self, url: str) -> None: + """ + Set and persist the base URL. + + Args: + url: The new base URL + """ + self.base_url = url.strip() + get_database().set_config("base_url", self.base_url) + + def set_default_model(self, model_id: str) -> None: + """ + Set and persist the default model. + + Args: + model_id: The model ID to set as default + """ + self.default_model = model_id + get_database().set_config("default_model", model_id) + + def set_default_system_prompt(self, prompt: str) -> None: + """ + Set and persist the default system prompt. + + Args: + prompt: The system prompt to use for all new sessions. + Empty string "" means blank prompt (no system message). + """ + self.default_system_prompt = prompt + get_database().set_config("default_system_prompt", prompt) + + def clear_default_system_prompt(self) -> None: + """ + Clear the custom system prompt and revert to hardcoded default. + + This removes the custom prompt from the database, causing the + application to use the built-in DEFAULT_SYSTEM_PROMPT. + """ + self.default_system_prompt = None + # Remove from database to indicate "not set" + db = get_database() + with db._connection() as conn: + conn.execute("DELETE FROM config WHERE key = ?", ("default_system_prompt",)) + conn.commit() + + def set_stream_enabled(self, enabled: bool) -> None: + """ + Set and persist the streaming preference. + + Args: + enabled: Whether to enable streaming + """ + self.stream_enabled = enabled + get_database().set_config("stream_enabled", "on" if enabled else "off") + + def set_max_tokens(self, max_tokens: int) -> None: + """ + Set and persist the maximum tokens. + + Args: + max_tokens: Maximum number of tokens + + Raises: + ValueError: If max_tokens is less than 1 + """ + if max_tokens < 1: + raise ValueError("max_tokens must be at least 1") + self.max_tokens = max_tokens + get_database().set_config("max_token", str(max_tokens)) + + def set_cost_warning_threshold(self, threshold: float) -> None: + """ + Set and persist the cost warning threshold. + + Args: + threshold: Cost threshold in USD + + Raises: + ValueError: If threshold is negative + """ + if threshold < 0: + raise ValueError("cost_warning_threshold must be non-negative") + self.cost_warning_threshold = threshold + get_database().set_config("cost_warning_threshold", str(threshold)) + + def set_default_online_mode(self, enabled: bool) -> None: + """ + Set and persist the default online mode. + + Args: + enabled: Whether online mode should be enabled by default + """ + self.default_online_mode = enabled + get_database().set_config("default_online_mode", "on" if enabled else "off") + + def set_log_level(self, level: str) -> None: + """ + Set and persist the log level. + + Args: + level: The log level (debug/info/warning/error/critical) + + Raises: + ValueError: If level is not valid + """ + level_lower = level.lower() + if level_lower not in VALID_LOG_LEVELS: + raise ValueError( + f"Invalid log level: {level}. " + f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}" + ) + self.log_level = level_lower + get_database().set_config("log_level", level_lower) + + def set_log_max_size(self, size_mb: int) -> None: + """ + Set and persist the maximum log file size. + + Args: + size_mb: Maximum size in megabytes + + Raises: + ValueError: If size_mb is less than 1 + """ + if size_mb < 1: + raise ValueError("log_max_size_mb must be at least 1") + # Cap at 100 MB for safety + self.log_max_size_mb = min(size_mb, 100) + get_database().set_config("log_max_size_mb", str(self.log_max_size_mb)) + + +# Global settings instance +_settings: Optional[Settings] = None + + +def get_settings() -> Settings: + """ + Get the global settings instance. + + Returns: + The shared Settings instance, loading from database if needed + """ + global _settings + if _settings is None: + _settings = Settings.load() + return _settings + + +def reload_settings() -> Settings: + """ + Force reload settings from the database. + + Returns: + Fresh Settings instance + """ + global _settings + _settings = Settings.load() + return _settings diff --git a/oai/constants.py b/oai/constants.py new file mode 100644 index 0000000..a91e330 --- /dev/null +++ b/oai/constants.py @@ -0,0 +1,448 @@ +""" +Application-wide constants for oAI. + +This module contains all configuration constants, default values, and static +definitions used throughout the application. Centralizing these values makes +the codebase easier to maintain and configure. +""" + +from pathlib import Path +from typing import Set, Dict, Any +import logging + +# ============================================================================= +# APPLICATION METADATA +# ============================================================================= + +APP_NAME = "oAI" +APP_VERSION = "2.1.0" +APP_URL = "https://iurl.no/oai" +APP_DESCRIPTION = "OpenRouter AI Chat Client with MCP Integration" + +# ============================================================================= +# FILE PATHS +# ============================================================================= + +HOME_DIR = Path.home() +CONFIG_DIR = HOME_DIR / ".config" / "oai" +CACHE_DIR = HOME_DIR / ".cache" / "oai" +HISTORY_FILE = CONFIG_DIR / "history.txt" +DATABASE_FILE = CONFIG_DIR / "oai_config.db" +LOG_FILE = CONFIG_DIR / "oai.log" + +# ============================================================================= +# API CONFIGURATION +# ============================================================================= + +DEFAULT_BASE_URL = "https://openrouter.ai/api/v1" +DEFAULT_STREAM_ENABLED = True +DEFAULT_MAX_TOKENS = 100_000 +DEFAULT_ONLINE_MODE = False + +# ============================================================================= +# DEFAULT SYSTEM PROMPT +# ============================================================================= + +DEFAULT_SYSTEM_PROMPT = ( + "You are a knowledgeable and helpful AI assistant. Provide clear, accurate, " + "and well-structured responses. Be concise yet thorough. When uncertain about " + "something, acknowledge your limitations. For technical topics, include relevant " + "details and examples when helpful." +) + +# ============================================================================= +# PRICING DEFAULTS (per million tokens) +# ============================================================================= + +DEFAULT_INPUT_PRICE = 3.0 +DEFAULT_OUTPUT_PRICE = 15.0 + +MODEL_PRICING: Dict[str, float] = { + "input": DEFAULT_INPUT_PRICE, + "output": DEFAULT_OUTPUT_PRICE, +} + +# ============================================================================= +# CREDIT ALERTS +# ============================================================================= + +LOW_CREDIT_RATIO = 0.1 # Alert when credits < 10% of total +LOW_CREDIT_AMOUNT = 1.0 # Alert when credits < $1.00 +DEFAULT_COST_WARNING_THRESHOLD = 0.01 # Alert when single message cost exceeds this +COST_WARNING_THRESHOLD = DEFAULT_COST_WARNING_THRESHOLD # Alias for convenience + +# ============================================================================= +# LOGGING CONFIGURATION +# ============================================================================= + +DEFAULT_LOG_MAX_SIZE_MB = 10 +DEFAULT_LOG_BACKUP_COUNT = 2 +DEFAULT_LOG_LEVEL = "info" + +VALID_LOG_LEVELS: Dict[str, int] = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + +# ============================================================================= +# FILE HANDLING +# ============================================================================= + +# Maximum file size for reading (10 MB) +MAX_FILE_SIZE = 10 * 1024 * 1024 + +# Content truncation threshold (50 KB) +CONTENT_TRUNCATION_THRESHOLD = 50 * 1024 + +# Maximum items in directory listing +MAX_LIST_ITEMS = 1000 + +# Supported code file extensions for syntax highlighting +SUPPORTED_CODE_EXTENSIONS: Set[str] = { + ".py", ".js", ".ts", ".cs", ".java", ".c", ".cpp", ".h", ".hpp", + ".rb", ".ruby", ".php", ".swift", ".kt", ".kts", ".go", + ".sh", ".bat", ".ps1", ".r", ".scala", ".pl", ".lua", ".dart", + ".elm", ".xml", ".json", ".yaml", ".yml", ".md", ".txt", +} + +# All allowed file extensions for attachment +ALLOWED_FILE_EXTENSIONS: Set[str] = { + # Code files + ".py", ".js", ".ts", ".jsx", ".tsx", ".vue", ".java", ".c", ".cpp", ".cc", ".cxx", + ".h", ".hpp", ".hxx", ".rb", ".go", ".rs", ".swift", ".kt", ".kts", ".php", + ".sh", ".bash", ".zsh", ".fish", ".bat", ".cmd", ".ps1", + # Data files + ".json", ".csv", ".yaml", ".yml", ".toml", ".xml", ".sql", ".db", ".sqlite", ".sqlite3", + # Documents + ".txt", ".md", ".log", ".conf", ".cfg", ".ini", ".env", ".properties", + # Images + ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg", ".ico", + # Archives + ".zip", ".tar", ".gz", ".bz2", ".7z", ".rar", ".xz", + # Config files + ".lock", ".gitignore", ".dockerignore", ".editorconfig", ".eslintrc", + ".prettierrc", ".babelrc", ".nvmrc", ".npmrc", + # Binary/Compiled + ".pyc", ".pyo", ".pyd", ".so", ".dll", ".dylib", ".exe", ".app", + ".dmg", ".pkg", ".deb", ".rpm", ".apk", ".ipa", + # ML/AI + ".pkl", ".pickle", ".joblib", ".npy", ".npz", ".safetensors", ".onnx", + ".pt", ".pth", ".ckpt", ".pb", ".tflite", ".mlmodel", ".coreml", ".rknn", + # Data formats + ".wasm", ".proto", ".graphql", ".graphqls", ".grpc", ".avro", ".parquet", + ".orc", ".feather", ".arrow", ".hdf5", ".h5", ".mat", ".rdata", ".rds", + # Other + ".pdf", ".class", ".jar", ".war", +} + +# ============================================================================= +# SECURITY CONFIGURATION +# ============================================================================= + +# System directories that should never be accessed +SYSTEM_DIRS_BLACKLIST: Set[str] = { + # macOS + "/System", "/Library", "/private", "/usr", "/bin", "/sbin", + # Linux + "/boot", "/dev", "/proc", "/sys", "/root", + # Windows + "C:\\Windows", "C:\\Program Files", "C:\\Program Files (x86)", +} + +# Directories to skip during file operations +SKIP_DIRECTORIES: Set[str] = { + # Python virtual environments + ".venv", "venv", "env", "virtualenv", + "site-packages", "dist-packages", + # Python caches + "__pycache__", ".pytest_cache", ".mypy_cache", + # JavaScript/Node + "node_modules", + # Version control + ".git", ".svn", + # IDEs + ".idea", ".vscode", + # Build directories + "build", "dist", "eggs", ".eggs", +} + +# ============================================================================= +# DATABASE QUERIES - SQL SAFETY +# ============================================================================= + +# Maximum query execution timeout (seconds) +MAX_QUERY_TIMEOUT = 5 + +# Maximum rows returned from queries +MAX_QUERY_RESULTS = 1000 + +# Default rows per query +DEFAULT_QUERY_LIMIT = 100 + +# Keywords that are blocked in database queries +DANGEROUS_SQL_KEYWORDS: Set[str] = { + "INSERT", "UPDATE", "DELETE", "DROP", "CREATE", + "ALTER", "TRUNCATE", "REPLACE", "ATTACH", "DETACH", + "PRAGMA", "VACUUM", "REINDEX", +} + +# ============================================================================= +# MCP CONFIGURATION +# ============================================================================= + +# Maximum tool call iterations per request +MAX_TOOL_LOOPS = 5 + +# ============================================================================= +# VALID COMMANDS +# ============================================================================= + +VALID_COMMANDS: Set[str] = { + "/retry", "/online", "/memory", "/paste", "/export", "/save", "/load", + "/delete", "/list", "/prev", "/next", "/stats", "/middleout", "/reset", + "/info", "/model", "/maxtoken", "/system", "/config", "/credits", "/clear", + "/cl", "/help", "/mcp", +} + +# ============================================================================= +# COMMAND HELP DATABASE +# ============================================================================= + +COMMAND_HELP: Dict[str, Dict[str, Any]] = { + "/clear": { + "aliases": ["/cl"], + "description": "Clear the terminal screen for a clean interface.", + "usage": "/clear\n/cl", + "examples": [ + ("Clear screen", "/clear"), + ("Using short alias", "/cl"), + ], + "notes": "You can also use the keyboard shortcut Ctrl+L.", + }, + "/help": { + "description": "Display help information for commands.", + "usage": "/help [command|topic]", + "examples": [ + ("Show all commands", "/help"), + ("Get help for a specific command", "/help /model"), + ("Get detailed MCP help", "/help mcp"), + ], + "notes": "Use /help without arguments to see the full command list.", + }, + "mcp": { + "description": "Complete guide to MCP (Model Context Protocol).", + "usage": "See detailed examples below", + "examples": [], + "notes": """ +MCP (Model Context Protocol) gives your AI assistant direct access to: + โ€ข Local files and folders (read, search, list) + โ€ข SQLite databases (inspect, search, query) + +FILE MODE (default): + /mcp on Start MCP server + /mcp add ~/Documents Grant access to folder + /mcp list View all allowed folders + +DATABASE MODE: + /mcp add db ~/app/data.db Add specific database + /mcp db list View all databases + /mcp db 1 Work with database #1 + /mcp files Switch back to file mode + +WRITE MODE (optional): + /mcp write on Enable file modifications + /mcp write off Disable write mode (back to read-only) + +For command-specific help: /help /mcp + """, + }, + "/mcp": { + "description": "Manage MCP for local file access and SQLite database querying.", + "usage": "/mcp [args]", + "examples": [ + ("Enable MCP server", "/mcp on"), + ("Disable MCP server", "/mcp off"), + ("Show MCP status", "/mcp status"), + ("", ""), + ("โ”โ”โ” FILE MODE โ”โ”โ”", ""), + ("Add folder for file access", "/mcp add ~/Documents"), + ("Remove folder", "/mcp remove ~/Desktop"), + ("List allowed folders", "/mcp list"), + ("Enable write mode", "/mcp write on"), + ("", ""), + ("โ”โ”โ” DATABASE MODE โ”โ”โ”", ""), + ("Add SQLite database", "/mcp add db ~/app/data.db"), + ("List all databases", "/mcp db list"), + ("Switch to database #1", "/mcp db 1"), + ("Switch back to file mode", "/mcp files"), + ], + "notes": "MCP allows AI to read local files and query SQLite databases.", + }, + "/memory": { + "description": "Toggle conversation memory.", + "usage": "/memory [on|off]", + "examples": [ + ("Check current memory status", "/memory"), + ("Enable conversation memory", "/memory on"), + ("Disable memory (save costs)", "/memory off"), + ], + "notes": "Memory is ON by default. Disabling saves tokens.", + }, + "/online": { + "description": "Enable or disable online mode (web search).", + "usage": "/online [on|off]", + "examples": [ + ("Check online mode status", "/online"), + ("Enable web search", "/online on"), + ("Disable web search", "/online off"), + ], + "notes": "Not all models support online mode.", + }, + "/paste": { + "description": "Paste plain text from clipboard and send to the AI.", + "usage": "/paste [prompt]", + "examples": [ + ("Paste clipboard content", "/paste"), + ("Paste with a question", "/paste Explain this code"), + ], + "notes": "Only plain text is supported.", + }, + "/retry": { + "description": "Resend the last prompt from conversation history.", + "usage": "/retry", + "examples": [("Retry last message", "/retry")], + "notes": "Requires at least one message in history.", + }, + "/next": { + "description": "View the next response in conversation history.", + "usage": "/next", + "examples": [("Navigate to next response", "/next")], + "notes": "Use /prev to go backward.", + }, + "/prev": { + "description": "View the previous response in conversation history.", + "usage": "/prev", + "examples": [("Navigate to previous response", "/prev")], + "notes": "Use /next to go forward.", + }, + "/reset": { + "description": "Clear conversation history and reset system prompt.", + "usage": "/reset", + "examples": [("Reset conversation", "/reset")], + "notes": "Requires confirmation.", + }, + "/info": { + "description": "Display detailed information about a model.", + "usage": "/info [model_id]", + "examples": [ + ("Show current model info", "/info"), + ("Show specific model info", "/info gpt-4o"), + ], + "notes": "Shows pricing, capabilities, and context length.", + }, + "/model": { + "description": "Select or change the AI model.", + "usage": "/model [search_term]", + "examples": [ + ("List all models", "/model"), + ("Search for GPT models", "/model gpt"), + ("Search for Claude models", "/model claude"), + ], + "notes": "Models are numbered for easy selection.", + }, + "/config": { + "description": "View or modify application configuration.", + "usage": "/config [setting] [value]", + "examples": [ + ("View all settings", "/config"), + ("Set API key", "/config api"), + ("Set default model", "/config model"), + ("Set system prompt", "/config system You are a helpful assistant"), + ("Enable streaming", "/config stream on"), + ], + "notes": "Available: api, url, model, system, stream, costwarning, maxtoken, online, loglevel.", + }, + "/maxtoken": { + "description": "Set a temporary session token limit.", + "usage": "/maxtoken [value]", + "examples": [ + ("View current session limit", "/maxtoken"), + ("Set session limit to 2000", "/maxtoken 2000"), + ], + "notes": "Cannot exceed stored max token limit.", + }, + "/system": { + "description": "Set or clear the session-level system prompt.", + "usage": "/system [prompt|clear|default ]", + "examples": [ + ("View current system prompt", "/system"), + ("Set as Python expert", "/system You are a Python expert"), + ("Multiline with newlines", r"/system You are an expert.\nBe clear and concise."), + ("Save as default", "/system default You are a helpful assistant"), + ("Revert to default", "/system clear"), + ("Blank prompt", '/system ""'), + ], + "notes": r"Use \n for newlines. /system clear reverts to hardcoded default.", + }, + "/save": { + "description": "Save the current conversation history.", + "usage": "/save ", + "examples": [("Save conversation", "/save my_chat")], + "notes": "Saved conversations can be loaded later with /load.", + }, + "/load": { + "description": "Load a saved conversation.", + "usage": "/load ", + "examples": [ + ("Load by name", "/load my_chat"), + ("Load by number from /list", "/load 3"), + ], + "notes": "Use /list to see numbered conversations.", + }, + "/delete": { + "description": "Delete a saved conversation.", + "usage": "/delete ", + "examples": [("Delete by name", "/delete my_chat")], + "notes": "Requires confirmation. Cannot be undone.", + }, + "/list": { + "description": "List all saved conversations.", + "usage": "/list", + "examples": [("Show saved conversations", "/list")], + "notes": "Conversations are numbered for use with /load and /delete.", + }, + "/export": { + "description": "Export the current conversation to a file.", + "usage": "/export ", + "examples": [ + ("Export as Markdown", "/export md notes.md"), + ("Export as JSON", "/export json conversation.json"), + ("Export as HTML", "/export html report.html"), + ], + "notes": "Available formats: md, json, html.", + }, + "/stats": { + "description": "Display session statistics.", + "usage": "/stats", + "examples": [("View session statistics", "/stats")], + "notes": "Shows tokens, costs, and credits.", + }, + "/credits": { + "description": "Display your OpenRouter account credits.", + "usage": "/credits", + "examples": [("Check credits", "/credits")], + "notes": "Shows total, used, and remaining credits.", + }, + "/middleout": { + "description": "Toggle middle-out transform for long prompts.", + "usage": "/middleout [on|off]", + "examples": [ + ("Check status", "/middleout"), + ("Enable compression", "/middleout on"), + ], + "notes": "Compresses prompts exceeding context size.", + }, +} diff --git a/oai/core/__init__.py b/oai/core/__init__.py new file mode 100644 index 0000000..d2d1e9e --- /dev/null +++ b/oai/core/__init__.py @@ -0,0 +1,14 @@ +""" +Core functionality for oAI. + +This module provides the main session management and AI client +classes that power the chat application. +""" + +from oai.core.session import ChatSession +from oai.core.client import AIClient + +__all__ = [ + "ChatSession", + "AIClient", +] diff --git a/oai/core/client.py b/oai/core/client.py new file mode 100644 index 0000000..5bcdbec --- /dev/null +++ b/oai/core/client.py @@ -0,0 +1,422 @@ +""" +AI Client for oAI. + +This module provides a high-level client for interacting with AI models +through the provider abstraction layer. +""" + +import asyncio +import json +from typing import Any, Callable, Dict, Iterator, List, Optional, Union + +from oai.constants import APP_NAME, APP_URL, MODEL_PRICING +from oai.providers.base import ( + AIProvider, + ChatMessage, + ChatResponse, + ModelInfo, + StreamChunk, + ToolCall, + UsageStats, +) +from oai.providers.openrouter import OpenRouterProvider +from oai.utils.logging import get_logger + + +class AIClient: + """ + High-level AI client for chat interactions. + + Provides a simplified interface for sending chat requests, + handling streaming, and managing tool calls. + + Attributes: + provider: The underlying AI provider + default_model: Default model ID to use + http_headers: Custom HTTP headers for requests + """ + + def __init__( + self, + api_key: str, + base_url: Optional[str] = None, + provider_class: type = OpenRouterProvider, + app_name: str = APP_NAME, + app_url: str = APP_URL, + ): + """ + Initialize the AI client. + + Args: + api_key: API key for authentication + base_url: Optional custom base URL + provider_class: Provider class to use (default: OpenRouterProvider) + app_name: Application name for headers + app_url: Application URL for headers + """ + self.provider: AIProvider = provider_class( + api_key=api_key, + base_url=base_url, + app_name=app_name, + app_url=app_url, + ) + self.default_model: Optional[str] = None + self.logger = get_logger() + + def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]: + """ + Get available models. + + Args: + filter_text_only: Whether to exclude video-only models + + Returns: + List of ModelInfo objects + """ + return self.provider.list_models(filter_text_only=filter_text_only) + + def get_model(self, model_id: str) -> Optional[ModelInfo]: + """ + Get information about a specific model. + + Args: + model_id: Model identifier + + Returns: + ModelInfo or None if not found + """ + return self.provider.get_model(model_id) + + def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]: + """ + Get raw model data for provider-specific fields. + + Args: + model_id: Model identifier + + Returns: + Raw model dictionary or None + """ + if hasattr(self.provider, "get_raw_model"): + return self.provider.get_raw_model(model_id) + return None + + def chat( + self, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + system_prompt: Optional[str] = None, + online: bool = False, + transforms: Optional[List[str]] = None, + ) -> Union[ChatResponse, Iterator[StreamChunk]]: + """ + Send a chat request. + + Args: + messages: List of message dictionaries + model: Model ID (uses default if not specified) + stream: Whether to stream the response + max_tokens: Maximum tokens in response + temperature: Sampling temperature + tools: Tool definitions for function calling + tool_choice: Tool selection mode + system_prompt: System prompt to prepend + online: Whether to enable online mode + transforms: List of transforms (e.g., ["middle-out"]) + + Returns: + ChatResponse for non-streaming, Iterator[StreamChunk] for streaming + + Raises: + ValueError: If no model specified and no default set + """ + model_id = model or self.default_model + if not model_id: + raise ValueError("No model specified and no default set") + + # Apply online mode suffix + if online and hasattr(self.provider, "get_effective_model_id"): + model_id = self.provider.get_effective_model_id(model_id, True) + + # Convert dict messages to ChatMessage objects + chat_messages = [] + + # Add system prompt if provided + if system_prompt: + chat_messages.append(ChatMessage(role="system", content=system_prompt)) + + # Convert message dicts + for msg in messages: + # Convert tool_calls dicts to ToolCall objects if present + tool_calls_data = msg.get("tool_calls") + tool_calls_obj = None + if tool_calls_data: + from oai.providers.base import ToolCall, ToolFunction + tool_calls_obj = [] + for tc in tool_calls_data: + # Handle both ToolCall objects and dicts + if isinstance(tc, ToolCall): + tool_calls_obj.append(tc) + elif isinstance(tc, dict): + func_data = tc.get("function", {}) + tool_calls_obj.append( + ToolCall( + id=tc.get("id", ""), + type=tc.get("type", "function"), + function=ToolFunction( + name=func_data.get("name", ""), + arguments=func_data.get("arguments", "{}"), + ), + ) + ) + + chat_messages.append( + ChatMessage( + role=msg.get("role", "user"), + content=msg.get("content"), + tool_calls=tool_calls_obj, + tool_call_id=msg.get("tool_call_id"), + ) + ) + + self.logger.debug( + f"Sending chat request: model={model_id}, " + f"messages={len(chat_messages)}, stream={stream}" + ) + + return self.provider.chat( + model=model_id, + messages=chat_messages, + stream=stream, + max_tokens=max_tokens, + temperature=temperature, + tools=tools, + tool_choice=tool_choice, + transforms=transforms, + ) + + def chat_with_tools( + self, + messages: List[Dict[str, Any]], + tools: List[Dict[str, Any]], + tool_executor: Callable[[str, Dict[str, Any]], Dict[str, Any]], + model: Optional[str] = None, + max_loops: int = 5, + max_tokens: Optional[int] = None, + system_prompt: Optional[str] = None, + on_tool_call: Optional[Callable[[ToolCall], None]] = None, + on_tool_result: Optional[Callable[[str, Dict[str, Any]], None]] = None, + ) -> ChatResponse: + """ + Send a chat request with automatic tool call handling. + + Executes tool calls returned by the model and continues + the conversation until no more tool calls are requested. + + Args: + messages: Initial messages + tools: Tool definitions + tool_executor: Function to execute tool calls + model: Model ID + max_loops: Maximum tool call iterations + max_tokens: Maximum response tokens + system_prompt: System prompt + on_tool_call: Callback when tool is called + on_tool_result: Callback when tool returns result + + Returns: + Final ChatResponse after all tool calls complete + """ + model_id = model or self.default_model + if not model_id: + raise ValueError("No model specified and no default set") + + # Build initial messages + chat_messages = [] + if system_prompt: + chat_messages.append({"role": "system", "content": system_prompt}) + chat_messages.extend(messages) + + loop_count = 0 + current_response: Optional[ChatResponse] = None + + while loop_count < max_loops: + # Send request + response = self.chat( + messages=chat_messages, + model=model_id, + stream=False, + max_tokens=max_tokens, + tools=tools, + tool_choice="auto", + ) + + if not isinstance(response, ChatResponse): + raise ValueError("Expected non-streaming response") + + current_response = response + + # Check for tool calls + tool_calls = response.tool_calls + if not tool_calls: + break + + self.logger.info(f"Model requested {len(tool_calls)} tool call(s)") + + # Process each tool call + tool_results = [] + for tc in tool_calls: + if on_tool_call: + on_tool_call(tc) + + try: + args = json.loads(tc.function.arguments) + except json.JSONDecodeError as e: + self.logger.error(f"Failed to parse tool arguments: {e}") + result = {"error": f"Invalid arguments: {e}"} + else: + result = tool_executor(tc.function.name, args) + + if on_tool_result: + on_tool_result(tc.function.name, result) + + tool_results.append({ + "tool_call_id": tc.id, + "role": "tool", + "name": tc.function.name, + "content": json.dumps(result), + }) + + # Add assistant message with tool calls + assistant_msg = { + "role": "assistant", + "content": response.content, + "tool_calls": [ + { + "id": tc.id, + "type": tc.type, + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in tool_calls + ], + } + chat_messages.append(assistant_msg) + chat_messages.extend(tool_results) + + loop_count += 1 + + if loop_count >= max_loops: + self.logger.warning(f"Reached max tool call loops ({max_loops})") + + return current_response + + def stream_chat( + self, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + max_tokens: Optional[int] = None, + system_prompt: Optional[str] = None, + online: bool = False, + on_chunk: Optional[Callable[[StreamChunk], None]] = None, + ) -> tuple[str, Optional[UsageStats]]: + """ + Stream a chat response and collect the full text. + + Args: + messages: Chat messages + model: Model ID + max_tokens: Maximum tokens + system_prompt: System prompt + online: Online mode + on_chunk: Optional callback for each chunk + + Returns: + Tuple of (full_response_text, usage_stats) + """ + response = self.chat( + messages=messages, + model=model, + stream=True, + max_tokens=max_tokens, + system_prompt=system_prompt, + online=online, + ) + + if isinstance(response, ChatResponse): + # Not actually streaming + return response.content or "", response.usage + + full_text = "" + usage: Optional[UsageStats] = None + + for chunk in response: + if chunk.error: + self.logger.error(f"Stream error: {chunk.error}") + break + + if chunk.delta_content: + full_text += chunk.delta_content + if on_chunk: + on_chunk(chunk) + + if chunk.usage: + usage = chunk.usage + + return full_text, usage + + def get_credits(self) -> Optional[Dict[str, Any]]: + """ + Get account credit information. + + Returns: + Credit info dict or None if unavailable + """ + return self.provider.get_credits() + + def estimate_cost( + self, + model_id: str, + input_tokens: int, + output_tokens: int, + ) -> float: + """ + Estimate cost for a completion. + + Args: + model_id: Model ID + input_tokens: Number of input tokens + output_tokens: Number of output tokens + + Returns: + Estimated cost in USD + """ + if hasattr(self.provider, "estimate_cost"): + return self.provider.estimate_cost(model_id, input_tokens, output_tokens) + + # Fallback to default pricing + input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000 + output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000 + return input_cost + output_cost + + def set_default_model(self, model_id: str) -> None: + """ + Set the default model. + + Args: + model_id: Model ID to use as default + """ + self.default_model = model_id + self.logger.info(f"Default model set to: {model_id}") + + def clear_cache(self) -> None: + """Clear the provider's model cache.""" + if hasattr(self.provider, "clear_cache"): + self.provider.clear_cache() diff --git a/oai/core/session.py b/oai/core/session.py new file mode 100644 index 0000000..c4e5e31 --- /dev/null +++ b/oai/core/session.py @@ -0,0 +1,659 @@ +""" +Chat session management for oAI. + +This module provides the ChatSession class that manages an interactive +chat session including history, state, and message handling. +""" + +import asyncio +import json +import time +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple + +from rich.live import Live +from rich.markdown import Markdown + +from oai.commands.registry import CommandContext, CommandResult, registry +from oai.config.database import Database +from oai.config.settings import Settings +from oai.constants import ( + COST_WARNING_THRESHOLD, + LOW_CREDIT_AMOUNT, + LOW_CREDIT_RATIO, +) +from oai.core.client import AIClient +from oai.mcp.manager import MCPManager +from oai.providers.base import ChatResponse, StreamChunk, UsageStats +from oai.ui.console import ( + console, + display_markdown, + display_panel, + print_error, + print_info, + print_success, + print_warning, +) +from oai.ui.prompts import prompt_copy_response +from oai.utils.logging import get_logger + + +@dataclass +class SessionStats: + """ + Statistics for the current session. + + Tracks tokens, costs, and message counts. + """ + + total_input_tokens: int = 0 + total_output_tokens: int = 0 + total_cost: float = 0.0 + message_count: int = 0 + + @property + def total_tokens(self) -> int: + """Get total token count.""" + return self.total_input_tokens + self.total_output_tokens + + def add_usage(self, usage: Optional[UsageStats], cost: float = 0.0) -> None: + """ + Add usage stats from a response. + + Args: + usage: Usage statistics + cost: Cost if not in usage + """ + if usage: + self.total_input_tokens += usage.prompt_tokens + self.total_output_tokens += usage.completion_tokens + if usage.total_cost_usd: + self.total_cost += usage.total_cost_usd + else: + self.total_cost += cost + else: + self.total_cost += cost + self.message_count += 1 + + +@dataclass +class HistoryEntry: + """ + A single entry in the conversation history. + + Stores the user prompt, assistant response, and metrics. + """ + + prompt: str + response: str + prompt_tokens: int = 0 + completion_tokens: int = 0 + msg_cost: float = 0.0 + timestamp: Optional[float] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + return { + "prompt": self.prompt, + "response": self.response, + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + "msg_cost": self.msg_cost, + } + + +class ChatSession: + """ + Manages an interactive chat session. + + Handles conversation history, state management, command processing, + and communication with the AI client. + + Attributes: + client: AI client for API requests + settings: Application settings + mcp_manager: MCP manager for file/database access + history: Conversation history + stats: Session statistics + """ + + def __init__( + self, + client: AIClient, + settings: Settings, + mcp_manager: Optional[MCPManager] = None, + ): + """ + Initialize a chat session. + + Args: + client: AI client instance + settings: Application settings + mcp_manager: Optional MCP manager + """ + self.client = client + self.settings = settings + self.mcp_manager = mcp_manager + self.db = Database() + + self.history: List[HistoryEntry] = [] + self.stats = SessionStats() + + # Session state + self.system_prompt: str = settings.effective_system_prompt + self.memory_enabled: bool = True + self.memory_start_index: int = 0 + self.online_enabled: bool = settings.default_online_mode + self.middle_out_enabled: bool = False + self.session_max_token: int = 0 + self.current_index: int = 0 + + # Selected model + self.selected_model: Optional[Dict[str, Any]] = None + + self.logger = get_logger() + + def get_context(self) -> CommandContext: + """ + Get the current command context. + + Returns: + CommandContext with current session state + """ + return CommandContext( + settings=self.settings, + provider=self.client.provider, + mcp_manager=self.mcp_manager, + selected_model_raw=self.selected_model, + session_history=[e.to_dict() for e in self.history], + session_system_prompt=self.system_prompt, + memory_enabled=self.memory_enabled, + memory_start_index=self.memory_start_index, + online_enabled=self.online_enabled, + middle_out_enabled=self.middle_out_enabled, + session_max_token=self.session_max_token, + total_input_tokens=self.stats.total_input_tokens, + total_output_tokens=self.stats.total_output_tokens, + total_cost=self.stats.total_cost, + message_count=self.stats.message_count, + current_index=self.current_index, + ) + + def set_model(self, model: Dict[str, Any]) -> None: + """ + Set the selected model. + + Args: + model: Raw model dictionary + """ + self.selected_model = model + self.client.set_default_model(model["id"]) + self.logger.info(f"Model selected: {model['id']}") + + def build_api_messages(self, user_input: str) -> List[Dict[str, Any]]: + """ + Build the messages array for an API request. + + Includes system prompt, history (if memory enabled), and current input. + + Args: + user_input: Current user input + + Returns: + List of message dictionaries + """ + messages = [] + + # Add system prompt + if self.system_prompt: + messages.append({"role": "system", "content": self.system_prompt}) + + # Add database context if in database mode + if self.mcp_manager and self.mcp_manager.enabled: + if self.mcp_manager.mode == "database" and self.mcp_manager.selected_db_index is not None: + db = self.mcp_manager.databases[self.mcp_manager.selected_db_index] + db_context = ( + f"You are connected to SQLite database: {db['name']}\n" + f"Available tables: {', '.join(db['tables'])}\n\n" + "Use inspect_database, search_database, or query_database tools. " + "All queries are read-only." + ) + messages.append({"role": "system", "content": db_context}) + + # Add history if memory enabled + if self.memory_enabled: + for i in range(self.memory_start_index, len(self.history)): + entry = self.history[i] + messages.append({"role": "user", "content": entry.prompt}) + messages.append({"role": "assistant", "content": entry.response}) + + # Add current message + messages.append({"role": "user", "content": user_input}) + + return messages + + def get_mcp_tools(self) -> Optional[List[Dict[str, Any]]]: + """ + Get MCP tool definitions if available. + + Returns: + List of tool schemas or None + """ + if not self.mcp_manager or not self.mcp_manager.enabled: + return None + + if not self.selected_model: + return None + + # Check if model supports tools + supported_params = self.selected_model.get("supported_parameters", []) + if "tools" not in supported_params and "functions" not in supported_params: + return None + + return self.mcp_manager.get_tools_schema() + + async def execute_tool( + self, + tool_name: str, + tool_args: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Execute an MCP tool. + + Args: + tool_name: Name of the tool + tool_args: Tool arguments + + Returns: + Tool execution result + """ + if not self.mcp_manager: + return {"error": "MCP not available"} + + return await self.mcp_manager.call_tool(tool_name, **tool_args) + + def send_message( + self, + user_input: str, + stream: bool = True, + on_stream_chunk: Optional[Callable[[str], None]] = None, + ) -> Tuple[str, Optional[UsageStats], float]: + """ + Send a message and get a response. + + Args: + user_input: User's input text + stream: Whether to stream the response + on_stream_chunk: Callback for stream chunks + + Returns: + Tuple of (response_text, usage_stats, response_time) + """ + if not self.selected_model: + raise ValueError("No model selected") + + start_time = time.time() + messages = self.build_api_messages(user_input) + + # Get MCP tools + tools = self.get_mcp_tools() + if tools: + # Disable streaming when tools are present + stream = False + + # Build request parameters + model_id = self.selected_model["id"] + if self.online_enabled: + if hasattr(self.client.provider, "get_effective_model_id"): + model_id = self.client.provider.get_effective_model_id(model_id, True) + + transforms = ["middle-out"] if self.middle_out_enabled else None + + max_tokens = None + if self.session_max_token > 0: + max_tokens = self.session_max_token + + if tools: + # Use tool handling flow + response = self._send_with_tools( + messages=messages, + model_id=model_id, + tools=tools, + max_tokens=max_tokens, + transforms=transforms, + ) + response_time = time.time() - start_time + return response.content or "", response.usage, response_time + + elif stream: + # Use streaming flow + full_text, usage = self._stream_response( + messages=messages, + model_id=model_id, + max_tokens=max_tokens, + transforms=transforms, + on_chunk=on_stream_chunk, + ) + response_time = time.time() - start_time + return full_text, usage, response_time + + else: + # Non-streaming request + response = self.client.chat( + messages=messages, + model=model_id, + stream=False, + max_tokens=max_tokens, + transforms=transforms, + ) + response_time = time.time() - start_time + + if isinstance(response, ChatResponse): + return response.content or "", response.usage, response_time + else: + return "", None, response_time + + def _send_with_tools( + self, + messages: List[Dict[str, Any]], + model_id: str, + tools: List[Dict[str, Any]], + max_tokens: Optional[int] = None, + transforms: Optional[List[str]] = None, + ) -> ChatResponse: + """ + Send a request with tool call handling. + + Args: + messages: API messages + model_id: Model ID + tools: Tool definitions + max_tokens: Max tokens + transforms: Transforms list + + Returns: + Final ChatResponse + """ + max_loops = 5 + loop_count = 0 + api_messages = list(messages) + + while loop_count < max_loops: + response = self.client.chat( + messages=api_messages, + model=model_id, + stream=False, + max_tokens=max_tokens, + tools=tools, + tool_choice="auto", + transforms=transforms, + ) + + if not isinstance(response, ChatResponse): + raise ValueError("Expected ChatResponse") + + tool_calls = response.tool_calls + if not tool_calls: + return response + + console.print(f"\n[dim yellow]๐Ÿ”ง AI requesting {len(tool_calls)} tool call(s)...[/]") + + tool_results = [] + for tc in tool_calls: + try: + args = json.loads(tc.function.arguments) + except json.JSONDecodeError as e: + self.logger.error(f"Failed to parse tool arguments: {e}") + tool_results.append({ + "tool_call_id": tc.id, + "role": "tool", + "name": tc.function.name, + "content": json.dumps({"error": f"Invalid arguments: {e}"}), + }) + continue + + # Display tool call + args_display = ", ".join( + f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}" + for k, v in args.items() + ) + console.print(f"[dim cyan] โ†’ {tc.function.name}({args_display})[/]") + + # Execute tool + result = asyncio.run(self.execute_tool(tc.function.name, args)) + + if "error" in result: + console.print(f"[dim red] โœ— Error: {result['error']}[/]") + else: + self._display_tool_success(tc.function.name, result) + + tool_results.append({ + "tool_call_id": tc.id, + "role": "tool", + "name": tc.function.name, + "content": json.dumps(result), + }) + + # Add assistant message with tool calls + api_messages.append({ + "role": "assistant", + "content": response.content, + "tool_calls": [ + { + "id": tc.id, + "type": tc.type, + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in tool_calls + ], + }) + api_messages.extend(tool_results) + + console.print("\n[dim cyan]๐Ÿ’ญ Processing tool results...[/]") + loop_count += 1 + + self.logger.warning(f"Reached max tool loops ({max_loops})") + console.print(f"[bold yellow]โš ๏ธ Reached maximum tool calls ({max_loops})[/]") + return response + + def _display_tool_success(self, tool_name: str, result: Dict[str, Any]) -> None: + """Display a success message for a tool call.""" + if tool_name == "search_files": + count = result.get("count", 0) + console.print(f"[dim green] โœ“ Found {count} file(s)[/]") + elif tool_name == "read_file": + size = result.get("size", 0) + truncated = " (truncated)" if result.get("truncated") else "" + console.print(f"[dim green] โœ“ Read {size} bytes{truncated}[/]") + elif tool_name == "list_directory": + count = result.get("count", 0) + console.print(f"[dim green] โœ“ Listed {count} item(s)[/]") + elif tool_name == "inspect_database": + if "table" in result: + console.print(f"[dim green] โœ“ Inspected table: {result['table']}[/]") + else: + console.print(f"[dim green] โœ“ Inspected database ({result.get('table_count', 0)} tables)[/]") + elif tool_name == "search_database": + count = result.get("count", 0) + console.print(f"[dim green] โœ“ Found {count} match(es)[/]") + elif tool_name == "query_database": + count = result.get("count", 0) + console.print(f"[dim green] โœ“ Query returned {count} row(s)[/]") + else: + console.print("[dim green] โœ“ Success[/]") + + def _stream_response( + self, + messages: List[Dict[str, Any]], + model_id: str, + max_tokens: Optional[int] = None, + transforms: Optional[List[str]] = None, + on_chunk: Optional[Callable[[str], None]] = None, + ) -> Tuple[str, Optional[UsageStats]]: + """ + Stream a response with live display. + + Args: + messages: API messages + model_id: Model ID + max_tokens: Max tokens + transforms: Transforms + on_chunk: Callback for chunks + + Returns: + Tuple of (full_text, usage) + """ + response = self.client.chat( + messages=messages, + model=model_id, + stream=True, + max_tokens=max_tokens, + transforms=transforms, + ) + + if isinstance(response, ChatResponse): + return response.content or "", response.usage + + full_text = "" + usage: Optional[UsageStats] = None + + try: + with Live("", console=console, refresh_per_second=10) as live: + for chunk in response: + if chunk.error: + console.print(f"\n[bold red]Stream error: {chunk.error}[/]") + break + + if chunk.delta_content: + full_text += chunk.delta_content + live.update(Markdown(full_text)) + if on_chunk: + on_chunk(chunk.delta_content) + + if chunk.usage: + usage = chunk.usage + + except KeyboardInterrupt: + console.print("\n[bold yellow]โš ๏ธ Streaming interrupted[/]") + return "", None + + return full_text, usage + + def add_to_history( + self, + prompt: str, + response: str, + usage: Optional[UsageStats] = None, + cost: float = 0.0, + ) -> None: + """ + Add an exchange to the history. + + Args: + prompt: User prompt + response: Assistant response + usage: Usage statistics + cost: Cost if not in usage + """ + entry = HistoryEntry( + prompt=prompt, + response=response, + prompt_tokens=usage.prompt_tokens if usage else 0, + completion_tokens=usage.completion_tokens if usage else 0, + msg_cost=usage.total_cost_usd if usage and usage.total_cost_usd else cost, + timestamp=time.time(), + ) + self.history.append(entry) + self.current_index = len(self.history) - 1 + self.stats.add_usage(usage, cost) + + def save_conversation(self, name: str) -> bool: + """ + Save the current conversation. + + Args: + name: Name for the saved conversation + + Returns: + True if saved successfully + """ + if not self.history: + return False + + data = [e.to_dict() for e in self.history] + self.db.save_conversation(name, data) + self.logger.info(f"Saved conversation: {name}") + return True + + def load_conversation(self, name: str) -> bool: + """ + Load a saved conversation. + + Args: + name: Name of the conversation to load + + Returns: + True if loaded successfully + """ + data = self.db.load_conversation(name) + if not data: + return False + + self.history.clear() + for entry_dict in data: + self.history.append(HistoryEntry( + prompt=entry_dict.get("prompt", ""), + response=entry_dict.get("response", ""), + prompt_tokens=entry_dict.get("prompt_tokens", 0), + completion_tokens=entry_dict.get("completion_tokens", 0), + msg_cost=entry_dict.get("msg_cost", 0.0), + )) + + self.current_index = len(self.history) - 1 + self.memory_start_index = 0 + self.stats = SessionStats() # Reset stats for loaded conversation + self.logger.info(f"Loaded conversation: {name}") + return True + + def reset(self) -> None: + """Reset the session state.""" + self.history.clear() + self.stats = SessionStats() + self.system_prompt = "" + self.memory_start_index = 0 + self.current_index = 0 + self.logger.info("Session reset") + + def check_warnings(self) -> List[str]: + """ + Check for cost and credit warnings. + + Returns: + List of warning messages + """ + warnings = [] + + # Check last message cost + if self.history: + last_cost = self.history[-1].msg_cost + threshold = self.settings.cost_warning_threshold + if last_cost > threshold: + warnings.append( + f"High cost: ${last_cost:.4f} exceeds threshold ${threshold:.4f}" + ) + + # Check credits + credits = self.client.get_credits() + if credits: + left = credits.get("credits_left", 0) + total = credits.get("total_credits", 0) + + if left < LOW_CREDIT_AMOUNT: + warnings.append(f"Low credits: ${left:.2f} remaining!") + elif total > 0 and left < total * LOW_CREDIT_RATIO: + warnings.append(f"Credits low: less than 10% remaining (${left:.2f})") + + return warnings diff --git a/oai/mcp/__init__.py b/oai/mcp/__init__.py new file mode 100644 index 0000000..b7abc9a --- /dev/null +++ b/oai/mcp/__init__.py @@ -0,0 +1,28 @@ +""" +Model Context Protocol (MCP) integration for oAI. + +This package provides filesystem and database access capabilities +through the MCP standard, allowing AI models to interact with +local files and SQLite databases safely. + +Key components: +- MCPManager: High-level manager for MCP operations +- MCPFilesystemServer: Filesystem and database access implementation +- GitignoreParser: Pattern matching for .gitignore support +- SQLiteQueryValidator: Query safety validation +- CrossPlatformMCPConfig: OS-specific configuration +""" + +from oai.mcp.manager import MCPManager +from oai.mcp.server import MCPFilesystemServer +from oai.mcp.gitignore import GitignoreParser +from oai.mcp.validators import SQLiteQueryValidator +from oai.mcp.platform import CrossPlatformMCPConfig + +__all__ = [ + "MCPManager", + "MCPFilesystemServer", + "GitignoreParser", + "SQLiteQueryValidator", + "CrossPlatformMCPConfig", +] diff --git a/oai/mcp/gitignore.py b/oai/mcp/gitignore.py new file mode 100644 index 0000000..83ec4f4 --- /dev/null +++ b/oai/mcp/gitignore.py @@ -0,0 +1,166 @@ +""" +Gitignore pattern parsing for oAI MCP. + +This module implements .gitignore pattern matching to filter files +during MCP filesystem operations. +""" + +import fnmatch +from pathlib import Path +from typing import List, Tuple + +from oai.utils.logging import get_logger + + +class GitignoreParser: + """ + Parse and apply .gitignore patterns. + + Supports standard gitignore syntax including: + - Wildcards (*) and double wildcards (**) + - Directory-only patterns (ending with /) + - Negation patterns (starting with !) + - Comments (lines starting with #) + + Patterns are applied relative to the directory containing + the .gitignore file. + """ + + def __init__(self): + """Initialize an empty pattern collection.""" + # List of (pattern, is_negation, source_dir) + self.patterns: List[Tuple[str, bool, Path]] = [] + + def add_gitignore(self, gitignore_path: Path) -> None: + """ + Parse and add patterns from a .gitignore file. + + Args: + gitignore_path: Path to the .gitignore file + """ + logger = get_logger() + + if not gitignore_path.exists(): + return + + try: + source_dir = gitignore_path.parent + + with open(gitignore_path, "r", encoding="utf-8") as f: + for line_num, line in enumerate(f, 1): + line = line.rstrip("\n\r") + + # Skip empty lines and comments + if not line or line.startswith("#"): + continue + + # Check for negation pattern + is_negation = line.startswith("!") + if is_negation: + line = line[1:] + + # Remove leading slash (make relative to gitignore location) + if line.startswith("/"): + line = line[1:] + + self.patterns.append((line, is_negation, source_dir)) + + logger.debug( + f"Loaded {len(self.patterns)} patterns from {gitignore_path}" + ) + + except Exception as e: + logger.warning(f"Error reading {gitignore_path}: {e}") + + def should_ignore(self, path: Path) -> bool: + """ + Check if a path should be ignored based on gitignore patterns. + + Patterns are evaluated in order, with later patterns overriding + earlier ones. Negation patterns (starting with !) un-ignore + previously matched paths. + + Args: + path: Path to check + + Returns: + True if the path should be ignored + """ + if not self.patterns: + return False + + ignored = False + + for pattern, is_negation, source_dir in self.patterns: + # Only apply pattern if path is under the source directory + try: + rel_path = path.relative_to(source_dir) + except ValueError: + # Path is not relative to this gitignore's directory + continue + + rel_path_str = str(rel_path) + + # Check if pattern matches + if self._match_pattern(pattern, rel_path_str, path.is_dir()): + if is_negation: + ignored = False # Negation patterns un-ignore + else: + ignored = True + + return ignored + + def _match_pattern(self, pattern: str, path: str, is_dir: bool) -> bool: + """ + Match a gitignore pattern against a path. + + Args: + pattern: The gitignore pattern + path: The relative path string to match + is_dir: Whether the path is a directory + + Returns: + True if the pattern matches + """ + # Directory-only pattern (ends with /) + if pattern.endswith("/"): + if not is_dir: + return False + pattern = pattern[:-1] + + # Handle ** patterns (matches any number of directories) + if "**" in pattern: + pattern_parts = pattern.split("**") + if len(pattern_parts) == 2: + prefix, suffix = pattern_parts + + # Match if path starts with prefix and ends with suffix + if prefix: + if not path.startswith(prefix.rstrip("/")): + return False + if suffix: + suffix = suffix.lstrip("/") + if not (path.endswith(suffix) or f"/{suffix}" in path): + return False + return True + + # Direct match using fnmatch + if fnmatch.fnmatch(path, pattern): + return True + + # Match as subdirectory pattern (pattern without / matches in any directory) + if "/" not in pattern: + parts = path.split("/") + if any(fnmatch.fnmatch(part, pattern) for part in parts): + return True + + return False + + def clear(self) -> None: + """Clear all loaded patterns.""" + self.patterns = [] + + @property + def pattern_count(self) -> int: + """Get the number of loaded patterns.""" + return len(self.patterns) diff --git a/oai/mcp/manager.py b/oai/mcp/manager.py new file mode 100644 index 0000000..c44001c --- /dev/null +++ b/oai/mcp/manager.py @@ -0,0 +1,1365 @@ +""" +MCP Manager for oAI. + +This module provides the high-level interface for managing MCP operations, +including server lifecycle, mode switching, folder/database management, +and tool schema generation. +""" + +import asyncio +import datetime +import json +from pathlib import Path +from typing import Optional, List, Dict, Any + +from oai.constants import MAX_TOOL_LOOPS +from oai.config.database import get_database +from oai.mcp.platform import CrossPlatformMCPConfig +from oai.mcp.server import MCPFilesystemServer +from oai.utils.logging import get_logger + + +class MCPManager: + """ + Manage MCP server lifecycle, tool calls, and mode switching. + + This class provides the main interface for MCP functionality including: + - Enabling/disabling the MCP server + - Managing allowed folders and databases + - Switching between file and database modes + - Generating tool schemas for API requests + - Executing tool calls + + Attributes: + enabled: Whether MCP is currently enabled + write_enabled: Whether write mode is enabled (non-persistent) + mode: Current mode ('files' or 'database') + selected_db_index: Index of selected database (if in database mode) + server: The MCPFilesystemServer instance + allowed_folders: List of allowed folder paths + databases: List of registered databases + config: Platform configuration + session_start_time: When the current session started + """ + + def __init__(self): + """Initialize the MCP manager.""" + self.enabled = False + self.write_enabled = False # Off by default, resets each session + self.mode = "files" + self.selected_db_index: Optional[int] = None + + self.server: Optional[MCPFilesystemServer] = None + self.allowed_folders: List[Path] = [] + self.databases: List[Dict[str, Any]] = [] + + self.config = CrossPlatformMCPConfig() + self.session_start_time: Optional[datetime.datetime] = None + + # Load persisted data + self._load_folders() + self._load_databases() + + get_logger().info("MCP Manager initialized") + + # ========================================================================= + # PERSISTENCE + # ========================================================================= + + def _load_folders(self) -> None: + """Load allowed folders from database.""" + logger = get_logger() + db = get_database() + folders_json = db.get_mcp_config("allowed_folders") + + if folders_json: + try: + folder_paths = json.loads(folders_json) + self.allowed_folders = [ + self.config.normalize_path(p) for p in folder_paths + ] + logger.info(f"Loaded {len(self.allowed_folders)} folders from config") + except Exception as e: + logger.error(f"Error loading MCP folders: {e}") + self.allowed_folders = [] + + def _save_folders(self) -> None: + """Save allowed folders to database.""" + db = get_database() + folder_paths = [str(p) for p in self.allowed_folders] + db.set_mcp_config("allowed_folders", json.dumps(folder_paths)) + get_logger().info(f"Saved {len(self.allowed_folders)} folders to config") + + def _load_databases(self) -> None: + """Load databases from database.""" + self.databases = get_database().get_mcp_databases() + get_logger().info(f"Loaded {len(self.databases)} databases from config") + + # ========================================================================= + # ENABLE/DISABLE + # ========================================================================= + + def enable(self) -> Dict[str, Any]: + """ + Enable MCP server. + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - folder_count: Number of allowed folders + - database_count: Number of registered databases + - message: Status message + - error: Error message if failed + """ + logger = get_logger() + + if self.enabled: + return { + "success": False, + "error": "MCP is already enabled" + } + + try: + self.server = MCPFilesystemServer(self.allowed_folders) + self.enabled = True + self.session_start_time = datetime.datetime.now() + + get_database().set_mcp_config("mcp_enabled", "true") + + logger.info("MCP Filesystem Server enabled") + + return { + "success": True, + "folder_count": len(self.allowed_folders), + "database_count": len(self.databases), + "message": "MCP Filesystem Server started successfully" + } + except Exception as e: + logger.error(f"Error enabling MCP: {e}") + return { + "success": False, + "error": str(e) + } + + def disable(self) -> Dict[str, Any]: + """ + Disable MCP server. + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - message: Status message + - error: Error message if failed + """ + logger = get_logger() + + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled" + } + + try: + self.server = None + self.enabled = False + self.session_start_time = None + self.mode = "files" + self.selected_db_index = None + + get_database().set_mcp_config("mcp_enabled", "false") + + logger.info("MCP Filesystem Server disabled") + + return { + "success": True, + "message": "MCP Filesystem Server stopped" + } + except Exception as e: + logger.error(f"Error disabling MCP: {e}") + return { + "success": False, + "error": str(e) + } + + # ========================================================================= + # WRITE MODE + # ========================================================================= + + def enable_write(self) -> bool: + """ + Enable write mode. + + This method should be called after user confirmation in the UI. + + Returns: + True if write mode was enabled, False otherwise + """ + logger = get_logger() + + if not self.enabled: + logger.warning("Attempted to enable write mode without MCP enabled") + return False + + self.write_enabled = True + logger.info("MCP write mode enabled by user") + get_database().log_mcp_stat("write_mode_enabled", "", True) + return True + + def disable_write(self) -> None: + """Disable write mode.""" + if self.write_enabled: + self.write_enabled = False + get_logger().info("MCP write mode disabled") + get_database().log_mcp_stat("write_mode_disabled", "", True) + + # ========================================================================= + # MODE SWITCHING + # ========================================================================= + + def switch_mode( + self, + new_mode: str, + db_index: Optional[int] = None + ) -> Dict[str, Any]: + """ + Switch between files and database mode. + + Args: + new_mode: 'files' or 'database' + db_index: Database number (1-based) if switching to database mode + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - mode: Current mode + - message: Status message + - tools: Available tools in this mode + - database: Database info (if database mode) + """ + logger = get_logger() + + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled. Use /mcp on first" + } + + if new_mode == "files": + self.mode = "files" + self.selected_db_index = None + logger.info("Switched to file mode") + return { + "success": True, + "mode": "files", + "message": "Switched to file mode", + "tools": ["read_file", "list_directory", "search_files"] + } + + elif new_mode == "database": + if db_index is None: + return { + "success": False, + "error": "Database index required. Use /mcp db " + } + + if db_index < 1 or db_index > len(self.databases): + return { + "success": False, + "error": f"Invalid database number. Use 1-{len(self.databases)}" + } + + self.mode = "database" + self.selected_db_index = db_index - 1 # Convert to 0-based + db = self.databases[self.selected_db_index] + + logger.info(f"Switched to database mode: {db['name']}") + return { + "success": True, + "mode": "database", + "database": db, + "message": f"Switched to database #{db_index}: {db['name']}", + "tools": ["inspect_database", "search_database", "query_database"] + } + + else: + return { + "success": False, + "error": f"Invalid mode: {new_mode}" + } + + # ========================================================================= + # FOLDER MANAGEMENT + # ========================================================================= + + def add_folder(self, folder_path: str) -> Dict[str, Any]: + """ + Add a folder to the allowed list. + + Args: + folder_path: Path to the folder + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - path: Normalized path + - stats: Folder statistics + - total_folders: Total number of allowed folders + - warning: Optional warning message + """ + logger = get_logger() + + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled. Use /mcp on first" + } + + try: + path = self.config.normalize_path(folder_path) + + # Validation + if not path.exists(): + return { + "success": False, + "error": f"Directory does not exist: {path}" + } + + if not path.is_dir(): + return { + "success": False, + "error": f"Not a directory: {path}" + } + + if self.config.is_system_directory(path): + return { + "success": False, + "error": f"Cannot add system directory: {path}" + } + + if path in self.allowed_folders: + return { + "success": False, + "error": f"Folder already in allowed list: {path}" + } + + # Check for nested paths + parent_folder = None + for existing in self.allowed_folders: + try: + path.relative_to(existing) + parent_folder = existing + break + except ValueError: + continue + + # Get folder stats + stats = self.config.get_folder_stats(path) + + # Add folder + self.allowed_folders.append(path) + self._save_folders() + + # Update server and reload gitignores + if self.server: + self.server.allowed_folders = self.allowed_folders + self.server.reload_gitignores() + + logger.info(f"Added folder to MCP: {path}") + + result = { + "success": True, + "path": str(path), + "stats": stats, + "total_folders": len(self.allowed_folders) + } + + if parent_folder: + result["warning"] = f"Note: {parent_folder} is already allowed (parent folder)" + + return result + + except Exception as e: + logger.error(f"Error adding folder: {e}") + return { + "success": False, + "error": str(e) + } + + def remove_folder(self, folder_ref: str) -> Dict[str, Any]: + """ + Remove a folder from the allowed list. + + Args: + folder_ref: Folder path or number (1-based) + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - path: Removed folder path + - total_folders: Remaining folder count + - warning: Optional warning message + """ + logger = get_logger() + + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled. Use /mcp on first" + } + + try: + # Check if it's a number + if folder_ref.isdigit(): + index = int(folder_ref) - 1 + if 0 <= index < len(self.allowed_folders): + path = self.allowed_folders[index] + else: + return { + "success": False, + "error": f"Invalid folder number: {folder_ref}" + } + else: + # Treat as path + path = self.config.normalize_path(folder_ref) + if path not in self.allowed_folders: + return { + "success": False, + "error": f"Folder not in allowed list: {path}" + } + + # Remove folder + self.allowed_folders.remove(path) + self._save_folders() + + # Update server and reload gitignores + if self.server: + self.server.allowed_folders = self.allowed_folders + self.server.reload_gitignores() + + logger.info(f"Removed folder from MCP: {path}") + + result = { + "success": True, + "path": str(path), + "total_folders": len(self.allowed_folders) + } + + if len(self.allowed_folders) == 0: + result["warning"] = "No allowed folders remaining. Add one with /mcp add " + + return result + + except Exception as e: + logger.error(f"Error removing folder: {e}") + return { + "success": False, + "error": str(e) + } + + def list_folders(self) -> Dict[str, Any]: + """ + List all allowed folders with stats. + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - folders: List of folder info + - total_folders: Total folder count + - total_files: Total file count across folders + - total_size_mb: Total size in MB + """ + try: + folders_info = [] + total_files = 0 + total_size = 0 + + for idx, folder in enumerate(self.allowed_folders, 1): + stats = self.config.get_folder_stats(folder) + + folder_info = { + "number": idx, + "path": str(folder), + "exists": stats.get("exists", False) + } + + if stats.get("exists"): + folder_info["file_count"] = stats["file_count"] + folder_info["size_mb"] = stats["size_mb"] + total_files += stats["file_count"] + total_size += stats["total_size"] + + folders_info.append(folder_info) + + return { + "success": True, + "folders": folders_info, + "total_folders": len(self.allowed_folders), + "total_files": total_files, + "total_size_mb": total_size / (1024 * 1024) + } + + except Exception as e: + get_logger().error(f"Error listing folders: {e}") + return { + "success": False, + "error": str(e) + } + + # ========================================================================= + # DATABASE MANAGEMENT + # ========================================================================= + + def add_database(self, db_path: str) -> Dict[str, Any]: + """ + Add a SQLite database. + + Args: + db_path: Path to the database file + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - database: Database info + - number: Database number + - message: Status message + """ + import sqlite3 + logger = get_logger() + + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled. Use /mcp on first" + } + + try: + path = Path(db_path).resolve() + + # Validation + if not path.exists(): + return { + "success": False, + "error": f"Database file not found: {path}" + } + + if not path.is_file(): + return { + "success": False, + "error": f"Not a file: {path}" + } + + # Check if already added + if any(db["path"] == str(path) for db in self.databases): + return { + "success": False, + "error": f"Database already added: {path.name}" + } + + # Validate it's a SQLite database + try: + conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True) + cursor = conn.cursor() + + # Get tables + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" + ) + tables = [row[0] for row in cursor.fetchall()] + + conn.close() + except sqlite3.DatabaseError as e: + return { + "success": False, + "error": f"Not a valid SQLite database: {e}" + } + + # Get file size + db_size = path.stat().st_size + + # Create database entry + db_info = { + "path": str(path), + "name": path.name, + "size": db_size, + "tables": tables, + "added": datetime.datetime.now().isoformat() + } + + # Save to config + db_id = get_database().add_mcp_database(db_info) + db_info["id"] = db_id + + # Add to list + self.databases.append(db_info) + + logger.info(f"Added database: {path.name}") + + return { + "success": True, + "database": db_info, + "number": len(self.databases), + "message": f"Added database #{len(self.databases)}: {path.name}" + } + + except Exception as e: + logger.error(f"Error adding database: {e}") + return { + "success": False, + "error": str(e) + } + + def remove_database(self, db_ref: str) -> Dict[str, Any]: + """ + Remove a database. + + Args: + db_ref: Database number (1-based) or path + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - database: Removed database info + - message: Status message + - warning: Optional warning message + """ + logger = get_logger() + + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled" + } + + try: + # Check if it's a number + if db_ref.isdigit(): + index = int(db_ref) - 1 + if 0 <= index < len(self.databases): + db = self.databases[index] + else: + return { + "success": False, + "error": f"Invalid database number: {db_ref}" + } + else: + # Treat as path + path = str(Path(db_ref).resolve()) + db = next((d for d in self.databases if d["path"] == path), None) + if not db: + return { + "success": False, + "error": f"Database not found: {db_ref}" + } + index = self.databases.index(db) + + # If currently selected, deselect + if self.mode == "database" and self.selected_db_index == index: + self.mode = "files" + self.selected_db_index = None + + # Remove from config + get_database().remove_mcp_database(db["path"]) + + # Remove from list + self.databases.pop(index) + + logger.info(f"Removed database: {db['name']}") + + result = { + "success": True, + "database": db, + "message": f"Removed database: {db['name']}" + } + + if len(self.databases) == 0: + result["warning"] = "No databases remaining. Add one with /mcp add db " + + return result + + except Exception as e: + logger.error(f"Error removing database: {e}") + return { + "success": False, + "error": str(e) + } + + def list_databases(self) -> Dict[str, Any]: + """ + List all databases. + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - databases: List of database info + - count: Number of databases + """ + try: + db_list = [] + for idx, db in enumerate(self.databases, 1): + db_info = { + "number": idx, + "name": db["name"], + "path": db["path"], + "size_mb": db["size"] / (1024 * 1024), + "tables": db["tables"], + "table_count": len(db["tables"]), + "added": db["added"] + } + + # Check if file still exists + if not Path(db["path"]).exists(): + db_info["warning"] = "File not found" + + db_list.append(db_info) + + return { + "success": True, + "databases": db_list, + "count": len(db_list) + } + + except Exception as e: + get_logger().error(f"Error listing databases: {e}") + return { + "success": False, + "error": str(e) + } + + # ========================================================================= + # GITIGNORE + # ========================================================================= + + def toggle_gitignore(self, enabled: bool) -> Dict[str, Any]: + """ + Toggle .gitignore filtering. + + Args: + enabled: Whether to enable gitignore filtering + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - message: Status message + - pattern_count: Number of patterns (if enabled) + """ + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled" + } + + if not self.server: + return { + "success": False, + "error": "MCP server not running" + } + + self.server.respect_gitignore = enabled + status = "enabled" if enabled else "disabled" + + get_logger().info(f".gitignore filtering {status}") + + return { + "success": True, + "message": f".gitignore filtering {status}", + "pattern_count": self.server.gitignore_parser.pattern_count if enabled else 0 + } + + # ========================================================================= + # STATUS + # ========================================================================= + + def get_status(self) -> Dict[str, Any]: + """ + Get comprehensive MCP status. + + Returns: + Dictionary containing full MCP status information + """ + try: + stats = get_database().get_mcp_stats() + + uptime = None + if self.session_start_time: + delta = datetime.datetime.now() - self.session_start_time + hours = delta.seconds // 3600 + minutes = (delta.seconds % 3600) // 60 + uptime = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m" + + folder_info = self.list_folders() + + gitignore_status = "enabled" if ( + self.server and self.server.respect_gitignore + ) else "disabled" + gitignore_patterns = ( + self.server.gitignore_parser.pattern_count if self.server else 0 + ) + + # Current mode info + mode_info = { + "mode": self.mode, + "mode_display": ( + "Files" if self.mode == "files" + else f"DB #{self.selected_db_index + 1}" + ) + } + + if self.mode == "database" and self.selected_db_index is not None: + mode_info["database"] = self.databases[self.selected_db_index] + + return { + "success": True, + "enabled": self.enabled, + "write_enabled": self.write_enabled, + "uptime": uptime, + "mode_info": mode_info, + "folder_count": len(self.allowed_folders), + "database_count": len(self.databases), + "total_files": folder_info.get("total_files", 0), + "total_size_mb": folder_info.get("total_size_mb", 0), + "stats": stats, + "tools_available": self._get_current_tools(), + "gitignore_status": gitignore_status, + "gitignore_patterns": gitignore_patterns + } + + except Exception as e: + get_logger().error(f"Error getting MCP status: {e}") + return { + "success": False, + "error": str(e) + } + + def _get_current_tools(self) -> List[str]: + """Get tools available in current mode.""" + if not self.enabled: + return [] + + if self.mode == "files": + tools = ["read_file", "list_directory", "search_files"] + if self.write_enabled: + tools.extend([ + "write_file", "edit_file", "delete_file", + "create_directory", "move_file", "copy_file" + ]) + return tools + elif self.mode == "database": + return ["inspect_database", "search_database", "query_database"] + + return [] + + # ========================================================================= + # TOOL EXECUTION + # ========================================================================= + + async def call_tool(self, tool_name: str, **kwargs) -> Dict[str, Any]: + """ + Call an MCP tool. + + Args: + tool_name: Name of the tool to call + **kwargs: Tool arguments + + Returns: + Tool execution result + """ + if not self.enabled or not self.server: + return {"error": "MCP is not enabled"} + + # Check write permissions for write operations + write_tools = { + "write_file", "edit_file", "delete_file", + "create_directory", "move_file", "copy_file" + } + if tool_name in write_tools and not self.write_enabled: + return {"error": "Write operations disabled. Enable with /mcp write on"} + + try: + # File mode tools (read-only) + if tool_name == "read_file": + return await self.server.read_file(kwargs.get("file_path", "")) + elif tool_name == "list_directory": + return await self.server.list_directory( + kwargs.get("dir_path", ""), + kwargs.get("recursive", False) + ) + elif tool_name == "search_files": + return await self.server.search_files( + kwargs.get("pattern", ""), + kwargs.get("search_path"), + kwargs.get("content_search", False) + ) + + # Database mode tools + elif tool_name == "inspect_database": + if self.mode != "database" or self.selected_db_index is None: + return {"error": "Not in database mode. Use /mcp db first"} + db = self.databases[self.selected_db_index] + return await self.server.inspect_database( + db["path"], + kwargs.get("table_name") + ) + elif tool_name == "search_database": + if self.mode != "database" or self.selected_db_index is None: + return {"error": "Not in database mode. Use /mcp db first"} + db = self.databases[self.selected_db_index] + return await self.server.search_database( + db["path"], + kwargs.get("search_term", ""), + kwargs.get("table_name"), + kwargs.get("column_name") + ) + elif tool_name == "query_database": + if self.mode != "database" or self.selected_db_index is None: + return {"error": "Not in database mode. Use /mcp db first"} + db = self.databases[self.selected_db_index] + return await self.server.query_database( + db["path"], + kwargs.get("query", ""), + kwargs.get("limit") + ) + + # Write mode tools + elif tool_name == "write_file": + return await self.server.write_file( + kwargs.get("file_path", ""), + kwargs.get("content", "") + ) + elif tool_name == "edit_file": + return await self.server.edit_file( + kwargs.get("file_path", ""), + kwargs.get("old_text", ""), + kwargs.get("new_text", "") + ) + elif tool_name == "delete_file": + return await self.server.delete_file( + kwargs.get("file_path", ""), + kwargs.get("reason", "No reason provided"), + kwargs.get("confirm_callback") + ) + elif tool_name == "create_directory": + return await self.server.create_directory( + kwargs.get("dir_path", "") + ) + elif tool_name == "move_file": + return await self.server.move_file( + kwargs.get("source_path", ""), + kwargs.get("dest_path", "") + ) + elif tool_name == "copy_file": + return await self.server.copy_file( + kwargs.get("source_path", ""), + kwargs.get("dest_path", "") + ) + else: + return {"error": f"Unknown tool: {tool_name}"} + + except Exception as e: + get_logger().error(f"Error calling MCP tool {tool_name}: {e}") + return {"error": str(e)} + + # ========================================================================= + # TOOL SCHEMA GENERATION + # ========================================================================= + + def get_tools_schema(self) -> List[Dict[str, Any]]: + """ + Get MCP tools as OpenAI function calling schema. + + Returns the appropriate schema based on current mode. + + Returns: + List of tool definitions for API request + """ + if not self.enabled: + return [] + + if self.mode == "files": + return self._get_file_tools_schema() + elif self.mode == "database": + return self._get_database_tools_schema() + + return [] + + def _get_file_tools_schema(self) -> List[Dict[str, Any]]: + """Get file mode tools schema.""" + if len(self.allowed_folders) == 0: + return [] + + allowed_dirs_str = ", ".join(str(f) for f in self.allowed_folders) + + # Base read-only tools + tools = [ + { + "type": "function", + "function": { + "name": "search_files", + "description": ( + f"Search for files in the user's local filesystem. " + f"Allowed directories: {allowed_dirs_str}. " + "Can search by filename pattern (e.g., '*.py' for Python files) " + "or search within file contents. Automatically respects .gitignore " + "patterns and excludes virtual environments." + ), + "parameters": { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": ( + "Search pattern. For filename search, use glob patterns " + "like '*.py', '*.txt', 'report*'. For content search, " + "use plain text to search for." + ) + }, + "content_search": { + "type": "boolean", + "description": ( + "If true, searches inside file contents (SLOW). " + "If false, searches only filenames (FAST). Default is false." + ), + "default": False + }, + "search_path": { + "type": "string", + "description": ( + "Optional: specific directory to search in. " + "If not provided, searches all allowed directories." + ), + } + }, + "required": ["pattern"] + } + } + }, + { + "type": "function", + "function": { + "name": "read_file", + "description": ( + "Read the complete contents of a text file. " + "Only works for files within allowed directories. " + "Maximum file size: 10MB. Files larger than 50KB are automatically truncated. " + "Respects .gitignore patterns." + ), + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": ( + "Full path to the file to read " + "(e.g., /Users/username/Documents/report.txt)" + ) + } + }, + "required": ["file_path"] + } + } + }, + { + "type": "function", + "function": { + "name": "list_directory", + "description": ( + "List all files and subdirectories in a directory. " + "Automatically filters out virtual environments, build artifacts, " + "and .gitignore patterns. Limited to 1000 items." + ), + "parameters": { + "type": "object", + "properties": { + "dir_path": { + "type": "string", + "description": ( + "Directory path to list " + "(e.g., ~/Documents or /Users/username/Projects)" + ) + }, + "recursive": { + "type": "boolean", + "description": ( + "If true, lists files in subdirectories recursively. " + "Default is true. WARNING: Can be very large." + ), + "default": True + } + }, + "required": ["dir_path"] + } + } + } + ] + + # Add write tools if write mode is enabled + if self.write_enabled: + tools.extend(self._get_write_tools_schema(allowed_dirs_str)) + + return tools + + def _get_write_tools_schema(self, allowed_dirs_str: str) -> List[Dict[str, Any]]: + """Get write mode tools schema.""" + return [ + { + "type": "function", + "function": { + "name": "write_file", + "description": ( + f"Create or overwrite a file with content. " + f"Allowed directories: {allowed_dirs_str}. " + "Automatically creates parent directories if needed." + ), + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Full path to the file to create or overwrite" + }, + "content": { + "type": "string", + "description": "The complete content to write to the file" + } + }, + "required": ["file_path", "content"] + } + } + }, + { + "type": "function", + "function": { + "name": "edit_file", + "description": ( + "Make targeted edits by finding and replacing specific text. " + "The old_text must match exactly and appear only once in the file." + ), + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Full path to the file to edit" + }, + "old_text": { + "type": "string", + "description": ( + "The exact text to find and replace. " + "Must match exactly and appear only once." + ) + }, + "new_text": { + "type": "string", + "description": "The new text to replace the old text with" + } + }, + "required": ["file_path", "old_text", "new_text"] + } + } + }, + { + "type": "function", + "function": { + "name": "delete_file", + "description": ( + "Delete a file. ALWAYS requires user confirmation. " + "The user will see the file details and your reason before deciding." + ), + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Full path to the file to delete" + }, + "reason": { + "type": "string", + "description": ( + "Clear explanation for why this file should be deleted. " + "The user will see this reason." + ) + } + }, + "required": ["file_path", "reason"] + } + } + }, + { + "type": "function", + "function": { + "name": "create_directory", + "description": ( + "Create a new directory (and parent directories if needed). " + "Returns success if directory already exists." + ), + "parameters": { + "type": "object", + "properties": { + "dir_path": { + "type": "string", + "description": "Full path to the directory to create" + } + }, + "required": ["dir_path"] + } + } + }, + { + "type": "function", + "function": { + "name": "move_file", + "description": "Move or rename a file within allowed directories.", + "parameters": { + "type": "object", + "properties": { + "source_path": { + "type": "string", + "description": "Full path to the file to move/rename" + }, + "dest_path": { + "type": "string", + "description": "Full path for the new location/name" + } + }, + "required": ["source_path", "dest_path"] + } + } + }, + { + "type": "function", + "function": { + "name": "copy_file", + "description": "Copy a file to a new location within allowed directories.", + "parameters": { + "type": "object", + "properties": { + "source_path": { + "type": "string", + "description": "Full path to the file to copy" + }, + "dest_path": { + "type": "string", + "description": "Full path for the copy destination" + } + }, + "required": ["source_path", "dest_path"] + } + } + } + ] + + def _get_database_tools_schema(self) -> List[Dict[str, Any]]: + """Get database mode tools schema.""" + if self.selected_db_index is None or self.selected_db_index >= len(self.databases): + return [] + + db = self.databases[self.selected_db_index] + db_name = db["name"] + tables_str = ", ".join(db["tables"]) + + return [ + { + "type": "function", + "function": { + "name": "inspect_database", + "description": ( + f"Inspect the schema of the currently selected database ({db_name}). " + f"Can get all tables or details for a specific table. " + f"Available tables: {tables_str}." + ), + "parameters": { + "type": "object", + "properties": { + "table_name": { + "type": "string", + "description": ( + f"Optional: specific table to inspect. " + f"If not provided, returns info for all tables. " + f"Available: {tables_str}" + ) + } + }, + "required": [] + } + } + }, + { + "type": "function", + "function": { + "name": "search_database", + "description": ( + f"Search for a value across tables in the database ({db_name}). " + f"Performs partial matching across columns. " + f"Limited to {self.server.default_query_limit} results." + ), + "parameters": { + "type": "object", + "properties": { + "search_term": { + "type": "string", + "description": "Value to search for (partial match supported)" + }, + "table_name": { + "type": "string", + "description": f"Optional: limit search to specific table. Available: {tables_str}" + }, + "column_name": { + "type": "string", + "description": "Optional: limit search to specific column" + } + }, + "required": ["search_term"] + } + } + }, + { + "type": "function", + "function": { + "name": "query_database", + "description": ( + f"Execute a read-only SQL query on the database ({db_name}). " + f"Supports SELECT queries including JOINs, subqueries, CTEs. " + f"Maximum {self.server.max_query_results} rows. " + f"Timeout: {self.server.max_query_timeout} seconds. " + "INSERT/UPDATE/DELETE/DROP are blocked." + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": ( + f"SQL SELECT query to execute. " + f"Available tables: {tables_str}." + ) + }, + "limit": { + "type": "integer", + "description": ( + f"Optional: max rows to return " + f"(default {self.server.default_query_limit}, " + f"max {self.server.max_query_results})" + ) + } + }, + "required": ["query"] + } + } + } + ] + + +# Global MCP manager instance +_mcp_manager: Optional[MCPManager] = None + + +def get_mcp_manager() -> MCPManager: + """ + Get the global MCP manager instance. + + Returns: + The shared MCPManager instance + """ + global _mcp_manager + if _mcp_manager is None: + _mcp_manager = MCPManager() + return _mcp_manager diff --git a/oai/mcp/platform.py b/oai/mcp/platform.py new file mode 100644 index 0000000..dc941bf --- /dev/null +++ b/oai/mcp/platform.py @@ -0,0 +1,228 @@ +""" +Cross-platform MCP configuration for oAI. + +This module handles OS-specific configuration, path handling, +and security checks for the MCP filesystem server. +""" + +import os +import platform +import subprocess +from pathlib import Path +from typing import List, Dict, Any, Optional + +from oai.constants import SYSTEM_DIRS_BLACKLIST +from oai.utils.logging import get_logger + + +class CrossPlatformMCPConfig: + """ + Handle OS-specific MCP configuration. + + Provides methods for path normalization, security validation, + and OS-specific default directories. + + Attributes: + system: Operating system name + is_macos: Whether running on macOS + is_linux: Whether running on Linux + is_windows: Whether running on Windows + """ + + def __init__(self): + """Initialize platform detection.""" + self.system = platform.system() + self.is_macos = self.system == "Darwin" + self.is_linux = self.system == "Linux" + self.is_windows = self.system == "Windows" + + logger = get_logger() + logger.info(f"Detected OS: {self.system}") + + def get_default_allowed_dirs(self) -> List[Path]: + """ + Get safe default directories for the current OS. + + Returns: + List of default directories that are safe to access + """ + home = Path.home() + + if self.is_macos: + return [ + home / "Documents", + home / "Desktop", + home / "Downloads", + ] + + elif self.is_linux: + dirs = [home / "Documents"] + + # Try to get XDG directories + try: + for xdg_dir in ["DOCUMENTS", "DESKTOP", "DOWNLOAD"]: + result = subprocess.run( + ["xdg-user-dir", xdg_dir], + capture_output=True, + text=True, + timeout=1 + ) + if result.returncode == 0: + dir_path = Path(result.stdout.strip()) + if dir_path.exists(): + dirs.append(dir_path) + except (subprocess.TimeoutExpired, FileNotFoundError): + # Fallback to standard locations + dirs.extend([ + home / "Desktop", + home / "Downloads", + ]) + + return list(set(dirs)) + + elif self.is_windows: + return [ + home / "Documents", + home / "Desktop", + home / "Downloads", + ] + + # Fallback for unknown OS + return [home] + + def get_python_command(self) -> str: + """ + Get the Python executable path. + + Returns: + Path to the Python executable + """ + import sys + return sys.executable + + def get_filesystem_warning(self) -> str: + """ + Get OS-specific security warning message. + + Returns: + Warning message for the current OS + """ + if self.is_macos: + return """ +Note: macOS Security +The Filesystem MCP server needs access to your selected folder. +You may see a security prompt - click 'Allow' to proceed. +(System Settings > Privacy & Security > Files and Folders) +""" + elif self.is_linux: + return """ +Note: Linux Security +The Filesystem MCP server will access your selected folder. +Ensure oAI has appropriate file permissions. +""" + elif self.is_windows: + return """ +Note: Windows Security +The Filesystem MCP server will access your selected folder. +You may need to grant file access permissions. +""" + return "" + + def normalize_path(self, path: str) -> Path: + """ + Normalize a path for the current OS. + + Expands user directory (~) and resolves to absolute path. + + Args: + path: Path string to normalize + + Returns: + Normalized absolute Path + """ + return Path(os.path.expanduser(path)).resolve() + + def is_system_directory(self, path: Path) -> bool: + """ + Check if a path is a protected system directory. + + Args: + path: Path to check + + Returns: + True if the path is a system directory + """ + path_str = str(path) + for blocked in SYSTEM_DIRS_BLACKLIST: + if path_str.startswith(blocked): + return True + return False + + def is_safe_path(self, requested_path: Path, allowed_dirs: List[Path]) -> bool: + """ + Check if a path is within allowed directories. + + Args: + requested_path: Path being requested + allowed_dirs: List of allowed parent directories + + Returns: + True if the path is within an allowed directory + """ + try: + requested = requested_path.resolve() + + for allowed in allowed_dirs: + try: + allowed_resolved = allowed.resolve() + requested.relative_to(allowed_resolved) + return True + except ValueError: + continue + + return False + except Exception: + return False + + def get_folder_stats(self, folder: Path) -> Dict[str, Any]: + """ + Get statistics for a folder. + + Args: + folder: Path to the folder + + Returns: + Dictionary with folder statistics: + - exists: Whether the folder exists + - file_count: Number of files (if exists) + - total_size: Total size in bytes (if exists) + - size_mb: Size in megabytes (if exists) + - error: Error message (if any) + """ + logger = get_logger() + + try: + if not folder.exists() or not folder.is_dir(): + return {"exists": False} + + file_count = 0 + total_size = 0 + + for item in folder.rglob("*"): + if item.is_file(): + file_count += 1 + try: + total_size += item.stat().st_size + except (OSError, PermissionError): + pass + + return { + "exists": True, + "file_count": file_count, + "total_size": total_size, + "size_mb": total_size / (1024 * 1024), + } + + except Exception as e: + logger.error(f"Error getting folder stats for {folder}: {e}") + return {"exists": False, "error": str(e)} diff --git a/oai/mcp/server.py b/oai/mcp/server.py new file mode 100644 index 0000000..105a0de --- /dev/null +++ b/oai/mcp/server.py @@ -0,0 +1,1368 @@ +""" +MCP Filesystem Server implementation for oAI. + +This module provides the actual filesystem and database access operations +for the MCP integration. It handles file reading, directory listing, +searching, and SQLite database operations. +""" + +import datetime +import signal +import sqlite3 +from pathlib import Path +from typing import Optional, List, Dict, Any, Set + +from oai.constants import ( + MAX_FILE_SIZE, + CONTENT_TRUNCATION_THRESHOLD, + MAX_LIST_ITEMS, + MAX_QUERY_TIMEOUT, + MAX_QUERY_RESULTS, + DEFAULT_QUERY_LIMIT, + SKIP_DIRECTORIES, +) +from oai.config.database import get_database +from oai.mcp.platform import CrossPlatformMCPConfig +from oai.mcp.gitignore import GitignoreParser +from oai.mcp.validators import SQLiteQueryValidator +from oai.utils.logging import get_logger + + +class MCPFilesystemServer: + """ + MCP Filesystem Server with file access and SQLite database querying. + + Provides async methods for: + - File operations: read, write, edit, delete, copy, move + - Directory operations: list, create + - Search operations: filename and content search + - Database operations: inspect, search, query + + All operations respect allowed folder restrictions and gitignore patterns. + + Attributes: + allowed_folders: List of folders the server can access + config: Platform-specific configuration + max_file_size: Maximum file size for read operations + max_list_items: Maximum items to return in directory listings + respect_gitignore: Whether to apply gitignore filtering + gitignore_parser: Parser for gitignore patterns + max_query_timeout: Maximum seconds for database queries + max_query_results: Maximum rows to return from queries + default_query_limit: Default row limit for queries + """ + + def __init__(self, allowed_folders: List[Path]): + """ + Initialize the filesystem server. + + Args: + allowed_folders: List of folder paths the server can access + """ + self.allowed_folders = allowed_folders + self.config = CrossPlatformMCPConfig() + self.max_file_size = MAX_FILE_SIZE + self.max_list_items = MAX_LIST_ITEMS + self.respect_gitignore = True + + # Initialize gitignore parser and load patterns + self.gitignore_parser = GitignoreParser() + self._load_gitignores() + + # SQLite configuration + self.max_query_timeout = MAX_QUERY_TIMEOUT + self.max_query_results = MAX_QUERY_RESULTS + self.default_query_limit = DEFAULT_QUERY_LIMIT + + logger = get_logger() + logger.info( + f"MCP Filesystem Server initialized with {len(allowed_folders)} folders" + ) + + def _load_gitignores(self) -> None: + """Load all .gitignore files from allowed folders.""" + logger = get_logger() + gitignore_count = 0 + + for folder in self.allowed_folders: + if not folder.exists(): + continue + + # Load root .gitignore + root_gitignore = folder / ".gitignore" + if root_gitignore.exists(): + self.gitignore_parser.add_gitignore(root_gitignore) + gitignore_count += 1 + logger.info(f"Loaded .gitignore from {folder}") + + # Load nested .gitignore files + try: + for gitignore_path in folder.rglob(".gitignore"): + if gitignore_path != root_gitignore: + self.gitignore_parser.add_gitignore(gitignore_path) + gitignore_count += 1 + logger.debug(f"Loaded nested .gitignore: {gitignore_path}") + except Exception as e: + logger.warning(f"Error loading nested .gitignores from {folder}: {e}") + + if gitignore_count > 0: + logger.info( + f"Loaded {gitignore_count} .gitignore file(s) with " + f"{self.gitignore_parser.pattern_count} total patterns" + ) + + def reload_gitignores(self) -> None: + """Reload all .gitignore files (call when allowed_folders changes).""" + self.gitignore_parser.clear() + self._load_gitignores() + get_logger().info("Reloaded .gitignore patterns") + + def is_allowed_path(self, path: Path) -> bool: + """ + Check if a path is within allowed folders. + + Args: + path: Path to check + + Returns: + True if the path is allowed + """ + return self.config.is_safe_path(path, self.allowed_folders) + + def _should_skip_path(self, item_path: Path) -> bool: + """ + Check if a path should be skipped during traversal. + + Args: + item_path: Path to check + + Returns: + True if the path should be skipped + """ + # Check hardcoded skip directories + path_parts = item_path.parts + if any(part in SKIP_DIRECTORIES for part in path_parts): + return True + + # Check gitignore patterns (if enabled) + if self.respect_gitignore and self.gitignore_parser.should_ignore(item_path): + return True + + return False + + def _log_stat( + self, + tool_name: str, + folder: Optional[str], + success: bool, + error_message: Optional[str] = None + ) -> None: + """Log an MCP tool usage event.""" + get_database().log_mcp_stat(tool_name, folder, success, error_message) + + # ========================================================================= + # FILE READ OPERATIONS + # ========================================================================= + + async def read_file(self, file_path: str) -> Dict[str, Any]: + """ + Read file contents. + + Args: + file_path: Path to the file to read + + Returns: + Dictionary containing: + - content: File content + - path: Resolved file path + - size: File size in bytes + - truncated: Whether content was truncated + - error: Error message if failed + """ + logger = get_logger() + + try: + path = self.config.normalize_path(file_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} not in allowed folders" + logger.warning(error_msg) + self._log_stat("read_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + if not path.exists(): + error_msg = f"File not found: {path}" + logger.warning(error_msg) + self._log_stat("read_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + if not path.is_file(): + error_msg = f"Not a file: {path}" + logger.warning(error_msg) + self._log_stat("read_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Check if file should be ignored + if self.respect_gitignore and self.gitignore_parser.should_ignore(path): + error_msg = f"File ignored by .gitignore: {path}" + logger.warning(error_msg) + self._log_stat("read_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + file_size = path.stat().st_size + if file_size > self.max_file_size: + error_msg = ( + f"File too large: {file_size / (1024*1024):.1f}MB " + f"(max: {self.max_file_size / (1024*1024):.0f}MB)" + ) + logger.warning(error_msg) + self._log_stat("read_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Try to read as text + try: + content = path.read_text(encoding="utf-8") + + # Truncate large files + if file_size > CONTENT_TRUNCATION_THRESHOLD: + lines = content.split("\n") + total_lines = len(lines) + + head_lines = 500 + tail_lines = 100 + + if total_lines > (head_lines + tail_lines): + truncated_content = ( + "\n".join(lines[:head_lines]) + + f"\n\n... [TRUNCATED: {total_lines - head_lines - tail_lines} " + f"lines omitted] ...\n\n" + + "\n".join(lines[-tail_lines:]) + ) + + logger.info( + f"Read file (truncated): {path} " + f"({file_size} bytes, {total_lines} lines)" + ) + self._log_stat("read_file", str(path.parent), True) + + return { + "content": truncated_content, + "path": str(path), + "size": file_size, + "truncated": True, + "total_lines": total_lines, + "lines_shown": head_lines + tail_lines, + "note": ( + f"File truncated: showing first {head_lines} " + f"and last {tail_lines} lines of {total_lines} total" + ) + } + + logger.info(f"Read file: {path} ({file_size} bytes)") + self._log_stat("read_file", str(path.parent), True) + + return { + "content": content, + "path": str(path), + "size": file_size + } + + except UnicodeDecodeError: + error_msg = f"Cannot decode file as UTF-8: {path}" + logger.warning(error_msg) + self._log_stat("read_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + except Exception as e: + error_msg = f"Error reading file: {e}" + logger.error(error_msg) + self._log_stat("read_file", file_path, False, str(e)) + return {"error": error_msg} + + async def list_directory( + self, + dir_path: str, + recursive: bool = False + ) -> Dict[str, Any]: + """ + List directory contents. + + Args: + dir_path: Path to the directory + recursive: Whether to list recursively + + Returns: + Dictionary containing: + - path: Directory path + - items: List of file/directory info + - count: Number of items + - truncated: Whether results were truncated + """ + logger = get_logger() + + try: + path = self.config.normalize_path(dir_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} not in allowed folders" + logger.warning(error_msg) + self._log_stat("list_directory", str(path), False, error_msg) + return {"error": error_msg} + + if not path.exists(): + error_msg = f"Directory not found: {path}" + logger.warning(error_msg) + self._log_stat("list_directory", str(path), False, error_msg) + return {"error": error_msg} + + if not path.is_dir(): + error_msg = f"Not a directory: {path}" + logger.warning(error_msg) + self._log_stat("list_directory", str(path), False, error_msg) + return {"error": error_msg} + + items = [] + pattern = "**/*" if recursive else "*" + + for item in path.glob(pattern): + # Stop if we hit the limit + if len(items) >= self.max_list_items: + break + + # Skip excluded directories + if self._should_skip_path(item): + continue + + try: + stat = item.stat() + items.append({ + "name": item.name, + "path": str(item), + "type": "directory" if item.is_dir() else "file", + "size": stat.st_size if item.is_file() else 0, + "modified": datetime.datetime.fromtimestamp( + stat.st_mtime + ).isoformat() + }) + except (OSError, PermissionError): + continue + + truncated = len(items) >= self.max_list_items + + logger.info( + f"Listed directory: {path} ({len(items)} items, " + f"recursive={recursive}, truncated={truncated})" + ) + self._log_stat("list_directory", str(path), True) + + result = { + "path": str(path), + "items": items, + "count": len(items), + "truncated": truncated + } + + if truncated: + result["note"] = ( + f"Results limited to {self.max_list_items} items. " + "Use more specific search path or disable recursive mode." + ) + + return result + + except Exception as e: + error_msg = f"Error listing directory: {e}" + logger.error(error_msg) + self._log_stat("list_directory", dir_path, False, str(e)) + return {"error": error_msg} + + async def search_files( + self, + pattern: str, + search_path: Optional[str] = None, + content_search: bool = False + ) -> Dict[str, Any]: + """ + Search for files by name or content. + + Args: + pattern: Search pattern (glob for filename, text for content) + search_path: Optional path to search within + content_search: Whether to search within file contents + + Returns: + Dictionary containing: + - pattern: The search pattern used + - search_type: 'filename' or 'content' + - matches: List of matching files + - count: Number of matches + """ + logger = get_logger() + + try: + # Determine search roots + if search_path: + path = self.config.normalize_path(search_path) + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} not in allowed folders" + logger.warning(error_msg) + self._log_stat("search_files", str(path), False, error_msg) + return {"error": error_msg} + search_roots = [path] + else: + search_roots = self.allowed_folders + + matches = [] + + for root in search_roots: + if not root.exists(): + continue + + # Filename search (fast) + if not content_search: + for item in root.rglob(pattern): + if item.is_file() and not self._should_skip_path(item): + try: + stat = item.stat() + matches.append({ + "path": str(item), + "name": item.name, + "size": stat.st_size, + "modified": datetime.datetime.fromtimestamp( + stat.st_mtime + ).isoformat() + }) + except (OSError, PermissionError): + continue + else: + # Content search (slower) + for item in root.rglob("*"): + if item.is_file() and not self._should_skip_path(item): + try: + # Skip large files + if item.stat().st_size > self.max_file_size: + continue + + # Try to read and search content + try: + content = item.read_text(encoding="utf-8") + if pattern.lower() in content.lower(): + stat = item.stat() + matches.append({ + "path": str(item), + "name": item.name, + "size": stat.st_size, + "modified": datetime.datetime.fromtimestamp( + stat.st_mtime + ).isoformat() + }) + except (UnicodeDecodeError, PermissionError): + continue + except (OSError, PermissionError): + continue + + search_type = "content" if content_search else "filename" + logger.info( + f"Searched files: pattern='{pattern}', type={search_type}, " + f"found={len(matches)}" + ) + self._log_stat( + "search_files", + str(search_roots[0]) if search_roots else None, + True + ) + + return { + "pattern": pattern, + "search_type": search_type, + "matches": matches, + "count": len(matches) + } + + except Exception as e: + error_msg = f"Error searching files: {e}" + logger.error(error_msg) + self._log_stat("search_files", search_path or "all", False, str(e)) + return {"error": error_msg} + + # ========================================================================= + # FILE WRITE OPERATIONS + # ========================================================================= + + async def write_file(self, file_path: str, content: str) -> Dict[str, Any]: + """ + Create or overwrite a file. + + Note: Ignores .gitignore - allows writing to any file in allowed folders. + + Args: + file_path: Path to the file + content: Content to write + + Returns: + Dictionary containing: + - success: Whether the operation succeeded + - path: File path + - size: File size after writing + - created: Whether a new file was created + """ + logger = get_logger() + + try: + path = self.config.normalize_path(file_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("write_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Create parent directory if needed + parent_dir = path.parent + if not parent_dir.exists(): + parent_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Created parent directory: {parent_dir}") + + is_new_file = not path.exists() + + # Write content + path.write_text(content, encoding="utf-8") + file_size = path.stat().st_size + + logger.info( + f"{'Created' if is_new_file else 'Updated'} file: " + f"{path} ({file_size} bytes)" + ) + self._log_stat("write_file", str(path.parent), True) + + return { + "success": True, + "path": str(path), + "size": file_size, + "created": is_new_file + } + + except PermissionError as e: + error_msg = f"Permission denied writing to {file_path}: {e}" + logger.error(error_msg) + self._log_stat("write_file", file_path, False, str(e)) + return {"error": error_msg, "error_type": "PermissionError"} + except UnicodeEncodeError as e: + error_msg = f"Encoding error writing to {file_path}: {e}" + logger.error(error_msg) + self._log_stat("write_file", file_path, False, str(e)) + return {"error": error_msg, "error_type": "EncodingError"} + except Exception as e: + error_msg = f"Error writing file {file_path}: {e}" + logger.error(error_msg) + self._log_stat("write_file", file_path, False, str(e)) + return {"error": error_msg} + + async def edit_file( + self, + file_path: str, + old_text: str, + new_text: str + ) -> Dict[str, Any]: + """ + Find and replace text in a file. + + Note: Ignores .gitignore - allows editing any file in allowed folders. + + Args: + file_path: Path to the file + old_text: Text to find (must be unique in file) + new_text: Text to replace with + + Returns: + Dictionary containing: + - success: Whether the operation succeeded + - path: File path + - changes: Number of replacements made + """ + logger = get_logger() + + try: + path = self.config.normalize_path(file_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("edit_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + if not path.exists(): + error_msg = f"File not found: {path}" + logger.warning(error_msg) + self._log_stat("edit_file", str(path), False, error_msg) + return {"error": error_msg} + + if not path.is_file(): + error_msg = f"Not a file: {path}" + logger.warning(error_msg) + self._log_stat("edit_file", str(path), False, error_msg) + return {"error": error_msg} + + # Read current content + current_content = path.read_text(encoding="utf-8") + + # Check if old_text exists + if old_text not in current_content: + error_msg = f"Text not found in file: '{old_text[:50]}...'" + logger.warning(f"Edit failed - text not found in {path}") + self._log_stat("edit_file", str(path), False, error_msg) + return {"error": error_msg} + + # Check for unique match + occurrence_count = current_content.count(old_text) + if occurrence_count > 1: + error_msg = ( + f"Text appears {occurrence_count} times in file. " + "Please provide more context to make the match unique." + ) + logger.warning(f"Edit failed - ambiguous match in {path}") + self._log_stat("edit_file", str(path), False, error_msg) + return {"error": error_msg} + + # Perform replacement + new_content = current_content.replace(old_text, new_text, 1) + path.write_text(new_content, encoding="utf-8") + + logger.info(f"Edited file: {path} (1 replacement)") + self._log_stat("edit_file", str(path.parent), True) + + return { + "success": True, + "path": str(path), + "changes": 1 + } + + except PermissionError as e: + error_msg = f"Permission denied editing {file_path}: {e}" + logger.error(error_msg) + self._log_stat("edit_file", file_path, False, str(e)) + return {"error": error_msg, "error_type": "PermissionError"} + except UnicodeDecodeError as e: + error_msg = f"Cannot read file (encoding issue): {file_path}" + logger.error(error_msg) + self._log_stat("edit_file", file_path, False, str(e)) + return {"error": error_msg, "error_type": "EncodingError"} + except Exception as e: + error_msg = f"Error editing file {file_path}: {e}" + logger.error(error_msg) + self._log_stat("edit_file", file_path, False, str(e)) + return {"error": error_msg} + + async def delete_file( + self, + file_path: str, + reason: str, + confirm_callback=None + ) -> Dict[str, Any]: + """ + Delete a file with user confirmation. + + Note: Ignores .gitignore - allows deleting any file in allowed folders. + + Args: + file_path: Path to the file to delete + reason: Reason for deletion (shown to user) + confirm_callback: Callback function for user confirmation + + Returns: + Dictionary containing: + - success: Whether the file was deleted + - path: File path + - user_cancelled: Whether user cancelled the operation + """ + logger = get_logger() + + try: + path = self.config.normalize_path(file_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("delete_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + if not path.exists(): + error_msg = f"File not found: {path}" + logger.warning(error_msg) + self._log_stat("delete_file", str(path), False, error_msg) + return {"error": error_msg} + + if not path.is_file(): + error_msg = f"Not a file: {path}" + logger.warning(error_msg) + self._log_stat("delete_file", str(path), False, error_msg) + return {"error": error_msg} + + # Get file info for confirmation + file_size = path.stat().st_size + file_mtime = datetime.datetime.fromtimestamp(path.stat().st_mtime) + + # User confirmation required (handled by callback) + if confirm_callback: + confirmed = confirm_callback(path, file_size, file_mtime, reason) + if not confirmed: + logger.info(f"User cancelled file deletion: {path}") + self._log_stat("delete_file", str(path), False, "User cancelled") + return { + "success": False, + "user_cancelled": True, + "path": str(path) + } + + # Delete the file + path.unlink() + + logger.info(f"Deleted file: {path}") + self._log_stat("delete_file", str(path.parent), True) + + return { + "success": True, + "path": str(path), + "user_cancelled": False + } + + except PermissionError as e: + error_msg = f"Permission denied deleting {file_path}: {e}" + logger.error(error_msg) + self._log_stat("delete_file", file_path, False, str(e)) + return {"error": error_msg, "error_type": "PermissionError"} + except Exception as e: + error_msg = f"Error deleting file {file_path}: {e}" + logger.error(error_msg) + self._log_stat("delete_file", file_path, False, str(e)) + return {"error": error_msg} + + async def create_directory(self, dir_path: str) -> Dict[str, Any]: + """ + Create a directory. + + Note: Ignores .gitignore - allows creating directories in allowed folders. + + Args: + dir_path: Path to the directory to create + + Returns: + Dictionary containing: + - success: Whether the operation succeeded + - path: Directory path + - created: Whether a new directory was created + """ + logger = get_logger() + + try: + path = self.config.normalize_path(dir_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("create_directory", str(path.parent), False, error_msg) + return {"error": error_msg} + + already_exists = path.exists() + + if already_exists and not path.is_dir(): + error_msg = f"Path exists but is not a directory: {path}" + logger.warning(error_msg) + self._log_stat("create_directory", str(path), False, error_msg) + return {"error": error_msg} + + # Create directory (and parents) + path.mkdir(parents=True, exist_ok=True) + + if already_exists: + logger.info(f"Directory already exists: {path}") + else: + logger.info(f"Created directory: {path}") + + self._log_stat("create_directory", str(path.parent), True) + + return { + "success": True, + "path": str(path), + "created": not already_exists + } + + except PermissionError as e: + error_msg = f"Permission denied creating directory {dir_path}: {e}" + logger.error(error_msg) + self._log_stat("create_directory", dir_path, False, str(e)) + return {"error": error_msg, "error_type": "PermissionError"} + except Exception as e: + error_msg = f"Error creating directory {dir_path}: {e}" + logger.error(error_msg) + self._log_stat("create_directory", dir_path, False, str(e)) + return {"error": error_msg} + + async def move_file(self, source_path: str, dest_path: str) -> Dict[str, Any]: + """ + Move or rename a file. + + Note: Ignores .gitignore - allows moving files within allowed folders. + + Args: + source_path: Path to the source file + dest_path: Destination path + + Returns: + Dictionary containing: + - success: Whether the operation succeeded + - source: Source path + - destination: Destination path + """ + import shutil + logger = get_logger() + + try: + source = self.config.normalize_path(source_path) + dest = self.config.normalize_path(dest_path) + + # Both paths must be in allowed folders + if not self.is_allowed_path(source): + error_msg = f"Access denied: source {source} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("move_file", str(source.parent), False, error_msg) + return {"error": error_msg} + + if not self.is_allowed_path(dest): + error_msg = f"Access denied: destination {dest} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("move_file", str(dest.parent), False, error_msg) + return {"error": error_msg} + + if not source.exists(): + error_msg = f"Source file not found: {source}" + logger.warning(error_msg) + self._log_stat("move_file", str(source), False, error_msg) + return {"error": error_msg} + + if not source.is_file(): + error_msg = f"Source is not a file: {source}" + logger.warning(error_msg) + self._log_stat("move_file", str(source), False, error_msg) + return {"error": error_msg} + + # Create destination parent directory if needed + dest.parent.mkdir(parents=True, exist_ok=True) + + # Move/rename the file + shutil.move(str(source), str(dest)) + + logger.info(f"Moved file: {source} -> {dest}") + self._log_stat("move_file", str(source.parent), True) + + return { + "success": True, + "source": str(source), + "destination": str(dest) + } + + except PermissionError as e: + error_msg = f"Permission denied moving file: {e}" + logger.error(error_msg) + self._log_stat("move_file", source_path, False, str(e)) + return {"error": error_msg, "error_type": "PermissionError"} + except Exception as e: + error_msg = f"Error moving file from {source_path} to {dest_path}: {e}" + logger.error(error_msg) + self._log_stat("move_file", source_path, False, str(e)) + return {"error": error_msg} + + async def copy_file(self, source_path: str, dest_path: str) -> Dict[str, Any]: + """ + Copy a file. + + Note: Ignores .gitignore - allows copying files within allowed folders. + + Args: + source_path: Path to the source file + dest_path: Destination path + + Returns: + Dictionary containing: + - success: Whether the operation succeeded + - source: Source path + - destination: Destination path + - size: Size of copied file + """ + import shutil + logger = get_logger() + + try: + source = self.config.normalize_path(source_path) + dest = self.config.normalize_path(dest_path) + + # Both paths must be in allowed folders + if not self.is_allowed_path(source): + error_msg = f"Access denied: source {source} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("copy_file", str(source.parent), False, error_msg) + return {"error": error_msg} + + if not self.is_allowed_path(dest): + error_msg = f"Access denied: destination {dest} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("copy_file", str(dest.parent), False, error_msg) + return {"error": error_msg} + + if not source.exists(): + error_msg = f"Source file not found: {source}" + logger.warning(error_msg) + self._log_stat("copy_file", str(source), False, error_msg) + return {"error": error_msg} + + if not source.is_file(): + error_msg = f"Source is not a file: {source}" + logger.warning(error_msg) + self._log_stat("copy_file", str(source), False, error_msg) + return {"error": error_msg} + + # Create destination parent directory if needed + dest.parent.mkdir(parents=True, exist_ok=True) + + # Copy the file (preserves metadata) + shutil.copy2(str(source), str(dest)) + + file_size = dest.stat().st_size + + logger.info(f"Copied file: {source} -> {dest} ({file_size} bytes)") + self._log_stat("copy_file", str(source.parent), True) + + return { + "success": True, + "source": str(source), + "destination": str(dest), + "size": file_size + } + + except PermissionError as e: + error_msg = f"Permission denied copying file: {e}" + logger.error(error_msg) + self._log_stat("copy_file", source_path, False, str(e)) + return {"error": error_msg, "error_type": "PermissionError"} + except Exception as e: + error_msg = f"Error copying file from {source_path} to {dest_path}: {e}" + logger.error(error_msg) + self._log_stat("copy_file", source_path, False, str(e)) + return {"error": error_msg} + + # ========================================================================= + # DATABASE OPERATIONS + # ========================================================================= + + async def inspect_database( + self, + db_path: str, + table_name: Optional[str] = None + ) -> Dict[str, Any]: + """ + Inspect SQLite database schema. + + Args: + db_path: Path to the database file + table_name: Optional specific table to inspect + + Returns: + Dictionary containing database/table information + """ + logger = get_logger() + + try: + path = Path(db_path).resolve() + + if not path.exists(): + error_msg = f"Database not found: {path}" + logger.warning(error_msg) + self._log_stat("inspect_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + if not path.is_file(): + error_msg = f"Not a file: {path}" + logger.warning(error_msg) + self._log_stat("inspect_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Open database in read-only mode + try: + conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True) + cursor = conn.cursor() + except sqlite3.DatabaseError as e: + error_msg = f"Not a valid SQLite database: {path} ({e})" + logger.warning(error_msg) + self._log_stat("inspect_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + try: + if table_name: + # Get info for specific table + cursor.execute( + "SELECT sql FROM sqlite_master WHERE type='table' AND name=?", + (table_name,) + ) + result = cursor.fetchone() + + if not result: + return {"error": f"Table '{table_name}' not found"} + + create_sql = result[0] + + # Get column info + cursor.execute(f"PRAGMA table_info({table_name})") + columns = [] + for row in cursor.fetchall(): + columns.append({ + "id": row[0], + "name": row[1], + "type": row[2], + "not_null": bool(row[3]), + "default_value": row[4], + "primary_key": bool(row[5]) + }) + + # Get indexes + cursor.execute(f"PRAGMA index_list({table_name})") + indexes = [] + for row in cursor.fetchall(): + indexes.append({ + "name": row[1], + "unique": bool(row[2]) + }) + + # Get row count + cursor.execute(f"SELECT COUNT(*) FROM {table_name}") + row_count = cursor.fetchone()[0] + + logger.info(f"Inspected table: {table_name} in {path}") + self._log_stat("inspect_database", str(path.parent), True) + + return { + "database": str(path), + "table": table_name, + "create_sql": create_sql, + "columns": columns, + "indexes": indexes, + "row_count": row_count + } + else: + # Get all tables + cursor.execute( + "SELECT name, sql FROM sqlite_master WHERE type='table' ORDER BY name" + ) + tables = [] + for row in cursor.fetchall(): + table = row[0] + cursor.execute(f"SELECT COUNT(*) FROM {table}") + count = cursor.fetchone()[0] + tables.append({ + "name": table, + "row_count": count + }) + + # Get indexes + cursor.execute( + "SELECT name, tbl_name FROM sqlite_master WHERE type='index' ORDER BY name" + ) + indexes = [{"name": row[0], "table": row[1]} for row in cursor.fetchall()] + + # Get triggers + cursor.execute( + "SELECT name, tbl_name FROM sqlite_master WHERE type='trigger' ORDER BY name" + ) + triggers = [{"name": row[0], "table": row[1]} for row in cursor.fetchall()] + + # Get database size + db_size = path.stat().st_size + + logger.info(f"Inspected database: {path} ({len(tables)} tables)") + self._log_stat("inspect_database", str(path.parent), True) + + return { + "database": str(path), + "size": db_size, + "size_mb": db_size / (1024 * 1024), + "tables": tables, + "table_count": len(tables), + "indexes": indexes, + "triggers": triggers + } + finally: + conn.close() + + except Exception as e: + error_msg = f"Error inspecting database: {e}" + logger.error(error_msg) + self._log_stat("inspect_database", db_path, False, str(e)) + return {"error": error_msg} + + async def search_database( + self, + db_path: str, + search_term: str, + table_name: Optional[str] = None, + column_name: Optional[str] = None + ) -> Dict[str, Any]: + """ + Search for a value across database tables. + + Args: + db_path: Path to the database file + search_term: Value to search for (partial match) + table_name: Optional table to limit search + column_name: Optional column to limit search + + Returns: + Dictionary containing search results + """ + logger = get_logger() + + try: + path = Path(db_path).resolve() + + if not path.exists(): + error_msg = f"Database not found: {path}" + logger.warning(error_msg) + self._log_stat("search_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Open database in read-only mode + try: + conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True) + cursor = conn.cursor() + except sqlite3.DatabaseError as e: + error_msg = f"Not a valid SQLite database: {path} ({e})" + logger.warning(error_msg) + self._log_stat("search_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + try: + matches = [] + + # Get tables to search + if table_name: + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (table_name,) + ) + tables = [row[0] for row in cursor.fetchall()] + if not tables: + return {"error": f"Table '{table_name}' not found"} + else: + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = [row[0] for row in cursor.fetchall()] + + for table in tables: + # Get columns for this table + cursor.execute(f"PRAGMA table_info({table})") + columns = [row[1] for row in cursor.fetchall()] + + # Filter columns if specified + if column_name: + if column_name not in columns: + continue + columns = [column_name] + + # Search in each column + for column in columns: + try: + query = f"SELECT * FROM {table} WHERE {column} LIKE ? LIMIT {self.default_query_limit}" + cursor.execute(query, (f"%{search_term}%",)) + results = cursor.fetchall() + + if results: + col_names = [desc[0] for desc in cursor.description] + + for row in results: + row_dict = {} + for col_name, value in zip(col_names, row): + if isinstance(value, bytes): + try: + row_dict[col_name] = value.decode("utf-8") + except UnicodeDecodeError: + row_dict[col_name] = f"" + else: + row_dict[col_name] = value + + matches.append({ + "table": table, + "column": column, + "row": row_dict + }) + except sqlite3.Error: + # Skip columns that can't be searched + continue + + logger.info( + f"Searched database: {path} for '{search_term}' - " + f"found {len(matches)} matches" + ) + self._log_stat("search_database", str(path.parent), True) + + result = { + "database": str(path), + "search_term": search_term, + "matches": matches, + "count": len(matches) + } + + if table_name: + result["table_filter"] = table_name + if column_name: + result["column_filter"] = column_name + + if len(matches) >= self.default_query_limit: + result["note"] = ( + f"Results limited to {self.default_query_limit} matches. " + "Use more specific search or query_database for full results." + ) + + return result + + finally: + conn.close() + + except Exception as e: + error_msg = f"Error searching database: {e}" + logger.error(error_msg) + self._log_stat("search_database", db_path, False, str(e)) + return {"error": error_msg} + + async def query_database( + self, + db_path: str, + query: str, + limit: Optional[int] = None + ) -> Dict[str, Any]: + """ + Execute a read-only SQL query. + + Args: + db_path: Path to the database file + query: SQL query to execute (SELECT only) + limit: Optional row limit + + Returns: + Dictionary containing query results + """ + logger = get_logger() + + try: + path = Path(db_path).resolve() + + if not path.exists(): + error_msg = f"Database not found: {path}" + logger.warning(error_msg) + self._log_stat("query_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Validate query + is_safe, error_msg = SQLiteQueryValidator.is_safe_query(query) + if not is_safe: + logger.warning(f"Unsafe query rejected: {error_msg}") + self._log_stat("query_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Open database in read-only mode + try: + conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True) + cursor = conn.cursor() + except sqlite3.DatabaseError as e: + error_msg = f"Not a valid SQLite database: {path} ({e})" + logger.warning(error_msg) + self._log_stat("query_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + try: + # Set query timeout (Unix only) + def timeout_handler(signum, frame): + raise TimeoutError( + f"Query exceeded {self.max_query_timeout} second timeout" + ) + + if hasattr(signal, "SIGALRM"): + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(self.max_query_timeout) + + try: + # Execute query + cursor.execute(query) + + # Get results + result_limit = min( + limit or self.default_query_limit, + self.max_query_results + ) + results = cursor.fetchmany(result_limit) + + # Get column names + columns = [ + desc[0] for desc in cursor.description + ] if cursor.description else [] + + # Convert to list of dicts + rows = [] + for row in results: + row_dict = {} + for col_name, value in zip(columns, row): + if isinstance(value, bytes): + try: + row_dict[col_name] = value.decode("utf-8") + except UnicodeDecodeError: + row_dict[col_name] = f"" + else: + row_dict[col_name] = value + rows.append(row_dict) + + # Check if more results available + has_more = len(results) == result_limit + if has_more: + one_more = cursor.fetchone() + has_more = one_more is not None + + finally: + if hasattr(signal, "SIGALRM"): + signal.alarm(0) + + logger.info(f"Executed query on {path}: returned {len(rows)} rows") + self._log_stat("query_database", str(path.parent), True) + + result = { + "database": str(path), + "query": query, + "columns": columns, + "rows": rows, + "count": len(rows) + } + + if has_more: + result["truncated"] = True + result["note"] = ( + f"Results limited to {result_limit} rows. " + "Use LIMIT clause in query for more control." + ) + + return result + + except TimeoutError as e: + error_msg = str(e) + logger.warning(error_msg) + self._log_stat("query_database", str(path.parent), False, error_msg) + return {"error": error_msg} + except sqlite3.Error as e: + error_msg = f"SQL error: {e}" + logger.warning(error_msg) + self._log_stat("query_database", str(path.parent), False, error_msg) + return {"error": error_msg} + finally: + if hasattr(signal, "SIGALRM"): + signal.alarm(0) + conn.close() + + except Exception as e: + error_msg = f"Error executing query: {e}" + logger.error(error_msg) + self._log_stat("query_database", db_path, False, str(e)) + return {"error": error_msg} diff --git a/oai/mcp/validators.py b/oai/mcp/validators.py new file mode 100644 index 0000000..387a5fd --- /dev/null +++ b/oai/mcp/validators.py @@ -0,0 +1,123 @@ +""" +Query validation for oAI MCP database operations. + +This module provides safety validation for SQL queries to ensure +only read-only operations are executed. +""" + +import re +from typing import Tuple + +from oai.constants import DANGEROUS_SQL_KEYWORDS + + +class SQLiteQueryValidator: + """ + Validate SQLite queries for read-only safety. + + Ensures that only SELECT queries (including CTEs) are allowed + and blocks potentially dangerous operations like INSERT, UPDATE, + DELETE, DROP, etc. + """ + + @staticmethod + def is_safe_query(query: str) -> Tuple[bool, str]: + """ + Validate that a query is a safe read-only SELECT. + + The validation: + 1. Checks that query starts with SELECT or WITH + 2. Strips string literals before checking for dangerous keywords + 3. Blocks any dangerous keywords outside of string literals + + Args: + query: SQL query string to validate + + Returns: + Tuple of (is_safe, error_message) + - is_safe: True if the query is safe to execute + - error_message: Description of why query is unsafe (empty if safe) + + Examples: + >>> SQLiteQueryValidator.is_safe_query("SELECT * FROM users") + (True, "") + >>> SQLiteQueryValidator.is_safe_query("DELETE FROM users") + (False, "Only SELECT queries are allowed...") + >>> SQLiteQueryValidator.is_safe_query("SELECT 'DELETE' FROM users") + (True, "") # 'DELETE' is inside a string literal + """ + query_upper = query.strip().upper() + + # Must start with SELECT or WITH (for CTEs) + if not (query_upper.startswith("SELECT") or query_upper.startswith("WITH")): + return False, "Only SELECT queries are allowed (including WITH/CTE)" + + # Remove string literals before checking for dangerous keywords + # This prevents false positives when keywords appear in data + query_no_strings = re.sub(r"'[^']*'", "", query_upper) + query_no_strings = re.sub(r'"[^"]*"', "", query_no_strings) + + # Check for dangerous keywords outside of quotes + for keyword in DANGEROUS_SQL_KEYWORDS: + if re.search(r"\b" + keyword + r"\b", query_no_strings): + return False, f"Keyword '{keyword}' not allowed in read-only mode" + + return True, "" + + @staticmethod + def sanitize_table_name(table_name: str) -> str: + """ + Sanitize a table name to prevent SQL injection. + + Only allows alphanumeric characters and underscores. + + Args: + table_name: Table name to sanitize + + Returns: + Sanitized table name + + Raises: + ValueError: If table name contains invalid characters + """ + # Remove any characters that aren't alphanumeric or underscore + sanitized = re.sub(r"[^\w]", "", table_name) + + if not sanitized: + raise ValueError("Table name cannot be empty after sanitization") + + if sanitized != table_name: + raise ValueError( + f"Table name contains invalid characters: {table_name}" + ) + + return sanitized + + @staticmethod + def sanitize_column_name(column_name: str) -> str: + """ + Sanitize a column name to prevent SQL injection. + + Only allows alphanumeric characters and underscores. + + Args: + column_name: Column name to sanitize + + Returns: + Sanitized column name + + Raises: + ValueError: If column name contains invalid characters + """ + # Remove any characters that aren't alphanumeric or underscore + sanitized = re.sub(r"[^\w]", "", column_name) + + if not sanitized: + raise ValueError("Column name cannot be empty after sanitization") + + if sanitized != column_name: + raise ValueError( + f"Column name contains invalid characters: {column_name}" + ) + + return sanitized diff --git a/oai/providers/__init__.py b/oai/providers/__init__.py new file mode 100644 index 0000000..93df1e5 --- /dev/null +++ b/oai/providers/__init__.py @@ -0,0 +1,32 @@ +""" +Provider abstraction for oAI. + +This module provides a unified interface for AI model providers, +enabling easy extension to support additional providers beyond OpenRouter. +""" + +from oai.providers.base import ( + AIProvider, + ChatMessage, + ChatResponse, + ToolCall, + ToolFunction, + UsageStats, + ModelInfo, + ProviderCapabilities, +) +from oai.providers.openrouter import OpenRouterProvider + +__all__ = [ + # Base classes and types + "AIProvider", + "ChatMessage", + "ChatResponse", + "ToolCall", + "ToolFunction", + "UsageStats", + "ModelInfo", + "ProviderCapabilities", + # Provider implementations + "OpenRouterProvider", +] diff --git a/oai/providers/base.py b/oai/providers/base.py new file mode 100644 index 0000000..865fe7d --- /dev/null +++ b/oai/providers/base.py @@ -0,0 +1,413 @@ +""" +Abstract base provider for AI model integration. + +This module defines the interface that all AI providers must implement, +along with common data structures for requests and responses. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import ( + Any, + AsyncIterator, + Dict, + Iterator, + List, + Optional, + Union, +) + + +class MessageRole(str, Enum): + """Message roles in a conversation.""" + + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" + + +@dataclass +class ToolFunction: + """ + Represents a function within a tool call. + + Attributes: + name: The function name + arguments: JSON string of function arguments + """ + + name: str + arguments: str + + +@dataclass +class ToolCall: + """ + Represents a tool/function call requested by the model. + + Attributes: + id: Unique identifier for this tool call + type: Type of tool call (usually "function") + function: The function being called + """ + + id: str + type: str + function: ToolFunction + + +@dataclass +class UsageStats: + """ + Token usage statistics from an API response. + + Attributes: + prompt_tokens: Number of tokens in the prompt + completion_tokens: Number of tokens in the completion + total_tokens: Total tokens used + total_cost_usd: Cost in USD (if available from API) + """ + + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + total_cost_usd: Optional[float] = None + + @property + def input_tokens(self) -> int: + """Alias for prompt_tokens.""" + return self.prompt_tokens + + @property + def output_tokens(self) -> int: + """Alias for completion_tokens.""" + return self.completion_tokens + + +@dataclass +class ChatMessage: + """ + A single message in a chat conversation. + + Attributes: + role: The role of the message sender + content: Message content (text or structured content blocks) + name: Optional name for the sender + tool_calls: List of tool calls (for assistant messages) + tool_call_id: Tool call ID this message responds to (for tool messages) + """ + + role: str + content: Union[str, List[Dict[str, Any]], None] = None + name: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = None + tool_call_id: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format for API requests.""" + result: Dict[str, Any] = {"role": self.role} + + if self.content is not None: + result["content"] = self.content + + if self.name: + result["name"] = self.name + + if self.tool_calls: + result["tool_calls"] = [ + { + "id": tc.id, + "type": tc.type, + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in self.tool_calls + ] + + if self.tool_call_id: + result["tool_call_id"] = self.tool_call_id + + return result + + +@dataclass +class ChatResponseChoice: + """ + A single choice in a chat response. + + Attributes: + index: Index of this choice + message: The response message + finish_reason: Why the response ended + """ + + index: int + message: ChatMessage + finish_reason: Optional[str] = None + + +@dataclass +class ChatResponse: + """ + Response from a chat completion request. + + Attributes: + id: Unique identifier for this response + choices: List of response choices + usage: Token usage statistics + model: Model that generated this response + created: Unix timestamp of creation + """ + + id: str + choices: List[ChatResponseChoice] + usage: Optional[UsageStats] = None + model: Optional[str] = None + created: Optional[int] = None + + @property + def message(self) -> Optional[ChatMessage]: + """Get the first choice's message.""" + if self.choices: + return self.choices[0].message + return None + + @property + def content(self) -> Optional[str]: + """Get the text content of the first choice.""" + msg = self.message + if msg and isinstance(msg.content, str): + return msg.content + return None + + @property + def tool_calls(self) -> Optional[List[ToolCall]]: + """Get tool calls from the first choice.""" + msg = self.message + if msg: + return msg.tool_calls + return None + + +@dataclass +class StreamChunk: + """ + A single chunk from a streaming response. + + Attributes: + id: Response ID + delta_content: New content in this chunk + finish_reason: Finish reason (if this is the last chunk) + usage: Usage stats (usually in the last chunk) + error: Error message if something went wrong + """ + + id: str + delta_content: Optional[str] = None + finish_reason: Optional[str] = None + usage: Optional[UsageStats] = None + error: Optional[str] = None + + +@dataclass +class ModelInfo: + """ + Information about an AI model. + + Attributes: + id: Unique model identifier + name: Display name + description: Model description + context_length: Maximum context window size + pricing: Pricing info (input/output per million tokens) + supported_parameters: List of supported API parameters + input_modalities: Supported input types (text, image, etc.) + output_modalities: Supported output types + """ + + id: str + name: str + description: str = "" + context_length: int = 0 + pricing: Dict[str, float] = field(default_factory=dict) + supported_parameters: List[str] = field(default_factory=list) + input_modalities: List[str] = field(default_factory=lambda: ["text"]) + output_modalities: List[str] = field(default_factory=lambda: ["text"]) + + def supports_images(self) -> bool: + """Check if model supports image input.""" + return "image" in self.input_modalities + + def supports_tools(self) -> bool: + """Check if model supports function calling/tools.""" + return "tools" in self.supported_parameters or "functions" in self.supported_parameters + + def supports_streaming(self) -> bool: + """Check if model supports streaming responses.""" + return "stream" in self.supported_parameters + + def supports_online(self) -> bool: + """Check if model supports web search (online mode).""" + return self.supports_tools() + + +@dataclass +class ProviderCapabilities: + """ + Capabilities supported by a provider. + + Attributes: + streaming: Provider supports streaming responses + tools: Provider supports function calling + images: Provider supports image inputs + online: Provider supports web search + max_context: Maximum context length across all models + """ + + streaming: bool = True + tools: bool = True + images: bool = True + online: bool = False + max_context: int = 128000 + + +class AIProvider(ABC): + """ + Abstract base class for AI model providers. + + All provider implementations must inherit from this class + and implement the required abstract methods. + """ + + def __init__(self, api_key: str, base_url: Optional[str] = None): + """ + Initialize the provider. + + Args: + api_key: API key for authentication + base_url: Optional custom base URL for the API + """ + self.api_key = api_key + self.base_url = base_url + + @property + @abstractmethod + def name(self) -> str: + """Get the provider name.""" + pass + + @property + @abstractmethod + def capabilities(self) -> ProviderCapabilities: + """Get provider capabilities.""" + pass + + @abstractmethod + def list_models(self) -> List[ModelInfo]: + """ + Fetch available models from the provider. + + Returns: + List of available models with their info + """ + pass + + @abstractmethod + def get_model(self, model_id: str) -> Optional[ModelInfo]: + """ + Get information about a specific model. + + Args: + model_id: The model identifier + + Returns: + Model information or None if not found + """ + pass + + @abstractmethod + def chat( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + **kwargs: Any, + ) -> Union[ChatResponse, Iterator[StreamChunk]]: + """ + Send a chat completion request. + + Args: + model: Model ID to use + messages: List of chat messages + stream: Whether to stream the response + max_tokens: Maximum tokens in response + temperature: Sampling temperature + tools: List of tool definitions for function calling + tool_choice: How to handle tool selection ("auto", "none", etc.) + **kwargs: Additional provider-specific parameters + + Returns: + ChatResponse for non-streaming, Iterator[StreamChunk] for streaming + """ + pass + + @abstractmethod + async def chat_async( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + **kwargs: Any, + ) -> Union[ChatResponse, AsyncIterator[StreamChunk]]: + """ + Send an async chat completion request. + + Args: + model: Model ID to use + messages: List of chat messages + stream: Whether to stream the response + max_tokens: Maximum tokens in response + temperature: Sampling temperature + tools: List of tool definitions for function calling + tool_choice: How to handle tool selection + **kwargs: Additional provider-specific parameters + + Returns: + ChatResponse for non-streaming, AsyncIterator[StreamChunk] for streaming + """ + pass + + @abstractmethod + def get_credits(self) -> Optional[Dict[str, Any]]: + """ + Get account credit/balance information. + + Returns: + Dict with credit info or None if not supported + """ + pass + + def validate_api_key(self) -> bool: + """ + Validate that the API key is valid. + + Returns: + True if API key is valid + """ + try: + self.list_models() + return True + except Exception: + return False diff --git a/oai/providers/openrouter.py b/oai/providers/openrouter.py new file mode 100644 index 0000000..075fb05 --- /dev/null +++ b/oai/providers/openrouter.py @@ -0,0 +1,623 @@ +""" +OpenRouter provider implementation. + +This module implements the AIProvider interface for OpenRouter, +supporting chat completions, streaming, and function calling. +""" + +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union + +import requests +from openrouter import OpenRouter + +from oai.constants import APP_NAME, APP_URL, DEFAULT_BASE_URL +from oai.providers.base import ( + AIProvider, + ChatMessage, + ChatResponse, + ChatResponseChoice, + ModelInfo, + ProviderCapabilities, + StreamChunk, + ToolCall, + ToolFunction, + UsageStats, +) +from oai.utils.logging import get_logger + + +class OpenRouterProvider(AIProvider): + """ + OpenRouter API provider implementation. + + Provides access to multiple AI models through OpenRouter's unified API, + supporting chat completions, streaming responses, and function calling. + + Attributes: + client: The underlying OpenRouter client + _models_cache: Cached list of available models + """ + + def __init__( + self, + api_key: str, + base_url: Optional[str] = None, + app_name: str = APP_NAME, + app_url: str = APP_URL, + ): + """ + Initialize the OpenRouter provider. + + Args: + api_key: OpenRouter API key + base_url: Optional custom base URL + app_name: Application name for API headers + app_url: Application URL for API headers + """ + super().__init__(api_key, base_url or DEFAULT_BASE_URL) + self.app_name = app_name + self.app_url = app_url + self.client = OpenRouter(api_key=api_key) + self._models_cache: Optional[List[ModelInfo]] = None + self._raw_models_cache: Optional[List[Dict[str, Any]]] = None + + self.logger = get_logger() + self.logger.info(f"OpenRouter provider initialized with base URL: {self.base_url}") + + @property + def name(self) -> str: + """Get the provider name.""" + return "OpenRouter" + + @property + def capabilities(self) -> ProviderCapabilities: + """Get provider capabilities.""" + return ProviderCapabilities( + streaming=True, + tools=True, + images=True, + online=True, + max_context=2000000, # Claude models support up to 200k + ) + + def _get_headers(self) -> Dict[str, str]: + """Get standard HTTP headers for API requests.""" + headers = { + "HTTP-Referer": self.app_url, + "X-Title": self.app_name, + } + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + return headers + + def _parse_model(self, model_data: Dict[str, Any]) -> ModelInfo: + """ + Parse raw model data into ModelInfo. + + Args: + model_data: Raw model data from API + + Returns: + Parsed ModelInfo object + """ + architecture = model_data.get("architecture", {}) + pricing_data = model_data.get("pricing", {}) + + # Parse pricing (convert from string to float if needed) + pricing = {} + for key in ["prompt", "completion"]: + value = pricing_data.get(key) + if value is not None: + try: + # Convert from per-token to per-million-tokens + pricing[key] = float(value) * 1_000_000 + except (ValueError, TypeError): + pricing[key] = 0.0 + + return ModelInfo( + id=model_data.get("id", ""), + name=model_data.get("name", model_data.get("id", "")), + description=model_data.get("description", ""), + context_length=model_data.get("context_length", 0), + pricing=pricing, + supported_parameters=model_data.get("supported_parameters", []), + input_modalities=architecture.get("input_modalities", ["text"]), + output_modalities=architecture.get("output_modalities", ["text"]), + ) + + def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]: + """ + Fetch available models from OpenRouter. + + Args: + filter_text_only: If True, exclude video-only models + + Returns: + List of available models + + Raises: + Exception: If API request fails + """ + if self._models_cache is not None: + return self._models_cache + + try: + response = requests.get( + f"{self.base_url}/models", + headers=self._get_headers(), + timeout=10, + ) + response.raise_for_status() + + raw_models = response.json().get("data", []) + self._raw_models_cache = raw_models + + models = [] + for model_data in raw_models: + # Optionally filter out video-only models + if filter_text_only: + modalities = model_data.get("modalities", []) + if modalities and "video" in modalities and "text" not in modalities: + continue + + models.append(self._parse_model(model_data)) + + self._models_cache = models + self.logger.info(f"Fetched {len(models)} models from OpenRouter") + return models + + except requests.RequestException as e: + self.logger.error(f"Failed to fetch models: {e}") + raise + + def get_raw_models(self) -> List[Dict[str, Any]]: + """ + Get raw model data as returned by the API. + + Useful for accessing provider-specific fields not in ModelInfo. + + Returns: + List of raw model dictionaries + """ + if self._raw_models_cache is None: + self.list_models() + return self._raw_models_cache or [] + + def get_model(self, model_id: str) -> Optional[ModelInfo]: + """ + Get information about a specific model. + + Args: + model_id: The model identifier + + Returns: + Model information or None if not found + """ + models = self.list_models() + for model in models: + if model.id == model_id: + return model + return None + + def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]: + """ + Get raw model data for a specific model. + + Args: + model_id: The model identifier + + Returns: + Raw model dictionary or None if not found + """ + raw_models = self.get_raw_models() + for model in raw_models: + if model.get("id") == model_id: + return model + return None + + def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]: + """ + Convert ChatMessage objects to API format. + + Args: + messages: List of ChatMessage objects + + Returns: + List of message dictionaries for the API + """ + return [msg.to_dict() for msg in messages] + + def _parse_usage(self, usage_data: Any) -> Optional[UsageStats]: + """ + Parse usage data from API response. + + Args: + usage_data: Raw usage data from API + + Returns: + Parsed UsageStats or None + """ + if not usage_data: + return None + + # Handle both attribute and dict access + prompt_tokens = 0 + completion_tokens = 0 + total_cost = None + + if hasattr(usage_data, "prompt_tokens"): + prompt_tokens = getattr(usage_data, "prompt_tokens", 0) or 0 + elif isinstance(usage_data, dict): + prompt_tokens = usage_data.get("prompt_tokens", 0) or 0 + + if hasattr(usage_data, "completion_tokens"): + completion_tokens = getattr(usage_data, "completion_tokens", 0) or 0 + elif isinstance(usage_data, dict): + completion_tokens = usage_data.get("completion_tokens", 0) or 0 + + # Try alternative naming (input_tokens/output_tokens) + if prompt_tokens == 0: + if hasattr(usage_data, "input_tokens"): + prompt_tokens = getattr(usage_data, "input_tokens", 0) or 0 + elif isinstance(usage_data, dict): + prompt_tokens = usage_data.get("input_tokens", 0) or 0 + + if completion_tokens == 0: + if hasattr(usage_data, "output_tokens"): + completion_tokens = getattr(usage_data, "output_tokens", 0) or 0 + elif isinstance(usage_data, dict): + completion_tokens = usage_data.get("output_tokens", 0) or 0 + + # Get cost if available + if hasattr(usage_data, "total_cost_usd"): + total_cost = getattr(usage_data, "total_cost_usd", None) + elif isinstance(usage_data, dict): + total_cost = usage_data.get("total_cost_usd") + + return UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + total_cost_usd=float(total_cost) if total_cost else None, + ) + + def _parse_tool_calls(self, tool_calls_data: Any) -> Optional[List[ToolCall]]: + """ + Parse tool calls from API response. + + Args: + tool_calls_data: Raw tool calls data + + Returns: + List of ToolCall objects or None + """ + if not tool_calls_data: + return None + + tool_calls = [] + for tc in tool_calls_data: + # Handle both attribute and dict access + if hasattr(tc, "id"): + tc_id = tc.id + tc_type = getattr(tc, "type", "function") + func = tc.function + func_name = func.name + func_args = func.arguments + else: + tc_id = tc.get("id", "") + tc_type = tc.get("type", "function") + func = tc.get("function", {}) + func_name = func.get("name", "") + func_args = func.get("arguments", "{}") + + tool_calls.append( + ToolCall( + id=tc_id, + type=tc_type, + function=ToolFunction(name=func_name, arguments=func_args), + ) + ) + + return tool_calls if tool_calls else None + + def _parse_response(self, response: Any) -> ChatResponse: + """ + Parse API response into ChatResponse. + + Args: + response: Raw API response + + Returns: + Parsed ChatResponse + """ + choices = [] + for choice in response.choices: + msg = choice.message + message = ChatMessage( + role=msg.role if hasattr(msg, "role") else "assistant", + content=msg.content if hasattr(msg, "content") else None, + tool_calls=self._parse_tool_calls( + getattr(msg, "tool_calls", None) + ), + ) + choices.append( + ChatResponseChoice( + index=choice.index if hasattr(choice, "index") else 0, + message=message, + finish_reason=getattr(choice, "finish_reason", None), + ) + ) + + return ChatResponse( + id=response.id if hasattr(response, "id") else "", + choices=choices, + usage=self._parse_usage(getattr(response, "usage", None)), + model=getattr(response, "model", None), + created=getattr(response, "created", None), + ) + + def chat( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + transforms: Optional[List[str]] = None, + **kwargs: Any, + ) -> Union[ChatResponse, Iterator[StreamChunk]]: + """ + Send a chat completion request to OpenRouter. + + Args: + model: Model ID to use + messages: List of chat messages + stream: Whether to stream the response + max_tokens: Maximum tokens in response + temperature: Sampling temperature (0-2) + tools: List of tool definitions for function calling + tool_choice: How to handle tool selection ("auto", "none", etc.) + transforms: List of transforms (e.g., ["middle-out"]) + **kwargs: Additional parameters + + Returns: + ChatResponse for non-streaming, Iterator[StreamChunk] for streaming + """ + # Build request parameters + params: Dict[str, Any] = { + "model": model, + "messages": self._convert_messages(messages), + "stream": stream, + "http_headers": self._get_headers(), + } + + # Request usage stats in streaming responses + if stream: + params["stream_options"] = {"include_usage": True} + + if max_tokens is not None: + params["max_tokens"] = max_tokens + + if temperature is not None: + params["temperature"] = temperature + + if tools: + params["tools"] = tools + params["tool_choice"] = tool_choice or "auto" + + if transforms: + params["transforms"] = transforms + + # Add any additional parameters + params.update(kwargs) + + self.logger.debug(f"Sending chat request to model {model}") + + try: + response = self.client.chat.send(**params) + + if stream: + return self._stream_response(response) + else: + return self._parse_response(response) + + except Exception as e: + self.logger.error(f"Chat request failed: {e}") + raise + + def _stream_response(self, response: Any) -> Iterator[StreamChunk]: + """ + Process a streaming response. + + Args: + response: Streaming response from API + + Yields: + StreamChunk objects + """ + last_usage = None + + try: + for chunk in response: + # Check for errors + if hasattr(chunk, "error") and chunk.error: + yield StreamChunk( + id=getattr(chunk, "id", ""), + error=chunk.error.message if hasattr(chunk.error, "message") else str(chunk.error), + ) + return + + # Extract delta content + delta_content = None + finish_reason = None + + if hasattr(chunk, "choices") and chunk.choices: + choice = chunk.choices[0] + if hasattr(choice, "delta"): + delta = choice.delta + if hasattr(delta, "content") and delta.content: + delta_content = delta.content + finish_reason = getattr(choice, "finish_reason", None) + + # Track usage from last chunk + if hasattr(chunk, "usage") and chunk.usage: + last_usage = self._parse_usage(chunk.usage) + + yield StreamChunk( + id=getattr(chunk, "id", ""), + delta_content=delta_content, + finish_reason=finish_reason, + usage=last_usage if finish_reason else None, + ) + + except Exception as e: + self.logger.error(f"Stream error: {e}") + yield StreamChunk(id="", error=str(e)) + + async def chat_async( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + **kwargs: Any, + ) -> Union[ChatResponse, AsyncIterator[StreamChunk]]: + """ + Send an async chat completion request. + + Note: Currently wraps the sync implementation. + TODO: Implement true async support when OpenRouter SDK supports it. + + Args: + model: Model ID to use + messages: List of chat messages + stream: Whether to stream the response + max_tokens: Maximum tokens in response + temperature: Sampling temperature + tools: List of tool definitions + tool_choice: Tool selection mode + **kwargs: Additional parameters + + Returns: + ChatResponse for non-streaming, AsyncIterator for streaming + """ + # For now, use sync implementation + # TODO: Add true async when SDK supports it + result = self.chat( + model=model, + messages=messages, + stream=stream, + max_tokens=max_tokens, + temperature=temperature, + tools=tools, + tool_choice=tool_choice, + **kwargs, + ) + + if stream and isinstance(result, Iterator): + # Convert sync iterator to async + async def async_iter() -> AsyncIterator[StreamChunk]: + for chunk in result: + yield chunk + + return async_iter() + + return result + + def get_credits(self) -> Optional[Dict[str, Any]]: + """ + Get OpenRouter account credit information. + + Returns: + Dict with credit info: + - total_credits: Total credits purchased + - used_credits: Credits used + - credits_left: Remaining credits + + Raises: + Exception: If API request fails + """ + if not self.api_key: + return None + + try: + response = requests.get( + f"{self.base_url}/credits", + headers=self._get_headers(), + timeout=10, + ) + response.raise_for_status() + + data = response.json().get("data", {}) + total_credits = float(data.get("total_credits", 0)) + total_usage = float(data.get("total_usage", 0)) + credits_left = total_credits - total_usage + + return { + "total_credits": total_credits, + "used_credits": total_usage, + "credits_left": credits_left, + "total_credits_formatted": f"${total_credits:.2f}", + "used_credits_formatted": f"${total_usage:.2f}", + "credits_left_formatted": f"${credits_left:.2f}", + } + + except Exception as e: + self.logger.error(f"Failed to fetch credits: {e}") + return None + + def clear_cache(self) -> None: + """Clear the models cache to force a refresh.""" + self._models_cache = None + self._raw_models_cache = None + self.logger.debug("Models cache cleared") + + def get_effective_model_id(self, model_id: str, online_enabled: bool) -> str: + """ + Get the effective model ID with online suffix if needed. + + Args: + model_id: Base model ID + online_enabled: Whether online mode is enabled + + Returns: + Model ID with :online suffix if applicable + """ + if online_enabled and not model_id.endswith(":online"): + return f"{model_id}:online" + return model_id + + def estimate_cost( + self, + model_id: str, + input_tokens: int, + output_tokens: int, + ) -> float: + """ + Estimate the cost for a completion. + + Args: + model_id: Model ID + input_tokens: Number of input tokens + output_tokens: Number of output tokens + + Returns: + Estimated cost in USD + """ + model = self.get_model(model_id) + if model and model.pricing: + input_cost = model.pricing.get("prompt", 0) * input_tokens / 1_000_000 + output_cost = model.pricing.get("completion", 0) * output_tokens / 1_000_000 + return input_cost + output_cost + + # Fallback to default pricing if model not found + from oai.constants import MODEL_PRICING + + input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000 + output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000 + return input_cost + output_cost diff --git a/oai/py.typed b/oai/py.typed new file mode 100644 index 0000000..cb5707d --- /dev/null +++ b/oai/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561 +# This package supports type checking diff --git a/oai/ui/__init__.py b/oai/ui/__init__.py new file mode 100644 index 0000000..9f4be60 --- /dev/null +++ b/oai/ui/__init__.py @@ -0,0 +1,51 @@ +""" +UI utilities for oAI. + +This module provides rich terminal UI components and display helpers +for the chat application. +""" + +from oai.ui.console import ( + console, + clear_screen, + display_panel, + display_table, + display_markdown, + print_error, + print_warning, + print_success, + print_info, +) +from oai.ui.tables import ( + create_model_table, + create_stats_table, + create_help_table, + display_paginated_table, +) +from oai.ui.prompts import ( + prompt_confirm, + prompt_choice, + prompt_input, +) + +__all__ = [ + # Console utilities + "console", + "clear_screen", + "display_panel", + "display_table", + "display_markdown", + "print_error", + "print_warning", + "print_success", + "print_info", + # Table utilities + "create_model_table", + "create_stats_table", + "create_help_table", + "display_paginated_table", + # Prompt utilities + "prompt_confirm", + "prompt_choice", + "prompt_input", +] diff --git a/oai/ui/console.py b/oai/ui/console.py new file mode 100644 index 0000000..6836b7d --- /dev/null +++ b/oai/ui/console.py @@ -0,0 +1,242 @@ +""" +Console utilities for oAI. + +This module provides the Rich console instance and common display functions +for formatted terminal output. +""" + +from typing import Any, Optional + +from rich.console import Console +from rich.markdown import Markdown +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +# Global console instance for the application +console = Console() + + +def clear_screen() -> None: + """ + Clear the terminal screen. + + Uses ANSI escape codes for fast clearing, with a fallback + for terminals that don't support them. + """ + try: + print("\033[H\033[J", end="", flush=True) + except Exception: + # Fallback: print many newlines + print("\n" * 100) + + +def display_panel( + content: Any, + title: Optional[str] = None, + subtitle: Optional[str] = None, + border_style: str = "green", + title_align: str = "left", + subtitle_align: str = "right", +) -> None: + """ + Display content in a bordered panel. + + Args: + content: Content to display (string, Table, or Markdown) + title: Optional panel title + subtitle: Optional panel subtitle + border_style: Border color/style + title_align: Title alignment ("left", "center", "right") + subtitle_align: Subtitle alignment + """ + panel = Panel( + content, + title=title, + subtitle=subtitle, + border_style=border_style, + title_align=title_align, + subtitle_align=subtitle_align, + ) + console.print(panel) + + +def display_table( + table: Table, + title: Optional[str] = None, + subtitle: Optional[str] = None, +) -> None: + """ + Display a table with optional title panel. + + Args: + table: Rich Table to display + title: Optional panel title + subtitle: Optional panel subtitle + """ + if title: + display_panel(table, title=title, subtitle=subtitle) + else: + console.print(table) + + +def display_markdown( + content: str, + panel: bool = False, + title: Optional[str] = None, +) -> None: + """ + Display markdown-formatted content. + + Args: + content: Markdown text to display + panel: Whether to wrap in a panel + title: Optional panel title (if panel=True) + """ + md = Markdown(content) + if panel: + display_panel(md, title=title) + else: + console.print(md) + + +def print_error(message: str, prefix: str = "Error:") -> None: + """ + Print an error message in red. + + Args: + message: Error message to display + prefix: Prefix before the message (default: "Error:") + """ + console.print(f"[bold red]{prefix}[/] {message}") + + +def print_warning(message: str, prefix: str = "Warning:") -> None: + """ + Print a warning message in yellow. + + Args: + message: Warning message to display + prefix: Prefix before the message (default: "Warning:") + """ + console.print(f"[bold yellow]{prefix}[/] {message}") + + +def print_success(message: str, prefix: str = "โœ“") -> None: + """ + Print a success message in green. + + Args: + message: Success message to display + prefix: Prefix before the message (default: "โœ“") + """ + console.print(f"[bold green]{prefix}[/] {message}") + + +def print_info(message: str, dim: bool = False) -> None: + """ + Print an informational message in cyan. + + Args: + message: Info message to display + dim: Whether to dim the message + """ + if dim: + console.print(f"[dim cyan]{message}[/]") + else: + console.print(f"[bold cyan]{message}[/]") + + +def print_metrics( + tokens: int, + cost: float, + time_seconds: float, + context_info: str = "", + online: bool = False, + mcp_mode: Optional[str] = None, + tool_loops: int = 0, + session_tokens: int = 0, + session_cost: float = 0.0, +) -> None: + """ + Print formatted metrics for a response. + + Args: + tokens: Total tokens used + cost: Cost in USD + time_seconds: Response time + context_info: Context information string + online: Whether online mode is active + mcp_mode: MCP mode ("files", "database", or None) + tool_loops: Number of tool call loops + session_tokens: Total session tokens + session_cost: Total session cost + """ + parts = [ + f"๐Ÿ“Š Metrics: {tokens} tokens", + f"${cost:.4f}", + f"{time_seconds:.2f}s", + ] + + if context_info: + parts.append(context_info) + + if online: + parts.append("๐ŸŒ") + + if mcp_mode == "files": + parts.append("๐Ÿ”ง") + elif mcp_mode == "database": + parts.append("๐Ÿ—„๏ธ") + + if tool_loops > 0: + parts.append(f"({tool_loops} tool loop(s))") + + parts.append(f"Session: {session_tokens} tokens") + parts.append(f"${session_cost:.4f}") + + console.print(f"\n[dim blue]{' | '.join(parts)}[/]") + + +def format_size(size_bytes: int) -> str: + """ + Format a size in bytes to a human-readable string. + + Args: + size_bytes: Size in bytes + + Returns: + Formatted size string (e.g., "1.5 MB") + """ + for unit in ["B", "KB", "MB", "GB", "TB"]: + if abs(size_bytes) < 1024.0: + return f"{size_bytes:.1f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.1f} PB" + + +def format_tokens(tokens: int) -> str: + """ + Format token count with thousands separators. + + Args: + tokens: Number of tokens + + Returns: + Formatted token string (e.g., "1,234,567") + """ + return f"{tokens:,}" + + +def format_cost(cost: float, precision: int = 4) -> str: + """ + Format cost in USD. + + Args: + cost: Cost in dollars + precision: Decimal places + + Returns: + Formatted cost string (e.g., "$0.0123") + """ + return f"${cost:.{precision}f}" diff --git a/oai/ui/prompts.py b/oai/ui/prompts.py new file mode 100644 index 0000000..654a43e --- /dev/null +++ b/oai/ui/prompts.py @@ -0,0 +1,274 @@ +""" +Prompt utilities for oAI. + +This module provides functions for gathering user input +through confirmations, choices, and text prompts. +""" + +from typing import List, Optional, TypeVar + +import typer + +from oai.ui.console import console + +T = TypeVar("T") + + +def prompt_confirm( + message: str, + default: bool = False, + abort: bool = False, +) -> bool: + """ + Prompt the user for a yes/no confirmation. + + Args: + message: The question to ask + default: Default value if user presses Enter + abort: Whether to abort on "no" response + + Returns: + True if user confirms, False otherwise + """ + try: + return typer.confirm(message, default=default, abort=abort) + except (EOFError, KeyboardInterrupt): + console.print("\n[yellow]Cancelled[/]") + return False + + +def prompt_choice( + message: str, + choices: List[str], + default: Optional[str] = None, +) -> Optional[str]: + """ + Prompt the user to select from a list of choices. + + Args: + message: The question to ask + choices: List of valid choices + default: Default choice if user presses Enter + + Returns: + Selected choice or None if cancelled + """ + # Display choices + console.print(f"\n[bold cyan]{message}[/]") + for i, choice in enumerate(choices, 1): + default_marker = " [default]" if choice == default else "" + console.print(f" {i}. {choice}{default_marker}") + + try: + response = input("\nEnter number or value: ").strip() + + if not response and default: + return default + + # Try as number first + try: + index = int(response) - 1 + if 0 <= index < len(choices): + return choices[index] + except ValueError: + pass + + # Try as exact match + if response in choices: + return response + + # Try case-insensitive match + response_lower = response.lower() + for choice in choices: + if choice.lower() == response_lower: + return choice + + console.print(f"[red]Invalid choice: {response}[/]") + return None + + except (EOFError, KeyboardInterrupt): + console.print("\n[yellow]Cancelled[/]") + return None + + +def prompt_input( + message: str, + default: Optional[str] = None, + password: bool = False, + required: bool = False, +) -> Optional[str]: + """ + Prompt the user for text input. + + Args: + message: The prompt message + default: Default value if user presses Enter + password: Whether to hide input (for sensitive data) + required: Whether input is required (loops until provided) + + Returns: + User input or default, None if cancelled + """ + prompt_text = message + if default: + prompt_text += f" [{default}]" + prompt_text += ": " + + try: + while True: + if password: + import getpass + + response = getpass.getpass(prompt_text) + else: + response = input(prompt_text).strip() + + if not response: + if default: + return default + if required: + console.print("[yellow]Input required[/]") + continue + return None + + return response + + except (EOFError, KeyboardInterrupt): + console.print("\n[yellow]Cancelled[/]") + return None + + +def prompt_number( + message: str, + min_value: Optional[int] = None, + max_value: Optional[int] = None, + default: Optional[int] = None, +) -> Optional[int]: + """ + Prompt the user for a numeric input. + + Args: + message: The prompt message + min_value: Minimum allowed value + max_value: Maximum allowed value + default: Default value if user presses Enter + + Returns: + Integer value or None if cancelled + """ + prompt_text = message + if default is not None: + prompt_text += f" [{default}]" + prompt_text += ": " + + try: + while True: + response = input(prompt_text).strip() + + if not response: + if default is not None: + return default + return None + + try: + value = int(response) + except ValueError: + console.print("[red]Please enter a valid number[/]") + continue + + if min_value is not None and value < min_value: + console.print(f"[red]Value must be at least {min_value}[/]") + continue + + if max_value is not None and value > max_value: + console.print(f"[red]Value must be at most {max_value}[/]") + continue + + return value + + except (EOFError, KeyboardInterrupt): + console.print("\n[yellow]Cancelled[/]") + return None + + +def prompt_selection( + items: List[T], + message: str = "Select an item", + display_func: Optional[callable] = None, + allow_cancel: bool = True, +) -> Optional[T]: + """ + Prompt the user to select an item from a list. + + Args: + items: List of items to choose from + message: The selection prompt + display_func: Function to convert item to display string + allow_cancel: Whether to allow cancellation + + Returns: + Selected item or None if cancelled + """ + if not items: + console.print("[yellow]No items to select[/]") + return None + + display = display_func or str + + console.print(f"\n[bold cyan]{message}[/]") + for i, item in enumerate(items, 1): + console.print(f" {i}. {display(item)}") + + if allow_cancel: + console.print(f" 0. Cancel") + + try: + while True: + response = input("\nEnter number: ").strip() + + try: + index = int(response) + except ValueError: + console.print("[red]Please enter a valid number[/]") + continue + + if allow_cancel and index == 0: + return None + + if 1 <= index <= len(items): + return items[index - 1] + + console.print(f"[red]Please enter a number between 1 and {len(items)}[/]") + + except (EOFError, KeyboardInterrupt): + console.print("\n[yellow]Cancelled[/]") + return None + + +def prompt_copy_response(response: str) -> bool: + """ + Prompt user to copy a response to clipboard. + + Args: + response: The response text + + Returns: + True if copied, False otherwise + """ + try: + copy_choice = input("๐Ÿ’พ Type 'c' to copy response, or press Enter to continue: ").strip().lower() + if copy_choice == "c": + try: + import pyperclip + + pyperclip.copy(response) + console.print("[bold green]โœ… Response copied to clipboard![/]") + return True + except ImportError: + console.print("[yellow]pyperclip not installed - cannot copy to clipboard[/]") + except Exception as e: + console.print(f"[red]Failed to copy: {e}[/]") + except (EOFError, KeyboardInterrupt): + pass + + return False diff --git a/oai/ui/tables.py b/oai/ui/tables.py new file mode 100644 index 0000000..49a1c2e --- /dev/null +++ b/oai/ui/tables.py @@ -0,0 +1,373 @@ +""" +Table utilities for oAI. + +This module provides functions for creating and displaying +formatted tables with pagination support. +""" + +import os +import sys +from typing import Any, Dict, List, Optional + +from rich.panel import Panel +from rich.table import Table + +from oai.ui.console import clear_screen, console + + +def create_model_table( + models: List[Dict[str, Any]], + show_capabilities: bool = True, +) -> Table: + """ + Create a table displaying available AI models. + + Args: + models: List of model dictionaries + show_capabilities: Whether to show capability columns + + Returns: + Rich Table with model information + """ + if show_capabilities: + table = Table( + "No.", + "Model ID", + "Context", + "Image", + "Online", + "Tools", + show_header=True, + header_style="bold magenta", + ) + else: + table = Table( + "No.", + "Model ID", + "Context", + show_header=True, + header_style="bold magenta", + ) + + for i, model in enumerate(models, 1): + model_id = model.get("id", "Unknown") + context = model.get("context_length", 0) + context_str = f"{context:,}" if context else "-" + + if show_capabilities: + # Get modalities and parameters + architecture = model.get("architecture", {}) + input_modalities = architecture.get("input_modalities", []) + supported_params = model.get("supported_parameters", []) + + has_image = "โœ“" if "image" in input_modalities else "-" + has_online = "โœ“" if "tools" in supported_params else "-" + has_tools = "โœ“" if "tools" in supported_params or "functions" in supported_params else "-" + + table.add_row( + str(i), + model_id, + context_str, + has_image, + has_online, + has_tools, + ) + else: + table.add_row(str(i), model_id, context_str) + + return table + + +def create_stats_table(stats: Dict[str, Any]) -> Table: + """ + Create a table displaying session statistics. + + Args: + stats: Dictionary with statistics data + + Returns: + Rich Table with stats + """ + table = Table( + "Metric", + "Value", + show_header=True, + header_style="bold magenta", + ) + + # Token stats + if "input_tokens" in stats: + table.add_row("Input Tokens", f"{stats['input_tokens']:,}") + if "output_tokens" in stats: + table.add_row("Output Tokens", f"{stats['output_tokens']:,}") + if "total_tokens" in stats: + table.add_row("Total Tokens", f"{stats['total_tokens']:,}") + + # Cost stats + if "total_cost" in stats: + table.add_row("Total Cost", f"${stats['total_cost']:.4f}") + if "avg_cost" in stats: + table.add_row("Avg Cost/Message", f"${stats['avg_cost']:.4f}") + + # Message stats + if "message_count" in stats: + table.add_row("Messages", str(stats["message_count"])) + + # Credits + if "credits_left" in stats: + table.add_row("Credits Left", stats["credits_left"]) + + return table + + +def create_help_table(commands: Dict[str, Dict[str, str]]) -> Table: + """ + Create a help table for commands. + + Args: + commands: Dictionary of command info + + Returns: + Rich Table with command help + """ + table = Table( + "Command", + "Description", + "Example", + show_header=True, + header_style="bold magenta", + show_lines=False, + ) + + for cmd, info in commands.items(): + description = info.get("description", "") + example = info.get("example", "") + table.add_row(cmd, description, example) + + return table + + +def create_folder_table( + folders: List[Dict[str, Any]], + gitignore_info: str = "", +) -> Table: + """ + Create a table for MCP folder listing. + + Args: + folders: List of folder dictionaries + gitignore_info: Optional gitignore status info + + Returns: + Rich Table with folder information + """ + table = Table( + "No.", + "Path", + "Files", + "Size", + show_header=True, + header_style="bold magenta", + ) + + for folder in folders: + number = str(folder.get("number", "")) + path = folder.get("path", "") + + if folder.get("exists", True): + files = f"๐Ÿ“ {folder.get('file_count', 0)}" + size = f"{folder.get('size_mb', 0):.1f} MB" + else: + files = "[red]Not found[/red]" + size = "-" + + table.add_row(number, path, files, size) + + return table + + +def create_database_table(databases: List[Dict[str, Any]]) -> Table: + """ + Create a table for MCP database listing. + + Args: + databases: List of database dictionaries + + Returns: + Rich Table with database information + """ + table = Table( + "No.", + "Name", + "Tables", + "Size", + "Status", + show_header=True, + header_style="bold magenta", + ) + + for db in databases: + number = str(db.get("number", "")) + name = db.get("name", "") + table_count = f"{db.get('table_count', 0)} tables" + size = f"{db.get('size_mb', 0):.1f} MB" + + if db.get("warning"): + status = f"[red]{db['warning']}[/red]" + else: + status = "[green]โœ“[/green]" + + table.add_row(number, name, table_count, size, status) + + return table + + +def display_paginated_table( + table: Table, + title: str, + terminal_height: Optional[int] = None, +) -> None: + """ + Display a table with pagination for large datasets. + + Allows navigating through pages with keyboard input. + Press SPACE for next page, any other key to exit. + + Args: + table: Rich Table to display + title: Title for the table + terminal_height: Override terminal height (auto-detected if None) + """ + # Get terminal dimensions + try: + term_height = terminal_height or os.get_terminal_size().lines - 8 + except OSError: + term_height = 20 + + # Render table to segments + from rich.segment import Segment + + segments = list(console.render(table)) + + # Group segments into lines + current_line_segments: List[Segment] = [] + all_lines: List[List[Segment]] = [] + + for segment in segments: + if segment.text == "\n": + all_lines.append(current_line_segments) + current_line_segments = [] + else: + current_line_segments.append(segment) + + if current_line_segments: + all_lines.append(current_line_segments) + + total_lines = len(all_lines) + + # If table fits in one screen, just display it + if total_lines <= term_height: + console.print(Panel(table, title=title, title_align="left")) + return + + # Extract header and footer lines + header_lines: List[List[Segment]] = [] + data_lines: List[List[Segment]] = [] + footer_line: List[Segment] = [] + + # Find header end (line after the header text with border) + header_end_index = 0 + found_header_text = False + + for i, line_segments in enumerate(all_lines): + has_header_style = any( + seg.style and ("bold" in str(seg.style) or "magenta" in str(seg.style)) + for seg in line_segments + ) + + if has_header_style: + found_header_text = True + + if found_header_text and i > 0: + line_text = "".join(seg.text for seg in line_segments) + if any(char in line_text for char in ["โ”€", "โ”", "โ”ผ", "โ•ช", "โ”ค", "โ”œ"]): + header_end_index = i + break + + # Extract footer (bottom border) + if all_lines: + last_line_text = "".join(seg.text for seg in all_lines[-1]) + if any(char in last_line_text for char in ["โ”€", "โ”", "โ”ด", "โ•ง", "โ”˜", "โ””"]): + footer_line = all_lines[-1] + all_lines = all_lines[:-1] + + # Split into header and data + if header_end_index > 0: + header_lines = all_lines[: header_end_index + 1] + data_lines = all_lines[header_end_index + 1 :] + else: + header_lines = all_lines[: min(3, len(all_lines))] + data_lines = all_lines[min(3, len(all_lines)) :] + + lines_per_page = term_height - len(header_lines) + current_line = 0 + page_number = 1 + + # Paginate + while current_line < len(data_lines): + clear_screen() + console.print(f"[bold cyan]{title} (Page {page_number})[/]") + + # Print header + for line_segments in header_lines: + for segment in line_segments: + console.print(segment.text, style=segment.style, end="") + console.print() + + # Print data rows for this page + end_line = min(current_line + lines_per_page, len(data_lines)) + for line_segments in data_lines[current_line:end_line]: + for segment in line_segments: + console.print(segment.text, style=segment.style, end="") + console.print() + + # Print footer + if footer_line: + for segment in footer_line: + console.print(segment.text, style=segment.style, end="") + console.print() + + current_line = end_line + page_number += 1 + + # Prompt for next page + if current_line < len(data_lines): + console.print( + f"\n[dim yellow]--- Press SPACE for next page, " + f"or any other key to finish (Page {page_number - 1}, " + f"showing {end_line}/{len(data_lines)} data rows) ---[/dim yellow]" + ) + + try: + import termios + import tty + + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + + try: + tty.setraw(fd) + char = sys.stdin.read(1) + if char != " ": + break + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + + except (ImportError, OSError, AttributeError): + # Fallback for non-Unix systems + try: + user_input = input() + if user_input.strip(): + break + except (EOFError, KeyboardInterrupt): + break diff --git a/oai/utils/__init__.py b/oai/utils/__init__.py new file mode 100644 index 0000000..6048bfe --- /dev/null +++ b/oai/utils/__init__.py @@ -0,0 +1,20 @@ +""" +Utility modules for oAI. + +This package provides common utilities used throughout the application +including logging, file handling, and export functionality. +""" + +from oai.utils.logging import setup_logging, get_logger +from oai.utils.files import read_file_safe, is_binary_file +from oai.utils.export import export_as_markdown, export_as_json, export_as_html + +__all__ = [ + "setup_logging", + "get_logger", + "read_file_safe", + "is_binary_file", + "export_as_markdown", + "export_as_json", + "export_as_html", +] diff --git a/oai/utils/export.py b/oai/utils/export.py new file mode 100644 index 0000000..c476ef9 --- /dev/null +++ b/oai/utils/export.py @@ -0,0 +1,248 @@ +""" +Export utilities for oAI. + +This module provides functions for exporting conversation history +in various formats including Markdown, JSON, and HTML. +""" + +import json +import datetime +from typing import List, Dict +from html import escape as html_escape + +from oai.constants import APP_VERSION, APP_URL + + +def export_as_markdown( + session_history: List[Dict[str, str]], + session_system_prompt: str = "" +) -> str: + """ + Export conversation history as Markdown. + + Args: + session_history: List of message dictionaries with 'prompt' and 'response' + session_system_prompt: Optional system prompt to include + + Returns: + Markdown formatted string + """ + lines = ["# Conversation Export", ""] + + if session_system_prompt: + lines.extend([f"**System Prompt:** {session_system_prompt}", ""]) + + lines.append(f"**Export Date:** {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + lines.append(f"**Messages:** {len(session_history)}") + lines.append("") + lines.append("---") + lines.append("") + + for i, entry in enumerate(session_history, 1): + lines.append(f"## Message {i}") + lines.append("") + lines.append("**User:**") + lines.append("") + lines.append(entry.get("prompt", "")) + lines.append("") + lines.append("**Assistant:**") + lines.append("") + lines.append(entry.get("response", "")) + lines.append("") + lines.append("---") + lines.append("") + + lines.append(f"*Exported from oAI v{APP_VERSION} - {APP_URL}*") + + return "\n".join(lines) + + +def export_as_json( + session_history: List[Dict[str, str]], + session_system_prompt: str = "" +) -> str: + """ + Export conversation history as JSON. + + Args: + session_history: List of message dictionaries + session_system_prompt: Optional system prompt to include + + Returns: + JSON formatted string + """ + export_data = { + "export_date": datetime.datetime.now().isoformat(), + "app_version": APP_VERSION, + "system_prompt": session_system_prompt, + "message_count": len(session_history), + "messages": [ + { + "index": i + 1, + "prompt": entry.get("prompt", ""), + "response": entry.get("response", ""), + "prompt_tokens": entry.get("prompt_tokens", 0), + "completion_tokens": entry.get("completion_tokens", 0), + "cost": entry.get("msg_cost", 0.0), + } + for i, entry in enumerate(session_history) + ], + "totals": { + "prompt_tokens": sum(e.get("prompt_tokens", 0) for e in session_history), + "completion_tokens": sum(e.get("completion_tokens", 0) for e in session_history), + "total_cost": sum(e.get("msg_cost", 0.0) for e in session_history), + } + } + return json.dumps(export_data, indent=2, ensure_ascii=False) + + +def export_as_html( + session_history: List[Dict[str, str]], + session_system_prompt: str = "" +) -> str: + """ + Export conversation history as styled HTML. + + Args: + session_history: List of message dictionaries + session_system_prompt: Optional system prompt to include + + Returns: + HTML formatted string with embedded CSS + """ + html_parts = [ + "", + "", + "", + " ", + " ", + " Conversation Export - oAI", + " ", + "", + "", + "
", + "

Conversation Export

", + f"
Exported: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
", + f"
Total Messages: {len(session_history)}
", + "
", + ] + + if session_system_prompt: + html_parts.extend([ + "
", + " System Prompt", + f"
{html_escape(session_system_prompt)}
", + "
", + ]) + + for i, entry in enumerate(session_history, 1): + prompt = html_escape(entry.get("prompt", "")) + response = html_escape(entry.get("response", "")) + + html_parts.extend([ + "
", + f"
Message {i} of {len(session_history)}
", + "
", + "
User
", + f"
{prompt}
", + "
", + "
", + "
Assistant
", + f"
{response}
", + "
", + "
", + ]) + + html_parts.extend([ + " ", + "", + "", + ]) + + return "\n".join(html_parts) diff --git a/oai/utils/files.py b/oai/utils/files.py new file mode 100644 index 0000000..f793d8a --- /dev/null +++ b/oai/utils/files.py @@ -0,0 +1,323 @@ +""" +File handling utilities for oAI. + +This module provides safe file reading, type detection, and other +file-related operations used throughout the application. +""" + +import os +import mimetypes +import base64 +from pathlib import Path +from typing import Optional, Dict, Any, Tuple + +from oai.constants import ( + MAX_FILE_SIZE, + CONTENT_TRUNCATION_THRESHOLD, + SUPPORTED_CODE_EXTENSIONS, + ALLOWED_FILE_EXTENSIONS, +) +from oai.utils.logging import get_logger + + +def is_binary_file(file_path: Path) -> bool: + """ + Check if a file appears to be binary. + + Args: + file_path: Path to the file to check + + Returns: + True if the file appears to be binary, False otherwise + """ + try: + with open(file_path, "rb") as f: + # Read first 8KB to check for binary content + chunk = f.read(8192) + # Check for null bytes (common in binary files) + if b"\x00" in chunk: + return True + # Try to decode as UTF-8 + try: + chunk.decode("utf-8") + return False + except UnicodeDecodeError: + return True + except Exception: + return True + + +def get_file_type(file_path: Path) -> Tuple[Optional[str], str]: + """ + Determine the MIME type and category of a file. + + Args: + file_path: Path to the file + + Returns: + Tuple of (mime_type, category) where category is one of: + 'image', 'pdf', 'code', 'text', 'binary', 'unknown' + """ + mime_type, _ = mimetypes.guess_type(str(file_path)) + ext = file_path.suffix.lower() + + if mime_type and mime_type.startswith("image/"): + return mime_type, "image" + elif mime_type == "application/pdf" or ext == ".pdf": + return mime_type or "application/pdf", "pdf" + elif ext in SUPPORTED_CODE_EXTENSIONS: + return mime_type or "text/plain", "code" + elif mime_type and mime_type.startswith("text/"): + return mime_type, "text" + elif is_binary_file(file_path): + return mime_type, "binary" + else: + return mime_type, "unknown" + + +def read_file_safe( + file_path: Path, + max_size: int = MAX_FILE_SIZE, + truncate_threshold: int = CONTENT_TRUNCATION_THRESHOLD +) -> Dict[str, Any]: + """ + Safely read a file with size limits and truncation support. + + Args: + file_path: Path to the file to read + max_size: Maximum file size to read (bytes) + truncate_threshold: Threshold for truncating large files + + Returns: + Dictionary containing: + - content: File content (text or base64) + - size: File size in bytes + - truncated: Whether content was truncated + - encoding: 'text', 'base64', or None on error + - error: Error message if reading failed + """ + logger = get_logger() + + try: + path = Path(file_path).resolve() + + if not path.exists(): + return { + "content": None, + "size": 0, + "truncated": False, + "encoding": None, + "error": f"File not found: {path}" + } + + if not path.is_file(): + return { + "content": None, + "size": 0, + "truncated": False, + "encoding": None, + "error": f"Not a file: {path}" + } + + file_size = path.stat().st_size + + if file_size > max_size: + return { + "content": None, + "size": file_size, + "truncated": False, + "encoding": None, + "error": f"File too large: {file_size / (1024*1024):.1f}MB (max: {max_size / (1024*1024):.0f}MB)" + } + + # Try to read as text first + try: + content = path.read_text(encoding="utf-8") + + # Check if truncation is needed + if file_size > truncate_threshold: + lines = content.split("\n") + total_lines = len(lines) + + # Keep first 500 lines and last 100 lines + head_lines = 500 + tail_lines = 100 + + if total_lines > (head_lines + tail_lines): + truncated_content = ( + "\n".join(lines[:head_lines]) + + f"\n\n... [TRUNCATED: {total_lines - head_lines - tail_lines} lines omitted] ...\n\n" + + "\n".join(lines[-tail_lines:]) + ) + logger.info(f"Read file (truncated): {path} ({file_size} bytes, {total_lines} lines)") + return { + "content": truncated_content, + "size": file_size, + "truncated": True, + "total_lines": total_lines, + "lines_shown": head_lines + tail_lines, + "encoding": "text", + "error": None + } + + logger.info(f"Read file: {path} ({file_size} bytes)") + return { + "content": content, + "size": file_size, + "truncated": False, + "encoding": "text", + "error": None + } + + except UnicodeDecodeError: + # File is binary, return base64 encoded + with open(path, "rb") as f: + binary_data = f.read() + b64_content = base64.b64encode(binary_data).decode("utf-8") + logger.info(f"Read binary file: {path} ({file_size} bytes)") + return { + "content": b64_content, + "size": file_size, + "truncated": False, + "encoding": "base64", + "error": None + } + + except PermissionError as e: + return { + "content": None, + "size": 0, + "truncated": False, + "encoding": None, + "error": f"Permission denied: {e}" + } + except Exception as e: + logger.error(f"Error reading file {file_path}: {e}") + return { + "content": None, + "size": 0, + "truncated": False, + "encoding": None, + "error": str(e) + } + + +def get_file_extension(file_path: Path) -> str: + """ + Get the lowercase file extension. + + Args: + file_path: Path to the file + + Returns: + Lowercase extension including the dot (e.g., '.py') + """ + return file_path.suffix.lower() + + +def is_allowed_extension(file_path: Path) -> bool: + """ + Check if a file has an allowed extension for attachment. + + Args: + file_path: Path to the file + + Returns: + True if the extension is allowed, False otherwise + """ + return get_file_extension(file_path) in ALLOWED_FILE_EXTENSIONS + + +def format_file_size(size_bytes: int) -> str: + """ + Format a file size in human-readable format. + + Args: + size_bytes: Size in bytes + + Returns: + Formatted string (e.g., '1.5 MB', '512 KB') + """ + for unit in ["B", "KB", "MB", "GB", "TB"]: + if abs(size_bytes) < 1024: + return f"{size_bytes:.1f} {unit}" + size_bytes /= 1024 + return f"{size_bytes:.1f} PB" + + +def prepare_file_attachment( + file_path: Path, + model_capabilities: Dict[str, Any] +) -> Optional[Dict[str, Any]]: + """ + Prepare a file for attachment to an API request. + + Args: + file_path: Path to the file + model_capabilities: Model capability information + + Returns: + Content block dictionary for the API, or None if unsupported + """ + logger = get_logger() + path = Path(file_path).resolve() + + if not path.exists(): + logger.warning(f"File not found: {path}") + return None + + mime_type, category = get_file_type(path) + file_size = path.stat().st_size + + if file_size > MAX_FILE_SIZE: + logger.warning(f"File too large: {path} ({format_file_size(file_size)})") + return None + + try: + with open(path, "rb") as f: + file_data = f.read() + + if category == "image": + # Check if model supports images + input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", []) + if "image" not in input_modalities: + logger.warning(f"Model does not support images") + return None + + b64_data = base64.b64encode(file_data).decode("utf-8") + return { + "type": "image_url", + "image_url": {"url": f"data:{mime_type};base64,{b64_data}"} + } + + elif category == "pdf": + # Check if model supports PDFs + input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", []) + supports_pdf = any(mod in input_modalities for mod in ["document", "pdf", "file"]) + if not supports_pdf: + logger.warning(f"Model does not support PDFs") + return None + + b64_data = base64.b64encode(file_data).decode("utf-8") + return { + "type": "image_url", + "image_url": {"url": f"data:application/pdf;base64,{b64_data}"} + } + + elif category in ("code", "text"): + text_content = file_data.decode("utf-8") + return { + "type": "text", + "text": f"File: {path.name}\n\n{text_content}" + } + + else: + logger.warning(f"Unsupported file type: {category} ({mime_type})") + return None + + except UnicodeDecodeError: + logger.error(f"Cannot decode file as UTF-8: {path}") + return None + except Exception as e: + logger.error(f"Error preparing file attachment {path}: {e}") + return None diff --git a/oai/utils/logging.py b/oai/utils/logging.py new file mode 100644 index 0000000..859c51c --- /dev/null +++ b/oai/utils/logging.py @@ -0,0 +1,297 @@ +""" +Logging configuration for oAI. + +This module provides centralized logging setup with Rich formatting, +file rotation, and configurable log levels. +""" + +import io +import os +import glob +import logging +import datetime +import shutil +from logging.handlers import RotatingFileHandler +from pathlib import Path +from typing import Optional + +from rich.console import Console +from rich.logging import RichHandler + +from oai.constants import ( + LOG_FILE, + CONFIG_DIR, + DEFAULT_LOG_MAX_SIZE_MB, + DEFAULT_LOG_BACKUP_COUNT, + DEFAULT_LOG_LEVEL, + VALID_LOG_LEVELS, +) + + +class RotatingRichHandler(RotatingFileHandler): + """ + Custom log handler combining file rotation with Rich formatting. + + This handler writes Rich-formatted log output to a rotating file, + providing colored and formatted logs even in file output while + managing file size and backups automatically. + """ + + def __init__(self, *args, **kwargs): + """Initialize the handler with Rich console for formatting.""" + super().__init__(*args, **kwargs) + # Create an internal console for Rich formatting + self.rich_console = Console( + file=io.StringIO(), + width=120, + force_terminal=False + ) + self.rich_handler = RichHandler( + console=self.rich_console, + show_time=True, + show_path=True, + rich_tracebacks=True, + tracebacks_suppress=["requests", "openrouter", "urllib3", "httpx", "openai"] + ) + + def emit(self, record: logging.LogRecord) -> None: + """ + Emit a log record with Rich formatting. + + Args: + record: The log record to emit + """ + try: + # Format with Rich + self.rich_handler.emit(record) + output = self.rich_console.file.getvalue() + self.rich_console.file.seek(0) + self.rich_console.file.truncate(0) + + if output: + self.stream.write(output) + self.flush() + except Exception: + self.handleError(record) + + +class LoggingManager: + """ + Manages application logging configuration. + + Provides methods to setup, configure, and manage logging with + support for runtime reconfiguration and level changes. + """ + + def __init__(self): + """Initialize the logging manager.""" + self.handler: Optional[RotatingRichHandler] = None + self.app_logger: Optional[logging.Logger] = None + self.max_size_mb: int = DEFAULT_LOG_MAX_SIZE_MB + self.backup_count: int = DEFAULT_LOG_BACKUP_COUNT + self.log_level: str = DEFAULT_LOG_LEVEL + + def setup( + self, + max_size_mb: Optional[int] = None, + backup_count: Optional[int] = None, + log_level: Optional[str] = None + ) -> logging.Logger: + """ + Setup or reconfigure logging. + + Args: + max_size_mb: Maximum log file size in MB + backup_count: Number of backup files to keep + log_level: Logging level string + + Returns: + The configured application logger + """ + # Update configuration if provided + if max_size_mb is not None: + self.max_size_mb = max_size_mb + if backup_count is not None: + self.backup_count = backup_count + if log_level is not None: + self.log_level = log_level + + # Ensure config directory exists + CONFIG_DIR.mkdir(parents=True, exist_ok=True) + + # Get root logger + root_logger = logging.getLogger() + + # Remove existing handler if present + if self.handler is not None: + root_logger.removeHandler(self.handler) + try: + self.handler.close() + except Exception: + pass + + # Check if log needs manual rotation + self._check_rotation() + + # Create new handler + max_bytes = self.max_size_mb * 1024 * 1024 + self.handler = RotatingRichHandler( + filename=str(LOG_FILE), + maxBytes=max_bytes, + backupCount=self.backup_count, + encoding="utf-8" + ) + + self.handler.setLevel(logging.NOTSET) + root_logger.setLevel(logging.WARNING) + root_logger.addHandler(self.handler) + + # Suppress noisy third-party loggers + for logger_name in [ + "asyncio", "urllib3", "requests", "httpx", + "httpcore", "openai", "openrouter" + ]: + logging.getLogger(logger_name).setLevel(logging.WARNING) + + # Configure application logger + self.app_logger = logging.getLogger("oai_app") + level = VALID_LOG_LEVELS.get(self.log_level.lower(), logging.INFO) + self.app_logger.setLevel(level) + self.app_logger.propagate = True + + return self.app_logger + + def _check_rotation(self) -> None: + """Check if log file needs rotation and perform if necessary.""" + if not LOG_FILE.exists(): + return + + current_size = LOG_FILE.stat().st_size + max_bytes = self.max_size_mb * 1024 * 1024 + + if current_size >= max_bytes: + # Perform manual rotation + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + backup_file = f"{LOG_FILE}.{timestamp}" + + try: + shutil.move(str(LOG_FILE), backup_file) + except Exception: + pass + + # Clean old backups + self._cleanup_old_backups() + + def _cleanup_old_backups(self) -> None: + """Remove old backup files exceeding the backup count.""" + log_dir = LOG_FILE.parent + backup_pattern = f"{LOG_FILE.name}.*" + backups = sorted(glob.glob(str(log_dir / backup_pattern))) + + while len(backups) > self.backup_count: + oldest = backups.pop(0) + try: + os.remove(oldest) + except Exception: + pass + + def set_level(self, level: str) -> bool: + """ + Set the application log level. + + Args: + level: Log level string (debug/info/warning/error/critical) + + Returns: + True if level was set successfully, False otherwise + """ + level_lower = level.lower() + if level_lower not in VALID_LOG_LEVELS: + return False + + self.log_level = level_lower + if self.app_logger: + self.app_logger.setLevel(VALID_LOG_LEVELS[level_lower]) + + return True + + def get_logger(self) -> logging.Logger: + """ + Get the application logger, initializing if necessary. + + Returns: + The application logger + """ + if self.app_logger is None: + self.setup() + return self.app_logger + + +# Global logging manager instance +_logging_manager = LoggingManager() + + +def setup_logging( + max_size_mb: Optional[int] = None, + backup_count: Optional[int] = None, + log_level: Optional[str] = None +) -> logging.Logger: + """ + Setup application logging. + + This is the main entry point for configuring logging. Call this + early in application startup. + + Args: + max_size_mb: Maximum log file size in MB + backup_count: Number of backup files to keep + log_level: Logging level string + + Returns: + The configured application logger + """ + return _logging_manager.setup(max_size_mb, backup_count, log_level) + + +def get_logger() -> logging.Logger: + """ + Get the application logger. + + Returns: + The application logger instance + """ + return _logging_manager.get_logger() + + +def set_log_level(level: str) -> bool: + """ + Set the application log level. + + Args: + level: Log level string + + Returns: + True if successful, False otherwise + """ + return _logging_manager.set_level(level) + + +def reload_logging( + max_size_mb: Optional[int] = None, + backup_count: Optional[int] = None, + log_level: Optional[str] = None +) -> logging.Logger: + """ + Reload logging configuration. + + Useful when settings change at runtime. + + Args: + max_size_mb: New maximum log file size + backup_count: New backup count + log_level: New log level + + Returns: + The reconfigured logger + """ + return _logging_manager.setup(max_size_mb, backup_count, log_level) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..af80856 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,134 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "oai" +version = "2.1.0" +description = "OpenRouter AI Chat Client - A feature-rich terminal-based chat application" +readme = "README.md" +license = {text = "MIT"} +authors = [ + {name = "Rune", email = "rune@example.com"} +] +maintainers = [ + {name = "Rune", email = "rune@example.com"} +] +keywords = [ + "ai", + "chat", + "openrouter", + "cli", + "terminal", + "mcp", + "llm", +] +classifiers = [ + "Development Status :: 4 - Beta", + "Environment :: Console", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Utilities", +] +requires-python = ">=3.10" +dependencies = [ + "anyio>=4.0.0", + "click>=8.0.0", + "httpx>=0.24.0", + "markdown-it-py>=3.0.0", + "openrouter>=0.0.19", + "packaging>=21.0", + "prompt-toolkit>=3.0.0", + "pyperclip>=1.8.0", + "requests>=2.28.0", + "rich>=13.0.0", + "typer>=0.9.0", + "mcp>=1.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", + "pytest-asyncio>=0.21.0", + "pytest-cov>=4.0.0", + "black>=23.0.0", + "isort>=5.12.0", + "mypy>=1.0.0", + "ruff>=0.1.0", +] + +[project.urls] +Homepage = "https://iurl.no/oai" +Repository = "https://gitlab.pm/rune/oai" +Documentation = "https://iurl.no/oai" +"Bug Tracker" = "https://gitlab.pm/rune/oai/issues" + +[project.scripts] +oai = "oai.cli:main" + +[tool.setuptools] +packages = ["oai", "oai.commands", "oai.config", "oai.core", "oai.mcp", "oai.providers", "oai.ui", "oai.utils"] + +[tool.setuptools.package-data] +oai = ["py.typed"] + +[tool.black] +line-length = 100 +target-version = ["py310", "py311", "py312"] +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.mypy_cache + | \.pytest_cache + | \.venv + | build + | dist +)/ +''' + +[tool.isort] +profile = "black" +line_length = 100 +skip_gitignore = true + +[tool.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_configs = true +ignore_missing_imports = true +exclude = [ + "build", + "dist", + ".venv", +] + +[tool.ruff] +line-length = 100 +target-version = "py310" +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # Pyflakes + "I", # isort + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "UP", # pyupgrade +] +ignore = [ + "E501", # line too long (handled by black) + "B008", # do not perform function calls in argument defaults + "C901", # too complex +] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +asyncio_mode = "auto" +addopts = "-v --tb=short"