Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| b0cf88704e |
6
.gitignore
vendored
6
.gitignore
vendored
@@ -23,6 +23,9 @@ Pipfile.lock # Consider if you want to include or exclude
|
|||||||
*~.nib
|
*~.nib
|
||||||
*~.xib
|
*~.xib
|
||||||
|
|
||||||
|
# Claude Code local settings
|
||||||
|
.claude/
|
||||||
|
|
||||||
# Added by author
|
# Added by author
|
||||||
*.zip
|
*.zip
|
||||||
.note
|
.note
|
||||||
@@ -39,3 +42,6 @@ b0.sh
|
|||||||
*.old
|
*.old
|
||||||
*.sh
|
*.sh
|
||||||
*.back
|
*.back
|
||||||
|
requirements.txt
|
||||||
|
system_prompt.txt
|
||||||
|
CLAUDE*
|
||||||
|
|||||||
655
README.md
655
README.md
@@ -1,584 +1,301 @@
|
|||||||
# oAI - OpenRouter AI Chat
|
# oAI - OpenRouter AI Chat Client
|
||||||
|
|
||||||
A powerful terminal-based chat interface for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI agents to access local files and query SQLite databases directly.
|
A powerful, extensible terminal-based chat client for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI to access local files and query SQLite databases.
|
||||||
|
|
||||||
## Description
|
|
||||||
|
|
||||||
oAI is a feature-rich command-line chat application that provides an interactive interface to OpenRouter's AI models. It now includes **MCP integration** for local file system access and read-only database querying, allowing AI to help with code analysis, data exploration, and more.
|
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
### Core Features
|
### Core Features
|
||||||
- 🤖 Interactive chat with 300+ AI models via OpenRouter
|
- 🤖 Interactive chat with 300+ AI models via OpenRouter
|
||||||
- 🔍 Model selection with search and capability filtering
|
- 🔍 Model selection with search and filtering
|
||||||
- 💾 Conversation save/load/export (Markdown, JSON, HTML)
|
- 💾 Conversation save/load/export (Markdown, JSON, HTML)
|
||||||
- 📎 File attachment support (images, PDFs, code files)
|
- 📎 File attachments (images, PDFs, code files)
|
||||||
- 💰 Session cost tracking and credit monitoring
|
- 💰 Real-time cost tracking and credit monitoring
|
||||||
- 🎨 Rich terminal formatting with syntax highlighting
|
- 🎨 Rich terminal UI with syntax highlighting
|
||||||
- 📝 Persistent command history with search (Ctrl+R)
|
- 📝 Persistent command history with search (Ctrl+R)
|
||||||
- ⚙️ Configurable system prompts and token limits
|
|
||||||
- 🗄️ SQLite-based configuration and conversation storage
|
|
||||||
- 🌐 Online mode (web search capabilities)
|
- 🌐 Online mode (web search capabilities)
|
||||||
- 🧠 Conversation memory toggle (save costs with stateless mode)
|
- 🧠 Conversation memory toggle
|
||||||
|
|
||||||
### NEW: MCP (Model Context Protocol) v2.1.0-beta
|
### MCP Integration
|
||||||
- 🔧 **File Mode**: AI can read, search, and list your local files
|
- 🔧 **File Mode**: AI can read, search, and list local files
|
||||||
- Automatic .gitignore filtering
|
- Automatic .gitignore filtering
|
||||||
- Virtual environment exclusion (venv, node_modules, etc.)
|
- Virtual environment exclusion
|
||||||
- Supports code files, text, JSON, YAML, and more
|
|
||||||
- Large file handling (auto-truncates >50KB)
|
- Large file handling (auto-truncates >50KB)
|
||||||
|
|
||||||
- ✍️ **Write Mode** (NEW!): AI can modify files with your permission
|
- ✍️ **Write Mode**: AI can modify files with permission
|
||||||
- Create and edit files within allowed folders
|
- Create, edit, delete files
|
||||||
- Delete files (always requires confirmation)
|
- Move, copy, organize files
|
||||||
- Move, copy, and organize files
|
- Always requires explicit opt-in
|
||||||
- Create directories
|
|
||||||
- Ignores .gitignore for write operations
|
|
||||||
- OFF by default - explicit opt-in required
|
|
||||||
|
|
||||||
- 🗄️ **Database Mode**: AI can query your SQLite databases
|
- 🗄️ **Database Mode**: AI can query SQLite databases
|
||||||
- Read-only access (no data modification possible)
|
- Read-only access (safe)
|
||||||
- Schema inspection (tables, columns, indexes)
|
- Schema inspection
|
||||||
- Full-text search across all tables
|
- Full SQL query support
|
||||||
- SQL query execution (SELECT, JOINs, CTEs, subqueries)
|
|
||||||
- Query validation and timeout protection
|
|
||||||
- Result limiting (max 1000 rows)
|
|
||||||
|
|
||||||
- 🔒 **Security Features**:
|
|
||||||
- Explicit folder/database approval required
|
|
||||||
- System directory blocking
|
|
||||||
- Write mode OFF by default (non-persistent)
|
|
||||||
- Delete operations always require user confirmation
|
|
||||||
- Read-only database access
|
|
||||||
- SQL injection protection
|
|
||||||
- Query timeout (5 seconds)
|
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
- Python 3.10-3.13 (3.14 not supported yet)
|
- Python 3.10-3.13
|
||||||
- OpenRouter API key (get one at https://openrouter.ai)
|
- OpenRouter API key ([get one here](https://openrouter.ai))
|
||||||
- Function-calling model required for MCP features (GPT-4, Claude, etc.)
|
|
||||||
|
|
||||||
## Screenshot
|
|
||||||
|
|
||||||
[<img src="https://gitlab.pm/rune/oai/raw/branch/main/images/screenshot_01.png">](https://gitlab.pm/rune/oai/src/branch/main/README.md)
|
|
||||||
|
|
||||||
*Screenshot from version 1.0 - MCP interface shows mode indicators like `[🔧 MCP: Files]` or `[🗄️ MCP: DB #1]`*
|
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
### Option 1: From Source (Recommended for Development)
|
### Option 1: Install from Source (Recommended)
|
||||||
|
|
||||||
#### 1. Install Dependencies
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip install -r requirements.txt
|
# Clone the repository
|
||||||
|
git clone https://gitlab.pm/rune/oai.git
|
||||||
|
cd oai
|
||||||
|
|
||||||
|
# Install with pip
|
||||||
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 2. Make Executable
|
### Option 2: Pre-built Binary (macOS/Linux)
|
||||||
|
|
||||||
```bash
|
Download from [Releases](https://gitlab.pm/rune/oai/releases):
|
||||||
chmod +x oai.py
|
- **macOS (Apple Silicon)**: `oai_v2.1.0_mac_arm64.zip`
|
||||||
```
|
- **Linux (x86_64)**: `oai_v2.1.0_linux_x86_64.zip`
|
||||||
|
|
||||||
#### 3. Copy to PATH
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Option 1: System-wide (requires sudo)
|
|
||||||
sudo cp oai.py /usr/local/bin/oai
|
|
||||||
|
|
||||||
# Option 2: User-local (recommended)
|
|
||||||
mkdir -p ~/.local/bin
|
|
||||||
cp oai.py ~/.local/bin/oai
|
|
||||||
|
|
||||||
# Add to PATH if needed (add to ~/.bashrc or ~/.zshrc)
|
|
||||||
export PATH="$HOME/.local/bin:$PATH"
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 4. Verify Installation
|
|
||||||
|
|
||||||
```bash
|
|
||||||
oai --version
|
|
||||||
```
|
|
||||||
|
|
||||||
### Option 2: Pre-built Binaries
|
|
||||||
|
|
||||||
Download platform-specific binaries:
|
|
||||||
- **macOS (Apple Silicon)**: `oai_vx.x.x_mac_arm64.zip`
|
|
||||||
- **Linux (x86_64)**: `oai_vx.x.x-linux-x86_64.zip`
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Extract and install
|
# Extract and install
|
||||||
unzip oai_vx.x.x_mac_arm64.zip # or `oai_vx.x.x-linux-x86_64.zip`
|
unzip oai_v2.1.0_*.zip
|
||||||
chmod +x oai
|
mkdir -p ~/.local/bin
|
||||||
mkdir -p ~/.local/bin # Remember to add this to your path. Or just move to folder already in your $PATH
|
|
||||||
mv oai ~/.local/bin/
|
mv oai ~/.local/bin/
|
||||||
|
|
||||||
|
# macOS only: Remove quarantine and approve
|
||||||
|
xattr -cr ~/.local/bin/oai
|
||||||
|
# Then right-click oai in Finder → Open With → Terminal → Click "Open"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Add to PATH
|
||||||
### Alternative: Shell Alias
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Add to ~/.bashrc or ~/.zshrc
|
# Add to ~/.zshrc or ~/.bashrc
|
||||||
alias oai='python3 /path/to/oai.py'
|
export PATH="$HOME/.local/bin:$PATH"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
### First Run Setup
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
oai
|
# Start the chat client
|
||||||
|
oai chat
|
||||||
|
|
||||||
|
# Or with options
|
||||||
|
oai chat --model gpt-4o --mcp
|
||||||
```
|
```
|
||||||
|
|
||||||
On first run, you'll be prompted to enter your OpenRouter API key.
|
On first run, you'll be prompted for your OpenRouter API key.
|
||||||
|
|
||||||
### Basic Usage
|
### Basic Commands
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Start chatting
|
# In the chat interface:
|
||||||
oai
|
/model # Select AI model
|
||||||
|
/help # Show all commands
|
||||||
# Select a model
|
/mcp on # Enable file/database access
|
||||||
You> /model
|
/stats # View session statistics
|
||||||
|
exit # Quit
|
||||||
# Enable MCP for file access
|
|
||||||
You> /mcp on
|
|
||||||
You> /mcp add ~/Documents
|
|
||||||
|
|
||||||
# Ask AI to help with files (read-only)
|
|
||||||
[🔧 MCP: Files] You> List all Python files in Documents
|
|
||||||
[🔧 MCP: Files] You> Read and explain main.py
|
|
||||||
|
|
||||||
# Enable write mode to let AI modify files
|
|
||||||
You> /mcp write on
|
|
||||||
[🔧✍️ MCP: Files+Write] You> Create a new Python file with helper functions
|
|
||||||
[🔧✍️ MCP: Files+Write] You> Refactor main.py to use async/await
|
|
||||||
|
|
||||||
# Switch to database mode
|
|
||||||
You> /mcp add db ~/myapp/data.db
|
|
||||||
You> /mcp db 1
|
|
||||||
[🗄️ MCP: DB #1] You> Show me all tables
|
|
||||||
[🗄️ MCP: DB #1] You> Find all users created this month
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## MCP Guide
|
## MCP (Model Context Protocol)
|
||||||
|
|
||||||
### File Mode (Default)
|
MCP allows the AI to interact with your local files and databases.
|
||||||
|
|
||||||
|
### File Access
|
||||||
|
|
||||||
**Setup:**
|
|
||||||
```bash
|
```bash
|
||||||
/mcp on # Start MCP server
|
/mcp on # Enable MCP
|
||||||
/mcp add ~/Projects # Grant access to folder
|
/mcp add ~/Projects # Grant access to folder
|
||||||
/mcp add ~/Documents # Add another folder
|
/mcp list # View allowed folders
|
||||||
/mcp list # View all allowed folders
|
|
||||||
```
|
|
||||||
|
|
||||||
**Natural Language Usage:**
|
# Now ask the AI:
|
||||||
```
|
|
||||||
"List all Python files in Projects"
|
"List all Python files in Projects"
|
||||||
"Read and explain config.yaml"
|
"Read and explain main.py"
|
||||||
"Search for files containing 'TODO'"
|
"Search for files containing 'TODO'"
|
||||||
"What's in my Documents folder?"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
**Available Tools (Read-Only):**
|
|
||||||
- `read_file` - Read complete file contents
|
|
||||||
- `list_directory` - List files/folders (recursive optional)
|
|
||||||
- `search_files` - Search by name or content
|
|
||||||
|
|
||||||
**Available Tools (Write Mode - requires `/mcp write on`):**
|
|
||||||
- `write_file` - Create new files or overwrite existing ones
|
|
||||||
- `edit_file` - Find and replace text in existing files
|
|
||||||
- `delete_file` - Delete files (always requires confirmation)
|
|
||||||
- `create_directory` - Create directories
|
|
||||||
- `move_file` - Move or rename files
|
|
||||||
- `copy_file` - Copy files to new locations
|
|
||||||
|
|
||||||
**Features:**
|
|
||||||
- ✅ Automatic .gitignore filtering (read operations only)
|
|
||||||
- ✅ Skips virtual environments (venv, node_modules)
|
|
||||||
- ✅ Handles large files (auto-truncates >50KB)
|
|
||||||
- ✅ Cross-platform (macOS, Linux, Windows via WSL)
|
|
||||||
- ✅ Write mode OFF by default for safety
|
|
||||||
- ✅ Delete operations require user confirmation with LLM's reason
|
|
||||||
|
|
||||||
### Database Mode
|
|
||||||
|
|
||||||
**Setup:**
|
|
||||||
```bash
|
|
||||||
/mcp add db ~/app/database.db # Add SQLite database
|
|
||||||
/mcp db list # View all databases
|
|
||||||
/mcp db 1 # Switch to database #1
|
|
||||||
```
|
|
||||||
|
|
||||||
**Natural Language Usage:**
|
|
||||||
```
|
|
||||||
"Show me all tables in this database"
|
|
||||||
"Find records mentioning 'error'"
|
|
||||||
"How many users registered last week?"
|
|
||||||
"Get the schema for the orders table"
|
|
||||||
"Show me the 10 most recent transactions"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Available Tools:**
|
|
||||||
- `inspect_database` - View schema, tables, columns, indexes
|
|
||||||
- `search_database` - Full-text search across tables
|
|
||||||
- `query_database` - Execute read-only SQL queries
|
|
||||||
|
|
||||||
**Supported Queries:**
|
|
||||||
- ✅ SELECT statements
|
|
||||||
- ✅ JOINs (INNER, LEFT, RIGHT, FULL)
|
|
||||||
- ✅ Subqueries
|
|
||||||
- ✅ CTEs (Common Table Expressions)
|
|
||||||
- ✅ Aggregations (COUNT, SUM, AVG, etc.)
|
|
||||||
- ✅ WHERE, GROUP BY, HAVING, ORDER BY, LIMIT
|
|
||||||
- ❌ INSERT/UPDATE/DELETE (blocked for safety)
|
|
||||||
|
|
||||||
### Write Mode
|
### Write Mode
|
||||||
|
|
||||||
**Enable Write Mode:**
|
|
||||||
```bash
|
```bash
|
||||||
/mcp write on # Enable write mode (shows warning, requires confirmation)
|
/mcp write on # Enable file modifications
|
||||||
|
|
||||||
|
# AI can now:
|
||||||
|
"Create a new file called utils.py"
|
||||||
|
"Edit config.json and update the API URL"
|
||||||
|
"Delete the old backup files" # Always asks for confirmation
|
||||||
```
|
```
|
||||||
|
|
||||||
**Natural Language Usage:**
|
### Database Mode
|
||||||
```
|
|
||||||
"Create a new Python file called utils.py with helper functions"
|
|
||||||
"Edit main.py and replace the old API endpoint with the new one"
|
|
||||||
"Delete the backup.old file" (will prompt for confirmation)
|
|
||||||
"Create a directory called tests"
|
|
||||||
"Move config.json to the config folder"
|
|
||||||
```
|
|
||||||
|
|
||||||
**Important:**
|
|
||||||
- ⚠️ Write mode is **OFF by default** and resets each session
|
|
||||||
- ⚠️ Delete operations **always** require user confirmation
|
|
||||||
- ⚠️ All operations are limited to allowed MCP folders
|
|
||||||
- ✅ Write operations ignore .gitignore (can write to any file in allowed folders)
|
|
||||||
|
|
||||||
**Disable Write Mode:**
|
|
||||||
```bash
|
|
||||||
/mcp write off # Disable write mode (back to read-only)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Mode Management
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
/mcp status # Show current mode, write mode, stats, folders/databases
|
/mcp add db ~/app/data.db # Add database
|
||||||
/mcp files # Switch to file mode
|
/mcp db 1 # Switch to database mode
|
||||||
/mcp db <number> # Switch to database mode
|
|
||||||
/mcp gitignore on # Enable .gitignore filtering (default)
|
# Ask the AI:
|
||||||
/mcp write on|off # Enable/disable write mode
|
"Show all tables"
|
||||||
/mcp remove 2 # Remove folder/database by number
|
"Find users created this month"
|
||||||
|
"What's the schema for the orders table?"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Command Reference
|
## Command Reference
|
||||||
|
|
||||||
### Session Commands
|
### Chat Commands
|
||||||
```
|
| Command | Description |
|
||||||
/help [command] Show help menu or detailed command help
|
|---------|-------------|
|
||||||
/help mcp Comprehensive MCP guide
|
| `/help [cmd]` | Show help |
|
||||||
/clear or /cl Clear terminal screen (or Ctrl+L)
|
| `/model [search]` | Select model |
|
||||||
/memory on|off Toggle conversation memory (save costs)
|
| `/info [model]` | Model details |
|
||||||
/online on|off Enable/disable web search
|
| `/memory on\|off` | Toggle context |
|
||||||
/paste [prompt] Paste clipboard content
|
| `/online on\|off` | Toggle web search |
|
||||||
/retry Resend last prompt
|
| `/retry` | Resend last message |
|
||||||
/reset Clear history and system prompt
|
| `/clear` | Clear screen |
|
||||||
/prev View previous response
|
|
||||||
/next View next response
|
|
||||||
```
|
|
||||||
|
|
||||||
### MCP Commands
|
### MCP Commands
|
||||||
```
|
| Command | Description |
|
||||||
/mcp on Start MCP server
|
|---------|-------------|
|
||||||
/mcp off Stop MCP server
|
| `/mcp on\|off` | Enable/disable MCP |
|
||||||
/mcp status Show comprehensive status (includes write mode)
|
| `/mcp status` | Show MCP status |
|
||||||
/mcp add <folder> Add folder for file access
|
| `/mcp add <path>` | Add folder |
|
||||||
/mcp add db <path> Add SQLite database
|
| `/mcp add db <path>` | Add database |
|
||||||
/mcp list List all folders
|
| `/mcp list` | List folders |
|
||||||
/mcp db list List all databases
|
| `/mcp db list` | List databases |
|
||||||
/mcp db <number> Switch to database mode
|
| `/mcp db <n>` | Switch to database |
|
||||||
/mcp files Switch to file mode
|
| `/mcp files` | Switch to file mode |
|
||||||
/mcp remove <num> Remove folder/database
|
| `/mcp write on\|off` | Toggle write mode |
|
||||||
/mcp gitignore on Enable .gitignore filtering
|
|
||||||
/mcp write on Enable write mode (create/edit/delete files)
|
|
||||||
/mcp write off Disable write mode (read-only)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Model Commands
|
### Conversation Commands
|
||||||
```
|
| Command | Description |
|
||||||
/model [search] Select/change AI model
|
|---------|-------------|
|
||||||
/info [model_id] Show model details (pricing, capabilities)
|
| `/save <name>` | Save conversation |
|
||||||
```
|
| `/load <name>` | Load conversation |
|
||||||
|
| `/list` | List saved conversations |
|
||||||
|
| `/delete <name>` | Delete conversation |
|
||||||
|
| `/export md\|json\|html <file>` | Export |
|
||||||
|
|
||||||
### Configuration
|
### Configuration
|
||||||
```
|
| Command | Description |
|
||||||
/config View all settings
|
|---------|-------------|
|
||||||
/config api Set API key
|
| `/config` | View settings |
|
||||||
/config model Set default model
|
| `/config api` | Set API key |
|
||||||
/config online Set default online mode (on|off)
|
| `/config model <id>` | Set default model |
|
||||||
/config stream Enable/disable streaming (on|off)
|
| `/config stream on\|off` | Toggle streaming |
|
||||||
/config maxtoken Set max token limit
|
| `/stats` | Session statistics |
|
||||||
/config costwarning Set cost warning threshold ($)
|
| `/credits` | Check credits |
|
||||||
/config loglevel Set log level (debug/info/warning/error)
|
|
||||||
/config log Set log file size (MB)
|
|
||||||
```
|
|
||||||
|
|
||||||
### Conversation Management
|
## CLI Options
|
||||||
```
|
|
||||||
/save <name> Save conversation
|
|
||||||
/load <name|num> Load saved conversation
|
|
||||||
/delete <name|num> Delete conversation
|
|
||||||
/list List saved conversations
|
|
||||||
/export md|json|html <file> Export conversation
|
|
||||||
```
|
|
||||||
|
|
||||||
### Token & System
|
|
||||||
```
|
|
||||||
/maxtoken [value] Set session token limit
|
|
||||||
/system [prompt] Set system prompt (use 'clear' to reset)
|
|
||||||
/middleout on|off Enable prompt compression
|
|
||||||
```
|
|
||||||
|
|
||||||
### Monitoring
|
|
||||||
```
|
|
||||||
/stats View session statistics
|
|
||||||
/credits Check OpenRouter credits
|
|
||||||
```
|
|
||||||
|
|
||||||
### File Attachments
|
|
||||||
```
|
|
||||||
@/path/to/file Attach file (images, PDFs, code)
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
Debug @script.py
|
|
||||||
Analyze @data.json
|
|
||||||
Review @screenshot.png
|
|
||||||
```
|
|
||||||
|
|
||||||
## Configuration Options
|
|
||||||
|
|
||||||
All configuration stored in `~/.config/oai/`:
|
|
||||||
|
|
||||||
### Files
|
|
||||||
- `oai_config.db` - SQLite database (settings, conversations, MCP config)
|
|
||||||
- `oai.log` - Application logs (rotating, configurable size)
|
|
||||||
- `history.txt` - Command history (searchable with Ctrl+R)
|
|
||||||
|
|
||||||
### Key Settings
|
|
||||||
- **API Key**: OpenRouter authentication
|
|
||||||
- **Default Model**: Auto-select on startup
|
|
||||||
- **Streaming**: Real-time response display
|
|
||||||
- **Max Tokens**: Global and session limits
|
|
||||||
- **Cost Warning**: Alert threshold for expensive requests
|
|
||||||
- **Online Mode**: Default web search setting
|
|
||||||
- **Log Level**: debug/info/warning/error/critical
|
|
||||||
- **Log Size**: Rotating file size in MB
|
|
||||||
|
|
||||||
## Supported File Types
|
|
||||||
|
|
||||||
### Code Files
|
|
||||||
`.py, .js, .ts, .cs, .java, .c, .cpp, .h, .hpp, .rb, .ruby, .php, .swift, .kt, .kts, .go, .sh, .bat, .ps1, .R, .scala, .pl, .lua, .dart, .elm`
|
|
||||||
|
|
||||||
### Data Files
|
|
||||||
`.json, .yaml, .yml, .xml, .csv, .txt, .md`
|
|
||||||
|
|
||||||
### Images
|
|
||||||
All standard formats: PNG, JPEG, JPG, GIF, WEBP, BMP
|
|
||||||
|
|
||||||
### Documents
|
|
||||||
PDF (models with document support)
|
|
||||||
|
|
||||||
### Size Limits
|
|
||||||
- Images: 10 MB max
|
|
||||||
- Code/Text: Auto-truncates files >50KB
|
|
||||||
- Binary data: Displayed as `<binary: X bytes>`
|
|
||||||
|
|
||||||
## MCP Security
|
|
||||||
|
|
||||||
### Access Control
|
|
||||||
- ✅ Explicit folder/database approval required
|
|
||||||
- ✅ System directories blocked automatically
|
|
||||||
- ✅ User confirmation for each addition
|
|
||||||
- ✅ .gitignore patterns respected (file mode)
|
|
||||||
|
|
||||||
### Database Safety
|
|
||||||
- ✅ Read-only mode (cannot modify data)
|
|
||||||
- ✅ SQL query validation (blocks INSERT/UPDATE/DELETE)
|
|
||||||
- ✅ Query timeout (5 seconds max)
|
|
||||||
- ✅ Result limits (1000 rows max)
|
|
||||||
- ✅ Database opened in `mode=ro`
|
|
||||||
|
|
||||||
### File System Safety
|
|
||||||
- ✅ Read-only by default (write mode requires explicit opt-in)
|
|
||||||
- ✅ Write mode OFF by default each session (non-persistent)
|
|
||||||
- ✅ Delete operations always require user confirmation
|
|
||||||
- ✅ Write operations limited to allowed folders only
|
|
||||||
- ✅ System directories blocked
|
|
||||||
- ✅ Virtual environment exclusion
|
|
||||||
- ✅ Build artifact filtering
|
|
||||||
- ✅ Maximum file size (10 MB)
|
|
||||||
|
|
||||||
## Tips & Tricks
|
|
||||||
|
|
||||||
### Command History
|
|
||||||
- **↑/↓ arrows**: Navigate previous commands
|
|
||||||
- **Ctrl+R**: Search command history
|
|
||||||
- **Auto-complete**: Start typing `/` for command suggestions
|
|
||||||
|
|
||||||
### Cost Optimization
|
|
||||||
```bash
|
```bash
|
||||||
/memory off # Disable context (stateless mode)
|
oai chat [OPTIONS]
|
||||||
/maxtoken 1000 # Limit response length
|
|
||||||
/config costwarning 0.01 # Set alert threshold
|
Options:
|
||||||
|
-m, --model TEXT Model ID to use
|
||||||
|
-s, --system TEXT System prompt
|
||||||
|
-o, --online Enable online mode
|
||||||
|
--mcp Enable MCP server
|
||||||
|
--help Show help
|
||||||
```
|
```
|
||||||
|
|
||||||
### MCP Best Practices
|
Other commands:
|
||||||
```bash
|
```bash
|
||||||
# Check status frequently
|
oai config [setting] [value] # Configure settings
|
||||||
/mcp status
|
oai version # Show version
|
||||||
|
oai credits # Check credits
|
||||||
# Use specific paths to reduce search time
|
|
||||||
"List Python files in Projects/app/" # Better than
|
|
||||||
"List all Python files" # Slower
|
|
||||||
|
|
||||||
# Database queries - be specific
|
|
||||||
"SELECT * FROM users LIMIT 10" # Good
|
|
||||||
"SELECT * FROM users" # May hit row limit
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Debugging
|
## Configuration
|
||||||
```bash
|
|
||||||
# Enable debug logging
|
|
||||||
/config loglevel debug
|
|
||||||
|
|
||||||
# Check log file
|
Configuration is stored in `~/.config/oai/`:
|
||||||
tail -f ~/.config/oai/oai.log
|
|
||||||
|
|
||||||
# View MCP statistics
|
| File | Purpose |
|
||||||
/mcp status # Shows tool call counts
|
|------|---------|
|
||||||
|
| `oai_config.db` | Settings, conversations, MCP config |
|
||||||
|
| `oai.log` | Application logs |
|
||||||
|
| `history.txt` | Command history |
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
oai/
|
||||||
|
├── oai/
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── __main__.py # Entry point for python -m oai
|
||||||
|
│ ├── cli.py # Main CLI interface
|
||||||
|
│ ├── constants.py # Configuration constants
|
||||||
|
│ ├── commands/ # Slash command handlers
|
||||||
|
│ ├── config/ # Settings and database
|
||||||
|
│ ├── core/ # Chat client and session
|
||||||
|
│ ├── mcp/ # MCP server and tools
|
||||||
|
│ ├── providers/ # AI provider abstraction
|
||||||
|
│ ├── ui/ # Terminal UI utilities
|
||||||
|
│ └── utils/ # Logging, export, etc.
|
||||||
|
├── pyproject.toml # Package configuration
|
||||||
|
├── build.sh # Binary build script
|
||||||
|
└── README.md
|
||||||
```
|
```
|
||||||
|
|
||||||
## Troubleshooting
|
## Troubleshooting
|
||||||
|
|
||||||
### MCP Not Working
|
### macOS Binary Issues
|
||||||
```bash
|
|
||||||
# 1. Check if MCP is installed
|
|
||||||
python3 -c "import mcp; print('MCP OK')"
|
|
||||||
|
|
||||||
# 2. Verify model supports function calling
|
```bash
|
||||||
|
# Remove quarantine attribute
|
||||||
|
xattr -cr ~/.local/bin/oai
|
||||||
|
|
||||||
|
# Then in Finder: right-click oai → Open With → Terminal → Click "Open"
|
||||||
|
# After this, oai works from any terminal
|
||||||
|
```
|
||||||
|
|
||||||
|
### MCP Not Working
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Check if model supports function calling
|
||||||
/info # Look for "tools" in supported parameters
|
/info # Look for "tools" in supported parameters
|
||||||
|
|
||||||
# 3. Check MCP status
|
# Check MCP status
|
||||||
/mcp status
|
/mcp status
|
||||||
|
|
||||||
# 4. Review logs
|
# View logs
|
||||||
tail ~/.config/oai/oai.log
|
tail -f ~/.config/oai/oai.log
|
||||||
```
|
```
|
||||||
|
|
||||||
### Import Errors
|
### Import Errors
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Reinstall dependencies
|
# Reinstall package
|
||||||
pip install --force-reinstall -r requirements.txt
|
pip install -e . --force-reinstall
|
||||||
```
|
|
||||||
|
|
||||||
### Binary Issues (macOS)
|
|
||||||
```bash
|
|
||||||
# Remove quarantine
|
|
||||||
xattr -cr ~/.local/bin/oai
|
|
||||||
|
|
||||||
# Check security settings
|
|
||||||
# System Settings > Privacy & Security > "Allow Anyway"
|
|
||||||
```
|
|
||||||
|
|
||||||
### Database Errors
|
|
||||||
```bash
|
|
||||||
# Verify it's a valid SQLite database
|
|
||||||
sqlite3 database.db ".tables"
|
|
||||||
|
|
||||||
# Check file permissions
|
|
||||||
ls -la database.db
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Version History
|
## Version History
|
||||||
|
|
||||||
### v2.1.0-RC1 (Current)
|
### v2.1.0 (Current)
|
||||||
- ✨ **NEW**: MCP (Model Context Protocol) integration
|
- 🏗️ Complete codebase refactoring to modular package structure
|
||||||
- ✨ **NEW**: File system access (read, search, list)
|
- 🔌 Extensible provider architecture for adding new AI providers
|
||||||
- ✨ **NEW**: Write mode - AI can create, edit, and delete files
|
- 📦 Proper Python packaging with pyproject.toml
|
||||||
- 6 write tools: write_file, edit_file, delete_file, create_directory, move_file, copy_file
|
- ✨ MCP integration (file access, write mode, database queries)
|
||||||
- OFF by default - requires explicit `/mcp write on` activation
|
- 🔧 Command registry pattern for slash commands
|
||||||
- Delete operations always require user confirmation
|
- 📊 Improved cost tracking and session statistics
|
||||||
- Non-persistent setting (resets each session)
|
|
||||||
- ✨ **NEW**: SQLite database querying (read-only)
|
|
||||||
- ✨ **NEW**: Dual mode support (Files & Database)
|
|
||||||
- ✨ **NEW**: .gitignore filtering
|
|
||||||
- ✨ **NEW**: Binary data handling in databases
|
|
||||||
- ✨ **NEW**: Mode indicators in prompt (shows ✍️ when write mode active)
|
|
||||||
- ✨ **NEW**: Comprehensive `/help mcp` guide
|
|
||||||
- 🔧 Improved error handling for tool calls
|
|
||||||
- 🔧 Enhanced logging for MCP operations
|
|
||||||
- 🔧 Statistics tracking for tool usage
|
|
||||||
|
|
||||||
### v1.9.6
|
### v1.9.x
|
||||||
- Base version with core chat functionality
|
- Single-file implementation
|
||||||
- Conversation management
|
- Core chat functionality
|
||||||
- File attachments
|
- File attachments
|
||||||
- Cost tracking
|
- Conversation management
|
||||||
- Export capabilities
|
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
MIT License
|
MIT License - See [LICENSE](LICENSE) for details.
|
||||||
|
|
||||||
Copyright (c) 2024-2025 Rune Olsen
|
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
|
||||||
in the Software without restriction, including without limitation the rights
|
|
||||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
copies of the Software, and to permit persons to whom the Software is
|
|
||||||
furnished to do so, subject to the following conditions:
|
|
||||||
|
|
||||||
The above copyright notice and this permission notice shall be included in all
|
|
||||||
copies or substantial portions of the Software.
|
|
||||||
|
|
||||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
||||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
||||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
||||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
||||||
SOFTWARE.
|
|
||||||
|
|
||||||
Full license: https://opensource.org/licenses/MIT
|
|
||||||
|
|
||||||
## Author
|
## Author
|
||||||
|
|
||||||
**Rune Olsen**
|
**Rune Olsen**
|
||||||
|
|
||||||
- Homepage: https://ai.fubar.pm/
|
|
||||||
- Blog: https://blog.rune.pm
|
|
||||||
- Project: https://iurl.no/oai
|
- Project: https://iurl.no/oai
|
||||||
|
- Repository: https://gitlab.pm/rune/oai
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
Contributions welcome! Please:
|
|
||||||
1. Fork the repository
|
1. Fork the repository
|
||||||
2. Create a feature branch
|
2. Create a feature branch
|
||||||
3. Submit a pull request with detailed description
|
3. Submit a pull request
|
||||||
|
|
||||||
## Acknowledgments
|
|
||||||
|
|
||||||
- OpenRouter team for the unified AI API
|
|
||||||
- Rich library for beautiful terminal output
|
|
||||||
- MCP community for the protocol specification
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
**Star ⭐ this project if you find it useful!**
|
**⭐ Star this project if you find it useful!**
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
Did you really read all the way down here? WOW! You deserve a 🍾 🥂!
|
|
||||||
|
|||||||
26
oai/__init__.py
Normal file
26
oai/__init__.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
oAI - OpenRouter AI Chat Client
|
||||||
|
|
||||||
|
A feature-rich terminal-based chat application that provides an interactive CLI
|
||||||
|
interface to OpenRouter's unified AI API with advanced Model Context Protocol (MCP)
|
||||||
|
integration for filesystem and database access.
|
||||||
|
|
||||||
|
Author: Rune
|
||||||
|
License: MIT
|
||||||
|
"""
|
||||||
|
|
||||||
|
__version__ = "2.1.0"
|
||||||
|
__author__ = "Rune"
|
||||||
|
__license__ = "MIT"
|
||||||
|
|
||||||
|
# Lazy imports to avoid circular dependencies and improve startup time
|
||||||
|
# Full imports are available via submodules:
|
||||||
|
# from oai.config import Settings, Database
|
||||||
|
# from oai.providers import OpenRouterProvider, AIProvider
|
||||||
|
# from oai.mcp import MCPManager
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"__version__",
|
||||||
|
"__author__",
|
||||||
|
"__license__",
|
||||||
|
]
|
||||||
8
oai/__main__.py
Normal file
8
oai/__main__.py
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
"""
|
||||||
|
Entry point for running oAI as a module: python -m oai
|
||||||
|
"""
|
||||||
|
|
||||||
|
from oai.cli import main
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
719
oai/cli.py
Normal file
719
oai/cli.py
Normal file
@@ -0,0 +1,719 @@
|
|||||||
|
"""
|
||||||
|
Main CLI entry point for oAI.
|
||||||
|
|
||||||
|
This module provides the command-line interface for the oAI chat application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import typer
|
||||||
|
from prompt_toolkit import PromptSession
|
||||||
|
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
|
||||||
|
from prompt_toolkit.history import FileHistory
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
from rich.panel import Panel
|
||||||
|
|
||||||
|
from oai import __version__
|
||||||
|
from oai.commands import register_all_commands, registry
|
||||||
|
from oai.commands.registry import CommandContext, CommandStatus
|
||||||
|
from oai.config.database import Database
|
||||||
|
from oai.config.settings import Settings
|
||||||
|
from oai.constants import (
|
||||||
|
APP_NAME,
|
||||||
|
APP_URL,
|
||||||
|
APP_VERSION,
|
||||||
|
CONFIG_DIR,
|
||||||
|
HISTORY_FILE,
|
||||||
|
VALID_COMMANDS,
|
||||||
|
)
|
||||||
|
from oai.core.client import AIClient
|
||||||
|
from oai.core.session import ChatSession
|
||||||
|
from oai.mcp.manager import MCPManager
|
||||||
|
from oai.providers.base import UsageStats
|
||||||
|
from oai.providers.openrouter import OpenRouterProvider
|
||||||
|
from oai.ui.console import (
|
||||||
|
clear_screen,
|
||||||
|
console,
|
||||||
|
display_panel,
|
||||||
|
print_error,
|
||||||
|
print_info,
|
||||||
|
print_success,
|
||||||
|
print_warning,
|
||||||
|
)
|
||||||
|
from oai.ui.tables import create_model_table, display_paginated_table
|
||||||
|
from oai.utils.logging import LoggingManager, get_logger
|
||||||
|
|
||||||
|
# Create Typer app
|
||||||
|
app = typer.Typer(
|
||||||
|
name="oai",
|
||||||
|
help=f"oAI - OpenRouter AI Chat Client\n\nVersion: {APP_VERSION}",
|
||||||
|
add_completion=False,
|
||||||
|
epilog="For more information, visit: " + APP_URL,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@app.callback(invoke_without_command=True)
|
||||||
|
def main_callback(
|
||||||
|
ctx: typer.Context,
|
||||||
|
version_flag: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--version",
|
||||||
|
"-v",
|
||||||
|
help="Show version information",
|
||||||
|
is_flag=True,
|
||||||
|
),
|
||||||
|
) -> None:
|
||||||
|
"""Main callback to handle global options."""
|
||||||
|
# Show version with update check if --version flag
|
||||||
|
if version_flag:
|
||||||
|
version_info = check_for_updates(APP_VERSION)
|
||||||
|
console.print(version_info)
|
||||||
|
raise typer.Exit()
|
||||||
|
|
||||||
|
# Show version with update check when --help is requested
|
||||||
|
if "--help" in sys.argv or "-h" in sys.argv:
|
||||||
|
version_info = check_for_updates(APP_VERSION)
|
||||||
|
console.print(f"\n{version_info}\n")
|
||||||
|
|
||||||
|
# Continue to subcommand if provided
|
||||||
|
if ctx.invoked_subcommand is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
def check_for_updates(current_version: str) -> str:
|
||||||
|
"""Check for available updates."""
|
||||||
|
import requests
|
||||||
|
from packaging import version as pkg_version
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
"https://gitlab.pm/api/v1/repos/rune/oai/releases/latest",
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
timeout=1.0,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
version_online = data.get("tag_name", "").lstrip("v")
|
||||||
|
|
||||||
|
if not version_online:
|
||||||
|
return f"[bold green]oAI version {current_version}[/]"
|
||||||
|
|
||||||
|
current = pkg_version.parse(current_version)
|
||||||
|
latest = pkg_version.parse(version_online)
|
||||||
|
|
||||||
|
if latest > current:
|
||||||
|
return (
|
||||||
|
f"[bold green]oAI version {current_version}[/] "
|
||||||
|
f"[bold red](Update available: {current_version} → {version_online})[/]"
|
||||||
|
)
|
||||||
|
return f"[bold green]oAI version {current_version} (up to date)[/]"
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
return f"[bold green]oAI version {current_version}[/]"
|
||||||
|
|
||||||
|
|
||||||
|
def show_welcome(settings: Settings, version_info: str) -> None:
|
||||||
|
"""Display welcome message."""
|
||||||
|
console.print(Panel.fit(
|
||||||
|
f"{version_info}\n\n"
|
||||||
|
"[bold cyan]Commands:[/] /help for commands, /model to select model\n"
|
||||||
|
"[bold cyan]MCP:[/] /mcp on to enable file/database access\n"
|
||||||
|
"[bold cyan]Exit:[/] Type 'exit', 'quit', or 'bye'",
|
||||||
|
title=f"[bold green]Welcome to {APP_NAME}[/]",
|
||||||
|
border_style="green",
|
||||||
|
))
|
||||||
|
|
||||||
|
|
||||||
|
def select_model(client: AIClient, search_term: Optional[str] = None) -> Optional[dict]:
|
||||||
|
"""Display model selection interface."""
|
||||||
|
try:
|
||||||
|
models = client.provider.get_raw_models()
|
||||||
|
if not models:
|
||||||
|
print_error("No models available")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Filter by search term if provided
|
||||||
|
if search_term:
|
||||||
|
search_lower = search_term.lower()
|
||||||
|
models = [m for m in models if search_lower in m.get("id", "").lower()]
|
||||||
|
|
||||||
|
if not models:
|
||||||
|
print_error(f"No models found matching '{search_term}'")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Create and display table
|
||||||
|
table = create_model_table(models)
|
||||||
|
display_paginated_table(
|
||||||
|
table,
|
||||||
|
f"[bold green]Available Models ({len(models)})[/]",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prompt for selection
|
||||||
|
console.print("")
|
||||||
|
try:
|
||||||
|
choice = input("Enter model number (or press Enter to cancel): ").strip()
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not choice:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
index = int(choice) - 1
|
||||||
|
if 0 <= index < len(models):
|
||||||
|
selected = models[index]
|
||||||
|
print_success(f"Selected model: {selected['id']}")
|
||||||
|
return selected
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
print_error("Invalid selection")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print_error(f"Failed to fetch models: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def run_chat_loop(
|
||||||
|
session: ChatSession,
|
||||||
|
prompt_session: PromptSession,
|
||||||
|
settings: Settings,
|
||||||
|
) -> None:
|
||||||
|
"""Run the main chat loop."""
|
||||||
|
logger = get_logger()
|
||||||
|
mcp_manager = session.mcp_manager
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# Build prompt prefix
|
||||||
|
prefix = "You> "
|
||||||
|
if mcp_manager and mcp_manager.enabled:
|
||||||
|
if mcp_manager.mode == "files":
|
||||||
|
if mcp_manager.write_enabled:
|
||||||
|
prefix = "[🔧✍️ MCP: Files+Write] You> "
|
||||||
|
else:
|
||||||
|
prefix = "[🔧 MCP: Files] You> "
|
||||||
|
elif mcp_manager.mode == "database" and mcp_manager.selected_db_index is not None:
|
||||||
|
prefix = f"[🗄️ MCP: DB #{mcp_manager.selected_db_index + 1}] You> "
|
||||||
|
|
||||||
|
# Get user input
|
||||||
|
user_input = prompt_session.prompt(
|
||||||
|
prefix,
|
||||||
|
auto_suggest=AutoSuggestFromHistory(),
|
||||||
|
).strip()
|
||||||
|
|
||||||
|
if not user_input:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle escape sequence
|
||||||
|
if user_input.startswith("//"):
|
||||||
|
user_input = user_input[1:]
|
||||||
|
|
||||||
|
# Check for exit
|
||||||
|
if user_input.lower() in ["exit", "quit", "bye"]:
|
||||||
|
console.print(
|
||||||
|
f"\n[bold yellow]Goodbye![/]\n"
|
||||||
|
f"[dim]Session: {session.stats.total_tokens:,} tokens, "
|
||||||
|
f"${session.stats.total_cost:.4f}[/]"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"Session ended. Messages: {session.stats.message_count}, "
|
||||||
|
f"Tokens: {session.stats.total_tokens}, "
|
||||||
|
f"Cost: ${session.stats.total_cost:.4f}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check for unknown commands
|
||||||
|
if user_input.startswith("/"):
|
||||||
|
cmd_word = user_input.split()[0].lower()
|
||||||
|
if not registry.is_command(user_input):
|
||||||
|
# Check if it's a valid command prefix
|
||||||
|
is_valid = any(cmd_word.startswith(cmd) for cmd in VALID_COMMANDS)
|
||||||
|
if not is_valid:
|
||||||
|
print_error(f"Unknown command: {cmd_word}")
|
||||||
|
print_info("Type /help to see available commands.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Try to execute as command
|
||||||
|
context = session.get_context()
|
||||||
|
result = registry.execute(user_input, context)
|
||||||
|
|
||||||
|
if result:
|
||||||
|
# Update session state from context
|
||||||
|
session.memory_enabled = context.memory_enabled
|
||||||
|
session.memory_start_index = context.memory_start_index
|
||||||
|
session.online_enabled = context.online_enabled
|
||||||
|
session.middle_out_enabled = context.middle_out_enabled
|
||||||
|
session.session_max_token = context.session_max_token
|
||||||
|
session.current_index = context.current_index
|
||||||
|
session.system_prompt = context.session_system_prompt
|
||||||
|
|
||||||
|
if result.status == CommandStatus.EXIT:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle special results
|
||||||
|
if result.data:
|
||||||
|
# Retry - resend last prompt
|
||||||
|
if "retry_prompt" in result.data:
|
||||||
|
user_input = result.data["retry_prompt"]
|
||||||
|
# Fall through to send message
|
||||||
|
|
||||||
|
# Paste - send clipboard content
|
||||||
|
elif "paste_prompt" in result.data:
|
||||||
|
user_input = result.data["paste_prompt"]
|
||||||
|
# Fall through to send message
|
||||||
|
|
||||||
|
# Model selection
|
||||||
|
elif "show_model_selector" in result.data:
|
||||||
|
search = result.data.get("search", "")
|
||||||
|
model = select_model(session.client, search if search else None)
|
||||||
|
if model:
|
||||||
|
session.set_model(model)
|
||||||
|
# If this came from /config model, also save as default
|
||||||
|
if result.data.get("set_as_default"):
|
||||||
|
settings.set_default_model(model["id"])
|
||||||
|
print_success(f"Default model set to: {model['id']}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Load conversation
|
||||||
|
elif "load_conversation" in result.data:
|
||||||
|
history = result.data.get("history", [])
|
||||||
|
session.history.clear()
|
||||||
|
from oai.core.session import HistoryEntry
|
||||||
|
for entry in history:
|
||||||
|
session.history.append(HistoryEntry(
|
||||||
|
prompt=entry.get("prompt", ""),
|
||||||
|
response=entry.get("response", ""),
|
||||||
|
prompt_tokens=entry.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=entry.get("completion_tokens", 0),
|
||||||
|
msg_cost=entry.get("msg_cost", 0.0),
|
||||||
|
))
|
||||||
|
session.current_index = len(session.history) - 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Normal command completed
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
# Command completed with no special data
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Ensure model is selected
|
||||||
|
if not session.selected_model:
|
||||||
|
print_warning("Please select a model first with /model")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Send message
|
||||||
|
stream = settings.stream_enabled
|
||||||
|
if mcp_manager and mcp_manager.enabled:
|
||||||
|
tools = session.get_mcp_tools()
|
||||||
|
if tools:
|
||||||
|
stream = False # Disable streaming with tools
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
console.print(
|
||||||
|
"[bold green]Streaming response...[/] "
|
||||||
|
"[dim](Press Ctrl+C to cancel)[/]"
|
||||||
|
)
|
||||||
|
if session.online_enabled:
|
||||||
|
console.print("[dim cyan]🌐 Online mode active[/]")
|
||||||
|
console.print("")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response_text, usage, response_time = session.send_message(
|
||||||
|
user_input,
|
||||||
|
stream=stream,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print_error(f"Error: {e}")
|
||||||
|
logger.error(f"Message error: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not response_text:
|
||||||
|
print_error("No response received")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Display non-streaming response
|
||||||
|
if not stream:
|
||||||
|
console.print()
|
||||||
|
display_panel(
|
||||||
|
Markdown(response_text),
|
||||||
|
title="[bold green]AI Response[/]",
|
||||||
|
border_style="green",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate cost and tokens
|
||||||
|
cost = 0.0
|
||||||
|
tokens = 0
|
||||||
|
estimated = False
|
||||||
|
|
||||||
|
if usage and (usage.prompt_tokens > 0 or usage.completion_tokens > 0):
|
||||||
|
tokens = usage.total_tokens
|
||||||
|
if usage.total_cost_usd:
|
||||||
|
cost = usage.total_cost_usd
|
||||||
|
else:
|
||||||
|
cost = session.client.estimate_cost(
|
||||||
|
session.selected_model["id"],
|
||||||
|
usage.prompt_tokens,
|
||||||
|
usage.completion_tokens,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Estimate tokens when usage not available (streaming fallback)
|
||||||
|
# Rough estimate: ~4 characters per token for English text
|
||||||
|
est_input_tokens = len(user_input) // 4 + 1
|
||||||
|
est_output_tokens = len(response_text) // 4 + 1
|
||||||
|
tokens = est_input_tokens + est_output_tokens
|
||||||
|
cost = session.client.estimate_cost(
|
||||||
|
session.selected_model["id"],
|
||||||
|
est_input_tokens,
|
||||||
|
est_output_tokens,
|
||||||
|
)
|
||||||
|
# Create estimated usage for session tracking
|
||||||
|
usage = UsageStats(
|
||||||
|
prompt_tokens=est_input_tokens,
|
||||||
|
completion_tokens=est_output_tokens,
|
||||||
|
total_tokens=tokens,
|
||||||
|
)
|
||||||
|
estimated = True
|
||||||
|
|
||||||
|
# Add to history
|
||||||
|
session.add_to_history(user_input, response_text, usage, cost)
|
||||||
|
|
||||||
|
# Display metrics
|
||||||
|
est_marker = "~" if estimated else ""
|
||||||
|
context_info = ""
|
||||||
|
if session.memory_enabled:
|
||||||
|
context_count = len(session.history) - session.memory_start_index
|
||||||
|
if context_count > 1:
|
||||||
|
context_info = f", Context: {context_count} msg(s)"
|
||||||
|
else:
|
||||||
|
context_info = ", Memory: OFF"
|
||||||
|
|
||||||
|
online_emoji = " 🌐" if session.online_enabled else ""
|
||||||
|
mcp_emoji = ""
|
||||||
|
if mcp_manager and mcp_manager.enabled:
|
||||||
|
if mcp_manager.mode == "files":
|
||||||
|
mcp_emoji = " 🔧"
|
||||||
|
elif mcp_manager.mode == "database":
|
||||||
|
mcp_emoji = " 🗄️"
|
||||||
|
|
||||||
|
console.print(
|
||||||
|
f"\n[dim blue]📊 {est_marker}{tokens} tokens | {est_marker}${cost:.4f} | {response_time:.2f}s"
|
||||||
|
f"{context_info}{online_emoji}{mcp_emoji} | "
|
||||||
|
f"Session: {est_marker}{session.stats.total_tokens:,} tokens | "
|
||||||
|
f"{est_marker}${session.stats.total_cost:.4f}[/]"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check warnings
|
||||||
|
warnings = session.check_warnings()
|
||||||
|
for warning in warnings:
|
||||||
|
print_warning(warning)
|
||||||
|
|
||||||
|
# Offer to copy
|
||||||
|
console.print("")
|
||||||
|
try:
|
||||||
|
from oai.ui.prompts import prompt_copy_response
|
||||||
|
prompt_copy_response(response_text)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
console.print("")
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
console.print("\n[dim]Input cancelled[/]")
|
||||||
|
continue
|
||||||
|
except EOFError:
|
||||||
|
console.print("\n[bold yellow]Goodbye![/]")
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def chat(
|
||||||
|
model: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--model",
|
||||||
|
"-m",
|
||||||
|
help="Model ID to use",
|
||||||
|
),
|
||||||
|
system: Optional[str] = typer.Option(
|
||||||
|
None,
|
||||||
|
"--system",
|
||||||
|
"-s",
|
||||||
|
help="System prompt",
|
||||||
|
),
|
||||||
|
online: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--online",
|
||||||
|
"-o",
|
||||||
|
help="Enable online mode",
|
||||||
|
),
|
||||||
|
mcp: bool = typer.Option(
|
||||||
|
False,
|
||||||
|
"--mcp",
|
||||||
|
help="Enable MCP server",
|
||||||
|
),
|
||||||
|
) -> None:
|
||||||
|
"""Start an interactive chat session."""
|
||||||
|
# Setup logging
|
||||||
|
logging_manager = LoggingManager()
|
||||||
|
logging_manager.setup()
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
# Clear screen
|
||||||
|
clear_screen()
|
||||||
|
|
||||||
|
# Load settings
|
||||||
|
settings = Settings.load()
|
||||||
|
|
||||||
|
# Check API key
|
||||||
|
if not settings.api_key:
|
||||||
|
print_error("No API key configured")
|
||||||
|
print_info("Run: oai --config api to set your API key")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
# Initialize client
|
||||||
|
try:
|
||||||
|
client = AIClient(
|
||||||
|
api_key=settings.api_key,
|
||||||
|
base_url=settings.base_url,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print_error(f"Failed to initialize client: {e}")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
# Register commands
|
||||||
|
register_all_commands()
|
||||||
|
|
||||||
|
# Check for updates and show welcome
|
||||||
|
version_info = check_for_updates(APP_VERSION)
|
||||||
|
show_welcome(settings, version_info)
|
||||||
|
|
||||||
|
# Initialize MCP manager
|
||||||
|
mcp_manager = MCPManager()
|
||||||
|
if mcp:
|
||||||
|
result = mcp_manager.enable()
|
||||||
|
if result["success"]:
|
||||||
|
print_success("MCP enabled")
|
||||||
|
else:
|
||||||
|
print_warning(f"MCP: {result.get('error', 'Failed to enable')}")
|
||||||
|
|
||||||
|
# Create session
|
||||||
|
session = ChatSession(
|
||||||
|
client=client,
|
||||||
|
settings=settings,
|
||||||
|
mcp_manager=mcp_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Set system prompt
|
||||||
|
if system:
|
||||||
|
session.system_prompt = system
|
||||||
|
print_info(f"System prompt: {system}")
|
||||||
|
|
||||||
|
# Set online mode
|
||||||
|
if online:
|
||||||
|
session.online_enabled = True
|
||||||
|
print_info("Online mode enabled")
|
||||||
|
|
||||||
|
# Select model
|
||||||
|
if model:
|
||||||
|
raw_model = client.get_raw_model(model)
|
||||||
|
if raw_model:
|
||||||
|
session.set_model(raw_model)
|
||||||
|
else:
|
||||||
|
print_warning(f"Model '{model}' not found")
|
||||||
|
elif settings.default_model:
|
||||||
|
raw_model = client.get_raw_model(settings.default_model)
|
||||||
|
if raw_model:
|
||||||
|
session.set_model(raw_model)
|
||||||
|
else:
|
||||||
|
print_warning(f"Default model '{settings.default_model}' not available")
|
||||||
|
|
||||||
|
# Setup prompt session
|
||||||
|
HISTORY_FILE.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
prompt_session = PromptSession(
|
||||||
|
history=FileHistory(str(HISTORY_FILE)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Run chat loop
|
||||||
|
run_chat_loop(session, prompt_session, settings)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def config(
|
||||||
|
setting: Optional[str] = typer.Argument(
|
||||||
|
None,
|
||||||
|
help="Setting to configure (api, url, model, system, stream, costwarning, maxtoken, online, log, loglevel)",
|
||||||
|
),
|
||||||
|
value: Optional[str] = typer.Argument(
|
||||||
|
None,
|
||||||
|
help="Value to set",
|
||||||
|
),
|
||||||
|
) -> None:
|
||||||
|
"""View or modify configuration settings."""
|
||||||
|
settings = Settings.load()
|
||||||
|
|
||||||
|
if not setting:
|
||||||
|
# Show all settings
|
||||||
|
from rich.table import Table
|
||||||
|
from oai.constants import DEFAULT_SYSTEM_PROMPT
|
||||||
|
|
||||||
|
table = Table("Setting", "Value", show_header=True, header_style="bold magenta")
|
||||||
|
table.add_row("API Key", "***" + settings.api_key[-4:] if settings.api_key else "Not set")
|
||||||
|
table.add_row("Base URL", settings.base_url)
|
||||||
|
table.add_row("Default Model", settings.default_model or "Not set")
|
||||||
|
|
||||||
|
# Show system prompt status
|
||||||
|
if settings.default_system_prompt is None:
|
||||||
|
system_prompt_display = f"[default] {DEFAULT_SYSTEM_PROMPT[:40]}..."
|
||||||
|
elif settings.default_system_prompt == "":
|
||||||
|
system_prompt_display = "[blank]"
|
||||||
|
else:
|
||||||
|
system_prompt_display = settings.default_system_prompt[:50] + "..." if len(settings.default_system_prompt) > 50 else settings.default_system_prompt
|
||||||
|
table.add_row("System Prompt", system_prompt_display)
|
||||||
|
|
||||||
|
table.add_row("Streaming", "on" if settings.stream_enabled else "off")
|
||||||
|
table.add_row("Cost Warning", f"${settings.cost_warning_threshold:.4f}")
|
||||||
|
table.add_row("Max Tokens", str(settings.max_tokens))
|
||||||
|
table.add_row("Default Online", "on" if settings.default_online_mode else "off")
|
||||||
|
table.add_row("Log Level", settings.log_level)
|
||||||
|
|
||||||
|
display_panel(table, title="[bold green]Configuration[/]")
|
||||||
|
return
|
||||||
|
|
||||||
|
setting = setting.lower()
|
||||||
|
|
||||||
|
if setting == "api":
|
||||||
|
if value:
|
||||||
|
settings.set_api_key(value)
|
||||||
|
else:
|
||||||
|
from oai.ui.prompts import prompt_input
|
||||||
|
new_key = prompt_input("Enter API key", password=True)
|
||||||
|
if new_key:
|
||||||
|
settings.set_api_key(new_key)
|
||||||
|
print_success("API key updated")
|
||||||
|
|
||||||
|
elif setting == "url":
|
||||||
|
settings.set_base_url(value or "https://openrouter.ai/api/v1")
|
||||||
|
print_success(f"Base URL set to: {settings.base_url}")
|
||||||
|
|
||||||
|
elif setting == "model":
|
||||||
|
if value:
|
||||||
|
settings.set_default_model(value)
|
||||||
|
print_success(f"Default model set to: {value}")
|
||||||
|
else:
|
||||||
|
print_info(f"Current default model: {settings.default_model or 'Not set'}")
|
||||||
|
|
||||||
|
elif setting == "system":
|
||||||
|
from oai.constants import DEFAULT_SYSTEM_PROMPT
|
||||||
|
|
||||||
|
if value:
|
||||||
|
# Decode escape sequences like \n for newlines
|
||||||
|
value = value.encode().decode('unicode_escape')
|
||||||
|
settings.set_default_system_prompt(value)
|
||||||
|
if value:
|
||||||
|
print_success(f"Default system prompt set to: {value}")
|
||||||
|
else:
|
||||||
|
print_success("Default system prompt set to blank.")
|
||||||
|
else:
|
||||||
|
if settings.default_system_prompt is None:
|
||||||
|
print_info(f"Using hardcoded default: {DEFAULT_SYSTEM_PROMPT[:60]}...")
|
||||||
|
elif settings.default_system_prompt == "":
|
||||||
|
print_info("System prompt: [blank]")
|
||||||
|
else:
|
||||||
|
print_info(f"System prompt: {settings.default_system_prompt}")
|
||||||
|
|
||||||
|
elif setting == "stream":
|
||||||
|
if value and value.lower() in ["on", "off"]:
|
||||||
|
settings.set_stream_enabled(value.lower() == "on")
|
||||||
|
print_success(f"Streaming {'enabled' if settings.stream_enabled else 'disabled'}")
|
||||||
|
else:
|
||||||
|
print_info("Usage: oai config stream [on|off]")
|
||||||
|
|
||||||
|
elif setting == "costwarning":
|
||||||
|
if value:
|
||||||
|
try:
|
||||||
|
threshold = float(value)
|
||||||
|
settings.set_cost_warning_threshold(threshold)
|
||||||
|
print_success(f"Cost warning threshold set to: ${threshold:.4f}")
|
||||||
|
except ValueError:
|
||||||
|
print_error("Please enter a valid number")
|
||||||
|
else:
|
||||||
|
print_info(f"Current threshold: ${settings.cost_warning_threshold:.4f}")
|
||||||
|
|
||||||
|
elif setting == "maxtoken":
|
||||||
|
if value:
|
||||||
|
try:
|
||||||
|
max_tok = int(value)
|
||||||
|
settings.set_max_tokens(max_tok)
|
||||||
|
print_success(f"Max tokens set to: {max_tok}")
|
||||||
|
except ValueError:
|
||||||
|
print_error("Please enter a valid number")
|
||||||
|
else:
|
||||||
|
print_info(f"Current max tokens: {settings.max_tokens}")
|
||||||
|
|
||||||
|
elif setting == "online":
|
||||||
|
if value and value.lower() in ["on", "off"]:
|
||||||
|
settings.set_default_online_mode(value.lower() == "on")
|
||||||
|
print_success(f"Default online mode {'enabled' if settings.default_online_mode else 'disabled'}")
|
||||||
|
else:
|
||||||
|
print_info("Usage: oai config online [on|off]")
|
||||||
|
|
||||||
|
elif setting == "loglevel":
|
||||||
|
valid_levels = ["debug", "info", "warning", "error", "critical"]
|
||||||
|
if value and value.lower() in valid_levels:
|
||||||
|
settings.set_log_level(value.lower())
|
||||||
|
print_success(f"Log level set to: {value.lower()}")
|
||||||
|
else:
|
||||||
|
print_info(f"Valid levels: {', '.join(valid_levels)}")
|
||||||
|
|
||||||
|
else:
|
||||||
|
print_error(f"Unknown setting: {setting}")
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def version() -> None:
|
||||||
|
"""Show version information."""
|
||||||
|
version_info = check_for_updates(APP_VERSION)
|
||||||
|
console.print(version_info)
|
||||||
|
|
||||||
|
|
||||||
|
@app.command()
|
||||||
|
def credits() -> None:
|
||||||
|
"""Check account credits."""
|
||||||
|
settings = Settings.load()
|
||||||
|
|
||||||
|
if not settings.api_key:
|
||||||
|
print_error("No API key configured")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
client = AIClient(api_key=settings.api_key, base_url=settings.base_url)
|
||||||
|
credits_data = client.get_credits()
|
||||||
|
|
||||||
|
if not credits_data:
|
||||||
|
print_error("Failed to fetch credits")
|
||||||
|
raise typer.Exit(1)
|
||||||
|
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
table = Table("Metric", "Value", show_header=True, header_style="bold magenta")
|
||||||
|
table.add_row("Total Credits", credits_data.get("total_credits_formatted", "N/A"))
|
||||||
|
table.add_row("Used Credits", credits_data.get("used_credits_formatted", "N/A"))
|
||||||
|
table.add_row("Credits Left", credits_data.get("credits_left_formatted", "N/A"))
|
||||||
|
|
||||||
|
display_panel(table, title="[bold green]Account Credits[/]")
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Main entry point."""
|
||||||
|
# Default to 'chat' command if no arguments provided
|
||||||
|
if len(sys.argv) == 1:
|
||||||
|
sys.argv.append("chat")
|
||||||
|
app()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
24
oai/commands/__init__.py
Normal file
24
oai/commands/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
"""
|
||||||
|
Command system for oAI.
|
||||||
|
|
||||||
|
This module provides a command registry and handler system
|
||||||
|
for processing slash commands in the chat interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from oai.commands.registry import (
|
||||||
|
Command,
|
||||||
|
CommandRegistry,
|
||||||
|
CommandContext,
|
||||||
|
CommandResult,
|
||||||
|
registry,
|
||||||
|
)
|
||||||
|
from oai.commands.handlers import register_all_commands
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Command",
|
||||||
|
"CommandRegistry",
|
||||||
|
"CommandContext",
|
||||||
|
"CommandResult",
|
||||||
|
"registry",
|
||||||
|
"register_all_commands",
|
||||||
|
]
|
||||||
1441
oai/commands/handlers.py
Normal file
1441
oai/commands/handlers.py
Normal file
File diff suppressed because it is too large
Load Diff
381
oai/commands/registry.py
Normal file
381
oai/commands/registry.py
Normal file
@@ -0,0 +1,381 @@
|
|||||||
|
"""
|
||||||
|
Command registry for oAI.
|
||||||
|
|
||||||
|
This module defines the command system infrastructure including
|
||||||
|
the Command base class, CommandContext for state, and CommandRegistry
|
||||||
|
for managing available commands.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
|
||||||
|
|
||||||
|
from oai.utils.logging import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from oai.config.settings import Settings
|
||||||
|
from oai.providers.base import AIProvider, ModelInfo
|
||||||
|
from oai.mcp.manager import MCPManager
|
||||||
|
|
||||||
|
|
||||||
|
class CommandStatus(str, Enum):
|
||||||
|
"""Status of command execution."""
|
||||||
|
|
||||||
|
SUCCESS = "success"
|
||||||
|
ERROR = "error"
|
||||||
|
CONTINUE = "continue" # Continue to next handler
|
||||||
|
EXIT = "exit" # Exit the application
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommandResult:
|
||||||
|
"""
|
||||||
|
Result of a command execution.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
status: Execution status
|
||||||
|
message: Optional message to display
|
||||||
|
data: Optional data payload
|
||||||
|
should_continue: Whether to continue the main loop
|
||||||
|
"""
|
||||||
|
|
||||||
|
status: CommandStatus = CommandStatus.SUCCESS
|
||||||
|
message: Optional[str] = None
|
||||||
|
data: Optional[Any] = None
|
||||||
|
should_continue: bool = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def success(cls, message: Optional[str] = None, data: Any = None) -> "CommandResult":
|
||||||
|
"""Create a success result."""
|
||||||
|
return cls(status=CommandStatus.SUCCESS, message=message, data=data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def error(cls, message: str) -> "CommandResult":
|
||||||
|
"""Create an error result."""
|
||||||
|
return cls(status=CommandStatus.ERROR, message=message)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def exit(cls, message: Optional[str] = None) -> "CommandResult":
|
||||||
|
"""Create an exit result."""
|
||||||
|
return cls(status=CommandStatus.EXIT, message=message, should_continue=False)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommandContext:
|
||||||
|
"""
|
||||||
|
Context object providing state to command handlers.
|
||||||
|
|
||||||
|
Contains all the session state needed by commands including
|
||||||
|
settings, provider, conversation history, and MCP manager.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
settings: Application settings
|
||||||
|
provider: AI provider instance
|
||||||
|
mcp_manager: MCP manager instance
|
||||||
|
selected_model: Currently selected model
|
||||||
|
session_history: Conversation history
|
||||||
|
session_system_prompt: Current system prompt
|
||||||
|
memory_enabled: Whether memory is enabled
|
||||||
|
online_enabled: Whether online mode is enabled
|
||||||
|
session_tokens: Session token counts
|
||||||
|
session_cost: Session cost total
|
||||||
|
"""
|
||||||
|
|
||||||
|
settings: Optional["Settings"] = None
|
||||||
|
provider: Optional["AIProvider"] = None
|
||||||
|
mcp_manager: Optional["MCPManager"] = None
|
||||||
|
selected_model: Optional["ModelInfo"] = None
|
||||||
|
selected_model_raw: Optional[Dict[str, Any]] = None
|
||||||
|
session_history: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
session_system_prompt: str = ""
|
||||||
|
memory_enabled: bool = True
|
||||||
|
memory_start_index: int = 0
|
||||||
|
online_enabled: bool = False
|
||||||
|
middle_out_enabled: bool = False
|
||||||
|
session_max_token: int = 0
|
||||||
|
total_input_tokens: int = 0
|
||||||
|
total_output_tokens: int = 0
|
||||||
|
total_cost: float = 0.0
|
||||||
|
message_count: int = 0
|
||||||
|
current_index: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CommandHelp:
|
||||||
|
"""
|
||||||
|
Help information for a command.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
description: Brief description
|
||||||
|
usage: Usage syntax
|
||||||
|
examples: List of (description, example) tuples
|
||||||
|
notes: Additional notes
|
||||||
|
aliases: Command aliases
|
||||||
|
"""
|
||||||
|
|
||||||
|
description: str
|
||||||
|
usage: str = ""
|
||||||
|
examples: List[tuple] = field(default_factory=list)
|
||||||
|
notes: str = ""
|
||||||
|
aliases: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class Command(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for all commands.
|
||||||
|
|
||||||
|
Commands implement the execute method to handle their logic.
|
||||||
|
They can also provide help information and aliases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Get the primary command name (e.g., '/help')."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
def aliases(self) -> List[str]:
|
||||||
|
"""Get command aliases (e.g., ['/h'] for help)."""
|
||||||
|
return []
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def help(self) -> CommandHelp:
|
||||||
|
"""Get command help information."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def execute(self, args: str, context: CommandContext) -> CommandResult:
|
||||||
|
"""
|
||||||
|
Execute the command.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
args: Arguments passed to the command
|
||||||
|
context: Command execution context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CommandResult indicating success/failure
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def matches(self, input_text: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if this command matches the input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_text: User input text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if this command should handle the input
|
||||||
|
"""
|
||||||
|
input_lower = input_text.lower()
|
||||||
|
cmd_word = input_lower.split()[0] if input_lower.split() else ""
|
||||||
|
|
||||||
|
# Check primary name
|
||||||
|
if cmd_word == self.name.lower():
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Check aliases
|
||||||
|
for alias in self.aliases:
|
||||||
|
if cmd_word == alias.lower():
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_args(self, input_text: str) -> str:
|
||||||
|
"""
|
||||||
|
Extract arguments from the input text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_text: Full user input
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Arguments portion of the input
|
||||||
|
"""
|
||||||
|
parts = input_text.split(maxsplit=1)
|
||||||
|
return parts[1] if len(parts) > 1 else ""
|
||||||
|
|
||||||
|
|
||||||
|
class CommandRegistry:
|
||||||
|
"""
|
||||||
|
Registry for managing available commands.
|
||||||
|
|
||||||
|
Provides registration, lookup, and execution of commands.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize an empty command registry."""
|
||||||
|
self._commands: Dict[str, Command] = {}
|
||||||
|
self._aliases: Dict[str, str] = {}
|
||||||
|
self.logger = get_logger()
|
||||||
|
|
||||||
|
def register(self, command: Command) -> None:
|
||||||
|
"""
|
||||||
|
Register a command.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
command: Command instance to register
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If command name already registered
|
||||||
|
"""
|
||||||
|
name = command.name.lower()
|
||||||
|
|
||||||
|
if name in self._commands:
|
||||||
|
raise ValueError(f"Command '{name}' already registered")
|
||||||
|
|
||||||
|
self._commands[name] = command
|
||||||
|
|
||||||
|
# Register aliases
|
||||||
|
for alias in command.aliases:
|
||||||
|
alias_lower = alias.lower()
|
||||||
|
if alias_lower in self._aliases:
|
||||||
|
self.logger.warning(
|
||||||
|
f"Alias '{alias}' already registered, overwriting"
|
||||||
|
)
|
||||||
|
self._aliases[alias_lower] = name
|
||||||
|
|
||||||
|
self.logger.debug(f"Registered command: {name}")
|
||||||
|
|
||||||
|
def register_function(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
handler: Callable[[str, CommandContext], CommandResult],
|
||||||
|
description: str,
|
||||||
|
usage: str = "",
|
||||||
|
aliases: Optional[List[str]] = None,
|
||||||
|
examples: Optional[List[tuple]] = None,
|
||||||
|
notes: str = "",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Register a function-based command.
|
||||||
|
|
||||||
|
Convenience method for simple commands that don't need
|
||||||
|
a full Command class.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Command name (e.g., '/help')
|
||||||
|
handler: Function to execute
|
||||||
|
description: Help description
|
||||||
|
usage: Usage syntax
|
||||||
|
aliases: Command aliases
|
||||||
|
examples: Example usages
|
||||||
|
notes: Additional notes
|
||||||
|
"""
|
||||||
|
aliases = aliases or []
|
||||||
|
examples = examples or []
|
||||||
|
|
||||||
|
class FunctionCommand(Command):
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def aliases(self) -> List[str]:
|
||||||
|
return aliases
|
||||||
|
|
||||||
|
@property
|
||||||
|
def help(self) -> CommandHelp:
|
||||||
|
return CommandHelp(
|
||||||
|
description=description,
|
||||||
|
usage=usage,
|
||||||
|
examples=examples,
|
||||||
|
notes=notes,
|
||||||
|
aliases=aliases,
|
||||||
|
)
|
||||||
|
|
||||||
|
def execute(self, args: str, context: CommandContext) -> CommandResult:
|
||||||
|
return handler(args, context)
|
||||||
|
|
||||||
|
self.register(FunctionCommand())
|
||||||
|
|
||||||
|
def get(self, name: str) -> Optional[Command]:
|
||||||
|
"""
|
||||||
|
Get a command by name or alias.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Command name or alias
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Command instance or None if not found
|
||||||
|
"""
|
||||||
|
name_lower = name.lower()
|
||||||
|
|
||||||
|
# Check direct match
|
||||||
|
if name_lower in self._commands:
|
||||||
|
return self._commands[name_lower]
|
||||||
|
|
||||||
|
# Check aliases
|
||||||
|
if name_lower in self._aliases:
|
||||||
|
return self._commands[self._aliases[name_lower]]
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def find(self, input_text: str) -> Optional[Command]:
|
||||||
|
"""
|
||||||
|
Find a command that matches the input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_text: User input text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Matching Command or None
|
||||||
|
"""
|
||||||
|
cmd_word = input_text.lower().split()[0] if input_text.split() else ""
|
||||||
|
return self.get(cmd_word)
|
||||||
|
|
||||||
|
def execute(self, input_text: str, context: CommandContext) -> Optional[CommandResult]:
|
||||||
|
"""
|
||||||
|
Execute a command matching the input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_text: User input text
|
||||||
|
context: Execution context
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CommandResult or None if no matching command
|
||||||
|
"""
|
||||||
|
command = self.find(input_text)
|
||||||
|
if command:
|
||||||
|
args = command.get_args(input_text)
|
||||||
|
self.logger.debug(f"Executing command: {command.name} with args: {args}")
|
||||||
|
return command.execute(args, context)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def is_command(self, input_text: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if input is a valid command.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_text: User input text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if input matches a registered command
|
||||||
|
"""
|
||||||
|
return self.find(input_text) is not None
|
||||||
|
|
||||||
|
def list_commands(self) -> List[Command]:
|
||||||
|
"""
|
||||||
|
Get all registered commands.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Command instances
|
||||||
|
"""
|
||||||
|
return list(self._commands.values())
|
||||||
|
|
||||||
|
def get_all_names(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Get all command names and aliases.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of command names including aliases
|
||||||
|
"""
|
||||||
|
names = list(self._commands.keys())
|
||||||
|
names.extend(self._aliases.keys())
|
||||||
|
return sorted(set(names))
|
||||||
|
|
||||||
|
|
||||||
|
# Global registry instance
|
||||||
|
registry = CommandRegistry()
|
||||||
11
oai/config/__init__.py
Normal file
11
oai/config/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
"""
|
||||||
|
Configuration management for oAI.
|
||||||
|
|
||||||
|
This package handles all configuration persistence, settings management,
|
||||||
|
and database operations for the application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from oai.config.settings import Settings
|
||||||
|
from oai.config.database import Database
|
||||||
|
|
||||||
|
__all__ = ["Settings", "Database"]
|
||||||
472
oai/config/database.py
Normal file
472
oai/config/database.py
Normal file
@@ -0,0 +1,472 @@
|
|||||||
|
"""
|
||||||
|
Database persistence layer for oAI.
|
||||||
|
|
||||||
|
This module provides a clean abstraction for SQLite operations including
|
||||||
|
configuration storage, conversation persistence, and MCP statistics tracking.
|
||||||
|
All database operations are centralized here for maintainability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import json
|
||||||
|
import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
from oai.constants import DATABASE_FILE, CONFIG_DIR
|
||||||
|
|
||||||
|
|
||||||
|
class Database:
|
||||||
|
"""
|
||||||
|
SQLite database manager for oAI.
|
||||||
|
|
||||||
|
Handles all database operations including:
|
||||||
|
- Configuration key-value storage
|
||||||
|
- Conversation session persistence
|
||||||
|
- MCP configuration and statistics
|
||||||
|
- Database registrations for MCP
|
||||||
|
|
||||||
|
Uses context managers for safe connection handling and supports
|
||||||
|
automatic table creation on first use.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_path: Optional[Path] = None):
|
||||||
|
"""
|
||||||
|
Initialize the database manager.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Optional custom database path. Defaults to standard location.
|
||||||
|
"""
|
||||||
|
self.db_path = db_path or DATABASE_FILE
|
||||||
|
self._ensure_directories()
|
||||||
|
self._ensure_tables()
|
||||||
|
|
||||||
|
def _ensure_directories(self) -> None:
|
||||||
|
"""Ensure the configuration directory exists."""
|
||||||
|
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _connection(self):
|
||||||
|
"""
|
||||||
|
Context manager for database connections.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
sqlite3.Connection: Active database connection
|
||||||
|
|
||||||
|
Example:
|
||||||
|
with self._connection() as conn:
|
||||||
|
conn.execute("SELECT * FROM config")
|
||||||
|
"""
|
||||||
|
conn = sqlite3.connect(str(self.db_path))
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
conn.commit()
|
||||||
|
except Exception:
|
||||||
|
conn.rollback()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def _ensure_tables(self) -> None:
|
||||||
|
"""Create all required tables if they don't exist."""
|
||||||
|
with self._connection() as conn:
|
||||||
|
# Main configuration table
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS config (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# Conversation sessions table
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS conversation_sessions (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
timestamp TEXT NOT NULL,
|
||||||
|
data TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# MCP configuration table
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS mcp_config (
|
||||||
|
key TEXT PRIMARY KEY,
|
||||||
|
value TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# MCP statistics table
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS mcp_stats (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
timestamp TEXT NOT NULL,
|
||||||
|
tool_name TEXT NOT NULL,
|
||||||
|
folder TEXT,
|
||||||
|
success INTEGER NOT NULL,
|
||||||
|
error_message TEXT
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# MCP databases table
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS mcp_databases (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
path TEXT NOT NULL UNIQUE,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
size INTEGER,
|
||||||
|
tables TEXT,
|
||||||
|
added_timestamp TEXT NOT NULL
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# CONFIGURATION METHODS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def get_config(self, key: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Retrieve a configuration value by key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The configuration key to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configuration value, or None if not found
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"SELECT value FROM config WHERE key = ?",
|
||||||
|
(key,)
|
||||||
|
)
|
||||||
|
result = cursor.fetchone()
|
||||||
|
return result[0] if result else None
|
||||||
|
|
||||||
|
def set_config(self, key: str, value: str) -> None:
|
||||||
|
"""
|
||||||
|
Set a configuration value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The configuration key
|
||||||
|
value: The value to store
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)",
|
||||||
|
(key, value)
|
||||||
|
)
|
||||||
|
|
||||||
|
def delete_config(self, key: str) -> bool:
|
||||||
|
"""
|
||||||
|
Delete a configuration value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The configuration key to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a row was deleted, False otherwise
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"DELETE FROM config WHERE key = ?",
|
||||||
|
(key,)
|
||||||
|
)
|
||||||
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
|
def get_all_config(self) -> Dict[str, str]:
|
||||||
|
"""
|
||||||
|
Retrieve all configuration values.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of all key-value pairs
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute("SELECT key, value FROM config")
|
||||||
|
return dict(cursor.fetchall())
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# MCP CONFIGURATION METHODS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def get_mcp_config(self, key: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Retrieve an MCP configuration value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The MCP configuration key
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configuration value, or None if not found
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"SELECT value FROM mcp_config WHERE key = ?",
|
||||||
|
(key,)
|
||||||
|
)
|
||||||
|
result = cursor.fetchone()
|
||||||
|
return result[0] if result else None
|
||||||
|
|
||||||
|
def set_mcp_config(self, key: str, value: str) -> None:
|
||||||
|
"""
|
||||||
|
Set an MCP configuration value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: The MCP configuration key
|
||||||
|
value: The value to store
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR REPLACE INTO mcp_config (key, value) VALUES (?, ?)",
|
||||||
|
(key, value)
|
||||||
|
)
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# MCP STATISTICS METHODS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def log_mcp_stat(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
folder: Optional[str],
|
||||||
|
success: bool,
|
||||||
|
error_message: Optional[str] = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Log an MCP tool usage event.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the MCP tool that was called
|
||||||
|
folder: The folder path involved (if any)
|
||||||
|
success: Whether the call succeeded
|
||||||
|
error_message: Error message if the call failed
|
||||||
|
"""
|
||||||
|
timestamp = datetime.datetime.now().isoformat()
|
||||||
|
with self._connection() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""INSERT INTO mcp_stats
|
||||||
|
(timestamp, tool_name, folder, success, error_message)
|
||||||
|
VALUES (?, ?, ?, ?, ?)""",
|
||||||
|
(timestamp, tool_name, folder, 1 if success else 0, error_message)
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_mcp_stats(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get aggregated MCP usage statistics.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing usage statistics:
|
||||||
|
- total_calls: Total number of tool calls
|
||||||
|
- reads: Number of file reads
|
||||||
|
- lists: Number of directory listings
|
||||||
|
- searches: Number of file searches
|
||||||
|
- db_inspects: Number of database inspections
|
||||||
|
- db_searches: Number of database searches
|
||||||
|
- db_queries: Number of database queries
|
||||||
|
- last_used: Timestamp of last usage
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute("""
|
||||||
|
SELECT
|
||||||
|
COUNT(*) as total_calls,
|
||||||
|
SUM(CASE WHEN tool_name = 'read_file' THEN 1 ELSE 0 END) as reads,
|
||||||
|
SUM(CASE WHEN tool_name = 'list_directory' THEN 1 ELSE 0 END) as lists,
|
||||||
|
SUM(CASE WHEN tool_name = 'search_files' THEN 1 ELSE 0 END) as searches,
|
||||||
|
SUM(CASE WHEN tool_name = 'inspect_database' THEN 1 ELSE 0 END) as db_inspects,
|
||||||
|
SUM(CASE WHEN tool_name = 'search_database' THEN 1 ELSE 0 END) as db_searches,
|
||||||
|
SUM(CASE WHEN tool_name = 'query_database' THEN 1 ELSE 0 END) as db_queries,
|
||||||
|
MAX(timestamp) as last_used
|
||||||
|
FROM mcp_stats
|
||||||
|
""")
|
||||||
|
row = cursor.fetchone()
|
||||||
|
return {
|
||||||
|
"total_calls": row[0] or 0,
|
||||||
|
"reads": row[1] or 0,
|
||||||
|
"lists": row[2] or 0,
|
||||||
|
"searches": row[3] or 0,
|
||||||
|
"db_inspects": row[4] or 0,
|
||||||
|
"db_searches": row[5] or 0,
|
||||||
|
"db_queries": row[6] or 0,
|
||||||
|
"last_used": row[7],
|
||||||
|
}
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# MCP DATABASE REGISTRY METHODS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def add_mcp_database(self, db_info: Dict[str, Any]) -> int:
|
||||||
|
"""
|
||||||
|
Register a database for MCP access.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_info: Dictionary containing:
|
||||||
|
- path: Database file path
|
||||||
|
- name: Display name
|
||||||
|
- size: File size in bytes
|
||||||
|
- tables: List of table names
|
||||||
|
- added: Timestamp when added
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The database ID
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""INSERT INTO mcp_databases
|
||||||
|
(path, name, size, tables, added_timestamp)
|
||||||
|
VALUES (?, ?, ?, ?, ?)""",
|
||||||
|
(
|
||||||
|
db_info["path"],
|
||||||
|
db_info["name"],
|
||||||
|
db_info["size"],
|
||||||
|
json.dumps(db_info["tables"]),
|
||||||
|
db_info["added"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cursor = conn.execute(
|
||||||
|
"SELECT id FROM mcp_databases WHERE path = ?",
|
||||||
|
(db_info["path"],)
|
||||||
|
)
|
||||||
|
return cursor.fetchone()[0]
|
||||||
|
|
||||||
|
def remove_mcp_database(self, db_path: str) -> bool:
|
||||||
|
"""
|
||||||
|
Remove a database from the MCP registry.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to the database file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if a row was deleted, False otherwise
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"DELETE FROM mcp_databases WHERE path = ?",
|
||||||
|
(db_path,)
|
||||||
|
)
|
||||||
|
return cursor.rowcount > 0
|
||||||
|
|
||||||
|
def get_mcp_databases(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Retrieve all registered MCP databases.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of database information dictionaries
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"""SELECT id, path, name, size, tables, added_timestamp
|
||||||
|
FROM mcp_databases ORDER BY id"""
|
||||||
|
)
|
||||||
|
databases = []
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
tables_list = json.loads(row[4]) if row[4] else []
|
||||||
|
databases.append({
|
||||||
|
"id": row[0],
|
||||||
|
"path": row[1],
|
||||||
|
"name": row[2],
|
||||||
|
"size": row[3],
|
||||||
|
"tables": tables_list,
|
||||||
|
"added": row[5],
|
||||||
|
})
|
||||||
|
return databases
|
||||||
|
|
||||||
|
# =========================================================================
|
||||||
|
# CONVERSATION METHODS
|
||||||
|
# =========================================================================
|
||||||
|
|
||||||
|
def save_conversation(self, name: str, data: List[Dict[str, str]]) -> None:
|
||||||
|
"""
|
||||||
|
Save a conversation session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name/identifier for the conversation
|
||||||
|
data: List of message dictionaries with 'prompt' and 'response' keys
|
||||||
|
"""
|
||||||
|
timestamp = datetime.datetime.now().isoformat()
|
||||||
|
data_json = json.dumps(data)
|
||||||
|
with self._connection() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"""INSERT INTO conversation_sessions
|
||||||
|
(name, timestamp, data) VALUES (?, ?, ?)""",
|
||||||
|
(name, timestamp, data_json)
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_conversation(self, name: str) -> Optional[List[Dict[str, str]]]:
|
||||||
|
"""
|
||||||
|
Load a conversation by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the conversation to load
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of messages, or None if not found
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"""SELECT data FROM conversation_sessions
|
||||||
|
WHERE name = ?
|
||||||
|
ORDER BY timestamp DESC LIMIT 1""",
|
||||||
|
(name,)
|
||||||
|
)
|
||||||
|
result = cursor.fetchone()
|
||||||
|
if result:
|
||||||
|
return json.loads(result[0])
|
||||||
|
return None
|
||||||
|
|
||||||
|
def delete_conversation(self, name: str) -> int:
|
||||||
|
"""
|
||||||
|
Delete a conversation by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the conversation to delete
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of rows deleted
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute(
|
||||||
|
"DELETE FROM conversation_sessions WHERE name = ?",
|
||||||
|
(name,)
|
||||||
|
)
|
||||||
|
return cursor.rowcount
|
||||||
|
|
||||||
|
def list_conversations(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
List all saved conversations.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of conversation summaries with name, timestamp, and message_count
|
||||||
|
"""
|
||||||
|
with self._connection() as conn:
|
||||||
|
cursor = conn.execute("""
|
||||||
|
SELECT name, MAX(timestamp) as last_saved, data
|
||||||
|
FROM conversation_sessions
|
||||||
|
GROUP BY name
|
||||||
|
ORDER BY last_saved DESC
|
||||||
|
""")
|
||||||
|
conversations = []
|
||||||
|
for row in cursor.fetchall():
|
||||||
|
name, timestamp, data_json = row
|
||||||
|
data = json.loads(data_json)
|
||||||
|
conversations.append({
|
||||||
|
"name": name,
|
||||||
|
"timestamp": timestamp,
|
||||||
|
"message_count": len(data),
|
||||||
|
})
|
||||||
|
return conversations
|
||||||
|
|
||||||
|
|
||||||
|
# Global database instance for convenience
|
||||||
|
_db: Optional[Database] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_database() -> Database:
|
||||||
|
"""
|
||||||
|
Get the global database instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The shared Database instance
|
||||||
|
"""
|
||||||
|
global _db
|
||||||
|
if _db is None:
|
||||||
|
_db = Database()
|
||||||
|
return _db
|
||||||
361
oai/config/settings.py
Normal file
361
oai/config/settings.py
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
"""
|
||||||
|
Settings management for oAI.
|
||||||
|
|
||||||
|
This module provides a centralized settings class that handles all application
|
||||||
|
configuration with type safety, validation, and persistence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from oai.constants import (
|
||||||
|
DEFAULT_BASE_URL,
|
||||||
|
DEFAULT_STREAM_ENABLED,
|
||||||
|
DEFAULT_MAX_TOKENS,
|
||||||
|
DEFAULT_ONLINE_MODE,
|
||||||
|
DEFAULT_COST_WARNING_THRESHOLD,
|
||||||
|
DEFAULT_LOG_MAX_SIZE_MB,
|
||||||
|
DEFAULT_LOG_BACKUP_COUNT,
|
||||||
|
DEFAULT_LOG_LEVEL,
|
||||||
|
DEFAULT_SYSTEM_PROMPT,
|
||||||
|
VALID_LOG_LEVELS,
|
||||||
|
)
|
||||||
|
from oai.config.database import get_database
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Settings:
|
||||||
|
"""
|
||||||
|
Application settings with persistence support.
|
||||||
|
|
||||||
|
This class provides a clean interface for managing all configuration
|
||||||
|
options. Settings are automatically loaded from the database on
|
||||||
|
initialization and can be persisted back.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
api_key: OpenRouter API key
|
||||||
|
base_url: API base URL
|
||||||
|
default_model: Default model ID to use
|
||||||
|
default_system_prompt: Custom system prompt (None = use hardcoded default, "" = blank)
|
||||||
|
stream_enabled: Whether to stream responses
|
||||||
|
max_tokens: Maximum tokens per request
|
||||||
|
cost_warning_threshold: Alert threshold for message cost
|
||||||
|
default_online_mode: Whether online mode is enabled by default
|
||||||
|
log_max_size_mb: Maximum log file size in MB
|
||||||
|
log_backup_count: Number of log file backups to keep
|
||||||
|
log_level: Logging level (debug/info/warning/error/critical)
|
||||||
|
"""
|
||||||
|
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
base_url: str = DEFAULT_BASE_URL
|
||||||
|
default_model: Optional[str] = None
|
||||||
|
default_system_prompt: Optional[str] = None
|
||||||
|
stream_enabled: bool = DEFAULT_STREAM_ENABLED
|
||||||
|
max_tokens: int = DEFAULT_MAX_TOKENS
|
||||||
|
cost_warning_threshold: float = DEFAULT_COST_WARNING_THRESHOLD
|
||||||
|
default_online_mode: bool = DEFAULT_ONLINE_MODE
|
||||||
|
log_max_size_mb: int = DEFAULT_LOG_MAX_SIZE_MB
|
||||||
|
log_backup_count: int = DEFAULT_LOG_BACKUP_COUNT
|
||||||
|
log_level: str = DEFAULT_LOG_LEVEL
|
||||||
|
|
||||||
|
@property
|
||||||
|
def effective_system_prompt(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the effective system prompt to use.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The custom prompt if set, hardcoded default if None, or blank if explicitly set to ""
|
||||||
|
"""
|
||||||
|
if self.default_system_prompt is None:
|
||||||
|
return DEFAULT_SYSTEM_PROMPT
|
||||||
|
return self.default_system_prompt
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
"""Validate settings after initialization."""
|
||||||
|
self._validate()
|
||||||
|
|
||||||
|
def _validate(self) -> None:
|
||||||
|
"""Validate all settings values."""
|
||||||
|
# Validate log level
|
||||||
|
if self.log_level.lower() not in VALID_LOG_LEVELS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid log level: {self.log_level}. "
|
||||||
|
f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate numeric bounds
|
||||||
|
if self.max_tokens < 1:
|
||||||
|
raise ValueError("max_tokens must be at least 1")
|
||||||
|
|
||||||
|
if self.cost_warning_threshold < 0:
|
||||||
|
raise ValueError("cost_warning_threshold must be non-negative")
|
||||||
|
|
||||||
|
if self.log_max_size_mb < 1:
|
||||||
|
raise ValueError("log_max_size_mb must be at least 1")
|
||||||
|
|
||||||
|
if self.log_backup_count < 0:
|
||||||
|
raise ValueError("log_backup_count must be non-negative")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls) -> "Settings":
|
||||||
|
"""
|
||||||
|
Load settings from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Settings instance with values from database
|
||||||
|
"""
|
||||||
|
db = get_database()
|
||||||
|
|
||||||
|
# Helper to safely parse boolean
|
||||||
|
def parse_bool(value: Optional[str], default: bool) -> bool:
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
return value.lower() in ("on", "true", "1", "yes")
|
||||||
|
|
||||||
|
# Helper to safely parse int
|
||||||
|
def parse_int(value: Optional[str], default: int) -> int:
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except ValueError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
# Helper to safely parse float
|
||||||
|
def parse_float(value: Optional[str], default: float) -> float:
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
try:
|
||||||
|
return float(value)
|
||||||
|
except ValueError:
|
||||||
|
return default
|
||||||
|
|
||||||
|
# Get system prompt from DB: None means not set (use default), "" means explicitly blank
|
||||||
|
system_prompt_value = db.get_config("default_system_prompt")
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
api_key=db.get_config("api_key"),
|
||||||
|
base_url=db.get_config("base_url") or DEFAULT_BASE_URL,
|
||||||
|
default_model=db.get_config("default_model"),
|
||||||
|
default_system_prompt=system_prompt_value,
|
||||||
|
stream_enabled=parse_bool(
|
||||||
|
db.get_config("stream_enabled"),
|
||||||
|
DEFAULT_STREAM_ENABLED
|
||||||
|
),
|
||||||
|
max_tokens=parse_int(
|
||||||
|
db.get_config("max_token"),
|
||||||
|
DEFAULT_MAX_TOKENS
|
||||||
|
),
|
||||||
|
cost_warning_threshold=parse_float(
|
||||||
|
db.get_config("cost_warning_threshold"),
|
||||||
|
DEFAULT_COST_WARNING_THRESHOLD
|
||||||
|
),
|
||||||
|
default_online_mode=parse_bool(
|
||||||
|
db.get_config("default_online_mode"),
|
||||||
|
DEFAULT_ONLINE_MODE
|
||||||
|
),
|
||||||
|
log_max_size_mb=parse_int(
|
||||||
|
db.get_config("log_max_size_mb"),
|
||||||
|
DEFAULT_LOG_MAX_SIZE_MB
|
||||||
|
),
|
||||||
|
log_backup_count=parse_int(
|
||||||
|
db.get_config("log_backup_count"),
|
||||||
|
DEFAULT_LOG_BACKUP_COUNT
|
||||||
|
),
|
||||||
|
log_level=db.get_config("log_level") or DEFAULT_LOG_LEVEL,
|
||||||
|
)
|
||||||
|
|
||||||
|
def save(self) -> None:
|
||||||
|
"""Persist all settings to the database."""
|
||||||
|
db = get_database()
|
||||||
|
|
||||||
|
# Only save API key if it exists
|
||||||
|
if self.api_key:
|
||||||
|
db.set_config("api_key", self.api_key)
|
||||||
|
|
||||||
|
db.set_config("base_url", self.base_url)
|
||||||
|
|
||||||
|
if self.default_model:
|
||||||
|
db.set_config("default_model", self.default_model)
|
||||||
|
|
||||||
|
# Save system prompt: None means not set (don't save), otherwise save the value (even if "")
|
||||||
|
if self.default_system_prompt is not None:
|
||||||
|
db.set_config("default_system_prompt", self.default_system_prompt)
|
||||||
|
|
||||||
|
db.set_config("stream_enabled", "on" if self.stream_enabled else "off")
|
||||||
|
db.set_config("max_token", str(self.max_tokens))
|
||||||
|
db.set_config("cost_warning_threshold", str(self.cost_warning_threshold))
|
||||||
|
db.set_config("default_online_mode", "on" if self.default_online_mode else "off")
|
||||||
|
db.set_config("log_max_size_mb", str(self.log_max_size_mb))
|
||||||
|
db.set_config("log_backup_count", str(self.log_backup_count))
|
||||||
|
db.set_config("log_level", self.log_level)
|
||||||
|
|
||||||
|
def set_api_key(self, api_key: str) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the API key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: The new API key
|
||||||
|
"""
|
||||||
|
self.api_key = api_key.strip()
|
||||||
|
get_database().set_config("api_key", self.api_key)
|
||||||
|
|
||||||
|
def set_base_url(self, url: str) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the base URL.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
url: The new base URL
|
||||||
|
"""
|
||||||
|
self.base_url = url.strip()
|
||||||
|
get_database().set_config("base_url", self.base_url)
|
||||||
|
|
||||||
|
def set_default_model(self, model_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the default model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: The model ID to set as default
|
||||||
|
"""
|
||||||
|
self.default_model = model_id
|
||||||
|
get_database().set_config("default_model", model_id)
|
||||||
|
|
||||||
|
def set_default_system_prompt(self, prompt: str) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the default system prompt.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: The system prompt to use for all new sessions.
|
||||||
|
Empty string "" means blank prompt (no system message).
|
||||||
|
"""
|
||||||
|
self.default_system_prompt = prompt
|
||||||
|
get_database().set_config("default_system_prompt", prompt)
|
||||||
|
|
||||||
|
def clear_default_system_prompt(self) -> None:
|
||||||
|
"""
|
||||||
|
Clear the custom system prompt and revert to hardcoded default.
|
||||||
|
|
||||||
|
This removes the custom prompt from the database, causing the
|
||||||
|
application to use the built-in DEFAULT_SYSTEM_PROMPT.
|
||||||
|
"""
|
||||||
|
self.default_system_prompt = None
|
||||||
|
# Remove from database to indicate "not set"
|
||||||
|
db = get_database()
|
||||||
|
with db._connection() as conn:
|
||||||
|
conn.execute("DELETE FROM config WHERE key = ?", ("default_system_prompt",))
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def set_stream_enabled(self, enabled: bool) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the streaming preference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enabled: Whether to enable streaming
|
||||||
|
"""
|
||||||
|
self.stream_enabled = enabled
|
||||||
|
get_database().set_config("stream_enabled", "on" if enabled else "off")
|
||||||
|
|
||||||
|
def set_max_tokens(self, max_tokens: int) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the maximum tokens.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_tokens: Maximum number of tokens
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If max_tokens is less than 1
|
||||||
|
"""
|
||||||
|
if max_tokens < 1:
|
||||||
|
raise ValueError("max_tokens must be at least 1")
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
get_database().set_config("max_token", str(max_tokens))
|
||||||
|
|
||||||
|
def set_cost_warning_threshold(self, threshold: float) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the cost warning threshold.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
threshold: Cost threshold in USD
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If threshold is negative
|
||||||
|
"""
|
||||||
|
if threshold < 0:
|
||||||
|
raise ValueError("cost_warning_threshold must be non-negative")
|
||||||
|
self.cost_warning_threshold = threshold
|
||||||
|
get_database().set_config("cost_warning_threshold", str(threshold))
|
||||||
|
|
||||||
|
def set_default_online_mode(self, enabled: bool) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the default online mode.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
enabled: Whether online mode should be enabled by default
|
||||||
|
"""
|
||||||
|
self.default_online_mode = enabled
|
||||||
|
get_database().set_config("default_online_mode", "on" if enabled else "off")
|
||||||
|
|
||||||
|
def set_log_level(self, level: str) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the log level.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: The log level (debug/info/warning/error/critical)
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If level is not valid
|
||||||
|
"""
|
||||||
|
level_lower = level.lower()
|
||||||
|
if level_lower not in VALID_LOG_LEVELS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid log level: {level}. "
|
||||||
|
f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}"
|
||||||
|
)
|
||||||
|
self.log_level = level_lower
|
||||||
|
get_database().set_config("log_level", level_lower)
|
||||||
|
|
||||||
|
def set_log_max_size(self, size_mb: int) -> None:
|
||||||
|
"""
|
||||||
|
Set and persist the maximum log file size.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size_mb: Maximum size in megabytes
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If size_mb is less than 1
|
||||||
|
"""
|
||||||
|
if size_mb < 1:
|
||||||
|
raise ValueError("log_max_size_mb must be at least 1")
|
||||||
|
# Cap at 100 MB for safety
|
||||||
|
self.log_max_size_mb = min(size_mb, 100)
|
||||||
|
get_database().set_config("log_max_size_mb", str(self.log_max_size_mb))
|
||||||
|
|
||||||
|
|
||||||
|
# Global settings instance
|
||||||
|
_settings: Optional[Settings] = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_settings() -> Settings:
|
||||||
|
"""
|
||||||
|
Get the global settings instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The shared Settings instance, loading from database if needed
|
||||||
|
"""
|
||||||
|
global _settings
|
||||||
|
if _settings is None:
|
||||||
|
_settings = Settings.load()
|
||||||
|
return _settings
|
||||||
|
|
||||||
|
|
||||||
|
def reload_settings() -> Settings:
|
||||||
|
"""
|
||||||
|
Force reload settings from the database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Fresh Settings instance
|
||||||
|
"""
|
||||||
|
global _settings
|
||||||
|
_settings = Settings.load()
|
||||||
|
return _settings
|
||||||
448
oai/constants.py
Normal file
448
oai/constants.py
Normal file
@@ -0,0 +1,448 @@
|
|||||||
|
"""
|
||||||
|
Application-wide constants for oAI.
|
||||||
|
|
||||||
|
This module contains all configuration constants, default values, and static
|
||||||
|
definitions used throughout the application. Centralizing these values makes
|
||||||
|
the codebase easier to maintain and configure.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Set, Dict, Any
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# APPLICATION METADATA
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
APP_NAME = "oAI"
|
||||||
|
APP_VERSION = "2.1.0"
|
||||||
|
APP_URL = "https://iurl.no/oai"
|
||||||
|
APP_DESCRIPTION = "OpenRouter AI Chat Client with MCP Integration"
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# FILE PATHS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
HOME_DIR = Path.home()
|
||||||
|
CONFIG_DIR = HOME_DIR / ".config" / "oai"
|
||||||
|
CACHE_DIR = HOME_DIR / ".cache" / "oai"
|
||||||
|
HISTORY_FILE = CONFIG_DIR / "history.txt"
|
||||||
|
DATABASE_FILE = CONFIG_DIR / "oai_config.db"
|
||||||
|
LOG_FILE = CONFIG_DIR / "oai.log"
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# API CONFIGURATION
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
DEFAULT_BASE_URL = "https://openrouter.ai/api/v1"
|
||||||
|
DEFAULT_STREAM_ENABLED = True
|
||||||
|
DEFAULT_MAX_TOKENS = 100_000
|
||||||
|
DEFAULT_ONLINE_MODE = False
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# DEFAULT SYSTEM PROMPT
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
DEFAULT_SYSTEM_PROMPT = (
|
||||||
|
"You are a knowledgeable and helpful AI assistant. Provide clear, accurate, "
|
||||||
|
"and well-structured responses. Be concise yet thorough. When uncertain about "
|
||||||
|
"something, acknowledge your limitations. For technical topics, include relevant "
|
||||||
|
"details and examples when helpful."
|
||||||
|
)
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# PRICING DEFAULTS (per million tokens)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
DEFAULT_INPUT_PRICE = 3.0
|
||||||
|
DEFAULT_OUTPUT_PRICE = 15.0
|
||||||
|
|
||||||
|
MODEL_PRICING: Dict[str, float] = {
|
||||||
|
"input": DEFAULT_INPUT_PRICE,
|
||||||
|
"output": DEFAULT_OUTPUT_PRICE,
|
||||||
|
}
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# CREDIT ALERTS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
LOW_CREDIT_RATIO = 0.1 # Alert when credits < 10% of total
|
||||||
|
LOW_CREDIT_AMOUNT = 1.0 # Alert when credits < $1.00
|
||||||
|
DEFAULT_COST_WARNING_THRESHOLD = 0.01 # Alert when single message cost exceeds this
|
||||||
|
COST_WARNING_THRESHOLD = DEFAULT_COST_WARNING_THRESHOLD # Alias for convenience
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# LOGGING CONFIGURATION
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
DEFAULT_LOG_MAX_SIZE_MB = 10
|
||||||
|
DEFAULT_LOG_BACKUP_COUNT = 2
|
||||||
|
DEFAULT_LOG_LEVEL = "info"
|
||||||
|
|
||||||
|
VALID_LOG_LEVELS: Dict[str, int] = {
|
||||||
|
"debug": logging.DEBUG,
|
||||||
|
"info": logging.INFO,
|
||||||
|
"warning": logging.WARNING,
|
||||||
|
"error": logging.ERROR,
|
||||||
|
"critical": logging.CRITICAL,
|
||||||
|
}
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# FILE HANDLING
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Maximum file size for reading (10 MB)
|
||||||
|
MAX_FILE_SIZE = 10 * 1024 * 1024
|
||||||
|
|
||||||
|
# Content truncation threshold (50 KB)
|
||||||
|
CONTENT_TRUNCATION_THRESHOLD = 50 * 1024
|
||||||
|
|
||||||
|
# Maximum items in directory listing
|
||||||
|
MAX_LIST_ITEMS = 1000
|
||||||
|
|
||||||
|
# Supported code file extensions for syntax highlighting
|
||||||
|
SUPPORTED_CODE_EXTENSIONS: Set[str] = {
|
||||||
|
".py", ".js", ".ts", ".cs", ".java", ".c", ".cpp", ".h", ".hpp",
|
||||||
|
".rb", ".ruby", ".php", ".swift", ".kt", ".kts", ".go",
|
||||||
|
".sh", ".bat", ".ps1", ".r", ".scala", ".pl", ".lua", ".dart",
|
||||||
|
".elm", ".xml", ".json", ".yaml", ".yml", ".md", ".txt",
|
||||||
|
}
|
||||||
|
|
||||||
|
# All allowed file extensions for attachment
|
||||||
|
ALLOWED_FILE_EXTENSIONS: Set[str] = {
|
||||||
|
# Code files
|
||||||
|
".py", ".js", ".ts", ".jsx", ".tsx", ".vue", ".java", ".c", ".cpp", ".cc", ".cxx",
|
||||||
|
".h", ".hpp", ".hxx", ".rb", ".go", ".rs", ".swift", ".kt", ".kts", ".php",
|
||||||
|
".sh", ".bash", ".zsh", ".fish", ".bat", ".cmd", ".ps1",
|
||||||
|
# Data files
|
||||||
|
".json", ".csv", ".yaml", ".yml", ".toml", ".xml", ".sql", ".db", ".sqlite", ".sqlite3",
|
||||||
|
# Documents
|
||||||
|
".txt", ".md", ".log", ".conf", ".cfg", ".ini", ".env", ".properties",
|
||||||
|
# Images
|
||||||
|
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg", ".ico",
|
||||||
|
# Archives
|
||||||
|
".zip", ".tar", ".gz", ".bz2", ".7z", ".rar", ".xz",
|
||||||
|
# Config files
|
||||||
|
".lock", ".gitignore", ".dockerignore", ".editorconfig", ".eslintrc",
|
||||||
|
".prettierrc", ".babelrc", ".nvmrc", ".npmrc",
|
||||||
|
# Binary/Compiled
|
||||||
|
".pyc", ".pyo", ".pyd", ".so", ".dll", ".dylib", ".exe", ".app",
|
||||||
|
".dmg", ".pkg", ".deb", ".rpm", ".apk", ".ipa",
|
||||||
|
# ML/AI
|
||||||
|
".pkl", ".pickle", ".joblib", ".npy", ".npz", ".safetensors", ".onnx",
|
||||||
|
".pt", ".pth", ".ckpt", ".pb", ".tflite", ".mlmodel", ".coreml", ".rknn",
|
||||||
|
# Data formats
|
||||||
|
".wasm", ".proto", ".graphql", ".graphqls", ".grpc", ".avro", ".parquet",
|
||||||
|
".orc", ".feather", ".arrow", ".hdf5", ".h5", ".mat", ".rdata", ".rds",
|
||||||
|
# Other
|
||||||
|
".pdf", ".class", ".jar", ".war",
|
||||||
|
}
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# SECURITY CONFIGURATION
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# System directories that should never be accessed
|
||||||
|
SYSTEM_DIRS_BLACKLIST: Set[str] = {
|
||||||
|
# macOS
|
||||||
|
"/System", "/Library", "/private", "/usr", "/bin", "/sbin",
|
||||||
|
# Linux
|
||||||
|
"/boot", "/dev", "/proc", "/sys", "/root",
|
||||||
|
# Windows
|
||||||
|
"C:\\Windows", "C:\\Program Files", "C:\\Program Files (x86)",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Directories to skip during file operations
|
||||||
|
SKIP_DIRECTORIES: Set[str] = {
|
||||||
|
# Python virtual environments
|
||||||
|
".venv", "venv", "env", "virtualenv",
|
||||||
|
"site-packages", "dist-packages",
|
||||||
|
# Python caches
|
||||||
|
"__pycache__", ".pytest_cache", ".mypy_cache",
|
||||||
|
# JavaScript/Node
|
||||||
|
"node_modules",
|
||||||
|
# Version control
|
||||||
|
".git", ".svn",
|
||||||
|
# IDEs
|
||||||
|
".idea", ".vscode",
|
||||||
|
# Build directories
|
||||||
|
"build", "dist", "eggs", ".eggs",
|
||||||
|
}
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# DATABASE QUERIES - SQL SAFETY
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Maximum query execution timeout (seconds)
|
||||||
|
MAX_QUERY_TIMEOUT = 5
|
||||||
|
|
||||||
|
# Maximum rows returned from queries
|
||||||
|
MAX_QUERY_RESULTS = 1000
|
||||||
|
|
||||||
|
# Default rows per query
|
||||||
|
DEFAULT_QUERY_LIMIT = 100
|
||||||
|
|
||||||
|
# Keywords that are blocked in database queries
|
||||||
|
DANGEROUS_SQL_KEYWORDS: Set[str] = {
|
||||||
|
"INSERT", "UPDATE", "DELETE", "DROP", "CREATE",
|
||||||
|
"ALTER", "TRUNCATE", "REPLACE", "ATTACH", "DETACH",
|
||||||
|
"PRAGMA", "VACUUM", "REINDEX",
|
||||||
|
}
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# MCP CONFIGURATION
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# Maximum tool call iterations per request
|
||||||
|
MAX_TOOL_LOOPS = 5
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# VALID COMMANDS
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
VALID_COMMANDS: Set[str] = {
|
||||||
|
"/retry", "/online", "/memory", "/paste", "/export", "/save", "/load",
|
||||||
|
"/delete", "/list", "/prev", "/next", "/stats", "/middleout", "/reset",
|
||||||
|
"/info", "/model", "/maxtoken", "/system", "/config", "/credits", "/clear",
|
||||||
|
"/cl", "/help", "/mcp",
|
||||||
|
}
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# COMMAND HELP DATABASE
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
COMMAND_HELP: Dict[str, Dict[str, Any]] = {
|
||||||
|
"/clear": {
|
||||||
|
"aliases": ["/cl"],
|
||||||
|
"description": "Clear the terminal screen for a clean interface.",
|
||||||
|
"usage": "/clear\n/cl",
|
||||||
|
"examples": [
|
||||||
|
("Clear screen", "/clear"),
|
||||||
|
("Using short alias", "/cl"),
|
||||||
|
],
|
||||||
|
"notes": "You can also use the keyboard shortcut Ctrl+L.",
|
||||||
|
},
|
||||||
|
"/help": {
|
||||||
|
"description": "Display help information for commands.",
|
||||||
|
"usage": "/help [command|topic]",
|
||||||
|
"examples": [
|
||||||
|
("Show all commands", "/help"),
|
||||||
|
("Get help for a specific command", "/help /model"),
|
||||||
|
("Get detailed MCP help", "/help mcp"),
|
||||||
|
],
|
||||||
|
"notes": "Use /help without arguments to see the full command list.",
|
||||||
|
},
|
||||||
|
"mcp": {
|
||||||
|
"description": "Complete guide to MCP (Model Context Protocol).",
|
||||||
|
"usage": "See detailed examples below",
|
||||||
|
"examples": [],
|
||||||
|
"notes": """
|
||||||
|
MCP (Model Context Protocol) gives your AI assistant direct access to:
|
||||||
|
• Local files and folders (read, search, list)
|
||||||
|
• SQLite databases (inspect, search, query)
|
||||||
|
|
||||||
|
FILE MODE (default):
|
||||||
|
/mcp on Start MCP server
|
||||||
|
/mcp add ~/Documents Grant access to folder
|
||||||
|
/mcp list View all allowed folders
|
||||||
|
|
||||||
|
DATABASE MODE:
|
||||||
|
/mcp add db ~/app/data.db Add specific database
|
||||||
|
/mcp db list View all databases
|
||||||
|
/mcp db 1 Work with database #1
|
||||||
|
/mcp files Switch back to file mode
|
||||||
|
|
||||||
|
WRITE MODE (optional):
|
||||||
|
/mcp write on Enable file modifications
|
||||||
|
/mcp write off Disable write mode (back to read-only)
|
||||||
|
|
||||||
|
For command-specific help: /help /mcp
|
||||||
|
""",
|
||||||
|
},
|
||||||
|
"/mcp": {
|
||||||
|
"description": "Manage MCP for local file access and SQLite database querying.",
|
||||||
|
"usage": "/mcp <command> [args]",
|
||||||
|
"examples": [
|
||||||
|
("Enable MCP server", "/mcp on"),
|
||||||
|
("Disable MCP server", "/mcp off"),
|
||||||
|
("Show MCP status", "/mcp status"),
|
||||||
|
("", ""),
|
||||||
|
("━━━ FILE MODE ━━━", ""),
|
||||||
|
("Add folder for file access", "/mcp add ~/Documents"),
|
||||||
|
("Remove folder", "/mcp remove ~/Desktop"),
|
||||||
|
("List allowed folders", "/mcp list"),
|
||||||
|
("Enable write mode", "/mcp write on"),
|
||||||
|
("", ""),
|
||||||
|
("━━━ DATABASE MODE ━━━", ""),
|
||||||
|
("Add SQLite database", "/mcp add db ~/app/data.db"),
|
||||||
|
("List all databases", "/mcp db list"),
|
||||||
|
("Switch to database #1", "/mcp db 1"),
|
||||||
|
("Switch back to file mode", "/mcp files"),
|
||||||
|
],
|
||||||
|
"notes": "MCP allows AI to read local files and query SQLite databases.",
|
||||||
|
},
|
||||||
|
"/memory": {
|
||||||
|
"description": "Toggle conversation memory.",
|
||||||
|
"usage": "/memory [on|off]",
|
||||||
|
"examples": [
|
||||||
|
("Check current memory status", "/memory"),
|
||||||
|
("Enable conversation memory", "/memory on"),
|
||||||
|
("Disable memory (save costs)", "/memory off"),
|
||||||
|
],
|
||||||
|
"notes": "Memory is ON by default. Disabling saves tokens.",
|
||||||
|
},
|
||||||
|
"/online": {
|
||||||
|
"description": "Enable or disable online mode (web search).",
|
||||||
|
"usage": "/online [on|off]",
|
||||||
|
"examples": [
|
||||||
|
("Check online mode status", "/online"),
|
||||||
|
("Enable web search", "/online on"),
|
||||||
|
("Disable web search", "/online off"),
|
||||||
|
],
|
||||||
|
"notes": "Not all models support online mode.",
|
||||||
|
},
|
||||||
|
"/paste": {
|
||||||
|
"description": "Paste plain text from clipboard and send to the AI.",
|
||||||
|
"usage": "/paste [prompt]",
|
||||||
|
"examples": [
|
||||||
|
("Paste clipboard content", "/paste"),
|
||||||
|
("Paste with a question", "/paste Explain this code"),
|
||||||
|
],
|
||||||
|
"notes": "Only plain text is supported.",
|
||||||
|
},
|
||||||
|
"/retry": {
|
||||||
|
"description": "Resend the last prompt from conversation history.",
|
||||||
|
"usage": "/retry",
|
||||||
|
"examples": [("Retry last message", "/retry")],
|
||||||
|
"notes": "Requires at least one message in history.",
|
||||||
|
},
|
||||||
|
"/next": {
|
||||||
|
"description": "View the next response in conversation history.",
|
||||||
|
"usage": "/next",
|
||||||
|
"examples": [("Navigate to next response", "/next")],
|
||||||
|
"notes": "Use /prev to go backward.",
|
||||||
|
},
|
||||||
|
"/prev": {
|
||||||
|
"description": "View the previous response in conversation history.",
|
||||||
|
"usage": "/prev",
|
||||||
|
"examples": [("Navigate to previous response", "/prev")],
|
||||||
|
"notes": "Use /next to go forward.",
|
||||||
|
},
|
||||||
|
"/reset": {
|
||||||
|
"description": "Clear conversation history and reset system prompt.",
|
||||||
|
"usage": "/reset",
|
||||||
|
"examples": [("Reset conversation", "/reset")],
|
||||||
|
"notes": "Requires confirmation.",
|
||||||
|
},
|
||||||
|
"/info": {
|
||||||
|
"description": "Display detailed information about a model.",
|
||||||
|
"usage": "/info [model_id]",
|
||||||
|
"examples": [
|
||||||
|
("Show current model info", "/info"),
|
||||||
|
("Show specific model info", "/info gpt-4o"),
|
||||||
|
],
|
||||||
|
"notes": "Shows pricing, capabilities, and context length.",
|
||||||
|
},
|
||||||
|
"/model": {
|
||||||
|
"description": "Select or change the AI model.",
|
||||||
|
"usage": "/model [search_term]",
|
||||||
|
"examples": [
|
||||||
|
("List all models", "/model"),
|
||||||
|
("Search for GPT models", "/model gpt"),
|
||||||
|
("Search for Claude models", "/model claude"),
|
||||||
|
],
|
||||||
|
"notes": "Models are numbered for easy selection.",
|
||||||
|
},
|
||||||
|
"/config": {
|
||||||
|
"description": "View or modify application configuration.",
|
||||||
|
"usage": "/config [setting] [value]",
|
||||||
|
"examples": [
|
||||||
|
("View all settings", "/config"),
|
||||||
|
("Set API key", "/config api"),
|
||||||
|
("Set default model", "/config model"),
|
||||||
|
("Set system prompt", "/config system You are a helpful assistant"),
|
||||||
|
("Enable streaming", "/config stream on"),
|
||||||
|
],
|
||||||
|
"notes": "Available: api, url, model, system, stream, costwarning, maxtoken, online, loglevel.",
|
||||||
|
},
|
||||||
|
"/maxtoken": {
|
||||||
|
"description": "Set a temporary session token limit.",
|
||||||
|
"usage": "/maxtoken [value]",
|
||||||
|
"examples": [
|
||||||
|
("View current session limit", "/maxtoken"),
|
||||||
|
("Set session limit to 2000", "/maxtoken 2000"),
|
||||||
|
],
|
||||||
|
"notes": "Cannot exceed stored max token limit.",
|
||||||
|
},
|
||||||
|
"/system": {
|
||||||
|
"description": "Set or clear the session-level system prompt.",
|
||||||
|
"usage": "/system [prompt|clear|default <prompt>]",
|
||||||
|
"examples": [
|
||||||
|
("View current system prompt", "/system"),
|
||||||
|
("Set as Python expert", "/system You are a Python expert"),
|
||||||
|
("Multiline with newlines", r"/system You are an expert.\nBe clear and concise."),
|
||||||
|
("Save as default", "/system default You are a helpful assistant"),
|
||||||
|
("Revert to default", "/system clear"),
|
||||||
|
("Blank prompt", '/system ""'),
|
||||||
|
],
|
||||||
|
"notes": r"Use \n for newlines. /system clear reverts to hardcoded default.",
|
||||||
|
},
|
||||||
|
"/save": {
|
||||||
|
"description": "Save the current conversation history.",
|
||||||
|
"usage": "/save <name>",
|
||||||
|
"examples": [("Save conversation", "/save my_chat")],
|
||||||
|
"notes": "Saved conversations can be loaded later with /load.",
|
||||||
|
},
|
||||||
|
"/load": {
|
||||||
|
"description": "Load a saved conversation.",
|
||||||
|
"usage": "/load <name|number>",
|
||||||
|
"examples": [
|
||||||
|
("Load by name", "/load my_chat"),
|
||||||
|
("Load by number from /list", "/load 3"),
|
||||||
|
],
|
||||||
|
"notes": "Use /list to see numbered conversations.",
|
||||||
|
},
|
||||||
|
"/delete": {
|
||||||
|
"description": "Delete a saved conversation.",
|
||||||
|
"usage": "/delete <name|number>",
|
||||||
|
"examples": [("Delete by name", "/delete my_chat")],
|
||||||
|
"notes": "Requires confirmation. Cannot be undone.",
|
||||||
|
},
|
||||||
|
"/list": {
|
||||||
|
"description": "List all saved conversations.",
|
||||||
|
"usage": "/list",
|
||||||
|
"examples": [("Show saved conversations", "/list")],
|
||||||
|
"notes": "Conversations are numbered for use with /load and /delete.",
|
||||||
|
},
|
||||||
|
"/export": {
|
||||||
|
"description": "Export the current conversation to a file.",
|
||||||
|
"usage": "/export <format> <filename>",
|
||||||
|
"examples": [
|
||||||
|
("Export as Markdown", "/export md notes.md"),
|
||||||
|
("Export as JSON", "/export json conversation.json"),
|
||||||
|
("Export as HTML", "/export html report.html"),
|
||||||
|
],
|
||||||
|
"notes": "Available formats: md, json, html.",
|
||||||
|
},
|
||||||
|
"/stats": {
|
||||||
|
"description": "Display session statistics.",
|
||||||
|
"usage": "/stats",
|
||||||
|
"examples": [("View session statistics", "/stats")],
|
||||||
|
"notes": "Shows tokens, costs, and credits.",
|
||||||
|
},
|
||||||
|
"/credits": {
|
||||||
|
"description": "Display your OpenRouter account credits.",
|
||||||
|
"usage": "/credits",
|
||||||
|
"examples": [("Check credits", "/credits")],
|
||||||
|
"notes": "Shows total, used, and remaining credits.",
|
||||||
|
},
|
||||||
|
"/middleout": {
|
||||||
|
"description": "Toggle middle-out transform for long prompts.",
|
||||||
|
"usage": "/middleout [on|off]",
|
||||||
|
"examples": [
|
||||||
|
("Check status", "/middleout"),
|
||||||
|
("Enable compression", "/middleout on"),
|
||||||
|
],
|
||||||
|
"notes": "Compresses prompts exceeding context size.",
|
||||||
|
},
|
||||||
|
}
|
||||||
14
oai/core/__init__.py
Normal file
14
oai/core/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""
|
||||||
|
Core functionality for oAI.
|
||||||
|
|
||||||
|
This module provides the main session management and AI client
|
||||||
|
classes that power the chat application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from oai.core.session import ChatSession
|
||||||
|
from oai.core.client import AIClient
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ChatSession",
|
||||||
|
"AIClient",
|
||||||
|
]
|
||||||
422
oai/core/client.py
Normal file
422
oai/core/client.py
Normal file
@@ -0,0 +1,422 @@
|
|||||||
|
"""
|
||||||
|
AI Client for oAI.
|
||||||
|
|
||||||
|
This module provides a high-level client for interacting with AI models
|
||||||
|
through the provider abstraction layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
|
from oai.constants import APP_NAME, APP_URL, MODEL_PRICING
|
||||||
|
from oai.providers.base import (
|
||||||
|
AIProvider,
|
||||||
|
ChatMessage,
|
||||||
|
ChatResponse,
|
||||||
|
ModelInfo,
|
||||||
|
StreamChunk,
|
||||||
|
ToolCall,
|
||||||
|
UsageStats,
|
||||||
|
)
|
||||||
|
from oai.providers.openrouter import OpenRouterProvider
|
||||||
|
from oai.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class AIClient:
|
||||||
|
"""
|
||||||
|
High-level AI client for chat interactions.
|
||||||
|
|
||||||
|
Provides a simplified interface for sending chat requests,
|
||||||
|
handling streaming, and managing tool calls.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
provider: The underlying AI provider
|
||||||
|
default_model: Default model ID to use
|
||||||
|
http_headers: Custom HTTP headers for requests
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
provider_class: type = OpenRouterProvider,
|
||||||
|
app_name: str = APP_NAME,
|
||||||
|
app_url: str = APP_URL,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the AI client.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for authentication
|
||||||
|
base_url: Optional custom base URL
|
||||||
|
provider_class: Provider class to use (default: OpenRouterProvider)
|
||||||
|
app_name: Application name for headers
|
||||||
|
app_url: Application URL for headers
|
||||||
|
"""
|
||||||
|
self.provider: AIProvider = provider_class(
|
||||||
|
api_key=api_key,
|
||||||
|
base_url=base_url,
|
||||||
|
app_name=app_name,
|
||||||
|
app_url=app_url,
|
||||||
|
)
|
||||||
|
self.default_model: Optional[str] = None
|
||||||
|
self.logger = get_logger()
|
||||||
|
|
||||||
|
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
||||||
|
"""
|
||||||
|
Get available models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_text_only: Whether to exclude video-only models
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ModelInfo objects
|
||||||
|
"""
|
||||||
|
return self.provider.list_models(filter_text_only=filter_text_only)
|
||||||
|
|
||||||
|
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
||||||
|
"""
|
||||||
|
Get information about a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ModelInfo or None if not found
|
||||||
|
"""
|
||||||
|
return self.provider.get_model(model_id)
|
||||||
|
|
||||||
|
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get raw model data for provider-specific fields.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw model dictionary or None
|
||||||
|
"""
|
||||||
|
if hasattr(self.provider, "get_raw_model"):
|
||||||
|
return self.provider.get_raw_model(model_id)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
model: Optional[str] = None,
|
||||||
|
stream: bool = False,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
online: bool = False,
|
||||||
|
transforms: Optional[List[str]] = None,
|
||||||
|
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
||||||
|
"""
|
||||||
|
Send a chat request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of message dictionaries
|
||||||
|
model: Model ID (uses default if not specified)
|
||||||
|
stream: Whether to stream the response
|
||||||
|
max_tokens: Maximum tokens in response
|
||||||
|
temperature: Sampling temperature
|
||||||
|
tools: Tool definitions for function calling
|
||||||
|
tool_choice: Tool selection mode
|
||||||
|
system_prompt: System prompt to prepend
|
||||||
|
online: Whether to enable online mode
|
||||||
|
transforms: List of transforms (e.g., ["middle-out"])
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If no model specified and no default set
|
||||||
|
"""
|
||||||
|
model_id = model or self.default_model
|
||||||
|
if not model_id:
|
||||||
|
raise ValueError("No model specified and no default set")
|
||||||
|
|
||||||
|
# Apply online mode suffix
|
||||||
|
if online and hasattr(self.provider, "get_effective_model_id"):
|
||||||
|
model_id = self.provider.get_effective_model_id(model_id, True)
|
||||||
|
|
||||||
|
# Convert dict messages to ChatMessage objects
|
||||||
|
chat_messages = []
|
||||||
|
|
||||||
|
# Add system prompt if provided
|
||||||
|
if system_prompt:
|
||||||
|
chat_messages.append(ChatMessage(role="system", content=system_prompt))
|
||||||
|
|
||||||
|
# Convert message dicts
|
||||||
|
for msg in messages:
|
||||||
|
# Convert tool_calls dicts to ToolCall objects if present
|
||||||
|
tool_calls_data = msg.get("tool_calls")
|
||||||
|
tool_calls_obj = None
|
||||||
|
if tool_calls_data:
|
||||||
|
from oai.providers.base import ToolCall, ToolFunction
|
||||||
|
tool_calls_obj = []
|
||||||
|
for tc in tool_calls_data:
|
||||||
|
# Handle both ToolCall objects and dicts
|
||||||
|
if isinstance(tc, ToolCall):
|
||||||
|
tool_calls_obj.append(tc)
|
||||||
|
elif isinstance(tc, dict):
|
||||||
|
func_data = tc.get("function", {})
|
||||||
|
tool_calls_obj.append(
|
||||||
|
ToolCall(
|
||||||
|
id=tc.get("id", ""),
|
||||||
|
type=tc.get("type", "function"),
|
||||||
|
function=ToolFunction(
|
||||||
|
name=func_data.get("name", ""),
|
||||||
|
arguments=func_data.get("arguments", "{}"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
chat_messages.append(
|
||||||
|
ChatMessage(
|
||||||
|
role=msg.get("role", "user"),
|
||||||
|
content=msg.get("content"),
|
||||||
|
tool_calls=tool_calls_obj,
|
||||||
|
tool_call_id=msg.get("tool_call_id"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.logger.debug(
|
||||||
|
f"Sending chat request: model={model_id}, "
|
||||||
|
f"messages={len(chat_messages)}, stream={stream}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.provider.chat(
|
||||||
|
model=model_id,
|
||||||
|
messages=chat_messages,
|
||||||
|
stream=stream,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
transforms=transforms,
|
||||||
|
)
|
||||||
|
|
||||||
|
def chat_with_tools(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
tools: List[Dict[str, Any]],
|
||||||
|
tool_executor: Callable[[str, Dict[str, Any]], Dict[str, Any]],
|
||||||
|
model: Optional[str] = None,
|
||||||
|
max_loops: int = 5,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
on_tool_call: Optional[Callable[[ToolCall], None]] = None,
|
||||||
|
on_tool_result: Optional[Callable[[str, Dict[str, Any]], None]] = None,
|
||||||
|
) -> ChatResponse:
|
||||||
|
"""
|
||||||
|
Send a chat request with automatic tool call handling.
|
||||||
|
|
||||||
|
Executes tool calls returned by the model and continues
|
||||||
|
the conversation until no more tool calls are requested.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Initial messages
|
||||||
|
tools: Tool definitions
|
||||||
|
tool_executor: Function to execute tool calls
|
||||||
|
model: Model ID
|
||||||
|
max_loops: Maximum tool call iterations
|
||||||
|
max_tokens: Maximum response tokens
|
||||||
|
system_prompt: System prompt
|
||||||
|
on_tool_call: Callback when tool is called
|
||||||
|
on_tool_result: Callback when tool returns result
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final ChatResponse after all tool calls complete
|
||||||
|
"""
|
||||||
|
model_id = model or self.default_model
|
||||||
|
if not model_id:
|
||||||
|
raise ValueError("No model specified and no default set")
|
||||||
|
|
||||||
|
# Build initial messages
|
||||||
|
chat_messages = []
|
||||||
|
if system_prompt:
|
||||||
|
chat_messages.append({"role": "system", "content": system_prompt})
|
||||||
|
chat_messages.extend(messages)
|
||||||
|
|
||||||
|
loop_count = 0
|
||||||
|
current_response: Optional[ChatResponse] = None
|
||||||
|
|
||||||
|
while loop_count < max_loops:
|
||||||
|
# Send request
|
||||||
|
response = self.chat(
|
||||||
|
messages=chat_messages,
|
||||||
|
model=model_id,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(response, ChatResponse):
|
||||||
|
raise ValueError("Expected non-streaming response")
|
||||||
|
|
||||||
|
current_response = response
|
||||||
|
|
||||||
|
# Check for tool calls
|
||||||
|
tool_calls = response.tool_calls
|
||||||
|
if not tool_calls:
|
||||||
|
break
|
||||||
|
|
||||||
|
self.logger.info(f"Model requested {len(tool_calls)} tool call(s)")
|
||||||
|
|
||||||
|
# Process each tool call
|
||||||
|
tool_results = []
|
||||||
|
for tc in tool_calls:
|
||||||
|
if on_tool_call:
|
||||||
|
on_tool_call(tc)
|
||||||
|
|
||||||
|
try:
|
||||||
|
args = json.loads(tc.function.arguments)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
self.logger.error(f"Failed to parse tool arguments: {e}")
|
||||||
|
result = {"error": f"Invalid arguments: {e}"}
|
||||||
|
else:
|
||||||
|
result = tool_executor(tc.function.name, args)
|
||||||
|
|
||||||
|
if on_tool_result:
|
||||||
|
on_tool_result(tc.function.name, result)
|
||||||
|
|
||||||
|
tool_results.append({
|
||||||
|
"tool_call_id": tc.id,
|
||||||
|
"role": "tool",
|
||||||
|
"name": tc.function.name,
|
||||||
|
"content": json.dumps(result),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add assistant message with tool calls
|
||||||
|
assistant_msg = {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response.content,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": tc.id,
|
||||||
|
"type": tc.type,
|
||||||
|
"function": {
|
||||||
|
"name": tc.function.name,
|
||||||
|
"arguments": tc.function.arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for tc in tool_calls
|
||||||
|
],
|
||||||
|
}
|
||||||
|
chat_messages.append(assistant_msg)
|
||||||
|
chat_messages.extend(tool_results)
|
||||||
|
|
||||||
|
loop_count += 1
|
||||||
|
|
||||||
|
if loop_count >= max_loops:
|
||||||
|
self.logger.warning(f"Reached max tool call loops ({max_loops})")
|
||||||
|
|
||||||
|
return current_response
|
||||||
|
|
||||||
|
def stream_chat(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
model: Optional[str] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
system_prompt: Optional[str] = None,
|
||||||
|
online: bool = False,
|
||||||
|
on_chunk: Optional[Callable[[StreamChunk], None]] = None,
|
||||||
|
) -> tuple[str, Optional[UsageStats]]:
|
||||||
|
"""
|
||||||
|
Stream a chat response and collect the full text.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Chat messages
|
||||||
|
model: Model ID
|
||||||
|
max_tokens: Maximum tokens
|
||||||
|
system_prompt: System prompt
|
||||||
|
online: Online mode
|
||||||
|
on_chunk: Optional callback for each chunk
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (full_response_text, usage_stats)
|
||||||
|
"""
|
||||||
|
response = self.chat(
|
||||||
|
messages=messages,
|
||||||
|
model=model,
|
||||||
|
stream=True,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
online=online,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(response, ChatResponse):
|
||||||
|
# Not actually streaming
|
||||||
|
return response.content or "", response.usage
|
||||||
|
|
||||||
|
full_text = ""
|
||||||
|
usage: Optional[UsageStats] = None
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.error:
|
||||||
|
self.logger.error(f"Stream error: {chunk.error}")
|
||||||
|
break
|
||||||
|
|
||||||
|
if chunk.delta_content:
|
||||||
|
full_text += chunk.delta_content
|
||||||
|
if on_chunk:
|
||||||
|
on_chunk(chunk)
|
||||||
|
|
||||||
|
if chunk.usage:
|
||||||
|
usage = chunk.usage
|
||||||
|
|
||||||
|
return full_text, usage
|
||||||
|
|
||||||
|
def get_credits(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get account credit information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Credit info dict or None if unavailable
|
||||||
|
"""
|
||||||
|
return self.provider.get_credits()
|
||||||
|
|
||||||
|
def estimate_cost(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
input_tokens: int,
|
||||||
|
output_tokens: int,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Estimate cost for a completion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model ID
|
||||||
|
input_tokens: Number of input tokens
|
||||||
|
output_tokens: Number of output tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated cost in USD
|
||||||
|
"""
|
||||||
|
if hasattr(self.provider, "estimate_cost"):
|
||||||
|
return self.provider.estimate_cost(model_id, input_tokens, output_tokens)
|
||||||
|
|
||||||
|
# Fallback to default pricing
|
||||||
|
input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000
|
||||||
|
output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000
|
||||||
|
return input_cost + output_cost
|
||||||
|
|
||||||
|
def set_default_model(self, model_id: str) -> None:
|
||||||
|
"""
|
||||||
|
Set the default model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model ID to use as default
|
||||||
|
"""
|
||||||
|
self.default_model = model_id
|
||||||
|
self.logger.info(f"Default model set to: {model_id}")
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""Clear the provider's model cache."""
|
||||||
|
if hasattr(self.provider, "clear_cache"):
|
||||||
|
self.provider.clear_cache()
|
||||||
659
oai/core/session.py
Normal file
659
oai/core/session.py
Normal file
@@ -0,0 +1,659 @@
|
|||||||
|
"""
|
||||||
|
Chat session management for oAI.
|
||||||
|
|
||||||
|
This module provides the ChatSession class that manages an interactive
|
||||||
|
chat session including history, state, and message handling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
|
from rich.live import Live
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
|
||||||
|
from oai.commands.registry import CommandContext, CommandResult, registry
|
||||||
|
from oai.config.database import Database
|
||||||
|
from oai.config.settings import Settings
|
||||||
|
from oai.constants import (
|
||||||
|
COST_WARNING_THRESHOLD,
|
||||||
|
LOW_CREDIT_AMOUNT,
|
||||||
|
LOW_CREDIT_RATIO,
|
||||||
|
)
|
||||||
|
from oai.core.client import AIClient
|
||||||
|
from oai.mcp.manager import MCPManager
|
||||||
|
from oai.providers.base import ChatResponse, StreamChunk, UsageStats
|
||||||
|
from oai.ui.console import (
|
||||||
|
console,
|
||||||
|
display_markdown,
|
||||||
|
display_panel,
|
||||||
|
print_error,
|
||||||
|
print_info,
|
||||||
|
print_success,
|
||||||
|
print_warning,
|
||||||
|
)
|
||||||
|
from oai.ui.prompts import prompt_copy_response
|
||||||
|
from oai.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SessionStats:
|
||||||
|
"""
|
||||||
|
Statistics for the current session.
|
||||||
|
|
||||||
|
Tracks tokens, costs, and message counts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
total_input_tokens: int = 0
|
||||||
|
total_output_tokens: int = 0
|
||||||
|
total_cost: float = 0.0
|
||||||
|
message_count: int = 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total_tokens(self) -> int:
|
||||||
|
"""Get total token count."""
|
||||||
|
return self.total_input_tokens + self.total_output_tokens
|
||||||
|
|
||||||
|
def add_usage(self, usage: Optional[UsageStats], cost: float = 0.0) -> None:
|
||||||
|
"""
|
||||||
|
Add usage stats from a response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
usage: Usage statistics
|
||||||
|
cost: Cost if not in usage
|
||||||
|
"""
|
||||||
|
if usage:
|
||||||
|
self.total_input_tokens += usage.prompt_tokens
|
||||||
|
self.total_output_tokens += usage.completion_tokens
|
||||||
|
if usage.total_cost_usd:
|
||||||
|
self.total_cost += usage.total_cost_usd
|
||||||
|
else:
|
||||||
|
self.total_cost += cost
|
||||||
|
else:
|
||||||
|
self.total_cost += cost
|
||||||
|
self.message_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HistoryEntry:
|
||||||
|
"""
|
||||||
|
A single entry in the conversation history.
|
||||||
|
|
||||||
|
Stores the user prompt, assistant response, and metrics.
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt: str
|
||||||
|
response: str
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
completion_tokens: int = 0
|
||||||
|
msg_cost: float = 0.0
|
||||||
|
timestamp: Optional[float] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format."""
|
||||||
|
return {
|
||||||
|
"prompt": self.prompt,
|
||||||
|
"response": self.response,
|
||||||
|
"prompt_tokens": self.prompt_tokens,
|
||||||
|
"completion_tokens": self.completion_tokens,
|
||||||
|
"msg_cost": self.msg_cost,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class ChatSession:
|
||||||
|
"""
|
||||||
|
Manages an interactive chat session.
|
||||||
|
|
||||||
|
Handles conversation history, state management, command processing,
|
||||||
|
and communication with the AI client.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
client: AI client for API requests
|
||||||
|
settings: Application settings
|
||||||
|
mcp_manager: MCP manager for file/database access
|
||||||
|
history: Conversation history
|
||||||
|
stats: Session statistics
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
client: AIClient,
|
||||||
|
settings: Settings,
|
||||||
|
mcp_manager: Optional[MCPManager] = None,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize a chat session.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
client: AI client instance
|
||||||
|
settings: Application settings
|
||||||
|
mcp_manager: Optional MCP manager
|
||||||
|
"""
|
||||||
|
self.client = client
|
||||||
|
self.settings = settings
|
||||||
|
self.mcp_manager = mcp_manager
|
||||||
|
self.db = Database()
|
||||||
|
|
||||||
|
self.history: List[HistoryEntry] = []
|
||||||
|
self.stats = SessionStats()
|
||||||
|
|
||||||
|
# Session state
|
||||||
|
self.system_prompt: str = settings.effective_system_prompt
|
||||||
|
self.memory_enabled: bool = True
|
||||||
|
self.memory_start_index: int = 0
|
||||||
|
self.online_enabled: bool = settings.default_online_mode
|
||||||
|
self.middle_out_enabled: bool = False
|
||||||
|
self.session_max_token: int = 0
|
||||||
|
self.current_index: int = 0
|
||||||
|
|
||||||
|
# Selected model
|
||||||
|
self.selected_model: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
self.logger = get_logger()
|
||||||
|
|
||||||
|
def get_context(self) -> CommandContext:
|
||||||
|
"""
|
||||||
|
Get the current command context.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CommandContext with current session state
|
||||||
|
"""
|
||||||
|
return CommandContext(
|
||||||
|
settings=self.settings,
|
||||||
|
provider=self.client.provider,
|
||||||
|
mcp_manager=self.mcp_manager,
|
||||||
|
selected_model_raw=self.selected_model,
|
||||||
|
session_history=[e.to_dict() for e in self.history],
|
||||||
|
session_system_prompt=self.system_prompt,
|
||||||
|
memory_enabled=self.memory_enabled,
|
||||||
|
memory_start_index=self.memory_start_index,
|
||||||
|
online_enabled=self.online_enabled,
|
||||||
|
middle_out_enabled=self.middle_out_enabled,
|
||||||
|
session_max_token=self.session_max_token,
|
||||||
|
total_input_tokens=self.stats.total_input_tokens,
|
||||||
|
total_output_tokens=self.stats.total_output_tokens,
|
||||||
|
total_cost=self.stats.total_cost,
|
||||||
|
message_count=self.stats.message_count,
|
||||||
|
current_index=self.current_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
def set_model(self, model: Dict[str, Any]) -> None:
|
||||||
|
"""
|
||||||
|
Set the selected model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Raw model dictionary
|
||||||
|
"""
|
||||||
|
self.selected_model = model
|
||||||
|
self.client.set_default_model(model["id"])
|
||||||
|
self.logger.info(f"Model selected: {model['id']}")
|
||||||
|
|
||||||
|
def build_api_messages(self, user_input: str) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Build the messages array for an API request.
|
||||||
|
|
||||||
|
Includes system prompt, history (if memory enabled), and current input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_input: Current user input
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of message dictionaries
|
||||||
|
"""
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
# Add system prompt
|
||||||
|
if self.system_prompt:
|
||||||
|
messages.append({"role": "system", "content": self.system_prompt})
|
||||||
|
|
||||||
|
# Add database context if in database mode
|
||||||
|
if self.mcp_manager and self.mcp_manager.enabled:
|
||||||
|
if self.mcp_manager.mode == "database" and self.mcp_manager.selected_db_index is not None:
|
||||||
|
db = self.mcp_manager.databases[self.mcp_manager.selected_db_index]
|
||||||
|
db_context = (
|
||||||
|
f"You are connected to SQLite database: {db['name']}\n"
|
||||||
|
f"Available tables: {', '.join(db['tables'])}\n\n"
|
||||||
|
"Use inspect_database, search_database, or query_database tools. "
|
||||||
|
"All queries are read-only."
|
||||||
|
)
|
||||||
|
messages.append({"role": "system", "content": db_context})
|
||||||
|
|
||||||
|
# Add history if memory enabled
|
||||||
|
if self.memory_enabled:
|
||||||
|
for i in range(self.memory_start_index, len(self.history)):
|
||||||
|
entry = self.history[i]
|
||||||
|
messages.append({"role": "user", "content": entry.prompt})
|
||||||
|
messages.append({"role": "assistant", "content": entry.response})
|
||||||
|
|
||||||
|
# Add current message
|
||||||
|
messages.append({"role": "user", "content": user_input})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def get_mcp_tools(self) -> Optional[List[Dict[str, Any]]]:
|
||||||
|
"""
|
||||||
|
Get MCP tool definitions if available.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tool schemas or None
|
||||||
|
"""
|
||||||
|
if not self.mcp_manager or not self.mcp_manager.enabled:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if not self.selected_model:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Check if model supports tools
|
||||||
|
supported_params = self.selected_model.get("supported_parameters", [])
|
||||||
|
if "tools" not in supported_params and "functions" not in supported_params:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return self.mcp_manager.get_tools_schema()
|
||||||
|
|
||||||
|
async def execute_tool(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
tool_args: Dict[str, Any],
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Execute an MCP tool.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_name: Name of the tool
|
||||||
|
tool_args: Tool arguments
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tool execution result
|
||||||
|
"""
|
||||||
|
if not self.mcp_manager:
|
||||||
|
return {"error": "MCP not available"}
|
||||||
|
|
||||||
|
return await self.mcp_manager.call_tool(tool_name, **tool_args)
|
||||||
|
|
||||||
|
def send_message(
|
||||||
|
self,
|
||||||
|
user_input: str,
|
||||||
|
stream: bool = True,
|
||||||
|
on_stream_chunk: Optional[Callable[[str], None]] = None,
|
||||||
|
) -> Tuple[str, Optional[UsageStats], float]:
|
||||||
|
"""
|
||||||
|
Send a message and get a response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_input: User's input text
|
||||||
|
stream: Whether to stream the response
|
||||||
|
on_stream_chunk: Callback for stream chunks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (response_text, usage_stats, response_time)
|
||||||
|
"""
|
||||||
|
if not self.selected_model:
|
||||||
|
raise ValueError("No model selected")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
messages = self.build_api_messages(user_input)
|
||||||
|
|
||||||
|
# Get MCP tools
|
||||||
|
tools = self.get_mcp_tools()
|
||||||
|
if tools:
|
||||||
|
# Disable streaming when tools are present
|
||||||
|
stream = False
|
||||||
|
|
||||||
|
# Build request parameters
|
||||||
|
model_id = self.selected_model["id"]
|
||||||
|
if self.online_enabled:
|
||||||
|
if hasattr(self.client.provider, "get_effective_model_id"):
|
||||||
|
model_id = self.client.provider.get_effective_model_id(model_id, True)
|
||||||
|
|
||||||
|
transforms = ["middle-out"] if self.middle_out_enabled else None
|
||||||
|
|
||||||
|
max_tokens = None
|
||||||
|
if self.session_max_token > 0:
|
||||||
|
max_tokens = self.session_max_token
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
# Use tool handling flow
|
||||||
|
response = self._send_with_tools(
|
||||||
|
messages=messages,
|
||||||
|
model_id=model_id,
|
||||||
|
tools=tools,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
transforms=transforms,
|
||||||
|
)
|
||||||
|
response_time = time.time() - start_time
|
||||||
|
return response.content or "", response.usage, response_time
|
||||||
|
|
||||||
|
elif stream:
|
||||||
|
# Use streaming flow
|
||||||
|
full_text, usage = self._stream_response(
|
||||||
|
messages=messages,
|
||||||
|
model_id=model_id,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
transforms=transforms,
|
||||||
|
on_chunk=on_stream_chunk,
|
||||||
|
)
|
||||||
|
response_time = time.time() - start_time
|
||||||
|
return full_text, usage, response_time
|
||||||
|
|
||||||
|
else:
|
||||||
|
# Non-streaming request
|
||||||
|
response = self.client.chat(
|
||||||
|
messages=messages,
|
||||||
|
model=model_id,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
transforms=transforms,
|
||||||
|
)
|
||||||
|
response_time = time.time() - start_time
|
||||||
|
|
||||||
|
if isinstance(response, ChatResponse):
|
||||||
|
return response.content or "", response.usage, response_time
|
||||||
|
else:
|
||||||
|
return "", None, response_time
|
||||||
|
|
||||||
|
def _send_with_tools(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
model_id: str,
|
||||||
|
tools: List[Dict[str, Any]],
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
transforms: Optional[List[str]] = None,
|
||||||
|
) -> ChatResponse:
|
||||||
|
"""
|
||||||
|
Send a request with tool call handling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: API messages
|
||||||
|
model_id: Model ID
|
||||||
|
tools: Tool definitions
|
||||||
|
max_tokens: Max tokens
|
||||||
|
transforms: Transforms list
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Final ChatResponse
|
||||||
|
"""
|
||||||
|
max_loops = 5
|
||||||
|
loop_count = 0
|
||||||
|
api_messages = list(messages)
|
||||||
|
|
||||||
|
while loop_count < max_loops:
|
||||||
|
response = self.client.chat(
|
||||||
|
messages=api_messages,
|
||||||
|
model=model_id,
|
||||||
|
stream=False,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice="auto",
|
||||||
|
transforms=transforms,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(response, ChatResponse):
|
||||||
|
raise ValueError("Expected ChatResponse")
|
||||||
|
|
||||||
|
tool_calls = response.tool_calls
|
||||||
|
if not tool_calls:
|
||||||
|
return response
|
||||||
|
|
||||||
|
console.print(f"\n[dim yellow]🔧 AI requesting {len(tool_calls)} tool call(s)...[/]")
|
||||||
|
|
||||||
|
tool_results = []
|
||||||
|
for tc in tool_calls:
|
||||||
|
try:
|
||||||
|
args = json.loads(tc.function.arguments)
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
self.logger.error(f"Failed to parse tool arguments: {e}")
|
||||||
|
tool_results.append({
|
||||||
|
"tool_call_id": tc.id,
|
||||||
|
"role": "tool",
|
||||||
|
"name": tc.function.name,
|
||||||
|
"content": json.dumps({"error": f"Invalid arguments: {e}"}),
|
||||||
|
})
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Display tool call
|
||||||
|
args_display = ", ".join(
|
||||||
|
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
|
||||||
|
for k, v in args.items()
|
||||||
|
)
|
||||||
|
console.print(f"[dim cyan] → {tc.function.name}({args_display})[/]")
|
||||||
|
|
||||||
|
# Execute tool
|
||||||
|
result = asyncio.run(self.execute_tool(tc.function.name, args))
|
||||||
|
|
||||||
|
if "error" in result:
|
||||||
|
console.print(f"[dim red] ✗ Error: {result['error']}[/]")
|
||||||
|
else:
|
||||||
|
self._display_tool_success(tc.function.name, result)
|
||||||
|
|
||||||
|
tool_results.append({
|
||||||
|
"tool_call_id": tc.id,
|
||||||
|
"role": "tool",
|
||||||
|
"name": tc.function.name,
|
||||||
|
"content": json.dumps(result),
|
||||||
|
})
|
||||||
|
|
||||||
|
# Add assistant message with tool calls
|
||||||
|
api_messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": response.content,
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"id": tc.id,
|
||||||
|
"type": tc.type,
|
||||||
|
"function": {
|
||||||
|
"name": tc.function.name,
|
||||||
|
"arguments": tc.function.arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for tc in tool_calls
|
||||||
|
],
|
||||||
|
})
|
||||||
|
api_messages.extend(tool_results)
|
||||||
|
|
||||||
|
console.print("\n[dim cyan]💭 Processing tool results...[/]")
|
||||||
|
loop_count += 1
|
||||||
|
|
||||||
|
self.logger.warning(f"Reached max tool loops ({max_loops})")
|
||||||
|
console.print(f"[bold yellow]⚠️ Reached maximum tool calls ({max_loops})[/]")
|
||||||
|
return response
|
||||||
|
|
||||||
|
def _display_tool_success(self, tool_name: str, result: Dict[str, Any]) -> None:
|
||||||
|
"""Display a success message for a tool call."""
|
||||||
|
if tool_name == "search_files":
|
||||||
|
count = result.get("count", 0)
|
||||||
|
console.print(f"[dim green] ✓ Found {count} file(s)[/]")
|
||||||
|
elif tool_name == "read_file":
|
||||||
|
size = result.get("size", 0)
|
||||||
|
truncated = " (truncated)" if result.get("truncated") else ""
|
||||||
|
console.print(f"[dim green] ✓ Read {size} bytes{truncated}[/]")
|
||||||
|
elif tool_name == "list_directory":
|
||||||
|
count = result.get("count", 0)
|
||||||
|
console.print(f"[dim green] ✓ Listed {count} item(s)[/]")
|
||||||
|
elif tool_name == "inspect_database":
|
||||||
|
if "table" in result:
|
||||||
|
console.print(f"[dim green] ✓ Inspected table: {result['table']}[/]")
|
||||||
|
else:
|
||||||
|
console.print(f"[dim green] ✓ Inspected database ({result.get('table_count', 0)} tables)[/]")
|
||||||
|
elif tool_name == "search_database":
|
||||||
|
count = result.get("count", 0)
|
||||||
|
console.print(f"[dim green] ✓ Found {count} match(es)[/]")
|
||||||
|
elif tool_name == "query_database":
|
||||||
|
count = result.get("count", 0)
|
||||||
|
console.print(f"[dim green] ✓ Query returned {count} row(s)[/]")
|
||||||
|
else:
|
||||||
|
console.print("[dim green] ✓ Success[/]")
|
||||||
|
|
||||||
|
def _stream_response(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, Any]],
|
||||||
|
model_id: str,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
transforms: Optional[List[str]] = None,
|
||||||
|
on_chunk: Optional[Callable[[str], None]] = None,
|
||||||
|
) -> Tuple[str, Optional[UsageStats]]:
|
||||||
|
"""
|
||||||
|
Stream a response with live display.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: API messages
|
||||||
|
model_id: Model ID
|
||||||
|
max_tokens: Max tokens
|
||||||
|
transforms: Transforms
|
||||||
|
on_chunk: Callback for chunks
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (full_text, usage)
|
||||||
|
"""
|
||||||
|
response = self.client.chat(
|
||||||
|
messages=messages,
|
||||||
|
model=model_id,
|
||||||
|
stream=True,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
transforms=transforms,
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(response, ChatResponse):
|
||||||
|
return response.content or "", response.usage
|
||||||
|
|
||||||
|
full_text = ""
|
||||||
|
usage: Optional[UsageStats] = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with Live("", console=console, refresh_per_second=10) as live:
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.error:
|
||||||
|
console.print(f"\n[bold red]Stream error: {chunk.error}[/]")
|
||||||
|
break
|
||||||
|
|
||||||
|
if chunk.delta_content:
|
||||||
|
full_text += chunk.delta_content
|
||||||
|
live.update(Markdown(full_text))
|
||||||
|
if on_chunk:
|
||||||
|
on_chunk(chunk.delta_content)
|
||||||
|
|
||||||
|
if chunk.usage:
|
||||||
|
usage = chunk.usage
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
console.print("\n[bold yellow]⚠️ Streaming interrupted[/]")
|
||||||
|
return "", None
|
||||||
|
|
||||||
|
return full_text, usage
|
||||||
|
|
||||||
|
def add_to_history(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
response: str,
|
||||||
|
usage: Optional[UsageStats] = None,
|
||||||
|
cost: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Add an exchange to the history.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: User prompt
|
||||||
|
response: Assistant response
|
||||||
|
usage: Usage statistics
|
||||||
|
cost: Cost if not in usage
|
||||||
|
"""
|
||||||
|
entry = HistoryEntry(
|
||||||
|
prompt=prompt,
|
||||||
|
response=response,
|
||||||
|
prompt_tokens=usage.prompt_tokens if usage else 0,
|
||||||
|
completion_tokens=usage.completion_tokens if usage else 0,
|
||||||
|
msg_cost=usage.total_cost_usd if usage and usage.total_cost_usd else cost,
|
||||||
|
timestamp=time.time(),
|
||||||
|
)
|
||||||
|
self.history.append(entry)
|
||||||
|
self.current_index = len(self.history) - 1
|
||||||
|
self.stats.add_usage(usage, cost)
|
||||||
|
|
||||||
|
def save_conversation(self, name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Save the current conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name for the saved conversation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if saved successfully
|
||||||
|
"""
|
||||||
|
if not self.history:
|
||||||
|
return False
|
||||||
|
|
||||||
|
data = [e.to_dict() for e in self.history]
|
||||||
|
self.db.save_conversation(name, data)
|
||||||
|
self.logger.info(f"Saved conversation: {name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def load_conversation(self, name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Load a saved conversation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Name of the conversation to load
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if loaded successfully
|
||||||
|
"""
|
||||||
|
data = self.db.load_conversation(name)
|
||||||
|
if not data:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.history.clear()
|
||||||
|
for entry_dict in data:
|
||||||
|
self.history.append(HistoryEntry(
|
||||||
|
prompt=entry_dict.get("prompt", ""),
|
||||||
|
response=entry_dict.get("response", ""),
|
||||||
|
prompt_tokens=entry_dict.get("prompt_tokens", 0),
|
||||||
|
completion_tokens=entry_dict.get("completion_tokens", 0),
|
||||||
|
msg_cost=entry_dict.get("msg_cost", 0.0),
|
||||||
|
))
|
||||||
|
|
||||||
|
self.current_index = len(self.history) - 1
|
||||||
|
self.memory_start_index = 0
|
||||||
|
self.stats = SessionStats() # Reset stats for loaded conversation
|
||||||
|
self.logger.info(f"Loaded conversation: {name}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset the session state."""
|
||||||
|
self.history.clear()
|
||||||
|
self.stats = SessionStats()
|
||||||
|
self.system_prompt = ""
|
||||||
|
self.memory_start_index = 0
|
||||||
|
self.current_index = 0
|
||||||
|
self.logger.info("Session reset")
|
||||||
|
|
||||||
|
def check_warnings(self) -> List[str]:
|
||||||
|
"""
|
||||||
|
Check for cost and credit warnings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of warning messages
|
||||||
|
"""
|
||||||
|
warnings = []
|
||||||
|
|
||||||
|
# Check last message cost
|
||||||
|
if self.history:
|
||||||
|
last_cost = self.history[-1].msg_cost
|
||||||
|
threshold = self.settings.cost_warning_threshold
|
||||||
|
if last_cost > threshold:
|
||||||
|
warnings.append(
|
||||||
|
f"High cost: ${last_cost:.4f} exceeds threshold ${threshold:.4f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check credits
|
||||||
|
credits = self.client.get_credits()
|
||||||
|
if credits:
|
||||||
|
left = credits.get("credits_left", 0)
|
||||||
|
total = credits.get("total_credits", 0)
|
||||||
|
|
||||||
|
if left < LOW_CREDIT_AMOUNT:
|
||||||
|
warnings.append(f"Low credits: ${left:.2f} remaining!")
|
||||||
|
elif total > 0 and left < total * LOW_CREDIT_RATIO:
|
||||||
|
warnings.append(f"Credits low: less than 10% remaining (${left:.2f})")
|
||||||
|
|
||||||
|
return warnings
|
||||||
28
oai/mcp/__init__.py
Normal file
28
oai/mcp/__init__.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""
|
||||||
|
Model Context Protocol (MCP) integration for oAI.
|
||||||
|
|
||||||
|
This package provides filesystem and database access capabilities
|
||||||
|
through the MCP standard, allowing AI models to interact with
|
||||||
|
local files and SQLite databases safely.
|
||||||
|
|
||||||
|
Key components:
|
||||||
|
- MCPManager: High-level manager for MCP operations
|
||||||
|
- MCPFilesystemServer: Filesystem and database access implementation
|
||||||
|
- GitignoreParser: Pattern matching for .gitignore support
|
||||||
|
- SQLiteQueryValidator: Query safety validation
|
||||||
|
- CrossPlatformMCPConfig: OS-specific configuration
|
||||||
|
"""
|
||||||
|
|
||||||
|
from oai.mcp.manager import MCPManager
|
||||||
|
from oai.mcp.server import MCPFilesystemServer
|
||||||
|
from oai.mcp.gitignore import GitignoreParser
|
||||||
|
from oai.mcp.validators import SQLiteQueryValidator
|
||||||
|
from oai.mcp.platform import CrossPlatformMCPConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"MCPManager",
|
||||||
|
"MCPFilesystemServer",
|
||||||
|
"GitignoreParser",
|
||||||
|
"SQLiteQueryValidator",
|
||||||
|
"CrossPlatformMCPConfig",
|
||||||
|
]
|
||||||
166
oai/mcp/gitignore.py
Normal file
166
oai/mcp/gitignore.py
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
"""
|
||||||
|
Gitignore pattern parsing for oAI MCP.
|
||||||
|
|
||||||
|
This module implements .gitignore pattern matching to filter files
|
||||||
|
during MCP filesystem operations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import fnmatch
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
from oai.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class GitignoreParser:
|
||||||
|
"""
|
||||||
|
Parse and apply .gitignore patterns.
|
||||||
|
|
||||||
|
Supports standard gitignore syntax including:
|
||||||
|
- Wildcards (*) and double wildcards (**)
|
||||||
|
- Directory-only patterns (ending with /)
|
||||||
|
- Negation patterns (starting with !)
|
||||||
|
- Comments (lines starting with #)
|
||||||
|
|
||||||
|
Patterns are applied relative to the directory containing
|
||||||
|
the .gitignore file.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize an empty pattern collection."""
|
||||||
|
# List of (pattern, is_negation, source_dir)
|
||||||
|
self.patterns: List[Tuple[str, bool, Path]] = []
|
||||||
|
|
||||||
|
def add_gitignore(self, gitignore_path: Path) -> None:
|
||||||
|
"""
|
||||||
|
Parse and add patterns from a .gitignore file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gitignore_path: Path to the .gitignore file
|
||||||
|
"""
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
if not gitignore_path.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
source_dir = gitignore_path.parent
|
||||||
|
|
||||||
|
with open(gitignore_path, "r", encoding="utf-8") as f:
|
||||||
|
for line_num, line in enumerate(f, 1):
|
||||||
|
line = line.rstrip("\n\r")
|
||||||
|
|
||||||
|
# Skip empty lines and comments
|
||||||
|
if not line or line.startswith("#"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Check for negation pattern
|
||||||
|
is_negation = line.startswith("!")
|
||||||
|
if is_negation:
|
||||||
|
line = line[1:]
|
||||||
|
|
||||||
|
# Remove leading slash (make relative to gitignore location)
|
||||||
|
if line.startswith("/"):
|
||||||
|
line = line[1:]
|
||||||
|
|
||||||
|
self.patterns.append((line, is_negation, source_dir))
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
f"Loaded {len(self.patterns)} patterns from {gitignore_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error reading {gitignore_path}: {e}")
|
||||||
|
|
||||||
|
def should_ignore(self, path: Path) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a path should be ignored based on gitignore patterns.
|
||||||
|
|
||||||
|
Patterns are evaluated in order, with later patterns overriding
|
||||||
|
earlier ones. Negation patterns (starting with !) un-ignore
|
||||||
|
previously matched paths.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the path should be ignored
|
||||||
|
"""
|
||||||
|
if not self.patterns:
|
||||||
|
return False
|
||||||
|
|
||||||
|
ignored = False
|
||||||
|
|
||||||
|
for pattern, is_negation, source_dir in self.patterns:
|
||||||
|
# Only apply pattern if path is under the source directory
|
||||||
|
try:
|
||||||
|
rel_path = path.relative_to(source_dir)
|
||||||
|
except ValueError:
|
||||||
|
# Path is not relative to this gitignore's directory
|
||||||
|
continue
|
||||||
|
|
||||||
|
rel_path_str = str(rel_path)
|
||||||
|
|
||||||
|
# Check if pattern matches
|
||||||
|
if self._match_pattern(pattern, rel_path_str, path.is_dir()):
|
||||||
|
if is_negation:
|
||||||
|
ignored = False # Negation patterns un-ignore
|
||||||
|
else:
|
||||||
|
ignored = True
|
||||||
|
|
||||||
|
return ignored
|
||||||
|
|
||||||
|
def _match_pattern(self, pattern: str, path: str, is_dir: bool) -> bool:
|
||||||
|
"""
|
||||||
|
Match a gitignore pattern against a path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pattern: The gitignore pattern
|
||||||
|
path: The relative path string to match
|
||||||
|
is_dir: Whether the path is a directory
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the pattern matches
|
||||||
|
"""
|
||||||
|
# Directory-only pattern (ends with /)
|
||||||
|
if pattern.endswith("/"):
|
||||||
|
if not is_dir:
|
||||||
|
return False
|
||||||
|
pattern = pattern[:-1]
|
||||||
|
|
||||||
|
# Handle ** patterns (matches any number of directories)
|
||||||
|
if "**" in pattern:
|
||||||
|
pattern_parts = pattern.split("**")
|
||||||
|
if len(pattern_parts) == 2:
|
||||||
|
prefix, suffix = pattern_parts
|
||||||
|
|
||||||
|
# Match if path starts with prefix and ends with suffix
|
||||||
|
if prefix:
|
||||||
|
if not path.startswith(prefix.rstrip("/")):
|
||||||
|
return False
|
||||||
|
if suffix:
|
||||||
|
suffix = suffix.lstrip("/")
|
||||||
|
if not (path.endswith(suffix) or f"/{suffix}" in path):
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Direct match using fnmatch
|
||||||
|
if fnmatch.fnmatch(path, pattern):
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Match as subdirectory pattern (pattern without / matches in any directory)
|
||||||
|
if "/" not in pattern:
|
||||||
|
parts = path.split("/")
|
||||||
|
if any(fnmatch.fnmatch(part, pattern) for part in parts):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""Clear all loaded patterns."""
|
||||||
|
self.patterns = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def pattern_count(self) -> int:
|
||||||
|
"""Get the number of loaded patterns."""
|
||||||
|
return len(self.patterns)
|
||||||
1365
oai/mcp/manager.py
Normal file
1365
oai/mcp/manager.py
Normal file
File diff suppressed because it is too large
Load Diff
228
oai/mcp/platform.py
Normal file
228
oai/mcp/platform.py
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
"""
|
||||||
|
Cross-platform MCP configuration for oAI.
|
||||||
|
|
||||||
|
This module handles OS-specific configuration, path handling,
|
||||||
|
and security checks for the MCP filesystem server.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
from oai.constants import SYSTEM_DIRS_BLACKLIST
|
||||||
|
from oai.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class CrossPlatformMCPConfig:
|
||||||
|
"""
|
||||||
|
Handle OS-specific MCP configuration.
|
||||||
|
|
||||||
|
Provides methods for path normalization, security validation,
|
||||||
|
and OS-specific default directories.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
system: Operating system name
|
||||||
|
is_macos: Whether running on macOS
|
||||||
|
is_linux: Whether running on Linux
|
||||||
|
is_windows: Whether running on Windows
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize platform detection."""
|
||||||
|
self.system = platform.system()
|
||||||
|
self.is_macos = self.system == "Darwin"
|
||||||
|
self.is_linux = self.system == "Linux"
|
||||||
|
self.is_windows = self.system == "Windows"
|
||||||
|
|
||||||
|
logger = get_logger()
|
||||||
|
logger.info(f"Detected OS: {self.system}")
|
||||||
|
|
||||||
|
def get_default_allowed_dirs(self) -> List[Path]:
|
||||||
|
"""
|
||||||
|
Get safe default directories for the current OS.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of default directories that are safe to access
|
||||||
|
"""
|
||||||
|
home = Path.home()
|
||||||
|
|
||||||
|
if self.is_macos:
|
||||||
|
return [
|
||||||
|
home / "Documents",
|
||||||
|
home / "Desktop",
|
||||||
|
home / "Downloads",
|
||||||
|
]
|
||||||
|
|
||||||
|
elif self.is_linux:
|
||||||
|
dirs = [home / "Documents"]
|
||||||
|
|
||||||
|
# Try to get XDG directories
|
||||||
|
try:
|
||||||
|
for xdg_dir in ["DOCUMENTS", "DESKTOP", "DOWNLOAD"]:
|
||||||
|
result = subprocess.run(
|
||||||
|
["xdg-user-dir", xdg_dir],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
timeout=1
|
||||||
|
)
|
||||||
|
if result.returncode == 0:
|
||||||
|
dir_path = Path(result.stdout.strip())
|
||||||
|
if dir_path.exists():
|
||||||
|
dirs.append(dir_path)
|
||||||
|
except (subprocess.TimeoutExpired, FileNotFoundError):
|
||||||
|
# Fallback to standard locations
|
||||||
|
dirs.extend([
|
||||||
|
home / "Desktop",
|
||||||
|
home / "Downloads",
|
||||||
|
])
|
||||||
|
|
||||||
|
return list(set(dirs))
|
||||||
|
|
||||||
|
elif self.is_windows:
|
||||||
|
return [
|
||||||
|
home / "Documents",
|
||||||
|
home / "Desktop",
|
||||||
|
home / "Downloads",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Fallback for unknown OS
|
||||||
|
return [home]
|
||||||
|
|
||||||
|
def get_python_command(self) -> str:
|
||||||
|
"""
|
||||||
|
Get the Python executable path.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Path to the Python executable
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
return sys.executable
|
||||||
|
|
||||||
|
def get_filesystem_warning(self) -> str:
|
||||||
|
"""
|
||||||
|
Get OS-specific security warning message.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Warning message for the current OS
|
||||||
|
"""
|
||||||
|
if self.is_macos:
|
||||||
|
return """
|
||||||
|
Note: macOS Security
|
||||||
|
The Filesystem MCP server needs access to your selected folder.
|
||||||
|
You may see a security prompt - click 'Allow' to proceed.
|
||||||
|
(System Settings > Privacy & Security > Files and Folders)
|
||||||
|
"""
|
||||||
|
elif self.is_linux:
|
||||||
|
return """
|
||||||
|
Note: Linux Security
|
||||||
|
The Filesystem MCP server will access your selected folder.
|
||||||
|
Ensure oAI has appropriate file permissions.
|
||||||
|
"""
|
||||||
|
elif self.is_windows:
|
||||||
|
return """
|
||||||
|
Note: Windows Security
|
||||||
|
The Filesystem MCP server will access your selected folder.
|
||||||
|
You may need to grant file access permissions.
|
||||||
|
"""
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def normalize_path(self, path: str) -> Path:
|
||||||
|
"""
|
||||||
|
Normalize a path for the current OS.
|
||||||
|
|
||||||
|
Expands user directory (~) and resolves to absolute path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path string to normalize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Normalized absolute Path
|
||||||
|
"""
|
||||||
|
return Path(os.path.expanduser(path)).resolve()
|
||||||
|
|
||||||
|
def is_system_directory(self, path: Path) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a path is a protected system directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: Path to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the path is a system directory
|
||||||
|
"""
|
||||||
|
path_str = str(path)
|
||||||
|
for blocked in SYSTEM_DIRS_BLACKLIST:
|
||||||
|
if path_str.startswith(blocked):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def is_safe_path(self, requested_path: Path, allowed_dirs: List[Path]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a path is within allowed directories.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requested_path: Path being requested
|
||||||
|
allowed_dirs: List of allowed parent directories
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the path is within an allowed directory
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
requested = requested_path.resolve()
|
||||||
|
|
||||||
|
for allowed in allowed_dirs:
|
||||||
|
try:
|
||||||
|
allowed_resolved = allowed.resolve()
|
||||||
|
requested.relative_to(allowed_resolved)
|
||||||
|
return True
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
return False
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
|
|
||||||
|
def get_folder_stats(self, folder: Path) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Get statistics for a folder.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folder: Path to the folder
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary with folder statistics:
|
||||||
|
- exists: Whether the folder exists
|
||||||
|
- file_count: Number of files (if exists)
|
||||||
|
- total_size: Total size in bytes (if exists)
|
||||||
|
- size_mb: Size in megabytes (if exists)
|
||||||
|
- error: Error message (if any)
|
||||||
|
"""
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
try:
|
||||||
|
if not folder.exists() or not folder.is_dir():
|
||||||
|
return {"exists": False}
|
||||||
|
|
||||||
|
file_count = 0
|
||||||
|
total_size = 0
|
||||||
|
|
||||||
|
for item in folder.rglob("*"):
|
||||||
|
if item.is_file():
|
||||||
|
file_count += 1
|
||||||
|
try:
|
||||||
|
total_size += item.stat().st_size
|
||||||
|
except (OSError, PermissionError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {
|
||||||
|
"exists": True,
|
||||||
|
"file_count": file_count,
|
||||||
|
"total_size": total_size,
|
||||||
|
"size_mb": total_size / (1024 * 1024),
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error getting folder stats for {folder}: {e}")
|
||||||
|
return {"exists": False, "error": str(e)}
|
||||||
1368
oai/mcp/server.py
Normal file
1368
oai/mcp/server.py
Normal file
File diff suppressed because it is too large
Load Diff
123
oai/mcp/validators.py
Normal file
123
oai/mcp/validators.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""
|
||||||
|
Query validation for oAI MCP database operations.
|
||||||
|
|
||||||
|
This module provides safety validation for SQL queries to ensure
|
||||||
|
only read-only operations are executed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from oai.constants import DANGEROUS_SQL_KEYWORDS
|
||||||
|
|
||||||
|
|
||||||
|
class SQLiteQueryValidator:
|
||||||
|
"""
|
||||||
|
Validate SQLite queries for read-only safety.
|
||||||
|
|
||||||
|
Ensures that only SELECT queries (including CTEs) are allowed
|
||||||
|
and blocks potentially dangerous operations like INSERT, UPDATE,
|
||||||
|
DELETE, DROP, etc.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_safe_query(query: str) -> Tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Validate that a query is a safe read-only SELECT.
|
||||||
|
|
||||||
|
The validation:
|
||||||
|
1. Checks that query starts with SELECT or WITH
|
||||||
|
2. Strips string literals before checking for dangerous keywords
|
||||||
|
3. Blocks any dangerous keywords outside of string literals
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: SQL query string to validate
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (is_safe, error_message)
|
||||||
|
- is_safe: True if the query is safe to execute
|
||||||
|
- error_message: Description of why query is unsafe (empty if safe)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
>>> SQLiteQueryValidator.is_safe_query("SELECT * FROM users")
|
||||||
|
(True, "")
|
||||||
|
>>> SQLiteQueryValidator.is_safe_query("DELETE FROM users")
|
||||||
|
(False, "Only SELECT queries are allowed...")
|
||||||
|
>>> SQLiteQueryValidator.is_safe_query("SELECT 'DELETE' FROM users")
|
||||||
|
(True, "") # 'DELETE' is inside a string literal
|
||||||
|
"""
|
||||||
|
query_upper = query.strip().upper()
|
||||||
|
|
||||||
|
# Must start with SELECT or WITH (for CTEs)
|
||||||
|
if not (query_upper.startswith("SELECT") or query_upper.startswith("WITH")):
|
||||||
|
return False, "Only SELECT queries are allowed (including WITH/CTE)"
|
||||||
|
|
||||||
|
# Remove string literals before checking for dangerous keywords
|
||||||
|
# This prevents false positives when keywords appear in data
|
||||||
|
query_no_strings = re.sub(r"'[^']*'", "", query_upper)
|
||||||
|
query_no_strings = re.sub(r'"[^"]*"', "", query_no_strings)
|
||||||
|
|
||||||
|
# Check for dangerous keywords outside of quotes
|
||||||
|
for keyword in DANGEROUS_SQL_KEYWORDS:
|
||||||
|
if re.search(r"\b" + keyword + r"\b", query_no_strings):
|
||||||
|
return False, f"Keyword '{keyword}' not allowed in read-only mode"
|
||||||
|
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sanitize_table_name(table_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize a table name to prevent SQL injection.
|
||||||
|
|
||||||
|
Only allows alphanumeric characters and underscores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table_name: Table name to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized table name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If table name contains invalid characters
|
||||||
|
"""
|
||||||
|
# Remove any characters that aren't alphanumeric or underscore
|
||||||
|
sanitized = re.sub(r"[^\w]", "", table_name)
|
||||||
|
|
||||||
|
if not sanitized:
|
||||||
|
raise ValueError("Table name cannot be empty after sanitization")
|
||||||
|
|
||||||
|
if sanitized != table_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"Table name contains invalid characters: {table_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def sanitize_column_name(column_name: str) -> str:
|
||||||
|
"""
|
||||||
|
Sanitize a column name to prevent SQL injection.
|
||||||
|
|
||||||
|
Only allows alphanumeric characters and underscores.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column_name: Column name to sanitize
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sanitized column name
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If column name contains invalid characters
|
||||||
|
"""
|
||||||
|
# Remove any characters that aren't alphanumeric or underscore
|
||||||
|
sanitized = re.sub(r"[^\w]", "", column_name)
|
||||||
|
|
||||||
|
if not sanitized:
|
||||||
|
raise ValueError("Column name cannot be empty after sanitization")
|
||||||
|
|
||||||
|
if sanitized != column_name:
|
||||||
|
raise ValueError(
|
||||||
|
f"Column name contains invalid characters: {column_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return sanitized
|
||||||
32
oai/providers/__init__.py
Normal file
32
oai/providers/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""
|
||||||
|
Provider abstraction for oAI.
|
||||||
|
|
||||||
|
This module provides a unified interface for AI model providers,
|
||||||
|
enabling easy extension to support additional providers beyond OpenRouter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from oai.providers.base import (
|
||||||
|
AIProvider,
|
||||||
|
ChatMessage,
|
||||||
|
ChatResponse,
|
||||||
|
ToolCall,
|
||||||
|
ToolFunction,
|
||||||
|
UsageStats,
|
||||||
|
ModelInfo,
|
||||||
|
ProviderCapabilities,
|
||||||
|
)
|
||||||
|
from oai.providers.openrouter import OpenRouterProvider
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Base classes and types
|
||||||
|
"AIProvider",
|
||||||
|
"ChatMessage",
|
||||||
|
"ChatResponse",
|
||||||
|
"ToolCall",
|
||||||
|
"ToolFunction",
|
||||||
|
"UsageStats",
|
||||||
|
"ModelInfo",
|
||||||
|
"ProviderCapabilities",
|
||||||
|
# Provider implementations
|
||||||
|
"OpenRouterProvider",
|
||||||
|
]
|
||||||
413
oai/providers/base.py
Normal file
413
oai/providers/base.py
Normal file
@@ -0,0 +1,413 @@
|
|||||||
|
"""
|
||||||
|
Abstract base provider for AI model integration.
|
||||||
|
|
||||||
|
This module defines the interface that all AI providers must implement,
|
||||||
|
along with common data structures for requests and responses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncIterator,
|
||||||
|
Dict,
|
||||||
|
Iterator,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MessageRole(str, Enum):
|
||||||
|
"""Message roles in a conversation."""
|
||||||
|
|
||||||
|
SYSTEM = "system"
|
||||||
|
USER = "user"
|
||||||
|
ASSISTANT = "assistant"
|
||||||
|
TOOL = "tool"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolFunction:
|
||||||
|
"""
|
||||||
|
Represents a function within a tool call.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
name: The function name
|
||||||
|
arguments: JSON string of function arguments
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
arguments: str
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolCall:
|
||||||
|
"""
|
||||||
|
Represents a tool/function call requested by the model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Unique identifier for this tool call
|
||||||
|
type: Type of tool call (usually "function")
|
||||||
|
function: The function being called
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
type: str
|
||||||
|
function: ToolFunction
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UsageStats:
|
||||||
|
"""
|
||||||
|
Token usage statistics from an API response.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
prompt_tokens: Number of tokens in the prompt
|
||||||
|
completion_tokens: Number of tokens in the completion
|
||||||
|
total_tokens: Total tokens used
|
||||||
|
total_cost_usd: Cost in USD (if available from API)
|
||||||
|
"""
|
||||||
|
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
completion_tokens: int = 0
|
||||||
|
total_tokens: int = 0
|
||||||
|
total_cost_usd: Optional[float] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_tokens(self) -> int:
|
||||||
|
"""Alias for prompt_tokens."""
|
||||||
|
return self.prompt_tokens
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_tokens(self) -> int:
|
||||||
|
"""Alias for completion_tokens."""
|
||||||
|
return self.completion_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatMessage:
|
||||||
|
"""
|
||||||
|
A single message in a chat conversation.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
role: The role of the message sender
|
||||||
|
content: Message content (text or structured content blocks)
|
||||||
|
name: Optional name for the sender
|
||||||
|
tool_calls: List of tool calls (for assistant messages)
|
||||||
|
tool_call_id: Tool call ID this message responds to (for tool messages)
|
||||||
|
"""
|
||||||
|
|
||||||
|
role: str
|
||||||
|
content: Union[str, List[Dict[str, Any]], None] = None
|
||||||
|
name: Optional[str] = None
|
||||||
|
tool_calls: Optional[List[ToolCall]] = None
|
||||||
|
tool_call_id: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""Convert to dictionary format for API requests."""
|
||||||
|
result: Dict[str, Any] = {"role": self.role}
|
||||||
|
|
||||||
|
if self.content is not None:
|
||||||
|
result["content"] = self.content
|
||||||
|
|
||||||
|
if self.name:
|
||||||
|
result["name"] = self.name
|
||||||
|
|
||||||
|
if self.tool_calls:
|
||||||
|
result["tool_calls"] = [
|
||||||
|
{
|
||||||
|
"id": tc.id,
|
||||||
|
"type": tc.type,
|
||||||
|
"function": {
|
||||||
|
"name": tc.function.name,
|
||||||
|
"arguments": tc.function.arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for tc in self.tool_calls
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.tool_call_id:
|
||||||
|
result["tool_call_id"] = self.tool_call_id
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatResponseChoice:
|
||||||
|
"""
|
||||||
|
A single choice in a chat response.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
index: Index of this choice
|
||||||
|
message: The response message
|
||||||
|
finish_reason: Why the response ended
|
||||||
|
"""
|
||||||
|
|
||||||
|
index: int
|
||||||
|
message: ChatMessage
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatResponse:
|
||||||
|
"""
|
||||||
|
Response from a chat completion request.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Unique identifier for this response
|
||||||
|
choices: List of response choices
|
||||||
|
usage: Token usage statistics
|
||||||
|
model: Model that generated this response
|
||||||
|
created: Unix timestamp of creation
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
choices: List[ChatResponseChoice]
|
||||||
|
usage: Optional[UsageStats] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
created: Optional[int] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def message(self) -> Optional[ChatMessage]:
|
||||||
|
"""Get the first choice's message."""
|
||||||
|
if self.choices:
|
||||||
|
return self.choices[0].message
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def content(self) -> Optional[str]:
|
||||||
|
"""Get the text content of the first choice."""
|
||||||
|
msg = self.message
|
||||||
|
if msg and isinstance(msg.content, str):
|
||||||
|
return msg.content
|
||||||
|
return None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tool_calls(self) -> Optional[List[ToolCall]]:
|
||||||
|
"""Get tool calls from the first choice."""
|
||||||
|
msg = self.message
|
||||||
|
if msg:
|
||||||
|
return msg.tool_calls
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StreamChunk:
|
||||||
|
"""
|
||||||
|
A single chunk from a streaming response.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Response ID
|
||||||
|
delta_content: New content in this chunk
|
||||||
|
finish_reason: Finish reason (if this is the last chunk)
|
||||||
|
usage: Usage stats (usually in the last chunk)
|
||||||
|
error: Error message if something went wrong
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
delta_content: Optional[str] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
usage: Optional[UsageStats] = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModelInfo:
|
||||||
|
"""
|
||||||
|
Information about an AI model.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
id: Unique model identifier
|
||||||
|
name: Display name
|
||||||
|
description: Model description
|
||||||
|
context_length: Maximum context window size
|
||||||
|
pricing: Pricing info (input/output per million tokens)
|
||||||
|
supported_parameters: List of supported API parameters
|
||||||
|
input_modalities: Supported input types (text, image, etc.)
|
||||||
|
output_modalities: Supported output types
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str = ""
|
||||||
|
context_length: int = 0
|
||||||
|
pricing: Dict[str, float] = field(default_factory=dict)
|
||||||
|
supported_parameters: List[str] = field(default_factory=list)
|
||||||
|
input_modalities: List[str] = field(default_factory=lambda: ["text"])
|
||||||
|
output_modalities: List[str] = field(default_factory=lambda: ["text"])
|
||||||
|
|
||||||
|
def supports_images(self) -> bool:
|
||||||
|
"""Check if model supports image input."""
|
||||||
|
return "image" in self.input_modalities
|
||||||
|
|
||||||
|
def supports_tools(self) -> bool:
|
||||||
|
"""Check if model supports function calling/tools."""
|
||||||
|
return "tools" in self.supported_parameters or "functions" in self.supported_parameters
|
||||||
|
|
||||||
|
def supports_streaming(self) -> bool:
|
||||||
|
"""Check if model supports streaming responses."""
|
||||||
|
return "stream" in self.supported_parameters
|
||||||
|
|
||||||
|
def supports_online(self) -> bool:
|
||||||
|
"""Check if model supports web search (online mode)."""
|
||||||
|
return self.supports_tools()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ProviderCapabilities:
|
||||||
|
"""
|
||||||
|
Capabilities supported by a provider.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
streaming: Provider supports streaming responses
|
||||||
|
tools: Provider supports function calling
|
||||||
|
images: Provider supports image inputs
|
||||||
|
online: Provider supports web search
|
||||||
|
max_context: Maximum context length across all models
|
||||||
|
"""
|
||||||
|
|
||||||
|
streaming: bool = True
|
||||||
|
tools: bool = True
|
||||||
|
images: bool = True
|
||||||
|
online: bool = False
|
||||||
|
max_context: int = 128000
|
||||||
|
|
||||||
|
|
||||||
|
class AIProvider(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for AI model providers.
|
||||||
|
|
||||||
|
All provider implementations must inherit from this class
|
||||||
|
and implement the required abstract methods.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, base_url: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
Initialize the provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: API key for authentication
|
||||||
|
base_url: Optional custom base URL for the API
|
||||||
|
"""
|
||||||
|
self.api_key = api_key
|
||||||
|
self.base_url = base_url
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Get the provider name."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def capabilities(self) -> ProviderCapabilities:
|
||||||
|
"""Get provider capabilities."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def list_models(self) -> List[ModelInfo]:
|
||||||
|
"""
|
||||||
|
Fetch available models from the provider.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of available models with their info
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
||||||
|
"""
|
||||||
|
Get information about a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: The model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model information or None if not found
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
stream: bool = False,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
||||||
|
"""
|
||||||
|
Send a chat completion request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model ID to use
|
||||||
|
messages: List of chat messages
|
||||||
|
stream: Whether to stream the response
|
||||||
|
max_tokens: Maximum tokens in response
|
||||||
|
temperature: Sampling temperature
|
||||||
|
tools: List of tool definitions for function calling
|
||||||
|
tool_choice: How to handle tool selection ("auto", "none", etc.)
|
||||||
|
**kwargs: Additional provider-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def chat_async(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
stream: bool = False,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
|
||||||
|
"""
|
||||||
|
Send an async chat completion request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model ID to use
|
||||||
|
messages: List of chat messages
|
||||||
|
stream: Whether to stream the response
|
||||||
|
max_tokens: Maximum tokens in response
|
||||||
|
temperature: Sampling temperature
|
||||||
|
tools: List of tool definitions for function calling
|
||||||
|
tool_choice: How to handle tool selection
|
||||||
|
**kwargs: Additional provider-specific parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResponse for non-streaming, AsyncIterator[StreamChunk] for streaming
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_credits(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get account credit/balance information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with credit info or None if not supported
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def validate_api_key(self) -> bool:
|
||||||
|
"""
|
||||||
|
Validate that the API key is valid.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if API key is valid
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.list_models()
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return False
|
||||||
623
oai/providers/openrouter.py
Normal file
623
oai/providers/openrouter.py
Normal file
@@ -0,0 +1,623 @@
|
|||||||
|
"""
|
||||||
|
OpenRouter provider implementation.
|
||||||
|
|
||||||
|
This module implements the AIProvider interface for OpenRouter,
|
||||||
|
supporting chat completions, streaming, and function calling.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from openrouter import OpenRouter
|
||||||
|
|
||||||
|
from oai.constants import APP_NAME, APP_URL, DEFAULT_BASE_URL
|
||||||
|
from oai.providers.base import (
|
||||||
|
AIProvider,
|
||||||
|
ChatMessage,
|
||||||
|
ChatResponse,
|
||||||
|
ChatResponseChoice,
|
||||||
|
ModelInfo,
|
||||||
|
ProviderCapabilities,
|
||||||
|
StreamChunk,
|
||||||
|
ToolCall,
|
||||||
|
ToolFunction,
|
||||||
|
UsageStats,
|
||||||
|
)
|
||||||
|
from oai.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
class OpenRouterProvider(AIProvider):
|
||||||
|
"""
|
||||||
|
OpenRouter API provider implementation.
|
||||||
|
|
||||||
|
Provides access to multiple AI models through OpenRouter's unified API,
|
||||||
|
supporting chat completions, streaming responses, and function calling.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
client: The underlying OpenRouter client
|
||||||
|
_models_cache: Cached list of available models
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: str,
|
||||||
|
base_url: Optional[str] = None,
|
||||||
|
app_name: str = APP_NAME,
|
||||||
|
app_url: str = APP_URL,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Initialize the OpenRouter provider.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
api_key: OpenRouter API key
|
||||||
|
base_url: Optional custom base URL
|
||||||
|
app_name: Application name for API headers
|
||||||
|
app_url: Application URL for API headers
|
||||||
|
"""
|
||||||
|
super().__init__(api_key, base_url or DEFAULT_BASE_URL)
|
||||||
|
self.app_name = app_name
|
||||||
|
self.app_url = app_url
|
||||||
|
self.client = OpenRouter(api_key=api_key)
|
||||||
|
self._models_cache: Optional[List[ModelInfo]] = None
|
||||||
|
self._raw_models_cache: Optional[List[Dict[str, Any]]] = None
|
||||||
|
|
||||||
|
self.logger = get_logger()
|
||||||
|
self.logger.info(f"OpenRouter provider initialized with base URL: {self.base_url}")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Get the provider name."""
|
||||||
|
return "OpenRouter"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def capabilities(self) -> ProviderCapabilities:
|
||||||
|
"""Get provider capabilities."""
|
||||||
|
return ProviderCapabilities(
|
||||||
|
streaming=True,
|
||||||
|
tools=True,
|
||||||
|
images=True,
|
||||||
|
online=True,
|
||||||
|
max_context=2000000, # Claude models support up to 200k
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_headers(self) -> Dict[str, str]:
|
||||||
|
"""Get standard HTTP headers for API requests."""
|
||||||
|
headers = {
|
||||||
|
"HTTP-Referer": self.app_url,
|
||||||
|
"X-Title": self.app_name,
|
||||||
|
}
|
||||||
|
if self.api_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
def _parse_model(self, model_data: Dict[str, Any]) -> ModelInfo:
|
||||||
|
"""
|
||||||
|
Parse raw model data into ModelInfo.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_data: Raw model data from API
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed ModelInfo object
|
||||||
|
"""
|
||||||
|
architecture = model_data.get("architecture", {})
|
||||||
|
pricing_data = model_data.get("pricing", {})
|
||||||
|
|
||||||
|
# Parse pricing (convert from string to float if needed)
|
||||||
|
pricing = {}
|
||||||
|
for key in ["prompt", "completion"]:
|
||||||
|
value = pricing_data.get(key)
|
||||||
|
if value is not None:
|
||||||
|
try:
|
||||||
|
# Convert from per-token to per-million-tokens
|
||||||
|
pricing[key] = float(value) * 1_000_000
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pricing[key] = 0.0
|
||||||
|
|
||||||
|
return ModelInfo(
|
||||||
|
id=model_data.get("id", ""),
|
||||||
|
name=model_data.get("name", model_data.get("id", "")),
|
||||||
|
description=model_data.get("description", ""),
|
||||||
|
context_length=model_data.get("context_length", 0),
|
||||||
|
pricing=pricing,
|
||||||
|
supported_parameters=model_data.get("supported_parameters", []),
|
||||||
|
input_modalities=architecture.get("input_modalities", ["text"]),
|
||||||
|
output_modalities=architecture.get("output_modalities", ["text"]),
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
||||||
|
"""
|
||||||
|
Fetch available models from OpenRouter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
filter_text_only: If True, exclude video-only models
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of available models
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If API request fails
|
||||||
|
"""
|
||||||
|
if self._models_cache is not None:
|
||||||
|
return self._models_cache
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.base_url}/models",
|
||||||
|
headers=self._get_headers(),
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
raw_models = response.json().get("data", [])
|
||||||
|
self._raw_models_cache = raw_models
|
||||||
|
|
||||||
|
models = []
|
||||||
|
for model_data in raw_models:
|
||||||
|
# Optionally filter out video-only models
|
||||||
|
if filter_text_only:
|
||||||
|
modalities = model_data.get("modalities", [])
|
||||||
|
if modalities and "video" in modalities and "text" not in modalities:
|
||||||
|
continue
|
||||||
|
|
||||||
|
models.append(self._parse_model(model_data))
|
||||||
|
|
||||||
|
self._models_cache = models
|
||||||
|
self.logger.info(f"Fetched {len(models)} models from OpenRouter")
|
||||||
|
return models
|
||||||
|
|
||||||
|
except requests.RequestException as e:
|
||||||
|
self.logger.error(f"Failed to fetch models: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def get_raw_models(self) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get raw model data as returned by the API.
|
||||||
|
|
||||||
|
Useful for accessing provider-specific fields not in ModelInfo.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of raw model dictionaries
|
||||||
|
"""
|
||||||
|
if self._raw_models_cache is None:
|
||||||
|
self.list_models()
|
||||||
|
return self._raw_models_cache or []
|
||||||
|
|
||||||
|
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
||||||
|
"""
|
||||||
|
Get information about a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: The model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model information or None if not found
|
||||||
|
"""
|
||||||
|
models = self.list_models()
|
||||||
|
for model in models:
|
||||||
|
if model.id == model_id:
|
||||||
|
return model
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get raw model data for a specific model.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: The model identifier
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Raw model dictionary or None if not found
|
||||||
|
"""
|
||||||
|
raw_models = self.get_raw_models()
|
||||||
|
for model in raw_models:
|
||||||
|
if model.get("id") == model_id:
|
||||||
|
return model
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Convert ChatMessage objects to API format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: List of ChatMessage objects
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of message dictionaries for the API
|
||||||
|
"""
|
||||||
|
return [msg.to_dict() for msg in messages]
|
||||||
|
|
||||||
|
def _parse_usage(self, usage_data: Any) -> Optional[UsageStats]:
|
||||||
|
"""
|
||||||
|
Parse usage data from API response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
usage_data: Raw usage data from API
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed UsageStats or None
|
||||||
|
"""
|
||||||
|
if not usage_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Handle both attribute and dict access
|
||||||
|
prompt_tokens = 0
|
||||||
|
completion_tokens = 0
|
||||||
|
total_cost = None
|
||||||
|
|
||||||
|
if hasattr(usage_data, "prompt_tokens"):
|
||||||
|
prompt_tokens = getattr(usage_data, "prompt_tokens", 0) or 0
|
||||||
|
elif isinstance(usage_data, dict):
|
||||||
|
prompt_tokens = usage_data.get("prompt_tokens", 0) or 0
|
||||||
|
|
||||||
|
if hasattr(usage_data, "completion_tokens"):
|
||||||
|
completion_tokens = getattr(usage_data, "completion_tokens", 0) or 0
|
||||||
|
elif isinstance(usage_data, dict):
|
||||||
|
completion_tokens = usage_data.get("completion_tokens", 0) or 0
|
||||||
|
|
||||||
|
# Try alternative naming (input_tokens/output_tokens)
|
||||||
|
if prompt_tokens == 0:
|
||||||
|
if hasattr(usage_data, "input_tokens"):
|
||||||
|
prompt_tokens = getattr(usage_data, "input_tokens", 0) or 0
|
||||||
|
elif isinstance(usage_data, dict):
|
||||||
|
prompt_tokens = usage_data.get("input_tokens", 0) or 0
|
||||||
|
|
||||||
|
if completion_tokens == 0:
|
||||||
|
if hasattr(usage_data, "output_tokens"):
|
||||||
|
completion_tokens = getattr(usage_data, "output_tokens", 0) or 0
|
||||||
|
elif isinstance(usage_data, dict):
|
||||||
|
completion_tokens = usage_data.get("output_tokens", 0) or 0
|
||||||
|
|
||||||
|
# Get cost if available
|
||||||
|
if hasattr(usage_data, "total_cost_usd"):
|
||||||
|
total_cost = getattr(usage_data, "total_cost_usd", None)
|
||||||
|
elif isinstance(usage_data, dict):
|
||||||
|
total_cost = usage_data.get("total_cost_usd")
|
||||||
|
|
||||||
|
return UsageStats(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
total_cost_usd=float(total_cost) if total_cost else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _parse_tool_calls(self, tool_calls_data: Any) -> Optional[List[ToolCall]]:
|
||||||
|
"""
|
||||||
|
Parse tool calls from API response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_calls_data: Raw tool calls data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ToolCall objects or None
|
||||||
|
"""
|
||||||
|
if not tool_calls_data:
|
||||||
|
return None
|
||||||
|
|
||||||
|
tool_calls = []
|
||||||
|
for tc in tool_calls_data:
|
||||||
|
# Handle both attribute and dict access
|
||||||
|
if hasattr(tc, "id"):
|
||||||
|
tc_id = tc.id
|
||||||
|
tc_type = getattr(tc, "type", "function")
|
||||||
|
func = tc.function
|
||||||
|
func_name = func.name
|
||||||
|
func_args = func.arguments
|
||||||
|
else:
|
||||||
|
tc_id = tc.get("id", "")
|
||||||
|
tc_type = tc.get("type", "function")
|
||||||
|
func = tc.get("function", {})
|
||||||
|
func_name = func.get("name", "")
|
||||||
|
func_args = func.get("arguments", "{}")
|
||||||
|
|
||||||
|
tool_calls.append(
|
||||||
|
ToolCall(
|
||||||
|
id=tc_id,
|
||||||
|
type=tc_type,
|
||||||
|
function=ToolFunction(name=func_name, arguments=func_args),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return tool_calls if tool_calls else None
|
||||||
|
|
||||||
|
def _parse_response(self, response: Any) -> ChatResponse:
|
||||||
|
"""
|
||||||
|
Parse API response into ChatResponse.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Raw API response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Parsed ChatResponse
|
||||||
|
"""
|
||||||
|
choices = []
|
||||||
|
for choice in response.choices:
|
||||||
|
msg = choice.message
|
||||||
|
message = ChatMessage(
|
||||||
|
role=msg.role if hasattr(msg, "role") else "assistant",
|
||||||
|
content=msg.content if hasattr(msg, "content") else None,
|
||||||
|
tool_calls=self._parse_tool_calls(
|
||||||
|
getattr(msg, "tool_calls", None)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
choices.append(
|
||||||
|
ChatResponseChoice(
|
||||||
|
index=choice.index if hasattr(choice, "index") else 0,
|
||||||
|
message=message,
|
||||||
|
finish_reason=getattr(choice, "finish_reason", None),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ChatResponse(
|
||||||
|
id=response.id if hasattr(response, "id") else "",
|
||||||
|
choices=choices,
|
||||||
|
usage=self._parse_usage(getattr(response, "usage", None)),
|
||||||
|
model=getattr(response, "model", None),
|
||||||
|
created=getattr(response, "created", None),
|
||||||
|
)
|
||||||
|
|
||||||
|
def chat(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
stream: bool = False,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
transforms: Optional[List[str]] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
||||||
|
"""
|
||||||
|
Send a chat completion request to OpenRouter.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model ID to use
|
||||||
|
messages: List of chat messages
|
||||||
|
stream: Whether to stream the response
|
||||||
|
max_tokens: Maximum tokens in response
|
||||||
|
temperature: Sampling temperature (0-2)
|
||||||
|
tools: List of tool definitions for function calling
|
||||||
|
tool_choice: How to handle tool selection ("auto", "none", etc.)
|
||||||
|
transforms: List of transforms (e.g., ["middle-out"])
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
|
||||||
|
"""
|
||||||
|
# Build request parameters
|
||||||
|
params: Dict[str, Any] = {
|
||||||
|
"model": model,
|
||||||
|
"messages": self._convert_messages(messages),
|
||||||
|
"stream": stream,
|
||||||
|
"http_headers": self._get_headers(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# Request usage stats in streaming responses
|
||||||
|
if stream:
|
||||||
|
params["stream_options"] = {"include_usage": True}
|
||||||
|
|
||||||
|
if max_tokens is not None:
|
||||||
|
params["max_tokens"] = max_tokens
|
||||||
|
|
||||||
|
if temperature is not None:
|
||||||
|
params["temperature"] = temperature
|
||||||
|
|
||||||
|
if tools:
|
||||||
|
params["tools"] = tools
|
||||||
|
params["tool_choice"] = tool_choice or "auto"
|
||||||
|
|
||||||
|
if transforms:
|
||||||
|
params["transforms"] = transforms
|
||||||
|
|
||||||
|
# Add any additional parameters
|
||||||
|
params.update(kwargs)
|
||||||
|
|
||||||
|
self.logger.debug(f"Sending chat request to model {model}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.chat.send(**params)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._stream_response(response)
|
||||||
|
else:
|
||||||
|
return self._parse_response(response)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Chat request failed: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _stream_response(self, response: Any) -> Iterator[StreamChunk]:
|
||||||
|
"""
|
||||||
|
Process a streaming response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: Streaming response from API
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
StreamChunk objects
|
||||||
|
"""
|
||||||
|
last_usage = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
for chunk in response:
|
||||||
|
# Check for errors
|
||||||
|
if hasattr(chunk, "error") and chunk.error:
|
||||||
|
yield StreamChunk(
|
||||||
|
id=getattr(chunk, "id", ""),
|
||||||
|
error=chunk.error.message if hasattr(chunk.error, "message") else str(chunk.error),
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Extract delta content
|
||||||
|
delta_content = None
|
||||||
|
finish_reason = None
|
||||||
|
|
||||||
|
if hasattr(chunk, "choices") and chunk.choices:
|
||||||
|
choice = chunk.choices[0]
|
||||||
|
if hasattr(choice, "delta"):
|
||||||
|
delta = choice.delta
|
||||||
|
if hasattr(delta, "content") and delta.content:
|
||||||
|
delta_content = delta.content
|
||||||
|
finish_reason = getattr(choice, "finish_reason", None)
|
||||||
|
|
||||||
|
# Track usage from last chunk
|
||||||
|
if hasattr(chunk, "usage") and chunk.usage:
|
||||||
|
last_usage = self._parse_usage(chunk.usage)
|
||||||
|
|
||||||
|
yield StreamChunk(
|
||||||
|
id=getattr(chunk, "id", ""),
|
||||||
|
delta_content=delta_content,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
usage=last_usage if finish_reason else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Stream error: {e}")
|
||||||
|
yield StreamChunk(id="", error=str(e))
|
||||||
|
|
||||||
|
async def chat_async(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[ChatMessage],
|
||||||
|
stream: bool = False,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
tool_choice: Optional[str] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
|
||||||
|
"""
|
||||||
|
Send an async chat completion request.
|
||||||
|
|
||||||
|
Note: Currently wraps the sync implementation.
|
||||||
|
TODO: Implement true async support when OpenRouter SDK supports it.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: Model ID to use
|
||||||
|
messages: List of chat messages
|
||||||
|
stream: Whether to stream the response
|
||||||
|
max_tokens: Maximum tokens in response
|
||||||
|
temperature: Sampling temperature
|
||||||
|
tools: List of tool definitions
|
||||||
|
tool_choice: Tool selection mode
|
||||||
|
**kwargs: Additional parameters
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ChatResponse for non-streaming, AsyncIterator for streaming
|
||||||
|
"""
|
||||||
|
# For now, use sync implementation
|
||||||
|
# TODO: Add true async when SDK supports it
|
||||||
|
result = self.chat(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
stream=stream,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
temperature=temperature,
|
||||||
|
tools=tools,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream and isinstance(result, Iterator):
|
||||||
|
# Convert sync iterator to async
|
||||||
|
async def async_iter() -> AsyncIterator[StreamChunk]:
|
||||||
|
for chunk in result:
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return async_iter()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def get_credits(self) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Get OpenRouter account credit information.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with credit info:
|
||||||
|
- total_credits: Total credits purchased
|
||||||
|
- used_credits: Credits used
|
||||||
|
- credits_left: Remaining credits
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
Exception: If API request fails
|
||||||
|
"""
|
||||||
|
if not self.api_key:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.get(
|
||||||
|
f"{self.base_url}/credits",
|
||||||
|
headers=self._get_headers(),
|
||||||
|
timeout=10,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
data = response.json().get("data", {})
|
||||||
|
total_credits = float(data.get("total_credits", 0))
|
||||||
|
total_usage = float(data.get("total_usage", 0))
|
||||||
|
credits_left = total_credits - total_usage
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_credits": total_credits,
|
||||||
|
"used_credits": total_usage,
|
||||||
|
"credits_left": credits_left,
|
||||||
|
"total_credits_formatted": f"${total_credits:.2f}",
|
||||||
|
"used_credits_formatted": f"${total_usage:.2f}",
|
||||||
|
"credits_left_formatted": f"${credits_left:.2f}",
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"Failed to fetch credits: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""Clear the models cache to force a refresh."""
|
||||||
|
self._models_cache = None
|
||||||
|
self._raw_models_cache = None
|
||||||
|
self.logger.debug("Models cache cleared")
|
||||||
|
|
||||||
|
def get_effective_model_id(self, model_id: str, online_enabled: bool) -> str:
|
||||||
|
"""
|
||||||
|
Get the effective model ID with online suffix if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Base model ID
|
||||||
|
online_enabled: Whether online mode is enabled
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Model ID with :online suffix if applicable
|
||||||
|
"""
|
||||||
|
if online_enabled and not model_id.endswith(":online"):
|
||||||
|
return f"{model_id}:online"
|
||||||
|
return model_id
|
||||||
|
|
||||||
|
def estimate_cost(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
input_tokens: int,
|
||||||
|
output_tokens: int,
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Estimate the cost for a completion.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_id: Model ID
|
||||||
|
input_tokens: Number of input tokens
|
||||||
|
output_tokens: Number of output tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Estimated cost in USD
|
||||||
|
"""
|
||||||
|
model = self.get_model(model_id)
|
||||||
|
if model and model.pricing:
|
||||||
|
input_cost = model.pricing.get("prompt", 0) * input_tokens / 1_000_000
|
||||||
|
output_cost = model.pricing.get("completion", 0) * output_tokens / 1_000_000
|
||||||
|
return input_cost + output_cost
|
||||||
|
|
||||||
|
# Fallback to default pricing if model not found
|
||||||
|
from oai.constants import MODEL_PRICING
|
||||||
|
|
||||||
|
input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000
|
||||||
|
output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000
|
||||||
|
return input_cost + output_cost
|
||||||
2
oai/py.typed
Normal file
2
oai/py.typed
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# Marker file for PEP 561
|
||||||
|
# This package supports type checking
|
||||||
51
oai/ui/__init__.py
Normal file
51
oai/ui/__init__.py
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
"""
|
||||||
|
UI utilities for oAI.
|
||||||
|
|
||||||
|
This module provides rich terminal UI components and display helpers
|
||||||
|
for the chat application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from oai.ui.console import (
|
||||||
|
console,
|
||||||
|
clear_screen,
|
||||||
|
display_panel,
|
||||||
|
display_table,
|
||||||
|
display_markdown,
|
||||||
|
print_error,
|
||||||
|
print_warning,
|
||||||
|
print_success,
|
||||||
|
print_info,
|
||||||
|
)
|
||||||
|
from oai.ui.tables import (
|
||||||
|
create_model_table,
|
||||||
|
create_stats_table,
|
||||||
|
create_help_table,
|
||||||
|
display_paginated_table,
|
||||||
|
)
|
||||||
|
from oai.ui.prompts import (
|
||||||
|
prompt_confirm,
|
||||||
|
prompt_choice,
|
||||||
|
prompt_input,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# Console utilities
|
||||||
|
"console",
|
||||||
|
"clear_screen",
|
||||||
|
"display_panel",
|
||||||
|
"display_table",
|
||||||
|
"display_markdown",
|
||||||
|
"print_error",
|
||||||
|
"print_warning",
|
||||||
|
"print_success",
|
||||||
|
"print_info",
|
||||||
|
# Table utilities
|
||||||
|
"create_model_table",
|
||||||
|
"create_stats_table",
|
||||||
|
"create_help_table",
|
||||||
|
"display_paginated_table",
|
||||||
|
# Prompt utilities
|
||||||
|
"prompt_confirm",
|
||||||
|
"prompt_choice",
|
||||||
|
"prompt_input",
|
||||||
|
]
|
||||||
242
oai/ui/console.py
Normal file
242
oai/ui/console.py
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
"""
|
||||||
|
Console utilities for oAI.
|
||||||
|
|
||||||
|
This module provides the Rich console instance and common display functions
|
||||||
|
for formatted terminal output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.markdown import Markdown
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.table import Table
|
||||||
|
from rich.text import Text
|
||||||
|
|
||||||
|
# Global console instance for the application
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
def clear_screen() -> None:
|
||||||
|
"""
|
||||||
|
Clear the terminal screen.
|
||||||
|
|
||||||
|
Uses ANSI escape codes for fast clearing, with a fallback
|
||||||
|
for terminals that don't support them.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
print("\033[H\033[J", end="", flush=True)
|
||||||
|
except Exception:
|
||||||
|
# Fallback: print many newlines
|
||||||
|
print("\n" * 100)
|
||||||
|
|
||||||
|
|
||||||
|
def display_panel(
|
||||||
|
content: Any,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
subtitle: Optional[str] = None,
|
||||||
|
border_style: str = "green",
|
||||||
|
title_align: str = "left",
|
||||||
|
subtitle_align: str = "right",
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Display content in a bordered panel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Content to display (string, Table, or Markdown)
|
||||||
|
title: Optional panel title
|
||||||
|
subtitle: Optional panel subtitle
|
||||||
|
border_style: Border color/style
|
||||||
|
title_align: Title alignment ("left", "center", "right")
|
||||||
|
subtitle_align: Subtitle alignment
|
||||||
|
"""
|
||||||
|
panel = Panel(
|
||||||
|
content,
|
||||||
|
title=title,
|
||||||
|
subtitle=subtitle,
|
||||||
|
border_style=border_style,
|
||||||
|
title_align=title_align,
|
||||||
|
subtitle_align=subtitle_align,
|
||||||
|
)
|
||||||
|
console.print(panel)
|
||||||
|
|
||||||
|
|
||||||
|
def display_table(
|
||||||
|
table: Table,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
subtitle: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Display a table with optional title panel.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table: Rich Table to display
|
||||||
|
title: Optional panel title
|
||||||
|
subtitle: Optional panel subtitle
|
||||||
|
"""
|
||||||
|
if title:
|
||||||
|
display_panel(table, title=title, subtitle=subtitle)
|
||||||
|
else:
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
|
||||||
|
def display_markdown(
|
||||||
|
content: str,
|
||||||
|
panel: bool = False,
|
||||||
|
title: Optional[str] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Display markdown-formatted content.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
content: Markdown text to display
|
||||||
|
panel: Whether to wrap in a panel
|
||||||
|
title: Optional panel title (if panel=True)
|
||||||
|
"""
|
||||||
|
md = Markdown(content)
|
||||||
|
if panel:
|
||||||
|
display_panel(md, title=title)
|
||||||
|
else:
|
||||||
|
console.print(md)
|
||||||
|
|
||||||
|
|
||||||
|
def print_error(message: str, prefix: str = "Error:") -> None:
|
||||||
|
"""
|
||||||
|
Print an error message in red.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Error message to display
|
||||||
|
prefix: Prefix before the message (default: "Error:")
|
||||||
|
"""
|
||||||
|
console.print(f"[bold red]{prefix}[/] {message}")
|
||||||
|
|
||||||
|
|
||||||
|
def print_warning(message: str, prefix: str = "Warning:") -> None:
|
||||||
|
"""
|
||||||
|
Print a warning message in yellow.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Warning message to display
|
||||||
|
prefix: Prefix before the message (default: "Warning:")
|
||||||
|
"""
|
||||||
|
console.print(f"[bold yellow]{prefix}[/] {message}")
|
||||||
|
|
||||||
|
|
||||||
|
def print_success(message: str, prefix: str = "✓") -> None:
|
||||||
|
"""
|
||||||
|
Print a success message in green.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Success message to display
|
||||||
|
prefix: Prefix before the message (default: "✓")
|
||||||
|
"""
|
||||||
|
console.print(f"[bold green]{prefix}[/] {message}")
|
||||||
|
|
||||||
|
|
||||||
|
def print_info(message: str, dim: bool = False) -> None:
|
||||||
|
"""
|
||||||
|
Print an informational message in cyan.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: Info message to display
|
||||||
|
dim: Whether to dim the message
|
||||||
|
"""
|
||||||
|
if dim:
|
||||||
|
console.print(f"[dim cyan]{message}[/]")
|
||||||
|
else:
|
||||||
|
console.print(f"[bold cyan]{message}[/]")
|
||||||
|
|
||||||
|
|
||||||
|
def print_metrics(
|
||||||
|
tokens: int,
|
||||||
|
cost: float,
|
||||||
|
time_seconds: float,
|
||||||
|
context_info: str = "",
|
||||||
|
online: bool = False,
|
||||||
|
mcp_mode: Optional[str] = None,
|
||||||
|
tool_loops: int = 0,
|
||||||
|
session_tokens: int = 0,
|
||||||
|
session_cost: float = 0.0,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Print formatted metrics for a response.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: Total tokens used
|
||||||
|
cost: Cost in USD
|
||||||
|
time_seconds: Response time
|
||||||
|
context_info: Context information string
|
||||||
|
online: Whether online mode is active
|
||||||
|
mcp_mode: MCP mode ("files", "database", or None)
|
||||||
|
tool_loops: Number of tool call loops
|
||||||
|
session_tokens: Total session tokens
|
||||||
|
session_cost: Total session cost
|
||||||
|
"""
|
||||||
|
parts = [
|
||||||
|
f"📊 Metrics: {tokens} tokens",
|
||||||
|
f"${cost:.4f}",
|
||||||
|
f"{time_seconds:.2f}s",
|
||||||
|
]
|
||||||
|
|
||||||
|
if context_info:
|
||||||
|
parts.append(context_info)
|
||||||
|
|
||||||
|
if online:
|
||||||
|
parts.append("🌐")
|
||||||
|
|
||||||
|
if mcp_mode == "files":
|
||||||
|
parts.append("🔧")
|
||||||
|
elif mcp_mode == "database":
|
||||||
|
parts.append("🗄️")
|
||||||
|
|
||||||
|
if tool_loops > 0:
|
||||||
|
parts.append(f"({tool_loops} tool loop(s))")
|
||||||
|
|
||||||
|
parts.append(f"Session: {session_tokens} tokens")
|
||||||
|
parts.append(f"${session_cost:.4f}")
|
||||||
|
|
||||||
|
console.print(f"\n[dim blue]{' | '.join(parts)}[/]")
|
||||||
|
|
||||||
|
|
||||||
|
def format_size(size_bytes: int) -> str:
|
||||||
|
"""
|
||||||
|
Format a size in bytes to a human-readable string.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size_bytes: Size in bytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted size string (e.g., "1.5 MB")
|
||||||
|
"""
|
||||||
|
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||||
|
if abs(size_bytes) < 1024.0:
|
||||||
|
return f"{size_bytes:.1f} {unit}"
|
||||||
|
size_bytes /= 1024.0
|
||||||
|
return f"{size_bytes:.1f} PB"
|
||||||
|
|
||||||
|
|
||||||
|
def format_tokens(tokens: int) -> str:
|
||||||
|
"""
|
||||||
|
Format token count with thousands separators.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokens: Number of tokens
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted token string (e.g., "1,234,567")
|
||||||
|
"""
|
||||||
|
return f"{tokens:,}"
|
||||||
|
|
||||||
|
|
||||||
|
def format_cost(cost: float, precision: int = 4) -> str:
|
||||||
|
"""
|
||||||
|
Format cost in USD.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cost: Cost in dollars
|
||||||
|
precision: Decimal places
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted cost string (e.g., "$0.0123")
|
||||||
|
"""
|
||||||
|
return f"${cost:.{precision}f}"
|
||||||
274
oai/ui/prompts.py
Normal file
274
oai/ui/prompts.py
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
"""
|
||||||
|
Prompt utilities for oAI.
|
||||||
|
|
||||||
|
This module provides functions for gathering user input
|
||||||
|
through confirmations, choices, and text prompts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional, TypeVar
|
||||||
|
|
||||||
|
import typer
|
||||||
|
|
||||||
|
from oai.ui.console import console
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_confirm(
|
||||||
|
message: str,
|
||||||
|
default: bool = False,
|
||||||
|
abort: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Prompt the user for a yes/no confirmation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The question to ask
|
||||||
|
default: Default value if user presses Enter
|
||||||
|
abort: Whether to abort on "no" response
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if user confirms, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return typer.confirm(message, default=default, abort=abort)
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
console.print("\n[yellow]Cancelled[/]")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_choice(
|
||||||
|
message: str,
|
||||||
|
choices: List[str],
|
||||||
|
default: Optional[str] = None,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Prompt the user to select from a list of choices.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The question to ask
|
||||||
|
choices: List of valid choices
|
||||||
|
default: Default choice if user presses Enter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Selected choice or None if cancelled
|
||||||
|
"""
|
||||||
|
# Display choices
|
||||||
|
console.print(f"\n[bold cyan]{message}[/]")
|
||||||
|
for i, choice in enumerate(choices, 1):
|
||||||
|
default_marker = " [default]" if choice == default else ""
|
||||||
|
console.print(f" {i}. {choice}{default_marker}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = input("\nEnter number or value: ").strip()
|
||||||
|
|
||||||
|
if not response and default:
|
||||||
|
return default
|
||||||
|
|
||||||
|
# Try as number first
|
||||||
|
try:
|
||||||
|
index = int(response) - 1
|
||||||
|
if 0 <= index < len(choices):
|
||||||
|
return choices[index]
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Try as exact match
|
||||||
|
if response in choices:
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Try case-insensitive match
|
||||||
|
response_lower = response.lower()
|
||||||
|
for choice in choices:
|
||||||
|
if choice.lower() == response_lower:
|
||||||
|
return choice
|
||||||
|
|
||||||
|
console.print(f"[red]Invalid choice: {response}[/]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
console.print("\n[yellow]Cancelled[/]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_input(
|
||||||
|
message: str,
|
||||||
|
default: Optional[str] = None,
|
||||||
|
password: bool = False,
|
||||||
|
required: bool = False,
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Prompt the user for text input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The prompt message
|
||||||
|
default: Default value if user presses Enter
|
||||||
|
password: Whether to hide input (for sensitive data)
|
||||||
|
required: Whether input is required (loops until provided)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
User input or default, None if cancelled
|
||||||
|
"""
|
||||||
|
prompt_text = message
|
||||||
|
if default:
|
||||||
|
prompt_text += f" [{default}]"
|
||||||
|
prompt_text += ": "
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
if password:
|
||||||
|
import getpass
|
||||||
|
|
||||||
|
response = getpass.getpass(prompt_text)
|
||||||
|
else:
|
||||||
|
response = input(prompt_text).strip()
|
||||||
|
|
||||||
|
if not response:
|
||||||
|
if default:
|
||||||
|
return default
|
||||||
|
if required:
|
||||||
|
console.print("[yellow]Input required[/]")
|
||||||
|
continue
|
||||||
|
return None
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
console.print("\n[yellow]Cancelled[/]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_number(
|
||||||
|
message: str,
|
||||||
|
min_value: Optional[int] = None,
|
||||||
|
max_value: Optional[int] = None,
|
||||||
|
default: Optional[int] = None,
|
||||||
|
) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Prompt the user for a numeric input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: The prompt message
|
||||||
|
min_value: Minimum allowed value
|
||||||
|
max_value: Maximum allowed value
|
||||||
|
default: Default value if user presses Enter
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Integer value or None if cancelled
|
||||||
|
"""
|
||||||
|
prompt_text = message
|
||||||
|
if default is not None:
|
||||||
|
prompt_text += f" [{default}]"
|
||||||
|
prompt_text += ": "
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
response = input(prompt_text).strip()
|
||||||
|
|
||||||
|
if not response:
|
||||||
|
if default is not None:
|
||||||
|
return default
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
value = int(response)
|
||||||
|
except ValueError:
|
||||||
|
console.print("[red]Please enter a valid number[/]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if min_value is not None and value < min_value:
|
||||||
|
console.print(f"[red]Value must be at least {min_value}[/]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if max_value is not None and value > max_value:
|
||||||
|
console.print(f"[red]Value must be at most {max_value}[/]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
console.print("\n[yellow]Cancelled[/]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_selection(
|
||||||
|
items: List[T],
|
||||||
|
message: str = "Select an item",
|
||||||
|
display_func: Optional[callable] = None,
|
||||||
|
allow_cancel: bool = True,
|
||||||
|
) -> Optional[T]:
|
||||||
|
"""
|
||||||
|
Prompt the user to select an item from a list.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
items: List of items to choose from
|
||||||
|
message: The selection prompt
|
||||||
|
display_func: Function to convert item to display string
|
||||||
|
allow_cancel: Whether to allow cancellation
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Selected item or None if cancelled
|
||||||
|
"""
|
||||||
|
if not items:
|
||||||
|
console.print("[yellow]No items to select[/]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
display = display_func or str
|
||||||
|
|
||||||
|
console.print(f"\n[bold cyan]{message}[/]")
|
||||||
|
for i, item in enumerate(items, 1):
|
||||||
|
console.print(f" {i}. {display(item)}")
|
||||||
|
|
||||||
|
if allow_cancel:
|
||||||
|
console.print(f" 0. Cancel")
|
||||||
|
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
response = input("\nEnter number: ").strip()
|
||||||
|
|
||||||
|
try:
|
||||||
|
index = int(response)
|
||||||
|
except ValueError:
|
||||||
|
console.print("[red]Please enter a valid number[/]")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if allow_cancel and index == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if 1 <= index <= len(items):
|
||||||
|
return items[index - 1]
|
||||||
|
|
||||||
|
console.print(f"[red]Please enter a number between 1 and {len(items)}[/]")
|
||||||
|
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
console.print("\n[yellow]Cancelled[/]")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_copy_response(response: str) -> bool:
|
||||||
|
"""
|
||||||
|
Prompt user to copy a response to clipboard.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
response: The response text
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if copied, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
copy_choice = input("💾 Type 'c' to copy response, or press Enter to continue: ").strip().lower()
|
||||||
|
if copy_choice == "c":
|
||||||
|
try:
|
||||||
|
import pyperclip
|
||||||
|
|
||||||
|
pyperclip.copy(response)
|
||||||
|
console.print("[bold green]✅ Response copied to clipboard![/]")
|
||||||
|
return True
|
||||||
|
except ImportError:
|
||||||
|
console.print("[yellow]pyperclip not installed - cannot copy to clipboard[/]")
|
||||||
|
except Exception as e:
|
||||||
|
console.print(f"[red]Failed to copy: {e}[/]")
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False
|
||||||
373
oai/ui/tables.py
Normal file
373
oai/ui/tables.py
Normal file
@@ -0,0 +1,373 @@
|
|||||||
|
"""
|
||||||
|
Table utilities for oAI.
|
||||||
|
|
||||||
|
This module provides functions for creating and displaying
|
||||||
|
formatted tables with pagination support.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from rich.panel import Panel
|
||||||
|
from rich.table import Table
|
||||||
|
|
||||||
|
from oai.ui.console import clear_screen, console
|
||||||
|
|
||||||
|
|
||||||
|
def create_model_table(
|
||||||
|
models: List[Dict[str, Any]],
|
||||||
|
show_capabilities: bool = True,
|
||||||
|
) -> Table:
|
||||||
|
"""
|
||||||
|
Create a table displaying available AI models.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
models: List of model dictionaries
|
||||||
|
show_capabilities: Whether to show capability columns
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rich Table with model information
|
||||||
|
"""
|
||||||
|
if show_capabilities:
|
||||||
|
table = Table(
|
||||||
|
"No.",
|
||||||
|
"Model ID",
|
||||||
|
"Context",
|
||||||
|
"Image",
|
||||||
|
"Online",
|
||||||
|
"Tools",
|
||||||
|
show_header=True,
|
||||||
|
header_style="bold magenta",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
table = Table(
|
||||||
|
"No.",
|
||||||
|
"Model ID",
|
||||||
|
"Context",
|
||||||
|
show_header=True,
|
||||||
|
header_style="bold magenta",
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, model in enumerate(models, 1):
|
||||||
|
model_id = model.get("id", "Unknown")
|
||||||
|
context = model.get("context_length", 0)
|
||||||
|
context_str = f"{context:,}" if context else "-"
|
||||||
|
|
||||||
|
if show_capabilities:
|
||||||
|
# Get modalities and parameters
|
||||||
|
architecture = model.get("architecture", {})
|
||||||
|
input_modalities = architecture.get("input_modalities", [])
|
||||||
|
supported_params = model.get("supported_parameters", [])
|
||||||
|
|
||||||
|
has_image = "✓" if "image" in input_modalities else "-"
|
||||||
|
has_online = "✓" if "tools" in supported_params else "-"
|
||||||
|
has_tools = "✓" if "tools" in supported_params or "functions" in supported_params else "-"
|
||||||
|
|
||||||
|
table.add_row(
|
||||||
|
str(i),
|
||||||
|
model_id,
|
||||||
|
context_str,
|
||||||
|
has_image,
|
||||||
|
has_online,
|
||||||
|
has_tools,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
table.add_row(str(i), model_id, context_str)
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def create_stats_table(stats: Dict[str, Any]) -> Table:
|
||||||
|
"""
|
||||||
|
Create a table displaying session statistics.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
stats: Dictionary with statistics data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rich Table with stats
|
||||||
|
"""
|
||||||
|
table = Table(
|
||||||
|
"Metric",
|
||||||
|
"Value",
|
||||||
|
show_header=True,
|
||||||
|
header_style="bold magenta",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Token stats
|
||||||
|
if "input_tokens" in stats:
|
||||||
|
table.add_row("Input Tokens", f"{stats['input_tokens']:,}")
|
||||||
|
if "output_tokens" in stats:
|
||||||
|
table.add_row("Output Tokens", f"{stats['output_tokens']:,}")
|
||||||
|
if "total_tokens" in stats:
|
||||||
|
table.add_row("Total Tokens", f"{stats['total_tokens']:,}")
|
||||||
|
|
||||||
|
# Cost stats
|
||||||
|
if "total_cost" in stats:
|
||||||
|
table.add_row("Total Cost", f"${stats['total_cost']:.4f}")
|
||||||
|
if "avg_cost" in stats:
|
||||||
|
table.add_row("Avg Cost/Message", f"${stats['avg_cost']:.4f}")
|
||||||
|
|
||||||
|
# Message stats
|
||||||
|
if "message_count" in stats:
|
||||||
|
table.add_row("Messages", str(stats["message_count"]))
|
||||||
|
|
||||||
|
# Credits
|
||||||
|
if "credits_left" in stats:
|
||||||
|
table.add_row("Credits Left", stats["credits_left"])
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def create_help_table(commands: Dict[str, Dict[str, str]]) -> Table:
|
||||||
|
"""
|
||||||
|
Create a help table for commands.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
commands: Dictionary of command info
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rich Table with command help
|
||||||
|
"""
|
||||||
|
table = Table(
|
||||||
|
"Command",
|
||||||
|
"Description",
|
||||||
|
"Example",
|
||||||
|
show_header=True,
|
||||||
|
header_style="bold magenta",
|
||||||
|
show_lines=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
for cmd, info in commands.items():
|
||||||
|
description = info.get("description", "")
|
||||||
|
example = info.get("example", "")
|
||||||
|
table.add_row(cmd, description, example)
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def create_folder_table(
|
||||||
|
folders: List[Dict[str, Any]],
|
||||||
|
gitignore_info: str = "",
|
||||||
|
) -> Table:
|
||||||
|
"""
|
||||||
|
Create a table for MCP folder listing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
folders: List of folder dictionaries
|
||||||
|
gitignore_info: Optional gitignore status info
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rich Table with folder information
|
||||||
|
"""
|
||||||
|
table = Table(
|
||||||
|
"No.",
|
||||||
|
"Path",
|
||||||
|
"Files",
|
||||||
|
"Size",
|
||||||
|
show_header=True,
|
||||||
|
header_style="bold magenta",
|
||||||
|
)
|
||||||
|
|
||||||
|
for folder in folders:
|
||||||
|
number = str(folder.get("number", ""))
|
||||||
|
path = folder.get("path", "")
|
||||||
|
|
||||||
|
if folder.get("exists", True):
|
||||||
|
files = f"📁 {folder.get('file_count', 0)}"
|
||||||
|
size = f"{folder.get('size_mb', 0):.1f} MB"
|
||||||
|
else:
|
||||||
|
files = "[red]Not found[/red]"
|
||||||
|
size = "-"
|
||||||
|
|
||||||
|
table.add_row(number, path, files, size)
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def create_database_table(databases: List[Dict[str, Any]]) -> Table:
|
||||||
|
"""
|
||||||
|
Create a table for MCP database listing.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
databases: List of database dictionaries
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Rich Table with database information
|
||||||
|
"""
|
||||||
|
table = Table(
|
||||||
|
"No.",
|
||||||
|
"Name",
|
||||||
|
"Tables",
|
||||||
|
"Size",
|
||||||
|
"Status",
|
||||||
|
show_header=True,
|
||||||
|
header_style="bold magenta",
|
||||||
|
)
|
||||||
|
|
||||||
|
for db in databases:
|
||||||
|
number = str(db.get("number", ""))
|
||||||
|
name = db.get("name", "")
|
||||||
|
table_count = f"{db.get('table_count', 0)} tables"
|
||||||
|
size = f"{db.get('size_mb', 0):.1f} MB"
|
||||||
|
|
||||||
|
if db.get("warning"):
|
||||||
|
status = f"[red]{db['warning']}[/red]"
|
||||||
|
else:
|
||||||
|
status = "[green]✓[/green]"
|
||||||
|
|
||||||
|
table.add_row(number, name, table_count, size, status)
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def display_paginated_table(
|
||||||
|
table: Table,
|
||||||
|
title: str,
|
||||||
|
terminal_height: Optional[int] = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Display a table with pagination for large datasets.
|
||||||
|
|
||||||
|
Allows navigating through pages with keyboard input.
|
||||||
|
Press SPACE for next page, any other key to exit.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
table: Rich Table to display
|
||||||
|
title: Title for the table
|
||||||
|
terminal_height: Override terminal height (auto-detected if None)
|
||||||
|
"""
|
||||||
|
# Get terminal dimensions
|
||||||
|
try:
|
||||||
|
term_height = terminal_height or os.get_terminal_size().lines - 8
|
||||||
|
except OSError:
|
||||||
|
term_height = 20
|
||||||
|
|
||||||
|
# Render table to segments
|
||||||
|
from rich.segment import Segment
|
||||||
|
|
||||||
|
segments = list(console.render(table))
|
||||||
|
|
||||||
|
# Group segments into lines
|
||||||
|
current_line_segments: List[Segment] = []
|
||||||
|
all_lines: List[List[Segment]] = []
|
||||||
|
|
||||||
|
for segment in segments:
|
||||||
|
if segment.text == "\n":
|
||||||
|
all_lines.append(current_line_segments)
|
||||||
|
current_line_segments = []
|
||||||
|
else:
|
||||||
|
current_line_segments.append(segment)
|
||||||
|
|
||||||
|
if current_line_segments:
|
||||||
|
all_lines.append(current_line_segments)
|
||||||
|
|
||||||
|
total_lines = len(all_lines)
|
||||||
|
|
||||||
|
# If table fits in one screen, just display it
|
||||||
|
if total_lines <= term_height:
|
||||||
|
console.print(Panel(table, title=title, title_align="left"))
|
||||||
|
return
|
||||||
|
|
||||||
|
# Extract header and footer lines
|
||||||
|
header_lines: List[List[Segment]] = []
|
||||||
|
data_lines: List[List[Segment]] = []
|
||||||
|
footer_line: List[Segment] = []
|
||||||
|
|
||||||
|
# Find header end (line after the header text with border)
|
||||||
|
header_end_index = 0
|
||||||
|
found_header_text = False
|
||||||
|
|
||||||
|
for i, line_segments in enumerate(all_lines):
|
||||||
|
has_header_style = any(
|
||||||
|
seg.style and ("bold" in str(seg.style) or "magenta" in str(seg.style))
|
||||||
|
for seg in line_segments
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_header_style:
|
||||||
|
found_header_text = True
|
||||||
|
|
||||||
|
if found_header_text and i > 0:
|
||||||
|
line_text = "".join(seg.text for seg in line_segments)
|
||||||
|
if any(char in line_text for char in ["─", "━", "┼", "╪", "┤", "├"]):
|
||||||
|
header_end_index = i
|
||||||
|
break
|
||||||
|
|
||||||
|
# Extract footer (bottom border)
|
||||||
|
if all_lines:
|
||||||
|
last_line_text = "".join(seg.text for seg in all_lines[-1])
|
||||||
|
if any(char in last_line_text for char in ["─", "━", "┴", "╧", "┘", "└"]):
|
||||||
|
footer_line = all_lines[-1]
|
||||||
|
all_lines = all_lines[:-1]
|
||||||
|
|
||||||
|
# Split into header and data
|
||||||
|
if header_end_index > 0:
|
||||||
|
header_lines = all_lines[: header_end_index + 1]
|
||||||
|
data_lines = all_lines[header_end_index + 1 :]
|
||||||
|
else:
|
||||||
|
header_lines = all_lines[: min(3, len(all_lines))]
|
||||||
|
data_lines = all_lines[min(3, len(all_lines)) :]
|
||||||
|
|
||||||
|
lines_per_page = term_height - len(header_lines)
|
||||||
|
current_line = 0
|
||||||
|
page_number = 1
|
||||||
|
|
||||||
|
# Paginate
|
||||||
|
while current_line < len(data_lines):
|
||||||
|
clear_screen()
|
||||||
|
console.print(f"[bold cyan]{title} (Page {page_number})[/]")
|
||||||
|
|
||||||
|
# Print header
|
||||||
|
for line_segments in header_lines:
|
||||||
|
for segment in line_segments:
|
||||||
|
console.print(segment.text, style=segment.style, end="")
|
||||||
|
console.print()
|
||||||
|
|
||||||
|
# Print data rows for this page
|
||||||
|
end_line = min(current_line + lines_per_page, len(data_lines))
|
||||||
|
for line_segments in data_lines[current_line:end_line]:
|
||||||
|
for segment in line_segments:
|
||||||
|
console.print(segment.text, style=segment.style, end="")
|
||||||
|
console.print()
|
||||||
|
|
||||||
|
# Print footer
|
||||||
|
if footer_line:
|
||||||
|
for segment in footer_line:
|
||||||
|
console.print(segment.text, style=segment.style, end="")
|
||||||
|
console.print()
|
||||||
|
|
||||||
|
current_line = end_line
|
||||||
|
page_number += 1
|
||||||
|
|
||||||
|
# Prompt for next page
|
||||||
|
if current_line < len(data_lines):
|
||||||
|
console.print(
|
||||||
|
f"\n[dim yellow]--- Press SPACE for next page, "
|
||||||
|
f"or any other key to finish (Page {page_number - 1}, "
|
||||||
|
f"showing {end_line}/{len(data_lines)} data rows) ---[/dim yellow]"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
import termios
|
||||||
|
import tty
|
||||||
|
|
||||||
|
fd = sys.stdin.fileno()
|
||||||
|
old_settings = termios.tcgetattr(fd)
|
||||||
|
|
||||||
|
try:
|
||||||
|
tty.setraw(fd)
|
||||||
|
char = sys.stdin.read(1)
|
||||||
|
if char != " ":
|
||||||
|
break
|
||||||
|
finally:
|
||||||
|
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
|
||||||
|
|
||||||
|
except (ImportError, OSError, AttributeError):
|
||||||
|
# Fallback for non-Unix systems
|
||||||
|
try:
|
||||||
|
user_input = input()
|
||||||
|
if user_input.strip():
|
||||||
|
break
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
break
|
||||||
20
oai/utils/__init__.py
Normal file
20
oai/utils/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
"""
|
||||||
|
Utility modules for oAI.
|
||||||
|
|
||||||
|
This package provides common utilities used throughout the application
|
||||||
|
including logging, file handling, and export functionality.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from oai.utils.logging import setup_logging, get_logger
|
||||||
|
from oai.utils.files import read_file_safe, is_binary_file
|
||||||
|
from oai.utils.export import export_as_markdown, export_as_json, export_as_html
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"setup_logging",
|
||||||
|
"get_logger",
|
||||||
|
"read_file_safe",
|
||||||
|
"is_binary_file",
|
||||||
|
"export_as_markdown",
|
||||||
|
"export_as_json",
|
||||||
|
"export_as_html",
|
||||||
|
]
|
||||||
248
oai/utils/export.py
Normal file
248
oai/utils/export.py
Normal file
@@ -0,0 +1,248 @@
|
|||||||
|
"""
|
||||||
|
Export utilities for oAI.
|
||||||
|
|
||||||
|
This module provides functions for exporting conversation history
|
||||||
|
in various formats including Markdown, JSON, and HTML.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import datetime
|
||||||
|
from typing import List, Dict
|
||||||
|
from html import escape as html_escape
|
||||||
|
|
||||||
|
from oai.constants import APP_VERSION, APP_URL
|
||||||
|
|
||||||
|
|
||||||
|
def export_as_markdown(
|
||||||
|
session_history: List[Dict[str, str]],
|
||||||
|
session_system_prompt: str = ""
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Export conversation history as Markdown.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_history: List of message dictionaries with 'prompt' and 'response'
|
||||||
|
session_system_prompt: Optional system prompt to include
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Markdown formatted string
|
||||||
|
"""
|
||||||
|
lines = ["# Conversation Export", ""]
|
||||||
|
|
||||||
|
if session_system_prompt:
|
||||||
|
lines.extend([f"**System Prompt:** {session_system_prompt}", ""])
|
||||||
|
|
||||||
|
lines.append(f"**Export Date:** {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
||||||
|
lines.append(f"**Messages:** {len(session_history)}")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("---")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
for i, entry in enumerate(session_history, 1):
|
||||||
|
lines.append(f"## Message {i}")
|
||||||
|
lines.append("")
|
||||||
|
lines.append("**User:**")
|
||||||
|
lines.append("")
|
||||||
|
lines.append(entry.get("prompt", ""))
|
||||||
|
lines.append("")
|
||||||
|
lines.append("**Assistant:**")
|
||||||
|
lines.append("")
|
||||||
|
lines.append(entry.get("response", ""))
|
||||||
|
lines.append("")
|
||||||
|
lines.append("---")
|
||||||
|
lines.append("")
|
||||||
|
|
||||||
|
lines.append(f"*Exported from oAI v{APP_VERSION} - {APP_URL}*")
|
||||||
|
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def export_as_json(
|
||||||
|
session_history: List[Dict[str, str]],
|
||||||
|
session_system_prompt: str = ""
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Export conversation history as JSON.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_history: List of message dictionaries
|
||||||
|
session_system_prompt: Optional system prompt to include
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
JSON formatted string
|
||||||
|
"""
|
||||||
|
export_data = {
|
||||||
|
"export_date": datetime.datetime.now().isoformat(),
|
||||||
|
"app_version": APP_VERSION,
|
||||||
|
"system_prompt": session_system_prompt,
|
||||||
|
"message_count": len(session_history),
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"index": i + 1,
|
||||||
|
"prompt": entry.get("prompt", ""),
|
||||||
|
"response": entry.get("response", ""),
|
||||||
|
"prompt_tokens": entry.get("prompt_tokens", 0),
|
||||||
|
"completion_tokens": entry.get("completion_tokens", 0),
|
||||||
|
"cost": entry.get("msg_cost", 0.0),
|
||||||
|
}
|
||||||
|
for i, entry in enumerate(session_history)
|
||||||
|
],
|
||||||
|
"totals": {
|
||||||
|
"prompt_tokens": sum(e.get("prompt_tokens", 0) for e in session_history),
|
||||||
|
"completion_tokens": sum(e.get("completion_tokens", 0) for e in session_history),
|
||||||
|
"total_cost": sum(e.get("msg_cost", 0.0) for e in session_history),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return json.dumps(export_data, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
|
def export_as_html(
|
||||||
|
session_history: List[Dict[str, str]],
|
||||||
|
session_system_prompt: str = ""
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Export conversation history as styled HTML.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
session_history: List of message dictionaries
|
||||||
|
session_system_prompt: Optional system prompt to include
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HTML formatted string with embedded CSS
|
||||||
|
"""
|
||||||
|
html_parts = [
|
||||||
|
"<!DOCTYPE html>",
|
||||||
|
"<html>",
|
||||||
|
"<head>",
|
||||||
|
" <meta charset='UTF-8'>",
|
||||||
|
" <meta name='viewport' content='width=device-width, initial-scale=1.0'>",
|
||||||
|
" <title>Conversation Export - oAI</title>",
|
||||||
|
" <style>",
|
||||||
|
" * { box-sizing: border-box; }",
|
||||||
|
" body {",
|
||||||
|
" font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;",
|
||||||
|
" max-width: 900px;",
|
||||||
|
" margin: 40px auto;",
|
||||||
|
" padding: 20px;",
|
||||||
|
" background: #f5f5f5;",
|
||||||
|
" color: #333;",
|
||||||
|
" }",
|
||||||
|
" .header {",
|
||||||
|
" background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);",
|
||||||
|
" color: white;",
|
||||||
|
" padding: 30px;",
|
||||||
|
" border-radius: 10px;",
|
||||||
|
" margin-bottom: 30px;",
|
||||||
|
" box-shadow: 0 4px 6px rgba(0,0,0,0.1);",
|
||||||
|
" }",
|
||||||
|
" .header h1 {",
|
||||||
|
" margin: 0 0 10px 0;",
|
||||||
|
" font-size: 2em;",
|
||||||
|
" }",
|
||||||
|
" .export-info {",
|
||||||
|
" opacity: 0.9;",
|
||||||
|
" font-size: 0.95em;",
|
||||||
|
" margin: 5px 0;",
|
||||||
|
" }",
|
||||||
|
" .system-prompt {",
|
||||||
|
" background: #fff3cd;",
|
||||||
|
" padding: 20px;",
|
||||||
|
" border-radius: 8px;",
|
||||||
|
" margin-bottom: 25px;",
|
||||||
|
" border-left: 5px solid #ffc107;",
|
||||||
|
" box-shadow: 0 2px 4px rgba(0,0,0,0.05);",
|
||||||
|
" }",
|
||||||
|
" .system-prompt strong {",
|
||||||
|
" color: #856404;",
|
||||||
|
" display: block;",
|
||||||
|
" margin-bottom: 10px;",
|
||||||
|
" font-size: 1.1em;",
|
||||||
|
" }",
|
||||||
|
" .message-container { margin-bottom: 20px; }",
|
||||||
|
" .message {",
|
||||||
|
" background: white;",
|
||||||
|
" padding: 20px;",
|
||||||
|
" border-radius: 8px;",
|
||||||
|
" box-shadow: 0 2px 4px rgba(0,0,0,0.08);",
|
||||||
|
" margin-bottom: 12px;",
|
||||||
|
" }",
|
||||||
|
" .user-message { border-left: 5px solid #10b981; }",
|
||||||
|
" .assistant-message { border-left: 5px solid #3b82f6; }",
|
||||||
|
" .role {",
|
||||||
|
" font-weight: bold;",
|
||||||
|
" margin-bottom: 12px;",
|
||||||
|
" font-size: 1.05em;",
|
||||||
|
" text-transform: uppercase;",
|
||||||
|
" letter-spacing: 0.5px;",
|
||||||
|
" }",
|
||||||
|
" .user-role { color: #10b981; }",
|
||||||
|
" .assistant-role { color: #3b82f6; }",
|
||||||
|
" .content {",
|
||||||
|
" line-height: 1.8;",
|
||||||
|
" white-space: pre-wrap;",
|
||||||
|
" color: #333;",
|
||||||
|
" }",
|
||||||
|
" .message-number {",
|
||||||
|
" color: #6b7280;",
|
||||||
|
" font-size: 0.85em;",
|
||||||
|
" margin-bottom: 15px;",
|
||||||
|
" font-weight: 600;",
|
||||||
|
" }",
|
||||||
|
" .footer {",
|
||||||
|
" text-align: center;",
|
||||||
|
" margin-top: 40px;",
|
||||||
|
" padding: 20px;",
|
||||||
|
" color: #6b7280;",
|
||||||
|
" font-size: 0.9em;",
|
||||||
|
" }",
|
||||||
|
" .footer a { color: #667eea; text-decoration: none; }",
|
||||||
|
" .footer a:hover { text-decoration: underline; }",
|
||||||
|
" @media print {",
|
||||||
|
" body { background: white; }",
|
||||||
|
" .message { break-inside: avoid; }",
|
||||||
|
" }",
|
||||||
|
" </style>",
|
||||||
|
"</head>",
|
||||||
|
"<body>",
|
||||||
|
" <div class='header'>",
|
||||||
|
" <h1>Conversation Export</h1>",
|
||||||
|
f" <div class='export-info'>Exported: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</div>",
|
||||||
|
f" <div class='export-info'>Total Messages: {len(session_history)}</div>",
|
||||||
|
" </div>",
|
||||||
|
]
|
||||||
|
|
||||||
|
if session_system_prompt:
|
||||||
|
html_parts.extend([
|
||||||
|
" <div class='system-prompt'>",
|
||||||
|
" <strong>System Prompt</strong>",
|
||||||
|
f" <div>{html_escape(session_system_prompt)}</div>",
|
||||||
|
" </div>",
|
||||||
|
])
|
||||||
|
|
||||||
|
for i, entry in enumerate(session_history, 1):
|
||||||
|
prompt = html_escape(entry.get("prompt", ""))
|
||||||
|
response = html_escape(entry.get("response", ""))
|
||||||
|
|
||||||
|
html_parts.extend([
|
||||||
|
" <div class='message-container'>",
|
||||||
|
f" <div class='message-number'>Message {i} of {len(session_history)}</div>",
|
||||||
|
" <div class='message user-message'>",
|
||||||
|
" <div class='role user-role'>User</div>",
|
||||||
|
f" <div class='content'>{prompt}</div>",
|
||||||
|
" </div>",
|
||||||
|
" <div class='message assistant-message'>",
|
||||||
|
" <div class='role assistant-role'>Assistant</div>",
|
||||||
|
f" <div class='content'>{response}</div>",
|
||||||
|
" </div>",
|
||||||
|
" </div>",
|
||||||
|
])
|
||||||
|
|
||||||
|
html_parts.extend([
|
||||||
|
" <div class='footer'>",
|
||||||
|
f" <p>Generated by oAI v{APP_VERSION} • <a href='{APP_URL}'>{APP_URL}</a></p>",
|
||||||
|
" </div>",
|
||||||
|
"</body>",
|
||||||
|
"</html>",
|
||||||
|
])
|
||||||
|
|
||||||
|
return "\n".join(html_parts)
|
||||||
323
oai/utils/files.py
Normal file
323
oai/utils/files.py
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
"""
|
||||||
|
File handling utilities for oAI.
|
||||||
|
|
||||||
|
This module provides safe file reading, type detection, and other
|
||||||
|
file-related operations used throughout the application.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import mimetypes
|
||||||
|
import base64
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional, Dict, Any, Tuple
|
||||||
|
|
||||||
|
from oai.constants import (
|
||||||
|
MAX_FILE_SIZE,
|
||||||
|
CONTENT_TRUNCATION_THRESHOLD,
|
||||||
|
SUPPORTED_CODE_EXTENSIONS,
|
||||||
|
ALLOWED_FILE_EXTENSIONS,
|
||||||
|
)
|
||||||
|
from oai.utils.logging import get_logger
|
||||||
|
|
||||||
|
|
||||||
|
def is_binary_file(file_path: Path) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a file appears to be binary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the file appears to be binary, False otherwise
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
# Read first 8KB to check for binary content
|
||||||
|
chunk = f.read(8192)
|
||||||
|
# Check for null bytes (common in binary files)
|
||||||
|
if b"\x00" in chunk:
|
||||||
|
return True
|
||||||
|
# Try to decode as UTF-8
|
||||||
|
try:
|
||||||
|
chunk.decode("utf-8")
|
||||||
|
return False
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_type(file_path: Path) -> Tuple[Optional[str], str]:
|
||||||
|
"""
|
||||||
|
Determine the MIME type and category of a file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (mime_type, category) where category is one of:
|
||||||
|
'image', 'pdf', 'code', 'text', 'binary', 'unknown'
|
||||||
|
"""
|
||||||
|
mime_type, _ = mimetypes.guess_type(str(file_path))
|
||||||
|
ext = file_path.suffix.lower()
|
||||||
|
|
||||||
|
if mime_type and mime_type.startswith("image/"):
|
||||||
|
return mime_type, "image"
|
||||||
|
elif mime_type == "application/pdf" or ext == ".pdf":
|
||||||
|
return mime_type or "application/pdf", "pdf"
|
||||||
|
elif ext in SUPPORTED_CODE_EXTENSIONS:
|
||||||
|
return mime_type or "text/plain", "code"
|
||||||
|
elif mime_type and mime_type.startswith("text/"):
|
||||||
|
return mime_type, "text"
|
||||||
|
elif is_binary_file(file_path):
|
||||||
|
return mime_type, "binary"
|
||||||
|
else:
|
||||||
|
return mime_type, "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def read_file_safe(
|
||||||
|
file_path: Path,
|
||||||
|
max_size: int = MAX_FILE_SIZE,
|
||||||
|
truncate_threshold: int = CONTENT_TRUNCATION_THRESHOLD
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Safely read a file with size limits and truncation support.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file to read
|
||||||
|
max_size: Maximum file size to read (bytes)
|
||||||
|
truncate_threshold: Threshold for truncating large files
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary containing:
|
||||||
|
- content: File content (text or base64)
|
||||||
|
- size: File size in bytes
|
||||||
|
- truncated: Whether content was truncated
|
||||||
|
- encoding: 'text', 'base64', or None on error
|
||||||
|
- error: Error message if reading failed
|
||||||
|
"""
|
||||||
|
logger = get_logger()
|
||||||
|
|
||||||
|
try:
|
||||||
|
path = Path(file_path).resolve()
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
return {
|
||||||
|
"content": None,
|
||||||
|
"size": 0,
|
||||||
|
"truncated": False,
|
||||||
|
"encoding": None,
|
||||||
|
"error": f"File not found: {path}"
|
||||||
|
}
|
||||||
|
|
||||||
|
if not path.is_file():
|
||||||
|
return {
|
||||||
|
"content": None,
|
||||||
|
"size": 0,
|
||||||
|
"truncated": False,
|
||||||
|
"encoding": None,
|
||||||
|
"error": f"Not a file: {path}"
|
||||||
|
}
|
||||||
|
|
||||||
|
file_size = path.stat().st_size
|
||||||
|
|
||||||
|
if file_size > max_size:
|
||||||
|
return {
|
||||||
|
"content": None,
|
||||||
|
"size": file_size,
|
||||||
|
"truncated": False,
|
||||||
|
"encoding": None,
|
||||||
|
"error": f"File too large: {file_size / (1024*1024):.1f}MB (max: {max_size / (1024*1024):.0f}MB)"
|
||||||
|
}
|
||||||
|
|
||||||
|
# Try to read as text first
|
||||||
|
try:
|
||||||
|
content = path.read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
# Check if truncation is needed
|
||||||
|
if file_size > truncate_threshold:
|
||||||
|
lines = content.split("\n")
|
||||||
|
total_lines = len(lines)
|
||||||
|
|
||||||
|
# Keep first 500 lines and last 100 lines
|
||||||
|
head_lines = 500
|
||||||
|
tail_lines = 100
|
||||||
|
|
||||||
|
if total_lines > (head_lines + tail_lines):
|
||||||
|
truncated_content = (
|
||||||
|
"\n".join(lines[:head_lines]) +
|
||||||
|
f"\n\n... [TRUNCATED: {total_lines - head_lines - tail_lines} lines omitted] ...\n\n" +
|
||||||
|
"\n".join(lines[-tail_lines:])
|
||||||
|
)
|
||||||
|
logger.info(f"Read file (truncated): {path} ({file_size} bytes, {total_lines} lines)")
|
||||||
|
return {
|
||||||
|
"content": truncated_content,
|
||||||
|
"size": file_size,
|
||||||
|
"truncated": True,
|
||||||
|
"total_lines": total_lines,
|
||||||
|
"lines_shown": head_lines + tail_lines,
|
||||||
|
"encoding": "text",
|
||||||
|
"error": None
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info(f"Read file: {path} ({file_size} bytes)")
|
||||||
|
return {
|
||||||
|
"content": content,
|
||||||
|
"size": file_size,
|
||||||
|
"truncated": False,
|
||||||
|
"encoding": "text",
|
||||||
|
"error": None
|
||||||
|
}
|
||||||
|
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
# File is binary, return base64 encoded
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
binary_data = f.read()
|
||||||
|
b64_content = base64.b64encode(binary_data).decode("utf-8")
|
||||||
|
logger.info(f"Read binary file: {path} ({file_size} bytes)")
|
||||||
|
return {
|
||||||
|
"content": b64_content,
|
||||||
|
"size": file_size,
|
||||||
|
"truncated": False,
|
||||||
|
"encoding": "base64",
|
||||||
|
"error": None
|
||||||
|
}
|
||||||
|
|
||||||
|
except PermissionError as e:
|
||||||
|
return {
|
||||||
|
"content": None,
|
||||||
|
"size": 0,
|
||||||
|
"truncated": False,
|
||||||
|
"encoding": None,
|
||||||
|
"error": f"Permission denied: {e}"
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error reading file {file_path}: {e}")
|
||||||
|
return {
|
||||||
|
"content": None,
|
||||||
|
"size": 0,
|
||||||
|
"truncated": False,
|
||||||
|
"encoding": None,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_file_extension(file_path: Path) -> str:
|
||||||
|
"""
|
||||||
|
Get the lowercase file extension.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Lowercase extension including the dot (e.g., '.py')
|
||||||
|
"""
|
||||||
|
return file_path.suffix.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def is_allowed_extension(file_path: Path) -> bool:
|
||||||
|
"""
|
||||||
|
Check if a file has an allowed extension for attachment.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if the extension is allowed, False otherwise
|
||||||
|
"""
|
||||||
|
return get_file_extension(file_path) in ALLOWED_FILE_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
|
def format_file_size(size_bytes: int) -> str:
|
||||||
|
"""
|
||||||
|
Format a file size in human-readable format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size_bytes: Size in bytes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Formatted string (e.g., '1.5 MB', '512 KB')
|
||||||
|
"""
|
||||||
|
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
||||||
|
if abs(size_bytes) < 1024:
|
||||||
|
return f"{size_bytes:.1f} {unit}"
|
||||||
|
size_bytes /= 1024
|
||||||
|
return f"{size_bytes:.1f} PB"
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_file_attachment(
|
||||||
|
file_path: Path,
|
||||||
|
model_capabilities: Dict[str, Any]
|
||||||
|
) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
Prepare a file for attachment to an API request.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Path to the file
|
||||||
|
model_capabilities: Model capability information
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Content block dictionary for the API, or None if unsupported
|
||||||
|
"""
|
||||||
|
logger = get_logger()
|
||||||
|
path = Path(file_path).resolve()
|
||||||
|
|
||||||
|
if not path.exists():
|
||||||
|
logger.warning(f"File not found: {path}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
mime_type, category = get_file_type(path)
|
||||||
|
file_size = path.stat().st_size
|
||||||
|
|
||||||
|
if file_size > MAX_FILE_SIZE:
|
||||||
|
logger.warning(f"File too large: {path} ({format_file_size(file_size)})")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(path, "rb") as f:
|
||||||
|
file_data = f.read()
|
||||||
|
|
||||||
|
if category == "image":
|
||||||
|
# Check if model supports images
|
||||||
|
input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", [])
|
||||||
|
if "image" not in input_modalities:
|
||||||
|
logger.warning(f"Model does not support images")
|
||||||
|
return None
|
||||||
|
|
||||||
|
b64_data = base64.b64encode(file_data).decode("utf-8")
|
||||||
|
return {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:{mime_type};base64,{b64_data}"}
|
||||||
|
}
|
||||||
|
|
||||||
|
elif category == "pdf":
|
||||||
|
# Check if model supports PDFs
|
||||||
|
input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", [])
|
||||||
|
supports_pdf = any(mod in input_modalities for mod in ["document", "pdf", "file"])
|
||||||
|
if not supports_pdf:
|
||||||
|
logger.warning(f"Model does not support PDFs")
|
||||||
|
return None
|
||||||
|
|
||||||
|
b64_data = base64.b64encode(file_data).decode("utf-8")
|
||||||
|
return {
|
||||||
|
"type": "image_url",
|
||||||
|
"image_url": {"url": f"data:application/pdf;base64,{b64_data}"}
|
||||||
|
}
|
||||||
|
|
||||||
|
elif category in ("code", "text"):
|
||||||
|
text_content = file_data.decode("utf-8")
|
||||||
|
return {
|
||||||
|
"type": "text",
|
||||||
|
"text": f"File: {path.name}\n\n{text_content}"
|
||||||
|
}
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning(f"Unsupported file type: {category} ({mime_type})")
|
||||||
|
return None
|
||||||
|
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
logger.error(f"Cannot decode file as UTF-8: {path}")
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error preparing file attachment {path}: {e}")
|
||||||
|
return None
|
||||||
297
oai/utils/logging.py
Normal file
297
oai/utils/logging.py
Normal file
@@ -0,0 +1,297 @@
|
|||||||
|
"""
|
||||||
|
Logging configuration for oAI.
|
||||||
|
|
||||||
|
This module provides centralized logging setup with Rich formatting,
|
||||||
|
file rotation, and configurable log levels.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
|
import logging
|
||||||
|
import datetime
|
||||||
|
import shutil
|
||||||
|
from logging.handlers import RotatingFileHandler
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from rich.console import Console
|
||||||
|
from rich.logging import RichHandler
|
||||||
|
|
||||||
|
from oai.constants import (
|
||||||
|
LOG_FILE,
|
||||||
|
CONFIG_DIR,
|
||||||
|
DEFAULT_LOG_MAX_SIZE_MB,
|
||||||
|
DEFAULT_LOG_BACKUP_COUNT,
|
||||||
|
DEFAULT_LOG_LEVEL,
|
||||||
|
VALID_LOG_LEVELS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RotatingRichHandler(RotatingFileHandler):
|
||||||
|
"""
|
||||||
|
Custom log handler combining file rotation with Rich formatting.
|
||||||
|
|
||||||
|
This handler writes Rich-formatted log output to a rotating file,
|
||||||
|
providing colored and formatted logs even in file output while
|
||||||
|
managing file size and backups automatically.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
"""Initialize the handler with Rich console for formatting."""
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
# Create an internal console for Rich formatting
|
||||||
|
self.rich_console = Console(
|
||||||
|
file=io.StringIO(),
|
||||||
|
width=120,
|
||||||
|
force_terminal=False
|
||||||
|
)
|
||||||
|
self.rich_handler = RichHandler(
|
||||||
|
console=self.rich_console,
|
||||||
|
show_time=True,
|
||||||
|
show_path=True,
|
||||||
|
rich_tracebacks=True,
|
||||||
|
tracebacks_suppress=["requests", "openrouter", "urllib3", "httpx", "openai"]
|
||||||
|
)
|
||||||
|
|
||||||
|
def emit(self, record: logging.LogRecord) -> None:
|
||||||
|
"""
|
||||||
|
Emit a log record with Rich formatting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
record: The log record to emit
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Format with Rich
|
||||||
|
self.rich_handler.emit(record)
|
||||||
|
output = self.rich_console.file.getvalue()
|
||||||
|
self.rich_console.file.seek(0)
|
||||||
|
self.rich_console.file.truncate(0)
|
||||||
|
|
||||||
|
if output:
|
||||||
|
self.stream.write(output)
|
||||||
|
self.flush()
|
||||||
|
except Exception:
|
||||||
|
self.handleError(record)
|
||||||
|
|
||||||
|
|
||||||
|
class LoggingManager:
|
||||||
|
"""
|
||||||
|
Manages application logging configuration.
|
||||||
|
|
||||||
|
Provides methods to setup, configure, and manage logging with
|
||||||
|
support for runtime reconfiguration and level changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the logging manager."""
|
||||||
|
self.handler: Optional[RotatingRichHandler] = None
|
||||||
|
self.app_logger: Optional[logging.Logger] = None
|
||||||
|
self.max_size_mb: int = DEFAULT_LOG_MAX_SIZE_MB
|
||||||
|
self.backup_count: int = DEFAULT_LOG_BACKUP_COUNT
|
||||||
|
self.log_level: str = DEFAULT_LOG_LEVEL
|
||||||
|
|
||||||
|
def setup(
|
||||||
|
self,
|
||||||
|
max_size_mb: Optional[int] = None,
|
||||||
|
backup_count: Optional[int] = None,
|
||||||
|
log_level: Optional[str] = None
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
Setup or reconfigure logging.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_size_mb: Maximum log file size in MB
|
||||||
|
backup_count: Number of backup files to keep
|
||||||
|
log_level: Logging level string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configured application logger
|
||||||
|
"""
|
||||||
|
# Update configuration if provided
|
||||||
|
if max_size_mb is not None:
|
||||||
|
self.max_size_mb = max_size_mb
|
||||||
|
if backup_count is not None:
|
||||||
|
self.backup_count = backup_count
|
||||||
|
if log_level is not None:
|
||||||
|
self.log_level = log_level
|
||||||
|
|
||||||
|
# Ensure config directory exists
|
||||||
|
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# Get root logger
|
||||||
|
root_logger = logging.getLogger()
|
||||||
|
|
||||||
|
# Remove existing handler if present
|
||||||
|
if self.handler is not None:
|
||||||
|
root_logger.removeHandler(self.handler)
|
||||||
|
try:
|
||||||
|
self.handler.close()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Check if log needs manual rotation
|
||||||
|
self._check_rotation()
|
||||||
|
|
||||||
|
# Create new handler
|
||||||
|
max_bytes = self.max_size_mb * 1024 * 1024
|
||||||
|
self.handler = RotatingRichHandler(
|
||||||
|
filename=str(LOG_FILE),
|
||||||
|
maxBytes=max_bytes,
|
||||||
|
backupCount=self.backup_count,
|
||||||
|
encoding="utf-8"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.handler.setLevel(logging.NOTSET)
|
||||||
|
root_logger.setLevel(logging.WARNING)
|
||||||
|
root_logger.addHandler(self.handler)
|
||||||
|
|
||||||
|
# Suppress noisy third-party loggers
|
||||||
|
for logger_name in [
|
||||||
|
"asyncio", "urllib3", "requests", "httpx",
|
||||||
|
"httpcore", "openai", "openrouter"
|
||||||
|
]:
|
||||||
|
logging.getLogger(logger_name).setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
# Configure application logger
|
||||||
|
self.app_logger = logging.getLogger("oai_app")
|
||||||
|
level = VALID_LOG_LEVELS.get(self.log_level.lower(), logging.INFO)
|
||||||
|
self.app_logger.setLevel(level)
|
||||||
|
self.app_logger.propagate = True
|
||||||
|
|
||||||
|
return self.app_logger
|
||||||
|
|
||||||
|
def _check_rotation(self) -> None:
|
||||||
|
"""Check if log file needs rotation and perform if necessary."""
|
||||||
|
if not LOG_FILE.exists():
|
||||||
|
return
|
||||||
|
|
||||||
|
current_size = LOG_FILE.stat().st_size
|
||||||
|
max_bytes = self.max_size_mb * 1024 * 1024
|
||||||
|
|
||||||
|
if current_size >= max_bytes:
|
||||||
|
# Perform manual rotation
|
||||||
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
backup_file = f"{LOG_FILE}.{timestamp}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
shutil.move(str(LOG_FILE), backup_file)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Clean old backups
|
||||||
|
self._cleanup_old_backups()
|
||||||
|
|
||||||
|
def _cleanup_old_backups(self) -> None:
|
||||||
|
"""Remove old backup files exceeding the backup count."""
|
||||||
|
log_dir = LOG_FILE.parent
|
||||||
|
backup_pattern = f"{LOG_FILE.name}.*"
|
||||||
|
backups = sorted(glob.glob(str(log_dir / backup_pattern)))
|
||||||
|
|
||||||
|
while len(backups) > self.backup_count:
|
||||||
|
oldest = backups.pop(0)
|
||||||
|
try:
|
||||||
|
os.remove(oldest)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_level(self, level: str) -> bool:
|
||||||
|
"""
|
||||||
|
Set the application log level.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Log level string (debug/info/warning/error/critical)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if level was set successfully, False otherwise
|
||||||
|
"""
|
||||||
|
level_lower = level.lower()
|
||||||
|
if level_lower not in VALID_LOG_LEVELS:
|
||||||
|
return False
|
||||||
|
|
||||||
|
self.log_level = level_lower
|
||||||
|
if self.app_logger:
|
||||||
|
self.app_logger.setLevel(VALID_LOG_LEVELS[level_lower])
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_logger(self) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
Get the application logger, initializing if necessary.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The application logger
|
||||||
|
"""
|
||||||
|
if self.app_logger is None:
|
||||||
|
self.setup()
|
||||||
|
return self.app_logger
|
||||||
|
|
||||||
|
|
||||||
|
# Global logging manager instance
|
||||||
|
_logging_manager = LoggingManager()
|
||||||
|
|
||||||
|
|
||||||
|
def setup_logging(
|
||||||
|
max_size_mb: Optional[int] = None,
|
||||||
|
backup_count: Optional[int] = None,
|
||||||
|
log_level: Optional[str] = None
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
Setup application logging.
|
||||||
|
|
||||||
|
This is the main entry point for configuring logging. Call this
|
||||||
|
early in application startup.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_size_mb: Maximum log file size in MB
|
||||||
|
backup_count: Number of backup files to keep
|
||||||
|
log_level: Logging level string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The configured application logger
|
||||||
|
"""
|
||||||
|
return _logging_manager.setup(max_size_mb, backup_count, log_level)
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger() -> logging.Logger:
|
||||||
|
"""
|
||||||
|
Get the application logger.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The application logger instance
|
||||||
|
"""
|
||||||
|
return _logging_manager.get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def set_log_level(level: str) -> bool:
|
||||||
|
"""
|
||||||
|
Set the application log level.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
level: Log level string
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if successful, False otherwise
|
||||||
|
"""
|
||||||
|
return _logging_manager.set_level(level)
|
||||||
|
|
||||||
|
|
||||||
|
def reload_logging(
|
||||||
|
max_size_mb: Optional[int] = None,
|
||||||
|
backup_count: Optional[int] = None,
|
||||||
|
log_level: Optional[str] = None
|
||||||
|
) -> logging.Logger:
|
||||||
|
"""
|
||||||
|
Reload logging configuration.
|
||||||
|
|
||||||
|
Useful when settings change at runtime.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_size_mb: New maximum log file size
|
||||||
|
backup_count: New backup count
|
||||||
|
log_level: New log level
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The reconfigured logger
|
||||||
|
"""
|
||||||
|
return _logging_manager.setup(max_size_mb, backup_count, log_level)
|
||||||
134
pyproject.toml
Normal file
134
pyproject.toml
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["setuptools>=61.0", "wheel"]
|
||||||
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "oai"
|
||||||
|
version = "2.1.0"
|
||||||
|
description = "OpenRouter AI Chat Client - A feature-rich terminal-based chat application"
|
||||||
|
readme = "README.md"
|
||||||
|
license = {text = "MIT"}
|
||||||
|
authors = [
|
||||||
|
{name = "Rune", email = "rune@example.com"}
|
||||||
|
]
|
||||||
|
maintainers = [
|
||||||
|
{name = "Rune", email = "rune@example.com"}
|
||||||
|
]
|
||||||
|
keywords = [
|
||||||
|
"ai",
|
||||||
|
"chat",
|
||||||
|
"openrouter",
|
||||||
|
"cli",
|
||||||
|
"terminal",
|
||||||
|
"mcp",
|
||||||
|
"llm",
|
||||||
|
]
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 4 - Beta",
|
||||||
|
"Environment :: Console",
|
||||||
|
"Intended Audience :: Developers",
|
||||||
|
"License :: OSI Approved :: MIT License",
|
||||||
|
"Operating System :: OS Independent",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.10",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
||||||
|
"Topic :: Utilities",
|
||||||
|
]
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = [
|
||||||
|
"anyio>=4.0.0",
|
||||||
|
"click>=8.0.0",
|
||||||
|
"httpx>=0.24.0",
|
||||||
|
"markdown-it-py>=3.0.0",
|
||||||
|
"openrouter>=0.0.19",
|
||||||
|
"packaging>=21.0",
|
||||||
|
"prompt-toolkit>=3.0.0",
|
||||||
|
"pyperclip>=1.8.0",
|
||||||
|
"requests>=2.28.0",
|
||||||
|
"rich>=13.0.0",
|
||||||
|
"typer>=0.9.0",
|
||||||
|
"mcp>=1.0.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=7.0.0",
|
||||||
|
"pytest-asyncio>=0.21.0",
|
||||||
|
"pytest-cov>=4.0.0",
|
||||||
|
"black>=23.0.0",
|
||||||
|
"isort>=5.12.0",
|
||||||
|
"mypy>=1.0.0",
|
||||||
|
"ruff>=0.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Homepage = "https://iurl.no/oai"
|
||||||
|
Repository = "https://gitlab.pm/rune/oai"
|
||||||
|
Documentation = "https://iurl.no/oai"
|
||||||
|
"Bug Tracker" = "https://gitlab.pm/rune/oai/issues"
|
||||||
|
|
||||||
|
[project.scripts]
|
||||||
|
oai = "oai.cli:main"
|
||||||
|
|
||||||
|
[tool.setuptools]
|
||||||
|
packages = ["oai", "oai.commands", "oai.config", "oai.core", "oai.mcp", "oai.providers", "oai.ui", "oai.utils"]
|
||||||
|
|
||||||
|
[tool.setuptools.package-data]
|
||||||
|
oai = ["py.typed"]
|
||||||
|
|
||||||
|
[tool.black]
|
||||||
|
line-length = 100
|
||||||
|
target-version = ["py310", "py311", "py312"]
|
||||||
|
include = '\.pyi?$'
|
||||||
|
exclude = '''
|
||||||
|
/(
|
||||||
|
\.git
|
||||||
|
| \.mypy_cache
|
||||||
|
| \.pytest_cache
|
||||||
|
| \.venv
|
||||||
|
| build
|
||||||
|
| dist
|
||||||
|
)/
|
||||||
|
'''
|
||||||
|
|
||||||
|
[tool.isort]
|
||||||
|
profile = "black"
|
||||||
|
line_length = 100
|
||||||
|
skip_gitignore = true
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.10"
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unused_configs = true
|
||||||
|
ignore_missing_imports = true
|
||||||
|
exclude = [
|
||||||
|
"build",
|
||||||
|
"dist",
|
||||||
|
".venv",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 100
|
||||||
|
target-version = "py310"
|
||||||
|
select = [
|
||||||
|
"E", # pycodestyle errors
|
||||||
|
"W", # pycodestyle warnings
|
||||||
|
"F", # Pyflakes
|
||||||
|
"I", # isort
|
||||||
|
"B", # flake8-bugbear
|
||||||
|
"C4", # flake8-comprehensions
|
||||||
|
"UP", # pyupgrade
|
||||||
|
]
|
||||||
|
ignore = [
|
||||||
|
"E501", # line too long (handled by black)
|
||||||
|
"B008", # do not perform function calls in argument defaults
|
||||||
|
"C901", # too complex
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
python_files = ["test_*.py"]
|
||||||
|
asyncio_mode = "auto"
|
||||||
|
addopts = "-v --tb=short"
|
||||||
Reference in New Issue
Block a user