Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a0ed0eaaf0 | |||
| 9a76ce2c1f | |||
| 106ef676e2 | |||
| 8c49339452 | |||
| 8a113a9bbe | |||
| c305d5cf49 | |||
| 2c9f33868e | |||
| d4f1a1c6a4 | |||
| 2e7c49bf68 |
7
.gitignore
vendored
7
.gitignore
vendored
@@ -23,9 +23,6 @@ Pipfile.lock # Consider if you want to include or exclude
|
|||||||
*~.nib
|
*~.nib
|
||||||
*~.xib
|
*~.xib
|
||||||
|
|
||||||
# Claude Code local settings
|
|
||||||
.claude/
|
|
||||||
|
|
||||||
# Added by author
|
# Added by author
|
||||||
*.zip
|
*.zip
|
||||||
.note
|
.note
|
||||||
@@ -42,7 +39,3 @@ b0.sh
|
|||||||
*.old
|
*.old
|
||||||
*.sh
|
*.sh
|
||||||
*.back
|
*.back
|
||||||
requirements.txt
|
|
||||||
system_prompt.txt
|
|
||||||
CLAUDE*
|
|
||||||
SESSION*_COMPLETE.md
|
|
||||||
|
|||||||
686
README.md
686
README.md
@@ -1,326 +1,584 @@
|
|||||||
# oAI - OpenRouter AI Chat Client
|
# oAI - OpenRouter AI Chat
|
||||||
|
|
||||||
A powerful, modern **Textual TUI** chat client for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI to access local files and query SQLite databases.
|
A powerful terminal-based chat interface for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI agents to access local files and query SQLite databases directly.
|
||||||
|
|
||||||
|
## Description
|
||||||
|
|
||||||
|
oAI is a feature-rich command-line chat application that provides an interactive interface to OpenRouter's AI models. It now includes **MCP integration** for local file system access and read-only database querying, allowing AI to help with code analysis, data exploration, and more.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
### Core Features
|
### Core Features
|
||||||
- 🖥️ **Modern Textual TUI** with async streaming and beautiful interface
|
|
||||||
- 🤖 Interactive chat with 300+ AI models via OpenRouter
|
- 🤖 Interactive chat with 300+ AI models via OpenRouter
|
||||||
- 🔍 Model selection with search, filtering, and capability icons
|
- 🔍 Model selection with search and capability filtering
|
||||||
- 💾 Conversation save/load/export (Markdown, JSON, HTML)
|
- 💾 Conversation save/load/export (Markdown, JSON, HTML)
|
||||||
- 📎 File attachments (images, PDFs, code files)
|
- 📎 File attachment support (images, PDFs, code files)
|
||||||
- 💰 Real-time cost tracking and credit monitoring
|
- 💰 Session cost tracking and credit monitoring
|
||||||
- 🎨 Dark theme with syntax highlighting and Markdown rendering
|
- 🎨 Rich terminal formatting with syntax highlighting
|
||||||
- 📝 Command history navigation (Up/Down arrows)
|
- 📝 Persistent command history with search (Ctrl+R)
|
||||||
|
- ⚙️ Configurable system prompts and token limits
|
||||||
|
- 🗄️ SQLite-based configuration and conversation storage
|
||||||
- 🌐 Online mode (web search capabilities)
|
- 🌐 Online mode (web search capabilities)
|
||||||
- 🧠 Conversation memory toggle
|
- 🧠 Conversation memory toggle (save costs with stateless mode)
|
||||||
- ⌨️ Keyboard shortcuts (F1=Help, F2=Models, Ctrl+S=Stats)
|
|
||||||
|
|
||||||
### MCP Integration
|
### NEW: MCP (Model Context Protocol) v2.1.0-beta
|
||||||
- 🔧 **File Mode**: AI can read, search, and list local files
|
- 🔧 **File Mode**: AI can read, search, and list your local files
|
||||||
- Automatic .gitignore filtering
|
- Automatic .gitignore filtering
|
||||||
- Virtual environment exclusion
|
- Virtual environment exclusion (venv, node_modules, etc.)
|
||||||
|
- Supports code files, text, JSON, YAML, and more
|
||||||
- Large file handling (auto-truncates >50KB)
|
- Large file handling (auto-truncates >50KB)
|
||||||
|
|
||||||
|
- ✍️ **Write Mode** (NEW!): AI can modify files with your permission
|
||||||
|
- Create and edit files within allowed folders
|
||||||
|
- Delete files (always requires confirmation)
|
||||||
|
- Move, copy, and organize files
|
||||||
|
- Create directories
|
||||||
|
- Ignores .gitignore for write operations
|
||||||
|
- OFF by default - explicit opt-in required
|
||||||
|
|
||||||
|
- 🗄️ **Database Mode**: AI can query your SQLite databases
|
||||||
|
- Read-only access (no data modification possible)
|
||||||
|
- Schema inspection (tables, columns, indexes)
|
||||||
|
- Full-text search across all tables
|
||||||
|
- SQL query execution (SELECT, JOINs, CTEs, subqueries)
|
||||||
|
- Query validation and timeout protection
|
||||||
|
- Result limiting (max 1000 rows)
|
||||||
|
|
||||||
- ✍️ **Write Mode**: AI can modify files with permission
|
- 🔒 **Security Features**:
|
||||||
- Create, edit, delete files
|
- Explicit folder/database approval required
|
||||||
- Move, copy, organize files
|
- System directory blocking
|
||||||
- Always requires explicit opt-in
|
- Write mode OFF by default (non-persistent)
|
||||||
|
- Delete operations always require user confirmation
|
||||||
- 🗄️ **Database Mode**: AI can query SQLite databases
|
- Read-only database access
|
||||||
- Read-only access (safe)
|
- SQL injection protection
|
||||||
- Schema inspection
|
- Query timeout (5 seconds)
|
||||||
- Full SQL query support
|
|
||||||
|
|
||||||
## Requirements
|
## Requirements
|
||||||
|
|
||||||
- Python 3.10-3.13
|
- Python 3.10-3.13 (3.14 not supported yet)
|
||||||
- OpenRouter API key ([get one here](https://openrouter.ai))
|
- OpenRouter API key (get one at https://openrouter.ai)
|
||||||
|
- Function-calling model required for MCP features (GPT-4, Claude, etc.)
|
||||||
|
|
||||||
|
## Screenshot
|
||||||
|
|
||||||
|
[<img src="https://gitlab.pm/rune/oai/raw/branch/main/images/screenshot_01.png">](https://gitlab.pm/rune/oai/src/branch/main/README.md)
|
||||||
|
|
||||||
|
*Screenshot from version 1.0 - MCP interface shows mode indicators like `[🔧 MCP: Files]` or `[🗄️ MCP: DB #1]`*
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
### Option 1: Pre-built Binary (macOS/Linux) (Recommended)
|
### Option 1: From Source (Recommended for Development)
|
||||||
|
|
||||||
Download from [Releases](https://gitlab.pm/rune/oai/releases):
|
#### 1. Install Dependencies
|
||||||
- **macOS (Apple Silicon)**: `oai_v3.0.0_mac_arm64.zip`
|
|
||||||
- **Linux (x86_64)**: `oai_v3.0.0_linux_x86_64.zip`
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Extract and install
|
pip install -r requirements.txt
|
||||||
unzip oai_v3.0.0_*.zip
|
|
||||||
mkdir -p ~/.local/bin
|
|
||||||
mv oai ~/.local/bin/
|
|
||||||
|
|
||||||
# macOS only: Remove quarantine and approve
|
|
||||||
xattr -cr ~/.local/bin/oai
|
|
||||||
# Then right-click oai in Finder → Open With → Terminal → Click "Open"
|
|
||||||
```
|
```
|
||||||
|
|
||||||
### Add to PATH
|
#### 2. Make Executable
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Add to ~/.zshrc or ~/.bashrc
|
chmod +x oai.py
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 3. Copy to PATH
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Option 1: System-wide (requires sudo)
|
||||||
|
sudo cp oai.py /usr/local/bin/oai
|
||||||
|
|
||||||
|
# Option 2: User-local (recommended)
|
||||||
|
mkdir -p ~/.local/bin
|
||||||
|
cp oai.py ~/.local/bin/oai
|
||||||
|
|
||||||
|
# Add to PATH if needed (add to ~/.bashrc or ~/.zshrc)
|
||||||
export PATH="$HOME/.local/bin:$PATH"
|
export PATH="$HOME/.local/bin:$PATH"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 4. Verify Installation
|
||||||
### Option 2: Install from Source
|
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Clone the repository
|
oai --version
|
||||||
git clone https://gitlab.pm/rune/oai.git
|
|
||||||
cd oai
|
|
||||||
|
|
||||||
# Install with pip
|
|
||||||
pip install -e .
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Option 2: Pre-built Binaries
|
||||||
|
|
||||||
|
Download platform-specific binaries:
|
||||||
|
- **macOS (Apple Silicon)**: `oai_vx.x.x_mac_arm64.zip`
|
||||||
|
- **Linux (x86_64)**: `oai_vx.x.x-linux-x86_64.zip`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Extract and install
|
||||||
|
unzip oai_vx.x.x_mac_arm64.zip # or `oai_vx.x.x-linux-x86_64.zip`
|
||||||
|
chmod +x oai
|
||||||
|
mkdir -p ~/.local/bin # Remember to add this to your path. Or just move to folder already in your $PATH
|
||||||
|
mv oai ~/.local/bin/
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### Alternative: Shell Alias
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Add to ~/.bashrc or ~/.zshrc
|
||||||
|
alias oai='python3 /path/to/oai.py'
|
||||||
|
```
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
|
### First Run Setup
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Start oAI (launches TUI)
|
oai
|
||||||
|
```
|
||||||
|
|
||||||
|
On first run, you'll be prompted to enter your OpenRouter API key.
|
||||||
|
|
||||||
|
### Basic Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start chatting
|
||||||
oai
|
oai
|
||||||
|
|
||||||
# Or with options
|
# Select a model
|
||||||
oai --model gpt-4o --online --mcp
|
You> /model
|
||||||
|
|
||||||
# Show version
|
# Enable MCP for file access
|
||||||
oai version
|
You> /mcp on
|
||||||
|
You> /mcp add ~/Documents
|
||||||
|
|
||||||
|
# Ask AI to help with files (read-only)
|
||||||
|
[🔧 MCP: Files] You> List all Python files in Documents
|
||||||
|
[🔧 MCP: Files] You> Read and explain main.py
|
||||||
|
|
||||||
|
# Enable write mode to let AI modify files
|
||||||
|
You> /mcp write on
|
||||||
|
[🔧✍️ MCP: Files+Write] You> Create a new Python file with helper functions
|
||||||
|
[🔧✍️ MCP: Files+Write] You> Refactor main.py to use async/await
|
||||||
|
|
||||||
|
# Switch to database mode
|
||||||
|
You> /mcp add db ~/myapp/data.db
|
||||||
|
You> /mcp db 1
|
||||||
|
[🗄️ MCP: DB #1] You> Show me all tables
|
||||||
|
[🗄️ MCP: DB #1] You> Find all users created this month
|
||||||
```
|
```
|
||||||
|
|
||||||
On first run, you'll be prompted for your OpenRouter API key.
|
## MCP Guide
|
||||||
|
|
||||||
### Basic Commands
|
### File Mode (Default)
|
||||||
|
|
||||||
|
**Setup:**
|
||||||
```bash
|
```bash
|
||||||
# In the TUI interface:
|
/mcp on # Start MCP server
|
||||||
/model # Select AI model (or press F2)
|
/mcp add ~/Projects # Grant access to folder
|
||||||
/help # Show all commands (or press F1)
|
/mcp add ~/Documents # Add another folder
|
||||||
/mcp on # Enable file/database access
|
/mcp list # View all allowed folders
|
||||||
/stats # View session statistics (or press Ctrl+S)
|
|
||||||
/config # View configuration settings
|
|
||||||
/credits # Check account credits
|
|
||||||
Ctrl+Q # Quit
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## MCP (Model Context Protocol)
|
**Natural Language Usage:**
|
||||||
|
```
|
||||||
MCP allows the AI to interact with your local files and databases.
|
|
||||||
|
|
||||||
### File Access
|
|
||||||
|
|
||||||
```bash
|
|
||||||
/mcp on # Enable MCP
|
|
||||||
/mcp add ~/Projects # Grant access to folder
|
|
||||||
/mcp list # View allowed folders
|
|
||||||
|
|
||||||
# Now ask the AI:
|
|
||||||
"List all Python files in Projects"
|
"List all Python files in Projects"
|
||||||
"Read and explain main.py"
|
"Read and explain config.yaml"
|
||||||
"Search for files containing 'TODO'"
|
"Search for files containing 'TODO'"
|
||||||
|
"What's in my Documents folder?"
|
||||||
```
|
```
|
||||||
|
|
||||||
### Write Mode
|
**Available Tools (Read-Only):**
|
||||||
|
- `read_file` - Read complete file contents
|
||||||
|
- `list_directory` - List files/folders (recursive optional)
|
||||||
|
- `search_files` - Search by name or content
|
||||||
|
|
||||||
```bash
|
**Available Tools (Write Mode - requires `/mcp write on`):**
|
||||||
/mcp write on # Enable file modifications
|
- `write_file` - Create new files or overwrite existing ones
|
||||||
|
- `edit_file` - Find and replace text in existing files
|
||||||
|
- `delete_file` - Delete files (always requires confirmation)
|
||||||
|
- `create_directory` - Create directories
|
||||||
|
- `move_file` - Move or rename files
|
||||||
|
- `copy_file` - Copy files to new locations
|
||||||
|
|
||||||
# AI can now:
|
**Features:**
|
||||||
"Create a new file called utils.py"
|
- ✅ Automatic .gitignore filtering (read operations only)
|
||||||
"Edit config.json and update the API URL"
|
- ✅ Skips virtual environments (venv, node_modules)
|
||||||
"Delete the old backup files" # Always asks for confirmation
|
- ✅ Handles large files (auto-truncates >50KB)
|
||||||
```
|
- ✅ Cross-platform (macOS, Linux, Windows via WSL)
|
||||||
|
- ✅ Write mode OFF by default for safety
|
||||||
|
- ✅ Delete operations require user confirmation with LLM's reason
|
||||||
|
|
||||||
### Database Mode
|
### Database Mode
|
||||||
|
|
||||||
|
**Setup:**
|
||||||
```bash
|
```bash
|
||||||
/mcp add db ~/app/data.db # Add database
|
/mcp add db ~/app/database.db # Add SQLite database
|
||||||
/mcp db 1 # Switch to database mode
|
/mcp db list # View all databases
|
||||||
|
/mcp db 1 # Switch to database #1
|
||||||
|
```
|
||||||
|
|
||||||
# Ask the AI:
|
**Natural Language Usage:**
|
||||||
"Show all tables"
|
```
|
||||||
"Find users created this month"
|
"Show me all tables in this database"
|
||||||
"What's the schema for the orders table?"
|
"Find records mentioning 'error'"
|
||||||
|
"How many users registered last week?"
|
||||||
|
"Get the schema for the orders table"
|
||||||
|
"Show me the 10 most recent transactions"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Available Tools:**
|
||||||
|
- `inspect_database` - View schema, tables, columns, indexes
|
||||||
|
- `search_database` - Full-text search across tables
|
||||||
|
- `query_database` - Execute read-only SQL queries
|
||||||
|
|
||||||
|
**Supported Queries:**
|
||||||
|
- ✅ SELECT statements
|
||||||
|
- ✅ JOINs (INNER, LEFT, RIGHT, FULL)
|
||||||
|
- ✅ Subqueries
|
||||||
|
- ✅ CTEs (Common Table Expressions)
|
||||||
|
- ✅ Aggregations (COUNT, SUM, AVG, etc.)
|
||||||
|
- ✅ WHERE, GROUP BY, HAVING, ORDER BY, LIMIT
|
||||||
|
- ❌ INSERT/UPDATE/DELETE (blocked for safety)
|
||||||
|
|
||||||
|
### Write Mode
|
||||||
|
|
||||||
|
**Enable Write Mode:**
|
||||||
|
```bash
|
||||||
|
/mcp write on # Enable write mode (shows warning, requires confirmation)
|
||||||
|
```
|
||||||
|
|
||||||
|
**Natural Language Usage:**
|
||||||
|
```
|
||||||
|
"Create a new Python file called utils.py with helper functions"
|
||||||
|
"Edit main.py and replace the old API endpoint with the new one"
|
||||||
|
"Delete the backup.old file" (will prompt for confirmation)
|
||||||
|
"Create a directory called tests"
|
||||||
|
"Move config.json to the config folder"
|
||||||
|
```
|
||||||
|
|
||||||
|
**Important:**
|
||||||
|
- ⚠️ Write mode is **OFF by default** and resets each session
|
||||||
|
- ⚠️ Delete operations **always** require user confirmation
|
||||||
|
- ⚠️ All operations are limited to allowed MCP folders
|
||||||
|
- ✅ Write operations ignore .gitignore (can write to any file in allowed folders)
|
||||||
|
|
||||||
|
**Disable Write Mode:**
|
||||||
|
```bash
|
||||||
|
/mcp write off # Disable write mode (back to read-only)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Mode Management
|
||||||
|
|
||||||
|
```bash
|
||||||
|
/mcp status # Show current mode, write mode, stats, folders/databases
|
||||||
|
/mcp files # Switch to file mode
|
||||||
|
/mcp db <number> # Switch to database mode
|
||||||
|
/mcp gitignore on # Enable .gitignore filtering (default)
|
||||||
|
/mcp write on|off # Enable/disable write mode
|
||||||
|
/mcp remove 2 # Remove folder/database by number
|
||||||
```
|
```
|
||||||
|
|
||||||
## Command Reference
|
## Command Reference
|
||||||
|
|
||||||
### Chat Commands
|
### Session Commands
|
||||||
| Command | Description |
|
```
|
||||||
|---------|-------------|
|
/help [command] Show help menu or detailed command help
|
||||||
| `/help [cmd]` | Show help |
|
/help mcp Comprehensive MCP guide
|
||||||
| `/model [search]` | Select model |
|
/clear or /cl Clear terminal screen (or Ctrl+L)
|
||||||
| `/info [model]` | Model details |
|
/memory on|off Toggle conversation memory (save costs)
|
||||||
| `/memory on\|off` | Toggle context |
|
/online on|off Enable/disable web search
|
||||||
| `/online on\|off` | Toggle web search |
|
/paste [prompt] Paste clipboard content
|
||||||
| `/retry` | Resend last message |
|
/retry Resend last prompt
|
||||||
| `/clear` | Clear screen |
|
/reset Clear history and system prompt
|
||||||
|
/prev View previous response
|
||||||
|
/next View next response
|
||||||
|
```
|
||||||
|
|
||||||
### MCP Commands
|
### MCP Commands
|
||||||
| Command | Description |
|
```
|
||||||
|---------|-------------|
|
/mcp on Start MCP server
|
||||||
| `/mcp on\|off` | Enable/disable MCP |
|
/mcp off Stop MCP server
|
||||||
| `/mcp status` | Show MCP status |
|
/mcp status Show comprehensive status (includes write mode)
|
||||||
| `/mcp add <path>` | Add folder |
|
/mcp add <folder> Add folder for file access
|
||||||
| `/mcp add db <path>` | Add database |
|
/mcp add db <path> Add SQLite database
|
||||||
| `/mcp list` | List folders |
|
/mcp list List all folders
|
||||||
| `/mcp db list` | List databases |
|
/mcp db list List all databases
|
||||||
| `/mcp db <n>` | Switch to database |
|
/mcp db <number> Switch to database mode
|
||||||
| `/mcp files` | Switch to file mode |
|
/mcp files Switch to file mode
|
||||||
| `/mcp write on\|off` | Toggle write mode |
|
/mcp remove <num> Remove folder/database
|
||||||
|
/mcp gitignore on Enable .gitignore filtering
|
||||||
|
/mcp write on Enable write mode (create/edit/delete files)
|
||||||
|
/mcp write off Disable write mode (read-only)
|
||||||
|
```
|
||||||
|
|
||||||
### Conversation Commands
|
### Model Commands
|
||||||
| Command | Description |
|
```
|
||||||
|---------|-------------|
|
/model [search] Select/change AI model
|
||||||
| `/save <name>` | Save conversation |
|
/info [model_id] Show model details (pricing, capabilities)
|
||||||
| `/load <name>` | Load conversation |
|
```
|
||||||
| `/list` | List saved conversations |
|
|
||||||
| `/delete <name>` | Delete conversation |
|
|
||||||
| `/export md\|json\|html <file>` | Export |
|
|
||||||
|
|
||||||
### Configuration
|
### Configuration
|
||||||
| Command | Description |
|
```
|
||||||
|---------|-------------|
|
/config View all settings
|
||||||
| `/config` | View settings |
|
/config api Set API key
|
||||||
| `/config api` | Set API key |
|
/config model Set default model
|
||||||
| `/config model <id>` | Set default model |
|
/config online Set default online mode (on|off)
|
||||||
| `/config stream on\|off` | Toggle streaming |
|
/config stream Enable/disable streaming (on|off)
|
||||||
| `/stats` | Session statistics |
|
/config maxtoken Set max token limit
|
||||||
| `/credits` | Check credits |
|
/config costwarning Set cost warning threshold ($)
|
||||||
|
/config loglevel Set log level (debug/info/warning/error)
|
||||||
|
/config log Set log file size (MB)
|
||||||
|
```
|
||||||
|
|
||||||
## CLI Options
|
### Conversation Management
|
||||||
|
```
|
||||||
|
/save <name> Save conversation
|
||||||
|
/load <name|num> Load saved conversation
|
||||||
|
/delete <name|num> Delete conversation
|
||||||
|
/list List saved conversations
|
||||||
|
/export md|json|html <file> Export conversation
|
||||||
|
```
|
||||||
|
|
||||||
|
### Token & System
|
||||||
|
```
|
||||||
|
/maxtoken [value] Set session token limit
|
||||||
|
/system [prompt] Set system prompt (use 'clear' to reset)
|
||||||
|
/middleout on|off Enable prompt compression
|
||||||
|
```
|
||||||
|
|
||||||
|
### Monitoring
|
||||||
|
```
|
||||||
|
/stats View session statistics
|
||||||
|
/credits Check OpenRouter credits
|
||||||
|
```
|
||||||
|
|
||||||
|
### File Attachments
|
||||||
|
```
|
||||||
|
@/path/to/file Attach file (images, PDFs, code)
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
Debug @script.py
|
||||||
|
Analyze @data.json
|
||||||
|
Review @screenshot.png
|
||||||
|
```
|
||||||
|
|
||||||
|
## Configuration Options
|
||||||
|
|
||||||
|
All configuration stored in `~/.config/oai/`:
|
||||||
|
|
||||||
|
### Files
|
||||||
|
- `oai_config.db` - SQLite database (settings, conversations, MCP config)
|
||||||
|
- `oai.log` - Application logs (rotating, configurable size)
|
||||||
|
- `history.txt` - Command history (searchable with Ctrl+R)
|
||||||
|
|
||||||
|
### Key Settings
|
||||||
|
- **API Key**: OpenRouter authentication
|
||||||
|
- **Default Model**: Auto-select on startup
|
||||||
|
- **Streaming**: Real-time response display
|
||||||
|
- **Max Tokens**: Global and session limits
|
||||||
|
- **Cost Warning**: Alert threshold for expensive requests
|
||||||
|
- **Online Mode**: Default web search setting
|
||||||
|
- **Log Level**: debug/info/warning/error/critical
|
||||||
|
- **Log Size**: Rotating file size in MB
|
||||||
|
|
||||||
|
## Supported File Types
|
||||||
|
|
||||||
|
### Code Files
|
||||||
|
`.py, .js, .ts, .cs, .java, .c, .cpp, .h, .hpp, .rb, .ruby, .php, .swift, .kt, .kts, .go, .sh, .bat, .ps1, .R, .scala, .pl, .lua, .dart, .elm`
|
||||||
|
|
||||||
|
### Data Files
|
||||||
|
`.json, .yaml, .yml, .xml, .csv, .txt, .md`
|
||||||
|
|
||||||
|
### Images
|
||||||
|
All standard formats: PNG, JPEG, JPG, GIF, WEBP, BMP
|
||||||
|
|
||||||
|
### Documents
|
||||||
|
PDF (models with document support)
|
||||||
|
|
||||||
|
### Size Limits
|
||||||
|
- Images: 10 MB max
|
||||||
|
- Code/Text: Auto-truncates files >50KB
|
||||||
|
- Binary data: Displayed as `<binary: X bytes>`
|
||||||
|
|
||||||
|
## MCP Security
|
||||||
|
|
||||||
|
### Access Control
|
||||||
|
- ✅ Explicit folder/database approval required
|
||||||
|
- ✅ System directories blocked automatically
|
||||||
|
- ✅ User confirmation for each addition
|
||||||
|
- ✅ .gitignore patterns respected (file mode)
|
||||||
|
|
||||||
|
### Database Safety
|
||||||
|
- ✅ Read-only mode (cannot modify data)
|
||||||
|
- ✅ SQL query validation (blocks INSERT/UPDATE/DELETE)
|
||||||
|
- ✅ Query timeout (5 seconds max)
|
||||||
|
- ✅ Result limits (1000 rows max)
|
||||||
|
- ✅ Database opened in `mode=ro`
|
||||||
|
|
||||||
|
### File System Safety
|
||||||
|
- ✅ Read-only by default (write mode requires explicit opt-in)
|
||||||
|
- ✅ Write mode OFF by default each session (non-persistent)
|
||||||
|
- ✅ Delete operations always require user confirmation
|
||||||
|
- ✅ Write operations limited to allowed folders only
|
||||||
|
- ✅ System directories blocked
|
||||||
|
- ✅ Virtual environment exclusion
|
||||||
|
- ✅ Build artifact filtering
|
||||||
|
- ✅ Maximum file size (10 MB)
|
||||||
|
|
||||||
|
## Tips & Tricks
|
||||||
|
|
||||||
|
### Command History
|
||||||
|
- **↑/↓ arrows**: Navigate previous commands
|
||||||
|
- **Ctrl+R**: Search command history
|
||||||
|
- **Auto-complete**: Start typing `/` for command suggestions
|
||||||
|
|
||||||
|
### Cost Optimization
|
||||||
```bash
|
```bash
|
||||||
oai [OPTIONS]
|
/memory off # Disable context (stateless mode)
|
||||||
|
/maxtoken 1000 # Limit response length
|
||||||
Options:
|
/config costwarning 0.01 # Set alert threshold
|
||||||
-m, --model TEXT Model ID to use
|
|
||||||
-s, --system TEXT System prompt
|
|
||||||
-o, --online Enable online mode
|
|
||||||
--mcp Enable MCP server
|
|
||||||
-v, --version Show version
|
|
||||||
--help Show help
|
|
||||||
```
|
```
|
||||||
|
|
||||||
Commands:
|
### MCP Best Practices
|
||||||
```bash
|
```bash
|
||||||
oai # Launch TUI (default)
|
# Check status frequently
|
||||||
oai version # Show version information
|
/mcp status
|
||||||
oai --help # Show help message
|
|
||||||
|
# Use specific paths to reduce search time
|
||||||
|
"List Python files in Projects/app/" # Better than
|
||||||
|
"List all Python files" # Slower
|
||||||
|
|
||||||
|
# Database queries - be specific
|
||||||
|
"SELECT * FROM users LIMIT 10" # Good
|
||||||
|
"SELECT * FROM users" # May hit row limit
|
||||||
```
|
```
|
||||||
|
|
||||||
## Configuration
|
### Debugging
|
||||||
|
```bash
|
||||||
|
# Enable debug logging
|
||||||
|
/config loglevel debug
|
||||||
|
|
||||||
Configuration is stored in `~/.config/oai/`:
|
# Check log file
|
||||||
|
tail -f ~/.config/oai/oai.log
|
||||||
|
|
||||||
| File | Purpose |
|
# View MCP statistics
|
||||||
|------|---------|
|
/mcp status # Shows tool call counts
|
||||||
| `oai_config.db` | Settings, conversations, MCP config |
|
|
||||||
| `oai.log` | Application logs |
|
|
||||||
| `history.txt` | Command history |
|
|
||||||
|
|
||||||
## Project Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
oai/
|
|
||||||
├── oai/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ ├── __main__.py # Entry point for python -m oai
|
|
||||||
│ ├── cli.py # Main CLI entry point
|
|
||||||
│ ├── constants.py # Configuration constants
|
|
||||||
│ ├── commands/ # Slash command handlers
|
|
||||||
│ ├── config/ # Settings and database
|
|
||||||
│ ├── core/ # Chat client and session
|
|
||||||
│ ├── mcp/ # MCP server and tools
|
|
||||||
│ ├── providers/ # AI provider abstraction
|
|
||||||
│ ├── tui/ # Textual TUI interface
|
|
||||||
│ │ ├── app.py # Main TUI application
|
|
||||||
│ │ ├── widgets/ # Custom widgets
|
|
||||||
│ │ ├── screens/ # Modal screens
|
|
||||||
│ │ └── styles.tcss # TUI styling
|
|
||||||
│ └── utils/ # Logging, export, etc.
|
|
||||||
├── pyproject.toml # Package configuration
|
|
||||||
├── build.sh # Binary build script
|
|
||||||
└── README.md
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Troubleshooting
|
## Troubleshooting
|
||||||
|
|
||||||
### macOS Binary Issues
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Remove quarantine attribute
|
|
||||||
xattr -cr ~/.local/bin/oai
|
|
||||||
|
|
||||||
# Then in Finder: right-click oai → Open With → Terminal → Click "Open"
|
|
||||||
# After this, oai works from any terminal
|
|
||||||
```
|
|
||||||
|
|
||||||
### MCP Not Working
|
### MCP Not Working
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Check if model supports function calling
|
# 1. Check if MCP is installed
|
||||||
|
python3 -c "import mcp; print('MCP OK')"
|
||||||
|
|
||||||
|
# 2. Verify model supports function calling
|
||||||
/info # Look for "tools" in supported parameters
|
/info # Look for "tools" in supported parameters
|
||||||
|
|
||||||
# Check MCP status
|
# 3. Check MCP status
|
||||||
/mcp status
|
/mcp status
|
||||||
|
|
||||||
# View logs
|
# 4. Review logs
|
||||||
tail -f ~/.config/oai/oai.log
|
tail ~/.config/oai/oai.log
|
||||||
```
|
```
|
||||||
|
|
||||||
### Import Errors
|
### Import Errors
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Reinstall package
|
# Reinstall dependencies
|
||||||
pip install -e . --force-reinstall
|
pip install --force-reinstall -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
### Binary Issues (macOS)
|
||||||
|
```bash
|
||||||
|
# Remove quarantine
|
||||||
|
xattr -cr ~/.local/bin/oai
|
||||||
|
|
||||||
|
# Check security settings
|
||||||
|
# System Settings > Privacy & Security > "Allow Anyway"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Database Errors
|
||||||
|
```bash
|
||||||
|
# Verify it's a valid SQLite database
|
||||||
|
sqlite3 database.db ".tables"
|
||||||
|
|
||||||
|
# Check file permissions
|
||||||
|
ls -la database.db
|
||||||
```
|
```
|
||||||
|
|
||||||
## Version History
|
## Version History
|
||||||
|
|
||||||
### v3.0.0 (Current)
|
### v2.1.0-RC1 (Current)
|
||||||
- 🎨 **Complete migration to Textual TUI** - Modern async terminal interface
|
- ✨ **NEW**: MCP (Model Context Protocol) integration
|
||||||
- 🗑️ **Removed CLI interface** - TUI-only for cleaner codebase (11.6% smaller)
|
- ✨ **NEW**: File system access (read, search, list)
|
||||||
- 🖱️ **Modal screens** - Help, stats, config, credits, model selector
|
- ✨ **NEW**: Write mode - AI can create, edit, and delete files
|
||||||
- ⌨️ **Keyboard shortcuts** - F1 (help), F2 (models), Ctrl+S (stats), etc.
|
- 6 write tools: write_file, edit_file, delete_file, create_directory, move_file, copy_file
|
||||||
- 🎯 **Capability indicators** - Visual icons for model features (vision, tools, online)
|
- OFF by default - requires explicit `/mcp write on` activation
|
||||||
- 🎨 **Consistent dark theme** - Professional styling throughout
|
- Delete operations always require user confirmation
|
||||||
- 📊 **Enhanced model selector** - Search, filter, capability columns
|
- Non-persistent setting (resets each session)
|
||||||
- 🚀 **Default command** - Just run `oai` to launch TUI
|
- ✨ **NEW**: SQLite database querying (read-only)
|
||||||
- 🧹 **Code cleanup** - Removed 1,300+ lines of CLI code
|
- ✨ **NEW**: Dual mode support (Files & Database)
|
||||||
|
- ✨ **NEW**: .gitignore filtering
|
||||||
|
- ✨ **NEW**: Binary data handling in databases
|
||||||
|
- ✨ **NEW**: Mode indicators in prompt (shows ✍️ when write mode active)
|
||||||
|
- ✨ **NEW**: Comprehensive `/help mcp` guide
|
||||||
|
- 🔧 Improved error handling for tool calls
|
||||||
|
- 🔧 Enhanced logging for MCP operations
|
||||||
|
- 🔧 Statistics tracking for tool usage
|
||||||
|
|
||||||
### v2.1.0
|
### v1.9.6
|
||||||
- 🏗️ Complete codebase refactoring to modular package structure
|
- Base version with core chat functionality
|
||||||
- 🔌 Extensible provider architecture for adding new AI providers
|
|
||||||
- 📦 Proper Python packaging with pyproject.toml
|
|
||||||
- ✨ MCP integration (file access, write mode, database queries)
|
|
||||||
- 🔧 Command registry pattern for slash commands
|
|
||||||
- 📊 Improved cost tracking and session statistics
|
|
||||||
|
|
||||||
### v1.9.x
|
|
||||||
- Single-file implementation
|
|
||||||
- Core chat functionality
|
|
||||||
- File attachments
|
|
||||||
- Conversation management
|
- Conversation management
|
||||||
|
- File attachments
|
||||||
|
- Cost tracking
|
||||||
|
- Export capabilities
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
MIT License - See [LICENSE](LICENSE) for details.
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2024-2025 Rune Olsen
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
|
||||||
|
Full license: https://opensource.org/licenses/MIT
|
||||||
|
|
||||||
## Author
|
## Author
|
||||||
|
|
||||||
**Rune Olsen**
|
**Rune Olsen**
|
||||||
|
|
||||||
|
- Homepage: https://ai.fubar.pm/
|
||||||
|
- Blog: https://blog.rune.pm
|
||||||
- Project: https://iurl.no/oai
|
- Project: https://iurl.no/oai
|
||||||
- Repository: https://gitlab.pm/rune/oai
|
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
|
Contributions welcome! Please:
|
||||||
1. Fork the repository
|
1. Fork the repository
|
||||||
2. Create a feature branch
|
2. Create a feature branch
|
||||||
3. Submit a pull request
|
3. Submit a pull request with detailed description
|
||||||
|
|
||||||
|
## Acknowledgments
|
||||||
|
|
||||||
|
- OpenRouter team for the unified AI API
|
||||||
|
- Rich library for beautiful terminal output
|
||||||
|
- MCP community for the protocol specification
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
**⭐ Star this project if you find it useful!**
|
**Star ⭐ this project if you find it useful!**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
Did you really read all the way down here? WOW! You deserve a 🍾 🥂!
|
||||||
|
|||||||
@@ -1,26 +0,0 @@
|
|||||||
"""
|
|
||||||
oAI - OpenRouter AI Chat Client
|
|
||||||
|
|
||||||
A feature-rich terminal-based chat application that provides an interactive CLI
|
|
||||||
interface to OpenRouter's unified AI API with advanced Model Context Protocol (MCP)
|
|
||||||
integration for filesystem and database access.
|
|
||||||
|
|
||||||
Author: Rune
|
|
||||||
License: MIT
|
|
||||||
"""
|
|
||||||
|
|
||||||
__version__ = "2.1.0"
|
|
||||||
__author__ = "Rune"
|
|
||||||
__license__ = "MIT"
|
|
||||||
|
|
||||||
# Lazy imports to avoid circular dependencies and improve startup time
|
|
||||||
# Full imports are available via submodules:
|
|
||||||
# from oai.config import Settings, Database
|
|
||||||
# from oai.providers import OpenRouterProvider, AIProvider
|
|
||||||
# from oai.mcp import MCPManager
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"__version__",
|
|
||||||
"__author__",
|
|
||||||
"__license__",
|
|
||||||
]
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
"""
|
|
||||||
Entry point for running oAI as a module: python -m oai
|
|
||||||
"""
|
|
||||||
|
|
||||||
from oai.cli import main
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
199
oai/cli.py
199
oai/cli.py
@@ -1,199 +0,0 @@
|
|||||||
"""
|
|
||||||
Main CLI entry point for oAI.
|
|
||||||
|
|
||||||
This module provides the command-line interface for the oAI TUI application.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import typer
|
|
||||||
|
|
||||||
from oai import __version__
|
|
||||||
from oai.commands import register_all_commands
|
|
||||||
from oai.config.settings import Settings
|
|
||||||
from oai.constants import APP_URL, APP_VERSION
|
|
||||||
from oai.core.client import AIClient
|
|
||||||
from oai.core.session import ChatSession
|
|
||||||
from oai.mcp.manager import MCPManager
|
|
||||||
from oai.utils.logging import LoggingManager, get_logger
|
|
||||||
|
|
||||||
# Create Typer app
|
|
||||||
app = typer.Typer(
|
|
||||||
name="oai",
|
|
||||||
help=f"oAI - OpenRouter AI Chat Client (TUI)\n\nVersion: {APP_VERSION}",
|
|
||||||
add_completion=False,
|
|
||||||
epilog="For more information, visit: " + APP_URL,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.callback(invoke_without_command=True)
|
|
||||||
def main_callback(
|
|
||||||
ctx: typer.Context,
|
|
||||||
version_flag: bool = typer.Option(
|
|
||||||
False,
|
|
||||||
"--version",
|
|
||||||
"-v",
|
|
||||||
help="Show version information",
|
|
||||||
is_flag=True,
|
|
||||||
),
|
|
||||||
model: Optional[str] = typer.Option(
|
|
||||||
None,
|
|
||||||
"--model",
|
|
||||||
"-m",
|
|
||||||
help="Model ID to use",
|
|
||||||
),
|
|
||||||
system: Optional[str] = typer.Option(
|
|
||||||
None,
|
|
||||||
"--system",
|
|
||||||
"-s",
|
|
||||||
help="System prompt",
|
|
||||||
),
|
|
||||||
online: bool = typer.Option(
|
|
||||||
False,
|
|
||||||
"--online",
|
|
||||||
"-o",
|
|
||||||
help="Enable online mode",
|
|
||||||
),
|
|
||||||
mcp: bool = typer.Option(
|
|
||||||
False,
|
|
||||||
"--mcp",
|
|
||||||
help="Enable MCP server",
|
|
||||||
),
|
|
||||||
) -> None:
|
|
||||||
"""Main callback - launches TUI by default."""
|
|
||||||
if version_flag:
|
|
||||||
typer.echo(f"oAI version {APP_VERSION}")
|
|
||||||
raise typer.Exit()
|
|
||||||
|
|
||||||
# If no subcommand provided, launch TUI
|
|
||||||
if ctx.invoked_subcommand is None:
|
|
||||||
_launch_tui(model, system, online, mcp)
|
|
||||||
|
|
||||||
|
|
||||||
def _launch_tui(
|
|
||||||
model: Optional[str] = None,
|
|
||||||
system: Optional[str] = None,
|
|
||||||
online: bool = False,
|
|
||||||
mcp: bool = False,
|
|
||||||
) -> None:
|
|
||||||
"""Launch the Textual TUI interface."""
|
|
||||||
# Setup logging
|
|
||||||
logging_manager = LoggingManager()
|
|
||||||
logging_manager.setup()
|
|
||||||
logger = get_logger()
|
|
||||||
|
|
||||||
# Load settings
|
|
||||||
settings = Settings.load()
|
|
||||||
|
|
||||||
# Check API key
|
|
||||||
if not settings.api_key:
|
|
||||||
typer.echo("Error: No API key configured", err=True)
|
|
||||||
typer.echo("Run: oai config api to set your API key", err=True)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
# Initialize client
|
|
||||||
try:
|
|
||||||
client = AIClient(
|
|
||||||
api_key=settings.api_key,
|
|
||||||
base_url=settings.base_url,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
typer.echo(f"Error: Failed to initialize client: {e}", err=True)
|
|
||||||
raise typer.Exit(1)
|
|
||||||
|
|
||||||
# Register commands
|
|
||||||
register_all_commands()
|
|
||||||
|
|
||||||
# Initialize MCP manager (always create it, even if not enabled)
|
|
||||||
mcp_manager = MCPManager()
|
|
||||||
if mcp:
|
|
||||||
try:
|
|
||||||
result = mcp_manager.enable()
|
|
||||||
if result["success"]:
|
|
||||||
logger.info("MCP server enabled in files mode")
|
|
||||||
else:
|
|
||||||
logger.warning(f"MCP: {result.get('error', 'Failed to enable')}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to enable MCP: {e}")
|
|
||||||
|
|
||||||
# Create session with MCP manager
|
|
||||||
session = ChatSession(
|
|
||||||
client=client,
|
|
||||||
settings=settings,
|
|
||||||
mcp_manager=mcp_manager,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set system prompt if provided
|
|
||||||
if system:
|
|
||||||
session.set_system_prompt(system)
|
|
||||||
|
|
||||||
# Enable online mode if requested
|
|
||||||
if online:
|
|
||||||
session.online_enabled = True
|
|
||||||
|
|
||||||
# Set model if specified, otherwise use default
|
|
||||||
if model:
|
|
||||||
raw_model = client.get_raw_model(model)
|
|
||||||
if raw_model:
|
|
||||||
session.set_model(raw_model)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Model '{model}' not found")
|
|
||||||
elif settings.default_model:
|
|
||||||
raw_model = client.get_raw_model(settings.default_model)
|
|
||||||
if raw_model:
|
|
||||||
session.set_model(raw_model)
|
|
||||||
else:
|
|
||||||
logger.warning(f"Default model '{settings.default_model}' not available")
|
|
||||||
|
|
||||||
# Run Textual app
|
|
||||||
from oai.tui.app import oAIChatApp
|
|
||||||
|
|
||||||
app_instance = oAIChatApp(session, settings, model)
|
|
||||||
app_instance.run()
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def tui(
|
|
||||||
model: Optional[str] = typer.Option(
|
|
||||||
None,
|
|
||||||
"--model",
|
|
||||||
"-m",
|
|
||||||
help="Model ID to use",
|
|
||||||
),
|
|
||||||
system: Optional[str] = typer.Option(
|
|
||||||
None,
|
|
||||||
"--system",
|
|
||||||
"-s",
|
|
||||||
help="System prompt",
|
|
||||||
),
|
|
||||||
online: bool = typer.Option(
|
|
||||||
False,
|
|
||||||
"--online",
|
|
||||||
"-o",
|
|
||||||
help="Enable online mode",
|
|
||||||
),
|
|
||||||
mcp: bool = typer.Option(
|
|
||||||
False,
|
|
||||||
"--mcp",
|
|
||||||
help="Enable MCP server",
|
|
||||||
),
|
|
||||||
) -> None:
|
|
||||||
"""Start Textual TUI interface (alias for just running 'oai')."""
|
|
||||||
_launch_tui(model, system, online, mcp)
|
|
||||||
|
|
||||||
|
|
||||||
@app.command()
|
|
||||||
def version() -> None:
|
|
||||||
"""Show version information."""
|
|
||||||
typer.echo(f"oAI version {APP_VERSION}")
|
|
||||||
typer.echo(f"Visit {APP_URL} for more information")
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
"""Entry point for the CLI."""
|
|
||||||
app()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
"""
|
|
||||||
Command system for oAI.
|
|
||||||
|
|
||||||
This module provides a command registry and handler system
|
|
||||||
for processing slash commands in the chat interface.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from oai.commands.registry import (
|
|
||||||
Command,
|
|
||||||
CommandRegistry,
|
|
||||||
CommandContext,
|
|
||||||
CommandResult,
|
|
||||||
registry,
|
|
||||||
)
|
|
||||||
from oai.commands.handlers import register_all_commands
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Command",
|
|
||||||
"CommandRegistry",
|
|
||||||
"CommandContext",
|
|
||||||
"CommandResult",
|
|
||||||
"registry",
|
|
||||||
"register_all_commands",
|
|
||||||
]
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,382 +0,0 @@
|
|||||||
"""
|
|
||||||
Command registry for oAI.
|
|
||||||
|
|
||||||
This module defines the command system infrastructure including
|
|
||||||
the Command base class, CommandContext for state, and CommandRegistry
|
|
||||||
for managing available commands.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from enum import Enum
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
|
|
||||||
|
|
||||||
from oai.utils.logging import get_logger
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from oai.config.settings import Settings
|
|
||||||
from oai.providers.base import AIProvider, ModelInfo
|
|
||||||
from oai.mcp.manager import MCPManager
|
|
||||||
|
|
||||||
|
|
||||||
class CommandStatus(str, Enum):
|
|
||||||
"""Status of command execution."""
|
|
||||||
|
|
||||||
SUCCESS = "success"
|
|
||||||
ERROR = "error"
|
|
||||||
CONTINUE = "continue" # Continue to next handler
|
|
||||||
EXIT = "exit" # Exit the application
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CommandResult:
|
|
||||||
"""
|
|
||||||
Result of a command execution.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
status: Execution status
|
|
||||||
message: Optional message to display
|
|
||||||
data: Optional data payload
|
|
||||||
should_continue: Whether to continue the main loop
|
|
||||||
"""
|
|
||||||
|
|
||||||
status: CommandStatus = CommandStatus.SUCCESS
|
|
||||||
message: Optional[str] = None
|
|
||||||
data: Optional[Any] = None
|
|
||||||
should_continue: bool = True
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def success(cls, message: Optional[str] = None, data: Any = None) -> "CommandResult":
|
|
||||||
"""Create a success result."""
|
|
||||||
return cls(status=CommandStatus.SUCCESS, message=message, data=data)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def error(cls, message: str) -> "CommandResult":
|
|
||||||
"""Create an error result."""
|
|
||||||
return cls(status=CommandStatus.ERROR, message=message)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def exit(cls, message: Optional[str] = None) -> "CommandResult":
|
|
||||||
"""Create an exit result."""
|
|
||||||
return cls(status=CommandStatus.EXIT, message=message, should_continue=False)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CommandContext:
|
|
||||||
"""
|
|
||||||
Context object providing state to command handlers.
|
|
||||||
|
|
||||||
Contains all the session state needed by commands including
|
|
||||||
settings, provider, conversation history, and MCP manager.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
settings: Application settings
|
|
||||||
provider: AI provider instance
|
|
||||||
mcp_manager: MCP manager instance
|
|
||||||
selected_model: Currently selected model
|
|
||||||
session_history: Conversation history
|
|
||||||
session_system_prompt: Current system prompt
|
|
||||||
memory_enabled: Whether memory is enabled
|
|
||||||
online_enabled: Whether online mode is enabled
|
|
||||||
session_tokens: Session token counts
|
|
||||||
session_cost: Session cost total
|
|
||||||
"""
|
|
||||||
|
|
||||||
settings: Optional["Settings"] = None
|
|
||||||
provider: Optional["AIProvider"] = None
|
|
||||||
mcp_manager: Optional["MCPManager"] = None
|
|
||||||
selected_model: Optional["ModelInfo"] = None
|
|
||||||
selected_model_raw: Optional[Dict[str, Any]] = None
|
|
||||||
session_history: List[Dict[str, Any]] = field(default_factory=list)
|
|
||||||
session_system_prompt: str = ""
|
|
||||||
memory_enabled: bool = True
|
|
||||||
memory_start_index: int = 0
|
|
||||||
online_enabled: bool = False
|
|
||||||
middle_out_enabled: bool = False
|
|
||||||
session_max_token: int = 0
|
|
||||||
total_input_tokens: int = 0
|
|
||||||
total_output_tokens: int = 0
|
|
||||||
total_cost: float = 0.0
|
|
||||||
message_count: int = 0
|
|
||||||
is_tui: bool = False # Flag for TUI mode
|
|
||||||
current_index: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class CommandHelp:
|
|
||||||
"""
|
|
||||||
Help information for a command.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
description: Brief description
|
|
||||||
usage: Usage syntax
|
|
||||||
examples: List of (description, example) tuples
|
|
||||||
notes: Additional notes
|
|
||||||
aliases: Command aliases
|
|
||||||
"""
|
|
||||||
|
|
||||||
description: str
|
|
||||||
usage: str = ""
|
|
||||||
examples: List[tuple] = field(default_factory=list)
|
|
||||||
notes: str = ""
|
|
||||||
aliases: List[str] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class Command(ABC):
|
|
||||||
"""
|
|
||||||
Abstract base class for all commands.
|
|
||||||
|
|
||||||
Commands implement the execute method to handle their logic.
|
|
||||||
They can also provide help information and aliases.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def name(self) -> str:
|
|
||||||
"""Get the primary command name (e.g., '/help')."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
def aliases(self) -> List[str]:
|
|
||||||
"""Get command aliases (e.g., ['/h'] for help)."""
|
|
||||||
return []
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def help(self) -> CommandHelp:
|
|
||||||
"""Get command help information."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def execute(self, args: str, context: CommandContext) -> CommandResult:
|
|
||||||
"""
|
|
||||||
Execute the command.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
args: Arguments passed to the command
|
|
||||||
context: Command execution context
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CommandResult indicating success/failure
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def matches(self, input_text: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if this command matches the input.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_text: User input text
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if this command should handle the input
|
|
||||||
"""
|
|
||||||
input_lower = input_text.lower()
|
|
||||||
cmd_word = input_lower.split()[0] if input_lower.split() else ""
|
|
||||||
|
|
||||||
# Check primary name
|
|
||||||
if cmd_word == self.name.lower():
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check aliases
|
|
||||||
for alias in self.aliases:
|
|
||||||
if cmd_word == alias.lower():
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_args(self, input_text: str) -> str:
|
|
||||||
"""
|
|
||||||
Extract arguments from the input text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_text: Full user input
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Arguments portion of the input
|
|
||||||
"""
|
|
||||||
parts = input_text.split(maxsplit=1)
|
|
||||||
return parts[1] if len(parts) > 1 else ""
|
|
||||||
|
|
||||||
|
|
||||||
class CommandRegistry:
|
|
||||||
"""
|
|
||||||
Registry for managing available commands.
|
|
||||||
|
|
||||||
Provides registration, lookup, and execution of commands.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize an empty command registry."""
|
|
||||||
self._commands: Dict[str, Command] = {}
|
|
||||||
self._aliases: Dict[str, str] = {}
|
|
||||||
self.logger = get_logger()
|
|
||||||
|
|
||||||
def register(self, command: Command) -> None:
|
|
||||||
"""
|
|
||||||
Register a command.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
command: Command instance to register
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If command name already registered
|
|
||||||
"""
|
|
||||||
name = command.name.lower()
|
|
||||||
|
|
||||||
if name in self._commands:
|
|
||||||
raise ValueError(f"Command '{name}' already registered")
|
|
||||||
|
|
||||||
self._commands[name] = command
|
|
||||||
|
|
||||||
# Register aliases
|
|
||||||
for alias in command.aliases:
|
|
||||||
alias_lower = alias.lower()
|
|
||||||
if alias_lower in self._aliases:
|
|
||||||
self.logger.warning(
|
|
||||||
f"Alias '{alias}' already registered, overwriting"
|
|
||||||
)
|
|
||||||
self._aliases[alias_lower] = name
|
|
||||||
|
|
||||||
self.logger.debug(f"Registered command: {name}")
|
|
||||||
|
|
||||||
def register_function(
|
|
||||||
self,
|
|
||||||
name: str,
|
|
||||||
handler: Callable[[str, CommandContext], CommandResult],
|
|
||||||
description: str,
|
|
||||||
usage: str = "",
|
|
||||||
aliases: Optional[List[str]] = None,
|
|
||||||
examples: Optional[List[tuple]] = None,
|
|
||||||
notes: str = "",
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Register a function-based command.
|
|
||||||
|
|
||||||
Convenience method for simple commands that don't need
|
|
||||||
a full Command class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Command name (e.g., '/help')
|
|
||||||
handler: Function to execute
|
|
||||||
description: Help description
|
|
||||||
usage: Usage syntax
|
|
||||||
aliases: Command aliases
|
|
||||||
examples: Example usages
|
|
||||||
notes: Additional notes
|
|
||||||
"""
|
|
||||||
aliases = aliases or []
|
|
||||||
examples = examples or []
|
|
||||||
|
|
||||||
class FunctionCommand(Command):
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def aliases(self) -> List[str]:
|
|
||||||
return aliases
|
|
||||||
|
|
||||||
@property
|
|
||||||
def help(self) -> CommandHelp:
|
|
||||||
return CommandHelp(
|
|
||||||
description=description,
|
|
||||||
usage=usage,
|
|
||||||
examples=examples,
|
|
||||||
notes=notes,
|
|
||||||
aliases=aliases,
|
|
||||||
)
|
|
||||||
|
|
||||||
def execute(self, args: str, context: CommandContext) -> CommandResult:
|
|
||||||
return handler(args, context)
|
|
||||||
|
|
||||||
self.register(FunctionCommand())
|
|
||||||
|
|
||||||
def get(self, name: str) -> Optional[Command]:
|
|
||||||
"""
|
|
||||||
Get a command by name or alias.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Command name or alias
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Command instance or None if not found
|
|
||||||
"""
|
|
||||||
name_lower = name.lower()
|
|
||||||
|
|
||||||
# Check direct match
|
|
||||||
if name_lower in self._commands:
|
|
||||||
return self._commands[name_lower]
|
|
||||||
|
|
||||||
# Check aliases
|
|
||||||
if name_lower in self._aliases:
|
|
||||||
return self._commands[self._aliases[name_lower]]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def find(self, input_text: str) -> Optional[Command]:
|
|
||||||
"""
|
|
||||||
Find a command that matches the input.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_text: User input text
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Matching Command or None
|
|
||||||
"""
|
|
||||||
cmd_word = input_text.lower().split()[0] if input_text.split() else ""
|
|
||||||
return self.get(cmd_word)
|
|
||||||
|
|
||||||
def execute(self, input_text: str, context: CommandContext) -> Optional[CommandResult]:
|
|
||||||
"""
|
|
||||||
Execute a command matching the input.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_text: User input text
|
|
||||||
context: Execution context
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CommandResult or None if no matching command
|
|
||||||
"""
|
|
||||||
command = self.find(input_text)
|
|
||||||
if command:
|
|
||||||
args = command.get_args(input_text)
|
|
||||||
self.logger.debug(f"Executing command: {command.name} with args: {args}")
|
|
||||||
return command.execute(args, context)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def is_command(self, input_text: str) -> bool:
|
|
||||||
"""
|
|
||||||
Check if input is a valid command.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_text: User input text
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if input matches a registered command
|
|
||||||
"""
|
|
||||||
return self.find(input_text) is not None
|
|
||||||
|
|
||||||
def list_commands(self) -> List[Command]:
|
|
||||||
"""
|
|
||||||
Get all registered commands.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of Command instances
|
|
||||||
"""
|
|
||||||
return list(self._commands.values())
|
|
||||||
|
|
||||||
def get_all_names(self) -> List[str]:
|
|
||||||
"""
|
|
||||||
Get all command names and aliases.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of command names including aliases
|
|
||||||
"""
|
|
||||||
names = list(self._commands.keys())
|
|
||||||
names.extend(self._aliases.keys())
|
|
||||||
return sorted(set(names))
|
|
||||||
|
|
||||||
|
|
||||||
# Global registry instance
|
|
||||||
registry = CommandRegistry()
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
"""
|
|
||||||
Configuration management for oAI.
|
|
||||||
|
|
||||||
This package handles all configuration persistence, settings management,
|
|
||||||
and database operations for the application.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from oai.config.settings import Settings
|
|
||||||
from oai.config.database import Database
|
|
||||||
|
|
||||||
__all__ = ["Settings", "Database"]
|
|
||||||
@@ -1,472 +0,0 @@
|
|||||||
"""
|
|
||||||
Database persistence layer for oAI.
|
|
||||||
|
|
||||||
This module provides a clean abstraction for SQLite operations including
|
|
||||||
configuration storage, conversation persistence, and MCP statistics tracking.
|
|
||||||
All database operations are centralized here for maintainability.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sqlite3
|
|
||||||
import json
|
|
||||||
import datetime
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, List, Dict, Any
|
|
||||||
from contextlib import contextmanager
|
|
||||||
|
|
||||||
from oai.constants import DATABASE_FILE, CONFIG_DIR
|
|
||||||
|
|
||||||
|
|
||||||
class Database:
|
|
||||||
"""
|
|
||||||
SQLite database manager for oAI.
|
|
||||||
|
|
||||||
Handles all database operations including:
|
|
||||||
- Configuration key-value storage
|
|
||||||
- Conversation session persistence
|
|
||||||
- MCP configuration and statistics
|
|
||||||
- Database registrations for MCP
|
|
||||||
|
|
||||||
Uses context managers for safe connection handling and supports
|
|
||||||
automatic table creation on first use.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, db_path: Optional[Path] = None):
|
|
||||||
"""
|
|
||||||
Initialize the database manager.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db_path: Optional custom database path. Defaults to standard location.
|
|
||||||
"""
|
|
||||||
self.db_path = db_path or DATABASE_FILE
|
|
||||||
self._ensure_directories()
|
|
||||||
self._ensure_tables()
|
|
||||||
|
|
||||||
def _ensure_directories(self) -> None:
|
|
||||||
"""Ensure the configuration directory exists."""
|
|
||||||
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _connection(self):
|
|
||||||
"""
|
|
||||||
Context manager for database connections.
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
sqlite3.Connection: Active database connection
|
|
||||||
|
|
||||||
Example:
|
|
||||||
with self._connection() as conn:
|
|
||||||
conn.execute("SELECT * FROM config")
|
|
||||||
"""
|
|
||||||
conn = sqlite3.connect(str(self.db_path))
|
|
||||||
try:
|
|
||||||
yield conn
|
|
||||||
conn.commit()
|
|
||||||
except Exception:
|
|
||||||
conn.rollback()
|
|
||||||
raise
|
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
def _ensure_tables(self) -> None:
|
|
||||||
"""Create all required tables if they don't exist."""
|
|
||||||
with self._connection() as conn:
|
|
||||||
# Main configuration table
|
|
||||||
conn.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS config (
|
|
||||||
key TEXT PRIMARY KEY,
|
|
||||||
value TEXT NOT NULL
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
|
|
||||||
# Conversation sessions table
|
|
||||||
conn.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS conversation_sessions (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
name TEXT NOT NULL,
|
|
||||||
timestamp TEXT NOT NULL,
|
|
||||||
data TEXT NOT NULL
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
|
|
||||||
# MCP configuration table
|
|
||||||
conn.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS mcp_config (
|
|
||||||
key TEXT PRIMARY KEY,
|
|
||||||
value TEXT NOT NULL
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
|
|
||||||
# MCP statistics table
|
|
||||||
conn.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS mcp_stats (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
timestamp TEXT NOT NULL,
|
|
||||||
tool_name TEXT NOT NULL,
|
|
||||||
folder TEXT,
|
|
||||||
success INTEGER NOT NULL,
|
|
||||||
error_message TEXT
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
|
|
||||||
# MCP databases table
|
|
||||||
conn.execute("""
|
|
||||||
CREATE TABLE IF NOT EXISTS mcp_databases (
|
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
||||||
path TEXT NOT NULL UNIQUE,
|
|
||||||
name TEXT NOT NULL,
|
|
||||||
size INTEGER,
|
|
||||||
tables TEXT,
|
|
||||||
added_timestamp TEXT NOT NULL
|
|
||||||
)
|
|
||||||
""")
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# CONFIGURATION METHODS
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def get_config(self, key: str) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Retrieve a configuration value by key.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: The configuration key to retrieve
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The configuration value, or None if not found
|
|
||||||
"""
|
|
||||||
with self._connection() as conn:
|
|
||||||
cursor = conn.execute(
|
|
||||||
"SELECT value FROM config WHERE key = ?",
|
|
||||||
(key,)
|
|
||||||
)
|
|
||||||
result = cursor.fetchone()
|
|
||||||
return result[0] if result else None
|
|
||||||
|
|
||||||
def set_config(self, key: str, value: str) -> None:
|
|
||||||
"""
|
|
||||||
Set a configuration value.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: The configuration key
|
|
||||||
value: The value to store
|
|
||||||
"""
|
|
||||||
with self._connection() as conn:
|
|
||||||
conn.execute(
|
|
||||||
"INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)",
|
|
||||||
(key, value)
|
|
||||||
)
|
|
||||||
|
|
||||||
def delete_config(self, key: str) -> bool:
|
|
||||||
"""
|
|
||||||
Delete a configuration value.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: The configuration key to delete
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if a row was deleted, False otherwise
|
|
||||||
"""
|
|
||||||
with self._connection() as conn:
|
|
||||||
cursor = conn.execute(
|
|
||||||
"DELETE FROM config WHERE key = ?",
|
|
||||||
(key,)
|
|
||||||
)
|
|
||||||
return cursor.rowcount > 0
|
|
||||||
|
|
||||||
def get_all_config(self) -> Dict[str, str]:
|
|
||||||
"""
|
|
||||||
Retrieve all configuration values.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary of all key-value pairs
|
|
||||||
"""
|
|
||||||
with self._connection() as conn:
|
|
||||||
cursor = conn.execute("SELECT key, value FROM config")
|
|
||||||
return dict(cursor.fetchall())
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# MCP CONFIGURATION METHODS
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def get_mcp_config(self, key: str) -> Optional[str]:
|
|
||||||
"""
|
|
||||||
Retrieve an MCP configuration value.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: The MCP configuration key
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The configuration value, or None if not found
|
|
||||||
"""
|
|
||||||
with self._connection() as conn:
|
|
||||||
cursor = conn.execute(
|
|
||||||
"SELECT value FROM mcp_config WHERE key = ?",
|
|
||||||
(key,)
|
|
||||||
)
|
|
||||||
result = cursor.fetchone()
|
|
||||||
return result[0] if result else None
|
|
||||||
|
|
||||||
def set_mcp_config(self, key: str, value: str) -> None:
|
|
||||||
"""
|
|
||||||
Set an MCP configuration value.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: The MCP configuration key
|
|
||||||
value: The value to store
|
|
||||||
"""
|
|
||||||
with self._connection() as conn:
|
|
||||||
conn.execute(
|
|
||||||
"INSERT OR REPLACE INTO mcp_config (key, value) VALUES (?, ?)",
|
|
||||||
(key, value)
|
|
||||||
)
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# MCP STATISTICS METHODS
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def log_mcp_stat(
|
|
||||||
self,
|
|
||||||
tool_name: str,
|
|
||||||
folder: Optional[str],
|
|
||||||
success: bool,
|
|
||||||
error_message: Optional[str] = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Log an MCP tool usage event.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: Name of the MCP tool that was called
|
|
||||||
folder: The folder path involved (if any)
|
|
||||||
success: Whether the call succeeded
|
|
||||||
error_message: Error message if the call failed
|
|
||||||
"""
|
|
||||||
timestamp = datetime.datetime.now().isoformat()
|
|
||||||
with self._connection() as conn:
|
|
||||||
conn.execute(
|
|
||||||
"""INSERT INTO mcp_stats
|
|
||||||
(timestamp, tool_name, folder, success, error_message)
|
|
||||||
VALUES (?, ?, ?, ?, ?)""",
|
|
||||||
(timestamp, tool_name, folder, 1 if success else 0, error_message)
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_mcp_stats(self) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Get aggregated MCP usage statistics.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing usage statistics:
|
|
||||||
- total_calls: Total number of tool calls
|
|
||||||
- reads: Number of file reads
|
|
||||||
- lists: Number of directory listings
|
|
||||||
- searches: Number of file searches
|
|
||||||
- db_inspects: Number of database inspections
|
|
||||||
- db_searches: Number of database searches
|
|
||||||
- db_queries: Number of database queries
|
|
||||||
- last_used: Timestamp of last usage
|
|
||||||
"""
|
|
||||||
with self._connection() as conn:
|
|
||||||
cursor = conn.execute("""
|
|
||||||
SELECT
|
|
||||||
COUNT(*) as total_calls,
|
|
||||||
SUM(CASE WHEN tool_name = 'read_file' THEN 1 ELSE 0 END) as reads,
|
|
||||||
SUM(CASE WHEN tool_name = 'list_directory' THEN 1 ELSE 0 END) as lists,
|
|
||||||
SUM(CASE WHEN tool_name = 'search_files' THEN 1 ELSE 0 END) as searches,
|
|
||||||
SUM(CASE WHEN tool_name = 'inspect_database' THEN 1 ELSE 0 END) as db_inspects,
|
|
||||||
SUM(CASE WHEN tool_name = 'search_database' THEN 1 ELSE 0 END) as db_searches,
|
|
||||||
SUM(CASE WHEN tool_name = 'query_database' THEN 1 ELSE 0 END) as db_queries,
|
|
||||||
MAX(timestamp) as last_used
|
|
||||||
FROM mcp_stats
|
|
||||||
""")
|
|
||||||
row = cursor.fetchone()
|
|
||||||
return {
|
|
||||||
"total_calls": row[0] or 0,
|
|
||||||
"reads": row[1] or 0,
|
|
||||||
"lists": row[2] or 0,
|
|
||||||
"searches": row[3] or 0,
|
|
||||||
"db_inspects": row[4] or 0,
|
|
||||||
"db_searches": row[5] or 0,
|
|
||||||
"db_queries": row[6] or 0,
|
|
||||||
"last_used": row[7],
|
|
||||||
}
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# MCP DATABASE REGISTRY METHODS
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def add_mcp_database(self, db_info: Dict[str, Any]) -> int:
|
|
||||||
"""
|
|
||||||
Register a database for MCP access.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db_info: Dictionary containing:
|
|
||||||
- path: Database file path
|
|
||||||
- name: Display name
|
|
||||||
- size: File size in bytes
|
|
||||||
- tables: List of table names
|
|
||||||
- added: Timestamp when added
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The database ID
|
|
||||||
"""
|
|
||||||
with self._connection() as conn:
|
|
||||||
conn.execute(
|
|
||||||
"""INSERT INTO mcp_databases
|
|
||||||
(path, name, size, tables, added_timestamp)
|
|
||||||
VALUES (?, ?, ?, ?, ?)""",
|
|
||||||
(
|
|
||||||
db_info["path"],
|
|
||||||
db_info["name"],
|
|
||||||
db_info["size"],
|
|
||||||
json.dumps(db_info["tables"]),
|
|
||||||
db_info["added"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cursor = conn.execute(
|
|
||||||
"SELECT id FROM mcp_databases WHERE path = ?",
|
|
||||||
(db_info["path"],)
|
|
||||||
)
|
|
||||||
return cursor.fetchone()[0]
|
|
||||||
|
|
||||||
def remove_mcp_database(self, db_path: str) -> bool:
|
|
||||||
"""
|
|
||||||
Remove a database from the MCP registry.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
db_path: Path to the database file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if a row was deleted, False otherwise
|
|
||||||
"""
|
|
||||||
with self._connection() as conn:
|
|
||||||
cursor = conn.execute(
|
|
||||||
"DELETE FROM mcp_databases WHERE path = ?",
|
|
||||||
(db_path,)
|
|
||||||
)
|
|
||||||
return cursor.rowcount > 0
|
|
||||||
|
|
||||||
def get_mcp_databases(self) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Retrieve all registered MCP databases.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of database information dictionaries
|
|
||||||
"""
|
|
||||||
with self._connection() as conn:
|
|
||||||
cursor = conn.execute(
|
|
||||||
"""SELECT id, path, name, size, tables, added_timestamp
|
|
||||||
FROM mcp_databases ORDER BY id"""
|
|
||||||
)
|
|
||||||
databases = []
|
|
||||||
for row in cursor.fetchall():
|
|
||||||
tables_list = json.loads(row[4]) if row[4] else []
|
|
||||||
databases.append({
|
|
||||||
"id": row[0],
|
|
||||||
"path": row[1],
|
|
||||||
"name": row[2],
|
|
||||||
"size": row[3],
|
|
||||||
"tables": tables_list,
|
|
||||||
"added": row[5],
|
|
||||||
})
|
|
||||||
return databases
|
|
||||||
|
|
||||||
# =========================================================================
|
|
||||||
# CONVERSATION METHODS
|
|
||||||
# =========================================================================
|
|
||||||
|
|
||||||
def save_conversation(self, name: str, data: List[Dict[str, str]]) -> None:
|
|
||||||
"""
|
|
||||||
Save a conversation session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Name/identifier for the conversation
|
|
||||||
data: List of message dictionaries with 'prompt' and 'response' keys
|
|
||||||
"""
|
|
||||||
timestamp = datetime.datetime.now().isoformat()
|
|
||||||
data_json = json.dumps(data)
|
|
||||||
with self._connection() as conn:
|
|
||||||
conn.execute(
|
|
||||||
"""INSERT INTO conversation_sessions
|
|
||||||
(name, timestamp, data) VALUES (?, ?, ?)""",
|
|
||||||
(name, timestamp, data_json)
|
|
||||||
)
|
|
||||||
|
|
||||||
def load_conversation(self, name: str) -> Optional[List[Dict[str, str]]]:
|
|
||||||
"""
|
|
||||||
Load a conversation by name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Name of the conversation to load
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of messages, or None if not found
|
|
||||||
"""
|
|
||||||
with self._connection() as conn:
|
|
||||||
cursor = conn.execute(
|
|
||||||
"""SELECT data FROM conversation_sessions
|
|
||||||
WHERE name = ?
|
|
||||||
ORDER BY timestamp DESC LIMIT 1""",
|
|
||||||
(name,)
|
|
||||||
)
|
|
||||||
result = cursor.fetchone()
|
|
||||||
if result:
|
|
||||||
return json.loads(result[0])
|
|
||||||
return None
|
|
||||||
|
|
||||||
def delete_conversation(self, name: str) -> int:
|
|
||||||
"""
|
|
||||||
Delete a conversation by name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Name of the conversation to delete
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of rows deleted
|
|
||||||
"""
|
|
||||||
with self._connection() as conn:
|
|
||||||
cursor = conn.execute(
|
|
||||||
"DELETE FROM conversation_sessions WHERE name = ?",
|
|
||||||
(name,)
|
|
||||||
)
|
|
||||||
return cursor.rowcount
|
|
||||||
|
|
||||||
def list_conversations(self) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
List all saved conversations.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of conversation summaries with name, timestamp, and message_count
|
|
||||||
"""
|
|
||||||
with self._connection() as conn:
|
|
||||||
cursor = conn.execute("""
|
|
||||||
SELECT name, MAX(timestamp) as last_saved, data
|
|
||||||
FROM conversation_sessions
|
|
||||||
GROUP BY name
|
|
||||||
ORDER BY last_saved DESC
|
|
||||||
""")
|
|
||||||
conversations = []
|
|
||||||
for row in cursor.fetchall():
|
|
||||||
name, timestamp, data_json = row
|
|
||||||
data = json.loads(data_json)
|
|
||||||
conversations.append({
|
|
||||||
"name": name,
|
|
||||||
"timestamp": timestamp,
|
|
||||||
"message_count": len(data),
|
|
||||||
})
|
|
||||||
return conversations
|
|
||||||
|
|
||||||
|
|
||||||
# Global database instance for convenience
|
|
||||||
_db: Optional[Database] = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_database() -> Database:
|
|
||||||
"""
|
|
||||||
Get the global database instance.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The shared Database instance
|
|
||||||
"""
|
|
||||||
global _db
|
|
||||||
if _db is None:
|
|
||||||
_db = Database()
|
|
||||||
return _db
|
|
||||||
@@ -1,361 +0,0 @@
|
|||||||
"""
|
|
||||||
Settings management for oAI.
|
|
||||||
|
|
||||||
This module provides a centralized settings class that handles all application
|
|
||||||
configuration with type safety, validation, and persistence.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Optional
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from oai.constants import (
|
|
||||||
DEFAULT_BASE_URL,
|
|
||||||
DEFAULT_STREAM_ENABLED,
|
|
||||||
DEFAULT_MAX_TOKENS,
|
|
||||||
DEFAULT_ONLINE_MODE,
|
|
||||||
DEFAULT_COST_WARNING_THRESHOLD,
|
|
||||||
DEFAULT_LOG_MAX_SIZE_MB,
|
|
||||||
DEFAULT_LOG_BACKUP_COUNT,
|
|
||||||
DEFAULT_LOG_LEVEL,
|
|
||||||
DEFAULT_SYSTEM_PROMPT,
|
|
||||||
VALID_LOG_LEVELS,
|
|
||||||
)
|
|
||||||
from oai.config.database import get_database
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Settings:
|
|
||||||
"""
|
|
||||||
Application settings with persistence support.
|
|
||||||
|
|
||||||
This class provides a clean interface for managing all configuration
|
|
||||||
options. Settings are automatically loaded from the database on
|
|
||||||
initialization and can be persisted back.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
api_key: OpenRouter API key
|
|
||||||
base_url: API base URL
|
|
||||||
default_model: Default model ID to use
|
|
||||||
default_system_prompt: Custom system prompt (None = use hardcoded default, "" = blank)
|
|
||||||
stream_enabled: Whether to stream responses
|
|
||||||
max_tokens: Maximum tokens per request
|
|
||||||
cost_warning_threshold: Alert threshold for message cost
|
|
||||||
default_online_mode: Whether online mode is enabled by default
|
|
||||||
log_max_size_mb: Maximum log file size in MB
|
|
||||||
log_backup_count: Number of log file backups to keep
|
|
||||||
log_level: Logging level (debug/info/warning/error/critical)
|
|
||||||
"""
|
|
||||||
|
|
||||||
api_key: Optional[str] = None
|
|
||||||
base_url: str = DEFAULT_BASE_URL
|
|
||||||
default_model: Optional[str] = None
|
|
||||||
default_system_prompt: Optional[str] = None
|
|
||||||
stream_enabled: bool = DEFAULT_STREAM_ENABLED
|
|
||||||
max_tokens: int = DEFAULT_MAX_TOKENS
|
|
||||||
cost_warning_threshold: float = DEFAULT_COST_WARNING_THRESHOLD
|
|
||||||
default_online_mode: bool = DEFAULT_ONLINE_MODE
|
|
||||||
log_max_size_mb: int = DEFAULT_LOG_MAX_SIZE_MB
|
|
||||||
log_backup_count: int = DEFAULT_LOG_BACKUP_COUNT
|
|
||||||
log_level: str = DEFAULT_LOG_LEVEL
|
|
||||||
|
|
||||||
@property
|
|
||||||
def effective_system_prompt(self) -> str:
|
|
||||||
"""
|
|
||||||
Get the effective system prompt to use.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The custom prompt if set, hardcoded default if None, or blank if explicitly set to ""
|
|
||||||
"""
|
|
||||||
if self.default_system_prompt is None:
|
|
||||||
return DEFAULT_SYSTEM_PROMPT
|
|
||||||
return self.default_system_prompt
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
"""Validate settings after initialization."""
|
|
||||||
self._validate()
|
|
||||||
|
|
||||||
def _validate(self) -> None:
|
|
||||||
"""Validate all settings values."""
|
|
||||||
# Validate log level
|
|
||||||
if self.log_level.lower() not in VALID_LOG_LEVELS:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid log level: {self.log_level}. "
|
|
||||||
f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate numeric bounds
|
|
||||||
if self.max_tokens < 1:
|
|
||||||
raise ValueError("max_tokens must be at least 1")
|
|
||||||
|
|
||||||
if self.cost_warning_threshold < 0:
|
|
||||||
raise ValueError("cost_warning_threshold must be non-negative")
|
|
||||||
|
|
||||||
if self.log_max_size_mb < 1:
|
|
||||||
raise ValueError("log_max_size_mb must be at least 1")
|
|
||||||
|
|
||||||
if self.log_backup_count < 0:
|
|
||||||
raise ValueError("log_backup_count must be non-negative")
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls) -> "Settings":
|
|
||||||
"""
|
|
||||||
Load settings from the database.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Settings instance with values from database
|
|
||||||
"""
|
|
||||||
db = get_database()
|
|
||||||
|
|
||||||
# Helper to safely parse boolean
|
|
||||||
def parse_bool(value: Optional[str], default: bool) -> bool:
|
|
||||||
if value is None:
|
|
||||||
return default
|
|
||||||
return value.lower() in ("on", "true", "1", "yes")
|
|
||||||
|
|
||||||
# Helper to safely parse int
|
|
||||||
def parse_int(value: Optional[str], default: int) -> int:
|
|
||||||
if value is None:
|
|
||||||
return default
|
|
||||||
try:
|
|
||||||
return int(value)
|
|
||||||
except ValueError:
|
|
||||||
return default
|
|
||||||
|
|
||||||
# Helper to safely parse float
|
|
||||||
def parse_float(value: Optional[str], default: float) -> float:
|
|
||||||
if value is None:
|
|
||||||
return default
|
|
||||||
try:
|
|
||||||
return float(value)
|
|
||||||
except ValueError:
|
|
||||||
return default
|
|
||||||
|
|
||||||
# Get system prompt from DB: None means not set (use default), "" means explicitly blank
|
|
||||||
system_prompt_value = db.get_config("default_system_prompt")
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
api_key=db.get_config("api_key"),
|
|
||||||
base_url=db.get_config("base_url") or DEFAULT_BASE_URL,
|
|
||||||
default_model=db.get_config("default_model"),
|
|
||||||
default_system_prompt=system_prompt_value,
|
|
||||||
stream_enabled=parse_bool(
|
|
||||||
db.get_config("stream_enabled"),
|
|
||||||
DEFAULT_STREAM_ENABLED
|
|
||||||
),
|
|
||||||
max_tokens=parse_int(
|
|
||||||
db.get_config("max_token"),
|
|
||||||
DEFAULT_MAX_TOKENS
|
|
||||||
),
|
|
||||||
cost_warning_threshold=parse_float(
|
|
||||||
db.get_config("cost_warning_threshold"),
|
|
||||||
DEFAULT_COST_WARNING_THRESHOLD
|
|
||||||
),
|
|
||||||
default_online_mode=parse_bool(
|
|
||||||
db.get_config("default_online_mode"),
|
|
||||||
DEFAULT_ONLINE_MODE
|
|
||||||
),
|
|
||||||
log_max_size_mb=parse_int(
|
|
||||||
db.get_config("log_max_size_mb"),
|
|
||||||
DEFAULT_LOG_MAX_SIZE_MB
|
|
||||||
),
|
|
||||||
log_backup_count=parse_int(
|
|
||||||
db.get_config("log_backup_count"),
|
|
||||||
DEFAULT_LOG_BACKUP_COUNT
|
|
||||||
),
|
|
||||||
log_level=db.get_config("log_level") or DEFAULT_LOG_LEVEL,
|
|
||||||
)
|
|
||||||
|
|
||||||
def save(self) -> None:
|
|
||||||
"""Persist all settings to the database."""
|
|
||||||
db = get_database()
|
|
||||||
|
|
||||||
# Only save API key if it exists
|
|
||||||
if self.api_key:
|
|
||||||
db.set_config("api_key", self.api_key)
|
|
||||||
|
|
||||||
db.set_config("base_url", self.base_url)
|
|
||||||
|
|
||||||
if self.default_model:
|
|
||||||
db.set_config("default_model", self.default_model)
|
|
||||||
|
|
||||||
# Save system prompt: None means not set (don't save), otherwise save the value (even if "")
|
|
||||||
if self.default_system_prompt is not None:
|
|
||||||
db.set_config("default_system_prompt", self.default_system_prompt)
|
|
||||||
|
|
||||||
db.set_config("stream_enabled", "on" if self.stream_enabled else "off")
|
|
||||||
db.set_config("max_token", str(self.max_tokens))
|
|
||||||
db.set_config("cost_warning_threshold", str(self.cost_warning_threshold))
|
|
||||||
db.set_config("default_online_mode", "on" if self.default_online_mode else "off")
|
|
||||||
db.set_config("log_max_size_mb", str(self.log_max_size_mb))
|
|
||||||
db.set_config("log_backup_count", str(self.log_backup_count))
|
|
||||||
db.set_config("log_level", self.log_level)
|
|
||||||
|
|
||||||
def set_api_key(self, api_key: str) -> None:
|
|
||||||
"""
|
|
||||||
Set and persist the API key.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key: The new API key
|
|
||||||
"""
|
|
||||||
self.api_key = api_key.strip()
|
|
||||||
get_database().set_config("api_key", self.api_key)
|
|
||||||
|
|
||||||
def set_base_url(self, url: str) -> None:
|
|
||||||
"""
|
|
||||||
Set and persist the base URL.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
url: The new base URL
|
|
||||||
"""
|
|
||||||
self.base_url = url.strip()
|
|
||||||
get_database().set_config("base_url", self.base_url)
|
|
||||||
|
|
||||||
def set_default_model(self, model_id: str) -> None:
|
|
||||||
"""
|
|
||||||
Set and persist the default model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id: The model ID to set as default
|
|
||||||
"""
|
|
||||||
self.default_model = model_id
|
|
||||||
get_database().set_config("default_model", model_id)
|
|
||||||
|
|
||||||
def set_default_system_prompt(self, prompt: str) -> None:
|
|
||||||
"""
|
|
||||||
Set and persist the default system prompt.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: The system prompt to use for all new sessions.
|
|
||||||
Empty string "" means blank prompt (no system message).
|
|
||||||
"""
|
|
||||||
self.default_system_prompt = prompt
|
|
||||||
get_database().set_config("default_system_prompt", prompt)
|
|
||||||
|
|
||||||
def clear_default_system_prompt(self) -> None:
|
|
||||||
"""
|
|
||||||
Clear the custom system prompt and revert to hardcoded default.
|
|
||||||
|
|
||||||
This removes the custom prompt from the database, causing the
|
|
||||||
application to use the built-in DEFAULT_SYSTEM_PROMPT.
|
|
||||||
"""
|
|
||||||
self.default_system_prompt = None
|
|
||||||
# Remove from database to indicate "not set"
|
|
||||||
db = get_database()
|
|
||||||
with db._connection() as conn:
|
|
||||||
conn.execute("DELETE FROM config WHERE key = ?", ("default_system_prompt",))
|
|
||||||
conn.commit()
|
|
||||||
|
|
||||||
def set_stream_enabled(self, enabled: bool) -> None:
|
|
||||||
"""
|
|
||||||
Set and persist the streaming preference.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
enabled: Whether to enable streaming
|
|
||||||
"""
|
|
||||||
self.stream_enabled = enabled
|
|
||||||
get_database().set_config("stream_enabled", "on" if enabled else "off")
|
|
||||||
|
|
||||||
def set_max_tokens(self, max_tokens: int) -> None:
|
|
||||||
"""
|
|
||||||
Set and persist the maximum tokens.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_tokens: Maximum number of tokens
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If max_tokens is less than 1
|
|
||||||
"""
|
|
||||||
if max_tokens < 1:
|
|
||||||
raise ValueError("max_tokens must be at least 1")
|
|
||||||
self.max_tokens = max_tokens
|
|
||||||
get_database().set_config("max_token", str(max_tokens))
|
|
||||||
|
|
||||||
def set_cost_warning_threshold(self, threshold: float) -> None:
|
|
||||||
"""
|
|
||||||
Set and persist the cost warning threshold.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
threshold: Cost threshold in USD
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If threshold is negative
|
|
||||||
"""
|
|
||||||
if threshold < 0:
|
|
||||||
raise ValueError("cost_warning_threshold must be non-negative")
|
|
||||||
self.cost_warning_threshold = threshold
|
|
||||||
get_database().set_config("cost_warning_threshold", str(threshold))
|
|
||||||
|
|
||||||
def set_default_online_mode(self, enabled: bool) -> None:
|
|
||||||
"""
|
|
||||||
Set and persist the default online mode.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
enabled: Whether online mode should be enabled by default
|
|
||||||
"""
|
|
||||||
self.default_online_mode = enabled
|
|
||||||
get_database().set_config("default_online_mode", "on" if enabled else "off")
|
|
||||||
|
|
||||||
def set_log_level(self, level: str) -> None:
|
|
||||||
"""
|
|
||||||
Set and persist the log level.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
level: The log level (debug/info/warning/error/critical)
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If level is not valid
|
|
||||||
"""
|
|
||||||
level_lower = level.lower()
|
|
||||||
if level_lower not in VALID_LOG_LEVELS:
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid log level: {level}. "
|
|
||||||
f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}"
|
|
||||||
)
|
|
||||||
self.log_level = level_lower
|
|
||||||
get_database().set_config("log_level", level_lower)
|
|
||||||
|
|
||||||
def set_log_max_size(self, size_mb: int) -> None:
|
|
||||||
"""
|
|
||||||
Set and persist the maximum log file size.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
size_mb: Maximum size in megabytes
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If size_mb is less than 1
|
|
||||||
"""
|
|
||||||
if size_mb < 1:
|
|
||||||
raise ValueError("log_max_size_mb must be at least 1")
|
|
||||||
# Cap at 100 MB for safety
|
|
||||||
self.log_max_size_mb = min(size_mb, 100)
|
|
||||||
get_database().set_config("log_max_size_mb", str(self.log_max_size_mb))
|
|
||||||
|
|
||||||
|
|
||||||
# Global settings instance
|
|
||||||
_settings: Optional[Settings] = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_settings() -> Settings:
|
|
||||||
"""
|
|
||||||
Get the global settings instance.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The shared Settings instance, loading from database if needed
|
|
||||||
"""
|
|
||||||
global _settings
|
|
||||||
if _settings is None:
|
|
||||||
_settings = Settings.load()
|
|
||||||
return _settings
|
|
||||||
|
|
||||||
|
|
||||||
def reload_settings() -> Settings:
|
|
||||||
"""
|
|
||||||
Force reload settings from the database.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Fresh Settings instance
|
|
||||||
"""
|
|
||||||
global _settings
|
|
||||||
_settings = Settings.load()
|
|
||||||
return _settings
|
|
||||||
448
oai/constants.py
448
oai/constants.py
@@ -1,448 +0,0 @@
|
|||||||
"""
|
|
||||||
Application-wide constants for oAI.
|
|
||||||
|
|
||||||
This module contains all configuration constants, default values, and static
|
|
||||||
definitions used throughout the application. Centralizing these values makes
|
|
||||||
the codebase easier to maintain and configure.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Set, Dict, Any
|
|
||||||
import logging
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# APPLICATION METADATA
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
APP_NAME = "oAI"
|
|
||||||
APP_VERSION = "3.0.0"
|
|
||||||
APP_URL = "https://iurl.no/oai"
|
|
||||||
APP_DESCRIPTION = "OpenRouter AI Chat Client with MCP Integration"
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# FILE PATHS
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
HOME_DIR = Path.home()
|
|
||||||
CONFIG_DIR = HOME_DIR / ".config" / "oai"
|
|
||||||
CACHE_DIR = HOME_DIR / ".cache" / "oai"
|
|
||||||
HISTORY_FILE = CONFIG_DIR / "history.txt"
|
|
||||||
DATABASE_FILE = CONFIG_DIR / "oai_config.db"
|
|
||||||
LOG_FILE = CONFIG_DIR / "oai.log"
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# API CONFIGURATION
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
DEFAULT_BASE_URL = "https://openrouter.ai/api/v1"
|
|
||||||
DEFAULT_STREAM_ENABLED = True
|
|
||||||
DEFAULT_MAX_TOKENS = 100_000
|
|
||||||
DEFAULT_ONLINE_MODE = False
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# DEFAULT SYSTEM PROMPT
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
DEFAULT_SYSTEM_PROMPT = (
|
|
||||||
"You are a knowledgeable and helpful AI assistant. Provide clear, accurate, "
|
|
||||||
"and well-structured responses. Be concise yet thorough. When uncertain about "
|
|
||||||
"something, acknowledge your limitations. For technical topics, include relevant "
|
|
||||||
"details and examples when helpful."
|
|
||||||
)
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# PRICING DEFAULTS (per million tokens)
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
DEFAULT_INPUT_PRICE = 3.0
|
|
||||||
DEFAULT_OUTPUT_PRICE = 15.0
|
|
||||||
|
|
||||||
MODEL_PRICING: Dict[str, float] = {
|
|
||||||
"input": DEFAULT_INPUT_PRICE,
|
|
||||||
"output": DEFAULT_OUTPUT_PRICE,
|
|
||||||
}
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# CREDIT ALERTS
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
LOW_CREDIT_RATIO = 0.1 # Alert when credits < 10% of total
|
|
||||||
LOW_CREDIT_AMOUNT = 1.0 # Alert when credits < $1.00
|
|
||||||
DEFAULT_COST_WARNING_THRESHOLD = 0.01 # Alert when single message cost exceeds this
|
|
||||||
COST_WARNING_THRESHOLD = DEFAULT_COST_WARNING_THRESHOLD # Alias for convenience
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# LOGGING CONFIGURATION
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
DEFAULT_LOG_MAX_SIZE_MB = 10
|
|
||||||
DEFAULT_LOG_BACKUP_COUNT = 2
|
|
||||||
DEFAULT_LOG_LEVEL = "info"
|
|
||||||
|
|
||||||
VALID_LOG_LEVELS: Dict[str, int] = {
|
|
||||||
"debug": logging.DEBUG,
|
|
||||||
"info": logging.INFO,
|
|
||||||
"warning": logging.WARNING,
|
|
||||||
"error": logging.ERROR,
|
|
||||||
"critical": logging.CRITICAL,
|
|
||||||
}
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# FILE HANDLING
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
# Maximum file size for reading (10 MB)
|
|
||||||
MAX_FILE_SIZE = 10 * 1024 * 1024
|
|
||||||
|
|
||||||
# Content truncation threshold (50 KB)
|
|
||||||
CONTENT_TRUNCATION_THRESHOLD = 50 * 1024
|
|
||||||
|
|
||||||
# Maximum items in directory listing
|
|
||||||
MAX_LIST_ITEMS = 1000
|
|
||||||
|
|
||||||
# Supported code file extensions for syntax highlighting
|
|
||||||
SUPPORTED_CODE_EXTENSIONS: Set[str] = {
|
|
||||||
".py", ".js", ".ts", ".cs", ".java", ".c", ".cpp", ".h", ".hpp",
|
|
||||||
".rb", ".ruby", ".php", ".swift", ".kt", ".kts", ".go",
|
|
||||||
".sh", ".bat", ".ps1", ".r", ".scala", ".pl", ".lua", ".dart",
|
|
||||||
".elm", ".xml", ".json", ".yaml", ".yml", ".md", ".txt",
|
|
||||||
}
|
|
||||||
|
|
||||||
# All allowed file extensions for attachment
|
|
||||||
ALLOWED_FILE_EXTENSIONS: Set[str] = {
|
|
||||||
# Code files
|
|
||||||
".py", ".js", ".ts", ".jsx", ".tsx", ".vue", ".java", ".c", ".cpp", ".cc", ".cxx",
|
|
||||||
".h", ".hpp", ".hxx", ".rb", ".go", ".rs", ".swift", ".kt", ".kts", ".php",
|
|
||||||
".sh", ".bash", ".zsh", ".fish", ".bat", ".cmd", ".ps1",
|
|
||||||
# Data files
|
|
||||||
".json", ".csv", ".yaml", ".yml", ".toml", ".xml", ".sql", ".db", ".sqlite", ".sqlite3",
|
|
||||||
# Documents
|
|
||||||
".txt", ".md", ".log", ".conf", ".cfg", ".ini", ".env", ".properties",
|
|
||||||
# Images
|
|
||||||
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg", ".ico",
|
|
||||||
# Archives
|
|
||||||
".zip", ".tar", ".gz", ".bz2", ".7z", ".rar", ".xz",
|
|
||||||
# Config files
|
|
||||||
".lock", ".gitignore", ".dockerignore", ".editorconfig", ".eslintrc",
|
|
||||||
".prettierrc", ".babelrc", ".nvmrc", ".npmrc",
|
|
||||||
# Binary/Compiled
|
|
||||||
".pyc", ".pyo", ".pyd", ".so", ".dll", ".dylib", ".exe", ".app",
|
|
||||||
".dmg", ".pkg", ".deb", ".rpm", ".apk", ".ipa",
|
|
||||||
# ML/AI
|
|
||||||
".pkl", ".pickle", ".joblib", ".npy", ".npz", ".safetensors", ".onnx",
|
|
||||||
".pt", ".pth", ".ckpt", ".pb", ".tflite", ".mlmodel", ".coreml", ".rknn",
|
|
||||||
# Data formats
|
|
||||||
".wasm", ".proto", ".graphql", ".graphqls", ".grpc", ".avro", ".parquet",
|
|
||||||
".orc", ".feather", ".arrow", ".hdf5", ".h5", ".mat", ".rdata", ".rds",
|
|
||||||
# Other
|
|
||||||
".pdf", ".class", ".jar", ".war",
|
|
||||||
}
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# SECURITY CONFIGURATION
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
# System directories that should never be accessed
|
|
||||||
SYSTEM_DIRS_BLACKLIST: Set[str] = {
|
|
||||||
# macOS
|
|
||||||
"/System", "/Library", "/private", "/usr", "/bin", "/sbin",
|
|
||||||
# Linux
|
|
||||||
"/boot", "/dev", "/proc", "/sys", "/root",
|
|
||||||
# Windows
|
|
||||||
"C:\\Windows", "C:\\Program Files", "C:\\Program Files (x86)",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Directories to skip during file operations
|
|
||||||
SKIP_DIRECTORIES: Set[str] = {
|
|
||||||
# Python virtual environments
|
|
||||||
".venv", "venv", "env", "virtualenv",
|
|
||||||
"site-packages", "dist-packages",
|
|
||||||
# Python caches
|
|
||||||
"__pycache__", ".pytest_cache", ".mypy_cache",
|
|
||||||
# JavaScript/Node
|
|
||||||
"node_modules",
|
|
||||||
# Version control
|
|
||||||
".git", ".svn",
|
|
||||||
# IDEs
|
|
||||||
".idea", ".vscode",
|
|
||||||
# Build directories
|
|
||||||
"build", "dist", "eggs", ".eggs",
|
|
||||||
}
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# DATABASE QUERIES - SQL SAFETY
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
# Maximum query execution timeout (seconds)
|
|
||||||
MAX_QUERY_TIMEOUT = 5
|
|
||||||
|
|
||||||
# Maximum rows returned from queries
|
|
||||||
MAX_QUERY_RESULTS = 1000
|
|
||||||
|
|
||||||
# Default rows per query
|
|
||||||
DEFAULT_QUERY_LIMIT = 100
|
|
||||||
|
|
||||||
# Keywords that are blocked in database queries
|
|
||||||
DANGEROUS_SQL_KEYWORDS: Set[str] = {
|
|
||||||
"INSERT", "UPDATE", "DELETE", "DROP", "CREATE",
|
|
||||||
"ALTER", "TRUNCATE", "REPLACE", "ATTACH", "DETACH",
|
|
||||||
"PRAGMA", "VACUUM", "REINDEX",
|
|
||||||
}
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# MCP CONFIGURATION
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
# Maximum tool call iterations per request
|
|
||||||
MAX_TOOL_LOOPS = 5
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# VALID COMMANDS
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
VALID_COMMANDS: Set[str] = {
|
|
||||||
"/retry", "/online", "/memory", "/paste", "/export", "/save", "/load",
|
|
||||||
"/delete", "/list", "/prev", "/next", "/stats", "/middleout", "/reset",
|
|
||||||
"/info", "/model", "/maxtoken", "/system", "/config", "/credits", "/clear",
|
|
||||||
"/cl", "/help", "/mcp",
|
|
||||||
}
|
|
||||||
|
|
||||||
# =============================================================================
|
|
||||||
# COMMAND HELP DATABASE
|
|
||||||
# =============================================================================
|
|
||||||
|
|
||||||
COMMAND_HELP: Dict[str, Dict[str, Any]] = {
|
|
||||||
"/clear": {
|
|
||||||
"aliases": ["/cl"],
|
|
||||||
"description": "Clear the terminal screen for a clean interface.",
|
|
||||||
"usage": "/clear\n/cl",
|
|
||||||
"examples": [
|
|
||||||
("Clear screen", "/clear"),
|
|
||||||
("Using short alias", "/cl"),
|
|
||||||
],
|
|
||||||
"notes": "You can also use the keyboard shortcut Ctrl+L.",
|
|
||||||
},
|
|
||||||
"/help": {
|
|
||||||
"description": "Display help information for commands.",
|
|
||||||
"usage": "/help [command|topic]",
|
|
||||||
"examples": [
|
|
||||||
("Show all commands", "/help"),
|
|
||||||
("Get help for a specific command", "/help /model"),
|
|
||||||
("Get detailed MCP help", "/help mcp"),
|
|
||||||
],
|
|
||||||
"notes": "Use /help without arguments to see the full command list.",
|
|
||||||
},
|
|
||||||
"mcp": {
|
|
||||||
"description": "Complete guide to MCP (Model Context Protocol).",
|
|
||||||
"usage": "See detailed examples below",
|
|
||||||
"examples": [],
|
|
||||||
"notes": """
|
|
||||||
MCP (Model Context Protocol) gives your AI assistant direct access to:
|
|
||||||
• Local files and folders (read, search, list)
|
|
||||||
• SQLite databases (inspect, search, query)
|
|
||||||
|
|
||||||
FILE MODE (default):
|
|
||||||
/mcp on Start MCP server
|
|
||||||
/mcp add ~/Documents Grant access to folder
|
|
||||||
/mcp list View all allowed folders
|
|
||||||
|
|
||||||
DATABASE MODE:
|
|
||||||
/mcp add db ~/app/data.db Add specific database
|
|
||||||
/mcp db list View all databases
|
|
||||||
/mcp db 1 Work with database #1
|
|
||||||
/mcp files Switch back to file mode
|
|
||||||
|
|
||||||
WRITE MODE (optional):
|
|
||||||
/mcp write on Enable file modifications
|
|
||||||
/mcp write off Disable write mode (back to read-only)
|
|
||||||
|
|
||||||
For command-specific help: /help /mcp
|
|
||||||
""",
|
|
||||||
},
|
|
||||||
"/mcp": {
|
|
||||||
"description": "Manage MCP for local file access and SQLite database querying.",
|
|
||||||
"usage": "/mcp <command> [args]",
|
|
||||||
"examples": [
|
|
||||||
("Enable MCP server", "/mcp on"),
|
|
||||||
("Disable MCP server", "/mcp off"),
|
|
||||||
("Show MCP status", "/mcp status"),
|
|
||||||
("", ""),
|
|
||||||
("━━━ FILE MODE ━━━", ""),
|
|
||||||
("Add folder for file access", "/mcp add ~/Documents"),
|
|
||||||
("Remove folder", "/mcp remove ~/Desktop"),
|
|
||||||
("List allowed folders", "/mcp list"),
|
|
||||||
("Enable write mode", "/mcp write on"),
|
|
||||||
("", ""),
|
|
||||||
("━━━ DATABASE MODE ━━━", ""),
|
|
||||||
("Add SQLite database", "/mcp add db ~/app/data.db"),
|
|
||||||
("List all databases", "/mcp db list"),
|
|
||||||
("Switch to database #1", "/mcp db 1"),
|
|
||||||
("Switch back to file mode", "/mcp files"),
|
|
||||||
],
|
|
||||||
"notes": "MCP allows AI to read local files and query SQLite databases.",
|
|
||||||
},
|
|
||||||
"/memory": {
|
|
||||||
"description": "Toggle conversation memory.",
|
|
||||||
"usage": "/memory [on|off]",
|
|
||||||
"examples": [
|
|
||||||
("Check current memory status", "/memory"),
|
|
||||||
("Enable conversation memory", "/memory on"),
|
|
||||||
("Disable memory (save costs)", "/memory off"),
|
|
||||||
],
|
|
||||||
"notes": "Memory is ON by default. Disabling saves tokens.",
|
|
||||||
},
|
|
||||||
"/online": {
|
|
||||||
"description": "Enable or disable online mode (web search).",
|
|
||||||
"usage": "/online [on|off]",
|
|
||||||
"examples": [
|
|
||||||
("Check online mode status", "/online"),
|
|
||||||
("Enable web search", "/online on"),
|
|
||||||
("Disable web search", "/online off"),
|
|
||||||
],
|
|
||||||
"notes": "Not all models support online mode.",
|
|
||||||
},
|
|
||||||
"/paste": {
|
|
||||||
"description": "Paste plain text from clipboard and send to the AI.",
|
|
||||||
"usage": "/paste [prompt]",
|
|
||||||
"examples": [
|
|
||||||
("Paste clipboard content", "/paste"),
|
|
||||||
("Paste with a question", "/paste Explain this code"),
|
|
||||||
],
|
|
||||||
"notes": "Only plain text is supported.",
|
|
||||||
},
|
|
||||||
"/retry": {
|
|
||||||
"description": "Resend the last prompt from conversation history.",
|
|
||||||
"usage": "/retry",
|
|
||||||
"examples": [("Retry last message", "/retry")],
|
|
||||||
"notes": "Requires at least one message in history.",
|
|
||||||
},
|
|
||||||
"/next": {
|
|
||||||
"description": "View the next response in conversation history.",
|
|
||||||
"usage": "/next",
|
|
||||||
"examples": [("Navigate to next response", "/next")],
|
|
||||||
"notes": "Use /prev to go backward.",
|
|
||||||
},
|
|
||||||
"/prev": {
|
|
||||||
"description": "View the previous response in conversation history.",
|
|
||||||
"usage": "/prev",
|
|
||||||
"examples": [("Navigate to previous response", "/prev")],
|
|
||||||
"notes": "Use /next to go forward.",
|
|
||||||
},
|
|
||||||
"/reset": {
|
|
||||||
"description": "Clear conversation history and reset system prompt.",
|
|
||||||
"usage": "/reset",
|
|
||||||
"examples": [("Reset conversation", "/reset")],
|
|
||||||
"notes": "Requires confirmation.",
|
|
||||||
},
|
|
||||||
"/info": {
|
|
||||||
"description": "Display detailed information about a model.",
|
|
||||||
"usage": "/info [model_id]",
|
|
||||||
"examples": [
|
|
||||||
("Show current model info", "/info"),
|
|
||||||
("Show specific model info", "/info gpt-4o"),
|
|
||||||
],
|
|
||||||
"notes": "Shows pricing, capabilities, and context length.",
|
|
||||||
},
|
|
||||||
"/model": {
|
|
||||||
"description": "Select or change the AI model.",
|
|
||||||
"usage": "/model [search_term]",
|
|
||||||
"examples": [
|
|
||||||
("List all models", "/model"),
|
|
||||||
("Search for GPT models", "/model gpt"),
|
|
||||||
("Search for Claude models", "/model claude"),
|
|
||||||
],
|
|
||||||
"notes": "Models are numbered for easy selection.",
|
|
||||||
},
|
|
||||||
"/config": {
|
|
||||||
"description": "View or modify application configuration.",
|
|
||||||
"usage": "/config [setting] [value]",
|
|
||||||
"examples": [
|
|
||||||
("View all settings", "/config"),
|
|
||||||
("Set API key", "/config api"),
|
|
||||||
("Set default model", "/config model"),
|
|
||||||
("Set system prompt", "/config system You are a helpful assistant"),
|
|
||||||
("Enable streaming", "/config stream on"),
|
|
||||||
],
|
|
||||||
"notes": "Available: api, url, model, system, stream, costwarning, maxtoken, online, loglevel.",
|
|
||||||
},
|
|
||||||
"/maxtoken": {
|
|
||||||
"description": "Set a temporary session token limit.",
|
|
||||||
"usage": "/maxtoken [value]",
|
|
||||||
"examples": [
|
|
||||||
("View current session limit", "/maxtoken"),
|
|
||||||
("Set session limit to 2000", "/maxtoken 2000"),
|
|
||||||
],
|
|
||||||
"notes": "Cannot exceed stored max token limit.",
|
|
||||||
},
|
|
||||||
"/system": {
|
|
||||||
"description": "Set or clear the session-level system prompt.",
|
|
||||||
"usage": "/system [prompt|clear|default <prompt>]",
|
|
||||||
"examples": [
|
|
||||||
("View current system prompt", "/system"),
|
|
||||||
("Set as Python expert", "/system You are a Python expert"),
|
|
||||||
("Multiline with newlines", r"/system You are an expert.\nBe clear and concise."),
|
|
||||||
("Save as default", "/system default You are a helpful assistant"),
|
|
||||||
("Revert to default", "/system clear"),
|
|
||||||
("Blank prompt", '/system ""'),
|
|
||||||
],
|
|
||||||
"notes": r"Use \n for newlines. /system clear reverts to hardcoded default.",
|
|
||||||
},
|
|
||||||
"/save": {
|
|
||||||
"description": "Save the current conversation history.",
|
|
||||||
"usage": "/save <name>",
|
|
||||||
"examples": [("Save conversation", "/save my_chat")],
|
|
||||||
"notes": "Saved conversations can be loaded later with /load.",
|
|
||||||
},
|
|
||||||
"/load": {
|
|
||||||
"description": "Load a saved conversation.",
|
|
||||||
"usage": "/load <name|number>",
|
|
||||||
"examples": [
|
|
||||||
("Load by name", "/load my_chat"),
|
|
||||||
("Load by number from /list", "/load 3"),
|
|
||||||
],
|
|
||||||
"notes": "Use /list to see numbered conversations.",
|
|
||||||
},
|
|
||||||
"/delete": {
|
|
||||||
"description": "Delete a saved conversation.",
|
|
||||||
"usage": "/delete <name|number>",
|
|
||||||
"examples": [("Delete by name", "/delete my_chat")],
|
|
||||||
"notes": "Requires confirmation. Cannot be undone.",
|
|
||||||
},
|
|
||||||
"/list": {
|
|
||||||
"description": "List all saved conversations.",
|
|
||||||
"usage": "/list",
|
|
||||||
"examples": [("Show saved conversations", "/list")],
|
|
||||||
"notes": "Conversations are numbered for use with /load and /delete.",
|
|
||||||
},
|
|
||||||
"/export": {
|
|
||||||
"description": "Export the current conversation to a file.",
|
|
||||||
"usage": "/export <format> <filename>",
|
|
||||||
"examples": [
|
|
||||||
("Export as Markdown", "/export md notes.md"),
|
|
||||||
("Export as JSON", "/export json conversation.json"),
|
|
||||||
("Export as HTML", "/export html report.html"),
|
|
||||||
],
|
|
||||||
"notes": "Available formats: md, json, html.",
|
|
||||||
},
|
|
||||||
"/stats": {
|
|
||||||
"description": "Display session statistics.",
|
|
||||||
"usage": "/stats",
|
|
||||||
"examples": [("View session statistics", "/stats")],
|
|
||||||
"notes": "Shows tokens, costs, and credits.",
|
|
||||||
},
|
|
||||||
"/credits": {
|
|
||||||
"description": "Display your OpenRouter account credits.",
|
|
||||||
"usage": "/credits",
|
|
||||||
"examples": [("Check credits", "/credits")],
|
|
||||||
"notes": "Shows total, used, and remaining credits.",
|
|
||||||
},
|
|
||||||
"/middleout": {
|
|
||||||
"description": "Toggle middle-out transform for long prompts.",
|
|
||||||
"usage": "/middleout [on|off]",
|
|
||||||
"examples": [
|
|
||||||
("Check status", "/middleout"),
|
|
||||||
("Enable compression", "/middleout on"),
|
|
||||||
],
|
|
||||||
"notes": "Compresses prompts exceeding context size.",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
"""
|
|
||||||
Core functionality for oAI.
|
|
||||||
|
|
||||||
This module provides the main session management and AI client
|
|
||||||
classes that power the chat application.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from oai.core.session import ChatSession
|
|
||||||
from oai.core.client import AIClient
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ChatSession",
|
|
||||||
"AIClient",
|
|
||||||
]
|
|
||||||
@@ -1,422 +0,0 @@
|
|||||||
"""
|
|
||||||
AI Client for oAI.
|
|
||||||
|
|
||||||
This module provides a high-level client for interacting with AI models
|
|
||||||
through the provider abstraction layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
|
|
||||||
|
|
||||||
from oai.constants import APP_NAME, APP_URL, MODEL_PRICING
|
|
||||||
from oai.providers.base import (
|
|
||||||
AIProvider,
|
|
||||||
ChatMessage,
|
|
||||||
ChatResponse,
|
|
||||||
ModelInfo,
|
|
||||||
StreamChunk,
|
|
||||||
ToolCall,
|
|
||||||
UsageStats,
|
|
||||||
)
|
|
||||||
from oai.providers.openrouter import OpenRouterProvider
|
|
||||||
from oai.utils.logging import get_logger
|
|
||||||
|
|
||||||
|
|
||||||
class AIClient:
|
|
||||||
"""
|
|
||||||
High-level AI client for chat interactions.
|
|
||||||
|
|
||||||
Provides a simplified interface for sending chat requests,
|
|
||||||
handling streaming, and managing tool calls.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
provider: The underlying AI provider
|
|
||||||
default_model: Default model ID to use
|
|
||||||
http_headers: Custom HTTP headers for requests
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api_key: str,
|
|
||||||
base_url: Optional[str] = None,
|
|
||||||
provider_class: type = OpenRouterProvider,
|
|
||||||
app_name: str = APP_NAME,
|
|
||||||
app_url: str = APP_URL,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the AI client.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key: API key for authentication
|
|
||||||
base_url: Optional custom base URL
|
|
||||||
provider_class: Provider class to use (default: OpenRouterProvider)
|
|
||||||
app_name: Application name for headers
|
|
||||||
app_url: Application URL for headers
|
|
||||||
"""
|
|
||||||
self.provider: AIProvider = provider_class(
|
|
||||||
api_key=api_key,
|
|
||||||
base_url=base_url,
|
|
||||||
app_name=app_name,
|
|
||||||
app_url=app_url,
|
|
||||||
)
|
|
||||||
self.default_model: Optional[str] = None
|
|
||||||
self.logger = get_logger()
|
|
||||||
|
|
||||||
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
|
||||||
"""
|
|
||||||
Get available models.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filter_text_only: Whether to exclude video-only models
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of ModelInfo objects
|
|
||||||
"""
|
|
||||||
return self.provider.list_models(filter_text_only=filter_text_only)
|
|
||||||
|
|
||||||
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
|
||||||
"""
|
|
||||||
Get information about a specific model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id: Model identifier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ModelInfo or None if not found
|
|
||||||
"""
|
|
||||||
return self.provider.get_model(model_id)
|
|
||||||
|
|
||||||
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Get raw model data for provider-specific fields.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id: Model identifier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Raw model dictionary or None
|
|
||||||
"""
|
|
||||||
if hasattr(self.provider, "get_raw_model"):
|
|
||||||
return self.provider.get_raw_model(model_id)
|
|
||||||
return None
|
|
||||||
|
|
||||||
def chat(
|
|
||||||
self,
|
|
||||||
messages: List[Dict[str, Any]],
|
|
||||||
model: Optional[str] = None,
|
|
||||||
stream: bool = False,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
tools: Optional[List[Dict[str, Any]]] = None,
|
|
||||||
tool_choice: Optional[str] = None,
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
online: bool = False,
|
|
||||||
transforms: Optional[List[str]] = None,
|
|
||||||
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
|
||||||
"""
|
|
||||||
Send a chat request.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of message dictionaries
|
|
||||||
model: Model ID (uses default if not specified)
|
|
||||||
stream: Whether to stream the response
|
|
||||||
max_tokens: Maximum tokens in response
|
|
||||||
temperature: Sampling temperature
|
|
||||||
tools: Tool definitions for function calling
|
|
||||||
tool_choice: Tool selection mode
|
|
||||||
system_prompt: System prompt to prepend
|
|
||||||
online: Whether to enable online mode
|
|
||||||
transforms: List of transforms (e.g., ["middle-out"])
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If no model specified and no default set
|
|
||||||
"""
|
|
||||||
model_id = model or self.default_model
|
|
||||||
if not model_id:
|
|
||||||
raise ValueError("No model specified and no default set")
|
|
||||||
|
|
||||||
# Apply online mode suffix
|
|
||||||
if online and hasattr(self.provider, "get_effective_model_id"):
|
|
||||||
model_id = self.provider.get_effective_model_id(model_id, True)
|
|
||||||
|
|
||||||
# Convert dict messages to ChatMessage objects
|
|
||||||
chat_messages = []
|
|
||||||
|
|
||||||
# Add system prompt if provided
|
|
||||||
if system_prompt:
|
|
||||||
chat_messages.append(ChatMessage(role="system", content=system_prompt))
|
|
||||||
|
|
||||||
# Convert message dicts
|
|
||||||
for msg in messages:
|
|
||||||
# Convert tool_calls dicts to ToolCall objects if present
|
|
||||||
tool_calls_data = msg.get("tool_calls")
|
|
||||||
tool_calls_obj = None
|
|
||||||
if tool_calls_data:
|
|
||||||
from oai.providers.base import ToolCall, ToolFunction
|
|
||||||
tool_calls_obj = []
|
|
||||||
for tc in tool_calls_data:
|
|
||||||
# Handle both ToolCall objects and dicts
|
|
||||||
if isinstance(tc, ToolCall):
|
|
||||||
tool_calls_obj.append(tc)
|
|
||||||
elif isinstance(tc, dict):
|
|
||||||
func_data = tc.get("function", {})
|
|
||||||
tool_calls_obj.append(
|
|
||||||
ToolCall(
|
|
||||||
id=tc.get("id", ""),
|
|
||||||
type=tc.get("type", "function"),
|
|
||||||
function=ToolFunction(
|
|
||||||
name=func_data.get("name", ""),
|
|
||||||
arguments=func_data.get("arguments", "{}"),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
chat_messages.append(
|
|
||||||
ChatMessage(
|
|
||||||
role=msg.get("role", "user"),
|
|
||||||
content=msg.get("content"),
|
|
||||||
tool_calls=tool_calls_obj,
|
|
||||||
tool_call_id=msg.get("tool_call_id"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.logger.debug(
|
|
||||||
f"Sending chat request: model={model_id}, "
|
|
||||||
f"messages={len(chat_messages)}, stream={stream}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.provider.chat(
|
|
||||||
model=model_id,
|
|
||||||
messages=chat_messages,
|
|
||||||
stream=stream,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=temperature,
|
|
||||||
tools=tools,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
transforms=transforms,
|
|
||||||
)
|
|
||||||
|
|
||||||
def chat_with_tools(
|
|
||||||
self,
|
|
||||||
messages: List[Dict[str, Any]],
|
|
||||||
tools: List[Dict[str, Any]],
|
|
||||||
tool_executor: Callable[[str, Dict[str, Any]], Dict[str, Any]],
|
|
||||||
model: Optional[str] = None,
|
|
||||||
max_loops: int = 5,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
on_tool_call: Optional[Callable[[ToolCall], None]] = None,
|
|
||||||
on_tool_result: Optional[Callable[[str, Dict[str, Any]], None]] = None,
|
|
||||||
) -> ChatResponse:
|
|
||||||
"""
|
|
||||||
Send a chat request with automatic tool call handling.
|
|
||||||
|
|
||||||
Executes tool calls returned by the model and continues
|
|
||||||
the conversation until no more tool calls are requested.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: Initial messages
|
|
||||||
tools: Tool definitions
|
|
||||||
tool_executor: Function to execute tool calls
|
|
||||||
model: Model ID
|
|
||||||
max_loops: Maximum tool call iterations
|
|
||||||
max_tokens: Maximum response tokens
|
|
||||||
system_prompt: System prompt
|
|
||||||
on_tool_call: Callback when tool is called
|
|
||||||
on_tool_result: Callback when tool returns result
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Final ChatResponse after all tool calls complete
|
|
||||||
"""
|
|
||||||
model_id = model or self.default_model
|
|
||||||
if not model_id:
|
|
||||||
raise ValueError("No model specified and no default set")
|
|
||||||
|
|
||||||
# Build initial messages
|
|
||||||
chat_messages = []
|
|
||||||
if system_prompt:
|
|
||||||
chat_messages.append({"role": "system", "content": system_prompt})
|
|
||||||
chat_messages.extend(messages)
|
|
||||||
|
|
||||||
loop_count = 0
|
|
||||||
current_response: Optional[ChatResponse] = None
|
|
||||||
|
|
||||||
while loop_count < max_loops:
|
|
||||||
# Send request
|
|
||||||
response = self.chat(
|
|
||||||
messages=chat_messages,
|
|
||||||
model=model_id,
|
|
||||||
stream=False,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
tools=tools,
|
|
||||||
tool_choice="auto",
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(response, ChatResponse):
|
|
||||||
raise ValueError("Expected non-streaming response")
|
|
||||||
|
|
||||||
current_response = response
|
|
||||||
|
|
||||||
# Check for tool calls
|
|
||||||
tool_calls = response.tool_calls
|
|
||||||
if not tool_calls:
|
|
||||||
break
|
|
||||||
|
|
||||||
self.logger.info(f"Model requested {len(tool_calls)} tool call(s)")
|
|
||||||
|
|
||||||
# Process each tool call
|
|
||||||
tool_results = []
|
|
||||||
for tc in tool_calls:
|
|
||||||
if on_tool_call:
|
|
||||||
on_tool_call(tc)
|
|
||||||
|
|
||||||
try:
|
|
||||||
args = json.loads(tc.function.arguments)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
self.logger.error(f"Failed to parse tool arguments: {e}")
|
|
||||||
result = {"error": f"Invalid arguments: {e}"}
|
|
||||||
else:
|
|
||||||
result = tool_executor(tc.function.name, args)
|
|
||||||
|
|
||||||
if on_tool_result:
|
|
||||||
on_tool_result(tc.function.name, result)
|
|
||||||
|
|
||||||
tool_results.append({
|
|
||||||
"tool_call_id": tc.id,
|
|
||||||
"role": "tool",
|
|
||||||
"name": tc.function.name,
|
|
||||||
"content": json.dumps(result),
|
|
||||||
})
|
|
||||||
|
|
||||||
# Add assistant message with tool calls
|
|
||||||
assistant_msg = {
|
|
||||||
"role": "assistant",
|
|
||||||
"content": response.content,
|
|
||||||
"tool_calls": [
|
|
||||||
{
|
|
||||||
"id": tc.id,
|
|
||||||
"type": tc.type,
|
|
||||||
"function": {
|
|
||||||
"name": tc.function.name,
|
|
||||||
"arguments": tc.function.arguments,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for tc in tool_calls
|
|
||||||
],
|
|
||||||
}
|
|
||||||
chat_messages.append(assistant_msg)
|
|
||||||
chat_messages.extend(tool_results)
|
|
||||||
|
|
||||||
loop_count += 1
|
|
||||||
|
|
||||||
if loop_count >= max_loops:
|
|
||||||
self.logger.warning(f"Reached max tool call loops ({max_loops})")
|
|
||||||
|
|
||||||
return current_response
|
|
||||||
|
|
||||||
def stream_chat(
|
|
||||||
self,
|
|
||||||
messages: List[Dict[str, Any]],
|
|
||||||
model: Optional[str] = None,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
system_prompt: Optional[str] = None,
|
|
||||||
online: bool = False,
|
|
||||||
on_chunk: Optional[Callable[[StreamChunk], None]] = None,
|
|
||||||
) -> tuple[str, Optional[UsageStats]]:
|
|
||||||
"""
|
|
||||||
Stream a chat response and collect the full text.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: Chat messages
|
|
||||||
model: Model ID
|
|
||||||
max_tokens: Maximum tokens
|
|
||||||
system_prompt: System prompt
|
|
||||||
online: Online mode
|
|
||||||
on_chunk: Optional callback for each chunk
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (full_response_text, usage_stats)
|
|
||||||
"""
|
|
||||||
response = self.chat(
|
|
||||||
messages=messages,
|
|
||||||
model=model,
|
|
||||||
stream=True,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
online=online,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(response, ChatResponse):
|
|
||||||
# Not actually streaming
|
|
||||||
return response.content or "", response.usage
|
|
||||||
|
|
||||||
full_text = ""
|
|
||||||
usage: Optional[UsageStats] = None
|
|
||||||
|
|
||||||
for chunk in response:
|
|
||||||
if chunk.error:
|
|
||||||
self.logger.error(f"Stream error: {chunk.error}")
|
|
||||||
break
|
|
||||||
|
|
||||||
if chunk.delta_content:
|
|
||||||
full_text += chunk.delta_content
|
|
||||||
if on_chunk:
|
|
||||||
on_chunk(chunk)
|
|
||||||
|
|
||||||
if chunk.usage:
|
|
||||||
usage = chunk.usage
|
|
||||||
|
|
||||||
return full_text, usage
|
|
||||||
|
|
||||||
def get_credits(self) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Get account credit information.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Credit info dict or None if unavailable
|
|
||||||
"""
|
|
||||||
return self.provider.get_credits()
|
|
||||||
|
|
||||||
def estimate_cost(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
input_tokens: int,
|
|
||||||
output_tokens: int,
|
|
||||||
) -> float:
|
|
||||||
"""
|
|
||||||
Estimate cost for a completion.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id: Model ID
|
|
||||||
input_tokens: Number of input tokens
|
|
||||||
output_tokens: Number of output tokens
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Estimated cost in USD
|
|
||||||
"""
|
|
||||||
if hasattr(self.provider, "estimate_cost"):
|
|
||||||
return self.provider.estimate_cost(model_id, input_tokens, output_tokens)
|
|
||||||
|
|
||||||
# Fallback to default pricing
|
|
||||||
input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000
|
|
||||||
output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000
|
|
||||||
return input_cost + output_cost
|
|
||||||
|
|
||||||
def set_default_model(self, model_id: str) -> None:
|
|
||||||
"""
|
|
||||||
Set the default model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id: Model ID to use as default
|
|
||||||
"""
|
|
||||||
self.default_model = model_id
|
|
||||||
self.logger.info(f"Default model set to: {model_id}")
|
|
||||||
|
|
||||||
def clear_cache(self) -> None:
|
|
||||||
"""Clear the provider's model cache."""
|
|
||||||
if hasattr(self.provider, "clear_cache"):
|
|
||||||
self.provider.clear_cache()
|
|
||||||
@@ -1,891 +0,0 @@
|
|||||||
"""
|
|
||||||
Chat session management for oAI.
|
|
||||||
|
|
||||||
This module provides the ChatSession class that manages an interactive
|
|
||||||
chat session including history, state, and message handling.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Tuple
|
|
||||||
|
|
||||||
from oai.commands.registry import CommandContext, CommandResult, registry
|
|
||||||
from oai.config.database import Database
|
|
||||||
from oai.config.settings import Settings
|
|
||||||
from oai.constants import (
|
|
||||||
COST_WARNING_THRESHOLD,
|
|
||||||
LOW_CREDIT_AMOUNT,
|
|
||||||
LOW_CREDIT_RATIO,
|
|
||||||
)
|
|
||||||
from oai.core.client import AIClient
|
|
||||||
from oai.mcp.manager import MCPManager
|
|
||||||
from oai.providers.base import ChatResponse, StreamChunk, UsageStats
|
|
||||||
from oai.utils.logging import get_logger
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SessionStats:
|
|
||||||
"""
|
|
||||||
Statistics for the current session.
|
|
||||||
|
|
||||||
Tracks tokens, costs, and message counts.
|
|
||||||
"""
|
|
||||||
|
|
||||||
total_input_tokens: int = 0
|
|
||||||
total_output_tokens: int = 0
|
|
||||||
total_cost: float = 0.0
|
|
||||||
message_count: int = 0
|
|
||||||
|
|
||||||
@property
|
|
||||||
def total_tokens(self) -> int:
|
|
||||||
"""Get total token count."""
|
|
||||||
return self.total_input_tokens + self.total_output_tokens
|
|
||||||
|
|
||||||
def add_usage(self, usage: Optional[UsageStats], cost: float = 0.0) -> None:
|
|
||||||
"""
|
|
||||||
Add usage stats from a response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
usage: Usage statistics
|
|
||||||
cost: Cost if not in usage
|
|
||||||
"""
|
|
||||||
if usage:
|
|
||||||
self.total_input_tokens += usage.prompt_tokens
|
|
||||||
self.total_output_tokens += usage.completion_tokens
|
|
||||||
if usage.total_cost_usd:
|
|
||||||
self.total_cost += usage.total_cost_usd
|
|
||||||
else:
|
|
||||||
self.total_cost += cost
|
|
||||||
else:
|
|
||||||
self.total_cost += cost
|
|
||||||
self.message_count += 1
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class HistoryEntry:
|
|
||||||
"""
|
|
||||||
A single entry in the conversation history.
|
|
||||||
|
|
||||||
Stores the user prompt, assistant response, and metrics.
|
|
||||||
"""
|
|
||||||
|
|
||||||
prompt: str
|
|
||||||
response: str
|
|
||||||
prompt_tokens: int = 0
|
|
||||||
completion_tokens: int = 0
|
|
||||||
msg_cost: float = 0.0
|
|
||||||
timestamp: Optional[float] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
"""Convert to dictionary format."""
|
|
||||||
return {
|
|
||||||
"prompt": self.prompt,
|
|
||||||
"response": self.response,
|
|
||||||
"prompt_tokens": self.prompt_tokens,
|
|
||||||
"completion_tokens": self.completion_tokens,
|
|
||||||
"msg_cost": self.msg_cost,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class ChatSession:
|
|
||||||
"""
|
|
||||||
Manages an interactive chat session.
|
|
||||||
|
|
||||||
Handles conversation history, state management, command processing,
|
|
||||||
and communication with the AI client.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
client: AI client for API requests
|
|
||||||
settings: Application settings
|
|
||||||
mcp_manager: MCP manager for file/database access
|
|
||||||
history: Conversation history
|
|
||||||
stats: Session statistics
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
client: AIClient,
|
|
||||||
settings: Settings,
|
|
||||||
mcp_manager: Optional[MCPManager] = None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize a chat session.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
client: AI client instance
|
|
||||||
settings: Application settings
|
|
||||||
mcp_manager: Optional MCP manager
|
|
||||||
"""
|
|
||||||
self.client = client
|
|
||||||
self.settings = settings
|
|
||||||
self.mcp_manager = mcp_manager
|
|
||||||
self.db = Database()
|
|
||||||
|
|
||||||
self.history: List[HistoryEntry] = []
|
|
||||||
self.stats = SessionStats()
|
|
||||||
|
|
||||||
# Session state
|
|
||||||
self.system_prompt: str = settings.effective_system_prompt
|
|
||||||
self.memory_enabled: bool = True
|
|
||||||
self.memory_start_index: int = 0
|
|
||||||
self.online_enabled: bool = settings.default_online_mode
|
|
||||||
self.middle_out_enabled: bool = False
|
|
||||||
self.session_max_token: int = 0
|
|
||||||
self.current_index: int = 0
|
|
||||||
|
|
||||||
# Selected model
|
|
||||||
self.selected_model: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
self.logger = get_logger()
|
|
||||||
|
|
||||||
def get_context(self) -> CommandContext:
|
|
||||||
"""
|
|
||||||
Get the current command context.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
CommandContext with current session state
|
|
||||||
"""
|
|
||||||
return CommandContext(
|
|
||||||
settings=self.settings,
|
|
||||||
provider=self.client.provider,
|
|
||||||
mcp_manager=self.mcp_manager,
|
|
||||||
selected_model_raw=self.selected_model,
|
|
||||||
session_history=[e.to_dict() for e in self.history],
|
|
||||||
session_system_prompt=self.system_prompt,
|
|
||||||
memory_enabled=self.memory_enabled,
|
|
||||||
memory_start_index=self.memory_start_index,
|
|
||||||
online_enabled=self.online_enabled,
|
|
||||||
middle_out_enabled=self.middle_out_enabled,
|
|
||||||
session_max_token=self.session_max_token,
|
|
||||||
total_input_tokens=self.stats.total_input_tokens,
|
|
||||||
total_output_tokens=self.stats.total_output_tokens,
|
|
||||||
total_cost=self.stats.total_cost,
|
|
||||||
message_count=self.stats.message_count,
|
|
||||||
current_index=self.current_index,
|
|
||||||
)
|
|
||||||
|
|
||||||
def set_model(self, model: Dict[str, Any]) -> None:
|
|
||||||
"""
|
|
||||||
Set the selected model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Raw model dictionary
|
|
||||||
"""
|
|
||||||
self.selected_model = model
|
|
||||||
self.client.set_default_model(model["id"])
|
|
||||||
self.logger.info(f"Model selected: {model['id']}")
|
|
||||||
|
|
||||||
def build_api_messages(self, user_input: str) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Build the messages array for an API request.
|
|
||||||
|
|
||||||
Includes system prompt, history (if memory enabled), and current input.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_input: Current user input
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of message dictionaries
|
|
||||||
"""
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
# Add system prompt
|
|
||||||
if self.system_prompt:
|
|
||||||
messages.append({"role": "system", "content": self.system_prompt})
|
|
||||||
|
|
||||||
# Add database context if in database mode
|
|
||||||
if self.mcp_manager and self.mcp_manager.enabled:
|
|
||||||
if self.mcp_manager.mode == "database" and self.mcp_manager.selected_db_index is not None:
|
|
||||||
db = self.mcp_manager.databases[self.mcp_manager.selected_db_index]
|
|
||||||
db_context = (
|
|
||||||
f"You are connected to SQLite database: {db['name']}\n"
|
|
||||||
f"Available tables: {', '.join(db['tables'])}\n\n"
|
|
||||||
"Use inspect_database, search_database, or query_database tools. "
|
|
||||||
"All queries are read-only."
|
|
||||||
)
|
|
||||||
messages.append({"role": "system", "content": db_context})
|
|
||||||
|
|
||||||
# Add history if memory enabled
|
|
||||||
if self.memory_enabled:
|
|
||||||
for i in range(self.memory_start_index, len(self.history)):
|
|
||||||
entry = self.history[i]
|
|
||||||
messages.append({"role": "user", "content": entry.prompt})
|
|
||||||
messages.append({"role": "assistant", "content": entry.response})
|
|
||||||
|
|
||||||
# Add current message
|
|
||||||
messages.append({"role": "user", "content": user_input})
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
def get_mcp_tools(self) -> Optional[List[Dict[str, Any]]]:
|
|
||||||
"""
|
|
||||||
Get MCP tool definitions if available.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of tool schemas or None
|
|
||||||
"""
|
|
||||||
if not self.mcp_manager or not self.mcp_manager.enabled:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not self.selected_model:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Check if model supports tools
|
|
||||||
supported_params = self.selected_model.get("supported_parameters", [])
|
|
||||||
if "tools" not in supported_params and "functions" not in supported_params:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return self.mcp_manager.get_tools_schema()
|
|
||||||
|
|
||||||
async def execute_tool(
|
|
||||||
self,
|
|
||||||
tool_name: str,
|
|
||||||
tool_args: Dict[str, Any],
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Execute an MCP tool.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_name: Name of the tool
|
|
||||||
tool_args: Tool arguments
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tool execution result
|
|
||||||
"""
|
|
||||||
if not self.mcp_manager:
|
|
||||||
return {"error": "MCP not available"}
|
|
||||||
|
|
||||||
return await self.mcp_manager.call_tool(tool_name, **tool_args)
|
|
||||||
|
|
||||||
def send_message(
|
|
||||||
self,
|
|
||||||
user_input: str,
|
|
||||||
stream: bool = True,
|
|
||||||
on_stream_chunk: Optional[Callable[[str], None]] = None,
|
|
||||||
) -> Tuple[str, Optional[UsageStats], float]:
|
|
||||||
"""
|
|
||||||
Send a message and get a response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_input: User's input text
|
|
||||||
stream: Whether to stream the response
|
|
||||||
on_stream_chunk: Callback for stream chunks
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (response_text, usage_stats, response_time)
|
|
||||||
"""
|
|
||||||
if not self.selected_model:
|
|
||||||
raise ValueError("No model selected")
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
messages = self.build_api_messages(user_input)
|
|
||||||
|
|
||||||
# Get MCP tools
|
|
||||||
tools = self.get_mcp_tools()
|
|
||||||
if tools:
|
|
||||||
# Disable streaming when tools are present
|
|
||||||
stream = False
|
|
||||||
|
|
||||||
# Build request parameters
|
|
||||||
model_id = self.selected_model["id"]
|
|
||||||
if self.online_enabled:
|
|
||||||
if hasattr(self.client.provider, "get_effective_model_id"):
|
|
||||||
model_id = self.client.provider.get_effective_model_id(model_id, True)
|
|
||||||
|
|
||||||
transforms = ["middle-out"] if self.middle_out_enabled else None
|
|
||||||
|
|
||||||
max_tokens = None
|
|
||||||
if self.session_max_token > 0:
|
|
||||||
max_tokens = self.session_max_token
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
# Use tool handling flow
|
|
||||||
response = self._send_with_tools(
|
|
||||||
messages=messages,
|
|
||||||
model_id=model_id,
|
|
||||||
tools=tools,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
transforms=transforms,
|
|
||||||
)
|
|
||||||
response_time = time.time() - start_time
|
|
||||||
return response.content or "", response.usage, response_time
|
|
||||||
|
|
||||||
elif stream:
|
|
||||||
# Use streaming flow
|
|
||||||
full_text, usage = self._stream_response(
|
|
||||||
messages=messages,
|
|
||||||
model_id=model_id,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
transforms=transforms,
|
|
||||||
on_chunk=on_stream_chunk,
|
|
||||||
)
|
|
||||||
response_time = time.time() - start_time
|
|
||||||
return full_text, usage, response_time
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Non-streaming request
|
|
||||||
response = self.client.chat(
|
|
||||||
messages=messages,
|
|
||||||
model=model_id,
|
|
||||||
stream=False,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
transforms=transforms,
|
|
||||||
)
|
|
||||||
response_time = time.time() - start_time
|
|
||||||
|
|
||||||
if isinstance(response, ChatResponse):
|
|
||||||
return response.content or "", response.usage, response_time
|
|
||||||
else:
|
|
||||||
return "", None, response_time
|
|
||||||
|
|
||||||
def _send_with_tools(
|
|
||||||
self,
|
|
||||||
messages: List[Dict[str, Any]],
|
|
||||||
model_id: str,
|
|
||||||
tools: List[Dict[str, Any]],
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
transforms: Optional[List[str]] = None,
|
|
||||||
) -> ChatResponse:
|
|
||||||
"""
|
|
||||||
Send a request with tool call handling.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: API messages
|
|
||||||
model_id: Model ID
|
|
||||||
tools: Tool definitions
|
|
||||||
max_tokens: Max tokens
|
|
||||||
transforms: Transforms list
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Final ChatResponse
|
|
||||||
"""
|
|
||||||
max_loops = 5
|
|
||||||
loop_count = 0
|
|
||||||
api_messages = list(messages)
|
|
||||||
|
|
||||||
while loop_count < max_loops:
|
|
||||||
response = self.client.chat(
|
|
||||||
messages=api_messages,
|
|
||||||
model=model_id,
|
|
||||||
stream=False,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
tools=tools,
|
|
||||||
tool_choice="auto",
|
|
||||||
transforms=transforms,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(response, ChatResponse):
|
|
||||||
raise ValueError("Expected ChatResponse")
|
|
||||||
|
|
||||||
tool_calls = response.tool_calls
|
|
||||||
if not tool_calls:
|
|
||||||
return response
|
|
||||||
|
|
||||||
# Tool calls requested by AI
|
|
||||||
|
|
||||||
tool_results = []
|
|
||||||
for tc in tool_calls:
|
|
||||||
try:
|
|
||||||
args = json.loads(tc.function.arguments)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
self.logger.error(f"Failed to parse tool arguments: {e}")
|
|
||||||
tool_results.append({
|
|
||||||
"tool_call_id": tc.id,
|
|
||||||
"role": "tool",
|
|
||||||
"name": tc.function.name,
|
|
||||||
"content": json.dumps({"error": f"Invalid arguments: {e}"}),
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Display tool call
|
|
||||||
args_display = ", ".join(
|
|
||||||
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
|
|
||||||
for k, v in args.items()
|
|
||||||
)
|
|
||||||
# Executing tool: {tc.function.name}
|
|
||||||
|
|
||||||
# Execute tool
|
|
||||||
result = asyncio.run(self.execute_tool(tc.function.name, args))
|
|
||||||
|
|
||||||
if "error" in result:
|
|
||||||
# Tool execution error logged
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# Tool execution successful
|
|
||||||
pass
|
|
||||||
|
|
||||||
tool_results.append({
|
|
||||||
"tool_call_id": tc.id,
|
|
||||||
"role": "tool",
|
|
||||||
"name": tc.function.name,
|
|
||||||
"content": json.dumps(result),
|
|
||||||
})
|
|
||||||
|
|
||||||
# Add assistant message with tool calls
|
|
||||||
api_messages.append({
|
|
||||||
"role": "assistant",
|
|
||||||
"content": response.content,
|
|
||||||
"tool_calls": [
|
|
||||||
{
|
|
||||||
"id": tc.id,
|
|
||||||
"type": tc.type,
|
|
||||||
"function": {
|
|
||||||
"name": tc.function.name,
|
|
||||||
"arguments": tc.function.arguments,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for tc in tool_calls
|
|
||||||
],
|
|
||||||
})
|
|
||||||
api_messages.extend(tool_results)
|
|
||||||
|
|
||||||
# Processing tool results
|
|
||||||
loop_count += 1
|
|
||||||
|
|
||||||
self.logger.warning(f"Reached max tool loops ({max_loops})")
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def _stream_response(
|
|
||||||
self,
|
|
||||||
messages: List[Dict[str, Any]],
|
|
||||||
model_id: str,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
transforms: Optional[List[str]] = None,
|
|
||||||
on_chunk: Optional[Callable[[str], None]] = None,
|
|
||||||
) -> Tuple[str, Optional[UsageStats]]:
|
|
||||||
"""
|
|
||||||
Stream a response with live display.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: API messages
|
|
||||||
model_id: Model ID
|
|
||||||
max_tokens: Max tokens
|
|
||||||
transforms: Transforms
|
|
||||||
on_chunk: Callback for chunks
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (full_text, usage)
|
|
||||||
"""
|
|
||||||
response = self.client.chat(
|
|
||||||
messages=messages,
|
|
||||||
model=model_id,
|
|
||||||
stream=True,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
transforms=transforms,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(response, ChatResponse):
|
|
||||||
return response.content or "", response.usage
|
|
||||||
|
|
||||||
full_text = ""
|
|
||||||
usage: Optional[UsageStats] = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
for chunk in response:
|
|
||||||
if chunk.error:
|
|
||||||
self.logger.error(f"Stream error: {chunk.error}")
|
|
||||||
break
|
|
||||||
|
|
||||||
if chunk.delta_content:
|
|
||||||
full_text += chunk.delta_content
|
|
||||||
if on_chunk:
|
|
||||||
on_chunk(chunk.delta_content)
|
|
||||||
|
|
||||||
if chunk.usage:
|
|
||||||
usage = chunk.usage
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
self.logger.info("Streaming interrupted")
|
|
||||||
return "", None
|
|
||||||
|
|
||||||
return full_text, usage
|
|
||||||
|
|
||||||
# ========== ASYNC METHODS FOR TUI ==========
|
|
||||||
|
|
||||||
async def send_message_async(
|
|
||||||
self,
|
|
||||||
user_input: str,
|
|
||||||
stream: bool = True,
|
|
||||||
) -> AsyncIterator[StreamChunk]:
|
|
||||||
"""
|
|
||||||
Async version of send_message for Textual TUI.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_input: User's input text
|
|
||||||
stream: Whether to stream the response
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
StreamChunk objects for progressive display
|
|
||||||
"""
|
|
||||||
if not self.selected_model:
|
|
||||||
raise ValueError("No model selected")
|
|
||||||
|
|
||||||
messages = self.build_api_messages(user_input)
|
|
||||||
tools = self.get_mcp_tools()
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
# Disable streaming when tools are present
|
|
||||||
stream = False
|
|
||||||
|
|
||||||
model_id = self.selected_model["id"]
|
|
||||||
if self.online_enabled:
|
|
||||||
if hasattr(self.client.provider, "get_effective_model_id"):
|
|
||||||
model_id = self.client.provider.get_effective_model_id(model_id, True)
|
|
||||||
|
|
||||||
transforms = ["middle-out"] if self.middle_out_enabled else None
|
|
||||||
max_tokens = None
|
|
||||||
if self.session_max_token > 0:
|
|
||||||
max_tokens = self.session_max_token
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
# Use async tool handling flow
|
|
||||||
async for chunk in self._send_with_tools_async(
|
|
||||||
messages=messages,
|
|
||||||
model_id=model_id,
|
|
||||||
tools=tools,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
transforms=transforms,
|
|
||||||
):
|
|
||||||
yield chunk
|
|
||||||
elif stream:
|
|
||||||
# Use async streaming flow
|
|
||||||
async for chunk in self._stream_response_async(
|
|
||||||
messages=messages,
|
|
||||||
model_id=model_id,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
transforms=transforms,
|
|
||||||
):
|
|
||||||
yield chunk
|
|
||||||
else:
|
|
||||||
# Non-streaming request
|
|
||||||
response = self.client.chat(
|
|
||||||
messages=messages,
|
|
||||||
model=model_id,
|
|
||||||
stream=False,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
transforms=transforms,
|
|
||||||
)
|
|
||||||
if isinstance(response, ChatResponse):
|
|
||||||
# Yield single chunk with complete response
|
|
||||||
chunk = StreamChunk(
|
|
||||||
id="",
|
|
||||||
delta_content=response.content,
|
|
||||||
usage=response.usage,
|
|
||||||
error=None,
|
|
||||||
)
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
async def _send_with_tools_async(
|
|
||||||
self,
|
|
||||||
messages: List[Dict[str, Any]],
|
|
||||||
model_id: str,
|
|
||||||
tools: List[Dict[str, Any]],
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
transforms: Optional[List[str]] = None,
|
|
||||||
) -> AsyncIterator[StreamChunk]:
|
|
||||||
"""
|
|
||||||
Async version of _send_with_tools for TUI.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: API messages
|
|
||||||
model_id: Model ID
|
|
||||||
tools: Tool definitions
|
|
||||||
max_tokens: Max tokens
|
|
||||||
transforms: Transforms list
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
StreamChunk objects including tool call notifications
|
|
||||||
"""
|
|
||||||
max_loops = 5
|
|
||||||
loop_count = 0
|
|
||||||
api_messages = list(messages)
|
|
||||||
|
|
||||||
while loop_count < max_loops:
|
|
||||||
response = self.client.chat(
|
|
||||||
messages=api_messages,
|
|
||||||
model=model_id,
|
|
||||||
stream=False,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
tools=tools,
|
|
||||||
tool_choice="auto",
|
|
||||||
transforms=transforms,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(response, ChatResponse):
|
|
||||||
raise ValueError("Expected ChatResponse")
|
|
||||||
|
|
||||||
tool_calls = response.tool_calls
|
|
||||||
if not tool_calls:
|
|
||||||
# Final response, yield it
|
|
||||||
chunk = StreamChunk(
|
|
||||||
id="",
|
|
||||||
delta_content=response.content,
|
|
||||||
usage=response.usage,
|
|
||||||
error=None,
|
|
||||||
)
|
|
||||||
yield chunk
|
|
||||||
return
|
|
||||||
|
|
||||||
# Yield notification about tool calls
|
|
||||||
tool_notification = f"\n🔧 AI requesting {len(tool_calls)} tool call(s)...\n"
|
|
||||||
yield StreamChunk(id="", delta_content=tool_notification, usage=None, error=None)
|
|
||||||
|
|
||||||
tool_results = []
|
|
||||||
for tc in tool_calls:
|
|
||||||
try:
|
|
||||||
args = json.loads(tc.function.arguments)
|
|
||||||
except json.JSONDecodeError as e:
|
|
||||||
self.logger.error(f"Failed to parse tool arguments: {e}")
|
|
||||||
tool_results.append({
|
|
||||||
"tool_call_id": tc.id,
|
|
||||||
"role": "tool",
|
|
||||||
"name": tc.function.name,
|
|
||||||
"content": json.dumps({"error": f"Invalid arguments: {e}"}),
|
|
||||||
})
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Yield tool call display
|
|
||||||
args_display = ", ".join(
|
|
||||||
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
|
|
||||||
for k, v in args.items()
|
|
||||||
)
|
|
||||||
tool_display = f" → {tc.function.name}({args_display})\n"
|
|
||||||
yield StreamChunk(id="", delta_content=tool_display, usage=None, error=None)
|
|
||||||
|
|
||||||
# Execute tool (await instead of asyncio.run)
|
|
||||||
result = await self.execute_tool(tc.function.name, args)
|
|
||||||
|
|
||||||
if "error" in result:
|
|
||||||
error_msg = f" ✗ Error: {result['error']}\n"
|
|
||||||
yield StreamChunk(id="", delta_content=error_msg, usage=None, error=None)
|
|
||||||
else:
|
|
||||||
success_msg = self._format_tool_success(tc.function.name, result)
|
|
||||||
yield StreamChunk(id="", delta_content=success_msg, usage=None, error=None)
|
|
||||||
|
|
||||||
tool_results.append({
|
|
||||||
"tool_call_id": tc.id,
|
|
||||||
"role": "tool",
|
|
||||||
"name": tc.function.name,
|
|
||||||
"content": json.dumps(result),
|
|
||||||
})
|
|
||||||
|
|
||||||
# Add assistant message with tool calls
|
|
||||||
api_messages.append({
|
|
||||||
"role": "assistant",
|
|
||||||
"content": response.content,
|
|
||||||
"tool_calls": [
|
|
||||||
{
|
|
||||||
"id": tc.id,
|
|
||||||
"type": tc.type,
|
|
||||||
"function": {
|
|
||||||
"name": tc.function.name,
|
|
||||||
"arguments": tc.function.arguments,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for tc in tool_calls
|
|
||||||
],
|
|
||||||
})
|
|
||||||
|
|
||||||
# Add tool results
|
|
||||||
api_messages.extend(tool_results)
|
|
||||||
loop_count += 1
|
|
||||||
|
|
||||||
# Max loops reached
|
|
||||||
yield StreamChunk(
|
|
||||||
id="",
|
|
||||||
delta_content="\n⚠️ Maximum tool call loops reached\n",
|
|
||||||
usage=None,
|
|
||||||
error="Max loops reached"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _format_tool_success(self, tool_name: str, result: Dict[str, Any]) -> str:
|
|
||||||
"""Format a success message for a tool call."""
|
|
||||||
if tool_name == "search_files":
|
|
||||||
count = result.get("count", 0)
|
|
||||||
return f" ✓ Found {count} file(s)\n"
|
|
||||||
elif tool_name == "read_file":
|
|
||||||
size = result.get("size", 0)
|
|
||||||
truncated = " (truncated)" if result.get("truncated") else ""
|
|
||||||
return f" ✓ Read {size} bytes{truncated}\n"
|
|
||||||
elif tool_name == "list_directory":
|
|
||||||
count = result.get("count", 0)
|
|
||||||
return f" ✓ Listed {count} item(s)\n"
|
|
||||||
elif tool_name == "inspect_database":
|
|
||||||
if "table" in result:
|
|
||||||
return f" ✓ Inspected table: {result['table']}\n"
|
|
||||||
else:
|
|
||||||
return f" ✓ Inspected database ({result.get('table_count', 0)} tables)\n"
|
|
||||||
elif tool_name == "search_database":
|
|
||||||
count = result.get("count", 0)
|
|
||||||
return f" ✓ Found {count} match(es)\n"
|
|
||||||
elif tool_name == "query_database":
|
|
||||||
count = result.get("count", 0)
|
|
||||||
return f" ✓ Query returned {count} row(s)\n"
|
|
||||||
else:
|
|
||||||
return " ✓ Success\n"
|
|
||||||
|
|
||||||
async def _stream_response_async(
|
|
||||||
self,
|
|
||||||
messages: List[Dict[str, Any]],
|
|
||||||
model_id: str,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
transforms: Optional[List[str]] = None,
|
|
||||||
) -> AsyncIterator[StreamChunk]:
|
|
||||||
"""
|
|
||||||
Async version of _stream_response for TUI.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: API messages
|
|
||||||
model_id: Model ID
|
|
||||||
max_tokens: Max tokens
|
|
||||||
transforms: Transforms
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
StreamChunk objects
|
|
||||||
"""
|
|
||||||
response = self.client.chat(
|
|
||||||
messages=messages,
|
|
||||||
model=model_id,
|
|
||||||
stream=True,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
transforms=transforms,
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(response, ChatResponse):
|
|
||||||
# Non-streaming response
|
|
||||||
chunk = StreamChunk(
|
|
||||||
id="",
|
|
||||||
delta_content=response.content,
|
|
||||||
usage=response.usage,
|
|
||||||
error=None,
|
|
||||||
)
|
|
||||||
yield chunk
|
|
||||||
return
|
|
||||||
|
|
||||||
# Stream the response
|
|
||||||
for chunk in response:
|
|
||||||
if chunk.error:
|
|
||||||
yield StreamChunk(id="", delta_content=None, usage=None, error=chunk.error)
|
|
||||||
break
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
# ========== END ASYNC METHODS ==========
|
|
||||||
|
|
||||||
def add_to_history(
|
|
||||||
self,
|
|
||||||
prompt: str,
|
|
||||||
response: str,
|
|
||||||
usage: Optional[UsageStats] = None,
|
|
||||||
cost: float = 0.0,
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
Add an exchange to the history.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt: User prompt
|
|
||||||
response: Assistant response
|
|
||||||
usage: Usage statistics
|
|
||||||
cost: Cost if not in usage
|
|
||||||
"""
|
|
||||||
entry = HistoryEntry(
|
|
||||||
prompt=prompt,
|
|
||||||
response=response,
|
|
||||||
prompt_tokens=usage.prompt_tokens if usage else 0,
|
|
||||||
completion_tokens=usage.completion_tokens if usage else 0,
|
|
||||||
msg_cost=usage.total_cost_usd if usage and usage.total_cost_usd else cost,
|
|
||||||
timestamp=time.time(),
|
|
||||||
)
|
|
||||||
self.history.append(entry)
|
|
||||||
self.current_index = len(self.history) - 1
|
|
||||||
self.stats.add_usage(usage, cost)
|
|
||||||
|
|
||||||
def save_conversation(self, name: str) -> bool:
|
|
||||||
"""
|
|
||||||
Save the current conversation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Name for the saved conversation
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if saved successfully
|
|
||||||
"""
|
|
||||||
if not self.history:
|
|
||||||
return False
|
|
||||||
|
|
||||||
data = [e.to_dict() for e in self.history]
|
|
||||||
self.db.save_conversation(name, data)
|
|
||||||
self.logger.info(f"Saved conversation: {name}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def load_conversation(self, name: str) -> bool:
|
|
||||||
"""
|
|
||||||
Load a saved conversation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Name of the conversation to load
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if loaded successfully
|
|
||||||
"""
|
|
||||||
data = self.db.load_conversation(name)
|
|
||||||
if not data:
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.history.clear()
|
|
||||||
for entry_dict in data:
|
|
||||||
self.history.append(HistoryEntry(
|
|
||||||
prompt=entry_dict.get("prompt", ""),
|
|
||||||
response=entry_dict.get("response", ""),
|
|
||||||
prompt_tokens=entry_dict.get("prompt_tokens", 0),
|
|
||||||
completion_tokens=entry_dict.get("completion_tokens", 0),
|
|
||||||
msg_cost=entry_dict.get("msg_cost", 0.0),
|
|
||||||
))
|
|
||||||
|
|
||||||
self.current_index = len(self.history) - 1
|
|
||||||
self.memory_start_index = 0
|
|
||||||
self.stats = SessionStats() # Reset stats for loaded conversation
|
|
||||||
self.logger.info(f"Loaded conversation: {name}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
"""Reset the session state."""
|
|
||||||
self.history.clear()
|
|
||||||
self.stats = SessionStats()
|
|
||||||
self.system_prompt = ""
|
|
||||||
self.memory_start_index = 0
|
|
||||||
self.current_index = 0
|
|
||||||
self.logger.info("Session reset")
|
|
||||||
|
|
||||||
def check_warnings(self) -> List[str]:
|
|
||||||
"""
|
|
||||||
Check for cost and credit warnings.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of warning messages
|
|
||||||
"""
|
|
||||||
warnings = []
|
|
||||||
|
|
||||||
# Check last message cost
|
|
||||||
if self.history:
|
|
||||||
last_cost = self.history[-1].msg_cost
|
|
||||||
threshold = self.settings.cost_warning_threshold
|
|
||||||
if last_cost > threshold:
|
|
||||||
warnings.append(
|
|
||||||
f"High cost: ${last_cost:.4f} exceeds threshold ${threshold:.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check credits
|
|
||||||
credits = self.client.get_credits()
|
|
||||||
if credits:
|
|
||||||
left = credits.get("credits_left", 0)
|
|
||||||
total = credits.get("total_credits", 0)
|
|
||||||
|
|
||||||
if left < LOW_CREDIT_AMOUNT:
|
|
||||||
warnings.append(f"Low credits: ${left:.2f} remaining!")
|
|
||||||
elif total > 0 and left < total * LOW_CREDIT_RATIO:
|
|
||||||
warnings.append(f"Credits low: less than 10% remaining (${left:.2f})")
|
|
||||||
|
|
||||||
return warnings
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
"""
|
|
||||||
Model Context Protocol (MCP) integration for oAI.
|
|
||||||
|
|
||||||
This package provides filesystem and database access capabilities
|
|
||||||
through the MCP standard, allowing AI models to interact with
|
|
||||||
local files and SQLite databases safely.
|
|
||||||
|
|
||||||
Key components:
|
|
||||||
- MCPManager: High-level manager for MCP operations
|
|
||||||
- MCPFilesystemServer: Filesystem and database access implementation
|
|
||||||
- GitignoreParser: Pattern matching for .gitignore support
|
|
||||||
- SQLiteQueryValidator: Query safety validation
|
|
||||||
- CrossPlatformMCPConfig: OS-specific configuration
|
|
||||||
"""
|
|
||||||
|
|
||||||
from oai.mcp.manager import MCPManager
|
|
||||||
from oai.mcp.server import MCPFilesystemServer
|
|
||||||
from oai.mcp.gitignore import GitignoreParser
|
|
||||||
from oai.mcp.validators import SQLiteQueryValidator
|
|
||||||
from oai.mcp.platform import CrossPlatformMCPConfig
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"MCPManager",
|
|
||||||
"MCPFilesystemServer",
|
|
||||||
"GitignoreParser",
|
|
||||||
"SQLiteQueryValidator",
|
|
||||||
"CrossPlatformMCPConfig",
|
|
||||||
]
|
|
||||||
@@ -1,166 +0,0 @@
|
|||||||
"""
|
|
||||||
Gitignore pattern parsing for oAI MCP.
|
|
||||||
|
|
||||||
This module implements .gitignore pattern matching to filter files
|
|
||||||
during MCP filesystem operations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import fnmatch
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Tuple
|
|
||||||
|
|
||||||
from oai.utils.logging import get_logger
|
|
||||||
|
|
||||||
|
|
||||||
class GitignoreParser:
|
|
||||||
"""
|
|
||||||
Parse and apply .gitignore patterns.
|
|
||||||
|
|
||||||
Supports standard gitignore syntax including:
|
|
||||||
- Wildcards (*) and double wildcards (**)
|
|
||||||
- Directory-only patterns (ending with /)
|
|
||||||
- Negation patterns (starting with !)
|
|
||||||
- Comments (lines starting with #)
|
|
||||||
|
|
||||||
Patterns are applied relative to the directory containing
|
|
||||||
the .gitignore file.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize an empty pattern collection."""
|
|
||||||
# List of (pattern, is_negation, source_dir)
|
|
||||||
self.patterns: List[Tuple[str, bool, Path]] = []
|
|
||||||
|
|
||||||
def add_gitignore(self, gitignore_path: Path) -> None:
|
|
||||||
"""
|
|
||||||
Parse and add patterns from a .gitignore file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
gitignore_path: Path to the .gitignore file
|
|
||||||
"""
|
|
||||||
logger = get_logger()
|
|
||||||
|
|
||||||
if not gitignore_path.exists():
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
source_dir = gitignore_path.parent
|
|
||||||
|
|
||||||
with open(gitignore_path, "r", encoding="utf-8") as f:
|
|
||||||
for line_num, line in enumerate(f, 1):
|
|
||||||
line = line.rstrip("\n\r")
|
|
||||||
|
|
||||||
# Skip empty lines and comments
|
|
||||||
if not line or line.startswith("#"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check for negation pattern
|
|
||||||
is_negation = line.startswith("!")
|
|
||||||
if is_negation:
|
|
||||||
line = line[1:]
|
|
||||||
|
|
||||||
# Remove leading slash (make relative to gitignore location)
|
|
||||||
if line.startswith("/"):
|
|
||||||
line = line[1:]
|
|
||||||
|
|
||||||
self.patterns.append((line, is_negation, source_dir))
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Loaded {len(self.patterns)} patterns from {gitignore_path}"
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Error reading {gitignore_path}: {e}")
|
|
||||||
|
|
||||||
def should_ignore(self, path: Path) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a path should be ignored based on gitignore patterns.
|
|
||||||
|
|
||||||
Patterns are evaluated in order, with later patterns overriding
|
|
||||||
earlier ones. Negation patterns (starting with !) un-ignore
|
|
||||||
previously matched paths.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Path to check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the path should be ignored
|
|
||||||
"""
|
|
||||||
if not self.patterns:
|
|
||||||
return False
|
|
||||||
|
|
||||||
ignored = False
|
|
||||||
|
|
||||||
for pattern, is_negation, source_dir in self.patterns:
|
|
||||||
# Only apply pattern if path is under the source directory
|
|
||||||
try:
|
|
||||||
rel_path = path.relative_to(source_dir)
|
|
||||||
except ValueError:
|
|
||||||
# Path is not relative to this gitignore's directory
|
|
||||||
continue
|
|
||||||
|
|
||||||
rel_path_str = str(rel_path)
|
|
||||||
|
|
||||||
# Check if pattern matches
|
|
||||||
if self._match_pattern(pattern, rel_path_str, path.is_dir()):
|
|
||||||
if is_negation:
|
|
||||||
ignored = False # Negation patterns un-ignore
|
|
||||||
else:
|
|
||||||
ignored = True
|
|
||||||
|
|
||||||
return ignored
|
|
||||||
|
|
||||||
def _match_pattern(self, pattern: str, path: str, is_dir: bool) -> bool:
|
|
||||||
"""
|
|
||||||
Match a gitignore pattern against a path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
pattern: The gitignore pattern
|
|
||||||
path: The relative path string to match
|
|
||||||
is_dir: Whether the path is a directory
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the pattern matches
|
|
||||||
"""
|
|
||||||
# Directory-only pattern (ends with /)
|
|
||||||
if pattern.endswith("/"):
|
|
||||||
if not is_dir:
|
|
||||||
return False
|
|
||||||
pattern = pattern[:-1]
|
|
||||||
|
|
||||||
# Handle ** patterns (matches any number of directories)
|
|
||||||
if "**" in pattern:
|
|
||||||
pattern_parts = pattern.split("**")
|
|
||||||
if len(pattern_parts) == 2:
|
|
||||||
prefix, suffix = pattern_parts
|
|
||||||
|
|
||||||
# Match if path starts with prefix and ends with suffix
|
|
||||||
if prefix:
|
|
||||||
if not path.startswith(prefix.rstrip("/")):
|
|
||||||
return False
|
|
||||||
if suffix:
|
|
||||||
suffix = suffix.lstrip("/")
|
|
||||||
if not (path.endswith(suffix) or f"/{suffix}" in path):
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Direct match using fnmatch
|
|
||||||
if fnmatch.fnmatch(path, pattern):
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Match as subdirectory pattern (pattern without / matches in any directory)
|
|
||||||
if "/" not in pattern:
|
|
||||||
parts = path.split("/")
|
|
||||||
if any(fnmatch.fnmatch(part, pattern) for part in parts):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def clear(self) -> None:
|
|
||||||
"""Clear all loaded patterns."""
|
|
||||||
self.patterns = []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def pattern_count(self) -> int:
|
|
||||||
"""Get the number of loaded patterns."""
|
|
||||||
return len(self.patterns)
|
|
||||||
1365
oai/mcp/manager.py
1365
oai/mcp/manager.py
File diff suppressed because it is too large
Load Diff
@@ -1,228 +0,0 @@
|
|||||||
"""
|
|
||||||
Cross-platform MCP configuration for oAI.
|
|
||||||
|
|
||||||
This module handles OS-specific configuration, path handling,
|
|
||||||
and security checks for the MCP filesystem server.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import platform
|
|
||||||
import subprocess
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
|
|
||||||
from oai.constants import SYSTEM_DIRS_BLACKLIST
|
|
||||||
from oai.utils.logging import get_logger
|
|
||||||
|
|
||||||
|
|
||||||
class CrossPlatformMCPConfig:
|
|
||||||
"""
|
|
||||||
Handle OS-specific MCP configuration.
|
|
||||||
|
|
||||||
Provides methods for path normalization, security validation,
|
|
||||||
and OS-specific default directories.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
system: Operating system name
|
|
||||||
is_macos: Whether running on macOS
|
|
||||||
is_linux: Whether running on Linux
|
|
||||||
is_windows: Whether running on Windows
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize platform detection."""
|
|
||||||
self.system = platform.system()
|
|
||||||
self.is_macos = self.system == "Darwin"
|
|
||||||
self.is_linux = self.system == "Linux"
|
|
||||||
self.is_windows = self.system == "Windows"
|
|
||||||
|
|
||||||
logger = get_logger()
|
|
||||||
logger.info(f"Detected OS: {self.system}")
|
|
||||||
|
|
||||||
def get_default_allowed_dirs(self) -> List[Path]:
|
|
||||||
"""
|
|
||||||
Get safe default directories for the current OS.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of default directories that are safe to access
|
|
||||||
"""
|
|
||||||
home = Path.home()
|
|
||||||
|
|
||||||
if self.is_macos:
|
|
||||||
return [
|
|
||||||
home / "Documents",
|
|
||||||
home / "Desktop",
|
|
||||||
home / "Downloads",
|
|
||||||
]
|
|
||||||
|
|
||||||
elif self.is_linux:
|
|
||||||
dirs = [home / "Documents"]
|
|
||||||
|
|
||||||
# Try to get XDG directories
|
|
||||||
try:
|
|
||||||
for xdg_dir in ["DOCUMENTS", "DESKTOP", "DOWNLOAD"]:
|
|
||||||
result = subprocess.run(
|
|
||||||
["xdg-user-dir", xdg_dir],
|
|
||||||
capture_output=True,
|
|
||||||
text=True,
|
|
||||||
timeout=1
|
|
||||||
)
|
|
||||||
if result.returncode == 0:
|
|
||||||
dir_path = Path(result.stdout.strip())
|
|
||||||
if dir_path.exists():
|
|
||||||
dirs.append(dir_path)
|
|
||||||
except (subprocess.TimeoutExpired, FileNotFoundError):
|
|
||||||
# Fallback to standard locations
|
|
||||||
dirs.extend([
|
|
||||||
home / "Desktop",
|
|
||||||
home / "Downloads",
|
|
||||||
])
|
|
||||||
|
|
||||||
return list(set(dirs))
|
|
||||||
|
|
||||||
elif self.is_windows:
|
|
||||||
return [
|
|
||||||
home / "Documents",
|
|
||||||
home / "Desktop",
|
|
||||||
home / "Downloads",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Fallback for unknown OS
|
|
||||||
return [home]
|
|
||||||
|
|
||||||
def get_python_command(self) -> str:
|
|
||||||
"""
|
|
||||||
Get the Python executable path.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Path to the Python executable
|
|
||||||
"""
|
|
||||||
import sys
|
|
||||||
return sys.executable
|
|
||||||
|
|
||||||
def get_filesystem_warning(self) -> str:
|
|
||||||
"""
|
|
||||||
Get OS-specific security warning message.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Warning message for the current OS
|
|
||||||
"""
|
|
||||||
if self.is_macos:
|
|
||||||
return """
|
|
||||||
Note: macOS Security
|
|
||||||
The Filesystem MCP server needs access to your selected folder.
|
|
||||||
You may see a security prompt - click 'Allow' to proceed.
|
|
||||||
(System Settings > Privacy & Security > Files and Folders)
|
|
||||||
"""
|
|
||||||
elif self.is_linux:
|
|
||||||
return """
|
|
||||||
Note: Linux Security
|
|
||||||
The Filesystem MCP server will access your selected folder.
|
|
||||||
Ensure oAI has appropriate file permissions.
|
|
||||||
"""
|
|
||||||
elif self.is_windows:
|
|
||||||
return """
|
|
||||||
Note: Windows Security
|
|
||||||
The Filesystem MCP server will access your selected folder.
|
|
||||||
You may need to grant file access permissions.
|
|
||||||
"""
|
|
||||||
return ""
|
|
||||||
|
|
||||||
def normalize_path(self, path: str) -> Path:
|
|
||||||
"""
|
|
||||||
Normalize a path for the current OS.
|
|
||||||
|
|
||||||
Expands user directory (~) and resolves to absolute path.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Path string to normalize
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Normalized absolute Path
|
|
||||||
"""
|
|
||||||
return Path(os.path.expanduser(path)).resolve()
|
|
||||||
|
|
||||||
def is_system_directory(self, path: Path) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a path is a protected system directory.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: Path to check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the path is a system directory
|
|
||||||
"""
|
|
||||||
path_str = str(path)
|
|
||||||
for blocked in SYSTEM_DIRS_BLACKLIST:
|
|
||||||
if path_str.startswith(blocked):
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
def is_safe_path(self, requested_path: Path, allowed_dirs: List[Path]) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a path is within allowed directories.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
requested_path: Path being requested
|
|
||||||
allowed_dirs: List of allowed parent directories
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the path is within an allowed directory
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
requested = requested_path.resolve()
|
|
||||||
|
|
||||||
for allowed in allowed_dirs:
|
|
||||||
try:
|
|
||||||
allowed_resolved = allowed.resolve()
|
|
||||||
requested.relative_to(allowed_resolved)
|
|
||||||
return True
|
|
||||||
except ValueError:
|
|
||||||
continue
|
|
||||||
|
|
||||||
return False
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_folder_stats(self, folder: Path) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Get statistics for a folder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
folder: Path to the folder
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary with folder statistics:
|
|
||||||
- exists: Whether the folder exists
|
|
||||||
- file_count: Number of files (if exists)
|
|
||||||
- total_size: Total size in bytes (if exists)
|
|
||||||
- size_mb: Size in megabytes (if exists)
|
|
||||||
- error: Error message (if any)
|
|
||||||
"""
|
|
||||||
logger = get_logger()
|
|
||||||
|
|
||||||
try:
|
|
||||||
if not folder.exists() or not folder.is_dir():
|
|
||||||
return {"exists": False}
|
|
||||||
|
|
||||||
file_count = 0
|
|
||||||
total_size = 0
|
|
||||||
|
|
||||||
for item in folder.rglob("*"):
|
|
||||||
if item.is_file():
|
|
||||||
file_count += 1
|
|
||||||
try:
|
|
||||||
total_size += item.stat().st_size
|
|
||||||
except (OSError, PermissionError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
return {
|
|
||||||
"exists": True,
|
|
||||||
"file_count": file_count,
|
|
||||||
"total_size": total_size,
|
|
||||||
"size_mb": total_size / (1024 * 1024),
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error getting folder stats for {folder}: {e}")
|
|
||||||
return {"exists": False, "error": str(e)}
|
|
||||||
1368
oai/mcp/server.py
1368
oai/mcp/server.py
File diff suppressed because it is too large
Load Diff
@@ -1,123 +0,0 @@
|
|||||||
"""
|
|
||||||
Query validation for oAI MCP database operations.
|
|
||||||
|
|
||||||
This module provides safety validation for SQL queries to ensure
|
|
||||||
only read-only operations are executed.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from oai.constants import DANGEROUS_SQL_KEYWORDS
|
|
||||||
|
|
||||||
|
|
||||||
class SQLiteQueryValidator:
|
|
||||||
"""
|
|
||||||
Validate SQLite queries for read-only safety.
|
|
||||||
|
|
||||||
Ensures that only SELECT queries (including CTEs) are allowed
|
|
||||||
and blocks potentially dangerous operations like INSERT, UPDATE,
|
|
||||||
DELETE, DROP, etc.
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_safe_query(query: str) -> Tuple[bool, str]:
|
|
||||||
"""
|
|
||||||
Validate that a query is a safe read-only SELECT.
|
|
||||||
|
|
||||||
The validation:
|
|
||||||
1. Checks that query starts with SELECT or WITH
|
|
||||||
2. Strips string literals before checking for dangerous keywords
|
|
||||||
3. Blocks any dangerous keywords outside of string literals
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: SQL query string to validate
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (is_safe, error_message)
|
|
||||||
- is_safe: True if the query is safe to execute
|
|
||||||
- error_message: Description of why query is unsafe (empty if safe)
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
>>> SQLiteQueryValidator.is_safe_query("SELECT * FROM users")
|
|
||||||
(True, "")
|
|
||||||
>>> SQLiteQueryValidator.is_safe_query("DELETE FROM users")
|
|
||||||
(False, "Only SELECT queries are allowed...")
|
|
||||||
>>> SQLiteQueryValidator.is_safe_query("SELECT 'DELETE' FROM users")
|
|
||||||
(True, "") # 'DELETE' is inside a string literal
|
|
||||||
"""
|
|
||||||
query_upper = query.strip().upper()
|
|
||||||
|
|
||||||
# Must start with SELECT or WITH (for CTEs)
|
|
||||||
if not (query_upper.startswith("SELECT") or query_upper.startswith("WITH")):
|
|
||||||
return False, "Only SELECT queries are allowed (including WITH/CTE)"
|
|
||||||
|
|
||||||
# Remove string literals before checking for dangerous keywords
|
|
||||||
# This prevents false positives when keywords appear in data
|
|
||||||
query_no_strings = re.sub(r"'[^']*'", "", query_upper)
|
|
||||||
query_no_strings = re.sub(r'"[^"]*"', "", query_no_strings)
|
|
||||||
|
|
||||||
# Check for dangerous keywords outside of quotes
|
|
||||||
for keyword in DANGEROUS_SQL_KEYWORDS:
|
|
||||||
if re.search(r"\b" + keyword + r"\b", query_no_strings):
|
|
||||||
return False, f"Keyword '{keyword}' not allowed in read-only mode"
|
|
||||||
|
|
||||||
return True, ""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def sanitize_table_name(table_name: str) -> str:
|
|
||||||
"""
|
|
||||||
Sanitize a table name to prevent SQL injection.
|
|
||||||
|
|
||||||
Only allows alphanumeric characters and underscores.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
table_name: Table name to sanitize
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sanitized table name
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If table name contains invalid characters
|
|
||||||
"""
|
|
||||||
# Remove any characters that aren't alphanumeric or underscore
|
|
||||||
sanitized = re.sub(r"[^\w]", "", table_name)
|
|
||||||
|
|
||||||
if not sanitized:
|
|
||||||
raise ValueError("Table name cannot be empty after sanitization")
|
|
||||||
|
|
||||||
if sanitized != table_name:
|
|
||||||
raise ValueError(
|
|
||||||
f"Table name contains invalid characters: {table_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return sanitized
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def sanitize_column_name(column_name: str) -> str:
|
|
||||||
"""
|
|
||||||
Sanitize a column name to prevent SQL injection.
|
|
||||||
|
|
||||||
Only allows alphanumeric characters and underscores.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
column_name: Column name to sanitize
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Sanitized column name
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If column name contains invalid characters
|
|
||||||
"""
|
|
||||||
# Remove any characters that aren't alphanumeric or underscore
|
|
||||||
sanitized = re.sub(r"[^\w]", "", column_name)
|
|
||||||
|
|
||||||
if not sanitized:
|
|
||||||
raise ValueError("Column name cannot be empty after sanitization")
|
|
||||||
|
|
||||||
if sanitized != column_name:
|
|
||||||
raise ValueError(
|
|
||||||
f"Column name contains invalid characters: {column_name}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return sanitized
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
"""
|
|
||||||
Provider abstraction for oAI.
|
|
||||||
|
|
||||||
This module provides a unified interface for AI model providers,
|
|
||||||
enabling easy extension to support additional providers beyond OpenRouter.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from oai.providers.base import (
|
|
||||||
AIProvider,
|
|
||||||
ChatMessage,
|
|
||||||
ChatResponse,
|
|
||||||
ToolCall,
|
|
||||||
ToolFunction,
|
|
||||||
UsageStats,
|
|
||||||
ModelInfo,
|
|
||||||
ProviderCapabilities,
|
|
||||||
)
|
|
||||||
from oai.providers.openrouter import OpenRouterProvider
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
# Base classes and types
|
|
||||||
"AIProvider",
|
|
||||||
"ChatMessage",
|
|
||||||
"ChatResponse",
|
|
||||||
"ToolCall",
|
|
||||||
"ToolFunction",
|
|
||||||
"UsageStats",
|
|
||||||
"ModelInfo",
|
|
||||||
"ProviderCapabilities",
|
|
||||||
# Provider implementations
|
|
||||||
"OpenRouterProvider",
|
|
||||||
]
|
|
||||||
@@ -1,413 +0,0 @@
|
|||||||
"""
|
|
||||||
Abstract base provider for AI model integration.
|
|
||||||
|
|
||||||
This module defines the interface that all AI providers must implement,
|
|
||||||
along with common data structures for requests and responses.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from enum import Enum
|
|
||||||
from typing import (
|
|
||||||
Any,
|
|
||||||
AsyncIterator,
|
|
||||||
Dict,
|
|
||||||
Iterator,
|
|
||||||
List,
|
|
||||||
Optional,
|
|
||||||
Union,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MessageRole(str, Enum):
|
|
||||||
"""Message roles in a conversation."""
|
|
||||||
|
|
||||||
SYSTEM = "system"
|
|
||||||
USER = "user"
|
|
||||||
ASSISTANT = "assistant"
|
|
||||||
TOOL = "tool"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ToolFunction:
|
|
||||||
"""
|
|
||||||
Represents a function within a tool call.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
name: The function name
|
|
||||||
arguments: JSON string of function arguments
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
arguments: str
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ToolCall:
|
|
||||||
"""
|
|
||||||
Represents a tool/function call requested by the model.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
id: Unique identifier for this tool call
|
|
||||||
type: Type of tool call (usually "function")
|
|
||||||
function: The function being called
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
type: str
|
|
||||||
function: ToolFunction
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class UsageStats:
|
|
||||||
"""
|
|
||||||
Token usage statistics from an API response.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
prompt_tokens: Number of tokens in the prompt
|
|
||||||
completion_tokens: Number of tokens in the completion
|
|
||||||
total_tokens: Total tokens used
|
|
||||||
total_cost_usd: Cost in USD (if available from API)
|
|
||||||
"""
|
|
||||||
|
|
||||||
prompt_tokens: int = 0
|
|
||||||
completion_tokens: int = 0
|
|
||||||
total_tokens: int = 0
|
|
||||||
total_cost_usd: Optional[float] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input_tokens(self) -> int:
|
|
||||||
"""Alias for prompt_tokens."""
|
|
||||||
return self.prompt_tokens
|
|
||||||
|
|
||||||
@property
|
|
||||||
def output_tokens(self) -> int:
|
|
||||||
"""Alias for completion_tokens."""
|
|
||||||
return self.completion_tokens
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ChatMessage:
|
|
||||||
"""
|
|
||||||
A single message in a chat conversation.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
role: The role of the message sender
|
|
||||||
content: Message content (text or structured content blocks)
|
|
||||||
name: Optional name for the sender
|
|
||||||
tool_calls: List of tool calls (for assistant messages)
|
|
||||||
tool_call_id: Tool call ID this message responds to (for tool messages)
|
|
||||||
"""
|
|
||||||
|
|
||||||
role: str
|
|
||||||
content: Union[str, List[Dict[str, Any]], None] = None
|
|
||||||
name: Optional[str] = None
|
|
||||||
tool_calls: Optional[List[ToolCall]] = None
|
|
||||||
tool_call_id: Optional[str] = None
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
"""Convert to dictionary format for API requests."""
|
|
||||||
result: Dict[str, Any] = {"role": self.role}
|
|
||||||
|
|
||||||
if self.content is not None:
|
|
||||||
result["content"] = self.content
|
|
||||||
|
|
||||||
if self.name:
|
|
||||||
result["name"] = self.name
|
|
||||||
|
|
||||||
if self.tool_calls:
|
|
||||||
result["tool_calls"] = [
|
|
||||||
{
|
|
||||||
"id": tc.id,
|
|
||||||
"type": tc.type,
|
|
||||||
"function": {
|
|
||||||
"name": tc.function.name,
|
|
||||||
"arguments": tc.function.arguments,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for tc in self.tool_calls
|
|
||||||
]
|
|
||||||
|
|
||||||
if self.tool_call_id:
|
|
||||||
result["tool_call_id"] = self.tool_call_id
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ChatResponseChoice:
|
|
||||||
"""
|
|
||||||
A single choice in a chat response.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
index: Index of this choice
|
|
||||||
message: The response message
|
|
||||||
finish_reason: Why the response ended
|
|
||||||
"""
|
|
||||||
|
|
||||||
index: int
|
|
||||||
message: ChatMessage
|
|
||||||
finish_reason: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ChatResponse:
|
|
||||||
"""
|
|
||||||
Response from a chat completion request.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
id: Unique identifier for this response
|
|
||||||
choices: List of response choices
|
|
||||||
usage: Token usage statistics
|
|
||||||
model: Model that generated this response
|
|
||||||
created: Unix timestamp of creation
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
choices: List[ChatResponseChoice]
|
|
||||||
usage: Optional[UsageStats] = None
|
|
||||||
model: Optional[str] = None
|
|
||||||
created: Optional[int] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def message(self) -> Optional[ChatMessage]:
|
|
||||||
"""Get the first choice's message."""
|
|
||||||
if self.choices:
|
|
||||||
return self.choices[0].message
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def content(self) -> Optional[str]:
|
|
||||||
"""Get the text content of the first choice."""
|
|
||||||
msg = self.message
|
|
||||||
if msg and isinstance(msg.content, str):
|
|
||||||
return msg.content
|
|
||||||
return None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def tool_calls(self) -> Optional[List[ToolCall]]:
|
|
||||||
"""Get tool calls from the first choice."""
|
|
||||||
msg = self.message
|
|
||||||
if msg:
|
|
||||||
return msg.tool_calls
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class StreamChunk:
|
|
||||||
"""
|
|
||||||
A single chunk from a streaming response.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
id: Response ID
|
|
||||||
delta_content: New content in this chunk
|
|
||||||
finish_reason: Finish reason (if this is the last chunk)
|
|
||||||
usage: Usage stats (usually in the last chunk)
|
|
||||||
error: Error message if something went wrong
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
delta_content: Optional[str] = None
|
|
||||||
finish_reason: Optional[str] = None
|
|
||||||
usage: Optional[UsageStats] = None
|
|
||||||
error: Optional[str] = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ModelInfo:
|
|
||||||
"""
|
|
||||||
Information about an AI model.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
id: Unique model identifier
|
|
||||||
name: Display name
|
|
||||||
description: Model description
|
|
||||||
context_length: Maximum context window size
|
|
||||||
pricing: Pricing info (input/output per million tokens)
|
|
||||||
supported_parameters: List of supported API parameters
|
|
||||||
input_modalities: Supported input types (text, image, etc.)
|
|
||||||
output_modalities: Supported output types
|
|
||||||
"""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
description: str = ""
|
|
||||||
context_length: int = 0
|
|
||||||
pricing: Dict[str, float] = field(default_factory=dict)
|
|
||||||
supported_parameters: List[str] = field(default_factory=list)
|
|
||||||
input_modalities: List[str] = field(default_factory=lambda: ["text"])
|
|
||||||
output_modalities: List[str] = field(default_factory=lambda: ["text"])
|
|
||||||
|
|
||||||
def supports_images(self) -> bool:
|
|
||||||
"""Check if model supports image input."""
|
|
||||||
return "image" in self.input_modalities
|
|
||||||
|
|
||||||
def supports_tools(self) -> bool:
|
|
||||||
"""Check if model supports function calling/tools."""
|
|
||||||
return "tools" in self.supported_parameters or "functions" in self.supported_parameters
|
|
||||||
|
|
||||||
def supports_streaming(self) -> bool:
|
|
||||||
"""Check if model supports streaming responses."""
|
|
||||||
return "stream" in self.supported_parameters
|
|
||||||
|
|
||||||
def supports_online(self) -> bool:
|
|
||||||
"""Check if model supports web search (online mode)."""
|
|
||||||
return self.supports_tools()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ProviderCapabilities:
|
|
||||||
"""
|
|
||||||
Capabilities supported by a provider.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
streaming: Provider supports streaming responses
|
|
||||||
tools: Provider supports function calling
|
|
||||||
images: Provider supports image inputs
|
|
||||||
online: Provider supports web search
|
|
||||||
max_context: Maximum context length across all models
|
|
||||||
"""
|
|
||||||
|
|
||||||
streaming: bool = True
|
|
||||||
tools: bool = True
|
|
||||||
images: bool = True
|
|
||||||
online: bool = False
|
|
||||||
max_context: int = 128000
|
|
||||||
|
|
||||||
|
|
||||||
class AIProvider(ABC):
|
|
||||||
"""
|
|
||||||
Abstract base class for AI model providers.
|
|
||||||
|
|
||||||
All provider implementations must inherit from this class
|
|
||||||
and implement the required abstract methods.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, api_key: str, base_url: Optional[str] = None):
|
|
||||||
"""
|
|
||||||
Initialize the provider.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key: API key for authentication
|
|
||||||
base_url: Optional custom base URL for the API
|
|
||||||
"""
|
|
||||||
self.api_key = api_key
|
|
||||||
self.base_url = base_url
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def name(self) -> str:
|
|
||||||
"""Get the provider name."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
|
||||||
@abstractmethod
|
|
||||||
def capabilities(self) -> ProviderCapabilities:
|
|
||||||
"""Get provider capabilities."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def list_models(self) -> List[ModelInfo]:
|
|
||||||
"""
|
|
||||||
Fetch available models from the provider.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of available models with their info
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
|
||||||
"""
|
|
||||||
Get information about a specific model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id: The model identifier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model information or None if not found
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def chat(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: List[ChatMessage],
|
|
||||||
stream: bool = False,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
tools: Optional[List[Dict[str, Any]]] = None,
|
|
||||||
tool_choice: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
|
||||||
"""
|
|
||||||
Send a chat completion request.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Model ID to use
|
|
||||||
messages: List of chat messages
|
|
||||||
stream: Whether to stream the response
|
|
||||||
max_tokens: Maximum tokens in response
|
|
||||||
temperature: Sampling temperature
|
|
||||||
tools: List of tool definitions for function calling
|
|
||||||
tool_choice: How to handle tool selection ("auto", "none", etc.)
|
|
||||||
**kwargs: Additional provider-specific parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def chat_async(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: List[ChatMessage],
|
|
||||||
stream: bool = False,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
tools: Optional[List[Dict[str, Any]]] = None,
|
|
||||||
tool_choice: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
|
|
||||||
"""
|
|
||||||
Send an async chat completion request.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Model ID to use
|
|
||||||
messages: List of chat messages
|
|
||||||
stream: Whether to stream the response
|
|
||||||
max_tokens: Maximum tokens in response
|
|
||||||
temperature: Sampling temperature
|
|
||||||
tools: List of tool definitions for function calling
|
|
||||||
tool_choice: How to handle tool selection
|
|
||||||
**kwargs: Additional provider-specific parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ChatResponse for non-streaming, AsyncIterator[StreamChunk] for streaming
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_credits(self) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Get account credit/balance information.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with credit info or None if not supported
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
def validate_api_key(self) -> bool:
|
|
||||||
"""
|
|
||||||
Validate that the API key is valid.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if API key is valid
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
self.list_models()
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return False
|
|
||||||
@@ -1,630 +0,0 @@
|
|||||||
"""
|
|
||||||
OpenRouter provider implementation.
|
|
||||||
|
|
||||||
This module implements the AIProvider interface for OpenRouter,
|
|
||||||
supporting chat completions, streaming, and function calling.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
|
|
||||||
|
|
||||||
import requests
|
|
||||||
from openrouter import OpenRouter
|
|
||||||
|
|
||||||
from oai.constants import APP_NAME, APP_URL, DEFAULT_BASE_URL
|
|
||||||
from oai.providers.base import (
|
|
||||||
AIProvider,
|
|
||||||
ChatMessage,
|
|
||||||
ChatResponse,
|
|
||||||
ChatResponseChoice,
|
|
||||||
ModelInfo,
|
|
||||||
ProviderCapabilities,
|
|
||||||
StreamChunk,
|
|
||||||
ToolCall,
|
|
||||||
ToolFunction,
|
|
||||||
UsageStats,
|
|
||||||
)
|
|
||||||
from oai.utils.logging import get_logger
|
|
||||||
|
|
||||||
|
|
||||||
class OpenRouterProvider(AIProvider):
|
|
||||||
"""
|
|
||||||
OpenRouter API provider implementation.
|
|
||||||
|
|
||||||
Provides access to multiple AI models through OpenRouter's unified API,
|
|
||||||
supporting chat completions, streaming responses, and function calling.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
client: The underlying OpenRouter client
|
|
||||||
_models_cache: Cached list of available models
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
api_key: str,
|
|
||||||
base_url: Optional[str] = None,
|
|
||||||
app_name: str = APP_NAME,
|
|
||||||
app_url: str = APP_URL,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Initialize the OpenRouter provider.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key: OpenRouter API key
|
|
||||||
base_url: Optional custom base URL
|
|
||||||
app_name: Application name for API headers
|
|
||||||
app_url: Application URL for API headers
|
|
||||||
"""
|
|
||||||
super().__init__(api_key, base_url or DEFAULT_BASE_URL)
|
|
||||||
self.app_name = app_name
|
|
||||||
self.app_url = app_url
|
|
||||||
self.client = OpenRouter(api_key=api_key)
|
|
||||||
self._models_cache: Optional[List[ModelInfo]] = None
|
|
||||||
self._raw_models_cache: Optional[List[Dict[str, Any]]] = None
|
|
||||||
|
|
||||||
self.logger = get_logger()
|
|
||||||
self.logger.info(f"OpenRouter provider initialized with base URL: {self.base_url}")
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
"""Get the provider name."""
|
|
||||||
return "OpenRouter"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def capabilities(self) -> ProviderCapabilities:
|
|
||||||
"""Get provider capabilities."""
|
|
||||||
return ProviderCapabilities(
|
|
||||||
streaming=True,
|
|
||||||
tools=True,
|
|
||||||
images=True,
|
|
||||||
online=True,
|
|
||||||
max_context=2000000, # Claude models support up to 200k
|
|
||||||
)
|
|
||||||
|
|
||||||
def _get_headers(self) -> Dict[str, str]:
|
|
||||||
"""Get standard HTTP headers for API requests."""
|
|
||||||
headers = {
|
|
||||||
"HTTP-Referer": self.app_url,
|
|
||||||
"X-Title": self.app_name,
|
|
||||||
}
|
|
||||||
if self.api_key:
|
|
||||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
|
||||||
return headers
|
|
||||||
|
|
||||||
def _parse_model(self, model_data: Dict[str, Any]) -> ModelInfo:
|
|
||||||
"""
|
|
||||||
Parse raw model data into ModelInfo.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_data: Raw model data from API
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Parsed ModelInfo object
|
|
||||||
"""
|
|
||||||
architecture = model_data.get("architecture", {})
|
|
||||||
pricing_data = model_data.get("pricing", {})
|
|
||||||
|
|
||||||
# Parse pricing (convert from string to float if needed)
|
|
||||||
pricing = {}
|
|
||||||
for key in ["prompt", "completion"]:
|
|
||||||
value = pricing_data.get(key)
|
|
||||||
if value is not None:
|
|
||||||
try:
|
|
||||||
# Convert from per-token to per-million-tokens
|
|
||||||
pricing[key] = float(value) * 1_000_000
|
|
||||||
except (ValueError, TypeError):
|
|
||||||
pricing[key] = 0.0
|
|
||||||
|
|
||||||
return ModelInfo(
|
|
||||||
id=model_data.get("id", ""),
|
|
||||||
name=model_data.get("name", model_data.get("id", "")),
|
|
||||||
description=model_data.get("description", ""),
|
|
||||||
context_length=model_data.get("context_length", 0),
|
|
||||||
pricing=pricing,
|
|
||||||
supported_parameters=model_data.get("supported_parameters", []),
|
|
||||||
input_modalities=architecture.get("input_modalities", ["text"]),
|
|
||||||
output_modalities=architecture.get("output_modalities", ["text"]),
|
|
||||||
)
|
|
||||||
|
|
||||||
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
|
|
||||||
"""
|
|
||||||
Fetch available models from OpenRouter.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filter_text_only: If True, exclude video-only models
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of available models
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: If API request fails
|
|
||||||
"""
|
|
||||||
if self._models_cache is not None:
|
|
||||||
return self._models_cache
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.get(
|
|
||||||
f"{self.base_url}/models",
|
|
||||||
headers=self._get_headers(),
|
|
||||||
timeout=10,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
raw_models = response.json().get("data", [])
|
|
||||||
self._raw_models_cache = raw_models
|
|
||||||
|
|
||||||
models = []
|
|
||||||
for model_data in raw_models:
|
|
||||||
# Optionally filter out video-only models
|
|
||||||
if filter_text_only:
|
|
||||||
modalities = model_data.get("modalities", [])
|
|
||||||
if modalities and "video" in modalities and "text" not in modalities:
|
|
||||||
continue
|
|
||||||
|
|
||||||
models.append(self._parse_model(model_data))
|
|
||||||
|
|
||||||
self._models_cache = models
|
|
||||||
self.logger.info(f"Fetched {len(models)} models from OpenRouter")
|
|
||||||
return models
|
|
||||||
|
|
||||||
except requests.RequestException as e:
|
|
||||||
self.logger.error(f"Failed to fetch models: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def get_raw_models(self) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Get raw model data as returned by the API.
|
|
||||||
|
|
||||||
Useful for accessing provider-specific fields not in ModelInfo.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of raw model dictionaries
|
|
||||||
"""
|
|
||||||
if self._raw_models_cache is None:
|
|
||||||
self.list_models()
|
|
||||||
return self._raw_models_cache or []
|
|
||||||
|
|
||||||
def get_model(self, model_id: str) -> Optional[ModelInfo]:
|
|
||||||
"""
|
|
||||||
Get information about a specific model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id: The model identifier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model information or None if not found
|
|
||||||
"""
|
|
||||||
models = self.list_models()
|
|
||||||
for model in models:
|
|
||||||
if model.id == model_id:
|
|
||||||
return model
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Get raw model data for a specific model.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id: The model identifier
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Raw model dictionary or None if not found
|
|
||||||
"""
|
|
||||||
raw_models = self.get_raw_models()
|
|
||||||
for model in raw_models:
|
|
||||||
if model.get("id") == model_id:
|
|
||||||
return model
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Convert ChatMessage objects to API format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
messages: List of ChatMessage objects
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of message dictionaries for the API
|
|
||||||
"""
|
|
||||||
return [msg.to_dict() for msg in messages]
|
|
||||||
|
|
||||||
def _parse_usage(self, usage_data: Any) -> Optional[UsageStats]:
|
|
||||||
"""
|
|
||||||
Parse usage data from API response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
usage_data: Raw usage data from API
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Parsed UsageStats or None
|
|
||||||
"""
|
|
||||||
if not usage_data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Handle both attribute and dict access
|
|
||||||
prompt_tokens = 0
|
|
||||||
completion_tokens = 0
|
|
||||||
total_cost = None
|
|
||||||
|
|
||||||
if hasattr(usage_data, "prompt_tokens"):
|
|
||||||
prompt_tokens = getattr(usage_data, "prompt_tokens", 0) or 0
|
|
||||||
elif isinstance(usage_data, dict):
|
|
||||||
prompt_tokens = usage_data.get("prompt_tokens", 0) or 0
|
|
||||||
|
|
||||||
if hasattr(usage_data, "completion_tokens"):
|
|
||||||
completion_tokens = getattr(usage_data, "completion_tokens", 0) or 0
|
|
||||||
elif isinstance(usage_data, dict):
|
|
||||||
completion_tokens = usage_data.get("completion_tokens", 0) or 0
|
|
||||||
|
|
||||||
# Try alternative naming (input_tokens/output_tokens)
|
|
||||||
if prompt_tokens == 0:
|
|
||||||
if hasattr(usage_data, "input_tokens"):
|
|
||||||
prompt_tokens = getattr(usage_data, "input_tokens", 0) or 0
|
|
||||||
elif isinstance(usage_data, dict):
|
|
||||||
prompt_tokens = usage_data.get("input_tokens", 0) or 0
|
|
||||||
|
|
||||||
if completion_tokens == 0:
|
|
||||||
if hasattr(usage_data, "output_tokens"):
|
|
||||||
completion_tokens = getattr(usage_data, "output_tokens", 0) or 0
|
|
||||||
elif isinstance(usage_data, dict):
|
|
||||||
completion_tokens = usage_data.get("output_tokens", 0) or 0
|
|
||||||
|
|
||||||
# Get cost if available
|
|
||||||
# OpenRouter returns cost in different places:
|
|
||||||
# 1. As 'total_cost_usd' in usage object (rare)
|
|
||||||
# 2. As 'usage' at root level (common - this is the dollar amount)
|
|
||||||
total_cost = None
|
|
||||||
if hasattr(usage_data, "total_cost_usd"):
|
|
||||||
total_cost = getattr(usage_data, "total_cost_usd", None)
|
|
||||||
elif hasattr(usage_data, "usage"):
|
|
||||||
# OpenRouter puts cost as 'usage' field (dollar amount)
|
|
||||||
total_cost = getattr(usage_data, "usage", None)
|
|
||||||
elif isinstance(usage_data, dict):
|
|
||||||
total_cost = usage_data.get("total_cost_usd") or usage_data.get("usage")
|
|
||||||
|
|
||||||
return UsageStats(
|
|
||||||
prompt_tokens=prompt_tokens,
|
|
||||||
completion_tokens=completion_tokens,
|
|
||||||
total_tokens=prompt_tokens + completion_tokens,
|
|
||||||
total_cost_usd=float(total_cost) if total_cost else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _parse_tool_calls(self, tool_calls_data: Any) -> Optional[List[ToolCall]]:
|
|
||||||
"""
|
|
||||||
Parse tool calls from API response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tool_calls_data: Raw tool calls data
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of ToolCall objects or None
|
|
||||||
"""
|
|
||||||
if not tool_calls_data:
|
|
||||||
return None
|
|
||||||
|
|
||||||
tool_calls = []
|
|
||||||
for tc in tool_calls_data:
|
|
||||||
# Handle both attribute and dict access
|
|
||||||
if hasattr(tc, "id"):
|
|
||||||
tc_id = tc.id
|
|
||||||
tc_type = getattr(tc, "type", "function")
|
|
||||||
func = tc.function
|
|
||||||
func_name = func.name
|
|
||||||
func_args = func.arguments
|
|
||||||
else:
|
|
||||||
tc_id = tc.get("id", "")
|
|
||||||
tc_type = tc.get("type", "function")
|
|
||||||
func = tc.get("function", {})
|
|
||||||
func_name = func.get("name", "")
|
|
||||||
func_args = func.get("arguments", "{}")
|
|
||||||
|
|
||||||
tool_calls.append(
|
|
||||||
ToolCall(
|
|
||||||
id=tc_id,
|
|
||||||
type=tc_type,
|
|
||||||
function=ToolFunction(name=func_name, arguments=func_args),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return tool_calls if tool_calls else None
|
|
||||||
|
|
||||||
def _parse_response(self, response: Any) -> ChatResponse:
|
|
||||||
"""
|
|
||||||
Parse API response into ChatResponse.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response: Raw API response
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Parsed ChatResponse
|
|
||||||
"""
|
|
||||||
choices = []
|
|
||||||
for choice in response.choices:
|
|
||||||
msg = choice.message
|
|
||||||
message = ChatMessage(
|
|
||||||
role=msg.role if hasattr(msg, "role") else "assistant",
|
|
||||||
content=msg.content if hasattr(msg, "content") else None,
|
|
||||||
tool_calls=self._parse_tool_calls(
|
|
||||||
getattr(msg, "tool_calls", None)
|
|
||||||
),
|
|
||||||
)
|
|
||||||
choices.append(
|
|
||||||
ChatResponseChoice(
|
|
||||||
index=choice.index if hasattr(choice, "index") else 0,
|
|
||||||
message=message,
|
|
||||||
finish_reason=getattr(choice, "finish_reason", None),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
return ChatResponse(
|
|
||||||
id=response.id if hasattr(response, "id") else "",
|
|
||||||
choices=choices,
|
|
||||||
usage=self._parse_usage(getattr(response, "usage", None)),
|
|
||||||
model=getattr(response, "model", None),
|
|
||||||
created=getattr(response, "created", None),
|
|
||||||
)
|
|
||||||
|
|
||||||
def chat(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: List[ChatMessage],
|
|
||||||
stream: bool = False,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
tools: Optional[List[Dict[str, Any]]] = None,
|
|
||||||
tool_choice: Optional[str] = None,
|
|
||||||
transforms: Optional[List[str]] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Union[ChatResponse, Iterator[StreamChunk]]:
|
|
||||||
"""
|
|
||||||
Send a chat completion request to OpenRouter.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Model ID to use
|
|
||||||
messages: List of chat messages
|
|
||||||
stream: Whether to stream the response
|
|
||||||
max_tokens: Maximum tokens in response
|
|
||||||
temperature: Sampling temperature (0-2)
|
|
||||||
tools: List of tool definitions for function calling
|
|
||||||
tool_choice: How to handle tool selection ("auto", "none", etc.)
|
|
||||||
transforms: List of transforms (e.g., ["middle-out"])
|
|
||||||
**kwargs: Additional parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
|
|
||||||
"""
|
|
||||||
# Build request parameters
|
|
||||||
params: Dict[str, Any] = {
|
|
||||||
"model": model,
|
|
||||||
"messages": self._convert_messages(messages),
|
|
||||||
"stream": stream,
|
|
||||||
"http_headers": self._get_headers(),
|
|
||||||
}
|
|
||||||
|
|
||||||
# Request usage stats in streaming responses
|
|
||||||
if stream:
|
|
||||||
params["stream_options"] = {"include_usage": True}
|
|
||||||
|
|
||||||
if max_tokens is not None:
|
|
||||||
params["max_tokens"] = max_tokens
|
|
||||||
|
|
||||||
if temperature is not None:
|
|
||||||
params["temperature"] = temperature
|
|
||||||
|
|
||||||
if tools:
|
|
||||||
params["tools"] = tools
|
|
||||||
params["tool_choice"] = tool_choice or "auto"
|
|
||||||
|
|
||||||
if transforms:
|
|
||||||
params["transforms"] = transforms
|
|
||||||
|
|
||||||
# Add any additional parameters
|
|
||||||
params.update(kwargs)
|
|
||||||
|
|
||||||
self.logger.debug(f"Sending chat request to model {model}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = self.client.chat.send(**params)
|
|
||||||
|
|
||||||
if stream:
|
|
||||||
return self._stream_response(response)
|
|
||||||
else:
|
|
||||||
return self._parse_response(response)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Chat request failed: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
def _stream_response(self, response: Any) -> Iterator[StreamChunk]:
|
|
||||||
"""
|
|
||||||
Process a streaming response.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
response: Streaming response from API
|
|
||||||
|
|
||||||
Yields:
|
|
||||||
StreamChunk objects
|
|
||||||
"""
|
|
||||||
last_usage = None
|
|
||||||
|
|
||||||
try:
|
|
||||||
for chunk in response:
|
|
||||||
# Check for errors
|
|
||||||
if hasattr(chunk, "error") and chunk.error:
|
|
||||||
yield StreamChunk(
|
|
||||||
id=getattr(chunk, "id", ""),
|
|
||||||
error=chunk.error.message if hasattr(chunk.error, "message") else str(chunk.error),
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Extract delta content
|
|
||||||
delta_content = None
|
|
||||||
finish_reason = None
|
|
||||||
|
|
||||||
if hasattr(chunk, "choices") and chunk.choices:
|
|
||||||
choice = chunk.choices[0]
|
|
||||||
if hasattr(choice, "delta"):
|
|
||||||
delta = choice.delta
|
|
||||||
if hasattr(delta, "content") and delta.content:
|
|
||||||
delta_content = delta.content
|
|
||||||
finish_reason = getattr(choice, "finish_reason", None)
|
|
||||||
|
|
||||||
# Track usage from last chunk
|
|
||||||
if hasattr(chunk, "usage") and chunk.usage:
|
|
||||||
last_usage = self._parse_usage(chunk.usage)
|
|
||||||
|
|
||||||
yield StreamChunk(
|
|
||||||
id=getattr(chunk, "id", ""),
|
|
||||||
delta_content=delta_content,
|
|
||||||
finish_reason=finish_reason,
|
|
||||||
usage=last_usage if finish_reason else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Stream error: {e}")
|
|
||||||
yield StreamChunk(id="", error=str(e))
|
|
||||||
|
|
||||||
async def chat_async(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: List[ChatMessage],
|
|
||||||
stream: bool = False,
|
|
||||||
max_tokens: Optional[int] = None,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
tools: Optional[List[Dict[str, Any]]] = None,
|
|
||||||
tool_choice: Optional[str] = None,
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
|
|
||||||
"""
|
|
||||||
Send an async chat completion request.
|
|
||||||
|
|
||||||
Note: Currently wraps the sync implementation.
|
|
||||||
TODO: Implement true async support when OpenRouter SDK supports it.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: Model ID to use
|
|
||||||
messages: List of chat messages
|
|
||||||
stream: Whether to stream the response
|
|
||||||
max_tokens: Maximum tokens in response
|
|
||||||
temperature: Sampling temperature
|
|
||||||
tools: List of tool definitions
|
|
||||||
tool_choice: Tool selection mode
|
|
||||||
**kwargs: Additional parameters
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ChatResponse for non-streaming, AsyncIterator for streaming
|
|
||||||
"""
|
|
||||||
# For now, use sync implementation
|
|
||||||
# TODO: Add true async when SDK supports it
|
|
||||||
result = self.chat(
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
stream=stream,
|
|
||||||
max_tokens=max_tokens,
|
|
||||||
temperature=temperature,
|
|
||||||
tools=tools,
|
|
||||||
tool_choice=tool_choice,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if stream and isinstance(result, Iterator):
|
|
||||||
# Convert sync iterator to async
|
|
||||||
async def async_iter() -> AsyncIterator[StreamChunk]:
|
|
||||||
for chunk in result:
|
|
||||||
yield chunk
|
|
||||||
|
|
||||||
return async_iter()
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def get_credits(self) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Get OpenRouter account credit information.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dict with credit info:
|
|
||||||
- total_credits: Total credits purchased
|
|
||||||
- used_credits: Credits used
|
|
||||||
- credits_left: Remaining credits
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
Exception: If API request fails
|
|
||||||
"""
|
|
||||||
if not self.api_key:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.get(
|
|
||||||
f"{self.base_url}/credits",
|
|
||||||
headers=self._get_headers(),
|
|
||||||
timeout=10,
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
data = response.json().get("data", {})
|
|
||||||
total_credits = float(data.get("total_credits", 0))
|
|
||||||
total_usage = float(data.get("total_usage", 0))
|
|
||||||
credits_left = total_credits - total_usage
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total_credits": total_credits,
|
|
||||||
"used_credits": total_usage,
|
|
||||||
"credits_left": credits_left,
|
|
||||||
"total_credits_formatted": f"${total_credits:.2f}",
|
|
||||||
"used_credits_formatted": f"${total_usage:.2f}",
|
|
||||||
"credits_left_formatted": f"${credits_left:.2f}",
|
|
||||||
}
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
self.logger.error(f"Failed to fetch credits: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def clear_cache(self) -> None:
|
|
||||||
"""Clear the models cache to force a refresh."""
|
|
||||||
self._models_cache = None
|
|
||||||
self._raw_models_cache = None
|
|
||||||
self.logger.debug("Models cache cleared")
|
|
||||||
|
|
||||||
def get_effective_model_id(self, model_id: str, online_enabled: bool) -> str:
|
|
||||||
"""
|
|
||||||
Get the effective model ID with online suffix if needed.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id: Base model ID
|
|
||||||
online_enabled: Whether online mode is enabled
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Model ID with :online suffix if applicable
|
|
||||||
"""
|
|
||||||
if online_enabled and not model_id.endswith(":online"):
|
|
||||||
return f"{model_id}:online"
|
|
||||||
return model_id
|
|
||||||
|
|
||||||
def estimate_cost(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
input_tokens: int,
|
|
||||||
output_tokens: int,
|
|
||||||
) -> float:
|
|
||||||
"""
|
|
||||||
Estimate the cost for a completion.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_id: Model ID
|
|
||||||
input_tokens: Number of input tokens
|
|
||||||
output_tokens: Number of output tokens
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Estimated cost in USD
|
|
||||||
"""
|
|
||||||
model = self.get_model(model_id)
|
|
||||||
if model and model.pricing:
|
|
||||||
input_cost = model.pricing.get("prompt", 0) * input_tokens / 1_000_000
|
|
||||||
output_cost = model.pricing.get("completion", 0) * output_tokens / 1_000_000
|
|
||||||
return input_cost + output_cost
|
|
||||||
|
|
||||||
# Fallback to default pricing if model not found
|
|
||||||
from oai.constants import MODEL_PRICING
|
|
||||||
|
|
||||||
input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000
|
|
||||||
output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000
|
|
||||||
return input_cost + output_cost
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
# Marker file for PEP 561
|
|
||||||
# This package supports type checking
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
"""Textual TUI interface for oAI."""
|
|
||||||
|
|
||||||
from oai.tui.app import oAIChatApp
|
|
||||||
|
|
||||||
__all__ = ["oAIChatApp"]
|
|
||||||
1002
oai/tui/app.py
1002
oai/tui/app.py
File diff suppressed because it is too large
Load Diff
@@ -1,21 +0,0 @@
|
|||||||
"""TUI screens for oAI."""
|
|
||||||
|
|
||||||
from oai.tui.screens.config_screen import ConfigScreen
|
|
||||||
from oai.tui.screens.conversation_selector import ConversationSelectorScreen
|
|
||||||
from oai.tui.screens.credits_screen import CreditsScreen
|
|
||||||
from oai.tui.screens.dialogs import AlertDialog, ConfirmDialog, InputDialog
|
|
||||||
from oai.tui.screens.help_screen import HelpScreen
|
|
||||||
from oai.tui.screens.model_selector import ModelSelectorScreen
|
|
||||||
from oai.tui.screens.stats_screen import StatsScreen
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"AlertDialog",
|
|
||||||
"ConfirmDialog",
|
|
||||||
"ConfigScreen",
|
|
||||||
"ConversationSelectorScreen",
|
|
||||||
"CreditsScreen",
|
|
||||||
"InputDialog",
|
|
||||||
"HelpScreen",
|
|
||||||
"ModelSelectorScreen",
|
|
||||||
"StatsScreen",
|
|
||||||
]
|
|
||||||
@@ -1,107 +0,0 @@
|
|||||||
"""Configuration screen for oAI TUI."""
|
|
||||||
|
|
||||||
from textual.app import ComposeResult
|
|
||||||
from textual.containers import Container, Vertical
|
|
||||||
from textual.screen import ModalScreen
|
|
||||||
from textual.widgets import Button, Static
|
|
||||||
|
|
||||||
from oai.config.settings import Settings
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigScreen(ModalScreen[None]):
|
|
||||||
"""Modal screen displaying configuration settings."""
|
|
||||||
|
|
||||||
DEFAULT_CSS = """
|
|
||||||
ConfigScreen {
|
|
||||||
align: center middle;
|
|
||||||
}
|
|
||||||
|
|
||||||
ConfigScreen > Container {
|
|
||||||
width: 70;
|
|
||||||
height: auto;
|
|
||||||
background: #1e1e1e;
|
|
||||||
border: solid #555555;
|
|
||||||
}
|
|
||||||
|
|
||||||
ConfigScreen .header {
|
|
||||||
dock: top;
|
|
||||||
width: 100%;
|
|
||||||
height: auto;
|
|
||||||
background: #2d2d2d;
|
|
||||||
color: #cccccc;
|
|
||||||
padding: 0 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
ConfigScreen .content {
|
|
||||||
width: 100%;
|
|
||||||
height: auto;
|
|
||||||
background: #1e1e1e;
|
|
||||||
padding: 2;
|
|
||||||
color: #cccccc;
|
|
||||||
}
|
|
||||||
|
|
||||||
ConfigScreen .footer {
|
|
||||||
dock: bottom;
|
|
||||||
width: 100%;
|
|
||||||
height: auto;
|
|
||||||
background: #2d2d2d;
|
|
||||||
padding: 1 2;
|
|
||||||
align: center middle;
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, settings: Settings):
|
|
||||||
super().__init__()
|
|
||||||
self.settings = settings
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
"""Compose the screen."""
|
|
||||||
with Container():
|
|
||||||
yield Static("[bold]Configuration[/]", classes="header")
|
|
||||||
with Vertical(classes="content"):
|
|
||||||
yield Static(self._get_config_text(), markup=True)
|
|
||||||
with Vertical(classes="footer"):
|
|
||||||
yield Button("Close", id="close", variant="primary")
|
|
||||||
|
|
||||||
def _get_config_text(self) -> str:
|
|
||||||
"""Generate the configuration text."""
|
|
||||||
from oai.constants import DEFAULT_SYSTEM_PROMPT
|
|
||||||
|
|
||||||
# API Key display
|
|
||||||
api_key_display = "***" + self.settings.api_key[-4:] if self.settings.api_key else "Not set"
|
|
||||||
|
|
||||||
# System prompt display
|
|
||||||
if self.settings.default_system_prompt is None:
|
|
||||||
system_prompt_display = f"[default] {DEFAULT_SYSTEM_PROMPT[:40]}..."
|
|
||||||
elif self.settings.default_system_prompt == "":
|
|
||||||
system_prompt_display = "[blank]"
|
|
||||||
else:
|
|
||||||
prompt = self.settings.default_system_prompt
|
|
||||||
system_prompt_display = prompt[:50] + "..." if len(prompt) > 50 else prompt
|
|
||||||
|
|
||||||
return f"""
|
|
||||||
[bold cyan]═══ CONFIGURATION ═══[/]
|
|
||||||
|
|
||||||
[bold]API Key:[/] {api_key_display}
|
|
||||||
[bold]Base URL:[/] {self.settings.base_url}
|
|
||||||
[bold]Default Model:[/] {self.settings.default_model or "Not set"}
|
|
||||||
|
|
||||||
[bold]System Prompt:[/] {system_prompt_display}
|
|
||||||
|
|
||||||
[bold]Streaming:[/] {"on" if self.settings.stream_enabled else "off"}
|
|
||||||
[bold]Cost Warning:[/] ${self.settings.cost_warning_threshold:.4f}
|
|
||||||
[bold]Max Tokens:[/] {self.settings.max_tokens}
|
|
||||||
[bold]Default Online:[/] {"on" if self.settings.default_online_mode else "off"}
|
|
||||||
[bold]Log Level:[/] {self.settings.log_level}
|
|
||||||
|
|
||||||
[dim]Use /config [setting] [value] to modify settings[/]
|
|
||||||
"""
|
|
||||||
|
|
||||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
||||||
"""Handle button press."""
|
|
||||||
self.dismiss()
|
|
||||||
|
|
||||||
def on_key(self, event) -> None:
|
|
||||||
"""Handle keyboard shortcuts."""
|
|
||||||
if event.key in ("escape", "enter"):
|
|
||||||
self.dismiss()
|
|
||||||
@@ -1,205 +0,0 @@
|
|||||||
"""Conversation selector screen for oAI TUI."""
|
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from textual.app import ComposeResult
|
|
||||||
from textual.containers import Container, Vertical
|
|
||||||
from textual.screen import ModalScreen
|
|
||||||
from textual.widgets import Button, DataTable, Input, Static
|
|
||||||
|
|
||||||
|
|
||||||
class ConversationSelectorScreen(ModalScreen[Optional[dict]]):
|
|
||||||
"""Modal screen for selecting a saved conversation."""
|
|
||||||
|
|
||||||
DEFAULT_CSS = """
|
|
||||||
ConversationSelectorScreen {
|
|
||||||
align: center middle;
|
|
||||||
}
|
|
||||||
|
|
||||||
ConversationSelectorScreen > Container {
|
|
||||||
width: 80%;
|
|
||||||
height: 70%;
|
|
||||||
background: #1e1e1e;
|
|
||||||
border: solid #555555;
|
|
||||||
layout: vertical;
|
|
||||||
}
|
|
||||||
|
|
||||||
ConversationSelectorScreen .header {
|
|
||||||
height: 3;
|
|
||||||
width: 100%;
|
|
||||||
background: #2d2d2d;
|
|
||||||
color: #cccccc;
|
|
||||||
padding: 0 2;
|
|
||||||
content-align: center middle;
|
|
||||||
}
|
|
||||||
|
|
||||||
ConversationSelectorScreen .search-input {
|
|
||||||
height: 3;
|
|
||||||
width: 100%;
|
|
||||||
background: #2a2a2a;
|
|
||||||
border: solid #555555;
|
|
||||||
margin: 0 0 1 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
ConversationSelectorScreen .search-input:focus {
|
|
||||||
border: solid #888888;
|
|
||||||
}
|
|
||||||
|
|
||||||
ConversationSelectorScreen DataTable {
|
|
||||||
height: 1fr;
|
|
||||||
width: 100%;
|
|
||||||
background: #1e1e1e;
|
|
||||||
border: solid #555555;
|
|
||||||
}
|
|
||||||
|
|
||||||
ConversationSelectorScreen .footer {
|
|
||||||
height: 5;
|
|
||||||
width: 100%;
|
|
||||||
background: #2d2d2d;
|
|
||||||
padding: 1 2;
|
|
||||||
align: center middle;
|
|
||||||
}
|
|
||||||
|
|
||||||
ConversationSelectorScreen Button {
|
|
||||||
margin: 0 1;
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, conversations: List[dict]):
|
|
||||||
super().__init__()
|
|
||||||
self.all_conversations = conversations
|
|
||||||
self.filtered_conversations = conversations
|
|
||||||
self.selected_conversation: Optional[dict] = None
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
"""Compose the screen."""
|
|
||||||
with Container():
|
|
||||||
yield Static(
|
|
||||||
f"[bold]Load Conversation[/] [dim]({len(self.all_conversations)} saved)[/]",
|
|
||||||
classes="header"
|
|
||||||
)
|
|
||||||
yield Input(placeholder="Search conversations...", id="search-input", classes="search-input")
|
|
||||||
yield DataTable(id="conv-table", cursor_type="row", show_header=True, zebra_stripes=True)
|
|
||||||
with Vertical(classes="footer"):
|
|
||||||
yield Button("Load", id="load", variant="success")
|
|
||||||
yield Button("Cancel", id="cancel", variant="error")
|
|
||||||
|
|
||||||
def on_mount(self) -> None:
|
|
||||||
"""Initialize the table when mounted."""
|
|
||||||
table = self.query_one("#conv-table", DataTable)
|
|
||||||
|
|
||||||
# Add columns
|
|
||||||
table.add_column("#", width=5)
|
|
||||||
table.add_column("Name", width=40)
|
|
||||||
table.add_column("Messages", width=12)
|
|
||||||
table.add_column("Last Saved", width=20)
|
|
||||||
|
|
||||||
# Populate table
|
|
||||||
self._populate_table()
|
|
||||||
|
|
||||||
# Focus table if list is small (fits on screen), otherwise focus search
|
|
||||||
if len(self.all_conversations) <= 10:
|
|
||||||
table.focus()
|
|
||||||
else:
|
|
||||||
search_input = self.query_one("#search-input", Input)
|
|
||||||
search_input.focus()
|
|
||||||
|
|
||||||
def _populate_table(self) -> None:
|
|
||||||
"""Populate the table with conversations."""
|
|
||||||
table = self.query_one("#conv-table", DataTable)
|
|
||||||
table.clear()
|
|
||||||
|
|
||||||
for idx, conv in enumerate(self.filtered_conversations, 1):
|
|
||||||
name = conv.get("name", "Unknown")
|
|
||||||
message_count = str(conv.get("message_count", 0))
|
|
||||||
last_saved = conv.get("last_saved", "Unknown")
|
|
||||||
|
|
||||||
# Format timestamp if it's a full datetime
|
|
||||||
if "T" in last_saved or len(last_saved) > 20:
|
|
||||||
try:
|
|
||||||
# Truncate to just date and time
|
|
||||||
last_saved = last_saved[:19].replace("T", " ")
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
|
|
||||||
table.add_row(
|
|
||||||
str(idx),
|
|
||||||
name,
|
|
||||||
message_count,
|
|
||||||
last_saved,
|
|
||||||
key=str(idx)
|
|
||||||
)
|
|
||||||
|
|
||||||
def on_input_changed(self, event: Input.Changed) -> None:
|
|
||||||
"""Filter conversations based on search input."""
|
|
||||||
if event.input.id != "search-input":
|
|
||||||
return
|
|
||||||
|
|
||||||
search_term = event.value.lower()
|
|
||||||
|
|
||||||
if not search_term:
|
|
||||||
self.filtered_conversations = self.all_conversations
|
|
||||||
else:
|
|
||||||
self.filtered_conversations = [
|
|
||||||
c for c in self.all_conversations
|
|
||||||
if search_term in c.get("name", "").lower()
|
|
||||||
]
|
|
||||||
|
|
||||||
self._populate_table()
|
|
||||||
|
|
||||||
def on_data_table_row_selected(self, event: DataTable.RowSelected) -> None:
|
|
||||||
"""Handle row selection (click)."""
|
|
||||||
try:
|
|
||||||
row_index = int(event.row_key.value) - 1
|
|
||||||
if 0 <= row_index < len(self.filtered_conversations):
|
|
||||||
self.selected_conversation = self.filtered_conversations[row_index]
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_data_table_row_highlighted(self, event) -> None:
|
|
||||||
"""Handle row highlight (arrow key navigation)."""
|
|
||||||
try:
|
|
||||||
table = self.query_one("#conv-table", DataTable)
|
|
||||||
if table.cursor_row is not None:
|
|
||||||
row_data = table.get_row_at(table.cursor_row)
|
|
||||||
if row_data:
|
|
||||||
row_index = int(row_data[0]) - 1
|
|
||||||
if 0 <= row_index < len(self.filtered_conversations):
|
|
||||||
self.selected_conversation = self.filtered_conversations[row_index]
|
|
||||||
except (ValueError, IndexError, AttributeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
||||||
"""Handle button press."""
|
|
||||||
if event.button.id == "load":
|
|
||||||
if self.selected_conversation:
|
|
||||||
self.dismiss(self.selected_conversation)
|
|
||||||
else:
|
|
||||||
self.dismiss(None)
|
|
||||||
else:
|
|
||||||
self.dismiss(None)
|
|
||||||
|
|
||||||
def on_key(self, event) -> None:
|
|
||||||
"""Handle keyboard shortcuts."""
|
|
||||||
if event.key == "escape":
|
|
||||||
self.dismiss(None)
|
|
||||||
elif event.key == "enter":
|
|
||||||
# If in search input, move to table
|
|
||||||
search_input = self.query_one("#search-input", Input)
|
|
||||||
if search_input.has_focus:
|
|
||||||
table = self.query_one("#conv-table", DataTable)
|
|
||||||
table.focus()
|
|
||||||
# If in table, select current row
|
|
||||||
else:
|
|
||||||
table = self.query_one("#conv-table", DataTable)
|
|
||||||
if table.cursor_row is not None:
|
|
||||||
try:
|
|
||||||
row_data = table.get_row_at(table.cursor_row)
|
|
||||||
if row_data:
|
|
||||||
row_index = int(row_data[0]) - 1
|
|
||||||
if 0 <= row_index < len(self.filtered_conversations):
|
|
||||||
selected = self.filtered_conversations[row_index]
|
|
||||||
self.dismiss(selected)
|
|
||||||
except (ValueError, IndexError, AttributeError):
|
|
||||||
if self.selected_conversation:
|
|
||||||
self.dismiss(self.selected_conversation)
|
|
||||||
@@ -1,125 +0,0 @@
|
|||||||
"""Credits screen for oAI TUI."""
|
|
||||||
|
|
||||||
from typing import Optional, Dict, Any
|
|
||||||
|
|
||||||
from textual.app import ComposeResult
|
|
||||||
from textual.containers import Container, Vertical
|
|
||||||
from textual.screen import ModalScreen
|
|
||||||
from textual.widgets import Button, Static
|
|
||||||
|
|
||||||
from oai.core.client import AIClient
|
|
||||||
|
|
||||||
|
|
||||||
class CreditsScreen(ModalScreen[None]):
|
|
||||||
"""Modal screen displaying account credits."""
|
|
||||||
|
|
||||||
DEFAULT_CSS = """
|
|
||||||
CreditsScreen {
|
|
||||||
align: center middle;
|
|
||||||
}
|
|
||||||
|
|
||||||
CreditsScreen > Container {
|
|
||||||
width: 60;
|
|
||||||
height: auto;
|
|
||||||
background: #1e1e1e;
|
|
||||||
border: solid #555555;
|
|
||||||
}
|
|
||||||
|
|
||||||
CreditsScreen .header {
|
|
||||||
dock: top;
|
|
||||||
width: 100%;
|
|
||||||
height: auto;
|
|
||||||
background: #2d2d2d;
|
|
||||||
color: #cccccc;
|
|
||||||
padding: 0 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
CreditsScreen .content {
|
|
||||||
width: 100%;
|
|
||||||
height: auto;
|
|
||||||
background: #1e1e1e;
|
|
||||||
padding: 2;
|
|
||||||
color: #cccccc;
|
|
||||||
}
|
|
||||||
|
|
||||||
CreditsScreen .footer {
|
|
||||||
dock: bottom;
|
|
||||||
width: 100%;
|
|
||||||
height: auto;
|
|
||||||
background: #2d2d2d;
|
|
||||||
padding: 1 2;
|
|
||||||
align: center middle;
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, client: AIClient):
|
|
||||||
super().__init__()
|
|
||||||
self.client = client
|
|
||||||
self.credits_data: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
"""Compose the screen."""
|
|
||||||
with Container():
|
|
||||||
yield Static("[bold]Account Credits[/]", classes="header")
|
|
||||||
with Vertical(classes="content"):
|
|
||||||
yield Static("[dim]Loading...[/]", id="credits-content", markup=True)
|
|
||||||
with Vertical(classes="footer"):
|
|
||||||
yield Button("Close", id="close", variant="primary")
|
|
||||||
|
|
||||||
def on_mount(self) -> None:
|
|
||||||
"""Fetch credits when mounted."""
|
|
||||||
self.fetch_credits()
|
|
||||||
|
|
||||||
def fetch_credits(self) -> None:
|
|
||||||
"""Fetch and display credits information."""
|
|
||||||
try:
|
|
||||||
self.credits_data = self.client.provider.get_credits()
|
|
||||||
content = self.query_one("#credits-content", Static)
|
|
||||||
content.update(self._get_credits_text())
|
|
||||||
except Exception as e:
|
|
||||||
content = self.query_one("#credits-content", Static)
|
|
||||||
content.update(f"[red]Error fetching credits:[/]\n{str(e)}")
|
|
||||||
|
|
||||||
def _get_credits_text(self) -> str:
|
|
||||||
"""Generate the credits text."""
|
|
||||||
if not self.credits_data:
|
|
||||||
return "[yellow]No credit information available[/]"
|
|
||||||
|
|
||||||
total = self.credits_data.get("total_credits", 0)
|
|
||||||
used = self.credits_data.get("used_credits", 0)
|
|
||||||
remaining = self.credits_data.get("credits_left", 0)
|
|
||||||
|
|
||||||
# Calculate percentage used
|
|
||||||
if total > 0:
|
|
||||||
percent_used = (used / total) * 100
|
|
||||||
percent_remaining = (remaining / total) * 100
|
|
||||||
else:
|
|
||||||
percent_used = 0
|
|
||||||
percent_remaining = 0
|
|
||||||
|
|
||||||
# Color code based on remaining credits
|
|
||||||
if percent_remaining > 50:
|
|
||||||
remaining_color = "green"
|
|
||||||
elif percent_remaining > 20:
|
|
||||||
remaining_color = "yellow"
|
|
||||||
else:
|
|
||||||
remaining_color = "red"
|
|
||||||
|
|
||||||
return f"""
|
|
||||||
[bold cyan]═══ OPENROUTER CREDITS ═══[/]
|
|
||||||
|
|
||||||
[bold]Total Credits:[/] ${total:.2f}
|
|
||||||
[bold]Used:[/] ${used:.2f} [dim]({percent_used:.1f}%)[/]
|
|
||||||
[bold]Remaining:[/] [{remaining_color}]${remaining:.2f}[/] [dim]({percent_remaining:.1f}%)[/]
|
|
||||||
|
|
||||||
[dim]Visit openrouter.ai to add more credits[/]
|
|
||||||
"""
|
|
||||||
|
|
||||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
||||||
"""Handle button press."""
|
|
||||||
self.dismiss()
|
|
||||||
|
|
||||||
def on_key(self, event) -> None:
|
|
||||||
"""Handle keyboard shortcuts."""
|
|
||||||
if event.key in ("escape", "enter"):
|
|
||||||
self.dismiss()
|
|
||||||
@@ -1,236 +0,0 @@
|
|||||||
"""Modal dialog screens for oAI TUI."""
|
|
||||||
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
from textual.app import ComposeResult
|
|
||||||
from textual.containers import Container, Horizontal, Vertical
|
|
||||||
from textual.screen import ModalScreen
|
|
||||||
from textual.widgets import Button, Input, Label, Static
|
|
||||||
|
|
||||||
|
|
||||||
class ConfirmDialog(ModalScreen[bool]):
|
|
||||||
"""A confirmation dialog with Yes/No buttons."""
|
|
||||||
|
|
||||||
DEFAULT_CSS = """
|
|
||||||
ConfirmDialog {
|
|
||||||
align: center middle;
|
|
||||||
}
|
|
||||||
|
|
||||||
ConfirmDialog > Container {
|
|
||||||
width: 60;
|
|
||||||
height: auto;
|
|
||||||
background: #2d2d2d;
|
|
||||||
border: solid #555555;
|
|
||||||
padding: 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
ConfirmDialog Label {
|
|
||||||
width: 100%;
|
|
||||||
content-align: center middle;
|
|
||||||
margin-bottom: 2;
|
|
||||||
color: #cccccc;
|
|
||||||
}
|
|
||||||
|
|
||||||
ConfirmDialog Horizontal {
|
|
||||||
width: 100%;
|
|
||||||
height: auto;
|
|
||||||
align: center middle;
|
|
||||||
}
|
|
||||||
|
|
||||||
ConfirmDialog Button {
|
|
||||||
margin: 0 1;
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
message: str,
|
|
||||||
title: str = "Confirm",
|
|
||||||
yes_label: str = "Yes",
|
|
||||||
no_label: str = "No",
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.message = message
|
|
||||||
self.title = title
|
|
||||||
self.yes_label = yes_label
|
|
||||||
self.no_label = no_label
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
"""Compose the dialog."""
|
|
||||||
with Container():
|
|
||||||
yield Static(f"[bold]{self.title}[/]", classes="dialog-title")
|
|
||||||
yield Label(self.message)
|
|
||||||
with Horizontal():
|
|
||||||
yield Button(self.yes_label, id="yes", variant="success")
|
|
||||||
yield Button(self.no_label, id="no", variant="error")
|
|
||||||
|
|
||||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
||||||
"""Handle button press."""
|
|
||||||
if event.button.id == "yes":
|
|
||||||
self.dismiss(True)
|
|
||||||
else:
|
|
||||||
self.dismiss(False)
|
|
||||||
|
|
||||||
def on_key(self, event) -> None:
|
|
||||||
"""Handle keyboard shortcuts."""
|
|
||||||
if event.key == "escape":
|
|
||||||
self.dismiss(False)
|
|
||||||
elif event.key == "enter":
|
|
||||||
self.dismiss(True)
|
|
||||||
|
|
||||||
|
|
||||||
class InputDialog(ModalScreen[Optional[str]]):
|
|
||||||
"""An input dialog for text entry."""
|
|
||||||
|
|
||||||
DEFAULT_CSS = """
|
|
||||||
InputDialog {
|
|
||||||
align: center middle;
|
|
||||||
}
|
|
||||||
|
|
||||||
InputDialog > Container {
|
|
||||||
width: 70;
|
|
||||||
height: auto;
|
|
||||||
background: #2d2d2d;
|
|
||||||
border: solid #555555;
|
|
||||||
padding: 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
InputDialog Label {
|
|
||||||
width: 100%;
|
|
||||||
margin-bottom: 1;
|
|
||||||
color: #cccccc;
|
|
||||||
}
|
|
||||||
|
|
||||||
InputDialog Input {
|
|
||||||
width: 100%;
|
|
||||||
margin-bottom: 2;
|
|
||||||
background: #3a3a3a;
|
|
||||||
border: solid #555555;
|
|
||||||
}
|
|
||||||
|
|
||||||
InputDialog Input:focus {
|
|
||||||
border: solid #888888;
|
|
||||||
}
|
|
||||||
|
|
||||||
InputDialog Horizontal {
|
|
||||||
width: 100%;
|
|
||||||
height: auto;
|
|
||||||
align: center middle;
|
|
||||||
}
|
|
||||||
|
|
||||||
InputDialog Button {
|
|
||||||
margin: 0 1;
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
message: str,
|
|
||||||
title: str = "Input",
|
|
||||||
default: str = "",
|
|
||||||
placeholder: str = "",
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.message = message
|
|
||||||
self.title = title
|
|
||||||
self.default = default
|
|
||||||
self.placeholder = placeholder
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
"""Compose the dialog."""
|
|
||||||
with Container():
|
|
||||||
yield Static(f"[bold]{self.title}[/]", classes="dialog-title")
|
|
||||||
yield Label(self.message)
|
|
||||||
yield Input(
|
|
||||||
value=self.default,
|
|
||||||
placeholder=self.placeholder,
|
|
||||||
id="input-field"
|
|
||||||
)
|
|
||||||
with Horizontal():
|
|
||||||
yield Button("OK", id="ok", variant="primary")
|
|
||||||
yield Button("Cancel", id="cancel")
|
|
||||||
|
|
||||||
def on_mount(self) -> None:
|
|
||||||
"""Focus the input field when mounted."""
|
|
||||||
input_field = self.query_one("#input-field", Input)
|
|
||||||
input_field.focus()
|
|
||||||
|
|
||||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
||||||
"""Handle button press."""
|
|
||||||
if event.button.id == "ok":
|
|
||||||
input_field = self.query_one("#input-field", Input)
|
|
||||||
self.dismiss(input_field.value)
|
|
||||||
else:
|
|
||||||
self.dismiss(None)
|
|
||||||
|
|
||||||
def on_input_submitted(self, event: Input.Submitted) -> None:
|
|
||||||
"""Handle Enter key in input field."""
|
|
||||||
self.dismiss(event.value)
|
|
||||||
|
|
||||||
def on_key(self, event) -> None:
|
|
||||||
"""Handle keyboard shortcuts."""
|
|
||||||
if event.key == "escape":
|
|
||||||
self.dismiss(None)
|
|
||||||
|
|
||||||
|
|
||||||
class AlertDialog(ModalScreen[None]):
|
|
||||||
"""A simple alert/message dialog."""
|
|
||||||
|
|
||||||
DEFAULT_CSS = """
|
|
||||||
AlertDialog {
|
|
||||||
align: center middle;
|
|
||||||
}
|
|
||||||
|
|
||||||
AlertDialog > Container {
|
|
||||||
width: 60;
|
|
||||||
height: auto;
|
|
||||||
background: #2d2d2d;
|
|
||||||
border: solid #555555;
|
|
||||||
padding: 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
AlertDialog Label {
|
|
||||||
width: 100%;
|
|
||||||
content-align: center middle;
|
|
||||||
margin-bottom: 2;
|
|
||||||
color: #cccccc;
|
|
||||||
}
|
|
||||||
|
|
||||||
AlertDialog Horizontal {
|
|
||||||
width: 100%;
|
|
||||||
height: auto;
|
|
||||||
align: center middle;
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, message: str, title: str = "Alert", variant: str = "default"):
|
|
||||||
super().__init__()
|
|
||||||
self.message = message
|
|
||||||
self.title = title
|
|
||||||
self.variant = variant
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
"""Compose the dialog."""
|
|
||||||
# Choose color based on variant (using design system)
|
|
||||||
color = "$primary"
|
|
||||||
if self.variant == "error":
|
|
||||||
color = "$error"
|
|
||||||
elif self.variant == "success":
|
|
||||||
color = "$success"
|
|
||||||
elif self.variant == "warning":
|
|
||||||
color = "$warning"
|
|
||||||
|
|
||||||
with Container():
|
|
||||||
yield Static(f"[bold {color}]{self.title}[/]", classes="dialog-title")
|
|
||||||
yield Label(self.message)
|
|
||||||
with Horizontal():
|
|
||||||
yield Button("OK", id="ok", variant="primary")
|
|
||||||
|
|
||||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
||||||
"""Handle button press."""
|
|
||||||
self.dismiss()
|
|
||||||
|
|
||||||
def on_key(self, event) -> None:
|
|
||||||
"""Handle keyboard shortcuts."""
|
|
||||||
if event.key in ("escape", "enter"):
|
|
||||||
self.dismiss()
|
|
||||||
@@ -1,138 +0,0 @@
|
|||||||
"""Help screen for oAI TUI."""
|
|
||||||
|
|
||||||
from textual.app import ComposeResult
|
|
||||||
from textual.containers import Container, Vertical
|
|
||||||
from textual.screen import ModalScreen
|
|
||||||
from textual.widgets import Button, Static
|
|
||||||
|
|
||||||
|
|
||||||
class HelpScreen(ModalScreen[None]):
|
|
||||||
"""Modal screen displaying help and commands."""
|
|
||||||
|
|
||||||
DEFAULT_CSS = """
|
|
||||||
HelpScreen {
|
|
||||||
align: center middle;
|
|
||||||
}
|
|
||||||
|
|
||||||
HelpScreen > Container {
|
|
||||||
width: 90%;
|
|
||||||
height: 85%;
|
|
||||||
background: #1e1e1e;
|
|
||||||
border: solid #555555;
|
|
||||||
}
|
|
||||||
|
|
||||||
HelpScreen .header {
|
|
||||||
dock: top;
|
|
||||||
width: 100%;
|
|
||||||
height: auto;
|
|
||||||
background: #2d2d2d;
|
|
||||||
color: #cccccc;
|
|
||||||
padding: 0 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
HelpScreen .content {
|
|
||||||
height: 1fr;
|
|
||||||
background: #1e1e1e;
|
|
||||||
padding: 2;
|
|
||||||
overflow-y: auto;
|
|
||||||
color: #cccccc;
|
|
||||||
}
|
|
||||||
|
|
||||||
HelpScreen .footer {
|
|
||||||
dock: bottom;
|
|
||||||
width: 100%;
|
|
||||||
height: auto;
|
|
||||||
background: #2d2d2d;
|
|
||||||
padding: 1 2;
|
|
||||||
align: center middle;
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
"""Compose the screen."""
|
|
||||||
with Container():
|
|
||||||
yield Static("[bold]oAI Help & Commands[/]", classes="header")
|
|
||||||
with Vertical(classes="content"):
|
|
||||||
yield Static(self._get_help_text(), markup=True)
|
|
||||||
with Vertical(classes="footer"):
|
|
||||||
yield Button("Close", id="close", variant="primary")
|
|
||||||
|
|
||||||
def _get_help_text(self) -> str:
|
|
||||||
"""Generate the help text."""
|
|
||||||
return """
|
|
||||||
[bold cyan]═══ KEYBOARD SHORTCUTS ═══[/]
|
|
||||||
[bold]F1[/] Show this help (Ctrl+H may not work)
|
|
||||||
[bold]F2[/] Open model selector (Ctrl+M may not work)
|
|
||||||
[bold]Ctrl+S[/] Show session statistics
|
|
||||||
[bold]Ctrl+L[/] Clear chat display
|
|
||||||
[bold]Ctrl+P[/] Show previous message
|
|
||||||
[bold]Ctrl+N[/] Show next message
|
|
||||||
[bold]Ctrl+Q[/] Quit application
|
|
||||||
[bold]Up/Down[/] Navigate input history
|
|
||||||
[bold]ESC[/] Close dialogs
|
|
||||||
[dim]Note: Some Ctrl keys may be captured by your terminal[/]
|
|
||||||
|
|
||||||
[bold cyan]═══ SLASH COMMANDS ═══[/]
|
|
||||||
[bold yellow]Session Control:[/]
|
|
||||||
/reset Clear conversation history (with confirmation)
|
|
||||||
/clear Clear the chat display
|
|
||||||
/memory on/off Toggle conversation memory
|
|
||||||
/online on/off Toggle online search mode
|
|
||||||
/exit, /quit, /bye Exit the application
|
|
||||||
|
|
||||||
[bold yellow]Model & Configuration:[/]
|
|
||||||
/model [search] Open model selector with optional search
|
|
||||||
/config View configuration settings
|
|
||||||
/config api Set API key (prompts for input)
|
|
||||||
/config stream on Enable streaming responses
|
|
||||||
/system [prompt] Set session system prompt
|
|
||||||
/maxtoken [n] Set session token limit
|
|
||||||
|
|
||||||
[bold yellow]Conversation Management:[/]
|
|
||||||
/save [name] Save current conversation
|
|
||||||
/load [name] Load saved conversation (shows picker if no name)
|
|
||||||
/list List all saved conversations
|
|
||||||
/delete <name> Delete a saved conversation
|
|
||||||
|
|
||||||
[bold yellow]Export:[/]
|
|
||||||
/export md [file] Export as Markdown
|
|
||||||
/export json [file] Export as JSON
|
|
||||||
/export html [file] Export as HTML
|
|
||||||
|
|
||||||
[bold yellow]History Navigation:[/]
|
|
||||||
/prev Show previous message in history
|
|
||||||
/next Show next message in history
|
|
||||||
|
|
||||||
[bold yellow]MCP (Model Context Protocol):[/]
|
|
||||||
/mcp on Enable MCP file access
|
|
||||||
/mcp off Disable MCP
|
|
||||||
/mcp status Show MCP status
|
|
||||||
/mcp add <path> Add folder for file access
|
|
||||||
/mcp list List registered folders
|
|
||||||
/mcp write Toggle write permissions
|
|
||||||
|
|
||||||
[bold yellow]Information & Utilities:[/]
|
|
||||||
/help Show this help screen
|
|
||||||
/stats Show session statistics
|
|
||||||
/credits Check account credits
|
|
||||||
/retry Retry last prompt
|
|
||||||
/paste Paste from clipboard and send
|
|
||||||
|
|
||||||
[bold cyan]═══ TIPS ═══[/]
|
|
||||||
• Type [bold]/[/] to see command suggestions with [bold]Tab[/] to autocomplete
|
|
||||||
• Use [bold]Up/Down arrows[/] to navigate your input history
|
|
||||||
• Type [bold]//[/] at start to escape commands (sends /help as literal message)
|
|
||||||
• All messages support [bold]Markdown formatting[/] with syntax highlighting
|
|
||||||
• Responses stream in real-time for better interactivity
|
|
||||||
• Enable MCP to let AI access your local files and databases
|
|
||||||
• Use [bold]F1[/] or [bold]F2[/] if Ctrl shortcuts don't work in your terminal
|
|
||||||
"""
|
|
||||||
|
|
||||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
||||||
"""Handle button press."""
|
|
||||||
self.dismiss()
|
|
||||||
|
|
||||||
def on_key(self, event) -> None:
|
|
||||||
"""Handle keyboard shortcuts."""
|
|
||||||
if event.key in ("escape", "enter"):
|
|
||||||
self.dismiss()
|
|
||||||
@@ -1,254 +0,0 @@
|
|||||||
"""Model selector screen for oAI TUI."""
|
|
||||||
|
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from textual.app import ComposeResult
|
|
||||||
from textual.containers import Container, Vertical
|
|
||||||
from textual.screen import ModalScreen
|
|
||||||
from textual.widgets import Button, DataTable, Input, Label, Static
|
|
||||||
|
|
||||||
|
|
||||||
class ModelSelectorScreen(ModalScreen[Optional[dict]]):
|
|
||||||
"""Modal screen for selecting an AI model."""
|
|
||||||
|
|
||||||
DEFAULT_CSS = """
|
|
||||||
ModelSelectorScreen {
|
|
||||||
align: center middle;
|
|
||||||
}
|
|
||||||
|
|
||||||
ModelSelectorScreen > Container {
|
|
||||||
width: 90%;
|
|
||||||
height: 85%;
|
|
||||||
background: #1e1e1e;
|
|
||||||
border: solid #555555;
|
|
||||||
layout: vertical;
|
|
||||||
}
|
|
||||||
|
|
||||||
ModelSelectorScreen .header {
|
|
||||||
height: 3;
|
|
||||||
width: 100%;
|
|
||||||
background: #2d2d2d;
|
|
||||||
color: #cccccc;
|
|
||||||
padding: 0 2;
|
|
||||||
content-align: center middle;
|
|
||||||
}
|
|
||||||
|
|
||||||
ModelSelectorScreen .search-input {
|
|
||||||
height: 3;
|
|
||||||
width: 100%;
|
|
||||||
background: #2a2a2a;
|
|
||||||
border: solid #555555;
|
|
||||||
margin: 0 0 1 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
ModelSelectorScreen .search-input:focus {
|
|
||||||
border: solid #888888;
|
|
||||||
}
|
|
||||||
|
|
||||||
ModelSelectorScreen DataTable {
|
|
||||||
height: 1fr;
|
|
||||||
width: 100%;
|
|
||||||
background: #1e1e1e;
|
|
||||||
border: solid #555555;
|
|
||||||
}
|
|
||||||
|
|
||||||
ModelSelectorScreen .footer {
|
|
||||||
height: 5;
|
|
||||||
width: 100%;
|
|
||||||
background: #2d2d2d;
|
|
||||||
padding: 1 2;
|
|
||||||
align: center middle;
|
|
||||||
}
|
|
||||||
|
|
||||||
ModelSelectorScreen Button {
|
|
||||||
margin: 0 1;
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, models: List[dict], current_model: Optional[str] = None):
|
|
||||||
super().__init__()
|
|
||||||
self.all_models = models
|
|
||||||
self.filtered_models = models
|
|
||||||
self.current_model = current_model
|
|
||||||
self.selected_model: Optional[dict] = None
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
"""Compose the screen."""
|
|
||||||
with Container():
|
|
||||||
yield Static(
|
|
||||||
f"[bold]Select Model[/] [dim]({len(self.all_models)} available)[/]",
|
|
||||||
classes="header"
|
|
||||||
)
|
|
||||||
yield Input(placeholder="Search to filter models...", id="search-input", classes="search-input")
|
|
||||||
yield DataTable(id="model-table", cursor_type="row", show_header=True, zebra_stripes=True)
|
|
||||||
with Vertical(classes="footer"):
|
|
||||||
yield Button("Select", id="select", variant="success")
|
|
||||||
yield Button("Cancel", id="cancel", variant="error")
|
|
||||||
|
|
||||||
def on_mount(self) -> None:
|
|
||||||
"""Initialize the table when mounted."""
|
|
||||||
table = self.query_one("#model-table", DataTable)
|
|
||||||
|
|
||||||
# Add columns
|
|
||||||
table.add_column("#", width=5)
|
|
||||||
table.add_column("Model ID", width=35)
|
|
||||||
table.add_column("Name", width=30)
|
|
||||||
table.add_column("Context", width=10)
|
|
||||||
table.add_column("Price", width=12)
|
|
||||||
table.add_column("Img", width=4)
|
|
||||||
table.add_column("Tools", width=6)
|
|
||||||
table.add_column("Online", width=7)
|
|
||||||
|
|
||||||
# Populate table
|
|
||||||
self._populate_table()
|
|
||||||
|
|
||||||
# Focus table if list is small (fits on screen), otherwise focus search
|
|
||||||
if len(self.filtered_models) <= 20:
|
|
||||||
table.focus()
|
|
||||||
else:
|
|
||||||
search_input = self.query_one("#search-input", Input)
|
|
||||||
search_input.focus()
|
|
||||||
|
|
||||||
def _populate_table(self) -> None:
|
|
||||||
"""Populate the table with models."""
|
|
||||||
table = self.query_one("#model-table", DataTable)
|
|
||||||
table.clear()
|
|
||||||
|
|
||||||
rows_added = 0
|
|
||||||
for idx, model in enumerate(self.filtered_models, 1):
|
|
||||||
try:
|
|
||||||
model_id = model.get("id", "")
|
|
||||||
name = model.get("name", "")
|
|
||||||
context = str(model.get("context_length", "N/A"))
|
|
||||||
|
|
||||||
# Format pricing
|
|
||||||
pricing = model.get("pricing", {})
|
|
||||||
prompt_price = pricing.get("prompt", "0")
|
|
||||||
completion_price = pricing.get("completion", "0")
|
|
||||||
|
|
||||||
# Convert to numbers and format
|
|
||||||
try:
|
|
||||||
prompt = float(prompt_price) * 1000000 # Convert to per 1M tokens
|
|
||||||
completion = float(completion_price) * 1000000
|
|
||||||
if prompt == 0 and completion == 0:
|
|
||||||
price = "Free"
|
|
||||||
else:
|
|
||||||
price = f"${prompt:.2f}/${completion:.2f}"
|
|
||||||
except:
|
|
||||||
price = "N/A"
|
|
||||||
|
|
||||||
# Check capabilities
|
|
||||||
architecture = model.get("architecture", {})
|
|
||||||
modality = architecture.get("modality", "")
|
|
||||||
supported_params = model.get("supported_parameters", [])
|
|
||||||
|
|
||||||
# Vision support: check if modality contains "image"
|
|
||||||
supports_vision = "image" in modality
|
|
||||||
|
|
||||||
# Tool support: check if "tools" or "tool_choice" in supported_parameters
|
|
||||||
supports_tools = "tools" in supported_params or "tool_choice" in supported_params
|
|
||||||
|
|
||||||
# Online support: check if model can use :online suffix (most models can)
|
|
||||||
# Models that already have :online in their ID support it
|
|
||||||
supports_online = ":online" in model_id or model_id not in ["openrouter/free"]
|
|
||||||
|
|
||||||
# Format capability indicators
|
|
||||||
img_indicator = "✓" if supports_vision else "-"
|
|
||||||
tools_indicator = "✓" if supports_tools else "-"
|
|
||||||
web_indicator = "✓" if supports_online else "-"
|
|
||||||
|
|
||||||
# Add row
|
|
||||||
table.add_row(
|
|
||||||
str(idx),
|
|
||||||
model_id,
|
|
||||||
name,
|
|
||||||
context,
|
|
||||||
price,
|
|
||||||
img_indicator,
|
|
||||||
tools_indicator,
|
|
||||||
web_indicator,
|
|
||||||
key=str(idx)
|
|
||||||
)
|
|
||||||
rows_added += 1
|
|
||||||
|
|
||||||
except Exception:
|
|
||||||
# Silently skip rows that fail to add
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_input_changed(self, event: Input.Changed) -> None:
|
|
||||||
"""Filter models based on search input."""
|
|
||||||
if event.input.id != "search-input":
|
|
||||||
return
|
|
||||||
|
|
||||||
search_term = event.value.lower()
|
|
||||||
|
|
||||||
if not search_term:
|
|
||||||
self.filtered_models = self.all_models
|
|
||||||
else:
|
|
||||||
self.filtered_models = [
|
|
||||||
m for m in self.all_models
|
|
||||||
if search_term in m.get("id", "").lower()
|
|
||||||
or search_term in m.get("name", "").lower()
|
|
||||||
]
|
|
||||||
|
|
||||||
self._populate_table()
|
|
||||||
|
|
||||||
def on_data_table_row_selected(self, event: DataTable.RowSelected) -> None:
|
|
||||||
"""Handle row selection (click or arrow navigation)."""
|
|
||||||
try:
|
|
||||||
row_index = int(event.row_key.value) - 1
|
|
||||||
if 0 <= row_index < len(self.filtered_models):
|
|
||||||
self.selected_model = self.filtered_models[row_index]
|
|
||||||
except (ValueError, IndexError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_data_table_row_highlighted(self, event) -> None:
|
|
||||||
"""Handle row highlight (arrow key navigation)."""
|
|
||||||
try:
|
|
||||||
table = self.query_one("#model-table", DataTable)
|
|
||||||
if table.cursor_row is not None:
|
|
||||||
row_data = table.get_row_at(table.cursor_row)
|
|
||||||
if row_data:
|
|
||||||
row_index = int(row_data[0]) - 1
|
|
||||||
if 0 <= row_index < len(self.filtered_models):
|
|
||||||
self.selected_model = self.filtered_models[row_index]
|
|
||||||
except (ValueError, IndexError, AttributeError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
||||||
"""Handle button press."""
|
|
||||||
if event.button.id == "select":
|
|
||||||
if self.selected_model:
|
|
||||||
self.dismiss(self.selected_model)
|
|
||||||
else:
|
|
||||||
# No selection, dismiss without result
|
|
||||||
self.dismiss(None)
|
|
||||||
else:
|
|
||||||
self.dismiss(None)
|
|
||||||
|
|
||||||
def on_key(self, event) -> None:
|
|
||||||
"""Handle keyboard shortcuts."""
|
|
||||||
if event.key == "escape":
|
|
||||||
self.dismiss(None)
|
|
||||||
elif event.key == "enter":
|
|
||||||
# If in search input, move to table
|
|
||||||
search_input = self.query_one("#search-input", Input)
|
|
||||||
if search_input.has_focus:
|
|
||||||
table = self.query_one("#model-table", DataTable)
|
|
||||||
table.focus()
|
|
||||||
# If in table or anywhere else, select current row
|
|
||||||
else:
|
|
||||||
table = self.query_one("#model-table", DataTable)
|
|
||||||
# Get the currently highlighted row
|
|
||||||
if table.cursor_row is not None:
|
|
||||||
try:
|
|
||||||
row_key = table.get_row_at(table.cursor_row)
|
|
||||||
if row_key:
|
|
||||||
row_index = int(row_key[0]) - 1
|
|
||||||
if 0 <= row_index < len(self.filtered_models):
|
|
||||||
selected = self.filtered_models[row_index]
|
|
||||||
self.dismiss(selected)
|
|
||||||
except (ValueError, IndexError, AttributeError):
|
|
||||||
# Fall back to previously selected model
|
|
||||||
if self.selected_model:
|
|
||||||
self.dismiss(self.selected_model)
|
|
||||||
@@ -1,129 +0,0 @@
|
|||||||
"""Statistics screen for oAI TUI."""
|
|
||||||
|
|
||||||
from textual.app import ComposeResult
|
|
||||||
from textual.containers import Container, Vertical
|
|
||||||
from textual.screen import ModalScreen
|
|
||||||
from textual.widgets import Button, Static
|
|
||||||
|
|
||||||
from oai.core.session import ChatSession
|
|
||||||
|
|
||||||
|
|
||||||
class StatsScreen(ModalScreen[None]):
|
|
||||||
"""Modal screen displaying session statistics."""
|
|
||||||
|
|
||||||
DEFAULT_CSS = """
|
|
||||||
StatsScreen {
|
|
||||||
align: center middle;
|
|
||||||
}
|
|
||||||
|
|
||||||
StatsScreen > Container {
|
|
||||||
width: 70;
|
|
||||||
height: auto;
|
|
||||||
background: #1e1e1e;
|
|
||||||
border: solid #555555;
|
|
||||||
}
|
|
||||||
|
|
||||||
StatsScreen .header {
|
|
||||||
dock: top;
|
|
||||||
width: 100%;
|
|
||||||
height: auto;
|
|
||||||
background: #2d2d2d;
|
|
||||||
color: #cccccc;
|
|
||||||
padding: 0 2;
|
|
||||||
}
|
|
||||||
|
|
||||||
StatsScreen .content {
|
|
||||||
width: 100%;
|
|
||||||
height: auto;
|
|
||||||
background: #1e1e1e;
|
|
||||||
padding: 2;
|
|
||||||
color: #cccccc;
|
|
||||||
}
|
|
||||||
|
|
||||||
StatsScreen .footer {
|
|
||||||
dock: bottom;
|
|
||||||
width: 100%;
|
|
||||||
height: auto;
|
|
||||||
background: #2d2d2d;
|
|
||||||
padding: 1 2;
|
|
||||||
align: center middle;
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, session: ChatSession):
|
|
||||||
super().__init__()
|
|
||||||
self.session = session
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
"""Compose the screen."""
|
|
||||||
with Container():
|
|
||||||
yield Static("[bold]Session Statistics[/]", classes="header")
|
|
||||||
with Vertical(classes="content"):
|
|
||||||
yield Static(self._get_stats_text(), markup=True)
|
|
||||||
with Vertical(classes="footer"):
|
|
||||||
yield Button("Close", id="close", variant="primary")
|
|
||||||
|
|
||||||
def _get_stats_text(self) -> str:
|
|
||||||
"""Generate the statistics text."""
|
|
||||||
stats = self.session.stats
|
|
||||||
|
|
||||||
# Calculate averages
|
|
||||||
avg_input = stats.total_input_tokens // stats.message_count if stats.message_count > 0 else 0
|
|
||||||
avg_output = stats.total_output_tokens // stats.message_count if stats.message_count > 0 else 0
|
|
||||||
avg_cost = stats.total_cost / stats.message_count if stats.message_count > 0 else 0
|
|
||||||
|
|
||||||
# Get model info
|
|
||||||
model_name = "None"
|
|
||||||
model_context = "N/A"
|
|
||||||
if self.session.selected_model:
|
|
||||||
model_name = self.session.selected_model.get("name", "Unknown")
|
|
||||||
model_context = str(self.session.selected_model.get("context_length", "N/A"))
|
|
||||||
|
|
||||||
# MCP status
|
|
||||||
mcp_status = "Disabled"
|
|
||||||
if self.session.mcp_manager and self.session.mcp_manager.enabled:
|
|
||||||
mode = self.session.mcp_manager.mode
|
|
||||||
if mode == "files":
|
|
||||||
write = " (Write)" if self.session.mcp_manager.write_enabled else ""
|
|
||||||
mcp_status = f"Enabled - Files{write}"
|
|
||||||
elif mode == "database":
|
|
||||||
db_idx = self.session.mcp_manager.selected_db_index
|
|
||||||
if db_idx is not None:
|
|
||||||
db_name = self.session.mcp_manager.databases[db_idx]["name"]
|
|
||||||
mcp_status = f"Enabled - Database ({db_name})"
|
|
||||||
|
|
||||||
return f"""
|
|
||||||
[bold cyan]═══ SESSION INFO ═══[/]
|
|
||||||
[bold]Messages:[/] {stats.message_count}
|
|
||||||
[bold]Current Model:[/] {model_name}
|
|
||||||
[bold]Context Length:[/] {model_context}
|
|
||||||
[bold]Memory:[/] {"Enabled" if self.session.memory_enabled else "Disabled"}
|
|
||||||
[bold]Online Mode:[/] {"Enabled" if self.session.online_enabled else "Disabled"}
|
|
||||||
[bold]MCP:[/] {mcp_status}
|
|
||||||
|
|
||||||
[bold cyan]═══ TOKEN USAGE ═══[/]
|
|
||||||
[bold]Input Tokens:[/] {stats.total_input_tokens:,}
|
|
||||||
[bold]Output Tokens:[/] {stats.total_output_tokens:,}
|
|
||||||
[bold]Total Tokens:[/] {stats.total_tokens:,}
|
|
||||||
|
|
||||||
[bold]Avg Input/Msg:[/] {avg_input:,}
|
|
||||||
[bold]Avg Output/Msg:[/] {avg_output:,}
|
|
||||||
|
|
||||||
[bold cyan]═══ COSTS ═══[/]
|
|
||||||
[bold]Total Cost:[/] ${stats.total_cost:.6f}
|
|
||||||
[bold]Avg Cost/Msg:[/] ${avg_cost:.6f}
|
|
||||||
|
|
||||||
[bold cyan]═══ HISTORY ═══[/]
|
|
||||||
[bold]History Size:[/] {len(self.session.history)} entries
|
|
||||||
[bold]Current Index:[/] {self.session.current_index + 1 if self.session.history else 0}
|
|
||||||
[bold]Memory Start:[/] {self.session.memory_start_index + 1}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def on_button_pressed(self, event: Button.Pressed) -> None:
|
|
||||||
"""Handle button press."""
|
|
||||||
self.dismiss()
|
|
||||||
|
|
||||||
def on_key(self, event) -> None:
|
|
||||||
"""Handle keyboard shortcuts."""
|
|
||||||
if event.key in ("escape", "enter"):
|
|
||||||
self.dismiss()
|
|
||||||
@@ -1,167 +0,0 @@
|
|||||||
/* Textual CSS for oAI TUI - Using Textual Design System */
|
|
||||||
|
|
||||||
Screen {
|
|
||||||
background: $background;
|
|
||||||
overflow: hidden;
|
|
||||||
}
|
|
||||||
|
|
||||||
Header {
|
|
||||||
dock: top;
|
|
||||||
height: auto;
|
|
||||||
background: #2d2d2d;
|
|
||||||
color: #cccccc;
|
|
||||||
padding: 0 1;
|
|
||||||
border-bottom: solid #555555;
|
|
||||||
}
|
|
||||||
|
|
||||||
ChatDisplay {
|
|
||||||
background: $background;
|
|
||||||
border: none;
|
|
||||||
padding: 1;
|
|
||||||
scrollbar-background: $background;
|
|
||||||
scrollbar-color: $primary;
|
|
||||||
overflow-y: auto;
|
|
||||||
}
|
|
||||||
|
|
||||||
UserMessageWidget {
|
|
||||||
margin: 0 0 1 0;
|
|
||||||
padding: 1;
|
|
||||||
background: $surface;
|
|
||||||
border-left: thick $success;
|
|
||||||
height: auto;
|
|
||||||
}
|
|
||||||
|
|
||||||
SystemMessageWidget {
|
|
||||||
margin: 0 0 1 0;
|
|
||||||
padding: 1;
|
|
||||||
background: #2a2a2a;
|
|
||||||
border-left: thick #888888;
|
|
||||||
height: auto;
|
|
||||||
color: #cccccc;
|
|
||||||
}
|
|
||||||
|
|
||||||
AssistantMessageWidget {
|
|
||||||
margin: 0 0 1 0;
|
|
||||||
padding: 1;
|
|
||||||
background: $panel;
|
|
||||||
border-left: thick $accent;
|
|
||||||
height: auto;
|
|
||||||
}
|
|
||||||
|
|
||||||
#assistant-label {
|
|
||||||
margin-bottom: 1;
|
|
||||||
color: #cccccc;
|
|
||||||
}
|
|
||||||
|
|
||||||
#assistant-content {
|
|
||||||
height: auto;
|
|
||||||
max-height: 100%;
|
|
||||||
color: $text;
|
|
||||||
}
|
|
||||||
|
|
||||||
InputBar {
|
|
||||||
dock: bottom;
|
|
||||||
height: auto;
|
|
||||||
background: #2d2d2d;
|
|
||||||
align: center middle;
|
|
||||||
border-top: solid #555555;
|
|
||||||
padding: 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
#input-prefix {
|
|
||||||
width: auto;
|
|
||||||
padding: 0 1;
|
|
||||||
content-align: center middle;
|
|
||||||
color: #888888;
|
|
||||||
}
|
|
||||||
|
|
||||||
#input-prefix.prefix-hidden {
|
|
||||||
display: none;
|
|
||||||
}
|
|
||||||
|
|
||||||
#chat-input {
|
|
||||||
width: 85%;
|
|
||||||
height: 5;
|
|
||||||
min-height: 5;
|
|
||||||
background: #3a3a3a;
|
|
||||||
border: none;
|
|
||||||
padding: 1 2;
|
|
||||||
color: #ffffff;
|
|
||||||
content-align: left top;
|
|
||||||
}
|
|
||||||
|
|
||||||
#chat-input:focus {
|
|
||||||
background: #404040;
|
|
||||||
}
|
|
||||||
|
|
||||||
#command-dropdown {
|
|
||||||
display: none;
|
|
||||||
dock: bottom;
|
|
||||||
offset-y: -5;
|
|
||||||
offset-x: 7.5%;
|
|
||||||
height: auto;
|
|
||||||
max-height: 12;
|
|
||||||
width: 85%;
|
|
||||||
background: #2d2d2d;
|
|
||||||
border: solid #555555;
|
|
||||||
padding: 0;
|
|
||||||
layer: overlay;
|
|
||||||
}
|
|
||||||
|
|
||||||
#command-dropdown.visible {
|
|
||||||
display: block;
|
|
||||||
}
|
|
||||||
|
|
||||||
#command-dropdown #command-list {
|
|
||||||
background: #2d2d2d;
|
|
||||||
border: none;
|
|
||||||
scrollbar-background: #2d2d2d;
|
|
||||||
scrollbar-color: #555555;
|
|
||||||
}
|
|
||||||
|
|
||||||
Footer {
|
|
||||||
dock: bottom;
|
|
||||||
height: auto;
|
|
||||||
background: #252525;
|
|
||||||
color: #888888;
|
|
||||||
padding: 0 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* Button styles */
|
|
||||||
Button {
|
|
||||||
height: 3;
|
|
||||||
min-width: 10;
|
|
||||||
background: #3a3a3a;
|
|
||||||
color: #cccccc;
|
|
||||||
border: none;
|
|
||||||
}
|
|
||||||
|
|
||||||
Button:hover {
|
|
||||||
background: #4a4a4a;
|
|
||||||
}
|
|
||||||
|
|
||||||
Button:focus {
|
|
||||||
background: #505050;
|
|
||||||
}
|
|
||||||
|
|
||||||
Button.-primary {
|
|
||||||
background: #3a3a3a;
|
|
||||||
}
|
|
||||||
|
|
||||||
Button.-success {
|
|
||||||
background: #2d5016;
|
|
||||||
color: #90ee90;
|
|
||||||
}
|
|
||||||
|
|
||||||
Button.-success:hover {
|
|
||||||
background: #3a6b1e;
|
|
||||||
}
|
|
||||||
|
|
||||||
Button.-error {
|
|
||||||
background: #5a1a1a;
|
|
||||||
color: #ff6b6b;
|
|
||||||
}
|
|
||||||
|
|
||||||
Button.-error:hover {
|
|
||||||
background: #6e2222;
|
|
||||||
}
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
"""TUI widgets for oAI."""
|
|
||||||
|
|
||||||
from oai.tui.widgets.chat_display import ChatDisplay
|
|
||||||
from oai.tui.widgets.footer import Footer
|
|
||||||
from oai.tui.widgets.header import Header
|
|
||||||
from oai.tui.widgets.input_bar import InputBar
|
|
||||||
from oai.tui.widgets.message import AssistantMessageWidget, SystemMessageWidget, UserMessageWidget
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"ChatDisplay",
|
|
||||||
"Footer",
|
|
||||||
"Header",
|
|
||||||
"InputBar",
|
|
||||||
"UserMessageWidget",
|
|
||||||
"SystemMessageWidget",
|
|
||||||
"AssistantMessageWidget",
|
|
||||||
]
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
"""Chat display widget for oAI TUI."""
|
|
||||||
|
|
||||||
from textual.containers import ScrollableContainer
|
|
||||||
from textual.widgets import Static
|
|
||||||
|
|
||||||
|
|
||||||
class ChatDisplay(ScrollableContainer):
|
|
||||||
"""Scrollable container for chat messages."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(id="chat-display")
|
|
||||||
|
|
||||||
async def add_message(self, widget: Static) -> None:
|
|
||||||
"""Add a message widget to the display."""
|
|
||||||
await self.mount(widget)
|
|
||||||
self.scroll_end(animate=False)
|
|
||||||
|
|
||||||
def clear_messages(self) -> None:
|
|
||||||
"""Clear all messages from the display."""
|
|
||||||
for child in list(self.children):
|
|
||||||
child.remove()
|
|
||||||
@@ -1,178 +0,0 @@
|
|||||||
"""Command dropdown menu for TUI input."""
|
|
||||||
|
|
||||||
from textual.app import ComposeResult
|
|
||||||
from textual.containers import VerticalScroll
|
|
||||||
from textual.widget import Widget
|
|
||||||
from textual.widgets import Label, OptionList
|
|
||||||
from textual.widgets.option_list import Option
|
|
||||||
|
|
||||||
from oai.commands import registry
|
|
||||||
|
|
||||||
|
|
||||||
class CommandDropdown(VerticalScroll):
|
|
||||||
"""Dropdown menu showing available commands."""
|
|
||||||
|
|
||||||
DEFAULT_CSS = """
|
|
||||||
CommandDropdown {
|
|
||||||
display: none;
|
|
||||||
height: auto;
|
|
||||||
max-height: 12;
|
|
||||||
width: 80;
|
|
||||||
background: #2d2d2d;
|
|
||||||
border: solid #555555;
|
|
||||||
padding: 0;
|
|
||||||
layer: overlay;
|
|
||||||
}
|
|
||||||
|
|
||||||
CommandDropdown.visible {
|
|
||||||
display: block;
|
|
||||||
}
|
|
||||||
|
|
||||||
CommandDropdown OptionList {
|
|
||||||
height: auto;
|
|
||||||
max-height: 12;
|
|
||||||
background: #2d2d2d;
|
|
||||||
border: none;
|
|
||||||
padding: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
CommandDropdown OptionList > .option-list--option {
|
|
||||||
padding: 0 2;
|
|
||||||
color: #cccccc;
|
|
||||||
background: transparent;
|
|
||||||
}
|
|
||||||
|
|
||||||
CommandDropdown OptionList > .option-list--option-highlighted {
|
|
||||||
background: #3e3e3e;
|
|
||||||
color: #ffffff;
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize the command dropdown."""
|
|
||||||
super().__init__(id="command-dropdown")
|
|
||||||
self._all_commands = []
|
|
||||||
self._load_commands()
|
|
||||||
|
|
||||||
def _load_commands(self) -> None:
|
|
||||||
"""Load all available commands."""
|
|
||||||
# Get base commands with descriptions
|
|
||||||
base_commands = [
|
|
||||||
("/help", "Show help screen"),
|
|
||||||
("/model", "Select AI model"),
|
|
||||||
("/stats", "Show session statistics"),
|
|
||||||
("/credits", "Check account credits"),
|
|
||||||
("/clear", "Clear chat display"),
|
|
||||||
("/reset", "Reset conversation history"),
|
|
||||||
("/memory on", "Enable conversation memory"),
|
|
||||||
("/memory off", "Disable memory"),
|
|
||||||
("/online on", "Enable online search"),
|
|
||||||
("/online off", "Disable online search"),
|
|
||||||
("/save", "Save current conversation"),
|
|
||||||
("/load", "Load saved conversation"),
|
|
||||||
("/list", "List saved conversations"),
|
|
||||||
("/delete", "Delete a conversation"),
|
|
||||||
("/export md", "Export as Markdown"),
|
|
||||||
("/export json", "Export as JSON"),
|
|
||||||
("/export html", "Export as HTML"),
|
|
||||||
("/prev", "Show previous message"),
|
|
||||||
("/next", "Show next message"),
|
|
||||||
("/config", "View configuration"),
|
|
||||||
("/config api", "Set API key"),
|
|
||||||
("/system", "Set system prompt"),
|
|
||||||
("/maxtoken", "Set token limit"),
|
|
||||||
("/retry", "Retry last prompt"),
|
|
||||||
("/paste", "Paste from clipboard"),
|
|
||||||
("/mcp on", "Enable MCP file access"),
|
|
||||||
("/mcp off", "Disable MCP"),
|
|
||||||
("/mcp status", "Show MCP status"),
|
|
||||||
("/mcp add", "Add folder/database"),
|
|
||||||
("/mcp remove", "Remove folder/database"),
|
|
||||||
("/mcp list", "List folders"),
|
|
||||||
("/mcp write on", "Enable write mode"),
|
|
||||||
("/mcp write off", "Disable write mode"),
|
|
||||||
("/mcp files", "Switch to file mode"),
|
|
||||||
("/mcp db list", "List databases"),
|
|
||||||
]
|
|
||||||
|
|
||||||
self._all_commands = base_commands
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
"""Compose the dropdown."""
|
|
||||||
yield OptionList(id="command-list")
|
|
||||||
|
|
||||||
def show_commands(self, filter_text: str = "") -> None:
|
|
||||||
"""Show commands matching the filter.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filter_text: Text to filter commands by
|
|
||||||
"""
|
|
||||||
option_list = self.query_one("#command-list", OptionList)
|
|
||||||
option_list.clear_options()
|
|
||||||
|
|
||||||
if not filter_text.startswith("/"):
|
|
||||||
self.remove_class("visible")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Remove the leading slash for filtering
|
|
||||||
filter_without_slash = filter_text[1:].lower()
|
|
||||||
|
|
||||||
# Filter commands - show if filter text is contained anywhere in the command
|
|
||||||
if filter_without_slash:
|
|
||||||
matching = [
|
|
||||||
(cmd, desc) for cmd, desc in self._all_commands
|
|
||||||
if filter_without_slash in cmd[1:].lower() # Skip the / in command for matching
|
|
||||||
]
|
|
||||||
else:
|
|
||||||
# Show all commands when just "/" is typed
|
|
||||||
matching = self._all_commands
|
|
||||||
|
|
||||||
if not matching:
|
|
||||||
self.remove_class("visible")
|
|
||||||
return
|
|
||||||
|
|
||||||
# Add options - limit to 10 results
|
|
||||||
for cmd, desc in matching[:10]:
|
|
||||||
# Format: command in white, description in gray, separated by spaces
|
|
||||||
label = f"{cmd} [dim]{desc}[/]" if desc else cmd
|
|
||||||
option_list.add_option(Option(label, id=cmd))
|
|
||||||
|
|
||||||
self.add_class("visible")
|
|
||||||
|
|
||||||
# Auto-select first option
|
|
||||||
if len(option_list._options) > 0:
|
|
||||||
option_list.highlighted = 0
|
|
||||||
|
|
||||||
def hide(self) -> None:
|
|
||||||
"""Hide the dropdown."""
|
|
||||||
self.remove_class("visible")
|
|
||||||
|
|
||||||
def get_selected_command(self) -> str | None:
|
|
||||||
"""Get the currently selected command.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Selected command text or None
|
|
||||||
"""
|
|
||||||
option_list = self.query_one("#command-list", OptionList)
|
|
||||||
if option_list.highlighted is not None:
|
|
||||||
option = option_list.get_option_at_index(option_list.highlighted)
|
|
||||||
return option.id
|
|
||||||
return None
|
|
||||||
|
|
||||||
def move_selection_up(self) -> None:
|
|
||||||
"""Move selection up in the list."""
|
|
||||||
option_list = self.query_one("#command-list", OptionList)
|
|
||||||
if option_list.option_count > 0:
|
|
||||||
if option_list.highlighted is None:
|
|
||||||
option_list.highlighted = option_list.option_count - 1
|
|
||||||
elif option_list.highlighted > 0:
|
|
||||||
option_list.highlighted -= 1
|
|
||||||
|
|
||||||
def move_selection_down(self) -> None:
|
|
||||||
"""Move selection down in the list."""
|
|
||||||
option_list = self.query_one("#command-list", OptionList)
|
|
||||||
if option_list.option_count > 0:
|
|
||||||
if option_list.highlighted is None:
|
|
||||||
option_list.highlighted = 0
|
|
||||||
elif option_list.highlighted < option_list.option_count - 1:
|
|
||||||
option_list.highlighted += 1
|
|
||||||
@@ -1,58 +0,0 @@
|
|||||||
"""Command suggester for TUI input."""
|
|
||||||
|
|
||||||
from typing import Iterable
|
|
||||||
|
|
||||||
from textual.suggester import Suggester
|
|
||||||
|
|
||||||
from oai.commands import registry
|
|
||||||
|
|
||||||
|
|
||||||
class CommandSuggester(Suggester):
|
|
||||||
"""Suggester that provides command completions."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize the command suggester."""
|
|
||||||
super().__init__(use_cache=False, case_sensitive=False)
|
|
||||||
# Get all command names from registry
|
|
||||||
self._commands = []
|
|
||||||
self._update_commands()
|
|
||||||
|
|
||||||
def _update_commands(self) -> None:
|
|
||||||
"""Update the list of available commands."""
|
|
||||||
# Get all registered command names
|
|
||||||
command_names = registry.get_all_names()
|
|
||||||
# Add common MCP subcommands for better UX
|
|
||||||
mcp_subcommands = [
|
|
||||||
"/mcp on",
|
|
||||||
"/mcp off",
|
|
||||||
"/mcp status",
|
|
||||||
"/mcp add",
|
|
||||||
"/mcp remove",
|
|
||||||
"/mcp list",
|
|
||||||
"/mcp write on",
|
|
||||||
"/mcp write off",
|
|
||||||
"/mcp files",
|
|
||||||
"/mcp db list",
|
|
||||||
]
|
|
||||||
self._commands = command_names + mcp_subcommands
|
|
||||||
|
|
||||||
async def get_suggestion(self, value: str) -> str | None:
|
|
||||||
"""Get a command suggestion based on the current input.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
value: Current input value
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Suggested completion or None
|
|
||||||
"""
|
|
||||||
if not value or not value.startswith("/"):
|
|
||||||
return None
|
|
||||||
|
|
||||||
# Find the first command that starts with the input
|
|
||||||
value_lower = value.lower()
|
|
||||||
for cmd in self._commands:
|
|
||||||
if cmd.lower().startswith(value_lower) and cmd.lower() != value_lower:
|
|
||||||
# Return the rest of the command (after what's already typed)
|
|
||||||
return cmd[len(value):]
|
|
||||||
|
|
||||||
return None
|
|
||||||
@@ -1,39 +0,0 @@
|
|||||||
"""Footer widget for oAI TUI."""
|
|
||||||
|
|
||||||
from textual.app import ComposeResult
|
|
||||||
from textual.widgets import Static
|
|
||||||
|
|
||||||
|
|
||||||
class Footer(Static):
|
|
||||||
"""Footer displaying session metrics."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
self.tokens_in = 0
|
|
||||||
self.tokens_out = 0
|
|
||||||
self.cost = 0.0
|
|
||||||
self.messages = 0
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
"""Compose the footer."""
|
|
||||||
yield Static(self._format_footer(), id="footer-content")
|
|
||||||
|
|
||||||
def _format_footer(self) -> str:
|
|
||||||
"""Format the footer text."""
|
|
||||||
return (
|
|
||||||
f"[dim]Messages: {self.messages} | "
|
|
||||||
f"Tokens: {self.tokens_in + self.tokens_out:,} "
|
|
||||||
f"({self.tokens_in:,} in, {self.tokens_out:,} out) | "
|
|
||||||
f"Cost: ${self.cost:.4f}[/]"
|
|
||||||
)
|
|
||||||
|
|
||||||
def update_stats(
|
|
||||||
self, tokens_in: int, tokens_out: int, cost: float, messages: int
|
|
||||||
) -> None:
|
|
||||||
"""Update the displayed statistics."""
|
|
||||||
self.tokens_in = tokens_in
|
|
||||||
self.tokens_out = tokens_out
|
|
||||||
self.cost = cost
|
|
||||||
self.messages = messages
|
|
||||||
content = self.query_one("#footer-content", Static)
|
|
||||||
content.update(self._format_footer())
|
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
"""Header widget for oAI TUI."""
|
|
||||||
|
|
||||||
from textual.app import ComposeResult
|
|
||||||
from textual.widgets import Static
|
|
||||||
from typing import Optional, Dict, Any
|
|
||||||
|
|
||||||
|
|
||||||
class Header(Static):
|
|
||||||
"""Header displaying app title, version, current model, and capabilities."""
|
|
||||||
|
|
||||||
def __init__(self, version: str = "3.0.0", model: str = "", model_info: Optional[Dict[str, Any]] = None):
|
|
||||||
super().__init__()
|
|
||||||
self.version = version
|
|
||||||
self.model = model
|
|
||||||
self.model_info = model_info or {}
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
"""Compose the header."""
|
|
||||||
yield Static(self._format_header(), id="header-content")
|
|
||||||
|
|
||||||
def _format_capabilities(self) -> str:
|
|
||||||
"""Format capability icons based on model info."""
|
|
||||||
if not self.model_info:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
icons = []
|
|
||||||
|
|
||||||
# Check vision support
|
|
||||||
architecture = self.model_info.get("architecture", {})
|
|
||||||
modality = architecture.get("modality", "")
|
|
||||||
if "image" in modality:
|
|
||||||
icons.append("[bold cyan]👁️[/]") # Bright if supported
|
|
||||||
else:
|
|
||||||
icons.append("[dim]👁️[/]") # Dim if not supported
|
|
||||||
|
|
||||||
# Check tool support
|
|
||||||
supported_params = self.model_info.get("supported_parameters", [])
|
|
||||||
if "tools" in supported_params or "tool_choice" in supported_params:
|
|
||||||
icons.append("[bold cyan]🔧[/]")
|
|
||||||
else:
|
|
||||||
icons.append("[dim]🔧[/]")
|
|
||||||
|
|
||||||
# Check online support (most models support :online suffix)
|
|
||||||
model_id = self.model_info.get("id", "")
|
|
||||||
if ":online" in model_id or model_id not in ["openrouter/free"]:
|
|
||||||
icons.append("[bold cyan]🌐[/]")
|
|
||||||
else:
|
|
||||||
icons.append("[dim]🌐[/]")
|
|
||||||
|
|
||||||
return " ".join(icons) if icons else ""
|
|
||||||
|
|
||||||
def _format_header(self) -> str:
|
|
||||||
"""Format the header text."""
|
|
||||||
model_text = f" | {self.model}" if self.model else ""
|
|
||||||
capabilities = self._format_capabilities()
|
|
||||||
capabilities_text = f" {capabilities}" if capabilities else ""
|
|
||||||
return f"[bold cyan]oAI[/] [dim]v{self.version}[/]{model_text}{capabilities_text}"
|
|
||||||
|
|
||||||
def update_model(self, model: str, model_info: Optional[Dict[str, Any]] = None) -> None:
|
|
||||||
"""Update the displayed model and capabilities."""
|
|
||||||
self.model = model
|
|
||||||
if model_info:
|
|
||||||
self.model_info = model_info
|
|
||||||
content = self.query_one("#header-content", Static)
|
|
||||||
content.update(self._format_header())
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
"""Input bar widget for oAI TUI."""
|
|
||||||
|
|
||||||
from textual.app import ComposeResult
|
|
||||||
from textual.containers import Horizontal
|
|
||||||
from textual.widgets import Input, Static
|
|
||||||
|
|
||||||
|
|
||||||
class InputBar(Horizontal):
|
|
||||||
"""Input bar with prompt prefix and text input."""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__(id="input-bar")
|
|
||||||
self.mcp_status = ""
|
|
||||||
self.online_mode = False
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
"""Compose the input bar."""
|
|
||||||
yield Static(self._format_prefix(), id="input-prefix", classes="prefix-hidden" if not (self.mcp_status or self.online_mode) else "")
|
|
||||||
yield Input(
|
|
||||||
placeholder="Type a message or /command...",
|
|
||||||
id="chat-input"
|
|
||||||
)
|
|
||||||
|
|
||||||
def _format_prefix(self) -> str:
|
|
||||||
"""Format the input prefix with status indicators."""
|
|
||||||
indicators = []
|
|
||||||
if self.mcp_status:
|
|
||||||
indicators.append(f"[cyan]{self.mcp_status}[/]")
|
|
||||||
if self.online_mode:
|
|
||||||
indicators.append("[green]🌐[/]")
|
|
||||||
|
|
||||||
prefix = " ".join(indicators) + " " if indicators else ""
|
|
||||||
return f"{prefix}[bold]>[/]"
|
|
||||||
|
|
||||||
def update_mcp_status(self, status: str) -> None:
|
|
||||||
"""Update MCP status indicator."""
|
|
||||||
self.mcp_status = status
|
|
||||||
prefix = self.query_one("#input-prefix", Static)
|
|
||||||
prefix.update(self._format_prefix())
|
|
||||||
|
|
||||||
def update_online_mode(self, online: bool) -> None:
|
|
||||||
"""Update online mode indicator."""
|
|
||||||
self.online_mode = online
|
|
||||||
prefix = self.query_one("#input-prefix", Static)
|
|
||||||
prefix.update(self._format_prefix())
|
|
||||||
|
|
||||||
def get_input(self) -> Input:
|
|
||||||
"""Get the input widget."""
|
|
||||||
return self.query_one("#chat-input", Input)
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
"""Message widgets for oAI TUI."""
|
|
||||||
|
|
||||||
from typing import Any, AsyncIterator, Tuple
|
|
||||||
|
|
||||||
from rich.markdown import Markdown
|
|
||||||
from textual.app import ComposeResult
|
|
||||||
from textual.widgets import RichLog, Static
|
|
||||||
|
|
||||||
|
|
||||||
class UserMessageWidget(Static):
|
|
||||||
"""Widget for displaying user messages."""
|
|
||||||
|
|
||||||
def __init__(self, content: str):
|
|
||||||
super().__init__()
|
|
||||||
self.content = content
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
"""Compose the user message."""
|
|
||||||
yield Static(f"[bold green]You:[/] {self.content}")
|
|
||||||
|
|
||||||
|
|
||||||
class SystemMessageWidget(Static):
|
|
||||||
"""Widget for displaying system/info messages without 'You:' prefix."""
|
|
||||||
|
|
||||||
def __init__(self, content: str):
|
|
||||||
super().__init__()
|
|
||||||
self.content = content
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
"""Compose the system message."""
|
|
||||||
yield Static(self.content)
|
|
||||||
|
|
||||||
|
|
||||||
class AssistantMessageWidget(Static):
|
|
||||||
"""Widget for displaying assistant responses with streaming support."""
|
|
||||||
|
|
||||||
def __init__(self, model_name: str = "Assistant"):
|
|
||||||
super().__init__()
|
|
||||||
self.model_name = model_name
|
|
||||||
self.full_text = ""
|
|
||||||
|
|
||||||
def compose(self) -> ComposeResult:
|
|
||||||
"""Compose the assistant message."""
|
|
||||||
yield Static(f"[bold]{self.model_name}:[/]", id="assistant-label")
|
|
||||||
yield RichLog(id="assistant-content", highlight=True, markup=True, wrap=True)
|
|
||||||
|
|
||||||
async def stream_response(self, response_iterator: AsyncIterator) -> Tuple[str, Any]:
|
|
||||||
"""Stream tokens progressively and return final text and usage."""
|
|
||||||
log = self.query_one("#assistant-content", RichLog)
|
|
||||||
self.full_text = ""
|
|
||||||
usage = None
|
|
||||||
|
|
||||||
async for chunk in response_iterator:
|
|
||||||
if hasattr(chunk, "delta_content") and chunk.delta_content:
|
|
||||||
self.full_text += chunk.delta_content
|
|
||||||
log.clear()
|
|
||||||
log.write(Markdown(self.full_text))
|
|
||||||
|
|
||||||
if hasattr(chunk, "usage") and chunk.usage:
|
|
||||||
usage = chunk.usage
|
|
||||||
|
|
||||||
return self.full_text, usage
|
|
||||||
|
|
||||||
def set_content(self, content: str) -> None:
|
|
||||||
"""Set the complete content (non-streaming)."""
|
|
||||||
self.full_text = content
|
|
||||||
log = self.query_one("#assistant-content", RichLog)
|
|
||||||
log.clear()
|
|
||||||
log.write(Markdown(content))
|
|
||||||
@@ -1,20 +0,0 @@
|
|||||||
"""
|
|
||||||
Utility modules for oAI.
|
|
||||||
|
|
||||||
This package provides common utilities used throughout the application
|
|
||||||
including logging, file handling, and export functionality.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from oai.utils.logging import setup_logging, get_logger
|
|
||||||
from oai.utils.files import read_file_safe, is_binary_file
|
|
||||||
from oai.utils.export import export_as_markdown, export_as_json, export_as_html
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"setup_logging",
|
|
||||||
"get_logger",
|
|
||||||
"read_file_safe",
|
|
||||||
"is_binary_file",
|
|
||||||
"export_as_markdown",
|
|
||||||
"export_as_json",
|
|
||||||
"export_as_html",
|
|
||||||
]
|
|
||||||
@@ -1,248 +0,0 @@
|
|||||||
"""
|
|
||||||
Export utilities for oAI.
|
|
||||||
|
|
||||||
This module provides functions for exporting conversation history
|
|
||||||
in various formats including Markdown, JSON, and HTML.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import json
|
|
||||||
import datetime
|
|
||||||
from typing import List, Dict
|
|
||||||
from html import escape as html_escape
|
|
||||||
|
|
||||||
from oai.constants import APP_VERSION, APP_URL
|
|
||||||
|
|
||||||
|
|
||||||
def export_as_markdown(
|
|
||||||
session_history: List[Dict[str, str]],
|
|
||||||
session_system_prompt: str = ""
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Export conversation history as Markdown.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_history: List of message dictionaries with 'prompt' and 'response'
|
|
||||||
session_system_prompt: Optional system prompt to include
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Markdown formatted string
|
|
||||||
"""
|
|
||||||
lines = ["# Conversation Export", ""]
|
|
||||||
|
|
||||||
if session_system_prompt:
|
|
||||||
lines.extend([f"**System Prompt:** {session_system_prompt}", ""])
|
|
||||||
|
|
||||||
lines.append(f"**Export Date:** {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
|
||||||
lines.append(f"**Messages:** {len(session_history)}")
|
|
||||||
lines.append("")
|
|
||||||
lines.append("---")
|
|
||||||
lines.append("")
|
|
||||||
|
|
||||||
for i, entry in enumerate(session_history, 1):
|
|
||||||
lines.append(f"## Message {i}")
|
|
||||||
lines.append("")
|
|
||||||
lines.append("**User:**")
|
|
||||||
lines.append("")
|
|
||||||
lines.append(entry.get("prompt", ""))
|
|
||||||
lines.append("")
|
|
||||||
lines.append("**Assistant:**")
|
|
||||||
lines.append("")
|
|
||||||
lines.append(entry.get("response", ""))
|
|
||||||
lines.append("")
|
|
||||||
lines.append("---")
|
|
||||||
lines.append("")
|
|
||||||
|
|
||||||
lines.append(f"*Exported from oAI v{APP_VERSION} - {APP_URL}*")
|
|
||||||
|
|
||||||
return "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
def export_as_json(
|
|
||||||
session_history: List[Dict[str, str]],
|
|
||||||
session_system_prompt: str = ""
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Export conversation history as JSON.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_history: List of message dictionaries
|
|
||||||
session_system_prompt: Optional system prompt to include
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
JSON formatted string
|
|
||||||
"""
|
|
||||||
export_data = {
|
|
||||||
"export_date": datetime.datetime.now().isoformat(),
|
|
||||||
"app_version": APP_VERSION,
|
|
||||||
"system_prompt": session_system_prompt,
|
|
||||||
"message_count": len(session_history),
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"index": i + 1,
|
|
||||||
"prompt": entry.get("prompt", ""),
|
|
||||||
"response": entry.get("response", ""),
|
|
||||||
"prompt_tokens": entry.get("prompt_tokens", 0),
|
|
||||||
"completion_tokens": entry.get("completion_tokens", 0),
|
|
||||||
"cost": entry.get("msg_cost", 0.0),
|
|
||||||
}
|
|
||||||
for i, entry in enumerate(session_history)
|
|
||||||
],
|
|
||||||
"totals": {
|
|
||||||
"prompt_tokens": sum(e.get("prompt_tokens", 0) for e in session_history),
|
|
||||||
"completion_tokens": sum(e.get("completion_tokens", 0) for e in session_history),
|
|
||||||
"total_cost": sum(e.get("msg_cost", 0.0) for e in session_history),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return json.dumps(export_data, indent=2, ensure_ascii=False)
|
|
||||||
|
|
||||||
|
|
||||||
def export_as_html(
|
|
||||||
session_history: List[Dict[str, str]],
|
|
||||||
session_system_prompt: str = ""
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Export conversation history as styled HTML.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session_history: List of message dictionaries
|
|
||||||
session_system_prompt: Optional system prompt to include
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
HTML formatted string with embedded CSS
|
|
||||||
"""
|
|
||||||
html_parts = [
|
|
||||||
"<!DOCTYPE html>",
|
|
||||||
"<html>",
|
|
||||||
"<head>",
|
|
||||||
" <meta charset='UTF-8'>",
|
|
||||||
" <meta name='viewport' content='width=device-width, initial-scale=1.0'>",
|
|
||||||
" <title>Conversation Export - oAI</title>",
|
|
||||||
" <style>",
|
|
||||||
" * { box-sizing: border-box; }",
|
|
||||||
" body {",
|
|
||||||
" font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;",
|
|
||||||
" max-width: 900px;",
|
|
||||||
" margin: 40px auto;",
|
|
||||||
" padding: 20px;",
|
|
||||||
" background: #f5f5f5;",
|
|
||||||
" color: #333;",
|
|
||||||
" }",
|
|
||||||
" .header {",
|
|
||||||
" background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);",
|
|
||||||
" color: white;",
|
|
||||||
" padding: 30px;",
|
|
||||||
" border-radius: 10px;",
|
|
||||||
" margin-bottom: 30px;",
|
|
||||||
" box-shadow: 0 4px 6px rgba(0,0,0,0.1);",
|
|
||||||
" }",
|
|
||||||
" .header h1 {",
|
|
||||||
" margin: 0 0 10px 0;",
|
|
||||||
" font-size: 2em;",
|
|
||||||
" }",
|
|
||||||
" .export-info {",
|
|
||||||
" opacity: 0.9;",
|
|
||||||
" font-size: 0.95em;",
|
|
||||||
" margin: 5px 0;",
|
|
||||||
" }",
|
|
||||||
" .system-prompt {",
|
|
||||||
" background: #fff3cd;",
|
|
||||||
" padding: 20px;",
|
|
||||||
" border-radius: 8px;",
|
|
||||||
" margin-bottom: 25px;",
|
|
||||||
" border-left: 5px solid #ffc107;",
|
|
||||||
" box-shadow: 0 2px 4px rgba(0,0,0,0.05);",
|
|
||||||
" }",
|
|
||||||
" .system-prompt strong {",
|
|
||||||
" color: #856404;",
|
|
||||||
" display: block;",
|
|
||||||
" margin-bottom: 10px;",
|
|
||||||
" font-size: 1.1em;",
|
|
||||||
" }",
|
|
||||||
" .message-container { margin-bottom: 20px; }",
|
|
||||||
" .message {",
|
|
||||||
" background: white;",
|
|
||||||
" padding: 20px;",
|
|
||||||
" border-radius: 8px;",
|
|
||||||
" box-shadow: 0 2px 4px rgba(0,0,0,0.08);",
|
|
||||||
" margin-bottom: 12px;",
|
|
||||||
" }",
|
|
||||||
" .user-message { border-left: 5px solid #10b981; }",
|
|
||||||
" .assistant-message { border-left: 5px solid #3b82f6; }",
|
|
||||||
" .role {",
|
|
||||||
" font-weight: bold;",
|
|
||||||
" margin-bottom: 12px;",
|
|
||||||
" font-size: 1.05em;",
|
|
||||||
" text-transform: uppercase;",
|
|
||||||
" letter-spacing: 0.5px;",
|
|
||||||
" }",
|
|
||||||
" .user-role { color: #10b981; }",
|
|
||||||
" .assistant-role { color: #3b82f6; }",
|
|
||||||
" .content {",
|
|
||||||
" line-height: 1.8;",
|
|
||||||
" white-space: pre-wrap;",
|
|
||||||
" color: #333;",
|
|
||||||
" }",
|
|
||||||
" .message-number {",
|
|
||||||
" color: #6b7280;",
|
|
||||||
" font-size: 0.85em;",
|
|
||||||
" margin-bottom: 15px;",
|
|
||||||
" font-weight: 600;",
|
|
||||||
" }",
|
|
||||||
" .footer {",
|
|
||||||
" text-align: center;",
|
|
||||||
" margin-top: 40px;",
|
|
||||||
" padding: 20px;",
|
|
||||||
" color: #6b7280;",
|
|
||||||
" font-size: 0.9em;",
|
|
||||||
" }",
|
|
||||||
" .footer a { color: #667eea; text-decoration: none; }",
|
|
||||||
" .footer a:hover { text-decoration: underline; }",
|
|
||||||
" @media print {",
|
|
||||||
" body { background: white; }",
|
|
||||||
" .message { break-inside: avoid; }",
|
|
||||||
" }",
|
|
||||||
" </style>",
|
|
||||||
"</head>",
|
|
||||||
"<body>",
|
|
||||||
" <div class='header'>",
|
|
||||||
" <h1>Conversation Export</h1>",
|
|
||||||
f" <div class='export-info'>Exported: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</div>",
|
|
||||||
f" <div class='export-info'>Total Messages: {len(session_history)}</div>",
|
|
||||||
" </div>",
|
|
||||||
]
|
|
||||||
|
|
||||||
if session_system_prompt:
|
|
||||||
html_parts.extend([
|
|
||||||
" <div class='system-prompt'>",
|
|
||||||
" <strong>System Prompt</strong>",
|
|
||||||
f" <div>{html_escape(session_system_prompt)}</div>",
|
|
||||||
" </div>",
|
|
||||||
])
|
|
||||||
|
|
||||||
for i, entry in enumerate(session_history, 1):
|
|
||||||
prompt = html_escape(entry.get("prompt", ""))
|
|
||||||
response = html_escape(entry.get("response", ""))
|
|
||||||
|
|
||||||
html_parts.extend([
|
|
||||||
" <div class='message-container'>",
|
|
||||||
f" <div class='message-number'>Message {i} of {len(session_history)}</div>",
|
|
||||||
" <div class='message user-message'>",
|
|
||||||
" <div class='role user-role'>User</div>",
|
|
||||||
f" <div class='content'>{prompt}</div>",
|
|
||||||
" </div>",
|
|
||||||
" <div class='message assistant-message'>",
|
|
||||||
" <div class='role assistant-role'>Assistant</div>",
|
|
||||||
f" <div class='content'>{response}</div>",
|
|
||||||
" </div>",
|
|
||||||
" </div>",
|
|
||||||
])
|
|
||||||
|
|
||||||
html_parts.extend([
|
|
||||||
" <div class='footer'>",
|
|
||||||
f" <p>Generated by oAI v{APP_VERSION} • <a href='{APP_URL}'>{APP_URL}</a></p>",
|
|
||||||
" </div>",
|
|
||||||
"</body>",
|
|
||||||
"</html>",
|
|
||||||
])
|
|
||||||
|
|
||||||
return "\n".join(html_parts)
|
|
||||||
@@ -1,323 +0,0 @@
|
|||||||
"""
|
|
||||||
File handling utilities for oAI.
|
|
||||||
|
|
||||||
This module provides safe file reading, type detection, and other
|
|
||||||
file-related operations used throughout the application.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
|
||||||
import mimetypes
|
|
||||||
import base64
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional, Dict, Any, Tuple
|
|
||||||
|
|
||||||
from oai.constants import (
|
|
||||||
MAX_FILE_SIZE,
|
|
||||||
CONTENT_TRUNCATION_THRESHOLD,
|
|
||||||
SUPPORTED_CODE_EXTENSIONS,
|
|
||||||
ALLOWED_FILE_EXTENSIONS,
|
|
||||||
)
|
|
||||||
from oai.utils.logging import get_logger
|
|
||||||
|
|
||||||
|
|
||||||
def is_binary_file(file_path: Path) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a file appears to be binary.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to the file to check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the file appears to be binary, False otherwise
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with open(file_path, "rb") as f:
|
|
||||||
# Read first 8KB to check for binary content
|
|
||||||
chunk = f.read(8192)
|
|
||||||
# Check for null bytes (common in binary files)
|
|
||||||
if b"\x00" in chunk:
|
|
||||||
return True
|
|
||||||
# Try to decode as UTF-8
|
|
||||||
try:
|
|
||||||
chunk.decode("utf-8")
|
|
||||||
return False
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
return True
|
|
||||||
except Exception:
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
def get_file_type(file_path: Path) -> Tuple[Optional[str], str]:
|
|
||||||
"""
|
|
||||||
Determine the MIME type and category of a file.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to the file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple of (mime_type, category) where category is one of:
|
|
||||||
'image', 'pdf', 'code', 'text', 'binary', 'unknown'
|
|
||||||
"""
|
|
||||||
mime_type, _ = mimetypes.guess_type(str(file_path))
|
|
||||||
ext = file_path.suffix.lower()
|
|
||||||
|
|
||||||
if mime_type and mime_type.startswith("image/"):
|
|
||||||
return mime_type, "image"
|
|
||||||
elif mime_type == "application/pdf" or ext == ".pdf":
|
|
||||||
return mime_type or "application/pdf", "pdf"
|
|
||||||
elif ext in SUPPORTED_CODE_EXTENSIONS:
|
|
||||||
return mime_type or "text/plain", "code"
|
|
||||||
elif mime_type and mime_type.startswith("text/"):
|
|
||||||
return mime_type, "text"
|
|
||||||
elif is_binary_file(file_path):
|
|
||||||
return mime_type, "binary"
|
|
||||||
else:
|
|
||||||
return mime_type, "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
def read_file_safe(
|
|
||||||
file_path: Path,
|
|
||||||
max_size: int = MAX_FILE_SIZE,
|
|
||||||
truncate_threshold: int = CONTENT_TRUNCATION_THRESHOLD
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Safely read a file with size limits and truncation support.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to the file to read
|
|
||||||
max_size: Maximum file size to read (bytes)
|
|
||||||
truncate_threshold: Threshold for truncating large files
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Dictionary containing:
|
|
||||||
- content: File content (text or base64)
|
|
||||||
- size: File size in bytes
|
|
||||||
- truncated: Whether content was truncated
|
|
||||||
- encoding: 'text', 'base64', or None on error
|
|
||||||
- error: Error message if reading failed
|
|
||||||
"""
|
|
||||||
logger = get_logger()
|
|
||||||
|
|
||||||
try:
|
|
||||||
path = Path(file_path).resolve()
|
|
||||||
|
|
||||||
if not path.exists():
|
|
||||||
return {
|
|
||||||
"content": None,
|
|
||||||
"size": 0,
|
|
||||||
"truncated": False,
|
|
||||||
"encoding": None,
|
|
||||||
"error": f"File not found: {path}"
|
|
||||||
}
|
|
||||||
|
|
||||||
if not path.is_file():
|
|
||||||
return {
|
|
||||||
"content": None,
|
|
||||||
"size": 0,
|
|
||||||
"truncated": False,
|
|
||||||
"encoding": None,
|
|
||||||
"error": f"Not a file: {path}"
|
|
||||||
}
|
|
||||||
|
|
||||||
file_size = path.stat().st_size
|
|
||||||
|
|
||||||
if file_size > max_size:
|
|
||||||
return {
|
|
||||||
"content": None,
|
|
||||||
"size": file_size,
|
|
||||||
"truncated": False,
|
|
||||||
"encoding": None,
|
|
||||||
"error": f"File too large: {file_size / (1024*1024):.1f}MB (max: {max_size / (1024*1024):.0f}MB)"
|
|
||||||
}
|
|
||||||
|
|
||||||
# Try to read as text first
|
|
||||||
try:
|
|
||||||
content = path.read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
# Check if truncation is needed
|
|
||||||
if file_size > truncate_threshold:
|
|
||||||
lines = content.split("\n")
|
|
||||||
total_lines = len(lines)
|
|
||||||
|
|
||||||
# Keep first 500 lines and last 100 lines
|
|
||||||
head_lines = 500
|
|
||||||
tail_lines = 100
|
|
||||||
|
|
||||||
if total_lines > (head_lines + tail_lines):
|
|
||||||
truncated_content = (
|
|
||||||
"\n".join(lines[:head_lines]) +
|
|
||||||
f"\n\n... [TRUNCATED: {total_lines - head_lines - tail_lines} lines omitted] ...\n\n" +
|
|
||||||
"\n".join(lines[-tail_lines:])
|
|
||||||
)
|
|
||||||
logger.info(f"Read file (truncated): {path} ({file_size} bytes, {total_lines} lines)")
|
|
||||||
return {
|
|
||||||
"content": truncated_content,
|
|
||||||
"size": file_size,
|
|
||||||
"truncated": True,
|
|
||||||
"total_lines": total_lines,
|
|
||||||
"lines_shown": head_lines + tail_lines,
|
|
||||||
"encoding": "text",
|
|
||||||
"error": None
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(f"Read file: {path} ({file_size} bytes)")
|
|
||||||
return {
|
|
||||||
"content": content,
|
|
||||||
"size": file_size,
|
|
||||||
"truncated": False,
|
|
||||||
"encoding": "text",
|
|
||||||
"error": None
|
|
||||||
}
|
|
||||||
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
# File is binary, return base64 encoded
|
|
||||||
with open(path, "rb") as f:
|
|
||||||
binary_data = f.read()
|
|
||||||
b64_content = base64.b64encode(binary_data).decode("utf-8")
|
|
||||||
logger.info(f"Read binary file: {path} ({file_size} bytes)")
|
|
||||||
return {
|
|
||||||
"content": b64_content,
|
|
||||||
"size": file_size,
|
|
||||||
"truncated": False,
|
|
||||||
"encoding": "base64",
|
|
||||||
"error": None
|
|
||||||
}
|
|
||||||
|
|
||||||
except PermissionError as e:
|
|
||||||
return {
|
|
||||||
"content": None,
|
|
||||||
"size": 0,
|
|
||||||
"truncated": False,
|
|
||||||
"encoding": None,
|
|
||||||
"error": f"Permission denied: {e}"
|
|
||||||
}
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error reading file {file_path}: {e}")
|
|
||||||
return {
|
|
||||||
"content": None,
|
|
||||||
"size": 0,
|
|
||||||
"truncated": False,
|
|
||||||
"encoding": None,
|
|
||||||
"error": str(e)
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_file_extension(file_path: Path) -> str:
|
|
||||||
"""
|
|
||||||
Get the lowercase file extension.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to the file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Lowercase extension including the dot (e.g., '.py')
|
|
||||||
"""
|
|
||||||
return file_path.suffix.lower()
|
|
||||||
|
|
||||||
|
|
||||||
def is_allowed_extension(file_path: Path) -> bool:
|
|
||||||
"""
|
|
||||||
Check if a file has an allowed extension for attachment.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to the file
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if the extension is allowed, False otherwise
|
|
||||||
"""
|
|
||||||
return get_file_extension(file_path) in ALLOWED_FILE_EXTENSIONS
|
|
||||||
|
|
||||||
|
|
||||||
def format_file_size(size_bytes: int) -> str:
|
|
||||||
"""
|
|
||||||
Format a file size in human-readable format.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
size_bytes: Size in bytes
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Formatted string (e.g., '1.5 MB', '512 KB')
|
|
||||||
"""
|
|
||||||
for unit in ["B", "KB", "MB", "GB", "TB"]:
|
|
||||||
if abs(size_bytes) < 1024:
|
|
||||||
return f"{size_bytes:.1f} {unit}"
|
|
||||||
size_bytes /= 1024
|
|
||||||
return f"{size_bytes:.1f} PB"
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_file_attachment(
|
|
||||||
file_path: Path,
|
|
||||||
model_capabilities: Dict[str, Any]
|
|
||||||
) -> Optional[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
Prepare a file for attachment to an API request.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: Path to the file
|
|
||||||
model_capabilities: Model capability information
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Content block dictionary for the API, or None if unsupported
|
|
||||||
"""
|
|
||||||
logger = get_logger()
|
|
||||||
path = Path(file_path).resolve()
|
|
||||||
|
|
||||||
if not path.exists():
|
|
||||||
logger.warning(f"File not found: {path}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
mime_type, category = get_file_type(path)
|
|
||||||
file_size = path.stat().st_size
|
|
||||||
|
|
||||||
if file_size > MAX_FILE_SIZE:
|
|
||||||
logger.warning(f"File too large: {path} ({format_file_size(file_size)})")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(path, "rb") as f:
|
|
||||||
file_data = f.read()
|
|
||||||
|
|
||||||
if category == "image":
|
|
||||||
# Check if model supports images
|
|
||||||
input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", [])
|
|
||||||
if "image" not in input_modalities:
|
|
||||||
logger.warning(f"Model does not support images")
|
|
||||||
return None
|
|
||||||
|
|
||||||
b64_data = base64.b64encode(file_data).decode("utf-8")
|
|
||||||
return {
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {"url": f"data:{mime_type};base64,{b64_data}"}
|
|
||||||
}
|
|
||||||
|
|
||||||
elif category == "pdf":
|
|
||||||
# Check if model supports PDFs
|
|
||||||
input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", [])
|
|
||||||
supports_pdf = any(mod in input_modalities for mod in ["document", "pdf", "file"])
|
|
||||||
if not supports_pdf:
|
|
||||||
logger.warning(f"Model does not support PDFs")
|
|
||||||
return None
|
|
||||||
|
|
||||||
b64_data = base64.b64encode(file_data).decode("utf-8")
|
|
||||||
return {
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {"url": f"data:application/pdf;base64,{b64_data}"}
|
|
||||||
}
|
|
||||||
|
|
||||||
elif category in ("code", "text"):
|
|
||||||
text_content = file_data.decode("utf-8")
|
|
||||||
return {
|
|
||||||
"type": "text",
|
|
||||||
"text": f"File: {path.name}\n\n{text_content}"
|
|
||||||
}
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.warning(f"Unsupported file type: {category} ({mime_type})")
|
|
||||||
return None
|
|
||||||
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
logger.error(f"Cannot decode file as UTF-8: {path}")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Error preparing file attachment {path}: {e}")
|
|
||||||
return None
|
|
||||||
@@ -1,297 +0,0 @@
|
|||||||
"""
|
|
||||||
Logging configuration for oAI.
|
|
||||||
|
|
||||||
This module provides centralized logging setup with Rich formatting,
|
|
||||||
file rotation, and configurable log levels.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import io
|
|
||||||
import os
|
|
||||||
import glob
|
|
||||||
import logging
|
|
||||||
import datetime
|
|
||||||
import shutil
|
|
||||||
from logging.handlers import RotatingFileHandler
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.logging import RichHandler
|
|
||||||
|
|
||||||
from oai.constants import (
|
|
||||||
LOG_FILE,
|
|
||||||
CONFIG_DIR,
|
|
||||||
DEFAULT_LOG_MAX_SIZE_MB,
|
|
||||||
DEFAULT_LOG_BACKUP_COUNT,
|
|
||||||
DEFAULT_LOG_LEVEL,
|
|
||||||
VALID_LOG_LEVELS,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RotatingRichHandler(RotatingFileHandler):
|
|
||||||
"""
|
|
||||||
Custom log handler combining file rotation with Rich formatting.
|
|
||||||
|
|
||||||
This handler writes Rich-formatted log output to a rotating file,
|
|
||||||
providing colored and formatted logs even in file output while
|
|
||||||
managing file size and backups automatically.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
|
||||||
"""Initialize the handler with Rich console for formatting."""
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
# Create an internal console for Rich formatting
|
|
||||||
self.rich_console = Console(
|
|
||||||
file=io.StringIO(),
|
|
||||||
width=120,
|
|
||||||
force_terminal=False
|
|
||||||
)
|
|
||||||
self.rich_handler = RichHandler(
|
|
||||||
console=self.rich_console,
|
|
||||||
show_time=True,
|
|
||||||
show_path=True,
|
|
||||||
rich_tracebacks=True,
|
|
||||||
tracebacks_suppress=["requests", "openrouter", "urllib3", "httpx", "openai"]
|
|
||||||
)
|
|
||||||
|
|
||||||
def emit(self, record: logging.LogRecord) -> None:
|
|
||||||
"""
|
|
||||||
Emit a log record with Rich formatting.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
record: The log record to emit
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# Format with Rich
|
|
||||||
self.rich_handler.emit(record)
|
|
||||||
output = self.rich_console.file.getvalue()
|
|
||||||
self.rich_console.file.seek(0)
|
|
||||||
self.rich_console.file.truncate(0)
|
|
||||||
|
|
||||||
if output:
|
|
||||||
self.stream.write(output)
|
|
||||||
self.flush()
|
|
||||||
except Exception:
|
|
||||||
self.handleError(record)
|
|
||||||
|
|
||||||
|
|
||||||
class LoggingManager:
|
|
||||||
"""
|
|
||||||
Manages application logging configuration.
|
|
||||||
|
|
||||||
Provides methods to setup, configure, and manage logging with
|
|
||||||
support for runtime reconfiguration and level changes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
"""Initialize the logging manager."""
|
|
||||||
self.handler: Optional[RotatingRichHandler] = None
|
|
||||||
self.app_logger: Optional[logging.Logger] = None
|
|
||||||
self.max_size_mb: int = DEFAULT_LOG_MAX_SIZE_MB
|
|
||||||
self.backup_count: int = DEFAULT_LOG_BACKUP_COUNT
|
|
||||||
self.log_level: str = DEFAULT_LOG_LEVEL
|
|
||||||
|
|
||||||
def setup(
|
|
||||||
self,
|
|
||||||
max_size_mb: Optional[int] = None,
|
|
||||||
backup_count: Optional[int] = None,
|
|
||||||
log_level: Optional[str] = None
|
|
||||||
) -> logging.Logger:
|
|
||||||
"""
|
|
||||||
Setup or reconfigure logging.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_size_mb: Maximum log file size in MB
|
|
||||||
backup_count: Number of backup files to keep
|
|
||||||
log_level: Logging level string
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The configured application logger
|
|
||||||
"""
|
|
||||||
# Update configuration if provided
|
|
||||||
if max_size_mb is not None:
|
|
||||||
self.max_size_mb = max_size_mb
|
|
||||||
if backup_count is not None:
|
|
||||||
self.backup_count = backup_count
|
|
||||||
if log_level is not None:
|
|
||||||
self.log_level = log_level
|
|
||||||
|
|
||||||
# Ensure config directory exists
|
|
||||||
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
# Get root logger
|
|
||||||
root_logger = logging.getLogger()
|
|
||||||
|
|
||||||
# Remove existing handler if present
|
|
||||||
if self.handler is not None:
|
|
||||||
root_logger.removeHandler(self.handler)
|
|
||||||
try:
|
|
||||||
self.handler.close()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Check if log needs manual rotation
|
|
||||||
self._check_rotation()
|
|
||||||
|
|
||||||
# Create new handler
|
|
||||||
max_bytes = self.max_size_mb * 1024 * 1024
|
|
||||||
self.handler = RotatingRichHandler(
|
|
||||||
filename=str(LOG_FILE),
|
|
||||||
maxBytes=max_bytes,
|
|
||||||
backupCount=self.backup_count,
|
|
||||||
encoding="utf-8"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.handler.setLevel(logging.NOTSET)
|
|
||||||
root_logger.setLevel(logging.WARNING)
|
|
||||||
root_logger.addHandler(self.handler)
|
|
||||||
|
|
||||||
# Suppress noisy third-party loggers
|
|
||||||
for logger_name in [
|
|
||||||
"asyncio", "urllib3", "requests", "httpx",
|
|
||||||
"httpcore", "openai", "openrouter"
|
|
||||||
]:
|
|
||||||
logging.getLogger(logger_name).setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
# Configure application logger
|
|
||||||
self.app_logger = logging.getLogger("oai_app")
|
|
||||||
level = VALID_LOG_LEVELS.get(self.log_level.lower(), logging.INFO)
|
|
||||||
self.app_logger.setLevel(level)
|
|
||||||
self.app_logger.propagate = True
|
|
||||||
|
|
||||||
return self.app_logger
|
|
||||||
|
|
||||||
def _check_rotation(self) -> None:
|
|
||||||
"""Check if log file needs rotation and perform if necessary."""
|
|
||||||
if not LOG_FILE.exists():
|
|
||||||
return
|
|
||||||
|
|
||||||
current_size = LOG_FILE.stat().st_size
|
|
||||||
max_bytes = self.max_size_mb * 1024 * 1024
|
|
||||||
|
|
||||||
if current_size >= max_bytes:
|
|
||||||
# Perform manual rotation
|
|
||||||
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
||||||
backup_file = f"{LOG_FILE}.{timestamp}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
shutil.move(str(LOG_FILE), backup_file)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Clean old backups
|
|
||||||
self._cleanup_old_backups()
|
|
||||||
|
|
||||||
def _cleanup_old_backups(self) -> None:
|
|
||||||
"""Remove old backup files exceeding the backup count."""
|
|
||||||
log_dir = LOG_FILE.parent
|
|
||||||
backup_pattern = f"{LOG_FILE.name}.*"
|
|
||||||
backups = sorted(glob.glob(str(log_dir / backup_pattern)))
|
|
||||||
|
|
||||||
while len(backups) > self.backup_count:
|
|
||||||
oldest = backups.pop(0)
|
|
||||||
try:
|
|
||||||
os.remove(oldest)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def set_level(self, level: str) -> bool:
|
|
||||||
"""
|
|
||||||
Set the application log level.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
level: Log level string (debug/info/warning/error/critical)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if level was set successfully, False otherwise
|
|
||||||
"""
|
|
||||||
level_lower = level.lower()
|
|
||||||
if level_lower not in VALID_LOG_LEVELS:
|
|
||||||
return False
|
|
||||||
|
|
||||||
self.log_level = level_lower
|
|
||||||
if self.app_logger:
|
|
||||||
self.app_logger.setLevel(VALID_LOG_LEVELS[level_lower])
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def get_logger(self) -> logging.Logger:
|
|
||||||
"""
|
|
||||||
Get the application logger, initializing if necessary.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The application logger
|
|
||||||
"""
|
|
||||||
if self.app_logger is None:
|
|
||||||
self.setup()
|
|
||||||
return self.app_logger
|
|
||||||
|
|
||||||
|
|
||||||
# Global logging manager instance
|
|
||||||
_logging_manager = LoggingManager()
|
|
||||||
|
|
||||||
|
|
||||||
def setup_logging(
|
|
||||||
max_size_mb: Optional[int] = None,
|
|
||||||
backup_count: Optional[int] = None,
|
|
||||||
log_level: Optional[str] = None
|
|
||||||
) -> logging.Logger:
|
|
||||||
"""
|
|
||||||
Setup application logging.
|
|
||||||
|
|
||||||
This is the main entry point for configuring logging. Call this
|
|
||||||
early in application startup.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_size_mb: Maximum log file size in MB
|
|
||||||
backup_count: Number of backup files to keep
|
|
||||||
log_level: Logging level string
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The configured application logger
|
|
||||||
"""
|
|
||||||
return _logging_manager.setup(max_size_mb, backup_count, log_level)
|
|
||||||
|
|
||||||
|
|
||||||
def get_logger() -> logging.Logger:
|
|
||||||
"""
|
|
||||||
Get the application logger.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The application logger instance
|
|
||||||
"""
|
|
||||||
return _logging_manager.get_logger()
|
|
||||||
|
|
||||||
|
|
||||||
def set_log_level(level: str) -> bool:
|
|
||||||
"""
|
|
||||||
Set the application log level.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
level: Log level string
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if successful, False otherwise
|
|
||||||
"""
|
|
||||||
return _logging_manager.set_level(level)
|
|
||||||
|
|
||||||
|
|
||||||
def reload_logging(
|
|
||||||
max_size_mb: Optional[int] = None,
|
|
||||||
backup_count: Optional[int] = None,
|
|
||||||
log_level: Optional[str] = None
|
|
||||||
) -> logging.Logger:
|
|
||||||
"""
|
|
||||||
Reload logging configuration.
|
|
||||||
|
|
||||||
Useful when settings change at runtime.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
max_size_mb: New maximum log file size
|
|
||||||
backup_count: New backup count
|
|
||||||
log_level: New log level
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The reconfigured logger
|
|
||||||
"""
|
|
||||||
return _logging_manager.setup(max_size_mb, backup_count, log_level)
|
|
||||||
134
pyproject.toml
134
pyproject.toml
@@ -1,134 +0,0 @@
|
|||||||
[build-system]
|
|
||||||
requires = ["setuptools>=61.0", "wheel"]
|
|
||||||
build-backend = "setuptools.build_meta"
|
|
||||||
|
|
||||||
[project]
|
|
||||||
name = "oai"
|
|
||||||
version = "3.0.0"
|
|
||||||
description = "OpenRouter AI Chat Client - A feature-rich terminal-based chat application"
|
|
||||||
readme = "README.md"
|
|
||||||
license = {text = "MIT"}
|
|
||||||
authors = [
|
|
||||||
{name = "Rune", email = "rune@example.com"}
|
|
||||||
]
|
|
||||||
maintainers = [
|
|
||||||
{name = "Rune", email = "rune@example.com"}
|
|
||||||
]
|
|
||||||
keywords = [
|
|
||||||
"ai",
|
|
||||||
"chat",
|
|
||||||
"openrouter",
|
|
||||||
"cli",
|
|
||||||
"terminal",
|
|
||||||
"mcp",
|
|
||||||
"llm",
|
|
||||||
]
|
|
||||||
classifiers = [
|
|
||||||
"Development Status :: 4 - Beta",
|
|
||||||
"Environment :: Console",
|
|
||||||
"Intended Audience :: Developers",
|
|
||||||
"License :: OSI Approved :: MIT License",
|
|
||||||
"Operating System :: OS Independent",
|
|
||||||
"Programming Language :: Python :: 3",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
|
||||||
"Programming Language :: Python :: 3.11",
|
|
||||||
"Programming Language :: Python :: 3.12",
|
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
||||||
"Topic :: Utilities",
|
|
||||||
]
|
|
||||||
requires-python = ">=3.10"
|
|
||||||
dependencies = [
|
|
||||||
"anyio>=4.0.0",
|
|
||||||
"click>=8.0.0",
|
|
||||||
"httpx>=0.24.0",
|
|
||||||
"markdown-it-py>=3.0.0",
|
|
||||||
"openrouter>=0.0.19",
|
|
||||||
"packaging>=21.0",
|
|
||||||
"pyperclip>=1.8.0",
|
|
||||||
"requests>=2.28.0",
|
|
||||||
"rich>=13.0.0",
|
|
||||||
"textual>=0.50.0",
|
|
||||||
"typer>=0.9.0",
|
|
||||||
"mcp>=1.0.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.optional-dependencies]
|
|
||||||
dev = [
|
|
||||||
"pytest>=7.0.0",
|
|
||||||
"pytest-asyncio>=0.21.0",
|
|
||||||
"pytest-cov>=4.0.0",
|
|
||||||
"black>=23.0.0",
|
|
||||||
"isort>=5.12.0",
|
|
||||||
"mypy>=1.0.0",
|
|
||||||
"ruff>=0.1.0",
|
|
||||||
]
|
|
||||||
|
|
||||||
[project.urls]
|
|
||||||
Homepage = "https://iurl.no/oai"
|
|
||||||
Repository = "https://gitlab.pm/rune/oai"
|
|
||||||
Documentation = "https://iurl.no/oai"
|
|
||||||
"Bug Tracker" = "https://gitlab.pm/rune/oai/issues"
|
|
||||||
|
|
||||||
[project.scripts]
|
|
||||||
oai = "oai.cli:main"
|
|
||||||
|
|
||||||
[tool.setuptools]
|
|
||||||
packages = ["oai", "oai.commands", "oai.config", "oai.core", "oai.mcp", "oai.providers", "oai.tui", "oai.tui.widgets", "oai.tui.screens", "oai.utils"]
|
|
||||||
|
|
||||||
[tool.setuptools.package-data]
|
|
||||||
oai = ["py.typed"]
|
|
||||||
|
|
||||||
[tool.black]
|
|
||||||
line-length = 100
|
|
||||||
target-version = ["py310", "py311", "py312"]
|
|
||||||
include = '\.pyi?$'
|
|
||||||
exclude = '''
|
|
||||||
/(
|
|
||||||
\.git
|
|
||||||
| \.mypy_cache
|
|
||||||
| \.pytest_cache
|
|
||||||
| \.venv
|
|
||||||
| build
|
|
||||||
| dist
|
|
||||||
)/
|
|
||||||
'''
|
|
||||||
|
|
||||||
[tool.isort]
|
|
||||||
profile = "black"
|
|
||||||
line_length = 100
|
|
||||||
skip_gitignore = true
|
|
||||||
|
|
||||||
[tool.mypy]
|
|
||||||
python_version = "3.10"
|
|
||||||
warn_return_any = true
|
|
||||||
warn_unused_configs = true
|
|
||||||
ignore_missing_imports = true
|
|
||||||
exclude = [
|
|
||||||
"build",
|
|
||||||
"dist",
|
|
||||||
".venv",
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
line-length = 100
|
|
||||||
target-version = "py310"
|
|
||||||
select = [
|
|
||||||
"E", # pycodestyle errors
|
|
||||||
"W", # pycodestyle warnings
|
|
||||||
"F", # Pyflakes
|
|
||||||
"I", # isort
|
|
||||||
"B", # flake8-bugbear
|
|
||||||
"C4", # flake8-comprehensions
|
|
||||||
"UP", # pyupgrade
|
|
||||||
]
|
|
||||||
ignore = [
|
|
||||||
"E501", # line too long (handled by black)
|
|
||||||
"B008", # do not perform function calls in argument defaults
|
|
||||||
"C901", # too complex
|
|
||||||
]
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
|
||||||
testpaths = ["tests"]
|
|
||||||
python_files = ["test_*.py"]
|
|
||||||
asyncio_mode = "auto"
|
|
||||||
addopts = "-v --tb=short"
|
|
||||||
Reference in New Issue
Block a user