3 Commits
2.1-RC2 ... 3.0

Author SHA1 Message Date
1191fa6d19 Updated README 2026-02-04 11:26:08 +01:00
6298158d3c oAI version 3.0 beta 1 2026-02-04 11:22:53 +01:00
b0cf88704e 2.1 (#2)
Final release of version 2.1.

Headlights:

### Core Features
- 🤖 Interactive chat with 300+ AI models via OpenRouter
- 🔍 Model selection with search and filtering
- 💾 Conversation save/load/export (Markdown, JSON, HTML)
- 📎 File attachments (images, PDFs, code files)
- 💰 Real-time cost tracking and credit monitoring
- 🎨 Rich terminal UI with syntax highlighting
- 📝 Persistent command history with search (Ctrl+R)
- 🌐 Online mode (web search capabilities)
- 🧠 Conversation memory toggle

### MCP Integration
- 🔧 **File Mode**: AI can read, search, and list local files
  - Automatic .gitignore filtering
  - Virtual environment exclusion
  - Large file handling (auto-truncates >50KB)

- ✍️ **Write Mode**: AI can modify files with permission
  - Create, edit, delete files
  - Move, copy, organize files
  - Always requires explicit opt-in

- 🗄️ **Database Mode**: AI can query SQLite databases
  - Read-only access (safe)
  - Schema inspection
  - Full SQL query support

Reviewed-on: #2
Co-authored-by: Rune Olsen <rune@rune.pm>
Co-committed-by: Rune Olsen <rune@rune.pm>
2026-02-03 09:02:44 +01:00
50 changed files with 13224 additions and 6464 deletions

7
.gitignore vendored
View File

@@ -23,6 +23,9 @@ Pipfile.lock # Consider if you want to include or exclude
*~.nib *~.nib
*~.xib *~.xib
# Claude Code local settings
.claude/
# Added by author # Added by author
*.zip *.zip
.note .note
@@ -39,3 +42,7 @@ b0.sh
*.old *.old
*.sh *.sh
*.back *.back
requirements.txt
system_prompt.txt
CLAUDE*
SESSION*_COMPLETE.md

696
README.md
View File

@@ -1,584 +1,326 @@
# oAI - OpenRouter AI Chat # oAI - OpenRouter AI Chat Client
A powerful terminal-based chat interface for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI agents to access local files and query SQLite databases directly. A powerful, modern **Textual TUI** chat client for OpenRouter API with **MCP (Model Context Protocol)** support, enabling AI to access local files and query SQLite databases.
## Description
oAI is a feature-rich command-line chat application that provides an interactive interface to OpenRouter's AI models. It now includes **MCP integration** for local file system access and read-only database querying, allowing AI to help with code analysis, data exploration, and more.
## Features ## Features
### Core Features ### Core Features
- 🖥️ **Modern Textual TUI** with async streaming and beautiful interface
- 🤖 Interactive chat with 300+ AI models via OpenRouter - 🤖 Interactive chat with 300+ AI models via OpenRouter
- 🔍 Model selection with search and capability filtering - 🔍 Model selection with search, filtering, and capability icons
- 💾 Conversation save/load/export (Markdown, JSON, HTML) - 💾 Conversation save/load/export (Markdown, JSON, HTML)
- 📎 File attachment support (images, PDFs, code files) - 📎 File attachments (images, PDFs, code files)
- 💰 Session cost tracking and credit monitoring - 💰 Real-time cost tracking and credit monitoring
- 🎨 Rich terminal formatting with syntax highlighting - 🎨 Dark theme with syntax highlighting and Markdown rendering
- 📝 Persistent command history with search (Ctrl+R) - 📝 Command history navigation (Up/Down arrows)
- ⚙️ Configurable system prompts and token limits
- 🗄️ SQLite-based configuration and conversation storage
- 🌐 Online mode (web search capabilities) - 🌐 Online mode (web search capabilities)
- 🧠 Conversation memory toggle (save costs with stateless mode) - 🧠 Conversation memory toggle
- ⌨️ Keyboard shortcuts (F1=Help, F2=Models, Ctrl+S=Stats)
### NEW: MCP (Model Context Protocol) v2.1.0-beta ### MCP Integration
- 🔧 **File Mode**: AI can read, search, and list your local files - 🔧 **File Mode**: AI can read, search, and list local files
- Automatic .gitignore filtering - Automatic .gitignore filtering
- Virtual environment exclusion (venv, node_modules, etc.) - Virtual environment exclusion
- Supports code files, text, JSON, YAML, and more
- Large file handling (auto-truncates >50KB) - Large file handling (auto-truncates >50KB)
- ✍️ **Write Mode** (NEW!): AI can modify files with your permission
- Create and edit files within allowed folders
- Delete files (always requires confirmation)
- Move, copy, and organize files
- Create directories
- Ignores .gitignore for write operations
- OFF by default - explicit opt-in required
- 🗄️ **Database Mode**: AI can query your SQLite databases
- Read-only access (no data modification possible)
- Schema inspection (tables, columns, indexes)
- Full-text search across all tables
- SQL query execution (SELECT, JOINs, CTEs, subqueries)
- Query validation and timeout protection
- Result limiting (max 1000 rows)
- 🔒 **Security Features**: - ✍️ **Write Mode**: AI can modify files with permission
- Explicit folder/database approval required - Create, edit, delete files
- System directory blocking - Move, copy, organize files
- Write mode OFF by default (non-persistent) - Always requires explicit opt-in
- Delete operations always require user confirmation
- Read-only database access - 🗄️ **Database Mode**: AI can query SQLite databases
- SQL injection protection - Read-only access (safe)
- Query timeout (5 seconds) - Schema inspection
- Full SQL query support
## Requirements ## Requirements
- Python 3.10-3.13 (3.14 not supported yet) - Python 3.10-3.13
- OpenRouter API key (get one at https://openrouter.ai) - OpenRouter API key ([get one here](https://openrouter.ai))
- Function-calling model required for MCP features (GPT-4, Claude, etc.)
## Screenshot
[<img src="https://gitlab.pm/rune/oai/raw/branch/main/images/screenshot_01.png">](https://gitlab.pm/rune/oai/src/branch/main/README.md)
*Screenshot from version 1.0 - MCP interface shows mode indicators like `[🔧 MCP: Files]` or `[🗄️ MCP: DB #1]`*
## Installation ## Installation
### Option 1: From Source (Recommended for Development) ### Option 1: Pre-built Binary (macOS/Linux) (Recommended)
#### 1. Install Dependencies Download from [Releases](https://gitlab.pm/rune/oai/releases):
- **macOS (Apple Silicon)**: `oai_v3.0.0_mac_arm64.zip`
```bash - **Linux (x86_64)**: `oai_v3.0.0_linux_x86_64.zip`
pip install -r requirements.txt
```
#### 2. Make Executable
```bash
chmod +x oai.py
```
#### 3. Copy to PATH
```bash
# Option 1: System-wide (requires sudo)
sudo cp oai.py /usr/local/bin/oai
# Option 2: User-local (recommended)
mkdir -p ~/.local/bin
cp oai.py ~/.local/bin/oai
# Add to PATH if needed (add to ~/.bashrc or ~/.zshrc)
export PATH="$HOME/.local/bin:$PATH"
```
#### 4. Verify Installation
```bash
oai --version
```
### Option 2: Pre-built Binaries
Download platform-specific binaries:
- **macOS (Apple Silicon)**: `oai_vx.x.x_mac_arm64.zip`
- **Linux (x86_64)**: `oai_vx.x.x-linux-x86_64.zip`
```bash ```bash
# Extract and install # Extract and install
unzip oai_vx.x.x_mac_arm64.zip # or `oai_vx.x.x-linux-x86_64.zip` unzip oai_v3.0.0_*.zip
chmod +x oai mkdir -p ~/.local/bin
mkdir -p ~/.local/bin # Remember to add this to your path. Or just move to folder already in your $PATH
mv oai ~/.local/bin/ mv oai ~/.local/bin/
# macOS only: Remove quarantine and approve
xattr -cr ~/.local/bin/oai
# Then right-click oai in Finder → Open With → Terminal → Click "Open"
``` ```
### Add to PATH
### Alternative: Shell Alias
```bash ```bash
# Add to ~/.bashrc or ~/.zshrc # Add to ~/.zshrc or ~/.bashrc
alias oai='python3 /path/to/oai.py' export PATH="$HOME/.local/bin:$PATH"
``` ```
### Option 2: Install from Source
```bash
# Clone the repository
git clone https://gitlab.pm/rune/oai.git
cd oai
# Install with pip
pip install -e .
```
## Quick Start ## Quick Start
### First Run Setup
```bash ```bash
oai # Start oAI (launches TUI)
```
On first run, you'll be prompted to enter your OpenRouter API key.
### Basic Usage
```bash
# Start chatting
oai oai
# Select a model # Or with options
You> /model oai --model gpt-4o --online --mcp
# Enable MCP for file access # Show version
You> /mcp on oai version
You> /mcp add ~/Documents
# Ask AI to help with files (read-only)
[🔧 MCP: Files] You> List all Python files in Documents
[🔧 MCP: Files] You> Read and explain main.py
# Enable write mode to let AI modify files
You> /mcp write on
[🔧✍️ MCP: Files+Write] You> Create a new Python file with helper functions
[🔧✍️ MCP: Files+Write] You> Refactor main.py to use async/await
# Switch to database mode
You> /mcp add db ~/myapp/data.db
You> /mcp db 1
[🗄️ MCP: DB #1] You> Show me all tables
[🗄️ MCP: DB #1] You> Find all users created this month
``` ```
## MCP Guide On first run, you'll be prompted for your OpenRouter API key.
### File Mode (Default) ### Basic Commands
**Setup:**
```bash ```bash
/mcp on # Start MCP server # In the TUI interface:
/mcp add ~/Projects # Grant access to folder /model # Select AI model (or press F2)
/mcp add ~/Documents # Add another folder /help # Show all commands (or press F1)
/mcp list # View all allowed folders /mcp on # Enable file/database access
/stats # View session statistics (or press Ctrl+S)
/config # View configuration settings
/credits # Check account credits
Ctrl+Q # Quit
``` ```
**Natural Language Usage:** ## MCP (Model Context Protocol)
```
MCP allows the AI to interact with your local files and databases.
### File Access
```bash
/mcp on # Enable MCP
/mcp add ~/Projects # Grant access to folder
/mcp list # View allowed folders
# Now ask the AI:
"List all Python files in Projects" "List all Python files in Projects"
"Read and explain config.yaml" "Read and explain main.py"
"Search for files containing 'TODO'" "Search for files containing 'TODO'"
"What's in my Documents folder?"
``` ```
**Available Tools (Read-Only):**
- `read_file` - Read complete file contents
- `list_directory` - List files/folders (recursive optional)
- `search_files` - Search by name or content
**Available Tools (Write Mode - requires `/mcp write on`):**
- `write_file` - Create new files or overwrite existing ones
- `edit_file` - Find and replace text in existing files
- `delete_file` - Delete files (always requires confirmation)
- `create_directory` - Create directories
- `move_file` - Move or rename files
- `copy_file` - Copy files to new locations
**Features:**
- ✅ Automatic .gitignore filtering (read operations only)
- ✅ Skips virtual environments (venv, node_modules)
- ✅ Handles large files (auto-truncates >50KB)
- ✅ Cross-platform (macOS, Linux, Windows via WSL)
- ✅ Write mode OFF by default for safety
- ✅ Delete operations require user confirmation with LLM's reason
### Database Mode
**Setup:**
```bash
/mcp add db ~/app/database.db # Add SQLite database
/mcp db list # View all databases
/mcp db 1 # Switch to database #1
```
**Natural Language Usage:**
```
"Show me all tables in this database"
"Find records mentioning 'error'"
"How many users registered last week?"
"Get the schema for the orders table"
"Show me the 10 most recent transactions"
```
**Available Tools:**
- `inspect_database` - View schema, tables, columns, indexes
- `search_database` - Full-text search across tables
- `query_database` - Execute read-only SQL queries
**Supported Queries:**
- ✅ SELECT statements
- ✅ JOINs (INNER, LEFT, RIGHT, FULL)
- ✅ Subqueries
- ✅ CTEs (Common Table Expressions)
- ✅ Aggregations (COUNT, SUM, AVG, etc.)
- ✅ WHERE, GROUP BY, HAVING, ORDER BY, LIMIT
- ❌ INSERT/UPDATE/DELETE (blocked for safety)
### Write Mode ### Write Mode
**Enable Write Mode:**
```bash ```bash
/mcp write on # Enable write mode (shows warning, requires confirmation) /mcp write on # Enable file modifications
# AI can now:
"Create a new file called utils.py"
"Edit config.json and update the API URL"
"Delete the old backup files" # Always asks for confirmation
``` ```
**Natural Language Usage:** ### Database Mode
```
"Create a new Python file called utils.py with helper functions"
"Edit main.py and replace the old API endpoint with the new one"
"Delete the backup.old file" (will prompt for confirmation)
"Create a directory called tests"
"Move config.json to the config folder"
```
**Important:**
- ⚠️ Write mode is **OFF by default** and resets each session
- ⚠️ Delete operations **always** require user confirmation
- ⚠️ All operations are limited to allowed MCP folders
- ✅ Write operations ignore .gitignore (can write to any file in allowed folders)
**Disable Write Mode:**
```bash
/mcp write off # Disable write mode (back to read-only)
```
### Mode Management
```bash ```bash
/mcp status # Show current mode, write mode, stats, folders/databases /mcp add db ~/app/data.db # Add database
/mcp files # Switch to file mode /mcp db 1 # Switch to database mode
/mcp db <number> # Switch to database mode
/mcp gitignore on # Enable .gitignore filtering (default) # Ask the AI:
/mcp write on|off # Enable/disable write mode "Show all tables"
/mcp remove 2 # Remove folder/database by number "Find users created this month"
"What's the schema for the orders table?"
``` ```
## Command Reference ## Command Reference
### Session Commands ### Chat Commands
``` | Command | Description |
/help [command] Show help menu or detailed command help |---------|-------------|
/help mcp Comprehensive MCP guide | `/help [cmd]` | Show help |
/clear or /cl Clear terminal screen (or Ctrl+L) | `/model [search]` | Select model |
/memory on|off Toggle conversation memory (save costs) | `/info [model]` | Model details |
/online on|off Enable/disable web search | `/memory on\|off` | Toggle context |
/paste [prompt] Paste clipboard content | `/online on\|off` | Toggle web search |
/retry Resend last prompt | `/retry` | Resend last message |
/reset Clear history and system prompt | `/clear` | Clear screen |
/prev View previous response
/next View next response
```
### MCP Commands ### MCP Commands
``` | Command | Description |
/mcp on Start MCP server |---------|-------------|
/mcp off Stop MCP server | `/mcp on\|off` | Enable/disable MCP |
/mcp status Show comprehensive status (includes write mode) | `/mcp status` | Show MCP status |
/mcp add <folder> Add folder for file access | `/mcp add <path>` | Add folder |
/mcp add db <path> Add SQLite database | `/mcp add db <path>` | Add database |
/mcp list List all folders | `/mcp list` | List folders |
/mcp db list List all databases | `/mcp db list` | List databases |
/mcp db <number> Switch to database mode | `/mcp db <n>` | Switch to database |
/mcp files Switch to file mode | `/mcp files` | Switch to file mode |
/mcp remove <num> Remove folder/database | `/mcp write on\|off` | Toggle write mode |
/mcp gitignore on Enable .gitignore filtering
/mcp write on Enable write mode (create/edit/delete files)
/mcp write off Disable write mode (read-only)
```
### Model Commands ### Conversation Commands
``` | Command | Description |
/model [search] Select/change AI model |---------|-------------|
/info [model_id] Show model details (pricing, capabilities) | `/save <name>` | Save conversation |
``` | `/load <name>` | Load conversation |
| `/list` | List saved conversations |
| `/delete <name>` | Delete conversation |
| `/export md\|json\|html <file>` | Export |
### Configuration ### Configuration
``` | Command | Description |
/config View all settings |---------|-------------|
/config api Set API key | `/config` | View settings |
/config model Set default model | `/config api` | Set API key |
/config online Set default online mode (on|off) | `/config model <id>` | Set default model |
/config stream Enable/disable streaming (on|off) | `/config stream on\|off` | Toggle streaming |
/config maxtoken Set max token limit | `/stats` | Session statistics |
/config costwarning Set cost warning threshold ($) | `/credits` | Check credits |
/config loglevel Set log level (debug/info/warning/error)
/config log Set log file size (MB)
```
### Conversation Management ## CLI Options
```
/save <name> Save conversation
/load <name|num> Load saved conversation
/delete <name|num> Delete conversation
/list List saved conversations
/export md|json|html <file> Export conversation
```
### Token & System
```
/maxtoken [value] Set session token limit
/system [prompt] Set system prompt (use 'clear' to reset)
/middleout on|off Enable prompt compression
```
### Monitoring
```
/stats View session statistics
/credits Check OpenRouter credits
```
### File Attachments
```
@/path/to/file Attach file (images, PDFs, code)
Examples:
Debug @script.py
Analyze @data.json
Review @screenshot.png
```
## Configuration Options
All configuration stored in `~/.config/oai/`:
### Files
- `oai_config.db` - SQLite database (settings, conversations, MCP config)
- `oai.log` - Application logs (rotating, configurable size)
- `history.txt` - Command history (searchable with Ctrl+R)
### Key Settings
- **API Key**: OpenRouter authentication
- **Default Model**: Auto-select on startup
- **Streaming**: Real-time response display
- **Max Tokens**: Global and session limits
- **Cost Warning**: Alert threshold for expensive requests
- **Online Mode**: Default web search setting
- **Log Level**: debug/info/warning/error/critical
- **Log Size**: Rotating file size in MB
## Supported File Types
### Code Files
`.py, .js, .ts, .cs, .java, .c, .cpp, .h, .hpp, .rb, .ruby, .php, .swift, .kt, .kts, .go, .sh, .bat, .ps1, .R, .scala, .pl, .lua, .dart, .elm`
### Data Files
`.json, .yaml, .yml, .xml, .csv, .txt, .md`
### Images
All standard formats: PNG, JPEG, JPG, GIF, WEBP, BMP
### Documents
PDF (models with document support)
### Size Limits
- Images: 10 MB max
- Code/Text: Auto-truncates files >50KB
- Binary data: Displayed as `<binary: X bytes>`
## MCP Security
### Access Control
- ✅ Explicit folder/database approval required
- ✅ System directories blocked automatically
- ✅ User confirmation for each addition
- ✅ .gitignore patterns respected (file mode)
### Database Safety
- ✅ Read-only mode (cannot modify data)
- ✅ SQL query validation (blocks INSERT/UPDATE/DELETE)
- ✅ Query timeout (5 seconds max)
- ✅ Result limits (1000 rows max)
- ✅ Database opened in `mode=ro`
### File System Safety
- ✅ Read-only by default (write mode requires explicit opt-in)
- ✅ Write mode OFF by default each session (non-persistent)
- ✅ Delete operations always require user confirmation
- ✅ Write operations limited to allowed folders only
- ✅ System directories blocked
- ✅ Virtual environment exclusion
- ✅ Build artifact filtering
- ✅ Maximum file size (10 MB)
## Tips & Tricks
### Command History
- **↑/↓ arrows**: Navigate previous commands
- **Ctrl+R**: Search command history
- **Auto-complete**: Start typing `/` for command suggestions
### Cost Optimization
```bash ```bash
/memory off # Disable context (stateless mode) oai [OPTIONS]
/maxtoken 1000 # Limit response length
/config costwarning 0.01 # Set alert threshold Options:
-m, --model TEXT Model ID to use
-s, --system TEXT System prompt
-o, --online Enable online mode
--mcp Enable MCP server
-v, --version Show version
--help Show help
``` ```
### MCP Best Practices Commands:
```bash ```bash
# Check status frequently oai # Launch TUI (default)
/mcp status oai version # Show version information
oai --help # Show help message
# Use specific paths to reduce search time
"List Python files in Projects/app/" # Better than
"List all Python files" # Slower
# Database queries - be specific
"SELECT * FROM users LIMIT 10" # Good
"SELECT * FROM users" # May hit row limit
``` ```
### Debugging ## Configuration
```bash
# Enable debug logging
/config loglevel debug
# Check log file Configuration is stored in `~/.config/oai/`:
tail -f ~/.config/oai/oai.log
# View MCP statistics | File | Purpose |
/mcp status # Shows tool call counts |------|---------|
| `oai_config.db` | Settings, conversations, MCP config |
| `oai.log` | Application logs |
| `history.txt` | Command history |
## Project Structure
```
oai/
├── oai/
│ ├── __init__.py
│ ├── __main__.py # Entry point for python -m oai
│ ├── cli.py # Main CLI entry point
│ ├── constants.py # Configuration constants
│ ├── commands/ # Slash command handlers
│ ├── config/ # Settings and database
│ ├── core/ # Chat client and session
│ ├── mcp/ # MCP server and tools
│ ├── providers/ # AI provider abstraction
│ ├── tui/ # Textual TUI interface
│ │ ├── app.py # Main TUI application
│ │ ├── widgets/ # Custom widgets
│ │ ├── screens/ # Modal screens
│ │ └── styles.tcss # TUI styling
│ └── utils/ # Logging, export, etc.
├── pyproject.toml # Package configuration
├── build.sh # Binary build script
└── README.md
``` ```
## Troubleshooting ## Troubleshooting
### MCP Not Working ### macOS Binary Issues
```bash
# 1. Check if MCP is installed
python3 -c "import mcp; print('MCP OK')"
# 2. Verify model supports function calling ```bash
# Remove quarantine attribute
xattr -cr ~/.local/bin/oai
# Then in Finder: right-click oai → Open With → Terminal → Click "Open"
# After this, oai works from any terminal
```
### MCP Not Working
```bash
# Check if model supports function calling
/info # Look for "tools" in supported parameters /info # Look for "tools" in supported parameters
# 3. Check MCP status # Check MCP status
/mcp status /mcp status
# 4. Review logs # View logs
tail ~/.config/oai/oai.log tail -f ~/.config/oai/oai.log
``` ```
### Import Errors ### Import Errors
```bash ```bash
# Reinstall dependencies # Reinstall package
pip install --force-reinstall -r requirements.txt pip install -e . --force-reinstall
```
### Binary Issues (macOS)
```bash
# Remove quarantine
xattr -cr ~/.local/bin/oai
# Check security settings
# System Settings > Privacy & Security > "Allow Anyway"
```
### Database Errors
```bash
# Verify it's a valid SQLite database
sqlite3 database.db ".tables"
# Check file permissions
ls -la database.db
``` ```
## Version History ## Version History
### v2.1.0-RC1 (Current) ### v3.0.0 (Current)
- **NEW**: MCP (Model Context Protocol) integration - 🎨 **Complete migration to Textual TUI** - Modern async terminal interface
- **NEW**: File system access (read, search, list) - 🗑️ **Removed CLI interface** - TUI-only for cleaner codebase (11.6% smaller)
- **NEW**: Write mode - AI can create, edit, and delete files - 🖱️ **Modal screens** - Help, stats, config, credits, model selector
- 6 write tools: write_file, edit_file, delete_file, create_directory, move_file, copy_file - ⌨️ **Keyboard shortcuts** - F1 (help), F2 (models), Ctrl+S (stats), etc.
- OFF by default - requires explicit `/mcp write on` activation - 🎯 **Capability indicators** - Visual icons for model features (vision, tools, online)
- Delete operations always require user confirmation - 🎨 **Consistent dark theme** - Professional styling throughout
- Non-persistent setting (resets each session) - 📊 **Enhanced model selector** - Search, filter, capability columns
- **NEW**: SQLite database querying (read-only) - 🚀 **Default command** - Just run `oai` to launch TUI
- **NEW**: Dual mode support (Files & Database) - 🧹 **Code cleanup** - Removed 1,300+ lines of CLI code
-**NEW**: .gitignore filtering
-**NEW**: Binary data handling in databases
-**NEW**: Mode indicators in prompt (shows ✍️ when write mode active)
-**NEW**: Comprehensive `/help mcp` guide
- 🔧 Improved error handling for tool calls
- 🔧 Enhanced logging for MCP operations
- 🔧 Statistics tracking for tool usage
### v1.9.6 ### v2.1.0
- Base version with core chat functionality - 🏗️ Complete codebase refactoring to modular package structure
- Conversation management - 🔌 Extensible provider architecture for adding new AI providers
- 📦 Proper Python packaging with pyproject.toml
- ✨ MCP integration (file access, write mode, database queries)
- 🔧 Command registry pattern for slash commands
- 📊 Improved cost tracking and session statistics
### v1.9.x
- Single-file implementation
- Core chat functionality
- File attachments - File attachments
- Cost tracking - Conversation management
- Export capabilities
## License ## License
MIT License MIT License - See [LICENSE](LICENSE) for details.
Copyright (c) 2024-2025 Rune Olsen
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
Full license: https://opensource.org/licenses/MIT
## Author ## Author
**Rune Olsen** **Rune Olsen**
- Homepage: https://ai.fubar.pm/
- Blog: https://blog.rune.pm
- Project: https://iurl.no/oai - Project: https://iurl.no/oai
- Repository: https://gitlab.pm/rune/oai
## Contributing ## Contributing
Contributions welcome! Please:
1. Fork the repository 1. Fork the repository
2. Create a feature branch 2. Create a feature branch
3. Submit a pull request with detailed description 3. Submit a pull request
## Acknowledgments
- OpenRouter team for the unified AI API
- Rich library for beautiful terminal output
- MCP community for the protocol specification
--- ---
**Star this project if you find it useful!** **Star this project if you find it useful!**
---
Did you really read all the way down here? WOW! You deserve a 🍾 🥂!

5987
oai.py

File diff suppressed because it is too large Load Diff

26
oai/__init__.py Normal file
View File

@@ -0,0 +1,26 @@
"""
oAI - OpenRouter AI Chat Client
A feature-rich terminal-based chat application that provides an interactive CLI
interface to OpenRouter's unified AI API with advanced Model Context Protocol (MCP)
integration for filesystem and database access.
Author: Rune
License: MIT
"""
__version__ = "2.1.0"
__author__ = "Rune"
__license__ = "MIT"
# Lazy imports to avoid circular dependencies and improve startup time
# Full imports are available via submodules:
# from oai.config import Settings, Database
# from oai.providers import OpenRouterProvider, AIProvider
# from oai.mcp import MCPManager
__all__ = [
"__version__",
"__author__",
"__license__",
]

8
oai/__main__.py Normal file
View File

@@ -0,0 +1,8 @@
"""
Entry point for running oAI as a module: python -m oai
"""
from oai.cli import main
if __name__ == "__main__":
main()

199
oai/cli.py Normal file
View File

@@ -0,0 +1,199 @@
"""
Main CLI entry point for oAI.
This module provides the command-line interface for the oAI TUI application.
"""
import sys
from typing import Optional
import typer
from oai import __version__
from oai.commands import register_all_commands
from oai.config.settings import Settings
from oai.constants import APP_URL, APP_VERSION
from oai.core.client import AIClient
from oai.core.session import ChatSession
from oai.mcp.manager import MCPManager
from oai.utils.logging import LoggingManager, get_logger
# Create Typer app
app = typer.Typer(
name="oai",
help=f"oAI - OpenRouter AI Chat Client (TUI)\n\nVersion: {APP_VERSION}",
add_completion=False,
epilog="For more information, visit: " + APP_URL,
)
@app.callback(invoke_without_command=True)
def main_callback(
ctx: typer.Context,
version_flag: bool = typer.Option(
False,
"--version",
"-v",
help="Show version information",
is_flag=True,
),
model: Optional[str] = typer.Option(
None,
"--model",
"-m",
help="Model ID to use",
),
system: Optional[str] = typer.Option(
None,
"--system",
"-s",
help="System prompt",
),
online: bool = typer.Option(
False,
"--online",
"-o",
help="Enable online mode",
),
mcp: bool = typer.Option(
False,
"--mcp",
help="Enable MCP server",
),
) -> None:
"""Main callback - launches TUI by default."""
if version_flag:
typer.echo(f"oAI version {APP_VERSION}")
raise typer.Exit()
# If no subcommand provided, launch TUI
if ctx.invoked_subcommand is None:
_launch_tui(model, system, online, mcp)
def _launch_tui(
model: Optional[str] = None,
system: Optional[str] = None,
online: bool = False,
mcp: bool = False,
) -> None:
"""Launch the Textual TUI interface."""
# Setup logging
logging_manager = LoggingManager()
logging_manager.setup()
logger = get_logger()
# Load settings
settings = Settings.load()
# Check API key
if not settings.api_key:
typer.echo("Error: No API key configured", err=True)
typer.echo("Run: oai config api to set your API key", err=True)
raise typer.Exit(1)
# Initialize client
try:
client = AIClient(
api_key=settings.api_key,
base_url=settings.base_url,
)
except Exception as e:
typer.echo(f"Error: Failed to initialize client: {e}", err=True)
raise typer.Exit(1)
# Register commands
register_all_commands()
# Initialize MCP manager (always create it, even if not enabled)
mcp_manager = MCPManager()
if mcp:
try:
result = mcp_manager.enable()
if result["success"]:
logger.info("MCP server enabled in files mode")
else:
logger.warning(f"MCP: {result.get('error', 'Failed to enable')}")
except Exception as e:
logger.warning(f"Failed to enable MCP: {e}")
# Create session with MCP manager
session = ChatSession(
client=client,
settings=settings,
mcp_manager=mcp_manager,
)
# Set system prompt if provided
if system:
session.set_system_prompt(system)
# Enable online mode if requested
if online:
session.online_enabled = True
# Set model if specified, otherwise use default
if model:
raw_model = client.get_raw_model(model)
if raw_model:
session.set_model(raw_model)
else:
logger.warning(f"Model '{model}' not found")
elif settings.default_model:
raw_model = client.get_raw_model(settings.default_model)
if raw_model:
session.set_model(raw_model)
else:
logger.warning(f"Default model '{settings.default_model}' not available")
# Run Textual app
from oai.tui.app import oAIChatApp
app_instance = oAIChatApp(session, settings, model)
app_instance.run()
@app.command()
def tui(
model: Optional[str] = typer.Option(
None,
"--model",
"-m",
help="Model ID to use",
),
system: Optional[str] = typer.Option(
None,
"--system",
"-s",
help="System prompt",
),
online: bool = typer.Option(
False,
"--online",
"-o",
help="Enable online mode",
),
mcp: bool = typer.Option(
False,
"--mcp",
help="Enable MCP server",
),
) -> None:
"""Start Textual TUI interface (alias for just running 'oai')."""
_launch_tui(model, system, online, mcp)
@app.command()
def version() -> None:
"""Show version information."""
typer.echo(f"oAI version {APP_VERSION}")
typer.echo(f"Visit {APP_URL} for more information")
def main() -> None:
"""Entry point for the CLI."""
app()
if __name__ == "__main__":
main()

24
oai/commands/__init__.py Normal file
View File

@@ -0,0 +1,24 @@
"""
Command system for oAI.
This module provides a command registry and handler system
for processing slash commands in the chat interface.
"""
from oai.commands.registry import (
Command,
CommandRegistry,
CommandContext,
CommandResult,
registry,
)
from oai.commands.handlers import register_all_commands
__all__ = [
"Command",
"CommandRegistry",
"CommandContext",
"CommandResult",
"registry",
"register_all_commands",
]

1478
oai/commands/handlers.py Normal file

File diff suppressed because it is too large Load Diff

382
oai/commands/registry.py Normal file
View File

@@ -0,0 +1,382 @@
"""
Command registry for oAI.
This module defines the command system infrastructure including
the Command base class, CommandContext for state, and CommandRegistry
for managing available commands.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
from oai.utils.logging import get_logger
if TYPE_CHECKING:
from oai.config.settings import Settings
from oai.providers.base import AIProvider, ModelInfo
from oai.mcp.manager import MCPManager
class CommandStatus(str, Enum):
"""Status of command execution."""
SUCCESS = "success"
ERROR = "error"
CONTINUE = "continue" # Continue to next handler
EXIT = "exit" # Exit the application
@dataclass
class CommandResult:
"""
Result of a command execution.
Attributes:
status: Execution status
message: Optional message to display
data: Optional data payload
should_continue: Whether to continue the main loop
"""
status: CommandStatus = CommandStatus.SUCCESS
message: Optional[str] = None
data: Optional[Any] = None
should_continue: bool = True
@classmethod
def success(cls, message: Optional[str] = None, data: Any = None) -> "CommandResult":
"""Create a success result."""
return cls(status=CommandStatus.SUCCESS, message=message, data=data)
@classmethod
def error(cls, message: str) -> "CommandResult":
"""Create an error result."""
return cls(status=CommandStatus.ERROR, message=message)
@classmethod
def exit(cls, message: Optional[str] = None) -> "CommandResult":
"""Create an exit result."""
return cls(status=CommandStatus.EXIT, message=message, should_continue=False)
@dataclass
class CommandContext:
"""
Context object providing state to command handlers.
Contains all the session state needed by commands including
settings, provider, conversation history, and MCP manager.
Attributes:
settings: Application settings
provider: AI provider instance
mcp_manager: MCP manager instance
selected_model: Currently selected model
session_history: Conversation history
session_system_prompt: Current system prompt
memory_enabled: Whether memory is enabled
online_enabled: Whether online mode is enabled
session_tokens: Session token counts
session_cost: Session cost total
"""
settings: Optional["Settings"] = None
provider: Optional["AIProvider"] = None
mcp_manager: Optional["MCPManager"] = None
selected_model: Optional["ModelInfo"] = None
selected_model_raw: Optional[Dict[str, Any]] = None
session_history: List[Dict[str, Any]] = field(default_factory=list)
session_system_prompt: str = ""
memory_enabled: bool = True
memory_start_index: int = 0
online_enabled: bool = False
middle_out_enabled: bool = False
session_max_token: int = 0
total_input_tokens: int = 0
total_output_tokens: int = 0
total_cost: float = 0.0
message_count: int = 0
is_tui: bool = False # Flag for TUI mode
current_index: int = 0
@dataclass
class CommandHelp:
"""
Help information for a command.
Attributes:
description: Brief description
usage: Usage syntax
examples: List of (description, example) tuples
notes: Additional notes
aliases: Command aliases
"""
description: str
usage: str = ""
examples: List[tuple] = field(default_factory=list)
notes: str = ""
aliases: List[str] = field(default_factory=list)
class Command(ABC):
"""
Abstract base class for all commands.
Commands implement the execute method to handle their logic.
They can also provide help information and aliases.
"""
@property
@abstractmethod
def name(self) -> str:
"""Get the primary command name (e.g., '/help')."""
pass
@property
def aliases(self) -> List[str]:
"""Get command aliases (e.g., ['/h'] for help)."""
return []
@property
@abstractmethod
def help(self) -> CommandHelp:
"""Get command help information."""
pass
@abstractmethod
def execute(self, args: str, context: CommandContext) -> CommandResult:
"""
Execute the command.
Args:
args: Arguments passed to the command
context: Command execution context
Returns:
CommandResult indicating success/failure
"""
pass
def matches(self, input_text: str) -> bool:
"""
Check if this command matches the input.
Args:
input_text: User input text
Returns:
True if this command should handle the input
"""
input_lower = input_text.lower()
cmd_word = input_lower.split()[0] if input_lower.split() else ""
# Check primary name
if cmd_word == self.name.lower():
return True
# Check aliases
for alias in self.aliases:
if cmd_word == alias.lower():
return True
return False
def get_args(self, input_text: str) -> str:
"""
Extract arguments from the input text.
Args:
input_text: Full user input
Returns:
Arguments portion of the input
"""
parts = input_text.split(maxsplit=1)
return parts[1] if len(parts) > 1 else ""
class CommandRegistry:
"""
Registry for managing available commands.
Provides registration, lookup, and execution of commands.
"""
def __init__(self):
"""Initialize an empty command registry."""
self._commands: Dict[str, Command] = {}
self._aliases: Dict[str, str] = {}
self.logger = get_logger()
def register(self, command: Command) -> None:
"""
Register a command.
Args:
command: Command instance to register
Raises:
ValueError: If command name already registered
"""
name = command.name.lower()
if name in self._commands:
raise ValueError(f"Command '{name}' already registered")
self._commands[name] = command
# Register aliases
for alias in command.aliases:
alias_lower = alias.lower()
if alias_lower in self._aliases:
self.logger.warning(
f"Alias '{alias}' already registered, overwriting"
)
self._aliases[alias_lower] = name
self.logger.debug(f"Registered command: {name}")
def register_function(
self,
name: str,
handler: Callable[[str, CommandContext], CommandResult],
description: str,
usage: str = "",
aliases: Optional[List[str]] = None,
examples: Optional[List[tuple]] = None,
notes: str = "",
) -> None:
"""
Register a function-based command.
Convenience method for simple commands that don't need
a full Command class.
Args:
name: Command name (e.g., '/help')
handler: Function to execute
description: Help description
usage: Usage syntax
aliases: Command aliases
examples: Example usages
notes: Additional notes
"""
aliases = aliases or []
examples = examples or []
class FunctionCommand(Command):
@property
def name(self) -> str:
return name
@property
def aliases(self) -> List[str]:
return aliases
@property
def help(self) -> CommandHelp:
return CommandHelp(
description=description,
usage=usage,
examples=examples,
notes=notes,
aliases=aliases,
)
def execute(self, args: str, context: CommandContext) -> CommandResult:
return handler(args, context)
self.register(FunctionCommand())
def get(self, name: str) -> Optional[Command]:
"""
Get a command by name or alias.
Args:
name: Command name or alias
Returns:
Command instance or None if not found
"""
name_lower = name.lower()
# Check direct match
if name_lower in self._commands:
return self._commands[name_lower]
# Check aliases
if name_lower in self._aliases:
return self._commands[self._aliases[name_lower]]
return None
def find(self, input_text: str) -> Optional[Command]:
"""
Find a command that matches the input.
Args:
input_text: User input text
Returns:
Matching Command or None
"""
cmd_word = input_text.lower().split()[0] if input_text.split() else ""
return self.get(cmd_word)
def execute(self, input_text: str, context: CommandContext) -> Optional[CommandResult]:
"""
Execute a command matching the input.
Args:
input_text: User input text
context: Execution context
Returns:
CommandResult or None if no matching command
"""
command = self.find(input_text)
if command:
args = command.get_args(input_text)
self.logger.debug(f"Executing command: {command.name} with args: {args}")
return command.execute(args, context)
return None
def is_command(self, input_text: str) -> bool:
"""
Check if input is a valid command.
Args:
input_text: User input text
Returns:
True if input matches a registered command
"""
return self.find(input_text) is not None
def list_commands(self) -> List[Command]:
"""
Get all registered commands.
Returns:
List of Command instances
"""
return list(self._commands.values())
def get_all_names(self) -> List[str]:
"""
Get all command names and aliases.
Returns:
List of command names including aliases
"""
names = list(self._commands.keys())
names.extend(self._aliases.keys())
return sorted(set(names))
# Global registry instance
registry = CommandRegistry()

11
oai/config/__init__.py Normal file
View File

@@ -0,0 +1,11 @@
"""
Configuration management for oAI.
This package handles all configuration persistence, settings management,
and database operations for the application.
"""
from oai.config.settings import Settings
from oai.config.database import Database
__all__ = ["Settings", "Database"]

472
oai/config/database.py Normal file
View File

@@ -0,0 +1,472 @@
"""
Database persistence layer for oAI.
This module provides a clean abstraction for SQLite operations including
configuration storage, conversation persistence, and MCP statistics tracking.
All database operations are centralized here for maintainability.
"""
import sqlite3
import json
import datetime
from pathlib import Path
from typing import Optional, List, Dict, Any
from contextlib import contextmanager
from oai.constants import DATABASE_FILE, CONFIG_DIR
class Database:
"""
SQLite database manager for oAI.
Handles all database operations including:
- Configuration key-value storage
- Conversation session persistence
- MCP configuration and statistics
- Database registrations for MCP
Uses context managers for safe connection handling and supports
automatic table creation on first use.
"""
def __init__(self, db_path: Optional[Path] = None):
"""
Initialize the database manager.
Args:
db_path: Optional custom database path. Defaults to standard location.
"""
self.db_path = db_path or DATABASE_FILE
self._ensure_directories()
self._ensure_tables()
def _ensure_directories(self) -> None:
"""Ensure the configuration directory exists."""
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
@contextmanager
def _connection(self):
"""
Context manager for database connections.
Yields:
sqlite3.Connection: Active database connection
Example:
with self._connection() as conn:
conn.execute("SELECT * FROM config")
"""
conn = sqlite3.connect(str(self.db_path))
try:
yield conn
conn.commit()
except Exception:
conn.rollback()
raise
finally:
conn.close()
def _ensure_tables(self) -> None:
"""Create all required tables if they don't exist."""
with self._connection() as conn:
# Main configuration table
conn.execute("""
CREATE TABLE IF NOT EXISTS config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
)
""")
# Conversation sessions table
conn.execute("""
CREATE TABLE IF NOT EXISTS conversation_sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
name TEXT NOT NULL,
timestamp TEXT NOT NULL,
data TEXT NOT NULL
)
""")
# MCP configuration table
conn.execute("""
CREATE TABLE IF NOT EXISTS mcp_config (
key TEXT PRIMARY KEY,
value TEXT NOT NULL
)
""")
# MCP statistics table
conn.execute("""
CREATE TABLE IF NOT EXISTS mcp_stats (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TEXT NOT NULL,
tool_name TEXT NOT NULL,
folder TEXT,
success INTEGER NOT NULL,
error_message TEXT
)
""")
# MCP databases table
conn.execute("""
CREATE TABLE IF NOT EXISTS mcp_databases (
id INTEGER PRIMARY KEY AUTOINCREMENT,
path TEXT NOT NULL UNIQUE,
name TEXT NOT NULL,
size INTEGER,
tables TEXT,
added_timestamp TEXT NOT NULL
)
""")
# =========================================================================
# CONFIGURATION METHODS
# =========================================================================
def get_config(self, key: str) -> Optional[str]:
"""
Retrieve a configuration value by key.
Args:
key: The configuration key to retrieve
Returns:
The configuration value, or None if not found
"""
with self._connection() as conn:
cursor = conn.execute(
"SELECT value FROM config WHERE key = ?",
(key,)
)
result = cursor.fetchone()
return result[0] if result else None
def set_config(self, key: str, value: str) -> None:
"""
Set a configuration value.
Args:
key: The configuration key
value: The value to store
"""
with self._connection() as conn:
conn.execute(
"INSERT OR REPLACE INTO config (key, value) VALUES (?, ?)",
(key, value)
)
def delete_config(self, key: str) -> bool:
"""
Delete a configuration value.
Args:
key: The configuration key to delete
Returns:
True if a row was deleted, False otherwise
"""
with self._connection() as conn:
cursor = conn.execute(
"DELETE FROM config WHERE key = ?",
(key,)
)
return cursor.rowcount > 0
def get_all_config(self) -> Dict[str, str]:
"""
Retrieve all configuration values.
Returns:
Dictionary of all key-value pairs
"""
with self._connection() as conn:
cursor = conn.execute("SELECT key, value FROM config")
return dict(cursor.fetchall())
# =========================================================================
# MCP CONFIGURATION METHODS
# =========================================================================
def get_mcp_config(self, key: str) -> Optional[str]:
"""
Retrieve an MCP configuration value.
Args:
key: The MCP configuration key
Returns:
The configuration value, or None if not found
"""
with self._connection() as conn:
cursor = conn.execute(
"SELECT value FROM mcp_config WHERE key = ?",
(key,)
)
result = cursor.fetchone()
return result[0] if result else None
def set_mcp_config(self, key: str, value: str) -> None:
"""
Set an MCP configuration value.
Args:
key: The MCP configuration key
value: The value to store
"""
with self._connection() as conn:
conn.execute(
"INSERT OR REPLACE INTO mcp_config (key, value) VALUES (?, ?)",
(key, value)
)
# =========================================================================
# MCP STATISTICS METHODS
# =========================================================================
def log_mcp_stat(
self,
tool_name: str,
folder: Optional[str],
success: bool,
error_message: Optional[str] = None
) -> None:
"""
Log an MCP tool usage event.
Args:
tool_name: Name of the MCP tool that was called
folder: The folder path involved (if any)
success: Whether the call succeeded
error_message: Error message if the call failed
"""
timestamp = datetime.datetime.now().isoformat()
with self._connection() as conn:
conn.execute(
"""INSERT INTO mcp_stats
(timestamp, tool_name, folder, success, error_message)
VALUES (?, ?, ?, ?, ?)""",
(timestamp, tool_name, folder, 1 if success else 0, error_message)
)
def get_mcp_stats(self) -> Dict[str, Any]:
"""
Get aggregated MCP usage statistics.
Returns:
Dictionary containing usage statistics:
- total_calls: Total number of tool calls
- reads: Number of file reads
- lists: Number of directory listings
- searches: Number of file searches
- db_inspects: Number of database inspections
- db_searches: Number of database searches
- db_queries: Number of database queries
- last_used: Timestamp of last usage
"""
with self._connection() as conn:
cursor = conn.execute("""
SELECT
COUNT(*) as total_calls,
SUM(CASE WHEN tool_name = 'read_file' THEN 1 ELSE 0 END) as reads,
SUM(CASE WHEN tool_name = 'list_directory' THEN 1 ELSE 0 END) as lists,
SUM(CASE WHEN tool_name = 'search_files' THEN 1 ELSE 0 END) as searches,
SUM(CASE WHEN tool_name = 'inspect_database' THEN 1 ELSE 0 END) as db_inspects,
SUM(CASE WHEN tool_name = 'search_database' THEN 1 ELSE 0 END) as db_searches,
SUM(CASE WHEN tool_name = 'query_database' THEN 1 ELSE 0 END) as db_queries,
MAX(timestamp) as last_used
FROM mcp_stats
""")
row = cursor.fetchone()
return {
"total_calls": row[0] or 0,
"reads": row[1] or 0,
"lists": row[2] or 0,
"searches": row[3] or 0,
"db_inspects": row[4] or 0,
"db_searches": row[5] or 0,
"db_queries": row[6] or 0,
"last_used": row[7],
}
# =========================================================================
# MCP DATABASE REGISTRY METHODS
# =========================================================================
def add_mcp_database(self, db_info: Dict[str, Any]) -> int:
"""
Register a database for MCP access.
Args:
db_info: Dictionary containing:
- path: Database file path
- name: Display name
- size: File size in bytes
- tables: List of table names
- added: Timestamp when added
Returns:
The database ID
"""
with self._connection() as conn:
conn.execute(
"""INSERT INTO mcp_databases
(path, name, size, tables, added_timestamp)
VALUES (?, ?, ?, ?, ?)""",
(
db_info["path"],
db_info["name"],
db_info["size"],
json.dumps(db_info["tables"]),
db_info["added"]
)
)
cursor = conn.execute(
"SELECT id FROM mcp_databases WHERE path = ?",
(db_info["path"],)
)
return cursor.fetchone()[0]
def remove_mcp_database(self, db_path: str) -> bool:
"""
Remove a database from the MCP registry.
Args:
db_path: Path to the database file
Returns:
True if a row was deleted, False otherwise
"""
with self._connection() as conn:
cursor = conn.execute(
"DELETE FROM mcp_databases WHERE path = ?",
(db_path,)
)
return cursor.rowcount > 0
def get_mcp_databases(self) -> List[Dict[str, Any]]:
"""
Retrieve all registered MCP databases.
Returns:
List of database information dictionaries
"""
with self._connection() as conn:
cursor = conn.execute(
"""SELECT id, path, name, size, tables, added_timestamp
FROM mcp_databases ORDER BY id"""
)
databases = []
for row in cursor.fetchall():
tables_list = json.loads(row[4]) if row[4] else []
databases.append({
"id": row[0],
"path": row[1],
"name": row[2],
"size": row[3],
"tables": tables_list,
"added": row[5],
})
return databases
# =========================================================================
# CONVERSATION METHODS
# =========================================================================
def save_conversation(self, name: str, data: List[Dict[str, str]]) -> None:
"""
Save a conversation session.
Args:
name: Name/identifier for the conversation
data: List of message dictionaries with 'prompt' and 'response' keys
"""
timestamp = datetime.datetime.now().isoformat()
data_json = json.dumps(data)
with self._connection() as conn:
conn.execute(
"""INSERT INTO conversation_sessions
(name, timestamp, data) VALUES (?, ?, ?)""",
(name, timestamp, data_json)
)
def load_conversation(self, name: str) -> Optional[List[Dict[str, str]]]:
"""
Load a conversation by name.
Args:
name: Name of the conversation to load
Returns:
List of messages, or None if not found
"""
with self._connection() as conn:
cursor = conn.execute(
"""SELECT data FROM conversation_sessions
WHERE name = ?
ORDER BY timestamp DESC LIMIT 1""",
(name,)
)
result = cursor.fetchone()
if result:
return json.loads(result[0])
return None
def delete_conversation(self, name: str) -> int:
"""
Delete a conversation by name.
Args:
name: Name of the conversation to delete
Returns:
Number of rows deleted
"""
with self._connection() as conn:
cursor = conn.execute(
"DELETE FROM conversation_sessions WHERE name = ?",
(name,)
)
return cursor.rowcount
def list_conversations(self) -> List[Dict[str, Any]]:
"""
List all saved conversations.
Returns:
List of conversation summaries with name, timestamp, and message_count
"""
with self._connection() as conn:
cursor = conn.execute("""
SELECT name, MAX(timestamp) as last_saved, data
FROM conversation_sessions
GROUP BY name
ORDER BY last_saved DESC
""")
conversations = []
for row in cursor.fetchall():
name, timestamp, data_json = row
data = json.loads(data_json)
conversations.append({
"name": name,
"timestamp": timestamp,
"message_count": len(data),
})
return conversations
# Global database instance for convenience
_db: Optional[Database] = None
def get_database() -> Database:
"""
Get the global database instance.
Returns:
The shared Database instance
"""
global _db
if _db is None:
_db = Database()
return _db

361
oai/config/settings.py Normal file
View File

@@ -0,0 +1,361 @@
"""
Settings management for oAI.
This module provides a centralized settings class that handles all application
configuration with type safety, validation, and persistence.
"""
from dataclasses import dataclass, field
from typing import Optional
from pathlib import Path
from oai.constants import (
DEFAULT_BASE_URL,
DEFAULT_STREAM_ENABLED,
DEFAULT_MAX_TOKENS,
DEFAULT_ONLINE_MODE,
DEFAULT_COST_WARNING_THRESHOLD,
DEFAULT_LOG_MAX_SIZE_MB,
DEFAULT_LOG_BACKUP_COUNT,
DEFAULT_LOG_LEVEL,
DEFAULT_SYSTEM_PROMPT,
VALID_LOG_LEVELS,
)
from oai.config.database import get_database
@dataclass
class Settings:
"""
Application settings with persistence support.
This class provides a clean interface for managing all configuration
options. Settings are automatically loaded from the database on
initialization and can be persisted back.
Attributes:
api_key: OpenRouter API key
base_url: API base URL
default_model: Default model ID to use
default_system_prompt: Custom system prompt (None = use hardcoded default, "" = blank)
stream_enabled: Whether to stream responses
max_tokens: Maximum tokens per request
cost_warning_threshold: Alert threshold for message cost
default_online_mode: Whether online mode is enabled by default
log_max_size_mb: Maximum log file size in MB
log_backup_count: Number of log file backups to keep
log_level: Logging level (debug/info/warning/error/critical)
"""
api_key: Optional[str] = None
base_url: str = DEFAULT_BASE_URL
default_model: Optional[str] = None
default_system_prompt: Optional[str] = None
stream_enabled: bool = DEFAULT_STREAM_ENABLED
max_tokens: int = DEFAULT_MAX_TOKENS
cost_warning_threshold: float = DEFAULT_COST_WARNING_THRESHOLD
default_online_mode: bool = DEFAULT_ONLINE_MODE
log_max_size_mb: int = DEFAULT_LOG_MAX_SIZE_MB
log_backup_count: int = DEFAULT_LOG_BACKUP_COUNT
log_level: str = DEFAULT_LOG_LEVEL
@property
def effective_system_prompt(self) -> str:
"""
Get the effective system prompt to use.
Returns:
The custom prompt if set, hardcoded default if None, or blank if explicitly set to ""
"""
if self.default_system_prompt is None:
return DEFAULT_SYSTEM_PROMPT
return self.default_system_prompt
def __post_init__(self):
"""Validate settings after initialization."""
self._validate()
def _validate(self) -> None:
"""Validate all settings values."""
# Validate log level
if self.log_level.lower() not in VALID_LOG_LEVELS:
raise ValueError(
f"Invalid log level: {self.log_level}. "
f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}"
)
# Validate numeric bounds
if self.max_tokens < 1:
raise ValueError("max_tokens must be at least 1")
if self.cost_warning_threshold < 0:
raise ValueError("cost_warning_threshold must be non-negative")
if self.log_max_size_mb < 1:
raise ValueError("log_max_size_mb must be at least 1")
if self.log_backup_count < 0:
raise ValueError("log_backup_count must be non-negative")
@classmethod
def load(cls) -> "Settings":
"""
Load settings from the database.
Returns:
Settings instance with values from database
"""
db = get_database()
# Helper to safely parse boolean
def parse_bool(value: Optional[str], default: bool) -> bool:
if value is None:
return default
return value.lower() in ("on", "true", "1", "yes")
# Helper to safely parse int
def parse_int(value: Optional[str], default: int) -> int:
if value is None:
return default
try:
return int(value)
except ValueError:
return default
# Helper to safely parse float
def parse_float(value: Optional[str], default: float) -> float:
if value is None:
return default
try:
return float(value)
except ValueError:
return default
# Get system prompt from DB: None means not set (use default), "" means explicitly blank
system_prompt_value = db.get_config("default_system_prompt")
return cls(
api_key=db.get_config("api_key"),
base_url=db.get_config("base_url") or DEFAULT_BASE_URL,
default_model=db.get_config("default_model"),
default_system_prompt=system_prompt_value,
stream_enabled=parse_bool(
db.get_config("stream_enabled"),
DEFAULT_STREAM_ENABLED
),
max_tokens=parse_int(
db.get_config("max_token"),
DEFAULT_MAX_TOKENS
),
cost_warning_threshold=parse_float(
db.get_config("cost_warning_threshold"),
DEFAULT_COST_WARNING_THRESHOLD
),
default_online_mode=parse_bool(
db.get_config("default_online_mode"),
DEFAULT_ONLINE_MODE
),
log_max_size_mb=parse_int(
db.get_config("log_max_size_mb"),
DEFAULT_LOG_MAX_SIZE_MB
),
log_backup_count=parse_int(
db.get_config("log_backup_count"),
DEFAULT_LOG_BACKUP_COUNT
),
log_level=db.get_config("log_level") or DEFAULT_LOG_LEVEL,
)
def save(self) -> None:
"""Persist all settings to the database."""
db = get_database()
# Only save API key if it exists
if self.api_key:
db.set_config("api_key", self.api_key)
db.set_config("base_url", self.base_url)
if self.default_model:
db.set_config("default_model", self.default_model)
# Save system prompt: None means not set (don't save), otherwise save the value (even if "")
if self.default_system_prompt is not None:
db.set_config("default_system_prompt", self.default_system_prompt)
db.set_config("stream_enabled", "on" if self.stream_enabled else "off")
db.set_config("max_token", str(self.max_tokens))
db.set_config("cost_warning_threshold", str(self.cost_warning_threshold))
db.set_config("default_online_mode", "on" if self.default_online_mode else "off")
db.set_config("log_max_size_mb", str(self.log_max_size_mb))
db.set_config("log_backup_count", str(self.log_backup_count))
db.set_config("log_level", self.log_level)
def set_api_key(self, api_key: str) -> None:
"""
Set and persist the API key.
Args:
api_key: The new API key
"""
self.api_key = api_key.strip()
get_database().set_config("api_key", self.api_key)
def set_base_url(self, url: str) -> None:
"""
Set and persist the base URL.
Args:
url: The new base URL
"""
self.base_url = url.strip()
get_database().set_config("base_url", self.base_url)
def set_default_model(self, model_id: str) -> None:
"""
Set and persist the default model.
Args:
model_id: The model ID to set as default
"""
self.default_model = model_id
get_database().set_config("default_model", model_id)
def set_default_system_prompt(self, prompt: str) -> None:
"""
Set and persist the default system prompt.
Args:
prompt: The system prompt to use for all new sessions.
Empty string "" means blank prompt (no system message).
"""
self.default_system_prompt = prompt
get_database().set_config("default_system_prompt", prompt)
def clear_default_system_prompt(self) -> None:
"""
Clear the custom system prompt and revert to hardcoded default.
This removes the custom prompt from the database, causing the
application to use the built-in DEFAULT_SYSTEM_PROMPT.
"""
self.default_system_prompt = None
# Remove from database to indicate "not set"
db = get_database()
with db._connection() as conn:
conn.execute("DELETE FROM config WHERE key = ?", ("default_system_prompt",))
conn.commit()
def set_stream_enabled(self, enabled: bool) -> None:
"""
Set and persist the streaming preference.
Args:
enabled: Whether to enable streaming
"""
self.stream_enabled = enabled
get_database().set_config("stream_enabled", "on" if enabled else "off")
def set_max_tokens(self, max_tokens: int) -> None:
"""
Set and persist the maximum tokens.
Args:
max_tokens: Maximum number of tokens
Raises:
ValueError: If max_tokens is less than 1
"""
if max_tokens < 1:
raise ValueError("max_tokens must be at least 1")
self.max_tokens = max_tokens
get_database().set_config("max_token", str(max_tokens))
def set_cost_warning_threshold(self, threshold: float) -> None:
"""
Set and persist the cost warning threshold.
Args:
threshold: Cost threshold in USD
Raises:
ValueError: If threshold is negative
"""
if threshold < 0:
raise ValueError("cost_warning_threshold must be non-negative")
self.cost_warning_threshold = threshold
get_database().set_config("cost_warning_threshold", str(threshold))
def set_default_online_mode(self, enabled: bool) -> None:
"""
Set and persist the default online mode.
Args:
enabled: Whether online mode should be enabled by default
"""
self.default_online_mode = enabled
get_database().set_config("default_online_mode", "on" if enabled else "off")
def set_log_level(self, level: str) -> None:
"""
Set and persist the log level.
Args:
level: The log level (debug/info/warning/error/critical)
Raises:
ValueError: If level is not valid
"""
level_lower = level.lower()
if level_lower not in VALID_LOG_LEVELS:
raise ValueError(
f"Invalid log level: {level}. "
f"Must be one of: {', '.join(VALID_LOG_LEVELS.keys())}"
)
self.log_level = level_lower
get_database().set_config("log_level", level_lower)
def set_log_max_size(self, size_mb: int) -> None:
"""
Set and persist the maximum log file size.
Args:
size_mb: Maximum size in megabytes
Raises:
ValueError: If size_mb is less than 1
"""
if size_mb < 1:
raise ValueError("log_max_size_mb must be at least 1")
# Cap at 100 MB for safety
self.log_max_size_mb = min(size_mb, 100)
get_database().set_config("log_max_size_mb", str(self.log_max_size_mb))
# Global settings instance
_settings: Optional[Settings] = None
def get_settings() -> Settings:
"""
Get the global settings instance.
Returns:
The shared Settings instance, loading from database if needed
"""
global _settings
if _settings is None:
_settings = Settings.load()
return _settings
def reload_settings() -> Settings:
"""
Force reload settings from the database.
Returns:
Fresh Settings instance
"""
global _settings
_settings = Settings.load()
return _settings

448
oai/constants.py Normal file
View File

@@ -0,0 +1,448 @@
"""
Application-wide constants for oAI.
This module contains all configuration constants, default values, and static
definitions used throughout the application. Centralizing these values makes
the codebase easier to maintain and configure.
"""
from pathlib import Path
from typing import Set, Dict, Any
import logging
# =============================================================================
# APPLICATION METADATA
# =============================================================================
APP_NAME = "oAI"
APP_VERSION = "3.0.0"
APP_URL = "https://iurl.no/oai"
APP_DESCRIPTION = "OpenRouter AI Chat Client with MCP Integration"
# =============================================================================
# FILE PATHS
# =============================================================================
HOME_DIR = Path.home()
CONFIG_DIR = HOME_DIR / ".config" / "oai"
CACHE_DIR = HOME_DIR / ".cache" / "oai"
HISTORY_FILE = CONFIG_DIR / "history.txt"
DATABASE_FILE = CONFIG_DIR / "oai_config.db"
LOG_FILE = CONFIG_DIR / "oai.log"
# =============================================================================
# API CONFIGURATION
# =============================================================================
DEFAULT_BASE_URL = "https://openrouter.ai/api/v1"
DEFAULT_STREAM_ENABLED = True
DEFAULT_MAX_TOKENS = 100_000
DEFAULT_ONLINE_MODE = False
# =============================================================================
# DEFAULT SYSTEM PROMPT
# =============================================================================
DEFAULT_SYSTEM_PROMPT = (
"You are a knowledgeable and helpful AI assistant. Provide clear, accurate, "
"and well-structured responses. Be concise yet thorough. When uncertain about "
"something, acknowledge your limitations. For technical topics, include relevant "
"details and examples when helpful."
)
# =============================================================================
# PRICING DEFAULTS (per million tokens)
# =============================================================================
DEFAULT_INPUT_PRICE = 3.0
DEFAULT_OUTPUT_PRICE = 15.0
MODEL_PRICING: Dict[str, float] = {
"input": DEFAULT_INPUT_PRICE,
"output": DEFAULT_OUTPUT_PRICE,
}
# =============================================================================
# CREDIT ALERTS
# =============================================================================
LOW_CREDIT_RATIO = 0.1 # Alert when credits < 10% of total
LOW_CREDIT_AMOUNT = 1.0 # Alert when credits < $1.00
DEFAULT_COST_WARNING_THRESHOLD = 0.01 # Alert when single message cost exceeds this
COST_WARNING_THRESHOLD = DEFAULT_COST_WARNING_THRESHOLD # Alias for convenience
# =============================================================================
# LOGGING CONFIGURATION
# =============================================================================
DEFAULT_LOG_MAX_SIZE_MB = 10
DEFAULT_LOG_BACKUP_COUNT = 2
DEFAULT_LOG_LEVEL = "info"
VALID_LOG_LEVELS: Dict[str, int] = {
"debug": logging.DEBUG,
"info": logging.INFO,
"warning": logging.WARNING,
"error": logging.ERROR,
"critical": logging.CRITICAL,
}
# =============================================================================
# FILE HANDLING
# =============================================================================
# Maximum file size for reading (10 MB)
MAX_FILE_SIZE = 10 * 1024 * 1024
# Content truncation threshold (50 KB)
CONTENT_TRUNCATION_THRESHOLD = 50 * 1024
# Maximum items in directory listing
MAX_LIST_ITEMS = 1000
# Supported code file extensions for syntax highlighting
SUPPORTED_CODE_EXTENSIONS: Set[str] = {
".py", ".js", ".ts", ".cs", ".java", ".c", ".cpp", ".h", ".hpp",
".rb", ".ruby", ".php", ".swift", ".kt", ".kts", ".go",
".sh", ".bat", ".ps1", ".r", ".scala", ".pl", ".lua", ".dart",
".elm", ".xml", ".json", ".yaml", ".yml", ".md", ".txt",
}
# All allowed file extensions for attachment
ALLOWED_FILE_EXTENSIONS: Set[str] = {
# Code files
".py", ".js", ".ts", ".jsx", ".tsx", ".vue", ".java", ".c", ".cpp", ".cc", ".cxx",
".h", ".hpp", ".hxx", ".rb", ".go", ".rs", ".swift", ".kt", ".kts", ".php",
".sh", ".bash", ".zsh", ".fish", ".bat", ".cmd", ".ps1",
# Data files
".json", ".csv", ".yaml", ".yml", ".toml", ".xml", ".sql", ".db", ".sqlite", ".sqlite3",
# Documents
".txt", ".md", ".log", ".conf", ".cfg", ".ini", ".env", ".properties",
# Images
".png", ".jpg", ".jpeg", ".gif", ".bmp", ".webp", ".svg", ".ico",
# Archives
".zip", ".tar", ".gz", ".bz2", ".7z", ".rar", ".xz",
# Config files
".lock", ".gitignore", ".dockerignore", ".editorconfig", ".eslintrc",
".prettierrc", ".babelrc", ".nvmrc", ".npmrc",
# Binary/Compiled
".pyc", ".pyo", ".pyd", ".so", ".dll", ".dylib", ".exe", ".app",
".dmg", ".pkg", ".deb", ".rpm", ".apk", ".ipa",
# ML/AI
".pkl", ".pickle", ".joblib", ".npy", ".npz", ".safetensors", ".onnx",
".pt", ".pth", ".ckpt", ".pb", ".tflite", ".mlmodel", ".coreml", ".rknn",
# Data formats
".wasm", ".proto", ".graphql", ".graphqls", ".grpc", ".avro", ".parquet",
".orc", ".feather", ".arrow", ".hdf5", ".h5", ".mat", ".rdata", ".rds",
# Other
".pdf", ".class", ".jar", ".war",
}
# =============================================================================
# SECURITY CONFIGURATION
# =============================================================================
# System directories that should never be accessed
SYSTEM_DIRS_BLACKLIST: Set[str] = {
# macOS
"/System", "/Library", "/private", "/usr", "/bin", "/sbin",
# Linux
"/boot", "/dev", "/proc", "/sys", "/root",
# Windows
"C:\\Windows", "C:\\Program Files", "C:\\Program Files (x86)",
}
# Directories to skip during file operations
SKIP_DIRECTORIES: Set[str] = {
# Python virtual environments
".venv", "venv", "env", "virtualenv",
"site-packages", "dist-packages",
# Python caches
"__pycache__", ".pytest_cache", ".mypy_cache",
# JavaScript/Node
"node_modules",
# Version control
".git", ".svn",
# IDEs
".idea", ".vscode",
# Build directories
"build", "dist", "eggs", ".eggs",
}
# =============================================================================
# DATABASE QUERIES - SQL SAFETY
# =============================================================================
# Maximum query execution timeout (seconds)
MAX_QUERY_TIMEOUT = 5
# Maximum rows returned from queries
MAX_QUERY_RESULTS = 1000
# Default rows per query
DEFAULT_QUERY_LIMIT = 100
# Keywords that are blocked in database queries
DANGEROUS_SQL_KEYWORDS: Set[str] = {
"INSERT", "UPDATE", "DELETE", "DROP", "CREATE",
"ALTER", "TRUNCATE", "REPLACE", "ATTACH", "DETACH",
"PRAGMA", "VACUUM", "REINDEX",
}
# =============================================================================
# MCP CONFIGURATION
# =============================================================================
# Maximum tool call iterations per request
MAX_TOOL_LOOPS = 5
# =============================================================================
# VALID COMMANDS
# =============================================================================
VALID_COMMANDS: Set[str] = {
"/retry", "/online", "/memory", "/paste", "/export", "/save", "/load",
"/delete", "/list", "/prev", "/next", "/stats", "/middleout", "/reset",
"/info", "/model", "/maxtoken", "/system", "/config", "/credits", "/clear",
"/cl", "/help", "/mcp",
}
# =============================================================================
# COMMAND HELP DATABASE
# =============================================================================
COMMAND_HELP: Dict[str, Dict[str, Any]] = {
"/clear": {
"aliases": ["/cl"],
"description": "Clear the terminal screen for a clean interface.",
"usage": "/clear\n/cl",
"examples": [
("Clear screen", "/clear"),
("Using short alias", "/cl"),
],
"notes": "You can also use the keyboard shortcut Ctrl+L.",
},
"/help": {
"description": "Display help information for commands.",
"usage": "/help [command|topic]",
"examples": [
("Show all commands", "/help"),
("Get help for a specific command", "/help /model"),
("Get detailed MCP help", "/help mcp"),
],
"notes": "Use /help without arguments to see the full command list.",
},
"mcp": {
"description": "Complete guide to MCP (Model Context Protocol).",
"usage": "See detailed examples below",
"examples": [],
"notes": """
MCP (Model Context Protocol) gives your AI assistant direct access to:
• Local files and folders (read, search, list)
• SQLite databases (inspect, search, query)
FILE MODE (default):
/mcp on Start MCP server
/mcp add ~/Documents Grant access to folder
/mcp list View all allowed folders
DATABASE MODE:
/mcp add db ~/app/data.db Add specific database
/mcp db list View all databases
/mcp db 1 Work with database #1
/mcp files Switch back to file mode
WRITE MODE (optional):
/mcp write on Enable file modifications
/mcp write off Disable write mode (back to read-only)
For command-specific help: /help /mcp
""",
},
"/mcp": {
"description": "Manage MCP for local file access and SQLite database querying.",
"usage": "/mcp <command> [args]",
"examples": [
("Enable MCP server", "/mcp on"),
("Disable MCP server", "/mcp off"),
("Show MCP status", "/mcp status"),
("", ""),
("━━━ FILE MODE ━━━", ""),
("Add folder for file access", "/mcp add ~/Documents"),
("Remove folder", "/mcp remove ~/Desktop"),
("List allowed folders", "/mcp list"),
("Enable write mode", "/mcp write on"),
("", ""),
("━━━ DATABASE MODE ━━━", ""),
("Add SQLite database", "/mcp add db ~/app/data.db"),
("List all databases", "/mcp db list"),
("Switch to database #1", "/mcp db 1"),
("Switch back to file mode", "/mcp files"),
],
"notes": "MCP allows AI to read local files and query SQLite databases.",
},
"/memory": {
"description": "Toggle conversation memory.",
"usage": "/memory [on|off]",
"examples": [
("Check current memory status", "/memory"),
("Enable conversation memory", "/memory on"),
("Disable memory (save costs)", "/memory off"),
],
"notes": "Memory is ON by default. Disabling saves tokens.",
},
"/online": {
"description": "Enable or disable online mode (web search).",
"usage": "/online [on|off]",
"examples": [
("Check online mode status", "/online"),
("Enable web search", "/online on"),
("Disable web search", "/online off"),
],
"notes": "Not all models support online mode.",
},
"/paste": {
"description": "Paste plain text from clipboard and send to the AI.",
"usage": "/paste [prompt]",
"examples": [
("Paste clipboard content", "/paste"),
("Paste with a question", "/paste Explain this code"),
],
"notes": "Only plain text is supported.",
},
"/retry": {
"description": "Resend the last prompt from conversation history.",
"usage": "/retry",
"examples": [("Retry last message", "/retry")],
"notes": "Requires at least one message in history.",
},
"/next": {
"description": "View the next response in conversation history.",
"usage": "/next",
"examples": [("Navigate to next response", "/next")],
"notes": "Use /prev to go backward.",
},
"/prev": {
"description": "View the previous response in conversation history.",
"usage": "/prev",
"examples": [("Navigate to previous response", "/prev")],
"notes": "Use /next to go forward.",
},
"/reset": {
"description": "Clear conversation history and reset system prompt.",
"usage": "/reset",
"examples": [("Reset conversation", "/reset")],
"notes": "Requires confirmation.",
},
"/info": {
"description": "Display detailed information about a model.",
"usage": "/info [model_id]",
"examples": [
("Show current model info", "/info"),
("Show specific model info", "/info gpt-4o"),
],
"notes": "Shows pricing, capabilities, and context length.",
},
"/model": {
"description": "Select or change the AI model.",
"usage": "/model [search_term]",
"examples": [
("List all models", "/model"),
("Search for GPT models", "/model gpt"),
("Search for Claude models", "/model claude"),
],
"notes": "Models are numbered for easy selection.",
},
"/config": {
"description": "View or modify application configuration.",
"usage": "/config [setting] [value]",
"examples": [
("View all settings", "/config"),
("Set API key", "/config api"),
("Set default model", "/config model"),
("Set system prompt", "/config system You are a helpful assistant"),
("Enable streaming", "/config stream on"),
],
"notes": "Available: api, url, model, system, stream, costwarning, maxtoken, online, loglevel.",
},
"/maxtoken": {
"description": "Set a temporary session token limit.",
"usage": "/maxtoken [value]",
"examples": [
("View current session limit", "/maxtoken"),
("Set session limit to 2000", "/maxtoken 2000"),
],
"notes": "Cannot exceed stored max token limit.",
},
"/system": {
"description": "Set or clear the session-level system prompt.",
"usage": "/system [prompt|clear|default <prompt>]",
"examples": [
("View current system prompt", "/system"),
("Set as Python expert", "/system You are a Python expert"),
("Multiline with newlines", r"/system You are an expert.\nBe clear and concise."),
("Save as default", "/system default You are a helpful assistant"),
("Revert to default", "/system clear"),
("Blank prompt", '/system ""'),
],
"notes": r"Use \n for newlines. /system clear reverts to hardcoded default.",
},
"/save": {
"description": "Save the current conversation history.",
"usage": "/save <name>",
"examples": [("Save conversation", "/save my_chat")],
"notes": "Saved conversations can be loaded later with /load.",
},
"/load": {
"description": "Load a saved conversation.",
"usage": "/load <name|number>",
"examples": [
("Load by name", "/load my_chat"),
("Load by number from /list", "/load 3"),
],
"notes": "Use /list to see numbered conversations.",
},
"/delete": {
"description": "Delete a saved conversation.",
"usage": "/delete <name|number>",
"examples": [("Delete by name", "/delete my_chat")],
"notes": "Requires confirmation. Cannot be undone.",
},
"/list": {
"description": "List all saved conversations.",
"usage": "/list",
"examples": [("Show saved conversations", "/list")],
"notes": "Conversations are numbered for use with /load and /delete.",
},
"/export": {
"description": "Export the current conversation to a file.",
"usage": "/export <format> <filename>",
"examples": [
("Export as Markdown", "/export md notes.md"),
("Export as JSON", "/export json conversation.json"),
("Export as HTML", "/export html report.html"),
],
"notes": "Available formats: md, json, html.",
},
"/stats": {
"description": "Display session statistics.",
"usage": "/stats",
"examples": [("View session statistics", "/stats")],
"notes": "Shows tokens, costs, and credits.",
},
"/credits": {
"description": "Display your OpenRouter account credits.",
"usage": "/credits",
"examples": [("Check credits", "/credits")],
"notes": "Shows total, used, and remaining credits.",
},
"/middleout": {
"description": "Toggle middle-out transform for long prompts.",
"usage": "/middleout [on|off]",
"examples": [
("Check status", "/middleout"),
("Enable compression", "/middleout on"),
],
"notes": "Compresses prompts exceeding context size.",
},
}

14
oai/core/__init__.py Normal file
View File

@@ -0,0 +1,14 @@
"""
Core functionality for oAI.
This module provides the main session management and AI client
classes that power the chat application.
"""
from oai.core.session import ChatSession
from oai.core.client import AIClient
__all__ = [
"ChatSession",
"AIClient",
]

422
oai/core/client.py Normal file
View File

@@ -0,0 +1,422 @@
"""
AI Client for oAI.
This module provides a high-level client for interacting with AI models
through the provider abstraction layer.
"""
import asyncio
import json
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
from oai.constants import APP_NAME, APP_URL, MODEL_PRICING
from oai.providers.base import (
AIProvider,
ChatMessage,
ChatResponse,
ModelInfo,
StreamChunk,
ToolCall,
UsageStats,
)
from oai.providers.openrouter import OpenRouterProvider
from oai.utils.logging import get_logger
class AIClient:
"""
High-level AI client for chat interactions.
Provides a simplified interface for sending chat requests,
handling streaming, and managing tool calls.
Attributes:
provider: The underlying AI provider
default_model: Default model ID to use
http_headers: Custom HTTP headers for requests
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
provider_class: type = OpenRouterProvider,
app_name: str = APP_NAME,
app_url: str = APP_URL,
):
"""
Initialize the AI client.
Args:
api_key: API key for authentication
base_url: Optional custom base URL
provider_class: Provider class to use (default: OpenRouterProvider)
app_name: Application name for headers
app_url: Application URL for headers
"""
self.provider: AIProvider = provider_class(
api_key=api_key,
base_url=base_url,
app_name=app_name,
app_url=app_url,
)
self.default_model: Optional[str] = None
self.logger = get_logger()
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
"""
Get available models.
Args:
filter_text_only: Whether to exclude video-only models
Returns:
List of ModelInfo objects
"""
return self.provider.list_models(filter_text_only=filter_text_only)
def get_model(self, model_id: str) -> Optional[ModelInfo]:
"""
Get information about a specific model.
Args:
model_id: Model identifier
Returns:
ModelInfo or None if not found
"""
return self.provider.get_model(model_id)
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
"""
Get raw model data for provider-specific fields.
Args:
model_id: Model identifier
Returns:
Raw model dictionary or None
"""
if hasattr(self.provider, "get_raw_model"):
return self.provider.get_raw_model(model_id)
return None
def chat(
self,
messages: List[Dict[str, Any]],
model: Optional[str] = None,
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
system_prompt: Optional[str] = None,
online: bool = False,
transforms: Optional[List[str]] = None,
) -> Union[ChatResponse, Iterator[StreamChunk]]:
"""
Send a chat request.
Args:
messages: List of message dictionaries
model: Model ID (uses default if not specified)
stream: Whether to stream the response
max_tokens: Maximum tokens in response
temperature: Sampling temperature
tools: Tool definitions for function calling
tool_choice: Tool selection mode
system_prompt: System prompt to prepend
online: Whether to enable online mode
transforms: List of transforms (e.g., ["middle-out"])
Returns:
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
Raises:
ValueError: If no model specified and no default set
"""
model_id = model or self.default_model
if not model_id:
raise ValueError("No model specified and no default set")
# Apply online mode suffix
if online and hasattr(self.provider, "get_effective_model_id"):
model_id = self.provider.get_effective_model_id(model_id, True)
# Convert dict messages to ChatMessage objects
chat_messages = []
# Add system prompt if provided
if system_prompt:
chat_messages.append(ChatMessage(role="system", content=system_prompt))
# Convert message dicts
for msg in messages:
# Convert tool_calls dicts to ToolCall objects if present
tool_calls_data = msg.get("tool_calls")
tool_calls_obj = None
if tool_calls_data:
from oai.providers.base import ToolCall, ToolFunction
tool_calls_obj = []
for tc in tool_calls_data:
# Handle both ToolCall objects and dicts
if isinstance(tc, ToolCall):
tool_calls_obj.append(tc)
elif isinstance(tc, dict):
func_data = tc.get("function", {})
tool_calls_obj.append(
ToolCall(
id=tc.get("id", ""),
type=tc.get("type", "function"),
function=ToolFunction(
name=func_data.get("name", ""),
arguments=func_data.get("arguments", "{}"),
),
)
)
chat_messages.append(
ChatMessage(
role=msg.get("role", "user"),
content=msg.get("content"),
tool_calls=tool_calls_obj,
tool_call_id=msg.get("tool_call_id"),
)
)
self.logger.debug(
f"Sending chat request: model={model_id}, "
f"messages={len(chat_messages)}, stream={stream}"
)
return self.provider.chat(
model=model_id,
messages=chat_messages,
stream=stream,
max_tokens=max_tokens,
temperature=temperature,
tools=tools,
tool_choice=tool_choice,
transforms=transforms,
)
def chat_with_tools(
self,
messages: List[Dict[str, Any]],
tools: List[Dict[str, Any]],
tool_executor: Callable[[str, Dict[str, Any]], Dict[str, Any]],
model: Optional[str] = None,
max_loops: int = 5,
max_tokens: Optional[int] = None,
system_prompt: Optional[str] = None,
on_tool_call: Optional[Callable[[ToolCall], None]] = None,
on_tool_result: Optional[Callable[[str, Dict[str, Any]], None]] = None,
) -> ChatResponse:
"""
Send a chat request with automatic tool call handling.
Executes tool calls returned by the model and continues
the conversation until no more tool calls are requested.
Args:
messages: Initial messages
tools: Tool definitions
tool_executor: Function to execute tool calls
model: Model ID
max_loops: Maximum tool call iterations
max_tokens: Maximum response tokens
system_prompt: System prompt
on_tool_call: Callback when tool is called
on_tool_result: Callback when tool returns result
Returns:
Final ChatResponse after all tool calls complete
"""
model_id = model or self.default_model
if not model_id:
raise ValueError("No model specified and no default set")
# Build initial messages
chat_messages = []
if system_prompt:
chat_messages.append({"role": "system", "content": system_prompt})
chat_messages.extend(messages)
loop_count = 0
current_response: Optional[ChatResponse] = None
while loop_count < max_loops:
# Send request
response = self.chat(
messages=chat_messages,
model=model_id,
stream=False,
max_tokens=max_tokens,
tools=tools,
tool_choice="auto",
)
if not isinstance(response, ChatResponse):
raise ValueError("Expected non-streaming response")
current_response = response
# Check for tool calls
tool_calls = response.tool_calls
if not tool_calls:
break
self.logger.info(f"Model requested {len(tool_calls)} tool call(s)")
# Process each tool call
tool_results = []
for tc in tool_calls:
if on_tool_call:
on_tool_call(tc)
try:
args = json.loads(tc.function.arguments)
except json.JSONDecodeError as e:
self.logger.error(f"Failed to parse tool arguments: {e}")
result = {"error": f"Invalid arguments: {e}"}
else:
result = tool_executor(tc.function.name, args)
if on_tool_result:
on_tool_result(tc.function.name, result)
tool_results.append({
"tool_call_id": tc.id,
"role": "tool",
"name": tc.function.name,
"content": json.dumps(result),
})
# Add assistant message with tool calls
assistant_msg = {
"role": "assistant",
"content": response.content,
"tool_calls": [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in tool_calls
],
}
chat_messages.append(assistant_msg)
chat_messages.extend(tool_results)
loop_count += 1
if loop_count >= max_loops:
self.logger.warning(f"Reached max tool call loops ({max_loops})")
return current_response
def stream_chat(
self,
messages: List[Dict[str, Any]],
model: Optional[str] = None,
max_tokens: Optional[int] = None,
system_prompt: Optional[str] = None,
online: bool = False,
on_chunk: Optional[Callable[[StreamChunk], None]] = None,
) -> tuple[str, Optional[UsageStats]]:
"""
Stream a chat response and collect the full text.
Args:
messages: Chat messages
model: Model ID
max_tokens: Maximum tokens
system_prompt: System prompt
online: Online mode
on_chunk: Optional callback for each chunk
Returns:
Tuple of (full_response_text, usage_stats)
"""
response = self.chat(
messages=messages,
model=model,
stream=True,
max_tokens=max_tokens,
system_prompt=system_prompt,
online=online,
)
if isinstance(response, ChatResponse):
# Not actually streaming
return response.content or "", response.usage
full_text = ""
usage: Optional[UsageStats] = None
for chunk in response:
if chunk.error:
self.logger.error(f"Stream error: {chunk.error}")
break
if chunk.delta_content:
full_text += chunk.delta_content
if on_chunk:
on_chunk(chunk)
if chunk.usage:
usage = chunk.usage
return full_text, usage
def get_credits(self) -> Optional[Dict[str, Any]]:
"""
Get account credit information.
Returns:
Credit info dict or None if unavailable
"""
return self.provider.get_credits()
def estimate_cost(
self,
model_id: str,
input_tokens: int,
output_tokens: int,
) -> float:
"""
Estimate cost for a completion.
Args:
model_id: Model ID
input_tokens: Number of input tokens
output_tokens: Number of output tokens
Returns:
Estimated cost in USD
"""
if hasattr(self.provider, "estimate_cost"):
return self.provider.estimate_cost(model_id, input_tokens, output_tokens)
# Fallback to default pricing
input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000
output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000
return input_cost + output_cost
def set_default_model(self, model_id: str) -> None:
"""
Set the default model.
Args:
model_id: Model ID to use as default
"""
self.default_model = model_id
self.logger.info(f"Default model set to: {model_id}")
def clear_cache(self) -> None:
"""Clear the provider's model cache."""
if hasattr(self.provider, "clear_cache"):
self.provider.clear_cache()

891
oai/core/session.py Normal file
View File

@@ -0,0 +1,891 @@
"""
Chat session management for oAI.
This module provides the ChatSession class that manages an interactive
chat session including history, state, and message handling.
"""
import asyncio
import json
import time
from dataclasses import dataclass, field
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Tuple
from oai.commands.registry import CommandContext, CommandResult, registry
from oai.config.database import Database
from oai.config.settings import Settings
from oai.constants import (
COST_WARNING_THRESHOLD,
LOW_CREDIT_AMOUNT,
LOW_CREDIT_RATIO,
)
from oai.core.client import AIClient
from oai.mcp.manager import MCPManager
from oai.providers.base import ChatResponse, StreamChunk, UsageStats
from oai.utils.logging import get_logger
@dataclass
class SessionStats:
"""
Statistics for the current session.
Tracks tokens, costs, and message counts.
"""
total_input_tokens: int = 0
total_output_tokens: int = 0
total_cost: float = 0.0
message_count: int = 0
@property
def total_tokens(self) -> int:
"""Get total token count."""
return self.total_input_tokens + self.total_output_tokens
def add_usage(self, usage: Optional[UsageStats], cost: float = 0.0) -> None:
"""
Add usage stats from a response.
Args:
usage: Usage statistics
cost: Cost if not in usage
"""
if usage:
self.total_input_tokens += usage.prompt_tokens
self.total_output_tokens += usage.completion_tokens
if usage.total_cost_usd:
self.total_cost += usage.total_cost_usd
else:
self.total_cost += cost
else:
self.total_cost += cost
self.message_count += 1
@dataclass
class HistoryEntry:
"""
A single entry in the conversation history.
Stores the user prompt, assistant response, and metrics.
"""
prompt: str
response: str
prompt_tokens: int = 0
completion_tokens: int = 0
msg_cost: float = 0.0
timestamp: Optional[float] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format."""
return {
"prompt": self.prompt,
"response": self.response,
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"msg_cost": self.msg_cost,
}
class ChatSession:
"""
Manages an interactive chat session.
Handles conversation history, state management, command processing,
and communication with the AI client.
Attributes:
client: AI client for API requests
settings: Application settings
mcp_manager: MCP manager for file/database access
history: Conversation history
stats: Session statistics
"""
def __init__(
self,
client: AIClient,
settings: Settings,
mcp_manager: Optional[MCPManager] = None,
):
"""
Initialize a chat session.
Args:
client: AI client instance
settings: Application settings
mcp_manager: Optional MCP manager
"""
self.client = client
self.settings = settings
self.mcp_manager = mcp_manager
self.db = Database()
self.history: List[HistoryEntry] = []
self.stats = SessionStats()
# Session state
self.system_prompt: str = settings.effective_system_prompt
self.memory_enabled: bool = True
self.memory_start_index: int = 0
self.online_enabled: bool = settings.default_online_mode
self.middle_out_enabled: bool = False
self.session_max_token: int = 0
self.current_index: int = 0
# Selected model
self.selected_model: Optional[Dict[str, Any]] = None
self.logger = get_logger()
def get_context(self) -> CommandContext:
"""
Get the current command context.
Returns:
CommandContext with current session state
"""
return CommandContext(
settings=self.settings,
provider=self.client.provider,
mcp_manager=self.mcp_manager,
selected_model_raw=self.selected_model,
session_history=[e.to_dict() for e in self.history],
session_system_prompt=self.system_prompt,
memory_enabled=self.memory_enabled,
memory_start_index=self.memory_start_index,
online_enabled=self.online_enabled,
middle_out_enabled=self.middle_out_enabled,
session_max_token=self.session_max_token,
total_input_tokens=self.stats.total_input_tokens,
total_output_tokens=self.stats.total_output_tokens,
total_cost=self.stats.total_cost,
message_count=self.stats.message_count,
current_index=self.current_index,
)
def set_model(self, model: Dict[str, Any]) -> None:
"""
Set the selected model.
Args:
model: Raw model dictionary
"""
self.selected_model = model
self.client.set_default_model(model["id"])
self.logger.info(f"Model selected: {model['id']}")
def build_api_messages(self, user_input: str) -> List[Dict[str, Any]]:
"""
Build the messages array for an API request.
Includes system prompt, history (if memory enabled), and current input.
Args:
user_input: Current user input
Returns:
List of message dictionaries
"""
messages = []
# Add system prompt
if self.system_prompt:
messages.append({"role": "system", "content": self.system_prompt})
# Add database context if in database mode
if self.mcp_manager and self.mcp_manager.enabled:
if self.mcp_manager.mode == "database" and self.mcp_manager.selected_db_index is not None:
db = self.mcp_manager.databases[self.mcp_manager.selected_db_index]
db_context = (
f"You are connected to SQLite database: {db['name']}\n"
f"Available tables: {', '.join(db['tables'])}\n\n"
"Use inspect_database, search_database, or query_database tools. "
"All queries are read-only."
)
messages.append({"role": "system", "content": db_context})
# Add history if memory enabled
if self.memory_enabled:
for i in range(self.memory_start_index, len(self.history)):
entry = self.history[i]
messages.append({"role": "user", "content": entry.prompt})
messages.append({"role": "assistant", "content": entry.response})
# Add current message
messages.append({"role": "user", "content": user_input})
return messages
def get_mcp_tools(self) -> Optional[List[Dict[str, Any]]]:
"""
Get MCP tool definitions if available.
Returns:
List of tool schemas or None
"""
if not self.mcp_manager or not self.mcp_manager.enabled:
return None
if not self.selected_model:
return None
# Check if model supports tools
supported_params = self.selected_model.get("supported_parameters", [])
if "tools" not in supported_params and "functions" not in supported_params:
return None
return self.mcp_manager.get_tools_schema()
async def execute_tool(
self,
tool_name: str,
tool_args: Dict[str, Any],
) -> Dict[str, Any]:
"""
Execute an MCP tool.
Args:
tool_name: Name of the tool
tool_args: Tool arguments
Returns:
Tool execution result
"""
if not self.mcp_manager:
return {"error": "MCP not available"}
return await self.mcp_manager.call_tool(tool_name, **tool_args)
def send_message(
self,
user_input: str,
stream: bool = True,
on_stream_chunk: Optional[Callable[[str], None]] = None,
) -> Tuple[str, Optional[UsageStats], float]:
"""
Send a message and get a response.
Args:
user_input: User's input text
stream: Whether to stream the response
on_stream_chunk: Callback for stream chunks
Returns:
Tuple of (response_text, usage_stats, response_time)
"""
if not self.selected_model:
raise ValueError("No model selected")
start_time = time.time()
messages = self.build_api_messages(user_input)
# Get MCP tools
tools = self.get_mcp_tools()
if tools:
# Disable streaming when tools are present
stream = False
# Build request parameters
model_id = self.selected_model["id"]
if self.online_enabled:
if hasattr(self.client.provider, "get_effective_model_id"):
model_id = self.client.provider.get_effective_model_id(model_id, True)
transforms = ["middle-out"] if self.middle_out_enabled else None
max_tokens = None
if self.session_max_token > 0:
max_tokens = self.session_max_token
if tools:
# Use tool handling flow
response = self._send_with_tools(
messages=messages,
model_id=model_id,
tools=tools,
max_tokens=max_tokens,
transforms=transforms,
)
response_time = time.time() - start_time
return response.content or "", response.usage, response_time
elif stream:
# Use streaming flow
full_text, usage = self._stream_response(
messages=messages,
model_id=model_id,
max_tokens=max_tokens,
transforms=transforms,
on_chunk=on_stream_chunk,
)
response_time = time.time() - start_time
return full_text, usage, response_time
else:
# Non-streaming request
response = self.client.chat(
messages=messages,
model=model_id,
stream=False,
max_tokens=max_tokens,
transforms=transforms,
)
response_time = time.time() - start_time
if isinstance(response, ChatResponse):
return response.content or "", response.usage, response_time
else:
return "", None, response_time
def _send_with_tools(
self,
messages: List[Dict[str, Any]],
model_id: str,
tools: List[Dict[str, Any]],
max_tokens: Optional[int] = None,
transforms: Optional[List[str]] = None,
) -> ChatResponse:
"""
Send a request with tool call handling.
Args:
messages: API messages
model_id: Model ID
tools: Tool definitions
max_tokens: Max tokens
transforms: Transforms list
Returns:
Final ChatResponse
"""
max_loops = 5
loop_count = 0
api_messages = list(messages)
while loop_count < max_loops:
response = self.client.chat(
messages=api_messages,
model=model_id,
stream=False,
max_tokens=max_tokens,
tools=tools,
tool_choice="auto",
transforms=transforms,
)
if not isinstance(response, ChatResponse):
raise ValueError("Expected ChatResponse")
tool_calls = response.tool_calls
if not tool_calls:
return response
# Tool calls requested by AI
tool_results = []
for tc in tool_calls:
try:
args = json.loads(tc.function.arguments)
except json.JSONDecodeError as e:
self.logger.error(f"Failed to parse tool arguments: {e}")
tool_results.append({
"tool_call_id": tc.id,
"role": "tool",
"name": tc.function.name,
"content": json.dumps({"error": f"Invalid arguments: {e}"}),
})
continue
# Display tool call
args_display = ", ".join(
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
for k, v in args.items()
)
# Executing tool: {tc.function.name}
# Execute tool
result = asyncio.run(self.execute_tool(tc.function.name, args))
if "error" in result:
# Tool execution error logged
pass
else:
# Tool execution successful
pass
tool_results.append({
"tool_call_id": tc.id,
"role": "tool",
"name": tc.function.name,
"content": json.dumps(result),
})
# Add assistant message with tool calls
api_messages.append({
"role": "assistant",
"content": response.content,
"tool_calls": [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in tool_calls
],
})
api_messages.extend(tool_results)
# Processing tool results
loop_count += 1
self.logger.warning(f"Reached max tool loops ({max_loops})")
return response
def _stream_response(
self,
messages: List[Dict[str, Any]],
model_id: str,
max_tokens: Optional[int] = None,
transforms: Optional[List[str]] = None,
on_chunk: Optional[Callable[[str], None]] = None,
) -> Tuple[str, Optional[UsageStats]]:
"""
Stream a response with live display.
Args:
messages: API messages
model_id: Model ID
max_tokens: Max tokens
transforms: Transforms
on_chunk: Callback for chunks
Returns:
Tuple of (full_text, usage)
"""
response = self.client.chat(
messages=messages,
model=model_id,
stream=True,
max_tokens=max_tokens,
transforms=transforms,
)
if isinstance(response, ChatResponse):
return response.content or "", response.usage
full_text = ""
usage: Optional[UsageStats] = None
try:
for chunk in response:
if chunk.error:
self.logger.error(f"Stream error: {chunk.error}")
break
if chunk.delta_content:
full_text += chunk.delta_content
if on_chunk:
on_chunk(chunk.delta_content)
if chunk.usage:
usage = chunk.usage
except KeyboardInterrupt:
self.logger.info("Streaming interrupted")
return "", None
return full_text, usage
# ========== ASYNC METHODS FOR TUI ==========
async def send_message_async(
self,
user_input: str,
stream: bool = True,
) -> AsyncIterator[StreamChunk]:
"""
Async version of send_message for Textual TUI.
Args:
user_input: User's input text
stream: Whether to stream the response
Yields:
StreamChunk objects for progressive display
"""
if not self.selected_model:
raise ValueError("No model selected")
messages = self.build_api_messages(user_input)
tools = self.get_mcp_tools()
if tools:
# Disable streaming when tools are present
stream = False
model_id = self.selected_model["id"]
if self.online_enabled:
if hasattr(self.client.provider, "get_effective_model_id"):
model_id = self.client.provider.get_effective_model_id(model_id, True)
transforms = ["middle-out"] if self.middle_out_enabled else None
max_tokens = None
if self.session_max_token > 0:
max_tokens = self.session_max_token
if tools:
# Use async tool handling flow
async for chunk in self._send_with_tools_async(
messages=messages,
model_id=model_id,
tools=tools,
max_tokens=max_tokens,
transforms=transforms,
):
yield chunk
elif stream:
# Use async streaming flow
async for chunk in self._stream_response_async(
messages=messages,
model_id=model_id,
max_tokens=max_tokens,
transforms=transforms,
):
yield chunk
else:
# Non-streaming request
response = self.client.chat(
messages=messages,
model=model_id,
stream=False,
max_tokens=max_tokens,
transforms=transforms,
)
if isinstance(response, ChatResponse):
# Yield single chunk with complete response
chunk = StreamChunk(
id="",
delta_content=response.content,
usage=response.usage,
error=None,
)
yield chunk
async def _send_with_tools_async(
self,
messages: List[Dict[str, Any]],
model_id: str,
tools: List[Dict[str, Any]],
max_tokens: Optional[int] = None,
transforms: Optional[List[str]] = None,
) -> AsyncIterator[StreamChunk]:
"""
Async version of _send_with_tools for TUI.
Args:
messages: API messages
model_id: Model ID
tools: Tool definitions
max_tokens: Max tokens
transforms: Transforms list
Yields:
StreamChunk objects including tool call notifications
"""
max_loops = 5
loop_count = 0
api_messages = list(messages)
while loop_count < max_loops:
response = self.client.chat(
messages=api_messages,
model=model_id,
stream=False,
max_tokens=max_tokens,
tools=tools,
tool_choice="auto",
transforms=transforms,
)
if not isinstance(response, ChatResponse):
raise ValueError("Expected ChatResponse")
tool_calls = response.tool_calls
if not tool_calls:
# Final response, yield it
chunk = StreamChunk(
id="",
delta_content=response.content,
usage=response.usage,
error=None,
)
yield chunk
return
# Yield notification about tool calls
tool_notification = f"\n🔧 AI requesting {len(tool_calls)} tool call(s)...\n"
yield StreamChunk(id="", delta_content=tool_notification, usage=None, error=None)
tool_results = []
for tc in tool_calls:
try:
args = json.loads(tc.function.arguments)
except json.JSONDecodeError as e:
self.logger.error(f"Failed to parse tool arguments: {e}")
tool_results.append({
"tool_call_id": tc.id,
"role": "tool",
"name": tc.function.name,
"content": json.dumps({"error": f"Invalid arguments: {e}"}),
})
continue
# Yield tool call display
args_display = ", ".join(
f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}"
for k, v in args.items()
)
tool_display = f"{tc.function.name}({args_display})\n"
yield StreamChunk(id="", delta_content=tool_display, usage=None, error=None)
# Execute tool (await instead of asyncio.run)
result = await self.execute_tool(tc.function.name, args)
if "error" in result:
error_msg = f" ✗ Error: {result['error']}\n"
yield StreamChunk(id="", delta_content=error_msg, usage=None, error=None)
else:
success_msg = self._format_tool_success(tc.function.name, result)
yield StreamChunk(id="", delta_content=success_msg, usage=None, error=None)
tool_results.append({
"tool_call_id": tc.id,
"role": "tool",
"name": tc.function.name,
"content": json.dumps(result),
})
# Add assistant message with tool calls
api_messages.append({
"role": "assistant",
"content": response.content,
"tool_calls": [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in tool_calls
],
})
# Add tool results
api_messages.extend(tool_results)
loop_count += 1
# Max loops reached
yield StreamChunk(
id="",
delta_content="\n⚠️ Maximum tool call loops reached\n",
usage=None,
error="Max loops reached"
)
def _format_tool_success(self, tool_name: str, result: Dict[str, Any]) -> str:
"""Format a success message for a tool call."""
if tool_name == "search_files":
count = result.get("count", 0)
return f" ✓ Found {count} file(s)\n"
elif tool_name == "read_file":
size = result.get("size", 0)
truncated = " (truncated)" if result.get("truncated") else ""
return f" ✓ Read {size} bytes{truncated}\n"
elif tool_name == "list_directory":
count = result.get("count", 0)
return f" ✓ Listed {count} item(s)\n"
elif tool_name == "inspect_database":
if "table" in result:
return f" ✓ Inspected table: {result['table']}\n"
else:
return f" ✓ Inspected database ({result.get('table_count', 0)} tables)\n"
elif tool_name == "search_database":
count = result.get("count", 0)
return f" ✓ Found {count} match(es)\n"
elif tool_name == "query_database":
count = result.get("count", 0)
return f" ✓ Query returned {count} row(s)\n"
else:
return " ✓ Success\n"
async def _stream_response_async(
self,
messages: List[Dict[str, Any]],
model_id: str,
max_tokens: Optional[int] = None,
transforms: Optional[List[str]] = None,
) -> AsyncIterator[StreamChunk]:
"""
Async version of _stream_response for TUI.
Args:
messages: API messages
model_id: Model ID
max_tokens: Max tokens
transforms: Transforms
Yields:
StreamChunk objects
"""
response = self.client.chat(
messages=messages,
model=model_id,
stream=True,
max_tokens=max_tokens,
transforms=transforms,
)
if isinstance(response, ChatResponse):
# Non-streaming response
chunk = StreamChunk(
id="",
delta_content=response.content,
usage=response.usage,
error=None,
)
yield chunk
return
# Stream the response
for chunk in response:
if chunk.error:
yield StreamChunk(id="", delta_content=None, usage=None, error=chunk.error)
break
yield chunk
# ========== END ASYNC METHODS ==========
def add_to_history(
self,
prompt: str,
response: str,
usage: Optional[UsageStats] = None,
cost: float = 0.0,
) -> None:
"""
Add an exchange to the history.
Args:
prompt: User prompt
response: Assistant response
usage: Usage statistics
cost: Cost if not in usage
"""
entry = HistoryEntry(
prompt=prompt,
response=response,
prompt_tokens=usage.prompt_tokens if usage else 0,
completion_tokens=usage.completion_tokens if usage else 0,
msg_cost=usage.total_cost_usd if usage and usage.total_cost_usd else cost,
timestamp=time.time(),
)
self.history.append(entry)
self.current_index = len(self.history) - 1
self.stats.add_usage(usage, cost)
def save_conversation(self, name: str) -> bool:
"""
Save the current conversation.
Args:
name: Name for the saved conversation
Returns:
True if saved successfully
"""
if not self.history:
return False
data = [e.to_dict() for e in self.history]
self.db.save_conversation(name, data)
self.logger.info(f"Saved conversation: {name}")
return True
def load_conversation(self, name: str) -> bool:
"""
Load a saved conversation.
Args:
name: Name of the conversation to load
Returns:
True if loaded successfully
"""
data = self.db.load_conversation(name)
if not data:
return False
self.history.clear()
for entry_dict in data:
self.history.append(HistoryEntry(
prompt=entry_dict.get("prompt", ""),
response=entry_dict.get("response", ""),
prompt_tokens=entry_dict.get("prompt_tokens", 0),
completion_tokens=entry_dict.get("completion_tokens", 0),
msg_cost=entry_dict.get("msg_cost", 0.0),
))
self.current_index = len(self.history) - 1
self.memory_start_index = 0
self.stats = SessionStats() # Reset stats for loaded conversation
self.logger.info(f"Loaded conversation: {name}")
return True
def reset(self) -> None:
"""Reset the session state."""
self.history.clear()
self.stats = SessionStats()
self.system_prompt = ""
self.memory_start_index = 0
self.current_index = 0
self.logger.info("Session reset")
def check_warnings(self) -> List[str]:
"""
Check for cost and credit warnings.
Returns:
List of warning messages
"""
warnings = []
# Check last message cost
if self.history:
last_cost = self.history[-1].msg_cost
threshold = self.settings.cost_warning_threshold
if last_cost > threshold:
warnings.append(
f"High cost: ${last_cost:.4f} exceeds threshold ${threshold:.4f}"
)
# Check credits
credits = self.client.get_credits()
if credits:
left = credits.get("credits_left", 0)
total = credits.get("total_credits", 0)
if left < LOW_CREDIT_AMOUNT:
warnings.append(f"Low credits: ${left:.2f} remaining!")
elif total > 0 and left < total * LOW_CREDIT_RATIO:
warnings.append(f"Credits low: less than 10% remaining (${left:.2f})")
return warnings

28
oai/mcp/__init__.py Normal file
View File

@@ -0,0 +1,28 @@
"""
Model Context Protocol (MCP) integration for oAI.
This package provides filesystem and database access capabilities
through the MCP standard, allowing AI models to interact with
local files and SQLite databases safely.
Key components:
- MCPManager: High-level manager for MCP operations
- MCPFilesystemServer: Filesystem and database access implementation
- GitignoreParser: Pattern matching for .gitignore support
- SQLiteQueryValidator: Query safety validation
- CrossPlatformMCPConfig: OS-specific configuration
"""
from oai.mcp.manager import MCPManager
from oai.mcp.server import MCPFilesystemServer
from oai.mcp.gitignore import GitignoreParser
from oai.mcp.validators import SQLiteQueryValidator
from oai.mcp.platform import CrossPlatformMCPConfig
__all__ = [
"MCPManager",
"MCPFilesystemServer",
"GitignoreParser",
"SQLiteQueryValidator",
"CrossPlatformMCPConfig",
]

166
oai/mcp/gitignore.py Normal file
View File

@@ -0,0 +1,166 @@
"""
Gitignore pattern parsing for oAI MCP.
This module implements .gitignore pattern matching to filter files
during MCP filesystem operations.
"""
import fnmatch
from pathlib import Path
from typing import List, Tuple
from oai.utils.logging import get_logger
class GitignoreParser:
"""
Parse and apply .gitignore patterns.
Supports standard gitignore syntax including:
- Wildcards (*) and double wildcards (**)
- Directory-only patterns (ending with /)
- Negation patterns (starting with !)
- Comments (lines starting with #)
Patterns are applied relative to the directory containing
the .gitignore file.
"""
def __init__(self):
"""Initialize an empty pattern collection."""
# List of (pattern, is_negation, source_dir)
self.patterns: List[Tuple[str, bool, Path]] = []
def add_gitignore(self, gitignore_path: Path) -> None:
"""
Parse and add patterns from a .gitignore file.
Args:
gitignore_path: Path to the .gitignore file
"""
logger = get_logger()
if not gitignore_path.exists():
return
try:
source_dir = gitignore_path.parent
with open(gitignore_path, "r", encoding="utf-8") as f:
for line_num, line in enumerate(f, 1):
line = line.rstrip("\n\r")
# Skip empty lines and comments
if not line or line.startswith("#"):
continue
# Check for negation pattern
is_negation = line.startswith("!")
if is_negation:
line = line[1:]
# Remove leading slash (make relative to gitignore location)
if line.startswith("/"):
line = line[1:]
self.patterns.append((line, is_negation, source_dir))
logger.debug(
f"Loaded {len(self.patterns)} patterns from {gitignore_path}"
)
except Exception as e:
logger.warning(f"Error reading {gitignore_path}: {e}")
def should_ignore(self, path: Path) -> bool:
"""
Check if a path should be ignored based on gitignore patterns.
Patterns are evaluated in order, with later patterns overriding
earlier ones. Negation patterns (starting with !) un-ignore
previously matched paths.
Args:
path: Path to check
Returns:
True if the path should be ignored
"""
if not self.patterns:
return False
ignored = False
for pattern, is_negation, source_dir in self.patterns:
# Only apply pattern if path is under the source directory
try:
rel_path = path.relative_to(source_dir)
except ValueError:
# Path is not relative to this gitignore's directory
continue
rel_path_str = str(rel_path)
# Check if pattern matches
if self._match_pattern(pattern, rel_path_str, path.is_dir()):
if is_negation:
ignored = False # Negation patterns un-ignore
else:
ignored = True
return ignored
def _match_pattern(self, pattern: str, path: str, is_dir: bool) -> bool:
"""
Match a gitignore pattern against a path.
Args:
pattern: The gitignore pattern
path: The relative path string to match
is_dir: Whether the path is a directory
Returns:
True if the pattern matches
"""
# Directory-only pattern (ends with /)
if pattern.endswith("/"):
if not is_dir:
return False
pattern = pattern[:-1]
# Handle ** patterns (matches any number of directories)
if "**" in pattern:
pattern_parts = pattern.split("**")
if len(pattern_parts) == 2:
prefix, suffix = pattern_parts
# Match if path starts with prefix and ends with suffix
if prefix:
if not path.startswith(prefix.rstrip("/")):
return False
if suffix:
suffix = suffix.lstrip("/")
if not (path.endswith(suffix) or f"/{suffix}" in path):
return False
return True
# Direct match using fnmatch
if fnmatch.fnmatch(path, pattern):
return True
# Match as subdirectory pattern (pattern without / matches in any directory)
if "/" not in pattern:
parts = path.split("/")
if any(fnmatch.fnmatch(part, pattern) for part in parts):
return True
return False
def clear(self) -> None:
"""Clear all loaded patterns."""
self.patterns = []
@property
def pattern_count(self) -> int:
"""Get the number of loaded patterns."""
return len(self.patterns)

1365
oai/mcp/manager.py Normal file

File diff suppressed because it is too large Load Diff

228
oai/mcp/platform.py Normal file
View File

@@ -0,0 +1,228 @@
"""
Cross-platform MCP configuration for oAI.
This module handles OS-specific configuration, path handling,
and security checks for the MCP filesystem server.
"""
import os
import platform
import subprocess
from pathlib import Path
from typing import List, Dict, Any, Optional
from oai.constants import SYSTEM_DIRS_BLACKLIST
from oai.utils.logging import get_logger
class CrossPlatformMCPConfig:
"""
Handle OS-specific MCP configuration.
Provides methods for path normalization, security validation,
and OS-specific default directories.
Attributes:
system: Operating system name
is_macos: Whether running on macOS
is_linux: Whether running on Linux
is_windows: Whether running on Windows
"""
def __init__(self):
"""Initialize platform detection."""
self.system = platform.system()
self.is_macos = self.system == "Darwin"
self.is_linux = self.system == "Linux"
self.is_windows = self.system == "Windows"
logger = get_logger()
logger.info(f"Detected OS: {self.system}")
def get_default_allowed_dirs(self) -> List[Path]:
"""
Get safe default directories for the current OS.
Returns:
List of default directories that are safe to access
"""
home = Path.home()
if self.is_macos:
return [
home / "Documents",
home / "Desktop",
home / "Downloads",
]
elif self.is_linux:
dirs = [home / "Documents"]
# Try to get XDG directories
try:
for xdg_dir in ["DOCUMENTS", "DESKTOP", "DOWNLOAD"]:
result = subprocess.run(
["xdg-user-dir", xdg_dir],
capture_output=True,
text=True,
timeout=1
)
if result.returncode == 0:
dir_path = Path(result.stdout.strip())
if dir_path.exists():
dirs.append(dir_path)
except (subprocess.TimeoutExpired, FileNotFoundError):
# Fallback to standard locations
dirs.extend([
home / "Desktop",
home / "Downloads",
])
return list(set(dirs))
elif self.is_windows:
return [
home / "Documents",
home / "Desktop",
home / "Downloads",
]
# Fallback for unknown OS
return [home]
def get_python_command(self) -> str:
"""
Get the Python executable path.
Returns:
Path to the Python executable
"""
import sys
return sys.executable
def get_filesystem_warning(self) -> str:
"""
Get OS-specific security warning message.
Returns:
Warning message for the current OS
"""
if self.is_macos:
return """
Note: macOS Security
The Filesystem MCP server needs access to your selected folder.
You may see a security prompt - click 'Allow' to proceed.
(System Settings > Privacy & Security > Files and Folders)
"""
elif self.is_linux:
return """
Note: Linux Security
The Filesystem MCP server will access your selected folder.
Ensure oAI has appropriate file permissions.
"""
elif self.is_windows:
return """
Note: Windows Security
The Filesystem MCP server will access your selected folder.
You may need to grant file access permissions.
"""
return ""
def normalize_path(self, path: str) -> Path:
"""
Normalize a path for the current OS.
Expands user directory (~) and resolves to absolute path.
Args:
path: Path string to normalize
Returns:
Normalized absolute Path
"""
return Path(os.path.expanduser(path)).resolve()
def is_system_directory(self, path: Path) -> bool:
"""
Check if a path is a protected system directory.
Args:
path: Path to check
Returns:
True if the path is a system directory
"""
path_str = str(path)
for blocked in SYSTEM_DIRS_BLACKLIST:
if path_str.startswith(blocked):
return True
return False
def is_safe_path(self, requested_path: Path, allowed_dirs: List[Path]) -> bool:
"""
Check if a path is within allowed directories.
Args:
requested_path: Path being requested
allowed_dirs: List of allowed parent directories
Returns:
True if the path is within an allowed directory
"""
try:
requested = requested_path.resolve()
for allowed in allowed_dirs:
try:
allowed_resolved = allowed.resolve()
requested.relative_to(allowed_resolved)
return True
except ValueError:
continue
return False
except Exception:
return False
def get_folder_stats(self, folder: Path) -> Dict[str, Any]:
"""
Get statistics for a folder.
Args:
folder: Path to the folder
Returns:
Dictionary with folder statistics:
- exists: Whether the folder exists
- file_count: Number of files (if exists)
- total_size: Total size in bytes (if exists)
- size_mb: Size in megabytes (if exists)
- error: Error message (if any)
"""
logger = get_logger()
try:
if not folder.exists() or not folder.is_dir():
return {"exists": False}
file_count = 0
total_size = 0
for item in folder.rglob("*"):
if item.is_file():
file_count += 1
try:
total_size += item.stat().st_size
except (OSError, PermissionError):
pass
return {
"exists": True,
"file_count": file_count,
"total_size": total_size,
"size_mb": total_size / (1024 * 1024),
}
except Exception as e:
logger.error(f"Error getting folder stats for {folder}: {e}")
return {"exists": False, "error": str(e)}

1368
oai/mcp/server.py Normal file

File diff suppressed because it is too large Load Diff

123
oai/mcp/validators.py Normal file
View File

@@ -0,0 +1,123 @@
"""
Query validation for oAI MCP database operations.
This module provides safety validation for SQL queries to ensure
only read-only operations are executed.
"""
import re
from typing import Tuple
from oai.constants import DANGEROUS_SQL_KEYWORDS
class SQLiteQueryValidator:
"""
Validate SQLite queries for read-only safety.
Ensures that only SELECT queries (including CTEs) are allowed
and blocks potentially dangerous operations like INSERT, UPDATE,
DELETE, DROP, etc.
"""
@staticmethod
def is_safe_query(query: str) -> Tuple[bool, str]:
"""
Validate that a query is a safe read-only SELECT.
The validation:
1. Checks that query starts with SELECT or WITH
2. Strips string literals before checking for dangerous keywords
3. Blocks any dangerous keywords outside of string literals
Args:
query: SQL query string to validate
Returns:
Tuple of (is_safe, error_message)
- is_safe: True if the query is safe to execute
- error_message: Description of why query is unsafe (empty if safe)
Examples:
>>> SQLiteQueryValidator.is_safe_query("SELECT * FROM users")
(True, "")
>>> SQLiteQueryValidator.is_safe_query("DELETE FROM users")
(False, "Only SELECT queries are allowed...")
>>> SQLiteQueryValidator.is_safe_query("SELECT 'DELETE' FROM users")
(True, "") # 'DELETE' is inside a string literal
"""
query_upper = query.strip().upper()
# Must start with SELECT or WITH (for CTEs)
if not (query_upper.startswith("SELECT") or query_upper.startswith("WITH")):
return False, "Only SELECT queries are allowed (including WITH/CTE)"
# Remove string literals before checking for dangerous keywords
# This prevents false positives when keywords appear in data
query_no_strings = re.sub(r"'[^']*'", "", query_upper)
query_no_strings = re.sub(r'"[^"]*"', "", query_no_strings)
# Check for dangerous keywords outside of quotes
for keyword in DANGEROUS_SQL_KEYWORDS:
if re.search(r"\b" + keyword + r"\b", query_no_strings):
return False, f"Keyword '{keyword}' not allowed in read-only mode"
return True, ""
@staticmethod
def sanitize_table_name(table_name: str) -> str:
"""
Sanitize a table name to prevent SQL injection.
Only allows alphanumeric characters and underscores.
Args:
table_name: Table name to sanitize
Returns:
Sanitized table name
Raises:
ValueError: If table name contains invalid characters
"""
# Remove any characters that aren't alphanumeric or underscore
sanitized = re.sub(r"[^\w]", "", table_name)
if not sanitized:
raise ValueError("Table name cannot be empty after sanitization")
if sanitized != table_name:
raise ValueError(
f"Table name contains invalid characters: {table_name}"
)
return sanitized
@staticmethod
def sanitize_column_name(column_name: str) -> str:
"""
Sanitize a column name to prevent SQL injection.
Only allows alphanumeric characters and underscores.
Args:
column_name: Column name to sanitize
Returns:
Sanitized column name
Raises:
ValueError: If column name contains invalid characters
"""
# Remove any characters that aren't alphanumeric or underscore
sanitized = re.sub(r"[^\w]", "", column_name)
if not sanitized:
raise ValueError("Column name cannot be empty after sanitization")
if sanitized != column_name:
raise ValueError(
f"Column name contains invalid characters: {column_name}"
)
return sanitized

32
oai/providers/__init__.py Normal file
View File

@@ -0,0 +1,32 @@
"""
Provider abstraction for oAI.
This module provides a unified interface for AI model providers,
enabling easy extension to support additional providers beyond OpenRouter.
"""
from oai.providers.base import (
AIProvider,
ChatMessage,
ChatResponse,
ToolCall,
ToolFunction,
UsageStats,
ModelInfo,
ProviderCapabilities,
)
from oai.providers.openrouter import OpenRouterProvider
__all__ = [
# Base classes and types
"AIProvider",
"ChatMessage",
"ChatResponse",
"ToolCall",
"ToolFunction",
"UsageStats",
"ModelInfo",
"ProviderCapabilities",
# Provider implementations
"OpenRouterProvider",
]

413
oai/providers/base.py Normal file
View File

@@ -0,0 +1,413 @@
"""
Abstract base provider for AI model integration.
This module defines the interface that all AI providers must implement,
along with common data structures for requests and responses.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Optional,
Union,
)
class MessageRole(str, Enum):
"""Message roles in a conversation."""
SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
TOOL = "tool"
@dataclass
class ToolFunction:
"""
Represents a function within a tool call.
Attributes:
name: The function name
arguments: JSON string of function arguments
"""
name: str
arguments: str
@dataclass
class ToolCall:
"""
Represents a tool/function call requested by the model.
Attributes:
id: Unique identifier for this tool call
type: Type of tool call (usually "function")
function: The function being called
"""
id: str
type: str
function: ToolFunction
@dataclass
class UsageStats:
"""
Token usage statistics from an API response.
Attributes:
prompt_tokens: Number of tokens in the prompt
completion_tokens: Number of tokens in the completion
total_tokens: Total tokens used
total_cost_usd: Cost in USD (if available from API)
"""
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
total_cost_usd: Optional[float] = None
@property
def input_tokens(self) -> int:
"""Alias for prompt_tokens."""
return self.prompt_tokens
@property
def output_tokens(self) -> int:
"""Alias for completion_tokens."""
return self.completion_tokens
@dataclass
class ChatMessage:
"""
A single message in a chat conversation.
Attributes:
role: The role of the message sender
content: Message content (text or structured content blocks)
name: Optional name for the sender
tool_calls: List of tool calls (for assistant messages)
tool_call_id: Tool call ID this message responds to (for tool messages)
"""
role: str
content: Union[str, List[Dict[str, Any]], None] = None
name: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = None
tool_call_id: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary format for API requests."""
result: Dict[str, Any] = {"role": self.role}
if self.content is not None:
result["content"] = self.content
if self.name:
result["name"] = self.name
if self.tool_calls:
result["tool_calls"] = [
{
"id": tc.id,
"type": tc.type,
"function": {
"name": tc.function.name,
"arguments": tc.function.arguments,
},
}
for tc in self.tool_calls
]
if self.tool_call_id:
result["tool_call_id"] = self.tool_call_id
return result
@dataclass
class ChatResponseChoice:
"""
A single choice in a chat response.
Attributes:
index: Index of this choice
message: The response message
finish_reason: Why the response ended
"""
index: int
message: ChatMessage
finish_reason: Optional[str] = None
@dataclass
class ChatResponse:
"""
Response from a chat completion request.
Attributes:
id: Unique identifier for this response
choices: List of response choices
usage: Token usage statistics
model: Model that generated this response
created: Unix timestamp of creation
"""
id: str
choices: List[ChatResponseChoice]
usage: Optional[UsageStats] = None
model: Optional[str] = None
created: Optional[int] = None
@property
def message(self) -> Optional[ChatMessage]:
"""Get the first choice's message."""
if self.choices:
return self.choices[0].message
return None
@property
def content(self) -> Optional[str]:
"""Get the text content of the first choice."""
msg = self.message
if msg and isinstance(msg.content, str):
return msg.content
return None
@property
def tool_calls(self) -> Optional[List[ToolCall]]:
"""Get tool calls from the first choice."""
msg = self.message
if msg:
return msg.tool_calls
return None
@dataclass
class StreamChunk:
"""
A single chunk from a streaming response.
Attributes:
id: Response ID
delta_content: New content in this chunk
finish_reason: Finish reason (if this is the last chunk)
usage: Usage stats (usually in the last chunk)
error: Error message if something went wrong
"""
id: str
delta_content: Optional[str] = None
finish_reason: Optional[str] = None
usage: Optional[UsageStats] = None
error: Optional[str] = None
@dataclass
class ModelInfo:
"""
Information about an AI model.
Attributes:
id: Unique model identifier
name: Display name
description: Model description
context_length: Maximum context window size
pricing: Pricing info (input/output per million tokens)
supported_parameters: List of supported API parameters
input_modalities: Supported input types (text, image, etc.)
output_modalities: Supported output types
"""
id: str
name: str
description: str = ""
context_length: int = 0
pricing: Dict[str, float] = field(default_factory=dict)
supported_parameters: List[str] = field(default_factory=list)
input_modalities: List[str] = field(default_factory=lambda: ["text"])
output_modalities: List[str] = field(default_factory=lambda: ["text"])
def supports_images(self) -> bool:
"""Check if model supports image input."""
return "image" in self.input_modalities
def supports_tools(self) -> bool:
"""Check if model supports function calling/tools."""
return "tools" in self.supported_parameters or "functions" in self.supported_parameters
def supports_streaming(self) -> bool:
"""Check if model supports streaming responses."""
return "stream" in self.supported_parameters
def supports_online(self) -> bool:
"""Check if model supports web search (online mode)."""
return self.supports_tools()
@dataclass
class ProviderCapabilities:
"""
Capabilities supported by a provider.
Attributes:
streaming: Provider supports streaming responses
tools: Provider supports function calling
images: Provider supports image inputs
online: Provider supports web search
max_context: Maximum context length across all models
"""
streaming: bool = True
tools: bool = True
images: bool = True
online: bool = False
max_context: int = 128000
class AIProvider(ABC):
"""
Abstract base class for AI model providers.
All provider implementations must inherit from this class
and implement the required abstract methods.
"""
def __init__(self, api_key: str, base_url: Optional[str] = None):
"""
Initialize the provider.
Args:
api_key: API key for authentication
base_url: Optional custom base URL for the API
"""
self.api_key = api_key
self.base_url = base_url
@property
@abstractmethod
def name(self) -> str:
"""Get the provider name."""
pass
@property
@abstractmethod
def capabilities(self) -> ProviderCapabilities:
"""Get provider capabilities."""
pass
@abstractmethod
def list_models(self) -> List[ModelInfo]:
"""
Fetch available models from the provider.
Returns:
List of available models with their info
"""
pass
@abstractmethod
def get_model(self, model_id: str) -> Optional[ModelInfo]:
"""
Get information about a specific model.
Args:
model_id: The model identifier
Returns:
Model information or None if not found
"""
pass
@abstractmethod
def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, Iterator[StreamChunk]]:
"""
Send a chat completion request.
Args:
model: Model ID to use
messages: List of chat messages
stream: Whether to stream the response
max_tokens: Maximum tokens in response
temperature: Sampling temperature
tools: List of tool definitions for function calling
tool_choice: How to handle tool selection ("auto", "none", etc.)
**kwargs: Additional provider-specific parameters
Returns:
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
"""
pass
@abstractmethod
async def chat_async(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
"""
Send an async chat completion request.
Args:
model: Model ID to use
messages: List of chat messages
stream: Whether to stream the response
max_tokens: Maximum tokens in response
temperature: Sampling temperature
tools: List of tool definitions for function calling
tool_choice: How to handle tool selection
**kwargs: Additional provider-specific parameters
Returns:
ChatResponse for non-streaming, AsyncIterator[StreamChunk] for streaming
"""
pass
@abstractmethod
def get_credits(self) -> Optional[Dict[str, Any]]:
"""
Get account credit/balance information.
Returns:
Dict with credit info or None if not supported
"""
pass
def validate_api_key(self) -> bool:
"""
Validate that the API key is valid.
Returns:
True if API key is valid
"""
try:
self.list_models()
return True
except Exception:
return False

630
oai/providers/openrouter.py Normal file
View File

@@ -0,0 +1,630 @@
"""
OpenRouter provider implementation.
This module implements the AIProvider interface for OpenRouter,
supporting chat completions, streaming, and function calling.
"""
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union
import requests
from openrouter import OpenRouter
from oai.constants import APP_NAME, APP_URL, DEFAULT_BASE_URL
from oai.providers.base import (
AIProvider,
ChatMessage,
ChatResponse,
ChatResponseChoice,
ModelInfo,
ProviderCapabilities,
StreamChunk,
ToolCall,
ToolFunction,
UsageStats,
)
from oai.utils.logging import get_logger
class OpenRouterProvider(AIProvider):
"""
OpenRouter API provider implementation.
Provides access to multiple AI models through OpenRouter's unified API,
supporting chat completions, streaming responses, and function calling.
Attributes:
client: The underlying OpenRouter client
_models_cache: Cached list of available models
"""
def __init__(
self,
api_key: str,
base_url: Optional[str] = None,
app_name: str = APP_NAME,
app_url: str = APP_URL,
):
"""
Initialize the OpenRouter provider.
Args:
api_key: OpenRouter API key
base_url: Optional custom base URL
app_name: Application name for API headers
app_url: Application URL for API headers
"""
super().__init__(api_key, base_url or DEFAULT_BASE_URL)
self.app_name = app_name
self.app_url = app_url
self.client = OpenRouter(api_key=api_key)
self._models_cache: Optional[List[ModelInfo]] = None
self._raw_models_cache: Optional[List[Dict[str, Any]]] = None
self.logger = get_logger()
self.logger.info(f"OpenRouter provider initialized with base URL: {self.base_url}")
@property
def name(self) -> str:
"""Get the provider name."""
return "OpenRouter"
@property
def capabilities(self) -> ProviderCapabilities:
"""Get provider capabilities."""
return ProviderCapabilities(
streaming=True,
tools=True,
images=True,
online=True,
max_context=2000000, # Claude models support up to 200k
)
def _get_headers(self) -> Dict[str, str]:
"""Get standard HTTP headers for API requests."""
headers = {
"HTTP-Referer": self.app_url,
"X-Title": self.app_name,
}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
return headers
def _parse_model(self, model_data: Dict[str, Any]) -> ModelInfo:
"""
Parse raw model data into ModelInfo.
Args:
model_data: Raw model data from API
Returns:
Parsed ModelInfo object
"""
architecture = model_data.get("architecture", {})
pricing_data = model_data.get("pricing", {})
# Parse pricing (convert from string to float if needed)
pricing = {}
for key in ["prompt", "completion"]:
value = pricing_data.get(key)
if value is not None:
try:
# Convert from per-token to per-million-tokens
pricing[key] = float(value) * 1_000_000
except (ValueError, TypeError):
pricing[key] = 0.0
return ModelInfo(
id=model_data.get("id", ""),
name=model_data.get("name", model_data.get("id", "")),
description=model_data.get("description", ""),
context_length=model_data.get("context_length", 0),
pricing=pricing,
supported_parameters=model_data.get("supported_parameters", []),
input_modalities=architecture.get("input_modalities", ["text"]),
output_modalities=architecture.get("output_modalities", ["text"]),
)
def list_models(self, filter_text_only: bool = True) -> List[ModelInfo]:
"""
Fetch available models from OpenRouter.
Args:
filter_text_only: If True, exclude video-only models
Returns:
List of available models
Raises:
Exception: If API request fails
"""
if self._models_cache is not None:
return self._models_cache
try:
response = requests.get(
f"{self.base_url}/models",
headers=self._get_headers(),
timeout=10,
)
response.raise_for_status()
raw_models = response.json().get("data", [])
self._raw_models_cache = raw_models
models = []
for model_data in raw_models:
# Optionally filter out video-only models
if filter_text_only:
modalities = model_data.get("modalities", [])
if modalities and "video" in modalities and "text" not in modalities:
continue
models.append(self._parse_model(model_data))
self._models_cache = models
self.logger.info(f"Fetched {len(models)} models from OpenRouter")
return models
except requests.RequestException as e:
self.logger.error(f"Failed to fetch models: {e}")
raise
def get_raw_models(self) -> List[Dict[str, Any]]:
"""
Get raw model data as returned by the API.
Useful for accessing provider-specific fields not in ModelInfo.
Returns:
List of raw model dictionaries
"""
if self._raw_models_cache is None:
self.list_models()
return self._raw_models_cache or []
def get_model(self, model_id: str) -> Optional[ModelInfo]:
"""
Get information about a specific model.
Args:
model_id: The model identifier
Returns:
Model information or None if not found
"""
models = self.list_models()
for model in models:
if model.id == model_id:
return model
return None
def get_raw_model(self, model_id: str) -> Optional[Dict[str, Any]]:
"""
Get raw model data for a specific model.
Args:
model_id: The model identifier
Returns:
Raw model dictionary or None if not found
"""
raw_models = self.get_raw_models()
for model in raw_models:
if model.get("id") == model_id:
return model
return None
def _convert_messages(self, messages: List[ChatMessage]) -> List[Dict[str, Any]]:
"""
Convert ChatMessage objects to API format.
Args:
messages: List of ChatMessage objects
Returns:
List of message dictionaries for the API
"""
return [msg.to_dict() for msg in messages]
def _parse_usage(self, usage_data: Any) -> Optional[UsageStats]:
"""
Parse usage data from API response.
Args:
usage_data: Raw usage data from API
Returns:
Parsed UsageStats or None
"""
if not usage_data:
return None
# Handle both attribute and dict access
prompt_tokens = 0
completion_tokens = 0
total_cost = None
if hasattr(usage_data, "prompt_tokens"):
prompt_tokens = getattr(usage_data, "prompt_tokens", 0) or 0
elif isinstance(usage_data, dict):
prompt_tokens = usage_data.get("prompt_tokens", 0) or 0
if hasattr(usage_data, "completion_tokens"):
completion_tokens = getattr(usage_data, "completion_tokens", 0) or 0
elif isinstance(usage_data, dict):
completion_tokens = usage_data.get("completion_tokens", 0) or 0
# Try alternative naming (input_tokens/output_tokens)
if prompt_tokens == 0:
if hasattr(usage_data, "input_tokens"):
prompt_tokens = getattr(usage_data, "input_tokens", 0) or 0
elif isinstance(usage_data, dict):
prompt_tokens = usage_data.get("input_tokens", 0) or 0
if completion_tokens == 0:
if hasattr(usage_data, "output_tokens"):
completion_tokens = getattr(usage_data, "output_tokens", 0) or 0
elif isinstance(usage_data, dict):
completion_tokens = usage_data.get("output_tokens", 0) or 0
# Get cost if available
# OpenRouter returns cost in different places:
# 1. As 'total_cost_usd' in usage object (rare)
# 2. As 'usage' at root level (common - this is the dollar amount)
total_cost = None
if hasattr(usage_data, "total_cost_usd"):
total_cost = getattr(usage_data, "total_cost_usd", None)
elif hasattr(usage_data, "usage"):
# OpenRouter puts cost as 'usage' field (dollar amount)
total_cost = getattr(usage_data, "usage", None)
elif isinstance(usage_data, dict):
total_cost = usage_data.get("total_cost_usd") or usage_data.get("usage")
return UsageStats(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
total_cost_usd=float(total_cost) if total_cost else None,
)
def _parse_tool_calls(self, tool_calls_data: Any) -> Optional[List[ToolCall]]:
"""
Parse tool calls from API response.
Args:
tool_calls_data: Raw tool calls data
Returns:
List of ToolCall objects or None
"""
if not tool_calls_data:
return None
tool_calls = []
for tc in tool_calls_data:
# Handle both attribute and dict access
if hasattr(tc, "id"):
tc_id = tc.id
tc_type = getattr(tc, "type", "function")
func = tc.function
func_name = func.name
func_args = func.arguments
else:
tc_id = tc.get("id", "")
tc_type = tc.get("type", "function")
func = tc.get("function", {})
func_name = func.get("name", "")
func_args = func.get("arguments", "{}")
tool_calls.append(
ToolCall(
id=tc_id,
type=tc_type,
function=ToolFunction(name=func_name, arguments=func_args),
)
)
return tool_calls if tool_calls else None
def _parse_response(self, response: Any) -> ChatResponse:
"""
Parse API response into ChatResponse.
Args:
response: Raw API response
Returns:
Parsed ChatResponse
"""
choices = []
for choice in response.choices:
msg = choice.message
message = ChatMessage(
role=msg.role if hasattr(msg, "role") else "assistant",
content=msg.content if hasattr(msg, "content") else None,
tool_calls=self._parse_tool_calls(
getattr(msg, "tool_calls", None)
),
)
choices.append(
ChatResponseChoice(
index=choice.index if hasattr(choice, "index") else 0,
message=message,
finish_reason=getattr(choice, "finish_reason", None),
)
)
return ChatResponse(
id=response.id if hasattr(response, "id") else "",
choices=choices,
usage=self._parse_usage(getattr(response, "usage", None)),
model=getattr(response, "model", None),
created=getattr(response, "created", None),
)
def chat(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
transforms: Optional[List[str]] = None,
**kwargs: Any,
) -> Union[ChatResponse, Iterator[StreamChunk]]:
"""
Send a chat completion request to OpenRouter.
Args:
model: Model ID to use
messages: List of chat messages
stream: Whether to stream the response
max_tokens: Maximum tokens in response
temperature: Sampling temperature (0-2)
tools: List of tool definitions for function calling
tool_choice: How to handle tool selection ("auto", "none", etc.)
transforms: List of transforms (e.g., ["middle-out"])
**kwargs: Additional parameters
Returns:
ChatResponse for non-streaming, Iterator[StreamChunk] for streaming
"""
# Build request parameters
params: Dict[str, Any] = {
"model": model,
"messages": self._convert_messages(messages),
"stream": stream,
"http_headers": self._get_headers(),
}
# Request usage stats in streaming responses
if stream:
params["stream_options"] = {"include_usage": True}
if max_tokens is not None:
params["max_tokens"] = max_tokens
if temperature is not None:
params["temperature"] = temperature
if tools:
params["tools"] = tools
params["tool_choice"] = tool_choice or "auto"
if transforms:
params["transforms"] = transforms
# Add any additional parameters
params.update(kwargs)
self.logger.debug(f"Sending chat request to model {model}")
try:
response = self.client.chat.send(**params)
if stream:
return self._stream_response(response)
else:
return self._parse_response(response)
except Exception as e:
self.logger.error(f"Chat request failed: {e}")
raise
def _stream_response(self, response: Any) -> Iterator[StreamChunk]:
"""
Process a streaming response.
Args:
response: Streaming response from API
Yields:
StreamChunk objects
"""
last_usage = None
try:
for chunk in response:
# Check for errors
if hasattr(chunk, "error") and chunk.error:
yield StreamChunk(
id=getattr(chunk, "id", ""),
error=chunk.error.message if hasattr(chunk.error, "message") else str(chunk.error),
)
return
# Extract delta content
delta_content = None
finish_reason = None
if hasattr(chunk, "choices") and chunk.choices:
choice = chunk.choices[0]
if hasattr(choice, "delta"):
delta = choice.delta
if hasattr(delta, "content") and delta.content:
delta_content = delta.content
finish_reason = getattr(choice, "finish_reason", None)
# Track usage from last chunk
if hasattr(chunk, "usage") and chunk.usage:
last_usage = self._parse_usage(chunk.usage)
yield StreamChunk(
id=getattr(chunk, "id", ""),
delta_content=delta_content,
finish_reason=finish_reason,
usage=last_usage if finish_reason else None,
)
except Exception as e:
self.logger.error(f"Stream error: {e}")
yield StreamChunk(id="", error=str(e))
async def chat_async(
self,
model: str,
messages: List[ChatMessage],
stream: bool = False,
max_tokens: Optional[int] = None,
temperature: Optional[float] = None,
tools: Optional[List[Dict[str, Any]]] = None,
tool_choice: Optional[str] = None,
**kwargs: Any,
) -> Union[ChatResponse, AsyncIterator[StreamChunk]]:
"""
Send an async chat completion request.
Note: Currently wraps the sync implementation.
TODO: Implement true async support when OpenRouter SDK supports it.
Args:
model: Model ID to use
messages: List of chat messages
stream: Whether to stream the response
max_tokens: Maximum tokens in response
temperature: Sampling temperature
tools: List of tool definitions
tool_choice: Tool selection mode
**kwargs: Additional parameters
Returns:
ChatResponse for non-streaming, AsyncIterator for streaming
"""
# For now, use sync implementation
# TODO: Add true async when SDK supports it
result = self.chat(
model=model,
messages=messages,
stream=stream,
max_tokens=max_tokens,
temperature=temperature,
tools=tools,
tool_choice=tool_choice,
**kwargs,
)
if stream and isinstance(result, Iterator):
# Convert sync iterator to async
async def async_iter() -> AsyncIterator[StreamChunk]:
for chunk in result:
yield chunk
return async_iter()
return result
def get_credits(self) -> Optional[Dict[str, Any]]:
"""
Get OpenRouter account credit information.
Returns:
Dict with credit info:
- total_credits: Total credits purchased
- used_credits: Credits used
- credits_left: Remaining credits
Raises:
Exception: If API request fails
"""
if not self.api_key:
return None
try:
response = requests.get(
f"{self.base_url}/credits",
headers=self._get_headers(),
timeout=10,
)
response.raise_for_status()
data = response.json().get("data", {})
total_credits = float(data.get("total_credits", 0))
total_usage = float(data.get("total_usage", 0))
credits_left = total_credits - total_usage
return {
"total_credits": total_credits,
"used_credits": total_usage,
"credits_left": credits_left,
"total_credits_formatted": f"${total_credits:.2f}",
"used_credits_formatted": f"${total_usage:.2f}",
"credits_left_formatted": f"${credits_left:.2f}",
}
except Exception as e:
self.logger.error(f"Failed to fetch credits: {e}")
return None
def clear_cache(self) -> None:
"""Clear the models cache to force a refresh."""
self._models_cache = None
self._raw_models_cache = None
self.logger.debug("Models cache cleared")
def get_effective_model_id(self, model_id: str, online_enabled: bool) -> str:
"""
Get the effective model ID with online suffix if needed.
Args:
model_id: Base model ID
online_enabled: Whether online mode is enabled
Returns:
Model ID with :online suffix if applicable
"""
if online_enabled and not model_id.endswith(":online"):
return f"{model_id}:online"
return model_id
def estimate_cost(
self,
model_id: str,
input_tokens: int,
output_tokens: int,
) -> float:
"""
Estimate the cost for a completion.
Args:
model_id: Model ID
input_tokens: Number of input tokens
output_tokens: Number of output tokens
Returns:
Estimated cost in USD
"""
model = self.get_model(model_id)
if model and model.pricing:
input_cost = model.pricing.get("prompt", 0) * input_tokens / 1_000_000
output_cost = model.pricing.get("completion", 0) * output_tokens / 1_000_000
return input_cost + output_cost
# Fallback to default pricing if model not found
from oai.constants import MODEL_PRICING
input_cost = MODEL_PRICING["input"] * input_tokens / 1_000_000
output_cost = MODEL_PRICING["output"] * output_tokens / 1_000_000
return input_cost + output_cost

2
oai/py.typed Normal file
View File

@@ -0,0 +1,2 @@
# Marker file for PEP 561
# This package supports type checking

5
oai/tui/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""Textual TUI interface for oAI."""
from oai.tui.app import oAIChatApp
__all__ = ["oAIChatApp"]

1002
oai/tui/app.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,21 @@
"""TUI screens for oAI."""
from oai.tui.screens.config_screen import ConfigScreen
from oai.tui.screens.conversation_selector import ConversationSelectorScreen
from oai.tui.screens.credits_screen import CreditsScreen
from oai.tui.screens.dialogs import AlertDialog, ConfirmDialog, InputDialog
from oai.tui.screens.help_screen import HelpScreen
from oai.tui.screens.model_selector import ModelSelectorScreen
from oai.tui.screens.stats_screen import StatsScreen
__all__ = [
"AlertDialog",
"ConfirmDialog",
"ConfigScreen",
"ConversationSelectorScreen",
"CreditsScreen",
"InputDialog",
"HelpScreen",
"ModelSelectorScreen",
"StatsScreen",
]

View File

@@ -0,0 +1,107 @@
"""Configuration screen for oAI TUI."""
from textual.app import ComposeResult
from textual.containers import Container, Vertical
from textual.screen import ModalScreen
from textual.widgets import Button, Static
from oai.config.settings import Settings
class ConfigScreen(ModalScreen[None]):
"""Modal screen displaying configuration settings."""
DEFAULT_CSS = """
ConfigScreen {
align: center middle;
}
ConfigScreen > Container {
width: 70;
height: auto;
background: #1e1e1e;
border: solid #555555;
}
ConfigScreen .header {
dock: top;
width: 100%;
height: auto;
background: #2d2d2d;
color: #cccccc;
padding: 0 2;
}
ConfigScreen .content {
width: 100%;
height: auto;
background: #1e1e1e;
padding: 2;
color: #cccccc;
}
ConfigScreen .footer {
dock: bottom;
width: 100%;
height: auto;
background: #2d2d2d;
padding: 1 2;
align: center middle;
}
"""
def __init__(self, settings: Settings):
super().__init__()
self.settings = settings
def compose(self) -> ComposeResult:
"""Compose the screen."""
with Container():
yield Static("[bold]Configuration[/]", classes="header")
with Vertical(classes="content"):
yield Static(self._get_config_text(), markup=True)
with Vertical(classes="footer"):
yield Button("Close", id="close", variant="primary")
def _get_config_text(self) -> str:
"""Generate the configuration text."""
from oai.constants import DEFAULT_SYSTEM_PROMPT
# API Key display
api_key_display = "***" + self.settings.api_key[-4:] if self.settings.api_key else "Not set"
# System prompt display
if self.settings.default_system_prompt is None:
system_prompt_display = f"[default] {DEFAULT_SYSTEM_PROMPT[:40]}..."
elif self.settings.default_system_prompt == "":
system_prompt_display = "[blank]"
else:
prompt = self.settings.default_system_prompt
system_prompt_display = prompt[:50] + "..." if len(prompt) > 50 else prompt
return f"""
[bold cyan]═══ CONFIGURATION ═══[/]
[bold]API Key:[/] {api_key_display}
[bold]Base URL:[/] {self.settings.base_url}
[bold]Default Model:[/] {self.settings.default_model or "Not set"}
[bold]System Prompt:[/] {system_prompt_display}
[bold]Streaming:[/] {"on" if self.settings.stream_enabled else "off"}
[bold]Cost Warning:[/] ${self.settings.cost_warning_threshold:.4f}
[bold]Max Tokens:[/] {self.settings.max_tokens}
[bold]Default Online:[/] {"on" if self.settings.default_online_mode else "off"}
[bold]Log Level:[/] {self.settings.log_level}
[dim]Use /config [setting] [value] to modify settings[/]
"""
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
self.dismiss()
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key in ("escape", "enter"):
self.dismiss()

View File

@@ -0,0 +1,205 @@
"""Conversation selector screen for oAI TUI."""
from typing import List, Optional
from textual.app import ComposeResult
from textual.containers import Container, Vertical
from textual.screen import ModalScreen
from textual.widgets import Button, DataTable, Input, Static
class ConversationSelectorScreen(ModalScreen[Optional[dict]]):
"""Modal screen for selecting a saved conversation."""
DEFAULT_CSS = """
ConversationSelectorScreen {
align: center middle;
}
ConversationSelectorScreen > Container {
width: 80%;
height: 70%;
background: #1e1e1e;
border: solid #555555;
layout: vertical;
}
ConversationSelectorScreen .header {
height: 3;
width: 100%;
background: #2d2d2d;
color: #cccccc;
padding: 0 2;
content-align: center middle;
}
ConversationSelectorScreen .search-input {
height: 3;
width: 100%;
background: #2a2a2a;
border: solid #555555;
margin: 0 0 1 0;
}
ConversationSelectorScreen .search-input:focus {
border: solid #888888;
}
ConversationSelectorScreen DataTable {
height: 1fr;
width: 100%;
background: #1e1e1e;
border: solid #555555;
}
ConversationSelectorScreen .footer {
height: 5;
width: 100%;
background: #2d2d2d;
padding: 1 2;
align: center middle;
}
ConversationSelectorScreen Button {
margin: 0 1;
}
"""
def __init__(self, conversations: List[dict]):
super().__init__()
self.all_conversations = conversations
self.filtered_conversations = conversations
self.selected_conversation: Optional[dict] = None
def compose(self) -> ComposeResult:
"""Compose the screen."""
with Container():
yield Static(
f"[bold]Load Conversation[/] [dim]({len(self.all_conversations)} saved)[/]",
classes="header"
)
yield Input(placeholder="Search conversations...", id="search-input", classes="search-input")
yield DataTable(id="conv-table", cursor_type="row", show_header=True, zebra_stripes=True)
with Vertical(classes="footer"):
yield Button("Load", id="load", variant="success")
yield Button("Cancel", id="cancel", variant="error")
def on_mount(self) -> None:
"""Initialize the table when mounted."""
table = self.query_one("#conv-table", DataTable)
# Add columns
table.add_column("#", width=5)
table.add_column("Name", width=40)
table.add_column("Messages", width=12)
table.add_column("Last Saved", width=20)
# Populate table
self._populate_table()
# Focus table if list is small (fits on screen), otherwise focus search
if len(self.all_conversations) <= 10:
table.focus()
else:
search_input = self.query_one("#search-input", Input)
search_input.focus()
def _populate_table(self) -> None:
"""Populate the table with conversations."""
table = self.query_one("#conv-table", DataTable)
table.clear()
for idx, conv in enumerate(self.filtered_conversations, 1):
name = conv.get("name", "Unknown")
message_count = str(conv.get("message_count", 0))
last_saved = conv.get("last_saved", "Unknown")
# Format timestamp if it's a full datetime
if "T" in last_saved or len(last_saved) > 20:
try:
# Truncate to just date and time
last_saved = last_saved[:19].replace("T", " ")
except:
pass
table.add_row(
str(idx),
name,
message_count,
last_saved,
key=str(idx)
)
def on_input_changed(self, event: Input.Changed) -> None:
"""Filter conversations based on search input."""
if event.input.id != "search-input":
return
search_term = event.value.lower()
if not search_term:
self.filtered_conversations = self.all_conversations
else:
self.filtered_conversations = [
c for c in self.all_conversations
if search_term in c.get("name", "").lower()
]
self._populate_table()
def on_data_table_row_selected(self, event: DataTable.RowSelected) -> None:
"""Handle row selection (click)."""
try:
row_index = int(event.row_key.value) - 1
if 0 <= row_index < len(self.filtered_conversations):
self.selected_conversation = self.filtered_conversations[row_index]
except (ValueError, IndexError):
pass
def on_data_table_row_highlighted(self, event) -> None:
"""Handle row highlight (arrow key navigation)."""
try:
table = self.query_one("#conv-table", DataTable)
if table.cursor_row is not None:
row_data = table.get_row_at(table.cursor_row)
if row_data:
row_index = int(row_data[0]) - 1
if 0 <= row_index < len(self.filtered_conversations):
self.selected_conversation = self.filtered_conversations[row_index]
except (ValueError, IndexError, AttributeError):
pass
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
if event.button.id == "load":
if self.selected_conversation:
self.dismiss(self.selected_conversation)
else:
self.dismiss(None)
else:
self.dismiss(None)
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key == "escape":
self.dismiss(None)
elif event.key == "enter":
# If in search input, move to table
search_input = self.query_one("#search-input", Input)
if search_input.has_focus:
table = self.query_one("#conv-table", DataTable)
table.focus()
# If in table, select current row
else:
table = self.query_one("#conv-table", DataTable)
if table.cursor_row is not None:
try:
row_data = table.get_row_at(table.cursor_row)
if row_data:
row_index = int(row_data[0]) - 1
if 0 <= row_index < len(self.filtered_conversations):
selected = self.filtered_conversations[row_index]
self.dismiss(selected)
except (ValueError, IndexError, AttributeError):
if self.selected_conversation:
self.dismiss(self.selected_conversation)

View File

@@ -0,0 +1,125 @@
"""Credits screen for oAI TUI."""
from typing import Optional, Dict, Any
from textual.app import ComposeResult
from textual.containers import Container, Vertical
from textual.screen import ModalScreen
from textual.widgets import Button, Static
from oai.core.client import AIClient
class CreditsScreen(ModalScreen[None]):
"""Modal screen displaying account credits."""
DEFAULT_CSS = """
CreditsScreen {
align: center middle;
}
CreditsScreen > Container {
width: 60;
height: auto;
background: #1e1e1e;
border: solid #555555;
}
CreditsScreen .header {
dock: top;
width: 100%;
height: auto;
background: #2d2d2d;
color: #cccccc;
padding: 0 2;
}
CreditsScreen .content {
width: 100%;
height: auto;
background: #1e1e1e;
padding: 2;
color: #cccccc;
}
CreditsScreen .footer {
dock: bottom;
width: 100%;
height: auto;
background: #2d2d2d;
padding: 1 2;
align: center middle;
}
"""
def __init__(self, client: AIClient):
super().__init__()
self.client = client
self.credits_data: Optional[Dict[str, Any]] = None
def compose(self) -> ComposeResult:
"""Compose the screen."""
with Container():
yield Static("[bold]Account Credits[/]", classes="header")
with Vertical(classes="content"):
yield Static("[dim]Loading...[/]", id="credits-content", markup=True)
with Vertical(classes="footer"):
yield Button("Close", id="close", variant="primary")
def on_mount(self) -> None:
"""Fetch credits when mounted."""
self.fetch_credits()
def fetch_credits(self) -> None:
"""Fetch and display credits information."""
try:
self.credits_data = self.client.provider.get_credits()
content = self.query_one("#credits-content", Static)
content.update(self._get_credits_text())
except Exception as e:
content = self.query_one("#credits-content", Static)
content.update(f"[red]Error fetching credits:[/]\n{str(e)}")
def _get_credits_text(self) -> str:
"""Generate the credits text."""
if not self.credits_data:
return "[yellow]No credit information available[/]"
total = self.credits_data.get("total_credits", 0)
used = self.credits_data.get("used_credits", 0)
remaining = self.credits_data.get("credits_left", 0)
# Calculate percentage used
if total > 0:
percent_used = (used / total) * 100
percent_remaining = (remaining / total) * 100
else:
percent_used = 0
percent_remaining = 0
# Color code based on remaining credits
if percent_remaining > 50:
remaining_color = "green"
elif percent_remaining > 20:
remaining_color = "yellow"
else:
remaining_color = "red"
return f"""
[bold cyan]═══ OPENROUTER CREDITS ═══[/]
[bold]Total Credits:[/] ${total:.2f}
[bold]Used:[/] ${used:.2f} [dim]({percent_used:.1f}%)[/]
[bold]Remaining:[/] [{remaining_color}]${remaining:.2f}[/] [dim]({percent_remaining:.1f}%)[/]
[dim]Visit openrouter.ai to add more credits[/]
"""
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
self.dismiss()
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key in ("escape", "enter"):
self.dismiss()

236
oai/tui/screens/dialogs.py Normal file
View File

@@ -0,0 +1,236 @@
"""Modal dialog screens for oAI TUI."""
from typing import Callable, Optional
from textual.app import ComposeResult
from textual.containers import Container, Horizontal, Vertical
from textual.screen import ModalScreen
from textual.widgets import Button, Input, Label, Static
class ConfirmDialog(ModalScreen[bool]):
"""A confirmation dialog with Yes/No buttons."""
DEFAULT_CSS = """
ConfirmDialog {
align: center middle;
}
ConfirmDialog > Container {
width: 60;
height: auto;
background: #2d2d2d;
border: solid #555555;
padding: 2;
}
ConfirmDialog Label {
width: 100%;
content-align: center middle;
margin-bottom: 2;
color: #cccccc;
}
ConfirmDialog Horizontal {
width: 100%;
height: auto;
align: center middle;
}
ConfirmDialog Button {
margin: 0 1;
}
"""
def __init__(
self,
message: str,
title: str = "Confirm",
yes_label: str = "Yes",
no_label: str = "No",
):
super().__init__()
self.message = message
self.title = title
self.yes_label = yes_label
self.no_label = no_label
def compose(self) -> ComposeResult:
"""Compose the dialog."""
with Container():
yield Static(f"[bold]{self.title}[/]", classes="dialog-title")
yield Label(self.message)
with Horizontal():
yield Button(self.yes_label, id="yes", variant="success")
yield Button(self.no_label, id="no", variant="error")
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
if event.button.id == "yes":
self.dismiss(True)
else:
self.dismiss(False)
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key == "escape":
self.dismiss(False)
elif event.key == "enter":
self.dismiss(True)
class InputDialog(ModalScreen[Optional[str]]):
"""An input dialog for text entry."""
DEFAULT_CSS = """
InputDialog {
align: center middle;
}
InputDialog > Container {
width: 70;
height: auto;
background: #2d2d2d;
border: solid #555555;
padding: 2;
}
InputDialog Label {
width: 100%;
margin-bottom: 1;
color: #cccccc;
}
InputDialog Input {
width: 100%;
margin-bottom: 2;
background: #3a3a3a;
border: solid #555555;
}
InputDialog Input:focus {
border: solid #888888;
}
InputDialog Horizontal {
width: 100%;
height: auto;
align: center middle;
}
InputDialog Button {
margin: 0 1;
}
"""
def __init__(
self,
message: str,
title: str = "Input",
default: str = "",
placeholder: str = "",
):
super().__init__()
self.message = message
self.title = title
self.default = default
self.placeholder = placeholder
def compose(self) -> ComposeResult:
"""Compose the dialog."""
with Container():
yield Static(f"[bold]{self.title}[/]", classes="dialog-title")
yield Label(self.message)
yield Input(
value=self.default,
placeholder=self.placeholder,
id="input-field"
)
with Horizontal():
yield Button("OK", id="ok", variant="primary")
yield Button("Cancel", id="cancel")
def on_mount(self) -> None:
"""Focus the input field when mounted."""
input_field = self.query_one("#input-field", Input)
input_field.focus()
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
if event.button.id == "ok":
input_field = self.query_one("#input-field", Input)
self.dismiss(input_field.value)
else:
self.dismiss(None)
def on_input_submitted(self, event: Input.Submitted) -> None:
"""Handle Enter key in input field."""
self.dismiss(event.value)
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key == "escape":
self.dismiss(None)
class AlertDialog(ModalScreen[None]):
"""A simple alert/message dialog."""
DEFAULT_CSS = """
AlertDialog {
align: center middle;
}
AlertDialog > Container {
width: 60;
height: auto;
background: #2d2d2d;
border: solid #555555;
padding: 2;
}
AlertDialog Label {
width: 100%;
content-align: center middle;
margin-bottom: 2;
color: #cccccc;
}
AlertDialog Horizontal {
width: 100%;
height: auto;
align: center middle;
}
"""
def __init__(self, message: str, title: str = "Alert", variant: str = "default"):
super().__init__()
self.message = message
self.title = title
self.variant = variant
def compose(self) -> ComposeResult:
"""Compose the dialog."""
# Choose color based on variant (using design system)
color = "$primary"
if self.variant == "error":
color = "$error"
elif self.variant == "success":
color = "$success"
elif self.variant == "warning":
color = "$warning"
with Container():
yield Static(f"[bold {color}]{self.title}[/]", classes="dialog-title")
yield Label(self.message)
with Horizontal():
yield Button("OK", id="ok", variant="primary")
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
self.dismiss()
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key in ("escape", "enter"):
self.dismiss()

View File

@@ -0,0 +1,138 @@
"""Help screen for oAI TUI."""
from textual.app import ComposeResult
from textual.containers import Container, Vertical
from textual.screen import ModalScreen
from textual.widgets import Button, Static
class HelpScreen(ModalScreen[None]):
"""Modal screen displaying help and commands."""
DEFAULT_CSS = """
HelpScreen {
align: center middle;
}
HelpScreen > Container {
width: 90%;
height: 85%;
background: #1e1e1e;
border: solid #555555;
}
HelpScreen .header {
dock: top;
width: 100%;
height: auto;
background: #2d2d2d;
color: #cccccc;
padding: 0 2;
}
HelpScreen .content {
height: 1fr;
background: #1e1e1e;
padding: 2;
overflow-y: auto;
color: #cccccc;
}
HelpScreen .footer {
dock: bottom;
width: 100%;
height: auto;
background: #2d2d2d;
padding: 1 2;
align: center middle;
}
"""
def compose(self) -> ComposeResult:
"""Compose the screen."""
with Container():
yield Static("[bold]oAI Help & Commands[/]", classes="header")
with Vertical(classes="content"):
yield Static(self._get_help_text(), markup=True)
with Vertical(classes="footer"):
yield Button("Close", id="close", variant="primary")
def _get_help_text(self) -> str:
"""Generate the help text."""
return """
[bold cyan]═══ KEYBOARD SHORTCUTS ═══[/]
[bold]F1[/] Show this help (Ctrl+H may not work)
[bold]F2[/] Open model selector (Ctrl+M may not work)
[bold]Ctrl+S[/] Show session statistics
[bold]Ctrl+L[/] Clear chat display
[bold]Ctrl+P[/] Show previous message
[bold]Ctrl+N[/] Show next message
[bold]Ctrl+Q[/] Quit application
[bold]Up/Down[/] Navigate input history
[bold]ESC[/] Close dialogs
[dim]Note: Some Ctrl keys may be captured by your terminal[/]
[bold cyan]═══ SLASH COMMANDS ═══[/]
[bold yellow]Session Control:[/]
/reset Clear conversation history (with confirmation)
/clear Clear the chat display
/memory on/off Toggle conversation memory
/online on/off Toggle online search mode
/exit, /quit, /bye Exit the application
[bold yellow]Model & Configuration:[/]
/model [search] Open model selector with optional search
/config View configuration settings
/config api Set API key (prompts for input)
/config stream on Enable streaming responses
/system [prompt] Set session system prompt
/maxtoken [n] Set session token limit
[bold yellow]Conversation Management:[/]
/save [name] Save current conversation
/load [name] Load saved conversation (shows picker if no name)
/list List all saved conversations
/delete <name> Delete a saved conversation
[bold yellow]Export:[/]
/export md [file] Export as Markdown
/export json [file] Export as JSON
/export html [file] Export as HTML
[bold yellow]History Navigation:[/]
/prev Show previous message in history
/next Show next message in history
[bold yellow]MCP (Model Context Protocol):[/]
/mcp on Enable MCP file access
/mcp off Disable MCP
/mcp status Show MCP status
/mcp add <path> Add folder for file access
/mcp list List registered folders
/mcp write Toggle write permissions
[bold yellow]Information & Utilities:[/]
/help Show this help screen
/stats Show session statistics
/credits Check account credits
/retry Retry last prompt
/paste Paste from clipboard and send
[bold cyan]═══ TIPS ═══[/]
• Type [bold]/[/] to see command suggestions with [bold]Tab[/] to autocomplete
• Use [bold]Up/Down arrows[/] to navigate your input history
• Type [bold]//[/] at start to escape commands (sends /help as literal message)
• All messages support [bold]Markdown formatting[/] with syntax highlighting
• Responses stream in real-time for better interactivity
• Enable MCP to let AI access your local files and databases
• Use [bold]F1[/] or [bold]F2[/] if Ctrl shortcuts don't work in your terminal
"""
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
self.dismiss()
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key in ("escape", "enter"):
self.dismiss()

View File

@@ -0,0 +1,254 @@
"""Model selector screen for oAI TUI."""
from typing import List, Optional
from textual.app import ComposeResult
from textual.containers import Container, Vertical
from textual.screen import ModalScreen
from textual.widgets import Button, DataTable, Input, Label, Static
class ModelSelectorScreen(ModalScreen[Optional[dict]]):
"""Modal screen for selecting an AI model."""
DEFAULT_CSS = """
ModelSelectorScreen {
align: center middle;
}
ModelSelectorScreen > Container {
width: 90%;
height: 85%;
background: #1e1e1e;
border: solid #555555;
layout: vertical;
}
ModelSelectorScreen .header {
height: 3;
width: 100%;
background: #2d2d2d;
color: #cccccc;
padding: 0 2;
content-align: center middle;
}
ModelSelectorScreen .search-input {
height: 3;
width: 100%;
background: #2a2a2a;
border: solid #555555;
margin: 0 0 1 0;
}
ModelSelectorScreen .search-input:focus {
border: solid #888888;
}
ModelSelectorScreen DataTable {
height: 1fr;
width: 100%;
background: #1e1e1e;
border: solid #555555;
}
ModelSelectorScreen .footer {
height: 5;
width: 100%;
background: #2d2d2d;
padding: 1 2;
align: center middle;
}
ModelSelectorScreen Button {
margin: 0 1;
}
"""
def __init__(self, models: List[dict], current_model: Optional[str] = None):
super().__init__()
self.all_models = models
self.filtered_models = models
self.current_model = current_model
self.selected_model: Optional[dict] = None
def compose(self) -> ComposeResult:
"""Compose the screen."""
with Container():
yield Static(
f"[bold]Select Model[/] [dim]({len(self.all_models)} available)[/]",
classes="header"
)
yield Input(placeholder="Search to filter models...", id="search-input", classes="search-input")
yield DataTable(id="model-table", cursor_type="row", show_header=True, zebra_stripes=True)
with Vertical(classes="footer"):
yield Button("Select", id="select", variant="success")
yield Button("Cancel", id="cancel", variant="error")
def on_mount(self) -> None:
"""Initialize the table when mounted."""
table = self.query_one("#model-table", DataTable)
# Add columns
table.add_column("#", width=5)
table.add_column("Model ID", width=35)
table.add_column("Name", width=30)
table.add_column("Context", width=10)
table.add_column("Price", width=12)
table.add_column("Img", width=4)
table.add_column("Tools", width=6)
table.add_column("Online", width=7)
# Populate table
self._populate_table()
# Focus table if list is small (fits on screen), otherwise focus search
if len(self.filtered_models) <= 20:
table.focus()
else:
search_input = self.query_one("#search-input", Input)
search_input.focus()
def _populate_table(self) -> None:
"""Populate the table with models."""
table = self.query_one("#model-table", DataTable)
table.clear()
rows_added = 0
for idx, model in enumerate(self.filtered_models, 1):
try:
model_id = model.get("id", "")
name = model.get("name", "")
context = str(model.get("context_length", "N/A"))
# Format pricing
pricing = model.get("pricing", {})
prompt_price = pricing.get("prompt", "0")
completion_price = pricing.get("completion", "0")
# Convert to numbers and format
try:
prompt = float(prompt_price) * 1000000 # Convert to per 1M tokens
completion = float(completion_price) * 1000000
if prompt == 0 and completion == 0:
price = "Free"
else:
price = f"${prompt:.2f}/${completion:.2f}"
except:
price = "N/A"
# Check capabilities
architecture = model.get("architecture", {})
modality = architecture.get("modality", "")
supported_params = model.get("supported_parameters", [])
# Vision support: check if modality contains "image"
supports_vision = "image" in modality
# Tool support: check if "tools" or "tool_choice" in supported_parameters
supports_tools = "tools" in supported_params or "tool_choice" in supported_params
# Online support: check if model can use :online suffix (most models can)
# Models that already have :online in their ID support it
supports_online = ":online" in model_id or model_id not in ["openrouter/free"]
# Format capability indicators
img_indicator = "" if supports_vision else "-"
tools_indicator = "" if supports_tools else "-"
web_indicator = "" if supports_online else "-"
# Add row
table.add_row(
str(idx),
model_id,
name,
context,
price,
img_indicator,
tools_indicator,
web_indicator,
key=str(idx)
)
rows_added += 1
except Exception:
# Silently skip rows that fail to add
pass
def on_input_changed(self, event: Input.Changed) -> None:
"""Filter models based on search input."""
if event.input.id != "search-input":
return
search_term = event.value.lower()
if not search_term:
self.filtered_models = self.all_models
else:
self.filtered_models = [
m for m in self.all_models
if search_term in m.get("id", "").lower()
or search_term in m.get("name", "").lower()
]
self._populate_table()
def on_data_table_row_selected(self, event: DataTable.RowSelected) -> None:
"""Handle row selection (click or arrow navigation)."""
try:
row_index = int(event.row_key.value) - 1
if 0 <= row_index < len(self.filtered_models):
self.selected_model = self.filtered_models[row_index]
except (ValueError, IndexError):
pass
def on_data_table_row_highlighted(self, event) -> None:
"""Handle row highlight (arrow key navigation)."""
try:
table = self.query_one("#model-table", DataTable)
if table.cursor_row is not None:
row_data = table.get_row_at(table.cursor_row)
if row_data:
row_index = int(row_data[0]) - 1
if 0 <= row_index < len(self.filtered_models):
self.selected_model = self.filtered_models[row_index]
except (ValueError, IndexError, AttributeError):
pass
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
if event.button.id == "select":
if self.selected_model:
self.dismiss(self.selected_model)
else:
# No selection, dismiss without result
self.dismiss(None)
else:
self.dismiss(None)
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key == "escape":
self.dismiss(None)
elif event.key == "enter":
# If in search input, move to table
search_input = self.query_one("#search-input", Input)
if search_input.has_focus:
table = self.query_one("#model-table", DataTable)
table.focus()
# If in table or anywhere else, select current row
else:
table = self.query_one("#model-table", DataTable)
# Get the currently highlighted row
if table.cursor_row is not None:
try:
row_key = table.get_row_at(table.cursor_row)
if row_key:
row_index = int(row_key[0]) - 1
if 0 <= row_index < len(self.filtered_models):
selected = self.filtered_models[row_index]
self.dismiss(selected)
except (ValueError, IndexError, AttributeError):
# Fall back to previously selected model
if self.selected_model:
self.dismiss(self.selected_model)

View File

@@ -0,0 +1,129 @@
"""Statistics screen for oAI TUI."""
from textual.app import ComposeResult
from textual.containers import Container, Vertical
from textual.screen import ModalScreen
from textual.widgets import Button, Static
from oai.core.session import ChatSession
class StatsScreen(ModalScreen[None]):
"""Modal screen displaying session statistics."""
DEFAULT_CSS = """
StatsScreen {
align: center middle;
}
StatsScreen > Container {
width: 70;
height: auto;
background: #1e1e1e;
border: solid #555555;
}
StatsScreen .header {
dock: top;
width: 100%;
height: auto;
background: #2d2d2d;
color: #cccccc;
padding: 0 2;
}
StatsScreen .content {
width: 100%;
height: auto;
background: #1e1e1e;
padding: 2;
color: #cccccc;
}
StatsScreen .footer {
dock: bottom;
width: 100%;
height: auto;
background: #2d2d2d;
padding: 1 2;
align: center middle;
}
"""
def __init__(self, session: ChatSession):
super().__init__()
self.session = session
def compose(self) -> ComposeResult:
"""Compose the screen."""
with Container():
yield Static("[bold]Session Statistics[/]", classes="header")
with Vertical(classes="content"):
yield Static(self._get_stats_text(), markup=True)
with Vertical(classes="footer"):
yield Button("Close", id="close", variant="primary")
def _get_stats_text(self) -> str:
"""Generate the statistics text."""
stats = self.session.stats
# Calculate averages
avg_input = stats.total_input_tokens // stats.message_count if stats.message_count > 0 else 0
avg_output = stats.total_output_tokens // stats.message_count if stats.message_count > 0 else 0
avg_cost = stats.total_cost / stats.message_count if stats.message_count > 0 else 0
# Get model info
model_name = "None"
model_context = "N/A"
if self.session.selected_model:
model_name = self.session.selected_model.get("name", "Unknown")
model_context = str(self.session.selected_model.get("context_length", "N/A"))
# MCP status
mcp_status = "Disabled"
if self.session.mcp_manager and self.session.mcp_manager.enabled:
mode = self.session.mcp_manager.mode
if mode == "files":
write = " (Write)" if self.session.mcp_manager.write_enabled else ""
mcp_status = f"Enabled - Files{write}"
elif mode == "database":
db_idx = self.session.mcp_manager.selected_db_index
if db_idx is not None:
db_name = self.session.mcp_manager.databases[db_idx]["name"]
mcp_status = f"Enabled - Database ({db_name})"
return f"""
[bold cyan]═══ SESSION INFO ═══[/]
[bold]Messages:[/] {stats.message_count}
[bold]Current Model:[/] {model_name}
[bold]Context Length:[/] {model_context}
[bold]Memory:[/] {"Enabled" if self.session.memory_enabled else "Disabled"}
[bold]Online Mode:[/] {"Enabled" if self.session.online_enabled else "Disabled"}
[bold]MCP:[/] {mcp_status}
[bold cyan]═══ TOKEN USAGE ═══[/]
[bold]Input Tokens:[/] {stats.total_input_tokens:,}
[bold]Output Tokens:[/] {stats.total_output_tokens:,}
[bold]Total Tokens:[/] {stats.total_tokens:,}
[bold]Avg Input/Msg:[/] {avg_input:,}
[bold]Avg Output/Msg:[/] {avg_output:,}
[bold cyan]═══ COSTS ═══[/]
[bold]Total Cost:[/] ${stats.total_cost:.6f}
[bold]Avg Cost/Msg:[/] ${avg_cost:.6f}
[bold cyan]═══ HISTORY ═══[/]
[bold]History Size:[/] {len(self.session.history)} entries
[bold]Current Index:[/] {self.session.current_index + 1 if self.session.history else 0}
[bold]Memory Start:[/] {self.session.memory_start_index + 1}
"""
def on_button_pressed(self, event: Button.Pressed) -> None:
"""Handle button press."""
self.dismiss()
def on_key(self, event) -> None:
"""Handle keyboard shortcuts."""
if event.key in ("escape", "enter"):
self.dismiss()

167
oai/tui/styles.tcss Normal file
View File

@@ -0,0 +1,167 @@
/* Textual CSS for oAI TUI - Using Textual Design System */
Screen {
background: $background;
overflow: hidden;
}
Header {
dock: top;
height: auto;
background: #2d2d2d;
color: #cccccc;
padding: 0 1;
border-bottom: solid #555555;
}
ChatDisplay {
background: $background;
border: none;
padding: 1;
scrollbar-background: $background;
scrollbar-color: $primary;
overflow-y: auto;
}
UserMessageWidget {
margin: 0 0 1 0;
padding: 1;
background: $surface;
border-left: thick $success;
height: auto;
}
SystemMessageWidget {
margin: 0 0 1 0;
padding: 1;
background: #2a2a2a;
border-left: thick #888888;
height: auto;
color: #cccccc;
}
AssistantMessageWidget {
margin: 0 0 1 0;
padding: 1;
background: $panel;
border-left: thick $accent;
height: auto;
}
#assistant-label {
margin-bottom: 1;
color: #cccccc;
}
#assistant-content {
height: auto;
max-height: 100%;
color: $text;
}
InputBar {
dock: bottom;
height: auto;
background: #2d2d2d;
align: center middle;
border-top: solid #555555;
padding: 1;
}
#input-prefix {
width: auto;
padding: 0 1;
content-align: center middle;
color: #888888;
}
#input-prefix.prefix-hidden {
display: none;
}
#chat-input {
width: 85%;
height: 5;
min-height: 5;
background: #3a3a3a;
border: none;
padding: 1 2;
color: #ffffff;
content-align: left top;
}
#chat-input:focus {
background: #404040;
}
#command-dropdown {
display: none;
dock: bottom;
offset-y: -5;
offset-x: 7.5%;
height: auto;
max-height: 12;
width: 85%;
background: #2d2d2d;
border: solid #555555;
padding: 0;
layer: overlay;
}
#command-dropdown.visible {
display: block;
}
#command-dropdown #command-list {
background: #2d2d2d;
border: none;
scrollbar-background: #2d2d2d;
scrollbar-color: #555555;
}
Footer {
dock: bottom;
height: auto;
background: #252525;
color: #888888;
padding: 0 1;
}
/* Button styles */
Button {
height: 3;
min-width: 10;
background: #3a3a3a;
color: #cccccc;
border: none;
}
Button:hover {
background: #4a4a4a;
}
Button:focus {
background: #505050;
}
Button.-primary {
background: #3a3a3a;
}
Button.-success {
background: #2d5016;
color: #90ee90;
}
Button.-success:hover {
background: #3a6b1e;
}
Button.-error {
background: #5a1a1a;
color: #ff6b6b;
}
Button.-error:hover {
background: #6e2222;
}

View File

@@ -0,0 +1,17 @@
"""TUI widgets for oAI."""
from oai.tui.widgets.chat_display import ChatDisplay
from oai.tui.widgets.footer import Footer
from oai.tui.widgets.header import Header
from oai.tui.widgets.input_bar import InputBar
from oai.tui.widgets.message import AssistantMessageWidget, SystemMessageWidget, UserMessageWidget
__all__ = [
"ChatDisplay",
"Footer",
"Header",
"InputBar",
"UserMessageWidget",
"SystemMessageWidget",
"AssistantMessageWidget",
]

View File

@@ -0,0 +1,21 @@
"""Chat display widget for oAI TUI."""
from textual.containers import ScrollableContainer
from textual.widgets import Static
class ChatDisplay(ScrollableContainer):
"""Scrollable container for chat messages."""
def __init__(self):
super().__init__(id="chat-display")
async def add_message(self, widget: Static) -> None:
"""Add a message widget to the display."""
await self.mount(widget)
self.scroll_end(animate=False)
def clear_messages(self) -> None:
"""Clear all messages from the display."""
for child in list(self.children):
child.remove()

View File

@@ -0,0 +1,178 @@
"""Command dropdown menu for TUI input."""
from textual.app import ComposeResult
from textual.containers import VerticalScroll
from textual.widget import Widget
from textual.widgets import Label, OptionList
from textual.widgets.option_list import Option
from oai.commands import registry
class CommandDropdown(VerticalScroll):
"""Dropdown menu showing available commands."""
DEFAULT_CSS = """
CommandDropdown {
display: none;
height: auto;
max-height: 12;
width: 80;
background: #2d2d2d;
border: solid #555555;
padding: 0;
layer: overlay;
}
CommandDropdown.visible {
display: block;
}
CommandDropdown OptionList {
height: auto;
max-height: 12;
background: #2d2d2d;
border: none;
padding: 0;
}
CommandDropdown OptionList > .option-list--option {
padding: 0 2;
color: #cccccc;
background: transparent;
}
CommandDropdown OptionList > .option-list--option-highlighted {
background: #3e3e3e;
color: #ffffff;
}
"""
def __init__(self):
"""Initialize the command dropdown."""
super().__init__(id="command-dropdown")
self._all_commands = []
self._load_commands()
def _load_commands(self) -> None:
"""Load all available commands."""
# Get base commands with descriptions
base_commands = [
("/help", "Show help screen"),
("/model", "Select AI model"),
("/stats", "Show session statistics"),
("/credits", "Check account credits"),
("/clear", "Clear chat display"),
("/reset", "Reset conversation history"),
("/memory on", "Enable conversation memory"),
("/memory off", "Disable memory"),
("/online on", "Enable online search"),
("/online off", "Disable online search"),
("/save", "Save current conversation"),
("/load", "Load saved conversation"),
("/list", "List saved conversations"),
("/delete", "Delete a conversation"),
("/export md", "Export as Markdown"),
("/export json", "Export as JSON"),
("/export html", "Export as HTML"),
("/prev", "Show previous message"),
("/next", "Show next message"),
("/config", "View configuration"),
("/config api", "Set API key"),
("/system", "Set system prompt"),
("/maxtoken", "Set token limit"),
("/retry", "Retry last prompt"),
("/paste", "Paste from clipboard"),
("/mcp on", "Enable MCP file access"),
("/mcp off", "Disable MCP"),
("/mcp status", "Show MCP status"),
("/mcp add", "Add folder/database"),
("/mcp remove", "Remove folder/database"),
("/mcp list", "List folders"),
("/mcp write on", "Enable write mode"),
("/mcp write off", "Disable write mode"),
("/mcp files", "Switch to file mode"),
("/mcp db list", "List databases"),
]
self._all_commands = base_commands
def compose(self) -> ComposeResult:
"""Compose the dropdown."""
yield OptionList(id="command-list")
def show_commands(self, filter_text: str = "") -> None:
"""Show commands matching the filter.
Args:
filter_text: Text to filter commands by
"""
option_list = self.query_one("#command-list", OptionList)
option_list.clear_options()
if not filter_text.startswith("/"):
self.remove_class("visible")
return
# Remove the leading slash for filtering
filter_without_slash = filter_text[1:].lower()
# Filter commands - show if filter text is contained anywhere in the command
if filter_without_slash:
matching = [
(cmd, desc) for cmd, desc in self._all_commands
if filter_without_slash in cmd[1:].lower() # Skip the / in command for matching
]
else:
# Show all commands when just "/" is typed
matching = self._all_commands
if not matching:
self.remove_class("visible")
return
# Add options - limit to 10 results
for cmd, desc in matching[:10]:
# Format: command in white, description in gray, separated by spaces
label = f"{cmd} [dim]{desc}[/]" if desc else cmd
option_list.add_option(Option(label, id=cmd))
self.add_class("visible")
# Auto-select first option
if len(option_list._options) > 0:
option_list.highlighted = 0
def hide(self) -> None:
"""Hide the dropdown."""
self.remove_class("visible")
def get_selected_command(self) -> str | None:
"""Get the currently selected command.
Returns:
Selected command text or None
"""
option_list = self.query_one("#command-list", OptionList)
if option_list.highlighted is not None:
option = option_list.get_option_at_index(option_list.highlighted)
return option.id
return None
def move_selection_up(self) -> None:
"""Move selection up in the list."""
option_list = self.query_one("#command-list", OptionList)
if option_list.option_count > 0:
if option_list.highlighted is None:
option_list.highlighted = option_list.option_count - 1
elif option_list.highlighted > 0:
option_list.highlighted -= 1
def move_selection_down(self) -> None:
"""Move selection down in the list."""
option_list = self.query_one("#command-list", OptionList)
if option_list.option_count > 0:
if option_list.highlighted is None:
option_list.highlighted = 0
elif option_list.highlighted < option_list.option_count - 1:
option_list.highlighted += 1

View File

@@ -0,0 +1,58 @@
"""Command suggester for TUI input."""
from typing import Iterable
from textual.suggester import Suggester
from oai.commands import registry
class CommandSuggester(Suggester):
"""Suggester that provides command completions."""
def __init__(self):
"""Initialize the command suggester."""
super().__init__(use_cache=False, case_sensitive=False)
# Get all command names from registry
self._commands = []
self._update_commands()
def _update_commands(self) -> None:
"""Update the list of available commands."""
# Get all registered command names
command_names = registry.get_all_names()
# Add common MCP subcommands for better UX
mcp_subcommands = [
"/mcp on",
"/mcp off",
"/mcp status",
"/mcp add",
"/mcp remove",
"/mcp list",
"/mcp write on",
"/mcp write off",
"/mcp files",
"/mcp db list",
]
self._commands = command_names + mcp_subcommands
async def get_suggestion(self, value: str) -> str | None:
"""Get a command suggestion based on the current input.
Args:
value: Current input value
Returns:
Suggested completion or None
"""
if not value or not value.startswith("/"):
return None
# Find the first command that starts with the input
value_lower = value.lower()
for cmd in self._commands:
if cmd.lower().startswith(value_lower) and cmd.lower() != value_lower:
# Return the rest of the command (after what's already typed)
return cmd[len(value):]
return None

39
oai/tui/widgets/footer.py Normal file
View File

@@ -0,0 +1,39 @@
"""Footer widget for oAI TUI."""
from textual.app import ComposeResult
from textual.widgets import Static
class Footer(Static):
"""Footer displaying session metrics."""
def __init__(self):
super().__init__()
self.tokens_in = 0
self.tokens_out = 0
self.cost = 0.0
self.messages = 0
def compose(self) -> ComposeResult:
"""Compose the footer."""
yield Static(self._format_footer(), id="footer-content")
def _format_footer(self) -> str:
"""Format the footer text."""
return (
f"[dim]Messages: {self.messages} | "
f"Tokens: {self.tokens_in + self.tokens_out:,} "
f"({self.tokens_in:,} in, {self.tokens_out:,} out) | "
f"Cost: ${self.cost:.4f}[/]"
)
def update_stats(
self, tokens_in: int, tokens_out: int, cost: float, messages: int
) -> None:
"""Update the displayed statistics."""
self.tokens_in = tokens_in
self.tokens_out = tokens_out
self.cost = cost
self.messages = messages
content = self.query_one("#footer-content", Static)
content.update(self._format_footer())

65
oai/tui/widgets/header.py Normal file
View File

@@ -0,0 +1,65 @@
"""Header widget for oAI TUI."""
from textual.app import ComposeResult
from textual.widgets import Static
from typing import Optional, Dict, Any
class Header(Static):
"""Header displaying app title, version, current model, and capabilities."""
def __init__(self, version: str = "3.0.0", model: str = "", model_info: Optional[Dict[str, Any]] = None):
super().__init__()
self.version = version
self.model = model
self.model_info = model_info or {}
def compose(self) -> ComposeResult:
"""Compose the header."""
yield Static(self._format_header(), id="header-content")
def _format_capabilities(self) -> str:
"""Format capability icons based on model info."""
if not self.model_info:
return ""
icons = []
# Check vision support
architecture = self.model_info.get("architecture", {})
modality = architecture.get("modality", "")
if "image" in modality:
icons.append("[bold cyan]👁️[/]") # Bright if supported
else:
icons.append("[dim]👁️[/]") # Dim if not supported
# Check tool support
supported_params = self.model_info.get("supported_parameters", [])
if "tools" in supported_params or "tool_choice" in supported_params:
icons.append("[bold cyan]🔧[/]")
else:
icons.append("[dim]🔧[/]")
# Check online support (most models support :online suffix)
model_id = self.model_info.get("id", "")
if ":online" in model_id or model_id not in ["openrouter/free"]:
icons.append("[bold cyan]🌐[/]")
else:
icons.append("[dim]🌐[/]")
return " ".join(icons) if icons else ""
def _format_header(self) -> str:
"""Format the header text."""
model_text = f" | {self.model}" if self.model else ""
capabilities = self._format_capabilities()
capabilities_text = f" {capabilities}" if capabilities else ""
return f"[bold cyan]oAI[/] [dim]v{self.version}[/]{model_text}{capabilities_text}"
def update_model(self, model: str, model_info: Optional[Dict[str, Any]] = None) -> None:
"""Update the displayed model and capabilities."""
self.model = model
if model_info:
self.model_info = model_info
content = self.query_one("#header-content", Static)
content.update(self._format_header())

View File

@@ -0,0 +1,49 @@
"""Input bar widget for oAI TUI."""
from textual.app import ComposeResult
from textual.containers import Horizontal
from textual.widgets import Input, Static
class InputBar(Horizontal):
"""Input bar with prompt prefix and text input."""
def __init__(self):
super().__init__(id="input-bar")
self.mcp_status = ""
self.online_mode = False
def compose(self) -> ComposeResult:
"""Compose the input bar."""
yield Static(self._format_prefix(), id="input-prefix", classes="prefix-hidden" if not (self.mcp_status or self.online_mode) else "")
yield Input(
placeholder="Type a message or /command...",
id="chat-input"
)
def _format_prefix(self) -> str:
"""Format the input prefix with status indicators."""
indicators = []
if self.mcp_status:
indicators.append(f"[cyan]{self.mcp_status}[/]")
if self.online_mode:
indicators.append("[green]🌐[/]")
prefix = " ".join(indicators) + " " if indicators else ""
return f"{prefix}[bold]>[/]"
def update_mcp_status(self, status: str) -> None:
"""Update MCP status indicator."""
self.mcp_status = status
prefix = self.query_one("#input-prefix", Static)
prefix.update(self._format_prefix())
def update_online_mode(self, online: bool) -> None:
"""Update online mode indicator."""
self.online_mode = online
prefix = self.query_one("#input-prefix", Static)
prefix.update(self._format_prefix())
def get_input(self) -> Input:
"""Get the input widget."""
return self.query_one("#chat-input", Input)

View File

@@ -0,0 +1,69 @@
"""Message widgets for oAI TUI."""
from typing import Any, AsyncIterator, Tuple
from rich.markdown import Markdown
from textual.app import ComposeResult
from textual.widgets import RichLog, Static
class UserMessageWidget(Static):
"""Widget for displaying user messages."""
def __init__(self, content: str):
super().__init__()
self.content = content
def compose(self) -> ComposeResult:
"""Compose the user message."""
yield Static(f"[bold green]You:[/] {self.content}")
class SystemMessageWidget(Static):
"""Widget for displaying system/info messages without 'You:' prefix."""
def __init__(self, content: str):
super().__init__()
self.content = content
def compose(self) -> ComposeResult:
"""Compose the system message."""
yield Static(self.content)
class AssistantMessageWidget(Static):
"""Widget for displaying assistant responses with streaming support."""
def __init__(self, model_name: str = "Assistant"):
super().__init__()
self.model_name = model_name
self.full_text = ""
def compose(self) -> ComposeResult:
"""Compose the assistant message."""
yield Static(f"[bold]{self.model_name}:[/]", id="assistant-label")
yield RichLog(id="assistant-content", highlight=True, markup=True, wrap=True)
async def stream_response(self, response_iterator: AsyncIterator) -> Tuple[str, Any]:
"""Stream tokens progressively and return final text and usage."""
log = self.query_one("#assistant-content", RichLog)
self.full_text = ""
usage = None
async for chunk in response_iterator:
if hasattr(chunk, "delta_content") and chunk.delta_content:
self.full_text += chunk.delta_content
log.clear()
log.write(Markdown(self.full_text))
if hasattr(chunk, "usage") and chunk.usage:
usage = chunk.usage
return self.full_text, usage
def set_content(self, content: str) -> None:
"""Set the complete content (non-streaming)."""
self.full_text = content
log = self.query_one("#assistant-content", RichLog)
log.clear()
log.write(Markdown(content))

20
oai/utils/__init__.py Normal file
View File

@@ -0,0 +1,20 @@
"""
Utility modules for oAI.
This package provides common utilities used throughout the application
including logging, file handling, and export functionality.
"""
from oai.utils.logging import setup_logging, get_logger
from oai.utils.files import read_file_safe, is_binary_file
from oai.utils.export import export_as_markdown, export_as_json, export_as_html
__all__ = [
"setup_logging",
"get_logger",
"read_file_safe",
"is_binary_file",
"export_as_markdown",
"export_as_json",
"export_as_html",
]

248
oai/utils/export.py Normal file
View File

@@ -0,0 +1,248 @@
"""
Export utilities for oAI.
This module provides functions for exporting conversation history
in various formats including Markdown, JSON, and HTML.
"""
import json
import datetime
from typing import List, Dict
from html import escape as html_escape
from oai.constants import APP_VERSION, APP_URL
def export_as_markdown(
session_history: List[Dict[str, str]],
session_system_prompt: str = ""
) -> str:
"""
Export conversation history as Markdown.
Args:
session_history: List of message dictionaries with 'prompt' and 'response'
session_system_prompt: Optional system prompt to include
Returns:
Markdown formatted string
"""
lines = ["# Conversation Export", ""]
if session_system_prompt:
lines.extend([f"**System Prompt:** {session_system_prompt}", ""])
lines.append(f"**Export Date:** {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
lines.append(f"**Messages:** {len(session_history)}")
lines.append("")
lines.append("---")
lines.append("")
for i, entry in enumerate(session_history, 1):
lines.append(f"## Message {i}")
lines.append("")
lines.append("**User:**")
lines.append("")
lines.append(entry.get("prompt", ""))
lines.append("")
lines.append("**Assistant:**")
lines.append("")
lines.append(entry.get("response", ""))
lines.append("")
lines.append("---")
lines.append("")
lines.append(f"*Exported from oAI v{APP_VERSION} - {APP_URL}*")
return "\n".join(lines)
def export_as_json(
session_history: List[Dict[str, str]],
session_system_prompt: str = ""
) -> str:
"""
Export conversation history as JSON.
Args:
session_history: List of message dictionaries
session_system_prompt: Optional system prompt to include
Returns:
JSON formatted string
"""
export_data = {
"export_date": datetime.datetime.now().isoformat(),
"app_version": APP_VERSION,
"system_prompt": session_system_prompt,
"message_count": len(session_history),
"messages": [
{
"index": i + 1,
"prompt": entry.get("prompt", ""),
"response": entry.get("response", ""),
"prompt_tokens": entry.get("prompt_tokens", 0),
"completion_tokens": entry.get("completion_tokens", 0),
"cost": entry.get("msg_cost", 0.0),
}
for i, entry in enumerate(session_history)
],
"totals": {
"prompt_tokens": sum(e.get("prompt_tokens", 0) for e in session_history),
"completion_tokens": sum(e.get("completion_tokens", 0) for e in session_history),
"total_cost": sum(e.get("msg_cost", 0.0) for e in session_history),
}
}
return json.dumps(export_data, indent=2, ensure_ascii=False)
def export_as_html(
session_history: List[Dict[str, str]],
session_system_prompt: str = ""
) -> str:
"""
Export conversation history as styled HTML.
Args:
session_history: List of message dictionaries
session_system_prompt: Optional system prompt to include
Returns:
HTML formatted string with embedded CSS
"""
html_parts = [
"<!DOCTYPE html>",
"<html>",
"<head>",
" <meta charset='UTF-8'>",
" <meta name='viewport' content='width=device-width, initial-scale=1.0'>",
" <title>Conversation Export - oAI</title>",
" <style>",
" * { box-sizing: border-box; }",
" body {",
" font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;",
" max-width: 900px;",
" margin: 40px auto;",
" padding: 20px;",
" background: #f5f5f5;",
" color: #333;",
" }",
" .header {",
" background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);",
" color: white;",
" padding: 30px;",
" border-radius: 10px;",
" margin-bottom: 30px;",
" box-shadow: 0 4px 6px rgba(0,0,0,0.1);",
" }",
" .header h1 {",
" margin: 0 0 10px 0;",
" font-size: 2em;",
" }",
" .export-info {",
" opacity: 0.9;",
" font-size: 0.95em;",
" margin: 5px 0;",
" }",
" .system-prompt {",
" background: #fff3cd;",
" padding: 20px;",
" border-radius: 8px;",
" margin-bottom: 25px;",
" border-left: 5px solid #ffc107;",
" box-shadow: 0 2px 4px rgba(0,0,0,0.05);",
" }",
" .system-prompt strong {",
" color: #856404;",
" display: block;",
" margin-bottom: 10px;",
" font-size: 1.1em;",
" }",
" .message-container { margin-bottom: 20px; }",
" .message {",
" background: white;",
" padding: 20px;",
" border-radius: 8px;",
" box-shadow: 0 2px 4px rgba(0,0,0,0.08);",
" margin-bottom: 12px;",
" }",
" .user-message { border-left: 5px solid #10b981; }",
" .assistant-message { border-left: 5px solid #3b82f6; }",
" .role {",
" font-weight: bold;",
" margin-bottom: 12px;",
" font-size: 1.05em;",
" text-transform: uppercase;",
" letter-spacing: 0.5px;",
" }",
" .user-role { color: #10b981; }",
" .assistant-role { color: #3b82f6; }",
" .content {",
" line-height: 1.8;",
" white-space: pre-wrap;",
" color: #333;",
" }",
" .message-number {",
" color: #6b7280;",
" font-size: 0.85em;",
" margin-bottom: 15px;",
" font-weight: 600;",
" }",
" .footer {",
" text-align: center;",
" margin-top: 40px;",
" padding: 20px;",
" color: #6b7280;",
" font-size: 0.9em;",
" }",
" .footer a { color: #667eea; text-decoration: none; }",
" .footer a:hover { text-decoration: underline; }",
" @media print {",
" body { background: white; }",
" .message { break-inside: avoid; }",
" }",
" </style>",
"</head>",
"<body>",
" <div class='header'>",
" <h1>Conversation Export</h1>",
f" <div class='export-info'>Exported: {datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</div>",
f" <div class='export-info'>Total Messages: {len(session_history)}</div>",
" </div>",
]
if session_system_prompt:
html_parts.extend([
" <div class='system-prompt'>",
" <strong>System Prompt</strong>",
f" <div>{html_escape(session_system_prompt)}</div>",
" </div>",
])
for i, entry in enumerate(session_history, 1):
prompt = html_escape(entry.get("prompt", ""))
response = html_escape(entry.get("response", ""))
html_parts.extend([
" <div class='message-container'>",
f" <div class='message-number'>Message {i} of {len(session_history)}</div>",
" <div class='message user-message'>",
" <div class='role user-role'>User</div>",
f" <div class='content'>{prompt}</div>",
" </div>",
" <div class='message assistant-message'>",
" <div class='role assistant-role'>Assistant</div>",
f" <div class='content'>{response}</div>",
" </div>",
" </div>",
])
html_parts.extend([
" <div class='footer'>",
f" <p>Generated by oAI v{APP_VERSION} &bull; <a href='{APP_URL}'>{APP_URL}</a></p>",
" </div>",
"</body>",
"</html>",
])
return "\n".join(html_parts)

323
oai/utils/files.py Normal file
View File

@@ -0,0 +1,323 @@
"""
File handling utilities for oAI.
This module provides safe file reading, type detection, and other
file-related operations used throughout the application.
"""
import os
import mimetypes
import base64
from pathlib import Path
from typing import Optional, Dict, Any, Tuple
from oai.constants import (
MAX_FILE_SIZE,
CONTENT_TRUNCATION_THRESHOLD,
SUPPORTED_CODE_EXTENSIONS,
ALLOWED_FILE_EXTENSIONS,
)
from oai.utils.logging import get_logger
def is_binary_file(file_path: Path) -> bool:
"""
Check if a file appears to be binary.
Args:
file_path: Path to the file to check
Returns:
True if the file appears to be binary, False otherwise
"""
try:
with open(file_path, "rb") as f:
# Read first 8KB to check for binary content
chunk = f.read(8192)
# Check for null bytes (common in binary files)
if b"\x00" in chunk:
return True
# Try to decode as UTF-8
try:
chunk.decode("utf-8")
return False
except UnicodeDecodeError:
return True
except Exception:
return True
def get_file_type(file_path: Path) -> Tuple[Optional[str], str]:
"""
Determine the MIME type and category of a file.
Args:
file_path: Path to the file
Returns:
Tuple of (mime_type, category) where category is one of:
'image', 'pdf', 'code', 'text', 'binary', 'unknown'
"""
mime_type, _ = mimetypes.guess_type(str(file_path))
ext = file_path.suffix.lower()
if mime_type and mime_type.startswith("image/"):
return mime_type, "image"
elif mime_type == "application/pdf" or ext == ".pdf":
return mime_type or "application/pdf", "pdf"
elif ext in SUPPORTED_CODE_EXTENSIONS:
return mime_type or "text/plain", "code"
elif mime_type and mime_type.startswith("text/"):
return mime_type, "text"
elif is_binary_file(file_path):
return mime_type, "binary"
else:
return mime_type, "unknown"
def read_file_safe(
file_path: Path,
max_size: int = MAX_FILE_SIZE,
truncate_threshold: int = CONTENT_TRUNCATION_THRESHOLD
) -> Dict[str, Any]:
"""
Safely read a file with size limits and truncation support.
Args:
file_path: Path to the file to read
max_size: Maximum file size to read (bytes)
truncate_threshold: Threshold for truncating large files
Returns:
Dictionary containing:
- content: File content (text or base64)
- size: File size in bytes
- truncated: Whether content was truncated
- encoding: 'text', 'base64', or None on error
- error: Error message if reading failed
"""
logger = get_logger()
try:
path = Path(file_path).resolve()
if not path.exists():
return {
"content": None,
"size": 0,
"truncated": False,
"encoding": None,
"error": f"File not found: {path}"
}
if not path.is_file():
return {
"content": None,
"size": 0,
"truncated": False,
"encoding": None,
"error": f"Not a file: {path}"
}
file_size = path.stat().st_size
if file_size > max_size:
return {
"content": None,
"size": file_size,
"truncated": False,
"encoding": None,
"error": f"File too large: {file_size / (1024*1024):.1f}MB (max: {max_size / (1024*1024):.0f}MB)"
}
# Try to read as text first
try:
content = path.read_text(encoding="utf-8")
# Check if truncation is needed
if file_size > truncate_threshold:
lines = content.split("\n")
total_lines = len(lines)
# Keep first 500 lines and last 100 lines
head_lines = 500
tail_lines = 100
if total_lines > (head_lines + tail_lines):
truncated_content = (
"\n".join(lines[:head_lines]) +
f"\n\n... [TRUNCATED: {total_lines - head_lines - tail_lines} lines omitted] ...\n\n" +
"\n".join(lines[-tail_lines:])
)
logger.info(f"Read file (truncated): {path} ({file_size} bytes, {total_lines} lines)")
return {
"content": truncated_content,
"size": file_size,
"truncated": True,
"total_lines": total_lines,
"lines_shown": head_lines + tail_lines,
"encoding": "text",
"error": None
}
logger.info(f"Read file: {path} ({file_size} bytes)")
return {
"content": content,
"size": file_size,
"truncated": False,
"encoding": "text",
"error": None
}
except UnicodeDecodeError:
# File is binary, return base64 encoded
with open(path, "rb") as f:
binary_data = f.read()
b64_content = base64.b64encode(binary_data).decode("utf-8")
logger.info(f"Read binary file: {path} ({file_size} bytes)")
return {
"content": b64_content,
"size": file_size,
"truncated": False,
"encoding": "base64",
"error": None
}
except PermissionError as e:
return {
"content": None,
"size": 0,
"truncated": False,
"encoding": None,
"error": f"Permission denied: {e}"
}
except Exception as e:
logger.error(f"Error reading file {file_path}: {e}")
return {
"content": None,
"size": 0,
"truncated": False,
"encoding": None,
"error": str(e)
}
def get_file_extension(file_path: Path) -> str:
"""
Get the lowercase file extension.
Args:
file_path: Path to the file
Returns:
Lowercase extension including the dot (e.g., '.py')
"""
return file_path.suffix.lower()
def is_allowed_extension(file_path: Path) -> bool:
"""
Check if a file has an allowed extension for attachment.
Args:
file_path: Path to the file
Returns:
True if the extension is allowed, False otherwise
"""
return get_file_extension(file_path) in ALLOWED_FILE_EXTENSIONS
def format_file_size(size_bytes: int) -> str:
"""
Format a file size in human-readable format.
Args:
size_bytes: Size in bytes
Returns:
Formatted string (e.g., '1.5 MB', '512 KB')
"""
for unit in ["B", "KB", "MB", "GB", "TB"]:
if abs(size_bytes) < 1024:
return f"{size_bytes:.1f} {unit}"
size_bytes /= 1024
return f"{size_bytes:.1f} PB"
def prepare_file_attachment(
file_path: Path,
model_capabilities: Dict[str, Any]
) -> Optional[Dict[str, Any]]:
"""
Prepare a file for attachment to an API request.
Args:
file_path: Path to the file
model_capabilities: Model capability information
Returns:
Content block dictionary for the API, or None if unsupported
"""
logger = get_logger()
path = Path(file_path).resolve()
if not path.exists():
logger.warning(f"File not found: {path}")
return None
mime_type, category = get_file_type(path)
file_size = path.stat().st_size
if file_size > MAX_FILE_SIZE:
logger.warning(f"File too large: {path} ({format_file_size(file_size)})")
return None
try:
with open(path, "rb") as f:
file_data = f.read()
if category == "image":
# Check if model supports images
input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", [])
if "image" not in input_modalities:
logger.warning(f"Model does not support images")
return None
b64_data = base64.b64encode(file_data).decode("utf-8")
return {
"type": "image_url",
"image_url": {"url": f"data:{mime_type};base64,{b64_data}"}
}
elif category == "pdf":
# Check if model supports PDFs
input_modalities = model_capabilities.get("architecture", {}).get("input_modalities", [])
supports_pdf = any(mod in input_modalities for mod in ["document", "pdf", "file"])
if not supports_pdf:
logger.warning(f"Model does not support PDFs")
return None
b64_data = base64.b64encode(file_data).decode("utf-8")
return {
"type": "image_url",
"image_url": {"url": f"data:application/pdf;base64,{b64_data}"}
}
elif category in ("code", "text"):
text_content = file_data.decode("utf-8")
return {
"type": "text",
"text": f"File: {path.name}\n\n{text_content}"
}
else:
logger.warning(f"Unsupported file type: {category} ({mime_type})")
return None
except UnicodeDecodeError:
logger.error(f"Cannot decode file as UTF-8: {path}")
return None
except Exception as e:
logger.error(f"Error preparing file attachment {path}: {e}")
return None

297
oai/utils/logging.py Normal file
View File

@@ -0,0 +1,297 @@
"""
Logging configuration for oAI.
This module provides centralized logging setup with Rich formatting,
file rotation, and configurable log levels.
"""
import io
import os
import glob
import logging
import datetime
import shutil
from logging.handlers import RotatingFileHandler
from pathlib import Path
from typing import Optional
from rich.console import Console
from rich.logging import RichHandler
from oai.constants import (
LOG_FILE,
CONFIG_DIR,
DEFAULT_LOG_MAX_SIZE_MB,
DEFAULT_LOG_BACKUP_COUNT,
DEFAULT_LOG_LEVEL,
VALID_LOG_LEVELS,
)
class RotatingRichHandler(RotatingFileHandler):
"""
Custom log handler combining file rotation with Rich formatting.
This handler writes Rich-formatted log output to a rotating file,
providing colored and formatted logs even in file output while
managing file size and backups automatically.
"""
def __init__(self, *args, **kwargs):
"""Initialize the handler with Rich console for formatting."""
super().__init__(*args, **kwargs)
# Create an internal console for Rich formatting
self.rich_console = Console(
file=io.StringIO(),
width=120,
force_terminal=False
)
self.rich_handler = RichHandler(
console=self.rich_console,
show_time=True,
show_path=True,
rich_tracebacks=True,
tracebacks_suppress=["requests", "openrouter", "urllib3", "httpx", "openai"]
)
def emit(self, record: logging.LogRecord) -> None:
"""
Emit a log record with Rich formatting.
Args:
record: The log record to emit
"""
try:
# Format with Rich
self.rich_handler.emit(record)
output = self.rich_console.file.getvalue()
self.rich_console.file.seek(0)
self.rich_console.file.truncate(0)
if output:
self.stream.write(output)
self.flush()
except Exception:
self.handleError(record)
class LoggingManager:
"""
Manages application logging configuration.
Provides methods to setup, configure, and manage logging with
support for runtime reconfiguration and level changes.
"""
def __init__(self):
"""Initialize the logging manager."""
self.handler: Optional[RotatingRichHandler] = None
self.app_logger: Optional[logging.Logger] = None
self.max_size_mb: int = DEFAULT_LOG_MAX_SIZE_MB
self.backup_count: int = DEFAULT_LOG_BACKUP_COUNT
self.log_level: str = DEFAULT_LOG_LEVEL
def setup(
self,
max_size_mb: Optional[int] = None,
backup_count: Optional[int] = None,
log_level: Optional[str] = None
) -> logging.Logger:
"""
Setup or reconfigure logging.
Args:
max_size_mb: Maximum log file size in MB
backup_count: Number of backup files to keep
log_level: Logging level string
Returns:
The configured application logger
"""
# Update configuration if provided
if max_size_mb is not None:
self.max_size_mb = max_size_mb
if backup_count is not None:
self.backup_count = backup_count
if log_level is not None:
self.log_level = log_level
# Ensure config directory exists
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
# Get root logger
root_logger = logging.getLogger()
# Remove existing handler if present
if self.handler is not None:
root_logger.removeHandler(self.handler)
try:
self.handler.close()
except Exception:
pass
# Check if log needs manual rotation
self._check_rotation()
# Create new handler
max_bytes = self.max_size_mb * 1024 * 1024
self.handler = RotatingRichHandler(
filename=str(LOG_FILE),
maxBytes=max_bytes,
backupCount=self.backup_count,
encoding="utf-8"
)
self.handler.setLevel(logging.NOTSET)
root_logger.setLevel(logging.WARNING)
root_logger.addHandler(self.handler)
# Suppress noisy third-party loggers
for logger_name in [
"asyncio", "urllib3", "requests", "httpx",
"httpcore", "openai", "openrouter"
]:
logging.getLogger(logger_name).setLevel(logging.WARNING)
# Configure application logger
self.app_logger = logging.getLogger("oai_app")
level = VALID_LOG_LEVELS.get(self.log_level.lower(), logging.INFO)
self.app_logger.setLevel(level)
self.app_logger.propagate = True
return self.app_logger
def _check_rotation(self) -> None:
"""Check if log file needs rotation and perform if necessary."""
if not LOG_FILE.exists():
return
current_size = LOG_FILE.stat().st_size
max_bytes = self.max_size_mb * 1024 * 1024
if current_size >= max_bytes:
# Perform manual rotation
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
backup_file = f"{LOG_FILE}.{timestamp}"
try:
shutil.move(str(LOG_FILE), backup_file)
except Exception:
pass
# Clean old backups
self._cleanup_old_backups()
def _cleanup_old_backups(self) -> None:
"""Remove old backup files exceeding the backup count."""
log_dir = LOG_FILE.parent
backup_pattern = f"{LOG_FILE.name}.*"
backups = sorted(glob.glob(str(log_dir / backup_pattern)))
while len(backups) > self.backup_count:
oldest = backups.pop(0)
try:
os.remove(oldest)
except Exception:
pass
def set_level(self, level: str) -> bool:
"""
Set the application log level.
Args:
level: Log level string (debug/info/warning/error/critical)
Returns:
True if level was set successfully, False otherwise
"""
level_lower = level.lower()
if level_lower not in VALID_LOG_LEVELS:
return False
self.log_level = level_lower
if self.app_logger:
self.app_logger.setLevel(VALID_LOG_LEVELS[level_lower])
return True
def get_logger(self) -> logging.Logger:
"""
Get the application logger, initializing if necessary.
Returns:
The application logger
"""
if self.app_logger is None:
self.setup()
return self.app_logger
# Global logging manager instance
_logging_manager = LoggingManager()
def setup_logging(
max_size_mb: Optional[int] = None,
backup_count: Optional[int] = None,
log_level: Optional[str] = None
) -> logging.Logger:
"""
Setup application logging.
This is the main entry point for configuring logging. Call this
early in application startup.
Args:
max_size_mb: Maximum log file size in MB
backup_count: Number of backup files to keep
log_level: Logging level string
Returns:
The configured application logger
"""
return _logging_manager.setup(max_size_mb, backup_count, log_level)
def get_logger() -> logging.Logger:
"""
Get the application logger.
Returns:
The application logger instance
"""
return _logging_manager.get_logger()
def set_log_level(level: str) -> bool:
"""
Set the application log level.
Args:
level: Log level string
Returns:
True if successful, False otherwise
"""
return _logging_manager.set_level(level)
def reload_logging(
max_size_mb: Optional[int] = None,
backup_count: Optional[int] = None,
log_level: Optional[str] = None
) -> logging.Logger:
"""
Reload logging configuration.
Useful when settings change at runtime.
Args:
max_size_mb: New maximum log file size
backup_count: New backup count
log_level: New log level
Returns:
The reconfigured logger
"""
return _logging_manager.setup(max_size_mb, backup_count, log_level)

134
pyproject.toml Normal file
View File

@@ -0,0 +1,134 @@
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "oai"
version = "3.0.0"
description = "OpenRouter AI Chat Client - A feature-rich terminal-based chat application"
readme = "README.md"
license = {text = "MIT"}
authors = [
{name = "Rune", email = "rune@example.com"}
]
maintainers = [
{name = "Rune", email = "rune@example.com"}
]
keywords = [
"ai",
"chat",
"openrouter",
"cli",
"terminal",
"mcp",
"llm",
]
classifiers = [
"Development Status :: 4 - Beta",
"Environment :: Console",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
"Topic :: Utilities",
]
requires-python = ">=3.10"
dependencies = [
"anyio>=4.0.0",
"click>=8.0.0",
"httpx>=0.24.0",
"markdown-it-py>=3.0.0",
"openrouter>=0.0.19",
"packaging>=21.0",
"pyperclip>=1.8.0",
"requests>=2.28.0",
"rich>=13.0.0",
"textual>=0.50.0",
"typer>=0.9.0",
"mcp>=1.0.0",
]
[project.optional-dependencies]
dev = [
"pytest>=7.0.0",
"pytest-asyncio>=0.21.0",
"pytest-cov>=4.0.0",
"black>=23.0.0",
"isort>=5.12.0",
"mypy>=1.0.0",
"ruff>=0.1.0",
]
[project.urls]
Homepage = "https://iurl.no/oai"
Repository = "https://gitlab.pm/rune/oai"
Documentation = "https://iurl.no/oai"
"Bug Tracker" = "https://gitlab.pm/rune/oai/issues"
[project.scripts]
oai = "oai.cli:main"
[tool.setuptools]
packages = ["oai", "oai.commands", "oai.config", "oai.core", "oai.mcp", "oai.providers", "oai.tui", "oai.tui.widgets", "oai.tui.screens", "oai.utils"]
[tool.setuptools.package-data]
oai = ["py.typed"]
[tool.black]
line-length = 100
target-version = ["py310", "py311", "py312"]
include = '\.pyi?$'
exclude = '''
/(
\.git
| \.mypy_cache
| \.pytest_cache
| \.venv
| build
| dist
)/
'''
[tool.isort]
profile = "black"
line_length = 100
skip_gitignore = true
[tool.mypy]
python_version = "3.10"
warn_return_any = true
warn_unused_configs = true
ignore_missing_imports = true
exclude = [
"build",
"dist",
".venv",
]
[tool.ruff]
line-length = 100
target-version = "py310"
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # Pyflakes
"I", # isort
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"UP", # pyupgrade
]
ignore = [
"E501", # line too long (handled by black)
"B008", # do not perform function calls in argument defaults
"C901", # too complex
]
[tool.pytest.ini_options]
testpaths = ["tests"]
python_files = ["test_*.py"]
asyncio_mode = "auto"
addopts = "-v --tb=short"