From b0cf88704ebf83effdbc27913132b9e6edc6f573 Mon Sep 17 00:00:00 2001 From: Rune Olsen Date: Tue, 3 Feb 2026 09:02:44 +0100 Subject: [PATCH] 2.1 (#2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Final release of version 2.1. Headlights: ### Core Features - 🤖 Interactive chat with 300+ AI models via OpenRouter - 🔍 Model selection with search and filtering - 💾 Conversation save/load/export (Markdown, JSON, HTML) - 📎 File attachments (images, PDFs, code files) - 💰 Real-time cost tracking and credit monitoring - 🎨 Rich terminal UI with syntax highlighting - 📝 Persistent command history with search (Ctrl+R) - 🌐 Online mode (web search capabilities) - 🧠 Conversation memory toggle ### MCP Integration - 🔧 **File Mode**: AI can read, search, and list local files - Automatic .gitignore filtering - Virtual environment exclusion - Large file handling (auto-truncates >50KB) - ✍️ **Write Mode**: AI can modify files with permission - Create, edit, delete files - Move, copy, organize files - Always requires explicit opt-in - 🗄️ **Database Mode**: AI can query SQLite databases - Read-only access (safe) - Schema inspection - Full SQL query support Reviewed-on: https://gitlab.pm/rune/oai/pulls/2 Co-authored-by: Rune Olsen Co-committed-by: Rune Olsen --- .gitignore | 17 +- README.md | 411 ++++--- oai.py | 2273 ----------------------------------- oai/__init__.py | 26 + oai/__main__.py | 8 + oai/cli.py | 719 +++++++++++ oai/commands/__init__.py | 24 + oai/commands/handlers.py | 1441 ++++++++++++++++++++++ oai/commands/registry.py | 381 ++++++ oai/config/__init__.py | 11 + oai/config/database.py | 472 ++++++++ oai/config/settings.py | 361 ++++++ oai/constants.py | 448 +++++++ oai/core/__init__.py | 14 + oai/core/client.py | 422 +++++++ oai/core/session.py | 659 ++++++++++ oai/mcp/__init__.py | 28 + oai/mcp/gitignore.py | 166 +++ oai/mcp/manager.py | 1365 +++++++++++++++++++++ oai/mcp/platform.py | 228 ++++ oai/mcp/server.py | 1368 +++++++++++++++++++++ oai/mcp/validators.py | 123 ++ oai/providers/__init__.py | 32 + oai/providers/base.py | 413 +++++++ oai/providers/openrouter.py | 623 ++++++++++ oai/py.typed | 2 + oai/ui/__init__.py | 51 + oai/ui/console.py | 242 ++++ oai/ui/prompts.py | 274 +++++ oai/ui/tables.py | 373 ++++++ oai/utils/__init__.py | 20 + oai/utils/export.py | 248 ++++ oai/utils/files.py | 323 +++++ oai/utils/logging.py | 297 +++++ pyproject.toml | 134 +++ requirements.txt | 64 +- 36 files changed, 11576 insertions(+), 2485 deletions(-) delete mode 100644 oai.py create mode 100644 oai/__init__.py create mode 100644 oai/__main__.py create mode 100644 oai/cli.py create mode 100644 oai/commands/__init__.py create mode 100644 oai/commands/handlers.py create mode 100644 oai/commands/registry.py create mode 100644 oai/config/__init__.py create mode 100644 oai/config/database.py create mode 100644 oai/config/settings.py create mode 100644 oai/constants.py create mode 100644 oai/core/__init__.py create mode 100644 oai/core/client.py create mode 100644 oai/core/session.py create mode 100644 oai/mcp/__init__.py create mode 100644 oai/mcp/gitignore.py create mode 100644 oai/mcp/manager.py create mode 100644 oai/mcp/platform.py create mode 100644 oai/mcp/server.py create mode 100644 oai/mcp/validators.py create mode 100644 oai/providers/__init__.py create mode 100644 oai/providers/base.py create mode 100644 oai/providers/openrouter.py create mode 100644 oai/py.typed create mode 100644 oai/ui/__init__.py create mode 100644 oai/ui/console.py create mode 100644 oai/ui/prompts.py create mode 100644 oai/ui/tables.py create mode 100644 oai/utils/__init__.py create mode 100644 oai/utils/export.py create mode 100644 oai/utils/files.py create mode 100644 oai/utils/logging.py create mode 100644 pyproject.toml diff --git a/.gitignore b/.gitignore index 107b564..c8e0fc6 100644 --- a/.gitignore +++ b/.gitignore @@ -22,8 +22,12 @@ Pipfile.lock # Consider if you want to include or exclude ._* *~.nib *~.xib -README.md.old -oai.zip + +# Claude Code local settings +.claude/ + +# Added by author +*.zip .note diagnose.py *.log @@ -33,4 +37,11 @@ build* compiled/ images/oai-iOS-Default-1024x1024@1x.png images/oai.icon/ -b0.sh \ No newline at end of file +b0.sh +*.bak +*.old +*.sh +*.back +requirements.txt +system_prompt.txt +CLAUDE* diff --git a/README.md b/README.md index 0c91bd6..61bc8fd 100644 --- a/README.md +++ b/README.md @@ -1,232 +1,301 @@ -# oAI - OpenRouter AI Chat +# oAI - OpenRouter AI Chat Client -A terminal-based chat interface for OpenRouter API with conversation management, cost tracking, and rich formatting. - -## Description - -oAI is a command-line chat application that provides an interactive interface to OpenRouter's AI models. It features conversation persistence, file attachments, export capabilities, and detailed session metrics. +A powerful, extensible terminal-based chat client for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI to access local files and query SQLite databases. ## Features -- Interactive chat with multiple AI models via OpenRouter -- Model selection with search functionality -- Conversation save/load/export (Markdown, JSON, HTML) -- File attachment support (code files and images) -- Session cost tracking and credit monitoring -- Rich terminal formatting with syntax highlighting -- Persistent command history -- Configurable system prompts and token limits -- SQLite-based configuration and conversation storage +### Core Features +- 🤖 Interactive chat with 300+ AI models via OpenRouter +- 🔍 Model selection with search and filtering +- 💾 Conversation save/load/export (Markdown, JSON, HTML) +- 📎 File attachments (images, PDFs, code files) +- 💰 Real-time cost tracking and credit monitoring +- 🎨 Rich terminal UI with syntax highlighting +- 📝 Persistent command history with search (Ctrl+R) +- 🌐 Online mode (web search capabilities) +- 🧠 Conversation memory toggle + +### MCP Integration +- 🔧 **File Mode**: AI can read, search, and list local files + - Automatic .gitignore filtering + - Virtual environment exclusion + - Large file handling (auto-truncates >50KB) + +- ✍️ **Write Mode**: AI can modify files with permission + - Create, edit, delete files + - Move, copy, organize files + - Always requires explicit opt-in + +- 🗄️ **Database Mode**: AI can query SQLite databases + - Read-only access (safe) + - Schema inspection + - Full SQL query support ## Requirements -- Python 3.7 or higher -- OpenRouter API key (get one at https://openrouter.ai) - -## Screenshot (from version 1.0) - -[](https://gitlab.pm/rune/oai/src/branch/main/README.md) - -Screenshot of `/help` screen. +- Python 3.10-3.13 +- OpenRouter API key ([get one here](https://openrouter.ai)) ## Installation -### 1. Install Dependencies - -Use the included `requirements.txt` file to install the dependencies: +### Option 1: Install from Source (Recommended) ```bash -pip install -r requirements.txt +# Clone the repository +git clone https://gitlab.pm/rune/oai.git +cd oai + +# Install with pip +pip install -e . ``` -### 2. Make the Script Executable +### Option 2: Pre-built Binary (macOS/Linux) + +Download from [Releases](https://gitlab.pm/rune/oai/releases): +- **macOS (Apple Silicon)**: `oai_v2.1.0_mac_arm64.zip` +- **Linux (x86_64)**: `oai_v2.1.0_linux_x86_64.zip` ```bash -chmod +x oai.py -``` - -### 3. Copy to PATH - -Copy the script to a directory in your `$PATH` environment variable. Common locations include: - -```bash -# Option 1: System-wide (requires sudo) -sudo cp oai.py /usr/local/bin/oai - -# Option 2: User-local (recommended) +# Extract and install +unzip oai_v2.1.0_*.zip mkdir -p ~/.local/bin -cp oai.py ~/.local/bin/oai +mv oai ~/.local/bin/ -# Add to PATH if not already (add to ~/.bashrc or ~/.zshrc) +# macOS only: Remove quarantine and approve +xattr -cr ~/.local/bin/oai +# Then right-click oai in Finder → Open With → Terminal → Click "Open" +``` + +### Add to PATH + +```bash +# Add to ~/.zshrc or ~/.bashrc export PATH="$HOME/.local/bin:$PATH" ``` -### 4. Verify Installation +## Quick Start ```bash -oai +# Start the chat client +oai chat + +# Or with options +oai chat --model gpt-4o --mcp ``` -### 5. Alternative Installation (for *nix systems) - -If you have issues with the above method you can add an alias in your `.bashrc`, `.zshrc` etc. - -```bash -alias oai='python3 ' -``` - -On first run, you will be prompted to enter your OpenRouter API key. - -### 6. Use Binaries - -You can also just download the supplied binary for either Mac wit Mx (M1, M2 etc) `oai_mac_arm64.zip` and follow [#3](https://gitlab.pm/rune/oai#3-copy-to-path). Or download for Linux (64bit) `oai_linux_x86_64.zip` and also follow [#3](https://gitlab.pm/rune/oai#3-copy-to-path). - -## Usage - -### Starting the Application - -```bash -oai -``` +On first run, you'll be prompted for your OpenRouter API key. ### Basic Commands -``` -/help Show all available commands -/model Select an AI model -/config api Set OpenRouter API key -exit Quit the application +```bash +# In the chat interface: +/model # Select AI model +/help # Show all commands +/mcp on # Enable file/database access +/stats # View session statistics +exit # Quit ``` -### Configuration +## MCP (Model Context Protocol) -All configuration is stored in `~/.config/oai/`: -- `oai_config.db` - SQLite database for settings and conversations -- `oai.log` - Application log file -- `history.txt` - Command history +MCP allows the AI to interact with your local files and databases. -### Common Workflows +### File Access -**Select a Model:** -``` -/model +```bash +/mcp on # Enable MCP +/mcp add ~/Projects # Grant access to folder +/mcp list # View allowed folders + +# Now ask the AI: +"List all Python files in Projects" +"Read and explain main.py" +"Search for files containing 'TODO'" ``` -**Paste from clipboard:** -Paste and send content to model -``` -/paste +### Write Mode + +```bash +/mcp write on # Enable file modifications + +# AI can now: +"Create a new file called utils.py" +"Edit config.json and update the API URL" +"Delete the old backup files" # Always asks for confirmation ``` -Paste with prompt and send content to model -``` -/paste Analyze this text -``` +### Database Mode -**Start Chatting:** -``` -You> Hello, how are you? -``` +```bash +/mcp add db ~/app/data.db # Add database +/mcp db 1 # Switch to database mode -**Attach Files:** +# Ask the AI: +"Show all tables" +"Find users created this month" +"What's the schema for the orders table?" ``` -You> Debug this code @/path/to/script.py -You> Analyze this image @/path/to/image.png -``` - -**Save Conversation:** -``` -/save my_conversation -``` - -**Export to File:** -``` -/export md notes.md -/export json backup.json -/export html report.html -``` - -**View Session Stats:** - -``` -/stats -/credits -``` - -**Prevous commands input:** - -Use the up/down arrows to see earlier `/`commands and earlier input to model and `` to execute the same command or resend the same input. ## Command Reference -Use `/help` within the application for a complete command reference organized by category: -- Session Commands -- Model Commands -- Configuration -- Token & System -- Conversation Management -- Monitoring & Stats -- File Attachments +### Chat Commands +| Command | Description | +|---------|-------------| +| `/help [cmd]` | Show help | +| `/model [search]` | Select model | +| `/info [model]` | Model details | +| `/memory on\|off` | Toggle context | +| `/online on\|off` | Toggle web search | +| `/retry` | Resend last message | +| `/clear` | Clear screen | -## Configuration Options +### MCP Commands +| Command | Description | +|---------|-------------| +| `/mcp on\|off` | Enable/disable MCP | +| `/mcp status` | Show MCP status | +| `/mcp add ` | Add folder | +| `/mcp add db ` | Add database | +| `/mcp list` | List folders | +| `/mcp db list` | List databases | +| `/mcp db ` | Switch to database | +| `/mcp files` | Switch to file mode | +| `/mcp write on\|off` | Toggle write mode | -- API Key: `/config api` -- Base URL: `/config url` -- Streaming: `/config stream on|off` -- Default Model: `/config model` -- Cost Warning: `/config costwarning ` -- Max Token Limit: `/config maxtoken ` +### Conversation Commands +| Command | Description | +|---------|-------------| +| `/save ` | Save conversation | +| `/load ` | Load conversation | +| `/list` | List saved conversations | +| `/delete ` | Delete conversation | +| `/export md\|json\|html ` | Export | -## File Support +### Configuration +| Command | Description | +|---------|-------------| +| `/config` | View settings | +| `/config api` | Set API key | +| `/config model ` | Set default model | +| `/config stream on\|off` | Toggle streaming | +| `/stats` | Session statistics | +| `/credits` | Check credits | -**Supported Code Extensions:** -.py, .js, .ts, .cs, .java, .c, .cpp, .h, .hpp, .rb, .ruby, .php, .swift, .kt, .kts, .go, .sh, .bat, .ps1, .R, .scala, .pl, .lua, .dart, .elm, .xml, .json, .yaml, .yml, .md, .txt +## CLI Options -**Image Support:** -Any image format with proper MIME type (PNG, JPEG, GIF, etc.) +```bash +oai chat [OPTIONS] -## Data Storage +Options: + -m, --model TEXT Model ID to use + -s, --system TEXT System prompt + -o, --online Enable online mode + --mcp Enable MCP server + --help Show help +``` -- Configuration: `~/.config/oai/oai_config.db` -- Logs: `~/.config/oai/oai.log` -- History: `~/.config/oai/history.txt` +Other commands: +```bash +oai config [setting] [value] # Configure settings +oai version # Show version +oai credits # Check credits +``` + +## Configuration + +Configuration is stored in `~/.config/oai/`: + +| File | Purpose | +|------|---------| +| `oai_config.db` | Settings, conversations, MCP config | +| `oai.log` | Application logs | +| `history.txt` | Command history | + +## Project Structure + +``` +oai/ +├── oai/ +│ ├── __init__.py +│ ├── __main__.py # Entry point for python -m oai +│ ├── cli.py # Main CLI interface +│ ├── constants.py # Configuration constants +│ ├── commands/ # Slash command handlers +│ ├── config/ # Settings and database +│ ├── core/ # Chat client and session +│ ├── mcp/ # MCP server and tools +│ ├── providers/ # AI provider abstraction +│ ├── ui/ # Terminal UI utilities +│ └── utils/ # Logging, export, etc. +├── pyproject.toml # Package configuration +├── build.sh # Binary build script +└── README.md +``` + +## Troubleshooting + +### macOS Binary Issues + +```bash +# Remove quarantine attribute +xattr -cr ~/.local/bin/oai + +# Then in Finder: right-click oai → Open With → Terminal → Click "Open" +# After this, oai works from any terminal +``` + +### MCP Not Working + +```bash +# Check if model supports function calling +/info # Look for "tools" in supported parameters + +# Check MCP status +/mcp status + +# View logs +tail -f ~/.config/oai/oai.log +``` + +### Import Errors + +```bash +# Reinstall package +pip install -e . --force-reinstall +``` + +## Version History + +### v2.1.0 (Current) +- 🏗️ Complete codebase refactoring to modular package structure +- 🔌 Extensible provider architecture for adding new AI providers +- 📦 Proper Python packaging with pyproject.toml +- ✨ MCP integration (file access, write mode, database queries) +- 🔧 Command registry pattern for slash commands +- 📊 Improved cost tracking and session statistics + +### v1.9.x +- Single-file implementation +- Core chat functionality +- File attachments +- Conversation management ## License -MIT License - -Copyright (c) 2024 Rune Olsen - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. - -Full license: https://opensource.org/licenses/MIT +MIT License - See [LICENSE](LICENSE) for details. ## Author **Rune Olsen** +- Project: https://iurl.no/oai +- Repository: https://gitlab.pm/rune/oai -Blog: https://blog.rune.pm +## Contributing -## Version +1. Fork the repository +2. Create a feature branch +3. Submit a pull request -1.0 +--- -## Support - -For issues, questions, or contributions, visit https://iurl.no/oai and create an issue. \ No newline at end of file +**⭐ Star this project if you find it useful!** diff --git a/oai.py b/oai.py deleted file mode 100644 index 26b92ef..0000000 --- a/oai.py +++ /dev/null @@ -1,2273 +0,0 @@ -#!/usr/bin/python3 -W ignore::DeprecationWarning -import sys -import os -import requests -import time # For response time tracking -from pathlib import Path -from typing import Optional, List, Dict, Any -import typer -from rich.console import Console -from rich.panel import Panel -from rich.table import Table -from rich.text import Text -from rich.markdown import Markdown -from rich.live import Live -from openrouter import OpenRouter -import pyperclip -import mimetypes -import base64 -import re -import sqlite3 -import json -import datetime -import logging -from logging.handlers import RotatingFileHandler # Added for log rotation -from prompt_toolkit import PromptSession -from prompt_toolkit.history import FileHistory -from rich.logging import RichHandler -from prompt_toolkit.auto_suggest import AutoSuggestFromHistory -from packaging import version as pkg_version -import io # Added for custom handler - -# App version. Changes by author with new releases. -version = '1.9.6' - -app = typer.Typer() - -# Application identification for OpenRouter -APP_NAME = "oAI" -APP_URL = "https://iurl.no/oai" - -# Paths -home = Path.home() -config_dir = home / '.config' / 'oai' -cache_dir = home / '.cache' / 'oai' -history_file = config_dir / 'history.txt' # Persistent input history file -database = config_dir / 'oai_config.db' -log_file = config_dir / 'oai.log' - -# Create dirs if needed -config_dir.mkdir(parents=True, exist_ok=True) -cache_dir.mkdir(parents=True, exist_ok=True) - -# Rich console for chat UI (separate from logging) -console = Console() - -# Valid commands list for validation -VALID_COMMANDS = { - '/retry', '/online', '/memory', '/paste', '/export', '/save', '/load', - '/delete', '/list', '/prev', '/next', '/stats', '/middleout', '/reset', - '/info', '/model', '/maxtoken', '/system', '/config', '/credits', '/clear', '/cl', '/help' -} - -# Detailed command help database -COMMAND_HELP = { - '/clear': { - 'aliases': ['/cl'], - 'description': 'Clear the terminal screen for a clean interface.', - 'usage': '/clear\n/cl', - 'examples': [ - ('Clear screen', '/clear'), - ('Using short alias', '/cl'), - ], - 'notes': 'You can also use the keyboard shortcut Ctrl+L to clear the screen.' - }, - '/help': { - 'description': 'Display help information for commands.', - 'usage': '/help [command]', - 'examples': [ - ('Show all commands', '/help'), - ('Get help for a specific command', '/help /model'), - ('Get help for config', '/help /config'), - ], - 'notes': 'Use /help without arguments to see the full command list, or /help for detailed information about a specific command.' - }, - '/memory': { - 'description': 'Toggle conversation memory. When ON, the AI remembers conversation history. When OFF, each request is independent (saves tokens and cost).', - 'usage': '/memory [on|off]', - 'examples': [ - ('Check current memory status', '/memory'), - ('Enable conversation memory', '/memory on'), - ('Disable memory (save costs)', '/memory off'), - ], - 'notes': 'Memory is ON by default. Disabling memory reduces API costs but the AI won\'t remember previous messages. Messages are still saved locally for your reference.' - }, - '/online': { - 'description': 'Enable or disable online mode (web search capabilities) for the current session.', - 'usage': '/online [on|off]', - 'examples': [ - ('Check online mode status', '/online'), - ('Enable web search', '/online on'), - ('Disable web search', '/online off'), - ], - 'notes': 'Not all models support online mode. The model must have "tools" parameter support. This setting overrides the default online mode configured with /config online.' - }, - '/paste': { - 'description': 'Paste plain text or code from clipboard and send to the AI. Optionally add a prompt.', - 'usage': '/paste [prompt]', - 'examples': [ - ('Paste clipboard content', '/paste'), - ('Paste with a question', '/paste Explain this code'), - ('Paste and ask for review', '/paste Review this for bugs'), - ], - 'notes': 'Only plain text is supported. Binary clipboard data will be rejected. The clipboard content is shown as a preview before sending.' - }, - '/retry': { - 'description': 'Resend the last prompt from conversation history.', - 'usage': '/retry', - 'examples': [ - ('Retry last message', '/retry'), - ], - 'notes': 'Useful when you get an error or want a different response to the same prompt. Requires at least one message in history.' - }, - '/next': { - 'description': 'View the next response in conversation history.', - 'usage': '/next', - 'examples': [ - ('Navigate to next response', '/next'), - ], - 'notes': 'Navigate through conversation history. Use /prev to go backward.' - }, - '/prev': { - 'description': 'View the previous response in conversation history.', - 'usage': '/prev', - 'examples': [ - ('Navigate to previous response', '/prev'), - ], - 'notes': 'Navigate through conversation history. Use /next to go forward.' - }, - '/reset': { - 'description': 'Clear conversation history and reset system prompt. This resets all session metrics.', - 'usage': '/reset', - 'examples': [ - ('Reset conversation', '/reset'), - ], - 'notes': 'Requires confirmation. This clears all message history, resets the system prompt, and resets token/cost counters. Use when starting a completely new conversation topic.' - }, - '/info': { - 'description': 'Display detailed information about a model including pricing, capabilities, context length, and online support.', - 'usage': '/info [model_id]', - 'examples': [ - ('Show current model info', '/info'), - ('Show specific model info', '/info gpt-4o'), - ('Check model capabilities', '/info claude-3-opus'), - ], - 'notes': 'Without arguments, shows info for the currently selected model. Displays pricing per million tokens, supported modalities (text, image, etc.), and parameter support.' - }, - '/model': { - 'description': 'Select or change the AI model for the current session. Shows image and online capabilities.', - 'usage': '/model [search_term]', - 'examples': [ - ('List all models', '/model'), - ('Search for GPT models', '/model gpt'), - ('Search for Claude models', '/model claude'), - ], - 'notes': 'Models are numbered for easy selection. The table shows Image (✓ if model accepts images) and Online (✓ if model supports web search) columns.' - }, - '/config': { - 'description': 'View or modify application configuration settings.', - 'usage': '/config [setting] [value]', - 'examples': [ - ('View all settings', '/config'), - ('Set API key', '/config api'), - ('Set default model', '/config model'), - ('Enable streaming', '/config stream on'), - ('Set cost warning threshold', '/config costwarning 0.05'), - ('Set log level', '/config loglevel debug'), - ('Set default online mode', '/config online on'), - ], - 'notes': 'Available settings: api (API key), url (base URL), model (default model), stream (on/off), costwarning (threshold $), maxtoken (limit), online (default on/off), log (size MB), loglevel (debug/info/warning/error/critical).' - }, - '/maxtoken': { - 'description': 'Set a temporary session token limit (cannot exceed stored max token limit).', - 'usage': '/maxtoken [value]', - 'examples': [ - ('View current session limit', '/maxtoken'), - ('Set session limit to 2000', '/maxtoken 2000'), - ('Set to 50000', '/maxtoken 50000'), - ], - 'notes': 'This is a session-only setting and cannot exceed the stored max token limit (set with /config maxtoken). View without arguments to see current value.' - }, - '/system': { - 'description': 'Set or clear the session-level system prompt to guide AI behavior.', - 'usage': '/system [prompt|clear]', - 'examples': [ - ('View current system prompt', '/system'), - ('Set as Python expert', '/system You are a Python expert'), - ('Set as code reviewer', '/system You are a senior code reviewer. Focus on bugs and best practices.'), - ('Clear system prompt', '/system clear'), - ], - 'notes': 'System prompts influence how the AI responds throughout the session. Use "clear" to remove the current system prompt.' - }, - '/save': { - 'description': 'Save the current conversation history to the database.', - 'usage': '/save ', - 'examples': [ - ('Save conversation', '/save my_chat'), - ('Save with descriptive name', '/save python_debugging_2024'), - ], - 'notes': 'Saved conversations can be loaded later with /load. Use descriptive names to easily find them later.' - }, - '/load': { - 'description': 'Load a saved conversation from the database by name or number.', - 'usage': '/load ', - 'examples': [ - ('Load by name', '/load my_chat'), - ('Load by number from /list', '/load 3'), - ], - 'notes': 'Use /list to see numbered conversations. Loading a conversation replaces current history and resets session metrics.' - }, - '/delete': { - 'description': 'Delete a saved conversation from the database. Requires confirmation.', - 'usage': '/delete ', - 'examples': [ - ('Delete by name', '/delete my_chat'), - ('Delete by number from /list', '/delete 3'), - ], - 'notes': 'Use /list to see numbered conversations. This action cannot be undone!' - }, - '/list': { - 'description': 'List all saved conversations with numbers, message counts, and timestamps.', - 'usage': '/list', - 'examples': [ - ('Show saved conversations', '/list'), - ], - 'notes': 'Conversations are numbered for easy use with /load and /delete commands. Shows message count and last saved time.' - }, - '/export': { - 'description': 'Export the current conversation to a file in various formats.', - 'usage': '/export ', - 'examples': [ - ('Export as Markdown', '/export md notes.md'), - ('Export as JSON', '/export json conversation.json'), - ('Export as HTML', '/export html report.html'), - ], - 'notes': 'Available formats: md (Markdown), json (JSON), html (HTML). The export includes all messages and the system prompt if set.' - }, - '/stats': { - 'description': 'Display session statistics including tokens, costs, and credits.', - 'usage': '/stats', - 'examples': [ - ('View session statistics', '/stats'), - ], - 'notes': 'Shows total input/output tokens, total cost, average cost per message, and remaining credits. Also displays credit warnings if applicable.' - }, - '/credits': { - 'description': 'Display your OpenRouter account credits and usage.', - 'usage': '/credits', - 'examples': [ - ('Check credits', '/credits'), - ], - 'notes': 'Shows total credits, used credits, and credits left. Displays warnings if credits are low (< $1 or < 10% of total).' - }, - '/middleout': { - 'description': 'Enable or disable middle-out transform to compress prompts exceeding context size.', - 'usage': '/middleout [on|off]', - 'examples': [ - ('Check status', '/middleout'), - ('Enable compression', '/middleout on'), - ('Disable compression', '/middleout off'), - ], - 'notes': 'Middle-out transform intelligently compresses long prompts to fit within model context limits. Useful for very long conversations.' - }, -} - -# Supported code file extensions -SUPPORTED_CODE_EXTENSIONS = { - '.py', '.js', '.ts', '.cs', '.java', '.c', '.cpp', '.h', '.hpp', - '.rb', '.ruby', '.php', '.swift', '.kt', '.kts', '.go', - '.sh', '.bat', '.ps1', '.R', '.scala', '.pl', '.lua', '.dart', - '.elm', '.xml', '.json', '.yaml', '.yml', '.md', '.txt' -} - -# Session metrics constants (per 1M tokens, in USD; adjustable) -MODEL_PRICING = { - 'input': 3.0, # $3/M input tokens (adjustable) - 'output': 15.0 # $15/M output tokens (adjustable) -} -LOW_CREDIT_RATIO = 0.1 # Warn if credits left < 10% of total -LOW_CREDIT_AMOUNT = 1.0 # Warn if credits left < $1 in absolute terms -HIGH_COST_WARNING = "cost_warning_threshold" # Configurable key for cost threshold, default $0.01 - -# Valid log levels mapping -VALID_LOG_LEVELS = { - 'debug': logging.DEBUG, - 'info': logging.INFO, - 'warning': logging.WARNING, - 'error': logging.ERROR, - 'critical': logging.CRITICAL -} - -# DB configuration -database = config_dir / 'oai_config.db' -DB_FILE = str(database) - -def create_table_if_not_exists(): - """Ensure the config and conversation_sessions tables exist.""" - os.makedirs(config_dir, exist_ok=True) - with sqlite3.connect(DB_FILE) as conn: - conn.execute('''CREATE TABLE IF NOT EXISTS config ( - key TEXT PRIMARY KEY, - value TEXT NOT NULL - )''') - conn.execute('''CREATE TABLE IF NOT EXISTS conversation_sessions ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - name TEXT NOT NULL, - timestamp TEXT NOT NULL, - data TEXT NOT NULL -- JSON of session_history - )''') - conn.commit() - -def get_config(key: str) -> Optional[str]: - create_table_if_not_exists() - with sqlite3.connect(DB_FILE) as conn: - cursor = conn.execute('SELECT value FROM config WHERE key = ?', (key,)) - result = cursor.fetchone() - return result[0] if result else None - -def set_config(key: str, value: str): - create_table_if_not_exists() - with sqlite3.connect(DB_FILE) as conn: - conn.execute('INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)', (key, value)) - conn.commit() - -# ============================================================================ -# ROTATING RICH HANDLER - Combines RotatingFileHandler with Rich formatting -# ============================================================================ -class RotatingRichHandler(RotatingFileHandler): - """Custom handler that combines RotatingFileHandler with Rich formatting.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # Create a Rich console that writes to a string buffer - self.rich_console = Console(file=io.StringIO(), width=120, force_terminal=False) - self.rich_handler = RichHandler( - console=self.rich_console, - show_time=True, - show_path=True, - rich_tracebacks=True, - tracebacks_suppress=['requests', 'openrouter', 'urllib3', 'httpx', 'openai'] - ) - - def emit(self, record): - try: - # Let RichHandler format the record - self.rich_handler.emit(record) - - # Get the formatted output from the string buffer - output = self.rich_console.file.getvalue() - - # Clear the buffer for next use - self.rich_console.file.seek(0) - self.rich_console.file.truncate(0) - - # Write the Rich-formatted output to our rotating file - if output: - self.stream.write(output) - self.flush() - - except Exception: - self.handleError(record) - -# ============================================================================ -# LOGGING SETUP - MUST BE DONE AFTER CONFIG IS LOADED -# ============================================================================ - -# Load log configuration from DB FIRST (before creating handler) -LOG_MAX_SIZE_MB = int(get_config('log_max_size_mb') or "10") -LOG_BACKUP_COUNT = int(get_config('log_backup_count') or "2") -LOG_LEVEL_STR = get_config('log_level') or "info" -LOG_LEVEL = VALID_LOG_LEVELS.get(LOG_LEVEL_STR.lower(), logging.INFO) - -# Global reference to the handler for dynamic reloading -app_handler = None -app_logger = None - -def setup_logging(): - """Setup or reset logging configuration with current settings.""" - global app_handler, LOG_MAX_SIZE_MB, LOG_BACKUP_COUNT, LOG_LEVEL, app_logger - - # Get the root logger - root_logger = logging.getLogger() - - # Remove existing handler if present - if app_handler is not None: - root_logger.removeHandler(app_handler) - try: - app_handler.close() - except: - pass - - # Check if log file needs immediate rotation - if os.path.exists(log_file): - current_size = os.path.getsize(log_file) - max_bytes = LOG_MAX_SIZE_MB * 1024 * 1024 - - if current_size >= max_bytes: - # Perform immediate rotation - import shutil - timestamp = datetime.datetime.now().strftime('%Y%m%d_%H%M%S') - backup_file = f"{log_file}.{timestamp}" - try: - shutil.move(str(log_file), backup_file) - except Exception as e: - print(f"Warning: Could not rotate log file: {e}") - - # Clean up old backups if exceeding limit - log_dir = os.path.dirname(log_file) - log_basename = os.path.basename(log_file) - backup_pattern = f"{log_basename}.*" - - import glob - backups = sorted(glob.glob(os.path.join(log_dir, backup_pattern))) - - # Keep only the most recent backups - while len(backups) > LOG_BACKUP_COUNT: - oldest = backups.pop(0) - try: - os.remove(oldest) - except: - pass - - # Create new handler with current settings - app_handler = RotatingRichHandler( - filename=str(log_file), - maxBytes=LOG_MAX_SIZE_MB * 1024 * 1024, - backupCount=LOG_BACKUP_COUNT, - encoding='utf-8' - ) - - # Set handler level to NOTSET so it processes all records - app_handler.setLevel(logging.NOTSET) - - # Configure root logger - set to WARNING to suppress third-party library noise - root_logger.setLevel(logging.WARNING) - root_logger.addHandler(app_handler) - - # Suppress noisy third-party loggers - # These libraries create DEBUG logs that pollute our log file - logging.getLogger('asyncio').setLevel(logging.WARNING) - logging.getLogger('urllib3').setLevel(logging.WARNING) - logging.getLogger('requests').setLevel(logging.WARNING) - logging.getLogger('httpx').setLevel(logging.WARNING) - logging.getLogger('httpcore').setLevel(logging.WARNING) - logging.getLogger('openai').setLevel(logging.WARNING) - logging.getLogger('openrouter').setLevel(logging.WARNING) - - # Get or create app logger and set its level (this filters what gets logged) - app_logger = logging.getLogger("oai_app") - app_logger.setLevel(LOG_LEVEL) - # Don't propagate to avoid root logger filtering - app_logger.propagate = True - - return app_logger - -# Initial logging setup -app_logger = setup_logging() - -def set_log_level(level_str: str) -> bool: - """Set the application log level. Returns True if successful.""" - global LOG_LEVEL, LOG_LEVEL_STR, app_logger - level_str_lower = level_str.lower() - if level_str_lower not in VALID_LOG_LEVELS: - return False - LOG_LEVEL = VALID_LOG_LEVELS[level_str_lower] - LOG_LEVEL_STR = level_str_lower - - # Update the logger level immediately - if app_logger: - app_logger.setLevel(LOG_LEVEL) - - return True - -def reload_logging_config(): - """Reload logging configuration from database and reinitialize handler.""" - global LOG_MAX_SIZE_MB, LOG_BACKUP_COUNT, LOG_LEVEL, LOG_LEVEL_STR, app_logger - - # Reload from database - LOG_MAX_SIZE_MB = int(get_config('log_max_size_mb') or "10") - LOG_BACKUP_COUNT = int(get_config('log_backup_count') or "2") - LOG_LEVEL_STR = get_config('log_level') or "info" - LOG_LEVEL = VALID_LOG_LEVELS.get(LOG_LEVEL_STR.lower(), logging.INFO) - - # Reinitialize logging - app_logger = setup_logging() - - return app_logger - -# ============================================================================ -# END OF LOGGING SETUP -# ============================================================================ - -logger = logging.getLogger(__name__) - -def check_for_updates(current_version: str) -> str: - """ - Check if a new version is available using semantic versioning. - - Returns: - Formatted status string for display - """ - try: - response = requests.get( - 'https://gitlab.pm/api/v1/repos/rune/oai/releases/latest', - headers={"Content-Type": "application/json"}, - timeout=1.0, - allow_redirects=True - ) - response.raise_for_status() - - data = response.json() - version_online = data.get('tag_name', '').lstrip('v') - - if not version_online: - logger.warning("No version found in API response") - return f"[bold green]oAI version {current_version}[/]" - - current = pkg_version.parse(current_version) - latest = pkg_version.parse(version_online) - - if latest > current: - logger.info(f"Update available: {current_version} → {version_online}") - return f"[bold green]oAI version {current_version} [/][bold red](Update available: {current_version} → {version_online})[/]" - else: - logger.debug(f"Already up to date: {current_version}") - return f"[bold green]oAI version {current_version} (up to date)[/]" - - except requests.exceptions.HTTPError as e: - logger.warning(f"HTTP error checking for updates: {e.response.status_code}") - return f"[bold green]oAI version {current_version}[/]" - except requests.exceptions.ConnectionError: - logger.warning("Network error checking for updates (offline?)") - return f"[bold green]oAI version {current_version}[/]" - except requests.exceptions.Timeout: - logger.warning("Timeout checking for updates") - return f"[bold green]oAI version {current_version}[/]" - except requests.exceptions.RequestException as e: - logger.warning(f"Request error checking for updates: {type(e).__name__}") - return f"[bold green]oAI version {current_version}[/]" - except (KeyError, ValueError) as e: - logger.warning(f"Invalid API response checking for updates: {e}") - return f"[bold green]oAI version {current_version}[/]" - except Exception as e: - logger.error(f"Unexpected error checking for updates: {e}") - return f"[bold green]oAI version {current_version}[/]" - -def save_conversation(name: str, data: List[Dict[str, str]]): - """Save conversation history to DB.""" - timestamp = datetime.datetime.now().isoformat() - data_json = json.dumps(data) - with sqlite3.connect(DB_FILE) as conn: - conn.execute('INSERT INTO conversation_sessions (name, timestamp, data) VALUES (?, ?, ?)', (name, timestamp, data_json)) - conn.commit() - -def load_conversation(name: str) -> Optional[List[Dict[str, str]]]: - """Load conversation history from DB (latest by timestamp).""" - with sqlite3.connect(DB_FILE) as conn: - cursor = conn.execute('SELECT data FROM conversation_sessions WHERE name = ? ORDER BY timestamp DESC LIMIT 1', (name,)) - result = cursor.fetchone() - if result: - return json.loads(result[0]) - return None - -def delete_conversation(name: str) -> int: - """Delete all conversation sessions with the given name. Returns number of deleted rows.""" - with sqlite3.connect(DB_FILE) as conn: - cursor = conn.execute('DELETE FROM conversation_sessions WHERE name = ?', (name,)) - conn.commit() - return cursor.rowcount - -def list_conversations() -> List[Dict[str, Any]]: - """List all saved conversations from DB with metadata.""" - with sqlite3.connect(DB_FILE) as conn: - cursor = conn.execute(''' - SELECT name, MAX(timestamp) as last_saved, data - FROM conversation_sessions - GROUP BY name - ORDER BY last_saved DESC - ''') - conversations = [] - for row in cursor.fetchall(): - name, timestamp, data_json = row - data = json.loads(data_json) - conversations.append({ - 'name': name, - 'timestamp': timestamp, - 'message_count': len(data) - }) - return conversations - -def estimate_cost(input_tokens: int, output_tokens: int) -> float: - """Estimate cost in USD based on token counts.""" - return (input_tokens * MODEL_PRICING['input'] / 1_000_000) + (output_tokens * MODEL_PRICING['output'] / 1_000_000) - -def has_web_search_capability(model: Dict[str, Any]) -> bool: - """Check if model supports web search based on supported_parameters.""" - supported_params = model.get("supported_parameters", []) - # Web search is typically indicated by 'tools' parameter support - return "tools" in supported_params - -def has_image_capability(model: Dict[str, Any]) -> bool: - """Check if model supports image input based on input modalities.""" - architecture = model.get("architecture", {}) - input_modalities = architecture.get("input_modalities", []) - return "image" in input_modalities - -def supports_online_mode(model: Dict[str, Any]) -> bool: - """Check if model supports :online suffix for web search.""" - # Models that support tools parameter can use :online - return has_web_search_capability(model) - -def get_effective_model_id(base_model_id: str, online_enabled: bool) -> str: - """Get the effective model ID with :online suffix if enabled.""" - if online_enabled and not base_model_id.endswith(':online'): - return f"{base_model_id}:online" - return base_model_id - -def export_as_markdown(session_history: List[Dict[str, str]], session_system_prompt: str = "") -> str: - """Export conversation history as Markdown.""" - lines = ["# Conversation Export", ""] - if session_system_prompt: - lines.extend([f"**System Prompt:** {session_system_prompt}", ""]) - lines.append(f"**Export Date:** {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") - lines.append("") - lines.append("---") - lines.append("") - - for i, entry in enumerate(session_history, 1): - lines.append(f"## Message {i}") - lines.append("") - lines.append("**User:**") - lines.append("") - lines.append(entry['prompt']) - lines.append("") - lines.append("**Assistant:**") - lines.append("") - lines.append(entry['response']) - lines.append("") - lines.append("---") - lines.append("") - - return "\n".join(lines) - -def export_as_json(session_history: List[Dict[str, str]], session_system_prompt: str = "") -> str: - """Export conversation history as JSON.""" - export_data = { - "export_date": datetime.datetime.now().isoformat(), - "system_prompt": session_system_prompt, - "message_count": len(session_history), - "messages": session_history - } - return json.dumps(export_data, indent=2, ensure_ascii=False) - -def export_as_html(session_history: List[Dict[str, str]], session_system_prompt: str = "") -> str: - """Export conversation history as HTML.""" - # Escape HTML special characters - def escape_html(text): - return text.replace('&', '&').replace('<', '<').replace('>', '>').replace('"', '"').replace("'", ''') - - html_parts = [ - "", - "", - "", - " ", - " ", - " Conversation Export", - " ", - "", - "", - "
", - "

💬 Conversation Export

", - f"
📅 Exported: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
", - f"
📊 Total Messages: {len(session_history)}
", - "
", - ] - - if session_system_prompt: - html_parts.extend([ - "
", - " ⚙️ System Prompt", - f"
{escape_html(session_system_prompt)}
", - "
", - ]) - - for i, entry in enumerate(session_history, 1): - html_parts.extend([ - "
", - f"
Message {i} of {len(session_history)}
", - "
", - "
👤 User
", - f"
{escape_html(entry['prompt'])}
", - "
", - "
", - "
🤖 Assistant
", - f"
{escape_html(entry['response'])}
", - "
", - "
", - ]) - - html_parts.extend([ - "
", - "

Generated by oAI Chat • https://iurl.no/oai

", - "
", - "", - "", - ]) - - return "\n".join(html_parts) - -def show_command_help(command: str): - """Display detailed help for a specific command.""" - # Normalize command to ensure it starts with / - if not command.startswith('/'): - command = '/' + command - - # Check if command exists - if command not in COMMAND_HELP: - console.print(f"[bold red]Unknown command: {command}[/]") - console.print("[bold yellow]Type /help to see all available commands.[/]") - app_logger.warning(f"Help requested for unknown command: {command}") - return - - help_data = COMMAND_HELP[command] - - # Create detailed help panel - help_content = [] - - # Aliases if available - if 'aliases' in help_data: - aliases_str = ", ".join(help_data['aliases']) - help_content.append(f"[bold cyan]Aliases:[/] {aliases_str}") - help_content.append("") - - # Description - help_content.append(f"[bold cyan]Description:[/]") - help_content.append(help_data['description']) - help_content.append("") - - # Usage - help_content.append(f"[bold cyan]Usage:[/]") - help_content.append(f"[yellow]{help_data['usage']}[/]") - help_content.append("") - - # Examples - if 'examples' in help_data and help_data['examples']: - help_content.append(f"[bold cyan]Examples:[/]") - for desc, example in help_data['examples']: - help_content.append(f" [dim]{desc}:[/]") - help_content.append(f" [green]{example}[/]") - help_content.append("") - - # Notes - if 'notes' in help_data: - help_content.append(f"[bold cyan]Notes:[/]") - help_content.append(f"[dim]{help_data['notes']}[/]") - - console.print(Panel( - "\n".join(help_content), - title=f"[bold green]Help: {command}[/]", - title_align="left", - border_style="green", - width=console.width - 4 - )) - - app_logger.info(f"Displayed detailed help for command: {command}") - -# Load configs (AFTER logging is set up) -API_KEY = get_config('api_key') -OPENROUTER_BASE_URL = get_config('base_url') or "https://openrouter.ai/api/v1" -STREAM_ENABLED = get_config('stream_enabled') or "on" -DEFAULT_MODEL_ID = get_config('default_model') -MAX_TOKEN = int(get_config('max_token') or "100000") -COST_WARNING_THRESHOLD = float(get_config(HIGH_COST_WARNING) or "0.01") # Configurable cost threshold for alerts -DEFAULT_ONLINE_MODE = get_config('default_online_mode') or "off" # New: Default online mode setting - -# Fetch models with app identification headers -models_data = [] -text_models = [] -try: - headers = { - "Authorization": f"Bearer {API_KEY}", - "HTTP-Referer": APP_URL, - "X-Title": APP_NAME - } if API_KEY else { - "HTTP-Referer": APP_URL, - "X-Title": APP_NAME - } - response = requests.get(f"{OPENROUTER_BASE_URL}/models", headers=headers) - response.raise_for_status() - models_data = response.json()["data"] - text_models = [m for m in models_data if "modalities" not in m or "video" not in (m.get("modalities") or [])] - selected_model_default = None - if DEFAULT_MODEL_ID: - selected_model_default = next((m for m in text_models if m["id"] == DEFAULT_MODEL_ID), None) - if not selected_model_default: - console.print(f"[bold yellow]Warning: Default model '{DEFAULT_MODEL_ID}' unavailable. Use '/config model'.[/]") -except Exception as e: - models_data = [] - text_models = [] - app_logger.error(f"Failed to fetch models: {e}") - -def get_credits(api_key: str, base_url: str = OPENROUTER_BASE_URL) -> Optional[Dict[str, str]]: - if not api_key: - return None - url = f"{base_url}/credits" - headers = { - "Authorization": f"Bearer {api_key}", - "HTTP-Referer": APP_URL, - "X-Title": APP_NAME - } - try: - response = requests.get(url, headers=headers) - response.raise_for_status() - data = response.json().get('data', {}) - total_credits = float(data.get('total_credits', 0)) - total_usage = float(data.get('total_usage', 0)) - credits_left = total_credits - total_usage - return { - 'total_credits': f"${total_credits:.2f}", - 'used_credits': f"${total_usage:.2f}", - 'credits_left': f"${credits_left:.2f}" - } - except Exception as e: - console.print(f"[bold red]Error fetching credits: {e}[/]") - return None - -def check_credit_alerts(credits_data: Optional[Dict[str, str]]) -> List[str]: - """Check and return list of credit-related alerts.""" - alerts = [] - if credits_data: - credits_left_value = float(credits_data['credits_left'].strip('$')) - total_credits_value = float(credits_data['total_credits'].strip('$')) - if credits_left_value < LOW_CREDIT_AMOUNT: - alerts.append(f"Critical credit alert: Less than ${LOW_CREDIT_AMOUNT:.2f} left ({credits_data['credits_left']})") - elif credits_left_value < total_credits_value * LOW_CREDIT_RATIO: - alerts.append(f"Low credit alert: Credits left < 10% of total ({credits_data['credits_left']})") - return alerts - -def clear_screen(): - try: - print("\033[H\033[J", end="", flush=True) - except: - print("\n" * 100) - -def display_paginated_table(table: Table, title: str): - """Display a table with pagination support using Rich console for colored output, repeating header on each page.""" - # Get terminal height (subtract some lines for prompt and margins) - try: - terminal_height = os.get_terminal_size().lines - 8 - except: - terminal_height = 20 # Fallback if terminal size can't be determined - - # Create a segment-based approach to capture Rich-rendered output - from rich.segment import Segment - - # Render the table to segments - segments = list(console.render(table)) - - # Convert segments to lines while preserving style - current_line_segments = [] - all_lines = [] - - for segment in segments: - if segment.text == '\n': - all_lines.append(current_line_segments) - current_line_segments = [] - else: - current_line_segments.append(segment) - - # Add last line if not empty - if current_line_segments: - all_lines.append(current_line_segments) - - total_lines = len(all_lines) - - # If fits on one screen after segment analysis - if total_lines <= terminal_height: - console.print(Panel(table, title=title, title_align="left")) - return - - # Separate header from data rows - header_lines = [] - data_lines = [] - - # Find where the header ends - header_end_index = 0 - found_header_text = False - - for i, line_segments in enumerate(all_lines): - # Check if this line contains header-style text - has_header_style = any( - seg.style and ('bold' in str(seg.style) or 'magenta' in str(seg.style)) - for seg in line_segments - ) - - if has_header_style: - found_header_text = True - - # After finding header text, the next line with box-drawing chars is the separator - if found_header_text and i > 0: - line_text = ''.join(seg.text for seg in line_segments) - if any(char in line_text for char in ['─', '━', '┼', '╪', '┤', '├']): - header_end_index = i - break - - # If we found a header separator, split there - if header_end_index > 0: - header_lines = all_lines[:header_end_index + 1] - data_lines = all_lines[header_end_index + 1:] - else: - # Fallback: assume first 3 lines are header - header_lines = all_lines[:min(3, len(all_lines))] - data_lines = all_lines[min(3, len(all_lines)):] - - # Calculate how many data lines fit per page - lines_per_page = terminal_height - len(header_lines) - - # Display with pagination - current_line = 0 - page_number = 1 - - while current_line < len(data_lines): - # Clear screen for each page - clear_screen() - - # Print title - console.print(f"[bold cyan]{title} (Page {page_number})[/]") - - # Print header on every page - for line_segments in header_lines: - for segment in line_segments: - console.print(segment.text, style=segment.style, end="") - console.print() - - # Calculate how many data lines to show on this page - end_line = min(current_line + lines_per_page, len(data_lines)) - - # Print data lines for this page - for line_segments in data_lines[current_line:end_line]: - for segment in line_segments: - console.print(segment.text, style=segment.style, end="") - console.print() - - # Update position - current_line = end_line - page_number += 1 - - # If there's more content, wait for user - if current_line < len(data_lines): - console.print(f"\n[dim yellow]--- Press SPACE for next page, or any other key to finish (Page {page_number - 1}, showing {end_line}/{len(data_lines)} data rows) ---[/dim yellow]") - try: - import sys - import tty - import termios - - fd = sys.stdin.fileno() - old_settings = termios.tcgetattr(fd) - - try: - tty.setraw(fd) - char = sys.stdin.read(1) - - if char != ' ': - break - finally: - termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) - except: - # Fallback for Windows or if termios not available - input_char = input().strip() - if input_char != '': - break - else: - break - -@app.command() -def chat(): - global API_KEY, OPENROUTER_BASE_URL, STREAM_ENABLED, MAX_TOKEN, COST_WARNING_THRESHOLD, DEFAULT_ONLINE_MODE, LOG_MAX_SIZE_MB, LOG_BACKUP_COUNT, LOG_LEVEL, LOG_LEVEL_STR, app_logger - session_max_token = 0 - session_system_prompt = "" - session_history = [] - current_index = -1 - total_input_tokens = 0 - total_output_tokens = 0 - total_cost = 0.0 - message_count = 0 - middle_out_enabled = False # Session-level middle-out transform flag - conversation_memory_enabled = True # Memory ON by default - memory_start_index = 0 # Track when memory was last enabled - saved_conversations_cache = [] # Cache for /list results to use with /load by number - online_mode_enabled = DEFAULT_ONLINE_MODE == "on" # Initialize from config - - app_logger.info("Starting new chat session with memory enabled") - - if not API_KEY: - console.print("[bold red]API key not found. Use '/config api'.[/]") - try: - new_api_key = typer.prompt("Enter API key") - if new_api_key.strip(): - set_config('api_key', new_api_key.strip()) - API_KEY = new_api_key.strip() - console.print("[bold green]API key saved. Re-run.[/]") - else: - raise typer.Exit() - except: - console.print("[bold red]No API key. Exiting.[/]") - raise typer.Exit() - - if not text_models: - console.print("[bold red]No models available. Check API key/URL.[/]") - raise typer.Exit() - - # Check for credit alerts at startup - credits_data = get_credits(API_KEY, OPENROUTER_BASE_URL) - startup_credit_alerts = check_credit_alerts(credits_data) - if startup_credit_alerts: - startup_alert_msg = " | ".join(startup_credit_alerts) - console.print(f"[bold red]⚠️ Startup {startup_alert_msg}[/]") - app_logger.warning(f"Startup credit alerts: {startup_alert_msg}") - - selected_model = selected_model_default - - # Initialize OpenRouter client - client = OpenRouter(api_key=API_KEY) - - if selected_model: - online_status = "enabled" if online_mode_enabled else "disabled" - console.print(f"[bold blue]Welcome to oAI![/] [bold red]Active model: {selected_model['name']}[/] [dim cyan](Online mode: {online_status})[/]") - else: - console.print("[bold blue]Welcome to oAI![/] [italic blue]Select a model with '/model'.[/]") - - if not selected_model: - console.print("[bold yellow]No model selected. Use '/model'.[/]") - - # Persistent input history - session = PromptSession(history=FileHistory(str(history_file))) - - while True: - try: - user_input = session.prompt("You> ", auto_suggest=AutoSuggestFromHistory()).strip() - - # Handle // escape sequence - convert to single / and treat as regular text - if user_input.startswith("//"): - user_input = user_input[1:] # Remove first slash, keep the rest - - # Check for unknown commands - elif user_input.startswith("/") and user_input.lower() not in ["exit", "quit", "bye"]: - command_word = user_input.split()[0].lower() if user_input.split() else user_input.lower() - - if not any(command_word.startswith(cmd) for cmd in VALID_COMMANDS): - console.print(f"[bold red]Unknown command: {command_word}[/]") - console.print("[bold yellow]Type /help to see all available commands.[/]") - app_logger.warning(f"Unknown command attempted: {command_word}") - continue - - if user_input.lower() in ["exit", "quit", "bye"]: - total_tokens = total_input_tokens + total_output_tokens - app_logger.info(f"Session ended. Total messages: {message_count}, Total tokens: {total_tokens}, Total cost: ${total_cost:.4f}") - console.print("[bold yellow]Goodbye![/]") - return - - # Commands with logging - if user_input.lower() == "/retry": - if not session_history: - console.print("[bold red]No history to retry.[/]") - app_logger.warning("Retry attempted with no history") - continue - last_prompt = session_history[-1]['prompt'] - console.print("[bold green]Retrying last prompt...[/]") - app_logger.info(f"Retrying prompt: {last_prompt[:100]}...") - user_input = last_prompt - elif user_input.lower().startswith("/online"): - args = user_input[8:].strip() - if not args: - status = "enabled" if online_mode_enabled else "disabled" - default_status = "enabled" if DEFAULT_ONLINE_MODE == "on" else "disabled" - console.print(f"[bold blue]Online mode (web search) {status}.[/]") - console.print(f"[dim blue]Default setting: {default_status} (use '/config online on|off' to change)[/]") - if selected_model: - if supports_online_mode(selected_model): - console.print(f"[dim green]Current model '{selected_model['name']}' supports online mode.[/]") - else: - console.print(f"[dim yellow]Current model '{selected_model['name']}' does not support online mode.[/]") - continue - if args.lower() == "on": - if not selected_model: - console.print("[bold red]No model selected. Select a model first with '/model'.[/]") - continue - if not supports_online_mode(selected_model): - console.print(f"[bold red]Model '{selected_model['name']}' does not support online mode (web search).[/]") - console.print("[dim yellow]Online mode requires models with 'tools' parameter support.[/]") - app_logger.warning(f"Online mode activation failed - model {selected_model['id']} doesn't support it") - continue - online_mode_enabled = True - console.print("[bold green]Online mode enabled for this session. Model will use web search capabilities.[/]") - console.print(f"[dim blue]Effective model ID: {get_effective_model_id(selected_model['id'], True)}[/]") - app_logger.info(f"Online mode enabled for model {selected_model['id']}") - elif args.lower() == "off": - online_mode_enabled = False - console.print("[bold green]Online mode disabled for this session. Model will not use web search.[/]") - if selected_model: - console.print(f"[dim blue]Effective model ID: {selected_model['id']}[/]") - app_logger.info("Online mode disabled") - else: - console.print("[bold yellow]Usage: /online on|off (or /online to view status)[/]") - continue - elif user_input.lower().startswith("/memory"): - args = user_input[8:].strip() - if not args: - status = "enabled" if conversation_memory_enabled else "disabled" - history_count = len(session_history) - memory_start_index if conversation_memory_enabled and memory_start_index < len(session_history) else 0 - console.print(f"[bold blue]Conversation memory {status}.[/]") - if conversation_memory_enabled: - console.print(f"[dim blue]Tracking {history_count} message(s) since memory enabled.[/]") - else: - console.print(f"[dim yellow]Memory disabled. Each request is independent (saves tokens/cost).[/]") - continue - if args.lower() == "on": - conversation_memory_enabled = True - memory_start_index = len(session_history) - console.print("[bold green]Conversation memory enabled. Will remember conversations from this point forward.[/]") - console.print(f"[dim blue]Memory will track messages starting from index {memory_start_index}.[/]") - app_logger.info(f"Conversation memory enabled at index {memory_start_index}") - elif args.lower() == "off": - conversation_memory_enabled = False - console.print("[bold green]Conversation memory disabled. API calls will not include history (lower cost).[/]") - console.print(f"[dim yellow]Note: Messages are still saved locally but not sent to API.[/]") - app_logger.info("Conversation memory disabled") - else: - console.print("[bold yellow]Usage: /memory on|off (or /memory to view status)[/]") - continue - elif user_input.lower().startswith("/paste"): - optional_prompt = user_input[7:].strip() - - try: - clipboard_content = pyperclip.paste() - except Exception as e: - console.print(f"[bold red]Failed to access clipboard: {e}[/]") - app_logger.error(f"Clipboard access error: {e}") - continue - - if not clipboard_content or not clipboard_content.strip(): - console.print("[bold red]Clipboard is empty.[/]") - app_logger.warning("Paste attempted with empty clipboard") - continue - - try: - clipboard_content.encode('utf-8') - - preview_lines = clipboard_content.split('\n')[:10] - preview_text = '\n'.join(preview_lines) - if len(clipboard_content.split('\n')) > 10: - preview_text += "\n... (content truncated for preview)" - - char_count = len(clipboard_content) - line_count = len(clipboard_content.split('\n')) - - console.print(Panel( - preview_text, - title=f"[bold cyan]📋 Clipboard Content Preview ({char_count} chars, {line_count} lines)[/]", - title_align="left", - border_style="cyan" - )) - - if optional_prompt: - final_prompt = f"{optional_prompt}\n\n```\n{clipboard_content}\n```" - console.print(f"[dim blue]Sending with prompt: '{optional_prompt}'[/]") - else: - final_prompt = clipboard_content - console.print("[dim blue]Sending clipboard content without additional prompt[/]") - - user_input = final_prompt - app_logger.info(f"Pasted content from clipboard: {char_count} chars, {line_count} lines, with prompt: {bool(optional_prompt)}") - - except UnicodeDecodeError: - console.print("[bold red]Clipboard contains non-text (binary) data. Only plain text is supported.[/]") - app_logger.error("Paste failed - clipboard contains binary data") - continue - except Exception as e: - console.print(f"[bold red]Error processing clipboard content: {e}[/]") - app_logger.error(f"Clipboard processing error: {e}") - continue - - elif user_input.lower().startswith("/export"): - args = user_input[8:].strip().split(maxsplit=1) - if len(args) != 2: - console.print("[bold red]Usage: /export [/]") - console.print("[bold yellow]Formats: md (Markdown), json (JSON), html (HTML)[/]") - console.print("[bold yellow]Example: /export md my_conversation.md[/]") - continue - - export_format = args[0].lower() - filename = args[1] - - if not session_history: - console.print("[bold red]No conversation history to export.[/]") - continue - - if export_format not in ['md', 'json', 'html']: - console.print("[bold red]Invalid format. Use: md, json, or html[/]") - continue - - try: - if export_format == 'md': - content = export_as_markdown(session_history, session_system_prompt) - elif export_format == 'json': - content = export_as_json(session_history, session_system_prompt) - elif export_format == 'html': - content = export_as_html(session_history, session_system_prompt) - - export_path = Path(filename).expanduser() - with open(export_path, 'w', encoding='utf-8') as f: - f.write(content) - - console.print(f"[bold green]✅ Conversation exported to: {export_path.absolute()}[/]") - console.print(f"[dim blue]Format: {export_format.upper()} | Messages: {len(session_history)} | Size: {len(content)} bytes[/]") - app_logger.info(f"Conversation exported as {export_format} to {export_path} ({len(session_history)} messages)") - except Exception as e: - console.print(f"[bold red]Export failed: {e}[/]") - app_logger.error(f"Export error: {e}") - continue - elif user_input.lower().startswith("/save"): - args = user_input[6:].strip() - if not args: - console.print("[bold red]Usage: /save [/]") - continue - if not session_history: - console.print("[bold red]No history to save.[/]") - continue - save_conversation(args, session_history) - console.print(f"[bold green]Conversation saved as '{args}'.[/]") - app_logger.info(f"Conversation saved as '{args}' with {len(session_history)} messages") - continue - elif user_input.lower().startswith("/load"): - args = user_input[6:].strip() - if not args: - console.print("[bold red]Usage: /load [/]") - console.print("[bold yellow]Tip: Use /list to see numbered conversations[/]") - continue - - conversation_name = None - if args.isdigit(): - conv_number = int(args) - if saved_conversations_cache and 1 <= conv_number <= len(saved_conversations_cache): - conversation_name = saved_conversations_cache[conv_number - 1]['name'] - console.print(f"[bold cyan]Loading conversation #{conv_number}: '{conversation_name}'[/]") - else: - console.print(f"[bold red]Invalid conversation number: {conv_number}[/]") - console.print(f"[bold yellow]Use /list to see available conversations (1-{len(saved_conversations_cache) if saved_conversations_cache else 0})[/]") - continue - else: - conversation_name = args - - loaded_data = load_conversation(conversation_name) - if not loaded_data: - console.print(f"[bold red]Conversation '{conversation_name}' not found.[/]") - app_logger.warning(f"Load failed for '{conversation_name}' - not found") - continue - session_history = loaded_data - current_index = len(session_history) - 1 - if conversation_memory_enabled: - memory_start_index = 0 - total_input_tokens = 0 - total_output_tokens = 0 - total_cost = 0.0 - message_count = 0 - console.print(f"[bold green]Conversation '{conversation_name}' loaded with {len(session_history)} messages.[/]") - app_logger.info(f"Conversation '{conversation_name}' loaded with {len(session_history)} messages") - continue - elif user_input.lower().startswith("/delete"): - args = user_input[8:].strip() - if not args: - console.print("[bold red]Usage: /delete [/]") - console.print("[bold yellow]Tip: Use /list to see numbered conversations[/]") - continue - - conversation_name = None - if args.isdigit(): - conv_number = int(args) - if saved_conversations_cache and 1 <= conv_number <= len(saved_conversations_cache): - conversation_name = saved_conversations_cache[conv_number - 1]['name'] - console.print(f"[bold cyan]Deleting conversation #{conv_number}: '{conversation_name}'[/]") - else: - console.print(f"[bold red]Invalid conversation number: {conv_number}[/]") - console.print(f"[bold yellow]Use /list to see available conversations (1-{len(saved_conversations_cache) if saved_conversations_cache else 0})[/]") - continue - else: - conversation_name = args - - try: - confirm = typer.confirm(f"Delete conversation '{conversation_name}'? This cannot be undone.", default=False) - if not confirm: - console.print("[bold yellow]Deletion cancelled.[/]") - continue - except (EOFError, KeyboardInterrupt): - console.print("\n[bold yellow]Deletion cancelled.[/]") - continue - - deleted_count = delete_conversation(conversation_name) - if deleted_count > 0: - console.print(f"[bold green]Conversation '{conversation_name}' deleted ({deleted_count} version(s) removed).[/]") - app_logger.info(f"Conversation '{conversation_name}' deleted - {deleted_count} version(s)") - if saved_conversations_cache: - saved_conversations_cache = [c for c in saved_conversations_cache if c['name'] != conversation_name] - else: - console.print(f"[bold red]Conversation '{conversation_name}' not found.[/]") - app_logger.warning(f"Delete failed for '{conversation_name}' - not found") - continue - elif user_input.lower() == "/list": - conversations = list_conversations() - if not conversations: - console.print("[bold yellow]No saved conversations found.[/]") - app_logger.info("User viewed conversation list - empty") - saved_conversations_cache = [] - continue - - saved_conversations_cache = conversations - - table = Table("No.", "Name", "Messages", "Last Saved", show_header=True, header_style="bold magenta") - for idx, conv in enumerate(conversations, 1): - try: - dt = datetime.datetime.fromisoformat(conv['timestamp']) - formatted_time = dt.strftime('%Y-%m-%d %H:%M:%S') - except: - formatted_time = conv['timestamp'] - - table.add_row( - str(idx), - conv['name'], - str(conv['message_count']), - formatted_time - ) - - console.print(Panel(table, title=f"[bold green]Saved Conversations ({len(conversations)} total)[/]", title_align="left", subtitle="[dim]Use /load or /delete to manage conversations[/]", subtitle_align="right")) - app_logger.info(f"User viewed conversation list - {len(conversations)} conversations") - continue - elif user_input.lower() == "/prev": - if not session_history or current_index <= 0: - console.print("[bold red]No previous response.[/]") - continue - current_index -= 1 - prev_response = session_history[current_index]['response'] - md = Markdown(prev_response) - console.print(Panel(md, title=f"[bold green]Previous Response ({current_index + 1}/{len(session_history)})[/]", title_align="left")) - app_logger.debug(f"Viewed previous response at index {current_index}") - continue - elif user_input.lower() == "/next": - if not session_history or current_index >= len(session_history) - 1: - console.print("[bold red]No next response.[/]") - continue - current_index += 1 - next_response = session_history[current_index]['response'] - md = Markdown(next_response) - console.print(Panel(md, title=f"[bold green]Next Response ({current_index + 1}/{len(session_history)})[/]", title_align="left")) - app_logger.debug(f"Viewed next response at index {current_index}") - continue - elif user_input.lower() == "/stats": - credits = get_credits(API_KEY, OPENROUTER_BASE_URL) - credits_left = credits['credits_left'] if credits else "Unknown" - stats = f"Total Input: {total_input_tokens}, Total Output: {total_output_tokens}, Total Tokens: {total_input_tokens + total_output_tokens}, Total Cost: ${total_cost:.4f}, Avg Cost/Message: ${total_cost / message_count:.4f}" if message_count > 0 else "No messages." - table = Table("Metric", "Value", show_header=True, header_style="bold magenta") - table.add_row("Session Stats", stats) - table.add_row("Credits Left", credits_left) - console.print(Panel(table, title="[bold green]Session Cost Summary[/]", title_align="left")) - app_logger.info(f"User viewed stats: {stats}") - - warnings = check_credit_alerts(credits) - if warnings: - warning_text = '|'.join(warnings) - console.print(f"[bold red]⚠️ {warning_text}[/]") - app_logger.warning(f"Warnings in stats: {warning_text}") - continue - elif user_input.lower().startswith("/middleout"): - args = user_input[11:].strip() - if not args: - console.print(f"[bold blue]Middle-out transform {'enabled' if middle_out_enabled else 'disabled'}.[/]") - continue - if args.lower() == "on": - middle_out_enabled = True - console.print("[bold green]Middle-out transform enabled.[/]") - elif args.lower() == "off": - middle_out_enabled = False - console.print("[bold green]Middle-out transform disabled.[/]") - else: - console.print("[bold yellow]Usage: /middleout on|off (or /middleout to view status)[/]") - continue - elif user_input.lower() == "/reset": - try: - confirm = typer.confirm("Reset conversation context? This clears history and prompt.", default=False) - if not confirm: - console.print("[bold yellow]Reset cancelled.[/]") - continue - except (EOFError, KeyboardInterrupt): - console.print("\n[bold yellow]Reset cancelled.[/]") - continue - session_history = [] - current_index = -1 - session_system_prompt = "" - memory_start_index = 0 - total_input_tokens = 0 - total_output_tokens = 0 - total_cost = 0.0 - message_count = 0 - console.print("[bold green]Conversation context reset.[/]") - app_logger.info("Conversation context reset by user") - continue - - elif user_input.lower().startswith("/info"): - args = user_input[6:].strip() - if not args: - if not selected_model: - console.print("[bold red]No model selected and no model ID provided. Use '/model' first or '/info '.[/]") - continue - model_to_show = selected_model - else: - model_to_show = next((m for m in models_data if m["id"] == args or m.get("canonical_slug") == args or args.lower() in m["name"].lower()), None) - if not model_to_show: - console.print(f"[bold red]Model '{args}' not found.[/]") - continue - - pricing = model_to_show.get("pricing", {}) - architecture = model_to_show.get("architecture", {}) - supported_params = ", ".join(model_to_show.get("supported_parameters", [])) or "None" - top_provider = model_to_show.get("top_provider", {}) - - table = Table("Property", "Value", show_header=True, header_style="bold magenta") - table.add_row("ID", model_to_show["id"]) - table.add_row("Name", model_to_show["name"]) - table.add_row("Description", model_to_show.get("description", "N/A")) - table.add_row("Context Length", str(model_to_show.get("context_length", "N/A"))) - table.add_row("Online Support", "Yes" if supports_online_mode(model_to_show) else "No") - table.add_row("Pricing - Prompt ($/M tokens)", pricing.get("prompt", "N/A")) - table.add_row("Pricing - Completion ($/M tokens)", pricing.get("completion", "N/A")) - table.add_row("Pricing - Request ($)", pricing.get("request", "N/A")) - table.add_row("Pricing - Image ($)", pricing.get("image", "N/A")) - table.add_row("Input Modalities", ", ".join(architecture.get("input_modalities", [])) or "None") - table.add_row("Output Modalities", ", ".join(architecture.get("output_modalities", [])) or "None") - table.add_row("Supported Parameters", supported_params) - table.add_row("Top Provider Context Length", str(top_provider.get("context_length", "N/A"))) - table.add_row("Max Completion Tokens", str(top_provider.get("max_completion_tokens", "N/A"))) - table.add_row("Moderated", "Yes" if top_provider.get("is_moderated", False) else "No") - - console.print(Panel(table, title=f"[bold green]Model Info: {model_to_show['name']}[/]", title_align="left")) - continue - - elif user_input.startswith("/model"): - app_logger.info("User initiated model selection") - args = user_input[7:].strip() - search_term = args if args else "" - filtered_models = text_models - if search_term: - filtered_models = [m for m in text_models if search_term.lower() in m["name"].lower() or search_term.lower() in m["id"].lower()] - if not filtered_models: - console.print(f"[bold red]No models match '{search_term}'. Try '/model'.[/]") - continue - - table = Table("No.", "Name", "ID", "Image", "Online", show_header=True, header_style="bold magenta") - for i, model in enumerate(filtered_models, 1): - image_support = "[green]✓[/green]" if has_image_capability(model) else "[red]✗[/red]" - online_support = "[green]✓[/green]" if supports_online_mode(model) else "[red]✗[/red]" - table.add_row(str(i), model["name"], model["id"], image_support, online_support) - - title = f"[bold green]Available Models ({'All' if not search_term else f'Search: {search_term}'})[/]" - display_paginated_table(table, title) - - while True: - try: - choice = int(typer.prompt("Enter model number (or 0 to cancel)")) - if choice == 0: - break - if 1 <= choice <= len(filtered_models): - selected_model = filtered_models[choice - 1] - - # Apply default online mode if model supports it - if supports_online_mode(selected_model) and DEFAULT_ONLINE_MODE == "on": - online_mode_enabled = True - console.print(f"[bold cyan]Selected: {selected_model['name']} ({selected_model['id']})[/]") - console.print("[dim cyan]✓ Online mode auto-enabled (default setting). Use '/online off' to disable for this session.[/]") - else: - online_mode_enabled = False - console.print(f"[bold cyan]Selected: {selected_model['name']} ({selected_model['id']})[/]") - if supports_online_mode(selected_model): - console.print("[dim green]✓ This model supports online mode. Use '/online on' to enable web search.[/]") - - app_logger.info(f"Model selected: {selected_model['name']} ({selected_model['id']}), Online: {online_mode_enabled}") - break - console.print("[bold red]Invalid choice. Try again.[/]") - except ValueError: - console.print("[bold red]Invalid input. Enter a number.[/]") - continue - - elif user_input.startswith("/maxtoken"): - args = user_input[10:].strip() - if not args: - console.print(f"[bold blue]Current session max tokens: {session_max_token}[/]") - continue - try: - new_limit = int(args) - if new_limit < 1: - console.print("[bold red]Session token limit must be at least 1.[/]") - continue - if new_limit > MAX_TOKEN: - console.print(f"[bold yellow]Cannot exceed stored max ({MAX_TOKEN}). Capping.[/]") - new_limit = MAX_TOKEN - session_max_token = new_limit - console.print(f"[bold green]Session max tokens set to: {session_max_token}[/]") - except ValueError: - console.print("[bold red]Invalid token limit. Provide a positive integer.[/]") - continue - - elif user_input.startswith("/system"): - args = user_input[8:].strip() - if not args: - if session_system_prompt: - console.print(f"[bold blue]Current session system prompt:[/] {session_system_prompt}") - else: - console.print("[bold blue]No session system prompt set.[/]") - continue - if args.lower() == "clear": - session_system_prompt = "" - console.print("[bold green]Session system prompt cleared.[/]") - else: - session_system_prompt = args - console.print(f"[bold green]Session system prompt set to: {session_system_prompt}[/]") - continue - - elif user_input.startswith("/config"): - args = user_input[8:].strip().lower() - update = check_for_updates(version) - - if args == "api": - try: - new_api_key = typer.prompt("Enter new API key") - if new_api_key.strip(): - set_config('api_key', new_api_key.strip()) - API_KEY = new_api_key.strip() - client = OpenRouter(api_key=API_KEY) - console.print("[bold green]API key updated![/]") - else: - console.print("[bold yellow]No change.[/]") - except Exception as e: - console.print(f"[bold red]Error updating API key: {e}[/]") - elif args == "url": - try: - new_url = typer.prompt("Enter new base URL") - if new_url.strip(): - set_config('base_url', new_url.strip()) - OPENROUTER_BASE_URL = new_url.strip() - console.print("[bold green]Base URL updated![/]") - else: - console.print("[bold yellow]No change.[/]") - except Exception as e: - console.print(f"[bold red]Error updating URL: {e}[/]") - elif args.startswith("costwarning"): - sub_args = args[11:].strip() - if not sub_args: - console.print(f"[bold blue]Stored cost warning threshold: ${COST_WARNING_THRESHOLD:.4f}[/]") - continue - try: - new_threshold = float(sub_args) - if new_threshold < 0: - console.print("[bold red]Cost warning threshold must be >= 0.[/]") - continue - set_config(HIGH_COST_WARNING, str(new_threshold)) - COST_WARNING_THRESHOLD = new_threshold - console.print(f"[bold green]Cost warning threshold set to ${COST_WARNING_THRESHOLD:.4f}[/]") - except ValueError: - console.print("[bold red]Invalid cost threshold. Provide a valid number.[/]") - elif args.startswith("stream"): - sub_args = args[7:].strip() - if sub_args in ["on", "off"]: - set_config('stream_enabled', sub_args) - STREAM_ENABLED = sub_args - console.print(f"[bold green]Streaming {'enabled' if sub_args == 'on' else 'disabled'}.[/]") - else: - console.print("[bold yellow]Usage: /config stream on|off[/]") - elif args.startswith("loglevel"): - sub_args = args[8:].strip() - if not sub_args: - console.print(f"[bold blue]Current log level: {LOG_LEVEL_STR.upper()}[/]") - console.print(f"[dim yellow]Valid levels: debug, info, warning, error, critical[/]") - continue - if sub_args.lower() in VALID_LOG_LEVELS: - if set_log_level(sub_args): - set_config('log_level', sub_args.lower()) - console.print(f"[bold green]Log level set to: {sub_args.upper()}[/]") - app_logger.info(f"Log level changed to {sub_args.upper()}") - else: - console.print(f"[bold red]Failed to set log level.[/]") - else: - console.print(f"[bold red]Invalid log level: {sub_args}[/]") - console.print(f"[bold yellow]Valid levels: debug, info, warning, error, critical[/]") - elif args.startswith("log"): - sub_args = args[4:].strip() - if not sub_args: - console.print(f"[bold blue]Current log file size limit: {LOG_MAX_SIZE_MB} MB[/]") - console.print(f"[bold blue]Log backup count: {LOG_BACKUP_COUNT} files[/]") - console.print(f"[bold blue]Log level: {LOG_LEVEL_STR.upper()}[/]") - console.print(f"[dim yellow]Total max disk usage: ~{LOG_MAX_SIZE_MB * (LOG_BACKUP_COUNT + 1)} MB[/]") - continue - try: - new_size_mb = int(sub_args) - if new_size_mb < 1: - console.print("[bold red]Log size must be at least 1 MB.[/]") - continue - if new_size_mb > 100: - console.print("[bold yellow]Warning: Log size > 100MB. Capping at 100MB.[/]") - new_size_mb = 100 - set_config('log_max_size_mb', str(new_size_mb)) - LOG_MAX_SIZE_MB = new_size_mb - - # Reload logging configuration immediately - app_logger = reload_logging_config() - - console.print(f"[bold green]Log size limit set to {new_size_mb} MB and applied immediately.[/]") - console.print(f"[dim cyan]Log file rotated if it exceeded the new limit.[/]") - app_logger.info(f"Log size limit updated to {new_size_mb} MB and reloaded") - except ValueError: - console.print("[bold red]Invalid size. Provide a number in MB.[/]") - elif args.startswith("online"): - sub_args = args[7:].strip() - if not sub_args: - current_default = "enabled" if DEFAULT_ONLINE_MODE == "on" else "disabled" - console.print(f"[bold blue]Default online mode: {current_default}[/]") - console.print("[dim yellow]This sets the default for new models. Use '/online on|off' to override in current session.[/]") - continue - if sub_args in ["on", "off"]: - set_config('default_online_mode', sub_args) - DEFAULT_ONLINE_MODE = sub_args - console.print(f"[bold green]Default online mode {'enabled' if sub_args == 'on' else 'disabled'}.[/]") - console.print("[dim blue]Note: This affects new model selections. Current session unchanged.[/]") - app_logger.info(f"Default online mode set to {sub_args}") - else: - console.print("[bold yellow]Usage: /config online on|off[/]") - elif args.startswith("maxtoken"): - sub_args = args[9:].strip() - if not sub_args: - console.print(f"[bold blue]Stored max token limit: {MAX_TOKEN}[/]") - continue - try: - new_max = int(sub_args) - if new_max < 1: - console.print("[bold red]Max token limit must be at least 1.[/]") - continue - if new_max > 1000000: - console.print("[bold yellow]Capped at 1M for safety.[/]") - new_max = 1000000 - set_config('max_token', str(new_max)) - MAX_TOKEN = new_max - if session_max_token > MAX_TOKEN: - session_max_token = MAX_TOKEN - console.print(f"[bold yellow]Session adjusted to {session_max_token}.[/]") - console.print(f"[bold green]Stored max token limit updated to: {MAX_TOKEN}[/]") - except ValueError: - console.print("[bold red]Invalid token limit.[/]") - elif args.startswith("model"): - sub_args = args[6:].strip() - search_term = sub_args if sub_args else "" - filtered_models = text_models - if search_term: - filtered_models = [m for m in text_models if search_term.lower() in m["name"].lower() or search_term.lower() in m["id"].lower()] - if not filtered_models: - console.print(f"[bold red]No models match '{search_term}'. Try without search.[/]") - continue - - table = Table("No.", "Name", "ID", "Image", "Online", show_header=True, header_style="bold magenta") - for i, model in enumerate(filtered_models, 1): - image_support = "[green]✓[/green]" if has_image_capability(model) else "[red]✗[/red]" - online_support = "[green]✓[/green]" if supports_online_mode(model) else "[red]✗[/red]" - table.add_row(str(i), model["name"], model["id"], image_support, online_support) - - title = f"[bold green]Available Models for Default ({'All' if not search_term else f'Search: {search_term}'})[/]" - display_paginated_table(table, title) - - while True: - try: - choice = int(typer.prompt("Enter model number (or 0 to cancel)")) - if choice == 0: - break - if 1 <= choice <= len(filtered_models): - default_model = filtered_models[choice - 1] - set_config('default_model', default_model["id"]) - current_name = selected_model['name'] if selected_model else "None" - console.print(f"[bold cyan]Default model set to: {default_model['name']} ({default_model['id']}). Current unchanged: {current_name}[/]") - break - console.print("[bold red]Invalid choice. Try again.[/]") - except ValueError: - console.print("[bold red]Invalid input. Enter a number.[/]") - else: - DEFAULT_MODEL_ID = get_config('default_model') - memory_status = "Enabled" if conversation_memory_enabled else "Disabled" - memory_tracked = len(session_history) - memory_start_index if conversation_memory_enabled else 0 - table = Table("Setting", "Value", show_header=True, header_style="bold magenta", width=console.width - 10) - table.add_row("API Key", API_KEY or "[Not set]") - table.add_row("Base URL", OPENROUTER_BASE_URL or "[Not set]") - table.add_row("DB Path", str(database) or "[Not set]") - table.add_row("Logfile", str(log_file) or "[Not set]") - table.add_row("Log Size Limit", f"{LOG_MAX_SIZE_MB} MB") - table.add_row("Log Backups", str(LOG_BACKUP_COUNT)) - table.add_row("Log Level", LOG_LEVEL_STR.upper()) - table.add_row("Streaming", "Enabled" if STREAM_ENABLED == "on" else "Disabled") - table.add_row("Default Model", DEFAULT_MODEL_ID or "[Not set]") - table.add_row("Current Model", "[Not set]" if selected_model is None else str(selected_model["name"])) - table.add_row("Default Online Mode", "Enabled" if DEFAULT_ONLINE_MODE == "on" else "Disabled") - table.add_row("Session Online Mode", "Enabled" if online_mode_enabled else "Disabled") - table.add_row("Max Token", str(MAX_TOKEN)) - table.add_row("Session Token", "[Not set]" if session_max_token == 0 else str(session_max_token)) - table.add_row("Session System Prompt", session_system_prompt or "[Not set]") - table.add_row("Cost Warning Threshold", f"${COST_WARNING_THRESHOLD:.4f}") - table.add_row("Middle-out Transform", "Enabled" if middle_out_enabled else "Disabled") - table.add_row("Conversation Memory", f"{memory_status} ({memory_tracked} tracked)" if conversation_memory_enabled else memory_status) - table.add_row("History Size", str(len(session_history))) - table.add_row("Current History Index", str(current_index) if current_index >= 0 else "[None]") - table.add_row("App Name", APP_NAME) - table.add_row("App URL", APP_URL) - - credits = get_credits(API_KEY, OPENROUTER_BASE_URL) - if credits: - table.add_row("Total Credits", credits['total_credits']) - table.add_row("Used Credits", credits['used_credits']) - table.add_row("Credits Left", credits['credits_left']) - else: - table.add_row("Total Credits", "[Unavailable - Check API key]") - table.add_row("Used Credits", "[Unavailable - Check API key]") - table.add_row("Credits Left", "[Unavailable - Check API key]") - - console.print(Panel(table, title="[bold green]Current Configurations[/]", title_align="left", subtitle="%s" %(update), subtitle_align="right")) - continue - - if user_input.lower() == "/credits": - credits = get_credits(API_KEY, OPENROUTER_BASE_URL) - if credits: - console.print(f"[bold green]Credits left: {credits['credits_left']}[/]") - alerts = check_credit_alerts(credits) - if alerts: - for alert in alerts: - console.print(f"[bold red]⚠️ {alert}[/]") - else: - console.print("[bold red]Unable to fetch credits. Check your API key or network.[/]") - continue - - if user_input.lower() == "/clear" or user_input.lower() == "/cl": - clear_screen() - DEFAULT_MODEL_ID = get_config('default_model') - token_value = session_max_token if session_max_token != 0 else " Not set" - console.print(f"[bold cyan]Token limits: Max= {MAX_TOKEN}, Session={token_value}[/]") - console.print("[bold blue]Active model[/] [bold red]%s[/]" %(str(selected_model["name"]) if selected_model else "None")) - if online_mode_enabled: - console.print("[bold cyan]Online mode: Enabled (web search active)[/]") - continue - - if user_input.lower().startswith("/help"): - args = user_input[6:].strip() - - # If a specific command is requested - if args: - show_command_help(args) - continue - - # Otherwise show the full help menu - help_table = Table("Command", "Description", "Example", show_header=True, header_style="bold cyan", width=console.width - 10) - - # SESSION COMMANDS - help_table.add_row( - "[bold yellow]━━━ SESSION COMMANDS ━━━[/]", - "", - "" - ) - help_table.add_row( - "/clear or /cl", - "Clear the terminal screen for a clean interface. You can also use the keycombo [bold]ctrl+l[/]", - "/clear\n/cl" - ) - help_table.add_row( - "/help [command]", - "Show this help menu or get detailed help for a specific command.", - "/help\n/help /model" - ) - help_table.add_row( - "/memory [on|off]", - "Toggle conversation memory. ON sends history (AI remembers), OFF sends only current message (saves cost).", - "/memory\n/memory off" - ) - help_table.add_row( - "/next", - "View the next response in history.", - "/next" - ) - help_table.add_row( - "/online [on|off]", - "Enable/disable online mode (web search) for current session. Overrides default setting.", - "/online on\n/online off" - ) - help_table.add_row( - "/paste [prompt]", - "Paste plain text/code from clipboard and send to AI. Optional prompt can be added.", - "/paste\n/paste Explain this code" - ) - help_table.add_row( - "/prev", - "View the previous response in history.", - "/prev" - ) - help_table.add_row( - "/reset", - "Clear conversation history and reset system prompt (resets session metrics). Requires confirmation.", - "/reset" - ) - help_table.add_row( - "/retry", - "Resend the last prompt from history.", - "/retry" - ) - - # MODEL COMMANDS - help_table.add_row( - "[bold yellow]━━━ MODEL COMMANDS ━━━[/]", - "", - "" - ) - help_table.add_row( - "/info [model_id]", - "Display detailed info (pricing, modalities, context length, online support, etc.) for current or specified model.", - "/info\n/info gpt-4o" - ) - help_table.add_row( - "/model [search]", - "Select or change the current model for the session. Shows image and online capabilities. Supports searching by name or ID.", - "/model\n/model gpt" - ) - - # CONFIGURATION COMMANDS - help_table.add_row( - "[bold yellow]━━━ CONFIGURATION ━━━[/]", - "", - "" - ) - help_table.add_row( - "/config", - "View all current configurations, including limits, credits, and history.", - "/config" - ) - help_table.add_row( - "/config api", - "Set or update the OpenRouter API key.", - "/config api" - ) - help_table.add_row( - "/config costwarning [value]", - "Set the cost warning threshold. Alerts when response exceeds this cost (in USD).", - "/config costwarning 0.05" - ) - help_table.add_row( - "/config log [size_mb]", - "Set log file size limit in MB. Older logs are rotated automatically. Takes effect immediately.", - "/config log 20" - ) - help_table.add_row( - "/config loglevel [level]", - "Set log verbosity level. Valid levels: debug, info, warning, error, critical. Takes effect immediately.", - "/config loglevel debug\n/config loglevel warning" - ) - help_table.add_row( - "/config maxtoken [value]", - "Set stored max token limit (persisted in DB). View current if no value provided.", - "/config maxtoken 50000" - ) - help_table.add_row( - "/config model [search]", - "Set default model that loads on startup. Shows image and online capabilities. Doesn't change current session model.", - "/config model gpt" - ) - help_table.add_row( - "/config online [on|off]", - "Set default online mode for new model selections. Use '/online on|off' to override current session.", - "/config online on" - ) - help_table.add_row( - "/config stream [on|off]", - "Enable or disable response streaming.", - "/config stream off" - ) - help_table.add_row( - "/config url", - "Set or update the base URL for OpenRouter API.", - "/config url" - ) - - # TOKEN & SYSTEM COMMANDS - help_table.add_row( - "[bold yellow]━━━ TOKEN & SYSTEM ━━━[/]", - "", - "" - ) - help_table.add_row( - "/maxtoken [value]", - "Set temporary session token limit (≤ stored max). View current if no value provided.", - "/maxtoken 2000" - ) - help_table.add_row( - "/middleout [on|off]", - "Enable/disable middle-out transform to compress prompts exceeding context size.", - "/middleout on" - ) - help_table.add_row( - "/system [prompt|clear]", - "Set session-level system prompt to guide AI behavior. Use 'clear' to reset.", - "/system You are a Python expert" - ) - - # CONVERSATION MANAGEMENT - help_table.add_row( - "[bold yellow]━━━ CONVERSATION MGMT ━━━[/]", - "", - "" - ) - help_table.add_row( - "/delete ", - "Delete a saved conversation by name or number (from /list). Requires confirmation.", - "/delete my_chat\n/delete 3" - ) - help_table.add_row( - "/export ", - "Export conversation to file. Formats: md (Markdown), json (JSON), html (HTML).", - "/export md notes.md\n/export html report.html" - ) - help_table.add_row( - "/list", - "List all saved conversations with numbers, message counts, and timestamps.", - "/list" - ) - help_table.add_row( - "/load ", - "Load a saved conversation by name or number (from /list). Resets session metrics.", - "/load my_chat\n/load 3" - ) - help_table.add_row( - "/save ", - "Save current conversation history to database.", - "/save my_chat" - ) - - # MONITORING & STATS - help_table.add_row( - "[bold yellow]━━━ MONITORING & STATS ━━━[/]", - "", - "" - ) - help_table.add_row( - "/credits", - "Display credits left on your OpenRouter account with alerts.", - "/credits" - ) - help_table.add_row( - "/stats", - "Display session cost summary: tokens, cost, credits left, and warnings.", - "/stats" - ) - - # INPUT METHODS - help_table.add_row( - "[bold yellow]━━━ INPUT METHODS ━━━[/]", - "", - "" - ) - help_table.add_row( - "@/path/to/file", - "Attach files to messages: images (PNG, JPG, etc.), PDFs, and code files (.py, .js, etc.).", - "Debug @script.py\nSummarize @document.pdf\nAnalyze @image.png" - ) - help_table.add_row( - "Clipboard paste", - "Use /paste to send clipboard content (plain text/code) to AI.", - "/paste\n/paste Explain this" - ) - help_table.add_row( - "// escape", - "Start message with // to send a literal / character (e.g., //command sends '/command' as text, not a command)", - "//help sends '/help' as text" - ) - - # EXIT - help_table.add_row( - "[bold yellow]━━━ EXIT ━━━[/]", - "", - "" - ) - help_table.add_row( - "exit | quit | bye", - "Quit the chat application and display session summary.", - "exit" - ) - - console.print(Panel( - help_table, - title="[bold cyan]oAI Chat Help (Version %s)[/]" % version, - title_align="center", - subtitle="💡 Tip: Commands are case-insensitive • Use /help for detailed help • Memory ON by default • Use // to escape / • Visit: https://iurl.no/oai", - subtitle_align="center", - border_style="cyan" - )) - continue - - if not selected_model: - console.print("[bold yellow]Select a model first with '/model'.[/]") - continue - # Process file attachments with PDF support - content_blocks = [] - text_part = user_input - file_attachments = [] - for match in re.finditer(r'@([^\s]+)', user_input): - file_path = match.group(1) - expanded_path = os.path.expanduser(os.path.abspath(file_path)) - if not os.path.exists(expanded_path) or os.path.isdir(expanded_path): - console.print(f"[bold red]File not found or is a directory: {expanded_path}[/]") - continue - file_size = os.path.getsize(expanded_path) - if file_size > 10 * 1024 * 1024: - console.print(f"[bold red]File too large (>10MB): {expanded_path}[/]") - continue - mime_type, _ = mimetypes.guess_type(expanded_path) - file_ext = os.path.splitext(expanded_path)[1].lower() - try: - with open(expanded_path, 'rb') as f: - file_data = f.read() - - # Handle images - if mime_type and mime_type.startswith('image/'): - modalities = selected_model.get("architecture", {}).get("input_modalities", []) - if "image" not in modalities: - console.print("[bold red]Selected model does not support image attachments.[/]") - console.print(f"[dim yellow]Supported modalities: {', '.join(modalities) if modalities else 'text only'}[/]") - continue - b64_data = base64.b64encode(file_data).decode('utf-8') - content_blocks.append({"type": "image_url", "image_url": {"url": f"data:{mime_type};base64,{b64_data}"}}) - console.print(f"[dim green]✓ Image attached: {os.path.basename(expanded_path)} ({file_size / 1024:.1f} KB)[/]") - - # Handle PDFs - elif mime_type == 'application/pdf' or file_ext == '.pdf': - modalities = selected_model.get("architecture", {}).get("input_modalities", []) - supports_pdf = any(mod in modalities for mod in ["document", "pdf", "file"]) - if not supports_pdf: - console.print("[bold red]Selected model does not support PDF attachments.[/]") - console.print(f"[dim yellow]Supported modalities: {', '.join(modalities) if modalities else 'text only'}[/]") - continue - b64_data = base64.b64encode(file_data).decode('utf-8') - content_blocks.append({"type": "image_url", "image_url": {"url": f"data:application/pdf;base64,{b64_data}"}}) - console.print(f"[dim green]✓ PDF attached: {os.path.basename(expanded_path)} ({file_size / 1024:.1f} KB)[/]") - - # Handle code/text files - elif (mime_type == 'text/plain' or file_ext in SUPPORTED_CODE_EXTENSIONS): - text_content = file_data.decode('utf-8') - content_blocks.append({"type": "text", "text": f"Code File: {os.path.basename(expanded_path)}\n\n{text_content}"}) - console.print(f"[dim green]✓ Code file attached: {os.path.basename(expanded_path)} ({file_size / 1024:.1f} KB)[/]") - - else: - console.print(f"[bold red]Unsupported file type ({mime_type}) for {expanded_path}.[/]") - console.print("[bold yellow]Supported types: images (PNG, JPG, etc.), PDFs, and code files (.py, .js, etc.)[/]") - continue - - file_attachments.append(file_path) - app_logger.info(f"File attached: {os.path.basename(expanded_path)}, Type: {mime_type or file_ext}, Size: {file_size / 1024:.1f} KB") - except UnicodeDecodeError: - console.print(f"[bold red]Cannot decode {expanded_path} as UTF-8. File may be binary or use unsupported encoding.[/]") - app_logger.error(f"UTF-8 decode error for {expanded_path}") - continue - except Exception as e: - console.print(f"[bold red]Error reading file {expanded_path}: {e}[/]") - app_logger.error(f"File read error for {expanded_path}: {e}") - continue - text_part = re.sub(r'@([^\s]+)', '', text_part).strip() - - # Build message content - if text_part or content_blocks: - message_content = [] - if text_part: - message_content.append({"type": "text", "text": text_part}) - message_content.extend(content_blocks) - else: - console.print("[bold red]Prompt cannot be empty.[/]") - continue - - # Build API messages with conversation history if memory is enabled - api_messages = [] - - if session_system_prompt: - api_messages.append({"role": "system", "content": session_system_prompt}) - - if conversation_memory_enabled: - for i in range(memory_start_index, len(session_history)): - history_entry = session_history[i] - api_messages.append({ - "role": "user", - "content": history_entry['prompt'] - }) - api_messages.append({ - "role": "assistant", - "content": history_entry['response'] - }) - - api_messages.append({"role": "user", "content": message_content}) - - # Get effective model ID with :online suffix if enabled - effective_model_id = get_effective_model_id(selected_model["id"], online_mode_enabled) - - # Build API params - api_params = { - "model": effective_model_id, - "messages": api_messages, - "stream": STREAM_ENABLED == "on", - "http_headers": { - "HTTP-Referer": APP_URL, - "X-Title": APP_NAME - } - } - if session_max_token > 0: - api_params["max_tokens"] = session_max_token - if middle_out_enabled: - api_params["transforms"] = ["middle-out"] - - # Log API request - file_count = len(file_attachments) - history_messages_count = len(session_history) - memory_start_index if conversation_memory_enabled else 0 - memory_status = "ON" if conversation_memory_enabled else "OFF" - online_status = "ON" if online_mode_enabled else "OFF" - app_logger.info(f"API Request: Model '{effective_model_id}' (Online: {online_status}), Prompt length: {len(text_part)} chars, {file_count} file(s) attached, Memory: {memory_status}, History sent: {history_messages_count} messages, Transforms: middle-out {'enabled' if middle_out_enabled else 'disabled'}, App: {APP_NAME} ({APP_URL}).") - - # Send and handle response - is_streaming = STREAM_ENABLED == "on" - if is_streaming: - console.print("[bold green]Streaming response...[/] [dim](Press Ctrl+C to cancel)[/]") - if online_mode_enabled: - console.print("[dim cyan]🌐 Online mode active - model has web search access[/]") - console.print("") - else: - console.print("[bold green]Thinking...[/]", end="\r") - - start_time = time.time() - try: - response = client.chat.send(**api_params) - app_logger.info(f"API call successful for model '{effective_model_id}'") - except Exception as e: - console.print(f"[bold red]Error sending request: {e}[/]") - app_logger.error(f"API Error: {type(e).__name__}: {e}") - continue - - response_time = time.time() - start_time - - full_response = "" - if is_streaming: - try: - with Live("", console=console, refresh_per_second=10, auto_refresh=True) as live: - for chunk in response: - if hasattr(chunk, 'error') and chunk.error: - console.print(f"\n[bold red]Stream error: {chunk.error.message}[/]") - app_logger.error(f"Stream error: {chunk.error.message}") - break - if hasattr(chunk.choices[0].delta, 'content') and chunk.choices[0].delta.content: - content_chunk = chunk.choices[0].delta.content - full_response += content_chunk - md = Markdown(full_response) - live.update(md) - - console.print("") - - except KeyboardInterrupt: - console.print("\n[bold yellow]Streaming cancelled![/]") - app_logger.info("Streaming cancelled by user") - continue - else: - full_response = response.choices[0].message.content if response.choices else "" - console.print(f"\r{' ' * 20}\r", end="") - - if full_response: - if not is_streaming: - md = Markdown(full_response) - console.print(Panel(md, title="[bold green]AI Response[/]", title_align="left", border_style="green")) - - session_history.append({'prompt': user_input, 'response': full_response}) - current_index = len(session_history) - 1 - - # Process metrics - usage = getattr(response, 'usage', None) - input_tokens = usage.input_tokens if usage and hasattr(usage, 'input_tokens') else 0 - output_tokens = usage.output_tokens if usage and hasattr(usage, 'output_tokens') else 0 - msg_cost = usage.total_cost_usd if usage and hasattr(usage, 'total_cost_usd') else estimate_cost(input_tokens, output_tokens) - - total_input_tokens += input_tokens - total_output_tokens += output_tokens - total_cost += msg_cost - message_count += 1 - - app_logger.info(f"Response: Tokens - I:{input_tokens} O:{output_tokens} T:{input_tokens + output_tokens}, Cost: ${msg_cost:.4f}, Time: {response_time:.2f}s, Online: {online_mode_enabled}") - - # Per-message metrics display - if conversation_memory_enabled: - context_count = len(session_history) - memory_start_index - context_info = f", Context: {context_count} msg(s)" if context_count > 1 else "" - else: - context_info = ", Memory: OFF" - - online_info = " 🌐" if online_mode_enabled else "" - console.print(f"\n[dim blue]📊 Metrics: {input_tokens + output_tokens} tokens | ${msg_cost:.4f} | {response_time:.2f}s{context_info}{online_info} | Session: {total_input_tokens + total_output_tokens} tokens | ${total_cost:.4f}[/]") - - # Cost and credit alerts - warnings = [] - if msg_cost > COST_WARNING_THRESHOLD: - warnings.append(f"High cost alert: This response exceeded configurable threshold ${COST_WARNING_THRESHOLD:.4f}") - credits_data = get_credits(API_KEY, OPENROUTER_BASE_URL) - if credits_data: - warning_alerts = check_credit_alerts(credits_data) - warnings.extend(warning_alerts) - if warnings: - warning_text = ' | '.join(warnings) - console.print(f"[bold red]⚠️ {warning_text}[/]") - app_logger.warning(f"Warnings triggered: {warning_text}") - - console.print("") - try: - copy_choice = input("💾 Type 'c' to copy response, or press Enter to continue: ").strip().lower() - if copy_choice == "c": - pyperclip.copy(full_response) - console.print("[bold green]✅ Response copied to clipboard![/]") - except (EOFError, KeyboardInterrupt): - pass - - console.print("") - else: - console.print("[bold red]No response received.[/]") - app_logger.error("No response from API") - - except KeyboardInterrupt: - console.print("\n[bold yellow]Input interrupted. Continuing...[/]") - app_logger.warning("Input interrupted by Ctrl+C") - continue - except EOFError: - console.print("\n[bold yellow]Goodbye![/]") - total_tokens = total_input_tokens + total_output_tokens - app_logger.info(f"Session ended via EOF. Total messages: {message_count}, Total tokens: {total_tokens}, Total cost: ${total_cost:.4f}") - return - except Exception as e: - console.print(f"[bold red]Error: {e}[/]") - console.print("[bold yellow]Try again or select a model.[/]") - app_logger.error(f"Unexpected error: {type(e).__name__}: {e}") - -if __name__ == "__main__": - clear_screen() - app() \ No newline at end of file diff --git a/oai/__init__.py b/oai/__init__.py new file mode 100644 index 0000000..77da9b0 --- /dev/null +++ b/oai/__init__.py @@ -0,0 +1,26 @@ +""" +oAI - OpenRouter AI Chat Client + +A feature-rich terminal-based chat application that provides an interactive CLI +interface to OpenRouter's unified AI API with advanced Model Context Protocol (MCP) +integration for filesystem and database access. + +Author: Rune +License: MIT +""" + +__version__ = "2.1.0" +__author__ = "Rune" +__license__ = "MIT" + +# Lazy imports to avoid circular dependencies and improve startup time +# Full imports are available via submodules: +# from oai.config import Settings, Database +# from oai.providers import OpenRouterProvider, AIProvider +# from oai.mcp import MCPManager + +__all__ = [ + "__version__", + "__author__", + "__license__", +] diff --git a/oai/__main__.py b/oai/__main__.py new file mode 100644 index 0000000..bd02e03 --- /dev/null +++ b/oai/__main__.py @@ -0,0 +1,8 @@ +""" +Entry point for running oAI as a module: python -m oai +""" + +from oai.cli import main + +if __name__ == "__main__": + main() diff --git a/oai/cli.py b/oai/cli.py new file mode 100644 index 0000000..0405885 --- /dev/null +++ b/oai/cli.py @@ -0,0 +1,719 @@ +""" +Main CLI entry point for oAI. + +This module provides the command-line interface for the oAI chat application. +""" + +import os +import sys +from pathlib import Path +from typing import Optional + +import typer +from prompt_toolkit import PromptSession +from prompt_toolkit.auto_suggest import AutoSuggestFromHistory +from prompt_toolkit.history import FileHistory +from rich.markdown import Markdown +from rich.panel import Panel + +from oai import __version__ +from oai.commands import register_all_commands, registry +from oai.commands.registry import CommandContext, CommandStatus +from oai.config.database import Database +from oai.config.settings import Settings +from oai.constants import ( + APP_NAME, + APP_URL, + APP_VERSION, + CONFIG_DIR, + HISTORY_FILE, + VALID_COMMANDS, +) +from oai.core.client import AIClient +from oai.core.session import ChatSession +from oai.mcp.manager import MCPManager +from oai.providers.base import UsageStats +from oai.providers.openrouter import OpenRouterProvider +from oai.ui.console import ( + clear_screen, + console, + display_panel, + print_error, + print_info, + print_success, + print_warning, +) +from oai.ui.tables import create_model_table, display_paginated_table +from oai.utils.logging import LoggingManager, get_logger + +# Create Typer app +app = typer.Typer( + name="oai", + help=f"oAI - OpenRouter AI Chat Client\n\nVersion: {APP_VERSION}", + add_completion=False, + epilog="For more information, visit: " + APP_URL, +) + + +@app.callback(invoke_without_command=True) +def main_callback( + ctx: typer.Context, + version_flag: bool = typer.Option( + False, + "--version", + "-v", + help="Show version information", + is_flag=True, + ), +) -> None: + """Main callback to handle global options.""" + # Show version with update check if --version flag + if version_flag: + version_info = check_for_updates(APP_VERSION) + console.print(version_info) + raise typer.Exit() + + # Show version with update check when --help is requested + if "--help" in sys.argv or "-h" in sys.argv: + version_info = check_for_updates(APP_VERSION) + console.print(f"\n{version_info}\n") + + # Continue to subcommand if provided + if ctx.invoked_subcommand is None: + return + + +def check_for_updates(current_version: str) -> str: + """Check for available updates.""" + import requests + from packaging import version as pkg_version + + try: + response = requests.get( + "https://gitlab.pm/api/v1/repos/rune/oai/releases/latest", + headers={"Content-Type": "application/json"}, + timeout=1.0, + ) + response.raise_for_status() + + data = response.json() + version_online = data.get("tag_name", "").lstrip("v") + + if not version_online: + return f"[bold green]oAI version {current_version}[/]" + + current = pkg_version.parse(current_version) + latest = pkg_version.parse(version_online) + + if latest > current: + return ( + f"[bold green]oAI version {current_version}[/] " + f"[bold red](Update available: {current_version} → {version_online})[/]" + ) + return f"[bold green]oAI version {current_version} (up to date)[/]" + + except Exception: + return f"[bold green]oAI version {current_version}[/]" + + +def show_welcome(settings: Settings, version_info: str) -> None: + """Display welcome message.""" + console.print(Panel.fit( + f"{version_info}\n\n" + "[bold cyan]Commands:[/] /help for commands, /model to select model\n" + "[bold cyan]MCP:[/] /mcp on to enable file/database access\n" + "[bold cyan]Exit:[/] Type 'exit', 'quit', or 'bye'", + title=f"[bold green]Welcome to {APP_NAME}[/]", + border_style="green", + )) + + +def select_model(client: AIClient, search_term: Optional[str] = None) -> Optional[dict]: + """Display model selection interface.""" + try: + models = client.provider.get_raw_models() + if not models: + print_error("No models available") + return None + + # Filter by search term if provided + if search_term: + search_lower = search_term.lower() + models = [m for m in models if search_lower in m.get("id", "").lower()] + + if not models: + print_error(f"No models found matching '{search_term}'") + return None + + # Create and display table + table = create_model_table(models) + display_paginated_table( + table, + f"[bold green]Available Models ({len(models)})[/]", + ) + + # Prompt for selection + console.print("") + try: + choice = input("Enter model number (or press Enter to cancel): ").strip() + except (EOFError, KeyboardInterrupt): + return None + + if not choice: + return None + + try: + index = int(choice) - 1 + if 0 <= index < len(models): + selected = models[index] + print_success(f"Selected model: {selected['id']}") + return selected + except ValueError: + pass + + print_error("Invalid selection") + return None + + except Exception as e: + print_error(f"Failed to fetch models: {e}") + return None + + +def run_chat_loop( + session: ChatSession, + prompt_session: PromptSession, + settings: Settings, +) -> None: + """Run the main chat loop.""" + logger = get_logger() + mcp_manager = session.mcp_manager + + while True: + try: + # Build prompt prefix + prefix = "You> " + if mcp_manager and mcp_manager.enabled: + if mcp_manager.mode == "files": + if mcp_manager.write_enabled: + prefix = "[🔧✍️ MCP: Files+Write] You> " + else: + prefix = "[🔧 MCP: Files] You> " + elif mcp_manager.mode == "database" and mcp_manager.selected_db_index is not None: + prefix = f"[🗄️ MCP: DB #{mcp_manager.selected_db_index + 1}] You> " + + # Get user input + user_input = prompt_session.prompt( + prefix, + auto_suggest=AutoSuggestFromHistory(), + ).strip() + + if not user_input: + continue + + # Handle escape sequence + if user_input.startswith("//"): + user_input = user_input[1:] + + # Check for exit + if user_input.lower() in ["exit", "quit", "bye"]: + console.print( + f"\n[bold yellow]Goodbye![/]\n" + f"[dim]Session: {session.stats.total_tokens:,} tokens, " + f"${session.stats.total_cost:.4f}[/]" + ) + logger.info( + f"Session ended. Messages: {session.stats.message_count}, " + f"Tokens: {session.stats.total_tokens}, " + f"Cost: ${session.stats.total_cost:.4f}" + ) + return + + # Check for unknown commands + if user_input.startswith("/"): + cmd_word = user_input.split()[0].lower() + if not registry.is_command(user_input): + # Check if it's a valid command prefix + is_valid = any(cmd_word.startswith(cmd) for cmd in VALID_COMMANDS) + if not is_valid: + print_error(f"Unknown command: {cmd_word}") + print_info("Type /help to see available commands.") + continue + + # Try to execute as command + context = session.get_context() + result = registry.execute(user_input, context) + + if result: + # Update session state from context + session.memory_enabled = context.memory_enabled + session.memory_start_index = context.memory_start_index + session.online_enabled = context.online_enabled + session.middle_out_enabled = context.middle_out_enabled + session.session_max_token = context.session_max_token + session.current_index = context.current_index + session.system_prompt = context.session_system_prompt + + if result.status == CommandStatus.EXIT: + return + + # Handle special results + if result.data: + # Retry - resend last prompt + if "retry_prompt" in result.data: + user_input = result.data["retry_prompt"] + # Fall through to send message + + # Paste - send clipboard content + elif "paste_prompt" in result.data: + user_input = result.data["paste_prompt"] + # Fall through to send message + + # Model selection + elif "show_model_selector" in result.data: + search = result.data.get("search", "") + model = select_model(session.client, search if search else None) + if model: + session.set_model(model) + # If this came from /config model, also save as default + if result.data.get("set_as_default"): + settings.set_default_model(model["id"]) + print_success(f"Default model set to: {model['id']}") + continue + + # Load conversation + elif "load_conversation" in result.data: + history = result.data.get("history", []) + session.history.clear() + from oai.core.session import HistoryEntry + for entry in history: + session.history.append(HistoryEntry( + prompt=entry.get("prompt", ""), + response=entry.get("response", ""), + prompt_tokens=entry.get("prompt_tokens", 0), + completion_tokens=entry.get("completion_tokens", 0), + msg_cost=entry.get("msg_cost", 0.0), + )) + session.current_index = len(session.history) - 1 + continue + + else: + # Normal command completed + continue + else: + # Command completed with no special data + continue + + # Ensure model is selected + if not session.selected_model: + print_warning("Please select a model first with /model") + continue + + # Send message + stream = settings.stream_enabled + if mcp_manager and mcp_manager.enabled: + tools = session.get_mcp_tools() + if tools: + stream = False # Disable streaming with tools + + if stream: + console.print( + "[bold green]Streaming response...[/] " + "[dim](Press Ctrl+C to cancel)[/]" + ) + if session.online_enabled: + console.print("[dim cyan]🌐 Online mode active[/]") + console.print("") + + try: + response_text, usage, response_time = session.send_message( + user_input, + stream=stream, + ) + except Exception as e: + print_error(f"Error: {e}") + logger.error(f"Message error: {e}") + continue + + if not response_text: + print_error("No response received") + continue + + # Display non-streaming response + if not stream: + console.print() + display_panel( + Markdown(response_text), + title="[bold green]AI Response[/]", + border_style="green", + ) + + # Calculate cost and tokens + cost = 0.0 + tokens = 0 + estimated = False + + if usage and (usage.prompt_tokens > 0 or usage.completion_tokens > 0): + tokens = usage.total_tokens + if usage.total_cost_usd: + cost = usage.total_cost_usd + else: + cost = session.client.estimate_cost( + session.selected_model["id"], + usage.prompt_tokens, + usage.completion_tokens, + ) + else: + # Estimate tokens when usage not available (streaming fallback) + # Rough estimate: ~4 characters per token for English text + est_input_tokens = len(user_input) // 4 + 1 + est_output_tokens = len(response_text) // 4 + 1 + tokens = est_input_tokens + est_output_tokens + cost = session.client.estimate_cost( + session.selected_model["id"], + est_input_tokens, + est_output_tokens, + ) + # Create estimated usage for session tracking + usage = UsageStats( + prompt_tokens=est_input_tokens, + completion_tokens=est_output_tokens, + total_tokens=tokens, + ) + estimated = True + + # Add to history + session.add_to_history(user_input, response_text, usage, cost) + + # Display metrics + est_marker = "~" if estimated else "" + context_info = "" + if session.memory_enabled: + context_count = len(session.history) - session.memory_start_index + if context_count > 1: + context_info = f", Context: {context_count} msg(s)" + else: + context_info = ", Memory: OFF" + + online_emoji = " 🌐" if session.online_enabled else "" + mcp_emoji = "" + if mcp_manager and mcp_manager.enabled: + if mcp_manager.mode == "files": + mcp_emoji = " 🔧" + elif mcp_manager.mode == "database": + mcp_emoji = " 🗄️" + + console.print( + f"\n[dim blue]📊 {est_marker}{tokens} tokens | {est_marker}${cost:.4f} | {response_time:.2f}s" + f"{context_info}{online_emoji}{mcp_emoji} | " + f"Session: {est_marker}{session.stats.total_tokens:,} tokens | " + f"{est_marker}${session.stats.total_cost:.4f}[/]" + ) + + # Check warnings + warnings = session.check_warnings() + for warning in warnings: + print_warning(warning) + + # Offer to copy + console.print("") + try: + from oai.ui.prompts import prompt_copy_response + prompt_copy_response(response_text) + except Exception: + pass + console.print("") + + except KeyboardInterrupt: + console.print("\n[dim]Input cancelled[/]") + continue + except EOFError: + console.print("\n[bold yellow]Goodbye![/]") + return + + +@app.command() +def chat( + model: Optional[str] = typer.Option( + None, + "--model", + "-m", + help="Model ID to use", + ), + system: Optional[str] = typer.Option( + None, + "--system", + "-s", + help="System prompt", + ), + online: bool = typer.Option( + False, + "--online", + "-o", + help="Enable online mode", + ), + mcp: bool = typer.Option( + False, + "--mcp", + help="Enable MCP server", + ), +) -> None: + """Start an interactive chat session.""" + # Setup logging + logging_manager = LoggingManager() + logging_manager.setup() + logger = get_logger() + + # Clear screen + clear_screen() + + # Load settings + settings = Settings.load() + + # Check API key + if not settings.api_key: + print_error("No API key configured") + print_info("Run: oai --config api to set your API key") + raise typer.Exit(1) + + # Initialize client + try: + client = AIClient( + api_key=settings.api_key, + base_url=settings.base_url, + ) + except Exception as e: + print_error(f"Failed to initialize client: {e}") + raise typer.Exit(1) + + # Register commands + register_all_commands() + + # Check for updates and show welcome + version_info = check_for_updates(APP_VERSION) + show_welcome(settings, version_info) + + # Initialize MCP manager + mcp_manager = MCPManager() + if mcp: + result = mcp_manager.enable() + if result["success"]: + print_success("MCP enabled") + else: + print_warning(f"MCP: {result.get('error', 'Failed to enable')}") + + # Create session + session = ChatSession( + client=client, + settings=settings, + mcp_manager=mcp_manager, + ) + + # Set system prompt + if system: + session.system_prompt = system + print_info(f"System prompt: {system}") + + # Set online mode + if online: + session.online_enabled = True + print_info("Online mode enabled") + + # Select model + if model: + raw_model = client.get_raw_model(model) + if raw_model: + session.set_model(raw_model) + else: + print_warning(f"Model '{model}' not found") + elif settings.default_model: + raw_model = client.get_raw_model(settings.default_model) + if raw_model: + session.set_model(raw_model) + else: + print_warning(f"Default model '{settings.default_model}' not available") + + # Setup prompt session + HISTORY_FILE.parent.mkdir(parents=True, exist_ok=True) + prompt_session = PromptSession( + history=FileHistory(str(HISTORY_FILE)), + ) + + # Run chat loop + run_chat_loop(session, prompt_session, settings) + + +@app.command() +def config( + setting: Optional[str] = typer.Argument( + None, + help="Setting to configure (api, url, model, system, stream, costwarning, maxtoken, online, log, loglevel)", + ), + value: Optional[str] = typer.Argument( + None, + help="Value to set", + ), +) -> None: + """View or modify configuration settings.""" + settings = Settings.load() + + if not setting: + # Show all settings + from rich.table import Table + from oai.constants import DEFAULT_SYSTEM_PROMPT + + table = Table("Setting", "Value", show_header=True, header_style="bold magenta") + table.add_row("API Key", "***" + settings.api_key[-4:] if settings.api_key else "Not set") + table.add_row("Base URL", settings.base_url) + table.add_row("Default Model", settings.default_model or "Not set") + + # Show system prompt status + if settings.default_system_prompt is None: + system_prompt_display = f"[default] {DEFAULT_SYSTEM_PROMPT[:40]}..." + elif settings.default_system_prompt == "": + system_prompt_display = "[blank]" + else: + system_prompt_display = settings.default_system_prompt[:50] + "..." if len(settings.default_system_prompt) > 50 else settings.default_system_prompt + table.add_row("System Prompt", system_prompt_display) + + table.add_row("Streaming", "on" if settings.stream_enabled else "off") + table.add_row("Cost Warning", f"${settings.cost_warning_threshold:.4f}") + table.add_row("Max Tokens", str(settings.max_tokens)) + table.add_row("Default Online", "on" if settings.default_online_mode else "off") + table.add_row("Log Level", settings.log_level) + + display_panel(table, title="[bold green]Configuration[/]") + return + + setting = setting.lower() + + if setting == "api": + if value: + settings.set_api_key(value) + else: + from oai.ui.prompts import prompt_input + new_key = prompt_input("Enter API key", password=True) + if new_key: + settings.set_api_key(new_key) + print_success("API key updated") + + elif setting == "url": + settings.set_base_url(value or "https://openrouter.ai/api/v1") + print_success(f"Base URL set to: {settings.base_url}") + + elif setting == "model": + if value: + settings.set_default_model(value) + print_success(f"Default model set to: {value}") + else: + print_info(f"Current default model: {settings.default_model or 'Not set'}") + + elif setting == "system": + from oai.constants import DEFAULT_SYSTEM_PROMPT + + if value: + # Decode escape sequences like \n for newlines + value = value.encode().decode('unicode_escape') + settings.set_default_system_prompt(value) + if value: + print_success(f"Default system prompt set to: {value}") + else: + print_success("Default system prompt set to blank.") + else: + if settings.default_system_prompt is None: + print_info(f"Using hardcoded default: {DEFAULT_SYSTEM_PROMPT[:60]}...") + elif settings.default_system_prompt == "": + print_info("System prompt: [blank]") + else: + print_info(f"System prompt: {settings.default_system_prompt}") + + elif setting == "stream": + if value and value.lower() in ["on", "off"]: + settings.set_stream_enabled(value.lower() == "on") + print_success(f"Streaming {'enabled' if settings.stream_enabled else 'disabled'}") + else: + print_info("Usage: oai config stream [on|off]") + + elif setting == "costwarning": + if value: + try: + threshold = float(value) + settings.set_cost_warning_threshold(threshold) + print_success(f"Cost warning threshold set to: ${threshold:.4f}") + except ValueError: + print_error("Please enter a valid number") + else: + print_info(f"Current threshold: ${settings.cost_warning_threshold:.4f}") + + elif setting == "maxtoken": + if value: + try: + max_tok = int(value) + settings.set_max_tokens(max_tok) + print_success(f"Max tokens set to: {max_tok}") + except ValueError: + print_error("Please enter a valid number") + else: + print_info(f"Current max tokens: {settings.max_tokens}") + + elif setting == "online": + if value and value.lower() in ["on", "off"]: + settings.set_default_online_mode(value.lower() == "on") + print_success(f"Default online mode {'enabled' if settings.default_online_mode else 'disabled'}") + else: + print_info("Usage: oai config online [on|off]") + + elif setting == "loglevel": + valid_levels = ["debug", "info", "warning", "error", "critical"] + if value and value.lower() in valid_levels: + settings.set_log_level(value.lower()) + print_success(f"Log level set to: {value.lower()}") + else: + print_info(f"Valid levels: {', '.join(valid_levels)}") + + else: + print_error(f"Unknown setting: {setting}") + + +@app.command() +def version() -> None: + """Show version information.""" + version_info = check_for_updates(APP_VERSION) + console.print(version_info) + + +@app.command() +def credits() -> None: + """Check account credits.""" + settings = Settings.load() + + if not settings.api_key: + print_error("No API key configured") + raise typer.Exit(1) + + client = AIClient(api_key=settings.api_key, base_url=settings.base_url) + credits_data = client.get_credits() + + if not credits_data: + print_error("Failed to fetch credits") + raise typer.Exit(1) + + from rich.table import Table + + table = Table("Metric", "Value", show_header=True, header_style="bold magenta") + table.add_row("Total Credits", credits_data.get("total_credits_formatted", "N/A")) + table.add_row("Used Credits", credits_data.get("used_credits_formatted", "N/A")) + table.add_row("Credits Left", credits_data.get("credits_left_formatted", "N/A")) + + display_panel(table, title="[bold green]Account Credits[/]") + + +def main() -> None: + """Main entry point.""" + # Default to 'chat' command if no arguments provided + if len(sys.argv) == 1: + sys.argv.append("chat") + app() + + +if __name__ == "__main__": + main() diff --git a/oai/commands/__init__.py b/oai/commands/__init__.py new file mode 100644 index 0000000..37068b3 --- /dev/null +++ b/oai/commands/__init__.py @@ -0,0 +1,24 @@ +""" +Command system for oAI. + +This module provides a command registry and handler system +for processing slash commands in the chat interface. +""" + +from oai.commands.registry import ( + Command, + CommandRegistry, + CommandContext, + CommandResult, + registry, +) +from oai.commands.handlers import register_all_commands + +__all__ = [ + "Command", + "CommandRegistry", + "CommandContext", + "CommandResult", + "registry", + "register_all_commands", +] diff --git a/oai/commands/handlers.py b/oai/commands/handlers.py new file mode 100644 index 0000000..49f23a8 --- /dev/null +++ b/oai/commands/handlers.py @@ -0,0 +1,1441 @@ +""" +Command handlers for oAI. + +This module implements all the slash commands available in the chat interface. +Each command is registered with the global registry. +""" + +import json +from typing import Any, Dict, List, Optional + +from oai.commands.registry import ( + Command, + CommandContext, + CommandHelp, + CommandResult, + CommandStatus, + registry, +) +from oai.constants import COMMAND_HELP, VALID_COMMANDS +from oai.ui.console import ( + clear_screen, + console, + display_markdown, + display_panel, + print_error, + print_info, + print_success, + print_warning, +) +from oai.ui.prompts import prompt_confirm, prompt_input +from oai.ui.tables import ( + create_model_table, + create_stats_table, + display_paginated_table, +) +from oai.utils.export import export_as_html, export_as_json, export_as_markdown +from oai.utils.logging import get_logger + + +class HelpCommand(Command): + """Display help information for commands.""" + + @property + def name(self) -> str: + return "/help" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Display help information for commands.", + usage="/help [command|topic]", + examples=[ + ("Show all commands", "/help"), + ("Get help for a specific command", "/help /model"), + ("Get detailed MCP help", "/help mcp"), + ], + notes="Use /help without arguments to see the full command list.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + logger = get_logger() + + if args: + # Show help for specific command/topic + self._show_command_help(args) + else: + # Show general help + self._show_general_help(context) + + logger.info(f"Displayed help for: {args or 'general'}") + return CommandResult.success() + + def _show_command_help(self, command_or_topic: str) -> None: + """Show help for a specific command or topic.""" + # Handle topics (like 'mcp') + if not command_or_topic.startswith("/"): + if command_or_topic.lower() == "mcp": + help_data = COMMAND_HELP.get("mcp", {}) + if help_data: + content = [] + content.append(f"[bold cyan]Description:[/]") + content.append(help_data.get("description", "")) + content.append("") + content.append(help_data.get("notes", "")) + + display_panel( + "\n".join(content), + title="[bold green]MCP - Model Context Protocol Guide[/]", + border_style="green", + ) + return + command_or_topic = "/" + command_or_topic + + help_data = COMMAND_HELP.get(command_or_topic) + if not help_data: + print_error(f"Unknown command: {command_or_topic}") + print_warning("Type /help to see all available commands.") + return + + content = [] + + if help_data.get("aliases"): + content.append(f"[bold cyan]Aliases:[/] {', '.join(help_data['aliases'])}") + content.append("") + + content.append("[bold cyan]Description:[/]") + content.append(help_data.get("description", "")) + content.append("") + + content.append("[bold cyan]Usage:[/]") + content.append(f"[yellow]{help_data.get('usage', '')}[/]") + content.append("") + + examples = help_data.get("examples", []) + if examples: + content.append("[bold cyan]Examples:[/]") + for desc, example in examples: + if not desc and not example: + content.append("") + elif desc.startswith("━━━"): + content.append(f"[bold yellow]{desc}[/]") + else: + if desc: + content.append(f" [dim]{desc}:[/]") + if example: + content.append(f" [green]{example}[/]") + content.append("") + + notes = help_data.get("notes") + if notes: + content.append("[bold cyan]Notes:[/]") + content.append(f"[dim]{notes}[/]") + + display_panel( + "\n".join(content), + title=f"[bold green]Help: {command_or_topic}[/]", + border_style="green", + ) + + def _show_general_help(self, context: CommandContext) -> None: + """Show general help with all commands.""" + from rich.table import Table + + table = Table( + "Command", + "Description", + "Example", + show_header=True, + header_style="bold magenta", + show_lines=False, + ) + + # Group commands by category + categories = [ + ("[bold cyan]━━━ CHAT ━━━[/]", [ + ("/retry", "Resend the last prompt.", "/retry"), + ("/memory", "Toggle conversation memory.", "/memory on"), + ("/online", "Toggle online mode (web search).", "/online on"), + ("/paste", "Paste from clipboard with optional prompt.", "/paste Explain"), + ]), + ("[bold cyan]━━━ NAVIGATION ━━━[/]", [ + ("/prev", "View previous response in history.", "/prev"), + ("/next", "View next response in history.", "/next"), + ("/reset", "Clear conversation history.", "/reset"), + ]), + ("[bold cyan]━━━ MODEL & CONFIG ━━━[/]", [ + ("/model", "Select AI model.", "/model gpt"), + ("/info", "Show model information.", "/info"), + ("/config", "View or change settings.", "/config stream on"), + ("/maxtoken", "Set session token limit.", "/maxtoken 2000"), + ("/system", "Set system prompt.", "/system You are an expert"), + ]), + ("[bold cyan]━━━ SAVE & EXPORT ━━━[/]", [ + ("/save", "Save conversation.", "/save my_chat"), + ("/load", "Load saved conversation.", "/load my_chat"), + ("/delete", "Delete saved conversation.", "/delete my_chat"), + ("/list", "List saved conversations.", "/list"), + ("/export", "Export conversation.", "/export md notes.md"), + ]), + ("[bold cyan]━━━ STATS & INFO ━━━[/]", [ + ("/stats", "Show session statistics.", "/stats"), + ("/credits", "Show account credits.", "/credits"), + ("/middleout", "Toggle middle-out compression.", "/middleout on"), + ]), + ("[bold cyan]━━━ MCP (FILE/DB ACCESS) ━━━[/]", [ + ("/mcp on", "Enable MCP server.", "/mcp on"), + ("/mcp add", "Add folder or database.", "/mcp add ~/Documents"), + ("/mcp status", "Show MCP status.", "/mcp status"), + ("/mcp write", "Toggle write mode.", "/mcp write on"), + ]), + ("[bold cyan]━━━ UTILITY ━━━[/]", [ + ("/clear", "Clear the screen.", "/clear"), + ("/help", "Show this help.", "/help /model"), + ]), + ("[bold yellow]━━━ EXIT ━━━[/]", [ + ("exit", "Quit the application.", "exit"), + ]), + ] + + for header, commands in categories: + table.add_row(header, "", "") + for cmd, desc, example in commands: + table.add_row(cmd, desc, example) + + from oai.constants import APP_VERSION + + display_panel( + table, + title=f"[bold cyan]oAI Chat Help (Version {APP_VERSION})[/]", + subtitle="💡 Use /help for details • /help mcp for MCP guide", + ) + + +class ClearCommand(Command): + """Clear the terminal screen.""" + + @property + def name(self) -> str: + return "/clear" + + @property + def aliases(self) -> List[str]: + return ["/cl"] + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Clear the terminal screen.", + usage="/clear", + aliases=["/cl"], + notes="You can also use Ctrl+L.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + clear_screen() + return CommandResult.success() + + +class MemoryCommand(Command): + """Toggle conversation memory.""" + + @property + def name(self) -> str: + return "/memory" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Toggle conversation memory.", + usage="/memory [on|off]", + examples=[ + ("Check status", "/memory"), + ("Enable memory", "/memory on"), + ("Disable memory", "/memory off"), + ], + notes="When off, each message is independent (saves tokens).", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + status = "enabled" if context.memory_enabled else "disabled" + print_info(f"Conversation memory is {status}.") + return CommandResult.success() + + if args.lower() == "on": + context.memory_enabled = True + print_success("Memory enabled - AI will remember conversation.") + return CommandResult.success(data={"memory_enabled": True}) + elif args.lower() == "off": + context.memory_enabled = False + context.memory_start_index = len(context.session_history) + print_success("Memory disabled - each message is independent.") + return CommandResult.success(data={"memory_enabled": False}) + else: + print_error("Usage: /memory [on|off]") + return CommandResult.error("Invalid argument") + + +class OnlineCommand(Command): + """Toggle online mode (web search).""" + + @property + def name(self) -> str: + return "/online" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Enable or disable online mode (web search).", + usage="/online [on|off]", + examples=[ + ("Check status", "/online"), + ("Enable web search", "/online on"), + ("Disable web search", "/online off"), + ], + notes="Not all models support online mode.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + status = "enabled" if context.online_enabled else "disabled" + print_info(f"Online mode is {status}.") + return CommandResult.success() + + if args.lower() == "on": + if context.selected_model_raw: + params = context.selected_model_raw.get("supported_parameters", []) + if "tools" not in params: + print_warning("Current model may not support online mode.") + + context.online_enabled = True + print_success("Online mode enabled - AI can search the web.") + return CommandResult.success(data={"online_enabled": True}) + elif args.lower() == "off": + context.online_enabled = False + print_success("Online mode disabled.") + return CommandResult.success(data={"online_enabled": False}) + else: + print_error("Usage: /online [on|off]") + return CommandResult.error("Invalid argument") + + +class ResetCommand(Command): + """Reset conversation history.""" + + @property + def name(self) -> str: + return "/reset" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Clear conversation history and reset system prompt.", + usage="/reset", + notes="Requires confirmation. Resets all session metrics.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not prompt_confirm("Reset conversation and clear history?"): + print_info("Reset cancelled.") + return CommandResult.success() + + context.session_history.clear() + context.session_system_prompt = "" + context.memory_start_index = 0 + context.current_index = 0 + context.total_input_tokens = 0 + context.total_output_tokens = 0 + context.total_cost = 0.0 + context.message_count = 0 + + print_success("Conversation reset. History and system prompt cleared.") + get_logger().info("Conversation reset by user") + return CommandResult.success() + + +class StatsCommand(Command): + """Display session statistics.""" + + @property + def name(self) -> str: + return "/stats" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Display session statistics.", + usage="/stats", + notes="Shows tokens, costs, and credits.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + from rich.table import Table + + table = Table( + "Metric", + "Value", + show_header=True, + header_style="bold magenta", + ) + + table.add_row("Input Tokens", f"{context.total_input_tokens:,}") + table.add_row("Output Tokens", f"{context.total_output_tokens:,}") + table.add_row( + "Total Tokens", + f"{context.total_input_tokens + context.total_output_tokens:,}", + ) + table.add_row("Total Cost", f"${context.total_cost:.4f}") + + if context.message_count > 0: + avg_cost = context.total_cost / context.message_count + table.add_row("Avg Cost/Message", f"${avg_cost:.4f}") + + table.add_row("Messages", str(context.message_count)) + + # Get credits if provider available + if context.provider: + credits = context.provider.get_credits() + if credits: + table.add_row("", "") + table.add_row("[bold]Credits[/]", "") + table.add_row( + "Credits Left", + credits.get("credits_left_formatted", "N/A"), + ) + + display_panel(table, title="[bold green]Session Statistics[/]") + return CommandResult.success() + + +class CreditsCommand(Command): + """Display account credits.""" + + @property + def name(self) -> str: + return "/credits" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Display your OpenRouter account credits.", + usage="/credits", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not context.provider: + print_error("No provider configured.") + return CommandResult.error("No provider") + + credits = context.provider.get_credits() + if not credits: + print_error("Failed to fetch credits.") + return CommandResult.error("Failed to fetch credits") + + from rich.table import Table + + table = Table( + "Metric", + "Value", + show_header=True, + header_style="bold magenta", + ) + + table.add_row("Total Credits", credits.get("total_credits_formatted", "N/A")) + table.add_row("Used Credits", credits.get("used_credits_formatted", "N/A")) + table.add_row("Credits Left", credits.get("credits_left_formatted", "N/A")) + + # Check for warnings + from oai.constants import LOW_CREDIT_AMOUNT, LOW_CREDIT_RATIO + + credits_left = credits.get("credits_left", 0) + total = credits.get("total_credits", 0) + + warnings = [] + if credits_left < LOW_CREDIT_AMOUNT: + warnings.append(f"Less than ${LOW_CREDIT_AMOUNT:.2f} remaining!") + elif total > 0 and credits_left < total * LOW_CREDIT_RATIO: + warnings.append("Less than 10% of credits remaining!") + + display_panel(table, title="[bold green]Account Credits[/]") + + for warning in warnings: + print_warning(warning) + + return CommandResult.success(data=credits) + + +class ExportCommand(Command): + """Export conversation to file.""" + + @property + def name(self) -> str: + return "/export" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Export conversation to a file.", + usage="/export ", + examples=[ + ("Export as Markdown", "/export md notes.md"), + ("Export as JSON", "/export json conversation.json"), + ("Export as HTML", "/export html report.html"), + ], + notes="Available formats: md, json, html", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + parts = args.split(maxsplit=1) + if len(parts) != 2: + print_error("Usage: /export ") + return CommandResult.error("Invalid arguments") + + fmt, filename = parts + fmt = fmt.lower() + + if fmt not in ["md", "json", "html"]: + print_error(f"Unknown format: {fmt}") + print_info("Available formats: md, json, html") + return CommandResult.error("Invalid format") + + if not context.session_history: + print_warning("No conversation to export.") + return CommandResult.error("Empty history") + + try: + if fmt == "md": + content = export_as_markdown( + context.session_history, + context.session_system_prompt, + ) + elif fmt == "json": + content = export_as_json( + context.session_history, + context.session_system_prompt, + ) + else: # html + content = export_as_html( + context.session_history, + context.session_system_prompt, + ) + + with open(filename, "w", encoding="utf-8") as f: + f.write(content) + + print_success(f"Exported to {filename}") + get_logger().info(f"Exported conversation to {filename} ({fmt})") + return CommandResult.success() + + except Exception as e: + print_error(f"Export failed: {e}") + return CommandResult.error(str(e)) + + +class MiddleOutCommand(Command): + """Toggle middle-out compression.""" + + @property + def name(self) -> str: + return "/middleout" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Toggle middle-out transform for long prompts.", + usage="/middleout [on|off]", + examples=[ + ("Check status", "/middleout"), + ("Enable", "/middleout on"), + ("Disable", "/middleout off"), + ], + notes="Compresses prompts exceeding context size.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + status = "enabled" if context.middle_out_enabled else "disabled" + print_info(f"Middle-out compression is {status}.") + return CommandResult.success() + + if args.lower() == "on": + context.middle_out_enabled = True + print_success("Middle-out compression enabled.") + return CommandResult.success(data={"middle_out_enabled": True}) + elif args.lower() == "off": + context.middle_out_enabled = False + print_success("Middle-out compression disabled.") + return CommandResult.success(data={"middle_out_enabled": False}) + else: + print_error("Usage: /middleout [on|off]") + return CommandResult.error("Invalid argument") + + +class MaxTokenCommand(Command): + """Set session token limit.""" + + @property + def name(self) -> str: + return "/maxtoken" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Set session token limit.", + usage="/maxtoken [value]", + examples=[ + ("View current", "/maxtoken"), + ("Set to 2000", "/maxtoken 2000"), + ], + notes="Cannot exceed stored max token setting.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + if context.session_max_token > 0: + print_info(f"Session max token: {context.session_max_token}") + else: + print_info("No session token limit set (using model default).") + return CommandResult.success() + + try: + value = int(args) + if value <= 0: + print_error("Token limit must be positive.") + return CommandResult.error("Invalid value") + + # Check against stored limit + stored_max = 100000 # Default + if context.settings: + stored_max = context.settings.max_tokens + + if value > stored_max: + print_warning( + f"Value {value} exceeds stored limit {stored_max}. " + f"Using {stored_max}." + ) + value = stored_max + + context.session_max_token = value + print_success(f"Session max token set to {value}.") + return CommandResult.success(data={"session_max_token": value}) + + except ValueError: + print_error("Please enter a valid number.") + return CommandResult.error("Invalid number") + + +class SystemCommand(Command): + """Set session system prompt.""" + + @property + def name(self) -> str: + return "/system" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Set or clear the session system prompt.", + usage="/system [prompt|clear|default ]", + examples=[ + ("View current", "/system"), + ("Set prompt", "/system You are a Python expert"), + ("Multiline prompt", r"/system You are an expert.\nRespond clearly."), + ("Blank prompt", '/system ""'), + ("Save as default", "/system default You are a Python expert"), + ("Revert to default", "/system clear"), + ], + notes=r'Use \n for newlines. Use /system "" for blank, /system clear to revert to default.', + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + from oai.constants import DEFAULT_SYSTEM_PROMPT + + if not args: + # Show current session prompt and default + if context.session_system_prompt: + print_info(f"Session prompt: {context.session_system_prompt}") + else: + print_info("Session prompt: [blank]") + + if context.settings: + if context.settings.default_system_prompt is None: + print_info(f"Default prompt: [hardcoded] {DEFAULT_SYSTEM_PROMPT[:60]}...") + elif context.settings.default_system_prompt == "": + print_info("Default prompt: [blank]") + else: + print_info(f"Custom default: {context.settings.default_system_prompt}") + return CommandResult.success() + + if args.lower() == "clear": + # Revert to hardcoded default + if context.settings: + context.settings.clear_default_system_prompt() + context.session_system_prompt = DEFAULT_SYSTEM_PROMPT + print_success("Reverted to hardcoded default system prompt.") + print_info(f"Default: {DEFAULT_SYSTEM_PROMPT[:60]}...") + else: + context.session_system_prompt = DEFAULT_SYSTEM_PROMPT + print_success("Session prompt reverted to default.") + return CommandResult.success() + + # Check for default command + if args.lower().startswith("default "): + prompt = args[8:] # Remove "default " prefix (keep trailing spaces) + if not prompt: + print_error("Usage: /system default ") + return CommandResult.error("No prompt provided") + + # Decode escape sequences like \n for newlines + prompt = prompt.encode().decode('unicode_escape') + + if context.settings: + context.settings.set_default_system_prompt(prompt) + context.session_system_prompt = prompt + if prompt: + print_success(f"Default system prompt saved: {prompt}") + else: + print_success("Default system prompt set to blank.") + get_logger().info(f"Default system prompt updated: {prompt[:50]}...") + else: + print_error("Settings not available") + return CommandResult.error("No settings") + return CommandResult.success() + + # Decode escape sequences like \n for newlines + prompt = args.encode().decode('unicode_escape') + + context.session_system_prompt = prompt + if prompt: + print_success(f"Session system prompt set: {prompt}") + print_info("Use '/system default ' to save as default for all sessions") + else: + print_success("Session system prompt set to blank.") + get_logger().info(f"System prompt updated: {prompt[:50]}...") + return CommandResult.success() + + +class RetryCommand(Command): + """Resend the last prompt.""" + + @property + def name(self) -> str: + return "/retry" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Resend the last prompt.", + usage="/retry", + notes="Requires at least one message in history.", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not context.session_history: + print_error("No message to retry.") + return CommandResult.error("Empty history") + + last_prompt = context.session_history[-1].get("prompt", "") + if not last_prompt: + print_error("Last message has no prompt.") + return CommandResult.error("No prompt") + + # Return the prompt to be re-sent + print_info(f"Retrying: {last_prompt[:50]}...") + return CommandResult.success(data={"retry_prompt": last_prompt}) + + +class PrevCommand(Command): + """View previous response.""" + + @property + def name(self) -> str: + return "/prev" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="View previous response in history.", + usage="/prev", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not context.session_history: + print_error("No history to navigate.") + return CommandResult.error("Empty history") + + if context.current_index <= 0: + print_info("Already at the beginning of history.") + return CommandResult.success() + + context.current_index -= 1 + entry = context.session_history[context.current_index] + + console.print(f"\n[dim]Message {context.current_index + 1}/{len(context.session_history)}[/]") + console.print(f"\n[bold cyan]Prompt:[/] {entry.get('prompt', '')}") + display_markdown(entry.get("response", ""), panel=True, title="Response") + + return CommandResult.success() + + +class NextCommand(Command): + """View next response.""" + + @property + def name(self) -> str: + return "/next" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="View next response in history.", + usage="/next", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not context.session_history: + print_error("No history to navigate.") + return CommandResult.error("Empty history") + + if context.current_index >= len(context.session_history) - 1: + print_info("Already at the end of history.") + return CommandResult.success() + + context.current_index += 1 + entry = context.session_history[context.current_index] + + console.print(f"\n[dim]Message {context.current_index + 1}/{len(context.session_history)}[/]") + console.print(f"\n[bold cyan]Prompt:[/] {entry.get('prompt', '')}") + display_markdown(entry.get("response", ""), panel=True, title="Response") + + return CommandResult.success() + + +class ConfigCommand(Command): + """View or modify configuration.""" + + @property + def name(self) -> str: + return "/config" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="View or modify application configuration.", + usage="/config [setting] [value]", + examples=[ + ("View all settings", "/config"), + ("Set API key", "/config api"), + ("Enable streaming", "/config stream on"), + ], + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + from rich.table import Table + + if not context.settings: + print_error("Settings not available") + return CommandResult.error("No settings") + + settings = context.settings + + if not args: + # Show all settings + from oai.constants import DEFAULT_SYSTEM_PROMPT + + table = Table("Setting", "Value", show_header=True, header_style="bold magenta") + table.add_row("API Key", "***" + settings.api_key[-4:] if settings.api_key else "Not set") + table.add_row("Base URL", settings.base_url) + table.add_row("Default Model", settings.default_model or "Not set") + + # Show system prompt status + if settings.default_system_prompt is None: + system_prompt_display = f"[default] {DEFAULT_SYSTEM_PROMPT[:40]}..." + elif settings.default_system_prompt == "": + system_prompt_display = "[blank]" + else: + system_prompt_display = settings.default_system_prompt[:50] + "..." if len(settings.default_system_prompt) > 50 else settings.default_system_prompt + table.add_row("System Prompt", system_prompt_display) + + table.add_row("Streaming", "on" if settings.stream_enabled else "off") + table.add_row("Cost Warning", f"${settings.cost_warning_threshold:.4f}") + table.add_row("Max Tokens", str(settings.max_tokens)) + table.add_row("Default Online", "on" if settings.default_online_mode else "off") + table.add_row("Log Level", settings.log_level) + + display_panel(table, title="[bold green]Configuration[/]") + return CommandResult.success() + + parts = args.split(maxsplit=1) + setting = parts[0].lower() + value = parts[1] if len(parts) > 1 else None + + if setting == "api": + if value: + settings.set_api_key(value) + else: + new_key = prompt_input("Enter API key", password=True) + if new_key: + settings.set_api_key(new_key) + print_success("API key updated") + + elif setting == "stream": + if value and value.lower() in ["on", "off"]: + settings.set_stream_enabled(value.lower() == "on") + print_success(f"Streaming {'enabled' if settings.stream_enabled else 'disabled'}") + else: + print_info("Usage: /config stream [on|off]") + + elif setting == "model": + if value: + # Show model selector with search term, same as /model + return CommandResult.success(data={"show_model_selector": True, "search": value, "set_as_default": True}) + else: + print_info(f"Current: {settings.default_model or 'Not set'}") + + elif setting == "system": + from oai.constants import DEFAULT_SYSTEM_PROMPT + + if value: + # Decode escape sequences like \n for newlines + value = value.encode().decode('unicode_escape') + settings.set_default_system_prompt(value) + if value: + print_success(f"Default system prompt set to: {value}") + else: + print_success("Default system prompt set to blank.") + else: + if settings.default_system_prompt is None: + print_info(f"Using hardcoded default: {DEFAULT_SYSTEM_PROMPT[:60]}...") + elif settings.default_system_prompt == "": + print_info("System prompt: [blank]") + else: + print_info(f"System prompt: {settings.default_system_prompt}") + + elif setting == "online": + if value and value.lower() in ["on", "off"]: + settings.set_default_online_mode(value.lower() == "on") + print_success(f"Default online {'enabled' if settings.default_online_mode else 'disabled'}") + else: + print_info("Usage: /config online [on|off]") + + elif setting == "costwarning": + if value: + try: + settings.set_cost_warning_threshold(float(value)) + print_success(f"Cost warning set to: ${float(value):.4f}") + except ValueError: + print_error("Please enter a valid number") + else: + print_info(f"Current: ${settings.cost_warning_threshold:.4f}") + + elif setting == "maxtoken": + if value: + try: + settings.set_max_tokens(int(value)) + print_success(f"Max tokens set to: {value}") + except ValueError: + print_error("Please enter a valid number") + else: + print_info(f"Current: {settings.max_tokens}") + + elif setting == "loglevel": + valid_levels = ["debug", "info", "warning", "error", "critical"] + if value and value.lower() in valid_levels: + settings.set_log_level(value.lower()) + print_success(f"Log level set to: {value.lower()}") + else: + print_info(f"Valid levels: {', '.join(valid_levels)}") + + else: + print_error(f"Unknown setting: {setting}") + return CommandResult.error("Unknown setting") + + return CommandResult.success() + + +class ListCommand(Command): + """List saved conversations.""" + + @property + def name(self) -> str: + return "/list" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="List all saved conversations.", + usage="/list", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + from oai.config.database import Database + from rich.table import Table + + db = Database() + conversations = db.list_conversations() + + if not conversations: + print_info("No saved conversations.") + return CommandResult.success() + + table = Table("No.", "Name", "Messages", "Last Saved", show_header=True, header_style="bold magenta") + + for i, conv in enumerate(conversations, 1): + table.add_row( + str(i), + conv["name"], + str(conv["message_count"]), + conv["timestamp"][:19] if conv.get("timestamp") else "-", + ) + + display_panel(table, title="[bold green]Saved Conversations[/]") + return CommandResult.success(data={"conversations": conversations}) + + +class SaveCommand(Command): + """Save conversation.""" + + @property + def name(self) -> str: + return "/save" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Save the current conversation.", + usage="/save ", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + print_error("Usage: /save ") + return CommandResult.error("Missing name") + + if not context.session_history: + print_warning("No conversation to save.") + return CommandResult.error("Empty history") + + from oai.config.database import Database + + db = Database() + db.save_conversation(args, context.session_history) + print_success(f"Conversation saved as '{args}'") + get_logger().info(f"Saved conversation: {args}") + return CommandResult.success() + + +class LoadCommand(Command): + """Load saved conversation.""" + + @property + def name(self) -> str: + return "/load" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Load a saved conversation.", + usage="/load ", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + print_error("Usage: /load ") + return CommandResult.error("Missing name") + + from oai.config.database import Database + + db = Database() + + # Check if it's a number + name = args + if args.isdigit(): + conversations = db.list_conversations() + index = int(args) - 1 + if 0 <= index < len(conversations): + name = conversations[index]["name"] + else: + print_error(f"Invalid number. Use 1-{len(conversations)}") + return CommandResult.error("Invalid number") + + data = db.load_conversation(name) + if not data: + print_error(f"Conversation '{name}' not found") + return CommandResult.error("Not found") + + print_success(f"Loaded conversation '{name}' ({len(data)} messages)") + get_logger().info(f"Loaded conversation: {name}") + return CommandResult.success(data={"load_conversation": name, "history": data}) + + +class DeleteCommand(Command): + """Delete saved conversation.""" + + @property + def name(self) -> str: + return "/delete" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Delete a saved conversation.", + usage="/delete ", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not args: + print_error("Usage: /delete ") + return CommandResult.error("Missing name") + + from oai.config.database import Database + + db = Database() + + # Check if it's a number + name = args + if args.isdigit(): + conversations = db.list_conversations() + index = int(args) - 1 + if 0 <= index < len(conversations): + name = conversations[index]["name"] + else: + print_error(f"Invalid number. Use 1-{len(conversations)}") + return CommandResult.error("Invalid number") + + if not prompt_confirm(f"Delete conversation '{name}'?"): + print_info("Cancelled") + return CommandResult.success() + + count = db.delete_conversation(name) + if count > 0: + print_success(f"Deleted conversation '{name}'") + get_logger().info(f"Deleted conversation: {name}") + else: + print_error(f"Conversation '{name}' not found") + return CommandResult.error("Not found") + + return CommandResult.success() + + +class InfoCommand(Command): + """Show model information.""" + + @property + def name(self) -> str: + return "/info" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Display detailed model information.", + usage="/info [model_id]", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + from rich.table import Table + + if not context.provider: + print_error("No provider available") + return CommandResult.error("No provider") + + model_id = args.strip() if args else None + + if not model_id and context.selected_model_raw: + model_id = context.selected_model_raw.get("id") + + if not model_id: + print_error("No model specified. Use /info or select a model first.") + return CommandResult.error("No model") + + # Get raw model data + model = None + if hasattr(context.provider, "get_raw_model"): + model = context.provider.get_raw_model(model_id) + + if not model: + print_error(f"Model '{model_id}' not found") + return CommandResult.error("Not found") + + table = Table("Property", "Value", show_header=True, header_style="bold magenta") + table.add_row("ID", model.get("id", "")) + table.add_row("Name", model.get("name", "")) + table.add_row("Context Length", f"{model.get('context_length', 0):,}") + + # Pricing + pricing = model.get("pricing", {}) + if pricing: + prompt_price = float(pricing.get("prompt", 0)) * 1_000_000 + completion_price = float(pricing.get("completion", 0)) * 1_000_000 + table.add_row("Input Price", f"${prompt_price:.2f}/M tokens") + table.add_row("Output Price", f"${completion_price:.2f}/M tokens") + + # Capabilities + arch = model.get("architecture", {}) + input_mod = arch.get("input_modalities", []) + output_mod = arch.get("output_modalities", []) + supported = model.get("supported_parameters", []) + + table.add_row("Input Modalities", ", ".join(input_mod) if input_mod else "text") + table.add_row("Output Modalities", ", ".join(output_mod) if output_mod else "text") + table.add_row("Image Support", "✓" if "image" in input_mod else "✗") + table.add_row("Tool Support", "✓" if "tools" in supported else "✗") + table.add_row("Online Support", "✓" if "tools" in supported else "✗") + + display_panel(table, title=f"[bold green]Model: {model_id}[/]") + return CommandResult.success() + + +class MCPCommand(Command): + """MCP management command.""" + + @property + def name(self) -> str: + return "/mcp" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Manage MCP (Model Context Protocol).", + usage="/mcp [args]", + examples=[ + ("Enable MCP", "/mcp on"), + ("Show status", "/mcp status"), + ("Add folder", "/mcp add ~/Documents"), + ], + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + if not context.mcp_manager: + print_error("MCP not available") + return CommandResult.error("No MCP") + + mcp = context.mcp_manager + parts = args.strip().split(maxsplit=1) + cmd = parts[0].lower() if parts else "" + cmd_args = parts[1] if len(parts) > 1 else "" + + if cmd in ["on", "enable"]: + result = mcp.enable() + if result["success"]: + print_success(result.get("message", "MCP enabled")) + if result.get("folder_count", 0) == 0: + print_info("Add folders with: /mcp add ~/Documents") + else: + print_error(result.get("error", "Failed")) + return CommandResult.success() + + elif cmd in ["off", "disable"]: + result = mcp.disable() + if result["success"]: + print_success("MCP disabled") + else: + print_error(result.get("error", "Failed")) + return CommandResult.success() + + elif cmd == "status": + result = mcp.get_status() + if result["success"]: + from rich.table import Table + table = Table("Property", "Value", show_header=True, header_style="bold magenta") + + status = "[green]Active ✓[/]" if result["enabled"] else "[red]Inactive ✗[/]" + table.add_row("Status", status) + table.add_row("Mode", result.get("mode_info", {}).get("mode_display", "files")) + table.add_row("Folders", str(result.get("folder_count", 0))) + table.add_row("Databases", str(result.get("database_count", 0))) + table.add_row("Write Mode", "[green]Enabled[/]" if result.get("write_enabled") else "[dim]Disabled[/]") + table.add_row(".gitignore", result.get("gitignore_status", "on")) + + display_panel(table, title="[bold green]MCP Status[/]") + return CommandResult.success() + + elif cmd == "add": + if cmd_args.startswith("db "): + db_path = cmd_args[3:].strip() + result = mcp.add_database(db_path) + else: + result = mcp.add_folder(cmd_args) + + if result["success"]: + print_success(f"Added: {cmd_args}") + else: + print_error(result.get("error", "Failed")) + return CommandResult.success() + + elif cmd in ["remove", "rem"]: + if cmd_args.startswith("db "): + result = mcp.remove_database(cmd_args[3:].strip()) + else: + result = mcp.remove_folder(cmd_args) + + if result["success"]: + print_success(f"Removed: {cmd_args}") + else: + print_error(result.get("error", "Failed")) + return CommandResult.success() + + elif cmd == "list": + result = mcp.list_folders() + if result["success"]: + from rich.table import Table + table = Table("No.", "Path", "Files", "Size", show_header=True, header_style="bold magenta") + for f in result.get("folders", []): + table.add_row( + str(f["number"]), + f["path"], + str(f.get("file_count", 0)), + f"{f.get('size_mb', 0):.1f} MB", + ) + display_panel(table, title="[bold green]MCP Folders[/]") + return CommandResult.success() + + elif cmd == "db": + if cmd_args == "list": + result = mcp.list_databases() + if result["success"]: + from rich.table import Table + table = Table("No.", "Name", "Tables", "Size", show_header=True, header_style="bold magenta") + for db in result.get("databases", []): + table.add_row( + str(db["number"]), + db["name"], + str(db.get("table_count", 0)), + f"{db.get('size_mb', 0):.1f} MB", + ) + display_panel(table, title="[bold green]MCP Databases[/]") + elif cmd_args.isdigit(): + result = mcp.switch_mode("database", int(cmd_args)) + if result["success"]: + print_success(result.get("message", "Switched to database mode")) + else: + print_error(result.get("error", "Failed")) + return CommandResult.success() + + elif cmd == "files": + result = mcp.switch_mode("files") + if result["success"]: + print_success("Switched to file mode") + return CommandResult.success() + + elif cmd == "write": + if cmd_args.lower() == "on": + mcp.enable_write() + elif cmd_args.lower() == "off": + mcp.disable_write() + else: + status = "[green]Enabled[/]" if mcp.write_enabled else "[dim]Disabled[/]" + print_info(f"Write mode: {status}") + return CommandResult.success() + + elif cmd == "gitignore": + if cmd_args.lower() in ["on", "off"]: + result = mcp.toggle_gitignore(cmd_args.lower() == "on") + if result["success"]: + print_success(result.get("message", "Updated")) + else: + print_info("Usage: /mcp gitignore [on|off]") + return CommandResult.success() + + else: + print_error(f"Unknown MCP command: {cmd}") + print_info("Commands: on, off, status, add, remove, list, db, files, write, gitignore") + return CommandResult.error("Unknown command") + + +class PasteCommand(Command): + """Paste from clipboard.""" + + @property + def name(self) -> str: + return "/paste" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Paste from clipboard and send to AI.", + usage="/paste [prompt]", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + try: + import pyperclip + content = pyperclip.paste() + except ImportError: + print_error("pyperclip not installed") + return CommandResult.error("pyperclip not installed") + except Exception as e: + print_error(f"Failed to access clipboard: {e}") + return CommandResult.error(str(e)) + + if not content: + print_warning("Clipboard is empty") + return CommandResult.error("Empty clipboard") + + # Show preview + preview = content[:200] + "..." if len(content) > 200 else content + console.print(f"\n[dim]Clipboard content ({len(content)} chars):[/]") + console.print(f"[yellow]{preview}[/]\n") + + # Build the prompt + if args: + full_prompt = f"{args}\n\n```\n{content}\n```" + else: + full_prompt = content + + return CommandResult.success(data={"paste_prompt": full_prompt}) + + +class ModelCommand(Command): + """Select AI model.""" + + @property + def name(self) -> str: + return "/model" + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description="Select or search for AI models.", + usage="/model [search_term]", + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + # This is handled specially in the CLI, but we need a handler + # to prevent it from being sent to the AI + return CommandResult.success(data={"show_model_selector": True, "search": args}) + + +def register_all_commands() -> None: + """Register all built-in commands with the global registry.""" + commands = [ + HelpCommand(), + ClearCommand(), + MemoryCommand(), + OnlineCommand(), + ResetCommand(), + StatsCommand(), + CreditsCommand(), + ExportCommand(), + MiddleOutCommand(), + MaxTokenCommand(), + SystemCommand(), + RetryCommand(), + PrevCommand(), + NextCommand(), + ConfigCommand(), + ListCommand(), + SaveCommand(), + LoadCommand(), + DeleteCommand(), + InfoCommand(), + MCPCommand(), + PasteCommand(), + ModelCommand(), + ] + + for command in commands: + try: + registry.register(command) + except ValueError as e: + get_logger().warning(f"Failed to register command: {e}") diff --git a/oai/commands/registry.py b/oai/commands/registry.py new file mode 100644 index 0000000..9be81ad --- /dev/null +++ b/oai/commands/registry.py @@ -0,0 +1,381 @@ +""" +Command registry for oAI. + +This module defines the command system infrastructure including +the Command base class, CommandContext for state, and CommandRegistry +for managing available commands. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING + +from oai.utils.logging import get_logger + +if TYPE_CHECKING: + from oai.config.settings import Settings + from oai.providers.base import AIProvider, ModelInfo + from oai.mcp.manager import MCPManager + + +class CommandStatus(str, Enum): + """Status of command execution.""" + + SUCCESS = "success" + ERROR = "error" + CONTINUE = "continue" # Continue to next handler + EXIT = "exit" # Exit the application + + +@dataclass +class CommandResult: + """ + Result of a command execution. + + Attributes: + status: Execution status + message: Optional message to display + data: Optional data payload + should_continue: Whether to continue the main loop + """ + + status: CommandStatus = CommandStatus.SUCCESS + message: Optional[str] = None + data: Optional[Any] = None + should_continue: bool = True + + @classmethod + def success(cls, message: Optional[str] = None, data: Any = None) -> "CommandResult": + """Create a success result.""" + return cls(status=CommandStatus.SUCCESS, message=message, data=data) + + @classmethod + def error(cls, message: str) -> "CommandResult": + """Create an error result.""" + return cls(status=CommandStatus.ERROR, message=message) + + @classmethod + def exit(cls, message: Optional[str] = None) -> "CommandResult": + """Create an exit result.""" + return cls(status=CommandStatus.EXIT, message=message, should_continue=False) + + +@dataclass +class CommandContext: + """ + Context object providing state to command handlers. + + Contains all the session state needed by commands including + settings, provider, conversation history, and MCP manager. + + Attributes: + settings: Application settings + provider: AI provider instance + mcp_manager: MCP manager instance + selected_model: Currently selected model + session_history: Conversation history + session_system_prompt: Current system prompt + memory_enabled: Whether memory is enabled + online_enabled: Whether online mode is enabled + session_tokens: Session token counts + session_cost: Session cost total + """ + + settings: Optional["Settings"] = None + provider: Optional["AIProvider"] = None + mcp_manager: Optional["MCPManager"] = None + selected_model: Optional["ModelInfo"] = None + selected_model_raw: Optional[Dict[str, Any]] = None + session_history: List[Dict[str, Any]] = field(default_factory=list) + session_system_prompt: str = "" + memory_enabled: bool = True + memory_start_index: int = 0 + online_enabled: bool = False + middle_out_enabled: bool = False + session_max_token: int = 0 + total_input_tokens: int = 0 + total_output_tokens: int = 0 + total_cost: float = 0.0 + message_count: int = 0 + current_index: int = 0 + + +@dataclass +class CommandHelp: + """ + Help information for a command. + + Attributes: + description: Brief description + usage: Usage syntax + examples: List of (description, example) tuples + notes: Additional notes + aliases: Command aliases + """ + + description: str + usage: str = "" + examples: List[tuple] = field(default_factory=list) + notes: str = "" + aliases: List[str] = field(default_factory=list) + + +class Command(ABC): + """ + Abstract base class for all commands. + + Commands implement the execute method to handle their logic. + They can also provide help information and aliases. + """ + + @property + @abstractmethod + def name(self) -> str: + """Get the primary command name (e.g., '/help').""" + pass + + @property + def aliases(self) -> List[str]: + """Get command aliases (e.g., ['/h'] for help).""" + return [] + + @property + @abstractmethod + def help(self) -> CommandHelp: + """Get command help information.""" + pass + + @abstractmethod + def execute(self, args: str, context: CommandContext) -> CommandResult: + """ + Execute the command. + + Args: + args: Arguments passed to the command + context: Command execution context + + Returns: + CommandResult indicating success/failure + """ + pass + + def matches(self, input_text: str) -> bool: + """ + Check if this command matches the input. + + Args: + input_text: User input text + + Returns: + True if this command should handle the input + """ + input_lower = input_text.lower() + cmd_word = input_lower.split()[0] if input_lower.split() else "" + + # Check primary name + if cmd_word == self.name.lower(): + return True + + # Check aliases + for alias in self.aliases: + if cmd_word == alias.lower(): + return True + + return False + + def get_args(self, input_text: str) -> str: + """ + Extract arguments from the input text. + + Args: + input_text: Full user input + + Returns: + Arguments portion of the input + """ + parts = input_text.split(maxsplit=1) + return parts[1] if len(parts) > 1 else "" + + +class CommandRegistry: + """ + Registry for managing available commands. + + Provides registration, lookup, and execution of commands. + """ + + def __init__(self): + """Initialize an empty command registry.""" + self._commands: Dict[str, Command] = {} + self._aliases: Dict[str, str] = {} + self.logger = get_logger() + + def register(self, command: Command) -> None: + """ + Register a command. + + Args: + command: Command instance to register + + Raises: + ValueError: If command name already registered + """ + name = command.name.lower() + + if name in self._commands: + raise ValueError(f"Command '{name}' already registered") + + self._commands[name] = command + + # Register aliases + for alias in command.aliases: + alias_lower = alias.lower() + if alias_lower in self._aliases: + self.logger.warning( + f"Alias '{alias}' already registered, overwriting" + ) + self._aliases[alias_lower] = name + + self.logger.debug(f"Registered command: {name}") + + def register_function( + self, + name: str, + handler: Callable[[str, CommandContext], CommandResult], + description: str, + usage: str = "", + aliases: Optional[List[str]] = None, + examples: Optional[List[tuple]] = None, + notes: str = "", + ) -> None: + """ + Register a function-based command. + + Convenience method for simple commands that don't need + a full Command class. + + Args: + name: Command name (e.g., '/help') + handler: Function to execute + description: Help description + usage: Usage syntax + aliases: Command aliases + examples: Example usages + notes: Additional notes + """ + aliases = aliases or [] + examples = examples or [] + + class FunctionCommand(Command): + @property + def name(self) -> str: + return name + + @property + def aliases(self) -> List[str]: + return aliases + + @property + def help(self) -> CommandHelp: + return CommandHelp( + description=description, + usage=usage, + examples=examples, + notes=notes, + aliases=aliases, + ) + + def execute(self, args: str, context: CommandContext) -> CommandResult: + return handler(args, context) + + self.register(FunctionCommand()) + + def get(self, name: str) -> Optional[Command]: + """ + Get a command by name or alias. + + Args: + name: Command name or alias + + Returns: + Command instance or None if not found + """ + name_lower = name.lower() + + # Check direct match + if name_lower in self._commands: + return self._commands[name_lower] + + # Check aliases + if name_lower in self._aliases: + return self._commands[self._aliases[name_lower]] + + return None + + def find(self, input_text: str) -> Optional[Command]: + """ + Find a command that matches the input. + + Args: + input_text: User input text + + Returns: + Matching Command or None + """ + cmd_word = input_text.lower().split()[0] if input_text.split() else "" + return self.get(cmd_word) + + def execute(self, input_text: str, context: CommandContext) -> Optional[CommandResult]: + """ + Execute a command matching the input. + + Args: + input_text: User input text + context: Execution context + + Returns: + CommandResult or None if no matching command + """ + command = self.find(input_text) + if command: + args = command.get_args(input_text) + self.logger.debug(f"Executing command: {command.name} with args: {args}") + return command.execute(args, context) + return None + + def is_command(self, input_text: str) -> bool: + """ + Check if input is a valid command. + + Args: + input_text: User input text + + Returns: + True if input matches a registered command + """ + return self.find(input_text) is not None + + def list_commands(self) -> List[Command]: + """ + Get all registered commands. + + Returns: + List of Command instances + """ + return list(self._commands.values()) + + def get_all_names(self) -> List[str]: + """ + Get all command names and aliases. + + Returns: + List of command names including aliases + """ + names = list(self._commands.keys()) + names.extend(self._aliases.keys()) + return sorted(set(names)) + + +# Global registry instance +registry = CommandRegistry() diff --git a/oai/config/__init__.py b/oai/config/__init__.py new file mode 100644 index 0000000..d12b503 --- /dev/null +++ b/oai/config/__init__.py @@ -0,0 +1,11 @@ +""" +Configuration management for oAI. + +This package handles all configuration persistence, settings management, +and database operations for the application. +""" + +from oai.config.settings import Settings +from oai.config.database import Database + +__all__ = ["Settings", "Database"] diff --git a/oai/config/database.py b/oai/config/database.py new file mode 100644 index 0000000..a6a3496 --- /dev/null +++ b/oai/config/database.py @@ -0,0 +1,472 @@ +""" +Database persistence layer for oAI. + +This module provides a clean abstraction for SQLite operations including +configuration storage, conversation persistence, and MCP statistics tracking. +All database operations are centralized here for maintainability. +""" + +import sqlite3 +import json +import datetime +from pathlib import Path +from typing import Optional, List, Dict, Any +from contextlib import contextmanager + +from oai.constants import DATABASE_FILE, CONFIG_DIR + + +class Database: + """ + SQLite database manager for oAI. + + Handles all database operations including: + - Configuration key-value storage + - Conversation session persistence + - MCP configuration and statistics + - Database registrations for MCP + + Uses context managers for safe connection handling and supports + automatic table creation on first use. + """ + + def __init__(self, db_path: Optional[Path] = None): + """ + Initialize the database manager. + + Args: + db_path: Optional custom database path. Defaults to standard location. + """ + self.db_path = db_path or DATABASE_FILE + self._ensure_directories() + self._ensure_tables() + + def _ensure_directories(self) -> None: + """Ensure the configuration directory exists.""" + CONFIG_DIR.mkdir(parents=True, exist_ok=True) + + @contextmanager + def _connection(self): + """ + Context manager for database connections. + + Yields: + sqlite3.Connection: Active database connection + + Example: + with self._connection() as conn: + conn.execute("SELECT * FROM config") + """ + conn = sqlite3.connect(str(self.db_path)) + try: + yield conn + conn.commit() + except Exception: + conn.rollback() + raise + finally: + conn.close() + + def _ensure_tables(self) -> None: + """Create all required tables if they don't exist.""" + with self._connection() as conn: + # Main configuration table + conn.execute(""" + CREATE TABLE IF NOT EXISTS config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ) + """) + + # Conversation sessions table + conn.execute(""" + CREATE TABLE IF NOT EXISTS conversation_sessions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name TEXT NOT NULL, + timestamp TEXT NOT NULL, + data TEXT NOT NULL + ) + """) + + # MCP configuration table + conn.execute(""" + CREATE TABLE IF NOT EXISTS mcp_config ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ) + """) + + # MCP statistics table + conn.execute(""" + CREATE TABLE IF NOT EXISTS mcp_stats ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + timestamp TEXT NOT NULL, + tool_name TEXT NOT NULL, + folder TEXT, + success INTEGER NOT NULL, + error_message TEXT + ) + """) + + # MCP databases table + conn.execute(""" + CREATE TABLE IF NOT EXISTS mcp_databases ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + path TEXT NOT NULL UNIQUE, + name TEXT NOT NULL, + size INTEGER, + tables TEXT, + added_timestamp TEXT NOT NULL + ) + """) + + # ========================================================================= + # CONFIGURATION METHODS + # ========================================================================= + + def get_config(self, key: str) -> Optional[str]: + """ + Retrieve a configuration value by key. + + Args: + key: The configuration key to retrieve + + Returns: + The configuration value, or None if not found + """ + with self._connection() as conn: + cursor = conn.execute( + "SELECT value FROM config WHERE key = ?", + (key,) + ) + result = cursor.fetchone() + return result[0] if result else None + + def set_config(self, key: str, value: str) -> None: + """ + Set a configuration value. + + Args: + key: The configuration key + value: The value to store + """ + with self._connection() as conn: + conn.execute( + "INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)", + (key, value) + ) + + def delete_config(self, key: str) -> bool: + """ + Delete a configuration value. + + Args: + key: The configuration key to delete + + Returns: + True if a row was deleted, False otherwise + """ + with self._connection() as conn: + cursor = conn.execute( + "DELETE FROM config WHERE key = ?", + (key,) + ) + return cursor.rowcount > 0 + + def get_all_config(self) -> Dict[str, str]: + """ + Retrieve all configuration values. + + Returns: + Dictionary of all key-value pairs + """ + with self._connection() as conn: + cursor = conn.execute("SELECT key, value FROM config") + return dict(cursor.fetchall()) + + # ========================================================================= + # MCP CONFIGURATION METHODS + # ========================================================================= + + def get_mcp_config(self, key: str) -> Optional[str]: + """ + Retrieve an MCP configuration value. + + Args: + key: The MCP configuration key + + Returns: + The configuration value, or None if not found + """ + with self._connection() as conn: + cursor = conn.execute( + "SELECT value FROM mcp_config WHERE key = ?", + (key,) + ) + result = cursor.fetchone() + return result[0] if result else None + + def set_mcp_config(self, key: str, value: str) -> None: + """ + Set an MCP configuration value. + + Args: + key: The MCP configuration key + value: The value to store + """ + with self._connection() as conn: + conn.execute( + "INSERT OR REPLACE INTO mcp_config (key, value) VALUES (?, ?)", + (key, value) + ) + + # ========================================================================= + # MCP STATISTICS METHODS + # ========================================================================= + + def log_mcp_stat( + self, + tool_name: str, + folder: Optional[str], + success: bool, + error_message: Optional[str] = None + ) -> None: + """ + Log an MCP tool usage event. + + Args: + tool_name: Name of the MCP tool that was called + folder: The folder path involved (if any) + success: Whether the call succeeded + error_message: Error message if the call failed + """ + timestamp = datetime.datetime.now().isoformat() + with self._connection() as conn: + conn.execute( + """INSERT INTO mcp_stats + (timestamp, tool_name, folder, success, error_message) + VALUES (?, ?, ?, ?, ?)""", + (timestamp, tool_name, folder, 1 if success else 0, error_message) + ) + + def get_mcp_stats(self) -> Dict[str, Any]: + """ + Get aggregated MCP usage statistics. + + Returns: + Dictionary containing usage statistics: + - total_calls: Total number of tool calls + - reads: Number of file reads + - lists: Number of directory listings + - searches: Number of file searches + - db_inspects: Number of database inspections + - db_searches: Number of database searches + - db_queries: Number of database queries + - last_used: Timestamp of last usage + """ + with self._connection() as conn: + cursor = conn.execute(""" + SELECT + COUNT(*) as total_calls, + SUM(CASE WHEN tool_name = 'read_file' THEN 1 ELSE 0 END) as reads, + SUM(CASE WHEN tool_name = 'list_directory' THEN 1 ELSE 0 END) as lists, + SUM(CASE WHEN tool_name = 'search_files' THEN 1 ELSE 0 END) as searches, + SUM(CASE WHEN tool_name = 'inspect_database' THEN 1 ELSE 0 END) as db_inspects, + SUM(CASE WHEN tool_name = 'search_database' THEN 1 ELSE 0 END) as db_searches, + SUM(CASE WHEN tool_name = 'query_database' THEN 1 ELSE 0 END) as db_queries, + MAX(timestamp) as last_used + FROM mcp_stats + """) + row = cursor.fetchone() + return { + "total_calls": row[0] or 0, + "reads": row[1] or 0, + "lists": row[2] or 0, + "searches": row[3] or 0, + "db_inspects": row[4] or 0, + "db_searches": row[5] or 0, + "db_queries": row[6] or 0, + "last_used": row[7], + } + + # ========================================================================= + # MCP DATABASE REGISTRY METHODS + # ========================================================================= + + def add_mcp_database(self, db_info: Dict[str, Any]) -> int: + """ + Register a database for MCP access. + + Args: + db_info: Dictionary containing: + - path: Database file path + - name: Display name + - size: File size in bytes + - tables: List of table names + - added: Timestamp when added + + Returns: + The database ID + """ + with self._connection() as conn: + conn.execute( + """INSERT INTO mcp_databases + (path, name, size, tables, added_timestamp) + VALUES (?, ?, ?, ?, ?)""", + ( + db_info["path"], + db_info["name"], + db_info["size"], + json.dumps(db_info["tables"]), + db_info["added"] + ) + ) + cursor = conn.execute( + "SELECT id FROM mcp_databases WHERE path = ?", + (db_info["path"],) + ) + return cursor.fetchone()[0] + + def remove_mcp_database(self, db_path: str) -> bool: + """ + Remove a database from the MCP registry. + + Args: + db_path: Path to the database file + + Returns: + True if a row was deleted, False otherwise + """ + with self._connection() as conn: + cursor = conn.execute( + "DELETE FROM mcp_databases WHERE path = ?", + (db_path,) + ) + return cursor.rowcount > 0 + + def get_mcp_databases(self) -> List[Dict[str, Any]]: + """ + Retrieve all registered MCP databases. + + Returns: + List of database information dictionaries + """ + with self._connection() as conn: + cursor = conn.execute( + """SELECT id, path, name, size, tables, added_timestamp + FROM mcp_databases ORDER BY id""" + ) + databases = [] + for row in cursor.fetchall(): + tables_list = json.loads(row[4]) if row[4] else [] + databases.append({ + "id": row[0], + "path": row[1], + "name": row[2], + "size": row[3], + "tables": tables_list, + "added": row[5], + }) + return databases + + # ========================================================================= + # CONVERSATION METHODS + # ========================================================================= + + def save_conversation(self, name: str, data: List[Dict[str, str]]) -> None: + """ + Save a conversation session. + + Args: + name: Name/identifier for the conversation + data: List of message dictionaries with 'prompt' and 'response' keys + """ + timestamp = datetime.datetime.now().isoformat() + data_json = json.dumps(data) + with self._connection() as conn: + conn.execute( + """INSERT INTO conversation_sessions + (name, timestamp, data) VALUES (?, ?, ?)""", + (name, timestamp, data_json) + ) + + def load_conversation(self, name: str) -> Optional[List[Dict[str, str]]]: + """ + Load a conversation by name. + + Args: + name: Name of the conversation to load + + Returns: + List of messages, or None if not found + """ + with self._connection() as conn: + cursor = conn.execute( + """SELECT data FROM conversation_sessions + WHERE name = ? + ORDER BY timestamp DESC LIMIT 1""", + (name,) + ) + result = cursor.fetchone() + if result: + return json.loads(result[0]) + return None + + def delete_conversation(self, name: str) -> int: + """ + Delete a conversation by name. + + Args: + name: Name of the conversation to delete + + Returns: + Number of rows deleted + """ + with self._connection() as conn: + cursor = conn.execute( + "DELETE FROM conversation_sessions WHERE name = ?", + (name,) + ) + return cursor.rowcount + + def list_conversations(self) -> List[Dict[str, Any]]: + """ + List all saved conversations. + + Returns: + List of conversation summaries with name, timestamp, and message_count + """ + with self._connection() as conn: + cursor = conn.execute(""" + SELECT name, MAX(timestamp) as last_saved, data + FROM conversation_sessions + GROUP BY name + ORDER BY last_saved DESC + """) + conversations = [] + for row in cursor.fetchall(): + name, timestamp, data_json = row + data = json.loads(data_json) + conversations.append({ + "name": name, + "timestamp": timestamp, + "message_count": len(data), + }) + return conversations + + +# Global database instance for convenience +_db: Optional[Database] = None + + +def get_database() -> Database: + """ + Get the global database instance. + + Returns: + The shared Database instance + """ + global _db + if _db is None: + _db = Database() + return _db diff --git a/oai/config/settings.py b/oai/config/settings.py new file mode 100644 index 0000000..907a75a --- /dev/null +++ b/oai/config/settings.py @@ -0,0 +1,361 @@ +""" +Settings management for oAI. + +This module provides a centralized settings class that handles all application +configuration with type safety, validation, and persistence. +""" + +from dataclasses import dataclass, field +from typing import Optional +from pathlib import Path + +from oai.constants import ( + DEFAULT_BASE_URL, + DEFAULT_STREAM_ENABLED, + DEFAULT_MAX_TOKENS, + DEFAULT_ONLINE_MODE, + DEFAULT_COST_WARNING_THRESHOLD, + DEFAULT_LOG_MAX_SIZE_MB, + DEFAULT_LOG_BACKUP_COUNT, + DEFAULT_LOG_LEVEL, + DEFAULT_SYSTEM_PROMPT, + VALID_LOG_LEVELS, +) +from oai.config.database import get_database + + +@dataclass +class Settings: + """ + Application settings with persistence support. + + This class provides a clean interface for managing all configuration + options. Settings are automatically loaded from the database on + initialization and can be persisted back. + + Attributes: + api_key: OpenRouter API key + base_url: API base URL + default_model: Default model ID to use + default_system_prompt: Custom system prompt (None = use hardcoded default, "" = blank) + stream_enabled: Whether to stream responses + max_tokens: Maximum tokens per request + cost_warning_threshold: Alert threshold for message cost + default_online_mode: Whether online mode is enabled by default + log_max_size_mb: Maximum log file size in MB + log_backup_count: Number of log file backups to keep + log_level: Logging level (debug/info/warning/error/critical) + """ + + api_key: Optional[str] = None + base_url: str = DEFAULT_BASE_URL + default_model: Optional[str] = None + default_system_prompt: Optional[str] = None + stream_enabled: bool = DEFAULT_STREAM_ENABLED + max_tokens: int = DEFAULT_MAX_TOKENS + cost_warning_threshold: float = DEFAULT_COST_WARNING_THRESHOLD + default_online_mode: bool = DEFAULT_ONLINE_MODE + log_max_size_mb: int = DEFAULT_LOG_MAX_SIZE_MB + log_backup_count: int = DEFAULT_LOG_BACKUP_COUNT + log_level: str = DEFAULT_LOG_LEVEL + + @property + def effective_system_prompt(self) -> str: + """ + Get the effective system prompt to use. + + Returns: + The custom prompt if set, hardcoded default if None, or blank if explicitly set to "" + """ + if self.default_system_prompt is None: + return DEFAULT_SYSTEM_PROMPT + return self.default_system_prompt + + def __post_init__(self): + """Validate settings after initialization.""" + self._validate() + + def _validate(self) -> None: + """Validate all settings values.""" + # Validate log level + if self.log_level.lower() not in VALID_LOG_LEVELS: + raise ValueError( + f"Invalid log level: {self.log_level}. " + f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}" + ) + + # Validate numeric bounds + if self.max_tokens < 1: + raise ValueError("max_tokens must be at least 1") + + if self.cost_warning_threshold < 0: + raise ValueError("cost_warning_threshold must be non-negative") + + if self.log_max_size_mb < 1: + raise ValueError("log_max_size_mb must be at least 1") + + if self.log_backup_count < 0: + raise ValueError("log_backup_count must be non-negative") + + @classmethod + def load(cls) -> "Settings": + """ + Load settings from the database. + + Returns: + Settings instance with values from database + """ + db = get_database() + + # Helper to safely parse boolean + def parse_bool(value: Optional[str], default: bool) -> bool: + if value is None: + return default + return value.lower() in ("on", "true", "1", "yes") + + # Helper to safely parse int + def parse_int(value: Optional[str], default: int) -> int: + if value is None: + return default + try: + return int(value) + except ValueError: + return default + + # Helper to safely parse float + def parse_float(value: Optional[str], default: float) -> float: + if value is None: + return default + try: + return float(value) + except ValueError: + return default + + # Get system prompt from DB: None means not set (use default), "" means explicitly blank + system_prompt_value = db.get_config("default_system_prompt") + + return cls( + api_key=db.get_config("api_key"), + base_url=db.get_config("base_url") or DEFAULT_BASE_URL, + default_model=db.get_config("default_model"), + default_system_prompt=system_prompt_value, + stream_enabled=parse_bool( + db.get_config("stream_enabled"), + DEFAULT_STREAM_ENABLED + ), + max_tokens=parse_int( + db.get_config("max_token"), + DEFAULT_MAX_TOKENS + ), + cost_warning_threshold=parse_float( + db.get_config("cost_warning_threshold"), + DEFAULT_COST_WARNING_THRESHOLD + ), + default_online_mode=parse_bool( + db.get_config("default_online_mode"), + DEFAULT_ONLINE_MODE + ), + log_max_size_mb=parse_int( + db.get_config("log_max_size_mb"), + DEFAULT_LOG_MAX_SIZE_MB + ), + log_backup_count=parse_int( + db.get_config("log_backup_count"), + DEFAULT_LOG_BACKUP_COUNT + ), + log_level=db.get_config("log_level") or DEFAULT_LOG_LEVEL, + ) + + def save(self) -> None: + """Persist all settings to the database.""" + db = get_database() + + # Only save API key if it exists + if self.api_key: + db.set_config("api_key", self.api_key) + + db.set_config("base_url", self.base_url) + + if self.default_model: + db.set_config("default_model", self.default_model) + + # Save system prompt: None means not set (don't save), otherwise save the value (even if "") + if self.default_system_prompt is not None: + db.set_config("default_system_prompt", self.default_system_prompt) + + db.set_config("stream_enabled", "on" if self.stream_enabled else "off") + db.set_config("max_token", str(self.max_tokens)) + db.set_config("cost_warning_threshold", str(self.cost_warning_threshold)) + db.set_config("default_online_mode", "on" if self.default_online_mode else "off") + db.set_config("log_max_size_mb", str(self.log_max_size_mb)) + db.set_config("log_backup_count", str(self.log_backup_count)) + db.set_config("log_level", self.log_level) + + def set_api_key(self, api_key: str) -> None: + """ + Set and persist the API key. + + Args: + api_key: The new API key + """ + self.api_key = api_key.strip() + get_database().set_config("api_key", self.api_key) + + def set_base_url(self, url: str) -> None: + """ + Set and persist the base URL. + + Args: + url: The new base URL + """ + self.base_url = url.strip() + get_database().set_config("base_url", self.base_url) + + def set_default_model(self, model_id: str) -> None: + """ + Set and persist the default model. + + Args: + model_id: The model ID to set as default + """ + self.default_model = model_id + get_database().set_config("default_model", model_id) + + def set_default_system_prompt(self, prompt: str) -> None: + """ + Set and persist the default system prompt. + + Args: + prompt: The system prompt to use for all new sessions. + Empty string "" means blank prompt (no system message). + """ + self.default_system_prompt = prompt + get_database().set_config("default_system_prompt", prompt) + + def clear_default_system_prompt(self) -> None: + """ + Clear the custom system prompt and revert to hardcoded default. + + This removes the custom prompt from the database, causing the + application to use the built-in DEFAULT_SYSTEM_PROMPT. + """ + self.default_system_prompt = None + # Remove from database to indicate "not set" + db = get_database() + with db._connection() as conn: + conn.execute("DELETE FROM config WHERE key = ?", ("default_system_prompt",)) + conn.commit() + + def set_stream_enabled(self, enabled: bool) -> None: + """ + Set and persist the streaming preference. + + Args: + enabled: Whether to enable streaming + """ + self.stream_enabled = enabled + get_database().set_config("stream_enabled", "on" if enabled else "off") + + def set_max_tokens(self, max_tokens: int) -> None: + """ + Set and persist the maximum tokens. + + Args: + max_tokens: Maximum number of tokens + + Raises: + ValueError: If max_tokens is less than 1 + """ + if max_tokens < 1: + raise ValueError("max_tokens must be at least 1") + self.max_tokens = max_tokens + get_database().set_config("max_token", str(max_tokens)) + + def set_cost_warning_threshold(self, threshold: float) -> None: + """ + Set and persist the cost warning threshold. + + Args: + threshold: Cost threshold in USD + + Raises: + ValueError: If threshold is negative + """ + if threshold < 0: + raise ValueError("cost_warning_threshold must be non-negative") + self.cost_warning_threshold = threshold + get_database().set_config("cost_warning_threshold", str(threshold)) + + def set_default_online_mode(self, enabled: bool) -> None: + """ + Set and persist the default online mode. + + Args: + enabled: Whether online mode should be enabled by default + """ + self.default_online_mode = enabled + get_database().set_config("default_online_mode", "on" if enabled else "off") + + def set_log_level(self, level: str) -> None: + """ + Set and persist the log level. + + Args: + level: The log level (debug/info/warning/error/critical) + + Raises: + ValueError: If level is not valid + """ + level_lower = level.lower() + if level_lower not in VALID_LOG_LEVELS: + raise ValueError( + f"Invalid log level: {level}. " + f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}" + ) + self.log_level = level_lower + get_database().set_config("log_level", level_lower) + + def set_log_max_size(self, size_mb: int) -> None: + """ + Set and persist the maximum log file size. + + Args: + size_mb: Maximum size in megabytes + + Raises: + ValueError: If size_mb is less than 1 + """ + if size_mb < 1: + raise ValueError("log_max_size_mb must be at least 1") + # Cap at 100 MB for safety + self.log_max_size_mb = min(size_mb, 100) + get_database().set_config("log_max_size_mb", str(self.log_max_size_mb)) + + +# Global settings instance +_settings: Optional[Settings] = None + + +def get_settings() -> Settings: + """ + Get the global settings instance. + + Returns: + The shared Settings instance, loading from database if needed + """ + global _settings + if _settings is None: + _settings = Settings.load() + return _settings + + +def reload_settings() -> Settings: + """ + Force reload settings from the database. + + Returns: + Fresh Settings instance + """ + global _settings + _settings = Settings.load() + return _settings diff --git a/oai/constants.py b/oai/constants.py new file mode 100644 index 0000000..a91e330 --- /dev/null +++ b/oai/constants.py @@ -0,0 +1,448 @@ +""" +Application-wide constants for oAI. + +This module contains all configuration constants, default values, and static +definitions used throughout the application. Centralizing these values makes +the codebase easier to maintain and configure. +""" + +from pathlib import Path +from typing import Set, Dict, Any +import logging + +# ============================================================================= +# APPLICATION METADATA +# ============================================================================= + +APP_NAME = "oAI" +APP_VERSION = "2.1.0" +APP_URL = "https://iurl.no/oai" +APP_DESCRIPTION = "OpenRouter AI Chat Client with MCP Integration" + +# ============================================================================= +# FILE PATHS +# ============================================================================= + +HOME_DIR = Path.home() +CONFIG_DIR = HOME_DIR / ".config" / "oai" +CACHE_DIR = HOME_DIR / ".cache" / "oai" +HISTORY_FILE = CONFIG_DIR / "history.txt" +DATABASE_FILE = CONFIG_DIR / "oai_config.db" +LOG_FILE = CONFIG_DIR / "oai.log" + +# ============================================================================= +# API CONFIGURATION +# ============================================================================= + +DEFAULT_BASE_URL = "https://openrouter.ai/api/v1" +DEFAULT_STREAM_ENABLED = True +DEFAULT_MAX_TOKENS = 100_000 +DEFAULT_ONLINE_MODE = False + +# ============================================================================= +# DEFAULT SYSTEM PROMPT +# ============================================================================= + +DEFAULT_SYSTEM_PROMPT = ( + "You are a knowledgeable and helpful AI assistant. Provide clear, accurate, " + "and well-structured responses. Be concise yet thorough. When uncertain about " + "something, acknowledge your limitations. For technical topics, include relevant " + "details and examples when helpful." +) + +# ============================================================================= +# PRICING DEFAULTS (per million tokens) +# ============================================================================= + +DEFAULT_INPUT_PRICE = 3.0 +DEFAULT_OUTPUT_PRICE = 15.0 + +MODEL_PRICING: Dict[str, float] = { + "input": DEFAULT_INPUT_PRICE, + "output": DEFAULT_OUTPUT_PRICE, +} + +# ============================================================================= +# CREDIT ALERTS +# ============================================================================= + +LOW_CREDIT_RATIO = 0.1 # Alert when credits < 10% of total +LOW_CREDIT_AMOUNT = 1.0 # Alert when credits < $1.00 +DEFAULT_COST_WARNING_THRESHOLD = 0.01 # Alert when single message cost exceeds this +COST_WARNING_THRESHOLD = DEFAULT_COST_WARNING_THRESHOLD # Alias for convenience + +# ============================================================================= +# LOGGING CONFIGURATION +# ============================================================================= + +DEFAULT_LOG_MAX_SIZE_MB = 10 +DEFAULT_LOG_BACKUP_COUNT = 2 +DEFAULT_LOG_LEVEL = "info" + +VALID_LOG_LEVELS: Dict[str, int] = { + "debug": logging.DEBUG, + "info": logging.INFO, + "warning": logging.WARNING, + "error": logging.ERROR, + "critical": logging.CRITICAL, +} + +# ============================================================================= +# FILE HANDLING +# ============================================================================= + +# Maximum file size for reading (10 MB) +MAX_FILE_SIZE = 10 * 1024 * 1024 + +# Content truncation threshold (50 KB) +CONTENT_TRUNCATION_THRESHOLD = 50 * 1024 + +# Maximum items in directory listing +MAX_LIST_ITEMS = 1000 + +# Supported code file extensions for syntax highlighting +SUPPORTED_CODE_EXTENSIONS: Set[str] = { + ".py", ".js", ".ts", ".cs", ".java", ".c", ".cpp", ".h", ".hpp", + ".rb", ".ruby", ".php", ".swift", ".kt", ".kts", ".go", + ".sh", ".bat", ".ps1", ".r", ".scala", ".pl", ".lua", ".dart", + ".elm", ".xml", ".json", ".yaml", ".yml", ".md", ".txt", +} + +# All allowed file extensions for attachment +ALLOWED_FILE_EXTENSIONS: Set[str] = { + # Code files + ".py", ".js", ".ts", ".jsx", ".tsx", ".vue", ".java", ".c", ".cpp", ".cc", ".cxx", + ".h", ".hpp", ".hxx", ".rb", ".go", ".rs", ".swift", ".kt", ".kts", ".php", + ".sh", ".bash", ".zsh", ".fish", ".bat", ".cmd", ".ps1", + # Data files + ".json", ".csv", ".yaml", ".yml", ".toml", ".xml", ".sql", ".db", ".sqlite", ".sqlite3", + # Documents + ".txt", ".md", ".log", ".conf", ".cfg", ".ini", ".env", ".properties", + # Images + ".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg", ".ico", + # Archives + ".zip", ".tar", ".gz", ".bz2", ".7z", ".rar", ".xz", + # Config files + ".lock", ".gitignore", ".dockerignore", ".editorconfig", ".eslintrc", + ".prettierrc", ".babelrc", ".nvmrc", ".npmrc", + # Binary/Compiled + ".pyc", ".pyo", ".pyd", ".so", ".dll", ".dylib", ".exe", ".app", + ".dmg", ".pkg", ".deb", ".rpm", ".apk", ".ipa", + # ML/AI + ".pkl", ".pickle", ".joblib", ".npy", ".npz", ".safetensors", ".onnx", + ".pt", ".pth", ".ckpt", ".pb", ".tflite", ".mlmodel", ".coreml", ".rknn", + # Data formats + ".wasm", ".proto", ".graphql", ".graphqls", ".grpc", ".avro", ".parquet", + ".orc", ".feather", ".arrow", ".hdf5", ".h5", ".mat", ".rdata", ".rds", + # Other + ".pdf", ".class", ".jar", ".war", +} + +# ============================================================================= +# SECURITY CONFIGURATION +# ============================================================================= + +# System directories that should never be accessed +SYSTEM_DIRS_BLACKLIST: Set[str] = { + # macOS + "/System", "/Library", "/private", "/usr", "/bin", "/sbin", + # Linux + "/boot", "/dev", "/proc", "/sys", "/root", + # Windows + "C:\\Windows", "C:\\Program Files", "C:\\Program Files (x86)", +} + +# Directories to skip during file operations +SKIP_DIRECTORIES: Set[str] = { + # Python virtual environments + ".venv", "venv", "env", "virtualenv", + "site-packages", "dist-packages", + # Python caches + "__pycache__", ".pytest_cache", ".mypy_cache", + # JavaScript/Node + "node_modules", + # Version control + ".git", ".svn", + # IDEs + ".idea", ".vscode", + # Build directories + "build", "dist", "eggs", ".eggs", +} + +# ============================================================================= +# DATABASE QUERIES - SQL SAFETY +# ============================================================================= + +# Maximum query execution timeout (seconds) +MAX_QUERY_TIMEOUT = 5 + +# Maximum rows returned from queries +MAX_QUERY_RESULTS = 1000 + +# Default rows per query +DEFAULT_QUERY_LIMIT = 100 + +# Keywords that are blocked in database queries +DANGEROUS_SQL_KEYWORDS: Set[str] = { + "INSERT", "UPDATE", "DELETE", "DROP", "CREATE", + "ALTER", "TRUNCATE", "REPLACE", "ATTACH", "DETACH", + "PRAGMA", "VACUUM", "REINDEX", +} + +# ============================================================================= +# MCP CONFIGURATION +# ============================================================================= + +# Maximum tool call iterations per request +MAX_TOOL_LOOPS = 5 + +# ============================================================================= +# VALID COMMANDS +# ============================================================================= + +VALID_COMMANDS: Set[str] = { + "/retry", "/online", "/memory", "/paste", "/export", "/save", "/load", + "/delete", "/list", "/prev", "/next", "/stats", "/middleout", "/reset", + "/info", "/model", "/maxtoken", "/system", "/config", "/credits", "/clear", + "/cl", "/help", "/mcp", +} + +# ============================================================================= +# COMMAND HELP DATABASE +# ============================================================================= + +COMMAND_HELP: Dict[str, Dict[str, Any]] = { + "/clear": { + "aliases": ["/cl"], + "description": "Clear the terminal screen for a clean interface.", + "usage": "/clear\n/cl", + "examples": [ + ("Clear screen", "/clear"), + ("Using short alias", "/cl"), + ], + "notes": "You can also use the keyboard shortcut Ctrl+L.", + }, + "/help": { + "description": "Display help information for commands.", + "usage": "/help [command|topic]", + "examples": [ + ("Show all commands", "/help"), + ("Get help for a specific command", "/help /model"), + ("Get detailed MCP help", "/help mcp"), + ], + "notes": "Use /help without arguments to see the full command list.", + }, + "mcp": { + "description": "Complete guide to MCP (Model Context Protocol).", + "usage": "See detailed examples below", + "examples": [], + "notes": """ +MCP (Model Context Protocol) gives your AI assistant direct access to: + • Local files and folders (read, search, list) + • SQLite databases (inspect, search, query) + +FILE MODE (default): + /mcp on Start MCP server + /mcp add ~/Documents Grant access to folder + /mcp list View all allowed folders + +DATABASE MODE: + /mcp add db ~/app/data.db Add specific database + /mcp db list View all databases + /mcp db 1 Work with database #1 + /mcp files Switch back to file mode + +WRITE MODE (optional): + /mcp write on Enable file modifications + /mcp write off Disable write mode (back to read-only) + +For command-specific help: /help /mcp + """, + }, + "/mcp": { + "description": "Manage MCP for local file access and SQLite database querying.", + "usage": "/mcp [args]", + "examples": [ + ("Enable MCP server", "/mcp on"), + ("Disable MCP server", "/mcp off"), + ("Show MCP status", "/mcp status"), + ("", ""), + ("━━━ FILE MODE ━━━", ""), + ("Add folder for file access", "/mcp add ~/Documents"), + ("Remove folder", "/mcp remove ~/Desktop"), + ("List allowed folders", "/mcp list"), + ("Enable write mode", "/mcp write on"), + ("", ""), + ("━━━ DATABASE MODE ━━━", ""), + ("Add SQLite database", "/mcp add db ~/app/data.db"), + ("List all databases", "/mcp db list"), + ("Switch to database #1", "/mcp db 1"), + ("Switch back to file mode", "/mcp files"), + ], + "notes": "MCP allows AI to read local files and query SQLite databases.", + }, + "/memory": { + "description": "Toggle conversation memory.", + "usage": "/memory [on|off]", + "examples": [ + ("Check current memory status", "/memory"), + ("Enable conversation memory", "/memory on"), + ("Disable memory (save costs)", "/memory off"), + ], + "notes": "Memory is ON by default. Disabling saves tokens.", + }, + "/online": { + "description": "Enable or disable online mode (web search).", + "usage": "/online [on|off]", + "examples": [ + ("Check online mode status", "/online"), + ("Enable web search", "/online on"), + ("Disable web search", "/online off"), + ], + "notes": "Not all models support online mode.", + }, + "/paste": { + "description": "Paste plain text from clipboard and send to the AI.", + "usage": "/paste [prompt]", + "examples": [ + ("Paste clipboard content", "/paste"), + ("Paste with a question", "/paste Explain this code"), + ], + "notes": "Only plain text is supported.", + }, + "/retry": { + "description": "Resend the last prompt from conversation history.", + "usage": "/retry", + "examples": [("Retry last message", "/retry")], + "notes": "Requires at least one message in history.", + }, + "/next": { + "description": "View the next response in conversation history.", + "usage": "/next", + "examples": [("Navigate to next response", "/next")], + "notes": "Use /prev to go backward.", + }, + "/prev": { + "description": "View the previous response in conversation history.", + "usage": "/prev", + "examples": [("Navigate to previous response", "/prev")], + "notes": "Use /next to go forward.", + }, + "/reset": { + "description": "Clear conversation history and reset system prompt.", + "usage": "/reset", + "examples": [("Reset conversation", "/reset")], + "notes": "Requires confirmation.", + }, + "/info": { + "description": "Display detailed information about a model.", + "usage": "/info [model_id]", + "examples": [ + ("Show current model info", "/info"), + ("Show specific model info", "/info gpt-4o"), + ], + "notes": "Shows pricing, capabilities, and context length.", + }, + "/model": { + "description": "Select or change the AI model.", + "usage": "/model [search_term]", + "examples": [ + ("List all models", "/model"), + ("Search for GPT models", "/model gpt"), + ("Search for Claude models", "/model claude"), + ], + "notes": "Models are numbered for easy selection.", + }, + "/config": { + "description": "View or modify application configuration.", + "usage": "/config [setting] [value]", + "examples": [ + ("View all settings", "/config"), + ("Set API key", "/config api"), + ("Set default model", "/config model"), + ("Set system prompt", "/config system You are a helpful assistant"), + ("Enable streaming", "/config stream on"), + ], + "notes": "Available: api, url, model, system, stream, costwarning, maxtoken, online, loglevel.", + }, + "/maxtoken": { + "description": "Set a temporary session token limit.", + "usage": "/maxtoken [value]", + "examples": [ + ("View current session limit", "/maxtoken"), + ("Set session limit to 2000", "/maxtoken 2000"), + ], + "notes": "Cannot exceed stored max token limit.", + }, + "/system": { + "description": "Set or clear the session-level system prompt.", + "usage": "/system [prompt|clear|default ]", + "examples": [ + ("View current system prompt", "/system"), + ("Set as Python expert", "/system You are a Python expert"), + ("Multiline with newlines", r"/system You are an expert.\nBe clear and concise."), + ("Save as default", "/system default You are a helpful assistant"), + ("Revert to default", "/system clear"), + ("Blank prompt", '/system ""'), + ], + "notes": r"Use \n for newlines. /system clear reverts to hardcoded default.", + }, + "/save": { + "description": "Save the current conversation history.", + "usage": "/save ", + "examples": [("Save conversation", "/save my_chat")], + "notes": "Saved conversations can be loaded later with /load.", + }, + "/load": { + "description": "Load a saved conversation.", + "usage": "/load ", + "examples": [ + ("Load by name", "/load my_chat"), + ("Load by number from /list", "/load 3"), + ], + "notes": "Use /list to see numbered conversations.", + }, + "/delete": { + "description": "Delete a saved conversation.", + "usage": "/delete ", + "examples": [("Delete by name", "/delete my_chat")], + "notes": "Requires confirmation. Cannot be undone.", + }, + "/list": { + "description": "List all saved conversations.", + "usage": "/list", + "examples": [("Show saved conversations", "/list")], + "notes": "Conversations are numbered for use with /load and /delete.", + }, + "/export": { + "description": "Export the current conversation to a file.", + "usage": "/export ", + "examples": [ + ("Export as Markdown", "/export md notes.md"), + ("Export as JSON", "/export json conversation.json"), + ("Export as HTML", "/export html report.html"), + ], + "notes": "Available formats: md, json, html.", + }, + "/stats": { + "description": "Display session statistics.", + "usage": "/stats", + "examples": [("View session statistics", "/stats")], + "notes": "Shows tokens, costs, and credits.", + }, + "/credits": { + "description": "Display your OpenRouter account credits.", + "usage": "/credits", + "examples": [("Check credits", "/credits")], + "notes": "Shows total, used, and remaining credits.", + }, + "/middleout": { + "description": "Toggle middle-out transform for long prompts.", + "usage": "/middleout [on|off]", + "examples": [ + ("Check status", "/middleout"), + ("Enable compression", "/middleout on"), + ], + "notes": "Compresses prompts exceeding context size.", + }, +} diff --git a/oai/core/__init__.py b/oai/core/__init__.py new file mode 100644 index 0000000..d2d1e9e --- /dev/null +++ b/oai/core/__init__.py @@ -0,0 +1,14 @@ +""" +Core functionality for oAI. + +This module provides the main session management and AI client +classes that power the chat application. +""" + +from oai.core.session import ChatSession +from oai.core.client import AIClient + +__all__ = [ + "ChatSession", + "AIClient", +] diff --git a/oai/core/client.py b/oai/core/client.py new file mode 100644 index 0000000..5bcdbec --- /dev/null +++ b/oai/core/client.py @@ -0,0 +1,422 @@ +""" +AI Client for oAI. + +This module provides a high-level client for interacting with AI models +through the provider abstraction layer. +""" + +import asyncio +import json +from typing import Any, Callable, Dict, Iterator, List, Optional, Union + +from oai.constants import APP_NAME, APP_URL, MODEL_PRICING +from oai.providers.base import ( + AIProvider, + ChatMessage, + ChatResponse, + ModelInfo, + StreamChunk, + ToolCall, + UsageStats, +) +from oai.providers.openrouter import OpenRouterProvider +from oai.utils.logging import get_logger + + +class AIClient: + """ + High-level AI client for chat interactions. + + Provides a simplified interface for sending chat requests, + handling streaming, and managing tool calls. + + Attributes: + provider: The underlying AI provider + default_model: Default model ID to use + http_headers: Custom HTTP headers for requests + """ + + def __init__( + self, + api_key: str, + base_url: Optional[str] = None, + provider_class: type = OpenRouterProvider, + app_name: str = APP_NAME, + app_url: str = APP_URL, + ): + """ + Initialize the AI client. + + Args: + api_key: API key for authentication + base_url: Optional custom base URL + provider_class: Provider class to use (default: OpenRouterProvider) + app_name: Application name for headers + app_url: Application URL for headers + """ + self.provider: AIProvider = provider_class( + api_key=api_key, + base_url=base_url, + app_name=app_name, + app_url=app_url, + ) + self.default_model: Optional[str] = None + self.logger = get_logger() + + def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]: + """ + Get available models. + + Args: + filter_text_only: Whether to exclude video-only models + + Returns: + List of ModelInfo objects + """ + return self.provider.list_models(filter_text_only=filter_text_only) + + def get_model(self, model_id: str) -> Optional[ModelInfo]: + """ + Get information about a specific model. + + Args: + model_id: Model identifier + + Returns: + ModelInfo or None if not found + """ + return self.provider.get_model(model_id) + + def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]: + """ + Get raw model data for provider-specific fields. + + Args: + model_id: Model identifier + + Returns: + Raw model dictionary or None + """ + if hasattr(self.provider, "get_raw_model"): + return self.provider.get_raw_model(model_id) + return None + + def chat( + self, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + system_prompt: Optional[str] = None, + online: bool = False, + transforms: Optional[List[str]] = None, + ) -> Union[ChatResponse, Iterator[StreamChunk]]: + """ + Send a chat request. + + Args: + messages: List of message dictionaries + model: Model ID (uses default if not specified) + stream: Whether to stream the response + max_tokens: Maximum tokens in response + temperature: Sampling temperature + tools: Tool definitions for function calling + tool_choice: Tool selection mode + system_prompt: System prompt to prepend + online: Whether to enable online mode + transforms: List of transforms (e.g., ["middle-out"]) + + Returns: + ChatResponse for non-streaming, Iterator[StreamChunk] for streaming + + Raises: + ValueError: If no model specified and no default set + """ + model_id = model or self.default_model + if not model_id: + raise ValueError("No model specified and no default set") + + # Apply online mode suffix + if online and hasattr(self.provider, "get_effective_model_id"): + model_id = self.provider.get_effective_model_id(model_id, True) + + # Convert dict messages to ChatMessage objects + chat_messages = [] + + # Add system prompt if provided + if system_prompt: + chat_messages.append(ChatMessage(role="system", content=system_prompt)) + + # Convert message dicts + for msg in messages: + # Convert tool_calls dicts to ToolCall objects if present + tool_calls_data = msg.get("tool_calls") + tool_calls_obj = None + if tool_calls_data: + from oai.providers.base import ToolCall, ToolFunction + tool_calls_obj = [] + for tc in tool_calls_data: + # Handle both ToolCall objects and dicts + if isinstance(tc, ToolCall): + tool_calls_obj.append(tc) + elif isinstance(tc, dict): + func_data = tc.get("function", {}) + tool_calls_obj.append( + ToolCall( + id=tc.get("id", ""), + type=tc.get("type", "function"), + function=ToolFunction( + name=func_data.get("name", ""), + arguments=func_data.get("arguments", "{}"), + ), + ) + ) + + chat_messages.append( + ChatMessage( + role=msg.get("role", "user"), + content=msg.get("content"), + tool_calls=tool_calls_obj, + tool_call_id=msg.get("tool_call_id"), + ) + ) + + self.logger.debug( + f"Sending chat request: model={model_id}, " + f"messages={len(chat_messages)}, stream={stream}" + ) + + return self.provider.chat( + model=model_id, + messages=chat_messages, + stream=stream, + max_tokens=max_tokens, + temperature=temperature, + tools=tools, + tool_choice=tool_choice, + transforms=transforms, + ) + + def chat_with_tools( + self, + messages: List[Dict[str, Any]], + tools: List[Dict[str, Any]], + tool_executor: Callable[[str, Dict[str, Any]], Dict[str, Any]], + model: Optional[str] = None, + max_loops: int = 5, + max_tokens: Optional[int] = None, + system_prompt: Optional[str] = None, + on_tool_call: Optional[Callable[[ToolCall], None]] = None, + on_tool_result: Optional[Callable[[str, Dict[str, Any]], None]] = None, + ) -> ChatResponse: + """ + Send a chat request with automatic tool call handling. + + Executes tool calls returned by the model and continues + the conversation until no more tool calls are requested. + + Args: + messages: Initial messages + tools: Tool definitions + tool_executor: Function to execute tool calls + model: Model ID + max_loops: Maximum tool call iterations + max_tokens: Maximum response tokens + system_prompt: System prompt + on_tool_call: Callback when tool is called + on_tool_result: Callback when tool returns result + + Returns: + Final ChatResponse after all tool calls complete + """ + model_id = model or self.default_model + if not model_id: + raise ValueError("No model specified and no default set") + + # Build initial messages + chat_messages = [] + if system_prompt: + chat_messages.append({"role": "system", "content": system_prompt}) + chat_messages.extend(messages) + + loop_count = 0 + current_response: Optional[ChatResponse] = None + + while loop_count < max_loops: + # Send request + response = self.chat( + messages=chat_messages, + model=model_id, + stream=False, + max_tokens=max_tokens, + tools=tools, + tool_choice="auto", + ) + + if not isinstance(response, ChatResponse): + raise ValueError("Expected non-streaming response") + + current_response = response + + # Check for tool calls + tool_calls = response.tool_calls + if not tool_calls: + break + + self.logger.info(f"Model requested {len(tool_calls)} tool call(s)") + + # Process each tool call + tool_results = [] + for tc in tool_calls: + if on_tool_call: + on_tool_call(tc) + + try: + args = json.loads(tc.function.arguments) + except json.JSONDecodeError as e: + self.logger.error(f"Failed to parse tool arguments: {e}") + result = {"error": f"Invalid arguments: {e}"} + else: + result = tool_executor(tc.function.name, args) + + if on_tool_result: + on_tool_result(tc.function.name, result) + + tool_results.append({ + "tool_call_id": tc.id, + "role": "tool", + "name": tc.function.name, + "content": json.dumps(result), + }) + + # Add assistant message with tool calls + assistant_msg = { + "role": "assistant", + "content": response.content, + "tool_calls": [ + { + "id": tc.id, + "type": tc.type, + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in tool_calls + ], + } + chat_messages.append(assistant_msg) + chat_messages.extend(tool_results) + + loop_count += 1 + + if loop_count >= max_loops: + self.logger.warning(f"Reached max tool call loops ({max_loops})") + + return current_response + + def stream_chat( + self, + messages: List[Dict[str, Any]], + model: Optional[str] = None, + max_tokens: Optional[int] = None, + system_prompt: Optional[str] = None, + online: bool = False, + on_chunk: Optional[Callable[[StreamChunk], None]] = None, + ) -> tuple[str, Optional[UsageStats]]: + """ + Stream a chat response and collect the full text. + + Args: + messages: Chat messages + model: Model ID + max_tokens: Maximum tokens + system_prompt: System prompt + online: Online mode + on_chunk: Optional callback for each chunk + + Returns: + Tuple of (full_response_text, usage_stats) + """ + response = self.chat( + messages=messages, + model=model, + stream=True, + max_tokens=max_tokens, + system_prompt=system_prompt, + online=online, + ) + + if isinstance(response, ChatResponse): + # Not actually streaming + return response.content or "", response.usage + + full_text = "" + usage: Optional[UsageStats] = None + + for chunk in response: + if chunk.error: + self.logger.error(f"Stream error: {chunk.error}") + break + + if chunk.delta_content: + full_text += chunk.delta_content + if on_chunk: + on_chunk(chunk) + + if chunk.usage: + usage = chunk.usage + + return full_text, usage + + def get_credits(self) -> Optional[Dict[str, Any]]: + """ + Get account credit information. + + Returns: + Credit info dict or None if unavailable + """ + return self.provider.get_credits() + + def estimate_cost( + self, + model_id: str, + input_tokens: int, + output_tokens: int, + ) -> float: + """ + Estimate cost for a completion. + + Args: + model_id: Model ID + input_tokens: Number of input tokens + output_tokens: Number of output tokens + + Returns: + Estimated cost in USD + """ + if hasattr(self.provider, "estimate_cost"): + return self.provider.estimate_cost(model_id, input_tokens, output_tokens) + + # Fallback to default pricing + input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000 + output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000 + return input_cost + output_cost + + def set_default_model(self, model_id: str) -> None: + """ + Set the default model. + + Args: + model_id: Model ID to use as default + """ + self.default_model = model_id + self.logger.info(f"Default model set to: {model_id}") + + def clear_cache(self) -> None: + """Clear the provider's model cache.""" + if hasattr(self.provider, "clear_cache"): + self.provider.clear_cache() diff --git a/oai/core/session.py b/oai/core/session.py new file mode 100644 index 0000000..c4e5e31 --- /dev/null +++ b/oai/core/session.py @@ -0,0 +1,659 @@ +""" +Chat session management for oAI. + +This module provides the ChatSession class that manages an interactive +chat session including history, state, and message handling. +""" + +import asyncio +import json +import time +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple + +from rich.live import Live +from rich.markdown import Markdown + +from oai.commands.registry import CommandContext, CommandResult, registry +from oai.config.database import Database +from oai.config.settings import Settings +from oai.constants import ( + COST_WARNING_THRESHOLD, + LOW_CREDIT_AMOUNT, + LOW_CREDIT_RATIO, +) +from oai.core.client import AIClient +from oai.mcp.manager import MCPManager +from oai.providers.base import ChatResponse, StreamChunk, UsageStats +from oai.ui.console import ( + console, + display_markdown, + display_panel, + print_error, + print_info, + print_success, + print_warning, +) +from oai.ui.prompts import prompt_copy_response +from oai.utils.logging import get_logger + + +@dataclass +class SessionStats: + """ + Statistics for the current session. + + Tracks tokens, costs, and message counts. + """ + + total_input_tokens: int = 0 + total_output_tokens: int = 0 + total_cost: float = 0.0 + message_count: int = 0 + + @property + def total_tokens(self) -> int: + """Get total token count.""" + return self.total_input_tokens + self.total_output_tokens + + def add_usage(self, usage: Optional[UsageStats], cost: float = 0.0) -> None: + """ + Add usage stats from a response. + + Args: + usage: Usage statistics + cost: Cost if not in usage + """ + if usage: + self.total_input_tokens += usage.prompt_tokens + self.total_output_tokens += usage.completion_tokens + if usage.total_cost_usd: + self.total_cost += usage.total_cost_usd + else: + self.total_cost += cost + else: + self.total_cost += cost + self.message_count += 1 + + +@dataclass +class HistoryEntry: + """ + A single entry in the conversation history. + + Stores the user prompt, assistant response, and metrics. + """ + + prompt: str + response: str + prompt_tokens: int = 0 + completion_tokens: int = 0 + msg_cost: float = 0.0 + timestamp: Optional[float] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format.""" + return { + "prompt": self.prompt, + "response": self.response, + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + "msg_cost": self.msg_cost, + } + + +class ChatSession: + """ + Manages an interactive chat session. + + Handles conversation history, state management, command processing, + and communication with the AI client. + + Attributes: + client: AI client for API requests + settings: Application settings + mcp_manager: MCP manager for file/database access + history: Conversation history + stats: Session statistics + """ + + def __init__( + self, + client: AIClient, + settings: Settings, + mcp_manager: Optional[MCPManager] = None, + ): + """ + Initialize a chat session. + + Args: + client: AI client instance + settings: Application settings + mcp_manager: Optional MCP manager + """ + self.client = client + self.settings = settings + self.mcp_manager = mcp_manager + self.db = Database() + + self.history: List[HistoryEntry] = [] + self.stats = SessionStats() + + # Session state + self.system_prompt: str = settings.effective_system_prompt + self.memory_enabled: bool = True + self.memory_start_index: int = 0 + self.online_enabled: bool = settings.default_online_mode + self.middle_out_enabled: bool = False + self.session_max_token: int = 0 + self.current_index: int = 0 + + # Selected model + self.selected_model: Optional[Dict[str, Any]] = None + + self.logger = get_logger() + + def get_context(self) -> CommandContext: + """ + Get the current command context. + + Returns: + CommandContext with current session state + """ + return CommandContext( + settings=self.settings, + provider=self.client.provider, + mcp_manager=self.mcp_manager, + selected_model_raw=self.selected_model, + session_history=[e.to_dict() for e in self.history], + session_system_prompt=self.system_prompt, + memory_enabled=self.memory_enabled, + memory_start_index=self.memory_start_index, + online_enabled=self.online_enabled, + middle_out_enabled=self.middle_out_enabled, + session_max_token=self.session_max_token, + total_input_tokens=self.stats.total_input_tokens, + total_output_tokens=self.stats.total_output_tokens, + total_cost=self.stats.total_cost, + message_count=self.stats.message_count, + current_index=self.current_index, + ) + + def set_model(self, model: Dict[str, Any]) -> None: + """ + Set the selected model. + + Args: + model: Raw model dictionary + """ + self.selected_model = model + self.client.set_default_model(model["id"]) + self.logger.info(f"Model selected: {model['id']}") + + def build_api_messages(self, user_input: str) -> List[Dict[str, Any]]: + """ + Build the messages array for an API request. + + Includes system prompt, history (if memory enabled), and current input. + + Args: + user_input: Current user input + + Returns: + List of message dictionaries + """ + messages = [] + + # Add system prompt + if self.system_prompt: + messages.append({"role": "system", "content": self.system_prompt}) + + # Add database context if in database mode + if self.mcp_manager and self.mcp_manager.enabled: + if self.mcp_manager.mode == "database" and self.mcp_manager.selected_db_index is not None: + db = self.mcp_manager.databases[self.mcp_manager.selected_db_index] + db_context = ( + f"You are connected to SQLite database: {db['name']}\n" + f"Available tables: {', '.join(db['tables'])}\n\n" + "Use inspect_database, search_database, or query_database tools. " + "All queries are read-only." + ) + messages.append({"role": "system", "content": db_context}) + + # Add history if memory enabled + if self.memory_enabled: + for i in range(self.memory_start_index, len(self.history)): + entry = self.history[i] + messages.append({"role": "user", "content": entry.prompt}) + messages.append({"role": "assistant", "content": entry.response}) + + # Add current message + messages.append({"role": "user", "content": user_input}) + + return messages + + def get_mcp_tools(self) -> Optional[List[Dict[str, Any]]]: + """ + Get MCP tool definitions if available. + + Returns: + List of tool schemas or None + """ + if not self.mcp_manager or not self.mcp_manager.enabled: + return None + + if not self.selected_model: + return None + + # Check if model supports tools + supported_params = self.selected_model.get("supported_parameters", []) + if "tools" not in supported_params and "functions" not in supported_params: + return None + + return self.mcp_manager.get_tools_schema() + + async def execute_tool( + self, + tool_name: str, + tool_args: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Execute an MCP tool. + + Args: + tool_name: Name of the tool + tool_args: Tool arguments + + Returns: + Tool execution result + """ + if not self.mcp_manager: + return {"error": "MCP not available"} + + return await self.mcp_manager.call_tool(tool_name, **tool_args) + + def send_message( + self, + user_input: str, + stream: bool = True, + on_stream_chunk: Optional[Callable[[str], None]] = None, + ) -> Tuple[str, Optional[UsageStats], float]: + """ + Send a message and get a response. + + Args: + user_input: User's input text + stream: Whether to stream the response + on_stream_chunk: Callback for stream chunks + + Returns: + Tuple of (response_text, usage_stats, response_time) + """ + if not self.selected_model: + raise ValueError("No model selected") + + start_time = time.time() + messages = self.build_api_messages(user_input) + + # Get MCP tools + tools = self.get_mcp_tools() + if tools: + # Disable streaming when tools are present + stream = False + + # Build request parameters + model_id = self.selected_model["id"] + if self.online_enabled: + if hasattr(self.client.provider, "get_effective_model_id"): + model_id = self.client.provider.get_effective_model_id(model_id, True) + + transforms = ["middle-out"] if self.middle_out_enabled else None + + max_tokens = None + if self.session_max_token > 0: + max_tokens = self.session_max_token + + if tools: + # Use tool handling flow + response = self._send_with_tools( + messages=messages, + model_id=model_id, + tools=tools, + max_tokens=max_tokens, + transforms=transforms, + ) + response_time = time.time() - start_time + return response.content or "", response.usage, response_time + + elif stream: + # Use streaming flow + full_text, usage = self._stream_response( + messages=messages, + model_id=model_id, + max_tokens=max_tokens, + transforms=transforms, + on_chunk=on_stream_chunk, + ) + response_time = time.time() - start_time + return full_text, usage, response_time + + else: + # Non-streaming request + response = self.client.chat( + messages=messages, + model=model_id, + stream=False, + max_tokens=max_tokens, + transforms=transforms, + ) + response_time = time.time() - start_time + + if isinstance(response, ChatResponse): + return response.content or "", response.usage, response_time + else: + return "", None, response_time + + def _send_with_tools( + self, + messages: List[Dict[str, Any]], + model_id: str, + tools: List[Dict[str, Any]], + max_tokens: Optional[int] = None, + transforms: Optional[List[str]] = None, + ) -> ChatResponse: + """ + Send a request with tool call handling. + + Args: + messages: API messages + model_id: Model ID + tools: Tool definitions + max_tokens: Max tokens + transforms: Transforms list + + Returns: + Final ChatResponse + """ + max_loops = 5 + loop_count = 0 + api_messages = list(messages) + + while loop_count < max_loops: + response = self.client.chat( + messages=api_messages, + model=model_id, + stream=False, + max_tokens=max_tokens, + tools=tools, + tool_choice="auto", + transforms=transforms, + ) + + if not isinstance(response, ChatResponse): + raise ValueError("Expected ChatResponse") + + tool_calls = response.tool_calls + if not tool_calls: + return response + + console.print(f"\n[dim yellow]🔧 AI requesting {len(tool_calls)} tool call(s)...[/]") + + tool_results = [] + for tc in tool_calls: + try: + args = json.loads(tc.function.arguments) + except json.JSONDecodeError as e: + self.logger.error(f"Failed to parse tool arguments: {e}") + tool_results.append({ + "tool_call_id": tc.id, + "role": "tool", + "name": tc.function.name, + "content": json.dumps({"error": f"Invalid arguments: {e}"}), + }) + continue + + # Display tool call + args_display = ", ".join( + f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}" + for k, v in args.items() + ) + console.print(f"[dim cyan] → {tc.function.name}({args_display})[/]") + + # Execute tool + result = asyncio.run(self.execute_tool(tc.function.name, args)) + + if "error" in result: + console.print(f"[dim red] ✗ Error: {result['error']}[/]") + else: + self._display_tool_success(tc.function.name, result) + + tool_results.append({ + "tool_call_id": tc.id, + "role": "tool", + "name": tc.function.name, + "content": json.dumps(result), + }) + + # Add assistant message with tool calls + api_messages.append({ + "role": "assistant", + "content": response.content, + "tool_calls": [ + { + "id": tc.id, + "type": tc.type, + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in tool_calls + ], + }) + api_messages.extend(tool_results) + + console.print("\n[dim cyan]💭 Processing tool results...[/]") + loop_count += 1 + + self.logger.warning(f"Reached max tool loops ({max_loops})") + console.print(f"[bold yellow]⚠️ Reached maximum tool calls ({max_loops})[/]") + return response + + def _display_tool_success(self, tool_name: str, result: Dict[str, Any]) -> None: + """Display a success message for a tool call.""" + if tool_name == "search_files": + count = result.get("count", 0) + console.print(f"[dim green] ✓ Found {count} file(s)[/]") + elif tool_name == "read_file": + size = result.get("size", 0) + truncated = " (truncated)" if result.get("truncated") else "" + console.print(f"[dim green] ✓ Read {size} bytes{truncated}[/]") + elif tool_name == "list_directory": + count = result.get("count", 0) + console.print(f"[dim green] ✓ Listed {count} item(s)[/]") + elif tool_name == "inspect_database": + if "table" in result: + console.print(f"[dim green] ✓ Inspected table: {result['table']}[/]") + else: + console.print(f"[dim green] ✓ Inspected database ({result.get('table_count', 0)} tables)[/]") + elif tool_name == "search_database": + count = result.get("count", 0) + console.print(f"[dim green] ✓ Found {count} match(es)[/]") + elif tool_name == "query_database": + count = result.get("count", 0) + console.print(f"[dim green] ✓ Query returned {count} row(s)[/]") + else: + console.print("[dim green] ✓ Success[/]") + + def _stream_response( + self, + messages: List[Dict[str, Any]], + model_id: str, + max_tokens: Optional[int] = None, + transforms: Optional[List[str]] = None, + on_chunk: Optional[Callable[[str], None]] = None, + ) -> Tuple[str, Optional[UsageStats]]: + """ + Stream a response with live display. + + Args: + messages: API messages + model_id: Model ID + max_tokens: Max tokens + transforms: Transforms + on_chunk: Callback for chunks + + Returns: + Tuple of (full_text, usage) + """ + response = self.client.chat( + messages=messages, + model=model_id, + stream=True, + max_tokens=max_tokens, + transforms=transforms, + ) + + if isinstance(response, ChatResponse): + return response.content or "", response.usage + + full_text = "" + usage: Optional[UsageStats] = None + + try: + with Live("", console=console, refresh_per_second=10) as live: + for chunk in response: + if chunk.error: + console.print(f"\n[bold red]Stream error: {chunk.error}[/]") + break + + if chunk.delta_content: + full_text += chunk.delta_content + live.update(Markdown(full_text)) + if on_chunk: + on_chunk(chunk.delta_content) + + if chunk.usage: + usage = chunk.usage + + except KeyboardInterrupt: + console.print("\n[bold yellow]⚠️ Streaming interrupted[/]") + return "", None + + return full_text, usage + + def add_to_history( + self, + prompt: str, + response: str, + usage: Optional[UsageStats] = None, + cost: float = 0.0, + ) -> None: + """ + Add an exchange to the history. + + Args: + prompt: User prompt + response: Assistant response + usage: Usage statistics + cost: Cost if not in usage + """ + entry = HistoryEntry( + prompt=prompt, + response=response, + prompt_tokens=usage.prompt_tokens if usage else 0, + completion_tokens=usage.completion_tokens if usage else 0, + msg_cost=usage.total_cost_usd if usage and usage.total_cost_usd else cost, + timestamp=time.time(), + ) + self.history.append(entry) + self.current_index = len(self.history) - 1 + self.stats.add_usage(usage, cost) + + def save_conversation(self, name: str) -> bool: + """ + Save the current conversation. + + Args: + name: Name for the saved conversation + + Returns: + True if saved successfully + """ + if not self.history: + return False + + data = [e.to_dict() for e in self.history] + self.db.save_conversation(name, data) + self.logger.info(f"Saved conversation: {name}") + return True + + def load_conversation(self, name: str) -> bool: + """ + Load a saved conversation. + + Args: + name: Name of the conversation to load + + Returns: + True if loaded successfully + """ + data = self.db.load_conversation(name) + if not data: + return False + + self.history.clear() + for entry_dict in data: + self.history.append(HistoryEntry( + prompt=entry_dict.get("prompt", ""), + response=entry_dict.get("response", ""), + prompt_tokens=entry_dict.get("prompt_tokens", 0), + completion_tokens=entry_dict.get("completion_tokens", 0), + msg_cost=entry_dict.get("msg_cost", 0.0), + )) + + self.current_index = len(self.history) - 1 + self.memory_start_index = 0 + self.stats = SessionStats() # Reset stats for loaded conversation + self.logger.info(f"Loaded conversation: {name}") + return True + + def reset(self) -> None: + """Reset the session state.""" + self.history.clear() + self.stats = SessionStats() + self.system_prompt = "" + self.memory_start_index = 0 + self.current_index = 0 + self.logger.info("Session reset") + + def check_warnings(self) -> List[str]: + """ + Check for cost and credit warnings. + + Returns: + List of warning messages + """ + warnings = [] + + # Check last message cost + if self.history: + last_cost = self.history[-1].msg_cost + threshold = self.settings.cost_warning_threshold + if last_cost > threshold: + warnings.append( + f"High cost: ${last_cost:.4f} exceeds threshold ${threshold:.4f}" + ) + + # Check credits + credits = self.client.get_credits() + if credits: + left = credits.get("credits_left", 0) + total = credits.get("total_credits", 0) + + if left < LOW_CREDIT_AMOUNT: + warnings.append(f"Low credits: ${left:.2f} remaining!") + elif total > 0 and left < total * LOW_CREDIT_RATIO: + warnings.append(f"Credits low: less than 10% remaining (${left:.2f})") + + return warnings diff --git a/oai/mcp/__init__.py b/oai/mcp/__init__.py new file mode 100644 index 0000000..b7abc9a --- /dev/null +++ b/oai/mcp/__init__.py @@ -0,0 +1,28 @@ +""" +Model Context Protocol (MCP) integration for oAI. + +This package provides filesystem and database access capabilities +through the MCP standard, allowing AI models to interact with +local files and SQLite databases safely. + +Key components: +- MCPManager: High-level manager for MCP operations +- MCPFilesystemServer: Filesystem and database access implementation +- GitignoreParser: Pattern matching for .gitignore support +- SQLiteQueryValidator: Query safety validation +- CrossPlatformMCPConfig: OS-specific configuration +""" + +from oai.mcp.manager import MCPManager +from oai.mcp.server import MCPFilesystemServer +from oai.mcp.gitignore import GitignoreParser +from oai.mcp.validators import SQLiteQueryValidator +from oai.mcp.platform import CrossPlatformMCPConfig + +__all__ = [ + "MCPManager", + "MCPFilesystemServer", + "GitignoreParser", + "SQLiteQueryValidator", + "CrossPlatformMCPConfig", +] diff --git a/oai/mcp/gitignore.py b/oai/mcp/gitignore.py new file mode 100644 index 0000000..83ec4f4 --- /dev/null +++ b/oai/mcp/gitignore.py @@ -0,0 +1,166 @@ +""" +Gitignore pattern parsing for oAI MCP. + +This module implements .gitignore pattern matching to filter files +during MCP filesystem operations. +""" + +import fnmatch +from pathlib import Path +from typing import List, Tuple + +from oai.utils.logging import get_logger + + +class GitignoreParser: + """ + Parse and apply .gitignore patterns. + + Supports standard gitignore syntax including: + - Wildcards (*) and double wildcards (**) + - Directory-only patterns (ending with /) + - Negation patterns (starting with !) + - Comments (lines starting with #) + + Patterns are applied relative to the directory containing + the .gitignore file. + """ + + def __init__(self): + """Initialize an empty pattern collection.""" + # List of (pattern, is_negation, source_dir) + self.patterns: List[Tuple[str, bool, Path]] = [] + + def add_gitignore(self, gitignore_path: Path) -> None: + """ + Parse and add patterns from a .gitignore file. + + Args: + gitignore_path: Path to the .gitignore file + """ + logger = get_logger() + + if not gitignore_path.exists(): + return + + try: + source_dir = gitignore_path.parent + + with open(gitignore_path, "r", encoding="utf-8") as f: + for line_num, line in enumerate(f, 1): + line = line.rstrip("\n\r") + + # Skip empty lines and comments + if not line or line.startswith("#"): + continue + + # Check for negation pattern + is_negation = line.startswith("!") + if is_negation: + line = line[1:] + + # Remove leading slash (make relative to gitignore location) + if line.startswith("/"): + line = line[1:] + + self.patterns.append((line, is_negation, source_dir)) + + logger.debug( + f"Loaded {len(self.patterns)} patterns from {gitignore_path}" + ) + + except Exception as e: + logger.warning(f"Error reading {gitignore_path}: {e}") + + def should_ignore(self, path: Path) -> bool: + """ + Check if a path should be ignored based on gitignore patterns. + + Patterns are evaluated in order, with later patterns overriding + earlier ones. Negation patterns (starting with !) un-ignore + previously matched paths. + + Args: + path: Path to check + + Returns: + True if the path should be ignored + """ + if not self.patterns: + return False + + ignored = False + + for pattern, is_negation, source_dir in self.patterns: + # Only apply pattern if path is under the source directory + try: + rel_path = path.relative_to(source_dir) + except ValueError: + # Path is not relative to this gitignore's directory + continue + + rel_path_str = str(rel_path) + + # Check if pattern matches + if self._match_pattern(pattern, rel_path_str, path.is_dir()): + if is_negation: + ignored = False # Negation patterns un-ignore + else: + ignored = True + + return ignored + + def _match_pattern(self, pattern: str, path: str, is_dir: bool) -> bool: + """ + Match a gitignore pattern against a path. + + Args: + pattern: The gitignore pattern + path: The relative path string to match + is_dir: Whether the path is a directory + + Returns: + True if the pattern matches + """ + # Directory-only pattern (ends with /) + if pattern.endswith("/"): + if not is_dir: + return False + pattern = pattern[:-1] + + # Handle ** patterns (matches any number of directories) + if "**" in pattern: + pattern_parts = pattern.split("**") + if len(pattern_parts) == 2: + prefix, suffix = pattern_parts + + # Match if path starts with prefix and ends with suffix + if prefix: + if not path.startswith(prefix.rstrip("/")): + return False + if suffix: + suffix = suffix.lstrip("/") + if not (path.endswith(suffix) or f"/{suffix}" in path): + return False + return True + + # Direct match using fnmatch + if fnmatch.fnmatch(path, pattern): + return True + + # Match as subdirectory pattern (pattern without / matches in any directory) + if "/" not in pattern: + parts = path.split("/") + if any(fnmatch.fnmatch(part, pattern) for part in parts): + return True + + return False + + def clear(self) -> None: + """Clear all loaded patterns.""" + self.patterns = [] + + @property + def pattern_count(self) -> int: + """Get the number of loaded patterns.""" + return len(self.patterns) diff --git a/oai/mcp/manager.py b/oai/mcp/manager.py new file mode 100644 index 0000000..c44001c --- /dev/null +++ b/oai/mcp/manager.py @@ -0,0 +1,1365 @@ +""" +MCP Manager for oAI. + +This module provides the high-level interface for managing MCP operations, +including server lifecycle, mode switching, folder/database management, +and tool schema generation. +""" + +import asyncio +import datetime +import json +from pathlib import Path +from typing import Optional, List, Dict, Any + +from oai.constants import MAX_TOOL_LOOPS +from oai.config.database import get_database +from oai.mcp.platform import CrossPlatformMCPConfig +from oai.mcp.server import MCPFilesystemServer +from oai.utils.logging import get_logger + + +class MCPManager: + """ + Manage MCP server lifecycle, tool calls, and mode switching. + + This class provides the main interface for MCP functionality including: + - Enabling/disabling the MCP server + - Managing allowed folders and databases + - Switching between file and database modes + - Generating tool schemas for API requests + - Executing tool calls + + Attributes: + enabled: Whether MCP is currently enabled + write_enabled: Whether write mode is enabled (non-persistent) + mode: Current mode ('files' or 'database') + selected_db_index: Index of selected database (if in database mode) + server: The MCPFilesystemServer instance + allowed_folders: List of allowed folder paths + databases: List of registered databases + config: Platform configuration + session_start_time: When the current session started + """ + + def __init__(self): + """Initialize the MCP manager.""" + self.enabled = False + self.write_enabled = False # Off by default, resets each session + self.mode = "files" + self.selected_db_index: Optional[int] = None + + self.server: Optional[MCPFilesystemServer] = None + self.allowed_folders: List[Path] = [] + self.databases: List[Dict[str, Any]] = [] + + self.config = CrossPlatformMCPConfig() + self.session_start_time: Optional[datetime.datetime] = None + + # Load persisted data + self._load_folders() + self._load_databases() + + get_logger().info("MCP Manager initialized") + + # ========================================================================= + # PERSISTENCE + # ========================================================================= + + def _load_folders(self) -> None: + """Load allowed folders from database.""" + logger = get_logger() + db = get_database() + folders_json = db.get_mcp_config("allowed_folders") + + if folders_json: + try: + folder_paths = json.loads(folders_json) + self.allowed_folders = [ + self.config.normalize_path(p) for p in folder_paths + ] + logger.info(f"Loaded {len(self.allowed_folders)} folders from config") + except Exception as e: + logger.error(f"Error loading MCP folders: {e}") + self.allowed_folders = [] + + def _save_folders(self) -> None: + """Save allowed folders to database.""" + db = get_database() + folder_paths = [str(p) for p in self.allowed_folders] + db.set_mcp_config("allowed_folders", json.dumps(folder_paths)) + get_logger().info(f"Saved {len(self.allowed_folders)} folders to config") + + def _load_databases(self) -> None: + """Load databases from database.""" + self.databases = get_database().get_mcp_databases() + get_logger().info(f"Loaded {len(self.databases)} databases from config") + + # ========================================================================= + # ENABLE/DISABLE + # ========================================================================= + + def enable(self) -> Dict[str, Any]: + """ + Enable MCP server. + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - folder_count: Number of allowed folders + - database_count: Number of registered databases + - message: Status message + - error: Error message if failed + """ + logger = get_logger() + + if self.enabled: + return { + "success": False, + "error": "MCP is already enabled" + } + + try: + self.server = MCPFilesystemServer(self.allowed_folders) + self.enabled = True + self.session_start_time = datetime.datetime.now() + + get_database().set_mcp_config("mcp_enabled", "true") + + logger.info("MCP Filesystem Server enabled") + + return { + "success": True, + "folder_count": len(self.allowed_folders), + "database_count": len(self.databases), + "message": "MCP Filesystem Server started successfully" + } + except Exception as e: + logger.error(f"Error enabling MCP: {e}") + return { + "success": False, + "error": str(e) + } + + def disable(self) -> Dict[str, Any]: + """ + Disable MCP server. + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - message: Status message + - error: Error message if failed + """ + logger = get_logger() + + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled" + } + + try: + self.server = None + self.enabled = False + self.session_start_time = None + self.mode = "files" + self.selected_db_index = None + + get_database().set_mcp_config("mcp_enabled", "false") + + logger.info("MCP Filesystem Server disabled") + + return { + "success": True, + "message": "MCP Filesystem Server stopped" + } + except Exception as e: + logger.error(f"Error disabling MCP: {e}") + return { + "success": False, + "error": str(e) + } + + # ========================================================================= + # WRITE MODE + # ========================================================================= + + def enable_write(self) -> bool: + """ + Enable write mode. + + This method should be called after user confirmation in the UI. + + Returns: + True if write mode was enabled, False otherwise + """ + logger = get_logger() + + if not self.enabled: + logger.warning("Attempted to enable write mode without MCP enabled") + return False + + self.write_enabled = True + logger.info("MCP write mode enabled by user") + get_database().log_mcp_stat("write_mode_enabled", "", True) + return True + + def disable_write(self) -> None: + """Disable write mode.""" + if self.write_enabled: + self.write_enabled = False + get_logger().info("MCP write mode disabled") + get_database().log_mcp_stat("write_mode_disabled", "", True) + + # ========================================================================= + # MODE SWITCHING + # ========================================================================= + + def switch_mode( + self, + new_mode: str, + db_index: Optional[int] = None + ) -> Dict[str, Any]: + """ + Switch between files and database mode. + + Args: + new_mode: 'files' or 'database' + db_index: Database number (1-based) if switching to database mode + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - mode: Current mode + - message: Status message + - tools: Available tools in this mode + - database: Database info (if database mode) + """ + logger = get_logger() + + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled. Use /mcp on first" + } + + if new_mode == "files": + self.mode = "files" + self.selected_db_index = None + logger.info("Switched to file mode") + return { + "success": True, + "mode": "files", + "message": "Switched to file mode", + "tools": ["read_file", "list_directory", "search_files"] + } + + elif new_mode == "database": + if db_index is None: + return { + "success": False, + "error": "Database index required. Use /mcp db " + } + + if db_index < 1 or db_index > len(self.databases): + return { + "success": False, + "error": f"Invalid database number. Use 1-{len(self.databases)}" + } + + self.mode = "database" + self.selected_db_index = db_index - 1 # Convert to 0-based + db = self.databases[self.selected_db_index] + + logger.info(f"Switched to database mode: {db['name']}") + return { + "success": True, + "mode": "database", + "database": db, + "message": f"Switched to database #{db_index}: {db['name']}", + "tools": ["inspect_database", "search_database", "query_database"] + } + + else: + return { + "success": False, + "error": f"Invalid mode: {new_mode}" + } + + # ========================================================================= + # FOLDER MANAGEMENT + # ========================================================================= + + def add_folder(self, folder_path: str) -> Dict[str, Any]: + """ + Add a folder to the allowed list. + + Args: + folder_path: Path to the folder + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - path: Normalized path + - stats: Folder statistics + - total_folders: Total number of allowed folders + - warning: Optional warning message + """ + logger = get_logger() + + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled. Use /mcp on first" + } + + try: + path = self.config.normalize_path(folder_path) + + # Validation + if not path.exists(): + return { + "success": False, + "error": f"Directory does not exist: {path}" + } + + if not path.is_dir(): + return { + "success": False, + "error": f"Not a directory: {path}" + } + + if self.config.is_system_directory(path): + return { + "success": False, + "error": f"Cannot add system directory: {path}" + } + + if path in self.allowed_folders: + return { + "success": False, + "error": f"Folder already in allowed list: {path}" + } + + # Check for nested paths + parent_folder = None + for existing in self.allowed_folders: + try: + path.relative_to(existing) + parent_folder = existing + break + except ValueError: + continue + + # Get folder stats + stats = self.config.get_folder_stats(path) + + # Add folder + self.allowed_folders.append(path) + self._save_folders() + + # Update server and reload gitignores + if self.server: + self.server.allowed_folders = self.allowed_folders + self.server.reload_gitignores() + + logger.info(f"Added folder to MCP: {path}") + + result = { + "success": True, + "path": str(path), + "stats": stats, + "total_folders": len(self.allowed_folders) + } + + if parent_folder: + result["warning"] = f"Note: {parent_folder} is already allowed (parent folder)" + + return result + + except Exception as e: + logger.error(f"Error adding folder: {e}") + return { + "success": False, + "error": str(e) + } + + def remove_folder(self, folder_ref: str) -> Dict[str, Any]: + """ + Remove a folder from the allowed list. + + Args: + folder_ref: Folder path or number (1-based) + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - path: Removed folder path + - total_folders: Remaining folder count + - warning: Optional warning message + """ + logger = get_logger() + + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled. Use /mcp on first" + } + + try: + # Check if it's a number + if folder_ref.isdigit(): + index = int(folder_ref) - 1 + if 0 <= index < len(self.allowed_folders): + path = self.allowed_folders[index] + else: + return { + "success": False, + "error": f"Invalid folder number: {folder_ref}" + } + else: + # Treat as path + path = self.config.normalize_path(folder_ref) + if path not in self.allowed_folders: + return { + "success": False, + "error": f"Folder not in allowed list: {path}" + } + + # Remove folder + self.allowed_folders.remove(path) + self._save_folders() + + # Update server and reload gitignores + if self.server: + self.server.allowed_folders = self.allowed_folders + self.server.reload_gitignores() + + logger.info(f"Removed folder from MCP: {path}") + + result = { + "success": True, + "path": str(path), + "total_folders": len(self.allowed_folders) + } + + if len(self.allowed_folders) == 0: + result["warning"] = "No allowed folders remaining. Add one with /mcp add " + + return result + + except Exception as e: + logger.error(f"Error removing folder: {e}") + return { + "success": False, + "error": str(e) + } + + def list_folders(self) -> Dict[str, Any]: + """ + List all allowed folders with stats. + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - folders: List of folder info + - total_folders: Total folder count + - total_files: Total file count across folders + - total_size_mb: Total size in MB + """ + try: + folders_info = [] + total_files = 0 + total_size = 0 + + for idx, folder in enumerate(self.allowed_folders, 1): + stats = self.config.get_folder_stats(folder) + + folder_info = { + "number": idx, + "path": str(folder), + "exists": stats.get("exists", False) + } + + if stats.get("exists"): + folder_info["file_count"] = stats["file_count"] + folder_info["size_mb"] = stats["size_mb"] + total_files += stats["file_count"] + total_size += stats["total_size"] + + folders_info.append(folder_info) + + return { + "success": True, + "folders": folders_info, + "total_folders": len(self.allowed_folders), + "total_files": total_files, + "total_size_mb": total_size / (1024 * 1024) + } + + except Exception as e: + get_logger().error(f"Error listing folders: {e}") + return { + "success": False, + "error": str(e) + } + + # ========================================================================= + # DATABASE MANAGEMENT + # ========================================================================= + + def add_database(self, db_path: str) -> Dict[str, Any]: + """ + Add a SQLite database. + + Args: + db_path: Path to the database file + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - database: Database info + - number: Database number + - message: Status message + """ + import sqlite3 + logger = get_logger() + + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled. Use /mcp on first" + } + + try: + path = Path(db_path).resolve() + + # Validation + if not path.exists(): + return { + "success": False, + "error": f"Database file not found: {path}" + } + + if not path.is_file(): + return { + "success": False, + "error": f"Not a file: {path}" + } + + # Check if already added + if any(db["path"] == str(path) for db in self.databases): + return { + "success": False, + "error": f"Database already added: {path.name}" + } + + # Validate it's a SQLite database + try: + conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True) + cursor = conn.cursor() + + # Get tables + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" + ) + tables = [row[0] for row in cursor.fetchall()] + + conn.close() + except sqlite3.DatabaseError as e: + return { + "success": False, + "error": f"Not a valid SQLite database: {e}" + } + + # Get file size + db_size = path.stat().st_size + + # Create database entry + db_info = { + "path": str(path), + "name": path.name, + "size": db_size, + "tables": tables, + "added": datetime.datetime.now().isoformat() + } + + # Save to config + db_id = get_database().add_mcp_database(db_info) + db_info["id"] = db_id + + # Add to list + self.databases.append(db_info) + + logger.info(f"Added database: {path.name}") + + return { + "success": True, + "database": db_info, + "number": len(self.databases), + "message": f"Added database #{len(self.databases)}: {path.name}" + } + + except Exception as e: + logger.error(f"Error adding database: {e}") + return { + "success": False, + "error": str(e) + } + + def remove_database(self, db_ref: str) -> Dict[str, Any]: + """ + Remove a database. + + Args: + db_ref: Database number (1-based) or path + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - database: Removed database info + - message: Status message + - warning: Optional warning message + """ + logger = get_logger() + + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled" + } + + try: + # Check if it's a number + if db_ref.isdigit(): + index = int(db_ref) - 1 + if 0 <= index < len(self.databases): + db = self.databases[index] + else: + return { + "success": False, + "error": f"Invalid database number: {db_ref}" + } + else: + # Treat as path + path = str(Path(db_ref).resolve()) + db = next((d for d in self.databases if d["path"] == path), None) + if not db: + return { + "success": False, + "error": f"Database not found: {db_ref}" + } + index = self.databases.index(db) + + # If currently selected, deselect + if self.mode == "database" and self.selected_db_index == index: + self.mode = "files" + self.selected_db_index = None + + # Remove from config + get_database().remove_mcp_database(db["path"]) + + # Remove from list + self.databases.pop(index) + + logger.info(f"Removed database: {db['name']}") + + result = { + "success": True, + "database": db, + "message": f"Removed database: {db['name']}" + } + + if len(self.databases) == 0: + result["warning"] = "No databases remaining. Add one with /mcp add db " + + return result + + except Exception as e: + logger.error(f"Error removing database: {e}") + return { + "success": False, + "error": str(e) + } + + def list_databases(self) -> Dict[str, Any]: + """ + List all databases. + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - databases: List of database info + - count: Number of databases + """ + try: + db_list = [] + for idx, db in enumerate(self.databases, 1): + db_info = { + "number": idx, + "name": db["name"], + "path": db["path"], + "size_mb": db["size"] / (1024 * 1024), + "tables": db["tables"], + "table_count": len(db["tables"]), + "added": db["added"] + } + + # Check if file still exists + if not Path(db["path"]).exists(): + db_info["warning"] = "File not found" + + db_list.append(db_info) + + return { + "success": True, + "databases": db_list, + "count": len(db_list) + } + + except Exception as e: + get_logger().error(f"Error listing databases: {e}") + return { + "success": False, + "error": str(e) + } + + # ========================================================================= + # GITIGNORE + # ========================================================================= + + def toggle_gitignore(self, enabled: bool) -> Dict[str, Any]: + """ + Toggle .gitignore filtering. + + Args: + enabled: Whether to enable gitignore filtering + + Returns: + Dictionary containing: + - success: Whether operation succeeded + - message: Status message + - pattern_count: Number of patterns (if enabled) + """ + if not self.enabled: + return { + "success": False, + "error": "MCP is not enabled" + } + + if not self.server: + return { + "success": False, + "error": "MCP server not running" + } + + self.server.respect_gitignore = enabled + status = "enabled" if enabled else "disabled" + + get_logger().info(f".gitignore filtering {status}") + + return { + "success": True, + "message": f".gitignore filtering {status}", + "pattern_count": self.server.gitignore_parser.pattern_count if enabled else 0 + } + + # ========================================================================= + # STATUS + # ========================================================================= + + def get_status(self) -> Dict[str, Any]: + """ + Get comprehensive MCP status. + + Returns: + Dictionary containing full MCP status information + """ + try: + stats = get_database().get_mcp_stats() + + uptime = None + if self.session_start_time: + delta = datetime.datetime.now() - self.session_start_time + hours = delta.seconds // 3600 + minutes = (delta.seconds % 3600) // 60 + uptime = f"{hours}h {minutes}m" if hours > 0 else f"{minutes}m" + + folder_info = self.list_folders() + + gitignore_status = "enabled" if ( + self.server and self.server.respect_gitignore + ) else "disabled" + gitignore_patterns = ( + self.server.gitignore_parser.pattern_count if self.server else 0 + ) + + # Current mode info + mode_info = { + "mode": self.mode, + "mode_display": ( + "Files" if self.mode == "files" + else f"DB #{self.selected_db_index + 1}" + ) + } + + if self.mode == "database" and self.selected_db_index is not None: + mode_info["database"] = self.databases[self.selected_db_index] + + return { + "success": True, + "enabled": self.enabled, + "write_enabled": self.write_enabled, + "uptime": uptime, + "mode_info": mode_info, + "folder_count": len(self.allowed_folders), + "database_count": len(self.databases), + "total_files": folder_info.get("total_files", 0), + "total_size_mb": folder_info.get("total_size_mb", 0), + "stats": stats, + "tools_available": self._get_current_tools(), + "gitignore_status": gitignore_status, + "gitignore_patterns": gitignore_patterns + } + + except Exception as e: + get_logger().error(f"Error getting MCP status: {e}") + return { + "success": False, + "error": str(e) + } + + def _get_current_tools(self) -> List[str]: + """Get tools available in current mode.""" + if not self.enabled: + return [] + + if self.mode == "files": + tools = ["read_file", "list_directory", "search_files"] + if self.write_enabled: + tools.extend([ + "write_file", "edit_file", "delete_file", + "create_directory", "move_file", "copy_file" + ]) + return tools + elif self.mode == "database": + return ["inspect_database", "search_database", "query_database"] + + return [] + + # ========================================================================= + # TOOL EXECUTION + # ========================================================================= + + async def call_tool(self, tool_name: str, **kwargs) -> Dict[str, Any]: + """ + Call an MCP tool. + + Args: + tool_name: Name of the tool to call + **kwargs: Tool arguments + + Returns: + Tool execution result + """ + if not self.enabled or not self.server: + return {"error": "MCP is not enabled"} + + # Check write permissions for write operations + write_tools = { + "write_file", "edit_file", "delete_file", + "create_directory", "move_file", "copy_file" + } + if tool_name in write_tools and not self.write_enabled: + return {"error": "Write operations disabled. Enable with /mcp write on"} + + try: + # File mode tools (read-only) + if tool_name == "read_file": + return await self.server.read_file(kwargs.get("file_path", "")) + elif tool_name == "list_directory": + return await self.server.list_directory( + kwargs.get("dir_path", ""), + kwargs.get("recursive", False) + ) + elif tool_name == "search_files": + return await self.server.search_files( + kwargs.get("pattern", ""), + kwargs.get("search_path"), + kwargs.get("content_search", False) + ) + + # Database mode tools + elif tool_name == "inspect_database": + if self.mode != "database" or self.selected_db_index is None: + return {"error": "Not in database mode. Use /mcp db first"} + db = self.databases[self.selected_db_index] + return await self.server.inspect_database( + db["path"], + kwargs.get("table_name") + ) + elif tool_name == "search_database": + if self.mode != "database" or self.selected_db_index is None: + return {"error": "Not in database mode. Use /mcp db first"} + db = self.databases[self.selected_db_index] + return await self.server.search_database( + db["path"], + kwargs.get("search_term", ""), + kwargs.get("table_name"), + kwargs.get("column_name") + ) + elif tool_name == "query_database": + if self.mode != "database" or self.selected_db_index is None: + return {"error": "Not in database mode. Use /mcp db first"} + db = self.databases[self.selected_db_index] + return await self.server.query_database( + db["path"], + kwargs.get("query", ""), + kwargs.get("limit") + ) + + # Write mode tools + elif tool_name == "write_file": + return await self.server.write_file( + kwargs.get("file_path", ""), + kwargs.get("content", "") + ) + elif tool_name == "edit_file": + return await self.server.edit_file( + kwargs.get("file_path", ""), + kwargs.get("old_text", ""), + kwargs.get("new_text", "") + ) + elif tool_name == "delete_file": + return await self.server.delete_file( + kwargs.get("file_path", ""), + kwargs.get("reason", "No reason provided"), + kwargs.get("confirm_callback") + ) + elif tool_name == "create_directory": + return await self.server.create_directory( + kwargs.get("dir_path", "") + ) + elif tool_name == "move_file": + return await self.server.move_file( + kwargs.get("source_path", ""), + kwargs.get("dest_path", "") + ) + elif tool_name == "copy_file": + return await self.server.copy_file( + kwargs.get("source_path", ""), + kwargs.get("dest_path", "") + ) + else: + return {"error": f"Unknown tool: {tool_name}"} + + except Exception as e: + get_logger().error(f"Error calling MCP tool {tool_name}: {e}") + return {"error": str(e)} + + # ========================================================================= + # TOOL SCHEMA GENERATION + # ========================================================================= + + def get_tools_schema(self) -> List[Dict[str, Any]]: + """ + Get MCP tools as OpenAI function calling schema. + + Returns the appropriate schema based on current mode. + + Returns: + List of tool definitions for API request + """ + if not self.enabled: + return [] + + if self.mode == "files": + return self._get_file_tools_schema() + elif self.mode == "database": + return self._get_database_tools_schema() + + return [] + + def _get_file_tools_schema(self) -> List[Dict[str, Any]]: + """Get file mode tools schema.""" + if len(self.allowed_folders) == 0: + return [] + + allowed_dirs_str = ", ".join(str(f) for f in self.allowed_folders) + + # Base read-only tools + tools = [ + { + "type": "function", + "function": { + "name": "search_files", + "description": ( + f"Search for files in the user's local filesystem. " + f"Allowed directories: {allowed_dirs_str}. " + "Can search by filename pattern (e.g., '*.py' for Python files) " + "or search within file contents. Automatically respects .gitignore " + "patterns and excludes virtual environments." + ), + "parameters": { + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": ( + "Search pattern. For filename search, use glob patterns " + "like '*.py', '*.txt', 'report*'. For content search, " + "use plain text to search for." + ) + }, + "content_search": { + "type": "boolean", + "description": ( + "If true, searches inside file contents (SLOW). " + "If false, searches only filenames (FAST). Default is false." + ), + "default": False + }, + "search_path": { + "type": "string", + "description": ( + "Optional: specific directory to search in. " + "If not provided, searches all allowed directories." + ), + } + }, + "required": ["pattern"] + } + } + }, + { + "type": "function", + "function": { + "name": "read_file", + "description": ( + "Read the complete contents of a text file. " + "Only works for files within allowed directories. " + "Maximum file size: 10MB. Files larger than 50KB are automatically truncated. " + "Respects .gitignore patterns." + ), + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": ( + "Full path to the file to read " + "(e.g., /Users/username/Documents/report.txt)" + ) + } + }, + "required": ["file_path"] + } + } + }, + { + "type": "function", + "function": { + "name": "list_directory", + "description": ( + "List all files and subdirectories in a directory. " + "Automatically filters out virtual environments, build artifacts, " + "and .gitignore patterns. Limited to 1000 items." + ), + "parameters": { + "type": "object", + "properties": { + "dir_path": { + "type": "string", + "description": ( + "Directory path to list " + "(e.g., ~/Documents or /Users/username/Projects)" + ) + }, + "recursive": { + "type": "boolean", + "description": ( + "If true, lists files in subdirectories recursively. " + "Default is true. WARNING: Can be very large." + ), + "default": True + } + }, + "required": ["dir_path"] + } + } + } + ] + + # Add write tools if write mode is enabled + if self.write_enabled: + tools.extend(self._get_write_tools_schema(allowed_dirs_str)) + + return tools + + def _get_write_tools_schema(self, allowed_dirs_str: str) -> List[Dict[str, Any]]: + """Get write mode tools schema.""" + return [ + { + "type": "function", + "function": { + "name": "write_file", + "description": ( + f"Create or overwrite a file with content. " + f"Allowed directories: {allowed_dirs_str}. " + "Automatically creates parent directories if needed." + ), + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Full path to the file to create or overwrite" + }, + "content": { + "type": "string", + "description": "The complete content to write to the file" + } + }, + "required": ["file_path", "content"] + } + } + }, + { + "type": "function", + "function": { + "name": "edit_file", + "description": ( + "Make targeted edits by finding and replacing specific text. " + "The old_text must match exactly and appear only once in the file." + ), + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Full path to the file to edit" + }, + "old_text": { + "type": "string", + "description": ( + "The exact text to find and replace. " + "Must match exactly and appear only once." + ) + }, + "new_text": { + "type": "string", + "description": "The new text to replace the old text with" + } + }, + "required": ["file_path", "old_text", "new_text"] + } + } + }, + { + "type": "function", + "function": { + "name": "delete_file", + "description": ( + "Delete a file. ALWAYS requires user confirmation. " + "The user will see the file details and your reason before deciding." + ), + "parameters": { + "type": "object", + "properties": { + "file_path": { + "type": "string", + "description": "Full path to the file to delete" + }, + "reason": { + "type": "string", + "description": ( + "Clear explanation for why this file should be deleted. " + "The user will see this reason." + ) + } + }, + "required": ["file_path", "reason"] + } + } + }, + { + "type": "function", + "function": { + "name": "create_directory", + "description": ( + "Create a new directory (and parent directories if needed). " + "Returns success if directory already exists." + ), + "parameters": { + "type": "object", + "properties": { + "dir_path": { + "type": "string", + "description": "Full path to the directory to create" + } + }, + "required": ["dir_path"] + } + } + }, + { + "type": "function", + "function": { + "name": "move_file", + "description": "Move or rename a file within allowed directories.", + "parameters": { + "type": "object", + "properties": { + "source_path": { + "type": "string", + "description": "Full path to the file to move/rename" + }, + "dest_path": { + "type": "string", + "description": "Full path for the new location/name" + } + }, + "required": ["source_path", "dest_path"] + } + } + }, + { + "type": "function", + "function": { + "name": "copy_file", + "description": "Copy a file to a new location within allowed directories.", + "parameters": { + "type": "object", + "properties": { + "source_path": { + "type": "string", + "description": "Full path to the file to copy" + }, + "dest_path": { + "type": "string", + "description": "Full path for the copy destination" + } + }, + "required": ["source_path", "dest_path"] + } + } + } + ] + + def _get_database_tools_schema(self) -> List[Dict[str, Any]]: + """Get database mode tools schema.""" + if self.selected_db_index is None or self.selected_db_index >= len(self.databases): + return [] + + db = self.databases[self.selected_db_index] + db_name = db["name"] + tables_str = ", ".join(db["tables"]) + + return [ + { + "type": "function", + "function": { + "name": "inspect_database", + "description": ( + f"Inspect the schema of the currently selected database ({db_name}). " + f"Can get all tables or details for a specific table. " + f"Available tables: {tables_str}." + ), + "parameters": { + "type": "object", + "properties": { + "table_name": { + "type": "string", + "description": ( + f"Optional: specific table to inspect. " + f"If not provided, returns info for all tables. " + f"Available: {tables_str}" + ) + } + }, + "required": [] + } + } + }, + { + "type": "function", + "function": { + "name": "search_database", + "description": ( + f"Search for a value across tables in the database ({db_name}). " + f"Performs partial matching across columns. " + f"Limited to {self.server.default_query_limit} results." + ), + "parameters": { + "type": "object", + "properties": { + "search_term": { + "type": "string", + "description": "Value to search for (partial match supported)" + }, + "table_name": { + "type": "string", + "description": f"Optional: limit search to specific table. Available: {tables_str}" + }, + "column_name": { + "type": "string", + "description": "Optional: limit search to specific column" + } + }, + "required": ["search_term"] + } + } + }, + { + "type": "function", + "function": { + "name": "query_database", + "description": ( + f"Execute a read-only SQL query on the database ({db_name}). " + f"Supports SELECT queries including JOINs, subqueries, CTEs. " + f"Maximum {self.server.max_query_results} rows. " + f"Timeout: {self.server.max_query_timeout} seconds. " + "INSERT/UPDATE/DELETE/DROP are blocked." + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": ( + f"SQL SELECT query to execute. " + f"Available tables: {tables_str}." + ) + }, + "limit": { + "type": "integer", + "description": ( + f"Optional: max rows to return " + f"(default {self.server.default_query_limit}, " + f"max {self.server.max_query_results})" + ) + } + }, + "required": ["query"] + } + } + } + ] + + +# Global MCP manager instance +_mcp_manager: Optional[MCPManager] = None + + +def get_mcp_manager() -> MCPManager: + """ + Get the global MCP manager instance. + + Returns: + The shared MCPManager instance + """ + global _mcp_manager + if _mcp_manager is None: + _mcp_manager = MCPManager() + return _mcp_manager diff --git a/oai/mcp/platform.py b/oai/mcp/platform.py new file mode 100644 index 0000000..dc941bf --- /dev/null +++ b/oai/mcp/platform.py @@ -0,0 +1,228 @@ +""" +Cross-platform MCP configuration for oAI. + +This module handles OS-specific configuration, path handling, +and security checks for the MCP filesystem server. +""" + +import os +import platform +import subprocess +from pathlib import Path +from typing import List, Dict, Any, Optional + +from oai.constants import SYSTEM_DIRS_BLACKLIST +from oai.utils.logging import get_logger + + +class CrossPlatformMCPConfig: + """ + Handle OS-specific MCP configuration. + + Provides methods for path normalization, security validation, + and OS-specific default directories. + + Attributes: + system: Operating system name + is_macos: Whether running on macOS + is_linux: Whether running on Linux + is_windows: Whether running on Windows + """ + + def __init__(self): + """Initialize platform detection.""" + self.system = platform.system() + self.is_macos = self.system == "Darwin" + self.is_linux = self.system == "Linux" + self.is_windows = self.system == "Windows" + + logger = get_logger() + logger.info(f"Detected OS: {self.system}") + + def get_default_allowed_dirs(self) -> List[Path]: + """ + Get safe default directories for the current OS. + + Returns: + List of default directories that are safe to access + """ + home = Path.home() + + if self.is_macos: + return [ + home / "Documents", + home / "Desktop", + home / "Downloads", + ] + + elif self.is_linux: + dirs = [home / "Documents"] + + # Try to get XDG directories + try: + for xdg_dir in ["DOCUMENTS", "DESKTOP", "DOWNLOAD"]: + result = subprocess.run( + ["xdg-user-dir", xdg_dir], + capture_output=True, + text=True, + timeout=1 + ) + if result.returncode == 0: + dir_path = Path(result.stdout.strip()) + if dir_path.exists(): + dirs.append(dir_path) + except (subprocess.TimeoutExpired, FileNotFoundError): + # Fallback to standard locations + dirs.extend([ + home / "Desktop", + home / "Downloads", + ]) + + return list(set(dirs)) + + elif self.is_windows: + return [ + home / "Documents", + home / "Desktop", + home / "Downloads", + ] + + # Fallback for unknown OS + return [home] + + def get_python_command(self) -> str: + """ + Get the Python executable path. + + Returns: + Path to the Python executable + """ + import sys + return sys.executable + + def get_filesystem_warning(self) -> str: + """ + Get OS-specific security warning message. + + Returns: + Warning message for the current OS + """ + if self.is_macos: + return """ +Note: macOS Security +The Filesystem MCP server needs access to your selected folder. +You may see a security prompt - click 'Allow' to proceed. +(System Settings > Privacy & Security > Files and Folders) +""" + elif self.is_linux: + return """ +Note: Linux Security +The Filesystem MCP server will access your selected folder. +Ensure oAI has appropriate file permissions. +""" + elif self.is_windows: + return """ +Note: Windows Security +The Filesystem MCP server will access your selected folder. +You may need to grant file access permissions. +""" + return "" + + def normalize_path(self, path: str) -> Path: + """ + Normalize a path for the current OS. + + Expands user directory (~) and resolves to absolute path. + + Args: + path: Path string to normalize + + Returns: + Normalized absolute Path + """ + return Path(os.path.expanduser(path)).resolve() + + def is_system_directory(self, path: Path) -> bool: + """ + Check if a path is a protected system directory. + + Args: + path: Path to check + + Returns: + True if the path is a system directory + """ + path_str = str(path) + for blocked in SYSTEM_DIRS_BLACKLIST: + if path_str.startswith(blocked): + return True + return False + + def is_safe_path(self, requested_path: Path, allowed_dirs: List[Path]) -> bool: + """ + Check if a path is within allowed directories. + + Args: + requested_path: Path being requested + allowed_dirs: List of allowed parent directories + + Returns: + True if the path is within an allowed directory + """ + try: + requested = requested_path.resolve() + + for allowed in allowed_dirs: + try: + allowed_resolved = allowed.resolve() + requested.relative_to(allowed_resolved) + return True + except ValueError: + continue + + return False + except Exception: + return False + + def get_folder_stats(self, folder: Path) -> Dict[str, Any]: + """ + Get statistics for a folder. + + Args: + folder: Path to the folder + + Returns: + Dictionary with folder statistics: + - exists: Whether the folder exists + - file_count: Number of files (if exists) + - total_size: Total size in bytes (if exists) + - size_mb: Size in megabytes (if exists) + - error: Error message (if any) + """ + logger = get_logger() + + try: + if not folder.exists() or not folder.is_dir(): + return {"exists": False} + + file_count = 0 + total_size = 0 + + for item in folder.rglob("*"): + if item.is_file(): + file_count += 1 + try: + total_size += item.stat().st_size + except (OSError, PermissionError): + pass + + return { + "exists": True, + "file_count": file_count, + "total_size": total_size, + "size_mb": total_size / (1024 * 1024), + } + + except Exception as e: + logger.error(f"Error getting folder stats for {folder}: {e}") + return {"exists": False, "error": str(e)} diff --git a/oai/mcp/server.py b/oai/mcp/server.py new file mode 100644 index 0000000..105a0de --- /dev/null +++ b/oai/mcp/server.py @@ -0,0 +1,1368 @@ +""" +MCP Filesystem Server implementation for oAI. + +This module provides the actual filesystem and database access operations +for the MCP integration. It handles file reading, directory listing, +searching, and SQLite database operations. +""" + +import datetime +import signal +import sqlite3 +from pathlib import Path +from typing import Optional, List, Dict, Any, Set + +from oai.constants import ( + MAX_FILE_SIZE, + CONTENT_TRUNCATION_THRESHOLD, + MAX_LIST_ITEMS, + MAX_QUERY_TIMEOUT, + MAX_QUERY_RESULTS, + DEFAULT_QUERY_LIMIT, + SKIP_DIRECTORIES, +) +from oai.config.database import get_database +from oai.mcp.platform import CrossPlatformMCPConfig +from oai.mcp.gitignore import GitignoreParser +from oai.mcp.validators import SQLiteQueryValidator +from oai.utils.logging import get_logger + + +class MCPFilesystemServer: + """ + MCP Filesystem Server with file access and SQLite database querying. + + Provides async methods for: + - File operations: read, write, edit, delete, copy, move + - Directory operations: list, create + - Search operations: filename and content search + - Database operations: inspect, search, query + + All operations respect allowed folder restrictions and gitignore patterns. + + Attributes: + allowed_folders: List of folders the server can access + config: Platform-specific configuration + max_file_size: Maximum file size for read operations + max_list_items: Maximum items to return in directory listings + respect_gitignore: Whether to apply gitignore filtering + gitignore_parser: Parser for gitignore patterns + max_query_timeout: Maximum seconds for database queries + max_query_results: Maximum rows to return from queries + default_query_limit: Default row limit for queries + """ + + def __init__(self, allowed_folders: List[Path]): + """ + Initialize the filesystem server. + + Args: + allowed_folders: List of folder paths the server can access + """ + self.allowed_folders = allowed_folders + self.config = CrossPlatformMCPConfig() + self.max_file_size = MAX_FILE_SIZE + self.max_list_items = MAX_LIST_ITEMS + self.respect_gitignore = True + + # Initialize gitignore parser and load patterns + self.gitignore_parser = GitignoreParser() + self._load_gitignores() + + # SQLite configuration + self.max_query_timeout = MAX_QUERY_TIMEOUT + self.max_query_results = MAX_QUERY_RESULTS + self.default_query_limit = DEFAULT_QUERY_LIMIT + + logger = get_logger() + logger.info( + f"MCP Filesystem Server initialized with {len(allowed_folders)} folders" + ) + + def _load_gitignores(self) -> None: + """Load all .gitignore files from allowed folders.""" + logger = get_logger() + gitignore_count = 0 + + for folder in self.allowed_folders: + if not folder.exists(): + continue + + # Load root .gitignore + root_gitignore = folder / ".gitignore" + if root_gitignore.exists(): + self.gitignore_parser.add_gitignore(root_gitignore) + gitignore_count += 1 + logger.info(f"Loaded .gitignore from {folder}") + + # Load nested .gitignore files + try: + for gitignore_path in folder.rglob(".gitignore"): + if gitignore_path != root_gitignore: + self.gitignore_parser.add_gitignore(gitignore_path) + gitignore_count += 1 + logger.debug(f"Loaded nested .gitignore: {gitignore_path}") + except Exception as e: + logger.warning(f"Error loading nested .gitignores from {folder}: {e}") + + if gitignore_count > 0: + logger.info( + f"Loaded {gitignore_count} .gitignore file(s) with " + f"{self.gitignore_parser.pattern_count} total patterns" + ) + + def reload_gitignores(self) -> None: + """Reload all .gitignore files (call when allowed_folders changes).""" + self.gitignore_parser.clear() + self._load_gitignores() + get_logger().info("Reloaded .gitignore patterns") + + def is_allowed_path(self, path: Path) -> bool: + """ + Check if a path is within allowed folders. + + Args: + path: Path to check + + Returns: + True if the path is allowed + """ + return self.config.is_safe_path(path, self.allowed_folders) + + def _should_skip_path(self, item_path: Path) -> bool: + """ + Check if a path should be skipped during traversal. + + Args: + item_path: Path to check + + Returns: + True if the path should be skipped + """ + # Check hardcoded skip directories + path_parts = item_path.parts + if any(part in SKIP_DIRECTORIES for part in path_parts): + return True + + # Check gitignore patterns (if enabled) + if self.respect_gitignore and self.gitignore_parser.should_ignore(item_path): + return True + + return False + + def _log_stat( + self, + tool_name: str, + folder: Optional[str], + success: bool, + error_message: Optional[str] = None + ) -> None: + """Log an MCP tool usage event.""" + get_database().log_mcp_stat(tool_name, folder, success, error_message) + + # ========================================================================= + # FILE READ OPERATIONS + # ========================================================================= + + async def read_file(self, file_path: str) -> Dict[str, Any]: + """ + Read file contents. + + Args: + file_path: Path to the file to read + + Returns: + Dictionary containing: + - content: File content + - path: Resolved file path + - size: File size in bytes + - truncated: Whether content was truncated + - error: Error message if failed + """ + logger = get_logger() + + try: + path = self.config.normalize_path(file_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} not in allowed folders" + logger.warning(error_msg) + self._log_stat("read_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + if not path.exists(): + error_msg = f"File not found: {path}" + logger.warning(error_msg) + self._log_stat("read_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + if not path.is_file(): + error_msg = f"Not a file: {path}" + logger.warning(error_msg) + self._log_stat("read_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Check if file should be ignored + if self.respect_gitignore and self.gitignore_parser.should_ignore(path): + error_msg = f"File ignored by .gitignore: {path}" + logger.warning(error_msg) + self._log_stat("read_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + file_size = path.stat().st_size + if file_size > self.max_file_size: + error_msg = ( + f"File too large: {file_size / (1024*1024):.1f}MB " + f"(max: {self.max_file_size / (1024*1024):.0f}MB)" + ) + logger.warning(error_msg) + self._log_stat("read_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Try to read as text + try: + content = path.read_text(encoding="utf-8") + + # Truncate large files + if file_size > CONTENT_TRUNCATION_THRESHOLD: + lines = content.split("\n") + total_lines = len(lines) + + head_lines = 500 + tail_lines = 100 + + if total_lines > (head_lines + tail_lines): + truncated_content = ( + "\n".join(lines[:head_lines]) + + f"\n\n... [TRUNCATED: {total_lines - head_lines - tail_lines} " + f"lines omitted] ...\n\n" + + "\n".join(lines[-tail_lines:]) + ) + + logger.info( + f"Read file (truncated): {path} " + f"({file_size} bytes, {total_lines} lines)" + ) + self._log_stat("read_file", str(path.parent), True) + + return { + "content": truncated_content, + "path": str(path), + "size": file_size, + "truncated": True, + "total_lines": total_lines, + "lines_shown": head_lines + tail_lines, + "note": ( + f"File truncated: showing first {head_lines} " + f"and last {tail_lines} lines of {total_lines} total" + ) + } + + logger.info(f"Read file: {path} ({file_size} bytes)") + self._log_stat("read_file", str(path.parent), True) + + return { + "content": content, + "path": str(path), + "size": file_size + } + + except UnicodeDecodeError: + error_msg = f"Cannot decode file as UTF-8: {path}" + logger.warning(error_msg) + self._log_stat("read_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + except Exception as e: + error_msg = f"Error reading file: {e}" + logger.error(error_msg) + self._log_stat("read_file", file_path, False, str(e)) + return {"error": error_msg} + + async def list_directory( + self, + dir_path: str, + recursive: bool = False + ) -> Dict[str, Any]: + """ + List directory contents. + + Args: + dir_path: Path to the directory + recursive: Whether to list recursively + + Returns: + Dictionary containing: + - path: Directory path + - items: List of file/directory info + - count: Number of items + - truncated: Whether results were truncated + """ + logger = get_logger() + + try: + path = self.config.normalize_path(dir_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} not in allowed folders" + logger.warning(error_msg) + self._log_stat("list_directory", str(path), False, error_msg) + return {"error": error_msg} + + if not path.exists(): + error_msg = f"Directory not found: {path}" + logger.warning(error_msg) + self._log_stat("list_directory", str(path), False, error_msg) + return {"error": error_msg} + + if not path.is_dir(): + error_msg = f"Not a directory: {path}" + logger.warning(error_msg) + self._log_stat("list_directory", str(path), False, error_msg) + return {"error": error_msg} + + items = [] + pattern = "**/*" if recursive else "*" + + for item in path.glob(pattern): + # Stop if we hit the limit + if len(items) >= self.max_list_items: + break + + # Skip excluded directories + if self._should_skip_path(item): + continue + + try: + stat = item.stat() + items.append({ + "name": item.name, + "path": str(item), + "type": "directory" if item.is_dir() else "file", + "size": stat.st_size if item.is_file() else 0, + "modified": datetime.datetime.fromtimestamp( + stat.st_mtime + ).isoformat() + }) + except (OSError, PermissionError): + continue + + truncated = len(items) >= self.max_list_items + + logger.info( + f"Listed directory: {path} ({len(items)} items, " + f"recursive={recursive}, truncated={truncated})" + ) + self._log_stat("list_directory", str(path), True) + + result = { + "path": str(path), + "items": items, + "count": len(items), + "truncated": truncated + } + + if truncated: + result["note"] = ( + f"Results limited to {self.max_list_items} items. " + "Use more specific search path or disable recursive mode." + ) + + return result + + except Exception as e: + error_msg = f"Error listing directory: {e}" + logger.error(error_msg) + self._log_stat("list_directory", dir_path, False, str(e)) + return {"error": error_msg} + + async def search_files( + self, + pattern: str, + search_path: Optional[str] = None, + content_search: bool = False + ) -> Dict[str, Any]: + """ + Search for files by name or content. + + Args: + pattern: Search pattern (glob for filename, text for content) + search_path: Optional path to search within + content_search: Whether to search within file contents + + Returns: + Dictionary containing: + - pattern: The search pattern used + - search_type: 'filename' or 'content' + - matches: List of matching files + - count: Number of matches + """ + logger = get_logger() + + try: + # Determine search roots + if search_path: + path = self.config.normalize_path(search_path) + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} not in allowed folders" + logger.warning(error_msg) + self._log_stat("search_files", str(path), False, error_msg) + return {"error": error_msg} + search_roots = [path] + else: + search_roots = self.allowed_folders + + matches = [] + + for root in search_roots: + if not root.exists(): + continue + + # Filename search (fast) + if not content_search: + for item in root.rglob(pattern): + if item.is_file() and not self._should_skip_path(item): + try: + stat = item.stat() + matches.append({ + "path": str(item), + "name": item.name, + "size": stat.st_size, + "modified": datetime.datetime.fromtimestamp( + stat.st_mtime + ).isoformat() + }) + except (OSError, PermissionError): + continue + else: + # Content search (slower) + for item in root.rglob("*"): + if item.is_file() and not self._should_skip_path(item): + try: + # Skip large files + if item.stat().st_size > self.max_file_size: + continue + + # Try to read and search content + try: + content = item.read_text(encoding="utf-8") + if pattern.lower() in content.lower(): + stat = item.stat() + matches.append({ + "path": str(item), + "name": item.name, + "size": stat.st_size, + "modified": datetime.datetime.fromtimestamp( + stat.st_mtime + ).isoformat() + }) + except (UnicodeDecodeError, PermissionError): + continue + except (OSError, PermissionError): + continue + + search_type = "content" if content_search else "filename" + logger.info( + f"Searched files: pattern='{pattern}', type={search_type}, " + f"found={len(matches)}" + ) + self._log_stat( + "search_files", + str(search_roots[0]) if search_roots else None, + True + ) + + return { + "pattern": pattern, + "search_type": search_type, + "matches": matches, + "count": len(matches) + } + + except Exception as e: + error_msg = f"Error searching files: {e}" + logger.error(error_msg) + self._log_stat("search_files", search_path or "all", False, str(e)) + return {"error": error_msg} + + # ========================================================================= + # FILE WRITE OPERATIONS + # ========================================================================= + + async def write_file(self, file_path: str, content: str) -> Dict[str, Any]: + """ + Create or overwrite a file. + + Note: Ignores .gitignore - allows writing to any file in allowed folders. + + Args: + file_path: Path to the file + content: Content to write + + Returns: + Dictionary containing: + - success: Whether the operation succeeded + - path: File path + - size: File size after writing + - created: Whether a new file was created + """ + logger = get_logger() + + try: + path = self.config.normalize_path(file_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("write_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Create parent directory if needed + parent_dir = path.parent + if not parent_dir.exists(): + parent_dir.mkdir(parents=True, exist_ok=True) + logger.info(f"Created parent directory: {parent_dir}") + + is_new_file = not path.exists() + + # Write content + path.write_text(content, encoding="utf-8") + file_size = path.stat().st_size + + logger.info( + f"{'Created' if is_new_file else 'Updated'} file: " + f"{path} ({file_size} bytes)" + ) + self._log_stat("write_file", str(path.parent), True) + + return { + "success": True, + "path": str(path), + "size": file_size, + "created": is_new_file + } + + except PermissionError as e: + error_msg = f"Permission denied writing to {file_path}: {e}" + logger.error(error_msg) + self._log_stat("write_file", file_path, False, str(e)) + return {"error": error_msg, "error_type": "PermissionError"} + except UnicodeEncodeError as e: + error_msg = f"Encoding error writing to {file_path}: {e}" + logger.error(error_msg) + self._log_stat("write_file", file_path, False, str(e)) + return {"error": error_msg, "error_type": "EncodingError"} + except Exception as e: + error_msg = f"Error writing file {file_path}: {e}" + logger.error(error_msg) + self._log_stat("write_file", file_path, False, str(e)) + return {"error": error_msg} + + async def edit_file( + self, + file_path: str, + old_text: str, + new_text: str + ) -> Dict[str, Any]: + """ + Find and replace text in a file. + + Note: Ignores .gitignore - allows editing any file in allowed folders. + + Args: + file_path: Path to the file + old_text: Text to find (must be unique in file) + new_text: Text to replace with + + Returns: + Dictionary containing: + - success: Whether the operation succeeded + - path: File path + - changes: Number of replacements made + """ + logger = get_logger() + + try: + path = self.config.normalize_path(file_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("edit_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + if not path.exists(): + error_msg = f"File not found: {path}" + logger.warning(error_msg) + self._log_stat("edit_file", str(path), False, error_msg) + return {"error": error_msg} + + if not path.is_file(): + error_msg = f"Not a file: {path}" + logger.warning(error_msg) + self._log_stat("edit_file", str(path), False, error_msg) + return {"error": error_msg} + + # Read current content + current_content = path.read_text(encoding="utf-8") + + # Check if old_text exists + if old_text not in current_content: + error_msg = f"Text not found in file: '{old_text[:50]}...'" + logger.warning(f"Edit failed - text not found in {path}") + self._log_stat("edit_file", str(path), False, error_msg) + return {"error": error_msg} + + # Check for unique match + occurrence_count = current_content.count(old_text) + if occurrence_count > 1: + error_msg = ( + f"Text appears {occurrence_count} times in file. " + "Please provide more context to make the match unique." + ) + logger.warning(f"Edit failed - ambiguous match in {path}") + self._log_stat("edit_file", str(path), False, error_msg) + return {"error": error_msg} + + # Perform replacement + new_content = current_content.replace(old_text, new_text, 1) + path.write_text(new_content, encoding="utf-8") + + logger.info(f"Edited file: {path} (1 replacement)") + self._log_stat("edit_file", str(path.parent), True) + + return { + "success": True, + "path": str(path), + "changes": 1 + } + + except PermissionError as e: + error_msg = f"Permission denied editing {file_path}: {e}" + logger.error(error_msg) + self._log_stat("edit_file", file_path, False, str(e)) + return {"error": error_msg, "error_type": "PermissionError"} + except UnicodeDecodeError as e: + error_msg = f"Cannot read file (encoding issue): {file_path}" + logger.error(error_msg) + self._log_stat("edit_file", file_path, False, str(e)) + return {"error": error_msg, "error_type": "EncodingError"} + except Exception as e: + error_msg = f"Error editing file {file_path}: {e}" + logger.error(error_msg) + self._log_stat("edit_file", file_path, False, str(e)) + return {"error": error_msg} + + async def delete_file( + self, + file_path: str, + reason: str, + confirm_callback=None + ) -> Dict[str, Any]: + """ + Delete a file with user confirmation. + + Note: Ignores .gitignore - allows deleting any file in allowed folders. + + Args: + file_path: Path to the file to delete + reason: Reason for deletion (shown to user) + confirm_callback: Callback function for user confirmation + + Returns: + Dictionary containing: + - success: Whether the file was deleted + - path: File path + - user_cancelled: Whether user cancelled the operation + """ + logger = get_logger() + + try: + path = self.config.normalize_path(file_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("delete_file", str(path.parent), False, error_msg) + return {"error": error_msg} + + if not path.exists(): + error_msg = f"File not found: {path}" + logger.warning(error_msg) + self._log_stat("delete_file", str(path), False, error_msg) + return {"error": error_msg} + + if not path.is_file(): + error_msg = f"Not a file: {path}" + logger.warning(error_msg) + self._log_stat("delete_file", str(path), False, error_msg) + return {"error": error_msg} + + # Get file info for confirmation + file_size = path.stat().st_size + file_mtime = datetime.datetime.fromtimestamp(path.stat().st_mtime) + + # User confirmation required (handled by callback) + if confirm_callback: + confirmed = confirm_callback(path, file_size, file_mtime, reason) + if not confirmed: + logger.info(f"User cancelled file deletion: {path}") + self._log_stat("delete_file", str(path), False, "User cancelled") + return { + "success": False, + "user_cancelled": True, + "path": str(path) + } + + # Delete the file + path.unlink() + + logger.info(f"Deleted file: {path}") + self._log_stat("delete_file", str(path.parent), True) + + return { + "success": True, + "path": str(path), + "user_cancelled": False + } + + except PermissionError as e: + error_msg = f"Permission denied deleting {file_path}: {e}" + logger.error(error_msg) + self._log_stat("delete_file", file_path, False, str(e)) + return {"error": error_msg, "error_type": "PermissionError"} + except Exception as e: + error_msg = f"Error deleting file {file_path}: {e}" + logger.error(error_msg) + self._log_stat("delete_file", file_path, False, str(e)) + return {"error": error_msg} + + async def create_directory(self, dir_path: str) -> Dict[str, Any]: + """ + Create a directory. + + Note: Ignores .gitignore - allows creating directories in allowed folders. + + Args: + dir_path: Path to the directory to create + + Returns: + Dictionary containing: + - success: Whether the operation succeeded + - path: Directory path + - created: Whether a new directory was created + """ + logger = get_logger() + + try: + path = self.config.normalize_path(dir_path) + + if not self.is_allowed_path(path): + error_msg = f"Access denied: {path} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("create_directory", str(path.parent), False, error_msg) + return {"error": error_msg} + + already_exists = path.exists() + + if already_exists and not path.is_dir(): + error_msg = f"Path exists but is not a directory: {path}" + logger.warning(error_msg) + self._log_stat("create_directory", str(path), False, error_msg) + return {"error": error_msg} + + # Create directory (and parents) + path.mkdir(parents=True, exist_ok=True) + + if already_exists: + logger.info(f"Directory already exists: {path}") + else: + logger.info(f"Created directory: {path}") + + self._log_stat("create_directory", str(path.parent), True) + + return { + "success": True, + "path": str(path), + "created": not already_exists + } + + except PermissionError as e: + error_msg = f"Permission denied creating directory {dir_path}: {e}" + logger.error(error_msg) + self._log_stat("create_directory", dir_path, False, str(e)) + return {"error": error_msg, "error_type": "PermissionError"} + except Exception as e: + error_msg = f"Error creating directory {dir_path}: {e}" + logger.error(error_msg) + self._log_stat("create_directory", dir_path, False, str(e)) + return {"error": error_msg} + + async def move_file(self, source_path: str, dest_path: str) -> Dict[str, Any]: + """ + Move or rename a file. + + Note: Ignores .gitignore - allows moving files within allowed folders. + + Args: + source_path: Path to the source file + dest_path: Destination path + + Returns: + Dictionary containing: + - success: Whether the operation succeeded + - source: Source path + - destination: Destination path + """ + import shutil + logger = get_logger() + + try: + source = self.config.normalize_path(source_path) + dest = self.config.normalize_path(dest_path) + + # Both paths must be in allowed folders + if not self.is_allowed_path(source): + error_msg = f"Access denied: source {source} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("move_file", str(source.parent), False, error_msg) + return {"error": error_msg} + + if not self.is_allowed_path(dest): + error_msg = f"Access denied: destination {dest} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("move_file", str(dest.parent), False, error_msg) + return {"error": error_msg} + + if not source.exists(): + error_msg = f"Source file not found: {source}" + logger.warning(error_msg) + self._log_stat("move_file", str(source), False, error_msg) + return {"error": error_msg} + + if not source.is_file(): + error_msg = f"Source is not a file: {source}" + logger.warning(error_msg) + self._log_stat("move_file", str(source), False, error_msg) + return {"error": error_msg} + + # Create destination parent directory if needed + dest.parent.mkdir(parents=True, exist_ok=True) + + # Move/rename the file + shutil.move(str(source), str(dest)) + + logger.info(f"Moved file: {source} -> {dest}") + self._log_stat("move_file", str(source.parent), True) + + return { + "success": True, + "source": str(source), + "destination": str(dest) + } + + except PermissionError as e: + error_msg = f"Permission denied moving file: {e}" + logger.error(error_msg) + self._log_stat("move_file", source_path, False, str(e)) + return {"error": error_msg, "error_type": "PermissionError"} + except Exception as e: + error_msg = f"Error moving file from {source_path} to {dest_path}: {e}" + logger.error(error_msg) + self._log_stat("move_file", source_path, False, str(e)) + return {"error": error_msg} + + async def copy_file(self, source_path: str, dest_path: str) -> Dict[str, Any]: + """ + Copy a file. + + Note: Ignores .gitignore - allows copying files within allowed folders. + + Args: + source_path: Path to the source file + dest_path: Destination path + + Returns: + Dictionary containing: + - success: Whether the operation succeeded + - source: Source path + - destination: Destination path + - size: Size of copied file + """ + import shutil + logger = get_logger() + + try: + source = self.config.normalize_path(source_path) + dest = self.config.normalize_path(dest_path) + + # Both paths must be in allowed folders + if not self.is_allowed_path(source): + error_msg = f"Access denied: source {source} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("copy_file", str(source.parent), False, error_msg) + return {"error": error_msg} + + if not self.is_allowed_path(dest): + error_msg = f"Access denied: destination {dest} is not in allowed MCP folders" + logger.warning(error_msg) + self._log_stat("copy_file", str(dest.parent), False, error_msg) + return {"error": error_msg} + + if not source.exists(): + error_msg = f"Source file not found: {source}" + logger.warning(error_msg) + self._log_stat("copy_file", str(source), False, error_msg) + return {"error": error_msg} + + if not source.is_file(): + error_msg = f"Source is not a file: {source}" + logger.warning(error_msg) + self._log_stat("copy_file", str(source), False, error_msg) + return {"error": error_msg} + + # Create destination parent directory if needed + dest.parent.mkdir(parents=True, exist_ok=True) + + # Copy the file (preserves metadata) + shutil.copy2(str(source), str(dest)) + + file_size = dest.stat().st_size + + logger.info(f"Copied file: {source} -> {dest} ({file_size} bytes)") + self._log_stat("copy_file", str(source.parent), True) + + return { + "success": True, + "source": str(source), + "destination": str(dest), + "size": file_size + } + + except PermissionError as e: + error_msg = f"Permission denied copying file: {e}" + logger.error(error_msg) + self._log_stat("copy_file", source_path, False, str(e)) + return {"error": error_msg, "error_type": "PermissionError"} + except Exception as e: + error_msg = f"Error copying file from {source_path} to {dest_path}: {e}" + logger.error(error_msg) + self._log_stat("copy_file", source_path, False, str(e)) + return {"error": error_msg} + + # ========================================================================= + # DATABASE OPERATIONS + # ========================================================================= + + async def inspect_database( + self, + db_path: str, + table_name: Optional[str] = None + ) -> Dict[str, Any]: + """ + Inspect SQLite database schema. + + Args: + db_path: Path to the database file + table_name: Optional specific table to inspect + + Returns: + Dictionary containing database/table information + """ + logger = get_logger() + + try: + path = Path(db_path).resolve() + + if not path.exists(): + error_msg = f"Database not found: {path}" + logger.warning(error_msg) + self._log_stat("inspect_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + if not path.is_file(): + error_msg = f"Not a file: {path}" + logger.warning(error_msg) + self._log_stat("inspect_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Open database in read-only mode + try: + conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True) + cursor = conn.cursor() + except sqlite3.DatabaseError as e: + error_msg = f"Not a valid SQLite database: {path} ({e})" + logger.warning(error_msg) + self._log_stat("inspect_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + try: + if table_name: + # Get info for specific table + cursor.execute( + "SELECT sql FROM sqlite_master WHERE type='table' AND name=?", + (table_name,) + ) + result = cursor.fetchone() + + if not result: + return {"error": f"Table '{table_name}' not found"} + + create_sql = result[0] + + # Get column info + cursor.execute(f"PRAGMA table_info({table_name})") + columns = [] + for row in cursor.fetchall(): + columns.append({ + "id": row[0], + "name": row[1], + "type": row[2], + "not_null": bool(row[3]), + "default_value": row[4], + "primary_key": bool(row[5]) + }) + + # Get indexes + cursor.execute(f"PRAGMA index_list({table_name})") + indexes = [] + for row in cursor.fetchall(): + indexes.append({ + "name": row[1], + "unique": bool(row[2]) + }) + + # Get row count + cursor.execute(f"SELECT COUNT(*) FROM {table_name}") + row_count = cursor.fetchone()[0] + + logger.info(f"Inspected table: {table_name} in {path}") + self._log_stat("inspect_database", str(path.parent), True) + + return { + "database": str(path), + "table": table_name, + "create_sql": create_sql, + "columns": columns, + "indexes": indexes, + "row_count": row_count + } + else: + # Get all tables + cursor.execute( + "SELECT name, sql FROM sqlite_master WHERE type='table' ORDER BY name" + ) + tables = [] + for row in cursor.fetchall(): + table = row[0] + cursor.execute(f"SELECT COUNT(*) FROM {table}") + count = cursor.fetchone()[0] + tables.append({ + "name": table, + "row_count": count + }) + + # Get indexes + cursor.execute( + "SELECT name, tbl_name FROM sqlite_master WHERE type='index' ORDER BY name" + ) + indexes = [{"name": row[0], "table": row[1]} for row in cursor.fetchall()] + + # Get triggers + cursor.execute( + "SELECT name, tbl_name FROM sqlite_master WHERE type='trigger' ORDER BY name" + ) + triggers = [{"name": row[0], "table": row[1]} for row in cursor.fetchall()] + + # Get database size + db_size = path.stat().st_size + + logger.info(f"Inspected database: {path} ({len(tables)} tables)") + self._log_stat("inspect_database", str(path.parent), True) + + return { + "database": str(path), + "size": db_size, + "size_mb": db_size / (1024 * 1024), + "tables": tables, + "table_count": len(tables), + "indexes": indexes, + "triggers": triggers + } + finally: + conn.close() + + except Exception as e: + error_msg = f"Error inspecting database: {e}" + logger.error(error_msg) + self._log_stat("inspect_database", db_path, False, str(e)) + return {"error": error_msg} + + async def search_database( + self, + db_path: str, + search_term: str, + table_name: Optional[str] = None, + column_name: Optional[str] = None + ) -> Dict[str, Any]: + """ + Search for a value across database tables. + + Args: + db_path: Path to the database file + search_term: Value to search for (partial match) + table_name: Optional table to limit search + column_name: Optional column to limit search + + Returns: + Dictionary containing search results + """ + logger = get_logger() + + try: + path = Path(db_path).resolve() + + if not path.exists(): + error_msg = f"Database not found: {path}" + logger.warning(error_msg) + self._log_stat("search_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Open database in read-only mode + try: + conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True) + cursor = conn.cursor() + except sqlite3.DatabaseError as e: + error_msg = f"Not a valid SQLite database: {path} ({e})" + logger.warning(error_msg) + self._log_stat("search_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + try: + matches = [] + + # Get tables to search + if table_name: + cursor.execute( + "SELECT name FROM sqlite_master WHERE type='table' AND name=?", + (table_name,) + ) + tables = [row[0] for row in cursor.fetchall()] + if not tables: + return {"error": f"Table '{table_name}' not found"} + else: + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = [row[0] for row in cursor.fetchall()] + + for table in tables: + # Get columns for this table + cursor.execute(f"PRAGMA table_info({table})") + columns = [row[1] for row in cursor.fetchall()] + + # Filter columns if specified + if column_name: + if column_name not in columns: + continue + columns = [column_name] + + # Search in each column + for column in columns: + try: + query = f"SELECT * FROM {table} WHERE {column} LIKE ? LIMIT {self.default_query_limit}" + cursor.execute(query, (f"%{search_term}%",)) + results = cursor.fetchall() + + if results: + col_names = [desc[0] for desc in cursor.description] + + for row in results: + row_dict = {} + for col_name, value in zip(col_names, row): + if isinstance(value, bytes): + try: + row_dict[col_name] = value.decode("utf-8") + except UnicodeDecodeError: + row_dict[col_name] = f"" + else: + row_dict[col_name] = value + + matches.append({ + "table": table, + "column": column, + "row": row_dict + }) + except sqlite3.Error: + # Skip columns that can't be searched + continue + + logger.info( + f"Searched database: {path} for '{search_term}' - " + f"found {len(matches)} matches" + ) + self._log_stat("search_database", str(path.parent), True) + + result = { + "database": str(path), + "search_term": search_term, + "matches": matches, + "count": len(matches) + } + + if table_name: + result["table_filter"] = table_name + if column_name: + result["column_filter"] = column_name + + if len(matches) >= self.default_query_limit: + result["note"] = ( + f"Results limited to {self.default_query_limit} matches. " + "Use more specific search or query_database for full results." + ) + + return result + + finally: + conn.close() + + except Exception as e: + error_msg = f"Error searching database: {e}" + logger.error(error_msg) + self._log_stat("search_database", db_path, False, str(e)) + return {"error": error_msg} + + async def query_database( + self, + db_path: str, + query: str, + limit: Optional[int] = None + ) -> Dict[str, Any]: + """ + Execute a read-only SQL query. + + Args: + db_path: Path to the database file + query: SQL query to execute (SELECT only) + limit: Optional row limit + + Returns: + Dictionary containing query results + """ + logger = get_logger() + + try: + path = Path(db_path).resolve() + + if not path.exists(): + error_msg = f"Database not found: {path}" + logger.warning(error_msg) + self._log_stat("query_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Validate query + is_safe, error_msg = SQLiteQueryValidator.is_safe_query(query) + if not is_safe: + logger.warning(f"Unsafe query rejected: {error_msg}") + self._log_stat("query_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + # Open database in read-only mode + try: + conn = sqlite3.connect(f"file:{path}?mode=ro", uri=True) + cursor = conn.cursor() + except sqlite3.DatabaseError as e: + error_msg = f"Not a valid SQLite database: {path} ({e})" + logger.warning(error_msg) + self._log_stat("query_database", str(path.parent), False, error_msg) + return {"error": error_msg} + + try: + # Set query timeout (Unix only) + def timeout_handler(signum, frame): + raise TimeoutError( + f"Query exceeded {self.max_query_timeout} second timeout" + ) + + if hasattr(signal, "SIGALRM"): + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(self.max_query_timeout) + + try: + # Execute query + cursor.execute(query) + + # Get results + result_limit = min( + limit or self.default_query_limit, + self.max_query_results + ) + results = cursor.fetchmany(result_limit) + + # Get column names + columns = [ + desc[0] for desc in cursor.description + ] if cursor.description else [] + + # Convert to list of dicts + rows = [] + for row in results: + row_dict = {} + for col_name, value in zip(columns, row): + if isinstance(value, bytes): + try: + row_dict[col_name] = value.decode("utf-8") + except UnicodeDecodeError: + row_dict[col_name] = f"" + else: + row_dict[col_name] = value + rows.append(row_dict) + + # Check if more results available + has_more = len(results) == result_limit + if has_more: + one_more = cursor.fetchone() + has_more = one_more is not None + + finally: + if hasattr(signal, "SIGALRM"): + signal.alarm(0) + + logger.info(f"Executed query on {path}: returned {len(rows)} rows") + self._log_stat("query_database", str(path.parent), True) + + result = { + "database": str(path), + "query": query, + "columns": columns, + "rows": rows, + "count": len(rows) + } + + if has_more: + result["truncated"] = True + result["note"] = ( + f"Results limited to {result_limit} rows. " + "Use LIMIT clause in query for more control." + ) + + return result + + except TimeoutError as e: + error_msg = str(e) + logger.warning(error_msg) + self._log_stat("query_database", str(path.parent), False, error_msg) + return {"error": error_msg} + except sqlite3.Error as e: + error_msg = f"SQL error: {e}" + logger.warning(error_msg) + self._log_stat("query_database", str(path.parent), False, error_msg) + return {"error": error_msg} + finally: + if hasattr(signal, "SIGALRM"): + signal.alarm(0) + conn.close() + + except Exception as e: + error_msg = f"Error executing query: {e}" + logger.error(error_msg) + self._log_stat("query_database", db_path, False, str(e)) + return {"error": error_msg} diff --git a/oai/mcp/validators.py b/oai/mcp/validators.py new file mode 100644 index 0000000..387a5fd --- /dev/null +++ b/oai/mcp/validators.py @@ -0,0 +1,123 @@ +""" +Query validation for oAI MCP database operations. + +This module provides safety validation for SQL queries to ensure +only read-only operations are executed. +""" + +import re +from typing import Tuple + +from oai.constants import DANGEROUS_SQL_KEYWORDS + + +class SQLiteQueryValidator: + """ + Validate SQLite queries for read-only safety. + + Ensures that only SELECT queries (including CTEs) are allowed + and blocks potentially dangerous operations like INSERT, UPDATE, + DELETE, DROP, etc. + """ + + @staticmethod + def is_safe_query(query: str) -> Tuple[bool, str]: + """ + Validate that a query is a safe read-only SELECT. + + The validation: + 1. Checks that query starts with SELECT or WITH + 2. Strips string literals before checking for dangerous keywords + 3. Blocks any dangerous keywords outside of string literals + + Args: + query: SQL query string to validate + + Returns: + Tuple of (is_safe, error_message) + - is_safe: True if the query is safe to execute + - error_message: Description of why query is unsafe (empty if safe) + + Examples: + >>> SQLiteQueryValidator.is_safe_query("SELECT * FROM users") + (True, "") + >>> SQLiteQueryValidator.is_safe_query("DELETE FROM users") + (False, "Only SELECT queries are allowed...") + >>> SQLiteQueryValidator.is_safe_query("SELECT 'DELETE' FROM users") + (True, "") # 'DELETE' is inside a string literal + """ + query_upper = query.strip().upper() + + # Must start with SELECT or WITH (for CTEs) + if not (query_upper.startswith("SELECT") or query_upper.startswith("WITH")): + return False, "Only SELECT queries are allowed (including WITH/CTE)" + + # Remove string literals before checking for dangerous keywords + # This prevents false positives when keywords appear in data + query_no_strings = re.sub(r"'[^']*'", "", query_upper) + query_no_strings = re.sub(r'"[^"]*"', "", query_no_strings) + + # Check for dangerous keywords outside of quotes + for keyword in DANGEROUS_SQL_KEYWORDS: + if re.search(r"\b" + keyword + r"\b", query_no_strings): + return False, f"Keyword '{keyword}' not allowed in read-only mode" + + return True, "" + + @staticmethod + def sanitize_table_name(table_name: str) -> str: + """ + Sanitize a table name to prevent SQL injection. + + Only allows alphanumeric characters and underscores. + + Args: + table_name: Table name to sanitize + + Returns: + Sanitized table name + + Raises: + ValueError: If table name contains invalid characters + """ + # Remove any characters that aren't alphanumeric or underscore + sanitized = re.sub(r"[^\w]", "", table_name) + + if not sanitized: + raise ValueError("Table name cannot be empty after sanitization") + + if sanitized != table_name: + raise ValueError( + f"Table name contains invalid characters: {table_name}" + ) + + return sanitized + + @staticmethod + def sanitize_column_name(column_name: str) -> str: + """ + Sanitize a column name to prevent SQL injection. + + Only allows alphanumeric characters and underscores. + + Args: + column_name: Column name to sanitize + + Returns: + Sanitized column name + + Raises: + ValueError: If column name contains invalid characters + """ + # Remove any characters that aren't alphanumeric or underscore + sanitized = re.sub(r"[^\w]", "", column_name) + + if not sanitized: + raise ValueError("Column name cannot be empty after sanitization") + + if sanitized != column_name: + raise ValueError( + f"Column name contains invalid characters: {column_name}" + ) + + return sanitized diff --git a/oai/providers/__init__.py b/oai/providers/__init__.py new file mode 100644 index 0000000..93df1e5 --- /dev/null +++ b/oai/providers/__init__.py @@ -0,0 +1,32 @@ +""" +Provider abstraction for oAI. + +This module provides a unified interface for AI model providers, +enabling easy extension to support additional providers beyond OpenRouter. +""" + +from oai.providers.base import ( + AIProvider, + ChatMessage, + ChatResponse, + ToolCall, + ToolFunction, + UsageStats, + ModelInfo, + ProviderCapabilities, +) +from oai.providers.openrouter import OpenRouterProvider + +__all__ = [ + # Base classes and types + "AIProvider", + "ChatMessage", + "ChatResponse", + "ToolCall", + "ToolFunction", + "UsageStats", + "ModelInfo", + "ProviderCapabilities", + # Provider implementations + "OpenRouterProvider", +] diff --git a/oai/providers/base.py b/oai/providers/base.py new file mode 100644 index 0000000..865fe7d --- /dev/null +++ b/oai/providers/base.py @@ -0,0 +1,413 @@ +""" +Abstract base provider for AI model integration. + +This module defines the interface that all AI providers must implement, +along with common data structures for requests and responses. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import ( + Any, + AsyncIterator, + Dict, + Iterator, + List, + Optional, + Union, +) + + +class MessageRole(str, Enum): + """Message roles in a conversation.""" + + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + TOOL = "tool" + + +@dataclass +class ToolFunction: + """ + Represents a function within a tool call. + + Attributes: + name: The function name + arguments: JSON string of function arguments + """ + + name: str + arguments: str + + +@dataclass +class ToolCall: + """ + Represents a tool/function call requested by the model. + + Attributes: + id: Unique identifier for this tool call + type: Type of tool call (usually "function") + function: The function being called + """ + + id: str + type: str + function: ToolFunction + + +@dataclass +class UsageStats: + """ + Token usage statistics from an API response. + + Attributes: + prompt_tokens: Number of tokens in the prompt + completion_tokens: Number of tokens in the completion + total_tokens: Total tokens used + total_cost_usd: Cost in USD (if available from API) + """ + + prompt_tokens: int = 0 + completion_tokens: int = 0 + total_tokens: int = 0 + total_cost_usd: Optional[float] = None + + @property + def input_tokens(self) -> int: + """Alias for prompt_tokens.""" + return self.prompt_tokens + + @property + def output_tokens(self) -> int: + """Alias for completion_tokens.""" + return self.completion_tokens + + +@dataclass +class ChatMessage: + """ + A single message in a chat conversation. + + Attributes: + role: The role of the message sender + content: Message content (text or structured content blocks) + name: Optional name for the sender + tool_calls: List of tool calls (for assistant messages) + tool_call_id: Tool call ID this message responds to (for tool messages) + """ + + role: str + content: Union[str, List[Dict[str, Any]], None] = None + name: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = None + tool_call_id: Optional[str] = None + + def to_dict(self) -> Dict[str, Any]: + """Convert to dictionary format for API requests.""" + result: Dict[str, Any] = {"role": self.role} + + if self.content is not None: + result["content"] = self.content + + if self.name: + result["name"] = self.name + + if self.tool_calls: + result["tool_calls"] = [ + { + "id": tc.id, + "type": tc.type, + "function": { + "name": tc.function.name, + "arguments": tc.function.arguments, + }, + } + for tc in self.tool_calls + ] + + if self.tool_call_id: + result["tool_call_id"] = self.tool_call_id + + return result + + +@dataclass +class ChatResponseChoice: + """ + A single choice in a chat response. + + Attributes: + index: Index of this choice + message: The response message + finish_reason: Why the response ended + """ + + index: int + message: ChatMessage + finish_reason: Optional[str] = None + + +@dataclass +class ChatResponse: + """ + Response from a chat completion request. + + Attributes: + id: Unique identifier for this response + choices: List of response choices + usage: Token usage statistics + model: Model that generated this response + created: Unix timestamp of creation + """ + + id: str + choices: List[ChatResponseChoice] + usage: Optional[UsageStats] = None + model: Optional[str] = None + created: Optional[int] = None + + @property + def message(self) -> Optional[ChatMessage]: + """Get the first choice's message.""" + if self.choices: + return self.choices[0].message + return None + + @property + def content(self) -> Optional[str]: + """Get the text content of the first choice.""" + msg = self.message + if msg and isinstance(msg.content, str): + return msg.content + return None + + @property + def tool_calls(self) -> Optional[List[ToolCall]]: + """Get tool calls from the first choice.""" + msg = self.message + if msg: + return msg.tool_calls + return None + + +@dataclass +class StreamChunk: + """ + A single chunk from a streaming response. + + Attributes: + id: Response ID + delta_content: New content in this chunk + finish_reason: Finish reason (if this is the last chunk) + usage: Usage stats (usually in the last chunk) + error: Error message if something went wrong + """ + + id: str + delta_content: Optional[str] = None + finish_reason: Optional[str] = None + usage: Optional[UsageStats] = None + error: Optional[str] = None + + +@dataclass +class ModelInfo: + """ + Information about an AI model. + + Attributes: + id: Unique model identifier + name: Display name + description: Model description + context_length: Maximum context window size + pricing: Pricing info (input/output per million tokens) + supported_parameters: List of supported API parameters + input_modalities: Supported input types (text, image, etc.) + output_modalities: Supported output types + """ + + id: str + name: str + description: str = "" + context_length: int = 0 + pricing: Dict[str, float] = field(default_factory=dict) + supported_parameters: List[str] = field(default_factory=list) + input_modalities: List[str] = field(default_factory=lambda: ["text"]) + output_modalities: List[str] = field(default_factory=lambda: ["text"]) + + def supports_images(self) -> bool: + """Check if model supports image input.""" + return "image" in self.input_modalities + + def supports_tools(self) -> bool: + """Check if model supports function calling/tools.""" + return "tools" in self.supported_parameters or "functions" in self.supported_parameters + + def supports_streaming(self) -> bool: + """Check if model supports streaming responses.""" + return "stream" in self.supported_parameters + + def supports_online(self) -> bool: + """Check if model supports web search (online mode).""" + return self.supports_tools() + + +@dataclass +class ProviderCapabilities: + """ + Capabilities supported by a provider. + + Attributes: + streaming: Provider supports streaming responses + tools: Provider supports function calling + images: Provider supports image inputs + online: Provider supports web search + max_context: Maximum context length across all models + """ + + streaming: bool = True + tools: bool = True + images: bool = True + online: bool = False + max_context: int = 128000 + + +class AIProvider(ABC): + """ + Abstract base class for AI model providers. + + All provider implementations must inherit from this class + and implement the required abstract methods. + """ + + def __init__(self, api_key: str, base_url: Optional[str] = None): + """ + Initialize the provider. + + Args: + api_key: API key for authentication + base_url: Optional custom base URL for the API + """ + self.api_key = api_key + self.base_url = base_url + + @property + @abstractmethod + def name(self) -> str: + """Get the provider name.""" + pass + + @property + @abstractmethod + def capabilities(self) -> ProviderCapabilities: + """Get provider capabilities.""" + pass + + @abstractmethod + def list_models(self) -> List[ModelInfo]: + """ + Fetch available models from the provider. + + Returns: + List of available models with their info + """ + pass + + @abstractmethod + def get_model(self, model_id: str) -> Optional[ModelInfo]: + """ + Get information about a specific model. + + Args: + model_id: The model identifier + + Returns: + Model information or None if not found + """ + pass + + @abstractmethod + def chat( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + **kwargs: Any, + ) -> Union[ChatResponse, Iterator[StreamChunk]]: + """ + Send a chat completion request. + + Args: + model: Model ID to use + messages: List of chat messages + stream: Whether to stream the response + max_tokens: Maximum tokens in response + temperature: Sampling temperature + tools: List of tool definitions for function calling + tool_choice: How to handle tool selection ("auto", "none", etc.) + **kwargs: Additional provider-specific parameters + + Returns: + ChatResponse for non-streaming, Iterator[StreamChunk] for streaming + """ + pass + + @abstractmethod + async def chat_async( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + **kwargs: Any, + ) -> Union[ChatResponse, AsyncIterator[StreamChunk]]: + """ + Send an async chat completion request. + + Args: + model: Model ID to use + messages: List of chat messages + stream: Whether to stream the response + max_tokens: Maximum tokens in response + temperature: Sampling temperature + tools: List of tool definitions for function calling + tool_choice: How to handle tool selection + **kwargs: Additional provider-specific parameters + + Returns: + ChatResponse for non-streaming, AsyncIterator[StreamChunk] for streaming + """ + pass + + @abstractmethod + def get_credits(self) -> Optional[Dict[str, Any]]: + """ + Get account credit/balance information. + + Returns: + Dict with credit info or None if not supported + """ + pass + + def validate_api_key(self) -> bool: + """ + Validate that the API key is valid. + + Returns: + True if API key is valid + """ + try: + self.list_models() + return True + except Exception: + return False diff --git a/oai/providers/openrouter.py b/oai/providers/openrouter.py new file mode 100644 index 0000000..075fb05 --- /dev/null +++ b/oai/providers/openrouter.py @@ -0,0 +1,623 @@ +""" +OpenRouter provider implementation. + +This module implements the AIProvider interface for OpenRouter, +supporting chat completions, streaming, and function calling. +""" + +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union + +import requests +from openrouter import OpenRouter + +from oai.constants import APP_NAME, APP_URL, DEFAULT_BASE_URL +from oai.providers.base import ( + AIProvider, + ChatMessage, + ChatResponse, + ChatResponseChoice, + ModelInfo, + ProviderCapabilities, + StreamChunk, + ToolCall, + ToolFunction, + UsageStats, +) +from oai.utils.logging import get_logger + + +class OpenRouterProvider(AIProvider): + """ + OpenRouter API provider implementation. + + Provides access to multiple AI models through OpenRouter's unified API, + supporting chat completions, streaming responses, and function calling. + + Attributes: + client: The underlying OpenRouter client + _models_cache: Cached list of available models + """ + + def __init__( + self, + api_key: str, + base_url: Optional[str] = None, + app_name: str = APP_NAME, + app_url: str = APP_URL, + ): + """ + Initialize the OpenRouter provider. + + Args: + api_key: OpenRouter API key + base_url: Optional custom base URL + app_name: Application name for API headers + app_url: Application URL for API headers + """ + super().__init__(api_key, base_url or DEFAULT_BASE_URL) + self.app_name = app_name + self.app_url = app_url + self.client = OpenRouter(api_key=api_key) + self._models_cache: Optional[List[ModelInfo]] = None + self._raw_models_cache: Optional[List[Dict[str, Any]]] = None + + self.logger = get_logger() + self.logger.info(f"OpenRouter provider initialized with base URL: {self.base_url}") + + @property + def name(self) -> str: + """Get the provider name.""" + return "OpenRouter" + + @property + def capabilities(self) -> ProviderCapabilities: + """Get provider capabilities.""" + return ProviderCapabilities( + streaming=True, + tools=True, + images=True, + online=True, + max_context=2000000, # Claude models support up to 200k + ) + + def _get_headers(self) -> Dict[str, str]: + """Get standard HTTP headers for API requests.""" + headers = { + "HTTP-Referer": self.app_url, + "X-Title": self.app_name, + } + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + return headers + + def _parse_model(self, model_data: Dict[str, Any]) -> ModelInfo: + """ + Parse raw model data into ModelInfo. + + Args: + model_data: Raw model data from API + + Returns: + Parsed ModelInfo object + """ + architecture = model_data.get("architecture", {}) + pricing_data = model_data.get("pricing", {}) + + # Parse pricing (convert from string to float if needed) + pricing = {} + for key in ["prompt", "completion"]: + value = pricing_data.get(key) + if value is not None: + try: + # Convert from per-token to per-million-tokens + pricing[key] = float(value) * 1_000_000 + except (ValueError, TypeError): + pricing[key] = 0.0 + + return ModelInfo( + id=model_data.get("id", ""), + name=model_data.get("name", model_data.get("id", "")), + description=model_data.get("description", ""), + context_length=model_data.get("context_length", 0), + pricing=pricing, + supported_parameters=model_data.get("supported_parameters", []), + input_modalities=architecture.get("input_modalities", ["text"]), + output_modalities=architecture.get("output_modalities", ["text"]), + ) + + def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]: + """ + Fetch available models from OpenRouter. + + Args: + filter_text_only: If True, exclude video-only models + + Returns: + List of available models + + Raises: + Exception: If API request fails + """ + if self._models_cache is not None: + return self._models_cache + + try: + response = requests.get( + f"{self.base_url}/models", + headers=self._get_headers(), + timeout=10, + ) + response.raise_for_status() + + raw_models = response.json().get("data", []) + self._raw_models_cache = raw_models + + models = [] + for model_data in raw_models: + # Optionally filter out video-only models + if filter_text_only: + modalities = model_data.get("modalities", []) + if modalities and "video" in modalities and "text" not in modalities: + continue + + models.append(self._parse_model(model_data)) + + self._models_cache = models + self.logger.info(f"Fetched {len(models)} models from OpenRouter") + return models + + except requests.RequestException as e: + self.logger.error(f"Failed to fetch models: {e}") + raise + + def get_raw_models(self) -> List[Dict[str, Any]]: + """ + Get raw model data as returned by the API. + + Useful for accessing provider-specific fields not in ModelInfo. + + Returns: + List of raw model dictionaries + """ + if self._raw_models_cache is None: + self.list_models() + return self._raw_models_cache or [] + + def get_model(self, model_id: str) -> Optional[ModelInfo]: + """ + Get information about a specific model. + + Args: + model_id: The model identifier + + Returns: + Model information or None if not found + """ + models = self.list_models() + for model in models: + if model.id == model_id: + return model + return None + + def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]: + """ + Get raw model data for a specific model. + + Args: + model_id: The model identifier + + Returns: + Raw model dictionary or None if not found + """ + raw_models = self.get_raw_models() + for model in raw_models: + if model.get("id") == model_id: + return model + return None + + def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]: + """ + Convert ChatMessage objects to API format. + + Args: + messages: List of ChatMessage objects + + Returns: + List of message dictionaries for the API + """ + return [msg.to_dict() for msg in messages] + + def _parse_usage(self, usage_data: Any) -> Optional[UsageStats]: + """ + Parse usage data from API response. + + Args: + usage_data: Raw usage data from API + + Returns: + Parsed UsageStats or None + """ + if not usage_data: + return None + + # Handle both attribute and dict access + prompt_tokens = 0 + completion_tokens = 0 + total_cost = None + + if hasattr(usage_data, "prompt_tokens"): + prompt_tokens = getattr(usage_data, "prompt_tokens", 0) or 0 + elif isinstance(usage_data, dict): + prompt_tokens = usage_data.get("prompt_tokens", 0) or 0 + + if hasattr(usage_data, "completion_tokens"): + completion_tokens = getattr(usage_data, "completion_tokens", 0) or 0 + elif isinstance(usage_data, dict): + completion_tokens = usage_data.get("completion_tokens", 0) or 0 + + # Try alternative naming (input_tokens/output_tokens) + if prompt_tokens == 0: + if hasattr(usage_data, "input_tokens"): + prompt_tokens = getattr(usage_data, "input_tokens", 0) or 0 + elif isinstance(usage_data, dict): + prompt_tokens = usage_data.get("input_tokens", 0) or 0 + + if completion_tokens == 0: + if hasattr(usage_data, "output_tokens"): + completion_tokens = getattr(usage_data, "output_tokens", 0) or 0 + elif isinstance(usage_data, dict): + completion_tokens = usage_data.get("output_tokens", 0) or 0 + + # Get cost if available + if hasattr(usage_data, "total_cost_usd"): + total_cost = getattr(usage_data, "total_cost_usd", None) + elif isinstance(usage_data, dict): + total_cost = usage_data.get("total_cost_usd") + + return UsageStats( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + total_cost_usd=float(total_cost) if total_cost else None, + ) + + def _parse_tool_calls(self, tool_calls_data: Any) -> Optional[List[ToolCall]]: + """ + Parse tool calls from API response. + + Args: + tool_calls_data: Raw tool calls data + + Returns: + List of ToolCall objects or None + """ + if not tool_calls_data: + return None + + tool_calls = [] + for tc in tool_calls_data: + # Handle both attribute and dict access + if hasattr(tc, "id"): + tc_id = tc.id + tc_type = getattr(tc, "type", "function") + func = tc.function + func_name = func.name + func_args = func.arguments + else: + tc_id = tc.get("id", "") + tc_type = tc.get("type", "function") + func = tc.get("function", {}) + func_name = func.get("name", "") + func_args = func.get("arguments", "{}") + + tool_calls.append( + ToolCall( + id=tc_id, + type=tc_type, + function=ToolFunction(name=func_name, arguments=func_args), + ) + ) + + return tool_calls if tool_calls else None + + def _parse_response(self, response: Any) -> ChatResponse: + """ + Parse API response into ChatResponse. + + Args: + response: Raw API response + + Returns: + Parsed ChatResponse + """ + choices = [] + for choice in response.choices: + msg = choice.message + message = ChatMessage( + role=msg.role if hasattr(msg, "role") else "assistant", + content=msg.content if hasattr(msg, "content") else None, + tool_calls=self._parse_tool_calls( + getattr(msg, "tool_calls", None) + ), + ) + choices.append( + ChatResponseChoice( + index=choice.index if hasattr(choice, "index") else 0, + message=message, + finish_reason=getattr(choice, "finish_reason", None), + ) + ) + + return ChatResponse( + id=response.id if hasattr(response, "id") else "", + choices=choices, + usage=self._parse_usage(getattr(response, "usage", None)), + model=getattr(response, "model", None), + created=getattr(response, "created", None), + ) + + def chat( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + transforms: Optional[List[str]] = None, + **kwargs: Any, + ) -> Union[ChatResponse, Iterator[StreamChunk]]: + """ + Send a chat completion request to OpenRouter. + + Args: + model: Model ID to use + messages: List of chat messages + stream: Whether to stream the response + max_tokens: Maximum tokens in response + temperature: Sampling temperature (0-2) + tools: List of tool definitions for function calling + tool_choice: How to handle tool selection ("auto", "none", etc.) + transforms: List of transforms (e.g., ["middle-out"]) + **kwargs: Additional parameters + + Returns: + ChatResponse for non-streaming, Iterator[StreamChunk] for streaming + """ + # Build request parameters + params: Dict[str, Any] = { + "model": model, + "messages": self._convert_messages(messages), + "stream": stream, + "http_headers": self._get_headers(), + } + + # Request usage stats in streaming responses + if stream: + params["stream_options"] = {"include_usage": True} + + if max_tokens is not None: + params["max_tokens"] = max_tokens + + if temperature is not None: + params["temperature"] = temperature + + if tools: + params["tools"] = tools + params["tool_choice"] = tool_choice or "auto" + + if transforms: + params["transforms"] = transforms + + # Add any additional parameters + params.update(kwargs) + + self.logger.debug(f"Sending chat request to model {model}") + + try: + response = self.client.chat.send(**params) + + if stream: + return self._stream_response(response) + else: + return self._parse_response(response) + + except Exception as e: + self.logger.error(f"Chat request failed: {e}") + raise + + def _stream_response(self, response: Any) -> Iterator[StreamChunk]: + """ + Process a streaming response. + + Args: + response: Streaming response from API + + Yields: + StreamChunk objects + """ + last_usage = None + + try: + for chunk in response: + # Check for errors + if hasattr(chunk, "error") and chunk.error: + yield StreamChunk( + id=getattr(chunk, "id", ""), + error=chunk.error.message if hasattr(chunk.error, "message") else str(chunk.error), + ) + return + + # Extract delta content + delta_content = None + finish_reason = None + + if hasattr(chunk, "choices") and chunk.choices: + choice = chunk.choices[0] + if hasattr(choice, "delta"): + delta = choice.delta + if hasattr(delta, "content") and delta.content: + delta_content = delta.content + finish_reason = getattr(choice, "finish_reason", None) + + # Track usage from last chunk + if hasattr(chunk, "usage") and chunk.usage: + last_usage = self._parse_usage(chunk.usage) + + yield StreamChunk( + id=getattr(chunk, "id", ""), + delta_content=delta_content, + finish_reason=finish_reason, + usage=last_usage if finish_reason else None, + ) + + except Exception as e: + self.logger.error(f"Stream error: {e}") + yield StreamChunk(id="", error=str(e)) + + async def chat_async( + self, + model: str, + messages: List[ChatMessage], + stream: bool = False, + max_tokens: Optional[int] = None, + temperature: Optional[float] = None, + tools: Optional[List[Dict[str, Any]]] = None, + tool_choice: Optional[str] = None, + **kwargs: Any, + ) -> Union[ChatResponse, AsyncIterator[StreamChunk]]: + """ + Send an async chat completion request. + + Note: Currently wraps the sync implementation. + TODO: Implement true async support when OpenRouter SDK supports it. + + Args: + model: Model ID to use + messages: List of chat messages + stream: Whether to stream the response + max_tokens: Maximum tokens in response + temperature: Sampling temperature + tools: List of tool definitions + tool_choice: Tool selection mode + **kwargs: Additional parameters + + Returns: + ChatResponse for non-streaming, AsyncIterator for streaming + """ + # For now, use sync implementation + # TODO: Add true async when SDK supports it + result = self.chat( + model=model, + messages=messages, + stream=stream, + max_tokens=max_tokens, + temperature=temperature, + tools=tools, + tool_choice=tool_choice, + **kwargs, + ) + + if stream and isinstance(result, Iterator): + # Convert sync iterator to async + async def async_iter() -> AsyncIterator[StreamChunk]: + for chunk in result: + yield chunk + + return async_iter() + + return result + + def get_credits(self) -> Optional[Dict[str, Any]]: + """ + Get OpenRouter account credit information. + + Returns: + Dict with credit info: + - total_credits: Total credits purchased + - used_credits: Credits used + - credits_left: Remaining credits + + Raises: + Exception: If API request fails + """ + if not self.api_key: + return None + + try: + response = requests.get( + f"{self.base_url}/credits", + headers=self._get_headers(), + timeout=10, + ) + response.raise_for_status() + + data = response.json().get("data", {}) + total_credits = float(data.get("total_credits", 0)) + total_usage = float(data.get("total_usage", 0)) + credits_left = total_credits - total_usage + + return { + "total_credits": total_credits, + "used_credits": total_usage, + "credits_left": credits_left, + "total_credits_formatted": f"${total_credits:.2f}", + "used_credits_formatted": f"${total_usage:.2f}", + "credits_left_formatted": f"${credits_left:.2f}", + } + + except Exception as e: + self.logger.error(f"Failed to fetch credits: {e}") + return None + + def clear_cache(self) -> None: + """Clear the models cache to force a refresh.""" + self._models_cache = None + self._raw_models_cache = None + self.logger.debug("Models cache cleared") + + def get_effective_model_id(self, model_id: str, online_enabled: bool) -> str: + """ + Get the effective model ID with online suffix if needed. + + Args: + model_id: Base model ID + online_enabled: Whether online mode is enabled + + Returns: + Model ID with :online suffix if applicable + """ + if online_enabled and not model_id.endswith(":online"): + return f"{model_id}:online" + return model_id + + def estimate_cost( + self, + model_id: str, + input_tokens: int, + output_tokens: int, + ) -> float: + """ + Estimate the cost for a completion. + + Args: + model_id: Model ID + input_tokens: Number of input tokens + output_tokens: Number of output tokens + + Returns: + Estimated cost in USD + """ + model = self.get_model(model_id) + if model and model.pricing: + input_cost = model.pricing.get("prompt", 0) * input_tokens / 1_000_000 + output_cost = model.pricing.get("completion", 0) * output_tokens / 1_000_000 + return input_cost + output_cost + + # Fallback to default pricing if model not found + from oai.constants import MODEL_PRICING + + input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000 + output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000 + return input_cost + output_cost diff --git a/oai/py.typed b/oai/py.typed new file mode 100644 index 0000000..cb5707d --- /dev/null +++ b/oai/py.typed @@ -0,0 +1,2 @@ +# Marker file for PEP 561 +# This package supports type checking diff --git a/oai/ui/__init__.py b/oai/ui/__init__.py new file mode 100644 index 0000000..9f4be60 --- /dev/null +++ b/oai/ui/__init__.py @@ -0,0 +1,51 @@ +""" +UI utilities for oAI. + +This module provides rich terminal UI components and display helpers +for the chat application. +""" + +from oai.ui.console import ( + console, + clear_screen, + display_panel, + display_table, + display_markdown, + print_error, + print_warning, + print_success, + print_info, +) +from oai.ui.tables import ( + create_model_table, + create_stats_table, + create_help_table, + display_paginated_table, +) +from oai.ui.prompts import ( + prompt_confirm, + prompt_choice, + prompt_input, +) + +__all__ = [ + # Console utilities + "console", + "clear_screen", + "display_panel", + "display_table", + "display_markdown", + "print_error", + "print_warning", + "print_success", + "print_info", + # Table utilities + "create_model_table", + "create_stats_table", + "create_help_table", + "display_paginated_table", + # Prompt utilities + "prompt_confirm", + "prompt_choice", + "prompt_input", +] diff --git a/oai/ui/console.py b/oai/ui/console.py new file mode 100644 index 0000000..6836b7d --- /dev/null +++ b/oai/ui/console.py @@ -0,0 +1,242 @@ +""" +Console utilities for oAI. + +This module provides the Rich console instance and common display functions +for formatted terminal output. +""" + +from typing import Any, Optional + +from rich.console import Console +from rich.markdown import Markdown +from rich.panel import Panel +from rich.table import Table +from rich.text import Text + +# Global console instance for the application +console = Console() + + +def clear_screen() -> None: + """ + Clear the terminal screen. + + Uses ANSI escape codes for fast clearing, with a fallback + for terminals that don't support them. + """ + try: + print("\033[H\033[J", end="", flush=True) + except Exception: + # Fallback: print many newlines + print("\n" * 100) + + +def display_panel( + content: Any, + title: Optional[str] = None, + subtitle: Optional[str] = None, + border_style: str = "green", + title_align: str = "left", + subtitle_align: str = "right", +) -> None: + """ + Display content in a bordered panel. + + Args: + content: Content to display (string, Table, or Markdown) + title: Optional panel title + subtitle: Optional panel subtitle + border_style: Border color/style + title_align: Title alignment ("left", "center", "right") + subtitle_align: Subtitle alignment + """ + panel = Panel( + content, + title=title, + subtitle=subtitle, + border_style=border_style, + title_align=title_align, + subtitle_align=subtitle_align, + ) + console.print(panel) + + +def display_table( + table: Table, + title: Optional[str] = None, + subtitle: Optional[str] = None, +) -> None: + """ + Display a table with optional title panel. + + Args: + table: Rich Table to display + title: Optional panel title + subtitle: Optional panel subtitle + """ + if title: + display_panel(table, title=title, subtitle=subtitle) + else: + console.print(table) + + +def display_markdown( + content: str, + panel: bool = False, + title: Optional[str] = None, +) -> None: + """ + Display markdown-formatted content. + + Args: + content: Markdown text to display + panel: Whether to wrap in a panel + title: Optional panel title (if panel=True) + """ + md = Markdown(content) + if panel: + display_panel(md, title=title) + else: + console.print(md) + + +def print_error(message: str, prefix: str = "Error:") -> None: + """ + Print an error message in red. + + Args: + message: Error message to display + prefix: Prefix before the message (default: "Error:") + """ + console.print(f"[bold red]{prefix}[/] {message}") + + +def print_warning(message: str, prefix: str = "Warning:") -> None: + """ + Print a warning message in yellow. + + Args: + message: Warning message to display + prefix: Prefix before the message (default: "Warning:") + """ + console.print(f"[bold yellow]{prefix}[/] {message}") + + +def print_success(message: str, prefix: str = "✓") -> None: + """ + Print a success message in green. + + Args: + message: Success message to display + prefix: Prefix before the message (default: "✓") + """ + console.print(f"[bold green]{prefix}[/] {message}") + + +def print_info(message: str, dim: bool = False) -> None: + """ + Print an informational message in cyan. + + Args: + message: Info message to display + dim: Whether to dim the message + """ + if dim: + console.print(f"[dim cyan]{message}[/]") + else: + console.print(f"[bold cyan]{message}[/]") + + +def print_metrics( + tokens: int, + cost: float, + time_seconds: float, + context_info: str = "", + online: bool = False, + mcp_mode: Optional[str] = None, + tool_loops: int = 0, + session_tokens: int = 0, + session_cost: float = 0.0, +) -> None: + """ + Print formatted metrics for a response. + + Args: + tokens: Total tokens used + cost: Cost in USD + time_seconds: Response time + context_info: Context information string + online: Whether online mode is active + mcp_mode: MCP mode ("files", "database", or None) + tool_loops: Number of tool call loops + session_tokens: Total session tokens + session_cost: Total session cost + """ + parts = [ + f"📊 Metrics: {tokens} tokens", + f"${cost:.4f}", + f"{time_seconds:.2f}s", + ] + + if context_info: + parts.append(context_info) + + if online: + parts.append("🌐") + + if mcp_mode == "files": + parts.append("🔧") + elif mcp_mode == "database": + parts.append("🗄️") + + if tool_loops > 0: + parts.append(f"({tool_loops} tool loop(s))") + + parts.append(f"Session: {session_tokens} tokens") + parts.append(f"${session_cost:.4f}") + + console.print(f"\n[dim blue]{' | '.join(parts)}[/]") + + +def format_size(size_bytes: int) -> str: + """ + Format a size in bytes to a human-readable string. + + Args: + size_bytes: Size in bytes + + Returns: + Formatted size string (e.g., "1.5 MB") + """ + for unit in ["B", "KB", "MB", "GB", "TB"]: + if abs(size_bytes) < 1024.0: + return f"{size_bytes:.1f} {unit}" + size_bytes /= 1024.0 + return f"{size_bytes:.1f} PB" + + +def format_tokens(tokens: int) -> str: + """ + Format token count with thousands separators. + + Args: + tokens: Number of tokens + + Returns: + Formatted token string (e.g., "1,234,567") + """ + return f"{tokens:,}" + + +def format_cost(cost: float, precision: int = 4) -> str: + """ + Format cost in USD. + + Args: + cost: Cost in dollars + precision: Decimal places + + Returns: + Formatted cost string (e.g., "$0.0123") + """ + return f"${cost:.{precision}f}" diff --git a/oai/ui/prompts.py b/oai/ui/prompts.py new file mode 100644 index 0000000..654a43e --- /dev/null +++ b/oai/ui/prompts.py @@ -0,0 +1,274 @@ +""" +Prompt utilities for oAI. + +This module provides functions for gathering user input +through confirmations, choices, and text prompts. +""" + +from typing import List, Optional, TypeVar + +import typer + +from oai.ui.console import console + +T = TypeVar("T") + + +def prompt_confirm( + message: str, + default: bool = False, + abort: bool = False, +) -> bool: + """ + Prompt the user for a yes/no confirmation. + + Args: + message: The question to ask + default: Default value if user presses Enter + abort: Whether to abort on "no" response + + Returns: + True if user confirms, False otherwise + """ + try: + return typer.confirm(message, default=default, abort=abort) + except (EOFError, KeyboardInterrupt): + console.print("\n[yellow]Cancelled[/]") + return False + + +def prompt_choice( + message: str, + choices: List[str], + default: Optional[str] = None, +) -> Optional[str]: + """ + Prompt the user to select from a list of choices. + + Args: + message: The question to ask + choices: List of valid choices + default: Default choice if user presses Enter + + Returns: + Selected choice or None if cancelled + """ + # Display choices + console.print(f"\n[bold cyan]{message}[/]") + for i, choice in enumerate(choices, 1): + default_marker = " [default]" if choice == default else "" + console.print(f" {i}. {choice}{default_marker}") + + try: + response = input("\nEnter number or value: ").strip() + + if not response and default: + return default + + # Try as number first + try: + index = int(response) - 1 + if 0 <= index < len(choices): + return choices[index] + except ValueError: + pass + + # Try as exact match + if response in choices: + return response + + # Try case-insensitive match + response_lower = response.lower() + for choice in choices: + if choice.lower() == response_lower: + return choice + + console.print(f"[red]Invalid choice: {response}[/]") + return None + + except (EOFError, KeyboardInterrupt): + console.print("\n[yellow]Cancelled[/]") + return None + + +def prompt_input( + message: str, + default: Optional[str] = None, + password: bool = False, + required: bool = False, +) -> Optional[str]: + """ + Prompt the user for text input. + + Args: + message: The prompt message + default: Default value if user presses Enter + password: Whether to hide input (for sensitive data) + required: Whether input is required (loops until provided) + + Returns: + User input or default, None if cancelled + """ + prompt_text = message + if default: + prompt_text += f" [{default}]" + prompt_text += ": " + + try: + while True: + if password: + import getpass + + response = getpass.getpass(prompt_text) + else: + response = input(prompt_text).strip() + + if not response: + if default: + return default + if required: + console.print("[yellow]Input required[/]") + continue + return None + + return response + + except (EOFError, KeyboardInterrupt): + console.print("\n[yellow]Cancelled[/]") + return None + + +def prompt_number( + message: str, + min_value: Optional[int] = None, + max_value: Optional[int] = None, + default: Optional[int] = None, +) -> Optional[int]: + """ + Prompt the user for a numeric input. + + Args: + message: The prompt message + min_value: Minimum allowed value + max_value: Maximum allowed value + default: Default value if user presses Enter + + Returns: + Integer value or None if cancelled + """ + prompt_text = message + if default is not None: + prompt_text += f" [{default}]" + prompt_text += ": " + + try: + while True: + response = input(prompt_text).strip() + + if not response: + if default is not None: + return default + return None + + try: + value = int(response) + except ValueError: + console.print("[red]Please enter a valid number[/]") + continue + + if min_value is not None and value < min_value: + console.print(f"[red]Value must be at least {min_value}[/]") + continue + + if max_value is not None and value > max_value: + console.print(f"[red]Value must be at most {max_value}[/]") + continue + + return value + + except (EOFError, KeyboardInterrupt): + console.print("\n[yellow]Cancelled[/]") + return None + + +def prompt_selection( + items: List[T], + message: str = "Select an item", + display_func: Optional[callable] = None, + allow_cancel: bool = True, +) -> Optional[T]: + """ + Prompt the user to select an item from a list. + + Args: + items: List of items to choose from + message: The selection prompt + display_func: Function to convert item to display string + allow_cancel: Whether to allow cancellation + + Returns: + Selected item or None if cancelled + """ + if not items: + console.print("[yellow]No items to select[/]") + return None + + display = display_func or str + + console.print(f"\n[bold cyan]{message}[/]") + for i, item in enumerate(items, 1): + console.print(f" {i}. {display(item)}") + + if allow_cancel: + console.print(f" 0. Cancel") + + try: + while True: + response = input("\nEnter number: ").strip() + + try: + index = int(response) + except ValueError: + console.print("[red]Please enter a valid number[/]") + continue + + if allow_cancel and index == 0: + return None + + if 1 <= index <= len(items): + return items[index - 1] + + console.print(f"[red]Please enter a number between 1 and {len(items)}[/]") + + except (EOFError, KeyboardInterrupt): + console.print("\n[yellow]Cancelled[/]") + return None + + +def prompt_copy_response(response: str) -> bool: + """ + Prompt user to copy a response to clipboard. + + Args: + response: The response text + + Returns: + True if copied, False otherwise + """ + try: + copy_choice = input("💾 Type 'c' to copy response, or press Enter to continue: ").strip().lower() + if copy_choice == "c": + try: + import pyperclip + + pyperclip.copy(response) + console.print("[bold green]✅ Response copied to clipboard![/]") + return True + except ImportError: + console.print("[yellow]pyperclip not installed - cannot copy to clipboard[/]") + except Exception as e: + console.print(f"[red]Failed to copy: {e}[/]") + except (EOFError, KeyboardInterrupt): + pass + + return False diff --git a/oai/ui/tables.py b/oai/ui/tables.py new file mode 100644 index 0000000..49a1c2e --- /dev/null +++ b/oai/ui/tables.py @@ -0,0 +1,373 @@ +""" +Table utilities for oAI. + +This module provides functions for creating and displaying +formatted tables with pagination support. +""" + +import os +import sys +from typing import Any, Dict, List, Optional + +from rich.panel import Panel +from rich.table import Table + +from oai.ui.console import clear_screen, console + + +def create_model_table( + models: List[Dict[str, Any]], + show_capabilities: bool = True, +) -> Table: + """ + Create a table displaying available AI models. + + Args: + models: List of model dictionaries + show_capabilities: Whether to show capability columns + + Returns: + Rich Table with model information + """ + if show_capabilities: + table = Table( + "No.", + "Model ID", + "Context", + "Image", + "Online", + "Tools", + show_header=True, + header_style="bold magenta", + ) + else: + table = Table( + "No.", + "Model ID", + "Context", + show_header=True, + header_style="bold magenta", + ) + + for i, model in enumerate(models, 1): + model_id = model.get("id", "Unknown") + context = model.get("context_length", 0) + context_str = f"{context:,}" if context else "-" + + if show_capabilities: + # Get modalities and parameters + architecture = model.get("architecture", {}) + input_modalities = architecture.get("input_modalities", []) + supported_params = model.get("supported_parameters", []) + + has_image = "✓" if "image" in input_modalities else "-" + has_online = "✓" if "tools" in supported_params else "-" + has_tools = "✓" if "tools" in supported_params or "functions" in supported_params else "-" + + table.add_row( + str(i), + model_id, + context_str, + has_image, + has_online, + has_tools, + ) + else: + table.add_row(str(i), model_id, context_str) + + return table + + +def create_stats_table(stats: Dict[str, Any]) -> Table: + """ + Create a table displaying session statistics. + + Args: + stats: Dictionary with statistics data + + Returns: + Rich Table with stats + """ + table = Table( + "Metric", + "Value", + show_header=True, + header_style="bold magenta", + ) + + # Token stats + if "input_tokens" in stats: + table.add_row("Input Tokens", f"{stats['input_tokens']:,}") + if "output_tokens" in stats: + table.add_row("Output Tokens", f"{stats['output_tokens']:,}") + if "total_tokens" in stats: + table.add_row("Total Tokens", f"{stats['total_tokens']:,}") + + # Cost stats + if "total_cost" in stats: + table.add_row("Total Cost", f"${stats['total_cost']:.4f}") + if "avg_cost" in stats: + table.add_row("Avg Cost/Message", f"${stats['avg_cost']:.4f}") + + # Message stats + if "message_count" in stats: + table.add_row("Messages", str(stats["message_count"])) + + # Credits + if "credits_left" in stats: + table.add_row("Credits Left", stats["credits_left"]) + + return table + + +def create_help_table(commands: Dict[str, Dict[str, str]]) -> Table: + """ + Create a help table for commands. + + Args: + commands: Dictionary of command info + + Returns: + Rich Table with command help + """ + table = Table( + "Command", + "Description", + "Example", + show_header=True, + header_style="bold magenta", + show_lines=False, + ) + + for cmd, info in commands.items(): + description = info.get("description", "") + example = info.get("example", "") + table.add_row(cmd, description, example) + + return table + + +def create_folder_table( + folders: List[Dict[str, Any]], + gitignore_info: str = "", +) -> Table: + """ + Create a table for MCP folder listing. + + Args: + folders: List of folder dictionaries + gitignore_info: Optional gitignore status info + + Returns: + Rich Table with folder information + """ + table = Table( + "No.", + "Path", + "Files", + "Size", + show_header=True, + header_style="bold magenta", + ) + + for folder in folders: + number = str(folder.get("number", "")) + path = folder.get("path", "") + + if folder.get("exists", True): + files = f"📁 {folder.get('file_count', 0)}" + size = f"{folder.get('size_mb', 0):.1f} MB" + else: + files = "[red]Not found[/red]" + size = "-" + + table.add_row(number, path, files, size) + + return table + + +def create_database_table(databases: List[Dict[str, Any]]) -> Table: + """ + Create a table for MCP database listing. + + Args: + databases: List of database dictionaries + + Returns: + Rich Table with database information + """ + table = Table( + "No.", + "Name", + "Tables", + "Size", + "Status", + show_header=True, + header_style="bold magenta", + ) + + for db in databases: + number = str(db.get("number", "")) + name = db.get("name", "") + table_count = f"{db.get('table_count', 0)} tables" + size = f"{db.get('size_mb', 0):.1f} MB" + + if db.get("warning"): + status = f"[red]{db['warning']}[/red]" + else: + status = "[green]✓[/green]" + + table.add_row(number, name, table_count, size, status) + + return table + + +def display_paginated_table( + table: Table, + title: str, + terminal_height: Optional[int] = None, +) -> None: + """ + Display a table with pagination for large datasets. + + Allows navigating through pages with keyboard input. + Press SPACE for next page, any other key to exit. + + Args: + table: Rich Table to display + title: Title for the table + terminal_height: Override terminal height (auto-detected if None) + """ + # Get terminal dimensions + try: + term_height = terminal_height or os.get_terminal_size().lines - 8 + except OSError: + term_height = 20 + + # Render table to segments + from rich.segment import Segment + + segments = list(console.render(table)) + + # Group segments into lines + current_line_segments: List[Segment] = [] + all_lines: List[List[Segment]] = [] + + for segment in segments: + if segment.text == "\n": + all_lines.append(current_line_segments) + current_line_segments = [] + else: + current_line_segments.append(segment) + + if current_line_segments: + all_lines.append(current_line_segments) + + total_lines = len(all_lines) + + # If table fits in one screen, just display it + if total_lines <= term_height: + console.print(Panel(table, title=title, title_align="left")) + return + + # Extract header and footer lines + header_lines: List[List[Segment]] = [] + data_lines: List[List[Segment]] = [] + footer_line: List[Segment] = [] + + # Find header end (line after the header text with border) + header_end_index = 0 + found_header_text = False + + for i, line_segments in enumerate(all_lines): + has_header_style = any( + seg.style and ("bold" in str(seg.style) or "magenta" in str(seg.style)) + for seg in line_segments + ) + + if has_header_style: + found_header_text = True + + if found_header_text and i > 0: + line_text = "".join(seg.text for seg in line_segments) + if any(char in line_text for char in ["─", "━", "┼", "╪", "┤", "├"]): + header_end_index = i + break + + # Extract footer (bottom border) + if all_lines: + last_line_text = "".join(seg.text for seg in all_lines[-1]) + if any(char in last_line_text for char in ["─", "━", "┴", "╧", "┘", "└"]): + footer_line = all_lines[-1] + all_lines = all_lines[:-1] + + # Split into header and data + if header_end_index > 0: + header_lines = all_lines[: header_end_index + 1] + data_lines = all_lines[header_end_index + 1 :] + else: + header_lines = all_lines[: min(3, len(all_lines))] + data_lines = all_lines[min(3, len(all_lines)) :] + + lines_per_page = term_height - len(header_lines) + current_line = 0 + page_number = 1 + + # Paginate + while current_line < len(data_lines): + clear_screen() + console.print(f"[bold cyan]{title} (Page {page_number})[/]") + + # Print header + for line_segments in header_lines: + for segment in line_segments: + console.print(segment.text, style=segment.style, end="") + console.print() + + # Print data rows for this page + end_line = min(current_line + lines_per_page, len(data_lines)) + for line_segments in data_lines[current_line:end_line]: + for segment in line_segments: + console.print(segment.text, style=segment.style, end="") + console.print() + + # Print footer + if footer_line: + for segment in footer_line: + console.print(segment.text, style=segment.style, end="") + console.print() + + current_line = end_line + page_number += 1 + + # Prompt for next page + if current_line < len(data_lines): + console.print( + f"\n[dim yellow]--- Press SPACE for next page, " + f"or any other key to finish (Page {page_number - 1}, " + f"showing {end_line}/{len(data_lines)} data rows) ---[/dim yellow]" + ) + + try: + import termios + import tty + + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + + try: + tty.setraw(fd) + char = sys.stdin.read(1) + if char != " ": + break + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + + except (ImportError, OSError, AttributeError): + # Fallback for non-Unix systems + try: + user_input = input() + if user_input.strip(): + break + except (EOFError, KeyboardInterrupt): + break diff --git a/oai/utils/__init__.py b/oai/utils/__init__.py new file mode 100644 index 0000000..6048bfe --- /dev/null +++ b/oai/utils/__init__.py @@ -0,0 +1,20 @@ +""" +Utility modules for oAI. + +This package provides common utilities used throughout the application +including logging, file handling, and export functionality. +""" + +from oai.utils.logging import setup_logging, get_logger +from oai.utils.files import read_file_safe, is_binary_file +from oai.utils.export import export_as_markdown, export_as_json, export_as_html + +__all__ = [ + "setup_logging", + "get_logger", + "read_file_safe", + "is_binary_file", + "export_as_markdown", + "export_as_json", + "export_as_html", +] diff --git a/oai/utils/export.py b/oai/utils/export.py new file mode 100644 index 0000000..c476ef9 --- /dev/null +++ b/oai/utils/export.py @@ -0,0 +1,248 @@ +""" +Export utilities for oAI. + +This module provides functions for exporting conversation history +in various formats including Markdown, JSON, and HTML. +""" + +import json +import datetime +from typing import List, Dict +from html import escape as html_escape + +from oai.constants import APP_VERSION, APP_URL + + +def export_as_markdown( + session_history: List[Dict[str, str]], + session_system_prompt: str = "" +) -> str: + """ + Export conversation history as Markdown. + + Args: + session_history: List of message dictionaries with 'prompt' and 'response' + session_system_prompt: Optional system prompt to include + + Returns: + Markdown formatted string + """ + lines = ["# Conversation Export", ""] + + if session_system_prompt: + lines.extend([f"**System Prompt:** {session_system_prompt}", ""]) + + lines.append(f"**Export Date:** {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + lines.append(f"**Messages:** {len(session_history)}") + lines.append("") + lines.append("---") + lines.append("") + + for i, entry in enumerate(session_history, 1): + lines.append(f"## Message {i}") + lines.append("") + lines.append("**User:**") + lines.append("") + lines.append(entry.get("prompt", "")) + lines.append("") + lines.append("**Assistant:**") + lines.append("") + lines.append(entry.get("response", "")) + lines.append("") + lines.append("---") + lines.append("") + + lines.append(f"*Exported from oAI v{APP_VERSION} - {APP_URL}*") + + return "\n".join(lines) + + +def export_as_json( + session_history: List[Dict[str, str]], + session_system_prompt: str = "" +) -> str: + """ + Export conversation history as JSON. + + Args: + session_history: List of message dictionaries + session_system_prompt: Optional system prompt to include + + Returns: + JSON formatted string + """ + export_data = { + "export_date": datetime.datetime.now().isoformat(), + "app_version": APP_VERSION, + "system_prompt": session_system_prompt, + "message_count": len(session_history), + "messages": [ + { + "index": i + 1, + "prompt": entry.get("prompt", ""), + "response": entry.get("response", ""), + "prompt_tokens": entry.get("prompt_tokens", 0), + "completion_tokens": entry.get("completion_tokens", 0), + "cost": entry.get("msg_cost", 0.0), + } + for i, entry in enumerate(session_history) + ], + "totals": { + "prompt_tokens": sum(e.get("prompt_tokens", 0) for e in session_history), + "completion_tokens": sum(e.get("completion_tokens", 0) for e in session_history), + "total_cost": sum(e.get("msg_cost", 0.0) for e in session_history), + } + } + return json.dumps(export_data, indent=2, ensure_ascii=False) + + +def export_as_html( + session_history: List[Dict[str, str]], + session_system_prompt: str = "" +) -> str: + """ + Export conversation history as styled HTML. + + Args: + session_history: List of message dictionaries + session_system_prompt: Optional system prompt to include + + Returns: + HTML formatted string with embedded CSS + """ + html_parts = [ + "", + "", + "", + " ", + " ", + " Conversation Export - oAI", + " ", + "", + "", + "
", + "

Conversation Export

", + f"
Exported: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}
", + f"
Total Messages: {len(session_history)}
", + "
", + ] + + if session_system_prompt: + html_parts.extend([ + "
", + " System Prompt", + f"
{html_escape(session_system_prompt)}
", + "
", + ]) + + for i, entry in enumerate(session_history, 1): + prompt = html_escape(entry.get("prompt", "")) + response = html_escape(entry.get("response", "")) + + html_parts.extend([ + "
", + f"
Message {i} of {len(session_history)}
", + "
", + "
User
", + f"
{prompt}
", + "
", + "
", + "
Assistant
", + f"
{response}
", + "
", + "
", + ]) + + html_parts.extend([ + " ", + "", + "", + ]) + + return "\n".join(html_parts) diff --git a/oai/utils/files.py b/oai/utils/files.py new file mode 100644 index 0000000..f793d8a --- /dev/null +++ b/oai/utils/files.py @@ -0,0 +1,323 @@ +""" +File handling utilities for oAI. + +This module provides safe file reading, type detection, and other +file-related operations used throughout the application. +""" + +import os +import mimetypes +import base64 +from pathlib import Path +from typing import Optional, Dict, Any, Tuple + +from oai.constants import ( + MAX_FILE_SIZE, + CONTENT_TRUNCATION_THRESHOLD, + SUPPORTED_CODE_EXTENSIONS, + ALLOWED_FILE_EXTENSIONS, +) +from oai.utils.logging import get_logger + + +def is_binary_file(file_path: Path) -> bool: + """ + Check if a file appears to be binary. + + Args: + file_path: Path to the file to check + + Returns: + True if the file appears to be binary, False otherwise + """ + try: + with open(file_path, "rb") as f: + # Read first 8KB to check for binary content + chunk = f.read(8192) + # Check for null bytes (common in binary files) + if b"\x00" in chunk: + return True + # Try to decode as UTF-8 + try: + chunk.decode("utf-8") + return False + except UnicodeDecodeError: + return True + except Exception: + return True + + +def get_file_type(file_path: Path) -> Tuple[Optional[str], str]: + """ + Determine the MIME type and category of a file. + + Args: + file_path: Path to the file + + Returns: + Tuple of (mime_type, category) where category is one of: + 'image', 'pdf', 'code', 'text', 'binary', 'unknown' + """ + mime_type, _ = mimetypes.guess_type(str(file_path)) + ext = file_path.suffix.lower() + + if mime_type and mime_type.startswith("image/"): + return mime_type, "image" + elif mime_type == "application/pdf" or ext == ".pdf": + return mime_type or "application/pdf", "pdf" + elif ext in SUPPORTED_CODE_EXTENSIONS: + return mime_type or "text/plain", "code" + elif mime_type and mime_type.startswith("text/"): + return mime_type, "text" + elif is_binary_file(file_path): + return mime_type, "binary" + else: + return mime_type, "unknown" + + +def read_file_safe( + file_path: Path, + max_size: int = MAX_FILE_SIZE, + truncate_threshold: int = CONTENT_TRUNCATION_THRESHOLD +) -> Dict[str, Any]: + """ + Safely read a file with size limits and truncation support. + + Args: + file_path: Path to the file to read + max_size: Maximum file size to read (bytes) + truncate_threshold: Threshold for truncating large files + + Returns: + Dictionary containing: + - content: File content (text or base64) + - size: File size in bytes + - truncated: Whether content was truncated + - encoding: 'text', 'base64', or None on error + - error: Error message if reading failed + """ + logger = get_logger() + + try: + path = Path(file_path).resolve() + + if not path.exists(): + return { + "content": None, + "size": 0, + "truncated": False, + "encoding": None, + "error": f"File not found: {path}" + } + + if not path.is_file(): + return { + "content": None, + "size": 0, + "truncated": False, + "encoding": None, + "error": f"Not a file: {path}" + } + + file_size = path.stat().st_size + + if file_size > max_size: + return { + "content": None, + "size": file_size, + "truncated": False, + "encoding": None, + "error": f"File too large: {file_size / (1024*1024):.1f}MB (max: {max_size / (1024*1024):.0f}MB)" + } + + # Try to read as text first + try: + content = path.read_text(encoding="utf-8") + + # Check if truncation is needed + if file_size > truncate_threshold: + lines = content.split("\n") + total_lines = len(lines) + + # Keep first 500 lines and last 100 lines + head_lines = 500 + tail_lines = 100 + + if total_lines > (head_lines + tail_lines): + truncated_content = ( + "\n".join(lines[:head_lines]) + + f"\n\n... [TRUNCATED: {total_lines - head_lines - tail_lines} lines omitted] ...\n\n" + + "\n".join(lines[-tail_lines:]) + ) + logger.info(f"Read file (truncated): {path} ({file_size} bytes, {total_lines} lines)") + return { + "content": truncated_content, + "size": file_size, + "truncated": True, + "total_lines": total_lines, + "lines_shown": head_lines + tail_lines, + "encoding": "text", + "error": None + } + + logger.info(f"Read file: {path} ({file_size} bytes)") + return { + "content": content, + "size": file_size, + "truncated": False, + "encoding": "text", + "error": None + } + + except UnicodeDecodeError: + # File is binary, return base64 encoded + with open(path, "rb") as f: + binary_data = f.read() + b64_content = base64.b64encode(binary_data).decode("utf-8") + logger.info(f"Read binary file: {path} ({file_size} bytes)") + return { + "content": b64_content, + "size": file_size, + "truncated": False, + "encoding": "base64", + "error": None + } + + except PermissionError as e: + return { + "content": None, + "size": 0, + "truncated": False, + "encoding": None, + "error": f"Permission denied: {e}" + } + except Exception as e: + logger.error(f"Error reading file {file_path}: {e}") + return { + "content": None, + "size": 0, + "truncated": False, + "encoding": None, + "error": str(e) + } + + +def get_file_extension(file_path: Path) -> str: + """ + Get the lowercase file extension. + + Args: + file_path: Path to the file + + Returns: + Lowercase extension including the dot (e.g., '.py') + """ + return file_path.suffix.lower() + + +def is_allowed_extension(file_path: Path) -> bool: + """ + Check if a file has an allowed extension for attachment. + + Args: + file_path: Path to the file + + Returns: + True if the extension is allowed, False otherwise + """ + return get_file_extension(file_path) in ALLOWED_FILE_EXTENSIONS + + +def format_file_size(size_bytes: int) -> str: + """ + Format a file size in human-readable format. + + Args: + size_bytes: Size in bytes + + Returns: + Formatted string (e.g., '1.5 MB', '512 KB') + """ + for unit in ["B", "KB", "MB", "GB", "TB"]: + if abs(size_bytes) < 1024: + return f"{size_bytes:.1f} {unit}" + size_bytes /= 1024 + return f"{size_bytes:.1f} PB" + + +def prepare_file_attachment( + file_path: Path, + model_capabilities: Dict[str, Any] +) -> Optional[Dict[str, Any]]: + """ + Prepare a file for attachment to an API request. + + Args: + file_path: Path to the file + model_capabilities: Model capability information + + Returns: + Content block dictionary for the API, or None if unsupported + """ + logger = get_logger() + path = Path(file_path).resolve() + + if not path.exists(): + logger.warning(f"File not found: {path}") + return None + + mime_type, category = get_file_type(path) + file_size = path.stat().st_size + + if file_size > MAX_FILE_SIZE: + logger.warning(f"File too large: {path} ({format_file_size(file_size)})") + return None + + try: + with open(path, "rb") as f: + file_data = f.read() + + if category == "image": + # Check if model supports images + input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", []) + if "image" not in input_modalities: + logger.warning(f"Model does not support images") + return None + + b64_data = base64.b64encode(file_data).decode("utf-8") + return { + "type": "image_url", + "image_url": {"url": f"data:{mime_type};base64,{b64_data}"} + } + + elif category == "pdf": + # Check if model supports PDFs + input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", []) + supports_pdf = any(mod in input_modalities for mod in ["document", "pdf", "file"]) + if not supports_pdf: + logger.warning(f"Model does not support PDFs") + return None + + b64_data = base64.b64encode(file_data).decode("utf-8") + return { + "type": "image_url", + "image_url": {"url": f"data:application/pdf;base64,{b64_data}"} + } + + elif category in ("code", "text"): + text_content = file_data.decode("utf-8") + return { + "type": "text", + "text": f"File: {path.name}\n\n{text_content}" + } + + else: + logger.warning(f"Unsupported file type: {category} ({mime_type})") + return None + + except UnicodeDecodeError: + logger.error(f"Cannot decode file as UTF-8: {path}") + return None + except Exception as e: + logger.error(f"Error preparing file attachment {path}: {e}") + return None diff --git a/oai/utils/logging.py b/oai/utils/logging.py new file mode 100644 index 0000000..859c51c --- /dev/null +++ b/oai/utils/logging.py @@ -0,0 +1,297 @@ +""" +Logging configuration for oAI. + +This module provides centralized logging setup with Rich formatting, +file rotation, and configurable log levels. +""" + +import io +import os +import glob +import logging +import datetime +import shutil +from logging.handlers import RotatingFileHandler +from pathlib import Path +from typing import Optional + +from rich.console import Console +from rich.logging import RichHandler + +from oai.constants import ( + LOG_FILE, + CONFIG_DIR, + DEFAULT_LOG_MAX_SIZE_MB, + DEFAULT_LOG_BACKUP_COUNT, + DEFAULT_LOG_LEVEL, + VALID_LOG_LEVELS, +) + + +class RotatingRichHandler(RotatingFileHandler): + """ + Custom log handler combining file rotation with Rich formatting. + + This handler writes Rich-formatted log output to a rotating file, + providing colored and formatted logs even in file output while + managing file size and backups automatically. + """ + + def __init__(self, *args, **kwargs): + """Initialize the handler with Rich console for formatting.""" + super().__init__(*args, **kwargs) + # Create an internal console for Rich formatting + self.rich_console = Console( + file=io.StringIO(), + width=120, + force_terminal=False + ) + self.rich_handler = RichHandler( + console=self.rich_console, + show_time=True, + show_path=True, + rich_tracebacks=True, + tracebacks_suppress=["requests", "openrouter", "urllib3", "httpx", "openai"] + ) + + def emit(self, record: logging.LogRecord) -> None: + """ + Emit a log record with Rich formatting. + + Args: + record: The log record to emit + """ + try: + # Format with Rich + self.rich_handler.emit(record) + output = self.rich_console.file.getvalue() + self.rich_console.file.seek(0) + self.rich_console.file.truncate(0) + + if output: + self.stream.write(output) + self.flush() + except Exception: + self.handleError(record) + + +class LoggingManager: + """ + Manages application logging configuration. + + Provides methods to setup, configure, and manage logging with + support for runtime reconfiguration and level changes. + """ + + def __init__(self): + """Initialize the logging manager.""" + self.handler: Optional[RotatingRichHandler] = None + self.app_logger: Optional[logging.Logger] = None + self.max_size_mb: int = DEFAULT_LOG_MAX_SIZE_MB + self.backup_count: int = DEFAULT_LOG_BACKUP_COUNT + self.log_level: str = DEFAULT_LOG_LEVEL + + def setup( + self, + max_size_mb: Optional[int] = None, + backup_count: Optional[int] = None, + log_level: Optional[str] = None + ) -> logging.Logger: + """ + Setup or reconfigure logging. + + Args: + max_size_mb: Maximum log file size in MB + backup_count: Number of backup files to keep + log_level: Logging level string + + Returns: + The configured application logger + """ + # Update configuration if provided + if max_size_mb is not None: + self.max_size_mb = max_size_mb + if backup_count is not None: + self.backup_count = backup_count + if log_level is not None: + self.log_level = log_level + + # Ensure config directory exists + CONFIG_DIR.mkdir(parents=True, exist_ok=True) + + # Get root logger + root_logger = logging.getLogger() + + # Remove existing handler if present + if self.handler is not None: + root_logger.removeHandler(self.handler) + try: + self.handler.close() + except Exception: + pass + + # Check if log needs manual rotation + self._check_rotation() + + # Create new handler + max_bytes = self.max_size_mb * 1024 * 1024 + self.handler = RotatingRichHandler( + filename=str(LOG_FILE), + maxBytes=max_bytes, + backupCount=self.backup_count, + encoding="utf-8" + ) + + self.handler.setLevel(logging.NOTSET) + root_logger.setLevel(logging.WARNING) + root_logger.addHandler(self.handler) + + # Suppress noisy third-party loggers + for logger_name in [ + "asyncio", "urllib3", "requests", "httpx", + "httpcore", "openai", "openrouter" + ]: + logging.getLogger(logger_name).setLevel(logging.WARNING) + + # Configure application logger + self.app_logger = logging.getLogger("oai_app") + level = VALID_LOG_LEVELS.get(self.log_level.lower(), logging.INFO) + self.app_logger.setLevel(level) + self.app_logger.propagate = True + + return self.app_logger + + def _check_rotation(self) -> None: + """Check if log file needs rotation and perform if necessary.""" + if not LOG_FILE.exists(): + return + + current_size = LOG_FILE.stat().st_size + max_bytes = self.max_size_mb * 1024 * 1024 + + if current_size >= max_bytes: + # Perform manual rotation + timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") + backup_file = f"{LOG_FILE}.{timestamp}" + + try: + shutil.move(str(LOG_FILE), backup_file) + except Exception: + pass + + # Clean old backups + self._cleanup_old_backups() + + def _cleanup_old_backups(self) -> None: + """Remove old backup files exceeding the backup count.""" + log_dir = LOG_FILE.parent + backup_pattern = f"{LOG_FILE.name}.*" + backups = sorted(glob.glob(str(log_dir / backup_pattern))) + + while len(backups) > self.backup_count: + oldest = backups.pop(0) + try: + os.remove(oldest) + except Exception: + pass + + def set_level(self, level: str) -> bool: + """ + Set the application log level. + + Args: + level: Log level string (debug/info/warning/error/critical) + + Returns: + True if level was set successfully, False otherwise + """ + level_lower = level.lower() + if level_lower not in VALID_LOG_LEVELS: + return False + + self.log_level = level_lower + if self.app_logger: + self.app_logger.setLevel(VALID_LOG_LEVELS[level_lower]) + + return True + + def get_logger(self) -> logging.Logger: + """ + Get the application logger, initializing if necessary. + + Returns: + The application logger + """ + if self.app_logger is None: + self.setup() + return self.app_logger + + +# Global logging manager instance +_logging_manager = LoggingManager() + + +def setup_logging( + max_size_mb: Optional[int] = None, + backup_count: Optional[int] = None, + log_level: Optional[str] = None +) -> logging.Logger: + """ + Setup application logging. + + This is the main entry point for configuring logging. Call this + early in application startup. + + Args: + max_size_mb: Maximum log file size in MB + backup_count: Number of backup files to keep + log_level: Logging level string + + Returns: + The configured application logger + """ + return _logging_manager.setup(max_size_mb, backup_count, log_level) + + +def get_logger() -> logging.Logger: + """ + Get the application logger. + + Returns: + The application logger instance + """ + return _logging_manager.get_logger() + + +def set_log_level(level: str) -> bool: + """ + Set the application log level. + + Args: + level: Log level string + + Returns: + True if successful, False otherwise + """ + return _logging_manager.set_level(level) + + +def reload_logging( + max_size_mb: Optional[int] = None, + backup_count: Optional[int] = None, + log_level: Optional[str] = None +) -> logging.Logger: + """ + Reload logging configuration. + + Useful when settings change at runtime. + + Args: + max_size_mb: New maximum log file size + backup_count: New backup count + log_level: New log level + + Returns: + The reconfigured logger + """ + return _logging_manager.setup(max_size_mb, backup_count, log_level) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..af80856 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,134 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "oai" +version = "2.1.0" +description = "OpenRouter AI Chat Client - A feature-rich terminal-based chat application" +readme = "README.md" +license = {text = "MIT"} +authors = [ + {name = "Rune", email = "rune@example.com"} +] +maintainers = [ + {name = "Rune", email = "rune@example.com"} +] +keywords = [ + "ai", + "chat", + "openrouter", + "cli", + "terminal", + "mcp", + "llm", +] +classifiers = [ + "Development Status :: 4 - Beta", + "Environment :: Console", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Utilities", +] +requires-python = ">=3.10" +dependencies = [ + "anyio>=4.0.0", + "click>=8.0.0", + "httpx>=0.24.0", + "markdown-it-py>=3.0.0", + "openrouter>=0.0.19", + "packaging>=21.0", + "prompt-toolkit>=3.0.0", + "pyperclip>=1.8.0", + "requests>=2.28.0", + "rich>=13.0.0", + "typer>=0.9.0", + "mcp>=1.0.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0.0", + "pytest-asyncio>=0.21.0", + "pytest-cov>=4.0.0", + "black>=23.0.0", + "isort>=5.12.0", + "mypy>=1.0.0", + "ruff>=0.1.0", +] + +[project.urls] +Homepage = "https://iurl.no/oai" +Repository = "https://gitlab.pm/rune/oai" +Documentation = "https://iurl.no/oai" +"Bug Tracker" = "https://gitlab.pm/rune/oai/issues" + +[project.scripts] +oai = "oai.cli:main" + +[tool.setuptools] +packages = ["oai", "oai.commands", "oai.config", "oai.core", "oai.mcp", "oai.providers", "oai.ui", "oai.utils"] + +[tool.setuptools.package-data] +oai = ["py.typed"] + +[tool.black] +line-length = 100 +target-version = ["py310", "py311", "py312"] +include = '\.pyi?$' +exclude = ''' +/( + \.git + | \.mypy_cache + | \.pytest_cache + | \.venv + | build + | dist +)/ +''' + +[tool.isort] +profile = "black" +line_length = 100 +skip_gitignore = true + +[tool.mypy] +python_version = "3.10" +warn_return_any = true +warn_unused_configs = true +ignore_missing_imports = true +exclude = [ + "build", + "dist", + ".venv", +] + +[tool.ruff] +line-length = 100 +target-version = "py310" +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # Pyflakes + "I", # isort + "B", # flake8-bugbear + "C4", # flake8-comprehensions + "UP", # pyupgrade +] +ignore = [ + "E501", # line too long (handled by black) + "B008", # do not perform function calls in argument defaults + "C901", # too complex +] + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = ["test_*.py"] +asyncio_mode = "auto" +addopts = "-v --tb=short" diff --git a/requirements.txt b/requirements.txt index 8a0a132..ea6105b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,38 +1,26 @@ -anyio==4.11.0 -beautifulsoup4==4.14.2 -charset-normalizer==3.4.4 -click==8.3.1 -docopt==0.6.2 -h11==0.16.0 -httpcore==1.0.9 -httpx==0.28.1 -idna==3.11 -latex2mathml==3.78.1 -loguru==0.7.3 -markdown-it-py==4.0.0 -markdown2==2.5.4 -mdurl==0.1.2 -natsort==8.4.0 -openrouter==0.0.19 -packaging==25.0 -pipreqs==0.4.13 -prompt-toolkit==3.0.52 -Pygments==2.19.2 -pyperclip==1.11.0 -python-dateutil==2.9.0.post0 -python-magic==0.4.27 -PyYAML==6.0.3 -requests==2.32.5 -rich==14.2.0 -shellingham==1.5.4 -six==1.17.0 -sniffio==1.3.1 -soupsieve==2.8 -svgwrite==1.4.3 -tqdm==4.67.1 -typer==0.20.0 -typing-extensions==4.15.0 -urllib3==2.5.0 -wavedrom==2.0.3.post3 -wcwidth==0.2.14 -yarg==0.1.10 +# oai.py v2.1.0-beta - Core Dependencies +anyio>=4.11.0 +charset-normalizer>=3.4.4 +click>=8.3.1 +h11>=0.16.0 +httpcore>=1.0.9 +httpx>=0.28.1 +idna>=3.11 +markdown-it-py>=4.0.0 +mdurl>=0.1.2 +openrouter>=0.0.19 +packaging>=25.0 +prompt-toolkit>=3.0.52 +Pygments>=2.19.2 +pyperclip>=1.11.0 +requests>=2.32.5 +rich>=14.2.0 +shellingham>=1.5.4 +sniffio>=1.3.1 +typer>=0.20.0 +typing-extensions>=4.15.0 +urllib3>=2.5.0 +wcwidth>=0.2.14 + +# MCP (Model Context Protocol) +mcp>=1.25.0 \ No newline at end of file